import numpy as np
import glob
import mnerp_net_output as mno #must be imported first, because net_anal_utils imports from it
import net_anal_utils as nau #must be imported prior to na and isc - which import stuff from it
import net_anal_class as na
import input_spike_class as isc
import plot_utils as pu

####################################
# Parameters specifying set of files to analyze
#example 1 - Effect of regular stimulation and stp on spontaneous firing
filedir='/home/avrama/moose/moose_nerp/moose_nerp/ep_net/output/'
file_root='ep'
infiles=[]
confile_root=''
'''
param1=['GABA'] #condition - each of these occurs with each of the param2 sets.  
#param2 - each of these occurs with each of param1 sets.  The dictionary key is the word to use in constructing the filename and figure legends
param2=[{'PSP_':'GPe','_freq':20,'_plas':11},{'PSP_':'GPe','_freq':20,'_plas':10},{'PSP_':'str','_freq':20,'_plas':11},{'PSP_':'str','_freq':20,'_plas':10}]
suffix='_tg_GPe_lognorm*_ts_SPN_lognorm_ts_STN_lognorm'
out_fname='stim20Hz'
'''
#example 2A: 3 conditions, but no "regular" stimulation, #oscillatory input
filedir='/home/avrama/moose/mn_output/ep_net/' 
param1=['GABAosc','POST-HFSosc','POST-NoDaosc']
param2=[{'PSP_':'non','_freq':0,'_plas':1}]
suffix=''#oscillatory input
out_fname='osc'
infiles=['ep_net/STN_InhomPoisson_freq18_osc0.6','ep_net/str_InhomPoisson_freq4.0_osc0.2','ep_net/GPe_InhomPoisson_freq29.3_osc2.0']
'''
#example 2B: 3 conditions, but no "regular" stimulation, log normally distributed inputs
filedir='/home/avrama/moose/moose_nerp/moose_nerp/ep_net/output/' 
param1=['GABA','POST-HFS','POST-NoDa'] #spontaneous inputs
suffix='_tg_GPe_lognorm_freq29_ts_SPN_lognorm_ts_STN_lognorm' #spontaneous inputs have log normal distribution
param2=[{'PSP_':'non','_freq':0,'_plas':1}]
out_fname='lognorm'
'''

'''
#example 3 - bg network
filedir='/home/avrama/moose/moose_nerp/moose_nerp/bg_net/'
file_root='output/ctx7_Ctx_ramp'#'ctx7_Ctx_osc'
param1=['dur0.3*','dur0.5*']#['10.0_STN_lognorm28.0-fb']
param2=[{'_npas':3,'_lhx':0},{'_npas':3,'_lhx':5}]#,{'_npas':0,'_lhx':0},{'_npas':0,'_lhx':5}]
suffix='-500um'
#for input rasters, program will search for tt*.npy files with otherwise same name as output files
#alternatively, specify set of input files in list for displaying simple raster:
infiles=[filedir+'STN500_pulse_freq1.0_73dur0.05',
        filedir+'Ctx10000_ramp_freq5.0_50dur0.3',
        filedir+'Ctx10000_ramp_freq5.0_30dur0.5'] #input spikes
confile_root=''#'ctx7_connectCtx_ramp' #set to '' to avoid printing confile
'''
# Other parameters
neur='ep' #which neuron type to analyze response to regular stimulation, calculate fft (consider eliminating neur spec) and input fire freq)
numbins=20 #number of bins (for histograms) if no overlap of bins
bin_overlap=0.5 #fraction of binsize to move for next bin, set to 1.0 for no overlap
#vmtables: may be multiple tables per neuron if single neuron sim
# or one table per neuron if network sim
#for doing FFT and STA, should only use the soma table
#specify 0 (or index of soma table if sorted alphabetically by compartment name, e.g. prior to May 21)
# or set to None to use all tables, e.g. for network sim
vmtab_index=0
maxfreq=100 #only save PSD from FFT from frequencies up to this number
#set sta_start=sta_end to avoid calculating spike triggered average
sta_start=-40e-3
sta_end=0
binsize_for_prespike_sta=(sta_end-sta_start)/numbins
entropy_bin_size=0.01 #Lavian J Neurosci used 10 ms bins
#numBinsForEnt=[int((1./freq)/entropy_bin_size)]#[2, 5, 10, 25] #set to [] to avoid calculating

#parameters to control output
raster_plots=0 #showing both input data and output spiking
individual_plots=0 #set of trials for single condition
group_plots=1 #plots comparing parameter sets
savetxt=True
calc_input_ff=False
binsize_factor=0#10 #used for cross_corr, set to zero to skip cross_corr

############# loop over sets of files ##################
for cond in param1:
    for params in param2:
        key=''.join([str(k)+str(v) for k,v in params.items()])
        #### Determine dictionary keys for accumulating results of multiple parameter combos
        # assumes that either param1 or param2 has more than one entry
        # if neither has more than one entry, no need to accumulate and plot group plots
        ftitle='' #or file_root?
        if len(param1)>1:
            accum_key=cond
        else:
            accum_key=''
            ftitle=cond
        if len(param2)>1:
            accum_key=accum_key+key
        else:
            ftitle=key
        pattern=filedir+file_root+cond+key+suffix
        files=sorted(glob.glob(pattern+'*.npz'))
        if len(files)==0:
            print('************ no files found using',pattern)
        else:
            num_trials=len(files)
        ####create network_output object for each file
        data=[]
        for f in files:
            data.append(mno.network_output(f,numbins,bin_overlap,vmtab_index))
            data[-1].spikerate_func()
        #### Now create object that contains data from set of files
        #If neurtypes and files specified before na.network_output,
        #   could create alldata, and add to alldata without appending to data above
        alldata=na.network_fileset(files,data[0].neurtypes)
        for dat in data:
            alldata.st_arrays(dat)
            if len(dat.vmdat) and sta_end>sta_start:
                dat.calc_sta(sta_start,sta_end)
                alldata.sta_array(dat)
        ######### fft requires that vm was saved
        # could be calculated for all neur types
        #if np.any([len(vm) for vm in alldata.vmdat.values()]):
        if len(alldata.vmdat[neur]):
            alldata.fft_func(neur,init_time=0.05,maxfreq=maxfreq)#edit fft to only return 1st 500 values?  Or freq up to 500 Hz?
        print('cond,params',cond,key,'num neurons',alldata.num_neurs)
        ######################################
        ######## calculate summary measures (mean, std across trials)
        ######################################
        alldata.ISI_histogram(numbins)
        alldata.spikerate_mean()
        if np.any([len(sta_set) for sta_set in alldata.sta.values()]):
            alldata.sta_mean()
            xsta={nr:alldata.time[0:len(stawave)] for nr,stawave in alldata.sta_mean.items()}
        ######## Some measures only relevant if regular stimulation
        if len(alldata.pre_post_stim): #only analyze a single neur type
            alldata.latency(neur)
            alldata.ISI_vs_time(neur)
            alldata.lat_isi_mean()
            alldata.calc_lat_shift(neur,entropy_bin_size)
        ####### transfer summary measures to dictionary to plot multiple conditions on one graph
        for indata,dictname in zip(alldata.accum_list,alldata.accum_names):
            if dictname not in vars():
                vars()[dictname]={}
            vars()[dictname]=nau.accumulate_over_params(indata,vars()[dictname],accum_key)
        ############## input files - for raster or spike triggered average input
        if not len(infiles):
            inpattern=filedir+'tt'+file_root+cond+key+suffix
            input_files=sorted(glob.glob(inpattern+'*.npy'))
            if len(input_files):
                syn_input=isc.input_spikes(input_files,alldata.sim_time)
                if binsize_factor>0:
                    mean_cc,mean_cc_shuffle,cc_shuffle_corrected,cc_bins=nau.cross_corr(syn_input.spiketimes,alldata.spiketimes[neur],alldata.sim_time[neur],alldata.dt*binsize_factor)
                    accum_names=['cross_corr','cross_corr_shuffle','cross_corr_corrected']
                    accum_list=[mean_cc,mean_cc_shuffle,cc_shuffle_corrected]
                else:
                    accum_list=[]
                    accum_names=[]
                if calc_input_ff: #This is slow, so provide option to skip
                    syn_input.input_fire_freq(neur,binsize_for_prespike_sta)
                    accum_names=accum_names+syn_input.accum_names
                    accum_list=accum_list+syn_input.accum_list
                    if sta_end>start_start:
                        prespike_sta,prespike_mean,prespike_std,prespike_xvals=nau.sta_fire_freq(syn_input.inst_rate,alldata.spiketimes[neur],sta_start,sta_end,syn_input.xbins)
                        accum_names=accum_names+['prespike_sta_mean','prespike_sta_std']
                        accum_list=accum_list+[prespike_mean,prespike_std]
                for indata,dictname in zip(accum_list,accum_names):
                    if dictname not in vars():
                        vars()[dictname]={}
                    vars()[dictname]=nau.accumulate_over_params(indata,vars()[dictname],accum_key)
        ###################################################
        ######## Single parameter set plots
        ###################################################
        if individual_plots:
            if binsize_factor>0:
                pu.plot_cross_corr(mean_cc,mean_cc_shuffle,cc_shuffle_corrected,cc_bins)
            pu.plot_dict_of_dicts(alldata.isi_hist_mean,alldata.isihist_bins,ylabel='counts',std_dict=alldata.isi_hist_std,xlabel='ISI (sec)',ftitle=cond+' '+key)
            #pu.plot_dict(alldata.spikerate_mean,alldata.ratebins,ylabel='Spike Rate (Hz)',std_dict=alldata.spikerate_std,ftitle=cond+key)
            if len(alldata.pre_post_stim):
                pu.plot_dict(alldata.isi_time_mean,alldata.timebins[neur],std_dict=alldata.isi_time_std,ylabel='counts',ftitle='ISI '+cond+key)
                pu.plot_dict(alldata.lat_mean,alldata.pre_post_stim[neur],std_dict=alldata.lat_std,ylabel='Latency (sec)',ftitle=cond+key)
            if len(alldata.vmdat[neur]):
                pu.fft_plot(alldata,maxfreq=60,title=cond+key,mean_fft=True) #COMPARE TO ELIFE
                pu.plot_dict(alldata.sta_mean,xsta,std_dict=alldata.sta_std,ylabel='Vm (Volts)',ftitle='STA '+cond+' '+key)
        if raster_plots:
            pu.plot_raster(syn_input.spiketimes[0],alldata.sim_time[neur],ftitle='output '+cond+key)
            pu.plot_raster(alldata.spiketimes,alldata.sim_time[neur],syntt=dat.syntt_info,ftitle='input '+cond+key)
        if len(confile_root):
            con_fname=confile_root+cond+key+suffix+'.npz'
            nau.print_con(con_fname)
            pre_spikes={}
