1818from matplotlib import pyplot as plt
1919from matplotlib import ticker
2020from . import is_called_from_pytest
21- from .compare import comparable_formats , compare_images , make_test_filename
21+ from .compare import compare_images , make_test_filename , _skip_if_incomparable
2222from .exceptions import ImageComparisonFailure
2323
2424
@@ -135,54 +135,33 @@ def _raise_on_image_difference(expected, actual, tol):
135135 % err )
136136
137137
138- def _skip_if_format_is_uncomparable (extension ):
139- import pytest
140- return pytest .mark .skipif (
141- extension not in comparable_formats (),
142- reason = 'Cannot compare {} files on this system' .format (extension ))
143-
144-
145- def _mark_skip_if_format_is_uncomparable (extension ):
146- import pytest
147- if isinstance (extension , str ):
148- name = extension
149- marks = []
150- elif isinstance (extension , tuple ):
151- # Extension might be a pytest ParameterSet instead of a plain string.
152- # Unfortunately, this type is not exposed, so since it's a namedtuple,
153- # check for a tuple instead.
154- name , = extension .values
155- marks = [* extension .marks ]
156- else :
157- # Extension might be a pytest marker instead of a plain string.
158- name , = extension .args
159- marks = [extension .mark ]
160- return pytest .param (name ,
161- marks = [* marks , _skip_if_format_is_uncomparable (name )])
162-
163-
164- class _ImageComparisonBase :
138+ def _make_image_comparator (func = None ,
139+ baseline_images = None , * , extension = None , tol = 0 ,
140+ remove_text = False , savefig_kwargs = None ):
165141 """
166- Image comparison base class
142+ Image comparison base helper.
167143
168- This class provides *just* the comparison-related functionality and avoids
144+ This helper provides *just* the comparison-related functionality and avoids
169145 any code that would be specific to any testing framework.
170146 """
147+ if func is None :
148+ return functools .partial (
149+ _make_image_comparator ,
150+ baseline_images = baseline_images , extension = extension , tol = tol ,
151+ remove_text = remove_text , savefig_kwargs = savefig_kwargs )
152+
153+ if savefig_kwargs is None :
154+ savefig_kwargs = {}
171155
172- def __init__ (self , func , tol , remove_text , savefig_kwargs ):
173- self .func = func
174- self .baseline_dir , self .result_dir = _image_directories (func )
175- self .tol = tol
176- self .remove_text = remove_text
177- self .savefig_kwargs = savefig_kwargs
156+ baseline_dir , result_dir = _image_directories (func )
178157
179- def copy_baseline ( self , baseline , extension ):
180- baseline_path = self . baseline_dir / baseline
158+ def _copy_baseline ( baseline ):
159+ baseline_path = baseline_dir / baseline
181160 orig_expected_path = baseline_path .with_suffix (f'.{ extension } ' )
182161 if extension == 'eps' and not orig_expected_path .exists ():
183162 orig_expected_fname = baseline_path .with_suffix ('.pdf' )
184163 expected_fname = make_test_filename (
185- self . result_dir / orig_expected_path .name , 'expected' )
164+ result_dir / orig_expected_path .name , 'expected' )
186165 try :
187166 # os.symlink errors if the target already exists.
188167 with contextlib .suppress (OSError ):
@@ -197,24 +176,33 @@ def copy_baseline(self, baseline, extension):
197176 f"following file cannot be accessed: { orig_expected_fname } " )
198177 return expected_fname
199178
200- def compare (self , idx , baseline , extension ):
179+ @functools .wraps (func )
180+ def wrapper (* args , ** kwargs ):
201181 __tracebackhide__ = True
202- fignum = plt .get_fignums ()[idx ]
203- fig = plt .figure (fignum )
204-
205- if self .remove_text :
206- remove_ticks_and_titles (fig )
207-
208- actual_path = (self .result_dir / baseline ).with_suffix (f'.{ extension } ' )
209- kwargs = self .savefig_kwargs .copy ()
210- if extension == 'pdf' :
211- kwargs .setdefault ('metadata' ,
212- {'Creator' : None , 'Producer' : None ,
213- 'CreationDate' : None })
214- fig .savefig (actual_path , ** kwargs )
215-
216- expected_path = self .copy_baseline (baseline , extension )
217- _raise_on_image_difference (expected_path , actual_path , self .tol )
182+ _skip_if_incomparable (extension )
183+
184+ func (* args , ** kwargs )
185+
186+ fignums = plt .get_fignums ()
187+ assert len (fignums ) == len (baseline_images ), (
188+ "Test generated {} images but there are {} baseline images"
189+ .format (len (fignums ), len (baseline_images )))
190+ for baseline_image , fignum in zip (baseline_images , fignums ):
191+ fig = plt .figure (fignum )
192+ if remove_text :
193+ remove_ticks_and_titles (fig )
194+ actual_path = ((result_dir / baseline_image )
195+ .with_suffix (f'.{ extension } ' ))
196+ kwargs = savefig_kwargs .copy ()
197+ if extension == 'pdf' :
198+ kwargs .setdefault ('metadata' ,
199+ {'Creator' : None , 'Producer' : None ,
200+ 'CreationDate' : None })
201+ fig .savefig (actual_path , ** kwargs )
202+ expected_path = _copy_baseline (baseline_image )
203+ _raise_on_image_difference (expected_path , actual_path , tol )
204+
205+ return wrapper
218206
219207
220208def _pytest_image_comparison (baseline_images , extensions , tol ,
@@ -230,8 +218,6 @@ def _pytest_image_comparison(baseline_images, extensions, tol,
230218 """
231219 import pytest
232220
233- extensions = map (_mark_skip_if_format_is_uncomparable , extensions )
234-
235221 def decorator (func ):
236222 @functools .wraps (func )
237223 # Parameter indirection; see docstring above and comment below.
@@ -244,23 +230,19 @@ def decorator(func):
244230 @functools .wraps (func )
245231 def wrapper (* args , ** kwargs ):
246232 __tracebackhide__ = True
247- img = _ImageComparisonBase (func , tol = tol , remove_text = remove_text ,
248- savefig_kwargs = savefig_kwargs )
249- matplotlib .testing .set_font_settings_for_testing ()
250- func (* args , ** kwargs )
251-
252233 # Parameter indirection:
253234 # This is hacked on via the mpl_image_comparison_parameters fixture
254235 # so that we don't need to modify the function's real signature for
255236 # any parametrization. Modifying the signature is very very tricky
256237 # and likely to confuse pytest.
257238 baseline_images , extension = func .parameters
258239
259- assert len (plt .get_fignums ()) == len (baseline_images ), (
260- "Test generated {} images but there are {} baseline images"
261- .format (len (plt .get_fignums ()), len (baseline_images )))
262- for idx , baseline in enumerate (baseline_images ):
263- img .compare (idx , baseline , extension )
240+ matplotlib .testing .set_font_settings_for_testing ()
241+ comparator = _make_image_comparator (
242+ func ,
243+ baseline_images = baseline_images , extension = extension , tol = tol ,
244+ remove_text = remove_text , savefig_kwargs = savefig_kwargs )
245+ comparator (* args , ** kwargs )
264246
265247 return wrapper
266248
@@ -344,8 +326,6 @@ def image_comparison(baseline_images, extensions=None, tol=0,
344326 if extensions is None :
345327 # Default extensions to test, if not set via baseline_images.
346328 extensions = ['png' , 'pdf' , 'svg' ]
347- if savefig_kwarg is None :
348- savefig_kwarg = dict () # default no kwargs to savefig
349329 return _pytest_image_comparison (
350330 baseline_images = baseline_images , extensions = extensions , tol = tol ,
351331 freetype_version = freetype_version , remove_text = remove_text ,
@@ -390,6 +370,7 @@ def decorator(func):
390370 # Free-standing function.
391371 @pytest .mark .parametrize ("ext" , extensions )
392372 def wrapper (ext ):
373+ _skip_if_incomparable (ext )
393374 fig_test = plt .figure ("test" )
394375 fig_ref = plt .figure ("reference" )
395376 func (fig_test , fig_ref )
@@ -405,6 +386,7 @@ def wrapper(ext):
405386 # Method.
406387 @pytest .mark .parametrize ("ext" , extensions )
407388 def wrapper (self , ext ):
389+ _skip_if_incomparable (ext )
408390 fig_test = plt .figure ("test" )
409391 fig_ref = plt .figure ("reference" )
410392 func (self , fig_test , fig_ref )
0 commit comments