# -*- coding: utf-8 -*-
"""
Created on Wed Aug 5 13:39:14 2020
2021 march: change agent to use euclidean distance instead of Gaussian mixture
change learning rule for Q2 to use decreases in predicted reward
@author: kblackw1
"""
import numpy as np
import completeT_env as tone_discrim
import agent_twoQtwoSsplit as QL
from RL_TD2Q import RL
import RL_utils as rlu
import copy
from TD2Q_Qhx_graphs import Qhx_multiphase
def select_phases(block_DA_dip,PREE,savings,extinct,context,action_items):
acq_cue=context[0]#use [] for no cues
ext_cue=context[1] #1 for extinction in separate context
dis_cue=context[0]
ren_cue=acq_cue
acq2_cue=[]
if block_DA_dip:
learn_phases=['acquire','extinc','discrim']
figure_sets=[['extinc','discrim']]
traject_items={'acquire':[action_items[1]]+['rwd'],'extinc':[action_items[1]]+['rwd'],'discrim':action_items[1:]+['rwd']}
ext_cue=context[0] #0 for extinction in same context
elif PREE: #evaluate how reward prob during acquisition affects rate of extinction
learn_phases=['acquire','extinc']
figure_sets=[['acquire','extinc']]
traject_items={'acquire':[action_items[1]]+['rwd'],'extinc':[action_items[1]]+['rwd']}
ext_cue=context[0] #0 for extinction in same context
elif savings=='after extinction': #evaluate if learn faster the second time
learn_phases=['acquire','extinc','acquire2']
figure_sets=[['acquire','extinc','acquire2']]
traject_items={'acquire':[action_items[1]]+['rwd'],'extinc':[action_items[1]],'acquire2':[action_items[1]]+['rwd']}
ext_cue=context[0] #0 for extinction in same context
acq2_cue=context[0]
elif savings=='in new context':#evaluate if learn faster in new context
learn_phases=['acquire','acquire2']
figure_sets=[['acquire','acquire2']]
traject_items={'acquire':[action_items[1]]+['rwd'],'acquire2':[action_items[1]]+['rwd']}
acq2_cue=context[1] #acquisition in different context
elif extinct=='AAB':
learn_phases=['acquire','extinc','renew']
figure_sets=[['acquire','extinc','renew']]
traject_items={'acquire':[action_items[1]]+['rwd'],'extinc':[action_items[1]]+['rwd'],'renew':[action_items[1]]+['rwd']}
ext_cue=context[0]
ren_cue=context[1]
elif extinct=='ABB':
learn_phases=['acquire','extinc','renew']
figure_sets=[['acquire','extinc','renew']]
traject_items={'acquire':[action_items[1]]+['rwd'],'extinc':[action_items[1]]+['rwd'],'renew':[action_items[1]]+['rwd']}
ext_cue=context[1]
ren_cue=context[1]
elif extinct=='ABA':
learn_phases=['acquire','extinc','renew']
figure_sets=[['acquire','extinc','renew']]
traject_items={'acquire':[action_items[1]]+['rwd'],'extinc':[action_items[1]]+['rwd'],'renew':[action_items[1]]+['rwd']}
ext_cue=context[1]
ren_cue=context[0]
else: #this is ABA, but with added discrim and reverse
learn_phases=['acquire','extinc','renew','discrim','reverse'] #
figure_sets=[['discrim','reverse'],['acquire','extinc','renew']]
traject_items={'acquire':[action_items[1]]+['rwd'],'extinc':[action_items[1]]+['rwd'],'renew':[action_items[1]]+['rwd'],
'discrim':action_items[1:]+['rwd'],'reverse':action_items[1:]+['rwd']}
return learn_phases,figure_sets,traject_items,acq_cue,ext_cue,ren_cue,dis_cue,ren_cue,acq2_cue
def respond_decay(phases,state_act, output_data):
half={phs:[] for phs in phases};half_block={phs:[] for phs in phases}
for phs in phases:
all_data=output_data[phs][state_act]
mean_traject=np.mean(all_data,axis=0)
for data in all_data:
half[phs].append((np.max(data)+np.min(data))/2)
if mean_traject.argmax()>mean_traject.argmin():
h=np.where(data>half[phs][-1])
else:
h=np.where(data<half[phs][-1])
if len(h[0]):
half_block[phs].append(np.min(h))
else:
half_block[phs].append(np.nan)
print(phs,'half responding=',np.nanmean(half[phs]),'block=',np.nanmean(half_block[phs]))
return half,half_block
####################################################################################################################
if __name__ == "__main__":
from DiscriminationTaskParam2 import params,states,act
events_per_trial=params['events_per_trial'] #this is task specific
trials=200 #Iino: 180 trials for acq, then 160 trials for discrim; or discrim from the start using 60 trials of each per day (120 trials) * 3 days
numevents= events_per_trial*trials
runs=10 #10 for paper
#control output
printR=False #print environment Reward matrix
Info=False#print information for debugging
plot_hist=0#1: plot final Q, 2: plot the time since last reward
plot_Qhx=2 #2D or 3D plots of the dynamics of Q
save_reward_array=True
#additional cues that are part of the state for the agent, but not environment
#this means that they do not influence the state transition matrix
context=[[0],[1]] #set of possible context cues
#RewHx3 means that agent estimates the reward history as part of the context
#extinction context needs to be more similar to acquisition context than difference between tone/loc cues
#If want to add reward and time since reward to cues, need to divide by ~100
noise=0.15 #make noise small enough or state_thresh small enough to minimize new states in acquisition
#action_items is a subset of the state-action combinations that an agent can perform
#count number of responses to the following state-action combos:
action_items=[(('start','blip'),'center'),(('Pport','6kHz'),'left'),(('Pport','6kHz'),'right'),(('Pport','10kHz'),'left'),(('Pport','10kHz'),'right')]
#action_items=['center','left','right']
block_DA_dip=False #'AIP' #AIP or no_dip blocks homosynaptic LTP, no_dip blocks heterosynaptic LTD, False - control
PREE=0
savings='none'#''none'#'in new context'# 'after extinction'##'none'# #- for simulating discrim and reverse
extinct='none' #AAB: aquire and extinguish in A, try to renew in B; ABB: aquire in A, extinguish in B, re-test renewal in B
#Specify which learning protocols/phases to implement
learn_phases,figure_sets,traject_items,acq_cue,ext_cue,ren_cue,dis_cue,ren_cue,acq2_cue=select_phases(block_DA_dip,PREE,savings,extinct,context,action_items)
#state_sets and phases used for saving Qhx, beta and Qlen
if block_DA_dip:
state_sets=[[('Pport','6kHz'),('Pport','10kHz')]]
phases=[['acquire','discrim']]
elif extinct.startswith('A'):
state_sets=[[('Pport','6kHz')]]
phases=learn_phases
else:
state_sets=[[('Pport','6kHz'),('Pport','10kHz')],[('Pport','6kHz')]]
phases=[['acquire','discrim','reverse'],['acquire','extinc','renew']]
trial_subset=int(0.1*numevents) #display mean reward and count actions over 1st and last of these number of trials
#update some parameters of the agent
params['decision_rule']=None #'mult' #'delta' # #'combo','sumQ2', None means use direct negative of D1 rule
params['Q2other']=0.0
params['numQ']=2
params['wt_learning']=False
params['distance']='Euclidean'
params['beta_min']=0.5
params['beta']=1.5
params['beta_GPi']=10
params['gamma']=0.82
params['state_units']['context']=False
params['initQ']=-1 #-1 means do state splitting. If initQ=0, 1 or 10, it means initialize Q to that value and don't split
params['D2_rule']= None #'Ndelta' #'Bogacz' #'Opal' #None ### Ndelta: calculate delta from N matrix to update N. Opal: Opal update without critic
if params['distance']=='Euclidean':
#state_thresh={'Q1':[0.875,0],'Q2':[0.875,1.0]} #For Euclidean distance
#alpha={'Q1':[0.2,0],'Q2':[0.2,0.1]} #For Euclidean distance
state_thresh={'Q1':[1.0,0],'Q2':[0.75,0.625]} #For normalized Euclidean distance
alpha={'Q1':[0.3,0],'Q2':[0.2,0.1]} #For normalized Euclidean distance
else:
state_thresh={'Q1':[0.22, 0.22],'Q2':[0.20, 0.22]} #For Gaussian Mixture?,
alpha={'Q1':[0.4,0],'Q2':[0.4,0.2]} #For Gaussian Mixture? [0.62,0.19] for beta=0.6, 1Q or 2Q;'
params['state_thresh']=state_thresh['Q'+str(params['numQ'])] #for euclidean distance, no noise
#lower means more states for Euclidean distance rule
params['alpha']=alpha['Q'+str(params['numQ'])] #
traject_title='num Q: '+str(params['numQ'])+' rule:'+str( params['decision_rule'])+' forget:'+str(params['forgetting'])
################# For OpAL ################
params['use_Opal']=False
if params['use_Opal']:
params['numQ']=2
params['Q2other']=0
params['decision_rule']='delta'
params['alpha']=[0.1,0.1]#[0.2,0.2]#
params['beta_min']=1
params['beta']=1
params['gamma']=0.1 #called alpha_c in OpAL
params['initQ']=1 #do not split states, initialize Q values to 1
params['state_thresh']=[0.75,0.625]
params['D2_rule']='Opal'
noise=0.05
######################################
from DiscriminationTaskParam2 import Racq,Tacq,env_params, rwd
epochs=['Beg','End']
if PREE:
traject_title+=' PREE:'+str(PREE)
from DiscriminationTaskParam2 import loc, tone
Racq[(loc['Pport'],tone['6kHz'])][act['left']]=[(rwd['reward'],PREE),(rwd['base'],1-PREE)] #lick in left port - 90% reward
keys=rlu.construct_key(action_items +['rwd'],epochs)
resultslist={phs:{k+'_'+ep:[] for k in keys.values() for ep in epochs} for phs in learn_phases}
traject_dict={phs:{ta:[] for ta in traject_items[phs]} for phs in learn_phases}
#count number of responses to the following actions:
results={phs:{a:{'Beg':[],'End':[]} for a in action_items+['rwd']} for phs in learn_phases}
Qhx_actions=['left','right'] #actions of interest for Qhx
### to plot performance vs trial block
trials_per_block=10
events_per_block=trials_per_block* events_per_trial
num_blocks=int((numevents+1)/events_per_block)
params['events_per_block']=events_per_block
params['trials_per_block']=trials_per_block
params['trial_subset']=trial_subset
resultslist['params']={p:[] for p in params.keys()}
all_beta={'_'.join(k):[] for k in phases}
all_lenQ={k:{q:[] for q in range(params['numQ'])} for k in all_beta.keys()}
all_Qhx=[];all_bounds=[];all_ideals=[]
for r in range(runs):
rl={}
if 'acquire' in learn_phases:
print('&&&&&&&&&&&&&&&&&&&& acquire for run',r, states,'\n R:',Racq.keys(),'\n T:',Tacq.keys(),' cues:',acq_cue)
######### acquisition trials, context A, only 6 Khz + L turn allowed #########
rl['acquire'] = RL(tone_discrim.completeT, QL.QL, states,act,Racq,Tacq,params,env_params,printR=printR)
results,acqQ=rlu.run_sims(rl['acquire'],'acquire',numevents,trial_subset,action_items,noise,Info,acq_cue,r,results,phist=plot_hist,block_DA=block_DA_dip)
traject_dict=rl['acquire'].trajectory(traject_dict, traject_items,events_per_block)
if 'extinc' in learn_phases:
if runs==1:
print('&&&&&&&&&&&&&&&&&&&& extinction for run',r, states,' cues:',ext_cue)
from DiscriminationTaskParam2 import Rext,Tacq
rl['extinc'] = RL(tone_discrim.completeT, QL.QL, states,act,Rext,Tacq,params,env_params,printR=printR,oldQ=acqQ)
results,extQ=rlu.run_sims(rl['extinc'],'extinc',numevents,trial_subset,action_items,noise,Info,ext_cue,r,results,phist=plot_hist,block_DA=block_DA_dip)
traject_dict=rl['extinc'].trajectory(traject_dict, traject_items,events_per_block)
#### renewal - blocking D2 or Da Dip not tested
if 'renew' in learn_phases:
if runs==1:
print('&&&&&&&&&&&&&&&&&&&& renewal for run',r, states,' cues:',acq_cue)
rl['renew'] = RL(tone_discrim.completeT, QL.QL, states,act,Rext,Tacq,params,env_params,printR=printR,oldQ=extQ)
results,renQ=rlu.run_sims(rl['renew'],'renew',numevents,trial_subset,action_items,noise,Info,ren_cue,r,results)
traject_dict=rl['renew'].trajectory(traject_dict, traject_items,events_per_block)
####### discrimination trials, add in 10Khz tone, + needed reward and state transitions
if 'discrim' in learn_phases:
#use last context in the list
from DiscriminationTaskParam2 import Rdis,Tdis
if runs==1:
print('&&&&&&&&&&&&&&&&&&&& discrimination for run',r, states,'\n R:',Rdis.keys(),'\n T:',Tdis.keys(),' cues:',dis_cue)
if 'acquire' in learn_phases: #expand previous covariance matrix and Q with new states
rl['discrim'] = RL(tone_discrim.completeT, QL.QL, states,act,Rdis,Tdis, params,env_params,oldQ=acqQ)
acq_first=True
else:
rl['discrim'] = RL(tone_discrim.completeT, QL.QL, states,act,Rdis,Tdis, params,env_params)
acq_first=False
results,disQ=rlu.run_sims(rl['discrim'],'discrim',int(numevents),trial_subset,action_items,noise,Info,dis_cue,r,results,phist=plot_hist,block_DA=block_DA_dip)
traject_dict=rl['discrim'].trajectory(traject_dict, traject_items,events_per_block)
#rl['discrim'].set_of_plots('discrim, acquire 1st:'+str(acq_first),noise,t2,hist=plot_hist)
if Info:
print('discrim, acquire 1st:'+str(acq_first)+', mean reward=',np.round(np.mean(rl['discrim'].results['reward'][-trial_subset:]),2))
if 'reverse' in learn_phases:
from DiscriminationTaskParam2 import Rrev,Trev
if runs==1:
print('&&&&&&&&&&&&&&&&&&&& reversal',states,'\n R:',Rrev.keys(),'\n T:',Trev.keys(),' cues:',dis_cue)
rl['reverse']=RL(tone_discrim.completeT, QL.QL, states,act,Rrev, Trev,params,env_params,oldQ=disQ)
results,revQ=rlu.run_sims(rl['reverse'],'reverse',int(numevents),trial_subset,action_items,noise,Info,dis_cue,r,results,phist=plot_hist)
traject_dict=rl['reverse'].trajectory(traject_dict, traject_items,events_per_block)
if savings == 'in new context' or savings == 'after extinction':
if savings == 'in new context' :
rl['acquire2'] = RL(tone_discrim.completeT, QL.QL, states,act,Racq,Tacq,params,env_params,printR=printR,oldQ=acqQ)
if savings == 'after extinction':
rl['acquire2'] = RL(tone_discrim.completeT, QL.QL, states,act,Racq,Tacq,params,env_params,printR=printR,oldQ=extQ)
results,acq2Q=rlu.run_sims(rl['acq2'],'acquire2',numevents,trial_subset,action_items,noise,Info,acq2_cue,r,results,phist=plot_hist)
traject_dict=rl['acquire2'].trajectory(traject_dict, traject_items,events_per_block)
print ('>>>>>>>>>>>>>>>>>>>> savings', savings,'acq2 cue',acq2_cue)
all_beta,all_lenQ=rlu.beta_lenQ(rl,phases,all_beta,all_lenQ,params['numQ'])
agents=[[rl[phs] for phs in phaseset] for phaseset in phases]
one_Qhx={q:{} for q in range(params['numQ'])};one_ideals={q:{} for q in range(params['numQ'])};one_bounds={q:{} for q in range(params['numQ'])}
for ij,(state_subset,phase_set,agent_set) in enumerate(zip(state_sets,phases,agents)):
Qhx, boundaries,ideal_states=Qhx_multiphase(state_subset,Qhx_actions,agent_set,params['numQ'])
for q in Qhx.keys():
for st in Qhx[q].keys():
newstate=','.join(list(st))
one_Qhx[q][newstate+' '+str(ij)]=copy.deepcopy(Qhx[q][st])
one_ideals[q][newstate+' '+str(ij)]=copy.deepcopy(ideal_states[q][st])
one_bounds[q][newstate+' '+str(ij)]=copy.deepcopy(boundaries[q][st])
all_Qhx.append(one_Qhx)
all_ideals.append(one_ideals)
all_bounds.append(one_bounds)
######### Average over runs, also need stdev.
all_ta=[]; output_data={}
for phs in traject_dict.keys():
output_data[phs]={}
for ta in traject_dict[phs].keys():
all_ta.append(ta)
output_data[phs][ta]={'mean':np.mean(traject_dict[phs][ta],axis=0),'sterr':np.std(traject_dict[phs][ta],axis=0)/np.sqrt(runs-1)}
all_ta=list(set(all_ta))
#move reward to front
all_ta.insert(0, all_ta.pop(all_ta.index('rwd'))) #
for p in resultslist['params'].keys(): #
resultslist['params'][p].append(params[p])
resultslist=rlu.save_results(results,keys,resultslist)
interesting_combos={'acquire':['Pport_6kHz_left_End','rwd__End'],
'acquire2':['Pport_6kHz_left_End','rwd__End'],
'extinc':['Pport_6kHz_left_Beg','Pport_6kHz_left_End'],
'renew':['Pport_6kHz_left_Beg','Pport_6kHz_left_End'],
'discrim':['Pport_6kHz_left_Beg','Pport_6kHz_left_End','Pport_10kHz_left_Beg','Pport_10kHz_right_End','rwd__End'],
'reverse':['Pport_6kHz_left_End','Pport_10kHz_right_End','Pport_6kHz_right_End','Pport_10kHz_left_End','rwd__End']}
print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
print(' Using',params['numQ'], 'Q, alpha=',params['alpha'],'thresh',params['state_thresh'], 'beta=',params['beta'],'runs',runs,'of total events',numevents)
print(' apply learning_weights:',[k+':'+str(params[k]) for k in params.keys() if k.startswith('wt')])
print(' D2_rule=',params['D2_rule'],'decision rule=',params['decision_rule'],'split=',params['initQ'],'critic=',params['use_Opal'])
print('counts from ',trial_subset,' events (',events_per_trial,' events per trial) BEGIN END std over ',runs,'runs')
for phase in results.keys():
for sa,counts in results[phase].items():
print(phase.rjust(12), sa,':::',np.round(np.mean(counts['Beg']),2),'+/-',np.round(np.std(counts['Beg']),2),
',', np.round(np.mean(counts['End']),2),'+/-',np.round(np.std(counts['End']),2))
for sa in interesting_combos[phase]:
if sa in resultslist[phase]:
print( ' ',sa,':::',[round(val,3) for lst in resultslist[phase][sa] for val in lst] )
trajectfig=rlu.plot_trajectory(output_data,traject_title,figure_sets)
print('******* winner count, 1st run *****')
for phase,ag in rl.items():
print(phase,[(wck,np.sum(wcvals)) for wck,wcvals in ag.agent.winner_count.items()])
if plot_Qhx==2:
########## Plot Q values over time
### A. Qhx_multiphase plots only select state/actions, and concatenates multiple learning phases
from TD2Q_Qhx_graphs import plot_Qhx_2D
########################## NEXT:
for ij,(state_subset,phase_set,agent_set) in enumerate(zip(state_sets,phases,agents)):
fig=plot_Qhx_2D(Qhx,boundaries,events_per_trial,phase_set,ideal_states)
fig=plot_Qhx_2D(all_Qhx[0],all_bounds[0],events_per_trial,phases,all_ideals[0]) #fig 3
elif plot_Qhx==3:
### B. 3D plot Q history for selected actions, for all states, one graph per phase
for phase in ['discrim','reverse']:
rl[phase].agent.plot_Qdynamics(['center','left','right'],'surf',title=rl[phase].name)
if save_reward_array:
if block_DA_dip:
fname='DiscrimD2'+block_DA_dip
else:
fname='Discrim'
import datetime
dt=datetime.datetime.today()
date=str(dt).split()[0]
key_params=['numQ','Q2other','beta_GPi','decision_rule','beta_min','beta','gamma','use_Opal','D2_rule']
fname_params=key_params+['initQ']
fname=fname+date+'_'.join([k+str(params[k]) for k in fname_params])+'_rwd'+str(rwd['reward'])
np.savez(fname,par=params,results=resultslist,traject=output_data)
if plot_Qhx==2:
np.savez('Qhx'+fname,all_Qhx=all_Qhx,all_bounds=all_bounds,events_per_trial=events_per_trial,phases=phases,all_ideals=all_ideals,all_beta=all_beta,all_lenQ=all_lenQ)
# Q value plot of subset of states
select_states=['success','6kHz','10kHz']
numchars=3
state_subset=[ss[0:numchars] for ss in select_states]
for i in range(params['numQ']):
rl['discrim'].agent.visual(rl['discrim'].agent.Q[i],labels=rl['discrim'].state_to_words(i,noise,chars=numchars),title='dis Q'+str(i),state_subset=state_subset)
if save_reward_array:
numchars=8
allQ={i:rl['discrim'].agent.Q[i] for i in range(params['numQ'])}
all_labels={i:rl['discrim'].state_to_words(i,noise,chars=numchars) for i in range(params['numQ'])}
actions=rl['discrim'].agent.actions
np.savez('staticQ'+fname,allQ=allQ,labels=all_labels,actions=actions,state_subset=[ss[0:numchars] for ss in select_states])
if not block_DA_dip:
print('beta_min',params['beta_min'],'beta_max',params['beta'],'beta_GPi',params['beta_GPi'],'rwd',\
np.mean(results['acquire']['rwd']['End'])+np.mean(results['discrim']['rwd']['End'])+np.mean(results['reverse']['rwd']['End']) )
############## Evaluate decay of responding #####################
state_act=(('Pport','6kHz'),'left') #
half,half_block=respond_decay(['acquire','extinc','renew'],state_act,traject_dict)
#identify which panel for plotting fit
ax=[a for a in trajectfig.axes if 'Ppor' in a.get_ylabel() and '6kHz' in a.get_ylabel()]
symbols={'acquire':'o','extinc':'d','renew':'X'}
for jj,(key,x) in enumerate(half_block.items()):
ax[0].scatter(np.mean(x),np.mean(half[key]),marker=symbols[key],color=ax[0].get_lines()[3*jj].get_c())