11import os
2+ import signal
3+ import subprocess
24import time
35from datetime import datetime
46from os .path import isdir , isfile , join
57from pathlib import Path
8+ from typing import Callable , Optional
69
710import pytest
811
1215from .utils .io import check_file_contents , read_file
1316from .utils .job import default_task_output , list_jobs
1417from .utils .table import parse_multiline_cell
15- from .utils .wait import wait_for_pid_exit , wait_until
18+ from .utils .wait import wait_for_pid_exit , wait_for_worker_state , wait_until
1619
1720
1821def test_job_submit (hq_env : HqEnv ):
@@ -468,38 +471,11 @@ def test_cancel_all(hq_env: HqEnv):
468471
469472
470473def test_cancel_terminate_process_children (hq_env : HqEnv ):
471- hq_env .start_server ()
472- hq_env .start_worker ()
474+ def cancel (worker_process ):
475+ hq_env .command (["job" , "cancel" , "1" ])
476+ wait_for_job_state (hq_env , 1 , "CANCELED" )
473477
474- hq_env .command (
475- [
476- "submit" ,
477- "--" ,
478- * python (
479- """
480- import os
481- import sys
482- import time
483- print(os.getpid(), flush=True)
484- pid = os.fork()
485- if pid > 0:
486- print(pid, flush=True)
487- time.sleep(3600)
488- """
489- ),
490- ]
491- )
492- wait_for_job_state (hq_env , 1 , "RUNNING" )
493- wait_until (lambda : len (read_file (default_task_output ()).splitlines ()) == 2 )
494-
495- hq_env .command (["job" , "cancel" , "1" ])
496- wait_for_job_state (hq_env , 1 , "CANCELED" )
497-
498- pids = [int (pid ) for pid in read_file (default_task_output ()).splitlines ()]
499-
500- parent , child = pids
501- wait_for_pid_exit (parent )
502- wait_for_pid_exit (child )
478+ check_child_process_exited (hq_env , cancel )
503479
504480
505481def test_cancel_send_sigint (hq_env : HqEnv ):
@@ -582,8 +558,8 @@ def signal_handler(sig, frame):
582558def test_reporting_state_after_worker_lost (hq_env : HqEnv ):
583559 hq_env .start_server ()
584560 hq_env .start_workers (2 , cpus = 1 )
585- hq_env .command (["submit" , "sleep" , "1 " ])
586- hq_env .command (["submit" , "sleep" , "1 " ])
561+ hq_env .command (["submit" , "sleep" , "2 " ])
562+ hq_env .command (["submit" , "sleep" , "2 " ])
587563
588564 wait_for_job_state (hq_env , [1 , 2 ], "RUNNING" )
589565
@@ -592,15 +568,17 @@ def test_reporting_state_after_worker_lost(hq_env: HqEnv):
592568 table .check_column_value ("State" , 1 , "RUNNING" )
593569 hq_env .kill_worker (1 )
594570
595- time .sleep (0.25 )
571+ def task_is_waiting ():
572+ table = list_jobs (hq_env )
573+ if table .get_column_value ("State" )[0 ] == "WAITING" :
574+ return 0 , 1
575+ elif table .get_column_value ("State" )[1 ] == "WAITING" :
576+ return 1 , 0
577+ else :
578+ return None
596579
580+ idx , other = wait_until (task_is_waiting )
597581 table = list_jobs (hq_env )
598- if table .get_column_value ("State" )[0 ] == "WAITING" :
599- idx , other = 0 , 1
600- elif table .get_column_value ("State" )[1 ] == "WAITING" :
601- idx , other = 1 , 0
602- else :
603- assert 0
604582 assert table .get_column_value ("State" )[other ] == "RUNNING"
605583
606584 wait_for_job_state (hq_env , other + 1 , "FINISHED" )
@@ -1422,20 +1400,22 @@ def test_zero_custom_error_message(hq_env: HqEnv):
14221400
14231401
14241402@pytest .mark .parametrize ("count" , [None , 1 , 7 ])
1425- def test_crashing_job_status_default (count , hq_env : HqEnv ):
1403+ def test_crashing_job_status_default (count : Optional [ int ] , hq_env : HqEnv ):
14261404 hq_env .start_server ()
14271405
1428- if count :
1429- hq_env .command (["submit" , f"--crash-limit={ count } " , "sleep" , "10" ])
1430- else :
1431- # Crashing tasks threshold is 5 by default
1432- hq_env .command (["submit" , "sleep" , "10" ])
1433- count = 5
1406+ count = count if count is not None else 5
1407+
1408+ hq_env .command (["submit" , f"--crash-limit={ count } " , "sleep" , "10" ])
14341409
14351410 for i in range (count ):
14361411 hq_env .start_worker ()
14371412 wait_for_job_state (hq_env , 1 , "RUNNING" )
14381413 hq_env .kill_worker (i + 1 )
1414+ if i < count - 1 :
1415+ wait_for_job_state (hq_env , 1 , "WAITING" )
1416+
1417+ wait_for_job_state (hq_env , 1 , "FAILED" )
1418+
14391419 table = list_jobs (hq_env )
14401420 table .check_column_value ("State" , 0 , "FAILED" )
14411421
@@ -1494,11 +1474,39 @@ def get_pid():
14941474 wait_for_pid_exit (pid )
14951475
14961476
1497- # TODO: fix this somehow
1477+ def test_kill_task_subprocess_when_worker_is_interrupted (hq_env : HqEnv ):
1478+ def interrupt_worker (worker_process ):
1479+ hq_env .kill_worker (1 , signal = signal .SIGINT )
1480+
1481+ check_child_process_exited (hq_env , interrupt_worker )
1482+
1483+
14981484@pytest .mark .xfail
1499- def test_kill_task_subprocess_when_worker_dies (hq_env : HqEnv ):
1485+ def test_kill_task_subprocess_when_worker_is_terminated (hq_env : HqEnv ):
1486+ def terminate_worker (worker_process ):
1487+ hq_env .kill_worker (1 , signal = signal .SIGTERM )
1488+
1489+ check_child_process_exited (hq_env , terminate_worker )
1490+
1491+
1492+ def test_kill_task_subprocess_when_worker_is_stopped (hq_env : HqEnv ):
1493+ def stop_worker (worker_process ):
1494+ hq_env .command (["worker" , "stop" , "1" ])
1495+ wait_for_worker_state (hq_env , 1 , "STOPPED" )
1496+ hq_env .check_process_exited (worker_process )
1497+
1498+ check_child_process_exited (hq_env , stop_worker )
1499+
1500+
1501+ def check_child_process_exited (
1502+ hq_env : HqEnv , stop_fn : Callable [[subprocess .Popen ], None ]
1503+ ):
1504+ """
1505+ Creates a task that spawns a child, and then calls `stop_fn`, which should kill either the task
1506+ or the worker. The function then checks that both the task process and its child have been killed.
1507+ """
15001508 hq_env .start_server ()
1501- hq_env .start_worker ()
1509+ worker_process = hq_env .start_worker ()
15021510
15031511 hq_env .command (
15041512 [
@@ -1507,28 +1515,23 @@ def test_kill_task_subprocess_when_worker_dies(hq_env: HqEnv):
15071515 * python (
15081516 """
15091517import os
1518+ import sys
15101519import time
1511-
1512- child_pid = os.fork()
1513- if child_pid == 0:
1514- time.sleep(3600)
1515- else:
1516- print(child_pid, flush=True)
1520+ print(os.getpid(), flush=True)
1521+ pid = os.fork()
1522+ if pid > 0:
1523+ print(pid, flush=True)
15171524time.sleep(3600)
15181525"""
15191526 ),
15201527 ]
15211528 )
15221529 wait_for_job_state (hq_env , 1 , "RUNNING" )
1530+ wait_until (lambda : len (read_file (default_task_output ()).splitlines ()) == 2 )
1531+ pids = [int (pid ) for pid in read_file (default_task_output ()).splitlines ()]
15231532
1524- def get_pid ():
1525- pid = read_file (default_task_output ()).strip ()
1526- if not pid :
1527- return None
1528- return int (pid )
1529-
1530- pid = wait_until (get_pid )
1531-
1532- hq_env .kill_worker (1 )
1533+ stop_fn (worker_process )
15331534
1534- wait_for_pid_exit (pid )
1535+ parent , child = pids
1536+ wait_for_pid_exit (parent )
1537+ wait_for_pid_exit (child )
0 commit comments