# -*- coding: utf-8 -*-
"""
Created on Wed Aug 19 12:12:28 2020

@author: kblackw1
"""

############ reward   ################
rwd={'error':-5,'reward':10,'base':-1,'none':-1,'partial':0} 
#rwd={'error':-1,'reward':8,'base':0,'none':0,'partial':0}  #use these for Opal and Bogacz to show that 3 step tasks do not work - same mean rwd per trial
#rwd={'error':-1,'reward':6,'base':0,'none':0,'partial':1} #use these for Opal and Bogacz to improve performance on 3 step tasks - same mean rwd per trial
#rwd={'error':-2,'reward':5,'base':-0.5,'none':-0.5,'partial':0}  #show more lose-shift and sampling behavior? No
######### Parameters for the agent ##################################
params={}
params['wt_learning']=False
params['wt_noise']=False #whether to multiply noise by learning_rate - not helpful
params['numQ']=2
params['alpha']=[0.3,0.06]  # learning rate 0.3 and 0.06 produce learning in 400 trials,#slower for Q2 - D2 neurons 
params['beta']=0.9  # inverse temperature, controls exploration
params['beta_min']=0.1
params['gamma']=0.9  #discount factor# 
params['hist_len']=40
params['state_thresh']=[0.12,0.2] #similarity of noisy state to ideal state
#if lower state_creation_threshold for Q[0] compared to Q[1], then Q[0] will have fewer states
#possibly multiply state_thresh by learn_rate? to change dynamically?
params['sigma']=0.25 #similarity of noisey state to ideal state,std used in Mahalanobis distance.
params['time_inc']=0.1 #increment time since reward by this much in no reward
params['moving_avg_window']=3  #This in units of trials, the actual window is this times the number of events per trial
params['decision_rule']=None #'combo', 'delta', 'sumQ2', None ## None means use direct negative of D1 rule
params['Q2other']=0.1
params['forgetting']=0
params['reward_cues']=None ##options: 'TSR', 'RewHx3', 'reward', None
params['distance']='Euclidean'
params['initQ']=-1 #split states, initialize Q values to best matching
params['events_per_trial']=3

############### Make sure you have all the state transitions needed ##########
def validate_T(T,msg=''):
    print(msg)
    for st in T.keys():
        for newst_list in T[st].values():
           for newst in newst_list:
               if newst[0] not in T.keys():
                   print('new state', newst[0],'not in Tacq')

def validate_R(T,R,msg=''):
    print(msg)
    for st in T.keys():
        if st not in R.keys():
            print('state', st,'in T,but not in R')
        else:
            for a in T[st].keys():
                if a not in R[st].keys():
                    print('state/action:', st,a,'in T, but not in R')
   
######### Parameters for the environment ##################################
include_wander=True
act={'center':0,'left':1,'return':2,'right':3,'hold':4} # , 'wander':5 
if include_wander:
    act['wander']=5
states={'loc':{'start':0,'Pport':2,'Lport':1,'other':-1,'Rport':3},
        'tone':{'blip':0,'6kHz':6,'success':2,'error':-2,'10kHz':10}} 
params['state_units']={'loc':False,'tone':True} #Try false/true

#some convenient variables
start=(states['loc']['start'],states['tone']['blip']) #used many times
env_params={'start':start}
loc=states['loc'] #used only to define R and T
tone=states['tone'] #used only to define R and T
move=['left','right','center'] #,'wander' 
stay=['hold']
if include_wander:
    move=move+['wander']

################# Two arm bandit task of Josh Berke (Hamid et al. Nat Neuro V19)
Rbandit={};Tbandit={}  #dictionaries to improve readability/prevent mistakes
prwdR=0.8; prwdL=0.5 #initial values.  These change with each block of trials

####value of T dict is the new state
#from start - best response is poke
Tbandit[start]={a:[(start,1)] for a in act.values()} # stay at start, unless move
for a in ['left','right']:
    Tbandit[start][act[a]]=[((loc['other'],tone['error']),1)] #error if go to left or right port from start box
Tbandit[start][act['center']]=[((loc['Pport'],tone['6kHz']),1)] #poke at start tone, go to poke port

if include_wander:
    Tbandit[start][act['wander']]=[((loc['other'],tone['blip']),1)] #meandering, not yet at poke port **
    #What happens if agent doesn't go to poke port immediately **
    Tbandit[(loc['other'],tone['blip'])]={a:[((loc['other'],tone['blip']),1)] for a in act.values()} #default: remain in other, unless
    Tbandit[(loc['other'],tone['blip'])][act['center']]=[((loc['Pport'],tone['6kHz']),1)] #go to center port, after wandering
    Tbandit[(loc['other'],tone['blip'])][act['return']]=[(start,1)] #return to start

#from poke port - correct response is 'left'
Tbandit[(loc['Pport'],tone['6kHz'])]={a:[((loc['Pport'],tone['6kHz']),1)] for a in act.values()} #default - stay in poke port
if include_wander:
    Tbandit[(loc['Pport'],tone['6kHz'])][act['wander']]=[((loc['other'],tone['error']),1)]  #incorrect movements
Tbandit[(loc['Pport'],tone['6kHz'])][act['return']]=[((loc['start'],tone['error']),1)]  #incorrect movements
Tbandit[(loc['Pport'],tone['6kHz'])][act['right']]=[((loc['Rport'],tone['success']),prwdR),((loc['Rport'],tone['error']),1-prwdR)]
Tbandit[(loc['Pport'],tone['6kHz'])][act['left']]=[((loc['Lport'],tone['success']),prwdL),((loc['Lport'],tone['error']),1-prwdL)] #hear tone in poke port, go left, in left port/success

Tbandit[(loc['other'],tone['error'])]={a:[((loc['other'],tone['error']),1)] for a in act.values()}#remain in other unless
Tbandit[(loc['start'],tone['error'])]={act[a]:[((loc['other'],tone['error']),1)] for a in move}
Tbandit[(loc['start'],tone['error'])][act['hold']]=[(start,1)]

#from left port or right port - best response is return
Tbandit[(loc['Lport'],tone['success'])]={act[a]:[((loc['other'],tone['error']),1)] for a in move} #default, wandering around
Tbandit[(loc['Lport'],tone['success'])][act['hold']]=[((loc['Lport'],tone['error']),1)] #staying at Lport, but not continued success (reward)
Tbandit[(loc['Lport'],tone['success'])][act['return']]=[(start,1)] #go back to start to begin again

Tbandit[(loc['Rport'],tone['success'])]={act[a]:[((loc['other'],tone['error']),1)] for a in move} #default, wandering around
Tbandit[(loc['Rport'],tone['success'])][act['hold']]=[((loc['Rport'],tone['error']),1)] #staying at Lport, but not continued success (reward)
Tbandit[(loc['Rport'],tone['success'])][act['return']]=[(start,1)] #go back to start to begin again

#from right or left port or any with error tone - best response is return
Tbandit[(loc['Rport'],tone['error'])]={act[a]:[((loc['other'],tone['error']),1)] for a in move} 
Tbandit[(loc['Rport'],tone['error'])][act['hold']]=[((loc['Rport'],tone['error']),1)] #remain in Rport if no movement,

Tbandit[(loc['Lport'],tone['error'])] = {act[a]:[((loc['other'],tone['error']),1)] for a in move}
Tbandit[(loc['Lport'],tone['error'])][act['hold']]=[((loc['Lport'],tone['error']),1)] #remain in Lport if no movement,

for location in ['Rport','Lport','start','other']:
    Tbandit[(loc[location],tone['error'])][act['return']]=[(start,1)] #return to start

#error tone is associated with penalty when agent makes incorrect movement
for k in Tbandit.keys(): #Tbandit determines what states pairs need reward values
    Rbandit[k]={a:[(rwd['base'],1)] for a in act.values()} #default: cost of basic action
    Rbandit[k][act['hold']]=[(rwd['none'],1)] #not moving - no cost

if rwd['partial']>0:
    Rbandit[start][act['center']]=[(rwd['partial'],1)]
    Rbandit[(loc['Rport'],tone['success'])][act['return']]=[(rwd['partial'],1)]
    Rbandit[(loc['Lport'],tone['success'])][act['return']]=[(rwd['partial'],1)]
#reward for correct response
Rbandit[(loc['Pport'],tone['6kHz'])][act['right']]=[(rwd['reward'],prwdR),(rwd['base'],1-prwdR)]  
Rbandit[(loc['Pport'],tone['6kHz'])][act['left']]=[(rwd['reward'],prwdL),(rwd['base'],1-prwdL)]

#Error if go anywhere but Left or Right port after tone (but hold is not an error)
Rbandit[(loc['Pport'],tone['6kHz'])][act['return']]=[(rwd['error'],1)] 
if include_wander:
    Rbandit[(loc['Pport'],tone['6kHz'])][act['wander']]=[(rwd['error'],1)] 
#Error if go straight to left or right port from start box - same error for discrimination
for a in['right','left']:
    Rbandit[(loc['start'],tone['blip'])][act[a]]=[(rwd['error'],1)] 

###### Error if agent wanders or stays at Rport or Lport
# These are needed to keep agent from not returning for new trial after success
for a in ['right','left','center']:
    Rbandit[(loc['Rport'],tone['success'])][act[a]]=[(rwd['error'],1)] 
    Rbandit[(loc['Lport'],tone['success'])][act[a]]=[(rwd['error'],1)] 
if include_wander:
    Rbandit[(loc['Rport'],tone['success'])][act['wander']]=[(rwd['error'],1)] 
    Rbandit[(loc['Lport'],tone['success'])][act['wander']]=[(rwd['error'],1)] 

############## End bandit parameters ###################

if __name__== '__main__':
    ######## Make sure all needed transitions have been created
    validate_T(Tbandit,msg='validate bandit T')
    validate_R(Tbandit,Rbandit,msg='validate bandit R')