Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Update JAX plugin to JAX 0.9.#6238

Merged
JanuszL merged 6 commits into
NVIDIA:mainfrom
JanuszL:update_jax
Mar 5, 2026
Merged

Update JAX plugin to JAX 0.9.#6238
JanuszL merged 6 commits into
NVIDIA:mainfrom
JanuszL:update_jax

Conversation

@JanuszL
Copy link
Copy Markdown
Contributor

@JanuszL JanuszL commented Feb 27, 2026

  • Bumps the minimum required JAX version from 0.4.11 to 0.6.0.
  • Replaces PositionalSharding (removed in JAX 0.9) with the Sharding
    base class in isinstance checks across iterator.py and
    fn/_function_transform.py.
  • Replaces deprecated jax.device_put_sharded with
    jax.make_array_from_single_device_arrays using a NamedSharding
    built from the per-shard devices.
  • Fixes jax_shard.device() call in _build_output_with_device_put to use
    the version-aware _jax_device helper (property in JAX >= 0.4.31).
  • Replaces PositionalSharding.shape[0] with len(sharding.device_set)
    when computing the global shape for non-NamedSharding types.
  • Bumps the JAX version in qa/setup_packages.py from 0.4.16 to 0.6.0 & 0.9.0.1.
  • Replaces all remaining PositionalSharding usages in
    test_multigpu.py with NamedSharding/Mesh equivalents and drop
    the now-unused PositionalSharding and mesh_utils imports.

Category:

Other (e.g. Documentation, Tests, Configuration)

Description:

  • Bumps the minimum required JAX version from 0.4.11 to 0.6.0.
  • Replaces PositionalSharding (removed in JAX 0.9) with the Sharding
    base class in isinstance checks across iterator.py and
    fn/_function_transform.py.
  • Replaces deprecated jax.device_put_sharded with
    jax.make_array_from_single_device_arrays using a NamedSharding
    built from the per-shard devices.
  • Fixes jax_shard.device() call in _build_output_with_device_put to use
    the version-aware _jax_device helper (property in JAX >= 0.4.31).
  • Replaces PositionalSharding.shape[0] with len(sharding.device_set)
    when computing the global shape for non-NamedSharding types.
  • Bumps the JAX version in qa/setup_packages.py from 0.4.16 to 0.6.0 & 0.9.0.1.
  • Replaces all remaining PositionalSharding usages in
    test_multigpu.py with NamedSharding/Mesh equivalents and drop
    the now-unused PositionalSharding and mesh_utils imports.

Additional information:

Affected modules and functionalities:

  • JAX plugin

Key points relevant for the review:

Tests:

  • Existing tests apply
    • test_fw_iterators.jax
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: DALI-4622

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 27, 2026

Greptile Summary

This PR successfully migrates the DALI JAX plugin from JAX 0.4.11 to 0.6.0+, addressing breaking API changes introduced in JAX 0.9. The migration involves three main changes:

  • Sharding API update: Replaces all PositionalSharding references (removed in JAX 0.9) with the Sharding base class in isinstance checks. For non-NamedSharding types, device counting now uses len(sharding.device_set) instead of sharding.shape[0].

  • Deprecated API replacement: Replaces jax.device_put_sharded with jax.make_array_from_single_device_arrays in iterator.py:_build_output_with_device_put. The implementation maintains backward compatibility by using jnp.expand_dims to add an extra dimension to each shard, preserving the stacked array behavior where arr[device_id] returns the device's shard.

  • Version-aware device access: Uses the _jax_device helper throughout to handle the .device() method vs .device property change across JAX versions (property since 0.4.31).

All tests have been updated to use NamedSharding+Mesh instead of PositionalSharding, and addressable_shards[i].data instead of the deprecated device_buffers[i] API. The test script workaround for CUDA 12/JAX version mismatch has been removed.

Minor note: PR description mentions minimum version 0.7.2, but code implements 0.6.0 (with 0.9.0.1 for CUDA 13.0). The PR description should be updated to match the implementation.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk - it correctly handles JAX API deprecations with proper backward compatibility
  • The migration is well-executed with comprehensive test coverage. All three core changes (PositionalSharding removal, device_put_sharded replacement, version-aware device access) are implemented correctly. The dimension expansion logic in _build_output_with_device_put maintains the expected pmap-compatible behavior. Tests verify both manual sharding and iterator-based workflows work correctly.
  • No files require special attention

Important Files Changed

Filename Overview
dali/python/nvidia/dali/plugin/jax/iterator.py Replaces deprecated jax.device_put_sharded with jax.make_array_from_single_device_arrays, uses _jax_device helper for version compatibility
dali/test/python/jax_plugin/test_multigpu.py Comprehensive test updates: PositionalSharding to NamedSharding migration, device_buffers to addressable_shards, uses _jax_device helper
qa/setup_packages.py Updates JAX versions: 0.6.0 for CUDA 12.0, 0.9.0.1 for CUDA 13.0; removes numpy<2 and jax version constraints from flax/clu

