diff --git a/sklearn/cluster/_hdbscan/_tree.pyx b/sklearn/cluster/_hdbscan/_tree.pyx index 0e493f28379eb..6e4df6cf12592 100644 --- a/sklearn/cluster/_hdbscan/_tree.pyx +++ b/sklearn/cluster/_hdbscan/_tree.pyx @@ -282,40 +282,37 @@ cdef max_lambdas(cnp.ndarray hierarchy): return deaths -cdef class TreeUnionFind (object): +@cython.final +cdef class TreeUnionFind: - cdef cnp.ndarray _data_arr - cdef cnp.intp_t[:, ::1] _data - cdef cnp.ndarray is_component + cdef cnp.intp_t[:, ::1] data + cdef cnp.uint8_t[::1] is_component def __init__(self, size): - self._data_arr = np.zeros((size, 2), dtype=np.intp) - self._data_arr.T[0] = np.arange(size) - self._data = self._data_arr - self.is_component = np.ones(size, dtype=bool) + cdef cnp.intp_t idx + self.data = np.zeros((size, 2), dtype=np.intp) + for idx in range(size): + self.data[idx, 0] = idx + self.is_component = np.ones(size, dtype=np.uint8) - cdef union_(self, cnp.intp_t x, cnp.intp_t y): + cdef void union(self, cnp.intp_t x, cnp.intp_t y): cdef cnp.intp_t x_root = self.find(x) cdef cnp.intp_t y_root = self.find(y) - if self._data[x_root, 1] < self._data[y_root, 1]: - self._data[x_root, 0] = y_root - elif self._data[x_root, 1] > self._data[y_root, 1]: - self._data[y_root, 0] = x_root + if self.data[x_root, 1] < self.data[y_root, 1]: + self.data[x_root, 0] = y_root + elif self.data[x_root, 1] > self.data[y_root, 1]: + self.data[y_root, 0] = x_root else: - self._data[y_root, 0] = x_root - self._data[x_root, 1] += 1 - + self.data[y_root, 0] = x_root + self.data[x_root, 1] += 1 return - cdef find(self, cnp.intp_t x): - if self._data[x, 0] != x: - self._data[x, 0] = self.find(self._data[x, 0]) + cdef cnp.intp_t find(self, cnp.intp_t x): + if self.data[x, 0] != x: + self.data[x, 0] = self.find(self.data[x, 0]) self.is_component[x] = False - return self._data[x, 0] - - cdef cnp.ndarray[cnp.intp_t, ndim=1] components(self): - return self.is_component.nonzero()[0] + return self.data[x, 0] cpdef cnp.ndarray[cnp.intp_t, ndim=1] labelling_at_cut( @@ -361,8 +358,8 @@ cpdef cnp.ndarray[cnp.intp_t, ndim=1] labelling_at_cut( cluster = n_samples for row in linkage: if row[2] < cut: - union_find.union_( row[0], cluster) - union_find.union_( row[1], cluster) + union_find.union( row[0], cluster) + union_find.union( row[1], cluster) cluster += 1 cluster_size = np.zeros(cluster, dtype=np.intp) @@ -416,7 +413,7 @@ cdef cnp.ndarray[cnp.intp_t, ndim=1] do_labelling( child = child_array[n] parent = parent_array[n] if child not in clusters: - union_find.union_(parent, child) + union_find.union(parent, child) for n in range(root_cluster): cluster = union_find.find(n)