// $Id: run.hoc,v 1.253 2002/04/22 14:48:41 billl Exp $

//* prpatt(vec) presents pattern
ipatt = 0  // the chosen pattern for presentation
proc prpatt () { local ii,jj
  for ii=0,$o1.size-1 { 
    if ($o1.x[ii]>0) {
      incell[ii].v = 50
    }
  }
  ininh.v = 50
}

//* outputData
proc initMisc1() {
  gotime = 5
  nextgo = 1e10
}

proc outputData() {
  if (t>gotime+dt) { 
    forsec "incell" { v = v_init } 
    gotime = gotime + nextgo
  } else if (t>gotime) { 
    prpatt(ivl.object(ipatt))
  }
}

//* Printlist stuff
printlist.remove_all()
tmplist = new List("NRN")
tmplist.remove(tmplist.count-1)
record(tmplist,"soma.v(0.5)")

//* grast() puts filled squares on 1s and 0's on 0s
proc grast () { local i
  YO = ovl.object(ipatt)
  sprint(grvecstr,"%s%02d:%02d",datestr,runnum-1,ipatt)
  panobjl.object(0).glist.object(0).label(0.1,0.8,grvecstr)
  for lvtr(XO,&x,panobjl.object(0).glist,YO) {
    if (x==1) { XO.mark(tstop/4,40,"S",8)
    } else    { XO.mark(tstop/4,40,"o",8) }
  }
}

//* show() run out a vector and then superimpose the squares
proc show () { 
  if ($1>npatt) { printf("%d greater than npatt: %d.\n",$1,npatt) return }
  ipatt = $1
  run()
  geall(0)
  grall(0,0,39)
  grast()
}

//* testrun() run through all the patterns and compare output
proc testrun () { local ii,jj,p1,p2,ham
  veclist.remove_all()
  p1 = allocvecs(2)  
  p2=p1+1
  mso[p2].resize(szout)
  for ipatt=0,npatt-1 {
    run()
    rdplist(mso[p2]) // scan the printlist for firing
    savevec(mso[p2])
    print ipatt
  }
  for ltr2(XO,YO,veclist,ovl) {
    ham = XO.hamming(YO)
    pushvec(mso[p1],ham)
  }
  print ""  mso[p1].printf  print "mean=",mso[p1].mean
  dealloc(p1)
}

//* testrun1() just test 1 of them
proc testrun1 () { local ii,jj,p1,p2,ham
  p1 = allocvecs(2)  
  p2=p1+1
  mso[p2].resize(szout)
  ipatt=$1
  if (numarg()==2) { run() print "" }
  rdplist(mso[p2]) // scan the printlist for firing
  ovl.object(ipatt).vpr
  mso[p2].vpr
  print mso[p2].hamming(ovl.object(ipatt))
  dealloc(p1)
}

//* rdplist(vec) reads the printlist and sets vec entries to 1 if firing there
thresh = 0
thresht = 300
proc rdplist() { local ii,max,tme
  $o1.resize(0)  
  ii = -1
  for ii=0,printlist.count-1 { // can't 'return' from an iterator
    XO = printlist.object(ii)
    arraynum(XO.var,"NRN","soma.v(0.5)")
    if (strcmp(temp_string_,"0")==0) { ii = 0  // first one
    } else if (ii>=szout) { return // all done
    } else if (ii>=0) {
      sprint(temp_string2_,"%d",ii)
      if (strcmp(temp_string_,temp_string2_)!=0) { 
        printf("%s out of order.\n",XO.var) return }
    }
    tme = XO.vec.indwhere(">",thresh) * printStep // time of first spike    
    if (tme<0) { tme=1e10 } // no spikes were found
    max = (tme < thresht)
    pushvec($o1,max)
  }
}

//* arraynum(STR,NAME,TERM)
// return the NAME array number for STR where string contains TERM
// returns -1 if does not fit 
proc arraynum() { local l1,l2,l3
  temp_string_ = ""
  if (sfunc.substr($s1,$s2) != 0) { return }
  if (sfunc.substr($s1,$s2) == -1) { return }
  sfunc.tail($s1,$s2,temp_string_)
  sfunc.right(temp_string_,1) // remove '\['
  sfunc.head(temp_string_,"\\]",temp_string_)
}

//* variance ()
// calculate variance by moving mean from vec.x[i] to x
// NB: binom() changed
proc variance () { local i
  i = 0
  print "NB: must set up vec with means in order to calc stdev's."
  tot = fac(szinp)/fac(iflip)^2 // the total number of vectors
  for (BVBASE=-1.5;BVBASE<=0.9;BVBASE=BVBASE+0.1) {
    mean = vec.x[i]
    binom(iflip)
    i=i+1
  }
}

