Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
- Added `pyg::sampler::hetero_neighbor_sample` implementation ([#90](https://github.com/pyg-team/pyg-lib/pull/90))
- Added `pyg::utils::to_vector` implementation ([#88](https://github.com/pyg-team/pyg-lib/pull/88))
- Added support for PyTorch 1.12 ([#57](https://github.com/pyg-team/pyg-lib/pull/57), [#58](https://github.com/pyg-team/pyg-lib/pull/58))
- Added `grouped_matmul` and `segment_matmul` CUDA implementations via `cutlass` ([#51](https://github.com/pyg-team/pyg-lib/pull/51), [#56](https://github.com/pyg-team/pyg-lib/pull/56), [#61](https://github.com/pyg-team/pyg-lib/pull/61), [#64](https://github.com/pyg-team/pyg-lib/pull/64), [#69](https://github.com/pyg-team/pyg-lib/pull/69))
Expand Down
35 changes: 34 additions & 1 deletion pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
#include "pyg_lib/csrc/sampler/cpu/mapper.h"
#include "pyg_lib/csrc/sampler/subgraph.h"
#include "pyg_lib/csrc/utils/cpu/convert.h"
#include "pyg_lib/csrc/utils/types.h"

namespace pyg {
namespace sampler {

namespace {

// `node_t` is either a scalar or a pair of scalars of (example_id, node_id):
// Helper classes for bipartite neighbor sampling //////////////////////////////

// `node_t` is either a scalar or a pair of scalars (example_id, node_id):
template <typename node_t,
typename scalar_t,
bool replace,
Expand Down Expand Up @@ -191,6 +194,8 @@ class NeighborSampler {
std::vector<scalar_t> sampled_edge_ids_;
};

// Homogeneous neighbor sampling ///////////////////////////////////////////////

template <bool replace, bool directed, bool disjoint, bool return_edge_id>
std::tuple<at::Tensor, at::Tensor, at::Tensor, c10::optional<at::Tensor>>
sample(const at::Tensor& rowptr,
Expand Down Expand Up @@ -312,12 +317,40 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
return sample<false, false, false, false>(rowptr, col, seed, count, time);
}

// Heterogeneous neighbor sampling ////////////////////////////////////////////

std::tuple<c10::Dict<rel_t, at::Tensor>,
c10::Dict<rel_t, at::Tensor>,
c10::Dict<node_t, at::Tensor>,
c10::optional<c10::Dict<rel_t, at::Tensor>>>
hetero_neighbor_sample_kernel(
const std::vector<node_t>& node_types,
const std::vector<edge_t>& edge_types,
const c10::Dict<rel_t, at::Tensor>& rowptr_dict,
const c10::Dict<rel_t, at::Tensor>& col_dict,
const c10::Dict<node_t, at::Tensor>& seed_dict,
const c10::Dict<rel_t, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_t, at::Tensor>>& time_dict,
bool replace,
bool directed,
bool disjoint,
bool return_edge_id) {
std::cout << "hetero_neighbor_sample_kernel" << std::endl;
return std::make_tuple(col_dict, col_dict, col_dict, col_dict);
}

} // namespace

TORCH_LIBRARY_IMPL(pyg, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::neighbor_sample"),
TORCH_FN(neighbor_sample_kernel));
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
// TODO (matthias) fix automatic dispatching
m.def(TORCH_SELECTIVE_NAME("pyg::hetero_neighbor_sample_cpu"),
TORCH_FN(hetero_neighbor_sample_kernel));
}

} // namespace sampler
} // namespace pyg
32 changes: 32 additions & 0 deletions pyg_lib/csrc/sampler/neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,43 @@ neighbor_sample(const at::Tensor& rowptr,
disjoint, return_edge_id);
}

std::tuple<c10::Dict<rel_t, at::Tensor>,
c10::Dict<rel_t, at::Tensor>,
c10::Dict<node_t, at::Tensor>,
c10::optional<c10::Dict<rel_t, at::Tensor>>>
hetero_neighbor_sample(
const std::vector<node_t>& node_types,
const std::vector<edge_t>& edge_types,
const c10::Dict<rel_t, at::Tensor>& rowptr_dict,
const c10::Dict<rel_t, at::Tensor>& col_dict,
const c10::Dict<node_t, at::Tensor>& seed_dict,
const c10::Dict<rel_t, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_t, at::Tensor>>& time_dict,
bool replace,
bool directed,
bool disjoint,
bool return_edge_id) {
// TODO (matthias) Add TensorArg definitions and type checks.
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::hetero_neighbor_sample_cpu", "")
.typed<decltype(hetero_neighbor_sample)>();
return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, time_dict, replace, directed, disjoint,
return_edge_id);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] "
"num_neighbors, Tensor? time, bool replace, bool directed, bool "
"disjoint, bool return_edge_id) -> (Tensor, Tensor, Tensor, Tensor?)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] "
"edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, "
"Dict(str, Tensor) seed_dict, Dict(str, int[]) num_neighbors_dict, "
"Dict(str, Tensor)? time_dict, bool replace, bool directed, bool "
"disjoint, bool return_edge_id) -> (Dict(str, Tensor), Dict(str, "
"Tensor), Dict(str, Tensor), Dict(str, Tensor)?)"));
}

} // namespace sampler
Expand Down
23 changes: 23 additions & 0 deletions pyg_lib/csrc/sampler/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <ATen/ATen.h>
#include "pyg_lib/csrc/macros.h"
#include "pyg_lib/csrc/utils/types.h"

namespace pyg {
namespace sampler {
Expand All @@ -21,5 +22,27 @@ neighbor_sample(const at::Tensor& rowptr,
bool disjoint = false,
bool return_edge_id = true);

// Recursively samples neighbors from all node indices in `seed_dict`
// in the heterogeneous graph given by `(rowptr_dict, col_dict)`.
// Returns: (row_dict, col_dict, node_id_dict, edge_id_dict)
PYG_API
std::tuple<c10::Dict<rel_t, at::Tensor>,
c10::Dict<rel_t, at::Tensor>,
c10::Dict<node_t, at::Tensor>,
c10::optional<c10::Dict<rel_t, at::Tensor>>>
hetero_neighbor_sample(
const std::vector<node_t>& node_types,
const std::vector<edge_t>& edge_types,
const c10::Dict<rel_t, at::Tensor>& rowptr_dict,
const c10::Dict<rel_t, at::Tensor>& col_dict,
const c10::Dict<node_t, at::Tensor>& seed_dict,
const c10::Dict<rel_t, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_t, at::Tensor>>& time_dict =
c10::nullopt,
bool replace = false,
bool directed = true,
bool disjoint = false,
bool return_edge_id = true);

} // namespace sampler
} // namespace pyg
8 changes: 8 additions & 0 deletions pyg_lib/csrc/utils/types.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#pragma once

#include <string>
#include <tuple>

typedef std::string node_t;
typedef std::string rel_t;
typedef std::tuple<std::string, std::string, std::string> edge_t;
71 changes: 70 additions & 1 deletion pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import List, Tuple, Optional
from typing import List, Tuple, Optional, Dict

import torch
from torch import Tensor

NodeType = str
RelType = str
EdgeType = Tuple[str, str, str]


def neighbor_sample(
rowptr: Tensor,
Expand Down Expand Up @@ -51,6 +55,70 @@ def neighbor_sample(
return_edge_id)


def hetero_neighbor_sample(
rowptr_dict: Dict[EdgeType, Tensor],
col_dict: Dict[EdgeType, Tensor],
seed_dict: Dict[NodeType, Tensor],
num_neighbors_dict: Dict[EdgeType, List[int]],
time_dict: Optional[Dict[NodeType, Tensor]] = None,
replace: bool = False,
directed: bool = True,
disjoint: bool = False,
return_edge_id: bool = True,
) -> Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor], Dict[
NodeType, Tensor], Optional[Dict[EdgeType, Tensor]]]:
r"""Recursively samples neighbors from all node indices in :obj:`seed_dict`
in the heterogeneous graph given by :obj:`(rowptr_dict, col_dict)`.

.. note ::
Similar to :meth:`neighbor_sample`, but expects a dictionary of node
types (:obj:`str`) and edge tpyes (:obj:`Tuple[str, str, str]`) for
each non-boolean argument.

Args:
kwargs: Arguments of :meth:`neighbor_sample`.
"""
src_node_types = {k[0] for k in rowptr_dict.keys()}
dst_node_types = {k[-1] for k in rowptr_dict.keys()}
node_types = list(src_node_types | dst_node_types)
edge_types = list(rowptr_dict.keys())

