From b50f64eb5e47138b9f01d3878d2bc3e32b886662 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 1 Dec 2022 15:50:13 +0100 Subject: [PATCH 1/3] do not force 1 to 1 matching --- sklearn/utils/_param_validation.py | 7 ------- sklearn/utils/tests/test_param_validation.py | 14 -------------- 2 files changed, 21 deletions(-) diff --git a/sklearn/utils/_param_validation.py b/sklearn/utils/_param_validation.py index 797063a31dd96..5b6a0b2d78211 100644 --- a/sklearn/utils/_param_validation.py +++ b/sklearn/utils/_param_validation.py @@ -50,13 +50,6 @@ def validate_parameter_constraints(parameter_constraints, params, caller_name): caller_name : str The name of the estimator or function or method that called this function. """ - if len(set(parameter_constraints) - set(params)) != 0: - raise ValueError( - f"The parameter constraints {list(parameter_constraints)}" - " contain unexpected parameters" - f" {set(parameter_constraints) - set(params)}" - ) - for param_name, param_val in params.items(): # We allow parameters to not have a constraint so that third party estimators # can inherit from sklearn estimators without having to necessarily use the diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index fd73797582631..e4c3aa208392d 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -453,20 +453,6 @@ def test_validate_params(): _func(0, *[1, 2, 3], c="four", **{"e": 5}) -def test_validate_params_match_error(): - """Check that an informative error is raised when there are constraints - that have no matching function paramaters - """ - - @validate_params({"a": [int], "c": [int]}) - def func(a, b): - pass - - match = r"The parameter constraints .* contain unexpected parameters {'c'}" - with pytest.raises(ValueError, match=match): - func(1, 2) - - def test_validate_params_missing_params(): """Check that no error is raised when there are parameters without constraints From b5800501b729d298bbd0cc2e1ba0aa731c6c27de Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Thu, 1 Dec 2022 16:44:33 +0100 Subject: [PATCH 2/3] add test --- sklearn/utils/tests/test_param_validation.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index e4c3aa208392d..93d48f5d4c6d7 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -619,3 +619,22 @@ def test_cv_objects(): assert constraint.is_satisfied_by([([1, 2], [3, 4]), ([3, 4], [1, 2])]) assert constraint.is_satisfied_by(None) assert not constraint.is_satisfied_by("not a CV object") + + +def test_third_party_estimator(): + """Check that the validation from a scikit-learn estimator inherited by a third + party estimator does not impose a match between the dict of constraints and the + parameters of the estimator. + """ + + class ThirdPartyEstimator(_Estimator): + def __init__(self, b): + self.b = b + super().__init__(a=0) + + def fit(self, X=None, y=None): + super().fit(X, y) + + # does not raise, niether because "b" is not in the constraints dict, neither + # because "a" is not a parameter of the estimator. + ThirdPartyEstimator(b=0).fit() From 37480236066da452faf7ab71a3790797071be737 Mon Sep 17 00:00:00 2001 From: jeremie du boisberranger Date: Fri, 2 Dec 2022 11:33:10 +0100 Subject: [PATCH 3/3] reword --- sklearn/utils/tests/test_param_validation.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/tests/test_param_validation.py b/sklearn/utils/tests/test_param_validation.py index 93d48f5d4c6d7..074e08729764f 100644 --- a/sklearn/utils/tests/test_param_validation.py +++ b/sklearn/utils/tests/test_param_validation.py @@ -635,6 +635,6 @@ def __init__(self, b): def fit(self, X=None, y=None): super().fit(X, y) - # does not raise, niether because "b" is not in the constraints dict, neither - # because "a" is not a parameter of the estimator. + # does not raise, even though "b" is not in the constraints dict and "a" is not + # a parameter of the estimator. ThirdPartyEstimator(b=0).fit()