Pure JAX implementation of the Non-Uniform Fast Fourier Transform (NUFFT)
A JAX package for NUFFT already exists: jax-finufft. However, it wraps the C++ FINUFFT library via Foreign Function Interface (FFI), exposing it through custom XLA calls. This approach can lead to:
- Kernel fusion issues on GPU — custom XLA calls act as optimization barriers, preventing XLA from fusing operations
- CUDA version matching — GPU support requires matching CUDA versions between JAX and the library
nufftax takes a different approach — pure JAX implementation:
- Fully differentiable — gradients w.r.t. both values and sample locations
- Pure JAX — works with
jit,grad,vmap,jvp,vjpwith no FFI barriers - GPU ready — runs on CPU/GPU without code changes, benefits from XLA fusion
- All NUFFT types — Type 1, 2, 3 in 1D, 2D, 3D
| Transform | jit |
grad/vjp |
jvp |
vmap |
|---|---|---|---|---|
| Type 1 (1D/2D/3D) | ✅ | ✅ | ✅ | ✅ |
| Type 2 (1D/2D/3D) | ✅ | ✅ | ✅ | ✅ |
| Type 3 (1D/2D/3D) | ✅ | ✅ | ✅ | ✅ |
Differentiable inputs:
- Type 1:
gradw.r.t.c(strengths) andx, y, z(coordinates) - Type 2:
gradw.r.t.f(Fourier modes) andx, y, z(coordinates) - Type 3:
gradw.r.t.c(strengths),x, y, z(source coordinates), ands, t, u(target frequencies)
uv pip install nufftaximport jax
import jax.numpy as jnp
from nufftax import nufft1d1
# Irregular sample locations in [-pi, pi)
x = jnp.array([0.1, 0.7, 1.3, 2.1, -0.5])
c = jnp.array([1.0+0.5j, 0.3-0.2j, 0.8+0.1j, 0.2+0.4j, 0.5-0.3j])
# Compute Fourier modes
f = nufft1d1(x, c, n_modes=32, eps=1e-6)
# Differentiate through the transform
grad_c = jax.grad(lambda c: jnp.sum(jnp.abs(nufft1d1(x, c, n_modes=32)) ** 2))(c)- Quickstart — get running in 5 minutes
- Concepts — understand the mathematics
- Tutorials — MRI reconstruction, spectral analysis, optimization
- API Reference — complete function reference
MIT. Algorithm based on FINUFFT by the Flatiron Institute.
If you use nufftax in your research, please cite:
@software{nufftax,
author = {Oudoumanessah, Geoffroy and Iollo, Jacopo},
title = {nufftax: Pure JAX implementation of the Non-Uniform Fast Fourier Transform},
url = {https://github.com/geoffroyO/nufftax},
year = {2026}
}
@article{finufft,
author = {Barnett, Alexander H. and Magland, Jeremy F. and af Klinteberg, Ludvig},
title = {A parallel non-uniform fast Fourier transform library based on an ``exponential of semicircle'' kernel},
journal = {SIAM J. Sci. Comput.},
volume = {41},
number = {5},
pages = {C479--C504},
year = {2019}
}