#-*- coding: utf-8 -*-
"""
Created on Fri Nov 20 13:53:52 2020

@author: kblackw1
"""
'''    
Initial analysis of parameter variations.  Just looks at mean response
re-run best params (and save data) to show trajectories - calculate moving average of reward and some actions
'''

import numpy as np
import glob
import pandas as pd
import operator
from matplotlib import pyplot as plt
import matplotlib as mpl

pd.set_option("display.max_columns", 20)
pd.set_option('display.width',240)

def read_data(fnames,list_params,numQ=2,blockDA=False):
    df=[]
    max_correct=1
    for f in fnames:
        data=np.load(f,allow_pickle=True)
        results=data['allresults'].item()
        par=results['params'] #extract parameters
        del results['params']
        key_combos=list(results.keys())
        if type(results[key_combos[0]])==dict:
            #if results is dictionary of dictonaries, need to flatten it
            new_results={}
            for phase in key_combos:
                for key,vals in results[phase].items():
                    new_results[phase+'_'+key]=vals
            results=new_results
        #add parameter values to the dictionary
        new_list_params=[];single_params=[]
        for k in par.keys():
            #split state_thresh in two series
            if k in list_params:
                for i in range(numQ):
                    results[k+str(i+1)]=[round(p[i],3) for p in par[k]]
                    new_list_params.append(k+str(i+1))
            else:
                if isinstance(par[k][0],float):
                    results[k]=[round(p,3) for p in par[k]]
                    single_params.append(k)
                elif isinstance(par[k][0],dict): #NEEDS TESTING
                    for xx in par[k][0].keys():
                        results[k+xx]=[p[xx] for p in par[k]]
                else:
                    results[k]=par[k]
                    single_params.append(k)
                #else: do nothing
        #put back in correct alpha
        if blockDA and np.std(results['alpha2'])==0:
            a2inc=np.max(np.diff(results['alpha1']))/2
            a2range=np.arange(np.min(results['alpha1'])/2,1.1*np.max(results['alpha1'])/2,a2inc)
            repeats=int(len(results['alpha2'])/len(a2range))
            results['alpha2']=np.tile(a2range,repeats) 
        ############# stdev
        if 'reslist' in data.keys():
            reslist=data['reslist'].item()
            if type(reslist[key_combos[0]])==dict:
                for phase in key_combos:
                    for key,vals in reslist[phase].items():
                        results[phase+'_'+key+'_STD']=[np.std(par_result) for par_result in vals] 
                key_combos=list(new_results.keys())
            else:
                for col in key_combos: #corresponds to phase/action/epoch combo
                    results[col+'_STD']=[np.std(par_result) for par_result in reslist[col]]
            if 'max_correct' in reslist['params'].keys():
                max_correct=reslist['params']['max_correct']
        dfsubset=pd.DataFrame.from_dict(results,orient='index').transpose()
        #print(f,len(dfsubset))
        #add dictionary of result subset to list of  dictionaries
        df.append(dfsubset)
    allresults=pd.concat(df).reset_index() #concatentate everything into big dictionary
    #drop duplicate values
    param_names=new_list_params+single_params
    newdf=allresults.drop_duplicates(subset=param_names)
    print('size of df, before',len(allresults),'after',len(newdf),'params',param_names, 'max_correct', max_correct)
    return newdf,key_combos,param_names,max_correct

