% This code was used in: Masquelier & Kheradpisheh (2018) Optimal localist and distributed coding of spatiotemporal spike patterns through STDP and coincidence detection. Frontiers in Computational Neuroscience.
% with Matlab R2016b
% Aug 2018
% timothee.masquelier@cnrs.fr
%
% Several independent LIF neurons (which can differ by their
% threshold and dw_post, see param.m) integrate the same input spikes.
% A frozen input spike sequence (or several) repeat(s) perdiodically.
% Between these repetitions input spikes are random (homogeneous Poisson).
%
% All the numerical parameters are gathered in param.m
%
% This code is clock-driven.
% The LIF equation: tau_m*dV/dt = -V + I
% is integrated using forward Euler.
% The refractory period is ignored for simplicity.
%
% Synapses are equipped with spike timing-dependent plasticity (STDP).
% We will use the "classic" additive, all-to-all spikes with exponential
% windows proposed by Song, Miller & Abbott 2000 Nat Neurosc
%
% In addition, at each postsynaptic spike, all the synaptic weights are
% decreased by a fixed value dw_post (homeostatic mechanism), like in Kempter, Gerstner& van Hemmen 1999 Phys Rev E
%
% This script can be launched individually, with or without specifying a
% random seed (eg matlab  -r "seed=3;main")
% One can also launch multi threads with different random seeds using
% batch.py
%
% If ../data/pattern.mat exist, this/these pattern(s) will be used.
% Otherwise fresh ones are randomly generated.
%
% If ../data/w.mat exist, these weights will be used as initial weights (this can be useful to continue a simulation)
% Otherwise homogeneous initial weights are used.
%
% A convergence index is periodically stored ../data/conv.mat (unless in batch mode)
% If this file already exists, then new values are appended
% Otherwise a new file is created

if exist('seed','var') && seed>=0 % in case a seed is already defined (for example by calling matlab  -r "seed=3;main")
    rng(seed*sum(100*clock))
    timedLog(['Setting random seed to ' num2str(seed)])
    batch_mode = true;
else
    batch_mode = false;
    seed = -1;
end

% function sequence(seed)
% batch_mode = true;
% rng(seed)
% timedLog(['Setting random seed to ' num2str(seed)])

tic

%clear all

param

%__________________________________________
% INITIALIZATIONS
if tau_s > 0
    V_unit = tau_s/(tau_m-tau_s) * ( (tau_s/tau_m)^(tau_s/(tau_m-tau_s)) - (tau_s/tau_m)^(tau_m/(tau_m-tau_s)) ); % Maximum height of the postsynaptic potential caused by a unitary dirac input current
end

a_post = zeros(n_post,1); % LTP variable
a_pre = zeros(1,n_pre); % LTD variable

V_post = zeros(n_post,round(n_period_record*period/dt)); % Postsynaptic potential. Only most recent values are stored. Old values are overwritten.
I_post = zeros(n_post,1); % Postsynaptic current

adaptive_thr = zeros(length(thr),length(tau_thr));

if exist(['../data/w.' sprintf('%03d',seed) '.mat'],'file')
    disp(['Loading previous weights from ../data/w.' sprintf('%03d',seed) '.mat' ])
    load(['../data/w.' sprintf('%03d',seed) '.mat'])
else
    w = ones(n_post,n_pre); % Synaptic weights
    % w = intial_weight*2*rand(n_post,n_pre); % Synaptic weights
    %w = w_max*[ ones(3,1) ; zeros(4,1)];
    if tau_s==0
        w = w .* repmat(ones(n_post,1).*thr/(tau_m*n_pre*f)*1/(1+(2*tau_m*n_pre*f)^-.5*initial_distance_to_threshold),[1 n_pre]);
    else
        %w = w .* repmat(ones(n_post,1).*thr/60*10/f*1000/n_pre,[1 n_pre]);
        w = w .* repmat(ones(n_post,1).*thr/80*10/f*1000/n_pre,[1 n_pre]);
    end
    if max(w(:))>1
        warning('Some initial weights are > 1')
    end
