Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit d219cc4

Browse files
1st1ambv
authored andcommitted
bpo-34776: Fix dataclasses to support __future__ "annotations" mode (#9518)
1 parent bba873e commit d219cc4

File tree

4 files changed

+78
-34
lines changed

4 files changed

+78
-34
lines changed

Lib/dataclasses.py

Lines changed: 53 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -378,23 +378,24 @@ def _create_fn(name, args, body, *, globals=None, locals=None,
378378
# worries about external callers.
379379
if locals is None:
380380
locals = {}
381-
# __builtins__ may be the "builtins" module or
382-
# the value of its "__dict__",
383-
# so make sure "__builtins__" is the module.
384-
if globals is not None and '__builtins__' not in globals:
385-
globals['__builtins__'] = builtins
381+
if 'BUILTINS' not in locals:
382+
locals['BUILTINS'] = builtins
386383
return_annotation = ''
387384
if return_type is not MISSING:
388385
locals['_return_type'] = return_type
389386
return_annotation = '->_return_type'
390387
args = ','.join(args)
391-
body = '\n'.join(f' {b}' for b in body)
388+
body = '\n'.join(f' {b}' for b in body)
392389

393390
# Compute the text of the entire function.
394-
txt = f'def {name}({args}){return_annotation}:\n{body}'
391+
txt = f' def {name}({args}){return_annotation}:\n{body}'
395392

396-
exec(txt, globals, locals)
397-
return locals[name]
393+
local_vars = ', '.join(locals.keys())
394+
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
395+
396+
ns = {}
397+
exec(txt, globals, ns)
398+
return ns['__create_fn__'](**locals)
398399

399400

400401
def _field_assign(frozen, name, value, self_name):
@@ -405,7 +406,7 @@ def _field_assign(frozen, name, value, self_name):
405406
# self_name is what "self" is called in this function: don't
406407
# hard-code "self", since that might be a field name.
407408
if frozen:
408-
return f'__builtins__.object.__setattr__({self_name},{name!r},{value})'
409+
return f'BUILTINS.object.__setattr__({self_name},{name!r},{value})'
409410
return f'{self_name}.{name}={value}'
410411

411412

@@ -482,7 +483,7 @@ def _init_param(f):
482483
return f'{f.name}:_type_{f.name}{default}'
483484

484485

485-
def _init_fn(fields, frozen, has_post_init, self_name):
486+
def _init_fn(fields, frozen, has_post_init, self_name, globals):
486487
# fields contains both real fields and InitVar pseudo-fields.
487488

488489
# Make sure we don't have fields without defaults following fields
@@ -500,12 +501,15 @@ def _init_fn(fields, frozen, has_post_init, self_name):
500501
raise TypeError(f'non-default argument {f.name!r} '
501502
'follows default argument')
502503

503-
globals = {'MISSING': MISSING,
504-
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY}
504+
locals = {f'_type_{f.name}': f.type for f in fields}
505+
locals.update({
506+
'MISSING': MISSING,
507+
'_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY,
508+
})
505509

506510
body_lines = []
507511
for f in fields:
508-
line = _field_init(f, frozen, globals, self_name)
512+
line = _field_init(f, frozen, locals, self_name)
509513
# line is None means that this field doesn't require
510514
# initialization (it's a pseudo-field). Just skip it.
511515
if line:
@@ -521,7 +525,6 @@ def _init_fn(fields, frozen, has_post_init, self_name):
521525
if not body_lines:
522526
body_lines = ['pass']
523527

524-
locals = {f'_type_{f.name}': f.type for f in fields}
525528
return _create_fn('__init__',
526529
[self_name] + [_init_param(f) for f in fields if f.init],
527530
body_lines,
@@ -530,20 +533,19 @@ def _init_fn(fields, frozen, has_post_init, self_name):
530533
return_type=None)
531534

532535

533-
def _repr_fn(fields):
536+
def _repr_fn(fields, globals):
534537
fn = _create_fn('__repr__',
535538
('self',),
536539
['return self.__class__.__qualname__ + f"(' +
537540
', '.join([f"{f.name}={{self.{f.name}!r}}"
538541
for f in fields]) +
539-
')"'])
542+
')"'],
543+
globals=globals)
540544
return _recursive_repr(fn)
541545

542546

543-
def _frozen_get_del_attr(cls, fields):
544-
# XXX: globals is modified on the first call to _create_fn, then
545-
# the modified version is used in the second call. Is this okay?
546-
globals = {'cls': cls,
547+
def _frozen_get_del_attr(cls, fields, globals):
548+
locals = {'cls': cls,
547549
'FrozenInstanceError': FrozenInstanceError}
548550
if fields:
549551
fields_str = '(' + ','.join(repr(f.name) for f in fields) + ',)'
@@ -555,17 +557,19 @@ def _frozen_get_del_attr(cls, fields):
555557
(f'if type(self) is cls or name in {fields_str}:',
556558
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
557559
f'super(cls, self).__setattr__(name, value)'),
560+
locals=locals,
558561
globals=globals),
559562
_create_fn('__delattr__',
560563
('self', 'name'),
561564
(f'if type(self) is cls or name in {fields_str}:',
562565
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
563566
f'super(cls, self).__delattr__(name)'),
567+
locals=locals,
564568
globals=globals),
565569
)
566570

567571

