Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
- Added `temporal_strategy` option to `neighbor_sample` ([#114](https://github.com/pyg-team/pyg-lib/pull/114))
- Added benchmarking tool (Google Benchmark) along with `pyg::sampler::Mapper` benchmark example ([#101](https://github.com/pyg-team/pyg-lib/pull/101))
- Added CSC mode to `pyg::sampler::neighbor_sample` and `pyg::sampler::hetero_neighbor_sample` ([#95](https://github.com/pyg-team/pyg-lib/pull/95), [#96](https://github.com/pyg-team/pyg-lib/pull/96))
- Speed up `pyg::sampler::neighbor_sample` via `IndexTracker` implementation ([#84](https://github.com/pyg-team/pyg-lib/pull/84))
Expand Down
36 changes: 26 additions & 10 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,13 @@ template <typename node_t,
bool save_edge_ids>
class NeighborSampler {
public:
NeighborSampler(const scalar_t* rowptr, const scalar_t* col)
: rowptr_(rowptr), col_(col) {}
NeighborSampler(const scalar_t* rowptr,
const scalar_t* col,
const std::string temporal_strategy)
: rowptr_(rowptr), col_(col), temporal_strategy_(temporal_strategy) {
TORCH_CHECK(temporal_strategy == "uniform" || temporal_strategy == "last",
"No valid temporal strategy found");
}

void uniform_sample(const node_t global_src_node,
const scalar_t local_src_node,
Expand All @@ -48,7 +53,7 @@ class NeighborSampler {
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
const auto row_start = rowptr_[to_scalar_t(global_src_node)];
auto row_start = rowptr_[to_scalar_t(global_src_node)];
auto row_end = rowptr_[to_scalar_t(global_src_node) + 1];

// Find new `row_end` such that all neighbors fulfill temporal constraints:
Expand All @@ -57,6 +62,10 @@ class NeighborSampler {
[&](const scalar_t& a, const scalar_t& b) { return time[a] < b; });
row_end = it - col_;

if (temporal_strategy_ == "last") {
row_start = std::max(row_start, (scalar_t)(row_end - count));
}

_sample(global_src_node, local_src_node, row_start, row_end, count,
dst_mapper, generator, out_global_dst_nodes);
}
Expand Down Expand Up @@ -164,6 +173,7 @@ class NeighborSampler {

const scalar_t* rowptr_;
const scalar_t* col_;
const std::string temporal_strategy_;
std::vector<scalar_t> sampled_rows_;
std::vector<scalar_t> sampled_cols_;
std::vector<scalar_t> sampled_edge_ids_;
Expand All @@ -178,7 +188,8 @@ sample(const at::Tensor& rowptr,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const bool csc) {
const bool csc,
const std::string temporal_strategy) {
TORCH_CHECK(!time.has_value() || disjoint,
"Temporal sampling needs to create disjoint subgraphs");

Expand All @@ -202,8 +213,9 @@ sample(const at::Tensor& rowptr,

std::vector<node_t> sampled_nodes;
auto mapper = Mapper<node_t, scalar_t>(/*num_nodes=*/rowptr.size(0) - 1);
auto sampler = NeighborSamplerImpl(rowptr.data_ptr<scalar_t>(),
col.data_ptr<scalar_t>());
auto sampler =
NeighborSamplerImpl(rowptr.data_ptr<scalar_t>(),
col.data_ptr<scalar_t>(), temporal_strategy);

const auto seed_data = seed.data_ptr<scalar_t>();
if constexpr (!disjoint) {
Expand Down Expand Up @@ -266,7 +278,8 @@ sample(const std::vector<node_type>& node_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const bool csc) {
const bool csc,
const std::string temporal_strategy) {
TORCH_CHECK(!time_dict.has_value() || disjoint,
"Temporal sampling needs to create disjoint subgraphs");

Expand Down Expand Up @@ -337,7 +350,8 @@ sample(const std::vector<node_type>& node_types,
sampler_dict.insert(
{k, NeighborSamplerImpl(
rowptr_dict.at(to_rel_type(k)).data_ptr<scalar_t>(),
col_dict.at(to_rel_type(k)).data_ptr<scalar_t>())});
col_dict.at(to_rel_type(k)).data_ptr<scalar_t>(),
temporal_strategy)});
}

scalar_t i = 0;
Expand Down Expand Up @@ -476,9 +490,10 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col,
seed, num_neighbors, time, csc);
seed, num_neighbors, time, csc, temporal_strategy);
}

std::tuple<c10::Dict<rel_type, at::Tensor>,
Expand All @@ -497,10 +512,11 @@ hetero_neighbor_sample_kernel(
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, node_types,
edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, time_dict, csc);
num_neighbors_dict, time_dict, csc, temporal_strategy);
}

} // namespace
Expand Down
18 changes: 11 additions & 7 deletions pyg_lib/csrc/sampler/neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ neighbor_sample(const at::Tensor& rowptr,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
at::TensorArg rowptr_t{rowptr, "rowtpr", 1};
at::TensorArg col_t{col, "col", 1};
Expand All @@ -29,7 +30,7 @@ neighbor_sample(const at::Tensor& rowptr,
.findSchemaOrThrow("pyg::neighbor_sample", "")
.typed<decltype(neighbor_sample)>();
return op.call(rowptr, col, seed, num_neighbors, time, csc, replace, directed,
disjoint, return_edge_id);
disjoint, temporal_strategy, return_edge_id);
}

std::tuple<c10::Dict<rel_type, at::Tensor>,
Expand All @@ -48,30 +49,33 @@ hetero_neighbor_sample(
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
// TODO (matthias) Add TensorArg definitions and type checks.
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::hetero_neighbor_sample_cpu", "")
.typed<decltype(hetero_neighbor_sample)>();
return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, time_dict, csc, replace, directed,
disjoint, return_edge_id);
disjoint, temporal_strategy, return_edge_id);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] "
"num_neighbors, Tensor? time = None, bool csc = False, bool replace = "
"False, bool directed = True, bool disjoint = False, bool return_edge_id "
"= True) -> (Tensor, Tensor, Tensor, Tensor?)"));
"False, bool directed = True, bool disjoint = False, str "
"temporal_strategy = 'uniform', bool return_edge_id = True) -> (Tensor, "
"Tensor, Tensor, Tensor?)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] "
"edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, "
"Dict(str, Tensor) seed_dict, Dict(str, int[]) num_neighbors_dict, "
"Dict(str, Tensor)? time_dict = None, bool csc = False, bool replace = "
"False, bool directed = True, bool disjoint = False, bool return_edge_id "
"= True) -> (Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), "
"Dict(str, Tensor)?)"));
"False, bool directed = True, bool disjoint = False, str "
"temporal_strategy = 'uniform', bool return_edge_id = True) -> "
"(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), Dict(str, "
"Tensor)?)"));
}

} // namespace sampler
Expand Down
2 changes: 2 additions & 0 deletions pyg_lib/csrc/sampler/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ neighbor_sample(const at::Tensor& rowptr,
bool replace = false,
bool directed = true,
bool disjoint = false,
std::string strategy = "uniform",
bool return_edge_id = true);

