Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 66d7a5d

Browse files
miss-islington1st1
authored andcommitted
bpo-34776: Fix dataclasses to support __future__ "annotations" mode (GH-9518) (#17532)
(cherry picked from commit d219cc4) Co-authored-by: Yury Selivanov <[email protected]>
1 parent a0078d9 commit 66d7a5d

File tree

4 files changed

+78
-34
lines changed

4 files changed

+78
-34
lines changed

Lib/dataclasses.py

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -368,23 +368,24 @@ def _create_fn(name, args, body, *, globals=None, locals=None,
368368
# worries about external callers.
369369
if locals is None:
370370
locals = {}
371-
# __builtins__ may be the "builtins" module or
372-
# the value of its "__dict__",
373-
# so make sure "__builtins__" is the module.
374-
if globals is not None and '__builtins__' not in globals:
375-
globals['__builtins__'] = builtins
371+
if 'BUILTINS' not in locals:
372+
locals['BUILTINS'] = builtins
376373
return_annotation = ''
377374
if return_type is not MISSING:
378375
locals['_return_type'] = return_type
379376
return_annotation = '->_return_type'
380377
args = ','.join(args)
381-
body = '\n'.join(f' {b}' for b in body)
378+
body = '\n'.join(f' {b}' for b in body)
382379

383380
# Compute the text of the entire function.
384-
txt = f'def {name}({args}){return_annotation}:\n{body}'
381+
txt = f' def {name}({args}){return_annotation}:\n{body}'
385382

386-
exec(txt, globals, locals)
387-
return locals[name]
383+
local_vars = ', '.join(locals.keys())
384+
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
385+
386+
ns = {}
387+
exec(txt, globals, ns)
388+
return ns['__create_fn__'](**locals)
388389

389390

390391
def _field_assign(frozen, name, value, self_name):
@@ -395,7 +396,7 @@ def _field_assign(frozen, name, value, self_name):
395396
# self_name is what "self" is called in this function: don't
396397
# hard-code "self", since that might be a field name.
397398
if frozen:
398-
return f'__builtins__.object.__setattr__({self_name},{name!r},{value})'
399+
return f'BUILTINS.object.__setattr__({self_name},{name!r},{value})'
399400
return f'{self_name}.{name}={value}'
400401

401402

@@ -472,7 +473,7 @@ def _init_param(f):
472473
return f'{f.name}:_type_{f.name}{default}'
473474

474475

475-
def _init_fn(fields, frozen, has_post_init, self_name):
476+
def _init_fn(fields, frozen, has_post_init, self_name, globals):
476477
# fields contains both real fields and InitVar pseudo-fields.
477478

478479
# Make sure we don't have fields without defaults following fields
@@ -490,12 +491,15 @@ def _init_fn(fields, frozen, has_post_init, self_name):
490491
raise TypeError(f'non-default argument {f.name!r} '
491492
'follows default argument')
492493

493-
globals = {'MISSING': MISSING,
494-
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY}
494+
locals = {f'_type_{f.name}': f.type for f in fields}
495+
locals.update({
496+
'MISSING': MISSING,
497+
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY,
498+
})
495499

496500
body_lines = []
497501
for f in fields:
498-
line = _field_init(f, frozen, globals, self_name)
502+
line = _field_init(f, frozen, locals, self_name)
499503
# line is None means that this field doesn't require
500504
# initialization (it's a pseudo-field). Just skip it.
501505
if line:
@@ -511,7 +515,6 @@ def _init_fn(fields, frozen, has_post_init, self_name):
511515
if not body_lines:
512516
body_lines = ['pass']
513517

514-
locals = {f'_type_{f.name}': f.type for f in fields}
515518
return _create_fn('__init__',
516519
[self_name] + [_init_param(f) for f in fields if f.init],
517520
body_lines,
@@ -520,20 +523,19 @@ def _init_fn(fields, frozen, has_post_init, self_name):
520523
return_type=None)
521524

522525

523-
def _repr_fn(fields):
526+
def _repr_fn(fields, globals):
524527
fn = _create_fn('__repr__',
525528
('self',),
526529
['return self.__class__.__qualname__ + f"(' +
527530
', '.join([f"{f.name}={{self.{f.name}!r}}"
528531
for f in fields]) +
529-
')"'])
532+
')"'],
533+
globals=globals)
530534
return _recursive_repr(fn)
531535

532536

533-
def _frozen_get_del_attr(cls, fields):
534-
# XXX: globals is modified on the first call to _create_fn, then
535-
# the modified version is used in the second call. Is this okay?
536-
globals = {'cls': cls,
537+
def _frozen_get_del_attr(cls, fields, globals):
538+
locals = {'cls': cls,
537539
'FrozenInstanceError': FrozenInstanceError}
538540
if fields:
539541
fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
@@ -545,17 +547,19 @@ def _frozen_get_del_attr(cls, fields):
545547
(f'if type(self) is cls or name in {fields_str}:',
546548
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
547549
f'super(cls, self).__setattr__(name, value)'),
550+
locals=locals,
548551
globals=globals),
549552
_create_fn('__delattr__',
550553
('self', 'name'),
551554
(f'if type(self) is cls or name in {fields_str}:',
552555
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
553556
f'super(cls, self).__delattr__(name)'),
557+
locals=locals,
554558
globals=globals),
555559
)
556560

557561

558-
def _cmp_fn(name, op, self_tuple, other_tuple):
562+
def _cmp_fn(name, op, self_tuple, other_tuple, globals):
559563
# Create a comparison function. If the fields in the object are
560564
# named 'x' and 'y', then self_tuple is the string
561565
# '(self.x,self.y)' and other_tuple is the string
@@ -565,14 +569,16 @@ def _cmp_fn(name, op, self_tuple, other_tuple):
565569
('self', 'other'),
566570
[ 'if other.__class__ is self.__class__:',
567571
f' return {self_tuple}{op}{other_tuple}',
568-
'return NotImplemented'])
572+
'return NotImplemented'],
573+
globals=globals)
569574

570575

571-
def _hash_fn(fields):
576+
def _hash_fn(fields, globals):
572577
self_tuple = _tuple_str('self', fields)
573578
return _create_fn('__hash__',
574579
('self',),
575-
[f'return hash({self_tuple})'])
580+
[f'return hash({self_tuple})'],
581+
globals=globals)
576582

577583

578584
def _is_classvar(a_type, typing):
@@ -744,14 +750,14 @@ def _set_new_attribute(cls, name, value):
744750
# take. The common case is to do nothing, so instead of providing a
745751
# function that is a no-op, use None to signify that.
746752

747-
def _hash_set_none(cls, fields):
753+
def _hash_set_none(cls, fields, globals):
748754
return None
749755

750-
def _hash_add(cls, fields):
756+
def _hash_add(cls, fields, globals):
751757
flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
752-
return _hash_fn(flds)
758+
return _hash_fn(flds, globals)
753759

754-
def _hash_exception(cls, fields):
760+
def _hash_exception(cls, fields, globals):
755761
# Raise an exception.
756762
raise TypeError(f'Cannot overwrite attribute __hash__ '
757763
f'in class {cls.__name__}')
@@ -793,6 +799,16 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
793799
# is defined by the base class, which is found first.
794800
fields = {}
795801

802+
if cls.__module__ in sys.modules:
803+
globals = sys.modules[cls.__module__].__dict__
804+
else:
805+
# Theoretically this can happen if someone writes
806+
# a custom string to cls.__module__. In which case
807+
# such dataclass won't be fully introspectable
808+
# (w.r.t. typing.get_type_hints) but will still function
809+
# correctly.
810+
globals = {}
811+
796812
setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
797813
unsafe_hash, frozen))
798814

@@ -902,6 +918,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
902918
# if possible.
903919
'__dataclass_self__' if 'self' in fields
904920
else 'self',
921+
globals,
905922
))
906923

907924
# Get the fields as a list, and include only real fields. This is
@@ -910,7 +927,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
910927

911928
if repr:
912929
flds = [f for f in field_list if f.repr]
913-
_set_new_attribute(cls, '__repr__', _repr_fn(flds))
930+
_set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))
914931

915932
if eq:
916933
# Create _eq__ method. There's no need for a __ne__ method,
@@ -920,7 +937,8 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
920937
other_tuple = _tuple_str('other', flds)
921938
_set_new_attribute(cls, '__eq__',
922939
_cmp_fn('__eq__', '==',
923-
self_tuple, other_tuple))
940+
self_tuple, other_tuple,
941+
globals=globals))
924942

925943
if order:
926944
# Create and set the ordering methods.
@@ -933,13 +951,14 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
933951
('__ge__', '>='),
934952
]:
935953
if _set_new_attribute(cls, name,
936-
_cmp_fn(name, op, self_tuple, other_tuple)):
954+
_cmp_fn(name, op, self_tuple, other_tuple,
955+
globals=globals)):
937956
raise TypeError(f'Cannot overwrite attribute {name} '
938957
f'in class {cls.__name__}. Consider using '
939958
'functools.total_ordering')
940959

941960
if frozen:
942-
for fn in _frozen_get_del_attr(cls, field_list):
961+
for fn in _frozen_get_del_attr(cls, field_list, globals):
943962
if _set_new_attribute(cls, fn.__name__, fn):
944963
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
945964
f'in class {cls.__name__}')
@@ -952,7 +971,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
952971
if hash_action:
953972
# No need to call _set_new_attribute here, since by the time
954973
# we're here the overwriting is unconditional.
955-
cls.__hash__ = hash_action(cls, field_list)
974+
cls.__hash__ = hash_action(cls, field_list, globals)
956975

957976
if not getattr(cls, '__doc__'):
958977
# Create a class doc-string.

Lib/test/dataclass_textanno.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from __future__ import annotations
2+
3+
import dataclasses
4+
5+
6+
class Foo:
7+
pass
8+
9+
10+
@dataclasses.dataclass
11+
class Bar:
12+
foo: Foo

Lib/test/test_dataclasses.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import unittest
1111
from unittest.mock import Mock
1212
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
13+
from typing import get_type_hints
1314
from collections import deque, OrderedDict, namedtuple
1415
from functools import total_ordering
1516

@@ -2918,6 +2919,17 @@ def test_classvar_module_level_import(self):
29182919
# won't exist on the instance.
29192920
self.assertNotIn('not_iv4', c.__dict__)
29202921

2922+
def test_text_annotations(self):
2923+
from test import dataclass_textanno
2924+
2925+
self.assertEqual(
2926+
get_type_hints(dataclass_textanno.Bar),
2927+
{'foo': dataclass_textanno.Foo})
2928+
self.assertEqual(
2929+
get_type_hints(dataclass_textanno.Bar.__init__),
2930+
{'foo': dataclass_textanno.Foo,
2931+
'return': type(None)})
2932+
29212933

29222934
class TestMakeDataclass(unittest.TestCase):
29232935
def test_simple(self):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix dataclasses to support forward references in type annotations

0 commit comments

Comments
 (0)