-
Notifications
You must be signed in to change notification settings - Fork 363
Open
Labels
help wantedExtra attention is neededExtra attention is neededtype / bugIssue type: something isn't workingIssue type: something isn't working
Description
π Bug
Run tracking a number stored as a jax array is not compatible with jax>=0.6.0 .To reproduce
run = Run()
run.track(jnp.array(1.0), name='jax array', context={})
throws a
ValueError: Input type <class 'jaxlib._jax.ArrayImpl'> is neither python number nor AimObject
Expected behavior
If the array is just a single number, aim should be able to log this quantity as it did with earlier versions of jax.Environment
- Aim Version (e.g., 3.0.1)
- Python version 3.10.17
- pip version 25.1.1
- OS (e.g., Linux) MacOS
- Jax Version (0.6.1)
Additional context
Jax recently deprecated the jaxlib.xla_extension API, which means that the is_jax_device_array(inst) method as it is currently instantiated is not going to correctly identify jax Arrays for jax>=0.6.0.Metadata
Metadata
Assignees
Labels
help wantedExtra attention is neededExtra attention is neededtype / bugIssue type: something isn't workingIssue type: something isn't working