% Simulation setup for the results presented in Section 5.
% Analysis of synaptic integration of spikes from multiple heterogeneous neurons via 
% transient semi-analytical method.

%-------------------------------------------------------------------------------------
% Copyright 2018 by Koc University and Deniz Kilinc, A. Gokcen Mahmutoglu, Alper Demir 
% All Rights Reserved
%-------------------------------------------------------------------------------------

clear classes;
clear all
close all
clc

addpath('../cirsiumNeuron');
addpath('../cirsiumNeuron/steady_state_method');
addpath('../cirsiumNeuron/transient_method');

membraneArea = 1000e-12;
currentDensity = 25e-2;
T_vIn = 500e-3;
Ihigh = currentDensity*membraneArea;

Ninput = 4; 
neuron1 = neuronMC(membraneArea, 0, 0, 0, 'MC', 'Neuron1');
neuron2 = neuronMC(membraneArea, 0, 0, 0, 'MC', 'Neuron2');
neuron3 = neuronMC(membraneArea, 0, 0, 0, 'MC', 'Neuron3');
neuron4 = neuronMC(membraneArea, 0, 0, 0, 'MC', 'Neuron4');
%neuron5 = neuronMC(membraneArea, 0, 0, 0, 'MC', 'Neuron5');
%neuron6 = neuronMC(membraneArea, 0, 0, 0, 'MC', 'Neuron6');
%neuron7 = neuronMC(membraneArea, 0, 0, 0, 'MC', 'Neuron7');
%neuron8 = neuronMC(membraneArea, 0, 0, 0, 'MC', 'Neuron8');
%neuron9 = neuronMC(membraneArea, 0, 0, 0, 'MC', 'Neuron9');
%neuron10 = neuronMC(membraneArea, 0, 0, 0, 'MC', 'Neuron10');
neuron_out = neuronMC(membraneArea, Ninput, 0, 0, 'MC', 'Neuron_out');

neuron1.synapse(neuron_out,'exReceptorMC');
neuron2.synapse(neuron_out,'exReceptorMC');
neuron3.synapse(neuron_out,'exReceptorMC');
neuron4.synapse(neuron_out,'exReceptorMC');
%neuron5.synapse(neuron_out,'exReceptorMC');
%neuron6.synapse(neuron_out,'exReceptorMC');
%neuron7.synapse(neuron_out,'exReceptorMC');
%neuron8.synapse(neuron_out,'exReceptorMC');
%neuron9.synapse(neuron_out,'exReceptorMC');
%neuron10.synapse(neuron_out,'exReceptorMC');


f_vIn = @(t)(cosTapRect(t, T_vIn, 1/1000, 1, 1));
iIn1 = currentSource(Ihigh, f_vIn, 'Iin1');
iIn2 = currentSource(Ihigh, f_vIn, 'Iin2');
iIn3 = currentSource(Ihigh, f_vIn, 'Iin3');
iIn4 = currentSource(Ihigh, f_vIn, 'Iin4');
%iIn5 = currentSource(Ihigh, f_vIn, 'Iin5');
%iIn6 = currentSource(Ihigh, f_vIn, 'Iin6');
%iIn7 = currentSource(Ihigh, f_vIn, 'Iin7');
%iIn8 = currentSource(Ihigh, f_vIn, 'Iin8');
%iIn9 = currentSource(Ihigh, f_vIn, 'Iin9');
%iIn10 = currentSource(Ihigh, f_vIn, 'Iin10');

ckt = circuitMC('Single Neuron Test');

ckt.addComponent(neuron1, 'Node1', 'gnd');
ckt.addComponent(neuron2, 'Node2', 'gnd');
ckt.addComponent(neuron3, 'Node3', 'gnd');
ckt.addComponent(neuron4, 'Node4', 'gnd');
%ckt.addComponent(neuron5, 'Node5', 'gnd');
%ckt.addComponent(neuron6, 'Node6', 'gnd');
%ckt.addComponent(neuron7, 'Node7', 'gnd');
%ckt.addComponent(neuron8, 'Node8', 'gnd');
%ckt.addComponent(neuron9, 'Node9', 'gnd');
%ckt.addComponent(neuron10, 'Node10', 'gnd');
ckt.addComponent(neuron_out, 'Node_out', 'gnd');