//* check out questionable identity
// log of M!/(M/m)!^m
func fh() { return logfac($1)-$2*logfac($1/$2) }
func lcb() { return logfac($1)-logfac($1-$2)-logfac($2) }
func cb() { return exp(logfac($1)-logfac($1-$2)-logfac($2)) }
func pm() { return exp(logfac($1)-logfac($1-$2)) }
// compare with binomial expansion to the m power
proc idnty () { local M,m,i,j,s1,F
  M = $1
  m = $2
  F = int(M/m)
  if (F!=M/m) { printf("ERROR: %d not divisible by %d.\n",M,m) return }
  s1 = 0
  for i=0,F {
    a = exp(m*lcb(F,i))
    s1 = s1 + a
  }
  a = exp(fh(M,m))
  print s1," ",a," ",s1-a
}

//* binom(num) where num is F - the number of flips
// now calculates out the expected value
func binom () { local num,i,j,s1,s2,s3,ni,nf,full
  num = $1
  s1 = s2 = s3 = 0
  nf = fac(num)
  for i=0,num { // doesn't make any difference if include X.X or not (eg 1,num)
    ni = num-i
    a = nf/fac(ni)/fac(i)
    a = a*a/tot
    s3 = s3 + a*(((ni*BVBASE*BVBASE+ ni + 2*i*BVBASE) - mean)^2)
    s1 = s1 + a*(ni*BVBASE*BVBASE+ ni + 2*i*BVBASE)
  }
  full = num*BVBASE*BVBASE + num // num*1*1 actually
  printf("%g\t\t%g\t\t%g\t\t%g\t\t%g\n",BVBASE,s1,sqrt(s3),full,full/s1)
  return s1
}

//* compdots (): compare dotprods
// do allocvecs first
proc compdots () { local i,j,k,s1,s2,last,bnm,ll
  ll = 0
  mkiovecs(ivl,ovl)    
  last = BVBASE
  for (BVBASE=-1.2;BVBASE<=-0.8;BVBASE=BVBASE+0.1) {
    for ltr(XO,ivl) { XO.sw(last,BVBASE) }
    s1=s2=0
    for ltr(XO,ivl) {
      for i=(i1+1),ivl.count-1 {
        YO=ivl.object(i)
        if (!eqobj(XO,YO)) {
          pushvec(mso[ll],XO.dot(YO))
        }
      }
    }
    //    bnm = binom(iflip)
    //    printf("%g\t\t%g\t\t%g\t\t%g\t\t%g\n",BVBASE,XO.dot(XO),s1/s2,bnm,XO.dot(XO)/bnm)
    last = BVBASE
    ll = ll+1
  }
}

//* proc dotprods() { 
proc dotprods() { local sum,n,min,max,ham,sum2,dot
  sum=0  sum2=0
  n = 0
  for ltr(XO,ovl) {
//  for ltr(XO,tmplist) 
    sum=0 n=0
    min=1e10 max=-1e10
    for ltr(YO,ovl) { 
      if (! eqobj(XO,YO)) {
        ham = XO.hamming(YO) 
        dot = XO.dot(YO) 
//      print XO," ",YO," ",XO.dot(YO)," ",ham
//        printf("%d ",XO.hamming(YO))
        if (ham>max) { max=ham }
        if (ham<min) { min=ham }
//        sum = sum+XO.hamming(YO)  n=n+1
        sum = sum+XO.dot(YO)  n=n+1
      }
    }
    print XO," ",min," ",max," ",sum/n
    sum2 = sum2+sum
  }
  print sum2/n/ovl.count
}

//* stats() runs some statistics on performance
proc stats() { local ii,jj,p1,p2,ham,savszout,oflip
  rs = $1
  savszout = szout
  p1 = allocvecs(2)  p2=p1+1
//  for case(&szout,300,150,100,75,50,25) 
  for case(&szout,40)  {
    oflip = szout/2 // density of excitation out (Fo)
    mso[p2].resize(szout)
    ham = 0
    for ii=0,rs-1 {
      mkiovecs(ivl,ovl)
      crosstalk()
      makemat(mat,ivl,ovl)
      makeinh(inhv,ovl)
      jj = 0
      for ltr2(XO,YO,ivl,ovl) {
        kala(mso[p2],XO)
        ham = ham + mso[p2].hamming(YO)
        jj = jj+1
      }
    }
    ham = ham/jj/rs
    printf("%d/%d N=%d stats(%d)=%g\n",szinp,szout,npatt,rs,ham)
  }
  dealloc(p1)
}

