# -*- coding: utf-8 -*-
"""
Created on Thu May  6 14:55:57 2021

@author: kblackw1
"""
import numpy as np
import glob
'''
Update figures to show mean responses per trial?  Currently showing mean responses per 10 trials
Q matrix dynamics
'''

import matplotlib.pyplot as plt
plt.ion()
colors=[plt.get_cmap('Blues'),plt.get_cmap('Reds'),plt.get_cmap('Greys'),plt.get_cmap('Purples')] #plasma, viridis, inferno or magma possible
color_offset=0
letters=['A','B','C','D','E']
linestyles=['solid','dashed','dotted','dashdot']
fsize=11
import RL_utils as rlu

def create_traject_fig(traject,phases,actions,action_text,params,norm=None,leg_panel=0, leg_loc='best',leg_fs=fsize,sequential=False,leg_text={},color_dict={}):
    fig,ax=plt.subplots(len(actions),1,sharex=True)
    ax=fig.axes
    color_inc=int((255-color_offset)/(len(traject.keys())))
    blank=0.03
    ymin={act:0 for act in action_text}
    ymax={act:0 for act in action_text}
    for nQ,(numQ,data) in enumerate(traject.items()): #some figures look better if for pnum,phase comes first
        boundaries=[]
        for pnum,phase in enumerate(phases):
            if len(phases)<=len(colors):
                if sequential:
                    cmap=2 #0 gives blues - 2nd submission. 2 gives grays - 3rd submission
                    #cmap=pnum #for block case
                else:
                    cmap=pnum
                cnum=int(nQ+1)*color_inc+color_offset
                leg_cols=len(phases)
            else:
                color_inc=int((255-color_offset)/(len(data.keys())))
                cmap=nQ
                cnum=int(color_offset+(pnum+1)*color_inc)
                leg_cols=len(traject.keys())
            #print('traject_fit',phase,'Q=',numQ,nQ,'c=',cnum)
            if len(color_dict):
                color=color_dict[numQ][pnum]
                linestyle=linestyles[nQ]
                #print('col_dict, color=',color)
            else:
                color=colors[cmap].__call__(cnum)
                linestyle='solid'
            label_inc=(1-3*blank)/len(actions) #used for putting subplot labels
            ########## Text for legend ###########
            if len(phases)>len(colors):
                trace_label=phase
                if len(leg_text):
                   trace_label=leg_text[phase]
                legend_title=''.join([numQ+'Q               ' for numQ in traject.keys()])
                #leg_fs=10
            elif len(phases)>1:
                trace_label=phase
                if len(leg_text):
                   trace_label=leg_text[phase]
                if sequential:
                    legend_title=''
                    trace_label='numQ='+numQ
                if numQ.isdigit(): #don't add 'numQ=' if traject.keys() are not numQ
                    if leg_panel>-1:
                        legend_title='      '.join(['numQ='+str(numQ) for numQ in traject.keys()]) 
                    else:
                        legend_title=''
                    trace_label=trace_label+'Q'
                else:
                    legend_title=''
                    trace_label=numQ#phase+', '+numQ                    
                #
            else: #acquisition only
                trace_label=numQ
                legend_title='num Q'
            for anum,act in enumerate(actions):
                if act in data[phase].keys():
                    num_blocks=len(data[phase][act]['mean'])
                    if sequential:
                        block=(np.arange(num_blocks)+(pnum*num_blocks))*params['trials_per_block'][0] #convert to trials
                        boundaries.append(block[-1]+params['trials_per_block'][0])
                    else:
                        block=np.arange(num_blocks)*params['trials_per_block'][0]
                    ymax[action_text[anum]]=max(np.max(data[phase][act]['mean']+data[phase][act]['sterr']),ymax[action_text[anum]])
                    ymin[action_text[anum]]=min(ymin[action_text[anum]],np.min(data[phase][act]['mean']-data[phase][act]['sterr']))
                    #if pnum>0: #TEMPORARY - TO ELIMINATE LEGEND
                    #    trace_label='_'
                    if act=='rwd':
                        ax[anum].hlines(0,xmin=0,xmax=np.max(block),linestyle='dotted')
                    if norm is not None and act !='rwd': #### yvalues for reward already in units of "per trial"
                        yvals=data[phase][act]['mean']*norm['value']
                        yerr=data[phase][act]['sterr']*norm['value']
                        ymin[action_text[anum]]=-norm['units']*0.05;ymax[action_text[anum]]=norm['units']*1.05
                    else:
                        yvals=data[phase][act]['mean']
                        yerr=data[phase][act]['sterr']
                    ax[anum].errorbar(block,yvals,yerr=yerr,label= trace_label,color=color,linestyle=linestyle)
        maxy=np.max([axis.get_ylim()[1] for axis in ax[1:]] ) #make y axis limits the same for all panels
        miny=np.min([axis.get_ylim()[0] for axis in ax[1:]] )
        for anum in range(len(ax)):
            ax[anum].set_ylabel(action_text[anum], fontsize=fsize+1)
            ax[anum].set_ylim([ymin[action_text[anum]],ymax[action_text[anum]]])
            if len(boundaries):
                for xval in boundaries:
                    ax[anum].vlines(xval,ymin[action_text[anum]],ymax[action_text[anum]],linestyles='dashed',color='grey')
            if action_text[anum]=='reward':
                ylim=ax[anum].get_ylim()
                print('ylim',anum,action_text[anum],ylim)
                ax[anum].set_ylim([np.floor(ylim[0]),np.ceil(ylim[1])*1.05]) 
                yticks=np.linspace(np.floor(ylim[0]),np.ceil(ylim[1]),4)
                ylabels=[str(round(f)) for f in yticks]
                ax[anum].set_yticks(yticks,ylabels)
            elif norm is None:
                ax[anum].set_ylim([np.floor(miny),np.ceil(maxy)]) 
            y=(1-blank)-(anum*label_inc) #subtract because 0 is at bottom
            #if len(actions)>1:
            #    fig.text(0.02,y,letters[anum], fontsize=fsize+2)
            ax[anum].tick_params(axis='x', labelsize=fsize)
            ax[anum].tick_params(axis='y', labelsize=fsize)                
            if sequential:
                ax[anum].set_xlim([0,num_blocks*(pnum+1)*params['trials_per_block'][0]])
            else:
                ax[anum].set_xlim([0,num_blocks*params['trials_per_block'][0]])
            if leg_panel>-1 and leg_fs>0:
                leg=ax[leg_panel].legend(frameon=True,title=legend_title,ncol=leg_cols,loc=leg_loc,fontsize=leg_fs-1,title_fontsize=leg_fs-1,handletextpad=0.2,labelspacing=0.3,columnspacing=1)# markerfirst=True
    ax[anum].set_xlabel('Trial', fontsize=fsize+1)
    fig.tight_layout()
    return fig #so you can adjust size and then do fig.tight_layout()

def create_sequence_traject_fig(traject,actions,action_text,params,lenQ={},ept=7,norm=None):
    if len(lenQ):
        numrows=len(actions)+1
    else:
        numrows=len(actions)
    fig,ax=plt.subplots(numrows,1,sharex=False)
    ax=fig.axes
    blank=0.03
    for ijk,(numQ,data) in enumerate(traject.items()):
        incr=int(255/len(traject.keys()))
        color=colors[2].__call__((ijk+1)*incr)  #0 gives blues - 2nd submission. 2 gives grays - 3rd submission
        label_inc=(1-2*blank)/len(actions) #used for putting subplot labels
        ########## Text for legend ###########
        for anum,act in enumerate(actions):
            num_blocks=len(data[act]['mean'])
            block=np.arange(num_blocks)*params['trials_per_block'][0]
            if norm is not None and act !='rwd':
                yvals=data[act]['mean']*norm['value']
                yerr=data[act]['sterr']*norm['value']
            else:
                yvals=data[act]['mean']
                yerr=data[act]['sterr']
            ax[anum].errorbar(block,yvals,yerr=yerr,label= numQ,color=color)
            ax[anum].set_ylabel(action_text[anum], fontsize=fsize)
            if act=='rwd':
                ax[anum].hlines(0,xmin=0,xmax=np.max(block),linestyle='dotted')                    
            y=(1-blank)-(anum*label_inc) #subtract because 0 is at bottom
            #if len(actions)>1:
            #    fig.text(0.02,y,letters[anum], fontsize=fsize)
            ax[anum].set_xlim([0,num_blocks*params['trials_per_block'][0]])

        ax[anum].set_xlabel('Trial', fontsize=fsize)
        ax[0].legend(frameon=True)#,title='num Q')
    if len(lenQ):
        #lenQ[1]=lenQ[2] for numQ=2, so plot lenQ for both numQ=1 and 2
        qcolors=['magenta','purple']
        for q,data in lenQ.items():
            color=colors[3].__call__(int(q)*incr) #purples
            Xvals=np.arange(np.shape(data[int(q)])[1])/ept
            ax[anum+1].plot(Xvals,np.mean(data[int(q)],axis=0),label='numQ='+str(q),color=qcolors[int(q)-1])
        #ax[anum+1].legend()
        ax[anum+1].set_ylabel('States')
        ax[anum+1].set_xlim([-1,601])
        ax[-1].set_xlabel('Trial')
    if norm is None: ######################### Revisit after plotting new data
        for anum in range(1,numrows):
            if anum==numrows-1 and len(lenQ):
                yticks=[0,20,40,60,80,100]
            else:
                yticks=[0,5,10]
            ylabels=[str(round(f)) for f in yticks]
            ax[anum].set_yticks(yticks,ylabels)
        ax[0].set_ylim(-7.5,10.0)
    fig.tight_layout()
    return fig #so you can adjust size and then do fig.tight_layout()

