From 485ded7b89b54bd666ca775af14c0a575bfdbac0 Mon Sep 17 00:00:00 2001 From: Vincent M Date: Thu, 17 Nov 2022 08:42:23 +0100 Subject: [PATCH 1/2] _kd_tree has joined the party --- setup.py | 2 + sklearn/neighbors/_binary_tree.pxi | 64 +++++++++++++++++------------- 2 files changed, 38 insertions(+), 28 deletions(-) diff --git a/setup.py b/setup.py index 13f44982dcfe1..d94edbe923eb5 100755 --- a/setup.py +++ b/setup.py @@ -98,6 +98,8 @@ "sklearn.metrics._pairwise_distances_reduction._argkmin", "sklearn.metrics._pairwise_distances_reduction._radius_neighbors", "sklearn.metrics._pairwise_fast", + "sklearn.neighbors._ball_tree", + "sklearn.neighbors._kd_tree", "sklearn.neighbors._partition_nodes", "sklearn.tree._splitter", "sklearn.tree._utils", diff --git a/sklearn/neighbors/_binary_tree.pxi b/sklearn/neighbors/_binary_tree.pxi index 2b9ac839945cf..626670e74c7a7 100644 --- a/sklearn/neighbors/_binary_tree.pxi +++ b/sklearn/neighbors/_binary_tree.pxi @@ -166,6 +166,7 @@ from ..utils._typedefs import DTYPE, ITYPE from ..utils._heap cimport heap_push from ..utils._sorting cimport simultaneous_sort as _simultaneous_sort +# TODO: use cnp.PyArray_ENABLEFLAGS when cython >= 3.0 is used. cdef extern from "numpy/arrayobject.h": void PyArray_ENABLEFLAGS(cnp.ndarray arr, int flags) @@ -511,8 +512,8 @@ cdef class NeighborsHeap: n_nbrs : int the size of each heap. """ - cdef cnp.ndarray distances_arr - cdef cnp.ndarray indices_arr + cdef DTYPE_t[:, ::1] distances_arr + cdef ITYPE_t[:, ::1] indices_arr cdef DTYPE_t[:, ::1] distances cdef ITYPE_t[:, ::1] indices @@ -538,7 +539,7 @@ cdef class NeighborsHeap: """ if sort: self._sort() - return self.distances_arr, self.indices_arr + return self.distances_arr.base, self.indices_arr.base cdef inline DTYPE_t largest(self, ITYPE_t row) nogil except -1: """Return the largest distance in the given row""" @@ -643,8 +644,8 @@ cdef class NodeHeap: heap[i].val < min(heap[2 * i + 1].val, heap[2 * i + 2].val) """ - cdef cnp.ndarray data_arr - cdef NodeHeapData_t[::1] data + cdef NodeHeapData_t[:] data_arr + cdef NodeHeapData_t[:] data cdef ITYPE_t n def __cinit__(self): @@ -660,13 +661,16 @@ cdef class NodeHeap: cdef int resize(self, ITYPE_t new_size) except -1: """Resize the heap to be either larger or smaller""" - cdef NodeHeapData_t *data_ptr - cdef NodeHeapData_t *new_data_ptr - cdef ITYPE_t i - cdef ITYPE_t size = self.data.shape[0] - cdef cnp.ndarray new_data_arr = np.zeros(new_size, - dtype=NodeHeapData) - cdef NodeHeapData_t[::1] new_data = new_data_arr + cdef: + NodeHeapData_t *data_ptr + NodeHeapData_t *new_data_ptr + ITYPE_t i + ITYPE_t size = self.data.shape[0] + NodeHeapData_t[:] new_data_arr = np.zeros( + new_size, + dtype=NodeHeapData, + ) + NodeHeapData_t[:] new_data = new_data_arr if size > 0 and new_size > 0: data_ptr = &self.data[0] @@ -769,11 +773,11 @@ VALID_METRIC_IDS = get_valid_metric_ids(VALID_METRICS) # Binary Tree class cdef class BinaryTree: - cdef cnp.ndarray data_arr - cdef cnp.ndarray sample_weight_arr - cdef cnp.ndarray idx_array_arr - cdef cnp.ndarray node_data_arr - cdef cnp.ndarray node_bounds_arr + cdef const DTYPE_t[:, ::1] data_arr + cdef const DTYPE_t[::1] sample_weight_arr + cdef const ITYPE_t[::1] idx_array_arr + cdef const NodeData_t[::1] node_data_arr + cdef const DTYPE_t[:, :, ::1] node_bounds_arr cdef readonly const DTYPE_t[:, ::1] data cdef readonly const DTYPE_t[::1] sample_weight @@ -869,7 +873,7 @@ cdef class BinaryTree: # Allocate tree-specific data allocate_data(self, self.n_nodes, n_features) self._recursive_build( - node_data=self.node_data_arr, + node_data=self.node_data_arr.base, i_node=0, idx_start=0, idx_end=n_samples @@ -905,15 +909,15 @@ cdef class BinaryTree: """ if self.sample_weight is not None: # pass the numpy array - sample_weight_arr = self.sample_weight_arr + sample_weight_arr = self.sample_weight_arr.base else: # pass None to avoid confusion with the empty place holder # of size 1 from __cinit__ sample_weight_arr = None - return (self.data_arr, - self.idx_array_arr, - self.node_data_arr, - self.node_bounds_arr, + return (self.data_arr.base, + self.idx_array_arr.base, + self.node_data_arr.base, + self.node_bounds_arr.base, int(self.leaf_size), int(self.n_levels), int(self.n_nodes), @@ -993,8 +997,12 @@ cdef class BinaryTree: arrays: tuple of array Arrays for storing tree data, index, node data and node bounds. """ - return (self.data_arr, self.idx_array_arr, - self.node_data_arr, self.node_bounds_arr) + return ( + self.data_arr.base, + self.idx_array_arr.base, + self.node_data_arr.base, + self.node_bounds_arr.base, + ) cdef inline DTYPE_t dist(self, DTYPE_t* x1, DTYPE_t* x2, ITYPE_t size) nogil except -1: @@ -1340,14 +1348,14 @@ cdef class BinaryTree: # make a new numpy array that wraps the existing data indices_npy[i] = cnp.PyArray_SimpleNewFromData(1, &counts[i], cnp.NPY_INTP, indices[i]) # make sure the data will be freed when the numpy array is garbage collected - PyArray_ENABLEFLAGS(indices_npy[i], cnp.NPY_OWNDATA) + PyArray_ENABLEFLAGS(indices_npy[i], cnp.NPY_ARRAY_OWNDATA) # make sure the data is not freed twice indices[i] = NULL # make a new numpy array that wraps the existing data distances_npy[i] = cnp.PyArray_SimpleNewFromData(1, &counts[i], cnp.NPY_DOUBLE, distances[i]) # make sure the data will be freed when the numpy array is garbage collected - PyArray_ENABLEFLAGS(distances_npy[i], cnp.NPY_OWNDATA) + PyArray_ENABLEFLAGS(distances_npy[i], cnp.NPY_ARRAY_OWNDATA) # make sure the data is not freed twice distances[i] = NULL @@ -1360,7 +1368,7 @@ cdef class BinaryTree: # make a new numpy array that wraps the existing data indices_npy[i] = cnp.PyArray_SimpleNewFromData(1, &counts[i], cnp.NPY_INTP, indices[i]) # make sure the data will be freed when the numpy array is garbage collected - PyArray_ENABLEFLAGS(indices_npy[i], cnp.NPY_OWNDATA) + PyArray_ENABLEFLAGS(indices_npy[i], cnp.NPY_ARRAY_OWNDATA) # make sure the data is not freed twice indices[i] = NULL From ae5648cfb929a285541f47aa3081833d9b435f5d Mon Sep 17 00:00:00 2001 From: Vincent M Date: Fri, 2 Dec 2022 16:32:41 +0100 Subject: [PATCH 2/2] Apply Jepherson suggestion --- sklearn/neighbors/_binary_tree.pxi | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/neighbors/_binary_tree.pxi b/sklearn/neighbors/_binary_tree.pxi index 626670e74c7a7..24234461bc9d7 100644 --- a/sklearn/neighbors/_binary_tree.pxi +++ b/sklearn/neighbors/_binary_tree.pxi @@ -166,7 +166,7 @@ from ..utils._typedefs import DTYPE, ITYPE from ..utils._heap cimport heap_push from ..utils._sorting cimport simultaneous_sort as _simultaneous_sort -# TODO: use cnp.PyArray_ENABLEFLAGS when cython >= 3.0 is used. +# TODO: use cnp.PyArray_ENABLEFLAGS when Cython>=3.0 is used. cdef extern from "numpy/arrayobject.h": void PyArray_ENABLEFLAGS(cnp.ndarray arr, int flags)