/*
Stimulus protocol using random synaptic activity to evoke BAC firing.
Similar to the approach of Shai et al. (2015).
Demonstrating the synaptic plasticity rule by Ebner et al. (2019)
https://doi.org/10.1016/j.celrep.2019.11.068
Made to work with the Hay et al. (2011) L5b pyramidal cell model
This code uses NEURON's vecevent.mod mechanism.
*/

load_file("nrngui.hoc")

// ## Parameters
MODE = 2					// 0 ... basal only / 1 ... apical only / 2 ... apical + basal
DT = 0.025 					// ms, integration step size
WARM_UP = 2300 				// ms, delay from beginning of simulation to approach steady state
RUNTIME = 100				// ms, interval of actual synaptic activity
PAUSETIME = 100				// ms, interval of silence (only needed if repetitions > 1)
REPETITIONS = 1				// number of cycles (one cycle = run + pause)
FREQ = 10					// Hz, average event frequency for Poisson-distributed random activity
SYN_PAUSE = 0				// ms, prevents synapses from being active again for some time (deactivated by default)
A_SYN_WEIGHT = 0.0025		// µS, conductance of a single apical plastic synapse
B_SYN_WEIGHT = 0.0025		// µS, conductance of a single basal plastic synapse
A_INPUT_WEIGHT = 0.0025		// µS, conductance of a single apical input synapse
B_INPUT_WEIGHT = 0.0025		// µS, conductance of a single basal input synapse
N_SYN_APICAL = 10			// Number of apical plastic synapses
N_SYN_BASAL = 10			// Number of basal plastic synapses
N_INPUT_APICAL = 50			// Number of apical input/background synapses
N_INPUT_BASAL = 300			// Number of basal input/background synapses
SEED = 90					// (Example) seed for random number generator (deterministic simulations)

if(MODE == 0) { A_INPUT_WEIGHT = 0.0 }
if(MODE == 1) { B_INPUT_WEIGHT = 0.0 }
	
// ## Object references
objref trec, vrec_apic, vrec_basal
objref a_dist, b_dist, a_syn_rand_dist, b_syn_rand_dist, a_input_rand_dist, b_input_rand_dist
objref a_syn_list, b_syn_list, a_nc_list, b_nc_list, a_stim_list, b_stim_list
objref a_input_list, b_input_list, a_syn_loc, b_syn_loc, a_input_loc, b_input_loc
objref a_syn_weights, b_syn_weights
objref r, tempobj

// ## Create cell
load_file("import3d.hoc")
objref L5PC
strdef morphology_file
morphology_file = "morphologies/cell1.asc"
load_file("models/L5PCbiophys4.hoc")
load_file("models/L5PCtemplate.hoc")
L5PC = new L5PCtemplate(morphology_file)

/*
	Object notation hereafter:
	a ... apical / b ... basal
	syn ... plastic synapse / input ... non-plastic input/background synapse
*/

a_syn_list = new List()	// Lists to store multiple objects: plastic synapses, input synapses, NetCons and VecStims
b_syn_list = new List()
a_nc_list = new List()
b_nc_list = new List()
a_stim_list = new List()
b_stim_list = new List()
a_input_list = new List()
b_input_list = new List()
a_syn_loc = new Matrix(N_SYN_APICAL, 2)	// Matrices to store section numbers and relative distances of plastic and input synapses
b_syn_loc = new Matrix(N_SYN_BASAL, 2)
a_input_loc = new Matrix(N_INPUT_APICAL, 2)
b_input_loc = new Matrix(N_INPUT_BASAL, 2)
a_syn_weights = new List() // Lists to store recorded weight vectors
b_syn_weights = new List()

r = new Random(SEED) // deterministic

L5PC.soma(0.5) distance()
	
total_time = WARM_UP + REPETITIONS * (RUNTIME + PAUSETIME)
tstop = total_time

proc init(){
            finitialize(-70)
            dt = 10
            i = 0
            for i = 1,224 fadvance()	// Run warm-up phase using big steps
            print t
            dt = DT
}

obfunc rand_vec() { local n, start, end localobj vec
	// Create a vector containing n random numbers in the interval start to end
	n = $1
	start = $2
	end = $3
	vec = new Vector(n)
	r.uniform(start,end)
	vec.setrand(r)
	return vec
}

obfunc make_dist_vec() { local counter, temp_dist localobj sections, ending_points
	// Create a vector containing cumulated ending points of sections based on distance from the soma
	sections = $o1	// Input: SectionList containing all sections to be added to the distance vector
	// sections.printnames()
	ending_points = new Vector()
	temp_dist = 0
	counter = 0
	ending_points.append(0.0) // Start with 0
	forsec sections {
		temp_dist += (distance(1.0) - distance(0.0)) // Ending point of the section
		ending_points.append(temp_dist)
		// print temp_dist
	}
	return ending_points
}

obfunc get_rel_loc() { local absdist, counter, dist0, dist1, section, reldist localobj ending_points, rel_loc
	// Find corresponding section and relative location based on absolute distance in the cumulated distance vector
	ending_points = $o1		// Input object: ending points vector (apical or basal)
	absdist = $2			// Input variable: distance within the ending points vector
	rel_loc = new Vector()
	for counter = 0, ending_points.size() - 2 {	// - 2, because the last element is the final ending point
		if(absdist > ending_points.x(counter) && absdist < ending_points.x(counter + 1)) {
			dist0 = ending_points.x(counter)
			// print "Dist0 = ", dist0
			dist1 = ending_points.x(counter + 1)
			// print "Dist1 = ", dist1
			section = counter // Index of the section in the vector containing specified sections
			break
		}
	}
	reldist = (absdist - dist0) / (dist1 - dist0) // Relative distance along the specified section
	rel_loc.append(section, reldist) 
	return rel_loc
}

obfunc make_stim() { local counter, counter2, pause, timer localobj rand, events, stim
	// Create Poisson event VecStim objects
	events = new Vector()
	pause = 0
	timer = WARM_UP
	for counter2 = 1, REPETITIONS {
		rand = rand_vec(total_time, 0, 1)
		for counter = 0, RUNTIME - 1 {
			if(rand.x(counter) <= (FREQ / 1000) && pause == 0) {
				events.append(timer + counter)	// append spike times
				pause = SYN_PAUSE
			}
			if(pause > 0) { pause -= 1 } // allow the next event to happen only after a synaptic pause (deactivated by default)
		}
		timer += RUNTIME + PAUSETIME
	}
	stim = new VecStim()	// uses the VecStim class
	stim.play(events)
	return stim
}

proc make_syn_apical() { local section, dist, index_syn, index_stim
	// Make apical plastic synapse
	section = $1	// Section ID
	dist = $2		// Relative distance
	L5PC.apic[section] a_syn_list.append(new Syn4P(dist))	// at the specified location along the section
	index_syn = a_syn_list.count() - 1
	a_stim_list.append(make_stim())
	index_stim = a_stim_list.count() - 1
	a_nc_list.append(new NetCon(a_stim_list.o(index_stim), a_syn_list.o(index_syn), 0, 0, A_SYN_WEIGHT))
	a_syn_list.o(index_syn).tau_a = 0.2 						// time constant of EPSP rise
	a_syn_list.o(index_syn).tau_b = 2 							// time constant of EPSP decay
	a_syn_list.o(index_syn).e = 0								// reversal potential
	a_syn_list.o(index_syn).w_pre_init = 0.5					// pre factor initial value
	a_syn_list.o(index_syn).w_post_init = 2.0					// post factor initial value
	a_syn_list.o(index_syn).s_ampa = 0.8						// contribution of AMPAR currents
	a_syn_list.o(index_syn).s_nmda = 0.2						// contribution of NMDAR currents
	a_syn_list.o(index_syn).tau_G_a = 2 						// time constant of synaptic event G2 (rise)
	a_syn_list.o(index_syn).tau_G_b = 50 						// time constant of synaptic event G2 (decay)
	a_syn_list.o(index_syn).m_G = 10							// slope of the saturation function for G2
	a_syn_list.o(index_syn).A_LTD_pre = 1.5e-3					// amplitude of pre-LTD
	a_syn_list.o(index_syn).A_LTP_pre = 2.5e-4					// amplitude of pre-LTP
	a_syn_list.o(index_syn).A_LTD_post = 7.5e-4					// amplitude of post-LTD
	a_syn_list.o(index_syn).A_LTP_post = 7.8e-2					// amplitude of post-LTP
	a_syn_list.o(index_syn).tau_u_T = 10 						// time constant for filtering u to calculate T
	a_syn_list.o(index_syn).theta_u_T = -60						// voltage threshold applied to u to calculate T
	a_syn_list.o(index_syn).m_T = 1.7							// slope of the saturation function for T
	a_syn_list.o(index_syn).theta_u_N = -30						// voltage threshold applied to u to calculate N
	a_syn_list.o(index_syn).tau_Z_a = 1							// time constant of prea_syn_list.o(index_syn)aptic event Z (rise)
	a_syn_list.o(index_syn).tau_Z_b = 15 						// time constant of prea_syn_list.o(index_syn)aptic event Z (decay)
	a_syn_list.o(index_syn).m_Z = 6								// slope of the saturation function for Z
	a_syn_list.o(index_syn).tau_N_alpha = 7.5 					// time constant for calculating N-alpha
	a_syn_list.o(index_syn).tau_N_beta = 30						// time constant for calculating N-beta
	a_syn_list.o(index_syn).m_N_alpha = 2						// slope of the saturation function for N_alpha
	a_syn_list.o(index_syn).m_N_beta = 10						// slope of the saturation function for N_beta
	a_syn_list.o(index_syn).theta_N_X = 0.2						// threshold for N to calculate X
	a_syn_list.o(index_syn).theta_u_C = -68						// voltage threshold applied to u to calculate C
	a_syn_list.o(index_syn).theta_C_minus = 15					// threshold applied to C for post-LTD (P activation)
	a_syn_list.o(index_syn).theta_C_plus = 35					// threshold applied to C for post-LTP (K-alpha activation)
	a_syn_list.o(index_syn).tau_K_alpha = 15 					// time constant for filtering K_alpha to calculate K_alpha_bar
	a_syn_list.o(index_syn).tau_K_gamma = 20 					// time constant for filtering K_beta to calculate K_gamma
	a_syn_list.o(index_syn).m_K_alpha = 1.5						// slope of the saturation function for K_alpha
	a_syn_list.o(index_syn).m_K_beta = 1.7						// slope of the saturation function for K_beta
	a_syn_list.o(index_syn).s_K_beta = 100						// scaling factor for calculation of K_beta
}

proc make_syn_basal() { local section, dist, index_syn, index_stim
	// Make basal plastic synapse
	section = $1	// Section ID
	dist = $2		// Relative distance
	L5PC.dend[section] b_syn_list.append(new Syn4P(dist))	// at the specified location along the section
	index_syn = b_syn_list.count() - 1
	b_stim_list.append(make_stim())
	index_stim = b_stim_list.count() - 1
	b_nc_list.append(new NetCon(b_stim_list.o(index_stim), b_syn_list.o(index_syn), 0, 0, B_SYN_WEIGHT))
	b_syn_list.o(index_syn).tau_a = 0.2 						// time constant of EPSP rise
	b_syn_list.o(index_syn).tau_b = 2 							// time constant of EPSP decay
	b_syn_list.o(index_syn).e = 0								// reversal potential
	b_syn_list.o(index_syn).w_pre_init = 0.5					// pre factor initial value
	b_syn_list.o(index_syn).w_post_init = 2.0					// post factor initial value
	b_syn_list.o(index_syn).s_ampa = 0.8						// contribution of AMPAR currents
	b_syn_list.o(index_syn).s_nmda = 0.2						// contribution of NMDAR currents
	b_syn_list.o(index_syn).tau_G_a = 2 						// time constant of synaptic event G2 (rise)
	b_syn_list.o(index_syn).tau_G_b = 50 						// time constant of synaptic event G2 (decay)
	b_syn_list.o(index_syn).m_G = 10							// slope of the saturation function for G2
	b_syn_list.o(index_syn).A_LTD_pre = 1.5e-3					// amplitude of pre-LTD
	b_syn_list.o(index_syn).A_LTP_pre = 2.5e-4					// amplitude of pre-LTP
	b_syn_list.o(index_syn).A_LTD_post = 7.5e-4					// amplitude of post-LTD
	b_syn_list.o(index_syn).A_LTP_post = 7.8e-2					// amplitude of post-LTP
	b_syn_list.o(index_syn).tau_u_T = 10 						// time constant for filtering u to calculate T
	b_syn_list.o(index_syn).theta_u_T = -60						// voltage threshold applied to u to calculate T
	b_syn_list.o(index_syn).m_T = 1.7							// slope of the saturation function for T
	b_syn_list.o(index_syn).theta_u_N = -30						// voltage threshold applied to u to calculate N
	b_syn_list.o(index_syn).tau_Z_a = 1							// time constant of preb_syn_list.o(index_syn)aptic event Z (rise)
	b_syn_list.o(index_syn).tau_Z_b = 15 						// time constant of preb_syn_list.o(index_syn)aptic event Z (decay)
	b_syn_list.o(index_syn).m_Z = 6								// slope of the saturation function for Z
	b_syn_list.o(index_syn).tau_N_alpha = 7.5 					// time constant for calculating N-alpha
	b_syn_list.o(index_syn).tau_N_beta = 30						// time constant for calculating N-beta
	b_syn_list.o(index_syn).m_N_alpha = 2						// slope of the saturation function for N_alpha
	b_syn_list.o(index_syn).m_N_beta = 10						// slope of the saturation function for N_beta
	b_syn_list.o(index_syn).theta_N_X = 0.2						// threshold for N to calculate X
	b_syn_list.o(index_syn).theta_u_C = -68						// voltage threshold applied to u to calculate C
	b_syn_list.o(index_syn).theta_C_minus = 15					// threshold applied to C for post-LTD (P activation)
	b_syn_list.o(index_syn).theta_C_plus = 35					// threshold applied to C for post-LTP (K-alpha activation)
	b_syn_list.o(index_syn).tau_K_alpha = 15 					// time constant for filtering K_alpha to calculate K_alpha_bar
	b_syn_list.o(index_syn).tau_K_gamma = 20 					// time constant for filtering K_beta to calculate K_gamma
	b_syn_list.o(index_syn).m_K_alpha = 1.5						// slope of the saturation function for K_alpha
	b_syn_list.o(index_syn).m_K_beta = 1.7						// slope of the saturation function for K_beta
	b_syn_list.o(index_syn).s_K_beta = 100						// scaling factor for calculation of K_beta
}

