From 1fb9940b064443cf8a4dafcc5af26151eb3b112e Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 3 Nov 2020 11:39:05 +0100 Subject: [PATCH 1/5] MNT allows checks generator to be pluggable --- doc/developers/develop.rst | 25 ++++++++++++++++++ sklearn/utils/estimator_checks.py | 43 ++++++++++++++++++++++++------- 2 files changed, 58 insertions(+), 10 deletions(-) diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index b7b5d2ac0316f..23956e36b16df 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -777,3 +777,28 @@ The reason for this setup is reproducibility: when an estimator is ``fit`` twice to the same data, it should produce an identical model both times, hence the validation in ``fit``, not ``__init__``. + +Reuse the testing infrastructure from scikit-learn +-------------------------------------------------- + +Scikit-learn provides two utilities: +:func:`~sklearn.utils.estimator_checks.check_estimator` and +:func:`~sklearn.utils.estimator_checks.parametrize_with_checks`. These +utilities tests if a custom estimator is compatible with the different tools +available in scikit-learn. It is common for a third-party library to extend +the test suite with its own estimator checks. + +Both functions accept a parameter `checks_generator` which can be used for this +purpose. This parameter is a generator that yield callables such as:: + + + def check_estimator_has_fit(name, instance, strict_mode=True): + assert hasattr(instance, "fit"), f"{name} does not implement fit" + + + def checks_generator(estimator): + yield check_estimator_has_fit + +.. warning:: + The API of the checks function is experimental and the expected signature + can change without notice. diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 3cd19967ba9c1..0ba8f40650d98 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -431,7 +431,9 @@ def _should_be_skipped_or_marked(estimator, check, strict_mode): return False, 'placeholder reason that will never be used' -def parametrize_with_checks(estimators, strict_mode=True): +def parametrize_with_checks( + estimators, strict_mode=True, checks_generator=None, +): """Pytest specific decorator for parametrizing estimator checks. The `id` of each check is set to be a pprint version of the estimator @@ -464,6 +466,12 @@ def parametrize_with_checks(estimators, strict_mode=True): .. versionadded:: 0.24 + checks_generator : callable, default=None + The generator yielding checks for the estimators. By default, the + common checks from scikit-learn will be yielded. + + .. versionadded::0.24 + Returns ------- decorator : `pytest.mark.parametrize` @@ -488,18 +496,24 @@ def parametrize_with_checks(estimators, strict_mode=True): "Please pass an instance instead.") raise TypeError(msg) - def checks_generator(): + if checks_generator is None: + checks_generator = _yield_all_checks + + def _checks_generator(): for estimator in estimators: name = type(estimator).__name__ - for check in _yield_all_checks(estimator): + for check in checks_generator(estimator): check = partial(check, name, strict_mode=strict_mode) yield _maybe_mark_xfail(estimator, check, strict_mode, pytest) - return pytest.mark.parametrize("estimator, check", checks_generator(), - ids=_get_check_estimator_ids) + return pytest.mark.parametrize( + "estimator, check", _checks_generator(), ids=_get_check_estimator_ids + ) -def check_estimator(Estimator, generate_only=False, strict_mode=True): +def check_estimator( + Estimator, generate_only=False, strict_mode=True, checks_generator=None, +): """Check if estimator adheres to scikit-learn conventions. This estimator will run an extensive test-suite for input validation, @@ -550,6 +564,12 @@ def check_estimator(Estimator, generate_only=False, strict_mode=True): .. versionadded:: 0.24 + checks_generator : callable, default=None + The generator yielding checks for the estimators. By default, the + common checks from scikit-learn will be yielded. + + .. versionadded::0.24 + Returns ------- checks_generator : generator @@ -565,15 +585,18 @@ def check_estimator(Estimator, generate_only=False, strict_mode=True): estimator = Estimator name = type(estimator).__name__ - def checks_generator(): - for check in _yield_all_checks(estimator): + if checks_generator is None: + checks_generator = _yield_all_checks + + def _checks_generator(): + for check in checks_generator(estimator): check = _maybe_skip(estimator, check, strict_mode) yield estimator, partial(check, name, strict_mode=strict_mode) if generate_only: - return checks_generator() + return _checks_generator() - for estimator, check in checks_generator(): + for estimator, check in _checks_generator(): try: check(estimator) except SkipTest as exception: From f7f8761af88c2a3208f47a7fe1182f8e05b3b374 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 3 Nov 2020 11:47:38 +0100 Subject: [PATCH 2/5] iter --- doc/developers/develop.rst | 4 ++-- doc/whats_new/v0.24.rst | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/doc/developers/develop.rst b/doc/developers/develop.rst index 23956e36b16df..ec7251ddfc445 100644 --- a/doc/developers/develop.rst +++ b/doc/developers/develop.rst @@ -800,5 +800,5 @@ purpose. This parameter is a generator that yield callables such as:: yield check_estimator_has_fit .. warning:: - The API of the checks function is experimental and the expected signature - can change without notice. + This feature is experimental. The signature of the `check` functions can + change without any deprecation cycle. diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index ee027d69f1e9b..bc365876f2121 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -770,6 +770,13 @@ Changelog :func:`utils.sparse_func.incr_mean_variance_axis`. By :user:`Maria Telenczuk ` and :user:`Alex Gramfort `. +- |Enhancement| Allows to pass a custom checks generator using the parameter + `checks_generator` in + :func:`~sklearn.utils.estimator_checks.check_estimator` and + :func:`~sklearn.utils.estimator_checks.parametrize_with_checks`. Note that + this feature is experimental in such way that the expected signature of the + `check` functions can change without deprecation cycle. + :pr:`18750` by :user:`Guillaume Lemaitre`_. Miscellaneous ............. From 84b1e481c6042435aef55eb9502a32b0471584d1 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 3 Nov 2020 13:42:45 +0100 Subject: [PATCH 3/5] TST add tests --- sklearn/utils/tests/test_estimator_checks.py | 56 ++++++++++++++++++-- 1 file changed, 52 insertions(+), 4 deletions(-) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index ecbf7cb7be7f4..63a079cdc2ca0 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -1,5 +1,6 @@ import unittest import sys +import warnings import numpy as np import scipy.sparse as sp @@ -7,10 +8,14 @@ from sklearn.base import BaseEstimator, ClassifierMixin from sklearn.utils import deprecated -from sklearn.utils._testing import (assert_raises_regex, - ignore_warnings, - assert_warns, assert_raises, - SkipTest) +from sklearn.utils._testing import ( + assert_raises_regex, + ignore_warnings, + assert_warns, + assert_warns_message, + assert_raises, + SkipTest, +) from sklearn.utils.estimator_checks import check_estimator, _NotAnArray from sklearn.utils.estimator_checks \ import check_class_weight_balanced_linear_classifier @@ -21,6 +26,8 @@ from sklearn.utils.estimator_checks import check_no_attributes_set_in_init from sklearn.utils.estimator_checks import check_classifier_data_not_an_array from sklearn.utils.estimator_checks import check_regressor_data_not_an_array +from sklearn.utils.estimator_checks import check_estimator +from sklearn.utils.estimator_checks import parametrize_with_checks from sklearn.utils.validation import check_is_fitted from sklearn.utils.estimator_checks import check_outlier_corruption from sklearn.utils.fixes import np_version, parse_version @@ -677,3 +684,44 @@ def test_xfail_ignored_in_check_estimator(): # Make sure checks marked as xfail are just ignored and not run by # check_estimator(), but still raise a warning. assert_warns(SkipTestWarning, check_estimator, NuSVC()) + + +def my_own_check(name, instance, strict_mode=True): + warnings.warn("my_own_check was executed", UserWarning) + + +def my_own_generator(estimator): + yield my_own_check + + +def test_check_estimator_checks_generator(): + # Check that we can pass a custom checks generator in `check_estimator` + assert_warns_message( + UserWarning, + "my_own_check was executed", + check_estimator, + BaseEstimator(), + checks_generator=my_own_generator, + ) + + +def test_parametrize_with_checks_checks_generator(): + # Check that we can pass a custom checks generator in + # `parametrize_with_checks` + decorator = parametrize_with_checks( + [BaseEstimator()], checks_generator=my_own_generator + ) + + def test_estimator(estimator, check): + check(estimator) + + test_estimator = decorator(test_estimator) + for _mark in test_estimator.pytestmark: + for estimator, check in _mark.args[1]: + assert_warns_message( + UserWarning, + "my_own_check was executed", + test_estimator, + estimator, + check, + ) From c2970a4ec4c262d66ef6dd5f322a01117a57faa8 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 3 Nov 2020 14:06:24 +0100 Subject: [PATCH 4/5] iter --- doc/whats_new/v0.24.rst | 2 +- sklearn/utils/tests/test_estimator_checks.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/doc/whats_new/v0.24.rst b/doc/whats_new/v0.24.rst index bc365876f2121..e74a744c8f647 100644 --- a/doc/whats_new/v0.24.rst +++ b/doc/whats_new/v0.24.rst @@ -776,7 +776,7 @@ Changelog :func:`~sklearn.utils.estimator_checks.parametrize_with_checks`. Note that this feature is experimental in such way that the expected signature of the `check` functions can change without deprecation cycle. - :pr:`18750` by :user:`Guillaume Lemaitre`_. + :pr:`18750` by :user:`Guillaume Lemaitre `. Miscellaneous ............. diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 63a079cdc2ca0..3ba80879c13fa 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -26,7 +26,6 @@ from sklearn.utils.estimator_checks import check_no_attributes_set_in_init from sklearn.utils.estimator_checks import check_classifier_data_not_an_array from sklearn.utils.estimator_checks import check_regressor_data_not_an_array -from sklearn.utils.estimator_checks import check_estimator from sklearn.utils.estimator_checks import parametrize_with_checks from sklearn.utils.validation import check_is_fitted from sklearn.utils.estimator_checks import check_outlier_corruption From 85378362301379ed17a0cba96cfb3eaa10d05155 Mon Sep 17 00:00:00 2001 From: Guillaume Lemaitre Date: Tue, 3 Nov 2020 14:10:12 +0100 Subject: [PATCH 5/5] iter --- sklearn/utils/estimator_checks.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 0ba8f40650d98..a2b63153ec357 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -470,7 +470,7 @@ def parametrize_with_checks( The generator yielding checks for the estimators. By default, the common checks from scikit-learn will be yielded. - .. versionadded::0.24 + .. versionadded:: 0.24 Returns ------- @@ -568,7 +568,7 @@ def check_estimator( The generator yielding checks for the estimators. By default, the common checks from scikit-learn will be yielded. - .. versionadded::0.24 + .. versionadded:: 0.24 Returns -------