diff --git a/sklearn/cluster/_hdbscan/_linkage.pyx b/sklearn/cluster/_hdbscan/_linkage.pyx index fd9888ac4da82..34fa9c98278a5 100644 --- a/sklearn/cluster/_hdbscan/_linkage.pyx +++ b/sklearn/cluster/_hdbscan/_linkage.pyx @@ -10,6 +10,8 @@ from libc.float cimport DBL_MAX import numpy as np from ...metrics._dist_metrics cimport DistanceMetric from ...cluster._hierarchical_fast cimport UnionFind +from ...cluster._hdbscan._tree cimport HIERARCHY_t +from ...cluster._hdbscan._tree import HIERARCHY_dtype from ...utils._typedefs cimport ITYPE_t, DTYPE_t from ...utils._typedefs import ITYPE, DTYPE @@ -188,7 +190,7 @@ cpdef cnp.ndarray[MST_edge_t, ndim=1, mode='c'] mst_from_data_matrix( return mst -cpdef cnp.ndarray[cnp.float64_t, ndim=2, mode='c'] make_single_linkage(const MST_edge_t[::1] mst): +cpdef cnp.ndarray[HIERARCHY_t, ndim=1] make_single_linkage(const MST_edge_t[::1] mst): """Construct a single-linkage tree from an MST. Parameters @@ -199,26 +201,20 @@ cpdef cnp.ndarray[cnp.float64_t, ndim=2, mode='c'] make_single_linkage(const MST Returns ------- - single_linkage : ndarray of shape (n_samples - 1, 4) - The single-linkage tree tree (dendrogram) built from the MST. Each - of the array represents the following: - - - left node/cluster - - right node/cluster - - distance - - new cluster size + single_linkage : ndarray of shape (n_samples - 1,), dtype=HIERARCHY_dtype + The single-linkage tree tree (dendrogram) built from the MST. """ cdef: - cnp.ndarray[cnp.float64_t, ndim=2, mode='c'] single_linkage + cnp.ndarray[HIERARCHY_t, ndim=1] single_linkage # Note mst.shape[0] is one fewer than the number of samples cnp.int64_t n_samples = mst.shape[0] + 1 - cnp.int64_t current_node_cluster, next_node_cluster + cnp.intp_t current_node_cluster, next_node_cluster cnp.int64_t current_node, next_node, index cnp.float64_t distance UnionFind U = UnionFind(n_samples) - single_linkage = np.zeros((n_samples - 1, 4), dtype=np.float64) + single_linkage = np.zeros(n_samples - 1, dtype=HIERARCHY_dtype) for i in range(n_samples - 1): @@ -231,10 +227,10 @@ cpdef cnp.ndarray[cnp.float64_t, ndim=2, mode='c'] make_single_linkage(const MST # TODO: Update this to an array of structs (AoS). # Should be done simultaneously in _tree.pyx to ensure compatability. - single_linkage[i][0] = current_node_cluster - single_linkage[i][1] = next_node_cluster - single_linkage[i][2] = distance - single_linkage[i][3] = U.size[current_node_cluster] + U.size[next_node_cluster] + single_linkage[i].left_node = current_node_cluster + single_linkage[i].right_node = next_node_cluster + single_linkage[i].value = distance + single_linkage[i].cluster_size = U.size[current_node_cluster] + U.size[next_node_cluster] U.union(current_node_cluster, next_node_cluster) diff --git a/sklearn/cluster/_hdbscan/_tree.pxd b/sklearn/cluster/_hdbscan/_tree.pxd new file mode 100644 index 0000000000000..83d5b38cb99fb --- /dev/null +++ b/sklearn/cluster/_hdbscan/_tree.pxd @@ -0,0 +1,9 @@ +cimport numpy as cnp +import numpy as np + + +ctypedef packed struct HIERARCHY_t: + cnp.intp_t left_node + cnp.intp_t right_node + cnp.float64_t value + cnp.intp_t cluster_size diff --git a/sklearn/cluster/_hdbscan/_tree.pyx b/sklearn/cluster/_hdbscan/_tree.pyx index 3d6bed3e8df34..92ba4309d71c3 100644 --- a/sklearn/cluster/_hdbscan/_tree.pyx +++ b/sklearn/cluster/_hdbscan/_tree.pyx @@ -2,48 +2,67 @@ # Authors: Leland McInnes # License: 3-clause BSD -import numpy as np -cimport numpy as np +cimport numpy as cnp import cython +import numpy as np +cdef cnp.float64_t INFTY = np.inf +cdef cnp.intp_t NOISE = -1 -cdef np.double_t INFTY = np.inf +HIERARCHY_dtype = np.dtype([ + ("left_node", np.intp), + ("right_node", np.intp), + ("value", np.float64), + ("cluster_size", np.intp), +]) -cdef list bfs_from_hierarchy(np.ndarray[np.double_t, ndim=2] hierarchy, - np.intp_t bfs_root): +cdef list bfs_from_hierarchy( + cnp.ndarray[HIERARCHY_t, ndim=1] hierarchy, + cnp.intp_t bfs_root +): """ - Perform a breadth first search on a tree in scipy hclust format. + Perform a breadth first search on a single linkage hierarchy in + scipy.cluster.hierarchy format. """ - - cdef list to_process - cdef np.intp_t max_node - cdef np.intp_t num_points - cdef np.intp_t dim - - dim = hierarchy.shape[0] - max_node = 2 * dim - num_points = max_node - dim + 1 - - to_process = [bfs_root] + # NOTE: We keep `process_queue` as a list rather than a memory-view to + # retain semantics which make the below runtime algorithm convenient, such + # as not being required to preallocate space, and allowing easy checks to + # determine whether the list is emptey or not. + cdef list process_queue, next_queue + cdef cnp.intp_t n_samples = hierarchy.shape[0] + 1 + cdef cnp.intp_t node + process_queue = [bfs_root] result = [] - while to_process: - result.extend(to_process) - to_process = [x - num_points for x in - to_process if x >= num_points] - if to_process: - to_process = hierarchy[to_process, - :2].flatten().astype(np.intp).tolist() - + while process_queue: + result.extend(process_queue) + # By construction, node i is formed by the union of nodes + # hierarchy[i - n_samples, 0] and hierarchy[i - n_samples, 1] + process_queue = [ + x - n_samples + for x in process_queue + if x >= n_samples + ] + if process_queue: + next_queue = [] + for node in process_queue: + next_queue.extend( + [ + hierarchy[node].left_node, + hierarchy[node].right_node, + ] + ) + process_queue = next_queue return result - -cpdef np.ndarray condense_tree(np.ndarray[np.double_t, ndim=2] hierarchy, - np.intp_t min_cluster_size=10): - """Condense a tree according to a minimum cluster size. This is akin +cpdef cnp.ndarray condense_tree( + cnp.ndarray[HIERARCHY_t, ndim=1] hierarchy, + cnp.intp_t min_cluster_size=10 +): + """Condense an MST according to a minimum cluster size. This is akin to the runt pruning procedure of Stuetzle. The result is a much simpler tree that is easier to visualize. We include extra information on the lambda value at which individual points depart clusters for later @@ -51,298 +70,251 @@ cpdef np.ndarray condense_tree(np.ndarray[np.double_t, ndim=2] hierarchy, Parameters ---------- - hierarchy : ndarray (n_samples, 4) + hierarchy : ndarray of shape (n_samples,), dtype=HIERARCHY_dtype A single linkage hierarchy in scipy.cluster.hierarchy format. min_cluster_size : int, optional (default 10) - The minimum size of clusters to consider. Smaller "runt" - clusters are pruned from the tree. + The minimum size of clusters to consider. Clusters smaler than this + are pruned from the tree. Returns ------- - condensed_tree : numpy recarray + condensed_tree : ndarray of shape (n_samples,), dtype=HIERARCHY_dtype Effectively an edgelist with a parent, child, lambda_val - and child_size in each row providing a tree structure. + and cluster_size in each row providing a tree structure. """ - cdef np.intp_t root - cdef np.intp_t num_points - cdef np.intp_t next_label - cdef list node_list - cdef list result_list - - cdef np.ndarray[np.intp_t, ndim=1] relabel - cdef np.ndarray[np.int_t, ndim=1] ignore - cdef np.ndarray[np.double_t, ndim=1] children - - cdef np.intp_t node - cdef np.intp_t sub_node - cdef np.intp_t left - cdef np.intp_t right - cdef double lambda_value - cdef np.intp_t left_count - cdef np.intp_t right_count - - root = 2 * hierarchy.shape[0] - num_points = root // 2 + 1 - next_label = num_points + 1 + cdef: + cnp.intp_t root = 2 * hierarchy.shape[0] + cnp.intp_t num_points = hierarchy.shape[0] + 1 + cnp.intp_t next_label = num_points + 1 + list result_list, node_list = bfs_from_hierarchy(hierarchy, root) + + cnp.intp_t[::1] relabel + cnp.uint8_t[::1] ignore + HIERARCHY_t children - node_list = bfs_from_hierarchy(hierarchy, root) + cnp.intp_t node, sub_node, left, right + cnp.float64_t lambda_value, distance + cnp.intp_t left_count, right_count relabel = np.empty(root + 1, dtype=np.intp) relabel[root] = num_points result_list = [] - ignore = np.zeros(len(node_list), dtype=int) + ignore = np.zeros(len(node_list), dtype=np.uint8) for node in node_list: + # Guarantee that node is a cluster node of interest if ignore[node] or node < num_points: continue children = hierarchy[node - num_points] - left = children[0] - right = children[1] - if children[2] > 0.0: - lambda_value = 1.0 / children[2] + left = children.left_node + right = children.right_node + distance = children.value + if distance > 0.0: + lambda_value = 1.0 / distance else: lambda_value = INFTY + # Guarantee that left is a cluster node if left >= num_points: - left_count = hierarchy[left - num_points][3] + left_count = hierarchy[left - num_points].cluster_size else: left_count = 1 + # Guarantee that right is a cluster node if right >= num_points: - right_count = hierarchy[right - num_points][3] + right_count = hierarchy[right - num_points].cluster_size else: right_count = 1 + # Each child is regarded as a proper cluster if left_count >= min_cluster_size and right_count >= min_cluster_size: relabel[left] = next_label next_label += 1 - result_list.append((relabel[node], relabel[left], lambda_value, - left_count)) + result_list.append( + (relabel[node], relabel[left], lambda_value, left_count) + ) relabel[right] = next_label next_label += 1 - result_list.append((relabel[node], relabel[right], lambda_value, - right_count)) + result_list.append( + (relabel[node], relabel[right], lambda_value, right_count) + ) + # Each child is regarded as a collection of single-sample clusters elif left_count < min_cluster_size and right_count < min_cluster_size: for sub_node in bfs_from_hierarchy(hierarchy, left): if sub_node < num_points: - result_list.append((relabel[node], sub_node, - lambda_value, 1)) + result_list.append( + (relabel[node], sub_node, lambda_value, 1) + ) ignore[sub_node] = True - for sub_node in bfs_from_hierarchy(hierarchy, right): if sub_node < num_points: - result_list.append((relabel[node], sub_node, - lambda_value, 1)) + result_list.append( + (relabel[node], sub_node, lambda_value, 1) + ) ignore[sub_node] = True + # One child is a collection of single-sample clusters, while the other + # is a persistance of the parent node cluster elif left_count < min_cluster_size: relabel[right] = relabel[node] for sub_node in bfs_from_hierarchy(hierarchy, left): if sub_node < num_points: - result_list.append((relabel[node], sub_node, - lambda_value, 1)) + result_list.append( + (relabel[node], sub_node, lambda_value, 1) + ) ignore[sub_node] = True + # One child is a collection of single-sample clusters, while the other + # is a persistance of the parent node cluster else: relabel[left] = relabel[node] for sub_node in bfs_from_hierarchy(hierarchy, right): if sub_node < num_points: - result_list.append((relabel[node], sub_node, - lambda_value, 1)) + result_list.append( + (relabel[node], sub_node, lambda_value, 1) + ) ignore[sub_node] = True - return np.array(result_list, dtype=[('parent', np.intp), - ('child', np.intp), - ('lambda_val', float), - ('child_size', np.intp)]) - - -cpdef dict compute_stability(np.ndarray condensed_tree): - - cdef np.ndarray[np.double_t, ndim=1] result_arr - cdef np.ndarray sorted_child_data - cdef np.ndarray[np.intp_t, ndim=1] sorted_children - cdef np.ndarray[np.double_t, ndim=1] sorted_lambdas - - cdef np.ndarray[np.intp_t, ndim=1] parents - cdef np.ndarray[np.intp_t, ndim=1] sizes - cdef np.ndarray[np.double_t, ndim=1] lambdas - - cdef np.intp_t child - cdef np.intp_t parent - cdef np.intp_t child_size - cdef np.intp_t result_index - cdef np.intp_t current_child - cdef np.float64_t lambda_ - cdef np.float64_t min_lambda - - cdef np.ndarray[np.double_t, ndim=1] births_arr - cdef np.double_t *births - - cdef np.intp_t largest_child = condensed_tree['child'].max() - cdef np.intp_t smallest_cluster = condensed_tree['parent'].min() - cdef np.intp_t num_clusters = (condensed_tree['parent'].max() - - smallest_cluster + 1) - - if largest_child < smallest_cluster: - largest_child = smallest_cluster - - sorted_child_data = np.sort(condensed_tree[['child', 'lambda_val']], - axis=0) - births_arr = np.nan * np.ones(largest_child + 1, dtype=np.double) - births = ( births_arr.data) - sorted_children = sorted_child_data['child'].copy() - sorted_lambdas = sorted_child_data['lambda_val'].copy() - - parents = condensed_tree['parent'] - sizes = condensed_tree['child_size'] - lambdas = condensed_tree['lambda_val'] - - current_child = -1 - min_lambda = 0 - - for row in range(sorted_child_data.shape[0]): - child = sorted_children[row] - lambda_ = sorted_lambdas[row] - - if child == current_child: - min_lambda = min(min_lambda, lambda_) - elif current_child != -1: - births[current_child] = min_lambda - current_child = child - min_lambda = lambda_ - else: - # Initialize - current_child = child - min_lambda = lambda_ + return np.array(result_list, dtype=HIERARCHY_dtype) - if current_child != -1: - births[current_child] = min_lambda - births[smallest_cluster] = 0.0 - result_arr = np.zeros(num_clusters, dtype=np.double) +cpdef dict compute_stability(cnp.ndarray[HIERARCHY_t, ndim=1] condensed_tree): - for i in range(condensed_tree.shape[0]): - parent = parents[i] - lambda_ = lambdas[i] - child_size = sizes[i] - result_index = parent - smallest_cluster + cdef: + cnp.float64_t[::1] result, births_arr + cnp.ndarray[cnp.intp_t, ndim=1] parents + + cnp.intp_t parent, cluster_size, result_index + cnp.float64_t lambda_val + HIERARCHY_t condensed_node + cnp.float64_t[:, :] result_pre_dict + + + parents = condensed_tree['left_node'] + cdef cnp.intp_t largest_child = condensed_tree['right_node'].max() + cdef cnp.intp_t smallest_cluster = parents.min() + cdef cnp.intp_t num_clusters = parents.max() - smallest_cluster + 1 + + largest_child = max(largest_child, smallest_cluster) + + births_arr = np.full(largest_child + 1, np.nan, dtype=np.float64) - result_arr[result_index] += (lambda_ - births[parent]) * child_size + for idx in range(condensed_tree.shape[0]): + condensed_node = condensed_tree[idx] + births_arr[condensed_node.right_node] = condensed_node.value - result_pre_dict = np.vstack((np.arange(smallest_cluster, - condensed_tree['parent'].max() + 1), - result_arr)).T + births_arr[smallest_cluster] = 0.0 + + result = np.zeros(num_clusters, dtype=np.float64) + for condensed_node in condensed_tree: + parent = condensed_node.left_node + lambda_val = condensed_node.value + cluster_size = condensed_node.cluster_size + + result_index = parent - smallest_cluster + result[result_index] += (lambda_val - births_arr[parent]) * cluster_size + + result_pre_dict = np.vstack( + ( + np.arange(smallest_cluster, parents.max() + 1), + result + ) + ).T return dict(result_pre_dict) -cdef list bfs_from_cluster_tree(np.ndarray tree, np.intp_t bfs_root): +cdef list bfs_from_cluster_tree( + cnp.ndarray[HIERARCHY_t, ndim=1] hierarchy, + cnp.intp_t bfs_root, +): cdef list result - cdef np.ndarray[np.intp_t, ndim=1] to_process + cdef cnp.ndarray[cnp.intp_t, ndim=1] process_queue, children + children = hierarchy['right_node'] + + cdef cnp.intp_t[:] parents = hierarchy['left_node'] result = [] - to_process = np.array([bfs_root], dtype=np.intp) + process_queue = np.array([bfs_root], dtype=np.intp) - while to_process.shape[0] > 0: - result.extend(to_process.tolist()) - to_process = tree['child'][np.in1d(tree['parent'], to_process)] + while process_queue.shape[0] > 0: + result.extend(process_queue.tolist()) + process_queue = children[np.isin(parents, process_queue)] return result -cdef max_lambdas(np.ndarray tree): +cdef cnp.float64_t[::1] max_lambdas(cnp.ndarray[HIERARCHY_t, ndim=1] hierarchy): - cdef np.ndarray sorted_parent_data - cdef np.ndarray[np.intp_t, ndim=1] sorted_parents - cdef np.ndarray[np.double_t, ndim=1] sorted_lambdas + cdef cnp.intp_t parent, current_parent, idx + cdef cnp.float64_t lambda_val, max_lambda - cdef np.intp_t parent - cdef np.intp_t current_parent - cdef np.float64_t lambda_ - cdef np.float64_t max_lambda + cdef cnp.float64_t[::1] deaths - cdef np.ndarray[np.double_t, ndim=1] deaths_arr - cdef np.double_t *deaths + cdef cnp.intp_t largest_parent = hierarchy['left_node'].max() - cdef np.intp_t largest_parent = tree['parent'].max() + deaths = np.zeros(largest_parent + 1, dtype=np.float64) - sorted_parent_data = np.sort(tree[['parent', 'lambda_val']], axis=0) - deaths_arr = np.zeros(largest_parent + 1, dtype=np.double) - deaths = ( deaths_arr.data) - sorted_parents = sorted_parent_data['parent'] - sorted_lambdas = sorted_parent_data['lambda_val'] + current_parent = hierarchy[0].left_node + max_lambda = hierarchy[0].value - current_parent = -1 - max_lambda = 0 - - for row in range(sorted_parent_data.shape[0]): - parent = sorted_parents[row] - lambda_ = sorted_lambdas[row] + for idx in range(1, hierarchy.shape[0]): + parent = hierarchy[idx].left_node + lambda_val = hierarchy[idx].value if parent == current_parent: - max_lambda = max(max_lambda, lambda_) - elif current_parent != -1: - deaths[current_parent] = max_lambda - current_parent = parent - max_lambda = lambda_ + max_lambda = max(max_lambda, lambda_val) else: - # Initialize + deaths[current_parent] = max_lambda current_parent = parent - max_lambda = lambda_ + max_lambda = lambda_val deaths[current_parent] = max_lambda # value for last parent - - return deaths_arr + return deaths -cdef class TreeUnionFind (object): +cdef class TreeUnionFind: - cdef np.ndarray _data_arr - cdef np.intp_t[:, ::1] _data - cdef np.ndarray is_component + cdef cnp.intp_t[:, ::1] data + cdef cnp.uint8_t[::1] is_component def __init__(self, size): - self._data_arr = np.zeros((size, 2), dtype=np.intp) - self._data_arr.T[0] = np.arange(size) - self._data = ( ( - self._data_arr.data)) + cdef cnp.ndarray[cnp.intp_t, ndim=2] data_arr + data_arr = np.zeros((size, 2), dtype=np.intp) + data_arr.T[0] = np.arange(size) + self.data = data_arr self.is_component = np.ones(size, dtype=bool) - cdef union_(self, np.intp_t x, np.intp_t y): - cdef np.intp_t x_root = self.find(x) - cdef np.intp_t y_root = self.find(y) + cdef void union(self, cnp.intp_t x, cnp.intp_t y): + cdef cnp.intp_t x_root = self.find(x) + cdef cnp.intp_t y_root = self.find(y) - if self._data[x_root, 1] < self._data[y_root, 1]: - self._data[x_root, 0] = y_root - elif self._data[x_root, 1] > self._data[y_root, 1]: - self._data[y_root, 0] = x_root + if self.data[x_root, 1] < self.data[y_root, 1]: + self.data[x_root, 0] = y_root + elif self.data[x_root, 1] > self.data[y_root, 1]: + self.data[y_root, 0] = x_root else: - self._data[y_root, 0] = x_root - self._data[x_root, 1] += 1 - + self.data[y_root, 0] = x_root + self.data[x_root, 1] += 1 return - cdef find(self, np.intp_t x): - if self._data[x, 0] != x: - self._data[x, 0] = self.find(self._data[x, 0]) + cdef cnp.intp_t find(self, cnp.intp_t x): + if self.data[x, 0] != x: + self.data[x, 0] = self.find(self.data[x, 0]) self.is_component[x] = False - return self._data[x, 0] + return self.data[x, 0] - cdef np.ndarray[np.intp_t, ndim=1] components(self): - return self.is_component.nonzero()[0] - - -cpdef np.ndarray[np.intp_t, ndim=1] labelling_at_cut( - np.ndarray linkage, - np.double_t cut, - np.intp_t min_cluster_size): +cpdef cnp.ndarray[cnp.intp_t, ndim=1] labelling_at_cut( + cnp.ndarray[HIERARCHY_t, ndim=1] linkage, + cnp.float64_t cut, + cnp.intp_t min_cluster_size + ): """Given a single linkage tree and a cut value, return the vector of cluster labels at that cut value. This is useful for Robust Single Linkage, and extracting DBSCAN results @@ -350,10 +322,10 @@ cpdef np.ndarray[np.intp_t, ndim=1] labelling_at_cut( Parameters ---------- - linkage : ndarray (n_samples, 4) + linkage : ndarray of shape (n_samples,), dtype=HIERARCHY_dtype The single linkage tree in scipy.cluster.hierarchy format. - cut : double + cut : float The cut value at which to find clusters. min_cluster_size : int @@ -362,35 +334,32 @@ cpdef np.ndarray[np.intp_t, ndim=1] labelling_at_cut( Returns ------- - labels : ndarray (n_samples,) + labels : ndarray of shape (n_samples,) The cluster labels for each point in the data set; a label of -1 denotes a noise assignment. """ - cdef np.intp_t root - cdef np.intp_t num_points - cdef np.ndarray[np.intp_t, ndim=1] result_arr - cdef np.ndarray[np.intp_t, ndim=1] unique_labels - cdef np.ndarray[np.intp_t, ndim=1] cluster_size - cdef np.intp_t *result - cdef TreeUnionFind union_find - cdef np.intp_t n - cdef np.intp_t cluster - cdef np.intp_t cluster_id + cdef: + cnp.intp_t root, num_points + cnp.intp_t[::1] unique_labels, cluster_size + cnp.ndarray[cnp.intp_t, ndim=1] result + TreeUnionFind union_find + cnp.intp_t n, cluster, cluster_id, cluster_label + dict cluster_label_map + HIERARCHY_t node root = 2 * linkage.shape[0] num_points = root // 2 + 1 - result_arr = np.empty(num_points, dtype=np.intp) - result = ( result_arr.data) + result = np.empty(num_points, dtype=np.intp) - union_find = TreeUnionFind( root + 1) + union_find = TreeUnionFind(root + 1) cluster = num_points - for row in linkage: - if row[2] < cut: - union_find.union_( row[0], cluster) - union_find.union_( row[1], cluster) + for node in linkage: + if node.value < cut: + union_find.union(node.left_node, cluster) + union_find.union(node.right_node, cluster) cluster += 1 cluster_size = np.zeros(cluster, dtype=np.intp) @@ -399,13 +368,13 @@ cpdef np.ndarray[np.intp_t, ndim=1] labelling_at_cut( cluster_size[cluster] += 1 result[n] = cluster - cluster_label_map = {-1: -1} + cluster_label_map = {-1: NOISE} cluster_label = 0 - unique_labels = np.unique(result_arr) + unique_labels = np.unique(result) for cluster in unique_labels: if cluster_size[cluster] < min_cluster_size: - cluster_label_map[cluster] = -1 + cluster_label_map[cluster] = NOISE else: cluster_label_map[cluster] = cluster_label cluster_label += 1 @@ -413,93 +382,77 @@ cpdef np.ndarray[np.intp_t, ndim=1] labelling_at_cut( for n in range(num_points): result[n] = cluster_label_map[result[n]] - return result_arr + return result -cdef np.ndarray[np.intp_t, ndim=1] do_labelling( - np.ndarray tree, +cdef cnp.ndarray[cnp.intp_t, ndim=1] do_labelling( + cnp.ndarray[HIERARCHY_t, ndim=1] hierarchy, set clusters, dict cluster_label_map, - np.intp_t allow_single_cluster, - np.double_t cluster_selection_epsilon): - - cdef np.intp_t root_cluster - cdef np.ndarray[np.intp_t, ndim=1] result_arr - cdef np.ndarray[np.intp_t, ndim=1] parent_array - cdef np.ndarray[np.intp_t, ndim=1] child_array - cdef np.ndarray[np.double_t, ndim=1] lambda_array - cdef np.intp_t *result - cdef TreeUnionFind union_find - cdef np.intp_t parent - cdef np.intp_t child - cdef np.intp_t n - cdef np.intp_t cluster - - child_array = tree['child'] - parent_array = tree['parent'] - lambda_array = tree['lambda_val'] - - root_cluster = parent_array.min() - result_arr = np.empty(root_cluster, dtype=np.intp) - result = ( result_arr.data) - - union_find = TreeUnionFind(parent_array.max() + 1) - - for n in range(tree.shape[0]): - child = child_array[n] - parent = parent_array[n] + cnp.uint8_t allow_single_cluster, + cnp.float64_t cluster_selection_epsilon): + cdef: + cnp.ndarray[cnp.intp_t, ndim=1] result, parents, children + cnp.ndarray[cnp.float64_t, ndim=1] lambdas + TreeUnionFind union_find + cnp.intp_t root_cluster, parent, child, cluster, n + cnp.int64_t cluster_label, label + + children = hierarchy['right_node'] + parents = hierarchy['left_node'] + lambdas = hierarchy['value'] + + root_cluster = parents.min() + result = np.empty(root_cluster, dtype=np.intp) + + union_find = TreeUnionFind(parents.max() + 1) + + for n in range(hierarchy.shape[0]): + child = children[n] + parent = parents[n] if child not in clusters: - union_find.union_(parent, child) + union_find.union(parent, child) for n in range(root_cluster): cluster = union_find.find(n) - if cluster < root_cluster: - result[n] = -1 - elif cluster == root_cluster: - if len(clusters) == 1 and allow_single_cluster: - if cluster_selection_epsilon != 0.0: - if tree['lambda_val'][tree['child'] == n] >= 1 / cluster_selection_epsilon : - result[n] = cluster_label_map[cluster] - else: - result[n] = -1 - elif tree['lambda_val'][tree['child'] == n] >= \ - tree['lambda_val'][tree['parent'] == cluster].max(): - result[n] = cluster_label_map[cluster] - else: - result[n] = -1 - else: - result[n] = -1 - else: - result[n] = cluster_label_map[cluster] - - return result_arr - + label = NOISE + if cluster != root_cluster: + label = cluster_label_map[cluster] + result[n] = label + continue + if len(clusters) == 1 and allow_single_cluster: + if cluster_selection_epsilon != 0.0: + if lambdas[children == n] >= 1 / cluster_selection_epsilon : + label = cluster_label_map[cluster] + elif lambdas[children == n] >= lambdas[parents == cluster].max(): + label = cluster_label_map[cluster] + result[n] = label + return result -cdef get_probabilities(np.ndarray tree, dict cluster_map, np.ndarray labels): - cdef np.ndarray[np.double_t, ndim=1] result - cdef np.ndarray[np.double_t, ndim=1] deaths - cdef np.ndarray[np.double_t, ndim=1] lambda_array - cdef np.ndarray[np.intp_t, ndim=1] child_array - cdef np.ndarray[np.intp_t, ndim=1] parent_array - cdef np.intp_t root_cluster - cdef np.intp_t n - cdef np.intp_t point - cdef np.intp_t cluster_num - cdef np.intp_t cluster - cdef np.double_t max_lambda - cdef np.double_t lambda_ +cdef cnp.ndarray[cnp.float64_t, ndim=1] get_probabilities( + cnp.ndarray[HIERARCHY_t, ndim=1] hierarchy, + dict cluster_map, + cnp.intp_t[:] labels +): + cdef: + cnp.ndarray[cnp.float64_t, ndim=1] result + cnp.float64_t[:] lambdas + cnp.float64_t[::1] deaths + cnp.ndarray[cnp.intp_t, ndim=1] children, parents + cnp.intp_t n, point, cluster_num, cluster, root_cluster + cnp.float64_t max_lambda, lambda_val - child_array = tree['child'] - parent_array = tree['parent'] - lambda_array = tree['lambda_val'] + children = hierarchy['right_node'] + parents = hierarchy['left_node'] + lambdas = hierarchy['value'] result = np.zeros(labels.shape[0]) - deaths = max_lambdas(tree) - root_cluster = parent_array.min() + deaths = max_lambdas(hierarchy) + root_cluster = parents.min() - for n in range(tree.shape[0]): - point = child_array[n] + for n in range(hierarchy.shape[0]): + point = children[n] if point >= root_cluster: continue @@ -510,55 +463,87 @@ cdef get_probabilities(np.ndarray tree, dict cluster_map, np.ndarray labels): cluster = cluster_map[cluster_num] max_lambda = deaths[cluster] - if max_lambda == 0.0 or not np.isfinite(lambda_array[n]): + if max_lambda == 0.0 or not np.isfinite(lambdas[n]): result[point] = 1.0 else: - lambda_ = min(lambda_array[n], max_lambda) - result[point] = lambda_ / max_lambda + lambda_val = min(lambdas[n], max_lambda) + result[point] = lambda_val / max_lambda return result -cpdef list recurse_leaf_dfs(np.ndarray cluster_tree, np.intp_t current_node): - children = cluster_tree[cluster_tree['parent'] == current_node]['child'] - if len(children) == 0: +cdef list recurse_leaf_dfs( + cnp.ndarray[HIERARCHY_t, ndim=1] cluster_tree, + cnp.intp_t current_node, +): + cdef cnp.intp_t[:] children + cdef cnp.intp_t child + + children = cluster_tree[cluster_tree['left_node'] == current_node]['right_node'] + if children.shape[0] == 0: return [current_node,] else: return sum([recurse_leaf_dfs(cluster_tree, child) for child in children], []) -cpdef list get_cluster_tree_leaves(np.ndarray cluster_tree): +cdef list get_cluster_tree_leaves(cnp.ndarray[HIERARCHY_t, ndim=1] cluster_tree): + cdef cnp.intp_t root if cluster_tree.shape[0] == 0: return [] - root = cluster_tree['parent'].min() + root = cluster_tree['left_node'].min() return recurse_leaf_dfs(cluster_tree, root) -cpdef np.intp_t traverse_upwards(np.ndarray cluster_tree, np.double_t cluster_selection_epsilon, np.intp_t leaf, np.intp_t allow_single_cluster): +cdef cnp.intp_t traverse_upwards( + cnp.ndarray[HIERARCHY_t, ndim=1] cluster_tree, + cnp.float64_t cluster_selection_epsilon, + cnp.intp_t leaf, + cnp.intp_t allow_single_cluster +): - root = cluster_tree['parent'].min() - parent = cluster_tree[cluster_tree['child'] == leaf]['parent'] + cdef cnp.intp_t root, parent + cdef cnp.float64_t parent_eps + + root = cluster_tree['left_node'].min() + parent = cluster_tree[cluster_tree['right_node'] == leaf]['left_node'] if parent == root: if allow_single_cluster: return parent else: return leaf #return node closest to root - parent_eps = 1/cluster_tree[cluster_tree['child'] == parent]['lambda_val'] + parent_eps = 1 / cluster_tree[cluster_tree['right_node'] == parent]['value'] if parent_eps > cluster_selection_epsilon: return parent else: - return traverse_upwards(cluster_tree, cluster_selection_epsilon, parent, allow_single_cluster) - -cpdef set epsilon_search(set leaves, np.ndarray cluster_tree, np.double_t cluster_selection_epsilon, np.intp_t allow_single_cluster): - - selected_clusters = list() - processed = list() + return traverse_upwards( + cluster_tree, + cluster_selection_epsilon, + parent, + allow_single_cluster + ) + +cdef set epsilon_search( + set leaves, + cnp.ndarray[HIERARCHY_t, ndim=1] cluster_tree, + cnp.float64_t cluster_selection_epsilon, + cnp.intp_t allow_single_cluster +): + + cdef list selected_clusters = list() + cdef list processed = list() + cdef cnp.intp_t leaf, epsilon_child, sub_node + cdef cnp.float64_t eps for leaf in leaves: - eps = 1/cluster_tree['lambda_val'][cluster_tree['child'] == leaf][0] + eps = 1 / cluster_tree['value'][cluster_tree['right_node'] == leaf][0] if eps < cluster_selection_epsilon: if leaf not in processed: - epsilon_child = traverse_upwards(cluster_tree, cluster_selection_epsilon, leaf, allow_single_cluster) + epsilon_child = traverse_upwards( + cluster_tree, + cluster_selection_epsilon, + leaf, + allow_single_cluster + ) selected_clusters.append(epsilon_child) for sub_node in bfs_from_cluster_tree(cluster_tree, epsilon_child): @@ -570,27 +555,30 @@ cpdef set epsilon_search(set leaves, np.ndarray cluster_tree, np.double_t cluste return set(selected_clusters) @cython.wraparound(True) -cpdef tuple get_clusters(np.ndarray tree, dict stability, - cluster_selection_method='eom', - allow_single_cluster=False, - cluster_selection_epsilon=0.0, - max_cluster_size=None): +cpdef tuple get_clusters( + cnp.ndarray hierarchy, + dict stability, + cluster_selection_method='eom', + cnp.uint8_t allow_single_cluster=False, + cnp.float64_t cluster_selection_epsilon=0.0, + max_cluster_size=None +): """Given a tree and stability dict, produce the cluster labels (and probabilities) for a flat clustering based on the chosen cluster selection method. Parameters ---------- - tree : numpy recarray + hierarchy : ndarray of shape (n_samples,), dtype=HIERARCHY_dtype The condensed tree to extract flat clusters from stability : dict A dictionary mapping cluster_ids to stability values - cluster_selection_method : string, optional (default 'eom') + cluster_selection_method : {"eom", "leaf"} The method of selecting clusters. The default is the - Excess of Mass algorithm specified by 'eom'. The alternate - option is 'leaf'. + Excess of Mass algorithm specified by "eom". The alternate + option is "leaf". allow_single_cluster : boolean, optional (default False) Whether to allow a single cluster to be selected by the @@ -606,27 +594,24 @@ cpdef tuple get_clusters(np.ndarray tree, dict stability, Returns ------- - labels : ndarray (n_samples,) + labels : ndarray of shape (n_samples,) An integer array of cluster labels, with -1 denoting noise. - probabilities : ndarray (n_samples,) + probabilities : ndarray of shape (n_samples,) The cluster membership strength of each sample. stabilities : ndarray (n_clusters,) The cluster coherence strengths of each cluster. """ - cdef list node_list - cdef np.ndarray cluster_tree - cdef np.ndarray child_selection - cdef dict is_cluster - cdef dict cluster_sizes - cdef float subtree_stability - cdef np.intp_t node - cdef np.intp_t sub_node - cdef np.intp_t cluster - cdef np.intp_t num_points - cdef np.ndarray labels - cdef np.double_t max_lambda + cdef: + list node_list + cnp.ndarray[HIERARCHY_t, ndim=1] cluster_tree + cnp.uint8_t[:] child_selection + cnp.ndarray[cnp.intp_t, ndim=1] labels + dict is_cluster, cluster_sizes + cnp.float64_t subtree_stability, max_lambda + cnp.intp_t node, sub_node, cluster, num_points + cnp.ndarray[cnp.float64_t, ndim=1] probs # Assume clusters are ordered by numeric id equivalent to # a topological sort of the tree; This is valid given the @@ -635,30 +620,37 @@ cpdef tuple get_clusters(np.ndarray tree, dict stability, if allow_single_cluster: node_list = sorted(stability.keys(), reverse=True) else: + # exclude root node_list = sorted(stability.keys(), reverse=True)[:-1] - # (exclude root) - cluster_tree = tree[tree['child_size'] > 1] + cluster_tree = hierarchy[hierarchy['cluster_size'] > 1] is_cluster = {cluster: True for cluster in node_list} - num_points = np.max(tree[tree['child_size'] == 1]['child']) + 1 - max_lambda = np.max(tree['lambda_val']) + num_points = np.max(hierarchy[hierarchy['cluster_size'] == 1]['right_node']) + 1 + max_lambda = np.max(hierarchy['value']) if max_cluster_size is None: - max_cluster_size = num_points + 1 # Set to a value that will never be triggered - cluster_sizes = {child: child_size for child, child_size - in zip(cluster_tree['child'], cluster_tree['child_size'])} + # Set to a value that will never be triggered + max_cluster_size = num_points + 1 + cluster_sizes = { + child: cluster_size for child, cluster_size + in zip(cluster_tree['right_node'], cluster_tree['cluster_size']) + } if allow_single_cluster: # Compute cluster size for the root node cluster_sizes[node_list[-1]] = np.sum( - cluster_tree[cluster_tree['parent'] == node_list[-1]]['child_size']) + cluster_tree[cluster_tree['left_node'] == node_list[-1]]['cluster_size']) if cluster_selection_method == 'eom': for node in node_list: - child_selection = (cluster_tree['parent'] == node) + child_selection = cluster_tree['left_node'] == node subtree_stability = np.sum([ stability[child] for - child in cluster_tree['child'][child_selection]]) - if subtree_stability > stability[node] or cluster_sizes[node] > max_cluster_size: + child in cluster_tree['right_node'][child_selection]] + ) + if ( + subtree_stability > stability[node] + or cluster_sizes[node] > max_cluster_size + ): is_cluster[node] = False stability[node] = subtree_stability else: @@ -670,26 +662,39 @@ cpdef tuple get_clusters(np.ndarray tree, dict stability, eom_clusters = [c for c in is_cluster if is_cluster[c]] selected_clusters = [] # first check if eom_clusters only has root node, which skips epsilon check. - if (len(eom_clusters) == 1 and eom_clusters[0] == cluster_tree['parent'].min()): + if ( + len(eom_clusters) == 1 + and eom_clusters[0] == cluster_tree['left_node'].min() + ): if allow_single_cluster: selected_clusters = eom_clusters else: - selected_clusters = epsilon_search(set(eom_clusters), cluster_tree, cluster_selection_epsilon, allow_single_cluster) + selected_clusters = epsilon_search( + set(eom_clusters), + cluster_tree, + cluster_selection_epsilon, + allow_single_cluster + ) for c in is_cluster: if c in selected_clusters: is_cluster[c] = True else: is_cluster[c] = False - elif cluster_selection_method == 'leaf': + else: leaves = set(get_cluster_tree_leaves(cluster_tree)) if len(leaves) == 0: for c in is_cluster: is_cluster[c] = False - is_cluster[tree['parent'].min()] = True + is_cluster[hierarchy['left_node'].min()] = True if cluster_selection_epsilon != 0.0: - selected_clusters = epsilon_search(leaves, cluster_tree, cluster_selection_epsilon, allow_single_cluster) + selected_clusters = epsilon_search( + leaves, + cluster_tree, + cluster_selection_epsilon, + allow_single_cluster + ) else: selected_clusters = leaves @@ -698,16 +703,18 @@ cpdef tuple get_clusters(np.ndarray tree, dict stability, is_cluster[c] = True else: is_cluster[c] = False - else: - raise ValueError('Invalid Cluster Selection Method: %s\n' - 'Should be one of: "eom", "leaf"\n') - clusters = set([c for c in is_cluster if is_cluster[c]]) + clusters = {c for c in is_cluster if is_cluster[c]} cluster_map = {c: n for n, c in enumerate(sorted(list(clusters)))} reverse_cluster_map = {n: c for c, n in cluster_map.items()} - labels = do_labelling(tree, clusters, cluster_map, - allow_single_cluster, cluster_selection_epsilon) - probs = get_probabilities(tree, reverse_cluster_map, labels) + labels = do_labelling( + hierarchy, + clusters, + cluster_map, + allow_single_cluster, + cluster_selection_epsilon + ) + probs = get_probabilities(hierarchy, reverse_cluster_map, labels) return (labels, probs) diff --git a/sklearn/cluster/_hdbscan/hdbscan.py b/sklearn/cluster/_hdbscan/hdbscan.py index 4f1fcf1962d0b..d8cc8a93d70ac 100644 --- a/sklearn/cluster/_hdbscan/hdbscan.py +++ b/sklearn/cluster/_hdbscan/hdbscan.py @@ -21,6 +21,7 @@ from ...neighbors import BallTree, KDTree, NearestNeighbors from ...utils._param_validation import Interval, StrOptions from ...utils.validation import _assert_all_finite +from ._tree import HIERARCHY_dtype from ._linkage import ( make_single_linkage, mst_from_mutual_reachability, @@ -214,22 +215,24 @@ def remap_single_linkage_tree(tree, internal_to_raw, non_finite): outlier_count = len(non_finite) for i, (left, right, *_) in enumerate(tree): if left < finite_count: - tree[i, 0] = internal_to_raw[left] + tree[i]["left_node"] = internal_to_raw[left] else: - tree[i, 0] = left + outlier_count + tree[i]["left_node"] = left + outlier_count if right < finite_count: - tree[i, 1] = internal_to_raw[right] + tree[i]["right_node"] = internal_to_raw[right] else: - tree[i, 1] = right + outlier_count + tree[i]["right_node"] = right + outlier_count - outlier_tree = np.zeros((len(non_finite), 4)) - last_cluster_id = tree[tree.shape[0] - 1][0:2].max() - last_cluster_size = tree[tree.shape[0] - 1][3] + outlier_tree = np.zeros(len(non_finite), dtype=HIERARCHY_dtype) + last_cluster_id = max( + tree[tree.shape[0] - 1]["left_node"], tree[tree.shape[0] - 1]["right_node"] + ) + last_cluster_size = tree[tree.shape[0] - 1]["value"] for i, outlier in enumerate(non_finite): outlier_tree[i] = (outlier, last_cluster_id + 1, np.inf, last_cluster_size + 1) last_cluster_id += 1 last_cluster_size += 1 - tree = np.vstack([tree, outlier_tree]) + tree = np.concatenate([tree, outlier_tree]) return tree