Thanks to visit codestin.com
Credit goes to Github.com

Skip to content

Add sharding propagation support in nnx.eval_shape (clone of #5111)#5247

Open
samanklesaria wants to merge 2 commits intogoogle:mainfrom
samanklesaria:eval_shape_sharding
Open

Add sharding propagation support in nnx.eval_shape (clone of #5111)#5247
samanklesaria wants to merge 2 commits intogoogle:mainfrom
samanklesaria:eval_shape_sharding

Conversation

@samanklesaria
Copy link
Collaborator

Fixes #5110, clone of #5111 but rebased against current main.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @samanklesaria, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances nnx.eval_shape to correctly propagate sharding information, which is critical for models operating in a distributed environment using JAX's SPMD capabilities. By ensuring that eval_shape can infer and apply NamedSharding based on mesh configurations, it resolves a known issue where sharding properties were not correctly maintained during shape evaluation. This change improves the robustness and correctness of distributed model initialization and analysis.

Highlights

  • Sharding Propagation in nnx.eval_shape: Implemented support for propagating sharding information, including NamedSharding and mesh metadata, within nnx.eval_shape when creating variables.
  • Circular Import Resolution: Addressed potential circular import issues by moving the get_var_pspec import inside the _to_variable function.
  • New Test Cases for Sharding: Introduced two new test cases to validate that nnx.eval_shape correctly handles sharding specifications, both when mesh metadata is explicitly provided and when a global mesh is active.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • flax/nnx/transforms/transforms.py
    • Imported get_var_pspec locally within _to_variable to prevent circular dependencies.
    • Modified _to_variable to check for and apply NamedSharding to variables based on mesh metadata or the global mesh during shape evaluation.
  • tests/nnx/spmd_test.py
    • Added test_eval_shape_with_sharding0 to verify sharding propagation with explicit mesh metadata for multiple linear layers.
    • Added test_eval_shape_with_sharding1 to confirm sharding propagation when a global mesh is set using jax.set_mesh.
Activity
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

The pull request introduces sharding propagation support in nnx.eval_shape by adding logic to handle jax.sharding.NamedSharding for variables. This ensures that the sharding information is correctly propagated during shape evaluation. Additionally, new test cases have been added to tests/nnx/spmd_test.py to validate this new functionality. The changes appear to be correct and well-tested.

Comment on lines +241 to +248
global_mesh = jax.sharding.get_mesh()
if global_mesh.axis_sizes == ():
global_mesh = None
mesh = var.get_metadata("mesh", None) or global_mesh
if mesh is not None:
pspec = get_var_pspec(var)
sharding = jax.sharding.NamedSharding(mesh=mesh, spec=pspec)
var.set_value(jax.ShapeDtypeStruct(shape=var.shape, dtype=var.dtype, sharding=sharding))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for applying sharding within _to_variable seems to correctly handle cases where a mesh is explicitly provided in metadata or a global mesh is available. This ensures that jax.ShapeDtypeStruct is created with the appropriate sharding information.

Comment on lines +362 to +378
def test_eval_shape_with_sharding0(self):
# based on https://github.com/google/flax/issues/5110
mesh1 = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto))
mesh2 = jax.make_mesh((1, 4), ("c", "d"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto))

class Model(nnx.Module):
def __init__(self):
self.p1 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"out_sharding": ("a", "b"), "mesh": mesh1})
self.p2 = nnx.Linear(16, 16, rngs=nnx.Rngs(0), kernel_metadata={"out_sharding": ("c", "d"), "mesh": mesh2})

abs_model = nnx.eval_shape(lambda: Model())
assert isinstance(abs_model.p1.kernel.sharding, jax.sharding.NamedSharding)
assert abs_model.p1.kernel.sharding.mesh is mesh1
assert abs_model.p1.kernel.sharding.spec == jax.P("a", "b")
assert isinstance(abs_model.p2.kernel.sharding, jax.sharding.NamedSharding)
assert abs_model.p2.kernel.sharding.mesh is mesh2
assert abs_model.p2.kernel.sharding.spec == jax.P("c", "d")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test_eval_shape_with_sharding0 test case effectively verifies the sharding propagation when meshes are explicitly defined in the kernel_metadata for different linear layers. This is a good test for ensuring that custom mesh configurations are respected.

Comment on lines +380 to +390
def test_eval_shape_with_sharding1(self):
class Model(nnx.Module):
def __init__(self):
self.linear = nnx.Linear(10, 10, rngs=nnx.Rngs(0), kernel_metadata={"out_sharding": ("a", "b")})

mesh = jax.make_mesh((2, 2), ("a", "b"), (jax.sharding.AxisType.Auto, jax.sharding.AxisType.Auto))
with jax.set_mesh(mesh):
abs_model = nnx.eval_shape(lambda: Model())
assert isinstance(abs_model.linear.kernel.sharding, jax.sharding.NamedSharding)
assert abs_model.linear.kernel.sharding.mesh is mesh
assert abs_model.linear.kernel.sharding.spec == jax.P("a", "b")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test_eval_shape_with_sharding1 test case correctly validates sharding propagation when a global mesh is set using jax.set_mesh. This covers a common use case where sharding is implicitly applied based on the global mesh context.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

nnx.eval_shape should propagate 'sharding' and 'format' to ShapeDtypeStructs

2 participants