diff --git a/CHANGELOG.md b/CHANGELOG.md index 9693d008d..e9deac66c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [Unreleased] ### Added +- Added `pyg::utils::to_vector` implementation ([#88](https://github.com/pyg-team/pyg-lib/pull/88)) - Added support for PyTorch 1.12 ([#57](https://github.com/pyg-team/pyg-lib/pull/57), [#58](https://github.com/pyg-team/pyg-lib/pull/58)) - Added `grouped_matmul` and `segment_matmul` CUDA implementations via `cutlass` ([#51](https://github.com/pyg-team/pyg-lib/pull/51), [#56](https://github.com/pyg-team/pyg-lib/pull/56), [#61](https://github.com/pyg-team/pyg-lib/pull/61), [#64](https://github.com/pyg-team/pyg-lib/pull/64), [#69](https://github.com/pyg-team/pyg-lib/pull/69)) - Added `pyg::sampler::neighbor_sample` implementation ([#54](https://github.com/pyg-team/pyg-lib/pull/54), [#76](https://github.com/pyg-team/pyg-lib/pull/76), [#77](https://github.com/pyg-team/pyg-lib/pull/77), [#78](https://github.com/pyg-team/pyg-lib/pull/78), [#80](https://github.com/pyg-team/pyg-lib/pull/80), [#81](https://github.com/pyg-team/pyg-lib/pull/81)), [#85](https://github.com/pyg-team/pyg-lib/pull/85), [#86](https://github.com/pyg-team/pyg-lib/pull/86), [#87](https://github.com/pyg-team/pyg-lib/pull/87)) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 1c395d3b9..23cacab72 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -214,11 +214,11 @@ sample(const at::Tensor& rowptr, std::vector sampled_nodes; const auto seed_data = seed.data_ptr(); - for (size_t i = 0; i < seed.numel(); i++) { - if constexpr (!disjoint) { - mapper.insert(seed_data[i]); - sampled_nodes.push_back(seed_data[i]); - } else { + if constexpr (!disjoint) { + sampled_nodes = pyg::utils::to_vector(seed); + mapper.fill(seed); + } else { + for (size_t i = 0; i < seed.numel(); i++) { mapper.insert({i, seed_data[i]}); sampled_nodes.push_back({i, seed_data[i]}); } diff --git a/pyg_lib/csrc/utils/cpu/convert.h b/pyg_lib/csrc/utils/cpu/convert.h index 8edaa862d..6b65fc7ad 100644 --- a/pyg_lib/csrc/utils/cpu/convert.h +++ b/pyg_lib/csrc/utils/cpu/convert.h @@ -23,5 +23,12 @@ at::Tensor from_vector(const std::vector>& vec, return inplace ? out : out.clone(); } +template +std::vector to_vector(const at::Tensor& vec) { + TORCH_CHECK(vec.is_contiguous(), "'vec' needs to be contiguous"); + const auto vec_data = vec.data_ptr(); + return std::vector(vec_data, vec_data + vec.numel()); +} + } // namespace utils } // namespace pyg