@@ -227,44 +227,60 @@ def _pytest_image_comparison(baseline_images, extensions, tol,
227
227
Decorate function with image comparison for pytest.
228
228
229
229
This function creates a decorator that wraps a figure-generating function
230
- with image comparison code. Pytest can become confused if we change the
231
- signature of the function, so we indirectly pass anything we need via the
232
- `mpl_image_comparison_parameters` fixture and extra markers.
230
+ with image comparison code.
233
231
"""
234
232
import pytest
235
233
236
234
extensions = map (_mark_skip_if_format_is_uncomparable , extensions )
235
+ KEYWORD_ONLY = inspect .Parameter .KEYWORD_ONLY
237
236
238
237
def decorator (func ):
238
+ old_sig = inspect .signature (func )
239
+
239
240
@functools .wraps (func )
240
- # Parameter indirection; see docstring above and comment below.
241
- @pytest .mark .usefixtures ('mpl_image_comparison_parameters' )
242
241
@pytest .mark .parametrize ('extension' , extensions )
243
- @pytest .mark .baseline_images (baseline_images )
244
- # END Parameter indirection.
245
242
@pytest .mark .style (style )
246
243
@_checked_on_freetype_version (freetype_version )
247
244
@functools .wraps (func )
248
- def wrapper (* args , ** kwargs ):
245
+ def wrapper (* args , extension , request , ** kwargs ):
249
246
__tracebackhide__ = True
247
+ if 'extension' in old_sig .parameters :
248
+ kwargs ['extension' ] = extension
249
+ if 'request' in old_sig .parameters :
250
+ kwargs ['request' ] = request
251
+
250
252
img = _ImageComparisonBase (func , tol = tol , remove_text = remove_text ,
251
253
savefig_kwargs = savefig_kwargs )
252
254
matplotlib .testing .set_font_settings_for_testing ()
253
255
func (* args , ** kwargs )
254
256
255
- # Parameter indirection:
256
- # This is hacked on via the mpl_image_comparison_parameters fixture
257
- # so that we don't need to modify the function's real signature for
258
- # any parametrization. Modifying the signature is very very tricky
259
- # and likely to confuse pytest.
260
- baseline_images , extension = func .parameters
257
+ if baseline_images is not None :
258
+ our_baseline_images = baseline_images
259
+ else :
260
+ # Allow baseline image list to be produced on the fly based on
261
+ # current parametrization.
262
+ our_baseline_images = request .getfixturevalue (
263
+ 'baseline_images' )
261
264
262
- assert len (plt .get_fignums ()) == len (baseline_images ), (
265
+ assert len (plt .get_fignums ()) == len (our_baseline_images ), (
263
266
"Test generated {} images but there are {} baseline images"
264
- .format (len (plt .get_fignums ()), len (baseline_images )))
265
- for idx , baseline in enumerate (baseline_images ):
267
+ .format (len (plt .get_fignums ()), len (our_baseline_images )))
268
+ for idx , baseline in enumerate (our_baseline_images ):
266
269
img .compare (idx , baseline , extension )
267
270
271
+ parameters = list (old_sig .parameters .values ())
272
+ if 'extension' not in old_sig .parameters :
273
+ parameters += [inspect .Parameter ('extension' , KEYWORD_ONLY )]
274
+ if 'request' not in old_sig .parameters :
275
+ parameters += [inspect .Parameter ("request" , KEYWORD_ONLY )]
276
+ new_sig = old_sig .replace (parameters = parameters )
277
+ wrapper .__signature__ = new_sig
278
+
279
+ # Reach a bit into pytest internals to hoist the marks from our wrapped
280
+ # function.
281
+ new_marks = getattr (func , 'pytestmark' , []) + wrapper .pytestmark
282
+ wrapper .pytestmark = new_marks
283
+
268
284
return wrapper
269
285
270
286
return decorator
0 commit comments