# -*- coding: utf-8 -*-
"""
Created on Fri Dec 10 2021
@author: kblackw1
"""
import numpy as np
def plot_trajectory(output_data,title,figure_sets):
############# Plots for publication ##########
from matplotlib import pyplot as plt
plt.ion()
colors=plt.get_cmap('inferno') #plasma, viridis, inferno or magma possible
#colors=['k','r','b','gray']
for phases in figure_sets:
if len(phases)>1:
color_increment=int((len(colors.colors)-40)/(len(phases)-1)) #40 to avoid to light colors
else:
color_increment=127
all_items=list(reversed(list(set([item for phs in phases for item in output_data[phs].keys()]))))
panels=[len(output_data[phs].keys()) for phs in phases]
fig,axis=plt.subplots(nrows=max(panels),ncols=1,sharex=True)
fig.suptitle(title)
ymin=0;ymax=0
for phs in phases:
ymax= max(ymax,np.max([np.max(vals['mean']+vals['sterr']) for vals in output_data[phs].values()]))
ymin= min(ymin,np.min([np.min(vals['mean']+vals['sterr']) for vals in output_data[phs].values()]) )
cnum=phases.index(phs)*color_increment
for ta,data in output_data[phs].items():
ax=all_items.index(ta)
#print(phs,ta,type(data['mean']))
if isinstance(data['mean'],np.ndarray) or isinstance(data['mean'],list):
if np.all(np.isnan(data['sterr'])):
axis[ax].plot(range(len(data['mean'])),data['mean'],label=phs,c=colors.colors[cnum])
else:
axis[ax].errorbar(range(len(data['mean'])),data['mean'],yerr=data['sterr'],label=phs,capsize=5,c=colors.colors[cnum])
if ta=='rwd':
axis[ax].set_ylabel('reward')
axis[ax].set_ylim([np.floor(ymin),np.ceil(ymax)])
else:
if len(ta[0])>1:
ylabel=ta[0][0][0:4]+','+ta[0][1]+' '+ta[1][0:3]
else:
ylabel=ta[0][0]+' '+ta[1]
axis[ax].set_ylabel(ylabel)
axis[ax].set_ylim([0,np.ceil(ymax)*1.05])
#if phs == 'discrim' or phs == 'reverse':
# axis[ax].set_ylim([0,10])
#else:
# axis[ax].set_ylim([0,11])
axis[ax].legend()
axis[-1].set_xlabel('block')
#plt.show()
return fig
def save_results(results,key_dict,resultslist):
for phase in results.keys():
if phase in resultslist.keys():
for sacombo in results[phase].keys():
for ep,counts in results[phase][sacombo].items():
resultslist[phase][key_dict[sacombo]+'_'+ep].append(counts)
return resultslist
'''
def save_results(results,epochs,allresults,resultslist):
for phase in results.keys():
for ac,counts in results[phase].items():
for ep in epochs:
allresults[phase+'_'+ac+'_'+ep].append(np.round(np.mean(counts[ep]),3))
print(phase,ac,ep,counts[ep])
resultslist[phase+'_'+ac+'_'+ep].append(counts[ep])
return allresults,resultslist
'''
def construct_key(state_actions,epochs=None):
keys={}
for sacombo in state_actions:
if sacombo =='rwd':
env=['rwd']
ac=''
else:
env=sacombo[0]
ac=sacombo[1]
keys[sacombo]='_'.join(env)+'_'+ac
return keys
def run_sims(RL,phase,events,n_subset,action_items,noise,info,cues,rr,summary,phist=0,block_DA=False):
#Need to add block_DA as input to episode and run_sims
RL.episode(events,noise=noise,info=info,cues=cues,name=phase,block_DA=block_DA)
rwd_prob=np.mean(RL.agent.learn_hist['rwd_prob'][-n_subset:])
#summary,t2=RL.count_actions(summary,action_items,trial_subset)
summary,t2=RL.count_state_action(summary,action_items,n_subset)
Q={'Q':RL.agent.Q,'ideal_states':RL.agent.ideal_states,'learn_weight':RL.agent.learn_weight,'rwd_prob':rwd_prob,'name':phase}
if hasattr(RL.agent,'V'):
Q['V']=RL.agent.V
if rr==0:
if np.max(RL.results['reward'])>0:
t2=',mean reward='+str(np.round(np.mean(RL.results['reward'][-n_subset:]),2))
RL.set_of_plots(phase,noise,t2,hist=phist)
return summary,Q
def beta_lenQ(all_agents,all_phases,all_beta,all_lenQ,numQ):
for phase_set in all_phases:
beta=[];lenQ={q:[] for q in range(numQ)}
key='_'.join(phase_set)
for phs in phase_set:
beta.append(all_agents[phs].agent.learn_hist['beta'])
for q,qlen in all_agents[phs].agent.learn_hist['lenQ'].items():
lenQ[q].append(qlen)
all_beta[key].append([b for bb in beta for b in bb])
for q in lenQ.keys():
all_lenQ[key][q].append([b for bb in lenQ[q] for b in bb])
return all_beta,all_lenQ