10
10
from importlib .abc import Loader
11
11
from pathlib import Path
12
12
from textwrap import indent
13
- from typing import TYPE_CHECKING , Any
13
+ from typing import TYPE_CHECKING , Any , Callable
14
14
from unittest .mock import patch
15
15
16
16
import pytest
@@ -34,6 +34,7 @@ def run_code(
34
34
loader : Loader | None ,
35
35
config : ExamplesConfig ,
36
36
enable_print_mock : bool ,
37
+ print_callback : Callable [[str ], str ] | None ,
37
38
module_globals : dict [str , Any ] | None ,
38
39
) -> tuple [InsertPrintStatements , dict [str , Any ]]:
39
40
__tracebackhide__ = True
@@ -42,7 +43,7 @@ def run_code(
42
43
module = importlib .util .module_from_spec (spec )
43
44
44
45
# does nothing if insert_print_statements is False
45
- insert_print = InsertPrintStatements (python_file , config , enable_print_mock )
46
+ insert_print = InsertPrintStatements (python_file , config , enable_print_mock , print_callback )
46
47
47
48
if module_globals :
48
49
module .__dict__ .update (module_globals )
@@ -123,10 +124,13 @@ def __call__(self, *args: Any, sep: str = ' ', **kwargs: Any) -> None:
123
124
124
125
125
126
class InsertPrintStatements :
126
- def __init__ (self , python_path : Path , config : ExamplesConfig , enable : bool ):
127
+ def __init__ (
128
+ self , python_path : Path , config : ExamplesConfig , enable : bool , print_callback : Callable [[str ], str ] | None
129
+ ):
127
130
self .file = python_path
128
131
self .config = config
129
132
self .print_func = MockPrintFunction (python_path ) if enable else None
133
+ self .print_callback = print_callback
130
134
self .patch = None
131
135
132
136
def __enter__ (self ) -> None :
@@ -176,6 +180,8 @@ def _insert_print_args(
176
180
self , lines : list [str ], statement : PrintStatement , in_python : bool , line_index : int , col : int
177
181
) -> None :
178
182
single_line = statement .sep .join (map (str , statement .args ))
183
+ if self .print_callback :
184
+ single_line = self .print_callback (single_line )
179
185
indent_str = ' ' * col
180
186
max_single_length = self .config .line_length - len (indent_str )
181
187
if '\n ' not in single_line and len (single_line ) + len (comment_prefix ) < max_single_length :
@@ -185,6 +191,8 @@ def _insert_print_args(
185
191
sep = f'{ statement .sep } \n '
186
192
indent_config = dataclasses .replace (self .config , line_length = max_single_length )
187
193
output = sep .join (arg .format (indent_config ).strip ('\n ' ) for arg in statement .args )
194
+ if self .print_callback :
195
+ output = self .print_callback (output )
188
196
# remove trailing whitespace
189
197
output = re .sub (r' +$' , '' , output , flags = re .MULTILINE )
190
198
# have to use triple single quotes in python since we're already in a double quotes docstring
0 commit comments