From dae6148d0da5b268286af8df59d77744f603d492 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 6 Feb 2025 11:41:05 +0000 Subject: [PATCH 1/5] update --- pyg_lib/csrc/classes/cpu/hash_map.cpp | 129 +++++++++++++++++++++++++- test/csrc/classes/test_hash_map.cpp | 7 ++ 2 files changed, 131 insertions(+), 5 deletions(-) diff --git a/pyg_lib/csrc/classes/cpu/hash_map.cpp b/pyg_lib/csrc/classes/cpu/hash_map.cpp index c09f958b5..83ab3f744 100644 --- a/pyg_lib/csrc/classes/cpu/hash_map.cpp +++ b/pyg_lib/csrc/classes/cpu/hash_map.cpp @@ -27,8 +27,71 @@ struct CPUHashMapImpl : HashMapImpl { public: using ValueType = int64_t; - CPUHashMapImpl(const at::Tensor& key) { - map_.reserve(key.numel()); + CPUHashMapImpl(const at::Tensor& key, double load_factor) { + size_t capacity = std::ceil(key.numel() / load_factor); + map_.reserve(capacity); + + const auto key_data = key.data_ptr(); + for (int64_t i = 0; i < key.numel(); ++i) { + const auto [iterator, inserted] = map_.insert({key_data[i], i}); + TORCH_CHECK(inserted, "Found duplicated key in 'HashMap'."); + } + } + + at::Tensor get(const at::Tensor& query) override { + const auto options = + query.options().dtype(c10::CppTypeToScalarType::value); + const auto out = at::empty({query.numel()}, options); + const auto query_data = query.data_ptr(); + const auto out_data = out.data_ptr(); + + const auto num_threads = at::get_num_threads(); + const auto grain_size = + std::max((query.numel() + num_threads - 1) / num_threads, + at::internal::GRAIN_SIZE); + + at::parallel_for(0, query.numel(), grain_size, [&](int64_t b, int64_t e) { + for (int64_t i = b; i < e; ++i) { + const auto it = map_.find(query_data[i]); + out_data[i] = (it != map_.end()) ? it->second : -1; + } + }); + + return out; + } + + at::Tensor keys() override { + const auto size = static_cast(map_.size()); + + at::Tensor key; + if (std::is_same::value) { + key = at::empty({size}, at::TensorOptions().dtype(at::kShort)); + } else if (std::is_same::value) { + key = at::empty({size}, at::TensorOptions().dtype(at::kInt)); + } else { + key = at::empty({size}, at::TensorOptions().dtype(at::kLong)); + } + const auto key_data = key.data_ptr(); + + for (const auto& pair : map_) { // No efficient multi-threading possible :( + key_data[pair.second] = pair.first; + } + + return key; + } + + private: + phmap::flat_hash_map map_; +}; + +template +struct ParallelCPUHashMapImpl : HashMapImpl { + public: + using ValueType = int64_t; + + ParallelCPUHashMapImpl(const at::Tensor& key, double load_factor) { + size_t capacity = std::ceil(key.numel() / load_factor); + map_.reserve(capacity); const auto key_data = key.data_ptr(); @@ -94,14 +157,16 @@ struct CPUHashMapImpl : HashMapImpl { phmap::priv::hash_default_hash, phmap::priv::hash_default_eq, phmap::priv::Allocator>, - 12, + num_submaps, std::mutex> map_; }; struct CPUHashMap : torch::CustomClassHolder { public: - CPUHashMap(const at::Tensor& key) { + CPUHashMap(const at::Tensor& key, + int64_t num_submaps = 0, + double load_factor = 0.5) { at::TensorArg key_arg{key, "key", 0}; at::CheckedFrom c{"CPUHashMap.init"}; at::checkDeviceType(c, key, at::DeviceType::CPU); @@ -109,7 +174,61 @@ struct CPUHashMap : torch::CustomClassHolder { at::checkContiguous(c, key_arg); DISPATCH_KEY(key.scalar_type(), "cpu_hash_map_init", [&] { - map_ = std::make_unique>(key); + switch (num_submaps) { + case 0: + map_ = std::make_unique>(key, load_factor); + break; + case 2: + map_ = std::make_unique>( + key, load_factor); + break; + case 4: + map_ = std::make_unique>( + key, load_factor); + break; + case 8: + map_ = std::make_unique>( + key, load_factor); + break; + case 16: + map_ = std::make_unique>( + key, load_factor); + break; + case 32: + map_ = std::make_unique>( + key, load_factor); + break; + case 64: + map_ = std::make_unique>( + key, load_factor); + break; + case 128: + map_ = std::make_unique>( + key, load_factor); + break; + case 256: + map_ = std::make_unique>( + key, load_factor); + break; + case 512: + map_ = std::make_unique>( + key, load_factor); + break; + case 1024: + map_ = std::make_unique>( + key, load_factor); + break; + case 2048: + map_ = std::make_unique>( + key, load_factor); + break; + case 4096: + map_ = std::make_unique>( + key, load_factor); + break; + default: + TORCH_CHECK(false, "'num_submaps' needs to be a power of 2"); + } }); } diff --git a/test/csrc/classes/test_hash_map.cpp b/test/csrc/classes/test_hash_map.cpp index 0f6ed2ece..a6e497bb6 100644 --- a/test/csrc/classes/test_hash_map.cpp +++ b/test/csrc/classes/test_hash_map.cpp @@ -13,4 +13,11 @@ TEST(HashMapTest, BasicAssertions) { auto query = at::tensor({30, 10, 20, 40}, options); auto expected = at::tensor({2, 1, 3, -1}, options); EXPECT_TRUE(at::equal(map.get(query), expected)); + + map = pyg::classes::CPUHashMap(key, 16); + EXPECT_TRUE(at::equal(map.keys(), key)); + + auto query = at::tensor({30, 10, 20, 40}, options); + auto expected = at::tensor({2, 1, 3, -1}, options); + EXPECT_TRUE(at::equal(map.get(query), expected)); } From 1fad5a1f2218b2eb1e4eb1490f0632005a162009 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 6 Feb 2025 11:52:10 +0000 Subject: [PATCH 2/5] update --- test/csrc/classes/test_hash_map.cpp | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/csrc/classes/test_hash_map.cpp b/test/csrc/classes/test_hash_map.cpp index a6e497bb6..d5698a6d3 100644 --- a/test/csrc/classes/test_hash_map.cpp +++ b/test/csrc/classes/test_hash_map.cpp @@ -16,8 +16,5 @@ TEST(HashMapTest, BasicAssertions) { map = pyg::classes::CPUHashMap(key, 16); EXPECT_TRUE(at::equal(map.keys(), key)); - - auto query = at::tensor({30, 10, 20, 40}, options); - auto expected = at::tensor({2, 1, 3, -1}, options); EXPECT_TRUE(at::equal(map.get(query), expected)); } From 8e8b85e4ba6f634133c525ec79bfa7ce09920f89 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 6 Feb 2025 12:13:13 +0000 Subject: [PATCH 3/5] update --- pyg_lib/csrc/classes/cpu/hash_map.cpp | 2 +- pyg_lib/csrc/classes/cuda/hash_map.cu | 2 +- test/classes/test_hash_map.py | 30 +++++++++++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 test/classes/test_hash_map.py diff --git a/pyg_lib/csrc/classes/cpu/hash_map.cpp b/pyg_lib/csrc/classes/cpu/hash_map.cpp index 83ab3f744..765931fee 100644 --- a/pyg_lib/csrc/classes/cpu/hash_map.cpp +++ b/pyg_lib/csrc/classes/cpu/hash_map.cpp @@ -252,7 +252,7 @@ struct CPUHashMap : torch::CustomClassHolder { TORCH_LIBRARY_FRAGMENT(pyg, m) { m.class_("CPUHashMap") - .def(torch::init()) + .def(torch::init()) .def("get", &CPUHashMap::get) .def("keys", &CPUHashMap::keys) .def_pickle( diff --git a/pyg_lib/csrc/classes/cuda/hash_map.cu b/pyg_lib/csrc/classes/cuda/hash_map.cu index 31eb85e52..beacac655 100644 --- a/pyg_lib/csrc/classes/cuda/hash_map.cu +++ b/pyg_lib/csrc/classes/cuda/hash_map.cu @@ -142,7 +142,7 @@ struct CUDAHashMap : torch::CustomClassHolder { TORCH_LIBRARY_FRAGMENT(pyg, m) { m.class_("CUDAHashMap") - .def(torch::init()) + .def(torch::init()) .def("get", &CUDAHashMap::get) .def("keys", &CUDAHashMap::keys) .def_pickle( diff --git a/test/classes/test_hash_map.py b/test/classes/test_hash_map.py new file mode 100644 index 000000000..4f87d6e70 --- /dev/null +++ b/test/classes/test_hash_map.py @@ -0,0 +1,30 @@ +import pytest +import torch + +from pyg_lib.testing import withCUDA + + +@withCUDA +@pytest.mark.parametrize('load_factor', [0.5, 0.25]) +@pytest.mark.parametrize('dtype', [torch.short, torch.int, torch.long]) +def test_hash_map(load_factor, dtype, device): + key = torch.tensor([0, 10, 30, 20], device=device, dtype=dtype) + query = torch.tensor([30, 10, 20, 40], device=device, dtype=dtype) + + if key.is_cpu: + HashMap = torch.classes.pyg.CPUHashMap + hash_map = HashMap(key, 0, load_factor) + elif key.is_cuda: + HashMap = torch.classes.pyg.CUDAHashMap + hash_map = HashMap(key, load_factor) + else: + raise NotImplementedError(f"Unsupported device '{device}'") + + assert torch.equal(key, hash_map.keys()) + expected = torch.tensor([2, 1, 3, -1], device=device) + assert hash_map.get(query).equal(expected) + + if key.is_cpu: + hash_map = HashMap(key, 16, load_factor) + assert torch.equal(key, hash_map.keys()) + assert hash_map.get(query).equal(expected) From ddc546960febd1ad91592b2fd03192e5c4862dc9 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 6 Feb 2025 12:16:16 +0000 Subject: [PATCH 4/5] update --- README.md | 2 +- test/classes/test_hash_map.py | 8 ++++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 4b529893a..d0019ad50 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,7 @@ pip install pyg-lib -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html where * `${TORCH}` should be replaced by either `1.13.0`, `2.0.0`, `2.1.0`, `2.2.0`, `2.3.0`, `2.4.0` or `2.5.0` -* `${CUDA}` should be replaced by either `cpu`, `cu102`, `cu113`, `cu116`, `cu117`, `cu118`, `cu121`, or `cu124` +* `${CUDA}` should be replaced by either `cpu`, `cu102`, `cu117`, `cu118`, `cu121`, or `cu124` The following combinations are supported: diff --git a/test/classes/test_hash_map.py b/test/classes/test_hash_map.py index 4f87d6e70..b33709ac4 100644 --- a/test/classes/test_hash_map.py +++ b/test/classes/test_hash_map.py @@ -20,11 +20,15 @@ def test_hash_map(load_factor, dtype, device): else: raise NotImplementedError(f"Unsupported device '{device}'") - assert torch.equal(key, hash_map.keys()) + assert hash_map.keys().equal(key) + assert hash_map.keys().equal(key) expected = torch.tensor([2, 1, 3, -1], device=device) assert hash_map.get(query).equal(expected) + assert hash_map.get(query).dtype == torch.long if key.is_cpu: hash_map = HashMap(key, 16, load_factor) - assert torch.equal(key, hash_map.keys()) + assert hash_map.keys().dtype == dtype + assert hash_map.keys().equal(key) assert hash_map.get(query).equal(expected) + assert hash_map.get(query).dtype == torch.long From 54fc3a3e9b67714929988ac828eb45796997cdd4 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Thu, 6 Feb 2025 12:30:36 +0000 Subject: [PATCH 5/5] update --- test/classes/test_hash_map.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/classes/test_hash_map.py b/test/classes/test_hash_map.py index b33709ac4..f7283bc4d 100644 --- a/test/classes/test_hash_map.py +++ b/test/classes/test_hash_map.py @@ -1,5 +1,8 @@ +import os.path as osp + import pytest import torch +from torch import Tensor from pyg_lib.testing import withCUDA @@ -32,3 +35,29 @@ def test_hash_map(load_factor, dtype, device): assert hash_map.keys().equal(key) assert hash_map.get(query).equal(expected) assert hash_map.get(query).dtype == torch.long + + +class Foo(torch.nn.Module): + def __init__(self, key: Tensor): + super().__init__() + if key.is_cpu: + HashMap = torch.classes.pyg.CPUHashMap + self.map = HashMap(key, 0, 0.5) + elif key.is_cuda: + HashMap = torch.classes.pyg.CUDAHashMap + self.map = HashMap(key, 0.5) + + def forward(self, query: Tensor) -> Tensor: + return self.map.get(query) + + +@withCUDA +def test_serialization(device, tmp_path): + key = torch.tensor([0, 10, 30, 20], device=device) + scripted_foo = torch.jit.script(Foo(key)) + + path = osp.join(tmp_path, 'foo.pt') + scripted_foo.save(path) + loaded = torch.jit.load(path) + + assert loaded.map.keys().equal(key)