Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 646aa7e

Browse files
gh-90155: Fix bug in asyncio.Semaphore and strengthen FIFO guarantee (GH-93222)
The main problem was that an unluckily timed task cancellation could cause the semaphore to be stuck. There were also doubts about strict FIFO ordering of tasks allowed to pass. The Semaphore implementation was rewritten to be more similar to Lock. Many tests for edge cases (including cancellation) were added. (cherry picked from commit 24e0379) Co-authored-by: Cyker Way <[email protected]>
1 parent c967049 commit 646aa7e

File tree

3 files changed

+143
-22
lines changed

3 files changed

+143
-22
lines changed

Lib/asyncio/locks.py

Lines changed: 42 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -349,9 +349,8 @@ def __init__(self, value=1, *, loop=mixins._marker):
349349
super().__init__(loop=loop)
350350
if value < 0:
351351
raise ValueError("Semaphore initial value must be >= 0")
352+
self._waiters = None
352353
self._value = value
353-
self._waiters = collections.deque()
354-
self._wakeup_scheduled = False
355354

356355
def __repr__(self):
357356
res = super().__repr__()
@@ -360,16 +359,8 @@ def __repr__(self):
360359
extra = f'{extra}, waiters:{len(self._waiters)}'
361360
return f'<{res[1:-1]} [{extra}]>'
362361

363-
def _wake_up_next(self):
364-
while self._waiters:
365-
waiter = self._waiters.popleft()
366-
if not waiter.done():
367-
waiter.set_result(None)
368-
self._wakeup_scheduled = True
369-
return
370-
371362
def locked(self):
372-
"""Returns True if semaphore can not be acquired immediately."""
363+
"""Returns True if semaphore counter is zero."""
373364
return self._value == 0
374365

375366
async def acquire(self):
@@ -381,28 +372,57 @@ async def acquire(self):
381372
called release() to make it larger than 0, and then return
382373
True.
383374
"""
384-
# _wakeup_scheduled is set if *another* task is scheduled to wakeup
385-
# but its acquire() is not resumed yet
386-
while self._wakeup_scheduled or self._value <= 0:
387-
fut = self._get_loop().create_future()
388-
self._waiters.append(fut)
375+
if (not self.locked() and (self._waiters is None or
376+
all(w.cancelled() for w in self._waiters))):
377+
self._value -= 1
378+
return True
379+
380+
if self._waiters is None:
381+
self._waiters = collections.deque()
382+
fut = self._get_loop().create_future()
383+
self._waiters.append(fut)
384+
385+
# Finally block should be called before the CancelledError
386+
# handling as we don't want CancelledError to call
387+
# _wake_up_first() and attempt to wake up itself.
388+
try:
389389
try:
390390
await fut
391-
# reset _wakeup_scheduled *after* waiting for a future
392-
self._wakeup_scheduled = False
393-
except exceptions.CancelledError:
394-
self._wake_up_next()
395-
raise
391+
finally:
392+
self._waiters.remove(fut)
393+
except exceptions.CancelledError:
394+
if not self.locked():
395+
self._wake_up_first()
396+
raise
397+
396398
self._value -= 1
399+
if not self.locked():
400+
self._wake_up_first()
397401
return True
398402

399403
def release(self):
400404
"""Release a semaphore, incrementing the internal counter by one.
405+
401406
When it was zero on entry and another coroutine is waiting for it to
402407
become larger than zero again, wake up that coroutine.
403408
"""
404409
self._value += 1
405-
self._wake_up_next()
410+
self._wake_up_first()
411+
412+
def _wake_up_first(self):
413+
"""Wake up the first waiter if it isn't done."""
414+
if not self._waiters:
415+
return
416+
try:
417+
fut = next(iter(self._waiters))
418+
except StopIteration:
419+
return
420+
421+
# .done() necessarily means that a waiter will wake up later on and
422+
# either take the lock, or, if it was cancelled and lock wasn't
423+
# taken already, will hit this again and wake up a new waiter.
424+
if not fut.done():
425+
fut.set_result(True)
406426

407427

408428
class BoundedSemaphore(Semaphore):

Lib/test/test_asyncio/test_locks.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66

77
import asyncio
8+
import collections
89

