Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- ``sklearn.utils.estimator_checks.parametrize_with_checks`` now lets you configure
strict mode for xfailing checks. Tests that unexpectedly pass will lead to a test
failure. The default behaviour is unchanged.
By :user:`Tim Head <betatim>`.
43 changes: 40 additions & 3 deletions sklearn/utils/estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,7 @@ def _maybe_mark(
expected_failed_checks: dict[str, str] | None = None,
mark: Literal["xfail", "skip", None] = None,
pytest=None,
xfail_strict: bool | None = None,
):
"""Mark the test as xfail or skip if needed.

Expand All @@ -442,6 +443,13 @@ def _maybe_mark(
Pytest module to use to mark the check. This is only needed if ``mark`` is
`"xfail"`. Note that one can run `check_estimator` without having `pytest`
installed. This is used in combination with `parametrize_with_checks` only.
xfail_strict : bool, default=None
Whether to run checks in xfail strict mode. This option is ignored unless
`mark="xfail"`. If True, checks that are expected to fail but actually
pass will lead to a test failure. If False, unexpectedly passing tests
will be marked as xpass. If None, the default pytest behavior is used.

.. versionadded:: 1.8
"""
should_be_marked, reason = _should_be_skipped_or_marked(
estimator, check, expected_failed_checks
Expand All @@ -451,7 +459,14 @@ def _maybe_mark(

estimator_name = estimator.__class__.__name__
if mark == "xfail":
return pytest.param(estimator, check, marks=pytest.mark.xfail(reason=reason))
# With xfail_strict=None we want the value from the pytest config to
# take precedence and that means not passing strict to the xfail
# mark at all.
if xfail_strict is None:
mark = pytest.mark.xfail(reason=reason)
else:
mark = pytest.mark.xfail(reason=reason, strict=xfail_strict)
return pytest.param(estimator, check, marks=mark)
else:

@wraps(check)
Expand Down Expand Up @@ -501,6 +516,7 @@ def estimator_checks_generator(
legacy: bool = True,
expected_failed_checks: dict[str, str] | None = None,
mark: Literal["xfail", "skip", None] = None,
xfail_strict: bool | None = None,
):
"""Iteratively yield all check callables for an estimator.

Expand Down Expand Up @@ -528,6 +544,13 @@ def estimator_checks_generator(
xfail(`pytest.mark.xfail`) or skip. Marking a test as "skip" is done via
wrapping the check in a function that raises a
:class:`~sklearn.exceptions.SkipTest` exception.
xfail_strict : bool, default=None
Whether to run checks in xfail strict mode. This option is ignored unless
`mark="xfail"`. If True, checks that are expected to fail but actually
pass will lead to a test failure. If False, unexpectedly passing tests
will be marked as xpass. If None, the default pytest behavior is used.

.. versionadded:: 1.8

Returns
-------
Expand All @@ -552,6 +575,7 @@ def estimator_checks_generator(
expected_failed_checks=expected_failed_checks,
mark=mark,
pytest=pytest,
xfail_strict=xfail_strict,
)


Expand All @@ -560,6 +584,7 @@ def parametrize_with_checks(
*,
legacy: bool = True,
expected_failed_checks: Callable | None = None,
xfail_strict: bool | None = None,
):
"""Pytest specific decorator for parametrizing estimator checks.

Expand Down Expand Up @@ -605,9 +630,16 @@ def parametrize_with_checks(
Where `"check_name"` is the name of the check, and `"my reason"` is why
the check fails. These tests will be marked as xfail if the check fails.


.. versionadded:: 1.6

xfail_strict : bool, default=None
Whether to run checks in xfail strict mode. If True, checks that are
expected to fail but actually pass will lead to a test failure. If
False, unexpectedly passing tests will be marked as xpass. If None,
the default pytest behavior is used.

.. versionadded:: 1.8

Returns
-------
decorator : `pytest.mark.parametrize`
Expand Down Expand Up @@ -640,7 +672,12 @@ def parametrize_with_checks(

def _checks_generator(estimators, legacy, expected_failed_checks):
for estimator in estimators:
args = {"estimator": estimator, "legacy": legacy, "mark": "xfail"}
args = {
"estimator": estimator,
"legacy": legacy,
"mark": "xfail",
"xfail_strict": xfail_strict,
}
if callable(expected_failed_checks):
args["expected_failed_checks"] = expected_failed_checks(estimator)
yield from estimator_checks_generator(**args)
Expand Down
55 changes: 55 additions & 0 deletions sklearn/utils/tests/test_estimator_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1324,6 +1324,61 @@ def test_all_estimators_all_public():
run_tests_without_pytest()


def test_estimator_checks_generator_strict_none():
# Check that no "strict" mark is included in the generated checks
est = next(_construct_instances(NuSVC))
expected_to_fail = _get_expected_failed_checks(est)
# If we don't pass strict, it should not appear in the xfail mark either
# This way the behaviour configured in pytest.ini takes precedence.
checks = estimator_checks_generator(
est,
legacy=True,
expected_failed_checks=expected_to_fail,
mark="xfail",
)
# make sure we use a class that has expected failures
assert len(expected_to_fail) > 0
marked_checks = [c for c in checks if hasattr(c, "marks")]
# make sure we have some checks with marks
assert len(marked_checks) > 0

for parameter_set in marked_checks:
first_mark = parameter_set.marks[0]
assert "strict" not in first_mark.kwargs


def test_estimator_checks_generator_strict_xfail_tests():
# Make sure that the checks generator marks tests that are expected to fail
# as strict xfail
est = next(_construct_instances(NuSVC))
expected_to_fail = _get_expected_failed_checks(est)
checks = estimator_checks_generator(
est,
legacy=True,
expected_failed_checks=expected_to_fail,
mark="xfail",
xfail_strict=True,
)
# make sure we use a class that has expected failures
assert len(expected_to_fail) > 0
strict_xfailed_checks = []

# xfail'ed checks are wrapped in a ParameterSet, so below we extract
# the things we need via a bit of a crutch: len()
marked_checks = [c for c in checks if hasattr(c, "marks")]
# make sure we use a class that has expected failures
assert len(expected_to_fail) > 0

for parameter_set in marked_checks:
_, check = parameter_set.values
first_mark = parameter_set.marks[0]
if first_mark.kwargs["strict"]:
strict_xfailed_checks.append(_check_name(check))

# all checks expected to fail are marked as strict xfail
assert set(expected_to_fail.keys()) == set(strict_xfailed_checks)


@_mark_thread_unsafe_if_pytest_imported # Some checks use warnings.
def test_estimator_checks_generator_skipping_tests():
# Make sure the checks generator skips tests that are expected to fail
Expand Down
Loading