/*
Stimulus protocol using either random uniform or oscillatory synaptic activity
to illustrate single-branch heterosynaptic effects and NMDA spikes.
Demonstrating the synaptic plasticity rule by Ebner et al. (2019)
https://doi.org/10.1016/j.celrep.2019.11.068
Made to work with the Hay et al. (2011) L5b pyramidal cell model
This code uses NEURON's vecevent.mod mechanism.
*/

load_file("nrngui.hoc")

// ## Parameters
DT = 0.025 				// ms, integration step size
WARM_UP = 2350 			// ms, delay from beginning of simulation to approach steady state
RUNTIME = 350			// ms, interval of synaptic activity
PAUSETIME = 0			// ms, interval of silence (only needed if repetitions > 1)
COOL_DOWN = 100			// ms, time after end of events
REPETITIONS = 1			// number of cycles (one cycle = run + pause)
A_SYN_WEIGHT = 0.0025	// µS, conductance of a single synapse at location A
B_SYN_WEIGHT = 0.0025	// µS, conductance of a single synapse at location B
N_SYN_A = 8				// Number of plastic synapses at A
N_SYN_B = 8				// Number of plastic synapses at B
FREQ = 8				// Hz, average Poisson event frequency
WAVE_FREQ = 8			// Hz, frequency of the oscillation itself
WAVE_AMP = 0.05			// amplitude of sine wave function
WAVE_HEIGHT = 0.005		// baseline of sine wave function
SYN_PAUSE = 0			// ms, turn active synapses silent for this time (deactivated by default)
A_MODE = 1				// 0 ... Poisson; 1 ... synchronized
B_MODE = 0				// 0 ... Poisson; 1 ... synchronized
S_AMPA = 0.5			// AMPA current ratio
S_NMDA = 0.5			// NMDA current ratio
BRANCH = 53				// example branch (ID of apical section)
DIST_A = 0.8			// relative distance of cluster A on the branch
DIST_B = 0.2			// relative distance of cluster B on the branch

pi = 3.1415926

objref SEED
SEED = new Vector()
// SEED.indgen(1,100,1)
SEED.append(1)			// deterministic

// ## References
objref trec, vrec_a, vrec_b
objref a_syn_list, b_syn_list, a_nc_list, b_nc_list, a_stim_list, b_stim_list
objref a_syn_loc, b_syn_loc
objref a_syn_weights, b_syn_weights
objref r, tempobj

// ## Create cell
load_file("import3d.hoc")
objref L5PC
strdef morphology_file
morphology_file = "morphologies/cell1.asc"
load_file("models/L5PCbiophys4.hoc")
load_file("models/L5PCtemplate.hoc")
L5PC = new L5PCtemplate(morphology_file)

proc init(){
            finitialize(-70)
            dt = 10
            counter = 0
            for counter = 1,228 fadvance()	// Run warm-up phase using big steps
            dt = DT
}

obfunc rand_vec() { local n, start, end localobj vec
	// Create a vector containing n random numbers in the interval start to end
	n = $1
	start = $2
	end = $3
	vec = new Vector(n)
	r.uniform(start,end)
	vec.setrand(r)
	return vec
}

obfunc make_stim() { local counter, counter2, pause, timer localobj rand, events, stim
	// Create Poisson event VecStim objects
	events = new Vector()
	pause = 0
	timer = WARM_UP
	for counter2 = 1, REPETITIONS {
		rand = rand_vec(total_time, 0, 1)
		for counter = 0, RUNTIME - 1 {
			if(rand.x(counter) <= (FREQ / 1000) && pause == 0) {
				events.append(timer + counter)	// append spike times
				pause = SYN_PAUSE
			}
			if(pause > 0) { pause -= 1 } // allow the next event to happen only after a synaptic pause (deactivated by default)
		}
		timer += RUNTIME + PAUSETIME
	}
	stim = new VecStim() // using the VecStim class
	stim.play(events)
	return stim
}

obfunc oscillation() { local freq, amp, height, counter localobj vec
	// Create vector containing a custom sine wave function
	freq = $1
	amp = $2
	height = $3
	vec = new Vector(RUNTIME) // total stimulation time (in ms)
	for counter = 0, vec.size() - 1 {
		vec.x(counter) = sin((counter * 2 * pi) / (1000 / freq)) * 0.5 * amp + height
	}
	return vec
}

obfunc make_stim_wave() { local counter, counter2, pause, timer localobj prob, rand, events, stim
	// Create VecStim objects from input wave vector (oscillating probability)
	prob = $o1
	events = new Vector()
	pause = 0
	timer = WARM_UP
	for counter2 = 1, REPETITIONS {
		rand = rand_vec(RUNTIME, 0, 1)
		for counter = 0, RUNTIME - 1 {
			if(rand.x(counter) <= prob.x(counter) && pause == 0) {
				events.append(timer + counter)	// append spike times
				pause = SYN_PAUSE
			}
			if(pause > 0) { pause -= 1 } // allow the next event to happen only after a synaptic pause (deactivated by default)
		}
		timer += RUNTIME + PAUSETIME
	}
	stim = new VecStim()
	stim.play(events)
	return stim
}

proc make_syn_a() { local section, dist, index_syn, index_stim
	// Make plastic synapse at location A
	section = $1	// Section ID
	dist = $2		// Relative distance
	L5PC.apic[section] a_syn_list.append(new Syn4P(dist))	// at the specified location along the section
	index_syn = a_syn_list.count() - 1
	if(A_MODE == 0) {		// check for uniform or oscillating probability
		a_stim_list.append(make_stim())
	} else {
		a_stim_list.append(make_stim_wave(oscillation(WAVE_FREQ, WAVE_AMP, WAVE_HEIGHT)))
	}
	index_stim = a_stim_list.count() - 1
	a_nc_list.append(new NetCon(a_stim_list.o(index_stim), a_syn_list.o(index_syn), 0, 0, A_SYN_WEIGHT))
	a_syn_list.o(index_syn).tau_a = 0.2 						// time constant of EPSP rise
	a_syn_list.o(index_syn).tau_b = 2 							// time constant of EPSP decay
	a_syn_list.o(index_syn).e = 0								// reversal potential
	a_syn_list.o(index_syn).w_pre_init = 0.5					// pre factor initial value
	a_syn_list.o(index_syn).w_post_init = 2.0					// post factor initial value
	a_syn_list.o(index_syn).s_ampa = S_AMPA						// contribution of AMPAR currents
	a_syn_list.o(index_syn).s_nmda = S_NMDA						// contribution of NMDAR currents
	a_syn_list.o(index_syn).tau_G_a = 2 						// time constant of synaptic event G2 (rise)
	a_syn_list.o(index_syn).tau_G_b = 50 						// time constant of synaptic event G2 (decay)
	a_syn_list.o(index_syn).m_G = 10							// slope of the saturation function for G2
	a_syn_list.o(index_syn).A_LTD_pre = 1.5e-3					// amplitude of pre-LTD
	a_syn_list.o(index_syn).A_LTP_pre = 2.5e-4					// amplitude of pre-LTP
	a_syn_list.o(index_syn).A_LTD_post = 7.5e-4					// amplitude of post-LTD
	a_syn_list.o(index_syn).A_LTP_post = 7.8e-2					// amplitude of post-LTP
	a_syn_list.o(index_syn).tau_u_T = 10 						// time constant for filtering u to calculate T
	a_syn_list.o(index_syn).theta_u_T = -60						// voltage threshold applied to u to calculate T
	a_syn_list.o(index_syn).m_T = 1.7							// slope of the saturation function for T
	a_syn_list.o(index_syn).theta_u_N = -30						// voltage threshold applied to u to calculate N
	a_syn_list.o(index_syn).tau_Z_a = 1							// time constant of prea_syn_list.o(index_syn)aptic event Z (rise)
	a_syn_list.o(index_syn).tau_Z_b = 15 						// time constant of prea_syn_list.o(index_syn)aptic event Z (decay)
	a_syn_list.o(index_syn).m_Z = 6								// slope of the saturation function for Z
	a_syn_list.o(index_syn).tau_N_alpha = 7.5 					// time constant for calculating N-alpha
	a_syn_list.o(index_syn).tau_N_beta = 30						// time constant for calculating N-beta
	a_syn_list.o(index_syn).m_N_alpha = 2						// slope of the saturation function for N_alpha
	a_syn_list.o(index_syn).m_N_beta = 10						// slope of the saturation function for N_beta
	a_syn_list.o(index_syn).theta_N_X = 0.2						// threshold for N to calculate X
	a_syn_list.o(index_syn).theta_u_C = -68						// voltage threshold applied to u to calculate C
	a_syn_list.o(index_syn).theta_C_minus = 15					// threshold applied to C for post-LTD (P activation)
	a_syn_list.o(index_syn).theta_C_plus = 35					// threshold applied to C for post-LTP (K-alpha activation)
	a_syn_list.o(index_syn).tau_K_alpha = 15 					// time constant for filtering K_alpha to calculate K_alpha_bar
	a_syn_list.o(index_syn).tau_K_gamma = 20 					// time constant for filtering K_beta to calculate K_gamma
	a_syn_list.o(index_syn).m_K_alpha = 1.5						// slope of the saturation function for K_alpha
	a_syn_list.o(index_syn).m_K_beta = 1.7						// slope of the saturation function for K_beta
	a_syn_list.o(index_syn).s_K_beta = 100						// scaling factor for calculation of K_beta
}

proc make_syn_b() { local section, dist, index_syn, index_stim
	// ## Make plastic synapse at location B
	section = $1	// Section ID
	dist = $2		// Relative distance
	L5PC.apic[section] b_syn_list.append(new Syn4P(dist))	// at the specified location along the section
	index_syn = b_syn_list.count() - 1
	if(B_MODE == 0) {		// check for uniform or oscillating probability
		b_stim_list.append(make_stim())
	} else {
		b_stim_list.append(make_stim_wave(oscillation(WAVE_FREQ, WAVE_AMP, WAVE_HEIGHT)))
	}
	index_stim = b_stim_list.count() - 1
	b_nc_list.append(new NetCon(b_stim_list.o(index_stim), b_syn_list.o(index_syn), 0, 0, B_SYN_WEIGHT))
	b_syn_list.o(index_syn).tau_a = 0.2 						// time constant of EPSP rise
	b_syn_list.o(index_syn).tau_b = 2 							// time constant of EPSP decay
	b_syn_list.o(index_syn).e = 0								// reversal potential
	b_syn_list.o(index_syn).w_pre_init = 0.5					// pre factor initial value
	b_syn_list.o(index_syn).w_post_init = 2.0					// post factor initial value
	b_syn_list.o(index_syn).s_ampa = S_AMPA						// contribution of AMPAR currents
	b_syn_list.o(index_syn).s_nmda = S_NMDA						// contribution of NMDAR currents
	b_syn_list.o(index_syn).tau_G_a = 2 						// time constant of synaptic event G2 (rise)
	b_syn_list.o(index_syn).tau_G_b = 50 						// time constant of synaptic event G2 (decay)
	b_syn_list.o(index_syn).m_G = 10							// slope of the saturation function for G2
	b_syn_list.o(index_syn).A_LTD_pre = 1.5e-3					// amplitude of pre-LTD
	b_syn_list.o(index_syn).A_LTP_pre = 2.5e-4					// amplitude of pre-LTP
	b_syn_list.o(index_syn).A_LTD_post = 7.5e-4					// amplitude of post-LTD
	b_syn_list.o(index_syn).A_LTP_post = 7.8e-2					// amplitude of post-LTP
	b_syn_list.o(index_syn).tau_u_T = 10 						// time constant for filtering u to calculate T
	b_syn_list.o(index_syn).theta_u_T = -60						// voltage threshold applied to u to calculate T
	b_syn_list.o(index_syn).m_T = 1.7							// slope of the saturation function for T
	b_syn_list.o(index_syn).theta_u_N = -30						// voltage threshold applied to u to calculate N
	b_syn_list.o(index_syn).tau_Z_a = 1							// time constant of preb_syn_list.o(index_syn)aptic event Z (rise)
	b_syn_list.o(index_syn).tau_Z_b = 15 						// time constant of preb_syn_list.o(index_syn)aptic event Z (decay)
	b_syn_list.o(index_syn).m_Z = 6								// slope of the saturation function for Z
	b_syn_list.o(index_syn).tau_N_alpha = 7.5 					// time constant for calculating N-alpha
	b_syn_list.o(index_syn).tau_N_beta = 30						// time constant for calculating N-beta
	b_syn_list.o(index_syn).m_N_alpha = 2						// slope of the saturation function for N_alpha
	b_syn_list.o(index_syn).m_N_beta = 10						// slope of the saturation function for N_beta
	b_syn_list.o(index_syn).theta_N_X = 0.2						// threshold for N to calculate X
	b_syn_list.o(index_syn).theta_u_C = -68						// voltage threshold applied to u to calculate C
	b_syn_list.o(index_syn).theta_C_minus = 15					// threshold applied to C for post-LTD (P activation)
	b_syn_list.o(index_syn).theta_C_plus = 35					// threshold applied to C for post-LTP (K-alpha activation)
	b_syn_list.o(index_syn).tau_K_alpha = 15 					// time constant for filtering K_alpha to calculate K_alpha_bar
	b_syn_list.o(index_syn).tau_K_gamma = 20 					// time constant for filtering K_beta to calculate K_gamma
	b_syn_list.o(index_syn).m_K_alpha = 1.5						// slope of the saturation function for K_alpha
	b_syn_list.o(index_syn).m_K_beta = 1.7						// slope of the saturation function for K_beta
	b_syn_list.o(index_syn).s_K_beta = 100						// scaling factor for calculation of K_beta
}
	
