/* $Id: samutils.hoc,v 1.59 2009/12/11 03:56:04 samn Exp $ */

use_samn_utils=0
if(name_declared("INSTALLED_sn")){
  use_samn_utils=1
  install_sn()
} else {
  printf("Warning: couldn't install samnutils.mod\n")
}

//whether to print debug info
mydbg=0


//////////////////////////////////////////////////////////////////////////////////////
//color consts
black=1
red=2
blue=3
green=4
//////////////////////////////////////////////////////////////////////////////////////

//////////////////////////////////////////////////////////////////////////////////////
//time functions



//////////////////////////////////////////////////////////////////////////////////////


//////////////////////////////////////////////////////////////////////////////////////
//drawing functions
//plot and mark $o1 in color $2 using optional xinc = $3
proc plotit () { local xi localobj vv
  vv=$o1 clr=$2
  if(numarg()>2)xi=$3 else xi=1
  vv.plot(g,xi,clr,4)
  vv.mark(g,xi,"O",12,clr,1)
}

//plot $o1 with error bars in $o2, using xvec $o3, and color $4
proc plotite () { local clr localobj vv,ve,vx
  vv=$o1 ve=$o2 vx=$o3 clr=$4
  vv.ploterr(g,vx,ve,15,clr,4)
  vv.plot(g,vx,clr,4)
  vv.mark(g,vx,"O",15,clr,1)
}
//////////////////////////////////////////////////////////////////////////////////////


///////////////////////////////////////////////////////////
//string functions

func strlen(){ localobj strobj
  strobj = new StringFunctions()
  if(argtype(1)==1){
    return strobj.len($o1.s)
  } else if(argtype(1)==2){
    return strobj.len($s1)
  }
  printf("strlen ERRA: invalid argtype\n")
  return 0
}

func strcpy(){
  if(argtype(1)==1 && argtype(2)==1){
    sprint($o1.s,$o2.s)
    return 1
  } else if(argtype(1)==1 && argtype(2)==2){
    sprint($o1.s,$s2)
    return 1
  }
  return 0
}
///////////////////////////////////////////////////////////

//utility functions

func MIN(){
  if($1 <= $2) return $1
  return $2
}

func MAX(){
  if($1 >= $2) return $1
  return $2
}

//return 1 iff int($1) is a power of 2
//$o1 can be list or vector, returns
//true iff their size,count is powof2
func powof2(){ local val

  if(argtype(1)==0){
    val = int($1)
  } else if(isobj($o1,"Vector")){
    val = $o1.size
  } else if(isobj($01,"List")){
    val = $o1.count
  } else {
    printf("powof2 ERRA: invalid arg type\n")
    return -1
  }

  if(val==1) return 1

  if(val%2==1 || val<=0) return 0

  while(val>=2){
    if(val%2==1) return 0
    val /= 2
  }

  return 1
}

//return closest powof2 >= $1
//$o1 can be list or vector, returns
//closest powof2 >= their size,count
func ceilpowof2(){ local ceilval,val

  if(argtype(1)==0){
    val = $1
  } else if(isobj($o1,"Vector")){
    val = $o1.size
  } else if(isobj($o1,"List")){
    val = $o1.count
  } else {
    printf("ceilpowof2 ERRA: invalid arg type\n")
    return -1
  }

  ceilval = 1

  while(ceilval < val) ceilval *= 2

  return ceilval
}

//return closest powof2 <= $1
//$o1 can be list or vector, returns
//closest powof2 <= their size,count
func floorpowof2(){ local floorval,val

  if(argtype(1)==0){
    val = $1
  } else if(isobj($o1,"Vector")){
    val = $o1.size
  } else if(isobj($o1,"List")){
    val = $o1.count
  } else {
    printf("floorpowof2 ERRA: invalid arg type\n")
    return -1
  }

  floorval = val

  while(!powof2(floorval) && floorval >= 2) floorval -= 1

  while(floorval > val) floorval /= 2

  return floorval
}

//discrete wavelet transform using db4
//$o1 = input vector
//$2 = floor ? (if true use floor of pow of 2 of size, else ceil)
//$3 = inverse transform
//returns vector of 2^n dwt coefficients from d4t
obfunc dwt(){ local sz,inv,flr localobj d4t,vt

  d4t = new Vector()

  if(!INSTALLED_d4t_wavelet){
    printf("dwt ERRA: d4t not installed!\n")
    return d4t
  }

  if(!isobj($o1,"Vector")){
    printf("dwt ERRB: arg 1 must be vector!\n")
    return d4t
  }

  if(numarg() > 1) flr = $2 else flr = 1
  if(numarg() > 2) inv = $3 else inv = 1

  if(powof2($o1)){
    vt = $o1
    sz = $o1.size
  } else {
    if(flr) sz = floorpowof2($o1) else sz = ceilpowof2($o1)
    vt = new Vector(sz)
    vt.samp($o1,sz)
  }
 
  d4t = new Vector(sz)

  d4t.d4t(vt,inv,0)

  return d4t
}

//zeroes a level of dwt coefficients (for filtering)
//levels are 0 - (n-1) where 2^n is size of
//dwt coeffs vec
//$o1 = vector of dwt coefficients
//$2 = level to zero out or vec of levels to zero out
func zerodwtlevel(){ local idx,jdx,maxdx,offset,level localobj vlev

  if(!isobj($o1,"Vector")){
    printf("zerodwtlevel ERRA: arg 1 must be vector!\n")
    return 0
  }

  if(argtype(2)==0){
    level = $2
  
    offset = 2^level

    maxdx = 2*offset

    for(idx=offset;idx<maxdx;idx+=1) $o1.x(idx) = 0

  } else {
    vlev = $o2    
    for(idx=0;idx<vlev.size;idx+=1){
      zerodwtlevel($o1,vlev.x(idx))
    }
  }


  return 1
}

//upsample input using steps
//$o1 = input vec
//$2 = output size
obfunc stepupsamp(){ local idx,jdx,kdx,fctr localobj vtmp
  vtmp=new Vector($2)

  if(!isobj($o1,"Vector")){
    printf("stepupsamp ERRA: arg 1 must be Vector!\n")
    return vtmp
  }

  if($o1.size == 0){
    printf("stepupsamp ERRB: arg 1 Vector size is 0!\n")
    return vtmp
  }

  fctr = $2 / $o1.size

  if(fctr < 1){
    printf("stepupsamp ERRC: downsampling not allowed!\n")
    return vtmp
  }
  
  kdx = 0
  for(idx=0;idx<$o1.size;idx+=1){
    for(jdx=0;jdx<fctr;jdx+=1){
      vtmp.x(kdx) = $o1.x(idx)
      kdx += 1
    }
  }

  return vtmp
}

//returns dwt coefficient levels in list of vectors
//$o1 = vector of dwt coefficients
//$2 = resample to size [optional]
//$3 = use step resample
obfunc dwtlevels(){ local nlevels,rsz,idx,stp localobj lv,vtmp,vtmp2

  lv = new List()

  if(!INSTALLED_d4t_wavelet){
    printf("dwtlevels ERRA: d4t not installed!\n")
    return lv
  }

  if(!isobj($o1,"Vector")){
    printf("dwtlevels ERRB: arg 1 must be vector!\n")
    return lv
  }

  if(numarg() > 1) rsz = $2 else rsz = -1
  if(numarg() > 2) stp = $3 else stp = 0

  nlevels = log($o1.size) / log(2)

  for(idx=0;idx<nlevels;idx+=1){
    vtmp = new Vector(2^idx)
    vtmp.d4tlncoeffs($o1,idx)
    if(rsz > 0){
      if(stp){
        vtmp2 = stepupsamp(vtmp,rsz)
      } else {
        vtmp2 = new Vector(rsz)
        vtmp2.samp(vtmp,rsz)
      }
      lv.append(vtmp2)
    } else {
      lv.append(vtmp)
    }
  }

  return lv
}

//takes list of vectors , returns vector with avgs
obfunc GetListAvg(){ local idx,jdx,mx localobj vavg
  mx=0
  for idx=0,$o1.count-1 if($o1.o(idx).size>mx) mx = $o1.o(idx).size
  if(mx==0) return nil else vavg = new Vector(mx)
  for idx=0,$o1.count-1 {
    for jdx=0,$o1.o(idx).size-1{
      vavg.x(jdx) = vavg.x(jdx) + $o1.o(idx).x(jdx)
    }
  }
  vavg.div($o1.count) //
  return vavg
}

//intervals(TRAIN,OUTPUT)
//$o1 spikes/voltages
//$o2 interval vector
func intervals () {
  if ($o1.size<=1) { printf("%s size <2 in intervals()\n",$o1) return 0}
  $o2.deriv($o1,1,1)
  return $o2.size
}

//return derivative of input vector
//output vector has same size as input
//$o1 = input vector
//vp = output vector
obfunc Deriv(){ localobj vp
  vp=new Vector($o1.size)
  if($o1.size < 2){
    printf("Deriv ERRAA: input vec size < 2!\n")
    return vp
  }
  vp.deriv($o1,1,2)
  return vp
}
 
//return list of input vector ($o1),
//and 1st & 2nd derivs 
obfunc Traj(){ localobj lv,vp,vpp
  lv = new List()
  vp = Deriv($o1)
  vpp = Deriv(vp)
  lv.append($o1)
  lv.append(vp)
  lv.append(vpp)
  return lv
}
 
//plot list.o(1) vs list.o(0)
//and list.o(2) vs list.o(1)
//return list of graphs
obfunc PlotTraj(){ localobj lv,vp,vpp,lg
  lg = new List()
  lv = $o1
  vp = lv.o(1)
  vpp = lv.o(2)
  vp.label("v'")
  vpp.label("v''")
  lg.append(new Graph())
  vp.plot(lg.o(lg.count-1),lv.o(0))
  lg.append(new Graph())
  vpp.plot(lg.o(lg.count-1),vp)
  return lg
}

//sets up recording of spike times
//$o1 = cell
//$o2 = printlist obj
//$s3 = string
//$4 = spike thresh
proc RecordCell(){
  if(numarg()==3){
    $o2 = new_printlist_nc($o1,0,$s3)
  } else if(numarg()==4){
    $o2 = new_printlist_nc($o1,0,$s3,$4)
  }
}

//$o1 = vec
//$2 = size
//allocates memory for vector but resizes it to 0
proc VecMalloc(){
  $o1 = new Vector()
  $o1.resize($2)
  $o1.resize(0)
}

proc PrintBurstInfo(){
  if($o1.size < 4){
    printf("invalid size for burst info vector\n")
    return
  }
  printf("num bursts = %d\n",$o1.x(0))
  printf("spikes per burst = %d\n",$o1.x(1))
  printf("inter-spike time = %g\n",$o1.x(2))
  printf("inter-burst time = %g\n",$o1.x(3))
}

//$o1 = spike times
//$o2 = output vector[num_bursts,spikes_per_burst,inter_spike_time,inter_burst_time]
func GetBurstInfo(){ local i,lg_thresh,sm_thresh,spikes_per_burst localobj v_inter_spike,v_inter_burst,v_intervals

  lg_thresh = 20   //min time btwn bursts
  sm_thresh = 3.55  //max time btwn spikes within a burst

  if($o1.size < 2){
    $o2 = new Vector()
    $o2.resize(4)
    $o2.x(0)=$o2.x(1)=$o2.x(2)=$o2.x(3)=0
    printf("need at least two spikes to get burst info\n")
    return 0
  }

  v_intervals = new Vector() intervals($o1,v_intervals)
  //printf("here are the intervals:\n")
  //v_intervals.printf

  VecMalloc(v_inter_spike,v_intervals.size)
  VecMalloc(v_inter_burst,v_intervals.size)

  for(i=0;i<v_intervals.size;i+=1){
    if(v_intervals.x(i) >= lg_thresh){
      //printf("\tinter-burst time = %g\n",v_intervals.x(i))
      v_inter_burst.append(v_intervals.x(i))
    } else if(v_intervals.x(i) <= sm_thresh) {
      //printf("\tinter-spike time = %g\n",v_intervals.x(i))
      v_inter_spike.append(v_intervals.x(i))
    }
  }

  $o2 = new Vector() $o2.resize(4)

  //if no valid inter-spike / inter-burst found, return 0
  if(v_inter_spike.size == 0 && v_inter_burst.size == 0) {
    printf("\tno valid inter-spike / inter-burst found\n")
    return 0
  }

  //if( ((v_inter_burst.max-v_inter_burst.min) / v_inter_burst.max) >= 0.3){
    //printf("\tno consistent inter-burst time\n")
  //  return 0
  //}

  $o2.x(0) = v_inter_burst.size() + 1 //# of bursts
  $o2.x(1) = $o1.size / $o2.x(0)  //# of spikes per burst, assumes they are equally distributed

  if(v_inter_spike.size > 0) { $o2.x(2) = v_inter_spike.median() }  //inter-spike time
  if(v_inter_burst.size > 0) { $o2.x(3) = v_inter_burst.median() }  //inter-burst time

  //if no valid inter-spike / inter-burst found, return 0
  //if($o2.x(2) == 0 && $o2.x(3) == 0) {
  //  printf("\tno valid inter-spike / inter-burst found\n")
  //  return 0
  //}

  return 1
}

