[AMD] Add mori-shmem backend support#145
Conversation
- Add libmori_shmem_device Python bindings for Triton kernels - Add comprehensive test suite for mori-shmem APIs (PyTorch/UniqueID init) - Add build script (build_mori_shmem.sh) for mori and device BC
There was a problem hiding this comment.
Pull request overview
This PR adds support for the mori-shmem backend as an alternative to rocshmem for AMD GPU distributed computing. The implementation introduces a new backend selection mechanism via the TRITON_DIST_SHMEM_BACKEND environment variable and integrates the mori library as a git submodule.
- Adds mori-shmem as a selectable backend alongside rocshmem for AMD GPUs
- Implements dynamic backend switching through environment variable configuration
- Provides comprehensive test coverage for mori-shmem APIs
Reviewed changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 19 comments.
Show a summary per file
| File | Description |
|---|---|
| scripts/build_mori_shmem.sh | Build script for compiling mori library and linking bitcode files for device execution |
| python/triton_dist/utils.py | Core utilities for backend selection, initialization, and version/hash management |
| python/triton_dist/test/amd/test_mori_shmem_api.py | Test suite validating mori-shmem basic operations and device-level APIs |
| python/triton_dist/language/extra/libshmem_device.py | Updated module proxy to route to mori-shmem device library based on backend |
| python/triton_dist/language/extra/hip/libmori_shmem_device.py | Device-level API bindings for mori-shmem operations (my_pe, n_pes, int_p) |
| python/triton_dist/jit.py | JIT compilation integration with backend-specific module initialization |
| .gitmodules | Adds mori library as a git submodule dependency |
Comments suppressed due to low confidence (3)
python/triton_dist/test/amd/test_mori_shmem_api.py:28
- Import of 'time' is not used.
import time
python/triton_dist/test/amd/test_mori_shmem_api.py:29
- Import of 'shutil' is not used.
import shutil
python/triton_dist/test/amd/test_mori_shmem_api.py:31
- Import of 'dist' is not used.
import torch.distributed as dist
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 1 comment.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 11 out of 11 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
|
|
||
| # Simple helper to wrap mori shmem pointer as torch tensor | ||
| class MoriShmemBuffer: |
There was a problem hiding this comment.
This part of code and below shouldn't be here, can you please make these consistent with other shmem component? @jhchouuu
There was a problem hiding this comment.
Understood, it was just a temporary placement before. Do you prefer to maintain it in triton-dist like rocshmem, or directly in mori_shmem like nvshmem4py?
There was a problem hiding this comment.
@jhchouuu right now we can follow rocshmem, and can be refactor like nvshmem4py later.
There was a problem hiding this comment.
Now we wrap it into mori_shmem similar to nvshmem4py
|
|
||
|
|
||
| @core.extern | ||
| def my_pe(_semantic=None): |
There was a problem hiding this comment.
where are other shmem APIs ? Ideally it should be consistent compared to other shmem. @jhchouuu
There was a problem hiding this comment.
In another branch that doesn't involve mr, we need to discuss the API encapsulation method, as this will add a parameter qp_id for selecting the QP during RDMA communication. Maybe I should talk with @XG-zheng offline?
There was a problem hiding this comment.
And the CI has encountered a build error due to failure in pulling the submodules of mori. I will fix this issue.
* Feature: add more mori_shmem bitcode wrappers && add mori_shmem bandwith test * Feature: mori_shmem supoort dl op call && small refactor * fix ci && move tensor create to mori library * Refine setup.py
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 16 out of 16 changed files in this pull request and generated 7 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def get_rocshmem_home(): | ||
| return os.getenv("ROCSHMEM_HOME", | ||
| Path(__file__).parent.parent / "shmem" / "rocshmem_bind" / "rocshmem_build" / "install") | ||
| Path(__file__).parent.parent.parent / "shmem" / "rocshmem_bind" / "rocshmem_build" / "install") |
There was a problem hiding this comment.
The path traverses up three parent directories (parent.parent.parent) which appears incorrect. This changes the rocshmem default path from the original two parents to three, which would break existing rocshmem installations. The original path used Path(__file__).parent.parent, and this should remain unchanged.
There was a problem hiding this comment.
Path corrected to use three parents - the shmem directory is at project root, not inside python/.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 16 out of 16 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 16 out of 16 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| [submodule "3rdparty/mori"] | ||
| path = 3rdparty/mori | ||
| url = https://github.com/ROCm/mori.git | ||
| branch = jiahzhou/triton_dis_support |
There was a problem hiding this comment.
Corrected spelling of 'jiahzhou' to 'jiahzhou'. However, note that this appears to be referencing a development branch which may not be appropriate for production. Consider using a stable release branch instead.
| branch = jiahzhou/triton_dis_support |
| @@ -82,8 +88,8 @@ jobs: | |||
| - name: E2E tests | |||
| run: | | |||
| bash ./scripts/build_e2e_env.sh --download_model | |||
There was a problem hiding this comment.
The change from Qwen3-32B to Qwen3-0.6B significantly reduces the model size for testing. While this may speed up CI, ensure this smaller model still provides adequate coverage for the test scenarios. Consider documenting why this change was made (e.g., CI resource constraints) in the PR description or commit message.
| bash ./scripts/build_e2e_env.sh --download_model | |
| bash ./scripts/build_e2e_env.sh --download_model | |
| # Use the smaller Qwen3-0.6B model in CI to keep AMD E2E tests within runtime and memory limits. | |
| # The larger Qwen3-32B variants below remain commented out until the CI image supports them (e.g., flash-attention). |
| def get_rocshmem_home(): | ||
| return os.getenv("ROCSHMEM_HOME", | ||
| Path(__file__).parent.parent / "shmem" / "rocshmem_bind" / "rocshmem_build" / "install") | ||
| Path(__file__).parent.parent.parent / "shmem" / "rocshmem_bind" / "rocshmem_build" / "install") |
There was a problem hiding this comment.
The path has been changed from parent.parent to parent.parent.parent, adding an extra level of directory traversal. This breaks the relative path resolution for rocshmem. The original path structure should be maintained to ensure rocshmem can be located correctly.
| Path(__file__).parent.parent.parent / "shmem" / "rocshmem_bind" / "rocshmem_build" / "install") | |
| Path(__file__).parent.parent / "shmem" / "rocshmem_bind" / "rocshmem_build" / "install") |
| return get_rocshmem_version() | ||
| elif backend == 'mori_shmem': | ||
| return get_mori_version() | ||
| return "unknown" |
There was a problem hiding this comment.
When neither CUDA nor HIP is available, or when backend detection fails, returning 'unknown' could mask configuration issues. Consider raising an exception or logging a warning to make debugging easier.
| return get_rocshmem_hash() | ||
| elif backend == 'mori_shmem': | ||
| return get_mori_shmem_hash() | ||
| return "unknown" |
There was a problem hiding this comment.
Similar to get_shmem_version, returning 'unknown' silently when neither CUDA nor HIP is detected could hide configuration problems. Consider raising an exception or logging a warning.
| raise RuntimeError(f"Unknown TRITON_DIST_SHMEM_BACKEND: {shmem_backend}. Must be 'mori_shmem' or 'rocshmem'") | ||
|
|
||
| # Also build if explicitly requested via env var (for backward compatibility) | ||
| if check_env_flag("TRITON_DISTRIBUTED_BUILD_PYROCSHMEM", "0") and shmem_backend != "rocshmem": |
There was a problem hiding this comment.
The logic for building rocshmem when TRITON_DISTRIBUTED_BUILD_PYROCSHMEM is set seems redundant since rocshmem is already built based on TRITON_DIST_SHMEM_BACKEND. This could lead to building rocshmem twice in some scenarios. Consider simplifying this logic or clarifying the intended behavior in a comment.
Description
Add mori-shmem backend support for AMD GPUs
Changes
TRITON_DIST_SHMEM_BACKENDenvironment variable