%% This code runs based on the previously published 
%% paper: http://journals.plos.org/ploscompbiol/article?id=10.1371/journal.pcbi.1003787
%% and code: https://senselab.med.yale.edu/modeldb/showModel.cshtml?model=155565
%% note that all files in lib essentially stayed the same. So if you want to run a bigger sheet, please see previous code for the generation of the connectivity matrix.
%% contact yujiang.wang@ncl.ac.uk for more details


clear all
close all
addpath('lib')
%load('Conns_n150.mat') I recommend generating the connectivity matrix once
%and saving it and loading it in future...

tic;
n=150;

%calculate Connectivity Matrix for Py->Py 
CeLoc=GaussianLocConnFunc(n,@distTorus,5);%r_loc=500 micrometre, hence the standard deviation of the gaussian is 250 micrometre, which corresponds to 5 units.
CeLoc(CeLoc>0)=1;
%calculate Connectivity Matrix for Py->In
CeLocI=GaussianLocConnFunc(n,@distTorus,5);%r_loc=500 micrometer
CeLocI(CeLocI>0)=1;
toc;
%remote conn
tic;
nOut=round(mean(sum(CeLoc,1))*4/6);%number of outgoing connections per mini column
%patchSize*numPatches should be > nOut!!!
remRad=75;%3750 micrometers
nM=10;
patchSize=round(5^2*3.14/2);
numPatches=6;
nOverlap=3;
CeRem=ConnPatchyRemOverlap(n,nM,patchSize,numPatches,remRad,nOut,nOverlap,@distTorus,@makeCellClusterToroidal);
toc;


%% parameters
n=150;
tinterp=1;

parameters=getParam(n,CeRem,CeLoc,CeLocI);

P=-1.5;
parameters.PyInput=P*(ones(n^2,1));
%% prescan
        %prescan===============================
        tend=1.5;
        
        T=0:parameters.h*tinterp:tend;%for plotting
        nIt=(tend)/parameters.h+1;
        tic
            %using all 0 initial condition for Py
            initCond=zeros(2*n^2,1);
            parameters.NValue=getNoise(nIt,n);

            Y=runSheet(initCond,parameters);
            bgInitc=Y(end,:);
        toc

        
        %% real run
        
        parameters=getParam(n,CeRem,CeLoc,CeLocI);
        parameters.tauPy=1.5*1*ones(n^2,1)/25;%time scale of the populations
        parameters.tauInh=1.5*0.5*ones(n^2,1)/25;

        tend=15;
        T=0:parameters.h*tinterp:tend;%for plotting
        nIt=(tend)/parameters.h+1;
        %!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        ClusterNum=1; %change this number to change the number of subclusters
        % !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
        PercentHetP=0.05;
        
        parameters.PyInput=P*(ones(n^2,nIt));
        
        CellLocVAll=[];
        CellLocAll=zeros(n,n);
        Pt=1;
        
        for cn=1:ClusterNum
            [CellLoc,CellLocV] = makeCellCluster(1,PercentHetP/ClusterNum,n);%to scan: clustering coeff and percent of bad cells
            CellLocVAll=[CellLocVAll, CellLocV];
            CellLocAll=CellLocAll+CellLoc;
            
            r=1;
            Ramp=[P*ones(1,1000), P:abs((P-Pt))/(1500+r):Pt, Pt*ones(1,5000-r)]; 

%             
            parameters.PyInput(CellLocV,:)=repmat(Ramp,length(CellLocV),1);
            
        end
        CellLocVAll=unique(CellLocVAll);
        
        
        
 %%       
        %real runs===============================
        


        %using initial condition all 1
        tic
        initCond=bgInitc;
        
        parameters.NValue=getNoise(nIt,n);
        
        
        Y=runSheetPRamp(initCond,parameters);
        toc
        %%
        Py=Y(1:tinterp:end,1:n^2); 
        
        Inh=Y(1:tinterp:end,n^2+1:end);
        SCInput=parameters.NValue(:,1:tinterp:end);
        PyInput=parameters.PyInput(:,1:tinterp:end);
        LFP=parameters.Py2Py*double(Py') - parameters.Inh2Py*double(Inh') + 0.1*SCInput;% + PyInput + SCInput;
        LFP=single(LFP');

        mPy=mean(Py,2);sPy=std(Py,0,2);
        mLFP=mean(LFP,2);

        [mMacroCol,mMacroColLFP]=meanMacroCol(n,10,Py,LFP);


        
%%  plot
load('MayColourMap')


figure(1)
plot(T,mLFP)
title('Raw model output LFP')

samplingF=1/T(2);
filtmLFP=FilterEEG(mLFP, 1, samplingF, 'high', 3);
figure(2)
plot(T,filtmLFP)
title('Filtered model output LFP - with DC shift')



figure(5)
TPts=[1501:200:1501+5*200];
for k=1:length(TPts)
    subplot(1,length(TPts),k)
    imagesc(reshape(Py(TPts(k),:),n,n) )
    %colormap(mycmap)
    caxis([0 0.5])
    title(sprintf('T=%g',T(TPts(k))))
end