% function outSpikes = learnChannel(inSpikes, n, convergence) learns a
% communication channel between an ensemble with a given spike pattern and 
% a new output ensemble of LIF neurons, and returns the spikes of the 
% output ensemble. 
% 
% inSpikes: spikes of the presynaptic ensemble
% n: size of the post-synaptic LIF ensemble
% convergence: number of pre-synaptic neurons connected to each
%   post-synaptic neuron
% iterations: training iterations over data 

function outSpikes = learnChannel(inSpikes, n, convergence, iterations) 

    psc = PSC(.0002);
    current = getCurrent(inSpikes, .0002, 1500, psc);
    meanCurrent = mean(current,2);

    rate = .00000000005; % learning rate
    
    tauRef = .002;
    tauRC = .02;
    intercept = 0 + rand(n, 1) * .9
    maxFR = 200 + rand(n,1) * 200;
	x = 1 ./ (1 - exp( (tauRef - (1 ./ maxFR)) / tauRC));
	scale = (x - 1) ./ (1 - intercept)	
	bias = 1 - scale .* intercept

    alifInput = [];
    for i = 1:n
        tic
        indices = randomIndices(size(inSpikes,1), convergence);
        magnitude = mean(sum(current(indices,:)));    
        weights = rand(length(indices),1) * 1/magnitude;        
        runningMean = 0;
        maxWeight = 1/magnitude;
        
        weighted = weights * ones(1,size(current,2)) .* current(indices,:);
%         figure, hold on
%         plot(sum(weighted), 'b');
        
        weightHistory = zeros(length(indices), iterations*size(current,2));
        
        for j = 1:iterations            
%             tic
            for k = 1:size(current, 2)
                weighted = weights .* current(indices,k);
                in = bias(i) + scale(i)*sum(weighted);
                
                if (in > 1)
                    b = 1 / ( tauRef - tauRC * log(1 - 1/in) );
                else 
                    b = 0;
                end

                nMean = (j-1)*size(current,2) + k;
                runningMean = ( (nMean-1)*runningMean + b ) / nMean;
                
%                 dw = rate * (current(indices,k) - meanCurrent(indices)) * (b - runningMean) * (b > 0) * (b < maxFR(i));
                dw = rate * (current(indices,k) - meanCurrent(indices)) * (b - runningMean);
%                 sprintf('%2.15f, %2.15f, %f, %f', min(dw), max(dw), b, b - runningMean)
                weights = weights + dw;
                weights = min(maxWeight, max(0, weights));

                weightHistory(:,nMean) = weights;
                
            end
%             toc
        end
        
        weighted = weights * ones(1,size(current,2)) .* current(indices,:);
        alifInput = [alifInput; sum(weighted)];
%         plot(sum(weighted), 'r.');
%         figure, plot(weightHistory')
       toc 
    end
    
    [outSpikes, firings] = ALIF(.0002, alifInput, bias, scale, ones(size(bias))*tauRC, ones(size(bias))*tauRef, ones(size(bias)), zeros(size(bias)));


    
function indices = randomIndices(n, convergence)
    indices = 1:n;
    while length(indices) > convergence
        remove = ceil(rand * length(indices));
        indices = [indices(1:remove-1) indices(remove+1:end)];
    end