%% Code to FORCE Train an interneuron network to oscillate at a frequency theta_int, while receiving a secondary frequency of theta_ms
clear all
clc
%close all
rng(1) %fix RNG just for refersion interneurons down the road (see documentation), user can change this otherwise.
%% Neuronal Parameters
NE = 2000; %number of excitatory neurons
NI = 2000; %number of inhibitory neuron
mu = NE/(NE+NI); %ratio of E to total
T = 40; %total simulation time (s)
N = NE + NI;
dt = 0.00005; %time step (s)
nt = round(T/dt);
tref = 0.002; %Refractory time constant in seconds
tm = 0.01; %Membrane time constant
vreset = -65; %Voltage reset
vpeak = -40; %Voltage peak.
td = 0.02; %decay time;
tr = 0.002; %rise time;
%%
lambda = dt*0.05; %Sets the rate of weight change, too fast is unstable, too slow is bad as well.
Pinv = eye(N)*lambda; %initialize the correlation weight matrix for RLMS
p = 0.1; %Set the network sparsity
%% Coupling Weight Matrix
disp('Initializing Weight Matrix')
G = 0.1; %Static Weight Magnitude
W = 15; %Recurrent Weight Magnitude
AW = 10; %Magnitude of Oscillatory Inputs
p = 0.1;
CE = round(p*NE); %Number of E connections
CI = round(p*NI); %Number of I connectiosn
tic
%create weight matrix
tempE = zeros(1,NE); tempE(1,1:CE) = 1;
tempI = zeros(1,NI); tempI(1,1:CI) = -1;
for i = 1:1:N
OMEGA(i,:) = G*[(sqrt(CI)/sqrt(CE))*tempE(randperm(NE))/sqrt(CE),tempI(randperm(NI))/sqrt(CI)];
end
toc
%% Kill EE and EI connections.
OMEGA(NE+1:N,1:NE)=0;
OMEGA(1:NE,1:NE)=0;
%% Input from medial septum
input = -(1+cos(2*pi*(1:1:nt)*dt*8)); %8hz input.
% Input weight
WIN(1:N,1) = AW;
WIN(1:NE,1) = 0; %Kill GABAergic inputs to Excitatory Neurons.
%% Basis and supervisor
nb = 100; %number of oscillators in the basis
for k = 1:nb
zx(k,:) = (cos(2*pi*(1:1:nt)*dt*8.5 + 2*pi*rand)); %theta_int, 8.5 hz.
end
%% FORCE parameters
imin = round(1/dt); %start RLS
icrit = round(21/dt); %stop RLS
step = 10; %implement RLS every step interval
%% initialization parameters for the network
k = min(size(zx)); %size of supervisor
IPSC = zeros(N,1); %post synaptic current storage variable
h = zeros(N,1); %Storage variable for filtered firing rates
r = zeros(N,1); %second storage variable for filtered rates
hr = zeros(N,1); %Third variable for filtered rates
JD = 0*IPSC; %storage variable required for each spike time
tspike = zeros(4*nt,2); %Storage variable for spike times
ns = 0; %Number of spikes, counts during simulation
z = zeros(k,1); %Initialize the approximant
BPhi = zeros(N,k); %The initial matrix that will be learned by FORCE method
v = vreset + rand(N,1)*(vpeak-vreset); %Initialize neuronal voltage with random distribtuions
v_ = v; %v_ is the voltage at previous time steps
RECB = zeros(nt,10); %Storage matrix for the synaptic weights (a subset of them)
kd = 0;
%% Encoders
E = zeros(N,k);
for j = 1:N
in = ceil(k*rand);
E(j,in) = W;
end
%% Storage matrices
nq = 20;
REC = zeros(nt,20);
REC2 = zeros(round(nt/nq),N);
current = zeros(nt,k); %storage variable for output current/approximant
i = 1;
%% auxiliary parameters to implement FORCE training in a plausible way.
z1 = z; z2 = z;
dec = zeros(N,2);
BPhi1 = 0*BPhi; BPhi2 = 0*BPhi;
EPlus = E;
EPlus(EPlus<0) = 0;
EMinus = E - EPlus;
mask1 = [ones(NE,k); -ones(NI, k)];
kd = 0;
tlast = zeros(N,1); %This vector is used to set the refractory times
%Parameter used to compute the histogram/population activity online
bin = zeros(round(nt/round(0.001/dt)),1);
binI = bin;
BIAS(1:NE,1)= -40; %Background current to Excitatory neurons
BIAS(NE+1:N,1)= 10; %Background current to Inhibitory neurons
bs = 0; bsI = 0;
ks = 0;
%% START INTEGRATION
for i = 1:1:nt
z1 = BPhi1' * r;
z2 = BPhi2' * r;
I = IPSC + BIAS + EPlus*z1 + EMinus*z2 + WIN.*(input(i))*(dt*i<30) ; %Current to Neurons
dv = (dt*i>tlast + tref).*( (-v+I)/tm) ; %Voltage equation with refractory period
v = v + dt*(dv);
index = find(v>=vpeak); %Find the neurons that have spiked
%Store spike times, and get the weight matrix column sum of spikers
if length(index)>0
JD = sum(OMEGA(:,index),2); %compute the increase in current due to spiking
tspike(ns+1:ns+length(index),:) = [index,0*index+dt*i]; %store spike times
ns = ns + length(index); % total number of spikes so far
bs = bs + length(index(index<NE));
bsI = bsI + length(index(index>=NE));
end
tlast = tlast + (dt*i -tlast).*(v>=vpeak); %Used to set the refractory period of LIF neurons
% Code if the rise time is 0, and if the rise time is positive
if tr == 0
IPSC = IPSC*exp(-dt/td)+ JD*(length(index)>0)/(td);
r = r *exp(-dt/td) + (v>=vpeak)/td;
else
IPSC = IPSC*exp(-dt/tr) + h*dt;
h = h*exp(-dt/td) + JD*(length(index)>0)/(tr*td); %Integrate the current
r = r*exp(-dt/tr) + hr*dt;
hr = hr*exp(-dt/td) + (v>=vpeak)/(tr*td);
end
%Implement RLS
z = BPhi'*r; %approximant
err = z - zx(:,i); %error
%% RLMS
if mod(i,step)==1
if i > imin
if i < icrit
cd = Pinv*r;
BPhi = BPhi - (cd*err')/(1+(r')*cd);
Pinv = Pinv -((cd)*(cd'))/( 1 + (r')*(cd));
BPhiEP = BPhi(1:NE,:).*(BPhi(1:NE,:)>0);
BPhiEM = BPhi(1:NE,:).*(BPhi(1:NE,:)<0);
BPhiIP = BPhi(NE+1:N,:).*(BPhi(NE+1:N,:)>0);
BPhiIM = BPhi(NE+1:N,:).*(BPhi(NE+1:N,:)<0);
BPhi1 = [BPhiEP;BPhiIM*mu/(1-mu)];
BPhi2 = [BPhiEM;BPhiIP*mu/(1-mu)];
end
end
end
v = v + (30 - v).*(v>=vpeak); %rest the voltage and apply a cosmetic spike.
REC(i,:) = [v(1:10);v(NE+1:NE+10)]; %Record a random voltage
v = v + (vreset - v).*(v>=vpeak); %reset spike time
current(i,:) = z; %store oscillators
RECB(i,:) = BPhi(NE+1:NE+10); %store 10 inhibitory weights
% Store filtered spike trains for 10 neurons.
if mod(i,nq)==1
kd = kd + 1;
REC2(kd,:) = v;
end
% compute histogram
if mod(i,round(0.001/dt))==1
ks = ks + 1;
bin(ks) = bs;
binI(ks) = bsI;
bs = 0; bsI = 0;
end
%% plotting results
if mod(i,round(0.5/dt))==1
prog = dt*i/T
drawnow
figure(100)
%plot voltage traces
for j = 1:1:20
if j > 10
plot((1:1:i)*dt,REC(1:1:i,j)/(30-vreset)+j,'b'), hold on
else
plot((1:1:i)*dt,REC(1:1:i,j)/(30-vreset)+j,'r'), hold on
end
end
xlabel('Time (s)')
ylabel('Voltage (mv)')
ylim([0,21])
hold off
%plot supervisor and decoded network approximant
figure(200)
for ffd = 1:3
plot(dt*(1:1:i),zx(ffd,1:1:i)/(max(zx(ffd,:))-min(zx(ffd,:)))+ffd,'k','LineWidth',2), hold on
plot(dt*(1:1:i),current(1:1:i,ffd)/(max(zx(ffd,:))-min(zx(ffd,:)))+ffd,'LineWidth',2)
end
xlim([dt*i-1,dt*i])
xlabel('Time (s)')
ylabel('Decoded Oscillators, \theta_{int}')
hold off
%plot histograms
figure(90)
plot(dt*i*(1:ks)/ks,bin(1:ks),'r'), hold on
plot(dt*i*(1:ks)/ks,binI(1:ks),'b'), hold off
xlabel('Time (s)')
ylim([0,100])
ylabel('Population Activity')
%plot the decoders as they are being learned.
figure(5)
plot(dt*(1:1:i),RECB(1:1:i,1:10),'.')
xlabel('Time (s)')
ylabel('Decoders')
end
end
%% save the data, this file will be about 1-2gig's (depending on what network size, much larger for O(10^3)>neurons).
save force_trained.mat -v7.3
%%
clear all
clc
sorting_script