% This program models LGN-V1 pathways
% Feedforward and feedback connections
% Separate excitatory (positive) and inhibitory (negative) connections
% Author: Yanbo Lian
% Date: 26/02/2019
% Citation: Lian Y, Grayden DB, Kameneva T, Meffin H and Burkitt AN (2019) 
%           Toward a Biologically Plausible Model of LGN-V1 Pathways Based
%           on Efficient Coding.
%           Front. Neural Circuits 13:13. doi: 10.3389/fncir.2019.00013

clc; close all; clear

%% Load image
load('IMAGES_SparseCoding.mat')

numImages = size(IMAGES_WHITENED,3);
imageSize = size(IMAGES_WHITENED,1);
imgVar = 0.2; % variance of the input image
BUFF = 4; % the margin between the boundry of the image and selected patch

histFlag = 1; % display the history of cells responses
displayEvery = 200; % display plots after some trials
resizeFactor = 3; % higher resolution when displaying images

%% Define hyper parameters
numPretrain = 1e4;
numEpoches = 3e4; % number of epoches
batchSize = 100; % number of natural images in a minibatch
batchSizePretrain = 100; % number of images of white noise in a minibatch

normalizationMethod = 'L2 norm'; 
l1 = 1;
l2 = 1;

lambda = 0.6; % control sparseness; threshold of the F-I curve
aEta = 0.5; % learning rate of connections A1
tau = 12; % ms
dt = 3; % ms
uEta = dt/tau; % updating rate of membrane potentials U
nU = 30; % number of iterations of calculating membrane potentials U
threshType = 'non-negative soft'; % type of thresholding function that computes firing rates of simple cells from membrane potentials

%% Definitions of symbols
sz = 16; L = sz^2; % size of the image patch; L ON units and L OFF units
OC = 256/L; % overcompleteness
M1 = OC *L; % number of simple cells

% feedforward (up) connections between 2L LGN cells and M1 simple cells
aInitialMean=0.5; % for exponential distribution: var = mean ^ 2;
initial='exponential'; 
A_Up_Pos = NormalizeA( exprnd(aInitialMean,[2*L M1]), normalizationMethod, l1 ); % positive connections
A_Up_Neg = NormalizeA( -exprnd(aInitialMean,[2*L M1]), normalizationMethod, l2 ); % negative connections
A_Up = A_Up_Pos + A_Up_Neg; % overall feedforward connections

% feedback (down) connections between 2L LGN cells and M1 simple cells
A_Down_Pos = NormalizeA( exprnd(aInitialMean,[2*L M1]), normalizationMethod, l2 ); % positive connections
A_Down_Neg = NormalizeA( -exprnd(aInitialMean,[2*L M1]), normalizationMethod, l1 ); % negative connections
A_Down = A_Down_Pos + A_Down_Neg; % overall feedback connections

dA_Bound = 0.1; % maximal change of synaptic efficacy 

X_Data = zeros( L, batchSize ); % input image patches
X = zeros( 2*L, batchSize ); % input with ON and OFF channels

U_L = randn( 2*L, batchSize ); % membrane potential of ON-OFF LGN cells
S_L = rand( 2*L, batchSize ); % firing rate of ON-OFF LGN cells

U1 = randn( M1, batchSize ); % membrane potential of simple cells
S1 = rand( M1, batchSize ); % firing rate of simple cells

s_b = 2; % background firing rate that gives an offset of the reconstruction error
s1Max = 100; % maximum firing rate of simple cells
sL_Max = 100; % maximum firing rate of LGN cells

errorA_UpDown = ones( 1, 1+numEpoches+numPretrain ); % difference between A_Up and A_Down during learning
errorA_UpPosDownPos = ones( 1, 1+numEpoches+numPretrain ); % difference between A_Up_Pos and A_Down_Pos during learning
errorA_UpNegDownNeg = ones( 1, 1+numEpoches+numPretrain ); % difference between A_Up_Neg and A_Down_Neg during learning
errorA_UpDown(1) = sum ( ( A_Up(:) + A_Down(:) ).^2 ); % initial difference
errorA_UpPosDownPos(1) = sum ( ( A_Up_Pos(:) + A_Down_Neg(:) ).^2 );
errorA_UpNegDownNeg(1) = sum ( ( A_Up_Neg(:) + A_Down_Pos(:) ).^2 );