// Recursively samples neighbors from all node indices in `seed_dict`
Expand All @@ -44,6 +45,7 @@ hetero_neighbor_sample(
bool replace = false,
bool directed = true,
bool disjoint = false,
std::string strategy = "uniform",
bool return_edge_id = true);

} // namespace sampler
Expand Down
9 changes: 8 additions & 1 deletion pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def neighbor_sample(
replace: bool = False,
directed: bool = True,
disjoint: bool = False,
temporal_strategy: str = 'uniform',
return_edge_id: bool = True,
) -> Tuple[Tensor, Tensor, Tensor, Optional[Tensor]]:
r"""Recursively samples neighbors from all node indices in :obj:`seed`
Expand Down Expand Up @@ -51,6 +52,9 @@ def neighbor_sample(
edges between all sampled nodes. (default: :obj:`True`)
disjoint (bool, optional): If set to :obj:`True` , will create disjoint
subgraphs for every seed node. (default: :obj:`False`)
temporal_strategy (string, optional): The sampling strategy when using
temporal sampling (:obj:`"uniform"`, :obj:`"last"`).
(default: :obj:`"uniform"`)
return_edge_id (bool, optional): If set to :obj:`False`, will not
return the indices of edges of the original graph.
(default: :obj: `True`)
Expand All @@ -63,7 +67,8 @@ def neighbor_sample(
"""
return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors,
time, csc, replace, directed,
disjoint, return_edge_id)
disjoint, temporal_strategy,
return_edge_id)


def hetero_neighbor_sample(
Expand All @@ -76,6 +81,7 @@ def hetero_neighbor_sample(
replace: bool = False,
directed: bool = True,
disjoint: bool = False,
temporal_strategy: str = 'uniform',
return_edge_id: bool = True,
) -> Tuple[Dict[EdgeType, Tensor], Dict[EdgeType, Tensor], Dict[
NodeType, Tensor], Optional[Dict[EdgeType, Tensor]]]:
Expand Down Expand Up @@ -117,6 +123,7 @@ def hetero_neighbor_sample(
replace,
directed,
disjoint,
temporal_strategy,
return_edge_id,
)

Expand Down