-
-
Notifications
You must be signed in to change notification settings - Fork 26k
ENH allows checks generator to be pluggable #18750
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -777,3 +777,28 @@ The reason for this setup is reproducibility: | |
when an estimator is ``fit`` twice to the same data, | ||
it should produce an identical model both times, | ||
hence the validation in ``fit``, not ``__init__``. | ||
|
||
Reuse the testing infrastructure from scikit-learn | ||
-------------------------------------------------- | ||
|
||
Scikit-learn provides two utilities: | ||
:func:`~sklearn.utils.estimator_checks.check_estimator` and | ||
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks`. These | ||
utilities tests if a custom estimator is compatible with the different tools | ||
available in scikit-learn. It is common for a third-party library to extend | ||
the test suite with its own estimator checks. | ||
|
||
Both functions accept a parameter `checks_generator` which can be used for this | ||
purpose. This parameter is a generator that yield callables such as:: | ||
|
||
|
||
def check_estimator_has_fit(name, instance, strict_mode=True): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this will be api_only There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yep but we need to merge your PR first. |
||
assert hasattr(instance, "fit"), f"{name} does not implement fit" | ||
|
||
|
||
def checks_generator(estimator): | ||
yield check_estimator_has_fit | ||
|
||
.. warning:: | ||
This feature is experimental. The signature of the `check` functions can | ||
change without any deprecation cycle. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -431,7 +431,9 @@ def _should_be_skipped_or_marked(estimator, check, strict_mode): | |
return False, 'placeholder reason that will never be used' | ||
|
||
|
||
def parametrize_with_checks(estimators, strict_mode=True): | ||
def parametrize_with_checks( | ||
estimators, strict_mode=True, checks_generator=None, | ||
): | ||
"""Pytest specific decorator for parametrizing estimator checks. | ||
|
||
The `id` of each check is set to be a pprint version of the estimator | ||
|
@@ -464,6 +466,12 @@ def parametrize_with_checks(estimators, strict_mode=True): | |
|
||
.. versionadded:: 0.24 | ||
|
||
checks_generator : callable, default=None | ||
The generator yielding checks for the estimators. By default, the | ||
common checks from scikit-learn will be yielded. | ||
|
||
.. versionadded:: 0.24 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. also note experimental |
||
|
||
Returns | ||
------- | ||
decorator : `pytest.mark.parametrize` | ||
|
@@ -488,18 +496,24 @@ def parametrize_with_checks(estimators, strict_mode=True): | |
"Please pass an instance instead.") | ||
raise TypeError(msg) | ||
|
||
def checks_generator(): | ||
if checks_generator is None: | ||
checks_generator = _yield_all_checks | ||
|
||
def _checks_generator(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit but we don't need the newly-added leading underscore here since there's no notion of private/public (it might simplify the diff also) |
||
for estimator in estimators: | ||
name = type(estimator).__name__ | ||
for check in _yield_all_checks(estimator): | ||
for check in checks_generator(estimator): | ||
check = partial(check, name, strict_mode=strict_mode) | ||
yield _maybe_mark_xfail(estimator, check, strict_mode, pytest) | ||
|
||
return pytest.mark.parametrize("estimator, check", checks_generator(), | ||
ids=_get_check_estimator_ids) | ||
return pytest.mark.parametrize( | ||
"estimator, check", _checks_generator(), ids=_get_check_estimator_ids | ||
) | ||
|
||
|
||
def check_estimator(Estimator, generate_only=False, strict_mode=True): | ||
def check_estimator( | ||
Estimator, generate_only=False, strict_mode=True, checks_generator=None, | ||
): | ||
"""Check if estimator adheres to scikit-learn conventions. | ||
|
||
This estimator will run an extensive test-suite for input validation, | ||
|
@@ -550,6 +564,12 @@ def check_estimator(Estimator, generate_only=False, strict_mode=True): | |
|
||
.. versionadded:: 0.24 | ||
|
||
checks_generator : callable, default=None | ||
The generator yielding checks for the estimators. By default, the | ||
common checks from scikit-learn will be yielded. | ||
|
||
.. versionadded:: 0.24 | ||
|
||
Returns | ||
------- | ||
checks_generator : generator | ||
|
@@ -565,15 +585,18 @@ def check_estimator(Estimator, generate_only=False, strict_mode=True): | |
estimator = Estimator | ||
name = type(estimator).__name__ | ||
|
||
def checks_generator(): | ||
for check in _yield_all_checks(estimator): | ||
if checks_generator is None: | ||
checks_generator = _yield_all_checks | ||
|
||
def _checks_generator(): | ||
for check in checks_generator(estimator): | ||
check = _maybe_skip(estimator, check, strict_mode) | ||
yield estimator, partial(check, name, strict_mode=strict_mode) | ||
|
||
if generate_only: | ||
return checks_generator() | ||
return _checks_generator() | ||
|
||
for estimator, check in checks_generator(): | ||
for estimator, check in _checks_generator(): | ||
try: | ||
check(estimator) | ||
except SkipTest as exception: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,21 @@ | ||
import unittest | ||
import sys | ||
import warnings | ||
|
||
import numpy as np | ||
import scipy.sparse as sp | ||
import joblib | ||
|
||
from sklearn.base import BaseEstimator, ClassifierMixin | ||
from sklearn.utils import deprecated | ||
from sklearn.utils._testing import (assert_raises_regex, | ||
ignore_warnings, | ||
assert_warns, assert_raises, | ||
SkipTest) | ||
from sklearn.utils._testing import ( | ||
assert_raises_regex, | ||
ignore_warnings, | ||
assert_warns, | ||
assert_warns_message, | ||
assert_raises, | ||
SkipTest, | ||
) | ||
from sklearn.utils.estimator_checks import check_estimator, _NotAnArray | ||
from sklearn.utils.estimator_checks \ | ||
import check_class_weight_balanced_linear_classifier | ||
|
@@ -21,6 +26,7 @@ | |
from sklearn.utils.estimator_checks import check_no_attributes_set_in_init | ||
from sklearn.utils.estimator_checks import check_classifier_data_not_an_array | ||
from sklearn.utils.estimator_checks import check_regressor_data_not_an_array | ||
from sklearn.utils.estimator_checks import parametrize_with_checks | ||
from sklearn.utils.validation import check_is_fitted | ||
from sklearn.utils.estimator_checks import check_outlier_corruption | ||
from sklearn.utils.fixes import np_version, parse_version | ||
|
@@ -677,3 +683,44 @@ def test_xfail_ignored_in_check_estimator(): | |
# Make sure checks marked as xfail are just ignored and not run by | ||
# check_estimator(), but still raise a warning. | ||
assert_warns(SkipTestWarning, check_estimator, NuSVC()) | ||
|
||
|
||
def my_own_check(name, instance, strict_mode=True): | ||
warnings.warn("my_own_check was executed", UserWarning) | ||
|
||
|
||
def my_own_generator(estimator): | ||
yield my_own_check | ||
|
||
|
||
def test_check_estimator_checks_generator(): | ||
# Check that we can pass a custom checks generator in `check_estimator` | ||
assert_warns_message( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
UserWarning, | ||
"my_own_check was executed", | ||
check_estimator, | ||
BaseEstimator(), | ||
checks_generator=my_own_generator, | ||
) | ||
|
||
|
||
def test_parametrize_with_checks_checks_generator(): | ||
# Check that we can pass a custom checks generator in | ||
# `parametrize_with_checks` | ||
decorator = parametrize_with_checks( | ||
[BaseEstimator()], checks_generator=my_own_generator | ||
) | ||
|
||
def test_estimator(estimator, check): | ||
check(estimator) | ||
|
||
test_estimator = decorator(test_estimator) | ||
for _mark in test_estimator.pytestmark: | ||
for estimator, check in _mark.args[1]: | ||
assert_warns_message( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
UserWarning, | ||
"my_own_check was executed", | ||
test_estimator, | ||
estimator, | ||
check, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like your usage of "common" here :D :D