Add Shardy dialect support for JAX 0.8.2+ compatibility#1
Open
robtaylor wants to merge 2 commits into
Open
Conversation
JAX 0.8.2+ uses the Shardy partitioner by default (jax_use_shardy_partitioner=True),
which emits MLIR bytecode containing the sdy (Shardy) dialect. Without proper
dialect registration, IREE fails with "dialect 'sdy' does not implement the
bytecode interface".
This PR adds:
1. Shardy submodule (openxla/shardy) with CMake build support
- build_tools/third_party/shardy/ contains CMake build files since
upstream Shardy only has Bazel
2. New IREE input plugin at compiler/plugins/input/Shardy/
- Registers the sdy dialect via IREE's plugin architecture
- Provides StripShardyDialect pass to remove sdy ops/attributes for
single-device execution (sdy ops are metadata-only sharding annotations)
3. New IREE_INPUT_SHARDY CMake option (ON by default)
- Enables/disables Shardy dialect support in the compiler
4. Test for Shardy integration (test_shardy.py)
- Verifies JAX works with Shardy enabled on IREE PJRT backends
Technical details:
- ShardySession plugin class with DefaultActivated policy
- Dialect registration via mlir::sdy::registerAllDialects()
- Input conversion pipeline strips sdy.sharding attributes
- Properly integrated via iree_compiler_register_plugin() to ensure
symbols are included in libIREECompiler
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <[email protected]>
Signed-off-by: Rob Taylor <[email protected]>
- Fix iterator invalidation bug in StripShardyDialect.cpp by collecting ops first then erasing in reverse order - Add warning for unexpected Shardy op patterns that can't be handled - Implement pass registration in registerShardyInputConversionPasses() - Add StableHLO dependency check in CMake - Update copyright years to 2025 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]> Signed-off-by: Rob Taylor <[email protected]>
robtaylor
added a commit
that referenced
this pull request
Dec 27, 2025
…sPass
This fixes a bug that caused ElideTimepointsPass to fail on IR containing
loops (scf.for, scf.while, etc.) with two distinct issues:
1. SSA dominance violations ("operand #1 does not dominate this use"):
- Added dominance check before absorbing an await timepoint into a
timeline op's await list
- Changed pendingReplacements to use replaceUsesWithIf with dominance
check instead of replaceAllUsesWith
2. Infinite iteration in fixed-point pipeline:
- The pass was creating join ops and then re-adding the same timepoints
because Value identity checks don't detect coverage by join ops
- Added explicit check for whether a timepoint is already covered by
a join in the timeline op's await list
- Fixed didChange tracking to only signal changes when actual
modifications occur
Fixes iree-org#21982
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <[email protected]>
robtaylor
added a commit
that referenced
this pull request
Jan 3, 2026
…sPass
This fixes a bug that caused ElideTimepointsPass to fail on IR containing
loops (scf.for, scf.while, etc.) with two distinct issues:
1. SSA dominance violations ("operand #1 does not dominate this use"):
- Added dominance check before absorbing an await timepoint into a
timeline op's await list
- Changed pendingReplacements to use replaceUsesWithIf with dominance
check instead of replaceAllUsesWith
2. Infinite iteration in fixed-point pipeline:
- The pass was creating join ops and then re-adding the same timepoints
because Value identity checks don't detect coverage by join ops
- Added explicit check for whether a timepoint is already covered by
a join in the timeline op's await list
- Fixed didChange tracking to only signal changes when actual
modifications occur
Fixes iree-org#21982
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.5 <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Test PR to verify CI before upstream approval.
JAX 0.8.2+ uses the Shardy partitioner by default, which emits MLIR bytecode containing the
sdydialect. This PR adds full Shardy dialect support to IREE.Changes
third_party/shardy)build_tools/third_party/shardy/)compiler/plugins/input/Shardy/)test_shardy.py)🤖 Generated with Claude Code