% This script calculates robustness of different properties 
% as a function of the correlation radius
% This section generates Fig. 1g (bottom panel)

N_syn = 1:20; % Synaptic connectivity / Number inputs
f_mf = linspace(.05,.95,19); % fraction active MFs

robustness_speed = [];
for sigma = 0:5:30

    load(strcat('results_bp/grc_toy_r',num2str(sigma),'.mat'))

    err_grc = err_rms_grc; err_mf = err_rms_mf;
    thresh = 0.2; % threshold for determining learning speed
    T = size(err_grc,3);

    % Get number of epochs until learning is complete
    % both for GC- and for MF-based learning
    grc = nan(length(N_syn),length(f_mf));
    mf = nan(length(N_syn),length(f_mf));
    for j = 1:length(N_syn)
        for k = 1:length(f_mf)
            e_grc = reshape(err_grc(j,k,:),1,T);
            e_mf = reshape(err_mf(j,k,:),1,T);

            temp = find(e_grc<=thresh);
            if numel(temp) > 0
                grc(j,k) = temp(1);
            end

            temp = find(e_mf<=thresh);
            if numel(temp) > 0
                mf(j,k) = temp(1);
            end
        end
    end
    

    temp = (1./grc)./(1./mf)-1;
    temp = temp(isfinite(temp)); % only count cases that have converged
    
    robustness_speed = [robustness_speed, sum(temp(:)>0)/length(temp(:))];

end

figure, plot(0:5:30,robustness_speed,'-ok','LineWidth',3,'MarkerFaceColor','k')
axis([-5,35,.2,.8])
xlabel('Correlation radius (\mum)'), ylabel('Robustness of learning')
set(gca,'FontSize',20)

%% This section generates Fig. 2b,d and Fig. 3b (bottom panel)

N_syn = 1:20; % Synaptic connectivity / Number inputs
f_mf = linspace(.05,.95,19); % fraction active MFs

N_mf = 187; N_grc = 487;
N_patt = 640;

theta_initial = 3; NADT = 0;

N_repeats = 25;

robustness_sp = [];
robustness_total_var = [];
robustness_pop_corr = [];

for sigma = 0:5:30
    sigma
    
    % Total variance
    total_var_norm = nan(length(N_syn),length(f_mf),N_repeats);
    pop_corr_norm =  nan(length(N_syn),length(f_mf),N_repeats);
    sp_norm = zeros(length(N_syn),length(f_mf),N_repeats);

    for k1 = N_syn
        load(strcat('../network_structures/GCLconnectivity_',int2str(k1),'.mat'))
        conn_mat = double(conn_mat);

        for k2 = 1:length(f_mf)

            for k3 = 1:N_repeats

                % Input MF patterns
                if sigma == 0 % Independent case
                    x_mf = zeros(N_mf,N_patt);
                    for i = 1:N_patt
                        mf_on = randsample(N_mf,round(f_mf(k2)*N_mf));
                        x_mf(mf_on,i) = 1.;
                    end
                elseif sigma >0 % Correlated case -- generated following Macke et al. 2009
                    load(strcat('../input_statistics/mf_patterns_r',num2str(sigma),'.mat'))
                    R = Rs(:,:,k2); g = gs(k2);
                    t = R' * randn(N_mf,N_patt);
                    S = (t>-g(1)*ones(N_mf,N_patt)); 
                    x_mf = S;
                end

                theta = theta_initial + NADT*f_mf(k2); % threshold
                in = 4/k1*conn_mat'*x_mf; % input 
                x_grc = max(in-theta,0); % Output GC activity

                x_mf = double(x_mf); x_grc = double(x_grc);

                % Eigenvalues of covariance matrix of MF patterns
                C_mf = cov(x_mf');
                [~,L_mf] = eig(C_mf); L_mf = diag(L_mf);
                L_mf = real(sqrt(L_mf)); 

                % Total variance of MFs
                total_var_mf = sum(L_mf.^2);

                % Population correlation of MFs
                pop_corr_mf = (max(L_mf)/sum(L_mf) - 1./N_mf)/(1-1/N_mf);

                % Avg. population sparseness
                sptemp = zeros(1,N_patt); 
                for t = 1:N_patt
                    sptemp(t) =(N_mf-sum(x_mf(:,t))^2/sum(x_mf(:,t).^2))/(N_mf-1);
                end  
                sp_mf = nanmean(sptemp);

                if max(x_grc(:)) > 0

                    % Eigenvalues of covariance matrix of GC patterns
                    C_grc = cov(x_grc'); 
                    [~,L_grc] = eig(C_grc);  L_grc = diag(L_grc);
                    L_grc =real(sqrt(L_grc));

                    % Total variance of GCs
                    total_var_grc = sum(L_grc.^2); % total variation

                    % Population correlation of GCs
                    pop_corr_grc = (max(L_grc)/sum(L_grc) - 1./N_grc)/(1-1/N_grc);

                    % Avg. population sparseness
                    sptemp = zeros(1,N_patt); 
                    for t = 1:N_patt
                        sptemp(t) =(N_grc-sum(x_grc(:,t))^2/sum(x_grc(:,t).^2))/(N_grc-1);
                    end
                    sp_grc = nanmean(sptemp);
                else
                    sp_grc = NaN; pop_corr_grc = NaN; total_var_grc = NaN;
                end
                
                sp_norm(k1,k2,k3) = sp_grc./sp_mf;
                total_var_norm(k1,k2,k3) = total_var_grc./total_var_mf;
                pop_corr_norm(k1,k2,k3) = pop_corr_grc./pop_corr_mf;
                
            end
        end
    end
    
    temp = mean(sp_norm,3)-1;
    temp = temp(isfinite(temp)); % only count cases that have converged
    
    robustness_sp = [robustness_sp, sum(temp(:)>0)/length(temp(:))];
    
    temp =  mean(total_var_norm,3)-1;
    temp = temp(isfinite(temp)); % only count cases that have converged
   
    robustness_total_var = [robustness_total_var, sum(temp(:)>0)/length(temp(:))];
    
    temp = mean(pop_corr_norm,3)-1;
    temp = temp(isfinite(temp)); % only count cases that have converged
   
    robustness_pop_corr = [robustness_pop_corr, sum(temp(:)<0)/length(temp(:))];
end

figure, plot(0:5:30,robustness_sp,'-ok','LineWidth',3,'MarkerFaceColor','k')
axis([-5,35,0.6,1])
xlabel('Correlation radius (\mum)'), ylabel('Robustness of sparsening')
set(gca,'FontSize',20)

figure, plot(0:5:30,robustness_total_var,'-ok','LineWidth',3,'MarkerFaceColor','k')
axis([-5,35,0,1])
xlabel('Correlation radius (\mum)'), ylabel('Robustness of expansion')
set(gca,'FontSize',20)

figure, plot(0:5:30,robustness_pop_corr,'-ok','LineWidth',3,'MarkerFaceColor','k')
axis([-5,35,0,.3])
xlabel('Correlation radius (\mum)'), ylabel('Robustness of decorrelation')
set(gca,'FontSize',20)

%% This section generates Fig. 4c (bottom panel)

N_syn = 1:20; % Synaptic connectivity / Number inputs
f_mf = linspace(.05,.95,19); % fraction active MFs

% Modify for different input correlations
sigma = 20; % correlation radius, um

robustness_speed = [];
for theta = [0:.5:2,2.25:.25:3.75]
    theta
    
    if theta == 3
        load(strcat('results_bp/grc_toy_r',num2str(sigma),'.mat'))
    else
        load(strcat('results_bp_th/grc_toy_r',num2str(sigma),'_',num2str(theta,'%.2f'),'.mat'))
    end

    err_grc = err_rms_grc; err_mf = err_rms_mf;
    thresh = 0.2; % threshold for determining learning speed
    T = size(err_grc,3);

    % Get number of epochs until learning is complete
    % both for GC- and for MF-based learning
    grc = nan(length(N_syn),length(f_mf));
    mf = nan(length(N_syn),length(f_mf));
    for j = 1:length(N_syn)
        for k = 1:length(f_mf)
            e_grc = reshape(err_grc(j,k,:),1,T);
            e_mf = reshape(err_mf(j,k,:),1,T);

            temp = find(e_grc<=thresh);
            if numel(temp) > 0
                grc(j,k) = temp(1);
            end

            temp = find(e_mf<=thresh);
            if numel(temp) > 0
                mf(j,k) = temp(1);
            end
        end
    end
    

    temp = (1./grc)./(1./mf)-1;
    robustness_speed = [robustness_speed, nansum(temp(:)>0)/numel(temp(:))];

end

figure, plot([0:.5:2,2.25:.25:3.75],robustness_speed,'-ok','LineWidth',3,'MarkerFaceColor','k')
hold on, plot([3,3],[0,1],':k','LineWidth',2), axis([0,4,0,1])
xlabel('Correlation radius (\mum)'), ylabel('Robustness of learning')
set(gca,'FontSize',20)

%% This section generates Fig. 4c (top panel)

N_syn = 1:20; % Synaptic connectivity / Number inputs
f_mf = linspace(.05,.95,19); % fraction active MFs

N_mf = 187; N_grc = 487;
N_patt = 640;

N_repeats = 25;

% Modify for different input correlations
sigma = 20; % correlation radius, um

robustness_total_var = [];
robustness_pop_corr = [];

for theta = [0:.5:2,2.25:.25:3.75]
    theta
    
    % Total variance
    total_var_norm = nan(length(N_syn),length(f_mf),N_repeats);
    pop_corr_norm =  nan(length(N_syn),length(f_mf),N_repeats);
    
    for k1 = N_syn
        load(strcat('../network_structures/GCLconnectivity_',int2str(k1),'.mat'))
        conn_mat = double(conn_mat);

        for k2 = 1:length(f_mf)

            for k3 = 1:N_repeats

                % Input MF patterns
                if sigma == 0 % Independent case
                    x_mf = zeros(N_mf,N_patt);
                    for i = 1:N_patt
                        mf_on = randsample(N_mf,round(f_mf(k2)*N_mf));
                        x_mf(mf_on,i) = 1.;
                    end
                elseif sigma >0 % Correlated case -- generated following Macke et al. 2009
                    load(strcat('../input_statistics/mf_patterns_r',num2str(sigma),'.mat'))
                    R = Rs(:,:,k2); g = gs(k2);
                    t = R' * randn(N_mf,N_patt);
                    S = (t>-g(1)*ones(N_mf,N_patt)); 
                    x_mf = S;
                end

                in = 4/k1*conn_mat'*x_mf; % input 
                x_grc = max(in-theta,0); % Output GC activity

                x_mf = double(x_mf); x_grc = double(x_grc);

                % Eigenvalues of covariance matrix of MF patterns
                C_mf = cov(x_mf');
                [~,L_mf] = eig(C_mf); L_mf = diag(L_mf);
                L_mf = real(sqrt(L_mf)); 

                % Total variance of MFs
                total_var_mf = sum(L_mf.^2);

                % Population correlation of MFs
                pop_corr_mf = (max(L_mf)/sum(L_mf) - 1./N_mf)/(1-1/N_mf);
                
                if max(x_grc(:)) > 0

                    % Eigenvalues of covariance matrix of GC patterns
                    C_grc = cov(x_grc'); 
                    [~,L_grc] = eig(C_grc);  L_grc = diag(L_grc);
                    L_grc =real(sqrt(L_grc));

                    % Total variance of GCs
                    total_var_grc = sum(L_grc.^2); % total variation

                    % Population correlation of GCs
                    pop_corr_grc = (max(L_grc)/sum(L_grc) - 1./N_grc)/(1-1/N_grc);

                else
                    pop_corr_grc = NaN; total_var_grc = NaN;
                end
                
                total_var_norm(k1,k2,k3) = total_var_grc./total_var_mf;
                pop_corr_norm(k1,k2,k3) = pop_corr_grc./pop_corr_mf;
                
            end
        end
    end
    
    temp =  mean(total_var_norm,3)-1;
    temp = temp(isfinite(temp)); % only count cases that have converged
   
    robustness_total_var = [robustness_total_var, sum(temp(:)>0)/length(temp(:))];
    
    temp = mean(pop_corr_norm,3)-1;
    temp = temp(isfinite(temp)); % only count cases that have converged
   
    robustness_pop_corr = [robustness_pop_corr, sum(temp(:)<0)/length(temp(:))];
end

green = [0,.3,0]; purple = [.4,0,.6];
figure, plot([0:.5:2,2.25:.25:3.75],robustness_total_var,'-o','Color',green,'LineWidth',3,'MarkerFaceColor',green)
hold on, plot([0:.5:2,2.25:.25:3.75],robustness_pop_corr,'-o','Color',purple,'LineWidth',3,'MarkerFaceColor',purple)
plot([3,3],[0,1],':k','LineWidth',2), axis([0,4,0,1])
xlabel('Correlation radius (\mum)'), ylabel('Robustness')
set(gca,'FontSize',20)