66from pathlib import Path
77import shutil
88import string
9- import subprocess
109import sys
1110import unittest
1211import warnings
13- import pytest
14- try :
15- from contextlib import nullcontext
16- except ImportError :
17- from contextlib import ExitStack as nullcontext # Py3.6.
1812
1913import matplotlib .style
2014import matplotlib .units
@@ -205,7 +199,6 @@ def copy_baseline(self, baseline, extension):
205199 f"{ orig_expected_path } " ) from err
206200 return expected_fname
207201
208-
209202 def create_baseline (self , baseline , extension ):
210203 src_path = self .result_dir / baseline
211204 orig_src_path = src_path .with_suffix (f'.{ extension } ' )
@@ -224,12 +217,11 @@ def create_baseline(self, baseline, extension):
224217 raise ValueError ("Failed to put the image in the right place" )
225218 return dest_path
226219
227-
228- def compare ( self , idx , baseline , extension , * , _lock = False , generating = False ):
220+ def compare ( self , idx , baseline , extension , * , _lock = False ,
221+ generating = False ):
229222 __tracebackhide__ = True
230223 fignum = plt .get_fignums ()[idx ]
231224 fig = plt .figure (fignum )
232-
233225 if self .remove_text :
234226 remove_ticks_and_titles (fig )
235227
@@ -248,8 +240,8 @@ def compare(self, idx, baseline, extension, *, _lock=False, generating=False):
248240 self .create_baseline (baseline , extension )
249241 else :
250242 expected_path = self .copy_baseline (baseline , extension )
251- _raise_on_image_difference (expected_path , actual_path , self . tol )
252-
243+ _raise_on_image_difference (expected_path , actual_path ,
244+ self . tol )
253245
254246
255247def _pytest_image_comparison (baseline_images , extensions , tol ,
@@ -307,7 +299,8 @@ def wrapper(*args, extension, request, **kwargs):
307299 "Test generated {} images but there are {} baseline images"
308300 .format (len (plt .get_fignums ()), len (our_baseline_images )))
309301 for idx , baseline in enumerate (our_baseline_images ):
310- img .compare (idx , baseline , extension , _lock = needs_lock , generating = generate_images )
302+ img .compare (idx , baseline , extension , _lock = needs_lock ,
303+ generating = generate_images )
311304
312305 parameters = list (old_sig .parameters .values ())
313306 if 'extension' not in old_sig .parameters :
@@ -515,9 +508,34 @@ def _image_directories(func):
515508 doesn't exist.
516509 """
517510 module_path = Path (sys .modules [func .__module__ ].__file__ )
518- baseline_dir = (module_path .parent /
519- "baseline_images" /
520- module_path .stem )
511+ if func .__module__ .startswith ("matplotlib." ):
512+ try :
513+ import matplotlib_baseline_images
514+ baseline_dir = (Path (matplotlib_baseline_images .__file__ ).parent /
515+ module_path .stem )
516+ except :
517+ if (Path (__file__ ).parent / 'baseline_images' ).exists ():
518+ baseline_dir = (module_path .parent /
519+ "baseline_images" /
520+ module_path .stem )
521+ else :
522+ raise ImportError ("Not able to import"
523+ "matplotlib_baseline_images" )
524+ elif func .__module__ .startswith ("mpl_toolkits." ):
525+ try :
526+ import mpl_toolkits_baseline_images
527+ baseline_file = mpl_toolkits_baseline_images .__file__
528+ baseline_dir = (Path (baseline_file ).parent /
529+ module_path .stem )
530+ except :
531+ if (Path (mpl_toolkits_baseline_images .__file__ ).parent /
532+ module_path .stem ).exists ():
533+ baseline_dir = (module_path .parent /
534+ "baseline_images" /
535+ module_path .stem )
536+ else :
537+ raise ImportError ("Not able to import "
538+ "mpl_toolkits_baseline_images" )
521539 result_dir = Path ().resolve () / "result_images" / module_path .stem
522540 result_dir .mkdir (parents = True , exist_ok = True )
523541 return baseline_dir , result_dir
0 commit comments