diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 7145b2a0..57f41d80 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime import os import socket import sys @@ -33,6 +34,7 @@ class WorkerArgs: local_world_size: int world_size: int log_file: os.PathLike + timeout: int def to_bytes(self) -> bytes: return cloudpickle.dumps(self) @@ -81,7 +83,11 @@ def entrypoint(serialized_worker_args: bytes): if backend is None: backend = "nccl" if torch.cuda.is_available() else "gloo" dist.init_process_group( - backend=backend, world_size=worker_args.world_size, rank=worker_args.rank, store=store + backend=backend, + world_size=worker_args.world_size, + rank=worker_args.rank, + store=store, + timeout=datetime.timedelta(seconds=worker_args.timeout), ) os.environ["RANK"] = str(worker_args.rank) @@ -130,6 +136,7 @@ def main(launcher_agent_group: LauncherAgentGroup): local_world_size=num_workers, world_size=worker_world_size, log_file=worker_log_files[i], + timeout=launcher_payload.timeout, ).to_bytes(), ) for i in range(num_workers) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 9f26b7b8..346f303b 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -96,6 +96,7 @@ class Launcher: ] ) env_file: str | os.PathLike | None = None + timeout: int = 600 def run( self, @@ -209,6 +210,7 @@ def run( worker_global_ranks=worker_global_ranks, worker_log_files=worker_log_files, backend=self.backend, + timeout=self.timeout, ) agent_payloads: list[AgentPayload] = launcher_agent_group.sync_payloads(payload=payload)[1:] # pyright: ignore[reportAssignmentType] @@ -270,6 +272,7 @@ def launch( "NCCL*", ], env_file: str | os.PathLike | None = None, + timeout: int = 600, ) -> dict[int, Any]: """ Launch a distributed PyTorch function on the specified nodes. @@ -292,6 +295,8 @@ def launch( :type env_vars: list[str], optional :param env_file: An additional environment file that will be sourced prior to executing ``func``, defaults to None :type env_file: str | os.PathLike | None, optional + :param timeout: Worker process group timeout, defaults to 600 + :type timeout: int, optional :raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func`` :return: A dictionary mapping worker ranks to their output :rtype: dict[int, Any] @@ -304,4 +309,5 @@ def launch( log_dir=log_dir, env_vars=env_vars, env_file=env_file, + timeout=timeout, ).run(func=func, func_kwargs=func_kwargs) diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index a82a7e53..09bd2466 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -29,6 +29,7 @@ class LauncherPayload: worker_global_ranks: list[list[int]] worker_log_files: list[list[Path]] backend: Literal["mpi", "gloo", "nccl", "ucc", None] + timeout: int @dataclass