Thanks to visit codestin.com
Credit goes to github.com

Skip to content

jeertmans/fpt-jax

Repository files navigation

Fermat path-tracing with JAX

arXiv link Latest Release Python version

fpt-jax is a standalone library for differentiable path-tracing using the Fermat principle, implemented with JAX.

Installation

You can install this package from PyPI:

pip install fpt-jax

Usage

This library implements a single function, trace_rays, which traces rays undergoing specular reflections and diffractions on planar objects defined by origins and basis vectors:

> from fpt_jax import trace_rays; help(trace_rays)

trace_rays
   (tx: jax.Array, rx: jax.Array,
    object_origins: jax.Array, object_vectors: jax.Array, *,
    num_iters: int, unroll: int | bool = 1,
    num_iters_linesearch: int = 1, unroll_linesearch: int | bool = 1,
    implicit_diff: bool = True) -> jax.Array:

Compute the points of interaction of rays with objects using Fermat's principle.

Each ray is obtained by minimizing the total travel distance from transmitter to receiver, using a quasi-Newton optimization algorithm (BFGS). At each iteration, a line search is performed to find the optimal step size along the descent direction.

This function accepts batched inputs, where the leading dimensions must be broadcast-compatible.

Args:
    tx: Transmitter positions of shape (..., 3).
    rx: Receiver positions of shape (..., 3).
    object_origins: Origins of the objects of shape (..., num_interactions, 3).
    object_vectors: Vectors defining the objects of shape (..., num_interactions, num_dims, 3).
    num_iters: Number of iterations for the optimization algorithm.
    unroll: If an integer, the number of optimization iterations to unroll in the JAX scan.
        If True, unroll all iterations. If False, do not unroll.
    num_iters_linesearch: Number of iterations for the line search fixed-point iteration.
    unroll_linesearch: If an integer, the number of fixed-point iterations to unroll in the JAX scan.
        If True, unroll all iterations. If False, do not unroll.
    implicit_diff: Whether to use implicit differentiation for computing the gradient.
        If True, assumes that the solution has converged and applies the implicit function theorem
        to differentiate the optimization problem with respect to the input parameters:
            tx, rx, object_origins, and object_vectors.
        If False, the gradient is computed by backpropagating through all iterations of the optimization algorithm.

        Using implicit differentiation is more memory- and computationally efficient,
        as it does not require storing intermediate values from all iterations,
        but it may be less accurate if the optimization has not fully converged.
        Moreover, implicit differentiation is not compatible with forward-mode autodiff in JAX.

Returns:
    The points of interaction of shape (..., num_interactions, 3).
    To include the transmitter and receiver positions, concatenate tx and rx to the result.


This algorithm is also available within DiffeRT, our differentiable ray tracing library for radio propagation.

Getting help

For any question about the method or its implementation, make sure to first read the related paper.

If you want to report a bug in this library or the underlying algorithm, please open an issue on this GitHub repository. If you want to request a new feature, please consider opening an issue on DiffeRT's GitHub repository instead.

Citing

If you use this library in your research, please cite our paper:

@misc{eertmans2025fpt,
  title         = {Fast, Differentiable, GPU-Accelerated Ray Tracing for Multiple Diffraction and Reflection Paths},
  author        = {Jérome Eertmans and Sophie Lequeu and Benoît Legat and Laurent Jacques and Claude Oestges},
  year          = 2025,
  url           = {https://arxiv.org/abs/2510.16172},
  eprint        = {2510.16172},
  archiveprefix = {arXiv},
  primaryclass  = {eess.SP}
}