// $Id: infot.hoc,v 1.43 2009/12/04 01:25:55 samn Exp $ 

print "Loading infot.hoc..."

if(!installed_infot) {install_infot()}

//* tentropsig(v1,v2,nshuf,nbins,twoway,[xpast,ypast,hval])
//significance test of transfer entropy using shuffling
//returns (te - tes) / sds , where te is transfer entropy, tes is transfer entropy
//of shuffled data, sds is std-dev of transfer entrop of shuffled data
//should only accept as significanat values > 4-6
func tentropsig () { local nshuf,i,xp,yp,hv,nbins,te localobj v1,v2,vo
  v1=new Vector() v2=new Vector() vo=new Vector(1)
  v1.copy($o1) v2.copy($o2) nshuf=$3 nbins=$4
  if(numarg()>4) xp=$5 else xp=1
  if(numarg()>5) yp=$6 else yp=2
  if(numarg()>6) hv=$7 else hv=0
  te=v1.tentrop(v2,nbins,xp,yp,nshuf,vo,hv)
  if(1||verbose_infot>2) printf("te=%g,sig=%g\n",te,vo.x(0))
  return vo.x(0)
}

//** mutinfbshufv(v1,v2,[nshuf,nbins])
//return vector with mutual information from shuffled v1,v2
//used for significance test , i.e. : ((miorig - mishufmean) / mishufstdev) > 2
obfunc mutinfbshufv () { local nshuf,nbins,i localobj v1,v2,ve
  v1=new Vector() v2=new Vector() ve=new Vector()
  v1.copy($o1) v2.copy($o2)
  if(numarg()>2) nshuf=$3 else nshuf=20
  if(numarg()>3) nbins=$4 else nbins=10
  for i=0,nshuf-1 {
    v1.shuffle() v2.shuffle()
    ve.append(v1.mutinfb(v2,nbins))
  }
  return ve
}

//** mutinfbsig(v1,v2,[nshuf,nbins])
//get significance of mutual information, should be at least > 2
func mutinfbsig () { local nshuf,nbins,st localobj ve
  if(numarg()>2) nshuf=$3 else nshuf=20
  if(numarg()>3) nbins=$4 else nbins=10
  ve=mutinfbshufv($o1,$o2,nshuf,nbins)
  st=ve.stdev
  if(st<=0) st=1
  return ($o1.mutinfb($o2,nbins) - ve.mean) / st
}

//** tentropspksig(v1,v2,nshuffles)
//get significance of tentropspks using shuffling
//returns (TE - AvgTEShuffle) / StdDevTEShuffle
func tentropspksig () { local nshuf,i,xp,yp,hv,nbins,te,sd localobj v1,v2,ve
  v1=new Vector() v2=new Vector() ve=new Vector()
  v1.copy($o1) v2.copy($o2) nshuf=$3
  te=$o1.tentropspks($o2)
  for i=0,nshuf-1 {
    v1.shuffle()     ve.append(v1.tentropspks(v2))
  }
  if(verbose_infot>2) printf("te=%g,ve.mean=%g,ve.stdev=%g\n",te,ve.mean,ve.stdev)
  if(verbose_infot>2) ve.printf
  sd=ve.stdev()
  if(sd<=0)sd=1
  return (te-ve.mean)/sd
}

//* normte() get normalized transfer entropy using tentropspks in output vector vo
//vo.x(0)=transfer entropy of $o1->$o2
//vo.x(1)=H($o2Future|$o2Past)
//vo.x(2)=normalized transfer entropy in 0,1 range
//$3==number of shuffles
//$o1,$o2 should both have same size and non-negative values. this func is meant for time-binned spike train data
obfunc normte () { local a localobj ve,vo
  a=allocvecs(ve) vo=new Vector()
  nshuf=0
  nshuf=$3 vrsz(3+nshuf,vo) 
  te=$o1.tentropspks($o2,vo,nshuf)
  if(verbose_infot>2) vo.printf
  if(vo.x(1)<=0 && verbose_infot>0){printf("WARNING H(X2F|X2P)==%g<=0\n",vo.x(1)) vo.x(1)=1 }
  if (nshuf>0) {
    ve.copy(vo,3,vo.size-1)
    vo.resize(4)
    if (ve.mean!=vo.x[2]) printf("normte ERRA\n")
    vo.append(ve.stdev)
  } 
  vo.x[2]=te
  dealloc(a)
  return vo
}

