Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@mohsinm-dev
Copy link
Contributor

@mohsinm-dev mohsinm-dev commented Oct 3, 2025

Problem

When using nnx.jit with static arguments and in_shardings, users encountered errors:

  1. Not providing shardings for static args → "Tuple arity mismatch" error
  2. Providing shardings for static args → "pjit in_shardings specification must be a tree prefix" error

Solution

Fixed the sharding handling logic to properly filter out shardings for static argument positions before passing them to JAX's jit function.

Changes

  • Modified JitWrapped.__init__ in compilation.py to add None into shardings for static positions
  • Added comprehensive tests to verify both scenarios work correctly

Testing

  • Added tests/nnx/jit_static_sharding_test.py with tests for both failing scenarios
  • Verified no regressions in existing functionality

Fixes #4989

@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 21, 2025

@mohsinm-dev how about doing the following:

from jax._src import api_util  # We use fun_signature and resolve_argnums

class JitWrapped(...):
  def __init__(self, ...):
    ...
    if isinstance(in_shardings, (list, tuple)):
      # We should reintroduce None values into in_shardings corresponding to static arguments
      fun_signature = api_util.fun_signature(fun)
      _, _, static_argnums, _ = api_util.resolve_argnums(
          fun,
          fun_signature,
          None,
          None,
          static_argnums,
          static_argnames,
      )
      in_shardings = list(in_shardings)
      for static_arg_index in sorted(static_argnums):
        in_shardings.insert(static_arg_index, None)
      in_shardings = tuple(in_shardings)

    self.jitted_fn = jax.jit(
      JitFn(fun, in_shardings, out_shardings, kwarg_shardings, self),
      in_shardings=self.jax_in_shardings,
      out_shardings=(None, None, self.jax_out_shardings),
      ...

(full code: https://github.com/vfdev-5/flax/pull/new/nnx-jit-static-args)

The idea is that we would like to use jax functions work arguments manipulations.

For the case, when in_shardings is not a tuple but a single sharding value, we can just pass it as is without transforming to a tuple and inserting Nones for static args.

We certainly should write various test cases.

What do you think ?

@mohsinm-dev
Copy link
Contributor Author

mohsinm-dev commented Oct 22, 2025

@vfdev-5 Thanks for the suggestion. The current approach works as well, but, as you suggested, depending on the JAX, it is better, and we stay consistent with JAX's behavior for edge cases, such as mixed static_argnums and static_argnames.
I've implemented your approach and added comprehensive tests, all of which passed successfully.

I've added 5 detailed test cases covering static args, edge cases, and StateSharding. Let me know if you want more scenarios or if any seem redundant.

Copy link
Collaborator

@vfdev-5 vfdev-5 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this PR @mohsinm-dev !
I left few more comments.
Please also think to squash commits history into max 4 commits.

@mohsinm-dev mohsinm-dev force-pushed the fix-nnx-jit-static-sharding branch from 1dd4306 to 7a86859 Compare October 27, 2025 15:51
@vfdev-5
Copy link
Collaborator

vfdev-5 commented Oct 29, 2025

@vfdev-5 vfdev-5 requested a review from cgarciae October 31, 2025 13:56
- Add proper static argument resolution using api_util.resolve_argnums
- Consolidate 10 individual test cases into 2 parametrized tests
- Add StateSharding test for static arguments
- Ensure in_shardings correctly handles static argument positions
Removes trailing whitespace that was causing pre-commit hook failures.
No functional changes - only formatting cleanup.
@vfdev-5 vfdev-5 force-pushed the fix-nnx-jit-static-sharding branch from d4cdec4 to 38285f7 Compare November 4, 2025 15:46
@vfdev-5 vfdev-5 force-pushed the fix-nnx-jit-static-sharding branch from 38285f7 to dcf120b Compare November 4, 2025 16:22
@copybara-service copybara-service bot merged commit d3754f6 into google:main Nov 6, 2025
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

NNX JIT fails with static args and in shardings

4 participants