// $Id: autotune.hoc,v 1.1 2011/04/17 17:40:22 samn Exp $

//* sim duration
tstop=mytstop=htmax=6000e3

//* load sim

EEGain = EIGain = IEGain = IIGain = 1

autotune = 1 

load_file("nqsnet.hoc")
load_file("network.hoc")
load_file("params.hoc")
load_file("run.hoc")
load_file("nload.hoc")

jrtm_INTF6 = tstop + 1 // less printouts from intf6.mod

//* declares

declare("myncl",new List(),"myspkl",new List(),"myspktyl",new List(),"vice",new Vector(col.allcells))
declare("vtarg",new Vector(col.allcells))
declare("syl","o[2]") // list of sywvs

declare("nqrat","o[1]")
declare("ltab","o[1]","contab","o[1]","wtab","o[2]","mtab","o[1]")
declare("veinlyr",new Vector(CTYPi)) // corresponding E cells for I cells of a layer
veinlyr.x(I2)=E2
veinlyr.x(I2L)=E2
veinlyr.x(I4)=E4
veinlyr.x(I4L)=E4
veinlyr.x(I5)=E5R
veinlyr.x(I5L)=E5R
veinlyr.x(I6)=E6
veinlyr.x(I6L)=E6

declare("witer",0,"updinc",2e3,"uprob",0.1)
declare("EEinc",0.01,"EIinc",0.01,"IEinc",0.01,"IIinc",0.01,"skipI",1)

declare("vrecw",new Vector(),"nqrec","o[1]")

declare("IFctr",1.01,"IRate",4,"ERate",1)

//* lowrefrac - lower refractory period
proc lowrefrac () { local i localobj xo
  // lower refrac to allow more flexible freq. alterations
  for i=0,numcols-1 for ltr(xo,col[i].ce) if(!ice(xo.type))xo.refrac=2.5 else xo.refrac=1.25
}

//* settarg - set target # of spikes (per period)
proc settarg () { local i localobj xo
  for ltr(xo,col.ce,&i) {
    if(ice(xo.type)) vtarg.x(i) = IRate else vtarg.x(i) = ERate
  }
  vtarg.mul(updinc/1e3)
}

//* settunerc - setup recording of spikes used in tuning
proc settunerec () { local i localobj xo,nc
  for i=0,CTYPi-1 myspktyl.append(new Vector())
  for ltr(xo,col.ce,&i) {
    xo.flag("out",1) // make sure NetCon events enabled from this cell
    myncl.append(nc=new NetCon(xo,nil))
    myspkl.append(new Vector())
    nc.record(myspkl.o(i)) // record each cell separately
    vice.x(i)=ice(xo.type)
  }
}

//* mksyl - setup lists of weight vectors
proc mksyl () { local i,dvt localobj vw1,vw2
  for i=0,1 syl[i]=new List()
  for i=0,col.allcells-1 {
    dvt=col.ce.o(i).getdvi()
    vw1=new Vector(dvt)
    vw2=new Vector(dvt)
    col.ce.o(i).getsywv(vw1,vw2)
    syl[0].append(vw1)
    syl[1].append(vw2)
  }
}

//* conn2tab - make lookup tables with connectivity info
obfunc conn2tab () { local i,j,k,id1,id2 localobj ltab,col,nqc,contab,wtab1,wtab2,mtab,vc
  col=$o1 ltab=new List() vc=new Vector(col.allcells)
  for i=0,3 ltab.append(new Matrix(col.allcells,col.allcells))
  {contab=ltab.o(0) wtab1=ltab.o(1) wtab2=ltab.o(2) mtab=ltab.o(3)}
  if(col.connsnq==nil) {
    print "conn2tab ERR: col.connsnq is nil"
    return nil
  }
  nqc=col.connsnq
  nqc.sort("del") // make sure order in NQS corresponds to getdvi order
  for i=0,nqc.v.size-1 {
    id1=nqc.v[0].x(i) // from id1
    id2=nqc.v[1].x(i) // to id2
    contab.x(id1,id2)=1 // is there a connection?
    wtab1.x(id1,id2)=nqc.v[4].x(i) // weight 1
    wtab2.x(id1,id2)=nqc.v[5].x(i) // weight 2
    mtab.x(id1,id2) = vc.x(id1) // index into div vector -- assumes order in connsnq according to div
    vc.x(id1) += 1
  }
  return ltab
}

