Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 8c976f8

Browse files
committed
squashed
1 parent 5005428 commit 8c976f8

File tree

3 files changed

+167
-25
lines changed

3 files changed

+167
-25
lines changed

mypy/plugins/attrs.py

Lines changed: 85 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22

33
from __future__ import annotations
44

5-
from typing import Iterable, List, cast
5+
from collections import defaultdict
6+
from functools import reduce
7+
from typing import Iterable, List, Mapping, cast
68
from typing_extensions import Final, Literal
79

810
import mypy.plugin # To avoid circular imports.
911
from mypy.applytype import apply_generic_arguments
1012
from mypy.checker import TypeChecker
1113
from mypy.errorcodes import LITERAL_REQ
12-
from mypy.expandtype import expand_type
14+
from mypy.expandtype import expand_type, expand_type_by_instance
1315
from mypy.exprtotype import TypeTranslationError, expr_to_unanalyzed_type
16+
from mypy.meet import meet_types
1417
from mypy.messages import format_type_bare
1518
from mypy.nodes import (
1619
ARG_NAMED,
@@ -67,6 +70,7 @@
6770
Type,
6871
TypeOfAny,
6972
TypeVarType,
73+
UninhabitedType,
7074
UnionType,
7175
get_proper_type,
7276
)
@@ -943,12 +947,79 @@ def _get_attrs_init_type(typ: Instance) -> CallableType | None:
943947
return init_method.type
944948

945949

946-
def _get_attrs_cls_and_init(typ: ProperType) -> tuple[Instance | None, CallableType | None]:
950+
def _format_not_attrs_class_failure(t: Type, parent_t: Type) -> str:
951+
t_name = format_type_bare(t)
952+
if parent_t is t:
953+
return (
954+
f'Argument 1 to "evolve" has a variable type "{t_name}" not bound to an attrs class'
955+
if isinstance(t, TypeVarType)
956+
else f'Argument 1 to "evolve" has incompatible type "{t_name}"; expected an attrs class'
957+
)
958+
else:
959+
pt_name = format_type_bare(parent_t)
960+
return (
961+
f'Argument 1 to "evolve" has type "{pt_name}" whose item "{t_name}" is not bound to an attrs class'
962+
if isinstance(t, TypeVarType)
963+
else f'Argument 1 to "evolve" has incompatible type "{pt_name}" whose item "{t_name}" is not an attrs class'
964+
)
965+
966+
967+
def _get_expanded_attr_types(
968+
ctx: mypy.plugin.FunctionSigContext,
969+
typ: ProperType,
970+
display_typ: ProperType,
971+
parent_typ: ProperType,
972+
) -> list[Mapping[str, Type]] | None:
973+
"""
974+
For a given type, determine what attrs classes it can be, and returns the field types for each class.
975+
For generic classes, the field types are expanded.
976+
If the type contains Any or a non-attrs type, returns None; in the latter case, also reports an error.
977+
"""
978+
if isinstance(typ, AnyType):
979+
return None
980+
if isinstance(typ, UnionType):
981+
types = []
982+
had_errors = False
983+
for item in typ.relevant_items():
984+
item = get_proper_type(item)
985+
item_types = _get_expanded_attr_types(ctx, item, item, parent_typ)
986+
if isinstance(item_types, list):
987+
types += item_types
988+
else:
989+
had_errors = True
990+
if had_errors:
991+
return None
992+
return types
947993
if isinstance(typ, TypeVarType):
948-
typ = get_proper_type(typ.upper_bound)
994+
return _get_expanded_attr_types(
995+
ctx, get_proper_type(typ.upper_bound), display_typ, parent_typ
996+
)
949997
if not isinstance(typ, Instance):
950-
return None, None
951-
return typ, _get_attrs_init_type(typ)
998+
ctx.api.fail(_format_not_attrs_class_failure(display_typ, parent_typ), ctx.context)
999+
return None
1000+
init_func = _get_attrs_init_type(typ)
1001+
if init_func is None:
1002+
ctx.api.fail(_format_not_attrs_class_failure(display_typ, parent_typ), ctx.context)
1003+
return None
1004+
init_func = expand_type_by_instance(init_func, typ)
1005+
return [dict(zip(init_func.arg_names[1:], init_func.arg_types[1:]))]
1006+
1007+
1008+
def _meet_fields(types: list[Mapping[str, Type]]) -> Mapping[str, Type]:
1009+
"""
1010+
"Meets" the fields of a list of attrs classes, i.e. for each field, its new type will be the lower bound.
1011+
"""
1012+
field_to_types = defaultdict(list)
1013+
for fields in types:
1014+
for name, typ in fields.items():
1015+
field_to_types[name].append(typ)
1016+
1017+
return {
1018+
name: get_proper_type(reduce(meet_types, f_types))
1019+
if len(f_types) == len(types)
1020+
else UninhabitedType()
1021+
for name, f_types in field_to_types.items()
1022+
}
9521023

