Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Forward mode AD for linear algebra functions #64545

@IvanYashchuk

Description

@IvanYashchuk

🚀 Feature

Forward mode automatic differentiation (Forward AD) is useful for computing Jacobian-vector-products of mappings 𝕂ⁿ → 𝕂ᵐ, especially when m>>n.
Forward AD was recently added to PyTorch and we need more rules coverage. The process of adding a forward AD rule is similar to the one for adding reverse AD rules.

When implementing a new rule always start with tools/autograd/derivatives.yaml, this is the file where all autograd rules are registered.

Let's look at the rule for linalg_inv_ex:

- name: linalg_inv_ex(Tensor self, *, bool check_errors=False) -> (Tensor inverse, Tensor info)
  self: -at::matmul(inverse.conj().transpose(-2, -1), at::matmul(grad, inverse.conj().transpose(-2, -1)))
  inverse: -at::matmul(at::matmul(inverse, self_t), inverse)

Let A=self and C=inverse, then the reverse AD rule is "A̅ = −Cᴴ C̅ Cᴴ" and the forward AD rule is "Ċ = −C Ȧ C".
In the above schema self: includes the code to compute A̅ and inverse: has the code for computing Ċ.
The tangents of input variables have the suffix _t in the code.

If the code for the derivative rule is too complex to fit in one line of derivatives.yaml file,
then a C++ function can be added to torch/csrc/autograd/FunctionsManual.h implementing the rule in torch/csrc/autograd/FunctionsManual.cpp.

Tests of the forward AD support of a specific function can be enabled with supports_forward_ad=True in OpInfo of the corresponding operator in torch/testing/_internal/common_methods_invocations.py.

Some useful resources for rules and derivations:

Here's a task list for linear algebra functions:

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @jianyuh @mruberry @heitorschueroff @walterddr @IvanYashchuk @xwang233

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: linear algebraIssues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmultriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions