// since I am comparing single changes, I increase perturbation trials to 1000
load_file("nrngui.hoc")

create soma
create resistor_near
create capacitor_near
create neurites_far
create axon
access soma   // make soma the default section
soma {
    connect resistor_near(0), 1
    connect axon(0), 1
    }
resistor_near {
    connect capacitor_near(0), 1        
    }
capacitor_near {
    connect neurites_far(0), 1        
    }

// geometry
soma {
    L=   6.3407356878645703e+002    // um, length
    diam=   4.0172070898427108e+001  // um, diameter
    nseg=1    // compartmentalization parameter
    }
resistor_near {
    L=   5.6372590500769781e-001    // um, length
    diam=   5.6465364347491598e-001  // um, diameter
    nseg=1    // compartmentalization parameter
    }
capacitor_near {
    L=   2.2339749474466097e+002    // um, length
    diam=   5.3332778335997841e+002  // um, diameter
    nseg=1    // compartmentalization parameter
    }
neurites_far {
    L=   1.0290937340940929e+004    // um, length
    diam=   1.7007780296616840e+001  // um, diameter
    nseg=1    // compartmentalization parameter
    }
axon {
    L=   3.5454421750475530e+002  // um
    diam=   3.5428095561959343e+000  // um
    nseg=1
    }
// universal properties

forall {
    cm=   1.0000000000000000e+000  // uF/cm^2
    Ra=   1.0000000000000000e+002  // ohm*cm
    }
    
// make a SectionList for the soma/neurites
objref sn
sn = new SectionList()
soma sn.append()
capacitor_near sn.append()
neurites_far sn.append()

// make a SectionList for the neurites
objref neurites_all
neurites_all = new SectionList()
capacitor_near neurites_all.append()
neurites_far neurites_all.append()

// the mechanisms for all the soma/neurites cyls
forsec sn {
    insert leak
    insert kd
    insert af
    insert as
    insert h
    e_h=-25           // mV
    insert caint
    Pbar1_caint=1.1675   // um^3/s
    vol1_caint = 6.49 // um^3
    insert ca
    insert kca
    insert pr
    e_pr = -10  // mV
    ek=-80 // mV
    }

axon {
    insert leak
    insert na
    insert kda
    insert aa
    ena=55  // mV
    ek=-80 // mV
    Ra = 2
    }

forsec neurites_all {
    insert synab
    e_synab = -70  // mV
    insert synpd
    e_synpd = -80  // mV
    insert synpy
    e_synpy = -70  // mV
    }

// set globals
celsius=10   // degC
cai0_ca_ion=0.57e-3  // mM
cao0_ca_ion=13  // mM

dt = 0.02
steps_per_ms = 1/dt
tstart = 14000
tstop = 20000
dt_rec = 2

NP = 17
objref transvec,varvec,minvec,maxvec, tg, tg_std
transvec = new Vector(NP)
varvec = new Vector(NP)
objref m, p_mag,f
f = new File("./lp_model_mag.dat")
f.ropen()
p_mag = new Vector()
p_mag.scanf(f)
f.close()

ncell = 131

f = new File("./lp_model_intersect_ra_02_500.dat")
f.ropen()
m = new Matrix()
m.scanf(f,ncell,17)
f.close()

// set the synaptic strengths

objref  r[NP]
//highindex = 11
//lowindex = mcell_ran4_init(500)

highindex = 15
lowindex = mcell_ran4_init(300)


/****** please change here set the variation steps***********************/

objref var_step
var_step = new Vector(NP)
var1 = 0.1	//for soma and neurites
var2 = 0.1	//for axon

var_step.fill(var1,0,11)
var_step.fill(var2,12,NP-1)

n_var = 500
for (j=0;j<=NP-1;j+=1){
        r[j] = new Random()
        r[j].MCellRan4(highindex+j*2^24)		// Multiple independent streams are possible, each associated with a separate highindex with different starting values but the difference between highindex starting values should be greater than the length of each stream. 
        r[j].uniform(-var_step.x[j],var_step.x[j])
}

objref pc
pc = new ParallelContext()

func distscale() { local key, p1,p2,p3,p4,p5,p6,p7,p8,p9,p10,p11,p12,p13,p14,p15,p16,p17 localobj returnvec
	key = $1
	p1 = $o2.x(0)
	p2 = $o2.x(1)
	p3 = $o2.x(2)
	p4 = $o2.x(3)
	p5 = $o2.x(4)
	p6 = $o2.x(5)
	p7 = $o2.x(6)
	p8 = $o2.x(7)
	p9 = $o2.x(8)
	p10 = $o2.x(9)
	p11 = $o2.x(10)
	p12 = $o2.x(11)
	p13 = $o2.x(12)
	p14 = $o2.x(13)
	p15 = $o2.x(14)
	p16 = $o2.x(15)
	p17 = $o2.x(16)
	returnvec = new Vector()
	returnvec = single_comp(p1,p2,p3,p4,p5,p6,p7,p8,p9,p10,p11,p12,p13,p14,p15,p16,p17)

	pc.pack(returnvec)
	
	pc.post(key)
	return key
}

obfunc single_comp() {local meanSPB,freq_wave,mid,ISI_fast,freq localobj burster_vec,apc1,apc2,spkt,spkt1,wave,vol,vol2,isi,IBI,IBI_ind,first_spiketimes,fastISI,Nspikelet,ISIwave, periods 
	//function to calculate the max deflection due to a single synapse
cai0_ca_ion=0.57e-3  // mM
cao0_ca_ion=13  // mM
	forsec sn {
    e_leak=  $1
    gbar_leak=  $2
    gbar_kd=   $3
    gbar_af=   $4*0.885
    gbar_as=   $4
    Pbar_caint=   $5
    Pbar_ca=  $5
    if ($5<1e-8) {
    	Pbar_caint=   1e-8
    	Pbar_ca=   1e-8
    }
    gbar_kca=  $6
    gbar_h=   $7
    gbar_pr=   $8
    theta_pr=  $9
    }

// set the synaptic strengths
forsec neurites_all {
    gbar_synab=   $10
    gbar_synpd=   $11
    gbar_synpy=   $12
    }
    
// set the params for all the axon cyls
axon {
    e_leak=  $13
    gbar_leak=  $14
    gbar_na=   $15
    gbar_kda=   $16
    gbar_aa=   $17
    }

  spkt1 = new Vector()
  axon apc1 = new APCount(0.5)
  apc1.thresh = 10
  apc1.record(spkt1)

  vol = new Vector()
  vol.record(&soma.v(0.5),dt_rec)
  
  vol2 = new Vector()
  vol2.record(&axon.v(0.5),dt_rec)
   
  finitialize(-65)
  continuerun(tstop)
 // spkt.printf()
  // from here we get the statistical results for final analysis

  spkt = new Vector() 
  spkt.where(spkt1, ">=", tstart) // skip the initial tstart simulation results, so that models reach a stable state
  burster_vec = new Vector()
  vol.remove(0,tstart/dt_rec)
  vol2.remove(0,tstart/dt_rec)

  meanSPB = 0
  DC = 0
  freq_burst = 0  

if (spkt.size()>2) {

	isi = spkt.c.remove(0)			
	isi = isi.c.sub(spkt.c.resize(spkt.size()-1))	//here at least 2 spikes are required in spkt, its size should be n-1,say 3 spikes 2 ISIs, 0,1,2, n-2
	mid = (isi.max()+isi.min())/2
	IBI = isi.c.where( ">=", mid)
	IBI_ind = isi.c.indvwhere(isi.c, ">=", mid)		//find the index of interburst interval in the ISI vector
	IBI_ind.add(1)			// potentially in the range of 1,...n-1, so fine. even the last burster is not completed within the time range we compare
	first_spiketimes = spkt.ind(IBI_ind)
	fastISI = isi.c.where(isi.c, "<", mid)	// the thing is what about all isis are exactly the same

	Nspikelet = spkt.c.where(spkt.c,"[)",first_spiketimes.x(0),first_spiketimes.x(first_spiketimes.size()-1)) // to make this part work, we need at least 2 isi and 3 spikes
	
	if (first_spiketimes.size()>1) {
		meanSPB = Nspikelet.size()/(first_spiketimes.size()-1)		// the mean number of spikelets ber burst, until here, it seems to work
		periods = first_spiketimes.c.remove(0)
		periods = periods.c.sub(first_spiketimes.c.resize(first_spiketimes.size()-1))
	} else {
		meanSPB = Nspikelet.size()
		periods = first_spiketimes.c		// this is not technically right, but ok, if there are 3 spikes, at least here is not empty
	}
	
	period = periods.mean()	// mean burst duration
	
	if (period>0) {
		freq_burst = 1000/period
		DC = (period - IBI.mean())/period   // mean duty cycle, not technically correct, but fine.
		phase_on = (first_spiketimes.x(0)-15016.2)/period
	} else {
		freq_burst = 0
		DC = 0   // mean duty cycle, not technically correct, but fine.
		phase_on = 0
	}
	
	if (isi.max()-isi.min()<3) {
		meanSPB = 1
		DC = 0
		freq_burst = 1000/isi.mean()
		phase_on = 0
	}
	
		burster_vec.append($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,meanSPB, DC,freq_burst,phase_on,vol.max(),vol.min(),vol2.max(),vol2.min())
	
	
} else {

		burster_vec.append($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,0, 0,0,phase_on,vol.max(),vol.min(),vol2.max(),vol2.min())	// slow wave by last
	
}
//		print meanSPB
  return burster_vec
}

pc.runworker()

strdef tmpstr, outDir, cmd 
objref outfile, vec_ret
outfile = new File()
vec_ret = new Vector()
objref grid_ind

proc grid_search() {
	sprint(outDir,"simdata")
	sprint(cmd, "system(\"mkdir -p %s\")",outDir)
	execute(cmd)
	sprint(tmpstr,"%s/data_ra%03d_pert_%03dp.dat",outDir,axon.Ra,var2*100)
	outfile.wopen(tmpstr)
	
	for cu = 0, $1 {
				for j = 0, n_var-1	{
					mmtag=cu*1e5 + j	// n_var<1e5
					grid_ind = new Vector()
					a1 = m.x[cu][0]				//+r[0].repick()*p_mag.x[0]		//eleak
					a2 = (m.x[cu][1]+r[1].repick()*p_mag.x[1])*10	//gleak
					a3 = (m.x[cu][2]+r[2].repick()*p_mag.x[2])*10	//gkd
					if (a3<0) {a3=0}
					a4 = (m.x[cu][3]+r[3].repick()*p_mag.x[3])*10	//ga
					if (a4<0) {a4=0}
					a5 = (m.x[cu][4]+r[4].repick()*p_mag.x[4])*10	//pca
					if (a5<0) {a5=0}
					a6 = (m.x[cu][5]+r[5].repick()*p_mag.x[5])*10	//gkca
					if (a6<0) {a6=0}
					a7 = (m.x[cu][6]+r[6].repick()*p_mag.x[6])*10	//gh
					if (a7<0) {a7=0}
					a8 = (m.x[cu][7]+r[7].repick()*p_mag.x[7])*10   //gpr
					if (a8<0) {a8=0}
					a9 = m.x[cu][8]					//+r[8].repick()*p_mag.x[8]		//theta_pr
					a10 =(m.x[cu][9]+r[9].repick()*p_mag.x[9])*10 //gsyn_ab
					if (a10<0) {a10=0}
					a11 = (m.x[cu][10]+r[10].repick()*p_mag.x[10])*10	//gsyn_pd
					if (a11<0) {a11=0}
					a12 = (m.x[cu][11]+r[11].repick()*p_mag.x[11])*10	//gsyn_py
					a13 = m.x[cu][12]				//+r[12].repick()*p_mag.x[12]	//eleak_ax
					a14 = (m.x[cu][13]+r[13].repick()*p_mag.x[13])*10	//gleak_ax
					a15 = (m.x[cu][14]+r[14].repick()*p_mag.x[14])*10	//gNa_ax
					if (a15<0) {a15=0}
					a16 = (m.x[cu][15]+r[15].repick()*p_mag.x[15])*10	//gkd_ax
					a17 = (m.x[cu][16]+r[16].repick()*p_mag.x[16])*10	//ga_ax
					if (a17<0) {a17=0}
					grid_ind.append(a1,a2,a3,a4,a5,a6,a7,a8,a9,a10,a11,a12,a13,a14,a15,a16,a17)
					pc.submit("distscale",mmtag,grid_ind)	//send out the error calculations
				}

	}
	//collect error values
	while (pc.working()) {	
		key = pc.retval()	//retrieve the tag
		pc.look_take(key)	//remove the tag/job from the bulletin
		if (key%1000==0) {
      		      printf("%d parameter sets have been searched\n", key+1)
      		}
        p_id = int(key/1e5)
		p_tr = key-p_id*1e5
		vec_ret = pc.upkvec()	//unpack the error value associated with the tag

		out_size = vec_ret.size()
		
		out_size = vec_ret.size()
       	outfile.printf ("%04d %04d %g %g %g %g %g %g %g %g %g %g %g %g %g %g %g %g %g %g %g %g %g %g %g %g %g\n",p_id, p_tr,vec_ret.x(0), vec_ret.x(1),vec_ret.x(2),vec_ret.x(3),vec_ret.x(4),vec_ret.x(5),vec_ret.x(6),vec_ret.x(7),vec_ret.x(8),vec_ret.x(9),vec_ret.x(10),vec_ret.x(11),vec_ret.x(12),vec_ret.x(13),vec_ret.x(14),vec_ret.x(15),vec_ret.x(16),vec_ret.x(17),vec_ret.x(18),vec_ret.x(19),vec_ret.x(20),vec_ret.x(21),vec_ret.x(22),vec_ret.x(23),vec_ret.x(24) )
      	outfile.flush()
	}
	outfile.close()	
}

grid_search(ncell-1)