#Functions for calculating ISI, latency, psp_amplitude for a set of files
import glob
import numpy as np
import moose
from scipy import fftpack, signal

import detect

def flatten(isiarray):
    return [item for sublist in isiarray for item in sublist]

def file_set(pattern):
    files=sorted(glob.glob(pattern))
    if len(files)==0:
        print('********* no files found for ',pattern)
    return files

def find_somatabs(tabset,soma_name,tt=None):
    #find the table(s) with vm from the soma
    comp_names=[tab.neighbors['requestOut'][0].name for tab in tabset]
    soma_tabs=[tab for tab in tabset if tab.neighbors['requestOut'][0].name==soma_name]
    print ('ISI_ANAL: vm tables {}, soma vmtab={}, comp={}'.format(comp_names,soma_tabs,[st.neighbors['requestOut'][0].path for st in soma_tabs]))
    #if no soma tables found (perhaps wrong name) use the last one, which might be soma
    #or send back number of tables equal to number of time tables 
    ######## Needs more debugging for network simulation #################3
    if len(soma_tabs)==0:
        if tt:
            num_tabs=len(tt)
        else:
            num_tabs=1
        soma_tabs=comp_names[-num_tabs:]
    return soma_tabs
    
def spike_isi_from_vm(vmtab,simtime,soma='soma'):
    spike_time={key:[] for key in vmtab.keys()}
    numspikes={key:[] for key in vmtab.keys()}
    isis={key:[] for key in vmtab.keys()}
    for neurtype, tabset in vmtab.items():
        soma_tabs=find_somatabs(tabset,soma)
        for tab in soma_tabs:
            spike_time[neurtype].append(detect.detect_peaks(tab.vector)*tab.dt)
            isis[neurtype].append(np.diff(spike_time[neurtype][-1]))
            numspikes[neurtype]=[len(st) for st in spike_time[neurtype]]
        print(neurtype,'mean:',np.mean(numspikes[neurtype]),'rate',np.mean(numspikes[neurtype])/simtime,'from',numspikes[neurtype],
              'spikes, ISI mean&STD: ',[np.mean(isi) for isi in isis[neurtype]], [np.std(isi) for isi in isis[neurtype]] )
    return spike_time,isis

def stim_spikes(spike_time,timetables,soma='soma'):
    stim_spikes={key:[] for key in spike_time.keys()}
    for neurtype, tabset in spike_time.items():
        for tab,tt in zip(tabset,timetables[neurtype].values()):
            stim_spikes[neurtype].append([st for st in spike_time[neurtype][-1] if st>np.min(tt.vector) and st<np.max(tt.vector)])
    return stim_spikes

def psp_amp(vmtab,timetables,soma='soma'):
    psp_amp={key:[] for key in vmtab.keys()}
    psp_norm={key:[] for key in vmtab.keys()}
    for neurtype, tabset in vmtab.items():
        soma_tabs=find_somatabs(tabset,soma,tt=timetables[neurtype].values())
        for tab,tt in zip(soma_tabs,timetables[neurtype].values()):
            vm_init=[tab.vector[int(t/tab.dt)] for t in tt.vector]
            #use np.min for IPSPs and np.max for EPSPs
            vm_peak=[np.min(tab.vector[int(tt.vector[i]/tab.dt):int(tt.vector[i+1]/tab.dt)]) for i in range(len(tt.vector)-1)]
            psp_amp[neurtype].append([(vm_init[i]-vm_peak[i]) for i in range(len(vm_peak))])
            psp_norm[neurtype].append([amp/psp_amp[neurtype][-1][0] for amp in psp_amp[neurtype][-1]])
    return psp_amp,psp_norm

def isi_vs_time(spike_time,isi_vals,bins,binsize,isi_set):
    st_isi=dict(zip(spike_time[1:],isi_vals))
    for pre_post in bins.keys():
        for binmin in bins[pre_post]:
            binmax=binmin+binsize
            isi_set[pre_post][binmin].append([isi_val for st, isi_val in st_isi.items() if st>=binmin and st<binmax])
    return isi_set

def set_up_bins(file0,freq,numbins,neurtype):
    bins={}
    with np.load(file0,'r') as dat:
        params=dat['params'].item()
        #'syn_tt' has one tuple (but could have multiple), 
        #1st [0] selects the 1st tuple
        #2nd [1] selects stim_times (array of spike times) from tuple (synapse,stim_times), 
        stim_tt=params[neurtype]['syn_tt'][0][1]
        #simtime=params[simtime]
        simtime=4.0
    bin_size=(stim_tt[-1]+1/float(freq)-stim_tt[0])/numbins
    #bins['stim']={stim_tt[0]+i*bin_size : stim_tt[0]+(i+1)*bin_size for i in range(numbins)}
    bins['stim']=[stim_tt[0]+i*bin_size for i in range(numbins)]
    num_bins=min(numbins,int(stim_tt[0]/bin_size))
    bins['pre']=[bins['stim'][0]-(i+1)*bin_size for i in range(num_bins)]
    bins['post']=[bins['stim'][-1]+(i+1)*bin_size for i in range(num_bins)]
    return bins,bin_size,stim_tt,simtime

def setup_stimtimes(freq,stim_tt,isi,simtime):
    pre_post_stim={}
    pre_post_stim['stim']=stim_tt
    num_pre=int(stim_tt[0]/isi)
    pre_post_stim['pre']=[stim_tt[0]-(i+1)*isi for i in range(min(num_pre-1,freq))]
    num_post=int((simtime-stim_tt[-1])/isi)
    pre_post_stim['post']=[stim_tt[-1]+(i+1)*isi for i in range(min(num_post-1,freq))]
    return pre_post_stim

def latency(files,freq,neurtype,numbins):
    isi=1.0/freq
    latency={'pre':np.zeros((freq,len(files))),'post':np.zeros((freq,len(files))), 'stim':np.zeros((freq,len(files)))}
    bins,bin_size,stim_tt,simtime=set_up_bins(files[0],freq,numbins,neurtype)
    isi_set={'pre':{k:[] for k in bins['pre']},'post':{k:[] for k in bins['post']},'stim':{k:[] for k in bins['stim']}}
    pre_post_stim=setup_stimtimes(freq,stim_tt,isi,simtime)
    for fnum,fname in enumerate(files):
        dat=np.load(fname,'r')
        params=dat['params'].item()
        if 'spike_time' in dat.keys() and params['freq']==freq:
            spike_time=dat['spike_time'].item()[neurtype][0]
            for pre_post in pre_post_stim.keys():
                for i,time in enumerate(pre_post_stim[pre_post]):
                    next_spike=np.min(spike_time[np.where(spike_time>time)])
                    latency[pre_post][i,fnum]=next_spike-time
            isi_vals=dat['isi'].item()[neurtype][0]
            isi_set=isi_vs_time(spike_time,isi_vals,bins,bin_size,isi_set)
        else:
            print('whoops, wrong file',fname,'for freq', freq,'file contains', dat.keys())
    lat_mean={}
    lat_std={}
    for pre_post in latency.keys():
        lat_mean[pre_post]=np.mean(latency[pre_post],axis=1)
        lat_std[pre_post]=np.std(latency[pre_post],axis=1)
        #print('latency {}: mean {} \n std {}'.format(pre_post,lat_mean[pre_post],lat_std[pre_post]))
    isi_mean={}
    isi_std={}
    for pre_post in isi_set.keys():
        for binmin,isilist in isi_set[pre_post].items():
            isi_set[pre_post][binmin]=[item for sublist in isilist for item in sublist]
    for pre_post in isi_set.keys():
        isi_mean[pre_post]=[np.mean(isis) for isis in isi_set[pre_post].values()]
        isi_std[pre_post]=[np.std(isis) for isis in isi_set[pre_post].values()]
        #print('isi {}: mean {} \n std {}'.format(pre_post,isi_mean[pre_post],isi_std[pre_post]))
    return lat_mean,lat_std,isi_mean,isi_std,bins

def freq_dependence(fileroot,presyn,suffix):
    pattern=fileroot+presyn+'*'+suffix
    files=file_set(pattern)
    if len(files)==0:
        return
    frequency_set=np.unique([int(fname.split('freq')[-1].split('_')[0]) for fname in files])
    results={freq:{} for freq in frequency_set}
    xval_set={freq:{} for freq in frequency_set}
    for fname in files:
        dat=np.load(fname,'r')
        params=dat['params'].item()
        if 'norm' in dat.keys():
            print ('freq dep fname', fname, dat.keys())
            numplots=len(dat['norm'].item().keys())
            results[params['freq']]={ntype:[] for ntype in dat['norm'].item().keys()}
            for neurtype in dat['norm'].item().keys():
                results[params['freq']][neurtype]=dat['norm'].item()[neurtype]
                #'syn_tt' has one tuple (but could have multiple), 
                #1st [0] selects the 1st tuple
                #2nd [1] selects stim_times (array of spike times) from tuple (synapse,stim_times), 
                xval_set[params['freq']][neurtype]=dat['params'].item()['ep']['syn_tt'][0][1]
            ylabel='normalized PSP amp'
            xlabel='pulse'
        elif 'isi' in dat.keys():
            print ('freq dep fname', fname, dat.keys())
            numplots=len(dat['isi'].item().keys())
            results[params['freq']]={ntype:[] for ntype in dat['isi'].item().keys()}
            for neurtype in dat['isi'].item().keys():
                results[params['freq']][neurtype]=dat['isi'].item()[neurtype]
                xval_set[params['freq']][neurtype]=dat['spike_time'].item()[neurtype][0]
            ylabel='isi (sec)'
            xlabel='time (sec)'
        else:
            print('issue with file {} keys {}'.format(fname,dat.keys))
    return numplots,results,xval_set,xlabel,ylabel

def get_spiketimes(files,neurtype):
    spiketimes=[]
    #creates list of spike times from set of trials
    if len(files)>0:
        for fname in files:
            dat=np.load(fname,'r')
            spiketimes.append(dat['spike_time'].item()[neurtype][0])
        #Get start and end of stimulation only from last file
        #'syn_tt' has one tuple (but could have multiple), 1st item is comp, 2nd item is array of spike times
        if neurtype in dat['params'].item().keys():
            xstart=dat['params'].item()[neurtype]['syn_tt'][0][1][0]
            xend=dat['params'].item()[neurtype]['syn_tt'][0][1][-1]
            maxt=max([max(st) for st in spiketimes])
            syntt_info={'xstart':xstart,'xend':xend,'maxt':maxt}
        else:
            syntt_info={}
    return spiketimes,syntt_info

def ISI_histogram(files,stim_freq,neurtype):
    #set-up pre, during and post-stimulation time frames (bins)
    bins,binsize,stim_tt,simtime=set_up_bins(files[0],stim_freq,1,neurtype)
    isi_set={'pre':[],'post':[],'stim':[]}
    #read in ISI data and separate into 3 time frames
    for fname in files:
        dat=np.load(fname,'r')
        params=dat['params'].item()
        if 'spike_time' in dat.keys() and params['freq']==stim_freq:
            spike_time=dat['spike_time'].item()[neurtype][0]
            isi_vals=dat['isi'].item()[neurtype][0]
            #separate into pre, post and during stimulation (optional)
            st_isi=dict(zip(spike_time[1:],isi_vals))
            for pre_post,binlist in bins.items():
                binmin=binlist[0]
                binmax=binmin+binsize
                isi_set[pre_post].append([isi_val for st, isi_val in st_isi.items() if st>=binmin and st<binmax])
        else:
            print('wrong frequency')
    return isi_set

