From 068dccf2f457b0b4b2990c7ab8effc87573510b8 Mon Sep 17 00:00:00 2001 From: luis Date: Thu, 27 Jul 2023 16:09:21 +0200 Subject: [PATCH 1/4] add @validate_params to dbscan --- sklearn/cluster/_dbscan.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/sklearn/cluster/_dbscan.py b/sklearn/cluster/_dbscan.py index 30205f70ae157..43caa0d936a93 100644 --- a/sklearn/cluster/_dbscan.py +++ b/sklearn/cluster/_dbscan.py @@ -17,11 +17,29 @@ from ..base import BaseEstimator, ClusterMixin, _fit_context from ..metrics.pairwise import _VALID_METRICS from ..neighbors import NearestNeighbors -from ..utils._param_validation import Interval, StrOptions +from ..utils._param_validation import Interval, StrOptions, validate_params from ..utils.validation import _check_sample_weight from ._dbscan_inner import dbscan_inner +@validate_params( + { + "X": ["array-like", "sparse matrix"], + "eps": [Interval(Real, 0.0, None, closed="neither")], + "min_samples": [Interval(Integral, 1, None, closed="left")], + "metric": [ + StrOptions(set(_VALID_METRICS) | {"precomputed"}), + callable, + ], + "metric_params": [dict, None], + "algorithm": [StrOptions({"auto", "ball_tree", "kd_tree", "brute"})], + "leaf_size": [Interval(Integral, 1, None, closed="left")], + "p": [Interval(Real, 0.0, None, closed="left"), None], + "sample_weight": ["array-like", None], + "n_jobs": [Integral, None], + }, + prefer_skip_nested_validation=False, +) def dbscan( X, eps=0.5, From 4bdc1052756fc83877d688fc9bd40c948c6a6e58 Mon Sep 17 00:00:00 2001 From: luis Date: Thu, 27 Jul 2023 16:10:36 +0200 Subject: [PATCH 2/4] update public function list --- sklearn/tests/test_public_functions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index bd20de37d405e..7f9d1f221bba4 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -133,6 +133,7 @@ def _check_function_param_validation( "sklearn.calibration.calibration_curve", "sklearn.cluster.cluster_optics_dbscan", "sklearn.cluster.compute_optics_graph", + "sklearn.cluster.dbscan", "sklearn.cluster.estimate_bandwidth", "sklearn.cluster.kmeans_plusplus", "sklearn.cluster.cluster_optics_xi", From 21e2bf80664afc6fc8227bf7c48800c9747ac3f3 Mon Sep 17 00:00:00 2001 From: Luis Silvestrin Date: Thu, 27 Jul 2023 16:51:33 +0200 Subject: [PATCH 3/4] replacing unecessary annotator by comment --- sklearn/cluster/_dbscan.py | 20 ++------------------ sklearn/tests/test_public_functions.py | 1 - 2 files changed, 2 insertions(+), 19 deletions(-) diff --git a/sklearn/cluster/_dbscan.py b/sklearn/cluster/_dbscan.py index 43caa0d936a93..0d0b3cfde89f4 100644 --- a/sklearn/cluster/_dbscan.py +++ b/sklearn/cluster/_dbscan.py @@ -22,24 +22,8 @@ from ._dbscan_inner import dbscan_inner -@validate_params( - { - "X": ["array-like", "sparse matrix"], - "eps": [Interval(Real, 0.0, None, closed="neither")], - "min_samples": [Interval(Integral, 1, None, closed="left")], - "metric": [ - StrOptions(set(_VALID_METRICS) | {"precomputed"}), - callable, - ], - "metric_params": [dict, None], - "algorithm": [StrOptions({"auto", "ball_tree", "kd_tree", "brute"})], - "leaf_size": [Interval(Integral, 1, None, closed="left")], - "p": [Interval(Real, 0.0, None, closed="left"), None], - "sample_weight": ["array-like", None], - "n_jobs": [Integral, None], - }, - prefer_skip_nested_validation=False, -) +# This function is not validated using validate_params because +# it's just a factory for DBSCAN. def dbscan( X, eps=0.5, diff --git a/sklearn/tests/test_public_functions.py b/sklearn/tests/test_public_functions.py index 7f9d1f221bba4..bd20de37d405e 100644 --- a/sklearn/tests/test_public_functions.py +++ b/sklearn/tests/test_public_functions.py @@ -133,7 +133,6 @@ def _check_function_param_validation( "sklearn.calibration.calibration_curve", "sklearn.cluster.cluster_optics_dbscan", "sklearn.cluster.compute_optics_graph", - "sklearn.cluster.dbscan", "sklearn.cluster.estimate_bandwidth", "sklearn.cluster.kmeans_plusplus", "sklearn.cluster.cluster_optics_xi", From 7aa1bead6f2f43a9072a8e41b2f7bd7339d479be Mon Sep 17 00:00:00 2001 From: Luis Silvestrin Date: Thu, 27 Jul 2023 16:52:26 +0200 Subject: [PATCH 4/4] remove unused imports --- sklearn/cluster/_dbscan.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sklearn/cluster/_dbscan.py b/sklearn/cluster/_dbscan.py index 0d0b3cfde89f4..4dd09c9531c44 100644 --- a/sklearn/cluster/_dbscan.py +++ b/sklearn/cluster/_dbscan.py @@ -17,7 +17,7 @@ from ..base import BaseEstimator, ClusterMixin, _fit_context from ..metrics.pairwise import _VALID_METRICS from ..neighbors import NearestNeighbors -from ..utils._param_validation import Interval, StrOptions, validate_params +from ..utils._param_validation import Interval, StrOptions from ..utils.validation import _check_sample_weight from ._dbscan_inner import dbscan_inner