__author__ = 'milsteina'
from specify_cells import *
from plot_results import *
import random

morph_filename = 'EB2-late-bifurcation.swc'
mech_filename = '043016 Type A - km2_NMDA_KIN5_Pr'


def plot_rinp_figure(rec_filename, svg_title=None):
    """
    Expects an output file generated by parallel_rinp.
    File contains voltage recordings from dendritic compartments probed with hyperpolarizing current injections to
    measure steady-state r_inp. Plots r_inp vs distance to soma, with all dendritic sec_types superimposed.
    :param rec_filename: str
    :param svg_title: str
    """
    dt = 0.02
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 20
    sec_types = ['basal', 'trunk', 'apical', 'tuft']
    distances = {}
    r_inp = {}
    fig, axes = plt.subplots(1)
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for sec_type in sec_types:
        distances[sec_type] = []
        r_inp[sec_type] = []
    maxval, minval = None, None
    with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
        trial = f.itervalues().next()
        amp = trial['stim']['0'].attrs['amp']
        start = trial['stim']['0'].attrs['delay']
        stop = start + trial['stim']['0'].attrs['dur']
        for trial in f.itervalues():
            rec = trial['rec']['0']
            sec_type = rec.attrs['type']
            if sec_type in sec_types:
                distances[sec_type].append(rec.attrs['soma_distance'])
                if sec_type == 'basal':
                    distances[sec_type][-1] *= -1
                this_rest, this_peak, this_steady = get_Rinp(trial['time'][:], rec[:], start, stop, amp)
                r_inp[sec_type].append(this_steady)
                if maxval is None:
                    maxval = this_steady
                else:
                    maxval = max(maxval, this_steady)
                if minval is None:
                    minval = this_steady
                else:
                    minval = min(minval, this_steady)
    for i, sec_type in enumerate(sec_types):
        axes.scatter(distances[sec_type], r_inp[sec_type], label=sec_type, color=colors[i])
    axes.set_xlabel('Distance to soma (um)')
    axes.set_title('Input resistance gradient', fontsize=mpl.rcParams['font.size'])
    axes.set_ylabel('Input resistance (MOhm)')
    axes.set_xlim(-200., 525.)
    axes.set_xticks([-150., 0., 150., 300., 450.])
    if (maxval is not None) and (minval is not None):
        buffer = 0.1 * (maxval - minval)
        axes.set_ylim(minval - buffer, maxval + buffer)
    clean_axes(axes)
    axes.tick_params(direction='out')
    #plt.legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    if not svg_title is None:
        fig.set_size_inches(5.27, 4.37)
        fig.savefig(data_dir+svg_title+' - Rinp gradient.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size


def get_spike_shape(vm, equilibrate=250., th_dvdt=10., dt=0.01):
    """

    :param vm: array
    :param equilibrate: float
    :param th_dvdt: float
    :param dt: float
    :return: tuple of float: (v_peak, th_v)
    """
    vm = vm[int((equilibrate+0.4)/dt):]
    dvdt = np.gradient(vm, [dt])
    th_x = np.where(dvdt > th_dvdt)[0]
    if th_x.any():
        th_x = th_x[0] - int(1.6/dt)
    else:
        th_x = np.where(vm > -30.)[0][0] - int(2./dt)
    th_v = vm[th_x]
    v_peak = np.max(vm[th_x:th_x+int(5./dt)])
    return v_peak, th_v


def plot_bAP_attenuation_figure(rec_filename, svg_title=None, dt=0.01):
    """
    Expects an output file generated by record_bAP_attenuation.
    File contains voltage recordings from dendritic compartments probed with a somatic current injections to
    measure spike attenuation. Plots spike height vs distance to soma, with all dendritic sec_types superimposed.
    :param rec_filename: str
    :param svg_title: str
    :param dt: float
    """
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 20
    sec_types = ['basal', 'trunk', 'apical', 'tuft']
    distances = {}
    spike_height = {}
    fig, axes = plt.subplots(1)
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for sec_type in sec_types:
        distances[sec_type] = []
        spike_height[sec_type] = []
    maxval, minval = None, None
    with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
        trial = f.itervalues().next()
        rec = trial['rec']['0']
        equilibrate = trial['stim']['0'].attrs['delay']
        duration = trial['stim']['1'].attrs['dur']
        t = np.arange(0., duration, dt)
        soma_vm = np.interp(t, trial['time'], rec[:])
        v_peak, v_th = get_spike_shape(soma_vm, equilibrate)
        soma_height = v_peak - v_th
        x_th = np.where(soma_vm[int(equilibrate/dt):] >= v_th)[0][0] + int(equilibrate/dt)
        for rec in (rec for rec in trial['rec'].itervalues() if not rec.attrs['type'] == 'soma'):
            sec_type = rec.attrs['type']
            if sec_type in sec_types:
                distances[sec_type].append(rec.attrs['soma_distance'])
                if sec_type == 'basal':
                    distances[sec_type][-1] *= -1
                local_vm = np.interp(t, trial['time'], rec[:])
                local_peak = np.max(local_vm[x_th:x_th+int(10./dt)])
                local_pre = np.mean(local_vm[x_th-int(0.2/dt):x_th-int(0.1/dt)])
                local_height = local_peak - local_pre
                local_height /= soma_height
                spike_height[sec_type].append(local_height)
                if maxval is None:
                    maxval = local_height
                else:
                    maxval = max(maxval, local_height)
                if minval is None:
                    minval = local_height
                else:
                    minval = min(minval, local_height)
    for i, sec_type in enumerate(sec_types):
        axes.scatter(distances[sec_type], spike_height[sec_type], label=sec_type, color=colors[i])
    axes.set_xlabel('Distance to soma (um)')
    axes.set_title('bAP attenuation gradient', fontsize=mpl.rcParams['font.size'])
    axes.set_ylabel('Normalized spike amplitude')
    axes.set_xlim(-200., 525.)
    axes.set_xticks([-150., 0., 150., 300., 450.])
    if (maxval is not None) and (minval is not None):
        buffer = 0.1 * (maxval - minval)
        axes.set_ylim(minval - buffer, maxval + buffer)
    clean_axes(axes)
    axes.tick_params(direction='out')
    #plt.legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    if not svg_title is None:
        fig.set_size_inches(5.27, 4.37)
        fig.savefig(data_dir+svg_title+' - bAP attenuation.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size


def plot_EPSP_amplitude_figure(rec_filename, rec_loc=None, title=None, svg_title=None, dt=0.01):
    """
    Expects an output file generated by record_bAP_attenuation.
    File contains voltage recordings from dendritic compartments probed with a somatic current injections to
    measure spike attenuation. Plots spike height vs distance to soma, with all dendritic sec_types superimposed.
    :param rec_filename: str
    :param rec_loc: str
    :param title: str
    :param svg_title: str
    :param dt: float
    """
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 20
    sec_types = ['basal', 'trunk', 'apical', 'tuft']
    distances = {}
    EPSP_amp = {}
    fig, axes = plt.subplots(1)
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for sec_type in sec_types:
        distances[sec_type] = []
        EPSP_amp[sec_type] = []
    maxval, minval = None, None
    if rec_loc is None:
        rec_loc = 'soma'
    if title is None:
        title = 'Soma'
    with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
        trial = f.itervalues().next()
        for rec_key in trial['rec']:
            if trial['rec'][rec_key].attrs['description'] == rec_loc:
                break
        equilibrate = trial.attrs['equilibrate']
        duration = trial.attrs['duration']
        t = np.arange(0., duration, dt)
        for trial in f.itervalues():
            input_loc = trial.attrs['input_loc']
            if input_loc in sec_types:
                distances[input_loc].append(trial['rec']['3'].attrs['soma_distance'])
                if input_loc == 'basal':
                    distances[input_loc][-1] *= -1
                rec = trial['rec'][rec_key]
                vm = np.interp(t, trial['time'][:], rec[:])
                left, right = time2index(t, equilibrate - 3.0, equilibrate - 1.0)
                baseline = np.mean(vm[left:right])
                start, end = time2index(t, equilibrate, duration)
                this_amp = np.max(vm[start:end]) - baseline
                EPSP_amp[input_loc].append(this_amp)
                if maxval is None:
                    maxval = this_amp
                else:
                    maxval = max(maxval, this_amp)
                if minval is None:
                    minval = this_amp
                else:
                    minval = min(minval, this_amp)
    for i, sec_type in enumerate(sec_types):
        axes.scatter(distances[sec_type], EPSP_amp[sec_type], label=sec_type, color=colors[i])
    axes.set_xlabel('Distance to soma (um)')
    axes.set_title(title+' EPSP amplitude', fontsize=mpl.rcParams['font.size'])
    axes.set_ylabel('EPSP amplitude (mV)')
    axes.set_xlim(-200., 525.)
    axes.set_xticks([-150., 0., 150., 300., 450.])
    if (maxval is not None) and (minval is not None):
        buffer = 0.1 * (maxval - minval)
        axes.set_ylim(minval - buffer, maxval + buffer)
    clean_axes(axes)
    axes.tick_params(direction='out')
    #plt.legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    if not svg_title is None:
        fig.set_size_inches(5.27, 4.37)
        fig.savefig(data_dir+svg_title+' - '+title+' EPSP amplitude.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size


def plot_EPSP_duration_figure(rec_filename, rec_loc=None, title=None, svg_title=None, dt=0.01):
    """
    Expects an output file generated by record_bAP_attenuation.
    File contains voltage recordings from dendritic compartments probed with a somatic current injections to
    measure spike attenuation. Plots spike height vs distance to soma, with all dendritic sec_types superimposed.
    :param rec_filename: str
    :param rec_loc: str
    :param title: str
    :param svg_title: str
    :param dt: float
    """
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 20
    sec_types = ['basal', 'trunk', 'apical', 'tuft']
    distances = {}
    EPSP_duration = {}
    fig, axes = plt.subplots(1)
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for sec_type in sec_types:
        distances[sec_type] = []
        EPSP_duration[sec_type] = []
    maxval, minval = None, None
    if rec_loc is None:
        rec_loc = 'soma'
    if title is None:
        title = 'Soma'
    with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
        trial = f.itervalues().next()
        for rec_key in trial['rec']:
            if trial['rec'][rec_key].attrs['description'] == rec_loc:
                break
        equilibrate = trial.attrs['equilibrate']
        duration = trial.attrs['duration']
        t = np.arange(0., duration, dt)
        for trial in f.itervalues():
            input_loc = trial.attrs['input_loc']
            if input_loc in sec_types:
                distances[input_loc].append(trial['rec']['3'].attrs['soma_distance'])
                if input_loc == 'basal':
                    distances[input_loc][-1] *= -1
                rec = trial['rec'][rec_key]
                vm = np.interp(t, trial['time'][:], rec[:])
                left, right = time2index(t, equilibrate - 3.0, equilibrate - 1.0)
                baseline = np.mean(vm[left:right])
                start, end = time2index(t, equilibrate, duration)
                interp_t = np.array(t[start:end])
                interp_t -= interp_t[0]
                vm = vm[start:end] - baseline
                amp = np.max(vm)
                t_peak = np.where(vm == amp)[0][0]
                vm /= amp
                rise_50 = np.where(vm[:t_peak] >= 0.5)[0][0]
                decay_50 = np.where(vm[t_peak:] <= 0.5)[0][0]
                this_duration = interp_t[decay_50] + interp_t[t_peak] - interp_t[rise_50]
                EPSP_duration[input_loc].append(this_duration)
                if maxval is None:
                    maxval = this_duration
                else:
                    maxval = max(maxval, this_duration)
                if minval is None:
                    minval = this_duration
                else:
                    minval = min(minval, this_duration)
    for i, sec_type in enumerate(sec_types):
        axes.scatter(distances[sec_type], EPSP_duration[sec_type], label=sec_type, color=colors[i])
    axes.set_xlabel('Distance to soma (um)')
    axes.set_title(title+' EPSP duration', fontsize=mpl.rcParams['font.size'])
    axes.set_ylabel('EPSP duration (ms)')
    axes.set_xlim(-200., 525.)
    axes.set_xticks([-150., 0., 150., 300., 450.])
    if (maxval is not None) and (minval is not None):
        buffer = 0.1 * (maxval - minval)
        axes.set_ylim(minval - buffer, maxval + buffer)
    clean_axes(axes)
    axes.tick_params(direction='out')
    #plt.legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    if not svg_title is None:
        fig.set_size_inches(5.27, 4.37)
        fig.savefig(data_dir+svg_title+' - '+title+' EPSP duration.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size


def plot_spine_attenuation_ratio_figure(rec_filename, svg_title=None, dt=0.01):
    """
    Expects an output file generated by parallel_spine_attenuation_ratio.
    File contains voltage recordings from spines and parent branches during EPSC-shaped current injections. Attenuation
    is measured as the ratio of branch to spine voltage amplitude. Plots attenuation ration vs distance to soma, with
    all dendritic sec_types superimposed.
    :param rec_filename: str
    :param svg_title: str
    :param dt: float
    """
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 20
    sec_types = ['basal', 'trunk', 'apical', 'tuft']
    distances = {}
    ratio = {}
    index_dict = {}
    fig, axes = plt.subplots(1)
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for sec_type in sec_types:
        distances[sec_type] = []
        ratio[sec_type] = []
    maxval, minval = None, None
    with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
        trial = f.itervalues().next()
        # amp = trial.attrs['amp']
        equilibrate = trial.attrs['equilibrate']
        duration = trial.attrs['duration']
        t = np.arange(0., duration, dt)
        for trial_key in f:
            trial = f[trial_key]
            stim_loc = trial.attrs['stim_loc']
            spine_rec = trial['rec']['0'] if trial['rec']['0'].attrs['description'] == 'spine' else trial['rec']['1']
            spine_index = spine_rec.attrs['index']
            if not spine_index in index_dict:
                index_dict[spine_index] = {}
            index_dict[spine_index][stim_loc] = trial_key
        for index in index_dict.itervalues():
            spine_stim = f[index['spine']]['rec']
            spine_tvec = f[index['spine']]['time']
            for rec in spine_stim.itervalues():
                if rec.attrs['description'] == 'branch':
                    branch_rec = rec
                    sec_type = rec.attrs['type']
                elif rec.attrs['description'] == 'spine':
                    spine_rec = rec
            distances[sec_type].append(spine_rec.attrs['soma_distance'])
            if sec_type == 'basal':
                distances[sec_type][-1] *= -1
            branch_vm = np.interp(t, spine_tvec[:], branch_rec[:])
            spine_vm = np.interp(t, spine_tvec[:], spine_rec[:])
            left, right = time2index(t, equilibrate - 3.0, equilibrate - 1.0)
            baseline_branch = np.mean(branch_vm[left:right])
            baseline_spine = np.mean(spine_vm[left:right])
            left, right = time2index(t, equilibrate, duration)
            peak_branch = np.max(branch_vm[left:right]) - baseline_branch
            peak_spine = np.max(spine_vm[left:right]) - baseline_spine
            this_ratio = peak_spine / peak_branch
            ratio[sec_type].append(this_ratio)
            if maxval is None:
                maxval = this_ratio
            else:
                maxval = max(maxval, this_ratio)
            if minval is None:
                minval = this_ratio
            else:
                minval = min(minval, this_ratio)
    for i, sec_type in enumerate(sec_types):
        axes.scatter(distances[sec_type], ratio[sec_type], label=sec_type, color=colors[i])
    axes.set_xlabel('Distance to soma (um)')
    axes.set_title('Spine to branch EPSP attenuation', fontsize=mpl.rcParams['font.size'])
    axes.set_ylabel('EPSP attenuation ratio')
    axes.set_xlim(-200., 525.)
    axes.set_xticks([-150., 0., 150., 300., 450.])
    if (maxval is not None) and (minval is not None):
        buffer = 0.1 * (maxval - minval)
        axes.set_ylim(minval - buffer, maxval + buffer)
    clean_axes(axes)
    axes.tick_params(direction='out')
    #plt.legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    if not svg_title is None:
        fig.set_size_inches(5.27, 4.37)
        fig.savefig(data_dir+svg_title+' - spine attenuation ratio.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size


NMDA_type = 'NMDA_KIN5'
excitatory_stochastic = 1
syn_types = ['AMPA_KIN', NMDA_type]

local_random = random.Random()

cell = CA1_Pyr(morph_filename, mech_filename, full_spines=True)
cell.set_terminal_branch_na_gradient()
cell.insert_inhibitory_synapses_in_subset()

# place synapses in every spine
for sec_type in ['basal', 'trunk', 'apical', 'tuft']:
    for node in cell.get_nodes_of_subtype(sec_type):
        for spine in node.spines:
            syn = Synapse(cell, spine, syn_types, stochastic=excitatory_stochastic)
cell.init_synaptic_mechanisms()

AMPAR_Po = 0.3252

"""
plot_mech_param_distribution(cell, 'pas', 'g', param_label='Leak conductance gradient', svg_title='051316 - cell107')
plot_mech_param_distribution(cell, 'h', 'ghbar', param_label='Ih conductance gradient', svg_title='051316 - cell107')
plot_mech_param_distribution(cell, 'nas', 'gbar', param_label='Na conductance gradient', svg_title='051316 - cell107')
plot_sum_mech_param_distribution(cell, [('kap', 'gkabar'), ('kad', 'gkabar')],
                                 param_label='A-type K conductance gradient', svg_title='051316 - cell107')
plot_synaptic_param_distribution(cell, 'AMPA_KIN', 'gmax', scale_factor=1000.*AMPAR_Po, yunits='nS',
                                 param_label='Synaptic AMPAR gradient', svg_title='051316 - cell107')
plot_rinp_figure('043016 Type A - km2_NMDA_KIN5_Pr - Rinp', '051316 - cell107')
plot_bAP_attenuation_figure('output051320161518-pid24213_bAP', '051316 - cell107')
plot_EPSP_amplitude_figure('043016 Type A - km2_NMDA_KIN5_Pr - epsp attenuation', svg_title='051316 - cell107')
plot_EPSP_amplitude_figure('043016 Type A - km2_NMDA_KIN5_Pr - epsp attenuation', rec_loc='branch', title='Branch',
                           svg_title='051316 - cell107')
plot_EPSP_duration_figure('043016 Type A - km2_NMDA_KIN5_Pr - epsp attenuation', svg_title='051316 - cell107')
plot_EPSP_duration_figure('043016 Type A - km2_NMDA_KIN5_Pr - epsp attenuation', rec_loc='branch', title='Branch',
                           svg_title='051316 - cell107')
"""
plot_spine_attenuation_ratio_figure('043016 Type A - km2_NMDA_KIN5_Pr - spine AR', svg_title='051316 - cell107')