19 #if !defined(__NORI_KDTREE_H)
20 #define __NORI_KDTREE_H
22 #include <nori/bbox.h>
43 typedef _PointType PointType;
44 typedef _DataRecord DataRecord;
45 typedef uint32_t IndexType;
46 typedef typename PointType::Scalar Scalar;
60 right(0), data(), flags(0) { }
63 : position(position), right(0), data(data), flags(0) {}
75 throw NoriException(
"GenericKDTreeNode::setLeftIndex(): Internal error!");
79 bool isLeaf()
const {
return flags & (uint8_t) ELeafFlag; }
83 flags |= (uint8_t) ELeafFlag;
85 flags &= (uint8_t) ~ELeafFlag;
89 uint16_t
getAxis()
const {
return flags & (uint8_t) EAxisMask; }
91 void setAxis(uint8_t axis) { flags = (flags & (uint8_t) ~EAxisMask) | axis; }
96 void setPosition(
const PointType &value) { position = value; }
101 const DataRecord &
getData()
const {
return data; }
103 void setData(
const DataRecord &val) { data = val; }
107 template <
typename DataType,
typename IndexType>
void permute_inplace(
108 DataType *data, std::vector<IndexType> &perm);
126 typedef _NodeType NodeType;
127 typedef typename NodeType::PointType PointType;
128 typedef typename NodeType::IndexType IndexType;
129 typedef typename PointType::Scalar Scalar;
130 typedef typename PointType::VectorType VectorType;
134 Dimension = VectorType::RowsAtCompileTime
157 : distSquared(distSquared), index(index) { }
159 std::string toString()
const {
160 std::ostringstream oss;
161 oss <<
"SearchResult[distance=" << std::sqrt(distSquared)
162 <<
", index=" << index <<
"]";
167 return distSquared == r.distSquared &&
178 : m_nodes(nodes), m_heuristic(heuristic), m_depth(0) { }
190 size_t size()
const {
return m_nodes.size(); }
192 size_t capacity()
const {
return m_nodes.capacity(); }
195 m_nodes.push_back(node);
196 m_bbox.
expandBy(node.getPosition());
201 const NodeType &
operator[](
size_t idx)
const {
return m_nodes[idx]; }
221 void build(
bool recomputeBoundingBox =
false) {
222 if (m_nodes.size() == 0) {
223 std::cerr <<
"KDTree::build(): kd-tree is empty!" << endl;
227 cout <<
"Building a " << Dimension <<
"-dimensional kd-tree over "
228 << m_nodes.size() <<
" data points ("
229 << memString(m_nodes.size() *
sizeof(NodeType)).c_str() <<
") .. ";
232 if (recomputeBoundingBox) {
234 for (
size_t i=0; i<m_nodes.size(); ++i)
235 m_bbox.
expandBy(m_nodes[i].getPosition());
242 std::vector<IndexType> indirection(m_nodes.size());
243 for (
size_t i=0; i<m_nodes.size(); ++i)
244 indirection[i] = (IndexType) i;
247 build(1, indirection.begin(), indirection.begin(), indirection.end());
248 permute_inplace(&m_nodes[0], indirection);
250 cout <<
"done." << endl;
260 void search(
const PointType &p,
float searchRadius, std::vector<IndexType> &results)
const {
261 if (m_nodes.size() == 0)
264 IndexType *stack = (IndexType *) alloca((m_depth+1) *
sizeof(IndexType));
265 IndexType index = 0, stackPos = 1, found = 0;
266 float distSquared = searchRadius*searchRadius;
270 while (stackPos > 0) {
271 const NodeType &node = m_nodes[index];
275 if (!node.isLeaf()) {
276 float distToPlane = p[node.getAxis()]
277 - node.getPosition()[node.getAxis()];
279 bool searchBoth = distToPlane*distToPlane <= distSquared;
281 if (distToPlane > 0) {
286 stack[stackPos++] = node.getLeftIndex(index);
287 nextIndex = node.getRightIndex(index);
288 }
else if (searchBoth) {
289 nextIndex = node.getLeftIndex(index);
291 nextIndex = stack[--stackPos];
297 stack[stackPos++] = node.getRightIndex(index);
299 nextIndex = node.getLeftIndex(index);
302 nextIndex = stack[--stackPos];
306 const float pointDistSquared = (node.getPosition() - p).squaredNorm();
308 if (pointDistSquared < distSquared) {
310 results.push_back(index);
334 size_t nnSearch(
const PointType &p,
float &_sqrSearchRadius,
336 if (m_nodes.size() == 0)
339 IndexType *stack = (IndexType *) alloca((m_depth+1) *
sizeof(IndexType));
340 IndexType index = 0, stackPos = 1;
341 float sqrSearchRadius = _sqrSearchRadius;
342 size_t resultCount = 0;
346 while (stackPos > 0) {
347 const NodeType &node = m_nodes[index];
351 if (!node.isLeaf()) {
352 float distToPlane = p[node.getAxis()] - node.getPosition()[node.getAxis()];
354 bool searchBoth = distToPlane*distToPlane <= sqrSearchRadius;
356 if (distToPlane > 0) {
361 stack[stackPos++] = node.getLeftIndex(index);
362 nextIndex = node.getRightIndex(index);
363 }
else if (searchBoth) {
364 nextIndex = node.getLeftIndex(index);
366 nextIndex = stack[--stackPos];
372 stack[stackPos++] = node.getRightIndex(index);
374 nextIndex = node.getLeftIndex(index);
377 nextIndex = stack[--stackPos];
381 const float pointDistSquared = (node.getPosition() - p).squaredNorm();
383 if (pointDistSquared < sqrSearchRadius) {
386 if (resultCount < k) {
389 results[resultCount++] =
SearchResult(pointDistSquared, index);
392 return a.distSquared < b.distSquared;
397 std::make_heap(results, results + resultCount, comparator);
403 results[resultCount] =
SearchResult(pointDistSquared, index);
404 std::push_heap(results, end, comparator);
405 std::pop_heap(results, end, comparator);
408 sqrSearchRadius = results[0].distSquared;
413 _sqrSearchRadius = sqrSearchRadius;
432 float searchRadiusSqr = std::numeric_limits<float>::infinity();
433 return nnSearch(p, searchRadiusSqr, k, results);
439 return m_nodes[index].getRightIndex(index) != 0;
444 typename std::vector<IndexType>::iterator base,
445 typename std::vector<IndexType>::iterator rangeStart,
446 typename std::vector<IndexType>::iterator rangeEnd) {
447 if (rangeEnd <= rangeStart)
450 m_depth = std::max(depth, m_depth);
452 IndexType count = (IndexType) (rangeEnd-rangeStart);
456 m_nodes[*rangeStart].setLeaf(
true);
461 typename std::vector<IndexType>::iterator split;
463 switch (m_heuristic) {
466 split = rangeStart + count/2;
475 Scalar midpoint = (Scalar) 0.5f
476 * (m_bbox.
max[axis]+m_bbox.
min[axis]);
478 size_t nLT = std::count_if(rangeStart, rangeEnd,
480 return m_nodes[i].getPosition()[axis] <= midpoint;
485 split = rangeStart + nLT;
487 if (split == rangeStart)
489 else if (split == rangeEnd)
495 std::nth_element(rangeStart, split, rangeEnd,
496 [&](IndexType i1, IndexType i2) {
497 return m_nodes[i1].getPosition()[axis] < m_nodes[i2].getPosition()[axis];
501 NodeType &splitNode = m_nodes[*split];
502 splitNode.setAxis(axis);
503 splitNode.setLeaf(
false);
505 if (split+1 != rangeEnd)
506 splitNode.setRightIndex((IndexType) (rangeStart - base),
507 (IndexType) (split + 1 - base));
509 splitNode.setRightIndex((IndexType) (rangeStart - base), 0);
511 splitNode.setLeftIndex((IndexType) (rangeStart - base),
512 (IndexType) (rangeStart + 1 - base));
513 std::iter_swap(rangeStart, split);
516 Scalar temp = m_bbox.
max[axis],
517 splitPos = splitNode.getPosition()[axis];
518 m_bbox.
max[axis] = splitPos;
519 build(depth+1, base, rangeStart+1, split+1);
520 m_bbox.
max[axis] = temp;
522 if (split+1 != rangeEnd) {
523 temp = m_bbox.
min[axis];
524 m_bbox.
min[axis] = splitPos;
525 build(depth+1, base, split+1, rangeEnd);
526 m_bbox.
min[axis] = temp;
530 std::vector<NodeType> m_nodes;
531 BoundingBoxType m_bbox;
555 template <
typename DataType,
typename IndexType>
void permute_inplace(
556 DataType *data, std::vector<IndexType> &perm) {
557 for (
size_t i=0; i<perm.size(); i++) {
562 IndexType j = (IndexType) i;
563 DataType curval = data[i];
567 IndexType k = perm[j];
575 }
while (perm[j] != i);
Simple exception class, which stores a human-readable error description.
Generic multi-dimensional kd-tree data structure for point data.
size_t nnSearch(const PointType &p, size_t k, SearchResult *results) const
Run a k-nearest-neighbor search query without any search radius threshold.
NodeType & operator[](size_t idx)
Return one of the KD-tree nodes by index.
void search(const PointType &p, float searchRadius, std::vector< IndexType > &results) const
Run a search query.
void clear()
Clear the kd-tree array.
void push_back(const NodeType &node)
Append a kd-tree node to the node array.
bool hasRightChild(IndexType index) const
Return whether or not the inner node of the specified index has a right child node.
size_t nnSearch(const PointType &p, float &_sqrSearchRadius, size_t k, SearchResult *results) const
Run a k-nearest-neighbor search query.
size_t size() const
Return the size of the kd-tree.
void build(size_t depth, typename std::vector< IndexType >::iterator base, typename std::vector< IndexType >::iterator rangeStart, typename std::vector< IndexType >::iterator rangeEnd)
Tree construction routine.
void resize(size_t size)
Resize the kd-tree array.
void setBoundingBox(const BoundingBoxType &bbox)
Set the BoundingBox of the underlying point data.
Heuristic
Supported tree construction heuristics.
@ Balanced
Create a balanced tree by splitting along the median.
@ SlidingMidpoint
Use the sliding midpoint tree construction rule. This ensures that cells do not become overly elongat...
const NodeType & operator[](size_t idx) const
Return one of the KD-tree nodes by index (const version)
PointKDTree(size_t nodes=0, Heuristic heuristic=SlidingMidpoint)
Create an empty KD-tree that can hold the specified number of points.
size_t getDepth() const
Return the depth of the constructed KD-tree.
void reserve(size_t size)
Reserve a certain amount of memory for the kd-tree array.
void setDepth(size_t depth)
Set the depth of the constructed KD-tree (be careful with this)
void build(bool recomputeBoundingBox=false)
Construct the KD-tree hierarchy.
size_t capacity() const
Return the capacity of the kd-tree.
const BoundingBoxType & getBoundingBox() const
Return the BoundingBox of the underlying point data.
Simple kd-tree node data structure for use with PointKDTree.
uint16_t getAxis() const
Return the split axis associated with this node.
IndexType getLeftIndex(IndexType self) const
Given the current node's index, return the index of the left child.
GenericKDTreeNode()
Initialize a KD-tree node.
const PointType & getPosition() const
Return the position associated with this node.
void setLeftIndex(IndexType self, IndexType value)
Given the current node's index, set the left child index.
void setData(const DataRecord &val)
Set the data record associated with this node.
DataRecord & getData()
Return the data record associated with this node.
GenericKDTreeNode(const PointType &position, const DataRecord &data)
Initialize a KD-tree node with the given data record.
void setAxis(uint8_t axis)
Set the split flags associated with this node.
bool isLeaf() const
Check whether this is a leaf node.
const DataRecord & getData() const
Return the data record associated with this node (const version)
IndexType getRightIndex(IndexType self) const
Given the current node's index, return the index of the right child.
void setPosition(const PointType &value)
Set the position associated with this node.
void setLeaf(bool value)
Specify whether this is a leaf node.
void setRightIndex(IndexType self, IndexType value)
Given the current node's index, set the right child index.
Result data type for k-nn queries.
int getLargestAxis() const
Return the index of the largest axis.
void expandBy(const PointType &p)
Expand the bounding box to contain another point.
void reset()
Mark the bounding box as invalid.
PointType max
Component-wise maximum.
PointType min
Component-wise minimum.