% Stand-alone script to train a single-digit number comparison network.
% Script written by Angela Rose (was Porter), which calls methods written by Stefan Huber.

% outout the weights to a (space delimitered) csv file to be loaded into NSCCN
% main code so they are set. This is for speed and repeatability.
saveWeightsToFile = true;

% num ANNs to train. Need to set this value and numTrainDCN will be
% automatically calculated. If not saving the weights to an output file
% then just set this to 1.
numANN = 30;

% Each single digit comparison network runs training once and generates 2
% rows in the output file. Every ANN requires 2 single digit comparison
% networks to be trained (num size and phys size), which equals 4 rows (sets of weights)
% in the output file: for weights into: num size L node,
% num size R node, phys size L node, phys size R node. 
% For the single-digit comparison task that is not numerical Stroop (i.e
% taskType=2), train the networks as per Stroop as when testing the network
% it will just ignore the phys size network.
numTrainDCN = numANN*2;

if saveWeightsToFile
    % each row in file contains 18 weights that connect to one of the nodes
    % in the comparison layer. (First row is weights into left node, second row is weights into
    % right node. So each single digit comparison network
    % training will have 2 rows in this file for the 2 columns in
    % DCN.weightsI2O
    outputFile = 'savedWeightsI2O_30ANN_LR01_G_30000_1.csv';
    fileID = fopen(outputFile, 'w');
    %floating point number with a total of 18 characters (includes minus
    %signs and decimal point, and 15 decimal places. (There are no leading
    %zeroes.)
    formatSpecDtl = '%18.15f %18.15f %18.15f %18.15f %18.15f %18.15f %18.15f %18.15f %18.15f %18.15f %18.15f %18.15f %18.15f %18.15f %18.15f %18.15f %18.15f %18.15f\n';
    rng(1); %have set seed so can rerun and reproduce
end

% set number of training trials. Huber 2016 is 100,000 
numTrainTrials = 30000;

trainPairs = 'G';
   
for i=1:numTrainDCN
    
    DCN = DCNetwork();

    switch trainPairs
        case {'U0', 'U1'}
            % train network using random numbers from a uniform distribution
            DCN.trainRandomList(numTrainTrials, trainPairs)
        case 'G'
            % train network using random numbers from a Google survey
            DCN.trainRandomGoogleList(numTrainTrials)
        otherwise
            %error - need to process/handle this (haven't tested)
            fprintf('Error training comparison network, invalid trainPairs, default to U0: %s \n', trainPairs);
            DCN.trainRandomList(numTrainTrials, 'U0');
    end
            
    % test the network after training
    numTestTrials = 10000;
    DCN.testRandomList(numTestTrials)
    
    if saveWeightsToFile
        fprintf(fileID, formatSpecDtl, DCN.weightsI2O);
    end
end

%output weights to file
if saveWeightsToFile
    fclose(fileID);
end