From 7258ccfc1322fd0b9a9361fbc1d68877c61651fa Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 10 Aug 2021 12:27:52 +0200 Subject: [PATCH 1/8] MNT use check_scalar to validate scalar in AffinityPropagation --- sklearn/cluster/_affinity_propagation.py | 20 ++++- .../tests/test_affinity_propagation.py | 38 +++++++--- sklearn/utils/tests/test_validation.py | 73 ++++++++++++++++--- sklearn/utils/validation.py | 40 ++++++++-- 4 files changed, 140 insertions(+), 31 deletions(-) diff --git a/sklearn/cluster/_affinity_propagation.py b/sklearn/cluster/_affinity_propagation.py index b167f1cb2c212..0c594f8ec38ef 100644 --- a/sklearn/cluster/_affinity_propagation.py +++ b/sklearn/cluster/_affinity_propagation.py @@ -5,12 +5,15 @@ # License: BSD 3 clause -import numpy as np +import numbers import warnings +import numpy as np + from ..exceptions import ConvergenceWarning from ..base import BaseEstimator, ClusterMixin from ..utils import as_float_array, check_random_state +from ..utils import check_scalar from ..utils.deprecation import deprecated from ..utils.validation import check_is_fitted from ..metrics import euclidean_distances @@ -132,8 +135,6 @@ def affinity_propagation( if preference is None: preference = np.median(S) - if damping < 0.5 or damping >= 1: - raise ValueError("damping must be >= 0.5 and < 1") preference = np.array(preference) @@ -456,6 +457,19 @@ def fit(self, X, y=None): % str(self.affinity) ) + check_scalar( + self.damping, + "damping", + numbers.Real, + min_val=0.5, + max_val=1, + strictly_less_max_val=True, + ) + check_scalar(self.max_iter, "max_iter", numbers.Integral, min_val=1) + check_scalar( + self.convergence_iter, "convergence_iter", numbers.Integral, min_val=1 + ) + ( self.cluster_centers_indices_, self.labels_, diff --git a/sklearn/cluster/tests/test_affinity_propagation.py b/sklearn/cluster/tests/test_affinity_propagation.py index ac3ef2a6a16fd..4fa9f14c01487 100644 --- a/sklearn/cluster/tests/test_affinity_propagation.py +++ b/sklearn/cluster/tests/test_affinity_propagation.py @@ -64,17 +64,35 @@ def test_affinity_propagation(): ) assert_array_equal(labels, labels_no_copy) - # Test input validation - with pytest.raises(ValueError): + +def test_affinity_propagation_affinity_shape(): + """Check the shape of the affinity matrix when using `affinity_propagation.""" + S = -euclidean_distances(X, squared=True) + err_msg = "S must be a square array" + with pytest.raises(ValueError, match=err_msg): affinity_propagation(S[:, :-1]) - with pytest.raises(ValueError): - affinity_propagation(S, damping=0) - af = AffinityPropagation(affinity="unknown", random_state=78) - with pytest.raises(ValueError): - af.fit(X) - af_2 = AffinityPropagation(affinity="precomputed", random_state=21) - with pytest.raises(TypeError): - af_2.fit(csr_matrix((3, 3))) + + +@pytest.mark.parametrize( + "input, params, err_type, err_msg", + [ + (X, {"damping": 0}, ValueError, "`damping`= 0, must be >= 0.5"), + (X, {"damping": 2}, ValueError, "`damping`= 2, must be < 1"), + (X, {"max_iter": 0}, ValueError, "`max_iter`= 0, must be >= 1."), + (X, {"convergence_iter": 0}, ValueError, "`convergence_iter`= 0, must be >= 1"), + (X, {"affinity": "unknown"}, ValueError, "Affinity must be"), + ( + csr_matrix((3, 3)), + {"affinity": "precomputed"}, + TypeError, + "A sparse matrix was passed, but dense data is required", + ), + ], +) +def test_affinity_propagation_params_validation(input, params, err_type, err_msg): + """Check the parameters validation in `AffinityPropagation`.""" + with pytest.raises(err_type, match=err_msg): + AffinityPropagation(**params).fit(input) def test_affinity_propagation_predict(): diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 1a1449ecc209f..2d323b299203e 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1,5 +1,6 @@ """Tests for input validation functions""" +import numbers import warnings import os import re @@ -1004,43 +1005,95 @@ def __init__(self): _num_samples(TestNoLenWeirdShape()) -@pytest.mark.parametrize( - "x, target_type, min_val, max_val", [(3, int, 2, 5), (2.5, float, 2, 5)] -) -def test_check_scalar_valid(x, target_type, min_val, max_val): +@pytest.mark.parametrize("x", [2, 3, 2.5, 5]) +def test_check_scalar_valid(x): """Test that check_scalar returns no error/warning if valid inputs are provided""" with pytest.warns(None) as record: check_scalar( - x, "test_name", target_type=target_type, min_val=min_val, max_val=max_val + x, + "test_name", + target_type=numbers.Real, + min_val=2, + strictly_greater_min_val=False, + max_val=5, + strictly_less_max_val=False, ) assert len(record) == 0 @pytest.mark.parametrize( - "x, target_name, target_type, min_val, max_val, err_msg", + "x, target_name, target_type, min_val, strictly_gt, max_val, strictly_lt, err_msg", [ ( 1, "test_name1", float, 2, + False, 4, + False, TypeError( "`test_name1` must be an instance of " ", not ." ), ), - (1, "test_name2", int, 2, 4, ValueError("`test_name2`= 1, must be >= 2.")), - (5, "test_name3", int, 2, 4, ValueError("`test_name3`= 5, must be <= 4.")), + ( + 1, + "test_name2", + int, + 2, + False, + 4, + False, + ValueError("`test_name2`= 1, must be >= 2."), + ), + ( + 5, + "test_name3", + int, + 2, + False, + 4, + False, + ValueError("`test_name3`= 5, must be <= 4."), + ), + ( + 2, + "test_name4", + int, + 2, + True, + 4, + False, + ValueError("`test_name4`= 2, must be > 2."), + ), + ( + 4, + "test_name5", + int, + 2, + False, + 4, + True, + ValueError("`test_name5`= 4, must be < 4."), + ), ], ) -def test_check_scalar_invalid(x, target_name, target_type, min_val, max_val, err_msg): +def test_check_scalar_invalid( + x, target_name, target_type, min_val, strictly_gt, max_val, strictly_lt, err_msg +): """Test that check_scalar returns the right error if a wrong input is given""" with pytest.raises(Exception) as raised_error: check_scalar( - x, target_name, target_type=target_type, min_val=min_val, max_val=max_val + x, + target_name, + target_type=target_type, + min_val=min_val, + strictly_greater_min_val=strictly_gt, + max_val=max_val, + strictly_less_max_val=strictly_lt, ) assert str(raised_error.value) == str(err_msg) assert type(raised_error.value) == type(err_msg) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 98bf6ac8bdb6a..3e233d68ccb72 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -12,6 +12,7 @@ from functools import wraps import warnings import numbers +import operator import numpy as np import scipy.sparse as sp @@ -1231,7 +1232,16 @@ def check_non_negative(X, whom): raise ValueError("Negative values in data passed to %s" % whom) -def check_scalar(x, name, target_type, *, min_val=None, max_val=None): +def check_scalar( + x, + name, + target_type, + *, + min_val=None, + strictly_greater_min_val=False, + max_val=None, + strictly_less_max_val=False, +): """Validate scalar parameters type and value. Parameters @@ -1249,12 +1259,18 @@ def check_scalar(x, name, target_type, *, min_val=None, max_val=None): The minimum valid value the parameter can take. If None (default) it is implied that the parameter does not have a lower bound. - max_val : float or int, default=None + strictly_greater_min_val : bool, default=True + Whether the parameter should be strictly greater to `min_val`. + + max_val : float or int, default=False The maximum valid value the parameter can take. If None (default) it is implied that the parameter does not have an upper bound. + strictly_less_max_val : bool, default=False + Whether the parameter should be strictly less to `max_val`. + Raises - ------- + ------ TypeError If the parameter's type does not match the desired type. @@ -1264,14 +1280,22 @@ def check_scalar(x, name, target_type, *, min_val=None, max_val=None): if not isinstance(x, target_type): raise TypeError( - "`{}` must be an instance of {}, not {}.".format(name, target_type, type(x)) + f"`{name}` must be an instance of {target_type}, not {type(x)}." ) - if min_val is not None and x < min_val: - raise ValueError("`{}`= {}, must be >= {}.".format(name, x, min_val)) + comparison_operator = operator.le if strictly_greater_min_val else operator.lt + if min_val is not None and comparison_operator(x, min_val): + raise ValueError( + f"`{name}`= {x}, must be {'>' if strictly_greater_min_val else '>='} " + f"{min_val}." + ) - if max_val is not None and x > max_val: - raise ValueError("`{}`= {}, must be <= {}.".format(name, x, max_val)) + comparison_operator = operator.ge if strictly_less_max_val else operator.gt + if max_val is not None and comparison_operator(x, max_val): + raise ValueError( + f"`{name}`= {x}, must be {'<' if strictly_less_max_val else '<='} " + f"{max_val}." + ) def _check_psd_eigenvalues(lambdas, enable_warnings=False): From 36f8c7ce50dfa3c64b41c2dc58aa94d4177906f9 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 10 Aug 2021 12:55:43 +0200 Subject: [PATCH 2/8] refactor --- sklearn/cluster/_affinity_propagation.py | 26 +++++++++++++----------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/sklearn/cluster/_affinity_propagation.py b/sklearn/cluster/_affinity_propagation.py index 0c594f8ec38ef..9f745e4a8a936 100644 --- a/sklearn/cluster/_affinity_propagation.py +++ b/sklearn/cluster/_affinity_propagation.py @@ -457,18 +457,20 @@ def fit(self, X, y=None): % str(self.affinity) ) - check_scalar( - self.damping, - "damping", - numbers.Real, - min_val=0.5, - max_val=1, - strictly_less_max_val=True, - ) - check_scalar(self.max_iter, "max_iter", numbers.Integral, min_val=1) - check_scalar( - self.convergence_iter, "convergence_iter", numbers.Integral, min_val=1 - ) + scalars_checks = { + "damping": { + "target_type": numbers.Real, + "min_val": 0.5, + "max_val": 1, + "strictly_less_max_val": True, + }, + "max_iter": {"target_type": numbers.Integral, "min_val": 1}, + "convergence_iter": {"target_type": numbers.Integral, "min_val": 1}, + } + for scalar_name in scalars_checks: + check_scalar( + getattr(self, scalar_name), scalar_name, **scalars_checks[scalar_name] + ) ( self.cluster_centers_indices_, From 47d71d05dddd815b14ae07edde24f657ed268b1c Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Sun, 22 Aug 2021 20:08:35 +0200 Subject: [PATCH 3/8] Apply suggestions from code review Co-authored-by: Olivier Grisel --- sklearn/cluster/tests/test_affinity_propagation.py | 8 ++++---- sklearn/utils/tests/test_validation.py | 8 ++++---- sklearn/utils/validation.py | 4 ++-- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sklearn/cluster/tests/test_affinity_propagation.py b/sklearn/cluster/tests/test_affinity_propagation.py index 4fa9f14c01487..e1b46cb728257 100644 --- a/sklearn/cluster/tests/test_affinity_propagation.py +++ b/sklearn/cluster/tests/test_affinity_propagation.py @@ -76,10 +76,10 @@ def test_affinity_propagation_affinity_shape(): @pytest.mark.parametrize( "input, params, err_type, err_msg", [ - (X, {"damping": 0}, ValueError, "`damping`= 0, must be >= 0.5"), - (X, {"damping": 2}, ValueError, "`damping`= 2, must be < 1"), - (X, {"max_iter": 0}, ValueError, "`max_iter`= 0, must be >= 1."), - (X, {"convergence_iter": 0}, ValueError, "`convergence_iter`= 0, must be >= 1"), + (X, {"damping": 0}, ValueError, "damping == 0, must be >= 0.5"), + (X, {"damping": 2}, ValueError, "damping == 2, must be < 1"), + (X, {"max_iter": 0}, ValueError, "max_iter == 0, must be >= 1."), + (X, {"convergence_iter": 0}, ValueError, "convergence_iter == 0, must be >= 1"), (X, {"affinity": "unknown"}, ValueError, "Affinity must be"), ( csr_matrix((3, 3)), diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 2d323b299203e..fce0e5a14fc7a 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1046,7 +1046,7 @@ def test_check_scalar_valid(x): False, 4, False, - ValueError("`test_name2`= 1, must be >= 2."), + ValueError("test_name2 == 1, must be >= 2."), ), ( 5, @@ -1056,7 +1056,7 @@ def test_check_scalar_valid(x): False, 4, False, - ValueError("`test_name3`= 5, must be <= 4."), + ValueError("test_name3 == 5, must be <= 4."), ), ( 2, @@ -1066,7 +1066,7 @@ def test_check_scalar_valid(x): True, 4, False, - ValueError("`test_name4`= 2, must be > 2."), + ValueError("test_name4 == 2, must be > 2."), ), ( 4, @@ -1076,7 +1076,7 @@ def test_check_scalar_valid(x): False, 4, True, - ValueError("`test_name5`= 4, must be < 4."), + ValueError("test_name5 == 4, must be < 4."), ), ], ) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 3e233d68ccb72..95cb952ab6ac9 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1286,14 +1286,14 @@ def check_scalar( comparison_operator = operator.le if strictly_greater_min_val else operator.lt if min_val is not None and comparison_operator(x, min_val): raise ValueError( - f"`{name}`= {x}, must be {'>' if strictly_greater_min_val else '>='} " + f" {name} == {x}, must be {'>' if strictly_greater_min_val else '>='} " f"{min_val}." ) comparison_operator = operator.ge if strictly_less_max_val else operator.gt if max_val is not None and comparison_operator(x, max_val): raise ValueError( - f"`{name}`= {x}, must be {'<' if strictly_less_max_val else '<='} " + f"{name} == {x}, must be {'<' if strictly_less_max_val else '<='} " f"{max_val}." ) From c0b04b014ffb43ee061d235db91ee2b89159c9d5 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Aug 2021 13:37:27 +0200 Subject: [PATCH 4/8] iter --- sklearn/neighbors/tests/test_nca.py | 2 +- sklearn/utils/tests/test_validation.py | 3 +-- sklearn/utils/validation.py | 6 ++---- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/sklearn/neighbors/tests/test_nca.py b/sklearn/neighbors/tests/test_nca.py index f1ff623479a77..f9d7e5503a2c8 100644 --- a/sklearn/neighbors/tests/test_nca.py +++ b/sklearn/neighbors/tests/test_nca.py @@ -139,7 +139,7 @@ def test_params_validation(): ) with pytest.raises(ValueError, match=re.escape(msg)): NCA(init=1).fit(X, y) - with pytest.raises(ValueError, match="`max_iter`= -1, must be >= 1."): + with pytest.raises(ValueError, match="max_iter == -1, must be >= 1."): NCA(max_iter=-1).fit(X, y) init = rng.rand(5, 3) msg = ( diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index fce0e5a14fc7a..972ffbf1310ec 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1034,8 +1034,7 @@ def test_check_scalar_valid(x): 4, False, TypeError( - "`test_name1` must be an instance of " - ", not ." + "test_name1 must be an instance of , not ." ), ), ( diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 95cb952ab6ac9..54a622d06f03e 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1279,14 +1279,12 @@ def check_scalar( """ if not isinstance(x, target_type): - raise TypeError( - f"`{name}` must be an instance of {target_type}, not {type(x)}." - ) + raise TypeError(f"{name} must be an instance of {target_type}, not {type(x)}.") comparison_operator = operator.le if strictly_greater_min_val else operator.lt if min_val is not None and comparison_operator(x, min_val): raise ValueError( - f" {name} == {x}, must be {'>' if strictly_greater_min_val else '>='} " + f"{name} == {x}, must be {'>' if strictly_greater_min_val else '>='} " f"{min_val}." ) From be288d7ee59e1a339f73e455828da674a640f4a5 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Aug 2021 14:01:32 +0200 Subject: [PATCH 5/8] iter --- sklearn/cluster/_affinity_propagation.py | 29 ++++++++++++------------ sklearn/utils/tests/test_validation.py | 3 ++- sklearn/utils/validation.py | 7 ++++++ 3 files changed, 24 insertions(+), 15 deletions(-) diff --git a/sklearn/cluster/_affinity_propagation.py b/sklearn/cluster/_affinity_propagation.py index 9f745e4a8a936..5a09e6b81dce9 100644 --- a/sklearn/cluster/_affinity_propagation.py +++ b/sklearn/cluster/_affinity_propagation.py @@ -457,20 +457,21 @@ def fit(self, X, y=None): % str(self.affinity) ) - scalars_checks = { - "damping": { - "target_type": numbers.Real, - "min_val": 0.5, - "max_val": 1, - "strictly_less_max_val": True, - }, - "max_iter": {"target_type": numbers.Integral, "min_val": 1}, - "convergence_iter": {"target_type": numbers.Integral, "min_val": 1}, - } - for scalar_name in scalars_checks: - check_scalar( - getattr(self, scalar_name), scalar_name, **scalars_checks[scalar_name] - ) + check_scalar( + self.damping, + "damping", + target_type=numbers.Real, + min_val=0.5, + max_val=1, + strictly_less_max_val=True, + ) + check_scalar(self.max_iter, "max_iter", target_type=numbers.Integral, min_val=1) + check_scalar( + self.convergence_iter, + "convergence_iter", + target_type=numbers.Integral, + min_val=1, + ) ( self.cluster_centers_indices_, diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 972ffbf1310ec..8f86504afacb9 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1010,7 +1010,7 @@ def test_check_scalar_valid(x): """Test that check_scalar returns no error/warning if valid inputs are provided""" with pytest.warns(None) as record: - check_scalar( + scalar = check_scalar( x, "test_name", target_type=numbers.Real, @@ -1020,6 +1020,7 @@ def test_check_scalar_valid(x): strictly_less_max_val=False, ) assert len(record) == 0 + assert scalar == x @pytest.mark.parametrize( diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 54a622d06f03e..c664d8de22cff 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1269,6 +1269,11 @@ def check_scalar( strictly_less_max_val : bool, default=False Whether the parameter should be strictly less to `max_val`. + Returns + ------- + x : object + The validated object. + Raises ------ TypeError @@ -1295,6 +1300,8 @@ def check_scalar( f"{max_val}." ) + return x + def _check_psd_eigenvalues(lambdas, enable_warnings=False): """Check the eigenvalues of a positive semidefinite (PSD) matrix. From b5ac07fa20626b30116316debbf6735f6f2ba521 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Aug 2021 14:12:54 +0200 Subject: [PATCH 6/8] iter --- sklearn/utils/validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index c664d8de22cff..c2caa3cd256ca 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1271,8 +1271,8 @@ def check_scalar( Returns ------- - x : object - The validated object. + x : numbers.Number + The validated number. Raises ------ From 9c2c61ba1dca461e7fdac59f82011d16a3994efe Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Aug 2021 19:35:50 +0200 Subject: [PATCH 7/8] iter --- sklearn/cluster/_affinity_propagation.py | 2 +- sklearn/utils/tests/test_validation.py | 8 ++++---- sklearn/utils/validation.py | 18 ++++++++---------- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/sklearn/cluster/_affinity_propagation.py b/sklearn/cluster/_affinity_propagation.py index 5a09e6b81dce9..aefbe28ba67da 100644 --- a/sklearn/cluster/_affinity_propagation.py +++ b/sklearn/cluster/_affinity_propagation.py @@ -463,7 +463,7 @@ def fit(self, X, y=None): target_type=numbers.Real, min_val=0.5, max_val=1, - strictly_less_max_val=True, + strictly_lt_max_val=True, ) check_scalar(self.max_iter, "max_iter", target_type=numbers.Integral, min_val=1) check_scalar( diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 8f86504afacb9..98e9365697cf1 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1015,9 +1015,9 @@ def test_check_scalar_valid(x): "test_name", target_type=numbers.Real, min_val=2, - strictly_greater_min_val=False, + strictly_gt_min_val=False, max_val=5, - strictly_less_max_val=False, + strictly_lt_max_val=False, ) assert len(record) == 0 assert scalar == x @@ -1091,9 +1091,9 @@ def test_check_scalar_invalid( target_name, target_type=target_type, min_val=min_val, - strictly_greater_min_val=strictly_gt, + strictly_gt_min_val=strictly_gt, max_val=max_val, - strictly_less_max_val=strictly_lt, + strictly_lt_max_val=strictly_lt, ) assert str(raised_error.value) == str(err_msg) assert type(raised_error.value) == type(err_msg) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index c2caa3cd256ca..51a89dbaf4092 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1238,9 +1238,9 @@ def check_scalar( target_type, *, min_val=None, - strictly_greater_min_val=False, + strictly_gt_min_val=False, max_val=None, - strictly_less_max_val=False, + strictly_lt_max_val=False, ): """Validate scalar parameters type and value. @@ -1259,14 +1259,14 @@ def check_scalar( The minimum valid value the parameter can take. If None (default) it is implied that the parameter does not have a lower bound. - strictly_greater_min_val : bool, default=True + strictly_gt_min_val : bool, default=True Whether the parameter should be strictly greater to `min_val`. max_val : float or int, default=False The maximum valid value the parameter can take. If None (default) it is implied that the parameter does not have an upper bound. - strictly_less_max_val : bool, default=False + strictly_lt_max_val : bool, default=False Whether the parameter should be strictly less to `max_val`. Returns @@ -1286,18 +1286,16 @@ def check_scalar( if not isinstance(x, target_type): raise TypeError(f"{name} must be an instance of {target_type}, not {type(x)}.") - comparison_operator = operator.le if strictly_greater_min_val else operator.lt + comparison_operator = operator.le if strictly_gt_min_val else operator.lt if min_val is not None and comparison_operator(x, min_val): raise ValueError( - f"{name} == {x}, must be {'>' if strictly_greater_min_val else '>='} " - f"{min_val}." + f"{name} == {x}, must be {'>' if strictly_gt_min_val else '>='} {min_val}." ) - comparison_operator = operator.ge if strictly_less_max_val else operator.gt + comparison_operator = operator.ge if strictly_lt_max_val else operator.gt if max_val is not None and comparison_operator(x, max_val): raise ValueError( - f"{name} == {x}, must be {'<' if strictly_less_max_val else '<='} " - f"{max_val}." + f"{name} == {x}, must be {'<' if strictly_lt_max_val else '<='} {max_val}." ) return x From 767cf7c318dcd85d58ebb26467bbce3cc7d692d4 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Mon, 30 Aug 2021 19:49:34 +0200 Subject: [PATCH 8/8] iter --- sklearn/cluster/_affinity_propagation.py | 2 +- sklearn/utils/tests/test_validation.py | 25 +++++++++--------------- sklearn/utils/validation.py | 25 +++++++++++++----------- 3 files changed, 24 insertions(+), 28 deletions(-) diff --git a/sklearn/cluster/_affinity_propagation.py b/sklearn/cluster/_affinity_propagation.py index aefbe28ba67da..d777309b1e2d5 100644 --- a/sklearn/cluster/_affinity_propagation.py +++ b/sklearn/cluster/_affinity_propagation.py @@ -463,7 +463,7 @@ def fit(self, X, y=None): target_type=numbers.Real, min_val=0.5, max_val=1, - strictly_lt_max_val=True, + closed="right", ) check_scalar(self.max_iter, "max_iter", target_type=numbers.Integral, min_val=1) check_scalar( diff --git a/sklearn/utils/tests/test_validation.py b/sklearn/utils/tests/test_validation.py index 98e9365697cf1..4352a1f56f0f6 100644 --- a/sklearn/utils/tests/test_validation.py +++ b/sklearn/utils/tests/test_validation.py @@ -1015,25 +1015,23 @@ def test_check_scalar_valid(x): "test_name", target_type=numbers.Real, min_val=2, - strictly_gt_min_val=False, max_val=5, - strictly_lt_max_val=False, + closed="neither", ) assert len(record) == 0 assert scalar == x @pytest.mark.parametrize( - "x, target_name, target_type, min_val, strictly_gt, max_val, strictly_lt, err_msg", + "x, target_name, target_type, min_val, max_val, closed, err_msg", [ ( 1, "test_name1", float, 2, - False, 4, - False, + "neither", TypeError( "test_name1 must be an instance of , not ." ), @@ -1043,9 +1041,8 @@ def test_check_scalar_valid(x): "test_name2", int, 2, - False, 4, - False, + "neither", ValueError("test_name2 == 1, must be >= 2."), ), ( @@ -1053,9 +1050,8 @@ def test_check_scalar_valid(x): "test_name3", int, 2, - False, 4, - False, + "neither", ValueError("test_name3 == 5, must be <= 4."), ), ( @@ -1063,9 +1059,8 @@ def test_check_scalar_valid(x): "test_name4", int, 2, - True, 4, - False, + "left", ValueError("test_name4 == 2, must be > 2."), ), ( @@ -1073,15 +1068,14 @@ def test_check_scalar_valid(x): "test_name5", int, 2, - False, 4, - True, + "right", ValueError("test_name5 == 4, must be < 4."), ), ], ) def test_check_scalar_invalid( - x, target_name, target_type, min_val, strictly_gt, max_val, strictly_lt, err_msg + x, target_name, target_type, min_val, max_val, closed, err_msg ): """Test that check_scalar returns the right error if a wrong input is given""" @@ -1091,9 +1085,8 @@ def test_check_scalar_invalid( target_name, target_type=target_type, min_val=min_val, - strictly_gt_min_val=strictly_gt, max_val=max_val, - strictly_lt_max_val=strictly_lt, + closed=closed, ) assert str(raised_error.value) == str(err_msg) assert type(raised_error.value) == type(err_msg) diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index 51a89dbaf4092..4ac33e0d02d22 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1238,9 +1238,8 @@ def check_scalar( target_type, *, min_val=None, - strictly_gt_min_val=False, max_val=None, - strictly_lt_max_val=False, + closed="neither", ): """Validate scalar parameters type and value. @@ -1259,15 +1258,13 @@ def check_scalar( The minimum valid value the parameter can take. If None (default) it is implied that the parameter does not have a lower bound. - strictly_gt_min_val : bool, default=True - Whether the parameter should be strictly greater to `min_val`. - max_val : float or int, default=False The maximum valid value the parameter can take. If None (default) it is implied that the parameter does not have an upper bound. - strictly_lt_max_val : bool, default=False - Whether the parameter should be strictly less to `max_val`. + closed : {"left", "right", "both", "neither"}, default="neither" + Whether the interval is closed on the left-side, right-side, both or + neither. Returns ------- @@ -1286,16 +1283,22 @@ def check_scalar( if not isinstance(x, target_type): raise TypeError(f"{name} must be an instance of {target_type}, not {type(x)}.") - comparison_operator = operator.le if strictly_gt_min_val else operator.lt + expected_closed = {"left", "right", "both", "neither"} + if closed not in expected_closed: + raise ValueError(f"Unknown value for `closed`: {closed}") + + comparison_operator = operator.le if closed in ("left", "both") else operator.lt if min_val is not None and comparison_operator(x, min_val): raise ValueError( - f"{name} == {x}, must be {'>' if strictly_gt_min_val else '>='} {min_val}." + f"{name} == {x}, must be" + f" {'>' if closed in ('left', 'both') else '>='} {min_val}." ) - comparison_operator = operator.ge if strictly_lt_max_val else operator.gt + comparison_operator = operator.ge if closed in ("right", "both") else operator.gt if max_val is not None and comparison_operator(x, max_val): raise ValueError( - f"{name} == {x}, must be {'<' if strictly_lt_max_val else '<='} {max_val}." + f"{name} == {x}, must be" + f" {'<' if closed in ('right', 'both') else '<='} {max_val}." ) return x