diff --git a/CHANGELOG.md b/CHANGELOG.md index d6d64b24e..33d9d10a4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [Unreleased] ### Added +- Added `pyg::sampler::hetero_neighbor_sample` implementation ([#90](https://github.com/pyg-team/pyg-lib/pull/90)) - Added `pyg::utils::to_vector` implementation ([#88](https://github.com/pyg-team/pyg-lib/pull/88)) - Added support for PyTorch 1.12 ([#57](https://github.com/pyg-team/pyg-lib/pull/57), [#58](https://github.com/pyg-team/pyg-lib/pull/58)) - Added `grouped_matmul` and `segment_matmul` CUDA implementations via `cutlass` ([#51](https://github.com/pyg-team/pyg-lib/pull/51), [#56](https://github.com/pyg-team/pyg-lib/pull/56), [#61](https://github.com/pyg-team/pyg-lib/pull/61), [#64](https://github.com/pyg-team/pyg-lib/pull/64), [#69](https://github.com/pyg-team/pyg-lib/pull/69)) diff --git a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp index 23cacab72..ba0500d21 100644 --- a/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp +++ b/pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp @@ -5,13 +5,16 @@ #include "pyg_lib/csrc/sampler/cpu/mapper.h" #include "pyg_lib/csrc/sampler/subgraph.h" #include "pyg_lib/csrc/utils/cpu/convert.h" +#include "pyg_lib/csrc/utils/types.h" namespace pyg { namespace sampler { namespace { -// `node_t` is either a scalar or a pair of scalars of (example_id, node_id): +// Helper classes for bipartite neighbor sampling ////////////////////////////// + +// `node_t` is either a scalar or a pair of scalars (example_id, node_id): template sampled_edge_ids_; }; +// Homogeneous neighbor sampling /////////////////////////////////////////////// + template std::tuple> sample(const at::Tensor& rowptr, @@ -312,6 +317,28 @@ neighbor_sample_kernel(const at::Tensor& rowptr, return sample(rowptr, col, seed, count, time); } +// Heterogeneous neighbor sampling //////////////////////////////////////////// + +std::tuple, + c10::Dict, + c10::Dict, + c10::optional>> +hetero_neighbor_sample_kernel( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& rowptr_dict, + const c10::Dict& col_dict, + const c10::Dict& seed_dict, + const c10::Dict>& num_neighbors_dict, + const c10::optional>& time_dict, + bool replace, + bool directed, + bool disjoint, + bool return_edge_id) { + std::cout << "hetero_neighbor_sample_kernel" << std::endl; + return std::make_tuple(col_dict, col_dict, col_dict, col_dict); +} + } // namespace TORCH_LIBRARY_IMPL(pyg, CPU, m) { @@ -319,5 +346,11 @@ TORCH_LIBRARY_IMPL(pyg, CPU, m) { TORCH_FN(neighbor_sample_kernel)); } +TORCH_LIBRARY_FRAGMENT(pyg, m) { + // TODO (matthias) fix automatic dispatching + m.def(TORCH_SELECTIVE_NAME("pyg::hetero_neighbor_sample_cpu"), + TORCH_FN(hetero_neighbor_sample_kernel)); +} + } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/sampler/neighbor.cpp b/pyg_lib/csrc/sampler/neighbor.cpp index 5dda7407f..c2d647d80 100644 --- a/pyg_lib/csrc/sampler/neighbor.cpp +++ b/pyg_lib/csrc/sampler/neighbor.cpp @@ -31,11 +31,43 @@ neighbor_sample(const at::Tensor& rowptr, disjoint, return_edge_id); } +std::tuple, + c10::Dict, + c10::Dict, + c10::optional>> +hetero_neighbor_sample( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& rowptr_dict, + const c10::Dict& col_dict, + const c10::Dict& seed_dict, + const c10::Dict>& num_neighbors_dict, + const c10::optional>& time_dict, + bool replace, + bool directed, + bool disjoint, + bool return_edge_id) { + // TODO (matthias) Add TensorArg definitions and type checks. + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("pyg::hetero_neighbor_sample_cpu", "") + .typed(); + return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict, + num_neighbors_dict, time_dict, replace, directed, disjoint, + return_edge_id); +} + TORCH_LIBRARY_FRAGMENT(pyg, m) { m.def(TORCH_SELECTIVE_SCHEMA( "pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] " "num_neighbors, Tensor? time, bool replace, bool directed, bool " "disjoint, bool return_edge_id) -> (Tensor, Tensor, Tensor, Tensor?)")); + m.def(TORCH_SELECTIVE_SCHEMA( + "pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] " + "edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, " + "Dict(str, Tensor) seed_dict, Dict(str, int[]) num_neighbors_dict, " + "Dict(str, Tensor)? time_dict, bool replace, bool directed, bool " + "disjoint, bool return_edge_id) -> (Dict(str, Tensor), Dict(str, " + "Tensor), Dict(str, Tensor), Dict(str, Tensor)?)")); } } // namespace sampler diff --git a/pyg_lib/csrc/sampler/neighbor.h b/pyg_lib/csrc/sampler/neighbor.h index 68c27caec..f0e9aa7f1 100644 --- a/pyg_lib/csrc/sampler/neighbor.h +++ b/pyg_lib/csrc/sampler/neighbor.h @@ -2,6 +2,7 @@ #include #include "pyg_lib/csrc/macros.h" +#include "pyg_lib/csrc/utils/types.h" namespace pyg { namespace sampler { @@ -21,5 +22,27 @@ neighbor_sample(const at::Tensor& rowptr, bool disjoint = false, bool return_edge_id = true); +// Recursively samples neighbors from all node indices in `seed_dict` +// in the heterogeneous graph given by `(rowptr_dict, col_dict)`. +// Returns: (row_dict, col_dict, node_id_dict, edge_id_dict) +PYG_API +std::tuple, + c10::Dict, + c10::Dict, + c10::optional>> +hetero_neighbor_sample( + const std::vector& node_types, + const std::vector& edge_types, + const c10::Dict& rowptr_dict, + const c10::Dict& col_dict, + const c10::Dict& seed_dict, + const c10::Dict>& num_neighbors_dict, + const c10::optional>& time_dict = + c10::nullopt, + bool replace = false, + bool directed = true, + bool disjoint = false, + bool return_edge_id = true); + } // namespace sampler } // namespace pyg diff --git a/pyg_lib/csrc/utils/types.h b/pyg_lib/csrc/utils/types.h new file mode 100644 index 000000000..159f4daef --- /dev/null +++ b/pyg_lib/csrc/utils/types.h @@ -0,0 +1,8 @@ +#pragma once + +#include +#include + +typedef std::string node_t; +typedef std::string rel_t; +typedef std::tuple edge_t; diff --git a/pyg_lib/sampler/__init__.py b/pyg_lib/sampler/__init__.py index 7890c6428..61c7fb5cc 100644 --- a/pyg_lib/sampler/__init__.py +++ b/pyg_lib/sampler/__init__.py @@ -1,8 +1,12 @@ -from typing import List, Tuple, Optional +from typing import List, Tuple, Optional, Dict import torch from torch import Tensor +NodeType = str +RelType = str +EdgeType = Tuple[str, str, str] + def neighbor_sample( rowptr: Tensor, @@ -51,6 +55,70 @@ def neighbor_sample( return_edge_id) +def hetero_neighbor_sample( + rowptr_dict: Dict[EdgeType, Tensor], + col_dict: Dict[EdgeType, Tensor], + seed_dict: Dict[NodeType, Tensor], + num_neighbors_dict: Dict[EdgeType, List[int]], + time_dict: Optional[Dict[NodeType, Tensor]] = None, + replace: bool = False, + directed: bool = True, + disjoint: bool = False, + return_edge_id: bool = True, +) -> Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor], Dict[ + NodeType, Tensor], Optional[Dict[EdgeType, Tensor]]]: + r"""Recursively samples neighbors from all node indices in :obj:`seed_dict` + in the heterogeneous graph given by :obj:`(rowptr_dict, col_dict)`. + + .. note :: + Similar to :meth:`neighbor_sample`, but expects a dictionary of node + types (:obj:`str`) and edge tpyes (:obj:`Tuple[str, str, str]`) for + each non-boolean argument. + + Args: + kwargs: Arguments of :meth:`neighbor_sample`. + """ + src_node_types = {k[0] for k in rowptr_dict.keys()} + dst_node_types = {k[-1] for k in rowptr_dict.keys()} + node_types = list(src_node_types | dst_node_types) + edge_types = list(rowptr_dict.keys()) + + TO_REL_TYPE = {key: '__'.join(key) for key in edge_types} + TO_EDGE_TYPE = {'__'.join(key): key for key in edge_types} + + rowptr_dict = {TO_REL_TYPE[k]: v for k, v in rowptr_dict.items()} + col_dict = {TO_REL_TYPE[k]: v for k, v in col_dict.items()} + num_neighbors_dict = { + TO_REL_TYPE[k]: v + for k, v in num_neighbors_dict.items() + } + + out = torch.ops.pyg.hetero_neighbor_sample( + node_types, + edge_types, + rowptr_dict, + col_dict, + seed_dict, + num_neighbors_dict, + time_dict, + replace, + directed, + disjoint, + return_edge_id, + ) + + out_row_dict, out_col_dict, out_node_id_dict, out_edge_id_dict = out + out_row_dict = {TO_EDGE_TYPE[k]: v for k, v in out_row_dict.items()} + out_col_dict = {TO_EDGE_TYPE[k]: v for k, v in out_col_dict.items()} + if out_edge_id_dict is not None: + out_edge_id_dict = { + TO_EDGE_TYPE[k]: v + for k, v in out_edge_id_dict.items() + } + + return out_row_dict, out_col_dict, out_node_id_dict, out_edge_id_dict + + def subgraph( rowptr: Tensor, col: Tensor, @@ -103,6 +171,7 @@ def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int, __all__ = [ 'neighbor_sample', + 'hetero_neighbor_sample', 'subgraph', 'random_walk', ] diff --git a/test/csrc/sampler/test_neighbor.cpp b/test/csrc/sampler/test_neighbor.cpp index 065cee604..f9d3a6978 100644 --- a/test/csrc/sampler/test_neighbor.cpp +++ b/test/csrc/sampler/test_neighbor.cpp @@ -2,6 +2,7 @@ #include #include "pyg_lib/csrc/sampler/neighbor.h" +#include "pyg_lib/csrc/utils/types.h" #include "test/csrc/graph.h" TEST(NeighborTest, BasicAssertions) { @@ -25,7 +26,7 @@ TEST(NeighborTest, BasicAssertions) { EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); } -TEST(NeighborDisjointTest, BasicAssertions) { +TEST(DisjointNeighborTest, BasicAssertions) { auto options = at::TensorOptions().dtype(at::kLong); auto graph = cycle_graph(/*num_nodes=*/6, options); @@ -48,3 +49,24 @@ TEST(NeighborDisjointTest, BasicAssertions) { at::tensor({4, 5, 6, 7, 2, 3, 6, 7, 4, 5, 8, 9}, options); EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges)); } + +TEST(HeteroNeighborTest, BasicAssertions) { + auto options = at::TensorOptions().dtype(at::kLong); + + auto graph = cycle_graph(/*num_nodes=*/6, options); + std::vector node_types = {"paper"}; + std::vector edge_types = {{"paper", "to", "paper"}}; + c10::Dict rowptr_dict; + rowptr_dict.insert("paper__to__paper", std::get<0>(graph)); + c10::Dict col_dict; + col_dict.insert("paper__to__paper", std::get<1>(graph)); + c10::Dict seed_dict; + seed_dict.insert("paper", at::arange(2, 4, options)); + std::vector num_neighbors = {2, 2}; + c10::Dict> num_neighbors_dict; + num_neighbors_dict.insert("paper__to__paper", num_neighbors); + + auto out = pyg::sampler::hetero_neighbor_sample( + node_types, edge_types, rowptr_dict, col_dict, seed_dict, + num_neighbors_dict); +}