function [traj, infStates] = tapas_hhmm(r, p, varargin)
% Estimates a hierarchical hidden Markov model (HHMM)
%
% This function can be called in two ways:
% 
% (1) hhmm(r, p)
%   
%     where r is the structure generated by fitModel and p is the parameter vector in native space;
%
% (2) hhmm(r, ptrans, 'trans')
% 
%     where r is the structure generated by fitModel, ptrans is the parameter vector in
%     transformed space, and 'trans' is a flag indicating this.
%
% --------------------------------------------------------------------------------------------------
% Copyright (C) 2013 Christoph Mathys, TNU, UZH & ETHZ
%
% This file is part of the HGF toolbox, which is released under the terms of the GNU General Public
% Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL
% (either version 3 or, at your option, any later version). For further details, see the file
% COPYING or <http://www.gnu.org/licenses/>.

% Transform paramaters back to their native space if needed
if ~isempty(varargin) && strcmp(varargin{1},'trans');
    [p pstruct] = hhmm_transp(r, p);
end

% Fixed configuration elements
%
% Number of possible outcomes
m = r.c_prc.n_outcomes;

% Get model tree (N is for node)
N = pstruct.N;

% Check constraints
if ~isempty(N{1}.V)
    error('tapas:hgf:hhmm:IllegEntryProbRoot', 'Illegal entry probability for root node.');
end

for id = 1:length(N)
    if isempty(N{id}.A) == isempty(N{id}.B)
        error('tapas:hgf:hhmm:IllegCombOfAB', 'Illegal combination of A and B for node no. %d.', id);
    end
    
    if length(N{id}.children(:)) ~= size(N{id}.A,2)
        error('tapas:hgf:hhmm:NumOfChildIncons', 'Number of children inconsistent with A for node no. %d.', id);
    end
    
    if ~isempty(N{id}.A) && any(sum(N{id}.A,2)>1)
        error('tapas:hgf:hhmm:IllegA', 'Illegal transition matrix A for node no. %d: row sums have to be less than or equal to 1.', id);
    end
    
    if ~isempty(N{id}.A)
        for cid = N{id}.children
            cidx = find(N{id}.children==cid);
            if ~isempty(N{cid}.children) && N{id}.A(cidx,cidx) ~= 0
                error('tapas:hgf:hhmm:IllegASelf', 'Illegal transition matrix A for node no. %d: only production nodes may have self-transitions.', id);
            end
        end
    end
    
    if ~isempty(N{id}.B) && sum(N{id}.B(:))~=1
        error('tapas:hgf:hhmm:IllegB', 'Illegal outcome contingency vector B for node no. %d.', id);
    end
    
    if ~isempty(N{id}.children)
        Vsum = 0;
        for cid = N{id}.children
            Vsum = Vsum + N{cid}.V;
        end
        
        if Vsum ~= 1
            error('tapas:hgf:hhmm:IllegV', 'Illegal vertical transition probabilities V from node no. %d.', id);
        end
    end
end

% Flatten the tree into one large transition matrix
% ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

% Find production nodes
pn = [];
for id = 1:length(N)
    if isempty(N{id}.children)
        pn = [pn, id];
    end
end
flatDim = length(pn);

% Initialize outcome contingency matrix
Bflat = NaN(m,flatDim);

% Fill Bflat
for i = 1:flatDim
    Bflat(:,i) = N{pn(i)}.B';
end

% Check Bflat
if any(~isfinite(Bflat(:)))
    error('tapas:hgf:hhmm:NoOutConMat', 'Could not construct outcome contingency matrix.');
end

% Initialize flattened transitioned matrix
Aflat = NaN(flatDim);

