# -*- coding: utf-8 -*-
"""
Created on Wed Aug  5 13:39:14 2020
2021 march: change agent to use euclidean distance instead of Gaussian mixture
        change learning rule for Q2 to use decreases in predicted reward

@author: kblackw1
"""
import numpy as np
import copy
import completeT_env as tone_discrim
import agent_twoQtwoSsplit as QL
from RL_TD2Q import RL
import RL_utils as rlu
from TD2Q_Qhx_graphs import Qhx_multiphase         
from Discrim2stateProbT_twoQtwoS import select_phases, respond_decay

if __name__ == "__main__":
    from DiscriminationTaskParam2 import params,states,act
    events_per_trial=params['events_per_trial']  #this is task specific
    trials=200 #Iino: 180 trials for acq, then 160 trials for discrim; or discrim from the start using 60 trials of each per day (120 trials) * 3 days
    numevents= events_per_trial*trials
    runs=10 #10 for paper
    #control output
    plot_hist=0
    printR=False #print environment Reward matrix
    Info=False#print information for debugging
    #additional cues that are part of the state for the agent, but not environment
    #this means that they do not influence the state transition matrix
    context=[[0],[1]] #set of possible context cues
    noise=0.15 #make noise small enough or state_thresh small enough to minimize new states in acquisition
    #action_items is a subset of the state-action combinations that an agent can perform
    #count number of responses to the following state-action combos:
    action_items=[(('start','blip'),'center'),(('Pport','6kHz'),'left'),(('Pport','6kHz'),'right'),(('Pport','10kHz'),'left'),(('Pport','10kHz'),'right')]
    #action_items=['center','left','right']

    block_DA_dip=False#'AIP' #AIP or no_dip blocks homosynaptic LTP, no_dip blocks heterosynaptic LTD, False - control
    PREE=0
    savings='none'#''none'#'in new context'# 'after extinction'##'none'# #- for simulating discrim and reverse
    extinct='none' #AAB: aquire and extinguish in A, try to renew in B; ABB: aquire in A, extinguish in B, re-test renewal in B
    #Specify which learning protocols/phases to implement
    learn_phases,figure_sets,traject_items,acq_cue,ext_cue,ren_cue,dis_cue,ren_cue,acq2_cue=select_phases(block_DA_dip,PREE,savings,extinct,context,action_items)
    state_sets=[[('Pport','6kHz'),('Pport','10kHz')],[('Pport','6kHz')]]
    phases=[['acquire','discrim','reverse'],['acquire','extinc','renew']] 
    trial_subset=int(0.1*numevents) #display mean reward and count actions over 1st and last of these number of trials 
    #update some parameters of the agent
    params['decision_rule']=None #'delta' #'mult' #  #'combo','sumQ2', None means use direct negative of D1 rule
    params['Q2other']=0.1
    params['numQ']=2
    params['beta_min']=0.5
    params['beta']=1.5
    params['beta_GPi']=10
    params['gamma']=0.82
    params['state_units']['context']=False
    if params['distance']=='Euclidean':
        state_thresh={'Q1':[1.0,0],'Q2':[0.75,0.625]} #For normalized Euclidean distance
        alpha={'Q1':[0.3,0],'Q2':[0.2,0.1]}    #For normalized Euclidean distance
    else:
        state_thresh={'Q1':[0.22, 0.22],'Q2':[0.20, 0.22]} #For Gaussian Mixture?, 
        alpha={'Q1':[0.4,0],'Q2':[0.4,0.2]}    #For Gaussian Mixture? [0.62,0.19] for beta=0.6, 1Q or 2Q;'

    params['state_thresh']=state_thresh['Q'+str(params['numQ'])] #for euclidean distance, no noise
    #lower means more states for Euclidean distance rule
    params['alpha']=alpha['Q'+str(params['numQ'])] #  
    params['split']=True #if False - initialize new row in Q matrix to 0; if True - initialize to Q values of best matching state   
    ######################################
    from DiscriminationTaskParam2 import Racq,Tacq,env_params,Rext,Rdis,Tdis,Rrev,Trev
    epochs=['Beg','End']
    
    keys=rlu.construct_key(action_items +['rwd'],epochs)
    
    ### to plot performance vs trial block
    trials_per_block=10
    events_per_block=trials_per_block* events_per_trial
    num_blocks=int((numevents+1)/events_per_block)
    params['events_per_block']=events_per_block
    params['trials_per_block']=trials_per_block
    params['trial_subset']=trial_subset

    output_summary=[]
    key_params=['numQ','Q2other','beta_GPi','decision_rule','beta_min','beta','gamma']
    header=','.join(key_params)+',rwd_mean,rwd_std,half_rwd_block,half_block_std'
    output_summary.append(header)
    vary_param='beta' #'gamma' #
    for new_val in [0.9, 1.5, 2, 3, 5]: #[0.3,0.45,0.6,0.75,0.82,0.9,0.95,0.98]:
        params[vary_param]=new_val
        #params['beta_min']=new_val
        resultslist={phs:{k+'_'+ep:[] for k in keys.values() for ep in epochs} for phs in learn_phases}
        traject_dict={phs:{ta:[] for ta in traject_items[phs]} for phs in learn_phases}
        #count number of responses to the following actions:
        results={phs:{a:{'Beg':[],'End':[]} for a in action_items+['rwd']} for phs in learn_phases}
        resultslist['params']={p:[] for p in params.keys()} 
        all_beta={'_'.join(k):[] for k in phases}
        all_lenQ={k:{q:[] for q in range(params['numQ'])} for k in all_beta.keys()}
        for r in range(runs):
            rl={}
            if 'acquire' in learn_phases:
                ######### acquisition trials, context A, only 6 Khz + L turn allowed #########
                rl['acquire'] = RL(tone_discrim.completeT, QL.QL, states,act,Racq,Tacq,params,env_params,printR=printR)
                results,acqQ=rlu.run_sims(rl['acquire'],'acquire',numevents,trial_subset,action_items,noise,Info,acq_cue,-1,results,phist=plot_hist,block_DA=block_DA_dip)
                traject_dict=rl['acquire'].trajectory(traject_dict, traject_items,events_per_block)
            if 'extinc' in learn_phases:
                rl['extinc'] = RL(tone_discrim.completeT, QL.QL, states,act,Rext,Tacq,params,env_params,printR=printR,oldQ=acqQ)
                results,extQ=rlu.run_sims(rl['extinc'],'extinc',numevents,trial_subset,action_items,noise,Info,ext_cue,-1,results,phist=plot_hist,block_DA=block_DA_dip)
                traject_dict=rl['extinc'].trajectory(traject_dict, traject_items,events_per_block)
            #### renewal - blocking D2 or Da Dip not tested
            if 'renew' in learn_phases:
                rl['renew'] = RL(tone_discrim.completeT, QL.QL, states,act,Rext,Tacq,params,env_params,printR=printR,oldQ=extQ)
                results,renQ=rlu.run_sims(rl['renew'],'renew',numevents,trial_subset,action_items,noise,Info,ren_cue,-1,results)
                traject_dict=rl['renew'].trajectory(traject_dict, traject_items,events_per_block)
            ####### discrimination trials, add in 10Khz tone, + needed reward and state transitions
            if 'discrim' in learn_phases:
                #use last context in the list
                rl['discrim'] = RL(tone_discrim.completeT, QL.QL, states,act,Rdis,Tdis, params,env_params,oldQ=acqQ)
                acq_first=True
                results,disQ=rlu.run_sims(rl['discrim'],'discrim',int(numevents),trial_subset,action_items,noise,Info,dis_cue,-1,results,phist=plot_hist,block_DA=block_DA_dip)
                traject_dict=rl['discrim'].trajectory(traject_dict, traject_items,events_per_block)    
            ####### reverse trials, switch contingencies ####
            if 'reverse' in learn_phases:
                rl['reverse']=RL(tone_discrim.completeT, QL.QL, states,act,Rrev, Trev,params,env_params,oldQ=disQ)
                results,revQ=rlu.run_sims(rl['reverse'],'reverse',int(numevents),trial_subset,action_items,noise,Info,dis_cue,-1,results,phist=plot_hist)
                traject_dict=rl['reverse'].trajectory(traject_dict, traject_items,events_per_block)
            all_beta,all_lenQ=rlu.beta_lenQ(rl,phases,all_beta,all_lenQ,params['numQ'])

        ######### Average over runs, also need stdev.  
        all_ta=[]; output_data={}
        for phs in traject_dict.keys():
            output_data[phs]={}
            for ta in traject_dict[phs].keys():
                all_ta.append(ta)
                output_data[phs][ta]={'mean':np.mean(traject_dict[phs][ta],axis=0),'sterr':np.std(traject_dict[phs][ta],axis=0)/np.sqrt(runs-1)}
        all_ta=list(set(all_ta))
        #move reward to front
        all_ta.insert(0, all_ta.pop(all_ta.index('rwd')))              #
        for p in resultslist['params'].keys():               #
            resultslist['params'][p].append(params[p])
        resultslist=rlu.save_results(results,keys,resultslist)
        state_act=(('Pport','6kHz'),'left') #
        half,half_block=respond_decay(['acquire','extinc','renew'],state_act,traject_dict)
        ####### append summary results to list #############
        for phase in results.keys():
            newline=','.join([str(params[k]) for k in key_params])
            newline=newline+','+phase+','+str(np.round(np.mean(results[phase]['rwd']['End']),2))+','+str(np.round(np.std(results[phase]['rwd']['End']),2))
        ############## Evaluate decay of responding #####################
            if phase in half_block.keys():
                newline=newline+','+str(round(np.nanmean(half_block[phase]),2))+','+str(round(np.nanstd(half_block[phase]),2))
            else:
                newline=newline+',,'
            output_summary.append(newline)
        newline=','.join([str(params[k]) for k in key_params])
        total_rwd=np.array(results['acquire']['rwd']['End'])+np.array(results['discrim']['rwd']['End'])+np.array(results['reverse']['rwd']['End'])
        newline=newline+',TOTAL,'+str(np.round(np.mean(total_rwd),2))+','+str(np.round(np.std(total_rwd),2))
        output_summary.append(newline)        
        ########################## Save trajectories and Qhx
        actions=['left','right']
        agents=[[rl[phs] for phs in phaseset] for phaseset in phases]
        all_Qhx={q:{} for q in range(params['numQ'])};all_bounds={q:{} for q in range(params['numQ'])};all_ideals={q:{} for q in range(params['numQ'])}
        for ij,(state_subset,phase_set,agent_set) in enumerate(zip(state_sets,phases,agents)):
            Qhx, boundaries,ideal_states=Qhx_multiphase(state_subset,actions,agent_set,params['numQ'])
            for q in Qhx.keys():
                for st in Qhx[q].keys():
                    newstate=','.join(list(st))
                    all_Qhx[q][newstate+' '+str(ij)]=copy.deepcopy(Qhx[q][st])
                    all_ideals[q][newstate+' '+str(ij)]=copy.deepcopy(ideal_states[q][st])
                    all_bounds[q][newstate+' '+str(ij)]=copy.deepcopy(boundaries[q][st])
        del rl        
        import datetime
        dt=datetime.datetime.today()
        date=str(dt).split()[0]
        fname_params=key_params+['split']
        fname='Discrim'+date+'_'.join([k+str(params[k]) for k in fname_params])
        np.savez(fname,par=params,results=resultslist,traject=output_data)
        np.savez('Qhx'+fname,all_Qhx=all_Qhx,all_bounds=all_bounds,events_per_trial=events_per_trial,phases=phases,all_ideals=all_ideals,all_beta=all_beta,all_lenQ=all_lenQ)
    fname_params.remove(vary_param)
    #fname_params.remove('beta_min') #+'_bminbmax_'
    fname='Discrim'+date+'_'.join([k+str(params[k]) for k in fname_params])+vary_param+'range'
    np.save(fname,output_summary)