% function [weights, errorHistory] = learnedDecoders(signal, spikes, weights, iterations, jitterSD) 
% returns learned weights with which to combine component signals in order
% to approximate a desired composite signal. 
% 
% Note: in this version, 32 distinct spike jitter results are cached and re-used
% for many iterations, which speeds things up considerably. This caching was not
% done in the published version, but it does not seem to affect the
% results. 
% 
% signal: the desired composite signal
% spikes: incoming spike pattern 
% weights: weights to start from 
% iterations: number of iterations to perform before returning 
% jitterSD: Gaussian spike jitter SD (s)
% tau: filter time constant (ignored with explicit inhibition, for
%   simplicity)
% inhibitorySpikes (optional): spike pattern for inhibitory projection (if 
%   not provided, mixed positive and negative weights are allowed as a
%   simplification)
% inhibitoryWeights (optional): initial weights for inhibitory projection

function [weights, errorHistory] = learnedDecoders(signal, spikes, weights, iterations, jitterSD, tau, varargin)

    rms = (mean(signal.^2))^.5
    
    n = size(spikes,1);
    dt = .0002;
    psc = PSC(dt);
    
    components = getCurrent(spikes, dt, length(signal), psc);

    % error signal filter  
    if tau > 0
        time = dt:dt:(tau*10);
        kernel = exp(-time/tau);
        kernel = kernel / sum(kernel);
        filtSignal = loopFilter(signal, kernel);
        filtComponents = loopFilter(components, kernel);
    end
    
    k = .00001 / max(max(components)); % learning rate ... .0001 OK for no jitter
    
    nInhibitory = 0;
    if nargin > 6
        iSpikes = varargin{1};
        nInhibitory = size(iSpikes,1);
        iComponents = getCurrent(iSpikes, dt, length(signal), psc);
        ki = .00001 / max(max(iComponents));
        if (nargin > 7) 
            iWeights = varargin{2};
        else 
            iWeights = zeros(nInhibitory,1);
        end
    end
    
    weighted = weights * ones(size(signal)) .* components;
    if nInhibitory == 0
        estimate = sum(weighted);
    else 
        iWeighted = iWeights * ones(size(signal)) .* iComponents;
        estimate = sum(weighted) + sum(iWeighted);
    end
    
    errorHistory = zeros(1,iterations+1);
    errorHistory(1) = mean( (estimate - signal).^2 );
    
    for i = 1:iterations
        if jitterSD > 0
            nt = 32;
            ne = 5;
            if i <= nt | i > iterations-ne
                jittered = jitter(spikes, [jitterSD 0 0], [0 0 0], [0 1]);
                components = getCurrent(jittered, .0002, length(signal), psc);
                if tau > 0
                    filtComponents = loopFilter(components, kernel);
                end
                if nInhibitory > 0
                    iJittered = jitter(iSpikes, [jitterSD 0 0], [0 0 0], [0 1]);
                    iComponents = getCurrent(iJittered, .0002, length(signal), psc);
                end
                if i <= nt
                    cacheComponents(:,:,i) = components;
                    if nInhibitory > 0
                        cacheIComponents(:,:,i) = iComponents;
                    end
                end
            else 
                components = cacheComponents(:,:,mod(i,nt)+1);
                if nInhibitory > 0
                    iComponents = cacheIComponents(:,:,mod(i,nt)+1);
                end
            end
        end
        
        for j = 1:length(signal)
            if nInhibitory == 0
                if tau == 0
                    E = sum(weights .* components(:,j)) - signal(j);
                    dEdw = components(:,j) * E;
                else 
                    E = sum(weights .* filtComponents(:,j)) - filtSignal(j);
                    dEdw = filtComponents(:,j) * E;
                end
                
                weights = weights - k * dEdw;
            else 
                E = sum(weights .* components(:,j)) + sum(iWeights .* iComponents(:,j)) - signal(j);
                dEdw = components(:,j) * E;
                weights = weights - k * dEdw;
                weights = max(0,weights);
                
                dEdiw = iComponents(:,j) * E;
                iWeights = iWeights - ki * dEdiw;
                iWeights = min(0, iWeights);
            end
        end
    
        weighted = weights * ones(size(signal)) .* components;
        if nInhibitory == 0
            estimate = sum(weighted);
        else 
            iWeighted = iWeights * ones(size(signal)) .* iComponents;
            estimate = sum(weighted) + sum(iWeighted);
        end
        error = mean( (estimate - signal).^2 );
        errorHistory(i+1) = error/rms;
%         sprintf('%i: %1.10f', i, error/rms)
        if mod(i,200) == 1
            i
%             figure, hold on
%             plot(signal, 'k')
%             plot(estimate, 'r');
%             pause
        end
    end

    weighted = weights * ones(size(signal)) .* components;
    if nInhibitory == 0
        estimate = sum(weighted);
    else 
        iWeighted = iWeights * ones(size(signal)) .* iComponents;
        estimate = sum(weighted) + sum(iWeighted);
    end
    
%     figure, hold on
%     plot(signal, 'k')
%     plot(estimate, 'r');
    
    if nargin > 6
        weights = [weights; iWeights];
    end
    
function result = loopFilter(signals, kernel)
    for i = 1:size(signals,1)
        filtered = conv(signals(i,:), kernel);
        trailing = filtered(size(signals,2):end);
        filtered(1:length(trailing)) = filtered(1:length(trailing)) + trailing;
        result(i,:) = filtered(1:size(signals,2));
    end
%     figure, hold on, plot(signals', 'b'), plot(result', 'k')
%     pause