// Double-click on this file to run the code

load_file("nrngui.hoc")			// neuron main menu and some key parameter init (must come first)
xopen("ri06.nrn")			// geometry

//Generate 10um segments 
ii = 0
tot_nsegs = 0
forall {
	if (ii != 0) {
		mynseg = int((L/10 + 0.9)/2)*2 + 1	
		if (mynseg == 0) {
			mynseg = 1 }
		nseg = mynseg
	}
	tot_nsegs = tot_nsegs + nseg
	ii = ii + 1
}

//load simulated axon and count segments
xopen("simulated_axon.nrn")			
hill {
	tot_nsegs = tot_nsegs + nseg
}
iseg {
	tot_nsegs = tot_nsegs + nseg
}
for jj = 0,1 node[jj] {
	tot_nsegs = tot_nsegs + nseg
}
for jj = 0,2 inode[jj] {
	tot_nsegs = tot_nsegs + nseg
}


naslopebasal=0				//constant Na conductance along basal dendrites
dslopebasal=0				//constant K conductance along basal disconnect
synwait=1000				//time to wait for membrane voltage to equilibrate				
tstop = 100+synwait			//simulation endtime
init_routine = 1			// 1 for using initialization routine
usecvode = 0				// 1 for using cvode

xopen("code_point_processes.hoc")			// point processes (original)
xopen("code_objects.hoc")			// additional objects (specific to geometry)
xopen("code_membrane.hoc")			// membrane mechanisms (was init)
xopen("code_routine_for_runs.hoc")			// contains regular run procedure


// these next few lines correct for the fact that the soma is elliptical rather than cylindrical
// due to an idiosyncrasy of NEURON, you have to do the area call first
access somaA
area(0.5)
correctsoma()

// This procedure initializes the active properties when called from the init() routine
init_params()
forall {insert dv}
insert_pass()
init_props()

// Create spine array, synapse array, and output files
xopen("spinearraygeom.nrn")
objref ampasyn[20], nmdasyn[20], locsyn
locsyn=new Vector(20)
for i=0,19 {
	ampasyn[i]=new synampa()
	nmdasyn[i]=new synnm()
}
objref scalefile
scalefile=new File()

//set spine neck resistance
for ii=0,19 {
	spineneckarray[ii].sec {Ra=200.00}
	spineneckarray[ii].sec {
		insert h
		insert pas
		insert kdr
		insert kap
		insert kad
		insert nafast2
		insert naslowcond2
	}
	spineheadarray[ii].sec {Ra=200.00}
	spineheadarray[ii].sec {
		insert h
		insert pas
		insert kdr
		insert kap
		insert kad
		insert nafast2
		insert naslowcond2
	}
}	
	
xopen("code_synapse_array_setup.hoc")

numruns=50
objref parvec, voltsvec
parvec=new Vector(4)
voltsvec=new Vector(9*20)
strdef filename

//set up random number generator and pick locations
objref r
r=new Random(2)
r.uniform(0,1)
objref randomvals, randomvals2
randomvals=new Vector(20*numruns*basalno)
randomvals2=new Vector(20*numruns*basalno)
for tt=0,20*numruns*basalno-1 {
	randomvals.x[tt]=r.repick()
	randomvals2.x[tt]=r.repick()
}

func makecluster() {local number, scale, synnumber,runval
	//assign locations and strengths to a cluster of synapses on the same branch
	dendno=$1
	synnumber=$2
	scale=$3
	runval=$4
	randindex=dendno*20*numruns+runval*20+synnumber
	seg=randomvals.x[randindex]
	Aval=NA
	Nval=NN
	syntype=1
	probval=randomvals2.x[randindex]
	if (probval<0.075) {
		Aval=calc_syn_strength(1,50)
		Nval=calc_syn_strength(2,50)
		syntype=-1
	}
	if (scale==1) {
		seg=findpoint(dend[dendno].sec.L,seg)
		soma.sec {distance()}
		dend[dendno].sec {dval=distance(seg)}
		perfprob=0.09
		if (dval<135) {perfprob=(0.03*dval+0.9)/55}
		if (probval<perfprob) {
			Aval=calc_syn_strength(1,dval)
			Nval=calc_syn_strength(2,dval)
			syntype=-1
		}
	}
	print "seg ",seg
	dend[dendno].sec connect spineneckarray[synnumber-1].sec(0), seg
	spineheadarray[synnumber-1].sec {
		insert pas
		g_pas=dend[dendno].sec.g_pas(seg)
		cm=dend[dendno].sec.cm(seg)
		gbar_nafast2=0
		gbar_naslowcond2=0
		gbar_kdr=0
		gbar_kap=0
		gbar_kad=0
		//gbar_nafast2=dend[dendno].sec.gbar_nafast2(seg)
		//gbar_naslowcond2=dend[dendno].sec.gbar_naslowcond2(seg)
		//gbar_kdr=dend[dendno].sec.gbar_kdr(seg)
		//gbar_kap=dend[dendno].sec.gbar_kap(seg)
		//gbar_kad=dend[dendno].sec.gbar_kad(seg)
		gbar_h=0
	}
	spineneckarray[synnumber-1].sec {
		insert pas
		g_pas=dend[dendno].sec.g_pas(seg)
		cm=dend[dendno].sec.cm(seg)
		gbar_nafast2=0
		gbar_naslowcond2=0
		gbar_kdr=0
		gbar_kap=0
		gbar_kad=0
		//gbar_nafast2=dend[dendno].sec.gbar_nafast2(seg)
		//gbar_naslowcond2=dend[dendno].sec.gbar_naslowcond2(seg)
		//gbar_kdr=dend[dendno].sec.gbar_kdr(seg)
		//gbar_kap=dend[dendno].sec.gbar_kap(seg)
		//gbar_kad=dend[dendno].sec.gbar_kad(seg)
		gbar_h=0
	}
	spineheadarray[synnumber-1].sec {print gbar_nafast2, gbar_naslowcond2, gbar_kdr}
	syn_cc_array(dendno,seg,Aval,synnumber-1)
	syn_nmdacc_array(dendno,seg,Nval,synnumber-1)
	//disconnect and zero out the remaining values
	for clustercount=synnumber,20-1 {
		syn_cc_array(dendno,0,0,clustercount)
		syn_nmdacc_array(dendno,0,0,clustercount)
	}
	seg = seg*syntype
	return seg
}


