# -*- coding: utf-8 -*-
"""
Created on Wed Aug 5 13:39:14 2020
@author: kblackw1
"""
import numpy as np
import sequence_env as task_env
import agent_twoQtwoSsplit as QL
import RL_utils as rlu
##########################################################
class RL:
"""Reinforcement learning by interaction of Environment and Agent"""
def __init__(self, environment, agent, states,actions,R,T,Aparams,Eparams,oldQ={},printR=False):
"""Create the environment and the agent"""
self.env = environment(states,actions,R,T,Eparams,printR)
#self.agent = agent(self.env.T.keys(), self.env.actions,Aparams,oldQ)
self.agent = agent(self.env.actions,Aparams,oldQ)
self.vis = True # visualization
self.name=None
self.results={'state': [], 'reward':[],'action':[]}
def episode(self, tmax=50,noise=0,cues=[],info=False):
state = self.env.start() #state tuple, (0,0) to start
reward=0
action = self.agent.start(state,cues)
self.append_results(action,reward)
# Repeat interaction
if info:
print('start episode, from Q=', self.agent.Q,'\nresults',self.results)
for t in range(1, tmax+1):
reward, state = self.env.step(action,prn_info=info) #determine new state and reward from env
#print('t=',t,'state',state,'=',self.env.state_from_number(state),'reward=',reward)
action = self.agent.step(reward, state, noise,cues=cues,prn_info=info) #determine next action from current state and reward
self.append_results(action,reward)
return
def append_results(self,action,reward):
self.results['state'].append(self.env.state)
self.results['reward'].append(reward)
self.results['action'].append(action)
def visual(self,title=None):
"""Visualize state,action,reward of an eipsode"""
import matplotlib.pyplot as plt
plt.ion()
fig,ax=plt.subplots(nrows=3,ncols=1,sharex=True)
if title is not None:
fig.suptitle(title)
xvals=np.arange(len(self.results['reward']))
for i,key in enumerate(['reward','action']):
ax[i].plot(xvals,self.results[key], label=key)
ax[i].set_ylabel(key)
ax[i].legend()
ax[-1].set_xlabel('time')
offset=0.1
for i,((st,lbl),symbol) in enumerate(zip(self.env.state_types.items(),['k.','bx'])):
yval=[s_tup[st]+i*offset for s_tup in self.results['state']]
ax[2].plot(xvals,yval,marker=symbol[-1],color=symbol[0],label=lbl,linestyle='None')
ax[2].set_ylabel('state')
ax[2].legend()
def state_to_words(self,nn,noise,hx_len):
env_states=[];env_st_num=[]
env_bits=len(self.env.states.keys())
for st in self.agent.ideal_states[nn].values():
env_st_num.append([np.round(si,1) for si in st])
env_states.append([])
for si in st:
env_states[-1].append('--')
for ii,st in enumerate(env_st_num):
for jj,si in enumerate(st[0:env_bits]):
key=list(self.env.states.keys())[jj]
if np.abs(int(si))-np.abs(si)<=noise and int(np.round(si)) in self.env.states[key].values():
env_states[ii][jj]=list(self.env.states[key].keys())[list(self.env.states[key].values()).index(int(np.round(si)))][0:hx_len]
for jj,si in enumerate(st[env_bits:]):
env_states[ii][jj+env_bits]=str(si)
return env_states
def set_of_plots(self,numQ,noise,hx_len,title2='',hist=False):
import matplotlib.pyplot as plt
plt.ion()
self.visual(numQ+'Q'+title2) #differs from RL_TD2Q in parameter numQ vs learn_phase , and using hx_len
for ii in range(len(self.agent.Q)):
self.agent.visual(self.agent.Q[ii],labels=self.state_to_words(ii,noise,hx_len),
title=numQ+'Q, Q'+str(ii+1))
if hist:
self.agent.plot_learn_history(title=numQ+'Q, Q'+str(ii+1))
def get_statenum(self,state): ####### Not part of RL_TD2Q
if state[0] in self.env.states['loc']:
state0num=self.env.states['loc'][state[0]]
else:
state0num=-1 #this occurs if wildcard specified as action
#wildcard can be used to specify location or a characteri in press_hx
#If wildcard is used, need to find all possible matching states
matching_state1=[]
if state[1] in self.env.states['hx']:
state1num=self.env.states['hx'][state[1]]
else:
state1num=-1
#star_index=state[1].find('*')#will only find first occurrence
star_index=[i for i, letter in enumerate(state[1]) if letter =='*']
#list of possible matching states to *LL
for st in self.env.states['hx']:
if np.all([state[1][i]==st[i] for i in range(len(st)) if i not in star_index]):
matching_state1.append(st)
return state0num,state1num,matching_state1
def count_actions(self,allresults,sa_combo,event_subset,accum_type='mean'): ####### Not part of RL_TD2Q
#2021 jan 4: added multiply reward by events_per_trial to get mean reward per trial
trial_subset=event_subset/self.agent.events_per_trial
for sa in sa_combo:
state=sa[0]
anum=self.env.actions[sa[1]]
state0num,state1num,matching_state1=self.get_statenum(state)
#count how many times that state=state and action=action
#print('sa',sa,'matching states',matching_state1)
timeframe={'Beg':range(event_subset),'End':range(-event_subset,0)}
actions=np.array(self.results['action'])
for tf,trials in timeframe.items():
sa_count=0
action_indices=np.where(actions[trials]==anum)[0]+trials[0] #indices with correct actions
#for tr in trials:
for tr in action_indices:
#if self.results['action'][tr]==anum:
#count number of times that agent state is state0 and state1
if (state[0]=='*' or self.results['state'][tr][0]==state0num) and \
(self.results['state'][tr][1]==state1num or self.env.state_from_number(self.results['state'][tr])[1] in matching_state1):
sa_count+=1
allresults[sa][tf].append(sa_count/trial_subset) #events per trial, fraction of responses in specified number of events
if accum_type=='count':
max_rwd=np.max(self.results['reward'])
allresults['rwd']['Beg'].append(self.results['reward'][0:event_subset].count(max_rwd))/trial_subset #number of rewards per trial
allresults['rwd']['End'].append(self.results['reward'][-event_subset:].count(max_rwd))/trial_subset #maximum = 1
else:
allresults['rwd']['Beg'].append(np.mean(self.results['reward'][0:event_subset])*self.agent.events_per_trial) #mean reward per trial
allresults['rwd']['End'].append(np.mean(self.results['reward'][-event_subset:])*self.agent.events_per_trial)
return allresults
def trajectory(self,traject,sa_combo, num_blocks,events_per_block,numQphs,accum_type='mean'): #differs from RL_TD2Q
for sa in sa_combo:
if sa=='rwd':
if accum_type=='count':
max_rwd=np.max(self.results['reward'])
traject[numQphs]['rwd'].append([self.results['reward'][block*events_per_block:(block+1)*events_per_block].count(max_rwd) for block in range(num_blocks)]) #rewards per block
else:
traject[numQphs]['rwd'].append([self.agent.events_per_trial*np.mean(self.results['reward'][block*events_per_block:(block+1)*events_per_block]) for block in range(num_blocks)])
else:
anum=self.env.actions[sa[1]]
state=sa[0]
state0num,state1num,matching_state1=self.get_statenum(state)
block_count=[]
for block in range(num_blocks):
sa_count=0
for tr in range(block*events_per_block,(block+1)*events_per_block):
if self.results['action'][tr]==anum:
#count number of times that agent state is state0 and state1
if (state[0]=='*' or self.results['state'][tr][0]==state0num) and \
(self.results['state'][tr][1]==state1num or self.env.state_from_number(self.results['state'][tr])[1] in matching_state1):
sa_count+=1
block_count.append(sa_count)
traject[numQphs][sa].append(block_count)
return traject
def accum_Qhx(states,actions,rl,numQ,Qhx=None):
#find the state number corresponding to states for each learning phase
state_nums={state: {q: [] for q in range(numQ)} for state in states}
for q in range(numQ):
int_ideal_states=[(int(v[0]),int(v[1])) for v in rl.agent.ideal_states[q].values()]
int_ideal_state1=[int(v[1]) for v in rl.agent.ideal_states[q].values()]
for state in states:
st0,st1,matching_states=rl.get_statenum(state)
if len(matching_states)==0 and st1>-1:
matching_states=[state[1]]
for ms in matching_states:
hx_num=rl.env.states['hx'][ms]
if st0>-1:
if (st0,hx_num) in int_ideal_states:
qindex=int_ideal_states.index((st0,hx_num))
state_nums[state][q].append((state[0]+','+ms,qindex))
#print(state,',match:', ms,',num',st0,hx_num,'in Q:',qindex)
#else:
#print(state,',match:', ms,',num',st0,hx_num,'Not found')
else:
qindices=np.where(np.array(int_ideal_state1)==hx_num)[0]
#print(state,',match:', ms,',num',st0,hx_num,'in Q:',qindices)
for qnum in qindices:
state_pair=list(int_ideal_states[qnum])
state_words=rl.env.state_from_number(state_pair)
state_nums[state][q].append((state_words[0][0:3]+','+ms,qnum))
if not Qhx:
Qhx={st:{q:{ph[0]:{ac:[] for ac in actions} for ph in state_nums[st][q]} for q in state_nums[st].keys()} for st in state_nums.keys()}
for st in state_nums.keys():
for qv in state_nums[st].keys():
for (ph,qindex) in state_nums[st][qv]:
if ph in Qhx[st][qv].keys(): #not all states are visited each run
for ac in actions.keys():
Qhx[st][qv][ph][ac].append(rl.agent.Qhx[qv][:,qindex,rl.agent.actions[ac]])
else:
Qhx[st][qv][ph]={ac:[] for ac in actions}
for ac in actions.keys():
Qhx[st][qv][ph][ac].append(rl.agent.Qhx[qv][:,qindex,rl.agent.actions[ac]])
#need to return state_nums? which may differ for each run
return Qhx,state_nums
##########################################################
if __name__ == "__main__":
from SequenceTaskParam import Hx_len,rwd
from SequenceTaskParam import params,states,act
from SequenceTaskParam import Tloc,R,env_params
numtrials=600 # 450 #
runs=15
#If want to add reward and time since reward to cues, need to divide by ~100
noise=0.01 #make noise small enough or state_thresh small enough to minimize new states in acquisition
#control output
printR=False #print environment Reward matrix
Info=False #print information for debugging
plot_hist=0#1: plot Q, 2: plot the time since last reward
other_plots=True
save_data=True #write output data in npz file
Qvalues=[1,2] #simulate using these values for numQ, make this [2] to simulate inactivation
inactivate=None #set to None to skip the inactivation test at the end 'D1', 'D2'
inactivate_blocks=0# 1 or 3 for inactivate = 'D1' or 'D2', 0 otherwise
########## Plot Q values over time for these states and actions
plot_Qhx=True
actions_colors={'goL':'r','goR':'b','press':'k','goMag':'grey'}
if Hx_len==3:
#MINIMUM actions for reward = 6, so maximum rewards = 1 per 6 "trials"
state_action_combos=[(('*','*LL'), 'goR'),(('Rlever','*LL'),'press'),(('Rlever','LLR'),'press')]
elif Hx_len==4:
#MINIMUM actions for reward = 7, so maximum rewards = 1 per 7 "trials"
state_action_combos=[(('Llever','---L'), 'press'),(('Llever','**RL'), 'press'),(('Llever','**LL'), 'goR'),(('Rlever','**LL'),'press'),(('Rlever','*LLR'),'press'),(('Rlever','LLRR'),'goMag')]
overstay=[(('Llever','**LL'), act) for act in ['goL','goMag','press','other']]+\
[(('Rlever','LLRR'),act) for act in ['goL','goR','press','other']] #'**LL' instead of --LL
premature=[(('Llever','**RL'), act) for act in ['goL','goR','goMag','other']]+\
[(('Rlever','**RL'), act) for act in ['goR','goMag','other','press']]+\
[(('Llever','---L'), act) for act in ['goL','goR','goMag','other']]+\
[(('Rlever','---L'), act) for act in ['goR','goMag','other','press']]+\
[(('Rlever','*LLR'), act) for act in ['goL','goR','goMag','other']] #'*LLR' instead of -LLR
start=[(('mag','----'), act) for act in ['goL','goR','goMag','press','other']]
state_action_combos=state_action_combos+overstay+premature+start
sa_errors={'stay':overstay,'switch':premature,'start':start}
else:
print('unrecognized press history length')
plot_Qstates=[state[0] for state in state_action_combos]
numevents=numtrials*params['events_per_trial'] #number of events/actions allowed for agent per run/trial
trial_subset=int(0.05*numtrials)*params['events_per_trial']# display mean reward and count actions over 1st and last of these number of trials
epochs=['Beg','End']
trials_per_block=10
events_per_block=trials_per_block* params['events_per_trial']
num_blocks=int((numevents+1)/events_per_block)
#optionally add blocks of runs with D1 or D2 inactivated
#update some parameters
params['distance']='Euclidean'
params['wt_learning']=False
params['decision_rule']=None #'delta'#'combo', 'delta', 'sumQ2', None means use choose_winner
params['Q2other']=0.0 #heterosynaptic syn plas of Q2 for other actions
params['forgetting']=0#0.2 #heterosynaptic decrease Q1 for other actions
params['beta_min']=0.5#params['beta'] #0.1 is only slightly worse#
params['beta']=3
params['gamma']=0.95
params['beta_GPi']=10
params['moving_avg_window']=3
params['initQ']=-1 #-1 means do state splitting. If initQ=0, 1 or 10, it means initialize Q to that value and don't split
params['D2_rule']= None #'Ndelta' #'Bogacz' #'Opal'### Opal: use Opal update without critic, Ndelta: calculate delta for N matrix from N values
params['rwd']=rwd['reward']
#lower means more states
state_thresh={'Q1':[0.75,0],'Q2':[0.75,0.875]} #without normalized ED, with heterosynaptic LTP
state_thresh={'Q1':[0.75,0],'Q2':[0.75,0.625]} #or st2= 0.875 with normalized ED, with heterosynaptic LTD
alpha={'Q1':[0.2,0],'Q2':[0.2,0.35]}
#params['state_thresh']=[0.25,0.275]#[0.15,0.2] #threshold on prob for creating new state using Gaussian mixture
# higher means more states. Adjusted so that in new context, Q2 creates new states, but not Q1
#params['alpha']=[0.3,0.15] # [0.2,0.14] # double learning to learn in half the trials, slower for Q2 - D2 neurons
output_data={q:{} for q in Qvalues}
all_Qhx={q:[] for q in Qvalues}
all_beta={q:[] for q in Qvalues}
all_lenQ={q:{qq:[] for qq in range(1,q+1)} for q in Qvalues}
numchars=6
state_subset=['RRLL','RLLR']
params['events_per_block']=events_per_block
params['trials_per_block']=trials_per_block
params['trial_subset']=trial_subset
params['inact']=inactivate
params['inact_blocks']=inactivate_blocks
sa_keys=rlu.construct_key(state_action_combos +['rwd'],epochs)
results={numQ:{sa:{'Beg':[],'End':[]} for sa in state_action_combos+['rwd']} for numQ in Qvalues}
resultslist={numQ:{k+'_'+ep:[] for k in sa_keys.values() for ep in epochs} for numQ in Qvalues}
traject_dict={numQ:{k:[] for k in sa_keys.keys()} for numQ in Qvalues}
for numQ in Qvalues:
resultslist[numQ]['params']={p:[] for p in params}
Qhx=None
for r in range(runs):
params['numQ']=numQ
params['state_thresh']=state_thresh['Q'+str(numQ)]
params['alpha']= alpha['Q'+str(numQ)]
if runs==1:
print('&&&&&&&&&&&&&&&&&&&& STATES',states,'\n **** R:',R.keys(),'\n **** T:',Tloc.keys())
######### acquisition trials, context A, only 6 Khz + L turn allowed #########
acq = RL(task_env.separable_T, QL.QL, states,act,R,Tloc,params,env_params,printR=printR)
acq.episode(numevents,noise=noise,info=Info)
if params['inact'] and numQ==2:
if params['inact']=='D2':
acq.agent.Q[1]=np.zeros(np.shape(acq.agent.Q[1]))
acq.agent.alpha[1]=0
#acq.agent.numQ=1
params['Da_factor']=acq.agent.Da_factor=0.5
elif params['inact']=='D1':
acq.agent.Q[0]=np.zeros(np.shape(acq.agent.Q[0]))
acq.agent.alpha[0]=0
params['Da_factor']=acq.agent.Da_factor=2
acq.episode(events_per_block*params['inact_blocks'],noise=noise,info=Info)
#acq.set_of_plots('LLRR, numQ='+str(params['numQ']),noise,Hx_len,hist=plot_hist)
results[numQ]=acq.count_actions(results[numQ],state_action_combos,trial_subset,accum_type='mean')#,accum_type='count')
traject_dict=acq.trajectory(traject_dict, sa_keys,num_blocks+params['inact_blocks'],events_per_block,numQ,accum_type='mean')#,accum_type='count')
if r<1 and other_plots:
acq.set_of_plots(str(numQ),noise,Hx_len,title2='',hist=plot_hist)
#acq.visual()
if plot_Qhx: #need to return state_nums, which may differ for each run
Qhx,state_nums=accum_Qhx(plot_Qstates,actions_colors,acq,params['numQ'],Qhx)
#del acq #to free up memory
print('numQ=',numQ,', run',r,'Q0 mat states=',len(acq.agent.Q[0]),'alpha',acq.agent.alpha)
all_beta[numQ].append(acq.agent.learn_hist['beta'])
for qq in all_lenQ[numQ].keys():
if qq-1 in acq.agent.learn_hist['lenQ'].keys():
all_lenQ[numQ][qq].append(acq.agent.learn_hist['lenQ'][qq-1])
all_Qhx[numQ]=Qhx
resultslist=rlu.save_results(results,sa_keys,resultslist)
for p in resultslist[numQ]['params'].keys(): #
resultslist[numQ]['params'][p].append(params[p]) #
for ta in traject_dict[numQ].keys():
output_data[numQ][ta]={'mean':np.mean(traject_dict[numQ][ta],axis=0),'sterr':np.std(traject_dict[numQ][ta],axis=0)/np.sqrt(runs-1)}
print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
print(' Using',params['numQ'], 'Q, alpha=',params['alpha'],'thresh',params['state_thresh'], 'runs',runs,'of total events',numevents)
print(' weights:',[k+':'+str(params[k]) for k in params.keys() if k.startswith('wt')])
print('Q2 hetero=',params['Q2other'],'decision rule=',params['decision_rule'],'beta=',params['beta_min'],params['beta'])
print('counts from ',trial_subset,' events: BEGIN END std over ',runs,'runs. Hx_len=',Hx_len)
norm={sac:100 for sac in results[numQ].keys()}
norm['rwd']=1
for sa_combo,counts in results[numQ].items():
print(sa_combo,':::',np.round(np.mean(counts['Beg'])*norm[sa_combo],2),'% +/-',np.round(np.std(counts['Beg'])*norm[sa_combo],2),
',', np.round(np.mean(counts['End'])*norm[sa_combo],2),'% +/-',np.round(np.std(counts['End'])*norm[sa_combo],2))
if other_plots:
for i in range(params['numQ']):
acq.agent.visual(acq.agent.Q[i],labels=acq.state_to_words(i,noise,numchars),title='numQ='+str(numQ)+',Q'+str(i),state_subset=state_subset)
if save_data:
import datetime
dt=datetime.datetime.today()
date=str(dt).split()[0]
key_params=['numQ','Q2other','beta_GPi','decision_rule','beta_min','beta','gamma','rwd']
fname_params=key_params+['initQ']
fname='Sequence'+date+'_'.join([k+str(params[k]) for k in fname_params])
#fname='Sequence'+date+'_HxLen'+str(Hx_len)+'_alpha'+'_'.join([str(a) for a in params['alpha']])+'_st'+'_'.join([str(st) for st in params['state_thresh']])+\
if params['inact']:
fname=fname +'_inactive'+params['inact']+'_'+str(params['Da_factor'])
np.savez(fname,par=params,results=resultslist[numQ],traject=output_data[numQ],Qhx=all_Qhx[numQ],all_beta=all_beta[numQ],all_lenQ=all_lenQ[numQ],sa_errors=sa_errors)
allQ={i:acq.agent.Q[i] for i in range(params['numQ'])}
all_labels={i:acq.state_to_words(i,noise,numchars) for i in range(params['numQ'])}
actions=acq.agent.actions
print('\nsummary for beta_min=',params['beta_min'],'beta_max=',params['beta'],'beta_GPi=', params['beta_GPi'],'gamma=',params['gamma'])
for numQ in Qvalues:
halfrwd=(np.max(output_data[numQ]['rwd']['mean'])+np.min(output_data[numQ]['rwd']['mean']))/2
####### replace this with 90% of maximal to measure effect of beta? #############
block=np.min(np.where(output_data[numQ]['rwd']['mean']>halfrwd))
print('rwd End',':::',round(np.mean(results[numQ]['rwd']['End'])*norm['rwd'],2), \
'per trial, +/-',round(np.std(results[numQ]['rwd']['End'])*norm['rwd'],2), \
', blocks to half reward=',block, 'for nQ=', numQ)
#
title='History Length '+str(Hx_len)+'\nminimum '+str(params['events_per_trial'])+' actions per reward'
if other_plots:
rlu.plot_trajectory(output_data,title,[Qvalues])
if plot_Qhx:
plot_states=[('Llever','--LL'),('Rlever','-LLR'),('Rlever','LLRR')]#[('Llever','RRLL'),('Rlever','LLLL'),('Rlever','RLLR'),('Rlever','RRLL')]
actions_lines={a:'solid' for a in actions_colors.keys()}
from TD2Q_Qhx_graphs import plot_Qhx_sequence, plot_Qhx_sequence_1fig
figs=plot_Qhx_sequence_1fig (all_Qhx,plot_states,actions_colors,params['events_per_trial'],actions_lines)
#for numQ,Qhx in all_Qhx.items():
#figs=plot_Qhx_sequence(Qhx,actions_colors,params['events_per_trial'],numQ)
#figs=plot_Qhx_sequence_1fig (Qhx,plot_states,actions_colors,params['events_per_trial'])