load_file("nrngui.hoc")
load_file("BLcells_template_LFP_segconsider_all_Iinject_recordingimembrane.hoc")
load_file("interneuron_template_gj_LFP_Iinject_recordingimembrane.hoc")
load_file("function_NetStimOR.hoc")
load_file("function_LoadMatrix.hoc")


upscale=27                       ////define cell numbers and morphology parameters
NCELL = 1000*upscale
TotalCellNum = 1000*upscale
CellNum_p=900*upscale
CellNum_interneuron = 100*upscale

///p_cell morphology///
adend_L_p=270 //um
nseg_adend_p=8 //num of seg 
pdend_L_p=555  //um
nseg_pdend_p=7
nseg_soma_p=1
diam_soma_p = 25 //um
diam_soma_p1 = 24.75 //um
diam_adend_p = 3 
diam_pdend_p = 5 

nseg_all_p=nseg_adend_p+nseg_pdend_p+nseg_soma_p

modelcompartment_num=nseg_all_p

////I_cell morphology////
dend_L_I=150  ///um
nseg_dend_I=1  ///num of seg 
nseg_soma_I=1 
diam_soma_I = 15   ////um
diam_dend_I = 10     ////um
nseg_all_I=nseg_dend_I+nseg_soma_I
modelcompartment_num_ITN=nseg_all_I

diam_shank = 25    //20-50um
extralimit = 50

objref pc,cells,cell[TotalCellNum],nc,nil//,r
pc = new ParallelContext()

cells = new List()


CellNum = 1000*upscale
celsius = 31.0  

objref Sim_length_file
objref Sim_length_rec
Sim_length_file = new File()
Sim_length_file.ropen("./input/sim_length")   ////simulation length
Sim_length_rec = new Vector()
Sim_length_rec.scanf(Sim_length_file)


tstop = Sim_length_rec.x[0]

{load_file("function_TimeMonitor.hoc")}
dt = 0.05

steps_per_ms= 20

v_init = -70
//t0 = startsw()

//////////////////////// Choose the cell type /////////////////////

//// Read Cell_type.txt ////
objref Cell_type_file
objref Cell_type_rec
Cell_type_file = new File()
Cell_type_file.ropen("./input/Cell_type.txt")
Cell_type_rec = new Vector()
Cell_type_rec.scanf(Cell_type_file)

//// Read NM.txt ////   Randomly choosen number to decide whether cell has DA and NE or not
objref NM_file   
objref NM_rec
NM_file = new File()
NM_file.ropen("./input/NM.txt")
NM_rec = new Vector()
NM_rec.scanf(NM_file)

////Read 3D-Location information(soma)////
objref Location
Location = new Matrix()
strdef locationstr
locationstr="./input/location.txt"

Location = LoadMatrix(locationstr,TotalCellNum,3)


for(i=pc.id;i<NCELL;i+=pc.nhost) {              // Distribute the cells evenly among nodes
	NM_ind = NM_rec.x[i]                        // Actually, we didn't study NM effect in this study
	if (i < CellNum_p){								// 
		Cell_type_ind = Cell_type_rec.x[i]	
		if (Cell_type_ind <= 5){
			if (NM_ind == 0){
				cell = new Cell_A()
				cell.drv.amp=0.0//P_amp_random.repick()
			}else if (NM_ind == 1){
				//cell = new Cell_ADA()
                cell = new Cell_A()
				cell.drv.amp=0.0//P_amp_random.repick()
			}else if (NM_ind == 2){
				//cell = new Cell_ANE()
                cell = new Cell_A()
				cell.drv.amp=0.0//P_amp_random.repick()
			}else{
				//cell = new Cell_ADANE()
                cell = new Cell_A()
				cell.drv.amp=0.0//P_amp_random.repick()
			}
		}else if(Cell_type_ind <= 8){
			if (NM_ind == 0){
				cell = new Cell_B()
				cell.drv.amp=0.0//P_amp_random.repick()
			}else if (NM_ind == 1){
				//cell = new Cell_BDA()
                cell = new Cell_B()
				cell.drv.amp=0.0//P_amp_random.repick()
			}else if (NM_ind == 2){
				//cell = new Cell_BNE()
                cell = new Cell_B()
				cell.drv.amp=0.0//P_amp_random.repick()
			}else{
				//cell = new Cell_BDANE()
                cell = new Cell_B()
				cell.drv.amp=0.0//P_amp_random.repick()
			}
		}else{
			if (NM_ind == 0){
				cell = new Cell_C()
				cell.drv.amp=0.0//P_amp_random.repick()
			}else if (NM_ind == 1){
				//cell = new Cell_CDA()
                cell = new Cell_C()
				cell.drv.amp=0.0//P_amp_random.repick()
			}else if (NM_ind == 2){
				//cell = new Cell_CNE()
                cell = new Cell_C()
				cell.drv.amp=0.0//P_amp_random.repick()
			}else{
				//cell = new Cell_CDANE()
                cell = new Cell_C()
				cell.drv.amp=0.0//P_amp_random.repick()
			}
		}
	}else{
		cell = new InterneuronCell()            // Create Interneuron cells 
	    cell.drv.amp=0.0//I_amp_random.repick()
	}
	cells.append(cell)                          // Add this cell to the list (otherwise its lost!)
	pc.set_gid2node(i, pc.id)                   // Associate cell id with this node id
												// nc = (create netcon object on cell)
	nc = cell.connect2target(nil) 				// attach spike detector $
	nc.delay = 2
	nc.weight = 1
	pc.cell(i, nc)								// associate gid i with spike detector
												// Associate i with the netcon (so that the cluster 
	 		                                    // knows where the spikes are coming from)													
}

//////////////////////////////////////////////////////////////
//////////////  Connections for BL NET  //////////////////////
//////////////////////////////////////////////////////////////

objref cellid,fluc[TotalCellNum][2]


//// Read Cell_list--- list of files whose output,weight changes and ca+ concentration, will be printed ////
	objref op_file
	objref op_rec
	op_file = new File()
	op_file.ropen("./input/Cell_list.txt")
	op_rec = new Vector()
	op_rec.scanf(op_file)
	cell_plots = op_rec.size
  


/////////////////////////////////////////////////////////
///////////////Pyramid cells bg-noises/////////////////	
/////////////////////////////////////////////////////////
objref noiseRandObj[TotalCellNum][2]

////////////////////// Single-Point BACKGROUNG noise OR PYRAMIDAL CELLS //////////////////////
for m = 0, CellNum_p-1{
    if(!pc.gid_exists(m)) { continue }				// Can't connect to target if it doesn't exist 
													// on the node ("continue") skips rest of code
	cellid = pc.gid2cell(m)                     	// get GID object from ID	
	///for exc noise///////////////
	noiseRandObj[m][0]=new Random()
    noiseRandObj[m][0].Random123(m+100)  //# set lowindex to gid, set highindex to what?   
    noiseRandObj[m][0].normal(0,1)
	
	cellid.soma fluc[m][0] = new Gfluct2_exc(0.5)
    
    fluc[m][0].g_e0 = 0.0032//cannot be too high, otherwise would saturate, same for interneuron
    fluc[m][0].std_e = 0.003	
    fluc[m][0].setRandObj(noiseRandObj[m][0])

    ///for inh noise///////////////
	noiseRandObj[m][1]=new Random()
    noiseRandObj[m][1].Random123(m+100+TotalCellNum)   
    noiseRandObj[m][1].normal(0,1)
	
	cellid.soma fluc[m][1] = new Gfluct2_inh(0.5)
    fluc[m][1].g_i0 = 0.021
    fluc[m][1].std_i = 0.008
    fluc[m][1].setRandObj(noiseRandObj[m][1])
	
}	


///connect extrinsic inputs

////import connections/////
strdef E2P_str,E2P_size_str

sprint(E2P_str,"./input/E2P") ///E2P file stores the external connections to PNs
sprint(E2P_size_str,"./input/E2P_size")

///for Extrinsic
objref E2P_size_matrix,E2P_matrix
E2P_size_matrix = new Matrix()
E2P_size_matrix = LoadMatrix(E2P_size_str,1,2)

Inputnum_E=E2P_size_matrix.x[0][0] ///how many inputs
E_connectnum=E2P_size_matrix.x[0][1]   ///each input can connect how many # of cells

E2P_matrix = new Matrix()
E2P_matrix = LoadMatrix(E2P_str,Inputnum_E,E_connectnum) 



////import connections to INTs/////
strdef E2I_str,E2I_size_str
sprint(E2I_str,"./input/E2I") 
sprint(E2I_size_str,"./input/E2I_size")