// set up parallel context		
objref pc
pc = new ParallelContext()
print "number of hosts: ", pc.nhost(), "\thost id: ", pc.id() 

//function farmed out to slave nodes
func distspike() {localobj parvec, returnvec
        key_ds = $1
        parvec = $o2
        returnvec = new Vector(9*parvec.x[1])
        returnvec = activate_and_run_synapses(parvec.x[0],parvec.x[1],parvec.x[2],parvec.x[3])

        pc.pack(returnvec)
        pc.post(key_ds)

        return key_ds
}


obfunc max_voltage() {localobj outvec
	// find the location and amplitude of maximum voltage change in a dendrite
	dendval_mv=$1
	outvec=new Vector(7) //entries 0-1 :max voltage, location, entries 2-4: first time, location, entries 5-6 soma dv and maxloc
	outvec.x[0]=0
	outvec.x[1]=0
	outvec.x[2]=tstop
	outvec.x[3]=0
	outvec.x[4]=0
	outvec.x[5]=soma.sec.dvmax_dv(0.5)
	outvec.x[6]=soma.sec.vmaxt_dv(0.5)
	print tstop
	dend[dendval_mv].sec {
		for maxcount=0,nseg+1 {
			testloc=maxcount/(nseg+1)
			if (dvmax_dv(testloc)>outvec.x[0]) {
				outvec.x[0] = dvmax_dv(testloc)
				outvec.x[1] = testloc
			}
			if (vmaxt_dv(testloc) < outvec.x[2]) {
				outvec.x[2] = vmaxt_dv(testloc)
				outvec.x[3] = testloc
				outvec.x[4] = dvmax_dv(testloc)
			}
		}
	}
	return outvec
}

//function to sequentially run different numbers of synapses
obfunc activate_and_run_synapses() {localobj mv, returnvec
	dendval_ar=$1
	maxsyn_ar=$2
	scalefactor_ar=$3
	run_ar=$4
	mv = new Vector(7)
	returnvec = new Vector(maxsyn_ar*9)
	for scount_ar=1,maxsyn_ar {
		seg_ar=makecluster(dendval_ar,scount_ar,scalefactor_ar,run_ar)
		init()
		run()
		mv=max_voltage(dendval_ar)
		for count_ar=0,6 {
			returnvec.x[(scount_ar-1)*9+count_ar]=mv.x[count_ar]
		}
		returnvec.x[(scount_ar-1)*9+7]=abs(seg_ar)
		returnvec.x[(scount_ar-1)*9+8]=seg_ar/abs(seg_ar)
	}
	for tempcount=0,maxsyn_ar-1 {
		spineneckarray[tempcount].sec disconnect()
	}
	return returnvec
}

pc.runworker()

//procedure to run simulations on 5 specific long branches that include proximal, middle, and distal regions
proc calc_spikes() {
	somaA distance()
	for ttt=0,4 {
		if (ttt==0) {m=24}
		if (ttt==1) {m=21}
		if (ttt==2) {m=29}
		if (ttt==3) {m=5}
		if (ttt==4) {m=64}
		dend[m].sec startlen=distance(0)
		if ((startlen>50) && (dend[m].sec.L>100) && (startlen<100)) {
			for soption=0,1 {
				if (soption==0) {
					sprint(filename,"basal_spike_noscaling_branch_%d.dat",m)
				} else {
					sprint(filename,"basal_spike_scaling_branch_%d.dat",m)
				}
                scalefile.wopen(filename)	
				scalefile.printf("Dendrite\tRun\tNum_synapses\tLast_synapse_location\tLast_synapse_distance\tMax_dV\tMax_dV_location\tFirst_peak_time\tFirst_peak_location\tFirst_peak_voltage\tSoma_dV\tSoma_peak_time\n")
				dend[m].sec() {
					for nn=1,200 {
						parvec.x[0]=m
						parvec.x[1]=20
						parvec.x[2]=soption
						parvec.x[3]=nn
						mmtag=10000*nn+100*m+soption
						print "submitting ", mmtag
						pc.submit("distspike",mmtag,parvec)	//send out simulation calculations
					}
						
					//collect error values
					while (pc.working()) {	
						key = pc.retval()	//retrieve the tag
						pc.look_take(key)	//remove the tag/job from the bulletin
						
						voltsvec = pc.upkvec()	//unpack the error value associated with the tag
			
						print "received key ",key
						rno=int(key/10000)
						dno=(key-rno*10000)/100
						for ss=1,20 {
							somaA distance()
							dend[dno].sec ddist=distance(voltsvec.x[(ss-1)*9+7])
							scalefile.printf("%d\t%d\t%d\t%g\t%g\t%d",dno,rno,ss,voltsvec.x[(ss-1)*9+7],ddist,voltsvec.x[(ss-1)*9+8])
							for tt=0,6 {
								scalefile.printf("\t%g",voltsvec.x[(ss-1)*9+tt])
							}	
							scalefile.printf("\n")
						}
						scalefile.flush()
					}
				}	
			}
		}
	}
}

		
calc_spikes()