__author__ = 'milsteina'
from function_lib import *
import matplotlib.lines as mlines
import matplotlib as mpl
import numpy as np
import scipy.signal as signal
import scipy.stats as stats

mpl.rcParams['svg.fonttype'] = 'none'
mpl.rcParams['font.size'] = 18.  # 18.
#mpl.rcParams['font.sans-serif'] = 'Arial'
mpl.rcParams['font.sans-serif'] = 'Calibri'
mpl.rcParams['text.usetex'] = False
"""
mpl.rcParams['axes.labelsize'] = 'larger'
mpl.rcParams['axes.titlesize'] = 'xx-large'
mpl.rcParams['xtick.labelsize'] = 'large'
mpl.rcParams['ytick.labelsize'] = 'large'
mpl.rcParams['legend.fontsize'] = 'x-large'
"""

def plot_AR(rec_file_list, description_list="", title=None):
    """
    Expects each file in list to be generated by parallel_spine_attenuation_ratio.
    Files contain voltage recordings from spine and branch probed with EPSC-shaped current injections to measure spine
    to branch EPSP amplitude attenuation ratio, dendritic branch impedance, and spine neck resistance. Plots these
    parameters vs distance from dendrite origin, with one column per dendritic sec_type.
    Superimposes results from multiple files in list.
    :param rec_file_list: list of str
    :param description_list: list of str
    :param title: str
    """
    if not type(rec_file_list) == list:
        rec_file_list = [rec_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    default_sec_types = ['basal', 'trunk', 'apical', 'tuft']
    with h5py.File(data_dir+rec_file_list[0]+'.hdf5', 'r') as f:
        temp_sec_types = []
        for sim in [sim for sim in f.itervalues() if sim.attrs['stim_loc'] == 'spine']:
            rec = sim['rec']['0'] if sim['rec']['0'].attrs['description'] == 'branch' else sim['rec']['1']
            sec_type = rec.attrs['type']
            if not sec_type in temp_sec_types:
                temp_sec_types.append(sec_type)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    sec_types = [sec_type for sec_type in default_sec_types if sec_type in temp_sec_types]+\
                 [sec_type for sec_type in temp_sec_types if not sec_type in default_sec_types]
    distances = {}
    AR = {}
    dendR = {}
    neckR = {}
    fig, axes = plt.subplots(3, max(2, len(sec_types)))
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for index, rec_filename in enumerate(rec_file_list):
        index_dict = {}
        for sec_type in sec_types:
            distances[sec_type] = []
            AR[sec_type] = []
            dendR[sec_type] = []
            neckR[sec_type] = []
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            amp = f['0'].attrs['amp']
            equilibrate = f['0'].attrs['equilibrate']
            duration = f['0'].attrs['duration']
            # following parallel execution and combine_rec_files, the order of simulation records is shuffled
            # here the indices of paired records from spine_stim and branch_stim are collected
            for simiter in f:
                sim = f[simiter]
                stim_loc = sim.attrs['stim_loc']
                spine_rec = sim['rec']['0'] if sim['rec']['0'].attrs['description'] == 'spine' else sim['rec']['1']
                spine_index = spine_rec.attrs['index']
                if not spine_index in index_dict:
                    index_dict[spine_index] = {}
                index_dict[spine_index][stim_loc] = simiter
            for indices in index_dict.itervalues():
                spine_stim = f[indices['spine']]['rec']
                spine_tvec = f[indices['spine']]['time']
                branch_stim = f[indices['branch']]['rec']
                branch_tvec = f[indices['branch']]['time']
                for rec in spine_stim.itervalues():
                    if rec.attrs['description'] == 'branch':
                        branch_rec = rec
                        sec_type = rec.attrs['type']
                    elif rec.attrs['description'] == 'spine':
                        spine_rec = rec
                distances[sec_type].append(branch_rec.attrs['branch_distance'])
                interp_t = np.arange(0., duration, 0.001)
                interp_branch_vm = np.interp(interp_t, spine_tvec[:], branch_rec[:])
                interp_spine_vm = np.interp(interp_t, spine_tvec[:], spine_rec[:])
                left, right = time2index(interp_t, equilibrate-3.0, equilibrate-1.0)
                baseline_branch = np.average(interp_branch_vm[left:right])
                baseline_spine = np.average(interp_spine_vm[left:right])
                left, right = time2index(interp_t, equilibrate, duration)
                peak_branch = np.max(interp_branch_vm[left:right]) - baseline_branch
                peak_spine = np.max(interp_spine_vm[left:right]) - baseline_spine
                this_AR = peak_spine / peak_branch
                AR[sec_type].append(this_AR)
                branch_rec = branch_stim['0'] if branch_stim['0'].attrs['description'] == 'branch' else branch_stim['1']
                interp_t = np.arange(0., duration, 0.001)
                interp_branch_vm = np.interp(interp_t, branch_tvec[:], branch_rec[:])
                left, right = time2index(interp_t, equilibrate-3.0, equilibrate-1.0)
                baseline_branch = np.average(interp_branch_vm[left:right])
                left, right = time2index(interp_t, equilibrate, duration)
                peak_branch = np.max(interp_branch_vm[left:right]) - baseline_branch
                this_dendR = peak_branch / amp
                dendR[sec_type].append(this_dendR)
                this_neckR = (this_AR - 1) * this_dendR
                neckR[sec_type].append(this_neckR)
            for i, sec_type in enumerate(sec_types):
                axes[0][i].scatter(distances[sec_type], AR[sec_type], label=description_list[index],
                                   color=colors[index])
                axes[0][i].set_xlabel('Distance from Dendrite Origin (um)')  # , fontsize=20)
                axes[0][i].set_title(sec_type)  # , fontsize=28)
                axes[1][i].scatter(distances[sec_type], dendR[sec_type], label=description_list[index],
                                   color=colors[index])
                axes[1][i].set_xlabel('Distance from Dendrite Origin (um)')  # , fontsize=20)
                axes[1][i].set_title(sec_type)  # , fontsize=28)
                axes[2][i].scatter(distances[sec_type], neckR[sec_type], label=description_list[index],
                                   color=colors[index])
                axes[2][i].set_xlabel('Distance from Dendrite Origin (um)')  # , fontsize=20)
                axes[2][i].set_title(sec_type)  # , fontsize=28)
    axes[0][0].set_ylabel('Amplitude Ratio')  # , fontsize=20)
    axes[1][0].set_ylabel('R_Dend (MOhm)')  # , fontsize=20)
    axes[2][0].set_ylabel('R_Neck (MOhm)')  # , fontsize=20)
    if not description_list == [""]:
        axes[0][0].legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5)
    fig.subplots_adjust(hspace=0.45, wspace=0.3, left=0.06, right=0.94, top=0.94, bottom=0.06)
    if not title is None:
        fig.set_size_inches(20.8, 13)
        fig.savefig(data_dir+title+' - spine AR.svg', format='svg')
    plt.show()
    plt.close()


def plot_AR_EPSP_amp(rec_file_list, description_list="", title=None):
    """
    Expects each file in list to be generated by parallel_spine_attenuation.
    Files contain voltage recordings from spine and branch while injecting EPSC-shaped currents into either spine or
    branch to measure the amplitude attenuation ratio.
    Creates a grid of 16 plots of EPSP amp vs. time, with one row per dendritic sec_type and four columns containing all
    stimulation and recording conditions.
    Superimposes results from multiple files in list.
    :param rec_file_list: list of str
    :param description_list: list of str
    :param title: str
    """
    if not type(rec_file_list) == list:
        rec_file_list = [rec_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    default_sec_types = ['basal', 'trunk', 'apical', 'tuft']
    with h5py.File(data_dir+rec_file_list[0]+'.hdf5', 'r') as f:
        temp_sec_types = []
        for sim in [sim for sim in f.itervalues() if sim.attrs['stim_loc'] == 'spine']:
            rec = sim['rec']['0'] if sim['rec']['0'].attrs['description'] == 'branch' else sim['rec']['1']
            sec_type = rec.attrs['type']
            if not sec_type in temp_sec_types:
                temp_sec_types.append(sec_type)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    sec_types = [sec_type for sec_type in default_sec_types if sec_type in temp_sec_types]+\
                 [sec_type for sec_type in temp_sec_types if not sec_type in default_sec_types]
    distances = {}
    spine_amp = {'spine': {}, 'branch': {}}
    branch_amp = {'spine': {}, 'branch': {}}
    fig, axes = plt.subplots(max(2, len(sec_types)), 4)
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for index, rec_filename in enumerate(rec_file_list):
        index_dict = {}
        for sec_type in sec_types:
            distances[sec_type] = []
            for stim_loc in ['spine', 'branch']:
                spine_amp[stim_loc][sec_type] = []
                branch_amp[stim_loc][sec_type] = []
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            amp = f['0'].attrs['amp']
            equilibrate = f['0'].attrs['equilibrate']
            duration = f['0'].attrs['duration']
            # following parallel execution and combine_rec_files, the order of simulation records is shuffled
            # here the indices of paired records from spine_stim and branch_stim are collected
            for simiter in f:
                sim = f[simiter]
                stim_loc = sim.attrs['stim_loc']
                spine_rec = sim['rec']['0'] if sim['rec']['0'].attrs['description'] == 'spine' else sim['rec']['1']
                spine_index = spine_rec.attrs['index']
                if not spine_index in index_dict:
                    index_dict[spine_index] = {}
                index_dict[spine_index][stim_loc] = simiter
            for indices in index_dict.itervalues():
                spine_stim = f[indices['spine']]['rec']
                for rec in spine_stim.itervalues():
                    if rec.attrs['description'] == 'branch':
                        branch_rec = rec
                        sec_type = rec.attrs['type']
                distances[sec_type].append(branch_rec.attrs['branch_distance'])
                for stim_loc, stim, tvec in [(stim_loc, f[indices[stim_loc]]['rec'], f[indices[stim_loc]]['time'])
                                             for stim_loc in ['spine', 'branch']]:
                    for rec in stim.itervalues():
                        if rec.attrs['description'] == 'branch':
                            branch_rec = rec
                        else:
                            spine_rec = rec
                    interp_t = np.arange(0., duration, 0.001)
                    interp_branch_vm = np.interp(interp_t, tvec[:], branch_rec[:])
                    interp_spine_vm = np.interp(interp_t, tvec[:], spine_rec[:])
                    left, right = time2index(interp_t, equilibrate-3.0, equilibrate-1.0)
                    baseline_branch = np.average(interp_branch_vm[left:right])
                    baseline_spine = np.average(interp_spine_vm[left:right])
                    left, right = time2index(interp_t, equilibrate, duration)
                    peak_branch = np.max(interp_branch_vm[left:right]) - baseline_branch
                    peak_spine = np.max(interp_spine_vm[left:right]) - baseline_spine
                    spine_amp[stim_loc][sec_type].append(peak_spine)
                    branch_amp[stim_loc][sec_type].append(peak_branch)
            for i, sec_type in enumerate(sec_types):
                axes[i][0].scatter(distances[sec_type], branch_amp['branch'][sec_type], label=description_list[index],
                                   color=colors[index])
                axes[i][1].scatter(distances[sec_type], spine_amp['branch'][sec_type], label=description_list[index],
                                   color=colors[index])
                axes[i][2].scatter(distances[sec_type], branch_amp['spine'][sec_type], label=description_list[index],
                                   color=colors[index])
                axes[i][3].scatter(distances[sec_type], spine_amp['spine'][sec_type], label=description_list[index],
                                   color=colors[index])
    for i, sec_type in enumerate(sec_types):
        for j, label in enumerate(['Stim Branch - Record Branch', 'Stim Branch - Record Spine',
                                   'Stim Spine - Record Branch', 'Stim Spine - Record Spine']):
            axes[i][j].set_xlabel('Distance from Dendrite Origin (um)')
            axes[i][j].set_ylabel('Input Loc: '+sec_type+'\nEPSP Amplitude (mV)')
            axes[i][j].set_title(label)
    if not description_list == [""]:
        axes[0][0].legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5)
    fig.subplots_adjust(hspace=0.5, wspace=0.3, left=0.05, right=0.98, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(19.2, 12)
        fig.savefig(data_dir+title+' - spine AR - EPSP amp.svg', format='svg')
    plt.show()
    plt.close()


def plot_AR_vm(rec_file_list, description_list="", title=None):
    """
    Expects each file in list to be generated by parallel_spine_attenuation.
    Files contain voltage recordings from spine and branch while injecting EPSC-shaped currents into either spine or
    branch to measure the amplitude attenuation ratio.
    Creates a grid of 16 plots of vm vs. time, with one row per dendritic sec_type and four columns containing all
    stimulation and recording conditions.
    Superimposes results from multiple files in list.
    :param rec_file_list: list of str
    :param description_list: list of str
    :param title: str
    """
    if not type(rec_file_list) == list:
        rec_file_list = [rec_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    default_sec_types = ['basal', 'trunk', 'apical', 'tuft']
    with h5py.File(data_dir+rec_file_list[0]+'.hdf5', 'r') as f:
        temp_sec_types = []
        for sim in f.itervalues():
            rec = sim['rec']['0'] if sim['rec']['0'].attrs['description'] == 'branch' else sim['rec']['1']
            sec_type = rec.attrs['type']
            if not sec_type in temp_sec_types:
                temp_sec_types.append(sec_type)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    sec_types = [sec_type for sec_type in default_sec_types if sec_type in temp_sec_types]+\
                 [sec_type for sec_type in temp_sec_types if not sec_type in default_sec_types]
    fig, axes = plt.subplots(max(2, len(sec_types)), 4)
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    label_handles = []
    for index, rec_filename in enumerate(rec_file_list):
        index_dict = {}
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            equilibrate = f['0'].attrs['equilibrate']
            duration = f['0'].attrs['duration']
            # following parallel execution and combine_rec_files, the order of simulation records is shuffled
            # here the indices of paired records from spine_stim and branch_stim are collected
            for simiter in f:
                sim = f[simiter]
                stim_loc = sim.attrs['stim_loc']
                spine_rec = sim['rec']['0'] if sim['rec']['0'].attrs['description'] == 'spine' else sim['rec']['1']
                spine_index = spine_rec.attrs['index']
                if not spine_index in index_dict:
                    index_dict[spine_index] = {}
                index_dict[spine_index][stim_loc] = simiter
            for indices in index_dict.itervalues():
                spine_stim = f[indices['spine']]['rec']
                for rec in spine_stim.itervalues():
                    if rec.attrs['description'] == 'branch':
                        sec_type = rec.attrs['type']
                for stim_loc, stim, tvec in [(stim_loc, f[indices[stim_loc]]['rec'], f[indices[stim_loc]]['time'])
                                             for stim_loc in ['spine', 'branch']]:
                    for rec in stim.itervalues():
                        if rec.attrs['description'] == 'branch':
                            branch_rec = rec
                        else:
                            spine_rec = rec
                    j = 0 if stim_loc == 'branch' else 2
                    i = sec_types.index(sec_type)
                    interp_t = np.arange(0., duration, 0.01)
                    interp_branch_vm = np.interp(interp_t, tvec[:], branch_rec[:])
                    interp_spine_vm = np.interp(interp_t, tvec[:], spine_rec[:])
                    left, right = time2index(interp_t, equilibrate-5.0, duration)
                    interp_t -= interp_t[left] + 5.
                    axes[i][j].plot(interp_t[left:right], interp_branch_vm[left:right], color=colors[index])
                    axes[i][j+1].plot(interp_t[left:right], interp_spine_vm[left:right], color=colors[index])
        label_handles.append(mlines.Line2D([], [], color=colors[index], label=description_list[index]))
    for i, sec_type in enumerate(sec_types):
        for j, label in enumerate(['Stim Branch - Record Branch', 'Stim Branch - Record Spine',
                                   'Stim Spine - Record Branch', 'Stim Spine - Record Spine']):
            axes[i][j].set_xlabel('Time (ms)')
            axes[i][j].set_ylabel('Input Loc: '+sec_type+'\nVm (mV)')
            axes[i][j].set_title(label)
    if not description_list == [""]:
        axes[0][0].legend(handles=label_handles, framealpha=0.5, frameon=False)
    fig.subplots_adjust(hspace=0.5, wspace=0.3, left=0.05, right=0.98, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(19.2, 12)
        fig.savefig(data_dir+title+' - spine AR - traces.svg', format='svg')
    plt.show()
    plt.close()


def plot_Rinp(rec_file_list, description_list="", title=None):
    """
    Expects each file in list to be generated by parallel_rinp.
    Files contain voltage recordings from dendritic compartments probed with hyperpolarizing current injections to
    measure 1) peak r_inp, 2) steady-state r_inp, 3) their ratio, and 4) v_rest. Plots these parameters vs distance from
    dendrite origin, with one column per sec_type.
    Superimposes results from multiple files in list.
    :param rec_file_list: list of str
    :param description_list: list of str
    :param title: str
    """
    if not type(rec_file_list) == list:
        rec_file_list = [rec_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    default_sec_types = ['soma', 'basal', 'trunk', 'apical', 'tuft']
    with h5py.File(data_dir+rec_file_list[0]+'.hdf5', 'r') as f:
        temp_sec_types = []
        for sim in f.itervalues():
            rec = sim['rec']['0']
            sec_type = rec.attrs['type']
            if not sec_type in temp_sec_types:
                temp_sec_types.append(sec_type)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    sec_types = [sec_type for sec_type in default_sec_types if sec_type in temp_sec_types]+\
                 [sec_type for sec_type in temp_sec_types if not sec_type in default_sec_types]
    distances = {}
    peak = {}
    steady = {}
    sag = {}
    v_rest = {}
    fig, axes = plt.subplots(4, max(2, len(sec_types)))
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for index, rec_filename in enumerate(rec_file_list):
        for sec_type in sec_types:
            distances[sec_type] = []
            peak[sec_type] = []
            steady[sec_type] = []
            sag[sec_type] = []
            v_rest[sec_type] = []
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            amp = f['0']['stim']['0'].attrs['amp']
            start = f['0']['stim']['0'].attrs['delay']
            stop = start + f['0']['stim']['0'].attrs['dur']
            for sim in f.itervalues():
                rec = sim['rec']['0']
                sec_type = rec.attrs['type']
                distances[sec_type].append(rec.attrs['branch_distance'])
                tvec = sim['time']
                this_rest, this_peak, this_steady = get_Rinp(tvec[:], rec[:], start, stop, amp)
                peak[sec_type].append(this_peak)
                steady[sec_type].append(this_steady)
                sag[sec_type].append(100*(1-this_steady/this_peak))
                v_rest[sec_type].append(this_rest)
            for i, sec_type in enumerate(sec_types):
                axes[0][i].scatter(distances[sec_type], peak[sec_type], label=description_list[index],
                                   color=colors[index])
                axes[0][i].set_xlabel('Distance from Dendrite Origin (um)')
                axes[1][i].set_title(sec_type)
                axes[1][i].scatter(distances[sec_type], steady[sec_type], label=description_list[index],
                                   color=colors[index])
                axes[1][i].set_xlabel('Distance from Dendrite Origin (um)')
                axes[2][i].scatter(distances[sec_type], sag[sec_type], label=description_list[index],
                                   color=colors[index])
                axes[2][i].set_xlabel('Distance from Dendrite Origin (um)')
                axes[3][i].scatter(distances[sec_type], v_rest[sec_type], label=description_list[index],
                                   color=colors[index])
                axes[3][i].set_xlabel('Distance from Dendrite Origin (um)')
    axes[0][1].set_ylabel('Input Resistance\nPeak (MOhm)')
    axes[1][1].set_ylabel('Input Resistance\nSteady-state (MOhm)')
    axes[2][1].set_ylabel('% Sag')
    axes[3][1].set_ylabel('Resting Vm (mV)')
    if not description_list == [""]:
        axes[0][0].legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5)
    fig.subplots_adjust(hspace=0.45, wspace=0.45, left=0.05, right=0.95, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(20.8, 13)  # 19.2, 12)
        fig.savefig(data_dir+title+' - Rinp.svg', format='svg')
    plt.show()
    plt.close()


def plot_Rinp_vm(rec_file_list, description_list="", title=None):
    """
    Expects each file in list to be generated by parallel_rinp.
    Files contain voltage recordings from dendritic compartments probed with hyperpolarizing current injections.
    Plots vm vs. time,with one row per sec_type.
    Superimposes results from multiple files in list.
    :param rec_file_list: list of str
    :param description_list: list of str
    :param title: str
    """
    if not type(rec_file_list) == list:
        rec_file_list = [rec_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    default_sec_types = ['soma', 'basal', 'trunk', 'apical', 'tuft']
    with h5py.File(data_dir+rec_file_list[0]+'.hdf5', 'r') as f:
        temp_sec_types = []
        for sim in f.itervalues():
            rec = sim['rec']['0']
            sec_type = rec.attrs['type']
            if not sec_type in temp_sec_types:
                temp_sec_types.append(sec_type)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    sec_types = [sec_type for sec_type in default_sec_types if sec_type in temp_sec_types]+\
                 [sec_type for sec_type in temp_sec_types if not sec_type in default_sec_types]
    fig, axes = plt.subplots(1, max(2, len(sec_types)))
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    label_handles = []
    for index, rec_filename in enumerate(rec_file_list):
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            start = f['0']['stim']['0'].attrs['delay']
            stop = start + f['0']['stim']['0'].attrs['dur']
            for sim in f.itervalues():
                rec = sim['rec']['0']
                sec_type = rec.attrs['type']
                i = sec_types.index(sec_type)
                tvec = sim['time']
                interp_t = np.arange(0., stop, 0.01)
                interp_vm = np.interp(interp_t, tvec[:], rec[:])
                left, right = time2index(interp_t, start-5.0, stop)
                interp_t -= interp_t[left] + 5.
                axes[i].plot(interp_t[left:right], interp_vm[left:right], color=colors[index])
            for i, sec_type in enumerate(sec_types):
                axes[i].set_xlabel('Time (ms)')
                axes[i].set_ylabel('Vm (mV)')
                axes[i].set_title(sec_type)
        label_handles.append(mlines.Line2D([], [], color=colors[index], label=description_list[index]))
    if not description_list == [""]:
        axes[0].legend(handles=label_handles, framealpha=0.5, frameon=False)
    fig.subplots_adjust(hspace=0.4, wspace=0.3, left=0.05, right=0.98, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(19.2, 12)
        fig.savefig(data_dir+title+' - Rinp - traces.svg', format='svg')
    plt.show()
    plt.close()


def plot_Rinp_av_vm(rec_file_list, description_list="", title=None):
    """
    Expects each file in list to be generated by parallel_rinp.
    Files contain voltage recordings from dendritic compartments probed with hyperpolarizing current injections.
    Plots vm vs. time, with one row per sec_type, averaging all responses with the same input and recording location.
    This method subgroups trunk and apical sections as proximal or distal.
    Superimposes results from multiple files in list.
    :param rec_file_list: list of str
    :param description_list: list of str
    :param title: str
    """
    if not type(rec_file_list) == list:
        rec_file_list = [rec_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    default_sec_types = ['soma', 'basal', 'trunk_prox', 'trunk_dist', 'apical_prox', 'apical_dist', 'tuft']
    with h5py.File(data_dir+rec_file_list[0]+'.hdf5', 'r') as f:
        temp_sec_types = []
        for sim in f.itervalues():
            rec = sim['rec']['0']
            sec_type = rec.attrs['type']
            if not sec_type in temp_sec_types:
                if sec_type in ['trunk', 'apical']:
                    temp_sec_types.append(sec_type+'_prox')
                    temp_sec_types.append(sec_type+'_dist')
                else:
                    temp_sec_types.append(sec_type)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    sec_types = [sec_type for sec_type in default_sec_types if sec_type in temp_sec_types]+\
                 [sec_type for sec_type in temp_sec_types if not sec_type in default_sec_types]
    rows = max(3, len(sec_types)/4)
    fig, axes = plt.subplots(rows, min(4, len(sec_types)))
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    label_handles = []
    for index, rec_filename in enumerate(rec_file_list):
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            start = f['0']['stim']['0'].attrs['delay']
            stop = start + 200. # f['0']['stim']['0'].attrs['dur']
            average_vm = {}
            for sim in f.itervalues():
                rec = sim['rec']['0']
                sec_type = rec.attrs['type']
                if sec_type in ['trunk', 'apical']:
                    distance = rec.attrs['soma_distance'] if sec_type == 'trunk' else rec.attrs['soma_distance'] - \
                                                                                      rec.attrs['branch_distance']
                    if distance <= 150.:
                        sec_type += '_prox'
                    else:
                        sec_type += '_dist'
                i = sec_types.index(sec_type)
                tvec = sim['time']
                interp_t = np.arange(0., stop, 0.01)
                interp_vm = np.interp(interp_t, tvec[:], rec[:])
                left, right = time2index(interp_t, start-3.0, start-1.0)
                baseline = np.average(interp_vm[left:right])
                left, right = time2index(interp_t, start-5.0, stop)
                interp_vm = interp_vm[left:right] - baseline
                interp_t -= interp_t[left] + 5.
                interp_t = interp_t[left:right]
                if sec_type in average_vm:
                    average_vm[sec_type]['count'] += 1
                    average_vm[sec_type]['trace'] += interp_vm[:]
                else:
                    average_vm[sec_type] = {'count': 1, 'trace': interp_vm[:]}
            for i, sec_type in enumerate(sec_types):

                axes[i/4][i%4].plot(interp_t[:], average_vm[sec_type]['trace']/average_vm[sec_type]['count'],
                             color=colors[index])
                axes[i/4][i%4].set_xlabel('Time (ms)')  #, fontsize=20)
                axes[i/4][0].set_ylabel('Voltage (mV)')  #, fontsize=20)
                axes[i/4][i%4].set_title(sec_type)  #, fontsize=28)
        label_handles.append(mlines.Line2D([], [], color=colors[index], label=description_list[index]))
    if not description_list == [""]:
        axes[0][0].legend(handles=label_handles, framealpha=0.5, frameon=False, fontsize=20)
    fig.subplots_adjust(hspace=0.5, wspace=0.4, left=0.05, right=0.95, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(20.8, 13)
        fig.savefig(data_dir+title+' - Rinp - average traces.svg', format='svg')
    plt.show()
    plt.close()


def plot_superimpose_conditions(rec_filename, legend=False):
    """
    File contains simulation results from iterating through some changes in parameters or stimulation conditions.
    This function produces one plot per recorded vector. Each plot superimposes the recordings from each of the
    simulation iterations.
    :param rec_filename: str
    :param legend: bool
    """
    f = h5py.File(data_dir+rec_filename+'.hdf5', 'r')
    rec_ids = []
    sim_ids = []
    for sim in f.itervalues():
        if 'description' in sim.attrs and not sim.attrs['description'] in sim_ids:
            sim_ids.append(sim.attrs['description'])
        for rec in sim['rec'].itervalues():
            if 'description' in rec.attrs:
                rec_id = rec.attrs['description']
            else:
                rec_id = rec.attrs['type']+str(rec.attrs['index'])
            if not rec_id in (id['id'] for id in rec_ids):
                rec_ids.append({'id': rec_id, 'ylabel': rec.attrs['ylabel']+' ('+rec.attrs['units']+')'})
    fig, axes = plt.subplots(1, max(2, len(rec_ids)))
    for i in range(len(rec_ids)):
        axes[i].set_xlabel('Time (ms)')
        axes[i].set_ylabel(rec_ids[i]['ylabel'])
        axes[i].set_title(rec_ids[i]['id'])
    for sim in f.itervalues():
        if 'description' in sim.attrs:
            sim_id = sim.attrs['description']
        else:
            sim_id = ''
        tvec = sim['time']
        for rec in sim['rec'].itervalues():
            if ('description' in rec.attrs):
                rec_id = rec.attrs['description']
            else:
                rec_id = rec.attrs['type']+str(rec.attrs['index'])
            i = [index for index, id in enumerate(rec_ids) if id['id'] == rec_id][0]
            axes[i].plot(tvec[:], rec[:], label=sim_id)
    if legend:
        for i in range(len(rec_ids)):
            axes[i].legend(loc='best', framealpha=0.5, frameon=False)
    plt.subplots_adjust(hspace=0.4, wspace=0.3, left=0.05, right=0.95, top=0.95, bottom=0.1)
    plt.show()
    plt.close()
    f.close()


def plot_EPSP_attenuation(rec_file_list, description_list="", title=None):
    """
    Expects each file in list to be generated by parallel_EPSP_attenuation.
    Files contain simultaneous voltage recordings from 4 locations (soma, trunk, branch, spine) during single spine
    stimulation. Spines are distributed across 4 dendritic sec_types (basal, trunk, apical, tuft).
    Produces one figure containing a grid of 16 plots of EPSP amplitude vs. distance from dendrite origin.
    Superimposes results from multiple files in list.
    :param rec_file_list: list of str
    :param description_list: list of str
    :param title: str
    """
    if not type(rec_file_list) == list:
        rec_file_list = [rec_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    default_input_locs = ['basal', 'trunk', 'apical', 'tuft']
    default_rec_locs = ['soma', 'trunk', 'branch', 'spine']
    with h5py.File(data_dir+rec_file_list[0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        temp_rec_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['input_loc']
            if not input_loc in temp_input_locs:
                temp_input_locs.append(input_loc)
            for rec in sim['rec'].itervalues():
                rec_loc = rec.attrs['description']
                if not rec_loc in temp_rec_locs:
                    temp_rec_locs.append(rec_loc)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    rec_locs = [rec_loc for rec_loc in default_rec_locs if rec_loc in temp_rec_locs]+\
                 [rec_loc for rec_loc in temp_rec_locs if not rec_loc in default_rec_locs]
    distances = {}
    amps = {}
    fig, axes = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for index, rec_filename in enumerate(rec_file_list):
        for input_loc in input_locs:
            distances[input_loc] = []
            amps[input_loc] = {}
            for rec_loc in rec_locs:
                amps[input_loc][rec_loc] = []
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            equilibrate = f['0'].attrs['equilibrate']
            duration = f['0'].attrs['duration']
            for sim in f.itervalues():
                tvec = sim['time']
                input_loc = sim.attrs['input_loc']
                distances[input_loc].append(sim['rec']['3'].attrs['branch_distance'])
                for rec in sim['rec'].itervalues():
                    rec_loc = rec.attrs['description']
                    interp_t = np.arange(0., duration, 0.001)
                    interp_vm = np.interp(interp_t, tvec[:], rec[:])
                    left, right = time2index(interp_t, equilibrate-3.0, equilibrate-1.0)
                    baseline = np.average(interp_vm[left:right])
                    start, end = time2index(interp_t, equilibrate, duration)
                    amps[input_loc][rec_loc].append(np.max(interp_vm[start:end]) - baseline)
            for i, input_loc in enumerate(input_locs):
                for j, rec_loc in enumerate(rec_locs):
                    axes[i][j].scatter(distances[input_loc], amps[input_loc][rec_loc], color=colors[index],
                                        label=description_list[index])
                    axes[i][j].set_xlabel('Distance from Dendrite Origin (um)', fontsize='x-large')
                axes[i][0].set_ylabel('Spine Location: '+input_loc+'\nEPSP Amp (mV)', fontsize='xx-large')
            for j, rec_loc in enumerate(rec_locs):
                axes[0][j].set_title('Recording Loc: '+rec_loc)
    if not description_list == [""]:
        axes[0][0].legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5)
    plt.subplots_adjust(hspace=0.4, wspace=0.3, left=0.05, right=0.95, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(20.8, 15.6)  # 19.2, 12)19.2, 12)
        fig.savefig(data_dir+title+' - EPSP attenuation.svg', format='svg')
    plt.show()
    plt.close()


def plot_EPSP_kinetics(rec_file_list, description_list="", title=None):
    """
    Expects each file in list to be generated by parallel_EPSP_attenuation.
    Files contain simultaneous voltage recordings from 4 locations (soma, trunk, branch, spine) during single spine
    stimulation. Spines are distributed across 4 dendritic sec_types (basal, trunk, apical, tuft).
    Produces a grid of 16 plots of EPSP kinetics vs. distance from dendrite origin.
    Produces one figure each for rise kinetics and decay kinetics.
    Superimposes results from multiple files in list.
    :param rec_file_list: list of str
    :param description_list: list of str
    :param title: str
    """
    if not type(rec_file_list) == list:
        rec_file_list = [rec_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    default_input_locs = ['basal', 'trunk', 'apical', 'tuft']
    default_rec_locs = ['soma', 'trunk', 'branch', 'spine']
    with h5py.File(data_dir+rec_file_list[0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        temp_rec_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['input_loc']
            if not input_loc in temp_input_locs:
                temp_input_locs.append(input_loc)
            for rec in sim['rec'].itervalues():
                rec_loc = rec.attrs['description']
                if not rec_loc in temp_rec_locs:
                    temp_rec_locs.append(rec_loc)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    rec_locs = [rec_loc for rec_loc in default_rec_locs if rec_loc in temp_rec_locs]+\
                 [rec_loc for rec_loc in temp_rec_locs if not rec_loc in default_rec_locs]
    distances = {}
    rise_taus = {}
    decay_taus = {}
    fig1, axes1 = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    fig2, axes2 = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for index, rec_filename in enumerate(rec_file_list):
        for input_loc in input_locs:
            distances[input_loc] = []
            rise_taus[input_loc] = {}
            decay_taus[input_loc] = {}
            for rec_loc in rec_locs:
                rise_taus[input_loc][rec_loc] = []
                decay_taus[input_loc][rec_loc] = []
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            equilibrate = f['0'].attrs['equilibrate']
            duration = f['0'].attrs['duration']
            for sim in f.itervalues():
                tvec = sim['time']
                input_loc = sim.attrs['input_loc']
                distances[input_loc].append(sim['rec']['3'].attrs['branch_distance'])
                for rec in sim['rec'].itervalues():
                    rec_loc = rec.attrs['description']
                    left, right = time2index(tvec[:], equilibrate-3.0, equilibrate-1.0)
                    interp_t = np.arange(0., duration, 0.001)
                    baseline = np.average(rec[left:right])
                    interp_vm = np.interp(interp_t, tvec[:], rec[:])
                    start, end = time2index(interp_t, equilibrate, duration)
                    interp_t = interp_t[start:end]
                    interp_vm = interp_vm[start:end] - baseline
                    amp = np.max(interp_vm)
                    t_peak = np.where(interp_vm == amp)[0][0]
                    interp_vm /= amp
                    interp_t -= interp_t[0]
                    rise_10 = np.where(interp_vm[0:t_peak] >= 0.1)[0][0]
                    rise_90 = np.where(interp_vm[0:t_peak] >= 0.9)[0][0]
                    rise_50 = np.where(interp_vm[0:t_peak] >= 0.5)[0][0]
                    rise_tau = interp_t[rise_90] - interp_t[rise_10]
                    #decay_90 = np.where(interp_vm[t_peak:] <= 0.9)[0][0]
                    #decay_10 = np.where(interp_vm[t_peak:] <= 0.1)[0][0]
                    decay_50 = np.where(interp_vm[t_peak:] <= 0.5)[0][0]
                    #decay_tau = interp_t[decay_10] - interp_t[decay_90]
                    decay_tau = interp_t[decay_50] + interp_t[t_peak] - interp_t[rise_50]
                    """
                    rise_tau = optimize.curve_fit(model_exp_rise, interp_t[1:t_peak], interp_vm[1:t_peak], p0=0.3)[0]
                    decay_tau = optimize.curve_fit(model_exp_decay, interp_t[t_peak+1:]-interp_t[t_peak],
                                                   interp_vm[t_peak+1:], p0=5.)[0]
                    """
                    rise_taus[input_loc][rec_loc].append(rise_tau)
                    decay_taus[input_loc][rec_loc].append(decay_tau)
            for i, input_loc in enumerate(input_locs):
                for j, rec_loc in enumerate(rec_locs):
                    axes1[i][j].scatter(distances[input_loc], rise_taus[input_loc][rec_loc], color=colors[index],
                                        label=description_list[index])
                    axes1[i][j].set_xlabel('Distance from Dendrite Origin (um)')
                    axes1[i][j].set_ylabel('Spine Location: '+input_loc+'\nEPSP Rise Tau (ms)')
                    axes1[i][j].set_title('Recording Loc: '+rec_loc)
                    axes2[i][j].scatter(distances[input_loc], decay_taus[input_loc][rec_loc], color=colors[index],
                                        label=description_list[index])
                    axes2[i][j].set_xlabel('Distance from Dendrite Origin (um)')
                    axes2[i][j].set_ylabel('Spine Location: '+input_loc+'\nEPSP Decay Tau (ms)')
                    axes2[i][j].set_title('Recording Loc: '+rec_loc)
    if not description_list == [""]:
        axes1[0][0].legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5)
        axes2[0][0].legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5)
    fig1.subplots_adjust(hspace=0.5, wspace=0.3, left=0.05, right=0.98, top=0.95, bottom=0.05)
    fig2.subplots_adjust(hspace=0.5, wspace=0.3, left=0.05, right=0.98, top=0.95, bottom=0.05)
    if not title is None:
        fig1.set_size_inches(20.8, 15.6)  # 19.2, 12)19.2, 12)
        fig1.savefig(data_dir+title+' - EPSP attenuation - rise.svg', format='svg')
        fig2.set_size_inches(20.8, 15.6)  # 19.2, 12)19.2, 12)
        fig2.savefig(data_dir+title+' - EPSP attenuation - decay.svg', format='svg')
    plt.show()
    plt.close()


def plot_EPSP_av_vm(rec_file_list, description_list="", title=None):
    """
    Expects each file in list to be generated by parallel_EPSP_attenuation.
    Files contain simultaneous voltage recordings from 4 locations (soma, trunk, branch, spine) during single spine
    stimulation. Spines are distributed across 4 dendritic sec_types (basal, trunk, apical, tuft). This method subgroups
    trunk and apical sections as proximal or distal.
    Produces one figure containing a grid of 24 plots of EPSP vs. time, averaging all responses with the same input and
    recording location.
    Superimposes results from multiple files in list.
    :param rec_file_list: list of str
    :param description_list: list of str
    :param title: str
    """
    if not type(rec_file_list) == list:
        rec_file_list = [rec_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    default_input_locs = ['basal', 'trunk_prox', 'trunk_dist', 'apical_prox', 'apical_dist', 'tuft']
    default_rec_locs = ['soma', 'trunk', 'branch', 'spine']
    with h5py.File(data_dir+rec_file_list[0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        temp_rec_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['input_loc']
            if not input_loc in temp_input_locs:
                if input_loc in ['trunk', 'apical']:
                    temp_input_locs.append(input_loc+'_prox')
                    temp_input_locs.append(input_loc+'_dist')
                else:
                    temp_input_locs.append(input_loc)
            for rec in sim['rec'].itervalues():
                rec_loc = rec.attrs['description']
                if not rec_loc in temp_rec_locs:
                    temp_rec_locs.append(rec_loc)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    rec_locs = [rec_loc for rec_loc in default_rec_locs if rec_loc in temp_rec_locs]+\
                 [rec_loc for rec_loc in temp_rec_locs if not rec_loc in default_rec_locs]
    fig, axes = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for index, rec_filename in enumerate(rec_file_list):
        average_vm = {}
        for input_loc in input_locs:
            average_vm[input_loc] = {}
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            equilibrate = f['0'].attrs['equilibrate']
            duration = f['0'].attrs['duration']
            for sim in f.itervalues():
                tvec = sim['time']
                input_loc = sim.attrs['input_loc']
                if input_loc in ['trunk', 'apical']:
                    spine_rec = sim['rec']['3'].attrs
                    distance = spine_rec['soma_distance'] if input_loc == 'trunk' else spine_rec['soma_distance'] - \
                                                                                       spine_rec['branch_distance']
                    if distance <= 150.:
                        input_loc += '_prox'
                    else:
                        input_loc += '_dist'
                for rec in sim['rec'].itervalues():
                    rec_loc = rec.attrs['description']
                    interp_t = np.arange(0., duration, 0.001)
                    interp_vm = np.interp(interp_t, tvec[:], rec[:])
                    left, right = time2index(interp_t, equilibrate-3.0, equilibrate-1.0)
                    baseline = np.average(interp_vm[left:right])
                    left, right = time2index(interp_t, equilibrate-5.0, duration)
                    interp_vm = interp_vm[left:right] - baseline
                    interp_t -= interp_t[left] + 5.
                    interp_t = interp_t[left:right]
                    if rec_loc in average_vm[input_loc]:
                        average_vm[input_loc][rec_loc]['count'] += 1
                        average_vm[input_loc][rec_loc]['trace'] += interp_vm[:]
                    else:
                        average_vm[input_loc][rec_loc] = {'count': 1, 'trace': interp_vm[:]}
            for i, input_loc in enumerate(input_locs):
                for j, rec_loc in enumerate(rec_locs):
                    axes[i][j].plot(interp_t[:], average_vm[input_loc][rec_loc]['trace'] /
                            average_vm[input_loc][rec_loc]['count'], color=colors[index], label=description_list[index])
                    axes[i][j].set_xlabel('Time (ms)', fontsize='x-large')
                axes[i][0].set_ylabel('Spine Location:\n'+input_loc+'\nEPSP (mV)', fontsize='xx-large')
            for j, rec_loc in enumerate(rec_locs):
                axes[0][j].set_title('Recording Loc: '+rec_loc)
    if not description_list == [""]:
        axes[0][0].legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5)
    plt.subplots_adjust(hspace=0.5, wspace=0.3, left=0.05, right=0.95, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(20.8, 15.6)  # 19.2, 12)19.2, 12)
        fig.savefig(data_dir+title+' - EPSP average traces.svg', format='svg')
    plt.show()
    plt.close()


def plot_EPSP_i_attenuation(rec_file_list, description_list="", title=None):
    """
    Expects each file in list to be generated by parallel_EPSP_i_attenuation.
    Files contain simultaneous voltage recordings from 3 locations (soma, trunk, branch) during stimulation of a single
    branch with an EPSC-shaped current injection. Stimulated sec_types include (soma, basal, trunk, apical, tuft).
    Produces one figure containing a grid of 15 plots of EPSP amplitude vs. distance from soma.
    Superimposes results from multiple files in list.
    :param rec_file_list: list of str
    :param description_list: list of str
    :param title: str
    """
    if not type(rec_file_list) == list:
        rec_file_list = [rec_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    default_input_locs = ['soma', 'basal', 'trunk', 'apical', 'tuft']
    default_rec_locs = ['soma', 'trunk', 'branch']
    with h5py.File(data_dir+rec_file_list[0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        temp_rec_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['input_loc']
            if not input_loc in temp_input_locs:
                temp_input_locs.append(input_loc)
            for rec in sim['rec'].itervalues():
                rec_loc = rec.attrs['description']
                if not rec_loc in temp_rec_locs:
                    temp_rec_locs.append(rec_loc)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    rec_locs = [rec_loc for rec_loc in default_rec_locs if rec_loc in temp_rec_locs]+\
                 [rec_loc for rec_loc in temp_rec_locs if not rec_loc in default_rec_locs]
    distances = {}
    amps = {}
    fig, axes = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for index, rec_filename in enumerate(rec_file_list):
        for input_loc in input_locs:
            distances[input_loc] = []
            amps[input_loc] = {}
            for rec_loc in rec_locs:
                amps[input_loc][rec_loc] = []
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            equilibrate = f['0'].attrs['equilibrate']
            duration = f['0'].attrs['duration']
            for sim in f.itervalues():
                tvec = sim['time']
                input_loc = sim.attrs['input_loc']
                for rec in sim['rec'].itervalues():
                    rec_loc = rec.attrs['description']
                    if rec_loc == 'branch':
                        distances[input_loc].append(rec.attrs['soma_distance'])
                    interp_t = np.arange(0., duration, 0.001)
                    interp_vm = np.interp(interp_t, tvec[:], rec[:])
                    left, right = time2index(interp_t, equilibrate-3.0, equilibrate-1.0)
                    baseline = np.average(interp_vm[left:right])
                    start, end = time2index(interp_t, equilibrate, duration)
                    amps[input_loc][rec_loc].append(np.max(interp_vm[start:end]) - baseline)
            for i, input_loc in enumerate(input_locs):
                for j, rec_loc in enumerate(rec_locs):
                    axes[i][j].scatter(distances[input_loc], amps[input_loc][rec_loc], color=colors[index],
                                        label=description_list[index])
                    axes[i][j].set_xlabel('Distance from Soma (um)', fontsize='xx-large')
                axes[i][0].set_ylabel('Input Loc: '+input_loc+'\nEPSP Amp (mV)', fontsize='xx-large')
            for j, rec_loc in enumerate(rec_locs):
                axes[0][j].set_title('Recording Loc: '+rec_loc)
    if not description_list == [""]:
        axes[0][0].legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5)
    fig.subplots_adjust(hspace=0.45, wspace=0.25, left=0.05, right=0.95, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(20.8, 15.6)  # 19.2, 12)
        fig.savefig(data_dir+title+' - EPSP_i attenuation.svg', format='svg')
    plt.show()
    plt.close()


def plot_EPSP_i_kinetics(rec_file_list, description_list="", title=None):
    """
    Expects each file in list to be generated by parallel_EPSP_i_attenuation.
    Files contain simultaneous voltage recordings from 3 locations (soma, trunk, branch) during stimulation of a single
    branch with an EPSC-shaped current injection. Stimulated sec_types include (soma, basal, trunk, apical, tuft).
    Produces a grid of 15 plots of EPSP kinetics vs. distance from soma.
    Produces one figure each for rise kinetics and decay kinetics.
    Superimposes results from multiple files in list.
    :param rec_file_list: list of str
    :param description_list: list of str
    :param title: str
    """
    if not type(rec_file_list) == list:
        rec_file_list = [rec_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    default_input_locs = ['soma', 'basal', 'trunk', 'apical', 'tuft']
    default_rec_locs = ['soma', 'trunk', 'branch']
    with h5py.File(data_dir+rec_file_list[0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        temp_rec_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['input_loc']
            if not input_loc in temp_input_locs:
                temp_input_locs.append(input_loc)
            for rec in sim['rec'].itervalues():
                rec_loc = rec.attrs['description']
                if not rec_loc in temp_rec_locs:
                    temp_rec_locs.append(rec_loc)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    rec_locs = [rec_loc for rec_loc in default_rec_locs if rec_loc in temp_rec_locs]+\
                 [rec_loc for rec_loc in temp_rec_locs if not rec_loc in default_rec_locs]
    distances = {}
    rise_taus = {}
    decay_taus = {}
    fig1, axes1 = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    fig2, axes2 = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for index, rec_filename in enumerate(rec_file_list):
        for input_loc in input_locs:
            distances[input_loc] = []
            rise_taus[input_loc] = {}
            decay_taus[input_loc] = {}
            for rec_loc in rec_locs:
                rise_taus[input_loc][rec_loc] = []
                decay_taus[input_loc][rec_loc] = []
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            equilibrate = f['0'].attrs['equilibrate']
            duration = f['0'].attrs['duration']
            for sim in f.itervalues():
                tvec = sim['time']
                input_loc = sim.attrs['input_loc']
                for rec in sim['rec'].itervalues():
                    rec_loc = rec.attrs['description']
                    if rec_loc == 'branch':
                        distances[input_loc].append(rec.attrs['soma_distance'])
                    interp_t = np.arange(0., duration, 0.001)
                    interp_vm = np.interp(interp_t, tvec[:], rec[:])
                    left, right = time2index(interp_t, equilibrate-3.0, equilibrate-1.0)
                    baseline = np.average(interp_vm[left:right])
                    start, end = time2index(interp_t, equilibrate, duration)
                    interp_t = interp_t[start:end]
                    interp_vm = interp_vm[start:end] - baseline
                    amp = np.max(interp_vm)
                    t_peak = np.where(interp_vm == amp)[0][0]
                    interp_vm /= amp
                    interp_t -= interp_t[0]
                    rise_10 = np.where(interp_vm[0:t_peak] >= 0.1)[0][0]
                    rise_90 = np.where(interp_vm[0:t_peak] >= 0.9)[0][0]
                    rise_50 = np.where(interp_vm[0:t_peak] >= 0.5)[0][0]
                    rise_tau = interp_t[rise_90] - interp_t[rise_10]
                    decay_90 = np.where(interp_vm[t_peak:] <= 0.9)[0][0]
                    decay_10 = np.where(interp_vm[t_peak:] <= 0.1)[0][0]
                    decay_50 = np.where(interp_vm[t_peak:] <= 0.5)[0][0]
                    #decay_tau = interp_t[decay_10] - interp_t[decay_90]
                    decay_tau = interp_t[decay_50] + interp_t[t_peak] - interp_t[rise_50]
                    """
                    rise_tau = optimize.curve_fit(model_exp_rise, interp_t[1:t_peak], interp_vm[1:t_peak], p0=0.3)[0]
                    decay_tau = optimize.curve_fit(model_exp_decay, interp_t[t_peak+1:]-interp_t[t_peak],
                                                   interp_vm[t_peak+1:], p0=5.)[0]
                    """
                    rise_taus[input_loc][rec_loc].append(rise_tau)
                    decay_taus[input_loc][rec_loc].append(decay_tau)
            for i, input_loc in enumerate(input_locs):
                for j, rec_loc in enumerate(rec_locs):
                    axes1[i][j].scatter(distances[input_loc], rise_taus[input_loc][rec_loc], color=colors[index],
                                        label=description_list[index])
                    axes1[i][j].set_xlabel('Distance from Soma (um)', fontsize='xx-large')
                    axes2[i][j].scatter(distances[input_loc], decay_taus[input_loc][rec_loc], color=colors[index],
                                        label=description_list[index])
                    axes2[i][j].set_xlabel('Distance from Soma (um)', fontsize='xx-large')
                axes1[i][0].set_ylabel('Input Loc: '+input_loc+'\nEPSP Rise (ms)', fontsize='xx-large')
                axes2[i][0].set_ylabel('Input Loc: '+input_loc+'\nEPSP Decay (ms)', fontsize='xx-large')
            for j, rec_loc in enumerate(rec_locs):
                axes1[0][j].set_title('Recording Loc: '+rec_loc)
                axes2[0][j].set_title('Recording Loc: '+rec_loc)
    if not description_list == [""]:
        axes1[0][0].legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5)
        axes2[0][0].legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5)
    fig1.subplots_adjust(hspace=0.45, wspace=0.25, left=0.05, right=0.95, top=0.95, bottom=0.05)
    fig2.subplots_adjust(hspace=0.45, wspace=0.25, left=0.05, right=0.95, top=0.95, bottom=0.05)
    if not title is None:
        fig1.set_size_inches(20.8, 15.6)  # 19.2, 12)
        fig1.savefig(data_dir+title+' - EPSP_i attenuation - rise.svg', format='svg')
        fig2.set_size_inches(20.8, 15.6)  # 19.2, 12)
        fig2.savefig(data_dir+title+' - EPSP_i attenuation - decay.svg', format='svg')
    plt.show()
    plt.close()


def plot_EPSP_i_vm(rec_file_list, description_list="", title=None):
    """
    Expects each file in list to be generated by parallel_EPSP_i_attenuation.
    Files contain simultaneous voltage recordings from 3 locations (soma, trunk, branch) during stimulation of a single
    branch with an EPSC-shaped current injection. Stimulated sec_types include (soma, basal, trunk, apical, tuft).
    Produces a grid of 15 plots of vm vs time, superimposing all responses with the same input and recording location.
    Superimposes results from multiple files in list.
    :param rec_file_list: list of str
    :param description_list: list of str
    :param title: str
    """
    if not type(rec_file_list) == list:
        rec_file_list = [rec_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    default_input_locs = ['soma', 'basal', 'trunk', 'apical', 'tuft']
    default_rec_locs = ['soma', 'trunk', 'branch']
    with h5py.File(data_dir+rec_file_list[0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        temp_rec_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['input_loc']
            if not input_loc in temp_input_locs:
                temp_input_locs.append(input_loc)
            for rec in sim['rec'].itervalues():
                rec_loc = rec.attrs['description']
                if not rec_loc in temp_rec_locs:
                    temp_rec_locs.append(rec_loc)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    rec_locs = [rec_loc for rec_loc in default_rec_locs if rec_loc in temp_rec_locs]+\
                 [rec_loc for rec_loc in temp_rec_locs if not rec_loc in default_rec_locs]
    fig, axes = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    label_handles = []
    for index, rec_filename in enumerate(rec_file_list):
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            equilibrate = f['0'].attrs['equilibrate']
            duration = f['0'].attrs['duration']
            for sim in f.itervalues():
                tvec = sim['time']
                input_loc = sim.attrs['input_loc']
                i = input_locs.index(input_loc)
                for rec in sim['rec'].itervalues():
                    rec_loc = rec.attrs['description']
                    j = rec_locs.index(rec_loc)
                    interp_t = np.arange(0., duration, 0.01)
                    interp_vm = np.interp(interp_t, tvec[:], rec[:])
                    left, right = time2index(interp_t, equilibrate-5., equilibrate+50.)
                    interp_t -= interp_t[left] + 5.
                    axes[i][j].plot(interp_t[left:right], interp_vm[left:right], color=colors[index])
            for i, input_loc in enumerate(input_locs):
                for j, rec_loc in enumerate(rec_locs):
                    axes[i][j].set_xlabel('Time (ms)')
                    axes[i][j].set_ylabel('Input Loc: '+input_loc+'\nVm (mV)')
                    axes[i][j].set_title('Recording Loc: '+rec_loc)
        label_handles.append(mlines.Line2D([], [], color=colors[index], label=description_list[index]))
    if not description_list == [""]:
        axes[0][0].legend(handles=label_handles, framealpha=0.5, frameon=False)
    fig.subplots_adjust(hspace=0.5, wspace=0.3, left=0.05, right=0.98, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(19.2, 12)
        fig.savefig(data_dir+title+' - EPSP_i attenuation - traces.svg', format='svg')
    plt.show()
    plt.close()


def plot_EPSP_i_av_vm(rec_file_list, description_list="", title=None):
    """
    Expects each file in list to be generated by parallel_EPSP_i_attenuation.
    Files contain simultaneous voltage recordings from 3 locations (soma, trunk, branch) during stimulation of a single
    branch with an EPSC-shaped current injection. Stimulated sec_types include (basal, trunk, apical, tuft). This method
    subgroups trunk and apical sections as proximal or distal.
    Produces a grid of 18 plots of EPSP_i vs time, averaging all responses with the same input and recording location.
    Superimposes results from multiple files in list.
    :param rec_file_list: list of str
    :param description_list: list of str
    :param title: str
    """
    if not type(rec_file_list) == list:
        rec_file_list = [rec_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    default_input_locs = ['basal', 'trunk_prox', 'trunk_dist', 'apical_prox', 'apical_dist', 'tuft']
    default_rec_locs = ['soma', 'trunk', 'branch']
    with h5py.File(data_dir+rec_file_list[0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        temp_rec_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['input_loc']
            if not input_loc in temp_input_locs:
                if input_loc in ['trunk', 'apical']:
                    temp_input_locs.append(input_loc+'_prox')
                    temp_input_locs.append(input_loc+'_dist')
                else:
                    temp_input_locs.append(input_loc)
            for rec in sim['rec'].itervalues():
                rec_loc = rec.attrs['description']
                if not rec_loc in temp_rec_locs:
                    temp_rec_locs.append(rec_loc)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    rec_locs = [rec_loc for rec_loc in default_rec_locs if rec_loc in temp_rec_locs]+\
                 [rec_loc for rec_loc in temp_rec_locs if not rec_loc in default_rec_locs]
    fig, axes = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    label_handles = []
    for index, rec_filename in enumerate(rec_file_list):
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            equilibrate = f['0'].attrs['equilibrate']
            duration = f['0'].attrs['duration']
            for sim in f.itervalues():
                tvec = sim['time']
                input_loc = sim.attrs['input_loc']
                i = input_locs.index(input_loc)
                for rec in sim['rec'].itervalues():
                    rec_loc = rec.attrs['description']
                    j = rec_locs.index(rec_loc)
                    interp_t = np.arange(0., duration, 0.01)
                    interp_vm = np.interp(interp_t, tvec[:], rec[:])
                    left, right = time2index(interp_t, equilibrate-5., equilibrate+50.)
                    interp_t -= interp_t[left] + 5.
                    axes[i][j].plot(interp_t[left:right], interp_vm[left:right], color=colors[index])
            for i, input_loc in enumerate(input_locs):
                for j, rec_loc in enumerate(rec_locs):
                    axes[i][j].set_xlabel('Time (ms)')
                    axes[i][j].set_ylabel('Input Loc:\n'+input_loc+'\nVm (mV)')
                    axes[i][j].set_title('Recording Loc: '+rec_loc)
        label_handles.append(mlines.Line2D([], [], color=colors[index], label=description_list[index]))
    if not description_list == [""]:
        axes[0][0].legend(loc='best', handles=label_handles, framealpha=0.5, frameon=False)
    fig.subplots_adjust(hspace=0.5, wspace=0.3, left=0.05, right=0.98, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(19.2, 12)
        fig.savefig(data_dir+title+' - EPSP_i average traces.svg', format='svg')
    plt.show()
    plt.close()


def plot_synaptic_parameter(rec_file_list, description_list=None):
    """
    Expects each file in list to be generated by optimize_EPSP_amp.
    Files contain one group for each type of dendritic section. Groups contain distances from soma and values for all
    measured synaptic parameters. Produces one column of plots per sec_type, one row of plots per parameter, and
    superimposes data from each rec_file.
    :param rec_file_list: list of str
    :param description_list: list of str
    """
    if not type(rec_file_list) == list:
        rec_file_list = [rec_file_list]
    if description_list is None:
        description_list = [" " for rec in rec_file_list]
    with h5py.File(data_dir+rec_file_list[0]+'.hdf5', 'r') as f:
        param_list = [dataset for dataset in f.itervalues().next() if not dataset == 'distances']
        fig, axes = plt.subplots(max(2,len(param_list)), max(2, len(f)))
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for index, rec_filename in enumerate(rec_file_list):
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            for i, sec_type in enumerate(f):
                for j, dataset in enumerate(param_list):
                    axes[j][i].scatter(f[sec_type]['distances'][:], f[sec_type][dataset][:],
                                       label=description_list[index], color=colors[index])
                    axes[j][i].set_title(sec_type+' spines')
                    axes[j][i].set_xlabel('Distance to soma (um)')
                    axes[j][i].set_ylabel(f.attrs['syn_type']+': '+dataset+'\n'+f.attrs[dataset])
    plt.legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5)
    plt.subplots_adjust(hspace=0.4, wspace=0.3, left=0.05, right=0.98, top=0.95, bottom=0.05)
    plt.show()
    plt.close()


def plot_synaptic_param_distribution(cell, syn_type, param_name, scale_factor=1., param_label=None,
                                 ylabel='Peak conductance', yunits='uS', svg_title=None):
    """
    Takes a cell as input rather than a file. No simulation is required, this method just takes a fully specified cell
    and plots the relationship between distance and the specified synaptic parameter for all spines. Used while
    debugging specification of synaptic parameters.
    :param cell: :class:'HocCell'
    :param syn_type: str
    :param param_name: str
    :param scale_factor: float
    :param param_label: str
    :param ylabel: str
    :param yunits: str
    :param svg_title: str
    """
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    dend_types = ['basal', 'trunk', 'apical', 'tuft']

    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 20
    fig, axes = plt.subplots(1)
    maxval, minval = None, None
    for i, sec_type in enumerate(dend_types):
        syn_list = []
        distances = []
        param_vals = []
        for branch in cell.get_nodes_of_subtype(sec_type):
            for spine in branch.spines:
                syn_list.extend(spine.synapses)
        for syn in [syn for syn in syn_list if syn_type in syn._syn]:
            distances.append(cell.get_distance_to_node(cell.tree.root, syn.node.parent.parent, syn.loc))
            if sec_type == 'basal':
                    distances[-1] *= -1
            param_vals.append(getattr(syn.target(syn_type), param_name) * scale_factor)
        if param_vals:
            axes.scatter(distances, param_vals, color=colors[i], label=sec_type)
            if maxval is None:
                maxval = max(param_vals)
            else:
                maxval = max(maxval, max(param_vals))
            if minval is None:
                minval = min(param_vals)
            else:
                minval = min(minval, min(param_vals))
    axes.set_ylabel(ylabel + ' (' + yunits + ')')
    if (maxval is not None) and (minval is not None):
        buffer = 0.1 * (maxval - minval)
        axes.set_ylim(minval - buffer, maxval + buffer)
    axes.set_xlabel('Distance to soma (um)')
    axes.set_xlim(-200., 525.)
    axes.set_xticks([-150., 0., 150., 300., 450.])
    plt.legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    if param_label is not None:
        plt.title(param_label, fontsize=mpl.rcParams['font.size'])
    clean_axes(axes)
    axes.tick_params(direction='out')
    if not svg_title is None:
        if param_label is not None:
            svg_title = svg_title+' - '+param_label+'.svg'
        else:
            svg_title = svg_title+' - '+syn_type+'_'+param_name+' distribution.svg'
        fig.set_size_inches(5.27, 4.37)
        fig.savefig(data_dir+svg_title, format='svg', transparent=True)
    plt.show()
    plt.close()
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size


def plot_mech_param_distribution(cell, mech_name, param_name, scale_factor=10000., param_label=None,
                                 ylabel='Conductance density', yunits='pS/um2', svg_title=None):
    """
    Takes a cell as input rather than a file. No simulation is required, this method just takes a fully specified cell
    and plots the relationship between distance and the specified mechanism parameter for all dendritic segments. Used
    while debugging specification of mechanism parameters.
    :param cell: :class:'HocCell'
    :param mech_name: str
    :param param_name: str
    :param scale_factor: float
    :param ylabel: str
    :param yunits: str
    :param svg_title: str
    """
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    dend_types = ['basal', 'trunk', 'apical', 'tuft']

    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 20
    fig, axes = plt.subplots(1)
    maxval, minval = None, None
    for i, sec_type in enumerate(dend_types):
        distances = []
        param_vals = []
        for branch in cell.get_nodes_of_subtype(sec_type):
            for seg in [seg for seg in branch.sec if hasattr(seg, mech_name)]:
                distances.append(cell.get_distance_to_node(cell.tree.root, branch, seg.x))
                if sec_type == 'basal':
                    distances[-1] *= -1
                param_vals.append(getattr(getattr(seg, mech_name), param_name) * scale_factor)
        if param_vals:
            axes.scatter(distances, param_vals, color=colors[i], label=sec_type)
            if maxval is None:
                maxval = max(param_vals)
            else:
                maxval = max(maxval, max(param_vals))
            if minval is None:
                minval = min(param_vals)
            else:
                minval = min(minval, min(param_vals))
    axes.set_xlabel('Distance to soma (um)')
    axes.set_xlim(-200., 525.)
    axes.set_xticks([-150., 0., 150., 300., 450.])
    axes.set_ylabel(ylabel+' ('+yunits+')')
    if (maxval is not None) and (minval is not None):
        buffer = 0.1 * (maxval - minval)
        axes.set_ylim(minval-buffer, maxval+buffer)
    if param_label is not None:
        plt.title(param_label, fontsize=mpl.rcParams['font.size'])
    plt.legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    clean_axes(axes)
    axes.tick_params(direction='out')
    if not svg_title is None:
        if param_label is not None:
            svg_title = svg_title+' - '+param_label+'.svg'
        else:
            svg_title = svg_title+' - '+mech_name+'_'+param_name+' distribution.svg'
        fig.set_size_inches(5.27, 4.37)
        fig.savefig(data_dir + svg_title, format='svg', transparent=True)
    plt.show()
    plt.close()
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size


def plot_sum_mech_param_distribution(cell, mech_param_list, scale_factor=10000., param_label=None,
                                 ylabel='Conductance density', yunits='pS/um2', svg_title=None):
    """
    Takes a cell as input rather than a file. No simulation is required, this method just takes a fully specified cell
    and plots the relationship between distance and the specified mechanism parameter for all dendritic segments. Used
    while debugging specification of mechanism parameters.
    :param cell: :class:'HocCell'
    :param mech_param_list: list of tuple of str
    :param scale_factor: float
    :param ylabel: str
    :param yunits: str
    :param svg_title: str
    """
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    dend_types = ['basal', 'trunk', 'apical', 'tuft']

    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 20
    fig, axes = plt.subplots(1)
    maxval, minval = None, None
    for i, sec_type in enumerate(dend_types):
        distances = []
        param_vals = []
        for branch in cell.get_nodes_of_subtype(sec_type):
            for seg in branch.sec:
                this_param_val = 0.
                this_distance = None
                for mech_name, param_name in mech_param_list:
                    if hasattr(seg, mech_name):
                        if this_distance is None:
                            this_distance = cell.get_distance_to_node(cell.tree.root, branch, seg.x)
                            if sec_type == 'basal':
                                this_distance *= -1
                        this_param_val += getattr(getattr(seg, mech_name), param_name) * scale_factor
                if this_distance is not None:
                    distances.append(this_distance)
                    param_vals.append(this_param_val)
        axes.scatter(distances, param_vals, color=colors[i], label=sec_type)
        if maxval is None:
            maxval = max(param_vals)
        else:
            maxval = max(maxval, max(param_vals))
        if minval is None:
            minval = min(param_vals)
        else:
            minval = min(minval, min(param_vals))
    axes.set_xlabel('Distance to soma (um)')
    axes.set_xlim(-200., 525.)
    axes.set_xticks([-150., 0., 150., 300., 450.])
    axes.set_ylabel(ylabel+' ('+yunits+')')
    buffer = 0.1 * (maxval - minval)
    axes.set_ylim(minval-buffer, maxval+buffer)
    if param_label is not None:
        plt.title(param_label, fontsize=mpl.rcParams['font.size'])
    plt.legend(loc='best', scatterpoints=1, frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    clean_axes(axes)
    axes.tick_params(direction='out')
    if not svg_title is None:
        if param_label is not None:
            svg_title = svg_title+' - '+param_label+'.svg'
        else:
            mech_name, param_name = mech_param_list[0]
            svg_title = svg_title+' - '+mech_name+'_'+param_name+' distribution.svg'
        fig.set_size_inches(5.27, 4.37)
        fig.savefig(data_dir + svg_title, format='svg', transparent=True)
    plt.show()
    plt.close()
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size


def plot_expected_vs_actual_from_raw(expected_filename, actual_file_list, description_list="", location_list=None,
                                     title=None):
    """
    Expects each file in actual_file_list to be generated by parallel_clustered_ or
    parallel_distributed_branch_cooperativity. Files contain simultaneous voltage recordings from 3 locations (soma,
    trunk, dendrite origin) during synchronous stimulation of branches. Spines are distributed across 4 dendritic
    sec_types (basal, trunk, apical, tuft).
    Produces one figure containing a grid of up to 12 plots (4 sec_types by 3 recording locs) of expected EPSP amplitude
    vs. actual EPSP amplitude.
    Superimposes results from multiple branches in one color.
    Superimposes results from multiple files in list using different colors.
    :param expected_filename: str
    :param actual_file_list: list of str
    :param description_list: list of str
    :param location_list: list of str
    :param title: str
    """
    if not type(actual_file_list) == list:
        actual_file_list = [actual_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    if location_list is None:
        location_list = ['soma', 'branch']
    default_input_locs = ['basal', 'trunk', 'apical', 'tuft']
    default_rec_locs = ['soma', 'trunk', 'branch']
    with h5py.File(data_dir+actual_file_list[0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        temp_rec_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['path_type']
            if not input_loc in temp_input_locs:
                temp_input_locs.append(input_loc)
            for rec in sim['rec'].itervalues():
                rec_loc = rec.attrs['description']
                if not rec_loc in temp_rec_locs and rec_loc in location_list:
                    temp_rec_locs.append(rec_loc)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    rec_locs = [rec_loc for rec_loc in default_rec_locs if rec_loc in temp_rec_locs]+\
                 [rec_loc for rec_loc in temp_rec_locs if not rec_loc in default_rec_locs]
    fig, axes = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    label_handles = []
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    with h5py.File(data_dir+expected_filename+'.hdf5', 'r') as expected_file:
        expected_index_map = get_expected_spine_index_map(expected_file)
        for index, rec_filename in enumerate(actual_file_list):
            with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as actual_file:
                path_indexes = {input_loc: [] for input_loc in input_locs}
                for sim in actual_file.itervalues():
                    path_index = sim.attrs['path_index']
                    path_type = sim.attrs['path_type']
                    if not path_index in path_indexes[path_type] and path_type in input_locs:
                        path_indexes[path_type].append(path_index)
                for i, input_loc in enumerate(input_locs):
                    for path_index in path_indexes[input_loc]:
                        sorted_sim_keys = [key for key in actual_file if actual_file[key].attrs['path_index'] ==
                                                                                                        path_index]
                        sorted_sim_keys.sort(key=lambda x: len(actual_file[x].attrs['syn_indexes']))
                        expected_dict, actual_dict = get_expected_vs_actual(expected_file, actual_file,
                                                                        expected_index_map[path_index], sorted_sim_keys)
                        for j, location in enumerate(rec_locs):
                            axes[i][j].plot(expected_dict[location], actual_dict[location], color=colors[index])
            label_handles.append(mlines.Line2D([], [], color=colors[index], label=description_list[index]))
    for j, location in enumerate(rec_locs):
        axes[0][j].set_title('Recording loc: '+location)
        axes[-1][j].set_xlabel('Expected EPSP Amp (mV)')  # , fontsize='x-large')
    for i, input_loc in enumerate(input_locs):
        axes[i][0].set_ylabel('Spine Location: '+input_loc+'\nActual EPSP Amp (mV)')  # , fontsize='xx-large')
    if not description_list == [""]:
        axes[0][0].legend(loc='best', handles=label_handles, frameon=False, framealpha=0.5)
    plt.subplots_adjust(hspace=0.4, wspace=0.3, left=0.05, right=0.95, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(20.8, 15.6)  # 19.2, 12)19.2, 12)
        fig.savefig(data_dir+title+' - branch cooperativity.svg', format='svg')
    plt.show()
    plt.close()


def plot_expected_vs_actual_traces(expected_filename, actual_filenames, path_index, plot_every=3, max_num=22,
                                   interval=0.3, rec_loc='trunk', titles=None, dt=0.02, svg_title=None):
    """
    Given an output file generated by parallel_clustered_branch_cooperativity, this method produces side-by-side plots
    of the expected linear sum vs. actual summation, with increasing number of stimulated inputs superimposed. Every
    plot_every number of inputs is shown, up to max_num inputs, showing recordings from the specified recording
    location.
    One limitation is that, if providing a list of actual_files, the one with the longest integration window should be
    listed first, as this will set the shared x limit for all plots.
    :param expected_filename: str
    :param actual_filenames: list of str
    :param path_index: int
    :param plot_every: int
    :param max_num: int
    :param interval: float (inter-stimulus interval used during synchronous stimulation in the actual_file)
    :param rec_loc: str
    :param titles: list of str
    :param dt: float
    :param svg_title: str
    """
    expected_traces = []
    if not type(actual_filenames) == list:
        actual_filenames = [actual_filenames]
    actual_trace_array = []
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    with h5py.File(data_dir+expected_filename+'.hdf5', 'r') as expected_file:
        expected_index_map = get_expected_spine_index_map(expected_file)[path_index]
        expected_equilibrate = expected_file.itervalues().next().attrs['equilibrate']
        expected_duration = expected_file.itervalues().next().attrs['duration']
        with h5py.File(data_dir+actual_filenames[0]+'.hdf5', 'r') as actual_file:
            actual_traces = []
            trial = actual_file.itervalues().next()
            rec_key = (rec for rec in trial['rec'] if trial['rec'][rec].attrs['description'] == rec_loc).next()
            sim_keys = [key for key in actual_file if actual_file[key].attrs['path_index'] == path_index]
            sim_keys.sort(key=lambda x: len(actual_file[x].attrs['syn_indexes']))
            syn_indexes = actual_file[sim_keys[max_num-1]].attrs['syn_indexes']
            equilibrate  = actual_file[sim_keys[max_num-1]].attrs['equilibrate']
            duration = actual_file[sim_keys[max_num-1]].attrs['duration']
            t = np.arange(-10., duration - equilibrate, dt)
            running_expected = np.zeros(len(t))
            for i, syn_index in enumerate(syn_indexes):
                if i % plot_every == 0:
                    trial = actual_file[sim_keys[i]]
                    local_duration = trial.attrs['duration']
                    interp_t = np.arange(0., local_duration, dt)
                    actual_trace = np.interp(interp_t, trial['time'], trial['rec'][rec_key])
                    baseline = np.mean(actual_trace[int((equilibrate - 3.)/dt):int((equilibrate - 1.)/dt)])
                    actual_trace -= baseline
                    actual_trace = actual_trace[int((equilibrate - 10.)/dt):]
                    actual_traces.append(actual_trace)
                expected_key = expected_index_map[syn_index]
                expected_dict = get_expected_EPSP(expected_file, expected_key, expected_equilibrate,
                                                  expected_duration, dt)
                unit = expected_dict[rec_loc][int(2./dt):]
                start = int((10. + i * interval) / dt)
                stop = min(start + len(unit), len(t))
                running_expected[start:stop] += unit[:stop-start]
                if i % plot_every == 0:
                    expected_traces.append(np.array(running_expected))
            actual_trace_array.append(actual_traces)
        for actual_filename in actual_filenames[1:]:
            with h5py.File(data_dir+actual_filename+'.hdf5', 'r') as actual_file:
                actual_traces = []
                trial = actual_file.itervalues().next()
                sim_keys = [key for key in actual_file if actual_file[key].attrs['path_index'] == path_index]
                sim_keys.sort(key=lambda x: len(actual_file[x].attrs['syn_indexes']))
                syn_indexes = actual_file[sim_keys[max_num-1]].attrs['syn_indexes']
                for i, syn_index in enumerate(syn_indexes):
                    if i % plot_every == 0:
                        trial = actual_file[sim_keys[i]]
                        local_duration = trial.attrs['duration']
                        interp_t = np.arange(0., local_duration, dt)
                        actual_trace = np.interp(interp_t, trial['time'], trial['rec'][rec_key])
                        baseline = np.mean(actual_trace[int((equilibrate - 3.)/dt):int((equilibrate - 1.)/dt)])
                        actual_trace -= baseline
                        actual_trace = actual_trace[int((equilibrate - 10.)/dt):]
                        actual_traces.append(actual_trace)
                actual_trace_array.append(actual_traces)
    fig, axes = plt.subplots(1, 1+len(actual_filenames), sharey=True, sharex=True)
    for i in range(len(expected_traces)):
        axes[0].plot(t, expected_traces[i], c='grey')
        for j in range(len(actual_filenames)):
            actual_trace = actual_trace_array[j][i]
            axes[j+1].plot(np.arange(-10., len(actual_trace) * dt - 10., dt), actual_trace, c=colors[j])
    axes[0].set_title('Expected Linear Summation')
    for j in range(len(actual_filenames)):
        axes[j+1].set_title('Actual Summation')
        axes[j+1].set_xlabel('Time (ms)')
        if titles is not None:
            axes[j+1].legend(handles=[mlines.Line2D([], [], color=colors[j], label=titles[j])], loc='upper right',
                           frameon=False, framealpha=0.5)
    axes[0].set_ylabel('Amplitude (mV)')
    axes[0].set_xlabel('Time (ms)')
    plt.xlim(-10., 150.)
    clean_axes(axes)
    if svg_title is not None:
        plt.savefig(data_dir+svg_title+' - expected vs actual traces - '+str(path_index)+'.svg', format='svg')
    plt.show()
    plt.close()


def plot_expected_vs_actual_from_processed(actual_file_list, description_list=None, location_list=None, x='expected',
                                           title=None):
    """
    Expects each file in actual_file_list to be generated by passing the output of parallel_clustered_ or
    parallel_distributed_branch_cooperativity through export_nmdar_cooperativity. Files contain expected vs. actual
    measurements from simultaneous voltage recordings from 3 locations (soma, trunk, dendrite origin) during synchronous
    stimulation of branches. Spines are distributed across 4 dendritic sec_types (basal, trunk, apical, tuft).
    Produces one figure containing a grid of up to 12 plots (4 sec_types by 3 recording locs) of expected EPSP amplitude
    (or number of inputs) vs. actual EPSP amplitude.
    Superimposes results from multiple branches in two colors, separating by the path_category (e.g. proximal, distal,
    or terminal).
    Superimposes results from multiple files in list using different colors.
    :param actual_file_list: list of str
    :param description_list: list of str
    :param location_list: list of str
    :param x: str in ['expected', 'number']
    :param title: str
    """
    if not type(actual_file_list) == list:
        actual_file_list = [actual_file_list]
    if location_list is None:
        location_list = ['soma', 'branch']
    default_input_locs = ['basal', 'trunk', 'apical', 'tuft']
    default_rec_locs = ['soma', 'trunk', 'branch']
    with h5py.File(data_dir+actual_file_list[0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        temp_rec_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['path_type']
            if not input_loc in temp_input_locs:
                temp_input_locs.append(input_loc)
            for rec_loc in sim:
                if not rec_loc in temp_rec_locs and rec_loc in location_list:
                    temp_rec_locs.append(rec_loc)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    rec_locs = [rec_loc for rec_loc in default_rec_locs if rec_loc in temp_rec_locs]+\
                 [rec_loc for rec_loc in temp_rec_locs if not rec_loc in default_rec_locs]
    fig, axes = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    label_handles = [{input_loc: {} for input_loc in input_locs} for i in range(len(actual_file_list))]
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for index, rec_filename in enumerate(actual_file_list):
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as actual_file:
            if description_list is None:
                if 'description' in actual_file.attrs:
                    description_list = [actual_file.attrs['description']]
                else:
                    description_list = [""]
            elif 'description' in actual_file.attrs:
                description_list.append(actual_file.attrs['description'])
            else:
                description_list.append("")
            for sim in actual_file.itervalues():
                input_loc = sim.attrs['path_type']
                path_category = sim.attrs['path_category']
                if path_category == 'proximal':
                    color = colors[index*3]
                elif path_category == 'intermediate':
                    color = colors[index*3+1]
                else:
                    color = colors[index*3+2]
                if path_category not in label_handles[index][input_loc]:
                    label_handles[index][input_loc][path_category] = mlines.Line2D([], [], color=color,
                                                        label=description_list[index]+': '+path_category)
                i = input_locs.index(input_loc)
                for j, rec_loc in enumerate(rec_locs):
                    actual = sim[rec_loc]['actual'][:]
                    if x == 'number':
                        axes[i][j].plot(range(1, len(actual)+1), actual, color=color)
                    else:
                        axes[i][j].plot(sim[rec_loc]['expected'][:], actual, color=color)
                    clean_axes(axes[i][j])
    for j, location in enumerate(rec_locs):
        axes[0][j].set_title('Recording loc: '+location)
        if x == 'number':
            axes[-1][j].set_xlabel('Number of Inputs')  # , fontsize='x-large')
        else:
            axes[-1][j].set_xlabel('Expected EPSP Amp (mV)')  # , fontsize='x-large')
    for i, input_loc in enumerate(input_locs):
        axes[i][0].set_ylabel('Spine Location: '+input_loc+'\nActual EPSP Amp (mV)')  # , fontsize='xx-large')
        label_handle = []
        for index in range(len(label_handles)):
            label_handle.extend(label_handles[index][input_loc].values())
        axes[i][0].legend(loc='best', handles=label_handle, frameon=False, framealpha=0.5)
    plt.subplots_adjust(hspace=0.4, wspace=0.3, left=0.05, right=0.95, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(20.8, 15.6)  # 19.2, 12)19.2, 12)
        fig.savefig(data_dir+title+' - vs ' + x + ' - branch cooperativity.svg', format='svg')
    plt.show()
    plt.close()


def plot_expected_vs_actual_specific_branch(actual_file_list, path_index=None, description_list=None,
                                            rec_loc='trunk', svg_title=None):
    """
    Expects each file in actual_file_list to be generated by passing the output of parallel_clustered_ or
    parallel_distributed_branch_cooperativity through export_nmdar_cooperativity. Files contain expected vs. actual
    measurements from simultaneous voltage recordings from 3 locations (soma, trunk, branch) during synchronous
    stimulation of branches. Spines are distributed across 4 dendritic sec_types (basal, trunk, apical, tuft).
    Produces one figure containing one plot of expected EPSP amplitude vs. actual EPSP amplitude corresponding to the
    specified branch.
    Superimposes results from multiple files in list using different colors.
    :param actual_file_list: list of str
    :param path_index = int
    :param description_list: list of str
    :param rec_loc: str
    :param svg_title: str
    """
    if not type(actual_file_list) == list:
        actual_file_list = [actual_file_list]
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    fig, axes = plt.subplots(1)
    for index, rec_filename in enumerate(actual_file_list):
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as actual_file:
            if description_list is None:
                if 'description' in actual_file.attrs:
                    description_list = [actual_file.attrs['description']]
                else:
                    description_list = [""]
            elif 'description' in actual_file.attrs:
                description_list.append(actual_file.attrs['description'])
            else:
                description_list.append("")
            sim = actual_file[str(path_index)]
            axes.plot(sim[rec_loc]['expected'][:], sim[rec_loc]['actual'][:], color=colors[index],
                     label=description_list[index])
    axes.set_xlabel('Expected EPSP Amplitude (mV)')  # , fontsize='x-large')
    axes.set_ylabel('Actual EPSP Amplitude (mV)')  # , fontsize='xx-large')
    plt.legend(loc='best', frameon=False, framealpha=0.5)
    clean_axes(axes)
    if not svg_title is None:
        #plt.set_size_inches(20.8, 15.6)  # 19.2, 12)19.2, 12)
        plt.savefig(data_dir+svg_title+' - expected vs actual graph - '+str(path_index)+'.svg', format='svg')
    plt.show()
    plt.close()


def plot_nmdar_contribution_from_raw(expected_filename, actual_file_list, description_list="", location_list=None,
                                     title=None):
    """
    Expects each item in actual_file_list to be a tuple containing the names of files generated by parallel_clustered_
    or parallel_distributed_branch_cooperativity with and without NMDARs. Files contain simultaneous voltage recordings
    from 3 locations (soma, trunk, dendrite origin) during synchronous stimulation of branches. Spines are distributed
    across 4 dendritic sec_types (basal, trunk, apical, tuft).
    Produces one figure containing a grid of up to 12 plots (4 sec_types by 3 recording locs) of expected EPSP amplitude
    vs. actual EPSP amplitude.
    Superimposes results from multiple branches in one color.
    Superimposes results from multiple files in list using different colors.
    :param expected_filename: str
    :param actual_file_list: list of tuples of str
    :param description_list: list of str
    :param location_list: list of str
    :param title: str
    """
    if not type(actual_file_list) == list:
        actual_file_list = [actual_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    if location_list is None:
        location_list = ['soma', 'branch']
    default_input_locs = ['basal', 'trunk', 'apical', 'tuft']
    default_rec_locs = ['soma', 'trunk', 'branch']
    with h5py.File(data_dir+actual_file_list[0][0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        temp_rec_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['path_type']
            if not input_loc in temp_input_locs:
                temp_input_locs.append(input_loc)
            for rec in sim['rec'].itervalues():
                rec_loc = rec.attrs['description']
                if not rec_loc in temp_rec_locs and rec_loc in location_list:
                    temp_rec_locs.append(rec_loc)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    rec_locs = [rec_loc for rec_loc in default_rec_locs if rec_loc in temp_rec_locs]+\
                 [rec_loc for rec_loc in temp_rec_locs if not rec_loc in default_rec_locs]
    fig, axes = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    label_handles = []
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    with h5py.File(data_dir+expected_filename+'.hdf5', 'r') as expected_file:
        expected_index_map = get_expected_spine_index_map(expected_file)
        for index, (with_nmda_filename, without_nmda_filename) in enumerate(actual_file_list):
            with h5py.File(data_dir+with_nmda_filename+'.hdf5', 'r') as with_nmda_file:
                path_indexes = {input_loc: [] for input_loc in input_locs}
                for sim in with_nmda_file.itervalues():
                    path_index = sim.attrs['path_index']
                    path_type = sim.attrs['path_type']
                    if not path_index in path_indexes[path_type] and path_type in input_locs:
                        path_indexes[path_type].append(path_index)
                for i, input_loc in enumerate(input_locs):
                    for path_index in path_indexes[input_loc]:
                        sorted_sim_keys = [key for key in with_nmda_file if with_nmda_file[key].attrs['path_index'] ==
                                                                                                        path_index]
                        sorted_sim_keys.sort(key=lambda x: len(with_nmda_file[x].attrs['syn_indexes']))
                        expected_dict, with_nmda_dict = get_expected_vs_actual(expected_file, with_nmda_file,
                                                                        expected_index_map[path_index], sorted_sim_keys)
            with h5py.File(data_dir+without_nmda_filename+'.hdf5', 'r') as without_nmda_file:
                path_indexes = {input_loc: [] for input_loc in input_locs}
                for sim in without_nmda_file.itervalues():
                    path_index = sim.attrs['path_index']
                    path_type = sim.attrs['path_type']
                    if not path_index in path_indexes[path_type] and path_type in input_locs:
                        path_indexes[path_type].append(path_index)
                for i, input_loc in enumerate(input_locs):
                    for path_index in path_indexes[input_loc]:
                        sorted_sim_keys = [key for key in without_nmda_file if
                                           without_nmda_file[key].attrs['path_index'] == path_index]
                        sorted_sim_keys.sort(key=lambda x: len(without_nmda_file[x].attrs['syn_indexes']))
                        expected_dict, without_nmda_dict = get_expected_vs_actual(expected_file, without_nmda_file,
                                                                        expected_index_map[path_index], sorted_sim_keys)
                        for j, location in enumerate(rec_locs):
                            nmda_contribution = (np.array(with_nmda_dict[location]) -
                                np.array(without_nmda_dict[location])) / np.array(without_nmda_dict[location]) * 100.
                            axes[i][j].plot(expected_dict[location], nmda_contribution, color=colors[index])
                            print np.max(nmda_contribution)
            label_handles.append(mlines.Line2D([], [], color=colors[index], label=description_list[index]))
    for j, location in enumerate(rec_locs):
        axes[0][j].set_title('Recording loc: '+location)
        axes[-1][j].set_xlabel('Expected EPSP Amp (mV)')  # , fontsize='x-large')
    for i, input_loc in enumerate(input_locs):
        axes[i][0].set_ylabel('Spine Location: '+input_loc+'\nNMDAR Contribution (%)')  # , fontsize='xx-large')
    if not description_list == [""]:
        axes[0][0].legend(loc='best', handles=label_handles, frameon=False, framealpha=0.5)
    plt.subplots_adjust(hspace=0.4, wspace=0.3, left=0.05, right=0.95, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(20.8, 15.6)  # 19.2, 12)19.2, 12)
        fig.savefig(data_dir+title+' - NMDAR contribution.svg', format='svg')
    plt.show()
    plt.close()


def plot_nmdar_contribution_from_processed(actual_file_list, description_list=None, location_list=None, x='expected',
                                             title=None):
    """
    Expects each tuple in actual_file_list to be generated by passing the output of parallel_clustered_ or
    parallel_distributed_branch_cooperativity, with and without nmdars, through export_nmdar_cooperativity. Files
    contain expected vs. actual measurements from simultaneous voltage recordings from 3 locations (soma, trunk,
    dendrite origin) during synchronous stimulation of branches. Spines are distributed across 4 dendritic sec_types
    (basal, trunk, apical, tuft).
    Produces one figure containing a grid of up to 12 plots (4 sec_types by 3 recording locs) of % contribution of
    NMDARs to EPSP amplitude. The parameter 'x' determines whether to plot versus expected EPSP amplitude or distance
    from soma. For apical dendrites, this distance is from point of origin along the trunk to soma.
    Superimposes results from multiple branches in two colors, separating by the path_category (e.g. proximal, distal,
    or terminal).
    Superimposes results from multiple files in list using different colors.
    :param actual_file_list: list of tuple of str
    :param description_list: list of str
    :param location_list: list of str
    :param x: str in ['expected', 'distance']
    :param title: str
    """
    if not type(actual_file_list) == list:
        actual_file_list = [actual_file_list]
    if location_list is None:
        location_list = ['soma', 'branch']
    default_input_locs = ['basal', 'trunk', 'apical', 'tuft']
    default_rec_locs = ['soma', 'trunk', 'branch']
    with h5py.File(data_dir+actual_file_list[0][0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        temp_rec_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['path_type']
            if not input_loc in temp_input_locs:
                temp_input_locs.append(input_loc)
            for rec_loc in sim:
                if not rec_loc in temp_rec_locs and rec_loc in location_list:
                    temp_rec_locs.append(rec_loc)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    rec_locs = [rec_loc for rec_loc in default_rec_locs if rec_loc in temp_rec_locs]+\
                 [rec_loc for rec_loc in temp_rec_locs if not rec_loc in default_rec_locs]
    fig, axes = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    label_handles = [{input_loc: {} for input_loc in input_locs} for i in range(len(actual_file_list))]
    distances = {input_loc: {} for input_loc in input_locs}
    peaks = {input_loc: {rec_loc: {} for rec_loc in rec_locs} for input_loc in input_locs}
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for index, (with_nmda_filename, without_nmda_filename) in enumerate(actual_file_list):
        with h5py.File(data_dir+with_nmda_filename+'.hdf5', 'r') as with_nmda_file:
            with h5py.File(data_dir+without_nmda_filename+'.hdf5', 'r') as without_nmda_file:
                if description_list is None:
                    if 'description' in with_nmda_file.attrs:
                        description_list = [with_nmda_file.attrs['description']]
                    else:
                        description_list = [""]
                elif len(description_list) < len(actual_file_list):
                    if 'description' in with_nmda_file.attrs:
                        description_list.append(with_nmda_file.attrs['description'])
                    else:
                        description_list.append("")
                for path_index, sim in with_nmda_file.iteritems():
                    input_loc = sim.attrs['path_type']
                    if input_loc == 'apical':
                        distance = sim.attrs['origin_distance']
                    else:
                        distance = sim.attrs['soma_distance']
                    path_category = sim.attrs['path_category']
                    if path_category == 'proximal':
                        color = colors[index*3]
                    elif path_category == 'intermediate':
                        color = colors[index*3+1]
                    else:
                        color = colors[index*3+2]
                    if path_category not in distances[input_loc]:
                        distances[input_loc][path_category] = []
                        for rec_loc in rec_locs:
                            peaks[input_loc][rec_loc][path_category] = []
                        if x == 'expected':
                            label_handles[index][input_loc][path_category] = mlines.Line2D([], [], color=color,
                                                            label=description_list[index]+': '+path_category)
                        else:
                            label_handles[index][input_loc][path_category] = mlines.Line2D([], [], color='none',
                                                            marker='o', markeredgecolor=color, markerfacecolor=color,
                                                            label=description_list[index]+': '+path_category)
                    distances[input_loc][path_category].append(distance)
                    i = input_locs.index(input_loc)
                    for j, rec_loc in enumerate(rec_locs):
                        expected = sim[rec_loc]['expected'][:]
                        with_nmda = sim[rec_loc]['actual'][:]
                        without_nmda = without_nmda_file[path_index][rec_loc]['actual'][:]
                        nmda_contribution = (with_nmda - without_nmda) / without_nmda * 100.
                        peak = np.max(nmda_contribution)
                        peaks[input_loc][rec_loc][path_category].append(peak)
                        if x == 'expected':
                            axes[i][j].plot(expected, nmda_contribution, color=color)
        if x == 'distance':
            for i, input_loc in enumerate(input_locs):
                for j, rec_loc in enumerate(rec_locs):
                    for path_category in distances[input_loc]:
                        if path_category == 'proximal':
                            color = colors[index*3]
                        elif path_category == 'intermediate':
                            color = colors[index*3+1]
                        else:
                            color = colors[index*3+2]
                        axes[i][j].scatter(distances[input_loc][path_category],
                                           peaks[input_loc][rec_loc][path_category], color=color)
            xlabel = 'Distance to Soma (um)'
            ylabel = 'Peak NMDAR Contribution (%)'
        else:
            xlabel = 'Expected EPSP Amp (mV)'
            ylabel = 'NMDAR Contribution (%)'
    for j, location in enumerate(rec_locs):
        axes[0][j].set_title('Recording loc: '+location)
        axes[-1][j].set_xlabel(xlabel)  # , fontsize='x-large')
    for i, input_loc in enumerate(input_locs):
        axes[i][0].set_ylabel('Spine Location: '+input_loc+'\n'+ylabel)  # , fontsize='xx-large')
        label_handle = []
        for index in range(len(label_handles)):
            label_handle.extend(label_handles[index][input_loc].values())
        axes[i][0].legend(loc='best', handles=label_handle, frameon=False, framealpha=0.5)
    plt.subplots_adjust(hspace=0.4, wspace=0.3, left=0.05, right=0.95, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(20.8, 15.6)  # 19.2, 12)19.2, 12)
        fig.savefig(data_dir+title+' - NMDAR contribution.svg', format='svg')
    plt.show()
    plt.close()


def plot_nmdar_supralinearity_from_raw(expected_filename, actual_file_list, description_list="", location_list=None,
                                       title=None):
    """
    Expects each file in actual_file_list to be generated by parallel_clustered_ or
    parallel_distributed_branch_cooperativity. Files contain simultaneous voltage recordings from 3 locations (soma,
    trunk, dendrite origin) during synchronous stimulation of branches. Spines are distributed across 4 dendritic
    sec_types (basal, trunk, apical, tuft).
    Produces one figure containing a grid of up to 12 plots (4 sec_types by 3 recording locs) of expected EPSP amplitude
    vs. percentage of actual EPSP amplitude greater than linear.
    Superimposes results from multiple branches in one color.
    Superimposes results from multiple files in list using different colors.
    :param expected_filename: str
    :param actual_file_list: list of str
    :param description_list: list of str
    :param location_list: list of str
    :param title: str
    """
    if not type(actual_file_list) == list:
        actual_file_list = [actual_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    if location_list is None:
        location_list = ['soma', 'branch']
    default_input_locs = ['basal', 'trunk', 'apical', 'tuft']
    default_rec_locs = ['soma', 'trunk', 'branch']
    with h5py.File(data_dir+actual_file_list[0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        temp_rec_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['path_type']
            if not input_loc in temp_input_locs:
                temp_input_locs.append(input_loc)
            for rec in sim['rec'].itervalues():
                rec_loc = rec.attrs['description']
                if not rec_loc in temp_rec_locs and rec_loc in location_list:
                    temp_rec_locs.append(rec_loc)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    rec_locs = [rec_loc for rec_loc in default_rec_locs if rec_loc in temp_rec_locs]+\
                 [rec_loc for rec_loc in temp_rec_locs if not rec_loc in default_rec_locs]
    fig, axes = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    label_handles = []
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    with h5py.File(data_dir+expected_filename+'.hdf5', 'r') as expected_file:
        expected_index_map = get_expected_spine_index_map(expected_file)
        for index, rec_filename in enumerate(actual_file_list):
            with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as actual_file:
                path_indexes = {input_loc: [] for input_loc in input_locs}
                for sim in actual_file.itervalues():
                    path_index = sim.attrs['path_index']
                    path_type = sim.attrs['path_type']
                    if not path_index in path_indexes[path_type] and path_type in input_locs:
                        path_indexes[path_type].append(path_index)
                for i, input_loc in enumerate(input_locs):
                    for path_index in path_indexes[input_loc]:
                        sorted_sim_keys = [key for key in actual_file if actual_file[key].attrs['path_index'] ==
                                                                                                        path_index]
                        sorted_sim_keys.sort(key=lambda x: len(actual_file[x].attrs['syn_indexes']))
                        expected_dict, actual_dict = get_expected_vs_actual(expected_file, actual_file,
                                                                        expected_index_map[path_index], sorted_sim_keys)
                        for j, location in enumerate(rec_locs):
                            expected = np.array(expected_dict[location])
                            actual = np.array(actual_dict[location])
                            supralinearity = (actual - expected) / expected * 100.
                            axes[i][j].plot(expected_dict[location], supralinearity, color=colors[index])
                            print np.max(supralinearity)
            label_handles.append(mlines.Line2D([], [], color=colors[index], label=description_list[index]))
    for j, location in enumerate(rec_locs):
        axes[0][j].set_title('Recording loc: '+location)
        axes[-1][j].set_xlabel('Expected EPSP Amp (mV)')  # , fontsize='x-large')
    for i, input_loc in enumerate(input_locs):
        axes[i][0].set_ylabel('Spine Location: '+input_loc+'\nNMDAR Nonlinearity (%)')  # , fontsize='xx-large')
    if not description_list == [""]:
        axes[0][0].legend(loc='best', handles=label_handles, frameon=False, framealpha=0.5)
    plt.subplots_adjust(hspace=0.4, wspace=0.3, left=0.05, right=0.95, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(20.8, 15.6)  # 19.2, 12)19.2, 12)
        fig.savefig(data_dir+title+' - branch supralinearity.svg', format='svg')
    plt.show()
    plt.close()


def plot_nmdar_supralinearity_from_processed(actual_file_list, description_list=None, location_list=None, x='expected',
                                             title=None):
    """
    Expects each file in actual_file_list to be generated by passing the output of parallel_clustered_ or
    parallel_distributed_branch_cooperativity through export_nmdar_cooperativity. Files contain expected vs. actual
    measurements from simultaneous voltage recordings from 3 locations (soma, trunk, dendrite origin) during synchronous
    stimulation of branches. Spines are distributed across 4 dendritic sec_types (basal, trunk, apical, tuft).
    Produces one figure containing a grid of up to 12 plots (4 sec_types by 3 recording locs) of percentage of actual
    EPSP amplitude greater than linear. The parameter 'x' determines whether to plot versus expected EPSP amplitude or
    distance from soma. For apical dendrites, this distance is from point of origin along the trunk to soma.
    Superimposes results from multiple branches in two colors, separating by the path_category (e.g. proximal, distal,
    or terminal).
    Superimposes results from multiple files in list using different colors.
    :param actual_file_list: list of str
    :param description_list: list of str
    :param location_list: list of str
    :param x: str in ['expected', 'distance']
    :param title: str
    """
    if not type(actual_file_list) == list:
        actual_file_list = [actual_file_list]
    if location_list is None:
        location_list = ['soma', 'branch']
    default_input_locs = ['basal', 'trunk', 'apical', 'tuft']
    default_rec_locs = ['soma', 'trunk', 'branch']
    with h5py.File(data_dir+actual_file_list[0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        temp_rec_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['path_type']
            if not input_loc in temp_input_locs:
                temp_input_locs.append(input_loc)
            for rec_loc in sim:
                if not rec_loc in temp_rec_locs and rec_loc in location_list:
                    temp_rec_locs.append(rec_loc)
    # enforce the default order of input and recording locations for plotting, but allow for adding or subtracting
    # sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    rec_locs = [rec_loc for rec_loc in default_rec_locs if rec_loc in temp_rec_locs]+\
                 [rec_loc for rec_loc in temp_rec_locs if not rec_loc in default_rec_locs]
    fig, axes = plt.subplots(max(2, len(input_locs)), max(2, len(rec_locs)))
    label_handles = [{input_loc: {} for input_loc in input_locs} for i in range(len(actual_file_list))]
    distances = {input_loc: {} for input_loc in input_locs}
    peaks = {input_loc: {rec_loc: {} for rec_loc in rec_locs} for input_loc in input_locs}
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for index, rec_filename in enumerate(actual_file_list):
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as actual_file:
            if description_list is None:
                if 'description' in actual_file.attrs:
                    description_list = [actual_file.attrs['description']]
                else:
                    description_list = [""]
            elif len(description_list) < len(actual_file_list):
                if 'description' in actual_file.attrs:
                    description_list.append(actual_file.attrs['description'])
                else:
                    description_list.append("")
            for sim in actual_file.itervalues():
                input_loc = sim.attrs['path_type']
                if input_loc == 'apical':
                    distance = sim.attrs['origin_distance']
                else:
                    distance = sim.attrs['soma_distance']
                path_category = sim.attrs['path_category']
                if path_category == 'proximal':
                    color = colors[index*3]
                elif path_category == 'intermediate':
                    color = colors[index*3+1]
                else:
                    color = colors[index*3+2]
                if path_category not in distances[input_loc]:
                    distances[input_loc][path_category] = []
                    for rec_loc in rec_locs:
                        peaks[input_loc][rec_loc][path_category] = []
                    if x == 'expected':
                        label_handles[index][input_loc][path_category] = mlines.Line2D([], [], color=color,
                                                        label=description_list[index]+': '+path_category)
                    else:
                        label_handles[index][input_loc][path_category] = mlines.Line2D([], [], color='none', marker='o',
                         markeredgecolor=color, markerfacecolor=color, label=description_list[index]+': '+path_category)
                distances[input_loc][path_category].append(distance)
                i = input_locs.index(input_loc)
                for j, rec_loc in enumerate(rec_locs):
                    expected = sim[rec_loc]['expected'][:]
                    actual = sim[rec_loc]['actual'][:]
                    supralinearity = (actual - expected) / expected * 100.
                    peak = np.max(supralinearity)
                    peaks[input_loc][rec_loc][path_category].append(peak)
                    if x == 'expected':
                        axes[i][j].plot(expected, supralinearity, color=color)
        if x == 'distance':
            for i, input_loc in enumerate(input_locs):
                for j, rec_loc in enumerate(rec_locs):
                    for path_category in distances[input_loc]:
                        if path_category == 'proximal':
                            color = colors[index*3]
                        elif path_category == 'intermediate':
                            color = colors[index*3+1]
                        else:
                            color = colors[index*3+2]
                        axes[i][j].scatter(distances[input_loc][path_category],
                                           peaks[input_loc][rec_loc][path_category], color=color)
                        clean_axes(axes[i][j])
            xlabel = 'Distance to Soma (um)'
            ylabel = 'Peak NMDAR Supralinearity (%)'
        else:
            xlabel = 'Expected EPSP Amp (mV)'
            ylabel = 'NMDAR Nonlinearity (%)'
    for j, location in enumerate(rec_locs):
        axes[0][j].set_title('Recording loc: '+location)
        axes[-1][j].set_xlabel(xlabel)  # , fontsize='x-large')
    for i, input_loc in enumerate(input_locs):
        axes[i][0].set_ylabel('Spine Location: '+input_loc+'\n'+ylabel)  # , fontsize='xx-large')
        label_handle = []
        for index in range(len(label_handles)):
            label_handle.extend(label_handles[index][input_loc].values())
        axes[i][0].legend(loc='best', handles=label_handle, frameon=False, framealpha=0.5)
    plt.subplots_adjust(hspace=0.4, wspace=0.3, left=0.05, right=0.95, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(20.8, 15.6)  # 19.2, 12)19.2, 12)
        fig.savefig(data_dir+title+' - branch supralinearity.svg', format='svg')
    plt.show()
    plt.close()


def plot_nmdar_conductance_from_raw(actual_file_list, description_list="", object_description='NMDA_g', title=None):
    """
    Expects each file in actual_file_list to be generated by parallel_clustered_ or
    parallel_distributed_branch_cooperativity. Files should contain recordings of NMDA_KIN conductance in single spines
    that were activated first in a sequence of stimulated spines, and labeled with the description 'NMDA_g'. Spines are
    distributed across 4 dendritic sec_types (basal, trunk, apical, tuft).
    Produces one figure for each sec_type containing a plot of number of activated spines vs. peak NMDA conductance.
    Superimposes results from multiple spines of the same sec_type using the same color.
    Superimposes results from multiple files in list using different colors.
    :param actual_file_list: list of str
    :param description_list: list of str
    :param object_description: str
    :param title: str
    """
    if not type(actual_file_list) == list:
        actual_file_list = [actual_file_list]
    if not type(description_list) == list:
        description_list = [description_list]
    default_input_locs = ['basal', 'trunk', 'apical', 'tuft']
    with h5py.File(data_dir+actual_file_list[0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['path_type']
            if not input_loc in temp_input_locs:
                temp_input_locs.append(input_loc)
    # enforce the default order of input locations for plotting, but allow for adding or subtracting sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    fig, axes = plt.subplots(2, 2)
    label_handles = []
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for index, rec_filename in enumerate(actual_file_list):
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as actual_file:
            path_indexes = {input_loc: [] for input_loc in input_locs}
            for sim in actual_file.itervalues():
                path_index = sim.attrs['path_index']
                path_type = sim.attrs['path_type']
                if not path_index in path_indexes[path_type] and path_type in input_locs:
                    path_indexes[path_type].append(path_index)
            for i, input_loc in enumerate(input_locs):
                for path_index in path_indexes[input_loc]:
                    sorted_sim_keys = [key for key in actual_file if actual_file[key].attrs['path_index'] ==
                                                                                                    path_index]
                    sorted_sim_keys.sort(key=lambda x: len(actual_file[x].attrs['syn_indexes']))
                    peak_conductance = []
                    for sim in [actual_file[key] for key in sorted_sim_keys]:
                        rec = (rec[:] for rec in sim['rec'].itervalues() if
                               rec.attrs['description'] == object_description).next()
                        peak_conductance.append(np.max(rec))
                    axes[i/2][i%2].plot(range(1, len(sorted_sim_keys)+1), peak_conductance, color=colors[index])
            axes[i/2][i%2].set_xlabel('Input Number')  # , fontsize='x-large')
        label_handles.append(mlines.Line2D([], [], color=colors[index], label=description_list[index]))
    for i, input_loc in enumerate(input_locs):
        axes[i/2][i%2].set_ylabel('Spine Location: '+input_loc+'\nNMDAR Conductance (uS)')  # , fontsize='xx-large')
    if not description_list == [""]:
        axes[0][0].legend(loc='best', handles=label_handles, frameon=False, framealpha=0.5)
    plt.subplots_adjust(hspace=0.4, wspace=0.3, left=0.05, right=0.95, top=0.95, bottom=0.05)
    if not title is None:
        fig.set_size_inches(20.8, 15.6)  # 19.2, 12)19.2, 12)
        fig.savefig(data_dir+title+' - nmdar conductance.svg', format='svg')
    plt.show()
    plt.close()


def plot_nmdar_conductance_from_processed(actual_file_list, description_list=None, object_description='NMDA_g',
                                          x='number', svg_title=None, plot_input_locs=None):
    """
    Expects each file in actual_file_list to be generated by passing the output of parallel_clustered_ or
    parallel_distributed_branch_cooperativity through export_nmdar_cooperativity. Files should contain recordings of
    NMDA_KIN conductance in single spines that were activated first in a sequence of stimulated spines, and labeled with
    the description 'NMDA_g'. Spines are distributed across 4 dendritic sec_types (basal, trunk, apical, tuft).
    Produces one figure containing a grid of up to 12 plots (4 sec_types by 3 recording locs) of NMDAR conductance. The
    parameter 'x' determines whether to plot versus number of inputs or distance from soma. For apical dendrites, this
    distance is from point of origin along the trunk to soma.
    Superimposes results from multiple branches in two colors, separating by the path_category (e.g. proximal, distal,
    or terminal).
    Superimposes results from multiple files in list using different colors.
    :param actual_file_list: list of str
    :param description_list: list of str
    :param object_description: str
    :param x: str in ['number', 'distance']
    :param svg_title: str
    :param plot_input_locs: list of str: allows selecting plotting of subset of branch types
    """
    if not type(actual_file_list) == list:
        actual_file_list = [actual_file_list]
    default_input_locs = ['basal', 'trunk', 'apical', 'tuft']
    with h5py.File(data_dir+actual_file_list[0]+'.hdf5', 'r') as f:
        temp_input_locs = []
        for sim in f.itervalues():
            input_loc = sim.attrs['path_type']
            if not input_loc in temp_input_locs:
                temp_input_locs.append(input_loc)
    # enforce the default order of input locations for plotting, but allow for adding or subtracting sec_types
    input_locs = [input_loc for input_loc in default_input_locs if input_loc in temp_input_locs]+\
                 [input_loc for input_loc in temp_input_locs if not input_loc in default_input_locs]
    if plot_input_locs is None:
        plot_input_locs = input_locs
    fig, axes = plt.subplots(max(2, len(plot_input_locs)))
    label_handles = [{input_loc: {} for input_loc in plot_input_locs} for i in range(len(actual_file_list))]
    distances = {input_loc: {} for input_loc in plot_input_locs}
    peaks = {input_loc: {} for input_loc in plot_input_locs}
    colors = ['k', 'r', 'c', 'y', 'm', 'g', 'b']
    for index, rec_filename in enumerate(actual_file_list):
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as actual_file:
            if description_list is None:
                if 'description' in actual_file.attrs:
                    description_list = [actual_file.attrs['description']]
                else:
                    description_list = [""]
            elif len(description_list) < len(actual_file_list):
                if 'description' in actual_file.attrs:
                    description_list.append(actual_file.attrs['description'])
                else:
                    description_list.append("")
            for sim in actual_file.itervalues():
                input_loc = sim.attrs['path_type']
                if input_loc == 'apical':
                    distance = sim.attrs['origin_distance']
                else:
                    distance = sim.attrs['soma_distance']
                path_category = sim.attrs['path_category']
                if path_category == 'proximal':
                    color = colors[index*3]
                elif path_category == 'intermediate':
                    color = colors[index*3+1]
                else:
                    color = colors[index*3+2]
                if input_loc in distances and path_category not in distances[input_loc]:
                    distances[input_loc][path_category] = []
                    peaks[input_loc][path_category] = []
                    if x == 'number':
                        label_handles[index][input_loc][path_category] = mlines.Line2D([], [], color=color,
                                                        label=description_list[index]+': '+path_category)
                    else:
                        label_handles[index][input_loc][path_category] = mlines.Line2D([], [], color='none', marker='o',
                         markeredgecolor=color, markerfacecolor=color, label=description_list[index]+': '+path_category)
                if input_loc in distances:
                    distances[input_loc][path_category].append(distance)
                    i = plot_input_locs.index(input_loc)
                    actual = sim[object_description]['actual'][:]*1000.
                    peak = np.max(actual)
                    peaks[input_loc][path_category].append(peak)
                    if x == 'number':
                        axes[i].plot(range(1, len(actual)+1), actual, color=color)
        if x == 'distance':
            for i, input_loc in enumerate(plot_input_locs):
                for path_category in distances[input_loc]:
                    if path_category == 'proximal':
                        color = colors[index*3]
                    elif path_category == 'intermediate':
                        color = colors[index*3+1]
                    else:
                        color = colors[index*3+2]
                    axes[i].scatter(distances[input_loc][path_category], peaks[input_loc][path_category], color=color)
            xlabel = 'Distance to Soma (um)'
            ylabel = 'Peak NMDAR Conductance (nS)'
        else:
            xlabel = 'Input Number'
            ylabel = 'NMDAR Conductance (nS)'
    axes[0].set_title('Synaptic NMDAR Conductance')
    axes[-1].set_xlabel(xlabel)  # , fontsize='x-large')
    for i, input_loc in enumerate(plot_input_locs):
        axes[i].set_ylabel('Spine Location: '+input_loc+'\n'+ylabel)  # , fontsize='xx-large')
        label_handle = []
        for index in range(len(label_handles)):
            label_handle.extend(label_handles[index][input_loc].values())
        axes[i].legend(loc='best', handles=label_handle, frameon=False, framealpha=0.5)
    plt.subplots_adjust(hspace=0.4, wspace=0.3, left=0.05, right=0.95, top=0.95, bottom=0.05)
    clean_axes(axes)
    if not svg_title is None:
        #fig.set_size_inches(20.8, 15.6)  # 19.2, 12)19.2, 12)
        fig.savefig(data_dir+svg_title+' - NMDAR conductance.svg', format='svg')
    plt.show()
    plt.close()


def process_patterned_input_simulation_input_output(rec_filename, title, svg_title=None):
    """

    :param rec_file_name: str
    :param title: str
    :param svg_title: str
    """
    with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
        sim = f.itervalues().next()
        equilibrate = sim.attrs['equilibrate']
        track_equilibrate = sim.attrs['track_equilibrate']
        input_field_duration = sim.attrs['input_field_duration']
        duration = sim.attrs['duration']
        stim_dt = sim.attrs['stim_dt']
        track_duration = duration - equilibrate - track_equilibrate
        stim_t = np.arange(-track_equilibrate, track_duration, stim_dt)
        start = int(track_equilibrate/stim_dt)
        spatial_bin = input_field_duration/50.
        intervals = []
        pop_input = []
        successes = []
        inh_input = []
        output = []
        if 'successes' in sim:
            stochastic = True
        else:
            stochastic = False
        for sim in f.itervalues():
            exc_input_sum = None
            successes_sum = None
            inh_input_sum = None
            for key, train in sim['train'].iteritems():
                this_train = np.array(train)
                if len(this_train) > 0:
                    for i in range(len(this_train) - 1):
                        intervals.append(this_train[i+1] - this_train[i])
                this_exc_rate = get_binned_firing_rate(this_train, stim_t)
                if exc_input_sum is None:
                    exc_input_sum = np.array(this_exc_rate)
                else:
                    exc_input_sum = np.add(exc_input_sum, this_exc_rate)
                if stochastic:
                    this_success_rate = get_binned_firing_rate(np.array(sim['successes'][key]), stim_t)
                    if successes_sum is None:
                        successes_sum = np.array(this_success_rate)
                    else:
                        successes_sum = np.add(successes_sum, this_success_rate)
            pop_input.append(exc_input_sum)
            if stochastic:
                successes.append(successes_sum)
            for train in sim['inh_train'].itervalues():
                this_inh_rate = get_binned_firing_rate(np.array(train), stim_t)
                if inh_input_sum is None:
                    inh_input_sum = np.array(this_inh_rate)
                else:
                    inh_input_sum = np.add(inh_input_sum, this_inh_rate)
            inh_input.append(inh_input_sum)
            this_output = get_smoothed_firing_rate(np.array(sim['output']), stim_t[start:], bin_dur=3.*spatial_bin,
                                           bin_step=spatial_bin, dt=stim_dt)
            output.append(this_output)
        mean_input = np.mean(pop_input, axis=0)
        if stochastic:
            mean_successes = np.mean(successes, axis=0)
        mean_inh_input = np.mean(inh_input, axis=0)
        mean_output = np.mean(output, axis=0)
        fig, axes = plt.subplots(3, 1, sharex=True)
        axes[0].plot(stim_t[start:], mean_input[start:], label='Total Excitatory Input Spike Rate', c='b')
        if stochastic:
            axes[0].plot(stim_t[start:], mean_successes[start:], label='Total Excitatory Input Success Rate', c='g')
        axes[1].plot(stim_t[start:], mean_inh_input[start:], label='Total Inhibitory Input Spike Rate', c='k')
        axes[2].plot(stim_t[start:], mean_output, label='Single Cell Output Spike Rate', c='r')
        for ax in axes:
            ax.legend(loc='upper left', frameon=False, framealpha=0.5, fontsize=18)
        axes[2].set_xlabel('Time (ms)', fontsize=18)
        axes[1].set_ylabel('Event Rate (Hz)', fontsize=18)
        axes[0].set_title(title, fontsize=20)
        axes[0].set_ylim(0., max(16000., np.max(mean_input) * 1.2))
        axes[0].set_yticks(np.arange(0., max(16000., np.max(mean_input) * 1.2) + 1., 4000.))
        axes[1].set_ylim(0., max(20000., np.max(mean_inh_input) * 1.2))
        axes[2].set_ylim(0., max(50., np.max(mean_output) * 1.2))
        plt.xlim(0., track_duration)
        clean_axes(axes)
        if svg_title is not None:
            plt.savefig(data_dir+svg_title+' - input output - '+title+'.svg', format='svg')
            plt.close()
        else:
            plt.show()
            plt.close()
            plt.hist(intervals, bins=int(max(intervals)/3.), normed=True)
            plt.xlim(0., 200.)
            plt.ylabel('Probability')
            plt.xlabel('Inter-Spike Interval (ms)')
            plt.title('Distribution of Input Inter-Spike Intervals - '+title)
            plt.show()
            plt.close()
    if stochastic:
        return stim_t[start:], mean_input[start:], mean_successes[start:], mean_inh_input[start:], mean_output
    else:
        return stim_t[start:], mean_input[start:], mean_inh_input[start:], mean_output


def process_patterned_input_simulation(rec_filename, title, dt=0.02):
    """

    :param rec_file_name: str
    :param title: str
    :param dt: float
    :return: list of array
    # remember .attrs['phase_offset'] could be inside ['train'] for old files
    """
    with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
        sim = f.itervalues().next()
        equilibrate = sim.attrs['equilibrate']
        track_equilibrate = sim.attrs['track_equilibrate']
        track_length = sim.attrs['track_length']
        input_field_duration = sim.attrs['input_field_duration']
        duration = sim.attrs['duration']
        stim_dt = sim.attrs['stim_dt']
        bins = int((1.5 + track_length) * input_field_duration / 20.)
        track_duration = duration - equilibrate - track_equilibrate
        stim_t = np.arange(-track_equilibrate, track_duration, stim_dt)
        start = int(track_equilibrate/stim_dt)
        spatial_bin = input_field_duration/50.
        intervals = []
        pop_input = []
        output = []
        for sim in f.itervalues():
            exc_input_sum = None
            for key, train in sim['train'].iteritems():
                this_train = np.array(train)
                intervals.extend(np.diff(this_train))
                this_exc_rate = get_binned_firing_rate(this_train, stim_t)
                if exc_input_sum is None:
                    exc_input_sum = np.array(this_exc_rate)
                else:
                    exc_input_sum = np.add(exc_input_sum, this_exc_rate)
            pop_input.append(exc_input_sum)
            this_output = get_smoothed_firing_rate(np.array(sim['output']), stim_t[start:], bin_dur=3.*spatial_bin,
                                           bin_step=spatial_bin, dt=stim_dt)
            output.append(this_output)
        pop_psd = []
        for this_pop_input in pop_input:
            pop_freq, this_pop_psd = signal.periodogram(this_pop_input, 1000./stim_dt)
            pop_psd.append(this_pop_psd)
        pop_psd = np.mean(pop_psd, axis=0)
        left = np.where(pop_freq >= 4.)[0][0]
        right = np.where(pop_freq >= 11.)[0][0]
        pop_psd /= np.max(pop_psd[left:right])
        mean_output = np.mean(output, axis=0)
        plt.hist(intervals, bins=int((max(intervals) - min(intervals)) / 3.), normed=True)
        plt.xlim(0., 200.)
        plt.ylabel('Probability')
        plt.xlabel('Inter-Spike Interval (ms)')
        plt.title('Distribution of Input Inter-Spike Intervals - '+title)
        plt.show()
        plt.close()
        peak_locs = [sim.attrs['peak_loc'] for sim in f.itervalues().next()['train'].itervalues()]
        plt.hist(peak_locs, bins=bins)
        plt.xlabel('Time (ms)')
        plt.ylabel('Count (20 ms Bins)')
        plt.title('Distribution of Input Peak Locations - '+title)
        plt.xlim((np.min(peak_locs), np.max(peak_locs)))
        plt.show()
        plt.close()
        for sim in f.itervalues():
            t = np.arange(0., duration, dt)
            vm = np.interp(t, sim['time'], sim['rec']['0'])
            start = int((equilibrate + track_equilibrate)/dt)
            plt.plot(np.subtract(t[start:], equilibrate + track_equilibrate), vm[start:])
            plt.xlabel('Time (ms)')
            plt.ylabel('Voltage (mV)')
            plt.title('Somatic Vm - '+title)
            plt.ylim((-70., -50.))
        plt.show()
        plt.close()
    rec_t = np.arange(0., track_duration, dt)
    #spikes_removed = get_removed_spikes_alt(rec_filename, plot=0)
    spikes_removed = get_removed_spikes(rec_filename, plot=0)
    # down_sample traces to 2 kHz after clipping spikes for theta and ramp filtering
    down_dt = 0.5
    down_t = np.arange(0., track_duration, down_dt)
    # 2000 ms Hamming window, ~2 Hz low-pass for ramp, ~5 - 10 Hz bandpass for theta, ~0.2 Hz low-pass for residuals
    window_len = int(2000. / down_dt)
    pad_len = int(window_len / 2.)
    theta_filter = signal.firwin(window_len, [5., 10.], nyq=1000. / 2. / down_dt, pass_zero=False)
    ramp_filter = signal.firwin(window_len, 2., nyq=1000. / 2. / down_dt)
    slow_vm_filter = signal.firwin(window_len, .2, nyq=1000. / 2. / down_dt)
    theta_traces = []
    theta_removed = []
    ramp_traces = []
    slow_vm_traces = []
    residuals = []
    intra_psd = []
    theta_envelopes = []
    for trace in spikes_removed:
        intra_freq, this_intra_psd = signal.periodogram(trace, 1000. / dt)
        intra_psd.append(this_intra_psd)
        down_sampled = np.interp(down_t, rec_t, trace)
        padded_trace = np.zeros(len(down_sampled) + window_len)
        padded_trace[pad_len:-pad_len] = down_sampled
        padded_trace[:pad_len] = down_sampled[::-1][-pad_len:]
        padded_trace[-pad_len:] = down_sampled[::-1][:pad_len]
        filtered = signal.filtfilt(theta_filter, [1.], padded_trace, padlen=pad_len)
        this_theta_envelope = np.abs(signal.hilbert(filtered))
        filtered = filtered[pad_len:-pad_len]
        up_sampled = np.interp(rec_t, down_t, filtered)
        theta_traces.append(up_sampled)
        this_theta_removed = trace - up_sampled
        theta_removed.append(this_theta_removed)
        this_theta_envelope = this_theta_envelope[pad_len:-pad_len]
        up_sampled = np.interp(rec_t, down_t, this_theta_envelope)
        theta_envelopes.append(up_sampled)
        filtered = signal.filtfilt(ramp_filter, [1.], padded_trace, padlen=pad_len)
        filtered = filtered[pad_len:-pad_len]
        up_sampled = np.interp(rec_t, down_t, filtered)
        ramp_traces.append(up_sampled)
        filtered = signal.filtfilt(slow_vm_filter, [1.], padded_trace, padlen=pad_len)
        filtered = filtered[pad_len:-pad_len]
        up_sampled = np.interp(rec_t, down_t, filtered)
        slow_vm_traces.append(up_sampled)
        this_residual = this_theta_removed - up_sampled
        residuals.append(this_residual)
    intra_psd = np.mean(intra_psd, axis=0)
    left = np.where(intra_freq >= 4.)[0][0]
    right = np.where(intra_freq >= 11.)[0][0]
    intra_psd /= np.max(intra_psd[left:right])
    # mean_across_trials = np.mean(theta_removed, axis=0)
    # variance_across_trials = np.var(theta_removed, axis=0)
    binned_mean = [[] for i in range(len(residuals))]
    binned_variance = [[] for i in range(len(residuals))]
    binned_t = []
    bin_duration = 3. * spatial_bin
    interval = int(bin_duration / dt)
    for j in range(0, int(track_duration / bin_duration)):
        binned_t.append(j * bin_duration + bin_duration / 2.)
        for i, residual in enumerate(residuals):
            binned_variance[i].append(np.var(residual[j * interval:(j + 1) * interval]))
            binned_mean[i].append(np.mean(theta_removed[i][j * interval:(j + 1) * interval]))
    mean_theta_envelope = np.mean(theta_envelopes, axis=0)
    mean_ramp = np.mean(ramp_traces, axis=0)
    mean_binned_vm = np.mean(binned_mean, axis=0)
    mean_binned_var = np.mean(binned_variance, axis=0)
    scatter_vm_mean = np.array(binned_mean).flatten()
    scatter_vm_var = np.array(binned_variance).flatten()
    print 'Mean Theta Envelope for %s: %.2f' % (title, np.mean(mean_theta_envelope))
    plt.plot(binned_t, mean_binned_vm)
    plt.xlabel('Time - 180 ms bins')
    plt.ylabel('Voltage (mV)')
    plt.title('Somatic Vm Mean - Across Trials - ' + title)
    plt.show()
    plt.close()
    plt.plot(binned_t, mean_binned_var)
    plt.xlabel('Time (ms)')
    plt.ylabel('Vm Variance (mV' + r'$^2$' + ')')
    plt.title('Somatic Vm Variance - Across Trials - ' + title)
    plt.show()
    plt.close()
    plt.scatter(scatter_vm_mean, scatter_vm_var)
    plt.xlabel('Mean Vm (mV)')
    plt.ylabel('Vm Variance (mV' + r'$^2$' + ')')
    plt.title('Mean - Variance Analysis - ' + title)
    plt.show()
    plt.close()
    plt.plot(pop_freq, pop_psd, label='Total Population Input Spikes')
    plt.plot(intra_freq, intra_psd, label='Single Cell Intracellular Vm')
    plt.xlim(4., 11.)
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Normalized Power Density')
    plt.title('Power Spectral Density - ' + title)
    plt.legend(loc='best')
    plt.show()
    plt.close()
    return rec_t, residuals, mean_theta_envelope, scatter_vm_mean, scatter_vm_var, binned_t, mean_binned_vm, \
           mean_binned_var, mean_ramp, mean_output


def process_patterned_input_simulation_theta_freq(rec_filenames, conditions=None, theta_dur=None, field_center=4500.,
                                                  win_dur=1800., dt=0.02):
    """
    :param rec_file_names: dict of str
    :param conditions: list of str
    :param theta_dur: dict of {str: float}
    :param field_center: float
    :param win_dur: float
    :param dt: float
    :return: list of dict of array
    """
    if conditions is None:
        conditions = ['modinh0', 'modinh3']
    if theta_dur is None:
        theta_dur = {'orig': 150., 'modinh': 145.}
    peaks, phases, IPI, binned_peaks, binned_phases, binned_t, binned_IPI, theta_env = {}, {}, {}, {}, {}, {}, \
                                                                                            {}, {}
    for parameter in peaks, phases, IPI, binned_peaks, binned_phases, binned_t, binned_IPI:
        for group in ['exc', 'successes', 'inh', 'intra']:
            parameter[group] = {}
    for condition in conditions:
        rec_filename = rec_filenames[condition]
        with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
            sim = f.itervalues().next()
            equilibrate = sim.attrs['equilibrate']
            track_equilibrate = sim.attrs['track_equilibrate']
            duration = sim.attrs['duration']
            stim_dt = sim.attrs['stim_dt']
            track_duration = duration - equilibrate - track_equilibrate
            stim_t = np.arange(-track_equilibrate, track_duration, stim_dt)
            stim_t_start = int(track_equilibrate/stim_dt)
            t = stim_t[stim_t_start:] - stim_t[stim_t_start]
            rec_t = np.arange(0., track_duration, dt)
            down_dt = 0.5
            # ~5 - 10 Hz bandpass for theta
            window_len = int(2000. / down_dt)
            theta_filter = signal.firwin(window_len, [5., 10.], nyq=1000. / 2. / down_dt, pass_zero=False)
            time_offset = {'orig': [trial.attrs['phase_offset'] for trial in f.itervalues()]}
            if 'mod_inh_time_offset' in sim.attrs:
                time_offset['modinh'] = [trial.attrs['mod_inh_time_offset'] for trial in f.itervalues()]
            for i, trial in enumerate(f.itervalues()):
                for group, key in zip(['exc', 'successes', 'inh'], ['train', 'successes', 'inh_train']):
                    if key in trial:
                        input_sum = None
                        for train in trial[key].itervalues():
                            this_rate = get_binned_firing_rate(train[:], stim_t)
                            if input_sum is None:
                                input_sum = np.array(this_rate)
                            else:
                                input_sum = np.add(input_sum, this_rate)
                        if condition not in peaks[group]:
                            peaks[group][condition] = []
                            phases[group][condition] = {}
                            IPI[group][condition] = []
                        theta_trace = general_filter_trace(t, input_sum[stim_t_start:], filter=theta_filter,
                                                          duration=track_duration, dt=stim_dt)
                        for LFP_type in time_offset:
                            this_peaks, this_phases = get_waveform_phase_vs_time(t, theta_trace,
                                                                                 cycle_duration=theta_dur[LFP_type],
                                                                                 time_offset=time_offset[LFP_type][i])
                            if LFP_type not in phases[group][condition]:
                                phases[group][condition][LFP_type] = []
                            if LFP_type == 'orig':
                                peaks[group][condition].append(this_peaks)
                                this_IPI = np.diff(this_peaks)
                                IPI[group][condition].append(this_IPI)
                            phases[group][condition][LFP_type].append(this_phases)
        spikes_removed = get_removed_spikes(rec_filename, dt=dt, plot=0)
        group = 'intra'
        for i, trace in enumerate(spikes_removed):
            if condition not in peaks[group]:
                peaks[group][condition] = []
                phases[group][condition] = {}
                IPI[group][condition] = []
                theta_env[condition] = []
            theta_trace = general_filter_trace(rec_t, trace, filter=theta_filter,
                                               duration=track_duration, dt=dt)
            this_theta_env = np.abs(signal.hilbert(theta_trace))
            theta_env[condition].append(this_theta_env)
            for LFP_type in time_offset:
                this_peaks, this_phases = get_waveform_phase_vs_time(rec_t, theta_trace,
                                                                     cycle_duration=theta_dur[LFP_type],
                                                                     time_offset=time_offset[LFP_type][i])
                if LFP_type not in phases[group][condition]:
                    phases[group][condition][LFP_type] = []
                if LFP_type == 'orig':
                    peaks[group][condition].append(this_peaks)
                    this_IPI = np.diff(this_peaks)
                    IPI[group][condition].append(this_IPI)
                phases[group][condition][LFP_type].append(this_phases)
    start = field_center - win_dur / 2.
    end = start + win_dur
    for group in peaks:
        for condition in peaks[group]:
            binned_phases[group][condition] = {}
            for LFP_type in phases[group][condition]:
                binned_peaks[group][condition], binned_phases[group][condition][LFP_type] = \
                    plot_phase_precession(peaks[group][condition], phases[group][condition][LFP_type],
                                          group+'_'+condition+'; LFP: '+LFP_type, fit_start=start,
                                          fit_end=end)
            binned_t[group][condition], binned_IPI[group][condition] = plot_IPI(peaks[group][condition],
                                                                                IPI[group][condition],
                                                                                group+'_'+condition)
    return t, rec_t, peaks, phases, IPI, binned_peaks, binned_phases, binned_t, binned_IPI, theta_env


def plot_IPI(t_array, IPI_array, title, display_start=0., display_end=7500., bin_size=60., num_bins=5, svg_title=None,
             plot=True):
    """

    :param t_array: array
    :param IPI_array: array
    :param title: str
    :param display_start: float
    :param display_end: float
    :param bin_size: float
    :param num_bins: int
    :param svg_title: str
    :param plot: bool
    :return: list of array
    """
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 8
    peaks = []
    fig, axes = plt.subplots(1)
    for i, t in enumerate(t_array):
        this_IPI = IPI_array[i]
        this_t = [(t[i] + t[i + 1]) / 2. for i in range(len(this_IPI))]
        peaks.append(np.array(this_t))
        axes.scatter(this_t, this_IPI, color='gray', s=0.1)
    binned_t = []
    binned_IPI = []
    window_dur = float(num_bins) * bin_size
    start = display_start
    while start + window_dur <= display_end:
        this_IPI_bin = []
        for i, this_peak in enumerate(peaks):
            indexes = np.where((this_peak >= start) & (this_peak < start + window_dur))[0]
            this_IPI_bin.extend(IPI_array[i][indexes])
        if np.any(this_IPI_bin):
            binned_t.append(start + window_dur / 2.)  # (np.mean(spike_times))
            binned_IPI.append(np.mean(this_IPI_bin))
        start += window_dur
    binned_t = np.array(binned_t)
    binned_IPI = np.array(binned_IPI)
    axes.plot(binned_t, binned_IPI, c='k')  # , linewidth=2.)
    axes.set_xlim(display_start, display_end)
    axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
    if svg_title is not None:
        axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
    axes.set_ylabel('Theta inter-peak intervals (ms)')
    axes.set_xlabel('Time (s)')
    axes.set_title('Theta inter-peak intervals - ' + title, fontsize=mpl.rcParams['font.size'])
    clean_axes(axes)
    axes.tick_params(direction='out')
    if svg_title is not None:
        fig.set_size_inches(1.21, 2.)
        fig.savefig(data_dir + svg_title + ' - IPI - ' + title + '.svg', format='svg', transparent=True)
    if plot:
        plt.show()
    plt.close()
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size
    return binned_t, binned_IPI


def process_patterned_input_simulation_single_compartment(rec_filename, title, dt=0.1):
    """

    :param rec_file_name: str
    :param title: str
    :param dt: float
    :return: list of array
    # remember .attrs['phase_offset'] could be inside ['train'] for old files
    """
    with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
        sim = f.itervalues().next()
        equilibrate = sim.attrs['equilibrate']
        track_equilibrate = sim.attrs['track_equilibrate']
        track_length = sim.attrs['track_length']
        input_field_duration = sim.attrs['input_field_duration']
        duration = sim.attrs['duration']
        stim_dt = sim.attrs['stim_dt']
        bins = int((1.5 + track_length) * input_field_duration / 20.)
        track_duration = duration - equilibrate - track_equilibrate
        stim_t = np.arange(-track_equilibrate, track_duration, stim_dt)
        start = int(track_equilibrate/stim_dt)
        spatial_bin = input_field_duration/50.
        intervals = []
        pop_input = []
        for sim in f.itervalues():
            exc_input_sum = None
            for key, train in sim['train'].iteritems():
                this_train = np.array(train)
                intervals.extend(np.diff(this_train))
                this_exc_rate = get_binned_firing_rate(this_train, stim_t)
                if exc_input_sum is None:
                    exc_input_sum = np.array(this_exc_rate)
                else:
                    exc_input_sum = np.add(exc_input_sum, this_exc_rate)
            pop_input.append(exc_input_sum)
        pop_psd = []
        for this_pop_input in pop_input:
            pop_freq, this_pop_psd = signal.periodogram(this_pop_input, 1000./stim_dt)
            pop_psd.append(this_pop_psd)
        pop_psd = np.mean(pop_psd, axis=0)
        left = np.where(pop_freq >= 4.)[0][0]
        right = np.where(pop_freq >= 11.)[0][0]
        pop_psd /= np.max(pop_psd[left:right])
        plt.hist(intervals, bins=int((max(intervals)-min(intervals))/3.), normed=True)
        plt.xlim(0., 200.)
        plt.ylabel('Probability')
        plt.xlabel('Inter-Spike Interval (ms)')
        plt.title('Distribution of Input Inter-Spike Intervals - '+title)
        plt.show()
        plt.close()
        peak_locs = [sim.attrs['peak_loc'] for sim in f.itervalues().next()['train'].itervalues()]
        plt.hist(peak_locs, bins=bins)
        plt.xlabel('Time (ms)')
        plt.ylabel('Count (20 ms Bins)')
        plt.title('Distribution of Input Peak Locations - '+title)
        plt.xlim((np.min(peak_locs), np.max(peak_locs)))
        plt.show()
        plt.close()
        for sim in f.itervalues():
            t = np.arange(0., duration, dt)
            vm = np.interp(t, sim['time'], sim['rec']['0'])
            start = int((equilibrate + track_equilibrate)/dt)
            plt.plot(np.subtract(t[start:], equilibrate + track_equilibrate), vm[start:])
            plt.xlabel('Time (ms)')
            plt.ylabel('Voltage (mV)')
            plt.title('Somatic Vm - '+title)
            plt.ylim((-70., -50.))
        plt.show()
        plt.close()
    rec_t = np.arange(0., track_duration, dt)
    spikes_removed = get_removed_spikes(rec_filename, plot=0, dt=dt)
    # down_sample traces to 2 kHz after clipping spikes for theta and ramp filtering
    down_dt = 0.5
    down_t = np.arange(0., track_duration, down_dt)
    # 2000 ms Hamming window, ~2 Hz low-pass for ramp, ~5 - 10 Hz bandpass for theta, ~0.2 Hz low-pass for residuals
    window_len = int(2000. / down_dt)
    pad_len = int(window_len / 2.)
    theta_filter = signal.firwin(window_len, [5., 10.], nyq=1000. / 2. / down_dt, pass_zero=False)
    ramp_filter = signal.firwin(window_len, 2., nyq=1000. / 2. / down_dt)
    slow_vm_filter = signal.firwin(window_len, .2, nyq=1000. / 2. / down_dt)
    theta_traces = []
    theta_removed = []
    ramp_traces = []
    slow_vm_traces = []
    residuals = []
    intra_psd = []
    theta_envelopes = []
    for trace in spikes_removed:
        intra_freq, this_intra_psd = signal.periodogram(trace, 1000. / dt)
        intra_psd.append(this_intra_psd)
        down_sampled = np.interp(down_t, rec_t, trace)
        padded_trace = np.zeros(len(down_sampled) + window_len)
        padded_trace[pad_len:-pad_len] = down_sampled
        padded_trace[:pad_len] = down_sampled[::-1][-pad_len:]
        padded_trace[-pad_len:] = down_sampled[::-1][:pad_len]
        filtered = signal.filtfilt(theta_filter, [1.], padded_trace, padlen=pad_len)
        this_theta_envelope = np.abs(signal.hilbert(filtered))
        filtered = filtered[pad_len:-pad_len]
        up_sampled = np.interp(rec_t, down_t, filtered)
        theta_traces.append(up_sampled)
        this_theta_removed = trace - up_sampled
        theta_removed.append(this_theta_removed)
        this_theta_envelope = this_theta_envelope[pad_len:-pad_len]
        up_sampled = np.interp(rec_t, down_t, this_theta_envelope)
        theta_envelopes.append(up_sampled)
        filtered = signal.filtfilt(ramp_filter, [1.], padded_trace, padlen=pad_len)
        filtered = filtered[pad_len:-pad_len]
        up_sampled = np.interp(rec_t, down_t, filtered)
        ramp_traces.append(up_sampled)
        filtered = signal.filtfilt(slow_vm_filter, [1.], padded_trace, padlen=pad_len)
        filtered = filtered[pad_len:-pad_len]
        up_sampled = np.interp(rec_t, down_t, filtered)
        slow_vm_traces.append(up_sampled)
        this_residual = this_theta_removed - up_sampled
        residuals.append(this_residual)
    intra_psd = np.mean(intra_psd, axis=0)
    left = np.where(intra_freq >= 4.)[0][0]
    right = np.where(intra_freq >= 11.)[0][0]
    intra_psd /= np.max(intra_psd[left:right])
    # mean_across_trials = np.mean(theta_removed, axis=0)
    # variance_across_trials = np.var(theta_removed, axis=0)
    binned_mean = [[] for i in range(len(residuals))]
    binned_variance = [[] for i in range(len(residuals))]
    binned_t = []
    bin_duration = 3. * spatial_bin
    interval = int(bin_duration / dt)
    for j in range(0, int(track_duration / bin_duration)):
        binned_t.append(j * bin_duration + bin_duration / 2.)
        for i, residual in enumerate(residuals):
            binned_variance[i].append(np.var(residual[j * interval:(j + 1) * interval]))
            binned_mean[i].append(np.mean(theta_removed[i][j * interval:(j + 1) * interval]))
    mean_theta_envelope = np.mean(theta_envelopes, axis=0)
    mean_ramp = np.mean(ramp_traces, axis=0)
    mean_binned_vm = np.mean(binned_mean, axis=0)
    mean_binned_var = np.mean(binned_variance, axis=0)
    scatter_vm_mean = np.array(binned_mean).flatten()
    scatter_vm_var = np.array(binned_variance).flatten()
    print 'Mean Theta Envelope for %s: %.2f' % (title, np.mean(mean_theta_envelope))
    plt.plot(binned_t, mean_binned_vm)
    plt.xlabel('Time - 180 ms bins')
    plt.ylabel('Voltage (mV)')
    plt.title('Somatic Vm Mean - Across Trials - ' + title)
    plt.show()
    plt.close()
    plt.plot(binned_t, mean_binned_var)
    plt.xlabel('Time (ms)')
    plt.ylabel('Vm Variance (mV' + r'$^2$' + ')')
    plt.title('Somatic Vm Variance - Across Trials - ' + title)
    plt.show()
    plt.close()
    plt.scatter(scatter_vm_mean, scatter_vm_var)
    plt.xlabel('Mean Vm (mV)')
    plt.ylabel('Vm Variance (mV' + r'$^2$' + ')')
    plt.title('Mean - Variance Analysis - ' + title)
    plt.show()
    plt.close()
    plt.plot(pop_freq, pop_psd, label='Total Population Input Spikes')
    plt.plot(intra_freq, intra_psd, label='Single Cell Intracellular Vm')
    plt.xlim(4., 11.)
    plt.xlabel('Frequency (Hz)')
    plt.ylabel('Normalized Power Density')
    plt.title('Power Spectral Density - ' + title)
    plt.legend(loc='best')
    plt.show()
    plt.close()
    return rec_t, residuals, mean_theta_envelope, scatter_vm_mean, scatter_vm_var, binned_t, mean_binned_vm, \
           mean_binned_var, mean_ramp, np.zeros_like(mean_ramp)


def process_patterned_input_simulation_fix_bins(rec_filename, title, dt=0.02):
    """

    :param rec_file_name: str
    :param title: str
    :param dt: float
    :return: array
    # remember .attrs['phase_offset'] could be inside ['train'] for old files
    """
    with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
        sim = f.itervalues().next()
        equilibrate = sim.attrs['equilibrate']
        track_equilibrate = sim.attrs['track_equilibrate']
        track_length = sim.attrs['track_length']
        input_field_duration = sim.attrs['input_field_duration']
        duration = sim.attrs['duration']
        stim_dt = sim.attrs['stim_dt']
        bins = int((1.5 + track_length) * input_field_duration / 20.)
        track_duration = duration - equilibrate - track_equilibrate
        stim_t = np.arange(-track_equilibrate, track_duration, stim_dt)
        start = int(track_equilibrate/stim_dt)
        spatial_bin = input_field_duration/50.
    rec_t = np.arange(0., track_duration, dt)
    #spikes_removed = get_removed_spikes_alt(rec_filename, plot=0)
    spikes_removed = get_removed_spikes(rec_filename, plot=0)
    # down_sample traces to 2 kHz after clipping spikes for theta and ramp filtering
    down_dt = 0.5
    down_t = np.arange(0., track_duration, down_dt)
    # 2000 ms Hamming window, ~2 Hz low-pass for ramp, ~5 - 10 Hz bandpass for theta, ~0.2 Hz low-pass for residuals
    window_len = int(2000. / down_dt)
    pad_len = int(window_len / 2.)
    theta_filter = signal.firwin(window_len, [5., 10.], nyq=1000. / 2. / down_dt, pass_zero=False)
    ramp_filter = signal.firwin(window_len, 2., nyq=1000. / 2. / down_dt)
    slow_vm_filter = signal.firwin(window_len, .2, nyq=1000. / 2. / down_dt)
    theta_traces = []
    theta_removed = []
    ramp_traces = []
    slow_vm_traces = []
    residuals = []
    intra_psd = []
    theta_envelopes = []
    for trace in spikes_removed:
        down_sampled = np.interp(down_t, rec_t, trace)
        padded_trace = np.zeros(len(down_sampled) + window_len)
        padded_trace[pad_len:-pad_len] = down_sampled
        padded_trace[:pad_len] = down_sampled[::-1][-pad_len:]
        padded_trace[-pad_len:] = down_sampled[::-1][:pad_len]
        filtered = signal.filtfilt(theta_filter, [1.], padded_trace, padlen=pad_len)
        this_theta_envelope = np.abs(signal.hilbert(filtered))
        filtered = filtered[pad_len:-pad_len]
        up_sampled = np.interp(rec_t, down_t, filtered)
        theta_traces.append(up_sampled)
        this_theta_removed = trace - up_sampled
        theta_removed.append(this_theta_removed)
        this_theta_envelope = this_theta_envelope[pad_len:-pad_len]
        up_sampled = np.interp(rec_t, down_t, this_theta_envelope)
        theta_envelopes.append(up_sampled)
        filtered = signal.filtfilt(ramp_filter, [1.], padded_trace, padlen=pad_len)
        filtered = filtered[pad_len:-pad_len]
        up_sampled = np.interp(rec_t, down_t, filtered)
        ramp_traces.append(up_sampled)
        filtered = signal.filtfilt(slow_vm_filter, [1.], padded_trace, padlen=pad_len)
        filtered = filtered[pad_len:-pad_len]
        up_sampled = np.interp(rec_t, down_t, filtered)
        slow_vm_traces.append(up_sampled)
        this_residual = this_theta_removed - up_sampled
        residuals.append(this_residual)
    binned_mean = [[] for i in range(len(residuals))]
    binned_variance = [[] for i in range(len(residuals))]
    binned_t = []
    bin_duration = 3. * spatial_bin
    interval = int(bin_duration / dt)
    for j in range(0, int(track_duration / bin_duration)):
        binned_t.append(j * bin_duration + bin_duration / 2.)
        for i, residual in enumerate(residuals):
            binned_variance[i].append(np.var(residual[j * interval:(j + 1) * interval]))
            binned_mean[i].append(np.mean(theta_removed[i][j * interval:(j + 1) * interval]))
    mean_binned_vm = np.mean(binned_mean, axis=0)
    mean_binned_var = np.mean(binned_variance, axis=0)
    scatter_vm_mean = np.array(binned_mean).flatten()
    scatter_vm_var = np.array(binned_variance).flatten()
    return scatter_vm_mean, scatter_vm_var, binned_t, mean_binned_vm, mean_binned_var


def plot_patterned_input_individual_trial_traces(rec_t, vm_array, theta_traces, ramp_traces, residuals, index=None,
                                                 svg_title=None):
    """
    Accepts the output of get_patterned_input_component_traces, and either saves a single figure, or cycles through a
    a series of plots so the user can choose an individual trial to save as a figure.
    :param rec_t: array
    :param vm_array: list of array
    :param theta_traces: list of array
    :param ramp_traces: list of array
    :param residuals: list of array
    :param index: int
    :param svg_title: str
    """
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 8
    if index is not None:
        index_range = [index]
    else:
        index_range = range(len(vm_array))
    for i in index_range:
        fig, axes = plt.subplots(4, sharey=True, sharex=True)
        label_handles = []
        axes[0].plot(rec_t, vm_array[i], color='k', label='Raw Vm')
        axes[0].set_axis_off()
        label_handles.append(mlines.Line2D([], [], color='k', label='Raw Vm'))
        axes[1].plot(rec_t, ramp_traces[i], color='r', label='Subthreshold Vm')
        axes[1].set_axis_off()
        label_handles.append(mlines.Line2D([], [], color='r', label='Subthreshold Vm'))
        axes[2].plot(rec_t, theta_traces[i], color='c', label='Theta Vm')
        label_handles.append(mlines.Line2D([], [], color='c', label='Theta Vm'))
        axes[2].set_xlim(0., 7500.)
        axes[2].set_ylim(-67., 30.)
        axes[2].set_axis_off()
        axes[3].plot(rec_t, residuals[i], color='purple', label='Residual Vm')
        label_handles.append(mlines.Line2D([], [], color='purple', label='Residual Vm'))
        clean_axes(axes)
        if svg_title is not None:
            axes[1].legend(handles=label_handles, loc='best', frameon=False, framealpha=0.5)
            # fontsize=mpl.rcParams['font.size'])
            fig.set_size_inches(2.7696, 3.1506)
            fig.savefig(data_dir+svg_title+str(i)+' - example traces.svg', format='svg', transparent=True)
        plt.show()
        plt.close()
        if svg_title is not None:
            mpl.rcParams['font.size'] = remember_font_size


def plot_vm_distribution(rec_filenames, key_list=None, i_bounds=[0., 1800., 3600., 5400.],
                         dt=0.02, bin_width=0.2, svg_title=None):
    """
    Given a set of simulation files, collapse all the trials for a given condition into a single occupancy distribution
    of absolute voltages, for the time periods corresponding to the inhibitory manipulation. Superimpose the 4
    conditions ['Control - Out of Field', 'Control - In Field', 'Reduced Inhibition - Out of Field',
    'Reduced Inhibition - In Field'] in a single plot.
    :param rec_filenames: dict of str
    :param key_list: list of str
    :param i_bounds: list of float, time points corresponding to inhibitory manipulation
    :param peak_bounds: list of float, time points corresponding to 10 "spatial bins" for averaging
    :param dt: float, temporal resolution
    :param bin_width: float, in mV, determines number of bins
    :param svg_title: str
    """
    remember_font_size = mpl.rcParams['font.size']
    mpl.rcParams['font.size'] = 8
    trial_array = {}
    vm_array = {}
    if key_list is None:
        key_list = ['modinh0', 'modinh1', 'modinh2']
    for condition in key_list:
        #trial_array[condition] = get_removed_spikes(rec_filenames[condition], plot=0)
        spikes_removed_interp, trial_array[condition] = get_removed_spikes_nangaps(rec_filenames[condition])
    key_list.extend([key_list[0]+'_out', key_list[0]+'_in'])
    for source_condition, target_condition in zip([key_list[1], key_list[0]], [key_list[1], key_list[3]]):
        start = int(i_bounds[0]/dt)
        end = int(i_bounds[1]/dt)
        for trial in trial_array[source_condition]:
            this_vm_chunk = trial[start:end]
            keep = ~np.isnan(this_vm_chunk)
            if target_condition not in vm_array:
                vm_array[target_condition] = np.array(this_vm_chunk[keep])
            else:
                vm_array[target_condition] = np.append(vm_array[target_condition], np.array(this_vm_chunk[keep]))
    for source_condition, target_condition in zip([key_list[2], key_list[0]], [key_list[2], key_list[4]]):
        start = int(i_bounds[2]/dt)
        end = int(i_bounds[3]/dt)
        for trial in trial_array[source_condition]:
            this_vm_chunk = trial[start:end]
            keep = ~np.isnan(this_vm_chunk)
            if target_condition not in vm_array:
                vm_array[target_condition] = np.array(this_vm_chunk[keep])
            else:
                vm_array[target_condition] = np.append(vm_array[target_condition], np.array(this_vm_chunk[keep]))
    hist, edges = {}, {}
    for condition in vm_array:
        num_bins = int((np.max(vm_array[condition]) - np.min(vm_array[condition])) / bin_width)
        hist[condition], edges[condition] = np.histogram(vm_array[condition], density=True, bins=num_bins)
        hist[condition] *= bin_width * 100.
        edges[condition] = edges[condition][1:]
    colors = ['k', 'grey', 'orange', 'y']
    fig, axes = plt.subplots(1)
    for i, (condition, title) in enumerate(zip([key_list[3], key_list[4]], ['Out of field', 'In field'])):
        axes.plot(edges[condition], hist[condition], color=colors[i], label=title)
    clean_axes(axes)
    axes.set_xlabel('Vm (mV)')
    axes.set_ylabel('Probability (%)')
    axes.set_ylim(0., 7.)
    axes.set_yticks([0., 2., 4., 6.])
    axes.set_xlim(-70., -45.)
    axes.set_title('Control', fontsize=mpl.rcParams['font.size'])
    axes.tick_params(direction='out')
    plt.legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    if svg_title is not None:
        fig.set_size_inches(1.3198, 1.2169)
        #fig.set_size_inches(1.74, 1.43)
        fig.savefig(data_dir+svg_title+' - Vm Distributions - Control.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    fig, axes = plt.subplots(1)
    for i, (condition, title) in enumerate(zip([key_list[1], key_list[2]], ['Out of field', 'In field'])):
        axes.plot(edges[condition], hist[condition], color=colors[i+2], label=title)
    clean_axes(axes)
    axes.set_xlabel('Vm (mV)')
    axes.set_ylabel('Probability (%)')
    axes.set_ylim(0., 7.)
    axes.set_yticks([0., 2., 4., 6.])
    axes.set_xlim(-70., -45.)
    axes.set_title('Reduced inhibition', fontsize=mpl.rcParams['font.size'])
    axes.tick_params(direction='out')
    plt.legend(loc='best', frameon=False, framealpha=0.5, fontsize=8)
    if svg_title is not None:
        #fig.set_size_inches(1.74, 1.43)
        fig.set_size_inches(1.3198, 1.2169)
        fig.savefig(data_dir+svg_title+' - Vm Distributions - ModInh.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    fig, axes = plt.subplots(1)
    for i, (condition, title) in enumerate(zip([key_list[3], key_list[4], key_list[1], key_list[2]],
                                               ['Control - Out', 'Control - In', 'Reduced inhibition - Out',
                                                'Reduced inhibition - In'])):
        axes.plot(edges[condition], hist[condition], color=colors[i], label=title)
    clean_axes(axes)
    axes.set_xlabel('Vm (mV)')
    axes.set_ylabel('Probability (%)')
    axes.set_ylim(0., 7.)
    axes.set_yticks([0., 2., 4., 6.])
    axes.set_xlim(-70., -45.)
    axes.set_xticks([-70., -60., -50.])
    #axes.set_title('Simulated Vm Distributions', fontsize=8)
    axes.tick_params(direction='out')
    plt.legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    if svg_title is not None:
        #fig.set_size_inches(1.945, 1.16)
        fig.set_size_inches(1.3198, 1.2169)
        fig.savefig(data_dir + svg_title + ' - Vm Distributions - All.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    mpl.rcParams['font.size'] = remember_font_size
    span_edges = {}
    for condition in hist:
        span_edges[condition] = {}
        test = np.where(np.array(hist[condition]) >= 0.02 * np.max(hist[condition]))[0]
        print condition, 'start:', edges[condition][test[0]], 'end:', edges[condition][test[-1]]
        span_edges[condition]['start'] = edges[condition][test[0]]
        span_edges[condition]['end'] = edges[condition][test[-1]]
    for condition in [key_list[3], key_list[1]]:
        print condition, 'distance to threshold:', span_edges[condition]['end'] - -52.
    print 'Control overlap span:', span_edges[key_list[3]]['end'] - span_edges[key_list[4]]['start']
    print 'Modinh overlap span:', span_edges[key_list[1]]['end'] - span_edges[key_list[2]]['start']
    return hist, edges


def plot_patterned_input_sim_summary(rec_t, mean_theta_envelope, binned_t,  mean_binned_var, mean_ramp, mean_output,
                                     key_list=None, titles=None, baseline_range=[0., 600.], dt=0.02, svg_title=None):
    """
    Expects the output of process_patterned_input_simulation.
    Produces summary plots for ramp, variance, theta, and firing rate.
    :param rec_t: array
    :param mean_theta_envelope: array
    :param binned_t: array
    :param mean_binned_var: array
    :param mean_ramp: array
    :param mean_output: array
    :param key_list: list of str
    :param titles: list of str
    :param baseline_range: list of float
    :param dt: float
    :param svg_title: str
    """
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 20
    if key_list is None:
        key_list = ['modinh0', 'modinh1', 'modinh2']
    if titles is None:
        titles = ['Control', 'Reduced inhibition - In field', 'Reduced inhibition - Out of field']
    colors = ['k', 'y', 'orange']
    fig, axes = plt.subplots(1)
    baseline = np.mean(mean_ramp[key_list[0]][int(baseline_range[0]/dt):int(baseline_range[1]/dt)])
    for i, (condition, title) in enumerate(zip([key_list[0], key_list[2], key_list[1]], titles)):
        axes.plot(rec_t, np.subtract(mean_ramp[condition], baseline), color=colors[i], label=title)
    clean_axes(axes)
    axes.set_xlabel('Time (s)')
    axes.set_ylabel('DVm (mV)')
    axes.set_ylim(-0.8, 9.)
    axes.set_xlim(0., 7500.)
    axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
    axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
    axes.tick_params(direction='out')
    plt.legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    if svg_title is not None:
        fig.set_size_inches(4.403, 3.631)
        fig.savefig(data_dir+svg_title+' - Summary - Ramp.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    fig, axes = plt.subplots(1)
    for i, (condition, title) in enumerate(zip([key_list[0], key_list[2], key_list[1]], titles)):
        axes.plot(rec_t, mean_output[condition], color=colors[i], label=title)
    clean_axes(axes)
    axes.set_xlabel('Time (s)')
    axes.set_ylabel('Firing rate (Hz)')
    axes.set_ylim(0., 45.)
    axes.set_xlim(0., 7500.)
    axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
    axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
    # plt.legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    axes.tick_params(direction='out')
    if svg_title is not None:
        fig.set_size_inches(4.403, 3.631)
        fig.savefig(data_dir + svg_title + ' - Summary - Rate.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    fig, axes = plt.subplots(1)
    for i, (condition, title) in enumerate(zip([key_list[0], key_list[2], key_list[1]], titles)):
        axes.plot(rec_t, mean_theta_envelope[condition], color=colors[i], label=title)
    clean_axes(axes)
    axes.set_xlabel('Time (s)')
    axes.set_ylabel('Thetaintra (mV)')
    axes.set_ylim(0., 2.5)
    axes.set_xlim(0., 7500.)
    axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
    axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
    # plt.legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    axes.tick_params(direction='out')
    if svg_title is not None:
        fig.set_size_inches(4.403, 3.631)
        fig.savefig(data_dir + svg_title + ' - Summary - Theta.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    mpl.rcParams['font.size'] = 8
    fig, axes = plt.subplots(1)
    for i, (condition, title) in enumerate(zip([key_list[0], key_list[2], key_list[1]], titles)):
        axes.plot(binned_t[condition], mean_binned_var[condition], color=colors[i], label=title)
    clean_axes(axes)
    axes.set_xlabel('Time (s)', fontsize=8)
    axes.set_ylabel('Variance (mV2)', fontsize=8)
    axes.set_ylim(0., 7.)
    axes.set_xlim(0., 7500.)
    axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
    axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
    plt.legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    axes.tick_params(direction='out')
    if svg_title is not None:
        fig.set_size_inches(1.95, 1.16)
        fig.savefig(data_dir + svg_title + ' - Summary - Variance.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size


def plot_place_field_summmary_across_cells(rec_t, mean_theta_envelope, binned_t, mean_binned_var, mean_ramp,
                                           mean_output, groups=None, key_list=None, titles=None,
                                           baseline_range=[0., 600.], dt=0.02, svg_title=None):
    """
    Expects the output of process_patterned_input_simulation.
    Produces summary plots for ramp, variance, theta, and firing rate depicting mean and SEM across cells.
    :param rec_t: array
    :param mean_theta_envelope: array
    :param binned_t: array
    :param mean_binned_var: array
    :param mean_ramp: array
    :param mean_output: array
    :param groups: list of dict keys
    :param key_list: list of str
    :param titles: list of str
    :param baseline_range: list of float
    :param dt: float
    :param svg_title: str
    """
    if groups is None:
        groups = mean_theta_envelope.keys()
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 8
    if key_list is None:
        key_list = ['modinh0', 'modinh3']
    if titles is None:
        titles = ['Control', 'Reduced inhibition']
    colors = [('k', 'grey'), ('orange', 'orange')]
    for parameter in mean_theta_envelope, binned_t, mean_binned_var, mean_ramp, mean_output:
        parameter['mean'], parameter['var'] = {}, {}
        for condition in key_list:
            parameter['mean'][condition] = np.mean([parameter[group][condition] for group in groups], axis=0)
            parameter['var'][condition] = np.var([parameter[group][condition] for group in groups], axis=0)
    """
    fig, axes = plt.subplots(1)
    baseline = np.mean(mean_ramp['mean'][key_list[0]][int(baseline_range[0]/dt):int(baseline_range[1]/dt)])
    for i, (condition, title) in enumerate(zip(key_list, titles)):
        this_mean = np.subtract(mean_ramp['mean'][condition], baseline)
        this_variance = mean_ramp['var'][condition]
        this_SEM = np.divide(np.sqrt(this_variance), np.sqrt(float(len(groups))))
        axes.plot(rec_t, np.subtract(this_mean, this_SEM), color=colors[i][1])
        axes.plot(rec_t, np.add(this_mean, this_SEM), color=colors[i][1])
        axes.plot(rec_t, this_mean, color=colors[i][0], label=title)
        #axes.fill_between(rec_t, np.subtract(this_mean, this_SEM), np.add(this_mean, this_SEM),
        #                  color=colors[i][1])
    clean_axes(axes)
    axes.set_xlabel('Time (s)')
    axes.set_ylabel('DVm (mV)')
    axes.set_ylim(-1., 12.)
    axes.set_xlim(0., 7500.)
    axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
    axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
    axes.tick_params(direction='out')
    # plt.legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    if svg_title is not None:
        fig.set_size_inches(1.3198, 1.2169)
        fig.savefig(data_dir+svg_title+' - Summary - Ramp.svg', format='svg', transparent=True)
    plt.show()
    plt.close()

    fig, axes = plt.subplots(1)
    for i, (condition, title) in enumerate(zip(key_list, titles)):
        this_mean = mean_theta_envelope['mean'][condition]
        this_variance = mean_theta_envelope['var'][condition]
        this_SEM = np.divide(np.sqrt(this_variance), np.sqrt(float(len(groups))))
        axes.plot(rec_t, np.subtract(this_mean, this_SEM), color=colors[i][1])
        axes.plot(rec_t, np.add(this_mean, this_SEM), color=colors[i][1])
        axes.plot(rec_t, this_mean, color=colors[i][0], label=title)
    clean_axes(axes)
    axes.set_xlabel('Time (s)')
    axes.set_ylabel('Thetaintra (mV)')
    axes.set_ylim(0., 4.)
    axes.set_xlim(0., 7500.)
    axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
    axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
    # plt.legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    axes.tick_params(direction='out')
    if svg_title is not None:
        fig.set_size_inches(1.3198, 1.2169)
        fig.savefig(data_dir + svg_title + ' - Summary - Theta.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    """
    fig, axes = plt.subplots(1)
    for i, (condition, title) in enumerate(zip(key_list, titles)):
        this_mean = mean_output['mean'][condition]
        this_variance = mean_output['var'][condition]
        this_SEM = np.divide(np.sqrt(this_variance), np.sqrt(float(len(groups))))
        axes.plot(rec_t, np.subtract(this_mean, this_SEM), color=colors[i][1])
        axes.plot(rec_t, np.add(this_mean, this_SEM), color=colors[i][1])
        # axes.fill_between(rec_t, np.subtract(this_mean, this_SEM), np.add(this_mean, this_SEM),
        #                  color=colors[i][1])
        axes.plot(rec_t, this_mean, color=colors[i][0], label=title)
    clean_axes(axes)
    axes.set_xlabel('Time (s)')
    axes.set_ylabel('Firing rate (Hz)')
    # axes.set_ylim(-5., 60.)
    axes.set_ylim(-1., 5.)
    axes.set_xlim(0., 7500.)
    axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
    axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
    # plt.legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    axes.tick_params(direction='out')
    if svg_title is not None:
        fig.set_size_inches(1.3198, 1.2169)
        fig.savefig(data_dir + svg_title + ' - Summary - Rate.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    """
    fig, axes = plt.subplots(1)
    for i, (condition, title) in enumerate(zip(key_list, titles)):
        this_mean = mean_binned_var['mean'][condition]
        this_variance = mean_binned_var['var'][condition]
        this_SEM = np.divide(np.sqrt(this_variance), np.sqrt(float(len(groups))))
        axes.plot(binned_t['mean'][condition], np.subtract(this_mean, this_SEM), color=colors[i][1])
        axes.plot(binned_t['mean'][condition], np.add(this_mean, this_SEM), color=colors[i][1])
        #axes.fill_between(binned_t['mean'][condition], np.subtract(this_mean, this_SEM), np.add(this_mean, this_SEM),
        #                  color=colors[i][1])
        axes.plot(binned_t['mean'][condition], this_mean, color=colors[i][0], label=title)
    clean_axes(axes)
    axes.set_xlabel('Time (s)', fontsize=8)
    axes.set_ylabel('Variance (mV2)', fontsize=8)
    axes.set_ylim(0., 7.)
    axes.set_xlim(0., 7500.)
    axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
    axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
    # plt.legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    axes.tick_params(direction='out')
    if svg_title is not None:
        fig.set_size_inches(1.3198, 1.2169)
        fig.savefig(data_dir + svg_title + ' - Summary - Variance.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    """
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size


def plot_parameter_across_cells(rec_t, parameter_dict, parameter_title=None, ylabel=None, conditions=None,
                                nested_key=None, colors=None, subtract_baseline=False, baseline_range=[0., 600.],
                                dt=0.02, svg_title=None):
    """
    Generic method for plotting mean and variance of a parameter recorded from multiple simulations with different
    random seeds. Assumes a dict with structure {seed: {'condition': value}}, for example the output of
    get_low_pass_recs().
    :param rec_t: array
    :param parameter_dict: dict of dict of array
    :param parameter_title: str
    :param ylabel: str
    :param conditions: list of tuple
    :param nested_key: str
    :param colors: list of tuple
    :param subtract_baseline: bool
    :param baseline_range: list of float
    :param dt: float
    :param svg_title: str
    """
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 8
    if conditions is None:
        conditions = [('modinh0', 'Control'), ('modinh3', 'Reduced inhibition')]
    condition_keys = [item[0] for item in conditions]
    condition_labels = [item[1] for item in conditions]
    if colors is None:
        colors = [('k', 'grey'), ('orange', 'orange')]
    seeds = parameter_dict.keys()
    if nested_key is None:
        parameter = dict(parameter_dict)
    elif nested_key in parameter_dict.itervalues().next().itervalues().next().keys():
        parameter = {seed: {condition: parameter_dict[seed][condition][nested_key] for condition in condition_keys}
                     for seed in seeds}
    else:
        print nested_key, 'not a key within the provided dictionary.'
        return None
    parameter['mean'], parameter['var'] = {}, {}
    for condition in condition_keys:
        parameter['mean'][condition] = np.mean([parameter[seed][condition] for seed in seeds], axis=0)
        parameter['var'][condition] = np.var([parameter[seed][condition] for seed in seeds], axis=0)
    fig, axes = plt.subplots(1)
    if subtract_baseline:
        baseline = np.mean(parameter['mean'][condition_keys[0]][int(baseline_range[0]/dt):int(baseline_range[1]/dt)])
        for condition in condition_keys:
            parameter['mean'][condition] -= baseline
    for i, (condition, title) in enumerate(conditions):
        this_mean = parameter['mean'][condition]
        this_variance = parameter['var'][condition]
        this_SEM = np.divide(np.sqrt(this_variance), np.sqrt(float(len(seeds))))
        axes.plot(rec_t, np.subtract(this_mean, this_SEM), color=colors[i][1])
        axes.plot(rec_t, np.add(this_mean, this_SEM), color=colors[i][1])
        axes.plot(rec_t, this_mean, color=colors[i][0], label=title)
    clean_axes(axes)
    axes.set_xlabel('Time (s)')
    if ylabel is not None:
        axes.set_ylabel(ylabel)
    # axes.set_ylim(-0.8, 10.)
    axes.set_xlim(0., 7500.)
    axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
    axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
    axes.tick_params(direction='out')
    plt.legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    if parameter_title is not None:
        axes.set_title(parameter_title)
    if svg_title is not None:
        fig.set_size_inches(1.3198, 1.2169)
        if parameter_title is None:
            parameter_title = ''
        fig.savefig(data_dir+svg_title+' - Summary - ' + parameter_title + '.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size


def plot_phase_precession(t_array, phase_array, title, fit_start=3600., fit_end=5400., display_start=0.,
                          display_end=7500., bin_size=60., num_bins=5, svg_title=None, plot=True, adjust=True):
    """

    :param t_array: list of np.array
    :param phase_array: list of np.array
    :param title: str
    :param fit_start: float
    :param fit_end: float
    :param display_start: float
    :param display_end: float
    :param bin_size: float
    :param num_bins: int
    :param svg_title: str
    :param plot: bool
    :param adjust: bool
    """
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 8
    fig, axes = plt.subplots(1)
    for i, t in enumerate(t_array):
        phases = phase_array[i]
        axes.scatter(t, phases, color='gray', s=0.1)  # colors[i])
        axes.scatter(t, np.add(phases, 360.), color='gray', s=0.1)  # colors[i])
    binned_times = []
    binned_phases = []
    all_spike_times = []
    all_spike_phases = []
    window_dur = float(num_bins) * bin_size
    start = display_start
    while start + window_dur <= display_end:
        index_array = [np.where((spike_times >= start) & (spike_times < start + window_dur))[0] for spike_times in
                                                                                                        t_array]
        spike_times = []
        spike_phases = []
        for i in range(len(t_array)):
            for spike_time, spike_phase in zip(t_array[i][index_array[i]], phase_array[i][index_array[i]]):
                spike_times.append(spike_time)
                spike_phases.append(spike_phase)
        if np.any(spike_times):
            binned_times.append(start+window_dur/2.)  # (np.mean(spike_times))
            binned_phases.append(stats.circmean(spike_phases, high=360., low=0.))
            all_spike_times.extend(spike_times)
            all_spike_phases.extend(spike_phases)
        start += window_dur
    all_spike_times = np.array(all_spike_times)
    all_spike_phases = np.array(all_spike_phases)
    binned_times = np.array(binned_times)
    binned_phases = np.array(binned_phases)
    #indexes = np.where((all_spike_times >= fit_start) & (all_spike_times <= fit_end))[0]
    indexes = np.where((binned_times > fit_start) & (binned_times < fit_end))[0]
    #m, b = np.polyfit(all_spike_times[indexes], all_spike_phases[indexes], 1)
    m, b = np.polyfit(binned_times[indexes], binned_phases[indexes], 1)
    #indexes = np.where((all_spike_times >= fit_start) & (all_spike_times <= fit_end))[0]
    #fit_t = np.arange(np.min(all_spike_times[indexes]), np.max(all_spike_times[indexes]), 10.)
    #fit_t = np.arange(np.min(binned_times[indexes]), np.max(binned_times[indexes]+window_dur/2.), window_dur)
    fit_t = np.arange(fit_start, fit_end+bin_size, bin_size)
    if adjust:
        for i in range(0, len(binned_phases)):
            if i == 0:
                error0 = abs(binned_phases[i + 1] - binned_phases[i])
                error1 = abs(binned_phases[i + 1] - (binned_phases[i] + 360.))
                error2 = abs(binned_phases[i + 1] - (binned_phases[i] - 360.))
                if (error1 < error0) and (error1 < error2):
                    binned_phases[i] = binned_phases[i] + 360.
                elif error2 < error0:
                    binned_phases[i] = binned_phases[i] - 360.
            else:
                error0 = abs(binned_phases[i] - binned_phases[i - 1])
                error1 = abs((binned_phases[i] + 360.) - binned_phases[i - 1])
                error2 = abs((binned_phases[i] - 360.) - binned_phases[i - 1])
                if (error1 < error0) and (error1 < error2):
                    binned_phases[i] = binned_phases[i] + 360.
                elif error2 < error0:
                    binned_phases[i] = binned_phases[i] - 360.
    axes.plot(binned_times, binned_phases, c='k')  # , linewidth=2.)
    axes.plot(fit_t, m * fit_t + b, c='r')  # , linewidth=2.)
    #axes.set_ylim(0., 360.)
    axes.set_ylim(0., 720.)
    axes.set_yticks([0., 180., 360., 540., 720.])
    axes.set_xlim(display_start, display_end)
    axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
    if svg_title is not None:
        axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
    axes.set_ylabel('Phase (Degrees)')  # , fontsize=20)
    axes.set_xlabel('Time (s)')  # , fontsize=20)
    axes.set_title('Phase Precession - '+title, fontsize=mpl.rcParams['font.size'])
    clean_axes(axes)
    axes.tick_params(direction='out')
    if svg_title is not None:
        fig.set_size_inches(1.21, 2.)
        fig.savefig(data_dir+svg_title+' - Precession - '+title+'.svg', format='svg', transparent=True)
    if plot:
        plt.show()
    plt.close()
    print title, abs(m * (fit_end - fit_start))
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size
    return binned_times, binned_phases


def plot_phase_precession_paired(rec_filenames, conditions=None, titles=None, fit_start=2700., fit_end=6300.,
                                 display_start=0., display_end=7500., bin_size=60., num_bins=5, svg_title=None,
                                 plot=True, adjust=True):
    """
    :param rec_filenames: list of str
    :param conditions: list of str
    :param titles: list of str
    :param fit_start: float
    :param fit_end: float
    :param display_start: float
    :param display_end: float
    :param bin_size: float
    :param num_bins: int
    :param svg_title: str
    :param plot: bool
    :param adjust: bool
    """
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 8
    if conditions is None:
        conditions = ['modinh0', 'modinh3']
    if titles is None:
        titles = ['Control', 'Reduced inhibition']
    spike_times, spike_phases, intra_peaks, intra_phases, binned_spike_times, binned_spike_phases, binned_intra_peaks, \
        binned_intra_phases = {}, {}, {}, {}, {}, {}, {}, {}
    for condition in conditions:
        spike_times[condition], spike_phases[condition], intra_peaks[condition], intra_phases[condition] = \
            get_phase_precession(rec_filenames[condition])
        binned_spike_times[condition], binned_spike_phases[condition], binned_intra_peaks[condition], \
            binned_intra_phases[condition] = [], [], [], []
    window_dur = float(num_bins) * bin_size
    orig_fit_start = fit_start
    orig_fit_end = fit_end
    for (event_times, event_phases), (binned_event_times, binned_event_phases), param_type in \
            zip([(spike_times, spike_phases),
                 (intra_peaks, intra_phases)],
                [(binned_spike_times, binned_spike_phases), (binned_intra_peaks, binned_intra_phases)],
                ['Spikes', 'Intra']):
        consistent = False
        fit_start = orig_fit_start
        fit_end = orig_fit_end
        for condition, title in zip(conditions, titles):
            fig, axes = plt.subplots(1)
            for trial in range(len(event_times[condition])):
                axes.scatter(event_times[condition][trial], event_phases[condition][trial], color='gray', s=0.1)
            start = display_start
            while start + window_dur <= display_end:
                event_time_buffer = []
                event_phase_buffer = []
                for trial in range(len(event_times[condition])):
                    indexes = np.where((np.array(event_times[condition][trial]) >= start) &
                                       (np.array(event_times[condition][trial]) < start + window_dur))[0]
                    if np.any(indexes):
                        event_time_buffer.extend(event_times[condition][trial][indexes])
                        event_phase_buffer.extend(event_phases[condition][trial][indexes])
                if np.any(event_time_buffer):
                    binned_event_times[condition].append(start+window_dur/2.)
                    binned_event_phases[condition].append(stats.circmean(event_phase_buffer, high=360.,
                                                                         low=0.))
                start += window_dur
            binned_event_times[condition] = np.array(binned_event_times[condition])
            binned_event_phases[condition] = np.array(binned_event_phases[condition])
            if not consistent:  # choose fit_start and fit_end based on control spikes, then use for other conditions
                indexes = np.where((binned_event_times[condition] > fit_start) &
                                   (binned_event_times[condition] < fit_end))[0]
                if np.any(indexes):
                    start_index = np.where(binned_event_phases[condition] ==
                                           np.max(binned_event_phases[condition][indexes]))[0][0]
                    end_index = np.where(binned_event_phases[condition] ==
                                           np.min(binned_event_phases[condition][indexes]))[0][0]
                    fit_start = binned_event_times[condition][start_index] - window_dur / 2.
                    fit_end = binned_event_times[condition][end_index] + window_dur / 2.
                consistent = True
            indexes = np.where((binned_event_times[condition] > fit_start) &
                               (binned_event_times[condition] < fit_end))[0]
            if np.any(indexes):
                m, b = np.polyfit(binned_event_times[condition][indexes], binned_event_phases[condition][indexes], 1)
                fit_t = np.arange(fit_start + window_dur / 2., fit_end, bin_size)
                if adjust:
                    for i in range(1, len(binned_event_phases[condition][:-1])):
                        if binned_event_phases[condition][i] < 90.:
                            if np.abs(binned_event_phases[condition][i] - binned_event_phases[condition][i+1]) > 180.:
                                binned_event_phases[condition][i] = binned_event_phases[condition][i] + 360.
                axes.plot(binned_event_times[condition], binned_event_phases[condition], c='k')
                axes.plot(fit_t, m * fit_t + b, c='r')
                axes.set_ylim(0., 360.)
                axes.set_yticks([0., 180., 360.])
                axes.set_xlim(display_start, display_end)
                axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
                if svg_title is not None:
                    axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
                axes.set_ylabel('Theta phase (o)')
                axes.set_xlabel('Time (s)')
                axes.set_title('Phase precession - '+param_type+' '+title, fontsize=mpl.rcParams['font.size'])
                clean_axes(axes)
                axes.tick_params(direction='out')
                if svg_title is not None:
                    fig.set_size_inches(1.1859, 1.035)
                    fig.savefig(data_dir+svg_title+' - Precession - '+param_type+' '+condition+'.svg', format='svg',
                                transparent=True)
                if plot:
                    plt.show()
                plt.close()
                print param_type, condition, abs(m * (fit_end - fit_start - window_dur))
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size
    return binned_spike_times, binned_spike_phases, binned_intra_peaks, binned_intra_phases


def plot_phase_precession_sliding(t_array, phase_array, title, fit_start=3660., fit_end=5400., display_start=0.,
                          display_end=7500., bin_size=60., num_bins=3, svg_title=None):
    """

    :param t_array: list of np.array
    :param phase_array: list of np.array
    :param title: str
    :param fit_start: float
    :param fit_end: float
    :param display_start: float
    :param display_end: float
    :param bin_size: float
    :param num_bins: int
    """
    fig, axes = plt.subplots(1)
    for i, t in enumerate(t_array):
        phases = phase_array[i]
        axes.scatter(t, phases, color='lightgray')  # colors[i])
    binned_times = []
    binned_phases = []
    window_dur = float(num_bins) * bin_size
    start = display_start
    while start < display_end:
        index_array = [np.where((spike_times >= start) & (spike_times < start + window_dur))[0] for spike_times in
                                                                                                        t_array]
        spike_times = []
        spike_phases = []
        for i in range(len(t_array)):
            for spike_time, spike_phase in zip(t_array[i][index_array[i]], phase_array[i][index_array[i]]):
                spike_times.append(spike_time)
                spike_phases.append(spike_phase)
        if np.any(spike_times):
            binned_times.append(start+window_dur/2.)  # (np.mean(spike_times))
            binned_phases.append(stats.circmean(spike_phases, high=360., low=0.))
        start += bin_size
    binned_times = np.array(binned_times)
    binned_phases = np.array(binned_phases)
    indexes = np.where((binned_times >= fit_start+window_dur/2.) & (binned_times <= fit_end-window_dur/2.))[0]
    m, b = np.polyfit(binned_times[indexes], binned_phases[indexes], 1)
    fit_t = np.arange(fit_start, fit_end+bin_size, bin_size)
    axes.plot(binned_times, binned_phases, c='k', linewidth=2.)
    axes.plot(fit_t, m * fit_t + b, c='r', linewidth=2.)
    axes.set_ylim(0., 360.)
    axes.set_xlim(display_start, display_end)
    axes.set_ylabel('Phase (Degrees)', fontsize=20)
    axes.set_xlabel('Time (ms)', fontsize=20)
    axes.set_title('Phase Precession - '+ title, fontsize=20)
    clean_axes(axes)
    if svg_title is not None:
        plt.savefig(data_dir+svg_title+'.svg', format='svg')
    plt.show()
    plt.close()
    print title, abs(m * (fit_end - fit_start))
    return binned_times, binned_phases


def process_simple_axon_model_output(rec_filename):
    """

    :param rec_filename: str
    :param stim_list: list of int
    :return: min_voltages: dict
    """
    dt = 0.01
    with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
        if 'duration' in f['0'].attrs:
            duration = f['0'].attrs['duration']
        else:
            duration = 400.
        if 'equilibrate' in f['0'].attrs:
            equilibrate = f['0'].attrs['equilibrate']
        else:
            equilibrate = 250.
        if 'stim_dur' in f['0'].attrs:
            stim_dur = f['0'].attrs['stim_dur']
        else:
            stim_dur = 50.
        t = np.arange(0., duration, dt)
        left = int((equilibrate-3.) / dt)
        right = int((equilibrate-1.) / dt)
        start = int((equilibrate+stim_dur-11.) / dt)
        end = int((equilibrate+stim_dur-1.) / dt)
        distances = [rec.attrs['soma_distance'] for rec in f.itervalues().next()['rec'].itervalues()]
        propagation = {sim.attrs['vm_amp_target']: [] for sim in f.itervalues()}
        for i, sim in enumerate(f.itervalues()):
            target = sim.attrs['vm_amp_target']
            soma_plateau = sim.attrs['plateau']
            for rec in sim['rec'].itervalues():
                vm = np.interp(t, sim['time'], rec)
                baseline = np.mean(vm[left:right])
                plateau = np.min(vm[start:end]) - baseline
                propagation[target].append(plateau / soma_plateau)
    return distances, propagation


def get_spike_delay_vs_distance_simple_axon_model(rec_filename):
    """

    :param rec_filename:
    :return: distances, delays
    """
    dt = 0.01
    th_dvdt = 20.
    with h5py.File(data_dir+rec_filename+'.hdf5', 'r') as f:
        if 'duration' in f['0'].attrs:
            duration = f['0'].attrs['duration']
        else:
            duration = 400.
        if 'equilibrate' in f['0'].attrs:
            equilibrate = f['0'].attrs['equilibrate']
        else:
            equilibrate = 250.
        if 'stim_dur' in f['0'].attrs:
            stim_dur = f['0'].attrs['stim_dur']
        else:
            stim_dur = 50.
        t = np.arange(0., duration, dt)
        start = int((equilibrate+0.4) / dt)
        end = int((equilibrate+stim_dur) / dt)
        distances = []
        delays = []
        for sim in f.itervalues():
            if not distances:
                for rec in sim['rec'].itervalues():
                    distances.append(rec.attrs['soma_distance'])
            if sim['stim']['0'].attrs['amp'] > 0.:
                rec = sim['rec']['0']
                vm = np.interp(t, sim['time'], rec)
                dvdt = np.gradient(vm, [dt])
                th_x = np.where(dvdt[start:end] > th_dvdt)[0]
                if th_x.any():
                    soma_th_x = th_x[0] + start
                    end = soma_th_x + int(5. / dt)
                    soma_peak = np.max(vm[soma_th_x:end])
                    soma_peak_x = np.where(vm[soma_th_x:end]==soma_peak)[0][0] + soma_th_x
                    soma_peak_t = t[soma_peak_x]
                    start = soma_th_x - int(2. / dt)
                    for rec in sim['rec'].itervalues():
                        vm = np.interp(t, sim['time'], rec)
                        peak = np.max(vm[start:end])
                        peak_x = np.where(vm[start:end]==peak)[0][0] + start
                        peak_t = t[peak_x]
                        delay = peak_t - soma_peak_t
                        delays.append(delay)
                    break
    return distances, delays


def plot_patterned_input_binned_rinp(t_dict, phase_dict, r_inp_dict, key_list=None,
                                     t_bounds=[0., 1800., 3600., 5400., 7500.], del_t=300., del_phase=36., plot=1,
                                     svg_title=None):
    """

    :param t_dict: :class:'np.array', results from appending hypo_ and depo_array outputs of get_patterned_input_r_inp()
    :param phase_dict: :class:'np.array'
    :param r_inp_dict: :class:'np.array'
    :param key_list: list of str
    :param t_bounds: list of float, time points defining start and end of track and inhibitory manipulation
    :param del_t: float
    :param del_phase: float
    :param plot: int in [0, 1]
    :param svg_title: str
    :return tuple of array: binned_phase, r_inp_by_phase
    """
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 8
    if key_list is None:
        key_list = ['modinh0', 'modinh3']
    key_list.extend([key_list[0]+'_in', key_list[0]+'_out', key_list[1]+'_in', key_list[1]+'_out'])
    filtered_t_dict, filtered_phase_dict, filtered_r_inp_dict, indexes = {}, {}, {}, {}
    for i in [0, 1]:
        condition = key_list[i]
        indexes[condition] = np.where(r_inp_dict[condition] > 0.)[0]
    for target, source in [(key_list[i], key_list[j]) for (i, j) in zip([2, 4], [0, 1])]:
        indexes[target] = np.where((r_inp_dict[source] > 0.) & (t_dict[source] >= t_bounds[2]) &
                                      (t_dict[source] < t_bounds[3]))[0]
    for target, source in [(key_list[i], key_list[j]) for (i, j) in zip([3, 5], [0, 1])]:
        indexes[target] = np.where((r_inp_dict[source] > 0.) & (t_dict[source] >= t_bounds[0]) &
                                  (t_dict[source] < t_bounds[1]))[0]
    for target, source in [(key_list[i], key_list[j]) for (i, j) in zip(range(len(key_list)), [0, 1, 0, 0, 1, 1])]:
        filtered_t_dict[target] = t_dict[source][indexes[target]]
        filtered_phase_dict[target] = phase_dict[source][indexes[target]]
        filtered_r_inp_dict[target] = r_inp_dict[source][indexes[target]]
    binned_t, binned_phase, r_inp_by_t, r_inp_by_phase = {}, {}, {}, {}
    for condition in key_list:
        binned_t[condition], r_inp_by_t[condition] = [], []
        for t in np.arange(t_bounds[0], t_bounds[4], del_t):
            indexes = np.where((filtered_t_dict[condition] >= t) & (filtered_t_dict[condition] < t + del_t))[0]
            if np.any(indexes):
                binned_t[condition].append(t + del_t / 2.)
                r_inp_by_t[condition].append(np.mean(filtered_r_inp_dict[condition][indexes]))
        binned_phase[condition], r_inp_by_phase[condition] = [], []
        for phase in np.arange(0., 360., del_phase):
            indexes = np.where((filtered_phase_dict[condition] >= phase) &
                               (filtered_phase_dict[condition] < phase + del_phase))[0]
            if np.any(indexes):
                binned_phase[condition].append(phase + del_phase / 2.)
                r_inp_by_phase[condition].append(np.mean(filtered_r_inp_dict[condition][indexes]))
    fig, axes = plt.subplots(1, 2)
    axes[0].plot(binned_t[key_list[0]], r_inp_by_t[key_list[0]], c='k', label='Control')
    axes[0].plot(binned_t[key_list[1]], r_inp_by_t[key_list[1]], c='orange', label='Reduced inhibition')
    axes[0].set_xlabel('Time (ms)')
    axes[0].set_xlim(0., 7500.)
    axes[0].set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
    axes[0].set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
    axes[0].set_ylabel('Input Resistance (MOhm)')
    axes[0].legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    """
    axes[1].plot(binned_phase[key_list[3]], r_inp_by_phase[key_list[3]], c='k', label='Control - Out of field')
    axes[1].plot(binned_phase[key_list[4]], r_inp_by_phase[key_list[4]], c='grey', label='Control - In field')
    axes[1].plot(binned_phase[key_list[1]], r_inp_by_phase[key_list[1]], c='orange',
                 label='Reduced inhibition - Out of field')
    axes[1].plot(binned_phase[key_list[2]], r_inp_by_phase[key_list[2]], c='y', label='Reduced inhibition - In field')
    """
    axes[1].plot(binned_phase[key_list[0]], r_inp_by_phase[key_list[0]], c='k', label='Control')
    axes[1].plot(binned_phase[key_list[1]], r_inp_by_phase[key_list[1]], c='orange', label='Reduced inhibition')
    axes[1].set_xlabel('Theta phase (o)')
    axes[1].set_xlim(0., 360.)
    axes[1].set_xticks([0., 90., 180., 270., 360.])
    axes[1].set_ylabel('Input Resistance (MOhm)')
    axes[1].legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    clean_axes(axes)
    axes[0].tick_params(direction='out')
    axes[1].tick_params(direction='out')
    fig.subplots_adjust(wspace=0.6)
    for condition in key_list:
        print condition, np.mean(filtered_r_inp_dict[condition])
    if svg_title is not None:
        fig.set_size_inches(4.2, 1.5)
        fig.savefig(data_dir+svg_title+' - r_inp.svg', format='svg', transparent=True)
    if plot:
        plt.show()
    plt.close()
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size
    return binned_phase, r_inp_by_phase, binned_t, r_inp_by_t


def plot_somatic_rinp_across_cells(t_array, phase_array, r_inp_array, svg_title=None):
    """

    :param saved_param_file: str: .pkl file containing t_array, phase_array, r_inp_array with output from multiple seeds
    :param svg_title: str
    """
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 8
    r_inp_by_phase, binned_t, r_inp_by_t = {}, {}, {}
    for seed in t_array:
        binned_phase, r_inp_by_phase[seed], binned_t[seed], r_inp_by_t[seed] = \
            plot_patterned_input_binned_rinp(t_array[seed], phase_array[seed], r_inp_array[seed], plot=0)
    r_inp_by_phase['mean'] = {'modinh0': [], 'modinh3': []}
    r_inp_by_phase['var'] = {}
    for i in range(len(binned_phase['modinh0'])):
        for condition in r_inp_by_phase['mean']:
            val = np.mean([r_inp_by_phase[seed][condition][i] for seed in t_array])
            r_inp_by_phase['mean'][condition].append(val)
    for condition in ['modinh0', 'modinh3']:
        binned_phase[condition] = np.append(binned_phase[condition], np.add(binned_phase[condition], 360.))
        r_inp_by_phase['mean'][condition] = np.append(r_inp_by_phase['mean'][condition],
                                                      r_inp_by_phase['mean'][condition])
        r_inp_by_phase['var'][condition] = np.var([r_inp_by_phase[seed][condition] for seed in t_array], axis=0)
        r_inp_by_phase['var'][condition] = np.append(r_inp_by_phase['var'][condition],
                                                      r_inp_by_phase['var'][condition])
    fig, axes = plt.subplots(1)
    colors = [('k', 'grey'), ('orange', 'orange')]
    for i, (condition, title) in enumerate(zip(['modinh0', 'modinh3'], ['Control', 'Reduced inhibition'])):
        axes.plot(binned_phase[condition], r_inp_by_phase['mean'][condition], color=colors[i][0], label=title)
        axes.plot(binned_phase[condition], np.add(r_inp_by_phase['mean'][condition],
                                                  r_inp_by_phase['var'][condition]), color=colors[i][1])
        axes.plot(binned_phase[condition], np.subtract(r_inp_by_phase['mean'][condition],
                                                  r_inp_by_phase['var'][condition]), color=colors[i][1])
    axes.set_xlabel('Theta phase (o)')
    axes.set_xlim(0., 720.)
    axes.set_xticks([0., 180., 360., 540., 720.])
    axes.set_ylabel('Input Resistance (MOhm)')
    axes.legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    clean_axes(axes)
    axes.tick_params(direction='out')
    if svg_title is not None:
        fig.set_size_inches(2.6, 1.2169)
        fig.savefig(data_dir + svg_title + ' - r_inp.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size


def plot_patterned_input_i_syn_summary(rec_filename_array, svg_title=None):
    """
    Expects an array with 9 rec_filenames (3 conditions each for i_AMPA, i_NMDA, and i_GABA). Generates 2 plots
    (Synaptic Currents, and E:I Ratio), each a grid of 3 graphs (Control, Reduced Inhibition Out of Field,
    Reduced Inhibition In Field).
    :param rec_filename_array:
    :param svg_title: str
    """
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 20
    i_syn_dict = {'i_AMPA': {}, 'i_NMDA': {}, 'i_GABA': {}, 'ratio': {}}
    for condition in ['modinh0', 'modinh1', 'modinh2']:
        for rec_filename in rec_filename_array[condition]:
            rec_t, i_syn_mean_dict, i_syn_mean_low_pass_dict = process_i_syn_rec(rec_filename)
            for syn_type in i_syn_mean_low_pass_dict:
                i_syn_dict[syn_type][condition] = i_syn_mean_low_pass_dict[syn_type]
    colors = ['k', 'y', 'orange']
    for group in ['i_AMPA', 'i_NMDA', 'i_GABA']:
        fig, axes = plt.subplots(1)
        for i, (condition, title) in enumerate(zip(['modinh0', 'modinh2', 'modinh1'], ['Control',
                                            'Reduced inhibition - In field', 'Reduced inhibition - Out of field'])):
            axes.plot(rec_t, i_syn_dict[group][condition], c=colors[i], label=title, linewidth=1)
        clean_axes(axes)
        axes.set_xlabel('Time (s)')
        axes.set_ylabel('Current (nA)')
        axes.set_xlim(0., 7500.)
        axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
        axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
        axes.tick_params(direction='out')
        axes.set_title(group, fontsize=mpl.rcParams['font.size'])
        # plt.legend(loc='best', frameon=False, framealpha=0.5)
        if group == 'i_GABA':
            axes.set_ylim(0., .7)
        else:
            axes.set_ylim(-.7, 0.)
        if svg_title is not None:
            fig.set_size_inches(4.403, 3.631)
            fig.savefig(data_dir+svg_title+' - '+group+'.svg', format='svg', transparent=True)
        plt.show()
        plt.close()
    for condition in ['modinh0', 'modinh1', 'modinh2']:
        i_syn_dict['ratio'][condition] = np.divide(np.abs(np.add(i_syn_dict['i_AMPA'][condition],
                                                                 i_syn_dict['i_NMDA'][condition])),
                                                   i_syn_dict['i_GABA'][condition])
    fig, axes = plt.subplots(1)
    for i, (condition, title) in enumerate(zip(['modinh0', 'modinh2', 'modinh1'], ['Control',
                                    'Reduced inhibition - In field', 'Reduced inhibition - Out of field'])):
        axes.plot(rec_t, i_syn_dict['ratio'][condition], c=colors[i], label=title, linewidth=2)
    clean_axes(axes)
    axes.set_xlabel('Time (s)')
    axes.set_ylabel('E:I ratio')
    axes.set_ylim(1., 2.8)
    axes.set_xlim(0., 7500.)
    axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
    axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
    axes.tick_params(direction='out')
    # plt.legend(loc='best', frameon=False, framealpha=0.5)
    if svg_title is not None:
        fig.set_size_inches(4.403, 3.631)
        fig.savefig(data_dir+svg_title+' - E_I ratio.svg', format='svg', transparent=True)
    plt.show()
    plt.close()
    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size
    for group in i_syn_dict:
        get_i_syn_mean_values(i_syn_dict[group], group)


def plot_place_field_i_syn_across_cells(rec_filename_array, groups=None, svg_title=None):
    """
    Expects a nested dictionary of rec_filenames {'i_AMPA', 'i_NMDA', 'i_GABA': {seed or group id: {'modinh0',
    'modinh3'}}}. Generates 4 plots (3 synaptic currents, and E:I Ratio) depicting the mean and SEM across cells.
    :param rec_filename_array:
    :param groups: list of dict keys
    :param svg_title: str
    """
    if svg_title is not None:
        remember_font_size = mpl.rcParams['font.size']
        mpl.rcParams['font.size'] = 8
    if groups is None:
        groups = rec_filename_array['i_AMPA'].keys()
    i_syn_dict = {}
    for key in ['i_AMPA', 'i_NMDA', 'i_GABA', 'ratio']:
        i_syn_dict[key] = {}
        for group in groups:
            i_syn_dict[key][group] = {}
    for syn_type in rec_filename_array:
        for group in groups:
            for condition in ['modinh0', 'modinh3']:
                rec_filename = rec_filename_array[syn_type][group][condition]
                rec_t, i_syn_mean_dict, i_syn_mean_low_pass_dict = process_i_syn_rec(rec_filename)
                i_syn_dict[syn_type][group][condition] = i_syn_mean_low_pass_dict[syn_type]
    for group in groups:
        for condition in ['modinh0', 'modinh3']:
            i_syn_dict['ratio'][group][condition] = np.divide(np.abs(np.add(i_syn_dict['i_AMPA'][group][condition],
                                                                            i_syn_dict['i_NMDA'][group][condition])),
                                                              i_syn_dict['i_GABA'][group][condition])
    for key in ['i_AMPA', 'i_NMDA', 'i_GABA', 'ratio']:
        i_syn_dict[key]['mean'] = {}
        i_syn_dict[key]['var'] = {}
        for condition in ['modinh0', 'modinh3']:
            i_syn_dict[key]['mean'][condition] = np.mean([i_syn_dict[key][group][condition] for group in groups],
                                                         axis=0)
            i_syn_dict[key]['var'][condition] = np.var([i_syn_dict[key][group][condition] for group in groups], axis=0)

    colors = [('c', 'c'), ('k', 'grey'), ('purple', 'purple')]
    for condition, title in zip(['modinh0', 'modinh3'], ['Control', 'Reduced inhibition']):
        fig, axes = plt.subplots(1)
        for i, key in enumerate(['i_AMPA', 'i_NMDA', 'i_GABA']):
            this_mean = i_syn_dict[key]['mean'][condition]
            this_variance = i_syn_dict[key]['var'][condition]
            this_SEM = np.divide(np.sqrt(this_variance), np.sqrt(float(len(groups))))
            axes.plot(rec_t, np.subtract(this_mean, this_SEM), color=colors[i][1])
            axes.plot(rec_t, np.add(this_mean, this_SEM), color=colors[i][1])
            axes.plot(rec_t, this_mean, color=colors[i][0], label=key)
        clean_axes(axes)
        axes.set_xlabel('Time (s)')
        axes.set_ylabel('Current (nA)')
        axes.set_ylim(-0.7, 0.7)
        axes.set_yticks([-0.6, -0.3, 0., 0.3, 0.6])
        axes.set_xlim(0., 7500.)
        axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
        axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
        axes.tick_params(direction='out')
        plt.legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
        axes.set_title(title, fontsize=mpl.rcParams['font.size'])
        if svg_title is not None:
            fig.set_size_inches(1.3198, 1.2169)
            fig.savefig(data_dir+svg_title+' - i_syn - '+condition+'.svg', format='svg', transparent=True)
        plt.show()
        plt.close()

    key = 'ratio'
    colors = [('k', 'grey'), ('orange', 'orange')]
    fig, axes = plt.subplots(1)
    for i, (condition, title) in enumerate(zip(['modinh0', 'modinh3'], ['Control', 'Reduced inhibition'])):
        this_mean = i_syn_dict[key]['mean'][condition]
        this_variance = i_syn_dict[key]['var'][condition]
        this_SEM = np.divide(np.sqrt(this_variance), np.sqrt(float(len(groups))))
        axes.plot(rec_t, np.subtract(this_mean, this_SEM), color=colors[i][1])
        axes.plot(rec_t, np.add(this_mean, this_SEM), color=colors[i][1])
        axes.plot(rec_t, this_mean, color=colors[i][0], label=title)
    clean_axes(axes)
    axes.set_xlabel('Time (s)')
    axes.set_ylabel('Ratio')
    axes.set_ylim(1., 2.8)
    axes.set_yticks([1.0, 1.25, 1.5, 1.75, 2.0, 2.25, 2.5, 2.75])
    axes.set_xlim(0., 7500.)
    axes.set_xticks([0., 1500., 3000., 4500., 6000., 7500.])
    axes.set_xticklabels([0, 1.5, 3, 4.5, 6, 7.5])
    axes.tick_params(direction='out')
    axes.set_title('E:I ratio', fontsize=mpl.rcParams['font.size'])
    plt.legend(loc='best', frameon=False, framealpha=0.5, fontsize=mpl.rcParams['font.size'])
    if svg_title is not None:
        fig.set_size_inches(1.3198, 1.2169)
        fig.savefig(data_dir + svg_title + ' - i_syn - ' + key + '.svg', format='svg', transparent=True)
    plt.show()
    plt.close()

    if svg_title is not None:
        mpl.rcParams['font.size'] = remember_font_size
    for group in groups:
        print group, ':'
        for key in i_syn_dict:
            i_syn_dict[key][group]['modinh3_out'] = copy.deepcopy(i_syn_dict[key][group]['modinh3'])
            get_i_syn_mean_values(i_syn_dict[key][group], key, ['modinh0', 'modinh3_out', 'modinh3'])