// Learning basis functions to implement functions of one population-encoded
// variable using STDP.

// The model has an Input Layer (cellLayer[0]) and a Training Layer
// (cellLayer[1]), each consisting of spike sources, and projecting to an Output
// Layer (cellLayer[2]) consisting of integrate-and-fire neurons.

// The synaptic weights from Training-->Output are fixed.
// The synaptic weights from Input-->Output are plastic and obey a STDP rule.

// During training, the Input Layer receives input x, and the Training Layer
// input f(x). After training, the Training Layer is silent, and an input x to
// the Input Layer produces an output f(x) in the Output Layer.

// Uses the NetStimVR2 mechanism, rather than VecStimMs

// Andrew P. Davison, UNIC, CNRS, July 2004-May 2006

startsw()
objref cvode
cvode = new CVode()
xopen("netLayer.hoc")
xopen("layerConn.hoc")
xopen("ObjectArray.hoc")
xopen("intfire4nc.hoc")
xopen("plotweights.hoc")

// =-=-= Create objects and strings  =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=

objref random, fileobj[3], histfileobj
objref cellLayer[3], conn[3], spikecontrol
objref cellParams, spikerec[2]
objref deltat_vec[2][3], deltat_hist
strdef fileroot, infile, filename, save_fileroot
strdef command, funcstr, label, datadir
double m[2][3]

// =-=-= Global Parameters =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=

seed             = 0           // Seed for the random number generator
ncells           = 30          // Number of input spike trains per layer
pconnect         = 1.0         // Connection probability
wmax             = 0.02        // Maximum synaptic weight
f_winhib         = 0.0         // Inhibitory weight = f_winhib*wmax (fixed)
f_wtr            = 1.0         // Max training weight = f_wtr*wmax
syndelay         = 0.0         // Synaptic delay
tauLTP_StdwaSA   = 20          // (ms) Time constant for LTP
tauLTD_StdwaSA   = 20          // (ms) Time constant for LTD
B                = 1.06        // B = (aLTD*tauLTD)/(aLTP*tau_LTP)
aLTP             = 0.01        // Amplitude parameter for LTP
Rmax             = 60          // (Hz) Peak firing rate of input distribution
Rmin             = 0           // (Hz) Minumum input firing rate
Rsigma           = 0.2         // Width parameter for input distribution
alpha            = 1.0         // Gain of Training Layer rates compared to Input Layer
correlation_time = 20          // (ms) 
bgRate           = 1000        // (Hz) Firing rate for background activity
bgWeight         = 0.02        // Weight for background activity
funcstr          = "sin"       // Label for function to be approximated
nfuncparam       = 1           // Number of parameters of function
double k[nfuncparam]
k[0]             = 0.0         // Function parameter(s)
wtr_square       = 1           // Sets square or bell-shaped profile for T-->O weights
wtr_sigma        = 0.15        // Width parameter for Training-->Output weights
noise            = 1           // Noise parameter
histbins         = 100         // Number of bins for weight histograms
record_spikes    = 0           // Whether or not to record spikes
wfromfile        = 0           // if positive, read connections/weights from file
infile           = ""          // File to read connections/weights from
tstop            = 1e7         // (ms)
trw              = 1e5         // (ms) Time between reading input spikes/printing weights
numhist          = 10          // Number of histograms between each weight printout
label            = "bfstdp_demo_" // Extra label for labelling output files
datadir          = ""          // Sub-directory of Data for writing output files
tau_m            = 20          // Membrane time constant

// =-=-= Create utility objects  =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=

random = new Random(seed)
histfileobj = new File()
for i = 0,2 { 
  fileobj[i] = new File()
}
spikerec[0] = new ObjectArray(1,ncells,"Vector","")
spikerec[1] = new ObjectArray(1,ncells,"Vector","")

// =-=-= Create the network  =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=

// Input spike trains are implemented using NetStimVR2s.
print "Creating network layers (time ", stopsw(), "s)"

cellParams = new Vector(4)
cellParams.x[0] = tau_m
cellParams.x[1] = 5
cellParams.x[2] = 10
cellParams.x[3] = 15

// Create network layers
for layer = 0,1 {
  cellLayer[layer] = new NetLayer(1,ncells,"NetStimVR2",0.5)
  cellLayer[layer].set("noise",1)
  for i = 0,ncells-1 {
    cellLayer[layer].cell[i].theta = i/ncells
  }
}
Rmax_NetStimVR2 = Rmax
Rmin_NetStimVR2 = Rmin
sigma_NetStimVR2 = Rsigma

cellLayer[0].set("transform",0)
cellLayer[0].set("prmtr",0)
if (strcmp(funcstr,"") == 0) cellLayer[1].set("transform",0)
if (strcmp(funcstr,"mul") == 0) cellLayer[1].set("transform",1)
if (strcmp(funcstr,"sin") == 0) cellLayer[1].set("transform",2)
if (strcmp(funcstr,"sq") == 0) cellLayer[1].set("transform",3)
if (strcmp(funcstr,"asin") == 0) cellLayer[1].set("transform",4)
if (strcmp(funcstr,"sinn") == 0) cellLayer[1].set("transform",5)
cellLayer[1].set("prmtr",k[0])
cellLayer[1].set("alpha",alpha)

spikecontrol = new ControlNSVR2(0.5)
spikecontrol.tau_corr = correlation_time
spikecontrol.seed(seed)
setpointer spikecontrol.thetastim, thetastim_NetStimVR2
setpointer spikecontrol.tchange, tchange_NetStimVR2

cellLayer[2] = new NetLayer(1,ncells,"IntFire4nc",cellParams)

// Create synaptic connections
print "Creating synaptic connections (time ", stopsw(), "s)"

random.uniform(0,1)
if (wfromfile) { // read connections from file
  for i = 0,1 {
    sprint(filename,"%s.conn%d.conn",infile,i+1)
    fileobj[0].ropen(filename)
    conn[i] = new LayerConn(cellLayer[i],"",cellLayer[2],"syn",4,fileobj[0])
    fileobj[0].close()
  }
  if (f_winhib != 0) {
    sprint(filename,"%s.conn2.conn",infile)
    fileobj[0].ropen(filename)
    conn[2] = new LayerConn(cellLayer[2],"syn",cellLayer[2],"syn",4,fileobj[0])
    fileobj[0].close()
  }
} else {         // or generate them according to the rules specified
  conn[0] = new LayerConn(cellLayer[0],"",cellLayer[2],"syn",1,pconnect,random) // 1 for all:all
  r = random.uniform(0,wmax)
  conn[0].randomize_weights(random)
  conn[1] = new LayerConn(cellLayer[1],"",cellLayer[2],"syn",1,pconnect,random)
  if (syndelay < 0) {
    conn[0].set_delays(-1*syndelay)
    conn[1].set_delays(0)
  } else if (syndelay > 0) {
    conn[0].set_delays(0)
    conn[1].set_delays(syndelay)
  }
  if (f_winhib != 0) {
    conn[2] = new LayerConn(cellLayer[2],"syn",cellLayer[2],"syn",1)
    conn[2].remove_self_connections()
    conn[2].set_weights(wmax*f_winhib)
  }
}

// Turn on STDP for Input-->Output connections
print "Setting up STDP for Input-->Output connections (time ", stopsw(), "s)"
conn[0].stdp("StdwaSA")
conn[0].set_max_weight(wmax)
conn[0].wa_set("aLTP",aLTP)
conn[0].wa_set("aLTD",B*aLTP*tauLTP_StdwaSA/tauLTD_StdwaSA)

// Set background input
print "Setting background activity (time ", stopsw(), "s)"
sprint(command,"%f, %f, 0, 1, 1e12",bgWeight,bgRate)
cellLayer[2].call("set_background",command)

