%% Simulation of a hybrid oscillatory interference (OI) - continuous
% attractor network (CAN) model of grid cell firing
%
% Daniel Bush, UCL Institute of Cognitive Neuroscience (2016)
% www.danbush.co.uk
%
% For details and description of the model, see:
%
% Bush D, Burgess N (2014) Journal of Neuroscience 34: 5065-5079
%
% Bush D, Schmidt-Hieber C (in press) Hippocampal Microcircuits: A
% Computational Modeller's Resource Book. Springer, NY
%
% Note that 2D simulations of 20 minutes random foraging take around 20
% minutes to run on a standard desktop PC
%% Provide some settings for the simulation
Environment = '1D'; % Environment type (1D or 2D)
gridScale = 30; % Grid scale (cm)
%% Load tracking data, assign VCO orientations and initial phase offsets
switch Environment
case '1D'
load('TrackingData.mat','OneD');
track = OneD;
nGC = size(OneD.vcoInput,2); % Number of grid cells
vcoOrientations = [2*pi*ones(nGC,1) ; 5*pi/3*ones(nGC,1) ; pi/3*ones(nGC,1)];
vcoPhases = [linspace(4*pi,2*pi/nGC,nGC)' ; linspace(2*pi,2*pi/nGC,nGC)' ; linspace(2*pi,2*pi/nGC,nGC)'];
vcoInput = OneD.vcoInput; % VCO to GC connectivity template
clear OneD
case '2D'
load('TrackingData.mat','TwoD');
track = TwoD;
nGC = size(TwoD.vcoInput,2); % Number of grid cells
vcoOrientations = repmat(pi/3:pi/3:2*pi,sqrt(nGC),1); % VCO orientations
vcoOrientations = vcoOrientations(:);
vcoPhases = repmat(linspace(2*pi,2*pi/sqrt(nGC),sqrt(nGC))',1,6);
vcoPhases = vcoPhases(:); % VCO phases
vcoInput = TwoD.vcoInput; % VCO to GC connectivity template
clear TwoD
end
%% Generate the VCO rate functions
fBaseline = 8; % Baseline oscillation frequency
beta = 1/gridScale; % Slope of the running speed - VCO bursting firing frequency relationship
vcoSignal = nan(length(track.t_log),length(vcoPhases)); % Assign some memory
for vco = 1 : length(vcoPhases)
vector = [cos(vcoOrientations(vco,1)) sin(vcoOrientations(vco,1))];
disp = dot(repmat(vector,length(track.t_log),1),[track.x_log' track.y_log'],2); clear vector
velocity = diff(disp)./track.dt; clear disp
frequency = fBaseline + beta .* velocity; clear velocity
phase = cumsum([vcoPhases(vco) ; frequency.*2*pi.*track.dt]); clear frequency
circ_dist = abs(angle(exp(1i*track.head_dir)./exp(1i*vcoOrientations(vco,1))))';
vcoSignal(:,vco) = (1+cos(phase)).*(circ_dist<pi/2); clear phase circ_dist
end
clear vco beta vcoOrientations vcoPhases fBaseline
%% Generate the grid cell input rate functions
gcSignal = nan(length(track.t_log),nGC); % Assign some memory
for gc = 1 : nGC
gcSignal(:,gc) = sum(vcoSignal(:,vcoInput(:,gc)==1),2); % Integration VCO input rate functions to each grid cell
end
clear gc vcoSignal vcoInput
%% Initialise the recurrent inhibitory connectivity
gcReps = 48; % Number of grid cells that share a spatial offset (i.e. have the same VCO inputs)
intReps = 12; % Number of interneurons that share a spatial offset (i.e. receive input from the same grid cells)
gcInhW = 0.2; % Gain (i.e. relative strength) of grid cell to interneuron weights
gcInhSig = 0.2; % Standard deviation of Gaussian distributed grid cell to interneuron weights
gcInhConn = 0.5; % Mean grid cell to interneuron connection probability
inhGCW = 0.04; % Gain (i.e. relative strength) of inhibitory to grid cell weights
inhGCSig = 0.1; % Standard deviation of Gaussian distributed inhibitory to grid cell weights
inhGCConn = 0.7; % Mean interneuron to grid cell connection probability
inhGcWeights = toeplitz((cos(linspace(pi,3*pi-2*pi/nGC,nGC))+1)/2); % Interneuron to grid cell connectivity profile
excW = zeros(nGC*gcReps+nGC*intReps,nGC*gcReps+nGC*intReps); % Excitatory weight matrix
inhW = zeros(nGC*gcReps+nGC*intReps,nGC*gcReps+nGC*intReps); % Inhibitory weight matrix
for c = 1 : nGC
exc = (c-1)*gcReps+1:c*gcReps;
inh = nGC*gcReps+1+(c-1)*intReps:nGC*gcReps+c*intReps;
excW(exc,inh) = (1+gcInhSig*randn(length(exc),length(inh))).*gcInhW; clear exc
for c2 = 1 : nGC
exc = (c2-1)*gcReps+1:c2*gcReps;
inhW(inh,exc) = (1+inhGCSig*randn(length(inh),length(exc))).*inhGcWeights(c,c2).*inhGCW; clear exc
end
clear c2 inh
end
excW(excW<0) = 0; % Ensure all weights are positive
excW = excW.*double(rand(size(excW))<=gcInhConn); % Sparsify excitatory connectivity
inhW(inhW<0) = 0; % Ensure all weights are positive
inhW = inhW.*double(rand(size(inhW))<=inhGCConn); % Sparsify inhibitory connectivity
clear c gcInhConn gcInhSig gcInhW inhGCConn inhGCSig inhGCW inhGcWeights
%% Initialise the neural dynamics
vcoReps = 30; % Number of VCO cells that share a phase offset, for each orientation / ring attractor
vcoRate = 50; % Mean firing rate of each VCO input
vcoGCw = 4.5e-3; % VCO to grid cell weights
excPersI = 8.5e-4; % Persistent current to grid cells
inhPersI = 1.25e-4; % Persistent current to interneurons
excSig = 1.25e-4; % Standard deviation of noise input to grid cells
inhSig = 2.5e-4; % Standard deviation of noise input to interneurons
Cm = 0.5e-3; % Membrane conductance (mF)
gm = 25e-6; % Leak conductance (mS)
Vl = -70; % Leak reversal potential (mV)
Vt = -50; % Firing threshold (mV)
Vr = -65; % Reset potential (mV)
tauAMPA = 5.26; % AMPA current decay constant (ms)
E_AMPA = 0; % AMPA reversal potential (mV)
g_AMPA = 0.215e-4; % Maximum AMPA conductance (mS)
tauGABA_r = 3; % GABA rise time constant (ms)
tauGABA_1 = 50; % GABA decay constant (ms)
tauGABA_2 = (tauGABA_r*tauGABA_1)/(tauGABA_r+tauGABA_1); % GABA time constant (ms)
GABA_B = ((tauGABA_2/tauGABA_1)^(tauGABA_r/tauGABA_1)-(tauGABA_2/tauGABA_1)^(tauGABA_r/tauGABA_2))^-1;
E_GABA = -80; % GABA reversal potential (mV)
g_GABA = 0.14e-4; % Maximum GABA conductance (mS)
v = Vl*ones(size(excW,1),1); % Membrane voltage of all cells (mV)
AMPAExp = zeros(size(excW,1),1); % AMPA open channel probability for all cells
GABAExp1 = 0.8*ones(size(excW,1),1); % GABA open channel probability for all cells
GABAExp2 = 0.04*ones(size(excW,1),1); % GABA close channel probability for all cells
if strncmp(Environment,'1D',2)
logged = ceil(rand*nGC*gcReps); % Choose a random grid cell to log
v_log = nan(1,max(track.t_log)*1000); % Assign memory for the membrane voltage log
spikeTimes = nan(max(track.t_log)*nGC*(gcReps+intReps),2); % Assign memory to log spike times for each cell (estimating mean firing rate of 1Hz)
elseif strncmp(Environment,'2D',2)
offset = floor(rand*nGC); % Select a random grid cell offset
logged = offset*gcReps + (1:gcReps); % Log the spike times for all grid cells with that spatial offset...
logged = [logged nGC*gcReps+offset*intReps+(1:intReps)]; % ...and the corresponding population of interneurons
spikeTimes = nan(max(track.t_log)*(gcReps+intReps),2); % Assign memory to log spike times for those cells (estimating mean firing rate of 1Hz)
clear offset
end
sCount = 1; % Index for logging spike times
tic % Start the clock
%% Run the neural dynamics
for time = track.t_log(1)*1000 : max(track.t_log)*1000 % For each 1ms time step
% Compute the total number of Poisson input spikes to each grid cell
% from VCOs and add those to the GABA open channel probabilities
if mod(time,track.dt*1000) == 0
index = round(time/1000/track.dt - (track.t_log(1)-track.dt)/track.dt);
vcoInputs = reshape(poissrnd(repmat(vcoReps*vcoRate*track.dt*gcSignal(index,:),gcReps,1)),nGC*gcReps,1);
GABAExp1(1:nGC*gcReps) = GABAExp1(1:nGC*gcReps) + vcoInputs*vcoGCw;
GABAExp2(1:nGC*gcReps) = GABAExp2(1:nGC*gcReps) + vcoInputs*vcoGCw; clear index vcoInputs
end
% Compute the AMPA, GABA and persistent currents
AMPA_I = -g_AMPA .* AMPAExp .* (v-E_AMPA);
GABA_I = -g_GABA .* GABA_B .* (GABAExp1-GABAExp2) .* (v-E_GABA);
pers_I = [excPersI*ones(nGC*gcReps,1)+randn(nGC*gcReps,1)*excSig ; inhPersI*ones(nGC*intReps,1)+randn(nGC*intReps,1)*inhSig];
% Update the AMPA and GABA open channel probabilities
AMPAExp = AMPAExp .* exp(-1/tauAMPA);
GABAExp1 = GABAExp1.* exp(-1/tauGABA_1);
GABAExp2 = GABAExp2.* exp(-1/tauGABA_2);
% Update the membrane voltages (in two steps for numerical stability)
v = v + 0.5 * (1/Cm) * (AMPA_I + GABA_I + pers_I - (v - Vl) .* gm);
v = v + 0.5 * (1/Cm) * (AMPA_I + GABA_I + pers_I - (v - Vl) .* gm); clear AMPA_I GABA_I pers_I
% Find the neurons that fired, update the AMPA and GABA open channel probabilities
fired = v>=Vt;
AMPAExp = AMPAExp + sum(excW(fired,:),1)';
GABAExp1 = GABAExp1 + sum(inhW(fired,:),1)';
GABAExp2 = GABAExp2 + sum(inhW(fired,:),1)';
% Reset the membrane potential
v(fired) = Vr;
% Log the output (membrane voltage etc)
if strncmp(Environment,'1D',2)
v_log(time) = v(logged);
elseif strncmp(Environment,'2D',2)
fired = fired(logged);
end
if sum(fired)>0
spikeTimes(sCount:sCount+sum(fired)-1,:) = [nonzeros(fired.*(1:size(fired,1))') time*ones(sum(fired),1)/1000];
sCount = sCount + sum(fired);
end
clear fired
end
spikeTimes = spikeTimes(1:sCount-1,:); % Truncate the spike time log
spikeInds = round((spikeTimes(:,2)-track.t_log(1))./track.dt)+1; % Compute the tracking index at each spike time
spikeTimes(:,3) = track.x_log(spikeInds)'; % Compute the x location of each spike
spikeTimes(:,4) = track.y_log(spikeInds)'; clear spikeInds % Compute the y location of each spike
simTime = toc; % Compute the total simulation time
clear time sCount AMPAExp Cm E_AMPA E_GABA GABAExp1 GABAExp2 GABA_B Vl Vt
clear excPersI excSig excW fired g_AMPA g_GABA gcSignal gm gridScale
clear inhPersI inhSig inhW tauAMPA tauGABA_1 tauGABA_2 tauGABA_r v tic
clear vcoGCw vcoRate vcoReps
%% Analyse and plot the output
if strncmp(Environment,'1D',2)
% Linearly interpolate the voltage signal to remove spikes
interpWin = [-1 20]; % Interpolation window around each spike (ms)
spikeInds = find(v_log==Vr)-1; % Find the index of all spikes fired by the logged cell
for spike = 1 : length(spikeInds) % Linearly interpolate the voltage trace around each spike
v_log(spikeInds(spike)+interpWin(1):spikeInds(spike)+interpWin(2)) = linspace(v_log(spikeInds(spike)+interpWin(1)),v_log(spikeInds(spike)+interpWin(2)),diff(interpWin)+1);
end
clear interpWin spikeInds spike Vr
% Mean normalise the voltage signal
vMean = nanmean(v_log); % Mean normalise the voltage trace (for filtering)
v_log = v_log - vMean;
v_log(isnan(v_log)) = 0; % Eliminate any nan entries (for filtering)
% Filter in the <3Hz and 5-11Hz range with a 400th order FIR filter
a = fir1(400, [5 11]*2/1000, 'band'); % Set up a theta band pass filter
b = fir1(400, 3*2/1000, 'low'); % Set up a low pass ramp filter
Theta = filtfilt(a,1,v_log); % Filter in the theta band
Theta = abs(hilbert(Theta)); % Extract theta amplitude
Ramp = filtfilt(b,1,v_log); clear a b % Filter in the ramp band
% Compute the mean firing rate of all grid cells with that offset
binSize = 0.2; % Temporal bin size (s)
gridOffset = ceil(logged/gcReps); % Identify all cells that shared a spatial offset with the logged cell
spikeInds = ismember(spikeTimes(:,1),(gridOffset-1)*gcReps : gridOffset*gcReps);
spikeTs = spikeTimes(spikeInds,2); clear gridOffset spikeInds % Extract the spike times of all those cells
spikeRate = histc(spikeTs,0:binSize:max(track.t_log)); clear spikeTs
spikeRate = spikeRate(1:end-1)/gcReps/binSize; % Compute the firing rate for that sub-population of cells
fields = spikeRate>=(0.1*max(spikeRate)); % Identify candidate firing fields as rate > 10% of the peak rate
fields = regionprops(fields,'PixelIdxList'); % Narrows those down to fields with at least three consecutive bins
inField = zeros(size(v_log)); % Generate a firing field mask to examine changes in theta and ramp depolariation
for f = 1 : length(fields)
if length(fields(f,1).PixelIdxList)>=3
inds(1) = round(find(track.t_log>=(fields(f,1).PixelIdxList(1)*binSize),1,'first')*track.dt*1000);
inds(2) = round(find(track.t_log<=(fields(f,1).PixelIdxList(end)*binSize),1,'last')*track.dt*1000);
inField(inds(1):inds(2)) = 1; clear inds
end
end
clear f fields
% Plot the output
h1 = subplot(4,1,1); % Plot the mean firing rate along the track
plot(h1,binSize/2:binSize:max(track.t_log)-binSize/2,spikeRate,'k','LineWidth',2)
ylabel(h1,{'Firing','Rate (Hz)'},'FontSize',16)
ylim([0 10])
clear spikeRate
h2 = subplot(4,1,2); % Plot the membrane voltage trace along the track
plot(h2,linspace(track.t_log(1),track.t_log(end),length(v_log)),v_log+vMean,'k','LineWidth',2)
hold(h2,'on')
APs = round(spikeTimes(spikeTimes(:,1)==logged,2)*1000); % Add the action potentials back in
for spike = 1 : length(APs)
plot(h2,[APs(spike) APs(spike)]./1000,[v_log(APs(spike))+vMean 0],'k','LineWidth',2)
end
clear spike APs
hold(h2,'off')
ylabel(h2,{'Membrane','Potential (mV)'},'FontSize',16)
ylim([-80 10])
h3 = subplot(4,1,3); % Plot theta amplitude in the membrane potential along the track
plot(h3,linspace(track.t_log(1),track.t_log(end),length(v_log)),Theta,'k','LineWidth',2)
hold(h3,'on') % Add mean theta amplitude inside and outside the firing fields
plot(h3,linspace(track.t_log(1),track.t_log(end),length(v_log)),mean(Theta(inField==1))*ones(size(v_log)),'r--','LineWidth',2)
plot(h3,linspace(track.t_log(1),track.t_log(end),length(v_log)),mean(Theta(inField==0))*ones(size(v_log)),'b--','LineWidth',2)
hold(h3,'off')
ylabel(h3,{'Theta';'Amplitude (mV)'},'FontSize',16)
ylim([-5 5])
h4 = subplot(4,1,4); % Plot ramp depolarisation in the membrane potential along the track
plot(h4,linspace(track.t_log(1),track.t_log(end),length(v_log)),Ramp,'k','LineWidth',2)
hold(h4,'on') % Add mean ramp depolarisation inside and outside the firing fields
plot(h4,linspace(track.t_log(1),track.t_log(end),length(v_log)),mean(Ramp(inField==1))*ones(size(v_log)),'r--','LineWidth',2)
plot(h4,linspace(track.t_log(1),track.t_log(end),length(v_log)),mean(Ramp(inField==0))*ones(size(v_log)),'b--','LineWidth',2)
hold(h4,'off')
xlabel(h4,'Time (s)','FontSize',16)
ylabel(h4,{'Ramp','Depolarisation (mV)'},'FontSize',16)
ylim([-5 5])
clear fields
linkaxes([h1 h2 h3 h4],'x'); clear h1 h2 h3 h4
xlim([1 max(track.t_log)]) % Set the x axis limits to movement periods only
v_log = v_log + vMean; clear vMean % Return the voltage trace to its original offset
elseif strncmp(Environment,'2D',2)
% Generate smoothed rate maps
binSize = 2; % Spatial bin size (cm)
smthKern = 5; % Size of boxcar smoothing kernel (bins)
mapSize = ceil(max(spikeTimes(spikeTimes(:,1)<=gcReps,3:4))./binSize);
gridInds = ceil(spikeTimes(spikeTimes(:,1)<=gcReps,3:4)./binSize);
gridInds = accumarray(gridInds,1,mapSize)/gcReps; % Generate raw grid cell rate map
intInds = ceil(spikeTimes(spikeTimes(:,1)>gcReps,3:4)./binSize);
intInds = accumarray(intInds,1,mapSize)/gcReps; % Generate raw interneuron rate map
locInds = ceil([track.x_log' track.y_log']./binSize);
locInds = accumarray(locInds,1,mapSize)*track.dt; % Generate raw occupancy map
denom = filter2(ones(smthKern),double(locInds>0));
denom(denom==0) = nan;
locInds = filter2(ones(smthKern),locInds)./denom; % Smooth the occupancy map
gridInds = filter2(ones(smthKern),gridInds)./denom; % Smooth the grid cell and interneuron rate maps
intInds = filter2(ones(smthKern),intInds)./denom; clear denom smthKern
gridMap = (gridInds./locInds)';
intMap = (intInds./locInds)'; clear gridInds intInds locInds
% Plot the grid cell spike locations
subplot(2,2,1); % Plot animal trajectory and grid cell spike locations
plot(track.x_log,track.y_log,'Color',[0.8 0.8 0.8])
hold on
scatter(spikeTimes(spikeTimes(:,1)<=gcReps,3),spikeTimes(spikeTimes(:,1)<=gcReps,4),'r.')
hold off
title('Grid Cells','FontSize',24)
axis square
xlabel('Position (cm)','FontSize',16)
ylabel('Position (cm)','FontSize',16)
% Plot the grid cell rate map
subplot(2,2,2); % Plot the smoothed grid cell rate map
imagesc(binSize/2:binSize:mapSize(1)*binSize-binSize/2,binSize/2:binSize:mapSize(1)*binSize-binSize/2,gridMap)
set(gca,'YDir','normal')
title(['Peak rate = ' num2str(max(gridMap(:)),3), 'Hz'],'FontSize',16)
axis square
xlabel('Position (cm)','FontSize',16)
ylabel('Position (cm)','FontSize',16)
% Plot the interneuron spike locations
subplot(2,2,3); % Plot animal trajectory and interneuron spike locations
plot(track.x_log,track.y_log,'Color',[0.8 0.8 0.8])
hold on
scatter(spikeTimes(spikeTimes(:,1)>gcReps,3),spikeTimes(spikeTimes(:,1)>gcReps,4),'b.')
hold off
title('Interneurons','FontSize',24)
axis square
xlabel('Position (cm)','FontSize',16)
ylabel('Position (cm)','FontSize',16)
% Plot the interneurons rate map
subplot(2,2,4); % Plot the smoothed interneuron rate map
imagesc(binSize/2:binSize:mapSize(1)*binSize-binSize/2,binSize/2:binSize:mapSize(1)*binSize-binSize/2,intMap)
set(gca,'YDir','normal')
title(['Peak rate = ' num2str(max(intMap(:)),3), 'Hz'],'FontSize',16)
axis square
xlabel('Position (cm)','FontSize',16)
ylabel('Position (cm)','FontSize',16)
linkaxes
xlim([0 100])
ylim([0 100])
clear mapSize
end
clear Vr binSize gcReps logged intReps nGC