Thanks to visit codestin.com
Credit goes to github.com

Skip to content

tridiagonal_solve for a tiny matrix works on CPU but not on GPU #32487

@segasai

Description

@segasai

Description

Hi,

The following simple bit of code involving tridiagonal_solve works fine on CPU, but breaks with on the GPU:
Code:

import jax.numpy as jnp
import jax.lax.linalg as jll

print(
    jll.tridiagonal_solve(jnp.array([0, 1.]), jnp.array([4, 4.]),
                          jnp.array([1, 0.]),
                          jnp.array([0., 0])[:, None]))

CPU output:

$ env JAX_PLATFORMS=cpu python  xlog.py 
[[0.]
 [0.]]

GPU output:

$ env JAX_TRACEBACK_FILTERING=off python  xlog.py 
E1009 19:27:14.558310 1780750 pjrt_stream_executor_client.cc:3314] Execution of replica 0 failed: INTERNAL: jaxlib/gpu/sparse_kernels.cc:537: operation kernel(handle.get(), m, n, dl_data, d_data, du_data, out_data, m, workspace) failed: invalid value
Traceback (most recent call last):
  File "/tmp/xlog.py", line 5, in <module>
    jll.tridiagonal_solve(jnp.array([0, 1.]), jnp.array([4, 4.]),
  File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/lax/linalg.py", line 664, in tridiagonal_solve
    return tridiagonal_solve_p.bind(dl, d, du, b)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/core.py", line 634, in bind
    return self._true_bind(*args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/core.py", line 650, in _true_bind
    return self.bind_with_trace(prev_trace, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/core.py", line 662, in bind_with_trace
    return trace.process_primitive(self, args, params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/core.py", line 1189, in process_primitive
    return primitive.impl(*args, **params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/dispatch.py", line 90, in apply_primitive
    outs = fun(*args)
           ^^^^^^^^^^
  File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/pjit.py", line 268, in cache_miss
    executable, pgle_profiler, const_args) = _python_pjit_helper(
                                             ^^^^^^^^^^^^^^^^^^^^
  File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/pjit.py", line 147, in _python_pjit_helper
    out_flat, compiled, profiler, const_args = _pjit_call_impl_python(
                                               ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/pjit.py", line 1780, in _pjit_call_impl_python
    return (compiled.unsafe_call(*computation.const_args, *args),
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/profiler.py", line 364, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/skoposov/pyenv_dir/pyenv312/lib/python3.12/site-packages/jax/_src/interpreters/pxla.py", line 1379, in __call__
    results = self.xla_executable.execute_sharded(input_bufs)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib._jax.XlaRuntimeError: INTERNAL: jaxlib/gpu/sparse_kernels.cc:537: operation kernel(handle.get(), m, n, dl_data, d_data, du_data, out_data, m, workspace) failed: invalid value

If useful, here are the nvidia packages installed:

nvidia-cuda-cupti-cu12                  12.9.79
nvidia-cuda-nvcc-cu12                   12.9.86
nvidia-cuda-nvrtc-cu12                  12.9.86
nvidia-cuda-runtime-cu12                12.9.79
nvidia-cudnn-cu12                       9.14.0.64
nvidia-cufft-cu12                       11.4.1.4
nvidia-cufile-cu12                      1.14.1.1
nvidia-curand-cu12                      10.3.10.19
nvidia-cusolver-cu12                    11.7.5.82
nvidia-cusparse-cu12                    12.5.10.65
nvidia-cusparselt-cu12                  0.8.1
nvidia-nccl-cu12                        2.28.3
nvidia-nvjitlink-cu12                   12.9.86
nvidia-nvshmem-cu12                     3.4.5
nvidia-nvtx-cu12                        12.9.79

Thanks!

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.7.2
jaxlib: 0.7.2
numpy:  2.1.3
python: 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0]
device info: Quadro RTX 4000-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='milkyway', release='6.8.0-85-generic', version='#85-Ubuntu SMP PREEMPT_DYNAMIC Thu Sep 18 15:26:59 UTC 2025', machine='x86_64')

$ nvidia-smi
Thu Oct  9 19:24:36 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.65.06              Driver Version: 580.65.06      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Quadro RTX 4000                Off |   00000000:73:00.0  On |                  N/A |
| 30%   46C    P0             37W /  125W |    1391MiB /   8192MiB |      5%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
EDITED OUT
+-----------------------------------------------------------------------------------------+

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions