"""
Motoneuron Pool Optimization source code for
"Supercomputer framework for reverse engineering firing patterns of neuron populations to identify their synaptic inputs"
by Matthieu K. Chardon, Y. Curtis Wang, Marta Garcia, Emre Besler, J. Andrew Beauchamp,
Michael D’Mello, Randall K. Powers, and Charles J. Heckman.

This Python code is set up to run one combination for 20 motoneurons with MPI on one machine.

DOI: https://doi.org/10.7554/eLife.90624.2

Must be run with MPI, for instance:
mpiexec -n 20 python -u MN_pool_public.py

"""


from mpi4py import MPI
from neuron import h
import MN_types

import gc
import sys
import logging
import logging.handlers
import argparse
import time
import numpy as np
import pandas as pd
import copy
import pickle
import os


logging.basicConfig(format='%(levelname)s:%(message)s', level=logging.DEBUG)
h.nrnmpi_init()
pc_main = h.ParallelContext()


NUM_JOBS = 16
N_TRY = 20 # 20 # 8 # number of loops to try
NMN = 20  # 5 # number of motoneuron or number of subjobs
GAIN = 0.2  # Error gain
FIRSTMULT = 0.6  # Varies the first guess to help with optimization
SMULTSTRT = 1  # Relative Value for lowest threshold motorunit. these parameters allow for unequal distribution of synaptic weights
SMULTEND = 1  # Relative Value for highest threshold motorunit.
SMULTSKEW = 1  # Distribution of spacing of input

GINMULT = -0.7
NMMULT = 1.2
GINADD = 6.5
RANDOM_SEED = 1


def inner_job(std_noise_value, excitatory_input, inhibitory_input,
              try_str, MN_str, comb, nmmult, noise_name, stim_time):
    try:
        logging.info('INNER: - inner job launched!')
        logging.info('INNER: - running try ' + try_str + ', MN file ' + MN_str)
        logging.debug('INNER: - generating MN skeleton...')
        soma, d1, d2, d3, d4, dend = generate_MN_skeleton(std_noise_value, h)
        logging.debug('INNER: - done generating MN skeleton!')

        target_time = h.Vector(stim_time)

        # Record soma spike times
        nc_soma = h.NetCon(soma(0.5)._ref_v, None, sec=soma)
        nc_soma.threshold = -20  # Threshold voltage (mV) for spike detection
        spike_times_soma = h.Vector()
        nc_soma.record(spike_times_soma)

        # Record conductance Gfluctdv
        g_e = h.Vector()
        g_e.record(d1(0.5)._ref_g_e_Gfluctdv)
        g_i = h.Vector()
        g_i.record(d1(0.5)._ref_g_i_Gfluctdv)

        multex_Gfluctdv_vec = h.Vector()
        multex_Gfluctdv_vec.record(h._ref_multex_Gfluctdv)

        # Record membrane potential
        v_vec = h.Vector()  # Membrane potential vector
        t_vec = h.Vector()  # Time stamp vector
        v_vec.record(soma(0.5)._ref_v)
        t_vec.record(h._ref_t)

        # print('MN ' + str(ii))
        MN = {}
        MN["spike_times"] = {}
        MN['smspikes'] = {}
        MN["recruited"] = {}
        MN['weight'] = {}

        my_neuron_transform = getattr(MN_types, 'load_' + str(MN_str))
        my_neuron_transform(soma, dend, d1, d2, d3, d4, h)

        for sec in dend:
            if h.ismembrane('L_Ca_inact', sec=sec):
                sec.gcabar_L_Ca_inact = nmmult * sec.gcabar_L_Ca_inact

        weight = SMULTSTRT + ((int(MN_str[3:]) / (NMN - 1)) ** (1 + SMULTSKEW)) * (
                    SMULTEND - SMULTSTRT)  # Formula for distribution of excitatory input across the motoneuron pool
        MN['weight'] = weight

        seed_correlated = comb[3]
        h.new_seed_Gfluctdv(seed_correlated)  # Set random seed with same values
        MN['seed_correlated'] = {}
        MN['seed_correlated'] = seed_correlated

        excitatory_input_w = excitatory_input * weight
        stimvalex = h.Vector(excitatory_input_w)
        stimvalex.play(h._ref_multex_Gfluctdv, target_time)  # Vector play excitation into multex_Gfluctdv inside Gfluctdv.mod file
        stimvalin = h.Vector(inhibitory_input)
        stimvalin.play(h._ref_multin_Gfluctdv, target_time)  # Vector play inhbition into multin_Gfluctdv inside Gfluctdv.mod file

        # Run Stimulation
        gc.collect()
        h.dt = 0.025
        h.finitialize()
        h.fcurrent()
        h.frecord_init()
        logging.info('INNER: - computing: ' + try_str + ', ' + MN_str + ' for ' + str(comb))
        while int(h.t) < 22000:  # stop at 22000
            h.fadvance()
        h.stoprun = 1
        out_times = np.copy(spike_times_soma.as_numpy())

        g_e.resize(0)
        g_i.resize(0)
        v_vec.resize(0)
        t_vec.resize(0)
        spike_times_soma.resize(0)

        del stimvalex, stimvalin, g_e, g_i, multex_Gfluctdv_vec, v_vec, t_vec, soma, dend, d1, d2, d3, d4
        del my_neuron_transform, excitatory_input_w, spike_times_soma
        gc.collect()
        return MN, out_times, MN_str
    except Exception as e:
        logging.exception(e)
        gc.collect()
        return {}, None, MN_str