def plot_Q2_results(Q2df,key_combos):
    st=[]
    alph=[]
    st_trans=[]
    for i in range(2):
        st.append( Q2df['state_thresh'+str(i+1)].values  )
        #st1= allresults['params']['state_thresh2']   
        #a0=allresults['params']['alpha1']
        alph.append(Q2df['alpha'+str(i+1)].values) 
    for i in range(2):
        st_trans.append(np.where(np.diff(st[i])>0)[0]+1) #add one because np.diff places transition from 1 to 2 into slot 1, but new values begin in slot 2
        #st1_transition=np.where(np.diff(st1)>0)[0]+1
    #all state threshold transitions locations, including the beginning of the array
    st_transition=sorted(np.concatenate(([0],st_trans[0],st_trans[1])))
    rows=len(np.unique(alph[0]))
    cols=len(np.unique(alph[1]))
    
    #now do the plotting
    figset=[]
    for pa in key_combos:#combos:
        fig,axes=plt.subplots(len(np.unique(st[1])),len(np.unique(st[0])))
        ax=fig.axes
        fig.suptitle(pa+'; numQ='+str(int(Q2df['numQ'].iloc[0])))
        zvals=(Q2df[pa]).to_numpy(dtype=float)
        vmin=np.min(zvals);vmax=np.max(zvals)
        for i,at in enumerate(st_transition):
            #transpose, because a0 has greater range than a1
            if len(np.unique(st[0][at:at+rows*cols]))>1 or len(np.unique(st[1][at:at+rows*cols]))>1:
                print('**************** uh oh ************',at)
            plotz=np.reshape(zvals[at:at+rows*cols],(rows,cols)).T
            ax[i].set_title(str(round(st[0][at],3))+','+str(round(st[1][at],3)),fontsize=8)
            #transpose alpha values and verify that plots are correct
            #use x and y to label x and y axes
            y=np.reshape(alph[1][at:at+rows*cols],(rows,cols)).T[:,0]
            x=np.reshape(alph[0][at:at+rows*cols],(rows,cols)).T[0]
            im=ax[i].imshow(plotz,extent=[np.min(x),np.max(x),np.min(y),np.max(y)],cmap=plt.get_cmap('gray'),vmin=vmin,vmax=vmax,origin='lower')
        for ii in range(np.shape(axes)[0]):
            axes[ii,0].set_ylabel('alpha 2')
        for jj in range(np.shape(axes)[1]):
            axes[-1,jj].set_xlabel('alpha 1')
        cax = fig.add_axes([0.27, 0.9, 0.5, 0.03])
        fig.colorbar(im, cax=cax, orientation='horizontal')
        fig.show()
        figset.append(fig)
    return fig
def plot_Q1_results(Q2df,key_combos):
    x=Q1df['alpha1'].to_numpy(dtype=float)
    y=Q1df['state_thresh1'].to_numpy(dtype=float)
    from mpl_toolkits.axes_grid1 import make_axes_locatable
    cmap = mpl.cm.gray
    fig,ax=plt.subplots(len(key_combos))
    for i,pa in enumerate(key_combos): #enumerate(combos):
        zvals=(Q1df[pa]).to_numpy(dtype=float)
        vmin=np.min(zvals);vmax=np.max(zvals)
        plotz=np.reshape(zvals,(len(np.unique(y)),len(np.unique(x))))
        ax[i].imshow(plotz,extent=[np.min(x),np.max(x),np.min(y),np.max(y)],cmap=plt.get_cmap('gray'),vmin=vmin,vmax=vmax,origin='lower')
        ax[i].set_xlabel('alpha')
        ax[i].set_ylabel('state_thresh')
        ax[i].set_title(pa)
        divider = make_axes_locatable(ax[i])
        cax = divider.append_axes("right", size="10%", pad=0.1)
        norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
        fig.colorbar(mpl.cm.ScalarMappable(norm=norm, cmap=cmap),
                     cax=cax, orientation='vertical')#, label=phase)
    fig.show()
    return fig

########################### MAIN ################################################
############# Task specific filenames and events ################################
''' Variation in parameters simulated in December
sequence task reward output was number of times max reward obtained during trial_subset
trial_subset=int(0.05*numevents) = 0.05*10,000 = 500
#therfore, optimal rewards = 500/events_per_trial, which is 83 for HxLen=3 and 71 for HxLen=4
discrim task reward output was mean reward.  Optimal performance is:
    reward of 10 received on every third event, basic effort of -1 on two other events
    thus optimal reward is 8 for trial (per 3 events)
    Since mean was calculated over events (and not multiplied by 3 prior to Jan 3)
    optimal reward per trial is 8/3 prior to Jan 3 and 8 after Jan 3
'''
paradigm= 'sequence'  #'discrim' #'block_Da'# 'block_Da_b6' #'sequence' #'discrim_b6'#
datadir='NormEuclid2021dec10/' # Euclid/discrimRuleNone/  # discrim/ # discrim_beta6/ # Euclid/sequenceRuleNone/ #
min_param_combos=5  #keep lowering the percentiles until this many "best" parameter combinatins have been found
HxLen='4' #only relevant for sequence task
op=operator.gt
optimal_rwd=8 #used for histogram if maximum reward not included if the npz file
block=False #blockDa is false unless otherwise specified