% Fill Aflat
for i = 1:flatDim
    for j = 1:flatDim
        if N{pn(i)}.parent == N{pn(j)}.parent
            % If the production nodes are siblings, read out
            % their parent's A matrix
            pid = N{pn(i)}.parent;
            idx = find(N{pid}.children==pn(i));
            jdx = find(N{pid}.children==pn(j));
            Aflat(i,j) = N{pid}.A(idx,jdx);
        else
            % Otherwise, determine their lowest common ancestor
            % and use that to calculate the transition probability
            caid = ca(N,pn(i),pn(j));
            tp = 1;

            nid  = pn(i);
            pid  = N{nid}.parent;
            nidx = find(N{pid}.children==nid);

            % Move up to one node below lowest common ancestor from start node pn(i)
            while pid ~= caid
                aend = 1-sum(N{pid}.A(nidx,:));
                tp = tp*aend;

                nid = pid;
                pid = N{nid}.parent;
                nidx = find(N{pid}.children==nid);
            end

            % Do the horizontal transition to the ancestral line of target node pn(j)
            ancj = anc(N,pn(j));
            while ancj(1) ~= caid
                ancj(1) = [];
            end
            ancj(1) = [];

            caidxi = nidx;
            caidxj = find(N{caid}.children==ancj(1));
            tp = tp*N{caid}.A(caidxi,caidxj);
            
            % Go down to the target node pn(j)
            ancj(1) = [];
            while ~isempty(ancj)
                tp = tp*N{ancj(1)}.V;
                ancj(1) = [];
            end

            Aflat(i,j) = tp;
        end
    end
end

% Check Aflat
if any(~isfinite(Aflat(:)))
    error('tapas:hgf:hhmm:NoFlatTransMat', 'Could not flatten transition matrix.');
end

% Calculate prior probabilities of production nodes
pnp = NaN(1,flatDim);
for i = 1:flatDim
    anci = anc(N,pn(i));
    anci(1) = [];
    p = 1;
    while ~isempty(anci)
        p = p*N{anci(1)}.V;
        anci(1) = [];
    end
    pnp(i) = p;
end

% Check pnp
if sum(pnp) ~= 1
    error('tapas:hgf:hhmm:NoPriorProdNodes', 'Cannot calculate prior probabilities of production nodes.');
end

% Input and number of trials
u = r.u(:,1);
n = length(u);

% Initialize alpha-prime
alpr = NaN(n,flatDim);

% alpr(1,:)
altmp = pnp.*Bflat(u(1),:);
llh = sum(altmp);
alpr(1,:) = altmp./llh;

% Pass through alpha-prime update loop
for k = 2:1:n
    if not(ismember(k, r.ign))
        
        %%%%%%%%%%%%%%%%%%%%%%
        % Effect of input u(k)
        %%%%%%%%%%%%%%%%%%%%%%
        
        altmp = Bflat(u(k),:).*(alpr(k-1,:)*Aflat);
        llh = sum(altmp);
        alpr(k,:) = altmp./llh;
    else
        alpr(k,:) = alpr(k-1,:);
    end
end

% Predicted states
alprhat = [pnp; alpr];
alprhat(end,:) = [];

% Create result data structure
traj = struct;

traj.alpr    = alpr;
traj.alprhat = alprhat;

% Create matrix needed by observation model
infStates = traj.alpr;

end % function hhmm

% ----------------------------------------------------------------------------------------
% Find lowest common ancestor of nodes
function ca = ca(N,ida,idb)
    % Find ancestors of ida and idb
    anca = anc(N,ida);
    ancb = anc(N,idb);
    
    % Determine lowest common ancestor
    ca = NaN;
    aa = anca(1);
    ab = ancb(1);
    while aa == ab && ~isempty(anca) && ~isempty(ancb)
        ca = aa;
        anca(1) = [];
        ancb(1) = [];
        aa = anca(1);
        ab = ancb(1);
    end
    
    if aa == ab
        ca = aa;
    end
end

% ----------------------------------------------------------------------------------------
% Find ancestors of a node
function anc = anc(N,id)
    anc = id;
    idt = N{id}.parent;
    while ~isempty(idt)
        anc = [idt, anc];
        idt = N{idt}.parent;
    end
end