# -*- coding: utf-8 -*-
"""
Created on Thu May 13 10:44:25 2021
@author: kblackw1
"""
import numpy as np
import matplotlib.pyplot as plt
import copy
import sys
plt.ion()
# possibly accumulate Qhx, as done in Sequence task -
# separate out this function into accumulating and plotting
# CALL THE qhx accum function in the for r in runs loop
# possibly save Qhx, and read in for the plot
import string
letters=string.ascii_uppercase
blank=0.03
fsize=12
fsizeSml=10
def Qhx_multiphaseNewQ(states,actions,agents,numQ):
#find the state number corresponding to states for each learning phase
#this does not work if each phase starts from New - not using oldQ
sorted_agents=sorted(agents, key=lambda x: float(x.name.split(':')[0])/float(x.name.split(':')[1]))
state_digits=1
ideal_states={q:{} for q in range(numQ)}
Qhx={q:{state:{ac:[[] for n in range(1)] for ac in actions} for state in states} for q in range(numQ)}
#boundary stores x values for drawing phase boundaries
boundary={q:{state:[0] for state in states} for q in range(numQ)}
for rl in sorted_agents:
state_nums={q:{state:[] for state in states} for q in range(numQ)}
for q in range(numQ):
for state in states:
for stnum,st in rl.agent.ideal_states[q].items():
######## Note, [0] and [1] below implies each state has two components. Need to generalize this
if int(round(st[0]))==rl.env.states['loc'][state[0]] and int(round(st[1]))==rl.env.states['tone'][state[1]]:
state_nums[q][state].append(stnum)
if len(state_nums[q][state])>1:
sys.exit('Error, too many states {} for {} created for phase {}, Q {} starting from NEW '.format(state_nums[q][state],state,rl.name,q))
elif len(state_nums[q][state])==0:
print(' @@@@@@@@@@@@@@@@@ state',state,'not found for Q=',q, 'agent',rl.name)
state_nums[q][state]=[-1]
ideal_states[q][state]=[np.nan,np.nan]
else:
ideal_states[q][state]=[[str(round(a,state_digits)) for a in rl.agent.ideal_states[q][stnum]] for stnum in state_nums[q][state]]
#print('Q',q,'ideal_states',ideal_states[q],'state_nums',state_nums[q])
for state,stnums in state_nums[q].items():
boundary[q][state].append(boundary[q][state][-1]+len(rl.agent.Qhx[q])/rl.agent.events_per_trial)
for ac,arr in Qhx[q][state].items():
for arr_num,stn in enumerate(stnums):
#print('concat phase, Q',q,'state',state,stnums,'ac',ac,'arr_num',arr_num,'stn',stn)
if stn==-1: #arr_num = 0, no other stn values
print( ' @@@@@@ ready to add Q history for missing state',state,'for Q=',q, 'agent',rl.name,'1 state num:', stnums)
if len(Qhx[q][state][ac][arr_num]):
Qhx[q][state][ac][arr_num]=np.concatenate((Qhx[q][state][ac][arr_num],np.zeros(len(rl.agent.Qhx[q])))) ### create array of zeros for that phase/state/action
else:
Qhx[q][state][ac][arr_num]=np.zeros(len(rl.agent.Qhx[q])) ### create array of zeros for that phase/state/action
elif stn<np.shape(rl.agent.Qhx[q])[1]:
if len(Qhx[q][state][ac][arr_num]):
Qhx[q][state][ac][arr_num]=np.concatenate((Qhx[q][state][ac][arr_num],rl.agent.Qhx[q][:,stn,rl.agent.actions[ac]]))
else:
Qhx[q][state][ac][arr_num]=rl.agent.Qhx[q][:,stn,rl.agent.actions[ac]]
return Qhx, boundary,ideal_states
def Qhx_multiphase(states,actions,agents,numQ):
#find the state number corresponding to states for each learning phase
#this does not work if each phase starts from New - not using oldQ
state_digits=1
state_nums={q:{state:[] for state in states} for q in range(numQ)}
ideal_states={q:{} for q in range(numQ)}
num_states=1
for rl in agents:
for q in range(numQ):
for state in states: #identify the state number (index into ideal_states) for state
for stnum,st in rl.agent.ideal_states[q].items():
######## Note, 'loc',[0] and 'tone',[1] below implies each state has two components. Need to generalize this
if int(round(st[0]))==rl.env.states['loc'][state[0]] and int(round(st[1]))==rl.env.states['tone'][state[1]]:
state_nums[q][state].append(stnum)
state_nums[q][state]= list(np.unique(state_nums[q][state])) #next line may fail if subsequent agents don't use oldQs
ideal_states[q][state]=[[str(round(a,state_digits)) for a in rl.agent.ideal_states[q][stnum]] for stnum in state_nums[q][state]]
num_states=max(num_states,len(state_nums[q][state]))
#concatenate Q value history across phases for above states and actions
#Qhx={q:{state:{ac:[[] for n in range(num_states)] for ac in actions} for state in state_nums[q].keys()} for q in state_nums.keys()}
Qhx={q:{state:{ac:[[] for n in range(len(state_nums[q][state]))] for ac in actions} for state in state_nums[q].keys()} for q in state_nums.keys()}
#boundary stores x values for drawing phase boundaries
boundary={q:{state:[0] for state in state_nums[q].keys()} for q in Qhx.keys()}
#
for q in Qhx.keys():
for rl in agents:
for state,stnums in state_nums[q].items():
boundary[q][state].append(boundary[q][state][-1]+len(rl.agent.Qhx[q])/rl.agent.events_per_trial)
for ac,arr in Qhx[q][state].items():
for arr_num,stn in enumerate(stnums):
if stn<np.shape(rl.agent.Qhx[q])[1]:
if len(Qhx[q][state][ac][arr_num]):
Qhx[q][state][ac][arr_num]=np.concatenate((Qhx[q][state][ac][arr_num],rl.agent.Qhx[q][:,stn,rl.agent.actions[ac]]))
else:
Qhx[q][state][ac][arr_num]=rl.agent.Qhx[q][:,stn,rl.agent.actions[ac]]
else:
#add array of zeros if a state not represented in Q matrix
if len(Qhx[q][state][ac][arr_num]):
Qhx[q][state][ac][arr_num]=np.concatenate((Qhx[q][state][ac][arr_num],np.zeros(len(rl.agent.Qhx[q]))))
else:
Qhx[q][state][ac][arr_num]=np.zeros(len(rl.agent.Qhx[q]))
return Qhx, boundary,ideal_states
def plot_Qhx_OpAL(Qhx,boundary,ept,ac,title='',labels=None):
from matplotlib import pyplot as plt
colors=[plt.get_cmap('Blues'),plt.get_cmap('Reds'),plt.get_cmap('Purples'),plt.get_cmap('Greys')]
qname=['G','N']
if list(Qhx.keys())==[0,1]:
fig,axis=plt.subplots(len(Qhx[list(Qhx.keys())[0]]),3,sharex=True)
else:
fig,axis=plt.subplots(len(Qhx[list(Qhx.keys())[0]]),len(Qhx.keys()),sharex=True)
ax=fig.axes
fig.suptitle(title)
for col,q in enumerate(Qhx.keys()):
for row,state in enumerate(Qhx[q].keys()):
axnum=col+row*len(Qhx)
axtot=2+row*len(Qhx)
for arr_num,arr in enumerate((Qhx[q][state][ac])):
col_inc=colors[0].N*(5/6)/len(boundary[q][state])
if list(Qhx.keys())==[0,1]:
arrTot=Qhx[0][state][ac][arr_num]+Qhx[1][state][ac][arr_num]
for b in range(len(boundary[q][state])-1):
color=colors[0].reversed().__call__(int(b*col_inc))
bstart=int(boundary[q][state][b]*ept)
bend=int(boundary[q][state][b+1]*ept)
Xvals=np.arange(0,bend-bstart)
ax[axnum].plot(Xvals,arr[bstart:bend],label=ac,color=color)
if list(Qhx.keys())==[0,1]:
ax[axtot].plot(Xvals,arrTot[bstart:bend],label=ac,color=color)
ax[axtot].set_ylabel('G+N')
ax[axtot].set_xlabel('Trial',fontsize=fsizeSml+1)
ax[axnum].set_xlabel('Trial',fontsize=fsizeSml+1)
ax[col].set_ylabel(qname[col])
ax[0].legend(labels)
return fig
def plot_Qhx_2D(Qhx,boundary,ept,phases,ideal_states=None,fig=None,ax=None,title=''):
######### plot Qhx for bandit and Discrim task #############
from matplotlib import pyplot as plt
colors=[plt.get_cmap('Blues'),plt.get_cmap('Reds'),plt.get_cmap('Purples'),plt.get_cmap('Greys')]
if fig is None:
fig,axis=plt.subplots(len(Qhx[list(Qhx.keys())[0]]),len(Qhx),sharex=True)
ax=fig.axes
#numQ=len(Qhx)
#title_prefix=str(numQ)+'Q, '
title_prefix=''
else:
title_prefix=''
fig.suptitle(title)
for col,q in enumerate(Qhx.keys()):
if isinstance(q,int):
Qname={0:'G',1:'N'}
ax[col].set_title(title_prefix+Qname[q]+' values',fontsize=fsize-2)
#ax[col].set_title(title_prefix+'Q'+str(q+1)+' values',fontsize=fsize-2)
label_inc=1/len(Qhx[q].keys()) #used for putting subplot labels
for row,state in enumerate(Qhx[q].keys()):
axnum=col+row*len(Qhx)
print('ax',axnum,'Q',q,'state',state,'row',row)
for cnum,ac in enumerate(Qhx[q][state].keys()):
col_inc=colors[cnum].N*(2/3)/len(Qhx[q][state][ac])
for arr_num,arr in enumerate((Qhx[q][state][ac])):
color=colors[cnum%len(colors)].reversed().__call__(int(arr_num*col_inc))
if ideal_states is not None and len(Qhx[q][state][ac])>1:
#label=' '.join(ideal_states[q][state][arr_num]
#label='context='+ideal_states[q][state][arr_num][-1]
label=int(float(ideal_states[q][state][arr_num][-1])) #for Qhx figure in manuscript
leg_cols=2
else:
label=''
leg_cols=1
Xvals=np.arange(len(arr))/ept
ax[axnum].plot(Xvals,arr,label=ac+' '+ str(label),color=color)
#Next 3 lines are just for block case.
startx=0#160#
endx=Xvals[-1]
if endx>startx:
ax[axnum].set_xlim(startx,endx+(endx-startx)*0.05)
else:
print(Xvals)
handles,labels=ax[axnum].get_legend_handles_labels()
if leg_cols==2:
newlabels=[letters[int(lbl.split()[-1])] for lbl in labels]#for Qhx figure in manuscript
leg_ttl=list(np.unique([lbl.split()[0] for lbl in labels]))
leg_loc='center right'#'best'#
ax[axnum].legend(handles,newlabels,loc=leg_loc,ncol=leg_cols,title=' '.join(leg_ttl),fontsize=fsizeSml,title_fontsize=fsizeSml,handletextpad=0.2,labelspacing=0.3,columnspacing=1)
else:
if len(ax) > 1:
leg_loc='lower left'
ax[1].legend(loc='lower left',ncol=leg_cols,fontsize=fsizeSml)
ax[0].legend(loc='upper right',ncol=leg_cols,fontsize=fsizeSml)
if isinstance(state,tuple) or isinstance(state,list):
ax[axnum].set_ylabel(','.join(list(state)),fontsize=fsizeSml+1)
elif isinstance(state,str):
ax[axnum].set_ylabel(state.split()[0],fontsize=fsizeSml+1)
elif isinstance(q,str):#when re-arranging dictionary, then q has state and state has q val
ax[axnum].set_ylabel('Q'+str(state+1)+' values',fontsize=fsizeSml+1)
if row==len(Qhx[q].keys())-1:
ax[axnum].set_xlabel('Trial',fontsize=fsizeSml+1)
ylim=ax[axnum].get_ylim()
maxQ=max(ylim)
minQ=min(ylim)
Qrange=maxQ-minQ
ax[axnum].set_ylim([round(minQ-0.1*Qrange),round(maxQ+0.2*Qrange)]) #inset the curve to make room for text
ax[axnum].tick_params(axis='both', which='major', labelsize=fsizeSml)
for jj,xval in enumerate(boundary[q][state][1:]):
textx=(xval+boundary[q][state][jj])/2 #ax[axnum].transData
textx=0.5*(xval+boundary[q][state][jj]-startx)/boundary[q][state][-1] #transAxes
if isinstance(phases[0],str):
phs=phases[jj]
elif isinstance(phases[0],list) and isinstance(state,str):
if state.split()[-1].isdigit():
phs=phases[int(state.split()[-1])][jj]
elif isinstance(phases[0],list) and isinstance(q,str): #when re-arranging dictionary, then q has state and state has q val
if q.split()[-1].isdigit():
phs=phases[int(q.split()[-1])][jj]
#ax[axnum].text(textx,1*round(maxQ+0.05*Qrange),phs,ha='center',transform=ax[axnum].transData)
#ax[axnum].text(textx,0.9,phs,ha='center',transform=ax[axnum].transAxes,fontsize=fsizeSml)
ax[axnum].vlines(xval,round(minQ),round(maxQ),linestyles='dashed',color='grey')
y=(1-blank)-(row*label_inc) #subtract because 0 is at bottom
#if len(Qhx[q].keys())>1:
# fig.text(0.02,y,letters[row], fontsize=fsize)
return fig
#normalize to optimal? But what is optimal here?
def agent_response(runs,random_order,num_blocks,traject_dict,trials_per_block,fig=None,ax=None,norm=1):
for rr,r in enumerate(runs):
if rr>0 or (fig is None):
fig,ax=plt.subplots()
fig.suptitle('agent '+str(r))
left=np.zeros(len(random_order[r])*num_blocks)
right=np.zeros(len(random_order[r])*num_blocks)
for k,key in enumerate(random_order[r]):
start=k*num_blocks;end=(k+1)*num_blocks
left[start:end]=np.array(traject_dict[key][(('Pport', '6kHz'), 'left')][r])*norm
right[start:end]=np.array(traject_dict[key][(('Pport', '6kHz'), 'right')][r])*norm
#now plot the single trials
ax.plot(left,'b.',label='left')
ax.plot(right,'r.',label='right')
ax.set_ylabel('Response Rate')
ax.set_xlabel('Trial')
print('num_blocks',num_blocks,'left',len(left),trials_per_block)
xticks=np.arange(0,len(left),10)*trials_per_block
ax.set_xticks(np.arange(0,len(left),10),xticks)
ylim=ax.get_ylim()
for k,key in enumerate(random_order[r]):
ax.text((k+0.5)*num_blocks,ylim[1],key,ha='center',transform=ax.transData)
ax.vlines(k*num_blocks,0,10.1,color='gray',linestyles='dashed',linewidths=1)
ax.set_xlim([0,len(left)])
ax.set_ylim([ylim[0],ylim[1]*1.1])
ax.legend()#loc='center')
return fig
def plot_Qhx_sequence(Qhx,actions,ept,numQ):
### some states are practically zero, delete these from Qhx to be plotted
import copy
newQhx=copy.deepcopy(Qhx)
for state in Qhx.keys():
minQboth={row:[] for row in Qhx[state][0]};maxQboth={row:[] for row in Qhx[state][0]}
for col,q in enumerate(Qhx[state].keys()):
for press_hx in Qhx[state][q].keys():
maxQboth[press_hx].append(np.floor(np.max([np.max(arr) for arr in Qhx[state][q][press_hx].values()])))
minQboth[press_hx].append(np.ceil(np.min([np.min(arr) for arr in Qhx[state][q][press_hx].values()])))
deleterow=[]
for row in maxQboth.keys():
for q in range(len(maxQboth[row])):
if maxQboth[row][q]==minQboth[row][q]:
deleterow.append(row)
print('for state=',state,', not plotting these press histories:',np.unique(deleterow))
for row in np.unique(deleterow):
for q in newQhx[state].keys():
del newQhx[state][q][row]
#
######### Now create the plot #############
figures={}
for state in newQhx.keys():
fig,axis=plt.subplots(len(newQhx[state][0]),len(newQhx[state]),sharex=True)
figures[state]=fig
ax=fig.axes
for col,q in enumerate(newQhx[state].keys()):
ax[col].set_title(str(numQ)+'Q, Q'+str(q+1)+' values')
for row, press_hx in enumerate(newQhx[state][q]):
axnum=col+row*len(newQhx[state])
maxQ=round(np.max([np.max(arr) for arr in newQhx[state][q][press_hx].values()]),1)
minQ=round(np.min([np.min(arr) for arr in newQhx[state][q][press_hx].values()]),1)
for ac,color in actions.items():
#Average Q values across runs
Yvals=np.mean(newQhx[state][q][press_hx][ac],axis=0)
#trial number = event/events_per_trial (ept)
Xvals=np.arange(len(Yvals))/ept
ax[axnum].plot(Xvals,Yvals,label=ac,color=color)
if col==0:
ax[axnum].set_ylabel(press_hx)
ax[axnum].set_ylim([minQ,maxQ])
ax[axnum].legend()
ax[axnum].set_xlabel('Trial')
plt.show()
return figures
def replace_star(states):
new_states=[]
for state in states:
star_index=[i for i, letter in enumerate(state[1]) if letter =='*']
if len(star_index)==1:
for symb in ['L','R','-']:
new_states.append((state[0],state[1].replace('*',symb)))
elif len(star_index)==2:
newst=[]
for symb in ['L','R','-']:
newst.append(state[1].replace('*',symb,1))
for st in newst:
for symb in ['L','R']:
new_states.append((state[0],st.replace('*',symb)))
return new_states
def plot_Qhx_sequence_1fig(allQhx,plot_states,actions_colors,ept,actions_lines):
numcols=3 ######## change this to numQ if 1 figure per numQ
plot_state_trunc=[state[0][0:3]+','+state[1] for state in plot_states]
star_states=[ps for ps in plot_states if '*' in ps[1]]
if len(star_states):
new_states=replace_star(star_states)
plot_states=plot_states+new_states
for st in star_states:
plot_states.remove(st)
print('********NEW',plot_states)
plot_state_string=[','.join(list(ps)) for ps in plot_states]
fig,axis=plt.subplots(len(plot_states),numcols,sharex=True)
ax=fig.axes
for numQ,Qhx in allQhx.items(): ####### Remove for 1 figure per numQ
for state in Qhx.keys():
for q in Qhx[state].keys(): #q 1 if nq=0, q is either 1 or 2 ir nq=1
col=int(numQ)-1+int(q) ####### col = enumerate over Qhx[state] if 1 figure per numQ
for press_hx in Qhx[state][q].keys():
maxQ=0
minQ=0
found=False
if press_hx in plot_state_trunc:
row=plot_state_trunc.index(press_hx)
found=True
elif press_hx in plot_state_string:
row=[','.join(list(ps)) for ps in plot_states].index(press_hx)
found=True
if found:
axnum=row*numcols+col
print(numQ,q,row,col,axnum,state,press_hx)
if row==0:
Qtype={'10':'Q','11': 'Q2', '20':'G', '21':'N'} #Q2: if plotting more than
ax[axnum].set_title(str(numQ)+'Q, '+Qtype[str(numQ)+str(q)]+' values') #replace Q'+str(q+1)+
maxQ=round(np.max([np.max(arr) for arr in Qhx[state][q][press_hx].values()]),1) #comment out when Q1 or Q2 inactivated
minQ=round(np.min([np.min(arr) for arr in Qhx[state][q][press_hx].values()]),1)
for ac,color in actions_colors.items():
#Average Q values across runs
Yvals=np.mean(Qhx[state][q][press_hx][ac],axis=0)
maxQ=max(maxQ,np.max(Yvals))
minQ=min(minQ,np.min(Yvals))
#trial number = event/events_per_trial (ept)
Xvals=np.arange(len(Yvals))/ept
if ac != 'goMag':
ax[axnum].plot(Xvals,Yvals,label=ac,color=color, linestyle=actions_lines[ac])
if ('1' in allQhx.keys() and col==0) or (col==1 and '1' not in allQhx.keys()):
ax[axnum].set_ylabel(','.join(list(plot_states[row])),fontsize=11)
if minQ<maxQ:
ax[axnum].set_ylim([minQ,maxQ])
if row==len(plot_states)-1:
ax[axnum].set_xlabel('Trial',fontsize=12)
#else: print('************ not found ********',numQ,state,q,press_hx)
axis[0][1].legend(loc='upper left')
for col in range(numcols):
ylim=[axe.get_ylim() for axe in ax[col::3]]
ymin=np.floor(np.min([a[0] for a in ylim]))*1.1
ymax=np.ceil(np.max([a[1] for a in ylim]))*1.1
for axe in ax[col::3]:
if ymax>ymin:
axe.set_ylim([ymin,ymax])
plt.show()
return fig
def staticQ_barplot(Q,actions,title='',labels=None,state_subset=None):
"""Visualize the Q table by bar plot"""
fig,axis=plt.subplots(len(Q),1,sharex=True)
axes=fig.axes
#fig.suptitle(title)
Na=len(actions)
colors=plt.get_cmap('inferno') #plasma, viridis, inferno or magma
color_increment=int((len(colors.colors)-40)/(Na-1)) #40 to avoid to light colors
for i in Q.keys():
newQ=[];statenums=[];xlabels=[]
for s,row in enumerate(Q[i]):
if np.any(row):
newQ.append(row)
statenums.append(s)
if labels is not None:
xlabels.append(labels[i][s])
if len(xlabels) and state_subset is not None:
Qsubset=[]
keep_state=[(ii,lbl) for ii,lbl in enumerate(xlabels) for ss in state_subset if ss in lbl]
for (j,lbl) in keep_state:
Qsubset.append(newQ[j])
plotQ=np.array(Qsubset)
xlabels=[ks[1] for ks in keep_state]
statenums=[ks[0] for ks in keep_state]
else:
plotQ=np.array(newQ)
w = 1./(Na+0.5) # bar width
for a in range(Na):
cnum=a*color_increment
axes[i].bar(np.arange(len(statenums))+(a-(Na-1)/2)*w, plotQ[:,a], w,color=colors.colors[cnum])
if labels is not None:
xticks=[' '.join(lbl[0:2]) for lbl in xlabels]
else:
xticks=statenums
axes[i].set_ylabel("Q"+str(i+1)+" value")
axes[i].set_xticks(range(len(plotQ)),xticks)
#axes[-1].set_xlabel("state")
fig.legend(list(actions.keys()),bbox_to_anchor = (0.98, 0.98),ncol=2)#loc='upper right')
for i in Q.keys():
for ll in range(len(plotQ)-1):
ylim=axes[i].get_ylim()
axes[i].vlines(ll+0.5,ylim[0],ylim[1],'grey',linestyles='dashed') #make vertical grid - between groups of bars
label_inc=(1-blank)/len(Q)
x=1-blank-(i*label_inc) #subtract because 0 is at bottom
fig.text(blank,x,letters[i], fontsize=fsize)
return fig
def combined_bandit_Qhx_response(random_order,num_blocks,traject_dict,Qhx,boundaries,ept,phases,agent_num=-1,all_beta=[],Qlen=[],norm=1):
from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt
fig=plt.figure()
#gs=GridSpec(2,2) # 2 rows, 2 columns
#ax=[]
#ax.append(fig.add_subplot(gs[0,:])) # First row, span all columns
#ax.append(fig.add_subplot(gs[1,0])) # 2nd row, 1st column
#ax.append(fig.add_subplot(gs[1,1])) # 2nd row, 2nd column
if len(Qlen):
numrows=5
else:
numrows=4
gs=GridSpec(numrows,2)
ax=[]
for row in range(numrows):
ax.append(fig.add_subplot(gs[row,:]))
agent_response([agent_num],random_order,num_blocks,traject_dict,trials_per_block,fig,ax[0],norm=norm)
fig=plot_Qhx_2D(Qhx,boundaries,ept,phases,fig=fig,ax=[ax[1],ax[2]])
Xvals=np.arange(len(all_beta[agent_num]))/ept
ax[3].plot(Xvals,all_beta[agent_num])
ax[3].set_ylabel(r'$\beta$1')
if len(Qlen):
for q,data in Qlen[agent_num].items():
ax[4].plot(Xvals,data,label='Q'+str(q+1))
ax[4].legend()
ax[4].set_ylabel('States')
ax[-1].set_xlabel('Trial')
##### add boundaries ####
add_boundaries(numrows,Qhx,boundaries,ax)
#add subplot labels
fsize=14
blank=0.03
label_inc=1/len(fig.axes)
for row in range(len(fig.axes)):
y=(1-blank)-(row*label_inc) #subtract because 0 is at bottom
fig.text(0.02,y,letters[row], fontsize=fsize)
#fig.tight_layout()
return fig
def staticQ(f,figs,nQ):
data=np.load('staticQ'+f+'.npz',allow_pickle=True)
Q=data['allQ'].item()
labels=data['labels'].item()
state_subset=list(data['state_subset'])
actions=data['actions'].item()
figs[nQ]['static']=staticQ_barplot(Q,actions,title=str(nQ)+'Q',labels=labels,stabetate_subset=state_subset)
return figs
def add_labels(fig,numcols):
axes=fig.axes
for axnum,ax in enumerate(axes):
position=ax.get_position()
x=position.x0-0.06
y=position.y1
row=int(axnum/numcols)
col=axnum%numcols
fig.text(x,y,letters[col]+str(row+1), fontsize=fsize)
return
def combined_discrim_Qhx(Qhx,boundaries,ept,phases,all_ideals,all_beta=[],Qlen=[],agent=-1):
from matplotlib.gridspec import GridSpec
import matplotlib.pyplot as plt
Qkeys=list(Qhx.keys())
beta_axnum=len(Qhx[Qkeys[0]].keys())
if len(Qlen):
numrows=beta_axnum+2
else:
numrows=beta_axnum+1
fig=plt.figure()
gs=GridSpec(numrows,2)
ax=[]
for row in range(numrows):
ax.append(fig.add_subplot(gs[row,:]))
fig=plot_Qhx_2D(Qhx,boundaries,ept,phases,all_ideals,fig=fig,ax=ax[0:beta_axnum])
if isinstance(all_beta,list):
beta=all_beta[agent]
Xvals=np.arange(len(beta))/ept
ax[beta_axnum].plot(Xvals,beta,color='green')
#ax[beta_axnum].plot(Xvals,np.mean(beta,axis=0))
else:
for key in all_beta.keys():
beta=all_beta[key]
Xvals=np.arange(np.shape(beta)[1])/ept
ax[beta_axnum].plot(Xvals,np.mean(beta,axis=0),label=key)
ax[beta_axnum].set_ylabel(r'$\beta$1')
min_beta=np.min(beta)
max_beta=np.max(beta)
yticks=[str(round(f,1)) for f in np.arange(min_beta,max_beta+.01,0.5)]
ax[beta_axnum].set_yticks(np.arange(min_beta,max_beta+.01,0.5),yticks)
#ax[beta_axnum].legend()
if len(Qlen):
qcolors=['magenta','purple']
qnum=beta_axnum+1
for key1,data in Qlen.items():
if isinstance(data,dict):
for key2,d2 in Qlen[key1].items():
ax[qnum].plot(Xvals,d2[agent],label=key1+',Q'+str(key2),color=qcolors[key1])
ymax=np.max([d2[agent] for d2 in Qlen[key1].values()] )
else:
ax[qnum].plot(Xvals,data[agent],label='Q'+str(key1+1),color=qcolors[key1])
ymax=np.max([data[agent] for data in Qlen.values()] )
ax[qnum].legend()
ax[qnum].set_ylabel('States',fontsize=fsizeSml+1)
yticks = np.linspace(0,np.ceil(ymax)+np.ceil(ymax)%2,3,endpoint=True)
ylabels=[str(round(f)) for f in yticks]
ax[qnum].set_yticks(yticks,ylabels)
ax[-1].set_xlabel('Trial',fontsize=fsizeSml+1)
#length_q=[qq for q_row in data for qq in q_row]
#ax[3].plot(Xvals,length_q,label='Q'+str(q+1))
##### add vertical line showing phase boundaries ####
add_boundaries(numrows,Qhx,boundaries,ax)
#add subplot labels
'''fsize=14
blank=0.03
label_inc=1/len(fig.axes)
for row in range(len(fig.axes)):
y=(1-blank)-(row*label_inc) #subtract because 0 is at bottom
fig.text(0.02,y,letters[row], fontsize=fsize)'''
fig.tight_layout()
return fig
def add_boundaries(numrows,Qhx,boundaries,ax):
for axnum in range(3,numrows):
q=list(Qhx.keys())[0]
state=list(Qhx[q].keys())[0]
ylim=ax[axnum].get_ylim()
for jj,xval in enumerate(boundaries[q][state][1:]):
ax[axnum].vlines(xval,ylim[0],ylim[1],linestyles='dashed',color='grey')
def new_xlim(figs,xlim):
for figname in figs.keys():
ax=figs[figname].axes
for axis in ax:
axis.set_xlim(xlim)
def load_data(Qfil,parname='params'):
data=np.load(Qfil,allow_pickle=True)
try:
allQhx=data['all_Qhx'].item()
except:
allQhx=data['all_Qhx']
try:
all_bounds=data['all_bounds'].item()
all_ideals=data['all_ideals'].item()
except:
all_bounds=data['all_bounds']
all_ideals=data['all_ideals']
if 'all_beta' in data.keys(): #files beginning on 3 June 2022
try:
all_beta=data['all_beta'].item()
all_lenQ=data['all_lenQ'].item()
except:
all_beta=data['all_beta']
all_lenQ=data['all_lenQ']
else:
all_beta={}
all_lenQ={}
if parname in data.keys():
events_per_trial=data[parname].item()['events_per_trial']
trials_per_block=data[parname].item()['trials_per_block']
else:
events_per_trial=data['events_per_trial']
trials_per_block=10 #
phases=[[p for p in phs] for phs in data['phases']]
return allQhx,all_bounds,all_ideals,all_beta,all_lenQ,events_per_trial,trials_per_block,phases
if __name__ == "__main__":
import os
task='Bandit' #'AIP'#'sequence'# 'Discrim' ##
meanQ=False #make true in banditFiles.py if you know that meanQ has been added to file, e.g. if use_oldQ is False for bandit
if task=='sequence':
from sequenceFiles import fil,param_name
figs={q:{} for q in fil.keys()}
trial=-1
#
state_action_combos=[('Llever','---L'),('Llever','--LL'),('Llever', 'RRLL'),('Llever', 'RLLL'),('Rlever','--LL'),('Rlever','LLLL'),('Rlever','RLLL')] #,('Rlever','-LLL')
state_action_combos=[('Llever','---L'),('Llever','--LL'),('Rlever','--LL'),('Rlever','LLLL')]#,('Rlever','-LLR'),('Rlever','LLLR')]#,('Rlever','LLRR')]
actions_colors={'goL':'b','goR':'r','press':'k','goMag':'gray'}
actions_lines={'goL':'solid','goR':'solid','press':'dashed','goMag':'dotted'}
allQhx={}
for nQ,f in fil.items():
data=np.load(f+'.npz',allow_pickle=True)
events_per_trial=data[param_name].item()['events_per_trial']
trials_per_block=data[param_name].item()['trials_per_block']
try:
allQhx[nQ]=data['Qhx'].item()
except:
allQhx[nQ]=data['Qhx']
#numcols+=data['par'].item()['numQ']
#fig=plot_Qhx_sequence_1fig(allQhx[nQ],state_action_combos,actions_colors,events_per_trial)
#figs=staticQ(f,figs,nQ)
fig=plot_Qhx_sequence_1fig(allQhx,state_action_combos,actions_colors,events_per_trial,actions_lines)
else:
if task=='Bandit':
from banditFiles import fil, meanQ
trial=[4] #Bandit, [3,23,39] for June 3 2022 #[2,5,28] - for June 6, 2022
elif task=='Discrim':
from discrimFiles import fil
trial=0
elif task=='AIP':
from discrimFiles import AIPfil as fil
trial=0
figs={q:{} for q in fil.keys()}
for nQ,f in fil.items():
if len(os.path.dirname(f)):
Qfil=os.path.dirname(f)+'/Qhx'+os.path.basename(f)+'.npz'
else:
Qfil='Qhx'+os.path.basename(f)+'.npz'
all_Qhx,all_bounds,all_ideals,all_beta,all_lenQ,events_per_trial,trials_per_block,phases=load_data(Qfil)
if task=='Bandit': #bandit task only, Qhx for all trials is saved for more recent simulations, need random_order
data=np.load(Qfil,allow_pickle=True)
norm=1/trials_per_block
phases=data['phases']
#if 'random_order' not in other.keys():
# random_order=[phases]
if 'random_order' in data.keys():
random_order=data['random_order']
else:
#order of probability not saved, except for last trial
random_order=[phases]
trial=[-1]
if isinstance(all_Qhx,dict):
Qhx=all_Qhx
else:
Qhx=all_Qhx[trial[0]]
for nk in Qhx.keys():
if ('start','blip') in Qhx[nk].keys():
del Qhx[nk][('start','blip')]
for nk in Qhx.keys():
for st in Qhx[nk].keys():
if 'center' in Qhx[nk][st]:
del Qhx[nk][st]['center']
bounds=all_bounds[trial][0]
ideals=all_ideals[trial]
num_trials=int(len(all_lenQ[0][0])/events_per_trial)
if 'num_blocks' in data.keys():
num_blocks=data['num_blocks'].item()
else: #forgot to store num_blocks. calculate it
trials_per_phase=num_trials/len(phases)
num_blocks=int(trials_per_phase/trials_per_block) #10
print(Qfil,num_blocks)
data2=np.load(f+'.npz',allow_pickle=True)
traject_dict=data2['traject_dict'].item()
agent_response(trial,random_order,num_blocks,traject_dict,trials_per_block,norm=norm)
if meanQ: #the mean over Q has been added, plot that
trial=-1 #mean Qhx
#fig=plot_Qhx_2D(Qhx,bounds,events_per_trial,phases)
lbl=sorted(random_order[0], key=lambda x: float(x.split(':')[0])/float(x.split(':')[1]))
fig=plot_Qhx_OpAL(Qhx,bounds,events_per_trial,'left',labels=lbl)
title='meanQ'
fig.suptitle(title)
else:
fig_combined=combined_bandit_Qhx_response(random_order,num_blocks,traject_dict,Qhx,bounds,events_per_trial,phases,agent_num=trial[0],all_beta=all_beta,norm=norm)#,all_lenQ)
else: #Discrim, or older Bandit files, random order not saved, all_Qhx is single dictionary,
states=['Pport,6kHz 0','Pport,10kHz 0'] #for discrim,reverse
states=['Pport,6kHz 1'] #for acq,extinct
if task=='AIP':
states=['Pport,10kHz 0'] #for acq,discrim with block
if len(all_Qhx)>1:
Qhx=all_Qhx[trial]
bounds=all_bounds[trial]
ideals=all_ideals[trial]
else:
Qhx=all_Qhx[0]
bounds=all_bounds[0]
ideal=all_ideals[0]
for state in states:
if int(nQ)>1:
Qhx_subset={state:{k:v[state] for k,v in Qhx.items()}}
bounds_subset={state:{k:v[state] for k,v in bounds.items()}}
ideals_subset={state:{k:v[state] for k,v in ideals.items()}}
keynum=int(state.split()[-1])
if len(all_beta):
bkey=list(all_beta.keys())[keynum]
beta=copy.deepcopy(all_beta[bkey])
if task=='AIP':
figs[nQ]['qhx'+state]=combined_discrim_Qhx(Qhx_subset,bounds_subset,events_per_trial,phases,ideals_subset,all_beta=beta)
else:
figs[nQ]['qhx'+state]=combined_discrim_Qhx(Qhx_subset,bounds_subset,events_per_trial,phases,ideals_subset,all_beta=beta,Qlen=all_lenQ[bkey])
print('figures for',state,f,',bkey =',bkey)
else:
figs[nQ]['qhx']=plot_Qhx_2D(Qhx_subset,bounds_subset,events_per_trial,phases,ideals_subset,title=state)
print('figures for all:',f)
if int(nQ)>1:
figs[nQ]['qhx_all']=plot_Qhx_2D(Qhx,bounds,events_per_trial,phases,ideals,title='all phases')
else:
newQhx={1:Qhx,2:{}}
newbounds={1:bounds,2:{}}
newideals={1:ideals,2:{}}
figs[nQ]['qhx_all']=combined_discrim_Qhx(newQhx,newbounds,events_per_trial,phases,newideals,all_beta=all_beta,Qlen=all_lenQ)
#figs=staticQ(f,figs,nQ)
if 'Pport,10kHz 0' in states and task != 'AIP':
new_xlim(figs[nQ],[190,610])
#figs['2']['qhxPport,6kHz 0'].tight_layout()
#figs['2']['qhxPport,10kHz 0'].tight_layout()
#axes=figs['2']['qhxPport,6kHz 1'].axes
#axes=figs['2']['qhxPport,10kHz 0'].axes #for block case
#for ax in axes:
# ax.set_xlim([-1,401])#601])