def job(comb, comb_id, noise_name):
    try:
        logging.info(f"{comb_id} >> OUTER: - job started! comb is: {comb}")
        std_noise_value = 1.2e-5

        ###################################
        # Load Target Data
        ###################################
        target = np.loadtxt('targets_and_commands/glin16.txt')
        stim_time = target[:, 0]
        stim_target = target[:, 1]

        ###################################
        # Load Convolution Filter
        ###################################
        conv_filt = np.loadtxt('hwave4w.txt')
        conv_filt = -(conv_filt - 1)

        ginmult = comb[0]  # -0.7 # Proportionality constant to convert excitatory to inhibitory
        # for push pull ginmlut negative.  Set negative stimvalin to zero
        ginadd = comb[2]  # 4 # Starting value of inhibition on y axis
        nmmult = comb[1]  # Relative amount of neuromodulation (Ca channel)

        # Inital Excitatory Input
        excitatory_input = stim_target * FIRSTMULT
        average_rate = np.zeros(len(stim_time))
        leave_for_loop = False

        for try_id in range(N_TRY):
            MN = {}
            MN['n_try'] = {}
            MN['nmn'] = {}
            MN['ginmult'] = {}
            MN['nmmult'] = {}
            MN['ginadd'] = {}
            MN['random_seed'] = {}
            MN['gain'] = {}
            MN['firstmult'] = {}
            MN['smultstrt'] = {}
            MN['smultend'] = {}
            MN['smultskew'] = {}
            MN['try_final'] = {}

            MN['nmn'] = NMN
            MN['ginmult'] = ginmult
            MN['nmmult'] = nmmult
            MN['ginadd'] = ginadd
            MN['random_seed'] = comb[3]
            MN['gain'] = GAIN
            MN['firstmult'] = FIRSTMULT
            MN['smultstrt'] = SMULTSTRT
            MN['smultend'] = SMULTEND
            MN['smultskew'] = SMULTSKEW

            MN['n_try'] = try_id
            logging.info(f"{comb_id} >> OUTER: - Try: {try_id}")
            inhibitory_input = excitatory_input * ginmult + ginadd
            inhibitory_input[inhibitory_input <= 0] = 1e-7  # Set negative inhibitory values to 0

            MN['excitatory_input'] = {}
            MN['inhibitory_input'] = {}
            MN['excitatory_input'] = excitatory_input
            MN['inhibitory_input'] = inhibitory_input

            logging.info(f"{comb_id} >> OUTER: - launching inner job!")
            c = comb_id
            
            job_dict = [pc_main.submit(inner_job,
                                            copy.deepcopy(std_noise_value),
                                            copy.deepcopy(excitatory_input),
                                            copy.deepcopy(inhibitory_input),
                                            'try_'+str(try_id),
                                            'MN_'+str(ii),
                                            copy.deepcopy(comb),
                                            copy.deepcopy(nmmult),
                                            copy.deepcopy(noise_name),
                                            copy.deepcopy(stim_time)) for ii in range(NMN)]

            logging.info(f"{comb_id} >> OUTER: - inner jobs launched, waiting for return!")

            while pc_main.working(): # when running parallel
                try:
                    out_dict, spike_times, MN_str = pc_main.pyret() # when running parallel
                    ii = int(MN_str[3:])
                    logging.info(f"{comb_id} >> OUTER: - >>> MN_{ii} completed!")
                    MN[f"MN_{ii}"] = copy.deepcopy(out_dict)
                    del out_dict
                    gc.collect()

                    MN[f"MN_{ii}"]["spike_times"] = copy.deepcopy(spike_times)

                    if MN[f"MN_{ii}"]["spike_times"].size > 0:
                        MN[f"MN_{ii}"]["recruited"] = 1
                    else:
                        MN[f"MN_{ii}"]["recruited"] = 0

                    conv_vec = np.zeros(len(stim_time))
                    for i in range(0, len(spike_times)):
                        idx = (np.abs(stim_time - spike_times[i])).argmin()
                        conv_vec[idx] = 1

                    smspikes = np.convolve(conv_vec, conv_filt, 'same')  # perform convolution
                    smspikes = smspikes[0:len(stim_time)]  #
                    MN[f"MN_{ii}"]['smspikes'] = smspikes
                    average_rate = average_rate + smspikes
                    logging.info(f"{comb_id} >> OUTER: - >>> MN_{ii} completed computing!")
                    del spike_times
                    gc.collect()
                except Exception as e:
                    logging.exception(e)

            mn_recruited = 0
            for n in range(NMN):
                mn_recruited = mn_recruited + float(MN['MN_' + str(n)]["recruited"])
                MN['mn_recruited'] = mn_recruited

            average_rate = average_rate / mn_recruited  # Calculate average rate
            MN['average_rate'] = average_rate

            # Check if msq < 0.5 and mn_recruited == nmn
            msq = np.mean((stim_target - average_rate) ** 2)  # mean squared error
            MN['msq'] = msq

            if msq < 0.5 and mn_recruited == NMN:
                leave_for_loop = True
                logging.info('OUTER: - Mean Squared Error: ' + str(msq))
                logging.info('OUTER: - n_try = ' + str(try_id) + ', Search stopped')

            # Calculate Error and Recalibrate excitatory input
            GAIN_NEW = GAIN
            if msq < 3:
                GAIN_NEW = 0.5 * GAIN
            elif msq < 1.5:
                GAIN_NEW = 0.25 * GAIN
            error = (stim_target - average_rate) * GAIN_NEW
            excitatory_input = excitatory_input + error

            # Kludge factor if MN recruited is less than 18
            mulvec = np.loadtxt('targets_and_commands/mulvec.txt')
            if (NMN - mn_recruited) > 2:
                excitatory_input = np.multiply(excitatory_input, mulvec)
                
            excitatory_input[excitatory_input <= 0] = 1e-7  # Set negative excitatory values to 0

            MN['error'] = error
            gc.collect()

            ###################################
            # Save Data
            ###################################
            seed_value = comb[3]
            dir_name = f'{SMULTSTRT}{SMULTEND}{SMULTSKEW}_{seed_value}'
            path_name = os.path.join(os.getcwd(), 'new_data', noise_name, dir_name)
            if not os.path.exists(path_name):
                os.makedirs(path_name, exist_ok=True)
            sweep_name = f"{ginmult:.3f}_{nmmult:.3f}_{ginadd:.3f}_n_try_{try_id}_{comb[3]}"

            if leave_for_loop or try_id == (N_TRY-1):
                MN['try_final'] = try_id
                filename = f"pickle_{sweep_name}_FINAL.pkl"
                with open(os.path.join(path_name, filename), "wb") as f:
                    pickle.dump(MN, f)
                file_path = os.path.join(path_name, filename)

            del MN
            gc.collect()

            if leave_for_loop:
                break

        del target, stim_time, stim_target, conv_filt, excitatory_input, average_rate
        gc.collect()
        return (False, comb)
    except Exception as e:
        logging.exception(e)
        return (True, comb)


