From 21368614a37c838325b141b9e0be857c052c9335 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 10 Feb 2025 08:39:24 +0000 Subject: [PATCH] update --- pyg_lib/csrc/classes/cpu/hash_map.cpp | 59 ++++++++++++++++++--------- pyg_lib/csrc/classes/cuda/hash_map.cu | 54 +++++++++++++++++------- test/classes/test_hash_map.py | 9 ++++ test/csrc/classes/test_hash_map.cpp | 3 ++ 4 files changed, 90 insertions(+), 35 deletions(-) diff --git a/pyg_lib/csrc/classes/cpu/hash_map.cpp b/pyg_lib/csrc/classes/cpu/hash_map.cpp index af5c1fe09..6dc9898e3 100644 --- a/pyg_lib/csrc/classes/cpu/hash_map.cpp +++ b/pyg_lib/csrc/classes/cpu/hash_map.cpp @@ -20,6 +20,8 @@ struct HashMapImpl { virtual ~HashMapImpl() = default; virtual at::Tensor get(const at::Tensor& query) = 0; virtual at::Tensor keys() = 0; + virtual int64_t size() = 0; + virtual at::ScalarType dtype() = 0; }; template @@ -60,16 +62,8 @@ struct CPUHashMapImpl : HashMapImpl { } at::Tensor keys() override { - const auto size = static_cast(map_.size()); - - at::Tensor key; - if (std::is_same::value) { - key = at::empty({size}, at::TensorOptions().dtype(at::kShort)); - } else if (std::is_same::value) { - key = at::empty({size}, at::TensorOptions().dtype(at::kInt)); - } else { - key = at::empty({size}, at::TensorOptions().dtype(at::kLong)); - } + const at::Tensor key = + at::empty({size()}, at::TensorOptions().dtype(dtype())); const auto key_data = key.data_ptr(); for (const auto& pair : map_) { // No efficient multi-threading possible :( @@ -79,6 +73,18 @@ struct CPUHashMapImpl : HashMapImpl { return key; } + int64_t size() override { return static_cast(map_.size()); } + + at::ScalarType dtype() override { + if (std::is_same::value) { + return at::kShort; + } else if (std::is_same::value) { + return at::kInt; + } else { + return at::kLong; + } + } + private: phmap::flat_hash_map map_; }; @@ -129,16 +135,8 @@ struct ParallelCPUHashMapImpl : HashMapImpl { } at::Tensor keys() override { - const auto size = static_cast(map_.size()); - - at::Tensor key; - if (std::is_same::value) { - key = at::empty({size}, at::TensorOptions().dtype(at::kShort)); - } else if (std::is_same::value) { - key = at::empty({size}, at::TensorOptions().dtype(at::kInt)); - } else { - key = at::empty({size}, at::TensorOptions().dtype(at::kLong)); - } + const at::Tensor key = + at::empty({size()}, at::TensorOptions().dtype(dtype())); const auto key_data = key.data_ptr(); for (const auto& pair : map_) { // No efficient multi-threading possible :( @@ -148,6 +146,18 @@ struct ParallelCPUHashMapImpl : HashMapImpl { return key; } + int64_t size() override { return static_cast(map_.size()); } + + at::ScalarType dtype() override { + if (std::is_same::value) { + return at::kShort; + } else if (std::is_same::value) { + return at::kInt; + } else { + return at::kLong; + } + } + private: phmap::parallel_flat_hash_map< KeyType, @@ -234,6 +244,12 @@ struct CPUHashMap : torch::CustomClassHolder { at::Tensor keys() { return map_->keys(); } + int64_t size() { return map_->size(); } + + at::ScalarType dtype() { return map_->dtype(); } + + at::Device device() { return at::Device(at::kCPU); } + private: std::unique_ptr map_; }; @@ -245,6 +261,9 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { .def(torch::init()) .def("get", &CPUHashMap::get) .def("keys", &CPUHashMap::keys) + .def("size", &CPUHashMap::size) + .def("dtype", &CPUHashMap::dtype) + .def("device", &CPUHashMap::device) .def_pickle( // __getstate__ [](const c10::intrusive_ptr& self) -> at::Tensor { diff --git a/pyg_lib/csrc/classes/cuda/hash_map.cu b/pyg_lib/csrc/classes/cuda/hash_map.cu index 3217b8a82..1d9f81970 100644 --- a/pyg_lib/csrc/classes/cuda/hash_map.cu +++ b/pyg_lib/csrc/classes/cuda/hash_map.cu @@ -1,4 +1,5 @@ #include +#include #include #include @@ -23,6 +24,9 @@ struct HashMapImpl { virtual ~HashMapImpl() = default; virtual at::Tensor get(const at::Tensor& query) = 0; virtual at::Tensor keys() = 0; + virtual int64_t size() = 0; + virtual at::ScalarType dtype() = 0; + virtual at::Device device() = 0; }; #ifndef _WIN32 @@ -31,7 +35,10 @@ struct CUDAHashMapImpl : HashMapImpl { public: using ValueType = int64_t; - CUDAHashMapImpl(const at::Tensor& key, double load_factor) { + CUDAHashMapImpl(const at::Tensor& key, double load_factor) + : device_(key.device()) { + c10::cuda::MaybeSetDevice(key.get_device()); + KeyType constexpr empty_key_sentinel = std::numeric_limits::min(); ValueType constexpr empty_value_sentinel = -1; @@ -52,6 +59,8 @@ struct CUDAHashMapImpl : HashMapImpl { } at::Tensor get(const at::Tensor& query) override { + c10::cuda::MaybeSetDevice(query.get_device()); + const auto options = query.options().dtype(c10::CppTypeToScalarType::value); const auto out = at::empty({query.numel()}, options); @@ -64,21 +73,12 @@ struct CUDAHashMapImpl : HashMapImpl { } at::Tensor keys() override { - // TODO This will not work in multi-GPU scenarios. - const auto options = at::TensorOptions() - .device(at::DeviceType::CUDA) - .dtype(c10::CppTypeToScalarType::value); - const auto size = static_cast(map_->size()); + c10::cuda::MaybeSetDevice(device_.index()); - at::Tensor key; - if (std::is_same::value) { - key = at::empty({size}, options.dtype(at::kShort)); - } else if (std::is_same::value) { - key = at::empty({size}, options.dtype(at::kInt)); - } else { - key = at::empty({size}, options); - } - const auto value = at::empty({size}, options); + const auto options = at::TensorOptions().device(device_); + const at::Tensor key = at::empty({size()}, options.dtype(dtype())); + const at::Tensor value = at::empty( + {size()}, options.dtype(c10::CppTypeToScalarType::value)); const auto key_data = key.data_ptr(); const auto value_data = value.data_ptr(); @@ -90,8 +90,23 @@ struct CUDAHashMapImpl : HashMapImpl { return key.index_select(0, perm); } + int64_t size() override { return static_cast(map_->size()); } + + at::ScalarType dtype() override { + if (std::is_same::value) { + return at::kShort; + } else if (std::is_same::value) { + return at::kInt; + } else { + return at::kLong; + } + } + + at::Device device() override { return device_; } + private: std::unique_ptr> map_; + at::Device device_; }; #endif @@ -135,6 +150,12 @@ struct CUDAHashMap : torch::CustomClassHolder { #endif } + int64_t size() { return map_->size(); } + + at::ScalarType dtype() { return map_->dtype(); } + + at::Device device() { return map_->device(); } + private: #ifndef _WIN32 std::unique_ptr map_; @@ -148,6 +169,9 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { .def(torch::init()) .def("get", &CUDAHashMap::get) .def("keys", &CUDAHashMap::keys) + .def("size", &CUDAHashMap::size) + .def("dtype", &CUDAHashMap::dtype) + .def("device", &CUDAHashMap::device) .def_pickle( // __getstate__ [](const c10::intrusive_ptr& self) -> at::Tensor { diff --git a/test/classes/test_hash_map.py b/test/classes/test_hash_map.py index 2dc08dcba..fd5d81e13 100644 --- a/test/classes/test_hash_map.py +++ b/test/classes/test_hash_map.py @@ -6,6 +6,12 @@ from pyg_lib.testing import withCUDA +INT_TO_DTPYE = { + 2: torch.short, + 3: torch.int, + 4: torch.long, +} + @withCUDA @pytest.mark.parametrize('dtype', [torch.short, torch.int, torch.long]) @@ -22,6 +28,9 @@ def test_hash_map(dtype, device): else: raise NotImplementedError(f"Unsupported device '{device}'") + assert hash_map.size() == 4 + assert INT_TO_DTPYE[hash_map.dtype()] == dtype + assert hash_map.device() == device assert hash_map.keys().equal(key) assert hash_map.keys().dtype == dtype expected = torch.tensor([2, 1, 3, -1], device=device) diff --git a/test/csrc/classes/test_hash_map.cpp b/test/csrc/classes/test_hash_map.cpp index d5698a6d3..8b10d3d0e 100644 --- a/test/csrc/classes/test_hash_map.cpp +++ b/test/csrc/classes/test_hash_map.cpp @@ -8,6 +8,9 @@ TEST(HashMapTest, BasicAssertions) { auto key = at::tensor({0, 10, 30, 20}, options); auto map = pyg::classes::CPUHashMap(key); + EXPECT_EQ(map.size(), 4); + EXPECT_EQ(map.dtype(), at::kLong); + EXPECT_EQ(map.device(), at::Device(at::kCPU)); EXPECT_TRUE(at::equal(map.keys(), key)); auto query = at::tensor({30, 10, 20, 40}, options);