function[in,pos,LFP,gridCells,cycle,cycleDecoding] = humanGridSim(in)
%% Script to simulate grid cell activity in a 1D or 2D environment using
%  a phenomenological model adapted from Chadwick et al. eLife 4:e03542 
%  (2015), and then decode movement trajectory from grid cell firing rate 
%  and phase in each oscillatory cycle. 
%  Daniel Bush, UCL (2019) drdanielbush@gmail.com
%
%  Described in Bush and Burgess (2020) Detection and advantages of phase
%  coding in the absence of rhythmicity. Hippocampus (in press)
%  https://doi.org/10.1002/hipo.23199
%
%  NOTE:
%  Requires Tom O'Haver's fast smoothing function from:
%  https://www.mathworks.com/matlabcentral/fileexchange/19998-fast-smoothing-function
%
%  INPUTS, as fields of the 'in' structure, default values indicated in []:
%  gridScales   = Grid scales (cm), for five modules [30*1.4.^(0:4)]
%  nGCs      	= Number of grid cells per module [40]
%  sampleRate   = Simulation time step (Hz) [200]
%  sigmas       = Firing field width parameter (cm) [in.gridScales./10]
%  meanRate     = Mean firing rate (Hz) [1]
%  phaseLock    = Phase locking (true) or precession (false)? [false]
%  phaseMod     = Extent of phase coding 'k' (au) [1.5]
%  speedSlope   = Slope of running speed v firing rate (Hz/cm/s) [5/30]
%  environment  = Environment type (1D or 2D) ['2D']
%  trackFile    = 2D tracking data file [ceil(rand*3)]
%  vRange       = Moving speed range for 1D environments [2 30]
%  lfpType   	= LFP type ('constant' or 'human') ['human']
%  freqRange   	= LFP frequency range for human LFP (Hz) [2 20]
%  nPhaseBins   = Number of phase bins for decoding analysis [5]
%  fieldRateVar	= Standard deviation of in-field firing variability [0]



%% Provide some general parameters for the analysis
in              = dealWithInputs(in);

%  ...generate or load some tracking data
disp('Generating / importing tracking data...'); drawnow
pos             = trackingData(in);

%  ...generate or load some LFP data
disp('Generating / importing LFP data...'); drawnow
LFP             = lfpData(in,pos);

%  ...simulate grid cell firing patterns
disp('Generating grid cell spike trains...'); drawnow
gridCells       = gridCellFiring(in,pos,LFP);

%  ...chunk the data by oscillatory cycle
disp('Chunking the data by oscillatory cycle...'); drawnow
cycle           = chunkData(in,pos,LFP,gridCells);

%  ...decode location / movement direction / running speed / anxiety
disp('Decoding movement trajectory from activity in each cycle...'); drawnow
cycleDecoding	= decodeCycles(in,pos,gridCells,cycle);

%  ...remove superfluous variables to save memory
gridCells       = rmfield(gridCells,'r_rate');
gridCells       = rmfield(gridCells,'r_phase');
if isfield(gridCells,'r_pure')
    gridCells  	= rmfield(gridCells,'r_pure');
end

%  ...and plot some summary decoding figures
plotResults(cycle,cycleDecoding);
clc

end



%% Function to organise input settings
function[in]	= dealWithInputs(in)
if ~isfield(in,'gridScales') || isempty(in.gridScales)
    in.gridScales   = 30*1.4.^(0:4);        % Grid scales (cm), five modules
end
if ~isfield(in,'nGCs') || isempty(in.nGCs)
    in.nGCs      	= 40;                   % Number of grid cells per module
end
if ~isfield(in,'sampleRate') || isempty(in.sampleRate)
    in.sampleRate   = 200;                  % Simulation time step (Hz)
end
if ~isfield(in,'sigmas') || isempty(in.sigmas)
    in.sigmas       = in.gridScales./10;    % Firing field width parameter (cm)
end
if ~isfield(in,'meanRate') || isempty(in.meanRate)
    in.meanRate     = 1;                    % Mean firing rate (Hz)
end
if ~isfield(in,'phaseLock') || isempty(in.phaseLock)
    in.phaseLock    = false;                % Phase locking (true) or precession (false)?
end
if ~isfield(in,'phaseMod') || isempty(in.phaseMod)
    in.phaseMod     = 1.5;                  % Extent of phase coding 'k' (au)
end
if ~isfield(in,'speedSlope') || isempty(in.speedSlope)
    in.speedSlope   = 5/30;                 % Slope of running speed v firing rate (Hz/cm/s)
end
if ~isfield(in,'environment') || isempty(in.environment)
    in.environment  = '2D';                 % Environment type (1D or 2D)
end
if strcmp(in.environment,'2D') && (~isfield(in,'trackFile') || isempty(in.trackFile))
    in.trackFile    = ceil(rand*3);         % 2D tracking data file
end
if ~isfield(in,'vRange') || isempty(in.vRange)
    in.vRange       = [2 30];               % Moving speed range for 1D environments
end
if ~isfield(in,'lfpType') || isempty(in.lfpType)
    in.lfpType   	= 'human';              % LFP type ('constant' or 'human')
end
if ~isfield(in,'freqRange') || isempty(in.freqRange)
    in.freqRange   	= [2 20];               % LFP frequency range for human LFP (Hz)
end
if ~isfield(in,'nPhaseBins') || isempty(in.nPhaseBins)
    in.nPhaseBins   = 5;                    % Number of phase bins for decoding analysis
end
if ~isfield(in,'fieldRateVar') || isempty(in.fieldRateVar)
    in.fieldRateVar = 0;                    % Standard deviation of in-field firing variability
end
end



%% Function to generate or load tracking data
function[pos]   = trackingData(in)
switch in.environment
    
    % ...either randomly generate 1D tracking data
    case '1D'        
        t_log   = 1/in.sampleRate : 1/in.sampleRate : 30;   % Create a time base (s)
        v_log   = cumsum(randn(length(t_log),1));           % Assign time varying velocity (cm/s)
        v_log   = ((v_log-min(v_log))./range(v_log).*diff(in.vRange)+in.vRange(1))'; clear v_range
        x_log   = cumsum(v_log ./ in.sampleRate);           % Compute x co-ordinates
        y_log   = ones(size(x_log));                        % Assign y co-ordinates        
        
	% ...randomly generate longer 1D tracking data
    case '1Dlong'        
        t_log   = 1/in.sampleRate : 1/in.sampleRate : 300; 	% Create a time base (s)
        v_log   = cumsum(randn(length(t_log),1));           % Assign time varying velocity (cm/s)
        v_log   = ((v_log-min(v_log))./range(v_log).*diff(in.vRange)+in.vRange(1))'; clear v_range
        x_log   = cumsum(v_log ./ in.sampleRate);           % Compute x co-ordinates
        y_log   = ones(size(x_log));                        % Assign y co-ordinates        
        
  	% or load and manipulate 2D tracking data
    case '2D'
        load('trackingData.mat')
        string  = ['xy = Square' int2str(in.trackFile) './PixPerM*100;']; eval(string); clear string Square1 Square2 Square3 PixPerM
        t_log   = linspace(0,size(xy,1)/Fs,size(xy,1)*in.sampleRate/Fs);
        t_dat   = linspace(0,size(xy,1)/Fs,size(xy,1));
        x_log   = xy(:,1);
        y_log   = xy(:,2); clear xy
        v_log   = nan(size(x_log));
        v_log(1:end-1)	= sqrt(diff(x_log).^2 + diff(y_log).^2) * Fs; clear Fs
        v_log(end)      = v_log(end-1);
        x_log   = interp1(t_dat,x_log,t_log);
        y_log   = interp1(t_dat,y_log,t_log);
        v_log   = interp1(t_dat,v_log,t_log); clear t_dat
        
    case '2Dlin'        
        boxSize     = [100 100];    % Box size, cm
        runSpeed    = 30;           % Constant running speed, cm/s
        x_log       = nan(1,ceil(prod(boxSize)/runSpeed*in.sampleRate));
        y_log       = nan(1,ceil(prod(boxSize)/runSpeed*in.sampleRate));
        v_log       = runSpeed*ones(1,ceil(prod(boxSize)/runSpeed*in.sampleRate));
        t_log       = linspace(1/in.sampleRate,ceil(prod(boxSize)/runSpeed),ceil(prod(boxSize)/runSpeed*in.sampleRate));
        x_log(1)    = 1;
        y_log(1)    = 1;
        dir         = 1;
        for t       = 2 : length(x_log)            
            x_log(1,t)  = x_log(1,t-1) + dir.*runSpeed/in.sampleRate;
            y_log(1,t)  = y_log(1,t-1);
            if x_log(1,t) > boxSize(1) || x_log(1,t) < 0
                dir         = dir * -1;
                y_log(1,t)  = y_log(1,t) + 1;
            end            
        end
        clear t dir
        
end
pos.x_log       = x_log'; clear x_log
pos.y_log       = y_log'; clear y_log
pos.v_log       = v_log'; clear v_log
pos.t_log       = t_log'; clear t_log
pos.dt          = 1 / in.sampleRate;

end



%% Function to generate or load LFP data
function[LFP]   = lfpData(in,pos)
switch in.lfpType
    case 'human'                
        load('sampleEEG.mat');
        [b,a]           = butter(2,in.freqRange/(Fs/2));                % Generate second order Butterworth filter
        freq            = filtfilt(b,a,eeg(:,1)); clear b a             % Filter EEG data in frequency range of interest
        freq            = angle(hilbert(freq));                         % Get the phase at each time point
        freq            = angle(exp(1i.*diff(freq)));                   % Get the phase difference and wrap around
        freq(freq<0)    = (in.freqRange(1)*2*pi)/Fs;                    % Ignore negative frequencies
        freq            = fastsmooth(freq,Fs/20,3,1);                   % Smooth the data
        freq            = interp1(eeg(2:end,2), freq, pos.dt : pos.dt : eeg(end,2));
        eeg             = interp1(eeg(:,2), eeg(:,1), pos.dt : pos.dt : eeg(end,2));
        freq            = (freq./(2*pi)).*Fs; clear Fs                  % Convert to dynamic frequency
        startInd        = ceil(rand*(length(freq)-length(pos.t_log))); 	% Choose a random segment
        freq            = freq(startInd:startInd+length(pos.t_log)-1);  % Crop the frequency data
        eeg             = eeg(startInd:startInd+length(pos.t_log)-1);   % Crop the eeg data
        phase           = cumsum([0 2.*pi.*pos.dt.*freq(2:end)]);      	% Compute LFP phase at each timepoint        
        anxiety        	= (freq-min(freq))./range(freq); clear startInd	% Ascribe dynamic 'anxiety' variable
    
    case 'constant'
        anxiety         = ones(size(pos.t_log))';
        freq            = 8 .* anxiety;
        phase           = cumsum([0 2.*pi.*pos.dt.*freq(2:end)]);
        eeg             = cos(phase);
        
