import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
import neo
import quantities
import elephant

plt.style.use(['seaborn-paper',
                {'axes.spines.right':False,
                 'axes.spines.top': False,
                 'figure.constrained_layout.use': True,
                 'pdf.fonttype': 3,#42
                 'ps.fonttype': 3,
                 'savefig.dpi': 300,
                 'savefig.pad_inches': 0,
                 'figure.figsize':[8,8/1.333],#[11.32,11.32/1.3333],
                 }])

def fractional_size(f,fractional_size,height=None):
    default_size = f.get_size_inches()
    if height is None:
        newsize=default_size*fractional_size
    else:
        w = default_size[0]*fractional_size
        h = default_size[1]*fractional_size*height
        newsize=(w,h)
    f.set_size_inches(newsize)

def new_fontsize(f,fs):
    axis=f.axes
    for ax in axis:
        ax.tick_params(axis='x', labelsize=fs )
        ax.tick_params(axis='y', labelsize=fs )
        ylbl=ax.get_ylabel()
        ax.set_ylabel(ylbl,fontsize=fs)
        xlbl=ax.get_xlabel()
        ax.set_xlabel(xlbl,fontsize=fs)
    axis[0].set_xlabel('Time (s)',fontsize=fs)


def input_plot(tt_Ctx_SPN,data,low,high):
    trains = [neo.SpikeTrain(train*quantities.s, t_start=-1,t_stop=22) for train in tt_Ctx_SPN[low]['spikeTime']]
    trains_high = [neo.SpikeTrain(train*quantities.s, t_start=-1,t_stop=22) for train in tt_Ctx_SPN[high]['spikeTime']]

    psth = elephant.statistics.time_histogram(trains,quantities.s*.01)
    psth_high = elephant.statistics.time_histogram(trains_high,quantities.s*.01)
    #1st trial is identical regardless of variability.  Use data['low'] to show SPN output

    ###### Plot of low and high variability 
    plt.figure()#figsize=(12,8))
    plt.eventplot(tt_Ctx_SPN[low]['spikeTime'])
    plt.xlim(0,4)
    plt.title('Low Variability')

    plt.figure()#figsize=(12,8))
    plt.eventplot(tt_Ctx_SPN[high]['spikeTime'])
    plt.xlim(0,4)
    plt.title('High Variability')

    ##### plot of 1st and 2nd trial for high variability, to visualize movement of spikes
    plt.figure()
    trial_1 = [tt[tt<1.5] for tt in tt_Ctx_SPN[high]['spikeTime']]
    trial_2 = [tt[(tt>1.5)&(tt<3.5)]-2 for tt in tt_Ctx_SPN[high]['spikeTime']]
    plt.eventplot(trial_1,label='Trial 1')
    plt.eventplot(trial_2,colors=plt.cm.tab10(1),alpha=.7,label='Trial 2');
    plt.title('Initial Trial vs. Second Trial for High Variability')
    #plt.legend()
    plt.xlim(.5,.6)
    plt.ylim(125,150)

    # f,ax = plt.subplots()
    # ax.plot(psth.times,psth)
    # ax.set_xlim(0,2)

    # ## Figure for single trial input
    f1,ax = plt.subplots(2,1,sharex=True)
    a = ax[0]
    a.eventplot(tt_Ctx_SPN[low]['spikeTime'],linewidths=1.5,linelengths=1.5)
    a.set_xlim(0,1.2)
    a.set_ylabel('Neuron #')
    a=ax[1]
    a.plot(psth.times,psth)
    a.set_xlabel('Time (s)')
    a.set_ylabel('PSTH')
    sns.despine()
    fractional_size(f1,.75)

    ####### Figure for input (spike train & PSTH) and output (SPN response) of initial trial
    f2,ax = plt.subplots(3,1,sharex=True,constrained_layout=True)
    a = ax[0]
    a.eventplot(tt_Ctx_SPN[low]['spikeTime'],linewidths=1.5,linelengths=1.5)
    a.set_xlim(0,1.2)
    a.set_ylabel('Input Neuron #')
    a=ax[1]
    a.plot(psth.times,psth)
    a.set_ylabel('PSTH')
    #sns.despine()
    fractional_size(f2,.75,2)
    a=ax[2]
    plt.plot(data['time'],data['/data/VmD1_0']*1e3)
    a.set_xlabel('Time (s)')
    a.set_ylabel('SPN Soma Vm (mV)')
    f2.align_ylabels()
    a.set_xticks([0,.2,.4,.6,.8,1.0,1.2])
    a.set_yticks([-90,-60,-30,0,30])
    sns.despine(trim=True,offset=1)
    for a,l in zip(ax,['A','B','C']):
        a.text(-.175,1,l,transform=a.transAxes,fontweight='bold')
    return f1,f2

