Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@dvruette
Copy link
Contributor

@dvruette dvruette commented Aug 2, 2025

[generated by copilot]

This pull request refactors and enhances the handling of sharding constraints in the eformer/escale/partition/constraints.py file. The most notable changes include renaming and improving an existing function for applying sharding constraints and introducing a new function to handle PyTrees of JAX arrays with enhanced validation and correction logic.

Function renaming and improvement:

  • Renamed with_sharding_constraint to array_with_sharding_constraint to clarify its purpose as operating on a single JAX array. Updated the function's type annotations to use jax.Array for improved clarity and consistency.

New functionality for PyTrees:

  • Introduced a new with_sharding_constraint function to apply sharding constraints to PyTrees of JAX arrays. This function validates the compatibility of the input PyTree structure and sharding specification, ensures all elements in the sharding specification are valid types, and applies corrections to incompatible sharding axes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant