Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add forward AD for torch.linalg.eigh#62163

Closed
IvanYashchuk wants to merge 12 commits into
pytorch:masterfrom
IvanYashchuk:eigh-jvp
Closed

Add forward AD for torch.linalg.eigh#62163
IvanYashchuk wants to merge 12 commits into
pytorch:masterfrom
IvanYashchuk:eigh-jvp

Conversation

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

@IvanYashchuk IvanYashchuk commented Jul 25, 2021

This PR adds forward mode differentiation for torch.linalg.eigh and a few other functions required for tests to pass.

For some reason running tests for torch.linalg.eigvalsh and complex torch.linalg.eigh hangs. These tests are skipped for now.

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @jianyuh @mruberry @heitorschueroff @walterddr @IvanYashchuk @xwang233

@IvanYashchuk IvanYashchuk added module: autograd Related to torch.autograd, and the autograd engine in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul labels Jul 25, 2021
@IvanYashchuk IvanYashchuk requested a review from albanD July 25, 2021 14:44
@IvanYashchuk IvanYashchuk requested a review from soulitzer as a code owner July 25, 2021 14:44
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jul 25, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit c3fc5ed (more details on the Dr. CI page):


  • 1/1 failures introduced in this PR

🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See GitHub Actions build linux-bionic-py3.8-gcc9-coverage / test (distributed, 1, 1, linux.2xlarge) (1/1)

Step: "Test PyTorch" (full log | diagnosis details | 🔁 rerun)

2021-09-13T16:15:29.3957201Z AssertionError: Fa...true : Scalars failed to compare as equal! 0 != -6
2021-09-13T16:15:29.3950004Z ----------------------------------------------------------------------
2021-09-13T16:15:29.3950475Z Traceback (most recent call last):
2021-09-13T16:15:29.3951287Z   File "/opt/conda/lib/python3.8/site-packages/torch/testing/_internal/common_distributed.py", line 418, in wrapper
2021-09-13T16:15:29.3951926Z     self._join_processes(fn)
2021-09-13T16:15:29.3952766Z   File "/opt/conda/lib/python3.8/site-packages/torch/testing/_internal/common_distributed.py", line 637, in _join_processes
2021-09-13T16:15:29.3953446Z     self._check_return_codes(elapsed_time)
2021-09-13T16:15:29.3954290Z   File "/opt/conda/lib/python3.8/site-packages/torch/testing/_internal/common_distributed.py", line 692, in _check_return_codes
2021-09-13T16:15:29.3954954Z     self.assertEqual(
2021-09-13T16:15:29.3955705Z   File "/opt/conda/lib/python3.8/site-packages/torch/testing/_internal/common_utils.py", line 1676, in assertEqual
2021-09-13T16:15:29.3956453Z     super().assertTrue(result, msg=self._get_assert_msg(msg, debug_msg=debug_msg))
2021-09-13T16:15:29.3957201Z AssertionError: False is not true : Scalars failed to compare as equal! 0 != -6
2021-09-13T16:15:29.3957927Z Expect process 1 exit code to match Process 0 exit code of -6, but got 0
2021-09-13T16:15:29.3958263Z 
2021-09-13T16:15:29.3958732Z ----------------------------------------------------------------------
2021-09-13T16:15:29.3959137Z Ran 85 tests in 116.391s
2021-09-13T16:15:29.3959342Z 
2021-09-13T16:15:29.3959759Z FAILED (failures=1, skipped=31)
2021-09-13T16:15:29.3960009Z 
2021-09-13T16:15:29.3960332Z Generating XML reports...
2021-09-13T16:15:29.3975698Z Generated XML report: test-reports/python-unittest/distributed.test_c10d_gloo/TEST-CommTest-20210913161332.xml
2021-09-13T16:15:29.4031334Z Generated XML report: test-reports/python-unittest/distributed.test_c10d_gloo/TEST-DistributedDataParallelTest-20210913161332.xml

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@jbschlosser jbschlosser added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 26, 2021
@IvanYashchuk
Copy link
Copy Markdown
Collaborator Author

@albanD, CI fails with NotImplementedError: Trying to use forward AD with _reshape_alias that does not support it.. Do you have an idea how to enable forward AD for _reshape_alias? There's no such entry in derivatives.yaml file.

Copy link
Copy Markdown
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few notes on forward mode AD for maps with constrained inputs.

Comment thread torch/csrc/autograd/FunctionsManual.cpp Outdated

auto F = eigenvalues.unsqueeze(-2) - eigenvalues.unsqueeze(-1);
F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);
F = F.pow(-1);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to divide by F than to compute the inverse explicitly an then multiply.

Comment thread torch/csrc/autograd/FunctionsManual.cpp Outdated
// https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124
// Section 3.1 Eigenvalues and eigenvectors

auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), input_tangent), eigenvectors);
Copy link
Copy Markdown
Collaborator

