from neuron import h
from pprint import pprint
from pathlib import Path
import numpy as np
import pickle
import time
import pathlib
import os
class SynapseBuilder():
epsp_start = 200
epsp_end = 2000
after_epsp_wait = 200
stimulation_start = 500
root_path = os.getcwd()
def __init__(self):
self.init_object_vars()
self.init_object_functions()
def init_object_vars(self):
self.democracy = False
self.records_synapse = []
self.stimulations = []
self.clamp_connections = []
self.network_connections = []
self.plastic_synapses = []
self.nonplastic_synapses = []
self.artificial_cells = []
self.gaba_synapses = []
self.weight_factor = 1
self.plasticity_model_parameters = {}
self.nr2b_multiplier = 1
self.gamma_pairings = 5
self.theta_pairings = 20
self.gamma_freq = 80
self.theta_freq = 4
self.gbarmult = 1
self.nr2b_cure = 1
self.A_pm_mult = 10
self.A_mult = 1
self.gaba_weight = 1.0
self.ampamult = 1
self.nmdamult = 1
self.nr2amult = 1
self.nr2bmult = 1
self.wmax = 0.5
self.wmin = -0.5
self.ca3_starttime = 0 #11000
self.ca3_endtime = 612000 #21000
self.ltp_pairings = 100
self.ltp_freq = 100
self.ltp_interburst_interval = 1000
self.ltp_bursts = 1
self.ltd_pairings = 100
self.ltd_freq = 1
self.stdp_pairings = 30
self.stdp_freq = 5
self.gaba_enabled = False
self.plastmodifier = 1
self.initial_weight = 1
self.ltp_midpoint_multiplier = 1
self.stdp_synweight = 0.5
self.stdp_somaweight = 0.1
self.stdp_deltat = 10
self.epsp_measurement_enabled = False
self.presynaptic_enabled = True
def __del__(self):
self.reset_object_vars()
def reset_object_vars(self):
for item in self.stimulations:
item = None
for item in self.clamp_connections:
item = None
for item in self.network_connections:
item = None
for item in self.plastic_synapses:
item = None
for item in self.nonplastic_synapses:
item = None
for item in self.gaba_synapses:
item = None
for item in self.artificial_cells:
item = None
self.clamp_connections = []
self.stimulations = []
self.network_connections = []
self.plastic_synapses = []
self.nonplastic_synapses = []
self.artificial_cells = []
self.gaba_synapses = []
def init_object_functions(self):
h("objref vs")
self.default_parameters()
def set_parameters(self, kwargs):
if os.environ["MILEDIDEBUG"] == '1':
print("Setting synapse parameters")
for kwarg in kwargs:
if os.environ["MILEDIDEBUG"] == '1':
print(f"- [{kwarg}]: {kwargs[kwarg]}")
setattr(self, kwarg, kwargs[kwarg])
self.default_parameters()
def set_plasticity_model_parameters(self, kwargs):
self.default_parameters()
for kwarg in [x for x in kwargs]:
self.plasticity_model_parameters[kwarg] = kwargs[kwarg]
def default_parameters(self):
self.plasticity_model_parameters = {
"gAMPAbar": 1.27e-5 * self.gbarmult * self.ampamult,
"gNMDAbar": 1.0e-5 * self.gbarmult * self.nmdamult,
"gNR2Abar": 1 * self.nr2amult,
"gNR2Bbar": 1 * self.nr2b_multiplier * self.nr2bmult,
"Alpha_nr2a": 0.5,
"Alpha_nr2b": 0.1,
"Beta_nr2a": 0.024,
"Beta_nr2b": 0.0075,
"tau1": 0.1,
"tau2": 3,
"A_p": 4e-5 * self.A_pm_mult * self.A_mult,
"A_m": 5e-5 * self.A_pm_mult * self.A_mult,
"tetam": -69,
"tetap": -64,
"tau_0": 10,
"tau_y": 110,
"nmt_tau": 10,
"X_a": -5,
"X_b": 5000,
"nmda_thresh_p": 1e-7,
"nmda_thresh_d": 1e-7,
"gnmda_tanh_b": 1e4,
"wmin": self.wmin,
"wmax": self.wmax,
"gbar": 0,
"gbarmid": (self.initial_weight - 1)/18,
"plastmodifier": self.plastmodifier,
}
self.plasticity_model_parameters = {}
def aicd_mult(multi, aicd):
return 1 + (multi-1) * aicd
def set_alzheimers(self, aicd = 0):
self.nr2b_multiplier = SynapseBuilder.aicd_mult(4, aicd)
if os.environ["MILEDIDEBUG"] == '1':
print(f"NR2B Multiplier: {self.nr2b_multiplier}")
self.default_parameters()
# ========================================================
# ===================== =Builders= =======================
# ========================================================
def add_network_connection(self, stimulation, synapse, w):
if (os.environ.get("NRN_DEBUG") == "1"):
print(f"Adding connection: {synapse}")
netcon = h.NetCon(stimulation, synapse, 0, 0, w)
self.network_connections.append(netcon)
def add_stimulation(self, pairings, freq, start, noise=0):
stimulation = h.NetStim()
stimulation.number = pairings
stimulation.interval = 1000.0 / freq
stimulation.start = start
stimulation.noise = noise
if (os.environ.get("NRN_DEBUG") == "1"):
print(f"Adding stimulation: {pairings} pair * {stimulation.interval} interval ({freq} freq). Start at: {start}")
self.stimulations.append(stimulation)
return stimulation
# ========================================================
# ===================== =Synapses= =======================
# ========================================================
def add_nonplastic_synapse(self, dendrite, plasticity_model):
synapse = plasticity_model(dendrite)
self.nonplastic_synapses.append(synapse)
return synapse
def add_synapse(self, dendrite, plasticity_model):
synapse = plasticity_model(dendrite)
self.plastic_synapses.append(synapse)
return synapse
def add_gaba_synapse(self, dendrite):
synapse = self.add_nonplastic_synapse(dendrite, lambda x: h.Exp2Syn(0.5, x))
self.gaba_synapses.append(synapse)
synapse.tau1 = 35
synapse.tau2 = 100
synapse.e = -75
return synapse
def add_ampa_exp2syn_synapse(self, dendrite):
synapse = self.add_nonplastic_synapse(dendrite, lambda x: h.Exp2Syn(0.5, x))
synapse.tau1 = 0.1
synapse.tau2 = 3
synapse.e = 0
return synapse
def add_artificial_gaba(self, stim_in, somagaba, w):
artificial_cell = h.IntFire4()
self.artificial_cells.append(artificial_cell)
self.add_network_connection(stim_in, artificial_cell, 0.4)
for somagaba_syn in somagaba:
self.add_network_connection(artificial_cell, somagaba_syn, w)
def add_plasticity_synapse(self, dendrite):
if os.environ["MILEDIDEBUG"] == '1':
print("adding nmda syn")
synapse = self.add_synapse(dendrite, h.NMDASyn)
for kwarg in self.plasticity_model_parameters:
setattr(synapse, kwarg, self.plasticity_model_parameters[kwarg])
return synapse
def add_AMPA_democracy(self, synapses, weights):
if not self.democracy:
return
for (i, synapse) in enumerate(synapses):
synapse.gAMPAbar = synapse.gAMPAbar * weights[i]
def build_synapses(self, dendrites, synapse_function):
synapses = []
for dendrite in dendrites:
synapse = synapse_function(dendrite)
synapses.append(synapse)
return synapses
def connect_iclamp(self, dendrites, current):
for dendrite in dendrites:
self.add_iclamp_connection(dendrite, current)
def connect_vclamp(self, dendrites, voltage):
for dendrite in dendrites:
self.add_vclamp_connection(dendrite, voltage)
def connect_synapses_stimulation(self, synapses, stimulation, w=1):
if type(w) is list:
for i, synapse in enumerate(synapses):
self.add_network_connection(stimulation, synapse, w[i] * self.weight_factor)
else:
for synapse in synapses:
self.add_network_connection(stimulation, synapse, w * self.weight_factor)
def add_gaba_synapses(self, dendrite, count):
synapses = []
for i in range(count):
synapse = self.add_gaba_synapse(dendrite)
synapses.append(synapse)
return synapses
def build_plasticity_synapses(self, dendrites):
return self.build_synapses(dendrites, self.add_plasticity_synapse)
# ========================================================
# =================== =Stimulations= =====================
# ========================================================
def build_stimulations_ltp_gaba(self):
stimulations = []
freq = 100
pairings = 100
start = self.stimulation_start
stimulation = self.add_stimulation(pairings, freq, start-1)
stimulations.append(stimulation)
stimulation = self.add_stimulation(pairings, freq, start+1)
stimulations.append(stimulation)
stimulation_time = (1000/freq) * pairings
interburst_time = 1000
start = start + stimulation_time + interburst_time
stimulation = self.add_stimulation(pairings, freq, start-1)
stimulations.append(stimulation)
stimulation = self.add_stimulation(pairings, freq, start+1)
stimulations.append(stimulation)
return stimulations
def build_stimulations_ratio(self):
stimulations = []
freq = 100
pairings = 1
start = self.stimulation_start
stimulation = self.add_stimulation(pairings, freq, start)
stimulations.append(stimulation)
stimulation_time = (1000/freq) * pairings
return (stimulations, stimulation_time)
def build_stimulations_ltp(self):
stimulations = []
freq = self.ltp_freq
pairings = self.ltp_pairings
start = self.stimulation_start
stimulation = self.add_stimulation(pairings, freq, start)
stimulations.append(stimulation)
stimulation_time = (1000/freq) * pairings
for i in range(self.ltp_bursts-1):
if os.environ["MILEDIDEBUG"] == '1':
print(f"building burst: {i}")
start = start + stimulation_time + self.ltp_interburst_interval
stimulation = self.add_stimulation(pairings, freq, start)
stimulations.append(stimulation)
stimulation_time = stimulation_time * self.ltp_bursts + self.ltp_interburst_interval * (self.ltp_bursts - 1)
return (stimulations, stimulation_time)
def build_stimulations_ltd(self):
stimulations = []
freq = self.ltd_freq
pairings = self.ltd_pairings
start = self.stimulation_start
stimulation = self.add_stimulation(pairings, freq, start)
stimulations.append(stimulation)
stimulation_time = (1000/freq) * pairings
return (stimulations, stimulation_time)
def build_epsp_measurement(self, synapses, totalruntime):
stimulations = self.build_stimulations_epsp(totalruntime)
for stimulation in stimulations:
self.connect_synapses_stimulation(synapses, stimulation)
return stimulations
def build_stimulations_epsp(self, totalruntime):
stimulations = []
freq = 1
pairings = 1
start = self.epsp_start
stimulation = self.add_stimulation(pairings, freq, start)
stimulations.append(stimulation)
start = totalruntime - self.after_epsp_wait
stimulation = self.add_stimulation(pairings, freq, start)
stimulations.append(stimulation)
return stimulations
# ========================================================
# ======================= =Misc= =========================
# ========================================================
def set_up_records(self, synapses, soma):
self.set_up_synapse_records(synapses, self.records_synapse)
# record soma:
self.records_soma = {"v": h.Vector()}
self.records_soma['v'].record(soma._ref_v)
# record time:
self.records_t = h.Vector()
self.records_t.record(h._ref_t)
def set_up_synapse_records(self, synapses, records_synapse):
for synapse in synapses:
records_synapse.append(
{
"i": h.Vector(),
"v": h.Vector(),
})
records_synapse[-1]['v'].record(synapse.get_segment()._ref_v)
records_synapse[-1]['i'].record(synapse._ref_i)
def get_parameter(self, synapseid, parameter, t):
tid = next(x[0] for x in enumerate(self.records_t) if x[1] > t)
rec = self.records_synapse[synapseid][parameter][tid]
return (self.records_t[tid], rec)
def get_parameter_range(self, synapseid, parameter, tfrom, tto):
tid_from = next(x[0] for x in enumerate(self.records_t) if x[1] > tfrom)
tid_to = next(x[0] for x in enumerate(self.records_t) if x[1] > tto)
data = list(self.records_synapse[synapseid][parameter])
rec = data[tid_from:tid_to]
return (self.records_t[tid_from], self.records_t[tid_to], rec)
def add_vclamp(self, all_nodes, soma, voltage):
dendrites = [x.item for x in all_nodes]
self.connect_vclamp(dendrites, voltage)
self.connect_vclamp([soma], voltage)
def add_iclamp(self, all_nodes, soma, current):
dendrites = [x.item for x in all_nodes]
self.connect_iclamp(dendrites, current)
self.connect_iclamp([soma], current)
def set_monitor_dendrite(self, monitor_dendrite):
self.monitor_dendrite = monitor_dendrite
# ========================================================
# ==================== =Protocols= =======================
# ========================================================
def build_protocol_ratio(self, dendrite_nodes, soma):
dendrites = [x.item for x in dendrite_nodes]
synapses = self.build_plasticity_synapses(dendrites)
self.set_up_records(synapses, soma)
self.add_AMPA_democracy(synapses, [x.weight_by_distance for x in dendrite_nodes])
(stimulations, stimulation_time) = self.build_stimulations_ratio()
for stimulation in stimulations:
self.connect_synapses_stimulation(synapses, stimulation)
self.totalruntime = stimulation_time + self.stimulation_start + self.after_epsp_wait
return self.totalruntime
def build_protocol_ltp(self, dendrite_nodes, soma):
dendrites = [x.item for x in dendrite_nodes]
synapses = self.build_plasticity_synapses(dendrites)
gabasynapses = self.add_gaba_synapses(soma, 1)
self.set_up_records(synapses, soma)
self.add_AMPA_democracy(synapses, [x.weight_by_distance for x in dendrite_nodes])
(stimulations, stimulation_time) = self.build_stimulations_ltp()
for stimulation in stimulations:
self.connect_synapses_stimulation(synapses, stimulation)
if self.gaba_enabled:
self.connect_synapses_stimulation(gabasynapses, stimulation, self.gaba_weight)
#self.add_artificial_gaba(stimulation, gabasynapses, self.gaba_weight)
if self.epsp_measurement_enabled:
if os.environ["MILEDIDEBUG"] == '1':
print(f"stimulation_time: {stimulation_time}")
totalruntime = stimulation_time + self.stimulation_start + self.epsp_end + self.after_epsp_wait
if os.environ["MILEDIDEBUG"] == '1':
print(f"Measurement TotalTime: {totalruntime}")
self.build_epsp_measurement(synapses, totalruntime)
else:
totalruntime = stimulation_time + self.stimulation_start + self.after_epsp_wait
self.totalruntime = totalruntime
return totalruntime
def build_protocol_ltd(self, dendrite_nodes, soma):
dendrites = [x.item for x in dendrite_nodes]
synapses = self.build_plasticity_synapses(dendrites)
self.set_up_records(synapses, soma)
self.add_AMPA_democracy(synapses, [x.weight_by_distance for x in dendrite_nodes])
(stimulations, stimulation_time) = self.build_stimulations_ltd()
for stimulation in stimulations:
self.connect_synapses_stimulation(synapses, stimulation)
if self.epsp_measurement_enabled:
if os.environ["MILEDIDEBUG"] == '1':
print(f"stimulation_time: {stimulation_time}")
totalruntime = stimulation_time + self.stimulation_start + self.epsp_end + self.after_epsp_wait
if os.environ["MILEDIDEBUG"] == '1':
print(f"Measurement TotalTime: {totalruntime}")
self.build_epsp_measurement(synapses, totalruntime)
else:
totalruntime = stimulation_time + self.stimulation_start + self.after_epsp_wait
self.totalruntime = totalruntime
return totalruntime
# ========================================================
# ===================== =Analysis= =======================
# ========================================================
def get_weights(self):
weights = []
for item in self.records_synapse:
weights.append(list(item["synweight"])[-1])
return weights
def save_weights(self):
weights = []
for item in self.records_synapse:
weights.append(list(item["gbar"])[-1])
Path(f"weights").mkdir(parents=True, exist_ok=True)
with open(f"weights/weights.npy", 'wb') as f:
np.save(f, weights)
def mult_weights(self, multi=100):
for i, synapse in enumerate(self.plastic_synapses):
synapse.gbar = max(-0.5, synapse.gbar*multi)
def load_weights(self):
weights = np.load(f"weights/weights.npy", allow_pickle=True).tolist()
for i, synapse in enumerate(self.plastic_synapses):
pprint(synapse)
synapse.gbar = weights[i]
def print_epsp_change(self):
epspchange = self.get_epsp_change()
if epspchange is not None:
print(f"EPSP change: {epspchange:6.2f}%")
def get_epsp_change(self):
if not self.epsp_measurement_enabled:
print("EPSP measurement is not enabled.")
return None
if (h.t < int(self.totalruntime - self.after_epsp_wait + 50)):
print("Stopped early. Can't calculate EPSP.")
return None
min_start1 = sum([x <= (self.epsp_start - 50) for x in list(self.records_t)])
min_end1 = sum([x <= self.epsp_start for x in list(self.records_t)])
max_start1 = sum([x <= self.epsp_start for x in list(self.records_t)])
max_end1 = sum([x <= (self.epsp_start + 50) for x in list(self.records_t)])
min_start2 = sum([x <= (self.totalruntime - self.after_epsp_wait - 50) for x in list(self.records_t)])
min_end2 = sum([x <= (self.totalruntime - self.after_epsp_wait) for x in list(self.records_t)])
max_start2 = sum([x <= (self.totalruntime - self.after_epsp_wait) for x in list(self.records_t)])
max_end2 = sum([x <= (self.totalruntime - self.after_epsp_wait + 50) for x in list(self.records_t)])
min1 = min(list(self.records_soma["v"])[min_start1:min_end1])
max1 = max(list(self.records_soma["v"])[max_start1:max_end1])
min2 = min(list(self.records_soma["v"])[min_start2:min_end2])
max2 = max(list(self.records_soma["v"])[max_start2:max_end2])
before = (max1 - min1)
after = (max2 - min2)
change = (after * 100) / before
return change
def write_debug_files(self, folder="0"):
Path(f"debug/{folder}").mkdir(parents=True, exist_ok=True)
if self.records_ec_synapse is not None:
with open(f"debug/{folder}/ec_synapse.npy", 'wb') as f:
np.save(f, self.records_ec_synapse)
with open(f"debug/{folder}/synapse.npy", 'wb') as f:
np.save(f, self.records_synapse)
with open(f"debug/{folder}/dendrite.npy", 'wb') as f:
np.save(f, self.records_dendrite)
with open(f"debug/{folder}/soma.npy", 'wb') as f:
np.save(f, self.records_soma)
with open(f"debug/{folder}/t.npy", 'wb') as f:
np.save(f, self.records_t)
with open(f"debug/{folder}/epsp.npy", 'wb') as f:
np.save(f, self.get_epsp_change())
with open(f"debug/{folder}/parameters.npy", 'wb') as f:
np.save(f, self.plasticity_model_parameters)
with open(f"debug/{folder}/spikecount.npy", 'wb') as f:
np.save(f, self.calculate_spikes())
def calculate_spikes_soma(self):
spikestarted = False
spikecount = 0
for vmem in self.records_soma['v']:
if vmem > -30 and not spikestarted:
spikestarted = True
spikecount += 1
if vmem < -30 and spikestarted:
spikestarted = False
return spikecount
def calculate_spikes_synapse(self):
spikes = []
for syndata in self.records_synapse:
spikestarted = False
spikecount = 0
for vmem in syndata['v']:
if vmem > -30 and not spikestarted:
spikestarted = True
spikecount += 1
if vmem < -30 and spikestarted:
spikestarted = False
spikes.append(spikecount)
return spikes
def calculate_spikes(self):
synapse_spikes = self.calculate_spikes_synapse()
soma_spikes = self.calculate_spikes_soma()
return {"synapse_spikes": synapse_spikes, "soma_spikes": soma_spikes }
def get_parameters(self):
return self.plasticity_model_parameters
def pprintinfo(self):
pprint(self.get_parameters())