18
18
from matplotlib import pyplot as plt
19
19
from matplotlib import ticker
20
20
from . import is_called_from_pytest
21
- from .compare import comparable_formats , compare_images , make_test_filename
21
+ from .compare import compare_images , make_test_filename , _skip_if_incomparable
22
22
from .exceptions import ImageComparisonFailure
23
23
24
24
@@ -135,54 +135,33 @@ def _raise_on_image_difference(expected, actual, tol):
135
135
% err )
136
136
137
137
138
- def _skip_if_format_is_uncomparable (extension ):
139
- import pytest
140
- return pytest .mark .skipif (
141
- extension not in comparable_formats (),
142
- reason = 'Cannot compare {} files on this system' .format (extension ))
143
-
144
-
145
- def _mark_skip_if_format_is_uncomparable (extension ):
146
- import pytest
147
- if isinstance (extension , str ):
148
- name = extension
149
- marks = []
150
- elif isinstance (extension , tuple ):
151
- # Extension might be a pytest ParameterSet instead of a plain string.
152
- # Unfortunately, this type is not exposed, so since it's a namedtuple,
153
- # check for a tuple instead.
154
- name , = extension .values
155
- marks = [* extension .marks ]
156
- else :
157
- # Extension might be a pytest marker instead of a plain string.
158
- name , = extension .args
159
- marks = [extension .mark ]
160
- return pytest .param (name ,
161
- marks = [* marks , _skip_if_format_is_uncomparable (name )])
162
-
163
-
164
- class _ImageComparisonBase :
138
+ def _make_image_comparator (func = None ,
139
+ baseline_images = None , * , extension = None , tol = 0 ,
140
+ remove_text = False , savefig_kwargs = None ):
165
141
"""
166
- Image comparison base class
142
+ Image comparison base helper.
167
143
168
- This class provides *just* the comparison-related functionality and avoids
144
+ This helper provides *just* the comparison-related functionality and avoids
169
145
any code that would be specific to any testing framework.
170
146
"""
147
+ if func is None :
148
+ return functools .partial (
149
+ _make_image_comparator ,
150
+ baseline_images = baseline_images , extension = extension , tol = tol ,
151
+ remove_text = remove_text , savefig_kwargs = savefig_kwargs )
152
+
153
+ if savefig_kwargs is None :
154
+ savefig_kwargs = {}
171
155
172
- def __init__ (self , func , tol , remove_text , savefig_kwargs ):
173
- self .func = func
174
- self .baseline_dir , self .result_dir = _image_directories (func )
175
- self .tol = tol
176
- self .remove_text = remove_text
177
- self .savefig_kwargs = savefig_kwargs
156
+ baseline_dir , result_dir = _image_directories (func )
178
157
179
- def copy_baseline ( self , baseline , extension ):
180
- baseline_path = self . baseline_dir / baseline
158
+ def _copy_baseline ( baseline ):
159
+ baseline_path = baseline_dir / baseline
181
160
orig_expected_path = baseline_path .with_suffix (f'.{ extension } ' )
182
161
if extension == 'eps' and not orig_expected_path .exists ():
183
162
orig_expected_fname = baseline_path .with_suffix ('.pdf' )
184
163
expected_fname = make_test_filename (
185
- self . result_dir / orig_expected_path .name , 'expected' )
164
+ result_dir / orig_expected_path .name , 'expected' )
186
165
try :
187
166
# os.symlink errors if the target already exists.
188
167
with contextlib .suppress (OSError ):
@@ -197,24 +176,33 @@ def copy_baseline(self, baseline, extension):
197
176
f"following file cannot be accessed: { orig_expected_fname } " )
198
177
return expected_fname
199
178
200
- def compare (self , idx , baseline , extension ):
179
+ @functools .wraps (func )
180
+ def wrapper (* args , ** kwargs ):
201
181
__tracebackhide__ = True
202
- fignum = plt .get_fignums ()[idx ]
203
- fig = plt .figure (fignum )
204
-
205
- if self .remove_text :
206
- remove_ticks_and_titles (fig )
207
-
208
- actual_path = (self .result_dir / baseline ).with_suffix (f'.{ extension } ' )
209
- kwargs = self .savefig_kwargs .copy ()
210
- if extension == 'pdf' :
211
- kwargs .setdefault ('metadata' ,
212
- {'Creator' : None , 'Producer' : None ,
213
- 'CreationDate' : None })
214
- fig .savefig (actual_path , ** kwargs )
215
-
216
- expected_path = self .copy_baseline (baseline , extension )
217
- _raise_on_image_difference (expected_path , actual_path , self .tol )
182
+ _skip_if_incomparable (extension )
183
+
184
+ func (* args , ** kwargs )
185
+
186
+ fignums = plt .get_fignums ()
187
+ assert len (fignums ) == len (baseline_images ), (
188
+ "Test generated {} images but there are {} baseline images"
189
+ .format (len (fignums ), len (baseline_images )))
190
+ for baseline_image , fignum in zip (baseline_images , fignums ):
191
+ fig = plt .figure (fignum )
192
+ if remove_text :
193
+ remove_ticks_and_titles (fig )
194
+ actual_path = ((result_dir / baseline_image )
195
+ .with_suffix (f'.{ extension } ' ))
196
+ kwargs = savefig_kwargs .copy ()
197
+ if extension == 'pdf' :
198
+ kwargs .setdefault ('metadata' ,
199
+ {'Creator' : None , 'Producer' : None ,
200
+ 'CreationDate' : None })
201
+ fig .savefig (actual_path , ** kwargs )
202
+ expected_path = _copy_baseline (baseline_image )
203
+ _raise_on_image_difference (expected_path , actual_path , tol )
204
+
205
+ return wrapper
218
206
219
207
220
208
def _pytest_image_comparison (baseline_images , extensions , tol ,
@@ -230,8 +218,6 @@ def _pytest_image_comparison(baseline_images, extensions, tol,
230
218
"""
231
219
import pytest
232
220
233
- extensions = map (_mark_skip_if_format_is_uncomparable , extensions )
234
-
235
221
def decorator (func ):
236
222
@functools .wraps (func )
237
223
# Parameter indirection; see docstring above and comment below.
@@ -244,23 +230,19 @@ def decorator(func):
244
230
@functools .wraps (func )
245
231
def wrapper (* args , ** kwargs ):
246
232
__tracebackhide__ = True
247
- img = _ImageComparisonBase (func , tol = tol , remove_text = remove_text ,
248
- savefig_kwargs = savefig_kwargs )
249
- matplotlib .testing .set_font_settings_for_testing ()
250
- func (* args , ** kwargs )
251
-
252
233
# Parameter indirection:
253
234
# This is hacked on via the mpl_image_comparison_parameters fixture
254
235
# so that we don't need to modify the function's real signature for
255
236
# any parametrization. Modifying the signature is very very tricky
256
237
# and likely to confuse pytest.
257
238
baseline_images , extension = func .parameters
258
239
259
- assert len (plt .get_fignums ()) == len (baseline_images ), (
260
- "Test generated {} images but there are {} baseline images"
261
- .format (len (plt .get_fignums ()), len (baseline_images )))
262
- for idx , baseline in enumerate (baseline_images ):
263
- img .compare (idx , baseline , extension )
240
+ matplotlib .testing .set_font_settings_for_testing ()
241
+ comparator = _make_image_comparator (
242
+ func ,
243
+ baseline_images = baseline_images , extension = extension , tol = tol ,
244
+ remove_text = remove_text , savefig_kwargs = savefig_kwargs )
245
+ comparator (* args , ** kwargs )
264
246
265
247
return wrapper
266
248
@@ -344,8 +326,6 @@ def image_comparison(baseline_images, extensions=None, tol=0,
344
326
if extensions is None :
345
327
# Default extensions to test, if not set via baseline_images.
346
328
extensions = ['png' , 'pdf' , 'svg' ]
347
- if savefig_kwarg is None :
348
- savefig_kwarg = dict () # default no kwargs to savefig
349
329
return _pytest_image_comparison (
350
330
baseline_images = baseline_images , extensions = extensions , tol = tol ,
351
331
freetype_version = freetype_version , remove_text = remove_text ,
@@ -390,6 +370,7 @@ def decorator(func):
390
370
# Free-standing function.
391
371
@pytest .mark .parametrize ("ext" , extensions )
392
372
def wrapper (ext ):
373
+ _skip_if_incomparable (ext )
393
374
fig_test = plt .figure ("test" )
394
375
fig_ref = plt .figure ("reference" )
395
376
func (fig_test , fig_ref )
@@ -405,6 +386,7 @@ def wrapper(ext):
405
386
# Method.
406
387
@pytest .mark .parametrize ("ext" , extensions )
407
388
def wrapper (self , ext ):
389
+ _skip_if_incomparable (ext )
408
390
fig_test = plt .figure ("test" )
409
391
fig_ref = plt .figure ("reference" )
410
392
func (self , fig_test , fig_ref )
0 commit comments