classdef SparseCoding < handle
properties
Basis_num_used; % number of basis used to encode in sparse mode
Basis_size; % size of each base vector
Basis_num; % total basis number
Basis; % all the basis
Basis_hist; % save basis at regular intervals
Basis_selected; % indicates for each basis how often it has been selected
eta; % learning rate
Temperature; % temperature in softmax
Dsratio; % downsampling ratio (to produce 8x8)
patch_size; % size of extracted patches
end
methods
%PARAM = {Basis_num_used,Basis_size,Basis_num,eta,Temperature,Dsratio,Basis_S};
function obj = SparseCoding(PARAM, nSaves)
obj.Basis_num_used = PARAM{1};
obj.Basis_size = PARAM{2};
obj.Basis_num = PARAM{3};
obj.Basis_selected = zeros(PARAM{3}, 1);
obj.eta = PARAM{4};
obj.Temperature = PARAM{5};
obj.Dsratio = PARAM{6};
obj.patch_size = PARAM{8};
% initialize receptive field as white noise
a=rand(obj.Basis_size,obj.Basis_num)-0.5;
a=a*diag(1./sqrt(sum(a.*a)));
thenorm = ones(obj.Basis_size,1)*sqrt(sum(a.*a,1));
a=a./thenorm;
obj.Basis = a;
obj.Basis_hist = zeros(obj.Basis_size, obj.Basis_num, nSaves);
obj.Basis_hist(:, :, 1) = a;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%
%%% encode the image accoring to softmax distribution
%%%
%%% Images is the batch input
%%% debugmode indicates whether some intermedia should be recorded;
%%%
%%% Coef is the output Coefficients for each basis and images
%%% Error is the reconstruction error using current coefficients
%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [Coef, Error] = softmaxEncode(this,Images)
batch_size = size(Images,2);
Coef = zeros(this.Basis_num,batch_size);
I = Images;
for count = 1:this.Basis_num_used
corr = abs(this.Basis'*I)/this.Temperature;
corr = corr - kron(ones(this.Basis_num,1),max(corr));
softmaxcorr = softmax(corr);
softmaxcorr = tril(ones(this.Basis_num))*softmaxcorr - repmat(rand(1,batch_size),[this.Basis_num 1]); %faster than 'kron'
softmaxcorr(softmaxcorr<0) = 2;
[~,index] = min(softmaxcorr);
corr = this.Basis'*I;
linearindex = sub2ind(size(corr),index,1:batch_size);
Coef(linearindex) = Coef(linearindex) + corr(linearindex);
I = Images - this.Basis*Coef;
end
Error = I;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%
%%% Encode the input images with the best matched basis
%%%
%%% Images are the input images batch
%%%
%%% Coef is the output Coefficients
%%% Error is the reconstruction error using current coefficients
%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [coef, error, monocularity] = sparseEncode(this,imageBatch)
size_Batch = size(imageBatch,2);
coef = zeros(this.Basis_num,size_Batch);
imageOrig = imageBatch;
corr = this.Basis'*imageBatch; %correlation of each basis with each patch
corrBB = this.Basis'*this.Basis; %correlation between basis
for count = 1:this.Basis_num_used
[~,index] = max(abs(corr)); % indices of bases with max corr per patch
linearindex = sub2ind(size(corr),index,1:size_Batch); % corresponding linear indices in corr matrix
pCorr = corr(linearindex); % vector of correlations per patch (coefs per patch)
coef(linearindex) = coef(linearindex) + pCorr; % stores corr coefs into coef matrix
corr = corr - bsxfun(@times,corrBB(:,index),pCorr);
end
error = imageOrig - this.Basis*coef;
usedBasis = zeros(size(coef));
usedBasis(find(coef)) = 1;
usedBasis = sum(usedBasis, 2);
this.Basis_selected = usedBasis;
feature = mean(coef.^2, 2);
feature = feature ./ sum(feature);
binInds = calculateRightBinocularity(this.Basis); % returns right monocular dominance
weightedMonocs = binInds' .* feature;
monocularity = sum(weightedMonocs);
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%
%%% Calculate the correlation between input image and the basis
%%%
%%% Images are the input image batch
%%%
%%% Coef is the output correlation
%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [Coef,Error] = fullEncode(this,Images)
Coef = this.Basis'*Images;
Error = Images - this.Basis*Coef;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%
%%% Update the basis
%%%
%%% Coef is the input coefficient
%%% Error is the input error
%%% debugmode indicates whether some intermedia should be recorded;
%%%
%%% Basis_Change is the changing amount of the basis in current
%%% update
%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function updateBasis(this,coef,error)
deltaBases = error * coef'/size(error,2);
this.Basis = this.Basis + this.eta*deltaBases;
this.Basis = bsxfun(@rdivide,this.Basis,sqrt(sum(this.Basis.^2)));
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%
%%% train sparse coding for one step
%%%
%%% Images is the input image batch
%%% debugmode indicates whether some intermedia should be recorded;
%%%
%%% Error is the reconstruction error using the best matched coefficients
%%% Basis_picked indicates which basis are picked to encode
%%% Basis_Entropy is the entropy of each base
%%% Basis_Change is the changing amount of the basis in current
%%% update
%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [error, coef, monocularity] = stepTrain(this, Images)
[coef, error, monocularity] = this.sparseEncode(Images); % matching pursuit
updateBasis(this, coef, error); % adapt RFs via gradient descent
end
function [error, coef, monocularity, weightedMonocs] = suppressiveStepTrain(this, Images)
[coef, error, monocularity, weightedMonocs] = this.suppressiveEncode(Images);
updateBasis(this, coef, error);
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%
%%% save the parameters in a file
%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function saveClass(this, configfile)
Basis = this.Basis;
save(configfile, 'Basis', '-append');
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%
%%% save the Basis during training
%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function saveBasis(this, index)
% this.Basis_hist = cat(3, this.Basis_hist, this.Basis);
this.Basis_hist(:, :, index) = this.Basis;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%
%%% Encode the input images with the best matched basis
%%% and indicate the use of monocular basis functions
%%%
%%% imageBatch - preprocessed batch of input patches
%%%
%%% Coef - correlation coeffizients
%%% Error - reconstruction error
%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function [coef, error, monocularity, weightedMonocs] = suppressiveEncode(this,imageBatch)
size_Batch = size(imageBatch,2);
coef = zeros(this.Basis_num,size_Batch);
imageOrig = imageBatch;
corr = this.Basis'*imageBatch; % correlation of each basis with each patch
corrBB = this.Basis'*this.Basis; % correlation between basis
for count = 1:this.Basis_num_used
[~,index] = max(abs(corr)); % indices of bases with max corr per patch
linearindex = sub2ind(size(corr),index,1:size_Batch); % corresponding linear indices in corr matrix
pCorr = corr(linearindex); % vector of correlations per patch (coefs per patch)
coef(linearindex) = coef(linearindex) + pCorr; % stores corr coefs into coef matrix
corr = corr - bsxfun(@times,corrBB(:,index),pCorr);
end
error = imageOrig - this.Basis * coef;
usedBasis = zeros(size(coef));
usedBasis(find(coef)) = 1;
usedBasis = sum(usedBasis, 2);
this.Basis_selected = usedBasis;
feature = mean(coef.^2, 2);
feature = feature ./ sum(feature);
binInds = calculateRightBinocularity(this.Basis);
weightedMonocs = binInds' .* feature;
monocularity = sum(weightedMonocs);
end
function [coef, error, monocularity, weightedMonocs] = suppressiveEncodeAt(this, basisAt, imageBatch)
size_Batch = size(imageBatch,2);
coef = zeros(this.Basis_num,size_Batch);
imageOrig = imageBatch;
basis = this.Basis_hist(:, :, basisAt);
corr = basis' * imageBatch; % correlation of each basis with each patch
corrBB = basis' * basis; % correlation between basis
for count = 1:this.Basis_num_used
[~,index] = max(abs(corr)); % indices of bases with max corr per patch
linearindex = sub2ind(size(corr),index,1:size_Batch); % corresponding linear indices in corr matrix
pCorr = corr(linearindex); % vector of correlations per patch (coefs per patch)
coef(linearindex) = coef(linearindex) + pCorr; % stores corr coefs into coef matrix
corr = corr - bsxfun(@times,corrBB(:,index),pCorr);
end
error = imageOrig - basis * coef;
usedBasis = zeros(size(coef));
usedBasis(find(coef)) = 1;
usedBasis = sum(usedBasis, 2);
this.Basis_selected = usedBasis;
feature = mean(coef.^2, 2);
feature = feature ./ sum(feature);
binInds = calculateRightBinocularity(basis);
weightedMonocs = binInds' .* feature;
monocularity = sum(weightedMonocs);
end
end
end