diff --git a/README.md b/README.md
index 51f1f772..e3271e91 100644
--- a/README.md
+++ b/README.md
@@ -1,6 +1,7 @@
# torchrunx 🔥
[](https://github.com/apoorvkh/torchrunx/blob/main/pyproject.toml)
+[](https://github.com/pytorch/pytorch)
[](https://pypi.org/project/torchrunx/)

[](https://torchrunx.readthedocs.io)
@@ -16,102 +17,78 @@ By [Apoorv Khandelwal](http://apoorvkh.com) and [Peter Curtin](https://github.co
pip install torchrunx
```
-Requires: Linux, Python >= 3.8.1, PyTorch >= 2.0
+**Requires:** Linux (with shared filesystem & SSH access if using multiple machines)
-Shared filesystem & SSH access if using multiple machines
+## Demo
-## Minimal example
+Here's a simple example where we "train" a model on two nodes (with 2 GPUs each).
-Here's a simple example where we distribute `distributed_function` to two hosts (with 2 GPUs each):
+
+ Training code
-```python
-def train_model(model, dataset):
- trained_model = train(model, dataset)
-
- if int(os.environ["RANK"]) == 0:
- torch.save(learned_model, 'model.pt')
- return 'model.pt'
-
- return None
-```
-
-```python
-import torchrunx as trx
-
-model_path = trx.launch(
- func=train_model,
- func_kwargs={'model': my_model, 'training_dataset': mnist_train},
- hostnames=["localhost", "other_node"],
- workers_per_host=2
-)["localhost"][0] # return from rank 0 (first worker on "localhost")
-```
-
-## Why should I use this?
-
-[`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) is a hammer. `torchrunx` is a chisel.
-
-Whether you have 1 GPU, 8 GPUs, or 8 machines:
-
-Convenience:
+ ```python
+ import os
+ import torch
-- If you don't want to set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) yourself
-- If you want to run `python myscript.py` instead of `torchrun myscript.py`
-- If you don't want to manually SSH and run `torchrun --master-ip --master-port ...` on every machine (and if you don't want to babysit these machines for hanging failures)
+ def train():
+ rank = int(os.environ['RANK'])
+ local_rank = int(os.environ['LOCAL_RANK'])
-Robustness:
+ model = torch.nn.Linear(10, 10).to(local_rank)
+ ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank])
+ optimizer = torch.optim.AdamW(ddp_model.parameters())
-- If you want to run a complex, _modular_ workflow in one script
- - no worries about memory leaks or OS failures
- - don't parallelize your entire script: just the functions you want
-
-Features:
+ optimizer.zero_grad()
+ outputs = ddp_model(torch.randn(5, 10))
+ labels = torch.randn(5, 10).to(local_rank)
+ torch.nn.functional.mse_loss(outputs, labels).backward()
+ optimizer.step()
-- Our launch utility is super _Pythonic_
-- If you want to run distributed PyTorch functions from Python Notebooks.
-- Automatic integration with SLURM
+ if rank == 0:
+ return model
+ ```
-Why not?
+ You could also use `transformers.Trainer` (or similar) to automatically handle all the multi-GPU / DDP code above.
+
-- We don't support fault tolerance via torch elastic. Probably only useful if you are using 1000 GPUs. Maybe someone can make a PR.
-## More complicated example
+```python
+import torchrunx as trx
-We could also launch multiple functions, with different GPUs:
+if __name__ == "__main__":
+ trained_model = trx.launch(
+ func=train,
+ hostnames=["localhost", "other_node"],
+ workers_per_host=2 # num. GPUs
+ ).value(rank=0) # get returned object
-```python
-def train_model(model, dataset):
- trained_model = train(model, dataset)
+ torch.save(trained_model.state_dict(), "model.pth")
+```
- if int(os.environ["RANK"]) == 0:
- torch.save(learned_model, 'model.pt')
- return 'model.pt'
+### [Full API](https://torchrunx.readthedocs.io/stable/api.html)
+### [Advanced Usage](https://torchrunx.readthedocs.io/stable/advanced.html)
- return None
+## Why should I use this?
-def test_model(model_path, test_dataset):
- model = torch.load(model_path)
- accuracy = inference(model, test_dataset)
- return accuracy
-```
+Whether you have 1 GPU, 8 GPUs, or 8 machines.
-```python
-import torchrunx as trx
+__Features:__
-model_path = trx.launch(
- func=train_model,
- func_kwargs={'model': my_model, 'training_dataset': mnist_train},
- hostnames=["localhost", "other_node"],
- workers_per_host=2
-)["localhost"][0] # return from rank 0 (first worker on "localhost")
+- Our [`launch()`](https://torchrunx.readthedocs.io/stable/api.html#torchrunx.launch) utility is super _Pythonic_
+ - Return objects from your workers
+ - Run `python script.py` instead of `torchrun script.py`
+ - Launch multi-node functions, even from Python Notebooks
+- Fine-grained control over logging, environment variables, exception handling, etc.
+- Automatic integration with SLURM
+__Robustness:__
+- If you want to run a complex, _modular_ workflow in __one__ script
+ - don't parallelize your entire script: just the functions you want!
+ - no worries about memory leaks or OS failures
-accuracy = trx.launch(
- func=test_model,
- func_kwargs={'model': learned_model, 'test_dataset': mnist_test},
- hostnames=["localhost"],
- workers_per_host=1
-)["localhost"][0]
+__Convenience:__
-print(f'Accuracy: {accuracy}')
-```
+- If you don't want to:
+ - set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) yourself
+ - manually SSH into every machine and `torchrun --master-ip --master-port ...`, babysit failed processes, etc.
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 30373d03..06ac352c 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,4 +1,4 @@
sphinx==6.2.1
furo
myst-parser
-sphinx-toolbox
\ No newline at end of file
+sphinx-toolbox
diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst
index 59bdc333..21efff64 100644
--- a/docs/source/advanced.rst
+++ b/docs/source/advanced.rst
@@ -1,6 +1,33 @@
Advanced Usage
==============
+Multiple functions in one script
+--------------------------------
+
+We could also launch multiple functions (e.g. train on many GPUs, test on one GPU):
+
+.. code-block:: python
+
+ import torchrunx as trx
+
+ trained_model = trx.launch(
+ func=train,
+ hostnames=["node1", "node2"],
+ workers_per_host=8
+ ).value(rank=0)
+
+ accuracy = trx.launch(
+ func=test,
+ func_kwargs={'model': model},
+ hostnames=["localhost"],
+ workers_per_host=1
+ ).value(rank=0)
+
+ print(f'Accuracy: {accuracy}')
+
+``trx.launch()`` is self-cleaning: all processes are terminated (and the used memory is completely released) after each invocation.
+
+
Environment Detection
---------------------
@@ -61,18 +88,9 @@ For example, the `python ... --help` command will then result in:
Custom Logging
--------------
-Logs are generated at the worker and agent level, and are specified to :mod:`torchrunx.launch` via the ``log_spec`` argument. By default, a :mod:`torchrunx.DefaultLogSpec` is instantiated, causing logs at the worker and agent levels to be logged to files under ``'./logs'``, and the rank 0 worker's output streams are streamed to the launcher ``stdout``. Logs are prefixed with a timestamp by default. Agent logs have the format ``{timestamp}-{agent hostname}.log`` and workers have the format ``{timestamp}-{agent hostname}[{worker local rank}].log``.
-
-Custom logging classes can be subclassed from the :mod:`torchrunx.LogSpec` class. Any subclass must have a ``get_map`` method returning a dictionary mapping logger names to lists of :mod:`logging.Handler` objects, in order to be passed to :mod:`torchrunx.launch`. The logger names are of the format ``{agent hostname}`` for agents and ``{agent hostname}[{worker local rank}]`` for workers. The :mod:`torchrunx.DefaultLogSpec` maps all the loggers to :mod:`logging.Filehandler` object pointing to the files mentioned in the previous paragraph. It additionally maps the global rank 0 worker to a :mod:`logging.StreamHandler`, which writes logs the launcher's ``stdout`` stream.
-
-.. autoclass:: torchrunx.LogSpec
- :members:
-
-.. autoclass:: torchrunx.DefaultLogSpec
- :members:
+Logs are generated at the worker and agent level, and are specified to :mod:`torchrunx.launch` via the ``log_spec`` argument. By default, a is instantiated, causing logs at the worker and agent levels to be logged to files under ``'./logs'``, and the rank 0 worker's output streams are streamed to the launcher ``stdout``. Logs are prefixed with a timestamp by default. Agent logs have the format ``{timestamp}-{agent hostname}.log`` and workers have the format ``{timestamp}-{agent hostname}[{worker local rank}].log``.
-..
- TODO: example log structure
+Custom logging classes can be subclassed from the class. Any subclass must have a ``get_map`` method returning a dictionary mapping logger names to lists of :mod:`logging.Handler` objects, in order to be passed to :mod:`torchrunx.launch`. The logger names are of the format ``{agent hostname}`` for agents and ``{agent hostname}[{worker local rank}]`` for workers. The maps all the loggers to :mod:`logging.Filehandler` object pointing to the files mentioned in the previous paragraph. It additionally maps the global rank 0 worker to a :mod:`logging.StreamHandler`, which writes logs the launcher's ``stdout`` stream.
Propagating Exceptions
----------------------
diff --git a/docs/source/api.rst b/docs/source/api.rst
index 5726f26b..608f5b34 100644
--- a/docs/source/api.rst
+++ b/docs/source/api.rst
@@ -1,8 +1,7 @@
API
=============
-..
- TODO: examples, environmental variables available to workers (e.g. RANK, LOCAL_RANK)
+.. autofunction:: torchrunx.launch(func: Callable, ...)
-.. automodule:: torchrunx
- :members: launch, slurm_hosts, slurm_workers
\ No newline at end of file
+.. autoclass:: torchrunx.LaunchResult
+ :members:
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 2edb7aee..6ea3a8b2 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -20,8 +20,13 @@
'myst_parser',
'sphinx_toolbox.sidebar_links',
'sphinx_toolbox.github',
+ 'sphinx.ext.autodoc.typehints',
+ #"sphinx_autodoc_typehints",
]
+autodoc_typehints = "both"
+#typehints_defaults = 'comma'
+
github_username = 'apoorvkh'
github_repository = 'torchrunx'
@@ -43,4 +48,4 @@
epub_show_urls = 'footnote'
# code block syntax highlighting
-#pygments_style = 'sphinx'
\ No newline at end of file
+#pygments_style = 'sphinx'
diff --git a/docs/source/index.rst b/docs/source/index.rst
index 19063776..55900595 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -1,14 +1,9 @@
-Getting Started
-===============
-
.. include:: ../../README.md
:parser: myst_parser.sphinx_
-Contents
---------
-
.. toctree::
- :maxdepth: 2
+ :hidden:
+ :maxdepth: 1
api
advanced
@@ -17,4 +12,4 @@ Contents
.. sidebar-links::
:github:
- :pypi: torchrunx
\ No newline at end of file
+ :pypi: torchrunx
diff --git a/pixi.lock b/pixi.lock
index e67beb9f..e4fae5ca 100644
--- a/pixi.lock
+++ b/pixi.lock
@@ -2601,9 +2601,9 @@ packages:
requires_python: '>=3.8.0'
- kind: pypi
name: torchrunx
- version: 0.1.4
+ version: 0.2.0
path: .
- sha256: de986bf47e1c379e4de6b10ca352715d708bb5f9b4cfc8736e9ee592db5fe1ae
+ sha256: 1753f43bee54bc0da38cdd524dc501c0c2be9fbaaa7036bced9c9d03a7a8e810
requires_dist:
- cloudpickle>=3.0.0
- fabric>=3.0.0
diff --git a/src/torchrunx/__init__.py b/src/torchrunx/__init__.py
index 74214cb8..c5a7d6fd 100644
--- a/src/torchrunx/__init__.py
+++ b/src/torchrunx/__init__.py
@@ -1,9 +1,10 @@
-from .launcher import Launcher, launch
+from .launcher import Launcher, LaunchResult, launch
from .logging_utils import add_filter_to_handler, file_handler, stream_handler
__all__ = [
"Launcher",
"launch",
+ "LaunchResult",
"add_filter_to_handler",
"file_handler",
"stream_handler",
diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py
index 789af155..1860e444 100644
--- a/src/torchrunx/agent.py
+++ b/src/torchrunx/agent.py
@@ -73,7 +73,7 @@ def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerExce
os.environ["WORLD_SIZE"] = str(worker_args.world_size)
os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)
-
+
if worker_args.backend is not None:
backend = worker_args.backend
if backend == "auto":
diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py
index cd4a6098..4203b4f3 100644
--- a/src/torchrunx/launcher.py
+++ b/src/torchrunx/launcher.py
@@ -82,7 +82,7 @@ def build_launch_command(
logger_port: int,
world_size: int,
rank: int,
- env_vars: list[str] | tuple[str],
+ env_vars: tuple[str, ...],
env_file: str | os.PathLike | None,
) -> str:
# shlex.quote prevents shell injection here (resolves S602 in execute_command)
@@ -157,8 +157,9 @@ class Launcher:
workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto"
ssh_config_file: str | os.PathLike | None = None
backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto"
+ timeout: int = 600
log_handlers: list[Handler] | Literal["auto"] | None = "auto"
- env_vars: tuple[str] = ( # pyright: ignore [reportAssignmentType]
+ default_env_vars: tuple[str, ...] = (
"PATH",
"LD_LIBRARY",
"LIBRARY_PATH",
@@ -168,8 +169,8 @@ class Launcher:
"PYTORCH*",
"NCCL*",
)
+ extra_env_vars: tuple[str, ...] = ()
env_file: str | os.PathLike | None = None
- timeout: int = 600
def run( # noqa: C901, PLR0912
self,
@@ -177,19 +178,6 @@ def run( # noqa: C901, PLR0912
func_args: tuple[Any] | None = None,
func_kwargs: dict[str, Any] | None = None,
) -> LaunchResult:
- """
- Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch`
-
- :param func: The distributed function to call on all workers
- :type func: Callable
- :param func_args: Any positional arguments to be provided when calling ``func``
- :type func_args: tuple[Any]
- :param func_kwargs: Any keyword arguments to be provided when calling ``func``
- :type func_kwargs: dict[str, Any]
- :raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func``
- :return: A dictionary mapping worker ranks to their output
- :rtype: dict[int, Any]
- """
if not dist.is_available():
msg = "The torch.distributed package is not available."
raise RuntimeError(msg)
@@ -235,7 +223,7 @@ def run( # noqa: C901, PLR0912
logger_port=log_receiver.port,
world_size=world_size,
rank=i + 1,
- env_vars=self.env_vars,
+ env_vars=(self.default_env_vars + self.extra_env_vars),
env_file=self.env_file,
),
hostname=hostname,
@@ -316,8 +304,9 @@ def launch(
workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto",
ssh_config_file: str | os.PathLike | None = None,
backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto",
+ timeout: int = 600,
log_handlers: list[Handler] | Literal["auto"] | None = "auto",
- env_vars: tuple[str] = ( # pyright: ignore [reportArgumentType]
+ default_env_vars: tuple[str, ...] = (
"PATH",
"LD_LIBRARY",
"LIBRARY_PATH",
@@ -327,49 +316,37 @@ def launch(
"PYTORCH*",
"NCCL*",
),
+ extra_env_vars: tuple[str, ...] = (),
env_file: str | os.PathLike | None = None,
- timeout: int = 600,
) -> LaunchResult:
"""
Launch a distributed PyTorch function on the specified nodes.
- :param func: The distributed function to call on all workers
- :type func: Callable
- :param func_args: Any positional arguments to be provided when calling ``func``
- :type func_args: tuple[Any]
- :param func_kwargs: Any keyword arguments to be provided when calling ``func``
- :type func_kwargs: dict[str, Any]
- :param auto: Automatically determine allocation sizes, supports Slurm allocation. ``hostnames`` and ``workers_per_host`` are automatically assigned if they're set to ``None``, defaults to None
- :type auto: bool, optional
- :param hostnames: A list of node hostnames to start workers on, defaults to ["localhost"]
- :type hostnames: list[str] | Literal["auto", "slurm"] | None, optional
- :param workers_per_host: The number of workers per node. Providing an ``int`` implies all nodes should have ``workers_per_host`` workers, meanwhile providing a list causes node ``i`` to have ``worker_per_host[i]`` workers, defaults to 1
- :type workers_per_host: int | list[int] | Literal["auto", "slurm"] | None, optional
- :param ssh_config_file: An SSH configuration file to use when connecting to nodes, defaults to None
- :type ssh_config_file: str | os.PathLike | None, optional
- :param backend: A ``torch.distributed`` `backend string `_, defaults to None
- :type backend: Literal['mpi', 'gloo', 'nccl', 'ucc', None], optional
- :param log_handlers: A list of handlers to manage agent and worker logs, defaults to []
- :type log_handlers: list[Handler] | Literal["auto"], optional
- :param env_vars: A list of environmental variables to be copied from the launcher environment to workers. Allows for bash pattern matching syntax, defaults to ["PATH", "LD_LIBRARY", "LIBRARY_PATH", "PYTHON*", "CUDA*", "TORCH*", "PYTORCH*", "NCCL*"]
- :type env_vars: list[str], optional
- :param env_file: An additional environment file that will be sourced prior to executing ``func``, defaults to None
- :type env_file: str | os.PathLike | None, optional
- :param timeout: Worker process group timeout, defaults to 600
- :type timeout: int, optional
- :raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func``
- :return: A dictionary mapping worker ranks to their output
- :rtype: dict[int, Any]
+ :param func:
+ :param func_args:
+ :param func_kwargs:
+ :param hostnames: Nodes to launch the function on. Default infers from a SLURM environment or runs on localhost.
+ :param workers_per_host: Number of processes to run per node. Can define per node with :type:`list[int]`.
+ :param ssh_config_file: An SSH configuration file for connecting to nodes, by default loads ``~/.ssh/config`` or ``/etc/ssh/ssh_config``.
+ :param backend: `Backend `_ to initialize worker process group with. Default uses NCCL (if GPUs available) or GLOO. Disabled by ``None``.
+ :param timeout: Worker process group timeout (seconds).
+ :param log_handlers: A list of handlers to manage agent and worker logs. Default uses an automatic basic logging scheme.
+ :param default_env_vars: A list of environmental variables to be copied from the launcher process to workers. Allows for bash pattern matching syntax.
+ :param extra_env_vars: Additional, user-specified variables to copy.
+ :param env_file: A file (like ``.env``) with additional environment variables to copy.
+ :raises RuntimeError: May fail if ``torch.distributed`` not available or communication timeout between nodes
+ :raises Exception: Propagates exceptions raised in worker processes
""" # noqa: E501
return Launcher(
hostnames=hostnames,
workers_per_host=workers_per_host,
ssh_config_file=ssh_config_file,
backend=backend,
+ timeout=timeout,
log_handlers=log_handlers,
- env_vars=env_vars,
+ default_env_vars=default_env_vars,
+ extra_env_vars=extra_env_vars,
env_file=env_file,
- timeout=timeout,
).run(func=func, func_args=func_args, func_kwargs=func_kwargs)
@@ -391,6 +368,12 @@ def all(self, by: Literal["rank"]) -> list[Any]:
pass
def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]:
+ """
+ Get all worker return values by rank or hostname.
+
+ :param by: Whether to aggregate all return values by hostname, or just output all of them \
+ in order of rank, defaults to ``'hostname'``
+ """
if by == "hostname":
return dict(zip(self.hostnames, self.return_values))
elif by == "rank": # noqa: RET505
@@ -400,10 +383,20 @@ def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[An
raise TypeError(msg)
def values(self, hostname: str) -> list[Any]:
+ """
+ Get worker return values for host ``hostname``.
+
+ :param hostname: The host to get return values from
+ """
host_idx = self.hostnames.index(hostname)
return self.return_values[host_idx]
def value(self, rank: int) -> Any:
+ """
+ Get worker return value from global rank ``rank``.
+
+ :param rank: Global worker rank to get return value from
+ """
if rank < 0:
msg = f"Rank {rank} must be larger than 0"
raise ValueError(msg)