Thanks to visit codestin.com
Credit goes to github.com

Skip to content

*_solve methods: implements forward AD#65546

Closed
nikitaved wants to merge 18 commits into
masterfrom
nikitaved/solve_methods_forward_AD
Closed

*_solve methods: implements forward AD#65546
nikitaved wants to merge 18 commits into
masterfrom
nikitaved/solve_methods_forward_AD

Conversation

@nikitaved
Copy link
Copy Markdown
Collaborator

@nikitaved nikitaved commented Sep 23, 2021

This PR adds forward AD for *_solve methods.
Additionally, cholesky_solve gets OpInfo + a bug fix when wrong leading dimensions could be passed to LAPACK,
and lu_solve gets forward AD with 2xlu_solve instead of 1xlu_solve + 2xtriangular_solve.

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @lezcano @Varal7 @jianyuh @mruberry @walterddr @IvanYashchuk @xwang233

@pytorch-probot
Copy link
Copy Markdown

pytorch-probot Bot commented Sep 23, 2021

CI Flow Status

⚛️ CI Flow

Ruleset - Version: v1
Ruleset - File: https://github.com/pytorch/pytorch/blob/99f07670fb688f8445d16f4eec4d0edbe882c424/.github/generated-ciflow-ruleset.json
PR ciflow labels: ciflow/default

Workflows Labels (bold enabled) Status
Triggered Workflows
linux-bionic-py3.6-clang9 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/noarch, ciflow/xla ✅ triggered
linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-clang7-asan ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux, ciflow/sanitizers ✅ triggered
linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
linux-xenial-py3.6-gcc7-bazel-test ciflow/all, ciflow/bazel, ciflow/cpu, ciflow/default, ciflow/linux ✅ triggered
win-vs2019-cpu-py3 ciflow/all, ciflow/cpu, ciflow/default, ciflow/win ✅ triggered
win-vs2019-cuda11.3-py3 ciflow/all, ciflow/cuda, ciflow/default, ciflow/win ✅ triggered
Skipped Workflows
libtorch-linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
libtorch-linux-xenial-cuda11.3-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux 🚫 skipped
linux-bionic-cuda10.2-py3.9-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
linux-xenial-cuda10.2-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/slow 🚫 skipped
parallelnative-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped
periodic-libtorch-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/libtorch, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-linux-xenial-cuda11.1-py3.6-gcc7 ciflow/all, ciflow/cuda, ciflow/linux, ciflow/scheduled 🚫 skipped
periodic-win-vs2019-cuda11.1-py3 ciflow/all, ciflow/cuda, ciflow/scheduled, ciflow/win 🚫 skipped
puretorch-linux-xenial-py3.6-gcc5.4 ciflow/all, ciflow/cpu, ciflow/linux 🚫 skipped

You can add a comment to the PR and tag @pytorchbot with the following commands:
# ciflow rerun, "ciflow/default" will always be added automatically
@pytorchbot ciflow rerun

# ciflow rerun with additional labels "-l <ciflow/label_name>", which is equivalent to adding these labels manually and trigger the rerun
@pytorchbot ciflow rerun -l ciflow/scheduled -l ciflow/slow

