import numpy as np
import os
from matplotlib import pyplot as plt
import ISI_anal
colors=['r','k','b']

def plot_latency(rootname,lat_mean_dict,lat_std_dict,filesuffix):
    #plot the latency from network neuron simulations with one regular input train, multiple trials
    fig,axes =plt.subplots(len(lat_mean_dict),1,sharex=True)
    axis=fig.axes
    for i,synstim in enumerate(lat_mean_dict.keys()):
        presyn=synstim.split('_')[0]
        freq=synstim.split('_')[1]
        for k,key in enumerate(lat_mean_dict[synstim].keys()):
            xvals=range(len(lat_mean_dict[synstim][key]))
            axis[i].plot(xvals,lat_mean_dict[synstim][key],label=key+' mean',color=colors[k])
            axis[i].plot(xvals,lat_std_dict[synstim][key],label=key+' std',linestyle='dashed',color=colors[k])
        axis[i].set_xlabel('stim number')
        axis[i].set_ylabel(presyn+' input, latency (sec)')
        fig.suptitle('Latency: '+rootname+filesuffix.split('.')[0])
        axis[i].legend()

def plot_ISI(rootname,isi_mean_dict,isi_std_dict,bins,filesuffix):
    #plot the ISI from network neuron simulations, one regular input train, multiple trials
    fig,axes =plt.subplots(len(isi_mean_dict),1,sharex=True)
    axis=fig.axes
    for i,synstim in enumerate(isi_mean_dict.keys()):
        presyn=synstim.split('_')[0]
        freq=synstim.split('_')[1]
        for k,key in enumerate(bins.keys()):
            axis[i].plot(bins[key],isi_mean_dict[synstim][key],label=key+' mean',color=colors[k])
            axis[i].plot(bins[key],isi_std_dict[synstim][key],label=key+' std',linestyle='dashed',color=colors[k])
        axis[i].set_xlabel('time (sec)')
        axis[i].set_ylabel(presyn+' input, isi (sec)')
        fig.suptitle('ISI: '+rootname+filesuffix.split('.')[0])
        axis[i].legend()

def plot_ISI_cond(all_isi_mean,bins):
    fig,axes =plt.subplots(len(all_isi_mean[cond]),1,sharex=True)
    axis=fig.axes
    fig.suptitle('ISI mean: all conditions')
    for j,cond in enumerate(all_isi_mean.keys()): 
        for i,synstim in enumerate(all_isi_mean[cond].keys()):
            presyn=synstim.split('_')[0]
            freq=synstim.split('_')[1]
            for k,key in enumerate(bins.keys()):
                label=cond if k==0 else ""
                axis[i].plot(bins[key],all_isi_mean[cond][synstim][key],label=label,color=colors[j])
            axis[i].set_xlabel('time (sec)')
            axis[i].set_ylabel(presyn+' input, isi (sec)')
    axis[i].legend(loc='lower left')
    
def plot_postsyn_raster(rootname,suffix,spiketime_dict,syntt_info):
    ####### Raster plot of spikes in post-synaptic neuron #############
    fig,axes =plt.subplots(len(spiketime_dict), 1,sharex=True)
    fig.suptitle('output '+rootname+suffix)
    axis=fig.axes
    maxtime=0
    for ax,key in enumerate(spiketime_dict.keys()):
        maxtime=max(maxtime,np.max([np.max(m) for m in spiketime_dict[key]]))
        print(key,'max time=',maxtime, 'mean freq',np.mean([len(m)/maxtime for m in spiketime_dict[key]]))
        axis[ax].eventplot(spiketime_dict[key])
        axis[ax].set_ylabel(key+' trial')
        if len(syntt_info[key]):
            xstart=syntt_info[key]['xstart']
            xend=syntt_info[key]['xend']
            maxt=syntt_info[key]['maxt']
            axis[ax].annotate('stim onset',xy=(xstart,0),xytext=(xstart/maxt, -0.2),
                              textcoords='axes fraction', arrowprops=dict(facecolor='black', shrink=0.05))
            axis[ax].annotate('offset',xy=(xend,0),xytext=(xend/maxt, -0.2),
                              textcoords='axes fraction', arrowprops=dict(facecolor='red', shrink=0.05))
    axis[-1].set_xlabel('time (sec)')
    axis[0].set_xlim(1.0,np.round(maxtime))
    return

