Update JAX plugin to JAX 0.9.#6238
Conversation
Greptile SummaryThis PR successfully migrates the DALI JAX plugin from JAX 0.4.11 to 0.6.0+, addressing breaking API changes introduced in JAX 0.9. The migration involves three main changes:
All tests have been updated to use Minor note: PR description mentions minimum version 0.7.2, but code implements 0.6.0 (with 0.9.0.1 for CUDA 13.0). The PR description should be updated to match the implementation. Confidence Score: 5/5
Important Files Changed
Last reviewed commit: ddc3d89 |
0c77355 to
43558b8
Compare
|
@greptileai can you re-rereview? |
Additional Comments (5)
|
Additional Comments (1)
|
f912b1f to
be0b8a4
Compare
|
@greptileai can you re-rereview? |
Additional Comments (2)
|
|
@greptileai can you re-rereview? |
|
CI MESSAGE: [44967529]: BUILD STARTED |
|
CI MESSAGE: [44975089]: BUILD STARTED |
|
CI MESSAGE: [44984317]: BUILD STARTED |
Additional Comments (3)
Earlier in this same function (lines 103–104) the PR correctly replaced Additionally, The same
Both These entries should either be updated to compatible versions that support JAX 0.9, or their constraints should be relaxed / removed if the packages are now dropped from the test matrix.
For any caller that passes the iterator output directly into |
|
CI MESSAGE: [44986533]: BUILD STARTED |
|
CI MESSAGE: [44986533]: BUILD FAILED |
|
CI MESSAGE: [44967529]: BUILD FAILED |
- Bump the minimum required JAX version from 0.4.11 to 0.7.2. - Replace `PositionalSharding` (removed in JAX 0.9) with the `Sharding` base class in isinstance checks across `iterator.py` and `fn/_function_transform.py`. - Replace deprecated `jax.device_put_sharded` with `jax.make_array_from_single_device_arrays` using a `NamedSharding` built from the per-shard devices. - Fix `jax_shard.device()` call in `_build_output_with_device_put` to use the version-aware `_jax_device` helper (property in JAX >= 0.4.31). - Replace `PositionalSharding.shape[0]` with `len(sharding.device_set)` when computing the global shape for non-`NamedSharding` types. - Bump the JAX version in `qa/setup_packages.py` from 0.4.16 to 0.9.0.1. - Replace all remaining `PositionalSharding` usages in `test_multigpu.py` and `jax_server.py` with `NamedSharding`/`Mesh` equivalents and drop the now-unused `PositionalSharding` and `mesh_utils` imports. Signed-off-by: Janusz Lisiecki <[email protected]>
|
CI MESSAGE: [44996910]: BUILD STARTED |
|
CI MESSAGE: [44996910]: BUILD FAILED |
Signed-off-by: Janusz Lisiecki <[email protected]>
|
CI MESSAGE: [45049610]: BUILD FAILED |
|
CI MESSAGE: [45049783]: BUILD STARTED |
|
CI MESSAGE: [45049783]: BUILD PASSED |
Signed-off-by: Janusz Lisiecki <[email protected]>
|
CI MESSAGE: [45082620]: BUILD STARTED |
|
CI MESSAGE: [45082620]: BUILD FAILED |
Replace deprecated `device_buffers` with `addressable_shards` in qa/TL3_JAX_multiprocess/jax_server.py, consistent with the fixes applied to test_multigpu.py. Signed-off-by: Janusz Lisiecki <[email protected]>
Signed-off-by: Janusz Lisiecki <[email protected]>
|
CI MESSAGE: [45118272]: BUILD FAILED |
|
CI MESSAGE: [45118316]: BUILD STARTED |
|
CI MESSAGE: [45118536]: BUILD STARTED |
|
|
|
CI MESSAGE: [45145570]: BUILD FAILED |
|
CI MESSAGE: [45145786]: BUILD STARTED |
|
CI MESSAGE: [45145796]: BUILD STARTED |
Additional Comments (1)
|
Signed-off-by: Janusz Lisiecki <[email protected]>
|
CI MESSAGE: [45153895]: BUILD STARTED |
|
CI MESSAGE: [45154138]: BUILD STARTED |
|
CI MESSAGE: [45154138]: BUILD FAILED |
|
CI MESSAGE: [45153895]: BUILD FAILED |
|
CI MESSAGE: [45154138]: BUILD PASSED |
|
CI MESSAGE: [45153895]: BUILD PASSED |
PositionalSharding(removed in JAX 0.9) with theShardingbase class in isinstance checks across
iterator.pyandfn/_function_transform.py.jax.device_put_shardedwithjax.make_array_from_single_device_arraysusing aNamedShardingbuilt from the per-shard devices.
jax_shard.device()call in_build_output_with_device_putto usethe version-aware
_jax_devicehelper (property in JAX >= 0.4.31).PositionalSharding.shape[0]withlen(sharding.device_set)when computing the global shape for non-
NamedShardingtypes.qa/setup_packages.pyfrom 0.4.16 to 0.6.0 & 0.9.0.1.PositionalShardingusages intest_multigpu.pywithNamedSharding/Meshequivalents and dropthe now-unused
PositionalShardingandmesh_utilsimports.Category:
Other (e.g. Documentation, Tests, Configuration)
Description:
PositionalSharding(removed in JAX 0.9) with theShardingbase class in isinstance checks across
iterator.pyandfn/_function_transform.py.jax.device_put_shardedwithjax.make_array_from_single_device_arraysusing aNamedShardingbuilt from the per-shard devices.
jax_shard.device()call in_build_output_with_device_putto usethe version-aware
_jax_devicehelper (property in JAX >= 0.4.31).PositionalSharding.shape[0]withlen(sharding.device_set)when computing the global shape for non-
NamedShardingtypes.qa/setup_packages.pyfrom 0.4.16 to 0.6.0 & 0.9.0.1.PositionalShardingusages intest_multigpu.pywithNamedSharding/Meshequivalents and dropthe now-unused
PositionalShardingandmesh_utilsimports.Additional information:
Affected modules and functionalities:
Key points relevant for the review:
Tests:
Checklist
Documentation
DALI team only
Requirements
REQ IDs: N/A
JIRA TASK: DALI-4622