For more information, please take a look at the CI Flow Wiki.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Sep 23, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 99f0767 (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



1 job timed out:

  • pytorch_xla_linux_bionic_py3_6_clang9_test

❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_xla_linux_bionic_py3_6_clang9_test (1/1)

Step: "Test" (full log | diagnosis details | 🔁 rerun) ❄️

Oct 05 10:10:13 2021-10-05 10:10:13.244681: E t...r_lib.cc:561] Unknown: Could not start gRPC server
Oct 05 10:09:44 + python3 /var/lib/jenkins/workspace/xla/test/test_mp_mesh_reduce.py
Oct 05 10:09:48 + run_test python3 /var/lib/jenkins/workspace/xla/test/test_mp_sync_batch_norm.py
Oct 05 10:09:48 + python3 /var/lib/jenkins/workspace/xla/test/test_mp_sync_batch_norm.py
Oct 05 10:10:11 sync_bn1d_no_channel_test ok
Oct 05 10:10:11 sync_bn1d_multi_channel_test ok
Oct 05 10:10:11 sync_bn2d_test ok
Oct 05 10:10:11 sync_bn3d_test ok
Oct 05 10:10:12 + run_test python3 /var/lib/jenkins/workspace/xla/test/test_async_closures.py
Oct 05 10:10:12 + python3 /var/lib/jenkins/workspace/xla/test/test_async_closures.py
Oct 05 10:10:13 E1005 10:10:13.243885330  150245 server_chttp2.cc:40]        {"created":"@1633428613.243841873","description":"No address added out of total 1 resolved","file":"external/com_github_grpc_grpc/src/core/ext/transport/chttp2/server/chttp2_server.cc","file_line":395,"referenced_errors":[{"created":"@1633428613.243840524","description":"Failed to add any wildcard listeners","file":"external/com_github_grpc_grpc/src/core/lib/iomgr/tcp_server_posix.cc","file_line":342,"referenced_errors":[{"created":"@1633428613.243813841","description":"Address family not supported by protocol","errno":97,"file":"external/com_github_grpc_grpc/src/core/lib/iomgr/socket_utils_common_posix.cc","file_line":420,"os_error":"Address family not supported by protocol","syscall":"socket","target_address":"[::]:40912"},{"created":"@1633428613.243839958","description":"Unable to configure socket","fd":6,"file":"external/com_github_grpc_grpc/src/core/lib/iomgr/tcp_server_utils_posix_common.cc","file_line":216,"referenced_errors":[{"created":"@1633428613.243832612","description":"Address already in use","errno":98,"file":"external/com_github_grpc_grpc/src/core/lib/iomgr/tcp_server_utils_posix_common.cc","file_line":189,"os_error":"Address already in use","syscall":"bind"}]}]}]}
Oct 05 10:10:13 2021-10-05 10:10:13.244681: E tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:561] Unknown: Could not start gRPC server


Too long with no output (exceeded 1h30m0s): context deadline exceeded


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@nikitaved nikitaved marked this pull request as draft September 23, 2021 17:52
Comment thread torch/csrc/autograd/FunctionsManual.cpp Outdated
@nikitaved nikitaved marked this pull request as ready for review September 23, 2021 19:54
@nikitaved nikitaved added module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul module: autograd Related to torch.autograd, and the autograd engine in general labels Sep 23, 2021
Copy link
Copy Markdown
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a great PR! I had thought about implementing something like that, but I never got around to do it.

The code is very clean, thanks!

scalar_t* A_working_ptr = &A_data[i * A_mat_stride];
scalar_t* b_working_ptr = &b_data[i * b_mat_stride];
lapackCholeskySolve<scalar_t>(uplo, n, nrhs, A_working_ptr, n, b_working_ptr, n, &info);
lapackCholeskySolve<scalar_t>(uplo, n, nrhs, A_working_ptr, ldab, b_working_ptr, ldab, &info);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix.

