"""
Motoneuron Pool Optimization source code for
"Supercomputer framework for reverse engineering firing patterns of neuron populations to identify their synaptic inputs"
by Matthieu K. Chardon, Y. Curtis Wang, Marta Garcia, Emre Besler, J. Andrew Beauchamp,
Michael D’Mello, Randall K. Powers, and Charles J. Heckman.
This Python code is set up to run one combination for 20 motoneurons with MPI on one machine.
DOI: https://doi.org/10.7554/eLife.90624.2
Must be run with MPI, for instance:
mpiexec -n 20 python -u MN_pool_public.py
"""
from mpi4py import MPI
from neuron import h
import MN_types
import gc
import sys
import logging
import logging.handlers
import argparse
import time
import numpy as np
import pandas as pd
import copy
import pickle
import os
logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
h.nrnmpi_init()
pc_main = h.ParallelContext()
NUM_JOBS = 16
N_TRY = 20 # 20 # 8 # number of loops to try
NMN = 20 # 5 # number of motoneuron or number of subjobs
GAIN = 0.2 # Error gain
FIRSTMULT = 0.6 # Varies the first guess to help with optimization
SMULTSTRT = 1 # Relative Value for lowest threshold motorunit. these parameters allow for unequal distribution of synaptic weights
SMULTEND = 1 # Relative Value for highest threshold motorunit.
SMULTSKEW = 1 # Distribution of spacing of input
GINMULT = -0.7
NMMULT = 1.2
GINADD = 6.5
RANDOM_SEED = 1
def inner_job(std_noise_value, excitatory_input, inhibitory_input,
try_str, MN_str, comb, nmmult, noise_name, stim_time):
try:
logging.info('INNER: - inner job launched!')
logging.info('INNER: - running try ' + try_str + ', MN file ' + MN_str)
logging.debug('INNER: - generating MN skeleton...')
soma, d1, d2, d3, d4, dend = generate_MN_skeleton(std_noise_value, h)
logging.debug('INNER: - done generating MN skeleton!')
target_time = h.Vector(stim_time)
# Record soma spike times
nc_soma = h.NetCon(soma(0.5)._ref_v, None, sec=soma)
nc_soma.threshold = -20 # Threshold voltage (mV) for spike detection
spike_times_soma = h.Vector()
nc_soma.record(spike_times_soma)
# Record conductance Gfluctdv
g_e = h.Vector()
g_e.record(d1(0.5)._ref_g_e_Gfluctdv)
g_i = h.Vector()
g_i.record(d1(0.5)._ref_g_i_Gfluctdv)
multex_Gfluctdv_vec = h.Vector()
multex_Gfluctdv_vec.record(h._ref_multex_Gfluctdv)
# Record membrane potential
v_vec = h.Vector() # Membrane potential vector
t_vec = h.Vector() # Time stamp vector
v_vec.record(soma(0.5)._ref_v)
t_vec.record(h._ref_t)
# print('MN ' + str(ii))
MN = {}
MN["spike_times"] = {}
MN['smspikes'] = {}
MN["recruited"] = {}
MN['weight'] = {}
my_neuron_transform = getattr(MN_types, 'load_' + str(MN_str))
my_neuron_transform(soma, dend, d1, d2, d3, d4, h)
for sec in dend:
if h.ismembrane('L_Ca_inact', sec=sec):
sec.gcabar_L_Ca_inact = nmmult * sec.gcabar_L_Ca_inact
weight = SMULTSTRT + ((int(MN_str[3:]) / (NMN - 1)) ** (1 + SMULTSKEW)) * (
SMULTEND - SMULTSTRT) # Formula for distribution of excitatory input across the motoneuron pool
MN['weight'] = weight
seed_correlated = comb[3]
h.new_seed_Gfluctdv(seed_correlated) # Set random seed with same values
MN['seed_correlated'] = {}
MN['seed_correlated'] = seed_correlated
excitatory_input_w = excitatory_input * weight
stimvalex = h.Vector(excitatory_input_w)
stimvalex.play(h._ref_multex_Gfluctdv, target_time) # Vector play excitation into multex_Gfluctdv inside Gfluctdv.mod file
stimvalin = h.Vector(inhibitory_input)
stimvalin.play(h._ref_multin_Gfluctdv, target_time) # Vector play inhbition into multin_Gfluctdv inside Gfluctdv.mod file
# Run Stimulation
gc.collect()
h.dt = 0.025
h.finitialize()
h.fcurrent()
h.frecord_init()
logging.info('INNER: - computing: ' + try_str + ', ' + MN_str + ' for ' + str(comb))
while int(h.t) < 22000: # stop at 22000
h.fadvance()
h.stoprun = 1
out_times = np.copy(spike_times_soma.as_numpy())
g_e.resize(0)
g_i.resize(0)
v_vec.resize(0)
t_vec.resize(0)
spike_times_soma.resize(0)
del stimvalex, stimvalin, g_e, g_i, multex_Gfluctdv_vec, v_vec, t_vec, soma, dend, d1, d2, d3, d4
del my_neuron_transform, excitatory_input_w, spike_times_soma
gc.collect()
return MN, out_times, MN_str
except Exception as e:
logging.exception(e)
gc.collect()
return {}, None, MN_str
def job(comb, comb_id, noise_name):
try:
logging.info(f"{comb_id} >> OUTER: - job started! comb is: {comb}")
std_noise_value = 1.2e-5
###################################
# Load Target Data
###################################
target = np.loadtxt('targets_and_commands/glin16.txt')
stim_time = target[:, 0]
stim_target = target[:, 1]
###################################
# Load Convolution Filter
###################################
conv_filt = np.loadtxt('hwave4w.txt')
conv_filt = -(conv_filt - 1)
ginmult = comb[0] # -0.7 # Proportionality constant to convert excitatory to inhibitory
# for push pull ginmlut negative. Set negative stimvalin to zero
ginadd = comb[2] # 4 # Starting value of inhibition on y axis
nmmult = comb[1] # Relative amount of neuromodulation (Ca channel)
# Inital Excitatory Input
excitatory_input = stim_target * FIRSTMULT
average_rate = np.zeros(len(stim_time))
leave_for_loop = False
for try_id in range(N_TRY):
MN = {}
MN['n_try'] = {}
MN['nmn'] = {}
MN['ginmult'] = {}
MN['nmmult'] = {}
MN['ginadd'] = {}
MN['random_seed'] = {}
MN['gain'] = {}
MN['firstmult'] = {}
MN['smultstrt'] = {}
MN['smultend'] = {}
MN['smultskew'] = {}
MN['try_final'] = {}
MN['nmn'] = NMN
MN['ginmult'] = ginmult
MN['nmmult'] = nmmult
MN['ginadd'] = ginadd
MN['random_seed'] = comb[3]
MN['gain'] = GAIN
MN['firstmult'] = FIRSTMULT
MN['smultstrt'] = SMULTSTRT
MN['smultend'] = SMULTEND
MN['smultskew'] = SMULTSKEW
MN['n_try'] = try_id
logging.info(f"{comb_id} >> OUTER: - Try: {try_id}")
inhibitory_input = excitatory_input * ginmult + ginadd
inhibitory_input[inhibitory_input <= 0] = 1e-7 # Set negative inhibitory values to 0
MN['excitatory_input'] = {}
MN['inhibitory_input'] = {}
MN['excitatory_input'] = excitatory_input
MN['inhibitory_input'] = inhibitory_input
logging.info(f"{comb_id} >> OUTER: - launching inner job!")
c = comb_id
job_dict = [pc_main.submit(inner_job,
copy.deepcopy(std_noise_value),
copy.deepcopy(excitatory_input),
copy.deepcopy(inhibitory_input),
'try_'+str(try_id),
'MN_'+str(ii),
copy.deepcopy(comb),
copy.deepcopy(nmmult),
copy.deepcopy(noise_name),
copy.deepcopy(stim_time)) for ii in range(NMN)]
logging.info(f"{comb_id} >> OUTER: - inner jobs launched, waiting for return!")
while pc_main.working(): # when running parallel
try:
out_dict, spike_times, MN_str = pc_main.pyret() # when running parallel
ii = int(MN_str[3:])
logging.info(f"{comb_id} >> OUTER: - >>> MN_{ii} completed!")
MN[f"MN_{ii}"] = copy.deepcopy(out_dict)
del out_dict
gc.collect()
MN[f"MN_{ii}"]["spike_times"] = copy.deepcopy(spike_times)
if MN[f"MN_{ii}"]["spike_times"].size > 0:
MN[f"MN_{ii}"]["recruited"] = 1
else:
MN[f"MN_{ii}"]["recruited"] = 0
conv_vec = np.zeros(len(stim_time))
for i in range(0, len(spike_times)):
idx = (np.abs(stim_time - spike_times[i])).argmin()
conv_vec[idx] = 1
smspikes = np.convolve(conv_vec, conv_filt, 'same') # perform convolution
smspikes = smspikes[0:len(stim_time)] #
MN[f"MN_{ii}"]['smspikes'] = smspikes
average_rate = average_rate + smspikes
logging.info(f"{comb_id} >> OUTER: - >>> MN_{ii} completed computing!")
del spike_times
gc.collect()
except Exception as e:
logging.exception(e)
mn_recruited = 0
for n in range(NMN):
mn_recruited = mn_recruited + float(MN['MN_' + str(n)]["recruited"])
MN['mn_recruited'] = mn_recruited
average_rate = average_rate / mn_recruited # Calculate average rate
MN['average_rate'] = average_rate
# Check if msq < 0.5 and mn_recruited == nmn
msq = np.mean((stim_target - average_rate) ** 2) # mean squared error
MN['msq'] = msq
if msq < 0.5 and mn_recruited == NMN:
leave_for_loop = True
logging.info('OUTER: - Mean Squared Error: ' + str(msq))
logging.info('OUTER: - n_try = ' + str(try_id) + ', Search stopped')
# Calculate Error and Recalibrate excitatory input
GAIN_NEW = GAIN
if msq < 3:
GAIN_NEW = 0.5 * GAIN
elif msq < 1.5:
GAIN_NEW = 0.25 * GAIN
error = (stim_target - average_rate) * GAIN_NEW
excitatory_input = excitatory_input + error
# Kludge factor if MN recruited is less than 18
mulvec = np.loadtxt('targets_and_commands/mulvec.txt')
if (NMN - mn_recruited) > 2:
excitatory_input = np.multiply(excitatory_input, mulvec)
excitatory_input[excitatory_input <= 0] = 1e-7 # Set negative excitatory values to 0
MN['error'] = error
gc.collect()
###################################
# Save Data
###################################
seed_value = comb[3]
dir_name = f'{SMULTSTRT}{SMULTEND}{SMULTSKEW}_{seed_value}'
path_name = os.path.join(os.getcwd(), 'new_data', noise_name, dir_name)
if not os.path.exists(path_name):
os.makedirs(path_name, exist_ok=True)
sweep_name = f"{ginmult:.3f}_{nmmult:.3f}_{ginadd:.3f}_n_try_{try_id}_{comb[3]}"
if leave_for_loop or try_id == (N_TRY-1):
MN['try_final'] = try_id
filename = f"pickle_{sweep_name}_FINAL.pkl"
with open(os.path.join(path_name, filename), "wb") as f:
pickle.dump(MN, f)
file_path = os.path.join(path_name, filename)
del MN
gc.collect()
if leave_for_loop:
break
del target, stim_time, stim_target, conv_filt, excitatory_input, average_rate
gc.collect()
return (False, comb)
except Exception as e:
logging.exception(e)
return (True, comb)
def generate_MN_skeleton(std_noise_value, h_object):
# Define sections
soma = h_object.Section(name='soma') # Define soma section
d1 = h_object.Section(name='d1') # Define dend section
d2 = h_object.Section(name='d2') # Define dend section
d3 = h_object.Section(name='d3') # Define dend section
d4 = h_object.Section(name='d4') # Define dend section
dend = h_object.SectionList()
dend.append(sec=d1)
dend.append(sec=d2)
dend.append(sec=d3)
dend.append(sec=d4)
# Connect dendrites to soma
d1.connect(soma(1), 0)
d2.connect(soma(1), 0)
d3.connect(soma(0), 0)
d4.connect(soma(0), 0)
# h_object.topology()
# Soma Channel Definition
soma.insert('na3rp') # Insert na3rp channel: Transient Na Channel
soma.insert('naps') # Insert naps channel: Persistent Na Channel
soma.insert('kdrRL') # Insert kdrRL channel: Delayed Rectifier K Channel
soma.insert('mAHP') # Insert AHP mechanism
h_object.tmin_kdrRL = 0.8
h_object.taumax_kdrRL = 20
h_object.mVh_kdrRL = -21
# All Section Channel Definition
for sec in h_object.allsec():
sec.insert('pas') # Insert passive mechanism
sec.insert('gh') # Insert gh mechanism: HCN channel (look pubmed)
#sec.e_pas = -70 # Leak reversal potential mV changes in the MN_type Files
# Dentrite Channel Definition
for sec in dend:
sec.insert('L_Ca_inact') # Insert L_Ca_inact channel
sec.insert('Gfluctdv') # Insert Gfluctdv mechanism
sec.g_e0_Gfluctdv = 1e-5 # Average conductance - when multiplied by multex that is actual conductance S/cm^2
sec.g_i0_Gfluctdv = 1e-5 # Seimens/cm2
sec.tau_e_Gfluctdv = 20 # Time constant of filtered noise - see description in Gfluctdv.mod
sec.tau_i_Gfluctdv = 20
sec.std_e_Gfluctdv = std_noise_value # 1.2e-5 # Standard deviation of noise
sec.std_i_Gfluctdv = std_noise_value # 1.2e-5
return soma, d1, d2, d3, d4, dend
def main():
start = time.time()
noise_name = 'seed_correlated'
comb_id = "A"
logging.info('Starting Jobs')
job_dict = {pc_main.submit(job, [GINMULT, NMMULT, GINADD, RANDOM_SEED], comb_id, noise_name)}
logging.info(job_dict)
while pc_main.working():
ret = pc_main.pyret()
logging.info('ret: {}'.format(len(ret)))
out_df, comb = ret
if out_df is False:
logging.info('MAIN_LAUNCHER: - SUCCESS! comb succeeded: ' + str(comb))
elif out_df is True:
logging.info('MAIN_LAUNCHER: - ERROR! comb failed: ' + str(comb))
stop = time.time()
print("Sweep Time (s): " + str(stop - start))
pc_main.done()
return
if __name__ == '__main__':
pc_main.runworker()
main()