####### read in inputs if files specified separately (and same for all outputs) #######
if len(infiles):
    import os
    pre_spikes={}
    for f in infiles:
        pre_spikes[os.path.basename(f)]=np.load(f+'.npz','r',allow_pickle=True)['spikeTime']
    pu.plot_raster(pre_spikes,alldata.sim_time[neur],ftitle='input')
#####################
# plots comparing data across param2
#####################
if group_plots:
    rate_xvals=sorted([bin[0] for binset in alldata.timebins[neur].values() for bin in binset ]) #list
    pu.plot_dict_of_dicts(spikerate_mean,xarray=rate_xvals,ylabel='Hz',std_dict=spikerate_std,ftitle='spike rate: '+ftitle) 
    elph_xvals=np.linspace(0,alldata.sim_time[neur],np.shape(alldata.spikerate_elph[neur])[1]) #array
    pu.plot_dict_of_dicts(spikerate_elphmean,xarray=elph_xvals,ylabel='Hz',std_dict=spikerate_elphstd,ftitle='ELPH spike rate: '+ftitle,trials=num_trials) 
    hist_xvals={p:{k:[bin for bin in binset] for k,binset in alldata.isihist_bins[neur].items()} for p in isihist_mean[neur].keys()} #dict of dicts
    pu.plot_dict_of_dicts(isihist_mean[neur],std_dict=isihist_std[neur],xarray=hist_xvals,xlabel='ISI (sec)',ylabel='count',ftitle='ISI histogram: '+ftitle)
    if 'sta_mean' in vars():
        pu.plot_dict_of_dicts(sta_mean,xarray=xsta,ylabel='Vm (V)',std_dict=sta_std,ftitle='STA: '+ftitle)
    if 'inputrate_mean' in vars():
        pu.plot_dict_of_dicts(inputrate_mean,xarray=syn_input.xbins,ylabel='Hz',std_dict=inputrate_std,ftitle='Input firing rate')
        pu.plot_dict_of_dicts(prespike_sta_mean,xarray=prespike_xvals,ylabel='Firing Rate (Hz)',std_dict=prespike_sta_std,ftitle='STA Input: '+ftitle)
    if 'mean_fft' in vars():
        pu.plot_dict_of_epochs(mean_fft,std_dict=std_fft,xarray=alldata.freqs,ylabel='PSD',xlabel='Frequency (Hz)', ftitle='PSD: '+ftitle)
    if 'cross_corr' in vars():
        #consider calculating std in nau.cross_corr, and adding _std to accum_list
        pu.plot_dict_of_dicts(cross_corr_corrected,xarray=cc_bins,ylabel='',ftitle='cross_corr')        
    if len(alldata.pre_post_stim):
        stim_xvals={k: [val[0] for val in values] for k,values in alldata.timebins[neur].items()}
        pu.plot_dict_of_epochs(lat_mean,std_dict=lat_std,xarray=stim_xvals,ylabel='latency',ftitle='latency: '+ftitle)
        pu.plot_dict_of_epochs(isi_time_mean,std_dict=isi_time_std,xarray=stim_xvals,ylabel='mean ISI',ftitle='mean isi: '+ftitle)
        pu.plot_dict_of_dicts(entropy,ylabel='bits',ftitle='entropy: '+ftitle)
