diff --git a/sklearn/cluster/_mean_shift.py b/sklearn/cluster/_mean_shift.py index cbed1c5be3b9e..46a00ed3f0740 100644 --- a/sklearn/cluster/_mean_shift.py +++ b/sklearn/cluster/_mean_shift.py @@ -119,6 +119,7 @@ def _mean_shift_single_seed(my_mean, X, nbrs, max_iter): return tuple(my_mean), len(points_within), completed_iterations +@validate_params({"X": ["array-like"]}) def mean_shift( X, *, @@ -141,9 +142,9 @@ def mean_shift( Input data. bandwidth : float, default=None - Kernel bandwidth. + Kernel bandwidth. If not None, must be in the range [0, +inf). - If bandwidth is not given, it is determined using a heuristic based on + If None, the bandwidth is determined using a heuristic based on the median of all pairwise distances. This will take quadratic time in the number of samples. The sklearn.cluster.estimate_bandwidth function can be used to do this more efficiently. diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 9b2b56cdb3eb8..cc88bc0afcb48 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -155,6 +155,7 @@ def test_function_param_validation(func_module): PARAM_VALIDATION_CLASS_WRAPPER_LIST = [ ("sklearn.cluster.affinity_propagation", "sklearn.cluster.AffinityPropagation"), + ("sklearn.cluster.mean_shift", "sklearn.cluster.MeanShift"), ("sklearn.cluster.spectral_clustering", "sklearn.cluster.SpectralClustering"), ("sklearn.covariance.ledoit_wolf", "sklearn.covariance.LedoitWolf"), ("sklearn.covariance.oas", "sklearn.covariance.OAS"),