def plot_spine_calcium_and_weight(data,spineP,spineD='',stoptime=2,xmax=1.5,camax=3,wtmax=1.5):
    dt=np.diff(data['time'][0:2])[0]
    muM = u'(μM)'
    if len(spineD):
        f,ax = plt.subplots(1,2,sharex=True,sharey=True)
    else:
        f,ax = plt.subplots(1,squeeze=True)
        ax=[ax]
    #x = data['low']['time'][0:int(stoptime/dt)]
    x = data['time'][0:int(stoptime/dt)]
    
    pot_ex_ca = '/data/D1_sp{}{}Shell_0'.format(spineP[spineP.find('headplas')-1], spineP.split('/data/D1-extern1_to_')[-1].split('-sp')[0]) 
    ax[0].plot(x, data[pot_ex_ca][0:int(stoptime/dt)]*1e3)
    ax[0].set_ylabel(r'Spine Calcium '+muM,color=plt.cm.tab10(0),fontweight='bold')
    [a.set_xlabel('Time (s)') for a in ax]
    ax[0].set_title('Potentiation')
    ax[0].set_ylim(0,camax)
    ax[0].set_xlim(0,xmax)
    if len(spineD):
        dep_ex_ca = '/data/D1_sp{}{}Shell_0'.format(spineD[spineD.find('headplas')-1], spineD.split('/data/D1-extern1_to_')[-1].split('-sp')[0]) 
        ax[1].plot(x, data[dep_ex_ca][0:int(stoptime/dt)]*1e3)
        ax[1].set_title('Depression')
    #second value is the change in threshold. Ideally specify as parameters
    pot_amp_thresh = .46*1.158
    dep_amp_thresh = 0.2*1.656
    pot_dur_thresh = .002*1.653
    dep_dur_thresh = .032*.867

    for a in ax:
        a.axhline(y=pot_amp_thresh,linestyle='--',c=plt.cm.Blues(.4),zorder=-1)
        a.axhline(y=dep_amp_thresh,linestyle='-.',c=plt.cm.Blues(.4),zorder=-1)
    ax[0].annotate('Potentiation\nThreshold',(xmax,pot_amp_thresh),(-25,25),textcoords='offset points',arrowprops={'arrowstyle':'simple'})

    ax0twin = ax[0].twinx()

    ax0twin.plot(x,data[spineP][0:int(stoptime/dt)],c=plt.cm.tab10(2))
    ax0twin.spines['right'].set_visible(True)
    ax0twin.spines['right'].set_color(plt.cm.tab10(2))
    ax0twin.spines['left'].set_visible(False) 
    ax0twin.tick_params(left=False,labelleft=False)
    ax[0].tick_params('y',color=plt.cm.tab10(0),labelcolor=plt.cm.tab10(0))
    ax0twin.spines['left'].set_color(plt.cm.tab10(0))
    ax0twin.tick_params(right=False,labelright=False) 
    fractional_size(f,[1,.75])
    if len(spineD):
        ax1twin = ax[1].twinx()
        ax1twin.plot(x,data[spineD][0:int(stoptime/dt)],c=plt.cm.tab10(2))
        ax1twin.spines['right'].set_visible(True) 
        ax1twin.spines['right'].set_color(plt.cm.tab10(2)) 
        ax[1].spines['left'].set_visible(False)
        ax1twin.set_ylabel('Synaptic Weight',color=plt.cm.tab10(2),fontweight='bold')
        ax[1].tick_params(left=False)
        ax1twin.spines['left'].set_visible(False)  
        ax0twin.spines['right'].set_visible(False)  
        ax1twin.tick_params('y',color=plt.cm.tab10(2),labelcolor=plt.cm.tab10(2))
        ax[1].annotate('Depression\nThreshold',(0,dep_amp_thresh),(-25,25),textcoords='offset points',arrowprops={'arrowstyle':'simple'})
        twinaxes=[ax0twin,ax1twin]
        ax[0].text(-.25,1.1,'A',transform=ax[0].transAxes,fontweight='bold') #do these after fractional_size, else the coordinates specified are off the graph
        ax[1].text(-.2,1.1,'B',transform=ax[1].transAxes,fontweight='bold')
    else:
        twinaxes=[ax0twin]
        ax0twin.set_ylabel('Synaptic Weight',color=plt.cm.tab10(2),fontweight='bold')
    for tw in twinaxes:
        tw.set_ylim(0.5,wtmax)
    plt.show()
    return f

################### Scatter plot and histogram of weight change at end of trials - unused figure
def weight_change_plots(weight_change_event_df,binmin=-0.2,binmax=0.2,binsize=0.01):
    #scatter plot of weight change at each trial - one panel for each different variabilitity types
    sns.catplot(x='time',y='weightchange',data=weight_change_event_df,col='trial')

    #histogram of weight change events.
    bins=[binmin+i*binsize for i in range(int((-binmin-binsize)/binsize)+1)]+[-binsize/10,binsize/10]+[binsize+i*binsize for i in range(int((binmax-binsize)/binsize)+1)]
    weight_change_event_df.hist('weightchange',bins=bins)

