# -*- coding: utf-8 -*-
"""
Created on Wed Aug 19 12:12:28 2020

@author: kblackw1
"""

############ reward   ################
rwd={'error':-1,'reward':15,'base':-1,'none':0} 

######### Parameters for the agent ##################################
params={}
params['beta_GPi']=10
params['wt_learning']=False
params['wt_noise']=False #whether to multiply noise by learning_rate - not helpful
params['numQ']=1
params['alpha']=[0.3,0.06]  # learning rate 0.3 and 0.06 produce learning in 400 trials,slower for Q2 - D2 neurons
params['beta']=0.9  # inverse temperature, controls exploration
params['beta_min']=0.5
params['gamma']=0.9  #discount factor 
params['hist_len']=40
params['state_thresh']=[0.12,0.2] #similarity of noisy state to ideal state
#if lower state_creation_threshold for Q[0] compared to Q[1], then Q[0] will have fewer states
#possibly multiply state_thresh by learn_rate? to change dynamically?
params['sigma']=0.25 #similarity of noisey state to ideal state,std used in Mahalanobis distance.
params['time_inc']=0.1 #increment time since reward by this much in no reward
params['moving_avg_window']=3 ##This in units of trials, the actual window is this times the number of events per trial
params['decision_rule']=None #'combo', 'delta', 'sumQ2'
params['Q2other']=0.1
params['forgetting']=0
params['reward_cues']=None #options: 'TSR', 'RewHx3', 'reward'
params['distance']='Euclidean'
params['split']=True

############### Make sure you have all the state transitions needed ##########
def validate_T(T,msg=''):
    print(msg)
    for st in T.keys():
        for newst_list in T[st].values():
           for newst in newst_list:
               if newst[0] not in T.keys():
                   print('new state', newst[0],'not in Tacq')

def validate_R(T,R,msg=''):
    print(msg)
    for st in T.keys():
        if st not in R.keys():
            print('state', st,'in T,but not in R')
        else:
            for a in T[st].keys():
                if a not in R[st].keys():
                    print('state/action:', st,a,'in T, but not in R')
   
######### Parameters for the environment ##################################
act={'goMag':0,'goL':1,'goR':2,'press':3,'other':4}  #other includes grooming, not moving
Hx_len=4 # this specifies the number of for loops #Hx_len=3 works better than 4

hx_values=['L','R'] #possible characters in the history
start_hx='-'*Hx_len #starting press history is 'empty'
'''
WOrse performance if no '-'
hx_values=['L','R'] #possible characters in the history
import numpy
#starting press history is random
start_presses=numpy.random.randint(len(hx_values),size=Hx_len)
start_hx=''.join([hx_values[s] for s in start_presses]) 
'''
### Enumerate all possible 3-way combinations of press history
sequences={}
value=0
if Hx_len==3:
    for c1 in hx_values:
        for c2 in hx_values:
            for c3 in hx_values:
               sequences[c1+c2+c3]=value
               value+=1 ############ Why are states being numbered?  - only for the states dictionary.  
    for c3 in hx_values:
        sequences['--'+c3]=value
        value+=1
        for c2 in hx_values:
            sequences['-'+c2+c3]=value
            value+=1
    sequences['---']=value
    params['events_per_trial']=6
elif Hx_len==4:
    for c1 in hx_values:
        for c2 in hx_values:
            for c3 in hx_values:
                for c4 in hx_values:
                    sequences[c1+c2+c3+c4]=value
                    value+=1 ############ Why are states being numbered?  - only for the states dictionary.  
    for c4 in hx_values:
        sequences['---'+c4]=value
        value+=1
        for c3 in hx_values:
            sequences['--'+c3+c4]=value
            value+=1
            for c2 in hx_values:
                sequences['-'+c2+c3+c4]=value
                value+=1
    sequences['----']=value
    params['events_per_trial']=7
else:
    print('unanticipated Hx_len in press history')
           
#create state dictionary
states={'loc':{'mag':0,'Llever':1,'Rlever':2,'other':3},
        'hx': sequences} 
params['state_units']={'loc':False,'hx':False}
#some convenient variables
loc=states['loc'] #used only to define R and T
hx=states['hx'] #used only to define R and T

Tloc={loc[location]:{} for location in loc}  #dictionaries to improve readability/prevent mistakes

#two Transition matrices - NOTE, this is the transition for locations
#The transition for lever presses is a function specified in the environment

for location in ['Llever','Rlever','other','mag']:
    Tloc[loc[location]][act['goL']]=[(loc['Llever'],1)]
    Tloc[loc[location]][act['goR']]=[(loc['Rlever'],1)]
    Tloc[loc[location]][act['goMag']]=[(loc['mag'],1)]
    
for location in ['Llever','Rlever','other','mag']:
    for action in ['press','other']:
        Tloc[loc[location]][act[action]]=[(loc[location],1)]


#where to start episodes, and also re-start trial after reward  
start=(states['loc']['mag'],states['hx'][start_hx]) 
#put some environment values into dictionary for ease of param passing
env_params={'start':start,'hx_len':Hx_len,'hx_act':'press'}

#Reward matrix: enumerates all states.  
#Would be nice to avoid such enumeration and create function similar to T
R={}
for k in Tloc.keys(): #T determines what states pairs need reward values
    for st in states['hx'].values():
        R[(k,st)]={a:[(rwd['base'],1)] for a in act.values()} #default: cost of basic action
if Hx_len==3:
    R[(loc['Rlever'],states['hx']['LLR'])][act['press']]=[(rwd['reward'],0.95),(rwd['base'],0.05)]  #95% reward for correct press sequence
elif Hx_len==4:
    for location in loc:
        R[(loc[location],states['hx']['LLRR'])][act['goMag']]=[(rwd['reward'],0.95),(rwd['base'],0.05)]   #95% reward for correct press sequence
else:
    print('unanticipated Hx_len in reward assignment')

if __name__== '__main__':
    ######## Make sure all needed transitions have been created
    validate_T(Tloc,msg='validate Tloc')
    validate_R(Tloc,R,msg='validate R')
    print('press history length=',Hx_len,', start press',start_hx,', hx values',hx_values)