classdef Neurons
properties
dimensionality = 1;
popSize = 1;
preferredStimulus = 0;
integrationTime = 0;
distribution = 'Gaussian';
a = 1.0;
alpha = 0.5;
R = [];
add = 0.0;
exponent = 1.0;
truncate = false;
end
methods
function obj = Neurons(varargin)
% NEURONS Neuron population class constructor.
% n = Neurons(dimensionality, preferredStimuli, integrationTime, variabilityScheme, variabilityOpts)
%
% dimensionality - stimulus dimensionality (only 1-D stimuli currently supported)
% preferred Stimuli - column vector of (N = popsize) characteristic stimuli
% integrationTime - spike counting time per trial
% variabilityScheme - the type of variablilty model
% variabiltyOpts - an array of arguments specific to the chosen variability model
%
switch nargin
case 5
% Standard constructor
if length(varargin{1}) == 1 && isnumeric(varargin{1})
obj.dimensionality = varargin{1};
else
error([inputname(1) ' is not a valid stimulus dimensionality'])
end
if isnumeric(varargin{2}) && size(varargin{2}, 1) == obj.dimensionality
obj.preferredStimulus = varargin{2}';
obj.popSize = size(varargin{2}, 2);
else
error([inputname(2) ' is not a valid preferred stimulus value or vector'])
end
%if length(varargin{3}) == 1 && isnumeric(varargin{3})
% obj.maxRate = double(varargin{3}(ones(obj.popSize, 1)));
%elseif length(varargin{3}) == obj.popSize && isvector(varargin{3}) && isnumeric(varargin{3})
% obj.maxRate = reshape(double(varargin{3}), obj.popSize, 1);
%else
% error([inputname(3) ' is not a valid maximum firing rate value or vector for population size ' obj.popSize])
%end
%if length(varargin{4}) == 1 && isnumeric(varargin{4})
% obj.backgroundRate = double(varargin{4}(ones(obj.popSize, 1)));
%elseif length(varargin{4}) == obj.popSize && isvector(varargin{4}) && isnumeric(varargin{4})
% obj.backgroundRate = reshape(double(varargin{4}), obj.popSize, 1);
%else
% error([inputname(4) ' is not a valid background firing rate value or vector for population size ' obj.popSize])
%end
if length(varargin{3}) == 1 && isnumeric(varargin{3})
obj.integrationTime = double(varargin{3});
else
error([inputname(3) ' is not a valid integration time'])
end
switch lower(varargin{4})
case 'poisson'
obj.distribution = 'Poisson';
obj.a = [];
obj.alpha = [];
obj.R = [];
obj.add = 0.0;
case 'gaussian-independent'
obj.distribution = 'Gaussian';
obj.a = varargin{5}(1); % need checks
obj.alpha = varargin{5}(2); % need checks
obj.R = eye(obj.popSize);
if length(varargin{5}) == 3
obj.add = varargin{5}(3); % need checks
else
obj.add = 0.0;
end
case 'gaussian-uniform'
obj.distribution = 'Gaussian';
obj.a = varargin{5}(1); % need checks
obj.alpha = varargin{5}(2); % need checks
obj.R = varargin{5}(3) * ~eye(obj.popSize) + eye(obj.popSize);
obj.add = 0.0;
case 'gaussian-exponential'
obj.distribution = 'Gaussian';
obj.a = varargin{5}(1); % need checks
obj.alpha = varargin{5}(2); % need checks
c = varargin{5}(3); % need checks
rho = varargin{5}(4); % need checks
prefDiff = repmat(obj.preferredStimulus, 1, obj.popSize);
prefDiff = prefDiff - prefDiff.';
obj.R = c .* exp(-abs(double(prefDiff)) ./ rho) .* ~eye(obj.popSize) + eye(obj.popSize);
obj.add = 0.0;
case 'gaussian-gaussian'
obj.distribution = 'Gaussian';
obj.a = varargin{5}(1); % need checks
obj.alpha = varargin{5}(2); % need checks
c = varargin{5}(3); % need checks
beta = 1.0 ./ degToRad(varargin{5}(4)).^2; % need checks
prefDiff = repmat(obj.preferredStimulus, 1, obj.popSize);
prefDiff = prefDiff - prefDiff.';
obj.R = c .* exp((cosd(double(prefDiff)) - 1) .* beta) .* ~eye(obj.popSize) + eye(obj.popSize);
obj.add = 0.0;
case 'cercal'
obj.distribution = 'Gaussian';
obj.add = varargin{5}(1);
obj.a = varargin{5}(2);
obj.alpha = 0.5;
obj.R = eye(obj.popSize);
obj.exponent = 2.0;
otherwise
error([varargin{4} ' is not a valid variability regime'])
end
otherwise
error('Wrong number of arguments')
end
end
function varargout = mi(obj, method, stim, tol, maxiter)
if ~isa(stim, 'StimulusEnsemble')
error([inputname(3) ' is not a SimulusEnsemble object'])
end
% obj.popSize x stim.n
rMean = obj.integrationTime .* meanR(obj, stim);
rMeanCell = squeeze(mat2cell(rMean, obj.popSize, ones(stim.n, 1)));
% Do distribution-specific one-time prep
switch obj.distribution
case 'Gaussian'
% Compute mean response dependent cov matrix stack Q [ (popSize x popSize) x stim.n ]
QCell1 = obj.Q(rMeanCell);
% Compute lower triangular Chol(Q) for sampling
cholQ = cellfun(@(q) chol(q)', QCell1, 'UniformOutput', false);
% Compute upper triangular Chol(Q^-1) for fast PDF computation
cholInvQCell = cellfun(@(q) chol(inv(q)), QCell1, 'UniformOutput', false);
clear QCell1
% Define function for multivariate gaussian sampling
% Multiply by Cholesky decomposition of cov matrix Q, and add in mean
if obj.truncate
fRand = @(m, c, z) max((m + c * z), 0.0); % truncate
else
fRand = @(m, c, z) m + c * z; % don't truncate
end
case 'Poisson'
% Nothing to do here
otherwise
error('Unsupported distribution: %s', obj.distribution)
end
iter = 0; % iteration counter
miEst = OnlineStats(1, maxiter);
adaptive = false;
tic
cpS = cumsum(stim.pS);
cont = true;
while cont
iter = iter + 1;
% Display progress every 100 iterations
if ~mod(iter, 100)
fprintf('mi() iter: %d val: %.4g rel. error: %.4g\n', iter, miEst.runMean, miEst.runDelta)
end
switch method
case 'randMC'
% Sample s from stimulus distribution
[dummy, bin] = histc(rand(), cpS);
bin = bin + 1;
%s = double(stim.ensemble);
%s = s(bin);
% Sample r from response distribution
switch obj.distribution
case 'Gaussian'
% Generate vector of independent normal random numbers (mu=0, sigma=1)
z = randn(obj.popSize, 1);
% Multiply by Cholesky decomposition of cov matrix Q, and add in mean
% !!! NOTE NEGATIVE RESPONSES MAY BE TRUNCATED TO ZERO depending on value of obj.truncate !!!
r = fRand(rMeanCell{bin}, cholQ{bin}, z);
case 'Poisson'
% Sample from Poisson distributions
r = poissrnd(rMeanCell{bin});
end
otherwise
error('Unsupported method: %s', method)
end
if ~adaptive
% log P(r|s)
% Replicate to form a stim.n x stim.n cell array of response vectors
rCell = repmat({r}, [stim.n 1]);
% Calculate response log probability densities
switch obj.distribution
case 'Gaussian'
lpRgS = cell2mat(cellsxfun(@mvnormpdfln, rCell, rMeanCell', cholInvQCell', {'inv'}));
case 'Poisson'
lpRgS = cell2mat(cellsxfun(@(x, l) sum(poisspdfln(x, l)), rCell, rMeanCell'));
end
% log P(r,s')
% Mutiply P(r|s) and P(s) to find joint distribution
pS = stim.pS;
lpRS = lpRgS + log(pS');
% log P(r)
% Calculate marginal by summing over s'
lpR = logsumexp(lpRS);
lpR_sparse = mean([logsumexp(lpRS(1:2:end) + log(2)), logsumexp(lpRS(2:2:end) + log(2))]);
% log p(r,s)
lpRS = lpRS(bin);
if abs((lpR - lpR_sparse) / lpR) > tol
% One-shot trapezoid rule is insufficiently accurate; switch to adaptive method
adaptive = true;
end
end
if adaptive
pR = [0 0 0];
trace = false; % debug flag
fAdInt = @quad; % Use the quad function
% log p(r,s)
lpRS = log(obj.fpSR(stim.ensemble(bin), r, stim));
quadTol = exp(lpRS) * tol;
switch bin
case 1 % Bottom bin - do first 2 bins, remainder
[pR(1), fcnt(1)] = fAdInt(@(s) obj.fpSR(s, r, stim), stim.lowerLimit, stim.ensemble(bin+1), quadTol, trace);
[pR(2), fcnt(2)] = fAdInt(@(s) obj.fpSR(s, r, stim), stim.ensemble(bin+1), stim.upperLimit, quadTol, trace);
case stim.n % Top bin - do first bin, last bin, remainder
[pR(1), fcnt(1)] = fAdInt(@(s) obj.fpSR(s, r, stim), stim.lowerLimit, stim.ensemble(1), quadTol, trace);
[pR(2), fcnt(2)] = fAdInt(@(s) obj.fpSR(s, r, stim), stim.ensemble(1), stim.ensemble(end-1), quadTol, trace);
[pR(3), fcnt(3)] = fAdInt(@(s) obj.fpSR(s, r, stim), stim.ensemble(end-1), stim.upperLimit, quadTol, trace);
otherwise % Other bins - do one bin either side, remainder above, remainder below
[pR(1), fcnt(1)] = fAdInt(@(s) obj.fpSR(s, r, stim), stim.lowerLimit, stim.ensemble(bin-1), quadTol, trace);
[pR(2), fcnt(2)] = fAdInt(@(s) obj.fpSR(s, r, stim), stim.ensemble(bin-1), stim.ensemble(bin+1), quadTol, trace);
[pR(3), fcnt(3)] = fAdInt(@(s) obj.fpSR(s, r, stim), stim.ensemble(bin+1), stim.upperLimit, quadTol, trace);
end
lpR = log(sum(pR));
end
% log P(s)
lpS = log(pS(bin));
% sample MI in bits (convert from log_e to log_2)
miEst.appendSample((lpRS - (lpR + lpS)) ./ log(2));
% Test halting criteria (SEM, max iterations limit)
cont = miEst.runDelta > tol & iter < maxiter;
% Impose minimum iteration limit so we get a sensible estimate of SEM
cont = cont | iter < 1000;
end
% Trim unused samples from buffer
miEst.trim;
% Recompute MI, SEM cleanly
i = miEst.mean;
iSem = miEst.sem;
% Report final values
fprintf('mi() halting iter: %d val: %.4g SEM: %.4g\n', iter, i, iSem)
switch nargout
case 1
varargout = {i};
case 2
varargout = {i iSem};
case 3
varargout = {i iSem miEst};
otherwise
error('Unsupported number of return values.')
end
end
function varargout = ssiss(obj, n, method, stim, stimOrds, tol, maxiter, timeout)
try
% Test sanity of neuron indices
obj.preferredStimulus(n);
catch err
error([inputname(2) ' is not a valid neuron index'])
end
try
% Test sanity of stimulus ordinate indices
stim.ensemble(stimOrds);
catch err
error([inputname(5) ' is not a valid stimulus ordinate index'])
end
if ~any(strcmp(method, {'quadrature' 'randMC' 'quasirandMC'}))
error([method ' is not a valid SSI calculation method'])
end
if ~isa(stim, 'StimulusEnsemble')
error([inputname(4) ' is not a SimulusEnsemble object'])
end
% Create mask for calculating specific stimulus ordinates only
if ~isempty(stimOrds)
sMask = false(stim.n, 1);
sMask(stimOrds) = true;
sMaskN = sum(sMask + 0);
else
sMask = true(stim.n, 1);
sMaskN = stim.n;
stimOrds = 1:stim.n;
end
% Get mean responses for each stimulus ordinate
% obj.popSize x stim.n
rMean = obj.integrationTime .* obj.meanR(stim);
rMeanCell = squeeze(mat2cell(rMean, obj.popSize, ones(stim.n, 1)));
% Setup for computing marginals
if ~isempty(n)
% Create logical vector (mask) identifying neurons that are *not* part of the marginal SSI
margMask = true(obj.popSize, 1);
margMask(n) = false;
% Get mean responses for each stimulus ordinate
rMeanMargCell = cellfun(@(r) r(margMask), rMeanCell, 'UniformOutput', false);
end
switch obj.distribution
case 'Gaussian'
% Compute mean response dependent cov matrix stack Q [ (popSize x popSize) x stim.n ]
QCell1 = obj.Q(rMeanCell);
% Compute lower triangular Chol(Q) for sampling
cholQ = cellfun(@(q) chol(q)', QCell1, 'UniformOutput', false);
% Compute upper triangular Chol(Q^-1) for fast PDF computation
cholInvQCell = cellfun(@(q) chol(inv(q)), QCell1, 'UniformOutput', false);
if ~isempty(n)
% Compute mean response dependent cov matrix stack Q
QCellMarg1 = cellfun(@(q) q(margMask, margMask), QCell1, 'UniformOutput', false);
% Invert Q matrices and compute Cholesky decomps
cholInvQCellMarg = cellfun(@(q) chol(inv(q)), QCellMarg1, 'UniformOutput', false);
clear QCellMarg1
end
clear QCell1
% Define function for multivariate gaussian sampling
% Multiply by Cholesky decomposition of cov matrix Q, and add in mean
if obj.truncate
fRand = @(m, c, z) max((m + c * z), 0.0); % truncate
else
fRand = @(m, c, z) m + c * z; % don't truncate
end
case 'Poisson'
% Nothing to be done here
otherwise
error('Unsupported distribution: %s', obj.distribution)
end
% Initialise main loop, preallocate MC sample arrays
iter = 0;
cont = true;
Issi = OnlineStats(sMaskN, maxiter);
Isur = OnlineStats(sMaskN, maxiter);
IssiMarg = OnlineStats(sMaskN, maxiter);
IsurMarg = OnlineStats(sMaskN, maxiter);
% Main MC sampling loop
while cont
iter = iter + 1;
if ~mod(iter, 10)
fprintf('SSISS iter: %d of %d, rel. error: %.4g\n', iter, maxiter, mean(Issi.runDelta))
end
switch method
case 'randMC'
% Sample r from response distribution
switch obj.distribution
case 'Gaussian'
% Generate vector of independent normal random numbers (mu=0, sigma=1)
zCell = mat2cell(randn(obj.popSize, sMaskN), obj.popSize, ones(sMaskN, 1));
% Multiply by Cholesky decomposition of cov matrix Q, and add in mean
% !!! NOTE NEGATIVE RESPONSES MAY BE TRUNCATED TO ZERO, SEE ABOVE !!!
rCell = cellfun(fRand, rMeanCell(sMask), cholQ(sMask), zCell, 'UniformOutput', false); % stim.n cell array of obj.popSize vectors
case 'Poisson'
% Sample from Poisson distributions
rCell = cellfun(@poissrnd, rMeanCell(sMask), 'UniformOutput', false);
end
otherwise
error('Unsupported method: %s', method)
end
% log P(r|s')
% Calculate response probability densities
switch obj.distribution
case 'Gaussian'
lpRgS = cell2mat(cellsxfun(@mvnormpdfln, rCell, rMeanCell', cholInvQCell', {'inv'}));
case 'Poisson'
lpRgS = cell2mat(cellsxfun(@(x, l) sum(poisspdfln(x, l)), rCell, rMeanCell'));
end
% log P(r,s')
% Mutiply P(r|s) and P(s) to find joint distribution
lpRS = bsxfun(@plus, lpRgS, log(stim.pS')); % stim'.n x stim
% log P(r)
% Calculate marginal by summing over s'
lpR = logsumexp(lpRS, 1);
% log P(s'|r)
% Divide joint by marginal P(r)
lpSgR = bsxfun(@minus, lpRS, lpR);
% H(s'|r), in bits, converting from log_e to log_2
hSgR = -sum(exp(lpSgR) .* (lpSgR ./ log(2)), 1);
% Sample specific information Isp(r)
% Specific information; reduction in stimulus entropy due to observation of r
Issi.appendSample(stim.entropy - hSgR);
% Sample specific surprise
% log_2( P(r|s) / P(r) )
% Accumulate samples
Isur.appendSample((diag(lpRgS(stimOrds,:))' - lpR) ./ log(2));
if exist('margMask', 'var')
% If we are calculating a marginal SSI, compute the SSI for remaining neurons
% Mask out neurons of interest in response vectors
rCellMarg = cellfun(@(r) r(margMask), rCell, 'UniformOutput', false);
% log P(r|s)
switch obj.distribution
case 'Gaussian'
lpRgS = cell2mat(cellsxfun(@mvnormpdfln, rCellMarg, rMeanMargCell', cholInvQCellMarg', {'inv'}));
case 'Poisson'
lpRgS = cell2mat(cellsxfun(@(x, l) sum(poisspdfln(x, l)), rCellMarg, rMeanMargCell'));
end
% log P(r,s)
% Multiply P(r|s) and P(s) to find joint distribution
lpRS = bsxfun(@plus, lpRgS, log(stim.pS')); % stim'.n x stim
% log P(r)
% Calculate marginal by summing over s'
lpR = logsumexp(lpRS, 1);
% log P(s|r)
% Divide joint by marginal P(r)
lpSgR = bsxfun(@minus, lpRS, lpR);
% H(s|r), in bits, converting from log_e to log_2
hSgR = -sum(exp(lpSgR) .* (lpSgR ./ log(2)), 1);
% Isp(r)
% Specific information; reduction in stimulus entropy due to observation of r
% Compute MC sample
IssiMarg.appendSample(stim.entropy - hSgR);
% Specific surprise
% log_2( P(r|s) / P(r) )
% Compute MC sample
IsurMarg.appendSample((diag(lpRgS(stimOrds,:))' - lpR) ./ log(2));
end
% Test halting criteria (SEM, max iterations limit, timeout)
if exist('margMask', 'var')
delta = mean([Issi.runDelta IssiMarg.runDelta]);
else
delta = mean(Issi.runDelta);
end
cont = delta > tol & iter < maxiter;
% If the wall clock is running, check the elapsed time
try
cont = cont & toc < timeout;
catch
end
% Impose minimum iteration limit so we get a valid estimate of SEM
cont = cont | iter < 10;
end
try
fprintf('SSISS iter: %d elapsed time: %.4f seconds\n', iter, toc)
catch
fprintf('SSISS iter: %d\n', iter)
end
% Trim sample arrays
Issi.trim;
Isur.trim;
IssiMarg.trim;
IsurMarg.trim;
% Recalculate means cleanly
fullSSI = Issi.mean;
fullIsur = Isur.mean;
if exist('margMask', 'var')
remSSI = IssiMarg.mean;
SSI = fullSSI - remSSI;
remIsur = IsurMarg.mean;
else
remSSI = fullSSI;
SSI = fullSSI;
remIsur = fullIsur;
end
switch nargout
case 1
varargout = {SSI};
case 2
varargout = {SSI iter};
case 3
varargout = {fullSSI remSSI iter};
case 4
varargout = {fullSSI remSSI iter Issi IssiMarg};
case 5
varargout = {fullSSI remSSI fullIsur remIsur iter};
case 9
varargout = {fullSSI remSSI fullIsur remIsur iter Issi Isur IssiMarg IsurMarg};
otherwise
error('Unsupported number of outputs')
end
end
function varargout = fisher(obj, method, stim, tol, maxiter)
% Wrapper function for calculating Fisher information
switch method
case 'analytic'
switch obj.distribution
case 'Gaussian'
[J_mean, J_cov] = obj.fisher_analytic_gauss(stim);
switch nargout
case 0
varargout = {J_mean + J_cov};
case 1
varargout = {J_mean + J_cov};
case 2
varargout = {J_mean J_cov};
otherwise
error('Wrong number of outputs')
end
case 'Poisson'
J = obj.fisher_analytic_poiss(stim);
if nargout == 1
varargout = {J};
else
error('Wrong number of outputs')
end
otherwise
error('Unsupported distribution: %s', obj.distribution)
end
case 'randMC'
[J samples] = obj.fisher_mc(method, stim, tol, maxiter);
switch nargout
case 1
varargout = {J};
case 2
varargout = {J samples};
otherwise
error('Wrong number of outputs')
end
otherwise
error('"%s" is not a valid FI calculation method', method)
end
end
function [fullFisher remainderFisher] = margFisher(obj, nMarg, method, stim, tol)
% Function for calculating fisher of population with and
% without neuron(s) of interest
if ~isa(stim, 'StimulusEnsemble')
error([inputname(4) ' is not a SimulusEnsemble object'])
end
% Compute FI of full population
fullFisher = fisher(obj, method, stim, tol);
% Create population of remaining neurons
remainingNeurons = obj.remove(nMarg);
% FI of remaining neurons
remainderFisher = fisher(remainingNeurons, method, stim, tol);
end
function ifish = Ifisher(obj, stim)
% Function for computing I_Fisher, see:
%
% Brunel N, Nadal J (1998)
% Mutual information, Fisher information, and population coding.
% Neural Comput 10:1731?1757.
if false
% Fisher information
fish = obj.fisher('analytic', stim, 0);
% Ignore zero Fisher information values
zeroJ = find(fish == 0);
pS = stim.pS;
if ~isempty(zeroJ)
fish(zeroJ) = [];
pS(zeroJ) = [];
pS = pS ./ sum(pS);
end
ifish = stim.entropy - 0.5 .* sum(pS .* (1.0 + log2(pi) + log2(exp(1)) - log2(fish)));
else
ifish = stim.entropy - 0.5 .* quad(@(s) stim.pSint(s) .* (1.0 + log2(pi) + log2(exp(1)) - log2(obj.fisher('analytic', s, 0))), stim.ensemble(1) - diff(stim.ensemble(1:2)), stim.ensemble(end));
end
end
function varargout = mIfisher(obj, nMarg, stim)
% Function for computing the marginal I_Fisher
% Get whole-population I_Fisher
ifishFull = obj.Ifisher(stim);
% Create population of remaining neurons
remainingNeurons = obj.remove(nMarg);
% Compute remainder I_Fisher
ifishRem = remainingNeurons.Ifisher(stim);
switch nargout
case 1
varargout = {ifishFull - ifishRem};
case 2
varargout = {ifishFull ifishRem};
otherwise
error('Wrong number of outputs')
end
end
function ssif = SSIfisher(obj, n, fisherMethod, stim, tol)
% Function for computing SSI_Fisher, see:
%
% Yarrow S, Challis E, Series P (2012)
% Fisher and Shannon information in finite neural populations.
% Neural Comput (in review).
% S stimulus
sMat = repmat(stim.ensemble, [stim.n 1]);
% sHat stimulus estimate
sHatMat = repmat(stim.ensemble', [1 stim.n]);
% log p(S)
psMat = repmat(stim.pS, [stim.n 1]);
if stim.circular
% circular difference
dS = mod(sHatMat - sMat, stim.circular);
i = find(dS > (0.5 * stim.circular));
dS(i) = dS(i) - stim.circular;
i = find(dS < (-0.5 * stim.circular));
dS(i) = dS(i) + stim.circular;
else
% linear difference
dS = sHatMat - sMat;
end
if ~isempty(n)
[fullFI remFI] = obj.margFisher(n, fisherMethod, stim, tol);
% Compute SSIfisher excluding cells of interest
% sigma(s)
% Compute SD of optimal estimator as a function of the stimulus
sigma = remFI .^ -0.5;
sigmaMat = repmat(sigma, [stim.n 1]);
% log p(sHat|S)
lpsHat_s = cellfun(@mvnormpdfln, num2cell(dS), num2cell(zeros([stim.n stim.n])), num2cell(sigmaMat));
% log p(S)
lpS = log(psMat);
% log p(sHat,S)
lpsHats = lpsHat_s + lpS;
% log p(sHat)
lpsHat = logsumexp(lpsHats, 2);
% log p(S|sHat)
lps_sHat = lpsHats - repmat(lpsHat, [1 stim.n]);
% H(s|sHat) as a function of sHat
Hs_sHat = -sum(exp(lps_sHat) .* lps_sHat ./ log(2), 2);
% Isp(sHat) specific information
isp = stim.entropy - Hs_sHat;
% SSIfisher
ssifRem = sum(exp(lpsHat_s + repmat(log(isp), [1 stim.n])), 1);
else
fullFI = obj.fisher(fisherMethod, stim, tol);
end
% Compute SSIfisher for full population
% sigma(s)
% Compute SD of optimal estimator as a function of the stimulus
sigma = fullFI .^ -0.5;
sigmaMat = repmat(sigma, [stim.n 1]);
% log p(sHat|S)
lpsHat_s = cellfun(@mvnormpdfln, num2cell(dS), repmat({0}, [stim.n stim.n]), num2cell(sigmaMat));
% log p(S)
lpS = log(psMat);
% log p(sHat,S)
lpsHats = lpsHat_s + lpS;
% log p(sHat)
lpsHat = logsumexp(lpsHats, 2);
% log p(S|sHat)
lps_sHat = lpsHats - repmat(lpsHat, [1 stim.n]);
% H(s|sHat) as a function of sHat
Hs_sHat = -sum(exp(lps_sHat) .* lps_sHat ./ log(2), 2);
% Isp(sHat) specific information
isp = stim.entropy - Hs_sHat;
% SSIfisher
ssifFull = sum(exp(lpsHat_s + repmat(log(isp), [1 stim.n])), 1);
if ~isempty(n)
ssif = ssifFull - ssifRem;
else
ssif = ssifFull;
end
end
function q = Q(obj, resp)
q = cellfun(@(r) (obj.add .* obj.R + obj.a .* obj.R .* (r * r').^obj.alpha).^obj.exponent, resp, 'UniformOutput', false);
end
function retStr = char(obj)
% CHAR Text representation of Neurons object
% aString = char(object)
if ~isscalar(obj)
retStr = 'Array of Neurons objects';
return
end
prefStr = char(obj.preferredStimulus);
%maxRateStr = char(obj.maxRate);
%backgroundRateStr = char(obj.backgroundRate);
rStr = display(obj.R);
formatStr = ['Population size: %.0f\n'...
'Stimulus dimensionality: %.0f\n'...
'Preferred stimuli:\n'...
prefStr '\n'...
%'Maximum firing rate (Hz):\n'...
% maxRateStr '\n'...
%'Background firing rate (Hz):\n'...
% backgroundRateStr '\n'...
'Integration time: %.3f s\n'...
'Variability distribution: %s\n'...
'Variability coefficient a: %.3f\n'...
'Variability exponent alpha: %.3f\n'...
'Correlation matrix:\n'...
'%s\n'];
retStr = sprintf(formatStr, obj.popSize, double(obj.dimensionality), obj.integrationTime, obj.distribution, obj.a, obj.alpha, rStr);
end
function retVal = tcplot(obj, stim)
r = meanR(obj, stim);
[stims ind] = sort(double(stim.ensemble));
retVal = plot(stims, r(:,ind));
end
function varargout = remove(obj, nMarg)
if ~isempty(nMarg)
% Create logical vector (mask) identifying neurons that are *not* part of the marginal SSI
margMask = ones(obj.popSize, 1);
margMask(nMarg) = false;
% Number of remaining neurons
nMarg = sum(margMask);
margMask = logical(margMask);
else
error('Must specify index of a neuron or neurons')
end
obj.preferredStimulus = obj.preferredStimulus(margMask);
obj.popSize = nMarg;
Rmask = logical((margMask+0) * (margMask+0)');
obj.R = reshape(obj.R(Rmask), [nMarg nMarg]);
switch nargout
case 1
varargout = {obj};
case 2
varargout = {obj margMask};
otherwise
error('Wrong number of outputs')
end
end
end
methods (Access = protected)
function [J_mean, J_cov] = fisher_analytic_gauss(obj, stim)
% Info=Fisher(N, alpha, correlation_type, param_correlation)
% compute I_mean, I_cov and I_mean and I_cov if neurons are independent
% when the tuning function is circular normal and correlation is toeplitz
%
% PARAMETERS:
% -----------
% - N is the number of neurons
% - f is the tuning curve
% - f_prime is the derivative of the tuning curve
% - Q is the covariance matrix
% - k is the fano factor, now a vector
% - k prime is its derivative
%
% RETURNS:
% ---------
% Info(1): I_mean
% Info(2): I_mean_ind if the correlations were set to 0 (same variance)
% Info(3): I_cov
% Info(4): I_cov_ind if the correlations were set to 0
% Info(5): I_mean+I_cov
% Info(6): I_SQE
% pseries@gatsby.ucl.ac.uk 2/20/2005
% s.yarrow@ed.ac.uk 2008-2011
f = obj.integrationTime .* obj.meanR(stim);
f_prime = obj.integrationTime .* obj.dMeanR(stim);
g0 = f_prime ./ f;
% ==============================================================
% Correlation matrix, Covariance matrix, inverse and derivative.
% Fourier tranform (eigenvalues)
% ==============================================================
if isa(stim, 'StimulusEnsemble')
nStim = stim.n;
else
nStim = length(stim);
end
fCell = squeeze(mat2cell(f, obj.popSize, ones(nStim, 1)));
f_primeCell = squeeze(mat2cell(f_prime, obj.popSize, ones(nStim, 1)));
g0Cell = squeeze(mat2cell(g0, obj.popSize, ones(nStim, 1)));
QCell1 = obj.Q(fCell);
Q_inv = cellfun(@inv, QCell1, 'UniformOutput', false); % inverse
k = repmat({obj.a(ones(obj.popSize,1))}, [1 nStim]);
k_prime = repmat({zeros(obj.popSize,1)}, [1, nStim]);
fQ_prime = @(q, kk, kp, gz) obj.alpha * (diag(gz) * q + q * diag(gz)) + (diag(kp ./ kk)) * q;
Q_prime = cellfun(fQ_prime, QCell1, k, k_prime, g0Cell, 'UniformOutput', false); % derivative
% ==============================================================
% Fisher Information
% J_mean, J_cov
% (J = J_mean + J_cov)
% ==============================================================
if obj.popSize > 1
J_mean = cellfun(@(fp, qi) fp' * qi * fp, f_primeCell, Q_inv); % direct method
%Info(2) = cellfun(@(fp, di) fp * di * fp', f_primeCell, D_inv);
else
J_mean = cellfun(@(fp, q) fp.^2 / q, f_primeCell, QCell1);
end
J_cov = cellfun(@(qp, qi) 0.5 * trace(qp * qi * qp * qi), Q_prime, Q_inv); % direct method for Icov
end
function J = fisher_analytic_poiss(obj, stim)
% TC and TC derivative
f = obj.meanR(stim);
f_prime = obj.dMeanR(stim);
% Fisher information, see:
% Bethge M, Rotermund D, Pawelzik K (2002)
% Optimal short-term population coding: when Fisher information fails.
% Neural Comput 14:2317?2351.
%
J = obj.integrationTime .* sum(f_prime.^2 ./ f, 1);
end
function [J FI] = fisher_mc(obj, method, stim, tol, maxiter)
% Generate offset stimuli
stim2 = stim;
dTheta = 0.1 .* stim2.width;
stim2.ensemble = stim2.ensemble + dTheta;
% obj.popSize x stim.n
rMean1 = obj.integrationTime .* obj.meanR(stim);
rMeanCell1 = squeeze(mat2cell(rMean1, obj.popSize, ones(stim.n, 1)));
rMean2 = obj.integrationTime .* obj.meanR(stim2);
rMeanCell2 = squeeze(mat2cell(rMean2, obj.popSize, ones(stim2.n, 1)));
% Concat (stim) and (stim + dTheta) mean arrays
rMeanCellDiff = [rMeanCell1 ; rMeanCell2];
% Do distribution-specific setup
switch obj.distribution
case 'Gaussian'
% Compute mean response dependent cov matrix stack Q [ (popSize x popSize) x stim.n ]
QCell1 = obj.Q(rMeanCell1);
QCell2 = obj.Q(rMeanCell2);
% Compute Cholesky decomps of Q matrices
cholQ = cellfun(@(q) chol(q)', QCell1, 'UniformOutput', false);
% Invert Q matrices and compute Cholesky decomps
cholInvQCell1 = cellfun(@(q) chol(inv(q)), QCell1, 'UniformOutput', false);
cholInvQCell2 = cellfun(@(q) chol(inv(q)), QCell2, 'UniformOutput', false);
clear QCell1 QCell2
% Concat (stim) and (stim + dTheta) chol(Q^-1) arrays
cholInvQCellDiff = [cholInvQCell1 ; cholInvQCell2];
clear cholInvQCell1 cholInvQCell2
% Define function for multivariate gaussian sampling
% Multiply by Cholesky decomposition of cov matrix Q, and add in mean
if obj.truncate
fRand = @(m, c, z) max((m + c * z), 0.0); % truncate
else
fRand = @(m, c, z) m + c * z; % don't truncate
end
case 'Poisson'
% Nothing to do here
otherwise
error('Unsupported distribution: %s', obj.distribution)
end
iter = 0;
cont = true;
FI = OnlineStats(stim.n, maxiter);
while cont
iter = iter + 1;
if ~mod(iter, 100)
fprintf('Fisher iter: %d, rel. error: %.4g\n', iter, mean(FI.runDelta))
end
switch method
case 'randMC'
% Sample r from response distribution
switch obj.distribution
case 'Gaussian'
% Generate vector of independent normal random numbers (mu=0, sigma=1)
zCell = mat2cell(randn(obj.popSize, stim.n), obj.popSize, ones(stim.n, 1));
% Multiply by Cholesky decomposition of cov matrix Q, and add in mean
rCell = cellfun(fRand, rMeanCell1, cholQ, zCell, 'UniformOutput', false); % stim.n cell array of {obj.popSize}
case 'Poisson'
% Sample from Poisson
rCell = cellfun(@poissrnd, rMeanCell1, 'UniformOutput', false);
end
otherwise
error('Unsupported method: %s', method)
end
% log P(r|s)
% Calculate response probability densities
switch obj.distribution
case 'Gaussian'
lpRgS = cell2mat(cellsxfun(@mvnormpdfln, rCell, rMeanCellDiff, cholInvQCellDiff, {'inv'})); % 1 x stim.n
case 'Poisson'
lpRgS = cell2mat(cellsxfun(@(x, l) sum(poisspdfln(x, l)), rCell, rMeanCellDiff)); % 1 x stim.n
end
% d/ds log P(r|s)
% Find derivative
dlpRgS = diff(lpRgS, 1, 1) ./ dTheta;
% (d/ds log P(r|s))^2
FI.appendSample(dlpRgS .^ 2);
% Test halting criteria (SEM, max iterations limit, timeout)
delta = mean(FI.runDelta);
cont = all([delta > tol, iter < maxiter]);
% Impose minimum iteration limit so we get a sensible estimate of SEM
cont = cont | iter < 10;
end
fprintf('FI completed iter: %d\n', iter)
% Trim sample array
FI.trim;
% Return output
J = FI.mean;
end
function pSR = fpSR(obj, s, r, stim)
% mean response given s
rMean = obj.integrationTime .* meanR(obj, s);
rMeanCell = squeeze(mat2cell(rMean, obj.popSize, ones(length(s), 1)));
rCell = repmat({r}, [length(s) 1]);
switch obj.distribution
case 'Gaussian'
% covariance matrix
Q = obj.Q(rMeanCell);
% log p(r|s)
lpRgS = cell2mat(cellsxfun(@mvnormpdfln, rCell, rMeanCell', {[]}, Q'))';
case 'Poisson'
% log p(r|s)
lpRgS = sum(poisspdfln(r, rMean));
end
% log p(s) via linear piecewise interpolation on stimulus distribution
lpS = log(stim.pSint(s));
% p(s,r) = exp(log p(s) + log p(r|s))
pSR = exp(lpS + lpRgS);
end
end
end