/*
 * Decompiled with CFR 0.152.
 */
package com.bptripp.diff;

import ca.nengo.math.PDF;
import ca.nengo.math.impl.IndicatorPDF;
import ca.nengo.model.Node;
import ca.nengo.model.Noise;
import ca.nengo.model.Projection;
import ca.nengo.model.StructuralException;
import ca.nengo.model.impl.NodeFactory;
import ca.nengo.model.impl.NoiseFactory;
import ca.nengo.model.nef.NEFEnsemble;
import ca.nengo.model.nef.impl.DecodedTermination;
import ca.nengo.model.nef.impl.NEFEnsembleFactoryImpl;
import ca.nengo.model.nef.impl.NEFEnsembleImpl;
import ca.nengo.model.neuron.impl.LIFNeuronFactory;
import ca.nengo.util.MU;
import ca.nengo.util.VectorGenerator;
import ca.nengo.util.impl.RandomHypersphereVG;
import com.bptripp.diff.DifferentiatorNetwork;

public class FeedbackNetwork
extends DifferentiatorNetwork {
    private static final long serialVersionUID = 1L;
    private static String INPUT = "input";
    private static String FEEDBACK = "feedback";
    private float[][] myA;
    private float[][] myB;
    private NEFEnsemble myDiff;
    private Projection myInputDiffProjection;
    private Projection myDiffDiffProjection;
    private Projection myDiffOutputProjection;

    public FeedbackNetwork(int[] numInterneurons, float tauPSC, float[][] A, float[][] B, float[][] C) throws StructuralException {
        this.myA = A;
        this.myB = B;
        this.setName("feedback");
        this.getInputEnsemble().addDecodedTermination("input", MU.I((int)1), 0.005f, false);
        this.addProjection(this.getInput().getOrigin("origin"), this.getInputEnsemble().getTermination("input"));
        NEFEnsembleFactoryImpl ef = new NEFEnsembleFactoryImpl();
        DimensionRatioVG encoderFactory = new DimensionRatioVG(true, 1.0f, 1.0f);
        encoderFactory.setRatio(new float[]{numInterneurons[0], numInterneurons[1]});
        ef.setEncoderFactory((VectorGenerator)encoderFactory);
        ef.setNodeFactory((NodeFactory)new LIFNeuronFactory(0.02f, 5.0E-4f, (PDF)new IndicatorPDF(200.0f, 400.0f), (PDF)new IndicatorPDF(-1.2f, 0.95f)));
        int n = numInterneurons[0] + numInterneurons[1];
        this.myDiff = ef.make("diff", n, 2, "feedback_diff_" + numInterneurons[0] + "_" + numInterneurons[1], false);
        ((NEFEnsembleImpl)this.myDiff).setEvalPoints(new RandomHypersphereVG(false, (float)Math.sqrt(2.0), 0.0f).genVectors(300, 2));
        this.addNode((Node)this.myDiff);
        this.getOutputEnsemble().addDecodedTermination("diff", C, 0.005f, false);
        this.myDiff.addDecodedTermination(FEEDBACK, this.getA(tauPSC), tauPSC, false);
        this.myDiff.addDecodedTermination(INPUT, this.getB(tauPSC), tauPSC, false);
        this.myInputDiffProjection = this.addProjection(this.getInputEnsemble().getOrigin("X"), this.myDiff.getTermination("input"));
        this.myDiffDiffProjection = this.addProjection(this.myDiff.getOrigin("X"), this.myDiff.getTermination("feedback"));
        this.myDiffOutputProjection = this.addProjection(this.myDiff.getOrigin("X"), this.getOutputEnsemble().getTermination("diff"));
    }

    private float[][] getA(float tauPSC) {
        return MU.sum((float[][])MU.I((int)this.myA.length), (float[][])MU.prod((float[][])this.myA, (float)tauPSC));
    }

    private float[][] getB(float tauPSC) {
        return MU.prod((float[][])this.myB, (float)tauPSC);
    }

    @Override
    public void disableParisien() {
        this.myInputDiffProjection.removeBias();
        this.myDiffDiffProjection.removeBias();
        this.myDiffOutputProjection.removeBias();
    }

    @Override
    public void enableParisien(float propInhibitory) throws StructuralException {
        int nDiff = Math.round(propInhibitory * (float)this.myDiff.getNodes().length);
        int nOut = Math.round(propInhibitory * (float)this.getOutputEnsemble().getNodes().length);
        FeedbackNetwork.enableParisien(this.myInputDiffProjection, nDiff);
        this.myDiffDiffProjection.addBias(nDiff, TAU_INTERNEURONS, this.myDiffDiffProjection.getTermination().getTau(), true, false);
        FeedbackNetwork.enableParisien(this.myDiffOutputProjection, nOut);
    }

    @Override
    public void clearErrors() {
        try {
            ((Noise.Noisy)this.getInputEnsemble().getOrigin("X")).setNoise((Noise)new NoiseFactory.NoiseImplNull());
            ((Noise.Noisy)this.myDiff.getOrigin("X")).setNoise((Noise)new NoiseFactory.NoiseImplNull());
        }
        catch (StructuralException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void setDistortion(int nInput, int nDiff) {
        try {
            ((Noise.Noisy)this.getInputEnsemble().getOrigin("X")).setNoise(FeedbackNetwork.makeDistortion(nInput));
            ((Noise.Noisy)this.myDiff.getOrigin("X")).setNoise(FeedbackNetwork.makeDistortion(nDiff));
        }
        catch (StructuralException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void setNoise(int nInput, int nDiff) {
        try {
            ((Noise.Noisy)this.getInputEnsemble().getOrigin("X")).setNoise(FeedbackNetwork.makeNoise(nInput));
            ((Noise.Noisy)this.myDiff.getOrigin("X")).setNoise(FeedbackNetwork.makeNoise(nDiff));
        }
        catch (StructuralException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public void setTau(float tau) {
        try {
            DecodedTermination input = (DecodedTermination)this.myDiff.getTermination(INPUT);
            input.setTau(tau);
            input.setTransform(this.getB(tau));
            DecodedTermination feedback = (DecodedTermination)this.myDiff.getTermination(FEEDBACK);
            feedback.setTau(tau);
            feedback.setTransform(this.getA(tau));
        }
        catch (StructuralException e) {
            throw new RuntimeException(e);
        }
    }

    public static class DimensionRatioVG
    extends RandomHypersphereVG {
        private static final long serialVersionUID = 1L;
        float[] myRatio = new float[]{1.0f};

        public DimensionRatioVG(boolean surface, float radius, float axisClusterFactor) {
            super(surface, radius, axisClusterFactor);
        }

        public void setRatio(float[] ratio) {
            this.myRatio = ratio;
        }

        public float[] getRatio() {
            return this.myRatio;
        }

        public float[][] genVectors(int number, int dimension) {
            if (this.myRatio.length < dimension) {
                throw new RuntimeException("Not enough ratios");
            }
            float total = MU.sumToIndex((float[])this.myRatio, (int)(dimension - 1));
            int[] numNeeded = new int[dimension];
            int i = 0;
            while (i < numNeeded.length) {
                numNeeded[i] = Math.round((float)number * this.myRatio[i] / total);
                ++i;
            }
            int[] numGenerated = new int[dimension];
            boolean[] done = new boolean[dimension];
            float[][] result = new float[number][];
            int index = 0;
            while (!DimensionRatioVG.all(done)) {
                float[] vector = super.genVectors(1, dimension)[0];
                int biggestDim = DimensionRatioVG.biggestDimension(vector);
                if (done[biggestDim]) continue;
                int n = biggestDim;
                numGenerated[n] = numGenerated[n] + 1;
                result[index] = vector;
                ++index;
                if (numGenerated[biggestDim] != numNeeded[biggestDim]) continue;
                done[biggestDim] = true;
            }
            return result;
        }

        private static boolean all(boolean[] yes) {
            boolean result = true;
            boolean[] blArray = yes;
            int n = yes.length;
            int n2 = 0;
            while (n2 < n) {
                boolean b = blArray[n2];
                if (!b) {
                    result = false;
                }
                ++n2;
            }
            return result;
        }

        private static int biggestDimension(float[] vector) {
            float biggest = 0.0f;
            int result = 0;
            int i = 0;
            while (i < vector.length) {
                float size = Math.abs(vector[i]);
                if (size > biggest) {
                    biggest = size;
                    result = i;
                }
                ++i;
            }
            return result;
        }
    }
}

