From ae9373bf44cf59e1111f1512dc9f4182e37a86ca Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sat, 14 Sep 2024 02:17:40 -0400 Subject: [PATCH 01/10] no Sequence arguments in launcher --- src/torchrunx/launcher.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index fa73ae04..68de097e 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -14,7 +14,7 @@ from logging import Handler from multiprocessing import Process from pathlib import Path -from typing import Any, Callable, Literal, Sequence +from typing import Any, Callable, Literal import fabric import torch.distributed as dist @@ -86,7 +86,7 @@ def build_command( logger_port: int, world_size: int, rank: int, - env_vars: Sequence[str], + env_vars: list[str] | tuple[str], env_file: str | os.PathLike | None, ) -> str: # shlex.quote prevents shell injection here (resolves S602 in execute_command) @@ -160,7 +160,7 @@ class Launcher: ssh_config_file: str | os.PathLike | None = None backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None log_handlers: list[Handler] | Literal["auto"] | None = "auto" - env_vars: Sequence[str] = ( + env_vars: list[str] | tuple[str] = ( # pyright: ignore [reportAssignmentType] "PATH", "LD_LIBRARY", "LIBRARY_PATH", @@ -321,8 +321,8 @@ def launch( workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto", ssh_config_file: str | os.PathLike | None = None, backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None, - log_handlers: list[Handler] | Literal["auto"] = "auto", - env_vars: Sequence[str] = ( + log_handlers: list[Handler] | Literal["auto"] | None = "auto", + env_vars: list[str] | tuple[str] = ( # pyright: ignore [reportArgumentType] "PATH", "LD_LIBRARY", "LIBRARY_PATH", From fd47acf8d7f9f1f84db8302b6fda50145eb5e447 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sat, 14 Sep 2024 14:48:22 -0400 Subject: [PATCH 02/10] auto and None options for backend --- src/torchrunx/agent.py | 49 ++++++++++++++++++++------------------- src/torchrunx/launcher.py | 4 ++-- src/torchrunx/utils.py | 2 +- 3 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 04d1ec92..0f43506b 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -32,7 +32,7 @@ class WorkerArgs: logger_port: int main_agent_hostname: str main_agent_port: int - backend: Literal["mpi", "gloo", "nccl", "ucc", None] + backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None rank: int local_rank: int local_world_size: int @@ -67,29 +67,30 @@ def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerExce redirect_stdio_to_logger(logger) - store = dist.TCPStore( # pyright: ignore [reportPrivateImportUsage] - host_name=worker_args.main_agent_hostname, - port=worker_args.main_agent_port, - world_size=worker_args.world_size, - is_master=(worker_args.rank == 0), - ) - - backend = worker_args.backend or ("nccl" if torch.cuda.is_available() else "gloo") - - dist.init_process_group( - backend=backend, - world_size=worker_args.world_size, - rank=worker_args.rank, - store=store, - timeout=datetime.timedelta(seconds=worker_args.timeout), - ) - - os.environ["RANK"] = str(worker_args.rank) - os.environ["LOCAL_RANK"] = str(worker_args.local_rank) - os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size) - os.environ["WORLD_SIZE"] = str(worker_args.world_size) - os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname - os.environ["MASTER_PORT"] = str(worker_args.main_agent_port) + if worker_args.backend is not None: + os.environ["RANK"] = str(worker_args.rank) + os.environ["LOCAL_RANK"] = str(worker_args.local_rank) + os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size) + os.environ["WORLD_SIZE"] = str(worker_args.world_size) + os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname + os.environ["MASTER_PORT"] = str(worker_args.main_agent_port) + + backend = worker_args.backend + if backend == "auto": + backend = "nccl" if torch.cuda.is_available() else "gloo" + + dist.init_process_group( + backend=backend, + world_size=worker_args.world_size, + rank=worker_args.rank, + store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage] + host_name=worker_args.main_agent_hostname, + port=worker_args.main_agent_port, + world_size=worker_args.world_size, + is_master=(worker_args.rank == 0), + ), + timeout=datetime.timedelta(seconds=worker_args.timeout), + ) try: return worker_args.function() diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 68de097e..43097f1f 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -158,7 +158,7 @@ class Launcher: hostnames: list[str] | Literal["auto", "slurm"] = "auto" workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto" ssh_config_file: str | os.PathLike | None = None - backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None + backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto" log_handlers: list[Handler] | Literal["auto"] | None = "auto" env_vars: list[str] | tuple[str] = ( # pyright: ignore [reportAssignmentType] "PATH", @@ -320,7 +320,7 @@ def launch( hostnames: list[str] | Literal["auto", "slurm"] = "auto", workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto", ssh_config_file: str | os.PathLike | None = None, - backend: Literal["mpi", "gloo", "nccl", "ucc", None] = None, + backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto", log_handlers: list[Handler] | Literal["auto"] | None = "auto", env_vars: list[str] | tuple[str] = ( # pyright: ignore [reportArgumentType] "PATH", diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index 0fafec9d..5bf6ce3e 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -31,7 +31,7 @@ class LauncherPayload: hostnames: list[str] worker_global_ranks: list[list[int]] worker_world_size: int - backend: Literal["mpi", "gloo", "nccl", "ucc", None] + backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None timeout: int From 3707de28eef8c624864a6d4747829d3bad26a104 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sat, 14 Sep 2024 15:20:32 -0400 Subject: [PATCH 03/10] refactoring --- src/torchrunx/launcher.py | 42 ++++++++++---------- src/torchrunx/utils.py | 80 +++++++++++++++++++-------------------- 2 files changed, 62 insertions(+), 60 deletions(-) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 43097f1f..aaf9fd5f 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -80,7 +80,7 @@ def build_logging_server( ) -def build_command( +def build_launch_command( launcher_hostname: str, launcher_port: int, logger_port: int, @@ -122,33 +122,35 @@ def build_command( return " && ".join(commands) -def is_localhost(hostname_or_ip: str) -> bool: - # check if host is "loopback" address (i.e. designated to send to self) - try: - ip = ipaddress.ip_address(hostname_or_ip) - except ValueError: - ip = ipaddress.ip_address(socket.gethostbyname(hostname_or_ip)) - if ip.is_loopback: - return True - # else compare local interface addresses between host and localhost - host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(ip), None)] - localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)] - return len(set(host_addrs) & set(localhost_addrs)) > 0 - - def execute_command( command: str, hostname: str, ssh_config_file: str | os.PathLike | None = None, ) -> None: - if is_localhost(hostname): + is_localhost = True + _hostname_or_ip = hostname + try: + _ip = ipaddress.ip_address(_hostname_or_ip) + except ValueError: + _ip = ipaddress.ip_address(socket.gethostbyname(_hostname_or_ip)) + if not _ip.is_loopback: + # compare local interface addresses between host and localhost + _host_addrs = [addr[4][0] for addr in socket.getaddrinfo(str(_ip), None)] + _localhost_addrs = [addr[4][0] for addr in socket.getaddrinfo(socket.gethostname(), None)] + is_localhost = len(set(_host_addrs) & set(_localhost_addrs)) > 0 + + if is_localhost: # S602: subprocess.Popen is called with shell=True (https://docs.python.org/3.8/library/subprocess.html#security-considerations) # Made sure to shlex.quote arguments in build_command to prevent shell injection subprocess.Popen(command, shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # noqa: S602 else: + runtime_ssh_path = ssh_config_file + if isinstance(ssh_config_file, os.PathLike): + runtime_ssh_path = str(ssh_config_file) + with fabric.Connection( host=hostname, - config=fabric.Config(runtime_ssh_path=ssh_config_file), + config=fabric.Config(runtime_ssh_path=runtime_ssh_path), ) as conn: conn.run(f"{command} >> /dev/null 2>&1 &", asynchronous=True) @@ -160,7 +162,7 @@ class Launcher: ssh_config_file: str | os.PathLike | None = None backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto" log_handlers: list[Handler] | Literal["auto"] | None = "auto" - env_vars: list[str] | tuple[str] = ( # pyright: ignore [reportAssignmentType] + env_vars: tuple[str] = ( # pyright: ignore [reportAssignmentType] "PATH", "LD_LIBRARY", "LIBRARY_PATH", @@ -231,7 +233,7 @@ def run( # noqa: C901, PLR0912 for i, hostname in enumerate(hostnames): execute_command( - command=build_command( + command=build_launch_command( launcher_hostname=launcher_hostname, launcher_port=launcher_port, logger_port=log_receiver.port, @@ -322,7 +324,7 @@ def launch( ssh_config_file: str | os.PathLike | None = None, backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto", log_handlers: list[Handler] | Literal["auto"] | None = "auto", - env_vars: list[str] | tuple[str] = ( # pyright: ignore [reportArgumentType] + env_vars: tuple[str] = ( # pyright: ignore [reportArgumentType] "PATH", "LD_LIBRARY", "LIBRARY_PATH", diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index 5bf6ce3e..d4cd5684 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -25,46 +25,6 @@ class WorkerException: exception: Exception -@dataclass -class LauncherPayload: - fn: Callable - hostnames: list[str] - worker_global_ranks: list[list[int]] - worker_world_size: int - backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None - timeout: int - - -@dataclass -class AgentPayload: - hostname: str - port: int - process_id: int - - -@dataclass -class AgentStatus: - state: Literal["running", "failed", "done"] - return_values: dict[int, Any | WorkerException] = field(default_factory=dict) - - @classmethod - def from_result(cls, result: RunProcsResult | None) -> Self: - if result is None: - return cls(state="running") - - return_values = result.return_values - - if any(isinstance(v, WorkerException) for v in return_values.values()): - state = "failed" - else: - state = "done" - - return cls( - state=state, - return_values=return_values, - ) - - @dataclass class LauncherAgentGroup: launcher_hostname: str @@ -115,3 +75,43 @@ def sync_agent_statuses(self, status: AgentStatus | None) -> list[AgentStatus]: def shutdown(self) -> None: dist.destroy_process_group(group=self.group) + + +@dataclass +class LauncherPayload: + fn: Callable + hostnames: list[str] + worker_global_ranks: list[list[int]] + worker_world_size: int + backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None + timeout: int + + +@dataclass +class AgentPayload: + hostname: str + port: int + process_id: int + + +@dataclass +class AgentStatus: + state: Literal["running", "failed", "done"] + return_values: dict[int, Any | WorkerException] = field(default_factory=dict) + + @classmethod + def from_result(cls, result: RunProcsResult | None) -> Self: + if result is None: + return cls(state="running") + + return_values = result.return_values + + if any(isinstance(v, WorkerException) for v in return_values.values()): + state = "failed" + else: + state = "done" + + return cls( + state=state, + return_values=return_values, + ) From c5f082a4bcd2b2db04b6caafd4a5429049b212c5 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Fri, 27 Sep 2024 19:04:08 +0000 Subject: [PATCH 04/10] Updates to docs Co-authored-by: Peter Curtin --- CONTRIBUTING.md | 3 ++ README.md | 84 ++++++++++++++++++++++++++++-------- docs/source/contributing.rst | 21 +++++---- 3 files changed, 80 insertions(+), 28 deletions(-) create mode 100644 CONTRIBUTING.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..fa436244 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,3 @@ +# Contributing + +We use the [`pixi`](https://pixi.sh) package manager. Simply [install `pixi`](https://pixi.sh/latest/#installation) and run `pixi shell` in this repository. We use `ruff` for linting and formatting, `pyright` for static type checking, and `pytest` for testing. We build for `PyPI`. Our release pipeline is powered by Github Actions. diff --git a/README.md b/README.md index e65a3336..5b8f26bc 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,14 @@ # torchrunx 🔥 +By [Apoorv Khandelwal](http://apoorvkh.com) and [Peter Curtin](https://github.com/pmcurtin) + [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchrunx)](https://github.com/apoorvkh/torchrunx/blob/main/pyproject.toml) [![PyPI - Version](https://img.shields.io/pypi/v/torchrunx)](https://pypi.org/project/torchrunx/) ![Tests](https://img.shields.io/github/actions/workflow/status/apoorvkh/torchrunx/.github%2Fworkflows%2Fmain.yml) [![Docs](https://readthedocs.org/projects/torchrunx/badge/?version=stable)](https://torchrunx.readthedocs.io) [![GitHub License](https://img.shields.io/github/license/apoorvkh/torchrunx)](https://github.com/apoorvkh/torchrunx/blob/main/LICENSE) -Automatically launch functions and initialize distributed PyTorch environments on multiple machines +Automatically launch PyTorch functions onto multiple machines or GPUs ## Installation @@ -18,39 +20,83 @@ Requirements: - Operating System: Linux - Python >= 3.8.1 - PyTorch >= 2.0 -- Shared filesystem & passwordless SSH between hosts +- Shared filesystem & SSH between hosts + +## Features + +- Distribute PyTorch functions to multiple GPUs or machines +- `torchrun` with the convenience of a Python function +- Integration with SLURM + +Advantages: + +- Self-cleaning: avoid memory leaks! +- Better for complex workflows +- Doesn't parallelize the whole script: just what you want +- Run distributed functions from Python Notebooks ## Usage +Here's a simple example where we distribute `distributed_function` to two hosts (with 2 GPUs each): + ```python -# Simple example -def distributed_function(): - pass +def train_model(model, dataset): + trained_model = train(model, dataset) + + if int(os.environ["RANK"]) == 0: + torch.save(learned_model, 'model.pt') + return 'model.pt' + + return None ``` ```python import torchrunx as trx -trx.launch( - func=distributed_function, - func_kwargs={}, - hostnames=["node1", "node2"], # or just: ["localhost"] +model_path = trx.launch( + func=train_model, + func_kwargs={'model': my_model, 'training_dataset': mnist_train}, + hostnames=["localhost", "other_node"], workers_per_host=2 -) +)["localhost"][0] # return from rank 0 (first worker on "localhost") ``` -### In a SLURM allocation +We could also launch multiple functions, with different GPUs: ```python -trx.launch( - # ... - hostnames=trx.slurm_hosts(), - workers_per_host=trx.slurm_workers() -) +def train_model(model, dataset): + trained_model = train(model, dataset) + + if int(os.environ["RANK"]) == 0: + torch.save(learned_model, 'model.pt') + return 'model.pt' + + return None + +def test_model(model_path, test_dataset): + model = torch.load(model_path) + accuracy = inference(model, test_dataset) + return accuracy ``` -## Compared to other tools +```python +import torchrunx as trx + +model_path = trx.launch( + func=train_model, + func_kwargs={'model': my_model, 'training_dataset': mnist_train}, + hostnames=["localhost", "other_node"], + workers_per_host=2 +)["localhost"][0] # return from rank 0 (first worker on "localhost") + + -## Contributing +accuracy = trx.launch( + func=test_model, + func_kwargs={'model': learned_model, 'test_dataset': mnist_test}, + hostnames=["localhost"], + workers_per_host=1 +)["localhost"][0] -We use the [`pixi`](https://pixi.sh) package manager. Simply [install `pixi`](https://pixi.sh/latest/#installation) and run `pixi shell` in this repository. We use `ruff` for linting and formatting, `pyright` for static type checking, and `pytest` for testing. We build for `PyPI` and `conda-forge`. Our release pipeline is powered by Github Actions. +print(f'Accuracy: {accuracy}') +``` \ No newline at end of file diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index fb9554e3..707e376c 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -1,17 +1,20 @@ Contributing ============ -Development environment ------------------------ +.. include:: ../../CONTRIBUTING.md + :parser: myst_parser.sphinx_ -Ensure you have the latest development environment installed. After cloning our repository, `install pixi `_ and run ``pixi shell`` in the repo's root directory. Additionally, we use `ruff `_ for linting and formatting, `pyright `_ for type checking, and ``pytest`` for testing. +.. Development environment +.. ----------------------- -Testing -------- +.. Ensure you have the latest development environment installed. After cloning our repository, `install pixi `_ and run ``pixi shell`` in the repo's root directory. Additionally, we use `ruff `_ for linting and formatting, `pyright `_ for type checking, and ``pytest`` for testing. -``tests/`` contains ``pytest``-style tests for validating that code changes do not break the core functionality of **torchrunx**. At the moment, we have a few simple CI tests powered by Github action, which are limited to single-agent CPU-only tests due to Github's infrastructure. +.. Testing +.. ------- -Contributing ------------- +.. ``tests/`` contains ``pytest``-style tests for validating that code changes do not break the core functionality of **torchrunx**. At the moment, we have a few simple CI tests powered by Github action, which are limited to single-agent CPU-only tests due to Github's infrastructure. + +.. Contributing +.. ------------ -Make a pull request with your changes and we'll try to look at soon! If addressing a specific issue, mention it in the PR, and offer a short explanation of your fix. If adding a new feature, explain why it's meaningful and belongs in **torchrunx**. \ No newline at end of file +.. Make a pull request with your changes and we'll try to look at soon! If addressing a specific issue, mention it in the PR, and offer a short explanation of your fix. If adding a new feature, explain why it's meaningful and belongs in **torchrunx**. \ No newline at end of file From c17be1f32609c797f151c9ff7b3efe7ba5765e77 Mon Sep 17 00:00:00 2001 From: Apoorv Khandelwal Date: Sat, 28 Sep 2024 15:55:42 -0400 Subject: [PATCH 05/10] Update README.md --- README.md | 49 +++++++++++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 5b8f26bc..09ef5171 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ # torchrunx 🔥 -By [Apoorv Khandelwal](http://apoorvkh.com) and [Peter Curtin](https://github.com/pmcurtin) - [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchrunx)](https://github.com/apoorvkh/torchrunx/blob/main/pyproject.toml) [![PyPI - Version](https://img.shields.io/pypi/v/torchrunx)](https://pypi.org/project/torchrunx/) ![Tests](https://img.shields.io/github/actions/workflow/status/apoorvkh/torchrunx/.github%2Fworkflows%2Fmain.yml) [![Docs](https://readthedocs.org/projects/torchrunx/badge/?version=stable)](https://torchrunx.readthedocs.io) [![GitHub License](https://img.shields.io/github/license/apoorvkh/torchrunx)](https://github.com/apoorvkh/torchrunx/blob/main/LICENSE) -Automatically launch PyTorch functions onto multiple machines or GPUs +By [Apoorv Khandelwal](http://apoorvkh.com) and [Peter Curtin](https://github.com/pmcurtin) + +**Automatically distribute PyTorch functions onto multiple machines or GPUs** ## Installation @@ -16,24 +16,37 @@ Automatically launch PyTorch functions onto multiple machines or GPUs pip install torchrunx ``` -Requirements: -- Operating System: Linux -- Python >= 3.8.1 -- PyTorch >= 2.0 -- Shared filesystem & SSH between hosts +Requires: Linux, Python >= 3.8.1, PyTorch >= 2.0 + +Shared filesystem & SSH access if using multiple machines + +## Why should I use this? + +[`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) is a hammer. `torchrunx` is a chisel. -## Features +Whether you have 1 GPU, 8 GPUs, or 8 machines: -- Distribute PyTorch functions to multiple GPUs or machines -- `torchrun` with the convenience of a Python function -- Integration with SLURM +Convenience: -Advantages: +- If you don't want to set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) yourself +- If you want to run `python myscript.py` instead of `torchrun myscript.py` +- If you don't want to manually SSH and run `torchrun --master-ip --master-port ...` on every machine (and if you don't want to babysit these machines for hanging failures) -- Self-cleaning: avoid memory leaks! -- Better for complex workflows -- Doesn't parallelize the whole script: just what you want -- Run distributed functions from Python Notebooks +Robustness: + +- If you want to run a complex, _modular_ workflow in one script + - no worries about memory leaks or OS failures + - don't parallelize your entire script: just the functions you want + +Features: + +- Our launch utility is super _Pythonic_ +- If you want to run distributed PyTorch functions from Python Notebooks. +- Automatic integration with SLURM + +Why not? + +- We don't support fault tolerance via torch elastic. Probably only useful if you are using 1000 GPUs. Maybe someone can make a PR. ## Usage @@ -99,4 +112,4 @@ accuracy = trx.launch( )["localhost"][0] print(f'Accuracy: {accuracy}') -``` \ No newline at end of file +``` From f9132214a0ad19141f52bf098e3a14821e2b509f Mon Sep 17 00:00:00 2001 From: Apoorv Khandelwal Date: Sat, 28 Sep 2024 15:57:47 -0400 Subject: [PATCH 06/10] Update README.md --- README.md | 52 +++++++++++++++++++++++++++------------------------- 1 file changed, 27 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 09ef5171..51f1f772 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,32 @@ Requires: Linux, Python >= 3.8.1, PyTorch >= 2.0 Shared filesystem & SSH access if using multiple machines +## Minimal example + +Here's a simple example where we distribute `distributed_function` to two hosts (with 2 GPUs each): + +```python +def train_model(model, dataset): + trained_model = train(model, dataset) + + if int(os.environ["RANK"]) == 0: + torch.save(learned_model, 'model.pt') + return 'model.pt' + + return None +``` + +```python +import torchrunx as trx + +model_path = trx.launch( + func=train_model, + func_kwargs={'model': my_model, 'training_dataset': mnist_train}, + hostnames=["localhost", "other_node"], + workers_per_host=2 +)["localhost"][0] # return from rank 0 (first worker on "localhost") +``` + ## Why should I use this? [`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) is a hammer. `torchrunx` is a chisel. @@ -48,31 +74,7 @@ Why not? - We don't support fault tolerance via torch elastic. Probably only useful if you are using 1000 GPUs. Maybe someone can make a PR. -## Usage - -Here's a simple example where we distribute `distributed_function` to two hosts (with 2 GPUs each): - -```python -def train_model(model, dataset): - trained_model = train(model, dataset) - - if int(os.environ["RANK"]) == 0: - torch.save(learned_model, 'model.pt') - return 'model.pt' - - return None -``` - -```python -import torchrunx as trx - -model_path = trx.launch( - func=train_model, - func_kwargs={'model': my_model, 'training_dataset': mnist_train}, - hostnames=["localhost", "other_node"], - workers_per_host=2 -)["localhost"][0] # return from rank 0 (first worker on "localhost") -``` +## More complicated example We could also launch multiple functions, with different GPUs: From 79fb0a89603354476744a58d4130ed2fc69651f1 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sat, 28 Sep 2024 18:05:36 -0400 Subject: [PATCH 07/10] LaunchResult, first draft --- src/torchrunx/launcher.py | 41 ++++++++++++++++++++++++++------------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index aaf9fd5f..aa102dd3 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -14,19 +14,14 @@ from logging import Handler from multiprocessing import Process from pathlib import Path -from typing import Any, Callable, Literal +from typing import Any, Callable, Literal, overload import fabric import torch.distributed as dist from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers from .logging_utils import LogRecordSocketReceiver, default_handlers -from .utils import ( - LauncherAgentGroup, - LauncherPayload, - WorkerException, - get_open_port, -) +from .utils import AgentStatus, LauncherAgentGroup, LauncherPayload, WorkerException, get_open_port def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: @@ -180,7 +175,7 @@ def run( # noqa: C901, PLR0912 func: Callable, func_args: tuple[Any] | None = None, func_kwargs: dict[str, Any] | None = None, - ) -> dict[str, dict[int, Any]]: + ) -> LaunchResult: """ Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch` @@ -309,10 +304,7 @@ def run( # noqa: C901, PLR0912 ssh_config_file=self.ssh_config_file, ) - return { - hostname: agent_status.return_values - for hostname, agent_status in zip(hostnames, agent_statuses) - } + return LaunchResult(hostnames=hostnames, agent_statuses=agent_statuses) def launch( @@ -336,7 +328,7 @@ def launch( ), env_file: str | os.PathLike | None = None, timeout: int = 600, -) -> dict[str, dict[int, Any]]: +) -> LaunchResult: """ Launch a distributed PyTorch function on the specified nodes. @@ -378,3 +370,26 @@ def launch( env_file=env_file, timeout=timeout, ).run(func=func, func_args=func_args, func_kwargs=func_kwargs) + + +class LaunchResult: + def __init__(self, hostnames: list[str], agent_statuses: list[AgentStatus]) -> None: + self.results = { + hostname: agent_status.return_values + for hostname, agent_status in zip(hostnames, agent_statuses) + } + + def all(self) -> dict[str, list[Any]]: + return self.results + + # all(by='rank') + + # value(rank: int) + + @overload + def value(self, hostname: str) -> list[Any]: + return list(self.results[hostname].values()) + + @overload + def value(self, hostname: str, rank: int) -> Any: + return self.results[hostname][rank] From 4b227de51a43c3f6fc383443852b7af745ec5b6e Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 29 Sep 2024 23:51:09 -0400 Subject: [PATCH 08/10] updates to LaunchResult getters --- src/torchrunx/launcher.py | 53 ++++++++++++++++++++++++++++----------- src/torchrunx/utils.py | 22 ++++++++-------- 2 files changed, 49 insertions(+), 26 deletions(-) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index aa102dd3..cd4a6098 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -10,9 +10,10 @@ import subprocess import sys from dataclasses import dataclass -from functools import partial +from functools import partial, reduce from logging import Handler from multiprocessing import Process +from operator import add from pathlib import Path from typing import Any, Callable, Literal, overload @@ -279,7 +280,7 @@ def run( # noqa: C901, PLR0912 # raises specific exception if any agent fails for s in agent_statuses: - for value in s.return_values.values(): + for value in s.return_values: if isinstance(value, WorkerException): raise value.exception @@ -374,22 +375,44 @@ def launch( class LaunchResult: def __init__(self, hostnames: list[str], agent_statuses: list[AgentStatus]) -> None: - self.results = { - hostname: agent_status.return_values - for hostname, agent_status in zip(hostnames, agent_statuses) - } + self.hostnames: list[str] = hostnames + self.return_values: list[list[Any]] = [s.return_values for s in agent_statuses] + @overload def all(self) -> dict[str, list[Any]]: - return self.results - - # all(by='rank') - - # value(rank: int) + pass @overload - def value(self, hostname: str) -> list[Any]: - return list(self.results[hostname].values()) + def all(self, by: Literal["hostname"]) -> dict[str, list[Any]]: + pass @overload - def value(self, hostname: str, rank: int) -> Any: - return self.results[hostname][rank] + def all(self, by: Literal["rank"]) -> list[Any]: + pass + + def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]: + if by == "hostname": + return dict(zip(self.hostnames, self.return_values)) + elif by == "rank": # noqa: RET505 + return reduce(add, self.return_values) + + msg = "Invalid argument: expected by=('hostname' | 'rank')" + raise TypeError(msg) + + def values(self, hostname: str) -> list[Any]: + host_idx = self.hostnames.index(hostname) + return self.return_values[host_idx] + + def value(self, rank: int) -> Any: + if rank < 0: + msg = f"Rank {rank} must be larger than 0" + raise ValueError(msg) + + for values in self.return_values: + if rank >= len(values): + rank -= len(values) + else: + return values[rank] + + msg = f"Rank {rank} larger than world_size" + raise ValueError(msg) diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index d4cd5684..3770e93d 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -20,11 +20,6 @@ def get_open_port() -> int: return s.getsockname()[1] -@dataclass -class WorkerException: - exception: Exception - - @dataclass class LauncherAgentGroup: launcher_hostname: str @@ -94,22 +89,27 @@ class AgentPayload: process_id: int +@dataclass +class WorkerException: + exception: Exception + + @dataclass class AgentStatus: state: Literal["running", "failed", "done"] - return_values: dict[int, Any | WorkerException] = field(default_factory=dict) + return_values: list[Any | WorkerException] = field( + default_factory=list + ) # indexed by local rank @classmethod def from_result(cls, result: RunProcsResult | None) -> Self: if result is None: return cls(state="running") - return_values = result.return_values + return_values = list(result.return_values.values()) - if any(isinstance(v, WorkerException) for v in return_values.values()): - state = "failed" - else: - state = "done" + failed = any(isinstance(v, WorkerException) for v in return_values) + state = "failed" if failed else "done" return cls( state=state, From ca478358a67d1dc5908e74d79474714c5555d31f Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Mon, 30 Sep 2024 00:09:08 -0400 Subject: [PATCH 09/10] update tests for LaunchResults --- tests/test_ci.py | 3 +-- tests/test_func.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_ci.py b/tests/test_ci.py index f72f3ef4..64cd1e93 100644 --- a/tests/test_ci.py +++ b/tests/test_ci.py @@ -37,8 +37,7 @@ def dist_func() -> torch.Tensor: backend="gloo", # log_dir="./test_logs" ) - results = next(iter(r.values())) - assert torch.all(results[0] == results[1]) + assert torch.all(r.value(0) == r.value(1)) def test_logging() -> None: diff --git a/tests/test_func.py b/tests/test_func.py index 8fb264bf..9ce99b9b 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -13,7 +13,7 @@ def test_launch() -> None: workers_per_host="slurm", ) - result_values = [v for host_results in result.values() for v in host_results.values()] + result_values = result.all(by='rank') t = True for i in range(len(result_values)): From 7850aeb193ddc7bd7bf870548e4958d1711494cc Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Mon, 30 Sep 2024 00:14:30 -0400 Subject: [PATCH 10/10] formatting --- tests/test_func.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_func.py b/tests/test_func.py index 9ce99b9b..e8033b4e 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -13,7 +13,7 @@ def test_launch() -> None: workers_per_host="slurm", ) - result_values = result.all(by='rank') + result_values = result.all(by="rank") t = True for i in range(len(result_values)):