package lnsc.lstm;

import lnsc.*;
import lnsc.pmvf.*;

/** <P> Long Short-Term Memory (LSMT) network. </P>
 * <P> Implements Gers, Schraudolph & Schmiduber 2002 (from Journal of Machine
 * Learning Research 3:115-143) LSTM network. See also Gers, Schmidhuber, &
 * Cummins 2000 (from Neural Computation, 12(10) 2451-2471) and Hochreiter &
 * Schmidhuber 1997 (from Neural Computation 9(8):1735-1780).  </P>
 * <P> This network memory reads the inputs, pass them through a layer of memory
 * blocks, and weights memory blocks output before applying a layer of output
 * units (usually sigmoidal). Eventually it may support hidden units (trainable
 * using truncated gradient). The memory blocks are fully recurrently connected.
 * The input are expendend with a bias unit to feed the memory block. The
 * parameters are all the weights feeding the memory block (foward and
 * recurrent) and the weights feeding the output layer. Output layer can be feed
 * by everything (bias, input, memory block outputs, memory block gates).
 * Gradient from network output to all the weights is provided as in the paper.
 * Only derivatives to these parameters is available.</P>
 * <P> Parameters are all the weights of the memory blocks and all the weights
 * of the output layer. Derivative to parameters represent the derivative of
 * each output z_k(t) with respect to each weights based on the papers formulas.
 * <P>
 * <P> Memory cells internal states, outputs, and memory block gates are
 * available through keywords in {@link LSTMDataNames}. </P>
 * <P> Cloning is done through serialization, transient states are therefore
 * transient to cloning too (e.g. reseted in clones). <P>
 * <P> Avoid using this class directly or deriving it unless you really know
 * what you do. Use Factories instead as much as possible. </P>
 *  @see FastLSTMMemoryBlock
 *  @see AbstractLSTMFactory
 *  @see LSTMFactory
 *  @see LSTMDataNames
 *  @author Francois Rivest
 *  @version 1.0
 *  @since 1.0

public class FastLSTMNetwork extends AbstractFunctionalUnit2 {

    //Serial Version UID

    /** Serial version UID. */
    static final long serialVersionUID = -5612768043320548291L;

    //Private fields (architecture, see properties)

    /** Indicates whether bias should be connected to the output layer. */
    protected boolean m_BiasToOutput;

    /** Indicates whether input should be connected to the output layer. */
    protected boolean m_InputToOutput;

    /** Indicates whether gates of memory block should be connected to
     * the output layer. */
    protected boolean m_GateToOutput;

    /** Indicates whether gates of memory block should be recurrently connected
     * to memory blocks. */
    protected boolean m_GateToGate;

    /** Number of memory blocks. */
    protected int m_MemoryBlockCount;

    /** Memory blocks. */
    protected FastLSTMMemoryBlock[] m_MemoryBlocks;

    /** Output layer. */
    protected FastSingleLayerNeuralNetwork m_OutputLayer;

    /** Output layer local gradient factor. */
    protected double m_OutputWeightsLocalGradientFactor = 1.0;

    /** Public debug info output. */
    public boolean m_Debug = false;

    //Private fields (transient active/previous states, see also reset)

    /** Previous output of memory blocks [MemoryBlockCount][(MemoryBlock)OutputCount]. */
    protected transient double[][] m_PrevMemoryBlocksOutput;


    /** Construct an LSTM network.
     * @param    newInputCount    Number of input to the network.
     * @param    newBlockCount    Number of LSTM memory block.
     * @param    newCellperBlock  Number of memory cells per block.
     * @param    newg             Memory cell input squashing function (g)
     * @param    newh             Memory cell output squashing function (h)
     * @param    newInputGate     Memory block input gate function
     * @param    newForgetGate    Memory block forget gate function
     * @param    newOutputGate    Memory block output gate function
     * @param    newOutputCount   Number of output of the network
     * @param    newSampleOutput  Sample of an output function (should have one
     *                            input and one output, default LogisticUnit(1,0))
     * @param    newGateToGate    Connects block gates to block (default false)
     * @param    newBiasToOutput  Connects bias to output layer (default true)
     * @param    newInputToOutput Connects input to output layer (default true)
     * @param    newGateToOutput  Connects block gates to output layer (default true)
    protected FastLSTMNetwork(int newInputCount,
                              int newBlockCount,
                              int newCellperBlock,
                              FunctionalUnit newg,
                              FunctionalUnit newh,
                              FunctionalUnit newInputGate,
                              FunctionalUnit newForgetGate,
                              FunctionalUnit newOutputGate,
                              int newOutputCount,
                              FunctionalUnit newSampleOutput,
                              boolean newGateToGate,
                              boolean newBiasToOutput,
                              boolean newInputToOutput,
                              boolean newGateToOutput)

        //Argument check
        if (newInputCount < 0) {
            throw new IllegalArgumentException(
                "Number of input must be non negative!");
        if (newBlockCount < 0) {
            throw new IllegalArgumentException(
                "Number of memory block must be non negative!");
        if (!newSampleOutput.isDifferentiable()) {
            throw new IllegalArgumentException(
                "Sample output function must be differentiable!");

        //FunctionalUnit2 properties
        m_InputCount = newInputCount;
        m_OutputCount = newOutputCount;
        m_ParameterCount = 0;//below
        m_IsDifferentiable = false;
        m_IsTwiceDifferentiable = false;
        m_IsParameterDifferentiable = true;
        m_IsParameterTwiceDifferentiable = false;

        //Other properties
        m_BiasToOutput = newBiasToOutput;
        m_InputToOutput = newInputToOutput;
        m_GateToOutput = newGateToOutput;
        m_GateToGate = newGateToGate;
        m_MemoryBlockCount = newBlockCount;

        //Memory blocks
        m_MemoryBlocks = new FastLSTMMemoryBlock[newBlockCount];
        for (int i=0; i<newBlockCount; i++)
            //diffs between the two constructors here
            m_MemoryBlocks[i] = new FastLSTMMemoryBlock(
                1+newInputCount + newBlockCount*(newCellperBlock+(newGateToGate?3:0)),
                newCellperBlock, newg, newh, newInputGate, newForgetGate, newOutputGate);

        //Output layer
        m_OutputLayer = new FastSingleLayerNeuralNetwork(
            (newBiasToOutput?1:0) + (newInputToOutput?newInputCount:0) +
            newBlockCount * (newCellperBlock+(newGateToOutput?3:0)),
            false, newOutputCount, newSampleOutput);

        //Parameter count
        for (int i=0; i<m_MemoryBlockCount; i++)
            m_ParameterCount += m_MemoryBlocks[i].getParameterCount();
        m_ParameterCount += m_OutputLayer.getParameterCount();

        //Transient states initialisation



    /** Sets a factor used when computing the gradient to the output weights.
     * This factor does not affect the calculation of the gradient of any other
     * weights but the output layer weights. The gradient of these weights are
     * often 2 order of magnitude larger than the gradient of the other weights,
     * allowing more pression on the bias and input to output connections than
     * on any internal weights. Original setting is factor=1, but suggested
     * setting by Rivest is factor=1E-2. This encouraged search on internal
     * weights adjustement.
    protected void setOutputWeightsLocalGradientFactor(double newFactor)
        m_OutputWeightsLocalGradientFactor = newFactor;

    //FunctionalUnit2 interface implementation

    public void reset()
        m_PrevMemoryBlocksOutput = new double[m_MemoryBlockCount][];
        for (int i=0; i<m_MemoryBlockCount; i++)
            m_PrevMemoryBlocksOutput[i] = new double[m_MemoryBlocks[i].getOutputCount()];

    public FunctionalUnit2.ProcessPatternResult2 processPattern(
        double[] inputPattern,
        boolean computeDerivative,
        boolean computeSecondDerivative,
        boolean computeParameterDerivative,
        boolean computeParameterSecondDerivative,
        String[] recordList)

        //*** Preprocessing
        FunctionalUnit2.ProcessPatternResult2 ret =

        //*** Forward pass

        //*Memory blocks input vector
        //The input vector to the memory blocks is the concatenation of:
        //1, input pattern, block[0].prevOutput, .... block[n-1].prevOutput.
        //Make space
        int count = 1 + m_InputCount;
        for (int i=0; i<m_MemoryBlockCount; i++)
            count += m_MemoryBlocks[i].getOutputCount()-(m_GateToGate?0:3);
        double[] blockInput = new double[count];
        int index = 0;
        blockInput[index] = 1.0;
        LinearAlgebra.overwriteSubVector(index, inputPattern, blockInput);
        index += m_InputCount;
        //Memory blocks
        for (int i=0; i<m_MemoryBlockCount; i++)
            if (m_GateToGate) {
                index += m_MemoryBlocks[i].getOutputCount();
            } else {
                index += m_MemoryBlocks[i].getOutputCount()-3;
        if (m_Debug) {
            System.out.println("Bias: " + 1);
            System.out.println("Input: " + LinearAlgebra.toString(inputPattern));
            for (int i=0; i<m_MemoryBlockCount; i++)
                System.out.println("Previous Memory Block " + i + " Output: "
                                   + LinearAlgebra.toString(m_PrevMemoryBlocksOutput[i]));
            System.out.println("Gate2Gate: " + m_GateToGate);
            System.out.println("Resulting Memory Block Input: " + LinearAlgebra.toString(blockInput));

        //*Process memory blocks
        FunctionalUnit2.ProcessPatternResult2 memoryBlocks[] = new FunctionalUnit2.ProcessPatternResult2[m_MemoryBlockCount];
        for (int i=0; i<m_MemoryBlockCount; i++)
            memoryBlocks[i] = m_MemoryBlocks[i].processPattern(blockInput, false, false, computeParameterDerivative, false, recordList);

        //*Save internal state on request
         ret.extraData = new DataSet();
         if (DataNames.isMember(LSTMDataNames.LSTM_INTERNAL_STATES, recordList)) {
             double[][] int_states = new double[m_MemoryBlockCount][];
             for (int i=0; i<m_MemoryBlockCount; i++)
                 int_states[i] = (double[]) memoryBlocks[i].extraData.getData(LSTMDataNames.LSTM_INTERNAL_STATES);
             ret.extraData.setData(LSTMDataNames.LSTM_INTERNAL_STATES, LinearAlgebra.concatenateRows(int_states));
         if (DataNames.isMember(LSTMDataNames.LSTM_INTERNAL_ACTIVATIONS, recordList)) {
             double[][] int_acts = new double[m_MemoryBlockCount][];
             for (int i=0; i<m_MemoryBlockCount; i++)
                 int_acts[i] = LinearAlgebra.extractVector(memoryBlocks[i].outputPattern, 0, memoryBlocks[i].outputPattern.length-4);
             ret.extraData.setData(LSTMDataNames.LSTM_INTERNAL_ACTIVATIONS, LinearAlgebra.concatenateRows(int_acts));
         if (DataNames.isMember(LSTMDataNames.LSTM_INPUT_GATES, recordList)) {
             double[] gates = new double[m_MemoryBlockCount];
             for (int i=0; i<m_MemoryBlockCount; i++)
                 gates[i] = memoryBlocks[i].outputPattern[memoryBlocks[i].outputPattern.length-3];
             ret.extraData.setData(LSTMDataNames.LSTM_INPUT_GATES, gates);
         if (DataNames.isMember(LSTMDataNames.LSTM_FORGET_GATES, recordList)) {
             double[] gates = new double[m_MemoryBlockCount];
             for (int i=0; i<m_MemoryBlockCount; i++)
                 gates[i] = memoryBlocks[i].outputPattern[memoryBlocks[i].outputPattern.length-2];
             ret.extraData.setData(LSTMDataNames.LSTM_FORGET_GATES, gates);
         if (DataNames.isMember(LSTMDataNames.LSTM_OUTPUT_GATES, recordList)) {
             double[] gates = new double[m_MemoryBlockCount];
             for (int i=0; i<m_MemoryBlockCount; i++)
                 gates[i] = memoryBlocks[i].outputPattern[memoryBlocks[i].outputPattern.length-1];
             ret.extraData.setData(LSTMDataNames.LSTM_OUTPUT_GATES, gates);
         if (ret.extraData.getDataCount() == 0) {ret.extraData = null;}

        //*Output layer input vector
        //The input vector to the memory block is the concatenation of:
        //1, inputPattern, block[0].output, ..., block[n-1].output.
        //But, some may be removed based on the 3 xxxToOutput switch.
        //Make space
        count = (m_BiasToOutput?1:0) + (m_InputToOutput?m_InputCount:0);
        for (int i=0; i<m_MemoryBlockCount; i++)
            count += m_MemoryBlocks[i].getOutputCount()-(m_GateToOutput?0:3);
        double[] outputInput = new double[count];
        index = 0;
        if (m_BiasToOutput) {
            outputInput[index] = 1.0;
        if (m_InputToOutput) {
            LinearAlgebra.overwriteSubVector(index, inputPattern, outputInput);
            index += m_InputCount;
        //Memory blocks
        for (int i=0; i<m_MemoryBlockCount; i++)
            if (m_GateToOutput) {
                index += m_MemoryBlocks[i].getOutputCount();
            } else {
                index += m_MemoryBlocks[i].getOutputCount()-3;
        if (m_Debug) {
            System.out.println("Bias: " + 1);
            System.out.println("Input: " + LinearAlgebra.toString(inputPattern));
            for (int i=0; i<m_MemoryBlockCount; i++)
                System.out.println("Memory Block " + i + " Output: "
                                   + LinearAlgebra.toString(memoryBlocks[i].outputPattern));
            System.out.println("Bias2Output: " + m_BiasToOutput);
            System.out.println("Input2Output: " + m_InputToOutput);
            System.out.println("Gate2Output: " + m_GateToOutput);
            System.out.println("Resulting Output Layer Input: " + LinearAlgebra.toString(outputInput));

        //*Output layer process
        FunctionalUnit2.ProcessPatternResult2 output = m_OutputLayer.processPattern(outputInput, computeParameterDerivative, false, computeParameterDerivative, false, recordList);

        //*Construct output vector
        LinearAlgebra.overwriteSubVector(0, output.outputPattern, ret.outputPattern);
        if (m_Debug) {
            System.out.println("Output Layer Weights " + LinearAlgebra.toString(m_OutputLayer.getWeights()));
            System.out.println("Output Layer weighted sum " + LinearAlgebra.toString(LinearAlgebra.multMatrixVector(m_OutputLayer.getWeights(), outputInput)));
            System.out.println("Output Layer output: " + LinearAlgebra.toString(output.outputPattern));
            System.out.println("Resulting Network Output: " + LinearAlgebra.toString(ret.outputPattern));

        //*** Derivative computation
        if (computeParameterDerivative) {

            //*Derivative from network output to output layer weights

            //D output[i]/ D weights[j]
            double[][] der2OutputWeights = output.parameterDerivative;
            if (m_Debug) {
                    "Derivative from Output Layer to Ouput Layer Weights: "
                    + LinearAlgebra.toString(output.parameterDerivative));

            //*Derivative from network output to memory blocks weights

            //D output[i] /D block[j]weights[k] =
            //    D output[i] /D block[j]output[l] *
            //    D block[j]output[l] / block[j]param[k]
            double[][][] der2BlocksWeights = new double[m_MemoryBlockCount][][];
            //output derivative index
            int start = (m_BiasToOutput?1:0) + (m_InputToOutput?m_InputCount:0);
            int decount = m_GateToOutput?0:3;
            //for each memory block
            for (int j=0; j<m_MemoryBlockCount; j++)
                //extract output derivative sub matrix linking memory block j
                double[][] der2BlockOutput = LinearAlgebra.extractColumns(
                //remove gate output derivative from block parameter derivative
                double[][] blockOutputDer2BlockParam = LinearAlgebra.extractRows(
                der2BlocksWeights[j] = LinearAlgebra.multMatrixMatrix(der2BlockOutput, blockOutputDer2BlockParam);
                //update output derivative index
                start += m_MemoryBlocks[j].getOutputCount()-decount;
            if (m_Debug) {
                System.out.println("Derivative from Output Layer: "
                                   + LinearAlgebra.toString(output.derivative));
                for (int i=0; i<m_MemoryBlockCount; i++)
                    System.out.println("Derivative from Memory Block " + i + " to Memory Block " + i + " Weights "
                                       + LinearAlgebra.toString(memoryBlocks[i].parameterDerivative));
                for (int i=0; i<m_MemoryBlockCount; i++)
                        "Derivative from Output Layer to Memory Block " + i + " Weights: "
                        + LinearAlgebra.toString(der2BlocksWeights[i]));


            //*Construct parameters derivative

            //Parameters are constructed as follows:
            //    Weights for memory block 1 folowed by
            //    ...
            //    Weights for memory block N followed by
            //    Weights for output layer
            start = 0;
            for (int i=0; i<m_MemoryBlockCount; i++)
                LinearAlgebra.overwriteSubMatrix(0, start, der2BlocksWeights[i], ret.parameterDerivative);
                start += m_MemoryBlocks[i].getParameterCount();
            LinearAlgebra.overwriteSubMatrix(0, start,
              LinearAlgebra.multScalarMatrix(m_OutputWeightsLocalGradientFactor, //Changes in 16Feb06
                 der2OutputWeights), ret.parameterDerivative);
            if (m_Debug) {
                System.out.println("OutputWeightsLocalGradientFactor: "
                                   + m_OutputWeightsLocalGradientFactor);
                System.out.println("Derivative from Output to all Weights: "
                                   + LinearAlgebra.toString(ret.parameterDerivative));


        //*** Finalisation & clean up

        //*Back up state values
        for (int i=0; i<m_MemoryBlockCount; i++)
            m_PrevMemoryBlocksOutput[i] = memoryBlocks[i].outputPattern;

        //*Return values
        return ret;

    public double[] getParameters()
        //Parameters are constructed as follows:
        //    Weights for memory block 1 folowed by
        //    ...
        //    Weights for memory block N followed by
        //    Weights for output layer

        double[] ret = new double[0];
        for (int i=0; i<m_MemoryBlockCount; i++)
            ret = LinearAlgebra.concatenateVectors(ret, m_MemoryBlocks[i].getParameters());
        ret = LinearAlgebra.concatenateVectors(ret, m_OutputLayer.getParameters());

        return ret;

    public void setParameters(double[] parameters)

        //Parameters check
        if (!LinearAlgebra.isVector(parameters, m_ParameterCount)) {
            throw new IllegalArgumentException("parameters is of the wrong size!");

        //Parameters are constructed as follows:
        //    Weights for memory block 1 folowed by
        //    ...
        //    Weights for memory block N followed by
        //    Weights for output layer
        int start = 0;
        for (int i=0; i<m_MemoryBlockCount; i++)
            m_MemoryBlocks[i].setParameters(LinearAlgebra.extractVector(parameters, start, start + m_MemoryBlocks[i].getParameterCount()-1));
            start += m_MemoryBlocks[i].getParameterCount();
        m_OutputLayer.setParameters(LinearAlgebra.extractVector(parameters, start, parameters.length-1));


    //toString method

    public String toString()

        String ret = super.toString() + "\n";
        ret += "Class: FASTLSTMNetwork\n";

        //Structure information
        ret += "\tMemoryBlockCount = " + m_MemoryBlockCount + "\n";
        ret += "\tGateToOutput = " + m_GateToOutput + "\n";
        ret += "\tBiasToOutput = " + m_BiasToOutput + "\n";
        ret += "\tInputToOutput = " + m_InputToOutput + "\n";
        ret += "\tGateToOutput = " + m_GateToOutput + "\n";

        //Internal blocks information
        for (int i=0; i<m_MemoryBlockCount; i++)
            ret += "\tMemoryBlocks[" + i + "] = \n";
            ret += Tools.tabText(m_MemoryBlocks[i].toString(),2) + "\n";
        ret += "\tOutputLayer = \n";
        ret += Tools.tabText(m_OutputLayer.toString(),2) + "\n";

        //Extra parameter
        ret += "\tOutputWeightsLocalGradientFactor = " + m_OutputWeightsLocalGradientFactor;

        return ret;

    //Cloneable/Serializable interface implementation

    public Object clone() {
        return Tools.copyObject(this);
