import os
import sys
import json
from time import strftime, localtime
from time import time as TIME
import pickle
from itertools import chain
import argparse as arg
import logging

import numpy as np
from numpy.random import RandomState, SeedSequence, MT19937
import matplotlib.pyplot as plt

from neuron import h

from dlutils.cell import Cell, branch_order
from dlutils.synapse import AMPANMDAExp2Synapse
from dlutils.spine import Spine
from dlutils.utils import extract_mechanisms
from dlutils.morpho import Tree
from dlutils.graphics import plot_tree


prog_name = os.path.basename(sys.argv[0])


def make_gamma_spike_train(k, rate, tend=None, Nev=None, refractory_period=0, random_state=None):
    from scipy.stats import gamma
    if Nev is None and tend is not None:
        Nev = int(np.ceil(tend * rate))
    ISIs = []
    while len(ISIs) < Nev:
        ISI = gamma.rvs(k, loc=0, scale=1 / (k * rate), size=1, random_state=random_state)
        if ISI > refractory_period:
            ISIs.append(ISI)
    spks = np.cumsum(ISIs)
    if tend is not None:
        spks = spks[spks < tend]
    return spks


make_poisson_spike_train = lambda rate, tend=None, Nev=None, refractory_period=0, random_state=None: \
    make_gamma_spike_train(1, rate, tend, Nev, refractory_period, random_state)


if __name__ == '__main__':

    parser = arg.ArgumentParser(description='Simulate the activation of spines on the dendrite of a neuron model.')
    parser.add_argument('config_file', type=str, action='store', help='configuration file')
    parser.add_argument('--plot-morpho', action='store_true',
                        help='plot the morphology with the spines highlighted (default: no)')
    parser.add_argument('--save-traces', action='store_true',
                        help='save also voltage and current time series (default: no)')
    args = parser.parse_args(args=sys.argv[1:])

    config_file = args.config_file
    if not os.path.isfile(args.config_file):
        print(f'{prog_name}: {config_file}: no such file')
        sys.exit(1)
        
    ts = strftime('%Y%m%d-%H%M%S', localtime())
    config = json.load(open(config_file, 'r'))

    optimization_folder = config['optimization_folder']
    if not os.path.isdir(optimization_folder):
        print(f'{prog_name}: {optimization_folder}: no such directory')
        sys.exit(2)
    if optimization_folder[-1] != os.path.sep:
        optimization_folder += os.path.sep

    log_fmt = logging.Formatter('%(asctime)s  |  %(message)s', '%Y-%m-%d  %H:%M:%S')
    logger = logging.getLogger()
    file_hndl = logging.FileHandler(f'synaptic_activation_{ts}.log')
    file_hndl.setFormatter(log_fmt)
    logger.addHandler(file_hndl)
    console_hndl = logging.StreamHandler(sys.stdout)
    console_hndl.setFormatter(log_fmt)
    logger.addHandler(console_hndl)
    logger.setLevel(logging.INFO)
    
    try:
        seed = config['seed']
    except:
        with open('/dev/urandom', 'rb') as fid:
            seed = int.from_bytes(fid.read(4), 'little')
            config['seed'] = seed

    logger.info(f'Random number generator seed: {seed}')
    rs = RandomState(MT19937(SeedSequence(seed)))
    
    cell_type = config['cell_type']
    prefix = cell_type.capitalize()
    base_folder = optimization_folder + prefix + os.path.sep + config['cell_name'] + os.path.sep + config['optimization_run'] + '/'
    swc_file = config['swc_file']
    cell_name = config['cell_name']
    individual = config['individual']

    swc_file = base_folder + swc_file
    params_file = base_folder + f'individual_{individual}.json'
    config_file = base_folder + 'parameters.json'

    parameters = json.load(open(params_file, 'r'))
    try:
        mechanisms = extract_mechanisms(config_file, cell_name)
    except:
        # be a little flexible in the naming of the cells
        cell_name += '_'
        mechanisms = extract_mechanisms(config_file, cell_name)
    try:
        sim_pars = pickle.load(open(base_folder + 'simulation_parameters.pkl','rb'))
        replace_axon = sim_pars['replace_axon']
        add_axon_if_missing = not sim_pars['no_add_axon']
    except:
        replace_axon, add_axon_if_missing = True, True


    #############################
    ### simulation parameters ###
    #############################

    stim_dur = config['sim']['stim_dur']
    delay    = config['sim']['delay']
    after    = config['sim']['after']
    tstop    = delay + stim_dur + after


    ######################################
    ### build and instantiate the cell ###
    ######################################

    cell = Cell('cell_%d' % int(rs.uniform()*1e5), swc_file, parameters, mechanisms)
    cell.instantiate(replace_axon, add_axon_if_missing, force_passive=False, TTX=False)
    section_num = config['section_num']
    section = cell.morpho.apic[section_num]
    Ra = section.Ra * config['Ra_neck_coeff']
    logger.info(f'Branch order of section {section.name()}: {branch_order(section)}')


    ##############################
    ### instantiate the spines ###
    ##############################

    # in the Harnett paper, the head is spherical with a diameter of 0.5 um: a cylinder
    # with diameter and length equal to 0.5 has the same (outer) surface area as the sphere
    head_L = config['spine']['head_L']           # [um]
    head_diam = config['spine']['head_diam']     # [um]
    neck_L = config['spine']['neck_L']           # [um]
    neck_diam = config['spine']['neck_diam']     # [um]
    spine_distance = config['spine_distance']    # [um] distance between neighboring spines
    n_spines = config['n_spines']                # number of spines
    L = spine_distance * (n_spines - 1)
    norm_L = L / section.L

    spine_loc = config['spine_loc']
    start, stop = spine_loc + norm_L/2 * np.array([-1,1])
    if start < 0:
        start = 0
        stop = start + norm_L
    if stop > 1:
        stop = 1
        start = stop - norm_L
    spines = [Spine(section, x, head_L, head_diam, neck_L, neck_diam, Ra, i) \
              for i,x in enumerate(np.linspace(start, stop, n_spines))]

    for spine in spines:
        spine.instantiate()
    logger.info(f'Spines axial resistivity: {Ra:.1f} Ohm cm')
    
    if args.plot_morpho:
        ### Show where the spine is located on the dendritic tree
        tree = Tree(swc_file)
        height = 4
        width = height * tree.xy_ratio
        fig,ax = plt.subplots(1, 1, figsize=(width, height))
        plot_tree(tree, type_ids=(1,3,4), ax=ax, scalebar_length=100, bounds=[tree.bounds[0,:], tree.bounds[1,:]])
        # label all the sections
        for sec in chain(cell.morpho.apic, cell.morpho.dend):
            if sec in cell.morpho.apic:
                color = 'g'
            else:
                color = 'm'
            lbl = sec.name().split('.')[1].split('[')[1][:-1]
            n = sec.n3d()
            sec_coords = np.zeros((n,2))
            for i in range(n):
                sec_coords[i,:] = np.array([sec.x3d(i), sec.y3d(i)])
            middle = int(n / 2)
            plt.text(sec_coords[middle,0], sec_coords[middle,1], lbl, fontsize=3, color=color)
        ax.plot(spine._points[:,0], spine._points[:,1], 'ro', markerfacecolor='r', markersize=2)
        plt.axis('equal')
        plt.axis('off')
        fig.tight_layout(pad=-0.1)
        fig.savefig(f'morpho_with_spines_{ts}.pdf')
        sys.exit(0)

    # check the location of the spines in terms of distinct segments
    segments = [section(spines[0]._sec_x)]
    segments_idx = [[0]]
    for i,spine in enumerate(spines[1:]):
        if section(spine._sec_x) == segments[-1]:
            segments_idx[-1].append(i+1)
        else:
            segments.append(section(spine._sec_x))
            segments_idx.append([i+1])
    if len(segments_idx) == 1:
        logger.info('All spines are connected to the same segment')
    elif len(segments_idx) == n_spines:
        logger.info('Each spine is connected to a different segment on the dendritic branch')
    else:
        for group in segments_idx:
            if len(group) > 1:
                logger.info(f'Spines {np.array(group)+1} are connected to the same segment')
            else:
                logger.info(f'Spine {group[0]+1} is connected to a distinct segment')


    ########################################
    ### insert a synapse into each spine ###
    ########################################

    MG_MODELS = {'MDS': 1, 'HRN': 2, 'JS': 3}
    Mg_unblock_model = config['NMDA']['model']

    E = config['E_syn'] # [mV] reversal potential of the synapses

    AMPA_taus = config['AMPA']['time_constants']
    NMDA_taus = config['NMDA']['time_constants']
    weights = np.array([config['AMPA']['weight'], config['NMDA']['weight']])

    
    logger.info('AMPA:')
    logger.info('    tau_rise = {:.3f} ms'.format(AMPA_taus['tau1']))
    logger.info('   tau_decay = {:.3f} ms'.format(AMPA_taus['tau2']))
    logger.info('NMDA:')
    logger.info('    tau_rise = {:.3f} ms'.format(NMDA_taus['tau1']))
    logger.info('   tau_decay = {:.3f} ms'.format(NMDA_taus['tau2']))

    synapses = [AMPANMDAExp2Synapse(spine.head, 1, E, weights, AMPA = AMPA_taus, \
                                    NMDA = NMDA_taus) for spine in spines]

    for syn in synapses:
        syn.nmda_syn.mg_unblock_model = MG_MODELS[Mg_unblock_model]
        if Mg_unblock_model == 'MDS':
            syn.nmda_syn.alpha_vspom = config['NMDA']['alpha_vspom']
            syn.nmda_syn.v0_block = config['NMDA']['v0_block']
            syn.nmda_syn.eta = config['NMDA']['eta']
        elif Mg_unblock_model == 'JS':
            syn.nmda_syn.Kd = config['NMDA']['Kd']
            syn.nmda_syn.gamma = config['NMDA']['gamma']
            syn.nmda_syn.sh = config['NMDA']['sh']

    if Mg_unblock_model == 'MDS':
        logger.info('Using Maex & De Schutter Mg unblock model. Modified parameters:')
        logger.info('       alpha = {:.3f} 1/mV'.format(synapses[0].nmda_syn.alpha_vspom))
        logger.info('    v0_block = {:.3f} mV'.format(synapses[0].nmda_syn.v0_block))
        logger.info('         eta = {:.3f}'.format(synapses[0].nmda_syn.eta))
    elif Mg_unblock_model == 'JS':
        logger.info('Using Jahr & Stevens Mg unblock model. Modified parameters:')
        logger.info('          Kd = {:.3f} 1/mV'.format(synapses[0].nmda_syn.Kd))
        logger.info('       gamma = {:.3f} 1/mV'.format(synapses[0].nmda_syn.gamma))
        logger.info('          sh = {:.3f} mV'.format(synapses[0].nmda_syn.sh))
    elif Mg_unblock_model == 'HRN':
        logger.info('Using Harnett Mg unblock model with default parameters')


    ###########################################
    ### compute the presynaptic spike times ###
    ###########################################

    # spines will be activated in a Poisson fashion with this average interval between activations
    F_burst = config['synaptic_activation_frequency']
    if F_burst > 0:
        n_bursts = stim_dur * F_burst
        if n_bursts > 0:
            presyn_burst_times = 2 * delay + make_poisson_spike_train(F_burst, Nev=n_bursts,
                                                                      refractory_period=0.1 / F_burst,
                                                                      random_state=rs) * 1e3
            presyn_burst_times = presyn_burst_times[presyn_burst_times < delay + stim_dur]
    
            try:
                F = config['poisson_frequency']
                if F <= 0:
                    raise Exception('poisson_frequency must be > 0')
                poisson = True
            except:
                poisson = False
                spike_dt = config['spike_dt']
        
            presyn_spike_times = [np.array([]) for _ in range(n_spines)]

            for t0 in presyn_burst_times:
                if poisson:
                    spks = make_poisson_spike_train(F, Nev=n_spines,
                                                    refractory_period=config['sim']['dt'] * 5 * 1e-3,
                                                    random_state=rs) * 1e3
                    for i,j in enumerate(rs.permutation(n_spines)):
                        presyn_spike_times[j] = np.append(presyn_spike_times[j], t0 + spks[i])
                else:
                    for i in range(n_spines):
                        presyn_spike_times[i] = np.append(presyn_spike_times[i], t0 + i * spike_dt)

        else:
            logger.info('No presynaptic stimulation')
            presyn_burst_times = np.array([])
            presyn_spike_times = np.array([])

    else:
        F_burst *= -1
        T = 1 / F_burst * 1e3
        presyn_spike_times = [np.sort(config['sim']['delay'] + (n_spines - 1) * T - np.arange(i) * T)
                              for i in range(n_spines, 0, -1)]
        for i in range(n_spines):
            for j in range(len(presyn_spike_times[i])):
                presyn_spike_times[i][j] += i * config['spike_dt']
        presyn_burst_times = presyn_spike_times[0]
        stim_dur = n_spines * T
        tstop = delay + stim_dur + after

    if len(presyn_spike_times) > 0:
        logger.info('Presynaptic spike times:')
        for i in range(n_spines):
            # sort the presynaptic spike times so that we never run in the situation that
            # the i-th spike should have arrived before the (i-1)-th, which messes up NEURON
            presyn_spike_times[i] = np.sort(presyn_spike_times[i])
            logger.info(f'Spine {i+1}: t = {presyn_spike_times[i]}')
        for syn, spks in zip(synapses, presyn_spike_times):
            syn.set_presynaptic_spike_times(spks)

    ############################
    ### make the OU stimulus ###
    ############################

    try:
        dt = config['sim']['dt']        # [ms]
    except:
        dt = 0.025 # Neuron default time step
    OU = {}
    OU['t'] = np.arange(0, tstop, dt)
    OU['x'] = np.zeros(OU['t'].size)
    OU['rnd'] = rs.normal(size=OU['t'].size)
    for par in 'mean', 'stddev', 'tau':
        OU[par] = config['OU'][par]
    OU['const'] = 2 * OU['stddev']**2 / OU['tau']
    OU['mu'] = np.exp(-dt / OU['tau'])
    OU['coeff'] = np.sqrt(OU['const'] * OU['tau'] / 2 * (1 - OU['mu'] ** 2))
    idx, = np.where((OU['t'] >= delay)  & (OU['t'] <= delay + stim_dur))
    OU['x'][idx[0]] = OU['mean']
    for i in idx[1:]:
        OU['x'][i] = OU['mean'] + OU['mu'] * (OU['x'][i-1] - OU['mean']) + OU['coeff'] * OU['rnd'][i]

    vec = {key: h.Vector(OU[key]) for key in ('t','x')}

    stim = h.IClamp(cell.morpho.soma[0](0.5))
    if OU['stddev'] != 0:
        stim.dur = 10 * tstop
        vec['x'].play(stim._ref_amp, vec['t'], 1)
    else:
        stim.dur = stim_dur
        stim.delay = delay
        stim.amp = OU['mean']
        logger.info('The standard deviation of the OU process is zero: using conventional current clamp stimulus')

    ##########################
    ### make the recorders ###
    ##########################

    recorders = {}
    for lbl in 'time', 'Vsoma', 'Vdend', 'spike_times':
        recorders[lbl] = h.Vector()
    recorders['time'].record(h._ref_t)
    recorders['Vsoma'].record(cell.morpho.soma[0](0.5)._ref_v)
    recorders['Vdend'].record(section(spines[0]._sec_x)._ref_v)
    apc = h.APCount(cell.morpho.soma[0](0.5))
    apc.thresh = -20
    apc.record(recorders['spike_times'])


    ##########################
    ### run the simulation ###
    ##########################

    if OU['stddev'] != 0:
        h.cvode_active(0)
        h.dt = dt
        logger.info(f'Not using CVode: dt set to {dt:.3f} ms')
    else:
        h.cvode_active(1)
        logger.info('Using CVode')

    h.tstop = tstop
    logger.info('Running simulation')
    start = TIME()
    h.run()
    end = TIME()
    dur = int(end - start)
    hours = dur // 3600
    minutes = (dur % 3600) // 60
    secs = (dur % 60) % 60
    logger.info(f'Elapsed time: {hours:02d}:{minutes:02d}:{secs:02d}')
    

    #####################
    ### save the data ###
    #####################

    spike_times = np.array(recorders['spike_times'])
    data = {
        'config': config,
        'Ra': Ra,
        'presyn_burst_times': presyn_burst_times,
        'presyn_spike_times': presyn_spike_times,
        'spike_times': spike_times
    }
    if args.save_traces:
        data['OU_t'] = OU['t']
        data['OU_x'] = OU['x']
        for key in recorders:
            if key != 'spike_times':
                data[key] = np.array(recorders[key])
    logger.info(f'Saving data to synaptic_activation_{ts}.npz')
    np.savez_compressed(f'synaptic_activation_{ts}.npz', **data)
    ISI = np.diff(spike_times) * 1e-3
    firing_rate = len(spike_times) / stim_dur * 1e3
    CV = ISI.std() / ISI.mean()
    logger.info(f'Firing rate = {firing_rate:.2f} spike/s')
    logger.info(f'CV = {CV:.4f}')

    #####################
    ### plot a figure ### 
    #####################

    logger.info(f'Plotting simulation results to synaptic_activation_{ts}.pdf')
    try:
        time = data['time']
        Vsoma = data['Vsoma']
    except:
        time = np.array(recorders['time'])
        Vsoma = np.array(recorders['Vsoma'])
    fig,ax = plt.subplots(1, 1, figsize=(6,4))
    ax.plot(time, Vsoma, 'k', lw=1)
    ax.set_xlabel('Time (ms)')
    ax.set_ylabel('Vm (mV)')
    for side in 'top', 'right':
        ax.spines[side].set_visible(False)
    plt.savefig(f'synaptic_activation_{ts}.pdf')
    fig.tight_layout()