% This code was used in: Masquelier T (2017) STDP allows close-to-optimal spatiotemporal spike pattern detection by single coincidence detector neurons. Neuroscience.
% https://doi.org/10.1016/j.neuroscience.2017.06.032
% Jan 2017
% timothee.masquelier@cnrs.fr
%
% Several independent LIF neurons (which can differ by their
% threshold and dw_post, see param.m) integrate the same input spikes.
% A frozen input spike sequence (or several) repeat(s) perdiodically.
% Between these repetitions input spikes are random (homogeneous Poisson).
%
% All the numerical parameters are gathered in param.m
%
% This code is clock-driven.
% The LIF equation: tau_m*dV/dt = -V + I
% is integrated using forward Euler.
% The refractory period is ignored for simplicity.
%
% Synapses are equipped with spike timing-dependent plasticity (STDP).
% We will use the "classic" additive, all-to-all spikes with exponential
% windows proposed by Song, Miller & Abbott 2000 Nat Neurosc
%
% In addition, at each postsynaptic spike, all the synaptic weights are
% decreased by a fixed value dw_post (homeostatic mechanism), like in Kempter, Gerstner& van Hemmen 1999 Phys Rev E
%
% This script can be launched individually, with or without specifying a
% random seed (eg matlab  -r "seed=3;main")
% One can also launch multi threads with different random seeds using
% batch.py
%
% If ../data/pattern.mat exist, this/these pattern(s) will be used.
% Otherwise fresh ones are randomly generated.
%
% If ../data/w.mat exist, these weights will be used as initial weights (this can be useful to continue a simulation)
% Otherwise homogeneous initial weights are used.
%
% A convergence index is periodically stored ../data/conv.mat (unless in batch mode)
% If this file already exists, then new values are appended
% Otherwise a new file is created

if exist('seed','var') % in case a seed is already defined (for example by calling matlab  -r "seed=3;main")
    rng(seed)
    timedLog(['Setting random seed to ' num2str(seed)])
    batch_mode = true;
else
    batch_mode = false;
end

% function sequence(seed)
% batch_mode = true;
% rng(seed)
% timedLog(['Setting random seed to ' num2str(seed)])

tic

%clear all

param

%__________________________________________
% INITIALIZATIONS
if tau_s > 0
    V_unit = tau_s/(tau_m-tau_s) * ( (tau_s/tau_m)^(tau_s/(tau_m-tau_s)) - (tau_s/tau_m)^(tau_m/(tau_m-tau_s)) ); % Maximum height of the postsynaptic potential caused by a unitary dirac input current
end

a_post = zeros(n_post,1); % LTP variable
a_pre = zeros(1,n_pre); % LTD variable

V_post = zeros(n_post,round(n_period_record*period/dt)); % Postsynaptic potential. Only most recent values are stored. Old values are overwritten.
I_post = zeros(n_post,1); % Postsynaptic current

if exist('../data/w.mat','file')
    disp('Loading previous weights')
    load ../data/w.mat
else
    w = ones(n_post,n_pre); % Synaptic weights
    % w = intial_weight*2*rand(n_post,n_pre); % Synaptic weights
    %w = w_max*[ ones(3,1) ; zeros(4,1)];
    if tau_s==0
        w = w .* repmat(ones(n_post,1).*thr/(tau_m*n_pre*f)*1/(1+(2*tau_m*n_pre*f)^-.5*initial_distance_to_threshold),[1 n_pre]);
    else
        %w = w .* repmat(ones(n_post,1).*thr/60*10/f*1000/n_pre,[1 n_pre]);
        w = w .* repmat(ones(n_post,1).*thr/80*10/f*1000/n_pre,[1 n_pre]);
    end
    if max(w(:))>1
        warning('Some initial weights are > 1')
    end
end

if exist('../data/pattern.mat','file')
    disp('Loading previous pattern(s)')
    load ../data/pattern.mat
else
    clear pattern
    pattern{n_pattern}={};
    for p=1:n_pattern
        pattern{p} = sparse( rand(n_involved,round(pattern_duration/dt))<dt*f );
    end
    if ~batch_mode
        save ../data/pattern pattern
    end
end

 
% disp('*** Learn pattern by cheating *** ')
% w = 0*w;
% w(:,sum(pattern{1},2)>2) = 1;
% w(:,sum(pattern{1}(:,round(40e-3/dt+(1:tau_pre/dt))),2)>0) = 1;

count = 0; % Postsynaptic spike counter


% record output spikes
spike_list = zeros(n_period_record_spike*10*n_post,2);
cursor = 1;

