From 920deb13e70cba18148a0298273a51b3d5bdba62 Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Tue, 22 Jul 2025 16:16:04 -0400 Subject: [PATCH 1/3] Use nanobind to compile RTree wrapper [WIP] --- src/spatial_graph/_rtree/rtree.py | 76 +++-- src/spatial_graph/_rtree/src/rtree.c | 10 +- src/spatial_graph/_rtree/wrapper_template.cpp | 318 ++++++++++++++++++ 3 files changed, 364 insertions(+), 40 deletions(-) create mode 100644 src/spatial_graph/_rtree/wrapper_template.cpp diff --git a/src/spatial_graph/_rtree/rtree.py b/src/spatial_graph/_rtree/rtree.py index 72182db..178113b 100644 --- a/src/spatial_graph/_rtree/rtree.py +++ b/src/spatial_graph/_rtree/rtree.py @@ -15,6 +15,45 @@ EXTRA_COMPILE_ARGS = ["-O3", "-Wno-unreachable-code"] +def create_wrapper(cls, item_dtype, coord_dtype, dims): + item_dtype = DType(item_dtype) + coord_dtype = DType(coord_dtype) + + ############################################ + # create wrapper from template and compile # + ############################################ + + src_dir = Path(__file__).parent + wrapper_template = Template( + file=str(src_dir / "wrapper_template.cpp"), + compilerSettings={"directiveStartToken": "%"}, + ) + wrapper_template.item_dtype = item_dtype + wrapper_template.coord_dtype = coord_dtype + wrapper_template.dims = dims + wrapper_template.c_distance_function = cls.c_distance_function + wrapper_template.pyx_item_t_declaration = cls.pyx_item_t_declaration + wrapper_template.c_item_t_declaration = cls.c_item_t_declaration + wrapper_template.c_converter_functions = cls.c_converter_functions + wrapper_template.c_equal_function = cls.c_equal_function + + wrapper = witty.compile_nanobind( + str(wrapper_template), + source_files=[ + src_dir / "src" / "rtree.h", + src_dir / "src" / "rtree.c", + src_dir / "src" / "config.h", + ], + extra_compile_args=EXTRA_COMPILE_ARGS, + include_dirs=[str(src_dir)], + language="c++", + quiet=False, # quiet=True, + define_macros=DEFINE_MACROS, + ) + + return wrapper + + class RTree: """A generic RTree implementation, compiled on-the-fly during instantiation. @@ -176,41 +215,8 @@ def __new__( coord_dtype, dims, ): - item_dtype = DType(item_dtype) - coord_dtype = DType(coord_dtype) - - ############################################ - # create wrapper from template and compile # - ############################################ - - src_dir = Path(__file__).parent - wrapper_template = Template( - file=str(src_dir / "wrapper_template.pyx"), - compilerSettings={"directiveStartToken": "%"}, - ) - wrapper_template.item_dtype = item_dtype - wrapper_template.coord_dtype = coord_dtype - wrapper_template.dims = dims - wrapper_template.c_distance_function = cls.c_distance_function - wrapper_template.pyx_item_t_declaration = cls.pyx_item_t_declaration - wrapper_template.c_item_t_declaration = cls.c_item_t_declaration - wrapper_template.c_converter_functions = cls.c_converter_functions - wrapper_template.c_equal_function = cls.c_equal_function - - wrapper = witty.compile_module( - str(wrapper_template), - source_files=[ - src_dir / "src" / "rtree.h", - src_dir / "src" / "rtree.c", - src_dir / "src" / "config.h", - ], - extra_compile_args=EXTRA_COMPILE_ARGS, - include_dirs=[str(src_dir)], - language="c", - quiet=True, - define_macros=DEFINE_MACROS, - ) - RTreeType = type(cls.__name__, (cls, wrapper.RTree), {}) + wrapper = create_wrapper(cls, item_dtype, coord_dtype, dims) + RTreeType = type(cls.__name__, (wrapper.RTree,), cls.__dict__.copy()) return wrapper.RTree.__new__(RTreeType) def __init__(self, item_dtype, coord_dtype, dims): diff --git a/src/spatial_graph/_rtree/src/rtree.c b/src/spatial_graph/_rtree/src/rtree.c index 852004e..b858bac 100644 --- a/src/spatial_graph/_rtree/src/rtree.c +++ b/src/spatial_graph/_rtree/src/rtree.c @@ -157,7 +157,7 @@ void heapify_down(struct priority_queue* queue, size_t index) { bool enqueue(struct priority_queue* queue, struct element element) { if (queue->size == queue->capacity) { queue->capacity *= 2; - queue->elements = realloc(queue->elements, sizeof(struct element) * queue->capacity); + queue->elements = (struct element*)realloc(queue->elements, sizeof(struct element) * queue->capacity); if (!queue->elements) return false; } @@ -175,7 +175,7 @@ struct element dequeue(struct priority_queue* queue) { // reclaim some memory when the queue is shrinking if (queue->size < queue->capacity/4) { queue->capacity /= 2; - struct element *elements = realloc(queue->elements, sizeof(struct element) * queue->capacity); + struct element *elements = (struct element*)realloc(queue->elements, sizeof(struct element) * queue->capacity); if (!elements) { queue->capacity *= 2; } else { @@ -908,7 +908,6 @@ static bool node_delete(struct rtree *tr, struct rect *nr, struct node *node, if (!rect_contains(&node->rects[h], ir)) { continue; } - struct rect crect = node->rects[h]; cow_node_or(node->nodes[h], return false); if (!node_delete(tr, &node->rects[h], node->nodes[h], ir, item, depth+1, removed, shrunk, compare, udata)) @@ -919,6 +918,7 @@ static bool node_delete(struct rtree *tr, struct rect *nr, struct node *node, continue; } removed: + struct rect crect = node->rects[h]; if (node->nodes[h]->count == 0) { // underflow node_free(tr, node->nodes[h]); @@ -995,7 +995,7 @@ int rtree_delete(struct rtree *tr, const coord_t *min, const coord_t *max, return rtree_delete0(tr, min, max, item, NULL, NULL); } -int rtree_delete_with_comparator(struct rtree *tr, const coord_t *min, +bool rtree_delete_with_comparator(struct rtree *tr, const coord_t *min, const coord_t *max, const item_t item, int (*compare)(const item_t a, const item_t b, void *udata), void *udata) @@ -1005,7 +1005,7 @@ int rtree_delete_with_comparator(struct rtree *tr, const coord_t *min, struct rtree *rtree_clone(struct rtree *tr) { if (!tr) return NULL; - struct rtree *tr2 = tr->malloc(sizeof(struct rtree)); + struct rtree *tr2 = (struct rtree*)tr->malloc(sizeof(struct rtree)); if (!tr2) return NULL; memcpy(tr2, tr, sizeof(struct rtree)); if (tr2->root) rc_fetch_add(&tr2->root->rc, 1); diff --git a/src/spatial_graph/_rtree/wrapper_template.cpp b/src/spatial_graph/_rtree/wrapper_template.cpp new file mode 100644 index 0000000..9ead7b9 --- /dev/null +++ b/src/spatial_graph/_rtree/wrapper_template.cpp @@ -0,0 +1,318 @@ +#include +#include + +#define DIMS $dims +%if $c_distance_function +#define KNN_USE_EXACT_DISTANCE +%end if + +namespace nb = nanobind; + +typedef $coord_dtype.base_c_type coord_t; +typedef $item_dtype.base_c_type item_base_t; + +extern "C"{ + + + %if $item_dtype.is_array + typedef item_base_t pyx_item_t[$item_dtype.size]; + %else + typedef item_base_t pyx_item_t; + %end if + typedef pyx_item_t* pyx_items_t; + + %if $c_item_t_declaration + $c_item_t_declaration + %else + %if $item_dtype.is_array + typedef struct item_t { + item_base_t data[$item_dtype.size]; + } item_t; + %else + typedef item_base_t item_t; + %end if + %end if + + %if $c_equal_function + $c_equal_function + %else + inline bool equal(const item_t a, const item_t b) { + %if $item_dtype.is_array + return memcmp(&a, &b, sizeof(item_t)); + %else + return a == b; + %end if + } + %end if + + #include "src/rtree.h" + #include "src/rtree.c" + + %if $c_converter_functions + $c_converter_functions + %else + %if $item_dtype.is_array + inline item_t convert_pyx_to_c_item(pyx_item_t *pyx_item, coord_t *min, coord_t *max) { + item_t c_item; + memcpy(&c_item, *pyx_item, sizeof(item_t)); + return c_item; + } + inline void copy_c_to_pyx_item(const item_t c_item, pyx_item_t *pyx_item) { + memcpy(pyx_item, &c_item, sizeof(item_t)); + } + %else + // default PYX<->C converters, just casting + inline item_t convert_pyx_to_c_item(pyx_item_t *pyx_item, coord_t *min, coord_t *max) { + return (item_t)*pyx_item; + } + inline void copy_c_to_pyx_item(const item_t c_item, pyx_item_t *pyx_item) { + memcpy(pyx_item, &c_item, sizeof(item_t)); + } + %end if + %end if + + %if $c_distance_function + $c_distance_function + %end if + +} // extern "C" + +/************ + * TYPEDEFS * + ************/ + +using ItemsArray = nb::ndarray< + nb::numpy, + item_base_t, + nb::shape<-1>, + nb::c_contig +>; +using ItemsArrayObject = nb::detail::ndarray_object< + nb::numpy, + item_base_t, + nb::shape<-1>, + nb::c_contig +>; + +/************** + * CONVERTERS * + **************/ + +// default implementation for scalar item_t +inline item_t create_rtree_item(item_base_t *item, coord_t *min, coord_t *max) { + return (item_t)*item; +} + +class RTree { + +public: + + RTree() { + _rtree = rtree_new(); + } + + ~RTree() { + rtree_free(_rtree); + } + + void insert_point_items( + nb::ndarray, nb::c_contig> items, + nb::ndarray, nb::c_contig> points) { + + for (size_t i = 0; i < items.size(); i++) { + rtree_insert( + _rtree, + &points(i, 0), + NULL, + create_rtree_item(&items(i), &points(i, 0), NULL) + ); + } + } + + void insert_bb_items( + nb::ndarray, nb::c_contig> items, + nb::ndarray, nb::c_contig> bb_mins, + nb::ndarray, nb::c_contig> bb_maxs) { + + for (size_t i = 0; i < items.size(); i++) { + rtree_insert( + _rtree, + &bb_mins(i, 0), + &bb_maxs(i, 0), + create_rtree_item(&items(i), &bb_mins(i, 0), &bb_maxs(i, 0)) + ); + } + } + + size_t count( + nb::ndarray, nb::c_contig> bb_min, + nb::ndarray, nb::c_contig> bb_max) { + + auto count_iterator = []( + const coord_t* bb_min, + const coord_t* bb_max, + const item_t item, + void* udata) { + + size_t* count = (size_t*)udata; + *count += 1; + return true; + }; + + size_t num = 0; + rtree_search( + _rtree, + bb_min.data(), + bb_max.data(), + count_iterator, + &num); + + return num; + } + + nb::tuple bounding_box() { + + coord_t* bb_min = new coord_t[DIMS]; + coord_t* bb_max = new coord_t[DIMS]; + rtree_bb(_rtree, bb_min, bb_max); + + nb::capsule bb_min_owner(bb_min, [](void* p) noexcept { + delete[] (coord_t*)p; + }); + nb::capsule bb_max_owner(bb_max, [](void* p) noexcept { + delete[] (coord_t*)p; + }); + + return nb::make_tuple( + nb::ndarray>( + bb_min, + { DIMS }, + bb_min_owner), + nb::ndarray>( + bb_max, + { DIMS }, + bb_max_owner) + ); + } + + typedef typename std::vector Items; + typedef typename std::vector Distances; + + ItemsArrayObject search( + nb::ndarray, nb::c_contig> bb_min, + nb::ndarray, nb::c_contig> bb_max) { + + Items results; + auto search_iterator = []( + const coord_t *bb_min, + const coord_t *bb_max, + const item_t item, + void* results) { + static_cast(results)->push_back(item); + return true; + }; + + rtree_search( + _rtree, + bb_min.data(), + bb_max.data(), + search_iterator, + &results); + + return ItemsArray(results.data(), { results.size() }).cast(); + } + + ItemsArrayObject nearest( + nb::ndarray, nb::c_contig> point, + size_t k, + bool return_distances) { + + struct Results { + Items items; + Distances distances; + size_t k; + bool return_distances; + }; + Results results; + results.k = k; + results.return_distances = return_distances; + auto nearest_iterator = []( + const item_t item, + coord_t distance, + void* results) { + Results* r = static_cast(results); + r->items.push_back(item); + if (r->return_distances) + r->distances.push_back(distance); + return r->items.size() < r->k; + }; + + bool all_good = rtree_nearest( + _rtree, + point.data(), + nearest_iterator, + &results); + + // TODO + //if not all_good: + //raise RuntimeError("RTree nearest neighbor search ran out of memory.") + + if (return_distances) { + // TODO + //return items[:results.size], distances[:results.size] + } else { + return ItemsArray(results.items.data(), { results.items.size() }).cast(); + } + } + + size_t delete_items( + ItemsArray items, + nb::ndarray, nb::c_contig> bb_mins, + nb::ndarray, nb::c_contig> bb_maxs) { + + // TODO + //if bb_maxs is None: + //bb_maxs = bb_mins + + //cdef pyx_items_t pyx_items = memview_to_pyx_items_t(items) + + size_t total_deleted = 0; + for (size_t i = 0; i < items.size(); i++) { + size_t num_deleted = rtree_delete( + _rtree, + &bb_mins(i, 0), + &bb_maxs(i, 0), + create_rtree_item(&items(i), &bb_mins(i, 0), &bb_maxs(i, 0)) + ); + // TODO + //if (num_deleted == -1) + //raise RuntimeError("RTree delete ran out of memory.") + total_deleted += num_deleted; + } + + return total_deleted; + } + + size_t __len__() { + + return rtree_count(_rtree); + } + +private: + + rtree* _rtree; +}; + +NB_MODULE(rtree, m) { + nb::class_(m, "RTree") + .def(nb::init<>()) + .def("insert_point_items", &RTree::insert_point_items) + .def("insert_bb_items", &RTree::insert_bb_items) + .def("delete_items", &RTree::delete_items) + .def("bounding_box", &RTree::bounding_box) + .def("count", &RTree::count) + .def("search", &RTree::search) + .def("nearest", &RTree::nearest) + .def("__len__", &RTree::__len__) + ; +} From 853f7cecbca2bf262aac137e1c2104ba408ecac6 Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Fri, 25 Jul 2025 11:03:09 -0400 Subject: [PATCH 2/3] Convert rtree wrapper to plain C++ and use nanobind --- pyproject.toml | 7 +- src/spatial_graph/_graph/graph_base.py | 2 +- src/spatial_graph/_rtree/line_rtree.py | 56 +-- src/spatial_graph/_rtree/rtree.py | 9 +- src/spatial_graph/_rtree/wrapper_template.cpp | 405 ++++++++++++------ src/spatial_graph/_rtree/wrapper_template.pyx | 356 --------------- 6 files changed, 322 insertions(+), 513 deletions(-) delete mode 100644 src/spatial_graph/_rtree/wrapper_template.pyx diff --git a/pyproject.toml b/pyproject.toml index f00ed94..5a163ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,12 @@ classifiers = [ "Programming Language :: Python :: 3.13", "Typing :: Typed", ] -dependencies = ["witty>=v0.2.1", "CT3>=3.3.3", "numpy", "setuptools>=75.8.0"] +dependencies = [ + "witty>=0.3.1", + "CT3>=3.3.3", + "numpy", + "setuptools>=75.8.0", +] [dependency-groups] test = ["pytest>=8.3.5", "pytest-cov>=6.1.1"] diff --git a/src/spatial_graph/_graph/graph_base.py b/src/spatial_graph/_graph/graph_base.py index c9e63ba..9dd44c6 100644 --- a/src/spatial_graph/_graph/graph_base.py +++ b/src/spatial_graph/_graph/graph_base.py @@ -76,7 +76,7 @@ def _compile_graph( edge_attr_dtypes=edge_attr_dtypes, directed=directed, ) - wrapper = witty.compile_module( + wrapper = witty.compile_cython( wrapper_template, source_files=[str(SRC_DIR / "src" / "graph_lite.h")], extra_compile_args=EXTRA_COMPILE_ARGS, diff --git a/src/spatial_graph/_rtree/line_rtree.py b/src/spatial_graph/_rtree/line_rtree.py index 90bdcbd..eb4e850 100644 --- a/src/spatial_graph/_rtree/line_rtree.py +++ b/src/spatial_graph/_rtree/line_rtree.py @@ -4,43 +4,39 @@ class LineRTree(RTree): - pyx_item_t_declaration = """ - cdef struct item_t: - item_base_t u - item_base_t v - bool corner_mask[DIMS] -""" - c_item_t_declaration = """ typedef struct item_t { - item_base_t u; - item_base_t v; + item_data_base_t u; + item_data_base_t v; bool corner_mask[DIMS]; } item_t; """ c_converter_functions = """ -inline item_t convert_pyx_to_c_item(pyx_item_t *pyx_item, - coord_t *start, coord_t *end) { +inline void item_to_item_data( + const item_t& item, + item_data_t *item_data) { + + (*item_data)[0] = item.u; + (*item_data)[1] = item.v; +} +inline item_t item_data_to_item( + item_data_base_t *item_data, + coord_t *start, + coord_t *end) { + item_t item; - coord_t tmp; - item.u = (*pyx_item)[0]; - item.v = (*pyx_item)[1]; - for (int d = 0; d < DIMS; d++) { + item.u = item_data[0]; + item.v = item_data[1]; + for (unsigned int d = 0; d < DIMS; d++) { item.corner_mask[d] = (start[d] < end[d]); if (!item.corner_mask[d]) { // swap coordinates to create bounding box - tmp = start[d]; - start[d] = end[d]; - end[d] = tmp; + std::swap(start[d], end[d]); } } return item; } -inline void copy_c_to_pyx_item(const item_t c_item, pyx_item_t *pyx_item) { - (*pyx_item)[0] = c_item.u; - (*pyx_item)[1] = c_item.v; -} """ c_equal_function = """ @@ -51,6 +47,7 @@ class LineRTree(RTree): c_distance_function = """ inline coord_t length2(const coord_t x[]) { + coord_t length2 = 0; for (int d = 0; d < DIMS; d++) { length2 += pow(x[d], 2); @@ -58,8 +55,10 @@ class LineRTree(RTree): return length2; } -inline coord_t point_segment_dist2(const coord_t point[], const coord_t start[], - const coord_t end[]) { +inline coord_t point_segment_dist2( + const coord_t point[], + const coord_t start[], + const coord_t end[]) { coord_t a[DIMS]; coord_t b[DIMS]; @@ -79,7 +78,7 @@ class LineRTree(RTree): alpha /= length2(a); // clip at 0 and 1 (beginning and end of line segment) - alpha = min0(1, max0(0, alpha)); + alpha = std::min((coord_t)1, std::max((coord_t)0, alpha)); for (int d = 0; d < DIMS; d++) { @@ -95,7 +94,10 @@ class LineRTree(RTree): } extern inline coord_t distance( - const coord_t point[], const struct rect *rect, const struct item_t item) { + const coord_t point[], + const struct rect *rect, + const struct item_t item) { + coord_t start[DIMS]; coord_t end[DIMS]; for (int d = 0; d < DIMS; d++) { @@ -107,6 +109,8 @@ class LineRTree(RTree): end[d] = rect->min[d]; } } + + return point_segment_dist2(point, start, end); } """ diff --git a/src/spatial_graph/_rtree/rtree.py b/src/spatial_graph/_rtree/rtree.py index d5c478b..6d370c2 100644 --- a/src/spatial_graph/_rtree/rtree.py +++ b/src/spatial_graph/_rtree/rtree.py @@ -161,7 +161,7 @@ def delete_item(self, item, bb_min, bb_max=None): """ items = np.array([item], dtype=self.item_dtype.base) bb_mins = bb_min[np.newaxis, :] - bb_maxs = None if bb_max is None else bb_max[np.newaxis, :] + bb_maxs = bb_mins if bb_max is None else bb_max[np.newaxis, :] return self._ctree.delete_items(items, bb_mins, bb_maxs) def delete_items(self, items, bb_mins, bb_maxs=None): @@ -183,6 +183,8 @@ def delete_items(self, items, bb_mins, bb_maxs=None): Array of shape `(n, dims)`, the minimum/maximum points of the bounding boxes per item to delete. """ + if bb_maxs is None: + bb_maxs = bb_mins return self._ctree.delete_items(items, bb_mins, bb_maxs) def count(self, bb_min, bb_max): @@ -222,7 +224,10 @@ def nearest(self, point, k=1, return_distances=False): `distances` contains the distance of each found item to the query point. """ - return self._ctree.nearest(point, k, return_distances) + if return_distances: + return self._ctree.nearest_with_distances(point, k) + else: + return self._ctree.nearest(point, k) def insert_bb_items(self, items, bb_mins, bb_maxs): """Insert items with bounding boxes. diff --git a/src/spatial_graph/_rtree/wrapper_template.cpp b/src/spatial_graph/_rtree/wrapper_template.cpp index 9ead7b9..1b81392 100644 --- a/src/spatial_graph/_rtree/wrapper_template.cpp +++ b/src/spatial_graph/_rtree/wrapper_template.cpp @@ -1,110 +1,199 @@ +#include #include #include +namespace nb = nanobind; + +// number of spatial dimensions #define DIMS $dims -%if $c_distance_function -#define KNN_USE_EXACT_DISTANCE -%end if -namespace nb = nanobind; +/********************* + * TYPE DECLARATIONS * + *********************/ +// the base type of coordinates typedef $coord_dtype.base_c_type coord_t; -typedef $item_dtype.base_c_type item_base_t; -extern "C"{ +// the base type of item data +typedef $item_dtype.base_c_type item_data_base_t; +// the external representation of an item (referred to as item_data) +%if $item_dtype.is_array + // for arrays + typedef std::array item_data_t; +%else + // for scalars + typedef item_data_base_t item_data_t; +%end if +// the internal type of an item +%if $c_item_t_declaration + // custom item_t declaration + $c_item_t_declaration +%else + // default item_t declaration %if $item_dtype.is_array - typedef item_base_t pyx_item_t[$item_dtype.size]; + // for arrays + typedef struct item_t { + item_data_base_t data[$item_dtype.size]; + } item_t; %else - typedef item_base_t pyx_item_t; + // for scalars + typedef item_data_base_t item_t; %end if - typedef pyx_item_t* pyx_items_t; +%end if - %if $c_item_t_declaration - $c_item_t_declaration - %else +// shape of arrays holding data for multiple items +%if $item_dtype.is_array + using items_data_shape = nb::shape<-1, $item_dtype.size>; +%else + using items_data_shape = nb::shape<-1>; +%end if + +/************* + * FUNCTIONS * + *************/ + +/* CONVERTERS */ + +// item_t -> item_data_t and item_data_t -> item_t +%if $c_converter_functions + // custom converter functions + $c_converter_functions +%else %if $item_dtype.is_array - typedef struct item_t { - item_base_t data[$item_dtype.size]; - } item_t; + // default converters for array item_t + inline item_t item_data_to_item( + item_data_base_t *item_data, + coord_t *min, + coord_t *max) { + + item_t item; + memcpy(&item, item_data, sizeof(item_t)); + return item; + } + inline void item_to_item_data( + const item_t& item, + item_data_t *item_data) { + + memcpy(item_data, &item, sizeof(item_t)); + } %else - typedef item_base_t item_t; - %end if + // default converters for scalar item_t + inline item_t item_data_to_item( + item_data_base_t *item_data, + coord_t *min, + coord_t *max) { + + return (item_t)*item_data; + } + inline void item_to_item_data( + const item_t& item, + item_data_t *item_data) { + + *item_data = static_cast(item); + } %end if +%end if - %if $c_equal_function + +/* COMPARISON */ + +%if $c_equal_function + // custom comparison function $c_equal_function - %else +%else + // default comparison function inline bool equal(const item_t a, const item_t b) { - %if $item_dtype.is_array - return memcmp(&a, &b, sizeof(item_t)); - %else - return a == b; - %end if + %if $item_dtype.is_array + // for arrays + return memcmp(&a, &b, sizeof(item_t)); + %else + // for scalars + return a == b; + %end if } - %end if +%end if + +/******************* + * RTREE C BACKEND * + *******************/ + +/* DISTANCE */ + +%if $c_distance_function + // if a custom distance function is used, use this instead of the default + // bounding box distance function for kNN search + #define KNN_USE_EXACT_DISTANCE +%end if + +extern "C"{ #include "src/rtree.h" #include "src/rtree.c" - %if $c_converter_functions - $c_converter_functions - %else - %if $item_dtype.is_array - inline item_t convert_pyx_to_c_item(pyx_item_t *pyx_item, coord_t *min, coord_t *max) { - item_t c_item; - memcpy(&c_item, *pyx_item, sizeof(item_t)); - return c_item; - } - inline void copy_c_to_pyx_item(const item_t c_item, pyx_item_t *pyx_item) { - memcpy(pyx_item, &c_item, sizeof(item_t)); - } - %else - // default PYX<->C converters, just casting - inline item_t convert_pyx_to_c_item(pyx_item_t *pyx_item, coord_t *min, coord_t *max) { - return (item_t)*pyx_item; - } - inline void copy_c_to_pyx_item(const item_t c_item, pyx_item_t *pyx_item) { - memcpy(pyx_item, &c_item, sizeof(item_t)); - } - %end if - %end if +} // extern "C" - %if $c_distance_function +%if $c_distance_function + // custom distance function for exact computation of distances to items $c_distance_function - %end if - -} // extern "C" +%end if /************ * TYPEDEFS * ************/ -using ItemsArray = nb::ndarray< +using ItemsVec = std::vector; +using DistancesVec = std::vector; + +using Items = nb::ndarray< nb::numpy, - item_base_t, - nb::shape<-1>, + item_data_base_t, + items_data_shape, nb::c_contig >; -using ItemsArrayObject = nb::detail::ndarray_object< +using Distances = nb::ndarray< nb::numpy, - item_base_t, + coord_t, nb::shape<-1>, nb::c_contig >; -/************** - * CONVERTERS * - **************/ +using Point = nb::ndarray< + nb::numpy, + coord_t, + nb::shape, + nb::c_contig +>; +using Points = nb::ndarray< + nb::numpy, + coord_t, + nb::shape<-1, DIMS>, + nb::c_contig +>; -// default implementation for scalar item_t -inline item_t create_rtree_item(item_base_t *item, coord_t *min, coord_t *max) { - return (item_t)*item; -} +/*************** + * RTree CLASS * + ***************/ + +// Keep a unique C++ RTree typename for each module, so that multiple RTrees can +// coexist. The name in the resulting python module will still be just RTree. +#define RTree RTree_WITTY_MODULE_HASH class RTree { +private: + + // kNN search results + struct Results { + ItemsVec items; + DistancesVec distances; + size_t k; + bool return_distances; + }; + + rtree* _rtree; + public: RTree() { @@ -116,37 +205,45 @@ class RTree { } void insert_point_items( - nb::ndarray, nb::c_contig> items, - nb::ndarray, nb::c_contig> points) { + Items items, + Points points) { - for (size_t i = 0; i < items.size(); i++) { + for (size_t i = 0; i < items.shape(0); i++) { rtree_insert( _rtree, &points(i, 0), NULL, - create_rtree_item(&items(i), &points(i, 0), NULL) + %if $item_dtype.is_array + item_data_to_item(&items(i, 0), &points(i, 0), NULL) + %else + item_data_to_item(&items(i), &points(i, 0), NULL) + %end if ); } } void insert_bb_items( - nb::ndarray, nb::c_contig> items, - nb::ndarray, nb::c_contig> bb_mins, - nb::ndarray, nb::c_contig> bb_maxs) { + Items items, + Points bb_mins, + Points bb_maxs) { - for (size_t i = 0; i < items.size(); i++) { + for (size_t i = 0; i < items.shape(0); i++) { rtree_insert( _rtree, &bb_mins(i, 0), &bb_maxs(i, 0), - create_rtree_item(&items(i), &bb_mins(i, 0), &bb_maxs(i, 0)) + %if $item_dtype.is_array + item_data_to_item(&items(i, 0), &bb_mins(i, 0), &bb_maxs(i, 0)) + %else + item_data_to_item(&items(i), &bb_mins(i, 0), &bb_maxs(i, 0)) + %end if ); } } size_t count( - nb::ndarray, nb::c_contig> bb_min, - nb::ndarray, nb::c_contig> bb_max) { + Point bb_min, + Point bb_max) { auto count_iterator = []( const coord_t* bb_min, @@ -195,20 +292,20 @@ class RTree { ); } - typedef typename std::vector Items; - typedef typename std::vector Distances; + Items search( + Point bb_min, + Point bb_max) { - ItemsArrayObject search( - nb::ndarray, nb::c_contig> bb_min, - nb::ndarray, nb::c_contig> bb_max) { - - Items results; + ItemsVec* results = new ItemsVec(); auto search_iterator = []( - const coord_t *bb_min, - const coord_t *bb_max, - const item_t item, - void* results) { - static_cast(results)->push_back(item); + const coord_t *bb_min, + const coord_t *bb_max, + const item_t item, + void* results) { + + item_data_t item_data; + item_to_item_data(item, &item_data); + static_cast(results)->push_back(item_data); return true; }; @@ -217,31 +314,35 @@ class RTree { bb_min.data(), bb_max.data(), search_iterator, - &results); + results); - return ItemsArray(results.data(), { results.size() }).cast(); + nb::capsule results_owner(results, [](void* p) noexcept { + delete (ItemsVec*)p; + }); + + %if $item_dtype.is_array + return Items(results->data(), { results->size(), $item_dtype.size }, results_owner); + %else + return Items(results->data(), { results->size() }, results_owner); + %end if } - ItemsArrayObject nearest( - nb::ndarray, nb::c_contig> point, + Results* find_nearest( + Point point, size_t k, bool return_distances) { - struct Results { - Items items; - Distances distances; - size_t k; - bool return_distances; - }; - Results results; - results.k = k; - results.return_distances = return_distances; + Results* results = new Results(); + results->k = k; + results->return_distances = return_distances; auto nearest_iterator = []( const item_t item, coord_t distance, void* results) { Results* r = static_cast(results); - r->items.push_back(item); + item_data_t item_data; + item_to_item_data(item, &item_data); + r->items.push_back(item_data); if (r->return_distances) r->distances.push_back(distance); return r->items.size() < r->k; @@ -251,42 +352,91 @@ class RTree { _rtree, point.data(), nearest_iterator, - &results); + results); - // TODO - //if not all_good: - //raise RuntimeError("RTree nearest neighbor search ran out of memory.") + if (!all_good) + throw std::bad_alloc(); - if (return_distances) { - // TODO - //return items[:results.size], distances[:results.size] - } else { - return ItemsArray(results.items.data(), { results.items.size() }).cast(); - } + return results; } - size_t delete_items( - ItemsArray items, - nb::ndarray, nb::c_contig> bb_mins, - nb::ndarray, nb::c_contig> bb_maxs) { + Items nearest( + Point point, + size_t k) { + + Results* results = find_nearest(point, k, false); + + nb::capsule results_owner(results, [](void* p) noexcept { + delete (Results*)p; + }); + + %if $item_dtype.is_array + Items items( + results->items.data(), + { results->items.size(), $item_dtype.size }, + results_owner); + %else + Items items( + results->items.data(), + { results->items.size() }, + results_owner); + %end if + + return items; + } + + nb::tuple nearest_with_distances( + Point point, + size_t k) { + + Results* results = find_nearest(point, k, true); - // TODO - //if bb_maxs is None: - //bb_maxs = bb_mins + nb::capsule results_owner(results, [](void* p) noexcept { + delete (Results*)p; + }); + + %if $item_dtype.is_array + Items items( + results->items.data(), + { results->items.size(), $item_dtype.size }, + results_owner); + %else + Items items( + results->items.data(), + { results->items.size() }, + results_owner); + %end if + + Distances distances( + results->distances.data(), + { results->distances.size() }, + results_owner); + + return nb::make_tuple(items, distances); + + } - //cdef pyx_items_t pyx_items = memview_to_pyx_items_t(items) + size_t delete_items( + Items items, + Points bb_mins, + Points bb_maxs) { size_t total_deleted = 0; - for (size_t i = 0; i < items.size(); i++) { + for (size_t i = 0; i < items.shape(0); i++) { size_t num_deleted = rtree_delete( _rtree, &bb_mins(i, 0), &bb_maxs(i, 0), - create_rtree_item(&items(i), &bb_mins(i, 0), &bb_maxs(i, 0)) + %if $item_dtype.is_array + item_data_to_item(&items(i, 0), &bb_mins(i, 0), &bb_maxs(i, 0)) + %else + item_data_to_item(&items(i), &bb_mins(i, 0), &bb_maxs(i, 0)) + %end if ); - // TODO - //if (num_deleted == -1) - //raise RuntimeError("RTree delete ran out of memory.") + + if (num_deleted < 0) + throw std::bad_alloc(); + total_deleted += num_deleted; } @@ -297,12 +447,12 @@ class RTree { return rtree_count(_rtree); } - -private: - - rtree* _rtree; }; +/************************* + * NANOBIND REGISTRATION * + *************************/ + NB_MODULE(rtree, m) { nb::class_(m, "RTree") .def(nb::init<>()) @@ -313,6 +463,7 @@ NB_MODULE(rtree, m) { .def("count", &RTree::count) .def("search", &RTree::search) .def("nearest", &RTree::nearest) + .def("nearest_with_distances", &RTree::nearest_with_distances) .def("__len__", &RTree::__len__) ; } diff --git a/src/spatial_graph/_rtree/wrapper_template.pyx b/src/spatial_graph/_rtree/wrapper_template.pyx deleted file mode 100644 index f0844b6..0000000 --- a/src/spatial_graph/_rtree/wrapper_template.pyx +++ /dev/null @@ -1,356 +0,0 @@ -from libc.stdint cimport * -import numpy as np - - -ctypedef int bool - -cdef extern from *: - """ - typedef int bool; - #define false 0 - #define true 1 - - %if $c_distance_function - #define KNN_USE_EXACT_DISTANCE - %end if - #define DIMS $dims - - typedef $coord_dtype.to_pyxtype() coord_t; - typedef $item_dtype.base_c_type item_base_t; - %if $item_dtype.is_array - typedef item_base_t pyx_item_t[$item_dtype.size]; - %else - typedef item_base_t pyx_item_t; - %end if - typedef pyx_item_t* pyx_items_t; - - %if $c_item_t_declaration - $c_item_t_declaration - %else - %if $item_dtype.is_array - typedef struct item_t { - item_base_t data[$item_dtype.size]; - } item_t; - %else - typedef item_base_t item_t; - %end if - %end if - - %if $c_equal_function - $c_equal_function - %else - inline bool equal(const item_t a, const item_t b) { - %if $item_dtype.is_array - return memcmp(&a, &b, sizeof(item_t)); - %else - return a == b; - %end if - } - %end if - - #include "src/rtree.h" - #include "src/rtree.c" - - %if $c_converter_functions - $c_converter_functions - %else - %if $item_dtype.is_array - inline item_t convert_pyx_to_c_item(pyx_item_t *pyx_item, coord_t *min, coord_t *max) { - item_t c_item; - memcpy(&c_item, *pyx_item, sizeof(item_t)); - return c_item; - } - inline void copy_c_to_pyx_item(const item_t c_item, pyx_item_t *pyx_item) { - memcpy(pyx_item, &c_item, sizeof(item_t)); - } - %else - // default PYX<->C converters, just casting - inline item_t convert_pyx_to_c_item(pyx_item_t *pyx_item, coord_t *min, coord_t *max) { - return (item_t)*pyx_item; - } - inline void copy_c_to_pyx_item(const item_t c_item, pyx_item_t *pyx_item) { - memcpy(pyx_item, &c_item, sizeof(item_t)); - } - %end if - %end if - - %if $c_distance_function - $c_distance_function - %end if - """ - cdef enum: - DIMS = $dims - ctypedef $coord_dtype.to_pyxtype() coord_t - ctypedef $item_dtype.base_c_type item_base_t - %if $item_dtype.is_array - ctypedef item_base_t pyx_item_t[$item_dtype.size] - %else - ctypedef item_base_t pyx_item_t - %end if - ctypedef pyx_item_t* pyx_items_t - - %if $pyx_item_t_declaration - $pyx_item_t_declaration - %else - %if $item_dtype.is_array - # item_t can't be an array in rtree, arrays can't be assigned to (and this - # is needed inside rtree). So we make item_t a struct with field `data` to - # hold the array. - cdef struct item_t: - item_base_t data[$item_dtype.size] - %else - ctypedef item_base_t item_t - %end if - %end if - - # PYX <-> C converters - cdef item_t convert_pyx_to_c_item(pyx_item_t *pyx_item, coord_t *min, coord_t* max) - cdef void copy_c_to_pyx_item(const item_t c_item, pyx_item_t *pyx_item) - - # rtree API - cdef struct rtree - cdef rtree *rtree_new() - cdef void rtree_free(rtree *tr) - cdef bool rtree_insert( - rtree *tr, - const coord_t *min, - const coord_t *max, - const item_t item) - cdef void rtree_search( - const rtree *tr, - const coord_t *min, - const coord_t *max, - bool (*iter)( - const coord_t *min, - const coord_t *max, - const item_t item, - void *udata), - void *udata) - cdef bool rtree_nearest( - rtree *tr, - const coord_t *point, - bool (*iter)( - const item_t item, - coord_t distance, - void *udata), - void *udata) - cdef int rtree_delete( - rtree *tr, - const coord_t *min, - const coord_t *max, - const item_t item) - cdef size_t rtree_count(const rtree *tr) - cdef void rtree_bb(const rtree *tr, coord_t *min, coord_t *max) - - -cdef pyx_items_t memview_to_pyx_items_t($item_dtype.to_pyxtype(add_dim=True) items): - # implementation depends on dimension of item - %if $item_dtype.is_array - return &items[0, 0] - %else - return &items[0] - %end if - - -cdef bint count_iterator( - const coord_t* bb_min, - const coord_t* bb_max, - const item_t item, - void* udata - ) noexcept: - - cdef size_t* count = udata - count[0] = count[0] + 1 - return True - - -cdef struct search_results: - size_t size - pyx_items_t items - - -cdef init_search_results_from_memview(search_results* r, $item_dtype.to_pyxtype(add_dim=True) items): - r.size = 0 - r.items = memview_to_pyx_items_t(items) - - -cdef bint search_iterator( - const coord_t* bb_min, - const coord_t* bb_max, - const item_t item, - void* udata - ) noexcept: - - cdef search_results* results = udata - copy_c_to_pyx_item(item, &results.items[results.size]) - results.size += 1 - return True - - -cdef struct nearest_results: - size_t size - size_t max_size - pyx_items_t items - coord_t *distances - - -cdef init_nearest_results_from_memview(nearest_results* r, - $item_dtype.to_pyxtype(add_dim=True) items, - coord_t[::1] distances): - r.size = 0 - r.max_size = len(items) - r.items = memview_to_pyx_items_t(items) - r.distances = &distances[0] if distances is not None else NULL - - -cdef bint nearest_iterator( - const item_t item, - coord_t distance, - void* udata - ) noexcept: - - cdef nearest_results* results = udata - copy_c_to_pyx_item(item, &results.items[results.size]) - if results.distances != NULL: - results.distances[results.size] = distance - results.size += 1 - return results.size < results.max_size - - -cdef class RTree: - - cdef rtree* _rtree - - def __cinit__(self): - self._rtree = rtree_new() - - def __dealloc__(self): - rtree_free(self._rtree) - - def insert_point_items( - self, - $item_dtype.to_pyxtype(add_dim=True) items, - coord_t[:, ::1] points - ): - - cdef pyx_items_t pyx_items = memview_to_pyx_items_t(items) - - for i in range(len(items)): - rtree_insert( - self._rtree, - &points[i, 0], - NULL, - convert_pyx_to_c_item(&pyx_items[i], &points[i, 0], NULL)) - - def insert_bb_items( - self, - $item_dtype.to_pyxtype(add_dim=True) items, - coord_t[:, ::1] bb_mins, - coord_t[:, ::1] bb_maxs - ): - - cdef pyx_items_t pyx_items = memview_to_pyx_items_t(items) - - for i in range(len(items)): - rtree_insert( - self._rtree, - &bb_mins[i, 0], - &bb_maxs[i, 0], - convert_pyx_to_c_item(&pyx_items[i], &bb_mins[i, 0], &bb_maxs[i, 0])) - - def count(self, coord_t[::1] bb_min, coord_t[::1] bb_max): - - cdef size_t num = 0 - rtree_search( - self._rtree, - &bb_min[0], - &bb_max[0], - &count_iterator, - &num) - - return num - - def bounding_box(self): - bb_min = np.empty(($dims,), dtype="$coord_dtype.base") - bb_max = np.empty(($dims,), dtype="$coord_dtype.base") - cdef coord_t[::1] _bb_min = bb_min - cdef coord_t[::1] _bb_max = bb_max - rtree_bb(self._rtree, &_bb_min[0], &_bb_max[0]) - return (bb_min, bb_max) - - def search(self, coord_t[::1] bb_min, coord_t[::1] bb_max): - - cdef search_results results - cdef size_t num_results = self.count(bb_min, bb_max) - - items = np.zeros((num_results, $item_dtype.size), dtype="$item_dtype.base") - if num_results == 0: - return items - init_search_results_from_memview(&results, items) - - rtree_search( - self._rtree, - &bb_min[0], - &bb_max[0], - &search_iterator, - &results) - - return items - - def nearest(self, coord_t[::1] point, size_t k, return_distances=False): - - cdef nearest_results results - - items = np.zeros((k, $item_dtype.size), dtype="$item_dtype.base") - if return_distances: - distances = np.zeros((k,), dtype="$coord_dtype.base") - else: - distances = None - if k == 0: - return items - init_nearest_results_from_memview(&results, items, distances) - - all_good = rtree_nearest( - self._rtree, - &point[0], - &nearest_iterator, - &results) - - if not all_good: - raise RuntimeError("RTree nearest neighbor search ran out of memory.") - - if return_distances: - return items[:results.size], distances[:results.size] - else: - return items[:results.size] - - def delete_items( - self, - $item_dtype.to_pyxtype(add_dim=True) items, - coord_t[:, ::1] bb_mins, - coord_t[:, ::1] bb_maxs=None - ): - - if bb_maxs is None: - bb_maxs = bb_mins - - cdef pyx_items_t pyx_items = memview_to_pyx_items_t(items) - - total_deleted = 0 - for i in range(len(items)): - num_deleted = rtree_delete( - self._rtree, - &bb_mins[i, 0], - &bb_maxs[i, 0], - convert_pyx_to_c_item(&pyx_items[i], &bb_mins[i, 0], &bb_maxs[i, 0])) - if num_deleted == -1: - raise RuntimeError("RTree delete ran out of memory.") - # if num_deleted == 0: - # print(f"Item {pyx_items[i]} not deleted!") - total_deleted += num_deleted - - return total_deleted - - def __len__(self): - - return rtree_count(self._rtree) From 2d468785dbc9666c7db60385e7760b1b0e811b44 Mon Sep 17 00:00:00 2001 From: Jan Funke Date: Fri, 25 Jul 2025 11:36:21 -0400 Subject: [PATCH 3/3] Fix bug in test --- tests/test_rtree.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rtree.py b/tests/test_rtree.py index 10282ed..579c0d7 100644 --- a/tests/test_rtree.py +++ b/tests/test_rtree.py @@ -60,7 +60,7 @@ def test_nearest(): np.arange(10_000_000, dtype="uint64"), all_points, ) - points = rtree.nearest(np.array([0.5, 0.5]), k=100_000) + points = rtree.nearest(np.array([0.5, 0.5, 0.5]), k=100_000) assert len(points) == 100_000 # ensure that we find the right item in a big tree