From 72b08ba4bb80f5f89067de6e1c55ce218a92e6db Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sun, 25 Jul 2021 04:23:30 -0500 Subject: [PATCH 1/3] Add cholesky forward mode differentiation --- tools/autograd/derivatives.yaml | 2 ++ torch/csrc/autograd/FunctionsManual.cpp | 11 +++++++++++ torch/csrc/autograd/FunctionsManual.h | 1 + torch/testing/_internal/common_methods_invocations.py | 9 ++++++++- 4 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 2314044ec0d0a..9a0207bbf35d8 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -339,6 +339,7 @@ - name: linalg_cholesky_ex(Tensor self, *, bool check_errors=False) -> (Tensor L, Tensor info) self: cholesky_backward(grad, false, L) + L: cholesky_jvp(self_t, self_p, L) - name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor self, input2: cholesky_solve_backward(grad, self, input2, result, upper) @@ -1319,6 +1320,7 @@ - name: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) self: grad.transpose(dim0, dim1) + result: auto_linear - name: transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) self: grad.transpose(dim0, dim1) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 7b894cac08b30..74cc84f7c4d7b 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -997,6 +997,17 @@ Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArra return mask_selected.view(sizes); } +Tensor cholesky_jvp(Tensor input_tangent, Tensor input_primal, Tensor L) { + // Differentiation of the Cholesky decomposition, Iain Murray + // https://arxiv.org/abs/1602.07527 + // equation 8 + auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L, /*upper=*/false)); + auto phi = at::matmul(at::matmul(L_inverse, input_tangent), L_inverse.transpose(-2, -1).conj()); + phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5); + auto L_tangent = L.matmul(phi); + return L_tangent; +} + Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) { // cf. Iain Murray (2016); arXiv 1602.07527 // This gradient is symmetric, and not triangular. diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 3bd9d69696bbd..59cf24b9a46f9 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -102,6 +102,7 @@ at::Tensor mean_backward(at::Tensor grad, const at::IntArrayRef sizes, int64_t n at::Tensor var_std_mean_backward(const variable_list& grads, const at::Tensor& self, const at::Tensor& r1, const at::Tensor& r2, c10::optional dim, c10::optional correction, bool keepdim, bool is_std); at::Tensor masked_scatter_backward(const at::Tensor & grad, const at::Tensor & mask, at::IntArrayRef sizes); at::Tensor cholesky_backward(at::Tensor grad, bool upper, at::Tensor L); +at::Tensor cholesky_jvp(at::Tensor input_tangent, at::Tensor input_primal, at::Tensor L); at::Tensor cholesky_inverse_backward(at::Tensor grad, at::Tensor L, bool upper, at::Tensor inverse); at::Tensor split_with_sizes_backward(const std::vector &grads, IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 6134c67bb312d..46631ac7e17c1 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -5841,6 +5841,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): # got: vmap: Calling Tensor.as_strided is not supported # unless the batch dims being vmapped over are at the front of the tensor (in memory layout). check_batched_gradgrad=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_linalg_cholesky, gradcheck_wrapper=gradcheck_wrapper_hermitian_input, decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack], @@ -5854,9 +5855,14 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): aten_name='linalg_cholesky_ex', dtypes=floating_and_complex_types(), check_batched_gradgrad=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_linalg_cholesky, gradcheck_wrapper=gradcheck_wrapper_hermitian_input, - decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]), + decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack], + skips=( + # Gradcheck for complex generates invalid inputs for this function + SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),), + ), OpInfo('linalg.cond', aten_name='linalg_cond', dtypes=floating_and_complex_types(), @@ -7537,6 +7543,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), supports_out=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_transpose_swapdims), OpInfo('tril', dtypes=all_types_and_complex_and(torch.bool, torch.half), From 4344fa5bd52ab43b19f4002825e1b31837775ebd Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sun, 25 Jul 2021 05:05:04 -0500 Subject: [PATCH 2/3] Use const Tensor& --- torch/csrc/autograd/FunctionsManual.cpp | 2 +- torch/csrc/autograd/FunctionsManual.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 74cc84f7c4d7b..a7222636f03ad 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -997,7 +997,7 @@ Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArra return mask_selected.view(sizes); } -Tensor cholesky_jvp(Tensor input_tangent, Tensor input_primal, Tensor L) { +Tensor cholesky_jvp(const Tensor& input_tangent, const Tensor& input_primal, const Tensor& L) { // Differentiation of the Cholesky decomposition, Iain Murray // https://arxiv.org/abs/1602.07527 // equation 8 diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 59cf24b9a46f9..22e645fd7911e 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -102,7 +102,7 @@ at::Tensor mean_backward(at::Tensor grad, const at::IntArrayRef sizes, int64_t n at::Tensor var_std_mean_backward(const variable_list& grads, const at::Tensor& self, const at::Tensor& r1, const at::Tensor& r2, c10::optional dim, c10::optional correction, bool keepdim, bool is_std); at::Tensor masked_scatter_backward(const at::Tensor & grad, const at::Tensor & mask, at::IntArrayRef sizes); at::Tensor cholesky_backward(at::Tensor grad, bool upper, at::Tensor L); -at::Tensor cholesky_jvp(at::Tensor input_tangent, at::Tensor input_primal, at::Tensor L); +at::Tensor cholesky_jvp(const at::Tensor& input_tangent, const at::Tensor& input_primal, const at::Tensor& L); at::Tensor cholesky_inverse_backward(at::Tensor grad, at::Tensor L, bool upper, at::Tensor inverse); at::Tensor split_with_sizes_backward(const std::vector &grads, IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options); From 14dcdecbd232848798202d46c1e5e2f42c526496 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 6 Sep 2021 05:29:22 -0500 Subject: [PATCH 3/3] Add upper kwarg support to forward ad rule --- tools/autograd/derivatives.yaml | 2 +- torch/csrc/autograd/FunctionsManual.cpp | 13 ++++++++----- torch/csrc/autograd/FunctionsManual.h | 2 +- .../testing/_internal/common_methods_invocations.py | 4 ++-- 4 files changed, 12 insertions(+), 9 deletions(-) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 124a67e232b88..4bdb56558b327 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -339,7 +339,7 @@ - name: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info) self: cholesky_backward(grad, upper, L) - L: cholesky_jvp(self_t, self_p, L) + L: cholesky_jvp(self_t, L, upper) - name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor self, input2: cholesky_solve_backward(grad, self, input2, result, upper) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 156ebe2addf9a..9ccfbd162bafe 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1006,15 +1006,18 @@ Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArra return mask_selected.view(sizes); } -Tensor cholesky_jvp(const Tensor& input_tangent, const Tensor& input_primal, const Tensor& L) { +Tensor cholesky_jvp(const Tensor& input_tangent, const Tensor& L, bool upper) { // Differentiation of the Cholesky decomposition, Iain Murray // https://arxiv.org/abs/1602.07527 // equation 8 - auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L, /*upper=*/false)); - auto phi = at::matmul(at::matmul(L_inverse, input_tangent), L_inverse.transpose(-2, -1).conj()); + auto input_tangent_ = upper ? input_tangent.transpose(-1, -2).conj() : input_tangent; + auto L_ = upper ? L.transpose(-1, -2).conj() : L; + + auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L_, /*upper=*/false)); + auto phi = at::matmul(at::matmul(L_inverse, input_tangent_), L_inverse.transpose(-2, -1).conj()); phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5); - auto L_tangent = L.matmul(phi); - return L_tangent; + auto L_tangent = L_.matmul(phi); + return upper ? L_tangent.transpose(-1, -2).conj() : L_tangent; } Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) { diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 3cc4ff386ef56..6684bcb68ff2f 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -101,7 +101,7 @@ at::Tensor mean_backward(at::Tensor grad, const at::IntArrayRef sizes, int64_t n at::Tensor var_std_mean_backward(const variable_list& grads, const at::Tensor& self, const at::Tensor& r1, const at::Tensor& r2, c10::optional dim, c10::optional correction, bool keepdim, bool is_std); at::Tensor masked_scatter_backward(const at::Tensor & grad, const at::Tensor & mask, at::IntArrayRef sizes); at::Tensor cholesky_backward(at::Tensor grad, bool upper, at::Tensor L); -at::Tensor cholesky_jvp(const at::Tensor& input_tangent, const at::Tensor& input_primal, const at::Tensor& L); +at::Tensor cholesky_jvp(const at::Tensor& input_tangent, const at::Tensor& L, bool upper); at::Tensor cholesky_inverse_backward(at::Tensor grad, at::Tensor L, bool upper, at::Tensor inverse); at::Tensor split_with_sizes_backward(const std::vector &grads, IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 6b16557724143..6407dcc0e567b 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -3494,10 +3494,10 @@ def sample_inputs_linalg_cholesky(op_info, device, dtype, requires_grad=False, * batches = [(), (0, ), (2, ), (1, 1)] ns = [5, 0] out = [] - for batch, n in product(batches, ns): + for batch, n, upper in product(batches, ns, [True, False]): a = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device) a.requires_grad = requires_grad - out.append(SampleInput(a)) + out.append(SampleInput(a, kwargs={"upper": upper})) return out def sample_inputs_symeig(op_info, device, dtype, requires_grad=False):