// ## Main
objref a_weightfile, b_weightfile
objref a_g_file, b_g_file
objref a_voltagefile, b_voltagefile

for runs = 0, SEED.size() - 1 {		// Loop through the specified number of random seeds
	if(runs == 0) {	// Stopwatch Start
		startrun = startsw()
	}
	print "Run ", runs + 1, "of ", SEED.size()
	a_syn_list = new List()	// Lists to store multiple objects: plastic synapses, NetCons and VecStims
	b_syn_list = new List()
	a_nc_list = new List()
	b_nc_list = new List()
	a_stim_list = new List()
	b_stim_list = new List()
	a_syn_loc = new Matrix(N_SYN_A, 2)	// Matrices to store section numbers and relative distances of plastic synapses
	b_syn_loc = new Matrix(N_SYN_B, 2)
	a_syn_weights = new List() // Lists to store recorded weight vectors
	b_syn_weights = new List()

	r = new Random(SEED.x(runs)) // deterministic

	L5PC.soma(0.5) distance()
		
	total_time = WARM_UP + REPETITIONS * (RUNTIME + PAUSETIME) + COOL_DOWN
	tstop = total_time
	
	for counter = 0, N_SYN_A - 1 {		// make synapses at location A
		make_syn_a(BRANCH, DIST_A)
		a_syn_loc.x(counter, 0) = BRANCH
		a_syn_loc.x(counter, 1) = DIST_A
	}
	for counter = 0, N_SYN_B - 1 {		// make synapses at location B
		make_syn_b(BRANCH, DIST_B)
		b_syn_loc.x(counter, 0) = BRANCH
		b_syn_loc.x(counter, 1) = DIST_B
	}
	
	// Recording
	trec = new Vector()
	trec.record(&t)
	vrec_a = new Vector()
	vrec_a.record(&L5PC.apic[BRANCH].v(DIST_A)) // voltage at location A
	vrec_b = new Vector()
	vrec_b.record(&L5PC.apic[BRANCH].v(DIST_B))	// voltage at location B
	for counter = 0, N_SYN_A - 1 {				// synaptic weights at location A
		a_syn_weights.append(new Vector())
		a_syn_weights.o(counter).record(&a_syn_list.o(counter).w)
	}
	for counter = 0, N_SYN_B - 1 {				// synaptic weights at location B
		b_syn_weights.append(new Vector())
		b_syn_weights.o(counter).record(&b_syn_list.o(counter).w)
	}	
	
	init()
	run()
	
	if(runs == 0) {	// Stopwatch Stop
		stopsw()
		runtime = startsw() - startrun
		print "----------------------------------\nRUNTIME ESTIMATION"
		print "----------------------------------\n"
		print "One run: ", runtime, "s, Number of runs: ", SEED.size()
		runtime = runtime * SEED.size()
		days = int(runtime / 86400)
		temp = runtime % 86400
		hours = int(temp / 3600)
		temp = temp % 3600
		minutes = int(temp / 60)
		temp = temp % 60
		seconds = temp
		print "Total runtime: ", days, "days, ", hours, "hours, ", minutes, "minutes, ", int(seconds), "seconds."
		print "----------------------------------\n"
	}
}

