@@ -205,7 +205,7 @@ def copy_baseline(self, baseline, extension):
205205 f"{ orig_expected_path } " ) from err
206206 return expected_fname
207207
208- def compare (self , idx , baseline , extension , * , _lock = False ):
208+ def compare (self , idx , baseline , extension , * , _lock = False , generating = False ):
209209 __tracebackhide__ = True
210210 fignum = plt .get_fignums ()[idx ]
211211 fig = plt .figure (fignum )
@@ -222,9 +222,12 @@ def compare(self, idx, baseline, extension, *, _lock=False):
222222
223223 lock = cbook ._lock_path (actual_path ) if _lock else nullcontext ()
224224 with lock :
225- fig .savefig (actual_path , ** kwargs )
226- expected_path = self .copy_baseline (baseline , extension )
227- _raise_on_image_difference (expected_path , actual_path , self .tol )
225+ if (generating ):
226+ fig .savefig (actual_path , ** kwargs )
227+ else :
228+ fig .savefig (actual_path , ** kwargs )
229+ expected_path = self .copy_baseline (baseline , extension )
230+ _raise_on_image_difference (expected_path , actual_path , self .tol )
228231
229232
230233def _pytest_image_comparison (baseline_images , extensions , tol ,
@@ -255,6 +258,7 @@ def wrapper(*args, extension, request, **kwargs):
255258 kwargs ['extension' ] = extension
256259 if 'request' in old_sig .parameters :
257260 kwargs ['request' ] = request
261+ matplotlib_baseline_image_generation = request .config .getoption ("--matplotlib_baseline_image_generation" )
258262
259263 img = _ImageComparisonBase (func , tol = tol , remove_text = remove_text ,
260264 savefig_kwargs = savefig_kwargs )
@@ -280,7 +284,7 @@ def wrapper(*args, extension, request, **kwargs):
280284 "Test generated {} images but there are {} baseline images"
281285 .format (len (plt .get_fignums ()), len (our_baseline_images )))
282286 for idx , baseline in enumerate (our_baseline_images ):
283- img .compare (idx , baseline , extension , _lock = needs_lock )
287+ img .compare (idx , baseline , extension , _lock = needs_lock , generating = matplotlib_baseline_image_generation )
284288
285289 parameters = list (old_sig .parameters .values ())
286290 if 'extension' not in old_sig .parameters :
0 commit comments