from datetime import datetime
import os
from preneuron import Pre
from postneuron import Post
from astrocyte import Astro
from tqdm import tqdm
from scipy import io
from collections import defaultdict
def state_var_to_be_saved(pre, post, astro):
return{
"Ca_CaNHVA_pre": pre["Ca_CaNHVA_pre"],
"Ca_NMDAR_pre": pre["Ca_NMDAR_pre"],
"CaN_pre": pre["CaN_pre"],
"Glu_syncleft": post["Glu_syncleft"],
"Prel_pre": pre["Prel_pre"],
"RA2_O_pre": pre["RA2_O_pre"],
"Rrel_pre": pre["Rrel_pre"],
"V_pre": pre["V_pre"],
"AG_post": post["AG_post"],
"Ca_post": post["Ca_post"],
"Ca_ER_post": post["Ca_ER_post"],
"Ca_DAG_GaGTP_PLC_post": post["Ca_DAG_GaGTP_PLC_post"],
"Ca_DAG_PLC_post": post["Ca_DAG_PLC_post"],
"Ca_GaGTP_PLC_post": post["Ca_GaGTP_PLC_post"],
"Ca_PLC_post": post["Ca_PLC_post"],
"DAG_post": post["DAG_post"],
"GaGTP_PLC_post": post["GaGTP_PLC_post"],
"h_IP3R_post": post["h_IP3R_post"],
"IP3_post": post["IP3_post"],
"PLC_post": post["PLC_post"],
"V_dend_post": post["V_dend_post"],
"V_soma_post": post["V_soma_post"],
"Ca_astro": astro["Ca_astro"],
"Glu_extsyn": astro["Glu_extsyn"],
"h_astro": astro["h_astro"],
"IP3_astro": astro["IP3_astro"],
"Rrel_astro": astro["Rrel_astro"]
}
def main(T_shift):
time_start = datetime.now()
print("Simulation started at", time_start)
path_template = "./results_post_pre_pairing_100x/%sms/"
path = path_template % T_shift
try:
os.makedirs(path)
except OSError:
print("Creation of the directory %s failed" % path)
else:
print("Successfully created the directory %s " % path)
stim_start = 20000
trainlengthtime = 500000
restlengthtime = 20000
no_trains = 1
pulserate = 0.2
pulselengthtime = 10
A_stim_post = 25
A_stim_pre = 10
dt = 0.05
T_end = stim_start + no_trains * (trainlengthtime + restlengthtime)
Nsteps = round(T_end / dt)
t = [i * dt for i in range(Nsteps + 1)]
pulselength = round(pulselengthtime / dt)
no_pulses = round(pulserate * trainlengthtime * 1e-3)
pauselengthtime = (trainlengthtime - no_pulses * pulselengthtime) / no_pulses
pauselength = round(pauselengthtime / dt)
restlength = round(restlengthtime / dt)
steps_T_shift = round(T_shift / dt)
stim_pause_post = ([A_stim_post] * pulselength + [0] * pauselength) * no_pulses
stim_pause_pre = ([A_stim_pre] * pulselength + [0] * pauselength) * no_pulses
I_ext_post = [0] * (round(stim_start / dt) + 1) + (stim_pause_post + [0] * restlength) * no_trains
I_ext_pre = [0] * (round(stim_start / dt) + steps_T_shift + 1) + (stim_pause_pre + [0] * restlength) * no_trains
pre_params = Pre.get_parameters()
pre_init = Pre.get_initial_values(pre_params)
pre = Pre(pre_params, pre_init)
post_params = Post.get_parameters()
post_init = Post.get_initial_values(post_params)
post = Post(post_params, post_init)
astro_params = Astro.get_parameters()
astro_init = Astro.get_initial_values(astro_params)
astro = Astro(astro_params, astro_init)
Ca_flux_post = post.calcium_other_fluxes()
Ca_par_post = post.calcium_leak_parameters(Ca_flux_post["J_CaL_post"],
Ca_flux_post["J_IP3R_post"],
Ca_flux_post["J_NMDAR_post"],
Ca_flux_post["J_PMCA_post"],
Ca_flux_post["J_SERCA_post"])
Ca_flux_astro = astro.calcium_other_fluxes()
Ca_par_astro = astro.calcium_leak_parameters(Ca_flux_astro["J_IP3R_astro"],
Ca_flux_astro["J_SERCA_astro"])
state_var = state_var_to_be_saved(pre.x, post.x, astro.x)
saved_state_var = {key: [values] for key, values in state_var.items()}
saved_other_var = defaultdict(list)
t_spike = 10
for i in tqdm(range(Nsteps)):
if i == stim_start * 1 / 2 / dt or i == stim_start * 3 / 4 / dt:
Ca_flux_post = post.calcium_other_fluxes()
Ca_par_post = post.calcium_leak_parameters(Ca_flux_post["J_CaL_post"],
Ca_flux_post["J_IP3R_post"],
Ca_flux_post["J_NMDAR_post"],
Ca_flux_post["J_PMCA_post"],
Ca_flux_post["J_SERCA_post"])
Ca_flux_astro = astro.calcium_other_fluxes()
Ca_par_astro = astro.calcium_leak_parameters(Ca_flux_astro["J_IP3R_astro"],
Ca_flux_astro["J_SERCA_astro"])
Ca_pre_old = pre.x["Ca_CaNHVA_pre"]
Prel_pre_old = pre.x["Prel_pre"]
Rrel_pre_old = pre.x["Rrel_pre"]
V_pre_old = pre.x["V_pre"]
Ca_astro_old = astro.x["Ca_astro"]
Rrel_astro_old = astro.x["Rrel_astro"]
f_pre = pre.x["X_ac_pre"] / pre_params["X_total_pre"]
deriv_pre, other_var_pre = pre.derivative(
astro.x["Glu_extsyn"], post.x["Glu_syncleft"], I_ext_pre[i+1])
deriv_post, other_var_post = post.derivative(
pre.params["f_Glu_pre"], I_ext_post[i+1], Ca_par_post["r_leakCell_post"], Ca_par_post["r_leakER_post"])
deriv_ast, other_var_ast = astro.derivative(post.x["AG_post"], Ca_par_astro["r_leakER_astro"])
pre.solve_deriv(deriv_pre, dt)
post.solve_deriv(deriv_post, dt)
astro.solve_deriv(deriv_ast, dt)
if (pre.x["V_pre"] >= 0) and (V_pre_old < 0):
t_spike = 0
else:
t_spike = t_spike + dt
if (pre.x["Ca_CaNHVA_pre"] >= pre.params["C_thr_pre"]) and (t_spike < 10):
pre.solve_deltaf(Ca_pre_old, f_pre, Prel_pre_old, Rrel_pre_old)
post.solve_deltaf(pre, Rrel_pre_old)
t_spike = 10
if (astro.x["Ca_astro"] >= astro.params["C_thr_astro"]) and (Ca_astro_old < astro.params["C_thr_astro"]):
astro.solve_deltaf(Rrel_astro_old)
state_var = state_var_to_be_saved(pre.x, post.x, astro.x)
for key, values in state_var.items():
saved_state_var[key].append(values)
other_var = {"f_pre": f_pre,
"Glu_NMDAR_pre": other_var_pre["Glu_NMDAR_pre"],
"ICaNHVA_pre": other_var_pre["ICaNHVA_pre"],
"ICa_NMDAR_pre": other_var_pre["ICa_NMDAR_pre"],
"I_AMPAR_post": other_var_post["I_AMPAR_post"],
"ICaLHVA_dend_post": other_var_post["ICaLHVA_dend_post"],
"ICaLLVA_dend_post": other_var_post["ICaLLVA_dend_post"],
"ICa_NMDAR_post": other_var_post["ICa_NMDAR_post"],
"J_CaL_post": other_var_post["J_CaL_post"],
"J_IP3R_post": other_var_post["J_IP3R_post"],
"J_leakCell_post": other_var_post["J_leakCell_post"],
"J_leakER_post": other_var_post["J_leakER_post"],
"J_NMDAR_post": other_var_post["J_NMDAR_post"],
"J_PMCA_post": other_var_post["J_PMCA_post"],
"J_SERCA_post": other_var_post["J_SERCA_post"]}
for key, values in other_var.items():
saved_other_var[key].append(values)
io.savemat(os.path.join(path, "state_var_results.mat"), saved_state_var)
io.savemat(os.path.join(path, "other_var_results.mat"), saved_other_var)
io.savemat(os.path.join(path, "time_stimuli.mat"),
{**{"time": [tp / 1000 for tp in t]}, **{"I_ext_pre": [I_ext_pre]}, **{"I_ext_post": [I_ext_post]}})
io.savemat(os.path.join(path, "stimulation_parameters.mat"),
{**{"dt": dt}, **{"pulserate": pulserate}, **{"T_shift": T_shift}})
time_end = datetime.now()
total_time = (time_end - time_start).seconds / 60.
print("\n")
print("Simulation finished at", time_end)
print("Total time = {0:.2f} minutes".format(total_time))
if __name__ == "__main__":
for T_shift in range(10, 210, 10):
main(T_shift)