function [PerCorrect FiringRate FIndex W12perTrial W23perTrial RasterPlot] = ...
    spikingNetworkContextLearning(nTrial)
% spikingNetworkContextLearning
%   nTrial  - Number of trials. Numerous other parameters are specified in 
%             the function itself.
%
% RETURN
%   PerCorrect  - Percent correct with dimensions: nTrial x 1.
%   FiringRate  - Firing rate with dimensions: nTrial x nStim x nHippo.
%   FIndex      - Functional index as binary matrix. A strong weight 
%                 connection from the 1st to 2nd layer indicates this 
%                 connection being funcational. This index has the 
%                 dimensions: 1 x nHippo.
%   W12PerTrial - Weight matrix from 1st to 2nd layer for all trials. This 
%                 matrix has the dimensions: nTrial x nInput x nHippo.
%   W23PerTrial - Weight matrix from 2nd to 3rd layer for all trials. This
%                 matrix has the dimensions: nTrial x nHippo x nOutput.
%
% DESCRIPTION
%   This is the main network simulation. It inlcude the
%   - initialization of the network
%   - the spiking simulation for each trial with its phase of the trial and 
%     phase of reply.
%   - during replay the synaptic weights between 1st/2nd and 2nd/3rd layer
%     are adapted.
%

%   Florian Raudies, 09/07/2014, Boston University.

IntvlTrial  = [0 4000]; % ms
IntvlReplay = [0 400];  % ms
dt          = 0.5;      % time step
dtWindow    = 10;       % ms
TimeTrial   = ( IntvlTrial(1)  : dt : IntvlTrial(2)  )';
TimeReplay  = ( IntvlReplay(1) : dt : IntvlReplay(2) )';
nTimeTrial  = length(TimeTrial);
nTimeReplay = length(TimeReplay);
nWindow     = dtWindow/dt;
V_PEAK      = 0;            % in Volts, these are 0 mV.
V_TH        = -50*10^-3;    % in Volts, these are -50 mV.
V_RESET     = -70*10^-3;    % in Volts, these are -70 mV.
ETA         = 10^-6;        % Threshold for equal

%            Context and Place   Odor
% Input for: A1 | B1 | A2 | B2 | X | Y
InVec       = [1 0 0 0 1 0];

% Per definition the rewarded stimuli are: A1X  A2X  B1Y  B2Y.
reward      = @(X) (X(1) && X(5)) || (X(3) && X(5)) ...
                || (X(2) && X(6)) || (X(4) && X(6));
stimulusIndex = @(X) 1*X(1)+2*X(2)+3*X(3)+4*X(4)+4*X(6);

% Output action vector: Dig | Move
OutVec  = [0 0];

% Number of neurons to simulate the hippocampus.
nHippo  = 8;
nIn     = length(InVec);
nOut    = length(OutVec);
nStim   = 8;

% Randomly initialize the synaptic coupling strengths (weights).
% Per random some of these have place cell selectivity.
W12 = rand(nIn,nHippo);
W23 = rand(nHippo,nOut);
W22 = ones(nHippo,nHippo)   -eye(nHippo); % Inhibition weights
W33 = ones(nOut,nOut)       -eye(nOut);

% Number of steps in the history.
nHist       = 2;
InVecHst    = zeros(nHist,nIn);
HippoVecHst = zeros(nHist,nHippo);
OutVecHst   = zeros(nHist,nOut);

% Buffer for spikes within the STDP window.
nMaxSpike   = 10;
T1          = TimeBuffer(nMaxSpike,nIn,dtWindow);
T2          = TimeBuffer(nMaxSpike,nHippo,dtWindow);
T3          = TimeBuffer(nMaxSpike,nOut,dtWindow);

% For performance reasons randomize all indices.
Index       = rand(nTrial,2);

% Percent correct detected.
PerCorrect  = zeros(nTrial,1);
nMaxTime    = 1200;
nMaxSample  = 100;
RasterPlot  = ManySlotBuffer(nStim*nHippo,nMaxSample,nMaxTime);

% Set the options for the LIF neuron and STDP rule.
opt.V_PEAK  = V_PEAK;
opt.V_TH    = V_TH;
opt.V_RESET = V_RESET;
opt.dt      = dt;

% Define matrices for weights per trial.
W12perTrial = zeros(nTrial,nIn,nHippo);
W23perTrial = zeros(nTrial,nHippo,nOut);

