function [traj, infStates] = tapas_hmm(r, p, varargin)
% Estimates a hidden Markov model (HMM)
%
% This function can be called in two ways:
% 
% (1) hmm(r, p)
%   
%     where r is the structure generated by fitModel and p is the parameter vector in native space;
%
% (2) hmm(r, ptrans, 'trans')
% 
%     where r is the structure generated by fitModel, ptrans is the parameter vector in
%     transformed space, and 'trans' is a flag indicating this.
%
% --------------------------------------------------------------------------------------------------
% Copyright (C) 2013 Christoph Mathys, TNU, UZH & ETHZ
%
% This file is part of the HGF toolbox, which is released under the terms of the GNU General Public
% Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL
% (either version 3 or, at your option, any later version). For further details, see the file
% COPYING or <http://www.gnu.org/licenses/>.

% Transform paramaters back to their native space if needed
if ~isempty(varargin) && strcmp(varargin{1},'trans');
    p = hmm_transp(r, p);
end

% Fixed configuration elements
%
% Number of hidden states
d = r.c_prc.n_states;
% Number of possible outcomes
m = r.c_prc.n_outcomes;
% Outcome probabilities, given states
B = r.c_prc.B;
if any(sum(B)~=1)
    error('tapas:hgf:hmm:IllegB', 'Illegal state-outcome contingency matrix B');
end

% Unpack estimated parameters
%
% Prior probabilities of states
ppired = p(1:d-1);
ppilast = 1-sum(ppired);
if ppilast < 0
    error('tapas:hgf:hmm:IllegStatePrior', 'Illegal state prior.');
end
ppi = [ppired, ppilast];

% Transition matrix (d*d)
Ared = reshape(p(d:d^2-1), d, d-1);
Alastcolumn = 1-sum(Ared,2);
if any(Alastcolumn) < 0
    error('tapas:hgf:hmm:IllegTransMat', 'Illegal transition matrix.');
end
A = [Ared, Alastcolumn];

% Input and number of trials
u = r.u(:,1);
n = length(u);

% Initialize alpha-prime
alpr = NaN(n,d);

% alpr(1,:)
altmp = ppi.*B(u(1),:);
llh = sum(altmp);
alpr(1,:) = altmp./llh;

% Pass through alpha-prime update loop
for k = 2:1:n
    if not(ismember(k, r.ign))
        
        %%%%%%%%%%%%%%%%%%%%%%
        % Effect of input u(k)
        %%%%%%%%%%%%%%%%%%%%%%
        
        altmp = B(u(k),:).*(alpr(k-1,:)*A);
        llh = sum(altmp);
        alpr(k,:) = altmp./llh;
    else
        alpr(k,:) = alpr(k-1,:);
    end
end

% Predicted states
alprhat = [ppi; alpr];
alprhat(end,:) = [];

% Create result data structure
traj = struct;

traj.alpr    = alpr;
traj.alprhat = alprhat;

% Create matrix needed by observation model
infStates = traj.alpr;

return;