@lezcano lezcano Jul 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that the input tangent needs to be symmetrised as: const auto in = 0.5 * (input_tangent + input_tangent.transpose(-2, -1).conj()).

Denote Her(n) the n x n Hermitian matrices and U(n) the unitary matrices. We have that eigh : Her(n) -> R^n x U(n). This means that the differential of eigh goes from the tangent to Her(n) to the tangent of R^n times the tangent of U(n) (at the output matrices). Now, the tangent of Her(n) is Her(n) itself as Her(n) is just a linear subspace. As such, the input needs to be symmetric for this function to make sense.

Computing (A + A^H)/2 happens to be the orthogonal projection of a matrix onto the space Her(n). The theorem that formalises all this is the one that says that the differential of a map on an embedded manifold is the differential of the map on the total space restricted to the tangent space of the embedded manifold.

I wrote all this down in a comment in eigh_backward. In general it would be interesting that the way we compute forward / backward are somewhat equivalent (i.e. one is "taking the transpose" of the other).

Comment thread torch/csrc/autograd/FunctionsManual.cpp Outdated
F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);
F = F.pow(-1);

auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), input_tangent), eigenvectors);
Copy link
Copy Markdown
Collaborator

@lezcano lezcano Jul 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_tangent should be projected on to the Hermitian matrices (see comment above).

@IvanYashchuk IvanYashchuk marked this pull request as draft July 28, 2021 18:09
@pytorch-probot
Copy link
Copy Markdown

pytorch-probot Bot commented Sep 6, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/IvanYashchuk/pytorch/blob/c3fc5edbbfe091f2165a89afcd862419d740babb/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-bionic-py3.8-gcc9-coverage ciflow/all, ciflow/coverage, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda10.2-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
paralleltbb-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
puretorch-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/win 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@pytorch-probot pytorch-probot Bot assigned pytorchbot and unassigned pytorchbot Sep 6, 2021
@codecov
Copy link
Copy Markdown

codecov Bot commented Sep 6, 2021

Codecov Report

Merging #62163 (c3fc5ed) into master (b37503e) will decrease coverage by 4.59%.
The diff coverage is n/a.

@@            Coverage Diff             @@
##           master   #62163      +/-   ##
==========================================
- Coverage   66.60%   62.00%   -4.60%     
==========================================
  Files         716      716              
  Lines       92689    92689              
==========================================
- Hits        61735    57472    -4263     
- Misses      30954    35217    +4263     

@IvanYashchuk IvanYashchuk marked this pull request as ready for review September 7, 2021 07:41
@IvanYashchuk
Copy link
Copy Markdown
Collaborator Author

CI is green now. Thank you, @lezcano, I improved the PR with your suggestions. Could you take another look please?

@albanD, could you please review this PR?

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
I'll wait for @lezcano final approval and I'll land.

Comment thread torch/csrc/autograd/FunctionsManual.cpp Outdated
auto hermitian_tangent = 0.5*(input_tangent + input_tangent.transpose(-2, -1).conj());

auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), hermitian_tangent), eigenvectors);
auto eigenvectors_tangent = at::matmul(eigenvectors, tmp.div(E));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit you can return directly here to avoid the extra assignment.

decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# Gradcheck for complex hangs for this function
SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an issue open tracking this? Do we want one if not?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no open issue about it. I added a NotImplementedError for complex inputs for now, we'll enable it later.

Copy link
Copy Markdown
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM modulo @albanD's comments.
The only point is that it might be worth to land this after the support for several outputs in AD mode as, at the moment, this AD is doing twice the same computations. Once for eigvals and another one for eigvectors.

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Sep 7, 2021

Now that this is ready, I think we can just add this to master. And merge the implementation when the multi-output support is added.

@pytorch-probot pytorch-probot Bot assigned pytorchbot and unassigned pytorchbot Sep 13, 2021
@facebook-github-bot
Copy link
Copy Markdown
Contributor

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@albanD merged this pull request in 0aef44c.

@IvanYashchuk IvanYashchuk deleted the eigh-jvp branch September 14, 2021 12:39
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
Summary:
This PR adds forward mode differentiation for `torch.linalg.eigh` and a few other functions required for tests to pass.

For some reason running tests for `torch.linalg.eigvalsh` and complex `torch.linalg.eigh` hangs. These tests are skipped for now.

cc ezyang albanD zou3519 gqchen pearu nikitaved soulitzer Lezcano Varal7 jianyuh mruberry heitorschueroff walterddr IvanYashchuk xwang233

Pull Request resolved: pytorch#62163

Reviewed By: jbschlosser

Differential Revision: D30903988

Pulled By: albanD

fbshipit-source-id: d6a74adb9e6d2f4be8ac707848ecabf06d629823
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: autograd Related to torch.autograd, and the autograd engine in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants