Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Support allow_paid_broadcast in AIORateLimiter #4627

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 41 additions & 19 deletions telegram/ext/_aioratelimiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
except ImportError:
AIO_LIMITER_AVAILABLE = False

from telegram import constants
from telegram._utils.logging import get_logger
from telegram._utils.types import JSONDict
from telegram.error import RetryAfter
Expand Down Expand Up @@ -86,7 +87,8 @@ class AIORateLimiter(BaseRateLimiter[int]):
* A :exc:`~telegram.error.RetryAfter` exception will halt *all* requests for
:attr:`~telegram.error.RetryAfter.retry_after` + 0.1 seconds. This may be stricter than
necessary in some cases, e.g. the bot may hit a rate limit in one group but might still
be allowed to send messages in another group.
be allowed to send messages in another group or with
:paramref:`~telegram.Bot.send_message.allow_paid_broadcast` set to :obj:`True`.

Tip:
With `Bot API 7.1 <https://core.telegram.org/bots/api-changelog#october-31-2024>`_
Expand All @@ -96,10 +98,10 @@ class AIORateLimiter(BaseRateLimiter[int]):
:tg-const:`telegram.constants.FloodLimit.PAID_MESSAGES_PER_SECOND` messages per second by
paying a fee in Telegram Stars.

.. caution::
This class currently doesn't take the
:paramref:`~telegram.Bot.send_message.allow_paid_broadcast` parameter into account.
This means that the rate limiting is applied just like for any other message.
.. versionchanged:: NEXT.VERSION
This class automatically takes the
:paramref:`~telegram.Bot.send_message.allow_paid_broadcast` parameter into account and
throttles the requests accordingly.

Note:
This class is to be understood as minimal effort reference implementation.
Expand All @@ -114,23 +116,25 @@ class AIORateLimiter(BaseRateLimiter[int]):
Args:
overall_max_rate (:obj:`float`): The maximum number of requests allowed for the entire bot
per :paramref:`overall_time_period`. When set to 0, no rate limiting will be applied.
Defaults to ``30``.
Defaults to :tg-const:`telegram.constants.FloodLimit.MESSAGES_PER_SECOND`.
overall_time_period (:obj:`float`): The time period (in seconds) during which the
:paramref:`overall_max_rate` is enforced. When set to 0, no rate limiting will be
applied. Defaults to 1.
applied. Defaults to ``1``.
group_max_rate (:obj:`float`): The maximum number of requests allowed for requests related
to groups and channels per :paramref:`group_time_period`. When set to 0, no rate
limiting will be applied. Defaults to 20.
limiting will be applied. Defaults to
:tg-const:`telegram.constants.FloodLimit.MESSAGES_PER_MINUTE_PER_GROUP`.
group_time_period (:obj:`float`): The time period (in seconds) during which the
:paramref:`group_max_rate` is enforced. When set to 0, no rate limiting will be
applied. Defaults to 60.
applied. Defaults to ``60``.
max_retries (:obj:`int`): The maximum number of retries to be made in case of a
:exc:`~telegram.error.RetryAfter` exception.
If set to 0, no retries will be made. Defaults to ``0``.

"""

__slots__ = (
"_apb_limiter",
"_base_limiter",
"_group_limiters",
"_group_max_rate",
Expand All @@ -141,9 +145,9 @@ class AIORateLimiter(BaseRateLimiter[int]):

def __init__(
self,
overall_max_rate: float = 30,
overall_max_rate: float = constants.FloodLimit.MESSAGES_PER_SECOND,
overall_time_period: float = 1,
group_max_rate: float = 20,
group_max_rate: float = constants.FloodLimit.MESSAGES_PER_MINUTE_PER_GROUP,
group_time_period: float = 60,
max_retries: int = 0,
) -> None:
Expand All @@ -167,6 +171,9 @@ def __init__(
self._group_time_period = 0

self._group_limiters: dict[Union[str, int], AsyncLimiter] = {}
self._apb_limiter: AsyncLimiter = AsyncLimiter(
max_rate=constants.FloodLimit.PAID_MESSAGES_PER_SECOND, time_period=1
)
self._max_retries: int = max_retries
self._retry_after_event = asyncio.Event()
self._retry_after_event.set()
Expand Down Expand Up @@ -201,21 +208,30 @@ async def _run_request(
self,
chat: bool,
group: Union[str, int, bool],
allow_paid_broadcast: bool,
callback: Callable[..., Coroutine[Any, Any, Union[bool, JSONDict, list[JSONDict]]]],
args: Any,
kwargs: dict[str, Any],
) -> Union[bool, JSONDict, list[JSONDict]]:
base_context = self._base_limiter if (chat and self._base_limiter) else null_context()
group_context = (
self._get_group_limiter(group) if group and self._group_max_rate else null_context()
)

async with group_context, base_context:
async def inner() -> Union[bool, JSONDict, list[JSONDict]]:
# In case a retry_after was hit, we wait with processing the request
await self._retry_after_event.wait()

return await callback(*args, **kwargs)

if allow_paid_broadcast:
async with self._apb_limiter:
return await inner()
else:
base_context = self._base_limiter if (chat and self._base_limiter) else null_context()
group_context = (
self._get_group_limiter(group)
if group and self._group_max_rate
else null_context()
)

async with group_context, base_context:
return await inner()

# mypy doesn't understand that the last run of the for loop raises an exception
async def process_request(
self,
Expand All @@ -242,6 +258,7 @@ async def process_request(
group: Union[int, str, bool] = False
chat: bool = False
chat_id = data.get("chat_id")
allow_paid_broadcast = data.get("allow_paid_broadcast", False)
if chat_id is not None:
chat = True

Expand All @@ -257,7 +274,12 @@ async def process_request(
for i in range(max_retries + 1):
try:
return await self._run_request(
chat=chat, group=group, callback=callback, args=args, kwargs=kwargs
chat=chat,
group=group,
allow_paid_broadcast=allow_paid_broadcast,
callback=callback,
args=args,
kwargs=kwargs,
)
except RetryAfter as exc:
if i == max_retries:
Expand Down
74 changes: 70 additions & 4 deletions tests/ext/test_ratelimiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import json
import platform
import time
from collections import Counter
from http import HTTPStatus

import pytest
Expand Down Expand Up @@ -148,7 +149,9 @@ async def do_request(self, *args, **kwargs):
@pytest.mark.flaky(10, 1) # Timings aren't quite perfect
class TestAIORateLimiter:
count = 0
apb_count = 0
call_times = []
apb_call_times = []

class CountRequest(BaseRequest):
def __init__(self, retry_after=None):
Expand All @@ -161,8 +164,16 @@ async def shutdown(self) -> None:
pass

async def do_request(self, *args, **kwargs):
TestAIORateLimiter.count += 1
TestAIORateLimiter.call_times.append(time.time())
request_data = kwargs.get("request_data")
allow_paid_broadcast = request_data.parameters.get("allow_paid_broadcast", False)

if allow_paid_broadcast:
TestAIORateLimiter.apb_count += 1
TestAIORateLimiter.apb_call_times.append(time.time())
else:
TestAIORateLimiter.count += 1
TestAIORateLimiter.call_times.append(time.time())

if self.retry_after:
raise RetryAfter(retry_after=1)

Expand Down Expand Up @@ -190,10 +201,10 @@ async def do_request(self, *args, **kwargs):

@pytest.fixture(autouse=True)
def _reset(self):
self.count = 0
TestAIORateLimiter.count = 0
self.call_times = []
TestAIORateLimiter.call_times = []
TestAIORateLimiter.apb_count = 0
TestAIORateLimiter.apb_call_times = []

@pytest.mark.parametrize("max_retries", [0, 1, 4])
async def test_max_retries(self, bot, max_retries):
Expand Down Expand Up @@ -358,3 +369,58 @@ async def test_group_caching(self, bot, intermediate):
finally:
TestAIORateLimiter.count = 0
TestAIORateLimiter.call_times = []

async def test_allow_paid_broadcast(self, bot):
try:
rl_bot = ExtBot(
token=bot.token,
request=self.CountRequest(retry_after=None),
rate_limiter=AIORateLimiter(),
)

async with rl_bot:
apb_tasks = {}
non_apb_tasks = {}
for i in range(3000):
apb_tasks[i] = asyncio.create_task(
rl_bot.send_message(chat_id=-1, text="test", allow_paid_broadcast=True)
)

number = 2
for i in range(number):
non_apb_tasks[i] = asyncio.create_task(
rl_bot.send_message(chat_id=-1, text="test")
)
non_apb_tasks[i + number] = asyncio.create_task(
rl_bot.send_message(chat_id=-1, text="test", allow_paid_broadcast=False)
)

await asyncio.sleep(0.1)
# We expect 5 non-apb requests:
# 1: `get_me` from `async with rl_bot`
# 2-5: `send_message`
assert TestAIORateLimiter.count == 5
assert sum(1 for task in non_apb_tasks.values() if task.done()) == 4

# ~2 second after start
# We do the checks once all apb_tasks are done as apparently getting the timings
# right to check after 1 second is hard
await asyncio.sleep(2.1 - 0.1)
assert all(task.done() for task in apb_tasks.values())

apb_call_times = [
ct - TestAIORateLimiter.apb_call_times[0]
for ct in TestAIORateLimiter.apb_call_times
]
apb_call_times_dict = Counter(map(int, apb_call_times))

# We expect ~2000 apb requests after the first second
# 2000 (>>1000), since we have a floating window logic such that an initial
# burst is allowed that is hard to measure in the tests
assert apb_call_times_dict[0] <= 2000
assert apb_call_times_dict[0] + apb_call_times_dict[1] < 3000
assert sum(apb_call_times_dict.values()) == 3000

finally:
# cleanup
await asyncio.gather(*apb_tasks.values(), *non_apb_tasks.values())
Loading