//$o1 = stim
//$2 = min amp
//$3 = max amp
//$4 = amp inc
//$o5 = vector of # of bursts
//$o6 = vector of # of spikes per burst
//$o7 = inter-spike times
//$o8 = inter-burst times
//$o9 = cell
func GetBurstInfoVecs(){ local curamp,minamp,maxamp,ampinc localobj cell_rec,v_burst
   RecordCell($o9,cell_rec,"cell.soma.v")
   minamp = $2 maxamp = $3 ampinc = $4
   $o5 = new Vector()
   $o6 = new Vector()
   $o7 = new Vector()
   $o8 = new Vector()
   v_burst = new Vector()
   $o1.del = 40
   $o1.dur = 400
   tstop = 500
   for(curamp=minamp;curamp<=maxamp;curamp+=ampinc){
     $o1.amp = curamp
     printf("getting burst info for amp = %g\n",curamp)
     run()
     GetBurstInfo(cell_rec.tvec,v_burst)
     $o5.append(v_burst.x(0))
     $o6.append(v_burst.x(1))
     $o7.append(v_burst.x(2))
     $o8.append(v_burst.x(3))
   }
   return 1
}

//$o1 = graph
//$o2 = vector of # of bursts
//$o3 = vector of # of spikes per burst
//$o4 = inter-spike times
//$o5 = inter-burst times
//$o6 = amp
//$7 , whether to label
//$8 , whether to make new graph
proc PlotBurstInfoVecs(){
  if($8){
    $o1 = new Graph()
  }

  if($7){
    $o2.label("num bursts")
    $o3.label("spikes per burst")
    $o4.label("inter-spike time")
    $o5.label("inter-burst time")
  }

  $o3.plot($o1,$o6,1,1) //color brush
  $o5.plot($o1,$o6,2,1)
  $o2.plot($o1,$o6,3,1)
  $o4.plot($o1,$o6,4,1)

//  $o3.mark($o1,$o6,"+",6,1,1) //style size color brush
//  $o5.mark($o1,$o6,"t",6,2,1)
//  $o2.mark($o1,$o6,"s",6,3,1)
//  $o4.mark($o1,$o6,"o",6,4,1)

  $o1.flush()
}

//$o1 = voltage vector (x-axis=time, y-axis=voltage)
//$o2 = output spike times
//$3 = threshold for spike
//$4 = minimum interspike time
proc GetSpikeTimes() { local ii,v,dipthresh
  $o2 = new Vector($o1.size)
  $o2.resize(0)
  dipthresh = 1.5 * $3
  for(ii=0;ii<$o1.size;ii+=1) {
    v = $o1.x(ii)
    if(v >= $3) { //is v > threshold?
      if($o2.size>0) { //make sure at least $4 time has passed
        if(dt*ii-$o2.x($o2.size-1) < $4) {
          continue
        }
      }
      while( ii+1<$o1.size) { //get peak of spike
        if( $o1.x(ii) > $o1.x(ii+1) ) {
          break
        }
        ii += 1
      }
      $o2.append(dt*ii) //store spike location

      while(ii<$o1.size) { //must dip down sufficiently
        if($o1.x(ii) <= dipthresh) {
          break
        }
        ii += 1
      }
    }
  }
}

//just moves 1.25 ms to left & right of spikes
//returns vector with 1 if its in spike time window
//and 0 otherwise
//$o1 = voltage vector
//$o2 = spike times
//$3 = 1/2 time window around spikes [optional, default=1.25ms]
obfunc GetSpikeWindow(){ local leftv,rightv,meanv,dthresh,left,right,w,idx,jdx,mint,st,twin2\
                    localobj vspikes,vvolt,vns

  vvolt = $o1
  vspikes = $o2

  if(numarg() > 2) twin2 = $3 else twin2 = 1.25

  vns = new Vector(vvolt.size)
//  vns.copy(vvolt)

  for(idx = 0; idx < $o2.size; idx+=1){

    st = vspikes.x(idx)

    left = st/dt - twin2/dt
    if(left < 0) left = 0
    right = st/dt + twin2/dt
    if(right >= vvolt.size) right = vvolt.size-1
 
    w = right - left
    if(w == 0) w = 0

    leftv = vvolt.x(left)
    rightv = vvolt.x(right)

    for(jdx=left;jdx<=right;jdx+=1){
      vns.x(jdx) = 1      
//      vns.x(jdx) = ((w-(jdx-left))*leftv)/w + ((w-(right-jdx))*rightv)/w
    }   
  }

  return vns
}



//$o1 = voltage vector
//$o2 = spike times
obfunc CutOutSpikes(){ local leftv,rightv,meanv,dthresh,left,right,w,idx,jdx,mint,st\
                    localobj vspikes,vvolt,vns,vderiv

  vvolt = $o1
  vspikes = $o2

  vderiv = Deriv(vvolt)

  meanv = vderiv.mean
  dthresh = meanv+vderiv.stdev/4

  printf("dthresh = %f\n",dthresh)

  vns = new Vector(vvolt.size)
  vns.copy(vvolt)

  for(idx = 0; idx < $o2.size; idx+=1){

    st = vspikes.x(idx)

    left = st/dt - 5
    if(left < 0) left = 0
    right = st/dt + 5
    if(right >= vderiv.size) right = vderiv.size-1
 
    while(left > 0){
      if(vderiv.x(left) < dthresh){
        break
      }
      left-=1
    }

    while(right < vvolt.size-1){

      if(abs(vderiv.x(right)) < dthresh){
        break
      }
      right+=1
    }

    w = right - left
    if(w == 0) w = 0

    leftv = vvolt.x(left)
    rightv = vvolt.x(right)

    for(jdx=left;jdx<=right;jdx+=1){
//      vns.x(jdx) = 60      
      vns.x(jdx) = ((w-(jdx-left))*leftv)/w + ((w-(right-jdx))*rightv)/w
    }   
  }

  return vns
}

//$o1 = voltage vector (x-axis=time, y-axis=voltage)
//$o2 = spike time vector
//$o3 = graph obj
proc PlotSpikeTimes(){ local idx
  $o1.plot($o3,dt)
  for(idx=0;idx<$o2.size;idx+=1){
    $o3.mark($o2.x(idx),$o1.x($o2.x(idx)/dt),1,4,red,1)
  }
}

//get interspike interval times
//$o1 = spike train
//$o2 = output ISI start times
//$o3 = output ISI end times
func GetISITimes(){ local idx,mindist,tmp
  mindist = 1.25 //min dist after & before spike to go
  $o2 = new Vector()
  $o3 = new Vector()
  if($o1.size < 2){
    printf("need at least two spikes to get ISI\n")
    return 0
  }
  for(idx=0;idx<$o1.size-1;idx+=1){
    $o2.append($o1.x(idx)+mindist)
    $o3.append($o1.x(idx+1)-mindist)
    //swap if incorrect order due to mindist requirement
    if($o3.x($o3.size-1) <= $o2.x($o2.size-1)){
      tmp = $o3.x($o3.size-1)
      $o3.x($o3.size-1) = $o2.x($o2.size-1) + dt
      $o2.x($o2.size-1) = tmp
    }
  }
  return 1
}

//get trajectory v,v',v''
//$o1 = voltage vector
//$2 = start time
//$3 = end time
//$o4 = output copy of v
//$o5 = 1st deriv (v')
//$o6 = 2nd deriv (v'')
func GetTrajectory(){ local startt,endt
  $o4 = new Vector()  $o5 = new Vector() $o6 = new Vector()

  startt = $2
  endt = $3

  if(startt >= endt){
    printf("GetTrajectory ERRA: invalid time range (%f,%f)!\n",$2,$3)
    return 0
  }

  endt = endt / dt
  startt = startt / dt

  $o4.copy($o1,startt,endt) //copy v
  if($o4.size < 2){
    $o4.resize(0)
    printf("GetTrajectory ERRB: time range %f ms too small!\n",(endt-startt))
    return 0
  }

  $o5 = Deriv($o4)
  $o6 = Deriv($o5)
 
  return 1
}

//get all trajectories from start/end time vecs
//$o1 = voltage vector
//$o2 = start times
//$o3 = end times
//$o4 = list of v
//$o5 = list of v'
//$o6 = list of v''
func GetTrajectories(){ local idx,allbad localobj v1,v2,v3
  allbad = 1
  for(idx=0;idx<$o2.size;idx+=1){
    v1 = new Vector() v2 = new Vector() v3 = new Vector()
    if(GetTrajectory($o1,$o2.x(idx),$o3.x(idx),v1,v2,v3)){
      $o4.append(v1) $o5.append(v2) $o6.append(v3)
      allbad = 0
    } 
  }
  return !allbad
}

//$o1 = graph1
//$o2 = graph2
//$o3 = v
//$o4 = v'
//$o5 = v''
func PlotTrajectory(){
  if(!isassigned($o1)) $o1 = new Graph()
  $o4.label("v'")
  $o4.plot($o1,$o3)

  if(!isassigned($o2)) $o2 = new Graph()
  $o5.label("v''")
  $o5.plot($o2,$o4)

  return 1
}

//$o1 = list of v
//$o2 = list of v'
//$o3 = list of v''
//$o4 = list of graphs
func PlotTrajectories(){ local idx localobj mygv,mygvv
  for(idx=0;idx<$o1.count;idx+=1){
    mygv=new Graph()
    mygvv=new Graph()
    PlotTrajectory(mygv,mygvv,$o1.o(idx),$o2.o(idx),$o3.o(idx))
    $o4.append(mygv)
    $o4.append(mygvv)
  }
  return 1
}

//euclid dist btwn 2 points in 3D space
func EuclidDist(){ local x1,y1,z1,x2,y2,z2
  x1=$1 y1=$2 z1=$3 x2=$4 y2=$5 z2=$6
  return sqrt( (x1-x2)^2 + (y1-y2)^2 + (z1-z2)^2 )
}

//$o1 = vec of x1
//$o2 = vec of y1
//$o3 = vec of z1
//$o4 = vec of x2
//$o5 = vec of y2
//$o6 = vec of z2
func VEuclidDist(){ localobj l1,l2
  if(!use_samn_utils){
    printf("VEuclidDist ERRA: samnutils.mod not available\n")
    return -1
  }
  l1=new List() l2=new List()
  l1.append($o1) l1.append($o2) l1.append($o3)
  l2.append($o4) l2.append($o5) l2.append($o6)
  return LDist_sn(l1,l2)//in /usr/site/nrniv/local/mod/samnutils.mod
}

//$o1 = vec of x1
//$o2 = vec of y1
//$o3 = vec of z1
//$o4 = vec of x2
//$o5 = vec of y2
//$o6 = vec of z2
func VSQDist(){ localobj l1,l2
  if(!use_samn_utils){
    printf("VSQDist ERRA: samnutils.mod not available\n")
    return -1
  }
  l1=new List() l2=new List()
  l1.append($o1) l1.append($o2) l1.append($o3)
  l2.append($o4) l2.append($o5) l2.append($o6)
  return LDist_sn(l1,l2,SQDIFF_sn)//in /usr/site/nrniv/local/mod/samnutils.mod
}

//resample a vec to new size using linear interpolation
//$o1 = vec
//$2 = new size
func Resample(){ local newsz,idxdest,idxsrc,val,fctr,frac,last,lastset localobj vtmp
  vtmp = new Vector($2)
  fctr = $o1.size / $2
  vtmp.x(0) = $o1.x(0)
  idxsrc = fctr
  for(idxdest=1;idxdest<$2-1;idxdest+=1){
    idxsrc = idxdest * fctr
    frac = idxsrc - int(idxsrc)
    idxsrc = int(idxsrc)
    if(idxsrc+1>=$o1.size){
      vtmp.x(idxdest) = $o1.x(idxsrc)
      continue
    }
    val = (1-frac) * $o1.x(idxsrc) + frac * $o1.x(idxsrc+1)
    vtmp.x(idxdest) = val
  }
  vtmp.x($2-1) = $o1.x($o1.size-1)
  $o1.resize($2)
  $o1.copy(vtmp)
  return 1
}

//$o1 = stim
//$o2 = voltage vector (x-axis = time , y-axis = volt)
//$3 = min amp
//$4 = max amp
//$5 = amp inc
//$6 = threshold for spike
//$o7 = amp vec
//$o8 = f vec
//$o9 = cell
proc GetIFVecs() { local minamp,maxamp,ampinc,totalt,curamp,thresh
  minamp = $3 maxamp = $4 ampinc = $5 thresh = $6
  $o7 = new Vector()
  $o8 = new Vector()
  totalt = $o1.dur / 1000
  for(curamp = minamp; curamp <= maxamp; curamp += ampinc) {
    $o1.amp = curamp
    RecordCell($o9,$o2,"soma.v",$6)
    run()
    $o7.append(curamp)
    $o8.append($o2.tvec.size * (1/totalt))
  }
}

proc PlotIFVec(){ localobj g
  g = new Graph()
  $o1.label("Hz")
  $o2.label("nA")
  $o1.plot(g,$o2)
}

//$s1 = output file name
//$o2 = amp vec
//$o3 = freq vec
func SaveIFCurve(){ localobj f
  f = new File()
  f.wopen($s1)
  if(!f.isopen()) {
    printf("couldn't open %s for writing\n",$s1)
    return 0
  }
  if(!$o2.vwrite(f)){
    printf("couldn't write amp vec to %s\n",$s1)
    return 0
  }
  if(!$o3.vwrite(f)){
    printf("couldn't write freq vec to %s\n",$s1)
    return 0
  }
  f.close()
  return 1
}

//$s1 = input file name
//$o2 = amp vec
//$o3 = freq vec
func ReadIFCurve(){ localobj f
  f = new File()
  f.ropen($s1)
  if(!f.isopen()) {
    printf("couldn't open %s for reading\n",$s1)
    return 0
  }
  $o2 = new Vector()
  if(!$o2.vread(f)){
    printf("couldn't read amp vec from %s\n",$s1)
    return 0
  }
  $o3 = new Vector()
  if(!$o3.vread(f)){
    printf("couldn't read freq vec from %s\n",$s1)
    return 0
  }
  f.close()
  return 1
}

//writes vecs to file
//$s1 = path to file
//$o2 ... $on = vectors to be written to file
func WriteVecs(){ local i localobj f
  f = new File()
  f.wopen($s1)
  if(!f.isopen()) {
    printf("couldn't open %s for writing\n",$s1)
    return 0
  }
  for(i=2;i<=numarg();i+=1){
    printf("writing vector %d\n",i-1)
    if(!$oi.vwrite(f)){
      printf("couldn't write vector %d to %s\n",i-1,$s1)
      return 0
    }
  }
  f.close()
  return 1
}

//reads vecs from file
//$s1 = path to file
//$o2 ... $on = vectors to be read from file
func ReadVecs(){ local i localobj f
  f = new File()
  f.ropen($s1)
  if(!f.isopen()) {
    printf("couldn't open %s for reading\n",$s1)
    return 0
  }
  for(i=2;i<=numarg();i+=1){
    printf("reading vector %d\n",i-1)
    if(!$oi.vread(f)){
      printf("couldn't read vector %d from %s\n",i-1,$s1)
      return 0
    }
  }
  f.close()
  return 1
}

//$s1 = path to file
//$2 = index of vector to read
//$o3 = vector to read to
func ReadVecX(){ local idx localobj f
  f = new File()
  f.ropen($s1)
  if(!f.isopen()) {
    printf("couldn't open %s for reading\n",$s1)
    return 0
  }
  if(!isassigned($o3)) $o3 = new Vector()
  for(idx=0;idx<=$2;idx+=1){
    if(f.eof()){
      f.close()
      return 0
    }
    $o3.vread(f)
  }
  f.close()
  return 1
}

//tests whether file exists by opening in 'r' mode
func FileExists(){ localobj f
  f = new File()
  f.ropen($s1)
  if(f.isopen()){
    f.close()
    return 1
  }
  return 0
}

//reads vecs from file
//$s1 = path to file
//$o2 ... $on = vectors to be read from file
/*func Read2DVec(){ local i localobj f,cv
  f = new File()
  f.ropen($s1)
  if(!f.isopen()) {
    printf("couldn't open %s for reading\n",$s1)
    return 0
  }

  //first vector just has # of vectors in file
  cv = new Vector()
  if(!cv.vread(f)){
    printf("couldn't get # vectors in file\n")
    return 0
  }
  printf("there are %d vecs in file\n",cv.size)
  
  for(i=0;i<cv.size;i+=1){
    printf("reading vector %d\n",i+1)
    $3[i] = new Vector()
    if(!$3[i].vread(f)){
      printf("couldn't read vector %d from %s\n",i+1,$s1)
      return 0
    }
  }
  f.close()
  return 1
}*/

//$s1 = file path
//$o2 = vector of # of bursts
//$o3 = vector of # of spikes per burst
//$o4 = inter-spike times
//$o5 = inter-burst times
//$o6 = amp
func WriteBurstInfoVecs(){
  return WriteVecs($s1,$o2,$o3,$o4,$o5,$o6)
}

//$s1 = file path
//$o2 = vector of # of bursts
//$o3 = vector of # of spikes per burst
//$o4 = inter-spike times
//$o5 = inter-burst times
//$o6 = amp
func ReadBurstInfoVecs(){
  return ReadVecs($s1,$o2,$o3,$o4,$o5,$o6)
}

///////////////////////
//random related funcs

my_ran_ind = 0

objref r_uniform

proc InitRand(){
  r_uniform = new Random(1234567891011121314)
  my_ran_ind = r_uniform.MCellRan4()
  r_uniform.uniform(0,1)
}

func uniform(){
  return r_uniform.repick()
}

//$1 - mean
//$2 - stddev
white_noise_N = 20
sqrt_val = sqrt(12/white_noise_N)
func GetWhiteNoise(){ local X,N,U,i,mu,std
  mu=$1
  std=$2

  N=white_noise_N
  X=0

  for(i=0;i<N;i+=1){
     U = uniform()
     X +=  U
  }
	
  // for uniform randoms in [0,1], mu = 0.5 and var = 1/12 
  // adjust X so mu = 0 and var = 1 
	
  X  -=  N/2                // set mean to 0 
  X *= sqrt_val           // adjust variance to 1 X=X*sqrt(12/N)

  return mu + std*X
}

InitRand()

//init rand vector of size $1
//with mean = $2 , stddev = $3
//$o4 is vector
proc WhiteNoiseVec(){ local mu,std,inc,idx
  mu=$2 std=$3
  $o4 = new Vector($1)
  for(idx=0;idx<$o4.size;idx+=1){
    $o4.x(idx) = GetWhiteNoise(mu,std)
  }
}

proc SaveListOfVecs(){ local idx
  for(idx=0;idx<$o1.count;idx+=1){
    savevec($o1.o(idx))
  }
}

proc Smooth() { local idx,w,val,jdx,cnt localobj myv,myvtmp
  myv=$o1
  if(numarg()>1) w=$2 else w=5
  myvtmp=new Vector(myv.size)
  myvtmp.copy(myv)
  for(idx=w;idx<myv.size-w;idx+=1){
    cnt=0
    val=0
    for(jdx=idx-w;jdx<idx+w;jdx+=1){
      val += myv.x(jdx)
      cnt += 1
    }
    if(cnt>0) myvtmp.x(idx) = val / cnt
  }
  myv.copy(myvtmp)
}

proc Smooth2(){ local idx,w,val,jdx,cnt,wght,taus localobj myv,myvtmp
  myv=$o1
  w = taus/2
  wght=0
  myvtmp=new Vector(myv.size)
  if(numarg()>1) taus=3/$2 else taus=3/dt
//  myvtmp.copy(myv)
  for(idx=0;idx<2*w;idx+=1) wght += exp(-abs(idx-w)/taus)
  for(idx=w;idx<myv.size-w;idx+=1){
    cnt=0
    val=0
    for(jdx=idx-w;jdx<idx+w;jdx+=1){
      val += myv.x(jdx) * exp(-abs(jdx-idx)/taus)
      cnt += 1
    }
//    myvtmp.x(idx)=val
    if(cnt>0) myvtmp.x(idx) = val / wght//cnt
//    if(cnt>0) myvtmp.x(idx) = val / exp(1)
  }
  myv.copy(myvtmp)
}

//does smoothing of 1d vector using exp kernel
//$o1 = input vector , (also output)
//$2 = 1/2 width of kernel
//$3 = tstep
proc ExpSmooth(){ local idx,w,val,jdx,wght,myt localobj myv,myvtmp
  myv = $o1
  if(numarg()>1) w = $2 else w = 4
  if(numarg()>2) tstep = $3 else tstep=1
  wght = 0
  myt=0
  myvtmp=new Vector(myv.size)
  myt = tstep
  for(idx=0;idx<w;idx+=1){  
    wght += exp(-abs(tstep))
    myt += tstep
  }
  wght = 2*wght + 1
  if(mydbg) printf("weight=%f\n",mydbg)
  for(idx=w;idx<myv.size-w;idx+=1){
    val=0
    myt=w*tstep
    for(jdx=idx-w;jdx<idx;jdx+=1){
      val += myv.x(jdx) * exp(-myt)
      myt-=tstep
    }
    val += myv.x(jdx)
    myt=tstep
    for(;jdx<idx+w;jdx+=1){
      val += myv.x(jdx) * exp(-myt)
      myt+=tstep
    }
    myvtmp.x(idx)=val/wght
  }
  myv.copy(myvtmp)
}

//fix points from EEG trace where voltage diff
//between sample and median is > thresh since it's probably
//a problem with recording equipment ... assumes only 1 point
//in a row will be an outlier (since it replaces point with
//average of surrounding two neighbors)
//$o1 = vector
//$2 = thresh
//$3 = whether to only fix endpoints
//returns -1 on error, 1 if fixed any, 0 otherwise
func FixEEGOutliers(){ local idx,thresh,med,ept,fixedany localobj vtmp

  if(numarg() < 1){
    printf("FixEEGOutliers ERRA: incorrect # of args\n")
    return -1
  }

  if(!isobj($o1,"Vector")){
    printf("FixEEGOutliers ERRB: require vector arg\n")
    return -1
  }

  vtmp = $o1

  med = $o1.median

  thresh = 20000

  fixedany = 0

  if(numarg() > 1) if(argtype(2)==0) thresh = $2

  ept = 0//default don't ONLY fix endpoints
  if(numarg() > 2) if(argtype(3)==0) ept = $3

  if(mydbg) printf("FixEEGOutliers: thresh = %d\n",thresh)

  for(idx=0;idx<vtmp.size;idx+=1){
    if(abs(vtmp.x(idx)-med) >=thresh){

      fixedany = 1

      printf("FixEEGOutliers: fixing val=%f at idx=%d, ",vtmp.x(idx),idx)

      //take average of surrounding two points
      if(idx + 1 < vtmp.size && idx > 0){
        vtmp.x(idx) = ( vtmp.x(idx-1) + vtmp.x(idx+1) ) / 2
      } else if(idx > 0){ //take value from left
        vtmp.x(idx) = vtmp.x(idx - 1)
      } else if(idx + 1 < vtmp.size){ //take value from right
        vtmp.x(idx) = vtmp.x(idx+1)
      }
      printf("new val=%f, med=%f\n",vtmp.x(idx),med)
    }
    if(ept && idx==0) idx = vtmp.size - 2
  }

  return fixedany
}

///////////////////////////////////////////////////////////////////////////
//////////////////////// filter functions /////////////////////////////////
//$o1 = filter vector
//$2 = size
//$3 = tau
//$4 = inc
proc AlphaFilter(){ local idx,inc,tau,sz,myt
  sz = $2
  tau = $3
  myt = inc = $4
  $o1 = new Vector(sz)
  printf("tau=%f inc=%f\n",tau,inc)
  for(idx=0;idx<sz;idx+=1){
    $o1.x(idx) = myt * exp(-myt/tau)
    myt += inc
  }
  $o1.div($o1.sum)
}

//$o1 = filter vector
//$2 = size of ~ 1/2 filter
//$3 = tau
//$4 = inc
proc SymmetricAlphaFilter(){ local idx,inc,tau,sz localobj vtmp,vtmpr
  sz=$2
  tau=$3
  inc=$4

  vtmp=new Vector()
  AlphaFilter(vtmp,sz,tau,inc)

  vtmpr=new Vector()
  vtmpr.copy(vtmp,vtmp.max_ind,vtmp.size-1)

  $o1 = new Vector()
  vtmpr.reverse()
  $o1.copy(vtmpr)
  vtmpr.reverse()
  $o1.append(vtmpr)

  $o1.div($o1.sum)
}

//compute sample correlation
//$o1 = vector # 1
//$o2 = vector # 2
//must have same size
func cor(){ local idx,m1,m2,sum
  if($o1.size!=$o2.size){
    printf("cross_cor err: input vecs not same size: %d & %d\n",$o1.size,$o2.size)
    return 0
  } else if($o1.size<2){
    printf("cross_cor err: input vecs must have size > 1\n")
    return 0
  }
  m1 = $o1.mean m2 = $o2.mean
  for(idx=0;idx<$o1.size;idx+=1){
    sum += ($o1.x(idx)-m1)*($o2.x(idx)-m2)
  }
  sum = sum / (($o1.size-1)*$o1.stdev*$o2.stdev)
  return sum
}

//plot clusters in 2 dimensions
//$o1 = list of vecs
//$o2 = classes
//$o3 = graph
//$4 = dim1
//$5 = dim2
func plot_kmeans(){ local idx
  if($4 >= $o1.count || $5 >= $o1.count){
     printf("idx out of bounds %d %d %d",$4,$5,$o1.count)
     return 0
  }
  if(mydbg){
    printf("$o1.o(%d).size=%d\n",$4,$o1.o($4).size)
    printf("$o1.o(%d).size=%d\n",$5,$o1.o($5).size)
  }
  if(mydbg){
    for(idx=0;idx<$o2.size;idx+=1){
      if(idx >= $o2.size){
        printf("idx out of $o2 bounds %d %d\n",idx,$o2.size)
        continue
      } else if(idx >= $o1.o($4).size || idx >= $o1.o($5).size){
        printf("idx out of $o1.o($4) or $o1.o($5) %d %d %d\n",idx,$o1.o($4).size,$o1.o($5).size)
        continue
      }
      $o3.mark($o1.o($4).x(idx),$o1.o($5).x(idx),"o",4,$o2.x(idx)+2,1)
    }
  } else {
    for(idx=0;idx<$o2.size;idx+=1){
      if(idx >= $o2.size){
        printf("idx out of $o2 bounds %d %d\n",idx,$o2.size)
        continue
      } else if(idx >= $o1.o($4).size || idx >= $o1.o($5).size){
        printf("idx out of $o1.o($4) or $o1.o($5) %d %d %d\n",idx,$o1.o($4).size,$o1.o($5).size)
        continue
      }
      $o3.mark($o1.o($4).x(idx),$o1.o($5).x(idx),"o",4,$o2.x(idx)+2,1)
    }
  }
  return 1
}


