function spikes = get_spikes (pm, thresh, refract)
if size(pm,1) == 1; pm = pm(:); end % If it's a single row vector, make into a column
pm = pm - thresh;
pm_later = pm(2:end,:);
pm_earlier = pm(1:end-1,:);
pmdiff = sign(pm_later) - sign(pm_earlier);
spikes = (pmdiff > 1);
spikes = [zeros(1,size(spikes,2));spikes]; % Add a row of zeroes to spikes matrix
if refract > 0
spikes_refract = [];
N = size(spikes,1);
for kk = 1:size(pm,2)
index = find(spikes(:,kk),1,'first');
index_old = index - 1;
while ~isempty(index)
index = find(spikes(index_old:end,kk),1,'first');
index = index + index_old - 1;
end_index = min(N,index+refract); % Make sure not to go over the end of the array
spikes((index+1):end_index,kk)=0;
index_old = index + 1;
end
end
end
end