From 9ca8fc5ab190245728558d064abd59a5f9718377 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 11 Jan 2025 18:07:21 +0000 Subject: [PATCH 01/11] update --- benchmark/classes/hash_map.py | 81 ++++++++++++++++++++++++++++++++++- 1 file changed, 80 insertions(+), 1 deletion(-) diff --git a/benchmark/classes/hash_map.py b/benchmark/classes/hash_map.py index 02dbc7d70..18d6c3a12 100644 --- a/benchmark/classes/hash_map.py +++ b/benchmark/classes/hash_map.py @@ -1,3 +1,82 @@ +import argparse +import time + +import pandas as pd +import torch + from pyg_lib.classes import HashMap -print(HashMap) +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--num_keys', type=int, default=10_000_000) + parser.add_argument('--num_queries', type=int, default=1_000_000) + args = parser.parse_args() + + args.num_queries = min(args.num_queries, args.num_keys) + + num_warmups, num_steps = 50, 100 + if args.device == 'cpu': + num_warmups, num_steps = num_warmups // 10, num_steps // 10 + + key = torch.randperm(args.num_keys, device=args.device) + query = torch.randperm(args.num_queries, device=args.device) + query = query[:args.num_queries] + + t_init = t_get = 0 + for i in range(num_warmups + num_steps): + torch.cuda.synchronize() + t_start = time.perf_counter() + hash_map = HashMap(key) + torch.cuda.synchronize() + if i >= num_warmups: + t_init += time.perf_counter() - t_start + + t_start = time.perf_counter() + hash_map.get(query) + torch.cuda.synchronize() + if i >= num_warmups: + t_get += time.perf_counter() - t_start + + print(f'HashMap Init: {t_init / num_steps:.4f}s') + print(f' HashMap Get: {t_get / num_steps:.4f}s') + print('=====================') + + t_init = t_get = 0 + for i in range(num_warmups + num_steps): + torch.cuda.synchronize() + t_start = time.perf_counter() + hash_map = torch.full((args.num_keys, ), fill_value=-1, + dtype=torch.long, device=args.device) + hash_map[key] = torch.arange(args.num_keys) + torch.cuda.synchronize() + if i >= num_warmups: + t_init += time.perf_counter() - t_start + + t_start = time.perf_counter() + hash_map[query] + torch.cuda.synchronize() + if i >= num_warmups: + t_get += time.perf_counter() - t_start + + print(f'Memory Init: {t_init / num_steps:.4f}s') + print(f' Memory Get: {t_get / num_steps:.4f}s') + print('=====================') + + if key.is_cpu: + t_init = t_get = 0 + for i in range(num_warmups + num_steps): + t_start = time.perf_counter() + hash_map = pd.CategoricalDtype(categories=key.numpy(), + ordered=True) + if i >= num_warmups: + t_init += time.perf_counter() - t_start + + t_start = time.perf_counter() + ser = pd.Series(query.numpy(), dtype=hash_map) + torch.from_numpy(ser.cat.codes.to_numpy()).to(torch.long) + if i >= num_warmups: + t_get += time.perf_counter() - t_start + + print(f'Pandas Init: {t_init / num_steps:.4f}s') + print(f' Pandas Get: {t_get / num_steps:.4f}s') From 9284bc36ca2e5a5201c2175c106775db31194b75 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Sat, 11 Jan 2025 18:09:38 +0000 Subject: [PATCH 02/11] update --- benchmark/classes/hash_map.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/benchmark/classes/hash_map.py b/benchmark/classes/hash_map.py index 18d6c3a12..e1058307f 100644 --- a/benchmark/classes/hash_map.py +++ b/benchmark/classes/hash_map.py @@ -39,7 +39,7 @@ t_get += time.perf_counter() - t_start print(f'HashMap Init: {t_init / num_steps:.4f}s') - print(f' HashMap Get: {t_get / num_steps:.4f}s') + print(f'HashMap Get: {t_get / num_steps:.4f}s') print('=====================') t_init = t_get = 0 @@ -59,8 +59,8 @@ if i >= num_warmups: t_get += time.perf_counter() - t_start - print(f'Memory Init: {t_init / num_steps:.4f}s') - print(f' Memory Get: {t_get / num_steps:.4f}s') + print(f' Memory Init: {t_init / num_steps:.4f}s') + print(f' Memory Get: {t_get / num_steps:.4f}s') print('=====================') if key.is_cpu: @@ -74,9 +74,9 @@ t_start = time.perf_counter() ser = pd.Series(query.numpy(), dtype=hash_map) - torch.from_numpy(ser.cat.codes.to_numpy()).to(torch.long) + ser.cat.codes.to_numpy() if i >= num_warmups: t_get += time.perf_counter() - t_start - print(f'Pandas Init: {t_init / num_steps:.4f}s') - print(f' Pandas Get: {t_get / num_steps:.4f}s') + print(f' Pandas Init: {t_init / num_steps:.4f}s') + print(f' Pandas Get: {t_get / num_steps:.4f}s') From 5376a11c8f1d78891f08019c8b594898f5239511 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 13 Jan 2025 02:27:58 +0000 Subject: [PATCH 03/11] update --- benchmark/classes/hash_map.py | 27 +++++++----- pyg_lib/csrc/classes/cpu/hash_map.cpp | 59 ++++++++++++++------------- pyg_lib/csrc/classes/cpu/hash_map.h | 16 ++++++-- 3 files changed, 60 insertions(+), 42 deletions(-) diff --git a/benchmark/classes/hash_map.py b/benchmark/classes/hash_map.py index e1058307f..845496506 100644 --- a/benchmark/classes/hash_map.py +++ b/benchmark/classes/hash_map.py @@ -19,21 +19,28 @@ if args.device == 'cpu': num_warmups, num_steps = num_warmups // 10, num_steps // 10 - key = torch.randperm(args.num_keys, device=args.device) - query = torch.randperm(args.num_queries, device=args.device) - query = query[:args.num_queries] + max_value = torch.iinfo(torch.long).max + + key1 = torch.randint(0, max_value, (args.num_keys, ), dtype=torch.long, + device=args.device) + query1 = key1[torch.randperm(key1.size(0), device=args.device)] + query1 = query1[:args.num_queries] + + key2 = torch.randperm(args.num_keys, device=args.device) + query2 = torch.randperm(args.num_queries, device=args.device) + query2 = query2[:args.num_queries] t_init = t_get = 0 for i in range(num_warmups + num_steps): torch.cuda.synchronize() t_start = time.perf_counter() - hash_map = HashMap(key) + hash_map = HashMap(key1) torch.cuda.synchronize() if i >= num_warmups: t_init += time.perf_counter() - t_start t_start = time.perf_counter() - hash_map.get(query) + hash_map.get(query1) torch.cuda.synchronize() if i >= num_warmups: t_get += time.perf_counter() - t_start @@ -48,13 +55,13 @@ t_start = time.perf_counter() hash_map = torch.full((args.num_keys, ), fill_value=-1, dtype=torch.long, device=args.device) - hash_map[key] = torch.arange(args.num_keys) + hash_map[key2] = torch.arange(args.num_keys) torch.cuda.synchronize() if i >= num_warmups: t_init += time.perf_counter() - t_start t_start = time.perf_counter() - hash_map[query] + hash_map[query2] torch.cuda.synchronize() if i >= num_warmups: t_get += time.perf_counter() - t_start @@ -63,17 +70,17 @@ print(f' Memory Get: {t_get / num_steps:.4f}s') print('=====================') - if key.is_cpu: + if key1.is_cpu: t_init = t_get = 0 for i in range(num_warmups + num_steps): t_start = time.perf_counter() - hash_map = pd.CategoricalDtype(categories=key.numpy(), + hash_map = pd.CategoricalDtype(categories=key1.numpy(), ordered=True) if i >= num_warmups: t_init += time.perf_counter() - t_start t_start = time.perf_counter() - ser = pd.Series(query.numpy(), dtype=hash_map) + ser = pd.Series(query1.numpy(), dtype=hash_map) ser.cat.codes.to_numpy() if i >= num_warmups: t_get += time.perf_counter() - t_start diff --git a/pyg_lib/csrc/classes/cpu/hash_map.cpp b/pyg_lib/csrc/classes/cpu/hash_map.cpp index d596f2f6a..e7a8bb1cc 100644 --- a/pyg_lib/csrc/classes/cpu/hash_map.cpp +++ b/pyg_lib/csrc/classes/cpu/hash_map.cpp @@ -1,32 +1,36 @@ #include "hash_map.h" +#include #include namespace pyg { namespace classes { -CPUHashMap::CPUHashMap(const at::Tensor& key) { +template +CPUHashMap::CPUHashMap(const at::Tensor& key) { at::TensorArg key_arg{key, "key", 0}; at::CheckedFrom c{"HashMap.init"}; at::checkDeviceType(c, key, at::DeviceType::CPU); at::checkDim(c, key_arg, 1); at::checkContiguous(c, key_arg); - // clang-format off - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, - key.scalar_type(), - "cpu_hash_map_init", - [&] { - const auto key_data = key.data_ptr(); - for (int64_t i = 0; i < key.numel(); ++i) { - auto [iterator, inserted] = map_.insert({key_data[i], i}); - TORCH_CHECK(inserted, "Found duplicated key."); - } - }); - // clang-format on + map_.reserve(key.numel()); + + const auto num_threads = at::get_num_threads(); + const auto grain_size = std::max( + (key.numel() + num_threads - 1) / num_threads, at::internal::GRAIN_SIZE); + const auto key_data = key.data_ptr(); + + at::parallel_for(0, key.numel(), grain_size, [&](int64_t beg, int64_t end) { + for (int64_t i = beg; i < end; ++i) { + auto [iterator, inserted] = map_.insert({key_data[i], i}); + TORCH_CHECK(inserted, "Found duplicated key."); + } + }); }; -at::Tensor CPUHashMap::get(const at::Tensor& query) { +template +at::Tensor CPUHashMap::get(const at::Tensor& query) { at::TensorArg query_arg{query, "query", 0}; at::CheckedFrom c{"HashMap.get"}; at::checkDeviceType(c, query, at::DeviceType::CPU); @@ -37,27 +41,26 @@ at::Tensor CPUHashMap::get(const at::Tensor& query) { const auto out = at::empty({query.numel()}, options); auto out_data = out.data_ptr(); - // clang-format off - AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, - query.scalar_type(), - "cpu_hash_map_get", - [&] { - const auto query_data = query.data_ptr(); + const auto num_threads = at::get_num_threads(); + const auto grain_size = + std::max((query.numel() + num_threads - 1) / num_threads, + at::internal::GRAIN_SIZE); + const auto query_data = query.data_ptr(); - for (size_t i = 0; i < query.numel(); ++i) { - auto it = map_.find(query_data[i]); - out_data[i] = (it != map_.end()) ? it->second : -1; - } - }); - // clang-format on + at::parallel_for(0, query.numel(), grain_size, [&](int64_t beg, int64_t end) { + for (int64_t i = beg; i < end; ++i) { + auto it = map_.find(query_data[i]); + out_data[i] = (it != map_.end()) ? it->second : -1; + } + }); return out; } TORCH_LIBRARY(pyg, m) { - m.class_("CPUHashMap") + m.class_>("CPUHashMap") .def(torch::init()) - .def("get", &CPUHashMap::get); + .def("get", &CPUHashMap::get); } } // namespace classes diff --git a/pyg_lib/csrc/classes/cpu/hash_map.h b/pyg_lib/csrc/classes/cpu/hash_map.h index 2ec9169a8..6540acc87 100644 --- a/pyg_lib/csrc/classes/cpu/hash_map.h +++ b/pyg_lib/csrc/classes/cpu/hash_map.h @@ -1,21 +1,29 @@ #pragma once #include -#include +#include "parallel_hashmap/phmap.h" namespace pyg { namespace classes { +template struct CPUHashMap : torch::CustomClassHolder { public: - using KeyType = std:: - variant; + using ValueType = int64_t; CPUHashMap(const at::Tensor& key); at::Tensor get(const at::Tensor& query); private: - std::unordered_map map_; + phmap::parallel_flat_hash_map< + KeyType, + ValueType, + phmap::priv::hash_default_hash, + phmap::priv::hash_default_eq, + phmap::priv::Allocator>, + 8, + phmap::NullMutex> + map_; }; } // namespace classes From 7bec40f5f2ed0e93c402b7f79c4bad29afb4abad Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 13 Jan 2025 02:30:41 +0000 Subject: [PATCH 04/11] update --- test/csrc/classes/test_hash_map.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/csrc/classes/test_hash_map.cpp b/test/csrc/classes/test_hash_map.cpp index a54400b50..4705b3a84 100644 --- a/test/csrc/classes/test_hash_map.cpp +++ b/test/csrc/classes/test_hash_map.cpp @@ -7,7 +7,7 @@ TEST(CPUHashMapTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); auto key = at::tensor({0, 10, 30, 20}, options); - auto map = pyg::classes::CPUHashMap(key); + auto map = pyg::classes::CPUHashMap(key); auto query = at::tensor({30, 10, 20, 40}, options); auto expected = at::tensor({2, 1, 3, -1}, options); From 809ab3dbb33c8a4bea4de629b4858519642385e3 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 13 Jan 2025 02:34:08 +0000 Subject: [PATCH 05/11] update --- benchmark/classes/hash_map.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/benchmark/classes/hash_map.py b/benchmark/classes/hash_map.py index 845496506..092637f00 100644 --- a/benchmark/classes/hash_map.py +++ b/benchmark/classes/hash_map.py @@ -40,7 +40,7 @@ t_init += time.perf_counter() - t_start t_start = time.perf_counter() - hash_map.get(query1) + out1 = hash_map.get(query1) torch.cuda.synchronize() if i >= num_warmups: t_get += time.perf_counter() - t_start @@ -61,7 +61,7 @@ t_init += time.perf_counter() - t_start t_start = time.perf_counter() - hash_map[query2] + out2 = hash_map[query2] torch.cuda.synchronize() if i >= num_warmups: t_get += time.perf_counter() - t_start @@ -81,9 +81,11 @@ t_start = time.perf_counter() ser = pd.Series(query1.numpy(), dtype=hash_map) - ser.cat.codes.to_numpy() + out3 = ser.cat.codes.to_numpy() if i >= num_warmups: t_get += time.perf_counter() - t_start print(f' Pandas Init: {t_init / num_steps:.4f}s') print(f' Pandas Get: {t_get / num_steps:.4f}s') + + assert out1.equal(torch.tensor(out3)) From 9701e15804c379907cc8d1bcc22019cf26b868f5 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 13 Jan 2025 07:35:37 +0000 Subject: [PATCH 06/11] update --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 10afda777..c2d20125e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,6 +72,7 @@ if (NOT "$ENV{EXTERNAL_PHMAP_INCLUDE_DIR}" STREQUAL "") include_directories($ENV{EXTERNAL_PHMAP_INCLUDE_DIR}) else() set(PHMAP_DIR third_party/parallel-hashmap) + include_directories(${PHMAP_DIR}) target_include_directories(${PROJECT_NAME} PRIVATE ${PHMAP_DIR}) endif() From a32c79434941f14c4a8c98879782ee7a4944136f Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 13 Jan 2025 12:32:22 +0000 Subject: [PATCH 07/11] update --- pyg_lib/csrc/classes/cpu/hash_map.cpp | 24 ++++++++++++++++++++---- pyg_lib/csrc/classes/cpu/hash_map.h | 20 +++++++++++++++++--- test/csrc/classes/test_hash_map.cpp | 2 +- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/pyg_lib/csrc/classes/cpu/hash_map.cpp b/pyg_lib/csrc/classes/cpu/hash_map.cpp index e7a8bb1cc..12f8443f8 100644 --- a/pyg_lib/csrc/classes/cpu/hash_map.cpp +++ b/pyg_lib/csrc/classes/cpu/hash_map.cpp @@ -7,7 +7,7 @@ namespace pyg { namespace classes { template -CPUHashMap::CPUHashMap(const at::Tensor& key) { +CPUHashMapImpl::CPUHashMapImpl(const at::Tensor& key) { at::TensorArg key_arg{key, "key", 0}; at::CheckedFrom c{"HashMap.init"}; at::checkDeviceType(c, key, at::DeviceType::CPU); @@ -30,7 +30,7 @@ CPUHashMap::CPUHashMap(const at::Tensor& key) { }; template -at::Tensor CPUHashMap::get(const at::Tensor& query) { +at::Tensor CPUHashMapImpl::get(const at::Tensor& query) { at::TensorArg query_arg{query, "query", 0}; at::CheckedFrom c{"HashMap.get"}; at::checkDeviceType(c, query, at::DeviceType::CPU); @@ -57,10 +57,26 @@ at::Tensor CPUHashMap::get(const at::Tensor& query) { return out; } +CPUHashMap::CPUHashMap(const at::Tensor& key) { + map_ = std::make_unique>(key); + // clang-format off + AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, + key.scalar_type(), + "cpu_hash_map_init", + [&] { + map_ = std::make_unique>(key); + }); + // clang-format on +} + +at::Tensor CPUHashMap::get(const at::Tensor& query) { + return map_->get(query); +} + TORCH_LIBRARY(pyg, m) { - m.class_>("CPUHashMap") + m.class_("CPUHashMap") .def(torch::init()) - .def("get", &CPUHashMap::get); + .def("get", &CPUHashMap::get); } } // namespace classes diff --git a/pyg_lib/csrc/classes/cpu/hash_map.h b/pyg_lib/csrc/classes/cpu/hash_map.h index 6540acc87..468421193 100644 --- a/pyg_lib/csrc/classes/cpu/hash_map.h +++ b/pyg_lib/csrc/classes/cpu/hash_map.h @@ -6,13 +6,18 @@ namespace pyg { namespace classes { +struct IHashMap { + virtual ~IHashMap() = default; + virtual at::Tensor get(const at::Tensor& query) = 0; +}; + template -struct CPUHashMap : torch::CustomClassHolder { +struct CPUHashMapImpl : IHashMap { public: using ValueType = int64_t; - CPUHashMap(const at::Tensor& key); - at::Tensor get(const at::Tensor& query); + CPUHashMapImpl(const at::Tensor& key); + at::Tensor get(const at::Tensor& query) override; private: phmap::parallel_flat_hash_map< @@ -26,5 +31,14 @@ struct CPUHashMap : torch::CustomClassHolder { map_; }; +struct CPUHashMap : torch::CustomClassHolder { + public: + CPUHashMap(const at::Tensor& key); + at::Tensor get(const at::Tensor& query); + + private: + std::unique_ptr map_; +}; + } // namespace classes } // namespace pyg diff --git a/test/csrc/classes/test_hash_map.cpp b/test/csrc/classes/test_hash_map.cpp index 4705b3a84..a54400b50 100644 --- a/test/csrc/classes/test_hash_map.cpp +++ b/test/csrc/classes/test_hash_map.cpp @@ -7,7 +7,7 @@ TEST(CPUHashMapTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); auto key = at::tensor({0, 10, 30, 20}, options); - auto map = pyg::classes::CPUHashMap(key); + auto map = pyg::classes::CPUHashMap(key); auto query = at::tensor({30, 10, 20, 40}, options); auto expected = at::tensor({2, 1, 3, -1}, options); From 540f03e67f559bdbc8d393009c3b86436d421329 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 13 Jan 2025 12:54:37 +0000 Subject: [PATCH 08/11] update --- pyg_lib/csrc/classes/cpu/hash_map.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/pyg_lib/csrc/classes/cpu/hash_map.cpp b/pyg_lib/csrc/classes/cpu/hash_map.cpp index 12f8443f8..66b2977b9 100644 --- a/pyg_lib/csrc/classes/cpu/hash_map.cpp +++ b/pyg_lib/csrc/classes/cpu/hash_map.cpp @@ -58,7 +58,6 @@ at::Tensor CPUHashMapImpl::get(const at::Tensor& query) { } CPUHashMap::CPUHashMap(const at::Tensor& key) { - map_ = std::make_unique>(key); // clang-format off AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool, key.scalar_type(), From cc52803cc826a3b605ca4c781bac8a4af15a1359 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 13 Jan 2025 13:21:15 +0000 Subject: [PATCH 09/11] update --- pyg_lib/csrc/classes/cpu/hash_map.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyg_lib/csrc/classes/cpu/hash_map.h b/pyg_lib/csrc/classes/cpu/hash_map.h index 468421193..a19081643 100644 --- a/pyg_lib/csrc/classes/cpu/hash_map.h +++ b/pyg_lib/csrc/classes/cpu/hash_map.h @@ -26,8 +26,8 @@ struct CPUHashMapImpl : IHashMap { phmap::priv::hash_default_hash, phmap::priv::hash_default_eq, phmap::priv::Allocator>, - 8, - phmap::NullMutex> + 12, + std::mutex> map_; }; From 300cbd1a8e6e35fe22d60244157e61c597decb40 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 13 Jan 2025 13:26:31 +0000 Subject: [PATCH 10/11] update --- pyg_lib/csrc/classes/cpu/hash_map.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyg_lib/csrc/classes/cpu/hash_map.cpp b/pyg_lib/csrc/classes/cpu/hash_map.cpp index 063952b3f..66b2977b9 100644 --- a/pyg_lib/csrc/classes/cpu/hash_map.cpp +++ b/pyg_lib/csrc/classes/cpu/hash_map.cpp @@ -73,9 +73,9 @@ at::Tensor CPUHashMap::get(const at::Tensor& query) { } TORCH_LIBRARY(pyg, m) { - m.class_>("CPUHashMap") + m.class_("CPUHashMap") .def(torch::init()) - .def("get", &CPUHashMap::get); + .def("get", &CPUHashMap::get); } } // namespace classes From 94bba2d752b2d6c96c83198f42ef15bb7b017d7d Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 13 Jan 2025 13:56:00 +0000 Subject: [PATCH 11/11] update --- test/csrc/classes/test_hash_map.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/csrc/classes/test_hash_map.cpp b/test/csrc/classes/test_hash_map.cpp index 4705b3a84..a54400b50 100644 --- a/test/csrc/classes/test_hash_map.cpp +++ b/test/csrc/classes/test_hash_map.cpp @@ -7,7 +7,7 @@ TEST(CPUHashMapTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); auto key = at::tensor({0, 10, 30, 20}, options); - auto map = pyg::classes::CPUHashMap(key); + auto map = pyg::classes::CPUHashMap(key); auto query = at::tensor({30, 10, 20, 40}, options); auto expected = at::tensor({2, 1, 3, -1}, options);