-
Notifications
You must be signed in to change notification settings - Fork 787
Fix NNX jit static args with in_shardings issue #4989 #4996
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix NNX jit static args with in_shardings issue #4989 #4996
Conversation
|
@mohsinm-dev how about doing the following: from jax._src import api_util # We use fun_signature and resolve_argnums
class JitWrapped(...):
def __init__(self, ...):
...
if isinstance(in_shardings, (list, tuple)):
# We should reintroduce None values into in_shardings corresponding to static arguments
fun_signature = api_util.fun_signature(fun)
_, _, static_argnums, _ = api_util.resolve_argnums(
fun,
fun_signature,
None,
None,
static_argnums,
static_argnames,
)
in_shardings = list(in_shardings)
for static_arg_index in sorted(static_argnums):
in_shardings.insert(static_arg_index, None)
in_shardings = tuple(in_shardings)
self.jitted_fn = jax.jit(
JitFn(fun, in_shardings, out_shardings, kwarg_shardings, self),
in_shardings=self.jax_in_shardings,
out_shardings=(None, None, self.jax_out_shardings),
...(full code: https://github.com/vfdev-5/flax/pull/new/nnx-jit-static-args) The idea is that we would like to use jax functions work arguments manipulations. For the case, when We certainly should write various test cases. What do you think ? |
|
@vfdev-5 Thanks for the suggestion. The current approach works as well, but, as you suggested, depending on the JAX, it is better, and we stay consistent with JAX's behavior for edge cases, such as mixed static_argnums and static_argnames. I've added 5 detailed test cases covering static args, edge cases, and StateSharding. Let me know if you want more scenarios or if any seem redundant. |
vfdev-5
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this PR @mohsinm-dev !
I left few more comments.
Please also think to squash commits history into max 4 commits.
1dd4306 to
7a86859
Compare
|
@mohsinm-dev please fix the failing test: https://github.com/google/flax/actions/runs/18873989510/job/53861798191?pr=4996 |
- Add proper static argument resolution using api_util.resolve_argnums - Consolidate 10 individual test cases into 2 parametrized tests - Add StateSharding test for static arguments - Ensure in_shardings correctly handles static argument positions
Removes trailing whitespace that was causing pre-commit hook failures. No functional changes - only formatting cleanup.
d4cdec4 to
38285f7
Compare
38285f7 to
dcf120b
Compare
Problem
When using
nnx.jitwith static arguments andin_shardings, users encountered errors:Solution
Fixed the sharding handling logic to properly filter out shardings for static argument positions before passing them to JAX's jit function.
Changes
JitWrapped.__init__incompilation.pyto add None into shardings for static positionsTesting
tests/nnx/jit_static_sharding_test.pywith tests for both failing scenariosFixes #4989