package com.bptripp.diff;
import ca.nengo.model.Projection;
import ca.nengo.model.StructuralException;
import ca.nengo.model.Noise.Noisy;
import ca.nengo.model.impl.FunctionInput;
import ca.nengo.model.impl.NoiseFactory;
import ca.nengo.model.nef.NEFEnsemble;
import ca.nengo.model.nef.impl.DecodedTermination;
import ca.nengo.util.MU;
public class InterneuronNetwork extends DifferentiatorNetwork {
private static final long serialVersionUID = 1L;
private NEFEnsemble myInterneurons;
private Projection myInputOutputProjection;
private Projection myInputInterneuronProjection;
private Projection myInterneuronOutputProjection;
public InterneuronNetwork(float tauPSC, int numInterneurons) throws StructuralException {
setName("interneuron");
getInputEnsemble().addDecodedTermination("input", MU.I(1), TAU_IO, false);
addProjection(getInput().getOrigin(FunctionInput.ORIGIN_NAME), getInputEnsemble().getTermination("input"));
myInterneurons = myEnsembleFactory.make("interneurons", numInterneurons, 1, "diff_inter_"+numInterneurons, false);
myInterneurons.addDecodedTermination("input", MU.I(1), tauPSC, false);
myInputInterneuronProjection = addProjection(getInputEnsemble().getOrigin(NEFEnsemble.X), myInterneurons.getTermination("input"));
addNode(myInterneurons);
getOutputEnsemble().addDecodedTermination("direct", new float[][]{new float[]{1/tauPSC}}, tauPSC, false);
getOutputEnsemble().addDecodedTermination("indirect", new float[][]{new float[]{-1/tauPSC}}, tauPSC, false);
myInputOutputProjection = addProjection(getInputEnsemble().getOrigin(NEFEnsemble.X), getOutputEnsemble().getTermination("direct"));
myInterneuronOutputProjection = addProjection(myInterneurons.getOrigin(NEFEnsemble.X), getOutputEnsemble().getTermination("indirect"));
}
@Override
public void disableParisien() {
myInputOutputProjection.removeBias();
myInputInterneuronProjection.removeBias();
myInterneuronOutputProjection.removeBias();
}
@Override
public void enableParisien(float propInhibitory) throws StructuralException {
int nOutput = Math.round(propInhibitory * (float) getOutputEnsemble().getNodes().length);
int nInterneurons = Math.round(propInhibitory * (float) myInterneurons.getNodes().length);
enableParisien(myInputOutputProjection, nOutput);
enableParisien(myInputInterneuronProjection, nInterneurons);
enableParisien(myInterneuronOutputProjection, nOutput);
}
@Override
public void clearErrors() {
try {
((Noisy) getInputEnsemble().getOrigin(NEFEnsemble.X)).setNoise(new NoiseFactory.NoiseImplNull());
((Noisy) myInterneurons.getOrigin(NEFEnsemble.X)).setNoise(new NoiseFactory.NoiseImplNull());
} catch (StructuralException e) {
throw new RuntimeException(e);
}
}
@Override
public void setDistortion(int nInput, int nDiff) {
try {
((Noisy) getInputEnsemble().getOrigin(NEFEnsemble.X)).setNoise(makeDistortion(nInput));
((Noisy) myInterneurons.getOrigin(NEFEnsemble.X)).setNoise(makeDistortion(nDiff));
} catch (StructuralException e) {
throw new RuntimeException(e);
}
}
@Override
public void setNoise(int nInput, int nDiff) {
try {
((Noisy) getInputEnsemble().getOrigin(NEFEnsemble.X)).setNoise(makeNoise(nInput));
((Noisy) myInterneurons.getOrigin(NEFEnsemble.X)).setNoise(makeNoise(nDiff));
} catch (StructuralException e) {
throw new RuntimeException(e);
}
}
@Override
public void setTau(float tau) {
try {
myInterneurons.getTermination("input").setTau(tau);
DecodedTermination direct = (DecodedTermination) getOutputEnsemble().getTermination("direct");
direct.setTau(tau);
direct.setTransform(new float[][]{new float[]{1/tau}});
DecodedTermination indirect = (DecodedTermination) getOutputEnsemble().getTermination("indirect");
indirect.setTau(tau);
indirect.setTransform(new float[][]{new float[]{-1/tau}});
} catch (StructuralException e) {
throw new RuntimeException(e);
}
}
}