% function [weights, err, t] = decode(signal, dt, spikes, randSD, bandSD, band, noiseRate, nt, ne, interval)
% 
% signal: target signal for ensemble to approximate
% dt: time step (s)
% spikes: a matrix of spike times with spike times for each neuron in a row
% randSD: see jitter.m
% bandSD: see jitter.m
% band: see jitter.m
% noiseRate: rate of extra Poisson-refractory spikes introduced as noise
% nt: number of "training" trials (repeated presentations used to find optimal weights)
% ne: number of evaluation trials (with novel noise, used to evaluate error)
% interval: sampling interval of signal and current used for finding weights (e.g. with dt=.0002 
%       and interval=5, weights would be found using 1000 Hz sampling rate)
% 
% weights: optimal weights for approximating given signal
% err: mean-squared error of approximation
% time: elapsed time to run this function  
% 
% function [..., corruptError, examples, estimates] = decode(..., plotEstimates)
% 
% plotEstimates: plots target signal and each approximation
% corruptError: error recalculated with slightly perturbed weights (to test
%    effect of numerical precision limitations)
% examples: firing patterns of a single neuron over different training and
%    evaluation trials
% estimates: optimal approximations of target for each evaluation trial 

% IMPLEMENTATION NOTE: moderate performance gains can be had by using raw 
% spike signals instead of PSCs, to estimate deconvolution of given signal
% instead of signal itself. However, too much fussing is then needed with 
% large signals / numbers of neurons, to avoid running out of memory. The 
% method below (convolving each spike train with PSC kernel individually) 
% turns out to be not much slower, and is more straightforward. 

function [weights, err, t, varargout] = decode(signal, dt, spikes, randSD, bandSD, band, noiseRate, nt, ne, interval, varargin)  
    tic
    compositeSignal = [];
    compositeCurrent = [];   
    psc = PSC(dt);
    
    T = dt*length(signal);

    examples = [];
    for i = 1:nt
        jittered = jitter(spikes, randSD, bandSD, band); 
        noised = addNoiseSpikes(jittered, T, dt, noiseRate);
        
        current = getCurrent(noised, dt, length(signal), psc); 
        compositeSignal = [compositeSignal, signal(1:interval:end)];
        compositeCurrent = [compositeCurrent, current(:,1:interval:end)];
        
        example = noised(1,:);
        examples(i,1:length(example)) = example;
    end
    
    pack

    weights = optimalDecoders(compositeSignal, compositeCurrent, 0);
    
    corruption = (-1 + 2*(rand(size(weights))>.5)) * eps;
    err = zeros(ne,1);
    
    if (nargin > 10 & varargin{1})
        time = dt*(1:length(signal));    
        figure, hold on, plot(time, signal, 'k--') 
    end
    estimates = [];
    for i = 1:ne
        jittered = jitter(spikes, randSD, bandSD, band);
        noised = addNoiseSpikes(jittered, T, dt, noiseRate);
        
        current = getCurrent(noised, dt, length(signal), psc);
        weighted = (weights * ones(size(signal))) .* current; 
        estimate = sum(weighted);
        
        if (nargin > 10 & varargin{1}) 
            plot(time, sum(weighted), 'k');
        end
        
        err(i) = mean( (sum(weighted) - signal).^2 );
        
        if (nargout > 3)
            corruptEstimate = ((weights+corruption) * ones(size(signal))) .* current; 
            corruptError(i) = mean( (sum(corruptEstimate) - signal).^2 );
        end

        example = noised(1,:);
        examples(nt+i,1:length(example)) = example;
        
        estimates = [estimates; estimate];
    end
    
    if (nargout > 3)
        varargout{1} = corruptError;
    end
    
    if (nargout > 4)
        varargout{2} = examples;
    end
    
    if (nargout > 5) 
        varargout{3} = estimates;
    end
    
    t = toc;
    
function noised = addNoiseSpikes(spikes, T, dt, noiseRate) 
    n = size(spikes,1);
    if (noiseRate == 0) 
        noiseSpikes = [];
    else 
        [noiseSpikes cov] = genUncorrelated(n, T, dt, noiseRate, [1 0 0]);
    end
    allSpikes = sort([spikes noiseSpikes], 2);
    
    % Now we may have ISIs smaller than the refractory time.  
    absRT = .001;
    for i = 1:n
        noised(i,:) = space(allSpikes(i,:), absRT);
    end