def create_bandit_fig(traject, params,numpanels=2,color_dict={}):
    fig,ax=plt.subplots(numpanels,1,sharex=True)
    ax=fig.axes
    blank=0.03
    label_inc=(1-2*blank)/len(ax) #used for putting subplot labels
    leg_text=''
    for nQ, (numQ,data) in enumerate(traject.items()):
        leg_text=leg_text+numQ+'Q               '
        color_inc=int((255-color_offset)/(len(data.keys())))
        print(color_inc)
        if numpanels==1:
            axnum=0
        else:
            axnum=nQ
        ax[axnum].tick_params(axis='x', labelsize=fsize)
        ax[axnum].tick_params(axis='y', labelsize=fsize)
        for pnum,phs in enumerate(data.keys()):
            if len(color_dict):
                color=color_dict[numQ][pnum]
            else:
                color=colors[nQ].__call__(int(color_offset+(pnum+1)*color_inc))
            num_blocks=len(data[phs])
            block=np.arange(num_blocks)*params['trials_per_block'][0]
            ax[axnum].plot(block,data[phs],label= phs,color=color)
            ax[axnum].set_ylabel('Prob (L)', fontsize=fsize)
            ax[axnum].hlines(0.5,xmin=0,xmax=np.max(block),linestyle='dotted')                    
        if numpanels>1:
            ax[axnum].legend(frameon=True, title=numQ+'Q', fontsize=fsize)
        if len(ax)>1:
            y=(1-blank)-(nQ*label_inc) #subtract because 0 is at bottom
            fig.text(0.02,y,letters[nQ], fontsize=fsize)
    if numpanels==1:
        ax[0].legend(frameon=True,title=leg_text,ncol=len(traject.keys()), fontsize=fsize-1,title_fontsize=fsize-1,handletextpad=0.2,labelspacing=0.3,columnspacing=1)
    ax[axnum].set_xlim([0,(num_blocks+1)*params['trials_per_block'][0]])
    ax[axnum].set_xlabel('Trial', fontsize=fsize)
    fig.tight_layout()
    return fig #so you can adjust size and then do fig.tight_layout()

