Thanks to visit codestin.com
Credit goes to github.com

Skip to content

geoffroyO/nufftax

Repository files navigation

nufftax logo

Pure JAX implementation of the Non-Uniform Fast Fourier Transform (NUFFT)

CI Documentation Python 3.12+ License: MIT


MRI reconstruction example

Why nufftax?

A JAX package for NUFFT already exists: jax-finufft. However, it wraps the C++ FINUFFT library via Foreign Function Interface (FFI), exposing it through custom XLA calls. This approach can lead to:

  • Kernel fusion issues on GPU — custom XLA calls act as optimization barriers, preventing XLA from fusing operations
  • CUDA version matching — GPU support requires matching CUDA versions between JAX and the library

nufftax takes a different approach — pure JAX implementation:

  • Fully differentiable — gradients w.r.t. both values and sample locations
  • Pure JAX — works with jit, grad, vmap, jvp, vjp with no FFI barriers
  • GPU ready — runs on CPU/GPU without code changes, benefits from XLA fusion
  • All NUFFT types — Type 1, 2, 3 in 1D, 2D, 3D

JAX Transformation Support

Transform jit grad/vjp jvp vmap
Type 1 (1D/2D/3D)
Type 2 (1D/2D/3D)
Type 3 (1D/2D/3D)

Differentiable inputs:

  • Type 1: grad w.r.t. c (strengths) and x, y, z (coordinates)
  • Type 2: grad w.r.t. f (Fourier modes) and x, y, z (coordinates)
  • Type 3: grad w.r.t. c (strengths), x, y, z (source coordinates), and s, t, u (target frequencies)

Installation

uv pip install nufftax

Quick Example

import jax
import jax.numpy as jnp
from nufftax import nufft1d1

# Irregular sample locations in [-pi, pi)
x = jnp.array([0.1, 0.7, 1.3, 2.1, -0.5])
c = jnp.array([1.0+0.5j, 0.3-0.2j, 0.8+0.1j, 0.2+0.4j, 0.5-0.3j])

# Compute Fourier modes
f = nufft1d1(x, c, n_modes=32, eps=1e-6)

# Differentiate through the transform
grad_c = jax.grad(lambda c: jnp.sum(jnp.abs(nufft1d1(x, c, n_modes=32)) ** 2))(c)

Documentation

Read the full documentation →

License

MIT. Algorithm based on FINUFFT by the Flatiron Institute.

Citation

If you use nufftax in your research, please cite:

@software{nufftax,
  author = {Oudoumanessah, Geoffroy and Iollo, Jacopo},
  title = {nufftax: Pure JAX implementation of the Non-Uniform Fast Fourier Transform},
  url = {https://github.com/geoffroyO/nufftax},
  year = {2026}
}

@article{finufft,
  author = {Barnett, Alexander H. and Magland, Jeremy F. and af Klinteberg, Ludvig},
  title = {A parallel non-uniform fast Fourier transform library based on an ``exponential of semicircle'' kernel},
  journal = {SIAM J. Sci. Comput.},
  volume = {41},
  number = {5},
  pages = {C479--C504},
  year = {2019}
}

About

Pure JAX implementation of the Non-Uniform Fast Fourier Transform (NUFFT)

Topics

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Contributors 3

  •  
  •  
  •  

Languages