end
LFP.eeg         = eeg;
LFP.phase       = phase;
LFP.freq        = freq;
LFP.anxiety     = anxiety;
end



%% Function to generate grid cell firing patterns
function[gridCells]     = gridCellFiring(in,pos,LFP)
nCells                  = length(in.gridScales)*in.nGCs;
switch in.environment
    case {'1D','1Dlong'}
        r_rate          = nan(nCells,length(pos.x_log));   % Assign memory for the rate code of each grid cell
        r_phase         = nan(nCells,length(pos.x_log));   % Assign memory for the phase code of each grid cell
        r_cell          = nan(nCells,length(pos.x_log));   % Assign memory for the combined firing output of each cell
        for module      = 1 : length(in.gridScales)
            for cell    = 1 : in.nGCs
                ind             = (module-1)*in.nGCs + cell;
                x_centres       = ((cell-1)/in.nGCs)*in.gridScales(module)-in.gridScales(module) : in.gridScales(module) : max(pos.x_log) + in.gridScales(module);
                offset          = min(abs(repmat(pos.x_log,1,length(x_centres)) - repmat(x_centres,length(pos.x_log),1)),[],2);
                r_rate(ind,:)   = exp(-(offset.^2) ./ (2.*in.sigmas(module)).^2);
                if in.fieldRateVar > 0
                    if ind      == 1
                       r_pure   = nan(nCells,length(pos.x_log));
                    end
                    r_var       = zeros(size(r_rate(ind,:)));
                    [~,pks]     = findpeaks(offset);
                    r_var(pks)  = 1;
                    r_var       = 1 + cumsum(r_var);
                    fieldRates  = 1+randn(length(pks)+1,1) .* in.fieldRateVar;
                    fieldRates(fieldRates<0)    = 0; clear pks
                    r_var       = fieldRates(r_var)'; clear fieldRates
                    r_pure(ind,:)   = r_rate(ind,:);
                    r_rate(ind,:)   = r_rate(ind,:) .* r_var; clear r_var
                end
                clear offset
                if in.phaseLock
                    offset      = pi*ones(size(pos.x_log));
                else
                    offset     	= pos.x_log - x_centres(1) - in.gridScales(module)/2; clear x_centres
                    offset      = mod((-1/in.gridScales(module)).*offset,1)*2*pi;                    
                end
                r_phase(ind,:) 	= exp(in.phaseMod*cos(offset-LFP.phase')); clear offset
                r_phase(ind,:)  = r_phase(ind,:) ./ max(r_phase(ind,:));
                r_cell(ind,:)   = r_rate(ind,:) .* r_phase(ind,:);
                r_cell(ind,:)   = r_cell(ind,:) .* LFP.freq .* in.speedSlope .* pos.v_log';
                r_cell(ind,:)   = r_cell(ind,:) ./ sum(r_cell(ind,:)) .* in.meanRate .* range(pos.t_log); clear ind
            end
        end
        firingRates     = poissrnd(r_cell);
        clear module cell r_cell sigmas
        
    case {'2D','2Dlin'}
        r_rate          = nan(nCells,length(pos.x_log));   % Assign memory for the rate code of each grid cell
        r_phase         = nan(nCells,length(pos.x_log));   % Assign memory for the phase code of each grid cell
        r_cell          = nan(nCells,length(pos.x_log));   % Assign memory for the combined firing output of each cell
        for module      = 1 : length(in.gridScales)
            for cell    = 1 : in.nGCs
                ind             = (module-1)*in.nGCs + cell;
                xFields         = ceil(range(pos.x_log)/in.gridScales(module)) + 4;
                yFields         = ceil(range(pos.y_log)/(in.gridScales(module)*sind(60))) + 4;
                [x_centres,y_centres] = gridFields2D(cell,in.nGCs,in.gridScales(module),xFields,yFields); clear xFields yFields
                
                centres         = [repmat(reshape(x_centres,[1 1 length(x_centres)]),length(pos.x_log),1) repmat(reshape(y_centres,[1 1 length(y_centres)]),length(pos.y_log),1)];
                [offset,c]      = min(sqrt(sum((repmat([pos.x_log pos.y_log],[1 1 length(x_centres)])-centres).^2,2)),[],3); clear centres
                r_rate(ind,:)   = exp(-(offset.^2) ./ (2.*in.sigmas(module)).^2); clear offset
                if in.phaseLock
                    offset          = pi*ones(size(pos.x_log));
                else
                    offset          = nan(size(pos.x_log));
                    offset(1:end-1)	= distToFieldCentre([pos.x_log pos.y_log],[x_centres(c) y_centres(c)]) - in.gridScales(module)/2;
                    offset          = mod((1/in.gridScales(module)).*offset,1)*2*pi;
                end
                r_phase(ind,:) 	= exp(in.phaseMod*cos(offset-LFP.phase')); clear offset
                r_phase(ind,:)  = r_phase(ind,:) ./ max(r_phase(ind,:));
                if in.fieldRateVar
                    if ind      == 1
                        r_pure	= nan(nCells,length(pos.x_log));
                    end                    
                    fieldRates  = 1+randn(length(x_centres),1) .* in.fieldRateVar;
                    fieldRates(fieldRates<0)    = 0;
                    c           = fieldRates(c)';
                    r_pure(ind,:)   = r_rate(ind,:);
                    r_rate(ind,:)   = r_rate(ind,:) .* c; 
                end
                clear c                
                r_cell(ind,:)   = r_rate(ind,:) .* r_phase(ind,:);
                r_cell(ind,:)   = r_cell(ind,:) .* LFP.freq .* in.speedSlope .* pos.v_log';
                r_cell(ind,:)   = r_cell(ind,:) ./ nansum(r_cell(ind,:)) .* in.meanRate .* range(pos.t_log); clear ind
            end
        end
        firingRates     = poissrnd(r_cell);
        clear module cell r_cell sigmas
        
end
gridCells.r_rate        = r_rate; clear r_rate
gridCells.r_phase       = r_phase; clear r_phase
gridCells.spikeTrains   = firingRates; clear firingRates
if in.fieldRateVar > 0
    gridCells.r_pure    = r_pure; clear r_pure
end
gridCells.spikeTrains(isnan(gridCells.spikeTrains)) = 0;
end



%% Function to compute field centres for arbitrary 2D grid
function[x_centres,y_centres] = gridFields2D(cell,nGCs,scale,xFields,yFields)

%  Assign some memory
x_centres       = nan(xFields*yFields,1);
y_centres       = nan(xFields*yFields,1);

%  Generate a template to work with
xTemplate       = 0 : xFields-1;
yTemplate       = zeros(1,xFields);

%  Then generate a unit grid field of the requisite size
for yShift      = 0 : yFields-1
    inds        = (1:xFields) + yShift*xFields;
    if mod(yShift,2) == 0
        x_centres(inds)     = xTemplate;
    else
        x_centres(inds)     = xTemplate + cosd(60);
    end
    y_centres(inds)         = yTemplate + yShift*sind(60); clear inds
end
clear yShift

%  Rescale
x_centres       = x_centres * scale;
y_centres       = y_centres * scale;

%  Compute the cell dependent offset
xcoord          = mod(cell-1,sqrt(nGCs));
ycoord          = floor((cell-1)./sqrt(nGCs));
x_shift         = (xcoord + ycoord * cosd(60))/ sqrt(nGCs); clear xcoord
y_shift         = ycoord * sind(60) / sqrt(nGCs); clear ycoord

%  Shift to the origin of choice
x_centres       = x_centres - 2*scale + x_shift*scale;
y_centres       = y_centres - 2*scale + y_shift*scale;
end



%% Function to compute linear distance to perpendicular line through field
%  centre along current trajectory
function[dist]  = distToFieldCentre(loc,grid)

heading    	= diff(loc(:,2)) ./ diff(loc(:,1));
intercept   = loc(1:end-1,2) - heading .* loc(1:end-1,1);
perp        = -1./heading;
perpInt     = grid(1:end-1,2) - perp .* grid(1:end-1,1);

crossX      = (intercept - perpInt) ./ (perp - heading); clear perp perpInt
crossY      = heading .* crossX + intercept; clear heading intercept
crossX(isnan(crossX))   = grid(isnan(crossX),1);
crossY(isnan(crossY))   = loc(isnan(crossY),2);

dist        = sqrt((loc(1:end-1,1) - crossX).^2 + (loc(1:end-1,2) - crossY).^2); 
dist2       = sqrt((loc(2:end,1) - crossX).^2   + (loc(2:end,2) - crossY).^2); clear crossX crossY
dist(dist2>dist)    = -dist(dist2>dist);

end



%% Function to split the data into expected and actual firing rate per 
%  oscillatory cycle, and record mean position in each cycle
function[cycle] = chunkData(in,pos,LFP,gridCells)
[~,peaks]                   = findpeaks(cos(LFP.phase));
nCells                      = length(in.gridScales)*in.nGCs;
nCycles                     = length(peaks)-1;
cycle.nSpikes               = nan(nCells,nCycles);
cycle.expRate               = nan(nCells,length(peaks)-1);
cycle.actLoc                = nan(nCycles,1);
cycle.meanSpeed             = nan(nCycles,1);
cycle.meanAnxiety           = nan(nCycles,1);
cycle.cycleLength           = nan(nCycles,1);
cycle.nSpikes_binned        = nan(nCells,in.nPhaseBins,nCycles);
cycle.expRate_binned        = nan(nCells,in.nPhaseBins,nCycles);
cycle.expPhase_binned       = nan(nCells,in.nPhaseBins,nCycles);
cycle.expComb_binned        = nan(nCells,in.nPhaseBins,nCycles);
cycle.actLoc_binned         = nan(in.nPhaseBins,2,nCycles);
cycle.realVec               = nan(nCycles,1);
if in.fieldRateVar>0
    cycle.expPure           = nan(nCells,length(peaks)-1);
    cycle.expPure_binned	= nan(nCells,in.nPhaseBins,nCycles);
end
for c                       = 1 : nCycles
        
    cycle.cycleLength(c,1)  = (peaks(c+1)-peaks(c)).*pos.dt;
    cycle.nSpikes(:,c)      = sum(gridCells.spikeTrains(:,peaks(c):peaks(c+1)),2);
    cycle.expRate(:,c)      = mean(gridCells.r_rate(:,peaks(c):peaks(c+1)),2);
    cycle.actLoc(c,1)       = mean(pos.x_log(peaks(c):peaks(c+1)));
    cycle.actLoc(c,2)       = mean(pos.y_log(peaks(c):peaks(c+1)));
    cycle.meanSpeed(c,1)    = mean(pos.v_log(peaks(c):peaks(c+1)));       
    cycle.meanAnxiety(c,1)  = mean(LFP.anxiety(peaks(c):peaks(c+1)));           
    
    runningSpikes           = cumsum(sum(gridCells.spikeTrains(:,peaks(c):peaks(c+1)),1));
    [~,phaseBin]            = histc(runningSpikes,linspace(0,runningSpikes(end),in.nPhaseBins+1)); clear runningSpikes    
    phaseBin                = phaseBin';
    phaseBin(phaseBin>in.nPhaseBins) = in.nPhaseBins;
    spikes                  = gridCells.spikeTrains(:,peaks(c):peaks(c+1))';
    [x,y]                   = ndgrid(phaseBin,1:size(spikes,2));    
    cycle.nSpikes_binned(:,:,c)     = accumarray([x(:) y(:)],spikes(:))'; clear spikes    
    rate                    = gridCells.r_rate(:,peaks(c):peaks(c+1))';
    cycle.expRate_binned(:,:,c)     = accumarray([x(:) y(:)],rate(:),[in.nPhaseBins in.nGCs*length(in.gridScales)],@mean)'; clear rate
    phase                   = gridCells.r_phase(:,peaks(c):peaks(c+1))';
    cycle.expPhase_binned(:,:,c)    = accumarray([x(:) y(:)],phase(:),[in.nPhaseBins in.nGCs*length(in.gridScales)],@mean)'; clear phase
    comb                    = (gridCells.r_rate(:,peaks(c):peaks(c+1)) .* gridCells.r_phase(:,peaks(c):peaks(c+1)))';
    cycle.expComb_binned(:,:,c)     = accumarray([x(:) y(:)],comb(:),[in.nPhaseBins in.nGCs*length(in.gridScales)],@mean)'; clear comb
    cycle.actLoc_binned(:,1,c)      = accumarray(phaseBin,pos.x_log(peaks(c):peaks(c+1)),[in.nPhaseBins 1],@mean);
    cycle.actLoc_binned(:,2,c)      = accumarray(phaseBin,pos.y_log(peaks(c):peaks(c+1)),[in.nPhaseBins 1],@mean); 
    if in.fieldRateVar>0
        cycle.expPure(:,c) 	= mean(gridCells.r_pure(:,peaks(c):peaks(c+1)),2);
        pure                = gridCells.r_pure(:,peaks(c):peaks(c+1))';
        cycle.expPure_binned(:,:,c)	= accumarray([x(:) y(:)],pure(:),[in.nPhaseBins in.nGCs*length(in.gridScales)],@mean)'; clear pure
        
        comb                = (gridCells.r_pure(:,peaks(c):peaks(c+1)) .* gridCells.r_phase(:,peaks(c):peaks(c+1)))';
        cycle.expPurC_binned(:,:,c)	= accumarray([x(:) y(:)],comb(:),[in.nPhaseBins in.nGCs*length(in.gridScales)],@mean)'; clear comb
    end    
    clear x y phaseBin
    
    realTraj        = cycle.actLoc_binned(:,:,c);
    bx              = regress(realTraj(:,1),[1:in.nPhaseBins ; ones(1,in.nPhaseBins)]');
    by              = regress(realTraj(:,2),[1:in.nPhaseBins ; ones(1,in.nPhaseBins)]');
    cycle.realVec(c,1)              = atan2(by(1),bx(1)); clear bx by realTraj
end
clear c

%  Normalise each binned rate function by the mean firing rate of that cell
norm             	= sum(cycle.nSpikes,2) ./ sum(cycle.expRate,2); norm(isnan(norm)) = 0;
cycle.expRate                       = cycle.expRate .* repmat(norm,[1 size(cycle.expRate,2)]); clear norm
norm                = sum(cycle.nSpikes,2) ./ sum(sum(cycle.expRate_binned,2),3); norm(isnan(norm)) = 0;
cycle.expRate_binned                = cycle.expRate_binned .* repmat(norm,[1 size(cycle.expRate_binned,2) size(cycle.expRate_binned,3)]); clear norm
norm                = sum(cycle.nSpikes,2) ./ sum(sum(cycle.expPhase_binned,2),3); norm(isnan(norm)) = 0;
cycle.expPhase_binned               = cycle.expPhase_binned .* repmat(norm,[1 size(cycle.expPhase_binned,2) size(cycle.expPhase_binned,3)]); clear norm
norm                = sum(cycle.nSpikes,2) ./ sum(sum(cycle.expComb_binned,2),3); norm(isnan(norm)) = 0;
cycle.expComb_binned                = cycle.expComb_binned .* repmat(norm,[1 size(cycle.expComb_binned,2) size(cycle.expComb_binned,3)]); clear norm
if in.fieldRateVar>0
    norm            = sum(cycle.nSpikes,2) ./ sum(sum(cycle.expPure_binned,2),3); norm(isnan(norm)) = 0;
    cycle.expPure_binned            = cycle.expPure_binned .* repmat(norm,[1 size(cycle.expPure_binned,2) size(cycle.expPure_binned,3)]); clear norm   
    norm            = sum(cycle.nSpikes,2) ./ sum(sum(cycle.expPurC_binned,2),3); norm(isnan(norm)) = 0;
    cycle.expPurC_binned            = cycle.expPurC_binned .* repmat(norm,[1 size(cycle.expPurC_binned,2) size(cycle.expPurC_binned,3)]); clear norm
end

end



%% Function to decode movement trajectory from grid cell activity within
%  each oscillatory cycle
function[decoded]   = decodeCycles(in,pos,gridCells,cycle)

%  Extract / specify some parameters for the analysis
nCells              = in.nGCs * length(in.gridScales);
nCycles             = size(cycle.nSpikes,2);
nPhaseBins          = size(cycle.nSpikes_binned,2);
binSize             = 2;

%  Compute the mean rate function for each location bin
switch in.environment
    case {'1D','1Dlong'}
        posBins   	= 0:binSize:ceil(max(pos.x_log)/binSize)*binSize;
        [~,posBin]  = histc(pos.x_log,posBins);
        [x,y]       = ndgrid(posBin,1:nCells); clear posBin
        
        rate        = gridCells.r_rate';
        avgRate     = accumarray([x(:) y(:)],rate(:),[length(posBins)-1 nCells],@mean)'; clear rate 
        norm        = sum(cycle.nSpikes,2) ./ sum(avgRate,2);
        avgRate     = avgRate .* repmat(norm,1,size(avgRate,2)); clear norm        
        
        if in.fieldRateVar > 0            
            pure 	= gridCells.r_pure';            
            avgPure = accumarray([x(:) y(:)],pure(:),[length(posBins)-1 nCells],@mean)'; clear pure
            norm  	= sum(cycle.nSpikes,2) ./ sum(avgPure,2);
            avgPure = avgPure .* repmat(norm,1,size(avgPure,2)); clear norm
        end                
        clear x y 
    case '2D'
        maxPos      = max([pos.x_log pos.y_log],[],1);
        minPos      = min([pos.x_log pos.y_log],[],1);
        posBins     = floor(min(minPos)/binSize)*binSize : binSize : ceil(max(maxPos)/binSize)*binSize;
        posBinInd   = [pos.x_log-minPos(1) pos.y_log-minPos(2)] ./ binSize; clear minPos maxPos       
        posBinInd(posBinInd==0) = eps;
        posBinInd   = ceil(posBinInd);
        avgRate     = nan(range(posBinInd(:,1))+1,range(posBinInd(:,2))+1,nCells);        
        if in.fieldRateVar > 0
            avgPure	= nan(range(posBinInd(:,1))+1,range(posBinInd(:,2))+1,nCells);
        end
        for c       = 1 : nCells
            rate    = gridCells.r_rate(c,:)';
            avgRate(:,:,c)      = accumarray(posBinInd,rate,[],@mean); clear rate
            if in.fieldRateVar > 0
                pure 	= gridCells.r_pure(c,:)';
                avgPure(:,:,c)	= accumarray(posBinInd,pure,[],@mean); clear pure
            end
        end
        clear c posBinInd        
end

%  Decode location in each cycle / sub-cycle using firing rate across location bins
decodedLoc          = nan(nCycles,2);
binnedDecLoc        = nan(nPhaseBins,2,nCycles);
for bin             = 1 : nCycles    
    decodedLoc(bin,:)       = decodeLocation(cycle.nSpikes(:,bin),avgRate);
    if in.fieldRateVar > 0
        if bin      == 1
            decodedPure  	= nan(nCycles,2);
            binnedDecPure	= nan(nPhaseBins,2,nCycles);
        end
        decodedPure(bin,:)	= decodeLocation(cycle.nSpikes(:,bin),avgPure);
    end
    for phase       = 1 : nPhaseBins        
        binnedDecLoc(phase,:,bin)       = decodeLocation(cycle.nSpikes_binned(:,phase,bin),avgRate);
        if in.fieldRateVar > 0
            binnedDecPure(phase,:,bin)	= decodeLocation(cycle.nSpikes_binned(:,phase,bin),avgPure);
        end
    end
    clear phase
end
decoded.decLoc              = posBins(decodedLoc) + binSize/2; clear bin decodedLoc 
decoded.decLocErr           = sqrt(sum((decoded.decLoc - cycle.actLoc).^2,2)); 
decoded.binDecLoc           = posBins(binnedDecLoc) + binSize/2; clear binnedDecLoc
decoded.binDecLocErr        = squeeze(sqrt(sum((decoded.binDecLoc - cycle.actLoc_binned).^2,2)))';
decoded.binDecLocDynErr     = squeeze(decoded.binDecLoc - cycle.actLoc_binned);
if in.fieldRateVar > 0
    decoded.decPur          = posBins(decodedPure) + binSize/2; clear decodedPure
    decoded.decPurErr       = sqrt(sum((decoded.decPur - cycle.actLoc).^2,2));
    decoded.binDecPur       = posBins(binnedDecPure) + binSize/2; clear binnedDecPure
    decoded.binDecPurErr	= squeeze(sqrt(sum((decoded.binDecPur - cycle.actLoc_binned).^2,2)))';
    decoded.binDecPurDynErr	= squeeze(decoded.binDecPur - cycle.actLoc_binned);
end
clear posBins avgRate avgComb avgPure

%  Decode location in each cycle / sub-cycle using firing rate across
%  cycles, and trajectory in each cycle
cycLoc              = nan(nCycles,2);
rateLoc             = nan(nCycles,2);
phaseLoc            = nan(nCycles,2);
combLoc             = nan(nCycles,2);
decVec              = nan(nCycles,1);
if in.fieldRateVar > 0
    pureLoc         = nan(nCycles,2);
    purcLoc         = nan(nCycles,2);
end
for bin             = 1 : nCycles    
    cycLoc(bin,:)	= decodeLocation(cycle.nSpikes(:,bin),cycle.expRate);
    rate            = cycle.nSpikes_binned(:,:,bin);
    rateLoc(bin,:)	= decodeLocation(rate(:),reshape(cycle.expRate_binned,[size(cycle.expRate_binned,1)*size(cycle.expRate_binned,2) size(cycle.expRate_binned,3)]));
    phaseLoc(bin,:)	= decodeLocation(rate(:),reshape(cycle.expPhase_binned,[size(cycle.expPhase_binned,1)*size(cycle.expPhase_binned,2) size(cycle.expPhase_binned,3)]));
    combLoc(bin,:)	= decodeLocation(rate(:),reshape(cycle.expComb_binned,[size(cycle.expComb_binned,1)*size(cycle.expComb_binned,2) size(cycle.expComb_binned,3)]));
    if in.fieldRateVar > 0
        pureLoc(bin,:)	= decodeLocation(rate(:),reshape(cycle.expPure_binned,[size(cycle.expPure_binned,1)*size(cycle.expPure_binned,2) size(cycle.expPure_binned,3)]));
        purcLoc(bin,:)	= decodeLocation(rate(:),reshape(cycle.expPurC_binned,[size(cycle.expPurC_binned,1)*size(cycle.expPurC_binned,2) size(cycle.expPurC_binned,3)]));
    end
    decTraj         = decoded.binDecLoc(:,:,bin);
    bx              = regress(decTraj(:,1),[1:nPhaseBins ; ones(1,nPhaseBins)]');
    by              = regress(decTraj(:,2),[1:nPhaseBins ; ones(1,nPhaseBins)]');
    decVec(bin,1)	= atan2(by(1),bx(1)); clear bx by decTraj rate
    
end
decoded.cycLoc      = cycle.actLoc(cycLoc(:,1),:); clear bin cycLoc
decoded.cycLocErr   = sqrt(sum((decoded.cycLoc - cycle.actLoc).^2,2)); 
decoded.rateLoc     = cycle.actLoc(rateLoc(:,1),:); clear rateLoc
decoded.rateLocErr  = sqrt(sum((decoded.rateLoc - cycle.actLoc).^2,2)); 
decoded.phaseLoc    = cycle.actLoc(phaseLoc(:,1),:); clear phaseLoc 
decoded.phaseLocErr = sqrt(sum((decoded.phaseLoc - cycle.actLoc).^2,2)); 
decoded.combLoc   	= cycle.actLoc(combLoc(:,1),:); clear combLoc 
decoded.combLocErr  = sqrt(sum((decoded.combLoc - cycle.actLoc).^2,2)); 
if in.fieldRateVar > 0
    decoded.pureLoc   	= cycle.actLoc(pureLoc(:,1),:); clear pureLoc
    decoded.pureLocErr  = sqrt(sum((decoded.pureLoc - cycle.actLoc).^2,2));    
    decoded.combPureLoc     = cycle.actLoc(purcLoc(:,1),:); clear purcLoc
    decoded.combPureLocErr  = sqrt(sum((decoded.combPureLoc - cycle.actLoc).^2,2));    
end
decoded.decVec      = decVec; clear decVec
decoded.vecError    = angle(exp(1i*cycle.realVec)./exp(1i*decoded.decVec));

%  Then do linear regression on data from half of the cycles, use that to
%  predict running speed in the other half
spikesPerCycle      = sum(cycle.nSpikes)';
b                   = regress(cycle.meanSpeed(1:2:end),[spikesPerCycle(1:2:end) ones(round(nCycles/2),1)]);
decoded.predSpeed   = spikesPerCycle(2:2:end).*b(1) + b(2); clear b
decoded.speedError  = cycle.meanSpeed(2:2:end) - decoded.predSpeed; clear spikesPerCycle

%  Then do the same for 'anxiety'
b                   = regress(cycle.meanAnxiety(1:2:end),[1./cycle.cycleLength(1:2:end) ones(round(nCycles/2),1)]);
decoded.predAnx     = b(1)./cycle.cycleLength(2:2:end) + b(2); clear b
decoded.anxError    = (decoded.predAnx ./ cycle.meanAnxiety(2:2:end)) - 1;

end



%% Maximum likelihood estimation function for decoding location based on 
%  firing rates of spatial cells
function[decLoc]    = decodeLocation(firingRates,rateFunction)

expExpRate          = exp(-rateFunction);
factRateMaps        = factorial(firingRates);
if size(rateFunction,3)==1
    currK       	= repmat(firingRates,[1 size(rateFunction,2)]); clear firingRates
    fact_currK      = repmat(factRateMaps,[1 size(rateFunction,2)]); clear factRateMaps
    pval_contrib    = ((rateFunction.^currK)./fact_currK) .* expExpRate; clear expExpRate currK fact_currK rateFunction
    pval            = prod(pval_contrib); clear pval_contrib
    peaks           = find(pval==max(pval(:))); clear pval
    if ~isempty(peaks)
        decLoc     	= [peaks(randi(length(peaks),1)) 1]; 
    else
        decLoc      = nan;
    end
    clear peaks
else
    currK           = repmat(permute(firingRates,[3 2 1]), [size(rateFunction,1),size(rateFunction,2), 1]); clear firingRates
    fact_currK      = repmat(permute(factRateMaps,[3 2 1]), [size(rateFunction,1),size(rateFunction,2), 1]); clear factRateMaps
    pval_contrib    = ((rateFunction.^currK)./fact_currK) .* expExpRate; clear expExpRate currK fact_currK
    pval            = prod(pval_contrib,3); clear pval_contrib
    peaks           = find(pval==max(pval(:))); clear pval
    if ~isempty(peaks)
        [decLoc(1),decLoc(2)]	= ind2sub([size(rateFunction,1),size(rateFunction,2)],peaks(randi(length(peaks),1)));
    else
        decLoc      = [nan nan];
    end
    clear rateFunction peaks
end
end



%% Function to plot movement trajectory decoding results
function[] = plotResults(cycle,cycleDecoding)

nBins   = 30;       % Number of histogram bins
vThresh = 5;        % Speed threshold for location and direction decoding

figure
subplot(2,3,1)
xAx     = linspace(0,50,nBins);
histDat = hist(cycleDecoding.decLocErr(cycle.meanSpeed>=vThresh),xAx);
histDat = histDat./sum(histDat);
bar(xAx,histDat,'FaceColor',[0.8 0.8 0.8],'EdgeColor','none')
hold on
plot([median(cycleDecoding.decLocErr(cycle.meanSpeed>=vThresh)) median(cycleDecoding.decLocErr(cycle.meanSpeed>=vThresh))],[0 1],'r','LineWidth',2)
hold off
xlabel('Absolute location decoding error (cm)','FontSize',18)
ylabel('Relative Frequency','FontSize',18)
text(10,0.75,['Median error = ' num2str(median(cycleDecoding.decLocErr(cycle.meanSpeed>=vThresh)),3) 'cm'],'FontSize',18);
axis square
clear histDat
title('Location (rate only)','FontSize',24)

subplot(2,3,2)
histDat = hist(cycleDecoding.rateLocErr(cycle.meanSpeed>=vThresh),xAx);
histDat = histDat./sum(histDat);
bar(xAx,histDat,'FaceColor',[0.8 0.8 0.8],'EdgeColor','none')
hold on
plot([median(cycleDecoding.rateLocErr(cycle.meanSpeed>=vThresh)) median(cycleDecoding.rateLocErr(cycle.meanSpeed>=vThresh))],[0 1],'r','LineWidth',2)
hold off
xlabel('Absolute location decoding error (cm)','FontSize',18)
ylabel('Relative Frequency','FontSize',18)
text(10,0.75,['Median error = ' num2str(median(cycleDecoding.rateLocErr(cycle.meanSpeed>=vThresh)),3) 'cm'],'FontSize',18);
title('Location (rate only, in phase bins)','FontSize',24)
axis square
clear histDat

subplot(2,3,3)
histDat = hist(cycleDecoding.combLocErr(cycle.meanSpeed>=vThresh),xAx);
histDat = histDat./sum(histDat);
bar(xAx,histDat,'FaceColor',[0.8 0.8 0.8],'EdgeColor','none')
hold on
plot([median(cycleDecoding.combLocErr(cycle.meanSpeed>=vThresh)) median(cycleDecoding.combLocErr(cycle.meanSpeed>=vThresh))],[0 1],'r','LineWidth',2)
hold off
xlabel('Absolute location decoding error (cm)','FontSize',18)
ylabel('Relative Frequency','FontSize',18)
text(10,0.75,['Median error = ' num2str(median(cycleDecoding.combLocErr(cycle.meanSpeed>=vThresh)),3) 'cm'],'FontSize',18);
title('Location (rate and phase)','FontSize',24)
axis square
clear histDat xAx

subplot(2,3,4)
xAx     = linspace(-10,10,nBins);
histDat = hist(cycleDecoding.speedError,xAx);
histDat = histDat./sum(histDat);
bar(xAx,histDat,'FaceColor',[0.8 0.8 0.8],'EdgeColor','none')
xlabel('Running speed decoding error (cm)','FontSize',18)
ylabel('Relative Frequency','FontSize',18)
axis square
clear histDat xAx
ylim([0 0.2])
title('Running Speed','FontSize',24)

subplot(2,3,5)
polarhistogram(cycleDecoding.vecError(cycle.meanSpeed>=vThresh),nBins,'Normalization','Probability','FaceColor',[0.8 0.8 0.8],'EdgeColor','none')
title('Movement Direction','FontSize',24)

subplot(2,3,6)
xAx     = linspace(-10,10,nBins);
histDat = hist(cycleDecoding.anxError*100,xAx);
histDat = histDat./sum(histDat);
bar(xAx,histDat,'FaceColor',[0.8 0.8 0.8],'EdgeColor','none')
xlabel('Decoding error (%)','FontSize',18)
ylabel('Relative Frequency','FontSize',18)
axis square
ylim([0 0.15])
clear histDat xAx
title('Arbitrary Fourth Variable','FontSize',24)

end