#################### plot the set of results from single neuron simulations, range of input frequencies
##### either normalized PSPs if no spikes, or ISIs if spikes
def plot_freq_dep_psp(fileroot,presyn_set,suffix,neurtype):
    all_results={};all_xvals={}
    for i,presyn in enumerate(presyn_set):
        numplots,results,xval_set,xlabel,ylabel=ISI_anal.freq_dependence(fileroot,presyn,suffix)    
        all_results[presyn]=results
        all_xvals[presyn]=xval_set
    fig,axes =plt.subplots(numplots, len(presyn_set),sharex=True, sharey=True)
    fig.suptitle(neurtype+suffix)
    axis=fig.axes
    for i,presyn in enumerate(presyn_set):
        for k,freq in enumerate(sorted(all_results[presyn].keys())):
            for j,ntype in enumerate(all_results[presyn][freq].keys()):
                axisnum=i*len(all_results[presyn][freq].keys())+j
                for yval in all_results[presyn][freq][ntype]:
                    axis[axisnum].scatter(all_xvals[presyn][freq][ntype][0:len(yval)]-k*0.02,yval,label=str(ntype)+str(freq),marker='.')
                axis[axisnum].set_ylabel(str(presyn)+' '+ylabel)
            axis[axisnum].legend()
        axis[axisnum].set_xlabel(xlabel)
    return all_results,all_xvals

####### Membrane potential  #############
def plot_freq_dep_vm(fileroot,presyn_set,plasYN,inj,neurtype):
    fig,axes =plt.subplots(len(presyn_set), 1,sharex=True)
    fig.suptitle(' plasticity='+str(plasYN))
    axis=fig.axes
    for ax,presyn in enumerate(presyn_set):
        pattern=fileroot+presyn+'*_plas'+str(plasYN)+'_inj'+inj+'*Vm.txt'
        files=ISI_anal.file_set(pattern)
        if len(files)>0:
            vm_set={}
            for fname in sorted(files):
                data=np.loadtxt(fname,skiprows=0)
                freq=fname.split('freq')[-1].split('_')[0]
                vm_set[freq]=(data[:,0],data[:,1])
            offset=0
            for freq,(tim,vm) in vm_set.items():
                offset=offset+2 #mV
                axis[ax].plot(tim,1000*vm+offset,label=freq)
        axis[ax].set_ylabel(presyn+' Vm (mV)')
        axis[ax].legend()
    axis[-1].set_xlabel('Time (sec)')

#################### Raster plot of pre-synaptic inputs 
def plot_input_raster(pre_spikes,pattern,maxplots=None):
    colors=plt.get_cmap('viridis')
    #colors=plt.get_cmap('gist_heat')
    if maxplots:
        numplots=min(maxplots,len(pre_spikes))
    else:
        numplots=len(pre_spikes)
    for trial in range(numplots):
        fig,axes =plt.subplots(len(pre_spikes[trial].keys()), 1,sharex=True)
        fig.suptitle('input raster '+os.path.basename(pattern).split('.')[0]+'_'+str(trial))
        axis=fig.axes
        for ax,(key,spikes) in enumerate(pre_spikes[trial].items()):
            color_num=[int(cellnum*(colors.N/len(spikes))) for cellnum in range(len(spikes))]
            color_set=np.array([colors.__call__(color) for color in color_num])
            axis[ax].eventplot(spikes,color=color_set)
            axis[ax].set_ylabel(key)
        axis[-1].set_xlabel('time (s)')

def plot_sta_post_vm(pre_spikes,post_sta_dict,mean_sta_dict,post_xvals):
    for i,(synstim,sta_list) in enumerate(post_sta_dict.items()):
        fig,axes=plt.subplots(len(pre_spikes[synstim][0].keys()),1) 
        fig.suptitle('post sta '+synstim)
        axis=fig.axes
        for ax,(key,post_sta) in enumerate(sta_list.items()):
            for trial,sta in enumerate(post_sta):
                axis[ax].plot(post_xvals,sta,label=str(trial))
                axis[ax].set_ylabel(key+' trig')
            axis[ax].plot(post_xvals,mean_sta_dict[synstim][key],'k--',lw=3)
        axis[-1].set_xlabel('time (s)')
        #fig.tight_layout()

def plot_sta_vm(pre_xvals,sta_list_dict,fileroot,suffix):
    fig,axes =plt.subplots(len(sta_list_dict),1,sharex=True)
    axis=fig.axes
    fig.suptitle('ep STA '+os.path.basename(fileroot+suffix).split('_')[0])
    for i,(synstim,sta_list) in enumerate(sta_list_dict.items()):
        for trial in range(len(sta_list)):
            axis[i].plot(pre_xvals,sta_list[trial],label='sta'+str(trial))
        axis[i].plot(pre_xvals,np.mean(sta_list,axis=0),'k--',lw=2)
        axis[i].set_ylabel(synstim+' Vm (V)')
    axis[-1].set_xlabel('time (s)')
    axis[-1].legend()

def plot_sta_vm_cond(pre_xvals,sta_list_dict,mean_sta_vm):
    fig,axes=plt.subplots(len(sta_list_dict.keys()),sharex=True) 
    axis=fig.axes
    fig.suptitle('mean sta vm')
    for cond in mean_sta_vm.keys():
        for i,(synfreq,mean_sta) in enumerate(mean_sta_vm[cond].items()):
            axis[i].plot(pre_xvals,mean_sta,label=cond)
            axis[i].set_ylabel(synfreq+' Vm (V)')
    axis[-1].set_xlabel('time (s)')
    axis[-1].legend()
    
def plot_prespike_sta(prespike_sta,mean_sta,pre_xvals,title=''):
    fig,axes=plt.subplots(len(prespike_sta[0].keys()),1) 
    fig.suptitle('prespike sta '+title)
    axis=fig.axes
    for trial in range(len(prespike_sta)):
        for ax,(key,sta) in enumerate(prespike_sta[trial].items()):
            axis[ax].plot(pre_xvals,sta,label=str(trial))
            axis[ax].set_ylabel(key)
    for ax,(key,sta) in enumerate(mean_sta.items()):
        axis[ax].plot(pre_xvals,sta,'k--',lw=3)
    axis[-1].set_xlabel('time (s)')
    axis[-1].legend()

def plot_prespike_sta_cond(mean_prespike_sta,bins):
    fig,axis=plt.subplots(len(mean_prespike_sta[cond][synfreq].keys()),len(mean_prespike_sta[cond].keys()),sharex=True) 
    fig.suptitle('mean spike triggered average pre-synaptic firing')
    #need titles for each of the three columns
    for cond in mean_prespike_sta.keys():
        for axy,synfreq in enumerate(mean_prespike_sta[cond].keys()):
            for axx,(key,sta) in enumerate(mean_prespike_sta[cond][synfreq].items()):
                axis[axx,axy].plot(bins,sta,label=cond)
                axis[axx,0].set_ylabel(key)
            axis[-1,axy].set_xlabel('time (s)')
            axis[0,axy].title.set_text(synfreq)
        axis[0,0].legend()
    
def plot_inst_firing(inst_rate,xbins,title=''):
    fig,axes=plt.subplots(len(inst_rate[0].keys()),1) 
    fig.suptitle('instaneous pre-syn firing rate '+title)
    axis=fig.axes
    for trial in range(len(inst_rate)):
        for ax,(key,frate) in enumerate(inst_rate[trial].items()):
            axis[ax].plot(xbins,frate,label=str(trial))
            axis[ax].set_ylabel(key)
    axis[-1].set_xlabel('time (s)')
    axis[-1].legend()

#Calculate and plot histograms
def plot_isi_hist(rootname,isi_set_dict,numbins,suffix):
    fig,axes =plt.subplots(len(isi_set_dict),1,sharex=True)
    axis=fig.axes
    fig.suptitle('histogram '+rootname+suffix.split('.')[0])
    symbol={'stim':'o-','pre':'.--','post':'.--'}
    for i,(synstim,isi_set) in enumerate(isi_set_dict.items()):
        presyn=synstim.split('_')[0]
        freq=synstim.split('_')[1]
        mins=[np.min(flatten(isi_set[k])) for k in isi_set.keys()]
        maxs=[np.max(flatten(isi_set[k])) for k in isi_set.keys()]
        min_max=[np.min(mins),np.max(maxs)]
        histbins=10 ** np.linspace(np.log10(min_max[0]), np.log10(min_max[1]), numbins)
        histbins=np.linspace(min_max[0],min_max[1], numbins)
        hist_ep={};CV={}
        for pre_post,ISIs in isi_set.items():
            hist_ep[pre_post],tmp=np.histogram(flatten(ISIs),bins=histbins,range=min_max)
            plot_bins=[(histbins[i]+histbins[i+1])/2 for i in range(len(histbins)-1)]
            #plt.bar(plot_bins,hist_ep[pre_post], label=pre_post)#,color=colors.__call__(color_num[i]),width=binwidth)
            axis[i].plot(plot_bins,hist_ep[pre_post],symbol[pre_post], label=pre_post)
            CV[pre_post]=np.std(flatten(ISIs))/np.mean(flatten(ISIs))
            print(synstim,pre_post,': ISI mean, std=', np.mean(flatten(ISIs)),np.std(flatten(ISIs)),' CV=',CV[pre_post])
        axis[i].set_ylabel(synstim+' events')
    axis[-1].legend()
    axis[-1].set_xlabel('ISI')

def flatten(isiarray):
    return [item for sublist in isiarray for item in sublist]