[fix] Apply with_sharding_constraint recursively to pytrees
#4
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
[generated by copilot]
This pull request refactors and enhances the handling of sharding constraints in the
eformer/escale/partition/constraints.pyfile. The most notable changes include renaming and improving an existing function for applying sharding constraints and introducing a new function to handle PyTrees of JAX arrays with enhanced validation and correction logic.Function renaming and improvement:
with_sharding_constrainttoarray_with_sharding_constraintto clarify its purpose as operating on a single JAX array. Updated the function's type annotations to usejax.Arrayfor improved clarity and consistency.New functionality for PyTrees:
with_sharding_constraintfunction to apply sharding constraints to PyTrees of JAX arrays. This function validates the compatibility of the input PyTree structure and sharding specification, ensures all elements in the sharding specification are valid types, and applies corrections to incompatible sharding axes.