package stimdelrew;
import stimulusdelayreward.*;
import lnsc.page.*;
import lnsc.DataSetCollection;
import lnsc.DataSet;
import java.util.Random;
import java.io.*;
/** Main routine training networks until there are 5 successful networks.
* It saves successfull networks and final training data in .ds76 and .dsc76
* files respectively.
*
* @author Francois Rivest
* @version 1.0
*/
public class MasterTraining {
static final int ERROR = -2;
static final int CRASH = -1;
static final int FAIL = 0;
static final int LEARN = 1;
static final int KEPT = 3;
/** No arguments required. */
public static void main(String[] args) {
//Generate 30 fully successful network
int total = 0;
int error = 0;
int crash = 0;
int fail = 0;
int learn = 0;
int kept = 0;
while (kept <5) {
System.out.println("Trial " + (total+1));
try {
switch (run(.5, 2, true, true)) {
case FAIL:
fail++;
break;
case KEPT:
kept++;
case LEARN:
learn++;
break;
case ERROR:
default:
error++;
}
} catch(Exception e) {
crash++;
System.out.println("(true,true) crashes");
}
total++;
System.out.println("So far " + (kept));
}
try {
PrintStream out = new PrintStream(new FileOutputStream("Result.log", true));
out.println("Total: " + total);
out.println("Error: " + error);
out.println("Crash: " + crash);
out.println("Fail: " + fail);
out.println("Learn: " + learn);
out.println("Kept: " + kept);
out.close();
} catch (Exception e) {
System.err.println("Can't write log!");
}
System.out.println("Total: " + total);
System.out.println("Error: " + error);
System.out.println("Crash: " + crash);
System.out.println("Fail: " + fail);
System.out.println("Learn: " + learn);
System.out.println("Kept: " + kept);
}
public static int run(double lr, int uc, boolean inSquash, boolean outSquash)
{
double learningRate = lr;
int unitCount = uc;
boolean gate2gate = false;
boolean in2out = false;
boolean success = false;
boolean lastSuccess = false;
int firstSuccess = -1;
boolean prevLastSuccess = false; //On 18Sep06 to use previous to last train block
//Create the agent
AbstractObservableAgent a =
new ActorCritic_PDAETLSTM_Monkey2(
unitCount,unitCount,
inSquash, outSquash,
gate2gate, in2out,
learningRate, .1,
//TD: Rivest06, {no bias, cue only}->AC,
4, new FlexibleSignalStateRepresentation(false, true, false),
.8, true);//LSTM: e-trace, reset
//Create space for all data
DataSetCollection dataCol = new DataSetCollection(22);
int stop = 100*10;
int lowbound = stop-21;
//Run multiple 4 minutes block (about 40 trials at 6s per trial)
for (int i=0; i<=stop; i++)
{
//INSERT TEST BLOCK HERE
if (i % 10 == 0) {
//Environement
SingleAgentEnvironment env = new SingleAgentEnvironment(0,0,10*60*1);
//Data collector
DataSetCollector dc = new DataSetCollector();
env.addObserver(dc);
a.addObserver(dc);
//Run
env.go(a, new ExperimentTestState(1000));//test
//Collect data
if (i> lowbound) {
dataCol.setData("TestState", i-lowbound, dc.StateHistory);
dataCol.setData("TestMonkey", i-lowbound, dc.MonkeyHistory);
}
//Clean collector
a.deleteObserver(dc);
}
//TEST BLOCK END'S HERE
//Environement
SingleAgentEnvironment env = new SingleAgentEnvironment(0,0,10*60*2);
//Data collector
prevLastSuccess = lastSuccess; ////On 18Sep06: save previous train block success
lastSuccess = false;
DataSetCollector dc = new DataSetCollector();
env.addObserver(dc);
a.addObserver(dc);
//Run
env.go(a, new ExperimentState(1000));
//Test for success
if (dc.m_CorrectFound) {
System.out.println("Learning succeed at block " + i + " step " + dc.m_FirstCorrectStep
+ " with alpha = " + learningRate + "and " + unitCount + "units!");
success = true;
lastSuccess = true;
if (firstSuccess == -1) {
firstSuccess = i;
}
}
//Collect data
if (i> lowbound) {
dataCol.setData("State", i-lowbound, dc.StateHistory);
dataCol.setData("Monkey", i-lowbound, dc.MonkeyHistory);
dataCol.setData("LSTMSuccess", i-lowbound, new Boolean(dc.m_CorrectFound));
}
//Clean collector
a.deleteObserver(dc);
if (i%100 ==0) {System.out.println();}
}
//Tools.dumpV(dc);
//Tools.dumpE(dc);
//Tools.dumpR(dc);
if (!success) {
System.out.println("Learning failed(" + learningRate + "_" + unitCount
+ "_" + inSquash + "_" + outSquash + ")!");
return FAIL;
} else if (prevLastSuccess) { //Save only if it remained successfull
Random rnd = new Random(java.lang.System.currentTimeMillis());
long id = rnd.nextInt();
String name = "Result" + id + ".dsc76";
try {
System.out.print("Saving history in " + name + " ...");
Tools.saveDataSetCollection(name, dataCol);
DataSet dat = new DataSet();
dat.setData("Agent", a);
dat.setData("Learn", new Boolean(success));
dat.setData("Kept", new Boolean(prevLastSuccess));//On 18Sep06
dat.setData("First", new Integer(firstSuccess));
dat.setData("VeryLast", new Boolean(lastSuccess));//On 18Sep06
lnsc.Tools.saveDataSet("TrainedAgent" + id + ".ds76", dat);
System.out.println(" done!");
//System.out.println("Not saved!");
} catch (Exception e) {
System.err.println(e.toString());
return ERROR;
}
}
return (prevLastSuccess ? KEPT : LEARN);//On 18Sep06
}
}