@@ -381,41 +381,40 @@ def test_plot(fig_test, fig_ref):
381
381
fig_test.subplots().plot([1, 3, 5])
382
382
fig_ref.subplots().plot([0, 1, 2], [1, 3, 5])
383
383
"""
384
-
384
+ POSITIONAL_OR_KEYWORD = inspect . Parameter . POSITIONAL_OR_KEYWORD
385
385
def decorator (func ):
386
386
import pytest
387
387
388
388
_ , result_dir = _image_directories (func )
389
389
390
- if len (inspect .signature (func ).parameters ) == 2 :
391
- # Free-standing function.
392
- @pytest .mark .parametrize ("ext" , extensions )
393
- def wrapper (ext ):
394
- fig_test = plt .figure ("test" )
395
- fig_ref = plt .figure ("reference" )
396
- func (fig_test , fig_ref )
397
- test_image_path = result_dir / (func .__name__ + "." + ext )
398
- ref_image_path = (
399
- result_dir / (func .__name__ + "-expected." + ext ))
400
- fig_test .savefig (test_image_path )
401
- fig_ref .savefig (ref_image_path )
402
- _raise_on_image_difference (
403
- ref_image_path , test_image_path , tol = tol )
404
-
405
- elif len (inspect .signature (func ).parameters ) == 3 :
406
- # Method.
407
- @pytest .mark .parametrize ("ext" , extensions )
408
- def wrapper (self , ext ):
409
- fig_test = plt .figure ("test" )
410
- fig_ref = plt .figure ("reference" )
411
- func (self , fig_test , fig_ref )
412
- test_image_path = result_dir / (func .__name__ + "." + ext )
413
- ref_image_path = (
414
- result_dir / (func .__name__ + "-expected." + ext ))
415
- fig_test .savefig (test_image_path )
416
- fig_ref .savefig (ref_image_path )
417
- _raise_on_image_difference (
418
- ref_image_path , test_image_path , tol = tol )
390
+ @pytest .mark .parametrize ("ext" , extensions )
391
+ def wrapper (* args , ext , ** kwargs ):
392
+ fig_test = plt .figure ("test" )
393
+ fig_ref = plt .figure ("reference" )
394
+ func (* args , fig_test = fig_test , fig_ref = fig_ref , ** kwargs )
395
+ test_image_path = result_dir / (func .__name__ + "." + ext )
396
+ ref_image_path = result_dir / (
397
+ func .__name__ + "-expected." + ext
398
+ )
399
+ fig_test .savefig (test_image_path )
400
+ fig_ref .savefig (ref_image_path )
401
+ _raise_on_image_difference (
402
+ ref_image_path , test_image_path , tol = tol
403
+ )
404
+
405
+ sig = inspect .signature (func )
406
+ new_sig = sig .replace (
407
+ parameters = ([param
408
+ for param in sig .parameters .values ()
409
+ if param .name not in {"fig_test" , "fig_ref" }]
410
+ + [inspect .Parameter ("ext" , POSITIONAL_OR_KEYWORD )])
411
+ )
412
+ wrapper .__signature__ = new_sig
413
+
414
+ # reach a bit into pytest internals to hoist the marks from
415
+ # our wrapped function
416
+ new_marks = getattr (func , "pytestmark" , []) + wrapper .pytestmark
417
+ wrapper .pytestmark = new_marks
419
418
420
419
return wrapper
421
420
0 commit comments