#################### Spike triggered averages
def calc_sta(spike_time,window,vmdat,plotdt):
    numspikes=len(spike_time)
    samplesize=window[-1]-window[0]
    sta_array=np.zeros((numspikes,samplesize))
    for i,st in enumerate(spike_time):
        endpt=int(st/plotdt)+window[1]
        if endpt<len(vmdat):
            startpt=endpt-samplesize
            if startpt<0:
                sta_start=-startpt
                startpt=0
            else:
                sta_start=0
            sta_array[i,sta_start:]=vmdat[startpt:endpt]
    sta=np.mean(sta_array,axis=0)
    xvals=np.arange(window[0]*plotdt,window[1]*plotdt,plotdt)
    return xvals,sta

def sta_set(files,spike_time,neurtype,sta_start,sta_end):
    vmdat=[]
    sta_list=[]
    for trial,fname in enumerate(files):
        dat=np.load(fname,'r')
        params=dat['params'].item()
        plotdt=params['dt']
        window=(int(sta_start/plotdt),int(sta_end/plotdt))
        if 'vm' in dat.keys():
            vmdat.append(dat['vm'].item()[neurtype])
            #if (dat['vm'] has multiple traces, choose the last one
            trace=np.shape(vmdat)[1]-1
            xvals,sta=calc_sta(spike_time[trial],window,vmdat[trial][trace],plotdt)
            sta_list.append(sta)
            '''
            vmsignal=AnalogSignal(vmdat[trial][1],units='V',sampling_rate=plotdt*q.Hz)
            spikes=SpikeTrain(spike_time*q.s,t_stop=vmsignal.times[-1])
            e_sta=elephant.sta.spike_triggered_average(vmsignal,spikes,(-window*q.s,0*q.s))
            plt.plot(xvals,e_sta.magnitude,label='e_sta') 
            '''
        else:
            print('wrong spike file')
    #calculate mean over trials
    return sta_list,xvals,plotdt,vmdat

def input_raster(files):
    pre_spikes=[{} for f in files]
    for trial,infile in enumerate(files):
        ######### End temp stuff
        tt=np.load(infile).item()
        for ax,syntype in enumerate(tt.keys()):
            for presyn in tt[syntype].keys():
                spiketimes=[]
                for branch in sorted(tt[syntype][presyn].keys()):
                    #axis[ax].eventplot(tt[syntype][presyn][branch])
                    spiketimes.append(tt[syntype][presyn][branch])
                #flatten the spiketime array to use for prospective STA
                pre_spikes[trial][syntype+presyn]=spiketimes
    return pre_spikes

def post_sta_set(pre_spikes,sta_start,sta_end,plotdt,vmdat):
    window=(int(sta_start/plotdt),int(sta_end/plotdt))
    post_sta={key:[] for key in pre_spikes[0]}
    for trial in range(len(pre_spikes)):
        for ax,(key,spiketimes) in enumerate(pre_spikes[trial].items()):
            spikes=flatten(spiketimes)
            trace=np.shape(vmdat)[1]-1
            xvals,sta=calc_sta(spikes,window,vmdat[trial][trace],plotdt)
            post_sta[key].append(sta)
    mean_sta={}
    for ax,key in enumerate(post_sta.keys()):
        mean_sta[key]=np.mean(post_sta[key],axis=0)
    return post_sta,mean_sta,xvals

