diff --git a/README.md b/README.md index 51f1f772..e3271e91 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@ # torchrunx 🔥 [![PyPI - Python Version](https://img.shields.io/pypi/pyversions/torchrunx)](https://github.com/apoorvkh/torchrunx/blob/main/pyproject.toml) +[![PyTorch Version](https://img.shields.io/badge/torch-%3E%3D2.0-orange)](https://github.com/pytorch/pytorch) [![PyPI - Version](https://img.shields.io/pypi/v/torchrunx)](https://pypi.org/project/torchrunx/) ![Tests](https://img.shields.io/github/actions/workflow/status/apoorvkh/torchrunx/.github%2Fworkflows%2Fmain.yml) [![Docs](https://readthedocs.org/projects/torchrunx/badge/?version=stable)](https://torchrunx.readthedocs.io) @@ -16,102 +17,78 @@ By [Apoorv Khandelwal](http://apoorvkh.com) and [Peter Curtin](https://github.co pip install torchrunx ``` -Requires: Linux, Python >= 3.8.1, PyTorch >= 2.0 +**Requires:** Linux (with shared filesystem & SSH access if using multiple machines) -Shared filesystem & SSH access if using multiple machines +## Demo -## Minimal example +Here's a simple example where we "train" a model on two nodes (with 2 GPUs each). -Here's a simple example where we distribute `distributed_function` to two hosts (with 2 GPUs each): +
+ Training code -```python -def train_model(model, dataset): - trained_model = train(model, dataset) - - if int(os.environ["RANK"]) == 0: - torch.save(learned_model, 'model.pt') - return 'model.pt' - - return None -``` - -```python -import torchrunx as trx - -model_path = trx.launch( - func=train_model, - func_kwargs={'model': my_model, 'training_dataset': mnist_train}, - hostnames=["localhost", "other_node"], - workers_per_host=2 -)["localhost"][0] # return from rank 0 (first worker on "localhost") -``` - -## Why should I use this? - -[`torchrun`](https://pytorch.org/docs/stable/elastic/run.html) is a hammer. `torchrunx` is a chisel. - -Whether you have 1 GPU, 8 GPUs, or 8 machines: - -Convenience: + ```python + import os + import torch -- If you don't want to set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) yourself -- If you want to run `python myscript.py` instead of `torchrun myscript.py` -- If you don't want to manually SSH and run `torchrun --master-ip --master-port ...` on every machine (and if you don't want to babysit these machines for hanging failures) + def train(): + rank = int(os.environ['RANK']) + local_rank = int(os.environ['LOCAL_RANK']) -Robustness: + model = torch.nn.Linear(10, 10).to(local_rank) + ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) + optimizer = torch.optim.AdamW(ddp_model.parameters()) -- If you want to run a complex, _modular_ workflow in one script - - no worries about memory leaks or OS failures - - don't parallelize your entire script: just the functions you want - -Features: + optimizer.zero_grad() + outputs = ddp_model(torch.randn(5, 10)) + labels = torch.randn(5, 10).to(local_rank) + torch.nn.functional.mse_loss(outputs, labels).backward() + optimizer.step() -- Our launch utility is super _Pythonic_ -- If you want to run distributed PyTorch functions from Python Notebooks. -- Automatic integration with SLURM + if rank == 0: + return model + ``` -Why not? + You could also use `transformers.Trainer` (or similar) to automatically handle all the multi-GPU / DDP code above. +
-- We don't support fault tolerance via torch elastic. Probably only useful if you are using 1000 GPUs. Maybe someone can make a PR. -## More complicated example +```python +import torchrunx as trx -We could also launch multiple functions, with different GPUs: +if __name__ == "__main__": + trained_model = trx.launch( + func=train, + hostnames=["localhost", "other_node"], + workers_per_host=2 # num. GPUs + ).value(rank=0) # get returned object -```python -def train_model(model, dataset): - trained_model = train(model, dataset) + torch.save(trained_model.state_dict(), "model.pth") +``` - if int(os.environ["RANK"]) == 0: - torch.save(learned_model, 'model.pt') - return 'model.pt' +### [Full API](https://torchrunx.readthedocs.io/stable/api.html) +### [Advanced Usage](https://torchrunx.readthedocs.io/stable/advanced.html) - return None +## Why should I use this? -def test_model(model_path, test_dataset): - model = torch.load(model_path) - accuracy = inference(model, test_dataset) - return accuracy -``` +Whether you have 1 GPU, 8 GPUs, or 8 machines. -```python -import torchrunx as trx +__Features:__ -model_path = trx.launch( - func=train_model, - func_kwargs={'model': my_model, 'training_dataset': mnist_train}, - hostnames=["localhost", "other_node"], - workers_per_host=2 -)["localhost"][0] # return from rank 0 (first worker on "localhost") +- Our [`launch()`](https://torchrunx.readthedocs.io/stable/api.html#torchrunx.launch) utility is super _Pythonic_ + - Return objects from your workers + - Run `python script.py` instead of `torchrun script.py` + - Launch multi-node functions, even from Python Notebooks +- Fine-grained control over logging, environment variables, exception handling, etc. +- Automatic integration with SLURM +__Robustness:__ +- If you want to run a complex, _modular_ workflow in __one__ script + - don't parallelize your entire script: just the functions you want! + - no worries about memory leaks or OS failures -accuracy = trx.launch( - func=test_model, - func_kwargs={'model': learned_model, 'test_dataset': mnist_test}, - hostnames=["localhost"], - workers_per_host=1 -)["localhost"][0] +__Convenience:__ -print(f'Accuracy: {accuracy}') -``` +- If you don't want to: + - set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) yourself + - manually SSH into every machine and `torchrun --master-ip --master-port ...`, babysit failed processes, etc. diff --git a/docs/requirements.txt b/docs/requirements.txt index 30373d03..06ac352c 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,4 +1,4 @@ sphinx==6.2.1 furo myst-parser -sphinx-toolbox \ No newline at end of file +sphinx-toolbox diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 59bdc333..21efff64 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -1,6 +1,33 @@ Advanced Usage ============== +Multiple functions in one script +-------------------------------- + +We could also launch multiple functions (e.g. train on many GPUs, test on one GPU): + +.. code-block:: python + + import torchrunx as trx + + trained_model = trx.launch( + func=train, + hostnames=["node1", "node2"], + workers_per_host=8 + ).value(rank=0) + + accuracy = trx.launch( + func=test, + func_kwargs={'model': model}, + hostnames=["localhost"], + workers_per_host=1 + ).value(rank=0) + + print(f'Accuracy: {accuracy}') + +``trx.launch()`` is self-cleaning: all processes are terminated (and the used memory is completely released) after each invocation. + + Environment Detection --------------------- @@ -61,18 +88,9 @@ For example, the `python ... --help` command will then result in: Custom Logging -------------- -Logs are generated at the worker and agent level, and are specified to :mod:`torchrunx.launch` via the ``log_spec`` argument. By default, a :mod:`torchrunx.DefaultLogSpec` is instantiated, causing logs at the worker and agent levels to be logged to files under ``'./logs'``, and the rank 0 worker's output streams are streamed to the launcher ``stdout``. Logs are prefixed with a timestamp by default. Agent logs have the format ``{timestamp}-{agent hostname}.log`` and workers have the format ``{timestamp}-{agent hostname}[{worker local rank}].log``. - -Custom logging classes can be subclassed from the :mod:`torchrunx.LogSpec` class. Any subclass must have a ``get_map`` method returning a dictionary mapping logger names to lists of :mod:`logging.Handler` objects, in order to be passed to :mod:`torchrunx.launch`. The logger names are of the format ``{agent hostname}`` for agents and ``{agent hostname}[{worker local rank}]`` for workers. The :mod:`torchrunx.DefaultLogSpec` maps all the loggers to :mod:`logging.Filehandler` object pointing to the files mentioned in the previous paragraph. It additionally maps the global rank 0 worker to a :mod:`logging.StreamHandler`, which writes logs the launcher's ``stdout`` stream. - -.. autoclass:: torchrunx.LogSpec - :members: - -.. autoclass:: torchrunx.DefaultLogSpec - :members: +Logs are generated at the worker and agent level, and are specified to :mod:`torchrunx.launch` via the ``log_spec`` argument. By default, a is instantiated, causing logs at the worker and agent levels to be logged to files under ``'./logs'``, and the rank 0 worker's output streams are streamed to the launcher ``stdout``. Logs are prefixed with a timestamp by default. Agent logs have the format ``{timestamp}-{agent hostname}.log`` and workers have the format ``{timestamp}-{agent hostname}[{worker local rank}].log``. -.. - TODO: example log structure +Custom logging classes can be subclassed from the class. Any subclass must have a ``get_map`` method returning a dictionary mapping logger names to lists of :mod:`logging.Handler` objects, in order to be passed to :mod:`torchrunx.launch`. The logger names are of the format ``{agent hostname}`` for agents and ``{agent hostname}[{worker local rank}]`` for workers. The maps all the loggers to :mod:`logging.Filehandler` object pointing to the files mentioned in the previous paragraph. It additionally maps the global rank 0 worker to a :mod:`logging.StreamHandler`, which writes logs the launcher's ``stdout`` stream. Propagating Exceptions ---------------------- diff --git a/docs/source/api.rst b/docs/source/api.rst index 5726f26b..608f5b34 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -1,8 +1,7 @@ API ============= -.. - TODO: examples, environmental variables available to workers (e.g. RANK, LOCAL_RANK) +.. autofunction:: torchrunx.launch(func: Callable, ...) -.. automodule:: torchrunx - :members: launch, slurm_hosts, slurm_workers \ No newline at end of file +.. autoclass:: torchrunx.LaunchResult + :members: diff --git a/docs/source/conf.py b/docs/source/conf.py index 2edb7aee..6ea3a8b2 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -20,8 +20,13 @@ 'myst_parser', 'sphinx_toolbox.sidebar_links', 'sphinx_toolbox.github', + 'sphinx.ext.autodoc.typehints', + #"sphinx_autodoc_typehints", ] +autodoc_typehints = "both" +#typehints_defaults = 'comma' + github_username = 'apoorvkh' github_repository = 'torchrunx' @@ -43,4 +48,4 @@ epub_show_urls = 'footnote' # code block syntax highlighting -#pygments_style = 'sphinx' \ No newline at end of file +#pygments_style = 'sphinx' diff --git a/docs/source/index.rst b/docs/source/index.rst index 19063776..55900595 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,14 +1,9 @@ -Getting Started -=============== - .. include:: ../../README.md :parser: myst_parser.sphinx_ -Contents --------- - .. toctree:: - :maxdepth: 2 + :hidden: + :maxdepth: 1 api advanced @@ -17,4 +12,4 @@ Contents .. sidebar-links:: :github: - :pypi: torchrunx \ No newline at end of file + :pypi: torchrunx diff --git a/pixi.lock b/pixi.lock index e67beb9f..e4fae5ca 100644 --- a/pixi.lock +++ b/pixi.lock @@ -2601,9 +2601,9 @@ packages: requires_python: '>=3.8.0' - kind: pypi name: torchrunx - version: 0.1.4 + version: 0.2.0 path: . - sha256: de986bf47e1c379e4de6b10ca352715d708bb5f9b4cfc8736e9ee592db5fe1ae + sha256: 1753f43bee54bc0da38cdd524dc501c0c2be9fbaaa7036bced9c9d03a7a8e810 requires_dist: - cloudpickle>=3.0.0 - fabric>=3.0.0 diff --git a/src/torchrunx/__init__.py b/src/torchrunx/__init__.py index 74214cb8..c5a7d6fd 100644 --- a/src/torchrunx/__init__.py +++ b/src/torchrunx/__init__.py @@ -1,9 +1,10 @@ -from .launcher import Launcher, launch +from .launcher import Launcher, LaunchResult, launch from .logging_utils import add_filter_to_handler, file_handler, stream_handler __all__ = [ "Launcher", "launch", + "LaunchResult", "add_filter_to_handler", "file_handler", "stream_handler", diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 789af155..1860e444 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -73,7 +73,7 @@ def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerExce os.environ["WORLD_SIZE"] = str(worker_args.world_size) os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname os.environ["MASTER_PORT"] = str(worker_args.main_agent_port) - + if worker_args.backend is not None: backend = worker_args.backend if backend == "auto": diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index cd4a6098..4203b4f3 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -82,7 +82,7 @@ def build_launch_command( logger_port: int, world_size: int, rank: int, - env_vars: list[str] | tuple[str], + env_vars: tuple[str, ...], env_file: str | os.PathLike | None, ) -> str: # shlex.quote prevents shell injection here (resolves S602 in execute_command) @@ -157,8 +157,9 @@ class Launcher: workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto" ssh_config_file: str | os.PathLike | None = None backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto" + timeout: int = 600 log_handlers: list[Handler] | Literal["auto"] | None = "auto" - env_vars: tuple[str] = ( # pyright: ignore [reportAssignmentType] + default_env_vars: tuple[str, ...] = ( "PATH", "LD_LIBRARY", "LIBRARY_PATH", @@ -168,8 +169,8 @@ class Launcher: "PYTORCH*", "NCCL*", ) + extra_env_vars: tuple[str, ...] = () env_file: str | os.PathLike | None = None - timeout: int = 600 def run( # noqa: C901, PLR0912 self, @@ -177,19 +178,6 @@ def run( # noqa: C901, PLR0912 func_args: tuple[Any] | None = None, func_kwargs: dict[str, Any] | None = None, ) -> LaunchResult: - """ - Launch a distributed PyTorch function on the specified nodes. See :mod:`torchrunx.launch` - - :param func: The distributed function to call on all workers - :type func: Callable - :param func_args: Any positional arguments to be provided when calling ``func`` - :type func_args: tuple[Any] - :param func_kwargs: Any keyword arguments to be provided when calling ``func`` - :type func_kwargs: dict[str, Any] - :raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func`` - :return: A dictionary mapping worker ranks to their output - :rtype: dict[int, Any] - """ if not dist.is_available(): msg = "The torch.distributed package is not available." raise RuntimeError(msg) @@ -235,7 +223,7 @@ def run( # noqa: C901, PLR0912 logger_port=log_receiver.port, world_size=world_size, rank=i + 1, - env_vars=self.env_vars, + env_vars=(self.default_env_vars + self.extra_env_vars), env_file=self.env_file, ), hostname=hostname, @@ -316,8 +304,9 @@ def launch( workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto", ssh_config_file: str | os.PathLike | None = None, backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto", + timeout: int = 600, log_handlers: list[Handler] | Literal["auto"] | None = "auto", - env_vars: tuple[str] = ( # pyright: ignore [reportArgumentType] + default_env_vars: tuple[str, ...] = ( "PATH", "LD_LIBRARY", "LIBRARY_PATH", @@ -327,49 +316,37 @@ def launch( "PYTORCH*", "NCCL*", ), + extra_env_vars: tuple[str, ...] = (), env_file: str | os.PathLike | None = None, - timeout: int = 600, ) -> LaunchResult: """ Launch a distributed PyTorch function on the specified nodes. - :param func: The distributed function to call on all workers - :type func: Callable - :param func_args: Any positional arguments to be provided when calling ``func`` - :type func_args: tuple[Any] - :param func_kwargs: Any keyword arguments to be provided when calling ``func`` - :type func_kwargs: dict[str, Any] - :param auto: Automatically determine allocation sizes, supports Slurm allocation. ``hostnames`` and ``workers_per_host`` are automatically assigned if they're set to ``None``, defaults to None - :type auto: bool, optional - :param hostnames: A list of node hostnames to start workers on, defaults to ["localhost"] - :type hostnames: list[str] | Literal["auto", "slurm"] | None, optional - :param workers_per_host: The number of workers per node. Providing an ``int`` implies all nodes should have ``workers_per_host`` workers, meanwhile providing a list causes node ``i`` to have ``worker_per_host[i]`` workers, defaults to 1 - :type workers_per_host: int | list[int] | Literal["auto", "slurm"] | None, optional - :param ssh_config_file: An SSH configuration file to use when connecting to nodes, defaults to None - :type ssh_config_file: str | os.PathLike | None, optional - :param backend: A ``torch.distributed`` `backend string `_, defaults to None - :type backend: Literal['mpi', 'gloo', 'nccl', 'ucc', None], optional - :param log_handlers: A list of handlers to manage agent and worker logs, defaults to [] - :type log_handlers: list[Handler] | Literal["auto"], optional - :param env_vars: A list of environmental variables to be copied from the launcher environment to workers. Allows for bash pattern matching syntax, defaults to ["PATH", "LD_LIBRARY", "LIBRARY_PATH", "PYTHON*", "CUDA*", "TORCH*", "PYTORCH*", "NCCL*"] - :type env_vars: list[str], optional - :param env_file: An additional environment file that will be sourced prior to executing ``func``, defaults to None - :type env_file: str | os.PathLike | None, optional - :param timeout: Worker process group timeout, defaults to 600 - :type timeout: int, optional - :raises RuntimeError: May fail due to misconfiguration, or errors thrown by ``func`` - :return: A dictionary mapping worker ranks to their output - :rtype: dict[int, Any] + :param func: + :param func_args: + :param func_kwargs: + :param hostnames: Nodes to launch the function on. Default infers from a SLURM environment or runs on localhost. + :param workers_per_host: Number of processes to run per node. Can define per node with :type:`list[int]`. + :param ssh_config_file: An SSH configuration file for connecting to nodes, by default loads ``~/.ssh/config`` or ``/etc/ssh/ssh_config``. + :param backend: `Backend `_ to initialize worker process group with. Default uses NCCL (if GPUs available) or GLOO. Disabled by ``None``. + :param timeout: Worker process group timeout (seconds). + :param log_handlers: A list of handlers to manage agent and worker logs. Default uses an automatic basic logging scheme. + :param default_env_vars: A list of environmental variables to be copied from the launcher process to workers. Allows for bash pattern matching syntax. + :param extra_env_vars: Additional, user-specified variables to copy. + :param env_file: A file (like ``.env``) with additional environment variables to copy. + :raises RuntimeError: May fail if ``torch.distributed`` not available or communication timeout between nodes + :raises Exception: Propagates exceptions raised in worker processes """ # noqa: E501 return Launcher( hostnames=hostnames, workers_per_host=workers_per_host, ssh_config_file=ssh_config_file, backend=backend, + timeout=timeout, log_handlers=log_handlers, - env_vars=env_vars, + default_env_vars=default_env_vars, + extra_env_vars=extra_env_vars, env_file=env_file, - timeout=timeout, ).run(func=func, func_args=func_args, func_kwargs=func_kwargs) @@ -391,6 +368,12 @@ def all(self, by: Literal["rank"]) -> list[Any]: pass def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]: + """ + Get all worker return values by rank or hostname. + + :param by: Whether to aggregate all return values by hostname, or just output all of them \ + in order of rank, defaults to ``'hostname'`` + """ if by == "hostname": return dict(zip(self.hostnames, self.return_values)) elif by == "rank": # noqa: RET505 @@ -400,10 +383,20 @@ def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[An raise TypeError(msg) def values(self, hostname: str) -> list[Any]: + """ + Get worker return values for host ``hostname``. + + :param hostname: The host to get return values from + """ host_idx = self.hostnames.index(hostname) return self.return_values[host_idx] def value(self, rank: int) -> Any: + """ + Get worker return value from global rank ``rank``. + + :param rank: Global worker rank to get return value from + """ if rank < 0: msg = f"Rank {rank} must be larger than 0" raise ValueError(msg)