# -*- coding: utf-8 -*-
"""
Created on Tue Jun 25 13:55:10 2019

@author: Subhasis
"""


import sys
import os
import shutil
import numpy as np
import h5py as h5
import pandas as pd
import yaml
import argparse
import network_data_analysis as nda


"""Find the highest spiking KCs so we can disconnect them and repeat the 
network model simulation"""

def make_parser():
    parser = argparse.ArgumentParser(description='Disconnect high firing KCs from data file to create a template network')
    parser.add_argument('--limit', type=int, default=5, help='Upper limit of allowed KC spike count')
    parser.add_argument('--sdir', type=str, help='source directory')
    parser.add_argument('--tdir', type=str, help='target directory')
    parser.add_argument('--jid', type=str, help='JID of source dataset')
    return parser


def remove_high_firing_kcs(jid, limit, sdir, tdir):
    """Remove KCs firing more than `limit` from dataset of `jid`. `sdir`
    points to directory containing the data file, `tdir` is where the
    output file will be written. If source file is called `x.h5`,
    output file will be `x_kc{limit}.h5`

    """
    fpath = nda.find_h5_file(jid, sdir)
    kc_spike_count = []
    with h5.File(fpath, 'r') as fd:
        for kc, spikes in fd[nda.kc_st_path].items():
            kc_spike_count.append((spikes.attrs['source'], len(spikes)))
        try:
            forig_path = fd.attrs['original']
            print(jid, ': original template:', forig_path)
        except KeyError:
            forig_path = fpath
        syn = fd[nda.pn_kc_syn_path]
        if len(syn.shape) == 2:
            syn_orig = pd.DataFrame(data=syn[:, 0])
        else:
            syn_orig = pd.DataFrame(data=syn[:])
    kc_spikes = pd.DataFrame(kc_spike_count, columns=['kc', 'spikes'])
    print('# kcs > 5 spikes:', len(kc_spikes[kc_spikes['spikes'] > 5]))
    print('# kcs > 10 spikes:', len(kc_spikes[kc_spikes['spikes'] > 10]))
    print('# spiking kcs:', len(kc_spikes[kc_spikes['spikes'] > 0]))
    orig_0 = set(np.where(syn_orig['gmax'] == 0)[0])
    over = kc_spikes[kc_spikes['spikes'] > limit]
    print('{} kcs spiked more than {} spikes'.format(len(over), limit))
    if len(over) == 0:
        print('No KCs meet criterion. Nothing to do')
        return 0
    over_syn = pd.merge(syn_orig, over, left_on='post', right_on='kc')
    print('Synapses to > limit kcs', len(over_syn))
    print('Synapse to these kcs set to 0?', len(np.flatnonzero(over_syn['gmax'].values == 0)))
    # This is tricky - a simulation based on a template used to
    # generate a datafile with external reference to the synapse
    # datasets in the template. So attempt to generate another
    # template by updating synapses in the produced datafile
    # referred back to the original template.
    fname = os.path.basename(fpath).rpartition('.')[0]
    out_fname = '{}_kc{}.h5'.format(fname, limit)
    outfile = os.path.join(tdir, out_fname)
    if os.path.exists(outfile):
        print(f'File already exists: {outfile}')
        return 0
    # 2019-07-15 - copying forig_path resulted in losing KC spike
    # information from new simulation (fpath). Now that I save the
    # synapses directly, instead of an external ref, there is no need
    # to copy forig.
    print('Copying data from {} as {} to update using {} KC spiking'.format(forig_path, outfile, fpath))    
    shutil.copyfile(forig_path, outfile)
    print('Disabling PN synapses to KCs firing > {} spikes'.format(limit))
    changed_syn_count = 0
    with h5.File(outfile, 'a') as ofd:
        syndf = ofd[nda.pn_kc_syn_path]
        # Set the conductances to each KC spiking more than 5 spikes to 0
        for row in over.itertuples():
            idx = np.where(syn_orig['post'] == row.kc)[0]
            # print('Common 0 syn:', len(set(idx).intersection(orig_0)))
            changed_syn_count += len(idx)
            syndf[idx, 0, 'gmax'] = 0.0
        print('Modified: synapses set to 0 conductance:', changed_syn_count)
    ofd.close()
    print('Original: synapses with 0 conductance:', len(syn_orig[syn_orig['gmax'] == 0.0]))

    with h5.File(outfile, 'r') as o2:
        syn_new = pd.DataFrame(data=o2[nda.pn_kc_syn_path][:, 0])
        print('# synapses in updated file', len(syn_new))
        print('# shape of synapse data in updated file', syn_new.shape)
        print('# synapses with 0 conductance', len(syn_new[syn_new['gmax'] == 0.0]))
        assert (len(syn_new[syn_new['gmax'] == 0.0]) - len(syn_orig[syn_orig['gmax'] == 0.0])) == changed_syn_count
        print('Finished checking updated file')
    return changed_syn_count


if __name__ == '__main__':
    parser = make_parser()
    args = parser.parse_args()
    changed_syn_count = remove_high_firing_kcs(args.jid, args.limit, args.sdir, args.tdir)
    if changed_syn_count == 0:
        sys.exit(1)