@@ -317,97 +317,136 @@ def test_shutdown_all_methods_in_one_thread(self):
317317 def test_shutdown_immediate_all_methods_in_one_thread (self ):
318318 return self ._shutdown_all_methods_in_one_thread (True )
319319
320- def _write_msg_thread (self , q , n , results , delay ,
321- i_when_exec_shutdown ,
322- event_start , event_end ):
323- event_start .wait ()
324- for i in range (1 , n + 1 ):
325- try :
326- q .put ((i , "YDLO" ))
327- results .append (True )
328- except self .queue .ShutDown :
329- results .append (False )
330- # triggers shutdown of queue
331- if i == i_when_exec_shutdown :
332- event_end .set ()
333- time .sleep (delay )
334- # end of all puts
335- q .join ()
320+ def _shutdown_all_methods_in_many_threads (self , immediate ):
321+ # Arrange
322+ q = self .type2test ()
323+
324+ start_puts = threading .Event ()
325+ start_gets = threading .Event ()
326+ put = threading .Event ()
327+ shutdown = threading .Event ()
328+
329+ n_gets_lock = threading .Lock ()
330+ n_gets = 0
336331
337- def _read_msg_thread (self , q , nb , results , delay , event_start ):
338- event_start .wait ()
339- block = True
340- while nb :
341- time .sleep (delay )
332+ calls = []
333+ results = []
334+ queue_size_after_join = []
335+
336+ def _record_call (f , * a ):
337+ calls .append ((f , a ))
338+ return f (* a )
339+
340+ def _record_result (f ):
342341 try :
343- # Get at least one message
344- q .get (block )
345- block = False
346- q .task_done ()
347- results .append (True )
348- nb -= 1
349- except self .queue .ShutDown :
350- results .append (False )
351- nb -= 1
352- except self .queue .Empty :
353- pass
354- q .join ()
342+ result = f ()
343+ except Exception as e :
344+ results .append ((f , e ))
345+ else :
346+ results .append ((f , result ))
355347
356- def _shutdown_thread (self , q , event_end , immediate ):
357- event_end .wait ()
358- q .shutdown (immediate )
359- q .join ()
348+ def put_worker ():
349+ start_puts .wait ()
360350
361- def _join_thread (self , q , delay , event_start ):
362- event_start .wait ()
363- time .sleep (delay )
364- q .join ()
351+ for i in range (5 ):
352+ _record_call (q .put , i )
365353
366- def _shutdown_all_methods_in_many_threads (self , immediate ):
367- q = self .type2test ()
368- ps = []
369- ev_start = threading .Event ()
370- ev_exec_shutdown = threading .Event ()
371- res_puts = []
372- res_gets = []
373- delay = 1e-4
374- read_process = 4
375- nb_msgs = read_process * 16
376- nb_msgs_r = nb_msgs // read_process
377- when_exec_shutdown = nb_msgs // 2
378- lprocs = (
379- (self ._write_msg_thread , 1 , (q , nb_msgs , res_puts , delay ,
380- when_exec_shutdown ,
381- ev_start , ev_exec_shutdown )),
382- (self ._read_msg_thread , read_process , (q , nb_msgs_r ,
383- res_gets , delay * 2 ,
384- ev_start )),
385- (self ._join_thread , 2 , (q , delay * 2 , ev_start )),
386- (self ._shutdown_thread , 1 , (q , ev_exec_shutdown , immediate )),
387- )
388- # start all threds
389- for func , n , args in lprocs :
390- for i in range (n ):
391- ps .append (threading .Thread (target = func , args = args ))
392- ps [- 1 ].start ()
393- # set event in order to run q.shutdown()
394- ev_start .set ()
354+ start_gets .set ()
395355
396- if not immediate :
397- assert (len (res_gets ) == len (res_puts ))
398- assert (res_gets .count (True ) == res_puts .count (True ))
399- else :
400- assert (len (res_gets ) <= len (res_puts ))
401- assert (res_gets .count (True ) <= res_puts .count (True ))
356+ for i in range (5 , 25 ):
357+ put .wait ()
358+ _record_call (q .put , i )
359+ put .clear ()
360+
361+ shutdown .set ()
362+
363+ # Should raise ShutDown
364+ put .wait ()
365+ _record_call (q .put , 25 )
366+
367+ def get_worker ():
368+ nonlocal n_gets
369+
370+ start_gets .wait ()
402371
403- for thread in ps [1 :]:
372+ while True :
373+ with n_gets_lock :
374+ if n_gets >= 25 :
375+ break
376+ n_gets += 1
377+
378+ put .set ()
379+ _record_call (q .get , False )
380+
381+ put .set ()
382+ _record_call (q .get , False ) # should raise ShutDown if immediate
383+
384+ def join_worker ():
385+ start_gets .wait ()
386+ _record_call (q .join )
387+ queue_size_after_join .append (q .qsize ())
388+
389+ def shutdown_worker ():
390+ shutdown .wait ()
391+ _record_call (q .shutdown , immediate )
392+
393+ def _start_thread (f ):
394+ thread = threading .Thread (target = _record_result , args = (f ,))
395+ thread .start ()
396+ return thread
397+
398+ threads = [
399+ _start_thread (put_worker ),
400+ * (_start_thread (get_worker ) for _ in range (4 )),
401+ * (_start_thread (join_worker ) for _ in range (2 )),
402+ _start_thread (shutdown_worker ),
403+ ]
404+
405+ # Act
406+ start_puts .set ()
407+ shutdown .wait ()
408+ for thread in threads :
404409 thread .join ()
405410
406- @unittest .skip ("test times out (gh-115258)" )
411+ # Assert
412+ self .assertEqual (q .qsize (), 0 )
413+
414+ if immediate :
415+ self .assertTrue (all (qs > 0 for qs in queue_size_after_join ))
416+ else :
417+ self .assertTrue (all (qs == 0 for qs in queue_size_after_join ))
418+
419+ self .assertListEqual (
420+ [a for f , a in calls if f is q .put ], [(i ,) for i in range (33 )]
421+ )
422+ self .assertListEqual (
423+ [a for f , a in calls if f is q .get ], [(False ,)] * 36
424+ )
425+ self .assertListEqual ([a for f , a in calls if f is q .join ], [(), ()])
426+ self .assertListEqual (
427+ [a for f , a in calls if f is q .shutdown ], [immediate ]
428+ )
429+
430+ put_worker_result = next (r for f , r in results if f is put_worker )
431+ self .assertIs (put_worker_result .__class__ , self .queue .ShutDown )
432+
433+ get_worker_results = [r for f , r in results if f is get_worker ]
434+ if immediate :
435+ self .assertListEqual (get_worker_results , [self .queue .ShutDown ] * 4 )
436+ else :
437+ self .assertListEqual (get_worker_results , [None ] * 4 )
438+
439+ join_worker_results = [r for f , r in results if f is join_worker ]
440+ self .assertListEqual (join_worker_results , [None , None ])
441+
442+ shutdown_worker_result = next (
443+ r for f , r in results if f is shutdown_worker
444+ )
445+ self .assertIsNone (shutdown_worker_result , None )
446+
407447 def test_shutdown_all_methods_in_many_threads (self ):
408448 return self ._shutdown_all_methods_in_many_threads (False )
409449
410- @unittest .skip ("test times out (gh-115258)" )
411450 def test_shutdown_immediate_all_methods_in_many_threads (self ):
412451 return self ._shutdown_all_methods_in_many_threads (True )
413452
0 commit comments