function [traj, infStates] = tapas_kf(r, p, varargin)
% The scalar Kalman filter
%
% This function can be called in two ways:
%
% (1) tapas_kf(r, p)
%
% where r is the structure generated by tapas_fitModel and p is the parameter vector in native space;
%
% (2) tapas_kf(r, ptrans, 'trans')
%
% where r is the structure generated by tapas_fitModel, ptrans is the parameter vector in
% transformed space, and 'trans' is a flag indicating this.
%
% --------------------------------------------------------------------------------------------------
% Copyright (C) 2016 Christoph Mathys, TNU, UZH & ETHZ
%
% This file is part of the HGF toolbox, which is released under the terms of the GNU General Public
% Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL
% (either version 3 or, at your option, any later version). For further details, see the file
% COPYING or <http://www.gnu.org/licenses/>.
% Transform paramaters back to their native space if needed
if ~isempty(varargin) && strcmp(varargin{1},'trans');
p = tapas_kf_transp(r, p);
end
% Unpack parameters
g_0 = p(1); % Initial gain
mu_0 = p(2); % Initial hidden state mean
expom = exp(p(3)); % Process variance
pi_u = p(4); % Observation precision
% Add dummy "zeroth" trial
u = [0; r.u(:,1)];
n = length(u);
% Initialize updated quantities
da = NaN(n,1); % Prediction error
g = NaN(n,1); % Kalman gain
mu = NaN(n,1); % Hidden state mean
% Priors
g(1) = g_0;
mu(1) = mu_0;
% Pass through update loop
for k = 2:1:n
if not(ismember(k-1, r.ign))
%%%%%%%%%%%%%%%%%%%%%%
% Effect of input u(k)
%%%%%%%%%%%%%%%%%%%%%%
% Prediction error
da(k) = u(k)-mu(k-1);
% Gain update
g(k) = (g(k-1) +pi_u*expom)/(g(k-1) +pi_u*expom +1);
% Hidden state mean update
mu(k) = mu(k-1)+g(k)*da(k);
else
da(k) = 0;
g(k) = g(k-1);
mu(k) = mu(k-1);
end
end
% Predicted value
muhat = mu;
muhat(end) = [];
% Remove priors
da(1) = [];
g(1) = [];
mu(1) = [];
% Create result data structure
traj = struct;
traj.g = g;
traj.muhat = muhat;
traj.mu = mu;
traj.da = da;
% Create matrix (in this case: vector) needed by observation model
infStates = [traj.muhat, traj.mu];
return;