Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ file(GLOB_RECURSE ALL_SOURCES ${CSRC}/*.cpp)
if (WITH_CUDA)
file(GLOB_RECURSE ALL_SOURCES ${ALL_SOURCES} ${CSRC}/*.cu)
endif()
file(GLOB_RECURSE ALL_HEADERS ${CSRC}/*.h)
add_library(${PROJECT_NAME} SHARED ${ALL_SOURCES})
target_include_directories(${PROJECT_NAME} PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}")
if(MKL_INCLUDE_FOUND)
Expand Down
2 changes: 1 addition & 1 deletion pyg_lib/classes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

class HashMap:
def __init__(self, key: Tensor) -> Tensor:
self._map = torch.classes.pyg.CPUHashMap(key)
self._map = torch.classes.pyg.HashMap(key)

def get(self, query: Tensor) -> Tensor:
return self._map.get(query)
Expand Down
82 changes: 0 additions & 82 deletions pyg_lib/csrc/classes/cpu/hash_map.cpp

This file was deleted.

44 changes: 0 additions & 44 deletions pyg_lib/csrc/classes/cpu/hash_map.h

This file was deleted.

67 changes: 67 additions & 0 deletions pyg_lib/csrc/classes/cpu/hash_map_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#pragma once

#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include "../hash_map_impl.h"
#include "parallel_hashmap/phmap.h"

namespace pyg {
namespace classes {

template <typename KeyType>
struct CPUHashMapImpl : HashMapImpl {
public:
using ValueType = int64_t;

CPUHashMapImpl(const at::Tensor& key) {
map_.reserve(key.numel());

const auto num_threads = at::get_num_threads();
const auto grain_size =
std::max((key.numel() + num_threads - 1) / num_threads,
at::internal::GRAIN_SIZE);
const auto key_data = key.data_ptr<KeyType>();

at::parallel_for(0, key.numel(), grain_size, [&](int64_t beg, int64_t end) {
for (int64_t i = beg; i < end; ++i) {
auto [iterator, inserted] = map_.insert({key_data[i], i});
TORCH_CHECK(inserted, "Found duplicated key in 'HashMap'.");
}
});
}

at::Tensor get(const at::Tensor& query) override {
const auto options = at::TensorOptions().dtype(at::kLong);
const auto out = at::empty({query.numel()}, options);
auto out_data = out.data_ptr<int64_t>();

const auto num_threads = at::get_num_threads();
const auto grain_size =
std::max((query.numel() + num_threads - 1) / num_threads,
at::internal::GRAIN_SIZE);
const auto query_data = query.data_ptr<int64_t>();

at::parallel_for(0, query.numel(), grain_size, [&](int64_t b, int64_t e) {
for (int64_t i = b; i < e; ++i) {
auto it = map_.find(query_data[i]);
out_data[i] = (it != map_.end()) ? it->second : -1;
}
});

return out;
}

private:
phmap::parallel_flat_hash_map<
KeyType,
ValueType,
phmap::priv::hash_default_hash<KeyType>,
phmap::priv::hash_default_eq<KeyType>,
phmap::priv::Allocator<std::pair<const KeyType, ValueType>>,
12,
std::mutex>
map_;
};

} // namespace classes
} // namespace pyg
47 changes: 47 additions & 0 deletions pyg_lib/csrc/classes/hash_map.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include "hash_map.h"

#include <torch/library.h>
#include "cpu/hash_map_impl.h"

namespace pyg {
namespace classes {

HashMap::HashMap(const at::Tensor& key) {
at::TensorArg key_arg{key, "key", 0};
at::CheckedFrom c{"HashMap.init"};
at::checkDeviceType(c, key, at::DeviceType::CPU);
at::checkDim(c, key_arg, 1);
at::checkContiguous(c, key_arg);

// clang-format off
AT_DISPATCH_ALL_TYPES_AND(at::ScalarType::Bool,
key.scalar_type(),
"hash_map_init",
[&] {
/* if (key.is_cpu) { */
map_ = std::make_unique<CPUHashMapImpl<scalar_t>>(key);
/* } else { */
/* AT_ERROR("Received invalid device type for 'HashMap'."); */
/* } */
});
// clang-format on
}

at::Tensor HashMap::get(const at::Tensor& query) {
at::TensorArg query_arg{query, "query", 0};
at::CheckedFrom c{"HashMap.get"};
at::checkDeviceType(c, query, at::DeviceType::CPU);
at::checkDim(c, query_arg, 1);
at::checkContiguous(c, query_arg);

return map_->get(query);
}

TORCH_LIBRARY(pyg, m) {
m.class_<HashMap>("HashMap")
.def(torch::init<at::Tensor&>())
.def("get", &HashMap::get);
}

} // namespace classes
} // namespace pyg
19 changes: 19 additions & 0 deletions pyg_lib/csrc/classes/hash_map.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include <ATen/ATen.h>
#include "hash_map_impl.h"

namespace pyg {
namespace classes {

struct HashMap : torch::CustomClassHolder {
public:
HashMap(const at::Tensor& key);
at::Tensor get(const at::Tensor& query);

private:
std::unique_ptr<HashMapImpl> map_;
};

} // namespace classes
} // namespace pyg
14 changes: 14 additions & 0 deletions pyg_lib/csrc/classes/hash_map_impl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once

#include <ATen/ATen.h>

namespace pyg {
namespace classes {

struct HashMapImpl {
virtual ~HashMapImpl() = default;
virtual at::Tensor get(const at::Tensor& query) = 0;
};

} // namespace classes
} // namespace pyg
6 changes: 3 additions & 3 deletions test/csrc/classes/test_hash_map.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
#include <ATen/ATen.h>
#include <gtest/gtest.h>

#include "pyg_lib/csrc/classes/cpu/hash_map.h"
#include "pyg_lib/csrc/classes/hash_map.h"

TEST(CPUHashMapTest, BasicAssertions) {
TEST(HashMapTest, BasicAssertions) {
auto options = at::TensorOptions().dtype(at::kLong);
auto key = at::tensor({0, 10, 30, 20}, options);

auto map = pyg::classes::CPUHashMap(key);
auto map = pyg::classes::HashMap(key);

auto query = at::tensor({30, 10, 20, 40}, options);
auto expected = at::tensor({2, 1, 3, -1}, options);
Expand Down
Loading