Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@

- name: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)
self: diagonal_backward(grad, self.sizes(), offset, dim1, dim2)
result: auto_linear

- name: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor
self: norm_backward(grad, self - other, p, result)
Expand Down Expand Up @@ -579,10 +580,12 @@

- name: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
self: zeros_like(grad)
result: self_t.fill_(0)

- name: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)
self: zeros_like(grad)
value: grad.sum()
result: self_t.fill_(value_t)

- name: floor(Tensor self) -> Tensor
self: zeros_like(grad)
Expand Down Expand Up @@ -1338,6 +1341,8 @@

- name: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors)
self: eigh_backward(grads, self, /*eigenvectors=*/true, eigenvalues, eigenvectors)
eigenvalues: eigh_jvp_eigenvalues(self_t, eigenvalues, eigenvectors)
eigenvectors: eigh_jvp_eigenvectors(self_t, eigenvalues, eigenvectors)

- name: linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors)
self: linalg_eig_backward(grads, self, eigenvalues, eigenvectors)
Expand Down
48 changes: 48 additions & 0 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2409,6 +2409,54 @@ Tensor linalg_eig_backward(const std::vector<torch::autograd::Variable> &grads,
}
}

// jvp functions for eigenvalues and eigenvectors are separate
// because currently forward AD only works with one rule per output
Tensor eigh_jvp_eigenvalues(
const Tensor& input_tangent,
const Tensor& eigenvalues,
const Tensor& eigenvectors) {
// An extended collection of matrix derivative results for forward and reverse mode automatic differentiation
// https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124
// Section 3.1 Eigenvalues and eigenvectors

// TODO: gradcheck from test_ops.py hangs with complex inputs
TORCH_CHECK_NOT_IMPLEMENTED(
!input_tangent.is_complex(),
"the derivative for 'eigh' with complex inputs is not implemented.");

// see the note in the implementation of eigh_backward that tangent should be Hermitian
auto hermitian_tangent = 0.5*(input_tangent + input_tangent.transpose(-2, -1).conj());

auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), hermitian_tangent), eigenvectors);
auto eigenvalues_tangent = tmp.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1);
if (eigenvalues_tangent.is_complex()) {
return at::real(eigenvalues_tangent);
}
return eigenvalues_tangent;
}

Tensor eigh_jvp_eigenvectors(
const Tensor& input_tangent,
const Tensor& eigenvalues,
const Tensor& eigenvectors) {
// An extended collection of matrix derivative results for forward and reverse mode automatic differentiation
// https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124
// Section 3.1 Eigenvalues and eigenvectors

TORCH_CHECK_NOT_IMPLEMENTED(
!input_tangent.is_complex(),
"the derivative for 'eigh' with complex inputs is not implemented.");

auto E = eigenvalues.unsqueeze(-2) - eigenvalues.unsqueeze(-1);
E.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);

// see the note in the implementation of eigh_backward that tangent should be Hermitian
auto hermitian_tangent = 0.5*(input_tangent + input_tangent.transpose(-2, -1).conj());

auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), hermitian_tangent), eigenvectors);
return at::matmul(eigenvectors, tmp.div(E));
}

Tensor eigh_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool eigenvectors, const Tensor& L, const Tensor& V) {
// This function is used for both torch.symeig and torch.linalg.eigh.
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/autograd/FunctionsManual.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ Tensor slice_backward_wrapper(
int64_t step);
Tensor linalg_eig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
const Tensor& L, const Tensor& V);
Tensor eigh_jvp_eigenvectors(const Tensor& input_tangent, const Tensor& eigenvalues, const Tensor& eigenvectors);
Tensor eigh_jvp_eigenvalues(const Tensor& input_tangent, const Tensor& eigenvalues, const Tensor& eigenvectors);
Tensor eigh_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
bool eigenvectors, const Tensor& L, const Tensor& V);
std::tuple<Tensor, Tensor> triangular_solve_backward(
Expand Down
22 changes: 19 additions & 3 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6577,6 +6577,7 @@ def wrapper(x: np.ndarray, *args, **kwargs):
OpInfo('diagonal',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
supports_out=False,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_diagonal_diag_embed),
OpInfo('eq',
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
Expand Down Expand Up @@ -6961,16 +6962,25 @@ def wrapper(x: np.ndarray, *args, **kwargs):
aten_name='linalg_eigh',
dtypes=floating_and_complex_types(),
check_batched_gradgrad=False,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_linalg_eigh,
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]),
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# Gradcheck for complex hangs for this function, therefore it raises NotImplementedError for now
SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have an issue open tracking this? Do we want one if not?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no open issue about it. I added a NotImplementedError for complex inputs for now, we'll enable it later.

),
OpInfo('linalg.eigvalsh',
aten_name='linalg_eigvalsh',
dtypes=floating_and_complex_types(),
check_batched_gradgrad=False,
sample_inputs_func=sample_inputs_linalg_eigh,
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],),
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# Gradcheck hangs for this function
SkipInfo('TestGradients', 'test_forward_mode_AD'),),
),
OpInfo('linalg.householder_product',
aten_name='linalg_householder_product',
op=torch.linalg.householder_product,
Expand Down Expand Up @@ -8419,7 +8429,11 @@ def wrapper(x: np.ndarray, *args, **kwargs):
check_batched_gradgrad=False,
sample_inputs_func=sample_inputs_linalg_pinv_hermitian,
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]),
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# Gradcheck hangs for this function
SkipInfo('TestGradients', 'test_forward_mode_AD'),),
),
OpInfo('eig',
op=torch.eig,
dtypes=floating_and_complex_types(),
Expand All @@ -8438,6 +8452,7 @@ def wrapper(x: np.ndarray, *args, **kwargs):
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half,
*[torch.bfloat16] if (SM60OrLater and CUDA11OrLater) else []),
supports_out=False,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_einsum,
skips=(
# test does not work with passing lambda for op
Expand Down Expand Up @@ -8867,6 +8882,7 @@ def wrapper(x: np.ndarray, *args, **kwargs):
op=lambda x, scalar: torch.fill_(x.clone(), scalar),
method_variant=None,
inplace_variant=torch.Tensor.fill_,
supports_forward_ad=True,
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
skips=(
Expand Down