The following simple bit of code involving tridiagonal_solve works fine on CPU, but breaks with on the GPU:
Code:
$ env JAX_PLATFORMS=cpu python xlog.py
[[0.]
[0.]]
$ env JAX_TRACEBACK_FILTERING=off python xlog.py
E1009 19:27:14.558310 1780750 pjrt_stream_executor_client.cc:3314] Execution of replica 0 failed: INTERNAL: jaxlib/gpu/sparse_kernels.cc:537: operation kernel(handle.get(), m, n, dl_data, d_data, du_data, out_data, m, workspace) failed: invalid value
Traceback (most recent call last):
File "/tmp/xlog.py", line 5, in <module>
jll.tridiagonal_solve(jnp.array([0, 1.]), jnp.array([4, 4.]),
File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/lax/linalg.py", line 664, in tridiagonal_solve
return tridiagonal_solve_p.bind(dl, d, du, b)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/core.py", line 634, in bind
return self._true_bind(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/core.py", line 650, in _true_bind
return self.bind_with_trace(prev_trace, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/core.py", line 662, in bind_with_trace
return trace.process_primitive(self, args, params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/core.py", line 1189, in process_primitive
return primitive.impl(*args, **params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
outs = fun(*args)
^^^^^^^^^^
File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/pjit.py", line 268, in cache_miss
executable, pgle_profiler, const_args) = _python_pjit_helper(
^^^^^^^^^^^^^^^^^^^^
File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/pjit.py", line 147, in _python_pjit_helper
out_flat, compiled, profiler, const_args = _pjit_call_impl_python(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/pjit.py", line 1780, in _pjit_call_impl_python
return (compiled.unsafe_call(*computation.const_args, *args),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/profiler.py", line 364, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1379, in __call__
results = self.xla_executable.execute_sharded(input_bufs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib._jax.XlaRuntimeError: INTERNAL: jaxlib/gpu/sparse_kernels.cc:537: operation kernel(handle.get(), m, n, dl_data, d_data, du_data, out_data, m, workspace) failed: invalid value
nvidia-cuda-cupti-cu12 12.9.79
nvidia-cuda-nvcc-cu12 12.9.86
nvidia-cuda-nvrtc-cu12 12.9.86
nvidia-cuda-runtime-cu12 12.9.79
nvidia-cudnn-cu12 9.14.0.64
nvidia-cufft-cu12 11.4.1.4
nvidia-cufile-cu12 1.14.1.1
nvidia-curand-cu12 10.3.10.19
nvidia-cusolver-cu12 11.7.5.82
nvidia-cusparse-cu12 12.5.10.65
nvidia-cusparselt-cu12 0.8.1
nvidia-nccl-cu12 2.28.3
nvidia-nvjitlink-cu12 12.9.86
nvidia-nvshmem-cu12 3.4.5
nvidia-nvtx-cu12 12.9.79
jax: 0.7.2
jaxlib: 0.7.2
numpy: 2.1.3
python: 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0]
device info: Quadro RTX 4000-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='milkyway', release='6.8.0-85-generic', version='#85-Ubuntu SMP PREEMPT_DYNAMIC Thu Sep 18 15:26:59 UTC 2025', machine='x86_64')
$ nvidia-smi
Thu Oct 9 19:24:36 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.65.06 Driver Version: 580.65.06 CUDA Version: 13.0 |
+-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 Quadro RTX 4000 Off | 00000000:73:00.0 On | N/A |
| 30% 46C P0 37W / 125W | 1391MiB / 8192MiB | 5% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
EDITED OUT
+-----------------------------------------------------------------------------------------+
Description
Hi,
The following simple bit of code involving tridiagonal_solve works fine on CPU, but breaks with on the GPU:
Code:
CPU output:
GPU output:
If useful, here are the nvidia packages installed:
Thanks!
System info (python version, jaxlib version, accelerator, etc.)