/*
 * Decompiled with CFR 0.152.
 */
package com.bptripp.diff;

import ca.nengo.math.Function;
import ca.nengo.math.PDF;
import ca.nengo.math.impl.IndicatorPDF;
import ca.nengo.model.Node;
import ca.nengo.model.Noise;
import ca.nengo.model.Origin;
import ca.nengo.model.Projection;
import ca.nengo.model.RealOutput;
import ca.nengo.model.SimulationException;
import ca.nengo.model.SimulationMode;
import ca.nengo.model.StructuralException;
import ca.nengo.model.impl.NodeFactory;
import ca.nengo.model.impl.NoiseFactory;
import ca.nengo.model.nef.NEFEnsemble;
import ca.nengo.model.nef.NEFEnsembleFactory;
import ca.nengo.model.nef.impl.DecodedTermination;
import ca.nengo.model.nef.impl.NEFEnsembleFactoryImpl;
import ca.nengo.model.neuron.impl.ALIFNeuronFactory;
import ca.nengo.model.neuron.impl.ALIFSpikeGenerator;
import ca.nengo.model.neuron.impl.SpikingNeuron;
import ca.nengo.util.MU;
import com.bptripp.diff.DifferentiatorNetwork;
import com.bptripp.diff.Util;

public class AdaptingNetwork
extends DifferentiatorNetwork {
    private static final long serialVersionUID = 1L;
    private static final String COMPENSATING = "compensating";
    private static final String ADAPTING = "adapting";
    private NEFEnsemble myAdapting;
    private NEFEnsemble myCompensating;
    private float myPropAdapting;
    private Projection myInputAdaptingProjection;
    private Projection myInputCompensatingProjection;
    private Projection myAdaptingOutputProjection;
    private Projection myCompensatingOutputProjection;

    public AdaptingNetwork(int nAdapting, int nCompensating, float tauPSC) throws StructuralException {
        this.setName(ADAPTING);
        this.myPropAdapting = (float)nAdapting / ((float)nAdapting + (float)nCompensating);
        this.getInputEnsemble().addDecodedTermination("input", MU.I((int)1), 0.005f, false);
        this.addProjection(this.getInput().getOrigin("origin"), this.getInputEnsemble().getTermination("input"));
        NEFEnsembleFactory aef = AdaptingNetwork.getALIFEnsembleFactory();
        this.myAdapting = aef.make(ADAPTING, nAdapting, 1, "adapting_diff_" + nAdapting, false);
        this.addNode((Node)this.myAdapting);
        NEFEnsembleFactoryImpl ef = new NEFEnsembleFactoryImpl();
        this.myCompensating = ef.make(COMPENSATING, nCompensating, 1);
        this.myCompensating.addDecodedTermination("input", MU.I((int)1), tauPSC, false);
        this.addNode((Node)this.myCompensating);
        this.myAdapting.addDecodedTermination("input", MU.I((int)1), tauPSC, false);
        NEFEnsemble output = this.getOutputEnsemble();
        float[][] scale = new float[][]{{15.0f}};
        output.addDecodedTermination(ADAPTING, (float[][])scale, tauPSC, false);
        output.addDecodedTermination(COMPENSATING, (float[][])scale, tauPSC, false);
        this.myInputAdaptingProjection = this.addProjection(this.getInputEnsemble().getOrigin("X"), this.myAdapting.getTermination("input"));
        this.myInputCompensatingProjection = this.addProjection(this.getInputEnsemble().getOrigin("X"), this.myCompensating.getTermination("input"));
        this.myAdaptingOutputProjection = this.addProjection(this.myAdapting.getOrigin("X"), output.getTermination(ADAPTING));
        this.setCompensation(0.1f);
        try {
            this.getSimulator().addProbe(this.myCompensating.getName(), COMPENSATING, true);
        }
        catch (SimulationException e) {
            throw new StructuralException((Throwable)e);
        }
    }

    @Override
    public void enableParisien(float propInhibitory) throws StructuralException {
        int nAdapting = Math.round(propInhibitory * (float)this.myAdapting.getNodes().length);
        int nCompensating = Math.round(propInhibitory * (float)this.myCompensating.getNodes().length);
        int nOutput = Math.round(propInhibitory * (float)this.getOutputEnsemble().getNodes().length);
        AdaptingNetwork.enableParisien(this.myInputAdaptingProjection, nAdapting);
        AdaptingNetwork.enableParisien(this.myInputCompensatingProjection, nCompensating);
        this.myAdaptingOutputProjection.addBias(nOutput, TAU_INTERNEURONS, this.myAdaptingOutputProjection.getTermination().getTau(), true, false);
        AdaptingNetwork.enableParisien(this.myCompensatingOutputProjection, nOutput);
    }

    @Override
    public void disableParisien() {
        this.myInputAdaptingProjection.removeBias();
        this.myInputCompensatingProjection.removeBias();
        this.myAdaptingOutputProjection.removeBias();
        this.myCompensatingOutputProjection.removeBias();
    }

    public static NEFEnsembleFactory getALIFEnsembleFactory() {
        NEFEnsembleFactoryImpl result = new NEFEnsembleFactoryImpl();
        float incN = 0.05f;
        float tauN = 0.2f;
        result.setNodeFactory((NodeFactory)new ALIFNeuronFactory((PDF)new IndicatorPDF(200.0f, 400.0f), (PDF)new IndicatorPDF(-2.5f, -1.5f), (PDF)new IndicatorPDF(incN), 5.0E-4f, 0.02f, tauN));
        return result;
    }

    @Override
    public void setTau(float tau) {
        Node[] neurons = this.myAdapting.getNodes();
        int i = 0;
        while (i < neurons.length) {
            SpikingNeuron neuron = (SpikingNeuron)neurons[i];
            ALIFSpikeGenerator generator = (ALIFSpikeGenerator)neuron.getGenerator();
            float alpha = AdaptingNetwork.getSlope(neuron) / neuron.getScale();
            float b = neuron.getBias();
            float c = neuron.getScale();
            float tauN = tau / 2.0f * (b / c + 1.0f);
            float A_N = (1.0f / tau - 1.0f / tauN) / alpha;
            generator.setIncN(A_N);
            generator.setTauN(tauN);
            ++i;
        }
        try {
            this.setCompensation(tau);
            float[][] scale = new float[][]{{2.5f / tau}};
            ((DecodedTermination)this.getOutputEnsemble().getTermination(ADAPTING)).setTransform((float[][])scale);
            ((DecodedTermination)this.getOutputEnsemble().getTermination(COMPENSATING)).setTransform((float[][])scale);
        }
        catch (StructuralException e) {
            throw new RuntimeException(e);
        }
    }

    private void setCompensation(float tau) throws StructuralException {
        try {
            Function f = Util.getBiasCompensation(this.myAdapting, "X", tau * 8.0f);
            Origin origin = this.myCompensating.addDecodedOrigin(COMPENSATING, new Function[]{f}, "AXON");
            Projection[] projectionArray = this.getProjections();
            int n = projectionArray.length;
            int n2 = 0;
            while (n2 < n) {
                Projection p = projectionArray[n2];
                if (p.getTermination() == this.getOutputEnsemble().getTermination(COMPENSATING)) {
                    this.removeProjection(this.getOutputEnsemble().getTermination(COMPENSATING));
                }
                ++n2;
            }
            this.myCompensatingOutputProjection = this.addProjection(origin, this.getOutputEnsemble().getTermination(COMPENSATING));
        }
        catch (SimulationException e) {
            throw new StructuralException((Throwable)e);
        }
    }

    public static float getSlope(SpikingNeuron neuron) {
        SimulationMode mode = neuron.getMode();
        float slope = 0.0f;
        try {
            neuron.setMode(SimulationMode.CONSTANT_RATE);
            neuron.setRadialInput(-1.0f);
            neuron.run(0.0f, 0.0f);
            RealOutput low = (RealOutput)neuron.getOrigin("AXON").getValues();
            neuron.setRadialInput(1.0f);
            neuron.run(0.0f, 0.0f);
            RealOutput high = (RealOutput)neuron.getOrigin("AXON").getValues();
            slope = (high.getValues()[0] - low.getValues()[0]) / 2.0f;
            neuron.setMode(mode);
        }
        catch (SimulationException e) {
            throw new RuntimeException(e);
        }
        catch (StructuralException e) {
            throw new RuntimeException(e);
        }
        return slope;
    }

    @Override
    public void clearErrors() {
        try {
            ((Noise.Noisy)this.getInputEnsemble().getOrigin("X")).setNoise((Noise)new NoiseFactory.NoiseImplNull());
            ((Noise.Noisy)this.myAdapting.getOrigin("X")).setNoise((Noise)new NoiseFactory.NoiseImplNull());
            ((Noise.Noisy)this.myCompensating.getOrigin(COMPENSATING)).setNoise((Noise)new NoiseFactory.NoiseImplNull());
        }
        catch (StructuralException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void setDistortion(int nInput, int nDiff) {
        int nAdapting = Math.round((float)nDiff * this.myPropAdapting);
        int nCompensating = Math.round((float)nDiff * (1.0f - this.myPropAdapting));
        try {
            ((Noise.Noisy)this.getInputEnsemble().getOrigin("X")).setNoise(AdaptingNetwork.makeDistortion(nInput));
            ((Noise.Noisy)this.myAdapting.getOrigin("X")).setNoise(AdaptingNetwork.makeDistortion(nAdapting));
            ((Noise.Noisy)this.myCompensating.getOrigin(COMPENSATING)).setNoise(AdaptingNetwork.makeDistortion(nCompensating));
        }
        catch (StructuralException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void setNoise(int nInput, int nDiff) {
        int nAdapting = Math.round((float)nDiff * this.myPropAdapting);
        int nCompensating = Math.round((float)nDiff * (1.0f - this.myPropAdapting));
        try {
            ((Noise.Noisy)this.getInputEnsemble().getOrigin("X")).setNoise(AdaptingNetwork.makeNoise(nInput));
            ((Noise.Noisy)this.myAdapting.getOrigin("X")).setNoise(AdaptingNetwork.makeNoise(nAdapting));
            ((Noise.Noisy)this.myCompensating.getOrigin(COMPENSATING)).setNoise(AdaptingNetwork.makeNoise(nCompensating));
        }
        catch (StructuralException e) {
            throw new RuntimeException(e);
        }
    }

    public static void main(String[] args) throws StructuralException {
        AdaptingNetwork an = new AdaptingNetwork(200, 100, 0.01f);
        an.setTau(0.1f);
    }
}