///for Extrinsic
objref E2I_size_matrix,E2I_matrix
E2I_size_matrix = new Matrix()
E2I_size_matrix = LoadMatrix(E2I_size_str,1,2)

Inputnum_I=E2I_size_matrix.x[0][0] ///how many inputs
I_connectnum=E2I_size_matrix.x[0][1]   ///each input can connect how many # of cells

E2I_matrix = new Matrix()
E2I_matrix = LoadMatrix(E2I_str,Inputnum_I,I_connectnum) 


////import external spikes /////
strdef Etn_spikes_str,Etn_spikes_ind_str
sprint(Etn_spikes_str,"./input/spikesmatrix_op")
sprint(Etn_spikes_ind_str,"./input/spikes_ind")

objref Etn_spikes_file,Etn_spikes_ind_file
Etn_spikes_file=new File()
Etn_spikes_ind_file=new File()
Etn_spikes_file.ropen(Etn_spikes_str)
Etn_spikes_ind_file.ropen(Etn_spikes_ind_str)

objref Etn_spikes,Etn_spikes_ind
Etn_spikes=new Vector()
Etn_spikes_ind=new Vector()
Etn_spikes.scanf(Etn_spikes_file)
Etn_spikes_ind.scanf(Etn_spikes_ind_file)

//////call Procs for introducing spikes into network
{load_file("function_ConnectInputs_invivo_op.hoc")}

for rank=0, pc.nhost-1 { // host 0 first, then 1, 2, etc.
		if (rank==pc.id) {
ConnectInputs_E()///external spikes for PNs
}
pc.barrier()
}

for rank=0, pc.nhost-1 { // host 0 first, then 1, 2, etc.
		if (rank==pc.id) {
ConnectInputs_I()///external spikes for FSIs
}
pc.barrier()
}


    
    


//////////////////////single point BACKGROUNG fluctuation FOR INTERNEURONS//////////////////////
for m = CellNum_p, TotalCellNum-1{
    if(!pc.gid_exists(m)) { continue }				// Can't connect to target if it doesn't exist 
													// on the node ("continue") skips rest of code		
	cellid = pc.gid2cell(m)                     	// get GID object from ID	
     ///for exc noise///////////////
	noiseRandObj[m][0]=new Random()
    noiseRandObj[m][0].Random123(m+100)  //# set lowindex to gid, set highindex to what?   
    noiseRandObj[m][0].normal(0,1)
	
	cellid.soma fluc[m][0] = new Gfluct2_exc(0.5)
	fluc[m][0].g_e0 = 0.00121
    fluc[m][0].std_e = 0.00012	
    
    fluc[m][0].setRandObj(noiseRandObj[m][0])

    ///for inh noise///////////////
	noiseRandObj[m][1]=new Random()
    noiseRandObj[m][1].Random123(m+100+TotalCellNum)  //# set lowindex to gid, set highindex to what?   
    noiseRandObj[m][1].normal(0,1)
	
	cellid.soma fluc[m][1] = new Gfluct2_inh(0.5)
	fluc[m][1].g_i0 = 0.00573
    fluc[m][1].std_i = 0.00264

    fluc[m][1].setRandObj(noiseRandObj[m][1])
}


////////////////////////////////////////////////////
//////////  set up all internal connections  ///////////
////////////////////////////////////////////////////

strdef syn_str,syn_ind_str
sprint(syn_str,"./input/active_syn_op")  ////files stores all connection ids.
sprint(syn_ind_str,"./input/active_syn_ind")

objref syn_connect_file,syn_ind_file
syn_connect_file=new File()
syn_ind_file=new File()
syn_connect_file.ropen(syn_str)
syn_ind_file.ropen(syn_ind_str)

objref syn_connect,syn_ind
syn_connect=new Vector()
syn_ind=new Vector()
syn_connect.scanf(syn_connect_file)
syn_ind.scanf(syn_ind_file)


////load for GAP junction info

strdef GAP_str,GAP_ind_str,GAP_transferID_str
sprint(GAP_str,"./input/active_syn_GAP_op")  ///gap junction connected cells
sprint(GAP_ind_str,"./input/active_GAP_ind")
sprint(GAP_transferID_str,"./input/GAPid")

objref GAP_connect_file,GAP_ind_file,GAP_transferID_file
GAP_connect_file=new File()
GAP_ind_file=new File()
GAP_transferID_file=new File()
GAP_connect_file.ropen(GAP_str)
GAP_ind_file.ropen(GAP_ind_str)
GAP_transferID_file.ropen(GAP_transferID_str)

objref GAP_connect,GAP_ind,GAP_transferID
GAP_connect=new Vector()
GAP_ind=new Vector()
GAP_transferID=new Vector()
GAP_connect.scanf(GAP_connect_file)
GAP_ind.scanf(GAP_ind_file)
GAP_transferID.scanf(GAP_transferID_file)// gid only used for gap junction transfer

////recording internal synaptic connectivity
objref saveM
saveM = new File()

if(pc.id==0){     //"wopen" once by node 0 to clear the contents of the file
saveM.wopen("Matrix_NEW")
saveM.close()
}

{load_file("function_ConnectTwoCells.hoc")}
{load_file("function_ConnectInternal_simplify_online_op.hoc")}
for rank=0, pc.nhost-1 { // host 0 first, then 1, 2, etc.
		if (rank==pc.id) {
ConnectInternal()

}
pc.barrier()
}


{load_file("function_ConnectInternal_gj_simplify.hoc")}	///for gap junction connection

for rank=0, pc.nhost-1 { // host 0 first, then 1, 2, etc.
		if (rank==pc.id) {		
ConnectInternal_gj_simplify()
}
pc.barrier()
}

for rank=0, pc.nhost-1 { // host 0 first, then 1, 2, etc.	
{ pc.setup_transfer() }	
}


//////voltage recording////////////////////////////////////////////
objref Volrec[TotalCellNum], vollist
vollist = new List()

proc RecVol() {  local i
for i=0,cell_plots-1{
op = op_rec.x[i]
if(pc.gid_exists(op)){
cellid = pc.gid2cell(op)
Volrec[op] = new Vector()
Volrec[op].record(&cellid.soma.v(0.5),1)
vollist.append(Volrec[op])
}
}
}
//RecVol()   //call if needed


//////recording spikes
objref tvec, idvec 										// will be Vectors that record all spike times (tvec)
														// and the corresponding id numbers of the cells that spiked (idvec)
proc spikerecord() {local i localobj nc, nil

	tvec = new Vector()
	idvec = new Vector()
	for i=0, cells.count-1 {
	  nc = cells.object(i).connect2target(nil)
	  nc.record(tvec, idvec, nc.srcgid)
														// the Vector will continue to record spike times even after the NetCon has been destroyed
	}
}

spikerecord()
{pc.set_maxstep(10)}
stdinit()
//t1 = startsw()
{pc.psolve(tstop)}
//t2 = startsw()

//print "model setup time ", t1-t0, " run time ", t2-t1, " total ", t2-t0


/////save voltage////
objref f_volt
strdef vols

proc SavVol() { local i,d 

for i = 0,cell_plots-1 {
op = op_rec.x[i]
if(pc.gid_exists(op)){
sprint(vols,"./volts/volts_%d",op) 
f_volt = new File()
f_volt.wopen(vols)
d = vollist.index(Volrec[op])
vollist.o[d].printf(f_volt)
f_volt.close()  
}
}
}
//SavVol()   //call if needed

//////Below is for LFP calculation
////load electrode array positions and rotation info of soma for each neuron
objref oritation,elec_coords
oritation = new Matrix()
elec_coords = new Matrix()

strdef oritationstr,elec_coords_str
oritationstr="./input/oritation.txt"    ///control each cell dend's rotation
elec_coords_str="./input/elec_coords.txt"  //// define eletrode arrays coordinates (um)

Location.muls(1e3)  //convert to um
oritation = LoadMatrix(oritationstr,TotalCellNum,3)
elec_coords = LoadMatrix(elec_coords_str,1*1*1,3)  //if recording multiple electrode LFPs, do change the dimension for reading coordiantes files
load_file("function_calcconduc.hoc")  ///for calculating extracellular conductances

objref Location_single,oritation_single,elec_single
Location_single = new Vector()
oritation_single = new Vector()
elec_single = new Vector()

sigma=0.3///conductivity value
objref LFPrecording_len
LFPrecording_len=new Vector(1)
root_cell_id=3
///to define length of LFP automatically
if (pc.gid_exists(root_cell_id)) { 
cellid = pc.gid2cell(root_cell_id)
myid = pc.id()
LFPrecording_len.x[0]=cellid.tlist.o(0).size()//length of recorded LFP is dependent on simulation time and recording resolution
}
pc.barrier()
pc.allreduce(LFPrecording_len,1)


