You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Forward mode automatic differentiation (Forward AD) is useful for computing Jacobian-vector-products of mappings 𝕂ⁿ → 𝕂ᵐ, especially when m>>n.
Forward AD was recently added to PyTorch and we need more rules coverage. The process of adding a forward AD rule is similar to the one for adding reverse AD rules.
When implementing a new rule always start with tools/autograd/derivatives.yaml, this is the file where all autograd rules are registered.
Let A=self and C=inverse, then the reverse AD rule is "A̅ = −Cᴴ C̅ Cᴴ" and the forward AD rule is "Ċ = −C Ȧ C".
In the above schema self: includes the code to compute A̅ and inverse: has the code for computing Ċ.
The tangents of input variables have the suffix _t in the code.
If the code for the derivative rule is too complex to fit in one line of derivatives.yaml file,
then a C++ function can be added to torch/csrc/autograd/FunctionsManual.h implementing the rule in torch/csrc/autograd/FunctionsManual.cpp.
Tests of the forward AD support of a specific function can be enabled with supports_forward_ad=True in OpInfo of the corresponding operator in torch/testing/_internal/common_methods_invocations.py.
🚀 Feature
Forward mode automatic differentiation (Forward AD) is useful for computing Jacobian-vector-products of mappings 𝕂ⁿ → 𝕂ᵐ, especially when m>>n.
Forward AD was recently added to PyTorch and we need more rules coverage. The process of adding a forward AD rule is similar to the one for adding reverse AD rules.
When implementing a new rule always start with
tools/autograd/derivatives.yaml, this is the file where all autograd rules are registered.Let's look at the rule for
linalg_inv_ex:Let
A=selfandC=inverse, then the reverse AD rule is "A̅ = −Cᴴ C̅ Cᴴ" and the forward AD rule is "Ċ = −C Ȧ C".In the above schema
self:includes the code to compute A̅ andinverse:has the code for computing Ċ.The tangents of input variables have the suffix
_tin the code.If the code for the derivative rule is too complex to fit in one line of
derivatives.yamlfile,then a C++ function can be added to
torch/csrc/autograd/FunctionsManual.himplementing the rule intorch/csrc/autograd/FunctionsManual.cpp.Tests of the forward AD support of a specific function can be enabled with
supports_forward_ad=Truein OpInfo of the corresponding operator intorch/testing/_internal/common_methods_invocations.py.Some useful resources for rules and derivations:
M. Giles "An extended collection of matrix derivative results for forward and reverse mode automatic differentiation"
https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124
Seth Axen's blogpost on LU decomposition:
https://sethaxen.com/blog/2021/02/differentiating-the-lu-decomposition/
ChanRules.jl tutorial on deriving rules:
https://juliadiff.org/ChainRulesCore.jl/dev/arrays.html
Jin-Guo Liu's blogpost on SVD differentiation (focusing on complex input case):
https://giggleliu.github.io/2019/04/02/einsumbp.html
Here's a task list for linear algebra functions:
torch.linalg.inv, torch.linalg.inv_ex
PR: Add forward mode differentiation for inverse and solve #62160
torch.linalg.solve
PR: Add forward mode differentiation for inverse and solve #62160)
torch.linalg.cholesky, torch.linalg.cholesky_ex
PR: Add forward mode differentiation for torch.linalg.cholesky and transpose #62159
torch.linalg.eigh, torch.linalg.eigvalsh
PR: Add forward AD for torch.linalg.eigh #62163
torch.linalg.det
torch.linalg.slogdet
torch.linalg.qr
PR: torch.linalg.qr: forward AD support #67268
torch.linalg.eig, torch.linalg.eigvals
PRs: Added forward derivatives for neg, diag, inverse, linalg_eig #67837, Correct forward AD for linalg.eig and add checks #70527
torch.linalg.svd, torch.linalg.svdvals
PR in progress: Implement forward AD for linalg.svd and improve svd_backward #70253
torch.linalg.lstsq
PR:
torch.linalg.lstsq: forward/backward AD support #65054torch.linalg.solve_triangular
PR: Add linalg.solve_triangular #63568
torch.cholesky_solve, torch.lu_solve, torch.triangular_solve
PR:
*_solvemethods: implements forward AD #65546torch.cholesky_inverse
torch.lu_unpack
PR in progress:
torch.lu_unpack: forward AD support #64810cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @jianyuh @mruberry @heitorschueroff @walterddr @IvanYashchuk @xwang233