function [traj, infStates] = tapas_hgf_categorical(r, p, varargin)
if ~isfield(r,'c_prc')
error('tapas:hgf:ConfigRequired', 'Configuration required: before calling tapas_hgf_categorical, tapas_hgf_categorical_config has to be called.');
end
if ~isempty(varargin) && strcmp(varargin{1},'trans');
p = tapas_hgf_categorical_transp(r, p);
end
no = r.c_prc.n_outcomes;
mu2_0 = p(1:no);
sa2_0 = p(no+1:2*no);
mu3_0 = p(2*no+1);
sa3_0 = p(2*no+2);
ka = p(2*no+3);
om = p(2*no+4);
th = p(2*no+5);
u = [1; r.u(:,1)];
n = length(u);
mu1 = NaN(n,no);
pi1 = NaN(n,no);
mu2 = NaN(n,no);
pi2 = NaN(n,no);
mu3 = NaN(n,1);
pi3 = NaN(n,1);
mu1hat = NaN(n,no);
pi1hat = NaN(n,no);
mu2hat = NaN(n,no);
pi2hat = NaN(n,no);
mu3hat = NaN(n,1);
pi3hat = NaN(n,1);
v2 = NaN(n,1);
w2 = NaN(n,no);
da1 = NaN(n,no);
da2 = NaN(n,no);
mu1(1,:) = tapas_sgm(mu2_0, 1);
pi1(1,:) = 1./(mu1(1,:).*(1-mu1(1,:)));
mu2(1,:) = mu2_0;
pi2(1,:) = 1./sa2_0;
mu3(1) = mu3_0;
pi3(1) = 1/sa3_0;
for k = 2:1:n
if not(ismember(k-1, r.ign))
mu1hat(k,:) = tapas_sgm(mu2(k-1,:), 1);
pi1hat(k,:) = 1./(mu1hat(k,:).*(1 -mu1hat(k,:)));
mu1(k,:) = 0;
mu1(k,u(k)) = 1;
pi1(k,:) = Inf;
da1(k,:) = mu1(k,:) -mu1hat(k,:);
mu2hat(k,:) = mu2(k-1,:);
pi2hat(k,:) = 1./(1./pi2(k-1,:) +exp(ka *mu3(k-1) +om));
pi2(k,:) = pi2hat(k,:) +1./pi1hat(k,:);
mu2(k,:) = mu2hat(k,:) +1./pi2(k,:) .*da1(k,:);
da2(k,:) = (1./pi2(k,:) +(mu2(k,:) -mu2hat(k,:)).^2) .*pi2hat(k,:) -1;
mu3hat(k) = mu3(k-1);
pi3hat(k) = 1/(1/pi3(k-1) +th);
v2(k) = exp(ka *mu3(k-1) +om);
w2(k,:) = v2(k) *pi2hat(k,:);
pi3(k) = pi3hat(k) +sum(1/2 *ka^2 *w2(k,:) .*(w2(k,:) +(2 *w2(k,:) -1) .*da2(k,:)));
if pi3(k) <= 0
error('tapas:hgf:NegPostPrec', 'Negative posterior precision. Parameters are in a region where model assumptions are violated.');
end
mu3(k) = mu3hat(k) +sum(1/2 *1/pi3(k) *ka *w2(k,:) .*da2(k,:));
else
mu1(k,:) = mu1(k-1,:);
pi1(k,:) = pi1(k-1,:);
mu2(k,:) = mu2(k-1,:);
pi2(k,:) = pi2(k-1,:);
mu3(k) = mu3(k-1);
pi3(k) = pi3(k-1);
mu1hat(k,:) = mu1hat(k-1,:);
pi1hat(k,:) = pi1hat(k-1,:);
mu2hat(k,:) = mu2hat(k-1,:);
pi2hat(k,:) = pi2hat(k-1,:);
mu3hat(k) = mu3hat(k-1);
pi3hat(k) = pi3hat(k-1);
v2(k) = v2(k-1);
w2(k,:) = w2(k-1,:);
da1(k,:) = da1(k-1,:);
da2(k,:) = da2(k-1,:);
end
end
mu1(1,:) = [];
pi1(1,:) = [];
mu2(1,:) = [];
pi2(1,:) = [];
mu3(1) = [];
pi3(1) = [];
mu1hat(1,:) = [];
pi1hat(1,:) = [];
mu2hat(1,:) = [];
pi2hat(1,:) = [];
mu3hat(1) = [];
pi3hat(1) = [];
v2(1) = [];
w2(1,:) = [];
da1(1,:) = [];
da2(1,:) = [];
traj = struct;
traj.mu = NaN(n-1,3,no);
traj.mu(:,1,:) = mu1;
traj.mu(:,2,:) = mu2;
traj.mu(:,3,1) = mu3;
traj.sa = NaN(n-1,3,no);
traj.sa(:,1,:) = 1./pi1;
traj.sa(:,2,:) = 1./pi2;
traj.sa(:,3,1) = 1./pi3;
traj.muhat = NaN(n-1,3,no);
traj.muhat(:,1,:) = mu1hat;
traj.muhat(:,2,:) = mu2hat;
traj.muhat(:,3,1) = mu3hat;
traj.sahat = NaN(n-1,3,no);
traj.sahat(:,1,:) = 1./pi1hat;
traj.sahat(:,2,:) = 1./pi2hat;
traj.sahat(:,3,1) = 1./pi3hat;
traj.v = v2;
traj.w = w2;
traj.da = NaN(n-1,2,no);
traj.da(:,1,:) = da1;
traj.da(:,2,:) = da2;
traj.ud = traj.mu -traj.muhat;
psi = NaN(n-1,3,no);
for k = 1:n-1
psi(k,2,:) = 1./pi2(k,:);
psi(k,3,:) = pi2hat(k,:)./pi3(k);
end
traj.psi = psi;
epsi = NaN(n-1,3,no);
for k = 1:n-1
epsi(k,2,:) = squeeze(psi(k,2,:))' .*squeeze(da1(k,:));
epsi(k,3,:) = squeeze(psi(k,3,:))' .*squeeze(da2(k,:));
end
traj.epsi = epsi;
lr1 = NaN(n-1,no);
for k = 1:n-1
upd1 = tapas_sgm(mu2(k,:), 1) -mu1hat(k,:);
lr1(k,:) = upd1./da1(k,:);
end
wt = NaN(n-1,3,no);
wt(:,1,:) = lr1;
wt(:,2,:) = psi(:,2,:);
v2psi = NaN(n-1,no);
for k = 1:n-1
v2psi(k,:) = v2(k)*psi(k,3,:);
end
wt(:,3,:) = 1/2 *ka *v2psi;
traj.wt = wt;
infStates = NaN(n-1,3,no,4);
infStates(:,:,:,1) = traj.muhat;
infStates(:,:,:,2) = traj.sahat;
infStates(:,:,:,3) = traj.mu;
infStates(:,:,:,4) = traj.sa;
return;