package lnsc.lstm;
import lnsc.*;
/**
* <p> Factory for LSTM network using eligibility traces to find more rapidely
* association in time space. </p>
*
* <P> Special memory blocks are used. These blocks maintain memory traces for
* their input and used them in derivatives instead of using raw input. Traces
* are build as if input where bound between [-1, 1]. For a given input x_t
* it traces e_t = bound(lamdba(e_t) + x_t,[-1,1]), unless e_t and x_t have
* opposite sign and OppSignResetTraces is true, in whic cases e_t = bound(x_t).
* This is a formed of bounded cumulated traces. </P>
*
* <p> The implementation is totally transparent to the network, since it is
* the memoryblock parameters derivative that are properly constructed. </p>
*
* <p> It is unclear whether each gate should have a different trace decay rate
* and whether the peepwhole should have trace at all. Right now, they are
* all the same.<p>
*
* <p> Also note that gradient are defined recursively, so traces may be wrong. <p>
*
* @see ETLSTMNetwork1
* @see ETLSTMMemoryBlock1
*
* @author Francois Rivest
* @version 1.1
*/
public class ETLSTMFactory1 extends LSTMFactory {
/*********************************************************************/
//Serial Version UID
/** Serial version UID. */
//static final long serialVersionUID = ;
/*********************************************************************/
//Private fields (memory block)
/* Eligibility trace decay rate. */
protected double m_Lambda = .8;
/* Indicate whether traces are reset on opposite sign */
protected boolean m_OppSignResetTraces = true;
/*********************************************************************/
//Constructors
/** Construct an LSTM network factory.
* @param newInputCount Number of input to the network.
* @param newBlockCount Number of memory block.
* @param newCellperBlock Number of memory cells per block.
* @param newSquashInput true to squash input to cell (default)
* false without (use identity instead)
* @param newSquashOutput true to squash output of cell as in older papers
* false without squashing (default)
* @param newOutputCount Number of output of the network
* @param newSampleOutput Sample of an output function (should have one
* input and one output, default LogisticUnit(1,0))
* @param newGateToGate Connects block gates to block (default false)
* @param newBiasToOutput Connects bias to output layer (default true)
* @param newInputToOutput Connects input to output layer (default true)
* @param newGateToOutput Connects block gates to output layer (default true)
* @param newOutputWeightsLocalGradientFactor Scales the gradien of the
* output weigths internally.
* @param newLambda Eligibility traces decay rate.
* @param newOppSignResetTraces true to reset traces on opposite sign.
*/
public ETLSTMFactory1(int newInputCount,
int newBlockCount,
int newCellperBlock,
boolean newSquashInput,
boolean newSquashOutput,
int newOutputCount,
FunctionalUnit newSampleOutput,
boolean newGateToGate,
boolean newBiasToOutput,
boolean newInputToOutput,
boolean newGateToOutput,
double newOutputWeightsLocalGradientFactor,
double newLambda,
boolean newOppSignResetTraces)
{
m_InputCount = newInputCount;
m_BlockCount = newBlockCount;
m_CellperBlock = newCellperBlock;
m_g = newSquashInput ? (FunctionalUnit) new LogisticUnit(2,-1) : (FunctionalUnit) new LinearUnit();
m_h = newSquashOutput ? (FunctionalUnit) new LogisticUnit(2,-1) : (FunctionalUnit) new LinearUnit();
m_OutputCount = newOutputCount;
m_SampleOutput = newSampleOutput;
m_GateToGate = newGateToGate;
m_BiasToOutput = newBiasToOutput;
m_InputToOutput = newInputToOutput;
m_GateToOutput = newGateToOutput;
m_OutputWeightsLocalGradientFactor = newOutputWeightsLocalGradientFactor;
m_Lambda = newLambda;
m_OppSignResetTraces = newOppSignResetTraces;
}
/*********************************************************************/
//FunctionalUnitFactory interface implementation
public FunctionalUnit createUnit() {
ETLSTMNetwork1 newNet = new ETLSTMNetwork1(
m_InputCount,
m_BlockCount,
m_CellperBlock,
m_g,
m_h,
m_InputGate,
m_ForgetGate,
m_OutputGate,
m_OutputCount,
m_SampleOutput,
m_GateToGate,
m_BiasToOutput,
m_InputToOutput,
m_GateToOutput,
m_Lambda,
m_OppSignResetTraces);
newNet.setOutputWeightsLocalGradientFactor(m_OutputWeightsLocalGradientFactor);
initializeWeights(newNet);
return newNet;
}
}