# pn_kc_ggn_plot.py --- 
# 
# Filename: pn_kc_ggn_plot.py
# Description: 
# Author: Subhasis Ray
# Maintainer: 
# Created: Fri Feb 16 13:08:41 2018 (-0500)
# Last-Updated: Mon Mar  4 11:07:59 2019 (-0500)
#           By: Subhasis  Ray
#     Update #: 1108
# 
# Code:
from __future__ import print_function
import warnings
import sys
import os
from timeit import default_timer as timer
import numpy as np
import random
import h5py as h5
import yaml
import pandas as pd
from matplotlib import pyplot as plt
from pint import UnitRegistry

import sys

import neurograph as ng
from sklearn.neighbors import KernelDensity
import network_data_analysis as nda

ur = UnitRegistry()
Q_ = ur.Quantity


def despine(ax, locs='all'):
    """Remove the line marking axing bounder in locations specified in the list `locs`.
    The locations can be 'top', 'bottom', 'left', 'right' or 'all'
    """
    if locs == 'all':
        locs = ['top', 'bottom', 'left', 'right']
    for loc in locs:
        ax.spines[loc].set_visible(False)

def plot_population_psth(ax, spike_trains, ncell, bins, alpha=0.5, rate_sym='b-', cell_sym='r--'):
    """Plot the population PSTH on axis `ax`, where `spike_trains` is a
    list of arrays, each containing the spike times of one cell. ncell
    is the total number of cell and is used for computing rates per
    cell.

    """
    start = timer()
    cell_counts = np.zeros(len(bins) - 1)
    spike_counts = np.zeros(len(bins) - 1)
    spiking_cell_count = 0
    for train in spike_trains:
        if len(train) > 0:
            spiking_cell_count += 1
        hist, bins = np.histogram(train, bins)
        cell_counts += (hist > 0)
        spike_counts += hist
    print('Total number of spiking cells', spiking_cell_count, 'out of', len(spike_trains))
    spike_rate = spike_counts * 1e3 / (ncell * (bins[1] - bins[0]))
    cell_frac = cell_counts * 1e3 / (ncell * (bins[1] - bins[0]))
    if rate_sym is None:
        ax.plot((bins[:-1] + bins[1:]) / 2.0, spike_rate, 
                label='spikes/ncells/binwidth(s)', alpha=alpha)
    else:
        ax.plot((bins[:-1] + bins[1:]) / 2.0, spike_rate, 
                rate_sym, label='spikes/ncells/binwidth(s)', alpha=alpha)
    if cell_sym is None:
        ax.plot((bins[:-1] + bins[1:]) / 2.0, cell_frac, 
            label='cells spiking/ncells/binwidth(s)', alpha=alpha)
    else:        
        ax.plot((bins[:-1] + bins[1:]) / 2.0, cell_frac, cell_sym,
            label='cells spiking/ncells/binwidth(s)', alpha=alpha)
    end = timer()
    print('Plotted PSTH for {} spike trains in {} s'.format(
        len(spike_trains),
        end - start))
    return ax, spike_rate, cell_frac


def plot_population_KDE(ax, spike_trains, xgrid, bandwidth, color='k',
                        maxamp=1.0):
    """Plot kernel density estimation of spike times of all the spike
    trains.

    Scale the PDF so that the maximum amplitude is same as maxamp
    (useful for bringing it to scale with histogram for example).

    """
#    print(spike_trains)
    spike_times = np.concatenate(spike_trains)
    if len(spike_times) == 0:
        warnings.warn('No spikes in spike trains')
        return
    pdf, xgrid = nda.kde_pdf(spike_times, bandwidth=bandwidth, xgrid=xgrid)
    scale = maxamp / max(pdf) if max(pdf) > 0 else 1.0
    ax.plot(xgrid, pdf * scale / len(spike_trains), color=color, label='KDE')
    return pdf, xgrid


def plot_kc_spikes_by_cluster(ax, fd, ca_side, color='k', marker=',', hline=False, subsample_clus=0, subsample_kc=0):
    """Raster plot KC spike times ordered by cluster no.

    ca_side should be 'lca' or 'mca'.

    subsample_clus: number of clusters to plot if subsampling (> 0).

    subsample_kc:  number of kcs in each cluster to plot if subsampling (> 0).
    """
    kc_spikes = {row['source'].decode(): row['data']
                 for row in fd['/map/event/kc/kc_spiketime'].value}

    cluster_info = nda.extract_cluster_info(fd, ca_side)
    ii = 0
    spike_x, spike_y = [], []
    ret = []
    labels = list(set(cluster_info['label']))
    if subsample_clus > 0:
        labels = np.random.choice(labels, size=subsample_clus, replace=False)
    labelgrp = cluster_info.groupby('label')
    for label in labels:
        group = labelgrp.get_group(label)
        grpx, grpy = [], []
        kcs = np.char.decode(group.index.values.flatten().astype('S'))
        if subsample_kc > 0:
            kcs = np.random.choice(kcs, size=subsample_kc, replace=False)
        for kc in kcs:
            st = fd[kc_spikes[kc]].value
            grpx.append(st)
            grpy.append(np.ones(len(st)) + ii)
            ii += 1
        ret.append(ax.plot(np.concatenate(grpx), np.concatenate(grpy),
                           marker=marker, linestyle='none'))
        spike_x += grpx
        spike_y += grpy
        if hline:
            ax.axhline(y=ii-0.5, color='gray', linewidth=1.0)
    return ret, spike_x, spike_y


def plot_kc_spikes(ax, fd, ca_side='both', color='k', marker=','):
    """Raster plot KC spike times for KCs belonging to the specified side
    of calyx ('lca' or 'mca').

    This function does not care about spatial clusters.

    Returns: the line object, list of spike times and list of their
    y-positions.

    """
    if ca_side == 'both':
        nodes = fd[nda.kc_st_path].keys()
    else:
        nodes = nda.get_kc_spike_nodes_by_region(fd, ca_side)
    spike_x, spike_y = [], []
    fname = fd.filename
    try:
        spike_x, spike_y = nda.get_event_times(
            fd['/data/event/kc/kc_spiketime'],
            nodes=nodes)
    except KeyError:
        dirname = os.path.dirname(fname)
        fname = 'kc_spikes_' + os.path.basename(fd.filename)
        with h5.File(os.path.join(dirname, fname)) as kc_file:
            spike_x, spike_y = nda.get_event_times(kc_file, nodes=nodes)
    if len(spike_x) > 0:
        ret = ax.plot(np.concatenate(spike_x), np.concatenate(spike_y),
                      color=color, marker=marker, linestyle='none')
    else:
        ret = None
    return ret, spike_x, spike_y


def plot_kc_vm(ax, fd, region, count, color='k', alpha=0.5):
    """Plot Vm of `count` KCs from sepcified region."""
    match = nda.get_kc_vm_idx_by_region(fd, region)
    if len(match) == 0:
        return [], []
    selected = random.sample(match, min(count, len(match)))
    kc_vm_node = fd['/data/uniform/kc/KC_Vm']
    try:
        t = np.arange(kc_vm_node.shape[1]) * kc_vm_node.attrs['dt']
    except KeyError:
        t = np.arange(kc_vm_node.shape[1])
    for name, idx in selected:
        ax.plot(t, kc_vm_node[idx, :], label=name, color=color, alpha=alpha)
    return selected


def plot_ggn_vm(ax, fd, dataset, region=None, count=5, color='k', alpha=0.5):
    """Plot Vm of GGN from dataset"""
    sec_list = [sec.decode('utf-8')  for sec in dataset.dims[0]['source']]
    match = []
    if region is None:
        match = [(sec, ii) for ii, sec in enumerate(sec_list)]
    else:
        rsid = ng.name_sid[region]
        for ii, sec in enumerate(sec_list):
            sid = nda.ggn_sec_to_sid(sec)
            if sid == rsid:
                match.append((sec, ii))
    # print(match)
    if len(match) == 0:
        return [], []
    selected = random.sample(match, min(len(match), count))
    try:
        t = np.arange(dataset.shape[1]) * dataset.attrs['dt']
    except KeyError:
        t = np.arange(dataset.shape[1])
    for sec, ii in selected:
        ax.plot(t, dataset[ii, :], label=sec, color=color, alpha=alpha)
    return selected
    
    
