/*
 * Created on 1-Apr-08
 */
package com.bptripp.diff;

import ca.nengo.math.impl.IndicatorPDF;
import ca.nengo.model.Projection;
import ca.nengo.model.StructuralException;
import ca.nengo.model.Noise.Noisy;
import ca.nengo.model.impl.FunctionInput;
import ca.nengo.model.impl.NoiseFactory;
import ca.nengo.model.nef.NEFEnsemble;
import ca.nengo.model.nef.NEFEnsembleFactory;
import ca.nengo.model.nef.impl.DecodedTermination;
import ca.nengo.model.nef.impl.NEFEnsembleFactoryImpl;
import ca.nengo.model.nef.impl.NEFEnsembleImpl;
import ca.nengo.model.neuron.impl.LIFNeuronFactory;
import ca.nengo.util.MU;
import ca.nengo.util.impl.RandomHypersphereVG;

/**
 * A differentiator network in which band-pass input-output behavior arises from feedack dynamics. 
 * 
 * @author Bryan Tripp
 */
public class FeedbackNetwork extends DifferentiatorNetwork {

	private static final long serialVersionUID = 1L;

	private static String INPUT = "input";
	private static String FEEDBACK = "feedback";

	private float[][] myA;
	private float[][] myB;
	private NEFEnsemble myDiff;
	private Projection myInputDiffProjection;
	private Projection myDiffDiffProjection;
	private Projection myDiffOutputProjection;
	
	public FeedbackNetwork(int[] numInterneurons, float tauPSC, float[][] A, float[][] B, float[][] C) throws StructuralException {
		myA = A;
		myB = B;
		
		setName("feedback");
						
		getInputEnsemble().addDecodedTermination("input", MU.I(1), TAU_IO, false);
		addProjection(getInput().getOrigin(FunctionInput.ORIGIN_NAME), getInputEnsemble().getTermination("input"));
		
		//2D differentiator ensemble with specified # neurons along each dim
		NEFEnsembleFactory ef = new NEFEnsembleFactoryImpl();
		DimensionRatioVG encoderFactory = new DimensionRatioVG(true, 1, 1);
		encoderFactory.setRatio(new float[]{numInterneurons[0], numInterneurons[1]});
		ef.setEncoderFactory(encoderFactory);
		ef.setNodeFactory(new LIFNeuronFactory(.02f, .0005f, new IndicatorPDF(200, 400), new IndicatorPDF(-1.2f, .95f)));
		
		int n = numInterneurons[0]+numInterneurons[1];
		myDiff = ef.make("diff", n, 2, "feedback_diff_"+numInterneurons[0]+"_"+numInterneurons[1], false);
		((NEFEnsembleImpl) myDiff).setEvalPoints(new RandomHypersphereVG(false, (float) Math.sqrt(2), 0).genVectors(300, 2));
		addNode(myDiff);
		
		getOutputEnsemble().addDecodedTermination("diff", C, TAU_IO, false);
		
		myDiff.addDecodedTermination(FEEDBACK, getA(tauPSC), tauPSC, false);
		myDiff.addDecodedTermination(INPUT, getB(tauPSC), tauPSC, false);

		myInputDiffProjection = addProjection(getInputEnsemble().getOrigin(NEFEnsemble.X), myDiff.getTermination("input"));
		myDiffDiffProjection = addProjection(myDiff.getOrigin(NEFEnsemble.X), myDiff.getTermination("feedback"));
		myDiffOutputProjection = addProjection(myDiff.getOrigin(NEFEnsemble.X), getOutputEnsemble().getTermination("diff"));
	}
		
	
	private float[][] getA(float tauPSC) {
		return MU.sum(MU.I(myA.length), MU.prod(myA, tauPSC));		
	}
	
	private float[][] getB(float tauPSC) {
		return MU.prod(myB, tauPSC);		
	}

	@Override
	public void disableParisien() {
		myInputDiffProjection.removeBias();
		myDiffDiffProjection.removeBias();
		myDiffOutputProjection.removeBias();
	}

