Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ If you already previously cloned `pyg-lib`, update it:
```
git pull
git submodule sync --recursive
git submodule update --init --recursive --jobs 0
git submodule update --init --recursive
```

Then, build the library via:
Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.3.0] - 2023-MM-DD
### Added
- Added support for homogeneous biased neighborhood sampling ([#247](https://github.com/pyg-team/pyg-lib/pull/247))
- Added dispatch for XPU device in `index_sort` ([#243](https://github.com/pyg-team/pyg-lib/pull/243))
- Added `metis` partitioning ([#229](https://github.com/pyg-team/pyg-lib/pull/229))
- Enable `hetero_neighbor_samplee` to work in parallel ([#211](https://github.com/pyg-team/pyg-lib/pull/211))
Expand Down
23 changes: 5 additions & 18 deletions pyg_lib/csrc/sampler/cpu/mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,11 @@ class Mapper {
// but slower hash map implementation. As a general rule of thumb, we are
// safe to use vectors in case the number of nodes are small, or it is
// expected that we sample a large amount of nodes.
use_vec = (num_nodes < 1000000) || (num_entries > (num_nodes / 10));
use_vec = std::is_scalar<node_t>::value && (num_nodes > 0) &&
((num_nodes < 1000000) || (num_entries > (num_nodes / 10)));

if (num_nodes <= 0) { // == `num_nodes` is undefined:
use_vec = false;
}

// We can only utilize vector mappings in case entries are scalar:
if (!std::is_scalar<node_t>::value) {
use_vec = false;
}

if (use_vec) {
if (use_vec)
to_local_vec.resize(num_nodes, -1);
}
}

std::pair<scalar_t, bool> insert(const node_t& node) {
Expand All @@ -49,7 +40,7 @@ class Mapper {
res = std::pair<scalar_t, bool>(out.first->second, out.second);
}
if (res.second) {
curr++;
++curr;
}
return res;
}
Expand All @@ -65,11 +56,7 @@ class Mapper {
}

bool exists(const node_t& node) {
if (use_vec) {
return to_local_vec[node] >= 0;
} else {
return to_local_map.count(node) > 0;
}
return use_vec ? to_local_vec[node] >= 0 : to_local_map.count(node) > 0;
}

scalar_t map(const node_t& node) {
Expand Down
115 changes: 94 additions & 21 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@ class NeighborSampler {
"No valid temporal strategy found");
}

void biased_sample(const node_t global_src_node,
const scalar_t local_src_node,
const at::Tensor& edge_weight,
const int64_t count,
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
const auto row_start = rowptr_[to_scalar_t(global_src_node)];
const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1];

if ((row_end - row_start == 0) || (count == 0))
return;

const auto weight = edge_weight.narrow(0, row_start, row_end - row_start);

_biased_sample(global_src_node, local_src_node, row_start, row_end, count,
weight, dst_mapper, generator, out_global_dst_nodes);
}

void uniform_sample(const node_t global_src_node,
const scalar_t local_src_node,
const int64_t count,
Expand All @@ -44,6 +63,10 @@ class NeighborSampler {
std::vector<node_t>& out_global_dst_nodes) {
const auto row_start = rowptr_[to_scalar_t(global_src_node)];
const auto row_end = rowptr_[to_scalar_t(global_src_node) + 1];

if ((row_end - row_start == 0) || (count == 0))
return;

_sample(global_src_node, local_src_node, row_start, row_end, count,
dst_mapper, generator, out_global_dst_nodes);
}
Expand All @@ -59,6 +82,9 @@ class NeighborSampler {
auto row_start = rowptr_[to_scalar_t(global_src_node)];
auto row_end = rowptr_[to_scalar_t(global_src_node) + 1];

if ((row_end - row_start == 0) || (count == 0))
return;

// Find new `row_end` such that all neighbors fulfill temporal constraints:
auto it = std::upper_bound(
col_ + row_start, col_ + row_end, seed_time,
Expand Down Expand Up @@ -117,14 +143,8 @@ class NeighborSampler {
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
if (count == 0)
return;

const auto population = row_end - row_start;

if (population == 0)
return;

// Case 1: Sample the full neighborhood:
if (count < 0 || (!replace && count >= population)) {
for (scalar_t edge_id = row_start; edge_id < row_end; ++edge_id) {
Expand Down Expand Up @@ -181,6 +201,36 @@ class NeighborSampler {
}
}

void _biased_sample(const node_t global_src_node,
const scalar_t local_src_node,
const scalar_t row_start,
const scalar_t row_end,
const int64_t count,
const at::Tensor& weight,
pyg::sampler::Mapper<node_t, scalar_t>& dst_mapper,
pyg::random::RandintEngine<scalar_t>& generator,
std::vector<node_t>& out_global_dst_nodes) {
const auto population = row_end - row_start;

// Case 1: Sample the full neighborhood:
if (count < 0 || (!replace && count >= population)) {
for (scalar_t edge_id = row_start; edge_id < row_end; ++edge_id) {
add(edge_id, global_src_node, local_src_node, dst_mapper,
out_global_dst_nodes);
}
}

// Case 2: Multinomial sampling:
else {
const auto index = at::multinomial(weight, count, replace);
const auto index_data = index.data_ptr<int64_t>();
for (size_t i = 0; i < index.numel(); ++i) {
add(row_start + index_data[i], global_src_node, local_src_node,
dst_mapper, out_global_dst_nodes);
}
}
}

inline void add(const scalar_t edge_id,
const node_t global_src_node,
const scalar_t local_src_node,
Expand Down Expand Up @@ -229,6 +279,7 @@ sample(const at::Tensor& rowptr,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
const bool csc,
const std::string temporal_strategy) {
TORCH_CHECK(!time.has_value() || disjoint,
Expand All @@ -244,6 +295,8 @@ sample(const at::Tensor& rowptr,
TORCH_CHECK(seed_time.value().is_contiguous(),
"Non-contiguous 'seed_time'");
}
TORCH_CHECK(!(time.has_value() && edge_weight.has_value()),
"Biased temporal sampling not yet supported");

at::Tensor out_row, out_col, out_node_id;
c10::optional<at::Tensor> out_edge_id = c10::nullopt;
Expand Down Expand Up @@ -296,21 +349,39 @@ sample(const at::Tensor& rowptr,
for (size_t ell = 0; ell < num_neighbors.size(); ++ell) {
const auto count = num_neighbors[ell];
sampler.num_sampled_edges_per_hop.push_back(0);
if (!time.has_value()) {
if (edge_weight.has_value()) {
for (size_t i = begin; i < end; ++i) {
sampler.uniform_sample(/*global_src_node=*/sampled_nodes[i],
/*local_src_node=*/i, count, mapper, generator,
/*out_global_dst_nodes=*/sampled_nodes);
sampler.biased_sample(
/*global_src_node=*/sampled_nodes[i],
/*local_src_node=*/i,
/*edge_weight=*/edge_weight.value(),
/*count=*/count,
/*dst_mapper=*/mapper,
/*generator=*/generator,
/*out_global_dst_nodes=*/sampled_nodes);
}
} else if (!time.has_value()) {
for (size_t i = begin; i < end; ++i) {
sampler.uniform_sample(
/*global_src_node=*/sampled_nodes[i],
/*local_src_node=*/i,
/*count=*/count,
/*dst_mapper=*/mapper,
/*generator=*/generator,
/*out_global_dst_nodes=*/sampled_nodes);
}
} else if constexpr (!std::is_scalar<node_t>::value) { // Temporal:
const auto time_data = time.value().data_ptr<temporal_t>();
for (size_t i = begin; i < end; ++i) {
const auto batch_idx = sampled_nodes[i].first;
sampler.temporal_sample(/*global_src_node=*/sampled_nodes[i],
/*local_src_node=*/i, count,
seed_times[batch_idx], time_data, mapper,
generator,
/*out_global_dst_nodes=*/sampled_nodes);
sampler.temporal_sample(
/*global_src_node=*/sampled_nodes[i],
/*local_src_node=*/i, /*count=*/count,
/*seed_time=*/seed_times[batch_idx],
/*time=*/time_data,
/*dst_mapper=*/mapper,
/*generator=*/generator,
/*out_global_dst_nodes=*/sampled_nodes);
}
}
begin = end, end = sampled_nodes.size();
Expand Down Expand Up @@ -442,9 +513,9 @@ sample(const std::vector<node_type>& node_types,
temporal_strategy)});

if (parallel) {
// Each thread is assigned edge types that have the same dst node type.
// Thanks to this, each thread will operate on a separate mapper and
// separate sampler.
// Each thread is assigned edge types that have the same dst node
// type. Thanks to this, each thread will operate on a separate mapper
// and separate sampler.
bool added = false;
const auto dst = !csc ? std::get<2>(k) : std::get<0>(k);
for (auto& e : threads_edge_types) {
Expand All @@ -458,7 +529,7 @@ sample(const std::vector<node_type>& node_types,
threads_edge_types.push_back({k});
}
}
if (!parallel) { // If not parallel then one thread handles all edge types.
if (!parallel) { // One thread handles all edge types.
threads_edge_types.push_back({edge_types});
}

Expand Down Expand Up @@ -511,7 +582,7 @@ sample(const std::vector<node_type>& node_types,
}
at::parallel_for(
0, threads_edge_types.size(), 1, [&](size_t _s, size_t _e) {
for (auto j = _s; j < _e; j++) {
for (auto j = _s; j < _e; ++j) {
for (const auto& k : threads_edge_types[j]) {
const auto src = !csc ? std::get<0>(k) : std::get<2>(k);
const auto dst = !csc ? std::get<2>(k) : std::get<0>(k);
Expand Down Expand Up @@ -644,14 +715,16 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col,
seed, num_neighbors, time, seed_time, csc, temporal_strategy);
seed, num_neighbors, time, seed_time, edge_weight, csc,
temporal_strategy);
}

std::tuple<c10::Dict<rel_type, at::Tensor>,
Expand Down
1 change: 1 addition & 0 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
bool replace,
bool directed,
Expand Down
14 changes: 8 additions & 6 deletions pyg_lib/csrc/sampler/neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ neighbor_sample(const at::Tensor& rowptr,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
bool replace,
bool directed,
Expand All @@ -37,8 +38,8 @@ neighbor_sample(const at::Tensor& rowptr,
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::neighbor_sample", "")
.typed<decltype(neighbor_sample)>();
return op.call(rowptr, col, seed, num_neighbors, time, seed_time, csc,
replace, directed, disjoint, temporal_strategy,
return op.call(rowptr, col, seed, num_neighbors, time, seed_time, edge_weight,
csc, replace, directed, disjoint, temporal_strategy,
return_edge_id);
}

Expand Down Expand Up @@ -94,10 +95,11 @@ hetero_neighbor_sample(
TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] "
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, bool csc "
"= False, bool replace = False, bool directed = True, bool disjoint = "
"False, str temporal_strategy = 'uniform', bool return_edge_id = True) "
"-> (Tensor, Tensor, Tensor, Tensor?, int[], int[])"));
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, Tensor? "
"edge_weight = None, bool csc = False, bool replace = False, bool "
"directed = True, bool disjoint = False, str temporal_strategy = "
"'uniform', bool return_edge_id = True) -> "
"(Tensor, Tensor, Tensor, Tensor?, int[], int[])"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] "
"edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, "
Expand Down
1 change: 1 addition & 0 deletions pyg_lib/csrc/sampler/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ neighbor_sample(const at::Tensor& rowptr,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time = c10::nullopt,
const c10::optional<at::Tensor>& seed_time = c10::nullopt,
const c10::optional<at::Tensor>& edge_weight = c10::nullopt,
bool csc = false,
bool replace = false,
bool directed = true,
Expand Down
7 changes: 4 additions & 3 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def neighbor_sample(
num_neighbors: List[int],
time: Optional[Tensor] = None,
seed_time: Optional[Tensor] = None,
edge_weight: Optional[Tensor] = None,
csc: bool = False,
replace: bool = False,
directed: bool = True,
Expand Down Expand Up @@ -73,9 +74,9 @@ def neighbor_sample(
per hop.
"""
return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors,
time, seed_time, csc, replace,
directed, disjoint, temporal_strategy,
return_edge_id)
time, seed_time, edge_weight, csc,
replace, directed, disjoint,
temporal_strategy, return_edge_id)


def hetero_neighbor_sample(
Expand Down
Loading