diff --git a/pytest_examples/eval_example.py b/pytest_examples/eval_example.py index 2dc5b93..ddedbb4 100644 --- a/pytest_examples/eval_example.py +++ b/pytest_examples/eval_example.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable import pytest from _pytest.assertion.rewrite import AssertionRewritingHook @@ -30,6 +30,7 @@ def __init__(self, *, tmp_path: Path, pytest_request: pytest.FixtureRequest): self._test_id = pytest_request.node.nodeid self.to_update: list[CodeExample] = [] self.config: ExamplesConfig = ExamplesConfig() + self.print_callback: Callable[[str], str] | None = None def set_config( self, @@ -159,7 +160,9 @@ def _run( enable_print_mock = False python_file = self._write_file(example) - return run_code(example, python_file, loader, self.config, enable_print_mock, module_globals) + return run_code( + example, python_file, loader, self.config, enable_print_mock, self.print_callback, module_globals + ) def lint(self, example: CodeExample) -> None: """ diff --git a/pytest_examples/run_code.py b/pytest_examples/run_code.py index 8f16f3c..4706c15 100644 --- a/pytest_examples/run_code.py +++ b/pytest_examples/run_code.py @@ -10,7 +10,7 @@ from importlib.abc import Loader from pathlib import Path from textwrap import indent -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Callable from unittest.mock import patch import pytest @@ -34,6 +34,7 @@ def run_code( loader: Loader | None, config: ExamplesConfig, enable_print_mock: bool, + print_callback: Callable[[str], str] | None, module_globals: dict[str, Any] | None, ) -> tuple[InsertPrintStatements, dict[str, Any]]: __tracebackhide__ = True @@ -42,7 +43,7 @@ def run_code( module = importlib.util.module_from_spec(spec) # does nothing if insert_print_statements is False - insert_print = InsertPrintStatements(python_file, config, enable_print_mock) + insert_print = InsertPrintStatements(python_file, config, enable_print_mock, print_callback) if module_globals: module.__dict__.update(module_globals) @@ -123,10 +124,13 @@ def __call__(self, *args: Any, sep: str = ' ', **kwargs: Any) -> None: class InsertPrintStatements: - def __init__(self, python_path: Path, config: ExamplesConfig, enable: bool): + def __init__( + self, python_path: Path, config: ExamplesConfig, enable: bool, print_callback: Callable[[str], str] | None + ): self.file = python_path self.config = config self.print_func = MockPrintFunction(python_path) if enable else None + self.print_callback = print_callback self.patch = None def __enter__(self) -> None: @@ -176,6 +180,8 @@ def _insert_print_args( self, lines: list[str], statement: PrintStatement, in_python: bool, line_index: int, col: int ) -> None: single_line = statement.sep.join(map(str, statement.args)) + if self.print_callback: + single_line = self.print_callback(single_line) indent_str = ' ' * col max_single_length = self.config.line_length - len(indent_str) if '\n' not in single_line and len(single_line) + len(comment_prefix) < max_single_length: @@ -185,6 +191,8 @@ def _insert_print_args( sep = f'{statement.sep}\n' indent_config = dataclasses.replace(self.config, line_length=max_single_length) output = sep.join(arg.format(indent_config).strip('\n') for arg in statement.args) + if self.print_callback: + output = self.print_callback(output) # remove trailing whitespace output = re.sub(r' +$', '', output, flags=re.MULTILINE) # have to use triple single quotes in python since we're already in a double quotes docstring diff --git a/tests/test_run_examples.py b/tests/test_run_examples.py index 28ebc84..6943329 100644 --- a/tests/test_run_examples.py +++ b/tests/test_run_examples.py @@ -248,3 +248,52 @@ def div(y): assert exc_info.traceback[-2].frame.code.path == md_file assert exc_info.traceback[-2].lineno == 9 + + +def test_print_sub(pytester: pytest.Pytester): + pytester.makefile( + '.md', + # language=Markdown + my_file=''' +# My file + +```py +print('hello') +#> hello +print('1/2/3') +#> X/X/X +print({f'{i} key': i for i in range(8)}) +""" +{ + 'X key': X, + 'X key': X, + 'X key': X, + 'X key': X, + 'X key': X, + 'X key': X, + 'X key': X, + 'X key': X, +} +""" +``` + ''', + ) + # language=Python + pytester.makepyfile( + r""" +import re +from pytest_examples import find_examples, CodeExample, EvalExample +import pytest + +def print_sub(print_statement): + return re.sub(r'[0-9]+', 'X', print_statement) + +@pytest.mark.parametrize('example', find_examples('.'), ids=str) +def test_find_run_examples(example: CodeExample, eval_example: EvalExample): + eval_example.print_callback = print_sub + eval_example.run_print_check(example, rewrite_assertions=False) +""" + ) + + result = pytester.runpytest('-p', 'no:pretty', '-v') + result.assert_outcomes(passed=1)