/*
* Created on 1-Apr-08
*/
package com.bptripp.diff;
import ca.nengo.math.Function;
import ca.nengo.math.impl.IndicatorPDF;
import ca.nengo.model.Node;
import ca.nengo.model.Origin;
import ca.nengo.model.Projection;
import ca.nengo.model.RealOutput;
import ca.nengo.model.SimulationException;
import ca.nengo.model.SimulationMode;
import ca.nengo.model.StructuralException;
import ca.nengo.model.Noise.Noisy;
import ca.nengo.model.impl.FunctionInput;
import ca.nengo.model.impl.NoiseFactory;
import ca.nengo.model.nef.NEFEnsemble;
import ca.nengo.model.nef.NEFEnsembleFactory;
import ca.nengo.model.nef.impl.DecodedTermination;
import ca.nengo.model.nef.impl.NEFEnsembleFactoryImpl;
import ca.nengo.model.neuron.Neuron;
import ca.nengo.model.neuron.impl.ALIFNeuronFactory;
import ca.nengo.model.neuron.impl.ALIFSpikeGenerator;
import ca.nengo.model.neuron.impl.SpikingNeuron;
import ca.nengo.util.MU;
/**
* A differentiator network in which differentiation is achieved through
* adapting LIF neurons.
*
* @author Bryan Tripp
*/
public class AdaptingNetwork extends DifferentiatorNetwork {
private static final long serialVersionUID = 1L;
private static final String COMPENSATING = "compensating";
private static final String ADAPTING = "adapting";
private NEFEnsemble myAdapting;
private NEFEnsemble myCompensating;
private float myPropAdapting;
private Projection myInputAdaptingProjection;
private Projection myInputCompensatingProjection;
private Projection myAdaptingOutputProjection;
private Projection myCompensatingOutputProjection;
/**
* @param nAdapting Number of adapting neurons
* @param nCompensating Number of non-adapting neurons that compensate for non-zero adapted activity
* @param tauPSC Time constant of post-synaptic current within adapting and compensating neurons
*
* @throws StructuralException
*/
public AdaptingNetwork(int nAdapting, int nCompensating, float tauPSC) throws StructuralException {
setName("adapting");
myPropAdapting = (float) nAdapting / ((float) nAdapting + (float) nCompensating);
getInputEnsemble().addDecodedTermination("input", MU.I(1), TAU_IO, false);
addProjection(getInput().getOrigin(FunctionInput.ORIGIN_NAME), getInputEnsemble().getTermination("input"));
NEFEnsembleFactory aef = getALIFEnsembleFactory();
myAdapting = aef.make("adapting", nAdapting, 1, "adapting_diff_"+nAdapting, false);
addNode(myAdapting);
NEFEnsembleFactory ef = new NEFEnsembleFactoryImpl();
myCompensating = ef.make("compensating", nCompensating, 1);
myCompensating.addDecodedTermination("input", MU.I(1), tauPSC, false);
addNode(myCompensating);
myAdapting.addDecodedTermination("input", MU.I(1), tauPSC, false); //have to wait for bias compensation simulations
NEFEnsemble output = getOutputEnsemble();
float[][] scale = new float[][]{new float[]{15f}}; //this starting value roughly corresponds to the starting (non-uniform) time constant
output.addDecodedTermination(ADAPTING, scale, tauPSC, false);
output.addDecodedTermination(COMPENSATING, scale, tauPSC, false);
myInputAdaptingProjection = addProjection(getInputEnsemble().getOrigin(NEFEnsemble.X), myAdapting.getTermination("input"));
myInputCompensatingProjection = addProjection(getInputEnsemble().getOrigin(NEFEnsemble.X), myCompensating.getTermination("input"));
myAdaptingOutputProjection = addProjection(myAdapting.getOrigin(NEFEnsemble.X), output.getTermination(ADAPTING));
setCompensation(.1f);
try {
getSimulator().addProbe(myCompensating.getName(), COMPENSATING, true);
} catch (SimulationException e) {
throw new StructuralException(e);
}
}
/*
* (non-Javadoc)
* @see ca.bpt.diff.DifferentiatorNetwork#enableParisien(float)
*/
public void enableParisien(float propInhibitory) throws StructuralException {
int nAdapting = Math.round(propInhibitory * (float) myAdapting.getNodes().length);
int nCompensating = Math.round(propInhibitory * (float) myCompensating.getNodes().length);
int nOutput = Math.round(propInhibitory * (float) getOutputEnsemble().getNodes().length);
enableParisien(myInputAdaptingProjection, nAdapting);
enableParisien(myInputCompensatingProjection, nCompensating);
myAdaptingOutputProjection.addBias(nOutput, TAU_INTERNEURONS, myAdaptingOutputProjection.getTermination().getTau(), true, false);
enableParisien(myCompensatingOutputProjection, nOutput);
}
/*
* (non-Javadoc)
* @see ca.bpt.diff.DifferentiatorNetwork#disableParisien()
*/
public void disableParisien() {
myInputAdaptingProjection.removeBias();
myInputCompensatingProjection.removeBias();
myAdaptingOutputProjection.removeBias();
myCompensatingOutputProjection.removeBias();
}
/**
* @return factory for NEFEnsembles composed of adapting LIF neurons
*/
public static NEFEnsembleFactory getALIFEnsembleFactory() {
NEFEnsembleFactory result = new NEFEnsembleFactoryImpl();
float incN = .05f;
float tauN = .2f;
result.setNodeFactory(new ALIFNeuronFactory(new IndicatorPDF(200, 400), new IndicatorPDF(-2.5f, -1.5f), new IndicatorPDF(incN), .0005f, .02f, tauN));
return result;
}
@Override
public void setTau(float tau) {
Node[] neurons = myAdapting.getNodes();
for (int i = 0; i < neurons.length; i++) {
SpikingNeuron neuron = (SpikingNeuron) neurons[i];
ALIFSpikeGenerator generator = (ALIFSpikeGenerator) neuron.getGenerator();
float alpha = getSlope(neuron) / neuron.getScale();
float b = neuron.getBias();
float c = neuron.getScale();
float tauN = tau/2 * (b/c + 1);
float A_N = (1/tau - 1/tauN) / alpha;
generator.setIncN(A_N);
generator.setTauN(tauN);
}
try {
setCompensation(tau);
float[][] scale = new float[][]{new float[]{2.5f / tau}};
((DecodedTermination) getOutputEnsemble().getTermination(ADAPTING)).setTransform(scale);
((DecodedTermination) getOutputEnsemble().getTermination(COMPENSATING)).setTransform(scale);
} catch (StructuralException e) {
throw new RuntimeException(e);
}
}
private void setCompensation(float tau) throws StructuralException {
try {
Function f = Util.getBiasCompensation(myAdapting, NEFEnsemble.X, tau*8);
Origin origin = myCompensating.addDecodedOrigin(COMPENSATING, new Function[]{f}, Neuron.AXON);
for (Projection p : getProjections()) {
if (p.getTermination() == getOutputEnsemble().getTermination(COMPENSATING)) {
removeProjection(getOutputEnsemble().getTermination(COMPENSATING));
}
}
myCompensatingOutputProjection = addProjection(origin, getOutputEnsemble().getTermination(COMPENSATING));
} catch (SimulationException e) {
throw new StructuralException(e);
}
}
/**
* @param neuron A spiking neuron model
* @return mean derivative of spike rate wrt represented quantity, over the range [-1,1] (obtained by simulation)
*/
public static float getSlope(SpikingNeuron neuron) {
SimulationMode mode = neuron.getMode();
float slope = 0;
try {
neuron.setMode(SimulationMode.CONSTANT_RATE);
neuron.setRadialInput(-1);
neuron.run(0, 0);
RealOutput low = (RealOutput) neuron.getOrigin(Neuron.AXON).getValues();
neuron.setRadialInput(1);
neuron.run(0, 0);
RealOutput high = (RealOutput) neuron.getOrigin(Neuron.AXON).getValues();
slope = (high.getValues()[0] - low.getValues()[0]) / 2f;
neuron.setMode(mode);
} catch (SimulationException e) {
throw new RuntimeException(e);
} catch (StructuralException e) {
throw new RuntimeException(e);
}
return slope;
}
@Override
public void clearErrors() {
try {
((Noisy) getInputEnsemble().getOrigin(NEFEnsemble.X)).setNoise(new NoiseFactory.NoiseImplNull());
((Noisy) myAdapting.getOrigin(NEFEnsemble.X)).setNoise(new NoiseFactory.NoiseImplNull());
((Noisy) myCompensating.getOrigin(COMPENSATING)).setNoise(new NoiseFactory.NoiseImplNull());
} catch (StructuralException e) {
throw new RuntimeException(e);
}
}
@Override
public void setDistortion(int nInput, int nDiff) {
int nAdapting = Math.round(nDiff*myPropAdapting);
int nCompensating = Math.round(nDiff*(1-myPropAdapting));
try {
((Noisy) getInputEnsemble().getOrigin(NEFEnsemble.X)).setNoise(makeDistortion(nInput));
((Noisy) myAdapting.getOrigin(NEFEnsemble.X)).setNoise(makeDistortion(nAdapting));
((Noisy) myCompensating.getOrigin(COMPENSATING)).setNoise(makeDistortion(nCompensating));
} catch (StructuralException e) {
throw new RuntimeException(e);
}
}
@Override
public void setNoise(int nInput, int nDiff) {
int nAdapting = Math.round(nDiff*myPropAdapting);
int nCompensating = Math.round(nDiff*(1-myPropAdapting));
try {
((Noisy) getInputEnsemble().getOrigin(NEFEnsemble.X)).setNoise(makeNoise(nInput));
((Noisy) myAdapting.getOrigin(NEFEnsemble.X)).setNoise(makeNoise(nAdapting));
((Noisy) myCompensating.getOrigin(COMPENSATING)).setNoise(makeNoise(nCompensating));
} catch (StructuralException e) {
throw new RuntimeException(e);
}
}
public static void main(String[] args) throws StructuralException {
AdaptingNetwork an = new AdaptingNetwork(200, 100, .01f);
an.setTau(.1f);
}
}