package stimulusdelayrewardanalyzer;

import java.io.*;
import java.sql.*;
import lnsc.*;
import lnsc.lstm.*;

/** Main conversion thread to convert .dsc76 files into ODBC mdb.
 * This class can work on its own thread. It reads a given directory .dsc76
 * files made using StimulusDelayReward_28Jul06_1 based code and extract a
 * specific train or test block into a .txt file.
 *
 * Dump files are renaed to :
 *      m_TargetPath + "\\" + id + "_export_" + m_BlockIndex + xstr + ".txt"
 * Dumping is done by Tools.dumpBlock().
 *
 * @author Francois Rivest
 * @version 1.0
 */


public class Convert2DB implements Runnable {

	public final static int TRAIN = 0;
	public final static int TEST = 1;
	public final static int CONTROL = -1;

	/*********************************************************************/
    //Private fields

	/** Name of the database to fill. */
	protected String m_DBName;

	/** Path where the .dsc files are. */
	protected String m_DscPath;

	/** Date of the training sesion. */
	protected Date m_Date;

	/** Model of network trained. */
	protected int m_Model;

	/** Number of .dsc76 files completely processed. */
	protected int m_NetworkCount;

	/** Number of .dsc76 files in total. */
	protected int m_AllNetworkCount;

	/** Number of blocks within current .dsc file completely processed. */
	protected int m_BlockCount;

	/** Number of all blocks within current .dsc76 files. */
	protected int m_AllBlockCount = 25;

	/** Output text string buffer. */
	protected String m_Buff = "";

	/** Conneciton to databse. */
	protected Connection m_Conn;

	/*********************************************************************/
	//Constructors

	/** Construct a Runnable class that analyse a whole directory of .dsc76
	 * files to convert them into an access database.
	 * @param   newDatabaseName   Name of the access databse to write to.
	 * @param   newDscFilesPath   Name of the directory to read .dsc files from.
	 * @param   newDate           Date of training
	 * @param   newModel          Model of networks trained
	 */
	public Convert2DB(String newDatabaseName, String newDscFilesPath,
					  Date newDate, int newModel) {
		m_DBName = newDatabaseName;
		m_DscPath = newDscFilesPath;
		m_Date = newDate;
		m_Model = newModel;
	}

	/*********************************************************************/
	//Properties

	/** Number of .dsc76 files completely processed. */
	public int getNetworkCount() {
		return m_NetworkCount;
	}

	/** Number of blocks within current .dsc file completely processed. */
	public int getBlockCount() {
		return m_BlockCount;
	}

	/** Number of .dsc76 files in total. */
	public int getAllNetworkCount() {
		return m_AllNetworkCount;
	}

	/** Number of all blocks within current .dsc76 files. */
	public int getAllBlockCount() {
		return m_AllBlockCount;
	}

	/** Returns the current text buffer and empty it. */
	public synchronized String getText() {
		String tmp = m_Buff;
		m_Buff = "";
		return tmp;
	}

	/** Add some string to the text buffer. */
	protected synchronized void dbout(String text) {
		//m_Buff += text;
                System.out.print(text);//16 may 2008
	}

	/*********************************************************************/
	//Helper function

	/** Writes a new network row in the database.
	 * Returns the highest block number if the network exist, 0 if it does not.*/
	protected int createNetwork(int netID, int model) {

		//First check if it exist
		try {
			Statement stmt = m_Conn.createStatement();
			String select = "SELECT nNetworkID FROM _Networks WHERE " +
							"nNetworkID=" + netID;
			//dbout(select + "\n");
			ResultSet res = stmt.executeQuery(select);
			if (res.next()) {
				dbout("Network " + netID + " already exists!\n");
				//Get block number
				stmt = m_Conn.createStatement();
				select = "SELECT MAX(nBlockNumber) AS m FROM _Blocks WHERE " +
							"nNetworkID=" + netID;
				//dbout(select + "\n");
				res = stmt.executeQuery(select);
				if (!res.next()) {
					return 0;
				}
				return res.getInt("m");
			}
		} catch (SQLException e) {
			System.err.println(e.getMessage());
			e.printStackTrace();
			dbout(" -Unexpected problem!-");
			return 0;
		}
		//If it does not exist create it
		try {
			Statement stmt = m_Conn.createStatement();
			String insert = "INSERT INTO _Networks " +
				"(nNetworkModelID, nNetworkID) " +
				"VALUES (" + model + ", " + netID  + ")";
			//dbout(insert + "\n");
			stmt.execute(insert); //stmt.RETURN_GENERATED_KEYS does not work in access
								  //stmt.getGeneratedKeys();
		  } catch (SQLException e) {
			  System.err.println(e.getMessage());
			  e.printStackTrace();
			  dbout(" -Unexpected problem!-");
		  }
		  return 0;
	}

	/** Writes a new block row in the database and returns its new ID. */
	protected int createBlock(int netID, int trainDelay, Date date,
							  int blockNumber, int blockType) {
		try {
			Statement stmt = m_Conn.createStatement();
			String insert = "INSERT INTO _Blocks " +
				"(nNetworkID, nBlockTrainDelay, nBlockTypeID, " +
				"nBlockNumber, dBlockDate) VALUES (" + netID + ", " +
				trainDelay + ", " + blockType + ", " + blockNumber +
				", '" + date.toString() + "')";
			//dbout(insert + "\n");
			stmt.execute(insert); //stmt.RETURN_GENERATED_KEYS does not work in access
			                      //stmt.getGeneratedKeys();
			//so I must re-extract the key, hoping it is unique
			String select = "SELECT nBlockID FROM _Blocks WHERE " +
							"nNetworkID=" + netID + " AND " +
							"nBlockTrainDelay=" + trainDelay + " AND " +
							"nBlockNumber=" + blockNumber;
			//dbout(select + "\n");
			ResultSet res = stmt.executeQuery(select);
			if (!res.next()) {
				dbout(" -Problem: No key generated for block!- ");
				return 0;
			}
			int blockUID = res.getInt("nBlockID");
			if (res.next()) {
				dbout(" -Problem: Possible block duplicate!- ");
				return 0;
			}
			return blockUID;
		} catch (SQLException e) {
			System.err.println(e.getMessage());
			e.printStackTrace();
			dbout(" -Unexpected problem!-");
			return 0;
		}
	}

	/** Writes a new trial row in the database and returns its new ID. */
	protected int createTrial(int blockUID, int trialNumber,
							  int trialDelay, int trialOnset,
							  int trialType) {
		try {
			Statement stmt = m_Conn.createStatement();
			String insert = "INSERT INTO _Trials (" +
				"nBlockID, nTrialNumber, tTrialDelay, tTrialOnset, nTrialTypeID" +
				") VALUES (" +
				blockUID + ", " + trialNumber + ", " + trialDelay + ", " +
				trialOnset + ", " + trialType + ")";
			//dbout(insert + "\n");
			stmt.execute(insert); //stmt.RETURN_GENERATED_KEYS does not work in access
								  //stmt.getGeneratedKeys();
			//so I must re-extract the key, hoping it is unique
			String select = "SELECT nTrialID FROM _Trials WHERE " +
							"nBlockID=" + blockUID + " AND " +
							"nTrialNumber=" + trialNumber;
			//dbout(select + "\n");
			ResultSet res = stmt.executeQuery(select);
			if (!res.next()) {
				dbout(" -Problem: No key generated for trial!- ");
				return 0;
			}
			int trialUID = res.getInt("nTrialID");
			if (res.next()) {
				dbout(" -Problem: Possible trial duplicate!- ");
				return 0;
			}
			return trialUID;
		} catch (SQLException e) {
			System.err.println(e.getMessage());
			e.printStackTrace();
			dbout(" -Unexpected problem!-");
			return 0;
		}
	}

	/** Dumps signal from a block. */
	protected void dumpBlock(int blockUID, DataSetCollection state, DataSetCollection monkey) {

		boolean prevIsInTrial = false;

		for (int i = 0; i < state.getDataSetCount(); i++) {

			//Extract trial/state information
			int time = ( (Integer) state.getData(Tools.STEP, i)).intValue() *
				200; //tTime: block time (ms)
			boolean isInTrial = ( (Boolean) state.getData(Tools.IS_IN_TRIAL, i)).
				booleanValue(); //bIsInTrial: is in trial
			int trialNumber = ( (Integer) state.getData(Tools.CURRENT_TRIAL, i)).
				intValue(); //nTrialNumber: trial in block
			int trialTime = ( (Integer) state.getData(Tools.CURRENT_STEP, i)).
				intValue() * 200; //tTrialTime: time since onset (ms)
			int trialOnset = time - trialTime; //tTrialOnset: time of onset )ms)
			int trialDelay = ( (Integer) state.getData(Tools.CURRENT_DELAY, i)).
				intValue() * 200; //tTrialDelay: trial delay in (ms)
			int trialType = -1; //default is unknown because older simulations did not saved it.
			if (state.hasData(Tools.CURRENT_TRIALTYPE)) {
				trialType = ( (Integer) state.getData(Tools.CURRENT_TRIALTYPE,
					i)).intValue(); //nTrialTypeID: trial type
			}

			//Save trial info on new trial here
			if (isInTrial && !prevIsInTrial) {
				//Check for last test trial switch
				if (i == state.getDataSetCount() - 1) {
					System.out.println("Skipping sample " + time +
									   "ms of trial " + trialNumber +
									   " in blockUID " + blockUID);
					return; //skip last sample
				}
				int trialUID = createTrial(blockUID, trialNumber, trialDelay,
										   trialOnset, trialType);
				if (trialUID == 0) {
					break;
				}
			}
			prevIsInTrial = isInTrial;

			//Extract signals:
			double cs = ( (Double) state.getData(Tools.STIMULUS, i)).
				doubleValue();
			double us = ( (Double) state.getData(Tools.REWARD, i)).doubleValue();
			double[] bg_stim = (double[]) monkey.getData(Tools.STIMULUS, i);
			//The standard is
			//cs, LSTM(P(us)), MB1C1, MB1C2, MB2C1, MB2C2
			//Some very old simulations had US to TD
			//cs, us, LSTM(P(us)), MB1C1, MB1C2, MB2C1, MB2C2
			//Some TDBias simulations (30nov07) have a bias to TD
			//Bias, cs, LSTM(P(us)), MB1C1, MB1C2, MB2C1, MB2C2
			//Some LTSM2Outputs simulations (14Jan08) have 2 LSTM outputsd
			//cs, LSTM(P(cs)), LSTM(P(us)), MB1C1, MB1C2, MB2C1, MB2C2
			//The common point is that LSTM->BG are the last elements!
			double lstm = bg_stim[bg_stim.length - 5];
			double lstmCS = bg_stim[bg_stim.length - 6]; //if LSTM has 2 outputs
			double p = ( (Double) monkey.getData(Tools.PREDICTION, i)).
				doubleValue();
			double da = ( (Double) monkey.getData(Tools.DOPAMINE, i)).
				doubleValue();

			//double[][] w = (double[][]) monkey.getData(Tools.CRITICS_WEIGHTS, i);        //Final value should be suffient, not sampling
			//double[][] wc = (double[][]) monkey.getData(Tools.CRITICS_WEIGHTS_CHANGE, i);//I don't have this data
			//System.out.println(time + "\tw\t" + LinearAlgebra.toString(w) + "\n");

			//Extract lstm signals
			DataSet ds = (DataSet) monkey.getData("LSTM", i);
			double[] inputGates = (double[]) ds.getData(LSTMDataNames.
				LSTM_INPUT_GATES);
			double[] forgetGates = (double[]) ds.getData(LSTMDataNames.
				LSTM_FORGET_GATES);
			double[] outputGates = (double[]) ds.getData(LSTMDataNames.
				LSTM_OUTPUT_GATES);
			double[] states = (double[]) ds.getData(LSTMDataNames.
				LSTM_INTERNAL_STATES);
			double[] acts = (double[]) ds.getData(LSTMDataNames.
												  LSTM_INTERNAL_ACTIVATIONS);
			//Changes on 15Jan08
			boolean has2outputs = ( (double[]) ds.getData(DataNames.
				OUTPUT_PATTERNS)).length == 2 ? true : false;
			double[] output = ( (double[]) ds.getData(DataNames.OUTPUT_PATTERNS));
			double[] target = new double[] {
				0, 0}
				, error = new double[] {
				0, 0};
			if (i != state.getDataSetCount() - 1) {
				ds = (DataSet) monkey.getData("LSTM", i + 1);
				target = ( (double[]) ds.getData(DataNames.TARGET_PATTERNS));
				error = ( (double[]) ds.getData(DataNames.ERROR_PATTERNS));
			}

			//
			//(bg_Stim.length == 7) (bg_Stim.length == 7)
			//	 InputPatterns OutputPatterns
			//   LSTMOutputGates LSTMForgetGates LSTMInputGates
			//   LSTMInternalActivations LSTMInternalStates
			//   SumSquaredError ErrorPatterns  TargetPatterns

			try {
				Statement stmt = m_Conn.createStatement();
				String[] signals = {
						"nBlockID",
						"tTime",
						"bIsInTrial",
						"nTrialNumber",
						"tTrialTime",
						//"tTrialOnset",
						//"tTrialDelay",
						"fCS",
						"fUS",
						"fLSTM",
						"fP",
						"fDA",
						"fLSTMInputGate1",
						"fLSTMInputGate2",
						"fLSTMForgetGate1",
						"fLSTMForgetGate2",
						"fLSTMOutputGate1",
						"fLSTMOutputGate2",
						"fLSTMState11",
						"fLSTMState12",
						"fLSTMState21",
						"fLSTMState22",
						"fLSTMAct11",
						"fLSTMAct12",
						"fLSTMAct21",
						"fLSTMAct22",
						"fLSTMOutput",
						"fLSTMTarget",
						"fLSTMError",
						"fState11",
						"fState12",
						"fState21",
						"fState22"};
				if (has2outputs) {
					signals = concatenateStringArrays(signals, new String[]
						{"fLSTMcs",
						 "fLSTMOutputcs",
						 "fLSTMTargetcs",
						 "fLSTMErrorcs"});
				}
				String[] values = {
						Integer.toString(blockUID),
						Integer.toString(time),
						isInTrial ? "1" : "0",
						Integer.toString(trialNumber),
						Integer.toString(trialTime),
						//Integer.toString(trialOnset),
						//Integer.toString(trialDelay),
						Double.toString(cs),
						Double.toString(us),
						Double.toString(lstm),
						Double.toString(p),
						Double.toString(da),
						Double.toString(inputGates[0]),
						Double.toString(inputGates[1]),
						Double.toString(forgetGates[0]),
						Double.toString(forgetGates[1]),
						Double.toString(outputGates[0]),
						Double.toString(outputGates[1]),
						Double.toString(states[0]),
						Double.toString(states[1]),
						Double.toString(states[2]),
						Double.toString(states[3]),
						Double.toString(acts[0]),
						Double.toString(acts[1]),
						Double.toString(acts[2]),
						Double.toString(acts[3]),
						Double.toString(output[output.length - 1]),
						Double.toString(target[target.length - 1]),
						Double.toString(error[error.length - 1]),
						Double.toString(bg_stim[bg_stim.length - 4]),
						Double.toString(bg_stim[bg_stim.length - 3]),
						Double.toString(bg_stim[bg_stim.length - 2]),
						Double.toString(bg_stim[bg_stim.length - 1])};
					if (has2outputs) {
						values = concatenateStringArrays(values, new String[]
							{Double.toString(lstmCS),
							 Double.toString(output[0]),
							 Double.toString(target[0]),
							 Double.toString(error[0])});
					}
					//Run query
					String insert = SQLTools.buildSQLInsert("AllSignals",
						signals, values);
					//dbout(insert + "\n");
					stmt.execute(insert);
				}
				catch (SQLException e) {
					System.err.println(e.getMessage());
					e.printStackTrace();
					dbout(" -Unexpected problem!-");
					return;
				}

			}
			//DataSet ds = (DataSet) monkey.getData("LSTM", 3);
			//String[] lst = ds.dataNamesList();
			//for (int j=0; j<lst.length; j++)
			//{
			//	dbout(" " + lst[j]);
			//}
			//dbout("\n");

		}

