package stimdelrew;
import stimulusdelayreward.*;
import java.io.*;
import lnsc.page.*;
import lnsc.*;
/** Main routine to test trained networks on a control block (CS only or US only).
* It saves successfull networks and data in .ds76 and .dsc76 files respectively
* (beginning with "StopTest_").
*
* @author Francois Rivest
* @version 1.0
*/
public class PostTraining {
/** The first argument must be the directory where to find .ds76 files of
* trained networks to be tested. */
public static void main(String[] args) {
//First args must be directory where to find .ds76 files of trained nets
//List files
File[] files = getNetworksList(
args[0]);
if (files == null) {
return;
}
for (int z = 0; z < files.length; z++) {
System.out.print(files[z].getName() + "\t");
//Load agent
String a_name = files[z].getName();
System.out.println("Loading " + a_name);
AbstractObservableAgent a = loadAgent(files[z]);
//Train agent
if (a != null) {
testAgent(a, a_name, 1);//single block
System.out.println("\n\n");
}
System.out.println();
//System.out.println(LinearAlgebra.toString(blocks));
}
}
static void testAgent(AbstractObservableAgent a, String a_name, int stop) {
System.out.println("Testing " + a_name);
//Create space for all data
DataSetCollection dataCol = new DataSetCollection(5);//22
//Run multiple 4 minutes block (about 40 trials at 6s per trial)
for (int i = 0; i < stop; i++) {
//Environement
SingleAgentEnvironment env = new SingleAgentEnvironment(0,0,10*60*2);
//Data collector
DataSetCollector dc = new DataSetCollector();
env.addObserver(dc);
a.addObserver(dc);
//Run
env.go(a, new ExperimentControlState(1000));
//Collect data
//if (i > lowbound) {
dataCol.setData("ControlState", 25+i, dc.StateHistory);
dataCol.setData("ControlMonkey", 25+i, dc.MonkeyHistory);
//}
//Clean collector
a.deleteObserver(dc);
if (i % 100 == 0) {
System.out.println();
}
}
//Tools.dumpV(dc);
//Tools.dumpE(dc);
//Tools.dumpR(dc);
String name = "StopTest_" + a_name;
try {
System.out.print("Saving history in " + name + ".dsc76 ...");
stimulusdelayreward.Tools.saveDataSetCollection(name.replaceAll(".ds76",".dsc76"),
dataCol);
DataSet dat = new DataSet();
dat.setData("Agent", a);
dat.setData("BlockCount", new Integer(stop));
dat.setData("BlockNumber", new Integer(25));
lnsc.Tools.saveDataSet(name, dat);
System.out.println(" done!");
//System.out.println("Not saved!");
}
catch (Exception e) {
System.err.println(e.toString());
System.err.println("Can't save agent " + name);
}
}
static AbstractObservableAgent loadAgent(File filename) {
try {
DataSet dat = lnsc.Tools.loadDataSet(filename.getAbsolutePath());
AbstractObservableAgent a = (AbstractObservableAgent) dat.getData("Agent");
return a;
} catch (Exception e) {
System.err.println(e.toString());
System.err.println("Can't load agent " + filename.getName());
}
return null;
}
static File[] getNetworksList(String path) {
//Open directory and list files
File f = new File(path);
if (!f.isDirectory()) {
System.err.println("Not a directory!");
return null;
}
//List files
File[] fs = f.listFiles(new ExtensionFilter(new String[] {"ds76"}));
if (fs.length == 0) {
System.err.println("No '.ds76' files found!");
return null;
}
//Return
//open then and check the success flag
return fs;
}
}