package stimdelrew;

import stimulusdelayreward.*;
import lnsc.page.*;
import lnsc.DataSetCollection;
import lnsc.DataSet;
import java.util.Random;
import java.io.*;

/** Main routine to test 30 random networks on 2 control blocks (CS only or US only).
 *  It saves successfull networks and data in .ds76 and .dsc76 files respectively.
 *
 * @author Francois Rivest
 * @version 1.0
 */


public class PreTraining {

    static final int ERROR = -2;
    static final int CRASH = -1;
    static final int FAIL = 0;
    static final int LEARN = 1;
    static final int KEPT = 3;

    /** No arguments required. */
    public static void main(String[] args) {

        //Generate 30 fully successful network
        int total = 0;
        int error = 0;
        int crash = 0;
        int fail = 0;
        int learn = 0;
        int kept = 0;
        while (kept <30) {
            System.out.println("Trial " + (total+1));
            try {
                switch (run(.5, 2, true, true)) {
                    case FAIL:
                        fail++;
                        break;
                    case KEPT:
                        kept++;
                    case LEARN:
                        learn++;
                        break;
                    case ERROR:
                    default:
                        error++;
                }
            } catch(Exception e) {
                crash++;
                System.out.println("(true,true) crashes");
            }
            total++;
            System.out.println("So far " + (kept));
        }

        try {
            PrintStream out = new PrintStream(new FileOutputStream("Result.log", true));
            out.println("Total: " + total);
            out.println("Error: " + error);
            out.println("Crash: " + crash);
            out.println("Fail: " + fail);
            out.println("Learn: " + learn);
            out.println("Kept: " + kept);
            out.close();
        } catch (Exception e) {
            System.err.println("Can't write log!");
        }

        System.out.println("Total: " + total);
        System.out.println("Error: " + error);
        System.out.println("Crash: " + crash);
        System.out.println("Fail: " + fail);
        System.out.println("Learn: " + learn);
        System.out.println("Kept: " + kept);



    }

    public static int run(double lr, int uc, boolean inSquash, boolean outSquash)
    {

        double learningRate = lr;
        int unitCount = uc;

        boolean gate2gate = false;
        boolean in2out = false;

        boolean success = false;
        boolean lastSuccess = false;
        int firstSuccess = -1;
        boolean prevLastSuccess = false; //On 18Sep06 to use previous to last train block

        //Create the agent
        AbstractObservableAgent a =
            new ActorCritic_PDAETLSTM_Monkey2(
                              unitCount,unitCount,
                              inSquash, outSquash,
                              gate2gate, in2out,
                              learningRate, .1,
                              //TD: Rivest06, {no bias, cue only}->AC,
                              4, new FlexibleSignalStateRepresentation(false, true, false),
                              .8, true);//LSTM: e-trace, reset


        //Create space for all data
        DataSetCollection dataCol = new DataSetCollection(22);

        //int stop = 20;
        //int lowbound = stop-21;

        //Run multiple 4 minutes block (about 40 trials at 6s per trial)
        for (int i=0; i<2; i++)
        {

            //Environement
            SingleAgentEnvironment env = new SingleAgentEnvironment(0,0,10*60*2);
            //Data collector
            DataSetCollector dc = new DataSetCollector();
            env.addObserver(dc);
            a.addObserver(dc);
            //Run
            env.go(a, new ExperimentControlState(1000));
            //Collect data
             //if (i> lowbound) {
                 dataCol.setData("ControlState", i, dc.StateHistory);
                 dataCol.setData("ControlMonkey", i, dc.MonkeyHistory);
             //}
            //Clean collector
            a.deleteObserver(dc);
            if (i%100 ==0) {System.out.println();}
        }
        //Tools.dumpV(dc);
        //Tools.dumpE(dc);
        //Tools.dumpR(dc);

        //always save
            Random rnd = new Random(java.lang.System.currentTimeMillis());
            long id =  rnd.nextInt();
            String name = "Result" + id + ".dsc76";
            try {
                System.out.print("Saving history in " + name + " ...");
                Tools.saveDataSetCollection(name, dataCol);
                DataSet dat = new DataSet();
                dat.setData("Agent", a);
                lnsc.Tools.saveDataSet("TrainedAgent" + id + ".ds76", dat);
                System.out.println(" done!");
                //System.out.println("Not saved!");
                return KEPT;
            } catch (Exception e) {
                System.err.println(e.toString());
                return ERROR;
            }

    }
}