Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 95f822d

Browse files
committed
Consider docstring and PEP 263 in lib_updater.py
1 parent 7a37446 commit 95f822d

File tree

1 file changed

+36
-4
lines changed

1 file changed

+36
-4
lines changed

scripts/lib_updater.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -280,13 +280,45 @@ def has_unittest_import(tree: ast.Module) -> bool:
280280
return False
281281

282282

283-
def find_import_insert_line(tree: ast.Module) -> int:
284-
"""Find the line number after the last import statement."""
283+
def find_import_insert_line(tree: ast.Module, lines: list[str] | None = None) -> int:
284+
"""Find the line number after the last import statement.
285+
286+
If no imports exist, returns the line after the module docstring
287+
(if present). If there's no docstring either, attempts to skip
288+
shebang and encoding lines to comply with PEP 263.
289+
"""
285290
last_import_line = 0
286291
for node in tree.body:
287292
if isinstance(node, (ast.Import, ast.ImportFrom)):
288293
last_import_line = node.end_lineno or node.lineno
289-
return last_import_line
294+
295+
# If we found imports, return after the last one
296+
if last_import_line > 0:
297+
return last_import_line
298+
299+
# No imports found - check for module docstring
300+
if tree.body:
301+
first_node = tree.body[0]
302+
if (
303+
isinstance(first_node, ast.Expr)
304+
and isinstance(first_node.value, ast.Constant)
305+
and isinstance(first_node.value.value, str)
306+
):
307+
return first_node.end_lineno or first_node.lineno
308+
309+
# No imports and no docstring - try to skip shebang/encoding if lines provided
310+
# PEP 263: encoding declaration must be in first or second line
311+
# and match: ^[ \t\f]*#.*?coding[:=][ \t]*([-_.a-zA-Z0-9]+)
312+
if lines:
313+
insert_line = 0
314+
for i, line in enumerate(lines[:2]): # Only check first two lines per PEP 263
315+
if line.startswith("#!") or re.match(
316+
r"^[ \t\f]*#.*?coding[:=][ \t]*[-_.a-zA-Z0-9]+", line
317+
):
318+
insert_line = i + 1
319+
return insert_line
320+
321+
return 0
290322

291323

292324
def apply_patches(contents: str, patches: Patches) -> str:
@@ -297,7 +329,7 @@ def apply_patches(contents: str, patches: Patches) -> str:
297329

298330
# If we have modifications and unittest is not imported, add it
299331
if modifications and not has_unittest_import(tree):
300-
import_line = find_import_insert_line(tree)
332+
import_line = find_import_insert_line(tree, lines)
301333
modifications.append(
302334
(
303335
import_line,

0 commit comments

Comments
 (0)