///////////////////////////////////////////////////////////////////////////

/////////////////////////////////////////////////////////
//The return value is
//0 for numbers, 
//1 for objref, 
//2 for strdef, and
//3 for pointers to numbers.
//const for return codes from argtype()
NUMBER_T  = 0
OBJREF_T  = 1
STRING_T  = 2
POINTER_T = 3
/////////////////////////////////////////////////////////

//////////////////////////////////////////////////////////////
/////////////////// fitness functions for ga /////////////////

//$o1 = spike train 1
//$o2 = spike train 2
func SpikeTrainFitness(){ local thresh,fitness,tmp,idx1,idx2,spiket1,spiket2,foundm,matchspikes
  ///////////////////////////////////////////////////////////////////
  //calculate fitness
  //
  //this will determine the "fitness", a measure of the similarity
  //between two spike trains

  thresh = 1.5         //max time between "matching" spikes
  fitness = 0
  
  tmp = abs($o1.size-$o2.size)
  fitness -= (tmp*tmp*tmp) //penalty for diff # of spikes

  //find # of matching spikes from set 1 to set 2
  matchspikes = 0  
  for(idx1 = 0; idx1 < $o1.size ; idx1 += 1){
    spiket1 = $o1.x(idx1)
    foundm = 0
    for(idx2 = 0; idx2 < $o2.size && !foundm; idx2 += 1){
      spiket2 = $o2.x(idx2)
      if(abs(spiket1-spiket2)<=thresh){
        matchspikes += 1
        foundm = 1
        break
      }
    }
  }

  //matching spikes add to fitness
  fitness += (matchspikes*matchspikes)
  tmp = $o1.size - matchspikes  
  //unmatching spikes sub from fitness
  fitness -= (tmp*tmp)
  
  //find # of matching spikes from set 2 to set 1
  matchspikes = 0  
  for(idx2 = 0; idx2 < $o2.size ; idx2 += 1){
    spiket2 = $o2.x(idx2)
    foundm = 0
    for(idx1 = 0; idx1 < $o1.size && !foundm; idx1 += 1){
      spiket1 = $o1.x(idx1)
      if(abs(spiket1-spiket2)<=thresh){
        matchspikes += 1
        foundm = 1
        break
      }
    }
  }  
  
  fitness += (matchspikes*matchspikes)
  tmp = $o2.size - matchspikes
  fitness -= (tmp*tmp)

  return fitness
  ///////////////////////////////////////////////////////////////////
}


//$o1 = spike train 1
//$o2 = spike train 2
//$3 = cost per unit time to move spike
maxspikes=512
double dtab[maxspikes][maxspikes]
func SpikeTrainEditDist(){ local cost,nspi,nspj,row,col
  nspi = $o1.size
  nspj = $o2.size
  cost = $3
  
  if(cost==0){
    return abs(nspi-nspj)
  } else if(cost >= 9e24) {
    return nspi+nspj
  }

  if(nspi==0){
    return nspj
  } else if(nspj==0){
    return nspi
  }
  
  for(row=0;row<nspi+1;row+=1){
    for(col=0;col<nspj+1;col+=1){
      dtab[row][col]=0
    }
  }

  for(row=0;row<nspi+1;row+=1){
    dtab[row][0] = row
  }

  for(col=0;col<nspj+1;col+=1){
    dtab[0][col] = col
  }

  for(row=1;row<nspi+1;row+=1){
    for(col=1;col<nspj+1;col+=1){
      dtab[row][col]=MIN(MIN(dtab[row-1][col]+1,dtab[row][col-1]+1),dtab[row-1][col-1]+cost*abs($o1.x(row-1)-$o2.x(col-1)))
    }
  }
  return dtab[nspi][nspj]   
}

//distance between params of two mainen cells
//$o1 = cell 1
//$o2 = cell 2
func MainenDist(){ local dist
  dist = abs($o1.rho-$o2.rho)+\
         abs($o1.kappa-$o2.kappa)+\
         abs($o1.dend.gmax_naz-$o2.dend.gmax_naz)+\
         abs($o1.dend.gmax_Nca-$o2.dend.gmax_Nca)+\
         abs($o1.dend.gmax_km-$o2.dend.gmax_km)+\
         abs($o1.dend.gmax_kca-$o2.dend.gmax_kca)
  return dist
}

//$o1 = spike train 1 - "guess" spike train - this tries to reach target
//$o2 = spike train 2 - target spike train - this is the end result we want
func SpikeTrainFitness2(){ local thresh,fitness,tmp,idx1,idx2,spiket1,spiket2,matchspikes

  //this function determines the "fitness" of how close
  //spike train $o1 is to spike train $o2

  thresh = 0.5         //max time between "matching" spikes

  if($o1.size==0){
    return -$o2.size
  }
  
  //find # of matching spikes from set 1 to set 2
  matchspikes = 0  
  for(idx1 = 0; idx1 < $o1.size ; idx1 += 1){
    spiket1 = $o1.x(idx1)
    for(idx2 = 0; idx2 < $o2.size ; idx2 += 1){
      spiket2 = $o2.x(idx2)
      if(abs(spiket1-spiket2)<=thresh){
        matchspikes += 1
        break
      }
    }
  }

  return (matchspikes/$o1.size) - abs($o1.size-$o2.size) / $o1.size
}