// Turn on recording of spikes
if (record_spikes) {
  cellLayer[2].call("record","1")
  for i = 0,ncells-1 {
    conn[0].nc[i][i].record(spikerec[0].x[i])
    conn[1].nc[i][i].record(spikerec[1].x[i])
  }
}


// =-=-= Procedures =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

// Utility procedures ----------------------------------------------------------

proc set_fileroot() { local i
  system("date '+%Y%m%d_%H%M' > starttime")
  fileobj[0].ropen("starttime")
  fileobj[0].scanstr(save_fileroot)
  fileobj[0].close()
  sprint(fileroot,"Data/%s/%s%s",datadir,label,funcstr)
  for i = 0, nfuncparam-1 {
    sprint(fileroot,"%s-%3.1f",fileroot,k[i])
  }
  sprint(fileroot,"%s_%s",fileroot,save_fileroot)
  print "fileroot = ", fileroot
}

// Procedures to read input spike trains from file -----------------------------

// Procedures to set weights ---------------------------------------------------

proc set_training_weights() { local i, j, d
  // Set the Training-->Output weights
  
  for i = 0, ncells-1 {
    for j = 0, ncells-1 {
      if(object_id(conn[1].nc[i][j])) {
	d = i-j
	if (d > ncells/2)  { d = ncells - d }
	if (d < -ncells/2) { d = ncells + d }
	if (wtr_square) {
	  if (d <= wtr_sigma*ncells && d >= -wtr_sigma*ncells) {
	    conn[1].nc[i][j].weight = f_wtr*wmax
	  }
	} else {
	  conn[1].nc[i][j].weight = f_wtr*wmax*exp( (cos(2*PI*d/ncells) - 1) / (wtr_sigma*wtr_sigma) )
	}
      }
    }
  }
}

// Procedures for writing results to file --------------------------------------

proc save_parameters() { local i
  sprint(filename,"%s.param",fileroot)
  fileobj[0].wopen(filename)
  fileobj[0].printf("// Parameters for bfstdp_nsvr2.hoc\n")
  fileobj[0].printf("%-17s = %d\n","seed",seed)
  fileobj[0].printf("%-17s = %d\n","ncells",ncells)
  fileobj[0].printf("%-17s = %f\n","pconnect",pconnect)
  fileobj[0].printf("%-17s = %f\n","wmax",wmax)
  fileobj[0].printf("%-17s = %f\n","f_winhib",f_winhib)
  fileobj[0].printf("%-17s = %f\n","f_wtr",f_wtr)
  fileobj[0].printf("%-17s = %f\n","syndelay",syndelay)
  fileobj[0].printf("%-17s = %f\n","tauLTP_StdwaSA",tauLTP_StdwaSA)
  fileobj[0].printf("%-17s = %f\n","tauLTD_StdwaSA",tauLTD_StdwaSA)
  fileobj[0].printf("%-17s = %f\n","B",B)
  fileobj[0].printf("%-17s = %f\n","aLTP",aLTP)  
  fileobj[0].printf("%-17s = %f\n","Rmax",Rmax)
  fileobj[0].printf("%-17s = %f\n","Rmin",Rmin)
  fileobj[0].printf("%-17s = %f\n","Rsigma",Rsigma)
  fileobj[0].printf("%-17s = %f\n","alpha",alpha)
  fileobj[0].printf("%-17s = %f\n","correlation_time",correlation_time)
  fileobj[0].printf("%-17s = %f\n","bgWeight",bgWeight)
  fileobj[0].printf("%-17s = %f\n","bgRate",bgRate)
  fileobj[0].printf("%-17s = \"%s\"\n","funcstr",funcstr)
  fileobj[0].printf("%-17s = %f\n","nfuncparam",nfuncparam)
  for i = 0, nfuncparam-1 {
    fileobj[0].printf("%-14s[%d] = %f\n","k",i,k[i])
  }
  fileobj[0].printf("%-17s = %f\n","wtr_square",wtr_square)
  fileobj[0].printf("%-17s = %f\n","wtr_sigma",wtr_sigma)
  fileobj[0].printf("%-17s = %f\n","noise",noise)
  fileobj[0].printf("%-17s = %f\n","tau_m",tau_m)
  if (wfromfile) {
    fileobj[0].printf("%-17s = \"%s\"\n","infile",infile)
  }
  fileobj[0].close()
}

