Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fix for logging server serialization problems #80

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "torchrunx"
version = "0.2.0"
version = "0.2.1"
authors = [
{name = "Apoorv Khandelwal", email = "[email protected]"},
{name = "Peter Curtin", email = "[email protected]"},
Expand Down
53 changes: 16 additions & 37 deletions src/torchrunx/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from dataclasses import dataclass
from functools import partial, reduce
from logging import Handler
from multiprocessing import Process
from multiprocessing import Event, Process
from operator import add
from pathlib import Path
from typing import Any, Callable, Literal
Expand All @@ -34,7 +34,7 @@
ExceptionFromWorker,
WorkerFailedError,
)
from .utils.logging import LogRecordSocketReceiver, default_handlers
from .utils.logging import LoggingServerArgs, start_logging_server


@dataclass
Expand Down Expand Up @@ -76,27 +76,32 @@ def run( # noqa: C901, PLR0912

launcher_hostname = socket.getfqdn()
launcher_port = get_open_port()
logging_port = get_open_port()
world_size = len(hostnames) + 1

log_receiver = None
stop_logging_event = None
log_process = None
launcher_agent_group = None
agent_payloads = None

try:
# Start logging server (recieves LogRecords from agents/workers)

log_receiver = _build_logging_server(
logging_server_args = LoggingServerArgs(
log_handlers=log_handlers,
launcher_hostname=launcher_hostname,
logging_hostname=launcher_hostname,
logging_port=logging_port,
hostnames=hostnames,
workers_per_host=workers_per_host,
log_dir=Path(os.environ.get("TORCHRUNX_LOG_DIR", "torchrunx_logs")),
log_level=logging._nameToLevel[os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")], # noqa: SLF001
)

stop_logging_event = Event()

log_process = Process(
target=log_receiver.serve_forever,
target=start_logging_server,
args=(logging_server_args.serialize(), stop_logging_event),
daemon=True,
)

Expand All @@ -109,7 +114,7 @@ def run( # noqa: C901, PLR0912
command=_build_launch_command(
launcher_hostname=launcher_hostname,
launcher_port=launcher_port,
logger_port=log_receiver.port,
logger_port=logging_port,
world_size=world_size,
rank=i + 1,
env_vars=(self.default_env_vars + self.extra_env_vars),
Expand Down Expand Up @@ -166,11 +171,10 @@ def run( # noqa: C901, PLR0912
if all(s.state == "done" for s in agent_statuses):
break
finally:
if log_receiver is not None:
log_receiver.shutdown()
if log_process is not None:
log_receiver.server_close()
log_process.kill()
if stop_logging_event is not None:
stop_logging_event.set()
if log_process is not None:
log_process.kill()

if launcher_agent_group is not None:
launcher_agent_group.shutdown()
Expand Down Expand Up @@ -307,31 +311,6 @@ def _resolve_workers_per_host(
return workers_per_host


def _build_logging_server(
log_handlers: list[Handler] | Literal["auto"] | None,
launcher_hostname: str,
hostnames: list[str],
workers_per_host: list[int],
log_dir: str | os.PathLike,
log_level: int,
) -> LogRecordSocketReceiver:
if log_handlers is None:
log_handlers = []
elif log_handlers == "auto":
log_handlers = default_handlers(
hostnames=hostnames,
workers_per_host=workers_per_host,
log_dir=log_dir,
log_level=log_level,
)

return LogRecordSocketReceiver(
host=launcher_hostname,
port=get_open_port(),
handlers=log_handlers,
)


def _build_launch_command(
launcher_hostname: str,
launcher_port: int,
Expand Down
67 changes: 64 additions & 3 deletions src/torchrunx/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from __future__ import annotations

__all__ = [
"LogRecordSocketReceiver",
"LoggingServerArgs",
"start_logging_server",
"redirect_stdio_to_logger",
"log_records_to_socket",
"add_filter_to_handler",
Expand All @@ -25,12 +26,14 @@
from logging.handlers import SocketHandler
from pathlib import Path
from socketserver import StreamRequestHandler, ThreadingTCPServer
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal

import cloudpickle
from typing_extensions import Self

if TYPE_CHECKING:
import os
from multiprocessing.synchronize import Event as EventClass

## Handler utilities

Expand Down Expand Up @@ -139,7 +142,7 @@ def default_handlers(
## Launcher utilities


class LogRecordSocketReceiver(ThreadingTCPServer):
class _LogRecordSocketReceiver(ThreadingTCPServer):
"""TCP server for recieving Agent/Worker log records in Launcher.

Uses threading to avoid bottlenecks (i.e. "out-of-order" logs in Launcher process).
Expand Down Expand Up @@ -180,6 +183,64 @@ def shutdown(self) -> None:
self._BaseServer__is_shut_down.wait(timeout=3) # pyright: ignore[reportAttributeAccessIssue]


@dataclass
class LoggingServerArgs:
"""Arguments for starting a :class:`_LogRecordSocketReceiver`."""

log_handlers: list[Handler] | Literal["auto"] | None
logging_hostname: str
logging_port: int
hostnames: list[str]
workers_per_host: list[int]
log_dir: str | os.PathLike
log_level: int

def serialize(self) -> SerializedLoggingServerArgs:
"""Serialize :class:`LoggingServerArgs` for passing to a new process."""
return SerializedLoggingServerArgs(args=self)


class SerializedLoggingServerArgs:
def __init__(self, args: LoggingServerArgs) -> None:
self.bytes = cloudpickle.dumps(args)

def deserialize(self) -> LoggingServerArgs:
return cloudpickle.loads(self.bytes)


def start_logging_server(
serialized_args: SerializedLoggingServerArgs,
stop_event: EventClass,
) -> None:
"""Serve :class:`_LogRecordSocketReceiver` until stop event triggered."""
args: LoggingServerArgs = serialized_args.deserialize()

log_handlers = args.log_handlers
if log_handlers is None:
log_handlers = []
elif log_handlers == "auto":
log_handlers = default_handlers(
hostnames=args.hostnames,
workers_per_host=args.workers_per_host,
log_dir=args.log_dir,
log_level=args.log_level,
)

log_receiver = _LogRecordSocketReceiver(
host=args.logging_hostname,
port=args.logging_port,
handlers=log_handlers,
)

log_receiver.serve_forever()

while not stop_event.is_set():
pass

log_receiver.shutdown()
log_receiver.server_close()


## Agent/worker utilities


Expand Down