// *****************  Test Inputs  **********************

// ------------  Introduction   --------------------------------  
// 01/01/2005
// This program generats periodic input spikes that simulate afferent inputs to cells in the Medial Superio Olive (MSO) 
// of the mammalian brainstem as described in 
// Zhou, Carney, and Colburn (2005) J. Neuroscience, 25(12):3046-3058   





// ******* Codes start here ****************************************

seed_x=xred("Please choose a random seed number <1,1000>", 1, 1, 1000) //string,default,min,max

strdef inputDir,outputDir
outputDir="/home/yizhou/Research/Publishing/MSO_structure"
inputDir="/home/yizhou/Research/Publishing/MSO_structure"

  
  load_file("nrngui.hoc")
 
  create cell
  
  access cell





  //**** Excitatory and inhibitory synapses **************** 
  //---------------------------------------------------------
  
  //----     parameter description     ----------------------
  //syn0[]: E alpha synapses on contralateral dendrite, dend0
  //syn1[]: E alpha synapses on ipsilateral dendrite, dend1
  //gen0[]: E stimuli to syn0
  //gen1[]: E stimuli to syn1
  //netcon0[]: input- E synapse connection on dend0
  //netcon1[]: input- E synapse connection on dend1

  //syn_inhi0[]:  I alpha synapse on the soma
  //gen_inhi0[]:  I stimuli to syn_inhi0
  //netcon_inhi0[]: input- I synapse connection on the soma
  
 
  // ************ Excitatory synapses on two dendrites  ********
  objectvar syn0[10], gen0[10],syn1[10], gen1[10]
  
     //----- add spike, virtual spikes
     for i=0, 9{
                gen0[i] = new GaussStim(0.5)
		gen0[i].interval=2   // [ms] input period
  		gen0[i].start=0	  //  [ms] arrival time of input spikes
		gen0[i].number=10000
		gen0[i].factor=20	   // phase-locking factor F
		
		gen1[i] = new GaussStim(0.5) 
		gen1[i].interval=2   
  		gen1[i].start=0	 
  		gen1[i].number=10000
		gen1[i].factor=20	   
		
		gen0[i].seed(seed_x+i)
                gen1[i].seed(seed_x+10+i)
  	} 
 
 	//----- add EE expsyn at dend0
  	cell {  //contra side
		for i=0,9 { 
		n=20  //n=nseg  doesn't change after nseg
		syn0[i] = new Alpha(i*1/n+1/n/2)  // scatter along the firsthalf dend

		syn0[i].tau1 =0.1  //ms    
		syn0[i].tau2 =0.1   //ms
		syn0[i].con=10    //nS- mho
		syn0[i].e =0   // mV
			}                            
   		}

  	cell { //ipsi side
  		for i=0,9 {
		n=20  //n=nseg  
		syn1[i] = new Alpha(i*1/n+1/n/2)
		
		syn1[i].tau1 =0.1  //ms    may add some jitters
		syn1[i].tau2 =0.1   //ms
		syn1[i].con=0    //nS-mho
		syn1[i].e =0   // mV
    			}     
		}
  
   
   //---  Connect gen[] with syn[] 
   	objref netcon0[10],netcon1[10]

    	Rave_E_contra=240 // spikes/sec
	Rave_E_ipsi=240 

    	x_thres_contra=1-Rave_E_contra*gen0[0].interval/1000        //   prob(x>1-p)=p
    	x_thres_ipsi=1-Rave_E_ipsi*gen1[0].interval/1000        
      
       if (x_thres_contra<=0) {x_thres_contra=1e-4} 
       if (x_thres_ipsi<=0) {x_thres_ipsi=1e-4} 

        
    	e_delay=0  // no synaptic delay
    
    	for i=0,9  { 
    		netcon0[i]=new NetCon(gen0[i], syn0[i],x_thres_contra,e_delay,0.001) 
    										//thres delay gain[uS]
		netcon1[i]=new NetCon(gen1[i], syn1[i],x_thres_ipsi,e_delay,0.001) 
										//thres delay gain[uS]
       	}
  
 

  
  v_init=-65
  

    
 //----  parameter and data vector descriptions     ------------ 

  // v_ISI: [ms]  AP event time
  // v_voltage:  occurrence of events at each dt
 
  
  tstop=1000   	// [ms]  stimulus duration
  dt=0.025   	// [ms]  simulation time step
  bin=0.1 	// [ms]  binwidth of APs  
  
  cvode.active(0)   // constant integration time step


  

  objref v_ISI,v_voltage
  objref pre_E
  
  v_ISI=new Vector()
  v_voltage=new Vector(tstop/dt+2,0)

  
  
  //==== for input ISIs
  netcon0[0].record(v_ISI)//record input event time
  v_voltage.record(&gen0[0].x,dt) //record the event
  AP_threshold=x_thres_contra  // reset if record inputs
  
 
  
  //****  function  "pth": period time histogram  **** 
  //-----------------------------------------------
  //----  data vector description     ------------
  // v_save: bin v_voltage for each run
  // v_pth: wrap v_all with length=gen.interval/bin and normalize
  // vs_real: vector strength
  // phase_real: mean phase 

  // get_pth(): get v_save from each run
  
  
  objref v_save, v_all, v_pth, g_pth, x_pth, v_stand0
  
  proc get_pth() { 
		v_save=new Vector(L_PTH)
  		v_save.spikebin(v_voltage,AP_threshold) //(v_source, threshold)
		v_save.rebin(v_save,bin/dt)  // rebin the time axis
		v_all.add(v_save)   // add all trials
  	}
  	
  proc pth() { 
  		x=0
  		y=0
  		vs_real=0
  		angel_real=0
		length=gen0[0].interval/bin  //length of period points
		
		N_fold=v_all.size/length
		v_pth= new Vector(length,0)
		x_pth= new Vector(length,0)  // phase vector
		x_pth.indgen(1/length)
		
		//==== wrapping over one period
		for i=0,N_fold-1 { 
			for j=0,length-1 {
					v_pth.x[j]=v_pth.x[j]+v_all.x[length*i+j]
							}
			}
		v_pth.div(v_pth.sum)

		v_stand0=new Vector(x_pth.size,0)     // dend[0] input PSH
		v_stand0.indgen(1/length)
  		v_stand0.apply("Gauss0")
		v_stand0.rotate(gen0[0].start/bin)
		
		//==== calculate the VS and mean phase of responses
		for i=0,length-1 {
			x=x+v_pth.x[i]*cos(2*PI*i/length)
			y=y+v_pth.x[i]*sin(2*PI*i/length)
		}
		vs_real=sqrt(x^2+y^2)     //compute VS
		print "VS=",vs_real
		
		angel_real=atan(y/x)
			if(x<0 ) { angel_real=angel_real+PI
					} else if (y<0 && x>0) {
						angel_real=angel_real+2*PI
						}	
		print "Phase=",angel_real/(2*PI)
		
		
  		g_pth.size(0,1,0,0.3)
  		g_pth.align(0.5,0.2)
  		//g_pth.label(0.9,0.1,"Phase")
		g_pth.color(1)
  		g_pth.label(0.3,0.9,"Period Histogram")
  		g_pth.begin() 
		v_pth.mark(g_pth,x_pth,"+",9,1)
		v_pth.line(g_pth,x_pth,1,2)
		g_pth.color(2)
		v_stand0.line(g_pth,x_pth,2,2)
		
  				  		
}

  proc VS() {
		vs=exp(-PI^2/(2*gen0[0].factor^2))
	}
  
  func Gauss0() { //for PST
        sigma2=0.5/gen0[0].factor
        return exp(-($1-0.5)^2/(2*sigma2^2))/((sqrt(2*PI)*sigma2)*length)
}	
  
  
  //-------------- END pth () -------------------------
  //---------------------------------------------------





  //****  function  "ISI": interspike interval  **** 
  //-----------------------------------------------
  //----  data vector description     ------------
  
  // v_ISI: event times
  // bin_ISI : bin width of ISI	
  // hist_ISI: histogram of ISI for all run
  // hist: histogram of ISI for each run
  // mean_T, sq_T, num_T: ISI statistics for each run

  // get_ISI(): get data for each run
  
  objref hist_ISI,xtime,g_ISI,hist, v_stand
  objref mean_T, sq_T, num_T
  objref temp,t1, v1,v2   // operating vectors 
  
  mean_ISI=0
  OVF=0    // overflow
  EI=0
  
  mean_ISI2=0
  std_ISI=0
  sq_sum=0
  CV_ISI=0
  
  proc get_ISI() {
  		
  		v2=new Vector(L_ISI,0)
  		v1=new Vector(L_ISI,0)
  		
  		
  		v1.add(v_ISI)
  		v2.add(v_ISI)
  		v1.rotate(1) // right shift
  		v2.sub(v1)
  		v2.rotate(-1)
  		v2.resize(v2.size-1)  // cut off the initial interval
  		print "mean ISI interval   ", v2.mean
  		print "ISI std    ",v2.stdev
  		 
  		temp = new Vector(v2.size,0)
  		temp.add(v2)
  		temp.sub(v2.mean)  // for calculate the square sum
  		
  		mean_T.x[N_run-1]=v2.mean
  		num_T.x[N_run-1]=v2.size
  		sq_T.x[N_run-1]=temp.sumsq()
  		
  		mean_ISI=v2.mean + mean_ISI  // mean ISI for each run
  
  		print "**********************************"
  		
  		hist=v2.histogram(0, topbin+OVFbin, bin_ISI) 
		
  		if(hist.size-hist_ISI.size==1) { hist.resize(hist_ISI.size)} 
  		hist_ISI.add(hist)
		
  	}	
  	
  proc ISI() {
  		hist_ISI.div(hist_ISI.sum)  //normalize
  		OVF=hist_ISI.sum(topbin/bin_ISI, (topbin+OVFbin)/bin_ISI-1)
  		EI=hist_ISI.sum(0,1.5*T/bin_ISI-1)   // the first whole interval
  		
		hist_ISI.resize(topbin/bin_ISI+1)
		
  		xtime=new Vector(hist_ISI.size,0)
  		xtime.indgen(bin_ISI)

		v_stand=new Vector(hist_ISI.size,0)
  		v_stand.indgen(bin_ISI)
  		v_stand.apply("Gauss")
  		
  		g_ISI.size(0,topbin,0,0.2)
  		g_ISI.align(0.5,0.5)
  		//g_ISI.label(0.9,0.2,"T")
  		g_ISI.label(0.2,0.9,"ISI")
  		g_ISI.beginline()  
  		hist_ISI.mark(g_ISI,xtime,"o",6,1)
  		hist_ISI.line(g_ISI,xtime,1,2)
		v_stand.line(g_ISI,xtime,2,2)
		
  		
  		//==== calculate the mean and the std of intervals  
   		t1=new Vector(N_run,0)
   		t1.add(mean_T)
   		t1.mul(num_T)
   		
   
   		t2=t1.sum()
   		mean_ISI2=t2/num_T.sum()  //overall ISI mean 
  		
  		t1=new Vector(N_run,0)
  		t1.add(mean_T)
  		t1.sub(mean_ISI2)
  		t1.mul(t1)
  		t1.mul(num_T)
  		sq_sum=sq_T.sum()+t1.sum()
  		
  		if(num_T.sum()>1) {
  			std_ISI=sqrt(sq_sum/(num_T.sum-1))   //Overall ISI std
  		} else { std_ISI=sqrt(sq_sum) } 
  		
  		if(mean_ISI2>0) {
  				CV_ISI=std_ISI/mean_ISI2   //coefficient of variation
  			} else{ CV_ISI=1
  		}            
  		  
   }

  func Gauss() { //for ISI
        sigma=sqrt(2)/2*(gen0[0].interval/gen0[0].factor)  //std=sqrt(2)*T/2/factor
        return exp(-($1-gen0[0].interval)^2/(2*sigma^2))*bin/(sqrt(2*PI)*sigma)
}
  

  //----------- END ISI() -------------------------
  //-----------------------------------------------


  N_run=10   // ten trials
 

   //----- reset function for all simulations
   //------------------------------------------
  
  proc reset() {
  	
       
        //==== Reset synapse, and inputs
        
	getcon()
	getstim()	
	changelevel()
	

	//==== Reset all vec and var 
	N_run=10
         
        v_voltage=new Vector(tstop/dt+2,0)
        v_voltage.record(&gen0[0].x,dt) //record the event
    
        mean_T=new Vector(N_run,0)
  	sq_T=new Vector(N_run,0)
  	num_T=new Vector(N_run,0)

	T=gen0[0].interval
	bin_ISI=bin
  	topbin=int(10*T/bin_ISI)*bin_ISI  // 10 intervals and 1 for OVF
  	OVFbin=int(1*T/bin_ISI)*bin_ISI
  	  	
  	mean_ISI2=0
   	std_ISI=0
   	sq_sum=0
   	CV_ISI=0
   	
  	mean_ISI=0
  	OVF=0
  	EI=0
  	
  	hist_ISI=new Vector((topbin+OVFbin)/bin_ISI+1,0)
	v_all=new Vector(tstop/dt+2,0)
  	v_all.rebin(v_all,bin/dt)
  
  
  }
  
  //---------- end reset() ----------------
  //---------------------------------------
  




  
  //----- Main simulation function --------------------
  // Measure model response at one ITD condition  
  //---------------------------------------------------


  proc rerun() {
        
	 reset()

  	//*********************************
            while (N_run>0) {
                v_ISI=new Vector()
  		v_voltage=new Vector(tstop/dt+2,0)

            	  //==== for input ISI and PS
  		netcon0[0].record(v_ISI)//record input event time
  		v_voltage.record(&gen0[0].x,dt) //record the event
  		  		
  		finitialize(v_init) 
  		fcurrent()

  		integrate()
  		get_pth()
  		get_ISI()
               
		N_run=N_run-1 	
	  	} 

   
   N_run=10
   
   g_ISI.erase_all()
   g_pth.erase_all()
   ISI()
   pth()
  
   print "----------- END -----------------------"
}
  
  
 proc integrate() {
    
    	while (t < tstop) {
    		fadvance()  // advance solution 
 		 }
  print "N_run=",N_run   // show run time 
  print "-------------------------"
  print "Number_input_E=",v_ISI.size
  
  
  if(v_ISI.size<=2) { v_ISI=new Vector(3,0) }
  L_ISI=v_ISI.size
  //print "L_ISI=",L_ISI
  L_PTH=v_voltage.size		
  //print "L_PTH=",L_PTH	
}




