diff --git a/Lib/annotationlib.py b/Lib/annotationlib.py index 37f51e69f94127..cd24679f30abee 100644 --- a/Lib/annotationlib.py +++ b/Lib/annotationlib.py @@ -38,6 +38,7 @@ class Format(enum.IntEnum): "__weakref__", "__arg__", "__globals__", + "__extra_names__", "__code__", "__ast_node__", "__cell__", @@ -82,6 +83,7 @@ def __init__( # is created through __class__ assignment on a _Stringifier object. self.__globals__ = None self.__cell__ = None + self.__extra_names__ = None # These are initially None but serve as a cache and may be set to a non-None # value later. self.__code__ = None @@ -151,6 +153,8 @@ def evaluate(self, *, globals=None, locals=None, type_params=None, owner=None): if not self.__forward_is_class__ or param_name not in globals: globals[param_name] = param locals.pop(param_name, None) + if self.__extra_names__: + locals = {**locals, **self.__extra_names__} arg = self.__forward_arg__ if arg.isidentifier() and not keyword.iskeyword(arg): @@ -231,6 +235,10 @@ def __eq__(self, other): and self.__forward_is_class__ == other.__forward_is_class__ and self.__cell__ == other.__cell__ and self.__owner__ == other.__owner__ + and ( + (tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None) == + (tuple(sorted(other.__extra_names__.items())) if other.__extra_names__ else None) + ) ) def __hash__(self): @@ -241,6 +249,7 @@ def __hash__(self): self.__forward_is_class__, self.__cell__, self.__owner__, + tuple(sorted(self.__extra_names__.items())) if self.__extra_names__ else None, )) def __or__(self, other): @@ -274,6 +283,7 @@ def __init__( cell=None, *, stringifier_dict, + extra_names=None, ): # Either an AST node or a simple str (for the common case where a ForwardRef # represent a single name). @@ -285,6 +295,7 @@ def __init__( self.__code__ = None self.__ast_node__ = node self.__globals__ = globals + self.__extra_names__ = extra_names self.__cell__ = cell self.__owner__ = owner self.__stringifier_dict__ = stringifier_dict @@ -292,28 +303,63 @@ def __init__( def __convert_to_ast(self, other): if isinstance(other, _Stringifier): if isinstance(other.__ast_node__, str): - return ast.Name(id=other.__ast_node__) - return other.__ast_node__ - elif isinstance(other, slice): + return ast.Name(id=other.__ast_node__), other.__extra_names__ + return other.__ast_node__, other.__extra_names__ + elif ( + # In STRING format we don't bother with the create_unique_name() dance; + # it's better to emit the repr() of the object instead of an opaque name. + self.__stringifier_dict__.format == Format.STRING + or other is None + or type(other) in (str, int, float, bool, complex) + ): + return ast.Constant(value=other), None + elif type(other) is dict: + extra_names = {} + keys = [] + values = [] + for key, value in other.items(): + new_key, new_extra_names = self.__convert_to_ast(key) + if new_extra_names is not None: + extra_names.update(new_extra_names) + keys.append(new_key) + new_value, new_extra_names = self.__convert_to_ast(value) + if new_extra_names is not None: + extra_names.update(new_extra_names) + values.append(new_value) + return ast.Dict(keys, values), extra_names + elif type(other) in (list, tuple, set): + extra_names = {} + elts = [] + for elt in other: + new_elt, new_extra_names = self.__convert_to_ast(elt) + if new_extra_names is not None: + extra_names.update(new_extra_names) + elts.append(new_elt) + ast_class = {list: ast.List, tuple: ast.Tuple, set: ast.Set}[type(other)] + return ast_class(elts), extra_names + else: + name = self.__stringifier_dict__.create_unique_name() + return ast.Name(id=name), {name: other} + + def __convert_to_ast_getitem(self, other): + if isinstance(other, slice): + extra_names = {} + + def conv(obj): + if obj is None: + return None + new_obj, new_extra_names = self.__convert_to_ast(obj) + if new_extra_names is not None: + extra_names.update(new_extra_names) + return new_obj + return ast.Slice( - lower=( - self.__convert_to_ast(other.start) - if other.start is not None - else None - ), - upper=( - self.__convert_to_ast(other.stop) - if other.stop is not None - else None - ), - step=( - self.__convert_to_ast(other.step) - if other.step is not None - else None - ), - ) + lower=conv(other.start), + upper=conv(other.stop), + step=conv(other.step), + ), extra_names else: - return ast.Constant(value=other) + return self.__convert_to_ast(other) def __get_ast(self): node = self.__ast_node__ @@ -321,13 +367,19 @@ def __get_ast(self): return ast.Name(id=node) return node - def __make_new(self, node): + def __make_new(self, node, extra_names=None): + new_extra_names = {} + if self.__extra_names__ is not None: + new_extra_names.update(self.__extra_names__) + if extra_names is not None: + new_extra_names.update(extra_names) stringifier = _Stringifier( node, self.__globals__, self.__owner__, self.__forward_is_class__, stringifier_dict=self.__stringifier_dict__, + extra_names=new_extra_names or None, ) self.__stringifier_dict__.stringifiers.append(stringifier) return stringifier @@ -343,27 +395,37 @@ def __getitem__(self, other): if self.__ast_node__ == "__classdict__": raise KeyError if isinstance(other, tuple): - elts = [self.__convert_to_ast(elt) for elt in other] + extra_names = {} + elts = [] + for elt in other: + new_elt, new_extra_names = self.__convert_to_ast_getitem(elt) + if new_extra_names is not None: + extra_names.update(new_extra_names) + elts.append(new_elt) other = ast.Tuple(elts) else: - other = self.__convert_to_ast(other) + other, extra_names = self.__convert_to_ast_getitem(other) assert isinstance(other, ast.AST), repr(other) - return self.__make_new(ast.Subscript(self.__get_ast(), other)) + return self.__make_new(ast.Subscript(self.__get_ast(), other), extra_names) def __getattr__(self, attr): return self.__make_new(ast.Attribute(self.__get_ast(), attr)) def __call__(self, *args, **kwargs): - return self.__make_new( - ast.Call( - self.__get_ast(), - [self.__convert_to_ast(arg) for arg in args], - [ - ast.keyword(key, self.__convert_to_ast(value)) - for key, value in kwargs.items() - ], - ) - ) + extra_names = {} + ast_args = [] + for arg in args: + new_arg, new_extra_names = self.__convert_to_ast(arg) + if new_extra_names is not None: + extra_names.update(new_extra_names) + ast_args.append(new_arg) + ast_kwargs = [] + for key, value in kwargs.items(): + new_value, new_extra_names = self.__convert_to_ast(value) + if new_extra_names is not None: + extra_names.update(new_extra_names) + ast_kwargs.append(ast.keyword(key, new_value)) + return self.__make_new(ast.Call(self.__get_ast(), ast_args, ast_kwargs), extra_names) def __iter__(self): yield self.__make_new(ast.Starred(self.__get_ast())) @@ -378,8 +440,9 @@ def __format__(self, format_spec): def _make_binop(op: ast.AST): def binop(self, other): + rhs, extra_names = self.__convert_to_ast(other) return self.__make_new( - ast.BinOp(self.__get_ast(), op, self.__convert_to_ast(other)) + ast.BinOp(self.__get_ast(), op, rhs), extra_names ) return binop @@ -402,8 +465,9 @@ def binop(self, other): def _make_rbinop(op: ast.AST): def rbinop(self, other): + new_other, extra_names = self.__convert_to_ast(other) return self.__make_new( - ast.BinOp(self.__convert_to_ast(other), op, self.__get_ast()) + ast.BinOp(new_other, op, self.__get_ast()), extra_names ) return rbinop @@ -426,12 +490,14 @@ def rbinop(self, other): def _make_compare(op): def compare(self, other): + rhs, extra_names = self.__convert_to_ast(other) return self.__make_new( ast.Compare( left=self.__get_ast(), ops=[op], - comparators=[self.__convert_to_ast(other)], - ) + comparators=[rhs], + ), + extra_names, ) return compare @@ -459,13 +525,15 @@ def unary_op(self): class _StringifierDict(dict): - def __init__(self, namespace, globals=None, owner=None, is_class=False): + def __init__(self, namespace, *, globals=None, owner=None, is_class=False, format): super().__init__(namespace) self.namespace = namespace self.globals = globals self.owner = owner self.is_class = is_class self.stringifiers = [] + self.next_id = 1 + self.format = format def __missing__(self, key): fwdref = _Stringifier( @@ -478,6 +546,11 @@ def __missing__(self, key): self.stringifiers.append(fwdref) return fwdref + def create_unique_name(self): + name = f"__annotationlib_name_{self.next_id}__" + self.next_id += 1 + return name + def call_evaluate_function(evaluate, format, *, owner=None): """Call an evaluate function. Evaluate functions are normally generated for @@ -521,7 +594,7 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): # possibly constants if the annotate function uses them directly). We then # convert each of those into a string to get an approximation of the # original source. - globals = _StringifierDict({}) + globals = _StringifierDict({}, format=format) if annotate.__closure__: freevars = annotate.__code__.co_freevars new_closure = [] @@ -544,9 +617,9 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): ) annos = func(Format.VALUE_WITH_FAKE_GLOBALS) if _is_evaluate: - return annos if isinstance(annos, str) else repr(annos) + return _stringify_single(annos) return { - key: val if isinstance(val, str) else repr(val) + key: _stringify_single(val) for key, val in annos.items() } elif format == Format.FORWARDREF: @@ -569,7 +642,13 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): # that returns a bool and an defined set of attributes. namespace = {**annotate.__builtins__, **annotate.__globals__} is_class = isinstance(owner, type) - globals = _StringifierDict(namespace, annotate.__globals__, owner, is_class) + globals = _StringifierDict( + namespace, + globals=annotate.__globals__, + owner=owner, + is_class=is_class, + format=format, + ) if annotate.__closure__: freevars = annotate.__code__.co_freevars new_closure = [] @@ -619,6 +698,16 @@ def call_annotate_function(annotate, format, *, owner=None, _is_evaluate=False): raise ValueError(f"Invalid format: {format!r}") +def _stringify_single(anno): + if anno is ...: + return "..." + # We have to handle str specially to support PEP 563 stringified annotations. + elif isinstance(anno, str): + return anno + else: + return repr(anno) + + def get_annotate_from_class_namespace(obj): """Retrieve the annotate function from a class namespace dictionary. diff --git a/Lib/test/test_annotationlib.py b/Lib/test/test_annotationlib.py index 404a8ccc9d3741..d9000b6392277e 100644 --- a/Lib/test/test_annotationlib.py +++ b/Lib/test/test_annotationlib.py @@ -121,6 +121,28 @@ def f( self.assertIsInstance(gamma_anno, ForwardRef) self.assertEqual(gamma_anno, support.EqualToForwardRef("some < obj", owner=f)) + def test_partially_nonexistent_union(self): + # Test unions with '|' syntax equal unions with typing.Union[] with some forwardrefs + class UnionForwardrefs: + pipe: str | undefined + union: Union[str, undefined] + + annos = get_annotations(UnionForwardrefs, format=Format.FORWARDREF) + + pipe = annos["pipe"] + self.assertIsInstance(pipe, ForwardRef) + self.assertEqual( + pipe.evaluate(globals={"undefined": int}), + str | int, + ) + union = annos["union"] + self.assertIsInstance(union, Union) + arg1, arg2 = typing.get_args(union) + self.assertIs(arg1, str) + self.assertEqual( + arg2, support.EqualToForwardRef("undefined", is_class=True, owner=UnionForwardrefs) + ) + class TestStringFormat(unittest.TestCase): def test_closure(self): @@ -251,6 +273,89 @@ def f( }, ) + def test_getitem(self): + def f(x: undef1[str, undef2]): + pass + anno = annotationlib.get_annotations(f, format=Format.STRING) + self.assertEqual(anno, {"x": "undef1[str, undef2]"}) + + anno = annotationlib.get_annotations(f, format=Format.FORWARDREF) + fwdref = anno["x"] + self.assertIsInstance(fwdref, ForwardRef) + self.assertEqual( + fwdref.evaluate(globals={"undef1": dict, "undef2": float}), dict[str, float] + ) + + def test_slice(self): + def f(x: a[b:c]): + pass + anno = annotationlib.get_annotations(f, format=Format.STRING) + self.assertEqual(anno, {"x": "a[b:c]"}) + + def f(x: a[b:c, d:e]): + pass + anno = annotationlib.get_annotations(f, format=Format.STRING) + self.assertEqual(anno, {"x": "a[b:c, d:e]"}) + + obj = slice(1, 1, 1) + def f(x: obj): + pass + anno = annotationlib.get_annotations(f, format=Format.STRING) + self.assertEqual(anno, {"x": "obj"}) + + def test_literals(self): + def f( + a: 1, + b: 1.0, + c: "hello", + d: b"hello", + e: True, + f: None, + g: ..., + h: 1j, + ): + pass + + anno = annotationlib.get_annotations(f, format=Format.STRING) + self.assertEqual( + anno, + { + "a": "1", + "b": "1.0", + "c": 'hello', + "d": "b'hello'", + "e": "True", + "f": "None", + "g": "...", + "h": "1j", + }, + ) + + def test_displays(self): + # Simple case first + def f(x: a[[int, str], float]): + pass + anno = annotationlib.get_annotations(f, format=Format.STRING) + self.assertEqual(anno, {"x": "a[[int, str], float]"}) + + def g( + w: a[[int, str], float], + x: a[{int, str}, 3], + y: a[{int: str}, 4], + z: a[(int, str), 5], + ): + pass + anno = annotationlib.get_annotations(g, format=Format.STRING) + self.assertEqual( + anno, + { + "w": "a[[int, str], float]", + "x": "a[{int, str}, 3]", + "y": "a[{int: str}, 4]", + "z": "a[(int, str), 5]", + }, + ) + def test_nested_expressions(self): def f( nested: list[Annotated[set[int], "set of ints", 4j]], @@ -296,6 +401,17 @@ def f(fstring_format: f"{a:02d}"): with self.assertRaisesRegex(TypeError, format_msg): get_annotations(f, format=Format.STRING) + def test_shenanigans(self): + # In cases like this we can't reconstruct the source; test that we do something + # halfway reasonable. + def f(x: x | (1).__class__, y: (1).__class__): + pass + + self.assertEqual( + get_annotations(f, format=Format.STRING), + {"x": "x | ", "y": ""}, + ) + class TestGetAnnotations(unittest.TestCase): def test_builtin_type(self): diff --git a/Misc/NEWS.d/next/Library/2025-04-22-13-42-12.gh-issue-132805.r-dhmJ.rst b/Misc/NEWS.d/next/Library/2025-04-22-13-42-12.gh-issue-132805.r-dhmJ.rst new file mode 100644 index 00000000000000..d62b95775a67c2 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2025-04-22-13-42-12.gh-issue-132805.r-dhmJ.rst @@ -0,0 +1,2 @@ +Fix incorrect handling of nested non-constant values in the FORWARDREF +format in :mod:`annotationlib`.