// Learning basis functions to implement functions of one population-encoded
// variable using STDP.
// The model has an Input Layer (cellLayer[0]) and a Training Layer
// (cellLayer[1]), each consisting of spike sources, and projecting to an Output
// Layer (cellLayer[2]) consisting of integrate-and-fire neurons.
// The synaptic weights from Training-->Output are fixed.
// The synaptic weights from Input-->Output are plastic and obey a STDP rule.
// During training, the Input Layer receives input x, and the Training Layer
// input f(x). After training, the Training Layer is silent, and an input x to
// the Input Layer produces an output f(x) in the Output Layer.
// Uses the NetStimVR2 mechanism, rather than VecStimMs
// Andrew P. Davison, UNIC, CNRS, July 2004-May 2006
startsw()
objref cvode
cvode = new CVode()
xopen("netLayer.hoc")
xopen("layerConn.hoc")
xopen("ObjectArray.hoc")
xopen("intfire4nc.hoc")
xopen("plotweights.hoc")
// =-=-= Create objects and strings =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
objref random, fileobj[3], histfileobj
objref cellLayer[3], conn[3], spikecontrol
objref cellParams, spikerec[2]
objref deltat_vec[2][3], deltat_hist
strdef fileroot, infile, filename, save_fileroot
strdef command, funcstr, label, datadir
double m[2][3]
// =-=-= Global Parameters =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
seed = 0 // Seed for the random number generator
ncells = 30 // Number of input spike trains per layer
pconnect = 1.0 // Connection probability
wmax = 0.02 // Maximum synaptic weight
f_winhib = 0.0 // Inhibitory weight = f_winhib*wmax (fixed)
f_wtr = 1.0 // Max training weight = f_wtr*wmax
syndelay = 0.0 // Synaptic delay
tauLTP_StdwaSA = 20 // (ms) Time constant for LTP
tauLTD_StdwaSA = 20 // (ms) Time constant for LTD
B = 1.06 // B = (aLTD*tauLTD)/(aLTP*tau_LTP)
aLTP = 0.01 // Amplitude parameter for LTP
Rmax = 60 // (Hz) Peak firing rate of input distribution
Rmin = 0 // (Hz) Minumum input firing rate
Rsigma = 0.2 // Width parameter for input distribution
alpha = 1.0 // Gain of Training Layer rates compared to Input Layer
correlation_time = 20 // (ms)
bgRate = 1000 // (Hz) Firing rate for background activity
bgWeight = 0.02 // Weight for background activity
funcstr = "sin" // Label for function to be approximated
nfuncparam = 1 // Number of parameters of function
double k[nfuncparam]
k[0] = 0.0 // Function parameter(s)
wtr_square = 1 // Sets square or bell-shaped profile for T-->O weights
wtr_sigma = 0.15 // Width parameter for Training-->Output weights
noise = 1 // Noise parameter
histbins = 100 // Number of bins for weight histograms
record_spikes = 0 // Whether or not to record spikes
wfromfile = 0 // if positive, read connections/weights from file
infile = "" // File to read connections/weights from
tstop = 1e7 // (ms)
trw = 1e5 // (ms) Time between reading input spikes/printing weights
numhist = 10 // Number of histograms between each weight printout
label = "bfstdp_demo_" // Extra label for labelling output files
datadir = "" // Sub-directory of Data for writing output files
tau_m = 20 // Membrane time constant
// =-=-= Create utility objects =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
random = new Random(seed)
histfileobj = new File()
for i = 0,2 {
fileobj[i] = new File()
}
spikerec[0] = new ObjectArray(1,ncells,"Vector","")
spikerec[1] = new ObjectArray(1,ncells,"Vector","")
// =-=-= Create the network =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
// Input spike trains are implemented using NetStimVR2s.
print "Creating network layers (time ", stopsw(), "s)"
cellParams = new Vector(4)
cellParams.x[0] = tau_m
cellParams.x[1] = 5
cellParams.x[2] = 10
cellParams.x[3] = 15
// Create network layers
for layer = 0,1 {
cellLayer[layer] = new NetLayer(1,ncells,"NetStimVR2",0.5)
cellLayer[layer].set("noise",1)
for i = 0,ncells-1 {
cellLayer[layer].cell[i].theta = i/ncells
}
}
Rmax_NetStimVR2 = Rmax
Rmin_NetStimVR2 = Rmin
sigma_NetStimVR2 = Rsigma
cellLayer[0].set("transform",0)
cellLayer[0].set("prmtr",0)
if (strcmp(funcstr,"") == 0) cellLayer[1].set("transform",0)
if (strcmp(funcstr,"mul") == 0) cellLayer[1].set("transform",1)
if (strcmp(funcstr,"sin") == 0) cellLayer[1].set("transform",2)
if (strcmp(funcstr,"sq") == 0) cellLayer[1].set("transform",3)
if (strcmp(funcstr,"asin") == 0) cellLayer[1].set("transform",4)
if (strcmp(funcstr,"sinn") == 0) cellLayer[1].set("transform",5)
cellLayer[1].set("prmtr",k[0])
cellLayer[1].set("alpha",alpha)
spikecontrol = new ControlNSVR2(0.5)
spikecontrol.tau_corr = correlation_time
spikecontrol.seed(seed)
setpointer spikecontrol.thetastim, thetastim_NetStimVR2
setpointer spikecontrol.tchange, tchange_NetStimVR2
cellLayer[2] = new NetLayer(1,ncells,"IntFire4nc",cellParams)
// Create synaptic connections
print "Creating synaptic connections (time ", stopsw(), "s)"
random.uniform(0,1)
if (wfromfile) { // read connections from file
for i = 0,1 {
sprint(filename,"%s.conn%d.conn",infile,i+1)
fileobj[0].ropen(filename)
conn[i] = new LayerConn(cellLayer[i],"",cellLayer[2],"syn",4,fileobj[0])
fileobj[0].close()
}
if (f_winhib != 0) {
sprint(filename,"%s.conn2.conn",infile)
fileobj[0].ropen(filename)
conn[2] = new LayerConn(cellLayer[2],"syn",cellLayer[2],"syn",4,fileobj[0])
fileobj[0].close()
}
} else { // or generate them according to the rules specified
conn[0] = new LayerConn(cellLayer[0],"",cellLayer[2],"syn",1,pconnect,random) // 1 for all:all
r = random.uniform(0,wmax)
conn[0].randomize_weights(random)
conn[1] = new LayerConn(cellLayer[1],"",cellLayer[2],"syn",1,pconnect,random)
if (syndelay < 0) {
conn[0].set_delays(-1*syndelay)
conn[1].set_delays(0)
} else if (syndelay > 0) {
conn[0].set_delays(0)
conn[1].set_delays(syndelay)
}
if (f_winhib != 0) {
conn[2] = new LayerConn(cellLayer[2],"syn",cellLayer[2],"syn",1)
conn[2].remove_self_connections()
conn[2].set_weights(wmax*f_winhib)
}
}
// Turn on STDP for Input-->Output connections
print "Setting up STDP for Input-->Output connections (time ", stopsw(), "s)"
conn[0].stdp("StdwaSA")
conn[0].set_max_weight(wmax)
conn[0].wa_set("aLTP",aLTP)
conn[0].wa_set("aLTD",B*aLTP*tauLTP_StdwaSA/tauLTD_StdwaSA)
// Set background input
print "Setting background activity (time ", stopsw(), "s)"
sprint(command,"%f, %f, 0, 1, 1e12",bgWeight,bgRate)
cellLayer[2].call("set_background",command)
// Turn on recording of spikes
if (record_spikes) {
cellLayer[2].call("record","1")
for i = 0,ncells-1 {
conn[0].nc[i][i].record(spikerec[0].x[i])
conn[1].nc[i][i].record(spikerec[1].x[i])
}
}
// =-=-= Procedures =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
// Utility procedures ----------------------------------------------------------
proc set_fileroot() { local i
system("date '+%Y%m%d_%H%M' > starttime")
fileobj[0].ropen("starttime")
fileobj[0].scanstr(save_fileroot)
fileobj[0].close()
sprint(fileroot,"Data/%s/%s%s",datadir,label,funcstr)
for i = 0, nfuncparam-1 {
sprint(fileroot,"%s-%3.1f",fileroot,k[i])
}
sprint(fileroot,"%s_%s",fileroot,save_fileroot)
print "fileroot = ", fileroot
}
// Procedures to read input spike trains from file -----------------------------
// Procedures to set weights ---------------------------------------------------
proc set_training_weights() { local i, j, d
// Set the Training-->Output weights
for i = 0, ncells-1 {
for j = 0, ncells-1 {
if(object_id(conn[1].nc[i][j])) {
d = i-j
if (d > ncells/2) { d = ncells - d }
if (d < -ncells/2) { d = ncells + d }
if (wtr_square) {
if (d <= wtr_sigma*ncells && d >= -wtr_sigma*ncells) {
conn[1].nc[i][j].weight = f_wtr*wmax
}
} else {
conn[1].nc[i][j].weight = f_wtr*wmax*exp( (cos(2*PI*d/ncells) - 1) / (wtr_sigma*wtr_sigma) )
}
}
}
}
}
// Procedures for writing results to file --------------------------------------
proc save_parameters() { local i
sprint(filename,"%s.param",fileroot)
fileobj[0].wopen(filename)
fileobj[0].printf("// Parameters for bfstdp_nsvr2.hoc\n")
fileobj[0].printf("%-17s = %d\n","seed",seed)
fileobj[0].printf("%-17s = %d\n","ncells",ncells)
fileobj[0].printf("%-17s = %f\n","pconnect",pconnect)
fileobj[0].printf("%-17s = %f\n","wmax",wmax)
fileobj[0].printf("%-17s = %f\n","f_winhib",f_winhib)
fileobj[0].printf("%-17s = %f\n","f_wtr",f_wtr)
fileobj[0].printf("%-17s = %f\n","syndelay",syndelay)
fileobj[0].printf("%-17s = %f\n","tauLTP_StdwaSA",tauLTP_StdwaSA)
fileobj[0].printf("%-17s = %f\n","tauLTD_StdwaSA",tauLTD_StdwaSA)
fileobj[0].printf("%-17s = %f\n","B",B)
fileobj[0].printf("%-17s = %f\n","aLTP",aLTP)
fileobj[0].printf("%-17s = %f\n","Rmax",Rmax)
fileobj[0].printf("%-17s = %f\n","Rmin",Rmin)
fileobj[0].printf("%-17s = %f\n","Rsigma",Rsigma)
fileobj[0].printf("%-17s = %f\n","alpha",alpha)
fileobj[0].printf("%-17s = %f\n","correlation_time",correlation_time)
fileobj[0].printf("%-17s = %f\n","bgWeight",bgWeight)
fileobj[0].printf("%-17s = %f\n","bgRate",bgRate)
fileobj[0].printf("%-17s = \"%s\"\n","funcstr",funcstr)
fileobj[0].printf("%-17s = %f\n","nfuncparam",nfuncparam)
for i = 0, nfuncparam-1 {
fileobj[0].printf("%-14s[%d] = %f\n","k",i,k[i])
}
fileobj[0].printf("%-17s = %f\n","wtr_square",wtr_square)
fileobj[0].printf("%-17s = %f\n","wtr_sigma",wtr_sigma)
fileobj[0].printf("%-17s = %f\n","noise",noise)
fileobj[0].printf("%-17s = %f\n","tau_m",tau_m)
if (wfromfile) {
fileobj[0].printf("%-17s = \"%s\"\n","infile",infile)
}
fileobj[0].close()
}
proc print_rasters() { local i,j,k
// Write spike times to files.
// Plot using
// gnuplot> plot "<fileroot>.input1.ras" u 1:2 w d
if (record_spikes) {
for i = 0,1 {
sprint(filename,"%s.cell%d.ras",fileroot,i+1)
$o1.wopen(filename)
for j = 0,ncells-1 {
for k = 0,spikerec[i].x[j].size()-1 {
$o1.printf("%15.5g\t%d\n",spikerec[i].x[j].x[k],j)
}
$o1.printf("\n")
}
$o1.close()
}
sprint(filename,"%s.cell3.ras",fileroot)
$o1.wopen(filename)
cellLayer[2].print_spikes($o1)
$o1.close()
}
}
proc print_weights() { local i
sprint(filename,"%s.conn%d.w",fileroot,$1+1)
fileobj[0].wopen(filename)
conn[$1].print_weights(fileobj[0])
fileobj[0].close()
}
proc save_connections() { local i
for i = 0,2-(f_winhib==0) {
sprint(filename,"%s.conn%d.conn",fileroot,i+1)
fileobj[0].wopen(filename)
conn[i].save_connections(fileobj[0])
fileobj[0].close()
}
}
proc print_weight_distribution() { local i
// Pointless to calculate distribution for inhibitory weights (i=1,2)
conn[0].print_weight_hist(histfileobj,histbins,1)
}
proc print_delta_t() { local i,ii, histbins, range, total_size
binwidth = $1 // ms
range = $2
histbins = 2*range+1
deltat_hist = new Vector(histbins)
for layer = 0,1 {
total_size = deltat_vec[layer][0].size() + deltat_vec[layer][1].size() + deltat_vec[layer][2].size()
for ii = 0,2 {
deltat_hist.hist(deltat_vec[layer][ii],-range-0.5,histbins,binwidth)
if ($3 == 1) deltat_hist.div(total_size)
sprint(filename,"%s.conn%d.deltat%d",fileroot,layer+1,ii)
fileobj.wopen(filename)
for i = 0, histbins-1 { //print in a column
fileobj.printf("%g\t%g\n",-range+binwidth*i,deltat_hist.x[i])
}
//deltat_vec.printf(fileobj)
fileobj.close()
}
}
}
// Procedures that process recorded data ---------------------------------------
proc calc_delta_t() { local i,j,k,l,ii, nspikes_post, nspikes_pre, deltat, d
// Calculate the distribution of spike-time differences (post-pre)
// in three classes: connections for which d < 0.1, d < 0.2, d >= 0.2
if (record_spikes) {
for ii = 0,2 {
for layer = 0,1 {
deltat_vec[layer][ii] = new Vector(1e6)
m[layer][ii] = 0
}
}
for i = 0,ncells-1 {
nspikes_post = cellLayer[2].cell[i].spiketimes.size()
if (nspikes_post > 0) {
for j = 0, nspikes_post-1 {
tpost = cellLayer[2].cell[i].spiketimes.x[j]
for k = 0,ncells-1 {
for layer = 0,1 {
if (layer==0) {
d = i/ncells - (sin(2*PI*k/ncells)+1)/2
} else {
d = i/ncells - k/ncells
}
if (d < -0.5) d += 1
if (d >= 0.5) d -= 1
d = abs(d)
if (d < 0.1) {
ii = 0
} else {
if (d < 0.2) {
ii = 1
} else {
ii = 2
}
}
nspikes_pre = spikerec[layer].x[k].size()
if (nspikes_pre > 0) {
for l = 0, nspikes_pre-1 {
deltat = tpost - spikerec[layer].x[k].x[l]
if (deltat < $2 && deltat > -1*$2) {
deltat_vec[layer][ii].x[m[layer][ii]] = deltat
m[layer][ii] += 1
if (m[layer][ii] >= deltat_vec[layer][ii].size()-1) {
deltat_vec[layer][ii].resize(2*deltat_vec[layer][ii].size)
printf("deltat_vec[%d][%d] resized\n",layer,ii)
}
}
}
}
}
}
}
}
}
printf("Spike pairs: %d,%d %d,%d %d,%d\n",m[0][0],m[1][0],m[0][1],m[1][1],m[0][2],m[1][2])
for ii = 0,2 {
deltat_vec[0][ii].resize(m[0][ii])
deltat_vec[1][ii].resize(m[1][ii])
}
print_delta_t($1,$2,$3)
}
}
// Procedures that run simulations ---------------------------------------------
proc run_training() { local i, j, fileopen, thist
// Training the network. The weight histogram is written to
// file every trw ms. The weights are written to file every
// thist = trw/numhist ms. The spike-times of the network
// cells are written to file at the end.
on_StdwaSA = 1
thist = int(trw/numhist)
sprint(filename,"%s.conn1.whist",fileroot)
histfileobj.wopen(filename)
save_parameters()
save_fileroot = fileroot
sprint(fileroot,"%s_%d",save_fileroot,0)
print_weights(0)
print_weights(1)
save_connections()
i = 0
j = 0
running_ = 1
stoprun = 0
setup_weight_plot()
finitialize(-65)
plot_weights(conn[0])
starttime = startsw()
while (t < tstop && stoprun == 0) {
sprint(fileroot,"%s_%d",save_fileroot,j*thist)
print_weight_distribution()
if (i == numhist) {
print_weights(0)
i = 0
printf("--- Simulated %d seconds in %d seconds\r",int(t/1000),startsw()-starttime)
flushf()
}
i += 1
j += 1
while (t < j*thist) {
fadvance()
}
//continuerun(j*thist)
plot_weights(conn[0])
}
printf("--- Simulated %d seconds in %d seconds\n",int(t/1000),stopsw())
sprint(fileroot,"%s_%d",save_fileroot,j*thist)
print_weights(0)
print_weights(1) // for debugging. Should not have changed since t = 0
print_weight_distribution()
save_connections()
fileroot = save_fileroot
// This corrects the pre-synaptic spiketimes for syndelay.
// This is necessary because nc.record records spike times at the source
// whereas we want to know them at the target.
if (syndelay < 0) {
for i = 0,ncells-1 {
spikerec[0].x[i].add(-1*syndelay)
}
} else if (syndelay > 0) {
for i = 0,ncells-1 {
spikerec[1].x[i].add(syndelay)
}
}
print_rasters(fileobj[0])
histfileobj.close()
print "Training complete. Time ", stopsw()
calc_delta_t(1.0,1000,0)
}
// =-=-= Initialize the network =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-
set_fileroot()
cvode.active(1)
cvode.use_local_dt(1) // The variable time step method must be used.
cvode.condition_order(2) // Improves threshold-detection.
set_training_weights()
//steps_per_ms = 10
//dt = 0.1
print "Finished set-up (time ", stopsw(), "s)"
print "Running training ..."
run_training()