% Plot results from backprop learning
% i.e., Fig. 1f

N_syn = 1:20; % Synaptic connectivity / Number inputs
f_mf = linspace(.05,.95,19); % fraction active MFs

% Modify for different input correlations
sigma = 0; % correlation radius, um

load(strcat('results_bp/grc_toy_r',num2str(sigma),'.mat'))

err_grc = err_rms_grc; err_mf = err_rms_mf;
thresh = 0.2; % threshold for determining learning speed
T = size(err_grc,3);

% Get number of epochs until learning is complete
% both for GC- and for MF-based learning
grc = nan(length(N_syn),length(f_mf));
mf = nan(length(N_syn),length(f_mf));
for j = 1:length(N_syn)
    for k = 1:length(f_mf)
        e_grc = reshape(err_grc(j,k,:),1,T);
        e_mf = reshape(err_mf(j,k,:),1,T);
        
        temp = find(e_grc<=thresh);
        if numel(temp) > 0
            grc(j,k) = temp(1);
        end
        
        temp = find(e_mf<=thresh);
        if numel(temp) > 0
            mf(j,k) = temp(1);
        end
    end
end

% Learning speed is 1 / number epochs to reach threshold
speed_grc = (1./grc);
speed_mf = (1./mf);
speed_norm = speed_grc./speed_mf;

% Find points where MF and GC learning speeds are equal
f_mf_zeros = []; x = [];
for j = 1:length(N_syn)
    temp = grc(j,:) - mf(j,:);
    for i = 1:length(temp)-1
        if temp(i) >0 && temp(i+1) <0
            f_mf_zeros = [f_mf_zeros, interp1(temp([i,i+1]),f_mf([i,i+1]),0)];
            x = [x,j];
        elseif temp(i) <0 && temp(i+1) >0
            f_mf_zeros = [f_mf_zeros, interp1(temp([i,i+1]),f_mf([i,i+1]),0)];
            x = [x,j];
        elseif temp(i) == 0
            f_mf_zeros = [f_mf_zeros, f_mf(i)];
            x = [x,j];
        end
    end    
end

f = fit(x',f_mf_zeros','exp2','StartPoint',[.5,.01,-1,-.5]);
y2 = f.a*exp(f.b*x) + f.c*exp(f.d*x);

% Do not plot region in which MF speed is faster than GC speed
for d = N_syn
    temp = find(f_mf < f.a*exp(f.b*d) + f.c*exp(f.d*d));
    speed_norm(d,temp) = NaN;
end

figure, imagesc(f_mf,1:20,speed_mf);
set(gca,'YDir','normal'); set(gca,'FontSize',20)
title('Raw MF speed');
xlabel('Fraction active MFs'); ylabel('Number inputs')

figure, imagesc(f_mf,1:20,speed_grc);
set(gca,'YDir','normal'); set(gca,'FontSize',20)
title('Raw GC speed');
xlabel('Fraction active MFs'); ylabel('Number inputs')

% Plot Fig. 1f
figure, y=imagesc(f_mf,1:20,speed_norm);
set(y,'AlphaData',~isnan(speed_norm))
set(gca,'YDir','normal'); set(gca,'FontSize',20)
title('Normalized learning speed');
xlabel('Fraction active MFs'); ylabel('Number inputs')
hold on, plot(y2,x,'Color',[.3,.8,1],'LineWidth',5)


%% Here plot Fig. 1d

% Parameters to plot

Nsyn = 4; % number inputs
f_mf_ix = 10; % corresponds to f_mf = 0.5

figure, hold on
plot([0,T],[thresh,thresh],':','Color',[.5,.5,.5],'LineWidth',3)
plot(reshape(err_grc(4,10,:),1,T),'r','LineWidth',3)
plot(reshape(err_mf(4,10,:),1,T),'b','LineWidth',3)
axis([0,2000,0,.4])
xlabel('Training epochs'); ylabel('RMS error')
set(gca,'FontSize',20)


%% Plot speed as fn of correlation radius
% i.e., for Fig 1e,g (top panel)

N_syn = 1:20; % Synaptic connectivity / Number inputs
f_mf = linspace(.05,.95,19); % fraction active MFs



speed_mf = []; 
speed_grc_sparse = []; speed_grc_dense = [];
speed_norm_sparse = []; speed_norm_dense = [];

for sigma = 0:5:30
    sigma
    
    load(strcat('results_bp/grc_toy_r',num2str(sigma),'.mat'))
    
    err_grc = err_rms_grc; err_mf = err_rms_mf;
    thresh = 0.2; % threshold for determining learning speed
    T = size(err_grc,3);

    % Get number of epochs until learning is complete
    % both for GC- and for MF-based learning
    for j = [4,16]
        grc = nan(length(f_mf),1);
        mf = nan(length(f_mf),1);
        for k = 1:length(f_mf)
            e_grc = reshape(err_grc(j,k,:),1,T);
            e_mf = reshape(err_mf(j,k,:),1,T);

            temp = find(e_grc<=thresh);
            if numel(temp) > 0
                grc(k) = temp(1);
            end

            temp = find(e_mf<=thresh);
            if numel(temp) > 0
                mf(k) = temp(1);
            end
        end
        
        if j == 4
            % Learning speed is 1 / number epochs to reach threshold
            speed_grc_sparse = [speed_grc_sparse, nanmedian(1./grc)];
            speed_norm_sparse = [speed_norm_sparse, nanmedian((1./grc)./(1./mf))];
        elseif j == 16
            % Learning speed is 1 / number epochs to reach threshold
            speed_grc_dense = [speed_grc_dense, nanmedian(1./grc)];
            speed_norm_dense = [speed_norm_dense, nanmedian((1./grc)./(1./mf))];
        end
        
    end
    
    % Learning speed is 1 / number epochs to reach threshold
    speed_mf = [speed_mf, nanmedian(1./mf)];    
    
end

figure, plot(0:5:30,speed_norm_sparse,'-ok','LineWidth',3,'MarkerFaceColor','k')
hold on, plot(0:5:30,speed_norm_dense,'--ok','LineWidth',3,'MarkerFaceColor','k')
plot([-5,35],[1,1],'k'); axis([-5,35,0,5])
xlabel('Correlation radius (\mum)'), ylabel('Norm. speed')
set(gca,'FontSize',20)


figure, plot(0:5:30,speed_grc_sparse,'-ob','LineWidth',3,'MarkerFaceColor','b')
hold on, plot(0:5:30,speed_grc_dense,'--ob','LineWidth',3,'MarkerFaceColor','b')
 plot(0:5:30,speed_mf,'-or','LineWidth',3,'MarkerFaceColor','r')
xlabel('Correlation radius (\mum)'), ylabel('Speed')
set(gca,'FontSize',20)