% mkfig_Fig14

%-----
% This file is associated with the following article, which has been provisionally accepted for publication in PLOS Computational Biology
% (initially submitted on May 11, 2016, and provisionally accepted on Sep 14, 2016):
% Authors: Ayaka Kato (1) & Kenji Morita (2)
% Affiliations:
%  (1) Department of Biological Sciences, Graduate School of Science, The University of Tokyo, Tokyo, Japan
%  (2) Physical and Health Education, Graduate School of Education, The University of Tokyo, Tokyo Japan
% Title: Forgetting in Reinforcement Learning Links Sustained Dopamine Signals to Motivation
% Short title: Dynamic Equilibrium in Reinforcement Learning
% Correspondence: Kenji Morita (morita@p.u-tokyo.ac.jp)
%-----

% to use the same random numbers as used in the simulations presented in the figures in the paper
load used_rand_twister_for_Fig14

% parameters
num_sim = 20;
num_trial = 1000;
RLtype = 'Q';
p_alpha = 0.5;
p_beta = 5;
p_gamma = 1;
decay_rate = 0.01;
DAdep_paras = [0.25,501]; % depletion to 1/4 after 500 trials
velo_Stay_factor = 0.5;
rewarded_state = [5,8; 5,6];

% simulations
%rand('twister',sum(100*clock));
%Dsim.rand_twister = rand('twister');
Dsim.rand_twister = used_rand_twister_for_Fig14;
rand('twister',Dsim.rand_twister);

% reward
Rews{1} = zeros(1,20); Rews{1}([11,17]) = [0.5 1];
Rews{2} = zeros(1,20); Rews{2}([11,12]) = [0.5 1];

% main
for k1 = 1:2
    for k2 = 1:num_sim
        fprintf('%d-%d\n',k1,k2);
        Dsim.Out{k1}{k2} = RLdecayTmaze2(num_trial,RLtype,[p_alpha,p_beta,p_gamma],Rews{k1},decay_rate,DAdep_paras,velo_Stay_factor);
    end
end

% choose 2 ratio
bin = 10; % num_trial/bin should be an integer
for k1 = 1:2
    Dsim.choose2ratio{k1} = NaN(num_sim,num_trial/bin);
    for k2 = 1:num_sim
        Dsim.choose2ratio{k1}(k2,:) = sum(reshape(Dsim.Out{k1}{k2}.ArmChoices-1,bin,num_trial/bin),1)/bin;
    end
end

% time to reach state 4
bin = 10; % num_trial/bin should be an integer
for k1 = 1:2
    Dsim.avetime{k1} = NaN(num_sim,num_trial/bin); % all, to reach state 4
    for k2 = 1:num_sim
        tmp_times = diff([0;Dsim.Out{k1}{k2}.endsteps]);
        tmp_times3 = NaN(num_trial,1);
        for k_trial = 1:num_trial
            tmp = [0;Dsim.Out{k1}{k2}.endsteps];
            tmp_tsteps = [tmp(k_trial)+1:tmp(k_trial+1)]; % time steps (from start to end) for k_trial
            tmp_times3(k_trial) = find(Dsim.Out{k1}{k2}.States(tmp_tsteps)==4,1,'first');
        end
        Dsim.avetime{k1}(k2,:) = mean2(reshape(tmp_times3,bin,num_trial/bin),1);
    end
end

% save
save('Dvelo','Dsim');

% plot
save_fig = 1;

% Fig. 14C,G
tmpletters = 'CG';
for k1 = 1:2
    F = figure;
    A = axes;
    hold on;
    P = plot([0 num_trial/bin],[0 0],'k:');
    P = plot([0 num_trial/bin],[0.5 0.5],'k:');
    P = plot([0 num_trial/bin],[1 1],'k:');
    P = plot((1/2)*(num_trial/bin)*[1 1],[-0.1 1.1],'k.-');
    P = plot([0.5:1:num_trial/bin-0.5],mean2(Dsim.choose2ratio{k1},1)+sem2(Dsim.choose2ratio{k1},1,1,0),'k--');
    P = plot([0.5:1:num_trial/bin-0.5],mean2(Dsim.choose2ratio{k1},1)-sem2(Dsim.choose2ratio{k1},1,1,0),'k--');
    P = plot([0.5:1:num_trial/bin-0.5],mean2(Dsim.choose2ratio{k1},1),'k');
    axis([0 num_trial/bin -0.1 1.1]);
    set(A,'Box','off');
    %set(A,'PlotBoxAspectRatio',[1 1 1]);
    set(A,'FontName','Ariel','FontSize',20);
    set(A,'XTick',[0:100/bin:num_trial/bin],'XTickLabel',[0:100:num_trial]);
    set(A,'YTick',[0:0.1:1],'YTickLabel',[0:0.1:1]);
    if save_fig
        print(F,'-depsc',['Fig14' tmpletters(k1)]);
    end
end

% Fig. 14D,H
tmpletters = 'DH';
for k1 = 1:2
    F = figure;
    A = axes;
    hold on;
    P = plot([0 num_trial/bin],[4 4],'k:');
    P = plot((1/2)*(num_trial/bin)*[1 1],[3 8],'k.-');
    P = plot([0.5:1:num_trial/bin-0.5],mean2(Dsim.avetime{k1},1)+sem2(Dsim.avetime{k1},1,1,0),'k--');
    P = plot([0.5:1:num_trial/bin-0.5],mean2(Dsim.avetime{k1},1)-sem2(Dsim.avetime{k1},1,1,0),'k--');
    P = plot([0.5:1:num_trial/bin-0.5],mean2(Dsim.avetime{k1},1),'k');
    axis([0 num_trial/bin 3 8]);
    set(A,'Box','off');
    %set(A,'PlotBoxAspectRatio',[1 1 1]);
    set(A,'FontName','Ariel','FontSize',20);
    set(A,'XTick',[0:100/bin:num_trial/bin],'XTickLabel',[0:100:num_trial]);
    set(A,'YTick',[3:1:8],'YTickLabel',[3:1:8]);
    if save_fig
        print(F,'-depsc',['Fig14' tmpletters(k1)]);
    end
end

% Fig. 14EI
tmpletters = 'EI';
for k1 = 1:2
    avelos{1}{k1} = NaN(num_sim,10); % all, before DA depletion
    avelos{2}{k1} = NaN(num_sim,10); % all, after DA depletion
    avelos{3}{k1} = NaN(num_sim,10); % after stay, before DA depletion
    avelos{4}{k1} = NaN(num_sim,10); % after stay, after DA depletion
    for k2 = 1:num_sim
        for k_state = 1:10
            tmp1 = zeros(length(Dsim.Out{k1}{k2}.States),1);
            tmp2 = zeros(length(Dsim.Out{k1}{k2}.States),1);
            tmp1(Dsim.Out{k1}{k2}.endsteps(250)+1:Dsim.Out{k1}{k2}.endsteps(500)) = 1;
            tmp2(Dsim.Out{k1}{k2}.endsteps(750)+1:Dsim.Out{k1}{k2}.endsteps(1000)) = 1;
            tmp_velo1 = Dsim.Out{k1}{k2}.Velocities((Dsim.Out{k1}{k2}.States==k_state)&tmp1);
            tmp_velo2 = Dsim.Out{k1}{k2}.Velocities((Dsim.Out{k1}{k2}.States==k_state)&tmp2);
            tmp_afterstay = (Dsim.Out{k1}{k2}.States == [0;Dsim.Out{k1}{k2}.States(1:end-1)]);
            tmp_velo3 = Dsim.Out{k1}{k2}.Velocities((Dsim.Out{k1}{k2}.States==k_state)&tmp1&tmp_afterstay);
            tmp_velo4 = Dsim.Out{k1}{k2}.Velocities((Dsim.Out{k1}{k2}.States==k_state)&tmp2&tmp_afterstay);
            avelos{1}{k1}(k2,k_state) = mean2(tmp_velo1,1);
            avelos{2}{k1}(k2,k_state) = mean2(tmp_velo2,1);
            avelos{3}{k1}(k2,k_state) = mean2(tmp_velo3,1);
            avelos{4}{k1}(k2,k_state) = mean2(tmp_velo4,1);
        end
    end
    F = figure;
    A = axes;
    hold on;
    P = plot([0.5 7.5],[0 0],'b');
    P = plot([0.5 7.5],[1 1],'b');
    if k1 == 1
        % all
        P = errorbar([1:4 4.85 5.85 6.85],mean2(avelos{1}{k1}(:,[1:4 6 8 10]),1),sem2(avelos{1}{k1}(:,[1:4 6 8 10]),1,1,0),'k');
        P = plot([1:4 4.85 5.85 6.85],mean2(avelos{1}{k1}(:,[1:4 6 8 10]),1),'k--');
        P = errorbar([4 5.15 6.15],mean2(avelos{1}{k1}(:,[4 5 9]),1),sem2(avelos{1}{k1}(:,[4 5 9]),1,1,0),'k');
        P = plot([4 5.15 6.15],mean2(avelos{1}{k1}(:,[4 5 9]),1),'k:');
        P = errorbar([1:4 4.85 5.85 6.85],mean2(avelos{2}{k1}(:,[1:4 6 8 10]),1),sem2(avelos{2}{k1}(:,[1:4 6 8 10]),1,1,0),'r');
        P = plot([1:4 4.85 5.85 6.85],mean2(avelos{2}{k1}(:,[1:4 6 8 10]),1),'r--');
        P = errorbar([4 5.15 6.15],mean2(avelos{2}{k1}(:,[4 5 9]),1),sem2(avelos{2}{k1}(:,[4 5 9]),1,1,0),'r');
        P = plot([4 5.15 6.15],mean2(avelos{2}{k1}(:,[4 5 9]),1),'r:');
        % after stay
        P = errorbar([1:4 4.85 5.85 6.85],mean2(avelos{3}{k1}(:,[1:4 6 8 10]),1),sem2(avelos{3}{k1}(:,[1:4 6 8 10]),1,1,0),'c');
        P = plot([1:4 4.85 5.85 6.85],mean2(avelos{3}{k1}(:,[1:4 6 8 10]),1),'c--');
        P = errorbar([4 5.15 6.15],mean2(avelos{3}{k1}(:,[4 5 9]),1),sem2(avelos{3}{k1}(:,[4 5 9]),1,1,0),'c');
        P = plot([4 5.15 6.15],mean2(avelos{3}{k1}(:,[4 5 9]),1),'c:');
        P = errorbar([1:4 4.85 5.85 6.85],mean2(avelos{4}{k1}(:,[1:4 6 8 10]),1),sem2(avelos{4}{k1}(:,[1:4 6 8 10]),1,1,0),'m');
        P = plot([1:4 4.85 5.85 6.85],mean2(avelos{4}{k1}(:,[1:4 6 8 10]),1),'m--');
        P = errorbar([4 5.15 6.15],mean2(avelos{4}{k1}(:,[4 5 9]),1),sem2(avelos{4}{k1}(:,[4 5 9]),1,1,0),'m');
        P = plot([4 5.15 6.15],mean2(avelos{4}{k1}(:,[4 5 9]),1),'m:');
    elseif k1 == 2
        % all
        P = errorbar([1:4 4.85 5.85],mean2(avelos{1}{k1}(:,[1:4 6 10]),1),sem2(avelos{1}{k1}(:,[1:4 6 10]),1,1,0),'k');
        P = plot([1:4 4.85 5.85],mean2(avelos{1}{k1}(:,[1:4 6 10]),1),'k--');
        P = errorbar([4 5.15 6.15],mean2(avelos{1}{k1}(:,[4 5 9]),1),sem2(avelos{1}{k1}(:,[4 5 9]),1,1,0),'k');
        P = plot([4 5.15 6.15],mean2(avelos{1}{k1}(:,[4 5 9]),1),'k:');
        P = errorbar([1:4 4.85 5.85],mean2(avelos{2}{k1}(:,[1:4 6 10]),1),sem2(avelos{2}{k1}(:,[1:4 6 10]),1,1,0),'r');
        P = plot([1:4 4.85 5.85],mean2(avelos{2}{k1}(:,[1:4 6 10]),1),'r--');
        P = errorbar([4 5.15 6.15],mean2(avelos{2}{k1}(:,[4 5 9]),1),sem2(avelos{2}{k1}(:,[4 5 9]),1,1,0),'r');
        P = plot([4 5.15 6.15],mean2(avelos{2}{k1}(:,[4 5 9]),1),'r:');
        % after stay
        P = errorbar([1:4 4.85 5.85],mean2(avelos{3}{k1}(:,[1:4 6 10]),1),sem2(avelos{3}{k1}(:,[1:4 6 10]),1,1,0),'c');
        P = plot([1:4 4.85 5.85],mean2(avelos{3}{k1}(:,[1:4 6 10]),1),'c--');
        P = errorbar([4 5.15 6.15],mean2(avelos{3}{k1}(:,[4 5 9]),1),sem2(avelos{3}{k1}(:,[4 5 9]),1,1,0),'c');
        P = plot([4 5.15 6.15],mean2(avelos{3}{k1}(:,[4 5 9]),1),'c:');
        P = errorbar([1:4 4.85 5.85],mean2(avelos{4}{k1}(:,[1:4 6 10]),1),sem2(avelos{4}{k1}(:,[1:4 6 10]),1,1,0),'m');
        P = plot([1:4 4.85 5.85],mean2(avelos{4}{k1}(:,[1:4 6 10]),1),'m--');
        P = errorbar([4 5.15 6.15],mean2(avelos{4}{k1}(:,[4 5 9]),1),sem2(avelos{4}{k1}(:,[4 5 9]),1,1,0),'m');
        P = plot([4 5.15 6.15],mean2(avelos{4}{k1}(:,[4 5 9]),1),'m:');
    end
    axis([0.5 7.5 -0.1 1.1]);
    set(A,'Box','off');
    %set(A,'PlotBoxAspectRatio',[1 1 1]);
    set(A,'FontName','Ariel','FontSize',20);
    set(A,'XTick',[1:4 4.85 5.15 5.85 6.15 6.85 7.15],'XTickLabel',[]);
    set(A,'YTick',[0:0.1:1],'YTickLabel',[0:0.1:1]);
    if save_fig
        print(F,'-depsc',['Fig14' tmpletters(k1)]);
    end
end