# -*- coding: utf-8 -*-
"""
Created on Wed Dec 2 17:22:45 2020
@author: kblackw1
"""
import numpy as np
def run_sim(params,numevents,noise,results,sa_combo,trial_subset,printR=False):
import sequence_env as env
import agent_twoQtwoSsplit as QL
from SequenceTask import RL
from SequenceTaskParam import env_params, states,act,Tloc,R
acq = RL(env.separable_T, QL.QL, states,act,R,Tloc,params,env_params,printR=printR)
acq.episode(numevents,noise=noise,info=False)
results=acq.count_actions(results,sa_combo,trial_subset,accum_type='count')
#count number of times maximal reward obtained during trial_subset
del acq
return results
numtrials=600 #sim time for 600 trials, Hx_len=4, 10 runs is < 10 min on laptop. 25 sims < 250 min = 4 hours - HiPri
runs=10 #use 2 to test, 10 to simulate
noise=0.01 #make noise small enough or state_thresh small enough to minimize new states in acquisition
printR=False #print environment Reward matrix
from SequenceTask import save_results,construct_key
from SequenceTaskParam import Hx_len,params
if Hx_len==3:
#MINIMUM actions for reward = 6, so maximum rewards = 1 per 6 "trials"
state_action_combos=[(('*','*LL'), 'goR'),(('Rlever','*LL'),'press'),(('Rlever','LLR'),'press')]
events_per_trial=6
elif Hx_len==4:
#MINIMUM actions for reward = 7, so maximum rewards = 1 per 7 "trials"
state_action_combos=[(('*','**LL'), 'goR'),(('Rlever','**LL'),'press'),(('Rlever','*LLR'),'press'),(('Rlever','LLRR'),'goMag')]
events_per_trial=7
else:
print('unrecognized press history length')
params['events_per_trial']=events_per_trial
epochs=['Beg','End']
#loop over values for state_Thresh, alpha1,alpha2 here
state_thresh=[0.5,0.625,0.75,0.875,1.0] # 5 values
alpha1=[0.1,0.2,0.3,0.4,0.5]
numevents=numtrials*events_per_trial #number of events/actions allowed for agent per run/trial
trial_subset=int(0.05*numevents)# display mean reward and count actions over 1st and last of these number of trials
max_correct=trial_subset/events_per_trial/100
keys=construct_key(state_action_combos +['rwd'],epochs)
allresults={k+'_'+ep:[] for k in keys.values() for ep in epochs}
allresults['params']={p:[] for p in params} #to store list of parameters
resultslist={k+'_'+ep:[] for k in keys.values() for ep in epochs}
resultslist['params']={p:[] for p in params}
st2=0
a2=0
for st1 in state_thresh:
for a1 in alpha1:
params['numQ']=1
params['state_thresh']=[round(st1,3),round(st2,3)] #threshold on prob for creating new state
# higher means more states.
params['alpha']=[round(a1,3),round(a2,3)]
params['Hx_len']=Hx_len
for p in allresults['params'].keys():
allresults['params'][p].append(params[p]) #
resultslist['params'][p].append(params[p]) #
results={sa:{'Beg':[],'End':[]} for sa in state_action_combos+['rwd']}
for r in range(runs):
results=run_sim(params,numevents,noise,results,state_action_combos,trial_subset,printR)
allresults,resultslist=save_results(results,epochs,allresults,keys,resultslist)
resultslist['params']['max_correct']=max_correct
fname='Sequence_paramsHxLen'+str(params['Hx_len'])+'_Q'+'_'.join([str(params['numQ']),str(round(st1,3)),str(round(a1,3))])
np.savez(fname,allresults=allresults,params=params,reslist=resultslist)
fname='Sequence_paramsHxLen'+str(params['Hx_len'])+'_Q'+str(params['numQ'])+'_all'
np.savez(fname,allresults=allresults,params=params,reslist=resultslist)