From 3bd7f1b53bb2b62361f52cebd50e1e2acd7a5006 Mon Sep 17 00:00:00 2001 From: Peter Curtin <98424367+pmcurtin@users.noreply.github.com> Date: Sat, 19 Oct 2024 16:34:15 -0400 Subject: [PATCH 01/31] Update launcher.py --- src/torchrunx/launcher.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 4203b4f3..3a2c19ed 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -351,6 +351,9 @@ def launch( class LaunchResult: + """ + A class that holds worker return values, created by :mod:``torchrunx.launch`` or :mod:``torchrunx.Launcher.run``. + """ def __init__(self, hostnames: list[str], agent_statuses: list[AgentStatus]) -> None: self.hostnames: list[str] = hostnames self.return_values: list[list[Any]] = [s.return_values for s in agent_statuses] @@ -370,9 +373,7 @@ def all(self, by: Literal["rank"]) -> list[Any]: def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]: """ Get all worker return values by rank or hostname. - - :param by: Whether to aggregate all return values by hostname, or just output all of them \ - in order of rank, defaults to ``'hostname'`` + Returns a list of return values ordered by global rank, or a dictionary mapping hostnames to lists of return values ordered by local rank. """ if by == "hostname": return dict(zip(self.hostnames, self.return_values)) @@ -385,8 +386,6 @@ def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[An def values(self, hostname: str) -> list[Any]: """ Get worker return values for host ``hostname``. - - :param hostname: The host to get return values from """ host_idx = self.hostnames.index(hostname) return self.return_values[host_idx] @@ -394,8 +393,6 @@ def values(self, hostname: str) -> list[Any]: def value(self, rank: int) -> Any: """ Get worker return value from global rank ``rank``. - - :param rank: Global worker rank to get return value from """ if rank < 0: msg = f"Rank {rank} must be larger than 0" From 0b9c1dff8ec208af7bded5825579e98f5152d9ab Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 20 Oct 2024 14:19:07 -0400 Subject: [PATCH 02/31] moved log_handlers into .run() --- src/torchrunx/launcher.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index ddd2deaf..021e5e3b 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -36,7 +36,6 @@ class Launcher: ssh_config_file: str | os.PathLike | None = None backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto" timeout: int = 600 - log_handlers: list[Handler] | Literal["auto"] | None = "auto" default_env_vars: tuple[str, ...] = ( "PATH", "LD_LIBRARY", @@ -55,6 +54,7 @@ def run( # noqa: C901, PLR0912 func: Callable, func_args: tuple[Any] | None = None, func_kwargs: dict[str, Any] | None = None, + log_handlers: list[Handler] | Literal["auto"] | None = "auto" ) -> LaunchResult: if not dist.is_available(): msg = "The torch.distributed package is not available." @@ -76,7 +76,7 @@ def run( # noqa: C901, PLR0912 # start logging server log_receiver = build_logging_server( - log_handlers=self.log_handlers, + log_handlers=log_handlers, launcher_hostname=launcher_hostname, hostnames=hostnames, workers_per_host=workers_per_host, @@ -186,7 +186,6 @@ def launch( ssh_config_file: str | os.PathLike | None = None, backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None = "auto", timeout: int = 600, - log_handlers: list[Handler] | Literal["auto"] | None = "auto", default_env_vars: tuple[str, ...] = ( "PATH", "LD_LIBRARY", @@ -199,6 +198,7 @@ def launch( ), extra_env_vars: tuple[str, ...] = (), env_file: str | os.PathLike | None = None, + log_handlers: list[Handler] | Literal["auto"] | None = "auto", ) -> LaunchResult: """ Launch a distributed PyTorch function on the specified nodes. @@ -211,10 +211,10 @@ def launch( :param ssh_config_file: An SSH configuration file for connecting to nodes, by default loads ``~/.ssh/config`` or ``/etc/ssh/ssh_config``. :param backend: `Backend `_ to initialize worker process group with. Default uses NCCL (if GPUs available) or GLOO. Disabled by ``None``. :param timeout: Worker process group timeout (seconds). - :param log_handlers: A list of handlers to manage agent and worker logs. Default uses an automatic basic logging scheme. :param default_env_vars: A list of environmental variables to be copied from the launcher process to workers. Allows for bash pattern matching syntax. :param extra_env_vars: Additional, user-specified variables to copy. :param env_file: A file (like ``.env``) with additional environment variables to copy. + :param log_handlers: A list of handlers to manage agent and worker logs. Default uses an automatic basic logging scheme. :raises RuntimeError: If ``torch.distributed`` not available :raises AgentKilledError: If any agent is killed :raises Exception: Propagates exceptions raised in worker processes @@ -225,11 +225,10 @@ def launch( ssh_config_file=ssh_config_file, backend=backend, timeout=timeout, - log_handlers=log_handlers, default_env_vars=default_env_vars, extra_env_vars=extra_env_vars, env_file=env_file, - ).run(func=func, func_args=func_args, func_kwargs=func_kwargs) + ).run(func=func, func_args=func_args, func_kwargs=func_kwargs, log_handlers=log_handlers) class LaunchResult: From af8c829e46a3deff98c245d11f00288e0f5f3494 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 20 Oct 2024 14:19:23 -0400 Subject: [PATCH 03/31] update contributing --- CONTRIBUTING.md | 16 +++++++++++++++- docs/source/contributing.rst | 18 ------------------ 2 files changed, 15 insertions(+), 19 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index fa436244..24ff0658 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,3 +1,17 @@ # Contributing -We use the [`pixi`](https://pixi.sh) package manager. Simply [install `pixi`](https://pixi.sh/latest/#installation) and run `pixi shell` in this repository. We use `ruff` for linting and formatting, `pyright` for static type checking, and `pytest` for testing. We build for `PyPI`. Our release pipeline is powered by Github Actions. +We use the [`pixi`](https://pixi.sh) package manager. Simply [install `pixi`](https://pixi.sh/latest/#installation) and run `pixi shell` in this repository to activate the environment. + +We use `ruff check` for linting, `ruff format` for formatting, `pyright` for static type checking, and `pytest` for testing. + +We build wheels with `python -m build` and upload to [PyPI](https://pypi.org/project/torchrunx) with [twine](https://twine.readthedocs.io). Our release pipeline is powered by Github Actions. + +## Pull Requests + +Make a pull request with your changes on Github and we'll try to look at soon! If addressing a specific issue, mention it in the PR, and offer a short explanation of your fix. If adding a new feature, explain why it's meaningful and belongs in __torchrunx__. + +## Testing + +`tests/` contains `pytest`-style tests for validating that code changes do not break the core functionality of our library. + +At the moment, we run `pytest tests/test_ci.py` (i.e. simple single-node CPU-only tests) in our Github Actions CI pipeline (`.github/workflows/release.yml`). One can manually run our more involved tests (on GPUs, on multiple machines from SLURM) on their own hardware. diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index 707e376c..7a661a30 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -1,20 +1,2 @@ -Contributing -============ - .. include:: ../../CONTRIBUTING.md :parser: myst_parser.sphinx_ - -.. Development environment -.. ----------------------- - -.. Ensure you have the latest development environment installed. After cloning our repository, `install pixi `_ and run ``pixi shell`` in the repo's root directory. Additionally, we use `ruff `_ for linting and formatting, `pyright `_ for type checking, and ``pytest`` for testing. - -.. Testing -.. ------- - -.. ``tests/`` contains ``pytest``-style tests for validating that code changes do not break the core functionality of **torchrunx**. At the moment, we have a few simple CI tests powered by Github action, which are limited to single-agent CPU-only tests due to Github's infrastructure. - -.. Contributing -.. ------------ - -.. Make a pull request with your changes and we'll try to look at soon! If addressing a specific issue, mention it in the PR, and offer a short explanation of your fix. If adding a new feature, explain why it's meaningful and belongs in **torchrunx**. \ No newline at end of file From 4ac384eb11db5eb4b9b3fce9fbc69dce17355406 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 20 Oct 2024 14:19:35 -0400 Subject: [PATCH 04/31] add tyro, remove setuptools from extras --- pixi.lock | 115 ++++++++++++++++++++++++++---------------------------- pixi.toml | 4 +- 2 files changed, 57 insertions(+), 62 deletions(-) diff --git a/pixi.lock b/pixi.lock index 49a19663..0f1b3de1 100644 --- a/pixi.lock +++ b/pixi.lock @@ -209,7 +209,9 @@ environments: - pypi: https://files.pythonhosted.org/packages/ac/25/e715fa0bc24ac2114ed69da33adf451a38abb6f3f24ec207908112e9ba53/cryptography-43.0.3-cp39-abi3-manylinux_2_28_x86_64.whl - pypi: https://files.pythonhosted.org/packages/d5/50/83c593b07763e1161326b3b8c6686f0f4b0f24d5526546bee538c89837d6/decorator-5.1.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/20/8d/778b7d51b981a96554f29136cd59ca7880bf58094338085bcf2a979a0e6a/Deprecated-1.2.14-py2.py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/d5/7c/e9fcff7623954d86bdc17782036cbf715ecab1bec4847c008557affe1ca8/docstring_parser-0.16-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/ac/ac/aa3d8e0acbcd71140420bc752d7c9779cf3a2a3bb1d7ef30944e38b2cd39/eval_type_backport-0.2.0-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/02/cc/b7e31358aac6ed1ef2bb790a9746ac2c69bcb3c8588b41616914eb106eaf/exceptiongroup-1.2.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/d6/1f/e99e23ee01847147fa194e8d41cfcf2535a2dbfcb51414c541cadb15c5d7/fabric-3.2.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b9/f8/feced7779d755758a52d1f6635d990b8d98dc0a29fa568bbe0625f18fdf3/filelock-3.16.1-py3-none-any.whl @@ -267,7 +269,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/46/96/464058dd1d980014fb5aa0a1254e78799efb3096fc7a4823cd66a1621276/ruff-0.7.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5d/80/81ba44fc82afbf5ca553913ac49460e325dc5cf00c317b34c14d43ebd76b/safetensors-0.4.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/54/24/b4293291fa1dd830f353d2cb163295742fa87f179fcc8a20a306a81978b7/SecretStorage-3.3.3-py3-none-any.whl - - pypi: https://files.pythonhosted.org/packages/31/2d/90165d51ecd38f9a02c6832198c13a4e48652485e2ccf863ebb942c531b6/setuptools-75.2.0-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/e2/d1/a1d3189e7873408b9dc396aef0d7926c198b0df2aa3ddb5b539d3e89a70f/shtab-1.7.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/dc/a4/90123871996bfb8a7148cd11d61e7a0ddc0118114c071730b3dc3a05c7bc/submitit-1.5.2-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b2/fe/81695a1aa331a842b582453b605175f419fe8540355886031328089d840a/sympy-1.13.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/63/04/1ce87a3eae86aa4e301e863d42d6907deaf61f3e9178210d9ebe653e948c/tokenizers-0.20.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl @@ -278,6 +280,7 @@ environments: - pypi: https://files.pythonhosted.org/packages/c4/69/57e0fed438d547524e08bfedc587078314176ad1c15c8be904d3f03149ec/triton-3.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/5d/ec/00f9d5fd040ae29867355e559a94e9a8429225a0284a3f5f091a3878bfc0/twine-5.1.1-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl + - pypi: https://files.pythonhosted.org/packages/29/32/daacb047883a572f452f3763e07135a45d023a068badb6d83f9b07bc9410/tyro-0.8.13-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/ce/d9/5f4c13cecde62396b0d3fe530a50ccea91e7dfc1ccf0e09c228841bb5ba8/urllib3-2.2.3-py3-none-any.whl - pypi: https://files.pythonhosted.org/packages/b1/e7/459a8a4f40f2fa65eb73cb3f339e6d152957932516d18d0e996c7ae2d7ae/wrapt-1.16.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl - pypi: https://files.pythonhosted.org/packages/62/8b/5ba542fa83c90e09eac972fc9baca7a88e7e7ca4b221a89251954019308b/zipp-3.20.2-py3-none-any.whl @@ -535,12 +538,26 @@ packages: - bump2version<1 ; extra == 'dev' - sphinx<2 ; extra == 'dev' requires_python: '>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*' +- kind: pypi + name: docstring-parser + version: '0.16' + url: https://files.pythonhosted.org/packages/d5/7c/e9fcff7623954d86bdc17782036cbf715ecab1bec4847c008557affe1ca8/docstring_parser-0.16-py3-none-any.whl + sha256: bf0a1387354d3691d102edef7ec124f219ef639982d096e26e3b60aeffa90637 + requires_python: '>=3.6,<4.0' - kind: pypi name: docutils version: 0.21.2 url: https://files.pythonhosted.org/packages/8f/d7/9322c609343d929e75e7e5e6255e614fcc67572cfd083959cdef3b7aad79/docutils-0.21.2-py3-none-any.whl sha256: dafca5b9e384f0e419294eb4d2ff9fa826435bf15f15b7bd45723e8ad76811b2 requires_python: '>=3.9' +- kind: pypi + name: eval-type-backport + version: 0.2.0 + url: https://files.pythonhosted.org/packages/ac/ac/aa3d8e0acbcd71140420bc752d7c9779cf3a2a3bb1d7ef30944e38b2cd39/eval_type_backport-0.2.0-py3-none-any.whl + sha256: ac2f73d30d40c5a30a80b8739a789d6bb5e49fdffa66d7912667e2015d9c9933 + requires_dist: + - pytest ; extra == 'tests' + requires_python: '>=3.8' - kind: pypi name: exceptiongroup version: 1.2.2 @@ -1645,66 +1662,15 @@ packages: - jeepney>=0.6 requires_python: '>=3.6' - kind: pypi - name: setuptools - version: 75.2.0 - url: https://files.pythonhosted.org/packages/31/2d/90165d51ecd38f9a02c6832198c13a4e48652485e2ccf863ebb942c531b6/setuptools-75.2.0-py3-none-any.whl - sha256: a7fcb66f68b4d9e8e66b42f9876150a3371558f98fa32222ffaa5bced76406f8 + name: shtab + version: 1.7.1 + url: https://files.pythonhosted.org/packages/e2/d1/a1d3189e7873408b9dc396aef0d7926c198b0df2aa3ddb5b539d3e89a70f/shtab-1.7.1-py3-none-any.whl + sha256: 32d3d2ff9022d4c77a62492b6ec875527883891e33c6b479ba4d41a51e259983 requires_dist: - - pytest-checkdocs>=2.4 ; extra == 'check' - - pytest-ruff>=0.2.1 ; sys_platform != 'cygwin' and extra == 'check' - - ruff>=0.5.2 ; sys_platform != 'cygwin' and extra == 'check' - - packaging>=24 ; extra == 'core' - - more-itertools>=8.8 ; extra == 'core' - - jaraco-text>=3.7 ; extra == 'core' - - wheel>=0.43.0 ; extra == 'core' - - platformdirs>=2.6.2 ; extra == 'core' - - jaraco-collections ; extra == 'core' - - jaraco-functools ; extra == 'core' - - packaging ; extra == 'core' - - more-itertools ; extra == 'core' - - importlib-metadata>=6 ; python_version < '3.10' and extra == 'core' - - tomli>=2.0.1 ; python_version < '3.11' and extra == 'core' - - importlib-resources>=5.10.2 ; python_version < '3.9' and extra == 'core' - - pytest-cov ; extra == 'cover' - - sphinx>=3.5 ; extra == 'doc' - - jaraco-packaging>=9.3 ; extra == 'doc' - - rst-linker>=1.9 ; extra == 'doc' - - furo ; extra == 'doc' - - sphinx-lint ; extra == 'doc' - - jaraco-tidelift>=1.4 ; extra == 'doc' - - pygments-github-lexers==0.0.5 ; extra == 'doc' - - sphinx-favicon ; extra == 'doc' - - sphinx-inline-tabs ; extra == 'doc' - - sphinx-reredirects ; extra == 'doc' - - sphinxcontrib-towncrier ; extra == 'doc' - - sphinx-notfound-page<2,>=1 ; extra == 'doc' - - pyproject-hooks!=1.1 ; extra == 'doc' - - towncrier<24.7 ; extra == 'doc' - - pytest-enabler>=2.2 ; extra == 'enabler' - - pytest!=8.1.*,>=6 ; extra == 'test' - - virtualenv>=13.0.0 ; extra == 'test' - - wheel>=0.44.0 ; extra == 'test' - - pip>=19.1 ; extra == 'test' - - packaging>=23.2 ; extra == 'test' - - jaraco-envs>=2.2 ; extra == 'test' - - pytest-xdist>=3 ; extra == 'test' - - jaraco-path>=3.2.0 ; extra == 'test' - - build[virtualenv]>=1.0.3 ; extra == 'test' - - filelock>=3.4.0 ; extra == 'test' - - ini2toml[lite]>=0.14 ; extra == 'test' - - tomli-w>=1.0.0 ; extra == 'test' - - pytest-timeout ; extra == 'test' - - pytest-home>=0.5 ; extra == 'test' - - pytest-subprocess ; extra == 'test' - - pyproject-hooks!=1.1 ; extra == 'test' - - jaraco-test ; extra == 'test' - - jaraco-develop>=7.21 ; (python_version >= '3.9' and sys_platform != 'cygwin') and extra == 'test' - - pytest-perf ; sys_platform != 'cygwin' and extra == 'test' - - pytest-mypy ; extra == 'type' - - mypy==1.11.* ; extra == 'type' - - importlib-metadata>=7.0.2 ; python_version < '3.10' and extra == 'type' - - jaraco-develop>=7.21 ; sys_platform != 'cygwin' and extra == 'type' - requires_python: '>=3.8' + - pytest>=6 ; extra == 'dev' + - pytest-cov ; extra == 'dev' + - pytest-timeout ; extra == 'dev' + requires_python: '>=3.7' - kind: conda name: sqlite version: 3.46.0 @@ -2260,6 +2226,35 @@ packages: url: https://files.pythonhosted.org/packages/26/9f/ad63fc0248c5379346306f8668cda6e2e2e9c95e01216d2b8ffd9ff037d0/typing_extensions-4.12.2-py3-none-any.whl sha256: 04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d requires_python: '>=3.8' +- kind: pypi + name: tyro + version: 0.8.13 + url: https://files.pythonhosted.org/packages/29/32/daacb047883a572f452f3763e07135a45d023a068badb6d83f9b07bc9410/tyro-0.8.13-py3-none-any.whl + sha256: 8664cb409915c68d1b9ffc3c9e44b260387369b319a4d94399918c70d900b7de + requires_dist: + - docstring-parser>=0.16 + - typing-extensions>=4.7.0 + - rich>=11.1.0 + - shtab>=1.5.6 + - colorama>=0.4.0 ; platform_system == 'Windows' + - eval-type-backport>=0.1.3 ; python_version < '3.10' + - backports-cached-property>=1.0.2 ; python_version < '3.8' + - pyyaml>=6.0 ; extra == 'dev' + - frozendict>=2.3.4 ; extra == 'dev' + - pytest>=7.1.2 ; extra == 'dev' + - pytest-cov>=3.0.0 ; extra == 'dev' + - omegaconf>=2.2.2 ; extra == 'dev' + - attrs>=21.4.0 ; extra == 'dev' + - torch>=1.10.0 ; extra == 'dev' + - pyright!=1.1.379,>=1.1.349 ; extra == 'dev' + - ruff>=0.1.13 ; extra == 'dev' + - mypy>=1.4.1 ; extra == 'dev' + - numpy>=1.20.0 ; extra == 'dev' + - pydantic>=2.5.2 ; extra == 'dev' + - coverage[toml]>=6.5.0 ; extra == 'dev' + - eval-type-backport>=0.1.3 ; extra == 'dev' + - flax>=0.6.9 ; python_version >= '3.8' and extra == 'dev' + requires_python: '>=3.7' - kind: conda name: tzdata version: 2024b diff --git a/pixi.toml b/pixi.toml index c6c14e2f..ae7bdf8a 100644 --- a/pixi.toml +++ b/pixi.toml @@ -17,10 +17,10 @@ build = "*" twine = "*" [feature.extra.pypi-dependencies] -transformers = "*" +tyro = "*" submitit = "*" -setuptools = "*" accelerate = "*" +transformers = "*" [environments] default = { features = ["package", "dev"], solve-group = "default" } From cbf40b9f0c4c547a5a56e1e5e2fe0121dce1afbd Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 20 Oct 2024 15:22:16 -0400 Subject: [PATCH 05/31] enabled linting for docs; clarified public/private functions --- pyproject.toml | 4 +- src/torchrunx/agent.py | 158 +++++++++++++++++---------------- src/torchrunx/environment.py | 11 ++- src/torchrunx/launcher.py | 38 ++++---- src/torchrunx/logging_utils.py | 53 ++++++----- src/torchrunx/utils.py | 11 ++- 6 files changed, 148 insertions(+), 127 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ecc572a1..b1a12e8f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,6 @@ src = ["src", "tests"] [tool.ruff.lint] select = ["ALL"] ignore = [ - "D", # documentation "ANN101", "ANN102", "ANN401", # self / cls / Any annotations "BLE001", # blind exceptions "TD", # todo syntax @@ -54,9 +53,12 @@ ignore = [ ] [tool.ruff.lint.per-file-ignores] "tests/**/*.py" = [ + "D", "S101", # allow asserts "T201" # allow prints ] +[tool.ruff.lint.pydocstyle] +convention = "google" [tool.pyright] include = ["src", "tests"] diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 1860e444..396ec94d 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -1,5 +1,7 @@ from __future__ import annotations +__all__ = ["main"] + import datetime import logging import os @@ -25,83 +27,6 @@ ) -@dataclass -class WorkerArgs: - function: Callable - logger_hostname: str - logger_port: int - main_agent_hostname: str - main_agent_port: int - backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None - rank: int - local_rank: int - local_world_size: int - world_size: int - hostname: str - timeout: int - - def serialize(self) -> SerializedWorkerArgs: - return SerializedWorkerArgs(worker_args=self) - - -class SerializedWorkerArgs: - def __init__(self, worker_args: WorkerArgs) -> None: - self.bytes = cloudpickle.dumps(worker_args) - - def deserialize(self) -> WorkerArgs: - return cloudpickle.loads(self.bytes) - - -def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerException: - worker_args: WorkerArgs = serialized_worker_args.deserialize() - - logger = logging.getLogger() - - log_records_to_socket( - logger=logger, - hostname=worker_args.hostname, - worker_rank=worker_args.local_rank, - logger_hostname=worker_args.logger_hostname, - logger_port=worker_args.logger_port, - ) - - redirect_stdio_to_logger(logger) - - os.environ["RANK"] = str(worker_args.rank) - os.environ["LOCAL_RANK"] = str(worker_args.local_rank) - os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size) - os.environ["WORLD_SIZE"] = str(worker_args.world_size) - os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname - os.environ["MASTER_PORT"] = str(worker_args.main_agent_port) - - if worker_args.backend is not None: - backend = worker_args.backend - if backend == "auto": - backend = "nccl" if torch.cuda.is_available() else "gloo" - - dist.init_process_group( - backend=backend, - world_size=worker_args.world_size, - rank=worker_args.rank, - store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage] - host_name=worker_args.main_agent_hostname, - port=worker_args.main_agent_port, - world_size=worker_args.world_size, - is_master=(worker_args.rank == 0), - ), - timeout=datetime.timedelta(seconds=worker_args.timeout), - ) - - try: - return worker_args.function() - except Exception as e: - traceback.print_exc() - return WorkerException(exception=e) - finally: - sys.stdout.flush() - sys.stderr.flush() - - def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int) -> None: agent_rank = launcher_agent_group.rank - 1 @@ -135,7 +60,7 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ ctx = dist_mp.start_processes( name=f"{hostname}_", - entrypoint=entrypoint, + entrypoint=_entrypoint, args={ i: ( WorkerArgs( @@ -179,3 +104,80 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ ctx.close() sys.stdout.flush() sys.stderr.flush() + + +@dataclass +class WorkerArgs: + function: Callable + logger_hostname: str + logger_port: int + main_agent_hostname: str + main_agent_port: int + backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None + rank: int + local_rank: int + local_world_size: int + world_size: int + hostname: str + timeout: int + + def serialize(self) -> SerializedWorkerArgs: + return SerializedWorkerArgs(worker_args=self) + + +class SerializedWorkerArgs: + def __init__(self, worker_args: WorkerArgs) -> None: + self.bytes = cloudpickle.dumps(worker_args) + + def deserialize(self) -> WorkerArgs: + return cloudpickle.loads(self.bytes) + + +def _entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerException: + worker_args: WorkerArgs = serialized_worker_args.deserialize() + + logger = logging.getLogger() + + log_records_to_socket( + logger=logger, + hostname=worker_args.hostname, + worker_rank=worker_args.local_rank, + logger_hostname=worker_args.logger_hostname, + logger_port=worker_args.logger_port, + ) + + redirect_stdio_to_logger(logger) + + os.environ["RANK"] = str(worker_args.rank) + os.environ["LOCAL_RANK"] = str(worker_args.local_rank) + os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size) + os.environ["WORLD_SIZE"] = str(worker_args.world_size) + os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname + os.environ["MASTER_PORT"] = str(worker_args.main_agent_port) + + if worker_args.backend is not None: + backend = worker_args.backend + if backend == "auto": + backend = "nccl" if torch.cuda.is_available() else "gloo" + + dist.init_process_group( + backend=backend, + world_size=worker_args.world_size, + rank=worker_args.rank, + store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage] + host_name=worker_args.main_agent_hostname, + port=worker_args.main_agent_port, + world_size=worker_args.world_size, + is_master=(worker_args.rank == 0), + ), + timeout=datetime.timedelta(seconds=worker_args.timeout), + ) + + try: + return worker_args.function() + except Exception as e: + traceback.print_exc() + return WorkerException(exception=e) + finally: + sys.stdout.flush() + sys.stderr.flush() diff --git a/src/torchrunx/environment.py b/src/torchrunx/environment.py index 179cfb8d..ed27d1f6 100644 --- a/src/torchrunx/environment.py +++ b/src/torchrunx/environment.py @@ -1,5 +1,7 @@ from __future__ import annotations +__all__ = ["in_slurm_job", "slurm_hosts", "slurm_workers", "auto_hosts", "auto_workers"] + import os import subprocess @@ -29,8 +31,7 @@ def slurm_hosts() -> list[str]: def slurm_workers() -> int: - """ - | Determines number of workers per node in current Slurm allocation using + """| Determines number of workers per node in current Slurm allocation using | the ``SLURM_JOB_GPUS`` or ``SLURM_CPUS_ON_NODE`` environmental variables. :return: The implied number of workers per node @@ -52,8 +53,7 @@ def slurm_workers() -> int: def auto_hosts() -> list[str]: - """ - Automatically determine hostname list + """Automatically determine hostname list :return: Hostnames in Slurm allocation, or ['localhost'] :rtype: list[str] @@ -65,8 +65,7 @@ def auto_hosts() -> list[str]: def auto_workers() -> int: - """ - Automatically determine number of workers per host + """Automatically determine number of workers per host :return: Workers per host :rtype: int diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 021e5e3b..04504475 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -1,5 +1,7 @@ from __future__ import annotations +__all__ = ["AgentKilledError", "Launcher", "launch", "LaunchResult"] + import fnmatch import ipaddress import itertools @@ -54,14 +56,14 @@ def run( # noqa: C901, PLR0912 func: Callable, func_args: tuple[Any] | None = None, func_kwargs: dict[str, Any] | None = None, - log_handlers: list[Handler] | Literal["auto"] | None = "auto" + log_handlers: list[Handler] | Literal["auto"] | None = "auto", ) -> LaunchResult: if not dist.is_available(): msg = "The torch.distributed package is not available." raise RuntimeError(msg) - hostnames = resolve_hostnames(self.hostnames) - workers_per_host = resolve_workers_per_host(self.workers_per_host, len(hostnames)) + hostnames = _resolve_hostnames(self.hostnames) + workers_per_host = _resolve_workers_per_host(self.workers_per_host, len(hostnames)) launcher_hostname = socket.getfqdn() launcher_port = get_open_port() @@ -75,7 +77,7 @@ def run( # noqa: C901, PLR0912 try: # start logging server - log_receiver = build_logging_server( + log_receiver = _build_logging_server( log_handlers=log_handlers, launcher_hostname=launcher_hostname, hostnames=hostnames, @@ -94,8 +96,8 @@ def run( # noqa: C901, PLR0912 # start agents on each node for i, hostname in enumerate(hostnames): - execute_command( - command=build_launch_command( + _execute_command( + command=_build_launch_command( launcher_hostname=launcher_hostname, launcher_port=launcher_port, logger_port=log_receiver.port, @@ -168,7 +170,7 @@ def run( # noqa: C901, PLR0912 # cleanup: SIGTERM all agents if agent_payloads is not None: for agent_payload, agent_hostname in zip(agent_payloads, hostnames): - execute_command( + _execute_command( command=f"kill {agent_payload.process_id}", hostname=agent_hostname, ssh_config_file=self.ssh_config_file, @@ -200,8 +202,7 @@ def launch( env_file: str | os.PathLike | None = None, log_handlers: list[Handler] | Literal["auto"] | None = "auto", ) -> LaunchResult: - """ - Launch a distributed PyTorch function on the specified nodes. + """Launch a distributed PyTorch function on the specified nodes. :param func: :param func_args: @@ -249,8 +250,7 @@ def all(self, by: Literal["rank"]) -> list[Any]: pass def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]: - """ - Get all worker return values by rank or hostname. + """Get all worker return values by rank or hostname. :param by: Whether to aggregate all return values by hostname, or just output all of them \ in order of rank, defaults to ``'hostname'`` @@ -264,8 +264,7 @@ def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[An raise TypeError(msg) def values(self, hostname: str) -> list[Any]: - """ - Get worker return values for host ``hostname``. + """Get worker return values for host ``hostname``. :param hostname: The host to get return values from """ @@ -273,8 +272,7 @@ def values(self, hostname: str) -> list[Any]: return self.return_values[host_idx] def value(self, rank: int) -> Any: - """ - Get worker return value from global rank ``rank``. + """Get worker return value from global rank ``rank``. :param rank: Global worker rank to get return value from """ @@ -292,7 +290,7 @@ def value(self, rank: int) -> Any: raise ValueError(msg) -def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: +def _resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: if hostnames == "auto": return auto_hosts() if hostnames == "slurm": @@ -300,7 +298,7 @@ def resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[s return hostnames -def resolve_workers_per_host( +def _resolve_workers_per_host( workers_per_host: int | list[int] | Literal["auto", "slurm"], num_hosts: int, ) -> list[int]: @@ -318,7 +316,7 @@ def resolve_workers_per_host( return workers_per_host -def build_logging_server( +def _build_logging_server( log_handlers: list[Handler] | Literal["auto"] | None, launcher_hostname: str, hostnames: list[str], @@ -343,7 +341,7 @@ def build_logging_server( ) -def build_launch_command( +def _build_launch_command( launcher_hostname: str, launcher_port: int, logger_port: int, @@ -385,7 +383,7 @@ def build_launch_command( return " && ".join(commands) -def execute_command( +def _execute_command( command: str, hostname: str, ssh_config_file: str | os.PathLike | None = None, diff --git a/src/torchrunx/logging_utils.py b/src/torchrunx/logging_utils.py index d12b27f7..a0a0e2f4 100644 --- a/src/torchrunx/logging_utils.py +++ b/src/torchrunx/logging_utils.py @@ -1,5 +1,16 @@ from __future__ import annotations +__all__ = [ + "LogRecordSocketReceiver", + "redirect_stdio_to_logger", + "log_records_to_socket", + "add_filter_to_handler", + "file_handler", + "stream_handler", + "file_handlers", + "default_handlers", +] + import datetime import logging import pickle @@ -51,7 +62,7 @@ def handle(self) -> None: self.daemon_threads = True def shutdown(self) -> None: - """override BaseServer.shutdown() with added timeout""" + """Override BaseServer.shutdown() with added timeout""" self._BaseServer__shutdown_request = True self._BaseServer__is_shut_down.wait(timeout=3) # pyright: ignore[reportAttributeAccessIssue] @@ -59,6 +70,26 @@ def shutdown(self) -> None: ## Agent/worker utilities +def redirect_stdio_to_logger(logger: Logger) -> None: + class _LoggingStream(StringIO): + def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None: + super().__init__() + self.logger = logger + self.level = level + + def flush(self) -> None: + super().flush() + value = self.getvalue() + if value != "": + self.logger.log(self.level, value) + self.truncate(0) + self.seek(0) + + logging.captureWarnings(capture=True) + redirect_stderr(_LoggingStream(logger, level=logging.ERROR)).__enter__() + redirect_stdout(_LoggingStream(logger, level=logging.INFO)).__enter__() + + @dataclass class WorkerLogRecord(logging.LogRecord): hostname: str @@ -92,26 +123,6 @@ def record_factory(*args, **kwargs) -> WorkerLogRecord: # noqa: ANN002, ANN003 logger.addHandler(SocketHandler(host=logger_hostname, port=logger_port)) -def redirect_stdio_to_logger(logger: Logger) -> None: - class _LoggingStream(StringIO): - def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None: - super().__init__() - self.logger = logger - self.level = level - - def flush(self) -> None: - super().flush() - value = self.getvalue() - if value != "": - self.logger.log(self.level, value) - self.truncate(0) - self.seek(0) - - logging.captureWarnings(capture=True) - redirect_stderr(_LoggingStream(logger, level=logging.ERROR)).__enter__() - redirect_stdout(_LoggingStream(logger, level=logging.INFO)).__enter__() - - ## Handler utilities diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index 3770e93d..d0ec9451 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -1,5 +1,14 @@ from __future__ import annotations +__all__ = [ + "get_open_port", + "LauncherAgentGroup", + "LauncherPayload", + "AgentPayload", + "WorkerException", + "AgentStatus", +] + import datetime import socket from contextlib import closing @@ -49,7 +58,7 @@ def _deserialize(self, serialized: bytes) -> Any: return cloudpickle.loads(serialized) def _all_gather(self, obj: Any) -> list: - """gather object from every rank to list on every rank""" + """Gather object from every rank to list on every rank""" object_bytes = self._serialize(obj) object_list = [b""] * self.world_size # raises RuntimeError if timeout From 76aa20fd817e0504a90ed84bfd6d9ee871fc4ac5 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 20 Oct 2024 15:40:33 -0400 Subject: [PATCH 06/31] docs for utils.py --- src/torchrunx/utils.py | 33 ++++++++++++++++++++++++++++----- 1 file changed, 28 insertions(+), 5 deletions(-) diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index d0ec9451..1704644f 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -1,3 +1,4 @@ +"""Common utility functions and classes.""" from __future__ import annotations __all__ = [ @@ -24,6 +25,7 @@ def get_open_port() -> int: + """Return an open port number.""" with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.bind(("", 0)) return s.getsockname()[1] @@ -31,13 +33,19 @@ def get_open_port() -> int: @dataclass class LauncherAgentGroup: + """Initializes a GLOO distributed process group between launcher and all agents.""" + launcher_hostname: str launcher_port: int world_size: int rank: int def __post_init__(self) -> None: - # timeout will raise torch.distributed.DistStoreError + """Initialize process group. + + Raises: + torch.distributed.DistStoreError: if group initialization times out. + """ self.group = dist.init_process_group( backend="gloo", world_size=self.world_size, @@ -58,7 +66,7 @@ def _deserialize(self, serialized: bytes) -> Any: return cloudpickle.loads(serialized) def _all_gather(self, obj: Any) -> list: - """Gather object from every rank to list on every rank""" + """Gather object from every rank to list on every rank.""" object_bytes = self._serialize(obj) object_list = [b""] * self.world_size # raises RuntimeError if timeout @@ -69,20 +77,25 @@ def sync_payloads( self, payload: LauncherPayload | AgentPayload, ) -> tuple[LauncherPayload, list[AgentPayload]]: + """All-gather payloads across launcher and all agents.""" payloads = self._all_gather(payload) launcher_payload = payloads[0] agent_payloads = payloads[1:] return launcher_payload, agent_payloads def sync_agent_statuses(self, status: AgentStatus | None) -> list[AgentStatus]: + """All-gather agent statuses across launcher and all agents.""" return self._all_gather(status)[1:] # [0] is launcher (status=None) def shutdown(self) -> None: + """Terminate process group.""" dist.destroy_process_group(group=self.group) @dataclass class LauncherPayload: + """Payload from launcher to agents with runtime information.""" + fn: Callable hostnames: list[str] worker_global_ranks: list[list[int]] @@ -93,6 +106,8 @@ class LauncherPayload: @dataclass class AgentPayload: + """Payload corresponding to each agent.""" + hostname: str port: int process_id: int @@ -100,18 +115,26 @@ class AgentPayload: @dataclass class WorkerException: + """Wrapper for exception raised in worker process.""" + exception: Exception @dataclass class AgentStatus: + """Status of each agent (to be synchronized in LauncherAgentGroup). + + Attributes: + state: Whether the agent is running, failed, or done. + return_values: Objects returned (or exceptions raised) by workers (indexed by local rank). + """ + state: Literal["running", "failed", "done"] - return_values: list[Any | WorkerException] = field( - default_factory=list - ) # indexed by local rank + return_values: list[Any | WorkerException] = field(default_factory=list) @classmethod def from_result(cls, result: RunProcsResult | None) -> Self: + """Convert RunProcsResult (from polling worker process context) to AgentStatus.""" if result is None: return cls(state="running") From de93aafa11c46a43c793f5b1e92ed61c1ebdf625 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 20 Oct 2024 16:16:12 -0400 Subject: [PATCH 07/31] docs for logging_utils --- src/torchrunx/agent.py | 4 +- src/torchrunx/logging_utils.py | 79 +++++++++++++++++++++++++--------- 2 files changed, 60 insertions(+), 23 deletions(-) diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 396ec94d..efbd12c7 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -49,7 +49,7 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ log_records_to_socket( logger=logger, hostname=hostname, - worker_rank=None, + local_rank=None, logger_hostname=logger_hostname, logger_port=logger_port, ) @@ -141,7 +141,7 @@ def _entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | WorkerExc log_records_to_socket( logger=logger, hostname=worker_args.hostname, - worker_rank=worker_args.local_rank, + local_rank=worker_args.local_rank, logger_hostname=worker_args.logger_hostname, logger_port=worker_args.logger_port, ) diff --git a/src/torchrunx/logging_utils.py b/src/torchrunx/logging_utils.py index a0a0e2f4..b6a4074b 100644 --- a/src/torchrunx/logging_utils.py +++ b/src/torchrunx/logging_utils.py @@ -1,3 +1,5 @@ +"""Utilities for intercepting logs in worker processes and handling these in the Launcher.""" + from __future__ import annotations __all__ = [ @@ -15,6 +17,7 @@ import logging import pickle import struct +import sys from contextlib import redirect_stderr, redirect_stdout from dataclasses import dataclass from io import StringIO @@ -33,7 +36,13 @@ class LogRecordSocketReceiver(ThreadingTCPServer): + """TCP server for recieving Agent/Worker log records in Launcher. + + Uses threading to avoid bottlenecks (i.e. "out-of-order" logs in Launcher process). + """ + def __init__(self, host: str, port: int, handlers: list[Handler]) -> None: + """Processing streamed bytes as LogRecord objects.""" self.host = host self.port = port @@ -62,7 +71,7 @@ def handle(self) -> None: self.daemon_threads = True def shutdown(self) -> None: - """Override BaseServer.shutdown() with added timeout""" + """Override BaseServer.shutdown() with added timeout (to avoid hanging).""" self._BaseServer__shutdown_request = True self._BaseServer__is_shut_down.wait(timeout=3) # pyright: ignore[reportAttributeAccessIssue] @@ -71,6 +80,8 @@ def shutdown(self) -> None: def redirect_stdio_to_logger(logger: Logger) -> None: + """Redirect stderr/stdout: send output to logger at every flush.""" + class _LoggingStream(StringIO): def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None: super().__init__() @@ -78,7 +89,7 @@ def __init__(self, logger: Logger, level: int = logging.NOTSET) -> None: self.level = level def flush(self) -> None: - super().flush() + super().flush() # At "flush" to avoid logs of partial bytes value = self.getvalue() if value != "": self.logger.log(self.level, value) @@ -92,13 +103,15 @@ def flush(self) -> None: @dataclass class WorkerLogRecord(logging.LogRecord): + """Adding hostname, local_rank attributes to LogRecord. local_rank=None for Agent.""" + hostname: str - worker_rank: int | None + local_rank: int | None @classmethod - def from_record(cls, record: logging.LogRecord, hostname: str, worker_rank: int | None) -> Self: + def from_record(cls, record: logging.LogRecord, hostname: str, local_rank: int | None) -> Self: record.hostname = hostname - record.worker_rank = worker_rank + record.local_rank = local_rank record.__class__ = cls return record # pyright: ignore [reportReturnType] @@ -106,17 +119,18 @@ def from_record(cls, record: logging.LogRecord, hostname: str, worker_rank: int def log_records_to_socket( logger: Logger, hostname: str, - worker_rank: int | None, + local_rank: int | None, # None indicates agent logger_hostname: str, logger_port: int, ) -> None: + """Encode LogRecords with hostname/local_rank. Send to TCP socket on Launcher.""" logger.setLevel(logging.NOTSET) old_factory = logging.getLogRecordFactory() def record_factory(*args, **kwargs) -> WorkerLogRecord: # noqa: ANN002, ANN003 record = old_factory(*args, **kwargs) - return WorkerLogRecord.from_record(record, hostname, worker_rank) + return WorkerLogRecord.from_record(record, hostname, local_rank) logging.setLogRecordFactory(record_factory) @@ -129,26 +143,38 @@ def record_factory(*args, **kwargs) -> WorkerLogRecord: # noqa: ANN002, ANN003 def add_filter_to_handler( handler: Handler, hostname: str, - worker_rank: int | None, + local_rank: int | None, # None indicates agent log_level: int = logging.NOTSET, ) -> None: + """A filter for ``logging.Handler`` such that only specific agent/worker logs are handled. + + Args: + handler: ``logging.Handler`` to be modified. + hostname: Name of specified host. + local_rank: Rank of specified worker (or ``None`` for agent). + log_level: Minimum log level to capture. + """ + def _filter(record: WorkerLogRecord) -> bool: return ( record.hostname == hostname - and record.worker_rank == worker_rank + and record.local_rank == local_rank and record.levelno >= log_level ) handler.addFilter(_filter) # pyright: ignore [reportArgumentType] -def stream_handler(hostname: str, rank: int | None, log_level: int = logging.NOTSET) -> Handler: - handler = logging.StreamHandler() - add_filter_to_handler(handler, hostname, rank, log_level=log_level) +def stream_handler( + hostname: str, local_rank: int | None, log_level: int = logging.NOTSET +) -> Handler: + """logging.Handler builder function for writing logs to stdout.""" + handler = logging.StreamHandler(stream=sys.stdout) + add_filter_to_handler(handler, hostname, local_rank, log_level=log_level) handler.setFormatter( logging.Formatter( - "%(asctime)s:%(levelname)s:%(hostname)s[%(worker_rank)s]: %(message)s" - if rank is not None + "%(asctime)s:%(levelname)s:%(hostname)s[%(local_rank)s]: %(message)s" + if local_rank is not None else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s", ), ) @@ -157,12 +183,13 @@ def stream_handler(hostname: str, rank: int | None, log_level: int = logging.NOT def file_handler( hostname: str, - worker_rank: int | None, + local_rank: int | None, file_path: str | os.PathLike, log_level: int = logging.NOTSET, ) -> Handler: + """logging.Handler builder function for writing logs to a file.""" handler = logging.FileHandler(file_path) - add_filter_to_handler(handler, hostname, worker_rank, log_level=log_level) + add_filter_to_handler(handler, hostname, local_rank, log_level=log_level) formatter = logging.Formatter("%(asctime)s:%(levelname)s: %(message)s") handler.setFormatter(formatter) return handler @@ -174,19 +201,23 @@ def file_handlers( log_dir: str | os.PathLike = Path("torchrunx_logs"), log_level: int = logging.NOTSET, ) -> list[Handler]: + """Builder function for writing logs for all workers/agents to a directory. + + Files are named with timestamp, hostname, and the local_rank (for workers). + """ handlers = [] Path(log_dir).mkdir(parents=True, exist_ok=True) timestamp = datetime.datetime.now().isoformat(timespec="seconds") for hostname, num_workers in zip(hostnames, workers_per_host): - for rank in [None, *range(num_workers)]: + for local_rank in [None, *range(num_workers)]: file_path = ( f"{log_dir}/{timestamp}-{hostname}" - + (f"[{rank}]" if rank is not None else "") + + (f"[{local_rank}]" if local_rank is not None else "") + ".log" ) - handlers.append(file_handler(hostname, rank, file_path, log_level=log_level)) + handlers.append(file_handler(hostname, local_rank, file_path, log_level=log_level)) return handlers @@ -197,8 +228,14 @@ def default_handlers( log_dir: str | os.PathLike = Path("torchrunx_logs"), log_level: int = logging.INFO, ) -> list[Handler]: + """A default set of logging.Handlers to be used when ``launch(log_handlers="auto")``. + + Logs for host[0] and its local_rank[0] worker are written to the launcher process stdout. + Logs for all agents/workers are written to files in ``log_dir`` (named by timestamp, hostname, + local_rank). + """ return [ - stream_handler(hostname=hostnames[0], rank=None, log_level=log_level), - stream_handler(hostname=hostnames[0], rank=0, log_level=log_level), + stream_handler(hostname=hostnames[0], local_rank=None, log_level=log_level), + stream_handler(hostname=hostnames[0], local_rank=0, log_level=log_level), *file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=log_level), ] From e697257e3209012661525c6661528f1f1eb1695d Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 20 Oct 2024 17:55:59 -0400 Subject: [PATCH 08/31] advanced docs --- docs/source/advanced.rst | 101 ++++++++++++++++++-------------------- pixi.lock | 2 +- src/torchrunx/launcher.py | 2 +- 3 files changed, 49 insertions(+), 56 deletions(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 21efff64..3ccb1b1e 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -28,87 +28,80 @@ We could also launch multiple functions (e.g. train on many GPUs, test on one GP ``trx.launch()`` is self-cleaning: all processes are terminated (and the used memory is completely released) after each invocation. -Environment Detection ---------------------- - -By default, the `hostnames` or `workers_per_host` :mod:`torchrunx.launch` parameters are set to "auto". These parameters are populated via `SLURM`_ if a SLURM environment is automatically detected. Otherwise, `hostnames = ["localhost"]` and `workers_per_host` is set to the number of GPUs or CPUs (in order of precedence) available locally. - -SLURM -+++++ +SLURM integration +----------------- -If the `hostnames` or `workers_per_host` parameters are set to `"slurm"`, their values will be filled from the SLURM job. Passing `"slurm"` raises a `RuntimeError` if no SLURM allocation is detected from the environment. +By default, the ``hostnames`` or ``workers_per_host`` :mod:`torchrunx.launch` arguments are populated from the current SLURM allocation. If no allocation is detected, we assume 1 machine (``localhost``) with N GPUs or CPUs. +Raises a ``RuntimeError`` if ``hostnames`` or ``workers_per_host`` are intentionally set to ``"slurm"`` but no allocation is detected. -``Launcher`` class ------------------- +CLI support +----------- -We provide the ``torchrunx.Launcher`` class as an alternative to ``torchrunx.launch``. +We provide the :mod:`torchrunx.Launcher` class as an alias to :mod:`torchrunx.launch`. .. autoclass:: torchrunx.Launcher :members: -.. .. autofunction:: torchrunx.Launcher.run -CLI Support -+++++++++++ - -This allows **torchrunx** arguments to be more easily populated by CLI packages like `tyro `_: +We can use this class to populate arguments from the CLI (e.g. with `tyro `_): .. code:: python - import torchrunx as trx import tyro - def distributed_function(): - print("Hello world!") + def distributed_function(): print("Hello world!") if __name__ == "__main__": launcher = tyro.cli(trx.Launcher) - launcher.run(distributed_function, {}) + launcher.run(distributed_function) -For example, the `python ... --help` command will then result in: +``python ... --help`` then results in: .. code:: bash - - ╭─ options ─────────────────────────────────────────────────────────────────────────────────────────────────────╮ - │ -h, --help show this help message and exit │ - │ --hostnames {[STR [STR ...]]}|{auto,slurm} │ - │ (default: auto) │ - │ --workers-per-host INT|{[INT [INT ...]]}|{auto,slurm} │ - │ (default: auto) │ - │ --ssh-config-file {None}|STR|PATH │ - │ (default: None) │ - │ --backend {None,nccl,gloo,mpi,ucc,auto} │ - │ (default: auto) │ - │ --log-handlers {fixed} (fixed to: a u t o) │ - │ --env-vars STR (default: PATH LD_LIBRARY LIBRARY_PATH 'PYTHON*' 'CUDA*' 'TORCH*' 'PYTORCH*' 'NCCL*') │ - │ --env-file {None}|STR|PATH │ - │ (default: None) │ - │ --timeout INT (default: 600) │ - ╰───────────────────────────────────────────────────────────────────────────────────────────────────────────────╯ - -Custom Logging --------------- - -Logs are generated at the worker and agent level, and are specified to :mod:`torchrunx.launch` via the ``log_spec`` argument. By default, a is instantiated, causing logs at the worker and agent levels to be logged to files under ``'./logs'``, and the rank 0 worker's output streams are streamed to the launcher ``stdout``. Logs are prefixed with a timestamp by default. Agent logs have the format ``{timestamp}-{agent hostname}.log`` and workers have the format ``{timestamp}-{agent hostname}[{worker local rank}].log``. - -Custom logging classes can be subclassed from the class. Any subclass must have a ``get_map`` method returning a dictionary mapping logger names to lists of :mod:`logging.Handler` objects, in order to be passed to :mod:`torchrunx.launch`. The logger names are of the format ``{agent hostname}`` for agents and ``{agent hostname}[{worker local rank}]`` for workers. The maps all the loggers to :mod:`logging.Filehandler` object pointing to the files mentioned in the previous paragraph. It additionally maps the global rank 0 worker to a :mod:`logging.StreamHandler`, which writes logs the launcher's ``stdout`` stream. + ╭─ options ─────────────────────────────────────────────╮ + │ -h, --help show this help message and exit │ + │ --hostnames {[STR [STR ...]]}|{auto,slurm} │ + │ (default: auto) │ + │ --workers-per-host INT|{[INT [INT ...]]}|{auto,slurm} │ + │ (default: auto) │ + │ --ssh-config-file {None}|STR|PATH │ + │ (default: None) │ + │ --backend {None,nccl,gloo,mpi,ucc,auto} │ + │ (default: auto) │ + │ --timeout INT (default: 600) │ + │ --default-env-vars [STR [STR ...]] │ + │ (default: PATH LD_LIBRARY ...) │ + │ --extra-env-vars [STR [STR ...]] │ + │ (default: ) │ + │ --env-file {None}|STR|PATH │ + │ (default: None) │ + ╰───────────────────────────────────────────────────────╯ Propagating Exceptions ---------------------- -Exceptions that are raised in Workers will be raised in the Launcher process and can be caught by wrapping :mod:`torchrunx.launch` in a try-except clause. +Exceptions that are raised in Workers will be raised by :mod:`torchrunx.launch` or :mod:`torchrunx.Launcher.run`. -If a worker is killed by the operating system (e.g. due to Segmentation Fault or SIGKILL by running out of memory), the Launcher process raises a RuntimeError. +A :mod:`torchrunx.AgentKilledError` will be raised if any agent dies unexpectedly (e.g. if force-killed by the OS, due to segmentation faults or OOM). Environment Variables --------------------- -The :mod:`torchrunx.launch` ``env_vars`` argument allows the user to specify which environmental variables should be copied to the agents from the launcher environment. By default, it attempts to copy variables related to Python and important packages/technologies that **torchrunx** uses such as PyTorch, NCCL, CUDA, and more. Strings provided are matched with the names of environmental variables using ``fnmatch`` - standard UNIX filename pattern matching. The variables are inserted into the agent environments, and then copied to workers' environments when they are spawned. +Environment variables in the launcher process that match the :mod:`torchrunx.launch` ``default_env_vars`` argument are automatically copied to agents and workers. We set useful defaults for Python and PyTorch. Environment variables are pattern-matched with this list using ``fnmatch``. + +``default_env_vars`` can be overriden if desired. This list can be augmented using ``extra_env_vars``. Additional environment variables (and more custom bash logic) can be included via the ``env_file`` argument. Our agents ``source`` this file. + + +Custom Logging +-------------- + +We forward all logs (i.e. from ``logging`` and ``stdio``) from workers and agents to the Launcher. By default, the logs from the first agent and its first worker are printed into the Launcher's ``stdout`` stream. Logs from all agents and workers are written to files in ``$TORCHRUNX_LOG_DIR`` (default: ``./torchrunx_logs``) and are named by timestamp, hostname, and local_rank. + +``logging.Handler`` objects can be provided via the :mod:`torchrunx.launch` ``log_handlers`` argument to provide further customization (mapping specific agents/workers to custom output streams). + +We provide some utilities to help: -:mod:`torchrunx.launch` also accepts the ``env_file`` argument, which is designed to expose more advanced environmental configuration to the user. When a file is provided as this argument, the launcher will source the file on each node before executing the agent. This allows for custom bash scripts to be provided in the environmental variables, and allows for node-specific environmental variables to be set. +.. autofunction:: torchrunx.add_filter_to_handler -.. - TODO: example env_file +.. autofunction:: torchrunx.file_handler -Support for Numpy >= 2.0 ------------------------- -only supported if `torch>=2.3` +.. autofunction:: torchrunx.stream_handler diff --git a/pixi.lock b/pixi.lock index 0f1b3de1..14d646e1 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1793,7 +1793,7 @@ packages: name: torchrunx version: 0.2.0 path: . - sha256: 6e56e7865dd6c8758124a48d6eb086255880194f81b8c81f385bcf00429164d8 + sha256: 9d23ebc62f9f7c16307df31a08e09204d6de3aa4879029364f930e3f9c62a725 requires_dist: - cloudpickle>=3.0 - fabric>=3.2 diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 3965935d..60341914 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -216,7 +216,7 @@ def launch( :param extra_env_vars: Additional, user-specified variables to copy. :param env_file: A file (like ``.env``) with additional environment variables to copy. :param log_handlers: A list of handlers to manage agent and worker logs. Default uses an automatic basic logging scheme. - :raises RuntimeError: If ``torch.distributed`` not available + :raises RuntimeError: If ``torch.distributed`` not available or "slurm" specified but no allocation is detected. :raises AgentKilledError: If any agent is killed :raises Exception: Propagates exceptions raised in worker processes """ # noqa: E501 From 748c2b75e7e2c642a28fe2b338e47b121e5bede8 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 20 Oct 2024 21:36:20 -0400 Subject: [PATCH 09/31] adding napoleon for google docs --- docs/source/advanced.rst | 10 +++++---- docs/source/conf.py | 45 ++++++++++++++++++++------------------- src/torchrunx/launcher.py | 1 + src/torchrunx/utils.py | 1 + 4 files changed, 31 insertions(+), 26 deletions(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 3ccb1b1e..87ae72cd 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -25,13 +25,14 @@ We could also launch multiple functions (e.g. train on many GPUs, test on one GP print(f'Accuracy: {accuracy}') + ``trx.launch()`` is self-cleaning: all processes are terminated (and the used memory is completely released) after each invocation. SLURM integration ----------------- -By default, the ``hostnames`` or ``workers_per_host`` :mod:`torchrunx.launch` arguments are populated from the current SLURM allocation. If no allocation is detected, we assume 1 machine (``localhost``) with N GPUs or CPUs. +By default, the ``hostnames`` or ``workers_per_host`` arguments are populated from the current SLURM allocation. If no allocation is detected, we assume 1 machine (``localhost``) with N GPUs or CPUs. Raises a ``RuntimeError`` if ``hostnames`` or ``workers_per_host`` are intentionally set to ``"slurm"`` but no allocation is detected. CLI support @@ -57,6 +58,7 @@ We can use this class to populate arguments from the CLI (e.g. with `tyro None: self.hostnames: list[str] = hostnames self.return_values: list[list[Any]] = [s.return_values for s in agent_statuses] diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils.py index 1704644f..c669fcfb 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils.py @@ -1,4 +1,5 @@ """Common utility functions and classes.""" + from __future__ import annotations __all__ = [ From 24f4a981133882ab5f9957ca5de08f4cbd95617c Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 20 Oct 2024 23:23:01 -0400 Subject: [PATCH 10/31] linkcode --- docs/source/advanced.rst | 4 +++- docs/source/conf.py | 43 +++++++++++++++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 87ae72cd..010e50f2 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -46,10 +46,12 @@ We provide the :mod:`torchrunx.Launcher` class as an alias to :mod:`torchrunx.la We can use this class to populate arguments from the CLI (e.g. with `tyro `_): .. code:: python + import torchrunx as trx import tyro - def distributed_function(): print("Hello world!") + def distributed_function(): + pass if __name__ == "__main__": launcher = tyro.cli(trx.Launcher) diff --git a/docs/source/conf.py b/docs/source/conf.py index 55171c95..22f7fcfc 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,16 +13,14 @@ extensions = [ "sphinx.ext.duration", - "sphinx.ext.doctest", "sphinx.ext.autodoc", - "sphinx.ext.autosummary", "sphinx.ext.intersphinx", "myst_parser", "sphinx_toolbox.sidebar_links", "sphinx_toolbox.github", "sphinx.ext.autodoc.typehints", "sphinx.ext.napoleon", - #"sphinx_autodoc_typehints", + "sphinx.ext.linkcode", ] autodoc_typehints = "both" @@ -50,3 +48,42 @@ # code block syntax highlighting #pygments_style = "sphinx" + +code_url = f"https://github.com/{github_username}/{github_repository}/blob/{commit}" + +import importlib +import inspect + +def linkcode_resolve(domain, info): + # Non-linkable objects from the starter kit in the tutorial. + if domain == "js" or info["module"] == "connect4": + return + + assert domain == "py", "expected only Python objects" + + mod = importlib.import_module(info["module"]) + if "." in info["fullname"]: + objname, attrname = info["fullname"].split(".") + obj = getattr(mod, objname) + try: + # object is a method of a class + obj = getattr(obj, attrname) + except AttributeError: + # object is an attribute of a class + return None + else: + obj = getattr(mod, info["fullname"]) + + try: + file = inspect.getsourcefile(obj) + lines = inspect.getsourcelines(obj) + except TypeError: + # e.g. object is a typing.Union + return None + file = os.path.relpath(file, os.path.abspath("..")) + if not file.startswith("src/websockets"): + # e.g. object is a typing.NewType + return None + start, end = lines[1], lines[1] + len(lines[0]) - 1 + + return f"{code_url}/{file}#L{start}-L{end}" From cb6620c5406bc96770d17a113b903f36cf58caf0 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 20 Oct 2024 23:35:53 -0400 Subject: [PATCH 11/31] update linkcode --- docs/source/conf.py | 106 ++++++++++++++++++++++++-------------------- pyproject.toml | 1 + 2 files changed, 60 insertions(+), 47 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 22f7fcfc..449f1749 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -5,11 +5,10 @@ # Configuration file for the Sphinx documentation builder. -# -- Project information - project = "torchrunx" - -# -- General configuration +github_username = "apoorvkh" +github_repository = "torchrunx" +html_theme = "furo" extensions = [ "sphinx.ext.duration", @@ -23,67 +22,80 @@ "sphinx.ext.linkcode", ] -autodoc_typehints = "both" -#typehints_defaults = "comma" - -github_username = "apoorvkh" -github_repository = "torchrunx" +autodoc_typehints = "description" autodoc_mock_imports = ["torch", "fabric", "cloudpickle", "typing_extensions"] intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), - "sphinx": ("https://www.sphinx-doc.org/en/master/", None), } intersphinx_disabled_domains = ["std"] templates_path = ["_templates"] -# -- Options for HTML output -html_theme = "furo" +## Link code to Github source +# From: https://github.com/scikit-learn/scikit-learn/blob/main/doc/sphinxext/github_link.py -# -- Options for EPUB output -epub_show_urls = "footnote" +import inspect +import os +import subprocess +import sys +from functools import partial +from operator import attrgetter -# code block syntax highlighting -#pygments_style = "sphinx" -code_url = f"https://github.com/{github_username}/{github_repository}/blob/{commit}" +def _linkcode_resolve(domain, info, package, url_fmt, revision): + if revision is None: + return + if domain not in ("py", "pyx"): + return + if not info.get("module") or not info.get("fullname"): + return -import importlib -import inspect + class_name = info["fullname"].split(".")[0] + module = __import__(info["module"], fromlist=[class_name]) + obj = attrgetter(info["fullname"])(module) -def linkcode_resolve(domain, info): - # Non-linkable objects from the starter kit in the tutorial. - if domain == "js" or info["module"] == "connect4": + # Unwrap the object to get the correct source + # file in case that is wrapped by a decorator + obj = inspect.unwrap(obj) + + try: + fn = inspect.getsourcefile(obj) + except Exception: + fn = None + if not fn: + try: + fn = inspect.getsourcefile(sys.modules[obj.__module__]) + except Exception: + fn = None + if not fn: return - assert domain == "py", "expected only Python objects" + fn = os.path.relpath(fn, start=os.path.dirname(__import__(package).__file__)) + try: + lineno = inspect.getsourcelines(obj)[1] + except Exception: + lineno = "" + return url_fmt.format(revision=revision, package=package, path=fn, lineno=lineno) - mod = importlib.import_module(info["module"]) - if "." in info["fullname"]: - objname, attrname = info["fullname"].split(".") - obj = getattr(mod, objname) - try: - # object is a method of a class - obj = getattr(obj, attrname) - except AttributeError: - # object is an attribute of a class - return None - else: - obj = getattr(mod, info["fullname"]) +def make_linkcode_resolve(package, url_fmt): try: - file = inspect.getsourcefile(obj) - lines = inspect.getsourcelines(obj) - except TypeError: - # e.g. object is a typing.Union - return None - file = os.path.relpath(file, os.path.abspath("..")) - if not file.startswith("src/websockets"): - # e.g. object is a typing.NewType - return None - start, end = lines[1], lines[1] + len(lines[0]) - 1 - - return f"{code_url}/{file}#L{start}-L{end}" + revision = ( + subprocess.check_output("git rev-parse --short HEAD".split()).strip().decode("utf-8") + ) + except (subprocess.CalledProcessError, OSError): + print("Failed to execute git to get revision") + revision = None + return partial(_linkcode_resolve, revision=revision, package=package, url_fmt=url_fmt) + + +linkcode_resolve = make_linkcode_resolve( + "torchrunx", + ( + f"https://github.com/{github_username}/{github_repository}/" + "blob/{revision}/{package}/{path}#L{lineno}" + ), +) diff --git a/pyproject.toml b/pyproject.toml index b1a12e8f..3fb9d1ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ Documentation = "https://torchrunx.readthedocs.io" [tool.ruff] include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"] +exclude = ["docs"] line-length = 100 src = ["src", "tests"] [tool.ruff.lint] From 3eb297cccefd6a09f6c9ffad1b52b2eba88a9f62 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 20 Oct 2024 23:39:23 -0400 Subject: [PATCH 12/31] try again --- docs/source/conf.py | 36 ++++++++++++++---------------------- 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index 449f1749..45146d93 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -41,11 +41,23 @@ import os import subprocess import sys -from functools import partial from operator import attrgetter -def _linkcode_resolve(domain, info, package, url_fmt, revision): +try: + revision = ( + subprocess.check_output("git rev-parse --short HEAD".split()).strip().decode("utf-8") + ) +except (subprocess.CalledProcessError, OSError): + print("Failed to execute git to get revision") + revision = None + +url_fmt = ( + f"https://github.com/{github_username}/{github_repository}/" + "blob/{revision}/{package}/{path}#L{lineno}" +) + +def linkcode_resolve(domain, info): if revision is None: return if domain not in ("py", "pyx"): @@ -79,23 +91,3 @@ def _linkcode_resolve(domain, info, package, url_fmt, revision): except Exception: lineno = "" return url_fmt.format(revision=revision, package=package, path=fn, lineno=lineno) - - -def make_linkcode_resolve(package, url_fmt): - try: - revision = ( - subprocess.check_output("git rev-parse --short HEAD".split()).strip().decode("utf-8") - ) - except (subprocess.CalledProcessError, OSError): - print("Failed to execute git to get revision") - revision = None - return partial(_linkcode_resolve, revision=revision, package=package, url_fmt=url_fmt) - - -linkcode_resolve = make_linkcode_resolve( - "torchrunx", - ( - f"https://github.com/{github_username}/{github_repository}/" - "blob/{revision}/{package}/{path}#L{lineno}" - ), -) From e609f54868a67cd7a6d893def348ff90ca632d3e Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 20 Oct 2024 23:41:03 -0400 Subject: [PATCH 13/31] fix? --- docs/source/conf.py | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index 45146d93..c364267d 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -43,6 +43,7 @@ import sys from operator import attrgetter +package = project try: revision = ( From e88e32012ed68717dd3dfea079c53c1ec6ad391b Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Sun, 20 Oct 2024 23:43:51 -0400 Subject: [PATCH 14/31] now linkcode works --- docs/source/conf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index c364267d..cfb2ef39 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -55,7 +55,7 @@ url_fmt = ( f"https://github.com/{github_username}/{github_repository}/" - "blob/{revision}/{package}/{path}#L{lineno}" + "blob/{revision}/src/{package}/{path}#L{lineno}" ) def linkcode_resolve(domain, info): From bef8b281b128cffd1d6a1641e7fcd9183ce697b0 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Mon, 21 Oct 2024 00:17:03 -0400 Subject: [PATCH 15/31] updates --- docs/source/advanced.rst | 6 +++--- docs/source/conf.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 010e50f2..1a82818b 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -26,13 +26,13 @@ We could also launch multiple functions (e.g. train on many GPUs, test on one GP print(f'Accuracy: {accuracy}') -``trx.launch()`` is self-cleaning: all processes are terminated (and the used memory is completely released) after each invocation. +:mod:`torchrunx.launch` is self-cleaning: all processes are terminated (and the used memory is completely released) after each invocation. SLURM integration ----------------- -By default, the ``hostnames`` or ``workers_per_host`` arguments are populated from the current SLURM allocation. If no allocation is detected, we assume 1 machine (``localhost``) with N GPUs or CPUs. +By default, the ``hostnames`` or ``workers_per_host`` arguments are populated from the current SLURM allocation. If no allocation is detected, we assume 1 machine (``localhost``) with N workers (num. GPUs or CPUs). Raises a ``RuntimeError`` if ``hostnames`` or ``workers_per_host`` are intentionally set to ``"slurm"`` but no allocation is detected. CLI support @@ -41,7 +41,7 @@ CLI support We provide the :mod:`torchrunx.Launcher` class as an alias to :mod:`torchrunx.launch`. .. autoclass:: torchrunx.Launcher - :members: + :members: run We can use this class to populate arguments from the CLI (e.g. with `tyro `_): diff --git a/docs/source/conf.py b/docs/source/conf.py index cfb2ef39..cabeb79c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,17 +22,15 @@ "sphinx.ext.linkcode", ] -autodoc_typehints = "description" - autodoc_mock_imports = ["torch", "fabric", "cloudpickle", "typing_extensions"] +autodoc_typehints = "both" +autodoc_typehints_description_target = "documented_params" intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), } intersphinx_disabled_domains = ["std"] -templates_path = ["_templates"] - ## Link code to Github source # From: https://github.com/scikit-learn/scikit-learn/blob/main/doc/sphinxext/github_link.py @@ -92,3 +90,5 @@ def linkcode_resolve(domain, info): except Exception: lineno = "" return url_fmt.format(revision=revision, package=package, path=fn, lineno=lineno) + +## End of "link code to Github source" From 86bb67b0e1674bf40331000ef0ecce049cb565c6 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Mon, 21 Oct 2024 00:46:34 -0400 Subject: [PATCH 16/31] automethod run for launcher --- docs/source/advanced.rst | 31 +++++++++++++++++-------------- 1 file changed, 17 insertions(+), 14 deletions(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 1a82818b..d4f26ce4 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -28,22 +28,19 @@ We could also launch multiple functions (e.g. train on many GPUs, test on one GP :mod:`torchrunx.launch` is self-cleaning: all processes are terminated (and the used memory is completely released) after each invocation. - -SLURM integration ------------------ - -By default, the ``hostnames`` or ``workers_per_host`` arguments are populated from the current SLURM allocation. If no allocation is detected, we assume 1 machine (``localhost``) with N workers (num. GPUs or CPUs). -Raises a ``RuntimeError`` if ``hostnames`` or ``workers_per_host`` are intentionally set to ``"slurm"`` but no allocation is detected. - -CLI support ------------ +Launcher class +-------------- We provide the :mod:`torchrunx.Launcher` class as an alias to :mod:`torchrunx.launch`. .. autoclass:: torchrunx.Launcher - :members: run + :single-line-parameter-list: + .. automethod:: run -We can use this class to populate arguments from the CLI (e.g. with `tyro `_): +CLI integration +^^^^^^^^^^^^^^^ + +We can use :mod:`torchrunx.Launcher` to populate arguments from the CLI (e.g. with `tyro `_): .. code:: python @@ -80,14 +77,20 @@ We can use this class to populate arguments from the CLI (e.g. with `tyro Date: Mon, 21 Oct 2024 00:52:10 -0400 Subject: [PATCH 17/31] maximum_signature_line_length --- docs/source/advanced.rst | 1 - docs/source/conf.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index d4f26ce4..01c93a6e 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -34,7 +34,6 @@ Launcher class We provide the :mod:`torchrunx.Launcher` class as an alias to :mod:`torchrunx.launch`. .. autoclass:: torchrunx.Launcher - :single-line-parameter-list: .. automethod:: run CLI integration diff --git a/docs/source/conf.py b/docs/source/conf.py index cabeb79c..966816e9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,6 +26,8 @@ autodoc_typehints = "both" autodoc_typehints_description_target = "documented_params" +maximum_signature_line_length = 100 + intersphinx_mapping = { "python": ("https://docs.python.org/3/", None), } From 9950e96501e33defc60fe9877fdb7bc101e57265 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Mon, 21 Oct 2024 00:58:30 -0400 Subject: [PATCH 18/31] switch to members? --- docs/source/advanced.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 01c93a6e..49c66382 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -34,7 +34,7 @@ Launcher class We provide the :mod:`torchrunx.Launcher` class as an alias to :mod:`torchrunx.launch`. .. autoclass:: torchrunx.Launcher - .. automethod:: run + :members: CLI integration ^^^^^^^^^^^^^^^ From f335140d518ed8e441825811fdd4fbc66deab24f Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Mon, 28 Oct 2024 23:36:23 -0400 Subject: [PATCH 19/31] created utils/ --- src/torchrunx/__init__.py | 4 +-- src/torchrunx/__main__.py | 2 +- src/torchrunx/agent.py | 6 ++-- src/torchrunx/launcher.py | 29 ++++++++----------- src/torchrunx/utils/__init__.py | 0 src/torchrunx/{utils.py => utils/comm.py} | 16 ++-------- src/torchrunx/{ => utils}/environment.py | 0 src/torchrunx/utils/errors.py | 20 +++++++++++++ .../{logging_utils.py => utils/logging.py} | 0 9 files changed, 41 insertions(+), 36 deletions(-) create mode 100644 src/torchrunx/utils/__init__.py rename src/torchrunx/{utils.py => utils/comm.py} (96%) rename src/torchrunx/{ => utils}/environment.py (100%) create mode 100644 src/torchrunx/utils/errors.py rename src/torchrunx/{logging_utils.py => utils/logging.py} (100%) diff --git a/src/torchrunx/__init__.py b/src/torchrunx/__init__.py index 177c444e..39783314 100644 --- a/src/torchrunx/__init__.py +++ b/src/torchrunx/__init__.py @@ -1,6 +1,6 @@ from .launcher import Launcher, LaunchResult, launch -from .logging_utils import add_filter_to_handler, file_handler, stream_handler -from .utils import AgentFailedError, WorkerFailedError +from .utils.errors import AgentFailedError, WorkerFailedError +from .utils.logging import add_filter_to_handler, file_handler, stream_handler __all__ = [ "AgentFailedError", diff --git a/src/torchrunx/__main__.py b/src/torchrunx/__main__.py index c3b3099e..facdf7fe 100644 --- a/src/torchrunx/__main__.py +++ b/src/torchrunx/__main__.py @@ -1,7 +1,7 @@ from argparse import ArgumentParser from .agent import main -from .utils import LauncherAgentGroup +from .utils.comm import LauncherAgentGroup if __name__ == "__main__": parser = ArgumentParser() diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index f9b1576c..054dec52 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -17,14 +17,14 @@ import torch.distributed as dist import torch.distributed.elastic.multiprocessing as dist_mp -from .logging_utils import log_records_to_socket, redirect_stdio_to_logger -from .utils import ( +from .utils.comm import ( AgentPayload, AgentStatus, - ExceptionFromWorker, LauncherAgentGroup, get_open_port, ) +from .utils.errors import ExceptionFromWorker +from .utils.logging import log_records_to_socket, redirect_stdio_to_logger @dataclass diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index b6bfa11a..3fe0259c 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -1,6 +1,6 @@ from __future__ import annotations -__all__ = ["AgentKilledError", "Launcher", "launch", "LaunchResult"] +__all__ = ["Launcher", "launch", "LaunchResult"] import fnmatch import ipaddress @@ -22,16 +22,18 @@ import fabric import torch.distributed as dist -from .environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers -from .logging_utils import LogRecordSocketReceiver, default_handlers -from .utils import ( +from .utils.comm import ( AgentStatus, - ExceptionFromWorker, LauncherAgentGroup, LauncherPayload, - WorkerFailedError, get_open_port, ) +from .utils.environment import auto_hosts, auto_workers, slurm_hosts, slurm_workers +from .utils.errors import ( + ExceptionFromWorker, + WorkerFailedError, +) +from .utils.logging import LogRecordSocketReceiver, default_handlers @dataclass @@ -235,9 +237,7 @@ def launch( class LaunchResult: - """ - A class that holds worker return values, created by :mod:``torchrunx.launch`` or :mod:``torchrunx.Launcher.run``. - """ + """A class that holds worker return values, created by :mod:``torchrunx.launch`` or :mod:``torchrunx.Launcher.run``.""" def __init__(self, hostnames: list[str], agent_statuses: list[AgentStatus]) -> None: self.hostnames: list[str] = hostnames @@ -256,8 +256,7 @@ def all(self, by: Literal["rank"]) -> list[Any]: pass def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]: - """ - Get all worker return values by rank or hostname. + """Get all worker return values by rank or hostname. Returns a list of return values ordered by global rank, or a dictionary mapping hostnames to lists of return values ordered by local rank. """ if by == "hostname": @@ -269,16 +268,12 @@ def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[An raise TypeError(msg) def values(self, hostname: str) -> list[Any]: - """ - Get worker return values for host ``hostname``. - """ + """Get worker return values for host ``hostname``.""" host_idx = self.hostnames.index(hostname) return self.return_values[host_idx] def value(self, rank: int) -> Any: - """ - Get worker return value from global rank ``rank``. - """ + """Get worker return value from global rank ``rank``.""" if rank < 0: msg = f"Rank {rank} must be larger than 0" raise ValueError(msg) diff --git a/src/torchrunx/utils/__init__.py b/src/torchrunx/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/torchrunx/utils.py b/src/torchrunx/utils/comm.py similarity index 96% rename from src/torchrunx/utils.py rename to src/torchrunx/utils/comm.py index 02357a64..c1794e72 100644 --- a/src/torchrunx/utils.py +++ b/src/torchrunx/utils/comm.py @@ -1,4 +1,4 @@ -"""Common utility functions and classes.""" +"""Utilities for Launcher-Agent communication.""" from __future__ import annotations @@ -21,6 +21,8 @@ import torch.distributed as dist from typing_extensions import Self +from .errors import AgentFailedError, ExceptionFromWorker, WorkerFailedError + if TYPE_CHECKING: from torch.distributed.elastic.multiprocessing.api import RunProcsResult @@ -32,14 +34,6 @@ def get_open_port() -> int: return s.getsockname()[1] -class AgentFailedError(Exception): - pass - - -class WorkerFailedError(Exception): - pass - - @dataclass class LauncherAgentGroup: """Initializes a GLOO distributed process group between launcher and all agents.""" @@ -126,10 +120,6 @@ class AgentPayload: process_id: int -class ExceptionFromWorker: - exception: Exception - - @dataclass class AgentStatus: """Status of each agent (to be synchronized in LauncherAgentGroup). diff --git a/src/torchrunx/environment.py b/src/torchrunx/utils/environment.py similarity index 100% rename from src/torchrunx/environment.py rename to src/torchrunx/utils/environment.py diff --git a/src/torchrunx/utils/errors.py b/src/torchrunx/utils/errors.py new file mode 100644 index 00000000..c49f1726 --- /dev/null +++ b/src/torchrunx/utils/errors.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +__all__ = [ + "AgentFailedError", + "WorkerFailedError", + "ExceptionFromWorker", +] + + +class AgentFailedError(Exception): + pass + + +class WorkerFailedError(Exception): + pass + + +@dataclass +class ExceptionFromWorker: + exception: Exception diff --git a/src/torchrunx/logging_utils.py b/src/torchrunx/utils/logging.py similarity index 100% rename from src/torchrunx/logging_utils.py rename to src/torchrunx/utils/logging.py From 0b5e31607019eac1252e959ec512047fc9b16b94 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Mon, 28 Oct 2024 23:49:08 -0400 Subject: [PATCH 20/31] moved functions to worker.py --- src/torchrunx/agent.py | 87 +------------------------------------- src/torchrunx/worker.py | 94 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 85 deletions(-) create mode 100644 src/torchrunx/worker.py diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 054dec52..3c571ea2 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -2,19 +2,13 @@ __all__ = ["main"] -import datetime import logging import os import socket import sys import tempfile -import traceback -from dataclasses import dataclass -from typing import Any, Callable, Literal -import cloudpickle import torch -import torch.distributed as dist import torch.distributed.elastic.multiprocessing as dist_mp from .utils.comm import ( @@ -23,85 +17,8 @@ LauncherAgentGroup, get_open_port, ) -from .utils.errors import ExceptionFromWorker from .utils.logging import log_records_to_socket, redirect_stdio_to_logger - - -@dataclass -class WorkerArgs: - function: Callable - logger_hostname: str - logger_port: int - main_agent_hostname: str - main_agent_port: int - backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None - rank: int - local_rank: int - local_world_size: int - world_size: int - hostname: str - timeout: int - - def serialize(self) -> SerializedWorkerArgs: - return SerializedWorkerArgs(worker_args=self) - - -class SerializedWorkerArgs: - def __init__(self, worker_args: WorkerArgs) -> None: - self.bytes = cloudpickle.dumps(worker_args) - - def deserialize(self) -> WorkerArgs: - return cloudpickle.loads(self.bytes) - - -def _entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | ExceptionFromWorker: - worker_args: WorkerArgs = serialized_worker_args.deserialize() - - logger = logging.getLogger() - - log_records_to_socket( - logger=logger, - hostname=worker_args.hostname, - local_rank=worker_args.local_rank, - logger_hostname=worker_args.logger_hostname, - logger_port=worker_args.logger_port, - ) - - redirect_stdio_to_logger(logger) - - os.environ["RANK"] = str(worker_args.rank) - os.environ["LOCAL_RANK"] = str(worker_args.local_rank) - os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size) - os.environ["WORLD_SIZE"] = str(worker_args.world_size) - os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname - os.environ["MASTER_PORT"] = str(worker_args.main_agent_port) - - if worker_args.backend is not None: - backend = worker_args.backend - if backend == "auto": - backend = "nccl" if torch.cuda.is_available() else "gloo" - - dist.init_process_group( - backend=backend, - world_size=worker_args.world_size, - rank=worker_args.rank, - store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage] - host_name=worker_args.main_agent_hostname, - port=worker_args.main_agent_port, - world_size=worker_args.world_size, - is_master=(worker_args.rank == 0), - ), - timeout=datetime.timedelta(seconds=worker_args.timeout), - ) - - try: - return worker_args.function() - except Exception as e: - traceback.print_exc() - return ExceptionFromWorker(exception=e) - finally: - sys.stdout.flush() - sys.stderr.flush() +from .worker import WorkerArgs, entrypoint def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int) -> None: @@ -137,7 +54,7 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ ctx = dist_mp.start_processes( name=f"{hostname}_", - entrypoint=_entrypoint, + entrypoint=entrypoint, args={ i: ( WorkerArgs( diff --git a/src/torchrunx/worker.py b/src/torchrunx/worker.py new file mode 100644 index 00000000..53e27a90 --- /dev/null +++ b/src/torchrunx/worker.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +import datetime +import logging +import os +import sys +import traceback +from dataclasses import dataclass +from typing import Any, Callable, Literal + +import cloudpickle +import torch +import torch.distributed as dist + +from .utils.errors import ExceptionFromWorker +from .utils.logging import log_records_to_socket, redirect_stdio_to_logger + +__all__ = ["WorkerArgs", "entrypoint"] + +@dataclass +class WorkerArgs: + function: Callable + logger_hostname: str + logger_port: int + main_agent_hostname: str + main_agent_port: int + backend: Literal["nccl", "gloo", "mpi", "ucc", "auto"] | None + rank: int + local_rank: int + local_world_size: int + world_size: int + hostname: str + timeout: int + + def serialize(self) -> SerializedWorkerArgs: + return SerializedWorkerArgs(worker_args=self) + + +class SerializedWorkerArgs: + def __init__(self, worker_args: WorkerArgs) -> None: + self.bytes = cloudpickle.dumps(worker_args) + + def deserialize(self) -> WorkerArgs: + return cloudpickle.loads(self.bytes) + + +def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | ExceptionFromWorker: + worker_args: WorkerArgs = serialized_worker_args.deserialize() + + logger = logging.getLogger() + + log_records_to_socket( + logger=logger, + hostname=worker_args.hostname, + local_rank=worker_args.local_rank, + logger_hostname=worker_args.logger_hostname, + logger_port=worker_args.logger_port, + ) + + redirect_stdio_to_logger(logger) + + os.environ["RANK"] = str(worker_args.rank) + os.environ["LOCAL_RANK"] = str(worker_args.local_rank) + os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size) + os.environ["WORLD_SIZE"] = str(worker_args.world_size) + os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname + os.environ["MASTER_PORT"] = str(worker_args.main_agent_port) + + if worker_args.backend is not None: + backend = worker_args.backend + if backend == "auto": + backend = "nccl" if torch.cuda.is_available() else "gloo" + + dist.init_process_group( + backend=backend, + world_size=worker_args.world_size, + rank=worker_args.rank, + store=dist.TCPStore( # pyright: ignore [reportPrivateImportUsage] + host_name=worker_args.main_agent_hostname, + port=worker_args.main_agent_port, + world_size=worker_args.world_size, + is_master=(worker_args.rank == 0), + ), + timeout=datetime.timedelta(seconds=worker_args.timeout), + ) + + try: + return worker_args.function() + except Exception as e: + traceback.print_exc() + return ExceptionFromWorker(exception=e) + finally: + sys.stdout.flush() + sys.stderr.flush() From 084061f9d8000ca2c52b7353f8deffe4ea8bf311 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Tue, 29 Oct 2024 11:20:40 -0400 Subject: [PATCH 21/31] renamed to worker_entrypoint --- src/torchrunx/agent.py | 4 ++-- src/torchrunx/worker.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 3c571ea2..90ff193e 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -18,7 +18,7 @@ get_open_port, ) from .utils.logging import log_records_to_socket, redirect_stdio_to_logger -from .worker import WorkerArgs, entrypoint +from .worker import WorkerArgs, worker_entrypoint def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int) -> None: @@ -54,7 +54,7 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ ctx = dist_mp.start_processes( name=f"{hostname}_", - entrypoint=entrypoint, + entrypoint=worker_entrypoint, args={ i: ( WorkerArgs( diff --git a/src/torchrunx/worker.py b/src/torchrunx/worker.py index 53e27a90..ca4df74b 100644 --- a/src/torchrunx/worker.py +++ b/src/torchrunx/worker.py @@ -15,7 +15,8 @@ from .utils.errors import ExceptionFromWorker from .utils.logging import log_records_to_socket, redirect_stdio_to_logger -__all__ = ["WorkerArgs", "entrypoint"] +__all__ = ["WorkerArgs", "worker_entrypoint"] + @dataclass class WorkerArgs: @@ -44,7 +45,7 @@ def deserialize(self) -> WorkerArgs: return cloudpickle.loads(self.bytes) -def entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | ExceptionFromWorker: +def worker_entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | ExceptionFromWorker: worker_args: WorkerArgs = serialized_worker_args.deserialize() logger = logging.getLogger() From 6cc931160f604b41ab7cfb42f929ac7896251acd Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Tue, 29 Oct 2024 11:40:20 -0400 Subject: [PATCH 22/31] completed docs for utils --- src/torchrunx/__init__.py | 2 + src/torchrunx/__main__.py | 2 + src/torchrunx/utils/__init__.py | 1 + src/torchrunx/utils/environment.py | 23 +--- src/torchrunx/utils/errors.py | 8 +- src/torchrunx/utils/logging.py | 208 ++++++++++++++--------------- 6 files changed, 121 insertions(+), 123 deletions(-) diff --git a/src/torchrunx/__init__.py b/src/torchrunx/__init__.py index 39783314..405a7715 100644 --- a/src/torchrunx/__init__.py +++ b/src/torchrunx/__init__.py @@ -1,3 +1,5 @@ +"""API for our torchrunx library.""" + from .launcher import Launcher, LaunchResult, launch from .utils.errors import AgentFailedError, WorkerFailedError from .utils.logging import add_filter_to_handler, file_handler, stream_handler diff --git a/src/torchrunx/__main__.py b/src/torchrunx/__main__.py index facdf7fe..8626c1f8 100644 --- a/src/torchrunx/__main__.py +++ b/src/torchrunx/__main__.py @@ -1,3 +1,5 @@ +"""CLI entrypoint used for starting agents on different nodes.""" + from argparse import ArgumentParser from .agent import main diff --git a/src/torchrunx/utils/__init__.py b/src/torchrunx/utils/__init__.py index e69de29b..b2aa5714 100644 --- a/src/torchrunx/utils/__init__.py +++ b/src/torchrunx/utils/__init__.py @@ -0,0 +1 @@ +"""Utility classes and functions.""" diff --git a/src/torchrunx/utils/environment.py b/src/torchrunx/utils/environment.py index ed27d1f6..9e3c0be9 100644 --- a/src/torchrunx/utils/environment.py +++ b/src/torchrunx/utils/environment.py @@ -1,3 +1,5 @@ +"""Utilities for determining hosts and workers in environment.""" + from __future__ import annotations __all__ = ["in_slurm_job", "slurm_hosts", "slurm_workers", "auto_hosts", "auto_workers"] @@ -9,6 +11,7 @@ def in_slurm_job() -> bool: + """Check if current process is running in a Slurm allocation.""" return "SLURM_JOB_ID" in os.environ @@ -31,12 +34,7 @@ def slurm_hosts() -> list[str]: def slurm_workers() -> int: - """| Determines number of workers per node in current Slurm allocation using - | the ``SLURM_JOB_GPUS`` or ``SLURM_CPUS_ON_NODE`` environmental variables. - - :return: The implied number of workers per node - :rtype: int - """ + """Determines number of workers per node in current Slurm allocation.""" # TODO: sanity check SLURM variables, commands if not in_slurm_job(): msg = "Not in a SLURM job" @@ -48,16 +46,11 @@ def slurm_workers() -> int: if "SLURM_GPUS_PER_NODE" in os.environ: return int(os.environ["SLURM_GPUS_PER_NODE"]) - # TODO: should we assume that we plan to do one worker per CPU? return int(os.environ["SLURM_CPUS_ON_NODE"]) def auto_hosts() -> list[str]: - """Automatically determine hostname list - - :return: Hostnames in Slurm allocation, or ['localhost'] - :rtype: list[str] - """ + """Automatically determine hostnames to launch to.""" if in_slurm_job(): return slurm_hosts() @@ -65,11 +58,7 @@ def auto_hosts() -> list[str]: def auto_workers() -> int: - """Automatically determine number of workers per host - - :return: Workers per host - :rtype: int - """ + """Automatically determine workers per host from SLURM or based on GPU/CPU count.""" if in_slurm_job(): return slurm_workers() diff --git a/src/torchrunx/utils/errors.py b/src/torchrunx/utils/errors.py index c49f1726..e6bb3d24 100644 --- a/src/torchrunx/utils/errors.py +++ b/src/torchrunx/utils/errors.py @@ -1,3 +1,5 @@ +"""Exception classes for agents and workers.""" + from dataclasses import dataclass __all__ = [ @@ -8,13 +10,15 @@ class AgentFailedError(Exception): - pass + """Raised if agent fails (e.g. if signal received).""" class WorkerFailedError(Exception): - pass + """Raised if a worker fails (e.g. if signal recieved or segmentation fault).""" @dataclass class ExceptionFromWorker: + """Container for exceptions raised inside workers (from user script).""" + exception: Exception diff --git a/src/torchrunx/utils/logging.py b/src/torchrunx/utils/logging.py index b6a4074b..30659f69 100644 --- a/src/torchrunx/utils/logging.py +++ b/src/torchrunx/utils/logging.py @@ -32,6 +32,110 @@ if TYPE_CHECKING: import os +## Handler utilities + + +def add_filter_to_handler( + handler: Handler, + hostname: str, + local_rank: int | None, # None indicates agent + log_level: int = logging.NOTSET, +) -> None: + """A filter for ``logging.Handler`` such that only specific agent/worker logs are handled. + + Args: + handler: ``logging.Handler`` to be modified. + hostname: Name of specified host. + local_rank: Rank of specified worker (or ``None`` for agent). + log_level: Minimum log level to capture. + """ + + def _filter(record: WorkerLogRecord) -> bool: + return ( + record.hostname == hostname + and record.local_rank == local_rank + and record.levelno >= log_level + ) + + handler.addFilter(_filter) # pyright: ignore [reportArgumentType] + + +def stream_handler( + hostname: str, local_rank: int | None, log_level: int = logging.NOTSET +) -> Handler: + """logging.Handler builder function for writing logs to stdout.""" + handler = logging.StreamHandler(stream=sys.stdout) + add_filter_to_handler(handler, hostname, local_rank, log_level=log_level) + handler.setFormatter( + logging.Formatter( + "%(asctime)s:%(levelname)s:%(hostname)s[%(local_rank)s]: %(message)s" + if local_rank is not None + else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s", + ), + ) + return handler + + +def file_handler( + hostname: str, + local_rank: int | None, + file_path: str | os.PathLike, + log_level: int = logging.NOTSET, +) -> Handler: + """logging.Handler builder function for writing logs to a file.""" + handler = logging.FileHandler(file_path) + add_filter_to_handler(handler, hostname, local_rank, log_level=log_level) + formatter = logging.Formatter("%(asctime)s:%(levelname)s: %(message)s") + handler.setFormatter(formatter) + return handler + + +def file_handlers( + hostnames: list[str], + workers_per_host: list[int], + log_dir: str | os.PathLike = Path("torchrunx_logs"), + log_level: int = logging.NOTSET, +) -> list[Handler]: + """Builder function for writing logs for all workers/agents to a directory. + + Files are named with timestamp, hostname, and the local_rank (for workers). + """ + handlers = [] + + Path(log_dir).mkdir(parents=True, exist_ok=True) + timestamp = datetime.datetime.now().isoformat(timespec="seconds") + + for hostname, num_workers in zip(hostnames, workers_per_host): + for local_rank in [None, *range(num_workers)]: + file_path = ( + f"{log_dir}/{timestamp}-{hostname}" + + (f"[{local_rank}]" if local_rank is not None else "") + + ".log" + ) + handlers.append(file_handler(hostname, local_rank, file_path, log_level=log_level)) + + return handlers + + +def default_handlers( + hostnames: list[str], + workers_per_host: list[int], + log_dir: str | os.PathLike = Path("torchrunx_logs"), + log_level: int = logging.INFO, +) -> list[Handler]: + """A default set of logging.Handlers to be used when ``launch(log_handlers="auto")``. + + Logs for host[0] and its local_rank[0] worker are written to the launcher process stdout. + Logs for all agents/workers are written to files in ``log_dir`` (named by timestamp, hostname, + local_rank). + """ + return [ + stream_handler(hostname=hostnames[0], local_rank=None, log_level=log_level), + stream_handler(hostname=hostnames[0], local_rank=0, log_level=log_level), + *file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=log_level), + ] + + ## Launcher utilities @@ -135,107 +239,3 @@ def record_factory(*args, **kwargs) -> WorkerLogRecord: # noqa: ANN002, ANN003 logging.setLogRecordFactory(record_factory) logger.addHandler(SocketHandler(host=logger_hostname, port=logger_port)) - - -## Handler utilities - - -def add_filter_to_handler( - handler: Handler, - hostname: str, - local_rank: int | None, # None indicates agent - log_level: int = logging.NOTSET, -) -> None: - """A filter for ``logging.Handler`` such that only specific agent/worker logs are handled. - - Args: - handler: ``logging.Handler`` to be modified. - hostname: Name of specified host. - local_rank: Rank of specified worker (or ``None`` for agent). - log_level: Minimum log level to capture. - """ - - def _filter(record: WorkerLogRecord) -> bool: - return ( - record.hostname == hostname - and record.local_rank == local_rank - and record.levelno >= log_level - ) - - handler.addFilter(_filter) # pyright: ignore [reportArgumentType] - - -def stream_handler( - hostname: str, local_rank: int | None, log_level: int = logging.NOTSET -) -> Handler: - """logging.Handler builder function for writing logs to stdout.""" - handler = logging.StreamHandler(stream=sys.stdout) - add_filter_to_handler(handler, hostname, local_rank, log_level=log_level) - handler.setFormatter( - logging.Formatter( - "%(asctime)s:%(levelname)s:%(hostname)s[%(local_rank)s]: %(message)s" - if local_rank is not None - else "%(asctime)s:%(levelname)s:%(hostname)s: %(message)s", - ), - ) - return handler - - -def file_handler( - hostname: str, - local_rank: int | None, - file_path: str | os.PathLike, - log_level: int = logging.NOTSET, -) -> Handler: - """logging.Handler builder function for writing logs to a file.""" - handler = logging.FileHandler(file_path) - add_filter_to_handler(handler, hostname, local_rank, log_level=log_level) - formatter = logging.Formatter("%(asctime)s:%(levelname)s: %(message)s") - handler.setFormatter(formatter) - return handler - - -def file_handlers( - hostnames: list[str], - workers_per_host: list[int], - log_dir: str | os.PathLike = Path("torchrunx_logs"), - log_level: int = logging.NOTSET, -) -> list[Handler]: - """Builder function for writing logs for all workers/agents to a directory. - - Files are named with timestamp, hostname, and the local_rank (for workers). - """ - handlers = [] - - Path(log_dir).mkdir(parents=True, exist_ok=True) - timestamp = datetime.datetime.now().isoformat(timespec="seconds") - - for hostname, num_workers in zip(hostnames, workers_per_host): - for local_rank in [None, *range(num_workers)]: - file_path = ( - f"{log_dir}/{timestamp}-{hostname}" - + (f"[{local_rank}]" if local_rank is not None else "") - + ".log" - ) - handlers.append(file_handler(hostname, local_rank, file_path, log_level=log_level)) - - return handlers - - -def default_handlers( - hostnames: list[str], - workers_per_host: list[int], - log_dir: str | os.PathLike = Path("torchrunx_logs"), - log_level: int = logging.INFO, -) -> list[Handler]: - """A default set of logging.Handlers to be used when ``launch(log_handlers="auto")``. - - Logs for host[0] and its local_rank[0] worker are written to the launcher process stdout. - Logs for all agents/workers are written to files in ``log_dir`` (named by timestamp, hostname, - local_rank). - """ - return [ - stream_handler(hostname=hostnames[0], local_rank=None, log_level=log_level), - stream_handler(hostname=hostnames[0], local_rank=0, log_level=log_level), - *file_handlers(hostnames, workers_per_host, log_dir=log_dir, log_level=log_level), - ] From 490f2a86cb4de9e0683fba0ecd239bdb6be411c3 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Tue, 29 Oct 2024 12:17:51 -0400 Subject: [PATCH 23/31] more launcher docs --- src/torchrunx/launcher.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 3fe0259c..b02535c3 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -1,3 +1,5 @@ +"""For launching functions with our library.""" + from __future__ import annotations __all__ = ["Launcher", "launch", "LaunchResult"] @@ -23,7 +25,6 @@ import torch.distributed as dist from .utils.comm import ( - AgentStatus, LauncherAgentGroup, LauncherPayload, get_open_port, @@ -38,6 +39,11 @@ @dataclass class Launcher: + """Alias class for ``torchrunx.launch``. + + Useful for sequential invocations on the same configuration or for specifying arguments via CLI. + """ + hostnames: list[str] | Literal["auto", "slurm"] = "auto" workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto" ssh_config_file: str | os.PathLike | None = None @@ -63,6 +69,7 @@ def run( # noqa: C901, PLR0912 func_kwargs: dict[str, Any] | None = None, log_handlers: list[Handler] | Literal["auto"] | None = "auto", ) -> LaunchResult: + """Run a function using the configuration in ``torchrunx.Launcher``.""" if not dist.is_available(): msg = "The torch.distributed package is not available." raise RuntimeError(msg) @@ -180,7 +187,8 @@ def run( # noqa: C901, PLR0912 ssh_config_file=self.ssh_config_file, ) - return LaunchResult(hostnames=hostnames, agent_statuses=agent_statuses) + return_values = [s.return_values for s in agent_statuses] + return LaunchResult(hostnames=hostnames, return_values=return_values) def launch( @@ -220,9 +228,10 @@ def launch( :param extra_env_vars: Additional, user-specified variables to copy. :param env_file: A file (like ``.env``) with additional environment variables to copy. :param log_handlers: A list of handlers to manage agent and worker logs. Default uses an automatic basic logging scheme. - :raises RuntimeError: If ``torch.distributed`` not available or "slurm" specified but no allocation is detected. - :raises AgentKilledError: If any agent is killed - :raises Exception: Propagates exceptions raised in worker processes + :raises RuntimeError: Due to various misconfigurations. + :raises AgentFailedError: If any agent fails (e.g. due to signal from OS). + :raises WorkerFailedError: If any worker fails (e.g. due to segmentation faults). + :raises Exception: Propagates exceptions raised in worker processes. """ # noqa: E501 return Launcher( hostnames=hostnames, @@ -236,12 +245,12 @@ def launch( ).run(func=func, func_args=func_args, func_kwargs=func_kwargs, log_handlers=log_handlers) +@dataclass class LaunchResult: - """A class that holds worker return values, created by :mod:``torchrunx.launch`` or :mod:``torchrunx.Launcher.run``.""" + """Container for objects returned from workers after successful launches.""" - def __init__(self, hostnames: list[str], agent_statuses: list[AgentStatus]) -> None: - self.hostnames: list[str] = hostnames - self.return_values: list[list[Any]] = [s.return_values for s in agent_statuses] + hostnames: list[str] + return_values: list[list[Any]] @overload def all(self) -> dict[str, list[Any]]: From e54a5338450192e40ef385b08dd882e59315d1ad Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Tue, 29 Oct 2024 12:55:22 -0400 Subject: [PATCH 24/31] more updates to docs --- src/torchrunx/agent.py | 22 +++++++++++- src/torchrunx/launcher.py | 54 ++++++++++++++++++------------ src/torchrunx/utils/comm.py | 17 +++++++--- src/torchrunx/utils/environment.py | 6 +--- src/torchrunx/utils/logging.py | 8 ++--- src/torchrunx/worker.py | 21 ++++++++++++ 6 files changed, 91 insertions(+), 37 deletions(-) diff --git a/src/torchrunx/agent.py b/src/torchrunx/agent.py index 90ff193e..27dd8489 100644 --- a/src/torchrunx/agent.py +++ b/src/torchrunx/agent.py @@ -1,3 +1,5 @@ +"""Primary logic for agent processes.""" + from __future__ import annotations __all__ = ["main"] @@ -22,8 +24,21 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_port: int) -> None: + """Main function for agent processes (started on each node). + + This function spawns local worker processes (which run the target function). All agents monitor + their worker statuses (including returned objects and raised exceptions) and communicate these + with each other (and launcher). All agents terminate if failure occurs in any agent. + + Arguments: + launcher_agent_group: The communication group between launcher and all agents. + logger_hostname: The hostname of the launcher (for logging). + logger_port: The port of the launcher (for logging). + """ agent_rank = launcher_agent_group.rank - 1 + # Communicate initial payloads between launcher/agents + payload = AgentPayload( hostname=socket.getfqdn(), port=get_open_port(), @@ -38,6 +53,8 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ worker_global_ranks = launcher_payload.worker_global_ranks[agent_rank] num_workers = len(worker_global_ranks) + # Stream logs to logging server + logger = logging.getLogger() log_records_to_socket( @@ -50,7 +67,7 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ redirect_stdio_to_logger(logger) - # spawn workers + # Spawn worker processes ctx = dist_mp.start_processes( name=f"{hostname}_", @@ -84,6 +101,9 @@ def main(launcher_agent_group: LauncherAgentGroup, logger_hostname: str, logger_ ), # pyright: ignore [reportArgumentType] ) + # Monitor and communicate agent statuses + # Terminate gracefully upon failure + try: status = None while True: diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index b02535c3..092cab61 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -87,7 +87,7 @@ def run( # noqa: C901, PLR0912 agent_payloads = None try: - # start logging server + # Start logging server (recieves LogRecords from agents/workers) log_receiver = _build_logging_server( log_handlers=log_handlers, @@ -105,7 +105,7 @@ def run( # noqa: C901, PLR0912 log_process.start() - # start agents on each node + # Start agents on each node for i, hostname in enumerate(hostnames): _execute_command( @@ -122,7 +122,7 @@ def run( # noqa: C901, PLR0912 ssh_config_file=self.ssh_config_file, ) - # initialize launcher-agent process group + # Initialize launcher-agent process group # ranks = (launcher, agent_{hostnames[0]}, ..., agent[-1]) launcher_agent_group = LauncherAgentGroup( @@ -132,7 +132,7 @@ def run( # noqa: C901, PLR0912 rank=0, ) - # build and sync payloads between launcher and agents + # Sync initial payloads between launcher and agents _cumulative_workers = [0, *itertools.accumulate(workers_per_host)] @@ -152,7 +152,7 @@ def run( # noqa: C901, PLR0912 launcher_payload, agent_payloads = launcher_agent_group.sync_payloads(payload=payload) - # loop to monitor agent statuses (until failed or done) + # Monitor agent statuses (until failed or done) while True: # could raise AgentFailedError @@ -187,6 +187,7 @@ def run( # noqa: C901, PLR0912 ssh_config_file=self.ssh_config_file, ) + # if launch is successful: return objects from workers return_values = [s.return_values for s in agent_statuses] return LaunchResult(hostnames=hostnames, return_values=return_values) @@ -216,23 +217,32 @@ def launch( ) -> LaunchResult: """Launch a distributed PyTorch function on the specified nodes. - :param func: - :param func_args: - :param func_kwargs: - :param hostnames: Nodes to launch the function on. Default infers from a SLURM environment or runs on localhost. - :param workers_per_host: Number of processes to run per node. Can define per node with :type:`list[int]`. - :param ssh_config_file: An SSH configuration file for connecting to nodes, by default loads ``~/.ssh/config`` or ``/etc/ssh/ssh_config``. - :param backend: `Backend `_ to initialize worker process group with. Default uses NCCL (if GPUs available) or GLOO. Disabled by ``None``. - :param timeout: Worker process group timeout (seconds). - :param default_env_vars: A list of environmental variables to be copied from the launcher process to workers. Allows for bash pattern matching syntax. - :param extra_env_vars: Additional, user-specified variables to copy. - :param env_file: A file (like ``.env``) with additional environment variables to copy. - :param log_handlers: A list of handlers to manage agent and worker logs. Default uses an automatic basic logging scheme. - :raises RuntimeError: Due to various misconfigurations. - :raises AgentFailedError: If any agent fails (e.g. due to signal from OS). - :raises WorkerFailedError: If any worker fails (e.g. due to segmentation faults). - :raises Exception: Propagates exceptions raised in worker processes. - """ # noqa: E501 + Arguments: + func: Function to run on each worker. + func_args: Positional arguments for ``func``. + func_kwargs: Keyword arguments for ``func``. + hostnames: Nodes on which to launch the function. + Defaults to nodes inferred from a SLURM environment or localhost. + workers_per_host: Number of processes to run per node. + Can specify different counts per node with a list. + ssh_config_file: Path to an SSH configuration file for connecting to nodes. + Defaults to ``~/.ssh/config`` or ``/etc/ssh/ssh_config``. + backend: `Backend `_ + for worker process group. Defaults to NCCL (GPU) or GLOO (CPU). Set `None` to disable. + timeout: Worker process group timeout (seconds). + default_env_vars: Environment variables to copy from the launcher process to workers. + Supports bash pattern matching syntax. + extra_env_vars: Additional user-specified environment variables to copy. + env_file: Path to a file (e.g., `.env`) with additional environment variables to copy. + log_handlers: Handlers to manage agent and worker logs. + Defaults to an automatic basic logging scheme. + + Raises: + RuntimeError: If there are configuration issues. + AgentFailedError: If an agent fails, e.g. from an OS signal. + WorkerFailedError: If a worker fails, e.g. from a segmentation fault. + Exception: Any exception raised in a worker process is propagated. + """ return Launcher( hostnames=hostnames, workers_per_host=workers_per_host, diff --git a/src/torchrunx/utils/comm.py b/src/torchrunx/utils/comm.py index c1794e72..4ed16f0a 100644 --- a/src/torchrunx/utils/comm.py +++ b/src/torchrunx/utils/comm.py @@ -47,7 +47,7 @@ def __post_init__(self) -> None: """Initialize process group. Raises: - torch.distributed.DistStoreError: if group initialization times out. + torch.distributed.DistStoreError: if group initialization times out. """ self.group = dist.init_process_group( backend="gloo", @@ -69,7 +69,11 @@ def _deserialize(self, serialized: bytes) -> Any: return cloudpickle.loads(serialized) def _all_gather(self, obj: Any) -> list: - """Gather object from every rank to list on every rank.""" + """Gather object from every rank to list on every rank. + + Raises: + AgentFailedError: if any agent fails (observed by this communication). + """ try: object_bytes = self._serialize(obj) object_list = [b""] * self.world_size @@ -125,8 +129,8 @@ class AgentStatus: """Status of each agent (to be synchronized in LauncherAgentGroup). Attributes: - state: Whether the agent is running, failed, or done. - return_values: Objects returned (or exceptions raised) by workers (indexed by local rank). + state: Whether the agent is running, failed, or done. + return_values: Objects returned (or exceptions raised) by workers (indexed by local rank). """ state: Literal["running", "failed", "done"] @@ -139,10 +143,13 @@ def from_result(cls, result: RunProcsResult | None) -> Self: """Convert RunProcsResult (from polling worker process context) to AgentStatus.""" if result is None: return cls(state="running") + for local_rank, failure in result.failures.items(): result.return_values[local_rank] = WorkerFailedError(failure.message) + return_values = list(result.return_values.values()) - failed = any(isinstance(v, ExceptionFromWorker) for v in return_values) + + failed = any(isinstance(v, (ExceptionFromWorker, WorkerFailedError)) for v in return_values) state = "failed" if failed else "done" return cls( diff --git a/src/torchrunx/utils/environment.py b/src/torchrunx/utils/environment.py index 9e3c0be9..c297f027 100644 --- a/src/torchrunx/utils/environment.py +++ b/src/torchrunx/utils/environment.py @@ -16,11 +16,7 @@ def in_slurm_job() -> bool: def slurm_hosts() -> list[str]: - """Retrieves hostnames of Slurm-allocated nodes. - - :return: Hostnames of nodes in current Slurm allocation - :rtype: list[str] - """ + """Retrieves hostnames of Slurm-allocated nodes.""" # TODO: sanity check SLURM variables, commands if not in_slurm_job(): msg = "Not in a SLURM job" diff --git a/src/torchrunx/utils/logging.py b/src/torchrunx/utils/logging.py index 30659f69..72be39b1 100644 --- a/src/torchrunx/utils/logging.py +++ b/src/torchrunx/utils/logging.py @@ -44,10 +44,10 @@ def add_filter_to_handler( """A filter for ``logging.Handler`` such that only specific agent/worker logs are handled. Args: - handler: ``logging.Handler`` to be modified. - hostname: Name of specified host. - local_rank: Rank of specified worker (or ``None`` for agent). - log_level: Minimum log level to capture. + handler: ``logging.Handler`` to be modified. + hostname: Name of specified host. + local_rank: Rank of specified worker (or ``None`` for agent). + log_level: Minimum log level to capture. """ def _filter(record: WorkerLogRecord) -> bool: diff --git a/src/torchrunx/worker.py b/src/torchrunx/worker.py index ca4df74b..12caf827 100644 --- a/src/torchrunx/worker.py +++ b/src/torchrunx/worker.py @@ -1,3 +1,5 @@ +"""Arguments and entrypoint for the worker processes.""" + from __future__ import annotations import datetime @@ -20,6 +22,8 @@ @dataclass class WorkerArgs: + """Arguments passed from agent to spawned workers.""" + function: Callable logger_hostname: str logger_port: int @@ -34,10 +38,13 @@ class WorkerArgs: timeout: int def serialize(self) -> SerializedWorkerArgs: + """Arguments must be serialized (to bytes) before passed to spawned workers.""" return SerializedWorkerArgs(worker_args=self) class SerializedWorkerArgs: + """We use cloudpickle as a serialization backend (as it supports nearly all Python types).""" + def __init__(self, worker_args: WorkerArgs) -> None: self.bytes = cloudpickle.dumps(worker_args) @@ -46,8 +53,16 @@ def deserialize(self) -> WorkerArgs: def worker_entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | ExceptionFromWorker: + """Function called by spawned worker processes. + + Workers first prepare a process group (for communicating with all other workers). + They then invoke the user-provided function. + Logs are transmitted to the launcher process. + """ worker_args: WorkerArgs = serialized_worker_args.deserialize() + # Start logging to the logging server (i.e. the launcher) + logger = logging.getLogger() log_records_to_socket( @@ -60,6 +75,8 @@ def worker_entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | Exc redirect_stdio_to_logger(logger) + # Set rank/world environment variables + os.environ["RANK"] = str(worker_args.rank) os.environ["LOCAL_RANK"] = str(worker_args.local_rank) os.environ["LOCAL_WORLD_SIZE"] = str(worker_args.local_world_size) @@ -67,6 +84,8 @@ def worker_entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | Exc os.environ["MASTER_ADDR"] = worker_args.main_agent_hostname os.environ["MASTER_PORT"] = str(worker_args.main_agent_port) + # Prepare the process group (e.g. for communication within the user's function) + if worker_args.backend is not None: backend = worker_args.backend if backend == "auto": @@ -85,6 +104,8 @@ def worker_entrypoint(serialized_worker_args: SerializedWorkerArgs) -> Any | Exc timeout=datetime.timedelta(seconds=worker_args.timeout), ) + # Invoke the user's function on this worker + try: return worker_args.function() except Exception as e: From 455c3f394e932d9e7e4a9f6eb2564e07c66492e7 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Tue, 29 Oct 2024 17:21:12 -0400 Subject: [PATCH 25/31] switched LaunchResult to get --- src/torchrunx/launcher.py | 93 ++++++++++++++++++++++++++++----------- tests/test_ci.py | 2 +- 2 files changed, 68 insertions(+), 27 deletions(-) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 092cab61..19009b05 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -252,7 +252,12 @@ def launch( default_env_vars=default_env_vars, extra_env_vars=extra_env_vars, env_file=env_file, - ).run(func=func, func_args=func_args, func_kwargs=func_kwargs, log_handlers=log_handlers) + ).run( + func=func, + func_args=func_args, + func_kwargs=func_kwargs, + log_handlers=log_handlers, + ) @dataclass @@ -262,10 +267,6 @@ class LaunchResult: hostnames: list[str] return_values: list[list[Any]] - @overload - def all(self) -> dict[str, list[Any]]: - pass - @overload def all(self, by: Literal["hostname"]) -> dict[str, list[Any]]: pass @@ -275,36 +276,76 @@ def all(self, by: Literal["rank"]) -> list[Any]: pass def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]: - """Get all worker return values by rank or hostname. - Returns a list of return values ordered by global rank, or a dictionary mapping hostnames to lists of return values ordered by local rank. - """ + """Get return values from all workers.""" if by == "hostname": return dict(zip(self.hostnames, self.return_values)) elif by == "rank": # noqa: RET505 return reduce(add, self.return_values) + else: + msg = "Invalid argument for 'by'. Must be 'hostname' or 'rank'." + raise TypeError(msg) + + @overload + def get(self, hostname: None, rank: None) -> dict[str, list[Any]]: ... + + @overload + def get(self, hostname: None, rank: int) -> Any: ... - msg = "Invalid argument: expected by=('hostname' | 'rank')" - raise TypeError(msg) + @overload + def get(self, hostname: None, rank: list[int]) -> list[Any]: ... - def values(self, hostname: str) -> list[Any]: - """Get worker return values for host ``hostname``.""" - host_idx = self.hostnames.index(hostname) - return self.return_values[host_idx] + @overload + def get(self, hostname: str, rank: None) -> list[Any]: ... - def value(self, rank: int) -> Any: - """Get worker return value from global rank ``rank``.""" - if rank < 0: - msg = f"Rank {rank} must be larger than 0" - raise ValueError(msg) + @overload + def get(self, hostname: list[str], rank: None) -> dict[str, list[Any]]: ... - for values in self.return_values: - if rank >= len(values): - rank -= len(values) - else: - return values[rank] + @overload + def get(self, hostname: str, rank: int) -> Any: ... - msg = f"Rank {rank} larger than world_size" - raise ValueError(msg) + @overload + def get(self, hostname: str, rank: list[int]) -> list[Any]: ... + + @overload + def get(self, hostname: list[str], rank: int) -> list[Any]: ... + + @overload + def get(self, hostname: list[str], rank: list[int]) -> dict[str, list[Any]]: ... + + def get( # noqa: PLR0911 + self, + hostname: str | list[str] | None = None, + rank: int | list[int] | None = None, + ) -> dict[str, list[Any]] | list[Any] | Any: + """Get return values from selected workers.""" + if hostname is None and isinstance(rank, int): + return self.all(by="rank")[rank] + + if hostname is None and isinstance(rank, list): + _values = self.all(by="rank") + return [_values[r] for r in rank] + + if isinstance(hostname, str) and rank is None: + self.return_values[self.hostnames.index(hostname)] + + if isinstance(hostname, list) and rank is None: + return {h: self.get(hostname=h) for h in hostname} + + if isinstance(hostname, str) and isinstance(rank, int): + return self.get(hostname=hostname)[rank] + + if isinstance(hostname, str) and isinstance(rank, list): + return self.get(hostname=hostname)[rank] + + if isinstance(hostname, list) and isinstance(rank, int): + return [self.get(hostname=h)[rank] for h in hostname] + + if isinstance(hostname, list) and isinstance(rank, list): + _values = self.get(hostname=hostname) + return {h: [_values[h][r] for r in rank] for h in hostname} + + # remaining case: hostname=None, rank=None + return self.all(by="hostname") def _resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: diff --git a/tests/test_ci.py b/tests/test_ci.py index 64cd1e93..fe23e987 100644 --- a/tests/test_ci.py +++ b/tests/test_ci.py @@ -37,7 +37,7 @@ def dist_func() -> torch.Tensor: backend="gloo", # log_dir="./test_logs" ) - assert torch.all(r.value(0) == r.value(1)) + assert torch.all(r.get(rank=0) == r.get(rank=1)) def test_logging() -> None: From f96721850ea1e361665794c38d2c67a815fc6391 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Tue, 29 Oct 2024 17:22:52 -0400 Subject: [PATCH 26/31] bump hash in pixi lock --- pixi.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pixi.lock b/pixi.lock index 14d646e1..3f01b378 100644 --- a/pixi.lock +++ b/pixi.lock @@ -1793,7 +1793,7 @@ packages: name: torchrunx version: 0.2.0 path: . - sha256: 9d23ebc62f9f7c16307df31a08e09204d6de3aa4879029364f930e3f9c62a725 + sha256: 2902e99e3cf4c53010007e6ad3243b25a8c7be8325ad6d9889730c9f075ed72d requires_dist: - cloudpickle>=3.0 - fabric>=3.2 From 3a68eb6f3a2c8f5523ae2b90a96fa371798f35f5 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Tue, 29 Oct 2024 18:33:43 -0400 Subject: [PATCH 27/31] removed overloading from LaunchResult --- src/torchrunx/launcher.py | 90 ++++++--------------------------------- tests/test_ci.py | 2 +- tests/test_func.py | 2 +- 3 files changed, 15 insertions(+), 79 deletions(-) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 19009b05..f53f7e0a 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -19,7 +19,7 @@ from multiprocessing import Process from operator import add from pathlib import Path -from typing import Any, Callable, Literal, overload +from typing import Any, Callable, Literal import fabric import torch.distributed as dist @@ -267,85 +267,21 @@ class LaunchResult: hostnames: list[str] return_values: list[list[Any]] - @overload - def all(self, by: Literal["hostname"]) -> dict[str, list[Any]]: - pass + def by_hostname(self) -> dict[str, list[Any]]: + """All return values from workers, indexed by host and local rank.""" + return dict(zip(self.hostnames, self.return_values)) - @overload - def all(self, by: Literal["rank"]) -> list[Any]: - pass + def by_rank(self) -> list[Any]: + """All return values from workers, indexed by global rank.""" + return reduce(add, self.return_values) - def all(self, by: Literal["hostname", "rank"] = "hostname") -> dict[str, list[Any]] | list[Any]: - """Get return values from all workers.""" - if by == "hostname": - return dict(zip(self.hostnames, self.return_values)) - elif by == "rank": # noqa: RET505 - return reduce(add, self.return_values) - else: - msg = "Invalid argument for 'by'. Must be 'hostname' or 'rank'." - raise TypeError(msg) + def get(self, hostname: str, rank: int) -> Any: + """Get return value from worker (indexed by host and local rank).""" + return self.return_values[self.hostnames.index(hostname)][rank] - @overload - def get(self, hostname: None, rank: None) -> dict[str, list[Any]]: ... - - @overload - def get(self, hostname: None, rank: int) -> Any: ... - - @overload - def get(self, hostname: None, rank: list[int]) -> list[Any]: ... - - @overload - def get(self, hostname: str, rank: None) -> list[Any]: ... - - @overload - def get(self, hostname: list[str], rank: None) -> dict[str, list[Any]]: ... - - @overload - def get(self, hostname: str, rank: int) -> Any: ... - - @overload - def get(self, hostname: str, rank: list[int]) -> list[Any]: ... - - @overload - def get(self, hostname: list[str], rank: int) -> list[Any]: ... - - @overload - def get(self, hostname: list[str], rank: list[int]) -> dict[str, list[Any]]: ... - - def get( # noqa: PLR0911 - self, - hostname: str | list[str] | None = None, - rank: int | list[int] | None = None, - ) -> dict[str, list[Any]] | list[Any] | Any: - """Get return values from selected workers.""" - if hostname is None and isinstance(rank, int): - return self.all(by="rank")[rank] - - if hostname is None and isinstance(rank, list): - _values = self.all(by="rank") - return [_values[r] for r in rank] - - if isinstance(hostname, str) and rank is None: - self.return_values[self.hostnames.index(hostname)] - - if isinstance(hostname, list) and rank is None: - return {h: self.get(hostname=h) for h in hostname} - - if isinstance(hostname, str) and isinstance(rank, int): - return self.get(hostname=hostname)[rank] - - if isinstance(hostname, str) and isinstance(rank, list): - return self.get(hostname=hostname)[rank] - - if isinstance(hostname, list) and isinstance(rank, int): - return [self.get(hostname=h)[rank] for h in hostname] - - if isinstance(hostname, list) and isinstance(rank, list): - _values = self.get(hostname=hostname) - return {h: [_values[h][r] for r in rank] for h in hostname} - - # remaining case: hostname=None, rank=None - return self.all(by="hostname") + def rank(self, idx: int) -> Any: + """Get return value from worker (indexed by global rank).""" + return self.by_rank()[idx] def _resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: diff --git a/tests/test_ci.py b/tests/test_ci.py index fe23e987..79d89433 100644 --- a/tests/test_ci.py +++ b/tests/test_ci.py @@ -37,7 +37,7 @@ def dist_func() -> torch.Tensor: backend="gloo", # log_dir="./test_logs" ) - assert torch.all(r.get(rank=0) == r.get(rank=1)) + assert torch.all(r.rank(0) == r.rank(1)) def test_logging() -> None: diff --git a/tests/test_func.py b/tests/test_func.py index e8033b4e..097c7d2f 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -13,7 +13,7 @@ def test_launch() -> None: workers_per_host="slurm", ) - result_values = result.all(by="rank") + result_values = result.by_rank() t = True for i in range(len(result_values)): From 9e2d5f4799e16567e7046b7eb6892f318d47f7d8 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Tue, 29 Oct 2024 21:37:06 -0400 Subject: [PATCH 28/31] update all docs --- README.md | 15 ++++++++------- docs/source/advanced.rst | 20 ++++++++++---------- docs/source/api.rst | 4 +++- docs/source/conf.py | 2 +- src/torchrunx/launcher.py | 21 +++++++++------------ src/torchrunx/utils/logging.py | 14 +++++++------- 6 files changed, 38 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index e3271e91..74f3ce25 100644 --- a/README.md +++ b/README.md @@ -56,12 +56,13 @@ Here's a simple example where we "train" a model on two nodes (with 2 GPUs each) import torchrunx as trx if __name__ == "__main__": - trained_model = trx.launch( + result = trx.launch( func=train, hostnames=["localhost", "other_node"], - workers_per_host=2 # num. GPUs - ).value(rank=0) # get returned object + workers_per_host=2 # number of GPUs + ) + trained_model = result.rank(0) torch.save(trained_model.state_dict(), "model.pth") ``` @@ -70,9 +71,9 @@ if __name__ == "__main__": ## Why should I use this? -Whether you have 1 GPU, 8 GPUs, or 8 machines. +Whether you have 1 GPU, 8 GPUs, or 8 machines: -__Features:__ +__Features__ - Our [`launch()`](https://torchrunx.readthedocs.io/stable/api.html#torchrunx.launch) utility is super _Pythonic_ - Return objects from your workers @@ -81,13 +82,13 @@ __Features:__ - Fine-grained control over logging, environment variables, exception handling, etc. - Automatic integration with SLURM -__Robustness:__ +__Robustness__ - If you want to run a complex, _modular_ workflow in __one__ script - don't parallelize your entire script: just the functions you want! - no worries about memory leaks or OS failures -__Convenience:__ +__Convenience__ - If you don't want to: - set up [`dist.init_process_group`](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) yourself diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 49c66382..2c6ec9cd 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -14,19 +14,19 @@ We could also launch multiple functions (e.g. train on many GPUs, test on one GP func=train, hostnames=["node1", "node2"], workers_per_host=8 - ).value(rank=0) + ).rank(0) accuracy = trx.launch( func=test, - func_kwargs={'model': model}, + func_args=(trained_model,), hostnames=["localhost"], workers_per_host=1 - ).value(rank=0) + ).rank(0) print(f'Accuracy: {accuracy}') -:mod:`torchrunx.launch` is self-cleaning: all processes are terminated (and the used memory is completely released) after each invocation. +:mod:`torchrunx.launch` is self-cleaning: all processes are terminated (and the used memory is completely released) before the subsequent invocation. Launcher class -------------- @@ -85,9 +85,9 @@ Raises a ``RuntimeError`` if ``hostnames="slurm"`` or ``workers_per_host="slurm" Propagating exceptions ---------------------- -Exceptions that are raised in Workers will be raised by the launcher process. +Exceptions that are raised in workers will be raised by the launcher process. -A :mod:`torchrunx.AgentKilledError` will be raised if any agent dies unexpectedly (e.g. if force-killed by the OS, due to segmentation faults or OOM). +A :mod:`torchrunx.AgentFailedError` or :mod:`torchrunx.WorkerFailedError` will be raised if any agent or worker dies unexpectedly (e.g. if sent a signal from the OS, due to segmentation faults or OOM). Environment variables --------------------- @@ -100,14 +100,14 @@ Environment variables in the launcher process that match the ``default_env_vars` Custom logging -------------- -We forward all logs (i.e. from ``logging`` and ``stdio``) from workers and agents to the Launcher. By default, the logs from the first agent and its first worker are printed into the Launcher's ``stdout`` stream. Logs from all agents and workers are written to files in ``$TORCHRUNX_LOG_DIR`` (default: ``./torchrunx_logs``) and are named by timestamp, hostname, and local_rank. +We forward all logs (i.e. from :mod:`logging` and :mod:`sys.stdin`/:mod:`sys.stdout`) from workers and agents to the launcher. By default, the logs from the first agent and its first worker are printed into the launcher's ``stdout`` stream. Logs from all agents and workers are written to files in ``$TORCHRUNX_LOG_DIR`` (default: ``./torchrunx_logs``) and are named by timestamp, hostname, and local_rank. -``logging.Handler`` objects can be provided via the ``log_handlers`` argument to provide further customization (mapping specific agents/workers to custom output streams). +:mod:`logging.Handler` objects can be provided via the ``log_handlers`` argument to provide further customization (mapping specific agents/workers to custom output streams). We provide some utilities to help: -.. autofunction:: torchrunx.add_filter_to_handler - .. autofunction:: torchrunx.file_handler .. autofunction:: torchrunx.stream_handler + +.. autofunction:: torchrunx.add_filter_to_handler diff --git a/docs/source/api.rst b/docs/source/api.rst index 31b35fcf..518b323f 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -6,4 +6,6 @@ API .. autoclass:: torchrunx.LaunchResult :members: -.. autoclass:: torchrunx.AgentKilledError +.. autoclass:: torchrunx.AgentFailedError + +.. autoclass:: torchrunx.WorkerFailedError diff --git a/docs/source/conf.py b/docs/source/conf.py index 966816e9..2b2aadee 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -17,8 +17,8 @@ "myst_parser", "sphinx_toolbox.sidebar_links", "sphinx_toolbox.github", - "sphinx.ext.autodoc.typehints", "sphinx.ext.napoleon", + "sphinx.ext.autodoc.typehints", "sphinx.ext.linkcode", ] diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index f53f7e0a..54f8a3d7 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -39,10 +39,7 @@ @dataclass class Launcher: - """Alias class for ``torchrunx.launch``. - - Useful for sequential invocations on the same configuration or for specifying arguments via CLI. - """ + """Useful for sequential invocations or for specifying arguments via CLI.""" hostnames: list[str] | Literal["auto", "slurm"] = "auto" workers_per_host: int | list[int] | Literal["auto", "slurm"] = "auto" @@ -69,7 +66,7 @@ def run( # noqa: C901, PLR0912 func_kwargs: dict[str, Any] | None = None, log_handlers: list[Handler] | Literal["auto"] | None = "auto", ) -> LaunchResult: - """Run a function using the configuration in ``torchrunx.Launcher``.""" + """Run a function using the :mod:`torchrunx.Launcher` configuration.""" if not dist.is_available(): msg = "The torch.distributed package is not available." raise RuntimeError(msg) @@ -267,21 +264,21 @@ class LaunchResult: hostnames: list[str] return_values: list[list[Any]] - def by_hostname(self) -> dict[str, list[Any]]: + def by_hostnames(self) -> dict[str, list[Any]]: """All return values from workers, indexed by host and local rank.""" return dict(zip(self.hostnames, self.return_values)) - def by_rank(self) -> list[Any]: + def by_ranks(self) -> list[Any]: """All return values from workers, indexed by global rank.""" return reduce(add, self.return_values) - def get(self, hostname: str, rank: int) -> Any: - """Get return value from worker (indexed by host and local rank).""" + def index(self, hostname: str, rank: int) -> Any: + """Get return value from worker by host and local rank.""" return self.return_values[self.hostnames.index(hostname)][rank] - def rank(self, idx: int) -> Any: - """Get return value from worker (indexed by global rank).""" - return self.by_rank()[idx] + def rank(self, i: int) -> Any: + """Get return value from worker by global rank.""" + return self.by_rank()[i] def _resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: diff --git a/src/torchrunx/utils/logging.py b/src/torchrunx/utils/logging.py index 72be39b1..cb786ccd 100644 --- a/src/torchrunx/utils/logging.py +++ b/src/torchrunx/utils/logging.py @@ -41,10 +41,10 @@ def add_filter_to_handler( local_rank: int | None, # None indicates agent log_level: int = logging.NOTSET, ) -> None: - """A filter for ``logging.Handler`` such that only specific agent/worker logs are handled. + """A filter for :mod:`logging.Handler` such that only specific agent/worker logs are handled. Args: - handler: ``logging.Handler`` to be modified. + handler: Handler to be modified. hostname: Name of specified host. local_rank: Rank of specified worker (or ``None`` for agent). log_level: Minimum log level to capture. @@ -63,7 +63,7 @@ def _filter(record: WorkerLogRecord) -> bool: def stream_handler( hostname: str, local_rank: int | None, log_level: int = logging.NOTSET ) -> Handler: - """logging.Handler builder function for writing logs to stdout.""" + """Handler builder function for writing logs from specified hostname/rank to stdout.""" handler = logging.StreamHandler(stream=sys.stdout) add_filter_to_handler(handler, hostname, local_rank, log_level=log_level) handler.setFormatter( @@ -82,7 +82,7 @@ def file_handler( file_path: str | os.PathLike, log_level: int = logging.NOTSET, ) -> Handler: - """logging.Handler builder function for writing logs to a file.""" + """Handler builder function for writing logs from specified hostname/rank to a file.""" handler = logging.FileHandler(file_path) add_filter_to_handler(handler, hostname, local_rank, log_level=log_level) formatter = logging.Formatter("%(asctime)s:%(levelname)s: %(message)s") @@ -96,7 +96,7 @@ def file_handlers( log_dir: str | os.PathLike = Path("torchrunx_logs"), log_level: int = logging.NOTSET, ) -> list[Handler]: - """Builder function for writing logs for all workers/agents to a directory. + """Handler builder function for writing logs for all workers/agents to a directory. Files are named with timestamp, hostname, and the local_rank (for workers). """ @@ -123,9 +123,9 @@ def default_handlers( log_dir: str | os.PathLike = Path("torchrunx_logs"), log_level: int = logging.INFO, ) -> list[Handler]: - """A default set of logging.Handlers to be used when ``launch(log_handlers="auto")``. + """Default :mod:`logging.Handler`s for ``log_handlers="auto"`` in :mod:`torchrunx.launch`. - Logs for host[0] and its local_rank[0] worker are written to the launcher process stdout. + Logs for ``host[0]`` and its ``local_rank[0]`` worker are written to launcher process stdout. Logs for all agents/workers are written to files in ``log_dir`` (named by timestamp, hostname, local_rank). """ From a29212e068ad6f4e24b48d2def4a026b5f0e17be Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Tue, 29 Oct 2024 21:39:30 -0400 Subject: [PATCH 29/31] fix --- src/torchrunx/launcher.py | 2 +- tests/test_func.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchrunx/launcher.py b/src/torchrunx/launcher.py index 54f8a3d7..b745ac7a 100644 --- a/src/torchrunx/launcher.py +++ b/src/torchrunx/launcher.py @@ -278,7 +278,7 @@ def index(self, hostname: str, rank: int) -> Any: def rank(self, i: int) -> Any: """Get return value from worker by global rank.""" - return self.by_rank()[i] + return self.by_ranks()[i] def _resolve_hostnames(hostnames: list[str] | Literal["auto", "slurm"]) -> list[str]: diff --git a/tests/test_func.py b/tests/test_func.py index 097c7d2f..7b3ad7f6 100644 --- a/tests/test_func.py +++ b/tests/test_func.py @@ -13,7 +13,7 @@ def test_launch() -> None: workers_per_host="slurm", ) - result_values = result.by_rank() + result_values = result.by_ranks() t = True for i in range(len(result_values)): From 7bf9222e30c160147583fa8958fcd3a8ffb65f19 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Tue, 29 Oct 2024 21:50:36 -0400 Subject: [PATCH 30/31] small edits --- docs/source/advanced.rst | 2 +- docs/source/conf.py | 2 +- src/torchrunx/utils/logging.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 2c6ec9cd..9e151d1a 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -100,7 +100,7 @@ Environment variables in the launcher process that match the ``default_env_vars` Custom logging -------------- -We forward all logs (i.e. from :mod:`logging` and :mod:`sys.stdin`/:mod:`sys.stdout`) from workers and agents to the launcher. By default, the logs from the first agent and its first worker are printed into the launcher's ``stdout`` stream. Logs from all agents and workers are written to files in ``$TORCHRUNX_LOG_DIR`` (default: ``./torchrunx_logs``) and are named by timestamp, hostname, and local_rank. +We forward all logs (i.e. from :mod:`logging` and :mod:`sys.stdout`/:mod:`sys.stderr`) from workers and agents to the launcher. By default, the logs from the first agent and its first worker are printed into the launcher's ``stdout`` stream. Logs from all agents and workers are written to files in ``$TORCHRUNX_LOG_DIR`` (default: ``./torchrunx_logs``) and are named by timestamp, hostname, and local_rank. :mod:`logging.Handler` objects can be provided via the ``log_handlers`` argument to provide further customization (mapping specific agents/workers to custom output streams). diff --git a/docs/source/conf.py b/docs/source/conf.py index 2b2aadee..8e64563c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -22,7 +22,7 @@ "sphinx.ext.linkcode", ] -autodoc_mock_imports = ["torch", "fabric", "cloudpickle", "typing_extensions"] +autodoc_mock_imports = ["torch", "fabric", "cloudpickle", "sys", "logging", "typing_extensions"] autodoc_typehints = "both" autodoc_typehints_description_target = "documented_params" diff --git a/src/torchrunx/utils/logging.py b/src/torchrunx/utils/logging.py index cb786ccd..efd013a8 100644 --- a/src/torchrunx/utils/logging.py +++ b/src/torchrunx/utils/logging.py @@ -41,12 +41,12 @@ def add_filter_to_handler( local_rank: int | None, # None indicates agent log_level: int = logging.NOTSET, ) -> None: - """A filter for :mod:`logging.Handler` such that only specific agent/worker logs are handled. + """Apply a filter to :mod:`logging.Handler` so only specific worker logs are handled. Args: handler: Handler to be modified. hostname: Name of specified host. - local_rank: Rank of specified worker (or ``None`` for agent). + local_rank: Rank of specified worker on host (or ``None`` for agent itself). log_level: Minimum log level to capture. """ From 122febc9396f6833c7611c441c6a884d0d64e076 Mon Sep 17 00:00:00 2001 From: apoorvkh Date: Tue, 29 Oct 2024 22:57:17 -0400 Subject: [PATCH 31/31] how it works --- docs/source/advanced.rst | 1 + docs/source/how_it_works.rst | 13 +++++-------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 9e151d1a..6cc61d1f 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -96,6 +96,7 @@ Environment variables in the launcher process that match the ``default_env_vars` ``default_env_vars`` can be overriden if desired. This list can be augmented using ``extra_env_vars``. Additional environment variables (and more custom bash logic) can be included via the ``env_file`` argument. Our agents ``source`` this file. +We also set the following environment variables in each worker: ``LOCAL_RANK``, ``RANK``, ``LOCAL_WORLD_SIZE``, ``WORLD_SIZE``, ``MASTER_ADDR``, and ``MASTER_PORT``. Custom logging -------------- diff --git a/docs/source/how_it_works.rst b/docs/source/how_it_works.rst index 3ceb4b4c..9550e8d1 100644 --- a/docs/source/how_it_works.rst +++ b/docs/source/how_it_works.rst @@ -1,15 +1,12 @@ How it works ============ -In order to organize processes on different nodes, **torchrunx** maintains the following hierarchy: +If you want to (e.g.) train your model on several machines with **N** GPUs each, you should run your training function in **N** parallel processes on each machine. During training, each of these processes runs the same training code (i.e. your function) and communicate with each other (e.g. to synchronize gradients) using a `distributed process group `_. -#. The launcher, the process in which ``torchrunx.Launcher.run`` is executed: Connects to remote hosts and initializes and configures "agents", passes errors and return values from agents to the caller, and is responsible for cleaning up. -#. The agents, initialized on machines where computation is to be performed: Responsible for starting and monitoring "workers". -#. The workers, spawned by agents: Responsible for initializing a ``torch.distributed`` process group, and running the distributed function provided by the user. +Your script can call our library (via `mod:torchrunx.launch`) and specify a function to distribute. The main process running your script is henceforth known as the **launcher** process. -An example of how this hierarchy might look in practice is the following: -Suppose we wish to distribute a training function over four GPUs, and we have access to a cluster where nodes have two available GPUs each. Say that a single instance of our training function can leverage multiple GPUs. We can choose two available nodes and use the launcher to launch our function on those two nodes, specifying that we only need one worker per node, since a single instance of our training function can use both GPUs on each node. The launcher will launch an agent on each node and pass our configuration to the agents, after which the agents will each initialize one worker to begin executing the training function. We could also run two workers per node, each with one GPU, giving us four workers, although this would be slower. +Our launcher process spawns an **agent** process (via SSH) on each machine. Each agent then spawns **N** processes (known as **workers**) on its machine. All workers form a process group (with the specified `mod:torchrunx.launch` ``backend``) and run your function in parallel. -The launcher initializes the agents by simply SSHing into the provided hosts, and executing our agent code there. The launcher also provides key environmental variables from the launch environment to the sessions where the agents are started and tries to activate the same Python environment that was used to execute the launcher. This is one reason why all machines either running a launcher or agent process should share a filesystem. +**Agent–Worker Communication.** Our agents poll their workers every second and time-out if unresponsive for 5 seconds. Upon polling, our agents receive ``None`` (if the worker is still running) or a `RunProcsResult `_, indicating that the workers have either completed (providing an object returned from or the exception raised by our function) or failed (e.g. due to segmentation fault or OS signal). -The launcher and agents perform exception handling such that any exceptions in the worker processes are appropriately raised by the launcher process. The launcher and agents communicate using a ``torch.distributed`` process group, separate from the group that the workers use. \ No newline at end of file +**Launcher–Agent Communication.** The launcher and agents form a distributed group (with the CPU-based `GLOO backend `_) for the communication purposes of our library. Our agents synchronize their own "statuses" with each other and the launcher. An agent's status can include whether it is running/failed/completed and the result of the function. If the launcher or any agent fails to synchronize, all raise a `mod:torchrunx.AgentFailedError` and terminate. If any worker fails or raises an exception, the launcher raises a `mod:torchrunx.WorkerFailedError` or that exception and terminates along with all the agents. If all agents succeed, the launcher returns the objects returned by each worker.