def weight_histogram(data):
    f,ax = plt.subplots(1,1)
    #endweight={}
    index_for_weight = (np.abs(data['time'] - 2)).argmin()
    endweight = [data[n][index_for_weight] for n in data.dtype.names if 'plas' in n]
    ax.hist(endweight,bins=20)
    ax.set_yscale('log')
    ax.set_xlabel('Synaptic weight')
    ax.set_ylabel('Number of synapses (log scale)')
    fractional_size(f,.75)
    ax.set_xlim(.59,1.3)
    sns.despine(trim=True)
    return f

def nearby_synapse_image(all_mean,binned_means,tmax=1.0,mean_sub=True,title=''):
    from scipy import ndimage
    allf=[]
    #f,axs=plt.subplots(1,1,constrained_layout=True,figsize=(8,32))
    #for ax,ar in zip(axs,[pot_mean[1:,:]-all_mean[1:,:],dep_mean[1:,:]-all_mean[1:,:]]):#,nochange_mean[1:,:]]):
    for i,(k,v) in enumerate(binned_means.items()):
        if mean_sub:
            ar = v[0:,:]-all_mean[0:,:]
            pre_title='Mean-subtracted a'
            vmin=-25
            vmax=25
        else:
            ar = v[0:,:]
            pre_title='A'
            vmin=0
            vmax=50
        f,axs=plt.subplots(1,1,constrained_layout=True,figsize=(4,3))

        ax = axs#[i]
        #print('min: ',np.min(ar),'max: ',np.max(ar))
        pc=ax.pcolormesh(ndimage.gaussian_filter(ar,[0,0]),vmin=vmin,vmax=vmax,cmap='seismic')
        ax.set_yticks([0.5,9.5,18.5])
        ax.set_yticklabels([0,9,18])
        ax.set_ylim(0,20)
        ax.set_xticklabels(np.linspace(0,tmax,5))
        if tmax==1.0:
            ax.set_xticks([0,25,50,75,100])
        else:
            xticks=np.arange(0,np.shape(ar)[1],5)
            ax.set_xticks(xticks)
        ax.set_xlabel('Time (s)',fontsize=12)
        ax.set_title('weight bin='+str(round(k[0],3))+' to '+str(round(k[1],3)),fontsize=12)
        print('nearby synapse, min: ',np.min(ndimage.gaussian_filter(ar,[0,0])), ', max: ',np.max(ndimage.gaussian_filter(ar,[0,0])))

        ax.set_ylabel('Nearest neighboring synapses',fontsize=12)
        cbar=f.colorbar(pc)
        cbar.ax.set_ylabel('Firing rate (Hz)')# instantaneous firing rate (Hz)')
        if title:
           
            f.suptitle(pre_title+'vg firing of neighbors',fontsize=19)#,y=1.075)
        allf.append(f)
    return f

def nearby_synapse_1image(all_mean,binned_means,mean_sub=True,title=False):
    from scipy import ndimage
    f,axes=plt.subplots(3,2,constrained_layout=True,figsize=(12,12))
    j=0
    for k,v in binned_means.items():
        if mean_sub:
            ar = v[0:,:]-all_mean[0:,:]
            pre_title='Mean-subtracted a'
            vmin=-25
            vmax=25
        else:
            ar = v[0:,:]
            pre_title='A'
            vmin=0
            vmax=50
        if k[0]>0 or k[1] < [0]: #skip middle bin
            if k[0]<0:
                i=0
            else:
                i=1
            jj=j%3
            ax=axes[jj,i]
            pc=ax.pcolormesh(ndimage.gaussian_filter(ar,[0,0]),vmin=vmin,vmax=vmax,cmap='seismic')
            ax.set_yticks([0.5,6.5,12.5,18.5])
            ax.set_yticklabels([0,6,12,18])
            ax.tick_params(labelsize=12)
            ax.set_ylim(0,20)
            ax.set_xticks([0,25,50,75,100])
            ax.set_xticklabels(np.linspace(0,1,5))
            ax.set_xlabel('Time (s)',fontsize=12)
            ax.set_title('weight bin='+str(round(k[0],3))+' to '+str(round(k[1],3)),fontsize=12)
            ax.set_ylabel('Neighboring synapses',fontsize=12)
            if title:            
                f.suptitle(pre_title+'verage firing rate of neighboring synapses',fontsize=14)#,y=1.075)
            j+=1
        else:
            print('skipping',k)
    cbar=f.colorbar(pc)
    cbar.ax.set_ylabel('Firing rate (Hz)')# instantaneous firing rate (Hz)')
    return f

