3
3
# Licensed under the MIT license as detailed in LICENSE.txt
4
4
5
5
import asyncio
6
- import time
7
6
from pathlib import Path
8
7
from unittest import mock
9
8
12
11
13
12
from aiolimiter import AsyncLimiter
14
13
15
- WAIT_LIMIT = 2 # seconds before we declare the test failed
14
+ # max number of wait_for rounds when waiting for all events to settle
15
+ MAX_WAIT_FOR_ITER = 5
16
16
17
17
18
18
def test_version ():
@@ -32,14 +32,16 @@ def test_version():
32
32
33
33
async def wait_for_n_done (tasks , n ):
34
34
"""Wait for n (or more) tasks to have completed"""
35
- start = time .time ()
36
- pending = tasks
35
+ iteration = 0
37
36
remainder = len (tasks ) - n
38
- while time .time () <= start + WAIT_LIMIT and len (pending ) > remainder :
37
+ while iteration <= MAX_WAIT_FOR_ITER :
38
+ iteration += 1
39
39
_ , pending = await asyncio .wait (
40
- tasks , timeout = WAIT_LIMIT , return_when = asyncio .FIRST_COMPLETED
40
+ tasks , timeout = 0 , return_when = asyncio .FIRST_COMPLETED
41
41
)
42
- assert len (pending ) >= remainder
42
+ if len (pending ) <= remainder :
43
+ break
44
+ assert len (pending ) <= remainder
43
45
return pending
44
46
45
47
@@ -73,33 +75,61 @@ async def async_contextmanager_task(limiter):
73
75
pass
74
76
75
77
76
- @pytest .mark .parametrize ("task" , [acquire_task , async_contextmanager_task ])
77
- async def test_acquire (task ):
78
- current_time = 0
78
+ class MockLoopTime :
79
+ def __init__ (self ):
80
+ self .current_time = 0
81
+ event_loop = asyncio .get_running_loop ()
82
+ self .patch = mock .patch .object (event_loop , "time" , self .mocked_time )
83
+
84
+ def mocked_time (self ):
85
+ return self .current_time
79
86
80
- def mocked_time ():
81
- return current_time
87
+ def __enter__ (self ):
88
+ self .patch .start ()
89
+ return self
82
90
91
+ def __exit__ (self , * _ ):
92
+ self .patch .stop ()
93
+
94
+
95
+ @pytest .mark .parametrize ("task" , [acquire_task , async_contextmanager_task ])
96
+ async def test_acquire (task ):
83
97
# capacity released every 2 seconds
84
98
limiter = AsyncLimiter (5 , 10 )
85
99
86
- event_loop = asyncio .get_running_loop ()
87
- with mock .patch .object (event_loop , "time" , mocked_time ):
100
+ with MockLoopTime () as mocked_time :
88
101
tasks = [asyncio .ensure_future (task (limiter )) for _ in range (10 )]
89
102
90
103
pending = await wait_for_n_done (tasks , 5 )
91
104
assert len (pending ) == 5
92
105
93
- current_time = 3 # releases capacity for one and some buffer
106
+ mocked_time . current_time = 3 # releases capacity for one and some buffer
94
107
assert limiter .has_capacity ()
95
108
96
109
pending = await wait_for_n_done (pending , 1 )
97
110
assert len (pending ) == 4
98
111
99
- current_time = 7 # releases capacity for two more, plus buffer
112
+ mocked_time . current_time = 7 # releases capacity for two more, plus buffer
100
113
pending = await wait_for_n_done (pending , 2 )
101
114
assert len (pending ) == 2
102
115
103
- current_time = 11 # releases the remainder
116
+ mocked_time . current_time = 11 # releases the remainder
104
117
pending = await wait_for_n_done (pending , 2 )
105
118
assert len (pending ) == 0
119
+
120
+
121
+ async def test_acquire_wait_time ():
122
+ limiter = AsyncLimiter (3 , 3 )
123
+
124
+ with MockLoopTime () as mocked_time :
125
+ # Fill the bucket with an amount of 1
126
+ await limiter .acquire (1 )
127
+
128
+ # Acquiring an amount of 3 now should take 1 second
129
+ task = asyncio .ensure_future (limiter .acquire (3 ))
130
+ pending = await wait_for_n_done ([task ], 0 )
131
+ assert pending
132
+
133
+ mocked_time .current_time = 1
134
+ pending = await wait_for_n_done ([task ], 1 )
135
+ assert not pending
0 commit comments