//* GetTENQ() get an nqs with useful transfer entropy info
obfunc GetTENQ () { local te01,te10,pf01,pf10 localobj nqte,vo1,vo2
  if(numarg()>3) nqte=$o4
  if(nqte==nil) {
    nqte=new NQS("from","to","TE","NTE","HX2|X2P","prefdir","TEshufavg","TEshufstd","sig")
  } else nqte.clear()
  vo1=normte($o1,$o2,$3)
  vo2=normte($o2,$o1,$3)
  te01=vo1.x(2)
  te10=vo2.x(2)
  if(vo1.x(4)<=0)vo1.x(4)=1
  if(vo2.x(4)<=0)vo2.x(4)=1
  if(te01>0 || te10>0) {
    pf01=(te01-te10)/(te01+te10)
    pf10=(te10-te01)/(te01+te10)
  } else {
    pf01=pf10=0
  }
nqte.append(0,1,vo1.x(0),te01,vo1.x(1),pf01,vo1.x(3),vo1.x(4),(vo1.x(0)-vo1.x(3))/vo1.x(4))
nqte.append(1,0,vo2.x(0),te10,vo2.x(1),pf10,vo2.x(3),vo2.x(4),(vo2.x(0)-vo2.x(3))/vo2.x(4))
  return nqte
}

//** prefdte() get preferred direction of transfer entropy
//$o1=vec 1, $o2=vec 2, $3 = # of times to shuffle
func prefdte () { local nshuf,a,te01,te10,pfd localobj v1,v2,vtmp
  a=allocvecs(v1,v2,vtmp)
  v1.copy($o1) v2.copy($o2) nshuf=$3 vtmp.resize(3)
  v1=normte($o1,$o2,nshuf)
  v2=normte($o2,$o1,nshuf)
  te01=v1.x(2)
  te10=v2.x(2)
  pfd=(te01-te10)/(te01+te10)
  dealloc(a)
  return pfd
}

//** mkchist() averages entries in window into disc values and returns in new output vec
//$o1=input vec,$2=win size
obfunc mkchist () { local idx,eidx,wsz localobj vin,vout
  vin=$o1 wsz=$2 vout=new Vector() 
  vout.resize(1+vin.size/wsz) vout.resize(0)
  for(idx=0;idx<=vin.size;idx+=wsz) {
    eidx=idx+wsz-1
    if(eidx>=vin.size)eidx=vin.size-1
    if(eidx>idx) vout.append( int(vin.mean(idx,eidx)) )
  }
  return vout
}

//get magnitude of difference in a preferred direction - just abs of diff, but if theyre both neg, return 0
//$1 = nTE_X->Y
//$2 = nTE_Y->X
func prefdmag () { local n1,n2,s
  n1=$1 n2=$2
  if(n1>0 && n2<=0) return n1-n2 //n1 is relatively strong
  if(n2>0 && n1<=0) return n2-n1 //n2 is relatively strong
  if(n2<0 && n1<0) return 0      //both are weak
  return abs(n1-n2)              //both are weak positive
}


//** simple test for nte
{declare("vb","o[2]","vs","o[2]")}
for i=0,1 {
  vb[i]=new Vector()
  vs[i]=new Vector()
}

//mkspktrain(Random,rate,tmax) -- make a spike train with specified rate,tmax
//Random obj must be initialized
obfunc mkspktrain () { local tmax,rate,t,dt,intt localobj rdp,vs
  rdp=$o1 rate=$2 tmax=$3
  intt=1e3/rate
  t = 0
  vs=new Vector()
  while(t<=tmax) {
    dt = rdp.poisson(intt)
    t += dt
    vs.append(t)
  }
  return vs
}