def weight_vs_variability(data,df,titles,keys,sigma={}):
    f1,a = plt.subplots(1,5,gridspec_kw={'width_ratios':[3,3,3,3,2]}, sharey=True)
    axes=[a[0],a[1],a[2],a[3]]
    colors = plt.cm.tab10([0,1,2,3])
    colors=sns.color_palette('colorblind')[0:4]

    for i,k in enumerate(keys):
        plas_names=[nm for nm in data[k].dtype.names if 'plas' in nm]
        for n in plas_names:
            axes[i].plot(data[k]['time'][::100],data[k][n][::100],c=colors[i],linewidth=1,alpha=.7)
        axes[i].set_title(titles[i])
        axes[i].set_ylim(0,2)
        axes[i].set_xlabel('Time (s)')
    axes[0].set_ylabel('Synaptic Weight')
        #f.suptitle('Synaptic weight')

    ax=a[4]
    cs=plt.cm.tab10([0,1,2,3])
    cs=sns.color_palette('colorblind')[0:4]
    for spine in df['spine'].drop_duplicates():
        #do not plot if all trials have weight change < 0.025
        line = [ df.loc[ (df.spine == spine) & (df.trial == sv) ]['endweight'].iat[0] for sv in keys]
        #if line.count(1.0)==len(line):
        if np.any(abs(np.array(line)-1)>=0.25): #if ([l for l in line if abs(l-1)<.025]): continue
            x = np.arange(len(line))+np.random.uniform(-.1,.1)
            ax.plot(x,line,color=plt.cm.gray(.8),linewidth=.5)
            ax.scatter(x,line,marker='o',c=cs,zorder=10,s=30,edgecolor='white',alpha=.75)
    ax.axhline(1.0,color='grey',linestyle='--',linewidth=1.5)
    ax.set_xticks([0,3])
    ax.set_xticklabels(['Low','High'])
    ax.set_xlim(-1,4)
    #ax.set_ylabel('Synaptic Weight')
    ax.set_title('Final\nWeight')
    #ax.tick_params(labelrotation=90)

    ########## Figure 4 bottom panels in manuscript

    f2,ax = plt.subplots(1,2,constrained_layout=True,sharey=False)
    sns.boxplot(x='trial',y='endweight',data=df[df['endweight']<.99],ax=ax[0],palette='colorblind')

    #plt.figure()
    sns.boxplot(x='trial',y='endweight',data=df[df['endweight']>1.01],ax=ax[1],palette='colorblind')
    
    from scipy.stats import pearsonr
    if 'sigma' in titles[0]:
        df['sigma']=df.trial.map(sigma)
        group_var='sigma'
    else:
        group_var='trial'
    means=df[df['endweight']<.99].groupby(group_var).mean().endweight
    print('LTD, correlation of endweight with variability=',pearsonr(means.index,means.values))
    print('LTD, N=',len(df[df['endweight']<0.99]),'R,p=', pearsonr(df[df['endweight']<.99].endweight,df[df['endweight']<.99][group_var]))
    means=df[df['endweight']>1.01].groupby(group_var).mean().endweight
    print('LTP, N=4, correlation of endweight with variability=',pearsonr(means.index,means.values))
    print('LTP, N=',len(df[df['endweight']>1.01]),'R,p=', pearsonr(df[df['endweight']>1.01].endweight,df[df['endweight']>1.01][group_var]))
    
    # Could fit depression to ending weight vs. sigma; might be signficant
    # Specify a consistent plasticity criteria: i.e. 1% change in synaptic weight
    for a in ax:
        #a.set_xticklabels([1,10,100,200])
        if 'sigma' in titles[0]:
            a.set_xlabel('$\sigma $ (ms)')
        elif 'move' in titles[0]:
            a.set_xlabel('P(move), %')
        a.set_ylabel('Final Weight')
    ax[0].set_title('Depression')
    ax[1].set_title('Potentiation')
    return f1,f2

######################## Weight Change Triggered Average Pre-Syn firing #####################
def colorbar7(cs,binned_weight_change_index,ax,fontsize=10):
    cmap = matplotlib.colors.LinearSegmentedColormap.from_list('Custom cmap', cs,N=len(cs))
    norm = matplotlib.colors.BoundaryNorm(binned_weight_change_index, len(cs))
    ## Color bar
    #plt.colorbar.ColorbarBase(ax,cmap=cmap,norm=norm,
    #             label='Weight Change',spacing='proportional',ticks=binned_weight_change_index)
    #plt.colorbar(plt.cm.ScalarMappable(norm=plt.Normalize(vmin=binned_weight_change_index[0],vmax=binned_weight_change_index[-1]),cmap=plt.cm.coolwarm),
    #              ax=ax,label='Weight Change',spacing='proportional',ticks=binned_weight_change_index)
    cbar = plt.colorbar(plt.cm.ScalarMappable(norm=norm,cmap=cmap),
                  ax=ax,label='Weight Change',spacing='proportional',ticks=binned_weight_change_index,format='%.2f',aspect=40)

    # cbar.ax.get_yticklabels()[3].set_va('top')
    # cbar.ax.get_yticklabels()[2].set_va('top')
    # cbar.ax.get_yticklabels()[1].set_va('top')
    # cbar.ax.get_yticklabels()[0].set_va('top')
    updateticks = cbar.ax.get_yticks()
    updateticks = updateticks[[0,1,2,3,5,6,7]]
    updateticks[3]=0
    cbar.set_ticks(updateticks)
    labels = cbar.ax.get_yticklabels()
    labels[3]='$\pm0.01$'
    cbar.set_ticklabels(labels)
    cbar.ax.tick_params(labelsize=fontsize)
    
#### Combined Figure
def combined_figure(binned_pre,binned_calcium,binned_weight_change_index,duration,std=None):
    fig=plt.figure(constrained_layout=True,figsize=(6,8))
    grid=fig.add_gridspec(4,40)
    ax_pre=fig.add_subplot(grid[0:2,0:-1])
    ax_cbar=fig.add_subplot(grid[:,-1])
    ax_cal=fig.add_subplot(grid[2:4,0:-1])
    ax_cbar.spines['left'].set_visible(False)
    ax_cbar.spines['bottom'].set_visible(False)
    ax_cbar.tick_params(left=False,labelleft=False)
    ax_cbar.tick_params(bottom=False,labelbottom=False)
    cs = plt.cm.coolwarm(np.linspace(0,1,len(binned_pre)-1))
    cs =list(cs)
    cs.insert(len(binned_pre)//2,plt.cm.gray(0.5))
    ylabels=['Presynaptic Firing Rate (Hz)','Calcium Concentration (mM)']
    for ax, binned_means, ylbl in zip([ax_pre,ax_cal],[binned_pre,binned_calcium],ylabels):
        for i,(k,v) in enumerate(binned_means.items()):
            x = np.linspace(0,duration,np.shape(v)[-1])
            if len(np.shape(v))==2:
                ax.plot(x,v[0,:],c=cs[i])
            elif len(np.shape(v))==1:
                ax.plot(x,v,c=cs[i])
            else:
                print('binned means has too many dimensions')
    if std is not None:
        for ax,binned_means, stdy in zip([ax_pre,ax_cal],[binned_pre,binned_calcium], std):
            for i,(k,v) in enumerate(binned_means.items()):
                x = np.linspace(0,duration,np.shape(v)[-1])
                if len(np.shape(v))==2:
                    ax.fill_between(x, v[0,:]-stdy[k][0,:], v[0,:]+stdy[k][0,:],alpha=0.2,facecolor=cs[i])
                elif len(np.shape(v))==1:
                    ax.fill_between(x, v-stdy[k], v+stdy[k],alpha=0.2,facecolor=cs[i])
        ax.set_ylabel(ylbl,fontsize=12)
    ax_cal.set_xlabel('Time (s)',fontsize=12)
    colorbar7(cs,binned_weight_change_index,ax_cbar,fontsize=12)
    return fig,cs

def combined_spatial(combined_neighbors_array,binned_weight_change_dict,binned_weight_change_index,all_cor,cs):
    fig=plt.figure(constrained_layout=True,figsize=(6,8))
    grid=fig.add_gridspec(4,40)
    ax_neighbor=fig.add_subplot(grid[0:2,0:-1])
    ax_cbar=fig.add_subplot(grid[:,-1])
    ax_hist=fig.add_subplot(grid[2:4,0:-1])
    ax_cbar.spines['left'].set_visible(False)
    ax_cbar.spines['bottom'].set_visible(False)
    ax_cbar.tick_params(left=False,labelleft=False)
    ax_cbar.tick_params(bottom=False,labelbottom=False)
    x = np.linspace(0,1,len(combined_neighbors_array))
    for i,(k,v) in enumerate(binned_weight_change_dict.items()):
        ax_neighbor.plot(x,np.mean(combined_neighbors_array[:,v],axis=1),c = cs[i])
        ax_hist.hist(all_cor[v],histtype='step',bins=21,range=(-1,1),density=True,label=k,linewidth=2,color=cs[i],alpha=.8)
    ax_neighbor.set_ylabel('Combined Firing\n Rate of Neighbors (Hz)')
    ax_neighbor.set_xlabel('Time (s)')
    
    ax_hist.set_xlim(-1,1)
    ax_hist.set_xlabel('Correlation between Direct and Neighbors Input')
    ax_hist.set_ylabel('Normalized Histogram')
    colorbar7(cs,binned_weight_change_index,ax_cbar)
    return fig
    
#### Use 7 bins
def weight_change_trig_avg(binned_means,binned_weight_change_index,duration,std=None,cs=None,colorbar=False,title='',ylabel='Instantaneous Presynaptic Firing Rate (Hz)'):
    fig,ax = plt.subplots(1,1,constrained_layout=False)#,figsize=(12,8))
    fig.suptitle(title)
    if len(title):
        fontsize=10
    else:
        fontsize=14
    if not cs:
        cs = plt.cm.coolwarm(np.linspace(0,1,len(binned_means)-1))
        cs =list(cs)
        cs.insert(len(binned_means)//2,plt.cm.gray(0.5))
    for i,(k,v) in enumerate(binned_means.items()):
        x = np.linspace(0,duration,np.shape(v)[-1])
        if isinstance(k,tuple):
            lbl=' '.join([str(round(float(k[0]),3)),'to',str(round(float(k[1]),3))])
        else:
            lbl=k
        if len(np.shape(v))==2:
            ax.plot(x,v[0,:],c=cs[i],label=lbl)
            #if std is not None:
            #    ax.fill_between(x, v[0,:], v[0,:]+std[k][0,:],alpha=0.2,facecolor=cs[i])
        elif len(np.shape(v))==1:
            ax.plot(x,v,c=cs[i],label=lbl)
            #if std is not None:
            #    ax.fill_between(x, v, v+std[k],alpha=0.2,facecolor=cs[i])
        else:
            print('binned means has too many dimensions')
    ax.set_ylabel(ylabel)
    ax.set_xlabel('Time (s)')
    #ax.tick_params(axis='y', labelsize=12 )
    #ax.set_title('Binned Weight-change triggered\n average presynaptic firing rate')
    if not colorbar:
        ax.legend()
    else:
        colorbar7(cs,binned_weight_change_index,ax,fontsize=fontsize)
        # #cbar.ax.set_yticklabels(cbar.ax.get_yticklabels(),va='top')
        # cbar.ax.get_yticklabels()[4].set_va('bottom')
        # cbar.ax.get_yticklabels()[5].set_va('bottom')
        # cbar.ax.get_yticklabels()[6].set_va('bottom')
        # cbar.ax.get_yticklabels()[7].set_va('bottom')
    return fig,cs

def endwt_plot(df,xcolumn,xlabel,titles):
    if len(np.unique(df.trial))>10:
        leg=False
    else:
        leg='full'
    f_bcm,ax = plt.subplots()#figsize=(12,8))
    sns.scatterplot(x=xcolumn,y='endweight',data=df,hue='trial',ax=ax,palette='magma',legend=leg)
    ax.set_ylabel('Ending Synaptic Weight')
    ax.set_xlabel(xlabel)
    if 'sigma' in titles[0]:
        for i,t in enumerate(titles):
            if len(ax.get_legend().get_texts())  > len(titles):
                j=i+1
                ax.get_legend().get_texts()[0].set_text('Variability')            
            else:
                j=i
            ax.get_legend().get_texts()[j].set_text(t)
    #ax.set_title('Ending weight vs. total presynaptic spike count for every synapse')
    sns.despine(trim=True)
    return f_bcm

def fft_plot(fft_dict,freqs,binned_weight_change_index,cs=None,colorbar=False,title='',ylabel='FFT magnitude',min_pt=1,maxfreq=0.1):
    fig,ax = plt.subplots(1,1,constrained_layout=False)#,figsize=(12,8))
    fig.suptitle(title)
    if len(title):
        fontsize=10
    else:
        fontsize=14
    if not cs:
        cs = plt.cm.coolwarm(np.linspace(0,1,len(binned_means)-1))
        cs =list(cs)
        cs.insert(len(binned_means)//2,plt.cm.gray(0.5))
    for k,v in binned_means.items():
        if isinstance(k,tuple):
            lbl=' '.join([str(round(float(k[0]),3)),'to',str(round(float(k[1]),3))])
        else:
            lbl=k
        ax.plot(freqs[1:],v[min_pt:],c=cs[i],label=lbl)
    ax.set_ylabel(ylabel)
    ax.set_xlabel('Frequency (Hz)')
    ax.set_xlim([0,maxfreq])
    if not colorbar:
        ax.legend()
    else:
        colorbar7(cs,binned_weight_change_index,ax,fontsize=fontsize)
    return fig,cs
   
def rand_forest_plot(ytrain,ytest,fit,pred,title=''):
    f,axes = plt.subplots(1,2,sharey=True,sharex=True)
    axes[0].scatter(ytrain,fit)
    axes[0].set_title('train')
    #reg.score(y_train,fit)
    axes[1].set_title('test')
    axes[1].scatter(ytest,pred)
    axes[0].set_ylabel('predicted weight change')
    xmin=round(min(ytrain.min(),ytest.min()),1)
    xmax=round(max(ytrain.max(),ytest.max()),1)
    ymin=round(min(fit.min(),pred.min()),1)
    ymax=round(max(fit.max(),pred.max()),1)
    diagmin=min(xmin,ymin)
    diagmax=max(ymin,ymax)
    for ax in axes:
        ax.set_xlabel('weight change')
        ax.hlines(0,xmin,xmax,'gray','dashed')
        ax.vlines(0,ymin,ymax,'gray','dashed')
        ax.plot([diagmin,diagmax],[diagmin,diagmax],'g')
    f.suptitle(title)

def plot_features(list_features,epochs,ylabel):
    objects=[name for name,weight in list_features]
    y_pos = np.arange(len(list_features))
    performance = [weight for name, weight in list_features]
    f = plt.figure(figsize=(6,4))

    plt.bar(y_pos, performance, align='center', alpha=0.5)
    plt.xticks(y_pos, objects)
    plt.xticks(rotation=90)
    plt.ylabel(ylabel)
    plt.xlabel('Feature')
    plt.title(ylabel+' over '+epochs+' epochs')
    plt.tight_layout()

def plotPredictions(max_feat, train_test, predict_dict, class_labels, feature_order,title,colors):
    from matplotlib.colors import ListedColormap
    ########## Graph the output using contour graph
    #inputdf contains the value of a subset of features used for classifier, i.e., two different columns from df
    feature_cols = [feat[0] for feat in feature_order]
    #['r', 'b','m','gray','g','orange','cyan']
    plt.ion()
    edgecolors=['k','none']
    feature_axes=[(i,i+1) for i in range(0,max_feat,2)]
    print(feature_axes)
    for cols in feature_axes:
        plt.figure()
        plt.title(title)
        for key,col in zip(train_test.keys(),edgecolors):
            predict=predict_dict[key]
            df=train_test[key][0]
            plot_predict=[list(class_labels).index(p) for p in predict]
            plt.scatter(df[feature_cols[cols[0]]], df[feature_cols[cols[1]]], c=plot_predict,cmap=ListedColormap(colors), edgecolor=col, s=20,label=key)
            plt.xlabel(feature_cols[cols[0]])
            plt.ylabel(feature_cols[cols[1]])
            plt.legend()


def inst_wt_change_plot(weight_change_event_df,save_name=''):
    figs={}
    for xcolumn in ['isi','pre_rate']:#,'presyn_spike','postsyn_spike',]:
        f,ax = plt.subplots()
        sns.scatterplot(x=xcolumn,y='dW',data=weight_change_event_df,hue='trial',ax=ax,palette='magma',legend='full')    
        ax.set_ylabel('Weight Change',fontsize=12)
        ax.set_xlabel(xcolumn+' (sec)',fontsize=12)
        if xcolumn=='isi':
            xlim=ax.get_xlim()
            newXmin=max(-0.5,xlim[0])
            ax.set_xlim([newXmin,0.4])
        ax.tick_params(labelsize=12)
        ax.legend()
        figs[xcolumn]=f
    return figs

def dW_2Dcolor_plot(inst_weight_change):
    figs={}
    for ycolumn in ['isi2','pre_interval']:#'pre_spike2_dt'
        f,ax = plt.subplots()
        sns.scatterplot(x='isi',y=ycolumn,data=inst_weight_change,hue='dW',hue_norm=(-.0003,.0003),ax=ax,palette='seismic',legend='brief')    
        ax.set_ylabel(ycolumn,fontsize=12)
        ax.set_xlabel('ISI (sec)',fontsize=12)
        newYmin=min(max(-0.15,inst_weight_change[ycolumn].min()),0)
        newYmax=min(0.2,inst_weight_change[ycolumn].max())
        ax.set_ylim([newYmin,newYmax])
        ax.set_xlim([-0.15,0.1])
        ax.tick_params(labelsize=12)
        figs[ycolumn]=f
    return figs

 
def spiketime_plot(binned_spiketime,binned_spiketime_std,save_name='',wt_change_title=''):
    figs={}
    for kk in ['isi','pre_rate','isi2','pre_interval','pre_spike_dt']:# binned_spiketime.keys():
        f,ax = plt.subplots()
        f.suptitle(wt_change_title)
        labels=[str(round((b[0]+b[1])/2,5)) for b in binned_spiketime[kk].keys()]
        if kk.startswith('isi') or kk.startswith('pre'):
            values=binned_spiketime[kk].values()
            std=binned_spiketime_std[kk].values()
            ax.set_ylabel(kk+' (sec)',fontsize=12)
        else:   
             values=[np.log(val) for val in binned_spiketime[kk].values()]
             std=[np.log(val) for val in binned_spiketime_std[kk].values()]
             ax.set_ylabel('log('+kk+')(sec)',fontsize=12)
        ax.bar(labels,values,yerr=std)
        ax.set_xlabel('Weight Change Bin',fontsize=12)
        ax.tick_params(labelsize=12)
        figs[kk]=f
    return figs

def plot_3D_scatter(df,xname,yname,zname,binned_dict,cs):
    x={}; y={}; z={}
    for k,v in binned_dict.items():
        x[k]=df[xname].iloc[v]
        y[k]=df[yname].iloc[v]
        z[k]=df[zname].iloc[v]
    fig=plt.figure()
    ax=fig.add_subplot(projection='3d')
    for i,k in enumerate(x.keys()):
        label=' to '.join([str(round(wt,5)) for wt in k])
        ax.scatter(x[k].values,y[k].values,z[k].values,color=cs[i],label=label)
    ax.legend()
    ax.set_xlabel(xname)
    xlim=ax.get_xlim()
    ax.set_xlim([xlim[0],0.4])
    ylim=ax.get_ylim()
    ax.set_ylim([0,ylim[1]])
    ax.set_ylabel(yname)
    ax.set_zlabel(zname)
    ax.tick_params(labelsize=12)
    return fig

#3D scatter/color plot of dW (in color) vs isi1(x) vs isi2(y), or vs y=pre_dt or y = pre_interval
'''
#nothing shows up???
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure(figsize=(8, 6))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(df.spikecount,df.spinedist,df.endweight)
'''
'''
# ## Synaptic Weight all trials and synapses
# f,axes = plt.subplots(10,1,sharex=True,sharey=True,figsize=(12,20))
# for i,(k,d) in enumerate(data.items()):
#     for n in d.dtype.names:
#         if 'plas' in n:
#             axes[i].plot(d['time'][::100],d[n][::100])
#     axes[i].set_title(k)
#     axes[i].set_ylim(0,2.1)
#     axes[i].set_xlabel('Time (s)')
# axes[0].set_ylabel('Synaptic Weight')
#     #f.suptitle('Synaptic weight')
# #f.savefig(path+'all_syn_weight_over_time_figure.svg')

# cs=plt.cm.tab10(range(10))
# f,ax = plt.subplots()
# for spine in df['spine'].drop_duplicates():
#     #ignore no all no change:
#     line = [ df.loc[ (df.spine == spine) & (df.trial == sv) ]['endweight'].iat[0] for sv in range(10,101,10) ]
#     #if line.count(1.0)==len(line):
#     if ([l for l in line if abs(l-1)<.01]):
#         continue
#     x = np.arange(len(line))+np.random.uniform(-.1,.1)
#     ax.plot(x,line,color=plt.cm.gray(.8),linewidth=.5)
#     ax.scatter(x,line,marker='o',c=cs,zorder=10)
# #ax.set_xticks([0,1,2])
# #ax.set_xticklabels(['Low','Intermediate','High'])
# ax.set_ylabel('Synaptic Weight')
# ax.set_title('Ending Synaptic Weight')
# #f.savefig(path+'ending_synaptic_weight_figure.svg')


###### What is this figure ???
# for n in d.dtype.names:
#     if n.endswith('tertdend5_11-sp1head'):
#         print(n)
#         plt.figure()
#         peaks,_=find_peaks(d[n])
#         raster = d['time'][peaks]
#         plt.eventplot(np.sort(raster))
#         vmpeaks,_ = find_peaks(d[[vm for vm in d.dtype.names if 'Vm' in vm][0]],height=20e-3)
#         vmraster = d['time'][vmpeaks]
#         plt.eventplot(np.sort(vmraster),lineoffsets=-1,colors='r')
#         sorted_other_stim_spines = [s.replace('_sp','-sp').replace('ecdend','secdend') for s in stimspinetospinedist['tertdend5_11_sp1head'].sort_values().index]
#         for i,other in enumerate(sorted_other_stim_spines):
#             othername = '/data/D1-extern1_to_'+other
#             otherpeaks,_ = find_peaks(d[othername])
#             raster = d['time'][otherpeaks]
#             plt.eventplot(np.sort(raster),lineoffsets=i)
# plt.figure();plt.eventplot(np.sort(raster))


for i in spine_weights.index:
    plot_spine_calcium_and_weight(i)

f,axes = plt.subplots(1,1)
axes=[axes]
for i,(k,d) in enumerate(data.items()):
     for n in d.dtype.names:
         if 'sp' in n and 'Shell' in n:
             axes[i].plot(d['time'][::50],d[n][::50])
     axes[i].set_title(k)
     break
     #axes[i].set_ylim(0,2)
# f.suptitle('Spine Calcium')
'''