Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
- Added `[segment|grouped]_matmul` CPU implementation ([#111](https://github.com/pyg-team/pyg-lib/pull/111))
- Added `temporal_strategy` option to `neighbor_sample` ([#114](https://github.com/pyg-team/pyg-lib/pull/114))
- Added benchmarking tool (Google Benchmark) along with `pyg::sampler::Mapper` benchmark example ([#101](https://github.com/pyg-team/pyg-lib/pull/101))
- Added CSC mode to `pyg::sampler::neighbor_sample` and `pyg::sampler::hetero_neighbor_sample` ([#95](https://github.com/pyg-team/pyg-lib/pull/95), [#96](https://github.com/pyg-team/pyg-lib/pull/96))
Expand Down
58 changes: 58 additions & 0 deletions pyg_lib/csrc/ops/cpu/matmul_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#include <ATen/ATen.h>
#include <torch/library.h>

#include "pyg_lib/csrc/utils/convert.h"

namespace pyg {
namespace ops {

namespace {

void grouped_matmul_out_kernel(const at::TensorList input,
const at::TensorList other,
const at::TensorList out) {
for (size_t i = 0; i < out.size(); ++i)
at::matmul_out(const_cast<at::Tensor&>(out[i]), input[i], other[i]);
}

std::vector<at::Tensor> grouped_matmul_kernel(const at::TensorList input,
const at::TensorList other) {
std::vector<at::Tensor> out(input.size());
for (size_t i = 0; i < input.size(); ++i)
out[i] = input[i].new_empty({input[i].size(0), other[i].size(-1)});

grouped_matmul_out_kernel(input, other, out);

return out;
}

at::Tensor segment_matmul_kernel(const at::Tensor& input,
const at::Tensor& ptr,
const at::Tensor& other) {
const auto size = pyg::utils::size_from_ptr(ptr).cpu();
const auto sizes = at::IntArrayRef(size.data_ptr<int64_t>(), size.numel());
const auto out = input.new_empty({input.size(0), other.size(-1)});

auto outs = out.split_with_sizes(/*split_size=*/sizes, /*dim=*/0);
for (auto& out_part : outs) {
out_part.resize_(0);
}

grouped_matmul_out_kernel(
input.contiguous().split_with_sizes(/*split_size=*/sizes, /*dim=*/0),
other.contiguous().split(/*split_size=*/1, /*dim=*/0), outs);

return out;
}

} // namespace

TORCH_LIBRARY_IMPL(pyg, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::grouped_matmul"),
TORCH_FN(grouped_matmul_kernel));
m.impl(TORCH_SELECTIVE_NAME("pyg::segment_matmul"),
TORCH_FN(segment_matmul_kernel));
}

} // namespace ops
} // namespace pyg
26 changes: 7 additions & 19 deletions pyg_lib/csrc/ops/cuda/matmul_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@ namespace ops {

namespace {

void grouped_matmul_out_kernel(const std::vector<at::Tensor>& input,
const std::vector<at::Tensor>& other,
const std::vector<at::Tensor>& out) {
// TODO (matthias) Check tensor devices.

void grouped_matmul_out_kernel(const at::TensorList input,
const at::TensorList other,
const at::TensorList out) {
const auto num_matrices = input.size();
std::vector<at::Tensor> new_input, new_other, new_out;

Expand Down Expand Up @@ -81,8 +79,6 @@ void grouped_matmul_out_kernel(const std::vector<at::Tensor>& input,
for (size_t i = 0; i < num_matrices; ++i) {
auto m = new_input[i].size(0), k = new_input[i].size(1),
n = new_out[i].size(1);
TORCH_CHECK(new_input[i].size(-1) == new_other[i].size(-2),
"Shape mismatch");
all_problems[i] = cutlass::gemm::GemmCoord(m, n, k);
ld_A_host[i] = GemmKernel::LayoutA::packed({m, k}).stride(0);
ld_B_host[i] = GemmKernel::LayoutB::packed({k, n}).stride(0);
Expand Down Expand Up @@ -121,9 +117,8 @@ void grouped_matmul_out_kernel(const std::vector<at::Tensor>& input,
TORCH_CHECK(status == cutlass::Status::kSuccess, "GroupedGEMM run failed");
}

std::vector<at::Tensor> grouped_matmul_kernel(
const std::vector<at::Tensor>& input,
const std::vector<at::Tensor>& other) {
std::vector<at::Tensor> grouped_matmul_kernel(const at::TensorList input,
const at::TensorList other) {
std::vector<at::Tensor> out(input.size());
for (size_t i = 0; i < input.size(); ++i)
out[i] = input[i].new_empty({input[i].size(0), other[i].size(-1)});
Expand Down Expand Up @@ -151,17 +146,10 @@ at::Tensor segment_matmul_kernel(const at::Tensor& input,

} // namespace

TORCH_LIBRARY(pyg, m) {
m.def("pyg::cuda_grouped_matmul(Tensor[] input, Tensor[] other) -> Tensor[]");
m.def(
"pyg::cuda_segment_matmul(Tensor input, Tensor ptr, Tensor other) -> "
"Tensor");
}

TORCH_LIBRARY_IMPL(pyg, CUDA, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::cuda_grouped_matmul"),
m.impl(TORCH_SELECTIVE_NAME("pyg::grouped_matmul"),
TORCH_FN(grouped_matmul_kernel));
m.impl(TORCH_SELECTIVE_NAME("pyg::cuda_segment_matmul"),
m.impl(TORCH_SELECTIVE_NAME("pyg::segment_matmul"),
TORCH_FN(segment_matmul_kernel));
}

Expand Down
47 changes: 39 additions & 8 deletions pyg_lib/csrc/ops/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,58 @@
#include <torch/library.h>
#include <torch/script.h>

#include "pyg_lib/csrc/utils/check.h"
#include "pyg_lib/csrc/utils/convert.h"

namespace pyg {
namespace ops {

namespace {

std::vector<at::Tensor> _grouped_matmul(const std::vector<at::Tensor>& input,
const std::vector<at::Tensor>& other) {
// TODO (matthias) Add TensorArg definitions.
// TODO (matthias) Add dispatcher support.
std::vector<at::Tensor> _grouped_matmul(const at::TensorList input,
const at::TensorList other) {
TORCH_CHECK(input.size() == other.size(),
"Number of 'input' tensors must match number of 'other' tensors");
const auto n_tensors = input.size();
std::vector<at::TensorArg> input_args;
std::vector<at::TensorArg> other_args;
pyg::utils::fill_tensor_args(input_args, input, "input", 0);
pyg::utils::fill_tensor_args(other_args, other, "other", 1);
at::CheckedFrom c{"grouped_matmul"};

at::checkAllDefined(c, input_args);
at::checkAllDefined(c, other_args);
at::checkAllSameType(c, input_args);
at::checkAllSameType(c, other_args);
at::checkSameType(c, input_args[0], other_args[0]);
for (size_t i = 0; i < n_tensors; ++i) {
at::checkDim(c, input_args[i], 2);
at::checkDim(c, other_args[i], 2);
at::checkSize(c, other_args[i], 0, input_args[i]->size(-1));
}

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::cuda_grouped_matmul", "")
.findSchemaOrThrow("pyg::grouped_matmul", "")
.typed<decltype(_grouped_matmul)>();
return op.call(input, other);
}

at::Tensor _segment_matmul(const at::Tensor& input,
const at::Tensor& ptr,
const at::Tensor& other) {
// TODO (matthias) Add TensorArg definitions.
at::TensorArg input_arg{input, "input", 0};
at::TensorArg ptr_arg{ptr, "ptr", 1};
at::TensorArg other_arg{other, "other", 2};
at::CheckedFrom c{"segment_matmul"};

at::checkAllDefined(c, {input_arg, ptr_arg, other_arg});
at::checkSameType(c, input_arg, other_arg);
at::checkDim(c, input_arg, 2);
at::checkDim(c, ptr_arg, 1);
at::checkDim(c, other_arg, 3);
at::checkSize(c, other_arg, 1, input_arg->size(-1));
at::checkNumel(c, ptr_arg, other_arg->size(0) + 1);

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::segment_matmul", "")
.typed<decltype(_segment_matmul)>();
Expand Down Expand Up @@ -130,8 +161,8 @@ class SegmentMatmul : public torch::autograd::Function<SegmentMatmul> {
} // namespace

// Performs matrix multiplication across list of elements.
std::vector<at::Tensor> grouped_matmul(const std::vector<at::Tensor>& input,
const std::vector<at::Tensor>& other) {
std::vector<at::Tensor> grouped_matmul(const at::TensorList input,
const at::TensorList other) {
// TODO (matthias) Add autograd support.
/* return GroupedMatmul::apply(input, other)[0]; */
return _grouped_matmul(input, other);
Expand Down
5 changes: 2 additions & 3 deletions pyg_lib/csrc/ops/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@ namespace ops {

// Performs matrix multiplication across list of elements.
// TODO (matthias) Support `out` argument.
PYG_API std::vector<at::Tensor> grouped_matmul(
const std::vector<at::Tensor>& input,
const std::vector<at::Tensor>& other);
PYG_API std::vector<at::Tensor> grouped_matmul(const at::TensorList input,
const at::TensorList other);

// Performs matrix multiplication according to segments.
// TODO (matthias) Support `out` argument.
Expand Down
18 changes: 18 additions & 0 deletions pyg_lib/csrc/utils/check.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "check.h"

namespace pyg {
namespace utils {

void fill_tensor_args(std::vector<at::TensorArg>& args,
const at::TensorList tensors,
const std::string& name,
int pos) {
args.reserve(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
const auto full_name = name + "[" + std::to_string(i) + "]";
args.emplace_back(tensors[i], full_name.c_str(), pos);
}
}

} // namespace utils
} // namespace pyg
14 changes: 14 additions & 0 deletions pyg_lib/csrc/utils/check.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once

#include <ATen/ATen.h>

namespace pyg {
namespace utils {

void fill_tensor_args(std::vector<at::TensorArg>& args,
const at::TensorList tensors,
const std::string& name,
int pos);

} // namespace utils
} // namespace pyg
16 changes: 5 additions & 11 deletions pyg_lib/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,16 @@
class SegmentMatmul(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, ptr, other):
assert inputs.is_cuda
assert ptr.is_cuda
assert other.is_cuda
ctx.save_for_backward(inputs, ptr, other)
return torch.ops.pyg.cuda_segment_matmul(inputs, ptr, other)
return torch.ops.pyg.segment_matmul(inputs, ptr, other)

@staticmethod
def backward(ctx, out_grad):
inputs, ptr, other = ctx.saved_tensors

input_grad = None
if inputs.requires_grad:
input_grad = torch.ops.pyg.cuda_segment_matmul(
input_grad = torch.ops.pyg.segment_matmul(
out_grad, ptr, torch.transpose(other, -2, -1))

other_grad = None, None
Expand All @@ -40,11 +37,8 @@ def backward(ctx, out_grad):
class GroupedMatmul(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs: List[Tensor], others: List[Tensor]):
for x, other in zip(inputs, others):
assert x.is_cuda
assert other.is_cuda
ctx.save_for_backward(inputs, others)
outs = torch.ops.pyg.cuda_grouped_matmul(inputs, others)
outs = torch.ops.pyg.grouped_matmul(inputs, others)

# NOTE Autograd doesnt set out[i].requires_grad = True automatically
for i in range(len(outs)):
Expand All @@ -60,13 +54,13 @@ def backward(ctx, outs_grad: List[Tensor]):
if all([x.requires_grad for x in inputs]):
for i in range(len(others)):
others[i] = others[i].t()
inputs_grad = torch.ops.pyg.cuda_grouped_matmul(outs_grad, others)
inputs_grad = torch.ops.pyg.grouped_matmul(outs_grad, others)

others_grad = None
if all([other.requires_grad for other in others]):
for i in range(len(inputs)):
inputs[i] = inputs[i].t()
others_grad = torch.ops.pyg.cuda_grouped_matmul(inputs, outs_grad)
others_grad = torch.ops.pyg.grouped_matmul(inputs, outs_grad)

return inputs_grad, others_grad

Expand Down
30 changes: 19 additions & 11 deletions test/csrc/ops/test_matmul.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
#include <ATen/ATen.h>
#include <c10/core/DeviceType.h>
#include <gtest/gtest.h>

#include "pyg_lib/csrc/ops/matmul.h"

#ifdef WITH_CUDA
TEST(GroupedMatmulTest, BasicAssertions) {
auto options = at::TensorOptions().device(at::kCUDA);
class MatmulTest : public testing::TestWithParam<c10::DeviceType> {};

TEST_P(MatmulTest, GroupedMatmulForward) {
const auto param = ::testing::TestWithParam<c10::DeviceType>::GetParam();
auto options = at::TensorOptions().device(param);

std::vector<at::Tensor> input{at::randn({5, 8}, options),
at::randn({3, 12}, options)};
Expand All @@ -22,11 +25,10 @@ TEST(GroupedMatmulTest, BasicAssertions) {
auto expected_out1 = at::matmul(input[1], other[1]);
EXPECT_TRUE(at::allclose(out[1], expected_out1, 0.1, 0.1));
}
#endif

#ifdef WITH_CUDA
TEST(SegmentMatmulTest, BasicAssertions) {
auto options = at::TensorOptions().device(at::kCUDA);
TEST_P(MatmulTest, SegmentMatmulForward) {
const auto param = ::testing::TestWithParam<c10::DeviceType>::GetParam();
auto options = at::TensorOptions().device(param);

auto input = at::randn({8, 12}, options);
auto ptr = at::tensor({0, 5, 8}, options.dtype(at::kLong));
Expand All @@ -40,14 +42,13 @@ TEST(SegmentMatmulTest, BasicAssertions) {
auto expected_out1 = at::matmul(input.narrow(0, 5, 3), other[1]);
EXPECT_TRUE(at::allclose(out.narrow(0, 5, 3), expected_out1, 0.1, 0.1));
}
#endif

// TODO (matthias) add a grouped matmul backward test.

#ifdef WITH_CUDA
TEST(SegmentMatmulBackwardTest, BasicAssertions) {
TEST_P(MatmulTest, SegmentMatmulBackward) {
return; // TODO (matthias) uncomment this.
auto options = at::TensorOptions().device(at::kCUDA);
const auto param = ::testing::TestWithParam<c10::DeviceType>::GetParam();
auto options = at::TensorOptions().device(param);

auto input = at::randn({8, 12}, options).requires_grad_();
auto ptr = at::tensor({0, 5, 8}, options.dtype(at::kLong));
Expand All @@ -58,4 +59,11 @@ TEST(SegmentMatmulBackwardTest, BasicAssertions) {
EXPECT_TRUE(input.grad().numel() == input.numel());
EXPECT_TRUE(other.grad().numel() == other.numel());
}

INSTANTIATE_TEST_SUITE_P(OpsTest,
MatmulTest,
#ifdef WITH_CUDA
testing::Values(at::kCUDA, at::kCPU));
#else
testing::Values(at::kCPU));
#endif
Loading