proc print_rasters() { local i,j,k
  // Write spike times to files.
  // Plot using 
  //   gnuplot> plot "<fileroot>.input1.ras" u 1:2 w d
  
  if (record_spikes) {
    for i = 0,1 {
      sprint(filename,"%s.cell%d.ras",fileroot,i+1)
      $o1.wopen(filename)
      for j = 0,ncells-1 {
	for k = 0,spikerec[i].x[j].size()-1 {
	  $o1.printf("%15.5g\t%d\n",spikerec[i].x[j].x[k],j)
	}
	$o1.printf("\n")
      }
      $o1.close()
    }
    sprint(filename,"%s.cell3.ras",fileroot)
    $o1.wopen(filename)
    cellLayer[2].print_spikes($o1)
    $o1.close()
  }
}

proc print_weights() { local i
  sprint(filename,"%s.conn%d.w",fileroot,$1+1)
  fileobj[0].wopen(filename)
  conn[$1].print_weights(fileobj[0])
  fileobj[0].close()
}

proc save_connections() { local i
  for i = 0,2-(f_winhib==0) {
    sprint(filename,"%s.conn%d.conn",fileroot,i+1)
    fileobj[0].wopen(filename)
    conn[i].save_connections(fileobj[0])
    fileobj[0].close()
  }
}

proc print_weight_distribution() { local i
  // Pointless to calculate distribution for inhibitory weights (i=1,2)
  conn[0].print_weight_hist(histfileobj,histbins,1)
}

proc print_delta_t() { local i,ii, histbins, range, total_size
  binwidth = $1 // ms
  range = $2
  histbins = 2*range+1
  deltat_hist = new Vector(histbins)
  for layer = 0,1 {
    total_size = deltat_vec[layer][0].size() + deltat_vec[layer][1].size() + deltat_vec[layer][2].size()
    for ii = 0,2 {
      deltat_hist.hist(deltat_vec[layer][ii],-range-0.5,histbins,binwidth)
      if ($3 == 1) deltat_hist.div(total_size)
      sprint(filename,"%s.conn%d.deltat%d",fileroot,layer+1,ii)
      fileobj.wopen(filename)
      for i = 0, histbins-1 { //print in a column
	fileobj.printf("%g\t%g\n",-range+binwidth*i,deltat_hist.x[i])
      }
      //deltat_vec.printf(fileobj)
      fileobj.close()
    }
  }
}

// Procedures that process recorded data ---------------------------------------

