function Out = RLdecayTmaze2(num_trial,RLtype,RLparas,Rews,decay_rate,DAdep_paras,velo_Stay_factor)
%-----
% This file is associated with the following article, which has been provisionally accepted for publication in PLOS Computational Biology
% (initially submitted on May 11, 2016, and provisionally accepted on Sep 14, 2016):
% Authors: Ayaka Kato (1) & Kenji Morita (2)
% Affiliations:
% (1) Department of Biological Sciences, Graduate School of Science, The University of Tokyo, Tokyo, Japan
% (2) Physical and Health Education, Graduate School of Education, The University of Tokyo, Tokyo Japan
% Title: Forgetting in Reinforcement Learning Links Sustained Dopamine Signals to Motivation
% Short title: Dynamic Equilibrium in Reinforcement Learning
% Correspondence: Kenji Morita (morita@p.u-tokyo.ac.jp)
%-----
% Out = RLdecayTmaze2(num_trial,RLtype,RLparas,Rews,decay_rate,DAdep_paras,velo_paras)
%
% DAdep_paras: [DAdep_factor (mulitiplicative), DAdep_start_trial], e.g., [0.25, 501]
% velo_Stay_factor: velocity is multiplied by velo_Stay_factor when Stay is selected (0:instantaneously stop(=original model); e.g., 0.5:velocity is halved)
%
% <states> 1 2 3 4 5/6 7/8 reward-consume 9/10(after reward)
% <actions> 1S2G 4S5G 7S8G 10S11/12 13S14G - 19 - Arm1
% 16S17G - 20 - Arm2
%
% parameters
p_alpha = RLparas(1);
p_beta = RLparas(2);
p_gamma = RLparas(3);
num_tstep = num_trial * 100;
% initialization of the variables
ArmChoices = zeros(num_trial,1);
States = zeros(num_tstep,1);
Choices = zeros(num_tstep,1);
Qs = zeros(num_trial,20);
TDs = zeros(num_tstep,1);
endsteps = zeros(num_trial,1); % time steps for reaching the trial end
Velocities = NaN(num_tstep,1); % velocities (Go velocity is fixed to 1), initialization
Positions = NaN(num_tstep,1); % positions (1=start(State1) to 7(trial end)), initialization
% main loop
Qnow = zeros(1,20);
CurrS = 1;
CurrV = 0;
CurrP = 1;
prevA = [];
k_trial = 1;
for k_tstep = 1:num_tstep
% save the state
States(k_tstep) = CurrS;
Velocities(k_tstep) = CurrV;
Positions(k_tstep) = CurrP;
% set DA depletion
if k_trial >= DAdep_paras(2)
DAfactor = DAdep_paras(1);
else
DAfactor = 1;
end
% main
if isempty(prevA)
Choices(k_tstep) = 0 + actselect(Qnow([1,2]),p_beta);
if RLtype == 'Q'
TDs(k_tstep) = 0 + p_gamma*max(Qnow([1,2])) - 0;
elseif RLtype == 'S'
TDs(k_tstep) = 0 + p_gamma*Qnow(Choices(k_tstep)) - 0;
end
Qnow = Qnow * (1 - decay_rate);
prevA = Choices(k_tstep);
if Choices(k_tstep) == 2
CurrV = 1;
end
CurrP = CurrP + CurrV;
CurrS = floor(CurrP);
elseif CurrS == 4
Choices(k_tstep) = 9 + actselect(Qnow([10,11,12]),p_beta);
if RLtype == 'Q'
TDs(k_tstep) = Rews(prevA) + p_gamma*max(Qnow([10,11,12])) - Qnow(prevA);
elseif RLtype == 'S'
TDs(k_tstep) = Rews(prevA) + p_gamma*Qnow(Choices(k_tstep)) - Qnow(prevA);
end
Qnow(prevA) = Qnow(prevA) + DAfactor*p_alpha*TDs(k_tstep);
Qnow = Qnow * (1 - decay_rate);
prevA = Choices(k_tstep);
if Choices(k_tstep) == 10
CurrV = 0;
CurrP = 4;
CurrS = 4;
elseif Choices(k_tstep) == 11
ArmChoices(k_trial) = 1;
CurrV = 1;
CurrP = 5;
CurrS = 5;
elseif Choices(k_tstep) == 12
ArmChoices(k_trial) = 2;
CurrV = 1;
CurrP = 5;
CurrS = 6;
end
elseif (CurrS == 9) || (CurrS == 10) % trial end after reward consumption
TDs(k_tstep) = 0 + p_gamma*0 - Qnow(prevA);
Qnow(prevA) = Qnow(prevA) + DAfactor*p_alpha*TDs(k_tstep);
Qnow = Qnow * (1 - decay_rate);
Qs(k_trial,:) = Qnow;
endsteps(k_trial) = k_tstep;
if k_trial == num_trial
States = States(1:k_tstep);
Choices = Choices(1:k_tstep);
TDs = TDs(1:k_tstep);
Velocities = Velocities(1:k_tstep);
Positions = Positions(1:k_tstep);
break;
else
k_trial = k_trial + 1;
CurrS = 1;
CurrV = 0;
CurrP = 1;
prevA = [];
end
else
if Rews(prevA) > 0 % if reward is obtained
TDs(k_tstep) = Rews(prevA) + 0 - Qnow(prevA);
Qnow(prevA) = Qnow(prevA) + DAfactor*p_alpha*TDs(k_tstep);
Qnow = Qnow * (1 - decay_rate);
prevA = ArmChoices(k_trial) + 18;
CurrV = 0;
CurrS = ArmChoices(k_trial) + 8;
else
Choices(k_tstep) = 3*(CurrS-1) + actselect(Qnow(3*(CurrS-1)+[1,2]),p_beta);
if RLtype == 'Q'
TDs(k_tstep) = Rews(prevA) + p_gamma*max(Qnow(3*(CurrS-1)+[1,2])) - Qnow(prevA);
elseif RLtype == 'S'
TDs(k_tstep) = Rews(prevA) + p_gamma*Qnow(Choices(k_tstep)) - Qnow(prevA);
end
Qnow(prevA) = Qnow(prevA) + DAfactor*p_alpha*TDs(k_tstep);
Qnow = Qnow * (1 - decay_rate);
prevA = Choices(k_tstep);
% update of velocity
if Choices(k_tstep) == 3*(CurrS-1) + 2 % Go
CurrV = 1;
else
CurrV = CurrV * velo_Stay_factor;
end
% update of position, and modify the update of velocity when necessary
if (CurrS == 3) && (CurrP + CurrV > 4)
CurrP = 4;
CurrV = 4 - CurrP;
elseif ((CurrS == 5) || (CurrS == 6)) && (CurrP + CurrV > 6) && (Rews(prevA) > 0)
CurrP = 6;
CurrV = 6 - CurrP;
else
CurrP = CurrP + CurrV;
end
% update of state
if (CurrS == 1) || (CurrS == 2) || (CurrS == 3)
CurrS = floor(CurrP);
elseif (CurrS == 5) || (CurrS == 6)
CurrS = CurrS + 2*floor(CurrP - 5);
end
end
end
end
if k_tstep == num_tstep
error('number of time steps is not enough');
end
% output variables
Out.ArmChoices = ArmChoices;
Out.States = States;
Out.Choices = Choices;
Out.Qs = Qs;
Out.TDs = TDs;
Out.endsteps = endsteps;
Out.Velocities = Velocities;
Out.Positions = Positions;
% innter-file function
function chosen_option_index = actselect(optionQs,p_beta)
Pchoose = zeros(1,length(optionQs));
for k = 1:length(optionQs)
Pchoose(k) = exp(p_beta*optionQs(k))/sum(exp(p_beta*optionQs));
end
chosen_option_index = [];
tmp_rand = rand;
tmp = 0;
for k = 1:length(optionQs)
tmp = tmp + Pchoose(k);
if tmp_rand <= tmp
chosen_option_index = k;
break;
end
end
if isempty(chosen_option_index)
error('choice is not made');
end