/*
 * Created on 11-Apr-08
 */
package com.bptripp.diff;

import ca.nengo.dynamics.DynamicalSystem;
import ca.nengo.math.Function;
import ca.nengo.math.impl.IndicatorPDF;
import ca.nengo.model.Node;
import ca.nengo.model.Projection;
import ca.nengo.model.RealOutput;
import ca.nengo.model.SimulationException;
import ca.nengo.model.SimulationMode;
import ca.nengo.model.StructuralException;
import ca.nengo.model.Noise.Noisy;
import ca.nengo.model.impl.FunctionInput;
import ca.nengo.model.impl.NoiseFactory;
import ca.nengo.model.nef.NEFEnsemble;
import ca.nengo.model.nef.NEFEnsembleFactory;
import ca.nengo.model.nef.impl.BiasOrigin;
import ca.nengo.model.nef.impl.DecodedOrigin;
import ca.nengo.model.nef.impl.DecodedTermination;
import ca.nengo.model.nef.impl.NEFEnsembleFactoryImpl;
import ca.nengo.model.neuron.Neuron;
import ca.nengo.model.neuron.impl.LIFNeuronFactory;
import ca.nengo.model.neuron.impl.SpikingNeuron;
import ca.nengo.util.MU;
import ca.nengo.util.Probe;
import ca.nengo.util.TimeSeries;

/**
 * A DifferentiatorNetwork in which differentiation is achieved through short-term synaptic depression. 
 *   
 * @author Bryan Tripp
 */
public class DepressionNetwork extends DifferentiatorNetwork {

	private static final long serialVersionUID = 1L;
	
	private static final String DEPRESSING = "depressing";
	private static final String COMPENSATING = "compensating";

	private NEFEnsemble myDepressingEnsemble;
	private Probe myInputProbe;
	private Projection myDepressingProjection;
	private Projection myCompensatingProjection;

	/**
	 * @param n Number of neurons with depressing synapses (presynaptic depression mechanisms). 
	 *   
	 * @throws StructuralException
	 */
	public DepressionNetwork(int n) throws StructuralException {
		setName("depression");
		
		removeNode(super.getInputEnsemble().getName());
		for (Probe probe : getSimulator().getProbes()) {
			if (probe.getTarget().equals(super.getInputEnsemble())) {
				try {
					getSimulator().removeProbe(probe);
				} catch (SimulationException e) {
					throw new RuntimeException(e);
				}
			}
		}

		String name = super.getInputEnsemble().getName();
		
		myDepressingEnsemble = getLinearFactory().make(name, n, 1, "depression_input_"+n, true);
		myDepressingEnsemble.addDecodedTermination("input", MU.I(1), TAU_IO, false);
		addNode(myDepressingEnsemble);
		try {
			myInputProbe = getSimulator().addProbe(myDepressingEnsemble.getName(), NEFEnsemble.X, true);
		} catch (SimulationException e) {
			throw new RuntimeException(e);
		}
		addProjection(getInput().getOrigin(FunctionInput.ORIGIN_NAME), getInputEnsemble().getTermination("input"));
		
		int maxPoolSize = 100;
		float tauRecovery = 0.5f;
		float proportionReleased = .01f;
		DynamicalSystem depressionDynamics = new SynapticDepressionDynamics(maxPoolSize, tauRecovery, proportionReleased);
		((DecodedOrigin) myDepressingEnsemble.getOrigin(NEFEnsemble.X)).setSTPDynamics(depressionDynamics);

		NEFEnsemble output = getOutputEnsemble();
		float scale = 1f;
		output.addDecodedTermination(DEPRESSING, new float[][]{new float[]{scale}}, TAU_IO, false);
		output.addDecodedTermination(COMPENSATING, new float[][]{new float[]{scale}}, TAU_IO, false);
		
		setCompensation(.1f);		
		addProjections();
	}
	
	private static NEFEnsembleFactory getLinearFactory() {
		NEFEnsembleFactory result = new NEFEnsembleFactoryImpl();
		result.setNodeFactory(new LIFNeuronFactory(.02f, .0005f, new IndicatorPDF(200, 400), new IndicatorPDF(-2.5f, -1.5f)));
		return result;
	}
	
	private void removeProjections() throws StructuralException {
		removeProjection(getInputEnsemble().getTermination("input"));
		removeProjection(getOutputEnsemble().getTermination(DEPRESSING));					
		removeProjection(getOutputEnsemble().getTermination(COMPENSATING));					
	}
	
	private void addProjections() throws StructuralException {
		addProjection(getInput().getOrigin(FunctionInput.ORIGIN_NAME), getInputEnsemble().getTermination("input"));
		myDepressingProjection = addProjection(myDepressingEnsemble.getOrigin(NEFEnsemble.X), getOutputEnsemble().getTermination(DEPRESSING));
		myCompensatingProjection = addProjection(myDepressingEnsemble.getOrigin(COMPENSATING), getOutputEnsemble().getTermination(COMPENSATING));			
	}
	
	@Override
	public void disableParisien() {
		myCompensatingProjection.removeBias();
		myDepressingProjection.removeBias();
	}

	@Override
	public void enableParisien(float propInhibitory) throws StructuralException {
		int n = Math.round(propInhibitory * (float) getOutputEnsemble().getNodes().length);
		enableParisien(myCompensatingProjection, n);
		
		enableParisien(myDepressingProjection, n, false);
		DecodedOrigin o = ((DecodedOrigin) myDepressingEnsemble.getOrigin(NEFEnsemble.X));
		BiasOrigin bo = ((BiasOrigin) myDepressingEnsemble.getOrigin("output:depressing")); 
		bo.setSTPDynamics(o.getSTPDynamics());
		for (int i = 0; i < bo.getDecoders().length; i++) {
			((SynapticDepressionDynamics) bo.getSTPDynamics(i)).setTau(
					((SynapticDepressionDynamics) o.getSTPDynamics(i)).getTau());
			((SynapticDepressionDynamics) bo.getSTPDynamics(i)).setProportionReleased(
					((SynapticDepressionDynamics) o.getSTPDynamics(i)).getProportionReleased());
		}
	}

	@Override
	protected NEFEnsemble getInputEnsemble() {
		return myDepressingEnsemble;
	}

	@Override
	public TimeSeries getInputEnsembleData() {
		return myInputProbe.getData();
	}

	@Override
	public void setTau(float tau) {
		Node[] neurons = myDepressingEnsemble.getNodes();
		try {
			DecodedOrigin o = (DecodedOrigin) myDepressingEnsemble.getOrigin(NEFEnsemble.X);
			for (int i = 0; i < neurons.length; i++) {
				SpikingNeuron neuron = (SpikingNeuron) neurons[i];
				float r0 = getNominalRate(neuron);
				
				//choose F so at equilibrium S = 1/2 at r0 (this determines tauS)
				float F = 1 / (2*tau*r0);
				float tauS = 2*tau;
				SynapticDepressionDynamics d = (SynapticDepressionDynamics) o.getSTPDynamics(i); //note the index i
				d.setTau(tauS);
				d.setProportionReleased(F);
				System.out.println("tauS: " + tauS + " F:" + F);
			}		
			
		} catch (StructuralException e) {
			throw new RuntimeException(e);
		}
		
		try {
			removeProjections();
			setCompensation(tau);
			addProjections();

			float[][] scale = new float[][]{new float[]{4.4f / tau}};
			((DecodedTermination) getOutputEnsemble().getTermination(DEPRESSING)).setTransform(scale);
			((DecodedTermination) getOutputEnsemble().getTermination(COMPENSATING)).setTransform(scale);
		} catch (StructuralException e) {
			throw new RuntimeException(e);
		}		
	}
	
	private void setCompensation(float tau) throws StructuralException {
		try {
			myDepressingEnsemble.removeDecodedTermination("input"); //have to remove this temporarily
			Function f = Util.getBiasCompensation(myDepressingEnsemble, NEFEnsemble.X, tau*8);
			myDepressingEnsemble.addDecodedTermination("input", MU.I(1), TAU_IO, false);
			myDepressingEnsemble.addDecodedOrigin(COMPENSATING, new Function[]{f}, Neuron.AXON);
		} catch (SimulationException e) {
			throw new StructuralException(e);
		}
	}	
	
	private static float getNominalRate(SpikingNeuron neuron) {
		SimulationMode mode = neuron.getMode();
		float rate = 0;
		
		try {
			neuron.setMode(SimulationMode.CONSTANT_RATE);
			neuron.setRadialInput(0);
			neuron.run(0, 0);
			rate = ((RealOutput) neuron.getOrigin(Neuron.AXON).getValues()).getValues()[0];
			neuron.setMode(mode);
		} catch (SimulationException e) {
			throw new RuntimeException(e);
		} catch (StructuralException e) {
			throw new RuntimeException(e);
		}
		
		return rate;
	}

	@Override
	public void clearErrors() {
		try {
			((Noisy) myDepressingEnsemble.getOrigin(NEFEnsemble.X)).setNoise(new NoiseFactory.NoiseImplNull());
			((Noisy) myDepressingEnsemble.getOrigin(COMPENSATING)).setNoise(new NoiseFactory.NoiseImplNull());
		} catch (StructuralException e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public void setDistortion(int nInput, int nDiff) {
		try {
			((Noisy) myDepressingEnsemble.getOrigin(NEFEnsemble.X)).setNoise(makeDistortion(nInput));
			((Noisy) myDepressingEnsemble.getOrigin(COMPENSATING)).setNoise(makeDistortion(nInput));
		} catch (StructuralException e) {
			throw new RuntimeException(e);
		}
	}

	@Override
	public void setNoise(int nInput, int nDiff) {
		try {
			((Noisy) myDepressingEnsemble.getOrigin(NEFEnsemble.X)).setNoise(makeNoise(nInput));
			((Noisy) myDepressingEnsemble.getOrigin(COMPENSATING)).setNoise(makeNoise(nInput));
		} catch (StructuralException e) {
			throw new RuntimeException(e);
		}
	}
	
	public static void main(String[] args) {
		try {
			DepressionNetwork n = new DepressionNetwork(500);
			n.setTau(.5f);
		} catch (StructuralException e) {
			e.printStackTrace();
		}
	}

}