def read_data(pattern, files=None, keys=None, dep_var=None):
    ########## only saves one traject for each numQ #########
    if not files:
        files=glob.glob(pattern)
    print('pattern',pattern,'files',files)
    traject={}
    all_counts={}
    traject_dict={}
    for k,f in enumerate(files):
        data=np.load(f,allow_pickle=True)
        params=data['par'].item()
        par=data['results'].item()['params']
        if keys is None and dep_var is None:
            numQ=str(params['numQ'])
        elif keys:
            numQ=keys[k]
        else:
            numQ=','.join([str(params[k]) for k in dep_var])
        traject[numQ]=data['traject'].item()
        if 'shift_stay' in data.keys():
            all_counts[numQ]=data['shift_stay'].item()
            traject_dict[numQ]=data['traject_dict'].item()
        elif 'all_beta' in data.keys(): #files beginning on 3 June 2022
           all_counts[numQ]=data['all_lenQ'].item()
        if 'sa_errors' in data.keys():
            switch=[tuple(sa) for sa in data['sa_errors'].item()['switch']]
            switch_keys=rlu.construct_key(switch)
            overstay=[tuple(sa) for sa in data['sa_errors'].item()['stay']]
            overstay_keys=rlu.construct_key(overstay)
            start=[tuple(sa) for sa in data['sa_errors'].item()['start']]
            start_keys=rlu.construct_key(start)
            sa_error_keys={'switch':switch_keys,'overstay':overstay_keys,'start':start_keys}
        else:
            sa_error_keys={}
        #mean_rwd=np.sum(np.mean([phs['rwd__End'] for ph,phs in data['results'].item().items() if ph != 'params'],axis=2))
        #print('mean reward for', f,'=',mean_rwd)
        del data   
    return traject,all_counts,sa_error_keys,par,traject_dict
  
################## Stat Analysis  #########################
def create_df(pattern,files=None,del_variables=[],params=['numQ'],keys=None ):
    if not files:
        files=glob.glob(pattern)
    print('pattern',pattern,'files',files)
    df=[]
    for fnum,f in enumerate(files):
        par={}
        data=np.load(f,allow_pickle=True)
        print(f,list(data.keys()))
        results=data['results'].item()
        print('params',results['params'])
        for p in params:
            print('parameter',p)
            if p in results['params'].keys():
                print('parameter found', results['params'][p])
                par[p]=results['params'][p][0]
                #if p=='beta_min':
                #    if par[p]!=0.5:  #for statistical test on constant beta.  
                #        par[p]='const'
                if par[p] == None:
                    par[p]=-1
            elif keys:
                par[p]=keys[fnum]
            else:
                par[p]=-1
            #numQ=str(results['params']['numQ'][0])
        del results['params']
        key_combos=list(results.keys())
        new_results={}
        if type(results[key_combos[0]])==dict:
            #if results is dictionary of dictonaries, need to flatten it
            for phase in key_combos:
                for key,vals in results[phase].items():
                    new_results[phase+'_'+key]=vals[0] #shape is 1 row by blocks columns
            for zr in del_variables:
                del new_results[zr]
        elif len(np.shape(results[key_combos[0]]))==2:
            for phs in key_combos:
                if '-' in phs:
                    new_key=phs.replace('-','s')
                else:
                    new_key=phs.replace('*','x')
                new_results[new_key]=results[phs][0]
            for zr in del_variables:
                del new_results[zr]
        results=new_results
        dfsubset=pd.DataFrame.from_dict(results,orient='index').transpose()
        newparams=[]
        for p in params:
            pval=[par[p] for i  in range(len(dfsubset))]
            dfsubset[p]=pval
            if isinstance(par[p],list):
                for j in range(len(par[p])):
                    pval=[par[p][j] for i  in range(len(dfsubset))]
                    dfsubset[p+str(j)]=pval
                    newparams.append(p+str(j))
            else:
                newparams.append(p)
        #nQ=[numQ for i in range(len(dfsubset))]
        #dfsubset['numQ']=nQ
        df.append(dfsubset)
    alldf=pd.concat(df).reset_index() #concatentate everything into big dictionary
    return alldf,newparams