// ## Check for weight changes
objref weights_a, weights_b
weights_a = new Vector()
weights_b = new Vector()
for counter = 0, a_syn_list.count() - 1 { weights_a.append(a_syn_list.o(counter).w) }
for counter = 0, b_syn_list.count() - 1 { weights_b.append(b_syn_list.o(counter).w) }
print "----------------------------------"
print "Max A: ", weights_a.max(), ", Synapse #", weights_a.indwhere("==",weights_a.max())
print "Min A: ", weights_a.min(), ", Synapse #", weights_a.indwhere("==",weights_a.min())
print "Weights: ", weights_a.printf()
print "Average A: ", weights_a.mean()
print "----------------------------------"
print "Max B: ", weights_b.max(), ", Synapse #", weights_b.indwhere("==",weights_b.max())
print "Min B: ", weights_b.min(), ", Synapse #", weights_b.indwhere("==",weights_b.min())
print "Weights: ", weights_b.printf()
print "Average B: ", weights_b.mean()
print "----------------------------------"

// ## Visualization
objref voltplot		// Voltage
voltplot = new Graph(0)		// A & B local voltage over time
voltplot.size(0, 21030, -85, 10)
voltplot.view(-0.3, -85, 21030, 105, 65, 105, 300, 200)
vrec_b.line(voltplot, 3, 2)
vrec_a.line(voltplot, 7, 2)

objref wplot		// Weights
if(weights_a.min() < weights_b.min()) { 
	temp_min = weights_a.min()} else { temp_min = weights_b.min() }
if(weights_a.max() > weights_b.max()) {
	temp_max = weights_a.max()} else { temp_max = weights_b.max() }	
wplot = new Graph(0)
wplot.size(0, 21030, temp_min, temp_max)
wplot.view(0, temp_min, 21030, temp_max - temp_min, 65, 105, 300, 200)
for counter = 0, a_syn_weights.count() - 1 {
	a_syn_weights.o(counter).line(wplot, 7, 2)		// Weights at A in blue
}
for counter = 0, b_syn_weights.count() - 1 {
	b_syn_weights.o(counter).line(wplot, 3, 2)		// Weights at B in violet
}