From 01a91970c2e9ba9484d034d46e6cf5a54b789c0c Mon Sep 17 00:00:00 2001 From: rusty1s Date: Tue, 28 Jan 2025 11:49:23 +0000 Subject: [PATCH] update --- pyg_lib/csrc/classes/cuda/hash_map_impl.cu | 69 ++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 pyg_lib/csrc/classes/cuda/hash_map_impl.cu diff --git a/pyg_lib/csrc/classes/cuda/hash_map_impl.cu b/pyg_lib/csrc/classes/cuda/hash_map_impl.cu new file mode 100644 index 000000000..101cdaf05 --- /dev/null +++ b/pyg_lib/csrc/classes/cuda/hash_map_impl.cu @@ -0,0 +1,69 @@ +#include +#include + +#include "../hash_map_impl.h" + +namespace pyg { +namespace classes { + +namespace { + +template +struct CUDAHashMapImpl : HashMapImpl { + public: + using ValueType = int64_t; + + CUDAHashMapImpl(const at::Tensor& key) { + KeyType constexpr empty_key_sentinel = -1; // TODO + ValueType constexpr empty_value_sentinel = -1; + + map_ = std::make_unique>( + 2 * key.numel(), // loader_factor = 0.5 + cuco::empty_key{empty_key_sentinel}, + cuco::empty_value{empty_value_sentinel}); + + const auto options = + at::TensorOptions().device(key.device()).dtype(at::kLong); + const auto value = at::arange(key.numel(), options); + const auto key_data = key.data_ptr(); + const auto value_data = value.data_ptr(); + + map_->insert(key_data, value_data, key.numel()); + } + + at::Tensor get(const at::Tensor& query) override { + const auto options = + at::TensorOptions().device(query.device()).dtype(at::kLong); + const auto out = at::empty({query.numel()}, options); + const auto query_data = query.data_ptr(); + auto out_data = out.data_ptr(); + + map_->find(query_data, out_data, query.numel()); + + return out; + } + + private: + std::unique_ptr> map_; +}; + +// template struct CUDAHashMapImpl; +// template struct CUDAHashMapImpl; +// template struct CUDAHashMapImpl; +// template struct CUDAHashMapImpl; +// template struct CUDAHashMapImpl; +// template struct CUDAHashMapImpl; +// template struct CUDAHashMapImpl; +// template struct CUDAHashMapImpl; + +struct CUDAHashMap : torch::CustomClassHolder { + public: + CUDAHashMap(const at::Tensor& key) {} + + at::Tensor get(const at::Tensor& query) { return query; } +}; + +} // namespace + +} // namespace classes +} // namespace pyg