# fig_ggn_dep_gbar_corr_kc_spikes.py --- 
# Author: Subhasis  Ray
# Created: Thu Dec 27 17:17:29 2018 (-0500)
# Last-Updated: Fri Jan 11 17:23:30 2019 (-0500)
#           By: Subhasis  Ray
# Version: $Id$

# Code:

"""Attempt to correlate GGN depolarization and GGN-KC gbar to KC spike stats"""


from __future__ import print_function
import sys
if sys.platform == 'win32':
    sys.path.append('D:/subhasis_ggn/model/analysis')
else:
    sys.path += ['/home/rays3/projects/ggn/analysis',
                 '/home/rays3/projects/ggn/morphutils',
                 '/home/rays3/projects/ggn/nrn']
import os
import h5py as h5
import numpy as np
import random
from matplotlib import pyplot as plt
import yaml
import pint
from collections import defaultdict
import pandas as pd
import network_data_analysis as nda
from matplotlib.backends.backend_pdf import PdfPages
import pn_kc_ggn_plot_mpl as myplot
import neurograph as ng
import timeit
from sklearn import cluster, preprocessing
plt.style.use('classic')

_ur = pint.UnitRegistry()
Q_ = _ur.Quantity
# datadir = 'D:\\biowulf_stage\\olfactory_network\\'
datadir = '/data/rays3/ggn/olfactory_network/'


def get_dv_kc_data(fname):
    """Create a dataset where each row contains a KC, its presynaptic GGN
    section, maximum conductance from GGN, peak depolarization of that
    section and the number of spikes generated by that KC"""
    with h5.File(fname, 'r') as fd:
        print(jid, len(nda.get_spiking_kcs(fd)))
        print(yaml.dump(nda.load_config(fd), default_style=''))
        ggn_kc_gbar = pd.DataFrame(fd[nda.ggn_kc_syn_path]['pre', 'post', 'gbar'][:, 0])
        ggn_peak_vm = fd['/data/uniform/ggn_output/GGN_output_Vm'].value.max(axis=1)
        ggn_sec = fd['/data/uniform/ggn_output/GGN_output_Vm'].dims[0]['source']
        ggn_sec_vm = pd.DataFrame(data={'sec': ggn_sec, 'vm': ggn_peak_vm})
        pn_kc_gmax = pd.DataFrame(fd[nda.pn_kc_syn_path]['post', 'gmax'][:, 0])
        pn_kc_gmax_by_kc = pn_kc_gmax.groupby('post').sum().reset_index()
        kc_sc = {kc_st.attrs['source']: len(kc_st) for kc_st in fd[nda.kc_st_path].values()}
    df_kc_sc = pd.DataFrame(data={'kc': list(kc_sc.keys()), 'spike_count': list(kc_sc.values())})
    combined = pd.merge(ggn_kc_gbar, df_kc_sc, left_on='post', right_on='kc')
    combined = pd.merge(combined, ggn_sec_vm, left_on='pre', right_on='sec')    
    combined.drop(columns=['pre', 'post'], inplace=True)    
    return combined


def hist_2d_spiking_kc_data(data, vm_bins=5, gbar_bins=5):
    """Create 2D histogram of fraction of KCs spiking based with specified number of Vm
    and gbar bins."""
    kc_frac = defaultdict(dict)
    spike_count_frac =  defaultdict(dict)
    vm_cut, vm_bins = pd.cut(data.vm, bins=vm_bins, retbins=True)
    gbar_cut, gbar_bins = pd.cut(data.gbar, bins=gbar_bins, retbins=True)
    for vm_int in vm_cut.cat.categories:
        for gbar_int in gbar_cut.cat.categories:
            kc_frac[vm_int][gbar_int] =  0.0
            spike_count_frac[vm_int][gbar_int] =  0.0
    print(len(list(kc_frac.keys())))
    for key, grp in data.groupby([vm_cut, gbar_cut]):
        if key[0] not in kc_frac:
            print('!!!!', vm_cut)
        spiking = grp.spike_count.nonzero()[0]
        n_spiking = spiking.shape[0]
        spike_count = grp.spike_count.sum()
        frac_spiking = n_spiking * 1.0 / len(grp)
        kc_frac[key[0]][key[1]] = frac_spiking
        print(key, n_spiking)
        spike_count_per_cell = spike_count * 1.0 / len(grp)
        spike_count_frac[key[0]][key[1]] = spike_count_per_cell
    return kc_frac, spike_count_frac


def plot_hist(ddict):
    bins_a =  sorted(list(ddict.keys()))
    bins_b =  None
    z =  None
    for k_a in bins_a:
        if bins_b is None:
            bins_b =  sorted(list(ddict[k_a].keys()))
            z =  np.zeros((len(bins_a),  len(bins_b)))
            break
    for ia, k_a in enumerate(bins_a):
        for ib, k_b in enumerate(bins_b):
            print(k_a, k_b, ddict[k_a][k_b])
            z[ia,  ib] = ddict[k_a][k_b]
    print('Here')
    x =  [ba.left for ba in bins_a] + [bins_a[-1].right]
    y =  [bb.left for bb in bins_b] + [bins_b[-1].right]
    X, Y = np.meshgrid(x, y)
    fig, ax =  plt.subplots()
    ax.pcolormesh(X, Y, z.T)
    return fig, ax


jids = [
    '16377231',
    '16377233',
    '16377234',
    '16377237']

# for jid in jids:
#     fname = nda.find_h5_file(jid, datadir)
    # with h5.File(fname, 'r') as fd:
    #     print('jid: {}, spiking kcs: {}'.format(jid, len(nda.get_spiking_kcs(fd))))        
    #     print(yaml.dump(nda.load_config(fd), default_style='', default_flow_style=''))
#     myplot.plot_spike_rasters(fname)
#     myplot.plot_kc_spike_count_hist(fname)
# plt.show()

# |      jid | ggn-kc gmax | pn-kc gmax | spiking kcs |
# | 16377231 |         0.7 |        3.7 |       10525 |
# | 16377233 |         0.5 |        3.7 |       17708 |
# | 16377234 |         0.9 |        4.0 |       10536 |
# | 16377237 |         0.9 |        4.5 |       37060 |

## Taking 16377231 as a sample
# jid = '16377231'   # Had about 7031 KCs spiking
# fname =  nda.find_h5_file(jid, datadir)
# data =  get_dv_kc_data(fname)

# kc_frac, spike_frac = hist_2d_spiking_kc_data(data)
# for ka, va in kc_frac.items():
#     for kb,  vb in va.items():
#         print(ka, kb, vb)
# plot_hist(kc_frac)

# * all these turned out to spike only at the lowest GGN gbar
# for jid in jids:
#     fname = nda.find_h5_file(jid, datadir)
#     data = get_dv_kc_data(fname)
#     kc_frac, spike_frac = hist_2d_spiking_kc_data(data,  10, 10)
#     fig, ax = plot_hist(kc_frac)
#     fig2, ax2 = plot_hist(spike_frac)
#     # plt.show()

# plt.show()

spike_info_list =  []
fig, ax = plt.subplots(nrows=2, sharex='all')
for jid in jids:
    fname = nda.find_h5_file(jid, datadir)
    data = get_dv_kc_data(fname)
    vm_bins = pd.cut(data.vm, bins=10)
    # gbar_bins = pd.cut(data.gbar, bins=10)
    spike_info = []
    for key, grp in data.groupby(vm_bins):
        spiking = grp.spike_count.nonzero()[0]
        n_spiking = spiking.shape[0]
        spike_count = grp.spike_count.sum()
        frac_spiking = n_spiking * 1.0 / len(grp)
        spike_count_per_cell = spike_count * 1.0 / len(grp)
        spike_info.append({'jid': jid, 'vm_interval': key, 'frac_spiking': frac_spiking/key.length, 'spike_count_per_cell': spike_count_per_cell/key.length})
        # print(key, key.length, len(grp), n_spiking, frac_spiking, spike_count, spike_count_per_cell)
    spike_info_list.append(spike_info)
    spike_info =  pd.DataFrame(spike_info)
    ax[0].plot([vint.mid + 51.0 for vint in spike_info.vm_interval], spike_info.frac_spiking, 'o-', label='spiking KC fraction')
    # plt.legend(loc=3)
    # plt.subplot(212)
    ax[1].plot([vint.mid + 51.0 for vint in spike_info.vm_interval], spike_info.spike_count_per_cell, '^-', label='spikes per KC')
    
ax[1].set_xlabel('GGN dVm')
ax[0].set_ylabel('Fraction of KCs spiking')
ax[0].set_yticks([0.15, 0.25, 0.35, 0.45, 0.55])
ax[1].set_ylabel('# of spikes per KC')
ax[1].set_yticks([0.2, 1.0, 2.0])

for axis in ax:
    [spine.set_visible(False) for spine in axis.spines.values()]
    axis.tick_params(top=False, right= False)
    # plt.legend(loc=9)

fig.frameon = False
fig.savefig('ggn_peak_kc_spiking_relation.svg', transparent=True)
plt.show()

fig, ax = plt.subplots(nrows=2, sharex='all') 
for si in spike_info_list:
    spike_info = pd.DataFrame(si) 
    ax[0].plot([vint.mid + 51.0 for vint in spike_info.vm_interval], spike_info.frac_spiking, 'o-', label=spike_info.jid.iloc[0])
    # plt.legend(loc=3)
    # plt.subplot(212)
    ax[1].plot([vint.mid + 51.0 for vint in spike_info.vm_interval], spike_info.spike_count_per_cell, 'o-', label=spike_info.jid.iloc[0])

ax[1].legend()    
ax[1].set_xlabel('GGN depolarization (mV)')
ax[0].set_ylabel('Fraction of KCs spiking')
ax[0].set_yticks([0.15, 0.25, 0.35, 0.45, 0.55])
ax[1].set_ylabel('# of spikes per KC')
ax[1].set_yticks([0.2, 1.0, 2.0])

for axis in ax:
    [spine.set_visible(False) for spine in axis.spines.values()]
    axis.tick_params(top=False, right= False)
    # plt.legend(loc=9)

    
fig.frameon = False
fig.savefig('ggn_peak_kc_spiking_relation.svg', transparent=True)
plt.show()
          

                     
# 
# fig_ggn_dep_gbar_corr_kc_spikes.py ends here
