From fab3e078badc3782cb5eba59dfa05de7c9951740 Mon Sep 17 00:00:00 2001 From: OlhaBabicheva Date: Thu, 17 Aug 2023 13:38:42 +0000 Subject: [PATCH 01/13] Included RandintEngine code --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 259 ++++++++++++------- pyg_lib/csrc/sampler/cpu/neighbor_kernel.h | 7 +- pyg_lib/csrc/sampler/neighbor.cpp | 19 +- pyg_lib/csrc/sampler/neighbor.h | 7 +- 4 files changed, 192 insertions(+), 100 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 9e5cde4ec..f3ba996ec 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -36,20 +36,39 @@ class NeighborSampler { "No valid temporal strategy found"); } + void multinomial_sample(const node_t global_src_node, + const scalar_t local_src_node, + const at::Tensor& weights, + const int64_t count, + pyg::sampler::Mapper& dst_mapper, + pyg::random::RandintEngine& generator, + std::vector& out_global_dst_nodes) { + const auto row_start = rowptr_[to_scalar_t(global_src_node)]; + const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; + at::Tensor weights_neighborhood = weights.index( + {at::indexing::Slice(row_start, row_end)}); + + _sample(global_src_node, local_src_node, row_start, row_end, count, weights_neighborhood, true, + dst_mapper, generator, out_global_dst_nodes); + } + void uniform_sample(const node_t global_src_node, const scalar_t local_src_node, + const at::Tensor& weights, const int64_t count, pyg::sampler::Mapper& dst_mapper, pyg::random::RandintEngine& generator, std::vector& out_global_dst_nodes) { const auto row_start = rowptr_[to_scalar_t(global_src_node)]; const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; - _sample(global_src_node, local_src_node, row_start, row_end, count, + + _sample(global_src_node, local_src_node, row_start, row_end, count, weights, false, dst_mapper, generator, out_global_dst_nodes); } void temporal_sample(const node_t global_src_node, const scalar_t local_src_node, + const at::Tensor& weights, const int64_t count, const temporal_t seed_time, const temporal_t* time, @@ -58,6 +77,7 @@ class NeighborSampler { std::vector& out_global_dst_nodes) { auto row_start = rowptr_[to_scalar_t(global_src_node)]; auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; + auto population = row_end - row_start; // Find new `row_end` such that all neighbors fulfill temporal constraints: auto it = std::upper_bound( @@ -74,7 +94,7 @@ class NeighborSampler { "Found invalid non-sorted temporal neighborhood"); } - _sample(global_src_node, local_src_node, row_start, row_end, count, + _sample(global_src_node, local_src_node, row_start, row_end, count, weights, false, dst_mapper, generator, out_global_dst_nodes); } @@ -114,6 +134,8 @@ class NeighborSampler { const scalar_t row_start, const scalar_t row_end, const int64_t count, + const at::Tensor& weight_vector, + bool multinomial_mode, pyg::sampler::Mapper& dst_mapper, pyg::random::RandintEngine& generator, std::vector& out_global_dst_nodes) { @@ -131,51 +153,65 @@ class NeighborSampler { add(edge_id, global_src_node, local_src_node, dst_mapper, out_global_dst_nodes); } - } - - // Case 2: Sample with replacement: - else if (replace) { - if (row_end < (1 << 16)) { - const auto arr = std::move( - generator.generate_range_of_ints(row_start, row_end, count)); - for (const auto edge_id : arr) - add(edge_id, global_src_node, local_src_node, dst_mapper, + } else { + if (multinomial_mode) { + // Multinomial sampling + at::Tensor edges = at::multinomial(weight_vector, count, replace); + + // Add edges to the sampled list one-by-one + for (int i = 0; i < edges.numel(); ++i) { + const auto edge = row_start + edges.index({i}).item(); + add(edge, global_src_node, local_src_node, dst_mapper, out_global_dst_nodes); + } // Sampling complete } else { - for (int64_t i = 0; i < count; ++i) { - const auto edge_id = generator(row_start, row_end); - add(edge_id, global_src_node, local_src_node, dst_mapper, - out_global_dst_nodes); - } - } - } - - // Case 3: Sample without replacement: - else { - auto index_tracker = IndexTracker(population); - if (population < (1 << 16)) { - const auto arr = - std::move(generator.generate_range_of_ints(0, population, count)); - for (auto i = 0; i < arr.size(); ++i) { - auto rnd = arr[i]; - if (!index_tracker.try_insert(rnd)) { - rnd = population - count + i; - index_tracker.insert(population - count + i); + // Multinomial mode not enabled + + // Case 2: Sample with replacement: + if (replace) { + if (row_end < (1 << 16)) { + const auto arr = std::move( + generator.generate_range_of_ints(row_start, row_end, count)); + for (const auto edge_id : arr) + add(edge_id, global_src_node, local_src_node, dst_mapper, + out_global_dst_nodes); + } else { + for (int64_t i = 0; i < count; ++i) { + const auto edge_id = generator(row_start, row_end); + add(edge_id, global_src_node, local_src_node, dst_mapper, + out_global_dst_nodes); + } } - const auto edge_id = row_start + rnd; - add(edge_id, global_src_node, local_src_node, dst_mapper, - out_global_dst_nodes); } - } else { - for (auto i = population - count; i < population; ++i) { - auto rnd = generator(0, i + 1); - if (!index_tracker.try_insert(rnd)) { - rnd = i; - index_tracker.insert(i); + + // Case 3: Sample without replacement: + else { + auto index_tracker = IndexTracker(population); + if (population < (1 << 16)) { + const auto arr = + std::move(generator.generate_range_of_ints(0, population, count)); + for (auto i = 0; i < arr.size(); ++i) { + auto rnd = arr[i]; + if (!index_tracker.try_insert(rnd)) { + rnd = population - count + i; + index_tracker.insert(population - count + i); + } + const auto edge_id = row_start + rnd; + add(edge_id, global_src_node, local_src_node, dst_mapper, + out_global_dst_nodes); + } + } else { + for (auto i = population - count; i < population; ++i) { + auto rnd = generator(0, i + 1); + if (!index_tracker.try_insert(rnd)) { + rnd = i; + index_tracker.insert(i); + } + const auto edge_id = row_start + rnd; + add(edge_id, global_src_node, local_src_node, dst_mapper, + out_global_dst_nodes); + } } - const auto edge_id = row_start + rnd; - add(edge_id, global_src_node, local_src_node, dst_mapper, - out_global_dst_nodes); } } } @@ -216,7 +252,7 @@ class NeighborSampler { // Homogeneous neighbor sampling /////////////////////////////////////////////// -template +template std::tuple> sample(const at::Tensor& rowptr, const at::Tensor& col, + const at::Tensor& weights, const at::Tensor& seed, const std::vector& num_neighbors, const c10::optional& time, @@ -297,17 +334,27 @@ sample(const at::Tensor& rowptr, const auto count = num_neighbors[ell]; sampler.num_sampled_edges_per_hop.push_back(0); if (!time.has_value()) { - for (size_t i = begin; i < end; ++i) { - sampler.uniform_sample(/*global_src_node=*/sampled_nodes[i], - /*local_src_node=*/i, count, mapper, generator, - /*out_global_dst_nodes=*/sampled_nodes); + if (!multinomial_mode) { + for (size_t i = begin; i < end; ++i) { + sampler.uniform_sample(/*global_src_node=*/sampled_nodes[i], + /*local_src_node=*/i, weights, count, mapper, + generator, + /*out_global_dst_nodes=*/sampled_nodes); + } + } else { + for (size_t i = begin; i < end; ++i) { + sampler.multinomial_sample(/*global_src_node=*/sampled_nodes[i], + /*local_src_node=*/i, weights, count, mapper, + generator, + /*out_global_dst_nodes=*/sampled_nodes); + } } } else if constexpr (!std::is_scalar::value) { // Temporal: const auto time_data = time.value().data_ptr(); for (size_t i = begin; i < end; ++i) { const auto batch_idx = sampled_nodes[i].first; sampler.temporal_sample(/*global_src_node=*/sampled_nodes[i], - /*local_src_node=*/i, count, + /*local_src_node=*/i, weights, count, seed_times[batch_idx], time_data, mapper, generator, /*out_global_dst_nodes=*/sampled_nodes); @@ -334,7 +381,7 @@ sample(const at::Tensor& rowptr, // Heterogeneous neighbor sampling ///////////////////////////////////////////// -template +template std::tuple, c10::Dict, c10::Dict, @@ -532,7 +579,7 @@ sample(const std::vector& node_types, for (size_t i = begin; i < end; ++i) { sampler.uniform_sample( /*global_src_node=*/src_sampled_nodes[i], - /*local_src_node=*/i, count, dst_mapper, generator, + /*local_src_node=*/i, /*PH for weights*/at::ones({1,1}), count, dst_mapper, generator, dst_sampled_nodes); } } else if constexpr (!std::is_scalar< @@ -543,7 +590,7 @@ sample(const std::vector& node_types, const auto batch_idx = src_sampled_nodes[i].first; sampler.temporal_sample( /*global_src_node=*/src_sampled_nodes[i], - /*local_src_node=*/i, count, seed_times[batch_idx], + /*local_src_node=*/i, /*PH for weights*/at::ones({1,1}), count, seed_times[batch_idx], dst_time_data, dst_mapper, generator, dst_sampled_nodes); } @@ -596,39 +643,72 @@ sample(const std::vector& node_types, // Dispatcher ////////////////////////////////////////////////////////////////// -#define DISPATCH_SAMPLE(replace, directed, disjount, return_edge_id, ...) \ - if (replace && directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && return_edge_id) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && !return_edge_id) \ - return sample(__VA_ARGS__); +#define DISPATCH_SAMPLE(replace, directed, disjount, return_edge_id, multinomial_mode, ...) \ + if (replace && directed && disjoint && return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && directed && disjoint && !return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && !return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && !return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && !return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && !return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && !return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && !return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && !return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && directed && disjoint && return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && directed && disjoint && !return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && !return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && !return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && !return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && !return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && !return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && !return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && !return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); + } // namespace @@ -640,6 +720,7 @@ std::tuple> neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, + const at::Tensor& weights, const at::Tensor& seed, const std::vector& num_neighbors, const c10::optional& time, @@ -649,8 +730,9 @@ neighbor_sample_kernel(const at::Tensor& rowptr, bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id) { - DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col, + bool return_edge_id, + bool multinomial_mode) { + DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, multinomial_mode, rowptr, col, weights, seed, num_neighbors, time, seed_time, csc, temporal_strategy); } @@ -674,8 +756,9 @@ hetero_neighbor_sample_kernel( bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id) { - DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, node_types, + bool return_edge_id, + bool multinomial_mode) { + DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, multinomial_mode, node_types, edge_types, rowptr_dict, col_dict, seed_dict, num_neighbors_dict, time_dict, seed_time_dict, csc, temporal_strategy); diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h index 0e0a532f2..e8863a7ad 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h @@ -13,6 +13,7 @@ std::tuple> neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, + const at::Tensor& weights, const at::Tensor& seed, const std::vector& num_neighbors, const c10::optional& time, @@ -22,7 +23,8 @@ neighbor_sample_kernel(const at::Tensor& rowptr, bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id); + bool return_edge_id, + bool multinomial_mode); std::tuple, c10::Dict, @@ -44,7 +46,8 @@ hetero_neighbor_sample_kernel( bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id); + bool return_edge_id, + bool multinomial_mode); } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index f0550b78a..185f09a81 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -16,6 +16,7 @@ std::tuple> neighbor_sample(const at::Tensor& rowptr, const at::Tensor& col, + const at::Tensor& weights, const at::Tensor& seed, const std::vector& num_neighbors, const c10::optional& time, @@ -25,7 +26,8 @@ neighbor_sample(const at::Tensor& rowptr, bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id) { + bool return_edge_id, + bool multinomial_mode) { at::TensorArg rowptr_t{rowptr, "rowtpr", 1}; at::TensorArg col_t{col, "col", 1}; at::TensorArg seed_t{seed, "seed", 1}; @@ -37,9 +39,9 @@ neighbor_sample(const at::Tensor& rowptr, static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::neighbor_sample", "") .typed(); - return op.call(rowptr, col, seed, num_neighbors, time, seed_time, csc, + return op.call(rowptr, col, weights, seed, num_neighbors, time, seed_time, csc, replace, directed, disjoint, temporal_strategy, - return_edge_id); + return_edge_id, multinomial_mode); } std::tuple, @@ -62,7 +64,8 @@ hetero_neighbor_sample( bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id) { + bool return_edge_id, + bool multinomial_mode) { TORCH_CHECK(rowptr_dict.size() == col_dict.size(), "Number of edge types in 'rowptr_dict' and 'col_dict' must match") @@ -88,15 +91,15 @@ hetero_neighbor_sample( .typed(); return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict, num_neighbors_dict, time_dict, seed_time_dict, csc, replace, - directed, disjoint, temporal_strategy, return_edge_id); + directed, disjoint, temporal_strategy, return_edge_id, multinomial_mode); } TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] " + "pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor weights, Tensor seed, int[] " "num_neighbors, Tensor? time = None, Tensor? seed_time = None, bool csc " "= False, bool replace = False, bool directed = True, bool disjoint = " - "False, str temporal_strategy = 'uniform', bool return_edge_id = True) " + "False, str temporal_strategy = 'uniform', bool return_edge_id = True, bool multinomial_mode = False) " "-> (Tensor, Tensor, Tensor, Tensor?, int[], int[])")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] " @@ -105,7 +108,7 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { "Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict " "= None, bool csc = False, bool replace = False, bool directed = True, " "bool disjoint = False, str temporal_strategy = 'uniform', bool " - "return_edge_id = True) -> (Dict(str, Tensor), Dict(str, Tensor), " + "return_edge_id = True, bool multinomial_mode = False) -> (Dict(str, Tensor), Dict(str, Tensor), " "Dict(str, Tensor), Dict(str, Tensor)?, Dict(str, int[]), " "Dict(str, int[]))")); } diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index 55114450a..cd28a7c0c 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -19,6 +19,7 @@ std::tuple> neighbor_sample(const at::Tensor& rowptr, const at::Tensor& col, + const at::Tensor& weights, const at::Tensor& seed, const std::vector& num_neighbors, const c10::optional& time = c10::nullopt, @@ -28,7 +29,8 @@ neighbor_sample(const at::Tensor& rowptr, bool directed = true, bool disjoint = false, std::string strategy = "uniform", - bool return_edge_id = true); + bool return_edge_id = true, + bool multinomial_mode = false); // Recursively samples neighbors from all node indices in `seed_dict` // in the heterogeneous graph given by `(rowptr_dict, col_dict)`. @@ -56,7 +58,8 @@ hetero_neighbor_sample( bool directed = true, bool disjoint = false, std::string strategy = "uniform", - bool return_edge_id = true); + bool return_edge_id = true, + bool multinomial_mode = false); } // namespace sampler } // namespace pyg From 0250aade537af1b00291a7ef4485fa1d63747860 Mon Sep 17 00:00:00 2001 From: OlhaBabicheva Date: Thu, 17 Aug 2023 13:57:56 +0000 Subject: [PATCH 02/13] Changed edges index item to scalar_t type --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index f3ba996ec..d0abe0c35 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -160,7 +160,7 @@ class NeighborSampler { // Add edges to the sampled list one-by-one for (int i = 0; i < edges.numel(); ++i) { - const auto edge = row_start + edges.index({i}).item(); + const auto edge = row_start + edges.index({i}).item(); add(edge, global_src_node, local_src_node, dst_mapper, out_global_dst_nodes); } // Sampling complete From aac6c0aa051b546bb1540dbd713a37b8649aa0cc Mon Sep 17 00:00:00 2001 From: OlhaBabicheva Date: Thu, 17 Aug 2023 20:44:58 +0000 Subject: [PATCH 03/13] Fixed issues with pyg_lib.sampler.neighbor_sample (changed __init__.py) --- pyg_lib/sampler/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index 2cfea286c..3e4d49ed1 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -13,6 +13,7 @@ def neighbor_sample( col: Tensor, seed: Tensor, num_neighbors: List[int], + weights: Optional[Tensor] = None, time: Optional[Tensor] = None, seed_time: Optional[Tensor] = None, csc: bool = False, @@ -21,6 +22,7 @@ def neighbor_sample( disjoint: bool = False, temporal_strategy: str = 'uniform', return_edge_id: bool = True, + multinomial_mode: bool = False, ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], List[int], List[int]]: r"""Recursively samples neighbors from all node indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`. @@ -72,10 +74,10 @@ def neighbor_sample( Lastly, returns information about the sampled amount of nodes and edges per hop. """ - return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors, + return torch.ops.pyg.neighbor_sample(rowptr, col, weights, seed, num_neighbors, time, seed_time, csc, replace, directed, disjoint, temporal_strategy, - return_edge_id) + return_edge_id, multinomial_mode) def hetero_neighbor_sample( From ded35dfef6287eb2f937a78ebd49423855cff841 Mon Sep 17 00:00:00 2001 From: OlhaBabicheva Date: Fri, 25 Aug 2023 15:14:11 +0000 Subject: [PATCH 04/13] Tests update --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 7 ++++--- pyg_lib/csrc/sampler/cpu/neighbor_kernel.h | 3 +-- test/csrc/sampler/test_neighbor.cpp | 14 +++++++------- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index d0abe0c35..694ddf062 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -756,9 +756,10 @@ hetero_neighbor_sample_kernel( bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id, - bool multinomial_mode) { - DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, multinomial_mode, node_types, + bool return_edge_id) + // bool multinomial_mode) + { + DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, 0, node_types, edge_types, rowptr_dict, col_dict, seed_dict, num_neighbors_dict, time_dict, seed_time_dict, csc, temporal_strategy); diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h index e8863a7ad..6588a372c 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h @@ -46,8 +46,7 @@ hetero_neighbor_sample_kernel( bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id, - bool multinomial_mode); + bool return_edge_id); } // namespace sampler } // namespace pyg diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index 2c3c75709..1ec1db545 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -13,8 +13,8 @@ TEST(FullNeighborTest, BasicAssertions) { std::vector num_neighbors = {-1, -1}; auto out = pyg::sampler::neighbor_sample(/*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), seed, - num_neighbors); + /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), + seed, num_neighbors); auto expected_row = at::tensor({0, 0, 1, 1, 2, 2, 3, 3}, options); EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); @@ -40,7 +40,7 @@ TEST(WithoutReplacementNeighborTest, BasicAssertions) { at::manual_seed(123456); auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), seed, num_neighbors, /*time=*/c10::nullopt, + /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), seed, num_neighbors, /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false); auto expected_row = at::tensor({0, 1, 2, 3}, options); @@ -63,7 +63,7 @@ TEST(WithReplacementNeighborTest, BasicAssertions) { at::manual_seed(123456); auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), seed, num_neighbors, /*time=*/c10::nullopt, + /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), seed, num_neighbors, /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/true); auto expected_row = at::tensor({0, 1, 2, 3}, options); @@ -85,7 +85,7 @@ TEST(DisjointNeighborTest, BasicAssertions) { auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), seed, num_neighbors, /*time=*/c10::nullopt, + /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), seed, num_neighbors, /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true); @@ -115,7 +115,7 @@ TEST(TemporalNeighborTest, BasicAssertions) { col = std::get<0>(at::sort(col.view({-1, 2}), /*dim=*/1)).flatten(); auto out1 = pyg::sampler::neighbor_sample( - rowptr, col, seed, /*num_neighbors=*/{2, 2}, /*time=*/time, + rowptr, col, /*weights=*/at::ones(8), seed, /*num_neighbors=*/{2, 2}, /*time=*/time, /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true); @@ -131,7 +131,7 @@ TEST(TemporalNeighborTest, BasicAssertions) { EXPECT_TRUE(at::equal(std::get<3>(out1).value(), expected_edges)); auto out2 = pyg::sampler::neighbor_sample( - rowptr, col, seed, /*num_neighbors=*/{1, 2}, /*time=*/time, + rowptr, col, /*weights=*/at::ones(8), seed, /*num_neighbors=*/{1, 2}, /*time=*/time, /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true, /*temporal_strategy=*/"last"); From bfe7e3ad595e7bd52050970627703bfb2eebdc2b Mon Sep 17 00:00:00 2001 From: OlhaBabicheva Date: Mon, 28 Aug 2023 10:03:48 +0000 Subject: [PATCH 05/13] Fixed pip install --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 151 +++++++++++-------- pyg_lib/csrc/sampler/neighbor.cpp | 7 +- pyg_lib/csrc/sampler/neighbor.h | 3 +- 3 files changed, 88 insertions(+), 73 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 694ddf062..1697edf9d 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -54,7 +54,6 @@ class NeighborSampler { void uniform_sample(const node_t global_src_node, const scalar_t local_src_node, - const at::Tensor& weights, const int64_t count, pyg::sampler::Mapper& dst_mapper, pyg::random::RandintEngine& generator, @@ -62,13 +61,12 @@ class NeighborSampler { const auto row_start = rowptr_[to_scalar_t(global_src_node)]; const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; - _sample(global_src_node, local_src_node, row_start, row_end, count, weights, false, + _sample(global_src_node, local_src_node, row_start, row_end, count, dst_mapper, generator, out_global_dst_nodes); } void temporal_sample(const node_t global_src_node, const scalar_t local_src_node, - const at::Tensor& weights, const int64_t count, const temporal_t seed_time, const temporal_t* time, @@ -94,7 +92,7 @@ class NeighborSampler { "Found invalid non-sorted temporal neighborhood"); } - _sample(global_src_node, local_src_node, row_start, row_end, count, weights, false, + _sample(global_src_node, local_src_node, row_start, row_end, count, dst_mapper, generator, out_global_dst_nodes); } @@ -134,8 +132,6 @@ class NeighborSampler { const scalar_t row_start, const scalar_t row_end, const int64_t count, - const at::Tensor& weight_vector, - bool multinomial_mode, pyg::sampler::Mapper& dst_mapper, pyg::random::RandintEngine& generator, std::vector& out_global_dst_nodes) { @@ -153,70 +149,93 @@ class NeighborSampler { add(edge_id, global_src_node, local_src_node, dst_mapper, out_global_dst_nodes); } - } else { - if (multinomial_mode) { - // Multinomial sampling - at::Tensor edges = at::multinomial(weight_vector, count, replace); - - // Add edges to the sampled list one-by-one - for (int i = 0; i < edges.numel(); ++i) { - const auto edge = row_start + edges.index({i}).item(); - add(edge, global_src_node, local_src_node, dst_mapper, + } + + // Case 2: Sample with replacement: + else if (replace) { + if (row_end < (1 << 16)) { + const auto arr = std::move( + generator.generate_range_of_ints(row_start, row_end, count)); + for (const auto edge_id : arr) + add(edge_id, global_src_node, local_src_node, dst_mapper, out_global_dst_nodes); - } // Sampling complete } else { - // Multinomial mode not enabled - - // Case 2: Sample with replacement: - if (replace) { - if (row_end < (1 << 16)) { - const auto arr = std::move( - generator.generate_range_of_ints(row_start, row_end, count)); - for (const auto edge_id : arr) - add(edge_id, global_src_node, local_src_node, dst_mapper, - out_global_dst_nodes); - } else { - for (int64_t i = 0; i < count; ++i) { - const auto edge_id = generator(row_start, row_end); - add(edge_id, global_src_node, local_src_node, dst_mapper, - out_global_dst_nodes); - } - } + for (int64_t i = 0; i < count; ++i) { + const auto edge_id = generator(row_start, row_end); + add(edge_id, global_src_node, local_src_node, dst_mapper, + out_global_dst_nodes); } + } + } - // Case 3: Sample without replacement: - else { - auto index_tracker = IndexTracker(population); - if (population < (1 << 16)) { - const auto arr = - std::move(generator.generate_range_of_ints(0, population, count)); - for (auto i = 0; i < arr.size(); ++i) { - auto rnd = arr[i]; - if (!index_tracker.try_insert(rnd)) { - rnd = population - count + i; - index_tracker.insert(population - count + i); - } - const auto edge_id = row_start + rnd; - add(edge_id, global_src_node, local_src_node, dst_mapper, - out_global_dst_nodes); - } - } else { - for (auto i = population - count; i < population; ++i) { - auto rnd = generator(0, i + 1); - if (!index_tracker.try_insert(rnd)) { - rnd = i; - index_tracker.insert(i); - } - const auto edge_id = row_start + rnd; - add(edge_id, global_src_node, local_src_node, dst_mapper, - out_global_dst_nodes); - } + // Case 3: Sample without replacement: + else { + auto index_tracker = IndexTracker(population); + if (population < (1 << 16)) { + const auto arr = + std::move(generator.generate_range_of_ints(0, population, count)); + for (auto i = 0; i < arr.size(); ++i) { + auto rnd = arr[i]; + if (!index_tracker.try_insert(rnd)) { + rnd = population - count + i; + index_tracker.insert(population - count + i); } + const auto edge_id = row_start + rnd; + add(edge_id, global_src_node, local_src_node, dst_mapper, + out_global_dst_nodes); } + } else { + for (auto i = population - count; i < population; ++i) { + auto rnd = generator(0, i + 1); + if (!index_tracker.try_insert(rnd)) { + rnd = i; + index_tracker.insert(i); + } + const auto edge_id = row_start + rnd; + add(edge_id, global_src_node, local_src_node, dst_mapper, + out_global_dst_nodes); + } + } + } + } + + void _sample(const node_t global_src_node, + const scalar_t local_src_node, + const scalar_t row_start, + const scalar_t row_end, + const int64_t count, + const at::Tensor& weight_vector, + bool multinomial_mode, + pyg::sampler::Mapper& dst_mapper, + pyg::random::RandintEngine& generator, + std::vector& out_global_dst_nodes) { + if (count == 0) + return; + + const auto population = row_end - row_start; + + if (population == 0) + return; + + // Case 1: Sample the full neighborhood: + if (count < 0 || (!replace && count >= population)) { + for (scalar_t edge_id = row_start; edge_id < row_end; ++edge_id) { + add(edge_id, global_src_node, local_src_node, dst_mapper, + out_global_dst_nodes); } + } else { + // Multinomial sampling + at::Tensor edges = at::multinomial(weight_vector, count, replace); + // Add edges to the sampled list one-by-one + for (int i = 0; i < edges.numel(); ++i) { + const auto edge = row_start + edges.index({i}).item(); + add(edge, global_src_node, local_src_node, dst_mapper, + out_global_dst_nodes); + } // Sampling complete } } + inline void add(const scalar_t edge_id, const node_t global_src_node, const scalar_t local_src_node, @@ -337,7 +356,7 @@ sample(const at::Tensor& rowptr, if (!multinomial_mode) { for (size_t i = begin; i < end; ++i) { sampler.uniform_sample(/*global_src_node=*/sampled_nodes[i], - /*local_src_node=*/i, weights, count, mapper, + /*local_src_node=*/i, count, mapper, generator, /*out_global_dst_nodes=*/sampled_nodes); } @@ -354,7 +373,7 @@ sample(const at::Tensor& rowptr, for (size_t i = begin; i < end; ++i) { const auto batch_idx = sampled_nodes[i].first; sampler.temporal_sample(/*global_src_node=*/sampled_nodes[i], - /*local_src_node=*/i, weights, count, + /*local_src_node=*/i, count, seed_times[batch_idx], time_data, mapper, generator, /*out_global_dst_nodes=*/sampled_nodes); @@ -579,7 +598,7 @@ sample(const std::vector& node_types, for (size_t i = begin; i < end; ++i) { sampler.uniform_sample( /*global_src_node=*/src_sampled_nodes[i], - /*local_src_node=*/i, /*PH for weights*/at::ones({1,1}), count, dst_mapper, generator, + /*local_src_node=*/i, count, dst_mapper, generator, dst_sampled_nodes); } } else if constexpr (!std::is_scalar< @@ -590,7 +609,7 @@ sample(const std::vector& node_types, const auto batch_idx = src_sampled_nodes[i].first; sampler.temporal_sample( /*global_src_node=*/src_sampled_nodes[i], - /*local_src_node=*/i, /*PH for weights*/at::ones({1,1}), count, seed_times[batch_idx], + /*local_src_node=*/i, count, seed_times[batch_idx], dst_time_data, dst_mapper, generator, dst_sampled_nodes); } @@ -756,9 +775,7 @@ hetero_neighbor_sample_kernel( bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id) - // bool multinomial_mode) - { + bool return_edge_id) { DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, 0, node_types, edge_types, rowptr_dict, col_dict, seed_dict, num_neighbors_dict, time_dict, seed_time_dict, csc, diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index 185f09a81..508ccf255 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -64,8 +64,7 @@ hetero_neighbor_sample( bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id, - bool multinomial_mode) { + bool return_edge_id) { TORCH_CHECK(rowptr_dict.size() == col_dict.size(), "Number of edge types in 'rowptr_dict' and 'col_dict' must match") @@ -91,7 +90,7 @@ hetero_neighbor_sample( .typed(); return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict, num_neighbors_dict, time_dict, seed_time_dict, csc, replace, - directed, disjoint, temporal_strategy, return_edge_id, multinomial_mode); + directed, disjoint, temporal_strategy, return_edge_id); } TORCH_LIBRARY_FRAGMENT(pyg, m) { @@ -108,7 +107,7 @@ TORCH_LIBRARY_FRAGMENT(pyg, m) { "Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict " "= None, bool csc = False, bool replace = False, bool directed = True, " "bool disjoint = False, str temporal_strategy = 'uniform', bool " - "return_edge_id = True, bool multinomial_mode = False) -> (Dict(str, Tensor), Dict(str, Tensor), " + "return_edge_id = True) -> (Dict(str, Tensor), Dict(str, Tensor), " "Dict(str, Tensor), Dict(str, Tensor)?, Dict(str, int[]), " "Dict(str, int[]))")); } diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index cd28a7c0c..db60fa5ff 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -58,8 +58,7 @@ hetero_neighbor_sample( bool directed = true, bool disjoint = false, std::string strategy = "uniform", - bool return_edge_id = true, - bool multinomial_mode = false); + bool return_edge_id = true); } // namespace sampler } // namespace pyg From 5a07278509c72b27f8425e0828c92fb215caa966 Mon Sep 17 00:00:00 2001 From: OlhaBabicheva Date: Mon, 28 Aug 2023 13:40:25 +0000 Subject: [PATCH 06/13] Optimized neighbor_kernel.cpp --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 35 ++++++++++++-------- 1 file changed, 22 insertions(+), 13 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 1697edf9d..6cab5c267 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -45,6 +45,14 @@ class NeighborSampler { std::vector& out_global_dst_nodes) { const auto row_start = rowptr_[to_scalar_t(global_src_node)]; const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; + + if (count == 0) + return; + + const auto population = row_end - row_start; + if (population == 0) + return; + at::Tensor weights_neighborhood = weights.index( {at::indexing::Slice(row_start, row_end)}); @@ -61,6 +69,13 @@ class NeighborSampler { const auto row_start = rowptr_[to_scalar_t(global_src_node)]; const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; + if (count == 0) + return; + + const auto population = row_end - row_start; + if (population == 0) + return; + _sample(global_src_node, local_src_node, row_start, row_end, count, dst_mapper, generator, out_global_dst_nodes); } @@ -75,7 +90,13 @@ class NeighborSampler { std::vector& out_global_dst_nodes) { auto row_start = rowptr_[to_scalar_t(global_src_node)]; auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; - auto population = row_end - row_start; + const auto population = row_end - row_start; + + if (count == 0) + return; + + if (population == 0) + return; // Find new `row_end` such that all neighbors fulfill temporal constraints: auto it = std::upper_bound( @@ -135,14 +156,8 @@ class NeighborSampler { pyg::sampler::Mapper& dst_mapper, pyg::random::RandintEngine& generator, std::vector& out_global_dst_nodes) { - if (count == 0) - return; - const auto population = row_end - row_start; - if (population == 0) - return; - // Case 1: Sample the full neighborhood: if (count < 0 || (!replace && count >= population)) { for (scalar_t edge_id = row_start; edge_id < row_end; ++edge_id) { @@ -209,14 +224,8 @@ class NeighborSampler { pyg::sampler::Mapper& dst_mapper, pyg::random::RandintEngine& generator, std::vector& out_global_dst_nodes) { - if (count == 0) - return; - const auto population = row_end - row_start; - if (population == 0) - return; - // Case 1: Sample the full neighborhood: if (count < 0 || (!replace && count >= population)) { for (scalar_t edge_id = row_start; edge_id < row_end; ++edge_id) { From fc6a9552e2c328e2372697c014d9b1a51159b633 Mon Sep 17 00:00:00 2001 From: OlhaBabicheva Date: Tue, 29 Aug 2023 13:22:57 +0000 Subject: [PATCH 07/13] Second optimization of neighbor_kernel.cpp --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 6cab5c267..c42f7145e 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -46,9 +46,6 @@ class NeighborSampler { const auto row_start = rowptr_[to_scalar_t(global_src_node)]; const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; - if (count == 0) - return; - const auto population = row_end - row_start; if (population == 0) return; @@ -68,9 +65,6 @@ class NeighborSampler { std::vector& out_global_dst_nodes) { const auto row_start = rowptr_[to_scalar_t(global_src_node)]; const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; - - if (count == 0) - return; const auto population = row_end - row_start; if (population == 0) @@ -92,9 +86,6 @@ class NeighborSampler { auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; const auto population = row_end - row_start; - if (count == 0) - return; - if (population == 0) return; @@ -364,6 +355,8 @@ sample(const at::Tensor& rowptr, if (!time.has_value()) { if (!multinomial_mode) { for (size_t i = begin; i < end; ++i) { + if (count == 0) + continue; sampler.uniform_sample(/*global_src_node=*/sampled_nodes[i], /*local_src_node=*/i, count, mapper, generator, @@ -371,6 +364,8 @@ sample(const at::Tensor& rowptr, } } else { for (size_t i = begin; i < end; ++i) { + if (count == 0) + continue; sampler.multinomial_sample(/*global_src_node=*/sampled_nodes[i], /*local_src_node=*/i, weights, count, mapper, generator, @@ -380,6 +375,8 @@ sample(const at::Tensor& rowptr, } else if constexpr (!std::is_scalar::value) { // Temporal: const auto time_data = time.value().data_ptr(); for (size_t i = begin; i < end; ++i) { + if (count == 0) + continue; const auto batch_idx = sampled_nodes[i].first; sampler.temporal_sample(/*global_src_node=*/sampled_nodes[i], /*local_src_node=*/i, count, @@ -586,7 +583,7 @@ sample(const std::vector& node_types, } at::parallel_for( 0, threads_edge_types.size(), 1, [&](size_t _s, size_t _e) { - for (auto j = _s; j < _e; j++) { + for (auto j = _s; j < _e; ++j) { for (const auto& k : threads_edge_types[j]) { const auto src = !csc ? std::get<0>(k) : std::get<2>(k); const auto dst = !csc ? std::get<2>(k) : std::get<0>(k); From 5321634933d465b6f3829be5e531233b9d3bb47e Mon Sep 17 00:00:00 2001 From: OlhaBabicheva Date: Sun, 3 Sep 2023 19:36:23 +0000 Subject: [PATCH 08/13] Changed code in neighbor_kernel.cpp and mapper.h --- pyg_lib/csrc/sampler/cpu/mapper.h | 24 +++------- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 48 ++++++++++++-------- 2 files changed, 35 insertions(+), 37 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/mapper.h b/pyg_lib/csrc/sampler/cpu/mapper.h index aff8b11a4..9b26dbfc7 100644 --- a/pyg_lib/csrc/sampler/cpu/mapper.h +++ b/pyg_lib/csrc/sampler/cpu/mapper.h @@ -19,20 +19,12 @@ class Mapper { // but slower hash map implementation. As a general rule of thumb, we are // safe to use vectors in case the number of nodes are small, or it is // expected that we sample a large amount of nodes. - use_vec = (num_nodes < 1000000) || (num_entries > (num_nodes / 10)); + use_vec = std::is_scalar::value && + (num_nodes > 0) && + ((num_nodes < 1000000) || (num_entries > (num_nodes / 10))); - if (num_nodes <= 0) { // == `num_nodes` is undefined: - use_vec = false; - } - - // We can only utilize vector mappings in case entries are scalar: - if (!std::is_scalar::value) { - use_vec = false; - } - - if (use_vec) { + if (use_vec) to_local_vec.resize(num_nodes, -1); - } } std::pair insert(const node_t& node) { @@ -49,7 +41,7 @@ class Mapper { res = std::pair(out.first->second, out.second); } if (res.second) { - curr++; + ++curr; } return res; } @@ -65,11 +57,7 @@ class Mapper { } bool exists(const node_t& node) { - if (use_vec) { - return to_local_vec[node] >= 0; - } else { - return to_local_map.count(node) > 0; - } + return use_vec ? to_local_vec[node] >= 0 : to_local_map.count(node) > 0; } scalar_t map(const node_t& node) { diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index c42f7145e..3130c55f2 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -349,30 +349,25 @@ sample(const at::Tensor& rowptr, num_sampled_nodes_per_hop.push_back(seed.numel()); size_t begin = 0, end = seed.size(0); - for (size_t ell = 0; ell < num_neighbors.size(); ++ell) { - const auto count = num_neighbors[ell]; - sampler.num_sampled_edges_per_hop.push_back(0); - if (!time.has_value()) { - if (!multinomial_mode) { - for (size_t i = begin; i < end; ++i) { - if (count == 0) - continue; - sampler.uniform_sample(/*global_src_node=*/sampled_nodes[i], - /*local_src_node=*/i, count, mapper, - generator, - /*out_global_dst_nodes=*/sampled_nodes); - } - } else { - for (size_t i = begin; i < end; ++i) { + if constexpr (multinomial_mode) { + for (size_t ell = 0; ell < num_neighbors.size(); ++ell) { + const auto count = num_neighbors[ell]; + sampler.num_sampled_edges_per_hop.push_back(0); + for (size_t i = begin; i < end; ++i) { if (count == 0) continue; sampler.multinomial_sample(/*global_src_node=*/sampled_nodes[i], /*local_src_node=*/i, weights, count, mapper, generator, /*out_global_dst_nodes=*/sampled_nodes); - } } - } else if constexpr (!std::is_scalar::value) { // Temporal: + begin = end, end = sampled_nodes.size(); + num_sampled_nodes_per_hop.push_back(end - begin); + } + } else if constexpr (!std::is_scalar::value) { + for (size_t ell = 0; ell < num_neighbors.size(); ++ell) { + const auto count = num_neighbors[ell]; + sampler.num_sampled_edges_per_hop.push_back(0); const auto time_data = time.value().data_ptr(); for (size_t i = begin; i < end; ++i) { if (count == 0) @@ -384,9 +379,24 @@ sample(const at::Tensor& rowptr, generator, /*out_global_dst_nodes=*/sampled_nodes); } + begin = end, end = sampled_nodes.size(); + num_sampled_nodes_per_hop.push_back(end - begin); + } + } else { + for (size_t ell = 0; ell < num_neighbors.size(); ++ell) { + const auto count = num_neighbors[ell]; + sampler.num_sampled_edges_per_hop.push_back(0); + for (size_t i = begin; i < end; ++i) { + if (count == 0) + continue; + sampler.uniform_sample(/*global_src_node=*/sampled_nodes[i], + /*local_src_node=*/i, count, mapper, + generator, + /*out_global_dst_nodes=*/sampled_nodes); + } + begin = end, end = sampled_nodes.size(); + num_sampled_nodes_per_hop.push_back(end - begin); } - begin = end, end = sampled_nodes.size(); - num_sampled_nodes_per_hop.push_back(end - begin); } out_node_id = pyg::utils::from_vector(sampled_nodes); From 1dc3d208802615ddc2dcb7b7c6481c0800b4e19d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 3 Sep 2023 19:37:06 +0000 Subject: [PATCH 09/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pyg_lib/csrc/sampler/cpu/mapper.h | 3 +- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 190 +++++++++++-------- pyg_lib/csrc/sampler/neighbor.cpp | 10 +- pyg_lib/sampler/__init__.py | 9 +- test/csrc/sampler/test_neighbor.cpp | 21 +- 5 files changed, 132 insertions(+), 101 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/mapper.h b/pyg_lib/csrc/sampler/cpu/mapper.h index 9b26dbfc7..12b30ee2a 100644 --- a/pyg_lib/csrc/sampler/cpu/mapper.h +++ b/pyg_lib/csrc/sampler/cpu/mapper.h @@ -19,8 +19,7 @@ class Mapper { // but slower hash map implementation. As a general rule of thumb, we are // safe to use vectors in case the number of nodes are small, or it is // expected that we sample a large amount of nodes. - use_vec = std::is_scalar::value && - (num_nodes > 0) && + use_vec = std::is_scalar::value && (num_nodes > 0) && ((num_nodes < 1000000) || (num_entries > (num_nodes / 10))); if (use_vec) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 3130c55f2..8d100143f 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -37,12 +37,12 @@ class NeighborSampler { } void multinomial_sample(const node_t global_src_node, - const scalar_t local_src_node, - const at::Tensor& weights, - const int64_t count, - pyg::sampler::Mapper& dst_mapper, - pyg::random::RandintEngine& generator, - std::vector& out_global_dst_nodes) { + const scalar_t local_src_node, + const at::Tensor& weights, + const int64_t count, + pyg::sampler::Mapper& dst_mapper, + pyg::random::RandintEngine& generator, + std::vector& out_global_dst_nodes) { const auto row_start = rowptr_[to_scalar_t(global_src_node)]; const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; @@ -50,11 +50,12 @@ class NeighborSampler { if (population == 0) return; - at::Tensor weights_neighborhood = weights.index( - {at::indexing::Slice(row_start, row_end)}); - - _sample(global_src_node, local_src_node, row_start, row_end, count, weights_neighborhood, true, - dst_mapper, generator, out_global_dst_nodes); + at::Tensor weights_neighborhood = + weights.index({at::indexing::Slice(row_start, row_end)}); + + _sample(global_src_node, local_src_node, row_start, row_end, count, + weights_neighborhood, true, dst_mapper, generator, + out_global_dst_nodes); } void uniform_sample(const node_t global_src_node, @@ -235,7 +236,6 @@ class NeighborSampler { } } - inline void add(const scalar_t edge_id, const node_t global_src_node, const scalar_t local_src_node, @@ -271,7 +271,11 @@ class NeighborSampler { // Homogeneous neighbor sampling /////////////////////////////////////////////// -template +template std::tuple +template std::tuple, c10::Dict, c10::Dict, @@ -678,73 +685,89 @@ sample(const std::vector& node_types, // Dispatcher ////////////////////////////////////////////////////////////////// -#define DISPATCH_SAMPLE(replace, directed, disjount, return_edge_id, multinomial_mode, ...) \ - if (replace && directed && disjoint && return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && directed && disjoint && !return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && !return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && !return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && !return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && !return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && !return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && !return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && !return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && directed && disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && directed && disjoint && !return_edge_id && multinomial_mode) \ +#define DISPATCH_SAMPLE(replace, directed, disjount, return_edge_id, \ + multinomial_mode, ...) \ + if (replace && directed && disjoint && return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && directed && disjoint && !return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && !return_edge_id && \ + !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && !return_edge_id && \ + !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && return_edge_id && \ + !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && !return_edge_id && \ + !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && return_edge_id && !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && !return_edge_id && \ + !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && return_edge_id && \ + !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && !return_edge_id && \ + !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && return_edge_id && \ + !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && !return_edge_id && \ + !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && return_edge_id && \ + !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && !return_edge_id && \ + !multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && directed && disjoint && return_edge_id && multinomial_mode) \ return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && !return_edge_id && multinomial_mode) \ + if (replace && directed && disjoint && !return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && !return_edge_id && multinomial_mode) \ return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && !return_edge_id && multinomial_mode) \ + if (replace && !directed && disjoint && return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && !return_edge_id && multinomial_mode) \ return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && !return_edge_id && multinomial_mode) \ + if (replace && !directed && !disjoint && return_edge_id && multinomial_mode) \ return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && !return_edge_id && multinomial_mode) \ + if (replace && !directed && !disjoint && !return_edge_id && \ + multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && !return_edge_id && multinomial_mode) \ return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && !return_edge_id && multinomial_mode) \ + if (!replace && directed && !disjoint && return_edge_id && multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && !return_edge_id && \ + multinomial_mode) \ return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && !return_edge_id && multinomial_mode) \ + if (!replace && !directed && disjoint && return_edge_id && multinomial_mode) \ return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && !return_edge_id && multinomial_mode) \ + if (!replace && !directed && disjoint && !return_edge_id && \ + multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && return_edge_id && \ + multinomial_mode) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && !return_edge_id && \ + multinomial_mode) \ return sample(__VA_ARGS__); - } // namespace std::tuple, diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index 508ccf255..b4773dee8 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -39,8 +39,8 @@ neighbor_sample(const at::Tensor& rowptr, static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::neighbor_sample", "") .typed(); - return op.call(rowptr, col, weights, seed, num_neighbors, time, seed_time, csc, - replace, directed, disjoint, temporal_strategy, + return op.call(rowptr, col, weights, seed, num_neighbors, time, seed_time, + csc, replace, directed, disjoint, temporal_strategy, return_edge_id, multinomial_mode); } @@ -95,10 +95,12 @@ hetero_neighbor_sample( TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor weights, Tensor seed, int[] " + "pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor weights, Tensor " + "seed, int[] " "num_neighbors, Tensor? time = None, Tensor? seed_time = None, bool csc " "= False, bool replace = False, bool directed = True, bool disjoint = " - "False, str temporal_strategy = 'uniform', bool return_edge_id = True, bool multinomial_mode = False) " + "False, str temporal_strategy = 'uniform', bool return_edge_id = True, " + "bool multinomial_mode = False) " "-> (Tensor, Tensor, Tensor, Tensor?, int[], int[])")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] " diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index 3e4d49ed1..e0767341d 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -74,10 +74,11 @@ def neighbor_sample( Lastly, returns information about the sampled amount of nodes and edges per hop. """ - return torch.ops.pyg.neighbor_sample(rowptr, col, weights, seed, num_neighbors, - time, seed_time, csc, replace, - directed, disjoint, temporal_strategy, - return_edge_id, multinomial_mode) + return torch.ops.pyg.neighbor_sample(rowptr, col, weights, seed, + num_neighbors, time, seed_time, csc, + replace, directed, disjoint, + temporal_strategy, return_edge_id, + multinomial_mode) def hetero_neighbor_sample( diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index 1ec1db545..290d75f1f 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -12,9 +12,9 @@ TEST(FullNeighborTest, BasicAssertions) { auto seed = at::arange(2, 4, options); std::vector num_neighbors = {-1, -1}; - auto out = pyg::sampler::neighbor_sample(/*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), - seed, num_neighbors); + auto out = pyg::sampler::neighbor_sample( + /*rowptr=*/std::get<0>(graph), + /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), seed, num_neighbors); auto expected_row = at::tensor({0, 0, 1, 1, 2, 2, 3, 3}, options); EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); @@ -40,7 +40,8 @@ TEST(WithoutReplacementNeighborTest, BasicAssertions) { at::manual_seed(123456); auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), seed, num_neighbors, /*time=*/c10::nullopt, + /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), seed, num_neighbors, + /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false); auto expected_row = at::tensor({0, 1, 2, 3}, options); @@ -63,7 +64,8 @@ TEST(WithReplacementNeighborTest, BasicAssertions) { at::manual_seed(123456); auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), seed, num_neighbors, /*time=*/c10::nullopt, + /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), seed, num_neighbors, + /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/true); auto expected_row = at::tensor({0, 1, 2, 3}, options); @@ -85,7 +87,8 @@ TEST(DisjointNeighborTest, BasicAssertions) { auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), seed, num_neighbors, /*time=*/c10::nullopt, + /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), seed, num_neighbors, + /*time=*/c10::nullopt, /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true); @@ -115,7 +118,8 @@ TEST(TemporalNeighborTest, BasicAssertions) { col = std::get<0>(at::sort(col.view({-1, 2}), /*dim=*/1)).flatten(); auto out1 = pyg::sampler::neighbor_sample( - rowptr, col, /*weights=*/at::ones(8), seed, /*num_neighbors=*/{2, 2}, /*time=*/time, + rowptr, col, /*weights=*/at::ones(8), seed, /*num_neighbors=*/{2, 2}, + /*time=*/time, /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true); @@ -131,7 +135,8 @@ TEST(TemporalNeighborTest, BasicAssertions) { EXPECT_TRUE(at::equal(std::get<3>(out1).value(), expected_edges)); auto out2 = pyg::sampler::neighbor_sample( - rowptr, col, /*weights=*/at::ones(8), seed, /*num_neighbors=*/{1, 2}, /*time=*/time, + rowptr, col, /*weights=*/at::ones(8), seed, /*num_neighbors=*/{1, 2}, + /*time=*/time, /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true, /*temporal_strategy=*/"last"); From 72ec21aff92afbf74f76761b810ead20217c475f Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 4 Sep 2023 08:16:22 +0000 Subject: [PATCH 10/13] update --- .github/CONTRIBUTING.md | 2 +- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 300 +++++++------------ pyg_lib/csrc/sampler/cpu/neighbor_kernel.h | 5 +- pyg_lib/csrc/sampler/neighbor.cpp | 22 +- pyg_lib/csrc/sampler/neighbor.h | 2 +- pyg_lib/sampler/__init__.py | 10 +- test/csrc/sampler/test_neighbor.cpp | 59 +++- 7 files changed, 179 insertions(+), 221 deletions(-) diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 595645d39..28b3c26c6 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -22,7 +22,7 @@ If you already previously cloned `pyg-lib`, update it: ``` git pull git submodule sync --recursive -git submodule update --init --recursive --jobs 0 +git submodule update --init --recursive ``` Then, build the library via: diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 8d100143f..f4310bc20 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -36,26 +36,23 @@ class NeighborSampler { "No valid temporal strategy found"); } - void multinomial_sample(const node_t global_src_node, - const scalar_t local_src_node, - const at::Tensor& weights, - const int64_t count, - pyg::sampler::Mapper& dst_mapper, - pyg::random::RandintEngine& generator, - std::vector& out_global_dst_nodes) { + void biased_sample(const node_t global_src_node, + const scalar_t local_src_node, + const at::Tensor& edge_weight, + const int64_t count, + pyg::sampler::Mapper& dst_mapper, + pyg::random::RandintEngine& generator, + std::vector& out_global_dst_nodes) { const auto row_start = rowptr_[to_scalar_t(global_src_node)]; const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; - const auto population = row_end - row_start; - if (population == 0) - return; + if (row_end - row_start == 0) + || (count == 0) return; - at::Tensor weights_neighborhood = - weights.index({at::indexing::Slice(row_start, row_end)}); + const auto weight = edge_weight.narrow(0, row_start, row_end - row_start); - _sample(global_src_node, local_src_node, row_start, row_end, count, - weights_neighborhood, true, dst_mapper, generator, - out_global_dst_nodes); + _biased_sample(global_src_node, local_src_node, row_start, row_end, count, + weight, dst_mapper, generator, out_global_dst_nodes); } void uniform_sample(const node_t global_src_node, @@ -67,9 +64,8 @@ class NeighborSampler { const auto row_start = rowptr_[to_scalar_t(global_src_node)]; const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; - const auto population = row_end - row_start; - if (population == 0) - return; + if (row_end - row_start == 0) + || (count == 0) return; _sample(global_src_node, local_src_node, row_start, row_end, count, dst_mapper, generator, out_global_dst_nodes); @@ -85,10 +81,9 @@ class NeighborSampler { std::vector& out_global_dst_nodes) { auto row_start = rowptr_[to_scalar_t(global_src_node)]; auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; - const auto population = row_end - row_start; - if (population == 0) - return; + if (row_end - row_start == 0) + || (count == 0) return; // Find new `row_end` such that all neighbors fulfill temporal constraints: auto it = std::upper_bound( @@ -206,16 +201,15 @@ class NeighborSampler { } } - void _sample(const node_t global_src_node, - const scalar_t local_src_node, - const scalar_t row_start, - const scalar_t row_end, - const int64_t count, - const at::Tensor& weight_vector, - bool multinomial_mode, - pyg::sampler::Mapper& dst_mapper, - pyg::random::RandintEngine& generator, - std::vector& out_global_dst_nodes) { + void _biased_sample(const node_t global_src_node, + const scalar_t local_src_node, + const scalar_t row_start, + const scalar_t row_end, + const int64_t count, + const at::Tensor& weight, + pyg::sampler::Mapper& dst_mapper, + pyg::random::RandintEngine& generator, + std::vector& out_global_dst_nodes) { const auto population = row_end - row_start; // Case 1: Sample the full neighborhood: @@ -224,15 +218,16 @@ class NeighborSampler { add(edge_id, global_src_node, local_src_node, dst_mapper, out_global_dst_nodes); } - } else { - // Multinomial sampling - at::Tensor edges = at::multinomial(weight_vector, count, replace); - // Add edges to the sampled list one-by-one - for (int i = 0; i < edges.numel(); ++i) { - const auto edge = row_start + edges.index({i}).item(); - add(edge, global_src_node, local_src_node, dst_mapper, - out_global_dst_nodes); - } // Sampling complete + } + + // Case 2: Multinomial sampling: + else { + const auto index = at::multinomial(weight, count, replace); + const auto index_data = index.data_ptr(); + for (size_t i = 0; i < index.numel(); ++i) { + add(row_start + index_data[i], global_src_node, local_src_node, + dst_mapper, out_global_dst_nodes); + } } } @@ -271,11 +266,7 @@ class NeighborSampler { // Homogeneous neighbor sampling /////////////////////////////////////////////// -template +template std::tuple> sample(const at::Tensor& rowptr, const at::Tensor& col, - const at::Tensor& weights, const at::Tensor& seed, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, + const c10::optional& edge_weight, const bool csc, const std::string temporal_strategy) { TORCH_CHECK(!time.has_value() || disjoint, @@ -304,6 +295,8 @@ sample(const at::Tensor& rowptr, TORCH_CHECK(seed_time.value().is_contiguous(), "Non-contiguous 'seed_time'"); } + TORCH_CHECK(time.has_value() && edge_weight.has_value(), + "Biased temporal sampling not yet supported"); at::Tensor out_row, out_col, out_node_id; c10::optional out_edge_id = c10::nullopt; @@ -353,53 +346,45 @@ sample(const at::Tensor& rowptr, num_sampled_nodes_per_hop.push_back(seed.numel()); size_t begin = 0, end = seed.size(0); - if constexpr (multinomial_mode) { - for (size_t ell = 0; ell < num_neighbors.size(); ++ell) { - const auto count = num_neighbors[ell]; - sampler.num_sampled_edges_per_hop.push_back(0); + for (size_t ell = 0; ell < num_neighbors.size(); ++ell) { + const auto count = num_neighbors[ell]; + sampler.num_sampled_edges_per_hop.push_back(0); + if (edge_weight.has_value()) { for (size_t i = begin; i < end; ++i) { - if (count == 0) - continue; - sampler.multinomial_sample(/*global_src_node=*/sampled_nodes[i], - /*local_src_node=*/i, weights, count, - mapper, generator, - /*out_global_dst_nodes=*/sampled_nodes); + sampler.multinomial_sample( + /*global_src_node=*/sampled_nodes[i], + /*local_src_node=*/i, + /*edge_weight=*/edge_weight.value(), + /*count=*/count, + /*dst_mapper=*/mapper, + /*generator=*/generator, + /*out_global_dst_nodes=*/sampled_nodes); } - begin = end, end = sampled_nodes.size(); - num_sampled_nodes_per_hop.push_back(end - begin); - } - } else if constexpr (!std::is_scalar::value) { - for (size_t ell = 0; ell < num_neighbors.size(); ++ell) { - const auto count = num_neighbors[ell]; - sampler.num_sampled_edges_per_hop.push_back(0); - const auto time_data = time.value().data_ptr(); + } else if (!time.has_value) { for (size_t i = begin; i < end; ++i) { - if (count == 0) - continue; - const auto batch_idx = sampled_nodes[i].first; - sampler.temporal_sample(/*global_src_node=*/sampled_nodes[i], - /*local_src_node=*/i, count, - seed_times[batch_idx], time_data, mapper, - generator, - /*out_global_dst_nodes=*/sampled_nodes); + sampler.uniform_sample( + /*global_src_node=*/sampled_nodes[i], + /*local_src_node=*/i, + /*count=*/count, + /*dst_mapper=*/mapper, + /*generator=*/generator, + /*out_global_dst_nodes=*/sampled_nodes); } - begin = end, end = sampled_nodes.size(); - num_sampled_nodes_per_hop.push_back(end - begin); - } - } else { - for (size_t ell = 0; ell < num_neighbors.size(); ++ell) { - const auto count = num_neighbors[ell]; - sampler.num_sampled_edges_per_hop.push_back(0); + } else if constexpr (!std::is_scalar::value) { // Temporal: for (size_t i = begin; i < end; ++i) { - if (count == 0) - continue; - sampler.uniform_sample(/*global_src_node=*/sampled_nodes[i], - /*local_src_node=*/i, count, mapper, generator, - /*out_global_dst_nodes=*/sampled_nodes); + const auto batch_idx = sampled_nodes[i].first; + sampler.temporal_sample( + /*global_src_node=*/sampled_nodes[i], + /*local_src_node=*/i, /*count=*/count, + /*seed_time=*/seed_times[batch_idx], + /*time=*/time_data, + /*dst_mapper=*/mapper, + /*generator=*/generator, + /*out_global_dst_nodes=*/sampled_nodes); } - begin = end, end = sampled_nodes.size(); - num_sampled_nodes_per_hop.push_back(end - begin); } + begin = end, end = sampled_nodes.size(); + num_sampled_nodes_per_hop.push_back(end - begin); } out_node_id = pyg::utils::from_vector(sampled_nodes); @@ -419,11 +404,7 @@ sample(const at::Tensor& rowptr, // Heterogeneous neighbor sampling ///////////////////////////////////////////// -template +template std::tuple, c10::Dict, c10::Dict, @@ -531,9 +512,9 @@ sample(const std::vector& node_types, temporal_strategy)}); if (parallel) { - // Each thread is assigned edge types that have the same dst node type. - // Thanks to this, each thread will operate on a separate mapper and - // separate sampler. + // Each thread is assigned edge types that have the same dst node + // type. Thanks to this, each thread will operate on a separate mapper + // and separate sampler. bool added = false; const auto dst = !csc ? std::get<2>(k) : std::get<0>(k); for (auto& e : threads_edge_types) { @@ -547,7 +528,8 @@ sample(const std::vector& node_types, threads_edge_types.push_back({k}); } } - if (!parallel) { // If not parallel then one thread handles all edge types. + if (!parallel) { // If not parallel then one thread handles all edge + // types. threads_edge_types.push_back({edge_types}); } @@ -685,88 +667,39 @@ sample(const std::vector& node_types, // Dispatcher ////////////////////////////////////////////////////////////////// -#define DISPATCH_SAMPLE(replace, directed, disjount, return_edge_id, \ - multinomial_mode, ...) \ - if (replace && directed && disjoint && return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && directed && disjoint && !return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && !return_edge_id && \ - !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && !return_edge_id && \ - !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && return_edge_id && \ - !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && !return_edge_id && \ - !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && return_edge_id && !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && !return_edge_id && \ - !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && return_edge_id && \ - !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && !return_edge_id && \ - !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && return_edge_id && \ - !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && !return_edge_id && \ - !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && return_edge_id && \ - !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && !return_edge_id && \ - !multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && directed && disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && directed && disjoint && !return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && directed && !disjoint && !return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && disjoint && !return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (replace && !directed && !disjoint && !return_edge_id && \ - multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && disjoint && !return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && directed && !disjoint && !return_edge_id && \ - multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && return_edge_id && multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && disjoint && !return_edge_id && \ - multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && return_edge_id && \ - multinomial_mode) \ - return sample(__VA_ARGS__); \ - if (!replace && !directed && !disjoint && !return_edge_id && \ - multinomial_mode) \ - return sample(__VA_ARGS__); +#define DISPATCH_SAMPLE(replace, directed, disjount, return_edge_id, ...) \ + if (replace && directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (replace && !directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && return_edge_id) \ + return sample(__VA_ARGS__); \ + if (!replace && !directed && !disjoint && !return_edge_id) \ + return sample(__VA_ARGS__); } // namespace @@ -778,21 +711,20 @@ std::tuple> neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, - const at::Tensor& weights, const at::Tensor& seed, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, + const c10::optional& edge_weight, bool csc, bool replace, bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id, - bool multinomial_mode) { - DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, multinomial_mode, - rowptr, col, weights, seed, num_neighbors, time, seed_time, - csc, temporal_strategy); + bool return_edge_id) { + DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col, + seed, num_neighbors, time, seed_time, edge_weight, csc, + temporal_strategy); } std::tuple, @@ -816,7 +748,7 @@ hetero_neighbor_sample_kernel( bool disjoint, std::string temporal_strategy, bool return_edge_id) { - DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, 0, node_types, + DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, node_types, edge_types, rowptr_dict, col_dict, seed_dict, num_neighbors_dict, time_dict, seed_time_dict, csc, temporal_strategy); diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h index 6588a372c..d4cdfeb59 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.h @@ -13,18 +13,17 @@ std::tuple> neighbor_sample_kernel(const at::Tensor& rowptr, const at::Tensor& col, - const at::Tensor& weights, const at::Tensor& seed, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, + const c10::optional& edge_weight, bool csc, bool replace, bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id, - bool multinomial_mode); + bool return_edge_id); std::tuple, c10::Dict, diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index b4773dee8..e709e0ffb 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -16,18 +16,17 @@ std::tuple> neighbor_sample(const at::Tensor& rowptr, const at::Tensor& col, - const at::Tensor& weights, const at::Tensor& seed, const std::vector& num_neighbors, const c10::optional& time, const c10::optional& seed_time, + const c10::optional& edge_weight, bool csc, bool replace, bool directed, bool disjoint, std::string temporal_strategy, - bool return_edge_id, - bool multinomial_mode) { + bool return_edge_id) { at::TensorArg rowptr_t{rowptr, "rowtpr", 1}; at::TensorArg col_t{col, "col", 1}; at::TensorArg seed_t{seed, "seed", 1}; @@ -39,9 +38,9 @@ neighbor_sample(const at::Tensor& rowptr, static auto op = c10::Dispatcher::singleton() .findSchemaOrThrow("pyg::neighbor_sample", "") .typed(); - return op.call(rowptr, col, weights, seed, num_neighbors, time, seed_time, + return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight, csc, replace, directed, disjoint, temporal_strategy, - return_edge_id, multinomial_mode); + return_edge_id); } std::tuple, @@ -95,13 +94,12 @@ hetero_neighbor_sample( TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor weights, Tensor " - "seed, int[] " - "num_neighbors, Tensor? time = None, Tensor? seed_time = None, bool csc " - "= False, bool replace = False, bool directed = True, bool disjoint = " - "False, str temporal_strategy = 'uniform', bool return_edge_id = True, " - "bool multinomial_mode = False) " - "-> (Tensor, Tensor, Tensor, Tensor?, int[], int[])")); + "pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] " + "num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? " + "edge_weight = None, bool csc = False, bool replace = False, bool " + "directed = True, bool disjoint = False, str temporal_strategy = " + "'uniform', bool return_edge_id = True) -> " + "(Tensor, Tensor, Tensor, Tensor?, int[], int[])")); m.def(TORCH_SELECTIVE_SCHEMA( "pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] " "edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, " diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index db60fa5ff..86d4aa426 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -19,11 +19,11 @@ std::tuple> neighbor_sample(const at::Tensor& rowptr, const at::Tensor& col, - const at::Tensor& weights, const at::Tensor& seed, const std::vector& num_neighbors, const c10::optional& time = c10::nullopt, const c10::optional& seed_time = c10::nullopt, + const c10::optional& edge_weight = c10::nullopt, bool csc = false, bool replace = false, bool directed = true, diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index e0767341d..c58dc1295 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -13,16 +13,15 @@ def neighbor_sample( col: Tensor, seed: Tensor, num_neighbors: List[int], - weights: Optional[Tensor] = None, time: Optional[Tensor] = None, seed_time: Optional[Tensor] = None, + edge_weight: Optional[Tensor] = None, csc: bool = False, replace: bool = False, directed: bool = True, disjoint: bool = False, temporal_strategy: str = 'uniform', return_edge_id: bool = True, - multinomial_mode: bool = False, ) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor], List[int], List[int]]: r"""Recursively samples neighbors from all node indices in :obj:`seed` in the graph given by :obj:`(rowptr, col)`. @@ -74,11 +73,10 @@ def neighbor_sample( Lastly, returns information about the sampled amount of nodes and edges per hop. """ - return torch.ops.pyg.neighbor_sample(rowptr, col, weights, seed, - num_neighbors, time, seed_time, csc, + return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors, + time, seed_time, edge_weight, csc, replace, directed, disjoint, - temporal_strategy, return_edge_id, - multinomial_mode) + temporal_strategy, return_edge_id) def hetero_neighbor_sample( diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index 290d75f1f..c2ac2d070 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -14,7 +14,7 @@ TEST(FullNeighborTest, BasicAssertions) { auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), seed, num_neighbors); + /*col=*/std::get<1>(graph), seed, num_neighbors); auto expected_row = at::tensor({0, 0, 1, 1, 2, 2, 3, 3}, options); EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); @@ -40,9 +40,14 @@ TEST(WithoutReplacementNeighborTest, BasicAssertions) { at::manual_seed(123456); auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), seed, num_neighbors, + /*col=*/std::get<1>(graph), + /*seed=*/seed, + /*num_neighbors=*/num_neighbors, /*time=*/c10::nullopt, - /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false); + /*seed_time=*/c10::nullopt, + /*edge_weight=*/c10::nullopt, + /*csc=*/false, + /*replace=*/false); auto expected_row = at::tensor({0, 1, 2, 3}, options); EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); @@ -64,9 +69,14 @@ TEST(WithReplacementNeighborTest, BasicAssertions) { at::manual_seed(123456); auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), seed, num_neighbors, + /*col=*/std::get<1>(graph), + /*seed=*/seed, + /*num_neighbors=*/num_neighbors, /*time=*/c10::nullopt, - /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/true); + /*seed_time=*/c10::nullopt, + /*edge_weight=*/c10::nullopt, + /*csc=*/false, + /*replace=*/true); auto expected_row = at::tensor({0, 1, 2, 3}, options); EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); @@ -87,10 +97,16 @@ TEST(DisjointNeighborTest, BasicAssertions) { auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), /*weights=*/at::ones(8), seed, num_neighbors, + /*col=*/std::get<1>(graph), + /*seed=*/seed, + /*num_neighbors=*/num_neighbors, /*time=*/c10::nullopt, - /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, - /*directed=*/true, /*disjoint=*/true); + /*seed_time=*/c10::nullopt, + /*edge_weight=*/c10::nullopt, + /*csc=*/false, + /*replace=*/false, + /*directed=*/true, + /*disjoint=*/true); auto expected_row = at::tensor({0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5}, options); EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); @@ -118,10 +134,17 @@ TEST(TemporalNeighborTest, BasicAssertions) { col = std::get<0>(at::sort(col.view({-1, 2}), /*dim=*/1)).flatten(); auto out1 = pyg::sampler::neighbor_sample( - rowptr, col, /*weights=*/at::ones(8), seed, /*num_neighbors=*/{2, 2}, + /*rowptr=*/rowptr, + /*col=*/col, + /*seed=*/seed, + /*num_neighbors=*/{2, 2}, /*time=*/time, - /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, - /*directed=*/true, /*disjoint=*/true); + /*seed_time=*/c10::nullopt, + /*edge_weight=*/c10::nullopt, + /*csc=*/false, + /*replace=*/false, + /*directed=*/true, + /*disjoint=*/true); // Expect only the earlier neighbors or the same node to be sampled: auto expected_row = at::tensor({0, 1, 2, 2, 3, 3}, options); @@ -135,10 +158,18 @@ TEST(TemporalNeighborTest, BasicAssertions) { EXPECT_TRUE(at::equal(std::get<3>(out1).value(), expected_edges)); auto out2 = pyg::sampler::neighbor_sample( - rowptr, col, /*weights=*/at::ones(8), seed, /*num_neighbors=*/{1, 2}, + /*rowptr=*/rowptr, + /*col=*/col, + /*seed=*/seed, + /*num_neighbors=*/{1, 2}, /*time=*/time, - /*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false, - /*directed=*/true, /*disjoint=*/true, /*temporal_strategy=*/"last"); + /*seed_time=*/c10::nullopt, + /*edge_weight=*/c10::nullopt, + /*csc=*/false, + /*replace=*/false, + /*directed=*/true, + /*disjoint=*/true, + /*temporal_strategy=*/"last"); EXPECT_TRUE(at::equal(std::get<0>(out1), std::get<0>(out2))); EXPECT_TRUE(at::equal(std::get<1>(out1), std::get<1>(out2))); From 7192146708f8a5733d99312ace0ceca0d8151727 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 4 Sep 2023 08:21:19 +0000 Subject: [PATCH 11/13] update --- CHANGELOG.md | 1 + pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 16 ++++++++-------- test/csrc/sampler/test_neighbor.cpp | 4 +++- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c784c0e0..1fba33a08 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [0.3.0] - 2023-MM-DD ### Added +- Added support for homogeneous biased neighborhood sampling ([#247](https://github.com/pyg-team/pyg-lib/pull/247)) - Added dispatch for XPU device in `index_sort` ([#243](https://github.com/pyg-team/pyg-lib/pull/243)) - Added `metis` partitioning ([#229](https://github.com/pyg-team/pyg-lib/pull/229)) - Enable `hetero_neighbor_samplee` to work in parallel ([#211](https://github.com/pyg-team/pyg-lib/pull/211)) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index f4310bc20..c5fa5b4c7 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -46,8 +46,8 @@ class NeighborSampler { const auto row_start = rowptr_[to_scalar_t(global_src_node)]; const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; - if (row_end - row_start == 0) - || (count == 0) return; + if ((row_end - row_start == 0) || (count == 0)) + return; const auto weight = edge_weight.narrow(0, row_start, row_end - row_start); @@ -64,8 +64,8 @@ class NeighborSampler { const auto row_start = rowptr_[to_scalar_t(global_src_node)]; const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; - if (row_end - row_start == 0) - || (count == 0) return; + if ((row_end - row_start == 0) || (count == 0)) + return; _sample(global_src_node, local_src_node, row_start, row_end, count, dst_mapper, generator, out_global_dst_nodes); @@ -82,8 +82,8 @@ class NeighborSampler { auto row_start = rowptr_[to_scalar_t(global_src_node)]; auto row_end = rowptr_[to_scalar_t(global_src_node) + 1]; - if (row_end - row_start == 0) - || (count == 0) return; + if ((row_end - row_start == 0) || (count == 0)) + return; // Find new `row_end` such that all neighbors fulfill temporal constraints: auto it = std::upper_bound( @@ -371,6 +371,7 @@ sample(const at::Tensor& rowptr, /*out_global_dst_nodes=*/sampled_nodes); } } else if constexpr (!std::is_scalar::value) { // Temporal: + const auto time_data = time.value().data_ptr(); for (size_t i = begin; i < end; ++i) { const auto batch_idx = sampled_nodes[i].first; sampler.temporal_sample( @@ -528,8 +529,7 @@ sample(const std::vector& node_types, threads_edge_types.push_back({k}); } } - if (!parallel) { // If not parallel then one thread handles all edge - // types. + if (!parallel) { // One thread handles all edge types. threads_edge_types.push_back({edge_types}); } diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index c2ac2d070..0ac4897a7 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -14,7 +14,9 @@ TEST(FullNeighborTest, BasicAssertions) { auto out = pyg::sampler::neighbor_sample( /*rowptr=*/std::get<0>(graph), - /*col=*/std::get<1>(graph), seed, num_neighbors); + /*col=*/std::get<1>(graph), + /*seed=*/seed, + /*num_neighbors=*/num_neighbors); auto expected_row = at::tensor({0, 0, 1, 1, 2, 2, 3, 3}, options); EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); From 36d21d3f5b05164bb3526595659c73cec223fd79 Mon Sep 17 00:00:00 2001 From: rusty1s Date: Mon, 4 Sep 2023 08:42:38 +0000 Subject: [PATCH 12/13] update --- pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp | 8 ++--- pyg_lib/csrc/sampler/neighbor.h | 3 +- test/csrc/sampler/test_neighbor.cpp | 34 +++++++++++++++++++- 3 files changed, 38 insertions(+), 7 deletions(-) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index c5fa5b4c7..afb33aed4 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -223,7 +223,7 @@ class NeighborSampler { // Case 2: Multinomial sampling: else { const auto index = at::multinomial(weight, count, replace); - const auto index_data = index.data_ptr(); + const auto index_data = index.data_ptr(); for (size_t i = 0; i < index.numel(); ++i) { add(row_start + index_data[i], global_src_node, local_src_node, dst_mapper, out_global_dst_nodes); @@ -295,7 +295,7 @@ sample(const at::Tensor& rowptr, TORCH_CHECK(seed_time.value().is_contiguous(), "Non-contiguous 'seed_time'"); } - TORCH_CHECK(time.has_value() && edge_weight.has_value(), + TORCH_CHECK(!(time.has_value() && edge_weight.has_value()), "Biased temporal sampling not yet supported"); at::Tensor out_row, out_col, out_node_id; @@ -351,7 +351,7 @@ sample(const at::Tensor& rowptr, sampler.num_sampled_edges_per_hop.push_back(0); if (edge_weight.has_value()) { for (size_t i = begin; i < end; ++i) { - sampler.multinomial_sample( + sampler.biased_sample( /*global_src_node=*/sampled_nodes[i], /*local_src_node=*/i, /*edge_weight=*/edge_weight.value(), @@ -360,7 +360,7 @@ sample(const at::Tensor& rowptr, /*generator=*/generator, /*out_global_dst_nodes=*/sampled_nodes); } - } else if (!time.has_value) { + } else if (!time.has_value()) { for (size_t i = begin; i < end; ++i) { sampler.uniform_sample( /*global_src_node=*/sampled_nodes[i], diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index 86d4aa426..80c99d656 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -29,8 +29,7 @@ neighbor_sample(const at::Tensor& rowptr, bool directed = true, bool disjoint = false, std::string strategy = "uniform", - bool return_edge_id = true, - bool multinomial_mode = false); + bool return_edge_id = true); // Recursively samples neighbors from all node indices in `seed_dict` // in the heterogeneous graph given by `(rowptr_dict, col_dict)`. diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index 0ac4897a7..d12dc1012 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -5,7 +5,7 @@ #include "pyg_lib/csrc/utils/types.h" #include "test/csrc/graph.h" -TEST(FullNeighborTest, BasicAssertions) { +TEST(BasicNeighborTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); auto graph = cycle_graph(/*num_nodes=*/6, options); @@ -215,3 +215,35 @@ TEST(HeteroNeighborTest, BasicAssertions) { std::vector expected_num_edges = {4, 4}; EXPECT_TRUE(std::get<5>(out).at("paper__to__paper") == expected_num_edges); } + + +TEST(BiasedNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + auto graph = cycle_graph(/*num_nodes=*/6, options); + auto seed = at::arange(0, 2, options); + std::vector num_neighbors = {1}; + + auto ones = at::ones(6).view({-1, 1}); + auto zeros = at::zeros(6).view({-1, 1}); + // Only sample even edges: + auto edge_weight = at::cat({ones, zeros}, -1).view(-1); + + auto out = pyg::sampler::neighbor_sample( + /*rowptr=*/std::get<0>(graph), + /*col=*/std::get<1>(graph), + /*seed=*/seed, + /*num_neighbors=*/num_neighbors, + /*time=*/c10::nullopt, + /*seed_time=*/c10::nullopt, + /*edge_weight=*/edge_weight); + + auto expected_row = at::tensor({0, 1}, options); + EXPECT_TRUE(at::equal(std::get<0>(out), expected_row)); + auto expected_col = at::tensor({2, 0}, options); + EXPECT_TRUE(at::equal(std::get<1>(out), expected_col)); + auto expected_nodes = at::tensor({0, 1, 5}, options); + EXPECT_TRUE(at::equal(std::get<2>(out), expected_nodes)); + auto expected_edges = at::tensor({0, 2}, options); + EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); +} From db890bcd7329b519c86940c8caee7665e6a652a6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 4 Sep 2023 08:43:03 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/csrc/sampler/test_neighbor.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index d12dc1012..e57fd72d2 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -216,7 +216,6 @@ TEST(HeteroNeighborTest, BasicAssertions) { EXPECT_TRUE(std::get<5>(out).at("paper__to__paper") == expected_num_edges); } - TEST(BiasedNeighborTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong);