diff --git a/mypy/stubtest.py b/mypy/stubtest.py index e67992ace5f8..546ea96dd9a0 100644 --- a/mypy/stubtest.py +++ b/mypy/stubtest.py @@ -676,6 +676,18 @@ def _verify_signature( yield 'runtime does not have **kwargs argument "{}"'.format(stub.varkw.variable.name) +def _verify_coroutine( + stub: nodes.FuncItem, runtime: Any, *, runtime_is_coroutine: bool +) -> Optional[str]: + if stub.is_coroutine: + if not runtime_is_coroutine: + return 'is an "async def" function in the stub, but not at runtime' + else: + if runtime_is_coroutine: + return 'is an "async def" function at runtime, but not in the stub' + return None + + @verify.register(nodes.FuncItem) def verify_funcitem( stub: nodes.FuncItem, runtime: MaybeMissing[Any], object_path: List[str] @@ -693,19 +705,40 @@ def verify_funcitem( yield Error(object_path, "is inconsistent, " + message, stub, runtime) signature = safe_inspect_signature(runtime) + runtime_is_coroutine = inspect.iscoroutinefunction(runtime) + + if signature: + stub_sig = Signature.from_funcitem(stub) + runtime_sig = Signature.from_inspect_signature(signature) + runtime_sig_desc = f'{"async " if runtime_is_coroutine else ""}def {signature}' + else: + runtime_sig_desc = None + + coroutine_mismatch_error = _verify_coroutine( + stub, + runtime, + runtime_is_coroutine=runtime_is_coroutine + ) + + if coroutine_mismatch_error is not None: + yield Error( + object_path, + coroutine_mismatch_error, + stub, + runtime, + runtime_desc=runtime_sig_desc + ) + if not signature: return - stub_sig = Signature.from_funcitem(stub) - runtime_sig = Signature.from_inspect_signature(signature) - for message in _verify_signature(stub_sig, runtime_sig, function_name=stub.name): yield Error( object_path, "is inconsistent, " + message, stub, runtime, - runtime_desc="def " + str(signature), + runtime_desc=runtime_sig_desc, ) diff --git a/mypy/test/teststubtest.py b/mypy/test/teststubtest.py index 2852299548ed..78ae82b058cd 100644 --- a/mypy/test/teststubtest.py +++ b/mypy/test/teststubtest.py @@ -28,6 +28,34 @@ def use_tmp_dir() -> Iterator[None]: TEST_MODULE_NAME = "test_module" + +stubtest_typing_stub = """ +Any = object() + +class _SpecialForm: + def __getitem__(self, typeargs: Any) -> object: ... + +Callable: _SpecialForm = ... +Generic: _SpecialForm = ... + +class TypeVar: + def __init__(self, name, covariant: bool = ..., contravariant: bool = ...) -> None: ... + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_K = TypeVar("_K") +_V = TypeVar("_V") +_S = TypeVar("_S", contravariant=True) +_R = TypeVar("_R", covariant=True) + +class Coroutine(Generic[_T_co, _S, _R]): ... +class Iterable(Generic[_T_co]): ... +class Mapping(Generic[_K, _V]): ... +class Sequence(Iterable[_T_co]): ... +class Tuple(Sequence[_T_co]): ... +def overload(func: _T) -> _T: ... +""" + stubtest_builtins_stub = """ from typing import Generic, Mapping, Sequence, TypeVar, overload @@ -66,6 +94,8 @@ def run_stubtest( with use_tmp_dir(): with open("builtins.pyi", "w") as f: f.write(stubtest_builtins_stub) + with open("typing.pyi", "w") as f: + f.write(stubtest_typing_stub) with open("{}.pyi".format(TEST_MODULE_NAME), "w") as f: f.write(stub) with open("{}.py".format(TEST_MODULE_NAME), "w") as f: @@ -172,6 +202,29 @@ class X: error="X.mistyped_var", ) + @collect_cases + def test_coroutines(self) -> Iterator[Case]: + yield Case( + stub="async def foo() -> int: ...", + runtime="def foo(): return 5", + error="foo", + ) + yield Case( + stub="def bar() -> int: ...", + runtime="async def bar(): return 5", + error="bar", + ) + yield Case( + stub="def baz() -> int: ...", + runtime="def baz(): return 5", + error=None, + ) + yield Case( + stub="async def bingo() -> int: ...", + runtime="async def bingo(): return 5", + error=None, + ) + @collect_cases def test_arg_name(self) -> Iterator[Case]: yield Case(