18
18
from matplotlib import pyplot as plt
19
19
from matplotlib import ticker
20
20
from . import is_called_from_pytest
21
- from .compare import comparable_formats , compare_images , make_test_filename
21
+ from .compare import compare_images , make_test_filename , _skip_if_uncomparable
22
22
from .exceptions import ImageComparisonFailure
23
23
24
24
@@ -136,54 +136,33 @@ def _raise_on_image_difference(expected, actual, tol):
136
136
% err )
137
137
138
138
139
- def _skip_if_format_is_uncomparable (extension ):
140
- import pytest
141
- return pytest .mark .skipif (
142
- extension not in comparable_formats (),
143
- reason = 'Cannot compare {} files on this system' .format (extension ))
144
-
145
-
146
- def _mark_skip_if_format_is_uncomparable (extension ):
147
- import pytest
148
- if isinstance (extension , str ):
149
- name = extension
150
- marks = []
151
- elif isinstance (extension , tuple ):
152
- # Extension might be a pytest ParameterSet instead of a plain string.
153
- # Unfortunately, this type is not exposed, so since it's a namedtuple,
154
- # check for a tuple instead.
155
- name , = extension .values
156
- marks = [* extension .marks ]
157
- else :
158
- # Extension might be a pytest marker instead of a plain string.
159
- name , = extension .args
160
- marks = [extension .mark ]
161
- return pytest .param (name ,
162
- marks = [* marks , _skip_if_format_is_uncomparable (name )])
163
-
164
-
165
- class _ImageComparisonBase :
139
+ def _make_image_comparator (func = None ,
140
+ baseline_images = None , * , extension = None , tol = 0 ,
141
+ remove_text = False , savefig_kwargs = None ):
166
142
"""
167
- Image comparison base class
143
+ Image comparison base helper.
168
144
169
- This class provides *just* the comparison-related functionality and avoids
145
+ This helper provides *just* the comparison-related functionality and avoids
170
146
any code that would be specific to any testing framework.
171
147
"""
148
+ if func is None :
149
+ return functools .partial (
150
+ _make_image_comparator ,
151
+ baseline_images = baseline_images , extension = extension , tol = tol ,
152
+ remove_text = remove_text , savefig_kwargs = savefig_kwargs )
153
+
154
+ if savefig_kwargs is None :
155
+ savefig_kwargs = {}
172
156
173
- def __init__ (self , func , tol , remove_text , savefig_kwargs ):
174
- self .func = func
175
- self .baseline_dir , self .result_dir = _image_directories (func )
176
- self .tol = tol
177
- self .remove_text = remove_text
178
- self .savefig_kwargs = savefig_kwargs
157
+ baseline_dir , result_dir = _image_directories (func )
179
158
180
- def copy_baseline ( self , baseline , extension ):
181
- baseline_path = self . baseline_dir / baseline
159
+ def _copy_baseline ( baseline ):
160
+ baseline_path = baseline_dir / baseline
182
161
orig_expected_path = baseline_path .with_suffix (f'.{ extension } ' )
183
162
if extension == 'eps' and not orig_expected_path .exists ():
184
163
orig_expected_path = orig_expected_path .with_suffix ('.pdf' )
185
164
expected_fname = make_test_filename (
186
- self . result_dir / orig_expected_path .name , 'expected' )
165
+ result_dir / orig_expected_path .name , 'expected' )
187
166
try :
188
167
# os.symlink errors if the target already exists.
189
168
with contextlib .suppress (OSError ):
@@ -198,24 +177,33 @@ def copy_baseline(self, baseline, extension):
198
177
f"following file cannot be accessed: { orig_expected_path } " )
199
178
return expected_fname
200
179
201
- def compare (self , idx , baseline , extension ):
180
+ @functools .wraps (func )
181
+ def wrapper (* args , ** kwargs ):
202
182
__tracebackhide__ = True
203
- fignum = plt .get_fignums ()[idx ]
204
- fig = plt .figure (fignum )
205
-
206
- if self .remove_text :
207
- remove_ticks_and_titles (fig )
208
-
209
- actual_path = (self .result_dir / baseline ).with_suffix (f'.{ extension } ' )
210
- kwargs = self .savefig_kwargs .copy ()
211
- if extension == 'pdf' :
212
- kwargs .setdefault ('metadata' ,
213
- {'Creator' : None , 'Producer' : None ,
214
- 'CreationDate' : None })
215
- fig .savefig (actual_path , ** kwargs )
216
-
217
- expected_path = self .copy_baseline (baseline , extension )
218
- _raise_on_image_difference (expected_path , actual_path , self .tol )
183
+ _skip_if_uncomparable (extension )
184
+
185
+ func (* args , ** kwargs )
186
+
187
+ fignums = plt .get_fignums ()
188
+ assert len (fignums ) == len (baseline_images ), (
189
+ "Test generated {} images but there are {} baseline images"
190
+ .format (len (fignums ), len (baseline_images )))
191
+ for baseline_image , fignum in zip (baseline_images , fignums ):
192
+ fig = plt .figure (fignum )
193
+ if remove_text :
194
+ remove_ticks_and_titles (fig )
195
+ actual_path = ((result_dir / baseline_image )
196
+ .with_suffix (f'.{ extension } ' ))
197
+ kwargs = savefig_kwargs .copy ()
198
+ if extension == 'pdf' :
199
+ kwargs .setdefault ('metadata' ,
200
+ {'Creator' : None , 'Producer' : None ,
201
+ 'CreationDate' : None })
202
+ fig .savefig (actual_path , ** kwargs )
203
+ expected_path = _copy_baseline (baseline_image )
204
+ _raise_on_image_difference (expected_path , actual_path , tol )
205
+
206
+ return wrapper
219
207
220
208
221
209
def _pytest_image_comparison (baseline_images , extensions , tol ,
@@ -231,8 +219,6 @@ def _pytest_image_comparison(baseline_images, extensions, tol,
231
219
"""
232
220
import pytest
233
221
234
- extensions = map (_mark_skip_if_format_is_uncomparable , extensions )
235
-
236
222
def decorator (func ):
237
223
@functools .wraps (func )
238
224
# Parameter indirection; see docstring above and comment below.
@@ -245,23 +231,19 @@ def decorator(func):
245
231
@functools .wraps (func )
246
232
def wrapper (* args , ** kwargs ):
247
233
__tracebackhide__ = True
248
- img = _ImageComparisonBase (func , tol = tol , remove_text = remove_text ,
249
- savefig_kwargs = savefig_kwargs )
250
- matplotlib .testing .set_font_settings_for_testing ()
251
- func (* args , ** kwargs )
252
-
253
234
# Parameter indirection:
254
235
# This is hacked on via the mpl_image_comparison_parameters fixture
255
236
# so that we don't need to modify the function's real signature for
256
237
# any parametrization. Modifying the signature is very very tricky
257
238
# and likely to confuse pytest.
258
239
baseline_images , extension = func .parameters
259
240
260
- assert len (plt .get_fignums ()) == len (baseline_images ), (
261
- "Test generated {} images but there are {} baseline images"
262
- .format (len (plt .get_fignums ()), len (baseline_images )))
263
- for idx , baseline in enumerate (baseline_images ):
264
- img .compare (idx , baseline , extension )
241
+ matplotlib .testing .set_font_settings_for_testing ()
242
+ comparator = _make_image_comparator (
243
+ func ,
244
+ baseline_images = baseline_images , extension = extension , tol = tol ,
245
+ remove_text = remove_text , savefig_kwargs = savefig_kwargs )
246
+ comparator (* args , ** kwargs )
265
247
266
248
return wrapper
267
249
@@ -345,8 +327,6 @@ def image_comparison(baseline_images, extensions=None, tol=0,
345
327
if extensions is None :
346
328
# Default extensions to test, if not set via baseline_images.
347
329
extensions = ['png' , 'pdf' , 'svg' ]
348
- if savefig_kwarg is None :
349
- savefig_kwarg = dict () # default no kwargs to savefig
350
330
return _pytest_image_comparison (
351
331
baseline_images = baseline_images , extensions = extensions , tol = tol ,
352
332
freetype_version = freetype_version , remove_text = remove_text ,
@@ -389,6 +369,7 @@ def decorator(func):
389
369
390
370
@pytest .mark .parametrize ("ext" , extensions )
391
371
def wrapper (* args , ext , ** kwargs ):
372
+ _skip_if_uncomparable (ext )
392
373
try :
393
374
fig_test = plt .figure ("test" )
394
375
fig_ref = plt .figure ("reference" )
0 commit comments