Comment thread torch/csrc/autograd/FunctionsManual.cpp Outdated
const bool unitriangular
) {
return generic_solve_jvp(
[=](const Tensor& A, const Tensor& B) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[=](const Tensor& A, const Tensor& B) {
[upper, transpose, unitriangular](const Tensor& A, const Tensor& B) {

Otherwise it might need to copy all the tensors (I don't know if the compiler is allowed to ellide the copies).
Same below.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure it will happen either way since afaik Tensors are not being copied on operator=.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was more about the copies of tensors and the reference bumps. A very very minor thing really.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also capture everything by reference because the lambda never outlives the current scope?

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice refactoring! Thanks!

Comment thread torch/csrc/autograd/FunctionsManual.cpp Outdated
const bool unitriangular
) {
return generic_solve_jvp(
[=](const Tensor& A, const Tensor& B) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can also capture everything by reference because the lambda never outlives the current scope?

Comment thread torch/csrc/autograd/FunctionsManual.cpp Outdated
Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.
Thanks for the updates.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@mruberry
Copy link
Copy Markdown
Collaborator

Unlanding as this may have broken periodic_pytorch_xenial_cuda10_2_cudnn7_gcc7_old_gradcheck_test1. Relevant snippet:

Traceback (most recent call last):
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 1396, in wrapper
    method(*args, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 373, in instantiated_test
    raise rte
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 368, in instantiated_test
    result = test(self, **param_kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 769, in dep_fn
    return fn(slf, *args, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 769, in dep_fn
    return fn(slf, *args, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 769, in dep_fn
    return fn(slf, *args, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_device_type.py", line 734, in test_wrapper
    return test(*args, **kwargs)
  File "test_ops.py", line 616, in test_fn_gradgrad
    self._gradgrad_test_helper(device, dtype, op, op.get_op())
  File "test_ops.py", line 582, in _gradgrad_test_helper
    return self._check_helper(device, dtype, op, variant, 'gradgradcheck')
  File "test_ops.py", line 568, in _check_helper
    fast_mode=op.gradcheck_fast_mode))
  File "/opt/conda/lib/python3.6/site-packages/torch/testing/_internal/common_utils.py", line 2706, in gradgradcheck
    return torch.autograd.gradgradcheck(fn, inputs, grad_outputs, **kwargs)
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/gradcheck.py", line 1391, in gradgradcheck
    check_grad_dtypes=check_grad_dtypes, check_batched_grad=check_batched_grad, fast_mode=fast_mode)
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/gradcheck.py", line 1263, in gradcheck
    return _gradcheck_helper(**args)
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/gradcheck.py", line 1277, in _gradcheck_helper
    rtol, atol, check_grad_dtypes, check_forward_ad=check_forward_ad, nondet_tol=nondet_tol)
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/gradcheck.py", line 947, in _gradcheck_real_imag
    rtol, atol, check_grad_dtypes, nondet_tol)
  File "/opt/conda/lib/python3.6/site-packages/torch/autograd/gradcheck.py", line 995, in _slow_gradcheck
    raise GradcheckError(_get_notallclose_msg(a, n, i, j, complex_indices, test_imag))
torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 1 with respect to input 0,
numerical:tensor([[-2.1992e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -9.2681e+08, -1.8873e+09,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.7728e+08,  3.6097e+08, -2.9815e+09,  0.0000e+00,  0.0000e+00,
         -3.6879e+07, -7.5028e+07,  6.1982e+08, -1.2263e+09,  0.0000e+00,
         -6.5864e+06, -1.3416e+07,  1.1079e+08, -2.1920e+08, -1.2401e+10],
        [-2.3319e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -9.8306e+08, -2.1532e+09,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.8794e+08,  4.1188e+08, -2.9857e+09,  0.0000e+00,  0.0000e+00,
         -3.9066e+07, -8.5636e+07,  6.2070e+08, -1.1899e+09,  0.0000e+00,
         -6.9824e+06, -1.5306e+07,  1.1095e+08, -2.1270e+08, -1.2403e+10],
        [-8.6484e+07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -3.6419e+08, -4.8366e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          6.9656e+07,  9.2520e+07, -5.9743e+08,  0.0000e+00,  0.0000e+00,
         -1.4480e+07, -1.9233e+07,  1.2419e+08, -1.7001e+08,  0.0000e+00,
         -2.5859e+06, -3.4382e+06,  2.2201e+07, -3.0392e+07, -2.0132e+09],
        [-1.8727e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -7.8900e+08, -6.3067e+09,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.5091e+08,  1.2063e+09, -1.1023e+10,  0.0000e+00,  0.0000e+00,
         -3.1383e+07, -2.5078e+08,  2.2917e+09, -5.5897e+09,  0.0000e+00,
         -5.6089e+06, -4.4825e+07,  4.0962e+08, -9.9912e+08, -5.2264e+10],
        [-2.4300e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -1.0237e+09, -7.4280e+09,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.9583e+08,  1.4207e+09, -1.1041e+10,  0.0000e+00,  0.0000e+00,
         -4.0719e+07, -2.9536e+08,  2.2953e+09, -5.4365e+09,  0.0000e+00,
         -7.2734e+06, -5.2795e+07,  4.1028e+08, -9.7176e+08, -5.2275e+10],
        [-2.4439e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -1.0297e+09, -1.7713e+09,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.9698e+08,  3.3869e+08, -2.2676e+09,  0.0000e+00,  0.0000e+00,
         -4.0930e+07, -7.0437e+07,  4.7139e+08, -7.8493e+08,  0.0000e+00,
         -7.3159e+06, -1.2588e+07,  8.4263e+07, -1.4031e+08, -8.4850e+09],
        [ 3.5609e+07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.5125e+08,  5.5589e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -2.8906e+07, -1.0634e+08,  2.1953e+09,  0.0000e+00,  0.0000e+00,
          6.0039e+06,  2.2093e+07, -4.5639e+08,  1.0997e+09,  0.0000e+00,
          1.0728e+06,  3.9520e+06, -8.1576e+07,  1.9656e+08,  9.9960e+09],
        [ 4.6359e+07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.9581e+08,  7.7062e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -3.7469e+07, -1.4729e+08,  2.1987e+09,  0.0000e+00,  0.0000e+00,
          7.7891e+06,  3.0645e+07, -4.5710e+08,  1.0704e+09,  0.0000e+00,
          1.3931e+06,  5.4749e+06, -8.1703e+07,  1.9133e+08,  9.9979e+09],
        [ 4.6750e+07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.9688e+08,  2.3312e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -3.7625e+07, -4.4555e+07,  4.4782e+08,  0.0000e+00,  0.0000e+00,
          7.8398e+06,  9.2666e+06, -9.3096e+07,  1.5509e+08,  0.0000e+00,
          1.3994e+06,  1.6588e+06, -1.6641e+07,  2.7724e+07,  1.6228e+09],
        [-7.5781e+06,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -3.1312e+07, -1.1534e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          5.9375e+06,  2.2105e+07, -2.3180e+08,  0.0000e+00,  0.0000e+00,
         -1.2383e+06, -4.5942e+06,  4.8199e+07, -2.7880e+08,  0.0000e+00,
         -2.1973e+05, -8.2129e+05,  8.6125e+06, -4.9828e+07, -2.0779e+09],
        [-9.6250e+06,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -4.1062e+07, -1.6031e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          7.8594e+06,  3.0602e+07, -2.3247e+08,  0.0000e+00,  0.0000e+00,
         -1.6016e+06, -6.3745e+06,  4.8338e+07, -2.7272e+08,  0.0000e+00,
         -2.9346e+05, -1.1371e+06,  8.6373e+06, -4.8742e+07, -2.0783e+09],
        [-9.7500e+06,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -4.0625e+07, -4.8734e+07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          7.8750e+06,  9.2422e+06, -5.6648e+07,  0.0000e+00,  0.0000e+00,
         -1.6250e+06, -1.9229e+06,  1.1776e+07, -4.0388e+07,  0.0000e+00,
         -2.9395e+05, -3.4277e+05,  2.1046e+06, -7.2192e+06, -3.3734e+08],
        [-1.2656e+06,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -5.6250e+06, -2.0781e+07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.1094e+06,  3.9609e+06, -4.1430e+07,  0.0000e+00,  0.0000e+00,
         -2.4219e+05, -8.1104e+05,  8.6133e+06, -1.9686e+07,  0.0000e+00,
         -4.1504e+04, -1.4697e+05,  1.5402e+06, -3.5206e+06, -3.7161e+08],
        [-1.7500e+06,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -7.1250e+06, -2.8906e+07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.3750e+06,  5.5156e+06, -4.1551e+07,  0.0000e+00,  0.0000e+00,
         -3.0078e+05, -1.1431e+06,  8.6367e+06, -1.8594e+07,  0.0000e+00,
         -5.2734e+04, -2.0337e+05,  1.5447e+06, -3.3254e+06, -3.7169e+08],
        [-1.8438e+06,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -7.3125e+06, -8.4688e+06,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.3906e+06,  1.6523e+06, -1.0121e+07,  0.0000e+00,  0.0000e+00,
         -2.8516e+05, -3.4766e+05,  2.1045e+06, -2.3256e+06,  0.0000e+00,
         -5.1270e+04, -6.3965e+04,  3.7634e+05, -4.1592e+05, -6.0331e+07]],
       device='cuda:0')
analytical:tensor([[-2.1994e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -9.2701e+08, -1.8872e+09,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.7730e+08,  3.6095e+08, -2.9815e+09,  0.0000e+00,  0.0000e+00,
         -3.6860e+07, -7.5039e+07,  6.1982e+08, -1.2263e+09,  0.0000e+00,
         -6.5886e+06, -1.3413e+07,  1.1079e+08, -2.1920e+08, -1.2401e+10],
        [-2.3317e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -9.8278e+08, -2.1533e+09,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.8797e+08,  4.1185e+08, -2.9857e+09,  0.0000e+00,  0.0000e+00,
         -3.9077e+07, -8.5620e+07,  6.2070e+08, -1.1899e+09,  0.0000e+00,
         -6.9850e+06, -1.5305e+07,  1.1095e+08, -2.1270e+08, -1.2403e+10],
        [-8.6450e+07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -3.6437e+08, -4.8364e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          6.9689e+07,  9.2501e+07, -5.9743e+08,  0.0000e+00,  0.0000e+00,
         -1.4487e+07, -1.9229e+07,  1.2419e+08, -1.7001e+08,  0.0000e+00,
         -2.5898e+06, -3.4375e+06,  2.2201e+07, -3.0392e+07, -2.0132e+09],
        [-1.8724e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -7.8920e+08, -6.3070e+09,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.5095e+08,  1.2063e+09, -1.1023e+10,  0.0000e+00,  0.0000e+00,
         -3.1381e+07, -2.5078e+08,  2.2917e+09, -5.5897e+09,  0.0000e+00,
         -5.6091e+06, -4.4825e+07,  4.0962e+08, -9.9912e+08, -5.2264e+10],
        [-2.4288e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -1.0237e+09, -7.4283e+09,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.9579e+08,  1.4208e+09, -1.1041e+10,  0.0000e+00,  0.0000e+00,
         -4.0704e+07, -2.9537e+08,  2.2953e+09, -5.4365e+09,  0.0000e+00,
         -7.2756e+06, -5.2796e+07,  4.1028e+08, -9.7176e+08, -5.2275e+10],
        [-2.4427e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -1.0296e+09, -1.7710e+09,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.9692e+08,  3.3873e+08, -2.2675e+09,  0.0000e+00,  0.0000e+00,
         -4.0939e+07, -7.0418e+07,  4.7140e+08, -7.8493e+08,  0.0000e+00,
         -7.3175e+06, -1.2587e+07,  8.4263e+07, -1.4031e+08, -8.4850e+09],
        [ 3.5813e+07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.5095e+08,  5.5602e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -2.8870e+07, -1.0635e+08,  2.1953e+09,  0.0000e+00,  0.0000e+00,
          6.0020e+06,  2.2109e+07, -4.5639e+08,  1.0997e+09,  0.0000e+00,
          1.0728e+06,  3.9517e+06, -8.1576e+07,  1.9656e+08,  9.9960e+09],
        [ 4.6453e+07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.9579e+08,  7.7038e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -3.7448e+07, -1.4735e+08,  2.1987e+09,  0.0000e+00,  0.0000e+00,
          7.7853e+06,  3.0633e+07, -4.5710e+08,  1.0704e+09,  0.0000e+00,
          1.3915e+06,  5.4752e+06, -8.1703e+07,  1.9133e+08,  9.9979e+09],
        [ 4.6720e+07,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.9692e+08,  2.3317e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -3.7664e+07, -4.4596e+07,  4.4782e+08,  0.0000e+00,  0.0000e+00,
          7.8301e+06,  9.2714e+06, -9.3096e+07,  1.5509e+08,  0.0000e+00,
          1.3995e+06,  1.6571e+06, -1.6641e+07,  2.7724e+07,  1.6228e+09],
        [-7.4450e+06,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -3.1381e+07, -1.1559e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          6.0020e+06,  2.2109e+07, -2.3181e+08,  0.0000e+00,  0.0000e+00,
         -1.2480e+06, -4.5970e+06,  4.8199e+07, -2.7880e+08,  0.0000e+00,
         -2.2300e+05, -8.2144e+05,  8.6127e+06, -4.9828e+07, -2.0779e+09],
        [-9.6570e+06,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -4.0704e+07, -1.6016e+08,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          7.7853e+06,  3.0633e+07, -2.3248e+08,  0.0000e+00,  0.0000e+00,
         -1.6188e+06, -6.3693e+06,  4.8337e+07, -2.7272e+08,  0.0000e+00,
         -2.8926e+05, -1.1381e+06,  8.6375e+06, -4.8742e+07, -2.0783e+09],
        [-9.7125e+06,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -4.0939e+07, -4.8474e+07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          7.8301e+06,  9.2714e+06, -5.6641e+07,  0.0000e+00,  0.0000e+00,
         -1.6281e+06, -1.9278e+06,  1.1777e+07, -4.0388e+07,  0.0000e+00,
         -2.9092e+05, -3.4447e+05,  2.1044e+06, -7.2192e+06, -3.3734e+08],
        [-1.3308e+06,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -5.6091e+06, -2.0661e+07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.0728e+06,  3.9517e+06, -4.1432e+07,  0.0000e+00,  0.0000e+00,
         -2.2300e+05, -8.2144e+05,  8.6127e+06, -1.9686e+07,  0.0000e+00,
         -3.9882e+04, -1.4691e+05,  1.5403e+06, -3.5207e+06, -3.7161e+08],
        [-1.7263e+06,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -7.2756e+06, -2.8627e+07,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.3915e+06,  5.4752e+06, -4.1552e+07,  0.0000e+00,  0.0000e+00,
         -2.8926e+05, -1.1381e+06,  8.6375e+06, -1.8594e+07,  0.0000e+00,
         -5.1732e+04, -2.0355e+05,  1.5447e+06, -3.3254e+06, -3.7169e+08],
        [-1.7362e+06,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,
         -7.3175e+06, -8.6644e+06,  0.0000e+00,  0.0000e+00,  0.0000e+00,
          1.3995e+06,  1.6571e+06, -1.0124e+07,  0.0000e+00,  0.0000e+00,
         -2.9092e+05, -3.4447e+05,  2.1044e+06, -2.3257e+06,  0.0000e+00,
         -5.2029e+04, -6.1607e+04,  3.7636e+05, -4.1592e+05, -6.0331e+07]],
       device='cuda:0')

We had better run the periodic gradcheck job on this before relanding.

@mruberry mruberry reopened this Sep 29, 2021
@IvanYashchuk
Copy link
Copy Markdown
Collaborator

Here is a link to the failing CI: https://app.circleci.com/pipelines/github/pytorch/pytorch/386383/workflows/8e2b4619-efd2-44dd-bd5f-24a1289b65b5/jobs/16239595
It's on CircleCI so you can rerun it with SSH.
You didn't touch backward for cholesky_solve, but you did introduce OpInfo creating the test_fn_gradgrad_cholesky_solve_cuda_float64 test.
I assume the reason could be that sample inputs are not valid. cholesky_solve expects Cholesky factors as input, not general matrices. I'd try reusing sample_inputs_linalg_cholesky_inverse.

@nikitaved
Copy link
Copy Markdown
Collaborator Author

nikitaved commented Sep 29, 2021

Ah, good point, @IvanYashchuk ! That makes sense, thank you!

@IvanYashchuk
Copy link
Copy Markdown
Collaborator

That periodic job runs gradchecks with fast_mode=False. You could try testing it locally by adding gradcheck_fast_mode=False to OpInfo.

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Sep 30, 2021

with this label, we should have the "old_gradcheck" job in the CI for this PR.

@nikitaved
Copy link
Copy Markdown
Collaborator Author

This one looks good now with the updated OpInfo.

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the update!
Try 2!

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@albanD has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@github-actions github-actions Bot deleted the nikitaved/solve_methods_forward_AD branch February 13, 2024 01:56
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
Summary:
This PR adds forward AD for `*_solve` methods.
Additionally, `cholesky_solve` gets OpInfo + a bug fix when wrong leading dimensions could be passed to LAPACK,
and `lu_solve` gets forward AD with 2x`lu_solve` instead of 1x`lu_solve` + 2x`triangular_solve`.

cc ezyang albanD zou3519 gqchen pearu nikitaved soulitzer Lezcano Varal7 jianyuh mruberry walterddr IvanYashchuk xwang233

Pull Request resolved: pytorch#65546

Reviewed By: gchanan

Differential Revision: D31206837

Pulled By: albanD

fbshipit-source-id: 040beda97442e7a88a9df9abc7bb18313ce55bc3
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 25, 2026
Summary:
This PR adds forward AD for `*_solve` methods.
Additionally, `cholesky_solve` gets OpInfo + a bug fix when wrong leading dimensions could be passed to LAPACK,
and `lu_solve` gets forward AD with 2x`lu_solve` instead of 1x`lu_solve` + 2x`triangular_solve`.

cc ezyang albanD zou3519 gqchen pearu nikitaved soulitzer Lezcano Varal7 jianyuh mruberry walterddr IvanYashchuk xwang233

Pull Request resolved: pytorch#65546

Reviewed By: dagitses

Differential Revision: D31431847

Pulled By: albanD

fbshipit-source-id: 0e343e0d9da3c3d2051fca215fad289d77275251
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: autograd Related to torch.autograd, and the autograd engine in general module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants