#include <math.h>
#include "mex.h"
#if !defined(MAX)
#define MAX(A, B) ((A) > (B) ? (A) : (B))
#endif
#if !defined(MIN)
#define MIN(A, B) ((A) < (B) ? (A) : (B))
#endif
/*
*/
/* Revised June 26, 2011 by Kieran Bol - kieran_bol@hotmail.com */
void timeloop(double sptime[], double bursttime4[], double bursttime2[],
double w[], double signal[], double weight[], double rec[],double f, double g, double Lambda,
double eta, double eta2, mwSize numw, mwSize m, double tau_m, double tau_w, double delt, double wmax)
{
/*Parameters from Noonan et al. 2003*/
double A=0.15*4, B=2, alpha=20,
beta=0.35, D=0.1, E=3.5, tauref=0.1,taudend=50., b=0,
somawidth=0.05*4, dendwidth=1.0, taudecay=1., thresh=1.0;
double realt, avgw, Lwidth, Lwidth2, Bdef4, Bdef2, lt, nper, burstT, *pfspike, *L, v=0.025, tref=0, Dxh, Dwh, Dyh, Dsh, Dx=0, Dy=0, Ds=0, Dw=0;
mwSize n=1, i, k, j, reci=0, index=4, index4=0, index2=0, count4=4, count2=2, countr=0;
/*PF initialization */
pfspike=mxGetPr(mxCreateDoubleMatrix(3*numw,1, mxREAL));
L=mxGetPr(mxCreateDoubleMatrix(3*numw,1, mxREAL));
/*Burst definition and width*/
Bdef2=0.015/tau_m; /*defining a 2-spike burst (time is in units of tau_m within code so 15 ms = 0.015/tau_m)*/
Bdef4=3.0*Bdef2; /*4-spike burst has 3 ISIs so 3 x 2-spike burst definition */
Lwidth=0.1/tau_m; /*100 ms 4-spike burst learning rule width, from experimental data*/
Lwidth2=0.01/tau_m; /*10 ms 2-spike STDP width*/
tau_w=tau_w/tau_m;
f=f*tau_m;
nper=1/f/delt; /*number of time-steps in 1 period*/
/*Mapping given weight distribution to dynamic weight matrix w */
for(i=0;i<numw;i++){
w[i]=weight[i];
}
/* pfspike gives start times (i.e. firing times) of each PF over 3 periods */
for (i=0;i<3*numw;i++){
pfspike[i]=ceil(nper/numw*i)*delt; /* nper/numw*delt= T/numw = time span of each segment*/
}
/* Initializations*/
sptime[0]=-100; sptime[1]=-100; sptime[2]=-100; sptime[3]=-100;
/*----Time loop----*/
for (i=0;i<m;i++){
realt=delt*i; /*realt in units of tau_m, so not real time per say*/
k= (int) floor(fmodf(numw*realt/nper/delt,numw));
/*k is an integer that increases stepwise and signals the start of a new PF */
/*so for 0-2.5ms, k=0, for 2.5-5ms, k=1, etc., and loops back to k=0 when period is over (hence fmodf of numw)*/
/*if realt is greater than the refractory period*/
if(realt>tref){
/*GOVERNING EQUATION*/
v= v+delt*(-v + signal[i] + Lambda*(w[k]-g*v));
/*k is used to change the index of the weight that is active at a given time*/
/*Note absence of DAP*/
/* Do this if ISI beyond dendritic ref. period*/
if(taudend<(sptime[index-1]-sptime[index-2])){
if (dendwidth*Dx-somawidth*Ds > 0){ /*DAP is rectified*/
v=v+delt*alpha*(dendwidth*Dx-somawidth*Ds); /*DAP is added*/
}
}
}/*end of if realt > tref */
/*if V >threshold, a spike is fired and a burst might be recorded*/
if(v>thresh) {
v=0; /*v reset to 0*/
tref=realt+tauref+delt/2; /*refractory period is updated*/
sptime[index]=realt; /*spike is recorded*/
index++; /*index now moves to vacant position*/
count4--; /*so each 4-sp burst has 4 unique spikes: COMMENT OUT THIS LINE TO REMOVE DETECTION OF 4-SPIKE BURSTS*/
count2--; /* so each 2-spike burst has 2 unique spikes: COMMENT OUT THIS LINE TO REMOVE DETECTION OF 2-SPIKE BURSTS */
/*DAP parameters are updated*/
b=b+A+ B*b*b; /*updating b*/
taudend=D+E*b; /*updating dendritic refractory period*/
dendwidth=beta*b;
Dy=Dy+1/dendwidth/dendwidth;
Dw=Dw+1/somawidth/somawidth;
/*if the last spike occurred within Bdef2 of this spike and count2<1, then record a 2-sp burst*/
if((realt-sptime[index-2]<Bdef2)&& (count2<1)) { /*count2 makes sure that bursts don't share spikes */
bursttime2[index2]=sptime[index-2];/* tracks the 1st spike in burst (hence "index-2") */
index2++; /*2sp burst index moves up 1*/
count2=2; /* reset the count*/
/*learning 2sp rule*/
burstT= fmod(sptime[index-2],nper*delt)+nper*delt;
/*burstT = time of SP burst, mod the period of AM (i.e. 760 ms = 10 ms after start of 4 Hz period) + 1 period*/
/*to make sure that a burst at the end of a period affects weights at the beginning of the next cycle,
and same with a burst at the beginning of a period affecting weights at the end of the last cycle,
pfspike has PF "firing" times for 3 periods and burstT adds a period (i.e. +nper*delt) to the burst time*/
/*Also, since I know exactly when PFs will fire in the future,
I apply the learning rule both pre-post and post-pre when the SP cell fires*/
for (j=0;j<3*numw;j++){
L[j]=1-pow((pfspike[j]-burstT)/Lwidth2,2); /* Quadratic Learning rule for each PF time */
if(L[j]<0){L[j]=0;} /*rectification of the learning rule (so it's strictly inhibitory)*/
}
for (j=0;j<numw;j++){
w[j]=w[j]-eta2*w[j]*(L[j]+L[j+numw]+L[j+2*numw]); /*weights updated*/
if(w[j]<0){w[j]=0;} /*depression at that weight's segment from each of the 3 periods looked at is added together*/
} /*for 2-spike bursts, Lwidth2 is small, so L <0 --> L=0 often*/
}/*End of if 2-spike burst occurred*/
/*if the time between this spike and the 4th last spike is less than Bdef4, record a 4-spike burst*/
if((realt-sptime[index-4]<Bdef4)&& (count4<1)) {
bursttime4[index4]=sptime[index-4];/* tracks the 1st spike in burst*/
index4++;
count4=4; /*no overlapping 4 sp bursts*/
count2=2; /* so 2sp burst can't use last spike in 4 sp burst*/
/*since weights change immediately, once a 4-spike burst is identified,
a 2-spike burst has likely just occurred and must be removed
(so that a 4-spike burst is not mistakenly double counted as also having 2-spike bursts in it) */
/*UNLEARNING LOOP: 2sp bursts within the 4sp burst*/
while(bursttime4[index4-1]-bursttime2[index2-1] < delt){
burstT= fmod(bursttime2[index2-1],nper*delt)+nper*delt;
for (j=0;j<3*numw;j++){
L[j]=1-pow((pfspike[j]-burstT)/Lwidth2,2);
if(L[j]<0){L[j]=0;}
}
for (j=0;j<numw;j++){
w[j]=w[j]/(1-eta2*(L[j]+L[j+numw]+L[j+2*numw]));
if(w[j]<0){w[j]=0;}
}
/*this finds the effect of 2-sp burst that happened and does the inverse operation
Technically, the weights have changed since the burst because of potentiation rule,
but it is negligible (time elapsed ~45 ms compared to tau_w = 980s ) */
index2--; /*record of 2sp burst erased*/
countr++;
} /* repeat unlearning loop until no 2sp burst in last 4 spikes (i.e. could be 0, 1, or 2 bursts)*/
/*UNLEARNING LOOP: Removing 2sp burst that used the 1st spike in 4sp burst as its last spike*/
if (bursttime2[index2-1]==sptime[index-5]){
burstT= fmod(bursttime2[index2-1],nper*delt)+nper*delt;
for (j=0;j<3*numw;j++){
L[j]=1-pow((pfspike[j]-burstT)/Lwidth2,2);
if(L[j]<0){L[j]=0;}
}
for (j=0;j<numw;j++){
w[j]=w[j]/(1-eta2*(L[j]+L[j+numw]+L[j+2*numw]));
if(w[j]<0){w[j]=0;}
}
index2--;
countr++;
} /* 5th spike unlearning loop*/
/*learning 4sp rule*/
burstT= fmod(sptime[index-4],nper*delt)+nper*delt; /*range= [T,2T) */
/*with Lwidth4 being large compared to T at high AM freqs, using 3 periods means that
sometimes one PF will be affected multiple times by 1 burst (i.e. if a PF burst 50ms
before and 50ms after a SP cell burst, then PF will be depressed by the sum of both).
This effect is limited to 3 periods, so at 20 Hz and especially at 32 Hz, the effect of 4-spike
bursts are clipped at the ends. This was for computational simplicity, but it is unknown
how this situation is resolved in vivo anyway. */
for (j=0;j<3*numw;j++){
L[j]=1-pow((pfspike[j]-burstT)/Lwidth,2); /*Learning rule*/
if(L[j]<0){L[j]=0;}
}
for (j=0;j<numw;j++){
w[j]=w[j]-eta*w[j]*(L[j]+L[j+numw]+L[j+2*numw]); /*weight update*/
if(w[j]<0){w[j]=0;}
}
} /*end of if 4sp burst... */
} /*end of if fired...*/
/*Dendritic alpha f'n: how to code it dynamically*/
Dxh=Dx+delt*Dy; /*%D for DAP = dendritic after-polarization*/
Dyh=Dy+delt*(-Dx/(dendwidth*dendwidth)-2*Dy/dendwidth);
Dx=Dxh;
Dy=Dyh;
/*Somatic alpha f'n */
Dsh=Ds+delt*Dw;
Dwh=Dw+delt*(-Ds/(somawidth*somawidth)-2*Dw/somawidth);
Ds=Dsh;
Dw=Dwh;
b=b +delt*(-b/taudecay); /*b dynamically decays*/
/*Potentiation rule*/
for (j=0;j<numw;j++){
w[j]=w[j]+delt/tau_w*(wmax-w[j]);
}
}/*----end of time loop----*/
} /*end of function*/
/*The function below allows MATLAB to compile and communicate with the upper function, including input/output arrays, etc. */
void mexFunction( int nlhs, mxArray *plhs[],
int nrhs, const mxArray *prhs[] )
{
double g,wmax, eta, eta2, delt, numw, f, *rec, *sptime, *bursttime4,
Lambda, *signal, *weight, *w, *bursttime2, tau_m, tau_w;
mwSize mrows,ncols, mrows2, ncols2, endi, spsize;
/* Check for proper number of arguments. */
if(nrhs!=11) {
mexErrMsgTxt("Eleven inputs required.");
} else if(nlhs>5) {
mexErrMsgTxt("Too many output arguments");
}
signal=mxGetPr(prhs[0]);
mrows = mxGetM(prhs[0]);
ncols = mxGetN(prhs[0]);
endi= MAX(mrows,ncols);
f=mxGetScalar(prhs[1]);
weight=mxGetPr(prhs[2]);
g=mxGetScalar(prhs[3]);
Lambda=mxGetScalar(prhs[4]);
eta=mxGetScalar(prhs[5]);
eta2=mxGetScalar(prhs[6]);
tau_m=mxGetScalar(prhs[7]);
tau_w=mxGetScalar(prhs[8]);
delt=mxGetScalar(prhs[9]);
wmax=mxGetScalar(prhs[10]);
mrows2 = mxGetM(prhs[2]);
ncols2 = mxGetN(prhs[2]);
numw=MAX(mrows2,ncols2);
/* Create matrix for the return argument. */
/*I assume the avg f.r. not greater than 1/tau_m; otherwise, will get segmentation faults*/
spsize=endi*delt;
plhs[1] = mxCreateDoubleMatrix(spsize,1, mxREAL);
plhs[0] = mxCreateDoubleMatrix(spsize,1, mxREAL);
plhs[2] = mxCreateDoubleMatrix(spsize,1, mxREAL);
plhs[3] = mxCreateDoubleMatrix(mrows2,ncols2, mxREAL);
plhs[4] = mxCreateDoubleMatrix(spsize,1, mxREAL);
plhs[5] = mxCreateDoubleMatrix(1,1, mxREAL);
/* OUTPUTS*/
sptime=mxGetPr(plhs[0]);
bursttime2=mxGetPr(plhs[1]);
bursttime4=mxGetPr(plhs[2]);
w=mxGetPr(plhs[3]);
rec=mxGetPr(plhs[4]);
/* Call the timeloop subroutine. */
timeloop(sptime,bursttime4,bursttime2, w,signal,weight,rec,f,g,Lambda,eta,eta2, numw,endi,tau_m, tau_w, delt, wmax);
}