Last reviewed commit: ddc3d89

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Comment thread qa/setup_packages.py Outdated
@JanuszL JanuszL force-pushed the update_jax branch 2 times, most recently from 0c77355 to 43558b8 Compare February 27, 2026 11:41
@JanuszL
Copy link
Copy Markdown
Contributor Author

JanuszL commented Feb 27, 2026

@greptileai can you re-rereview?

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 27, 2026

Additional Comments (5)

dali/test/python/jax_plugin/test_multigpu.py, line 26
PositionalSharding was removed in JAX 0.9. This import will cause runtime errors.

from jax.sharding import NamedSharding, PartitionSpec, Mesh

dali/test/python/jax_plugin/test_multigpu.py, line 303
PositionalSharding is removed in JAX 0.9. Replace with NamedSharding using a Mesh.

    mesh = Mesh(jax.devices(), axis_names=("device",))
    sharding = NamedSharding(mesh, PartitionSpec("device"))

dali/test/python/jax_plugin/test_multigpu.py, line 317
PositionalSharding is removed in JAX 0.9. Replace with NamedSharding.

    sharding = NamedSharding(Mesh(mesh, axis_names=("device",)), PartitionSpec("device"))

dali/test/python/jax_plugin/test_multigpu.py, line 409
PositionalSharding is removed in JAX 0.9. Replace with NamedSharding.

    sharding = NamedSharding(Mesh(mesh, axis_names=("device",)), PartitionSpec("device"))

dali/test/python/jax_plugin/test_multigpu.py, line 520
PositionalSharding is removed in JAX 0.9. Replace with NamedSharding.

    sharding = NamedSharding(Mesh(mesh, axis_names=("device",)), PartitionSpec("device"))

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 27, 2026

Additional Comments (1)

dali/python/nvidia/dali/plugin/jax/iterator.py, line 15
Check that test files test_multigpu.py and jax_server.py still work - they import PositionalSharding which was removed in JAX 0.9

@JanuszL JanuszL force-pushed the update_jax branch 2 times, most recently from f912b1f to be0b8a4 Compare February 27, 2026 11:57
@JanuszL
Copy link
Copy Markdown
Contributor Author

JanuszL commented Feb 27, 2026

@greptileai can you re-rereview?

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 27, 2026

Additional Comments (2)

dali/test/python/jax_plugin/test_multigpu.py, line 108
jax.device_put_sharded is deprecated and removed in JAX 0.9 - replace with jax.make_array_from_single_device_arrays like in iterator.py

        # Build sharded array using the same approach as in iterator.py
        devices_arr = np.array([jax_shard_0.device(), jax_shard_1.device()])
        mesh = jax.sharding.Mesh(devices_arr, axis_names=("device",))
        sharding = jax.sharding.NamedSharding(mesh, jax.sharding.PartitionSpec("device"))
        global_shape = (len([jax_shard_0, jax_shard_1]) * jax_shard_0.shape[0], *jax_shard_0.shape[1:])
        jax_array = jax.make_array_from_single_device_arrays(
            global_shape, sharding, [jax_shard_0, jax_shard_1]
        )

dali/test/python/jax_plugin/test_multigpu.py, line 103
.device() method deprecated in JAX >= 0.4.31 - use _jax_device helper or .device property for consistency with iterator.py:207

@JanuszL
Copy link
Copy Markdown
Contributor Author

JanuszL commented Feb 27, 2026

@greptileai can you re-rereview?

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [44967529]: BUILD STARTED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [44975089]: BUILD STARTED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [44984317]: BUILD STARTED

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Feb 27, 2026

Additional Comments (3)

dali/test/python/jax_plugin/test_multigpu.py, line 162
Stale .device() method calls missed during migration

Earlier in this same function (lines 103–104) the PR correctly replaced jax_shard_0.device() / jax_shard_1.device() with _jax_device(...) because .device became a read-only property in JAX ≥ 0.4.31. The assertions on lines 161–162 were not updated: calling .device() will invoke __call__ on the returned Device object, which is not callable, and will raise TypeError at runtime on any JAX ≥ 0.4.31 (including the targeted 0.9.x).

Additionally, device_buffers is a long-deprecated attribute that may be absent in JAX 0.9; the preferred replacement is addressable_shards.

        assert _jax_device(jax_array.addressable_shards[0].data) == _jax_device(jax_shard_0)
        assert _jax_device(jax_array.addressable_shards[1].data) == _jax_device(jax_shard_1)

The same device_buffers access pattern also appears at lines 142 and 151 — those only read values (not devices) so they may be migrated to jax_array.addressable_shards[i].data as well.


