Nori  24
kdtree.h
1 /*
2  This file is part of Nori, a simple educational ray tracer
3 
4  Copyright (c) 2015 by Wenzel Jakob
5 
6  Nori is free software; you can redistribute it and/or modify
7  it under the terms of the GNU General Public License Version 3
8  as published by the Free Software Foundation.
9 
10  Nori is distributed in the hope that it will be useful,
11  but WITHOUT ANY WARRANTY; without even the implied warranty of
12  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13  GNU General Public License for more details.
14 
15  You should have received a copy of the GNU General Public License
16  along with this program. If not, see <http://www.gnu.org/licenses/>.
17 */
18 
19 #if !defined(__NORI_KDTREE_H)
20 #define __NORI_KDTREE_H
21 
22 #include <nori/bbox.h>
23 
24 NORI_NAMESPACE_BEGIN
25 
42 template <typename _PointType, typename _DataRecord> struct GenericKDTreeNode {
43  typedef _PointType PointType;
44  typedef _DataRecord DataRecord;
45  typedef uint32_t IndexType;
46  typedef typename PointType::Scalar Scalar;
47 
48  enum {
49  ELeafFlag = 0x10,
50  EAxisMask = 0x0F
51  };
52 
53  PointType position;
54  IndexType right;
55  DataRecord data;
56  uint8_t flags;
57 
59  GenericKDTreeNode() : position((Scalar) 0),
60  right(0), data(), flags(0) { }
62  GenericKDTreeNode(const PointType &position, const DataRecord &data)
63  : position(position), right(0), data(data), flags(0) {}
64 
66  IndexType getRightIndex(IndexType self) const { return right; }
68  void setRightIndex(IndexType self, IndexType value) { right = value; }
69 
71  IndexType getLeftIndex(IndexType self) const { return self + 1; }
73  void setLeftIndex(IndexType self, IndexType value) {
74  if (value != self+1)
75  throw NoriException("GenericKDTreeNode::setLeftIndex(): Internal error!");
76  }
77 
79  bool isLeaf() const { return flags & (uint8_t) ELeafFlag; }
81  void setLeaf(bool value) {
82  if (value)
83  flags |= (uint8_t) ELeafFlag;
84  else
85  flags &= (uint8_t) ~ELeafFlag;
86  }
87 
89  uint16_t getAxis() const { return flags & (uint8_t) EAxisMask; }
91  void setAxis(uint8_t axis) { flags = (flags & (uint8_t) ~EAxisMask) | axis; }
92 
94  const PointType &getPosition() const { return position; }
96  void setPosition(const PointType &value) { position = value; }
97 
99  DataRecord &getData() { return data; }
101  const DataRecord &getData() const { return data; }
103  void setData(const DataRecord &val) { data = val; }
104 };
105 
106 /* Forward declaration; the implementation is at the end of this file */
107 template <typename DataType, typename IndexType> void permute_inplace(
108  DataType *data, std::vector<IndexType> &perm);
109 
124 template <typename _NodeType> class PointKDTree {
125 public:
126  typedef _NodeType NodeType;
127  typedef typename NodeType::PointType PointType;
128  typedef typename NodeType::IndexType IndexType;
129  typedef typename PointType::Scalar Scalar;
130  typedef typename PointType::VectorType VectorType;
132 
133  enum {
134  Dimension = VectorType::RowsAtCompileTime
135  };
136 
138  enum Heuristic {
140  Balanced = 0,
141 
147  };
148 
150  struct SearchResult {
151  float distSquared;
152  IndexType index;
153 
154  SearchResult() {}
155 
156  SearchResult(float distSquared, IndexType index)
157  : distSquared(distSquared), index(index) { }
158 
159  std::string toString() const {
160  std::ostringstream oss;
161  oss << "SearchResult[distance=" << std::sqrt(distSquared)
162  << ", index=" << index << "]";
163  return oss.str();
164  }
165 
166  bool operator==(const SearchResult &r) const {
167  return distSquared == r.distSquared &&
168  index == r.index;
169  }
170  };
171 
172 public:
177  PointKDTree(size_t nodes = 0, Heuristic heuristic = SlidingMidpoint)
178  : m_nodes(nodes), m_heuristic(heuristic), m_depth(0) { }
179 
180  // =============================================================
182  // =============================================================
184  void clear() { m_nodes.clear(); m_bbox.reset(); }
186  void resize(size_t size) { m_nodes.resize(size); }
188  void reserve(size_t size) { m_nodes.reserve(size); }
190  size_t size() const { return m_nodes.size(); }
192  size_t capacity() const { return m_nodes.capacity(); }
194  void push_back(const NodeType &node) {
195  m_nodes.push_back(node);
196  m_bbox.expandBy(node.getPosition());
197  }
199  NodeType &operator[](size_t idx) { return m_nodes[idx]; }
201  const NodeType &operator[](size_t idx) const { return m_nodes[idx]; }
203  // =============================================================
204 
206  void setBoundingBox(const BoundingBoxType &bbox) { m_bbox = bbox; }
208  const BoundingBoxType &getBoundingBox() const { return m_bbox; }
210  size_t getDepth() const { return m_depth; }
212  void setDepth(size_t depth) { m_depth = depth; }
213 
221  void build(bool recomputeBoundingBox = false) {
222  if (m_nodes.size() == 0) {
223  std::cerr << "KDTree::build(): kd-tree is empty!" << endl;
224  return;
225  }
226 
227  cout << "Building a " << Dimension << "-dimensional kd-tree over "
228  << m_nodes.size() << " data points ("
229  << memString(m_nodes.size() * sizeof(NodeType)).c_str() << ") .. ";
230  cout.flush();
231 
232  if (recomputeBoundingBox) {
233  m_bbox.reset();
234  for (size_t i=0; i<m_nodes.size(); ++i)
235  m_bbox.expandBy(m_nodes[i].getPosition());
236  }
237 
238  /* Instead of shuffling around the node data itself, only modify
239  an indirection table initially. Once the tree construction
240  is done, this table will contain a indirection that can then
241  be applied to the data in one pass */
242  std::vector<IndexType> indirection(m_nodes.size());
243  for (size_t i=0; i<m_nodes.size(); ++i)
244  indirection[i] = (IndexType) i;
245 
246  m_depth = 0;
247  build(1, indirection.begin(), indirection.begin(), indirection.end());
248  permute_inplace(&m_nodes[0], indirection);
249 
250  cout << "done." << endl;
251  }
252 
260  void search(const PointType &p, float searchRadius, std::vector<IndexType> &results) const {
261  if (m_nodes.size() == 0)
262  return;
263 
264  IndexType *stack = (IndexType *) alloca((m_depth+1) * sizeof(IndexType));
265  IndexType index = 0, stackPos = 1, found = 0;
266  float distSquared = searchRadius*searchRadius;
267  stack[0] = 0;
268  results.clear();
269 
270  while (stackPos > 0) {
271  const NodeType &node = m_nodes[index];
272  IndexType nextIndex;
273 
274  /* Recurse on inner nodes */
275  if (!node.isLeaf()) {
276  float distToPlane = p[node.getAxis()]
277  - node.getPosition()[node.getAxis()];
278 
279  bool searchBoth = distToPlane*distToPlane <= distSquared;
280 
281  if (distToPlane > 0) {
282  /* The search query is located on the right side of the split.
283  Search this side first. */
284  if (hasRightChild(index)) {
285  if (searchBoth)
286  stack[stackPos++] = node.getLeftIndex(index);
287  nextIndex = node.getRightIndex(index);
288  } else if (searchBoth) {
289  nextIndex = node.getLeftIndex(index);
290  } else {
291  nextIndex = stack[--stackPos];
292  }
293  } else {
294  /* The search query is located on the left side of the split.
295  Search this side first. */
296  if (searchBoth && hasRightChild(index))
297  stack[stackPos++] = node.getRightIndex(index);
298 
299  nextIndex = node.getLeftIndex(index);
300  }
301  } else {
302  nextIndex = stack[--stackPos];
303  }
304 
305  /* Check if the current point is within the query's search radius */
306  const float pointDistSquared = (node.getPosition() - p).squaredNorm();
307 
308  if (pointDistSquared < distSquared) {
309  ++found;
310  results.push_back(index);
311  }
312 
313  index = nextIndex;
314  }
315  }
316 
334  size_t nnSearch(const PointType &p, float &_sqrSearchRadius,
335  size_t k, SearchResult *results) const {
336  if (m_nodes.size() == 0)
337  return 0;
338 
339  IndexType *stack = (IndexType *) alloca((m_depth+1) * sizeof(IndexType));
340  IndexType index = 0, stackPos = 1;
341  float sqrSearchRadius = _sqrSearchRadius;
342  size_t resultCount = 0;
343  bool isHeap = false;
344  stack[0] = 0;
345 
346  while (stackPos > 0) {
347  const NodeType &node = m_nodes[index];
348  IndexType nextIndex;
349 
350  /* Recurse on inner nodes */
351  if (!node.isLeaf()) {
352  float distToPlane = p[node.getAxis()] - node.getPosition()[node.getAxis()];
353 
354  bool searchBoth = distToPlane*distToPlane <= sqrSearchRadius;
355 
356  if (distToPlane > 0) {
357  /* The search query is located on the right side of the split.
358  Search this side first. */
359  if (hasRightChild(index)) {
360  if (searchBoth)
361  stack[stackPos++] = node.getLeftIndex(index);
362  nextIndex = node.getRightIndex(index);
363  } else if (searchBoth) {
364  nextIndex = node.getLeftIndex(index);
365  } else {
366  nextIndex = stack[--stackPos];
367  }
368  } else {
369  /* The search query is located on the left side of the split.
370  Search this side first. */
371  if (searchBoth && hasRightChild(index))
372  stack[stackPos++] = node.getRightIndex(index);
373 
374  nextIndex = node.getLeftIndex(index);
375  }
376  } else {
377  nextIndex = stack[--stackPos];
378  }
379 
380  /* Check if the current point is within the query's search radius */
381  const float pointDistSquared = (node.getPosition() - p).squaredNorm();
382 
383  if (pointDistSquared < sqrSearchRadius) {
384  /* Switch to a max-heap when the available search
385  result space is exhausted */
386  if (resultCount < k) {
387  /* There is still room, just add the point to
388  the search result list */
389  results[resultCount++] = SearchResult(pointDistSquared, index);
390  } else {
391  auto comparator = [](SearchResult &a, SearchResult &b) -> bool {
392  return a.distSquared < b.distSquared;
393  };
394 
395  if (!isHeap) {
396  /* Establish the max-heap property */
397  std::make_heap(results, results + resultCount, comparator);
398  isHeap = true;
399  }
400  SearchResult *end = results + resultCount + 1;
401 
402  /* Add the new point, remove the one that is farthest away */
403  results[resultCount] = SearchResult(pointDistSquared, index);
404  std::push_heap(results, end, comparator);
405  std::pop_heap(results, end, comparator);
406 
407  /* Reduce the search radius accordingly */
408  sqrSearchRadius = results[0].distSquared;
409  }
410  }
411  index = nextIndex;
412  }
413  _sqrSearchRadius = sqrSearchRadius;
414  return resultCount;
415  }
416 
430  size_t nnSearch(const PointType &p, size_t k,
431  SearchResult *results) const {
432  float searchRadiusSqr = std::numeric_limits<float>::infinity();
433  return nnSearch(p, searchRadiusSqr, k, results);
434  }
435 
436 protected:
438  bool hasRightChild(IndexType index) const {
439  return m_nodes[index].getRightIndex(index) != 0;
440  }
441 
443  void build(size_t depth,
444  typename std::vector<IndexType>::iterator base,
445  typename std::vector<IndexType>::iterator rangeStart,
446  typename std::vector<IndexType>::iterator rangeEnd) {
447  if (rangeEnd <= rangeStart)
448  throw NoriException("Internal error!");
449 
450  m_depth = std::max(depth, m_depth);
451 
452  IndexType count = (IndexType) (rangeEnd-rangeStart);
453 
454  if (count == 1) {
455  /* Create a leaf node */
456  m_nodes[*rangeStart].setLeaf(true);
457  return;
458  }
459 
460  int axis = 0;
461  typename std::vector<IndexType>::iterator split;
462 
463  switch (m_heuristic) {
464  case Balanced: {
465  /* Build a balanced tree */
466  split = rangeStart + count/2;
467  axis = m_bbox.getLargestAxis();
468  };
469  break;
470 
471  case SlidingMidpoint: {
472  /* Sliding midpoint rule: find a split that is close to the spatial median */
473  axis = m_bbox.getLargestAxis();
474 
475  Scalar midpoint = (Scalar) 0.5f
476  * (m_bbox.max[axis]+m_bbox.min[axis]);
477 
478  size_t nLT = std::count_if(rangeStart, rangeEnd,
479  [&](IndexType i) {
480  return m_nodes[i].getPosition()[axis] <= midpoint;
481  }
482  );
483 
484  /* Re-adjust the split to pass through a nearby point */
485  split = rangeStart + nLT;
486 
487  if (split == rangeStart)
488  ++split;
489  else if (split == rangeEnd)
490  --split;
491  };
492  break;
493  }
494 
495  std::nth_element(rangeStart, split, rangeEnd,
496  [&](IndexType i1, IndexType i2) {
497  return m_nodes[i1].getPosition()[axis] < m_nodes[i2].getPosition()[axis];
498  }
499  );
500 
501  NodeType &splitNode = m_nodes[*split];
502  splitNode.setAxis(axis);
503  splitNode.setLeaf(false);
504 
505  if (split+1 != rangeEnd)
506  splitNode.setRightIndex((IndexType) (rangeStart - base),
507  (IndexType) (split + 1 - base));
508  else
509  splitNode.setRightIndex((IndexType) (rangeStart - base), 0);
510 
511  splitNode.setLeftIndex((IndexType) (rangeStart - base),
512  (IndexType) (rangeStart + 1 - base));
513  std::iter_swap(rangeStart, split);
514 
515  /* Recursively build the children */
516  Scalar temp = m_bbox.max[axis],
517  splitPos = splitNode.getPosition()[axis];
518  m_bbox.max[axis] = splitPos;
519  build(depth+1, base, rangeStart+1, split+1);
520  m_bbox.max[axis] = temp;
521 
522  if (split+1 != rangeEnd) {
523  temp = m_bbox.min[axis];
524  m_bbox.min[axis] = splitPos;
525  build(depth+1, base, split+1, rangeEnd);
526  m_bbox.min[axis] = temp;
527  }
528  }
529 protected:
530  std::vector<NodeType> m_nodes;
531  BoundingBoxType m_bbox;
532  Heuristic m_heuristic;
533  size_t m_depth;
534 };
535 
555 template <typename DataType, typename IndexType> void permute_inplace(
556  DataType *data, std::vector<IndexType> &perm) {
557  for (size_t i=0; i<perm.size(); i++) {
558  if (perm[i] != i) {
559  /* The start of a new cycle has been found. Save
560  the value at this position, since it will be
561  overwritten */
562  IndexType j = (IndexType) i;
563  DataType curval = data[i];
564 
565  do {
566  /* Shuffle backwards */
567  IndexType k = perm[j];
568  data[j] = data[k];
569 
570  /* Also fix the permutations on the way */
571  perm[j] = j;
572  j = k;
573 
574  /* Until the end of the cycle has been found */
575  } while (perm[j] != i);
576 
577  /* Fix the final position with the saved value */
578  data[j] = curval;
579  perm[j] = j;
580  }
581  }
582 }
583 
584 NORI_NAMESPACE_END
585 
586 #endif /* __NORI_KDTREE_H */
Simple exception class, which stores a human-readable error description.
Definition: common.h:148
Generic multi-dimensional kd-tree data structure for point data.
Definition: kdtree.h:124
size_t nnSearch(const PointType &p, size_t k, SearchResult *results) const
Run a k-nearest-neighbor search query without any search radius threshold.
Definition: kdtree.h:430
NodeType & operator[](size_t idx)
Return one of the KD-tree nodes by index.
Definition: kdtree.h:199
void search(const PointType &p, float searchRadius, std::vector< IndexType > &results) const
Run a search query.
Definition: kdtree.h:260
void clear()
Clear the kd-tree array.
Definition: kdtree.h:184
void push_back(const NodeType &node)
Append a kd-tree node to the node array.
Definition: kdtree.h:194
bool hasRightChild(IndexType index) const
Return whether or not the inner node of the specified index has a right child node.
Definition: kdtree.h:438
size_t nnSearch(const PointType &p, float &_sqrSearchRadius, size_t k, SearchResult *results) const
Run a k-nearest-neighbor search query.
Definition: kdtree.h:334
size_t size() const
Return the size of the kd-tree.
Definition: kdtree.h:190
void build(size_t depth, typename std::vector< IndexType >::iterator base, typename std::vector< IndexType >::iterator rangeStart, typename std::vector< IndexType >::iterator rangeEnd)
Tree construction routine.
Definition: kdtree.h:443
void resize(size_t size)
Resize the kd-tree array.
Definition: kdtree.h:186
void setBoundingBox(const BoundingBoxType &bbox)
Set the BoundingBox of the underlying point data.
Definition: kdtree.h:206
Heuristic
Supported tree construction heuristics.
Definition: kdtree.h:138
@ Balanced
Create a balanced tree by splitting along the median.
Definition: kdtree.h:140
@ SlidingMidpoint
Use the sliding midpoint tree construction rule. This ensures that cells do not become overly elongat...
Definition: kdtree.h:146
const NodeType & operator[](size_t idx) const
Return one of the KD-tree nodes by index (const version)
Definition: kdtree.h:201
PointKDTree(size_t nodes=0, Heuristic heuristic=SlidingMidpoint)
Create an empty KD-tree that can hold the specified number of points.
Definition: kdtree.h:177
size_t getDepth() const
Return the depth of the constructed KD-tree.
Definition: kdtree.h:210
void reserve(size_t size)
Reserve a certain amount of memory for the kd-tree array.
Definition: kdtree.h:188
void setDepth(size_t depth)
Set the depth of the constructed KD-tree (be careful with this)
Definition: kdtree.h:212
void build(bool recomputeBoundingBox=false)
Construct the KD-tree hierarchy.
Definition: kdtree.h:221
size_t capacity() const
Return the capacity of the kd-tree.
Definition: kdtree.h:192
const BoundingBoxType & getBoundingBox() const
Return the BoundingBox of the underlying point data.
Definition: kdtree.h:208
Simple kd-tree node data structure for use with PointKDTree.
Definition: kdtree.h:42
uint16_t getAxis() const
Return the split axis associated with this node.
Definition: kdtree.h:89
IndexType getLeftIndex(IndexType self) const
Given the current node's index, return the index of the left child.
Definition: kdtree.h:71
GenericKDTreeNode()
Initialize a KD-tree node.
Definition: kdtree.h:59
const PointType & getPosition() const
Return the position associated with this node.
Definition: kdtree.h:94
void setLeftIndex(IndexType self, IndexType value)
Given the current node's index, set the left child index.
Definition: kdtree.h:73
void setData(const DataRecord &val)
Set the data record associated with this node.
Definition: kdtree.h:103
DataRecord & getData()
Return the data record associated with this node.
Definition: kdtree.h:99
GenericKDTreeNode(const PointType &position, const DataRecord &data)
Initialize a KD-tree node with the given data record.
Definition: kdtree.h:62
void setAxis(uint8_t axis)
Set the split flags associated with this node.
Definition: kdtree.h:91
bool isLeaf() const
Check whether this is a leaf node.
Definition: kdtree.h:79
const DataRecord & getData() const
Return the data record associated with this node (const version)
Definition: kdtree.h:101
IndexType getRightIndex(IndexType self) const
Given the current node's index, return the index of the right child.
Definition: kdtree.h:66
void setPosition(const PointType &value)
Set the position associated with this node.
Definition: kdtree.h:96
void setLeaf(bool value)
Specify whether this is a leaf node.
Definition: kdtree.h:81
void setRightIndex(IndexType self, IndexType value)
Given the current node's index, set the right child index.
Definition: kdtree.h:68
Result data type for k-nn queries.
Definition: kdtree.h:150
int getLargestAxis() const
Return the index of the largest axis.
Definition: bbox.h:308
void expandBy(const PointType &p)
Expand the bounding box to contain another point.
Definition: bbox.h:288
void reset()
Mark the bounding box as invalid.
Definition: bbox.h:282
PointType max
Component-wise maximum.
Definition: bbox.h:396
PointType min
Component-wise minimum.
Definition: bbox.h:395