Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add forward mode differentiation for torch.linalg.cholesky and transpose#62159

Closed
IvanYashchuk wants to merge 5 commits into
pytorch:masterfrom
IvanYashchuk:cholesky-jvp
Closed

Add forward mode differentiation for torch.linalg.cholesky and transpose#62159
IvanYashchuk wants to merge 5 commits into
pytorch:masterfrom
IvanYashchuk:cholesky-jvp

Conversation

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

@IvanYashchuk IvanYashchuk commented Jul 25, 2021

This PR adds forward mode differentiation for torch.linalg.cholesky, torch.linalg.cholesky_ex, and transpose functions.
Complex tests for Cholesky fail because for some reason the gradcheck sends matrices full of zeros to cholesky_jvp function.

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @jianyuh @mruberry @heitorschueroff @walterddr @IvanYashchuk @xwang233

@IvanYashchuk IvanYashchuk added module: autograd Related to torch.autograd, and the autograd engine in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul labels Jul 25, 2021
@IvanYashchuk IvanYashchuk requested a review from albanD July 25, 2021 09:29
@IvanYashchuk IvanYashchuk requested a review from soulitzer as a code owner July 25, 2021 09:29
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jul 25, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 14dcdec (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@jbschlosser jbschlosser added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 26, 2021
Copy link
Copy Markdown
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just left a comment. The rest of the code LGTM. Nice use of triangular_solve to compute the inverse :)

decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]),
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# Gradcheck for complex generates invalid inputs for this function
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might this be because it does not create Hermitian inputs? It's not clear whether we can assume that the inputs are Hermitian or not. See #62163 (comment)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho yes, I encountered that one but forgot to fix it.
We should update

return apply_to_c_inps(fn, lambda x: x + 0 * 1j), apply_to_c_inps(fn, lambda x: x * 1j)
to make it keep the original imag/real respectively instead of setting them to 0s.

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!
Thanks

decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]),
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# Gradcheck for complex generates invalid inputs for this function
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho yes, I encountered that one but forgot to fix it.
We should update

return apply_to_c_inps(fn, lambda x: x + 0 * 1j), apply_to_c_inps(fn, lambda x: x * 1j)
to make it keep the original imag/real respectively instead of setting them to 0s.


- name: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)
self: grad.transpose(dim0, dim1)
result: auto_linear
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note that this will conflict with #59993
Not a big deal as the rebase will be simple, but we should not land them at the same time.

@pytorch-probot
Copy link
Copy Markdown

pytorch-probot Bot commented Sep 6, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/IvanYashchuk/pytorch/blob/14dcdecbd232848798202d46c1e5e2f42c526496/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-bionic-py3.8-gcc9-coverage ciflow/all, ciflow/coverage, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda10.1-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/win 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@codecov
Copy link
Copy Markdown

codecov Bot commented Sep 6, 2021

Codecov Report

Merging #62159 (14dcdec) into master (544c8e6) will increase coverage by 0.08%.
The diff coverage is 100.00%.

@@            Coverage Diff             @@
##           master   #62159      +/-   ##
==========================================
+ Coverage   66.63%   66.72%   +0.08%     
==========================================
  Files         707      707              
  Lines       92338    92338              
==========================================
+ Hits        61534    61612      +78     
+ Misses      30804    30726      -78     

@IvanYashchuk
Copy link
Copy Markdown
Collaborator Author

torch.linalg.cholesky now has the upper argument. I updated the code and OpInfo's sample inputs to test that.
@albanD, could you please take another look and merge if it's good to go?

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@lezcano
Copy link
Copy Markdown
Collaborator

lezcano commented Sep 8, 2021

A small note: We should have a look at the performance of solve_triangular once #63568 is merged (or even solve_triangular_out). If it's not much worse than matmul, we should do two solves rather than inverse + 2 matmuls.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@albanD merged this pull request in dd8f6ac.

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
…ose (pytorch#62159)

Summary:
This PR adds forward mode differentiation for `torch.linalg.cholesky`, `torch.linalg.cholesky_ex`, and `transpose` functions.
Complex tests for Cholesky fail because for some reason the gradcheck sends matrices full of zeros to `cholesky_jvp` function.

cc ezyang albanD zou3519 gqchen pearu nikitaved soulitzer Lezcano Varal7 jianyuh mruberry heitorschueroff walterddr IvanYashchuk xwang233

Pull Request resolved: pytorch#62159

Reviewed By: mrshenli

Differential Revision: D30776829

Pulled By: albanD

fbshipit-source-id: 32e5539ed6423eed8c18cce16271330ab0ea8d5e
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed Merged module: autograd Related to torch.autograd, and the autograd engine in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants