package lnsc;
import java.util.*;
/** <P> Abstract class containing the basic implementation for the
* <code>FunctionalUnit</code> interface. </P>
*
* <P> In order to implement the <code>FunctionalUnit</code> interface,
* subclasses need the following 3 things: </P>
* <ol>
* <li>In the constructor, the fields <code>m_InputCount</code>,
* <code>m_OutputCount</code>, <code>m_IsDifferentiable</code>, and
* <code>m_IsStateless</code> must be filled appropriately. </li>
* <li>Either <code>processDataSet(DataSet, String[])</code> or
* <code>processPattern(double[], boolean)</code>
* must be implemented and should prefrerably begin by calling
* <code>preProcessDataSet(DataSet, String[])</code> or
* <code>preProcessPattern(double[], boolean)</code> respectively.
* For non stateless function, <code>reset()</code> must be added.</li>
* <li>Since <code>FunctionalUnit</code> are <code>Serializable</code> and
* <code>Cloneable</code>, any required extra code to make these
* interfaces work properly should be added. It is necessary to at
* least set the <code>private static serialVersionUID</code> variable
* appropriately for the <code>Seriablizable</code> interface. For
* complex objects, the <code>Cloneable</code> interface can rely on
* <code>Tools.copyObject(Serializable)</code>. </li>
* </ol>
*
* @see Tools#copyObject(Serializable)
*
* @author Francois Rivest
* @version 1.0
* @since 1.0
*/
public abstract class AbstractFunctionalUnit implements FunctionalUnit
{
/*********************************************************************/
//Serial Version UID
/** Serial version UID. */
static final long serialVersionUID = -6359518335520622462L;
/*********************************************************************/
//Private fields
/** Indicates the number of variables of the function.
* That is, the value returned by getInputCount().
* Derived classes must fill this slot in their constructor.
*/
protected int m_InputCount;
/** Indicates the number of values returned by the function.
* Tha is, the value returned by getOutputCount().
* Derived classes must fill this slot in their constructor.
*/
protected int m_OutputCount;
/** Indicates whether or not the function is differentiable.
* That is, the value returned by isDifferentiable().
* Derived classes must fill this slot in their constructor.
*/
protected boolean m_IsDifferentiable;
/** Indicates whether or not the function is twice differentiable.
* That is, the value returned by isTwiceDifferentiable().
* Derived classes must fill this slot in their constructor.
*/
protected boolean m_IsTwiceDifferentiable;
/** Indicates whether or not the function output depends solely of the
* current input (and not of the previous pattern it has processed).
* That is, the value returned by isStateless().
* Derived classes must fill this slot in their constructor.
*/
boolean m_IsStateless;
/*********************************************************************/
//FunctionalUnit implementation
public final int getInputCount() {return m_InputCount;}
public final int getOutputCount() {return m_OutputCount;}
public final boolean isDifferentiable() {return m_IsDifferentiable;}
public final boolean isTwiceDifferentiable() {return m_IsTwiceDifferentiable;}
public final boolean isStateless() {return m_IsStateless;}
//Default implementation does nothing
public void reset() {return;}
/** This function validates the argument and creates the object to return.
* It should be called at the very beginning of the method
* {@link #processPattern}.
*/
protected final FunctionalUnit.ProcessPatternResult preProcessPattern(double[] inputPattern, boolean computeDerivative, boolean computeSecondDerivative)
{
if (inputPattern.length != m_InputCount)
throw new IllegalArgumentException("inputPatten is of the wrong size!");
else if (computeSecondDerivative && (!m_IsTwiceDifferentiable))
throw new IllegalArgumentException("computeSecondDerivative requested on a non-twicedifferentiable function!");
else if (computeDerivative && (!m_IsDifferentiable))
throw new IllegalArgumentException("computeDerivative requested on a non-differentiable function!");
else
{
if (!computeDerivative && !computeSecondDerivative)
return new FunctionalUnit.ProcessPatternResult(new double[m_OutputCount]);
else if (!computeSecondDerivative)
return new FunctionalUnit.ProcessPatternResult(new double[m_OutputCount],
new double[m_OutputCount][m_InputCount]);
else
return new FunctionalUnit.ProcessPatternResult(new double[m_OutputCount],
new double[m_OutputCount][m_InputCount],
new double[m_OutputCount][m_InputCount][m_InputCount]);
}
}
//The default implementation calls processDataSet
public FunctionalUnit.ProcessPatternResult processPattern(double[] inputPattern, boolean computeDerivative, boolean computeSecondDerivative)
{
FunctionalUnit.ProcessPatternResult ret = preProcessPattern(inputPattern, computeDerivative, computeSecondDerivative);
DataSet dat = new DataSet();
dat.setData(DataNames.PATTERN_COUNT, new Integer(1));
dat.setData(DataNames.INPUT_PATTERNS, new double[][] {inputPattern});
String[] recordList;
if (computeSecondDerivative) {
recordList = new String[] {DataNames.SECOND_DERIVATIVES, DataNames.DERIVATIVES};
} else if (computeDerivative) {
recordList = new String[] {DataNames.DERIVATIVES};
} else {
recordList = new String[0];
}
processDataSet(dat, recordList);
ret.outputPattern = ((double[][]) dat.getData(DataNames.OUTPUT_PATTERNS))[0];
if (computeDerivative) {
ret.derivative = ((double[][][]) dat.getData(DataNames.DERIVATIVES))[0];
}
if (computeSecondDerivative) {
ret.secondDerivative = ((double[][][][]) dat.getData(DataNames.SECOND_DERIVATIVES))[0];
}
return ret;
}
/** This function validates the <code>dataSet</code>. It should be called at
* the very beginning of the method {@link #processDataSet}.
*/
protected final DataSet preProcessDataSet(DataSet dataSet, String[] recordList)
{
//check requirement
if (!dataSet.hasData(DataNames.PATTERN_COUNT)) {
throw new MissingDataException("Missing DataNames.PATTERN_COUNT in data set!", DataNames.PATTERN_COUNT);
}
if (!dataSet.hasData(DataNames.INPUT_PATTERNS)) {
throw new MissingDataException("Missing DataNames.INPUT_PATTERNS in data set!", DataNames.INPUT_PATTERNS);
}
//check types & sizes
if (!(dataSet.getData(DataNames.PATTERN_COUNT) instanceof Integer)) {
throw new InvalidDataException("DataNames.PATTERN_COUNT must be an Integer!", DataNames.PATTERN_COUNT);
}
int pCount = ((Integer) dataSet.getData(DataNames.PATTERN_COUNT)).intValue();
if (pCount < 0) {
throw new InvalidDataException("DataNames.PATTERN_COUNT must be non-negative!", DataNames.PATTERN_COUNT);
}
if (!LinearAlgebra.isMatrix(dataSet.getData(DataNames.INPUT_PATTERNS), pCount, m_InputCount)) {
throw new InvalidDataException("DataNames.INPUT_PATTERNS does match DataNames.PATTERN_COUNT and FunctionalUnit.getInputCount() in data set!", DataNames.INPUT_PATTERNS);
}
//return
return dataSet;
}
//The default implementation calls processPattern
public DataSet processDataSet(DataSet dataSet, String[] recordList)
{
int i;
FunctionalUnit.ProcessPatternResult ret;
double[][][] derivatives = null;
double[][][][] secondDerivatives = null;
//parameter check
preProcessDataSet(dataSet, recordList);
//gather input patterns
double[][] inputs = (double[][]) dataSet.getData(DataNames.INPUT_PATTERNS);
int count = inputs.length;
//create output patterns space
double[][] outputs = new double[count][];
//create derivative space if necessary
boolean computeDerivative = DataNames.isMember(DataNames.DERIVATIVES, recordList);
if (computeDerivative) {derivatives = new double[count][][];}
//create derivative space if necessary
boolean computeSecondDerivative = DataNames.isMember(DataNames.SECOND_DERIVATIVES, recordList);
if (computeSecondDerivative) {secondDerivatives = new double[count][][][];}
//for each input pattern
for (i=0; i<count; i++)
{
//compute output pattern (and derivative)
ret = processPattern(inputs[i], computeDerivative, computeSecondDerivative);
outputs[i] = ret.outputPattern;
if (computeDerivative) {derivatives[i] = ret.derivative;}
if (computeSecondDerivative) {secondDerivatives[i] = ret.secondDerivative;}
}
//put result back in data set
dataSet.setData(DataNames.OUTPUT_PATTERNS, outputs);
if (computeDerivative) {dataSet.setData(DataNames.DERIVATIVES, derivatives);}
if (computeSecondDerivative) {dataSet.setData(DataNames.SECOND_DERIVATIVES, secondDerivatives);}
//check for error computation
if (DataNames.isMember(DataNames.ERROR_PATTERNS, recordList)) {
computeErrorPatterns(dataSet);
}
//return
return dataSet;
}
/** Computes the error patterns given a data set that contains both the
* real outputs 'OutputPatterns' and the target outputs 'TargetPatterns'.
* Results are stored in the data set under 'ErrorPatterns'.
* @param dataSet A data set containing three required
* datas.
* @return The original <code>dataSet</code> augmented with
* 'ErrorPatterns'.
*/
protected DataSet computeErrorPatterns(DataSet dataSet)
{
//error checking
double[][] outputs = (double[][]) dataSet.getData(DataNames.OUTPUT_PATTERNS);
double[][] targets = (double[][]) dataSet.getData(DataNames.TARGET_PATTERNS);
double[][] errors = LinearAlgebra.batchSubVectors(outputs, targets);
dataSet.setData(DataNames.ERROR_PATTERNS, errors);
return dataSet;
}
/*********************************************************************/
//toString method
public String toString()
{
String ret = new String();
ret += "Interface: FunctionalUnit\n";
ret += "\tInputCount: " + Integer.toString(m_InputCount) + "\n";
ret += "\tOutputCount: " + Integer.toString(m_OutputCount) + "\n";
ret += "\tIsDifferentiable: " + new Boolean(m_IsDifferentiable).toString() + "\n";
ret += "\tIsTwiceDifferentiable: " + new Boolean(m_IsTwiceDifferentiable).toString() + "\n";
ret += "\tIsStateless: " + new Boolean(m_IsStateless).toString();
return ret;
}
/*********************************************************************/
//Cloneable interface implementation
public Object clone() throws CloneNotSupportedException
{
return super.clone();
}
}