//make random spikes with frequency $1, tmax=$2, offset for spikes=$3, alpha=$4 -- ratio of spikes from
//vs[0] that get placed in vs[1] 
//spikes in vs[0] are randomly picked, spikes in vs[1] are same as in vs[0] but shifted forward by $3 offset
//so vs[0] 'drives' vs[1], or can be used to predict it, but vs[1] cant be used to predict vs[0]
proc mkspks () { local tmax,rate,t,dt,intt,off,i,alpha localobj rdp
  rate=$1 tmax=$2 off=$3 
  if(numarg()>3)alpha=$4 else alpha=1
  intt=1e3/rate
  rdp=new Random()
  rdp.ACG(1234)
  rdp.poisson(intt)
  for i=0,1 vs[i].resize(0)
  vs[0]=mkspktrain(rdp,rate,tmax)
  if(alpha < 1.0) {
    for vtr(&t,vs[0]) if(rdp.uniform(0,1) <= alpha) vs[1].append(t+off)
  } else {
    vs[1].copy(vs[0])
    vs[1].add(off)
  }
}
//test nTE : nTE of X0 -> X1 should be much higher than nTE of X1 -> X0
//optional $1=offset == offset to shift spikes by, in ms
//optional $2=rate == rate of spikes, in Hz
//optional $3=bin size , in ms
//optional $4=alpha == ratio of spikes of X0 that get placed in X1 with offset
//optional $5=max time, in ms
func testnte () { local a,i,bisv,maxt,alpha,off,rate,binsz,dur  localobj nqt,nqout
  if(numarg()>0)off=$1 else off=10
  if(numarg()>1)rate=$2 else rate=50
  if(numarg()>2)binsz=$3 else binsz=10
  if(numarg()>3)alpha=$4 else alpha=1
  if(numarg()>4)dur=$5 else dur=10000
  bisv=binmin_infot binmin_infot=0
  print "output should be close to:\n\t0 1 0.6707 0.9975 0.6707 0.9407 0.003435 0.001672 399.1"
  print "\t1 0 0.02183 0.03049 0.6685 -0.9407 0.002828 0.001442 13.17"
  mkspks(rate,dur,off,alpha)
  maxt=vs[1].max
  printf("maxt=%g\n",maxt)
  if(vs[1].max>maxt)maxt=vs[1].max
  for i=0,1 vb[i].hist(vs[i],0,(maxt+binsz-1)/binsz,binsz)
  nqt=GetTENQ(vb[0],vb[1],200) 
nqout=new NQS("X1","X2")
nqout.odec("X1")
nqout.odec("X2")
batch_flag=1
nqout.append(vb[0],vb[1])
nqout.sv("/u/samn/bpftest/data/09dec17.func.testnte.nqs")
batch_flag=0
  nqt.pr
  nqsdel(nqt)
  binmin_infot=bisv
  return 1
}

//get kernel smoothed prob distrib in an nqs
//$o1=input vector
//$2=increment in x , smaller values mean finer resolution
//$3=bandwidth - higher means smoother output
// $4=min value in output, $5=max value in output
obfunc khist () { local min,max,inc,h,x,i,s localobj vx,vy,nq,vin
  vin=$o1 
  if(numarg()>1)inc=$2 else inc=0.1
  if(numarg()>2)h=$3 else h=vin.getbandwidth()
  if(numarg()>3)min=$4 else min=vin.min()
  if(numarg()>4)max=$5 else max=vin.max()
  {vx=new Vector() vy=new Vector()}
  vx.indgen(min,max,inc)
  vy.copy(vx)
  for vtr(&x,vx,&i) vy.x(i) = vin.kprob1D(h,x)
  s=vy.sum 
  if(s!=0) vy.div(vy.sum)
  nq=new NQS("x","y")
  nq.v[0]=vx
  nq.v[1]=vy
  return nq
}