diff --git a/extensions/mypy_extensions.py b/extensions/mypy_extensions.py index 82eea32a31d8..afc3fe9754f1 100644 --- a/extensions/mypy_extensions.py +++ b/extensions/mypy_extensions.py @@ -131,3 +131,13 @@ def KwArg(type=Any): # Return type that indicates a function does not return class NoReturn: pass + + +def declared_type(t): + """Declare the type of a declaration. + + This is useful for declaring a more specific type for a decorated class or + function definition than the decorator provides as a return value. + """ + # Return the identity function -- calling this should be a noop + return lambda __x: __x diff --git a/mypy/checker.py b/mypy/checker.py index 98bb35474bf9..79f300baa630 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -2266,8 +2266,14 @@ def visit_decorator(self, e: Decorator) -> None: [nodes.ARG_POS], e, callable_name=fullname) sig = cast(FunctionLike, sig) - sig = set_callable_name(sig, e.func) - e.var.type = sig + if e.var.type is not None: + # We have a declared type, check it. + self.check_subtype(sig, e.var.type, e, + subtype_label="inferred decorated type", + supertype_label="declared decorated type") + else: + e.var.type = sig + e.var.type = set_callable_name(e.var.type, e.func) e.var.is_ready = True if e.func.is_property: self.check_incompatible_property_override(e) diff --git a/mypy/semanal.py b/mypy/semanal.py index 523edc8563e0..29c55813c4dc 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -2633,6 +2633,19 @@ def visit_decorator(self, dec: Decorator) -> None: elif refers_to_fullname(d, 'typing.no_type_check'): dec.var.type = AnyType() no_type_check = True + elif isinstance(d, CallExpr) and ( + refers_to_fullname(d.callee, 'typing.declared_type') + or refers_to_fullname(d.callee, 'mypy_extensions.declared_type')): + removed.append(i) + if i != 0: + self.fail('"declared_type" must be the topmost decorator', d) + elif len(d.args) != 1: + self.fail('"declared_type" takes exactly one argument', d) + else: + dec.var.type = self.expr_to_analyzed_type(d.args[0]) + elif (refers_to_fullname(d, 'typing.declared_type') or + refers_to_fullname(d, 'mypy_extensions.declared_type')): + self.fail('"declared_type" must have a type as an argument', d) for i in reversed(removed): del dec.decorators[i] if not dec.is_overload or dec.var.is_property: @@ -3838,6 +3851,13 @@ def visit_decorator(self, dec: Decorator) -> None: engine just for decorators. """ super().visit_decorator(dec) + if dec.var.type is not None: + # We already have a declared type for this decorated thing. + return + if dec.func.is_awaitable_coroutine: + # The type here will be fixed up by checker.py, but we can't infer + # anything here. + return if dec.var.is_property: # Decorators are expected to have a callable type (it's a little odd). if dec.func.type is None: diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index fc049053f855..9e901af721c3 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -734,6 +734,188 @@ a = None # type: A a.f() a.f(None) # E: Too many arguments for "f" of "A" +[case testMethodWithDeclaredDecoratedType] + +from typing import Callable, Any +from mypy_extensions import declared_type + +def dec(f): pass + +# Note that the decorated type must account for the `self` argument -- It's applied pre-binding +class Foo: + @declared_type(Callable[[Any, int], str]) + @dec + def f(self): pass + +foo = Foo() + +foo.f("a") # E: Argument 1 to "f" of "Foo" has incompatible type "str"; expected "int" +x: str = foo.f(1) +y: int = foo.f(1) # E: Incompatible types in assignment (expression has type "str", variable has type "int") + +reveal_type(foo.f) # E: Revealed type is 'def (builtins.int) -> builtins.str' + +[builtins fixtures/dict.pyi] + +[case testClassMethodWithDeclaredDecoratedType] + +from typing import Callable, Any +from mypy_extensions import declared_type + +def dec(f): pass + +# Note that the decorated type must account for the `cls` argument -- It's applied pre-binding +class Foo: + @declared_type(Callable[[Any, int], str]) + @classmethod + @dec + def f(cls): pass + + +Foo.f("a") # E: Argument 1 to "f" of "Foo" has incompatible type "str"; expected "int" +x: str = Foo.f(1) +y: int = Foo.f(1) # E: Incompatible types in assignment (expression has type "str", variable has type "int") + +reveal_type(Foo.f) # E: Revealed type is 'def (builtins.int) -> builtins.str' + +[builtins fixtures/dict.pyi] + +[case testStaticMethodWithDeclaredDecoratedType] + +from typing import Callable +from mypy_extensions import declared_type + +def dec(f): pass + +class Foo: + @declared_type(Callable[[int], str]) + @staticmethod + @dec + def f(): pass + + +Foo.f("a") # E: Argument 1 to "f" of "Foo" has incompatible type "str"; expected "int" +x: str = Foo.f(1) +y: int = Foo.f(1) # E: Incompatible types in assignment (expression has type "str", variable has type "int") + +reveal_type(Foo.f) # E: Revealed type is 'def (builtins.int) -> builtins.str' + +[builtins fixtures/dict.pyi] + +[case testUntypedDecoratorWithDeclaredType] + +from typing import Callable +from mypy_extensions import declared_type + +def dec(f): pass + +@declared_type(Callable[[int], str]) +@dec +def f(): pass + +f("a") # E: Argument 1 to "f" has incompatible type "str"; expected "int" +x: str = f(1) +y: int = f(1) # E: Incompatible types in assignment (expression has type "str", variable has type "int") + +reveal_type(f) # E: Revealed type is 'def (builtins.int) -> builtins.str' + +[builtins fixtures/dict.pyi] + +[case testDecoratorWithDeclaredTypeBare] + +from typing import Callable +from mypy_extensions import declared_type + +@declared_type(Callable[[int], str]) +def f(x): pass + +f("a") # E: Argument 1 to "f" has incompatible type "str"; expected "int" +x: str = f(1) +y: int = f(1) # E: Incompatible types in assignment (expression has type "str", variable has type "int") +reveal_type(f) # E: Revealed type is 'def (builtins.int) -> builtins.str' + +[builtins fixtures/dict.pyi] + +[case testDecoratorWithDeclaredTypeIncompatible] + +from typing import Callable, TypeVar +from mypy_extensions import declared_type + +T = TypeVar('T') + +def dec(f: T) -> T: pass + +@declared_type(Callable[[int], str]) # E: Incompatible types (inferred decorated type Callable[[str], str], declared decorated type Callable[[int], str]) +@dec +def f(x: str) -> str: pass + +[builtins fixtures/dict.pyi] + +[case testDecoratorWithDeclaredTypeCompatible] + +from typing import Callable, TypeVar +from mypy_extensions import declared_type + +T = TypeVar('T') + +def dec(f: T) -> T: pass + +class A: pass +class B(A): pass + +@declared_type(Callable[[B], str]) +@dec +def f(x: A) -> str: pass + +reveal_type(f) # E: Revealed type is 'def (__main__.B) -> builtins.str' +[builtins fixtures/dict.pyi] + +[case testDecoratorWithDeclaredTypeNotTop] + +from typing import Callable +from mypy_extensions import declared_type + +def dec(f): pass + +@dec +@declared_type(Callable[[int], str]) # E: "declared_type" must be the topmost decorator +def f(): pass + +reveal_type(f) # E: Revealed type is 'Any' + +[builtins fixtures/dict.pyi] + +[case testDecoratorWithDeclaredTypeNoArgs] + +from typing import Callable +from mypy_extensions import declared_type + +def dec(f): pass + +@declared_type() # E: "declared_type" takes exactly one argument +@dec +def f(): pass + +reveal_type(f) # E: Revealed type is 'Any' + +[builtins fixtures/dict.pyi] + +[case testDecoratorWithDeclaredTypeNoCall] + +from typing import Callable +from mypy_extensions import declared_type + +def dec(f): pass + +@declared_type # E: "declared_type" must have a type as an argument +@dec +def f(): pass + +# NB: The revealed type below is technically correct, as weird as it looks +reveal_type(f) # E: Revealed type is 'def [T] (T`-1) -> T`-1' + +[builtins fixtures/dict.pyi] + [case testNestedDecorators] from typing import Any, Callable def dec1(f: Callable[[Any], None]) -> Callable[[], None]: pass diff --git a/test-data/unit/fixtures/dict.pyi b/test-data/unit/fixtures/dict.pyi index e920512274dd..84d74d2642f1 100644 --- a/test-data/unit/fixtures/dict.pyi +++ b/test-data/unit/fixtures/dict.pyi @@ -41,3 +41,7 @@ class bool: pass class ellipsis: pass class BaseException: pass + +# Because all tests that use mypy_extensions need dict, this is easier. +classmethod = object() +staticmethod = object() diff --git a/test-data/unit/lib-stub/mypy_extensions.pyi b/test-data/unit/lib-stub/mypy_extensions.pyi index fa540b99f4cd..518686391d74 100644 --- a/test-data/unit/lib-stub/mypy_extensions.pyi +++ b/test-data/unit/lib-stub/mypy_extensions.pyi @@ -1,4 +1,4 @@ -from typing import Dict, Type, TypeVar, Optional, Any +from typing import Dict, Type, TypeVar, Callable, Any, Optional _T = TypeVar('_T') @@ -19,3 +19,5 @@ def KwArg(type: _T = ...) -> _T: ... def TypedDict(typename: str, fields: Dict[str, Type[_T]]) -> Type[dict]: ... class NoReturn: pass + +def declared_type(t: Any) -> Callable[[T], T]: pass