#Dainauskas JJ, Marie H, Migliore M and Saudargiene A.
#GluN2B-NMDAR subunit contribution on synaptic plasticity: a phenomenological model for CA3-CA1 synapses
#Frontiers in Synaptic Neuroscience 2023 DOI: 10.3389/fnsyn.2023.1113957.
#Model of synaptic plasticity
import numpy as np
from neuron_model import CA1_NeuronModel
from synapse_nmda import Synapse_NMDA
from synapse_ampa import Synapse_AMPA
from multiprocessing import Pool
class PylediPlasticity:
def __init__(self, nmdabar=1e-3, ampabar=0.5, ampabar_multiplier=1, save_last_record_only=False):
self.save_last_record_only = save_last_record_only
self.set_records_list()
self.init_parameters()
self.init_records()
self.nmdabar = nmdabar
self.ampabar = ampabar * ampabar_multiplier
self.stdp_range = list(range(-100, 110, 10))
self.neuron = CA1_NeuronModel(self.dt, save_last_record_only=save_last_record_only)
self.synapse_nmda = Synapse_NMDA(self.dt, nmdabar=self.nmdabar)
self.synapse_ampa = Synapse_AMPA(self.dt, ampabar=self.ampabar)
def init_parameters(self):
self.t = 0.0
self.dt = 0.1 # ms
self.Ip0_ampl = 20
self.Ip0_ampl_rest = -0.5
#Membrane parameters
self.Vrest = -70 #mV
self.v = self.Vrest
self.E_AMPA = 0
self.E_NMDA = 0
# AMPA
self.g_ampa = 0.0
self.synInputs = 0.0
self.I_NMDA = 0.0
self.I_AMPA = 0.0
# NMDA
self.g_nmda = 0.0
self.g_nmda_LTP = 0
self.g_nmda_LTD = 0
self.g_nr2a = 0.0
self.g_nr2b = 0.0
# NMDA traces
self.g_nmda_tau1 = 20
self.g_nmda_tau2 = 1000
self.g_nmda_trace1 = 0.0
self.g_nmda_trace2 = 0.0
self.k_gnmda_scaling=1
# Synaptic plasticity model
self.v_trace1_threshed1 = 0.0
self.v_trace2_threshed2 = 0.0
self.pre_dirac = 0.0
self.dirac_trace = 0.0
self.A_dirac = 1
self.hill_ltp_mid = 11e-5
self.hill_ltd_mid = 9e-5
self.hill_ltp_coef = 4
self.hill_ltd_coef = 2
self.moving_threshold_hill_ltp = 0.0
self.moving_threshold_hill_ltp_multiplier_from_ltd = 10
self.moving_threshold_hill_ltp_tau = 100
self.moving_threshold_hill_ltd = 0.0
self.moving_threshold_hill_ltd_multiplier_from_ltp = 1000
self.moving_threshold_hill_ltd_tau = 100
self.A_ltp = 1
self.A_ltd = 100
self.v_trace1_tau = 10
self.v_trace2_tau = 10
self.v_trace_thresh1 = -65.0
self.v_trace_thresh2 = -67.0
self.dirac_trace_tau = 15
self.wmin = 0.4
self.wmax = 2.0
self.weight = 1.0
self.customweight = 1.0
self.w_ltd = 0.0
self.w_ltp = 0.0
self.w_change = 0.0
self.ltp_part = 0.0
self.ltd_part = 0.0
#-------------------------------------------------------------------------
# Record
#-------------------------------------------------------------------------
def set_records_list(self):
self.records_list = [
"g_ampa",
"g_nmda",
"g_nmda_trace1",
"g_nmda_trace2",
"t",
"v",
"v_soma",
"v_trace1_threshed1",
"v_trace2_threshed2",
"w_ltd",
"w_ltp",
"ltd_part",
"ltp_part",
"pre_dirac",
"dirac_trace",
"weight",
"hilleq_ltp",
"hilleq_ltd",
"g_nr2a",
"g_nr2b",
"moving_threshold_hill_ltd",
"moving_threshold_hill_ltp",
]
def init_records(self):
self.records = {}
if self.save_last_record_only:
for record_name in self.records_list:
self.records[record_name] = []
self.records[record_name].append(0)
else:
for record_name in self.records_list:
self.records[record_name] = []
def update_records(self):
if self.save_last_record_only:
for record_name in self.records_list:
self.records[record_name][-1] = getattr(self, record_name)
else:
for record_name in self.records_list:
self.records[record_name].append(getattr(self, record_name))
def step(self, i, syn_input=False, soma_input=False):
Ip0_ampl = self.Ip0_ampl_rest
if soma_input:
Ip0_ampl = self.Ip0_ampl
self.neuron.step(i, self.synInputs, self.I_NMDA, Ip0_ampl)
self.synapse_ampa.step(i, self.v, syn_input)
self.synapse_nmda.step(i, self.v, syn_input)
self.g_ampa = self.synapse_ampa.g_ampa * self.weight * self.customweight
self.g_nmda = self.synapse_nmda.g_nmda
self.g_nr2a = self.synapse_nmda.g_nr2a
self.g_nr2b = self.synapse_nmda.g_nr2b
self.v = self.neuron.Vdend
self.v_soma = self.neuron.Vsoma
#NMDA for LTP and LTD
self.g_nmda_LTP = self.synapse_nmda.g_nmda_LTP
self.g_nmda_LTD = self.synapse_nmda.g_nmda_LTD
self.pre_dirac = 0.0
if syn_input:
self.net_receive()
self.dynamics()
self.update()
self.t += self.dt
self.update_records()
#--------------------------------------------------------------------------------
#Synaptic plasticity model
#--------------------------------------------------------------------------------
def dynamics(self):
self.v_trace1_threshed1 += ((max(0, (self.v - self.v_trace_thresh1)) - self.v_trace1_threshed1) / self.v_trace1_tau) * self.dt
self.v_trace2_threshed2 += ((max(0, (self.v - self.v_trace_thresh2)) - self.v_trace2_threshed2) / self.v_trace2_tau) * self.dt
self.dirac_trace += ((self.pre_dirac - self.dirac_trace) / self.dirac_trace_tau) * self.dt
#--------------------
#NMDA traces
self.g_nmda_trace1 += (((self.g_nmda_LTP) - self.g_nmda_trace1) / self.g_nmda_tau1) * self.dt
self.g_nmda_trace2 += (((self.g_nmda_LTD) - self.g_nmda_trace2) / self.g_nmda_tau2) * self.dt
def net_receive(self):
self.pre_dirac = self.A_dirac
def update(self):
self.I_AMPA = self.g_ampa * (self.v - self.E_AMPA)
self.I_NMDA = self.g_nmda * (self.v - self.E_NMDA) * self.nmdabar
self.synInputs = self.I_AMPA + self.I_NMDA
#LTP
hill_ltp = (self.k_gnmda_scaling * self.g_nmda_trace1) ** self.hill_ltp_coef
self.hilleq_ltp = hill_ltp / ((self.hill_ltp_mid) ** self.hill_ltp_coef + hill_ltp) - self.moving_threshold_hill_ltp
if self.hilleq_ltp < 0:
self.hilleq_ltp = 0
self.ltp_part = self.A_ltp * self.hilleq_ltp * self.v_trace1_threshed1
#LTD
hill_ltd = (self.k_gnmda_scaling * self.g_nmda_trace2) ** self.hill_ltd_coef
self.hilleq_ltd = hill_ltd / ((self.hill_ltd_mid) ** self.hill_ltd_coef + hill_ltd) - self.moving_threshold_hill_ltd
if self.hilleq_ltd < 0:
self.hilleq_ltd = 0
self.ltd_part = self.A_ltd * self.hilleq_ltd * self.v_trace2_threshed2 * self.dirac_trace
#LTP threshold
self.moving_threshold_hill_ltp += ((-self.moving_threshold_hill_ltp + self.hilleq_ltd * self.moving_threshold_hill_ltp_multiplier_from_ltd) / self.moving_threshold_hill_ltp_tau) * self.dt
#LTD threshold
self.moving_threshold_hill_ltd += ((-self.moving_threshold_hill_ltd + self.hilleq_ltp * self.moving_threshold_hill_ltd_multiplier_from_ltp) / self.moving_threshold_hill_ltd_tau) * self.dt
ltpmult = (self.wmax - self.weight)
ltdmult = (self.weight - self.wmin)
self.w_change = (self.ltp_part * ltpmult - self.ltd_part * ltdmult)
self.w_ltp += self.ltp_part * ltpmult * self.dt
self.w_ltd += self.ltd_part * ltdmult * self.dt
self.weight += self.w_change * self.dt
if self.weight < 0:
self.weight = 0
def get_record_data(self):
return self.records
def run_stdp_tests_static(pairings = 2, frequency = 0.5, nmdabarmult = 1, pre_spikes = 1, post_spikes = 1, synparams = {}, stdp_range=range(-100, 110, 10), first_post=False, save_last_record_only=False):
taskrange = [(pairings, frequency, stdp_dt, nmdabarmult, pre_spikes, post_spikes, synparams, first_post, save_last_record_only) for stdp_dt in stdp_range]
with Pool(len(taskrange)) as p:
experiments = p.map(PylediPlasticity.run_single_stdp_static, taskrange)
return experiments
def run_single_stdp_static(input_tuple):
(pairings, frequency, stdp_dt, nmdabarmult, pre_spikes, post_spikes, synparams, first_post, save_last_record_only) = input_tuple
model = PylediPlasticity(save_last_record_only=save_last_record_only)
for synparam in synparams:
setattr(model, synparam, synparams[synparam])
dt = model.dt
pre_pair_interval = 10
pre_pair_interval_dt = pre_pair_interval / dt
post_pair_interval = 10
post_pair_interval_dt = post_pair_interval / dt
interval = int(1000/frequency)
start = 100
stop = interval * (pairings-1) + start + 300
dt_int_pair = int((interval * pairings)/dt)
dt_start_dend = int((start + 2)/dt)
dt_start_soma = int(start/dt)
dt_stop = int(stop/dt)
dt_interval = int(interval/dt)
dt_stdp_dt = int(stdp_dt/model.dt)
syn_inputs = []
if first_post:
pre_delay = 0
else:
pre_delay = post_pair_interval_dt * (post_spikes-1)
for i in range(pre_spikes):
syn_inputs += list(np.arange(dt_start_dend + pre_pair_interval_dt * i - dt_stdp_dt + pre_delay, dt_start_dend + dt_int_pair - dt_stdp_dt, dt_interval))
soma_delta = -1.0 / dt
soma_inputs = []
for i in range(post_spikes):
soma_inputs += list(np.arange(dt_start_soma + soma_delta + post_pair_interval_dt * i, dt_start_soma + soma_delta + dt_int_pair, dt_interval))
soma_input_period_ms = 5
soma_input_period = 0
for i in range(dt_stop):
syn_input_on = False
if soma_input_period > 0:
soma_input_period -= 1
if (i in syn_inputs):
syn_input_on = True
if (i in soma_inputs):
soma_input_period = soma_input_period_ms / model.dt
model.step(i, syn_input_on, soma_input_period > 0)
return model.get_record_data()
def run_freq_tests(self, frequency=5.0, pairings=2, soma_on=False, stdp_dt=10, start=300, afterstop=100, old_i = 0):
interval = int(1000/frequency)
start = 300
afterstop = 100
stop = interval * (pairings-1) + start + afterstop
dt_interval = int(interval/self.dt)
dt_start = int(start/self.dt)
dt_stop = int(stop/self.dt)
syn_inputs = np.arange(dt_start, dt_start + (dt_interval * pairings), dt_interval)
soma_inputs = []
if soma_on:
soma_inputs = np.arange(dt_start - int(5 / self.dt) + int(stdp_dt/self.dt), dt_start - int(5 / self.dt) + int(stdp_dt/self.dt) + (dt_interval * pairings), dt_interval)
soma_input_period_ms = 5
soma_input_period = 0
for i in range(dt_stop):
syn_input_on = False
if soma_input_period > 0:
soma_input_period -= 1
if (i in syn_inputs):
syn_input_on = True
if (i in soma_inputs):
soma_input_period = soma_input_period_ms / self.dt
self.step(i+old_i, syn_input_on, soma_input_period > 0)
self.run_data = self.get_record_data()
return dt_stop