diff --git a/pyproject.toml b/pyproject.toml index 3fb9d1ab..33169f50 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "torchrunx" -version = "0.2.0" +version = "0.2.1" authors = [ {name = "Apoorv Khandelwal", email = "mail@apoorvkh.com"}, {name = "Peter Curtin", email = "peter_curtin@brown.edu"}, diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index b745ac7a..f12d1138 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -16,7 +16,7 @@ from dataclasses import dataclass from functools import partial, reduce from logging import Handler -from multiprocessing import Process +from multiprocessing import Event, Process from operator import add from pathlib import Path from typing import Any, Callable, Literal @@ -34,7 +34,7 @@ ExceptionFromWorker, WorkerFailedError, ) -from .utils.logging import LogRecordSocketReceiver, default_handlers +from .utils.logging import LoggingServerArgs, start_logging_server @dataclass @@ -76,9 +76,10 @@ def run( # noqa: C901, PLR0912 launcher_hostname = socket.getfqdn() launcher_port = get_open_port() + logging_port = get_open_port() world_size = len(hostnames) + 1 - log_receiver = None + stop_logging_event = None log_process = None launcher_agent_group = None agent_payloads = None @@ -86,17 +87,21 @@ def run( # noqa: C901, PLR0912 try: # Start logging server (recieves LogRecords from agents/workers) - log_receiver = _build_logging_server( + logging_server_args = LoggingServerArgs( log_handlers=log_handlers, - launcher_hostname=launcher_hostname, + logging_hostname=launcher_hostname, + logging_port=logging_port, hostnames=hostnames, workers_per_host=workers_per_host, log_dir=Path(os.environ.get("TORCHRUNX_LOG_DIR", "torchrunx_logs")), log_level=logging._nameToLevel[os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")], # noqa: SLF001 ) + stop_logging_event = Event() + log_process = Process( - target=log_receiver.serve_forever, + target=start_logging_server, + args=(logging_server_args.serialize(), stop_logging_event), daemon=True, ) @@ -109,7 +114,7 @@ def run( # noqa: C901, PLR0912 command=_build_launch_command( launcher_hostname=launcher_hostname, launcher_port=launcher_port, - logger_port=log_receiver.port, + logger_port=logging_port, world_size=world_size, rank=i + 1, env_vars=(self.default_env_vars + self.extra_env_vars), @@ -166,11 +171,10 @@ def run( # noqa: C901, PLR0912 if all(s.state == "done" for s in agent_statuses): break finally: - if log_receiver is not None: - log_receiver.shutdown() - if log_process is not None: - log_receiver.server_close() - log_process.kill() + if stop_logging_event is not None: + stop_logging_event.set() + if log_process is not None: + log_process.kill() if launcher_agent_group is not None: launcher_agent_group.shutdown() @@ -307,31 +311,6 @@ def _resolve_workers_per_host( return workers_per_host -def _build_logging_server( - log_handlers: list[Handler] | Literal["auto"] | None, - launcher_hostname: str, - hostnames: list[str], - workers_per_host: list[int], - log_dir: str | os.PathLike, - log_level: int, -) -> LogRecordSocketReceiver: - if log_handlers is None: - log_handlers = [] - elif log_handlers == "auto": - log_handlers = default_handlers( - hostnames=hostnames, - workers_per_host=workers_per_host, - log_dir=log_dir, - log_level=log_level, - ) - - return LogRecordSocketReceiver( - host=launcher_hostname, - port=get_open_port(), - handlers=log_handlers, - ) - - def _build_launch_command( launcher_hostname: str, launcher_port: int, diff --git a/src/torchrunx/utils/logging.py b/src/torchrunx/utils/logging.py index efd013a8..d9be67a0 100644 --- a/src/torchrunx/utils/logging.py +++ b/src/torchrunx/utils/logging.py @@ -3,7 +3,8 @@ from __future__ import annotations __all__ = [ - "LogRecordSocketReceiver", + "LoggingServerArgs", + "start_logging_server", "redirect_stdio_to_logger", "log_records_to_socket", "add_filter_to_handler", @@ -25,12 +26,14 @@ from logging.handlers import SocketHandler from pathlib import Path from socketserver import StreamRequestHandler, ThreadingTCPServer -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal +import cloudpickle from typing_extensions import Self if TYPE_CHECKING: import os + from multiprocessing.synchronize import Event as EventClass ## Handler utilities @@ -139,7 +142,7 @@ def default_handlers( ## Launcher utilities -class LogRecordSocketReceiver(ThreadingTCPServer): +class _LogRecordSocketReceiver(ThreadingTCPServer): """TCP server for recieving Agent/Worker log records in Launcher. Uses threading to avoid bottlenecks (i.e. "out-of-order" logs in Launcher process). @@ -180,6 +183,64 @@ def shutdown(self) -> None: self._BaseServer__is_shut_down.wait(timeout=3) # pyright: ignore[reportAttributeAccessIssue] +@dataclass +class LoggingServerArgs: + """Arguments for starting a :class:`_LogRecordSocketReceiver`.""" + + log_handlers: list[Handler] | Literal["auto"] | None + logging_hostname: str + logging_port: int + hostnames: list[str] + workers_per_host: list[int] + log_dir: str | os.PathLike + log_level: int + + def serialize(self) -> SerializedLoggingServerArgs: + """Serialize :class:`LoggingServerArgs` for passing to a new process.""" + return SerializedLoggingServerArgs(args=self) + + +class SerializedLoggingServerArgs: + def __init__(self, args: LoggingServerArgs) -> None: + self.bytes = cloudpickle.dumps(args) + + def deserialize(self) -> LoggingServerArgs: + return cloudpickle.loads(self.bytes) + + +def start_logging_server( + serialized_args: SerializedLoggingServerArgs, + stop_event: EventClass, +) -> None: + """Serve :class:`_LogRecordSocketReceiver` until stop event triggered.""" + args: LoggingServerArgs = serialized_args.deserialize() + + log_handlers = args.log_handlers + if log_handlers is None: + log_handlers = [] + elif log_handlers == "auto": + log_handlers = default_handlers( + hostnames=args.hostnames, + workers_per_host=args.workers_per_host, + log_dir=args.log_dir, + log_level=args.log_level, + ) + + log_receiver = _LogRecordSocketReceiver( + host=args.logging_hostname, + port=args.logging_port, + handlers=log_handlers, + ) + + log_receiver.serve_forever() + + while not stop_event.is_set(): + pass + + log_receiver.shutdown() + log_receiver.server_close() + + ## Agent/worker utilities