def generate_MN_skeleton(std_noise_value, h_object):
    # Define sections
    soma = h_object.Section(name='soma')  # Define soma section
    d1 = h_object.Section(name='d1')  # Define dend section
    d2 = h_object.Section(name='d2')  # Define dend section
    d3 = h_object.Section(name='d3')  # Define dend section
    d4 = h_object.Section(name='d4')  # Define dend section

    dend = h_object.SectionList()
    dend.append(sec=d1)
    dend.append(sec=d2)
    dend.append(sec=d3)
    dend.append(sec=d4)

    # Connect dendrites to soma
    d1.connect(soma(1), 0)
    d2.connect(soma(1), 0)
    d3.connect(soma(0), 0)
    d4.connect(soma(0), 0)
    # h_object.topology()

    # Soma Channel Definition
    soma.insert('na3rp')  # Insert na3rp channel: Transient Na Channel
    soma.insert('naps')  # Insert naps channel: Persistent Na Channel
    soma.insert('kdrRL')  # Insert kdrRL channel: Delayed Rectifier K Channel
    soma.insert('mAHP')  # Insert AHP mechanism
    h_object.tmin_kdrRL = 0.8
    h_object.taumax_kdrRL = 20
    h_object.mVh_kdrRL = -21

    # All Section Channel Definition
    for sec in h_object.allsec():
        sec.insert('pas')  # Insert passive mechanism
        sec.insert('gh')  # Insert gh mechanism: HCN channel (look pubmed)
        #sec.e_pas = -70  # Leak reversal potential mV changes in the MN_type Files

    # Dentrite Channel Definition
    for sec in dend:
        sec.insert('L_Ca_inact')  # Insert L_Ca_inact channel
        sec.insert('Gfluctdv')  # Insert Gfluctdv mechanism
        sec.g_e0_Gfluctdv = 1e-5  # Average conductance - when multiplied by multex that is actual conductance S/cm^2
        sec.g_i0_Gfluctdv = 1e-5  # Seimens/cm2
        sec.tau_e_Gfluctdv = 20  # Time constant of filtered noise - see description in Gfluctdv.mod
        sec.tau_i_Gfluctdv = 20
        sec.std_e_Gfluctdv = std_noise_value  # 1.2e-5  # Standard deviation of noise
        sec.std_i_Gfluctdv = std_noise_value  # 1.2e-5
    return soma, d1, d2, d3, d4, dend


def main():
    start = time.time()
    noise_name = 'seed_correlated'
    comb_id = "A"

    logging.info('Starting Jobs')
    job_dict = {pc_main.submit(job, [GINMULT, NMMULT, GINADD, RANDOM_SEED], comb_id, noise_name)}
    logging.info(job_dict)

    while pc_main.working():
        ret = pc_main.pyret()
        logging.info('ret: {}'.format(len(ret)))
        out_df, comb = ret
        if out_df is False:
            logging.info('MAIN_LAUNCHER: - SUCCESS! comb succeeded: ' + str(comb))
        elif out_df is True:
            logging.info('MAIN_LAUNCHER: - ERROR! comb failed: ' + str(comb))

    stop = time.time()
    print("Sweep Time (s): " + str(stop - start))
    pc_main.done()
    return


if __name__ == '__main__':
    pc_main.runworker()
    main()