end

if exist(['../data/pattern.' sprintf('%03d',seed) '.mat'],'file')
    disp(['Loading previous pattern(s) from ../data/pattern.' sprintf('%03d',seed) '.mat' ])
    load(['../data/pattern.' sprintf('%03d',seed) '.mat']) 
else
    clear pattern
    pattern{n_pattern}={};
    for p=1:n_pattern
        pattern{p} = sparse( rand(n_involved,round(pattern_duration/dt))<dt*f );
    end
    save(['../data/pattern.' sprintf('%03d',seed) '.mat'],'pattern')
end

 
% disp('*** Learn pattern by cheating *** ')
% w = 0*w;
% w(:,sum(pattern{1},2)>2) = 1;
% w(:,sum(pattern{1}(:,round(40e-3/dt+(1:tau_pre/dt))),2)>0) = 1;

count = 0; % Postsynaptic spike counter


% record output spikes
spike_list = zeros(10*n_post*n_period_record_spike*period,2); % we assume that max firing rate is 10 Hz
cursor = 1;

%__________________________________________
% INTEGRATION (FORWARD EULER)
for i=2:n_period*period/dt
    
    if mod(i*dt,100*period)==0
        if mod(i*dt,1000*period)==0
            disp(['Period ' int2str(i*dt/period)])
        end

        if exist(['../data/conv.' sprintf('%03d',seed) '.mat'],'file')
            load(['../data/conv.' sprintf('%03d',seed) '.mat']);
        else
            conv = [];
        end
        conv = [conv;mean(abs((w-(w>.5))),2)'];
        save(['../data/conv.' sprintf('%03d',seed) '.mat'],'conv')
        
        % % Uncomment this for increasing learning rate
        %if i*dt < (n_period-n_period_record_spike)*period
%             % arithmetic
%             dw_post = dw_post + alpha * dw_post / da_pre;
%             da_pre = da_pre + alpha;
            
%             % geometric
%             dw_post = dw_post * alpha ;
%             da_pre = da_pre * alpha;
        %end
    end
    
    % Presynaptic spikes
    j = round( mod(i,round(period/dt)) - (period-pattern_duration-2*jitter)/dt );
    if j>0 % inside pattern
        if j==1 % first time in pattern
            pattern_idx = 1+mod(floor((i-1)/(period/dt)),n_pattern);
            jittered_pattern = jitter_pattern(pattern{pattern_idx},jitter,f,dt);
        end
        if n_involved < n_pre
            pre_spikes = [ jittered_pattern(:,j) ; sparse( rand(n_pre-n_involved,1) < dt*f ) ] ;
        else
            pre_spikes = jittered_pattern(:,j);
        end
    else
        %pre_spikes =  sparse( rand(n_pre,1) < dt*f ); % This is the easiest way to generate Poisson spikes, but it is very slow.
        
        if mod(poisson_spike.cursor-1,poisson_spike.n)==0 % time to generate new random input spikes
            % disp('Generating new random input spikes')
            n_spike = poissrnd(poisson_spike.m*poisson_spike.n*poisson_spike.p);            
            position = randperm(poisson_spike.m*poisson_spike.n,n_spike);
            i_ = mod(position-1,poisson_spike.m)+1;
            j_ = floor((position-1)/poisson_spike.m)+1;
            poisson_spike.array = sparse(i_,j_,ones(1,n_spike),poisson_spike.m,poisson_spike.n);
            poisson_spike.array = poisson_spike.array>0;
            clear i_  j_ position n_spike
        end
        
        pre_spikes =  poisson_spike.array(:,mod(poisson_spike.cursor-1,poisson_spike.n)+1);
        poisson_spike.cursor = poisson_spike.cursor+1;
    end
    
    if tau_s > 0
        I_post = I_post*(1-dt/tau_s) + w * pre_spikes / V_unit;
    else
        I_post =  w * pre_spikes * tau_m/dt;
    end
   
    % Postsynaptic integration step
    V_post(:,mod(i-1,size(V_post,2))+1) = V_post(:,mod(i-2,size(V_post,2))+1) + dt/tau_m * ( -V_post(:,mod(i-2,size(V_post,2))+1) + I_post )  ;
    
    % adaptive_thr
    for t = 1:length(tau_thr)
        adaptive_thr(:,t) = adaptive_thr(:,t) * ( 1 - dt/tau_thr(t) );
    end
%     V_thr = V_thr + dt/tau_thr * ( -V_thr + V_post(i) )  ;
    

    % STDP
    if da_post(1)>0
        a_post = a_post*(1-dt/tau_post);
        % presynaptic spikes
        w(:,pre_spikes) = w(:,pre_spikes) - repmat(a_post,[1 sum(pre_spikes)]); % LTD
        w(:,pre_spikes) = max(w(:,pre_spikes),0); % Hard bounds
    end
    if da_pre(1)>0
        a_pre = a_pre*(1-dt/tau_pre);
        % presynaptic spikes
        a_pre(pre_spikes) = a_pre(pre_spikes)+da_pre; % update presynaptic traces
        % a_pre(pre_spikes) = da_pre; % update presynaptic traces
    end
    
    
    % postsynaptic spikes
    post_spikes = find(V_post(:,mod(i-1,size(V_post,2))+1)>= thr + sum(adaptive_thr,2));
    if ~isempty(post_spikes)
        if i*dt >= (n_period-n_period_record_spike)*period
            if cursor+length(post_spikes)-1 > size(spike_list,1)
                spike_list = [ spike_list ; zeros(10*n_post*n_period_record_spike*period,2) ]; % max firing rate is increased by 10 Hz
                warning('Increase initial size for spike_list array')
            end
            spike_list(cursor:cursor+length(post_spikes)-1,:) = [ i*dt*ones(size(post_spikes)) post_spikes ];
            cursor = cursor+length(post_spikes);
        end
        if da_post(1)>0
            a_post(post_spikes) = a_post(post_spikes)+da_post(post_spikes); % update postsynaptic traces
        end
        if da_pre(1)>0
            for s=1:length(post_spikes)
                w(post_spikes(s),:) = min(1, w(post_spikes(s),:) + a_pre .* w(post_spikes(s),:) .* (1-w(post_spikes(s),:)));
                % w(post_spikes(s),:) = min(1, w(post_spikes(s),:) + a_pre );
            end            
        end
        if dw_post(1)>0
            for s=1:length(post_spikes)
                w(post_spikes(s),:) = max(0, w(post_spikes(s),:) - dw_post(post_spikes(s)) * w(post_spikes(s),:) .* (1-w(post_spikes(s),:)));
                % w(post_spikes(s),:) = max(0, w(post_spikes(s),:) - dw_post(post_spikes(s)));
            end            
        end
        
        for s=1:length(post_spikes)
            adaptive_thr(post_spikes(s),:) = adaptive_thr(post_spikes(s),:) + d_thr * thr(post_spikes(s));
            % adaptive_thr(post_spikes(s),:) = adaptive_thr(post_spikes(s),:) + d_thr * ( thr(post_spikes(s)) + adaptive_thr(post_spikes(s),:) );
        end

        V_post(post_spikes,mod(i-1,size(V_post,2))+1)=0;
        count = count+length(post_spikes);
        
    end
    
end

% Put V_post in correct order
V_post = [ V_post(:,mod(i-1,size(V_post,2))+2:end) V_post(:,1:mod(i-1,size(V_post,2))+1)];

spike_list(cursor:end,:) = []; % remove unused values

save(['../data/w.' sprintf('%03d',seed) '.mat'],'w')

disp([num2str(count/n_post) ' postsynaptic spikes per neuron (' num2str(count/n_period/period/n_post,'%.1f') 'Hz)'])


perf % computes hit rates, false alarm rate etc.
% estimate_SNR

toc

if batch_mode
    exit
end

plots