// $Id: KDTree.cpp,v 1.8 2010/03/09 17:38:00 samn Exp $
//kdtree class from bipython 1.43 - with some enhancements by sam neymotin
// #include "stdafx.h"
#include "KDTree.h"
int iSQRTCalls = 0;
namespace NSKDTree
{
float KDTREE_dist(float *coord1, float *coord2, int dim)
{
// returns the SQUARE of the distance between two points
int i;
float sum=0, dif=0;
for (i=0; i<dim; i++)
{
dif=coord1[i]-coord2[i];
sum+=dif*dif;
}
return sum;
}
// DataPoint
int DataPoint::current_dim=0;
int DataPoint::dim=3;
int operator<(const DataPoint &self, const DataPoint &other)
{
float a, b;
a=self._coord[DataPoint::current_dim];
b=other._coord[DataPoint::current_dim];
return a<b;
}
int operator==(const DataPoint &self, const DataPoint &other)
{
float a, b;
a=self._coord[DataPoint::current_dim];
b=other._coord[DataPoint::current_dim];
return a==b;
}
void DataPoint::set_data(long int index, float *coord)
{
_index=index;
_coord=coord;
}
float *DataPoint::get_coord(void)
{
return _coord;
}
long int DataPoint::get_index(void)
{
return _index;
}
// Node
Node::Node(float cut_value, int cut_dim, long int start, long int end)
{
_left=NULL;
_right=NULL;
_cut_value=cut_value;
_cut_dim=cut_dim;
// start and end index in _data_point_list
_start=start;
_end=end;
}
Node::~Node()
{
delete _left;
delete _right;
}
void Node::set_left_node(Node *node)
{
_left=node;
}
void Node::set_right_node(Node *node)
{
_right=node;
}
Node *Node::get_left_node(void)
{
return _left;
}
Node *Node::get_right_node(void)
{
return _right;
}
long int Node::get_start(void)
{
return _start;
}
long int Node::get_end(void)
{
return _end;
}
float Node::get_cut_value(void)
{
return _cut_value;
}
int Node::get_cut_dim(void)
{
return _cut_dim;
}
int Node::is_leaf(void)
{
if (_left==NULL && _right==NULL)
{
return 1;
}
else
{
return 0;
}
}
int Node::is_bucket(void)
{
if (_start==_end+1)
{
// Node contains a single point
return 0;
}
else
{
// Node contains several points
return 1;
}
}
// Region
int Region::dim=3;
Region::Region(float *left, float *right)
{
_left=new float[Region::dim];
_right=new float[Region::dim];
if (left==NULL || right==NULL)
{
// [-INF, INF]
int i;
for (i=0; i<Region::dim; i++)
{
_left[i]=-INF;
_right[i]=INF;
}
}
else
{
int i;
for (i=0; i<Region::dim; i++)
{
_left[i]=left[i];
_right[i]=right[i];
}
}
}
Region::~Region()
{
delete [] _left;
delete [] _right;
}
Region *Region::intersect_right(float split_coord, int current_dim)
{
float l, r;
r=_right[current_dim];
l=_left[current_dim];
if (split_coord<=l)
{
// split point lies to the left
return new Region(_left, _right);
}
else
{
if (split_coord<=r)
{
// split point in interval
// adjust left
int i;
vector<float> new_left(Region::dim);
for (i=0; i<Region::dim; i++)
{
new_left[i]=_left[i];
}
new_left[current_dim]=split_coord;
return new Region(&new_left[0], _right);
}
else
{
// interval lies to the left of split point
return NULL;
}
}
}
Region *Region::intersect_left(float split_coord, int current_dim)
{
float l, r;
r=_right[current_dim];
l=_left[current_dim];
if (split_coord<l)
{
// nothing to the left
return NULL;
}
else
{
if (split_coord<r)
{
// split point in interval
// adjust right
int i;
vector<float> new_right(Region::dim);
for (i=0; i<Region::dim; i++)
{
new_right[i]=_right[i];
}
new_right[current_dim]=split_coord;
return new Region(_left, &new_right[0]);
}
else
{
return new Region(_left, _right);
}
}
}
int Region::encloses(float *coord)
{
int i;
for (i=0; i<Region::dim; i++)
{
if (!(coord[i]>=_left[i] && coord[i]<=_right[i]))
{
return 0;
}
}
return 1;
}
float *Region::get_left(void)
{
return _left;
}
float *Region::get_right(void)
{
return _right;
}
int Region::test_intersection(Region *query_region, float radius)
{
int status=2;
int i;
for (i=0; i<Region::dim; i++)
{
float rs, rq, ls, lq;
rs=_right[i];
ls=_left[i];
rq=query_region->get_right()[i];
lq=query_region->get_left()[i];
if (ls-rq>radius)
{
// outside
return 0;
}
else if (lq-rs>radius)
{
// outside
return 0;
}
else if (rs<=rq && ls>=lq)
{
// inside (at least in dim i)
status=min(status, 2);
}
else
{
// overlap (at least in dim i)
status=1;
}
}
return status;
}
// KDTree
KDTree::KDTree(int dim, int bucket_size, bool delete_user_coords)
{
// set dimension
this->dim=dim;
DataPoint::dim=dim;
Region::dim=dim;
_delete_user_coords = delete_user_coords;
_center_coord=new float[dim];
_query_region=NULL;
_root=NULL;
_coords=NULL;
_count=0;
_neighbor_count=0;
_bucket_size=bucket_size;
_max_neighbors_to_find = -1;
_min_radius_sq = INF;
}
KDTree::~KDTree()
{
// clean up KD tree
delete _root;
delete _query_region;
delete [] _center_coord;
if(_delete_user_coords) delete [] _coords;
}
Node *KDTree::_build_tree(long int offset_begin, long int offset_end, int depth)
{
int localdim;
if (depth==0)
{
// start with [begin, end+1[
offset_begin=0;
offset_end=_data_point_list.size();
localdim=0;
}
else
{
localdim=depth%dim;
}
if ((offset_end-offset_begin)<=_bucket_size)
{
// leaf node
return new Node(-1, localdim, offset_begin, offset_end);
}
else
{
long int offset_split;
long int left_offset_begin, left_offset_end;
long int right_offset_begin, right_offset_end;
long int d;
float cut_value;
DataPoint data_point;
Node *left_node, *right_node, *new_node;
// set sort dimension
DataPoint::current_dim=localdim;
// sort method sorts [first, last[
sort(_data_point_list.begin()+offset_begin, _data_point_list.begin()+offset_end);
// calculate index of split point
d=offset_end-offset_begin;
offset_split=d/2+d%2;
data_point=_data_point_list[offset_begin+offset_split-1];
cut_value=(data_point.get_coord())[localdim];
// create new node and bind to left & right nodes
new_node=new Node(cut_value, localdim, offset_begin, offset_end);
// left
left_offset_begin=offset_begin;
left_offset_end=offset_begin+offset_split;
left_node=_build_tree(left_offset_begin, left_offset_end, depth+1);
// right
right_offset_begin=left_offset_end;
right_offset_end=offset_end;
right_node=_build_tree(right_offset_begin, right_offset_end, depth+1);
new_node->set_left_node(left_node);
new_node->set_right_node(right_node);
return new_node;
}
}
void KDTree::_add_point(long int index, float *coord)
{
DataPoint data_point;
data_point.set_data(index, coord);
// add to list of points
_data_point_list.push_back(data_point);
}
void KDTree::_set_query_region(float *left, float *right)
{
delete _query_region;
_query_region=new Region(left, right);
}
void KDTree::_search(Region *region, Node *node, int depth)
{
int current_dim;
if(depth==0)
{
// start with [-INF, INF] region
region=new Region();
// start with root node
node=_root;
}
current_dim=depth%dim;
if(node->is_leaf())
{
long int i;
for (i=node->get_start(); i<node->get_end(); i++)
{
DataPoint data_point;
data_point=_data_point_list[i];
if (_query_region->encloses(data_point.get_coord()))
{
// point is enclosed in query region - report & stop
_report_point(data_point.get_index(), data_point.get_coord());
}
}
}
else
{
Node *left_node, *right_node;
Region *left_region, *right_region;
left_node=node->get_left_node();
// LEFT HALF PLANE
// new region
left_region=region->intersect_left(node->get_cut_value(), current_dim);
// left_region is NULL if no overlap
if(left_region)
{
_test_region(left_node, left_region, depth);
}
// RIGHT HALF PLANE
right_node=node->get_right_node();
// new region
right_region=region->intersect_right(node->get_cut_value(), current_dim);
// right_region is NULL if no overlap
if(right_region)
{
// test for overlap/inside/outside & do recursion/report/stop
_test_region(right_node, right_region, depth);
}
}
delete region;
}
void KDTree::_test_region(Node *node, Region *region, int depth)
{
int intersect_flag;
// is node region inside, outside or overlapping
// with query region?
intersect_flag=region->test_intersection(_query_region);
if (intersect_flag==2)
{
// inside - extract points
_report_subtree(node);
// end of recursion
// get rid of region
delete region;
}
else if (intersect_flag==1)
{
// overlap - recursion
_search(region, node, depth+1);
// search does cleanup of region
}
else
{
// outside - stop
// end of recursion
// get rid of region
delete region;
}
}
void KDTree::_report_subtree(Node *node)
{
if (node->is_leaf())
{
// report point(s)
long int i;
for (i=node->get_start(); i<node->get_end(); i++)
{
DataPoint data_point;
data_point=_data_point_list[i];
_report_point(data_point.get_index(), data_point.get_coord());
}
}
else
{
// find points in subtrees via recursion
_report_subtree(node->get_left_node());
_report_subtree(node->get_right_node());
}
}
void KDTree::_report_point(long int index, float *coord)
{
float r;
r=KDTREE_dist(_center_coord, coord, KDTree::dim);
/*if(_max_neighbors_to_find == 1)
{ if(r<=_radius_sq && r<=_min_radius_sq && r>0.0)
{
_index_list.resize(1); _index_list[0]=index;
_radius_list.resize(1); _radius_list[0]=r;
_count=1;
_min_radius_sq=r;
}
}
else*/ if (r<=_radius_sq)
{
_index_list.push_back(index);
// note use of sqrt - only calculated if necessary
_radius_list.push_back(r);
_count++;
}
}
void KDTree::set_data(float *coords, long int nr_points)
{
long int i;
DataPoint::dim=dim;
Region::dim=dim;
// clean up stuff from previous use
delete _root;
if(_delete_user_coords) delete [] _coords;
_index_list.clear();
_radius_list.clear();
_count=0;
// keep pointer to coords to delete it
_coords=coords;
for (i=0; i<nr_points; i++)
{
_add_point(i, coords+i*dim);
}
// build KD tree
_root=_build_tree();
}
void KDTree::_search_r(Node* node,float* coord,bool allowzero)
{
if(!node)
return;
if(node->is_leaf())
{
long int i;
for (i=node->get_start(); i<node->get_end(); i++)
{ DataPoint data_point;
data_point=_data_point_list[i];
float r = KDTREE_dist(data_point.get_coord(), coord, KDTree::dim);
if (r<_min_radius_sq && (allowzero || r>0.0f))
{ _index_list[0] = data_point.get_index();
_radius_list[0] = r;
_count = 1;
_min_radius_sq = r;
}
}
}
else
{
int cutdim = node->get_cut_dim();
float cutval = node->get_cut_value();
float dif = coord[cutdim]-cutval;
float distcut=dif*dif;
// ----- Check which side of the cutline target is -----
if( coord[cutdim] < cutval)
{
_search_r(node->get_left_node(),coord,allowzero);
if(_min_radius_sq >= distcut)
_search_r(node->get_right_node(),coord,allowzero);
}
else
{
_search_r(node->get_right_node(),coord,allowzero);
if(_min_radius_sq >= distcut)
_search_r(node->get_left_node(),coord,allowzero);
}
}
}
void KDTree::search_nn(float* coord,bool allowzero)
{
_min_radius_sq = INF;
_index_list.resize(1);
_radius_list.resize(1);
_search_r(_root,coord,allowzero);
}
void KDTree::search_center_radius_sq(float *coord, float radius_sq,int iNNToFind)
{
int i;
vector<float> left(dim),right(dim);
DataPoint::dim=dim;
Region::dim=dim;
_index_list.clear();
_radius_list.clear();
_count=0;
_index_list.resize(20); _index_list.resize(0);
_radius_list.resize(20); _radius_list.resize(0);
////////////////////////////////////////
////////// added by sam
_max_neighbors_to_find = iNNToFind;
_min_radius_sq = INF;
////////////////////////////////////////
_radius=MySqrt(radius_sq);
// use of r^2 to avoid sqrt use
_radius_sq=radius_sq;
float radius=_radius;
for (i=0; i<dim; i++)
{
left[i]=coord[i]-radius;
right[i]=coord[i]+radius;
// set center of query
_center_coord[i]=coord[i];
}
// clean up!
//if(_delete_user_coords) delete [] coord; //???????????
_set_query_region(&left[0], &right[0]);
_search();
}
long int KDTree::get_count(void)
{
return _count;
}
void KDTree::copy_indices(long *indices)
{
long int i;
if (_count==0)
{
return;
}
for(i=0; i<_count; i++)
{
indices[i]=_index_list[i];
}
}
void KDTree::copy_radii_sq(float *radii_sq)
{
long int i;
if (_count==0)
{
return;
}
for(i=0; i<_count; i++)
{
radii_sq[i]=_radius_list[i];
}
}
void KDTree::neighbor_copy_indices(long int *indices)
{
long int i;
if (_neighbor_count==0)
{
return;
}
for(i=0; i<_neighbor_count*2; i++)
{
indices[i]=_neighbor_index_list[i];
}
}
void KDTree::neighbor_copy_radii_sq(float *radii_sq)
{
long int i;
if (_neighbor_count==0)
{
return;
}
for(i=0; i<_neighbor_count; i++)
{
radii_sq[i]=_neighbor_radius_list[i];
}
}
long int KDTree::neighbor_get_count(void)
{
return _neighbor_count;
}
void KDTree::neighbor_search_sq(float neighbor_radius_sq)
{
Region *region;
DataPoint::dim=dim;
Region::dim=dim;
_neighbor_index_list.clear();
_neighbor_radius_list.clear();
_neighbor_index_list.resize(20); _neighbor_index_list.resize(0);
_neighbor_radius_list.resize(20); _neighbor_radius_list.resize(0);
// note the use of r^2 to avoid use of sqrt
_neighbor_radius=MySqrt(neighbor_radius_sq);
_neighbor_radius_sq=neighbor_radius_sq;
_neighbor_count=0;
// start with [-INF, INF]
region=new Region();
if (_root->is_leaf())
{
// this is a boundary condition
// bucket_size>nr of points
_search_neighbors_in_bucket(_root);
}
else
{
// "normal" situation
_neighbor_search(_root, region, 0);
}
delete region;
}
void KDTree::_neighbor_search(Node *node, Region *region, int depth)
{
Node *left, *right;
Region *left_region, *right_region;
int localdim;
localdim=depth%dim;
left=node->get_left_node();
right=node->get_right_node();
// planes of left and right nodes
left_region=region->intersect_left(node->get_cut_value(), localdim);
right_region=region->intersect_right(node->get_cut_value(), localdim);
if (!left->is_leaf())
{
// search for pairs in this half plane
_neighbor_search(left, left_region, depth+1);
}
else
{
_search_neighbors_in_bucket(left);
}
if (!right->is_leaf())
{
// search for pairs in this half plane
_neighbor_search(right, right_region, depth+1);
}
else
{
_search_neighbors_in_bucket(right);
}
// search for pairs between the half planes
_neighbor_search_pairs(left, left_region, right, right_region, depth+1);
// cleanup
delete left_region;
delete right_region;
}
void KDTree::_test_neighbors(DataPoint &p1, DataPoint &p2)
{
float r;
r=KDTREE_dist(p1.get_coord(), p2.get_coord(), dim);
if(r<=_neighbor_radius_sq)
{
// we found a neighbor pair!
_neighbor_index_list.push_back(p1.get_index());
_neighbor_index_list.push_back(p2.get_index());
// note sqrt
//_neighbor_radius_list.push_back(MySqrt(r));
_neighbor_radius_list.push_back(r);
_neighbor_count++;
}
}
void KDTree::_search_neighbors_in_bucket(Node *node)
{
long int i;
for(i=node->get_start(); i<node->get_end(); i++)
{
DataPoint p1;
long int j;
p1=_data_point_list[i];
for (j=i+1; j<node->get_end(); j++)
{
DataPoint p2;
p2=_data_point_list[j];
_test_neighbors(p1, p2);
}
}
}
void KDTree::_search_neighbors_between_buckets(Node *node1, Node *node2)
{
long int i;
for(i=node1->get_start(); i<node1->get_end(); i++)
{
DataPoint p1;
long int j;
p1=_data_point_list[i];
for (j=node2->get_start(); j<node2->get_end(); j++)
{
DataPoint p2;
p2=_data_point_list[j];
_test_neighbors(p1, p2);
}
}
}
void KDTree::_neighbor_search_pairs(Node *down, Region *down_region,
Node *up, Region *up_region, int depth)
{
int down_is_leaf, up_is_leaf;
int localdim;
// if regions do not overlap - STOP
if (!down || !up || !down_region || !up_region)
{
// STOP
return;
}
if (down_region->test_intersection(up_region, _neighbor_radius)==0)
{
// regions cannot contain neighbors
return;
}
// dim
localdim=depth%dim;
// are they leaves?
up_is_leaf=up->is_leaf();
down_is_leaf=down->is_leaf();
if (up_is_leaf && down_is_leaf)
{
// two leaf nodes
_search_neighbors_between_buckets(down, up);
}
else
{
// one or no leaf nodes
Node *up_right, *up_left, *down_left, *down_right;
Region *up_left_region, *up_right_region,
*down_left_region, *down_right_region;
if (down_is_leaf)
{
down_left=down;
// make a copy of down_region
down_left_region=new Region(down_region->get_left(), down_region->get_right());
down_right=NULL;
down_right_region=NULL;
}
else
{
float cut_value;
cut_value=down->get_cut_value();
down_left=down->get_left_node();
down_right=down->get_right_node();
down_left_region=down_region->intersect_left(cut_value, localdim);
down_right_region=down_region->intersect_right(cut_value, localdim);
}
if (up_is_leaf)
{
up_left=up;
// make a copy of up_region
up_left_region=new Region(up_region->get_left(), up_region->get_right());
up_right=NULL;
up_right_region=NULL;
}
else
{
float cut_value;
cut_value=up->get_cut_value();
up_left=up->get_left_node();
up_right=up->get_right_node();
up_left_region=up_region->intersect_left(cut_value, localdim);
up_right_region=up_region->intersect_right(cut_value, localdim);
}
_neighbor_search_pairs(up_left, up_left_region, down_left, down_left_region, depth+1);
_neighbor_search_pairs(up_left, up_left_region, down_right, down_right_region, depth+1);
_neighbor_search_pairs(up_right, up_right_region, down_left, down_left_region, depth+1);
_neighbor_search_pairs(up_right, up_right_region, down_right, down_right_region, depth+1);
delete down_left_region;
delete down_right_region;
delete up_left_region;
delete up_right_region;
}
}
void KDTree::neighbor_simple_search_sq(float radius_sq)
{
long int i;
DataPoint::dim=dim;
Region::dim=dim;
_neighbor_radius=MySqrt(radius_sq);
_neighbor_radius_sq=radius_sq;
float radius = _neighbor_radius;
_neighbor_count=0;
_neighbor_index_list.clear();
_neighbor_radius_list.clear();
DataPoint::current_dim=0;
sort(_data_point_list.begin(), _data_point_list.end());
for (i=0; i<_data_point_list.size(); i++)
{
float x1;
long int j;
DataPoint p1;
p1=_data_point_list[i];
x1=p1.get_coord()[0];
for (j=i+1; j<_data_point_list.size(); j++)
{
DataPoint p2;
float x2;
p2=_data_point_list[j];
x2=p2.get_coord()[0];
if (fabs(x2-x1)<=radius)
{
_test_neighbors(p1, p2);
}
else
{
break;
}
}
}
}
}