This library provides a PyTorch implementation of Taylor mode automatic differentiation, a generalization of forward mode to higher-order derivatives.
It is similar to JAX's Taylor mode (jax.experimental.jet).
The repository also hosts the Python functionality+experiments and LaTeX source for our NeurIPS 2025 paper "Collapsing Taylor Mode Automatic Differentiation", which allows to further accelerate Taylor mode for many practical differential operators.
🔪 Warning: expect rough edges! 🔪
This is a research prototype with various limitations (e.g. operator coverage). We highly recommend double-checking your results with PyTorch's autodiff. Please help us improve the package by providing feedback, filing issues, and opening pull requests.
pip install jet-for-pytorchSee the documentation.
If you find the jet package useful for your research, consider citing
@inproceedings{dangel2025collapsing,
title = {Collapsing Taylor Mode Automatic Differentiation},
author = {Felix Dangel and Tim Siebert and Marius Zeinhofer and Andrea
Walther},
year = 2025,
booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
}