//objref LFP_multi[elec_coords.nrow()]   ///store LFP for multiple electordes
objref LFP_per_core  ////to record sum of LFP per core based on the same electrode
objref LFP_comp_temp_PN[CellNum_p][modelcompartment_num],LFP_comp_temp_ITN[CellNum_interneuron][modelcompartment_num_ITN]
objref LFP_single_file


proc multiLFP() { local m,op,i

for (m=0; m<elec_coords.nrow();m+=1) { // calculate LFP per electrode
elec_single=elec_coords.bcopy(m,0,1,3).to_vector()
LFP_per_core=new Vector(LFPrecording_len.x[0])    ////to record sum of LFP per core based on the same electrode
     for op = 0,CellNum-1 {
        if(pc.gid_exists(op)){ 
         dis=sqrt((elec_single.x[0]-Location.x[op][0])^2+(elec_single.x[1]-Location.x[op][1])^2+(elec_single.x[2]-Location.x[op][2])^2)
          if (dis<=10e120)  {    //only cells within centain distance of electrodes will be counted
         cellid = pc.gid2cell(op)
         Location_single=Location.bcopy(op,0,1,3).to_vector()  ///copy single cell' location info.
         oritation_single=oritation.bcopy(op,0,1,3).to_vector() ///copy single cell' oritation info.
  
             
             if (op<CellNum_p) {  /// for PNs               
                 for (i=0; i<cellid.tlist.count(); i+=1) {  ///i stands for segment
                 conduct=calcconduc_PN(Location_single,i,oritation_single,elec_single,dis)  //i=0 for soma, 1-8 for adend, 9-15 for pdend
                 LFP_comp_temp_PN[op][i]=new Vector(LFPrecording_len.x[0])
                 LFP_comp_temp_PN[op][i].copy(cellid.tlist.o(i))
                 LFP_comp_temp_PN[op][i].mul(conduct)
                 LFP_per_core.add(LFP_comp_temp_PN[op][i])  ///LFP_single is summed up per node                   
                    }
                } else {  /// for ITNs 
                  for (i=0; i<cellid.tlist.count(); i+=1) {  ///i stands for segment
                   conduct=calcconduc_ITN(Location_single,i,oritation_single,elec_single,dis)  //i=0 for soma, else for dends
                   LFP_comp_temp_ITN[op-CellNum_p][i]=new Vector(LFPrecording_len.x[0])
                   LFP_comp_temp_ITN[op-CellNum_p][i].copy(cellid.tlist.o(i))
                   LFP_comp_temp_ITN[op-CellNum_p][i].mul(conduct)
                   LFP_per_core.add(LFP_comp_temp_ITN[op-CellNum_p][i])  ///LFP_single is summed up per node
                   }
              }

            }
           }
          }
  
  pc.barrier ()    // wait for all hosts to get to this point
  pc.allreduce(LFP_per_core,1)    ///sum-up contributions from each host to have final LFP
   
  if(pc.id==0){     //open write files on host 0
   strdef LFPs
   sprint(LFPs,"./LFPs/LFP_elec_%d",m) 
   LFP_single_file=new File()
   LFP_single_file.wopen(LFPs)
   LFP_per_core.printf(LFP_single_file)
   LFP_single_file.close()
   }
  pc.barrier ()    // wait for all hosts to get to this point


}
}
multiLFP()


// PROCEDURE TO SEND SPIKES TO A FILE "SPIKERASTER"--------
objref spikefile

spikefile = new File("data")

if(pc.id==0){     //"wopen" once by node 0 to clear the contents of the file
spikefile.wopen()
spikefile.close()
}


proc spikefileout() { local i, rank
pc.barrier() // wait for all hosts to get to this point

for rank=0, pc.nhost-1 { // host 0 first, then 1, 2, etc.

if (rank==pc.id) {
for i=0, tvec.size-1 {
spikefile.aopen()                               //"aopen" to append data
spikefile.printf("%8.4f\t %d\n", tvec.x[i], idvec.x[i])
spikefile.close()
}
}
pc.barrier() // wait for all hosts to get to this point
}
}

spikefileout()

{pc.runworker()}
{pc.done()}