package stimulusdelayreward;
import lnsc.page.*;
import lnsc.lstm.*;
import lnsc.*;
import grsnc.binb.*;
/**
* This is the basic monkey using the full model.
*
* It uses and ActorCritic model of the basal ganglia and it uses a
* eligibility traces driven version of LSTM (ETLSTM1) as frontal cortex.
* Both system runs in parallel and basal ganglia receives input from LSTM
* output at previous time step. LSTM are updates on their next input. LSTM
* are trained to predict their next inputs.
*
* DA signal from BG is used to modulate the LSTM learning rate.
*
*
* @author Francois Rivest
* @version 1.0
*/
public class ActorCritic_PDAETLSTM_Monkey2 extends AbstractObservableAgent {
/*********************************************************************/
//Serial Version UID
/** Serial version UID. */
static final long serialVersionUID = 8854214947734172413L;
/*********************************************************************/
//Private fields (current state)
/** AC model. */
protected Agent m_ACMModel;
/** LSTM network model. */
protected ETLSTMNetwork1 m_LSTMNet;
/** LSTM trainer. */
protected OnlineSPMSELearning m_Trainer;
/** LSTM state representation. */
protected StateRepresentation m_LSTMStateRep;
/** AC state representation. */
protected StateRepresentation m_ACMStateRep;
/** AC state representation with previous LSTM. */
protected StateRepresentation m_ACMExtendedStateRep;
/** LSTM total outputs count. */
protected int m_LSTMCount;
/** Base LSTM learning rate. */ /***DA2***/
protected double m_LSTMlr;/***DA2***/
/** Last LSTM output, to use as input ot BG. */
protected transient double[] m_PrevLSTM;
/** Latest monkeys state. */
protected transient DataSet m_LatestState = null;
static private String[] m_RecordList = new String[] {
LSTMDataNames.INPUT_PATTERNS,
LSTMDataNames.OUTPUT_PATTERNS,
LSTMDataNames.TARGET_PATTERNS,
LSTMDataNames.ERROR_PATTERNS,
LSTMDataNames.SUM_SQUARED_ERROR,
LSTMDataNames.LSTM_INTERNAL_STATES,
LSTMDataNames.LSTM_INTERNAL_ACTIVATIONS,
LSTMDataNames.LSTM_INPUT_GATES,
LSTMDataNames.LSTM_FORGET_GATES,
LSTMDataNames.LSTM_OUTPUT_GATES
};
/*********************************************************************/
//Constructors
public ActorCritic_PDAETLSTM_Monkey2(int blockCount, int cellPerBlock,
boolean inSquash, boolean outSquash,
boolean gate2gate, boolean in2out,
double LSTMlr, double ACMlr,
int ACmodel, StateRepresentation ACStateRep,
double lambda, boolean oppSignResetTraces) {
ETLSTMFactory1 fact = new ETLSTMFactory1(
2, blockCount, cellPerBlock, inSquash, outSquash, 1, new LogisticUnit(),
gate2gate, true, in2out, false, 1, //gate2gate, bias2output, input2output, gate2output, outputfactor
lambda, oppSignResetTraces); //lambda, oppsignresettraces
m_LSTMNet = (ETLSTMNetwork1) fact.createUnit();
m_Trainer = new OnlineSPMSELearning(m_LSTMNet, LSTMlr, 1);
//m_LSTMCount = blockCount*(3+2*cellPerBlock) + m_LSTMNet.getOutputCount(); /***10Mar06***/
m_LSTMCount = blockCount*cellPerBlock + m_LSTMNet.getOutputCount(); /***10Mar06***/
m_LSTMlr = LSTMlr; /***DA2***/
m_LSTMStateRep = new TwoSignalStateRepresentation();
m_ACMStateRep = ACStateRep;//20080208
m_ACMExtendedStateRep = new OfflineStateRepresentation(m_ACMStateRep.getOutputCount()+m_LSTMCount); /***10Mar06***/
/*if (ACmodel == 1) {
m_ACMModel = new Rivest05(1, 1, m_ACMExtendedStateRep, ACMlr, .1);
} else if (ACmodel == 2) {
m_ACMModel = new Schultz97(1, 1, m_ACMExtendedStateRep, ACMlr, .1);
} else if (ACmodel == 3) {
m_ACMModel = new Pan05(1, 1, m_ACMExtendedStateRep, ACMlr, .1);
} else*/ if (ACmodel == 4) {
m_ACMModel = new Rivest06(1, 1, m_ACMExtendedStateRep, ACMlr, .1);
} else {
throw new RuntimeException("Unknown model!");
}
}
/*********************************************************************/
//Interface implementation
public void newEpisode(State newState) {
m_Trainer.reset();
m_ACMModel.newEpisode(newState);
m_PrevLSTM = new double[m_LSTMCount];/***10Mar06***/
//This assumes stateless representations
}
public void returnReward(State resultState, double reward) {
//useless in this framework
}
public Action requestAction(State currentState) {
//------------------------------------
//The first step is to process the ACM
//--Create ACM representation
double[] acm_input = LinearAlgebra.concatenateVectors(m_ACMStateRep.getRepresentation(currentState), m_PrevLSTM);
((OfflineStateRepresentation) m_ACMExtendedStateRep).setRep(acm_input);
//--Process ACM
m_ACMModel.returnReward(currentState, ((MonkeyObservableState)currentState).getRewardSignal());
Action a = m_ACMModel.requestAction(currentState);
//--Collect data
m_LatestState = m_ACMModel.toDataSet();
double da = ((Double) m_LatestState.getData(Rivest06.DOPAMINE)).doubleValue();
//--------------------------------------
//The second step is to process the LSTM
//--Create ACM representation
double[] lstm_input = m_LSTMStateRep.getRepresentation(currentState);
m_Trainer.setLearningRate(m_LSTMlr*(1+Math.abs(da))); /***DA2***/
//--Process LSTM model
DataSet lstm_data = m_Trainer.train(lstm_input, m_RecordList);
//--Collect data
m_LatestState.setData("LSTM", lstm_data);
m_PrevLSTM = (double[]) lstm_data.getData(DataNames.OUTPUT_PATTERNS);
/***10Mar06***/
//m_PrevLSTM = LinearAlgebra.concatenateVectors(m_PrevLSTM, (double[]) lstm_data.getData(LSTMDataNames.LSTM_INTERNAL_ACTIVATIONS));
m_PrevLSTM = LinearAlgebra.concatenateVectors(m_PrevLSTM, (double[]) lstm_data.getData(LSTMDataNames.LSTM_INTERNAL_STATES));
//m_PrevLSTM = LinearAlgebra.concatenateVectors(m_PrevLSTM, (double[]) lstm_data.getData(LSTMDataNames.LSTM_INPUT_GATES));
//m_PrevLSTM = LinearAlgebra.concatenateVectors(m_PrevLSTM, (double[]) lstm_data.getData(LSTMDataNames.LSTM_FORGET_GATES));
//m_PrevLSTM = LinearAlgebra.concatenateVectors(m_PrevLSTM, (double[]) lstm_data.getData(LSTMDataNames.LSTM_OUTPUT_GATES));
m_PrevLSTM = bound(m_PrevLSTM);
/***10Mar06***/
//---------------------------
//The third step notification
//--Notify obervers
setChanged();
notifyObservers();
//--Return null action
return a;
}
public void endEpisode(State finalState) {
//------------------------------------
//The first step is to process the ACM
//--Create ACM representation
double[] acm_input = LinearAlgebra.concatenateVectors(m_ACMStateRep.getRepresentation(finalState), m_PrevLSTM);
((OfflineStateRepresentation) m_ACMExtendedStateRep).setRep(acm_input);
//--Process ACM
m_ACMModel.returnReward(finalState, ((MonkeyObservableState)finalState).getRewardSignal());
Action a = m_ACMModel.requestAction(finalState);
//--Collect data
m_LatestState = m_ACMModel.toDataSet();
double da = ((Double) m_LatestState.getData(Rivest06.DOPAMINE)).doubleValue();
//--------------------------------------
//The second step is to process the LSTM
//--Create ACM representation
double[] lstm_input = m_LSTMStateRep.getRepresentation(finalState);
m_Trainer.setLearningRate(m_LSTMlr*(1+Math.abs(da))); /***DA2***/
//--Process LSTM model
DataSet lstm_data = m_Trainer.train(lstm_input, m_RecordList);
//--Collect data
m_LatestState.setData("LSTM", lstm_data);
m_PrevLSTM = (double[]) lstm_data.getData(DataNames.OUTPUT_PATTERNS);
/***10Mar06***/
//m_PrevLSTM = LinearAlgebra.concatenateVectors(m_PrevLSTM, (double[]) lstm_data.getData(LSTMDataNames.LSTM_INTERNAL_ACTIVATIONS));
m_PrevLSTM = LinearAlgebra.concatenateVectors(m_PrevLSTM, (double[]) lstm_data.getData(LSTMDataNames.LSTM_INTERNAL_STATES));
//m_PrevLSTM = LinearAlgebra.concatenateVectors(m_PrevLSTM, (double[]) lstm_data.getData(LSTMDataNames.LSTM_INPUT_GATES));
//m_PrevLSTM = LinearAlgebra.concatenateVectors(m_PrevLSTM, (double[]) lstm_data.getData(LSTMDataNames.LSTM_FORGET_GATES));
//m_PrevLSTM = LinearAlgebra.concatenateVectors(m_PrevLSTM, (double[]) lstm_data.getData(LSTMDataNames.LSTM_OUTPUT_GATES));
m_PrevLSTM = bound(m_PrevLSTM);
/***10Mar06***/
//---------------------------
//The third step notification
//--Notify obervers
setChanged();
notifyObservers();
}
/*********************************************************************/
//toDataSet
public DataSet toDataSet() {
return m_LatestState;
}
/*********************************************************************/
//toString
public String toString()
{
return m_ACMModel.toString() + "\n" + m_LSTMNet.toString();
}
/*********************************************************************/
//Helper
protected double[] bound(double[] p)
{
for (int i=0; i<p.length; i++)
{
p[i] = bound(p[i]);
}
return p;
}
protected double bound(double p)
{
double ubound = 1;
double lbound = 0;
return Math.max(Math.min(p,ubound),lbound);
}
}