ckt.addComponent(iIn1, 'gnd', 'Node1');
ckt.addComponent(iIn2, 'gnd', 'Node2');
ckt.addComponent(iIn3, 'gnd', 'Node3');
ckt.addComponent(iIn4, 'gnd', 'Node4');
%ckt.addComponent(iIn5, 'gnd', 'Node5');
%ckt.addComponent(iIn6, 'gnd', 'Node6');
%ckt.addComponent(iIn7, 'gnd', 'Node7');
%ckt.addComponent(iIn8, 'gnd', 'Node8');
%ckt.addComponent(iIn9, 'gnd', 'Node9');
%ckt.addComponent(iIn10, 'gnd', 'Node10');
ckt.setGroundNode('gnd');
ckt.seal();

% solver settings
solverType = 'SEQ-TRPZ';
linSolverType = 'GMRES';
currentVariables = 'on';
showSolution = 'off';
t0 = 0;
tend = 1*T_vIn;
timeStep = 1e-5;

% initial conditions
nTr = ckt.numMCs;
switch currentVariables
    case 'on'
        nV = ckt.numVars;
        nVV = ckt.numIndepVoltVars;
    case 'off'
        nV = ckt.numVarsnc;
        nVV = nV;
end

y0 = -0.065*ones(nV,1); %initial state vector for membrane voltages
s0=[];                %initial state vector for ion channels

inhibitory_channels = 20000; %number inhibitory receptor channels per synapse
excitatory_channels = 20000; %number excitatory receptor channels per synapse

for i=1:1:nTr
    if isa(ckt.MCs{i},'inReceptorMC') == true
        s0 = [s0; inhibitory_channels*ckt.MCs{i}.stateVector];
    elseif isa(ckt.MCs{i},'exReceptorMC') == true
        s0 = [s0; excitatory_channels*ckt.MCs{i}.stateVector];
    else
        s0 = [s0; ckt.MCs{i}.stateVector*(membraneArea/1e-12)];
    end
end

%% noisy transient simulation
solver = transientSolverMC(ckt, 'solverType', solverType,...
                            'linSolverType', linSolverType,...
                            'showSolution', showSolution,... 
                            'currentVariables', currentVariables,...
                            'y0', y0,...
                            'yp0', [],...
                            't0', t0,...
                            's0', s0,...
                            'sp0', [],...
                            'seq0', s0,...
                            'timeStep', timeStep,...
                            'tend', tend,...
                            'breakPoints', [],...
                            'reltol', 1e-3,...
                            'abstol_v', 1e-6,...
                            'abstol_c', 1e-9,...
                            'abstol_q', 1e-21,...
                            'chargeScaleFactor', 1e4/T_vIn,...
                            'lmax', 1e12);

tic
solver.solve();
duration = toc;

solver.displaySolution;

%% variance calculation with transient method
solverType = 'TRPZ-RK';
stateSpaceType = 'standard';
timeStep = 2.0e-6;
tref = tend;
Kind = 1:(Ninput+1);

solverTD = nonMonteCarloSolverMC(ckt, 'solverType', solverType,...
                            'linSolverType', linSolverType,...
                            'showSolution', showSolution,... 
                            'currentVariables', currentVariables,...
                            'y0', y0,...
                            'yp0', [],...
                            't0', t0,...
                            's0',s0,...
                            'sp0', [],...
                            'seq0', s0,...
                            'timeStep', timeStep,...
                            'tend', tend,...
                            'tref', tref,...
                            'Kind', Kind,...
                            'breakPoints', [],...
                            'reltol', 1e-3,...
                            'abstol_v', 1e-6,...
                            'abstol_c', 1e-9,...
                            'abstol_q', 1e-21,...
                            'chargeScaleFactor', 1e2/T_vIn,...
                            'lmax', 1e12);
tic
solverTD.solve();
toc

solverTD.displaySolution;
%%

Vout = solverTD.Y(Ninput+1,:);
t = solverTD.T;
dVout = diff(Vout(1,:))/(t(3)-t(2));
dVout_max = max(dVout(1,end-10000:end));

Var_alpha_out = solverTD.K(Ninput+1,:)/dVout_max^2;

figure
plot(t,Var_alpha_out); grid on;
title('Var[\alpha_{out}(t)]');

