Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Run tracking a number stored as a jax array is not compatible with jax>=0.6.0Β #3343

@blaserethan

Description

@blaserethan

πŸ› Bug

Run tracking a number stored as a jax array is not compatible with jax>=0.6.0 .

To reproduce

run = Run()
run.track(jnp.array(1.0), name='jax array', context={})

throws a

ValueError: Input type <class 'jaxlib._jax.ArrayImpl'> is neither python number nor AimObject

Expected behavior

If the array is just a single number, aim should be able to log this quantity as it did with earlier versions of jax.

Environment

  • Aim Version (e.g., 3.0.1)
  • Python version 3.10.17
  • pip version 25.1.1
  • OS (e.g., Linux) MacOS
  • Jax Version (0.6.1)

Additional context

Jax recently deprecated the jaxlib.xla_extension API, which means that the is_jax_device_array(inst) method as it is currently instantiated is not going to correctly identify jax Arrays for jax>=0.6.0.

Metadata

Metadata

Assignees

No one assigned

    Labels

    help wantedExtra attention is neededtype / bugIssue type: something isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions