Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.2.0] - 2023-MM-DD
### Added
- Added `sampled_op` impementation ([#156](https://github.com/pyg-team/pyg-lib/pull/156))
### Changed
- Improved `[segment|grouped]_matmul` CPU implementation via `at::matmul_out` and MKL BLAS `gemm_batch` ([#146](https://github.com/pyg-team/pyg-lib/pull/146))
### Removed
Expand Down
3 changes: 3 additions & 0 deletions pyg_lib/csrc/config.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#pragma once

#define WITH_MKL_BLAS() 0
55 changes: 55 additions & 0 deletions pyg_lib/csrc/ops/cpu/sampled_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#include <ATen/ATen.h>
#include <torch/library.h>

namespace pyg {
namespace ops {

namespace {

enum FnType { ADD, SUB, MUL, DIV };
const std::map<std::string, FnType> to_fn_type = {
{"add", ADD},
{"sub", SUB},
{"mul", MUL},
{"div", DIV},
};

at::Tensor sampled_op_kernel(const at::Tensor& left,
const at::Tensor& right,
const at::optional<at::Tensor> left_index,
const at::optional<at::Tensor> right_index,
const std::string fn) {
auto a = left;
if (left_index.has_value()) {
a = left.index_select(0, left_index.value());
}

auto b = right;
if (right_index.has_value()) {
b = right.index_select(0, right_index.value());
}

auto fn_type = to_fn_type.at(fn);

at::Tensor out;
if (fn_type == ADD) {
out = a + b;
} else if (fn_type == SUB) {
out = a - b;
} else if (fn_type == MUL) {
out = a * b;
} else if (fn_type == DIV) {
out = a / b;
}

return out;
}

} // namespace

TORCH_LIBRARY_IMPL(pyg, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::sampled_op"), TORCH_FN(sampled_op_kernel));
}

} // namespace ops
} // namespace pyg
113 changes: 113 additions & 0 deletions pyg_lib/csrc/ops/cuda/sampled_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/library.h>

namespace pyg {
namespace ops {

namespace {

#define THREADS 1024
#define CDIV(N, M) ((N) + (M)-1) / (M)

enum FnType { ADD, SUB, MUL, DIV };
const std::map<std::string, FnType> to_fn_type = {
{"add", ADD},
{"sub", SUB},
{"mul", MUL},
{"div", DIV},
};

template <typename scalar_t>
__global__ void sampled_op_kernel_impl(const scalar_t* __restrict__ left,
const scalar_t* __restrict__ right,
scalar_t* __restrict__ out,
const int64_t* __restrict__ left_index,
const int64_t* __restrict__ right_index,
const FnType fn_type,
const bool has_left_index,
const bool has_right_index,
const int64_t num_feats,
const int64_t numel) {
int64_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x;

if (thread_idx >= numel)
return;

int64_t i = thread_idx / num_feats;
if (has_left_index) {
i = left_index[i];
}

int64_t j = thread_idx / num_feats;
if (has_right_index) {
j = right_index[j];
}

int64_t k = thread_idx % num_feats;

scalar_t a = left[i * num_feats + k];
scalar_t b = right[j * num_feats + k];

scalar_t c;
if (fn_type == ADD) {
c = a + b;
} else if (fn_type == SUB) {
c = a - b;
} else if (fn_type == MUL) {
c = a * b;
} else if (fn_type == DIV) {
c = a / b;
}

out[thread_idx] = c;
}

at::Tensor sampled_op_kernel(const at::Tensor& left,
const at::Tensor& right,
const at::optional<at::Tensor> left_index,
const at::optional<at::Tensor> right_index,
const std::string fn) {
auto dim_size = left.size(0);
if (left_index.has_value() && !right_index.has_value()) {
dim_size = right.size(0);
} else if (left_index.has_value() && right_index.has_value()) {
dim_size = left_index.value().size(0);
}
const auto num_feats = left.size(1);
const auto numel = dim_size * num_feats;

const auto out = left.new_empty({dim_size, num_feats});

auto stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_ALL_TYPES(left.scalar_type(), "sampled_kernel_impl", [&] {
const auto left_data = left.data_ptr<scalar_t>();
const auto right_data = right.data_ptr<scalar_t>();
auto out_data = out.data_ptr<scalar_t>();

int64_t* left_index_data = NULL;
if (left_index.has_value()) {
left_index_data = left_index.value().data_ptr<int64_t>();
}
int64_t* right_index_data = NULL;
if (right_index.has_value()) {
right_index_data = right_index.value().data_ptr<int64_t>();
}

sampled_op_kernel_impl<scalar_t>
<<<CDIV(numel, THREADS), THREADS, 0, stream>>>(
left_data, right_data, out_data, left_index_data, right_index_data,
to_fn_type.at(fn), left_index.has_value(), right_index.has_value(),
num_feats, numel);
});
return out;
}

} // namespace

TORCH_LIBRARY_IMPL(pyg, CUDA, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::sampled_op"), TORCH_FN(sampled_op_kernel));
}

} // namespace ops
} // namespace pyg
63 changes: 63 additions & 0 deletions pyg_lib/csrc/ops/sampled.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include "sampled.h"

#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>

namespace pyg {
namespace ops {

// Performs the operation `op` at sampled left and right indices.
PYG_API at::Tensor sampled_op(const at::Tensor& left,
const at::Tensor& right,
const at::optional<at::Tensor> left_index,
const at::optional<at::Tensor> right_index,
const std::string fn) {
at::TensorArg left_arg{left, "left", 0};
at::TensorArg right_arg{right, "right", 1};
at::CheckedFrom c{"sampled_op"};

at::checkAllDefined(c, {left_arg, right_arg});
at::checkSameType(c, left_arg, right_arg);
at::checkContiguous(c, left_arg);
at::checkContiguous(c, right_arg);
at::checkDim(c, left_arg, 2);
at::checkDim(c, right_arg, 2);
at::checkSize(c, left_arg, 1, right_arg->size(1));

if (left_index.has_value()) {
at::TensorArg left_index_arg{left_index.value(), "left_index", 2};
at::checkContiguous(c, left_index_arg);
at::checkDim(c, left_index_arg, 1);
}

if (right_index.has_value()) {
at::TensorArg right_index_arg{right_index.value(), "right_index", 3};
at::checkContiguous(c, right_index_arg);
at::checkDim(c, right_index_arg, 1);
}

if (left_index.has_value() && right_index.has_value()) {
at::TensorArg left_index_arg{left_index.value(), "left_index", 2};
at::TensorArg right_index_arg{right_index.value(), "right_index", 3};
at::checkSameType(c, left_index_arg, right_index_arg);
at::checkSize(c, left_index_arg, 0, right_index_arg->size(0));
}

if (!left_index.has_value() && !right_index.has_value()) {
at::checkSize(c, left_arg, 0, right_arg->size(0));
}

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::sampled_op", "")
.typed<decltype(sampled_op)>();
return op.call(left, right, left_index, right_index, fn);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::sampled_op(Tensor left, Tensor right, Tensor? left_index, Tensor? "
"right_index, str op) -> Tensor"));
}

} // namespace ops
} // namespace pyg
17 changes: 17 additions & 0 deletions pyg_lib/csrc/ops/sampled.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include <ATen/ATen.h>
#include "pyg_lib/csrc/macros.h"

namespace pyg {
namespace ops {

// Performs the operation `op` at sampled left and right indices.
PYG_API at::Tensor sampled_op(const at::Tensor& left,
const at::Tensor& right,
const at::optional<at::Tensor> left_index,
const at::optional<at::Tensor> right_index,
const std::string fn);

} // namespace ops
} // namespace pyg
Loading