Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 0b53b86

Browse files
Faster acquire (#218)
* Faster acquire * Addressed PR comments * Mock out loop time to make test faster * Add changelog entry --------- Co-authored-by: Martijn Pieters <[email protected]>
1 parent 2684871 commit 0b53b86

File tree

3 files changed

+54
-19
lines changed

3 files changed

+54
-19
lines changed

changelog.d/217.bugfix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed wait time calculation for waiting tasks, making acquisition faster (PR by @schoennenbeck)

src/aiolimiter/leakybucket.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,13 @@ async def acquire(self, amount: float = 1) -> None:
102102
fut = loop.create_future()
103103
self._waiters[task] = fut
104104
try:
105-
await asyncio.wait_for(
106-
asyncio.shield(fut), 1 / self._rate_per_sec * amount
105+
# we need to wait until current_capacity is equal or greater than the
106+
# requested amount, where current_capacity is
107+
# (self.max_rate - self._level)
108+
wait_time = (
109+
1 / self._rate_per_sec * (amount - self.max_rate + self._level)
107110
)
111+
await asyncio.wait_for(asyncio.shield(fut), wait_time)
108112
except asyncio.TimeoutError:
109113
pass
110114
fut.cancel()

tests/test_aiolimiter.py

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# Licensed under the MIT license as detailed in LICENSE.txt
44

55
import asyncio
6-
import time
76
from pathlib import Path
87
from unittest import mock
98

@@ -12,7 +11,8 @@
1211

1312
from aiolimiter import AsyncLimiter
1413

15-
WAIT_LIMIT = 2 # seconds before we declare the test failed
14+
# max number of wait_for rounds when waiting for all events to settle
15+
MAX_WAIT_FOR_ITER = 5
1616

1717

1818
def test_version():
@@ -32,14 +32,16 @@ def test_version():
3232

3333
async def wait_for_n_done(tasks, n):
3434
"""Wait for n (or more) tasks to have completed"""
35-
start = time.time()
36-
pending = tasks
35+
iteration = 0
3736
remainder = len(tasks) - n
38-
while time.time() <= start + WAIT_LIMIT and len(pending) > remainder:
37+
while iteration <= MAX_WAIT_FOR_ITER:
38+
iteration += 1
3939
_, pending = await asyncio.wait(
40-
tasks, timeout=WAIT_LIMIT, return_when=asyncio.FIRST_COMPLETED
40+
tasks, timeout=0, return_when=asyncio.FIRST_COMPLETED
4141
)
42-
assert len(pending) >= remainder
42+
if len(pending) <= remainder:
43+
break
44+
assert len(pending) <= remainder
4345
return pending
4446

4547

@@ -73,33 +75,61 @@ async def async_contextmanager_task(limiter):
7375
pass
7476

7577

76-
@pytest.mark.parametrize("task", [acquire_task, async_contextmanager_task])
77-
async def test_acquire(task):
78-
current_time = 0
78+
class MockLoopTime:
79+
def __init__(self):
80+
self.current_time = 0
81+
event_loop = asyncio.get_running_loop()
82+
self.patch = mock.patch.object(event_loop, "time", self.mocked_time)
83+
84+
def mocked_time(self):
85+
return self.current_time
7986

80-
def mocked_time():
81-
return current_time
87+
def __enter__(self):
88+
self.patch.start()
89+
return self
8290

91+
def __exit__(self, *_):
92+
self.patch.stop()
93+
94+
95+
@pytest.mark.parametrize("task", [acquire_task, async_contextmanager_task])
96+
async def test_acquire(task):
8397
# capacity released every 2 seconds
8498
limiter = AsyncLimiter(5, 10)
8599

86-
event_loop = asyncio.get_running_loop()
87-
with mock.patch.object(event_loop, "time", mocked_time):
100+
with MockLoopTime() as mocked_time:
88101
tasks = [asyncio.ensure_future(task(limiter)) for _ in range(10)]
89102

90103
pending = await wait_for_n_done(tasks, 5)
91104
assert len(pending) == 5
92105

93-
current_time = 3 # releases capacity for one and some buffer
106+
mocked_time.current_time = 3 # releases capacity for one and some buffer
94107
assert limiter.has_capacity()
95108

96109
pending = await wait_for_n_done(pending, 1)
97110
assert len(pending) == 4
98111

99-
current_time = 7 # releases capacity for two more, plus buffer
112+
mocked_time.current_time = 7 # releases capacity for two more, plus buffer
100113
pending = await wait_for_n_done(pending, 2)
101114
assert len(pending) == 2
102115

103-
current_time = 11 # releases the remainder
116+
mocked_time.current_time = 11 # releases the remainder
104117
pending = await wait_for_n_done(pending, 2)
105118
assert len(pending) == 0
119+
120+
121+
async def test_acquire_wait_time():
122+
limiter = AsyncLimiter(3, 3)
123+
124+
with MockLoopTime() as mocked_time:
125+
# Fill the bucket with an amount of 1
126+
await limiter.acquire(1)
127+
128+
# Acquiring an amount of 3 now should take 1 second
129+
task = asyncio.ensure_future(limiter.acquire(3))
130+
pending = await wait_for_n_done([task], 0)
131+
assert pending
132+
133+
mocked_time.current_time = 1
134+
pending = await wait_for_n_done([task], 1)
135+
assert not pending

0 commit comments

Comments
 (0)