package lnsc.pmvf;
import lnsc.*;
/** <P> Abstract class containing the basic implementation for the
* <code>FunctionalUnit2</code> interface. </P>
*
* <P> In order to implement the <code>FunctionalUnit2</code> interface,
* subclasses need the following 4 things: </P>
* <ol>
* <li>In the constructor, the fields <code>m_InputCount</code>,
* <code>m_OutputCount</code>, <code>m_IsDifferentiable</code>,
* <code>m_IsTwiceDifferentiable</code>, <code>m_ParametersCount,
* <code>m_IsParametersDifferentiable</code>, and
* <code>m_IsParametersTwiceDifferentiable</code> must be filled
* appropriately. </li>
* <li><code>processPattern(double[], boolean, boolean, boolean, boolean)</code>
* must be implemented and should prefrerably begin by calling
* <code>preProcessPattern(double[], boolean, boolean, boolean, boolean)</code>.
* For non stateless function, <code>reset()</code> must be added.</li>
* <li>If the function is parametric (i.e. has more than one parameters),
* functions <code>getParameters()</code> and <code>setParameters(double[])</code>
* must be written accordingly. Note that get and set parameters must
* work by copying values, not referencing to whole arrays.</li>
* <li>Since <code>FunctionalUnit2</code> are <code>Serializable</code> and
* <code>Cloneable</code>, any required extra code to make these
* interfaces work properly should be added. It is necessary to at
* least set the <code>private static serialVersionUID</code> variable
* appropriately for the <code>Seriablizable</code> interface. For
* complex objects, the <code>Cloneable</code> interface can rely on
* <code>Tools.copyObject(Serializable)</code> for deep cloning. </li>
* </ol>
*
* @author Francois Rivest
* @version 1.0
* @since 1.0
*/
public abstract class AbstractFunctionalUnit2 implements FunctionalUnit2
{
/*********************************************************************/
//Serial Version UID
/** Serial version UID. */
static final long serialVersionUID = -6359518335520622462L;
/*********************************************************************/
//Private fields
/** Indicates the number of variables of the function.
* That is, the value returned by getInputCount().
* Derived classes must fill this slot in their constructor.
*/
protected int m_InputCount;
/** Indicates the number of values returned by the function.
* That is, the value returned by getOutputCount().
* Derived classes must fill this slot in their constructor.
*/
protected int m_OutputCount;
/** Indicates whether or not the function is differentiable.
* That is, the value returned by isDifferentiable().
* Derived classes must fill this slot in their constructor.
*/
protected boolean m_IsDifferentiable;
/** Indicates whether or not the function is twice differentiable.
* That is, the value returned by isTwiceDifferentiable().
* Derived classes must fill this slot in their constructor.
*/
protected boolean m_IsTwiceDifferentiable;
/** Indicates the number of parameters for this function.
* That is, the value returned by getParameterCount().
* Derived classes must fill this slot in their constructor.
*/
protected int m_ParameterCount;
/** Indicates whether or not the function is differentiable with respect
* to its parameters.
* That is, the value returned by isParameterDifferentiable().
* Derived classes must fill this slot in their constructor.
*/
protected boolean m_IsParameterDifferentiable;
/** Indicates whether or not the function is twice differentiable with
* respect to its parameters.
* That is, the value returned by isParameterTwiceDifferentiable().
* Derived classes must fill this slot in their constructor.
*/
protected boolean m_IsParameterTwiceDifferentiable;
/** Indicates whether or not the function output depends solely of the
* current input (and not of the previous pattern it has processed).
* That is, the value returned by isStateless().
* Derived classes must fill this slot in their constructor.
*/
boolean m_IsStateless;
/*********************************************************************/
//FunctionalUnit implementation
public final int getInputCount() {return m_InputCount;}
public final int getOutputCount() {return m_OutputCount;}
public final boolean isDifferentiable() {return m_IsDifferentiable;}
public final boolean isTwiceDifferentiable() {return m_IsTwiceDifferentiable;}
public final int getParameterCount() {return m_ParameterCount;}
public final boolean isParameterDifferentiable() {return m_IsParameterDifferentiable;}
public final boolean isParameterTwiceDifferentiable() {return m_IsParameterTwiceDifferentiable;}
public final boolean isStateless() {return m_IsStateless;}
//Default implementation does nothing
public void reset() {return;}
//For functionalUnit only
public FunctionalUnit.ProcessPatternResult processPattern(
double[] inputPattern,
boolean computeDerivative,
boolean computeSecondDerivative)
{
FunctionalUnit2.ProcessPatternResult2 ret =
processPattern(inputPattern,
computeDerivative,
computeSecondDerivative,
false,
false,
new String[0]);
return new FunctionalUnit.ProcessPatternResult(ret.outputPattern,
ret.derivative,
ret.secondDerivative);
}
/** This function validates the arguments and creates the object to
* return. It should be called at the very beginning of the method
* {@link #processPattern}.
*/
protected final FunctionalUnit2.ProcessPatternResult2 preProcessPattern(
double[] inputPattern,
boolean computeDerivative,
boolean computeSecondDerivative,
boolean computeParameterDerivative,
boolean computeParameterSecondDerivative,
String[] recordList)
{
if (inputPattern.length != m_InputCount)
throw new IllegalArgumentException("inputPatten is of the wrong size!");
else if (computeDerivative && (!m_IsDifferentiable))
throw new IllegalArgumentException("computeDerivative requested on a non-differentiable function!");
else if (computeSecondDerivative && (!m_IsTwiceDifferentiable))
throw new IllegalArgumentException("computeSecondDerivative requested on a non-twicedifferentiable function!");
else if (computeParameterDerivative && (!m_IsParameterDifferentiable))
throw new IllegalArgumentException("computeParameterDerivative requested on a non-differentiable function!");
else if (computeParameterSecondDerivative && (!m_IsParameterTwiceDifferentiable))
throw new IllegalArgumentException("computeParameterSecondDerivative requested on a non-twicedifferentiable function!");
else
{
FunctionalUnit2.ProcessPatternResult2 ret = new FunctionalUnit2.ProcessPatternResult2(m_OutputCount);
if (computeDerivative) {ret.derivative = new double[m_OutputCount][m_InputCount];}
if (computeSecondDerivative) {ret.secondDerivative = new double[m_OutputCount][m_InputCount][m_InputCount];}
if (computeParameterDerivative) {ret.parameterDerivative = new double[m_OutputCount][m_ParameterCount];}
if (computeParameterSecondDerivative) {ret.parameterSecondDerivative = new double[m_OutputCount][m_ParameterCount][m_ParameterCount];}
return ret;
}
}
//The default implementation calls processDataSet
public FunctionalUnit2.ProcessPatternResult2 processPattern(
double[] inputPattern,
boolean computeDerivative,
boolean computeSecondDerivative,
boolean computeParameterDerivative,
boolean computeParameterSecondDerivative,
String[] recordList)
{
//Preprocessing
FunctionalUnit2.ProcessPatternResult2 ret =
preProcessPattern(inputPattern,
computeDerivative,
computeSecondDerivative,
computeParameterDerivative,
computeParameterSecondDerivative,
recordList);
//Construct dataSet
DataSet dat = new DataSet();
dat.setData(DataNames.PATTERN_COUNT, new Integer(1));
dat.setData(DataNames.INPUT_PATTERNS, new double[][] {inputPattern});
//Construct recordList
if (computeDerivative) {
recordList = DataNames.concat(recordList, new String[]{DataNames.DERIVATIVES});
}
if (computeSecondDerivative) {
recordList = DataNames.concat(recordList, new String[]{DataNames.SECOND_DERIVATIVES});
}
if (computeParameterDerivative) {
recordList = DataNames.concat(recordList, new String[]{DataNames.PARAMETER_DERIVATIVES});
}
if (computeParameterSecondDerivative) {
recordList = DataNames.concat(recordList, new String[]{DataNames.PARAMETER_SECOND_DERIVATIVES});
}
//Process
processDataSet(dat, recordList);
//Extract results
ret.outputPattern = ((double[][]) dat.getData(DataNames.OUTPUT_PATTERNS))[0];
if (computeDerivative) {
ret.derivative = ((double[][][]) dat.getData(DataNames.DERIVATIVES))[0];
}
if (computeSecondDerivative) {
ret.secondDerivative = ((double[][][][]) dat.getData(DataNames.SECOND_DERIVATIVES))[0];
}
if (computeParameterDerivative) {
ret.parameterDerivative = ((double[][][]) dat.getData(DataNames.PARAMETER_DERIVATIVES))[0];
}
if (computeParameterSecondDerivative) {
ret.parameterSecondDerivative = ((double[][][][]) dat.getData(DataNames.PARAMETER_SECOND_DERIVATIVES))[0];
}
dat.removeAllBut(recordList);
ret.extraData = dat;
//Return
return ret;
}
/** This function validates the <code>dataSet</code>. It should be called at
* the very beginning of the method {@link #processDataSet}.
*/
protected final DataSet preProcessDataSet(DataSet dataSet, String[] recordList)
{
//check requirement
if (!dataSet.hasData(DataNames.PATTERN_COUNT)) {
throw new MissingDataException("Missing DataNames.PATTERN_COUNT in data set!", DataNames.PATTERN_COUNT);
}
if (!dataSet.hasData(DataNames.INPUT_PATTERNS)) {
throw new MissingDataException("Missing DataNames.INPUT_PATTERNS in data set!", DataNames.INPUT_PATTERNS);
}
//check types & sizes
if (!(dataSet.getData(DataNames.PATTERN_COUNT) instanceof Integer)) {
throw new InvalidDataException("DataNames.PATTERN_COUNT must be an Integer!", DataNames.PATTERN_COUNT);
}
int pCount = ((Integer) dataSet.getData(DataNames.PATTERN_COUNT)).intValue();
if (pCount < 0) {
throw new InvalidDataException("DataNames.PATTERN_COUNT must be non-negative!", DataNames.PATTERN_COUNT);
}
if (!LinearAlgebra.isMatrix(dataSet.getData(DataNames.INPUT_PATTERNS), pCount, m_InputCount)) {
throw new InvalidDataException("DataNames.INPUT_PATTERNS does match DataNames.PATTERN_COUNT and FunctionalUnit.getInputCount() in data set!", DataNames.INPUT_PATTERNS);
}
//return
return dataSet;
}
//The default implementation calls processPattern
public DataSet processDataSet(DataSet dataSet, String[] recordList)
{
double[][][] derivatives = null;
double[][][][] secondDerivatives = null;
double[][][] parameterDerivatives = null;
double[][][][] parameterSecondDerivatives = null;
//Preprocessing
preProcessDataSet(dataSet, recordList);
//Extract dataSet
int patternCount = ((Integer) dataSet.getData(DataNames.PATTERN_COUNT)).intValue();
double[][] inputs = (double[][]) dataSet.getData(DataNames.INPUT_PATTERNS);
//Extract recordList
boolean computeDerivative = DataNames.isMember(DataNames.DERIVATIVES, recordList);
boolean computeSecondDerivative = DataNames.isMember(DataNames.SECOND_DERIVATIVES, recordList);
boolean computeParameterDerivative = DataNames.isMember(DataNames.PARAMETER_DERIVATIVES, recordList);
boolean computeParameterSecondDerivative = DataNames.isMember(DataNames.PARAMETER_SECOND_DERIVATIVES, recordList);
//Create result spaces
double[][] outputs = new double[patternCount][];
if (computeDerivative) {
derivatives = new double[patternCount][][];
}
if (computeSecondDerivative) {
secondDerivatives = new double[patternCount][][][];
}
if (computeParameterDerivative) {
parameterDerivatives = new double[patternCount][][];
}
if (computeParameterSecondDerivative) {
parameterSecondDerivatives = new double[patternCount][][][];
}
//Process patterns
DataSetCollection extraData = new DataSetCollection(patternCount);
for (int i=0; i<patternCount; i++)
{
FunctionalUnit2.ProcessPatternResult2 ret =
processPattern(inputs[i],
computeDerivative,
computeSecondDerivative,
computeParameterDerivative,
computeParameterSecondDerivative,
recordList);
outputs[i] = ret.outputPattern;
if (computeDerivative) {derivatives[i] = ret.derivative;}
if (computeSecondDerivative) {secondDerivatives[i] = ret.secondDerivative;}
if (computeParameterDerivative) {derivatives[i] = ret.parameterDerivative;}
if (computeParameterSecondDerivative) {secondDerivatives[i] = ret.parameterSecondDerivative;}
extraData.setDataSet(i, ret.extraData, false);
}
//Save results
dataSet.setData(DataNames.OUTPUT_PATTERNS, outputs);
if (computeDerivative) {dataSet.setData(DataNames.DERIVATIVES, derivatives);}
if (computeSecondDerivative) {dataSet.setData(DataNames.SECOND_DERIVATIVES, secondDerivatives);}
if (computeParameterDerivative) {dataSet.setData(DataNames.PARAMETER_DERIVATIVES, parameterDerivatives);}
if (computeParameterSecondDerivative) {dataSet.setData(DataNames.PARAMETER_SECOND_DERIVATIVES, parameterSecondDerivatives);}
dataSet.setData(DataNames.EXTRA_DATA, extraData);
//Compute ERROR_PATTERNS on requested
if (DataNames.isMember(DataNames.ERROR_PATTERNS, recordList)) {
computeErrorPatterns(dataSet);
}
//return
return dataSet;
}
/** Computes the error patterns given a data set that contains both the
* real outputs {@link DataNames#OUTPUT_PATTERNS} and the target
* outputs {@link DataNames#TARGET_PATTERNS}. Results are stored in the
* data set under {@link DataNames#ERROR_PATTERNS}.
* @param dataSet A data set containing two required
* datas.
* @return The original <code>dataSet</code> augmented with
* the added data.
* @todo error checking
* @deprecated
*/
protected DataSet computeErrorPatterns(DataSet dataSet)
{
//TODO: error checking
double[][] outputs = (double[][]) dataSet.getData(DataNames.OUTPUT_PATTERNS);
double[][] targets = (double[][]) dataSet.getData(DataNames.TARGET_PATTERNS);
double[][] errors = LinearAlgebra.batchSubVectors(outputs, targets);
dataSet.setData(DataNames.ERROR_PATTERNS, errors);
return dataSet;
}
/*********************************************************************/
//toString method
public String toString()
{
String ret = new String();
ret += "Interface: FunctionalUnit\n";
ret += "\tInputCount: " + Integer.toString(m_InputCount) + "\n";
ret += "\tOutputCount: " + Integer.toString(m_OutputCount) + "\n";
ret += "\tIsDifferentiable: " + new Boolean(m_IsDifferentiable).toString() + "\n";
ret += "\tIsTwiceDifferentiable: " + new Boolean(m_IsTwiceDifferentiable).toString() + "\n";
ret += "\tParameterCount: " + Integer.toString(m_ParameterCount) + "\n";
ret += "\tIsParameterDifferentiable: " + new Boolean(m_IsParameterDifferentiable).toString() + "\n";
ret += "\tIsParameterTwiceDifferentiable: " + new Boolean(m_IsParameterTwiceDifferentiable).toString() + "\n";
ret += "\tIsStateless: " + new Boolean(m_IsStateless).toString();
return ret;
}
/*********************************************************************/
//Cloneable interface implementation
public Object clone() throws CloneNotSupportedException
{
return super.clone();
}
}