Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit f431ae3

Browse files
committed
Cancel local task processes when a worker is stopped by a command or a SIG[INT|TERM|TSTP] signal
1 parent 466fc35 commit f431ae3

9 files changed

Lines changed: 132 additions & 83 deletions

File tree

crates/tako/src/internal/worker/rpc.rs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::future::Future;
22
use std::net::{Ipv4Addr, SocketAddr};
33
use std::rc::Rc;
44
use std::sync::Arc;
5-
use std::time::Duration;
5+
use std::time::{Duration, Instant};
66

77
use bytes::{Bytes, BytesMut};
88
use futures::{SinkExt, Stream, StreamExt};
@@ -221,7 +221,8 @@ pub async fn run_worker(
221221
_ = overview_fut => { unreachable!() }
222222
};
223223

224-
match result {
224+
// Handle sending stop info to the server and finishing running tasks gracefully.
225+
let result = match result {
225226
Ok(Some(msg)) => {
226227
// Worker wants to end gracefully, send message to the server
227228
{
@@ -239,11 +240,42 @@ pub async fn run_worker(
239240
// Server has disconnected
240241
tokio::select! {
241242
_ = &mut try_start_tasks => { unreachable!() }
242-
r = finish_tasks_on_server_lost(state) => r
243+
r = finish_tasks_on_server_lost(state.clone()) => r
243244
}
244245
Err(e)
245246
}
247+
};
248+
249+
// At this point, there can still be some tasks that are running.
250+
// We cancel them here to make sure that we do not leak their spawned processes, if possible.
251+
// The futures of the tasks are scheduled onto the current tokio Runtime using spawn_local,
252+
// therefore we do not need to await any specific future to drive them forward.
253+
// try_start_tasks is not being polled, therefore no new tasks should be started.
254+
{
255+
let mut state_mut = state.get_mut();
256+
for task in state_mut.running_tasks.clone() {
257+
state_mut.cancel_task(task);
258+
}
246259
}
260+
261+
let start = Instant::now();
262+
loop {
263+
if state.get().running_tasks.is_empty() {
264+
break;
265+
}
266+
267+
// Do not wait for the tasks forever
268+
if start.elapsed() > Duration::from_secs(5) {
269+
break;
270+
}
271+
272+
log::info!("Waiting for tasks to be shut down...");
273+
// The await will drive the event loop forward, giving the task futures a chance
274+
// to remove themselves from state.running_tasks
275+
tokio::time::sleep(Duration::from_secs(1)).await;
276+
}
277+
278+
result
247279
};
248280

249281
// Provide a local task set for spawning futures

tests/conftest.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def is_process_alive():
6464
]
6565
return False
6666
raise Exception(f"Process with pid {process.pid} not found")
67+
6768
wait_until(lambda: not is_process_alive())
6869

6970
def check_running_processes(self):
@@ -90,13 +91,13 @@ def get_processes_by_name(
9091
if n == name:
9192
yield (i, p)
9293

93-
def kill_process(self, name: str):
94+
def kill_process(self, name: str, signal: int = signal.SIGTERM) -> subprocess.Popen:
9495
for i, p in self.get_processes_by_name(name):
9596
del self.processes[i]
9697
# Kill the whole group since the process may spawn a child
9798
if p.returncode is None and not p.poll():
98-
os.killpg(os.getpgid(p.pid), signal.SIGTERM)
99-
return
99+
os.killpg(os.getpgid(p.pid), signal)
100+
return p
100101
else:
101102
raise Exception("Process not found")
102103

@@ -220,14 +221,16 @@ def stop_server(self):
220221
def kill_server(self):
221222
self.kill_process("server")
222223

223-
def kill_worker(self, worker_id: int):
224+
def kill_worker(self, worker_id: int, signal: int = signal.SIGTERM, wait=True):
224225
table = self.command(["worker", "info", str(worker_id)], as_table=True)
225226
pid = table.get_row_value("Process pid")
226227
process = self.find_process_by_pid(int(pid))
227228
if process is None:
228229
raise Exception(f"Worker {worker_id} not found")
229230

230-
self.kill_process(process[0])
231+
process = self.kill_process(process[0], signal=signal)
232+
if wait:
233+
wait_until(lambda: process.poll() is not None)
231234

232235
def find_process_by_pid(self, pid: int) -> Optional[Tuple[str, subprocess.Popen]]:
233236
for name, process in self.processes:

tests/pyapi/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
import os.path
22
from typing import List, Tuple
33

4+
from hyperqueue import LocalCluster
45
from hyperqueue.client import Client
56
from hyperqueue.job import Job
67

7-
from hyperqueue import LocalCluster
8-
98
from ..conftest import HqEnv
109
from ..utils.mock import ProgramMock
1110

tests/pyapi/test_cluster.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
23
from hyperqueue.cluster import LocalCluster, WorkerConfig
34
from hyperqueue.job import Job
45

tests/pyapi/test_function.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44

55
import pytest
6+
67
from hyperqueue.client import Client, FailedJobsException, PythonEnv
78
from hyperqueue.ffi.protocol import ResourceRequest
89
from hyperqueue.job import Job

tests/pyapi/test_job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import iso8601
66
import pytest
7+
78
from hyperqueue.client import FailedJobsException
89
from hyperqueue.ffi.protocol import ResourceRequest
910
from hyperqueue.job import Job

tests/test_job.py

Lines changed: 70 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import os
2+
import signal
3+
import subprocess
24
import time
35
from datetime import datetime
46
from os.path import isdir, isfile, join
57
from pathlib import Path
8+
from typing import Callable, Optional
69

710
import pytest
811

@@ -12,7 +15,7 @@
1215
from .utils.io import check_file_contents, read_file
1316
from .utils.job import default_task_output, list_jobs
1417
from .utils.table import parse_multiline_cell
15-
from .utils.wait import wait_for_pid_exit, wait_until
18+
from .utils.wait import wait_for_pid_exit, wait_for_worker_state, wait_until
1619

1720

1821
def test_job_submit(hq_env: HqEnv):
@@ -468,38 +471,11 @@ def test_cancel_all(hq_env: HqEnv):
468471

469472

470473
def test_cancel_terminate_process_children(hq_env: HqEnv):
471-
hq_env.start_server()
472-
hq_env.start_worker()
474+
def cancel(worker_process):
475+
hq_env.command(["job", "cancel", "1"])
476+
wait_for_job_state(hq_env, 1, "CANCELED")
473477

474-
hq_env.command(
475-
[
476-
"submit",
477-
"--",
478-
*python(
479-
"""
480-
import os
481-
import sys
482-
import time
483-
print(os.getpid(), flush=True)
484-
pid = os.fork()
485-
if pid > 0:
486-
print(pid, flush=True)
487-
time.sleep(3600)
488-
"""
489-
),
490-
]
491-
)
492-
wait_for_job_state(hq_env, 1, "RUNNING")
493-
wait_until(lambda: len(read_file(default_task_output()).splitlines()) == 2)
494-
495-
hq_env.command(["job", "cancel", "1"])
496-
wait_for_job_state(hq_env, 1, "CANCELED")
497-
498-
pids = [int(pid) for pid in read_file(default_task_output()).splitlines()]
499-
500-
parent, child = pids
501-
wait_for_pid_exit(parent)
502-
wait_for_pid_exit(child)
478+
check_child_process_exited(hq_env, cancel)
503479

504480

505481
def test_cancel_send_sigint(hq_env: HqEnv):
@@ -582,8 +558,8 @@ def signal_handler(sig, frame):
582558
def test_reporting_state_after_worker_lost(hq_env: HqEnv):
583559
hq_env.start_server()
584560
hq_env.start_workers(2, cpus=1)
585-
hq_env.command(["submit", "sleep", "1"])
586-
hq_env.command(["submit", "sleep", "1"])
561+
hq_env.command(["submit", "sleep", "2"])
562+
hq_env.command(["submit", "sleep", "2"])
587563

588564
wait_for_job_state(hq_env, [1, 2], "RUNNING")
589565

@@ -592,15 +568,17 @@ def test_reporting_state_after_worker_lost(hq_env: HqEnv):
592568
table.check_column_value("State", 1, "RUNNING")
593569
hq_env.kill_worker(1)
594570

595-
time.sleep(0.25)
571+
def task_is_waiting():
572+
table = list_jobs(hq_env)
573+
if table.get_column_value("State")[0] == "WAITING":
574+
return 0, 1
575+
elif table.get_column_value("State")[1] == "WAITING":
576+
return 1, 0
577+
else:
578+
return None
596579

580+
idx, other = wait_until(task_is_waiting)
597581
table = list_jobs(hq_env)
598-
if table.get_column_value("State")[0] == "WAITING":
599-
idx, other = 0, 1
600-
elif table.get_column_value("State")[1] == "WAITING":
601-
idx, other = 1, 0
602-
else:
603-
assert 0
604582
assert table.get_column_value("State")[other] == "RUNNING"
605583

606584
wait_for_job_state(hq_env, other + 1, "FINISHED")
@@ -1422,20 +1400,22 @@ def test_zero_custom_error_message(hq_env: HqEnv):
14221400

14231401

14241402
@pytest.mark.parametrize("count", [None, 1, 7])
1425-
def test_crashing_job_status_default(count, hq_env: HqEnv):
1403+
def test_crashing_job_status_default(count: Optional[int], hq_env: HqEnv):
14261404
hq_env.start_server()
14271405

1428-
if count:
1429-
hq_env.command(["submit", f"--crash-limit={count}", "sleep", "10"])
1430-
else:
1431-
# Crashing tasks threshold is 5 by default
1432-
hq_env.command(["submit", "sleep", "10"])
1433-
count = 5
1406+
count = count if count is not None else 5
1407+
1408+
hq_env.command(["submit", f"--crash-limit={count}", "sleep", "10"])
14341409

14351410
for i in range(count):
14361411
hq_env.start_worker()
14371412
wait_for_job_state(hq_env, 1, "RUNNING")
14381413
hq_env.kill_worker(i + 1)
1414+
if i < count - 1:
1415+
wait_for_job_state(hq_env, 1, "WAITING")
1416+
1417+
wait_for_job_state(hq_env, 1, "FAILED")
1418+
14391419
table = list_jobs(hq_env)
14401420
table.check_column_value("State", 0, "FAILED")
14411421

@@ -1494,11 +1474,39 @@ def get_pid():
14941474
wait_for_pid_exit(pid)
14951475

14961476

1497-
# TODO: fix this somehow
1477+
def test_kill_task_subprocess_when_worker_is_interrupted(hq_env: HqEnv):
1478+
def interrupt_worker(worker_process):
1479+
hq_env.kill_worker(1, signal=signal.SIGINT)
1480+
1481+
check_child_process_exited(hq_env, interrupt_worker)
1482+
1483+
14981484
@pytest.mark.xfail
1499-
def test_kill_task_subprocess_when_worker_dies(hq_env: HqEnv):
1485+
def test_kill_task_subprocess_when_worker_is_terminated(hq_env: HqEnv):
1486+
def terminate_worker(worker_process):
1487+
hq_env.kill_worker(1, signal=signal.SIGTERM)
1488+
1489+
check_child_process_exited(hq_env, terminate_worker)
1490+
1491+
1492+
def test_kill_task_subprocess_when_worker_is_stopped(hq_env: HqEnv):
1493+
def stop_worker(worker_process):
1494+
hq_env.command(["worker", "stop", "1"])
1495+
wait_for_worker_state(hq_env, 1, "STOPPED")
1496+
hq_env.check_process_exited(worker_process)
1497+
1498+
check_child_process_exited(hq_env, stop_worker)
1499+
1500+
1501+
def check_child_process_exited(
1502+
hq_env: HqEnv, stop_fn: Callable[[subprocess.Popen], None]
1503+
):
1504+
"""
1505+
Creates a task that spawns a child, and then calls `stop_fn`, which should kill either the task
1506+
or the worker. The function then checks that both the task process and its child have been killed.
1507+
"""
15001508
hq_env.start_server()
1501-
hq_env.start_worker()
1509+
worker_process = hq_env.start_worker()
15021510

15031511
hq_env.command(
15041512
[
@@ -1507,28 +1515,23 @@ def test_kill_task_subprocess_when_worker_dies(hq_env: HqEnv):
15071515
*python(
15081516
"""
15091517
import os
1518+
import sys
15101519
import time
1511-
1512-
child_pid = os.fork()
1513-
if child_pid == 0:
1514-
time.sleep(3600)
1515-
else:
1516-
print(child_pid, flush=True)
1520+
print(os.getpid(), flush=True)
1521+
pid = os.fork()
1522+
if pid > 0:
1523+
print(pid, flush=True)
15171524
time.sleep(3600)
15181525
"""
15191526
),
15201527
]
15211528
)
15221529
wait_for_job_state(hq_env, 1, "RUNNING")
1530+
wait_until(lambda: len(read_file(default_task_output()).splitlines()) == 2)
1531+
pids = [int(pid) for pid in read_file(default_task_output()).splitlines()]
15231532

1524-
def get_pid():
1525-
pid = read_file(default_task_output()).strip()
1526-
if not pid:
1527-
return None
1528-
return int(pid)
1529-
1530-
pid = wait_until(get_pid)
1531-
1532-
hq_env.kill_worker(1)
1533+
stop_fn(worker_process)
15331534

1534-
wait_for_pid_exit(pid)
1535+
parent, child = pids
1536+
wait_for_pid_exit(parent)
1537+
wait_for_pid_exit(child)

tests/test_jobfile.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def test_job_file_resource_variants1(hq_env: HqEnv, tmp_path):
105105
"""
106106
[[task]]
107107
id = 0
108-
command = ["sleep", "0"]
108+
command = ["sleep", "1"]
109109
110110
[[task.request]]
111111
resources = { "cpus" = "8" }
@@ -116,6 +116,7 @@ def test_job_file_resource_variants1(hq_env: HqEnv, tmp_path):
116116
)
117117
hq_env.command(["job", "submit-file", "job.toml"])
118118

119+
wait_for_job_state(hq_env, 1, "RUNNING")
119120
table = hq_env.command(["task", "info", "1", "0"], as_table=True)
120121
table.check_row_value(
121122
"Resources",

0 commit comments

Comments
 (0)