//* proc nontarg ()
proc nontarg () { local jj,kk,p0,p1,p2,p3,p4,p5,p6,ham2,ham1,szi1,szf
  szf = 50
  szi1 = szf+1
  ind.indgen(0,szf,1)
  tmpfile.aopen(output_file)
  p0 = allocvecs(7,szinp)
  p1=p0+1 p2=p1+1 p3=p2+1 p4=p3+1 p5=p4+1 p6=p5+1
  mso[p2].resize(szout)
  mso[p5].resize(szi1*npatt)
  mso[p6].resize(szi1*npatt)
  if (veclist.count!=convn) { veclist.remove_all() mkveclist(convn,szinp[convn-1]) }
  for (kalap = 2.5;kalap<8;kalap=kalap+1) {
    print kalap
    for ltr(YO,ovl) { // npatt rows go through input/output pairs
      for jj=0,szf {  // cols change from 0 to half the bits
        for ii=0,convn-1 {
          mso[p3].resize(szinp[ii]) mso[p4].resize(szinp[ii])
          mso[p3].copy(ivl[ii].object(i1))
          mso[p3].flipbits(mso[p4],jj) // flip them in the copy not the original
          veclist.object(ii).copy(mso[p3])
        }
        kala(veclist)
        ham1 = test[0].hamming(YO) // calc hamming distance from target
        ham2 = test[0].sum // calc activity in the output
        mso[p5].mset(i1,jj,szi1,ham1)  // save the value
        mso[p6].mset(i1,jj,szi1,ham2)  // save the value
      }
    }
    mso[p4].resize(npatt)
    mso[p0].resize(szi1) mso[p1].resize(szi1)
    mso[p2].resize(szi1) mso[p3].resize(szi1)
    for ii=0,szf { 
      mso[p4].mcol(mso[p5],ii,szi1)
      mso[p2].x[ii] = mso[p4].mean
      mso[p3].x[ii] = mso[p4].stdev
      mso[p4].mcol(mso[p6],ii,szi1)
      mso[p0].x[ii] = mso[p4].mean
      mso[p1].x[ii] = mso[p4].stdev
    }
    //  printf("hamming: mso[%d],mso[%d]; sum: mso[%d],mso[%d] (mean,sdev)\n",p2,p3,p0,p1)
    tmpfile.printf("\n\"h %d %g\n",conva,kalap)
    vprf(ind,mso[p2],mso[p3])
    tmpfile.printf("\n\"s %d %g\n",conva,kalap)
    vprf(ind,mso[p0],mso[p1])
  }
  tmpfile.close
}

//* kala (testvec,mnum)
// kala = kohonen-anderson linear associator
// process the vector through the linear associator
proc kala () { local mnum
  for mnum=0,convn-1 {
    kala1(test[mnum],mat[mnum],$o1.object(mnum),mnum)  // excitation
    if (mnum>0) { test[0].add(test[mnum]) }
  }
  test[0].thresh(kalap)
}

proc kala1 () { local mnum
  $o1.mmult($o2,$o3)  // excitation
  $o1.add(inhv[$4])      // inhibition
  $o1.mul(npatt/ddot[$4]) // normalization
  $o1.thresh((BVBASE+1)/2)   // thresholding
}

//* proc testmat ()
// testmat([MAT#,INV#]) test mat MAT# with INV#, comparing output with OUV#
// default is to combine all mats (also with MAT#=-1), and try all INV#
proc testmat () { local ii,jj,p1,p2,ham,mnum,just1,min,max
  if (numarg()>0) { just1=$1 } else { just1=-1 }
  if (numarg()>1) { min=$2 max=$2 } else { min=0 max=ovl.count-1 }
  p1 = allocvecs(2)  p2=p1+1
  mso[p2].resize(szout)
  for jj=min,max {
    YO = ovl.object(jj)
    tmplist.remove_all
    for ii=0,convn-1 { tmplist.append(ivl[ii].object(jj)) }
    if (just1==-1) { kala(tmplist) 
    } else { kala1(test[0],mat[just1],ivl[just1].object(jj),just1) }
    ham = test[0].hamming(YO)
    pushvec(mso[p1],ham)
  }
  mso[p1].printf  
  print "mean=",mso[p1].mean
  dealloc(p1)
}

proc showtest() { local ii
//  newPlot(0,npatt,0,20)
  for ii=0,convn-1 {
    dealloc()
    kalap = ii+0.5
    testmat() // comment dealloc out of testmat() in order to show
//    mso[0].mark(graphItem,1,sym[ii].s,6,ii+1)
  }
}

// compare output without thresholding
// for ltr2(XO,YO,ovl,veclist) {
//   for vtr2(&x,&y,XO,YO) {printf("%g ",x-y) }
//   printf("\n\n")
// }

//* move lists into arrays
objref rr[npatt],ou[npatt],in[npatt]
for ltr(XO,ivl) { in[i1] = XO }
for ltr(XO,ovl) { ou[i1] = XO }
for ltr(XO,veclist) { rr[i1] = XO }

//* proc row(), col() 
proc row () {
  ind.resize(szinp)
  ind.mrow(mat,$1,szinp)
  ind.printf
}

proc col () {
  ind.resize(szout)
  ind.mrow(mat,$1,szinp)
  ind.printf
}

mktray(0,10,4)
setrange(0,0,tstop,-100,50)
setrange(0,-1)  // erase axes
show(28)  // show results for pattern#28
revec(vec,-100,50)
revec(ind,240,240)
for ltr(XO,panobj.glist) vec.line(XO,ind,2,5)