qa/setup_packages.py, line 672
flax and clu still pin jax==0.4.16, conflicting with new minimum JAX 0.7.2

Both flax 0.7.4 and clu 0.0.12 carry constraints=["jax==0.4.16"]. Now that the minimum required JAX is 0.7.2 (and the tested version is 0.9.0.1), any environment that installs these packages alongside the new JAX will hit an unsatisfiable constraint — the solver will either refuse to install or silently downgrade JAX to 0.4.16, breaking the JAX plugin.

These entries should either be updated to compatible versions that support JAX 0.9, or their constraints should be relaxed / removed if the packages are now dropped from the test matrix.


dali/python/nvidia/dali/plugin/jax/iterator.py, line 231
Output shape semantics changed for the pmap-compatible path

jax.device_put_sharded(shards, devices) produced a global array of shape (N, *shard_shape) — matching the pmap convention where the leading axis is the device axis. The replacement using make_array_from_single_device_arrays with global_shape = (N * shard_shape[0], *shard_shape[1:]) produces (N·B, H, W) instead.

For any caller that passes the iterator output directly into jax.pmap, the leading-axis size changes from N to N·B, which would cause pmap to attempt to use N·B devices and fail. The docstring still reads "compatible with pmapped JAX functions" — consider updating it to reflect the new semantics, or add a note that pmap usage is no longer the primary intended consumption pattern (consistent with pmap being deprecated in JAX 0.9).

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [44986533]: BUILD STARTED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [44986533]: BUILD FAILED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [44967529]: BUILD FAILED

- Bump the minimum required JAX version from 0.4.11 to 0.7.2.
- Replace `PositionalSharding` (removed in JAX 0.9) with the `Sharding`
  base class in isinstance checks across `iterator.py` and
  `fn/_function_transform.py`.
- Replace deprecated `jax.device_put_sharded` with
  `jax.make_array_from_single_device_arrays` using a `NamedSharding`
  built from the per-shard devices.
- Fix `jax_shard.device()` call in `_build_output_with_device_put` to use
  the version-aware `_jax_device` helper (property in JAX >= 0.4.31).
- Replace `PositionalSharding.shape[0]` with `len(sharding.device_set)`
  when computing the global shape for non-`NamedSharding` types.
- Bump the JAX version in `qa/setup_packages.py` from 0.4.16 to 0.9.0.1.
- Replace all remaining `PositionalSharding` usages in `test_multigpu.py`
  and `jax_server.py` with `NamedSharding`/`Mesh` equivalents and drop
  the now-unused `PositionalSharding` and `mesh_utils` imports.

Signed-off-by: Janusz Lisiecki <[email protected]>
@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [44996910]: BUILD STARTED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [44996910]: BUILD FAILED

Signed-off-by: Janusz Lisiecki <[email protected]>
@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45049610]: BUILD FAILED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45049783]: BUILD STARTED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45049783]: BUILD PASSED

Signed-off-by: Janusz Lisiecki <[email protected]>
@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45082620]: BUILD STARTED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45082620]: BUILD FAILED

JanuszL added 2 commits March 2, 2026 08:40
Replace deprecated `device_buffers` with `addressable_shards` in
qa/TL3_JAX_multiprocess/jax_server.py, consistent with the fixes
applied to test_multigpu.py.

Signed-off-by: Janusz Lisiecki <[email protected]>
Signed-off-by: Janusz Lisiecki <[email protected]>
@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45118272]: BUILD FAILED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45118316]: BUILD STARTED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45118536]: BUILD STARTED

@mzient
Copy link
Copy Markdown
Contributor

mzient commented Mar 2, 2026

qa/TL0_multigpu/test_body.sh:58 still contains
python -m pip install --upgrade "jax[cuda12_pip]==0.4.26" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html which breaks the tests

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45145570]: BUILD FAILED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45145786]: BUILD STARTED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45145796]: BUILD STARTED

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 2, 2026

Additional Comments (1)

qa/TL0_multigpu/test_body.sh, line 52
JAX 0.4.26 is below the new minimum requirement of 0.6.0 set in __init__.py. Tests will fail with assertion error when importing the JAX plugin.

    python -m pip install --upgrade "jax[cuda12_pip]>=0.6.0" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Signed-off-by: Janusz Lisiecki <[email protected]>
@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45153895]: BUILD STARTED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45154138]: BUILD STARTED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45154138]: BUILD FAILED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45153895]: BUILD FAILED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45154138]: BUILD PASSED

@dali-automaton
Copy link
Copy Markdown
Collaborator

CI MESSAGE: [45153895]: BUILD PASSED

@JanuszL JanuszL merged commit ff8655b into NVIDIA:main Mar 5, 2026
6 checks passed
@JanuszL JanuszL deleted the update_jax branch March 5, 2026 18:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants