|
5 | 5 | import os |
6 | 6 | from pathlib import Path |
7 | 7 | import shutil |
| 8 | +import string |
8 | 9 | import sys |
9 | 10 | import unittest |
10 | 11 | import warnings |
|
17 | 18 | from matplotlib import ft2font |
18 | 19 | from matplotlib import pyplot as plt |
19 | 20 | from matplotlib import ticker |
20 | | -from . import is_called_from_pytest |
| 21 | + |
21 | 22 | from .compare import comparable_formats, compare_images, make_test_filename |
22 | 23 | from .exceptions import ImageComparisonFailure |
23 | 24 |
|
@@ -381,34 +382,50 @@ def test_plot(fig_test, fig_ref): |
381 | 382 | fig_test.subplots().plot([1, 3, 5]) |
382 | 383 | fig_ref.subplots().plot([0, 1, 2], [1, 3, 5]) |
383 | 384 | """ |
384 | | - POSITIONAL_OR_KEYWORD = inspect.Parameter.POSITIONAL_OR_KEYWORD |
| 385 | + ALLOWED_CHARS = set(string.digits + string.ascii_letters + '_-[]()') |
| 386 | + KEYWORD_ONLY = inspect.Parameter.KEYWORD_ONLY |
385 | 387 | def decorator(func): |
386 | 388 | import pytest |
387 | 389 |
|
388 | 390 | _, result_dir = _image_directories(func) |
| 391 | + old_sig = inspect.signature(func) |
389 | 392 |
|
390 | 393 | @pytest.mark.parametrize("ext", extensions) |
391 | | - def wrapper(*args, ext, **kwargs): |
392 | | - fig_test = plt.figure("test") |
393 | | - fig_ref = plt.figure("reference") |
394 | | - func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs) |
395 | | - test_image_path = result_dir / (func.__name__ + "." + ext) |
396 | | - ref_image_path = result_dir / ( |
397 | | - func.__name__ + "-expected." + ext |
398 | | - ) |
399 | | - fig_test.savefig(test_image_path) |
400 | | - fig_ref.savefig(ref_image_path) |
401 | | - _raise_on_image_difference( |
402 | | - ref_image_path, test_image_path, tol=tol |
403 | | - ) |
404 | | - |
405 | | - sig = inspect.signature(func) |
406 | | - new_sig = sig.replace( |
407 | | - parameters=([param |
408 | | - for param in sig.parameters.values() |
409 | | - if param.name not in {"fig_test", "fig_ref"}] |
410 | | - + [inspect.Parameter("ext", POSITIONAL_OR_KEYWORD)]) |
411 | | - ) |
| 394 | + def wrapper(*args, **kwargs): |
| 395 | + ext = kwargs['ext'] |
| 396 | + if 'ext' not in old_sig.parameters: |
| 397 | + kwargs.pop('ext') |
| 398 | + request = kwargs['request'] |
| 399 | + if 'request' not in old_sig.parameters: |
| 400 | + kwargs.pop('request') |
| 401 | + |
| 402 | + file_name = "".join(c for c in request.node.name |
| 403 | + if c in ALLOWED_CHARS) |
| 404 | + try: |
| 405 | + fig_test = plt.figure("test") |
| 406 | + fig_ref = plt.figure("reference") |
| 407 | + func(*args, fig_test=fig_test, fig_ref=fig_ref, **kwargs) |
| 408 | + test_image_path = result_dir / (file_name + "." + ext) |
| 409 | + ref_image_path = result_dir / (file_name + "-expected." + ext) |
| 410 | + fig_test.savefig(test_image_path) |
| 411 | + fig_ref.savefig(ref_image_path) |
| 412 | + _raise_on_image_difference( |
| 413 | + ref_image_path, test_image_path, tol=tol |
| 414 | + ) |
| 415 | + finally: |
| 416 | + plt.close(fig_test) |
| 417 | + plt.close(fig_ref) |
| 418 | + |
| 419 | + parameters = [ |
| 420 | + param |
| 421 | + for param in old_sig.parameters.values() |
| 422 | + if param.name not in {"fig_test", "fig_ref"} |
| 423 | + ] |
| 424 | + if 'ext' not in old_sig.parameters: |
| 425 | + parameters += [inspect.Parameter("ext", KEYWORD_ONLY)] |
| 426 | + if 'request' not in old_sig.parameters: |
| 427 | + parameters += [inspect.Parameter("request", KEYWORD_ONLY)] |
| 428 | + new_sig = old_sig.replace(parameters=parameters) |
412 | 429 | wrapper.__signature__ = new_sig |
413 | 430 |
|
414 | 431 | # reach a bit into pytest internals to hoist the marks from |
|
0 commit comments