18
18
from matplotlib import pyplot as plt
19
19
from matplotlib import ticker
20
20
from . import is_called_from_pytest
21
- from .compare import comparable_formats , compare_images , make_test_filename
21
+ from .compare import compare_images , make_test_filename , _skip_if_incomparable
22
22
from .exceptions import ImageComparisonFailure
23
23
24
24
@@ -135,54 +135,33 @@ def _raise_on_image_difference(expected, actual, tol):
135
135
% err )
136
136
137
137
138
- def _skip_if_format_is_uncomparable (extension ):
139
- import pytest
140
- return pytest .mark .skipif (
141
- extension not in comparable_formats (),
142
- reason = 'Cannot compare {} files on this system' .format (extension ))
143
-
144
-
145
- def _mark_skip_if_format_is_uncomparable (extension ):
146
- import pytest
147
- if isinstance (extension , str ):
148
- name = extension
149
- marks = []
150
- elif isinstance (extension , tuple ):
151
- # Extension might be a pytest ParameterSet instead of a plain string.
152
- # Unfortunately, this type is not exposed, so since it's a namedtuple,
153
- # check for a tuple instead.
154
- name , = extension .values
155
- marks = [* extension .marks ]
156
- else :
157
- # Extension might be a pytest marker instead of a plain string.
158
- name , = extension .args
159
- marks = [extension .mark ]
160
- return pytest .param (name ,
161
- marks = [* marks , _skip_if_format_is_uncomparable (name )])
162
-
163
-
164
- class _ImageComparisonBase :
138
+ def _make_image_comparator (func = None ,
139
+ baseline_images = None , * , extension = None , tol = 0 ,
140
+ remove_text = False , savefig_kwargs = None ):
165
141
"""
166
- Image comparison base class
142
+ Image comparison base helper.
167
143
168
- This class provides *just* the comparison-related functionality and avoids
144
+ This helper provides *just* the comparison-related functionality and avoids
169
145
any code that would be specific to any testing framework.
170
146
"""
147
+ if func is None :
148
+ return functools .partial (
149
+ _make_image_comparator ,
150
+ baseline_images = baseline_images , extension = extension , tol = tol ,
151
+ remove_text = remove_text , savefig_kwargs = savefig_kwargs )
152
+
153
+ if savefig_kwargs is None :
154
+ savefig_kwargs = {}
171
155
172
- def __init__ (self , func , tol , remove_text , savefig_kwargs ):
173
- self .func = func
174
- self .baseline_dir , self .result_dir = _image_directories (func )
175
- self .tol = tol
176
- self .remove_text = remove_text
177
- self .savefig_kwargs = savefig_kwargs
156
+ baseline_dir , result_dir = _image_directories (func )
178
157
179
- def copy_baseline ( self , baseline , extension ):
180
- baseline_path = self . baseline_dir / baseline
158
+ def _copy_baseline ( baseline ):
159
+ baseline_path = baseline_dir / baseline
181
160
orig_expected_path = baseline_path .with_suffix (f'.{ extension } ' )
182
161
if extension == 'eps' and not orig_expected_path .exists ():
183
162
orig_expected_path = baseline_path .with_suffix ('.pdf' )
184
163
expected_path = make_test_filename (
185
- self . result_dir / orig_expected_path .name , 'expected' )
164
+ result_dir / orig_expected_path .name , 'expected' )
186
165
try :
187
166
shutil .copyfile (orig_expected_path , expected_path )
188
167
except IOError as exc :
@@ -191,24 +170,33 @@ def copy_baseline(self, baseline, extension):
191
170
f"{ expected_path } " ) from exc
192
171
return expected_path
193
172
194
- def compare (self , idx , baseline , extension ):
173
+ @functools .wraps (func )
174
+ def wrapper (* args , ** kwargs ):
195
175
__tracebackhide__ = True
196
- fignum = plt .get_fignums ()[idx ]
197
- fig = plt .figure (fignum )
198
-
199
- if self .remove_text :
200
- remove_ticks_and_titles (fig )
201
-
202
- actual_path = (self .result_dir / baseline ).with_suffix (f'.{ extension } ' )
203
- kwargs = self .savefig_kwargs .copy ()
204
- if extension == 'pdf' :
205
- kwargs .setdefault ('metadata' ,
206
- {'Creator' : None , 'Producer' : None ,
207
- 'CreationDate' : None })
208
- fig .savefig (actual_path , ** kwargs )
209
-
210
- expected_path = self .copy_baseline (baseline , extension )
211
- _raise_on_image_difference (expected_path , actual_path , self .tol )
176
+ _skip_if_incomparable (extension )
177
+
178
+ func (* args , ** kwargs )
179
+
180
+ fignums = plt .get_fignums ()
181
+ assert len (fignums ) == len (baseline_images ), (
182
+ "Test generated {} images but there are {} baseline images"
183
+ .format (len (fignums ), len (baseline_images )))
184
+ for baseline_image , fignum in zip (baseline_images , fignums ):
185
+ fig = plt .figure (fignum )
186
+ if remove_text :
187
+ remove_ticks_and_titles (fig )
188
+ actual_path = ((result_dir / baseline_image )
189
+ .with_suffix (f'.{ extension } ' ))
190
+ kwargs = savefig_kwargs .copy ()
191
+ if extension == 'pdf' :
192
+ kwargs .setdefault ('metadata' ,
193
+ {'Creator' : None , 'Producer' : None ,
194
+ 'CreationDate' : None })
195
+ fig .savefig (actual_path , ** kwargs )
196
+ expected_path = _copy_baseline (baseline_image , extension )
197
+ _raise_on_image_difference (expected_path , actual_path , tol )
198
+
199
+ return wrapper
212
200
213
201
214
202
def _pytest_image_comparison (baseline_images , extensions , tol ,
@@ -224,8 +212,6 @@ def _pytest_image_comparison(baseline_images, extensions, tol,
224
212
"""
225
213
import pytest
226
214
227
- extensions = map (_mark_skip_if_format_is_uncomparable , extensions )
228
-
229
215
def decorator (func ):
230
216
@functools .wraps (func )
231
217
# Parameter indirection; see docstring above and comment below.
@@ -238,23 +224,19 @@ def decorator(func):
238
224
@functools .wraps (func )
239
225
def wrapper (* args , ** kwargs ):
240
226
__tracebackhide__ = True
241
- img = _ImageComparisonBase (func , tol = tol , remove_text = remove_text ,
242
- savefig_kwargs = savefig_kwargs )
243
- matplotlib .testing .set_font_settings_for_testing ()
244
- func (* args , ** kwargs )
245
-
246
227
# Parameter indirection:
247
228
# This is hacked on via the mpl_image_comparison_parameters fixture
248
229
# so that we don't need to modify the function's real signature for
249
230
# any parametrization. Modifying the signature is very very tricky
250
231
# and likely to confuse pytest.
251
232
baseline_images , extension = func .parameters
252
233
253
- assert len (plt .get_fignums ()) == len (baseline_images ), (
254
- "Test generated {} images but there are {} baseline images"
255
- .format (len (plt .get_fignums ()), len (baseline_images )))
256
- for idx , baseline in enumerate (baseline_images ):
257
- img .compare (idx , baseline , extension )
234
+ matplotlib .testing .set_font_settings_for_testing ()
235
+ comparator = _make_image_comparator (
236
+ func ,
237
+ baseline_images = baseline_images , extension = extension , tol = tol ,
238
+ remove_text = remove_text , savefig_kwargs = savefig_kwargs )
239
+ comparator (* args , ** kwargs )
258
240
259
241
return wrapper
260
242
@@ -338,8 +320,6 @@ def image_comparison(baseline_images, extensions=None, tol=0,
338
320
if extensions is None :
339
321
# Default extensions to test, if not set via baseline_images.
340
322
extensions = ['png' , 'pdf' , 'svg' ]
341
- if savefig_kwarg is None :
342
- savefig_kwarg = dict () # default no kwargs to savefig
343
323
return _pytest_image_comparison (
344
324
baseline_images = baseline_images , extensions = extensions , tol = tol ,
345
325
freetype_version = freetype_version , remove_text = remove_text ,
@@ -384,6 +364,7 @@ def decorator(func):
384
364
# Free-standing function.
385
365
@pytest .mark .parametrize ("ext" , extensions )
386
366
def wrapper (ext ):
367
+ _skip_if_incomparable (ext )
387
368
fig_test = plt .figure ("test" )
388
369
fig_ref = plt .figure ("reference" )
389
370
func (fig_test , fig_ref )
@@ -399,6 +380,7 @@ def wrapper(ext):
399
380
# Method.
400
381
@pytest .mark .parametrize ("ext" , extensions )
401
382
def wrapper (self , ext ):
383
+ _skip_if_incomparable (ext )
402
384
fig_test = plt .figure ("test" )
403
385
fig_ref = plt .figure ("reference" )
404
386
func (self , fig_test , fig_ref )
0 commit comments