def plot_spike_rasters(fname, vm_samples=10, psth_bin_width=50.0,
                       kde_bw=50.0, by_cluster=False):
    """The file `fname` has data from pn_kc_ggn simulation. In the early
    ones I did not record the spike times for KCs. binwidths are in
    ms.

    """
    start = timer()
    print('Processing', fname)
    with h5.File(fname, 'r') as fd:
        try:
            print('Description:', fd.attrs['description'])
        except KeyError:
            print('No description available')
        try:
            config = yaml.load(fd.attrs['config'].decode())
        except KeyError:
            try:
                config = yaml.load(fd['model/filecontents/mb/network/config.yaml'][0].decode())
            except KeyError: # possibly fixed network - look for config in template
                try:
                    original = fd.attrs['original'].decode()
                    with h5.File(original, 'r') as origfd:
                        try:
                            config = yaml.load(origfd.attrs['config'].decode())
                        except AttributeError:  # Handling in Python3?
                            config = yaml.load(origfd.attrs['config'])
                except KeyError:
                    print('No config attribute or model file')
                    pass
        # PN spike raster
        pn_st = fd['/data/event/pn/pn_spiketime']
        fig, axes = plt.subplots(nrows=6, ncols=1, sharex=True)
        try:
            if 'calyx' in fd.attrs['description'].decode():
                descr = 'KC->GGN in alphaL + CA'
            else:
                descr = 'KC->GGN in alphaL only'
        except KeyError:
            descr = ''
        
        fig.suptitle('{} {}'.format(os.path.basename(fname), descr))
        ax_pn_spike_raster = axes[0]
        print('Plotting PN spikes')
        ax_pn_spike_raster.set_title('PN spike raster')
        nodes = [int(node.split('_')[-1]) for node in pn_st]
        nodes = ['pn_{}'.format(node) for node in sorted(nodes)]        
        spike_x, spike_y = nda.get_event_times(pn_st, nodes)
        ax_pn_spike_raster.plot(np.concatenate(spike_x), np.concatenate(spike_y), 'k,')
        simtime = Q_(config['stimulus']['onset']).to('ms').m +  \
                  Q_(config['stimulus']['duration']).to('ms').m + \
                  Q_(config['stimulus']['tail']).to('ms').m
        psth_bins = np.arange(0, simtime, psth_bin_width)
        kde_grid = np.linspace(0, simtime, 100.0)
        ax_pn_psth = ax_pn_spike_raster.twinx()
        _, sr, cf = plot_population_psth(ax_pn_psth, spike_x,
                                         config['pn']['number'], psth_bins)
        plot_population_KDE(ax_pn_psth, spike_x, kde_grid, kde_bw, color='y',
                            maxamp=max(sr))
        ax_kc_lca_spike_raster = axes[1]
        print('Plotting KC PSTH in LCA')
        ax_kc_lca_spike_raster.set_title('KC LCA')
        if by_cluster:
            _, lca_spike_x, lca_spike_y = plot_kc_spikes_by_cluster(ax_kc_lca_spike_raster, fd, 'LCA')
        else:
            _, lca_spike_x, lca_spike_y = plot_kc_spikes(ax_kc_lca_spike_raster, fd, 'LCA')
        ax_kc_lca_psth = ax_kc_lca_spike_raster.twinx()
        _, sr, cf = plot_population_psth(ax_kc_lca_psth, lca_spike_x, len(lca_spike_x), psth_bins)
        plot_population_KDE(ax_kc_lca_psth, lca_spike_x, kde_grid, kde_bw, color='y', maxamp=max(sr))
        ax_kc_lca_psth.legend()
        ax_kc_mca_spike_raster = axes[2]
        print('Plotting KC PSTH in MCA')
        ax_kc_mca_spike_raster.set_title('KC MCA')
        if by_cluster:
            _, mca_spike_x, mca_spike_y = plot_kc_spikes_by_cluster(ax_kc_mca_spike_raster, fd, 'MCA')
        else:
            _, mca_spike_x, mca_spike_y = plot_kc_spikes(ax_kc_mca_spike_raster, fd, 'MCA')
        ax_kc_mca_psth = ax_kc_mca_spike_raster.twinx()
        _, sr, cf = plot_population_psth(ax_kc_mca_psth, mca_spike_x, len(mca_spike_x), psth_bins)
        plot_population_KDE(ax_kc_mca_psth, mca_spike_x, kde_grid, kde_bw, color='y', maxamp=max(sr))
        ax_kc_mca_psth.legend()
        # LCA KC Vm
        ax_kc_lca_vm = axes[3]
        ax_kc_lca_vm.set_title('KC LCA')
        plot_kc_vm(ax_kc_lca_vm, fd, 'LCA', vm_samples)
        # MCA KC Vm
        ax_kc_mca_vm = axes[4]
        ax_kc_mca_vm.set_title('KC MCA')
        plot_kc_vm(ax_kc_mca_vm, fd, 'MCA', vm_samples)
        # GGN MCA Vm, GGN LCA Vm
        ggn_vm_plot = axes[5]
        ggn_vm_plot.set_title('GGN Vm')
        ggn_output_vm = fd['/data/uniform/ggn_output/GGN_output_Vm']
        plot_ggn_vm(ggn_vm_plot, fd, ggn_output_vm, 'LCA', vm_samples, color='r')
        plot_ggn_vm(ggn_vm_plot, fd, ggn_output_vm, 'MCA', vm_samples, color='b')
        lca, = ggn_vm_plot.plot([], color='r', label='LCA')
        mca, = ggn_vm_plot.plot([], color='b', label='MCA')
        alpha, = ggn_vm_plot.plot([], color='k', label='alphaL')
        basal, = ggn_vm_plot.plot([], color='g', label='basal')
        ggn_vm_plot.legend(handles=[lca, mca, alpha, basal])
        # GGN alphaL Vm
        ggn_alphaL_vm = fd['/data/uniform/ggn_alphaL_input/GGN_alphaL_input_Vm']
        plot_ggn_vm(ggn_vm_plot, fd, ggn_alphaL_vm, 'alphaL', vm_samples, color='k')
        try:
            ggn_basal_vm = fd['/data/uniform/ggn_basal/GGN_basal_Vm']
            plot_ggn_vm(ggn_vm_plot, fd, ggn_basal_vm, 'basal', vm_samples, color='g')
        except KeyError:
            warnings.warn('No basal Vm recorded from GGN')
        end = timer()
        print('Time for plotting {}s'.format(end - start))
        return fig, axes



def plot_data_by_jobid(jid, datadir, save=False, by_cluster=False, figdir='figures'):
    flist = os.listdir(datadir)
    match = [fname for fname in flist if 'JID{}'.format(jid) in fname]
    if len(match) > 1:
        print('Two files with same jobid.', match)
    for fname in match:
        if fname.endswith('.h5'):
            break
    fig, ax = plot_spike_rasters(os.path.join(datadir, fname), by_cluster=by_cluster)
    if save:
        figfile = os.path.join(figdir, fname.rpartition('.h5')[0] + '.png')
        fig.savefig(figfile)
        print('Saved figure in', figfile)
    return fig, ax


def plot_kc_spike_count_hist(fname, bins=None, save=False, figdir='figures'):
    """Plot histogram of spike counts recorded as 1D datasets under group
    path in file fname"""
    with h5.File(fname, 'r') as fd:
        kc_st_grp = fd['/data/event/kc/kc_spiketime']
        lca_kcs = nda.get_kc_spike_nodes_by_region(fd, 'LCA')
        lca_spike_count = [len(kc_st_grp[kc]) for kc in lca_kcs] 
        mca_kcs = nda.get_kc_spike_nodes_by_region(fd, 'mca')
        mca_spike_count = [len(kc_st_grp[kc]) for kc in mca_kcs]        
        fig, axes = plt.subplots(nrows=2, ncols=1, sharey='all', sharex='all')
        if bins is None:
            bins = np.arange(max(lca_spike_count + mca_spike_count) + 1)
        hist, bins, patches = axes[0].hist(lca_spike_count, bins=bins)
        axes[0].arrow(max(lca_spike_count), max(hist)/2, 0, -max(hist)/2.0,
                      head_width=0.5, head_length=max(hist)/20.0,
                      length_includes_head=True)
        axes[0].set_title('LCA')
        hist, bins, patches = axes[1].hist(mca_spike_count, bins=bins)
        axes[1].arrow(max(mca_spike_count), max(hist)/2, 0, -max(hist)/2,
                      head_width=0.5, head_length=max(hist)/20.0,
                      length_includes_head=True)
        axes[1].set_title('MCA')
        fname = os.path.basename(fname)
        fig.suptitle(os.path.basename(fname))
        fname = fname.rpartition('.h5')[0]
        if save:
            fname = os.path.join(figdir, fname) + '_kc_spikecount_hist.png'
            fig.tight_layout()
            fig.savefig(fname)
            print('Saved spike count histogram of KCs in', fname)
    return fig, axes


def plot_kc_clusters_with_presynaptic_pn(fpath, region, clusters):
    """Plot the spike rasters for KCs by cluster labels in clusters and
    the spike rasters for presynaptic PNs"""
    lca_clusters = nda.get_kc_clusters(fpath, region)
    print(lca_clusters.shape)
    lca_clusters = pd.DataFrame(data=lca_clusters[:, 0])
    lca_grp = lca_clusters.groupby('label')
    sec_st = {}
    pn_st = {}
    with h5.File(fpath, 'r') as fd:
        sec_st_path = nda.get_kc_event_node_map(fd)
        for kc, path in sec_st_path.items():
            sec_st[kc] = fd[path].value
        pn_kc_syn = fd[nda.pn_kc_syn_path]['pre', 'post']
        for node in fd['/data/event/pn/pn_spiketime'].values():        
            pn_st[node.attrs['source'].decode('utf-8')] = node.value

    keys = np.random.choice(lca_grp.groups.keys(), size=5, replace=False)
    pn_kc_syn = pd.DataFrame(pn_kc_syn[:, 0])            
    kc_pre_pn = pn_kc_syn.groupby('post')

    fig, axes = plt.subplots(nrows=len(keys), ncols=2, sharex='all')
    prev = set()    
    for ii, k in enumerate(keys):
        seclist = lca_grp.get_group(k)['sec']
        cluster_pre = set()
        for jj, sec in enumerate(seclist):
            st = sec_st[sec]
            axes[ii, 1].plot(st, jj * np.ones(len(st)), 'k,')
            current = set(kc_pre_pn.get_group(sec)['pre'])
            cluster_pre.update(current)
            print(len(current.intersection(prev)))
            prev = current
        for jj, pre in enumerate(cluster_pre):
            axes[ii, 0].plot(pn_st[pre], jj * np.ones(len(pn_st[pre])), 'k,')
        print('#################')
    return fig, axes


def compare_data(leftfiles, rightfiles, leftheader, rightheader):
    """Compare two simulations side by side"""
    figs = []
    axeslist = []
    psthaxlist = []
    for left, right in zip(leftfiles, rightfiles):
        fig, axes = plt.subplots(nrows=6, ncols=2, sharey='row')
        psth_axes = []
        
        for ii, fname in enumerate([left, right]):
            fpath = os.path.join(datadir, fname)                         
            with h5.File(fpath, 'r') as fd:
                config = nda.load_config(fd)
                bins = np.arange(0, nda.get_simtime(fd)+0.5, 50.0)
                try:
                    pns = list(fd[nda.pn_st_path].keys())
                except KeyError:
                    print('Could not find PNs in', fname)
                    return figs, axeslist, psthaxlist
                pns = sorted(pns, key=lambda x: int(x.split('_')[-1]))
                pn_st, pn_y = nda.get_event_times(fd[nda.pn_st_path], pns)
                axes[0, ii].plot(np.concatenate(pn_st), np.concatenate(pn_y), ',')
                psth_ax = axes[0, ii].twinx()
                psth_axes.append(psth_ax)
                plot_population_psth(psth_ax, pn_st, config['pn']['number'], bins)
                lines, kc_st, kc_y = plot_kc_spikes_by_cluster(axes[1, ii], fd, 'LCA')
                plot_population_psth(axes[2, ii], kc_st, len(kc_st), bins, rate_sym='b^', cell_sym='rv')
                stiminfo = nda.get_stimtime(fd)
                stimend = stiminfo['onset'] + stiminfo['duration'] + stiminfo['offdur']
                rates = [len(st[(st > stiminfo['onset']) & (st < stimend)]) * 1e3 
                         / (stimend - stiminfo['onset']) for st in kc_st]
                print(rates[:5])
                axes[3, ii].hist(rates, bins=np.arange(21))
                axes[3, ii].set_xlabel('Firing rate')
                plot_kc_vm(axes[4, ii], fd, 'LCA', 5)
                plot_ggn_vm(axes[5, ii], fd,
                                   fd['/data/uniform/ggn_output/GGN_output_Vm'],
                                   'LCA', 5, color='r')
                plot_ggn_vm(axes[5, ii], fd,
                                   fd['/data/uniform/ggn_basal/GGN_basal_Vm'],
                                   'basal', 5, color='g')
                axes[5, ii].set_ylim((-53, -35))
                axes[0, ii].set_title('{}\nFAKE? {}'.format(fname, nda.load_config(fd)['kc']['fake_clusters']))
        time_axes = [axes[ii, jj] for ii in [0, 1, 2, 4, 5] for jj in [0, 1]]
        for ax in time_axes[:-1]:
            ax.set_xticks([])
        axes[0, 0].get_shared_x_axes().join(*time_axes)
        axes[2, 0].get_shared_x_axes().join(*axes[2, :])
        # psth_axes[0].get_shared_y_axes().join(*psth_axes)
        psth_axes[0].autoscale()
        # axes[-1, -1].autoscale()
        fig.text(0.1, 0.95, leftheader, ha='left', va='bottom')
        fig.text(0.6, 0.95, rightheader, ha='left', va='bottom')
        fig.set_size_inches(15, 10)
        # fig.tight_layout()
        figs.append(fig)
        axeslist.append(axes)
        psthaxlist.append(psth_axes)
    return figs, axeslist, psthaxlist


datadir = '/data/rays3/ggn/olfactory_network'
files = [
    'pn_kc_ggn_UTC2018_02_15__14_47_23-PID127486-JID61581731.h5',
    'pn_kc_ggn_UTC2018_02_15__14_47_24-PID24307-JID61581730.h5',
    'pn_kc_ggn_UTC2018_02_16__23_55_29-PID78196-JID61754666.h5',
    'pn_kc_ggn_UTC2018_02_16__23_55_31-PID51661-JID61754888.h5',
    'pn_kc_ggn_UTC2018_02_17__00_01_30-PID113209-JID61756334.h5',
    'pn_kc_ggn_UTC2018_02_17__00_01_30-PID113210-JID61756338.h5'
    ]




import argparse 

def make_parser():
    parser = argparse.ArgumentParser(
        description='Plot data from NSDF simulation file for olfactory network.')
    parser.add_argument('-f', '--file', action='append')
    parser.add_argument('-c', '--cluster', action='store_true')
    parser.add_argument('-b', '--bandwidth', type=float, default=50.0,
                        help='Kernel Density estimation bandwidth')
    parser.add_argument('-s', '--save', action='store_true',
                        help='Save figures in files')
    return parser
                        
        
if __name__ == '__main__':    
    args = make_parser().parse_args()
    for fname in args.file:
        fig, axes = plot_spike_rasters(fname, kde_bw=args.bandwidth, by_cluster=args.cluster)
        fhist, fax = plot_kc_spike_count_hist(fname)
        if args.save:
            factivity = '{}_activity.png'.format(os.path.basename(fname))
            fig.savefig(factivity)
            histfile = '{}_kc_spike_count_hist.png'.format(os.path.basename(fname))
            fhist.savefig(histfile)
            print('Saved figures in {} and {}'.format(factivity, histfile))
    plt.show()
        
# 
# pn_kc_ggn_plot.py ends here