function [mean_D1, mean_D2, new_marks] = find_mean_spikes(markers, D1_xpts, D2_xpts, trial_window)
% find moving average (windowed) mean, spike rates across all MSNs in
% defined groups (D1 or D2)
% by 'windowed' is meant a convolution over individual trials but segmented
% across phase marker boundaries (so no 'blurring' across boundaries)
%
% MARKERS is a set of boundary trials between phases; don't include the
% last trial
% D1_xpts is a vector of xpt numbers for D1 neurons (oftne 1:10)
% D2_xpts is a vector of xpt numbers for D2 neurons (oftne 11:20)
% TRIAL_WINDOW is window size over which an average is taken
% returns results in MEAN_PS_D1 and MEAN_PS_D2
% NEW_MARKS is a set of new phase boundaries based on teh segmentation used
%
xpt_nos = [D1_xpts D2_xpts];
No_xpts = length(xpt_nos);
ps_ss = [];
ps_spikes = [];
for i=1:No_xpts
fname = ['results' num2str(xpt_nos(i))];
load(fname, 'post_spikes_ss');
No_trials = length(post_spikes_ss);
for j = 1:No_trials
spike_times = post_spikes_ss{j};
ps_spikes(j) = length(spike_times);
end
ps_ss = [ps_ss; ps_spikes];
end
mean_ps_D1 = mean(ps_ss(D1_xpts, :));
mean_ps_D2 = mean(ps_ss(D2_xpts, :));
mask = ones(1, trial_window);
mean_D1 = [];
mean_D2 = [];
start = 1;
markers = [markers length(mean_ps_D1)]; %include final trial
for j = 1:length(markers)
last = markers(j);
D1_seg = mean_ps_D1(start: last);
mu_D1_seg = conv(D1_seg, mask) ./ trial_window;
mu_D1_seg = mu_D1_seg(trial_window: end - (trial_window - 1));
D2_seg = mean_ps_D2(start: last);
mu_D2_seg = conv(D2_seg, mask) ./ trial_window;
mu_D2_seg = mu_D2_seg(trial_window: end - (trial_window - 1));
mean_D1 = [mean_D1 mu_D1_seg];
mean_D2 = [mean_D2 mu_D2_seg];
marks(j) = length(mu_D1_seg);
start = last + 1;
end
new_marks = cumsum(marks);
figure(1)
plot(mean_D1);
hold on
my =max(mean_D1);
for k=1:length(new_marks)
x = new_marks(k);
plot([x x], [0 my], 'r-.');
end
set(gcf, 'PaperOri', 'portrait')
set(gcf, 'PaperUnits', 'centimeters')
set(gcf, 'PaperPos', [0 0 20 14])
% (20 and 14 are thus measured in cm)
fnme = ['mean_D1_spikes.png'];
print(gcf, '-dpng', fnme, '-r100')
hold off
figure(2)
plot(mean_D2);
hold on
my =max(mean_D2);
for k=1:length(new_marks)
x = new_marks(k);
plot([x x], [0 my], 'r-.');
end
set(gcf, 'PaperOri', 'portrait')
set(gcf, 'PaperUnits', 'centimeters')
set(gcf, 'PaperPos', [0 0 20 14])
% (20 and 14 are thus measured in cm)
fnme = ['mean_D2_spikes.png'];
print(gcf, '-dpng', fnme, '-r100')
hold off
new_marks = [1 new_marks];