Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 1f11538

Browse files
gantsevdenisilevkivskyi
authored andcommitted
Type-check class keyword in an __init_subclass__() call (#7452)
Fixes #7190
1 parent 855076b commit 1f11538

File tree

3 files changed

+195
-2
lines changed

3 files changed

+195
-2
lines changed

mypy/checker.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF,
2727
CONTRAVARIANT, COVARIANT, INVARIANT, TypeVarExpr, AssignmentExpr,
2828
is_final_node,
29-
)
29+
ARG_NAMED)
3030
from mypy import nodes
3131
from mypy.literals import literal, literal_hash
3232
from mypy.typeanal import has_any_from_unimported_type, check_for_explicit_any
@@ -1675,6 +1675,7 @@ def visit_class_def(self, defn: ClassDef) -> None:
16751675
with self.scope.push_class(defn.info):
16761676
self.accept(defn.defs)
16771677
self.binder = old_binder
1678+
self.check_init_subclass(defn)
16781679
if not defn.has_incompatible_baseclass:
16791680
# Otherwise we've already found errors; more errors are not useful
16801681
self.check_multiple_inheritance(typ)
@@ -1704,6 +1705,50 @@ def visit_class_def(self, defn: ClassDef) -> None:
17041705
if typ.is_protocol and typ.defn.type_vars:
17051706
self.check_protocol_variance(defn)
17061707