proc make_input_apical() { local section, dist, index_input, index_stim
	// Make apical input synapse
	section = $1
	dist = $2
	L5PC.apic[section] a_input_list.append(new Exp2Syn(dist))	// at the specified location along the section
	index_input = a_input_list.count() - 1
	a_stim_list.append(make_stim())
	index_stim = a_stim_list.count() - 1
	a_nc_list.append(new NetCon(a_stim_list.o(index_stim), a_input_list.o(index_input), 0, 0, A_INPUT_WEIGHT))
	a_input_list.o(index_input).tau1 = 0.2
	a_input_list.o(index_input).tau2 = 2
}

proc make_input_basal() { local section, dist, index_syn, index_stim
	// ## Make basal input synapse
	section = $1
	dist = $2
	L5PC.dend[section] b_input_list.append(new Exp2Syn(dist))	// at the specified location along the section
	index_input = b_input_list.count() - 1
	b_stim_list.append(make_stim())
	index_stim = b_stim_list.count() - 1
	b_nc_list.append(new NetCon(b_stim_list.o(index_stim), b_input_list.o(index_input), 0, 0, B_INPUT_WEIGHT))
	b_input_list.o(index_input).tau1 = 0.2
	b_input_list.o(index_input).tau2 = 2
}

// ## Main
runs = 0
a_dist = make_dist_vec(L5PC.apical) // Create cumulated distance vector of the specified sections (whole apical and basal trees)
b_dist = make_dist_vec(L5PC.basal)
a_syn_rand_dist = rand_vec(5000, 0, a_dist.x(a_dist.size() - 1)) // Create random distances vector within the max distance range
b_syn_rand_dist = rand_vec(5000, 0, b_dist.x(b_dist.size() - 1))
a_input_rand_dist = rand_vec(5000, 0, a_dist.x(a_dist.size() - 1))
b_input_rand_dist = rand_vec(5000, 0, b_dist.x(b_dist.size() - 1))

for counter = 0, N_SYN_APICAL - 1 {		// make apical plastic synapses (all in one spot)
	make_syn_apical(50, 0.5)
	a_syn_loc.x(counter, 0) = 50
	a_syn_loc.x(counter, 1) = 0.5
}

for counter = 0, N_SYN_BASAL - 1 {		// make basal plastic synapses (all in one spot)
	make_syn_basal(39, 0.2)
	b_syn_loc.x(counter, 0) = 39
	b_syn_loc.x(counter, 1) = 0.2
}

counter = 0	// Make apical input synapses
n = N_INPUT_APICAL
while(n > 0) {
	tempobj = get_rel_loc(a_dist, a_input_rand_dist.x(counter))
	L5PC.apic[tempobj.x(0)] temp = distance(tempobj.x(1))
	if(temp >= 700) {	// Check if distance requirements (700 µm from soma) are fulfilled; otherwise, use next random location
		make_input_apical(tempobj.x(0), tempobj.x(1))
		a_input_loc.x(N_INPUT_APICAL - n, 0) = tempobj.x(0)
		a_input_loc.x(N_INPUT_APICAL - n, 1) = tempobj.x(1)
		n -= 1
	}
	counter += 1
}

counter = 0	// Make basal input synapses
n = N_INPUT_BASAL
while(n > 0) {
	tempobj = get_rel_loc(b_dist, b_input_rand_dist.x(counter))
	L5PC.dend[tempobj.x(0)] temp = distance(tempobj.x(1))
	if(temp >= 0) {	// Check if distance requirements are fulfilled (no requirements by default); otherwise, use next random location
		make_input_basal(tempobj.x(0), tempobj.x(1))
		b_input_loc.x(N_INPUT_BASAL - n, 0) = tempobj.x(0)
		b_input_loc.x(N_INPUT_BASAL - n, 1) = tempobj.x(1)
		n -= 1
	}
	counter += 1
}

// ## Recording
trec = new Vector()
trec.record(&t)
vrec_basal = new Vector()
vrec_basal.record(&L5PC.dend[39].v(0.2))	// voltage at basal location
vrec_apic = new Vector()
vrec_apic.record(&L5PC.apic[50].v(0.5))		// voltage at apical location
for counter = 0, N_SYN_APICAL - 1 {				// apical synaptic weights
	a_syn_weights.append(new Vector())
	a_syn_weights.o(counter).record(&a_syn_list.o(counter).w)
}
for counter = 0, N_SYN_BASAL - 1 {				// basal synaptic weights
	b_syn_weights.append(new Vector())
	b_syn_weights.o(counter).record(&b_syn_list.o(counter).w)
}

init()
run()

// ## Check for weight changes
objref weights_apical, weights_basal
weights_apical = new Vector()
weights_basal = new Vector()
for counter = 0, a_syn_list.count() - 1 { weights_apical.append(a_syn_list.o(counter).w) }
for counter = 0, b_syn_list.count() - 1 { weights_basal.append(b_syn_list.o(counter).w) }
print "----------------------------------"
print "Max apical: ", weights_apical.max(), ", Synapse #", weights_apical.indwhere("==",weights_apical.max())
print "Min apical: ", weights_apical.min(), ", Synapse #", weights_apical.indwhere("==",weights_apical.min())
print "Weights: ", weights_apical.printf()
print "----------------------------------"
print "Max basal: ", weights_basal.max(), ", Synapse #", weights_basal.indwhere("==",weights_basal.max())
print "Min basal: ", weights_basal.min(), ", Synapse #", weights_basal.indwhere("==",weights_basal.min())
print "Weights: ", weights_basal.printf()
print "----------------------------------"

// ## Visualization
objref voltplot
voltplot = new Graph(0)		// apical & basal voltage over time
voltplot.size(0, 10625, -85, 10)
voltplot.view(-0.3, -85, 10625, 105, 65, 105, 300, 200)
vrec_basal.line(voltplot, 7, 2)
vrec_apic.line(voltplot, 4, 2)
objref awplot, bwplot
awplot = new Graph(0)		// apical weights over time (green)
awplot.size(0, 10625, weights_apical.min(), weights_apical.max())
awplot.view(0, weights_apical.min(), 10625, weights_apical.max() - weights_apical.min(), 65, 105, 300, 200)
for counter = 0, a_syn_weights.count() - 1 {
	a_syn_weights.o(counter).line(awplot, 4, 2)
}
bwplot = new Graph(0)		// basal weights over time (violet)
bwplot.size(0, 10625, weights_basal.min(), weights_basal.max())
bwplot.view(0, weights_basal.min(), 10625, weights_basal.max() - weights_basal.min(), 65, 105, 300, 200)
for counter = 0, b_syn_weights.count() - 1 {
	b_syn_weights.o(counter).line(bwplot, 7, 2)
}