Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.4.0] - 2023-MM-DD
### Added
- Added `softmax_csr` implementation ([#264](https://github.com/pyg-team/pyg-lib/pull/264))
- Added `softmax_csr` implementation ([#264](https://github.com/pyg-team/pyg-lib/pull/264), [#282](https://github.com/pyg-team/pyg-lib/pull/282))
- Added support for edge-level sampling ([#280](https://github.com/pyg-team/pyg-lib/pull/280))
- Added support for `bfloat16` data type in `segment_matmul` and `grouped_matmul` (CPU only) ([#272](https://github.com/pyg-team/pyg-lib/pull/272))
### Changed
Expand Down
60 changes: 60 additions & 0 deletions pyg_lib/csrc/ops/autograd/softmax_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#include "../softmax.h"

#include <torch/autograd.h>

namespace pyg {
namespace ops {

namespace {

using torch::autograd::Variable;
using torch::autograd::variable_list;

class SoftmaxCSR : public torch::autograd::Function<SoftmaxCSR> {
public:
static variable_list forward(torch::autograd::AutogradContext* ctx,
const Variable& src,
const at::Tensor& ptr,
const int64_t dim) {
at::AutoDispatchBelowADInplaceOrView g;

Variable out = softmax_csr(src, ptr, dim);
ctx->saved_data["dim"] = dim;
ctx->save_for_backward({src, out, ptr});

return {out};
}

static variable_list backward(torch::autograd::AutogradContext* ctx,
variable_list out_grads) {
const auto out_grad = out_grads[0];
const auto saved = ctx->get_saved_variables();
const auto src = saved[0];
const auto out = saved[1];
const auto ptr = saved[2];
const auto dim = ctx->saved_data["dim"].toInt();

auto src_grad = Variable();
if (torch::autograd::any_variable_requires_grad({src})) {
src_grad = softmax_csr_backward(out, out_grad, ptr, dim);
}

return {src_grad, Variable(), Variable()};
}
};

at::Tensor softmax_csr_autograd(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim) {
return SoftmaxCSR::apply(src, ptr, dim)[0];
}

} // namespace

TORCH_LIBRARY_IMPL(pyg, Autograd, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr"),
TORCH_FN(softmax_csr_autograd));
}

} // namespace ops
} // namespace pyg
2 changes: 1 addition & 1 deletion pyg_lib/csrc/ops/cpu/softmax_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ at::Tensor softmax_csr_backward_kernel(const at::Tensor& out,
} // namespace

TORCH_LIBRARY_IMPL(pyg, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_forward"),
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr"),
TORCH_FN(softmax_csr_forward_kernel));
m.impl(TORCH_SELECTIVE_NAME("pyg::softmax_csr_backward"),
TORCH_FN(softmax_csr_backward_kernel));
Expand Down
16 changes: 8 additions & 8 deletions pyg_lib/csrc/ops/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@ namespace pyg {
namespace ops {

// Performs softmax operations for each group.
PYG_API at::Tensor softmax_csr_forward(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim) {
PYG_API at::Tensor softmax_csr(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim) {
at::TensorArg src_arg{src, "src", 0};
at::TensorArg ptr_arg{ptr, "ptr", 1};
at::CheckedFrom c{"softmax_forward"};
at::CheckedFrom c{"softmax_csr"};

at::checkAllDefined(c, {src_arg, ptr_arg});
at::checkContiguous(c, src_arg);
at::checkContiguous(c, ptr_arg);

static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::softmax_csr_forward", "")
.typed<decltype(softmax_csr_forward)>();
.findSchemaOrThrow("pyg::softmax_csr", "")
.typed<decltype(softmax_csr)>();
return op.call(src, ptr, dim);
}

Expand All @@ -32,7 +32,7 @@ PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out,
at::TensorArg out_arg{out, "out", 0};
at::TensorArg out_grad_arg{out_grad, "out_grad", 1};
at::TensorArg ptr_arg{ptr, "ptr", 2};
at::CheckedFrom c{"softmax_backward"};
at::CheckedFrom c{"softmax_csr_backward"};

at::checkAllDefined(c, {out_arg, out_grad_arg, ptr_arg});
at::checkContiguous(c, out_arg);
Expand All @@ -47,7 +47,7 @@ PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out,

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(
TORCH_SELECTIVE_SCHEMA("pyg::softmax_csr_forward(Tensor src, Tensor ptr, "
TORCH_SELECTIVE_SCHEMA("pyg::softmax_csr(Tensor src, Tensor ptr, "
"int dim=0) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::softmax_csr_backward(Tensor out, Tensor out_grad, "
Expand Down
6 changes: 3 additions & 3 deletions pyg_lib/csrc/ops/softmax.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ namespace pyg {
namespace ops {

// Performs softmax operations for each group.
PYG_API at::Tensor softmax_csr_forward(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim = 0);
PYG_API at::Tensor softmax_csr(const at::Tensor& src,
const at::Tensor& ptr,
const int64_t dim = 0);

// Computes gradient for grouped softmax operations.
PYG_API at::Tensor softmax_csr_backward(const at::Tensor& out,
Expand Down
27 changes: 2 additions & 25 deletions pyg_lib/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple

import torch
import torch.utils._pytree as pytree
Expand Down Expand Up @@ -331,29 +331,6 @@ def index_sort(
return torch.ops.pyg.index_sort(inputs, max_value)


class Softmax(torch.autograd.Function):
@staticmethod
def forward(
ctx,
src: Tensor,
ptr: Tensor,
dim: int = 0,
) -> Tensor:
out = torch.ops.pyg.softmax_csr_forward(src, ptr, dim)
ctx.save_for_backward(out, ptr)
ctx.dim = dim

return out

@staticmethod
def backward(ctx, out_grad: Tensor) -> Tuple[Union[Tensor, int]]:
out, ptr = ctx.saved_tensors
in_grad = torch.ops.pyg.softmax_csr_backward(out, out_grad, ptr,
ctx.dim)

return in_grad, None, None


def softmax_csr(
src: Tensor,
ptr: Tensor,
Expand Down Expand Up @@ -384,7 +361,7 @@ def softmax_csr(
[0.7792, 0.3502, 0.1638, 0.2145]])
"""
dim = dim + src.dim() if dim < 0 else dim
return Softmax.apply(src, ptr, dim)
return torch.ops.pyg.softmax_csr(src, ptr, dim)


__all__ = [
Expand Down
14 changes: 12 additions & 2 deletions test/csrc/ops/test_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,35 @@ TEST_P(CPUTest, SoftmaxCSRForward) {
const auto ptr = at::tensor({0, 3, 4, 7, 8}, at::kLong);
const auto expected_out = softmax2D_ref_impl(src, ptr, dim);

const auto out = pyg::ops::softmax_csr_forward(src, ptr, dim);
const auto out = pyg::ops::softmax_csr(src, ptr, dim);
EXPECT_EQ(expected_out.size(0), out.size(0));
EXPECT_EQ(expected_out.size(1), out.size(1));
EXPECT_TRUE(at::allclose(expected_out, out, 1e-04, 1e-04));
}

TEST_P(CPUTest, SoftmaxCSRBackward) {
TEST_P(CPUTest, SoftmaxCSRAutogradBackward) {
const auto dim = ::testing::TestWithParam<int64_t>::GetParam();
const auto src = at::rand({8, 8});
src.set_requires_grad(true);
const auto ptr = at::tensor({0, 3, 4, 7, 8}, at::kLong);
const auto out = softmax2D_ref_impl(src, ptr, dim);
const auto out_grad = at::rand({8, 8});

// use softmax_csr_backward directly
const auto in_grad = pyg::ops::softmax_csr_backward(out, out_grad, ptr, dim);
out.backward(out_grad);
EXPECT_EQ(src.grad().size(0), in_grad.size(0));
EXPECT_EQ(src.grad().size(1), in_grad.size(1));
EXPECT_TRUE(at::allclose(src.grad(), in_grad, 1e-04, 1e-04));

// use softmax backward via autograd module
const auto src2 = src.detach().clone();
src2.set_requires_grad(true);
const auto out2 = pyg::ops::softmax_csr(src2, ptr, dim);
out2.backward(out_grad);
EXPECT_EQ(src.grad().size(0), src2.grad().size(0));
EXPECT_EQ(src.grad().size(1), src2.grad().size(1));
EXPECT_TRUE(at::allclose(src.grad(), src2.grad(), 1e-04, 1e-04));
}

INSTANTIATE_TEST_SUITE_P(OpsTest,
Expand Down