Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@

- name: linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info)
self: cholesky_backward(grad, upper, L)
L: cholesky_jvp(self_t, L, upper)

- name: cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor
self, input2: cholesky_solve_backward(grad, self, input2, result, upper)
Expand Down
14 changes: 14 additions & 0 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,20 @@ Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArra
return mask_selected.view(sizes);
}

Tensor cholesky_jvp(const Tensor& input_tangent, const Tensor& L, bool upper) {
// Differentiation of the Cholesky decomposition, Iain Murray
// https://arxiv.org/abs/1602.07527
// equation 8
auto input_tangent_ = upper ? input_tangent.transpose(-1, -2).conj() : input_tangent;
auto L_ = upper ? L.transpose(-1, -2).conj() : L;

auto L_inverse = std::get<0>(at::triangular_solve(at::eye(L.size(-1), L.options()), L_, /*upper=*/false));
auto phi = at::matmul(at::matmul(L_inverse, input_tangent_), L_inverse.transpose(-2, -1).conj());
phi.tril_().diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).mul_(0.5);
auto L_tangent = L_.matmul(phi);
return upper ? L_tangent.transpose(-1, -2).conj() : L_tangent;
}

Tensor cholesky_backward(Tensor grad, bool upper, Tensor L) {
// cf. Iain Murray (2016); arXiv 1602.07527
// This gradient is symmetric, and not triangular.
Expand Down
1 change: 1 addition & 0 deletions torch/csrc/autograd/FunctionsManual.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ at::Tensor mean_backward(at::Tensor grad, const at::IntArrayRef sizes, int64_t n
at::Tensor var_std_mean_backward(const variable_list& grads, const at::Tensor& self, const at::Tensor& r1, const at::Tensor& r2, c10::optional<IntArrayRef> dim, c10::optional<int64_t> correction, bool keepdim, bool is_std);
at::Tensor masked_scatter_backward(const at::Tensor & grad, const at::Tensor & mask, at::IntArrayRef sizes);
at::Tensor cholesky_backward(at::Tensor grad, bool upper, at::Tensor L);
at::Tensor cholesky_jvp(const at::Tensor& input_tangent, const at::Tensor& L, bool upper);
at::Tensor cholesky_inverse_backward(at::Tensor grad, at::Tensor L, bool upper, at::Tensor inverse);
at::Tensor split_with_sizes_backward(const std::vector<torch::autograd::Variable> &grads,
IntArrayRef split_sizes, int64_t dim, IntArrayRef sizes, const at::TensorOptions &options);
Expand Down
12 changes: 9 additions & 3 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3494,10 +3494,10 @@ def sample_inputs_linalg_cholesky(op_info, device, dtype, requires_grad=False, *
batches = [(), (0, ), (2, ), (1, 1)]
ns = [5, 0]
out = []
for batch, n in product(batches, ns):
for batch, n, upper in product(batches, ns, [True, False]):
a = random_hermitian_pd_matrix(n, *batch, dtype=dtype, device=device)
a.requires_grad = requires_grad
out.append(SampleInput(a))
out.append(SampleInput(a, kwargs={"upper": upper}))
return out

def sample_inputs_symeig(op_info, device, dtype, requires_grad=False):
Expand Down Expand Up @@ -6804,6 +6804,7 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
# got: vmap: Calling Tensor.as_strided is not supported
# unless the batch dims being vmapped over are at the front of the tensor (in memory layout).
check_batched_gradgrad=False,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_linalg_cholesky,
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
Expand All @@ -6817,9 +6818,14 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
aten_name='linalg_cholesky_ex',
dtypes=floating_and_complex_types(),
check_batched_gradgrad=False,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_linalg_cholesky,
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack]),
decorators=[skipCUDAIfNoMagmaAndNoCusolver, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# Gradcheck for complex generates invalid inputs for this function
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might this be because it does not create Hermitian inputs? It's not clear whether we can assume that the inputs are Hermitian or not. See #62163 (comment)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ho yes, I encountered that one but forgot to fix it.
We should update

return apply_to_c_inps(fn, lambda x: x + 0 * 1j), apply_to_c_inps(fn, lambda x: x * 1j)
to make it keep the original imag/real respectively instead of setting them to 0s.

SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),),
),
OpInfo('linalg.cond',
aten_name='linalg_cond',
dtypes=floating_and_complex_types(),
Expand Down