import os
os.environ['MKL_NUM_THREADS'] = "1"
from typing import Dict, Union
from AlonsoMarderModel import AlonsoMarderModel
import numpy as np
import matplotlib.pyplot as plt
import time
import pickle
models = {
'a': {
'conductances': {
'g_Na': 1076.392, # uS, transient sodium conductance
'g_CaT': 6.4056, # uS, low-threshold calcium conductance
'g_CaS': 10.048, # uS, slow calcium conductance
'g_A': 8.0384, # uS, transient potassium conductance
'g_KCa': 17.584, # uS, calcium-dependent potassium conductance
'g_Kd': 124.0928, # uS, potassium conductance
'g_H': 0.11304, # uS, hyperpolarization-activated cation conductance
'g_L': 0.17584, # uS, leak conductance
},
'tau_ca': 653.5
},
# add any other models desired here
}
def make_dict_model_from_mcmc_result_list(
row_values: np.ndarray, name: str = 'mcmc') -> Dict[str, Dict[str, Union[Dict[str, float], float]]]:
"""
makes dictionary model from mcmc results
@param row_values: initial row vector from MCMC sampler results
@param name: name of the model
@return: model
"""
return {
name: {
'conductances': {
'g_Na': row_values[1], # uS, transient sodium conductance
'g_CaT': row_values[2], # uS, low-threshold calcium conductance
'g_CaS': row_values[3], # uS, slow calcium conductance
'g_A': row_values[4], # uS, transient potassium conductance
'g_KCa': row_values[5], # uS, calcium-dependent potassium conductance
'g_Kd': row_values[6], # uS, potassium conductance
'g_H': row_values[7], # uS, hyperpolarization-activated cation conductance
'g_L': row_values[8], # uS, leak conductance
},
'tau_ca': row_values[9]
}
}
def plot_model_outputs(
key: str, dict_models: dict, current_injected: float = 0.0, silent: bool = False) -> np.array:
"""
plot all 6 models mV vs ms with its threshold
:param key: model type (a through f)
:param dict_models: model specification (a through f)
:param current_injected: amount of current to inject
:param silent: False by default to show plots
:return: time steps, voltage trace, spike times and threshold and their respective plots
"""
print(f'key : {key}')
time_in_seconds = 6.0
the_model = AlonsoMarderModel(injected_current=current_injected,
conductances=dict_models[key]['conductances'],
tau_ca=dict_models[key]['tau_ca'])
time_steps = np.arange(0.0, time_in_seconds * 1e3, 0.01)
start = time.time()
model_output = the_model.run_simulation(time_steps)
print(f'compute time: {time.time() - start}')
if not silent:
plt.plot(model_output["t"], model_output["y"])
# plt.plot(model_output["spike_times"], model_output["spike_threshold"], "ro")
plt.xlabel('time (ms)')
plt.ylabel('Voltage (mV)')
plt.title(f'model: {key}')
# plt.legend(the_model.get_state_vars_labels())
plt.show()
return model_output
def plot_model_comparison(dict_models: Dict[str, Dict[str, Dict[str, np.ndarray]]]) -> None:
for key, val in dict_models.items():
if key == 'a':
continue
for current_index in np.arange(0, .5, .1):
fig, ((g), (h)) = plt.subplots(2)
fig.suptitle(f'ground truth model (top) and mcmc (bot) - model: {key} current: {float(current_index)}')
g.plot(dict_models['a'][current_index]['t'], dict_models['a'][current_index]['y'], 'tab:orange')
h.plot(dict_models[key][current_index]['t'], dict_models[key][current_index]['y'])
plt.xlabel('samples')
for ax in fig.get_axes():
ax.label_outer()
plt.savefig(f'{current_index}.png')
plt.show()
return None
def convert_dict_of_lists_to_model(
model_dict: Dict[str, Union[np.ndarray, list]]) -> Dict[str, Dict[str, Union[Dict[str, float], float]]]:
all_models = {}
for key, val in model_dict.items():
model = {
key: {
'conductances': {
'g_Na': val[0], # uS, transient sodium conductance
'g_CaT': val[1], # uS, low-threshold calcium conductance
'g_CaS': val[2], # uS, slow calcium conductance
'g_A': val[3], # uS, transient potassium conductance
'g_KCa': val[4], # uS, calcium-dependent potassium conductance
'g_Kd': val[5], # uS, potassium conductance
'g_H': val[6], # uS, hyperpolarization-activated cation conductance
'g_L': val[7], # uS, leak conductance
},
'tau_ca': val[8]
}
}
all_models.update(model)
return all_models
def compute_and_save_models(
models_to_make: Dict[str, Dict[str, Union[Dict[str, float], float]]],
mcmc_model_desired, compute_single_current=False) -> None:
"""
Store all models in a pickle
:return: None
"""
model = {}
model_all = {}
if compute_single_current:
for key in mcmc_model_desired:
if key == 'all':
model_all.update({current_to_plot: plot_model_outputs(key, models_to_make, current_to_plot, False) for
current_to_plot in np.arange(0.0, 0.5, 0.1)})
else:
model.update({float(key): plot_model_outputs(key, models_to_make, float(key), False)})
model_outs = {
'a':
{current_to_plop: plot_model_outputs('a', models_to_make, current_to_plop, False)
for current_to_plop in np.arange(0.0, 0.5, 0.1)},
'mcmc_single_currents': model,
'mcmc_all_current': model_all,
}
else:
model_outs = {
key: {current_to_plop: plot_model_outputs(key, models_to_make, current_to_plop, current_to_plop != 0.0)
for current_to_plop in np.arange(0.0, 0.5, 0.1)}
for key in models_to_make.keys()}
# compare_models(model_outs)
plot_model_comparison(model_outs)
with open('AlonsoMarderModel_generated_data.pkl', 'wb') as filehandle:
pickle.dump(model_outs, filehandle, protocol=pickle.HIGHEST_PROTOCOL)
return None
if __name__ == '__main__':
# plot_model_outputs()
mcmc_model_dict = {
'0.0': [1.36504984e+03, 6.90471282e+00, 1.03010743e+01, 7.52944902e+01, 1.79958197e+01, 1.07390603e+02,
4.15844030e-01, 1.75782751e-01, 7.38407016e+02], # new uniform, 82.89
'0.1': [1510.54096, 7.84354, 10.96749, 54.15303, 16.48460, 139.69044, 0.27945, 0.19041, 607.57802], # 0.00212
'0.2': [1823.57888, 6.95997, 13.63717, 145.06658, 13.45328, 92.24636, 0.35482, 0.18666, 342.65233], # 0.02280
'0.30000000000000004': [1554.45888, 6.53608, 14.23725, 140.65041, 12.56214, 53.72467, 0.21759, 0.14271,
470.50603], # 0.00189
'0.4': [1326.94727, 7.60586, 8.80298, 35.66033, 10.20653, 91.64060, 0.30257, 0.21334, 238.34122], # 0.00439
'all': [1307.70985, 8.85968, 13.41245, 114.53560, 16.89606, 121.59780, 0.28586, 0.12557, 794.03210], # 18.76235
}
mcmc_model = convert_dict_of_lists_to_model(mcmc_model_dict)
models.update(mcmc_model)
# plot_model_comparison(models)
compute_and_save_models(models, mcmc_model_dict, compute_single_current=True)
model_outputs = pickle.load(open('AlonsoMarderModel_generated_data.pkl', 'rb'))