diff --git a/doc/whats_new/upcoming_changes/sklearn.utils/31951.enhancement.rst b/doc/whats_new/upcoming_changes/sklearn.utils/31951.enhancement.rst new file mode 100644 index 0000000000000..78df7fff40743 --- /dev/null +++ b/doc/whats_new/upcoming_changes/sklearn.utils/31951.enhancement.rst @@ -0,0 +1,4 @@ +- ``sklearn.utils.estimator_checks.parametrize_with_checks`` now lets you configure + strict mode for xfailing checks. Tests that unexpectedly pass will lead to a test + failure. The default behaviour is unchanged. + By :user:`Tim Head `. diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 0841f9dd01d4d..d8cd13848a09d 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -424,6 +424,7 @@ def _maybe_mark( expected_failed_checks: dict[str, str] | None = None, mark: Literal["xfail", "skip", None] = None, pytest=None, + xfail_strict: bool | None = None, ): """Mark the test as xfail or skip if needed. @@ -442,6 +443,13 @@ def _maybe_mark( Pytest module to use to mark the check. This is only needed if ``mark`` is `"xfail"`. Note that one can run `check_estimator` without having `pytest` installed. This is used in combination with `parametrize_with_checks` only. + xfail_strict : bool, default=None + Whether to run checks in xfail strict mode. This option is ignored unless + `mark="xfail"`. If True, checks that are expected to fail but actually + pass will lead to a test failure. If False, unexpectedly passing tests + will be marked as xpass. If None, the default pytest behavior is used. + + .. versionadded:: 1.8 """ should_be_marked, reason = _should_be_skipped_or_marked( estimator, check, expected_failed_checks @@ -451,7 +459,14 @@ def _maybe_mark( estimator_name = estimator.__class__.__name__ if mark == "xfail": - return pytest.param(estimator, check, marks=pytest.mark.xfail(reason=reason)) + # With xfail_strict=None we want the value from the pytest config to + # take precedence and that means not passing strict to the xfail + # mark at all. + if xfail_strict is None: + mark = pytest.mark.xfail(reason=reason) + else: + mark = pytest.mark.xfail(reason=reason, strict=xfail_strict) + return pytest.param(estimator, check, marks=mark) else: @wraps(check) @@ -501,6 +516,7 @@ def estimator_checks_generator( legacy: bool = True, expected_failed_checks: dict[str, str] | None = None, mark: Literal["xfail", "skip", None] = None, + xfail_strict: bool | None = None, ): """Iteratively yield all check callables for an estimator. @@ -528,6 +544,13 @@ def estimator_checks_generator( xfail(`pytest.mark.xfail`) or skip. Marking a test as "skip" is done via wrapping the check in a function that raises a :class:`~sklearn.exceptions.SkipTest` exception. + xfail_strict : bool, default=None + Whether to run checks in xfail strict mode. This option is ignored unless + `mark="xfail"`. If True, checks that are expected to fail but actually + pass will lead to a test failure. If False, unexpectedly passing tests + will be marked as xpass. If None, the default pytest behavior is used. + + .. versionadded:: 1.8 Returns ------- @@ -552,6 +575,7 @@ def estimator_checks_generator( expected_failed_checks=expected_failed_checks, mark=mark, pytest=pytest, + xfail_strict=xfail_strict, ) @@ -560,6 +584,7 @@ def parametrize_with_checks( *, legacy: bool = True, expected_failed_checks: Callable | None = None, + xfail_strict: bool | None = None, ): """Pytest specific decorator for parametrizing estimator checks. @@ -605,9 +630,16 @@ def parametrize_with_checks( Where `"check_name"` is the name of the check, and `"my reason"` is why the check fails. These tests will be marked as xfail if the check fails. - .. versionadded:: 1.6 + xfail_strict : bool, default=None + Whether to run checks in xfail strict mode. If True, checks that are + expected to fail but actually pass will lead to a test failure. If + False, unexpectedly passing tests will be marked as xpass. If None, + the default pytest behavior is used. + + .. versionadded:: 1.8 + Returns ------- decorator : `pytest.mark.parametrize` @@ -640,7 +672,12 @@ def parametrize_with_checks( def _checks_generator(estimators, legacy, expected_failed_checks): for estimator in estimators: - args = {"estimator": estimator, "legacy": legacy, "mark": "xfail"} + args = { + "estimator": estimator, + "legacy": legacy, + "mark": "xfail", + "xfail_strict": xfail_strict, + } if callable(expected_failed_checks): args["expected_failed_checks"] = expected_failed_checks(estimator) yield from estimator_checks_generator(**args) diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index 2abe8caefd915..8048979640509 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -1324,6 +1324,61 @@ def test_all_estimators_all_public(): run_tests_without_pytest() +def test_estimator_checks_generator_strict_none(): + # Check that no "strict" mark is included in the generated checks + est = next(_construct_instances(NuSVC)) + expected_to_fail = _get_expected_failed_checks(est) + # If we don't pass strict, it should not appear in the xfail mark either + # This way the behaviour configured in pytest.ini takes precedence. + checks = estimator_checks_generator( + est, + legacy=True, + expected_failed_checks=expected_to_fail, + mark="xfail", + ) + # make sure we use a class that has expected failures + assert len(expected_to_fail) > 0 + marked_checks = [c for c in checks if hasattr(c, "marks")] + # make sure we have some checks with marks + assert len(marked_checks) > 0 + + for parameter_set in marked_checks: + first_mark = parameter_set.marks[0] + assert "strict" not in first_mark.kwargs + + +def test_estimator_checks_generator_strict_xfail_tests(): + # Make sure that the checks generator marks tests that are expected to fail + # as strict xfail + est = next(_construct_instances(NuSVC)) + expected_to_fail = _get_expected_failed_checks(est) + checks = estimator_checks_generator( + est, + legacy=True, + expected_failed_checks=expected_to_fail, + mark="xfail", + xfail_strict=True, + ) + # make sure we use a class that has expected failures + assert len(expected_to_fail) > 0 + strict_xfailed_checks = [] + + # xfail'ed checks are wrapped in a ParameterSet, so below we extract + # the things we need via a bit of a crutch: len() + marked_checks = [c for c in checks if hasattr(c, "marks")] + # make sure we use a class that has expected failures + assert len(expected_to_fail) > 0 + + for parameter_set in marked_checks: + _, check = parameter_set.values + first_mark = parameter_set.marks[0] + if first_mark.kwargs["strict"]: + strict_xfailed_checks.append(_check_name(check)) + + # all checks expected to fail are marked as strict xfail + assert set(expected_to_fail.keys()) == set(strict_xfailed_checks) + + @_mark_thread_unsafe_if_pytest_imported # Some checks use warnings. def test_estimator_checks_generator_skipping_tests(): # Make sure the checks generator skips tests that are expected to fail