########################################## Write output to file for generating nicer plots 
if savetxt:
    nau.write_dict_of_dicts(spikerate_mean,rate_xvals,'spike_rate_'+out_fname,'rate',spikerate_std) 
    nau.write_dict_of_dicts(spikerate_elphmean,elph_xvals,'elph_spike_rate_'+out_fname,'Erate',spikerate_elphstd)
    nau.write_triple_dict(isihist_mean,'isi_histogram_'+out_fname,'isiN',isihist_std,xdata=hist_xvals,xheader='isi_bin') #possibly delete triple dict and loop over neur type?
    nau.write_dict_of_dicts(sta_mean,xsta,'sta_vm_'+out_fname,'stavm',sta_std)
    if 'inputrate_mean' in vars():
        nau.write_dict_of_dicts(prespike_sta_mean,prespike_xvals,'sta_spike_'+out_fname,'stapre',prespike_sta_std)
    nau.write_dict_of_dicts(mean_fft,alldata.freqs,'fft_'+out_fname,'fft',std_fft,xheader='freq') #this may need triple dict if do fft for multiple neur types
    if len(alldata.pre_post_stim):
        #x values will be the same for all data.  Possibly concatenate pre, post and stim?  write_dict_of_epochs.
        num_conditions=len(param1)*len(param2)
        nau.write_dict_of_epochs(lat_mean,stim_xvals,'latency_'+out_fname,'lat',num_conditions,stddata=lat_std) 
        nau.write_dict_of_epochs(isi_time_mean,stim_xvals,'isi_time_'+out_fname,'itiT_N',num_conditions,stddata=isi_time_std)
        ent_xvals=sorted([v for val in stim_xvals.values() for v in val])
        nau.write_dict_of_dicts(entropy,ent_xvals,'entropy_'+out_fname,'ent')

'''
NEXT:
1. latency vs latency phase - check calculation - compare with previous code, change from % to / for phase?
2. Edit fft func to allow multiple neurons per type (possibly create new function?)
'''