Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.4.0] - 2023-MM-DD
### Added
- Added support for `bfloat16` data type in `segment_matmul` and `grouped_matmul` (CPU only) ([#272](https://github.com/pyg-team/pyg-lib/pull/272))
### Changed
- Added `--biased` parameter to run benchmarks for biased sampling ([#267](https://github.com/pyg-team/pyg-lib/pull/267))
- Improved speed of biased sampling ([#270](https://github.com/pyg-team/pyg-lib/pull/270))
Expand Down
6 changes: 4 additions & 2 deletions pyg_lib/csrc/ops/cpu/matmul_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ std::vector<at::Tensor> grouped_matmul_kernel(const at::TensorList input,
{input_contig[i].size(0), other_contig[i].size(-1)}));
}

AT_DISPATCH_ALL_TYPES(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
input_contig.front().scalar_type(), "grouped_matmul_kernel", [&] {
if (mkl_path_available<scalar_t>() &&
mkl_path_possible(input_contig, other_contig)) {
Expand Down Expand Up @@ -413,7 +414,8 @@ at::Tensor segment_matmul_kernel(const at::Tensor& input,
const auto other_contig = other.contiguous();
auto out = input_contig.new_empty({input.size(0), other.size(-1)});

AT_DISPATCH_ALL_TYPES(
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16,
input_contig.scalar_type(), "segment_matmul_kernel", [&] {
const auto n = other_contig.size(-1);
const auto k = input_contig.size(-1);
Expand Down
11 changes: 6 additions & 5 deletions pyg_lib/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@ def onlyTriton(func: Callable) -> Callable:


def withCUDA(func: Callable) -> Callable:
def wrapper(*args, **kwargs):
func(*args, device=torch.device('cpu'), **kwargs)
if torch.cuda.is_available():
func(*args, device=torch.device('cuda:0'), **kwargs)
import pytest

return wrapper
devices = [torch.device('cpu')]
if torch.cuda.is_available():
devices.append(torch.device('cuda:0'))

return pytest.mark.parametrize('device', devices)(func)


def withDataset(group: str, name: str) -> Callable:
Expand Down
21 changes: 16 additions & 5 deletions test/ops/test_matmul.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import pytest
import torch

import pyg_lib
Expand All @@ -11,11 +12,17 @@


@withCUDA
def test_segment_matmul_autograd(device):
inputs = torch.randn((8, 16), requires_grad=True, device=device)
@pytest.mark.parametrize('dtype', [torch.float, torch.bfloat16])
def test_segment_matmul_autograd(dtype, device):
if device.type == 'cuda' and dtype == torch.bfloat16:
pytest.skip('CUDA does not support bfloat16')

inputs = torch.randn((8, 16), requires_grad=True, device=device,
dtype=dtype)
ptr = torch.tensor([0, 5, 8]).to(torch.device(device))
other = torch.randn((2, 16, 32), requires_grad=True, device=device)
bias = torch.randn((2, 32), requires_grad=True, device=device)
other = torch.randn((2, 16, 32), requires_grad=True, device=device,
dtype=dtype)
bias = torch.randn((2, 32), requires_grad=True, device=device, dtype=dtype)
out = pyg_lib.ops.segment_matmul(inputs, ptr, other, bias)
assert out.size() == (8, 32)

Expand All @@ -31,7 +38,11 @@ def test_segment_matmul_autograd(device):


@withCUDA
def test_grouped_matmul_autograd(device):
@pytest.mark.parametrize('dtype', [torch.float, torch.bfloat16])
def test_grouped_matmul_autograd(dtype, device):
if device.type == 'cuda' and dtype == torch.bfloat16:
pytest.skip('CUDA does not support bfloat16')

inputs = [
torch.randn(5, 16, device=device, requires_grad=True),
torch.randn(6, 9, device=device, requires_grad=True),
Expand Down