@@ -527,8 +527,14 @@ class _QueueShutdownTestMixin:
527527
528528 @staticmethod
529529 async def _ping_awaitable (a ):
530+ async def swallow (a_ ):
531+ try :
532+ return await a_
533+ except Exception :
534+ pass
535+
530536 try :
531- await asyncio .wait_for (asyncio .shield (a ), 0.01 )
537+ await asyncio .wait_for (asyncio .shield (swallow ( a ) ), 0.01 )
532538 except TimeoutError :
533539 pass
534540
@@ -550,6 +556,7 @@ async def test_shutdown_empty(self):
550556
551557 await self ._ping_awaitable (join_task )
552558 self .assertTrue (join_task .done ())
559+ await join_task
553560
554561 with self .assertRaisesShutdown ():
555562 await q .put ("data" )
@@ -561,17 +568,26 @@ async def test_shutdown_empty(self):
561568 with self .assertRaisesShutdown ():
562569 q .get_nowait ()
563570
564- await join_task
565-
566571 async def test_shutdown_nonempty (self ):
567- q = self .q_class ()
572+ q = self .q_class (maxsize = 1 )
568573 loop = asyncio .get_running_loop ()
574+
569575 q .put_nowait ("data" )
570576 join_task = loop .create_task (q .join ())
577+ put_task = loop .create_task (q .put ("data2" ))
578+
579+ await self ._ping_awaitable (put_task )
580+ self .assertFalse (put_task .done ())
581+
571582 q .shutdown (immediate = False ) # unfinished tasks: 1 -> 1
572583
573584 self .assertEqual (q .qsize (), 1 )
574585
586+ await self ._ping_awaitable (put_task )
587+ self .assertTrue (put_task .done ())
588+ with self .assertRaisesShutdown ():
589+ await put_task
590+
575591 self .assertEqual (await q .get (), "data" )
576592
577593 await self ._ping_awaitable (join_task )
@@ -604,6 +620,7 @@ async def test_shutdown_immediate(self):
604620
605621 await self ._ping_awaitable (join_task )
606622 self .assertTrue (join_task .done ())
623+ await join_task
607624
608625 with self .assertRaisesShutdown ():
609626 await q .put ("data" )
@@ -620,8 +637,6 @@ async def test_shutdown_immediate(self):
620637 ):
621638 q .task_done ()
622639
623- await join_task
624-
625640 async def test_shutdown_immediate_with_unfinished (self ):
626641 q = self .q_class ()
627642 loop = asyncio .get_running_loop ()
@@ -631,6 +646,8 @@ async def test_shutdown_immediate_with_unfinished(self):
631646 self .assertEqual (await q .get (), "data" )
632647 q .shutdown (immediate = True ) # unfinished tasks: 2 -> 1
633648
649+ self .assertEqual (q .qsize (), 0 )
650+
634651 await self ._ping_awaitable (join_task )
635652 self .assertFalse (join_task .done ())
636653
0 commit comments