classdef DoubleContextLearnerDBNaLP < DoubleContextLearner
% DoubleContextLearner
% Uses a the combination of a Deep Belief Network (DBN) and Linear
% Perceptron (LP) to learn the double-context task.
% Florian Raudies, 01/30/2014, Boston University.
properties (SetAccess = private)
dbn % Deep belief network.
lp % Linear perceptron.
end
methods
function obj = DoubleContextLearnerDBNaLP(...
LetterLabel,NumberLabel,nHidden,nLayer)
obj = obj@DoubleContextLearner(LetterLabel,NumberLabel);
obj.dbn = DeepBeliefNetwork(repmat(nHidden,[1 nLayer]));
obj.lp = LinearPerceptron;
end
function obj = learn(obj,nBlock,ExcludeState)
[Data Label] = obj.generateData(nBlock, ExcludeState);
obj.dbn.fit(Data,Label);
% Freeze learning of DBN and train the linear perceptron.
Data = obj.getDataBlockExclude(ExcludeState);
Label = obj.getLabelBlockExclude(ExcludeState);
nLabel = size(Label,1); % Convert labels to 0, 1 and 1D.
[~, Label] = max(Label,[],2);
Label = Label - 1;
A = obj.dbn.probe(Data); % nLayer x nHidden x 16
Wc = obj.dbn.getLastLayer.getWeightForClass(); % 2 x nHidden
Data = sum(repmat(Wc(1,:),[nLabel 1])'.*squeeze(A(end,:,:)))';
obj.lp.train(Data,Label); % Requires label numbers 0 and 1.
end
function err = testError(obj)
Data = obj.getDataBlock();
Label = obj.getLabelBlock();
nLabel = size(Label,1);
A = obj.dbn.probe(Data); % nLayer x nHidden x 16
Wc = obj.dbn.getLastLayer.getWeightForClass(); % 2 x nHidden
Data = sum(repmat(Wc(1,:),[nLabel 1])'.*squeeze(A(end,:,:)))';
L = obj.lp.predict(Data);
[~, Label] = max(Label,[],2);
Label = Label - 1;
err = sum(L~=Label)/length(L);
end
function [A Wc] = getDBNActivationSortedByWeights(obj)
Data = obj.getDataBlock();
A = obj.dbn.probe(Data); % nLayer x nHidden x 16
Wc = obj.dbn.getLastLayer.getWeightForClass(); % 2 x nHidden
[~, Index] = sort(abs(Wc(1,:)),2,'descend');
A = A(:,Index,:);
Wc = Wc(1,Index);
end
function D = getLPData(obj)
Data = obj.getDataBlock();
A = obj.dbn.probe(Data); % nLayer x nHidden x 16
Wc = obj.dbn.getLastLayer.getWeightForClass(); % 2 x nHidden
D = sum(repmat(Wc(1,:),[nLabel 1])'.*squeeze(A(end,:,:)))';
end
function obj = setBlockTrain(obj,flag)
obj.blockTrain = flag;
obj.dbn.setBlockTrain(flag);
end
end
methods (Static = true)
function id = getIdentifier()
id = 'DBNaLP';
end
end
end