/*
Stimulus protocol as used by Letzkus et al. (2001)
Demonstrating the synaptic plasticity rule by Ebner et al. (2019)
https://doi.org/10.1016/j.celrep.2019.11.068
Made to work with the Hay et al. (2011) L5b pyramidal cell model
*/

load_file("nrngui.hoc")

// ## Object references
objref times, loc_prox, loc_dist, loclist, weight_changes
objref trec, vrec_prox, vrec_dist
objref syn, stim, nc, ic

// ## Parameters
REPS = 3			// Number of postsynaptic spikes
DT = 0.025 			// ms, integration step
AMP = 5.5 			// nA, amplitude of current injection to trigger postsynaptic spikes
DUR = 2.0			// ms, duration of the current injection
WARM_UP = 2300 		// ms, delay from beginning of simulation
COOL_DOWN = 100		// ms, silence phase after stimulation
WEIGHT = 0.0035		// µS, conductance of (single) synaptic potentials
FREQ = 200			// Hz, frequency of postsynaptic spikes

// ## Create cell
load_file("import3d.hoc")
objref L5PC
strdef morphology_file
morphology_file = "morphologies/cell1.asc"
load_file("models/L5PCbiophys4.hoc")
load_file("models/L5PCtemplate.hoc")
L5PC = new L5PCtemplate(morphology_file)

times = new Vector()	// delta t
times.append(10, -10)

L5PC.apic[1] loc_prox = new SectionRef()		// proximal location; distance from soma (center point): 90 µm
L5PC.apic[50] loc_dist = new SectionRef()		// distal location; distance from soma (center point): 669 µm
loclist = new List()
loclist.append(loc_prox)
loclist.append(loc_dist)

// ## Custom init procedure
proc init(){
            finitialize(-70)
            dt = 10
            i = 0
            for i = 1,224 fadvance()	// Run warm-up phase using big steps
            print t
            dt = DT
}

// ## Current trace vector
obfunc current_trace() { local freq, delta_t, total_t, start_t, end_t, counter localobj trace
	freq = $1
	delta_t = $2
	total_t = $3
	trace = new Vector(total_t/DT, 0)
	for counter = 0, REPS - 1 {
		start_t = (0 + delta_t + counter * (1000.0 / freq) + WARM_UP)
		end_t = (start_t + DUR) // this is the actual duration of the current injection
        trace.fill(AMP, start_t/DT, end_t/DT)
	}
    return trace
}

// ## Main stimulation protocol; requires timing and location as input
func stimulation() { local delta_t, delta_w localobj syn_loc, ic_trace, w_rec, w_pre_rec, w_post_rec
	delta_t = $1	// timing
	syn_loc = $o2	// location
	
	// ## Create synapse
	syn_loc.sec syn = new Syn4P(0.5) 	// at the specified branch
	syn.tau_a = 0.2 					// time constant of EPSP rise
	syn.tau_b = 2 						// time constant of EPSP decay
	syn.e = 0							// reversal potential
	syn.w_pre_init = 0.5				// pre factor initial value
	syn.w_post_init = 2.0				// post factor initial value
	syn.s_ampa = 0.5					// contribution of AMPAR currents
	syn.s_nmda = 0.5					// contribution of NMDAR currents
	syn.tau_G_a = 2 					// time constant of presynaptic event G (rise)
	syn.tau_G_b = 50 					// time constant of presynaptic event G (decay)
	syn.m_G = 10						// slope of the saturation function for G
	syn.A_LTD_pre = 1.5e-3				// amplitude of pre-LTD
	syn.A_LTP_pre = 2.5e-4				// amplitude of pre-LTP
	syn.A_LTD_post = 7.5e-4				// amplitude of post-LTD
	syn.A_LTP_post = 7.8e-2				// amplitude of post-LTP
	syn.tau_u_T = 10 					// time constant for filtering u to calculate T
	syn.theta_u_T = -60					// voltage threshold applied to u to calculate T
	syn.m_T = 1.7						// slope of the saturation function for T
	syn.theta_u_N = -30					// voltage threshold applied to u to calculate N
	syn.tau_Z_a = 1						// time constant of presynaptic event Z (rise)
	syn.tau_Z_b = 15 					// time constant of presynaptic event Z (decay)
	syn.m_Z = 6							// slope of the saturation function for Z
	syn.tau_N_alpha = 7.5 				// time constant for calculating N-alpha
	syn.tau_N_beta = 30					// time constant for calculating N-beta
	syn.m_N_alpha = 2					// slope of the saturation function for N_alpha
	syn.m_N_beta = 10					// slope of the saturation function for N_beta
	syn.theta_N_X = 0.2					// threshold for N to calculate X
	syn.theta_u_C = -68					// voltage threshold applied to u to calculate C
	syn.theta_C_minus = 15				// threshold applied to C for post-LTD (P activation)
	syn.theta_C_plus = 35				// threshold applied to C for post-LTP (K-alpha activation)
	syn.tau_K_alpha = 15 				// time constant for filtering K_alpha to calculate K_alpha_bar
	syn.tau_K_gamma = 20 				// time constant for filtering K_beta to calculate K_gamma
	syn.m_K_alpha = 1.5					// slope of the saturation function for K_alpha
	syn.m_K_beta = 1.7					// slope of the saturation function for K_beta
	syn.s_K_beta = 100					// scaling factor for calculation of K_beta
	
	// Create objects
	L5PC.soma ic = new IClamp(0.5)
	ic.del = 0
	ic.dur = 1e9 // not the actual duration
	stim = new NetStim()
	stim.number = 1
	stim.start = WARM_UP
	stim.noise = 0
	nc = new NetCon(stim, syn, 0, 0, WEIGHT)

	total_time = WARM_UP + REPS * (1000.0 / FREQ) + COOL_DOWN
	tstop = total_time
	
	ic_trace = new Vector()
	ic_trace = current_trace(FREQ, delta_t, total_time)
	ic_trace.play(&ic.amp, DT) // play the trace vector into the IClamp protocol
	
	// Recording
	trec = new Vector()
	trec.record(&t)
	vrec_prox = new Vector()
	vrec_prox.record(&loc_prox.sec.v(0.5))	// voltage at proximal location
	vrec_dist = new Vector()
	vrec_dist.record(&loc_dist.sec.v(0.5))	// voltage at distal location
	
	init()
	run()
	
	w_init = syn.w_pre_init * syn.w_post_init
	delta_w_pre = syn.w_pre - syn.w_pre_init	// Difference of present and initial values
	delta_w_post = syn.w_post - syn.w_post_init
	w_final_pre = syn.w_pre_init + delta_w_pre * 100			// 100 repetitions
	w_final_post = syn.w_post_init + delta_w_post * 100
	if (w_final_pre > 1.0) { w_final_pre = 1.0 }
	if (w_final_pre < 0.0) { w_final_pre = 0.0 }
	if (w_final_post > 5.0) { w_final_post = 5.0 }
	if (w_final_post < 0.0) { w_final_post = 0.0 }
	w_final = w_final_pre * w_final_post
	outcome = (w_final - w_init) / w_init * 100
		
	print "dw(pre)  = ", delta_w_pre, "--> ", (w_final_pre - syn.w_pre_init) / syn.w_pre_init * 100, "%"
	print "dw(post) = ", delta_w_post, "--> ", (w_final_post - syn.w_post_init) / syn.w_post_init * 100, "%"
	print "w(start) = ", w_init
	print "w(end)   = ", w_final
	print "Weight change --> ", outcome, "%"
	return outcome
}

// ## Main run procedure
weight_changes = new Vector()
runs = 1
n_locs = loclist.count() 	// number of different locations
n_times = times.size()		// number of different timings
for i_locs = 0, n_locs - 1 {
	for i_times = 0, n_times - 1 {
		print "\n*** Starting interval ", runs, "of ", (n_locs * n_times), "***\n// Section = "
		loclist.o[i_locs].sec print secname()
		weight_changes.append(stimulation(times.x[i_times], loclist.o[i_locs]))
		runs = runs + 1
	}
}

// ## Comparison Graph
objref voltplot
voltplot = new Graph(0)		// proximal & distal voltage over time (final run only)
voltplot.size(0, 7225, -85, 10)
voltplot.view(-0.3, -85, 7225, 105, 65, 105, 300, 200)
vrec_prox.line(voltplot, 1, 2)
vrec_dist.line(voltplot, 3, 2)
objref compplot, compvec
compvec = new Vector()
compvec.append(30, -11, -21, 43)	// Experimental data (Letzkus et al., 2006)
compplot = new Graph(0)
compplot.size(0, 3, -30, 50)
compplot.view(-0.3, -30, 3.3, 80, 65, 105, 300, 200)
compvec.line(compplot, 9, 2)
weight_changes.line(compplot, 2, 2)