%==========================================================================
%Title: Phasic dopamine changes and Hebbian mechanisms during probabilistic
%reversal learning in striatal circuits: a computational study
%
%Description: Script to perform the testing procedure: 4-choice
%experiment.
%
% The script calls the function used to simulate the task named
% 'BG_model_Response_Stimuli_Task1'.
% To be run after 'Basal_Training_synapses_4_channels.m' or
% 'Reversal_Training_synapses_4_channels.m'.
%
% 200 trials are executed, by using the four inputs 50 times each. During
% the test, random Gaussian noise was applied not only to cortical neurons
% but also to the input stimuli (simulating a noisy environment).
%
% This testing procedure has been performed starting from the network
% obtained after different epochs of learning indicated by vectors in lines
% 49-50.
%
% The results are shown at the end of the simulation in a table organized
% as follow:
% 1st column: epoch 
% 2nd column: mean 1st Action 
% 3rd column: standard deviation 1st Action 
% 4th column: mean 2nd Action 
% 5th column: standard deviation 2nd Action
%
% Mauro Ursino, Miriam Schirru 
% Jan. 2022
%==========================================================================
clc
clear
close all

oldpath = path;
selpath = uigetdir;
path(selpath,oldpath)
answer = input('scegli regola (1 = post-post, 2 = post-pre, 3 = pre-pre, 4 = total, 5 = Oja, 6 = adapt )');
answer1 = input ('scegli addestramento (1 = base, 2 = reversal)');

% %--------------------------------------------------------------------------
% % Initialisation
% %--------------------------------------------------------------------------
Nprove = 10;
Ns = 4;
Nc = 4;

Epoche = [1 50 75 100 150 200 250 300 350 400];  % case of reversal
%Epoche = [1 25 50 75 100 150 200 250 300];   % case basal
L_epoche = length(Epoche);

Matrice_risultati = zeros(L_epoche,5);

