From 76e046265ac7ecf3ef9288e5565845b1eab40346 Mon Sep 17 00:00:00 2001 From: Peter Curtin Date: Sat, 26 Oct 2024 13:45:47 -0400 Subject: [PATCH 1/7] basic WorkerKilledError --- src/torchrunx/launcher.py | 11 ++++++++++- src/torchrunx/utils.py | 9 +++++++-- tests/test_func.py | 1 + 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index ddd2deaf..7c9db86a 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -22,7 +22,14 @@ from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers from .logging_utils import LogRecordSocketReceiver, default_handlers -from .utils import AgentStatus, LauncherAgentGroup, LauncherPayload, WorkerException, get_open_port +from .utils import ( + AgentStatus, + LauncherAgentGroup, + LauncherPayload, + WorkerException, + WorkerKilledError, + get_open_port, +) class AgentKilledError(Exception): @@ -152,6 +159,8 @@ def run( # noqa: C901, PLR0912 for value in s.return_values: if isinstance(value, WorkerException): raise value.exception + if isinstance(value, WorkerKilledError): + raise value if all(s.state == "done" for s in agent_statuses): break diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index 3770e93d..3af34f02 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -94,6 +94,11 @@ class WorkerException: exception: Exception +@dataclass +class WorkerKilledError(Exception): + failure: str + + @dataclass class AgentStatus: state: Literal["running", "failed", "done"] @@ -105,9 +110,9 @@ class AgentStatus: def from_result(cls, result: RunProcsResult | None) -> Self: if result is None: return cls(state="running") - + for local_rank, failure in result.failures.items(): + result.return_values[local_rank] = WorkerKilledError(failure.message) return_values = list(result.return_values.values()) - failed = any(isinstance(v, WorkerException) for v in return_values) state = "failed" if failed else "done" diff --git a/tests/test_func.py b/tests/test_func.py index e8033b4e..f0a4fc23 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -1,4 +1,5 @@ import os +import time import torch import torch.distributed as dist From 54b6e42525d2b484d239b94c28b52b232ba91f63 Mon Sep 17 00:00:00 2001 From: Peter Curtin Date: Sat, 26 Oct 2024 14:09:41 -0400 Subject: [PATCH 2/7] remove time import --- tests/test_func.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_func.py b/tests/test_func.py index f0a4fc23..e8033b4e 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -1,5 +1,4 @@ import os -import time import torch import torch.distributed as dist From 4fee7f46b473a2149f725d44a785b6044f79f1d6 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 27 Oct 2024 14:30:46 -0400 Subject: [PATCH 3/7] clarified environment variables copied to workers --- src/torchrunx/agent.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 1860e444..c2b08f70 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -155,7 +155,9 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ ) for i in range(num_workers) }, + # environment variables from agent are already automatically copied to workers envs={i: {} for i in range(num_workers)}, + # we handle logging ourselves, so we can discard these **( {"logs_specs": dist_mp.DefaultLogsSpecs(log_dir=tempfile.mkdtemp())} if torch.__version__ >= "2.3" From 5ec7fd921ec64fca4e81ddeaa8d3bd0868d99677 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 27 Oct 2024 14:34:00 -0400 Subject: [PATCH 4/7] renamed `WorkerException` to `ExceptionFromWorker` --- src/torchrunx/agent.py | 6 +++--- src/torchrunx/launcher.py | 4 ++-- src/torchrunx/utils.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index c2b08f70..aef0d75f 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -19,8 +19,8 @@ from .utils import ( AgentPayload, AgentStatus, + ExceptionFromWorker, LauncherAgentGroup, - WorkerException, get_open_port, ) @@ -52,7 +52,7 @@ def deserialize(self) -> WorkerArgs: return cloudpickle.loads(self.bytes) -def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerException: +def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | ExceptionFromWorker: worker_args: WorkerArgs = serialized_worker_args.deserialize() logger = logging.getLogger() @@ -96,7 +96,7 @@ def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerExce return worker_args.function() except Exception as e: traceback.print_exc() - return WorkerException(exception=e) + return ExceptionFromWorker(exception=e) finally: sys.stdout.flush() sys.stderr.flush() diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 7c9db86a..90701039 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -24,9 +24,9 @@ from .logging_utils import LogRecordSocketReceiver, default_handlers from .utils import ( AgentStatus, + ExceptionFromWorker, LauncherAgentGroup, LauncherPayload, - WorkerException, WorkerKilledError, get_open_port, ) @@ -157,7 +157,7 @@ def run( # noqa: C901, PLR0912 # raises specific exception if any agent fails for s in agent_statuses: for value in s.return_values: - if isinstance(value, WorkerException): + if isinstance(value, ExceptionFromWorker): raise value.exception if isinstance(value, WorkerKilledError): raise value diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index 3af34f02..f20e7181 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -90,7 +90,7 @@ class AgentPayload: @dataclass -class WorkerException: +class ExceptionFromWorker: exception: Exception @@ -102,7 +102,7 @@ class WorkerKilledError(Exception): @dataclass class AgentStatus: state: Literal["running", "failed", "done"] - return_values: list[Any | WorkerException] = field( + return_values: list[Any | ExceptionFromWorker] = field( default_factory=list ) # indexed by local rank @@ -113,7 +113,7 @@ def from_result(cls, result: RunProcsResult | None) -> Self: for local_rank, failure in result.failures.items(): result.return_values[local_rank] = WorkerKilledError(failure.message) return_values = list(result.return_values.values()) - failed = any(isinstance(v, WorkerException) for v in return_values) + failed = any(isinstance(v, ExceptionFromWorker) for v in return_values) state = "failed" if failed else "done" return cls( From 10fa1a00d56e90f766a63cd350af6bbe0d4cc58d Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 27 Oct 2024 14:41:13 -0400 Subject: [PATCH 5/7] renamed to `WorkerFailedError` and `AgentFailedError` --- src/torchrunx/__init__.py | 6 ++++-- src/torchrunx/errors.py | 5 +++++ src/torchrunx/launcher.py | 10 +++------- src/torchrunx/utils.py | 9 +++------ 4 files changed, 15 insertions(+), 15 deletions(-) create mode 100644 src/torchrunx/errors.py diff --git a/src/torchrunx/__init__.py b/src/torchrunx/__init__.py index ca19796d..e9edf7ec 100644 --- a/src/torchrunx/__init__.py +++ b/src/torchrunx/__init__.py @@ -1,8 +1,10 @@ -from .launcher import AgentKilledError, Launcher, LaunchResult, launch +from .errors import AgentFailedError, WorkerFailedError +from .launcher import Launcher, LaunchResult, launch from .logging_utils import add_filter_to_handler, file_handler, stream_handler __all__ = [ - "AgentKilledError", + "AgentFailedError", + "WorkerFailedError", "Launcher", "launch", "LaunchResult", diff --git a/src/torchrunx/errors.py b/src/torchrunx/errors.py new file mode 100644 index 00000000..68c3ded7 --- /dev/null +++ b/src/torchrunx/errors.py @@ -0,0 +1,5 @@ +class AgentFailedError(Exception): + pass + +class WorkerFailedError(Exception): + pass diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 90701039..471642e8 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -21,21 +21,17 @@ import torch.distributed as dist from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers +from .errors import AgentFailedError, WorkerFailedError from .logging_utils import LogRecordSocketReceiver, default_handlers from .utils import ( AgentStatus, ExceptionFromWorker, LauncherAgentGroup, LauncherPayload, - WorkerKilledError, get_open_port, ) -class AgentKilledError(Exception): - pass - - @dataclass class Launcher: hostnames: list[str] | Literal["auto", "slurm"] = "auto" @@ -152,14 +148,14 @@ def run( # noqa: C901, PLR0912 agent_statuses = launcher_agent_group.sync_agent_statuses(status=None) except RuntimeError as e: # occurs if any agent dies and communication times out - raise AgentKilledError from e + raise AgentFailedError from e # raises specific exception if any agent fails for s in agent_statuses: for value in s.return_values: if isinstance(value, ExceptionFromWorker): raise value.exception - if isinstance(value, WorkerKilledError): + if isinstance(value, WorkerFailedError): raise value if all(s.state == "done" for s in agent_statuses): diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index f20e7181..7ac79f19 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -10,6 +10,8 @@ import torch.distributed as dist from typing_extensions import Self +from .errors import WorkerFailedError + if TYPE_CHECKING: from torch.distributed.elastic.multiprocessing.api import RunProcsResult @@ -94,11 +96,6 @@ class ExceptionFromWorker: exception: Exception -@dataclass -class WorkerKilledError(Exception): - failure: str - - @dataclass class AgentStatus: state: Literal["running", "failed", "done"] @@ -111,7 +108,7 @@ def from_result(cls, result: RunProcsResult | None) -> Self: if result is None: return cls(state="running") for local_rank, failure in result.failures.items(): - result.return_values[local_rank] = WorkerKilledError(failure.message) + result.return_values[local_rank] = WorkerFailedError(failure.message) return_values = list(result.return_values.values()) failed = any(isinstance(v, ExceptionFromWorker) for v in return_values) state = "failed" if failed else "done" From c0baede0e418e32e73f92350cf39c3b0aa9ada27 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 27 Oct 2024 14:46:25 -0400 Subject: [PATCH 6/7] Moved failed errors to utils; raising AgentFailedError in _all_gather --- src/torchrunx/__init__.py | 2 +- src/torchrunx/errors.py | 5 ----- src/torchrunx/launcher.py | 9 +++------ src/torchrunx/utils.py | 23 ++++++++++++++++------- 4 files changed, 20 insertions(+), 19 deletions(-) delete mode 100644 src/torchrunx/errors.py diff --git a/src/torchrunx/__init__.py b/src/torchrunx/__init__.py index e9edf7ec..177c444e 100644 --- a/src/torchrunx/__init__.py +++ b/src/torchrunx/__init__.py @@ -1,6 +1,6 @@ -from .errors import AgentFailedError, WorkerFailedError from .launcher import Launcher, LaunchResult, launch from .logging_utils import add_filter_to_handler, file_handler, stream_handler +from .utils import AgentFailedError, WorkerFailedError __all__ = [ "AgentFailedError", diff --git a/src/torchrunx/errors.py b/src/torchrunx/errors.py deleted file mode 100644 index 68c3ded7..00000000 --- a/src/torchrunx/errors.py +++ /dev/null @@ -1,5 +0,0 @@ -class AgentFailedError(Exception): - pass - -class WorkerFailedError(Exception): - pass diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 471642e8..64b871dc 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -21,13 +21,13 @@ import torch.distributed as dist from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers -from .errors import AgentFailedError, WorkerFailedError from .logging_utils import LogRecordSocketReceiver, default_handlers from .utils import ( AgentStatus, ExceptionFromWorker, LauncherAgentGroup, LauncherPayload, + WorkerFailedError, get_open_port, ) @@ -144,11 +144,8 @@ def run( # noqa: C901, PLR0912 # loop to monitor agent statuses (until failed or done) while True: - try: - agent_statuses = launcher_agent_group.sync_agent_statuses(status=None) - except RuntimeError as e: - # occurs if any agent dies and communication times out - raise AgentFailedError from e + # could raise AgentFailedError + agent_statuses = launcher_agent_group.sync_agent_statuses(status=None) # raises specific exception if any agent fails for s in agent_statuses: diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index 7ac79f19..c1dac52b 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -10,8 +10,6 @@ import torch.distributed as dist from typing_extensions import Self -from .errors import WorkerFailedError - if TYPE_CHECKING: from torch.distributed.elastic.multiprocessing.api import RunProcsResult @@ -22,6 +20,13 @@ def get_open_port() -> int: return s.getsockname()[1] +class AgentFailedError(Exception): + pass + +class WorkerFailedError(Exception): + pass + + @dataclass class LauncherAgentGroup: launcher_hostname: str @@ -52,11 +57,15 @@ def _deserialize(self, serialized: bytes) -> Any: def _all_gather(self, obj: Any) -> list: """gather object from every rank to list on every rank""" - object_bytes = self._serialize(obj) - object_list = [b""] * self.world_size - # raises RuntimeError if timeout - dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group) - return [self._deserialize(o) for o in object_list] + try: + object_bytes = self._serialize(obj) + object_list = [b""] * self.world_size + # raises RuntimeError if timeout + dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group) + return [self._deserialize(o) for o in object_list] + except RuntimeError as e: + # occurs if launcher or any agent dies and communication times out + raise AgentFailedError from e def sync_payloads( self, From 2824b6a017553a0e16fbce7b090883be052bc316 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 27 Oct 2024 14:59:28 -0400 Subject: [PATCH 7/7] some extra error docs --- src/torchrunx/agent.py | 4 +++- src/torchrunx/utils.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index aef0d75f..2a94c20d 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -169,8 +169,10 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ status = None while True: if status is None or status.state == "running": - status = AgentStatus.from_result(ctx.wait(5)) + # status can contain ExceptionFromWorker or WorkerFailedError + status = AgentStatus.from_result(result=ctx.wait(5)) + # can raise AgentFailedError in launcher and all agents agent_statuses = launcher_agent_group.sync_agent_statuses(status=status) all_done = all(s.state == "done" for s in agent_statuses) diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index c1dac52b..b4c2e768 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -23,6 +23,7 @@ def get_open_port() -> int: class AgentFailedError(Exception): pass + class WorkerFailedError(Exception): pass @@ -108,7 +109,7 @@ class ExceptionFromWorker: @dataclass class AgentStatus: state: Literal["running", "failed", "done"] - return_values: list[Any | ExceptionFromWorker] = field( + return_values: list[Any | WorkerFailedError | ExceptionFromWorker] = field( default_factory=list ) # indexed by local rank