%__________________________________________
% INTEGRATION (FORWARD EULER)
for i=2:n_period*period/dt
    
    if ~batch_mode
        if mod(i*dt,100*period)==0
            disp(['Period ' int2str(i*dt/period)])
            
            if exist('../data/conv.mat','file')
                load('../data/conv.mat');
            else
                conv = [];
            end
            conv = [conv;mean(mean(abs((w-(w>.5)))))];
            save ../data/conv.mat conv
        end
    end
    
    % Presynaptic spikes
    j = round( mod(i,round(period/dt)) - (period-pattern_duration-2*jitter)/dt );
    if j>0 % inside pattern
        if j==1 % first time in pattern
            pattern_idx = 1+mod(floor((i-1)/(period/dt)),n_pattern);
            jittered_pattern = jitter_pattern(pattern{pattern_idx},jitter,f,dt);
        end
        pre_spikes = [ jittered_pattern(:,j) ; sparse( rand(n_pre-n_involved,1) < dt*f ) ] ;
    else
        pre_spikes =  sparse( rand(n_pre,1) < dt*f );
    end
    
    if tau_s > 0
        I_post = I_post*(1-dt/tau_s) + w * pre_spikes / V_unit;
    else
        I_post =  w * pre_spikes * tau_m/dt;
    end
   
    % Postsynaptic integration step
    V_post(:,mod(i-1,size(V_post,2))+1) = V_post(:,mod(i-2,size(V_post,2))+1) + dt/tau_m * ( -V_post(:,mod(i-2,size(V_post,2))+1) + I_post )  ;
    
%     % adaptive_thr
%     V_thr = V_thr + dt/tau_thr * ( -V_thr + V_post(i) )  ;

    % STDP
    if da_post(1)>0
        a_post = a_post*(1-dt/tau_post);
        % presynaptic spikes
        w(:,pre_spikes) = w(:,pre_spikes) - repmat(a_post,[1 sum(pre_spikes)]); % LTD
        w(:,pre_spikes) = max(w(:,pre_spikes),0); % Hard bounds
    end
    if da_pre(1)>0
        a_pre = a_pre*(1-dt/tau_pre);
        % presynaptic spikes
        a_pre(pre_spikes) = a_pre(pre_spikes)+da_pre; % update presynaptic traces
    end
    
    
    % postsynaptic spikes
    post_spikes = find(V_post(:,mod(i-1,size(V_post,2))+1)>=thr);
    if ~isempty(post_spikes)
        if i*dt >= (n_period-n_period_record_spike)*period
            spike_list(cursor:cursor+length(post_spikes)-1,:) = [ i*dt*ones(size(post_spikes)) post_spikes ];
            cursor = cursor+length(post_spikes);
        end
        if da_post(1)>0
            a_post(post_spikes) = a_post(post_spikes)+da_post(post_spikes); % update postsynaptic traces
        end
        if da_pre(1)>0
            w(post_spikes,:) = w(post_spikes,:) + repmat(a_pre,[length(post_spikes),1]); % LTP
            w(post_spikes,:) = min(w(post_spikes,:),1); % Hard bounds
            
            % Soft bounds
            %w(post_spikes,:) = w(post_spikes,:) + w(post_spikes,:).*(1-w(post_spikes,:)).*repmat(a_pre,[length(post_spikes),1]); % LTP
        end
        if dw_post(1)>0
            % coef = min(1,.01+i/((n_period-n_period_record)*period/dt)*.99);
            %w(post_spikes,:) = w(post_spikes,:) - repmat((dw_post_final-dw_post(post_spikes))*min(1,i/(500*period/dt))+dw_post(post_spikes),[1 n_pre]); % Homeostatic
            
            w(post_spikes,:) = w(post_spikes,:) - repmat(dw_post(post_spikes),[1 n_pre]); % Homeostatic
            w(post_spikes,:) = max(w(post_spikes,:),0); % Hard bounds
            
            % Soft bounds
            %w(post_spikes,:) = w(post_spikes,:) -  w(post_spikes,:).*(1-w(post_spikes,:)).*repmat(dw_post(post_spikes),[1 n_pre]); % Homeostatic
        end

        V_post(post_spikes,mod(i-1,size(V_post,2))+1)=0;
        count = count+length(post_spikes);
    end
    
end

% Put V_post in correct order
V_post = [ V_post(:,mod(i-1,size(V_post,2))+2:end) V_post(:,1:mod(i-1,size(V_post,2))+1)];

if cursor == length(spike_list)+1
    warning('Increase initial size for spike_list array')
end
spike_list(cursor:end,:) = []; % remove unused values

if ~batch_mode
    save ../data/w w
end

disp([num2str(count/n_post) ' postsynaptic spikes per neuron (' num2str(count/n_period/period/n_post,'%.1f') 'Hz)'])

% if n_post==1
%     disp(['Expected V_noise = ' num2str( tau_m*sum(w)*f ) ])
%     disp(['Estimated peak 1 = ' num2str( tau_m*sum(w)*f + (1-exp(-tau_pre/tau_m))*( tau_m*sum(sum(pattern{1}(:,round(10e-3/dt+(1:tau_pre/dt)))))/tau_pre - tau_m*sum(w)*f ) )])
%     %disp(['Estimated peak 2 = ' num2str( tau_m*sum(w)*f + (1-exp(-tau_pre/tau_m))*( tau_m*sum(sum(pattern{1}(:,round(80e-3/dt+(1:tau_pre/dt)))))/tau_pre - tau_m*sum(w)*f ) )])
% end


perf % computes hit rates, false alarm rate etc.
% estimate_SNR

toc

if batch_mode
    exit
end