|
3 | 3 | import ast |
4 | 4 | import collections |
5 | 5 | import dataclasses |
| 6 | +import re |
6 | 7 | import secrets |
7 | 8 | import sys |
8 | 9 | from functools import lru_cache |
|
14 | 15 | else: |
15 | 16 | from typing_extensions import TypeGuard |
16 | 17 |
|
| 18 | +from black.mode import Mode |
17 | 19 | from black.output import out |
18 | 20 | from black.report import NothingChanged |
19 | 21 |
|
@@ -64,6 +66,34 @@ def jupyter_dependencies_are_installed(*, warn: bool) -> bool: |
64 | 66 | return installed |
65 | 67 |
|
66 | 68 |
|
| 69 | +def validate_cell(src: str, mode: Mode) -> None: |
| 70 | + """Check that cell does not already contain TransformerManager transformations, |
| 71 | + or non-Python cell magics, which might cause tokenizer_rt to break because of |
| 72 | + indentations. |
| 73 | +
|
| 74 | + If a cell contains ``!ls``, then it'll be transformed to |
| 75 | + ``get_ipython().system('ls')``. However, if the cell originally contained |
| 76 | + ``get_ipython().system('ls')``, then it would get transformed in the same way: |
| 77 | +
|
| 78 | + >>> TransformerManager().transform_cell("get_ipython().system('ls')") |
| 79 | + "get_ipython().system('ls')\n" |
| 80 | + >>> TransformerManager().transform_cell("!ls") |
| 81 | + "get_ipython().system('ls')\n" |
| 82 | +
|
| 83 | + Due to the impossibility of safely roundtripping in such situations, cells |
| 84 | + containing transformed magics will be ignored. |
| 85 | + """ |
| 86 | + if any(transformed_magic in src for transformed_magic in TRANSFORMED_MAGICS): |
| 87 | + raise NothingChanged |
| 88 | + |
| 89 | + line = _get_code_start(src) |
| 90 | + if line.startswith("%%") and ( |
| 91 | + line.split(maxsplit=1)[0][2:] |
| 92 | + not in PYTHON_CELL_MAGICS | mode.python_cell_magics |
| 93 | + ): |
| 94 | + raise NothingChanged |
| 95 | + |
| 96 | + |
67 | 97 | def remove_trailing_semicolon(src: str) -> tuple[str, bool]: |
68 | 98 | """Remove trailing semicolon from Jupyter notebook cell. |
69 | 99 |
|
@@ -276,6 +306,21 @@ def unmask_cell(src: str, replacements: list[Replacement]) -> str: |
276 | 306 | return src |
277 | 307 |
|
278 | 308 |
|
| 309 | +def _get_code_start(src: str) -> str: |
| 310 | + """Provides the first line where the code starts. |
| 311 | +
|
| 312 | + Iterates over lines of code until it finds the first line that doesn't |
| 313 | + contain only empty spaces and comments. It removes any empty spaces at the |
| 314 | + start of the line and returns it. If such line doesn't exist, it returns an |
| 315 | + empty string. |
| 316 | + """ |
| 317 | + for match in re.finditer(".+", src): |
| 318 | + line = match.group(0).lstrip() |
| 319 | + if line and not line.startswith("#"): |
| 320 | + return line |
| 321 | + return "" |
| 322 | + |
| 323 | + |
279 | 324 | def _is_ipython_magic(node: ast.expr) -> TypeGuard[ast.Attribute]: |
280 | 325 | """Check if attribute is IPython magic. |
281 | 326 |
|
|
0 commit comments