function [spikeTime, spikePeak, n] = findspikes(traces, fs, thres, varargin)
% 
% FINDSPIKES: This function performs spike discrimination on 
% single tips exceeding or within a threshold and/or time window range.
%
% Syntax:
%
%  spikeTime = findspikes(traces, fs, threshold)
%  [spikeTime spikePeak n] = findspikes(traces, fs, threshold [,direction] [,win_range] [,'plot'])
% 
% Description:
%
%  traces   : Multiple traces of signal. each trace in a column
%  fs       : Sampling frequency, in KHz
%  threshold: Either a scalar, or [thres1 thres2] to define a range
%  direction: Optional. 
%             A positive number to find positive-going spikes, and vice versa. 
%             The default value is +1 when threshold is a scalar,
%             is sign(thres2-thres1) when threshold is a vector.
%  win_range: Optional. 
%             [win_min win_max] to define a time window range of the width at threshold. in ms.
%  'plot'   : Optional. Plot the result. Not to plot by default.
%  (optional parameters: direction, win_range and 'plot' can be in any order)
%
%  spikeTime: Returns a cell array, each cell is a vector of spike times in each trial.
%  spikePeak: Returns the peak values of each spikes.
%  n        : the total number of spikes
% 
% Samples:
%
%  spikeTime=findspikes(signal, 10, 0.25);
%  spikeTime=findspikes(signal, 10, 0.25, 'plot');
%  spikeTime=findspikes(signal, 10, [-0.2 -0.5]);
%  spikeTime=findspikes(signal, 10, [-0.5 -0.2], -1);
%  spikeTime=findspikes(signal, 10, [-0.2 -0.3], [0.1 2], 'plot');
%
% Author: Li, Su based on the original of Alfonso Delagado-Reyes

% Copyright (c) 2007 Cengiz Gunay <cengique@users.sf.net>; Li, Su.
% This work is licensed under the Academic Free License ("AFL")
% v. 3.0. To view a copy of this license, please look at the COPYING
% file distributed with this software or visit
% http://opensource.org/licenses/afl-3.0.php.

% assign the arguments========================
error(nargchk(3,6,nargin))

for k=1:nargin-3
    if ischar(varargin{k})
        plotit=varargin{k};
    elseif isnumeric(varargin{k})
        if length(varargin{k})==1
            direction=varargin{k};
        else
            win_range=varargin{k};
        end
    else
        error('error')
    end
end

if ~exist('plotit', 'var')
    plotit = '';
end

if ~exist('direction', 'var')
    if length(thres)==1
        direction = 1;
    else
        direction=thres(2)-thres(1);
    end
end

direction=sign(direction); 
if direction==0
    error('threshold range or direction is zero')
end

% start to find spikes ===========================
thres = thres*direction;
thresh_min=min(thres);
n = 0;

for idx = 1:size(traces,2)

    trace = traces(:,idx);

    % flip the trace and threshold up-side-down to find the down-pointing spikes.
    trace = trace*direction;

    left_edges = find(trace(1:end-1) < thresh_min & trace(2:end) >= thresh_min);  % find rising slopes across the threshold
    right_edges = find(trace(1:end-1) >= thresh_min & trace(2:end) < thresh_min);  % find falling slopes across the threshold
    timeidx=[]; tips=[]; spike_num=0;
    
    not_empty=(~isempty(left_edges) && ~isempty(right_edges));
    while not_empty
        if right_edges(1) < left_edges(1)    % match the left and right edges of each window.
            right_edges(1)=[];
        end
        spike_num=min(length(left_edges), length(right_edges));
        if spike_num==0; break; end
        left_edges=left_edges(1:spike_num);
        right_edges=right_edges(1:spike_num);

        % eliminate the time windows out of the range.
        if exist('win_range', 'var')
            data_range=win_range.*fs;
            left_time=left_edges + (thresh_min - trace(left_edges)) ./ (trace(left_edges+1) - trace(left_edges));
            right_time=right_edges + (thresh_min - trace(right_edges)) ./ (trace(right_edges+1) - trace(right_edges));
            win_width=right_time-left_time;
            out_of_data_range_idx=find(win_width < data_range(1) | win_width > data_range(2));
            left_edges(out_of_data_range_idx)=[];
            right_edges(out_of_data_range_idx)=[];                
            spike_num=min(length(left_edges), length(right_edges));
        end
        if spike_num==0; break; end

        % merge the neighbor windows too close to each other.
        refractory=1;
        interval=left_edges(2:end)-right_edges(1:end-1);
        smallInterval=find(interval < refractory*fs);
        count=0; realSmallInterval=[];
        for k=1:length(smallInterval)
            avg=mean(trace);
            if min(trace(right_edges(smallInterval(k)):left_edges(smallInterval(k)+1)))>avg+(thresh_min-avg)*0.5
                count=count+1;
                realSmallInterval(count)=smallInterval(k);
            end
        end
        right_edges(realSmallInterval)=[];
        left_edges(realSmallInterval+1)=[];
        spike_num=min(length(left_edges), length(right_edges));
        if spike_num==0; break; end

        timeidx=zeros(spike_num,1); tips=zeros(spike_num,1);

        [tips timeidx]=arrayfun(@(x,y)max(trace(x:y)),left_edges,right_edges);
        timeidx=timeidx+left_edges-1;

        % eliminate the spikes exceed the max threshold.
        if length(thres)>1 
            bigSpikes=find(tips>max(thres));
            tips(bigSpikes)=[];
            timeidx(bigSpikes)=[];
        end
        spike_num=length(timeidx);

%         % eliminate small intervals.
%         interval=diff(timeidx)/fs;
%         smallInterval=find(interval<1);
%         rmlist=[];
%         for k=1:length(smallInterval)
%             if min(trace(timeidx(smallInterval(k)):timeidx(smallInterval(k)+1)))>thresh_min/2
%                 rmlist(end+1)=smallInterval(k);
%                 if tips(rmlist(end))>tips(rmlist(end)+1)
%                     rmlist(end)=rmlist(end)+1;
%                 end
%             end
%         end
%         tips(rmlist)=[];
%         timeidx(rmlist)=[];
        tips=tips*direction;
        not_empty=false;
    end

    if min(size(traces)) == 1
        spikeTime = timeidx/fs;
        spikePeak = tips;
    else
        spikeTime{idx} = timeidx/fs;
        spikePeak{idx} = tips;
    end
    n = n+spike_num;
end

% flip the traces back.
thres=thres*direction;
trace=trace*direction;

if isequal(plotit, 'plot') || nargout==0
    m_time = [1:size(trace,1)]'/(fs);
    plot(m_time,trace,'k'); hold on
    for I=1:length(thres)
        plot([0 m_time(end)],[thres(I) thres(I)],'b');
    end
    plot(timeidx/fs, tips, 'ro'); 
    hold off
    ylabel('V/uV')
    xlabel('t/ms')
    zoom on
end