function [traj, infStates] = tapas_hmm(r, p, varargin)
% Estimates a hidden Markov model (HMM)
%
% This function can be called in two ways:
%
% (1) hmm(r, p)
%
% where r is the structure generated by fitModel and p is the parameter vector in native space;
%
% (2) hmm(r, ptrans, 'trans')
%
% where r is the structure generated by fitModel, ptrans is the parameter vector in
% transformed space, and 'trans' is a flag indicating this.
%
% --------------------------------------------------------------------------------------------------
% Copyright (C) 2013 Christoph Mathys, TNU, UZH & ETHZ
%
% This file is part of the HGF toolbox, which is released under the terms of the GNU General Public
% Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL
% (either version 3 or, at your option, any later version). For further details, see the file
% COPYING or <http://www.gnu.org/licenses/>.
% Transform paramaters back to their native space if needed
if ~isempty(varargin) && strcmp(varargin{1},'trans');
p = hmm_transp(r, p);
end
% Fixed configuration elements
%
% Number of hidden states
d = r.c_prc.n_states;
% Number of possible outcomes
m = r.c_prc.n_outcomes;
% Outcome probabilities, given states
B = r.c_prc.B;
if any(sum(B)~=1)
error('tapas:hgf:hmm:IllegB', 'Illegal state-outcome contingency matrix B');
end
% Unpack estimated parameters
%
% Prior probabilities of states
ppired = p(1:d-1);
ppilast = 1-sum(ppired);
if ppilast < 0
error('tapas:hgf:hmm:IllegStatePrior', 'Illegal state prior.');
end
ppi = [ppired, ppilast];
% Transition matrix (d*d)
Ared = reshape(p(d:d^2-1), d, d-1);
Alastcolumn = 1-sum(Ared,2);
if any(Alastcolumn) < 0
error('tapas:hgf:hmm:IllegTransMat', 'Illegal transition matrix.');
end
A = [Ared, Alastcolumn];
% Input and number of trials
u = r.u(:,1);
n = length(u);
% Initialize alpha-prime
alpr = NaN(n,d);
% alpr(1,:)
altmp = ppi.*B(u(1),:);
llh = sum(altmp);
alpr(1,:) = altmp./llh;
% Pass through alpha-prime update loop
for k = 2:1:n
if not(ismember(k, r.ign))
%%%%%%%%%%%%%%%%%%%%%%
% Effect of input u(k)
%%%%%%%%%%%%%%%%%%%%%%
altmp = B(u(k),:).*(alpr(k-1,:)*A);
llh = sum(altmp);
alpr(k,:) = altmp./llh;
else
alpr(k,:) = alpr(k-1,:);
end
end
% Predicted states
alprhat = [ppi; alpr];
alprhat(end,:) = [];
% Create result data structure
traj = struct;
traj.alpr = alpr;
traj.alprhat = alprhat;
% Create matrix needed by observation model
infStates = traj.alpr;
return;