From 8687e330bb8209eb5a0fcde9c1f703b8f17a4661 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sun, 25 Jul 2021 09:40:10 -0500 Subject: [PATCH 1/8] Add forward AD for torch.linalg.eigh --- tools/autograd/derivatives.yaml | 12 ++++++ torch/csrc/autograd/FunctionsManual.cpp | 35 +++++++++++++++++ torch/csrc/autograd/FunctionsManual.h | 2 + .../_internal/common_methods_invocations.py | 39 +++++++++++++++++-- 4 files changed, 85 insertions(+), 3 deletions(-) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 2314044ec0d0a..e90b74b7cfa1e 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -469,6 +469,7 @@ - name: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) self: diagonal_backward(grad, self.sizes(), offset, dim1, dim2) + result: auto_linear - name: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor self: norm_backward(grad, self - other, p, result) @@ -566,10 +567,12 @@ - name: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) self: zeros_like(grad) + result: self_t.fill_(0) - name: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) self: zeros_like(grad) value: grad.sum() + result: self_t.fill_(value_t) - name: floor(Tensor self) -> Tensor self: zeros_like(grad) @@ -1079,13 +1082,16 @@ - name: pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor self: pow_backward(grad, self, exponent) + result: auto_element_wise - name: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor self: pow_backward_self(grad, self, exponent) exponent: pow_backward_exponent(grad, self, exponent, result) + result: (pow_backward_self(self_t.conj(), self_p, exponent_p) + pow_backward_exponent(exponent_t.conj(), self_p, exponent_p, result)).conj() - name: pow.Scalar(Scalar self, Tensor exponent) -> Tensor exponent: pow_backward_exponent(grad, self, exponent, result) + result: auto_element_wise - name: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor self: prod_backward(grad, self.to(grad.scalar_type()), result) @@ -1250,6 +1256,7 @@ - name: sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor self: handle_r_to_c(self.scalar_type(), grad) other: handle_r_to_c(other.scalar_type(), -grad * alpha.conj()) + result: self_t - maybe_multiply(other_t, alpha) - name: sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor self: handle_r_to_c(self.scalar_type(), grad) @@ -1281,6 +1288,8 @@ - name: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors) self: eigh_backward(grads, self, /*eigenvectors=*/true, eigenvalues, eigenvectors) + eigenvalues: eigh_jvp_eigenvalues(self_t, eigenvalues, eigenvectors) + eigenvectors: eigh_jvp_eigenvectors(self_t, eigenvalues, eigenvectors) - name: linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors) self: linalg_eig_backward(grads, self, eigenvalues, eigenvectors) @@ -1378,9 +1387,11 @@ - name: _unsafe_view(Tensor self, int[] size) -> Tensor self: grad.reshape(self.sizes()) + result: auto_linear - name: unsqueeze(Tensor(a) self, int dim) -> Tensor(a) self: grad.squeeze(dim) + result: auto_linear - name: unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!) self: grad.squeeze(dim) @@ -1393,6 +1404,7 @@ - name: view(Tensor(a) self, int[] size) -> Tensor(a) self: grad.reshape(self.sizes()) + result: auto_linear - name: view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) output_differentiability: [False] diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 7b894cac08b30..a9e227570bc45 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2386,6 +2386,41 @@ Tensor linalg_eig_backward(const std::vector &grads, } } +// jvp functions for eigenvalues and eigenvectors are separate +// because currently forward AD only works with one rule per output +Tensor eigh_jvp_eigenvalues( + const Tensor& input_tangent, + const Tensor& eigenvalues, + const Tensor& eigenvectors) { + // An extended collection of matrix derivative results for forward and reverse mode automatic differentiation + // https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124 + // Section 3.1 Eigenvalues and eigenvectors + + auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), input_tangent), eigenvectors); + auto eigenvalues_tangent = tmp.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1); + if (eigenvalues_tangent.is_complex()) { + return at::real(eigenvalues_tangent); + } + return eigenvalues_tangent; +} + +Tensor eigh_jvp_eigenvectors( + const Tensor& input_tangent, + const Tensor& eigenvalues, + const Tensor& eigenvectors) { + // An extended collection of matrix derivative results for forward and reverse mode automatic differentiation + // https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124 + // Section 3.1 Eigenvalues and eigenvectors + + auto F = eigenvalues.unsqueeze(-2) - eigenvalues.unsqueeze(-1); + F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); + F = F.pow(-1); + + auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), input_tangent), eigenvectors); + auto eigenvectors_tangent = at::matmul(eigenvectors, F.mul(tmp)); + return eigenvectors_tangent; +} + Tensor eigh_backward(const std::vector &grads, const Tensor& self, bool eigenvectors, const Tensor& L, const Tensor& V) { // This function is used for both torch.symeig and torch.linalg.eigh. diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 3bd9d69696bbd..5426bdb936133 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -157,6 +157,8 @@ Tensor slice_backward_wrapper( int64_t step); Tensor linalg_eig_backward(const std::vector &grads, const Tensor& self, const Tensor& L, const Tensor& V); +Tensor eigh_jvp_eigenvectors(const Tensor& input_tangent, const Tensor& eigenvalues, const Tensor& eigenvectors); +Tensor eigh_jvp_eigenvalues(const Tensor& input_tangent, const Tensor& eigenvalues, const Tensor& eigenvectors); Tensor eigh_backward(const std::vector &grads, const Tensor& self, bool eigenvectors, const Tensor& L, const Tensor& V); std::tuple triangular_solve_backward( diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 6134c67bb312d..cc8bfe9020cfc 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -4905,6 +4905,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): aliases=('subtract',), dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16), assert_autodiffed=True, + supports_forward_ad=True, sample_inputs_func=partial(sample_inputs_binary_pwise, alpha=2), supports_inplace_autograd=False), OpInfo('addmm', @@ -5438,6 +5439,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): safe_casts_outputs=True), OpInfo('diff', op=torch.diff, + supports_forward_ad=True, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_diff), OpInfo('div', @@ -5524,6 +5526,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): OpInfo('diagonal', dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), supports_out=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_diagonal_diag_embed), OpInfo('eq', dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), @@ -5780,6 +5783,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): torch.int32, torch.int64, torch.bfloat16, torch.half), supports_out=False, + supports_forward_ad=True, skips=( # following tests give a runtime error with undefined value tensor # see discussion : https://github.com/pytorch/pytorch/issues/56660 @@ -5883,16 +5887,25 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): aten_name='linalg_eigh', dtypes=floating_and_complex_types(), check_batched_gradgrad=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_linalg_eigh, gradcheck_wrapper=gradcheck_wrapper_hermitian_input, - decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]), + decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack], + skips=( + # Gradcheck for complex hangs for this function + SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),), + ), OpInfo('linalg.eigvalsh', aten_name='linalg_eigvalsh', dtypes=floating_and_complex_types(), check_batched_gradgrad=False, sample_inputs_func=sample_inputs_linalg_eigh, gradcheck_wrapper=gradcheck_wrapper_hermitian_input, - decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],), + decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack], + skips=( + # Gradcheck hangs for this function + SkipInfo('TestGradients', 'test_forward_mode_AD'),), + ), OpInfo('linalg.householder_product', aten_name='linalg_householder_product', op=torch.linalg.householder_product, @@ -5929,6 +5942,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16] if CUDA11OrLater else []), supports_inplace_autograd=False, + supports_forward_ad=True, # Batched grad checks fail for empty input tensors (see https://github.com/pytorch/pytorch/issues/53407) check_batched_grad=False, check_batched_gradgrad=False, @@ -6429,6 +6443,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): op=torch.outer, aliases=('ger', ), dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_forward_ad=True, sample_inputs_func=sample_inputs_outer,), OpInfo('ormqr', op=torch.ormqr, @@ -6451,11 +6466,13 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): backward_dtypesIfCUDA=all_types_and_complex_and(torch.bfloat16, torch.half), sample_inputs_func=sample_inputs_pow, supports_inplace_autograd=False, + supports_forward_ad=True, assert_autodiffed=True, ), OpInfo('float_power', dtypes=all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool), sample_inputs_func=sample_inputs_pow, + supports_forward_ad=True, skips=( SkipInfo('TestMathBits', 'test_conj_view', device_type='cuda'),),), OpInfo('prod', @@ -6684,6 +6701,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): backward_dtypesIfCPU=all_types_and_complex_and(torch.bfloat16, torch.bool), sample_inputs_func=sample_inputs_rbinops, supports_out=False, + supports_forward_ad=True, skips=( SkipInfo('TestJit', 'test_variant_consistency_jit',),), assert_autodiffed=True, @@ -6908,6 +6926,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): UnaryUfuncInfo('square', ref=np.square, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), + supports_forward_ad=True, decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),), skips=( # Reference: https://github.com/pytorch/pytorch/issues/52549 @@ -6995,7 +7014,11 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): check_batched_gradgrad=False, sample_inputs_func=sample_inputs_linalg_pinv_hermitian, gradcheck_wrapper=gradcheck_wrapper_hermitian_input, - decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]), + decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack], + skips=( + # Gradcheck hangs for this function + SkipInfo('TestGradients', 'test_forward_mode_AD'),), + ), OpInfo('eig', op=torch.eig, dtypes=floating_and_complex_types(), @@ -7196,12 +7219,14 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): OpInfo('ravel', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_out=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_ravel, ), OpInfo('reshape', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_view_reshape, supports_out=False, + supports_forward_ad=True, ), OpInfo('reshape_as', op=lambda x, other: x.reshape_as(other), @@ -7211,11 +7236,13 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): # Because reshape_as does not have a function variant. SkipInfo('TestJit', 'test_variant_consistency_jit'),), supports_out=False, + supports_forward_ad=True, ), OpInfo('view', op=lambda x, shape: x.view(shape), dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_out=False, + supports_forward_ad=True, skips=( # Because view does not have a function variant. SkipInfo('TestJit', 'test_variant_consistency_jit'),), @@ -7225,6 +7252,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): op=lambda x, other: x.view_as(other), dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_out=False, + supports_forward_ad=True, skips=( # Because view_as does not have a function variant. SkipInfo('TestJit', 'test_variant_consistency_jit'),), @@ -7356,6 +7384,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): OpInfo('dstack', dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_hstack_dstack_vstack, + supports_forward_ad=True, skips=( # dstack does not correctly warn when resizing out= inputs SkipInfo('TestCommon', 'test_out'),)), @@ -7411,6 +7440,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): op=lambda x, scalar: torch.fill_(x.clone(), scalar), method_variant=None, inplace_variant=torch.Tensor.fill_, + supports_forward_ad=True, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_out=False, skips=( @@ -7446,6 +7476,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_inplace_autograd=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_take_along_dim, gradcheck_nondet_tol=GRADCHECK_NONDET_TOL), ShapeFuncInfo('tile', @@ -7465,6 +7496,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_out=False, assert_autodiffed=True, + supports_forward_ad=True, sample_inputs_func=sample_unsqueeze), OpInfo('var', dtypes=floating_and_complex_types_and(torch.half), @@ -7548,6 +7580,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16), supports_inplace_autograd=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_kron), OpInfo('inner', dtypes=floating_and_complex_types_and(torch.half), From ce2944f6f2268bb508b984941d06602754a54d66 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sun, 25 Jul 2021 11:08:58 -0500 Subject: [PATCH 2/8] Add forward AD rule for transpose --- tools/autograd/derivatives.yaml | 1 + torch/testing/_internal/common_methods_invocations.py | 1 + 2 files changed, 2 insertions(+) diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index e90b74b7cfa1e..8be28bfd49c80 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1328,6 +1328,7 @@ - name: transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) self: grad.transpose(dim0, dim1) + result: auto_linear - name: transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) self: grad.transpose(dim0, dim1) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index cc8bfe9020cfc..f2464dd0cc40c 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -7569,6 +7569,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), dtypesIfCUDA=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.half), supports_out=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_transpose_swapdims), OpInfo('tril', dtypes=all_types_and_complex_and(torch.bool, torch.half), From ce994c95993f9b60faf7af7fee86e08d6d361d63 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 28 Jul 2021 07:36:55 -0500 Subject: [PATCH 3/8] Use hermitian tangent --- torch/csrc/autograd/FunctionsManual.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index d58328e7f8890..b803bfe4eb1ec 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2406,7 +2406,10 @@ Tensor eigh_jvp_eigenvalues( // https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124 // Section 3.1 Eigenvalues and eigenvectors - auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), input_tangent), eigenvectors); + // see the note in the implementation of eigh_backward that tangent should be Hermitian + auto hermitian_tangent = 0.5*(input_tangent + input_tangent.transpose(-2, -1).conj()); + + auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), hermitian_tangent), eigenvectors); auto eigenvalues_tangent = tmp.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1); if (eigenvalues_tangent.is_complex()) { return at::real(eigenvalues_tangent); @@ -2421,12 +2424,14 @@ Tensor eigh_jvp_eigenvectors( // An extended collection of matrix derivative results for forward and reverse mode automatic differentiation // https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124 // Section 3.1 Eigenvalues and eigenvectors - auto F = eigenvalues.unsqueeze(-2) - eigenvalues.unsqueeze(-1); F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); F = F.pow(-1); - auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), input_tangent), eigenvectors); + // see the note in the implementation of eigh_backward that tangent should be Hermitian + auto hermitian_tangent = 0.5*(input_tangent + input_tangent.transpose(-2, -1).conj()); + + auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), hermitian_tangent), eigenvectors); auto eigenvectors_tangent = at::matmul(eigenvectors, F.mul(tmp)); return eigenvectors_tangent; } From 3b4baf69ab17df68903877f1dc5c2af4c6e3132b Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 28 Jul 2021 07:38:16 -0500 Subject: [PATCH 4/8] Use div instead of reciprocal + mul --- torch/csrc/autograd/FunctionsManual.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index b803bfe4eb1ec..d1e1b15284887 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2424,15 +2424,14 @@ Tensor eigh_jvp_eigenvectors( // An extended collection of matrix derivative results for forward and reverse mode automatic differentiation // https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124 // Section 3.1 Eigenvalues and eigenvectors - auto F = eigenvalues.unsqueeze(-2) - eigenvalues.unsqueeze(-1); - F.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); - F = F.pow(-1); + auto E = eigenvalues.unsqueeze(-2) - eigenvalues.unsqueeze(-1); + E.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); // see the note in the implementation of eigh_backward that tangent should be Hermitian auto hermitian_tangent = 0.5*(input_tangent + input_tangent.transpose(-2, -1).conj()); auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), hermitian_tangent), eigenvectors); - auto eigenvectors_tangent = at::matmul(eigenvectors, F.mul(tmp)); + auto eigenvectors_tangent = at::matmul(eigenvectors, tmp.div(E)); return eigenvectors_tangent; } From bb094ca5c587dbc36c4e989161496ef70d775e3b Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Wed, 28 Jul 2021 08:37:43 -0500 Subject: [PATCH 5/8] Fix merge conflicts --- torch/testing/_internal/common_methods_invocations.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4b3b7f8dcf692..aa67697bca046 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -5520,7 +5520,6 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): safe_casts_outputs=True), OpInfo('diff', op=torch.diff, - supports_forward_ad=True, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_forward_ad=True, sample_inputs_func=sample_inputs_diff), @@ -6038,7 +6037,6 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): dtypesIfCPU=all_types_and_complex_and(torch.half, torch.bfloat16), dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16] if CUDA11OrLater else []), supports_inplace_autograd=False, - supports_forward_ad=True, # Batched grad checks fail for empty input tensors (see https://github.com/pytorch/pytorch/issues/53407) check_batched_grad=False, check_batched_gradgrad=False, @@ -7046,7 +7044,6 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): UnaryUfuncInfo('square', ref=np.square, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), - supports_forward_ad=True, decorators=(precisionOverride({torch.complex64: 3e-4, torch.bfloat16: 3e-1}),), supports_forward_ad=True, skips=( @@ -7634,7 +7631,6 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): supports_out=False, supports_forward_ad=True, assert_autodiffed=True, - supports_forward_ad=True, sample_inputs_func=sample_unsqueeze), OpInfo('var', dtypes=floating_and_complex_types_and(torch.half), From 47873563e9b4f08788b1e32a475c0f0e9c8c8a24 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 6 Sep 2021 14:26:09 -0500 Subject: [PATCH 6/8] Seems like einsum now supports forward AD --- torch/testing/_internal/common_methods_invocations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 543040bdece15..2c1de39cc4ef2 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -8287,6 +8287,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs): backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16] if (SM60OrLater and CUDA11OrLater) else []), supports_out=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_einsum, skips=( # test does not work with passing lambda for op From ff75f52004d3a7ce70163ebfaab18bcbb4043124 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 13 Sep 2021 10:15:57 -0500 Subject: [PATCH 7/8] Raise an error for complex inputs, the tests do not work yet --- torch/csrc/autograd/FunctionsManual.cpp | 10 ++++++++++ torch/testing/_internal/common_methods_invocations.py | 2 +- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index ae86f4fcc500b..7a1312ee8abc4 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2419,6 +2419,11 @@ Tensor eigh_jvp_eigenvalues( // https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124 // Section 3.1 Eigenvalues and eigenvectors + // TODO: gradcheck from test_ops.py hangs with complex inputs + TORCH_CHECK_NOT_IMPLEMENTED( + !input_tangent.is_complex(), + "the derivative for 'eigh' with complex inputs is not implemented."); + // see the note in the implementation of eigh_backward that tangent should be Hermitian auto hermitian_tangent = 0.5*(input_tangent + input_tangent.transpose(-2, -1).conj()); @@ -2437,6 +2442,11 @@ Tensor eigh_jvp_eigenvectors( // An extended collection of matrix derivative results for forward and reverse mode automatic differentiation // https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124 // Section 3.1 Eigenvalues and eigenvectors + + TORCH_CHECK_NOT_IMPLEMENTED( + !input_tangent.is_complex(), + "the derivative for 'eigh' with complex inputs is not implemented."); + auto E = eigenvalues.unsqueeze(-2) - eigenvalues.unsqueeze(-1); E.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 004215c77fac0..eef7beaae8636 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -6967,7 +6967,7 @@ def wrapper(x: np.ndarray, *args, **kwargs): gradcheck_wrapper=gradcheck_wrapper_hermitian_input, decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack], skips=( - # Gradcheck for complex hangs for this function + # Gradcheck for complex hangs for this function, therefore it raises NotImplementedError for now SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),), ), OpInfo('linalg.eigvalsh', From c3fc5edbbfe091f2165a89afcd862419d740babb Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 13 Sep 2021 10:16:51 -0500 Subject: [PATCH 8/8] Return directly without extra assignment --- torch/csrc/autograd/FunctionsManual.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 7a1312ee8abc4..7232f0b741f14 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2454,8 +2454,7 @@ Tensor eigh_jvp_eigenvectors( auto hermitian_tangent = 0.5*(input_tangent + input_tangent.transpose(-2, -1).conj()); auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), hermitian_tangent), eigenvectors); - auto eigenvectors_tangent = at::matmul(eigenvectors, tmp.div(E)); - return eigenvectors_tangent; + return at::matmul(eigenvectors, tmp.div(E)); } Tensor eigh_backward(const std::vector &grads, const Tensor& self,