# -*- coding: utf-8 -*-
"""
Created on Fri Apr 23 21:14:25 2021

@author: kblackw1
"""
import numpy as np
import completeT_env as tone_discrim
import agent_twoQtwoSsplit as QL
from RL_TD2Q import RL
import RL_utils as rlu
from TD2Q_Qhx_graphs import Qhx_multiphase,Qhx_multiphaseNewQ
from scipy import optimize

def linear(x,m,b):                                     
    return m*x+b

def plot_prob_tracking(ratio,fractionLeft,trials,showplot=True):
    ratio_list=[r for r in ratio.values()]
    fraction_list=[np.nanmean(f) for f in fractionLeft.values()]
    RMS=np.zeros(trials)
    for t in range(trials):
        RMS[t]=np.sqrt(np.sum([np.square(ratio[k]-f[t]) for k,f in fractionLeft.items()]))
    RMSmean=np.mean(RMS)
    RMSstd=np.std(RMS)
    delta=np.sqrt(np.sum([np.square(ratio[k]-np.nanmean(f)) for k,f in fractionLeft.items()]))
    popt, pcov = optimize.curve_fit(linear,ratio_list,fraction_list)
    predict=popt[0]*np.array(ratio_list)+popt[1]
    if showplot:
        from matplotlib import pyplot as plt
        plt.figure()
        plt.plot(ratio_list,fraction_list,'k*',label='actual, diff='+str(round(delta,4)))
        plt.plot(ratio_list,predict,'ro',label='fit')
        plt.xlabel('prob reward ratio')
        plt.ylabel('prob observed')
        plt.legend()
        plt.show()
    return popt,pcov,delta,RMSmean,RMSstd,RMS
    
def plot_histogram(fractionLeft,keys):
    ####### Histogram ###########
    from matplotlib import pyplot as plt
    plt.figure()
    plt.title('histogram '+','.join(keys))
    width=0.9/len(keys)
    for key in keys:
        hist,bins=np.histogram(fractionLeft[key])
        response1=hist[0]+hist[-1]
        print('std of P(L) for',key,'=',np.nanstd(fractionLeft[key]))
        plt.hist(fractionLeft[key],rwidth=width,label=key+'='+str(response1))
    plt.xlabel('P(L)')
    plt.ylabel('number (out of 40)')
    plt.legend()
    plt.show()
    return

def plot_reaction_time(display_runs,all_RT):
    from matplotlib import pyplot as plt
    for rr,r in enumerate(display_runs):
        fig,ax=plt.subplots()
        fig.suptitle('agent RT '+str(r))
        ax.plot(all_RT[r])
    plt.show()
    return fig

def plot_prob_traject(data,params,show_plot=True):
    p_choose_L={}
    for k in data.keys():
        p_choose_L[k]=data[k][(('Pport', '6kHz'), 'left')]['mean']/(data[k][(('Pport', '6kHz'), 'left')]['mean']+data[k][(('Pport', '6kHz'), 'right')]['mean'])
    p_choose_k_sorted=dict(sorted(p_choose_L.items(),key=lambda item: float(item[0].split(':')[0])/float(item[0].split(':')[1]),reverse=True))
    if show_plot:
        from matplotlib import pyplot as plt   
        plt.figure()
        plt.suptitle(traject_title+' wt learn:'+str(params['wt_learning']))
        
        colors=plt.get_cmap('inferno') #plasma, viridis, inferno or magma possible
        color_increment=int((len(colors.colors)-40)/(len(data.keys())-1)) #40 to avoid to light colors
        for k,key in enumerate(p_choose_k_sorted.keys()):
            cnum=k*color_increment
            plt.plot(p_choose_k_sorted[key],color=colors.colors[cnum],label=key)
        plt.legend()
        plt.ylabel('prob(choose L)')
        plt.xlabel('block')
    return p_choose_k_sorted

def shift_stay_list(acq,all_counts,rwd,loc,tone,act,r):
    def count_shift_stay(rwd_indices,same,different,counts,r):
        for phase in rwd_indices.keys(): 
            for index in rwd_indices[phase]:
                #count how many times next trial was left versus right
                if index+1 in same[phase]:
                    counts[phase]['stay'][r]+=1
                elif index+1 in different[phase]:
                    counts[phase]['shift'][r]+=1
        return counts
    responses={};total={}
    actions=['left','right']
    yes_rwd={act:{} for act in actions}
    no_rwd={act:{} for act in actions}
    for phase,rl in acq.items():
        res=rl.results
        responses[phase]=[list(res['state'][i])+[(res['action'][i])]+[(res['reward'][i+1])] for i in range(len(res['reward'])-1) if res['state'][i]==(loc['Pport'],tone['6kHz'])]    
        for action in actions:
            yes_rwd[action][phase]=[i for i,lst in enumerate(responses[phase]) if lst==[loc['Pport'],tone['6kHz'],act[action],rwd['reward']]]
            no_rwd[action][phase]=[i for i,lst in enumerate(responses[phase])if lst==[loc['Pport'],tone['6kHz'],act[action],rwd['base']]]
    for action in actions:
        total[action]={phase:sorted(yes_rwd[action][phase]+no_rwd[action][phase]) for phase in acq.keys()}

    all_counts['left_rwd']=count_shift_stay(yes_rwd['left'],total['left'],total['right'],all_counts['left_rwd'],r)
    all_counts['left_none']=count_shift_stay(no_rwd['left'],total['left'],total['right'],all_counts['left_none'],r)
    all_counts['right_rwd']=count_shift_stay(yes_rwd['right'],total['right'],total['left'],all_counts['right_rwd'],r)
    all_counts['right_none']=count_shift_stay(no_rwd['right'],total['right'],total['left'],all_counts['right_none'],r)
    return all_counts,responses

def calc_fraction_left(traject_dict,runs):
    fractionLeft={k:[] for k in traject_dict.keys()}
    noL={k:0 for k in traject_dict.keys()};noR={k:0 for k in traject_dict.keys()}
    ratio={}
    for k in traject_dict.keys():
        for run in range(runs):
            a=np.sum(traject_dict[k][(('Pport', '6kHz'),'left')][run])
            b=np.sum(traject_dict[k][(('Pport', '6kHz'),'right')][run])
            if (a+b)>0:
                fractionLeft[k].append(round(a/(a+b),4))
            else:
                fractionLeft[k].append(-1)
            if a==0:
                noL[k]+=1
            if b==0:
                noR[k]+=1
        R=float(k.split(':')[1])
        L=float(k.split(':')[0])
        ratio[k]=L/(L+R)
    return fractionLeft,noL,noR,ratio

def combined_bandit_Qhx_response(random_order,num_blocks,traject_dict,Qhx,boundaries,params,phases):
    from matplotlib.gridspec import GridSpec
    import matplotlib.pyplot as plt
    from TD2Q_Qhx_graphs import agent_response,plot_Qhx_2D
    ept=params['events_per_trial']
    trials_per_block=params['trials_per_bock']
    fig=plt.figure()
    gs=GridSpec(2,2) # 2 rows, 2 columns
    ax1=fig.add_subplot(gs[0,:]) # First row, span all columns
    ax2=fig.add_subplot(gs[1,0]) # 2nd row, 1st column
    ax3=fig.add_subplot(gs[1,1]) # 2nd row, 2nd column
    agent_response([-1],random_order,num_blocks,traject_dict,trials_per_block,fig,ax1)
    fig=plot_Qhx_2D(Qhx,boundaries,ept,phases,fig=fig,ax=[ax2,ax3])  
    
    #add subplot labels 
    letters=['A','B']
    fsize=14
    blank=0.03
    label_inc=(1-2*blank)/len(letters)
    for row in range(len(letters)):
        y=(1-blank)-(row*label_inc) #subtract because 0 is at bottom
        fig.text(0.02,y,letters[row], fontsize=fsize)
    fig.tight_layout()
    return fig
def accum_meanQ(Qhx,mean_Qhx):
    for q in Qhx.keys():
        for st in Qhx[q].keys():
            for aa in Qhx[q][st].keys():
                mean_Qhx[q][st][aa].append(Qhx[q][st][aa])
    return mean_Qhx
def calc_meanQ(mean_Qhx,all_Qhx):
    for q in mean_Qhx.keys():
        for st in mean_Qhx[q].keys():
            for aa in mean_Qhx[q][st].keys():
                print('calc_mean',q,st,aa,np.shape(mean_Qhx[q][st][aa]))
                mean_Qhx[q][st][aa]=np.mean(mean_Qhx[q][st][aa],axis=0)
                print('after calc_mean',q,st,aa,np.shape(mean_Qhx[q][st][aa]))
    all_Qhx.append(mean_Qhx)
    return mean_Qhx,all_Qhx
def opal_params(params):
    ################# For OpAL ################
    params['numQ']=2
    params['Q2other']=0
    params['decision_rule']='delta'
    params['alpha']=[0.1,0.1]#[0.2,0.2]#
    params['beta_min']=1
    params['beta']=1
    params['gamma']=0.1 #alpha_c
    params['state_thresh']=[0.75,0.625]
    params['initQ']=1 #do not split states, initialize Q values to 1
    params['D2_rule']='Opal'
    return params

if __name__ == '__main__':
    step1=False
    if step1:
        from Bandit1stepParam import params, env_params, states,act, Rbandit, Tbandit
        from Bandit1stepParam import loc, tone, rwd
    else:
        from BanditTaskParam import params, env_params, states,act, Rbandit, Tbandit
        from BanditTaskParam import loc, tone, rwd,include_wander
    events_per_trial=params['events_per_trial']  #this is task specific
    trials=100 
    numevents= events_per_trial*trials
    runs=40 #Hamid et al uses 14 rats. 40 gives smooth trajectories
    noise=0.15 #0.15 if start from oldQ, 0.05 if start new each phase.make noise small enough or state_thresh small enough to minimize new states in acquisition.  
    #control output
    printR=False #print environment Reward matrix
    Info=False#print information for debugging
    plot_hist=1#1: plot final Q, 2: plot the time since last reward, etc.
    plot_Qhx=1 #2D or 3D plot of Q dynamics.  if 1, then plot agent responses
    print_shift_stay=True
    save_data=True
    risk_task=False
    action_items=[(('Pport','6kHz'),a) for a in ['left','right']] #act.keys()]+[(('start','blip'),a) for a in act.keys()]
    ########### #For each run, randomize the order of this sequence #############
    prob_sets={'50:50':{'L':0.5, 'R': 0.5},'10:50':{'L':0.1,'R':0.5},
            '90:50':{'L':0.9, 'R': 0.5},'90:10':{'L':0.9, 'R': 0.1},
            '50:90':{'L':0.5, 'R': 0.9},'50:10':{'L':0.5, 'R': 0.1},'10:90':{'L':0.1,'R':0.9}}
    #           '91:49':{'L':0.91, 'R': 0.49},'49:91':{'L':0.49, 'R': 0.91},'49:11':{'L':0.49, 'R': 0.11},'11:49':{'L':0.11,'R':0.49}} 
    prob_sets=dict(sorted(prob_sets.items(),key=lambda item: float(item[0].split(':')[0])/float(item[0].split(':')[1]),reverse=True))

    #prob_sets={'20:80':{'L':0.2,'R':0.8},'80:20':{'L':0.8,'R':0.2}}
    if risk_task:
        prob_sets={'100:100':{'L':1,'R':1},'100:50':{'L':1,'R':0.5},'100:25':{'L':1,'R':0.25},'100:12.5':{'L':1,'R':0.125}}
    learn_phases=list(prob_sets.keys())
    figure_sets=[list(prob_sets.keys())]
    traject_items={phs:action_items+['rwd'] for phs in learn_phases}
    histogram_keys=[]#['50:50','10:50'] #to plot histogram of P(L)

    cues=[]
    trial_subset=int(0.1*numevents) #display mean reward and count actions over 1st and last of these number of trials 
    #update some parameters of the agent
    params['Q2other']=0.0
    params['numQ']=2
    params['wt_learning']=False
    params['distance']='Euclidean'
    params['beta_min']=0.5 #increased exploration when rewards are low
    params['beta']=1.5
    params['beta_GPi']=10
    params['gamma']=0.82
    params['moving_avg_window']=3  #This in units of trials, the actual window is this times the number of events per trial
    params['decision_rule']= None #'delta' #'mult' #
    params['initQ']=-1#-1 means do state splitting. If initQ=0, 1 or 10, it means initialize Q to that value and don't split
    params['D2_rule']= None #'Ndelta' #'Bogacz' #'Opal'#'Bogacz' ### Opal: use Opal update without critic, Ndelta: calculate delta for N matrix from N values
    params['step1']=step1
    use_oldQ=True
    params['use_Opal']=False
    divide_rwd_by_prob=False
    non_rwd=rwd['base'] #rwd['base'] or rwd['error'] #### base is better
    params['Da_factor']=1
    if params['distance']=='Euclidean':
        #state_thresh={'Q1':[0.875,0],'Q2':[0.875,0.75]} #For Euclidean distance
        #alpha={'Q1':[0.6,0],'Q2':[0.6,0.3]}    #For Euclidean distance
        state_thresh={'Q1':[0.875,0],'Q2':[0.75,0.625]} #For normalized Euclidean distance
        alpha={'Q1':[0.6,0],'Q2':[0.4,0.2]}    #For normalized Euclidean distance, 2x discrim values works with 100 trials

    else:
        state_thresh={'Q1':[0.22, 0.22],'Q2':[0.20, 0.22]} #For Gaussian Mixture?, 
        alpha={'Q1':[0.4,0],'Q2':[0.4,0.2]}    #For Gaussian Mixture? [0.62,0.19] for beta=0.6, 1Q or 2Q;'

    params['state_thresh']=state_thresh['Q'+str(params['numQ'])] #for euclidean distance, no noise
    #lower means more states for Euclidean distance rule
    params['alpha']=alpha['Q'+str(params['numQ'])] # 

    ################# For OpAL ################
    if params['use_Opal']: #use critic instead of RPEs, and use Opal learning rule
        params=opal_params(params)
    ######################################  
    traject_title='num Q: '+str(params['numQ'])+', beta:'+str(params['beta_min'])+':'+str(params['beta'])+', non_rwd:'+str(non_rwd)+',rwd/p:'+str(divide_rwd_by_prob)
    epochs=['Beg','End']

    keys=rlu.construct_key(action_items +['rwd'],epochs)
    resultslist={phs:{k+'_'+ep:[] for k in keys.values() for ep in epochs} for phs in learn_phases}
    traject_dict={phs:{ta:[] for ta in traject_items[phs]} for phs in learn_phases}

    #count number of responses to the following actions:
    results={phs:{a:{'Beg':[],'End':[]} for a in action_items+['rwd']} for phs in learn_phases}

    ### to plot performance vs trial block
    trials_per_block=10
    events_per_block=trials_per_block* events_per_trial
    num_blocks=int((numevents+1)/events_per_block)
    params['events_per_block']=events_per_block
    params['trials_per_block']=trials_per_block
    params['trial_subset']=trial_subset
    resultslist['params']={p:[] for p in params.keys()}

    random_order=[]
    key_list=list(prob_sets.keys())
    ######## Initiate dictionaries storing stay shift counts
    all_counts={'left_rwd':{},'left_none':{},'right_rwd':{},'right_none':{}}
    if step1:
        extra_acts=[]
    else:
        extra_acts=['hold', 'wander']
    #extra_acts=['hold'] for simpler 3 step task
    for key,counts in all_counts.items():
        for phase in learn_phases:
            counts[phase]={'stay':[0]*runs,'shift':[0]*runs}
    Qhx_states=[('Pport','6kHz'),('start','blip')]
    Qhx_actions=['left','right','center']
    if step1:
        extra_acts=[]
        Qhx_states=[('Pport','6kHz')]
        Qhx_actions=['left','right']
    else:
        extra_acts=['hold'] #to make task simpler and demonstrate results not due to complexity of task
        if include_wander:
            extra_acts=['hold', 'wander']
    wrong_actions={aaa:[0]*runs for aaa in extra_acts} 
    all_beta=[];all_lenQ=[];all_Qhx=[]; all_bounds=[]; all_ideals=[];all_RT=[]
    mean_Qhx={q:{st:{aa:[] for aa in Qhx_actions } for st in Qhx_states} for q in range(params['numQ'])}
    for r in range(runs):
        #randomize prob_sets
        acqQ={};acq={};beta=[];lenQ={q:[] for q in range(params['numQ'])};RT=[]
        random_order.append([k for k in key_list]) #keep track of order of probabilities
        print('*****************************************************\n************** run',r,'prob order',key_list)
        for phs_num,phs in enumerate(key_list):
            prob=prob_sets[phs]
            print('$$$$$$$$$$$$$$$$$$$$$ run',r,'prob',phs,prob, 'phase number',phs_num)
            #do not scale these rewards by prob since the experiments did not
            if not step1:
                Tbandit[(loc['Pport'],tone['6kHz'])][act['left']]=[((loc['Lport'],tone['success']),prob['L']),((loc['Lport'],tone['error']),1-prob['L'])] #hear tone in poke port, go left, in left port/success
                Tbandit[(loc['Pport'],tone['6kHz'])][act['right']]=[((loc['Rport'],tone['success']),prob['R']),((loc['Rport'],tone['error']),1-prob['R'])]
            if divide_rwd_by_prob:
                Rbandit[(loc['Pport'],tone['6kHz'])][act['left']]=[(rwd['reward']/prob['L'],prob['L']),(non_rwd,1-prob['L'])]   #lick in left port - 90% reward   
                Rbandit[(loc['Pport'],tone['6kHz'])][act['right']]=[(rwd['reward']/prob['R'],prob['R']),(non_rwd,1-prob['R'])] 
            elif risk_task:
                Rbandit[(loc['Pport'],tone['6kHz'])][act['left']]=[(rwd['reward'],prob['L']),(non_rwd,1-prob['L'])]   #lick in left port - 90% reward   
                Rbandit[(loc['Pport'],tone['6kHz'])][act['right']]=[(rwd['reward']*4,prob['R']),(non_rwd,1-prob['R'])] #right is risky lever, provide 4x reward
            else:
                Rbandit[(loc['Pport'],tone['6kHz'])][act['left']]=[(rwd['reward'],prob['L']),(non_rwd,1-prob['L'])]   #lick in left port - 90% reward   
                Rbandit[(loc['Pport'],tone['6kHz'])][act['right']]=[(rwd['reward'],prob['R']),(non_rwd,1-prob['R'])] 
            #for k,v in Rbandit.items():
            #    print(k,v)
            if use_oldQ:
                acq[phs] = RL(tone_discrim.completeT, QL.QL, states,act,Rbandit,Tbandit,params,env_params,printR=printR,oldQ=acqQ)
            else:
                acq[phs] = RL(tone_discrim.completeT, QL.QL, states,act,Rbandit,Tbandit,params,env_params,printR=printR) #start each epoch from init
            acq[phs].agent.Da_factor=params['Da_factor']
            results,acqQ=rlu.run_sims(acq[phs], phs,numevents,trial_subset,action_items,noise,Info,cues,r,results,phist=plot_hist)
            for aaa in extra_acts:
                wrong_actions[aaa][r]+=acq[phs].results['action'].count(act[aaa])
            traject_dict=acq[phs].trajectory(traject_dict, traject_items,events_per_block)
            #print ('prob complete',acq.keys(), 'results',results[phs],'traject',traject_dict[phs])
            beta.append(acq[phs].agent.learn_hist['beta'])
            RT.append([np.mean(acq[phs].agent.RT[x*events_per_trial:(x+1)*events_per_trial]) for x in range(trials)] )
            for q,qlen in acq[phs].agent.learn_hist['lenQ'].items():
                lenQ[q].append(qlen)
            #print(' !!!!!!!!!!!!!!!! End of phase', len(acq),'phase=', acq[phs].name, 'Qhx shape=',np.shape(acq[phs].agent.Qhx[0]))
        np.random.shuffle(key_list) #shuffle after run complete, so that first run does 50:50 first    
        ###### Count stay vs shift 
        all_counts,responses=shift_stay_list(acq,all_counts,rwd,loc,tone,act,r)
        #store beta, lenQ, Qhx, boundaries,ideal_states from the set of phases in a single trial/agent    
        all_beta.append([b for bb in beta for b in bb])
        all_RT.append([b for bb in RT for b in bb])
        all_lenQ.append({q:[b for bb in lenQ[q] for b in bb] for q in lenQ.keys()})
        agents=list(acq.values()) 
        if use_oldQ:
            Qhx, boundaries,ideal_states=Qhx_multiphase(Qhx_states,Qhx_actions,agents,params['numQ'])
        else:  #### sort agents by name (prob), otherwise the mean will be meaningless
            Qhx, boundaries,ideal_states=Qhx_multiphaseNewQ(Qhx_states,Qhx_actions,agents,params['numQ'])
        all_bounds.append(boundaries)
        all_Qhx.append(Qhx)
        all_ideals.append(ideal_states)
        mean_Qhx=accum_meanQ(Qhx,mean_Qhx)
    if not use_oldQ: #do not average across Qvalues if a) starting from previous and b) random order
        mean_Qhx,all_Qhx=calc_meanQ(mean_Qhx,all_Qhx)
        #print('mean_Qhx shape',np.shape(all_Qhx[-1][0][('Pport', '6kHz')]['left']), np.shape(all_Qhx[-1][0][('Pport', '6kHz')]['left']) )
        #random_order.append(sorted(key_list, key=lambda x: float(x.split(':')[0])/float(x.split(':')[1])))
    all_ta=[];output_data={}
    for phs in traject_dict.keys():
        output_data[phs]={}
        for ta in traject_dict[phs].keys():
            all_ta.append(ta)
            output_data[phs][ta]={'mean':np.mean(traject_dict[phs][ta],axis=0),'sterr':np.std(traject_dict[phs][ta],axis=0)/np.sqrt(runs-1)}
    all_ta=list(set(all_ta))
    #move reward to front
    all_ta.insert(0, all_ta.pop(all_ta.index('rwd')))
    for p in resultslist['params'].keys():             #
        resultslist['params'][p].append(params[p])                #
    resultslist=rlu.save_results(results,keys,resultslist)

    print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
    print(' Using',params['numQ'], 'Q, alpha=',params['alpha'],'thresh',params['state_thresh'], 'beta=',params['beta'],'runs',runs,'of total events',numevents)
    print(' apply learning_weights:',[k+':'+str(params[k]) for k in params.keys() if k.startswith('wt')])
    print(' D2_rule=',params['D2_rule'],'decision rule=',params['decision_rule'],'split/initQ=',params['initQ'],'critic=',params['use_Opal'])
    print('counts from ',trial_subset,' events (',events_per_trial,' events per trial)          BEGIN    END    std over ',runs,'runs')
    for phase in results.keys():
        for sa,counts in results[phase].items():
            print(phase,prob_sets[phase], sa,':::',np.round(np.mean(counts['Beg']),2),'+/-',np.round(np.std(counts['Beg']),2),
                ',', np.round(np.mean(counts['End']),2),'+/-',np.round(np.std(counts['End']),2))
            if sa in resultslist[phase]:
                print( '            ',sa,':::',[round(val,3) for lst in resultslist[phase][sa] for val in lst] )

    print('$$$$$$$$$$$$$ total End reward=',np.sum([np.mean(results[k]['rwd']['End']) for k in results.keys()]))
    print('divide by reward prob=',divide_rwd_by_prob,',non reward value', non_rwd)

    print('\n************ fraction Left ***************')
    fractionLeft,noL,noR,ratio=calc_fraction_left(traject_dict,runs)
    import persever as p
    if '50:50' in prob_sets.keys():
        persever,prior_phase=p.perseverance(traject_dict,runs,random_order,'50:50')
    for k in fractionLeft.keys():
        print(k,round(ratio[k],2),'mean Left',round(np.nanmean(fractionLeft[k]),2), 
        ', std',round(np.nanstd(fractionLeft[k]),2), '::: trials with: no response', fractionLeft[k].count(np.nan), 
    #      ', no L',results[k][(('Pport', '6kHz'),'left')]['End'].count(0),', no R',results[k][(('Pport', '6kHz'),'right')]['End'].count(0))
        ', no L',noL[k],', no R',noR[k])
    if print_shift_stay:
        print('\n************ shift-stay ***************')
        for phs in ['50:50']:#all_counts['left_rwd'].keys():
            print('\n*******',phs,'******')
            for key,counts in all_counts.items():
                print(key,':::\n   stay',counts[phs]['stay'],'\n   shift',counts[phs]['shift'])
                ss_ratio=[stay/(stay+shift) for stay,shift in zip(counts[phs]['stay'],counts[phs]['shift']) if stay+shift>0 ]
                events=[(stay+shift) for stay,shift in zip(counts[phs]['stay'],counts[phs]['shift'])]
                print(key,'mean stay=',round(np.mean(ss_ratio),3),'+/-',round(np.std(ss_ratio),3), 'out of', np.mean(events), 'events per run')

    print('wrong actions',[(aaa,np.mean(wrong_actions[aaa])) for aaa in wrong_actions.keys()])
    if save_data:
        import datetime
        dt=datetime.datetime.today()
        date=str(dt).split()[0]
        key_params=['numQ','Q2other','beta_GPi','decision_rule','beta_min','beta','gamma','use_Opal','step1','D2_rule']
        fname_params=key_params+['initQ']
        fname='Bandit'+date+'_'.join([k+str(params[k]) for k in fname_params])+'_rwd'+str(rwd['reward'])+'_'+str(rwd['none'])+'_wander'+str(include_wander)
        #np.savez(fname,par=params,results=resultslist,traject=output_data)
        np.savez(fname,par=params,results=resultslist,traject=output_data,traject_dict=traject_dict,shift_stay=all_counts,rwd=rwd)
    rlu.plot_trajectory(output_data,traject_title,figure_sets)
    plot_prob_traject(output_data,params)
    if len(histogram_keys):
        plot_histogram(fractionLeft,histogram_keys)
    popt,pcov,delta,RMSmean,RMSstd,_=plot_prob_tracking(ratio,fractionLeft,runs)
    tot_rwd=np.sum([np.mean(results[k]['rwd']['End']) for k in results.keys()])
    rwd_var=np.sum([np.var(results[k]['rwd']['End']) for k in results.keys()])
    print('$$$$$$$$$$$$$  beta min,max,Gpi=',params['beta_min'],params['beta'],params['beta_GPi'],'gamma=',params['gamma'],'rule=',params['decision_rule']\
    ,'step1=',step1,'\n$$$$$$$$$$$$$ total End reward=',round(tot_rwd,2),'+/-',round(np.sqrt(rwd_var),2))
    print({round(ratio[k],3): round(np.nanmean(fractionLeft[k]),3) for k in fractionLeft.keys()} )
    print('quality of prob tracking: slope=', round(popt[0],4),'+/-', round(pcov[0,0],4), 'delta=',round(delta,4),'RMS=',round(RMSmean,4),'+/-',round(RMSstd,4))
    if '50:50' in prob_sets.keys():
        print('perseverance',', no L',noL['50:50'],', no R',noR['50:50'],'fraction',(noL['50:50']+noR['50:50'])/runs)

    ag_dict={}
    for ag in agents:
        ag_dict[ag.name]= {'Q':ag.agent.Q, 'ideal_states':ag.agent.ideal_states,'states':states,'actions':act}    
        if hasattr(ag.agent,'V'):
            ag_dict[ag.name]['V']=ag.agent.V
    np.save('Q_V_'+fname,ag)    
    if plot_Qhx:
        from TD2Q_Qhx_graphs import agent_response
        display_runs=range(min(3,runs))
        figs=agent_response(display_runs,random_order,num_blocks,traject_dict,trials_per_block,norm=1/trials_per_block)
        phases=list(acq.keys())
        if save_data:
            if runs>50: #only save mean_Qhx to reduce file size
                all_Qhx=[mean_Qhx]
                all_bounds=[all_bounds[0]] #all are the same, no need to save all of them
                ideals=[all_ideals[0]] #not all the same, but not used in Qhx graph
            np.savez('Qhx'+fname,all_Qhx=all_Qhx,all_bounds=all_bounds,params=params,phases=phases,
                all_ideals=all_ideals,random_order=random_order,num_blocks=num_blocks,all_beta=all_beta,all_lenQ=all_lenQ,all_RT=all_RT)
    if plot_Qhx==2:
        from TD2Q_Qhx_graphs import plot_Qhx_2D 
        #plot Qhx and agent response       
        if use_oldQ:
           fig=plot_Qhx_2D(all_Qhx[display_runs[0]],boundaries,params['events_per_trial'],phases)  
        else:
            fig=plot_Qhx_2D(mean_Qhx,boundaries,params['events_per_trial'],phases)  
            
        ########## combined figure ##############   
        if len(Qhx[0].keys())==1:
            fig=combined_bandit_Qhx_response(random_order,num_blocks,traject_dict,Qhx,boundaries,params,phases)
    elif plot_Qhx==3: 
        ### B. 3D plot Q history for selected actions, for all states, one graph per phase
        for rl in acq.values():
            rl.agent.plot_Qdynamics(['left','right'],'surf',title=rl.name)
    '''
    select_states=['suc','Ppo']
    for i in range(params['numQ']):
        acq['50:50'].agent.visual(acq['50:50'].agent.Q[i],labels=acq['50:50'].state_to_words(i,noise),title='50:50 Q'+str(i),state_subset=select_states)
    '''