import numpy as np
from matplotlib import pyplot as plt
plt.ion()
from net_anal_class import flatten

def colornum(list_index,whole_list,colormap):
    return int(list_index*(colormap.N/len(whole_list)))

def plot_dict(measures,xaxis_vals,ylabel='',xlabel='Time (sec)',std_dict={},ftitle='',trials=1):
    colors=[plt.get_cmap('Greys'),plt.get_cmap('Blues'),plt.get_cmap('Reds'),plt.get_cmap('Purples'),plt.get_cmap('Oranges')]
    fig=plt.figure()
    fig.suptitle(ftitle)
    for i,(key, yvalues) in enumerate(measures.items()):
        colormap=i%len(colors)
        if isinstance(xaxis_vals,dict):
            if isinstance(xaxis_vals[key][0],tuple):
                xvals=[(x[0]+x[1])/2. for x in xaxis_vals[key]]
            else:
                xvals=xaxis_vals[key]
        else:
            xvals=xaxis_vals
        plt.plot(xvals,yvalues,label=key,color=colors[colormap].__call__(200))
        if len(std_dict):
            ste=np.array(std_dict[key])/np.sqrt(trials) #with  default=1, if not num trials not specified, just divide by 1
            plt.fill_between(xvals,np.array(yvalues)+ste,np.array(yvalues)-ste,facecolor=colors[colormap].__call__(100))
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.legend()

def choose_xvals(xarray,yvals,param1,param2,epoch=-1):
    if xarray is not None:
        if isinstance(xarray,dict):
            if isinstance(xarray[param1],dict): #could be dict of dict
                xvals=xarray[param1][param2]
                #print('xarray is dict of dict, with keys:',xarray[param1].keys(),ylabel,ftitle)
            else:
                xvals=xarray[param1] #could be single dict
                #print('xarray is dict, with keys:',xarray.keys(),ylabel,ftitle)
        else:
            xvals=xarray
    else:
        if epoch>-1:
            xvals=range(i*len(yvals),len(yvals)*(i+1))
        else:
            xvals=range(len(yvals))
    return xvals

def plot_dict_of_dicts(mean_dict,xarray=None,ylabel='',xlabel='Time (sec)',std_dict={},ftitle='',trials=1):
    colormap=plt.get_cmap('viridis') #'plasma','inferno'
    fig,axes =plt.subplots(len(mean_dict),1,sharex=True)
    fig.suptitle(ftitle)
    axis=fig.axes
    for i,param1 in enumerate(mean_dict.keys()): #param1 is neurtype
        for k,param2 in enumerate(mean_dict[param1].keys()): #param2 is condition
            sdvals={}
            ####################################################### 
            # Determine if there is a 3d level of dictionary
            # if so, and the 3d level has only 1 key - fine
            # else, cannot plot all the data with this function
            if isinstance(mean_dict[param1][param2],dict):
                print('3 level dictionary',param1,param2,'mean',mean_dict[param1].keys(), 'std',std_dict[param1].keys())
                if len(mean_dict[param1][param2].keys())>1:
                    print('plot_dict_of_dicts:PROBLEM - multiple epochs per condition per neuron, NO PLOTS',mean_dict[param1][param2].keys(),' for ',ftitle)
                    return
                else:
                    epoch=list(mean_dict[param1][param2].keys())[0] #if 3 level dictionary, take just the one epoch
                    yvals=np.array(mean_dict[param1][param2][epoch])
                    if len(std_dict):
                        sdvals=np.array(std_dict[param1][param2][epoch])
            else:
                yvals=np.array(mean_dict[param1][param2])
                if np.shape(np.shape(yvals))[0]==2:
                    print(' ****** Only printing first item in array of size',np.shape(yvals))
                    yvals=yvals[0]
                if len(std_dict):
                    sdvals=np.array(std_dict[param1][param2])
            xvals=choose_xvals(xarray,yvals,param1,param2)
            ####################################################### 
            main_color=colormap.__call__(colornum(k,mean_dict[param1],colormap))
            std_color=tuple([mc/4 for mc in main_color])
            axis[i].plot(xvals,yvals,color=main_color,label=str(param2))
            if len(sdvals):
                axis[i].fill_between(xvals,yvals+sdvals/np.sqrt(trials),yvals-sdvals/np.sqrt(trials),facecolor=std_color)
        axis[i].legend()
        axis[i].set_xlabel(xlabel)
        axis[i].set_ylabel(str(param1)+ ' '+ylabel)

#A. dict has N neuron keys, M condition keys - ABOVE WORKS: not dict, plots the list
#B. dict has N neuron keys, M condition keys, (and 1 epoch key) - ABOVE WORKS: plots list of 1st key
#C. dict has 1 neuron key, M condition keys - ABOVE WORKS
#C. dict has 1 neuron key, M condition keys, and P epoch keys - send in only the 1 neuron's dictionary

######### Key difference with dict_of_dict is that top key specifies trace (e.g. epoch) and second key specifies axis (e.g. param)
### could avoid this if accumulated lat_mean, etc differently 
def plot_dict_of_epochs(mean_dict,xarray=None,ylabel='',xlabel='Time (sec)',std_dict={},ftitle=''):
    colormap=plt.get_cmap('viridis') #'plasma','inferno'
    fig,axes =plt.subplots(1,1,sharex=True)
    fig.suptitle(ftitle)
    axis=fig.axes
    for i,param1 in enumerate(mean_dict.keys()): #param1 is epoch
        for k,param2 in enumerate(mean_dict[param1].keys()): #param2 is condition
            yvals=np.array(mean_dict[param1][param2])
            if len(std_dict):
                sdvals=np.array(std_dict[param1][param2])
            xvals=choose_xvals(xarray,yvals,param1,param2,epoch=i)
            main_color=colormap.__call__(colornum(k,mean_dict[param1],colormap))
            std_color=tuple([mc/4 for mc in main_color])
            if i == 0:
                axis[0].plot(xvals,yvals,color=main_color,label=str(param2))
            else:
                axis[0].plot(xvals,yvals,color=main_color)
            if len(sdvals):
                axis[0].fill_between(xvals,yvals+sdvals,yvals-sdvals,facecolor=std_color)
    axis[0].legend()
    axis[0].set_xlabel(xlabel)
    axis[0].set_ylabel(ylabel)

def plot_raster(spikes,max_time,max_trains=np.inf,ftitle='',syntt={}):
    #will plot input spikes from single trials, or output spikes
    colors=plt.get_cmap('viridis')
    fig,axes =plt.subplots(len(spikes), 1,sharex=True)
    fig.suptitle('raster \n'+ftitle)
    axis=fig.axes
    for ax,(key,spikeset) in enumerate(spikes.items()):
        #print(key,np.shape(spikeset)) 
        if len(np.shape(spikeset))==2: #multiple neurons per trial, need to reshape!
            spikeset=flatten(spikeset)
        numtrains=min(max_trains,len(spikeset))
        color_num=[colornum(cellnum,spikeset,colors) for cellnum in range(numtrains)]
        color_set=np.array([colors.__call__(color) for color in color_num])
        axis[ax].eventplot(spikeset[0:numtrains],color=color_set)
        axis[ax].set_ylabel(key)
        if key in syntt:
            xstart=syntt[key]['xstart']
            xend=syntt[key]['xend']
            axis[ax].annotate('stim onset',xy=(xstart,0),xytext=(xstart/max_time, -0.2),
                              textcoords='axes fraction', arrowprops=dict(facecolor='black', shrink=0.05))
            axis[ax].annotate('offset',xy=(xend,0),xytext=(xend/max_time, -0.2),
                              textcoords='axes fraction', arrowprops=dict(facecolor='red', shrink=0.05))
    axis[-1].set_xlim([0,max_time])
    axis[-1].set_xlabel('time (s)')

def fft_plot(alldata,maxfreq=500,phase=True,title='',mean_fft=False):
    colors=[plt.get_cmap('Greys'),plt.get_cmap('Blues'),plt.get_cmap('Reds'),plt.get_cmap('Purples'),plt.get_cmap('Oranges')]
    if phase:
        fig,axes=plt.subplots(2,1)
    else:
        fig,axes=plt.subplots(1,1)
    fig.suptitle(title+' fft')

    maxfreq_pt=np.min(np.where(alldata.freqs>maxfreq))
    minpt=1
    maxval=np.max([np.max(np.abs(f[minpt:])) for fft_set in alldata.fft_wave.values() for f in fft_set])
    for i,(epoch,fft) in enumerate(alldata.fft_wave.items()):
        mapnum=i%len(colors)
        for jj,ft in enumerate(fft):
            color=colornum(jj,fft,colors[mapnum])
            axes[0].plot(alldata.freqs[minpt:maxfreq_pt], np.abs(ft*ft)[minpt:maxfreq_pt],label=epoch,color=colors[mapnum].__call__(color))
            if phase:
                axes[1].plot(alldata.freqs[minpt:maxfreq_pt], np.angle(ft)[minpt:maxfreq_pt],'.',label=epoch,color=colors[mapnum].__call__(color))
                axes[1].set_ylabel('FFT Phase')
        if mean_fft:
            axes[0].plot(alldata.freqs[minpt:maxfreq_pt],np.abs(alldata.fft_of_mean[epoch]**2)[minpt:maxfreq_pt],color=colors[mapnum].__call__(80))
            if phase:
                axes[1].plot(alldata.freqs[minpt:maxfreq_pt],np.angle(alldata.fft_of_mean[epoch])[minpt:maxfreq_pt],'.',color=colors[mapnum].__call__(80))
    axes[0].set_ylabel('FFT Power')
    axes[0].set_ylim(0,np.round(maxval)**2 )
    axes[-1].set_xlim(0 , alldata.freqs[maxfreq_pt] )
    axes[-1].set_xlabel('Frequency in Hertz [Hz]')
    axes[0].legend()

#Triple dict of dicts of dicts
def plot_prespike_sta_cond(mean_prespike_sta,bins):
    fig,axis=plt.subplots(len(mean_prespike_sta[cond][synfreq].keys()),len(mean_prespike_sta[cond].
keys()),sharex=True)
    fig.suptitle('mean spike triggered average pre-synaptic firing')
    #need titles for each of the three columns
    for cond in mean_prespike_sta.keys():
        for axy,synfreq in enumerate(mean_prespike_sta[cond].keys()):
            for axx,(key,sta) in enumerate(mean_prespike_sta[cond][synfreq].items()):
                axis[axx,axy].plot(bins,sta,label=cond)
                axis[axx,0].set_ylabel(key)
            axis[-1,axy].set_xlabel('time (s)')
            axis[0,axy].title.set_text(synfreq)
        axis[0,0].legend()

#Triple dict of dicts of dicts, with scale and offset, either scatter or line plot
def plot_freq_dep(data,xvals,ylabel,title,num_neurs,xlabel='Time (sec)',scale=1,offset=0):
    colormap=plt.get_cmap('plasma')#'inferno'
    fig,axes =plt.subplots(len(data),num_neurs,sharex=True, sharey=True)
    fig.suptitle(title)
    axis=fig.axes
    x=xvals
    for i,presyn in enumerate(data.keys()):
        for k,freq in enumerate(sorted(data[presyn].keys())):
            #if freq.startswith('_'):
            #    freq_lbl=freq[1:]
            #else:
            freq_lbl=freq
            main_color=colormap.__call__(colornum(k,data[presyn],colormap))
            for j,ntype in enumerate(data[presyn][freq].keys()):
                if isinstance(xvals,dict):
                    x=xvals[presyn][freq][ntype]
                axisnum=i*len(data[presyn][freq].keys())+j
                if np.shape(np.shape(data[presyn][freq][ntype]))[0]==2:
                    y=np.array(data[presyn][freq][ntype][0])
                else:
                    y=np.array(data[presyn][freq][ntype])
                if xlabel=='Time (sec)':
                    axis[axisnum].plot(x[0:len(y)],scale*y+k*offset,color=main_color,label=freq_lbl)
                else:
                    axis[axisnum].scatter(x[0:len(y)],scale*y-k*offset,marker='.',color=main_color,label=freq_lbl)
                #axis[axisnum].legend(title=str(ntype))
                axis[axisnum].set_ylabel(str(presyn)+' '+ylabel)
    for j in range(num_neurs):
        axis[(len(data.keys())-1)*num_neurs+j].set_xlabel(xlabel)
    axis[0].legend(title=str(ntype))
 
def plot_cross_corr(mean_cc,mean_cc_shuffle,cc_shuffle_corrected,xbins):
    fig,axes =plt.subplots(3,1,sharex=True)
    fig.suptitle('cross correlograms ')
    for key in mean_cc.keys():
        axes[0].plot(xbins,mean_cc[key],label=key)
        axes[1].plot(xbins,mean_cc_shuffle[key],label=key)
        axes[2].plot(xbins,cc_shuffle_corrected[key],label=key)
    axes[0].set_ylabel('mean cc')
    axes[1].set_ylabel('mean cc shuffled')
    axes[2].set_ylabel('mean cc shuffled-corrected')
    axes[2].legend()