diff --git a/releasenotes/notes/fix-async-retry-type-overloads-27f3e0c239ed6b.yaml b/releasenotes/notes/fix-async-retry-type-overloads-27f3e0c239ed6b.yaml new file mode 100644 index 00000000..c5127669 --- /dev/null +++ b/releasenotes/notes/fix-async-retry-type-overloads-27f3e0c239ed6b.yaml @@ -0,0 +1,9 @@ +--- +fixes: + - | + The ``@retry`` decorator's type overloads for the ``sleep=`` parameter + (e.g. ``sleep=trio.sleep``) have been improved. Previously, the + async-sleep overload used ``R | Awaitable[R]`` as the return type + bound, which was ambiguous: for ``async def f() -> T``, pyright could + infer ``R = Coroutine[Any, Any, T]`` instead of ``R = T``, producing + false-positive type errors in downstream code. diff --git a/tenacity/__init__.py b/tenacity/__init__.py index e734d6f5..0410031a 100644 --- a/tenacity/__init__.py +++ b/tenacity/__init__.py @@ -20,6 +20,7 @@ import sys import threading import time +import types import typing as t import warnings from abc import ABC, abstractmethod @@ -87,8 +88,6 @@ tornado = None if t.TYPE_CHECKING: - import types - from typing_extensions import Self from . import asyncio as tasyncio @@ -585,6 +584,23 @@ def __repr__(self) -> str: return f"<{clsname} {id(self)}: attempt #{self.attempt_number}; slept for {slept}; last result: {result}>" +class _AsyncRetryDecorator(t.Protocol): + @t.overload + def __call__( + self, fn: "t.Callable[P, types.CoroutineType[t.Any, t.Any, R]]" + ) -> "t.Callable[P, types.CoroutineType[t.Any, t.Any, R]]": ... + @t.overload + def __call__( + self, fn: t.Callable[P, t.Coroutine[t.Any, t.Any, R]] + ) -> t.Callable[P, t.Coroutine[t.Any, t.Any, R]]: ... + @t.overload + def __call__( + self, fn: t.Callable[P, t.Awaitable[R]] + ) -> t.Callable[P, t.Awaitable[R]]: ... + @t.overload + def __call__(self, fn: t.Callable[P, R]) -> t.Callable[P, t.Awaitable[R]]: ... + + @t.overload def retry(func: WrappedFn) -> WrappedFn: ... @@ -606,7 +622,7 @@ def retry( retry_error_callback: t.Optional[ t.Callable[["RetryCallState"], t.Union[t.Any, t.Awaitable[t.Any]]] ] = ..., -) -> t.Callable[[t.Callable[P, R | t.Awaitable[R]]], t.Callable[P, t.Awaitable[R]]]: ... +) -> _AsyncRetryDecorator: ... @t.overload