-
Notifications
You must be signed in to change notification settings - Fork 3.2k
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Description
I am trying to run a code (proprietary) with higher precision (JAX_ENABLE_X64=1
), and it is giving me this error; surprisingly, the same code runs on 32-bit floating point numbers. Any leads what can be wrong or is this an actual bug?
I am unable to produce an MRE!
Traceback (most recent call last):
File "/home/meesum/academia/gwkokab/.venv/bin/f_monk_n_pls_m_gs", line 10, in <module>
sys.exit(main())
^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/kokab/n_pls_m_gs/monk.py", line 420, in main
).run()
^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/kokab/core/monk.py", line 199, in run
self.driver(
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/kokab/core/flowMC_based.py", line 862, in driver
sampler.sample(initial_position, data)
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/kokab/core/flowMC_based.py", line 434, in sample
) = self.strategies[strategy](
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/flowMC/strategy/take_steps.py", line 127, in __call__
positions, log_probs, do_accepts = eqx.filter_jit(
^^^^^^^^^^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/equinox/_jit.py", line 209, in __call__
return _call(self, False, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/equinox/_jit.py", line 263, in _call
marker, _, _ = out = jit_wrapper._cached(
^^^^^^^^^^^^^^^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/traceback_util.py", line 180, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 339, in cache_miss
pgle_profiler) = _python_pjit_helper(fun, jit_info, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 194, in _python_pjit_helper
out_flat, compiled, profiler = _pjit_call_impl_python(*args_flat, **p.params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/pjit.py", line 1659, in _pjit_call_impl_python
).compile()
^^^^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2448, in compile
executable = UnloadedMeshExecutable.from_hlo(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2967, in from_hlo
xla_executable = _cached_compilation(
^^^^^^^^^^^^^^^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 2758, in _cached_compilation
xla_executable = compiler.compile_or_get_cached(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 470, in compile_or_get_cached
return _compile_and_write_cache(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 687, in _compile_and_write_cache
executable = backend_compile(
^^^^^^^^^^^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/profiler.py", line 334, in wrapper
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 327, in backend_compile
raise e
File "/home/meesum/academia/gwkokab/.venv/lib/python3.11/site-packages/jax/_src/compiler.py", line 315, in backend_compile
return backend.compile(
^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: unsupported operands type
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.5.3
jaxlib: 0.5.3
numpy: 2.3.3
python: 3.11.11 (main, Feb 12 2025, 14:51:05) [Clang 19.1.6 ]
device info: NVIDIA GeForce RTX 4090-1, 1 local devices"
process_count: 1
platform: uname_result(system='Linux', node='administrator-MS-7B17', release='6.14.0-29-generic', version='#29~24.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Aug 14 16:52:50 UTC 2', machine='x86_64')
$ nvidia-smi
Fri Sep 26 17:53:20 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 575.64.03 Driver Version: 575.64.03 CUDA Version: 12.9 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA GeForce RTX 4090 Off | 00000000:01:00.0 Off | Off |
| 0% 39C P8 37W / 450W | 439MiB / 24564MiB | 0% Default |
| | | N/A |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
| 0 N/A N/A 1431 G /usr/lib/xorg/Xorg 9MiB |
| 0 N/A N/A 1579 G /usr/bin/gnome-shell 10MiB |
| 0 N/A N/A 2041816 C python 390MiB |
+-----------------------------------------------------------------------------------------+
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working