Thanks to visit codestin.com
Credit goes to github.com

Skip to content

jaxlib.xla_extension.XlaRuntimeError: INTERNAL: unsupported operands type on f64 but not on f32 #32136

@Qazalbash

Description

@Qazalbash

Description

I am trying to run a code (proprietary) with higher precision (JAX_ENABLE_X64=1), and it is giving me this error; surprisingly, the same code runs on 32-bit floating point numbers. Any leads what can be wrong or is this an actual bug?

I am unable to produce an MRE!

Traceback (most recent call last):
  File "/home/meesum/academia/gwkokab/.venv/bin/f_monk_n_pls_m_gs", line 10, in <module>
    sys.exit(main())
             ^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/kokab/n_pls_m_gs/monk.py", line 420, in main
    ).run()
      ^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/kokab/core/monk.py", line 199, in run
    self.driver(
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/kokab/core/flowMC_based.py", line 862, in driver
    sampler.sample(initial_position, data)
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/kokab/core/flowMC_based.py", line 434, in sample
    ) = self.strategies[strategy](
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/flowMC/strategy/take_steps.py", line 127, in __call__
    positions, log_probs, do_accepts = eqx.filter_jit(
                                       ^^^^^^^^^^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/equinox/_jit.py", line 209, in __call__
    return _call(self, False, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/equinox/_jit.py", line 263, in _call
    marker, _, _ = out = jit_wrapper._cached(
                         ^^^^^^^^^^^^^^^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 339, in cache_miss
    pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 194, in _python_pjit_helper
    out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1659, in _pjit_call_impl_python
    ).compile()
      ^^^^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2448, in compile
    executable = UnloadedMeshExecutable.from_hlo(
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2967, in from_hlo
    xla_executable = _cached_compilation(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2758, in _cached_compilation
    xla_executable = compiler.compile_or_get_cached(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 470, in compile_or_get_cached
    return _compile_and_write_cache(
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 687, in _compile_and_write_cache
    executable = backend_compile(
                 ^^^^^^^^^^^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 334, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 327, in backend_compile
    raise e
  File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 315, in backend_compile
    return backend.compile(
           ^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: unsupported operands type

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.5.3
jaxlib: 0.5.3
numpy:  2.3.3
python: 3.11.11 (main, Feb 12 2025, 14:51:05) [Clang 19.1.6 ]
device info: NVIDIA GeForce RTX 4090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='administrator-MS-7B17', release='6.14.0-29-generic', version='#29~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Aug 14 16:52:50 UTC 2', machine='x86_64')


$ nvidia-smi
Fri Sep 26 17:53:20 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.64.03              Driver Version: 575.64.03      CUDA Version: 12.9     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4090        Off |   00000000:01:00.0 Off |                  Off |
|  0%   39C    P8             37W /  450W |     439MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI              PID   Type   Process name                        GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A            1431      G   /usr/lib/xorg/Xorg                        9MiB |
|    0   N/A  N/A            1579      G   /usr/bin/gnome-shell                     10MiB |
|    0   N/A  N/A         2041816      C   python                                  390MiB |
+-----------------------------------------------------------------------------------------+

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions