From 3fbb58c3db24911cec0d7e77e88b98e4250b2cc8 Mon Sep 17 00:00:00 2001 From: Peter Curtin Date: Wed, 17 Jul 2024 12:26:04 -0400 Subject: [PATCH 1/2] add pg_timeout flag --- src/torchrunx/agent.py | 9 ++++++++- src/torchrunx/launcher.py | 6 ++++++ src/torchrunx/utils.py | 1 + 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 7145b2a0..86367acf 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -1,5 +1,6 @@ from __future__ import annotations +import datetime import os import socket import sys @@ -33,6 +34,7 @@ class WorkerArgs: local_world_size: int world_size: int log_file: os.PathLike + pg_timeout: int def to_bytes(self) -> bytes: return cloudpickle.dumps(self) @@ -81,7 +83,11 @@ def entrypoint(serialized_worker_args: bytes): if backend is None: backend = "nccl" if torch.cuda.is_available() else "gloo" dist.init_process_group( - backend=backend, world_size=worker_args.world_size, rank=worker_args.rank, store=store + backend=backend, + world_size=worker_args.world_size, + rank=worker_args.rank, + store=store, + timeout=datetime.timedelta(seconds=worker_args.pg_timeout), ) os.environ["RANK"] = str(worker_args.rank) @@ -130,6 +136,7 @@ def main(launcher_agent_group: LauncherAgentGroup): local_world_size=num_workers, world_size=worker_world_size, log_file=worker_log_files[i], + pg_timeout=launcher_payload.pg_timeout, ).to_bytes(), ) for i in range(num_workers) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 9f26b7b8..f07ebf9f 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -96,6 +96,7 @@ class Launcher: ] ) env_file: str | os.PathLike | None = None + pg_timeout: int = 600 def run( self, @@ -209,6 +210,7 @@ def run( worker_global_ranks=worker_global_ranks, worker_log_files=worker_log_files, backend=self.backend, + pg_timeout=self.pg_timeout, ) agent_payloads: list[AgentPayload] = launcher_agent_group.sync_payloads(payload=payload)[1:] # pyright: ignore[reportAssignmentType] @@ -270,6 +272,7 @@ def launch( "NCCL*", ], env_file: str | os.PathLike | None = None, + pg_timeout: int = 600, ) -> dict[int, Any]: """ Launch a distributed PyTorch function on the specified nodes. @@ -292,6 +295,8 @@ def launch( :type env_vars: list[str], optional :param env_file: An additional environment file that will be sourced prior to executing ``func``, defaults to None :type env_file: str | os.PathLike | None, optional + :param pg_timeout: Worker process group timeout, defaults to 600 + :type pg_timeout: int, optional :raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func`` :return: A dictionary mapping worker ranks to their output :rtype: dict[int, Any] @@ -304,4 +309,5 @@ def launch( log_dir=log_dir, env_vars=env_vars, env_file=env_file, + pg_timeout=pg_timeout, ).run(func=func, func_kwargs=func_kwargs) diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index a82a7e53..41e47ce3 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -29,6 +29,7 @@ class LauncherPayload: worker_global_ranks: list[list[int]] worker_log_files: list[list[Path]] backend: Literal["mpi", "gloo", "nccl", "ucc", None] + pg_timeout: int @dataclass From eea5998e25345f59c4b1220ca86cdab815967406 Mon Sep 17 00:00:00 2001 From: Peter Curtin Date: Thu, 18 Jul 2024 12:14:19 -0400 Subject: [PATCH 2/2] rename to timeout --- src/torchrunx/agent.py | 6 +++--- src/torchrunx/launcher.py | 12 ++++++------ src/torchrunx/utils.py | 2 +- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 86367acf..57f41d80 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -34,7 +34,7 @@ class WorkerArgs: local_world_size: int world_size: int log_file: os.PathLike - pg_timeout: int + timeout: int def to_bytes(self) -> bytes: return cloudpickle.dumps(self) @@ -87,7 +87,7 @@ def entrypoint(serialized_worker_args: bytes): world_size=worker_args.world_size, rank=worker_args.rank, store=store, - timeout=datetime.timedelta(seconds=worker_args.pg_timeout), + timeout=datetime.timedelta(seconds=worker_args.timeout), ) os.environ["RANK"] = str(worker_args.rank) @@ -136,7 +136,7 @@ def main(launcher_agent_group: LauncherAgentGroup): local_world_size=num_workers, world_size=worker_world_size, log_file=worker_log_files[i], - pg_timeout=launcher_payload.pg_timeout, + timeout=launcher_payload.timeout, ).to_bytes(), ) for i in range(num_workers) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index f07ebf9f..346f303b 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -96,7 +96,7 @@ class Launcher: ] ) env_file: str | os.PathLike | None = None - pg_timeout: int = 600 + timeout: int = 600 def run( self, @@ -210,7 +210,7 @@ def run( worker_global_ranks=worker_global_ranks, worker_log_files=worker_log_files, backend=self.backend, - pg_timeout=self.pg_timeout, + timeout=self.timeout, ) agent_payloads: list[AgentPayload] = launcher_agent_group.sync_payloads(payload=payload)[1:] # pyright: ignore[reportAssignmentType] @@ -272,7 +272,7 @@ def launch( "NCCL*", ], env_file: str | os.PathLike | None = None, - pg_timeout: int = 600, + timeout: int = 600, ) -> dict[int, Any]: """ Launch a distributed PyTorch function on the specified nodes. @@ -295,8 +295,8 @@ def launch( :type env_vars: list[str], optional :param env_file: An additional environment file that will be sourced prior to executing ``func``, defaults to None :type env_file: str | os.PathLike | None, optional - :param pg_timeout: Worker process group timeout, defaults to 600 - :type pg_timeout: int, optional + :param timeout: Worker process group timeout, defaults to 600 + :type timeout: int, optional :raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func`` :return: A dictionary mapping worker ranks to their output :rtype: dict[int, Any] @@ -309,5 +309,5 @@ def launch( log_dir=log_dir, env_vars=env_vars, env_file=env_file, - pg_timeout=pg_timeout, + timeout=timeout, ).run(func=func, func_kwargs=func_kwargs) diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index 41e47ce3..09bd2466 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -29,7 +29,7 @@ class LauncherPayload: worker_global_ranks: list[list[int]] worker_log_files: list[list[Path]] backend: Literal["mpi", "gloo", "nccl", "ucc", None] - pg_timeout: int + timeout: int @dataclass