diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 660c18827c3be..c5f943d9ed972 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -479,6 +479,7 @@ - name: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) self: diagonal_backward(grad, self.sizes(), offset, dim1, dim2) + result: auto_linear - name: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor self: norm_backward(grad, self - other, p, result) @@ -579,10 +580,12 @@ - name: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) self: zeros_like(grad) + result: self_t.fill_(0) - name: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) self: zeros_like(grad) value: grad.sum() + result: self_t.fill_(value_t) - name: floor(Tensor self) -> Tensor self: zeros_like(grad) @@ -1338,6 +1341,8 @@ - name: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors) self: eigh_backward(grads, self, /*eigenvectors=*/true, eigenvalues, eigenvectors) + eigenvalues: eigh_jvp_eigenvalues(self_t, eigenvalues, eigenvectors) + eigenvectors: eigh_jvp_eigenvectors(self_t, eigenvalues, eigenvectors) - name: linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors) self: linalg_eig_backward(grads, self, eigenvalues, eigenvectors) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index e53a0c5b6ce7b..7232f0b741f14 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -2409,6 +2409,54 @@ Tensor linalg_eig_backward(const std::vector &grads, } } +// jvp functions for eigenvalues and eigenvectors are separate +// because currently forward AD only works with one rule per output +Tensor eigh_jvp_eigenvalues( + const Tensor& input_tangent, + const Tensor& eigenvalues, + const Tensor& eigenvectors) { + // An extended collection of matrix derivative results for forward and reverse mode automatic differentiation + // https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124 + // Section 3.1 Eigenvalues and eigenvectors + + // TODO: gradcheck from test_ops.py hangs with complex inputs + TORCH_CHECK_NOT_IMPLEMENTED( + !input_tangent.is_complex(), + "the derivative for 'eigh' with complex inputs is not implemented."); + + // see the note in the implementation of eigh_backward that tangent should be Hermitian + auto hermitian_tangent = 0.5*(input_tangent + input_tangent.transpose(-2, -1).conj()); + + auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), hermitian_tangent), eigenvectors); + auto eigenvalues_tangent = tmp.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1); + if (eigenvalues_tangent.is_complex()) { + return at::real(eigenvalues_tangent); + } + return eigenvalues_tangent; +} + +Tensor eigh_jvp_eigenvectors( + const Tensor& input_tangent, + const Tensor& eigenvalues, + const Tensor& eigenvectors) { + // An extended collection of matrix derivative results for forward and reverse mode automatic differentiation + // https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124 + // Section 3.1 Eigenvalues and eigenvectors + + TORCH_CHECK_NOT_IMPLEMENTED( + !input_tangent.is_complex(), + "the derivative for 'eigh' with complex inputs is not implemented."); + + auto E = eigenvalues.unsqueeze(-2) - eigenvalues.unsqueeze(-1); + E.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY); + + // see the note in the implementation of eigh_backward that tangent should be Hermitian + auto hermitian_tangent = 0.5*(input_tangent + input_tangent.transpose(-2, -1).conj()); + + auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), hermitian_tangent), eigenvectors); + return at::matmul(eigenvectors, tmp.div(E)); +} + Tensor eigh_backward(const std::vector &grads, const Tensor& self, bool eigenvectors, const Tensor& L, const Tensor& V) { // This function is used for both torch.symeig and torch.linalg.eigh. diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index b24dce774b0d2..2e1a36f1b6db9 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -157,6 +157,8 @@ Tensor slice_backward_wrapper( int64_t step); Tensor linalg_eig_backward(const std::vector &grads, const Tensor& self, const Tensor& L, const Tensor& V); +Tensor eigh_jvp_eigenvectors(const Tensor& input_tangent, const Tensor& eigenvalues, const Tensor& eigenvectors); +Tensor eigh_jvp_eigenvalues(const Tensor& input_tangent, const Tensor& eigenvalues, const Tensor& eigenvectors); Tensor eigh_backward(const std::vector &grads, const Tensor& self, bool eigenvectors, const Tensor& L, const Tensor& V); std::tuple triangular_solve_backward( diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index da7b3a080665b..eef7beaae8636 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -6577,6 +6577,7 @@ def wrapper(x: np.ndarray, *args, **kwargs): OpInfo('diagonal', dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), supports_out=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_diagonal_diag_embed), OpInfo('eq', dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16), @@ -6961,16 +6962,25 @@ def wrapper(x: np.ndarray, *args, **kwargs): aten_name='linalg_eigh', dtypes=floating_and_complex_types(), check_batched_gradgrad=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_linalg_eigh, gradcheck_wrapper=gradcheck_wrapper_hermitian_input, - decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]), + decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack], + skips=( + # Gradcheck for complex hangs for this function, therefore it raises NotImplementedError for now + SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),), + ), OpInfo('linalg.eigvalsh', aten_name='linalg_eigvalsh', dtypes=floating_and_complex_types(), check_batched_gradgrad=False, sample_inputs_func=sample_inputs_linalg_eigh, gradcheck_wrapper=gradcheck_wrapper_hermitian_input, - decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],), + decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack], + skips=( + # Gradcheck hangs for this function + SkipInfo('TestGradients', 'test_forward_mode_AD'),), + ), OpInfo('linalg.householder_product', aten_name='linalg_householder_product', op=torch.linalg.householder_product, @@ -8419,7 +8429,11 @@ def wrapper(x: np.ndarray, *args, **kwargs): check_batched_gradgrad=False, sample_inputs_func=sample_inputs_linalg_pinv_hermitian, gradcheck_wrapper=gradcheck_wrapper_hermitian_input, - decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]), + decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack], + skips=( + # Gradcheck hangs for this function + SkipInfo('TestGradients', 'test_forward_mode_AD'),), + ), OpInfo('eig', op=torch.eig, dtypes=floating_and_complex_types(), @@ -8438,6 +8452,7 @@ def wrapper(x: np.ndarray, *args, **kwargs): backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half, *[torch.bfloat16] if (SM60OrLater and CUDA11OrLater) else []), supports_out=False, + supports_forward_ad=True, sample_inputs_func=sample_inputs_einsum, skips=( # test does not work with passing lambda for op @@ -8867,6 +8882,7 @@ def wrapper(x: np.ndarray, *args, **kwargs): op=lambda x, scalar: torch.fill_(x.clone(), scalar), method_variant=None, inplace_variant=torch.Tensor.fill_, + supports_forward_ad=True, dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16), supports_out=False, skips=(