/*
* Created on 1-Apr-08
*/
package com.bptripp.diff;
import ca.nengo.model.Projection;
import ca.nengo.model.StructuralException;
import ca.nengo.model.Noise.Noisy;
import ca.nengo.model.impl.FunctionInput;
import ca.nengo.model.impl.NoiseFactory;
import ca.nengo.model.nef.NEFEnsemble;
import ca.nengo.model.nef.impl.DecodedTermination;
import ca.nengo.util.MU;
public class DualTCNetwork extends DifferentiatorNetwork {
private static final long serialVersionUID = 1L;
private static final String DIRECT = "direct";
private static final String DELAYED = "delayed";
private Projection myDirectProjection;
private Projection myDelayedProjection;
/**
* @param tauPSC Time constant of post-synaptic current decay in fast projection.
* @param slowTauPSC Time constant of post-synaptic current decay in slow projection.
* @param correlatedError If true, errors in two projections are identical; if false they are uncorrelated
*
* @throws StructuralException
*/
public DualTCNetwork(float tauPSC, float slowTauPSC, boolean correlatedError) throws StructuralException {
setName("dualTC");
float tauDifference = slowTauPSC - tauPSC;
getInputEnsemble().addDecodedTermination("input", MU.I(1), TAU_IO, false);
addProjection(getInput().getOrigin(FunctionInput.ORIGIN_NAME), getInputEnsemble().getTermination("input"));
NEFEnsemble output = getOutputEnsemble();
output.addDecodedTermination("direct", new float[][]{new float[]{1f / tauDifference}}, tauPSC, false);
output.addDecodedTermination("delayed", new float[][]{new float[]{-1f / tauDifference}}, slowTauPSC, false);
myDirectProjection = addProjection(getInputEnsemble().getOrigin(NEFEnsemble.X), output.getTermination("direct"));
if (correlatedError) {
myDelayedProjection = addProjection(getInputEnsemble().getOrigin(NEFEnsemble.X), output.getTermination("delayed"));
} else {
int n = getInputEnsemble().getNodes().length;
NEFEnsemble uncorrelated = myEnsembleFactory.make("input2", n, 1, "diff_input2_"+n, false);
uncorrelated.addDecodedTermination("input", MU.I(1), TAU_IO, false);
addNode(uncorrelated);
addProjection(getInput().getOrigin(FunctionInput.ORIGIN_NAME), uncorrelated.getTermination("input"));
myDelayedProjection = addProjection(uncorrelated.getOrigin(NEFEnsemble.X), output.getTermination("delayed"));
}
}
@Override
public void clearErrors() {
((Noisy) myDirectProjection.getOrigin()).setNoise(new NoiseFactory.NoiseImplNull());
((Noisy) myDelayedProjection.getOrigin()).setNoise(new NoiseFactory.NoiseImplNull());
}
@Override
public void setDistortion(int nInput, int nDiff) {
((Noisy) myDirectProjection.getOrigin()).setNoise(makeDistortion(nInput));
((Noisy) myDelayedProjection.getOrigin()).setNoise(makeDistortion(nInput));
}
@Override
public void setNoise(int nInput, int nDiff) {
((Noisy) myDirectProjection.getOrigin()).setNoise(makeNoise(nInput));
((Noisy) myDelayedProjection.getOrigin()).setNoise(makeNoise(nInput));
}
@Override
public void setTau(float tau) {
try {
getOutputEnsemble().getTermination(DELAYED).setTau(tau);
float tauDifference = tau - getOutputEnsemble().getTermination(DIRECT).getTau();
((DecodedTermination) getOutputEnsemble().getTermination(DIRECT)).setTransform(new float[][]{new float[]{1f / tauDifference}});
((DecodedTermination) getOutputEnsemble().getTermination(DELAYED)).setTransform(new float[][]{new float[]{-1f / tauDifference}});
} catch (StructuralException e) {
throw new RuntimeException(e);
}
}
@Override
public void disableParisien() {
myDirectProjection.removeBias();
myDelayedProjection.removeBias();
}
@Override
public void enableParisien(float propInhibitory) throws StructuralException {
int n = Math.round(propInhibitory * (float) getOutputEnsemble().getNodes().length);
enableParisien(myDirectProjection, n);
enableParisien(myDelayedProjection, n);
}
public static void main(String[] args) throws StructuralException {
new DualTCNetwork(.005f, .1f, false);
}
}