diff --git a/lib/matplotlib/testing/decorators.py b/lib/matplotlib/testing/decorators.py index 7965d45e4f16..7e9621cb3b1d 100644 --- a/lib/matplotlib/testing/decorators.py +++ b/lib/matplotlib/testing/decorators.py @@ -381,41 +381,40 @@ def test_plot(fig_test, fig_ref): fig_test.subplots().plot([1, 3, 5]) fig_ref.subplots().plot([0, 1, 2], [1, 3, 5]) """ - + POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD def decorator(func): import pytest _, result_dir = _image_directories(func) - if len(inspect.signature(func).parameters) == 2: - # Free-standing function. - @pytest.mark.parametrize("ext", extensions) - def wrapper(ext): - fig_test = plt.figure("test") - fig_ref = plt.figure("reference") - func(fig_test, fig_ref) - test_image_path = result_dir / (func.__name__ + "." + ext) - ref_image_path = ( - result_dir / (func.__name__ + "-expected." + ext)) - fig_test.savefig(test_image_path) - fig_ref.savefig(ref_image_path) - _raise_on_image_difference( - ref_image_path, test_image_path, tol=tol) - - elif len(inspect.signature(func).parameters) == 3: - # Method. - @pytest.mark.parametrize("ext", extensions) - def wrapper(self, ext): - fig_test = plt.figure("test") - fig_ref = plt.figure("reference") - func(self, fig_test, fig_ref) - test_image_path = result_dir / (func.__name__ + "." + ext) - ref_image_path = ( - result_dir / (func.__name__ + "-expected." + ext)) - fig_test.savefig(test_image_path) - fig_ref.savefig(ref_image_path) - _raise_on_image_difference( - ref_image_path, test_image_path, tol=tol) + @pytest.mark.parametrize("ext", extensions) + def wrapper(*args, ext, **kwargs): + fig_test = plt.figure("test") + fig_ref = plt.figure("reference") + func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs) + test_image_path = result_dir / (func.__name__ + "." + ext) + ref_image_path = result_dir / ( + func.__name__ + "-expected." + ext + ) + fig_test.savefig(test_image_path) + fig_ref.savefig(ref_image_path) + _raise_on_image_difference( + ref_image_path, test_image_path, tol=tol + ) + + sig = inspect.signature(func) + new_sig = sig.replace( + parameters=([param + for param in sig.parameters.values() + if param.name not in {"fig_test", "fig_ref"}] + + [inspect.Parameter("ext", POSITIONAL_OR_KEYWORD)]) + ) + wrapper.__signature__ = new_sig + + # reach a bit into pytest internals to hoist the marks from + # our wrapped function + new_marks = getattr(func, "pytestmark", []) + wrapper.pytestmark + wrapper.pytestmark = new_marks return wrapper diff --git a/lib/matplotlib/tests/test_testing.py b/lib/matplotlib/tests/test_testing.py index 55415abb6a04..d1273ec1ba46 100644 --- a/lib/matplotlib/tests/test_testing.py +++ b/lib/matplotlib/tests/test_testing.py @@ -1,7 +1,17 @@ import warnings import pytest +from matplotlib.testing.decorators import check_figures_equal -@pytest.mark.xfail(strict=True, - reason="testing that warnings fail tests") + +@pytest.mark.xfail( + strict=True, reason="testing that warnings fail tests" +) def test_warn_to_fail(): warnings.warn("This should fail the test") + + +@pytest.mark.parametrize("a", [1]) +@check_figures_equal(extensions=["png"]) +@pytest.mark.parametrize("b", [1]) +def test_parametrize_with_check_figure_equal(a, fig_ref, b, fig_test): + assert a == b