Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 39 additions & 20 deletions pyg_lib/csrc/classes/cpu/hash_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ struct HashMapImpl {
virtual ~HashMapImpl() = default;
virtual at::Tensor get(const at::Tensor& query) = 0;
virtual at::Tensor keys() = 0;
virtual int64_t size() = 0;
virtual at::ScalarType dtype() = 0;
};

template <typename KeyType>
Expand Down Expand Up @@ -60,16 +62,8 @@ struct CPUHashMapImpl : HashMapImpl {
}

at::Tensor keys() override {
const auto size = static_cast<int64_t>(map_.size());

at::Tensor key;
if (std::is_same<KeyType, int16_t>::value) {
key = at::empty({size}, at::TensorOptions().dtype(at::kShort));
} else if (std::is_same<KeyType, int32_t>::value) {
key = at::empty({size}, at::TensorOptions().dtype(at::kInt));
} else {
key = at::empty({size}, at::TensorOptions().dtype(at::kLong));
}
const at::Tensor key =
at::empty({size()}, at::TensorOptions().dtype(dtype()));
const auto key_data = key.data_ptr<KeyType>();

for (const auto& pair : map_) { // No efficient multi-threading possible :(
Expand All @@ -79,6 +73,18 @@ struct CPUHashMapImpl : HashMapImpl {
return key;
}

int64_t size() override { return static_cast<int64_t>(map_.size()); }

at::ScalarType dtype() override {
if (std::is_same<KeyType, int16_t>::value) {
return at::kShort;
} else if (std::is_same<KeyType, int32_t>::value) {
return at::kInt;
} else {
return at::kLong;
}
}

private:
phmap::flat_hash_map<KeyType, ValueType> map_;
};
Expand Down Expand Up @@ -129,16 +135,8 @@ struct ParallelCPUHashMapImpl : HashMapImpl {
}

at::Tensor keys() override {
const auto size = static_cast<int64_t>(map_.size());

at::Tensor key;
if (std::is_same<KeyType, int16_t>::value) {
key = at::empty({size}, at::TensorOptions().dtype(at::kShort));
} else if (std::is_same<KeyType, int32_t>::value) {
key = at::empty({size}, at::TensorOptions().dtype(at::kInt));
} else {
key = at::empty({size}, at::TensorOptions().dtype(at::kLong));
}
const at::Tensor key =
at::empty({size()}, at::TensorOptions().dtype(dtype()));
const auto key_data = key.data_ptr<KeyType>();

for (const auto& pair : map_) { // No efficient multi-threading possible :(
Expand All @@ -148,6 +146,18 @@ struct ParallelCPUHashMapImpl : HashMapImpl {
return key;
}

int64_t size() override { return static_cast<int64_t>(map_.size()); }

at::ScalarType dtype() override {
if (std::is_same<KeyType, int16_t>::value) {
return at::kShort;
} else if (std::is_same<KeyType, int32_t>::value) {
return at::kInt;
} else {
return at::kLong;
}
}

private:
phmap::parallel_flat_hash_map<
KeyType,
Expand Down Expand Up @@ -234,6 +244,12 @@ struct CPUHashMap : torch::CustomClassHolder {

at::Tensor keys() { return map_->keys(); }

int64_t size() { return map_->size(); }

at::ScalarType dtype() { return map_->dtype(); }

at::Device device() { return at::Device(at::kCPU); }

private:
std::unique_ptr<HashMapImpl> map_;
};
Expand All @@ -245,6 +261,9 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) {
.def(torch::init<at::Tensor&, int64_t>())
.def("get", &CPUHashMap::get)
.def("keys", &CPUHashMap::keys)
.def("size", &CPUHashMap::size)
.def("dtype", &CPUHashMap::dtype)
.def("device", &CPUHashMap::device)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<CPUHashMap>& self) -> at::Tensor {
Expand Down
54 changes: 39 additions & 15 deletions pyg_lib/csrc/classes/cuda/hash_map.cu
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/library.h>
#include <limits>

Expand All @@ -23,6 +24,9 @@ struct HashMapImpl {
virtual ~HashMapImpl() = default;
virtual at::Tensor get(const at::Tensor& query) = 0;
virtual at::Tensor keys() = 0;
virtual int64_t size() = 0;
virtual at::ScalarType dtype() = 0;
virtual at::Device device() = 0;
};

#ifndef _WIN32
Expand All @@ -31,7 +35,10 @@ struct CUDAHashMapImpl : HashMapImpl {
public:
using ValueType = int64_t;

CUDAHashMapImpl(const at::Tensor& key, double load_factor) {
CUDAHashMapImpl(const at::Tensor& key, double load_factor)
: device_(key.device()) {
c10::cuda::MaybeSetDevice(key.get_device());

KeyType constexpr empty_key_sentinel = std::numeric_limits<KeyType>::min();
ValueType constexpr empty_value_sentinel = -1;

Expand All @@ -52,6 +59,8 @@ struct CUDAHashMapImpl : HashMapImpl {
}

at::Tensor get(const at::Tensor& query) override {
c10::cuda::MaybeSetDevice(query.get_device());

const auto options =
query.options().dtype(c10::CppTypeToScalarType<ValueType>::value);
const auto out = at::empty({query.numel()}, options);
Expand All @@ -64,21 +73,12 @@ struct CUDAHashMapImpl : HashMapImpl {
}

at::Tensor keys() override {
// TODO This will not work in multi-GPU scenarios.
const auto options = at::TensorOptions()
.device(at::DeviceType::CUDA)
.dtype(c10::CppTypeToScalarType<ValueType>::value);
const auto size = static_cast<int64_t>(map_->size());
c10::cuda::MaybeSetDevice(device_.index());

at::Tensor key;
if (std::is_same<KeyType, int16_t>::value) {
key = at::empty({size}, options.dtype(at::kShort));
} else if (std::is_same<KeyType, int32_t>::value) {
key = at::empty({size}, options.dtype(at::kInt));
} else {
key = at::empty({size}, options);
}
const auto value = at::empty({size}, options);
const auto options = at::TensorOptions().device(device_);
const at::Tensor key = at::empty({size()}, options.dtype(dtype()));
const at::Tensor value = at::empty(
{size()}, options.dtype(c10::CppTypeToScalarType<ValueType>::value));
const auto key_data = key.data_ptr<KeyType>();
const auto value_data = value.data_ptr<ValueType>();

Expand All @@ -90,8 +90,23 @@ struct CUDAHashMapImpl : HashMapImpl {
return key.index_select(0, perm);
}

int64_t size() override { return static_cast<int64_t>(map_->size()); }

at::ScalarType dtype() override {
if (std::is_same<KeyType, int16_t>::value) {
return at::kShort;
} else if (std::is_same<KeyType, int32_t>::value) {
return at::kInt;
} else {
return at::kLong;
}
}

at::Device device() override { return device_; }

private:
std::unique_ptr<cuco::static_map<KeyType, ValueType>> map_;
at::Device device_;
};
#endif

Expand Down Expand Up @@ -135,6 +150,12 @@ struct CUDAHashMap : torch::CustomClassHolder {
#endif
}

int64_t size() { return map_->size(); }

at::ScalarType dtype() { return map_->dtype(); }

at::Device device() { return map_->device(); }

private:
#ifndef _WIN32
std::unique_ptr<HashMapImpl> map_;
Expand All @@ -148,6 +169,9 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) {
.def(torch::init<at::Tensor&, double>())
.def("get", &CUDAHashMap::get)
.def("keys", &CUDAHashMap::keys)
.def("size", &CUDAHashMap::size)
.def("dtype", &CUDAHashMap::dtype)
.def("device", &CUDAHashMap::device)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<CUDAHashMap>& self) -> at::Tensor {
Expand Down
9 changes: 9 additions & 0 deletions test/classes/test_hash_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

from pyg_lib.testing import withCUDA

INT_TO_DTPYE = {
2: torch.short,
3: torch.int,
4: torch.long,
}


@withCUDA
@pytest.mark.parametrize('dtype', [torch.short, torch.int, torch.long])
Expand All @@ -22,6 +28,9 @@ def test_hash_map(dtype, device):
else:
raise NotImplementedError(f"Unsupported device '{device}'")

assert hash_map.size() == 4
assert INT_TO_DTPYE[hash_map.dtype()] == dtype
assert hash_map.device() == device
assert hash_map.keys().equal(key)
assert hash_map.keys().dtype == dtype
expected = torch.tensor([2, 1, 3, -1], device=device)
Expand Down
3 changes: 3 additions & 0 deletions test/csrc/classes/test_hash_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ TEST(HashMapTest, BasicAssertions) {
auto key = at::tensor({0, 10, 30, 20}, options);

auto map = pyg::classes::CPUHashMap(key);
EXPECT_EQ(map.size(), 4);
EXPECT_EQ(map.dtype(), at::kLong);
EXPECT_EQ(map.device(), at::Device(at::kCPU));
EXPECT_TRUE(at::equal(map.keys(), key));

auto query = at::tensor({30, 10, 20, 40}, options);
Expand Down
Loading