function t1 = joinTrees(t1, t2, alpha)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% t12 = joinTrees(t1, t2, alpha)
% create a new KD-tree with t1 and t2 as the children of the root
% The t1 subtree recieves weight alpha; the t2 subtree has wt. 1-alpha
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Copyright (C) 2003 Alexander Ihler; distributable under GPL -- see README.txt
if(t1.D ~= t2.D)
error('Input trees have different dimensionality!');
end
if(isfield(t1, 'type') ~= isfield(t2, 'type'))
error('Cant merge BallTree with BallTreeDensity!');
end
if(t1.type ~= t2.type)
error('Trees must have the same type of kernel!');
end
if(nargin < 2 || nargin > 3)
error('wrong number of arguments');
end
if(nargin == 2)
alpha = 0.5;
end
% leave off one of the zeros between leaves and nodes
if(t1.N==1) t1nodes = []; t1leaves=[1];
else t1nodes = [1:t1.N-1]; t1leaves = [t1.N+1:2*t1.N]; end;
if(t2.N==1) t2nodes = []; t2leaves = [1];
else t2nodes = [1:t2.N-1]; t2leaves = [t2.N+1:2*t2.N]; end;
if(t1.N~=1) t1root=2;
else t1root=3+length(t2nodes);
end
if(t2.N~=1) t2root=2+length(t1nodes);
else t2root=3+length(t1nodes)+length(t1leaves);
end
t1N = t1.N;
t1.N = t1.N + t2.N;
Os = zeros(t1.D, 1);
t1.centers = [Os t1.centers(:,t1nodes) t2.centers(:,t2nodes) Os ...
t1.centers(:,t1leaves) t2.centers(:,t2leaves)];
t1.ranges = [Os t1.ranges(:,t1nodes) t2.ranges(:,t2nodes) Os t1.ranges(:,t1leaves) ...
t2.ranges(:,t2leaves)];
t1.weights = [0 alpha*t1.weights(t1nodes) (1-alpha)*t2.weights(t2nodes) 0 ...
alpha*t1.weights(t1leaves) (1-alpha)*t2.weights(t2leaves)];
t1.means = [Os t1.means(:,t1nodes) t2.means(:,t2nodes) Os ...
t1.means(:,t1leaves) t2.means(:,t2leaves)];
% take care of variable BWs
t1varBWs = size(t1.bandwidth, 2) > 2*t1N;
t2varBWs = size(t2.bandwidth, 2) > 2*t2.N;
varBWs = t1varBWs || t2varBWs || ...
sum(t1.bandwidth(:,1) ~= t2.bandwidth(:,1)) > 0;
if(varBWs)
t1.bandwidth = [Os t1.bandwidth(:,t1nodes) t2.bandwidth(:,t2nodes) ...
Os t1.bandwidth(:,t1leaves) t2.bandwidth(:,t2leaves) ...
Os t1.bandwidth(:,t1nodes+2*t1N*t1varBWs) ...
t2.bandwidth(:,t2nodes+2*t2.N*t2varBWs) Os ...
t1.bandwidth(:,t1leaves+2*t1N*t1varBWs) ...
t2.bandwidth(:,t2leaves+2*t2.N*t2varBWs) Os ...
t1.bandwidth(:,t1nodes+4*t1N*t1varBWs) ...
t2.bandwidth(:,t2nodes+4*t2.N*t2varBWs) Os ...
t1.bandwidth(:,t1leaves+4*t1N*t1varBWs) ...
t2.bandwidth(:,t2leaves+4*t2.N*t2varBWs) ];
else
t1.bandwidth = [Os t1.bandwidth(:,t1nodes) t2.bandwidth(:,t2nodes) ...
Os t1.bandwidth(:,t1leaves) t2.bandwidth(:,t2leaves)];
end
%%%% Do stuff from calcStats because calcStats is protected. Don't
% change calc stats or this won't work.
ax = max(t1.centers(:,t1root)+t1.ranges(:,t1root), ...
t1.centers(:,t2root)+t1.ranges(:,t2root));
in = min(t1.centers(:,t1root)-t1.ranges(:,t1root), ...
t1.centers(:,t2root)-t1.ranges(:,t2root));
t1.centers(:,1) = (ax+in)/2;
t1.ranges(:,1) = (ax-in)/2;
%calcuate weight
t1.weights(1) = t1.weights(t1root) + t1.weights(t2root);
W = sum(t1.weights(t1.N+1:2*t1.N));
t1.weights = t1.weights / W; % normalize
t1w = t1.weights(t1root) / t1.weights(1);
t2w = t1.weights(t2root) / t1.weights(1);
%calculate mean
t1.means(:,1) = t1w*t1.means(:,t1root) + t2w*t1.means(:,t2root);
%calculate bandwidth
type = getType(t1);
if(strcmp(type, 'Gaussian'))
t1.bandwidth(:,1) = t1w * (t1.bandwidth(:,t1root) + t1.means(:,t1root).^2) ...
+ t2w * (t1.bandwidth(:,t2root) + t1.means(:,t2root).^2) - ...
t1.means(:,1).^2;
elseif(strcmp(type, 'Epanetchnikov'))
t1.bandwidth(:,1) = sqrt(.5 * (t1w * (2*t1.bandwidth(:,t1root).^2 ...
+ t1.means(:,t1root).^2) ...
+ t2w * (2*t1.bandwidth(:,t2root).^2 ...
+ t1.means(:,t2root).^2) ...
- t1.means(:,1).^2));
elseif(strcmp(type, 'Laplacian'))
t1.bandwidth(:,1) = sqrt(5 * (t1w * (.2*t1.bandwidth(:,t1root).^2 ...
+ t1.means(:,t1root).^2) ...
+ t2w * (.2*t1.bandwidth(:,t2root).^2 ...
+ t1.means(:,t2root).^2) ...
- t1.means(:,1).^2));
else
error(['unknown kernel type: ' type])
end
% take care of max and min BWs for variable bandwidths
if(varBWs)
t1.bandwidth(:,1+2*t1.N) = max(t1.bandwidth(:,t1root+2*t1.N), t1.bandwidth(:,t2root+2*t1.N));
t1.bandwidth(:,1+4*t1.N) = min(t1.bandwidth(:,t1root+4*t1.N), t1.bandwidth(:,t2root+4*t1.N));
end
t1n = 1;
t2n = 1 + length(t1nodes);
t1l = 1 + length(t2nodes);
t2l = 1 + length(t1nodes) + length(t1leaves);
% arrays are zero indexed
t1.lower = [t1.N addUints(t1.lower(t1nodes),t1l) ...
addUints(t2.lower(t2nodes),t2l) 0 ...
addUints(t1.lower(t1leaves),t1l) ...
addUints(t2.lower(t2leaves),t2l)];
t1.upper = [2*t1.N-1 addUints(t1.upper(t1nodes),t1l) ...
addUints(t2.upper(t2nodes),t2l) 0 ...
addUints(t1.upper(t1leaves),t1l) ...
addUints(t2.upper(t2leaves),t2l)];
t1leftch = addUints(t1.leftch, (t1.leftch < t1N) * t1n + ...
(t1.leftch >= t1N) * t1l);
t2leftch = addUints(t2.leftch, (t2.leftch < t2.N) * t2n + ...
(t2.leftch >= t2.N) * t2l);
t1rightch = addUints(t1.rightch, (t1.rightch < t1N) * t1n + ...
(t1.rightch >= t1N) .* (t1.rightch < 4e9) * t1l);
t2rightch = addUints(t2.rightch, (t2.rightch < t2.N) * t2n + ...
(t2.rightch >= t2.N) .* (t2.rightch < 4e9) * t2l);
t1.leftch = [d2uint(t1root) t1leftch(t1nodes) t2leftch(t2nodes) ...
0 t1leftch(t1leaves) t2leftch(t2leaves)];
t1.rightch = [d2uint(t2root) t1rightch(t1nodes), t2rightch(t2nodes) ...
0 t1rightch(t1leaves) t2rightch(t2leaves)];
t1.perm = [0 t1.perm(t1nodes) t2.perm(t2nodes) 0 t1.perm(t1leaves) ...
addUints(t2.perm(t2leaves),length(t1leaves))];
function c = addUints(a, b)
c = uint32(double(a) + double(b));
function u = d2uint(d)
u = uint32(d - 1);