# -*- coding: utf-8 -*-
"""
Created on Tue Jun 25 13:55:10 2019

@author: Subhasis
"""


import sys
import os
import shutil
import numpy as np
import h5py as h5
import pandas as pd
import yaml
import network_data_analysis as nda


"""Find the highest spiking KCs so we can disconnect them and repeat the 
network model simulation"""

# jids = [
#     22072442,
#     22087964,
#     22087965,
#     22087966,
#     22087967,
#     22087970,
#     22087971,
#     22087972,
#     22087973,
#     ]
# datadir = 'Y:/Subhasis/ggn_model_data/olfactory_network'
datadir = '/data/rays3/ggn/fixed_net'
templatedir = '/data/rays3/ggn/fixed_net_templates'

jids = [
    30184869, 
    30184873,
    30184874,
    30184876,
    30184878,
    30184880,
    30184882,
    30184884,
    30184886,
    30184888,
]


for jid in jids:
    fpath = nda.find_h5_file(jid, datadir)
    forig_path = fpath
    kc_spike_count = []
    with h5.File(fpath, 'r') as fd:
        for kc, spikes in fd[nda.kc_st_path].items():
            kc_spike_count.append((spikes.attrs['source'], len(spikes)))
        try:
            forig_path = fd.attrs['original']
        except KeyError:
            forig_path = fpath
        syn_orig = pd.DataFrame(data=fd[nda.pn_kc_syn_path][:, 0])
        
    kc_spikes = pd.DataFrame(kc_spike_count, columns=['kc', 'spikes'])
    #kc_spikes['spikes'].plot('hist')
    print('# kcs > 5 spikes:', len(kc_spikes[kc_spikes['spikes'] > 5]))
    print('# kcs > 10 spikes:', len(kc_spikes[kc_spikes['spikes'] > 10]))
    print('# spiking kcs:', len(kc_spikes[kc_spikes['spikes'] > 0]))
    orig_0 = set(np.where(syn_orig['gmax'] == 0)[0])
    limits = [5]
    for lim in limits:
        over = kc_spikes[kc_spikes['spikes'] > lim]['kc']        
        print(f'{len(over)} kcs spiked more than {lim} spikes')
        over_syn = pd.merge(syn_orig, over, left_on='post', right_on='kc')
        print('Synapses to > lim kcs', len(over_syn))
        print('Synapse to these kcs set to 0?', len(np.flatnonzero(over_syn['gmax'].values == 0)))
        fname = os.path.basename(fpath).rpartition('.')[0]
        out_fname = f'{orig_fname}_kc{lim}.h5'
        outfile = os.path.join(templatedir, out_fname)
        shutil.copyfile(forig_path, outfile)
        print('Copying data from {} as {} to update using {} KC spiking'.format(forig_path, outfile, fpath))
        with h5.File(outfile, 'a') as ofd:
            syndf = ofd[nda.pn_kc_syn_path]
            # Set the conductances to each KC spiking more than 5 spikes to 0
            changed_syn_count = 0
            for ii, kc in over.iteritems():
                idx = np.where(syn_orig['post'] == kc)[0]
                print('Common 0 syn:', len(set(idx).intersection(orig_0)))
                changed_syn_count += len(idx)
                syndf[idx, 0, 'gmax'] = 0.0
            print('Modified: synapses set to 0 conductance:', changed_syn_count)
        ofd.close()
        print('Original: synapses with 0 condictance:', len(syn_orig[syn_orig['gmax'] == 0.0]))
        
        with h5.File(outfile, 'r') as o2:
            syn_new = pd.DataFrame(data=o2[nda.pn_kc_syn_path][:, 0])
            print('# synapses in updated file', len(syn_new))
            print('# shape of synapse data in updated file', syn_new.shape)
            print('# synapses with 0 conductance', len(syn_new[syn_new['gmax'] == 0.0]))
            assert (len(syn_new[syn_new['gmax'] == 0.0]) - len(syn_orig[syn_orig['gmax'] == 0.0])) == changed_syn_count


# fd.close()

# # Test with some concrete samples
# forig_path = '/data/rays3/ggn/olfactory_network/mb_net_UTC2019_03_09__18_28_19-PID22056-JID22087969.h5'
# f_kc5_template_path = '/data/rays3/ggn/fixed_net_templates/mb_net_UTC2019_03_09__18_28_19-PID22056-JID22087969_kc5.h5'
# f_kc5_path = '/data/rays3/ggn/fixed_net/fixed_net_UTC2019_06_26__20_46_06-PID43476-JID30184880.h5'

# forig = h5.File(forig_path, 'r')
# f_tmp = h5.File(f_kc5_template_path, 'r')
# f_kc5 = h5.File(f_kc5_path, 'r')
# print(f_kc5.attrs['original'])

# orig_spike_count = pd.DataFrame([(st.attrs['source'], len(st)) for st in forig[nda.kc_st_path].values()], columns=['kc', 'spikes'])
# print(orig_spike_count.head())

# tmp_spike_count = pd.DataFrame([(st.attrs['source'], len(st)) for st in f_tmp[nda.kc_st_path].values()], columns=['kc', 'spikes'])
# print(tmp_spike_count.head())

# kc5_spike_count = pd.DataFrame([(st.attrs['source'], len(st)) for st in f_kc5[nda.kc_st_path].values()], columns=['kc', 'spikes'])
# print(kc5_spike_count.head())

# syn_orig = pd.DataFrame(data=forig[nda.pn_kc_syn_path][:, 0])
# print(syn_orig.head())

# syn_tmp = pd.DataFrame(data=f_tmp[nda.pn_kc_syn_path][:, 0])
# print(syn_tmp.head())

# syn_kc5 = pd.DataFrame(data=f_kc5[nda.pn_kc_syn_path][:, 0])
# print(syn_kc5.head())

# high_orig = orig_spike_count[orig_spike_count.spikes > 5]
# high_kc5 = kc5_spike_count[kc5_spike_count.spikes > 5]

# print('Original > 5:', len(high_orig))
# print('KC5 > 5:', len(high_kc5))

# syn_high_orig = pd.merge(high_orig, syn_orig, left_on='kc', right_on='post')
# print('Synapses to high KCs orig', len(syn_high_orig))

# syn_high_kc5 = pd.merge(high_kc5, syn_kc5, left_on='kc', right_on='post')
# print('Synapses to high KCs KC5', len(syn_high_kc5))
# print(len(pd.merge(syn_kc5, high_kc5, right_on='kc', left_on='post')))

# tmp_syn_0 = syn_tmp[syn_tmp.gmax == 0]
# print('0 syn in template', len(tmp_syn_0))

# kc5_syn_0 = syn_kc5[syn_kc5.gmax == 0]
# print('0 syn in kc5', len(kc5_syn_0))

# intersec = pd.merge(kc5_syn_0, syn_high_kc5, left_on=['pre', 'post'], right_on=['pre', 'post'])
# print('Common synapses', len(intersec))

# for fd in [forig, f_tmp, f_kc5]:
#     fd.close()