/* Author: Chris Fietkiewicz, Ph.D., Case Western Reserve University
   Contact: chris.fietkiewicz@case.edu

   Description: Reproduces spike raster used for Figure 11 from "Drive latencies 
    in hypoglossal motoneurons indicate developmental change in the brainstem 
    respiratory network" by C Fietkiewicz, KA Loparo, and CG Wilson in Journal of
    Neural Engineering, 2011 J. Neural Eng. 8:065011. Current clamp
    and voltage clamp traces seen in Figure 12 can be generated and saved to files
    by setting 'recordMode' to either 1 or 2, respectively (see variable description below).
    The raster plot shows 30 cells comprised of the following groups:
    pacemakers (1 - 14), premotoneurons/intermediate (15 - 28), and motoneurons (29 - 30).

   Acknowledgement: Based on source code developed by Timothy Anderson from T Anderson, 
    R Foglyano, and C Wilson, 2009. Altered respiratory rhythm in a preBotzinger
    complex model due to addition of low-threshold, non-inactivating K+ current and tonic
    input. BMC Neuroscience, 10, p.P323.
*/

setStop = 11500 // Simulation length in msec
strdef fileSuffix
fileSuffix = "figure11" // Used for output files
PMtoICweight1 = 0.003 // Synapse weight from pacemakers (PM) to intermediate cells (IC) (cluster #1)
PMtoICweight2 = 0.003 // Synapse weight from PM to IC (cluster #2)
ICtoHGweight = 0.004 // Synapse weight from IC to hypoglossal (HG) motoneurons
clusterToClusterWeight = 0.00004 // Inter-cluster synapse weight 
clusterToClusterRate = 0.57 // Probability (0 - 1) of successful connection from cell in one cluster to cell in other cluster
seedClusterConnections = 2 // seed for randomized connections
seedCells = 2 // seed for randomized cell parameter heterogeneity
recordMode = 3 // 0 = All spike times, 1 = HG input currents under v-clamp, 2 = HG voltage traces, 3 = Plot spike times but do not save files
clusterSize = 7 // Number of cells per cluster
pmConnectQty = 4 // Number of intra-cluster connections for each cell
PMtoPMweight = 0.00004 // Intra-cluster synapse weight
fanout = 0 // Number of IC cells to left and right that receive inputs from a given PM cell

load_file("cellTemplates.hoc")
objref connections // Stores PM to PM connections (0 = no connection, 1 = connection)
connections = new Matrix(2 * clusterSize, 2 * clusterSize)
print "seedClusterConnections, seedCells = ", seedClusterConnections, seedCells

//***************************Network specification interface****************
objref nclist, netcon, tonicweight
objref pmcells, tonicecells, tonicicells, icells, hgcells
objref ranlist, cells

cells = new List()
ranlist = new List()
nclist = new List()
hgcells = new List()
icells = new List()
pmcells = new List()

func hgcell_append() {
	cells.append($o1)
        hgcells.append($o1) 
	return hgcells.count - 1
}

func pmcell_append() {
	cells.append($o1)
        pmcells.append($o1) 
	return pmcells.count - 1
}

func icell_append() {
	cells.append($o1)
        icells.append($o1) 
	return icells.count - 1
}

//************************NEW CELLS*********************************
proc create3() {
     pmcells.remove_all()
     for i=0, $1-1 {
	  pmcell_append(new PM()) 
	  ranlist.append(new RandomStream(i))
	  }     
}
proc create4() {
     icells.remove_all()
     for i=0, $1-1 {
	 icell_append(new IC())
	 ranlist.append(new RandomStream(i))
	 }
}
proc create5() {
     hgcells.remove_all()
     for i=0, $1-1 {
	  hgcell_append(new HG()) 
	  ranlist.append(new RandomStream(i))
	  }
}
//***********************************CONNECTIONS***************************

proc clearnetwork() {
     nclist.remove_all()
     ranlist.remove_all()
     ranlist.browser
     cells.remove_all()
     tonicicells.remove_all()
     tonicecells.remove_all()
     pmcells.remove_all()
}
objref rs, connectFlag, connMat
connectivity1=0
connectivity2=0
connectivity3=0
connectivity4=0
connectivity5=0

proc SetPMtoPMConnect() {
connectivity4 =$1
}

// ******************************** PM -> PM (intra-cluster) ********************************
//objref row // For storing and printing connection matrix
objref randomConnection // Random number generator for network connections only
randomConnection = new Random(seedClusterConnections)

proc connectWithinCluster() {
	C_I2 = connectivity4	

	// Cluster #1
	for i = 0, pmcells.count() / 2 - 1 {
		randomConnection.discunif(0, pmcells.count() / 2 - 1)
		connectFlag = new Vector (pmcells.count()) 
		jj=0
//		print "Cell #", i, "..."
//		row = new Vector()
		while(jj < C_I2) { 
			r = randomConnection.repick()
			if (connectFlag.x[r] == 0 && r != i) {
				netcon = pmcells.object(i).connect2target(pmcells.object(r).synlist.object(0))
				netcon.threshold = 0
				netcon.delay = 0
				netcon.weight = $1   
				nclist.append(netcon)
				nclist.append(netcon)
				connectFlag.x[r] = 1
				jj += 1
//				row.append(r)
				connections.x[i][r] = 1
				print "Cluster #1: ", i, r
			}
		}
//		row.printf()
	}

	// Cluster #2
//	print "Cluster #2"
	for i = pmcells.count() / 2, pmcells.count() - 1 {
		randomConnection.discunif(pmcells.count() / 2, pmcells.count()-1)
		connectFlag = new Vector (pmcells.count()) 
		jj=0
//		print "Cell #", i, "..."
//		row = new Vector()
		while(jj < C_I2) { 
			r = randomConnection.repick()
			if (connectFlag.x[r] == 0 && r != i) {
				netcon = pmcells.object(i).connect2target(pmcells.object(r).synlist.object(0))
				netcon.threshold = 0
				netcon.delay = 0
				netcon.weight = $1   
				nclist.append(netcon)
				nclist.append(netcon)
				connectFlag.x[r] = 1
				jj += 1
//				row.append(r)
				connections.x[i][r] = 1
				print "Cluster #2: ", i, r
			}
		}
//		row.printf()
	}
}

// ******************************** PM -> PM (inter-cluster) ********************************
proc connectClusterToCluster() {
	randomConnection.uniform(0, 1)
	
	// Cluster #1 -> Cluster #2
//	print "Cluster #1 -> Cluster #2"
	for i = 0, pmcells.count() / 2 - 1 {
		r = randomConnection.repick()
//		row = new Vector()
//		print "Cell #", i, "..."
		for j = pmcells.count() / 2, pmcells.count() - 1 {
			r = randomConnection.repick()
			if (r < clusterToClusterRate) { // If this target cell is selected to receive...
				netcon = pmcells.object(i).connect2target(pmcells.object(j).synlist.object(0))
				netcon.threshold = 0
				netcon.delay = 0
				netcon.weight = clusterToClusterWeight   
				nclist.append(netcon)
//					row.append(j)
				connections.x[i][j] = 1
			}
		}
//		row.printf()
	}
	
	// Cluster #2 -> Cluster #1
//	print "Cluster #2 -> Cluster #1"
	for i = pmcells.count() / 2, pmcells.count() - 1 {
		r = randomConnection.repick()
//		row = new Vector()
//		print "Cell #", i, "..."
		for j = 0, pmcells.count() / 2 - 1 {
			r = randomConnection.repick()
			if (r < clusterToClusterRate) { // If this target cell is selected to receive...
				netcon = pmcells.object(i).connect2target(pmcells.object(j).synlist.object(0))
				netcon.threshold = 0
				netcon.delay = 0
				netcon.weight = clusterToClusterWeight   
				nclist.append(netcon)
//					row.append(j)
				connections.x[i][j] = 1
			}
		}
//		row.printf()
	}
}

// ******************************** PM -> IC ********************************
proc connectPMtoIC() {
	// Cluster #1
	for i = 0, icells.count() / 2 - 1 {
		targetStart = i - fanout
		if (targetStart < 0) {
			targetStart = 0
		}
		targetStop = i + fanout
		if (targetStop > icells.count() - 1) {
			targetStop = icells.count() - 1
		}
		for target = targetStart, targetStop {
			netcon = pmcells.object(i).connect2target(icells.object(target).synlist.object(0))
print "Connect: ", i, target
			netcon.threshold = 0
			netcon.delay = 0
			netcon.weight = $1
			nclist.append(netcon)
		}
	}

	// Cluster #2
	for i=icells.count() / 2, icells.count() - 1 {
		targetStart = i - fanout
		if (targetStart < 0) {
			targetStart = 0
		}
		targetStop = i + fanout
		if (targetStop > icells.count() - 1) {
			targetStop = icells.count() - 1
		}
		for target = targetStart, targetStop {
print "Connect: ", i, target
			netcon = pmcells.object(i).connect2target(icells.object(target).synlist.object(0))
			netcon.threshold = 0
			netcon.delay = 0
			netcon.weight = $2
			nclist.append(netcon)
		}
	}
}

// ******************************** IC -> HG ********************************
proc connectICtoHG() {
	// Cluster #1
	for i=0, icells.count() / 2 - 1 {
		netcon = icells.object(i).connect2target(hgcells.object(0).synlist.object(0))
		netcon.threshold = 0
		netcon.delay = 0
		netcon.weight = $1   
		nclist.append(netcon)
	}

	// Cluster #2
	for i=icells.count() / 2, icells.count() - 1 {
		netcon = icells.object(i).connect2target(hgcells.object(1).synlist.object(0))
		netcon.threshold = 0
		netcon.delay = 0
		netcon.weight = $1
		nclist.append(netcon)
	}
}
//*********************** CELL HETEROGENEITY *******************************
objref rpml, rpmn, ql, qn, cl, cn, q
objref randomCells // Random number generator for cell heterogeneity
randomCells = new Random(seedCells)

proc randomizeCells () { 
	print "PM Parameters..."
	for i=0, pmcells.count()-1 {
		PM[i].soma e_na=50
		PM[i].soma e_leak=-65
		PM[i].soma ek=-85

		// Set gmax_leak
		gLimitFactor = 0.40 // Factor for computing hard limit (+/-)
		gMean = 0.00005
		randomCells.uniform(gMean - gMean * gLimitFactor, gMean + gMean * gLimitFactor) // Set distribution
		r = randomCells.repick() // Pick value
		PM[i].soma gmax_leak = r
		r1 = r

		// Set gmax_K_leak
		gMean = 0.00002
		gLimitFactor = 0.40 // Factor for computing hard limit (+/-)
		randomCells.uniform(gMean - gMean * gLimitFactor, gMean + gMean * gLimitFactor) // Set distribution
		r = randomCells.repick() // Pick value
		PM[i].soma gmax_K_Leak = r
		r2 = r

		// Set gmax_K_leak
		gLimitFactor = 0.15 // Factor for computing hard limit (+/-)
		gMean = 0.00009
		randomCells.uniform(gMean - gMean * gLimitFactor, gMean + gMean * gLimitFactor) // Set distribution
		r = randomCells.repick() // Pick value
		PM[i].soma gmax_nap = r
		r3 = r
		
		print r1, r2, r3
	}

	// Reduction of legacy code which previously randomized ICs
	for i=0, icells.count() - 1 {
		IC[i].soma e_na=50
		IC[i].soma e_leak=-65
		IC[i].soma ek=-85
		IC[i].soma gmax_K_Leak = 0.0000666667 * 0.707
		IC[i].soma gmax_leak = 0.0000666667 * 0.2922
		IC[i].soma gmax_nap = 0.0000246667
	}

	// Reduction of legacy code which previously randomized HGs
	for i=0, hgcells.count()-1{
		HG[i].soma ena=60
		HG[i].soma eca=40
		HG[i].soma ek=-80
		HG[i].soma gmax_leak_hg = 0.0001476340694
		HG[i].soma gmax_NaP_hg = 0.00003974763407
	}
}

//*****************Changing Extracellular K****************
proc kbath() {
     for i=0, pmcells.count()-1 {
     PM[i].soma.kbath_pump=$1
     PM[i].soma.kbath_potassium=$1
     }
}

//********************** RECORD *********************
objref netcon, nil
objref spikeTimeVec, spikeIdVec, currentVec[2], vclamp[2], voltageVec[2]
spikeTimeVec = new Vector()
spikeIdVec = new Vector()

proc prepareRecording() {
	// Output spike times in current clamp
	if (recordMode == 0 || recordMode == 3) {
		for i=0, pmcells.count()-1 {
			pmcells.object(i).soma netcon = new NetCon(&v(0.5), nil, 0, 1, 0)    
			netcon.record(spikeTimeVec, spikeIdVec, i + 1)
		}

		for i=0, icells.count()-1 {
			icells.object(i).soma netcon = new NetCon(&v(0.5), nil, 0, 1, 0)    
			netcon.record(spikeTimeVec, spikeIdVec, i + 1 + pmcells.count())
		}

		for i=0, hgcells.count()-1 {
			hgcells.object(i).soma netcon = new NetCon(&v(0.5), nil, 0, 1, 0)    
			netcon.record(spikeTimeVec, spikeIdVec, i + 1 + pmcells.count() + icells.count())
		}
	}

	// Input current in voltage clamp
	if (recordMode == 1) {
		for i=0, hgcells.count()-1 {
			// Create v-clamps
			hgcells.object(i).soma vclamp[i] = new SEClamp(0.5)
			vclamp[i].dur1 = setStop
			vclamp[i].rs = 0.01
			vclamp[i].amp1 = -71

			// Set up recordings
			currentVec[i] = new Vector()
			currentVec[i].record(&vclamp[i].i)
		}
	}

	// Output voltage in current clamp
	if (recordMode == 2) {
		for i=0, hgcells.count()-1 {
			// Set up recordings
			voltageVec[i] = new Vector()
			voltageVec[i].record(&HG[i].soma.v(0.5))
		}
	}
}

//****************File Output*************
objref f, state

proc writeFiles() {  
	if (recordMode == 0) {
		strdef filename

		// Save connection matrix
		f = new File()
		sprint(filename, "connections_%s.txt", $s1)
		f.wopen(filename)
		connections.fprint(0, f)
		f.close()
		
		// Save i-clamp spike times
		f = new File()
		sprint(filename, "spikeTime_%s.txt", $s1)
		f.wopen(filename)
		spikeTimeVec.printf(f)
		f.close()

		// Save i-clamp spike IDs
		f = new File()
		sprint(filename, "spikeID_%s.txt", $s1)
		f.wopen(filename)
		spikeIdVec.printf(f)
		f.close()
	}

	if (recordMode == 1) {
		strdef fileName
		f = new File()
		for i=0, hgcells.count()-1 {
			sprint(fileName, "in%d_%s.bin", i + 1, $s1)
			f.wopen(fileName)
			currentVec[i].vwrite(f, 3)
			f.close()
		}
	}

	if (recordMode == 2) {
		strdef fileName
		f = new File()
		for i=0, hgcells.count()-1 {
			sprint(fileName, "out%d_%s.bin", i + 1, $s1)
			f.wopen(fileName)
			voltageVec[i].vwrite(f, 3)
			f.close()
		}
	}
}

//*************************INITIALIZATION************************
objref f, state

proc init() {
	celsius = 27
	ko0_k_ion = 5
	ki0_k_ion = 120
	cai0_ca_ion = 5e-05 
	cao0_ca_ion = 2
	prepareRecording()
	finitialize(v_init)
	if (cvode.active()) {
		cvode.re_init()
	} else {
		fcurrent()
	}
	frecord_init()
}

//**************** Setup *************
create3(clusterSize * 2) // Create PMs
create4(clusterSize * 2) // Create ICs
create5(2)               // Create HGs
randomizeCells()
kbath(8)
SetPMtoPMConnect(pmConnectQty) 
connectWithinCluster(PMtoPMweight)
connectClusterToCluster()
connectPMtoIC(PMtoICweight1, PMtoICweight2)
connectICtoHG(ICtoHGweight)

//**************** Run ****************
stdinit()
batch_run(setStop, 0.1)

//************* Print/save results ************
connections.printf() // Print connection matrix to screen for later retrieval
objref graster
// If plot option chosen, plot spike times
if (recordMode == 3) {
	graster = new Graph(0)
	graster.view(0 ,0 , setStop, pmcells.count() + icells.count() + hgcells.count(), 0, 0, 900, 600)
	spikeIdVec.mark(graster, spikeTimeVec, "|", 6)
// Otherwise, save results and quit
} else {
	writeFiles(fileSuffix)
	quit()
}