Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .github/workflows/cuda/Linux-env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,42 +4,42 @@ case ${1} in
cu124)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-12.4/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="5.0+PTX;6.0;7.0;7.5;8.0;8.6;9.0"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5;8.0;8.6;9.0"
;;
cu121)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-12.1/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="5.0+PTX;6.0;7.0;7.5;8.0;8.6;9.0"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5;8.0;8.6;9.0"
;;
cu118)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-11.8/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6;9.0"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5;8.0;8.6;9.0"
;;
cu117)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-11.7/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5;8.0;8.6"
;;
cu116)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-11.6/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5;8.0;8.6"
;;
cu115)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-11.5/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5;8.0;8.6"
;;
cu113)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-11.3/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5;8.0;8.6"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5;8.0;8.6"
;;
cu102)
export FORCE_CUDA=1
export PATH=/usr/local/cuda-10.2/bin:${PATH}
export TORCH_CUDA_ARCH_LIST="3.5;5.0+PTX;6.0;7.0;7.5"
export TORCH_CUDA_ARCH_LIST="6.0+PTX;7.0;7.5"
;;
*)
;;
Expand Down
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.15)
cmake_minimum_required(VERSION 3.18)
project(pyg)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
Expand Down Expand Up @@ -43,6 +43,7 @@ if(WITH_CUDA)
enable_language(CUDA)
add_definitions(-DWITH_CUDA)
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr -allow-unsupported-compiler")
set(CMAKE_CUDA_ARCHITECTURES "60;70;75;80;86;90")

if (NOT "$ENV{EXTERNAL_CUTLASS_INCLUDE_DIR}" STREQUAL "")
include_directories($ENV{EXTERNAL_CUTLASS_INCLUDE_DIR})
Expand Down
112 changes: 91 additions & 21 deletions pyg_lib/csrc/classes/cuda/hash_map_impl.cu
Original file line number Diff line number Diff line change
@@ -1,69 +1,139 @@
#include <ATen/ATen.h>
#include <torch/library.h>
#include <cuco/static_map.cuh>

#include "../hash_map_impl.h"
#include <limits>

namespace pyg {
namespace classes {

namespace {

#define DISPATCH_CASE_KEY(...) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)

#define DISPATCH_KEY(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, DISPATCH_CASE_KEY(__VA_ARGS__))

struct HashMapImpl {
virtual ~HashMapImpl() = default;
virtual at::Tensor get(const at::Tensor& query) = 0;
virtual at::Tensor keys() = 0;
};

template <typename KeyType>
struct CUDAHashMapImpl : HashMapImpl {
public:
using ValueType = int64_t;

CUDAHashMapImpl(const at::Tensor& key) {
KeyType constexpr empty_key_sentinel = -1; // TODO
KeyType constexpr empty_key_sentinel = std::numeric_limits<KeyType>::min();
ValueType constexpr empty_value_sentinel = -1;

map_ = std::make_unique<cuco::static_map<KeyType, ValueType>>(
2 * key.numel(), // loader_factor = 0.5
2 * key.numel(), // load_factor = 0.5
cuco::empty_key{empty_key_sentinel},
cuco::empty_value{empty_value_sentinel});

const auto key_data = key.data_ptr<KeyType>();
const auto options =
at::TensorOptions().device(key.device()).dtype(at::kLong);
key.options().dtype(c10::CppTypeToScalarType<ValueType>::value);
const auto value = at::arange(key.numel(), options);
const auto key_data = key.data_ptr<KeyType>();
const auto value_data = value.data_ptr<ValueType>();
const auto zipped =
thrust::make_zip_iterator(thrust::make_tuple(key_data, value_data));

map_->insert(key_data, value_data, key.numel());
map_->insert(zipped, zipped + key.numel());
}

at::Tensor get(const at::Tensor& query) override {
const auto options =
at::TensorOptions().device(query.device()).dtype(at::kLong);
query.options().dtype(c10::CppTypeToScalarType<ValueType>::value);
const auto out = at::empty({query.numel()}, options);
const auto query_data = query.data_ptr<KeyType>();
auto out_data = out.data_ptr<int64_t>();
const auto out_data = out.data_ptr<ValueType>();

map_->find(query_data, out_data, query.numel());
map_->find(query_data, query_data + query.numel(), out_data);

return out;
}

at::Tensor keys() override {
// TODO This will not work in multi-GPU scenarios.
const auto options = at::TensorOptions()
.device(at::DeviceType::CUDA)
.dtype(c10::CppTypeToScalarType<ValueType>::value);
const auto size = static_cast<int64_t>(map_->size());

at::Tensor key;
if (std::is_same<KeyType, int16_t>::value) {
key = at::empty({size}, options.dtype(at::kShort));
} else if (std::is_same<KeyType, int32_t>::value) {
key = at::empty({size}, options.dtype(at::kInt));
} else {
key = at::empty({size}, options);
}
const auto value = at::empty({size}, options);
const auto key_data = key.data_ptr<KeyType>();
const auto value_data = value.data_ptr<ValueType>();

map_->retrieve_all(key_data, value_data);

return key.index_select(0, value.argsort());
}

private:
std::unique_ptr<cuco::static_map<KeyType, ValueType>> map_;
};

// template struct CUDAHashMapImpl<bool>;
// template struct CUDAHashMapImpl<uint8_t>;
// template struct CUDAHashMapImpl<int8_t>;
// template struct CUDAHashMapImpl<int16_t>;
// template struct CUDAHashMapImpl<int32_t>;
// template struct CUDAHashMapImpl<int64_t>;
// template struct CUDAHashMapImpl<float>;
// template struct CUDAHashMapImpl<double>;

struct CUDAHashMap : torch::CustomClassHolder {
public:
CUDAHashMap(const at::Tensor& key) {}
CUDAHashMap(const at::Tensor& key) {
at::TensorArg key_arg{key, "key", 0};
at::CheckedFrom c{"CUDAHashMap.init"};
at::checkDeviceType(c, key, at::DeviceType::CUDA);
at::checkDim(c, key_arg, 1);
at::checkContiguous(c, key_arg);

DISPATCH_KEY(key.scalar_type(), "cuda_hash_map_init", [&] {
map_ = std::make_unique<CUDAHashMapImpl<scalar_t>>(key);
});
}

at::Tensor get(const at::Tensor& query) {
at::TensorArg query_arg{query, "query", 0};
at::CheckedFrom c{"CUDAHashMap.get"};
at::checkDeviceType(c, query, at::DeviceType::CUDA);
at::checkDim(c, query_arg, 1);
at::checkContiguous(c, query_arg);

at::Tensor get(const at::Tensor& query) { return query; }
return map_->get(query);
}

at::Tensor keys() { return map_->keys(); }

private:
std::unique_ptr<HashMapImpl> map_;
};

} // namespace

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.class_<CUDAHashMap>("CUDAHashMap")
.def(torch::init<at::Tensor&>())
.def("get", &CUDAHashMap::get)
.def("keys", &CUDAHashMap::keys)
.def_pickle(
// __getstate__
[](const c10::intrusive_ptr<CUDAHashMap>& self) -> at::Tensor {
return self->keys();
},
// __setstate__
[](const at::Tensor& state) -> c10::intrusive_ptr<CUDAHashMap> {
return c10::make_intrusive<CUDAHashMap>(state);
});
}

} // namespace classes
} // namespace pyg
Loading