diff --git a/CHANGELOG.md b/CHANGELOG.md index 41c759997..4738e8ff1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [Unreleased] ### Added - Added support for PyTorch 1.12 ([#57](https://github.com/pyg-team/pyg-lib/pull/57), [#58](https://github.com/pyg-team/pyg-lib/pull/58)) -- Added `grouped_matmul` and `segment_matmul` CUDA implementations via `cutlass` ([#51](https://github.com/pyg-team/pyg-lib/pull/51), [#56](https://github.com/pyg-team/pyg-lib/pull/56)) +- Added `grouped_matmul` and `segment_matmul` CUDA implementations via `cutlass` ([#51](https://github.com/pyg-team/pyg-lib/pull/51), [#56](https://github.com/pyg-team/pyg-lib/pull/56), [#61](https://github.com/pyg-team/pyg-lib/pull/61)) - Added `pyg::sampler::neighbor_sample` interface ([#54](https://github.com/pyg-team/pyg-lib/pull/54)) - Added `pyg::sampler::Mapper` utility for mapping global to local node indices ([#45](https://github.com/pyg-team/pyg-lib/pull/45))) - Added benchmark script ([#45](https://github.com/pyg-team/pyg-lib/pull/45)) diff --git a/pyg_lib/csrc/ops/matmul.cpp b/pyg_lib/csrc/ops/matmul.cpp index 06b6fed90..369b902cb 100644 --- a/pyg_lib/csrc/ops/matmul.cpp +++ b/pyg_lib/csrc/ops/matmul.cpp @@ -2,15 +2,16 @@ #include #include +#include namespace pyg { namespace ops { -// Performs matrix multiplication across list of elements. -std::vector grouped_matmul(const std::vector& input, - const std::vector& other) { +namespace { + +std::vector _grouped_matmul(const std::vector& input, + const std::vector& other) { // TODO (matthias) Add TensorArg definitions. - // TODO (matthias) Add autograd support. // TODO (matthias) Add dispatcher support. static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::grouped_matmul", "") @@ -18,18 +19,67 @@ std::vector grouped_matmul(const std::vector& input, return op.call(input, other); } -// Performs matrix multiplication according to segments. -at::Tensor segment_matmul(const at::Tensor& input, - const at::Tensor& ptr, - const at::Tensor& other) { +at::Tensor _segment_matmul(const at::Tensor& input, + const at::Tensor& ptr, + const at::Tensor& other) { // TODO (matthias) Add TensorArg definitions. - // TODO (matthias) Add autograd support. static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::segment_matmul", "") .typed(); return op.call(input, ptr, other); } +using torch::autograd::AutogradContext; +using torch::autograd::Variable; +using torch::autograd::variable_list; + +// Performs matrix multiplication according to segments. +class SegmentMatmul : public torch::autograd::Function { + public: + static variable_list forward(AutogradContext* ctx, + Variable input, + Variable ptr, + Variable other) { + Variable out = _segment_matmul(input, ptr, other); + ctx->save_for_backward({input, ptr, other}); + return {out}; + } + + static variable_list backward(AutogradContext* ctx, variable_list grad_outs) { + auto grad_out = grad_outs[0]; + auto saved = ctx->get_saved_variables(); + auto input = saved[0], ptr = saved[1], other = saved[2]; + + auto input_grad = Variable(), other_grad = Variable(); + if (torch::autograd::any_variable_requires_grad({input})) { + // TODO (matthias) get rid of unnecessary `contiguous` here. + auto other_t = other.transpose(-2, -1).contiguous(); + input_grad = _segment_matmul(grad_out, ptr, other_t); + } + if (torch::autograd::any_variable_requires_grad({other})) { + // TODO (matthias) implement backward pass for `other`. + } + + return {input_grad, Variable(), other_grad}; + } +}; + +} // namespace + +// Performs matrix multiplication across list of elements. +std::vector grouped_matmul(const std::vector& input, + const std::vector& other) { + // TODO (matthias) Add autograd support. + return _grouped_matmul(input, other); +} + +// Performs matrix multiplication according to segments. +at::Tensor segment_matmul(const at::Tensor& input, + const at::Tensor& ptr, + const at::Tensor& other) { + return SegmentMatmul::apply(input, ptr, other)[0]; +} + TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::grouped_matmul(Tensor[] input, Tensor[] other) -> Tensor[]")); diff --git a/test/csrc/ops/test_matmul.cpp b/test/csrc/ops/test_matmul.cpp index 4ffdd230c..4e62042e3 100644 --- a/test/csrc/ops/test_matmul.cpp +++ b/test/csrc/ops/test_matmul.cpp @@ -41,3 +41,18 @@ TEST(SegmentMatmulTest, BasicAssertions) { at::matmul(input.narrow(0, 5, 3), other[1]), 1e-01)); } #endif + +#ifdef WITH_CUDA +TEST(SegmentMatmulBackwardTest, BasicAssertions) { + auto options = at::TensorOptions().device(at::kCUDA); + + auto input = at::randn({8, 12}, options).requires_grad_(); + auto ptr = at::tensor({0, 5, 8}, options.dtype(at::kLong)); + auto other = at::randn({2, 12, 16}, options).requires_grad_(); + + auto out = pyg::ops::segment_matmul(input, ptr, other); + out.mean().backward(); + EXPECT_TRUE(input.grad().numel() == input.numel()); + EXPECT_TRUE(other.grad().numel() == 0); // No backward pass for `other` yet. +} +#endif