910
STR_RGX_REPR = (
1011
r'^<(?P<class>.*?) object at (?P<address>.*?)'
@@ -782,6 +783,9 @@ async def test_repr(self):
782783
self.assertTrue('waiters' not in repr(sem))
783784
self.assertTrue(RGX_REPR.match(repr(sem)))
784785

786+
if sem._waiters is None:
787+
sem._waiters = collections.deque()
788+
785789
sem._waiters.append(mock.Mock())
786790
self.assertTrue('waiters:1' in repr(sem))
787791
self.assertTrue(RGX_REPR.match(repr(sem)))
@@ -855,6 +859,7 @@ async def c4(result):
855859
sem.release()
856860
self.assertEqual(2, sem._value)
857861

862+
await asyncio.sleep(0)
858863
await asyncio.sleep(0)
859864
self.assertEqual(0, sem._value)
860865
self.assertEqual(3, len(result))
@@ -897,6 +902,7 @@ async def test_acquire_cancel_before_awoken(self):
897902
t2.cancel()
898903
sem.release()
899904

905+
await asyncio.sleep(0)
900906
await asyncio.sleep(0)
901907
num_done = sum(t.done() for t in [t3, t4])
902908
self.assertEqual(num_done, 1)
@@ -917,9 +923,32 @@ async def test_acquire_hang(self):
917923
t1.cancel()
918924
sem.release()
919925
await asyncio.sleep(0)
926+
await asyncio.sleep(0)
920927
self.assertTrue(sem.locked())
921928
self.assertTrue(t2.done())
922929

930+
async def test_acquire_no_hang(self):
931+
932+
sem = asyncio.Semaphore(1)
933+
934+
async def c1():
935+
async with sem:
936+
await asyncio.sleep(0)
937+
t2.cancel()
938+
939+
async def c2():
940+
async with sem:
941+
self.assertFalse(True)
942+
943+
t1 = asyncio.create_task(c1())
944+
t2 = asyncio.create_task(c2())
945+
946+
r1, r2 = await asyncio.gather(t1, t2, return_exceptions=True)
947+
self.assertTrue(r1 is None)
948+
self.assertTrue(isinstance(r2, asyncio.CancelledError))
949+
950+
await asyncio.wait_for(sem.acquire(), timeout=1.0)
951+
923952
def test_release_not_acquired(self):
924953
sem = asyncio.BoundedSemaphore()
925954

@@ -959,6 +988,77 @@ async def coro(tag):
959988
result
960989
)
961990

991+
async def test_acquire_fifo_order_2(self):
992+
sem = asyncio.Semaphore(1)
993+
result = []
994+
995+
async def c1(result):
996+
await sem.acquire()
997+
result.append(1)
998+
return True
999+
1000+
async def c2(result):
1001+
await sem.acquire()
1002+
result.append(2)
1003+
sem.release()
1004+
await sem.acquire()
1005+
result.append(4)
1006+
return True
1007+
1008+
async def c3(result):
1009+
await sem.acquire()
1010+
result.append(3)
1011+
return True
1012+
1013+
t1 = asyncio.create_task(c1(result))
1014+
t2 = asyncio.create_task(c2(result))
1015+
t3 = asyncio.create_task(c3(result))
1016+
1017+
await asyncio.sleep(0)
1018+
1019+
sem.release()
1020+
sem.release()
1021+
1022+
tasks = [t1, t2, t3]
1023+
await asyncio.gather(*tasks)
1024+
self.assertEqual([1, 2, 3, 4], result)
1025+
1026+
async def test_acquire_fifo_order_3(self):
1027+
sem = asyncio.Semaphore(0)
1028+
result = []
1029+
1030+
async def c1(result):
1031+
await sem.acquire()
1032+
result.append(1)
1033+
return True
1034+
1035+
async def c2(result):
1036+
await sem.acquire()
1037+
result.append(2)
1038+
return True
1039+
1040+
async def c3(result):
1041+
await sem.acquire()
1042+
result.append(3)
1043+
return True
1044+
1045+
t1 = asyncio.create_task(c1(result))
1046+
t2 = asyncio.create_task(c2(result))
1047+
t3 = asyncio.create_task(c3(result))
1048+
1049+
await asyncio.sleep(0)
1050+
1051+
t1.cancel()
1052+
1053+
await asyncio.sleep(0)
1054+
1055+
sem.release()
1056+
sem.release()
1057+
1058+
tasks = [t1, t2, t3]
1059+
await asyncio.gather(*tasks, return_exceptions=True)
1060+
self.assertEqual([2, 3], result)
1061+
9621062

9631063
if __name__ == '__main__':
9641064
unittest.main()
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix broken :class:`asyncio.Semaphore` when acquire is cancelled.

0 commit comments

Comments
 (0)