Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 2f2364d

Browse files
SanjayMarreddithomasjpfanglemaitre
authored
MNT Use check_scalar in BIRCH and DBSCAN (#20816)
Co-authored-by: Thomas J. Fan <[email protected]> Co-authored-by: Guillaume Lemaitre <[email protected]>
1 parent b2ee0f4 commit 2f2364d

File tree

4 files changed

+136
-17
lines changed

4 files changed

+136
-17
lines changed

sklearn/cluster/_birch.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from ..metrics.pairwise import euclidean_distances
1414
from ..base import TransformerMixin, ClusterMixin, BaseEstimator
1515
from ..utils.extmath import row_norms
16-
from ..utils import deprecated
16+
from ..utils import check_scalar, deprecated
1717
from ..utils.validation import check_is_fitted
1818
from ..exceptions import ConvergenceWarning
1919
from . import AgglomerativeClustering
@@ -512,7 +512,31 @@ def fit(self, X, y=None):
512512
self
513513
Fitted estimator.
514514
"""
515-
# TODO: Remove deprecated flags in 1.2
515+
516+
# Validating the scalar parameters.
517+
check_scalar(
518+
self.threshold,
519+
"threshold",
520+
target_type=numbers.Real,
521+
min_val=0.0,
522+
include_boundaries="neither",
523+
)
524+
check_scalar(
525+
self.branching_factor,
526+
"branching_factor",
527+
target_type=numbers.Integral,
528+
min_val=1,
529+
include_boundaries="neither",
530+
)
531+
if isinstance(self.n_clusters, numbers.Number):
532+
check_scalar(
533+
self.n_clusters,
534+
"n_clusters",
535+
target_type=numbers.Integral,
536+
min_val=1,
537+
)
538+
539+
# TODO: Remove deprected flags in 1.2
516540
self._deprecated_fit, self._deprecated_partial_fit = True, False
517541
return self._fit(X, partial=False)
518542

@@ -526,8 +550,6 @@ def _fit(self, X, partial):
526550
threshold = self.threshold
527551
branching_factor = self.branching_factor
528552

529-
if branching_factor <= 1:
530-
raise ValueError("Branching_factor should be greater than one.")
531553
n_samples, n_features = X.shape
532554

533555
# If partial_fit is called for the first time or fit is called, we
@@ -700,7 +722,7 @@ def _global_clustering(self, X=None):
700722
if len(centroids) < self.n_clusters:
701723
not_enough_centroids = True
702724
elif clusterer is not None and not hasattr(clusterer, "fit_predict"):
703-
raise ValueError(
725+
raise TypeError(
704726
"n_clusters should be an instance of ClusterMixin or an int"
705727
)
706728

sklearn/cluster/_dbscan.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
# License: BSD 3 clause
1111

1212
import numpy as np
13+
import numbers
1314
import warnings
1415
from scipy import sparse
1516

17+
from ..utils import check_scalar
1618
from ..base import BaseEstimator, ClusterMixin
1719
from ..utils.validation import _check_sample_weight
1820
from ..neighbors import NearestNeighbors
@@ -345,9 +347,6 @@ def fit(self, X, y=None, sample_weight=None):
345347
"""
346348
X = self._validate_data(X, accept_sparse="csr")
347349

348-
if not self.eps > 0.0:
349-
raise ValueError("eps must be positive.")
350-
351350
if sample_weight is not None:
352351
sample_weight = _check_sample_weight(sample_weight, X)
353352

@@ -361,6 +360,39 @@ def fit(self, X, y=None, sample_weight=None):
361360
warnings.simplefilter("ignore", sparse.SparseEfficiencyWarning)
362361
X.setdiag(X.diagonal()) # XXX: modifies X's internals in-place
363362

363+
# Validating the scalar parameters.
364+
check_scalar(
365+
self.eps,
366+
"eps",
367+
target_type=numbers.Real,
368+
min_val=0.0,
369+
include_boundaries="neither",
370+
)
371+
check_scalar(
372+
self.min_samples,
373+
"min_samples",
374+
target_type=numbers.Integral,
375+
min_val=1,
376+
include_boundaries="left",
377+
)
378+
check_scalar(
379+
self.leaf_size,
380+
"leaf_size",
381+
target_type=numbers.Integral,
382+
min_val=1,
383+
include_boundaries="left",
384+
)
385+
if self.p is not None:
386+
check_scalar(
387+
self.p,
388+
"p",
389+
target_type=numbers.Real,
390+
min_val=0.0,
391+
include_boundaries="left",
392+
)
393+
if self.n_jobs is not None:
394+
check_scalar(self.n_jobs, "n_jobs", target_type=numbers.Integral)
395+
364396
neighbors_model = NearestNeighbors(
365397
radius=self.eps,
366398
algorithm=self.algorithm,

sklearn/cluster/tests/test_birch.py

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,8 @@ def test_n_clusters():
8585
# Test that the wrong global clustering step raises an Error.
8686
clf = ElasticNet()
8787
brc3 = Birch(n_clusters=clf)
88-
with pytest.raises(ValueError):
88+
err_msg = "n_clusters should be an instance of ClusterMixin or an int"
89+
with pytest.raises(TypeError, match=err_msg):
8990
brc3.fit(X)
9091

9192
# Test that a small number of clusters raises a warning.
@@ -141,11 +142,6 @@ def test_branching_factor():
141142
brc.fit(X)
142143
check_branching_factor(brc.root_, branching_factor)
143144

144-
# Raises error when branching_factor is set to one.
145-
brc = Birch(n_clusters=None, branching_factor=1, threshold=0.01)
146-
with pytest.raises(ValueError):
147-
brc.fit(X)
148-
149145

150146
def check_threshold(birch_instance, threshold):
151147
"""Use the leaf linked list for traversal"""
@@ -187,3 +183,39 @@ def test_birch_fit_attributes_deprecated(attribute):
187183

188184
with pytest.warns(FutureWarning, match=msg):
189185
getattr(brc, attribute)
186+
187+
188+
@pytest.mark.parametrize(
189+
"params, err_type, err_msg",
190+
[
191+
({"threshold": -1.0}, ValueError, "threshold == -1.0, must be > 0.0."),
192+
({"threshold": 0.0}, ValueError, "threshold == 0.0, must be > 0.0."),
193+
({"branching_factor": 0}, ValueError, "branching_factor == 0, must be > 1."),
194+
({"branching_factor": 1}, ValueError, "branching_factor == 1, must be > 1."),
195+
(
196+
{"branching_factor": 1.5},
197+
TypeError,
198+
"branching_factor must be an instance of <class 'numbers.Integral'>, not"
199+
" <class 'float'>.",
200+
),
201+
({"branching_factor": -2}, ValueError, "branching_factor == -2, must be > 1."),
202+
({"n_clusters": 0}, ValueError, "n_clusters == 0, must be >= 1."),
203+
(
204+
{"n_clusters": 2.5},
205+
TypeError,
206+
"n_clusters must be an instance of <class 'numbers.Integral'>, not <class"
207+
" 'float'>.",
208+
),
209+
(
210+
{"n_clusters": "whatever"},
211+
TypeError,
212+
"n_clusters should be an instance of ClusterMixin or an int",
213+
),
214+
({"n_clusters": -3}, ValueError, "n_clusters == -3, must be >= 1."),
215+
],
216+
)
217+
def test_birch_params_validation(params, err_type, err_msg):
218+
"""Check the parameters validation in `Birch`."""
219+
X, _ = make_blobs(n_samples=80, centers=4)
220+
with pytest.raises(err_type, match=err_msg):
221+
Birch(**params).fit(X)

sklearn/cluster/tests/test_dbscan.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,11 +272,8 @@ def test_input_validation():
272272
@pytest.mark.parametrize(
273273
"args",
274274
[
275-
{"eps": -1.0},
276275
{"algorithm": "blah"},
277276
{"metric": "blah"},
278-
{"leaf_size": -1},
279-
{"p": -1},
280277
],
281278
)
282279
def test_dbscan_badargs(args):
@@ -428,3 +425,39 @@ def test_dbscan_precomputed_metric_with_initial_rows_zero():
428425
matrix = sparse.csr_matrix(ar)
429426
labels = DBSCAN(eps=0.2, metric="precomputed", min_samples=2).fit(matrix).labels_
430427
assert_array_equal(labels, [-1, -1, 0, 0, 0, 1, 1])
428+
429+
430+
@pytest.mark.parametrize(
431+
"params, err_type, err_msg",
432+
[
433+
({"eps": -1.0}, ValueError, "eps == -1.0, must be > 0.0."),
434+
({"eps": 0.0}, ValueError, "eps == 0.0, must be > 0.0."),
435+
({"min_samples": 0}, ValueError, "min_samples == 0, must be >= 1."),
436+
(
437+
{"min_samples": 1.5},
438+
TypeError,
439+
"min_samples must be an instance of <class 'numbers.Integral'>, not <class"
440+
" 'float'>.",
441+
),
442+
({"min_samples": -2}, ValueError, "min_samples == -2, must be >= 1."),
443+
({"leaf_size": 0}, ValueError, "leaf_size == 0, must be >= 1."),
444+
(
445+
{"leaf_size": 2.5},
446+
TypeError,
447+
"leaf_size must be an instance of <class 'numbers.Integral'>, not <class"
448+
" 'float'>.",
449+
),
450+
({"leaf_size": -3}, ValueError, "leaf_size == -3, must be >= 1."),
451+
({"p": -2}, ValueError, "p == -2, must be >= 0.0."),
452+
(
453+
{"n_jobs": 2.5},
454+
TypeError,
455+
"n_jobs must be an instance of <class 'numbers.Integral'>, not <class"
456+
" 'float'>.",
457+
),
458+
],
459+
)
460+
def test_dbscan_params_validation(params, err_type, err_msg):
461+
"""Check the parameters validation in `DBSCAN`."""
462+
with pytest.raises(err_type, match=err_msg):
463+
DBSCAN(**params).fit(X)

0 commit comments

Comments
 (0)