# -*- coding: utf-8 -*-
"""
Created on 2022 Dec 5
@author: kblackw1
"""
############ reward ################
from BanditTaskParam import params,rwd,validate_T,validate_R
rwd['reward']=(rwd['reward']+2*rwd['base']+2*rwd['partial']) #for equivalence to 3 step task?
act={'left':0,'right':1}
states={'loc':{'Pport':1},
'tone':{'6kHz':6}}
params['state_units']={'loc':False,'tone':False} #Try false/true
start=(states['loc']['Pport'],states['tone']['6kHz']) #used many times
env_params={'start':start}
loc=states['loc'] #used only to define R and T
tone=states['tone'] #used only to define R and T
params['events_per_trial']=1
Rbandit={};Tbandit={} #dictionaries to improve readability/prevent mistakes
prwdR=0.8; prwdL=0.5 #initial values. These change with each block of trials
Tbandit={start:{act['left']:[(start,1)],act['right']:[(start,1)]}}
Rbandit={start:{act['left']:[(rwd['reward'],prwdL),(rwd['base'],1-prwdL)], \
act['right']: [(rwd['reward'],prwdR),(rwd['base'],1-prwdR)]}}
if __name__== '__main__':
######## Make sure all needed transitions have been created
validate_T(Tbandit,msg='validate bandit T')
validate_R(Tbandit,Rbandit,msg='validate bandit R')