66from pathlib import Path
77import shutil
88import string
9+ import subprocess
910import sys
1011import unittest
1112import warnings
13+ import pytest
14+ try :
15+ from contextlib import nullcontext
16+ except ImportError :
17+ from contextlib import ExitStack as nullcontext # Py3.6.
1218
1319import matplotlib .style
1420import matplotlib .units
@@ -199,7 +205,27 @@ def copy_baseline(self, baseline, extension):
199205 f"{ orig_expected_path } " ) from err
200206 return expected_fname
201207
202- def compare (self , idx , baseline , extension , * , _lock = False ):
208+
209+ def create_baseline (self , baseline , extension ):
210+ src_path = self .result_dir / baseline
211+ orig_src_path = src_path .with_suffix (f'.{ extension } ' )
212+ if extension == 'eps' and not orig_src_path .exists ():
213+ orig_src_path = orig_src_path .with_suffix ('.pdf' )
214+ dest_path = Path (self .baseline_dir / orig_src_path .name )
215+ try :
216+ if dest_path .exists ():
217+ return dest_path
218+ Path (dest_path ).parent .mkdir (parents = True , exist_ok = True )
219+ try :
220+ os .symlink (orig_src_path , dest_path )
221+ except OSError : # On Windows, symlink *may* be unavailable.
222+ shutil .copyfile (orig_src_path , dest_path )
223+ except OSError as err :
224+ raise ValueError ("Failed to put the image in the right place" )
225+ return dest_path
226+
227+
228+ def compare (self , idx , baseline , extension , * , _lock = False , generating = False ):
203229 __tracebackhide__ = True
204230 fignum = plt .get_fignums ()[idx ]
205231 fig = plt .figure (fignum )
@@ -218,8 +244,12 @@ def compare(self, idx, baseline, extension, *, _lock=False):
218244 if _lock else contextlib .nullcontext ())
219245 with lock :
220246 fig .savefig (actual_path , ** kwargs )
221- expected_path = self .copy_baseline (baseline , extension )
222- _raise_on_image_difference (expected_path , actual_path , self .tol )
247+ if generating :
248+ self .create_baseline (baseline , extension )
249+ else :
250+ expected_path = self .copy_baseline (baseline , extension )
251+ _raise_on_image_difference (expected_path , actual_path , self .tol )
252+
223253
224254
225255def _pytest_image_comparison (baseline_images , extensions , tol ,
@@ -244,12 +274,14 @@ def decorator(func):
244274 @pytest .mark .style (style )
245275 @_checked_on_freetype_version (freetype_version )
246276 @functools .wraps (func )
277+ @pytest .mark .matplotlib_baseline_image_generation
247278 def wrapper (* args , extension , request , ** kwargs ):
248279 __tracebackhide__ = True
249280 if 'extension' in old_sig .parameters :
250281 kwargs ['extension' ] = extension
251282 if 'request' in old_sig .parameters :
252283 kwargs ['request' ] = request
284+ matplotlib_baseline_image_generation = request .config .getoption ("--matplotlib_baseline_image_generation" )
253285
254286 img = _ImageComparisonBase (func , tol = tol , remove_text = remove_text ,
255287 savefig_kwargs = savefig_kwargs )
@@ -275,7 +307,7 @@ def wrapper(*args, extension, request, **kwargs):
275307 "Test generated {} images but there are {} baseline images"
276308 .format (len (plt .get_fignums ()), len (our_baseline_images )))
277309 for idx , baseline in enumerate (our_baseline_images ):
278- img .compare (idx , baseline , extension , _lock = needs_lock )
310+ img .compare (idx , baseline , extension , _lock = needs_lock , generating = matplotlib_baseline_image_generation )
279311
280312 parameters = list (old_sig .parameters .values ())
281313 if 'extension' not in old_sig .parameters :
@@ -483,25 +515,9 @@ def _image_directories(func):
483515 doesn't exist.
484516 """
485517 module_path = Path (sys .modules [func .__module__ ].__file__ )
486- if func .__module__ .startswith ("matplotlib." ):
487- try :
488- import matplotlib_baseline_images
489- except :
490- raise ImportError ("Not able to import matplotlib_baseline_images" )
491- baseline_dir = (Path (matplotlib_baseline_images .__file__ ).parent /
492- module_path .stem )
493- elif func .__module__ .startswith ("mpl_toolkits." ):
494- try :
495- import mpl_toolkits_baseline_images
496- except :
497- raise ImportError ("Not able to import "
498- "mpl_toolkits_baseline_images" )
499- baseline_dir = (Path (mpl_toolkits_baseline_images .__file__ ).parent /
500- module_path .stem )
501- else :
502- baseline_dir = (module_path .parent /
503- "baseline_images" /
504- module_path .stem )
518+ baseline_dir = (module_path .parent /
519+ "baseline_images" /
520+ module_path .stem )
505521 result_dir = Path ().resolve () / "result_images" / module_path .stem
506522 result_dir .mkdir (parents = True , exist_ok = True )
507523 return baseline_dir , result_dir
0 commit comments