% DCNetwork
%
% Artificial neural network for the simulation of the
% comparison of single digit numbers
%
% Author: Stefan Huber (s.huber@iwm-tuebingen.de)
% Version: 1.0
%
% For numerical Stroop: some minor changes have been made, including adding
% the option to choose the type of training method. Author: Angela Rose.
classdef DCNetwork < handle
properties (Constant)
ALPHA = 0.01; % learning rate
TAU = 0.01; % cascade rate of activation in a trial
THRESHOLD = 0.75; % threshold for activation function
% Huber's saved weights
% other than decimalsS
SAVED_WEIGHTS = [-0.511864929228142 4.39272618516079;-0.724322338374192 2.77803059156427;0.00176841426989734 3.22726264972506;1.63543768053811 3.02522330191163;0.811278610158098 1.34327505427213;2.20759783931969 1.51536638099751;1.33504120609633 -0.241520665462216;2.99322799332285 -0.0996751717135838;2.89411652899791 -0.334236628835027;4.42060559891104 -0.785628038504714;3.79837236358670 -0.870777896644516;4.19367664194421 -0.0424334235584019;3.36917522197439 0.782453912512456;2.46255368411295 1.06376694144714;1.66670827248468 1.52276504448384;1.18145243490844 2.02458322600946;1.09237602341207 2.51717615997486;0.0720739423750520 2.32120452181259;-0.437315757860684 3.98194440199658;-1.05025975550086 3.49141796693560;];
% zero higher priority
SAVED_WEIGHTS_Decimal = [-1.10967651429135 5.71441611112066;-0.426791163090014 3.89825235159990;-0.245671885889232 3.01217497332500;1.04709541399589 2.91818748427143;2.10460457412000 2.56789142071137;2.26123114756907 1.95567574327179;2.77684783760107 1.29404601590733;4.26485498853246 1.02411226894145;3.49151377669695 -0.810200963160015;3.93747825853358 -1.13999213579980;5.57904771984225 -0.656448429320572;5.22868846975300 0.623637023834364;3.98284365689981 1.07550862069857;2.75880620694170 0.540112561414901;2.62652004004480 1.98731956131432;1.92814351157943 2.54007857515607;0.937594421447878 3.27700034098162;1.31609821850266 3.80067595859943;-0.836561282568751 3.30992955758938;-1.02898275260240 4.34033555033896];
end
properties (Access = public)
outputTest = zeros(100,6); % output matrix
%{
%format of outputTest:
col 1: input number 1
col 2: input number 2
col 3: output/actual activation left node
col 4: output/actual activation right node
col 5: reaction time (rt)
col 6: 1 (i.e. true) if successful trial - compared numbers correctly
0 (i.e. false) if unsuccessful trial
%}
end
properties (Access = public)
% these have been changed to remove
% the zero node.
inputLayer % input vecotor for two digits (10*2, 1)
% digits: range from 0 to 9
outputLayer = zeros(1,2); % output vector
weightsI2O % weights from input to output layer(10*2, 2)
end
methods (Access = public)
% constructor
function DCN = DCNetwork()
%remove zero node: have 18 nodes instead of 20
%DCN.inputLayer = zeros(2*10,1);
DCN.inputLayer = zeros(2*9,1);
outSize = size(DCN.outputLayer);
%DCN.weightsI2O = rand(2*10, outSize(1,2))*2-1;
DCN.weightsI2O = rand(2*9, outSize(1,2))*2-1;
end
% sets input-to-output weights to saved weights
function setWeightsToSavedWeights(DCN)
DCN.weightsI2O = DCN.SAVED_WEIGHTS;
end
% sets input-to-output weights to saved weights
function setWeightsToSavedWeightsDecimal(DCN)
DCN.weightsI2O = DCN.SAVED_WEIGHTS_Decimal;
end
% sets input-to-output weights
function setWeightsI2O(DCN, weightsI2O)
DCN.weightsI2O = weightsI2O;
end
% gets input-to-output weights
function weightsI2O = getWeightsI2O(DCN)
weightsI2O = DCN.weightsI2O;
end
% sets input activity for a pair of digits
function setInputActivity (DCN,digit1, digit2)
if (digit1 ~= digit2)
DCN.inputLayer = [DCN.calcActivityVector(digit1) ; DCN.calcActivityVector(digit2)];
else
%DCN.inputLayer = zeros(2*10,1);
DCN.inputLayer = zeros(2*9,1);
end;
end
function setWeightsFromExternalFile(DCN, numANN, numDCNetwork, DCNWeightsI2O_FromFile)
switch numDCNetwork
case 1
DCN.weightsI2O = DCNWeightsI2O_FromFile(1:18,4*numANN-3:4*numANN-2);
case 2
DCN.weightsI2O = DCNWeightsI2O_FromFile(1:18,4*numANN-1:4*numANN);
end
end
% train network and set weights
function setWeightsFromTraining(DCN, trainPairs)
outSize = size(DCN.outputLayer);
%DCN.weightsI2O = rand(2*10, outSize(1,2))*2-1; %initialise weights
DCN.weightsI2O = rand(2*9, outSize(1,2))*2-1; %initialise weights
numTrainTrials = 100000;
switch trainPairs
case {'U0', 'U1'}
% train network using random numbers from a Uniform distribution
DCN.trainRandomList(numTrainTrials, trainPairs);
case 'G'
% train network using random numbers from a Google survey
trainRandomGoogleList(DCN, numTrainTrials);
otherwise
% error
fprintf('Error training comparison network, invalid trainPairs, default to U0: %s \n', trainPairs);
DCN.trainRandomList(numTrainTrials, 'U0');
end
% test the network after training
numTestTrials = 10000;
testRandomList(DCN, numTestTrials);
%DCN.weightsI2O %display weights for debugging
end
% trains the network using randomly generated single-digit pairs (from a uniform distribution)
% it is ensured that digits are different
function trainRandomList(DCN, numItems, trainPairs)
for i=1:numItems
if (trainPairs == 'U0')
% weight all numbers evenly
%digit1 = randi(10)-1;
%digit2 = randi(10)-1;
digit1 = randi(9);
digit2 = randi(9);
while (digit1 == digit2)
%digit2 = randi(10)-1;
digit2 = randi(9);
end
end
%training by U1 will not happen when the zero node removed
if (trainPairs == 'U1')
% weight zero more
digit1 = randi(11)-2;
digit2 = randi(11)-2;
if (digit1 == -1)
digit1 = 0;
end
if (digit2 == -1)
digit2 = 0;
end
while (digit1 == digit2)
digit2 = randi(10)-2;
if (digit2 == -1)
digit2 = 0;
end
end;
end
setInputActivity(DCN, digit1, digit2);
propagateActivity(DCN);
adaptWeightsDelta(DCN, digit1, digit2);
%{
if (i > 0 && mod(i,numItems/10) == 0)
i %#ok<NOPRT>
end;
%}
end;
end
% trains the network using randomly generated single-digit number pairs (from a Google distribution)
% it is ensured that digits are different
function trainRandomGoogleList(DCN, numItems)
%DCN.weightsI2O %display weights for debugging
for i=1:numItems
digit1 = GoogleDistribution.getRandomNumber();
digit2 = GoogleDistribution.getRandomNumber();
while (digit1 == digit2)
digit2 = GoogleDistribution.getRandomNumber();
end;
setInputActivity(DCN, digit1, digit2);
propagateActivity(DCN);
adaptWeightsDelta(DCN, digit1, digit2);
%{
%display for debugging
if (i >= 99950)
i
DCN.weightsI2O
end;
%}
%{
if (i > 0 && mod(i,numItems/10) == 0)
%i %#ok<NOPRT>
%DCN.weightsI2O
end;
%}
end;
end
% tests the network using randomly generated number pairs (from a uniform distribution)
% results are written to outputTest
function testRandomList(DCN, numItems)
numErrors = 0; %initialise error count
for i=1:numItems
%digit1 = randi(10)-1;
%digit2 = randi(10)-1;
digit1 = randi(9);
digit2 = randi(9);
while (digit1 == digit2)
%digit2 = randi(10)-1;
digit2 = randi(9);
end;
DCN.outputTest(i,1) = digit1;
DCN.outputTest(i,2) = digit2;
setInputActivity(DCN, digit1, digit2);
rt = propagateActivity(DCN);
DCN.outputTest(i,3) = DCN.outputLayer(1,1);
DCN.outputTest(i,4) = DCN.outputLayer(1,2);
DCN.outputTest(i,5) = rt;
DCN.outputTest(i,6) = (DCN.outputLayer(1,1) > DCN.outputLayer(1,2) && digit1 > digit2) || ...
(DCN.outputLayer(1,1) < DCN.outputLayer(1,2) && digit1 < digit2);
%report if errors in testing
if (~DCN.outputTest(i,6))
numErrors = numErrors + 1;
fprintf('Comparison network testing error - unsuccessful trial: %d, compare: %d and %d \n', i, digit1, digit2);
end;
end;
fprintf('Comparison network testing - number of errors out of %d trials: %d \n', numItems, numErrors);
end
% propagates the activity once
% called when testing the network (and not when training)
function propagate(DCN)
newActivity = DCN.inputLayer'*DCN.weightsI2O;
DCN.outputLayer = DCN.outputLayer*(1-DCN.TAU) + newActivity*DCN.TAU;
inhMatrix = [DCN.outputLayer(1,2) * -2 DCN.outputLayer(1,1) * -2];
DCN.outputLayer = DCN.outputLayer + inhMatrix;
% DCN.outputLayer = DCN.outputLayer + rand()*0.005; %Huber
DCN.outputLayer = 1 ./ (1 + exp(-2*DCN.outputLayer));
end
% returns the output layer activity
function outputActivity = getOutputLayerActivity(DCN)
outputActivity = DCN.outputLayer;
if (outputActivity(1,1) > DCN.THRESHOLD)
outputActivity(1,1) = 1;
else
outputActivity(1,1) = 0;
end;
if (outputActivity(1,2) > DCN.THRESHOLD)
outputActivity(1,2) = 1;
else
outputActivity(1,2) = 0;
end;
end
% returns the output layer
function outputLayer = getOutputLayer(DCN)
outputLayer = DCN.outputLayer;
end
% resets the output layer
function resetOutputLayer(DCN)
DCN.outputLayer = zeros(1,2);
end
end
methods (Access = private)
% sets the response of threshold reached to 1
function setResponse(DCN)
if (DCN.outputLayer(1, 1) > DCN.THRESHOLD)
DCN.outputLayer(1, 1) = 1;
end;
if (DCN.outputLayer(1, 2) > DCN.THRESHOLD)
DCN.outputLayer(1, 2) = 1;
end;
end
% changes weights according to the delta rule
function adaptWeightsDelta(DCN, digit1, digit2)
f_derOutput = DCN.outputLayer.*(ones(1,2) - DCN.outputLayer); % first derivate of sigmoid
delta = DCN.inputLayer*(f_derOutput.*(DCN.getCorrect(digit1,digit2) - (DCN.outputLayer))*DCN.ALPHA);
DCN.weightsI2O = DCN.weightsI2O + delta;
end
% propagates the activity and returns the simulated RT
% called when training the network (and not while testing)
function rt = propagateActivity(DCN)
DCN.outputLayer = zeros(1,2);
rt = 0;
while(DCN.outputLayer(1,1) < DCN.THRESHOLD && DCN.outputLayer(1,2) < DCN.THRESHOLD && rt < 100)
newActivity = DCN.inputLayer'*DCN.weightsI2O;
DCN.outputLayer = DCN.outputLayer*(1-DCN.TAU) + newActivity*DCN.TAU;
inhMatrix = [DCN.outputLayer(1,2) * -2 DCN.outputLayer(1,1) * -2];
DCN.outputLayer = DCN.outputLayer + inhMatrix;
%DCN.outputLayer = DCN.outputLayer + rand()*0.005; %Huber
DCN.outputLayer = 1 ./ (1 + exp(-2*DCN.outputLayer));
rt = rt + 1;
end;
end
end
methods(Static, Access = private)
% calculates the activity vector for a digit
function activityVector = calcActivityVector(digit)
%remove zeroes
%{
activityVector = zeros(10,1);
for i=1:10
activityVector(i,1) = exp(-10*abs(i-(digit+1)));
end
%}
activityVector = zeros(9,1);
for i=1:9
activityVector(i,1) = exp(-10*abs(i-digit));
end;
end
% returns the correct output vector for a number pair digit1 and
% digit2
function correctMatrix = getCorrect(digit1, digit2)
if (digit1 > digit2)
correctMatrix = [1 0];
end
if (digit1 < digit2)
correctMatrix = [0 1];
end
end
end
end