function Out = RLdecayTmaze1(num_trial,RLtype,RLparas,Rews,decay_rate,DAdep_paras)

%-----
% This file is associated with the following article, which has been provisionally accepted for publication in PLOS Computational Biology
% (initially submitted on May 11, 2016, and provisionally accepted on Sep 14, 2016):
% Authors: Ayaka Kato (1) & Kenji Morita (2)
% Affiliations:
%  (1) Department of Biological Sciences, Graduate School of Science, The University of Tokyo, Tokyo, Japan
%  (2) Physical and Health Education, Graduate School of Education, The University of Tokyo, Tokyo Japan
% Title: Forgetting in Reinforcement Learning Links Sustained Dopamine Signals to Motivation
% Short title: Dynamic Equilibrium in Reinforcement Learning
% Correspondence: Kenji Morita (morita@p.u-tokyo.ac.jp)
%-----
% Out = RLdecayTmaze1(num_trial,RLtype,RLparas,Rews,decay_rate,DAdep_paras)
%
%   DAdep_paras: [DAdep_factor (mulitiplicative), DAdep_start_trial], e.g., [0.25, 501]
%
% <states>	1       2       3       4           5/6     7/8     9
% <actions> 1S2G	4S5G	7S8G	10S11/12	13S14G	19S20G      Arm1
%                                               16S17G	22S23G      Arm2
%

% parameters
p_alpha = RLparas(1);
p_beta = RLparas(2);
p_gamma = RLparas(3);
num_tstep = num_trial * 100;

% initialization of the variables
ArmChoices = zeros(num_trial,1);
States = zeros(num_tstep,1);
Choices = zeros(num_tstep,1);
Qs = zeros(num_trial,23);
TDs = zeros(num_tstep,1);
endsteps = zeros(num_trial,1); % time steps for reaching the trial end

% main loop
Qnow = zeros(1,23);
CurrS = 1;
prevA = [];
k_trial = 1;
for k_tstep = 1:num_tstep
    
    % save the state
    States(k_tstep) = CurrS;

    % set DA depletion
    if k_trial >= DAdep_paras(2)
        DAfactor = DAdep_paras(1);
    else
        DAfactor = 1;
    end
    
    % main
    if isempty(prevA)
        Choices(k_tstep) = 0 + actselect(Qnow([1,2]),p_beta);
        if RLtype == 'Q'
            TDs(k_tstep) = 0 + p_gamma*max(Qnow([1,2])) - 0;
        elseif RLtype == 'S'
            TDs(k_tstep) = 0 + p_gamma*Qnow(Choices(k_tstep)) - 0;
        end
        Qnow = Qnow * (1 - decay_rate);
        prevA = Choices(k_tstep);
        if Choices(k_tstep) == 2
            CurrS = 2;
        end
    elseif CurrS == 4
        Choices(k_tstep) = 9 + actselect(Qnow([10,11,12]),p_beta);
        if RLtype == 'Q'
            TDs(k_tstep) = Rews(prevA) + p_gamma*max(Qnow([10,11,12])) - Qnow(prevA);
        elseif RLtype == 'S'
            TDs(k_tstep) = Rews(prevA) + p_gamma*Qnow(Choices(k_tstep)) - Qnow(prevA);
        end
        Qnow(prevA) = Qnow(prevA) + DAfactor*p_alpha*TDs(k_tstep);
        Qnow = Qnow * (1 - decay_rate);
        prevA = Choices(k_tstep);
        if Choices(k_tstep) == 11
            ArmChoices(k_trial) = 1;
            CurrS = 5;
        elseif Choices(k_tstep) == 12
            ArmChoices(k_trial) = 2;
            CurrS = 6;
        end
    elseif CurrS == 9 % trial end
        TDs(k_tstep) = Rews(prevA) + p_gamma*0 - Qnow(prevA);
        Qnow(prevA) = Qnow(prevA) + DAfactor*p_alpha*TDs(k_tstep);
        Qnow = Qnow * (1 - decay_rate);
        Qs(k_trial,:) = Qnow;
        endsteps(k_trial) = k_tstep;
        if k_trial == num_trial
            States = States(1:k_tstep);
            Choices = Choices(1:k_tstep);
            TDs = TDs(1:k_tstep);
            break;
        else
            k_trial = k_trial + 1;
            CurrS = 1;
            prevA = [];
        end
    else
        Choices(k_tstep) = 3*(CurrS-1) + actselect(Qnow(3*(CurrS-1)+[1,2]),p_beta);
        if RLtype == 'Q'
            TDs(k_tstep) = Rews(prevA) + p_gamma*max(Qnow(3*(CurrS-1)+[1,2])) - Qnow(prevA);
        elseif RLtype == 'S'
            TDs(k_tstep) = Rews(prevA) + p_gamma*Qnow(Choices(k_tstep)) - Qnow(prevA);
        end
        Qnow(prevA) = Qnow(prevA) + DAfactor*p_alpha*TDs(k_tstep);
        Qnow = Qnow * (1 - decay_rate);
        prevA = Choices(k_tstep);
        if Choices(k_tstep) == 3*(CurrS-1) + 2 % Go
            if (CurrS == 1) || (CurrS == 2) || (CurrS == 3) || (CurrS == 8)
                CurrS = CurrS + 1;
            elseif (CurrS == 5) || (CurrS == 6) || (CurrS == 7)
                CurrS = CurrS + 2;
            end
        end
    end
end
if k_tstep == num_tstep
    error('number of time steps is not enough');
end

% output variables
Out.ArmChoices = ArmChoices;
Out.States = States;
Out.Choices = Choices;
Out.Qs = Qs;
Out.TDs = TDs;
Out.endsteps = endsteps;


% innter-file function

function chosen_option_index = actselect(optionQs,p_beta)

Pchoose = zeros(1,length(optionQs));
for k = 1:length(optionQs)
    Pchoose(k) = exp(p_beta*optionQs(k))/sum(exp(p_beta*optionQs));
end

chosen_option_index = [];
tmp_rand = rand;
tmp = 0;
for k = 1:length(optionQs)
    tmp = tmp + Pchoose(k);
    if tmp_rand <= tmp
        chosen_option_index = k;
        break;
    end
end
if isempty(chosen_option_index)
    error('choice is not made');
end