diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c1726931..45a3b483 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -86,4 +86,4 @@ jobs: cache: false environments: default activate-environment: default - - run: pytest tests/test_CI.py + - run: pytest tests/test_ci.py diff --git a/pixi.lock b/pixi.lock index 8e8c95df..e4fae5ca 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2601,9 +2601,9 @@ packages: requires_python: '>=3.8.0' - kind: pypi name: torchrunx - version: 0.1.3 + version: 0.2.0 path: . - sha256: 7352054b1212a4ce0d60c055288dd4f51cea2093a84d0a1a48ea97bdaa703fad + sha256: 1753f43bee54bc0da38cdd524dc501c0c2be9fbaaa7036bced9c9d03a7a8e810 requires_dist: - cloudpickle>=3.0.0 - fabric>=3.0.0 diff --git a/pyproject.toml b/pyproject.toml index 6f029d8b..33acfaeb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "torchrunx" -version = "0.1.3" +version = "0.2.0" authors = [ {name = "Apoorv Khandelwal", email = "mail@apoorvkh.com"}, {name = "Peter Curtin", email = "peter_curtin@brown.edu"}, @@ -41,7 +41,24 @@ include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"] line-length = 100 src = ["src", "tests"] [tool.ruff.lint] -select = ["E", "F", "B", "UP", "I"] +select = ["ALL"] +ignore = [ + "D", # documentation + "ANN101", "ANN102", "ANN401", # self / cls / Any annotations + "BLE001", # blind exceptions + "TD", # todo syntax + "FIX002", # existing todos + "PLR0913", # too many arguments + "DTZ005", # datetime timezone + "S301", # bandit: pickle + "S603", "S607", # bandit: subprocess + "COM812", "ISC001", # conflict with formatter +] +[tool.ruff.lint.per-file-ignores] +"tests/**/*.py" = [ + "S101", # allow asserts + "T201" # allow prints +] [tool.pyright] include = ["src", "tests"] diff --git a/src/torchrunx/__init__.py b/src/torchrunx/__init__.py index 46b3b1b9..74214cb8 100644 --- a/src/torchrunx/__init__.py +++ b/src/torchrunx/__init__.py @@ -1,6 +1,10 @@ from .launcher import Launcher, launch +from .logging_utils import add_filter_to_handler, file_handler, stream_handler __all__ = [ "Launcher", "launch", + "add_filter_to_handler", + "file_handler", + "stream_handler", ] diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index f4dfab33..04d1ec92 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -6,14 +6,14 @@ import socket import sys import tempfile +import traceback from dataclasses import dataclass from typing import Any, Callable, Literal import cloudpickle import torch import torch.distributed as dist -from torch.distributed.elastic.multiprocessing import start_processes -from typing_extensions import Self +import torch.distributed.elastic.multiprocessing as dist_mp from .logging_utils import log_records_to_socket, redirect_stdio_to_logger from .utils import ( @@ -40,16 +40,20 @@ class WorkerArgs: hostname: str timeout: int - def to_bytes(self) -> bytes: - return cloudpickle.dumps(self) + def serialize(self) -> SerializedWorkerArgs: + return SerializedWorkerArgs(worker_args=self) - @classmethod - def from_bytes(cls, serialized: bytes) -> Self: - return cloudpickle.loads(serialized) +class SerializedWorkerArgs: + def __init__(self, worker_args: WorkerArgs) -> None: + self.bytes = cloudpickle.dumps(worker_args) -def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException: - worker_args = WorkerArgs.from_bytes(serialized_worker_args) + def deserialize(self) -> WorkerArgs: + return cloudpickle.loads(self.bytes) + + +def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerException: + worker_args: WorkerArgs = serialized_worker_args.deserialize() logger = logging.getLogger() @@ -63,18 +67,14 @@ def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException: redirect_stdio_to_logger(logger) - store = dist.TCPStore( # pyright: ignore[reportPrivateImportUsage] + store = dist.TCPStore( # pyright: ignore [reportPrivateImportUsage] host_name=worker_args.main_agent_hostname, port=worker_args.main_agent_port, world_size=worker_args.world_size, is_master=(worker_args.rank == 0), ) - backend = worker_args.backend - if backend is None: - backend = "nccl" if torch.cuda.is_available() else "gloo" - - logger.debug(f"using backend: {backend}") + backend = worker_args.backend or ("nccl" if torch.cuda.is_available() else "gloo") dist.init_process_group( backend=backend, @@ -91,19 +91,17 @@ def entrypoint(serialized_worker_args: bytes) -> Any | WorkerException: os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname os.environ["MASTER_PORT"] = str(worker_args.main_agent_port) - logger.debug(f"executing function: {worker_args.function}") - try: return worker_args.function() except Exception as e: - logger.error(e) + traceback.print_exc() return WorkerException(exception=e) finally: sys.stdout.flush() sys.stderr.flush() -def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int): +def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int) -> None: agent_rank = launcher_agent_group.rank - 1 payload = AgentPayload( @@ -132,16 +130,9 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ redirect_stdio_to_logger(logger) - if torch.__version__ >= "2.3": - from torch.distributed.elastic.multiprocessing import DefaultLogsSpecs - - log_kwargs = {"logs_specs": DefaultLogsSpecs(log_dir=tempfile.mkdtemp())} - else: - log_kwargs = {"log_dir": tempfile.mkdtemp()} - # spawn workers - ctx = start_processes( + ctx = dist_mp.start_processes( name=f"{hostname}_", entrypoint=entrypoint, args={ @@ -159,31 +150,30 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ world_size=worker_world_size, hostname=launcher_payload.hostnames[agent_rank], timeout=launcher_payload.timeout, - ).to_bytes(), + ).serialize(), ) for i in range(num_workers) }, envs={i: {} for i in range(num_workers)}, - **log_kwargs, # pyright: ignore [reportArgumentType] + **( + {"logs_specs": dist_mp.DefaultLogsSpecs(log_dir=tempfile.mkdtemp())} + if torch.__version__ >= "2.3" + else {"log_dir": tempfile.mkdtemp()} + ), # pyright: ignore [reportArgumentType] ) - logger.info("starting processes") try: status = None while True: if status is None or status.state == "running": - status = AgentStatus.from_result( - result=ctx.wait(5), worker_global_ranks=worker_global_ranks - ) + status = AgentStatus.from_result(ctx.wait(5)) agent_statuses = launcher_agent_group.sync_agent_statuses(status=status) - if all(s.state == "done" for s in agent_statuses): - break - elif any(s.state == "failed" for s in agent_statuses): + all_done = all(s.state == "done" for s in agent_statuses) + any_failed = any(s.state == "failed" for s in agent_statuses) + if all_done or any_failed: break - except: - raise finally: ctx.close() sys.stdout.flush() diff --git a/src/torchrunx/environment.py b/src/torchrunx/environment.py index edf1431d..179cfb8d 100644 --- a/src/torchrunx/environment.py +++ b/src/torchrunx/environment.py @@ -17,7 +17,9 @@ def slurm_hosts() -> list[str]: :rtype: list[str] """ # TODO: sanity check SLURM variables, commands - assert in_slurm_job() + if not in_slurm_job(): + msg = "Not in a SLURM job" + raise RuntimeError(msg) return ( subprocess.check_output(["scontrol", "show", "hostnames", os.environ["SLURM_JOB_NODELIST"]]) .decode() @@ -35,15 +37,18 @@ def slurm_workers() -> int: :rtype: int """ # TODO: sanity check SLURM variables, commands - assert in_slurm_job() + if not in_slurm_job(): + msg = "Not in a SLURM job" + raise RuntimeError(msg) + if "SLURM_JOB_GPUS" in os.environ: # TODO: is it possible to allocate uneven GPUs across nodes? return len(os.environ["SLURM_JOB_GPUS"].split(",")) - elif "SLURM_GPUS_PER_NODE" in os.environ: + if "SLURM_GPUS_PER_NODE" in os.environ: return int(os.environ["SLURM_GPUS_PER_NODE"]) - else: - # TODO: should we assume that we plan to do one worker per CPU? - return int(os.environ["SLURM_CPUS_ON_NODE"]) + + # TODO: should we assume that we plan to do one worker per CPU? + return int(os.environ["SLURM_CPUS_ON_NODE"]) def auto_hosts() -> list[str]: diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 89cde697..4c826a6c 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -5,14 +5,15 @@ import itertools import logging import os +import shlex import socket import subprocess import sys -from collections import ChainMap from dataclasses import dataclass from functools import partial from logging import Handler from multiprocessing import Process +from pathlib import Path from typing import Any, Callable, Literal, Sequence import fabric @@ -28,6 +29,99 @@ ) +def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: + if hostnames == "auto": + return auto_hosts() + if hostnames == "slurm": + return slurm_hosts() + return hostnames + + +def resolve_workers_per_host( + workers_per_host: int | list[int] | Literal["auto", "slurm"], + num_hosts: int, +) -> list[int]: + if workers_per_host == "auto": + workers_per_host = auto_workers() + elif workers_per_host == "slurm": + workers_per_host = slurm_workers() + + if isinstance(workers_per_host, int): + workers_per_host = [workers_per_host] * num_hosts + elif len(workers_per_host) != num_hosts: + msg = "len(workers_per_host) != len(hostnames)" + raise ValueError(msg) + + return workers_per_host + + +def build_logging_server( + log_handlers: list[Handler] | Literal["auto"] | None, + launcher_hostname: str, + hostnames: list[str], + workers_per_host: list[int], + log_dir: str | os.PathLike, + log_level: int, +) -> LogRecordSocketReceiver: + if log_handlers is None: + log_handlers = [] + elif log_handlers == "auto": + log_handlers = default_handlers( + hostnames=hostnames, + workers_per_host=workers_per_host, + log_dir=log_dir, + log_level=log_level, + ) + + return LogRecordSocketReceiver( + host=launcher_hostname, + port=get_open_port(), + handlers=log_handlers, + ) + + +def build_command( + launcher_hostname: str, + launcher_port: int, + logger_port: int, + world_size: int, + rank: int, + env_vars: Sequence[str], + env_file: str | os.PathLike | None, +) -> str: + # shlex.quote prevents shell injection here (resolves S602 in execute_command) + + commands = [] + + current_dir = shlex.quote(str(Path.cwd())) + commands.append("cd " + current_dir) + + env_exports = [] + for k, v in os.environ.items(): + if any(fnmatch.fnmatch(k, e) for e in env_vars): + env_exports.append(shlex.quote(f"{k}={v}")) + + if len(env_exports) > 0: + commands.append("export " + " ".join(env_exports)) + + if env_file is not None: + commands.append("source " + shlex.quote(str(env_file))) + + python = shlex.quote(sys.executable) + launcher_hostname = shlex.quote(launcher_hostname) + + commands.append( + f"{python} -u -m torchrunx " + f"--launcher-hostname {launcher_hostname} " + f"--launcher-port {launcher_port} " + f"--logger-port {logger_port} " + f"--world-size {world_size} " + f"--rank {rank}", + ) + + return " && ".join(commands) + + def is_localhost(hostname_or_ip: str) -> bool: # check if host is "loopback" address (i.e. designated to send to self) try: @@ -48,10 +142,13 @@ def execute_command( ssh_config_file: str | os.PathLike | None = None, ) -> None: if is_localhost(hostname): - subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) + # S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.8/library/subprocess.html#security-considerations) + # Made sure to shlex.quote arguments in build_command to prevent shell injection + subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # noqa: S602 else: with fabric.Connection( - host=hostname, config=fabric.Config(runtime_ssh_path=ssh_config_file) + host=hostname, + config=fabric.Config(runtime_ssh_path=ssh_config_file), ) as conn: conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True) @@ -81,7 +178,7 @@ def run( func: Callable, func_args: tuple[Any] | None = None, func_kwargs: dict[str, Any] | None = None, - ) -> dict[int, Any]: + ) -> dict[str, dict[int, Any]]: """ Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch` @@ -96,93 +193,53 @@ def run( :rtype: dict[int, Any] """ if not dist.is_available(): - raise RuntimeError("The torch.distributed package is not available.") - - if self.hostnames == "auto": - self.hostnames = auto_hosts() - elif self.hostnames == "slurm": - self.hostnames = slurm_hosts() - - num_hosts = len(self.hostnames) - - if self.workers_per_host == "auto": - self.workers_per_host = auto_workers() - elif self.workers_per_host == "slurm": - self.workers_per_host = slurm_workers() + msg = "The torch.distributed package is not available." + raise RuntimeError(msg) - if isinstance(self.workers_per_host, int): - self.workers_per_host = [self.workers_per_host] * num_hosts - - assert num_hosts == len(self.workers_per_host) - - # + hostnames = resolve_hostnames(self.hostnames) + workers_per_host = resolve_workers_per_host(self.workers_per_host, len(hostnames)) launcher_hostname = socket.getfqdn() + launcher_port = get_open_port() + world_size = len(hostnames) + 1 - # setup logging - - if self.log_handlers is None: - self.log_handlers = [] - elif self.log_handlers == "auto": - self.log_handlers = default_handlers( - hostnames=self.hostnames, - workers_per_host=self.workers_per_host, - log_dir=os.environ.get("TORCHRUNX_DIR", "./torchrunx_logs"), - log_level=logging._nameToLevel.get( - os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO"), logging.NOTSET - ), - ) + # start logging server - logger_port = get_open_port() - log_receiver = LogRecordSocketReceiver( - host=launcher_hostname, port=logger_port, handlers=self.log_handlers + log_receiver = build_logging_server( + log_handlers=self.log_handlers, + launcher_hostname=launcher_hostname, + hostnames=hostnames, + workers_per_host=workers_per_host, + log_dir=Path(os.environ.get("TORCHRUNX_LOG_DIR", "torchrunx_logs")), + log_level=logging._nameToLevel[os.environ.get("TORCHRUNX_LOG_LEVEL", "INFO")], # noqa: SLF001 ) + log_process = Process( target=log_receiver.serve_forever, daemon=True, ) - log_process.start() - - # launch command - - current_dir = os.getcwd() - env_exports = [] - for k, v in os.environ.items(): - if any(fnmatch.fnmatch(k, e) for e in self.env_vars): - env_exports.append(f"{k}={v}") - - env_export_string = "" - if len(env_exports) > 0: - env_export_string = f"export {' '.join(env_exports)} && " - - env_file_string = "" - if self.env_file is not None: - env_file_string = f"source {self.env_file} && " - - launcher_port = get_open_port() - world_size = num_hosts + 1 # launcher + agents + log_process.start() # start agents on each node - for i, hostname in enumerate(self.hostnames): + + for i, hostname in enumerate(hostnames): execute_command( - command=( - f"cd {current_dir} && " - f"{env_export_string}" - f"{env_file_string}" - f"{sys.executable} -u -m torchrunx " - f"--launcher-hostname {launcher_hostname} " - f"--launcher-port {launcher_port} " - f"--logger-port {logger_port} " - f"--world-size {world_size} " - f"--rank {i+1}" + command=build_command( + launcher_hostname=launcher_hostname, + launcher_port=launcher_port, + logger_port=log_receiver.port, + world_size=world_size, + rank=i + 1, + env_vars=self.env_vars, + env_file=self.env_file, ), hostname=hostname, ssh_config_file=self.ssh_config_file, ) - # initialize launcher–agent process group - # ranks = (launcher, agent_0, ..., agent_{num_hosts-1}) + # initialize launcher-agent process group + # ranks = (launcher, agent_{hostnames[0]}, ..., agent[-1]) launcher_agent_group = LauncherAgentGroup( launcher_hostname=launcher_hostname, @@ -193,51 +250,45 @@ def run( # build and sync payloads between launcher and agents - _cumulative_workers = [0] + list(itertools.accumulate(self.workers_per_host)) - - worker_world_size = _cumulative_workers[-1] + _cumulative_workers = [0, *itertools.accumulate(workers_per_host)] - worker_global_ranks = [] # list of worker ranks per host - for n in range(num_hosts): - host_ranks = range(_cumulative_workers[n], _cumulative_workers[n + 1]) - worker_global_ranks.append(list(host_ranks)) - - if func_args is None: - func_args = tuple() - if func_kwargs is None: - func_kwargs = dict() + worker_global_ranks = [ + list(range(_cumulative_workers[n], _cumulative_workers[n + 1])) + for n in range(len(hostnames)) + ] payload = LauncherPayload( - fn=partial(func, *func_args, **func_kwargs), - hostnames=self.hostnames, - worker_world_size=worker_world_size, + fn=partial(func, *(func_args or ()), **(func_kwargs or {})), + hostnames=hostnames, worker_global_ranks=worker_global_ranks, + worker_world_size=sum(workers_per_host), backend=self.backend, timeout=self.timeout, ) launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload) - agent_pids = [p.process_id for p in agent_payloads] # loop to monitor agent statuses (until failed or done) + try: while True: + # raises exception if communication timeout due to death of any agent agent_statuses = launcher_agent_group.sync_agent_statuses(status=None) + # raises exception if any agent failed for s in agent_statuses: - if s.state == "failed": - for value in s.return_values.values(): - if isinstance(value, WorkerException): - raise value.exception + for value in s.return_values.values(): + if isinstance(value, WorkerException): + raise value.exception if all(s.state == "done" for s in agent_statuses): break except: # cleanup: SIGTERM all agents - for agent_pid, agent_hostname in zip(agent_pids, self.hostnames): + for agent_payload, agent_hostname in zip(agent_payloads, hostnames): execute_command( - command=f"kill {agent_pid}", + command=f"kill {agent_payload.process_id}", hostname=agent_hostname, ssh_config_file=self.ssh_config_file, ) @@ -248,8 +299,10 @@ def run( log_process.kill() dist.destroy_process_group() - return_values: dict[int, Any] = dict(ChainMap(*[s.return_values for s in agent_statuses])) - return return_values + return { + hostname: agent_status.return_values + for hostname, agent_status in zip(hostnames, agent_statuses) + } def launch( @@ -273,7 +326,7 @@ def launch( ), env_file: str | os.PathLike | None = None, timeout: int = 600, -) -> dict[int, Any]: +) -> dict[str, dict[int, Any]]: """ Launch a distributed PyTorch function on the specified nodes. diff --git a/src/torchrunx/logging_utils.py b/src/torchrunx/logging_utils.py index 469c845f..d12b27f7 100644 --- a/src/torchrunx/logging_utils.py +++ b/src/torchrunx/logging_utils.py @@ -2,15 +2,115 @@ import datetime import logging -import os import pickle import struct from contextlib import redirect_stderr, redirect_stdout +from dataclasses import dataclass from io import StringIO from logging import Handler, Logger from logging.handlers import SocketHandler from pathlib import Path from socketserver import StreamRequestHandler, ThreadingTCPServer +from typing import TYPE_CHECKING + +from typing_extensions import Self + +if TYPE_CHECKING: + import os + +## Launcher utilities + + +class LogRecordSocketReceiver(ThreadingTCPServer): + def __init__(self, host: str, port: int, handlers: list[Handler]) -> None: + self.host = host + self.port = port + + class _LogRecordStreamHandler(StreamRequestHandler): + def handle(self) -> None: + while True: + chunk_size = 4 + chunk = self.connection.recv(chunk_size) + if len(chunk) < chunk_size: + break + slen = struct.unpack(">L", chunk)[0] + chunk = self.connection.recv(slen) + while len(chunk) < slen: + chunk = chunk + self.connection.recv(slen - len(chunk)) + obj = pickle.loads(chunk) + record = logging.makeLogRecord(obj) + + for handler in handlers: + handler.handle(record) + + super().__init__( + server_address=(host, port), + RequestHandlerClass=_LogRecordStreamHandler, + bind_and_activate=True, + ) + self.daemon_threads = True + + def shutdown(self) -> None: + """override BaseServer.shutdown() with added timeout""" + self._BaseServer__shutdown_request = True + self._BaseServer__is_shut_down.wait(timeout=3) # pyright: ignore[reportAttributeAccessIssue] + + +## Agent/worker utilities + + +@dataclass +class WorkerLogRecord(logging.LogRecord): + hostname: str + worker_rank: int | None + + @classmethod + def from_record(cls, record: logging.LogRecord, hostname: str, worker_rank: int | None) -> Self: + record.hostname = hostname + record.worker_rank = worker_rank + record.__class__ = cls + return record # pyright: ignore [reportReturnType] + + +def log_records_to_socket( + logger: Logger, + hostname: str, + worker_rank: int | None, + logger_hostname: str, + logger_port: int, +) -> None: + logger.setLevel(logging.NOTSET) + + old_factory = logging.getLogRecordFactory() + + def record_factory(*args, **kwargs) -> WorkerLogRecord: # noqa: ANN002, ANN003 + record = old_factory(*args, **kwargs) + return WorkerLogRecord.from_record(record, hostname, worker_rank) + + logging.setLogRecordFactory(record_factory) + + logger.addHandler(SocketHandler(host=logger_hostname, port=logger_port)) + + +def redirect_stdio_to_logger(logger: Logger) -> None: + class _LoggingStream(StringIO): + def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None: + super().__init__() + self.logger = logger + self.level = level + + def flush(self) -> None: + super().flush() + value = self.getvalue() + if value != "": + self.logger.log(self.level, value) + self.truncate(0) + self.seek(0) + + logging.captureWarnings(capture=True) + redirect_stderr(_LoggingStream(logger, level=logging.ERROR)).__enter__() + redirect_stdout(_LoggingStream(logger, level=logging.INFO)).__enter__() + ## Handler utilities @@ -21,14 +121,27 @@ def add_filter_to_handler( worker_rank: int | None, log_level: int = logging.NOTSET, ) -> None: - def _filter(record: logging.LogRecord) -> bool: + def _filter(record: WorkerLogRecord) -> bool: return ( - record.hostname == hostname # pyright: ignore[reportAttributeAccessIssue] - and record.worker_rank == worker_rank # pyright: ignore[reportAttributeAccessIssue] + record.hostname == hostname + and record.worker_rank == worker_rank and record.levelno >= log_level ) - handler.addFilter(_filter) + handler.addFilter(_filter) # pyright: ignore [reportArgumentType] + + +def stream_handler(hostname: str, rank: int | None, log_level: int = logging.NOTSET) -> Handler: + handler = logging.StreamHandler() + add_filter_to_handler(handler, hostname, rank, log_level=log_level) + handler.setFormatter( + logging.Formatter( + "%(asctime)s:%(levelname)s:%(hostname)s[%(worker_rank)s]: %(message)s" + if rank is not None + else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s", + ), + ) + return handler def file_handler( @@ -52,11 +165,11 @@ def file_handlers( ) -> list[Handler]: handlers = [] - os.makedirs(log_dir, exist_ok=True) + Path(log_dir).mkdir(parents=True, exist_ok=True) timestamp = datetime.datetime.now().isoformat(timespec="seconds") for hostname, num_workers in zip(hostnames, workers_per_host): - for rank in [None] + list(range(num_workers)): + for rank in [None, *range(num_workers)]: file_path = ( f"{log_dir}/{timestamp}-{hostname}" + (f"[{rank}]" if rank is not None else "") @@ -67,19 +180,6 @@ def file_handlers( return handlers -def stream_handler(hostname: str, rank: int | None, log_level: int = logging.NOTSET) -> Handler: - handler = logging.StreamHandler() - add_filter_to_handler(handler, hostname, rank, log_level=log_level) - handler.setFormatter( - logging.Formatter( - "%(asctime)s:%(levelname)s:%(hostname)s[%(worker_rank)s]: %(message)s" - if rank is not None - else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s" - ) - ) - return handler - - def default_handlers( hostnames: list[str], workers_per_host: list[int], @@ -89,83 +189,5 @@ def default_handlers( return [ stream_handler(hostname=hostnames[0], rank=None, log_level=log_level), stream_handler(hostname=hostnames[0], rank=0, log_level=log_level), - ] + file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=log_level) - - -## Agent/worker utilities - - -def log_records_to_socket( - logger: Logger, - hostname: str, - worker_rank: int | None, - logger_hostname: str, - logger_port: int, -): - logger.setLevel(logging.NOTSET) - - old_factory = logging.getLogRecordFactory() - - def record_factory(*args, **kwargs): - record = old_factory(*args, **kwargs) - record.hostname = hostname - record.worker_rank = worker_rank - return record - - logging.setLogRecordFactory(record_factory) - - logger.addHandler(SocketHandler(host=logger_hostname, port=logger_port)) - - -def redirect_stdio_to_logger(logger: Logger): - class _LoggingStream(StringIO): - def __init__(self, logger: Logger, level: int = logging.NOTSET): - super().__init__() - self.logger = logger - self.level = level - - def flush(self): - super().flush() - value = self.getvalue() - if value != "": - self.logger.log(self.level, f"\n{value}") - self.truncate(0) - self.seek(0) - - logging.captureWarnings(True) - redirect_stderr(_LoggingStream(logger, level=logging.ERROR)).__enter__() - redirect_stdout(_LoggingStream(logger, level=logging.INFO)).__enter__() - - -## Launcher utilities - - -class LogRecordSocketReceiver(ThreadingTCPServer): - def __init__(self, host: str, port: int, handlers: list[Handler]): - class _LogRecordStreamHandler(StreamRequestHandler): - def handle(self): - while True: - chunk = self.connection.recv(4) - if len(chunk) < 4: - break - slen = struct.unpack(">L", chunk)[0] - chunk = self.connection.recv(slen) - while len(chunk) < slen: - chunk = chunk + self.connection.recv(slen - len(chunk)) - obj = pickle.loads(chunk) - record = logging.makeLogRecord(obj) - # - for handler in handlers: - handler.handle(record) - - super().__init__( - server_address=(host, port), - RequestHandlerClass=_LogRecordStreamHandler, - bind_and_activate=True, - ) - self.daemon_threads = True - - def shutdown(self): - """override BaseServer.shutdown() with added timeout""" - self._BaseServer__shutdown_request = True - self._BaseServer__is_shut_down.wait(timeout=3) # pyright: ignore[reportAttributeAccessIssue] + *file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=log_level), + ] diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index 3a14d342..1bd25c52 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -4,19 +4,20 @@ import socket from contextlib import closing from dataclasses import dataclass, field -from typing import Any, Callable, Literal +from typing import TYPE_CHECKING, Any, Callable, Literal import cloudpickle import torch.distributed as dist -from torch.distributed.elastic.multiprocessing.api import RunProcsResult from typing_extensions import Self +if TYPE_CHECKING: + from torch.distributed.elastic.multiprocessing.api import RunProcsResult + def get_open_port() -> int: with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("", 0)) - port = s.getsockname()[1] - return port + return s.getsockname()[1] @dataclass @@ -28,8 +29,8 @@ class WorkerException: class LauncherPayload: fn: Callable hostnames: list[str] - worker_world_size: int worker_global_ranks: list[list[int]] + worker_world_size: int backend: Literal["mpi", "gloo", "nccl", "ucc", None] timeout: int @@ -47,7 +48,7 @@ class AgentStatus: return_values: dict[int, Any | WorkerException] = field(default_factory=dict) @classmethod - def from_result(cls, result: RunProcsResult | None, worker_global_ranks: list[int]) -> Self: + def from_result(cls, result: RunProcsResult | None) -> Self: if result is None: return cls(state="running") @@ -60,7 +61,7 @@ def from_result(cls, result: RunProcsResult | None, worker_global_ranks: list[in return cls( state=state, - return_values={worker_global_ranks[k]: v for k, v in return_values.items()}, + return_values=return_values, ) @@ -76,7 +77,7 @@ def __post_init__(self) -> None: backend="gloo", world_size=self.world_size, rank=self.rank, - store=dist.TCPStore( # pyright: ignore[reportPrivateImportUsage] + store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage] host_name=self.launcher_hostname, port=self.launcher_port, world_size=self.world_size, @@ -85,27 +86,27 @@ def __post_init__(self) -> None: timeout=datetime.timedelta(seconds=30), ) - def _serialize(self, object: Any) -> bytes: - return cloudpickle.dumps(object) + def _serialize(self, obj: Any) -> bytes: + return cloudpickle.dumps(obj) def _deserialize(self, serialized: bytes) -> Any: return cloudpickle.loads(serialized) - def _all_gather(self, object: Any) -> list: + def _all_gather(self, obj: Any) -> list: """gather object from every rank to list on every rank""" - object_bytes = self._serialize(object) + object_bytes = self._serialize(obj) object_list = [b""] * self.world_size dist.all_gather_object(object_list=object_list, obj=object_bytes, group=self.group) - object_list = [self._deserialize(o) for o in object_list] - return object_list + return [self._deserialize(o) for o in object_list] def sync_payloads( - self, payload: LauncherPayload | AgentPayload + self, + payload: LauncherPayload | AgentPayload, ) -> tuple[LauncherPayload, list[AgentPayload]]: - payloads = self._all_gather(object=payload) + payloads = self._all_gather(payload) launcher_payload = payloads[0] agent_payloads = payloads[1:] return launcher_payload, agent_payloads def sync_agent_statuses(self, status: AgentStatus | None) -> list[AgentStatus]: - return self._all_gather(object=status)[1:] # [0] is launcher (status=None) + return self._all_gather(status)[1:] # [0] is launcher (status=None) diff --git a/tests/test_CI.py b/tests/test_ci.py similarity index 66% rename from tests/test_CI.py rename to tests/test_ci.py index b86cad64..f72f3ef4 100644 --- a/tests/test_CI.py +++ b/tests/test_ci.py @@ -1,5 +1,7 @@ import os import tempfile +from pathlib import Path +from typing import NoReturn import pytest import torch @@ -8,14 +10,11 @@ import torchrunx as trx -def test_simple_localhost(): - def dist_func(): +def test_simple_localhost() -> None: + def dist_func() -> torch.Tensor: rank = int(os.environ["RANK"]) - if rank == 0: - w = torch.rand((100, 100)) # in_dim, out_dim - else: - w = torch.zeros((100, 100)) + w = torch.rand((100, 100)) if rank == 0 else torch.zeros((100, 100)) dist.broadcast(w, 0) @@ -38,48 +37,50 @@ def dist_func(): backend="gloo", # log_dir="./test_logs" ) - assert torch.all(r[0] == r[1]) + results = next(iter(r.values())) + assert torch.all(results[0] == results[1]) -def test_logging(): - def dist_func(): +def test_logging() -> None: + def dist_func() -> None: rank = int(os.environ["RANK"]) print(f"worker rank: {rank}") tmp = tempfile.mkdtemp() - os.environ["TORCHRUNX_DIR"] = tmp + os.environ["TORCHRUNX_LOG_DIR"] = tmp + + num_workers = 2 trx.launch( func=dist_func, func_kwargs={}, - workers_per_host=2, + workers_per_host=num_workers, backend="gloo", ) log_files = next(os.walk(tmp), (None, None, []))[2] - assert len(log_files) == 3 + assert len(log_files) == num_workers + 1 for file in log_files: - with open(f"{tmp}/{file}") as f: + with Path(f"{tmp}/{file}").open() as f: contents = f.read() print(contents) if file.endswith("[0].log"): assert "worker rank: 0\n" in contents elif file.endswith("[1].log"): assert "worker rank: 1\n" in contents - else: - assert "starting processes" in contents -def test_error(): - def error_func(): - raise ValueError("abcdefg") +def test_error() -> None: + def error_func() -> NoReturn: + msg = "abcdefg" + raise ValueError(msg) tmp = tempfile.mkdtemp() os.environ["TORCHRUNX_DIR"] = tmp - with pytest.raises(ValueError) as excinfo: + with pytest.raises(ValueError) as excinfo: # noqa: PT011 trx.launch( func=error_func, func_kwargs={}, diff --git a/tests/test_func.py b/tests/test_func.py index 9db6454d..8fb264bf 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -6,21 +6,23 @@ import torchrunx as trx -def test_launch(): +def test_launch() -> None: result = trx.launch( func=simple_matmul, hostnames="slurm", workers_per_host="slurm", ) + result_values = [v for host_results in result.values() for v in host_results.values()] + t = True - for i in range(len(result)): - t = t and torch.all(result[i] == result[0]) + for i in range(len(result_values)): + t = t and torch.all(result_values[i] == result_values[0]) assert t, "Not all tensors equal" -def simple_matmul(): +def simple_matmul() -> torch.Tensor: rank = int(os.environ["RANK"]) local_rank = int(os.environ["LOCAL_RANK"]) device = torch.device(local_rank) if torch.cuda.is_available() else torch.device("cpu") diff --git a/tests/test_submitit.py b/tests/test_submitit.py index 290f7aad..433e3382 100644 --- a/tests/test_submitit.py +++ b/tests/test_submitit.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy import submitit @@ -9,22 +11,22 @@ class DummyDataset(Dataset): - def __init__(self, max_text_length=16, num_samples=20000) -> None: + def __init__(self, max_text_length: int = 16, num_samples: int = 20000) -> None: super().__init__() self.input_ids = torch.randint(0, 30522, (num_samples, max_text_length)) self.labels = copy.deepcopy(self.input_ids) - def __len__(self): + def __len__(self) -> int: return len(self.input_ids) - def __getitem__(self, index): + def __getitem__(self, index: int) -> dict[str, torch.Tensor]: return { "input_ids": self.input_ids[index], "labels": self.labels[index], } -def main(): +def main() -> None: model = BertForMaskedLM.from_pretrained("bert-base-uncased") train_dataset = DummyDataset() @@ -38,7 +40,7 @@ def main(): ) trainer = Trainer( - model=model, # type: ignore + model=model, args=training_arguments, train_dataset=train_dataset, ) @@ -46,11 +48,11 @@ def main(): trainer.train() -def launch(): +def launch() -> None: trx.launch(func=main, func_kwargs={}, hostnames="slurm", workers_per_host="slurm") -def test_submitit(): +def test_submitit() -> None: executor = submitit.SlurmExecutor(folder="logs") executor.update_parameters( diff --git a/tests/test_train.py b/tests/test_train.py index d28f5ef5..b654a8b7 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -3,23 +3,21 @@ import torchrunx as trx -def worker(): +def worker() -> None: import torch - class TwoLinLayerNet(torch.nn.Module): - def __init__(self): + class MLP(torch.nn.Module): + def __init__(self) -> None: super().__init__() self.a = torch.nn.Linear(10, 10, bias=False) self.b = torch.nn.Linear(10, 1, bias=False) - def forward(self, x): - a = self.a(x) - b = self.b(x) - return (a, b) + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.b(self.a(x)) local_rank = int(os.environ["LOCAL_RANK"]) print("init model") - model = TwoLinLayerNet().to(local_rank) + model = MLP().to(local_rank) print("init ddp") ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) @@ -28,11 +26,11 @@ def forward(self, x): for _ in range(20): output = ddp_model(inp) - loss = output[0] + output[1] - loss.sum().backward() + loss = output.sum() + loss.backward() -def test_distributed_train(): +def test_distributed_train() -> None: trx.launch( worker, hostnames="slurm",