diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index fa436244..24ff0658 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -1,3 +1,17 @@
# Contributing
-We use the [`pixi`](https://pixi.sh) package manager. Simply [install `pixi`](https://pixi.sh/latest/#installation) and run `pixi shell` in this repository. We use `ruff` for linting and formatting, `pyright` for static type checking, and `pytest` for testing. We build for `PyPI`. Our release pipeline is powered by Github Actions.
+We use the [`pixi`](https://pixi.sh) package manager. Simply [install `pixi`](https://pixi.sh/latest/#installation) and run `pixi shell` in this repository to activate the environment.
+
+We use `ruff check` for linting, `ruff format` for formatting, `pyright` for static type checking, and `pytest` for testing.
+
+We build wheels with `python -m build` and upload to [PyPI](https://pypi.org/project/torchrunx) with [twine](https://twine.readthedocs.io). Our release pipeline is powered by Github Actions.
+
+## Pull Requests
+
+Make a pull request with your changes on Github and we'll try to look at soon! If addressing a specific issue, mention it in the PR, and offer a short explanation of your fix. If adding a new feature, explain why it's meaningful and belongs in __torchrunx__.
+
+## Testing
+
+`tests/` contains `pytest`-style tests for validating that code changes do not break the core functionality of our library.
+
+At the moment, we run `pytest tests/test_ci.py` (i.e. simple single-node CPU-only tests) in our Github Actions CI pipeline (`.github/workflows/release.yml`). One can manually run our more involved tests (on GPUs, on multiple machines from SLURM) on their own hardware.
diff --git a/README.md b/README.md
index e3271e91..74f3ce25 100644
--- a/README.md
+++ b/README.md
@@ -56,12 +56,13 @@ Here's a simple example where we "train" a model on two nodes (with 2 GPUs each)
import torchrunx as trx
if __name__ == "__main__":
- trained_model = trx.launch(
+ result = trx.launch(
func=train,
hostnames=["localhost", "other_node"],
- workers_per_host=2 # num. GPUs
- ).value(rank=0) # get returned object
+ workers_per_host=2 # number of GPUs
+ )
+ trained_model = result.rank(0)
torch.save(trained_model.state_dict(), "model.pth")
```
@@ -70,9 +71,9 @@ if __name__ == "__main__":
## Why should I use this?
-Whether you have 1 GPU, 8 GPUs, or 8 machines.
+Whether you have 1 GPU, 8 GPUs, or 8 machines:
-__Features:__
+__Features__
- Our [`launch()`](https://torchrunx.readthedocs.io/stable/api.html#torchrunx.launch) utility is super _Pythonic_
- Return objects from your workers
@@ -81,13 +82,13 @@ __Features:__
- Fine-grained control over logging, environment variables, exception handling, etc.
- Automatic integration with SLURM
-__Robustness:__
+__Robustness__
- If you want to run a complex, _modular_ workflow in __one__ script
- don't parallelize your entire script: just the functions you want!
- no worries about memory leaks or OS failures
-__Convenience:__
+__Convenience__
- If you don't want to:
- set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) yourself
diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst
index 21efff64..6cc61d1f 100644
--- a/docs/source/advanced.rst
+++ b/docs/source/advanced.rst
@@ -14,43 +14,32 @@ We could also launch multiple functions (e.g. train on many GPUs, test on one GP
func=train,
hostnames=["node1", "node2"],
workers_per_host=8
- ).value(rank=0)
+ ).rank(0)
accuracy = trx.launch(
func=test,
- func_kwargs={'model': model},
+ func_args=(trained_model,),
hostnames=["localhost"],
workers_per_host=1
- ).value(rank=0)
+ ).rank(0)
print(f'Accuracy: {accuracy}')
-``trx.launch()`` is self-cleaning: all processes are terminated (and the used memory is completely released) after each invocation.
+:mod:`torchrunx.launch` is self-cleaning: all processes are terminated (and the used memory is completely released) before the subsequent invocation.
-Environment Detection
----------------------
-
-By default, the `hostnames` or `workers_per_host` :mod:`torchrunx.launch` parameters are set to "auto". These parameters are populated via `SLURM`_ if a SLURM environment is automatically detected. Otherwise, `hostnames = ["localhost"]` and `workers_per_host` is set to the number of GPUs or CPUs (in order of precedence) available locally.
-
-SLURM
-+++++
-
-If the `hostnames` or `workers_per_host` parameters are set to `"slurm"`, their values will be filled from the SLURM job. Passing `"slurm"` raises a `RuntimeError` if no SLURM allocation is detected from the environment.
-
-``Launcher`` class
-------------------
+Launcher class
+--------------
-We provide the ``torchrunx.Launcher`` class as an alternative to ``torchrunx.launch``.
+We provide the :mod:`torchrunx.Launcher` class as an alias to :mod:`torchrunx.launch`.
.. autoclass:: torchrunx.Launcher
- :members:
-.. .. autofunction:: torchrunx.Launcher.run
+ :members:
-CLI Support
-+++++++++++
+CLI integration
+^^^^^^^^^^^^^^^
-This allows **torchrunx** arguments to be more easily populated by CLI packages like `tyro `_:
+We can use :mod:`torchrunx.Launcher` to populate arguments from the CLI (e.g. with `tyro `_):
.. code:: python
@@ -58,57 +47,68 @@ This allows **torchrunx** arguments to be more easily populated by CLI packages
import tyro
def distributed_function():
- print("Hello world!")
+ pass
if __name__ == "__main__":
launcher = tyro.cli(trx.Launcher)
- launcher.run(distributed_function, {})
+ launcher.run(distributed_function)
-For example, the `python ... --help` command will then result in:
+``python ... --help`` then results in:
.. code:: bash
- ╭─ options ─────────────────────────────────────────────────────────────────────────────────────────────────────╮
- │ -h, --help show this help message and exit │
- │ --hostnames {[STR [STR ...]]}|{auto,slurm} │
- │ (default: auto) │
- │ --workers-per-host INT|{[INT [INT ...]]}|{auto,slurm} │
- │ (default: auto) │
- │ --ssh-config-file {None}|STR|PATH │
- │ (default: None) │
- │ --backend {None,nccl,gloo,mpi,ucc,auto} │
- │ (default: auto) │
- │ --log-handlers {fixed} (fixed to: a u t o) │
- │ --env-vars STR (default: PATH LD_LIBRARY LIBRARY_PATH 'PYTHON*' 'CUDA*' 'TORCH*' 'PYTORCH*' 'NCCL*') │
- │ --env-file {None}|STR|PATH │
- │ (default: None) │
- │ --timeout INT (default: 600) │
- ╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
-
-Custom Logging
---------------
+ ╭─ options ─────────────────────────────────────────────╮
+ │ -h, --help show this help message and exit │
+ │ --hostnames {[STR [STR ...]]}|{auto,slurm} │
+ │ (default: auto) │
+ │ --workers-per-host INT|{[INT [INT ...]]}|{auto,slurm} │
+ │ (default: auto) │
+ │ --ssh-config-file {None}|STR|PATH │
+ │ (default: None) │
+ │ --backend {None,nccl,gloo,mpi,ucc,auto} │
+ │ (default: auto) │
+ │ --timeout INT (default: 600) │
+ │ --default-env-vars [STR [STR ...]] │
+ │ (default: PATH LD_LIBRARY ...) │
+ │ --extra-env-vars [STR [STR ...]] │
+ │ (default: ) │
+ │ --env-file {None}|STR|PATH │
+ │ (default: None) │
+ ╰───────────────────────────────────────────────────────╯
+
+SLURM integration
+-----------------
+
+By default, the ``hostnames`` or ``workers_per_host`` arguments are populated from the current SLURM allocation. If no allocation is detected, we assume 1 machine (localhost) with N workers (num. GPUs or CPUs).
+Raises a ``RuntimeError`` if ``hostnames="slurm"`` or ``workers_per_host="slurm"`` but no allocation is detected.
+
+Propagating exceptions
+----------------------
-Logs are generated at the worker and agent level, and are specified to :mod:`torchrunx.launch` via the ``log_spec`` argument. By default, a is instantiated, causing logs at the worker and agent levels to be logged to files under ``'./logs'``, and the rank 0 worker's output streams are streamed to the launcher ``stdout``. Logs are prefixed with a timestamp by default. Agent logs have the format ``{timestamp}-{agent hostname}.log`` and workers have the format ``{timestamp}-{agent hostname}[{worker local rank}].log``.
+Exceptions that are raised in workers will be raised by the launcher process.
-Custom logging classes can be subclassed from the class. Any subclass must have a ``get_map`` method returning a dictionary mapping logger names to lists of :mod:`logging.Handler` objects, in order to be passed to :mod:`torchrunx.launch`. The logger names are of the format ``{agent hostname}`` for agents and ``{agent hostname}[{worker local rank}]`` for workers. The maps all the loggers to :mod:`logging.Filehandler` object pointing to the files mentioned in the previous paragraph. It additionally maps the global rank 0 worker to a :mod:`logging.StreamHandler`, which writes logs the launcher's ``stdout`` stream.
+A :mod:`torchrunx.AgentFailedError` or :mod:`torchrunx.WorkerFailedError` will be raised if any agent or worker dies unexpectedly (e.g. if sent a signal from the OS, due to segmentation faults or OOM).
-Propagating Exceptions
-----------------------
+Environment variables
+---------------------
-Exceptions that are raised in Workers will be raised in the Launcher process and can be caught by wrapping :mod:`torchrunx.launch` in a try-except clause.
+Environment variables in the launcher process that match the ``default_env_vars`` argument are automatically copied to agents and workers. We set useful defaults for Python and PyTorch. Environment variables are pattern-matched with this list using ``fnmatch``.
-If a worker is killed by the operating system (e.g. due to Segmentation Fault or SIGKILL by running out of memory), the Launcher process raises a RuntimeError.
+``default_env_vars`` can be overriden if desired. This list can be augmented using ``extra_env_vars``. Additional environment variables (and more custom bash logic) can be included via the ``env_file`` argument. Our agents ``source`` this file.
-Environment Variables
----------------------
+We also set the following environment variables in each worker: ``LOCAL_RANK``, ``RANK``, ``LOCAL_WORLD_SIZE``, ``WORLD_SIZE``, ``MASTER_ADDR``, and ``MASTER_PORT``.
+
+Custom logging
+--------------
+
+We forward all logs (i.e. from :mod:`logging` and :mod:`sys.stdout`/:mod:`sys.stderr`) from workers and agents to the launcher. By default, the logs from the first agent and its first worker are printed into the launcher's ``stdout`` stream. Logs from all agents and workers are written to files in ``$TORCHRUNX_LOG_DIR`` (default: ``./torchrunx_logs``) and are named by timestamp, hostname, and local_rank.
+
+:mod:`logging.Handler` objects can be provided via the ``log_handlers`` argument to provide further customization (mapping specific agents/workers to custom output streams).
-The :mod:`torchrunx.launch` ``env_vars`` argument allows the user to specify which environmental variables should be copied to the agents from the launcher environment. By default, it attempts to copy variables related to Python and important packages/technologies that **torchrunx** uses such as PyTorch, NCCL, CUDA, and more. Strings provided are matched with the names of environmental variables using ``fnmatch`` - standard UNIX filename pattern matching. The variables are inserted into the agent environments, and then copied to workers' environments when they are spawned.
+We provide some utilities to help:
-:mod:`torchrunx.launch` also accepts the ``env_file`` argument, which is designed to expose more advanced environmental configuration to the user. When a file is provided as this argument, the launcher will source the file on each node before executing the agent. This allows for custom bash scripts to be provided in the environmental variables, and allows for node-specific environmental variables to be set.
+.. autofunction:: torchrunx.file_handler
-..
- TODO: example env_file
+.. autofunction:: torchrunx.stream_handler
-Support for Numpy >= 2.0
-------------------------
-only supported if `torch>=2.3`
+.. autofunction:: torchrunx.add_filter_to_handler
diff --git a/docs/source/api.rst b/docs/source/api.rst
index 31b35fcf..518b323f 100644
--- a/docs/source/api.rst
+++ b/docs/source/api.rst
@@ -6,4 +6,6 @@ API
.. autoclass:: torchrunx.LaunchResult
:members:
-.. autoclass:: torchrunx.AgentKilledError
+.. autoclass:: torchrunx.AgentFailedError
+
+.. autoclass:: torchrunx.WorkerFailedError
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 6ea3a8b2..8e64563c 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -1,51 +1,96 @@
import os
import sys
-sys.path.insert(0, os.path.abspath('../../src'))
+sys.path.insert(0, os.path.abspath("../../src"))
# Configuration file for the Sphinx documentation builder.
-# -- Project information
-
-project = 'torchrunx'
-
-# -- General configuration
+project = "torchrunx"
+github_username = "apoorvkh"
+github_repository = "torchrunx"
+html_theme = "furo"
extensions = [
- 'sphinx.ext.duration',
- 'sphinx.ext.doctest',
- 'sphinx.ext.autodoc',
- 'sphinx.ext.autosummary',
- 'sphinx.ext.intersphinx',
- 'myst_parser',
- 'sphinx_toolbox.sidebar_links',
- 'sphinx_toolbox.github',
- 'sphinx.ext.autodoc.typehints',
- #"sphinx_autodoc_typehints",
+ "sphinx.ext.duration",
+ "sphinx.ext.autodoc",
+ "sphinx.ext.intersphinx",
+ "myst_parser",
+ "sphinx_toolbox.sidebar_links",
+ "sphinx_toolbox.github",
+ "sphinx.ext.napoleon",
+ "sphinx.ext.autodoc.typehints",
+ "sphinx.ext.linkcode",
]
+autodoc_mock_imports = ["torch", "fabric", "cloudpickle", "sys", "logging", "typing_extensions"]
autodoc_typehints = "both"
-#typehints_defaults = 'comma'
-
-github_username = 'apoorvkh'
-github_repository = 'torchrunx'
+autodoc_typehints_description_target = "documented_params"
-autodoc_mock_imports = ['torch', 'fabric', 'cloudpickle', 'typing_extensions']
+maximum_signature_line_length = 100
intersphinx_mapping = {
- 'python': ('https://docs.python.org/3/', None),
- 'sphinx': ('https://www.sphinx-doc.org/en/master/', None),
+ "python": ("https://docs.python.org/3/", None),
}
-intersphinx_disabled_domains = ['std']
+intersphinx_disabled_domains = ["std"]
+
+
+## Link code to Github source
+# From: https://github.com/scikit-learn/scikit-learn/blob/main/doc/sphinxext/github_link.py
+
+import inspect
+import os
+import subprocess
+import sys
+from operator import attrgetter
+
+package = project
+
+try:
+ revision = (
+ subprocess.check_output("git rev-parse --short HEAD".split()).strip().decode("utf-8")
+ )
+except (subprocess.CalledProcessError, OSError):
+ print("Failed to execute git to get revision")
+ revision = None
+
+url_fmt = (
+ f"https://github.com/{github_username}/{github_repository}/"
+ "blob/{revision}/src/{package}/{path}#L{lineno}"
+)
+
+def linkcode_resolve(domain, info):
+ if revision is None:
+ return
+ if domain not in ("py", "pyx"):
+ return
+ if not info.get("module") or not info.get("fullname"):
+ return
-templates_path = ['_templates']
+ class_name = info["fullname"].split(".")[0]
+ module = __import__(info["module"], fromlist=[class_name])
+ obj = attrgetter(info["fullname"])(module)
-# -- Options for HTML output
+ # Unwrap the object to get the correct source
+ # file in case that is wrapped by a decorator
+ obj = inspect.unwrap(obj)
-html_theme = 'furo'
+ try:
+ fn = inspect.getsourcefile(obj)
+ except Exception:
+ fn = None
+ if not fn:
+ try:
+ fn = inspect.getsourcefile(sys.modules[obj.__module__])
+ except Exception:
+ fn = None
+ if not fn:
+ return
-# -- Options for EPUB output
-epub_show_urls = 'footnote'
+ fn = os.path.relpath(fn, start=os.path.dirname(__import__(package).__file__))
+ try:
+ lineno = inspect.getsourcelines(obj)[1]
+ except Exception:
+ lineno = ""
+ return url_fmt.format(revision=revision, package=package, path=fn, lineno=lineno)
-# code block syntax highlighting
-#pygments_style = 'sphinx'
+## End of "link code to Github source"
diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst
index 707e376c..7a661a30 100644
--- a/docs/source/contributing.rst
+++ b/docs/source/contributing.rst
@@ -1,20 +1,2 @@
-Contributing
-============
-
.. include:: ../../CONTRIBUTING.md
:parser: myst_parser.sphinx_
-
-.. Development environment
-.. -----------------------
-
-.. Ensure you have the latest development environment installed. After cloning our repository, `install pixi `_ and run ``pixi shell`` in the repo's root directory. Additionally, we use `ruff `_ for linting and formatting, `pyright `_ for type checking, and ``pytest`` for testing.
-
-.. Testing
-.. -------
-
-.. ``tests/`` contains ``pytest``-style tests for validating that code changes do not break the core functionality of **torchrunx**. At the moment, we have a few simple CI tests powered by Github action, which are limited to single-agent CPU-only tests due to Github's infrastructure.
-
-.. Contributing
-.. ------------
-
-.. Make a pull request with your changes and we'll try to look at soon! If addressing a specific issue, mention it in the PR, and offer a short explanation of your fix. If adding a new feature, explain why it's meaningful and belongs in **torchrunx**.
\ No newline at end of file
diff --git a/docs/source/how_it_works.rst b/docs/source/how_it_works.rst
index 3ceb4b4c..9550e8d1 100644
--- a/docs/source/how_it_works.rst
+++ b/docs/source/how_it_works.rst
@@ -1,15 +1,12 @@
How it works
============
-In order to organize processes on different nodes, **torchrunx** maintains the following hierarchy:
+If you want to (e.g.) train your model on several machines with **N** GPUs each, you should run your training function in **N** parallel processes on each machine. During training, each of these processes runs the same training code (i.e. your function) and communicate with each other (e.g. to synchronize gradients) using a `distributed process group `_.
-#. The launcher, the process in which ``torchrunx.Launcher.run`` is executed: Connects to remote hosts and initializes and configures "agents", passes errors and return values from agents to the caller, and is responsible for cleaning up.
-#. The agents, initialized on machines where computation is to be performed: Responsible for starting and monitoring "workers".
-#. The workers, spawned by agents: Responsible for initializing a ``torch.distributed`` process group, and running the distributed function provided by the user.
+Your script can call our library (via `mod:torchrunx.launch`) and specify a function to distribute. The main process running your script is henceforth known as the **launcher** process.
-An example of how this hierarchy might look in practice is the following:
-Suppose we wish to distribute a training function over four GPUs, and we have access to a cluster where nodes have two available GPUs each. Say that a single instance of our training function can leverage multiple GPUs. We can choose two available nodes and use the launcher to launch our function on those two nodes, specifying that we only need one worker per node, since a single instance of our training function can use both GPUs on each node. The launcher will launch an agent on each node and pass our configuration to the agents, after which the agents will each initialize one worker to begin executing the training function. We could also run two workers per node, each with one GPU, giving us four workers, although this would be slower.
+Our launcher process spawns an **agent** process (via SSH) on each machine. Each agent then spawns **N** processes (known as **workers**) on its machine. All workers form a process group (with the specified `mod:torchrunx.launch` ``backend``) and run your function in parallel.
-The launcher initializes the agents by simply SSHing into the provided hosts, and executing our agent code there. The launcher also provides key environmental variables from the launch environment to the sessions where the agents are started and tries to activate the same Python environment that was used to execute the launcher. This is one reason why all machines either running a launcher or agent process should share a filesystem.
+**Agent–Worker Communication.** Our agents poll their workers every second and time-out if unresponsive for 5 seconds. Upon polling, our agents receive ``None`` (if the worker is still running) or a `RunProcsResult `_, indicating that the workers have either completed (providing an object returned from or the exception raised by our function) or failed (e.g. due to segmentation fault or OS signal).
-The launcher and agents perform exception handling such that any exceptions in the worker processes are appropriately raised by the launcher process. The launcher and agents communicate using a ``torch.distributed`` process group, separate from the group that the workers use.
\ No newline at end of file
+**Launcher–Agent Communication.** The launcher and agents form a distributed group (with the CPU-based `GLOO backend `_) for the communication purposes of our library. Our agents synchronize their own "statuses" with each other and the launcher. An agent's status can include whether it is running/failed/completed and the result of the function. If the launcher or any agent fails to synchronize, all raise a `mod:torchrunx.AgentFailedError` and terminate. If any worker fails or raises an exception, the launcher raises a `mod:torchrunx.WorkerFailedError` or that exception and terminates along with all the agents. If all agents succeed, the launcher returns the objects returned by each worker.
diff --git a/pixi.lock b/pixi.lock
index 49a19663..3f01b378 100644
--- a/pixi.lock
+++ b/pixi.lock
@@ -209,7 +209,9 @@ environments:
- pypi: https://files.pythonhosted.org/packages/ac/25/e715fa0bc24ac2114ed69da33adf451a38abb6f3f24ec207908112e9ba53/cryptography-43.0.3-cp39-abi3-manylinux_2_28_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/d5/50/83c593b07763e1161326b3b8c6686f0f4b0f24d5526546bee538c89837d6/decorator-5.1.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/20/8d/778b7d51b981a96554f29136cd59ca7880bf58094338085bcf2a979a0e6a/Deprecated-1.2.14-py2.py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/d5/7c/e9fcff7623954d86bdc17782036cbf715ecab1bec4847c008557affe1ca8/docstring_parser-0.16-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/ac/ac/aa3d8e0acbcd71140420bc752d7c9779cf3a2a3bb1d7ef30944e38b2cd39/eval_type_backport-0.2.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/d6/1f/e99e23ee01847147fa194e8d41cfcf2535a2dbfcb51414c541cadb15c5d7/fabric-3.2.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b9/f8/feced7779d755758a52d1f6635d990b8d98dc0a29fa568bbe0625f18fdf3/filelock-3.16.1-py3-none-any.whl
@@ -267,7 +269,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/46/96/464058dd1d980014fb5aa0a1254e78799efb3096fc7a4823cd66a1621276/ruff-0.7.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/5d/80/81ba44fc82afbf5ca553913ac49460e325dc5cf00c317b34c14d43ebd76b/safetensors-0.4.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/54/24/b4293291fa1dd830f353d2cb163295742fa87f179fcc8a20a306a81978b7/SecretStorage-3.3.3-py3-none-any.whl
- - pypi: https://files.pythonhosted.org/packages/31/2d/90165d51ecd38f9a02c6832198c13a4e48652485e2ccf863ebb942c531b6/setuptools-75.2.0-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/e2/d1/a1d3189e7873408b9dc396aef0d7926c198b0df2aa3ddb5b539d3e89a70f/shtab-1.7.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/dc/a4/90123871996bfb8a7148cd11d61e7a0ddc0118114c071730b3dc3a05c7bc/submitit-1.5.2-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/63/04/1ce87a3eae86aa4e301e863d42d6907deaf61f3e9178210d9ebe653e948c/tokenizers-0.20.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
@@ -278,6 +280,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/c4/69/57e0fed438d547524e08bfedc587078314176ad1c15c8be904d3f03149ec/triton-3.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/5d/ec/00f9d5fd040ae29867355e559a94e9a8429225a0284a3f5f091a3878bfc0/twine-5.1.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl
+ - pypi: https://files.pythonhosted.org/packages/29/32/daacb047883a572f452f3763e07135a45d023a068badb6d83f9b07bc9410/tyro-0.8.13-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/ce/d9/5f4c13cecde62396b0d3fe530a50ccea91e7dfc1ccf0e09c228841bb5ba8/urllib3-2.2.3-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b1/e7/459a8a4f40f2fa65eb73cb3f339e6d152957932516d18d0e996c7ae2d7ae/wrapt-1.16.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl
- pypi: https://files.pythonhosted.org/packages/62/8b/5ba542fa83c90e09eac972fc9baca7a88e7e7ca4b221a89251954019308b/zipp-3.20.2-py3-none-any.whl
@@ -535,12 +538,26 @@ packages:
- bump2version<1 ; extra == 'dev'
- sphinx<2 ; extra == 'dev'
requires_python: '>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*'
+- kind: pypi
+ name: docstring-parser
+ version: '0.16'
+ url: https://files.pythonhosted.org/packages/d5/7c/e9fcff7623954d86bdc17782036cbf715ecab1bec4847c008557affe1ca8/docstring_parser-0.16-py3-none-any.whl
+ sha256: bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637
+ requires_python: '>=3.6,<4.0'
- kind: pypi
name: docutils
version: 0.21.2
url: https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl
sha256: dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2
requires_python: '>=3.9'
+- kind: pypi
+ name: eval-type-backport
+ version: 0.2.0
+ url: https://files.pythonhosted.org/packages/ac/ac/aa3d8e0acbcd71140420bc752d7c9779cf3a2a3bb1d7ef30944e38b2cd39/eval_type_backport-0.2.0-py3-none-any.whl
+ sha256: ac2f73d30d40c5a30a80b8739a789d6bb5e49fdffa66d7912667e2015d9c9933
+ requires_dist:
+ - pytest ; extra == 'tests'
+ requires_python: '>=3.8'
- kind: pypi
name: exceptiongroup
version: 1.2.2
@@ -1645,66 +1662,15 @@ packages:
- jeepney>=0.6
requires_python: '>=3.6'
- kind: pypi
- name: setuptools
- version: 75.2.0
- url: https://files.pythonhosted.org/packages/31/2d/90165d51ecd38f9a02c6832198c13a4e48652485e2ccf863ebb942c531b6/setuptools-75.2.0-py3-none-any.whl
- sha256: a7fcb66f68b4d9e8e66b42f9876150a3371558f98fa32222ffaa5bced76406f8
+ name: shtab
+ version: 1.7.1
+ url: https://files.pythonhosted.org/packages/e2/d1/a1d3189e7873408b9dc396aef0d7926c198b0df2aa3ddb5b539d3e89a70f/shtab-1.7.1-py3-none-any.whl
+ sha256: 32d3d2ff9022d4c77a62492b6ec875527883891e33c6b479ba4d41a51e259983
requires_dist:
- - pytest-checkdocs>=2.4 ; extra == 'check'
- - pytest-ruff>=0.2.1 ; sys_platform != 'cygwin' and extra == 'check'
- - ruff>=0.5.2 ; sys_platform != 'cygwin' and extra == 'check'
- - packaging>=24 ; extra == 'core'
- - more-itertools>=8.8 ; extra == 'core'
- - jaraco-text>=3.7 ; extra == 'core'
- - wheel>=0.43.0 ; extra == 'core'
- - platformdirs>=2.6.2 ; extra == 'core'
- - jaraco-collections ; extra == 'core'
- - jaraco-functools ; extra == 'core'
- - packaging ; extra == 'core'
- - more-itertools ; extra == 'core'
- - importlib-metadata>=6 ; python_version < '3.10' and extra == 'core'
- - tomli>=2.0.1 ; python_version < '3.11' and extra == 'core'
- - importlib-resources>=5.10.2 ; python_version < '3.9' and extra == 'core'
- - pytest-cov ; extra == 'cover'
- - sphinx>=3.5 ; extra == 'doc'
- - jaraco-packaging>=9.3 ; extra == 'doc'
- - rst-linker>=1.9 ; extra == 'doc'
- - furo ; extra == 'doc'
- - sphinx-lint ; extra == 'doc'
- - jaraco-tidelift>=1.4 ; extra == 'doc'
- - pygments-github-lexers==0.0.5 ; extra == 'doc'
- - sphinx-favicon ; extra == 'doc'
- - sphinx-inline-tabs ; extra == 'doc'
- - sphinx-reredirects ; extra == 'doc'
- - sphinxcontrib-towncrier ; extra == 'doc'
- - sphinx-notfound-page<2,>=1 ; extra == 'doc'
- - pyproject-hooks!=1.1 ; extra == 'doc'
- - towncrier<24.7 ; extra == 'doc'
- - pytest-enabler>=2.2 ; extra == 'enabler'
- - pytest!=8.1.*,>=6 ; extra == 'test'
- - virtualenv>=13.0.0 ; extra == 'test'
- - wheel>=0.44.0 ; extra == 'test'
- - pip>=19.1 ; extra == 'test'
- - packaging>=23.2 ; extra == 'test'
- - jaraco-envs>=2.2 ; extra == 'test'
- - pytest-xdist>=3 ; extra == 'test'
- - jaraco-path>=3.2.0 ; extra == 'test'
- - build[virtualenv]>=1.0.3 ; extra == 'test'
- - filelock>=3.4.0 ; extra == 'test'
- - ini2toml[lite]>=0.14 ; extra == 'test'
- - tomli-w>=1.0.0 ; extra == 'test'
- - pytest-timeout ; extra == 'test'
- - pytest-home>=0.5 ; extra == 'test'
- - pytest-subprocess ; extra == 'test'
- - pyproject-hooks!=1.1 ; extra == 'test'
- - jaraco-test ; extra == 'test'
- - jaraco-develop>=7.21 ; (python_version >= '3.9' and sys_platform != 'cygwin') and extra == 'test'
- - pytest-perf ; sys_platform != 'cygwin' and extra == 'test'
- - pytest-mypy ; extra == 'type'
- - mypy==1.11.* ; extra == 'type'
- - importlib-metadata>=7.0.2 ; python_version < '3.10' and extra == 'type'
- - jaraco-develop>=7.21 ; sys_platform != 'cygwin' and extra == 'type'
- requires_python: '>=3.8'
+ - pytest>=6 ; extra == 'dev'
+ - pytest-cov ; extra == 'dev'
+ - pytest-timeout ; extra == 'dev'
+ requires_python: '>=3.7'
- kind: conda
name: sqlite
version: 3.46.0
@@ -1827,7 +1793,7 @@ packages:
name: torchrunx
version: 0.2.0
path: .
- sha256: 6e56e7865dd6c8758124a48d6eb086255880194f81b8c81f385bcf00429164d8
+ sha256: 2902e99e3cf4c53010007e6ad3243b25a8c7be8325ad6d9889730c9f075ed72d
requires_dist:
- cloudpickle>=3.0
- fabric>=3.2
@@ -2260,6 +2226,35 @@ packages:
url: https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl
sha256: 04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d
requires_python: '>=3.8'
+- kind: pypi
+ name: tyro
+ version: 0.8.13
+ url: https://files.pythonhosted.org/packages/29/32/daacb047883a572f452f3763e07135a45d023a068badb6d83f9b07bc9410/tyro-0.8.13-py3-none-any.whl
+ sha256: 8664cb409915c68d1b9ffc3c9e44b260387369b319a4d94399918c70d900b7de
+ requires_dist:
+ - docstring-parser>=0.16
+ - typing-extensions>=4.7.0
+ - rich>=11.1.0
+ - shtab>=1.5.6
+ - colorama>=0.4.0 ; platform_system == 'Windows'
+ - eval-type-backport>=0.1.3 ; python_version < '3.10'
+ - backports-cached-property>=1.0.2 ; python_version < '3.8'
+ - pyyaml>=6.0 ; extra == 'dev'
+ - frozendict>=2.3.4 ; extra == 'dev'
+ - pytest>=7.1.2 ; extra == 'dev'
+ - pytest-cov>=3.0.0 ; extra == 'dev'
+ - omegaconf>=2.2.2 ; extra == 'dev'
+ - attrs>=21.4.0 ; extra == 'dev'
+ - torch>=1.10.0 ; extra == 'dev'
+ - pyright!=1.1.379,>=1.1.349 ; extra == 'dev'
+ - ruff>=0.1.13 ; extra == 'dev'
+ - mypy>=1.4.1 ; extra == 'dev'
+ - numpy>=1.20.0 ; extra == 'dev'
+ - pydantic>=2.5.2 ; extra == 'dev'
+ - coverage[toml]>=6.5.0 ; extra == 'dev'
+ - eval-type-backport>=0.1.3 ; extra == 'dev'
+ - flax>=0.6.9 ; python_version >= '3.8' and extra == 'dev'
+ requires_python: '>=3.7'
- kind: conda
name: tzdata
version: 2024b
diff --git a/pixi.toml b/pixi.toml
index c6c14e2f..ae7bdf8a 100644
--- a/pixi.toml
+++ b/pixi.toml
@@ -17,10 +17,10 @@ build = "*"
twine = "*"
[feature.extra.pypi-dependencies]
-transformers = "*"
+tyro = "*"
submitit = "*"
-setuptools = "*"
accelerate = "*"
+transformers = "*"
[environments]
default = { features = ["package", "dev"], solve-group = "default" }
diff --git a/pyproject.toml b/pyproject.toml
index ecc572a1..3fb9d1ab 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,12 +36,12 @@ Documentation = "https://torchrunx.readthedocs.io"
[tool.ruff]
include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"]
+exclude = ["docs"]
line-length = 100
src = ["src", "tests"]
[tool.ruff.lint]
select = ["ALL"]
ignore = [
- "D", # documentation
"ANN101", "ANN102", "ANN401", # self / cls / Any annotations
"BLE001", # blind exceptions
"TD", # todo syntax
@@ -54,9 +54,12 @@ ignore = [
]
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = [
+ "D",
"S101", # allow asserts
"T201" # allow prints
]
+[tool.ruff.lint.pydocstyle]
+convention = "google"
[tool.pyright]
include = ["src", "tests"]
diff --git a/src/torchrunx/__init__.py b/src/torchrunx/__init__.py
index 177c444e..405a7715 100644
--- a/src/torchrunx/__init__.py
+++ b/src/torchrunx/__init__.py
@@ -1,6 +1,8 @@
+"""API for our torchrunx library."""
+
from .launcher import Launcher, LaunchResult, launch
-from .logging_utils import add_filter_to_handler, file_handler, stream_handler
-from .utils import AgentFailedError, WorkerFailedError
+from .utils.errors import AgentFailedError, WorkerFailedError
+from .utils.logging import add_filter_to_handler, file_handler, stream_handler
__all__ = [
"AgentFailedError",
diff --git a/src/torchrunx/__main__.py b/src/torchrunx/__main__.py
index c3b3099e..8626c1f8 100644
--- a/src/torchrunx/__main__.py
+++ b/src/torchrunx/__main__.py
@@ -1,7 +1,9 @@
+"""CLI entrypoint used for starting agents on different nodes."""
+
from argparse import ArgumentParser
from .agent import main
-from .utils import LauncherAgentGroup
+from .utils.comm import LauncherAgentGroup
if __name__ == "__main__":
parser = ArgumentParser()
diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py
index 2a94c20d..27dd8489 100644
--- a/src/torchrunx/agent.py
+++ b/src/torchrunx/agent.py
@@ -1,110 +1,44 @@
+"""Primary logic for agent processes."""
+
from __future__ import annotations
-import datetime
+__all__ = ["main"]
+
import logging
import os
import socket
import sys
import tempfile
-import traceback
-from dataclasses import dataclass
-from typing import Any, Callable, Literal
-import cloudpickle
import torch
-import torch.distributed as dist
import torch.distributed.elastic.multiprocessing as dist_mp
-from .logging_utils import log_records_to_socket, redirect_stdio_to_logger
-from .utils import (
+from .utils.comm import (
AgentPayload,
AgentStatus,
- ExceptionFromWorker,
LauncherAgentGroup,
get_open_port,
)
+from .utils.logging import log_records_to_socket, redirect_stdio_to_logger
+from .worker import WorkerArgs, worker_entrypoint
-@dataclass
-class WorkerArgs:
- function: Callable
- logger_hostname: str
- logger_port: int
- main_agent_hostname: str
- main_agent_port: int
- backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None
- rank: int
- local_rank: int
- local_world_size: int
- world_size: int
- hostname: str
- timeout: int
-
- def serialize(self) -> SerializedWorkerArgs:
- return SerializedWorkerArgs(worker_args=self)
-
-
-class SerializedWorkerArgs:
- def __init__(self, worker_args: WorkerArgs) -> None:
- self.bytes = cloudpickle.dumps(worker_args)
-
- def deserialize(self) -> WorkerArgs:
- return cloudpickle.loads(self.bytes)
-
-
-def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | ExceptionFromWorker:
- worker_args: WorkerArgs = serialized_worker_args.deserialize()
-
- logger = logging.getLogger()
-
- log_records_to_socket(
- logger=logger,
- hostname=worker_args.hostname,
- worker_rank=worker_args.local_rank,
- logger_hostname=worker_args.logger_hostname,
- logger_port=worker_args.logger_port,
- )
-
- redirect_stdio_to_logger(logger)
-
- os.environ["RANK"] = str(worker_args.rank)
- os.environ["LOCAL_RANK"] = str(worker_args.local_rank)
- os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size)
- os.environ["WORLD_SIZE"] = str(worker_args.world_size)
- os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
- os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)
-
- if worker_args.backend is not None:
- backend = worker_args.backend
- if backend == "auto":
- backend = "nccl" if torch.cuda.is_available() else "gloo"
-
- dist.init_process_group(
- backend=backend,
- world_size=worker_args.world_size,
- rank=worker_args.rank,
- store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage]
- host_name=worker_args.main_agent_hostname,
- port=worker_args.main_agent_port,
- world_size=worker_args.world_size,
- is_master=(worker_args.rank == 0),
- ),
- timeout=datetime.timedelta(seconds=worker_args.timeout),
- )
-
- try:
- return worker_args.function()
- except Exception as e:
- traceback.print_exc()
- return ExceptionFromWorker(exception=e)
- finally:
- sys.stdout.flush()
- sys.stderr.flush()
+def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int) -> None:
+ """Main function for agent processes (started on each node).
+ This function spawns local worker processes (which run the target function). All agents monitor
+ their worker statuses (including returned objects and raised exceptions) and communicate these
+ with each other (and launcher). All agents terminate if failure occurs in any agent.
-def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int) -> None:
+ Arguments:
+ launcher_agent_group: The communication group between launcher and all agents.
+ logger_hostname: The hostname of the launcher (for logging).
+ logger_port: The port of the launcher (for logging).
+ """
agent_rank = launcher_agent_group.rank - 1
+ # Communicate initial payloads between launcher/agents
+
payload = AgentPayload(
hostname=socket.getfqdn(),
port=get_open_port(),
@@ -119,23 +53,25 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
worker_global_ranks = launcher_payload.worker_global_ranks[agent_rank]
num_workers = len(worker_global_ranks)
+ # Stream logs to logging server
+
logger = logging.getLogger()
log_records_to_socket(
logger=logger,
hostname=hostname,
- worker_rank=None,
+ local_rank=None,
logger_hostname=logger_hostname,
logger_port=logger_port,
)
redirect_stdio_to_logger(logger)
- # spawn workers
+ # Spawn worker processes
ctx = dist_mp.start_processes(
name=f"{hostname}_",
- entrypoint=entrypoint,
+ entrypoint=worker_entrypoint,
args={
i: (
WorkerArgs(
@@ -165,6 +101,9 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_
), # pyright: ignore [reportArgumentType]
)
+ # Monitor and communicate agent statuses
+ # Terminate gracefully upon failure
+
try:
status = None
while True:
diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py
index 64b871dc..b745ac7a 100644
--- a/src/torchrunx/launcher.py
+++ b/src/torchrunx/launcher.py
@@ -1,5 +1,9 @@
+"""For launching functions with our library."""
+
from __future__ import annotations
+__all__ = ["Launcher", "launch", "LaunchResult"]
+
import fnmatch
import ipaddress
import itertools
@@ -15,31 +19,33 @@
from multiprocessing import Process
from operator import add
from pathlib import Path
-from typing import Any, Callable, Literal, overload
+from typing import Any, Callable, Literal
import fabric
import torch.distributed as dist
-from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers
-from .logging_utils import LogRecordSocketReceiver, default_handlers
-from .utils import (
- AgentStatus,
- ExceptionFromWorker,
+from .utils.comm import (
LauncherAgentGroup,
LauncherPayload,
- WorkerFailedError,
get_open_port,
)
+from .utils.environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers
+from .utils.errors import (
+ ExceptionFromWorker,
+ WorkerFailedError,
+)
+from .utils.logging import LogRecordSocketReceiver, default_handlers
@dataclass
class Launcher:
+ """Useful for sequential invocations or for specifying arguments via CLI."""
+
hostnames: list[str] | Literal["auto", "slurm"] = "auto"
workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto"
ssh_config_file: str | os.PathLike | None = None
backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto"
timeout: int = 600
- log_handlers: list[Handler] | Literal["auto"] | None = "auto"
default_env_vars: tuple[str, ...] = (
"PATH",
"LD_LIBRARY",
@@ -58,13 +64,15 @@ def run( # noqa: C901, PLR0912
func: Callable,
func_args: tuple[Any] | None = None,
func_kwargs: dict[str, Any] | None = None,
+ log_handlers: list[Handler] | Literal["auto"] | None = "auto",
) -> LaunchResult:
+ """Run a function using the :mod:`torchrunx.Launcher` configuration."""
if not dist.is_available():
msg = "The torch.distributed package is not available."
raise RuntimeError(msg)
- hostnames = resolve_hostnames(self.hostnames)
- workers_per_host = resolve_workers_per_host(self.workers_per_host, len(hostnames))
+ hostnames = _resolve_hostnames(self.hostnames)
+ workers_per_host = _resolve_workers_per_host(self.workers_per_host, len(hostnames))
launcher_hostname = socket.getfqdn()
launcher_port = get_open_port()
@@ -76,10 +84,10 @@ def run( # noqa: C901, PLR0912
agent_payloads = None
try:
- # start logging server
+ # Start logging server (recieves LogRecords from agents/workers)
- log_receiver = build_logging_server(
- log_handlers=self.log_handlers,
+ log_receiver = _build_logging_server(
+ log_handlers=log_handlers,
launcher_hostname=launcher_hostname,
hostnames=hostnames,
workers_per_host=workers_per_host,
@@ -94,11 +102,11 @@ def run( # noqa: C901, PLR0912
log_process.start()
- # start agents on each node
+ # Start agents on each node
for i, hostname in enumerate(hostnames):
- execute_command(
- command=build_launch_command(
+ _execute_command(
+ command=_build_launch_command(
launcher_hostname=launcher_hostname,
launcher_port=launcher_port,
logger_port=log_receiver.port,
@@ -111,7 +119,7 @@ def run( # noqa: C901, PLR0912
ssh_config_file=self.ssh_config_file,
)
- # initialize launcher-agent process group
+ # Initialize launcher-agent process group
# ranks = (launcher, agent_{hostnames[0]}, ..., agent[-1])
launcher_agent_group = LauncherAgentGroup(
@@ -121,7 +129,7 @@ def run( # noqa: C901, PLR0912
rank=0,
)
- # build and sync payloads between launcher and agents
+ # Sync initial payloads between launcher and agents
_cumulative_workers = [0, *itertools.accumulate(workers_per_host)]
@@ -141,7 +149,7 @@ def run( # noqa: C901, PLR0912
launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload)
- # loop to monitor agent statuses (until failed or done)
+ # Monitor agent statuses (until failed or done)
while True:
# could raise AgentFailedError
@@ -170,13 +178,15 @@ def run( # noqa: C901, PLR0912
# cleanup: SIGTERM all agents
if agent_payloads is not None:
for agent_payload, agent_hostname in zip(agent_payloads, hostnames):
- execute_command(
+ _execute_command(
command=f"kill {agent_payload.process_id}",
hostname=agent_hostname,
ssh_config_file=self.ssh_config_file,
)
- return LaunchResult(hostnames=hostnames, agent_statuses=agent_statuses)
+ # if launch is successful: return objects from workers
+ return_values = [s.return_values for s in agent_statuses]
+ return LaunchResult(hostnames=hostnames, return_values=return_values)
def launch(
@@ -188,7 +198,6 @@ def launch(
ssh_config_file: str | os.PathLike | None = None,
backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto",
timeout: int = 600,
- log_handlers: list[Handler] | Literal["auto"] | None = "auto",
default_env_vars: tuple[str, ...] = (
"PATH",
"LD_LIBRARY",
@@ -201,101 +210,78 @@ def launch(
),
extra_env_vars: tuple[str, ...] = (),
env_file: str | os.PathLike | None = None,
+ log_handlers: list[Handler] | Literal["auto"] | None = "auto",
) -> LaunchResult:
+ """Launch a distributed PyTorch function on the specified nodes.
+
+ Arguments:
+ func: Function to run on each worker.
+ func_args: Positional arguments for ``func``.
+ func_kwargs: Keyword arguments for ``func``.
+ hostnames: Nodes on which to launch the function.
+ Defaults to nodes inferred from a SLURM environment or localhost.
+ workers_per_host: Number of processes to run per node.
+ Can specify different counts per node with a list.
+ ssh_config_file: Path to an SSH configuration file for connecting to nodes.
+ Defaults to ``~/.ssh/config`` or ``/etc/ssh/ssh_config``.
+ backend: `Backend `_
+ for worker process group. Defaults to NCCL (GPU) or GLOO (CPU). Set `None` to disable.
+ timeout: Worker process group timeout (seconds).
+ default_env_vars: Environment variables to copy from the launcher process to workers.
+ Supports bash pattern matching syntax.
+ extra_env_vars: Additional user-specified environment variables to copy.
+ env_file: Path to a file (e.g., `.env`) with additional environment variables to copy.
+ log_handlers: Handlers to manage agent and worker logs.
+ Defaults to an automatic basic logging scheme.
+
+ Raises:
+ RuntimeError: If there are configuration issues.
+ AgentFailedError: If an agent fails, e.g. from an OS signal.
+ WorkerFailedError: If a worker fails, e.g. from a segmentation fault.
+ Exception: Any exception raised in a worker process is propagated.
"""
- Launch a distributed PyTorch function on the specified nodes.
-
- :param func:
- :param func_args:
- :param func_kwargs:
- :param hostnames: Nodes to launch the function on. Default infers from a SLURM environment or runs on localhost.
- :param workers_per_host: Number of processes to run per node. Can define per node with :type:`list[int]`.
- :param ssh_config_file: An SSH configuration file for connecting to nodes, by default loads ``~/.ssh/config`` or ``/etc/ssh/ssh_config``.
- :param backend: `Backend `_ to initialize worker process group with. Default uses NCCL (if GPUs available) or GLOO. Disabled by ``None``.
- :param timeout: Worker process group timeout (seconds).
- :param log_handlers: A list of handlers to manage agent and worker logs. Default uses an automatic basic logging scheme.
- :param default_env_vars: A list of environmental variables to be copied from the launcher process to workers. Allows for bash pattern matching syntax.
- :param extra_env_vars: Additional, user-specified variables to copy.
- :param env_file: A file (like ``.env``) with additional environment variables to copy.
- :raises RuntimeError: If ``torch.distributed`` not available
- :raises AgentKilledError: If any agent is killed
- :raises Exception: Propagates exceptions raised in worker processes
- """ # noqa: E501
return Launcher(
hostnames=hostnames,
workers_per_host=workers_per_host,
ssh_config_file=ssh_config_file,
backend=backend,
timeout=timeout,
- log_handlers=log_handlers,
default_env_vars=default_env_vars,
extra_env_vars=extra_env_vars,
env_file=env_file,
- ).run(func=func, func_args=func_args, func_kwargs=func_kwargs)
+ ).run(
+ func=func,
+ func_args=func_args,
+ func_kwargs=func_kwargs,
+ log_handlers=log_handlers,
+ )
+@dataclass
class LaunchResult:
- def __init__(self, hostnames: list[str], agent_statuses: list[AgentStatus]) -> None:
- self.hostnames: list[str] = hostnames
- self.return_values: list[list[Any]] = [s.return_values for s in agent_statuses]
-
- @overload
- def all(self) -> dict[str, list[Any]]:
- pass
-
- @overload
- def all(self, by: Literal["hostname"]) -> dict[str, list[Any]]:
- pass
-
- @overload
- def all(self, by: Literal["rank"]) -> list[Any]:
- pass
-
- def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]:
- """
- Get all worker return values by rank or hostname.
-
- :param by: Whether to aggregate all return values by hostname, or just output all of them \
- in order of rank, defaults to ``'hostname'``
- """
- if by == "hostname":
- return dict(zip(self.hostnames, self.return_values))
- elif by == "rank": # noqa: RET505
- return reduce(add, self.return_values)
-
- msg = "Invalid argument: expected by=('hostname' | 'rank')"
- raise TypeError(msg)
-
- def values(self, hostname: str) -> list[Any]:
- """
- Get worker return values for host ``hostname``.
-
- :param hostname: The host to get return values from
- """
- host_idx = self.hostnames.index(hostname)
- return self.return_values[host_idx]
-
- def value(self, rank: int) -> Any:
- """
- Get worker return value from global rank ``rank``.
-
- :param rank: Global worker rank to get return value from
- """
- if rank < 0:
- msg = f"Rank {rank} must be larger than 0"
- raise ValueError(msg)
-
- for values in self.return_values:
- if rank >= len(values):
- rank -= len(values)
- else:
- return values[rank]
-
- msg = f"Rank {rank} larger than world_size"
- raise ValueError(msg)
+ """Container for objects returned from workers after successful launches."""
+
+ hostnames: list[str]
+ return_values: list[list[Any]]
+
+ def by_hostnames(self) -> dict[str, list[Any]]:
+ """All return values from workers, indexed by host and local rank."""
+ return dict(zip(self.hostnames, self.return_values))
+
+ def by_ranks(self) -> list[Any]:
+ """All return values from workers, indexed by global rank."""
+ return reduce(add, self.return_values)
+
+ def index(self, hostname: str, rank: int) -> Any:
+ """Get return value from worker by host and local rank."""
+ return self.return_values[self.hostnames.index(hostname)][rank]
+
+ def rank(self, i: int) -> Any:
+ """Get return value from worker by global rank."""
+ return self.by_ranks()[i]
-def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]:
+def _resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]:
if hostnames == "auto":
return auto_hosts()
if hostnames == "slurm":
@@ -303,7 +289,7 @@ def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[s
return hostnames
-def resolve_workers_per_host(
+def _resolve_workers_per_host(
workers_per_host: int | list[int] | Literal["auto", "slurm"],
num_hosts: int,
) -> list[int]:
@@ -321,7 +307,7 @@ def resolve_workers_per_host(
return workers_per_host
-def build_logging_server(
+def _build_logging_server(
log_handlers: list[Handler] | Literal["auto"] | None,
launcher_hostname: str,
hostnames: list[str],
@@ -346,7 +332,7 @@ def build_logging_server(
)
-def build_launch_command(
+def _build_launch_command(
launcher_hostname: str,
launcher_port: int,
logger_port: int,
@@ -388,7 +374,7 @@ def build_launch_command(
return " && ".join(commands)
-def execute_command(
+def _execute_command(
command: str,
hostname: str,
ssh_config_file: str | os.PathLike | None = None,
diff --git a/src/torchrunx/utils/__init__.py b/src/torchrunx/utils/__init__.py
new file mode 100644
index 00000000..b2aa5714
--- /dev/null
+++ b/src/torchrunx/utils/__init__.py
@@ -0,0 +1 @@
+"""Utility classes and functions."""
diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils/comm.py
similarity index 70%
rename from src/torchrunx/utils.py
rename to src/torchrunx/utils/comm.py
index b4c2e768..4ed16f0a 100644
--- a/src/torchrunx/utils.py
+++ b/src/torchrunx/utils/comm.py
@@ -1,5 +1,16 @@
+"""Utilities for Launcher-Agent communication."""
+
from __future__ import annotations
+__all__ = [
+ "get_open_port",
+ "LauncherAgentGroup",
+ "LauncherPayload",
+ "AgentPayload",
+ "ExceptionFromWorker",
+ "AgentStatus",
+]
+
import datetime
import socket
from contextlib import closing
@@ -10,33 +21,34 @@
import torch.distributed as dist
from typing_extensions import Self
+from .errors import AgentFailedError, ExceptionFromWorker, WorkerFailedError
+
if TYPE_CHECKING:
from torch.distributed.elastic.multiprocessing.api import RunProcsResult
def get_open_port() -> int:
+ """Return an open port number."""
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind(("", 0))
return s.getsockname()[1]
-class AgentFailedError(Exception):
- pass
-
-
-class WorkerFailedError(Exception):
- pass
-
-
@dataclass
class LauncherAgentGroup:
+ """Initializes a GLOO distributed process group between launcher and all agents."""
+
launcher_hostname: str
launcher_port: int
world_size: int
rank: int
def __post_init__(self) -> None:
- # timeout will raise torch.distributed.DistStoreError
+ """Initialize process group.
+
+ Raises:
+ torch.distributed.DistStoreError: if group initialization times out.
+ """
self.group = dist.init_process_group(
backend="gloo",
world_size=self.world_size,
@@ -57,7 +69,11 @@ def _deserialize(self, serialized: bytes) -> Any:
return cloudpickle.loads(serialized)
def _all_gather(self, obj: Any) -> list:
- """gather object from every rank to list on every rank"""
+ """Gather object from every rank to list on every rank.
+
+ Raises:
+ AgentFailedError: if any agent fails (observed by this communication).
+ """
try:
object_bytes = self._serialize(obj)
object_list = [b""] * self.world_size
@@ -72,20 +88,25 @@ def sync_payloads(
self,
payload: LauncherPayload | AgentPayload,
) -> tuple[LauncherPayload, list[AgentPayload]]:
+ """All-gather payloads across launcher and all agents."""
payloads = self._all_gather(payload)
launcher_payload = payloads[0]
agent_payloads = payloads[1:]
return launcher_payload, agent_payloads
def sync_agent_statuses(self, status: AgentStatus | None) -> list[AgentStatus]:
+ """All-gather agent statuses across launcher and all agents."""
return self._all_gather(status)[1:] # [0] is launcher (status=None)
def shutdown(self) -> None:
+ """Terminate process group."""
dist.destroy_process_group(group=self.group)
@dataclass
class LauncherPayload:
+ """Payload from launcher to agents with runtime information."""
+
fn: Callable
hostnames: list[str]
worker_global_ranks: list[list[int]]
@@ -96,18 +117,22 @@ class LauncherPayload:
@dataclass
class AgentPayload:
+ """Payload corresponding to each agent."""
+
hostname: str
port: int
process_id: int
@dataclass
-class ExceptionFromWorker:
- exception: Exception
+class AgentStatus:
+ """Status of each agent (to be synchronized in LauncherAgentGroup).
+ Attributes:
+ state: Whether the agent is running, failed, or done.
+ return_values: Objects returned (or exceptions raised) by workers (indexed by local rank).
+ """
-@dataclass
-class AgentStatus:
state: Literal["running", "failed", "done"]
return_values: list[Any | WorkerFailedError | ExceptionFromWorker] = field(
default_factory=list
@@ -115,12 +140,16 @@ class AgentStatus:
@classmethod
def from_result(cls, result: RunProcsResult | None) -> Self:
+ """Convert RunProcsResult (from polling worker process context) to AgentStatus."""
if result is None:
return cls(state="running")
+
for local_rank, failure in result.failures.items():
result.return_values[local_rank] = WorkerFailedError(failure.message)
+
return_values = list(result.return_values.values())
- failed = any(isinstance(v, ExceptionFromWorker) for v in return_values)
+
+ failed = any(isinstance(v, (ExceptionFromWorker, WorkerFailedError)) for v in return_values)
state = "failed" if failed else "done"
return cls(
diff --git a/src/torchrunx/environment.py b/src/torchrunx/utils/environment.py
similarity index 63%
rename from src/torchrunx/environment.py
rename to src/torchrunx/utils/environment.py
index 179cfb8d..c297f027 100644
--- a/src/torchrunx/environment.py
+++ b/src/torchrunx/utils/environment.py
@@ -1,5 +1,9 @@
+"""Utilities for determining hosts and workers in environment."""
+
from __future__ import annotations
+__all__ = ["in_slurm_job", "slurm_hosts", "slurm_workers", "auto_hosts", "auto_workers"]
+
import os
import subprocess
@@ -7,15 +11,12 @@
def in_slurm_job() -> bool:
+ """Check if current process is running in a Slurm allocation."""
return "SLURM_JOB_ID" in os.environ
def slurm_hosts() -> list[str]:
- """Retrieves hostnames of Slurm-allocated nodes.
-
- :return: Hostnames of nodes in current Slurm allocation
- :rtype: list[str]
- """
+ """Retrieves hostnames of Slurm-allocated nodes."""
# TODO: sanity check SLURM variables, commands
if not in_slurm_job():
msg = "Not in a SLURM job"
@@ -29,13 +30,7 @@ def slurm_hosts() -> list[str]:
def slurm_workers() -> int:
- """
- | Determines number of workers per node in current Slurm allocation using
- | the ``SLURM_JOB_GPUS`` or ``SLURM_CPUS_ON_NODE`` environmental variables.
-
- :return: The implied number of workers per node
- :rtype: int
- """
+ """Determines number of workers per node in current Slurm allocation."""
# TODO: sanity check SLURM variables, commands
if not in_slurm_job():
msg = "Not in a SLURM job"
@@ -47,17 +42,11 @@ def slurm_workers() -> int:
if "SLURM_GPUS_PER_NODE" in os.environ:
return int(os.environ["SLURM_GPUS_PER_NODE"])
- # TODO: should we assume that we plan to do one worker per CPU?
return int(os.environ["SLURM_CPUS_ON_NODE"])
def auto_hosts() -> list[str]:
- """
- Automatically determine hostname list
-
- :return: Hostnames in Slurm allocation, or ['localhost']
- :rtype: list[str]
- """
+ """Automatically determine hostnames to launch to."""
if in_slurm_job():
return slurm_hosts()
@@ -65,12 +54,7 @@ def auto_hosts() -> list[str]:
def auto_workers() -> int:
- """
- Automatically determine number of workers per host
-
- :return: Workers per host
- :rtype: int
- """
+ """Automatically determine workers per host from SLURM or based on GPU/CPU count."""
if in_slurm_job():
return slurm_workers()
diff --git a/src/torchrunx/utils/errors.py b/src/torchrunx/utils/errors.py
new file mode 100644
index 00000000..e6bb3d24
--- /dev/null
+++ b/src/torchrunx/utils/errors.py
@@ -0,0 +1,24 @@
+"""Exception classes for agents and workers."""
+
+from dataclasses import dataclass
+
+__all__ = [
+ "AgentFailedError",
+ "WorkerFailedError",
+ "ExceptionFromWorker",
+]
+
+
+class AgentFailedError(Exception):
+ """Raised if agent fails (e.g. if signal received)."""
+
+
+class WorkerFailedError(Exception):
+ """Raised if a worker fails (e.g. if signal recieved or segmentation fault)."""
+
+
+@dataclass
+class ExceptionFromWorker:
+ """Container for exceptions raised inside workers (from user script)."""
+
+ exception: Exception
diff --git a/src/torchrunx/logging_utils.py b/src/torchrunx/utils/logging.py
similarity index 61%
rename from src/torchrunx/logging_utils.py
rename to src/torchrunx/utils/logging.py
index d12b27f7..efd013a8 100644
--- a/src/torchrunx/logging_utils.py
+++ b/src/torchrunx/utils/logging.py
@@ -1,9 +1,23 @@
+"""Utilities for intercepting logs in worker processes and handling these in the Launcher."""
+
from __future__ import annotations
+__all__ = [
+ "LogRecordSocketReceiver",
+ "redirect_stdio_to_logger",
+ "log_records_to_socket",
+ "add_filter_to_handler",
+ "file_handler",
+ "stream_handler",
+ "file_handlers",
+ "default_handlers",
+]
+
import datetime
import logging
import pickle
import struct
+import sys
from contextlib import redirect_stderr, redirect_stdout
from dataclasses import dataclass
from io import StringIO
@@ -18,11 +32,121 @@
if TYPE_CHECKING:
import os
+## Handler utilities
+
+
+def add_filter_to_handler(
+ handler: Handler,
+ hostname: str,
+ local_rank: int | None, # None indicates agent
+ log_level: int = logging.NOTSET,
+) -> None:
+ """Apply a filter to :mod:`logging.Handler` so only specific worker logs are handled.
+
+ Args:
+ handler: Handler to be modified.
+ hostname: Name of specified host.
+ local_rank: Rank of specified worker on host (or ``None`` for agent itself).
+ log_level: Minimum log level to capture.
+ """
+
+ def _filter(record: WorkerLogRecord) -> bool:
+ return (
+ record.hostname == hostname
+ and record.local_rank == local_rank
+ and record.levelno >= log_level
+ )
+
+ handler.addFilter(_filter) # pyright: ignore [reportArgumentType]
+
+
+def stream_handler(
+ hostname: str, local_rank: int | None, log_level: int = logging.NOTSET
+) -> Handler:
+ """Handler builder function for writing logs from specified hostname/rank to stdout."""
+ handler = logging.StreamHandler(stream=sys.stdout)
+ add_filter_to_handler(handler, hostname, local_rank, log_level=log_level)
+ handler.setFormatter(
+ logging.Formatter(
+ "%(asctime)s:%(levelname)s:%(hostname)s[%(local_rank)s]: %(message)s"
+ if local_rank is not None
+ else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s",
+ ),
+ )
+ return handler
+
+
+def file_handler(
+ hostname: str,
+ local_rank: int | None,
+ file_path: str | os.PathLike,
+ log_level: int = logging.NOTSET,
+) -> Handler:
+ """Handler builder function for writing logs from specified hostname/rank to a file."""
+ handler = logging.FileHandler(file_path)
+ add_filter_to_handler(handler, hostname, local_rank, log_level=log_level)
+ formatter = logging.Formatter("%(asctime)s:%(levelname)s: %(message)s")
+ handler.setFormatter(formatter)
+ return handler
+
+
+def file_handlers(
+ hostnames: list[str],
+ workers_per_host: list[int],
+ log_dir: str | os.PathLike = Path("torchrunx_logs"),
+ log_level: int = logging.NOTSET,
+) -> list[Handler]:
+ """Handler builder function for writing logs for all workers/agents to a directory.
+
+ Files are named with timestamp, hostname, and the local_rank (for workers).
+ """
+ handlers = []
+
+ Path(log_dir).mkdir(parents=True, exist_ok=True)
+ timestamp = datetime.datetime.now().isoformat(timespec="seconds")
+
+ for hostname, num_workers in zip(hostnames, workers_per_host):
+ for local_rank in [None, *range(num_workers)]:
+ file_path = (
+ f"{log_dir}/{timestamp}-{hostname}"
+ + (f"[{local_rank}]" if local_rank is not None else "")
+ + ".log"
+ )
+ handlers.append(file_handler(hostname, local_rank, file_path, log_level=log_level))
+
+ return handlers
+
+
+def default_handlers(
+ hostnames: list[str],
+ workers_per_host: list[int],
+ log_dir: str | os.PathLike = Path("torchrunx_logs"),
+ log_level: int = logging.INFO,
+) -> list[Handler]:
+ """Default :mod:`logging.Handler`s for ``log_handlers="auto"`` in :mod:`torchrunx.launch`.
+
+ Logs for ``host[0]`` and its ``local_rank[0]`` worker are written to launcher process stdout.
+ Logs for all agents/workers are written to files in ``log_dir`` (named by timestamp, hostname,
+ local_rank).
+ """
+ return [
+ stream_handler(hostname=hostnames[0], local_rank=None, log_level=log_level),
+ stream_handler(hostname=hostnames[0], local_rank=0, log_level=log_level),
+ *file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=log_level),
+ ]
+
+
## Launcher utilities
class LogRecordSocketReceiver(ThreadingTCPServer):
+ """TCP server for recieving Agent/Worker log records in Launcher.
+
+ Uses threading to avoid bottlenecks (i.e. "out-of-order" logs in Launcher process).
+ """
+
def __init__(self, host: str, port: int, handlers: list[Handler]) -> None:
+ """Processing streamed bytes as LogRecord objects."""
self.host = host
self.port = port
@@ -51,7 +175,7 @@ def handle(self) -> None:
self.daemon_threads = True
def shutdown(self) -> None:
- """override BaseServer.shutdown() with added timeout"""
+ """Override BaseServer.shutdown() with added timeout (to avoid hanging)."""
self._BaseServer__shutdown_request = True
self._BaseServer__is_shut_down.wait(timeout=3) # pyright: ignore[reportAttributeAccessIssue]
@@ -59,40 +183,9 @@ def shutdown(self) -> None:
## Agent/worker utilities
-@dataclass
-class WorkerLogRecord(logging.LogRecord):
- hostname: str
- worker_rank: int | None
-
- @classmethod
- def from_record(cls, record: logging.LogRecord, hostname: str, worker_rank: int | None) -> Self:
- record.hostname = hostname
- record.worker_rank = worker_rank
- record.__class__ = cls
- return record # pyright: ignore [reportReturnType]
-
-
-def log_records_to_socket(
- logger: Logger,
- hostname: str,
- worker_rank: int | None,
- logger_hostname: str,
- logger_port: int,
-) -> None:
- logger.setLevel(logging.NOTSET)
-
- old_factory = logging.getLogRecordFactory()
-
- def record_factory(*args, **kwargs) -> WorkerLogRecord: # noqa: ANN002, ANN003
- record = old_factory(*args, **kwargs)
- return WorkerLogRecord.from_record(record, hostname, worker_rank)
-
- logging.setLogRecordFactory(record_factory)
-
- logger.addHandler(SocketHandler(host=logger_hostname, port=logger_port))
-
-
def redirect_stdio_to_logger(logger: Logger) -> None:
+ """Redirect stderr/stdout: send output to logger at every flush."""
+
class _LoggingStream(StringIO):
def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None:
super().__init__()
@@ -100,7 +193,7 @@ def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None:
self.level = level
def flush(self) -> None:
- super().flush()
+ super().flush() # At "flush" to avoid logs of partial bytes
value = self.getvalue()
if value != "":
self.logger.log(self.level, value)
@@ -112,82 +205,37 @@ def flush(self) -> None:
redirect_stdout(_LoggingStream(logger, level=logging.INFO)).__enter__()
-## Handler utilities
-
-
-def add_filter_to_handler(
- handler: Handler,
- hostname: str,
- worker_rank: int | None,
- log_level: int = logging.NOTSET,
-) -> None:
- def _filter(record: WorkerLogRecord) -> bool:
- return (
- record.hostname == hostname
- and record.worker_rank == worker_rank
- and record.levelno >= log_level
- )
-
- handler.addFilter(_filter) # pyright: ignore [reportArgumentType]
+@dataclass
+class WorkerLogRecord(logging.LogRecord):
+ """Adding hostname, local_rank attributes to LogRecord. local_rank=None for Agent."""
+ hostname: str
+ local_rank: int | None
-def stream_handler(hostname: str, rank: int | None, log_level: int = logging.NOTSET) -> Handler:
- handler = logging.StreamHandler()
- add_filter_to_handler(handler, hostname, rank, log_level=log_level)
- handler.setFormatter(
- logging.Formatter(
- "%(asctime)s:%(levelname)s:%(hostname)s[%(worker_rank)s]: %(message)s"
- if rank is not None
- else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s",
- ),
- )
- return handler
+ @classmethod
+ def from_record(cls, record: logging.LogRecord, hostname: str, local_rank: int | None) -> Self:
+ record.hostname = hostname
+ record.local_rank = local_rank
+ record.__class__ = cls
+ return record # pyright: ignore [reportReturnType]
-def file_handler(
+def log_records_to_socket(
+ logger: Logger,
hostname: str,
- worker_rank: int | None,
- file_path: str | os.PathLike,
- log_level: int = logging.NOTSET,
-) -> Handler:
- handler = logging.FileHandler(file_path)
- add_filter_to_handler(handler, hostname, worker_rank, log_level=log_level)
- formatter = logging.Formatter("%(asctime)s:%(levelname)s: %(message)s")
- handler.setFormatter(formatter)
- return handler
-
-
-def file_handlers(
- hostnames: list[str],
- workers_per_host: list[int],
- log_dir: str | os.PathLike = Path("torchrunx_logs"),
- log_level: int = logging.NOTSET,
-) -> list[Handler]:
- handlers = []
-
- Path(log_dir).mkdir(parents=True, exist_ok=True)
- timestamp = datetime.datetime.now().isoformat(timespec="seconds")
+ local_rank: int | None, # None indicates agent
+ logger_hostname: str,
+ logger_port: int,
+) -> None:
+ """Encode LogRecords with hostname/local_rank. Send to TCP socket on Launcher."""
+ logger.setLevel(logging.NOTSET)
- for hostname, num_workers in zip(hostnames, workers_per_host):
- for rank in [None, *range(num_workers)]:
- file_path = (
- f"{log_dir}/{timestamp}-{hostname}"
- + (f"[{rank}]" if rank is not None else "")
- + ".log"
- )
- handlers.append(file_handler(hostname, rank, file_path, log_level=log_level))
+ old_factory = logging.getLogRecordFactory()
- return handlers
+ def record_factory(*args, **kwargs) -> WorkerLogRecord: # noqa: ANN002, ANN003
+ record = old_factory(*args, **kwargs)
+ return WorkerLogRecord.from_record(record, hostname, local_rank)
+ logging.setLogRecordFactory(record_factory)
-def default_handlers(
- hostnames: list[str],
- workers_per_host: list[int],
- log_dir: str | os.PathLike = Path("torchrunx_logs"),
- log_level: int = logging.INFO,
-) -> list[Handler]:
- return [
- stream_handler(hostname=hostnames[0], rank=None, log_level=log_level),
- stream_handler(hostname=hostnames[0], rank=0, log_level=log_level),
- *file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=log_level),
- ]
+ logger.addHandler(SocketHandler(host=logger_hostname, port=logger_port))
diff --git a/src/torchrunx/worker.py b/src/torchrunx/worker.py
new file mode 100644
index 00000000..12caf827
--- /dev/null
+++ b/src/torchrunx/worker.py
@@ -0,0 +1,116 @@
+"""Arguments and entrypoint for the worker processes."""
+
+from __future__ import annotations
+
+import datetime
+import logging
+import os
+import sys
+import traceback
+from dataclasses import dataclass
+from typing import Any, Callable, Literal
+
+import cloudpickle
+import torch
+import torch.distributed as dist
+
+from .utils.errors import ExceptionFromWorker
+from .utils.logging import log_records_to_socket, redirect_stdio_to_logger
+
+__all__ = ["WorkerArgs", "worker_entrypoint"]
+
+
+@dataclass
+class WorkerArgs:
+ """Arguments passed from agent to spawned workers."""
+
+ function: Callable
+ logger_hostname: str
+ logger_port: int
+ main_agent_hostname: str
+ main_agent_port: int
+ backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None
+ rank: int
+ local_rank: int
+ local_world_size: int
+ world_size: int
+ hostname: str
+ timeout: int
+
+ def serialize(self) -> SerializedWorkerArgs:
+ """Arguments must be serialized (to bytes) before passed to spawned workers."""
+ return SerializedWorkerArgs(worker_args=self)
+
+
+class SerializedWorkerArgs:
+ """We use cloudpickle as a serialization backend (as it supports nearly all Python types)."""
+
+ def __init__(self, worker_args: WorkerArgs) -> None:
+ self.bytes = cloudpickle.dumps(worker_args)
+
+ def deserialize(self) -> WorkerArgs:
+ return cloudpickle.loads(self.bytes)
+
+
+def worker_entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | ExceptionFromWorker:
+ """Function called by spawned worker processes.
+
+ Workers first prepare a process group (for communicating with all other workers).
+ They then invoke the user-provided function.
+ Logs are transmitted to the launcher process.
+ """
+ worker_args: WorkerArgs = serialized_worker_args.deserialize()
+
+ # Start logging to the logging server (i.e. the launcher)
+
+ logger = logging.getLogger()
+
+ log_records_to_socket(
+ logger=logger,
+ hostname=worker_args.hostname,
+ local_rank=worker_args.local_rank,
+ logger_hostname=worker_args.logger_hostname,
+ logger_port=worker_args.logger_port,
+ )
+
+ redirect_stdio_to_logger(logger)
+
+ # Set rank/world environment variables
+
+ os.environ["RANK"] = str(worker_args.rank)
+ os.environ["LOCAL_RANK"] = str(worker_args.local_rank)
+ os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size)
+ os.environ["WORLD_SIZE"] = str(worker_args.world_size)
+ os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname
+ os.environ["MASTER_PORT"] = str(worker_args.main_agent_port)
+
+ # Prepare the process group (e.g. for communication within the user's function)
+
+ if worker_args.backend is not None:
+ backend = worker_args.backend
+ if backend == "auto":
+ backend = "nccl" if torch.cuda.is_available() else "gloo"
+
+ dist.init_process_group(
+ backend=backend,
+ world_size=worker_args.world_size,
+ rank=worker_args.rank,
+ store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage]
+ host_name=worker_args.main_agent_hostname,
+ port=worker_args.main_agent_port,
+ world_size=worker_args.world_size,
+ is_master=(worker_args.rank == 0),
+ ),
+ timeout=datetime.timedelta(seconds=worker_args.timeout),
+ )
+
+ # Invoke the user's function on this worker
+
+ try:
+ return worker_args.function()
+ except Exception as e:
+ traceback.print_exc()
+ return ExceptionFromWorker(exception=e)
+ finally:
+ sys.stdout.flush()
+ sys.stderr.flush()
diff --git a/tests/test_ci.py b/tests/test_ci.py
index 64cd1e93..79d89433 100644
--- a/tests/test_ci.py
+++ b/tests/test_ci.py
@@ -37,7 +37,7 @@ def dist_func() -> torch.Tensor:
backend="gloo", # log_dir="./test_logs"
)
- assert torch.all(r.value(0) == r.value(1))
+ assert torch.all(r.rank(0) == r.rank(1))
def test_logging() -> None:
diff --git a/tests/test_func.py b/tests/test_func.py
index e8033b4e..7b3ad7f6 100644
--- a/tests/test_func.py
+++ b/tests/test_func.py
@@ -13,7 +13,7 @@ def test_launch() -> None:
workers_per_host="slurm",
)
- result_values = result.all(by="rank")
+ result_values = result.by_ranks()
t = True
for i in range(len(result_values)):