diff --git a/sklearn/cluster/_hdbscan/_reachability.pyx b/sklearn/cluster/_hdbscan/_reachability.pyx index 4118732d6e623..c83fef742e82e 100644 --- a/sklearn/cluster/_hdbscan/_reachability.pyx +++ b/sklearn/cluster/_hdbscan/_reachability.pyx @@ -1,41 +1,59 @@ -# mutual reachability distance compiutations +# mutual reachability distance computations # Authors: Leland McInnes # Meekail Zain +# Guillaume Lemaitre # License: 3-clause BSD import numpy as np from scipy.sparse import issparse -from ...neighbors import BallTree, KDTree +cimport cython +from cython cimport floating cimport numpy as cnp from cython.parallel cimport prange -from libc.math cimport isfinite +from libc.math cimport isfinite, INFINITY +cnp.import_array() -def mutual_reachability(distance_matrix, min_points=5, max_dist=0.0): - """Compute the weighted adjacency matrix of the mutual reachability - graph of a distance matrix. Note that computation is performed in-place for - `distance_matrix`. If out-of-place computation is required, pass a copy to - this function. +ctypedef fused integral: + int + long long + + +def mutual_reachability_graph( + distance_matrix, min_samples=5, max_distance=0.0 +): + """Compute the weighted adjacency matrix of the mutual reachability graph. + + The mutual reachability distance used to build the graph is defined as:: + + max(d_core(x_p), d_core(x_q), d(x_p, x_q)) + + and the core distance `d_core` is defined as the distance between a point + `x_p` and its k-th nearest neighbor. + + Note that all computations are done in-place. Parameters ---------- - distance_matrix : ndarray or sparse matrix of shape (n_samples, n_samples) + distance_matrix : {ndarray, sparse matrix} of shape (n_samples, n_samples) Array of distances between samples. If sparse, the array must be in - `LIL` format. + `CSR` format. - min_points : int, default=5 + min_samples : int, default=5 The number of points in a neighbourhood for a point to be considered a core point. - max_dist : float, default=0.0 + max_distance : float, default=0.0 The distance which `np.inf` is replaced with. When the true mutual- reachability distance is measured to be infinite, it is instead - truncated to `max_dist`. + truncated to `max_dist`. Only used when `distance_matrix` is a sparse + matrix. Returns ------- - mututal_reachability: ndarray of shape (n_samples, n_samples) + mututal_reachability_graph: {ndarray, sparse matrix} of shape \ + (n_samples, n_samples) Weighted adjacency matrix of the mutual reachability graph. References @@ -45,78 +63,125 @@ def mutual_reachability(distance_matrix, min_points=5, max_dist=0.0): In Pacific-Asia Conference on Knowledge Discovery and Data Mining (pp. 160-172). Springer Berlin Heidelberg. """ - # Account for index offset - min_points -= 1 - - # Note that in both routines `distance_matrix` is operated on in-place. At - # this point, if out-of-place operation is desired then this function - # should have been passed a copy. + further_neighbor_idx = min_samples - 1 if issparse(distance_matrix): - return _sparse_mutual_reachability( - distance_matrix, - min_points=min_points, - max_dist=max_dist - ).tocsr() + if distance_matrix.format != "csr": + raise ValueError( + "Only sparse CSR matrices are supported for `distance_matrix`." + ) + _sparse_mutual_reachability_graph( + distance_matrix.data, + distance_matrix.indices, + distance_matrix.indptr, + distance_matrix.shape[0], + further_neighbor_idx=further_neighbor_idx, + max_distance=max_distance, + ) + else: + _dense_mutual_reachability_graph( + distance_matrix, further_neighbor_idx=further_neighbor_idx + ) + return distance_matrix - return _dense_mutual_reachability(distance_matrix, min_points=min_points) -cdef _dense_mutual_reachability( - cnp.ndarray[dtype=cnp.float64_t, ndim=2] distance_matrix, - cnp.intp_t min_points=5 +def _dense_mutual_reachability_graph( + cnp.ndarray[dtype=floating, ndim=2] distance_matrix, + cnp.intp_t further_neighbor_idx, ): - cdef cnp.intp_t i, j, n_samples = distance_matrix.shape[0] - cdef cnp.float64_t mr_dist - cdef cnp.float64_t[:] core_distances + """Dense implementation of mutual reachability graph. + + The computation is done in-place, i.e. the distance matrix is modified + directly. + + Parameters + ---------- + distance_matrix : ndarray of shape (n_samples, n_samples) + Array of distances between samples. - # Compute the core distances for all samples `x_p` corresponding - # to the distance of the k-th farthest neighbours (including - # `x_p`). + further_neighbor_idx : int + The index of the furthest neighbor to use to define the core distances. + """ + cdef: + cnp.intp_t i, j, n_samples = distance_matrix.shape[0] + floating mutual_reachibility_distance + floating[:] core_distances + + # We assume that the distance matrix is symmetric. We choose to sort every + # row to have the same implementation than the sparse case that requires + # CSR matrix. core_distances = np.partition( - distance_matrix, - min_points, - axis=0, - )[min_points] + distance_matrix, further_neighbor_idx, axis=1 + )[:, further_neighbor_idx] with nogil: for i in range(n_samples): for j in prange(n_samples): - mr_dist = max( + mutual_reachibility_distance = max( core_distances[i], core_distances[j], - distance_matrix[i, j] + distance_matrix[i, j], ) - distance_matrix[i, j] = mr_dist - return distance_matrix + distance_matrix[i, j] = mutual_reachibility_distance + -# Assumes LIL format. -# TODO: Rewrite for CSR. -cdef _sparse_mutual_reachability( - object distance_matrix, - cnp.intp_t min_points=5, - cnp.float64_t max_dist=0. +def _sparse_mutual_reachability_graph( + cnp.ndarray[floating, ndim=1, mode="c"] data, + cnp.ndarray[integral, ndim=1, mode="c"] indices, + cnp.ndarray[integral, ndim=1, mode="c"] indptr, + cnp.intp_t n_samples, + cnp.intp_t further_neighbor_idx, + floating max_distance, ): - cdef cnp.intp_t i, j, n, n_samples = distance_matrix.shape[0] - cdef cnp.float64_t mr_dist - cdef cnp.float64_t[:] core_distances - cdef cnp.int32_t[:] nz_row_data, nz_col_data - core_distances = np.empty(n_samples, dtype=np.float64) + """Sparse implementation of mutual reachability graph. + + The computation is done in-place, i.e. the distance matrix is modified + directly. This implementation only accepts `CSR` format sparse matrices. + + Parameters + ---------- + distance_matrix : sparse matrix of shape (n_samples, n_samples) + Sparse matrix of distances between samples. The sparse format should + be `CSR`. + + further_neighbor_idx : int + The index of the furthest neighbor to use to define the core distances. + + max_distance : float + The distance which `np.inf` is replaced with. When the true mutual- + reachability distance is measured to be infinite, it is instead + truncated to `max_dist`. Only used when `distance_matrix` is a sparse + matrix. + """ + cdef: + integral i, col_ind, row_ind + floating mutual_reachibility_distance + floating[:] core_distances + floating[:] row_data + + if floating is float: + dtype = np.float32 + else: + dtype = np.float64 + + core_distances = np.empty(n_samples, dtype=dtype) for i in range(n_samples): - if min_points < len(distance_matrix.data[i]): + row_data = data[indptr[i]:indptr[i + 1]] + if further_neighbor_idx < row_data.size: core_distances[i] = np.partition( - distance_matrix.data[i], - min_points - )[min_points] + row_data, further_neighbor_idx + )[further_neighbor_idx] else: - core_distances[i] = np.infty - - nz_row_data, nz_col_data = distance_matrix.nonzero() - for n in range(nz_row_data.shape[0]): - i = nz_row_data[n] - j = nz_col_data[n] - mr_dist = max(core_distances[i], core_distances[j], distance_matrix[i, j]) - if isfinite(mr_dist): - distance_matrix[i, j] = mr_dist - elif max_dist > 0: - distance_matrix[i, j] = max_dist - return distance_matrix + core_distances[i] = INFINITY + + with nogil: + for row_ind in range(n_samples): + for i in range(indptr[row_ind], indptr[row_ind + 1]): + col_ind = indices[i] + mutual_reachibility_distance = max( + core_distances[row_ind], core_distances[col_ind], data[i] + ) + if isfinite(mutual_reachibility_distance): + data[i] = mutual_reachibility_distance + elif max_distance > 0: + data[i] = max_distance diff --git a/sklearn/cluster/_hdbscan/hdbscan.py b/sklearn/cluster/_hdbscan/hdbscan.py index 4a8760503ae40..753c4145ccc70 100644 --- a/sklearn/cluster/_hdbscan/hdbscan.py +++ b/sklearn/cluster/_hdbscan/hdbscan.py @@ -22,7 +22,7 @@ from ...utils._param_validation import Interval, StrOptions from ...utils.validation import _assert_all_finite from ._linkage import label, mst_from_distance_matrix, mst_from_data_matrix -from ._reachability import mutual_reachability +from ._reachability import mutual_reachability_graph from ._tree import compute_stability, condense_tree, get_clusters, labelling_at_cut FAST_METRICS = KDTree.valid_metrics + BallTree.valid_metrics @@ -46,8 +46,8 @@ } -def _brute_mst(mutual_reachability, min_samples, sparse=False): - if not sparse: +def _brute_mst(mutual_reachability, min_samples): + if not issparse(mutual_reachability): return mst_from_distance_matrix(mutual_reachability) # Check connected component on mutual reachability @@ -63,7 +63,7 @@ def _brute_mst(mutual_reachability, min_samples, sparse=False): f"There exists points with fewer than {min_samples} neighbors. Ensure" " your distance matrix has non-zero values for at least" f" `min_sample`={min_samples} neighbors for each points (i.e. K-nn" - " graph), or specify a `max_dist` in `metric_params` to use when" + " graph), or specify a `max_distance` in `metric_params` to use when" " distances are missing." ) @@ -116,10 +116,6 @@ def _hdbscan_brute( **metric_params, ): if metric == "precomputed": - # Treating this case explicitly, instead of letting - # sklearn.metrics.pairwise_distances handle it, - # enables the usage of numpy.inf in the distance - # matrix to indicate missing distance information. distance_matrix = X.copy() if copy else X else: distance_matrix = pairwise_distances( @@ -127,28 +123,18 @@ def _hdbscan_brute( ) distance_matrix /= alpha - # max_dist is only relevant for sparse and is ignored for dense - max_dist = metric_params.get("max_dist", 0.0) - sparse = issparse(distance_matrix) - - # TODO: Investigate whether it is worth implementing a PWD backend for the - # combined operations of: - # - The pairwise distance calculation - # - The element-wise mutual-reachability calculation - # I suspect this would be better handled as one composite Cython routine to - # minimize memory-movement, however I (@micky774) am unsure whether it is - # narrow enough of a scope for the current PWD backend, or if it is better - # as a separate utility. - distance_matrix = distance_matrix.tolil() if sparse else distance_matrix + max_distance = metric_params.get("max_distance", 0.0) + if issparse(distance_matrix) and distance_matrix.format != "csr": + # we need CSR format to avoid a conversion in `_brute_mst` when calling + # `csgraph.connected_components` + distance_matrix = distance_matrix.tocsr() # Note that `distance_matrix` is manipulated in-place, however we do not # need it for anything else past this point, hence the operation is safe. - mutual_reachability_ = mutual_reachability( - distance_matrix, min_points=min_samples, max_dist=max_dist - ) - min_spanning_tree = _brute_mst( - mutual_reachability_, min_samples=min_samples, sparse=sparse + mutual_reachability_ = mutual_reachability_graph( + distance_matrix, min_samples=min_samples, max_distance=max_distance ) + min_spanning_tree = _brute_mst(mutual_reachability_, min_samples=min_samples) # Warn if the MST couldn't be constructed around the missing distances if np.isinf(min_spanning_tree.T[2]).any(): warn( @@ -358,10 +344,9 @@ class HDBSCAN(ClusterMixin, BaseEstimator): copy : bool, default=False If `copy=True` then any time an in-place modifications would be made that would overwrite data passed to :term:`fit`, a copy will first be - made, guaranteeing that the original data will be unchanged. Currently - this only makes a difference when passing in a dense precomputed - distance array (i.e. when `metric="precomputed"`) and using the - `"brute"` algorithm (see `algorithm` for details). + made, guaranteeing that the original data will be unchanged. + Currently, it only applies with `metric="precomputed"`, passing a dense + array or a sparse matrix of format CSR and algorithm used is `"brute"`. Attributes ---------- diff --git a/sklearn/cluster/_hdbscan/tests/__init__.py b/sklearn/cluster/_hdbscan/tests/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sklearn/cluster/_hdbscan/tests/test_reachibility.py b/sklearn/cluster/_hdbscan/tests/test_reachibility.py new file mode 100644 index 0000000000000..c8ba28d0af25b --- /dev/null +++ b/sklearn/cluster/_hdbscan/tests/test_reachibility.py @@ -0,0 +1,64 @@ +import numpy as np +import pytest + +from sklearn.utils._testing import ( + _convert_container, + assert_allclose, +) + +from sklearn.cluster._hdbscan._reachability import mutual_reachability_graph + + +def test_mutual_reachability_graph_error_sparse_format(): + """Check that we raise an error if the sparse format is not CSR.""" + rng = np.random.RandomState(0) + X = rng.randn(10, 10) + X = X.T @ X + np.fill_diagonal(X, 0.0) + X = _convert_container(X, "sparse_csc") + + err_msg = "Only sparse CSR matrices are supported" + with pytest.raises(ValueError, match=err_msg): + mutual_reachability_graph(X) + + +@pytest.mark.parametrize("array_type", ["array", "sparse_csr"]) +def test_mutual_reachability_graph_inplace(array_type): + """Check that the operation is happening inplace.""" + rng = np.random.RandomState(0) + X = rng.randn(10, 10) + X = X.T @ X + np.fill_diagonal(X, 0.0) + X = _convert_container(X, array_type) + + mr_graph = mutual_reachability_graph(X) + + assert id(mr_graph) == id(X) + + +def test_mutual_reachability_graph_equivalence_dense_sparse(): + """Check that we get the same results for dense and sparse implementation.""" + rng = np.random.RandomState(0) + X = rng.randn(5, 5) + X_dense = X.T @ X + X_sparse = _convert_container(X_dense, "sparse_csr") + + mr_graph_dense = mutual_reachability_graph(X_dense, min_samples=3) + mr_graph_sparse = mutual_reachability_graph(X_sparse, min_samples=3) + + assert_allclose(mr_graph_dense, mr_graph_sparse.A) + + +@pytest.mark.parametrize("array_type", ["array", "sparse_csr"]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_mutual_reachability_graph_preserve_dtype(array_type, dtype): + """Check that the computation preserve dtype thanks to fused types.""" + rng = np.random.RandomState(0) + X = rng.randn(10, 10) + X = (X.T @ X).astype(dtype) + np.fill_diagonal(X, 0.0) + X = _convert_container(X, array_type) + + assert X.dtype == dtype + mr_graph = mutual_reachability_graph(X) + assert mr_graph.dtype == dtype