from .synapse_builder import SynapseBuilder
from neuron import h
import os
import pickle
import time
import numpy as np
from pprint import pprint
from pathlib import Path
import pathlib
class AMPANMDASynPlastBuilder(SynapseBuilder):
epsp_start = 200
epsp_end = 10000
after_epsp_wait = 200
stimulation_start = 1500
root_path = os.getcwd()
def __init__(self):
super().__init__()
def init_object_vars(self):
super().init_object_vars()
self.democracy = False
self.records_synapse = []
self.records_ec_synapse = []
self.stimulations = []
self.clamp_connections = []
self.network_connections = []
self.plastic_synapses = []
self.nonplastic_synapses = []
self.artificial_cells = []
self.gaba_synapses = []
self.weight_factor = 1
self.plasticity_model_parameters = {}
self.nr2b_multiplier = 1
self.gamma_pairings = 3
self.theta_pairings = 20
self.gamma_freq = 80
self.theta_freq = 4
self.gbarmult = 1.5
self.nr2b_cure = 1
self.A_pm_mult = 1
self.gaba_weight = 0.01
self.ampamult = 1
self.nmdamult = 1
self.nr2amult = 1
self.nr2bmult = 1
self.wmax = 2.5
self.wmin = 0.2
self.alpha_multiplier = 1
self.ca3_starttime = 0 #11000
self.ca3_endtime = 612000 #21000
self.ltp_pairings = 100
self.ltp_freq = 100
self.ltp_interburst_interval = 1000
self.ltp_bursts = 1
self.ltd_pairings = 500
self.ltd_freq = 1
self.stdp_pairings = 60
self.stdp_freq = 5
self.gaba_enabled = False
self.plastmodifier = 1
self.initial_weight = 1
self.stdp_synweight = 0.5
self.stdp_somaweight = 0.1
self.stdp_deltat = 10
self.epsp_measurement_enabled = True
self.presynaptic_enabled = True
self.full_records = True
def init_object_functions(self):
super().init_object_functions()
h("objref vs")
h("objref vs_ec")
h("objref vs_gaba")
#h("objref syntimes")
self.default_parameters()
def default_parameters(self):
self.plasticity_model_parameters = {
"gAMPAbar": 1.6e-4 * self.gbarmult * self.ampamult,
"gNMDAbar": 3.0e-7 * self.gbarmult * self.nmdamult * 20.0,
"gNR2Abar": 1 * self.nr2amult,
"gNR2Bbar": 1 * self.nr2b_multiplier * self.nr2bmult,
"Alpha_ampa": 1.1 * self.alpha_multiplier,
"Alpha_nr2a": 0.5 * self.alpha_multiplier,
"Alpha_nr2b": 0.1 * self.alpha_multiplier,
"Beta_ampa": 0.19,
"Beta_nr2a": 0.024,
"Beta_nr2b": 0.0075,
"AMPA_tau1": 0.5,
"AMPA_tau2": 3,
"A_ltp": 1.0e-4 * self.A_pm_mult * self.A_mult * 8.0e0,
"A_ltd": 4.0e-3 * self.A_pm_mult * self.A_mult * 8.0e0,
"v_thresh1": -63,
"v_thresh2": -63,
"v_trace1_thresh2": 0.2,
"v_trace1_tau": 10,
"v_trace2_tau": 10,
"g_nmda_trace_ltp_tau": 100,
"g_nmda_trace_ltd_tau": 3,
"X_trace_tau": 15,
"wmin": self.wmin,
"wmax": self.wmax,
"hill_midpoint_ltp": 1.7e-2,# * self.ltp_midpoint_multiplier,
"hill_midpoint_ltp2": self.ltp_midpoint_multiplier, #9e-2,
"hill_coef_ltp": 4,
"hill_midpoint_ltd": 3.3e-2,
"hill_coef_ltd": 2,
"X_max": 1,
"moving_threshold_hill_ltp_multiplier": 1,
"moving_threshold_hill_ltp_tau": 100.0,
"moving_threshold_hill_ltp2_multiplier": 1,
"moving_threshold_hill_ltp2_tau": 100.0,
"moving_threshold_hill_ltd_multiplier": 1,
"moving_threshold_hill_ltd_tau": 100.0,
}
def set_up_records(self, synapses, soma):
if self.full_records:
self.records_dendrite = { "cai": h.Vector(), }
self.records_dendrite['cai'].record(self.monitor_dendrite._ref_cai)
self.set_up_synapse_records_full(synapses, self.records_synapse)
else:
self.set_up_synapse_records(synapses, self.records_synapse)
# record soma:
self.records_soma = {"v": h.Vector()}
self.records_soma['v'].record(soma._ref_v)
# record time:
self.records_t = h.Vector()
self.records_t.record(h._ref_t)
def set_up_synapse_records(self, synapses, records_synapse):
for synapse in synapses:
records_synapse.append(
{
"i": h.Vector(),
"v": h.Vector(),
"g_nmda": h.Vector(),
"weight": h.Vector(),
})
records_synapse[-1]['v'].record(synapse.get_segment()._ref_v)
records_synapse[-1]['i'].record(synapse._ref_i)
records_synapse[-1]['g_nmda'].record(synapse._ref_g_nmda)
records_synapse[-1]['weight'].record(synapse._ref_weight)
def set_up_synapse_records_full(self, synapses, records_synapse):
for synapse in synapses:
records_synapse.append(
{
"i": h.Vector(),
"v": h.Vector(),
"g_nmda": h.Vector(),
"g_nr2a": h.Vector(),
"g_nr2b": h.Vector(),
"i_nmda": h.Vector(),
"i_ampa": h.Vector(),
"v_threshed1": h.Vector(),
"v_threshed2": h.Vector(),
"g_nmda_trace_ltp": h.Vector(),
"g_nmda_trace_ltd": h.Vector(),
"w_ltp": h.Vector(),
"w_ltd": h.Vector(),
"weight": h.Vector(),
})
records_synapse[-1]['v'].record(synapse.get_segment()._ref_v)
records_synapse[-1]['i'].record(synapse._ref_i)
records_synapse[-1]['g_nmda'].record(synapse._ref_g_nmda)
records_synapse[-1]['g_nr2a'].record(synapse._ref_g_nr2a)
records_synapse[-1]['g_nr2b'].record(synapse._ref_g_nr2b)
records_synapse[-1]['i_nmda'].record(synapse._ref_i_nmda)
records_synapse[-1]['i_ampa'].record(synapse._ref_i_ampa)
records_synapse[-1]['v_threshed1'].record(synapse._ref_v_threshed1)
records_synapse[-1]['v_threshed2'].record(synapse._ref_v_threshed2)
records_synapse[-1]['g_nmda_trace_ltp'].record(synapse._ref_g_nmda_trace_ltp)
records_synapse[-1]['g_nmda_trace_ltd'].record(synapse._ref_g_nmda_trace_ltd)
records_synapse[-1]['w_ltp'].record(synapse._ref_w_ltp)
records_synapse[-1]['w_ltd'].record(synapse._ref_w_ltd)
records_synapse[-1]['weight'].record(synapse._ref_weight)
def add_plasticity_synapse(self, dendrite):
if os.environ["MILEDIDEBUG"] == '1':
print("adding nmda plasticity syn")
print(dendrite)
synapse = self.add_synapse(dendrite, h.AMPANMDASynPlast)
for kwarg in self.plasticity_model_parameters:
setattr(synapse, kwarg, self.plasticity_model_parameters[kwarg])
return synapse
# ========================================================
# ===================== =Analysis= =======================
# ========================================================
def get_weights(self):
weights = []
for item in self.records_synapse:
weights.append(list(item["weight"])[-1])
return weights
def save_weights(self):
weights = []
for item in self.records_synapse:
weights.append(list(item["weight"])[-1])
Path(f"weights").mkdir(parents=True, exist_ok=True)
with open(f"weights/weights.npy", 'wb') as f:
np.save(f, weights)
def mult_weights(self, multi=100):
for i, synapse in enumerate(self.plastic_synapses):
synapse.w = max(-0.5, synapse.w*multi)
def load_weights(self):
weights = np.load(f"weights/weights.npy", allow_pickle=True).tolist()
for i, synapse in enumerate(self.plastic_synapses):
if os.environ["MILEDIDEBUG"] == '1':
pprint(synapse)
synapse.w = weights[i]