proc calc_delta_t() { local i,j,k,l,ii, nspikes_post, nspikes_pre, deltat, d
  // Calculate the distribution of spike-time differences (post-pre)
  // in three classes: connections for which d < 0.1, d < 0.2, d >= 0.2
  if (record_spikes) {
    for ii = 0,2 {
      for layer = 0,1 {
	deltat_vec[layer][ii] = new Vector(1e6)
	m[layer][ii] = 0
      }
    }
    for i = 0,ncells-1 {
      nspikes_post = cellLayer[2].cell[i].spiketimes.size()
      if (nspikes_post > 0) {
	for j = 0, nspikes_post-1 {
	  tpost = cellLayer[2].cell[i].spiketimes.x[j]
	  for k = 0,ncells-1 {
	    for layer = 0,1 {
	      if (layer==0) {
		d  = i/ncells - (sin(2*PI*k/ncells)+1)/2
	      } else {
		d = i/ncells - k/ncells
	      }
	      if (d < -0.5) d += 1
	      if (d >= 0.5) d -= 1
	      d = abs(d)
	      if (d < 0.1) {
		ii = 0
	      } else {
		if (d < 0.2) {
		  ii = 1
		} else {
		  ii = 2
		}
	      }
	      nspikes_pre = spikerec[layer].x[k].size()
	      if (nspikes_pre > 0) {
		for l = 0, nspikes_pre-1 {
		  deltat = tpost - spikerec[layer].x[k].x[l]
		  if (deltat < $2 && deltat > -1*$2) {
		    deltat_vec[layer][ii].x[m[layer][ii]] = deltat
		    m[layer][ii] += 1
		    if (m[layer][ii] >= deltat_vec[layer][ii].size()-1) {
		      deltat_vec[layer][ii].resize(2*deltat_vec[layer][ii].size)
		      printf("deltat_vec[%d][%d] resized\n",layer,ii)
		    }
		  }
		}
	      }
	    }
	  }
	}
      }
    }
    printf("Spike pairs: %d,%d  %d,%d  %d,%d\n",m[0][0],m[1][0],m[0][1],m[1][1],m[0][2],m[1][2])
    for ii = 0,2 {
      deltat_vec[0][ii].resize(m[0][ii])
      deltat_vec[1][ii].resize(m[1][ii])
    }
    print_delta_t($1,$2,$3)
    
  }
}



// Procedures that run simulations ---------------------------------------------

proc run_training() { local i, j, fileopen, thist
  // Training the network. The weight histogram is written to
  // file every trw ms. The weights are written to file every
  // thist = trw/numhist ms. The spike-times of the network
  // cells are written to file at the end.
  

  on_StdwaSA = 1
  thist = int(trw/numhist)

  sprint(filename,"%s.conn1.whist",fileroot)
  histfileobj.wopen(filename)
  
  save_parameters()
  save_fileroot = fileroot
  sprint(fileroot,"%s_%d",save_fileroot,0)
  print_weights(0)
  print_weights(1)
  save_connections()
  
  i = 0
  j = 0

  running_ = 1
  stoprun = 0
  setup_weight_plot()
  finitialize(-65)
  plot_weights(conn[0])
  starttime = startsw()
  while (t < tstop && stoprun == 0) {
    sprint(fileroot,"%s_%d",save_fileroot,j*thist)
    print_weight_distribution()
    if (i == numhist) {
      print_weights(0)
      i = 0
      printf("--- Simulated %d seconds in %d seconds\r",int(t/1000),startsw()-starttime)
      flushf()
    }
    i += 1
    j += 1
    while (t < j*thist) {
      fadvance()
    }
    //continuerun(j*thist)
    plot_weights(conn[0])
  }
  printf("--- Simulated %d seconds in %d seconds\n",int(t/1000),stopsw())
  
  sprint(fileroot,"%s_%d",save_fileroot,j*thist)
  print_weights(0)
  print_weights(1) // for debugging. Should not have changed since t = 0
  print_weight_distribution()
  save_connections()
  
  fileroot = save_fileroot
  
  // This corrects the pre-synaptic spiketimes for syndelay.
  // This is necessary because nc.record records spike times at the source
  // whereas we want to know them at the target.
  
  if (syndelay < 0) {
    for i = 0,ncells-1 {
      spikerec[0].x[i].add(-1*syndelay)
    } 
  } else if (syndelay > 0) {
    for i = 0,ncells-1 {
      spikerec[1].x[i].add(syndelay)
    }
  }

  
  print_rasters(fileobj[0])
  
  histfileobj.close()
  print "Training complete. Time ", stopsw()
  calc_delta_t(1.0,1000,0)
}

// =-=-= Initialize the network =-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-

set_fileroot()
cvode.active(1)
cvode.use_local_dt(1)         // The variable time step method must be used.
cvode.condition_order(2)      // Improves threshold-detection.
set_training_weights()
//steps_per_ms = 10
//dt = 0.1

print "Finished set-up (time ", stopsw(), "s)"

print "Running training ..."

run_training()