-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Description
Description
Description
Minimal JAX program fails on GPU matmul with XlaRuntimeError: INTERNAL: an unsupported value or parameter was passed to the function.
Repro code:
import jax, jax.numpy as jnp
print("Devices:", jax.devices())
f = jax.jit(lambda x,y: x@y)
a=jnp.ones((256,256), jnp.float32); b=a
print(f(a,b).block_until_ready())
Output:
Devices: [CudaDevice(id=0)]
jaxlib._jax.XlaRuntimeError: INTERNAL: an unsupported value or parameter was passed to the function
Expected: successful 256Γ256 matrix multiplication with [0,0] == 256.0.
Observed: always fails with INTERNAL error, even for tiny matmul.
Disabling Command Buffer (--xla_gpu_enable_command_buffer=), selective removal, or disabling Triton GEMM did not resolve the error.
System info (python version, jaxlib version, accelerator, etc.)
System Info
OS: WSL2 Ubuntu on Windows 11
Python: uv environment
JAX: 0.7.2
jaxlib: 0.7.2
jax-cuda12-pjrt: 0.7.2
jax-cuda12-plugin: 0.7.2
GPU: NVIDIA (CudaDevice id=0 detected)
CUDA/cuDNN wheels:
nvidia-cublas-cu12 12.8.4.1
nvidia-cuda-cupti-cu12 12.8.90
nvidia-cuda-nvcc-cu12 12.9.86
nvidia-cuda-nvrtc-cu12 12.8.93
nvidia-cuda-runtime-cu12 12.8.90
nvidia-cudnn-cu12 9.10.2.21
nvidia-cufft-cu12 11.3.3.83
nvidia-curand-cu12 10.3.9.90
nvidia-cusolver-cu12 11.7.3.90
nvidia-cusparse-cu12 12.5.8.93
nvidia-cusparselt-cu12 0.7.1
nvidia-nccl-cu12 2.27.3
nvidia-nvjitlink-cu12 12.8.93
nvidia-nvshmem-cu12 3.4.5
nvidia-nvtx-cu12 12.8.90