if paradigm == 'sequence':
    Q1end='_Q1_all'#'_Q1*' if HxLen=='4' else '_Q1_all' #this was fur Guassian Mixture model
    file_pattern={'Q1':datadir+'Sequence_paramsHxLen'+HxLen+Q1end,'Q2':datadir+'Sequence_HxLen'+HxLen+'_Q2*q2o0.2*_all'}
    phase_act=[] #if empty, will use key_combos from function
    rwd_column='rwd__End'
else:
    #these are the interesting state-action combinations
    phase_act=['acquire_rwd__End','discrim_rwd__End','reverse_rwd__End',
           'acquire_Pport_6kHz_left_End', 'extinc_Pport_6kHz_left_End','renew_Pport_6kHz_left_Beg','renew_Pport_6kHz_left_End',
           'discrim_Pport_6kHz_left_End','discrim_Pport_10kHz_right_End',
           'reverse_Pport_10kHz_left_End','reverse_Pport_6kHz_right_End']
    rwd_column='tot_rwd'
    rwd_events=['acquire_rwd__End','discrim_rwd__End' ] #first event: what is max observed, 2nd event - more difficult task
    action_events=['acquire_Pport_6kHz_left_End','renew_Pport_6kHz_left_Beg']#first event: what is max observed, 2nd event - more difficult task

#These are updated for some task variants 
if paradigm == 'discrim':
    #discrim task - results stored as action events
    file_pattern={'Q1':datadir+'Discrim2021-12-10_Q1_all','Q2':datadir+'Discrim2021-12-10_Q2_q2o0.1_*'}
   
    '''#uncomment to read earlier data using 1st version of code,
       #saved different state_action combos, used the Gaussian mixture model; Q1 & Q2 learning rules the same
    phase_act=[p+'_End' for p in ['extinc_left','renew_left','acquire_rwd','discrim_rwd', 'reverse_rwd']]+['renew_left_Beg'] # if empty, will use key_combos from function
    #the next 2 events used to find the best results
    rwd_events=['acquire_rwd_End','discrim_rwd_End' ] #first event: what is max observed, 2nd event - more difficult task
    action_events[1]=['acquire_left_End','renew_left_Beg']
    '''   
if paradigm=='block_Da_b6':
    file_pattern={'Q1':'none', 'Q2':'Discrim_beta0.6blockDadip_Q2_*_0'}
    phase_act=['acquire_rwd__End','discrim_rwd__End',
           'acquire_Pport_6kHz_left_End', 'extinc_Pport_6kHz_left_End',
           'discrim_Pport_6kHz_left_End','discrim_Pport_10kHz_right_End']
    action_events[1]='extinc_Pport_6kHz_left_End' # extinc should work with DA_block,
    op=operator.le
    block=True

if paradigm=='block_Da':
    file_pattern={'Q1':'none', 'Q2':datadir+'Discrim_blockDadip_Q2_*_alphaEnd0'}
    phase_act=[p+'_End' for p in ['extinc_left','acquire_rwd','discrim_rwd']]+['extinc_left_Beg'] # if empty, will use key_combos from function
    action_events=['acquire_left_End','extinc_left_End']#first event: what is max observed, 2nd event - more difficult task
    rwd_events=['acquire_rwd_End','discrim_rwd_End' ]
    op=operator.le
    block=True

######## Parameters relevant to all tasks
list_params=['state_thresh','alpha']
sort_pars=['state_thresh1','state_thresh2','alpha1','alpha2']
both_df=[]; both_labels=[];par_names=[]
############################ Analysis for 2 Q matrices ##############
#All simulations didn't run.  Need to re-run some, read in and concatenate several files
#read in all the results

print(file_pattern)
fnamesQ2=glob.glob(file_pattern['Q2']+'.npz')
if len(fnamesQ2):
    Q2df,key_combos,Q2par_names,max_correct=read_data(fnamesQ2,list_params,numQ=2,blockDA=block)
    sort_order=[sp for sp in sort_pars if sp in Q2par_names]
    Q2df.sort_values(sort_order,inplace=True)
    q2mean=Q2df.mean()
    q2std=Q2df.std()
    if rwd_column not in Q2df.columns:
        Q2df[rwd_column]=Q2df[rwd_events].sum(axis=1)    #[max_rwd_event]+ Q2df[rwd_measure2]
    both_df.append(Q2df)
    both_labels.append('Q2')
    par_names.append(sort_order)

    if len(phase_act):
        key_combos=phase_act
    elif paradigm=='sequence':
        key_combos=[pa for pa in key_combos if pa.endswith('End')]
    
    #now create color maps for each state-action, for each set of state transition thresholds
    plot_Q2_results(Q2df, key_combos)
    
####################### Repeat for 1 Q matrix ######################
#read in all the results
fnamesQ1=glob.glob(file_pattern['Q1']+'.npz')
if len(fnamesQ1):
    Q1df,key_combos,Q1par_names,max_correct=read_data(fnamesQ1,list_params)
    sort_order=[sp for sp in sort_pars if sp in Q1par_names]
    Q1df.sort_values(sort_order,inplace=True)
    both_df.append(Q1df)
    both_labels.append('Q1')
    par_names.append(sort_order)

    if len(phase_act):
        key_combos=phase_act
    elif paradigm=='sequence':
        key_combos=[pa for pa in key_combos if pa.endswith('End')]
    #plot
    plot_Q1_results(Q1df, key_combos)
    ##create data series with mean and std
    q1mean=Q1df.mean()
    q1std=Q1df.std()
    if rwd_column not in Q1df.columns:
        Q1df[rwd_column]=Q1df[rwd_events].sum(axis=1)    
######################## Summary over both Q1 and Q2 #####################
if len(fnamesQ1) and len(fnamesQ2):
    summary=pd.DataFrame({'Q1mean':q1mean,'Q1std':q1std,'Q2mean':q2mean,'Q2std':q2std})
elif len(fnamesQ2):
    summary=pd.DataFrame({'Q2mean':q2mean,'Q2std':q2std})
elif len(fnamesQ1):
    summary=pd.DataFrame({'Q1mean':q1mean,'Q1std':q1std})
else:
    print('no files found for Q1 or Q2!!!')

##############  Finish creating summary df
drop_idx=[x for x in summary.index if x.endswith('_STD')]
for idx in drop_idx:
    summary.drop(idx,axis=0, inplace=True)
#print(summary)

############### print results for best parameters #############################

if paradigm=='sequence':# or max_rwd_event=='tot_rwd':
    for label,df,par_names in zip(both_labels,both_df,par_names):
        stdkeys=[kc+'_STD' for kc in key_combos if kc+'_STD' in df.columns]
        found1=0
        for crit in np.arange(0.99,0.8,-0.01):
            rwd_crit=df[rwd_column].quantile(crit)
            best=df.loc[(df[rwd_column] >= rwd_crit)]
            if len(best) and len(best)>found1:
                found1=len(best)
                #/max_correct to convert to percent
                print('*** crit =',round(crit,3),'for',label,'\n',best[par_names],'\n',best[key_combos],'\n',best[stdkeys])
                if len(best)>min_param_combos:
                    break #break out of loop, once best performance found
            else:
                print('**** no records found for', label,'using criteria as alow as', round(crit,3))
else:
    for label,df,par_names in zip(both_labels,both_df,par_names):
        stdkeys=[kc+'_STD' for kc in key_combos if kc+'_STD' in df.columns]
        found1=0
        for crit in np.arange(0.99,0.7,-0.01):
            max_rwd=df[rwd_column].max()
            best_row=df[df[rwd_column]==max_rwd]
            rwd_crit=df[rwd_events[1]].quantile(crit)
            if op==operator.le or op==operator.lt:
                act_crit=df[action_events[1]].quantile(1-crit)
            else:
                act_crit=df[action_events[1]] .quantile(crit)               
            best=df.loc[(df[rwd_events[1]] >= rwd_crit) & op(df[action_events[1]], act_crit)]
            if len(best) and len(best)>found1:
                found1=len(best)
                print('*** crit=',round(crit,3),'on',rwd_events[1],action_events[1],'for',label,'\n',best[par_names],'\n',best[key_combos],'\n',best[stdkeys])
                if len(best)>min_param_combos:
                    break #break out of loop, once best performance found
            else:
                print('**** no records found for', label,'using criteria of ', round(crit,3),rwd_events[1],round(rwd_crit,3),action_events[1],round(act_crit,3))
print(summary.loc[key_combos])

########### Graph showing effect of parameters on reward ##########
#### only for Q1 because too many dimensions for Q2 #### 
rwd_combos=[a for a in key_combos if 'rwd' in a and 'End' in a]
if not paradigm.startswith('block_DA'):               
    fig,axis=plt.subplots(nrows=len(rwd_combos),ncols=1,sharex=True,sharey=True)
    ax=fig.axes
    fig.suptitle(paradigm+' 1Q')
    x=Q1df['alpha1'].to_numpy(dtype=float)
    y=Q1df['state_thresh1'].to_numpy(dtype=float)
    plotx=np.reshape(x,(len(np.unique(y)),len(np.unique(x))))
    ploty=np.reshape(y,(len(np.unique(y)),len(np.unique(x))))
    for jj,pa in enumerate(rwd_combos):
        zvals=(Q1df[pa]).to_numpy(dtype=float)/max_correct
        plotz=np.reshape(zvals,(len(np.unique(y)),len(np.unique(x))))
        for i,row in enumerate(plotz):
            ax[jj].plot(plotx[i],row,label=str(ploty[i][0]))
            ax[jj].set_ylabel(pa[0:7]+', % max)')
        ax[-1].set_xlabel('alpha')
        ax[-1].legend()
#copy plotz into igor?  name columns according to st, create alpha wave

######## histogram (pdf and CDF) showing effect of parameters on reward #######
if max_correct==1:
    binmax=optimal_rwd #or np.max(Q2df[rwd_combos].max())
    add_label=''
else:
    binmax=100
    add_label=' (% max)'
if paradigm.startswith('block_DA'):
    fig,axis=plt.subplots(nrows=1,ncols=2)
    fig.suptitle(paradigm)    
    for pa in rwd_combos:
        hist,bin_edges=np.histogram(df[pa]/max_correct,bins=25,range=(0,binmax))
        plot_bins=[(bin_edges[i]+bin_edges[i+1])/2 for i in range (len(hist))]
        axis[0].plot(plot_bins,hist/np.sum(hist),label=pa[0:7])
        axis[1].plot(plot_bins,np.cumsum(hist/np.sum(hist)),label=pa[0:7])
        axis[0].set_ylabel('pdf')
        axis[0].set_ylabel('CDF')
        for kk in [0,1]:
            axis[kk].set_xlabel(' (fraction of optimal)') 
            axis[kk].legend()    
else:
    save_hist={}
    fig,axis=plt.subplots(nrows=len(rwd_combos),ncols=2)
    ax=fig.axes
    if paradigm=='sequence':
        fig.suptitle(paradigm+', press Hx len='+HxLen)
    else:
        fig.suptitle(paradigm)
    for df,lbl in zip(both_df,both_labels):
        for jj,pa in enumerate(rwd_combos):
            hist,bin_edges=np.histogram(df[pa]/max_correct,bins=25,range=(0,binmax))
            save_hist[lbl]=hist
            plot_bins=[(bin_edges[i]+bin_edges[i+1])/2 for i in range (len(hist))]
            ax[jj*2].plot(plot_bins,hist/np.sum(hist),label=lbl)
            ax[jj*2].set_ylabel('fraction')
            ax[jj*2+1].plot(plot_bins,np.cumsum(hist/np.sum(hist)),label=lbl)
            ax[jj*2+1].set_ylabel('CDF')
            for kk in [0,1]:
                ax_num=jj*2+kk
                ax[ax_num].set_xlabel(pa[0:7]+add_label) 
                ax[ax_num].legend()
    hist_txt=plot_bins
    header='plot_bins'
    for k,v in save_hist.items():
        hist_txt=np.column_stack(([hist_txt,v]))
        header=header+'   '+k
    np.savetxt(paradigm+'_HxLen'+str(HxLen)+'_histogram.txt',hist_txt,header=header,fmt='%.3f')

#mutliply plot_bins (x values) by 100 for percent of optimal reward