Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
86 commits
Select commit Hold shift + click to select a range
a7a1519
drafting
puririshi98 Jul 7, 2022
439a1ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2022
4b8b77b
drafting
puririshi98 Jul 7, 2022
db660bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2022
ca1aa96
drafting
puririshi98 Jul 7, 2022
ece2b98
drafting
puririshi98 Jul 7, 2022
541f928
drafting
puririshi98 Jul 7, 2022
f0b8a2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2022
a71cba8
drafting
puririshi98 Jul 7, 2022
c166a20
Merge branch 'master' of https://github.com/puririshi98/pyg-lib
puririshi98 Jul 7, 2022
a86093f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 7, 2022
a891c25
drafting
puririshi98 Jul 7, 2022
46adc54
drafting
puririshi98 Jul 11, 2022
88e2488
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2022
cb019c6
addressed previous comments
puririshi98 Jul 11, 2022
c69a4b8
minor cleanups
puririshi98 Jul 11, 2022
6567f85
removing pytest to make pep8 happy
puririshi98 Jul 11, 2022
afcd7fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2022
01d706e
backward part of the test
puririshi98 Jul 11, 2022
89ca3d0
backward part of the test
puririshi98 Jul 11, 2022
5ddab2b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2022
5885e26
minor fixups
puririshi98 Jul 11, 2022
db2d532
Merge branch 'master' of https://github.com/puririshi98/pyg-lib
puririshi98 Jul 11, 2022
2ea1ca1
minor fixups
puririshi98 Jul 11, 2022
4ebfe1f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 11, 2022
5c7e81b
just segment matmul for now and i will ask internally for the solutio…
puririshi98 Jul 12, 2022
e0aa898
Merge branch 'master' of https://github.com/puririshi98/pyg-lib
puririshi98 Jul 12, 2022
773f478
minor fix
puririshi98 Jul 12, 2022
5e6c308
minor fix
puririshi98 Jul 12, 2022
a48a618
minor fix
puririshi98 Jul 12, 2022
379d26b
minor fix
puririshi98 Jul 12, 2022
34c374d
minor fix
puririshi98 Jul 12, 2022
8cd6462
minor fix
puririshi98 Jul 12, 2022
c638537
minor fix
puririshi98 Jul 12, 2022
aa3cbd5
minor fix
puririshi98 Jul 12, 2022
4d95568
minor fix
puririshi98 Jul 12, 2022
a5c58d0
minor fix
puririshi98 Jul 12, 2022
aa6384e
minor fix
puririshi98 Jul 12, 2022
f2d4f9d
minor fix
puririshi98 Jul 12, 2022
61507d3
minor fix
puririshi98 Jul 12, 2022
a1aa34b
minor fix
puririshi98 Jul 12, 2022
240ca9f
minor fix
puririshi98 Jul 12, 2022
4dbedfc
minor fix
puririshi98 Jul 12, 2022
0f565a0
minor fix
puririshi98 Jul 12, 2022
fb8dd46
minor fix
puririshi98 Jul 12, 2022
34a435c
trying to get CI to pass...
puririshi98 Jul 12, 2022
5ecd197
trying to get CI to pass...
puririshi98 Jul 12, 2022
f3da492
trying to get CI to pass...
puririshi98 Jul 12, 2022
b6206f9
trying to get CI to pass...
puririshi98 Jul 12, 2022
143d231
trying to get CI to pass...
puririshi98 Jul 12, 2022
020e12a
Merge branch 'master' into master
puririshi98 Jul 12, 2022
b737091
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
4cb1f5a
Merge branch 'master' of https://github.com/puririshi98/pyg-lib
puririshi98 Jul 12, 2022
87e007f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
c0b0f40
Merge branch 'master' of https://github.com/puririshi98/pyg-lib
puririshi98 Jul 12, 2022
61b2613
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
6540aba
minor fix
puririshi98 Jul 12, 2022
187faeb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
e2df485
Merge branch 'master' of https://github.com/puririshi98/pyg-lib
puririshi98 Jul 12, 2022
462928e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
fa3328f
Merge branch 'master' of https://github.com/puririshi98/pyg-lib
puririshi98 Jul 12, 2022
e443f12
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
de272be
minor fix
puririshi98 Jul 12, 2022
38ba159
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
c682cbb
Merge branch 'master' of https://github.com/puririshi98/pyg-lib
puririshi98 Jul 12, 2022
7f49c4d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
8d604cc
freezing all pyg version at working state
puririshi98 Jul 12, 2022
94aa3de
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
8bbd1f5
Merge branch 'master' of https://github.com/puririshi98/pyg-lib
puririshi98 Jul 12, 2022
adb7c3f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
e827c44
minor fix
puririshi98 Jul 12, 2022
deaf841
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
7384457
minor fix
puririshi98 Jul 12, 2022
ae230fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
5abb7df
trying to get CI to pass...
puririshi98 Jul 12, 2022
dcb3b7d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
b4cf268
trying to get CI to pass...
puririshi98 Jul 12, 2022
c27e135
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
eab8b12
trying to get CI to pass...
puririshi98 Jul 12, 2022
d28cf29
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 12, 2022
045ecc7
Merge branch 'master' of https://github.com/puririshi98/pyg-lib
puririshi98 Jul 12, 2022
365682b
update
rusty1s Jul 13, 2022
28d735a
update
rusty1s Jul 13, 2022
1b5671d
update
rusty1s Jul 13, 2022
fea1926
diff
rusty1s Jul 13, 2022
c0a534f
update
rusty1s Jul 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [Unreleased]
### Added
- Added support for PyTorch 1.12 ([#57](https://github.com/pyg-team/pyg-lib/pull/57), [#58](https://github.com/pyg-team/pyg-lib/pull/58))
- Added `grouped_matmul` and `segment_matmul` CUDA implementations via `cutlass` ([#51](https://github.com/pyg-team/pyg-lib/pull/51), [#56](https://github.com/pyg-team/pyg-lib/pull/56))
- Added `grouped_matmul` and `segment_matmul` CUDA implementations via `cutlass` ([#51](https://github.com/pyg-team/pyg-lib/pull/51), [#56](https://github.com/pyg-team/pyg-lib/pull/56), [#61](https://github.com/pyg-team/pyg-lib/pull/61))
- Added `pyg::sampler::neighbor_sample` interface ([#54](https://github.com/pyg-team/pyg-lib/pull/54))
- Added `pyg::sampler::Mapper` utility for mapping global to local node indices ([#45](https://github.com/pyg-team/pyg-lib/pull/45)))
- Added benchmark script ([#45](https://github.com/pyg-team/pyg-lib/pull/45))
Expand Down
68 changes: 59 additions & 9 deletions pyg_lib/csrc/ops/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,84 @@

#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>
#include <torch/script.h>

namespace pyg {
namespace ops {

// Performs matrix multiplication across list of elements.
std::vector<at::Tensor> grouped_matmul(const std::vector<at::Tensor>& input,
const std::vector<at::Tensor>& other) {
namespace {

std::vector<at::Tensor> _grouped_matmul(const std::vector<at::Tensor>& input,
const std::vector<at::Tensor>& other) {
// TODO (matthias) Add TensorArg definitions.
// TODO (matthias) Add autograd support.
// TODO (matthias) Add dispatcher support.
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::grouped_matmul", "")
.typed<decltype(grouped_matmul)>();
return op.call(input, other);
}

// Performs matrix multiplication according to segments.
at::Tensor segment_matmul(const at::Tensor& input,
const at::Tensor& ptr,
const at::Tensor& other) {
at::Tensor _segment_matmul(const at::Tensor& input,
const at::Tensor& ptr,
const at::Tensor& other) {
// TODO (matthias) Add TensorArg definitions.
// TODO (matthias) Add autograd support.
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::segment_matmul", "")
.typed<decltype(segment_matmul)>();
return op.call(input, ptr, other);
}

using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

// Performs matrix multiplication according to segments.
class SegmentMatmul : public torch::autograd::Function<SegmentMatmul> {
public:
static variable_list forward(AutogradContext* ctx,
Variable input,
Variable ptr,
Variable other) {
Variable out = _segment_matmul(input, ptr, other);
ctx->save_for_backward({input, ptr, other});
return {out};
}

static variable_list backward(AutogradContext* ctx, variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto input = saved[0], ptr = saved[1], other = saved[2];

auto input_grad = Variable(), other_grad = Variable();
if (torch::autograd::any_variable_requires_grad({input})) {
// TODO (matthias) get rid of unnecessary `contiguous` here.
auto other_t = other.transpose(-2, -1).contiguous();
input_grad = _segment_matmul(grad_out, ptr, other_t);
}
if (torch::autograd::any_variable_requires_grad({other})) {
// TODO (matthias) implement backward pass for `other`.
}

return {input_grad, Variable(), other_grad};
}
};

} // namespace

// Performs matrix multiplication across list of elements.
std::vector<at::Tensor> grouped_matmul(const std::vector<at::Tensor>& input,
const std::vector<at::Tensor>& other) {
// TODO (matthias) Add autograd support.
return _grouped_matmul(input, other);
}

// Performs matrix multiplication according to segments.
at::Tensor segment_matmul(const at::Tensor& input,
const at::Tensor& ptr,
const at::Tensor& other) {
return SegmentMatmul::apply(input, ptr, other)[0];
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::grouped_matmul(Tensor[] input, Tensor[] other) -> Tensor[]"));
Expand Down
15 changes: 15 additions & 0 deletions test/csrc/ops/test_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,18 @@ TEST(SegmentMatmulTest, BasicAssertions) {
at::matmul(input.narrow(0, 5, 3), other[1]), 1e-01));
}
#endif

#ifdef WITH_CUDA
TEST(SegmentMatmulBackwardTest, BasicAssertions) {
auto options = at::TensorOptions().device(at::kCUDA);

auto input = at::randn({8, 12}, options).requires_grad_();
auto ptr = at::tensor({0, 5, 8}, options.dtype(at::kLong));
auto other = at::randn({2, 12, 16}, options).requires_grad_();

auto out = pyg::ops::segment_matmul(input, ptr, other);
out.mean().backward();
EXPECT_TRUE(input.grad().numel() == input.numel());
EXPECT_TRUE(other.grad().numel() == 0); // No backward pass for `other` yet.
}
#endif