1708+
def check_init_subclass(self, defn: ClassDef) -> None:
1709+
"""Check that keywords in a class definition are valid arguments for __init_subclass__().
1710+
1711+
In this example:
1712+
1 class Base:
1713+
2 def __init_subclass__(cls, thing: int):
1714+
3 pass
1715+
4 class Child(Base, thing=5):
1716+
5 def __init_subclass__(cls):
1717+
6 pass
1718+
7 Child()
1719+
1720+
Base.__init_subclass__(thing=5) is called at line 4. This is what we simulate here.
1721+
Child.__init_subclass__ is never called.
1722+
"""
1723+
# At runtime, only Base.__init_subclass__ will be called, so
1724+
# we skip the current class itself.
1725+
for base in defn.info.mro[1:]:
1726+
if '__init_subclass__' not in base.names:
1727+
continue
1728+
name_expr = NameExpr(defn.name)
1729+
name_expr.node = base
1730+
callee = MemberExpr(name_expr, '__init_subclass__')
1731+
args = list(defn.keywords.values())
1732+
arg_names = list(defn.keywords.keys()) # type: List[Optional[str]]
1733+
# 'metaclass' keyword is consumed by the rest of the type machinery,
1734+
# and is never passed to __init_subclass__ implementations
1735+
if 'metaclass' in arg_names:
1736+
idx = arg_names.index('metaclass')
1737+
arg_names.pop(idx)
1738+
args.pop(idx)
1739+
arg_kinds = [ARG_NAMED] * len(args)
1740+
call_expr = CallExpr(callee, args, arg_kinds, arg_names)
1741+
call_expr.line = defn.line
1742+
call_expr.column = defn.column
1743+
call_expr.end_line = defn.end_line
1744+
self.expr_checker.accept(call_expr,
1745+
allow_none_return=True,
1746+
always_allow_any=True)
1747+
# We are only interested in the first Base having __init_subclass__
1748+
# all other bases have already been checked.
1749+
break
1750+
return
1751+
17071752
def check_protocol_variance(self, defn: ClassDef) -> None:
17081753
"""Check that protocol definition is compatible with declared
17091754
variances of type variables.

test-data/unit/check-classes.test

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ class B(A):
379379
[case testOverride__init_subclass__WithDifferentSignature]
380380
class A:
381381
def __init_subclass__(cls, x: int) -> None: pass
382-
class B(A):
382+
class B(A): # E: Too few arguments for "__init_subclass__" of "A"
383383
def __init_subclass__(cls) -> None: pass
384384

385385
[case testOverrideWithDecorator]
@@ -6164,6 +6164,102 @@ class C(B[int, T]):
61646164
# TODO: error message could be better.
61656165
self.x: Tuple[str, T] # E: Incompatible types in assignment (expression has type "Tuple[str, T]", base class "A" defined the type as "Tuple[int, T]")
61666166

6167+
[case testInitSubclassWrongType]
6168+
class Base:
6169+
default_name: str
6170+
6171+
def __init_subclass__(cls, default_name: str):
6172+
super().__init_subclass__()
6173+
cls.default_name = default_name
6174+
return
6175+
6176+
class Child(Base, default_name=5): # E: Argument "default_name" to "__init_subclass__" of "Base" has incompatible type "int"; expected "str"
6177+
pass
6178+
[builtins fixtures/object_with_init_subclass.pyi]
6179+
6180+
[case testInitSubclassTooFewArgs]
6181+
class Base:
6182+
default_name: str
6183+
6184+
def __init_subclass__(cls, default_name: str, **kwargs):
6185+
super().__init_subclass__()
6186+
cls.default_name = default_name
6187+
return
6188+
6189+
class Child(Base): # E: Too few arguments for "__init_subclass__" of "Base"
6190+
pass
6191+
[builtins fixtures/object_with_init_subclass.pyi]
6192+
6193+
[case testInitSubclassTooFewArgs2]
6194+
class Base:
6195+
default_name: str
6196+
6197+
def __init_subclass__(cls, default_name: str, thing: int):
6198+
super().__init_subclass__()
6199+
cls.default_name = default_name
6200+
return
6201+
# TODO implement this, so that no error is raised?
6202+
d = {"default_name": "abc", "thing": 0}
6203+
class Child(Base, **d): # E: Too few arguments for "__init_subclass__" of "Base"
6204+
pass
6205+
[builtins fixtures/object_with_init_subclass.pyi]
6206+
6207+
[case testInitSubclassOK]
6208+
class Base:
6209+
default_name: str
6210+
thing: int
6211+
6212+
def __init_subclass__(cls, default_name: str, thing:int, **kwargs):
6213+
super().__init_subclass__()
6214+
cls.default_name = default_name
6215+
return
6216+
6217+
class Child(Base, thing=5, default_name=""):
6218+
pass
6219+
[builtins fixtures/object_with_init_subclass.pyi]
6220+
6221+
[case testInitSubclassWithMetaclassOK]
6222+
class Base(type):
6223+
thing: int
6224+
6225+
def __init_subclass__(cls, thing: int):
6226+
cls.thing = thing
6227+
6228+
class Child(Base, metaclass=Base, thing=0):
6229+
pass
6230+
6231+
[case testTooManyArgsForObject]
6232+
class A(thing=5):
6233+
pass
6234+
[out]
6235+
main:1: error: Unexpected keyword argument "thing" for "__init_subclass__" of "object"
6236+
tmp/builtins.pyi:5: note: "__init_subclass__" of "object" defined here
6237+
[builtins fixtures/object_with_init_subclass.pyi]
6238+
6239+
[case testInitSubclassWithImports]
6240+
from init_subclass.a import Base
6241+
class Child(Base, thing=5): # E: Missing positional arguments "default_name", "kwargs" in call to "__init_subclass__" of "Base"
6242+
pass
6243+
[file init_subclass/a.py]
6244+
class Base:
6245+
default_name: str
6246+
thing: int
6247+
6248+
def __init_subclass__(cls, default_name: str, thing:int, **kwargs):
6249+
pass
6250+
[file init_subclass/__init__.py]
6251+
[builtins fixtures/object_with_init_subclass.pyi]
6252+
6253+
[case testInitSubclassWithImportsOK]
6254+
from init_subclass.a import MidBase
6255+
class Main(MidBase, test=True): pass
6256+
[file init_subclass/a.py]
6257+
class Base:
6258+
def __init_subclass__(cls, **kwargs) -> None: pass
6259+
class MidBase(Base): pass
6260+
[file init_subclass/__init__.py]
6261+
[builtins fixtures/object_with_init_subclass.pyi]
6262+
61676263
[case testOverrideGenericSelfClassMethod]
61686264
from typing import Generic, TypeVar, Type, List
61696265

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from typing import Sequence, Iterator, TypeVar, Mapping, Iterable, Optional, Union, overload, Tuple, Generic
2+
3+
class object:
4+
def __init__(self) -> None: ...
5+
def __init_subclass__(cls) -> None: ...
6+
7+
T = TypeVar('T')
8+
KT = TypeVar('KT')
9+
VT = TypeVar('VT')
10+
# copy pased from primitives.pyi
11+
class type:
12+
def __init__(self, x) -> None: pass
13+
14+
class int:
15+
# Note: this is a simplification of the actual signature
16+
def __init__(self, x: object = ..., base: int = ...) -> None: pass
17+
def __add__(self, i: int) -> int: pass
18+
class float:
19+
def __float__(self) -> float: pass
20+
class complex: pass
21+
class bool(int): pass
22+
class str(Sequence[str]):
23+
def __add__(self, s: str) -> str: pass
24+
def __iter__(self) -> Iterator[str]: pass
25+
def __contains__(self, other: object) -> bool: pass
26+
def __getitem__(self, item: int) -> str: pass
27+
def format(self, *args) -> str: pass
28+
class bytes(Sequence[int]):
29+
def __iter__(self) -> Iterator[int]: pass
30+
def __contains__(self, other: object) -> bool: pass
31+
def __getitem__(self, item: int) -> int: pass
32+
class bytearray: pass
33+
class tuple(Generic[T]): pass
34+
class function: pass
35+
class ellipsis: pass
36+
37+
# copy-pasted from dict.pyi
38+
class dict(Mapping[KT, VT]):
39+
@overload
40+
def __init__(self, **kwargs: VT) -> None: pass
41+
@overload
42+
def __init__(self, arg: Iterable[Tuple[KT, VT]], **kwargs: VT) -> None: pass
43+
def __getitem__(self, key: KT) -> VT: pass
44+
def __setitem__(self, k: KT, v: VT) -> None: pass
45+
def __iter__(self) -> Iterator[KT]: pass
46+
def __contains__(self, item: object) -> int: pass
47+
def update(self, a: Mapping[KT, VT]) -> None: pass
48+
@overload
49+
def get(self, k: KT) -> Optional[VT]: pass
50+
@overload
51+
def get(self, k: KT, default: Union[KT, T]) -> Union[VT, T]: pass
52+
def __len__(self) -> int: ...

0 commit comments

Comments
 (0)