%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% This script does trains the read-out synapses %%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
rast = zeros(neuronNum,(tEnd - tStart)/tStep + 1); %Matrix storing spike times for raster plots
rast_binary = zeros(neuronNum,(tEnd - tStart)/tStep + 1); %same but with binary numbers
rast_A = zeros(actionNeuronNum,(tEnd - tStart)/tStep + 1); %storing spike times of read-out neurons (binary)
rast_S = zeros(superNeuronNum,(tEnd - tStart)/tStep + 1); %storing spike times of supervisor neurons (binary)
rast_H = zeros(interNeuronNum,(tEnd - tStart)/tStep + 1); %storing spike times of interneurons (binary)
%for refractory period calculation (just big negative number)
lastAP = -50 * ones(1,neuronNum);
lastAPA = -50*ones(1,actionNeuronNum);
lastAPS = -50*ones(1,superNeuronNum);
lastAPH = -50*ones(1,interNeuronNum);
%membrane potential
memVol = Vreset+(V_T-Vreset)*rand(neuronNum,(tEnd - tStart)/tStep + 1);
memVolA = Vreset+(V_T-Vreset)*rand(actionNeuronNum,(tEnd - tStart)/tStep + 1);
memVolS = Vreset +(V_T - Vreset)*rand(superNeuronNum,(tEnd - tStart)/tStep + 1);
memVolH = Vreset +(V_T - Vreset)*rand(interNeuronNum,(tEnd - tStart)/tStep + 1);
%for supervisor, variables to keep track of when stimulation begins and finishes
begin = false;
begin_time = [];
finish = false;
for i =2:(tEnd - tStart)/tStep
if mod(i,1000/tStep)==0
i/10 %print every second elapsed time in ms
end
forwardInputsE = zeros(1,neuronNum);
forwardInputsI = zeros(1,neuronNum);
forwardInputsAE = zeros(1,actionNeuronNum);
forwardInputsAI = zeros(1,actionNeuronNum);
forwardInputsS = zeros(1,superNeuronNum);
forwardInputsH = zeros(1,interNeuronNum);
%%%%%%%%%%%%%%%%%%%
%%%CLOCK NETWORK%%%
%%%%%%%%%%%%%%%%%%%
for j = 1:neuronNum
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%EXTERNAL INPUT
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
while i*tStep > nextx(j)
nextx(j) = nextx(j) + exprnd(1)/rx(j);
if j <= EneuronNum
forwardInputsEPrev(j) = forwardInputsEPrev(j) + Jeex;
else
forwardInputsEPrev(j) = forwardInputsEPrev(j) + Jiex;
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%CONNCECTIVITY CALCULATIONS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
xerise(j) = xerise(j) -tStep*xerise(j)/tauerise + forwardInputsEPrev(j);
xedecay(j) = xedecay(j) -tStep*xedecay(j)/tauedecay + forwardInputsEPrev(j);
xirise(j) = xirise(j) -tStep*xirise(j)/tauirise + forwardInputsIPrev(j);
xidecay(j) = xidecay(j) -tStep*xidecay(j)/tauidecay + forwardInputsIPrev(j);
gE = (xedecay(j) - xerise(j))/(tauedecay - tauerise);
gI = (xidecay(j) - xirise(j))/(tauidecay - tauirise);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%EXCITATORY NEURONS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if(j <= EneuronNum)
x(j) = x(j) - (tStep/tau_x)*x(j);
w(j) = w(j) + (tStep/tau_w)*(a*(memVol(j,i-1) - V_E) - w(j)); %adaptation current
EVthreshold(j) = EVthreshold(j) + (tStep/tau_T)*(V_T - EVthreshold(j)); %adapting threshold
%cell dynamics
v = memVol(j,i-1) + (tStep/tau_E)*(-memVol(j,i-1) + V_E + DET*exp((memVol(j,i-1)-EVthreshold(j))/DET)) ...
+ (tStep/C)*(gE*(E_E - memVol(j,i-1)) + gI*(E_I - memVol(j,i-1)) - w(j));
if ((lastAP(j) + tau_abs/tStep)>=i) %Refractory Period
v = Vreset;
end
if (v > Vthres) %Fire if exceed threshold
v = Vreset;
lastAP(j) = i;
rast(j,i) = j;
rast_binary(j,i) = 1;
forwardInputsE = forwardInputsE + [weightsEE(:,j);weightsIE(:,j)]';
forwardInputsAE = forwardInputsAE + W_AE(:,j)';
EVthreshold(j) = EVthreshold(j) + A_T;
w(j) = w(j) + b;
x(j) = x(j) + 1;
end
memVol(j,i) = v;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%INHIBITORY NEURONS
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if (j > EneuronNum)
%cell dynamics
v = memVol(j,i-1) + (tStep/tau_I)*(-memVol(j,i-1) + V_I) + ...
(tStep/C)*(gE*(E_E - memVol(j,i-1)) + gI*(E_I - memVol(j,i-1)));
if ((lastAP(j) + tau_abs/tStep)>=i) %Refractory Period
v = Vreset;
end
if (v > V_T) %Fire if exceed threshold
v = Vreset;
lastAP(j) = i;
rast(j,i) = j;
rast_binary(j,i) = 1;
forwardInputsI = forwardInputsI + [weightsEI(:,j-EneuronNum);weightsII(:,j-EneuronNum)]';
end
memVol(j,i) = v;
end
end
%variables to keep track of learning
%external input to supervisor neurons is turned on at the beginning of
%the sequential activity in the recurrent network
states = zeros(numClusters,1);
if i>500 %burn-in time 50ms
for k = 1:numClusters
states(k) = sum(sum(rast_binary(1+(k-1)*EneuronNum/numClusters:k*EneuronNum/numClusters,i-100:i)));
end
temp = find(states==max(states));
if ~begin && temp(1)==1
begin = true;
finish = false;
begin_time = [begin_time;i];
end
if ~finish && begin && i-begin_time(end)>superLength/tStep
finish = true;
begin = false;
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% SUPERVISOR NEURONS %%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
for j = 1:superNeuronNum
if begin && ~finish %add external input to supervisor
rsex = supervisor(j,i-begin_time(end)+1);
else
rsex = 1; %baseline external input to supervisor neurons
end
while i*tStep>nextxS(j)
nextxS(j) = nextxS(j) + exprnd(1)/rsex;
forwardInputsSPrev(j) = forwardInputsSPrev(j) + Jsex;
end
xeriseS(j) = xeriseS(j) -tStep*xeriseS(j)/tauerise + forwardInputsSPrev(j);
xedecayS(j) = xedecayS(j) -tStep*xedecayS(j)/tauedecay + forwardInputsSPrev(j);
gE = (xedecayS(j) - xeriseS(j))/(tauedecay - tauerise);
wS(j) = wS(j) + (tStep/tau_wS)*(aS*(memVolS(j,i-1) - V_E) - wS(j)); %adaptation current
EVthresholdS(j) = EVthresholdS(j) + (tStep/tau_T)*(V_T - EVthresholdS(j)); %adapting threshold
%voltage dynamics
v = memVolS(j,i-1) + (tStep/tau_E)*(-memVolS(j,i-1) + V_E + DET*exp((memVolS(j,i-1)-EVthresholdS(j))/DET)) ...
+ (tStep/C)*( gE*(E_E - memVolS(j,i-1)) - wS(j));
if ((lastAPS(j) + tau_absS/tStep)>=i) %Refractory Period
v = Vreset;
end
if (v > Vthres) %Fire if exceed threshold
v = Vreset;
lastAPS(j) = i;
rast_S(j,i) = j;
forwardInputsAE(j) = forwardInputsAE(j) + Jas;
wS(j) = wS(j) + bS;
EVthresholdS(j) = EVthresholdS(j) + A_T;
end
memVolS(j,i) = v;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% INTER NEURON %%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
for j=1:interNeuronNum
while i*tStep>nextxH(j)
nextxH(j) = nextxH(j) + exprnd(1)/rhx;
forwardInputsHPrev(j) = forwardInputsHPrev(j) + jHX;
end
xeriseH(j) = xeriseH(j) -tStep*xeriseH(j)/tauerise + forwardInputsHPrev(j);
xedecayH(j) = xedecayH(j) -tStep*xedecayH(j)/tauedecay + forwardInputsHPrev(j);
gE = (xedecayH(j) - xeriseH(j))/(tauedecay - tauerise);
%cell dynamics
v = memVolH(j,i-1) + (tStep/tau_I)*(-memVolH(j,i-1) + V_I) + ...
(tStep/C)*( gE*(E_E - memVolH(j,i-1)) );
if ((lastAPH + tau_absH/tStep)>=i) %Refractory Period
v = Vreset;
end
if (v > V_T) %Fire if exceed threshold
v = Vreset;
lastAPH = i;
rast_H(j,i) = 1;
forwardInputsAI(j) = forwardInputsAI(j) + jAH;
end
memVolH(j,i) = v;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%% READ_OUT NEURONS %%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%
for j = 1:actionNeuronNum
%vSTDP
u(j) = u(j) + (memVolA(j,i-1) - u(j))*tStep/tau_u;
vs(j) = vs(j) + (memVolA(j,i-1) - vs(j))*tStep/tau_vs;
xeriseA(j) = xeriseA(j) -tStep*xeriseA(j)/tauerise + forwardInputsAEPrev(j);
xedecayA(j) = xedecayA(j) -tStep*xedecayA(j)/tauedecay + forwardInputsAEPrev(j);
xidecayA(j) = xidecayA(j) - tStep*xidecayA(j)/tauidecay + forwardInputsAIPrev(j);
xiriseA(j) = xiriseA(j) - tStep*xiriseA(j)/tauirise + forwardInputsAIPrev(j);
gE = (xedecayA(j) - xeriseA(j))/(tauedecay - tauerise);
gI = (xidecayA(j) - xiriseA(j))/(tauidecay - tauirise);
wA(j) = wA(j) + (tStep/tau_wA)*(aA*(memVolA(j,i-1) - V_E) - wA(j)); %adaptation current
EVthresholdA(j) = EVthresholdA(j) + (tStep/tau_T)*(V_T - EVthresholdA(j)); %adapting threshold
%voltage dynamics
v = memVolA(j,i-1) + (tStep/tau_E)*(-memVolA(j,i-1) + V_E + DET*exp((memVolA(j,i-1)-EVthresholdA(j))/DET)) ...
+ (tStep/C)*(gE*(E_E - memVolA(j,i-1)) + gI*(E_I - memVolA(j,i-1)) - wA(j));
if ((lastAPA(j) + tau_absA/tStep)>=i) %Refractory Period
v = Vreset;
end
if (v > Vthres) %Fire if exceed threshold
v = Vreset;
lastAPA(j) = i;
rast_A(j,i) = j;
forwardInputsH(j) = forwardInputsH(j) + jHA;
wA(j) = wA(j) + bA;
EVthresholdA(j) = EVthresholdA(j) + A_T;
end
memVolA(j,i) = v;
%%%%%%%%%%%%%%%%%%%%%%
%%% PLASTICITY %%%
%%%%%%%%%%%%%%%%%%%%%%
A_LTPCORR = A_LTP*(w_max - W_AE(j,:))/w_max;
LTP = A_LTPCORR.*x*max(memVolA(j,i)-th_LTP,0)*max(vs(j)-th_LTD,0);
LTD = A_LTD*rast_binary(1:EneuronNum,i)'*max(u(j)-th_LTD,0);
W_AE(j,:) = W_AE(j,:) + tStep*(LTP - LTD);
idx = find(W_AE(j,:)<0); %minimum weight is zero
W_AE(j,idx) = 0;
end
forwardInputsEPrev = forwardInputsE;
forwardInputsIPrev = forwardInputsI;
forwardInputsAEPrev = forwardInputsAE;
forwardInputsAIPrev = forwardInputsAI;
forwardInputsSPrev = forwardInputsS;
forwardInputsHPrev = forwardInputsH;
end