diff --git a/sklearn/cluster/_birch.py b/sklearn/cluster/_birch.py index 78c15bb8e1a15..b1f3fb5a64a74 100644 --- a/sklearn/cluster/_birch.py +++ b/sklearn/cluster/_birch.py @@ -13,7 +13,7 @@ from ..metrics.pairwise import euclidean_distances from ..base import TransformerMixin, ClusterMixin, BaseEstimator from ..utils.extmath import row_norms -from ..utils import deprecated +from ..utils import check_scalar, deprecated from ..utils.validation import check_is_fitted from ..exceptions import ConvergenceWarning from . import AgglomerativeClustering @@ -512,7 +512,31 @@ def fit(self, X, y=None): self Fitted estimator. """ - # TODO: Remove deprecated flags in 1.2 + + # Validating the scalar parameters. + check_scalar( + self.threshold, + "threshold", + target_type=numbers.Real, + min_val=0.0, + include_boundaries="neither", + ) + check_scalar( + self.branching_factor, + "branching_factor", + target_type=numbers.Integral, + min_val=1, + include_boundaries="neither", + ) + if isinstance(self.n_clusters, numbers.Number): + check_scalar( + self.n_clusters, + "n_clusters", + target_type=numbers.Integral, + min_val=1, + ) + + # TODO: Remove deprected flags in 1.2 self._deprecated_fit, self._deprecated_partial_fit = True, False return self._fit(X, partial=False) @@ -526,8 +550,6 @@ def _fit(self, X, partial): threshold = self.threshold branching_factor = self.branching_factor - if branching_factor <= 1: - raise ValueError("Branching_factor should be greater than one.") n_samples, n_features = X.shape # If partial_fit is called for the first time or fit is called, we @@ -700,7 +722,7 @@ def _global_clustering(self, X=None): if len(centroids) < self.n_clusters: not_enough_centroids = True elif clusterer is not None and not hasattr(clusterer, "fit_predict"): - raise ValueError( + raise TypeError( "n_clusters should be an instance of ClusterMixin or an int" ) diff --git a/sklearn/cluster/_dbscan.py b/sklearn/cluster/_dbscan.py index 9826af09ec372..0994248e01697 100644 --- a/sklearn/cluster/_dbscan.py +++ b/sklearn/cluster/_dbscan.py @@ -10,9 +10,11 @@ # License: BSD 3 clause import numpy as np +import numbers import warnings from scipy import sparse +from ..utils import check_scalar from ..base import BaseEstimator, ClusterMixin from ..utils.validation import _check_sample_weight from ..neighbors import NearestNeighbors @@ -345,9 +347,6 @@ def fit(self, X, y=None, sample_weight=None): """ X = self._validate_data(X, accept_sparse="csr") - if not self.eps > 0.0: - raise ValueError("eps must be positive.") - if sample_weight is not None: sample_weight = _check_sample_weight(sample_weight, X) @@ -361,6 +360,39 @@ def fit(self, X, y=None, sample_weight=None): warnings.simplefilter("ignore", sparse.SparseEfficiencyWarning) X.setdiag(X.diagonal()) # XXX: modifies X's internals in-place + # Validating the scalar parameters. + check_scalar( + self.eps, + "eps", + target_type=numbers.Real, + min_val=0.0, + include_boundaries="neither", + ) + check_scalar( + self.min_samples, + "min_samples", + target_type=numbers.Integral, + min_val=1, + include_boundaries="left", + ) + check_scalar( + self.leaf_size, + "leaf_size", + target_type=numbers.Integral, + min_val=1, + include_boundaries="left", + ) + if self.p is not None: + check_scalar( + self.p, + "p", + target_type=numbers.Real, + min_val=0.0, + include_boundaries="left", + ) + if self.n_jobs is not None: + check_scalar(self.n_jobs, "n_jobs", target_type=numbers.Integral) + neighbors_model = NearestNeighbors( radius=self.eps, algorithm=self.algorithm, diff --git a/sklearn/cluster/tests/test_birch.py b/sklearn/cluster/tests/test_birch.py index d32013e4e1314..5d8a3222ef156 100644 --- a/sklearn/cluster/tests/test_birch.py +++ b/sklearn/cluster/tests/test_birch.py @@ -85,7 +85,8 @@ def test_n_clusters(): # Test that the wrong global clustering step raises an Error. clf = ElasticNet() brc3 = Birch(n_clusters=clf) - with pytest.raises(ValueError): + err_msg = "n_clusters should be an instance of ClusterMixin or an int" + with pytest.raises(TypeError, match=err_msg): brc3.fit(X) # Test that a small number of clusters raises a warning. @@ -141,11 +142,6 @@ def test_branching_factor(): brc.fit(X) check_branching_factor(brc.root_, branching_factor) - # Raises error when branching_factor is set to one. - brc = Birch(n_clusters=None, branching_factor=1, threshold=0.01) - with pytest.raises(ValueError): - brc.fit(X) - def check_threshold(birch_instance, threshold): """Use the leaf linked list for traversal""" @@ -187,3 +183,39 @@ def test_birch_fit_attributes_deprecated(attribute): with pytest.warns(FutureWarning, match=msg): getattr(brc, attribute) + + +@pytest.mark.parametrize( + "params, err_type, err_msg", + [ + ({"threshold": -1.0}, ValueError, "threshold == -1.0, must be > 0.0."), + ({"threshold": 0.0}, ValueError, "threshold == 0.0, must be > 0.0."), + ({"branching_factor": 0}, ValueError, "branching_factor == 0, must be > 1."), + ({"branching_factor": 1}, ValueError, "branching_factor == 1, must be > 1."), + ( + {"branching_factor": 1.5}, + TypeError, + "branching_factor must be an instance of , not" + " .", + ), + ({"branching_factor": -2}, ValueError, "branching_factor == -2, must be > 1."), + ({"n_clusters": 0}, ValueError, "n_clusters == 0, must be >= 1."), + ( + {"n_clusters": 2.5}, + TypeError, + "n_clusters must be an instance of , not .", + ), + ( + {"n_clusters": "whatever"}, + TypeError, + "n_clusters should be an instance of ClusterMixin or an int", + ), + ({"n_clusters": -3}, ValueError, "n_clusters == -3, must be >= 1."), + ], +) +def test_birch_params_validation(params, err_type, err_msg): + """Check the parameters validation in `Birch`.""" + X, _ = make_blobs(n_samples=80, centers=4) + with pytest.raises(err_type, match=err_msg): + Birch(**params).fit(X) diff --git a/sklearn/cluster/tests/test_dbscan.py b/sklearn/cluster/tests/test_dbscan.py index 1c5ef8e58b2c5..40949e81a24b1 100644 --- a/sklearn/cluster/tests/test_dbscan.py +++ b/sklearn/cluster/tests/test_dbscan.py @@ -272,11 +272,8 @@ def test_input_validation(): @pytest.mark.parametrize( "args", [ - {"eps": -1.0}, {"algorithm": "blah"}, {"metric": "blah"}, - {"leaf_size": -1}, - {"p": -1}, ], ) def test_dbscan_badargs(args): @@ -428,3 +425,39 @@ def test_dbscan_precomputed_metric_with_initial_rows_zero(): matrix = sparse.csr_matrix(ar) labels = DBSCAN(eps=0.2, metric="precomputed", min_samples=2).fit(matrix).labels_ assert_array_equal(labels, [-1, -1, 0, 0, 0, 1, 1]) + + +@pytest.mark.parametrize( + "params, err_type, err_msg", + [ + ({"eps": -1.0}, ValueError, "eps == -1.0, must be > 0.0."), + ({"eps": 0.0}, ValueError, "eps == 0.0, must be > 0.0."), + ({"min_samples": 0}, ValueError, "min_samples == 0, must be >= 1."), + ( + {"min_samples": 1.5}, + TypeError, + "min_samples must be an instance of , not .", + ), + ({"min_samples": -2}, ValueError, "min_samples == -2, must be >= 1."), + ({"leaf_size": 0}, ValueError, "leaf_size == 0, must be >= 1."), + ( + {"leaf_size": 2.5}, + TypeError, + "leaf_size must be an instance of , not .", + ), + ({"leaf_size": -3}, ValueError, "leaf_size == -3, must be >= 1."), + ({"p": -2}, ValueError, "p == -2, must be >= 0.0."), + ( + {"n_jobs": 2.5}, + TypeError, + "n_jobs must be an instance of , not .", + ), + ], +) +def test_dbscan_params_validation(params, err_type, err_msg): + """Check the parameters validation in `DBSCAN`.""" + with pytest.raises(err_type, match=err_msg): + DBSCAN(**params).fit(X)