//******* update synapse and input information   ****


  proc changelevel() {
	
	x_thres_contra=1-Rave_E_contra*gen0[0].interval/1000        //   prob(x>1-p)=p
    	x_thres_ipsi=1-Rave_E_ipsi*gen1[0].interval/1000        

	
      
	if (x_thres_contra<=0) {x_thres_contra=1e-4} 
        if (x_thres_ipsi<=0) {x_thres_ipsi=1e-4} 
        

  	for i=0,9  { 
    	netcon0[i].threshold=x_thres_contra  
	netcon1[i].threshold=x_thres_ipsi
       }
      
  }

  proc getcon() {
		for i=0,9 {
		syn0[i].con=syn0[0].con    //umho
		syn1[i].con=syn1[0].con    //umho
		
		}
		print "*******************"
		print "syn0 con =",syn0[1].con
		print "syn1 con =",syn1[1].con
		
		print "--------------------------"
  }

  proc getstim() {
		//------- EE -------------
		for i=0, 9{
                gen0[i].interval=gen0[0].interval   // 
  		gen0[i].start=gen0[0].start	  //  identify ITD
		gen0[i].factor=gen0[0].factor	
				
		gen1[i].interval=gen1[0].interval   
  		gen1[i].start=gen1[0].start	  //  identify ITD
		gen1[i].factor=gen1[0].factor		
  	        } 

		
       		
  	        print "E0 start =",gen0[1].start
		print "E0 interval =",gen0[1].interval
		print "E1 start =",gen1[1].start
		print "E1 interval =",gen1[1].interval  
		
		print "*******************"
				
 }
  //----------- END rerun() -----------------------------------
  //-----------------------------------------------------------



  //**** GUI  "synapse and stimuli"  ****
  //---------------------------------------------
  
  objref syn_con, gen_start
  
  syn_con = new HBox()
  syn_con.intercept(1)       
  
  xpanel("")
  xlabel("E0 syn contra")
  xvalue("G [nS]","syn0[0].con")
  xpanel()
  
  
  syn_con.intercept(0)       
  syn_con.map("SYN",0,150,-1,50)              


 //--------------------------------------------------------------------------

  gen_start = new HBox()
  gen_start.intercept(1)       
  
  xpanel("")
  xlabel("E0 input contra")
  xvalue("delay [ms]","gen0[0].start")
  xvalue("period","gen0[0].interval")
  xvalue("synchrony factor","gen0[0].factor")
  xvalue("rate [spikes/sec]","Rave_E_contra")
  xpanel()
  
  
 
  gen_start.intercept(0)       
  gen_start.map("GEN",0,350,-1,50)              

  
//--------------------------------------------------------------------------


  objref boxv
  boxv=new VBox()
  boxv.intercept(1)
  xpanel("")
  xbutton("calculate VS of E0 input","VS()")
  xvalue("vs")
  xbutton("Reset","reset()")
  xbutton("Single Run", "rerun()")
  xbutton("Quit", "quit()")
  
  g_ISI=new Graph()
  g_pth=new Graph()
  xpanel()
  boxv.intercept(0)
  boxv.map("CONTROL",600,50,-1,50)
  //----------------------------------------------------------------