//* updatewts - update the weights
proc updatewts () { local i,j,md,df,fctr,inc,w0,w1,idx,trg,ety,cidx,a localobj xo,vs,ce
  print "t:", t, ", witer:",witer
  //a=allocvecs(vs)
  ce=col.ce
  for i=0,CTYPi-1 if(col.numc[i]) myspktyl.o(i).resize(0) //setup per-type counts
  for ltr(xo,ce,&i) myspktyl.o(xo.type).append(myspkl.o(i))
  for ltr(xo,ce,&i) {
//    if(rdm.uniform(0,1) < uprob) continue
    if(vice.x(i)) { // presynaptic I cell      
      for j=0,col.allcells-1 if(contab.x(i,j) && myspkl.o(i).size) {
        if(vice.x(j)) { // postsynaptic I cell
          if(skipI) continue
          inc = IIinc
          ety = veinlyr.x(ce.o(j).type)
          trg = MAXxy(vtarg.x(j),IFctr*myspktyl.o(ety).size/col.numc[ety])
        } else { // postsynaptic E cell
          inc = IEinc
          trg = vtarg.x(j)
        }
        df = myspkl.o(j).size - trg
        if(df > 0) fctr = 1 + inc else if(df < 0) {
          fctr = 1 - inc
        }
        idx = mtab.x(i,j)
        w0 = syl[0].o(i).x(idx) * fctr
        if(w0 < 0) syl[0].o(i).x(idx)=wtab[0].x(i,j)=0 else syl[0].o(i).x(idx)=wtab[0].x(i,j)=w0
      }
      xo.setsywv(syl[0].o(i),syl[1].o(i)) // reset weights
    } else { // presynaptic E cell
      for j=0,col.allcells-1 if(contab.x(i,j) && myspkl.o(i).size) {
        if(vice.x(j)) { // postsynaptic I cell
          if(skipI) continue
          inc = EIinc
          ety = veinlyr.x(ce.o(j).type)
          trg = MAXxy(vtarg.x(j),IFctr*myspktyl.o(ety).size/col.numc[ety])
        } else { // postsynaptic E cell
          inc = EEinc
          trg = vtarg.x(j)
        }
        df = myspkl.o(j).size - trg
        if(df > 0) fctr = 1 - inc else  if(df < 0) {
          fctr = 1 + inc
        }
        idx = mtab.x(i,j) // index into div,weight array
        w0 = syl[0].o(i).x(idx) * fctr // updated weight
        if(w0 < 0) syl[0].o(i).x(idx)=wtab[0].x(i,j)=0 else syl[0].o(i).x(idx)=wtab[0].x(i,j)=w0
        syl[1].o(i).x(idx) = wtab[1].x(i,j) = syl[0].o(i).x(idx) * NMAMR // NMDA weight
      }
      xo.setsywv(syl[0].o(i),syl[1].o(i)) // reset weights
    }
  }
  for vtr(&i,vrecw) {
    for j=0,col.allcells-1 if(contab.x(i,j)) {
      idx = mtab.x(i,j)
      nqrec.append(i,j,ce.o(i).type,ce.o(j).type,syl[0].o(i).x(idx),syl[1].o(i).x(idx),witer)
    }
  }
  witer += 1
  for i=0,myspkl.count-1 myspkl.o(i).resize(0) // reset spike counts for cells to 0
  for(i=CTYPi-1;i>=0;i-=1) if(col.numc[i]) {
    j=1e3*myspktyl.o(i).size/(col.numc[i]*updinc)
    print CTYP.o(i).s, " " , j , " avg Hz "
    nqrat.append(witer,i,j)
    myspktyl.o(i).resize(0) // reset spike counts for types to 0
  }
  //dealloc(a)
  cvode.event(t+updinc,"updatewts()") // set next update weights event
}

//* mysend - starts off the update q
proc mysend () { 
  {nqsdel(nqrat) nqrat=new NQS("witer","ty","rate")}
  {nqsdel(nqrec) nqrec=new NQS("id1","id2","ty1","ty2","w0","w1","witer")}
  cvode.event(updinc,"updatewts()") 
}
declare("fith",new FInitializeHandler("mysend()"))

//* calls

lowrefrac()
settarg()
settunerec()
mksyl()

ltab=conn2tab(col)
contab=ltab.o(0)
wtab[0]=ltab.o(1)
wtab[1]=ltab.o(2)
mtab=ltab.o(3)