def barplot(mean,sterr,variables,varname,ylabel='Reward',transpose=False,norm=None,legend=True):
    figS,ax=plt.subplots(figsize=(2,2.5)) #width, height
    if transpose:
        mean=mean.T
        sterr=sterr.T
        variables=mean.columns
    rows=list(mean.index.values)
    xlabels=['Ctrl' if x==-1 else varname+' '+str(x) for x in rows]
    if transpose:
        xlabels=[x[7:16] for x in rows]
    xvalues=np.arange(len(rows))
    w = 1./(len(variables)+0.5)
    for a,tv in enumerate(variables):
        #print(tv, type(tv))
        yvalues=mean[tv].values
        yerr=sterr[tv].values
        ############### Normalize to % of optimal per trial, except reward already normalized to per trial ###########
        lbl='Ctrl' if tv==-1 else tv
        if norm and ('rwd' not in lbl and 'reward' not in lbl): 
            yvalues*=norm['value']
            yerr*=norm['value']
        ax.bar(xvalues+(a-(len(variables)-1)/2)*w, yvalues,width=w,yerr=yerr,label=str(lbl)[0:8]) 
    ax.set_ylabel(ylabel)
    ax.hlines(0,np.min(xvalues),np.max(xvalues),linestyles='dashed',colors='gray')
    ax.set_xlabel('Condition')
    ax.set_xticks(xvalues,xlabels)
    if ('rwd' not in lbl and 'reward' not in lbl and norm ):
        ax.set_ylim([0,1.2*norm['units']])
    if legend:
        ax.legend(loc='best',fontsize=9,frameon=False)
    figS.suptitle('+'.join([str(v) for v in variables]),fontsize=8)
    return figS

def barplot_means(df,dep_var,test_variables):
    print(df.groupby(dep_var)[test_variables].aggregate(['mean','sem'])) #,'count'
    mean=df.groupby(dep_var)[test_variables].mean()
    cnt=df.groupby(dep_var)[test_variables].count()
    sterr=df.groupby(dep_var)[test_variables].std()/np.sqrt(cnt-1)
    return mean,cnt,sterr

def calc_norm(params,percent=False):
    norm={}
    if percent:
        norm['value']=100/params['trials_per_block'][0]
        norm['units']=100
    else:
        norm['value']=1/params['trials_per_block'][0]
        norm['units']=1
    return norm

if __name__ == "__main__":
    #for stats, only run one at a time
    task= 'bandit'# 'discrim' #'AIP'# 'sequence' #
    add_barplot=0 #only relevant for sequence
    shift_stay=0 #only relevant for bandit
    test_var=[]
    traject_fig=True
    ######################### DISCRIM #########################
    if task=='discrim':
        from discrimFiles import pattern,dep_var,files,test_variables,actions,action_text,keys
        traject,_,_,params,_=read_data(pattern,files=files,keys=keys) 
        norm=calc_norm(params,percent=False)
        ##### acquisition ######
        phase=['acquire']
        if traject_fig:
            figA=create_traject_fig(traject,phase,actions,action_text,params,norm=norm) #Fig 2A,B
                
        ##### extinction, renewal #####
        phase=['acquire','extinc','renew'] #phase=['extinc','renew']
        phase_text={'acquire':'','extinc':'Context B','renew':'Context A'}
        #phase_text={'acquire':'','extinc':'','renew':''}
        if traject_fig:
            figE=create_traject_fig(traject,phase,actions,action_text,params,leg_text=phase_text,norm=norm,sequential=True,) #Fig 2C
            #figE=create_traject_fig(traject,phase,actions,action_text,params,leg_panel=-1,sequential=True,norm=norm) #Fig 2C
        
        ##### discrimination, reversal #####
        phase=['acquire','discrim','reverse']
        actions=['rwd', (('Pport', '10kHz'),'right'),(('Pport', '10kHz'),'left')]#,(('Pport', '6kHz'),'left')]
        action_text=['reward', '10 kHz Right', '10 kHz Left']#,'6 kHz Left'] #
        if traject_fig:
            figD=create_traject_fig(traject,phase,actions,action_text,params,leg_panel=-1, sequential=True,norm=norm) #Fig 4
            xlim=[190,610]
            ax=figD.axes
            for axis in ax:
                axis.set_xlim(xlim)

    ######################### block Dopamine #########################
    elif task=='AIP':
        from discrimFiles import pattern,dep_var,files,test_variables,actions,action_text,keys
        print(actions)
        traject,_,_,params,_=read_data(pattern, files, keys) 
        norm=calc_norm(params)
        phase=['acquire','discrim']
        ###fix colors here
        figB=create_traject_fig(traject,phase,actions,action_text,params,leg_panel=1,leg_loc='upper left', leg_fs=12, sequential=True,norm=norm) #Fig 6
        ax=figB.axes
        #ax[0].set_ylim([0,11])
        #ax[1].set_ylim([-0.3,8]) 
        #ax[2].set_ylim([-0.3,8]) 
    #################### Sequence trajectory ##################
    elif task=='sequence':
        from sequenceFiles import pattern, dep_var, files,barplot_files,keys,test_variables,actions,action_text
        if add_barplot:
            files=barplot_files
            traject,all_lenQ,error_keys,params,_=read_data(pattern,barplot_files,keys=keys)   
        else:
            traject,all_lenQ,error_keys,params,_=read_data(pattern,files,dep_var=dep_var,keys=keys) 
        norm=calc_norm(params)
        if traject_fig:
            if len(traject)==2:
                figS=create_sequence_traject_fig(traject,actions,action_text,params,lenQ=all_lenQ) #fig 7
            else:
                figS=create_sequence_traject_fig(traject,actions,action_text,params) #fig 7
        ######## Currently not used ###########
        switch=[kk.replace('*','x')+'_End' for kk in error_keys['switch'].values() if 'press' in kk ]
        premature=[kk+'_End' for kk in ['Llever_xxRL_goR', 'Llever_---L_goR','Rlever_xxRL_press','Rlever_---L_press','Rlever_xLLR_goMag']]
        overstay=[kk.replace('*','x')+'_End' for kk in error_keys['overstay'].values() if 'press' in kk]
        start=[kk.replace('*','x')+'_End' for kk in error_keys['start'].values()]

    #################### Bandit Task Probabilities ##################
    elif task=='bandit':
        from banditFiles import pattern,dep_var,files,test_variables,actions,action_text,keys
        from BanditTask import calc_fraction_left,plot_prob_tracking,plot_prob_traject
        runs=40
        colors2=[plt.get_cmap('seismic').__call__(c) for c in [18,50]]+[plt.get_cmap('hsv').__call__(c) for c in [188,204,225]]+[plt.get_cmap('seismic').__call__(c) for c in [192,242]]
        newcol={'2':colors2,'1':colors2}
        if keys:
            newcol={k:colors2 for k in keys}
        traject,shift_stay,_,params,traject_dict=read_data(pattern,files=files,keys=keys,dep_var=dep_var)
        #shift_stay=False
        norm=calc_norm(params)
        p_choose_L={q:{} for q in traject.keys()}
        RMS={}
        if len(traject.keys())<=2:
            for numQ, data in traject.items():
                fractionLeft,noL,noR,ratio=calc_fraction_left(traject_dict[numQ],runs) #probability of Left over entire trial
                popt,pcov,delta,RMSmean,RMSstd,RMS[numQ]=plot_prob_tracking(ratio,fractionLeft,runs,showplot=False)
                print('numQ=',numQ, 'ratio:',{round(ratio[k],3): (round(np.nanmean(fractionLeft[k]),3),round(np.nanstd(fractionLeft[k])/np.sqrt(40),3)) for k in fractionLeft.keys()} )
                p_choose_L[numQ]=plot_prob_traject(data,params,show_plot=False)
            if traject_fig:
                figB=create_bandit_fig(p_choose_L,params,numpanels=2,color_dict=newcol) 
                tkeys=list(traject.keys())
                tmp_phs=list(traject[tkeys[0]].keys())
                phases=sorted(tmp_phs,key=lambda tmp_phs: float(tmp_phs.split(':')[0])-float(tmp_phs.split(':')[1]),reverse=True)
                figBT=create_traject_fig(traject,phases,actions,action_text,params,leg_fs=8,color_dict=newcol,norm=norm) #Fig 10C,D
        if shift_stay:
            for numQ,all_counts in shift_stay.items():
                print('\n ############################ numQ=',numQ)
                for phs in all_counts['left_rwd'].keys():
                    print('\n*******',phs,'******')
                    #print('left_rwd=',all_counts['left_rwd'][phs],'left_none=',all_counts['left_none'][phs],
                    #      'right_rwd=',all_counts['right_rwd'][phs],'right_none=',all_counts['right_none'][phs])
                    for key,counts in all_counts.items():
                        ratio=[stay/(stay+shift) for stay,shift in zip(counts[phs]['stay'],counts[phs]['shift']) if stay+shift>0 ]
                        events=[(stay+shift) for stay,shift in zip(counts[phs]['stay'],counts[phs]['shift'])]
                        print(key,round(np.mean(ratio),3),round(np.std(ratio)/np.sqrt(len(ratio)),3), 'out of', np.mean(events), 'responses')

############################# Stat analysis ##########################################
    import pandas as pd
    from scipy.stats import ttest_ind
    import statsmodels.api as sm
    from statsmodels.formula.api import ols
    import scikit_posthocs as sp
    import os

    df,dep_var=create_df(pattern,files=files,params=dep_var)
    if task=='block_da' or keys:
        df,dep_var=create_df(pattern,files=files,params=dep_var,keys=keys)
    if 'Bandit' in pattern:
        key1=list(traject.keys())[0]
        ratio={}
        df['sum_squares']=0
        for prob in traject[key1].keys():
            R=float(prob.split(':')[1])
            L=float(prob.split(':')[0])
            ratio[prob]=L/(L+R)
            ###### probability of Left at the End #######
            df[prob+'_probL']=df[prob+'_Pport_6kHz_left_End']/(df[prob+'_Pport_6kHz_left_End']+df[prob+'_Pport_6kHz_right_End'])
            df['sum_squares']+=np.square(ratio[prob]-df[prob+'_probL'])
        df['RMS']=np.sqrt(df['sum_squares'])
        df['mean_reward']=df.loc[:,test_variables].mean(axis=1)
        test_variables=['mean_reward','RMS']
    if os.path.basename(pattern).upper().startswith('DISCRIM') and task != 'AIP':
        df['mean_reward']=df.loc[:,['reverse_rwd__End','discrim_rwd__End']].mean(axis=1)
        test_variables=['mean_reward','acquire_rwd__End','reverse_rwd__End','discrim_rwd__End']
        #test_variables=['acquire_rwd__End','discrim_rwd__End','extinc_Pport_6kHz_left_Beg']#,'renew_Pport_6kHz_left_Beg']
        if 'split' in dep_var:
            test_variables=['mean_reward','extinc_Pport_6kHz_left_Beg','renew_Pport_6kHz_left_Beg']#,'extinc_Pport_6kHz_left_End','renew_Pport_6kHz_left_End'
    mean,cnt,sterr=barplot_means(df,dep_var,test_variables)
    
    textname=pattern.replace('?','x').replace('*','all_').split('.npz')[0]+'summary.txt'
    columns=[tv[0:-4] for tv in test_variables]
    header=' '.join(cnt.index.names)+' '+'  '.join(columns)
    rows=list(cnt.index.values)
    figBP=barplot(mean,sterr,test_variables,dep_var[0],norm=norm)
    np.savetxt(textname,np.column_stack((rows*3,np.vstack([cnt,round(mean,3),round(sterr,3)]))),fmt='%s',header=header,comments='') #Ttest tables/Fig 5 - Igor
    if task=='bandit':
        newtv= [k+'_probL' for k in ['90:10','90:50','50:10', '50:50','10:50','50:90','10:90']] 
        print(df.groupby(dep_var)[newtv].mean()) #mean probability of Left at the end of the trial
        print(df.groupby(dep_var)[newtv].sem()) #std error of probability of Left at the end of the trial
        if 'numQ' in dep_var:
            print('******** Ttest on entire trial RMS, numQ:', ttest_ind(RMS['1'],RMS['2'],equal_var=False))
    if task=='sequence' and add_barplot:
        df['Llever_1L_goR_End']=df['Llever_xxRL_goR_End']+df['Llever_sssL_goR_End']
        if 'Rlever_xxRL_press_End' in df.columns:
            df['Rlever_1L_press_End']=df['Rlever_xxRL_press_End']+df['Rlever_sssL_press_End']
        df['Llever_1L_press_End']=df['Llever_xxRL_press_End']+df['Llever_sssL_press_End']
        df['mag_ssss_nostart_End']=df['mag_ssss_goR_End']+df['mag_ssss_other_End']+df['mag_ssss_goMag_End']+df['mag_ssss_press_End']
        df['Llever_1L_press_End_Prob']=(df['Llever_1L_press_End'])/(df['Llever_1L_press_End']+df['Llever_sssL_goR_End']+df['Llever_sssL_goMag_End']+df['Llever_sssL_other_End']+df['Llever_xxRL_other_End']+df['Llever_xxRL_goMag_End']+df['Llever_xxRL_goR_End'])
        df['Llever_2L_switch_End_Prob']=df['Llever_xxLL_goR_End']/(df['Llever_xxLL_goR_End']+df['Llever_xxLL_press_End']+df['Llever_xxLL_goL_End']+df['Llever_xxLL_goMag_End']+df['Llever_xxLL_other_End'])
        df['mag_sss_start_End_Prob']=df['mag_ssss_goL_End']/(df['mag_ssss_nostart_End']+df['mag_ssss_goL_End'])
        all_test_variables={'overstay':['Llever_xxLL_press_End','Llever_xxLL_goL_End'], #correct vs over-stay
                        #'B':['Rlever_LLRR_goMag_End','Rlever_LLRR_press_End','Rlever_LLRR_goR_End'], #correct vs stay
                        #'C':['Rlever_xLLR_press_End','Rlever_xLLR_goL_End'], #correct vs switch
                        'premature':['Llever_1L_goR_End','Rlever_1L_press_End'],
                        #'start':['mag_ssss_goL_End', 'mag_ssss_nostart_End', ],# incorrect start
                        'all_three':['mag_ssss_goL_End','Llever_1L_press_End','Llever_xxLL_goR_End'],
                        'prob':['mag_sss_start_End_Prob','Llever_1L_press_End_Prob','Llever_2L_switch_End_Prob']} 
        for kk,test_var in all_test_variables.items():
            mean,cnt,sterr=barplot_means(df,dep_var,test_var)
            #figBP=barplot(mean,sterr,test_var,dep_var[0],ylabel='Responses per Trial')
            if kk=='prob':
                figBP=barplot(mean,sterr,test_var,dep_var[0],ylabel='Probability',transpose=True,legend=False)
            else:
                figBP=barplot(mean,sterr,test_var,dep_var[0],ylabel='Responses per Trial',transpose=True,legend=False)#,norm=norm)
    for tv in test_variables+test_var: # all_test_variables['all_three']+['rwd__End']: #
        if df[tv].isna().sum():
            testdf=df.dropna()
            print('new mean (after dropping Nans) for',tv, testdf.groupby(dep_var)[tv].aggregate(['mean','std','count']))
        else:
            testdf=df
        new_dep_var=[]
        for dv in dep_var:
            if testdf[dv].nunique()>1:
                new_dep_var.append(dv)
            else:
                print('STATS: only 1 level for variable=', dv)
        if len(dep_var)>len(new_dep_var):
            print('proposed dependent variables:',dep_var,', new dependent variables:',new_dep_var)
        dep_var=new_dep_var
        if testdf[dep_var].nunique().sum()==2:
            unique_vals=np.unique(testdf[dep_var[0]])
            tt=ttest_ind(testdf[testdf[dep_var[0]]==unique_vals[0]][tv], testdf[testdf[dep_var[0]]==unique_vals[1]][tv], equal_var=False)
            print('\n*******',tv,'\n',tt,'\n mean:\n',mean[tv],'\n sterr:\n',sterr[tv])      
        else:
            dependents=['C('+dv+')' for dv in dep_var]
            model_statement=' ~ '+'+'.join(dependents)
            print('\n****************************************\n',tv, '=',model_statement,'\n')
            model=ols(tv+model_statement,data=testdf).fit()
            print (sm.stats.anova_lm (model, typ=2), '\n',model.summary())
            if len(dep_var)==1: 
                print('\npost-hoc\n',tv,sp.posthoc_ttest(testdf, val_col=tv, group_col=dep_var[0], p_adjust='holm'))
### generate text file for Igor ##
'''header='block   '+'   '.join(['banditQ2_PL'+prob for prob in p_choose_L['2'].keys()])
num_blocks=len(p_choose_L['2']['50:10'])
output=np.arange(num_blocks)*params['trials_per_block'][0] 
for phs,arr in p_choose_L['2'].items():
    output=np.column_stack((output,arr))
np.savetxt(files[1].split('.npz')[0]+'.txt',output,fmt='%6.4f', header=header)
'''