From 3a4b3fc424fcd75ccaa5b763bfed47ff7f4330e0 Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Wed, 18 Jan 2023 09:26:05 +0100 Subject: [PATCH 1/9] Update LICENSEs --- LICENSE.txt | 8 ++++++++ LICENSE => LICENSE/LICENSE_pyg_lib | 0 LICENSE/LICENSE_radix_sort | 30 ++++++++++++++++++++++++++++++ 3 files changed, 38 insertions(+) create mode 100644 LICENSE.txt rename LICENSE => LICENSE/LICENSE_pyg_lib (100%) create mode 100644 LICENSE/LICENSE_radix_sort diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 000000000..dee647842 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,8 @@ +Copyright for a whole pyg-lib project (excluding radix_sort.h file located in +pyg_lib/csrc/cpu directory) is held by Copyright (c) 2022 PyG Team , +and is provided under the MIT licence (definition located in +LICENSE/LICENSE_pyg_lib file). + +Copyright for radix_sort.h file located in pyg_lib/csrc/cpu directory is held by +Copyright (c) Meta Platforms, Inc. and affiliates, and is provided under the BSD +license (definition located in LICENSE/LICENSE_radix_sort file). \ No newline at end of file diff --git a/LICENSE b/LICENSE/LICENSE_pyg_lib similarity index 100% rename from LICENSE rename to LICENSE/LICENSE_pyg_lib diff --git a/LICENSE/LICENSE_radix_sort b/LICENSE/LICENSE_radix_sort new file mode 100644 index 000000000..0dd6cae7d --- /dev/null +++ b/LICENSE/LICENSE_radix_sort @@ -0,0 +1,30 @@ +BSD License + +For FBGEMM software + +Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Facebook nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file From 06f262801ac4bccf34ce59923b8ab9d7ca7b004d Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Wed, 18 Jan 2023 09:26:59 +0100 Subject: [PATCH 2/9] Add `index_sort` definition --- pyg_lib/ops/__init__.py | 32 ++++++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/pyg_lib/ops/__init__.py b/pyg_lib/ops/__init__.py index 20f2af2fa..43811a163 100644 --- a/pyg_lib/ops/__init__.py +++ b/pyg_lib/ops/__init__.py @@ -1,7 +1,7 @@ -from typing import List, Optional +from typing import List, Optional, Tuple import torch -from torch import Tensor +from torch import Tensor, LongTensor from .scatter_reduce import fused_scatter_reduce @@ -227,6 +227,33 @@ def sampled_div( return out +def index_sort(input: LongTensor, + max: Optional[int] = None) -> Tuple[LongTensor, LongTensor]: + r"""Sorts the elements of the :obj:`input` tensor in ascending order by + value. It is expected that :obj:`input` tensor is 1-dimensional and + contains only positive, integer values. If :obj:`max` is given, it can be + used by the underlying algorithm for better performance. + + .. note:: + + This operation works only for tensors associated with the CPU device. + + Args: + input (torch.LongTensor): 1-dimensional tensor with positive integer + values. + max: (int, optional): A maximum value stored inside :obj:`input`. + This value can be an estimation, but needs to be greather + or equal to the real maximum. (default: :obj:`None`) + + Returns: + Tuple[torch.LongTensor, torch.LongTensor]: + A tuple containing sorted values and indices of the elements in the + original :obj:`input` tensor. + """ + out = torch.ops.pyg.index_sort(input, max) + return out + + __all__ = [ 'grouped_matmul', 'segment_matmul', @@ -234,5 +261,6 @@ def sampled_div( 'sampled_sub', 'sampled_mul', 'sampled_div', + 'index_sort', 'fused_scatter_reduce', ] From 89694bad8e5de54b156ffaa1367c835a872d5f30 Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Wed, 18 Jan 2023 09:27:35 +0100 Subject: [PATCH 3/9] Add `index_sort` implementation --- pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp | 83 +++++++++ pyg_lib/csrc/ops/cpu/radix_sort.h | 202 +++++++++++++++++++++ pyg_lib/csrc/ops/index_sort.cpp | 31 ++++ pyg_lib/csrc/ops/index_sort.h | 16 ++ 4 files changed, 332 insertions(+) create mode 100644 pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp create mode 100644 pyg_lib/csrc/ops/cpu/radix_sort.h create mode 100644 pyg_lib/csrc/ops/index_sort.cpp create mode 100644 pyg_lib/csrc/ops/index_sort.h diff --git a/pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp b/pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp new file mode 100644 index 000000000..1c9d0c043 --- /dev/null +++ b/pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp @@ -0,0 +1,83 @@ +#include + +#include +#include +#include + +#include "radix_sort.h" + +namespace pyg { +namespace ops { + +namespace { + +template +void vectorized_copy(scalar_t* dst, const scalar_t* src, int64_t size) { + constexpr int64_t unfold_step = 4; + int64_t index; + if (size >= unfold_step) { +#pragma omp simd + for (index = 0; index < size; index += unfold_step) { + dst[index] = src[index]; + dst[index + 1] = src[index + 1]; + dst[index + 2] = src[index + 2]; + dst[index + 3] = src[index + 3]; + } + } + for (index; index < size; ++index) { + dst[index] = src[index]; + } +} + +std::tuple index_sort_kernel( + const at::Tensor& input, + const at::optional max) { + if (input.numel() > at::internal::GRAIN_SIZE && is_radix_sort_available()) { + const auto elements = input.numel(); + const auto maximum = max.value_or(at::max(input).item()); + auto out_vals = at::detach(input).clone(); + auto out_indices = at::arange( + 0, elements, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); + + AT_DISPATCH_INTEGRAL_TYPES( + out_vals.scalar_type(), "index_sort_kernel", [&] { + scalar_t* vals = out_vals.data_ptr(); + int64_t* indices = out_indices.data_ptr(); + std::vector tmp_vals(elements); + std::vector tmp_indices(elements); + scalar_t* sorted_vals = nullptr; + int64_t* sorted_indices = nullptr; + std::tie(sorted_vals, sorted_indices) = + radix_sort_parallel(vals, indices, tmp_vals.data(), + tmp_indices.data(), elements, maximum); + + const bool sorted_in_place = vals == sorted_vals; + if (!sorted_in_place) { + const int num_threads = at::get_num_threads(); + const auto common_size = out_vals.numel(); + at::parallel_for( + 0, common_size, at::internal::GRAIN_SIZE / num_threads, + [&](int64_t begin, int64_t end) { + const auto job_size = end - begin; + vectorized_copy(vals + begin, sorted_vals + begin, job_size); + vectorized_copy(indices + begin, sorted_indices + begin, + job_size); + }); + } + }); + return std::tuple(out_vals, out_indices); + } else { + TORCH_CHECK(at::isIntegralType(input.scalar_type(), /*includeBool=*/false), + "Input should contain integral values."); + return at::sort(input); + } +} + +} // namespace + +TORCH_LIBRARY_IMPL(pyg, CPU, m) { + m.impl(TORCH_SELECTIVE_NAME("pyg::index_sort"), TORCH_FN(index_sort_kernel)); +} + +} // namespace ops +} // namespace pyg diff --git a/pyg_lib/csrc/ops/cpu/radix_sort.h b/pyg_lib/csrc/ops/cpu/radix_sort.h new file mode 100644 index 000000000..7666bc881 --- /dev/null +++ b/pyg_lib/csrc/ops/cpu/radix_sort.h @@ -0,0 +1,202 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright (c) Intel Corporation. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE_radix_sort file in the LICENSE directory of this source tree. + */ + +#pragma once + +#include +#include + +#if !defined(_OPENMP) + +namespace pyg { +namespace ops { + +bool inline is_radix_sort_available() { + return false; +} + +template +std::pair radix_sort_parallel(K* inp_key_buf, + V* inp_value_buf, + K* tmp_key_buf, + V* tmp_value_buf, + int64_t elements_count, + int64_t max_value) { + TORCH_CHECK( + false, + "radix_sort_parallel: pyg-lib is not compiled with OpenMP support"); +} + +} // namespace ops +} // namespace pyg + +#else + +#include +#include + +namespace pyg { +namespace ops { + +namespace { + +// Copied from fbgemm implementation available here: +// https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/src/cpu_utils.cpp +// +// `radix_sort_parallel` is only available when pyg-lib is compiled with +// OpenMP, since the algorithm requires sync between omp threads, which can not +// be perfectly mapped to `at::parallel_for` at the current stage. + +// histogram size per thread +constexpr int RDX_HIST_SIZE = 256; + +template +void radix_sort_kernel(K* input_keys, + V* input_values, + K* output_keys, + V* output_values, + int elements_count, + int* histogram, + int* histogram_ps, + int pass) { + int tid = omp_get_thread_num(); + int nthreads = omp_get_num_threads(); + int elements_count_4 = elements_count / 4 * 4; + + int* local_histogram = &histogram[RDX_HIST_SIZE * tid]; + int* local_histogram_ps = &histogram_ps[RDX_HIST_SIZE * tid]; + + // Step 1: compute histogram + for (int i = 0; i < RDX_HIST_SIZE; i++) { + local_histogram[i] = 0; + } + +#pragma omp for schedule(static) + for (int64_t i = 0; i < elements_count_4; i += 4) { + K key_1 = input_keys[i]; + K key_2 = input_keys[i + 1]; + K key_3 = input_keys[i + 2]; + K key_4 = input_keys[i + 3]; + + local_histogram[(key_1 >> (pass * 8)) & 0xFF]++; + local_histogram[(key_2 >> (pass * 8)) & 0xFF]++; + local_histogram[(key_3 >> (pass * 8)) & 0xFF]++; + local_histogram[(key_4 >> (pass * 8)) & 0xFF]++; + } + if (tid == (nthreads - 1)) { + for (int64_t i = elements_count_4; i < elements_count; ++i) { + K key = input_keys[i]; + local_histogram[(key >> (pass * 8)) & 0xFF]++; + } + } +#pragma omp barrier + // Step 2: prefix sum + if (tid == 0) { + int sum = 0, prev_sum = 0; + for (int bins = 0; bins < RDX_HIST_SIZE; bins++) { + for (int t = 0; t < nthreads; t++) { + sum += histogram[t * RDX_HIST_SIZE + bins]; + histogram_ps[t * RDX_HIST_SIZE + bins] = prev_sum; + prev_sum = sum; + } + } + histogram_ps[RDX_HIST_SIZE * nthreads] = prev_sum; + TORCH_CHECK(prev_sum == elements_count); + } +#pragma omp barrier + + // Step 3: scatter +#pragma omp for schedule(static) + for (int64_t i = 0; i < elements_count_4; i += 4) { + K key_1 = input_keys[i]; + K key_2 = input_keys[i + 1]; + K key_3 = input_keys[i + 2]; + K key_4 = input_keys[i + 3]; + + int bin_1 = (key_1 >> (pass * 8)) & 0xFF; + int bin_2 = (key_2 >> (pass * 8)) & 0xFF; + int bin_3 = (key_3 >> (pass * 8)) & 0xFF; + int bin_4 = (key_4 >> (pass * 8)) & 0xFF; + + int pos; + pos = local_histogram_ps[bin_1]++; + output_keys[pos] = key_1; + output_values[pos] = input_values[i]; + pos = local_histogram_ps[bin_2]++; + output_keys[pos] = key_2; + output_values[pos] = input_values[i + 1]; + pos = local_histogram_ps[bin_3]++; + output_keys[pos] = key_3; + output_values[pos] = input_values[i + 2]; + pos = local_histogram_ps[bin_4]++; + output_keys[pos] = key_4; + output_values[pos] = input_values[i + 3]; + } + if (tid == (nthreads - 1)) { + for (int64_t i = elements_count_4; i < elements_count; ++i) { + K key = input_keys[i]; + int pos = local_histogram_ps[(key >> (pass * 8)) & 0xFF]++; + output_keys[pos] = key; + output_values[pos] = input_values[i]; + } + } +} + +} // namespace + +bool inline is_radix_sort_available() { + return true; +} + +template +std::pair radix_sort_parallel(K* inp_key_buf, + V* inp_value_buf, + K* tmp_key_buf, + V* tmp_value_buf, + int64_t elements_count, + int64_t max_value) { + int maxthreads = omp_get_max_threads(); + std::unique_ptr histogram_tmp(new int[RDX_HIST_SIZE * maxthreads]); + std::unique_ptr histogram_ps_tmp( + new int[RDX_HIST_SIZE * maxthreads + 1]); + int* histogram = histogram_tmp.get(); + int* histogram_ps = histogram_ps_tmp.get(); + if (max_value == 0) { + return std::make_pair(inp_key_buf, inp_value_buf); + } + + // __builtin_clz is not portable + int num_bits = + sizeof(K) * 8 - c10::llvm::countLeadingZeros( + static_cast >(max_value)); + unsigned int num_passes = (num_bits + 7) / 8; + +#pragma omp parallel + { + K* input_keys = inp_key_buf; + V* input_values = inp_value_buf; + K* output_keys = tmp_key_buf; + V* output_values = tmp_value_buf; + + for (unsigned int pass = 0; pass < num_passes; pass++) { + radix_sort_kernel(input_keys, input_values, output_keys, output_values, + elements_count, histogram, histogram_ps, pass); + + std::swap(input_keys, output_keys); + std::swap(input_values, output_values); +#pragma omp barrier + } + } + return (num_passes % 2 == 0 ? std::make_pair(inp_key_buf, inp_value_buf) + : std::make_pair(tmp_key_buf, tmp_value_buf)); +} + +} // namespace ops +} // namespace pyg + +#endif \ No newline at end of file diff --git a/pyg_lib/csrc/ops/index_sort.cpp b/pyg_lib/csrc/ops/index_sort.cpp new file mode 100644 index 000000000..a858ce14a --- /dev/null +++ b/pyg_lib/csrc/ops/index_sort.cpp @@ -0,0 +1,31 @@ +#include "index_sort.h" + +#include +#include + +namespace pyg { +namespace ops { + +PYG_API std::tuple index_sort( + const at::Tensor& input, + const at::optional max) { + at::TensorArg input_arg{input, "input", 0}; + at::CheckedFrom c{"index_sort"}; + + at::checkAllDefined(c, {input_arg}); + at::checkContiguous(c, input_arg); + at::checkDim(c, input_arg, 1); + + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("pyg::index_sort", "") + .typed(); + return op.call(input, max); +} + +TORCH_LIBRARY_FRAGMENT(pyg, m) { + m.def(TORCH_SELECTIVE_SCHEMA( + "pyg::index_sort(Tensor indices, int? max = None) -> (Tensor, Tensor)")); +} + +} // namespace ops +} // namespace pyg diff --git a/pyg_lib/csrc/ops/index_sort.h b/pyg_lib/csrc/ops/index_sort.h new file mode 100644 index 000000000..c977896a1 --- /dev/null +++ b/pyg_lib/csrc/ops/index_sort.h @@ -0,0 +1,16 @@ +#pragma once + +#include +#include +#include +#include "pyg_lib/csrc/macros.h" + +namespace pyg { +namespace ops { + +PYG_API std::tuple index_sort( + const at::Tensor& input, + const at::optional max); + +} // namespace ops +} // namespace pyg From 8eaca80361d25043202e78c619fe4d14b8fce80a Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Wed, 18 Jan 2023 12:00:56 +0100 Subject: [PATCH 4/9] Add test for `index_sort` --- test/ops/test_index_sort.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 test/ops/test_index_sort.py diff --git a/test/ops/test_index_sort.py b/test/ops/test_index_sort.py new file mode 100644 index 000000000..34109ab6d --- /dev/null +++ b/test/ops/test_index_sort.py @@ -0,0 +1,21 @@ +import torch + +import pyg_lib + +torch.manual_seed(1234) + + +def test_index_sort(): + input = torch.randint(low=0, high=1024, size=(1000000, )) + ref_sorted_input, ref_indices = torch.sort(input, stable=True) + sorted_input, indices = pyg_lib.ops.index_sort(input) + assert torch.all(ref_sorted_input == sorted_input) + assert torch.all(ref_indices == indices) + + +def test_index_sort_negative(): + input = torch.randint(low=0, high=1024, size=(16, 32)) + # this should fail, as we do not support ndim > 1 + # check in pyg_lib/csrc/ops/index_sort.cpp is not performed + # TODO: fix this + sorted_input, indices = pyg_lib.ops.index_sort(input) From 8952471e33ed0697ceb418b9678617a09293e381 Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Wed, 18 Jan 2023 12:02:51 +0100 Subject: [PATCH 5/9] Update CHANGELOG.md --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29c1beb18..9c47350f5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.2.0] - 2023-MM-DD ### Added +- Added `index_sort` implementation ([#181](https://github.com/pyg-team/pyg-lib/pull/181)) - Added `triton>=2.0` support ([#171](https://github.com/pyg-team/pyg-lib/pull/171)) - Added `bias` term to `grouped_matmul` and `segment_matmul` ([#161](https://github.com/pyg-team/pyg-lib/pull/161)) - Added `sampled_op` implementation ([#156](https://github.com/pyg-team/pyg-lib/pull/156), [#159](https://github.com/pyg-team/pyg-lib/pull/159), [#160](https://github.com/pyg-team/pyg-lib/pull/160)) From 20360d451ea4d02f4b56081ed891b48394790eec Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Mon, 23 Jan 2023 10:21:58 +0100 Subject: [PATCH 6/9] Add support for cuda device --- pyg_lib/ops/__init__.py | 8 ++++++-- test/ops/test_index_sort.py | 15 +++++++++++---- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/pyg_lib/ops/__init__.py b/pyg_lib/ops/__init__.py index 43811a163..f6ac029c4 100644 --- a/pyg_lib/ops/__init__.py +++ b/pyg_lib/ops/__init__.py @@ -236,7 +236,8 @@ def index_sort(input: LongTensor, .. note:: - This operation works only for tensors associated with the CPU device. + This operation is optimized only for tensors associated with the CPU + device. Args: input (torch.LongTensor): 1-dimensional tensor with positive integer @@ -250,7 +251,10 @@ def index_sort(input: LongTensor, A tuple containing sorted values and indices of the elements in the original :obj:`input` tensor. """ - out = torch.ops.pyg.index_sort(input, max) + if input.is_cuda: + out = torch.sort(input) + else: + out = torch.ops.pyg.index_sort(input, max) return out diff --git a/test/ops/test_index_sort.py b/test/ops/test_index_sort.py index 34109ab6d..bb39958ab 100644 --- a/test/ops/test_index_sort.py +++ b/test/ops/test_index_sort.py @@ -1,20 +1,27 @@ import torch import pyg_lib +import pytest + +DEVICES = [torch.device('cpu')] +if torch.cuda.is_available(): + DEVICES.append(torch.device('cuda')) torch.manual_seed(1234) -def test_index_sort(): - input = torch.randint(low=0, high=1024, size=(1000000, )) +@pytest.mark.parametrize('device', DEVICES) +def test_index_sort(device): + input = torch.randint(low=0, high=1024, size=(1000000, ), device=device) ref_sorted_input, ref_indices = torch.sort(input, stable=True) sorted_input, indices = pyg_lib.ops.index_sort(input) assert torch.all(ref_sorted_input == sorted_input) assert torch.all(ref_indices == indices) -def test_index_sort_negative(): - input = torch.randint(low=0, high=1024, size=(16, 32)) +@pytest.mark.parametrize('device', DEVICES) +def test_index_sort_negative(device): + input = torch.randint(low=0, high=1024, size=(16, 32), device=device) # this should fail, as we do not support ndim > 1 # check in pyg_lib/csrc/ops/index_sort.cpp is not performed # TODO: fix this From 4014d35a070f94d0d172cd6fa91569ee3df14014 Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Mon, 23 Jan 2023 13:01:04 +0100 Subject: [PATCH 7/9] Move input validation to kernel implementation and remove redundant files --- pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp | 2 ++ pyg_lib/csrc/ops/index_sort.cpp | 18 ------------------ pyg_lib/csrc/ops/index_sort.h | 16 ---------------- test/ops/test_index_sort.py | 6 ++---- 4 files changed, 4 insertions(+), 38 deletions(-) delete mode 100644 pyg_lib/csrc/ops/index_sort.h diff --git a/pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp b/pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp index 1c9d0c043..d17005500 100644 --- a/pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp +++ b/pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp @@ -32,6 +32,8 @@ void vectorized_copy(scalar_t* dst, const scalar_t* src, int64_t size) { std::tuple index_sort_kernel( const at::Tensor& input, const at::optional max) { + TORCH_CHECK(input.is_contiguous(), "Input should be contiguous.") + TORCH_CHECK(input.dim() == 1, "Input should be 1-dimensional."); if (input.numel() > at::internal::GRAIN_SIZE && is_radix_sort_available()) { const auto elements = input.numel(); const auto maximum = max.value_or(at::max(input).item()); diff --git a/pyg_lib/csrc/ops/index_sort.cpp b/pyg_lib/csrc/ops/index_sort.cpp index a858ce14a..ef1203340 100644 --- a/pyg_lib/csrc/ops/index_sort.cpp +++ b/pyg_lib/csrc/ops/index_sort.cpp @@ -1,27 +1,9 @@ -#include "index_sort.h" - #include #include namespace pyg { namespace ops { -PYG_API std::tuple index_sort( - const at::Tensor& input, - const at::optional max) { - at::TensorArg input_arg{input, "input", 0}; - at::CheckedFrom c{"index_sort"}; - - at::checkAllDefined(c, {input_arg}); - at::checkContiguous(c, input_arg); - at::checkDim(c, input_arg, 1); - - static auto op = c10::Dispatcher::singleton() - .findSchemaOrThrow("pyg::index_sort", "") - .typed(); - return op.call(input, max); -} - TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::index_sort(Tensor indices, int? max = None) -> (Tensor, Tensor)")); diff --git a/pyg_lib/csrc/ops/index_sort.h b/pyg_lib/csrc/ops/index_sort.h deleted file mode 100644 index c977896a1..000000000 --- a/pyg_lib/csrc/ops/index_sort.h +++ /dev/null @@ -1,16 +0,0 @@ -#pragma once - -#include -#include -#include -#include "pyg_lib/csrc/macros.h" - -namespace pyg { -namespace ops { - -PYG_API std::tuple index_sort( - const at::Tensor& input, - const at::optional max); - -} // namespace ops -} // namespace pyg diff --git a/test/ops/test_index_sort.py b/test/ops/test_index_sort.py index bb39958ab..e5d89ec7c 100644 --- a/test/ops/test_index_sort.py +++ b/test/ops/test_index_sort.py @@ -22,7 +22,5 @@ def test_index_sort(device): @pytest.mark.parametrize('device', DEVICES) def test_index_sort_negative(device): input = torch.randint(low=0, high=1024, size=(16, 32), device=device) - # this should fail, as we do not support ndim > 1 - # check in pyg_lib/csrc/ops/index_sort.cpp is not performed - # TODO: fix this - sorted_input, indices = pyg_lib.ops.index_sort(input) + with pytest.raises(RuntimeError): + sorted_input, indices = pyg_lib.ops.index_sort(input) From f1586766aec68de0cc957d4103264cc26753dacf Mon Sep 17 00:00:00 2001 From: Damian Szwichtenberg Date: Wed, 25 Jan 2023 13:51:49 +0100 Subject: [PATCH 8/9] Resolve review comments --- pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp | 4 ++-- pyg_lib/ops/__init__.py | 13 +++++++------ test/ops/test_index_sort.py | 5 ++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp b/pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp index d17005500..55ffae173 100644 --- a/pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp +++ b/pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp @@ -36,8 +36,8 @@ std::tuple index_sort_kernel( TORCH_CHECK(input.dim() == 1, "Input should be 1-dimensional."); if (input.numel() > at::internal::GRAIN_SIZE && is_radix_sort_available()) { const auto elements = input.numel(); - const auto maximum = max.value_or(at::max(input).item()); - auto out_vals = at::detach(input).clone(); + const auto maximum = max.value_or(input.max().item()); + auto out_vals = input.detach().clone(); auto out_indices = at::arange( 0, elements, at::TensorOptions().device(at::kCPU).dtype(at::kLong)); diff --git a/pyg_lib/ops/__init__.py b/pyg_lib/ops/__init__.py index f6ac029c4..b24b86c4b 100644 --- a/pyg_lib/ops/__init__.py +++ b/pyg_lib/ops/__init__.py @@ -227,12 +227,13 @@ def sampled_div( return out -def index_sort(input: LongTensor, - max: Optional[int] = None) -> Tuple[LongTensor, LongTensor]: +def index_sort( + input: LongTensor, + max_value: Optional[int] = None) -> Tuple[LongTensor, LongTensor]: r"""Sorts the elements of the :obj:`input` tensor in ascending order by value. It is expected that :obj:`input` tensor is 1-dimensional and - contains only positive, integer values. If :obj:`max` is given, it can be - used by the underlying algorithm for better performance. + contains only positive, integer values. If :obj:`max_value` is given, it + can be used by the underlying algorithm for better performance. .. note:: @@ -242,7 +243,7 @@ def index_sort(input: LongTensor, Args: input (torch.LongTensor): 1-dimensional tensor with positive integer values. - max: (int, optional): A maximum value stored inside :obj:`input`. + max_value (int, optional): A maximum value stored inside :obj:`input`. This value can be an estimation, but needs to be greather or equal to the real maximum. (default: :obj:`None`) @@ -254,7 +255,7 @@ def index_sort(input: LongTensor, if input.is_cuda: out = torch.sort(input) else: - out = torch.ops.pyg.index_sort(input, max) + out = torch.ops.pyg.index_sort(input, max_value) return out diff --git a/test/ops/test_index_sort.py b/test/ops/test_index_sort.py index e5d89ec7c..5d4984e0f 100644 --- a/test/ops/test_index_sort.py +++ b/test/ops/test_index_sort.py @@ -19,8 +19,7 @@ def test_index_sort(device): assert torch.all(ref_indices == indices) -@pytest.mark.parametrize('device', DEVICES) -def test_index_sort_negative(device): - input = torch.randint(low=0, high=1024, size=(16, 32), device=device) +def test_index_sort_negative(): + input = torch.randint(low=0, high=1024, size=(16, 32), device='cpu') with pytest.raises(RuntimeError): sorted_input, indices = pyg_lib.ops.index_sort(input) From e2feec5ac66b645b0541422b9c8c901c5a6db549 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 25 Jan 2023 13:32:47 +0000 Subject: [PATCH 9/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- LICENSE.txt | 2 +- LICENSE/LICENSE_radix_sort | 2 +- pyg_lib/csrc/ops/cpu/radix_sort.h | 2 +- test/ops/test_index_sort.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/LICENSE.txt b/LICENSE.txt index dee647842..d3403de21 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -5,4 +5,4 @@ LICENSE/LICENSE_pyg_lib file). Copyright for radix_sort.h file located in pyg_lib/csrc/cpu directory is held by Copyright (c) Meta Platforms, Inc. and affiliates, and is provided under the BSD -license (definition located in LICENSE/LICENSE_radix_sort file). \ No newline at end of file +license (definition located in LICENSE/LICENSE_radix_sort file). diff --git a/LICENSE/LICENSE_radix_sort b/LICENSE/LICENSE_radix_sort index 0dd6cae7d..1c8dd93e6 100644 --- a/LICENSE/LICENSE_radix_sort +++ b/LICENSE/LICENSE_radix_sort @@ -27,4 +27,4 @@ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/pyg_lib/csrc/ops/cpu/radix_sort.h b/pyg_lib/csrc/ops/cpu/radix_sort.h index 7666bc881..8629958bb 100644 --- a/pyg_lib/csrc/ops/cpu/radix_sort.h +++ b/pyg_lib/csrc/ops/cpu/radix_sort.h @@ -199,4 +199,4 @@ std::pair radix_sort_parallel(K* inp_key_buf, } // namespace ops } // namespace pyg -#endif \ No newline at end of file +#endif diff --git a/test/ops/test_index_sort.py b/test/ops/test_index_sort.py index 5d4984e0f..349d2293e 100644 --- a/test/ops/test_index_sort.py +++ b/test/ops/test_index_sort.py @@ -1,7 +1,7 @@ +import pytest import torch import pyg_lib -import pytest DEVICES = [torch.device('cpu')] if torch.cuda.is_available():