		/*********************************************************************/
		//Runnable interface

		/** Function called by THREAD.start(). */
		public void run() {

			//Default for the moment
			int trainDelay = 1000; //in ms
			dbout("Model " + m_Model + ", Date " + m_Date.toString() +
				  ", Delay " + trainDelay + "ms\n");

			//Open connection
			m_Conn = Tools.openConnection(m_DBName);
			if (m_Conn == null) {
				dbout("Opening " + m_DBName + " failed!\n");
				return;
			}
			else {
				dbout("Opening " + m_DBName + " succeed!\n");
			}

			//Gather files
			File[] dscList = Tools.listDataSetCollections(m_DscPath);
			if (dscList == null) {
				dbout("Opening " + m_DscPath + " failed!\n");
				return;
			}
			else {
				dbout("Searching " + m_DscPath + ": " + dscList.length +
					  " files found!\n");
			}
			m_AllNetworkCount = dscList.length;

			//For all files
			for (int i = 0; i < m_AllNetworkCount; i++) {
				//Now check if this one is kept
				int netID = Integer.parseInt(Tools.getID(dscList[i].getName()));
                                System.out.println(netID);
				/*if (!isKept(dscList[i])) {//16 may 2008
                                  	dbout(netID + " not kept!\n");
					continue;
				}*/
				//Load the file
				dbout("Loading " + netID + " ... ");
				DataSetCollection dsc;
				try {
					dsc = Tools.loadDataSetColl(dscList[i]);
					dbout("done!\n");
				}
				catch (Exception e) {
					dbout(e.toString() + " Abort!");
					break;
				}
			        m_BlockCount++;

				//I should first check if the network as already been saved,
				//but here I assume it is the first time
				int blockCount = createNetwork(netID, m_Model);

				//for every block (train and test)
				for (int j = 0; j < dsc.getDataSetCount(); j++) {

					//Check for initial test block
					if (dsc.hasData("ControlState") &&
						(dsc.getData("ControlState", j) != null)) {
						blockCount++;
						dbout("Converting control block " + j +
							  " as block " + blockCount);
						int blockUID = createBlock(netID, trainDelay, m_Date,
							blockCount, CONTROL);
						dbout(" UID:" + blockUID + "\n");
						if (blockUID != 0) {
							dumpBlock(blockUID,
									  (DataSetCollection) dsc.getData(
								"ControlState", j),
									  (DataSetCollection) dsc.getData(
								"ControlMonkey", j));
						}
						m_BlockCount++;
					}

					//Check for test block
					if (dsc.hasData("TestState") && (dsc.getData("TestState", j) != null)) {
						blockCount++;
						dbout("Converting test block " + j + " as block " +
							  blockCount);
						int blockUID = createBlock(netID, trainDelay, m_Date,
							blockCount, TEST);
						dbout(" UID:" + blockUID + "\n");
						if (blockUID != 0) {
							dumpBlock(blockUID,
									  (DataSetCollection) dsc.getData(
								"TestState", j),
									  (DataSetCollection) dsc.getData(
								"TestMonkey", j));
						}
						m_BlockCount++;
					}

					//check for train block
					if (dsc.hasData("State") && (dsc.getData("State", j) != null)) {
						blockCount++;
						dbout("Converting train block " + j + " as block " +
							  blockCount);
						int blockUID = createBlock(netID, trainDelay, m_Date,
							blockCount, TRAIN);
						dbout(" UID:" + blockUID + "\n");
						if (blockUID != 0) {
							dumpBlock(blockUID,
									  (DataSetCollection) dsc.getData("State",
								j),
									  (DataSetCollection) dsc.getData("Monkey",
								j));
						}
						m_BlockCount++;
					}
				}

				//Track progress;
				synchronized (this) {
					m_BlockCount = 0;
					m_NetworkCount++;
				}
				;

			}

			//Close connection
			try {
				m_Conn.close();
			}
			catch (SQLException e) {
				dbout("Unexpected exception when closing connection!\n");
				e.printStackTrace();
			}
		}

	/** Given a .dsc file, check in the corresponding .ds file for the
	 * kept variable.
	 */
	boolean isKept(File f)
	{
		//Get .ds file name
		String temp = f.getAbsolutePath();
		//System.out.println(temp);
		File fs = new File(temp.replaceFirst("Result", "TrainedAgent").replaceFirst("dsc", "ds"));
		//System.out.println(fs.getName());
		//Check if file exist
		if (!fs.exists()) {
			System.err.println(fs.getName() + " not found!");
			return false;
		}
		//Read kept flag and return it
		//System.out.print(fs.getName());
		try {
			DataSet dat = lnsc.Tools.loadDataSet(fs.toString());
			boolean thissuc = ( (Boolean) dat.getData("Kept")).booleanValue();
			return thissuc;
		}
		catch (Exception exc) {
			System.err.println(exc.toString());
			return false;
		}
	}

	//HELPER function

	String[] concatenateStringArrays(String[] s1, String[] s2)
	{
		String[] ret = new String[s1.length+s2.length];
		for (int i=0; i<s1.length; i++)
		{
			ret[i] = s1[i];
		}
		for (int i=0; i<s2.length; i++)
		{
			ret[s1.length+i] = s2[i];
		}
		return ret;
	}
}