def sta_fire_freq(inst_rate,spike_list,sta_start,sta_end,weights,xbins):
    binsize=xbins[1]-xbins[0]
    window=(int(sta_start/binsize),int(sta_end/binsize))
    weighted_inst_rate=[np.zeros(len(xbins)) for trial in range(len(inst_rate))] 
    prespike_sta=[{} for t in range(len(inst_rate))]
    for trial in range(len(inst_rate)):
        spike_time=spike_list[trial]
        for key in inst_rate[trial].keys():
            weighted_inst_rate[trial]+=weights[key]*inst_rate[trial][key] 
            xbins,prespike_sta[trial][key]=calc_sta(spike_time,window,inst_rate[trial][key],binsize)
        xbins,prespike_sta[trial]['sum']=calc_sta(spike_time,window,weighted_inst_rate[trial],binsize)
    mean_sta={k:np.zeros(len(prespike_sta[0][k])) for k in prespike_sta[0]}
    for ax,key in enumerate(prespike_sta[0].keys()):
        for trial in range(len(prespike_sta)):
            mean_sta[key]+=prespike_sta[trial][key]
        mean_sta[key]=mean_sta[key]/len(prespike_sta)
    return prespike_sta,mean_sta,xbins

def input_fire_freq(pre_spikes,binsize):
    import elephant
    from neo.core import AnalogSignal,SpikeTrain
    import quantities as q
    inst_rate1=[{} for t in range(len(pre_spikes))]
    inst_rate2=[{} for t in range(len(pre_spikes))]
    for trial in range(len(pre_spikes)):
        print('inst firing rate for trial',trial)
        for key,spike_set in pre_spikes[trial].items():
            if isinstance(spike_set, list):
                spikes = np.sort(np.concatenate([st for st in spike_set]))
            else:
                spikes=spike_set
            train=SpikeTrain(spikes*q.s,t_stop=np.ceil(spikes[-1])*q.s)
            inst_rate1[trial][key]=elephant.statistics.instantaneous_rate(train,binsize*q.s).magnitude[:,0]
            xbins=np.arange(0,np.ceil(spikes[-1]),binsize)
            inst_rate2[trial][key]=np.zeros(len(xbins))
            for i,binmin in enumerate(xbins):
                inst_rate2[trial][key][i]=len([st for st in spikes if st>=binmin and st<binmin+binsize])/binsize
    return inst_rate1,inst_rate2,xbins

def fft_func(wave_array,ts,init_time,endtime):
    fft_wave=[]
    phase=[]
    init_point=np.min(np.where(ts>init_time))
    endpoint=np.max(np.where(ts<endtime))
    for wave in wave_array:
        #wave is an analog signal - either Vm or binary spike signal.  Do not use spiketime
        fft_wave.append(np.fft.rfft(wave[0][init_point:endpoint]))
        #Note that maximum frequency for fft is fs/2; the frequency unit is cycles/time units.
        #freqs is x axis. Two ways to obtain correct frequencies:
        #specify sampling spacing as 2nd parameter, note that fs=1/ts
        #freqs = np.fft.fftfreq(len(wave))*fs
        #multiply by sample spacing by max frequency (=1/ts):
        phase.append(np.arctan2(fft_wave[-1].imag,fft_wave[-1].real))
    freqs=np.fft.rfftfreq(len(wave[0][init_point:endpoint]),ts[1])
    mean_wave=np.mean(wave_array,axis=0)[0]
    mean_fft=np.fft.rfft(mean_wave[init_point:endpoint])
    mean_phase=np.arctan2(mean_fft.imag,mean_fft.real)
    return fft_wave,phase,freqs,mean_wave,{'mag':mean_fft,'phase':mean_phase}

############# Call this from multisim, after import ISI_anal
#  ISI_anal.save_tt(connections)
def save_tt(connections,param_sim):
    import moose
    used_tt={}
    for syntype in connections['ep']['/ep'].keys():
        used_tt[syntype]={}
        for ext in connections['ep']['/ep'][syntype].keys():
            used_tt[syntype][ext]={}
            for syn in connections['ep']['/ep'][syntype][ext].keys():
                tt=moose.element(connections['ep']['/ep'][syntype][ext][syn])
                used_tt[syntype][ext][syn]=tt.vector
    np.save('tt'+param_sim.fname,used_tt)