# -*- coding: utf-8 -*-
"""
Created on Wed Dec 2 17:22:45 2020
@author: kblackw1
"""
from __future__ import print_function, division
def run_sim(params,numevents,noise,results,sa_combo,trial_subset,printR=False):
import sequence_env as env
import agent_twoQtwoSsplit as QL
from SequenceTask import RL
from SequenceTaskParam import env_params, states,act,Tloc,R
acq = RL(env.separable_T, QL.QL, states,act,R,Tloc,params,env_params,printR=printR)
acq.episode(numevents,noise=noise,info=False)
results=acq.count_actions(results,sa_combo,trial_subset,accum_type='count')
del acq
return results
def run_one_set(p):
st1,st2,q2o=p
import numpy as np
from SequenceTask import save_results,construct_key
numtrials=600 #allow agent to perform this many actions/events
runs=10
noise=0.01 #make noise small enough or state_thresh small enough to minimize new states in acquisition
printR=False #print environment Reward matrix
from SequenceTaskParam import Hx_len,params
if Hx_len==3:
#MINIMUM actions for reward = 6, so maximum rewards = 1 per 6 actions/events
state_action_combos=[(('*','*LL'), 'goR'),(('Rlever','*LL'),'press'),(('Rlever','LLR'),'press')]
events_per_trial=6
elif Hx_len==4:
#MINIMUM actions for reward = 7, so maximum rewards = 1 per 7 actions/events
state_action_combos=[(('*','**LL'), 'goR'),(('Rlever','**LL'),'press'),(('Rlever','*LLR'),'press'),(('Rlever','LLRR'),'goMag')]
events_per_trial=7
else:
print('unrecognized press history length')
params['events_per_trial']=events_per_trial
alpha1=[0.2,0.3,0.4,0.5]
alpha2=[0.1,0.15,0.2,0.25,0.3,0.35,0.4]
numevents=numtrials*events_per_trial #number of events/actions allowed for agent per run/trial
trial_subset=int(0.05*numevents)# display mean reward and count actions over 1st and last of these number of trials
max_correct=trial_subset/events_per_trial/100
epochs=['Beg','End']
keys=construct_key(state_action_combos +['rwd'],epochs)
### allresults - store mean performance vs parameter at begining and end of trials
allresults={k+'_'+ep:[] for k in keys.values() for ep in epochs}
allresults['params']={p:[] for p in params} #to store list of parameters
resultslist={k+'_'+ep:[] for k in keys.values() for ep in epochs}
resultslist['params']={p:[] for p in params}
for a1 in alpha1: #at least 2*a2, double increment
for a2 in alpha2:
print('************ NEW SIM *********',np.round(st1,3),np.round(st2,3),np.round(a1,3),np.round(a2,3))
#update some parameters
params['numQ']=2
params['state_thresh']=[np.round(st1,3),np.round(st2,3)] #threshold on prob for creating new state
# higher means more states.
params['alpha']=[np.round(a1,3),np.round(a2,3)]
params['Hx_len']=Hx_len
params['Q2other']=q2o
#results: initialize for each set of parameters
for p in allresults['params'].keys():
allresults['params'][p].append(params[p]) #
resultslist['params'][p].append(params[p]) #
results={sa:{'Beg':[],'End':[]} for sa in state_action_combos+['rwd']}
for r in range(runs):
results=run_sim(params,numevents,noise,results,state_action_combos,trial_subset,printR)
allresults,resultslist=save_results(results,epochs,allresults,keys,resultslist)
resultslist['params']['max_correct']=max_correct
fname='Sequence_HxLen'+str(params['Hx_len'])+'_Q'+str(params['numQ'])+'_q2o'+'_'.join([str(params['Q2other']),str(round(st1,3)),str(round(st2,3))])+'_all'
print('**************** End of a1,a2 loop ************', fname)
np.savez(fname,allresults=allresults,params=params,reslist=resultslist)
return
if __name__ == "__main__":
from multiprocessing.pool import Pool
import os
#loop over values for state_Thresh, alpha1,alpha2 here
state_thresh=[0.5,0.625,0.75,0.875,1.0]
Q2_other=[0.05,0.1,0.2]
############### This is not quite working - simulations randomly stop for some param combos with no error message ############
########## Possibly running out of memory? ###################
params=[(round(st1,3),round(st2,3),round(q2o,2)) for st1 in state_thresh for st2 in state_thresh for q2o in Q2_other]
max_pools=os.cpu_count()
#num_pools=min(len(params),max_pools) #needed on single workstation
num_pools=len(params)
print('************* number of processors',max_pools,' params',len(params),params)
p = Pool(num_pools,maxtasksperchild=1)
p.map(run_one_set,params)
print('#################### Returned from p.map ##################')