Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Detection of Jax Arrays Breaks on Jax=0.4.* #2535

@EntilZha

Description

@EntilZha

The detection for jax arrays here

def is_jax_device_array(inst):
is broken in Jax 0.4.* since DeviceArray was renamed to Array. The simple fix to support the prior/current version would be to check for either DeviceArray or Array

Metadata

Metadata

Assignees

No one assigned

    Labels

    area / integrationsIssue area: integrations with other tools and libshelp wantedExtra attention is neededtype / bugIssue type: something isn't working

    Type

    No type

    Projects

    Status

    Next-up

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions