package lnsc.lstm;

import java.io.Serializable;
import lnsc.pmvf.*;
import lnsc.unsup.*;
import lnsc.*;


/** Generic Predictive Minimized Sum-squared-error online learning procedure.
 *  It should be upgraded to use a generic Optimizer and to be exported into
 *  some other package. At the limit, it could be generalized by having a
 *  FunctionalUnit to optimize.  (It minimizes half-MSE.) This one can predict
 *  a single of its input.
 *
 *  Note that in the data set, input/output and network data are current pattern
 *  data. Target, error, gradients, param, and everything related to learning
 *  are previous pattern data. Since the goal of the net is to predict the next
 *  input, the correction can only be done on the next call to train, prior to
 *  procesing the new input pattern.
 *
 *  Note, if the function is not stateless, it should not process data between
 *  train calls, unless a reset is called.
 *
 *  The network must have a single output.
 *
 * @author Francois Rivest
 * @version 1.0
 */


public class OnlineSPMSELearning implements OnlineUnsupervisedLearning, Serializable {

    /*********************************************************************/
    //Serial Version UID

    /** Serial version UID. */
    static final long serialVersionUID = -8476691339663537022L;

    /*********************************************************************/
    //Private fields

    /** Network to be trained. */
    protected FunctionalUnit2 m_Func;

    /** Learning rate. */
    protected double m_Alpha;

    /** Input index. */
    protected int m_Index;

    /*********************************************************************/
    //Private fields (state)

    /** Gradient at t-1. */
    protected transient double[][] m_PreviousGradient;

    /** Output at t-1. */
    protected transient double[] m_PreviousOutput;

    /*********************************************************************/
    //Constructors

    /** Constructs a learner for an LSTM network.
     * @param   newNet   LSTMNetwork to train
     * @param   newAlpha Learning rat
     * @param   newIndex
     */
    public OnlineSPMSELearning(FunctionalUnit2 newFunc, double newAlpha, int newIndex) {
        m_Func = newFunc;
        m_Alpha = newAlpha;
        m_Index = newIndex;
        m_PreviousOutput = new double[1];
        m_PreviousGradient = new double[1][m_Func.getParameterCount()];
    }

    /*********************************************************************/
    //Properties

    /** Allow the learning rate to be changed.
     * @param   newLearningRate   New learning rate value.
     */
    public void setLearningRate(double newLearningRate) {
        m_Alpha = newLearningRate;
    }

    /*********************************************************************/
    //Special

    /** Indicates the end of a sequence, restart function internal state. */
    public void reset() {
        m_Func.reset();
        m_PreviousOutput = new double[1];
        m_PreviousGradient = new double[1][m_Func.getParameterCount()];
    }

    /*********************************************************************/
    //Methods


    /*********************************************************************/
    //OnlineSupervisedLearning interface implementation

    //The current input is used in conjunction with the previous pattern
    //information (output & gradient) to make the previous step updates.
    //Then the current pattern is process and results are saved for processing
    //at the next pattern presentation.

    public DataSet train(double[] inputPattern, String[] recordList)
    {

        //Compute error vector
        double[] errorPattern = LinearAlgebra.subVectors(m_PreviousOutput, new double[] {inputPattern[m_Index]});

        //Compute sse value
        Double sse_val = new Double(LinearAlgebra.sumSquares(errorPattern));

        //Compute squared error gradient to paramater
        double[] gradients = LinearAlgebra.multVectorMatrix(errorPattern, m_PreviousGradient);

        //Compute deltas
        double[] deltas = LinearAlgebra.multScalarVector(-m_Alpha, gradients);

        //Update weights
        double[] params = m_Func.getParameters();
        m_Func.setParameters(LinearAlgebra.addeVectors(params, deltas));

        //Process the pattern through the network and get derivatives to weights
        FunctionalUnit2.ProcessPatternResult2 result = m_Func.processPattern(inputPattern, false, false, true, false, recordList);

        //Backup current data
        m_PreviousOutput = result.outputPattern;
        m_PreviousGradient = result.parameterDerivative;

        //Plots
        //System.out.println(inputPattern[0] + "\t" + inputPattern[1] + "\t" +
        //                   targetPattern[0] + "\t" +  result.outputPattern[0] + "\t" +
        //                   errorPattern[0]);
        //System.out.println("Deltas:\n" + Tools.tabText(LinearAlgebra.toString(deltas),2));


        //Return extra data
        DataSet ret;
        //Adaptive model specific data
        if ((result.extraData == null) && (recordList.length != 0)) {
            ret = new DataSet();
        } else {
            ret = result.extraData;
        }
        if (DataNames.isMember(DataNames.INPUT_PATTERNS, recordList)) {
            ret.setData(DataNames.INPUT_PATTERNS, inputPattern);
        }
        if (DataNames.isMember(DataNames.OUTPUT_PATTERNS, recordList)) {
            ret.setData(DataNames.OUTPUT_PATTERNS, result.outputPattern);
        }
        //Adaptation rule specific data
        if (DataNames.isMember(DataNames.ERROR_PATTERNS, recordList)) {
            ret.setData(DataNames.ERROR_PATTERNS, errorPattern);
        }
        if (DataNames.isMember(DataNames.TARGET_PATTERNS, recordList)) {
            ret.setData(DataNames.TARGET_PATTERNS, inputPattern);
        }
        if (DataNames.isMember(DataNames.SUM_SQUARED_ERROR, recordList)) {
            ret.setData(DataNames.SUM_SQUARED_ERROR, sse_val);
        }
        //Adaptation specific data
        if (DataNames.isMember(DataNames.VALUE, recordList)) {
            ret.setData(DataNames.VALUE, sse_val);
        }
        if (DataNames.isMember(DataNames.GRADIENT, recordList)) {
            ret.setData(DataNames.GRADIENT, gradients);
        }
        if (DataNames.isMember(DataNames.VARIABLES, recordList)) {
            ret.setData(DataNames.VARIABLES, params);
        }
        if (DataNames.isMember(DataNames.VARIABLE_CHANGES, recordList)) {
            ret.setData(DataNames.VARIABLE_CHANGES, deltas);
        }

        //Return
        return ret;

    }

    public void train(double[] inputPattern)
    {
        train(inputPattern, DataNames.EMPTY_RECORDLIST);
    }

    /*********************************************************************/
    //toString method


    public String toString()
    {
        String ret = super.toString() + "\n";
        ret += "Class: OnlineSPMSELearning\n";
        ret += "\tLearningRate: " + m_Alpha;
        ret += "\tIndex: " + m_Index;
        return ret;
    }


    /*********************************************************************/
    //Cloneable/Serializable interface implementation


}