568-
def _cmp_fn(name, op, self_tuple, other_tuple):
572+
def _cmp_fn(name, op, self_tuple, other_tuple, globals):
569573
# Create a comparison function. If the fields in the object are
570574
# named 'x' and 'y', then self_tuple is the string
571575
# '(self.x,self.y)' and other_tuple is the string
@@ -575,14 +579,16 @@ def _cmp_fn(name, op, self_tuple, other_tuple):
575579
('self', 'other'),
576580
[ 'if other.__class__ is self.__class__:',
577581
f' return {self_tuple}{op}{other_tuple}',
578-
'return NotImplemented'])
582+
'return NotImplemented'],
583+
globals=globals)
579584

580585

581-
def _hash_fn(fields):
586+
def _hash_fn(fields, globals):
582587
self_tuple = _tuple_str('self', fields)
583588
return _create_fn('__hash__',
584589
('self',),
585-
[f'return hash({self_tuple})'])
590+
[f'return hash({self_tuple})'],
591+
globals=globals)
586592

587593

588594
def _is_classvar(a_type, typing):
@@ -755,14 +761,14 @@ def _set_new_attribute(cls, name, value):
755761
# take. The common case is to do nothing, so instead of providing a
756762
# function that is a no-op, use None to signify that.
757763

758-
def _hash_set_none(cls, fields):
764+
def _hash_set_none(cls, fields, globals):
759765
return None
760766

761-
def _hash_add(cls, fields):
767+
def _hash_add(cls, fields, globals):
762768
flds = [f for f in fields if (f.compare if f.hash is None else f.hash)]
763-
return _hash_fn(flds)
769+
return _hash_fn(flds, globals)
764770

765-
def _hash_exception(cls, fields):
771+
def _hash_exception(cls, fields, globals):
766772
# Raise an exception.
767773
raise TypeError(f'Cannot overwrite attribute __hash__ '
768774
f'in class {cls.__name__}')
@@ -804,6 +810,16 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
804810
# is defined by the base class, which is found first.
805811
fields = {}
806812

813+
if cls.__module__ in sys.modules:
814+
globals = sys.modules[cls.__module__].__dict__
815+
else:
816+
# Theoretically this can happen if someone writes
817+
# a custom string to cls.__module__. In which case
818+
# such dataclass won't be fully introspectable
819+
# (w.r.t. typing.get_type_hints) but will still function
820+
# correctly.
821+
globals = {}
822+
807823
setattr(cls, _PARAMS, _DataclassParams(init, repr, eq, order,
808824
unsafe_hash, frozen))
809825

@@ -913,6 +929,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
913929
# if possible.
914930
'__dataclass_self__' if 'self' in fields
915931
else 'self',
932+
globals,
916933
))
917934

918935
# Get the fields as a list, and include only real fields. This is
@@ -921,7 +938,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
921938

922939
if repr:
923940
flds = [f for f in field_list if f.repr]
924-
_set_new_attribute(cls, '__repr__', _repr_fn(flds))
941+
_set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))
925942

926943
if eq:
927944
# Create _eq__ method. There's no need for a __ne__ method,
@@ -931,7 +948,8 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
931948
other_tuple = _tuple_str('other', flds)
932949
_set_new_attribute(cls, '__eq__',
933950
_cmp_fn('__eq__', '==',
934-
self_tuple, other_tuple))
951+
self_tuple, other_tuple,
952+
globals=globals))
935953

936954
if order:
937955
# Create and set the ordering methods.
@@ -944,13 +962,14 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
944962
('__ge__', '>='),
945963
]:
946964
if _set_new_attribute(cls, name,
947-
_cmp_fn(name, op, self_tuple, other_tuple)):
965+
_cmp_fn(name, op, self_tuple, other_tuple,
966+
globals=globals)):
948967
raise TypeError(f'Cannot overwrite attribute {name} '
949968
f'in class {cls.__name__}. Consider using '
950969
'functools.total_ordering')
951970

952971
if frozen:
953-
for fn in _frozen_get_del_attr(cls, field_list):
972+
for fn in _frozen_get_del_attr(cls, field_list, globals):
954973
if _set_new_attribute(cls, fn.__name__, fn):
955974
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
956975
f'in class {cls.__name__}')
@@ -963,7 +982,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen):
963982
if hash_action:
964983
# No need to call _set_new_attribute here, since by the time
965984
# we're here the overwriting is unconditional.
966-
cls.__hash__ = hash_action(cls, field_list)
985+
cls.__hash__ = hash_action(cls, field_list, globals)
967986

968987
if not getattr(cls, '__doc__'):
969988
# Create a class doc-string.

Lib/test/dataclass_textanno.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from __future__ import annotations
2+
3+
import dataclasses
4+
5+
6+
class Foo:
7+
pass
8+
9+
10+
@dataclasses.dataclass
11+
class Bar:
12+
foo: Foo

Lib/test/test_dataclasses.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import unittest
1111
from unittest.mock import Mock
1212
from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
13+
from typing import get_type_hints
1314
from collections import deque, OrderedDict, namedtuple
1415
from functools import total_ordering
1516

@@ -2926,6 +2927,17 @@ def test_classvar_module_level_import(self):
29262927
# won't exist on the instance.
29272928
self.assertNotIn('not_iv4', c.__dict__)
29282929

2930+
def test_text_annotations(self):
2931+
from test import dataclass_textanno
2932+
2933+
self.assertEqual(
2934+
get_type_hints(dataclass_textanno.Bar),
2935+
{'foo': dataclass_textanno.Foo})
2936+
self.assertEqual(
2937+
get_type_hints(dataclass_textanno.Bar.__init__),
2938+
{'foo': dataclass_textanno.Foo,
2939+
'return': type(None)})
2940+
29292941

29302942
class TestMakeDataclass(unittest.TestCase):
29312943
def test_simple(self):
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix dataclasses to support forward references in type annotations

0 commit comments

Comments
 (0)