import numpy as np
from scipy import fftpack, signal,stats
#from synth_trains import sig_filter as filt  - look into this if want to use butterworth filer
#
def flatten(isiarray):
    return [item for sublist in isiarray for item in sublist]

def calc_one_sta(spiketrain,window,vm_or_spikerate,dt):
    samplesize=window[-1]-window[0]
    numspikes=len(spiketrain)
    sta_array=np.zeros((numspikes,samplesize))
    for i,st in enumerate(spiketrain):
        endpt=int(st/dt)+window[1]
        if endpt<len(vm_or_spikerate) and st>0:
            startpt=endpt-samplesize
            if startpt<0:
                sta_start=-startpt
                startpt=0
            else:
                sta_start=0
            sta_array[i,sta_start:]=vm_or_spikerate[startpt:endpt]
    return np.mean(sta_array,axis=0)

class network_output():
    def __init__(self,fname,numbins,overlap,table_index=None):
        dat=np.load(fname,'r',allow_pickle=True)
        self.fname=fname
        self.spiketime_dict=dat['spike_time'].item()
        self.isis=dat['isi'].item()
        self.sim_time=dat['params'].item()['simtime']
        self.num_neurs={k:len(st) for k,st in self.spiketime_dict.items()}
        self.neurtypes=self.spiketime_dict.keys()
        self.syntt_info={}
        self.freq={}
        self.timebins={k:{} for k in self.neurtypes}
        self.binsize={}
        self.pre_post_stim={}
        self.vmdat={}
        if 'vm' in dat.keys():
            vmtabs=dat['vm'].item()
            for neur in self.neurtypes:
                if table_index:
                    self.vmdat[neur]=vmtabs[neur][table_index]
                else:
                    self.vmdat[neur]=vmtabs[neur][0] #KLUGE - just do the first neuron.  If remove [0], need to edit fft_func
            if 'dt' in dat['params'].item():
                self.dt=dat['params'].item()['dt']
            else:
                self.dt=self.sim_time/len(vmtabs[neur][0])
            #arbitrarily use 1st table of last neur to construct time wave
            self.time=np.linspace(0,self.dt*len(vmtabs[neur][0]),len(vmtabs[neur][0]),endpoint=False)
            self.sim_time=max(self.time[-1],self.sim_time)
            
        for neur in self.neurtypes:
            if 'syn_tt' in dat['params'].item().keys():
                ############### Preferred, but new way, of storing regular time-table data for ep simulations #######
                stim_tt=dat['params'].item()['syn_tt'][neur][0][1] #take stim times from 1st item in list, could be multiple branches stimulated
            elif neur in dat['params'].item().keys():
                if 'syn_tt' in dat['params'].item()[neur].keys():
                    #OLD WAY of storing time-table data for ep simulations - use this for simulations prior to may 20
                    stim_tt=dat['params'].item()[neur]['syn_tt'][0][1] #take stim times from 1st item in list
            if 'stim_tt' in locals():
                xstart=stim_tt[0]
                xend=stim_tt[-1]
                self.syntt_info[neur]={'xstart':xstart,'xend':xend,'stim_tt':stim_tt}
                self.timebins[neur],self.freq[neur],self.binsize[neur]=self.set_up_bins(numbins,stim_tt)
                isi=stim_tt[1]-stim_tt[0]
                self.pre_post_stim[neur]=self.setup_stim_epochs(isi,stim_tt,self.sim_time)
            else:
                self.freq[neur]=0
                self.binsize[neur]=self.sim_time/numbins
                nbins=int(self.sim_time/self.binsize[neur]/overlap-(1./overlap-1))
                binlow=[i*self.binsize[neur]*overlap for i in range(nbins)]
                self.timebins[neur]['all']=[(bl,bl+self.binsize[neur]) for bl in binlow]

    #
    def set_up_bins(self,numbins,stim_tt):
        bins={}
        freq=float(1/(stim_tt[1]-stim_tt[0]))
        bin_size=(stim_tt[-1]+1/freq-stim_tt[0])/numbins
        binlow=sorted([stim_tt[0]+i*bin_size for i in range(numbins)])
        bins['stim']=[(bl,bl+bin_size) for bl in binlow]
        #
        num_bins=min(numbins,int(stim_tt[0]/bin_size))
        binlow=sorted([bins['stim'][0][0]-(i+1)*bin_size for i in range(num_bins)])
        bins['pre']=[(bl,bl+bin_size) for bl in binlow]
        #
        num_bins=min(numbins,int((self.sim_time-stim_tt[-1])*freq))
        binlow=sorted([bins['stim'][-1][0]+(i+1)*bin_size for i in range(num_bins)])
        bins['post']=[(bl,bl+bin_size) for bl in binlow]
        return bins,freq,bin_size
    #
    def setup_stim_epochs(self,isi,stim_tt,simtime):
        pre_post_stim={}
        pre_post_stim['stim']=stim_tt
        num_pre=int(stim_tt[0]/isi)
        pre_post_stim['pre']=sorted([stim_tt[0]-(i+1)*isi for i in range(min(num_pre-1,len(stim_tt)))])
        num_post=int((simtime-stim_tt[-1])/isi)
        pre_post_stim['post']=sorted([stim_tt[-1]+(i+1)*isi for i in range(min(num_post-1,len(stim_tt)))])
        return pre_post_stim
    #
    def spikerate_func(self):
        import elephant as elph
        from neo.core import AnalogSignal,SpikeTrain
        import quantities as q
        self.spike_rate={};self.spike_rate_elph={};self.spike_rate_mean={}
        for ntype in self.neurtypes:
            ratebins=sorted([binpair for binset in self.timebins[ntype].values() for binpair in binset])
            binsize=self.binsize[ntype]
            self.spike_rate_elph[ntype]=np.zeros((len(self.spiketime_dict[ntype]),int(self.sim_time/binsize)))
            self.spike_rate[ntype]=np.zeros(len(ratebins))
            numneurs=self.num_neurs[ntype]
            for i,spiketrain in enumerate(self.spiketime_dict[ntype]):
                train=SpikeTrain(spiketrain*q.s,t_stop=self.sim_time*q.s)
                #NOTE, instantaneous_rate fails with kernel=auto and small number of spikes
                kernel=elph.kernels.GaussianKernel(sigma=binsize*q.s)
                self.spike_rate_elph[ntype][i,:]=elph.statistics.instantaneous_rate(train,binsize*q.s,kernel=kernel).magnitude[:,0]
            self.spike_rate_mean[ntype]=np.mean(self.spike_rate_elph[ntype],axis=0)
            #Compare spike_rate_mean (from elephant) with spike_rate 
            for bb,(bl,bh) in enumerate(ratebins):
                self.spike_rate[ntype][bb]=len([st for st in  flatten(self.spiketime_dict[ntype])
                                                if st>=bl and st<bh])/binsize/self.num_neurs[ntype]
        #separate into pre,post,stim epochs?  Or separate function - only do that if stim
    
    def calc_sta(self,sta_start,sta_end):
        self.sta={}
        window=(int(sta_start/self.dt),int(sta_end/self.dt))
        for ntype in self.vmdat.keys():
            sta_set=[]
            for i,spiketrain in enumerate(self.spiketime_dict[ntype]):
                sta_set.append(calc_one_sta(spiketrain,window,self.vmdat[ntype],self.dt)) #mean over spikes
            self.sta[ntype]=np.mean(sta_set,axis=0) #mean over neurons
        self.sta_xvals=np.arange(sta_start,sta_end,self.dt)
        #net_anal:calculate mean and std over trials