for indice_epoca = 1 : L_epoche

    Epoca = Epoche(indice_epoca);

    Array_n_succ_tot = zeros(Nc,Nprove);
    Array_n_small_succ_tot = zeros(Nc,Nprove);
    S1 = zeros(Nc,1);
    S2 = zeros(Nc,1);
    S3 = zeros(Nc,1);
    S4 = zeros(Nc,1);

    for ii = 1:Nprove

        switch answer
            case 1
                if answer1 == 1
                    name = strcat('W_tot_post_post_',num2str(ii));
                elseif answer1 == 2
                    name = strcat('Rev_W_tot_post_post_',num2str(ii));
                end
            case 2
                if answer1 == 1
                    name = strcat('W_tot_post_pre_',num2str(ii));
                elseif answer1 == 2
                    name = strcat('Rev_W_tot_post_pre_',num2str(ii));
                end
            case 3
                if answer1 == 1
                    name = strcat('W_tot_pre_pre_',num2str(ii));
                elseif answer1 == 2
                    name = strcat('Rev_W_tot_pre_pre_',num2str(ii));
                end
            case 4
                if answer1 == 1
                    name = strcat('W_tot_total_',num2str(ii));
                elseif answer1 == 2
                    name = strcat('Rev_W_tot_total_',num2str(ii));
                end
            case 5
                if answer1 == 1
                    name = strcat('W_tot_oja_',num2str(ii));
                elseif answer1 == 2
                    name = strcat('Rev_W_tot_oja_',num2str(ii));
                end
            case 6
                if answer1 == 1
                    name = strcat('W_tot_post_adapt_',num2str(ii));
                elseif answer1 == 2
                    name = strcat('Rev_W_tot_post_adapt_',num2str(ii));
                end
        end

        load(name)

        % basal stimuli
        caso = 0;

        STN_ON = 1;
        T_ON = 1;
        %trained synapses
        Wgc = squeeze(Wgc_epocs(:,:,4*Epoca));
        Wgs = squeeze(Wgs_epocs(:,:,4*Epoca));
        Wnc = squeeze(Wnc_epocs(:,:,4*Epoca));
        Wns = squeeze(Wns_epocs(:,:,4*Epoca));

        % four stimuli may be applied to the network
        if answer1 == 1
            S_high = 1.0;
            S_small = 0.3;
        elseif answer1 == 2
            S_high = 0.3;
            S_small = 1.0;
        end

        %S1: stimulus 1
        S1(1) = S_high;
        S1(2) = S_small;
        S1(3) = 0.1;
        S1(4) = 0.1;
        % S1 = S1';

        Correct_winner_1 = 1;
        Small_winner_1 = 2;

        %S2: stimulus 2
        S2(1) = S_small;
        S2(2) = S_high;
        S2(3) = 0.1;
        S2(4) = 0.1;
        % S2 = S2';

        Correct_winner_2 = 2;
        Small_winner_2 = 1;

        %S3: stimulus 3
        S3(1) = 0.1;
        S3(2) = 0.1;
        S3(3) = S_high;
        S3(4) = S_small;
        % S3 = S3';

        Correct_winner_3 = 3;
        Small_winner_3 = 4;

        %S4:stimulus 4
        S4(1) = 0.1;
        S4(2) = 0.1;
        S4(3) = S_small;
        S4(4) = S_high;
        % S4 = S4';

        Correct_winner_4 = 4;
        Small_winner_4 = 3;


        % task parameters

        trials = 200/Ns;
        Dop_tonic = 1.0; % value of the dopaminergic input used during training, default 1.2

        S_vett = [];
        exitC = [];
        Winner = [];

        S_vett_tot = zeros(Nc,Ns*trials);
        exitC_tot = zeros(Nc,Ns*trials);
        Winner_tot = zeros(1,Ns*trials);
        n_err_act_tot = zeros(1,Ns*trials);

        rng(21)
        noise1=zeros(Nc,Ns*trials);
        noise1(1:Ns,1:Ns*trials) =  0.15*randn(Ns,Ns*trials);
        % noise1 =   0.20*randn(Nc,4*trials);% noise to the cortex

        rng(31)
        noise2=zeros(Nc,Ns*trials);
        noise2(1:Ns,1:Ns*trials) = 0.2*randn(Ns,Ns*trials);% noise to S
        %%
        for i = 1:trials

            action = randperm (Ns);

            for j = 1:Ns
                noiseS = noise2(:,j+Ns*(i-1));
                noiseC = noise1(:,j+Ns*(i-1));
                %
                switch action(j)

                    case 1
                        S = S1;
                        S(1) = S(1)+noiseS(1);
                        S(2) = S(2)+noiseS(2);
                        S(3) = S(3)+noiseS(3);
                        S(4) = S(4)+noiseS(4);
                        Correct_winner = Correct_winner_1;
                        Minor_winner = Small_winner_1;

                    case 2
                        S = S2;
                        S(1) = S(1)+noiseS(1);
                        S(2) = S(2)+noiseS(2);
                        S(3) = S(3)+noiseS(3);
                        S(4) = S(4)+noiseS(4);
                        Correct_winner = Correct_winner_2;
                        Minor_winner = Small_winner_2;

                    case 3
                        S = S3;
                        S(1) = S(1)+noiseS(1);
                        S(2) = S(2)+noiseS(2);
                        S(3) = S(3)+noiseS(3);
                        S(4) = S(4)+noiseS(4);
                        Correct_winner = Correct_winner_3;
                        Minor_winner = Small_winner_3;

                    case 4
                        S = S4;
                        S(1) = S(1)+noiseS(1);
                        S(2) = S(2)+noiseS(2);
                        S(3) = S(3)+noiseS(3);
                        S(4) = S(4)+noiseS(4);
                        Correct_winner = Correct_winner_4;
                        Minor_winner = Small_winner_4;
                end


                %%
                S(find(S>1)) = 1;
                S(find(S<0)) = 0;
                %     S(3:4) = 0;


                % Call to the function which simulates the basal ganglia response
                [Uc,C,Ugo,Go,IGo_DA_Ach,Unogo,NoGo,INoGo_DA_Ach,Ugpe,Gpe,Ugpi,Gpi,Ut,T,Ustn,STN,E,k_tap_vett,Uchi,ChI,t] = BG_model_Response_Stimuli_Task1(S,Wgc,Wgs,Wnc,Wns,STN_ON,T_ON,Dop_tonic,noiseC);


                S_vett(1,j) = S(1);
                S_vett(2,j) = S(2);
                S_vett(3,j) = S(3);
                S_vett(4,j) = S(4);

                Winner(j) = Correct_winner;
                Small_winner(j) = Minor_winner;

                exitC(1,j) = C(1,end); %exit cortex ch 1
                exitC(2,j) = C(2,end);
                exitC(3,j) = C(3,end);
                exitC(4,j) = C(4,end);
                if max(exitC(:,j)) < 0.9 | sum (exitC(:,j)>0.9)>1 % number of no act or multiple act
                    n_err_act(j) = 1;
                    exitC(:,j) = zeros(Nc,1); %multiple act or no act
                else
                    n_err_act(j)= 0;
                end

            end

            S_vett_tot(:,Ns*i-(Ns-1):Ns*i) = S_vett;
            exitC_tot (:,Ns*i-(Ns-1):Ns*i) = exitC;
            Winner_tot(1,Ns*i-(Ns-1):Ns*i) = Winner;
            Small_winner_tot(1,Ns*i-(Ns-1):Ns*i) = Small_winner;

            err = isempty(n_err_act);
            if err ==0
                n_err_act_tot(1,Ns*i-(Ns-1):Ns*i) = n_err_act; %no act or mult act
            end

        end
        %%
        n_succ_tot = zeros(Ns,1);
        n_small_succ_tot = zeros(Ns,1);
        % n_err_tot = zeros(Ns,1);


        idx_act_C1 = find(exitC_tot(1,:)>=0.9); %look at C activation
        lostC1 = isempty(idx_act_C1);

        idx_act_C2 = find(exitC_tot(2,:)>=0.9);
        lostC2 = isempty(idx_act_C2);

        idx_act_C3 = find(exitC_tot(3,:)>=0.9);
        lostC3 = isempty(idx_act_C3);

        idx_act_C4 = find(exitC_tot(4,:)>=0.9);
        lostC4 = isempty(idx_act_C4);

        if lostC1==0  
            n_succ_tot(1) = length(intersect(find(Winner_tot==1),idx_act_C1));
            n_small_succ_tot(1) = length(intersect(find(Small_winner_tot==1),idx_act_C1));
            %                 n_err_tot(1) = length(setdiff(find(Winner_tot==1),idx_act_C1));
        end

        if lostC2==0  
            n_succ_tot(2) = length(intersect(find(Winner_tot==2),idx_act_C2));
            n_small_succ_tot(2) = length(intersect(find(Small_winner_tot==2),idx_act_C2));
            %             n_err_tot(2) = length(setdiff(find(Winner_tot==2),idx_act_C2));
        end


        if lostC3==0   
            n_succ_tot(3) = length(intersect(find(Winner_tot==3),idx_act_C3));
            n_small_succ_tot(3) = length(intersect(find(Small_winner_tot==3),idx_act_C3));
            %           n_err_tot(3) = length(setdiff(find(Winner_tot==3),idx_act_C3));
        else
            n_succ_tot(3) = 0;
            n_small_succ_tot(3) = 0;
        end


        if lostC4==0   
            n_succ_tot(4) = length(intersect(find(Winner_tot==4),idx_act_C4));
            n_small_succ_tot(4) = length(intersect(find(Small_winner_tot==4),idx_act_C4));
            %             n_err_tot(4) = length(setdiff(find(Winner_tot==4),idx_act_C4));
        else
            n_succ_tot(4) = 0;
            n_small_succ_tot(4) = 0;
        end

        n_succ = sum(n_succ_tot);
        n_small_succ = sum(n_small_succ_tot);
        n_succ_tot;
        n_small_succ_tot;
        Array_n_succ_tot(:,ii) = n_succ_tot;
        Array_n_small_succ_tot(:,ii) = n_small_succ_tot;
        % n_err = sum (n_err_tot)
    end

    Array_n_succ = sum(Array_n_succ_tot);
    Array_n_small_succ = sum(Array_n_small_succ_tot);
    Array_error= 200 - Array_n_succ - Array_n_small_succ;



    % disp(Array_n_succ_tot)
    % disp(Array_n_succ)
    % disp(Array_n_small_succ_tot)
    % disp(Array_n_small_succ)
    % disp(Array_error)


    mu_succ = mean(Array_n_succ);
    std_succ = std(Array_n_succ);
    mu_small_succ = mean(Array_n_small_succ);
    std_small_succ = std(Array_n_small_succ);
    Matrice_risultati(indice_epoca,:) = [Epoca mu_succ std_succ mu_small_succ std_small_succ];
end
% disp(Matrice_risultati)
TEST=array2table(Matrice_risultati);
TEST.Properties.VariableNames={'Epoch','Mean 1st Action','Std 1st Action','Mean 2nd Action','Std 2nd Action'};
TEST