From 06c36253831b01539fa10f464136bc17ddc0bb6e Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 19 Aug 2023 16:03:49 +0100 Subject: [PATCH 1/4] Experiment: lenient handling of trivial Callable suffixes --- mypy/messages.py | 3 ++ mypy/subtypes.py | 20 ++++++++++-- test-data/unit/check-callable.test | 31 +++++++++++++++++++ test-data/unit/check-modules.test | 12 +++---- .../unit/check-parameter-specification.test | 7 ++++- 5 files changed, 63 insertions(+), 10 deletions(-) diff --git a/mypy/messages.py b/mypy/messages.py index aab30ee29108..9a38f76e03ae 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -2123,6 +2123,9 @@ def report_protocol_problems( not is_subtype(subtype, erase_type(supertype), options=self.options) or not subtype.type.defn.type_vars or not supertype.type.defn.type_vars + # Always show detailed message for ParamSpec + or subtype.type.has_param_spec_type + or supertype.type.has_param_spec_type ): type_name = format_type(subtype, self.options, module_names=True) self.note(f"Following member(s) of {type_name} have conflicts:", context, code=code) diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 11847858c62c..98117a2edbc2 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1474,6 +1474,18 @@ def are_trivial_parameters(param: Parameters | NormalizedCallableType) -> bool: ) +def is_trivial_suffix(param: Parameters | NormalizedCallableType) -> bool: + param_star = param.var_arg() + param_star2 = param.kw_arg() + return ( + param.arg_kinds[-2:] == [ARG_STAR, ARG_STAR2] + and param_star is not None + and isinstance(get_proper_type(param_star.typ), AnyType) + and param_star2 is not None + and isinstance(get_proper_type(param_star2.typ), AnyType) + ) + + def are_parameters_compatible( left: Parameters | NormalizedCallableType, right: Parameters | NormalizedCallableType, @@ -1497,6 +1509,8 @@ def are_parameters_compatible( if are_trivial_parameters(right): return True + trivial_suffix = is_trivial_suffix(right) + # Match up corresponding arguments and check them for compatibility. In # every pair (argL, argR) of corresponding arguments from L and R, argL must # be "more general" than argR if L is to be a subtype of R. @@ -1526,7 +1540,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N if right_arg is None: return False if left_arg is None: - return not allow_partial_overlap + return not allow_partial_overlap and not trivial_suffix return not is_compat(right_arg.typ, left_arg.typ) if _incompatible(left_star, right_star) or _incompatible(left_star2, right_star2): @@ -1549,7 +1563,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N # Phase 1c: Check var args. Right has an infinite series of optional positional # arguments. Get all further positional args of left, and make sure # they're more general then the corresponding member in right. - if right_star is not None: + if right_star is not None and not trivial_suffix: # Synthesize an anonymous formal argument for the right right_by_position = right.try_synthesizing_arg_from_vararg(None) assert right_by_position is not None @@ -1576,7 +1590,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N # Phase 1d: Check kw args. Right has an infinite series of optional named # arguments. Get all further named args of left, and make sure # they're more general then the corresponding member in right. - if right_star2 is not None: + if right_star2 is not None and not trivial_suffix: right_names = {name for name in right.arg_names if name is not None} left_only_names = set() for name, kind in zip(left.arg_names, left.arg_kinds): diff --git a/test-data/unit/check-callable.test b/test-data/unit/check-callable.test index 07c42de74bb3..8a611a689be5 100644 --- a/test-data/unit/check-callable.test +++ b/test-data/unit/check-callable.test @@ -598,3 +598,34 @@ a: A a() # E: Missing positional argument "other" in call to "__call__" of "A" a(a) a(lambda: None) + +[case testCallableSubtypingTrivialSuffix] +from typing import Any, Protocol + +class Call(Protocol): + def __call__(self, x: int, *args: Any, **kwargs: Any) -> None: ... + +def f1() -> None: ... +a1: Call = f1 # E: Incompatible types in assignment (expression has type "Callable[[], None]", variable has type "Call") \ + # N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]" +def f2(x: str) -> None: ... +a2: Call = f2 # E: Incompatible types in assignment (expression has type "Callable[[str], None]", variable has type "Call") \ + # N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]" +def f3(y: int) -> None: ... +a3: Call = f3 # E: Incompatible types in assignment (expression has type "Callable[[int], None]", variable has type "Call") \ + # N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]" +def f4(x: int) -> None: ... +a4: Call = f4 + +def f5(x: int, y: int) -> None: ... +a5: Call = f5 + +def f6(x: int, y: int = 0) -> None: ... +a6: Call = f6 + +def f7(x: int, *, y: int) -> None: ... +a7: Call = f7 + +def f8(x: int, *args: int, **kwargs: str) -> None: ... +a8: Call = f8 +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-modules.test b/test-data/unit/check-modules.test index 3da5996ed274..94368f6c1113 100644 --- a/test-data/unit/check-modules.test +++ b/test-data/unit/check-modules.test @@ -3193,7 +3193,7 @@ from test1 import aaaa # E: Module "test1" has no attribute "aaaa" import b [file a.py] class Foo: - def frobnicate(self, x, *args, **kwargs): pass + def frobnicate(self, x: str, *args, **kwargs): pass [file b.py] from a import Foo class Bar(Foo): @@ -3201,21 +3201,21 @@ class Bar(Foo): [file b.py.2] from a import Foo class Bar(Foo): - def frobnicate(self, *args) -> None: pass + def frobnicate(self, *args: int) -> None: pass [file b.py.3] from a import Foo class Bar(Foo): - def frobnicate(self, *args) -> None: pass # type: ignore[override] # I know + def frobnicate(self, *args: int) -> None: pass # type: ignore[override] # I know [builtins fixtures/dict.pyi] [out1] tmp/b.py:3: error: Signature of "frobnicate" incompatible with supertype "Foo" tmp/b.py:3: note: Superclass: -tmp/b.py:3: note: def frobnicate(self, x: Any, *args: Any, **kwargs: Any) -> Any +tmp/b.py:3: note: def frobnicate(self, x: str, *args: Any, **kwargs: Any) -> Any tmp/b.py:3: note: Subclass: tmp/b.py:3: note: def frobnicate(self) -> None [out2] tmp/b.py:3: error: Signature of "frobnicate" incompatible with supertype "Foo" tmp/b.py:3: note: Superclass: -tmp/b.py:3: note: def frobnicate(self, x: Any, *args: Any, **kwargs: Any) -> Any +tmp/b.py:3: note: def frobnicate(self, x: str, *args: Any, **kwargs: Any) -> Any tmp/b.py:3: note: Subclass: -tmp/b.py:3: note: def frobnicate(self, *args: Any) -> None +tmp/b.py:3: note: def frobnicate(self, *args: int) -> None diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index dee8a971f925..80bfe23a4948 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1669,7 +1669,12 @@ class A(Protocol[P]): ... def bar(b: A[P]) -> A[Concatenate[int, P]]: - return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]") + return b # E: Incompatible return value type (got "A[P]", expected "A[[int, **P]]") \ + # N: Following member(s) of "A[P]" have conflicts: \ + # N: Expected: \ + # N: def foo(self, int, /, *args: P.args, **kwargs: P.kwargs) -> Any \ + # N: Got: \ + # N: def foo(self, *args: P.args, **kwargs: P.kwargs) -> Any [builtins fixtures/paramspec.pyi] [case testParamSpecPrefixSubtypingValidNonStrict] From daf25c5a1b90560653122c60f7110b61c519bd7b Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 26 Aug 2023 18:53:38 +0100 Subject: [PATCH 2/4] Limit the lenient approach only to erased types --- mypy/checker.py | 4 +- mypy/erasetype.py | 7 ++ mypy/subtypes.py | 3 +- mypy/typeops.py | 4 + mypy/types.py | 7 ++ test-data/unit/check-callable.test | 31 ----- test-data/unit/check-modules.test | 12 +- .../unit/check-parameter-specification.test | 106 ++++++++++++++++++ test-data/unit/fixtures/paramspec.pyi | 1 + 9 files changed, 136 insertions(+), 39 deletions(-) diff --git a/mypy/checker.py b/mypy/checker.py index a44601b83e21..631585b3a1ff 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1207,7 +1207,9 @@ def check_func_def( ): if defn.is_class or defn.name == "__new__": ref_type = mypy.types.TypeType.make_normalized(ref_type) - erased = get_proper_type(erase_to_bound(arg_type)) + # This level of erasure matches the one in checkmember.check_self_arg(), + # better keep these two checks consistent. + erased = get_proper_type(erase_typevars(erase_to_bound(arg_type))) if not is_subtype(ref_type, erased, ignore_type_params=True): if ( isinstance(erased, Instance) diff --git a/mypy/erasetype.py b/mypy/erasetype.py index fbbb4f80b578..a969784df134 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -175,6 +175,13 @@ def visit_param_spec(self, t: ParamSpecType) -> Type: return self.replacement return t + def visit_callable_type(self, t: CallableType) -> Type: + result = super().visit_callable_type(t) + if t.param_spec(): + assert isinstance(result, ProperType) and isinstance(result, CallableType) + result.erased = True + return result + def visit_type_alias_type(self, t: TypeAliasType) -> Type: # Type alias target can't contain bound type variables (not bound by the type # alias itself), so it is safe to just erase the arguments. diff --git a/mypy/subtypes.py b/mypy/subtypes.py index af123b903c74..410206d6e511 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1510,7 +1510,8 @@ def are_parameters_compatible( if are_trivial_parameters(right): return True - trivial_suffix = is_trivial_suffix(right) + # Parameters should not contain nested ParamSpec, so erasure doesn't make them less general. + trivial_suffix = isinstance(right, CallableType) and right.erased and is_trivial_suffix(right) # Match up corresponding arguments and check them for compatibility. In # every pair (argL, argR) of corresponding arguments from L and R, argL must diff --git a/mypy/typeops.py b/mypy/typeops.py index 0e0bc348942e..4020745c5dd1 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -253,6 +253,10 @@ def supported_self_type(typ: ProperType) -> bool: """ if isinstance(typ, TypeType): return supported_self_type(typ.item) + if isinstance(typ, CallableType): + # Special case: allow class callable instead of Type[...] as cls annotation, + # as well as callable self for callback protocols. + return True return isinstance(typ, TypeVarType) or ( isinstance(typ, Instance) and typ != fill_typevars(typ.type) ) diff --git a/mypy/types.py b/mypy/types.py index cf2c343655dd..e7b4a65cbe59 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1775,6 +1775,7 @@ class CallableType(FunctionLike): # (this is used for error messages) "imprecise_arg_kinds", "unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable? + "erased", # Is this callable created as an erased form of a more precise type? ) def __init__( @@ -1800,6 +1801,7 @@ def __init__( from_concatenate: bool = False, imprecise_arg_kinds: bool = False, unpack_kwargs: bool = False, + erased: bool = False, ) -> None: super().__init__(line, column) assert len(arg_types) == len(arg_kinds) == len(arg_names) @@ -1847,6 +1849,7 @@ def __init__( self.def_extras = {} self.type_guard = type_guard self.unpack_kwargs = unpack_kwargs + self.erased = erased def copy_modified( self: CT, @@ -1870,6 +1873,7 @@ def copy_modified( from_concatenate: Bogus[bool] = _dummy, imprecise_arg_kinds: Bogus[bool] = _dummy, unpack_kwargs: Bogus[bool] = _dummy, + erased: Bogus[bool] = _dummy, ) -> CT: modified = CallableType( arg_types=arg_types if arg_types is not _dummy else self.arg_types, @@ -1900,6 +1904,7 @@ def copy_modified( else self.imprecise_arg_kinds ), unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs, + erased=erased if erased is not _dummy else self.erased, ) # Optimization: Only NewTypes are supported as subtypes since # the class is effectively final, so we can use a cast safely. @@ -2213,6 +2218,7 @@ def serialize(self) -> JsonDict: "from_concatenate": self.from_concatenate, "imprecise_arg_kinds": self.imprecise_arg_kinds, "unpack_kwargs": self.unpack_kwargs, + "erased": self.erased, } @classmethod @@ -2237,6 +2243,7 @@ def deserialize(cls, data: JsonDict) -> CallableType: from_concatenate=data["from_concatenate"], imprecise_arg_kinds=data["imprecise_arg_kinds"], unpack_kwargs=data["unpack_kwargs"], + erased=data["erased"], ) diff --git a/test-data/unit/check-callable.test b/test-data/unit/check-callable.test index 8a611a689be5..07c42de74bb3 100644 --- a/test-data/unit/check-callable.test +++ b/test-data/unit/check-callable.test @@ -598,34 +598,3 @@ a: A a() # E: Missing positional argument "other" in call to "__call__" of "A" a(a) a(lambda: None) - -[case testCallableSubtypingTrivialSuffix] -from typing import Any, Protocol - -class Call(Protocol): - def __call__(self, x: int, *args: Any, **kwargs: Any) -> None: ... - -def f1() -> None: ... -a1: Call = f1 # E: Incompatible types in assignment (expression has type "Callable[[], None]", variable has type "Call") \ - # N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]" -def f2(x: str) -> None: ... -a2: Call = f2 # E: Incompatible types in assignment (expression has type "Callable[[str], None]", variable has type "Call") \ - # N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]" -def f3(y: int) -> None: ... -a3: Call = f3 # E: Incompatible types in assignment (expression has type "Callable[[int], None]", variable has type "Call") \ - # N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]" -def f4(x: int) -> None: ... -a4: Call = f4 - -def f5(x: int, y: int) -> None: ... -a5: Call = f5 - -def f6(x: int, y: int = 0) -> None: ... -a6: Call = f6 - -def f7(x: int, *, y: int) -> None: ... -a7: Call = f7 - -def f8(x: int, *args: int, **kwargs: str) -> None: ... -a8: Call = f8 -[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-modules.test b/test-data/unit/check-modules.test index 94368f6c1113..3da5996ed274 100644 --- a/test-data/unit/check-modules.test +++ b/test-data/unit/check-modules.test @@ -3193,7 +3193,7 @@ from test1 import aaaa # E: Module "test1" has no attribute "aaaa" import b [file a.py] class Foo: - def frobnicate(self, x: str, *args, **kwargs): pass + def frobnicate(self, x, *args, **kwargs): pass [file b.py] from a import Foo class Bar(Foo): @@ -3201,21 +3201,21 @@ class Bar(Foo): [file b.py.2] from a import Foo class Bar(Foo): - def frobnicate(self, *args: int) -> None: pass + def frobnicate(self, *args) -> None: pass [file b.py.3] from a import Foo class Bar(Foo): - def frobnicate(self, *args: int) -> None: pass # type: ignore[override] # I know + def frobnicate(self, *args) -> None: pass # type: ignore[override] # I know [builtins fixtures/dict.pyi] [out1] tmp/b.py:3: error: Signature of "frobnicate" incompatible with supertype "Foo" tmp/b.py:3: note: Superclass: -tmp/b.py:3: note: def frobnicate(self, x: str, *args: Any, **kwargs: Any) -> Any +tmp/b.py:3: note: def frobnicate(self, x: Any, *args: Any, **kwargs: Any) -> Any tmp/b.py:3: note: Subclass: tmp/b.py:3: note: def frobnicate(self) -> None [out2] tmp/b.py:3: error: Signature of "frobnicate" incompatible with supertype "Foo" tmp/b.py:3: note: Superclass: -tmp/b.py:3: note: def frobnicate(self, x: str, *args: Any, **kwargs: Any) -> Any +tmp/b.py:3: note: def frobnicate(self, x: Any, *args: Any, **kwargs: Any) -> Any tmp/b.py:3: note: Subclass: -tmp/b.py:3: note: def frobnicate(self, *args: int) -> None +tmp/b.py:3: note: def frobnicate(self, *args: Any) -> None diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index c1f159e0c334..586cb46c12ef 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1829,3 +1829,109 @@ class C(Generic[P]): ... c: C[int, [int, str], str] # E: Nested parameter specifications are not allowed reveal_type(c) # N: Revealed type is "__main__.C[Any]" [builtins fixtures/paramspec.pyi] + +[case testParamSpecConcatenateSelfType] +from typing import Callable +from typing_extensions import ParamSpec, Concatenate + +P = ParamSpec("P") +class A: + def __init__(self, a_param_1: str) -> None: ... + + @classmethod + def add_params(cls: Callable[P, A]) -> Callable[Concatenate[float, P], A]: + def new_constructor(i: float, *args: P.args, **kwargs: P.kwargs) -> A: + return cls(*args, **kwargs) + return new_constructor + + @classmethod + def remove_params(cls: Callable[Concatenate[str, P], A]) -> Callable[P, A]: + def new_constructor(*args: P.args, **kwargs: P.kwargs) -> A: + return cls("my_special_str", *args, **kwargs) + return new_constructor + +reveal_type(A.add_params()) # N: Revealed type is "def (builtins.float, a_param_1: builtins.str) -> __main__.A" +reveal_type(A.remove_params()) # N: Revealed type is "def () -> __main__.A" +[builtins fixtures/paramspec.pyi] + +[case testParamSpecConcatenateCallbackProtocol] +from typing import Protocol, TypeVar +from typing_extensions import ParamSpec, Concatenate + +P = ParamSpec("P") +R = TypeVar("R", covariant=True) + +class Path: ... + +class Function(Protocol[P, R]): + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: ... + +def file_cache(fn: Function[Concatenate[Path, P], R]) -> Function[P, R]: + def wrapper(*args: P.args, **kw: P.kwargs) -> R: + return fn(Path(), *args, **kw) + return wrapper + +@file_cache +def get_thing(path: Path, *, some_arg: int) -> int: ... +reveal_type(get_thing) # N: Revealed type is "__main__.Function[[*, some_arg: builtins.int], builtins.int]" +get_thing(some_arg=1) # OK +[builtins fixtures/paramspec.pyi] + +[case testParamSpecConcatenateKeywordOnly] +from typing import Callable, TypeVar +from typing_extensions import ParamSpec, Concatenate + +P = ParamSpec("P") +R = TypeVar("R") + +class Path: ... + +def file_cache(fn: Callable[Concatenate[Path, P], R]) -> Callable[P, R]: + def wrapper(*args: P.args, **kw: P.kwargs) -> R: + return fn(Path(), *args, **kw) + return wrapper + +@file_cache +def get_thing(path: Path, *, some_arg: int) -> int: ... +reveal_type(get_thing) # N: Revealed type is "def (*, some_arg: builtins.int) -> builtins.int" +get_thing(some_arg=1) # OK +[builtins fixtures/paramspec.pyi] + +[case testParamSpecConcatenateCallbackApply] +from typing import Callable, Protocol +from typing_extensions import ParamSpec, Concatenate + +P = ParamSpec("P") + +class FuncType(Protocol[P]): + def __call__(self, x: int, s: str, *args: P.args, **kw_args: P.kwargs) -> str: ... + +def forwarder1(fp: FuncType[P], *args: P.args, **kw_args: P.kwargs) -> str: + return fp(0, '', *args, **kw_args) + +def forwarder2(fp: Callable[Concatenate[int, str, P], str], *args: P.args, **kw_args: P.kwargs) -> str: + return fp(0, '', *args, **kw_args) + +def my_f(x: int, s: str, d: bool) -> str: ... +forwarder1(my_f, True) # OK +forwarder2(my_f, True) # OK +forwarder1(my_f, 1.0) # E: Argument 2 to "forwarder1" has incompatible type "float"; expected "bool" +forwarder2(my_f, 1.0) # E: Argument 2 to "forwarder2" has incompatible type "float"; expected "bool" +[builtins fixtures/paramspec.pyi] + +[case testParamSpecCallbackProtocolSelf] +from typing import Callable, Protocol, TypeVar +from typing_extensions import ParamSpec, Concatenate + +Params = ParamSpec("Params") +Result = TypeVar("Result", covariant=True) + +class FancyMethod(Protocol): + def __call__(self, arg1: int, arg2: str) -> bool: ... + def return_me(self: Callable[Params, Result]) -> Callable[Params, Result]: ... + def return_part(self: Callable[Concatenate[int, Params], Result]) -> Callable[Params, Result]: ... + +m: FancyMethod +reveal_type(m.return_me()) # N: Revealed type is "def (arg1: builtins.int, arg2: builtins.str) -> builtins.bool" +reveal_type(m.return_part()) # N: Revealed type is "def (arg2: builtins.str) -> builtins.bool" +[builtins fixtures/paramspec.pyi] diff --git a/test-data/unit/fixtures/paramspec.pyi b/test-data/unit/fixtures/paramspec.pyi index 9b0089f6a7e9..dfb5e126f242 100644 --- a/test-data/unit/fixtures/paramspec.pyi +++ b/test-data/unit/fixtures/paramspec.pyi @@ -16,6 +16,7 @@ class object: class function: ... class ellipsis: ... +class classmethod: ... class type: def __init__(self, *a: object) -> None: ... From 7165b6c35d426228152542ce4836be02f8b47f55 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 7 Sep 2023 23:14:27 +0100 Subject: [PATCH 3/4] Consider trivial suffix unconditionally after all --- mypy/checkmember.py | 2 ++ mypy/erasetype.py | 7 ----- mypy/subtypes.py | 4 +-- mypy/types.py | 7 ----- .../unit/check-parameter-specification.test | 26 +++++++++++++++++++ 5 files changed, 29 insertions(+), 17 deletions(-) diff --git a/mypy/checkmember.py b/mypy/checkmember.py index 60430839ff62..59af0d402e14 100644 --- a/mypy/checkmember.py +++ b/mypy/checkmember.py @@ -896,6 +896,8 @@ def f(self: S) -> T: ... return functype else: selfarg = get_proper_type(item.arg_types[0]) + # This level of erasure matches the one in checker.check_func_def(), + # better keep these two checks consistent. if subtypes.is_subtype(dispatched_arg_type, erase_typevars(erase_to_bound(selfarg))): new_items.append(item) elif isinstance(selfarg, ParamSpecType): diff --git a/mypy/erasetype.py b/mypy/erasetype.py index a969784df134..fbbb4f80b578 100644 --- a/mypy/erasetype.py +++ b/mypy/erasetype.py @@ -175,13 +175,6 @@ def visit_param_spec(self, t: ParamSpecType) -> Type: return self.replacement return t - def visit_callable_type(self, t: CallableType) -> Type: - result = super().visit_callable_type(t) - if t.param_spec(): - assert isinstance(result, ProperType) and isinstance(result, CallableType) - result.erased = True - return result - def visit_type_alias_type(self, t: TypeAliasType) -> Type: # Type alias target can't contain bound type variables (not bound by the type # alias itself), so it is safe to just erase the arguments. diff --git a/mypy/subtypes.py b/mypy/subtypes.py index 410206d6e511..a8748ba2d59a 100644 --- a/mypy/subtypes.py +++ b/mypy/subtypes.py @@ -1509,9 +1509,7 @@ def are_parameters_compatible( # Treat "def _(*a: Any, **kw: Any) -> X" similarly to "Callable[..., X]" if are_trivial_parameters(right): return True - - # Parameters should not contain nested ParamSpec, so erasure doesn't make them less general. - trivial_suffix = isinstance(right, CallableType) and right.erased and is_trivial_suffix(right) + trivial_suffix = is_trivial_suffix(right) # Match up corresponding arguments and check them for compatibility. In # every pair (argL, argR) of corresponding arguments from L and R, argL must diff --git a/mypy/types.py b/mypy/types.py index 118a61f128af..cee4595b67cc 100644 --- a/mypy/types.py +++ b/mypy/types.py @@ -1778,7 +1778,6 @@ class CallableType(FunctionLike): # (this is used for error messages) "imprecise_arg_kinds", "unpack_kwargs", # Was an Unpack[...] with **kwargs used to define this callable? - "erased", # Is this callable created as an erased form of a more precise type? ) def __init__( @@ -1804,7 +1803,6 @@ def __init__( from_concatenate: bool = False, imprecise_arg_kinds: bool = False, unpack_kwargs: bool = False, - erased: bool = False, ) -> None: super().__init__(line, column) assert len(arg_types) == len(arg_kinds) == len(arg_names) @@ -1852,7 +1850,6 @@ def __init__( self.def_extras = {} self.type_guard = type_guard self.unpack_kwargs = unpack_kwargs - self.erased = erased def copy_modified( self: CT, @@ -1876,7 +1873,6 @@ def copy_modified( from_concatenate: Bogus[bool] = _dummy, imprecise_arg_kinds: Bogus[bool] = _dummy, unpack_kwargs: Bogus[bool] = _dummy, - erased: Bogus[bool] = _dummy, ) -> CT: modified = CallableType( arg_types=arg_types if arg_types is not _dummy else self.arg_types, @@ -1907,7 +1903,6 @@ def copy_modified( else self.imprecise_arg_kinds ), unpack_kwargs=unpack_kwargs if unpack_kwargs is not _dummy else self.unpack_kwargs, - erased=erased if erased is not _dummy else self.erased, ) # Optimization: Only NewTypes are supported as subtypes since # the class is effectively final, so we can use a cast safely. @@ -2225,7 +2220,6 @@ def serialize(self) -> JsonDict: "from_concatenate": self.from_concatenate, "imprecise_arg_kinds": self.imprecise_arg_kinds, "unpack_kwargs": self.unpack_kwargs, - "erased": self.erased, } @classmethod @@ -2250,7 +2244,6 @@ def deserialize(cls, data: JsonDict) -> CallableType: from_concatenate=data["from_concatenate"], imprecise_arg_kinds=data["imprecise_arg_kinds"], unpack_kwargs=data["unpack_kwargs"], - erased=data["erased"], ) diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index dd5bdb3dc535..da831d29dd43 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -1936,6 +1936,32 @@ reveal_type(m.return_me()) # N: Revealed type is "def (arg1: builtins.int, arg2 reveal_type(m.return_part()) # N: Revealed type is "def (arg2: builtins.str) -> builtins.bool" [builtins fixtures/paramspec.pyi] +[case testParamSpecInferenceCallableAgainstAny] +from typing import Callable, TypeVar, Any +from typing_extensions import ParamSpec, Concatenate + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +class A: ... +a = A() + +def a_func( + func: Callable[Concatenate[A, _P], _R], +) -> Callable[Concatenate[Any, _P], _R]: + def wrapper(__a: Any, *args: _P.args, **kwargs: _P.kwargs) -> _R: + return func(a, *args, **kwargs) + return wrapper + +def test(a, *args): ... +x: Any +y: object + +a_func(test) +x = a_func(test) +y = a_func(test) +[builtins fixtures/paramspec.pyi] + [case testParamSpecInferenceWithCallbackProtocol] from typing import Protocol, Callable, ParamSpec From ffe43e392e18a33e588f8b28372f5dfcf9d88987 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Thu, 7 Sep 2023 23:45:07 +0100 Subject: [PATCH 4/4] Update tests --- test-data/unit/check-callable.test | 31 ++++++++++++++++++++++++++++++ test-data/unit/check-modules.test | 12 ++++++------ 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/test-data/unit/check-callable.test b/test-data/unit/check-callable.test index 07c42de74bb3..8a611a689be5 100644 --- a/test-data/unit/check-callable.test +++ b/test-data/unit/check-callable.test @@ -598,3 +598,34 @@ a: A a() # E: Missing positional argument "other" in call to "__call__" of "A" a(a) a(lambda: None) + +[case testCallableSubtypingTrivialSuffix] +from typing import Any, Protocol + +class Call(Protocol): + def __call__(self, x: int, *args: Any, **kwargs: Any) -> None: ... + +def f1() -> None: ... +a1: Call = f1 # E: Incompatible types in assignment (expression has type "Callable[[], None]", variable has type "Call") \ + # N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]" +def f2(x: str) -> None: ... +a2: Call = f2 # E: Incompatible types in assignment (expression has type "Callable[[str], None]", variable has type "Call") \ + # N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]" +def f3(y: int) -> None: ... +a3: Call = f3 # E: Incompatible types in assignment (expression has type "Callable[[int], None]", variable has type "Call") \ + # N: "Call.__call__" has type "Callable[[Arg(int, 'x'), VarArg(Any), KwArg(Any)], None]" +def f4(x: int) -> None: ... +a4: Call = f4 + +def f5(x: int, y: int) -> None: ... +a5: Call = f5 + +def f6(x: int, y: int = 0) -> None: ... +a6: Call = f6 + +def f7(x: int, *, y: int) -> None: ... +a7: Call = f7 + +def f8(x: int, *args: int, **kwargs: str) -> None: ... +a8: Call = f8 +[builtins fixtures/tuple.pyi] diff --git a/test-data/unit/check-modules.test b/test-data/unit/check-modules.test index 3da5996ed274..94368f6c1113 100644 --- a/test-data/unit/check-modules.test +++ b/test-data/unit/check-modules.test @@ -3193,7 +3193,7 @@ from test1 import aaaa # E: Module "test1" has no attribute "aaaa" import b [file a.py] class Foo: - def frobnicate(self, x, *args, **kwargs): pass + def frobnicate(self, x: str, *args, **kwargs): pass [file b.py] from a import Foo class Bar(Foo): @@ -3201,21 +3201,21 @@ class Bar(Foo): [file b.py.2] from a import Foo class Bar(Foo): - def frobnicate(self, *args) -> None: pass + def frobnicate(self, *args: int) -> None: pass [file b.py.3] from a import Foo class Bar(Foo): - def frobnicate(self, *args) -> None: pass # type: ignore[override] # I know + def frobnicate(self, *args: int) -> None: pass # type: ignore[override] # I know [builtins fixtures/dict.pyi] [out1] tmp/b.py:3: error: Signature of "frobnicate" incompatible with supertype "Foo" tmp/b.py:3: note: Superclass: -tmp/b.py:3: note: def frobnicate(self, x: Any, *args: Any, **kwargs: Any) -> Any +tmp/b.py:3: note: def frobnicate(self, x: str, *args: Any, **kwargs: Any) -> Any tmp/b.py:3: note: Subclass: tmp/b.py:3: note: def frobnicate(self) -> None [out2] tmp/b.py:3: error: Signature of "frobnicate" incompatible with supertype "Foo" tmp/b.py:3: note: Superclass: -tmp/b.py:3: note: def frobnicate(self, x: Any, *args: Any, **kwargs: Any) -> Any +tmp/b.py:3: note: def frobnicate(self, x: str, *args: Any, **kwargs: Any) -> Any tmp/b.py:3: note: Subclass: -tmp/b.py:3: note: def frobnicate(self, *args: Any) -> None +tmp/b.py:3: note: def frobnicate(self, *args: int) -> None