Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ENH allows checks generator to be pluggable #18750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions doc/developers/develop.rst
Original file line number Diff line number Diff line change
Expand Up @@ -777,3 +777,28 @@ The reason for this setup is reproducibility:
when an estimator is ``fit`` twice to the same data,
it should produce an identical model both times,
hence the validation in ``fit``, not ``__init__``.

Reuse the testing infrastructure from scikit-learn
--------------------------------------------------

Scikit-learn provides two utilities:
:func:`~sklearn.utils.estimator_checks.check_estimator` and
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks`. These
utilities tests if a custom estimator is compatible with the different tools
available in scikit-learn. It is common for a third-party library to extend
the test suite with its own estimator checks.
Comment on lines +788 to +789
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like your usage of "common" here :D :D


Both functions accept a parameter `checks_generator` which can be used for this
purpose. This parameter is a generator that yield callables such as::


def check_estimator_has_fit(name, instance, strict_mode=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will be api_only

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep but we need to merge your PR first.

assert hasattr(instance, "fit"), f"{name} does not implement fit"


def checks_generator(estimator):
yield check_estimator_has_fit

.. warning::
This feature is experimental. The signature of the `check` functions can
change without any deprecation cycle.
7 changes: 7 additions & 0 deletions doc/whats_new/v0.24.rst
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,13 @@ Changelog
:func:`utils.sparse_func.incr_mean_variance_axis`.
By :user:`Maria Telenczuk <maikia>` and :user:`Alex Gramfort <agramfort>`.

- |Enhancement| Allows to pass a custom checks generator using the parameter
`checks_generator` in
:func:`~sklearn.utils.estimator_checks.check_estimator` and
:func:`~sklearn.utils.estimator_checks.parametrize_with_checks`. Note that
this feature is experimental in such way that the expected signature of the
`check` functions can change without deprecation cycle.
:pr:`18750` by :user:`Guillaume Lemaitre <glemaitre>`.

Miscellaneous
.............
Expand Down
43 changes: 33 additions & 10 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,9 @@ def _should_be_skipped_or_marked(estimator, check, strict_mode):
return False, 'placeholder reason that will never be used'


def parametrize_with_checks(estimators, strict_mode=True):
def parametrize_with_checks(
estimators, strict_mode=True, checks_generator=None,
):
"""Pytest specific decorator for parametrizing estimator checks.

The `id` of each check is set to be a pprint version of the estimator
Expand Down Expand Up @@ -464,6 +466,12 @@ def parametrize_with_checks(estimators, strict_mode=True):

.. versionadded:: 0.24

checks_generator : callable, default=None
The generator yielding checks for the estimators. By default, the
common checks from scikit-learn will be yielded.

.. versionadded:: 0.24
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also note experimental


Returns
-------
decorator : `pytest.mark.parametrize`
Expand All @@ -488,18 +496,24 @@ def parametrize_with_checks(estimators, strict_mode=True):
"Please pass an instance instead.")
raise TypeError(msg)

def checks_generator():
if checks_generator is None:
checks_generator = _yield_all_checks

def _checks_generator():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit but we don't need the newly-added leading underscore here since there's no notion of private/public (it might simplify the diff also)

for estimator in estimators:
name = type(estimator).__name__
for check in _yield_all_checks(estimator):
for check in checks_generator(estimator):
check = partial(check, name, strict_mode=strict_mode)
yield _maybe_mark_xfail(estimator, check, strict_mode, pytest)

return pytest.mark.parametrize("estimator, check", checks_generator(),
ids=_get_check_estimator_ids)
return pytest.mark.parametrize(
"estimator, check", _checks_generator(), ids=_get_check_estimator_ids
)


def check_estimator(Estimator, generate_only=False, strict_mode=True):
def check_estimator(
Estimator, generate_only=False, strict_mode=True, checks_generator=None,
):
"""Check if estimator adheres to scikit-learn conventions.

This estimator will run an extensive test-suite for input validation,
Expand Down Expand Up @@ -550,6 +564,12 @@ def check_estimator(Estimator, generate_only=False, strict_mode=True):

.. versionadded:: 0.24

checks_generator : callable, default=None
The generator yielding checks for the estimators. By default, the
common checks from scikit-learn will be yielded.

.. versionadded:: 0.24

Returns
-------
checks_generator : generator
Expand All @@ -565,15 +585,18 @@ def check_estimator(Estimator, generate_only=False, strict_mode=True):
estimator = Estimator
name = type(estimator).__name__

def checks_generator():
for check in _yield_all_checks(estimator):
if checks_generator is None:
checks_generator = _yield_all_checks

def _checks_generator():
for check in checks_generator(estimator):
check = _maybe_skip(estimator, check, strict_mode)
yield estimator, partial(check, name, strict_mode=strict_mode)

if generate_only:
return checks_generator()
return _checks_generator()

for estimator, check in checks_generator():
for estimator, check in _checks_generator():
try:
check(estimator)
except SkipTest as exception:
Expand Down
55 changes: 51 additions & 4 deletions sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
import unittest
import sys
import warnings

import numpy as np
import scipy.sparse as sp
import joblib

from sklearn.base import BaseEstimator, ClassifierMixin
from sklearn.utils import deprecated
from sklearn.utils._testing import (assert_raises_regex,
ignore_warnings,
assert_warns, assert_raises,
SkipTest)
from sklearn.utils._testing import (
assert_raises_regex,
ignore_warnings,
assert_warns,
assert_warns_message,
assert_raises,
SkipTest,
)
from sklearn.utils.estimator_checks import check_estimator, _NotAnArray
from sklearn.utils.estimator_checks \
import check_class_weight_balanced_linear_classifier
Expand All @@ -21,6 +26,7 @@
from sklearn.utils.estimator_checks import check_no_attributes_set_in_init
from sklearn.utils.estimator_checks import check_classifier_data_not_an_array
from sklearn.utils.estimator_checks import check_regressor_data_not_an_array
from sklearn.utils.estimator_checks import parametrize_with_checks
from sklearn.utils.validation import check_is_fitted
from sklearn.utils.estimator_checks import check_outlier_corruption
from sklearn.utils.fixes import np_version, parse_version
Expand Down Expand Up @@ -677,3 +683,44 @@ def test_xfail_ignored_in_check_estimator():
# Make sure checks marked as xfail are just ignored and not run by
# check_estimator(), but still raise a warning.
assert_warns(SkipTestWarning, check_estimator, NuSVC())


def my_own_check(name, instance, strict_mode=True):
warnings.warn("my_own_check was executed", UserWarning)


def my_own_generator(estimator):
yield my_own_check


def test_check_estimator_checks_generator():
# Check that we can pass a custom checks generator in `check_estimator`
assert_warns_message(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with pytest.warns(...):?

UserWarning,
"my_own_check was executed",
check_estimator,
BaseEstimator(),
checks_generator=my_own_generator,
)


def test_parametrize_with_checks_checks_generator():
# Check that we can pass a custom checks generator in
# `parametrize_with_checks`
decorator = parametrize_with_checks(
[BaseEstimator()], checks_generator=my_own_generator
)

def test_estimator(estimator, check):
check(estimator)

test_estimator = decorator(test_estimator)
for _mark in test_estimator.pytestmark:
for estimator, check in _mark.args[1]:
assert_warns_message(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with pytest.warns?

UserWarning,
"my_own_check was executed",
test_estimator,
estimator,
check,
)