Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.2.0] - 2023-MM-DD
### Added
- Added `index_sort` implementation ([#181](https://github.com/pyg-team/pyg-lib/pull/181))
- Added `triton>=2.0` support ([#171](https://github.com/pyg-team/pyg-lib/pull/171))
- Added `bias` term to `grouped_matmul` and `segment_matmul` ([#161](https://github.com/pyg-team/pyg-lib/pull/161))
- Added `sampled_op` implementation ([#156](https://github.com/pyg-team/pyg-lib/pull/156), [#159](https://github.com/pyg-team/pyg-lib/pull/159), [#160](https://github.com/pyg-team/pyg-lib/pull/160))
Expand Down
8 changes: 8 additions & 0 deletions LICENSE.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Copyright for a whole pyg-lib project (excluding radix_sort.h file located in
pyg_lib/csrc/cpu directory) is held by Copyright (c) 2022 PyG Team <[email protected]>,
and is provided under the MIT licence (definition located in
LICENSE/LICENSE_pyg_lib file).

Copyright for radix_sort.h file located in pyg_lib/csrc/cpu directory is held by
Copyright (c) Meta Platforms, Inc. and affiliates, and is provided under the BSD
license (definition located in LICENSE/LICENSE_radix_sort file).
File renamed without changes.
30 changes: 30 additions & 0 deletions LICENSE/LICENSE_radix_sort
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
BSD License

For FBGEMM software

Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.

Redistribution and use in source and binary forms, with or without modification,
are permitted provided that the following conditions are met:

* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.

* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.

* Neither the name Facebook nor the names of its contributors may be used to
endorse or promote products derived from this software without specific
prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
85 changes: 85 additions & 0 deletions pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#include <tuple>

#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <torch/library.h>

#include "radix_sort.h"

namespace pyg {
namespace ops {

namespace {

template <typename scalar_t>
void vectorized_copy(scalar_t* dst, const scalar_t* src, int64_t size) {
constexpr int64_t unfold_step = 4;
int64_t index;
if (size >= unfold_step) {
#pragma omp simd
for (index = 0; index < size; index += unfold_step) {
dst[index] = src[index];
dst[index + 1] = src[index + 1];
dst[index + 2] = src[index + 2];
dst[index + 3] = src[index + 3];
}
}
for (index; index < size; ++index) {
dst[index] = src[index];
}
}

std::tuple<at::Tensor, at::Tensor> index_sort_kernel(
const at::Tensor& input,
const at::optional<int64_t> max) {
TORCH_CHECK(input.is_contiguous(), "Input should be contiguous.")
TORCH_CHECK(input.dim() == 1, "Input should be 1-dimensional.");
if (input.numel() > at::internal::GRAIN_SIZE && is_radix_sort_available()) {
const auto elements = input.numel();
const auto maximum = max.value_or(input.max().item<int64_t>());
auto out_vals = input.detach().clone();
auto out_indices = at::arange(
0, elements, at::TensorOptions().device(at::kCPU).dtype(at::kLong));

AT_DISPATCH_INTEGRAL_TYPES(
out_vals.scalar_type(), "index_sort_kernel", [&] {
scalar_t* vals = out_vals.data_ptr<scalar_t>();
int64_t* indices = out_indices.data_ptr<int64_t>();
std::vector<scalar_t> tmp_vals(elements);
std::vector<int64_t> tmp_indices(elements);
scalar_t* sorted_vals = nullptr;
int64_t* sorted_indices = nullptr;
std::tie(sorted_vals, sorted_indices) =
radix_sort_parallel(vals, indices, tmp_vals.data(),
tmp_indices.data(), elements, maximum);

const bool sorted_in_place = vals == sorted_vals;
if (!sorted_in_place) {
const int num_threads = at::get_num_threads();
const auto common_size = out_vals.numel();
at::parallel_for(
0, common_size, at::internal::GRAIN_SIZE / num_threads,
[&](int64_t begin, int64_t end) {
const auto job_size = end - begin;
vectorized_copy(vals + begin, sorted_vals + begin, job_size);
vectorized_copy(indices + begin, sorted_indices + begin,
job_size);
});
}
});
return std::tuple<at::Tensor, at::Tensor>(out_vals, out_indices);
} else {
TORCH_CHECK(at::isIntegralType(input.scalar_type(), /*includeBool=*/false),
"Input should contain integral values.");
return at::sort(input);
}
}

} // namespace

TORCH_LIBRARY_IMPL(pyg, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::index_sort"), TORCH_FN(index_sort_kernel));
}

} // namespace ops
} // namespace pyg
202 changes: 202 additions & 0 deletions pyg_lib/csrc/ops/cpu/radix_sort.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* Copyright (c) Intel Corporation.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE_radix_sort file in the LICENSE directory of this source tree.
*/

#pragma once

#include <cstdint>
#include <utility>

#if !defined(_OPENMP)

namespace pyg {
namespace ops {

bool inline is_radix_sort_available() {
return false;
}

template <typename K, typename V>
std::pair<K*, V*> radix_sort_parallel(K* inp_key_buf,
V* inp_value_buf,
K* tmp_key_buf,
V* tmp_value_buf,
int64_t elements_count,
int64_t max_value) {
TORCH_CHECK(
false,
"radix_sort_parallel: pyg-lib is not compiled with OpenMP support");
}

} // namespace ops
} // namespace pyg

#else

#include <c10/util/llvmMathExtras.h>
#include <omp.h>

namespace pyg {
namespace ops {

namespace {

// Copied from fbgemm implementation available here:
// https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/src/cpu_utils.cpp
//
// `radix_sort_parallel` is only available when pyg-lib is compiled with
// OpenMP, since the algorithm requires sync between omp threads, which can not
// be perfectly mapped to `at::parallel_for` at the current stage.

// histogram size per thread
constexpr int RDX_HIST_SIZE = 256;

template <typename K, typename V>
void radix_sort_kernel(K* input_keys,
V* input_values,
K* output_keys,
V* output_values,
int elements_count,
int* histogram,
int* histogram_ps,
int pass) {
int tid = omp_get_thread_num();
int nthreads = omp_get_num_threads();
int elements_count_4 = elements_count / 4 * 4;

int* local_histogram = &histogram[RDX_HIST_SIZE * tid];
int* local_histogram_ps = &histogram_ps[RDX_HIST_SIZE * tid];

// Step 1: compute histogram
for (int i = 0; i < RDX_HIST_SIZE; i++) {
local_histogram[i] = 0;
}

#pragma omp for schedule(static)
for (int64_t i = 0; i < elements_count_4; i += 4) {
K key_1 = input_keys[i];
K key_2 = input_keys[i + 1];
K key_3 = input_keys[i + 2];
K key_4 = input_keys[i + 3];

local_histogram[(key_1 >> (pass * 8)) & 0xFF]++;
local_histogram[(key_2 >> (pass * 8)) & 0xFF]++;
local_histogram[(key_3 >> (pass * 8)) & 0xFF]++;
local_histogram[(key_4 >> (pass * 8)) & 0xFF]++;
}
if (tid == (nthreads - 1)) {
for (int64_t i = elements_count_4; i < elements_count; ++i) {
K key = input_keys[i];
local_histogram[(key >> (pass * 8)) & 0xFF]++;
}
}
#pragma omp barrier
// Step 2: prefix sum
if (tid == 0) {
int sum = 0, prev_sum = 0;
for (int bins = 0; bins < RDX_HIST_SIZE; bins++) {
for (int t = 0; t < nthreads; t++) {
sum += histogram[t * RDX_HIST_SIZE + bins];
histogram_ps[t * RDX_HIST_SIZE + bins] = prev_sum;
prev_sum = sum;
}
}
histogram_ps[RDX_HIST_SIZE * nthreads] = prev_sum;
TORCH_CHECK(prev_sum == elements_count);
}
#pragma omp barrier

// Step 3: scatter
#pragma omp for schedule(static)
for (int64_t i = 0; i < elements_count_4; i += 4) {
K key_1 = input_keys[i];
K key_2 = input_keys[i + 1];
K key_3 = input_keys[i + 2];
K key_4 = input_keys[i + 3];

int bin_1 = (key_1 >> (pass * 8)) & 0xFF;
int bin_2 = (key_2 >> (pass * 8)) & 0xFF;
int bin_3 = (key_3 >> (pass * 8)) & 0xFF;
int bin_4 = (key_4 >> (pass * 8)) & 0xFF;

int pos;
pos = local_histogram_ps[bin_1]++;
output_keys[pos] = key_1;
output_values[pos] = input_values[i];
pos = local_histogram_ps[bin_2]++;
output_keys[pos] = key_2;
output_values[pos] = input_values[i + 1];
pos = local_histogram_ps[bin_3]++;
output_keys[pos] = key_3;
output_values[pos] = input_values[i + 2];
pos = local_histogram_ps[bin_4]++;
output_keys[pos] = key_4;
output_values[pos] = input_values[i + 3];
}
if (tid == (nthreads - 1)) {
for (int64_t i = elements_count_4; i < elements_count; ++i) {
K key = input_keys[i];
int pos = local_histogram_ps[(key >> (pass * 8)) & 0xFF]++;
output_keys[pos] = key;
output_values[pos] = input_values[i];
}
}
}

} // namespace

bool inline is_radix_sort_available() {
return true;
}

template <typename K, typename V>
std::pair<K*, V*> radix_sort_parallel(K* inp_key_buf,
V* inp_value_buf,
K* tmp_key_buf,
V* tmp_value_buf,
int64_t elements_count,
int64_t max_value) {
int maxthreads = omp_get_max_threads();
std::unique_ptr<int[]> histogram_tmp(new int[RDX_HIST_SIZE * maxthreads]);
std::unique_ptr<int[]> histogram_ps_tmp(
new int[RDX_HIST_SIZE * maxthreads + 1]);
int* histogram = histogram_tmp.get();
int* histogram_ps = histogram_ps_tmp.get();
if (max_value == 0) {
return std::make_pair(inp_key_buf, inp_value_buf);
}

// __builtin_clz is not portable
int num_bits =
sizeof(K) * 8 - c10::llvm::countLeadingZeros(
static_cast<std::make_unsigned_t<K> >(max_value));
unsigned int num_passes = (num_bits + 7) / 8;

#pragma omp parallel
{
K* input_keys = inp_key_buf;
V* input_values = inp_value_buf;
K* output_keys = tmp_key_buf;
V* output_values = tmp_value_buf;

for (unsigned int pass = 0; pass < num_passes; pass++) {
radix_sort_kernel(input_keys, input_values, output_keys, output_values,
elements_count, histogram, histogram_ps, pass);

std::swap(input_keys, output_keys);
std::swap(input_values, output_values);
#pragma omp barrier
}
}
return (num_passes % 2 == 0 ? std::make_pair(inp_key_buf, inp_value_buf)
: std::make_pair(tmp_key_buf, tmp_value_buf));
}

} // namespace ops
} // namespace pyg

#endif
13 changes: 13 additions & 0 deletions pyg_lib/csrc/ops/index_sort.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#include <ATen/core/dispatch/Dispatcher.h>
#include <torch/library.h>

namespace pyg {
namespace ops {

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::index_sort(Tensor indices, int? max = None) -> (Tensor, Tensor)"));
}

} // namespace ops
} // namespace pyg
Loading