Branch data Line data Source code
1 : : #pragma once
2 : :
3 : : #include <algorithm>
4 : : #include <cmath>
5 : : #include <cstdint>
6 : : #include <tuple>
7 : : #include <vector>
8 : : #include <cassert>
9 : :
10 : : namespace kdbush {
11 : :
12 : : template <std::uint8_t I, typename T>
13 : : struct nth {
14 : : inline static typename std::tuple_element<I, T>::type get(const T &t) {
15 : : return std::get<I>(t);
16 : : }
17 : : };
18 : :
19 : : template <typename TPoint, typename TContainer, typename TIndex = std::size_t>
20 : 0 : class KDBush {
21 : :
22 : : public:
23 : : using TNumber = decltype(nth<0, TPoint>::get(std::declval<TPoint>()));
24 : : static_assert(
25 : : std::is_same<TNumber, decltype(nth<1, TPoint>::get(std::declval<TPoint>()))>::value,
26 : : "point component types must be identical");
27 : :
28 : : static const std::uint8_t defaultNodeSize = 64;
29 : :
30 : 0 : KDBush(const std::uint8_t nodeSize_ = defaultNodeSize) : nodeSize(nodeSize_) {
31 : 0 : }
32 : :
33 : : KDBush(const std::vector<TPoint> &points_, const std::uint8_t nodeSize_ = defaultNodeSize)
34 : : : KDBush(std::begin(points_), std::end(points_), nodeSize_) {
35 : : }
36 : :
37 : : template <typename TPointIter>
38 : : KDBush(const TPointIter &points_begin,
39 : : const TPointIter &points_end,
40 : : const std::uint8_t nodeSize_ = defaultNodeSize)
41 : : : nodeSize(nodeSize_) {
42 : : fill(points_begin, points_end);
43 : : }
44 : :
45 : : void fill(const std::vector<TPoint> &points_) {
46 : : fill(std::begin(points_), std::end(points_));
47 : : }
48 : :
49 : : template <typename TPointIter>
50 : : void fill(const TPointIter &points_begin, const TPointIter &points_end) {
51 : : assert(points.empty());
52 : : const TIndex size = static_cast<TIndex>(std::distance(points_begin, points_end));
53 : :
54 : : if (size == 0) return;
55 : :
56 : : points.reserve(size);
57 : : //ids.reserve(size);
58 : :
59 : : TIndex i = 0;
60 : : for (auto p = points_begin; p != points_end; p++) {
61 : : points.emplace_back(nth<0, TPoint>::get(*p), nth<1, TPoint>::get(*p));
62 : : //ids.push_back(i++);
63 : : }
64 : :
65 : : sortKD(0, size - 1, 0);
66 : : }
67 : :
68 : : template <typename TVisitor>
69 : 0 : void range(const TNumber minX,
70 : : const TNumber minY,
71 : : const TNumber maxX,
72 : : const TNumber maxY,
73 : : const TVisitor &visitor) const {
74 : 0 : range(minX, minY, maxX, maxY, visitor, 0, static_cast<TIndex>(points.size() - 1), 0);
75 : 0 : }
76 : :
77 : : template <typename TVisitor>
78 : 0 : void within(const TNumber qx, const TNumber qy, const TNumber r, const TVisitor &visitor) const {
79 : 0 : within(qx, qy, r, visitor, 0, static_cast<TIndex>(points.size() - 1), 0);
80 : 0 : }
81 : :
82 : : protected:
83 : : //std::vector<TIndex> ids;
84 : : std::vector<TContainer> points;
85 : : const std::uint8_t nodeSize;
86 : :
87 : : template <typename TVisitor>
88 : 0 : void range(const TNumber minX,
89 : : const TNumber minY,
90 : : const TNumber maxX,
91 : : const TNumber maxY,
92 : : const TVisitor &visitor,
93 : : const TIndex left,
94 : : const TIndex right,
95 : : const std::uint8_t axis) const {
96 : 0 : if ( points.empty() )
97 : 0 : return;
98 : :
99 : 0 : if (right - left <= nodeSize) {
100 : 0 : for (auto i = left; i <= right; i++) {
101 : 0 : const TNumber x = std::get<0>(points[i].coords);
102 : 0 : const TNumber y = std::get<1>(points[i].coords);
103 : 0 : if (x >= minX && x <= maxX && y >= minY && y <= maxY) visitor(points[i]);
104 : 0 : }
105 : 0 : return;
106 : : }
107 : :
108 : 0 : const TIndex m = (left + right) >> 1;
109 : 0 : const TNumber x = std::get<0>(points[m].coords);
110 : 0 : const TNumber y = std::get<1>(points[m].coords);
111 : :
112 : 0 : if (x >= minX && x <= maxX && y >= minY && y <= maxY) visitor(points[m]);
113 : :
114 : 0 : if (axis == 0 ? minX <= x : minY <= y)
115 : 0 : range(minX, minY, maxX, maxY, visitor, left, m - 1, (axis + 1) % 2);
116 : :
117 : 0 : if (axis == 0 ? maxX >= x : maxY >= y)
118 : 0 : range(minX, minY, maxX, maxY, visitor, m + 1, right, (axis + 1) % 2);
119 : 0 : }
120 : :
121 : : template <typename TVisitor>
122 : 0 : void within(const TNumber qx,
123 : : const TNumber qy,
124 : : const TNumber r,
125 : : const TVisitor &visitor,
126 : : const TIndex left,
127 : : const TIndex right,
128 : : const std::uint8_t axis) const {
129 : :
130 : 0 : if ( points.empty() )
131 : 0 : return;
132 : :
133 : 0 : const TNumber r2 = r * r;
134 : :
135 : 0 : if (right - left <= nodeSize) {
136 : 0 : for (auto i = left; i <= right; i++) {
137 : 0 : const TNumber x = std::get<0>(points[i].coords);
138 : 0 : const TNumber y = std::get<1>(points[i].coords);
139 : 0 : if (sqDist(x, y, qx, qy) <= r2) visitor(points[i]);
140 : 0 : }
141 : 0 : return;
142 : : }
143 : :
144 : 0 : const TIndex m = (left + right) >> 1;
145 : 0 : const TNumber x = std::get<0>(points[m].coords);
146 : 0 : const TNumber y = std::get<1>(points[m].coords);
147 : :
148 : 0 : if (sqDist(x, y, qx, qy) <= r2) visitor(points[m]);
149 : :
150 : 0 : if (axis == 0 ? qx - r <= x : qy - r <= y)
151 : 0 : within(qx, qy, r, visitor, left, m - 1, (axis + 1) % 2);
152 : :
153 : 0 : if (axis == 0 ? qx + r >= x : qy + r >= y)
154 : 0 : within(qx, qy, r, visitor, m + 1, right, (axis + 1) % 2);
155 : 0 : }
156 : :
157 : 0 : void sortKD(const TIndex left, const TIndex right, const std::uint8_t axis) {
158 : 0 : if (right - left <= nodeSize) return;
159 : 0 : const TIndex m = (left + right) >> 1;
160 : 0 : if (axis == 0) {
161 : 0 : select<0>(m, left, right);
162 : 0 : } else {
163 : 0 : select<1>(m, left, right);
164 : : }
165 : 0 : sortKD(left, m - 1, (axis + 1) % 2);
166 : 0 : sortKD(m + 1, right, (axis + 1) % 2);
167 : 0 : }
168 : :
169 : : template <std::uint8_t I>
170 : 0 : void select(const TIndex k, TIndex left, TIndex right) {
171 : :
172 : 0 : while (right > left) {
173 : 0 : if (right - left > 600) {
174 : 0 : const double n = static_cast<double>(right - left + 1);
175 : 0 : const double m = static_cast<double>(k - left + 1);
176 : 0 : const double z = std::log(n);
177 : 0 : const double s = 0.5 * std::exp(2 * z / 3);
178 : 0 : const double r =
179 : 0 : k - m * s / n + 0.5 * std::sqrt(z * s * (1 - s / n)) * (2 * m < n ? -1 : 1);
180 : 0 : select<I>(k, std::max(left, TIndex(r)), std::min(right, TIndex(r + s)));
181 : 0 : }
182 : :
183 : 0 : const TNumber t = std::get<I>(points[k].coords);
184 : 0 : TIndex i = left;
185 : 0 : TIndex j = right;
186 : :
187 : 0 : swapItem(left, k);
188 : 0 : if (std::get<I>(points[right].coords) > t) swapItem(left, right);
189 : :
190 : 0 : while (i < j) {
191 : 0 : swapItem(i++, j--);
192 : 0 : while (std::get<I>(points[i].coords) < t) i++;
193 : 0 : while (std::get<I>(points[j].coords) > t) j--;
194 : : }
195 : :
196 : 0 : if (std::get<I>(points[left].coords) == t)
197 : 0 : swapItem(left, j);
198 : : else {
199 : 0 : swapItem(++j, right);
200 : : }
201 : :
202 : 0 : if (j <= k) left = j + 1;
203 : 0 : if (k <= j) right = j - 1;
204 : : }
205 : 0 : }
206 : :
207 : 0 : void swapItem(const TIndex i, const TIndex j) {
208 : : // std::iter_swap(ids.begin() + static_cast<std::int32_t>(i), ids.begin() + static_cast<std::int32_t>(j));
209 : 0 : std::iter_swap(points.begin() + static_cast<std::int32_t>(i), points.begin() + static_cast<std::int32_t>(j));
210 : 0 : }
211 : :
212 : 0 : TNumber sqDist(const TNumber ax, const TNumber ay, const TNumber bx, const TNumber by) const {
213 : 0 : auto dx = ax - bx;
214 : 0 : auto dy = ay - by;
215 : 0 : return dx * dx + dy * dy;
216 : : }
217 : : };
218 : :
219 : : } // namespace kdbush
|