From 28c54b95b47af0e1a585088f6100d68fa45efd20 Mon Sep 17 00:00:00 2001 From: "wishyut.pitawanik" Date: Thu, 23 Feb 2023 20:30:42 +0100 Subject: [PATCH 1/4] validate params for mean_shift() --- sklearn/cluster/_mean_shift.py | 4 ++-- sklearn/tests/test_public_functions.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/sklearn/cluster/_mean_shift.py b/sklearn/cluster/_mean_shift.py index cbed1c5be3b9e..728d07a2d6735 100644 --- a/sklearn/cluster/_mean_shift.py +++ b/sklearn/cluster/_mean_shift.py @@ -118,7 +118,7 @@ def _mean_shift_single_seed(my_mean, X, nbrs, max_iter): completed_iterations += 1 return tuple(my_mean), len(points_within), completed_iterations - +@validate_params({ "X": ["array-like"] }) def mean_shift( X, *, @@ -141,7 +141,7 @@ def mean_shift( Input data. bandwidth : float, default=None - Kernel bandwidth. + Kernel bandwidth. A float value > 0. If bandwidth is not given, it is determined using a heuristic based on the median of all pairwise distances. This will take quadratic time in diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 9b2b56cdb3eb8..cc88bc0afcb48 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -155,6 +155,7 @@ def test_function_param_validation(func_module): PARAM_VALIDATION_CLASS_WRAPPER_LIST = [ ("sklearn.cluster.affinity_propagation", "sklearn.cluster.AffinityPropagation"), + ("sklearn.cluster.mean_shift", "sklearn.cluster.MeanShift"), ("sklearn.cluster.spectral_clustering", "sklearn.cluster.SpectralClustering"), ("sklearn.covariance.ledoit_wolf", "sklearn.covariance.LedoitWolf"), ("sklearn.covariance.oas", "sklearn.covariance.OAS"), From 784a59e34bc598df16dc2fff0b7f527f73329432 Mon Sep 17 00:00:00 2001 From: "wishyut.pitawanik" Date: Thu, 23 Feb 2023 20:40:29 +0100 Subject: [PATCH 2/4] iterations From 4eb4b72edcdb200a55cc289e5baacc0532e1389f Mon Sep 17 00:00:00 2001 From: "wishyut.pitawanik" Date: Thu, 23 Feb 2023 20:42:46 +0100 Subject: [PATCH 3/4] fixed linting yay --- sklearn/cluster/_mean_shift.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sklearn/cluster/_mean_shift.py b/sklearn/cluster/_mean_shift.py index 728d07a2d6735..1ea23ee6c1052 100644 --- a/sklearn/cluster/_mean_shift.py +++ b/sklearn/cluster/_mean_shift.py @@ -118,7 +118,8 @@ def _mean_shift_single_seed(my_mean, X, nbrs, max_iter): completed_iterations += 1 return tuple(my_mean), len(points_within), completed_iterations -@validate_params({ "X": ["array-like"] }) + +@validate_params({"X": ["array-like"]}) def mean_shift( X, *, From e154d87815097696dda5275ee4ddba082d501d72 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 24 Feb 2023 10:25:48 +0100 Subject: [PATCH 4/4] rephrase --- sklearn/cluster/_mean_shift.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/cluster/_mean_shift.py b/sklearn/cluster/_mean_shift.py index 1ea23ee6c1052..46a00ed3f0740 100644 --- a/sklearn/cluster/_mean_shift.py +++ b/sklearn/cluster/_mean_shift.py @@ -142,9 +142,9 @@ def mean_shift( Input data. bandwidth : float, default=None - Kernel bandwidth. A float value > 0. + Kernel bandwidth. If not None, must be in the range [0, +inf). - If bandwidth is not given, it is determined using a heuristic based on + If None, the bandwidth is determined using a heuristic based on the median of all pairwise distances. This will take quadratic time in the number of samples. The sklearn.cluster.estimate_bandwidth function can be used to do this more efficiently.