9531024

9541025
def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> CallableType:
@@ -972,27 +1043,18 @@ def evolve_function_sig_callback(ctx: mypy.plugin.FunctionSigContext) -> Callabl
9721043
# </hack>
9731044

9741045
inst_type = get_proper_type(inst_type)
975-
if isinstance(inst_type, AnyType):
976-
return ctx.default_signature # evolve(Any, ....) -> Any
9771046
inst_type_str = format_type_bare(inst_type)
9781047

979-
attrs_type, attrs_init_type = _get_attrs_cls_and_init(inst_type)
980-
if attrs_type is None or attrs_init_type is None:
981-
ctx.api.fail(
982-
f'Argument 1 to "evolve" has a variable type "{inst_type_str}" not bound to an attrs class'
983-
if isinstance(inst_type, TypeVarType)
984-
else f'Argument 1 to "evolve" has incompatible type "{inst_type_str}"; expected an attrs class',
985-
ctx.context,
986-
)
1048+
attr_types = _get_expanded_attr_types(ctx, inst_type, inst_type, inst_type)
1049+
if attr_types is None:
9871050
return ctx.default_signature
1051+
fields = _meet_fields(attr_types)
9881052

989-
# AttrClass.__init__ has the following signature (or similar, if having kw-only & defaults):
990-
# def __init__(self, attr1: Type1, attr2: Type2) -> None:
991-
# We want to generate a signature for evolve that looks like this:
992-
# def evolve(inst: AttrClass, *, attr1: Type1 = ..., attr2: Type2 = ...) -> AttrClass:
993-
return attrs_init_type.copy_modified(
994-
arg_names=["inst"] + attrs_init_type.arg_names[1:],
995-
arg_kinds=[ARG_POS] + [ARG_NAMED_OPT for _ in attrs_init_type.arg_kinds[1:]],
1053+
return CallableType(
1054+
arg_names=["inst", *fields.keys()],
1055+
arg_kinds=[ARG_POS] + [ARG_NAMED_OPT] * len(fields),
1056+
arg_types=[inst_type, *fields.values()],
9961057
ret_type=inst_type,
1058+
fallback=ctx.default_signature.fallback,
9971059
name=f"{ctx.default_signature.name} of {inst_type_str}",
9981060
)

test-data/unit/check-attr.test

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1970,6 +1970,81 @@ reveal_type(ret) # N: Revealed type is "Any"
19701970

19711971
[typing fixtures/typing-medium.pyi]
19721972

1973+
[case testEvolveGeneric]
1974+
import attrs
1975+
from typing import Generic, TypeVar
1976+
1977+
T = TypeVar('T')
1978+
1979+
@attrs.define
1980+
class A(Generic[T]):
1981+
x: T
1982+
1983+
1984+
a = A(x=42)
1985+
reveal_type(a) # N: Revealed type is "__main__.A[builtins.int]"
1986+
a2 = attrs.evolve(a, x=42)
1987+
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"
1988+
a2 = attrs.evolve(a, x='42') # E: Argument "x" to "evolve" of "A[int]" has incompatible type "str"; expected "int"
1989+
reveal_type(a2) # N: Revealed type is "__main__.A[builtins.int]"
1990+
1991+
[builtins fixtures/attr.pyi]
1992+
1993+
[case testEvolveUnion]
1994+
# flags: --python-version 3.10
1995+
from typing import Generic, TypeVar
1996+
import attrs
1997+
1998+
T = TypeVar('T')
1999+
2000+
2001+
@attrs.define
2002+
class A(Generic[T]):
2003+
x: T # exercises meet(T=int, int) = int
2004+
y: bool # exercises meet(bool, int) = bool
2005+
z: str # exercises meet(str, bytes) = <nothing>
2006+
w: dict # exercises meet(dict, <nothing>) = <nothing>
2007+
2008+
2009+
@attrs.define
2010+
class B:
2011+
x: int
2012+
y: bool
2013+
z: bytes
2014+
2015+
2016+
a_or_b: A[int] | B
2017+
a2 = attrs.evolve(a_or_b, x=42, y=True)
2018+
a2 = attrs.evolve(a_or_b, x=42, y=True, z='42') # E: Argument "z" to "evolve" of "Union[A[int], B]" has incompatible type "str"; expected <nothing>
2019+
a2 = attrs.evolve(a_or_b, x=42, y=True, w={}) # E: Argument "w" to "evolve" of "Union[A[int], B]" has incompatible type "Dict[<nothing>, <nothing>]"; expected <nothing>
2020+
2021+
[builtins fixtures/attr.pyi]
2022+
2023+
[case testEvolveUnionOfTypeVar]
2024+
# flags: --python-version 3.10
2025+
import attrs
2026+
from typing import TypeVar
2027+
2028+
@attrs.define
2029+
class A:
2030+
x: int
2031+
y: int
2032+
z: str
2033+
w: dict
2034+
2035+
2036+
class B:
2037+
pass
2038+
2039+
TA = TypeVar('TA', bound=A)
2040+
TB = TypeVar('TB', bound=B)
2041+
2042+
def f(b_or_t: TA | TB | int) -> None:
2043+
a2 = attrs.evolve(b_or_t) # E: Argument 1 to "evolve" has type "Union[TA, TB, int]" whose item "TB" is not bound to an attrs class # E: Argument 1 to "evolve" has incompatible type "Union[TA, TB, int]" whose item "int" is not an attrs class
2044+
2045+
2046+
[builtins fixtures/attr.pyi]
2047+
19732048
[case testEvolveTypeVarBound]
19742049
import attrs
19752050
from typing import TypeVar
@@ -1997,11 +2072,12 @@ f(B(x=42))
19972072

19982073
[case testEvolveTypeVarBoundNonAttrs]
19992074
import attrs
2000-
from typing import TypeVar
2075+
from typing import Union, TypeVar
20012076

20022077
TInt = TypeVar('TInt', bound=int)
20032078
TAny = TypeVar('TAny')
20042079
TNone = TypeVar('TNone', bound=None)
2080+
TUnion = TypeVar('TUnion', bound=Union[str, int])
20052081

20062082
def f(t: TInt) -> None:
20072083
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has a variable type "TInt" not bound to an attrs class
@@ -2012,6 +2088,10 @@ def g(t: TAny) -> None:
20122088
def h(t: TNone) -> None:
20132089
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has a variable type "TNone" not bound to an attrs class
20142090

2091+
def x(t: TUnion) -> None:
2092+
_ = attrs.evolve(t, x=42) # E: Argument 1 to "evolve" has incompatible type "TUnion" whose item "str" is not an attrs class # E: Argument 1 to "evolve" has incompatible type "TUnion" whose item "int" is not an attrs class
2093+
2094+
20152095
[builtins fixtures/attr.pyi]
20162096

20172097
[case testEvolveTypeVarConstrained]

test-data/unit/fixtures/attr.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ class object:
99
class type: pass
1010
class bytes: pass
1111
class function: pass
12-
class bool: pass
1312
class float: pass
1413
class int:
1514
@overload
1615
def __init__(self, x: Union[str, bytes, int] = ...) -> None: ...
1716
@overload
1817
def __init__(self, x: Union[str, bytes], base: int) -> None: ...
18+
class bool(int): pass
1919
class complex:
2020
@overload
2121
def __init__(self, real: float = ..., im: float = ...) -> None: ...

0 commit comments

Comments
 (0)