TO_REL_TYPE = {key: '__'.join(key) for key in edge_types}
TO_EDGE_TYPE = {'__'.join(key): key for key in edge_types}

rowptr_dict = {TO_REL_TYPE[k]: v for k, v in rowptr_dict.items()}
col_dict = {TO_REL_TYPE[k]: v for k, v in col_dict.items()}
num_neighbors_dict = {
TO_REL_TYPE[k]: v
for k, v in num_neighbors_dict.items()
}

out = torch.ops.pyg.hetero_neighbor_sample(
node_types,
edge_types,
rowptr_dict,
col_dict,
seed_dict,
num_neighbors_dict,
time_dict,
replace,
directed,
disjoint,
return_edge_id,
)

out_row_dict, out_col_dict, out_node_id_dict, out_edge_id_dict = out
out_row_dict = {TO_EDGE_TYPE[k]: v for k, v in out_row_dict.items()}
out_col_dict = {TO_EDGE_TYPE[k]: v for k, v in out_col_dict.items()}
if out_edge_id_dict is not None:
out_edge_id_dict = {
TO_EDGE_TYPE[k]: v
for k, v in out_edge_id_dict.items()
}

return out_row_dict, out_col_dict, out_node_id_dict, out_edge_id_dict


def subgraph(
rowptr: Tensor,
col: Tensor,
Expand Down Expand Up @@ -103,6 +171,7 @@ def random_walk(rowptr: Tensor, col: Tensor, seed: Tensor, walk_length: int,

__all__ = [
'neighbor_sample',
'hetero_neighbor_sample',
'subgraph',
'random_walk',
]
24 changes: 23 additions & 1 deletion test/csrc/sampler/test_neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <gtest/gtest.h>

#include "pyg_lib/csrc/sampler/neighbor.h"
#include "pyg_lib/csrc/utils/types.h"
#include "test/csrc/graph.h"

TEST(NeighborTest, BasicAssertions) {
Expand All @@ -25,7 +26,7 @@ TEST(NeighborTest, BasicAssertions) {
EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges));
}

TEST(NeighborDisjointTest, BasicAssertions) {
TEST(DisjointNeighborTest, BasicAssertions) {
auto options = at::TensorOptions().dtype(at::kLong);

auto graph = cycle_graph(/*num_nodes=*/6, options);
Expand All @@ -48,3 +49,24 @@ TEST(NeighborDisjointTest, BasicAssertions) {
at::tensor({4, 5, 6, 7, 2, 3, 6, 7, 4, 5, 8, 9}, options);
EXPECT_TRUE(at::equal(std::get<3>(out).value(), expected_edges));
}

TEST(HeteroNeighborTest, BasicAssertions) {
auto options = at::TensorOptions().dtype(at::kLong);

auto graph = cycle_graph(/*num_nodes=*/6, options);
std::vector<node_t> node_types = {"paper"};
std::vector<edge_t> edge_types = {{"paper", "to", "paper"}};
c10::Dict<rel_t, at::Tensor> rowptr_dict;
rowptr_dict.insert("paper__to__paper", std::get<0>(graph));
c10::Dict<rel_t, at::Tensor> col_dict;
col_dict.insert("paper__to__paper", std::get<1>(graph));
c10::Dict<node_t, at::Tensor> seed_dict;
seed_dict.insert("paper", at::arange(2, 4, options));
std::vector<int64_t> num_neighbors = {2, 2};
c10::Dict<rel_t, std::vector<int64_t>> num_neighbors_dict;
num_neighbors_dict.insert("paper__to__paper", num_neighbors);

auto out = pyg::sampler::hetero_neighbor_sample(
node_types, edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict);
}