function [traj, infStates] = tapas_rw_binary_dual(r, p, varargin)
% Calculates the trajectories of v under the Rescorla-Wagner learning model for dual updates.
%
% This function can be called in two ways:
% 
% (1) tapas_rw_binary_dual(r, p)
%   
%     where r is the structure generated by tapas_fitModel and p is the parameter vector in native space;
%
% (2) tapas_rw_binary_dual(r, ptrans, 'trans')
% 
%     where r is the structure generated by tapas_fitModel, ptrans is the parameter vector in
%     transformed space, and 'trans' is a flag indicating this.
%
% --------------------------------------------------------------------------------------------------
% Copyright (C) 2012-2013 Christoph Mathys, TNU, UZH & ETHZ
%
% This file is part of the HGF toolbox, which is released under the terms of the GNU General Public
% Licence (GPL), version 3. You can redistribute it and/or modify it under the terms of the GPL
% (either version 3 or, at your option, any later version). For further details, see the file
% COPYING or <http://www.gnu.org/licenses/>.

% Transform paramaters back to their native space if needed
if ~isempty(varargin) && strcmp(varargin{1},'trans');
    p = tapas_rw_binary_dual_transp(r, p);
end

% Unpack parameters
v_0 = p(1:2);
al  = p(3);
ka  = p(4);

% Add dummy "zeroth" trial
u = [0; r.u(:,1)];
y = [0; r.y(:,1)];
n = length(u);

% Initialize updated quantity: value
v  = NaN(n,2);
da = NaN(n,2);

% Prior
v(1,:) = v_0;

% Pass through value update loop
for k = 2:1:n
    if not(ismember(k, r.ign))
        
        %%%%%%%%%%%%%%%%%%%%%%
        % Effect of input u(k)
        %%%%%%%%%%%%%%%%%%%%%%
        
        % Prediction error
        if u(k)==1
            da(k,y(k))   = 1 -v(k-1,y(k));
            da(k,3-y(k)) = 0 -v(k-1,3-y(k));
        elseif u(k)==0
            da(k,y(k))   = 0 -v(k-1,y(k));
            da(k,3-y(k)) = 1 -v(k-1,3-y(k));
        end
        
        % Value
        v(k,y(k))   = v(k-1,y(k))      +al*da(k,y(k));
        v(k,3-y(k)) = v(k-1,3-y(k)) +ka*al*da(k,3-y(k));
    else
        da(k,:) = [0, 0];
        v(k,:)  = v(k-1,:);
    end
end

% Predicted value
vhat = v;
vhat(end,:) = [];

% Remove representation priors
v(1,:)  = [];
da(1,:) = [];

% Create result data structure
traj = struct;

traj.v     = v;
traj.vhat  = vhat;
traj.da    = da;

% Create matrix (in this case: vector) needed by observation model
infstates = NaN(n-1,1,2,1,1);
infStates(:,1,1,1,1) = vhat(:,1);
infStates(:,1,2,1,1) = vhat(:,2);

return;