function [traj, infStates] = tapas_hgf_jget(r, p, varargin)
% Calculates the trajectories of the agent's representations under the HGF model of the jumping
% Gaussian estimation task (JGET)
%
% This function can be called in two ways:
%
% (1) tapas_hgf_jget(r, p)
%
% where r is the structure generated by fitModel and p is the parameter vector in native space;
%
% (2) tapas_hgf_jget(r, ptrans, 'trans')
%
% where r is the structure generated by fitModel, ptrans is the parameter vector in
% transformed space, and 'trans' is a flag indicating this.
%
% --------------------------------------------------------------------------------------------------
% Copyright (C) 2013 Christoph Mathys, TNU, UZH & ETHZ
%
% This file is part of the HGF toolbox, which is released under the terms of the GNU General Public
% Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL
% (either version 3 or, at your option, any later version). For further details, see the file
% COPYING or <http://www.gnu.org/licenses/>.
% Transform paramaters back to their native space if needed
if ~isempty(varargin) && strcmp(varargin{1},'trans');
p = tapas_hgf_jget_transp(r, p);
end
% Number of levels
try
l = r.c_prc.n_levels;
catch
l = length(p)/8;
if l ~= floor(l)
error('tapas:hgf:UndetNumLevels', 'Cannot determine number of levels');
end
end
% Unpack parameters
mux_0 = p(1:l);
sax_0 = p(l+1:2*l);
mua_0 = p(2*l+1:3*l);
saa_0 = p(3*l+1:4*l);
kau = p(4*l+1);
kax = p(4*l+2:5*l);
kaa = p(5*l+1:6*l-1);
omu = p(6*l);
omx = p(6*l+1:7*l-1);
oma = p(7*l+1:8*l-1);
thx = exp(p(7*l));
tha = exp(p(8*l));
% Add dummy "zeroth" trial
u = [0; r.u(:,1)];
% Number of trials (including prior)
n = length(u);
% Construct time axis
if r.c_prc.irregular_intervals
if size(u,2) > 1
t = [0; r.u(:,end)];
else
error('tapas:hgf:InputSingleColumn', 'Input matrix must contain more than one column if irregular_intervals is set to true.');
end
else
t = ones(n,1);
end
% Initialize updated quantities
% Representations
mux = NaN(n,l);
pix = NaN(n,l);
mua = NaN(n,l);
pia = NaN(n,l);
% Other quantities
muuhat = NaN(n,1);
piuhat = NaN(n,1);
muxhat = NaN(n,l);
pixhat = NaN(n,l);
muahat = NaN(n,l);
piahat = NaN(n,l);
daux = NaN(n,1);
daua = NaN(n,1);
wx = NaN(n,l-1);
dax = NaN(n,l-1);
wa = NaN(n,l-1);
daa = NaN(n,l-1);
% Representation priors
% Note: first entries of the other quantities remain
% NaN because they are undefined and are thrown away
% at the end; their presence simply leads to consistent
% trial indices.
mux(1,:) = mux_0;
pix(1,:) = 1./sax_0;
mua(1,:) = mua_0;
pia(1,:) = 1./saa_0;
% Representation update loop
% Pass through trials
for k = 2:1:n
if not(ismember(k-1, r.ign))
%%%%%%%%%%%%%%%%%%%%%%
% Effect of input u(k)
%%%%%%%%%%%%%%%%%%%%%%
% Input level
% ~~~~~~~~~~~
% Prediction (same as prediction of x_1, see below)
muuhat(k) = mux(k-1,1);
% Precision of prediction
piuhat(k) = 1/exp(kau *mua(k-1,1) +omu);
% Mean prediction error
daux(k) = u(k) -muuhat(k);
% 1st level
% ~~~~~~~~~
% Predictions
muxhat(k,1) = mux(k-1,1);
muahat(k,1) = mua(k-1,1);
% Precisions of predictions
pixhat(k,1) = 1/(1/pix(k-1,1) +t(k) *exp(kax(1) *mux(k-1,2) +omx(1)));
piahat(k,1) = 1/(1/pia(k-1,1) +t(k) *exp(kaa(1) *mua(k-1,2) +oma(1)));
% x-updates
pix(k,1) = pixhat(k,1) +piuhat(k);
mux(k,1) = muxhat(k,1) +piuhat(k)/pix(k,1) *daux(k);
% Prediction error of input precision
daua(k) = (1/pix(k,1) +(mux(k,1) -u(k))^2) *piuhat(k) -1;
% alpha-updates
pia(k,1) = piahat(k,1) +1/2 *kau^2 *(1 +daua(k));
mua(k,1) = muahat(k,1) +1/2 *1/pia(k,1) *kau *daua(k);
% Volatility prediction errors
dax(k,1) = (1/pix(k,1) +(mux(k,1) -muxhat(k,1))^2) *pixhat(k,1) -1;
daa(k,1) = (1/pia(k,1) +(mua(k,1) -muahat(k,1))^2) *piahat(k,1) -1;
if l > 2
% Pass through higher levels
% ~~~~~~~~~~~~~~~~~~~~~~~~~~
for j = 2:l-1
% Predictions
muxhat(k,j) = mux(k-1,j);
muahat(k,j) = mua(k-1,j);
% Precisions of predictions
pixhat(k,j) = 1/(1/pix(k-1,j) +t(k) *exp(kax(j) *mux(k-1,j+1) +omx(j)));
piahat(k,j) = 1/(1/pia(k-1,j) +t(k) *exp(kaa(j) *mua(k-1,j+1) +oma(j)));
% Weighting factors
wx(k,j-1) = t(k) *exp(kax(j-1) *mux(k-1,j) +omx(j-1)) *pixhat(k,j-1);
wa(k,j-1) = t(k) *exp(kaa(j-1) *mua(k-1,j) +oma(j-1)) *piahat(k,j-1);
% Updates
pix(k,j) = pixhat(k,j) +1/2 *kax(j-1)^2 *wx(k,j-1) *(wx(k,j-1) +(2 *wx(k,j-1) -1) *dax(k,j-1));
pia(k,j) = piahat(k,j) +1/2 *kaa(j-1)^2 *wa(k,j-1) *(wa(k,j-1) +(2 *wa(k,j-1) -1) *daa(k,j-1));
if pix(k,j) <= 0 || pia(k,j) <= 0
error('tapas:hgf:NegPostPrec', 'Negative posterior precision. Parameters are in a region where model assumptions are violated.');
end
mux(k,j) = muxhat(k,j) +1/2 *1/pix(k,j) *kax(j-1) *wx(k,j-1) *dax(k,j-1);
mua(k,j) = muahat(k,j) +1/2 *1/pia(k,j) *kaa(j-1) *wa(k,j-1) *daa(k,j-1);
% Volatility prediction errors
dax(k,j) = (1/pix(k,j) +(mux(k,j) -muxhat(k,j))^2) *pixhat(k,j) -1;
daa(k,j) = (1/pia(k,j) +(mua(k,j) -muahat(k,j))^2) *piahat(k,j) -1;
end
end
% Last level
% ~~~~~~~~~~
% Predictions
muxhat(k,l) = mux(k-1,l);
muahat(k,l) = mua(k-1,l);
% Precision of prediction
pixhat(k,l) = 1/(1/pix(k-1,l) +t(k) *thx);
piahat(k,l) = 1/(1/pia(k-1,l) +t(k) *tha);
% Weighting factor
wx(k,l-1) = t(k) *exp(kax(l-1) *mux(k-1,l) +omx(l-1)) *pixhat(k,l-1);
wa(k,l-1) = t(k) *exp(kaa(l-1) *mua(k-1,l) +oma(l-1)) *piahat(k,l-1);
% Updates
pix(k,l) = pixhat(k,l) +1/2 *kax(l-1)^2 *wx(k,l-1) *(wx(k,l-1) +(2 *wx(k,l-1) -1) *dax(k,l-1));
pia(k,l) = piahat(k,l) +1/2 *kaa(l-1)^2 *wa(k,l-1) *(wa(k,l-1) +(2 *wa(k,l-1) -1) *daa(k,l-1));
if pix(k,l) <= 0 || pia(k,l) <= 0
error('tapas:hgf:NegPostPrec', 'Negative posterior precision. Parameters are in a region where model assumptions are violated.');
end
mux(k,l) = muxhat(k,l) +1/2 *1/pix(k,l) *kax(l-1) *wx(k,l-1) *dax(k,l-1);
mua(k,l) = muahat(k,l) +1/2 *1/pia(k,l) *kaa(l-1) *wa(k,l-1) *daa(k,l-1);
% Volatility prediction error
dax(k,l) = (1/pix(k,l) +(mux(k,l) -muxhat(k,l))^2) *pixhat(k,l) -1;
daa(k,l) = (1/pia(k,l) +(mua(k,l) -muahat(k,l))^2) *piahat(k,l) -1;
else
mux(k,:) = mux(k-1,:);
mua(k,:) = mua(k-1,:);
pix(k,:) = pix(k-1,:);
pia(k,:) = pia(k-1,:);
muuhat(k) = muuhat(k-1);
piuhat(k) = piuhat(k-1);
muxhat(k,:) = muxhat(k-1,:);
muahat(k,:) = muahat(k-1,:);
pixhat(k,:) = pixhat(k-1,:);
piahat(k,:) = piahat(k-1,:);
daux(k) = daux(k-1);
daua(k) = daua(k-1);
wx(k,:) = wx(k-1,:);
wa(k,:) = wa(k-1,:);
dax(k,:) = dax(k-1,:);
daa(k,:) = daa(k-1,:);
end
end
% Remove representation priors
mux(1,:) = [];
mua(1,:) = [];
pix(1,:) = [];
pia(1,:) = [];
% Check validity of trajectories
if any(isnan(mux(:))) || any(isnan(pix(:))) || any(isnan(mua(:))) || any(isnan(pia(:)))
error('tapas:hgf:VarApproxInvalid', 'Variational approximation invalid. Parameters are in a region where model assumptions are violated.');
else
% Check for implausible jumps in trajectories
dmux = diff(mux);
dmua = diff(mua);
dpix = diff(pix);
dpia = diff(pia);
rmdmux = repmat(sqrt(mean(dmux.^2)),length(dmux),1);
rmdmua = repmat(sqrt(mean(dmua.^2)),length(dmua),1);
rmdpix = repmat(sqrt(mean(dpix.^2)),length(dpix),1);
rmdpia = repmat(sqrt(mean(dpia.^2)),length(dpia),1);
jumpTol = 256;
if any(abs(dmux(:)) > jumpTol*rmdmux(:)) || any(abs(dmua(:)) > jumpTol*rmdmua(:)) || any(abs(dpix(:)) > jumpTol*rmdpix(:)) || any(abs(dpia(:)) > jumpTol*rmdpia(:))
error('tapas:hgf:VarApproxInvalid', 'Variational approximation invalid. Parameters are in a region where model assumptions are violated.');
end
end
% Remove other dummy initial values
muuhat(1) = [];
piuhat(1) = [];
muxhat(1,:) = [];
muahat(1,:) = [];
pixhat(1,:) = [];
piahat(1,:) = [];
wx(1,:) = [];
wa(1,:) = [];
daux(1) = [];
daua(1) = [];
dax(1,:) = [];
daa(1,:) = [];
% Extract learning rates
lrx = NaN(n-1,l);
lra = NaN(n-1,l);
lrx(:,1) = piuhat./pix(:,1);
lrx(:,2:end) = kax./2 *wx./pix(:,2:end);
lra(:,1) = 1/2 *kau./pia(:,1);
lra(:,2:end) = kaa./2 *wa./pia(:,2:end);
% Create result data structure
traj = struct;
traj.mux = mux;
traj.mua = mua;
traj.sax = 1./pix;
traj.saa = 1./pia;
traj.muuhat = muuhat;
traj.muxhat = muxhat;
traj.muahat = muahat;
traj.sauhat = 1./piuhat;
traj.saxhat = 1./pixhat;
traj.saahat = 1./piahat;
traj.wx = wx;
traj.wa = wa;
traj.daux = daux;
traj.daua = daua;
traj.dax = dax;
traj.daa = daa;
traj.lrx = lrx;
traj.lra = lra;
% Create matrices for use by the observation model
infStates = NaN(n-1,1,10);
infStates(:,1,1) = traj.muuhat;
infStates(:,1,2) = traj.sauhat;
infStates(:,1,3) = traj.muxhat(:,1);
infStates(:,1,4) = traj.saxhat(:,1);
infStates(:,1,5) = traj.muahat(:,1);
infStates(:,1,6) = traj.saahat(:,1);
infStates(:,1,7) = traj.mux(:,1);
infStates(:,1,8) = traj.sax(:,1);
infStates(:,1,9) = traj.mua(:,1);
infStates(:,1,10) = traj.saa(:,1);
return;