%% Display A and S

% Display the connections from ON and OFF LGN cells to simple cells
figure(1);
subplot(231); DisplayA( 'ON', A_Up_Pos, resizeFactor ); title('A^{+}_{ON,Up}');
subplot(232); DisplayA( 'ON', A_Up_Neg, resizeFactor ); title('A^{-}_{ON,Up}');
subplot(233); DisplayA( 'ON', A_Up, resizeFactor ); title('A_{ON,Up}');
subplot(234); DisplayA( 'OFF', A_Up_Pos, resizeFactor ); title('A^{+}_{OFF,Up}');
subplot(235); DisplayA( 'OFF', A_Up_Neg, resizeFactor ); title('A^{-}_{OFF,Up}');
subplot(236); DisplayA( 'OFF', A_Up, resizeFactor ); title('A_{OFF,Up}');
colormap(Green2Magenta(64));

% Display the overall receptive fields of simple cells: Aon - Aoff
figure(2);
DisplayA( 'ONOFF', A_Up, resizeFactor); title('RFs: A_{ON,Up}-A_{OFF,Up}');
colormap(scm(256));

% Display the firing rates of LGN cells and simple cells
figure(3);
subplot(211); stem(S_L); title(['S_L: LGN cell responses of ' num2str(batchSize) 'patches']);
xlabel('LGN cells'); ylabel('firing rates');
subplot(212); stem(S1); title(['S1: simple cell responses of ' num2str(batchSize) 'patches']);
xlabel('simple cells'); ylabel('firing rates');

%% Pre-train the model using white noise to make sure A_Up converges to A_Down
X_DataPretrain = zeros( L, batchSizePretrain ); % input image patches
X_Pretrain = zeros( 2*L, batchSizePretrain ); % input with ON and OFF channels

U_L_Pretrain = randn( 2*L, batchSizePretrain ); % membrane potential of ON-OFF LGN cells
S_L_Pretrain = rand( 2*L, batchSizePretrain ); % firing rate of ON-OFF LGN cells

U1Pretrain = randn( M1, batchSizePretrain ); % membrane potential of simple cells
S1Pretrain = rand( M1, batchSizePretrain ); % firing rate of simple cells

for iPretrain = 1 : numPretrain
    
    % Generate white noise input with the variance of 'imgVar'
    X_DataPretrain = sqrt(imgVar) * randn(L, batchSizePretrain);
    
    % ON and OFF LGN input
    X_Pretrain( 1:L, : ) = max( X_DataPretrain, 0 );
    X_Pretrain( L+1:2*L, : ) = -min( X_DataPretrain, 0 );
    
    % Compute S and U for LGN and simple cells using previous values
    [ S1Pretrain, U1Pretrain, S_L_Pretrain, U_L_Pretrain ] = ...
        Compute_S_U_LGN_V1_UpDown( S1Pretrain, U1Pretrain, S_L_Pretrain, U_L_Pretrain,...
            X_Pretrain, A_Up, A_Down, lambda, s_b, uEta, nU, threshType, s1Max, sL_Max);

    % Update up and down connections A1
    dA = aEta * ( S_L_Pretrain - s_b ) * S1Pretrain' / batchSizePretrain; % learning rule
    dA = max( min(dA, dA_Bound), -dA_Bound ); % keep the updated amount bounded
    
    A_Up_Pos = max( A_Up_Pos + 1*dA, 0 );
    A_Up_Neg = min( A_Up_Neg + 1*dA, 0 ); % -A_Up_Neg = max( -A_Up_Neg - dA, 0 );
    
    A_Up_Pos = NormalizeA( A_Up_Pos, normalizationMethod, l1 ); % positive connections
    A_Up_Neg = NormalizeA( A_Up_Neg, normalizationMethod, l2 ); % negative connections
    
    A_Down_Pos = max( A_Down_Pos - 1*dA, 0 ); 
    A_Down_Neg = min( A_Down_Neg - 1*dA, 0 ); % -A_Down_Neg = max( -A_Down_Neg - dA, 0 );
    
    A_Down_Pos = NormalizeA( A_Down_Pos, normalizationMethod, l2 ); % positive connections
    A_Down_Neg = NormalizeA( A_Down_Neg, normalizationMethod, l1 ); % negative connections
    
    A_Up = A_Up_Pos + A_Up_Neg; % overall feedforward connections
    A_Down = A_Down_Pos + A_Down_Neg; % overall feedback connections
    
    max( dA(:) )
    min( dA(:) )

    % Display A and S
    if ( mod(iPretrain,displayEvery) == 0 )
        figure(1); % Display the connections from ON and OFF LGN cells to simple cells
        subplot(231); I_A_ON_Up_Pos = DisplayA( 'ON', A_Up_Pos, resizeFactor ); title('A^{+}_{ON,Up}');
        subplot(232); I_A_ON_Up_Neg = DisplayA( 'ON', A_Up_Neg, resizeFactor ); title('A^{-}_{ON,Up}');
        subplot(233); I_A_ON_Up = DisplayA( 'ON', A_Up, resizeFactor ); title('A_{ON,Up}');
        subplot(234); I_A_OFF_Up_Pos = DisplayA( 'OFF', A_Up_Pos, resizeFactor ); title('A^{+}_{OFF,Up}');
        subplot(235); I_A_OFF_Up_Neg = DisplayA( 'OFF', A_Up_Neg, resizeFactor ); title('A^{-}_{OFF,Up}');
        subplot(236); I_A_OFF_Up = DisplayA( 'OFF', A_Up, resizeFactor ); title('A_{OFF,Up}');
        colormap(Green2Magenta(64));
        
        figure(2); % Display the overall receptive fields of simple cells: Aon - Aoff
        DisplayA( 'ONOFF', A_Up, resizeFactor); title('RFs: A_{ON,Up}-A_{OFF,Up}');
        colormap(scm(256));
        
        figure(3); % Display the firing rates of LGN cells and simple cells
        subplot(211); stem(S_L_Pretrain); title(['S_L: LGN cell responses of ' num2str(batchSize) 'patches']);
        xlabel('LGN cells'); ylabel('firing rates');
        subplot(212); stem(S1Pretrain); title(['S1: simple cell responses of ' num2str(batchSize) 'patches']);
        xlabel('simple cells'); ylabel('firing rates');
    end
    
    % Compute the difference between up and down connections
    errorA_UpDown(1+iPretrain) = sum ( ( A_Up(:) + A_Down(:) ).^2 );
    errorA_UpPosDownPos(1+iPretrain) = sum ( ( A_Up_Pos(:) + A_Down_Neg(:) ).^2 );
    errorA_UpNegDownNeg(1+iPretrain) = sum ( ( A_Up_Neg(:) + A_Down_Pos(:) ).^2 );
    
    % print current status of learning
    fprintf('Pretraining %6d: ||Aup-Adown||^2: %4.4f\n',...
                iPretrain, errorA_UpDown(1+iPretrain));
% 	pause
end

        
%% train the model using whitened natural images
for iEpoch = 1 : numEpoches
    
    % adjust the learning rate
    if iEpoch > 1e4
        aEta = 0.2;
    end
    
    if iEpoch > 2e4
        aEta = 0.1;
    end
    
    % Choose an image at random out of 10 images in the dataset
    iImage = ceil( numImages * rand );
    thisImage = IMAGES_WHITENED(:,:,iImage);
    
    % extract image patches at random from this image to make data vector
    for iBatch = 1 : batchSize
        r = BUFF + ceil((imageSize-sz-2*BUFF)*rand); % select y coordinate
        c = BUFF + ceil((imageSize-sz-2*BUFF)*rand); % select x coordinate
        X_Data( : , iBatch ) = reshape( thisImage(r:r+sz-1,c:c+sz-1), L, 1 );
    end
    
    % ON and OFF LGN input
    X_ON = max( X_Data, 0 );
    X_OFF = -min( X_Data, 0 );
    
    X( 1:L, : ) = X_ON;
    X( L+1:2*L, : ) = X_OFF;
    
    % Compute S and U for LGN and simple cells using previous values
    [ S1, U1, S_L, U_L, S1_hist] = Compute_S_U_LGN_V1_UpDown( S1, U1, S_L, U_L,...
        1*X, A_Up, A_Down, lambda, s_b, uEta, nU, threshType, s1Max, sL_Max, histFlag);

    % Update up and down connections A1
    dA = aEta * ( S_L - s_b ) * S1' / batchSize; % learning rule
    dA = max( min(dA, dA_Bound), -dA_Bound ); % keep the updated amount bounded
    
    A_Up_Pos = max( A_Up_Pos + 1*dA, 0 );
    A_Up_Neg = min( A_Up_Neg + 1*dA, 0 ); % -A_Up_Neg = max( -A_Up_Neg - dA, 0 );
    A_Up_Pos = NormalizeA( A_Up_Pos, normalizationMethod, l1 ); % positive connections
    A_Up_Neg = NormalizeA( A_Up_Neg, normalizationMethod, l2 ); % negative connections
    
    A_Down_Pos = max( A_Down_Pos - 1*dA, 0 ); 
    A_Down_Neg = min( A_Down_Neg - 1*dA, 0 );
    A_Down_Pos = NormalizeA( A_Down_Pos, normalizationMethod, l2 ); % positive connections
    A_Down_Neg = NormalizeA( A_Down_Neg, normalizationMethod, l1 ); % negative connections
    
    A_Up = A_Up_Pos + A_Up_Neg; % overall feedforward connections
    A_Down = A_Down_Pos + A_Down_Neg; % overall feedback connections
    
    max( dA(:) )
    min( dA(:) )

    % Display A and S
    if ( mod(iEpoch,displayEvery) == 0 )
        figure(1); % Display the connections from ON and OFF LGN cells to simple cells
        subplot(231); DisplayA( 'ON', A_Up_Pos, resizeFactor ); title('A^{+}_{ON,Up}');
        subplot(232); DisplayA( 'ON', A_Up_Neg, resizeFactor ); title('A^{-}_{ON,Up}');
        subplot(233); DisplayA( 'ON', A_Up, resizeFactor ); title('A_{ON,Up}');
        subplot(234); DisplayA( 'OFF', A_Up_Pos, resizeFactor ); title('A^{+}_{OFF,Up}');
        subplot(235); DisplayA( 'OFF', A_Up_Neg, resizeFactor ); title('A^{-}_{OFF,Up}');
        subplot(236); DisplayA( 'OFF', A_Up, resizeFactor ); title('A_{OFF,Up}');
        colormap(Green2Magenta(64));
        
        figure(2); % Display the overall receptive fields of simple cells: Aon - Aoff
        DisplayA( 'ONOFF', A_Up, resizeFactor); title('RFs: A_{ON,Up}-A_{OFF,Up}');
        colormap(scm(256));
        
        figure(3); % Display the firing rates of LGN cells and simple cells
        subplot(211); stem(S_L); title(['S_L: LGN cell responses of ' num2str(batchSize) 'patches']);
        xlabel('LGN cells'); ylabel('firing rates');
        subplot(212); stem(S1); title(['S1: simple cell responses of ' num2str(batchSize) 'patches']);
        xlabel('simple cells'); ylabel('firing rates');
        
        % Display the trajectory of simple cells responses
        if histFlag == 1
            figure(4);
            plot(S1_hist);title('Trajectory of simple cells')
        end
    end
    
    % Compute the difference between up and down connections
    errorA_UpDown(1+numPretrain+iEpoch) = sum ( ( A_Up(:) + A_Down(:) ).^2 );
    errorA_UpPosDownPos(1+numPretrain+iEpoch) = sum ( ( A_Up_Pos(:) + A_Down_Neg(:) ).^2 );
    errorA_UpNegDownNeg(1+numPretrain+iEpoch) = sum ( ( A_Up_Neg(:) + A_Down_Pos(:) ).^2 );
    
    % print current status of learning
    fprintf('Iteration %6d: ||Aup-Adown||^2: %4.4f\n',...
                iEpoch,errorA_UpDown(1+numPretrain+iEpoch));
end