//returns min ISI distance btwn 2 sets of ISI trajectories
//$o1 = list of v1
//$o2 = list of v1'
//$o3 = list of v1''
//$o4 = list of v2
//$o5 = list of v2'
//$o6 = list of v2''
func ISIDistance(){ local idx1,idx2,dist,mindist,newsz localobj tmp1,tmp2,tmp3,tmp4,tmp5,tmp6
  dist = mindist = 9e308
  for(idx1=0;idx1<$o1.count;idx1+=1){
    for(idx2=0;idx2<$o4.count;idx2+=1){
      if($o1.o(idx1).size > $o4.o(idx2).size){ //resample $o4,5,6.o(idx) to size of $o1.o(idx)
        newsz = $o1.o(idx1).size
        tmp4 = new Vector(newsz) tmp4.samp($o4.o(idx2),newsz) 
        tmp5 = new Vector(newsz) tmp5.samp($o5.o(idx2),newsz)
        tmp6 = new Vector(newsz) tmp6.samp($o6.o(idx2),newsz)
        dist = VSQDist($o1.o(idx1),$o2.o(idx1),$o3.o(idx1),tmp4,tmp5,tmp6)
      } else if($o1.o(idx1).size < $o4.o(idx2).size){
        newsz = $o4.o(idx2).size
        tmp1 = new Vector(newsz) tmp1.samp($o1.o(idx1),newsz)
        tmp2 = new Vector(newsz) tmp2.samp($o2.o(idx1),newsz)
        tmp3 = new Vector(newsz) tmp3.samp($o3.o(idx1),newsz)
        dist = VSQDist(tmp1,tmp2,tmp3,$o4.o(idx2),$o5.o(idx2),$o6.o(idx2))
      } else{
        dist = VSQDist($o1.o(idx1),$o2.o(idx1),$o3.o(idx1),$o4.o(idx2),$o5.o(idx2),$o6.o(idx2))
      }
      if(dist != -1 && dist < mindist){
        mindist = dist
        if(mydbg){
          printf("idx1=%d idx2=%d mindist_so_far=%f\n",idx1,idx2,mindist)
        }
      }
    }
  }
  return mindist
}

//returns sum of min distance between each ISI trajectory in first set to 2nd set
//$o1 = list of v1
//$o2 = list of v1'
//$o3 = list of v1''
//$o4 = list of v2
//$o5 = list of v2'
//$o6 = list of v2''
//$&7 = best idx1 (optional)
//$&8 = best idx2 (optional)
func ISIDistance2(){ local idx1,idx2,dist,mindist,newsz,totaldist,savebestidx localobj tmp1,tmp2,tmp3,tmp4,tmp5,tmp6
  if(numarg()==8) savebestidx = 1 else savebestidx = 0 //
  totaldist = 0
  for(idx1=0;idx1<$o1.count;idx1+=1){
    dist = -1.0
    mindist = 9e25
    for(idx2=0;idx2<$o4.count;idx2+=1){
      if($o1.o(idx1).size > $o4.o(idx2).size){ //resample $o4,5,6.o(idx) to size of $o1.o(idx)
        newsz = $o1.o(idx1).size
        tmp4 = new Vector(newsz) tmp4.samp($o4.o(idx2),newsz) 
        tmp5 = new Vector(newsz) tmp5.samp($o5.o(idx2),newsz)
        tmp6 = new Vector(newsz) tmp6.samp($o6.o(idx2),newsz)
        dist = VSQDist($o1.o(idx1),$o2.o(idx1),$o3.o(idx1),tmp4,tmp5,tmp6)
      } else if($o1.o(idx1).size < $o4.o(idx2).size){
        newsz = $o4.o(idx2).size
        tmp1 = new Vector(newsz) tmp1.samp($o1.o(idx1),newsz)
        tmp2 = new Vector(newsz) tmp2.samp($o2.o(idx1),newsz)
        tmp3 = new Vector(newsz) tmp3.samp($o3.o(idx1),newsz)
        dist = VSQDist(tmp1,tmp2,tmp3,$o4.o(idx2),$o5.o(idx2),$o6.o(idx2))
      } else{
        dist = VSQDist($o1.o(idx1),$o2.o(idx1),$o3.o(idx1),$o4.o(idx2),$o5.o(idx2),$o6.o(idx2))
      }
      if(dist != -1.0 && dist < mindist){
        mindist = dist
        if(mydbg){
          printf("idx1=%d idx2=%d mindist_so_far=%f\n",idx1,idx2,mindist)
        }
        if(savebestidx){
          $&7 = idx1 $&8 = idx2
        }
      }
    }
    if(mindist < 9e25) totaldist += mindist
  }
  return totaldist
}

//////////////////////////////////////////////////////////////

///some templates...

//////////////////////////////////////////////////////////////////////////////////////
//template for storing inter-spike interval information
begintemplate isi_info

external GetISITimes,GetTrajectories,GetTrajectory
public start_times,end_times,lv1,lv2,lv3
objref start_times,end_times,lv1,lv2,lv3 //start,end isi times , v,v',v'' vectors
public setup

//$o1 = spike train
//$o2 = voltage vector
func setup(){ localobj spikes,voltvec
  if(numarg()!=2){
    return 0
  }
  spikes = $o1 voltvec = $o2
  if(!GetISITimes(spikes,start_times,end_times)) return 0
  lv1=new List() lv2=new List() lv3=new List()
  if(!GetTrajectories(voltvec,start_times,end_times,lv1,lv2,lv3)) return 0
  return 1
}

//$o1 = spike train
//$o2 = voltage vector
proc init(){

  start_times=new Vector()
  end_times=new Vector()
  lv1=new List()
  lv2=new List()
  lv3=new List()

  if(numarg()==0){
    return
  } else if(numarg()==2){
    setup($o1,$o2)
  }
}

endtemplate isi_info
//////////////////////////////////////////////////////////////////////////////////////

//////////////////////////////////////////////////////////////////////////////////////
//template for storing spike-related info , including inter-spike interval info
begintemplate spike_info

external GetSpikeTimes,GetISITimes,GetTrajectories,GetTrajectory
public isi,spikes,setup
objref isi,spikes

//$o1 = voltage vector
//$2 = spike threshold
//$3 = min interspike time
//$4 = whether to initialize isi info
func setup(){ local sz
  if(numarg()!=4){
    printf("invalid args to spike_info.setup\n")
    return 0
  }
  spikes = new Vector($o1.size)
  sz = spikes.GetSpikeTimes($o1,$2,$3)
  spikes.resize(sz)
  if($4) return isi.setup(spikes,$o1)
  return 1
}

//either give 0 or 4 arguments
//$o1 = voltage vector
//$2 = spike threshold
//$3 = min interspike time
//$4 = whether to initialize isi info
proc init(){ local thresh,mintimediff,getisi localobj voltvec
  thresh = -20
  mintimediff = 2
  isi=new isi_info()
  spikes=new Vector()
  getisi = 1
  if(numarg()==0){
    return
  } else if(numarg()==4){
    voltvec = $o1
    thresh = $2
    mintimediff = $3
    getisi = $4
  }
  setup(voltvec,thresh,mintimediff,getisi)
}

endtemplate spike_info

///
//////////////////////////////////////////////////////////////////////////////////////