diff --git a/src/torchrunx/__init__.py b/src/torchrunx/__init__.py index ca19796d..177c444e 100644 --- a/src/torchrunx/__init__.py +++ b/src/torchrunx/__init__.py @@ -1,8 +1,10 @@ -from .launcher import AgentKilledError, Launcher, LaunchResult, launch +from .launcher import Launcher, LaunchResult, launch from .logging_utils import add_filter_to_handler, file_handler, stream_handler +from .utils import AgentFailedError, WorkerFailedError __all__ = [ - "AgentKilledError", + "AgentFailedError", + "WorkerFailedError", "Launcher", "launch", "LaunchResult", diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 1860e444..2a94c20d 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -19,8 +19,8 @@ from .utils import ( AgentPayload, AgentStatus, + ExceptionFromWorker, LauncherAgentGroup, - WorkerException, get_open_port, ) @@ -52,7 +52,7 @@ def deserialize(self) -> WorkerArgs: return cloudpickle.loads(self.bytes) -def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerException: +def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | ExceptionFromWorker: worker_args: WorkerArgs = serialized_worker_args.deserialize() logger = logging.getLogger() @@ -96,7 +96,7 @@ def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerExce return worker_args.function() except Exception as e: traceback.print_exc() - return WorkerException(exception=e) + return ExceptionFromWorker(exception=e) finally: sys.stdout.flush() sys.stderr.flush() @@ -155,7 +155,9 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ ) for i in range(num_workers) }, + # environment variables from agent are already automatically copied to workers envs={i: {} for i in range(num_workers)}, + # we handle logging ourselves, so we can discard these **( {"logs_specs": dist_mp.DefaultLogsSpecs(log_dir=tempfile.mkdtemp())} if torch.__version__ >= "2.3" @@ -167,8 +169,10 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ status = None while True: if status is None or status.state == "running": - status = AgentStatus.from_result(ctx.wait(5)) + # status can contain ExceptionFromWorker or WorkerFailedError + status = AgentStatus.from_result(result=ctx.wait(5)) + # can raise AgentFailedError in launcher and all agents agent_statuses = launcher_agent_group.sync_agent_statuses(status=status) all_done = all(s.state == "done" for s in agent_statuses) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index ddd2deaf..64b871dc 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -22,11 +22,14 @@ from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers from .logging_utils import LogRecordSocketReceiver, default_handlers -from .utils import AgentStatus, LauncherAgentGroup, LauncherPayload, WorkerException, get_open_port - - -class AgentKilledError(Exception): - pass +from .utils import ( + AgentStatus, + ExceptionFromWorker, + LauncherAgentGroup, + LauncherPayload, + WorkerFailedError, + get_open_port, +) @dataclass @@ -141,17 +144,16 @@ def run( # noqa: C901, PLR0912 # loop to monitor agent statuses (until failed or done) while True: - try: - agent_statuses = launcher_agent_group.sync_agent_statuses(status=None) - except RuntimeError as e: - # occurs if any agent dies and communication times out - raise AgentKilledError from e + # could raise AgentFailedError + agent_statuses = launcher_agent_group.sync_agent_statuses(status=None) # raises specific exception if any agent fails for s in agent_statuses: for value in s.return_values: - if isinstance(value, WorkerException): + if isinstance(value, ExceptionFromWorker): raise value.exception + if isinstance(value, WorkerFailedError): + raise value if all(s.state == "done" for s in agent_statuses): break diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index 3770e93d..b4c2e768 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -20,6 +20,14 @@ def get_open_port() -> int: return s.getsockname()[1] +class AgentFailedError(Exception): + pass + + +class WorkerFailedError(Exception): + pass + + @dataclass class LauncherAgentGroup: launcher_hostname: str @@ -50,11 +58,15 @@ def _deserialize(self, serialized: bytes) -> Any: def _all_gather(self, obj: Any) -> list: """gather object from every rank to list on every rank""" - object_bytes = self._serialize(obj) - object_list = [b""] * self.world_size - # raises RuntimeError if timeout - dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group) - return [self._deserialize(o) for o in object_list] + try: + object_bytes = self._serialize(obj) + object_list = [b""] * self.world_size + # raises RuntimeError if timeout + dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group) + return [self._deserialize(o) for o in object_list] + except RuntimeError as e: + # occurs if launcher or any agent dies and communication times out + raise AgentFailedError from e def sync_payloads( self, @@ -90,14 +102,14 @@ class AgentPayload: @dataclass -class WorkerException: +class ExceptionFromWorker: exception: Exception @dataclass class AgentStatus: state: Literal["running", "failed", "done"] - return_values: list[Any | WorkerException] = field( + return_values: list[Any | WorkerFailedError | ExceptionFromWorker] = field( default_factory=list ) # indexed by local rank @@ -105,10 +117,10 @@ class AgentStatus: def from_result(cls, result: RunProcsResult | None) -> Self: if result is None: return cls(state="running") - + for local_rank, failure in result.failures.items(): + result.return_values[local_rank] = WorkerFailedError(failure.message) return_values = list(result.return_values.values()) - - failed = any(isinstance(v, WorkerException) for v in return_values) + failed = any(isinstance(v, ExceptionFromWorker) for v in return_values) state = "failed" if failed else "done" return cls(