	@Override
	public void enableParisien(float propInhibitory) throws StructuralException {
		int nDiff = Math.round(propInhibitory * (float) myDiff.getNodes().length);
		int nOut = Math.round(propInhibitory * (float) getOutputEnsemble().getNodes().length);
		enableParisien(myInputDiffProjection, nDiff);
		myDiffDiffProjection.addBias(nDiff, TAU_INTERNEURONS, myDiffDiffProjection.getTermination().getTau(), true, false);
		enableParisien(myDiffOutputProjection, nOut);
	}

	@Override
	public void clearErrors() {
		try {
			((Noisy) getInputEnsemble().getOrigin(NEFEnsemble.X)).setNoise(new NoiseFactory.NoiseImplNull());
			((Noisy) myDiff.getOrigin(NEFEnsemble.X)).setNoise(new NoiseFactory.NoiseImplNull());
		} catch (StructuralException e) {
			throw new RuntimeException(e);
		}
	}

	/**
	 * Note: sets equal error on each output of the differentiator ensemble
	 */
	@Override
	public void setDistortion(int nInput, int nDiff) {
		try {
			((Noisy) getInputEnsemble().getOrigin(NEFEnsemble.X)).setNoise(makeDistortion(nInput));
			((Noisy) myDiff.getOrigin(NEFEnsemble.X)).setNoise(makeDistortion(nDiff));
		} catch (StructuralException e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public void setNoise(int nInput, int nDiff) {
		try {
			((Noisy) getInputEnsemble().getOrigin(NEFEnsemble.X)).setNoise(makeNoise(nInput));
			((Noisy) myDiff.getOrigin(NEFEnsemble.X)).setNoise(makeNoise(nDiff));
		} catch (StructuralException e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public void setTau(float tau) {
		try {
			DecodedTermination input = (DecodedTermination) myDiff.getTermination(INPUT);
			input.setTau(tau);
			input.setTransform(getB(tau));
			DecodedTermination feedback = (DecodedTermination) myDiff.getTermination(FEEDBACK);
			feedback.setTau(tau);
			feedback.setTransform(getA(tau));
		} catch (StructuralException e) {
			throw new RuntimeException(e);
		}
	}
	
	/**
	 * Allows us to specify ratio of vectors generated along each dimension. 
	 * 
	 * @author Bryan Tripp
	 */
	public static class DimensionRatioVG extends RandomHypersphereVG {

		private static final long serialVersionUID = 1L;
		float[] myRatio;
		
		public DimensionRatioVG(boolean surface, float radius, float axisClusterFactor) {
			super(surface, radius, axisClusterFactor);
			myRatio = new float[]{1};
		}
		
		public void setRatio(float[] ratio) {
			myRatio = ratio;
		}
		
		public float[] getRatio() {
			return myRatio;
		}
		
		@Override
		public float[][] genVectors(int number, int dimension) {
			if (myRatio.length < dimension) {
				throw new RuntimeException("Not enough ratios");
			}
			
			float total = MU.sumToIndex(myRatio, dimension-1);
			int[] numNeeded = new int[dimension];
			for (int i = 0; i < numNeeded.length; i++) {
				numNeeded[i] = Math.round((float) number * myRatio[i] / total);
			}			
			int[] numGenerated = new int[dimension];
			boolean[] done = new boolean[dimension];
			
			float[][] result = new float[number][];
			int index = 0;
			while (!all(done)) {
				float[] vector = super.genVectors(1, dimension)[0];
				int biggestDim = biggestDimension(vector);
				if (!done[biggestDim]) {
					numGenerated[biggestDim]++;
					result[index] = vector;
					index++;
					if (numGenerated[biggestDim] == numNeeded[biggestDim]) done[biggestDim] = true;
				}
			}
			
			return result;
		}
		
		private static boolean all(boolean[] yes) {
			boolean result = true;
			for (boolean b : yes) {
				if (!b) result = false;
			}
			return result;
		}
		
		private static int biggestDimension(float[] vector) {
			float biggest = 0;
			int result = 0;
			
			for (int i = 0; i < vector.length; i++) {
				float size = Math.abs(vector[i]); 
				if (size > biggest) {
					biggest = size;
					result = i;
				}
			}
			
			return result;			
		}
		
	}
	
}