% *************************************************************************
% Loop over all trials.
% *************************************************************************
for iTrial = 1:nTrial,
    % Start trial in a random state.
    InVec = zeros(1,nIn);
    InVec(1+round(Index(iTrial,1)*3))   = 1;
    InVec(1+4+(Index(iTrial,2)>0.5))    = 1;
    % Reset counter for buffers.
    nHistCount  = 0;
    InVecHst(1+nHistCount,:) = InVec;
    TraceV1     = nan(nTimeTrial,nIn);
    TraceV2     = nan(nTimeTrial,nHippo);
    TraceV3     = nan(nTimeTrial,nOut);
    rewarded    = 0;
    nMoveSpike  = 0;
    nDigSpike   = 0;
    % Initialize membrane potentials.
    V1      = repmat(V_RESET,[1 nIn]);
    V2      = repmat(V_RESET,[1 nHippo]);
    V3      = repmat(V_RESET,[1 nOut]);
    NoiseV1 = 10^-6*randn(nTimeTrial,1);
    NoiseV2 = 10^-6*randn(nTimeTrial,1);
    NoiseV3 = 10^-6*randn(nTimeTrial,1);
    nThDigSpike     = 5;
    nThMoveSpike    = 5;
    iLastTime       = 1;
    for iTime = 1 : nTimeTrial,
        t           = TimeTrial(iTime);
        opt.I       = InVec;
        opt.G       = repmat(.1,[1 nIn]);        
        V1          = lifModel(t, V1,opt) + NoiseV1(iTime);
        [~,mi]      = max((V1-V_RESET)*W12 - (V2-V_RESET)*W22);
        opt.I       = zeros(1,nHippo);
        opt.I(mi)   = .98; % nA
        V2          = lifModel(t, V2,opt) + NoiseV2(iTime);
        [~,mi]      = max((V2-V_RESET)*W23 - (V3-V_RESET)*W33);
        opt.I       = zeros(1,nOut);
        opt.I(mi)   = .96; % nA
        V3          = lifModel(t, V3,opt) + NoiseV3(iTime);
        TraceV1(iTime,:) = V1;
        TraceV2(iTime,:) = V2;
        TraceV3(iTime,:) = V3;
        % Register any spikes at the output in the output vector.
        OutVec = zeros(1,nOut);
        OutVec(abs(V3-V_PEAK)<=ETA) = 1;
        % Keep the history of the states/firings.
        if any(abs(V2-V_PEAK)<=ETA)
            HippoVecHst(1+nHistCount,:) = double(abs(V2-V_PEAK)<=ETA);
        end
        if any(abs(V3-V_PEAK)<=ETA)
            OutVecHst(1+nHistCount,:) = double(abs(V3-V_PEAK)<=ETA);
        end
        nMoveSpike  = nMoveSpike    + OutVec(2);
        nDigSpike   = nDigSpike     + OutVec(1);
        % Dig ?
        if nDigSpike>=nThDigSpike,
            OutVec(1)       = 1;
            OutVecHst(1+nHistCount,:) = OutVec;
            rewarded = reward(InVec);
            for iHippo = 1:nHippo,
                iSlot   = sub2ind([nStim nHippo],stimulusIndex(InVec),iHippo);
                DataRow = [iTrial; iTime-iLastTime+1; ...
                    abs(TraceV2(iLastTime:iTime,iHippo)-V_PEAK)<=ETA];
                RasterPlot.addEntryToSlot(iSlot,DataRow);
            end
            break;
        end
        % Move?
        if nMoveSpike>=nThMoveSpike,
            nThMoveSpike    = 5;
            nThDigSpike     = max(nThDigSpike - 1,  0);
            nMoveSpike      = 0;
            OutVec(2)       = 1;
            OutVecHst(1+nHistCount,:) = OutVec;
            for iHippo = 1:nHippo,
                iSlot   = sub2ind([nStim nHippo],stimulusIndex(InVec),iHippo);
                DataRow = [iTrial; iTime-iLastTime+1; ...
                    abs(TraceV2(iLastTime:iTime,iHippo)-V_PEAK)<=ETA];
                RasterPlot.addEntryToSlot(iSlot,DataRow);
            end
            iLastTime = iTime+1;
            % Move to the other place.
            Tmp = InVec(3:4);
            InVec(3:4) = InVec(1:2);
            InVec(1:2) = Tmp;
            % Then percept changes too.
            InVec(5) = InVec(6);
            InVec(6) = ~InVec(5);
            % Increment the counter for the buffer.
            nHistCount = mod(nHistCount + 1,nHist);
            InVecHst(1+nHistCount,:) = InVec;
            % Assume there is a break and all the membrane potential return
            % to their resting state.
            V1 = repmat(V_RESET,[1 nIn]);
            V2 = repmat(V_RESET,[1 nHippo]);
            V3 = repmat(V_RESET,[1 nOut]);            
        end
    end
    PerCorrect(iTrial) = rewarded;
    
    % Replay the sequence with the last 1+nHistCount steps.
    for iHist = 1 : (1+nHistCount),
        InVec       = InVecHst(iHist,:);
        HippoVec    = HippoVecHst(iHist,:);
        OutVec      = OutVecHst(iHist,:);
        % Assume there was a break and all membrane potentials return
        % to their resting state value.
        V1          = repmat(V_RESET,[1 nIn]);
        V2          = repmat(V_RESET,[1 nHippo]);
        V3          = repmat(V_RESET,[1 nOut]);
        TraceV1     = nan(nTimeReplay,nIn);
        TraceV2     = nan(nTimeReplay,nHippo);
        TraceV3     = nan(nTimeReplay,nOut);
        TraceW12    = nan(nTimeReplay,nIn,nHippo);
        TraceW23    = nan(nTimeReplay,nHippo,nOut);
        % Clear out any remaining spike times.
        T1.clear();
        T2.clear();
        T3.clear();
        % Start the replay sequence.
        for iTime = 1 : nTimeReplay,
            t       = TimeReplay(iTime);
            % Replay in forward direction --- per STDP strenghening
            if rewarded,
                opt.I   = InVec;
                V1      = lifModel(t, V1, opt);
                opt.I   = .98*HippoVec;
                V2      = lifModel(t, V2, opt);
                opt.I   = .96*OutVec;
                V3      = lifModel(t, V3, opt);
            % Replay in inverse direction --- per STDP weakening
            else
                opt.I   = .96*InVec;
                V1      = lifModel(t, V1, opt);
                opt.I   = .98*HippoVec;
                V2      = lifModel(t, V2, opt);
                opt.I   = OutVec;
                V3      = lifModel(t, V3, opt);
            end                    
            % Retire spike times which are too old.
            T1.retire(t);
            T2.retire(t);
            T3.retire(t);
            % Register new spike times.
            T1.addTime(t,abs(V1-V_PEAK)<=ETA);
            T2.addTime(t,abs(V2-V_PEAK)<=ETA);
            T3.addTime(t,abs(V3-V_PEAK)<=ETA);
            % Update the synaptic weights.
            if iTime >= nWindow,
                for iIn = 1:nIn,
                    TimePre     = T1.time(iIn);
                    nPre        = length(TimePre);
                    if nPre==0, continue; end
                    for iHippo = 1:nHippo,
                        TimePost    = T2.time(iHippo);
                        nPost       = length(TimePost);
                        % Are there any spikes for the pre- and the 
                        % post-synaptic neuron in the time window?
                        if nPre>0 && nPost>0,
                            opt.TimePre     = TimePre;
                            opt.TimePost    = TimePost;
                            W12(iIn,iHippo) = stdpModel(...
                                                t,W12(iIn,iHippo),opt);
                        end
                    end
                end
                for iHippo = 1:nHippo,
                    TimePre     = T2.time(iHippo);
                    nPre        = length(TimePre);
                    if nPre==0, continue; end
                    for iOut = 1:nOut,
                        TimePost    = T3.time(iOut);
                        nPost       = length(TimePost);
                        % Are there any spikes for the pre- and the 
                        % post-synaptic neuron in the time window?
                        if nPre>0 && nPost>0,
                            opt.TimePre     = TimePre;
                            opt.TimePost    = TimePost;
                            W23(iHippo,iOut) = stdpModel(...
                                                t,W23(iHippo,iOut),opt);
                        end
                    end
                end
            end
            TraceV1(iTime,:)    = V1;
            TraceV2(iTime,:)    = V2;
            TraceV3(iTime,:)    = V3;
            TraceW12(iTime,:,:) = W12;
            TraceW23(iTime,:,:) = W23;
        end
    end
    W12perTrial(iTrial,:,:) = W12;
    W23perTrial(iTrial,:,:) = W23;
end

opt.nTrial      = nTrial;
opt.nStim       = nStim;
opt.nCell       = nHippo;
opt.nMaxSample  = nMaxSample;

% Calculate the firing rate for each hippocampal cell, trial, and stimulus. 
FiringRate = rasterPlotToFiringRate(RasterPlot, opt);
% Calculate a binary vector indicating whether a hippocampal cell is part
% of the functional network or not.
FIndex = max(W12,[],1)>(1-ETA);