Thanks to visit codestin.com
Credit goes to github.com

Skip to content

πŸ› Bug Report: GPU matmul fails with INTERNAL: an unsupported value or parameter was passed to the functionΒ #32168

@Ascarshen

Description

@Ascarshen

Description

Description

Minimal JAX program fails on GPU matmul with XlaRuntimeError: INTERNAL: an unsupported value or parameter was passed to the function.

Repro code:

import jax, jax.numpy as jnp
print("Devices:", jax.devices())
f = jax.jit(lambda x,y: x@y)
a=jnp.ones((256,256), jnp.float32); b=a
print(f(a,b).block_until_ready())

Output:

Devices: [CudaDevice(id=0)]
jaxlib._jax.XlaRuntimeError: INTERNAL: an unsupported value or parameter was passed to the function

Expected: successful 256Γ—256 matrix multiplication with [0,0] == 256.0.
Observed: always fails with INTERNAL error, even for tiny matmul.
Disabling Command Buffer (--xla_gpu_enable_command_buffer=), selective removal, or disabling Triton GEMM did not resolve the error.

System info (python version, jaxlib version, accelerator, etc.)

System Info

OS: WSL2 Ubuntu on Windows 11

Python: uv environment

JAX: 0.7.2

jaxlib: 0.7.2

jax-cuda12-pjrt: 0.7.2

jax-cuda12-plugin: 0.7.2

GPU: NVIDIA (CudaDevice id=0 detected)

CUDA/cuDNN wheels:

nvidia-cublas-cu12 12.8.4.1
nvidia-cuda-cupti-cu12 12.8.90
nvidia-cuda-nvcc-cu12 12.9.86
nvidia-cuda-nvrtc-cu12 12.8.93
nvidia-cuda-runtime-cu12 12.8.90
nvidia-cudnn-cu12 9.10.2.21
nvidia-cufft-cu12 11.3.3.83
nvidia-curand-cu12 10.3.9.90
nvidia-cusolver-cu12 11.7.3.90
nvidia-cusparse-cu12 12.5.8.93
nvidia-cusparselt-cu12 0.7.1
nvidia-nccl-cu12 2.27.3
nvidia-nvjitlink-cu12 12.8.93
nvidia-nvshmem-cu12 3.4.5
nvidia-nvtx-cu12 12.8.90

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions