# plot_active_kc_distr.py --- 
# Author: Subhasis  Ray
# Created: Mon Nov 12 15:32:18 2018 (-0500)
# Last-Updated: Fri Dec  7 11:32:07 2018 (-0500)
#           By: Subhasis  Ray
# Version: $Id$

# Code:

import sys
import os
import collections
import itertools as it
import numpy as np
import h5py as h5
import pandas as pd
import matplotlib.pyplot as plt
import yaml 
import network_data_analysis as nda

plt.rc('font', size=11)

os.chdir('/home/rays3/projects/ggn/analysis')

datadir = '/data/rays3/ggn/fixed_net'
trials_file = 'odor_trials_data.csv'


odor_trials_data = pd.read_csv(trials_file)

def find_jid_spiking_kcs(jid_list, datadir):
    ret = []
    for jid in jid_list:
        fname = nda.find_h5_file(str(jid), datadir)
        with h5.File(fname, 'r') as fd:
            ret.append(nda.get_spiking_kcs(fd))
    return ret

jid_spiking_kcs_fname = 'jid_spiking_kcs.yaml'
if not os.path.exists(jid_spiking_kcs_fname):
    spiking_kcs = find_jid_spiking_kcs(odor_trials_data.jid.values, datadir)
    jid_spiking_kcs = dict(zip(odor_trials_data.jid.values, spiking_kcs))
    with open('jid_spiking_kcs.yaml', 'w') as fd:
        yaml.dump(jid_spiking_kcs, fd)                
    print('Dumped spiking KCs for all trials in jid_spiking_kcs.yaml')
else:
    with open(jid_spiking_kcs_fname, 'r') as fd:
        jid_spiking_kcs = yaml.load(fd)
        
jid_kc_count = odor_trials_data.copy()
jid_kc_count['spiking_kcs'] = 0

common_kcs = odor_trials_data[['connection', 'template_jid', 'odor']].copy()
common_kcs.drop_duplicates(inplace=True)
common_kcs.loc[:, 'common_kcs'] = 0
common_kcs.loc[:, 'avg_kcs'] = 0

for odor, ogrp in odor_trials_data.groupby('odor'):
    for conn, cgrp in ogrp.groupby('connection'):
        print('Connection:', conn)
        for template, tgrp in cgrp.groupby('template_jid'):
            print('Template:', template)
            common = None
            counts = []
            for idx, jid in tgrp.jid.iteritems():
                counts.append(len(jid_spiking_kcs[jid]))
                jid_kc_count.loc[jid_kc_count.jid == jid, 'spiking_kcs']   \
                    = counts[-1]
                if common is None:
                    common = set(jid_spiking_kcs[jid])
                else:
                    common = common.intersection(jid_spiking_kcs[jid])
            pos = (common_kcs.odor == odor) & (common_kcs.connection == conn) & (common_kcs.template_jid == template)
            common_kcs.loc[pos, 'common_kcs'] = len(common)
            common_kcs.loc[pos, 'avg_kcs'] = np.mean(counts)


templates = odor_trials_data.template_jid.unique()
for t in templates:
    fname = nda.find_h5_file(str(t), datadir)
    print(fname)
    with h5.File(fname, 'r') as fd:
        print(yaml.dump(nda.load_config(fd), default_flow_style=''))

tmp_color_map = {
    10829002: '#e66101',       # iid
    10829014: '#b2abd2',      # clus
    9932209: '#fdb863',        # iid
    9932198: '#5e3c99'       # clus
}

tmp_label_map = {
    10829002: 'Diffuse 2',       # iid
    10829014: 'Clustered 2',      # clus
    9932209: 'Diffuse 1',        # iid
    9932198: 'Clustered 1'       # clus
}


########## Plot PN activity for each odor

ax = None
ii = 0
for odor, ogrp in odor_trials_data.groupby('odor'):
    fig = plt.figure()
    ax = fig.add_subplot(111, sharex=ax)
    print(odor)
    for odata in ogrp.itertuples():
        fname = nda.find_h5_file(str(odata.jid), datadir)
        with h5.File(fname, 'r') as fd:
            pns = list(fd[nda.pn_st_path].keys())
            pns = sorted(pns, key=lambda x: int(x.rpartition('_')[-1]))
            stlist = [fd[nda.pn_st_path][pn].value for pn in pns]
            ylist = [np.ones_like(st) * ii for ii, st in enumerate(stlist)]
            ax.plot(np.concatenate(stlist), np.concatenate(ylist), ',')
            # for sp in ['left', 'right', 'top', 'bottom']:
            #     ax.spines[sp].set_visible(False)
            ax.xaxis.set_visible(False)
            ax.yaxis.set_visible(False)
            ax.set_ylim(0, len(pns))
            ax.set_xlim(200, 1500)
        ii += 1
        break
    fig.set_size_inches(0.8, 1)
    fig.set_frameon(False)
    fig.suptitle(odor.upper())
    fig.subplots_adjust(left=0.01, bottom=0.01, right=0.99, top=0.8)
    fig.savefig('Figure_4a_odor_{}.png'.format(odor), transparent=True)
plt.show()    


########## Plot the distribution of spiking KC counts
## Figure 4b
odor_idx = 0
fig, ax = plt.subplots()
odor_ranges = {}
for odor, ogrp in jid_kc_count.groupby('odor'):
    jj = 0    
    for conn, cgrp in ogrp.groupby('connection'):
        print('Connection:', conn)
        for template, tgrp in cgrp.groupby('template_jid'):
            # print('Template:', template)
            counts = tgrp.spiking_kcs.values
            # print(len(counts))
            c = tmp_color_map[template]
            ax.boxplot(counts, positions=[odor_idx + jj], notch=False,
                       boxprops=dict(color=c, facecolor=c),
                       capprops=dict(color=c),
                       whiskerprops=dict(color=c),
                       flierprops=dict(marker='o', color=c,
                                       markerfacecolor=c, markeredgecolor=c,
                                       linestyle='none', markersize=3),
                       medianprops=dict(color='white'),
                       patch_artist=True)
            # print(odor_idx + jj)
            jj += 1
    odor_ranges[odor] = (odor_idx, odor_idx + 4)
    odor_idx += 5

conn_template = jid_kc_count[['connection', 'template_jid']].copy().drop_duplicates()
lines = []
for conn, tgrp in conn_template.groupby('connection'):
    for template in sorted(tgrp.template_jid.values):
        print(template)
        lines.append(plt.Line2D([],[], linewidth=2, color=tmp_color_map[template], label=tmp_label_map[template]))
ax.legend(bbox_to_anchor=(0.0, 1.02, 1.0, 0.102), handles=lines, loc=3, ncol=2, mode='expand')        
ax.set_xlim(-1, odor_idx)
texty = 1900.0
texth = 100
for odor, (oleft, oright) in odor_ranges.items():
    patch = plt.Rectangle((oleft-0.5, texty - texth/2.0), oright - oleft + 0.5, texth, edgecolor='black', facecolor='white')
    ax.add_patch(patch)
    bb = ax.text(0.5 * (oleft + oright), texty, odor.upper(), ha="center", va="center")
ax.set_ylabel('# of activated KCs')    
ax.spines['bottom'].set_visible(False)    
ax.spines['top'].set_visible(False)    
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
ax.xaxis.set_visible(False)
fig.set_frameon(False)
fig.set_size_inches(10/2.54, 7/2.54)
fig.tight_layout()
fig.savefig('active_kc_distr.svg')
plt.show()

############# Plot the distribution of common KCs vs average number of spiking KCs
## Figure 4c
odor_marker = {'a': 'p',
               'a1': 's',
               'a2': 'd',
               'b': 'o'}
fig, ax = plt.subplots()
for conn, cgrp in common_kcs.groupby('connection'):
    for template, tgrp in cgrp.groupby('template_jid'):
        for odor, ogrp in tgrp.groupby('odor'):
            assert(len(ogrp) == 1)
            print(ogrp)
            ax.plot(ogrp.avg_kcs.values, ogrp.common_kcs.values, color=tmp_color_map[template], marker=odor_marker[odor], mfc='none')
ax.spines['bottom'].set_visible(False)    
ax.spines['top'].set_visible(False)    
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
fig.set_frameon(False)
lines = []
for odor, marker in odor_marker.items():
    lines.append(plt.Line2D([],[], linestyle='', marker=marker, mec='black', mfc='none', label=odor.upper()))
ax.legend(bbox_to_anchor=(0.0, 1.02, 1.0, 0.102), handles=lines, loc=3, ncol=2, mode='expand')        
ax.set_xlabel('Average # of KCs')
ax.set_ylabel('# of common KCs')
fig.set_size_inches(7/2.54, 7/2.54)
fig.tight_layout()
fig.savefig('active_kc_common_vs_average.svg')
plt.show()

############ Plot distribution of common KCs between different odors
## Figure 4d

spiking_kcs_stats = pd.read_csv('shared_spiking_kc_stats.csv')
spiking_kcs_stats.loc[:, 'odor_distance'] = -1

def get_odist(oleft, oright):
    odist = -1
    if oleft == oright:
        odist = 0
    elif ((oleft == 'a') and (oright == 'a1')) or ((oleft == 'a1') and (oright == 'a2')):
        odist = 15
    elif (oleft == 'a') and (oright == 'a2'):
        odist = 30
    elif (oleft == 'a') and (oright == 'b'):
        odist = 100
    return odist

for (oleft, oright), ogrp in spiking_kcs_stats.groupby(['odor_left', 'odor_right']):
    spiking_kcs_stats.loc[(spiking_kcs_stats.odor_left == oleft) &
                          (spiking_kcs_stats.odor_right == oright),
                          'odor_distance'] = get_odist(oleft, oright)

good_stats = spiking_kcs_stats[spiking_kcs_stats.odor_distance >= 0].copy()

conn_color = {'iid': '#e66101', 'clus':  '#b2abd2'}

## Violin plot of distribution of shared spiking KCs by connection
fig, ax = plt.subplots()
pos = 0
print('******************')
for conn, cgrp in good_stats.groupby('connection'):
    print(conn, '~~~~~~~~~~~~~~~~')
    common_kcs = []
    for odist, distgrp in cgrp.groupby('odor_distance'):
        if odist == 30:
            print(odist, distgrp)
            
        common_kcs.append(distgrp.common.values)
    vp = ax.violinplot(common_kcs, np.arange(4) * 4 + pos, showmedians=True, points=20)
    plt.setp(vp['bodies'], color=conn_color[conn], alpha=0.5)
    pos += 1
plt.show()


################## Final plot Figure 4 d
## Violin plot of distribution of shared spiking KCs by connection and template
fig, ax = plt.subplots()
pos = 0
grpw = 6
for conn, cgrp in good_stats.groupby('connection'):
    for template, tgrp in cgrp.groupby('template'):
        common_kcs = []
        print('====')
        for odist, distgrp in tgrp.groupby('odor_distance'):
            print(odist)
            common_kcs.append(distgrp.common.values)
        print('---')
        c = tmp_color_map[template]
        vp = ax.violinplot(common_kcs, np.arange(4) * grpw + pos, showmedians=True, points=20)
        plt.setp(vp['bodies'], color=c)
        plt.setp(vp['cmins'], color=c)  #tmp_color_map[template])
        plt.setp(vp['cmedians'], color=c)  #tmp_color_map[template])
        plt.setp(vp['cmaxes'], color=c)  #tmp_color_map[template])
        plt.setp(vp['cbars'], color=c)  #tmp_color_map[template])
        pos += 1
ax.spines['bottom'].set_visible(False)    
ax.spines['top'].set_visible(False)    
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(False)
fig.set_frameon(False)
ax.set_ylabel('# of common KCs')
odist = sorted(good_stats.odor_distance.unique())
texth = 30
texty = -10
for ii, od in enumerate(odist):
    patch = plt.Rectangle((-1 + ii * grpw, texty - texth/2.0), grpw-1, texth, edgecolor='black', facecolor='white')
    ax.add_patch(patch)
    bb = ax.text(ii * grpw + 2.0, texty, str(od), ha="center", va="center")
ax.set_xlabel('Shift in responsive PN population')    
# ax.xaxis.set_visible(False)
ax.set_xticks([])
fig.set_size_inches(10/2.54, 7/2.54)
fig.tight_layout()
fig.savefig('active_common_kc_distr.svg')
plt.show()
######## END Fig 4d

##################
## Recompute common KCs from scratch and plot for each type of connection
fig, ax = plt.subplots()
pos = 0
for conn, cgrp in odor_trials_data.groupby('connection'):
    shared_kc_count = collections.defaultdict(list)
    for template, tgrp in cgrp.groupby('template_jid'):
        print(template)
        for (oleft, oright) in it.combinations_with_replacement(tgrp.odor.unique(), 2):
            odist = get_odist(oleft, oright)
            print(oleft, oright, odist)
            if odist < 0:
                continue
            left = np.random.permutation(tgrp.loc[tgrp.odor == oleft, 'jid'].values)
            right = np.random.permutation(tgrp.loc[tgrp.odor == oright, 'jid'].values)
            if oleft == oright:            
                right = left[len(left) // 2:]
                left = left[:len(left) // 2]
            for jid_left, jid_right in zip(left, right):
                left_kcs = set(jid_spiking_kcs[jid_left])
                right_kcs = set(jid_spiking_kcs[jid_right])
                shared_kc_count[odist].append(len(left_kcs.intersection(right_kcs)))        
            
    sorted_dist_counts = sorted(list(shared_kc_count.items()), key=lambda x: x[0])
    print('==========')
    for k, v in sorted_dist_counts:
        print(k, len(v))
    print('------')
    sorted_counts = [x[1] for x in sorted_dist_counts]
    vp = ax.violinplot(sorted_counts, np.arange(4)* 4 + pos, showmedians=True, points=20)
    pos += 1
plt.show()

##################
## Recompute common KCs from scratch and plot for each type of connection
## and for each template
fig, ax = plt.subplots()
pos = 0
for conn, cgrp in odor_trials_data.groupby('connection'):
    for template, tgrp in cgrp.groupby('template_jid'):
        print(template)
        shared_kc_count = collections.defaultdict(list)
        for (oleft, oright) in it.combinations_with_replacement(tgrp.odor.unique(), 2):
            odist = get_odist(oleft, oright)
            print(oleft, oright, odist)
            if odist < 0:
                continue
            left = np.random.permutation(tgrp.loc[tgrp.odor == oleft, 'jid'].values)
            right = np.random.permutation(tgrp.loc[tgrp.odor == oright, 'jid'].values)
            if oleft == oright:            
                right = left[len(left) // 2:]
                left = left[:len(left) // 2]
            for jid_left, jid_right in zip(left, right):
                left_kcs = set(jid_spiking_kcs[jid_left])
                right_kcs = set(jid_spiking_kcs[jid_right])
                shared_kc_count[odist].append(len(left_kcs.intersection(right_kcs)))        
            
        sorted_dist_counts = sorted(list(shared_kc_count.items()), key=lambda x: x[0])
        print('==========')
        for k, v in sorted_dist_counts:
            print(k, len(v))
        print('------')
        sorted_counts = [x[1] for x in sorted_dist_counts]
        vp = ax.violinplot(sorted_counts, np.arange(4)* 4 + pos, showmedians=True, points=20)
        pos += 1
plt.show()


# 
# plot_active_kc_distr.py ends here