# -*- coding: utf-8 -*-
"""
Created on Wed Aug 19 12:12:28 2020
@author: kblackw1
"""
############ reward ################
rwd={'error':-5,'reward':10,'base':-1,'none':0}
#rwd={'error':-1,'reward':6,'base':0,'none':0,'partial':1} #use these for Opal and Bogacz?
######### Parameters for the agent ##################################
params={}
params['wt_learning']=False
params['wt_noise']=False #whether to multiply noise by learning_rate - not helpful
params['numQ']=1
params['alpha']=[0.3,0.06] # learning rate 0.3 and 0.06 produce learning in 400 trials,#slower for Q2 - D2 neurons
params['beta']=0.9 # inverse temperature, controls exploration
params['beta_min']=0.5
params['beta_GPi']=10 #Should be similar to using max
params['gamma']=0.9 #discount factor
params['hist_len']=40
params['state_thresh']=[0.12,0.2] #similarity of noisy state to ideal state
#if lower state_creation_threshold for Q[0] compared to Q[1], then Q[0] will have fewer states
#possibly multiply state_thresh by learn_rate? to change dynamically?
params['sigma']=0.25 #similarity of noisey state to ideal state,std used in Mahalanobis distance.
params['time_inc']=0.1 #increment time since reward by this much in no reward
params['moving_avg_window']=3 #This in units of trials, the actual window is this times the number of events per trial
params['decision_rule']=None #'combo', 'delta', 'sumQ2', None ## None means use direct negative of D1 rule
params['Q2other']=0.1
params['forgetting']=0
params['reward_cues']=None ##options: 'TSR', 'RewHx3', 'reward', None
params['distance']='Euclidean'
params['initQ']=-1 #split states, initialize Q values to best matching
params['events_per_trial']=3
############### Make sure you have all the state transitions needed ##########
def validate_T(T,msg=''):
print(msg)
for st in T.keys():
for newst_list in T[st].values():
for newst in newst_list:
if newst[0] not in T.keys():
print('new state', newst[0],'not in Tacq')
def validate_R(T,R,msg=''):
print(msg)
for st in T.keys():
if st not in R.keys():
print('state', st,'in T,but not in R')
else:
for a in T[st].keys():
if a not in R[st].keys():
print('state/action:', st,a,'in T, but not in R')
######### Parameters for the environment ##################################
act={'center':0,'left':1,'return':2,'right':3,'wander':4,'hold':5,'groom':6,'other':7}
states={'loc':{'start':0,'Pport':2,'Lport':1,'other':4,'Rport':3},
'tone':{'blip':0,'6kHz':6,'success':2,'error':-2,'10kHz':10}} #These values need units
params['state_units']={'loc':False,'tone':True}
#some convenient variables
start=(states['loc']['start'],states['tone']['blip']) #used many times
env_params={'start':start}
loc=states['loc'] #used only to define R and T
tone=states['tone'] #used only to define R and T
move=['left','right','wander','center']
stay=['groom','other','hold']# ['hold'] - if eliminate groom and other#
Racq={};Tacq={} #dictionaries to improve readability/prevent mistakes
####value of T dict is the new state
#from start - best response is poke
Tacq[start]={a:[(start,1)] for a in act.values()} # stay at start
for a in ['left','right']:
Tacq[start][act[a]]=[((loc['other'],tone['error']),1)] #error if go to left or right port from start box
Tacq[start][act['wander']]=[((loc['other'],tone['blip']),1)] #meandering, not yet at poke port
Tacq[start][act['center']]=[((loc['Pport'],tone['6kHz']),1)] #poke at start tone, go to poke port
#What happens if agent doesn't go to poke port immediately
Tacq[(loc['other'],tone['blip'])]={a:[((loc['other'],tone['blip']),1)] for a in act.values()} #default: remain in other unless
Tacq[(loc['other'],tone['blip'])][act['center']]=[((loc['Pport'],tone['6kHz']),1)] #go to center port, after wandering
Tacq[(loc['other'],tone['blip'])][act['return']]=[(start,1)] #return to start
#from poke port - correct response is 'left'
Tacq[(loc['Pport'],tone['6kHz'])]={a:[((loc['Pport'],tone['6kHz']),1)] for a in act.values()} #default - stay in poke port
for a in ['wander','return']: #incorrect movements
Tacq[(loc['Pport'],tone['6kHz'])][act[a]]=[((loc['other'],tone['error']),1)] #incorrect movements
Tacq[(loc['Pport'],tone['6kHz'])][act['right']]=[((loc['Rport'],tone['error']),1)]
Tacq[(loc['Pport'],tone['6kHz'])][act['left']]=[((loc['Lport'],tone['success']),1)] #hear tone in poke port, go left, in left port/success
Tacq[(loc['other'],tone['error'])]={a:[((loc['other'],tone['error']),1)] for a in act.values()}#remain in other unless
Tacq[(loc['other'],tone['error'])][act['return']]=[(start,1)] #return to start
#from left port - best response is return
Tacq[(loc['Lport'],tone['success'])]={act[a]:[((loc['other'],tone['error']),1)] for a in move} #default, wandering around
for a in stay:
Tacq[(loc['Lport'],tone['success'])][act[a]]=[((loc['Lport'],tone['error']),1)] #staying at Lport, but not continued success (reward)
Tacq[(loc['Lport'],tone['success'])][act['return']]=[(start,1)] #go back to start to begin again
#from right port - best response is return
Tacq[(loc['Rport'],tone['error'])]={act[a]:[((loc['other'],tone['error']),1)] for a in move}
for a in stay:
Tacq[(loc['Rport'],tone['error'])][act[a]]=[((loc['Rport'],tone['error']),1)] #remain in Rport if no movement,
Tacq[(loc['Rport'],tone['error'])][act['return']]=[(start,1)] #return to start
Tacq[(loc['Lport'],tone['error'])] = {act[a]:[((loc['other'],tone['error']),1)] for a in move}
for a in stay:
Tacq[(loc['Lport'],tone['error'])][act[a]]=[((loc['Lport'],tone['error']),1)] #remain in Lport if no movement,
Tacq[(loc['Lport'],tone['error'])][act['return']]=[(start,1)] #return to start
#error tone is not associated with penalty except with incorrect response from Pport
for k in Tacq.keys(): #Tacq determines what states pairs need reward values
Racq[k]={a:[(rwd['base'],1)] for a in act.values()} #default: cost of basic action
Racq[k][act['hold']]=[(rwd['none'],1)] #not moving - nocost
#reward for correct response
Racq[(loc['Pport'],tone['6kHz'])][act['left']]=[(rwd['reward'],0.9),(rwd['base'],0.1)] #lick in left port - 90% reward
#Error if go anywhere but correct port after tone
for a in ['right','wander','return']:
Racq[(loc['Pport'],tone['6kHz'])][act[a]]=[(rwd['error'],1)]
#Error if go straight to left or right port from start box - same error for discrimination
for a in['right','left']:
Racq[(loc['start'],tone['blip'])][act[a]]=[(rwd['error'],1)]
############## End acquisition parameters ###################
### initialize extinction as same as acquistion
## But, no errors or rewards
import copy #needed to make copies of dict of dicts
#Liu J Neurosci 40-6409 - during extinction, syringe pump (success sound) still triggered
#THUS, transitions the same, but no reward (nor penalty?)
Rext=copy.deepcopy(Racq)
Text=copy.deepcopy(Tacq)
for a in ['right','wander','return','left']:
Rext[(loc['Pport'],tone['6kHz'])][act[a]]=[(rwd['base'],1)] #base cost
###### initialize discrimination as same as acquistion
Rdis=copy.deepcopy(Racq)
Tdis=copy.deepcopy(Tacq)
### add in some states
#Change transitions from start so that 50% of time 6kHz and 50% 10kHz
Tdis[start][act['center']]=[((loc['Pport'],tone['6kHz']),0.5),((loc['Pport'],tone['10kHz']),0.5)] #after poking, each tone presented 50% of time
Tdis[(loc['other'],tone['blip'])][act['center']]=[((loc['Pport'],tone['6kHz']),0.5),((loc['Pport'],tone['10kHz']),0.5)] #go to poke port late
#add in transitions to (Pport,10kHz)
Tdis[(loc['Pport'],tone['10kHz'])]={a:[((loc['Pport'],tone['10kHz']),1)] for a in act.values()} #default - stay in poke port
for a in ['wander','return']:
Tdis[(loc['Pport'],tone['10kHz'])][act[a]]=[((loc['other'],tone['error']),1)] #incorrect movements
Tdis[(loc['Pport'],tone['10kHz'])][act['right']]=[((loc['Rport'],tone['success']),1)] #hear tone in poke port, go left, in left port/success
Tdis[(loc['Pport'],tone['10kHz'])][act['left']]=[((loc['Lport'],tone['error']),1)] #hear tone in poke port, go left, in left port/success
Tdis[(loc['Rport'],tone['success'])]={act[a]:[((loc['other'],tone['error']),1)] for a in move} ##wandering around
for a in stay:
Tdis[(loc['Rport'],tone['success'])][act[a]]=[((loc['Rport'],tone['error']),1)] #staying at Rport, but not continued success (reward)
Tdis[(loc['Rport'],tone['success'])][act['return']]=[(start,1)] #go back to start to begin again
#add in rewards to (Pport,10kHz)
Rdis[(loc['Pport'],tone['10kHz'])]={a:[(rwd['base'],1)] for a in act.values()}
Rdis[(loc['Pport'],tone['10kHz'])][act['right']]=[(rwd['reward'],0.9),(rwd['base'],0.1)] #lick in left port - 90% reward
Rdis[(loc['Pport'],tone['10kHz'])][act['hold']]=[(rwd['none'],1)]
for a in ['left','wander','return']:
Rdis[(loc['Pport'],tone['10kHz'])][act[a]]=[(rwd['error'],1)]
#add in rewards to (Rport,success)
Rdis[(loc['Rport'],tone['success'])]={a:[(rwd['base'],1)] for a in act.values()}
Rdis[(loc['Rport'],tone['success'])][act['hold']]=[(rwd['none'],1)]
############## End discrimination parameters ###################
#complex reversal - change which behavior rewarded for which tone
#i.e., now 6 Khz means go right for reward (not left)
# 10 Khz means go left for rewerd (not right)
Rrev=copy.deepcopy(Rdis)
Trev=copy.deepcopy(Tdis)
Trev[(loc['Pport'],tone['6kHz'])][act['right']]=[((loc['Rport'],tone['success']),1)]
Trev[(loc['Pport'],tone['6kHz'])][act['left']]=[((loc['Lport'],tone['error']),1)] #hear tone in poke port, go left, in left port/success
Trev[(loc['Pport'],tone['10kHz'])][act['right']]=[((loc['Rport'],tone['error']),1)] #hear tone in poke port, go left, in left port/success
Trev[(loc['Pport'],tone['10kHz'])][act['left']]=[((loc['Rport'],tone['success']),1)] #hear tone in poke port, go left, in left port/success
Rrev[(loc['Pport'],tone['10kHz'])][act['left']]=[(rwd['reward'],0.9),(rwd['base'],0.1)] #lick in left port - 90% reward
Rrev[(loc['Pport'],tone['10kHz'])][act['right']]=[(rwd['error'],1)] #lick in left port - 90% reward
Rrev[(loc['Pport'],tone['6kHz'])][act['left']]=[(rwd['error'],1)] #lick in left port - 90% reward
Rrev[(loc['Pport'],tone['6kHz'])][act['right']]=[(rwd['reward'],0.9),(rwd['base'],0.1)] #lick in left port - 90% reward
################# Two arm bandit task of Josh Berke (Hamid et al. Nat Neuro V19)
Rbandit=copy.deepcopy(Racq)
Tbandit=copy.deepcopy(Tacq)
prwdR=0.8; prwdL=0.5 #initial values. These change with each block of trials
Tbandit[(loc['Pport'],tone['6kHz'])][act['right']]=[((loc['Rport'],tone['success']),prwdR),((loc['Rport'],tone['error']),1-prwdR)]
Rbandit[(loc['Pport'],tone['6kHz'])][act['right']]=[(rwd['reward'],prwdR),(rwd['base'],1-prwdR)]
Tbandit[(loc['Pport'],tone['6kHz'])][act['left']]=[((loc['Lport'],tone['success']),prwdL),((loc['Lport'],tone['error']),1-prwdL)] #hear tone in poke port, go left, in left port/success
Rbandit[(loc['Pport'],tone['6kHz'])][act['left']]=[(rwd['reward'],prwdL),(rwd['base'],1-prwdL)]
Tbandit[(loc['Rport'],tone['success'])]={act[a]:[((loc['other'],tone['error']),1)] for a in move} #default, wandering around
for a in stay:
Tbandit[(loc['Rport'],tone['success'])][act[a]]=[((loc['Rport'],tone['error']),1)] #staying at Lport, but not continued success (reward)
Tbandit[(loc['Rport'],tone['success'])][act['return']]=[(start,1)] #go back to start to begin again
Rbandit[(loc['Rport'],tone['success'])]={a:[(rwd['base'],1)] for a in act.values()} #default: cost of basic action
Rbandit[(loc['Rport'],tone['success'])][act['hold']]=[(rwd['none'],1)]
if __name__== '__main__':
######## Make sure all needed transitions have been created
validate_T(Tacq,msg='validate Tacq')
validate_R(Tacq,Racq,msg='validate Racq')
validate_T(Tdis,msg='validate discrim T')
validate_R(Tdis,Rdis,msg='validate discrim R')
validate_T(Trev,msg='validate reversal T')
validate_R(Trev,Rrev,msg='validate reversal R')
validate_T(Tbandit,msg='validate bandit T')
validate_R(Tbandit,Rbandit,msg='validate bandit R')