Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [Unreleased]
### Added
- Allow overriding `time` for seed nodes via `seed_time` in `neighbor_sample` ([#118](https://github.com/pyg-team/pyg-lib/pull/118))
- Added `[segment|grouped]_matmul` CPU implementation ([#111](https://github.com/pyg-team/pyg-lib/pull/111))
- Added `temporal_strategy` option to `neighbor_sample` ([#114](https://github.com/pyg-team/pyg-lib/pull/114))
- Added benchmarking tool (Google Benchmark) along with `pyg::sampler::Mapper` benchmark example ([#101](https://github.com/pyg-team/pyg-lib/pull/101))
Expand Down
88 changes: 59 additions & 29 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,16 +193,21 @@ sample(const at::Tensor& rowptr,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
const bool csc,
const std::string temporal_strategy) {
TORCH_CHECK(!time.has_value() || disjoint,
"Temporal sampling needs to create disjoint subgraphs");

TORCH_CHECK(rowptr.is_contiguous(), "Non-contiguous 'rowptr' vector");
TORCH_CHECK(col.is_contiguous(), "Non-contiguous 'col' vector");
TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed' vector");
TORCH_CHECK(rowptr.is_contiguous(), "Non-contiguous 'rowptr'");
TORCH_CHECK(col.is_contiguous(), "Non-contiguous 'col'");
TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'");
if (time.has_value()) {
TORCH_CHECK(time.value().is_contiguous(), "Non-contiguous 'time' vector");
TORCH_CHECK(time.value().is_contiguous(), "Non-contiguous 'time'");
}
if (seed_time.has_value()) {
TORCH_CHECK(seed_time.value().is_contiguous(),
"Non-contiguous 'seed_time'");
}

at::Tensor out_row, out_col, out_node_id;
Expand All @@ -221,16 +226,28 @@ sample(const at::Tensor& rowptr,
auto sampler =
NeighborSamplerImpl(rowptr.data_ptr<scalar_t>(),
col.data_ptr<scalar_t>(), temporal_strategy);
std::vector<scalar_t> seed_times;

const auto seed_data = seed.data_ptr<scalar_t>();
if constexpr (!disjoint) {
sampled_nodes = pyg::utils::to_vector<scalar_t>(seed);
mapper.fill(seed);
} else {
for (size_t i = 0; i < seed.numel(); i++) {
for (size_t i = 0; i < seed.numel(); ++i) {
sampled_nodes.push_back({i, seed_data[i]});
mapper.insert({i, seed_data[i]});
}
if (seed_time.has_value()) {
const auto seed_time_data = seed_time.value().data_ptr<scalar_t>();
for (size_t i = 0; i < seed.numel(); ++i) {
seed_times.push_back(seed_time_data[i]);
}
} else if (time.has_value()) {
const auto time_data = time.value().data_ptr<scalar_t>();
for (size_t i = 0; i < seed.numel(); ++i) {
seed_times.push_back(time_data[seed_data[i]]);
}
}
}

size_t begin = 0, end = seed.size(0);
Expand All @@ -246,11 +263,11 @@ sample(const at::Tensor& rowptr,
} else if constexpr (!std::is_scalar<node_t>::value) { // Temporal:
const auto time_data = time.value().data_ptr<scalar_t>();
for (size_t i = begin; i < end; ++i) {
const auto seed_node = seed_data[std::get<0>(sampled_nodes[i])];
const auto seed_time = time_data[seed_node];
const auto batch_idx = sampled_nodes[i].first;
sampler.temporal_sample(/*global_src_node=*/sampled_nodes[i],
/*local_src_node=*/i, count, seed_time,
time_data, mapper, generator,
/*local_src_node=*/i, count,
seed_times[batch_idx], time_data, mapper,
generator,
/*out_global_dst_nodes=*/sampled_nodes);
}
}
Expand Down Expand Up @@ -283,27 +300,34 @@ sample(const std::vector<node_type>& node_types,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const bool csc,
const std::string temporal_strategy) {
TORCH_CHECK(!time_dict.has_value() || disjoint,
"Temporal sampling needs to create disjoint subgraphs");

for (const auto& kv : rowptr_dict) {
const at::Tensor& rowptr = kv.value();
TORCH_CHECK(rowptr.is_contiguous(), "Non-contiguous 'rowptr' vector");
TORCH_CHECK(rowptr.is_contiguous(), "Non-contiguous 'rowptr'");
}
for (const auto& kv : col_dict) {
const at::Tensor& col = kv.value();
TORCH_CHECK(col.is_contiguous(), "Non-contiguous 'col' vector");
TORCH_CHECK(col.is_contiguous(), "Non-contiguous 'col'");
}
for (const auto& kv : seed_dict) {
const at::Tensor& seed = kv.value();
TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed' vector");
TORCH_CHECK(seed.is_contiguous(), "Non-contiguous 'seed'");
}
if (time_dict.has_value()) {
for (const auto& kv : time_dict.value()) {
const at::Tensor& time = kv.value();
TORCH_CHECK(time.is_contiguous(), "Non-contiguous 'time' vector");
TORCH_CHECK(time.is_contiguous(), "Non-contiguous 'time'");
}
}
if (seed_time_dict.has_value()) {
for (const auto& kv : seed_time_dict.value()) {
const at::Tensor& seed_time = kv.value();
TORCH_CHECK(seed_time.is_contiguous(), "Non-contiguous 'seed_time'");
}
}

Expand Down Expand Up @@ -343,9 +367,10 @@ sample(const std::vector<node_type>& node_types,
phmap::flat_hash_map<edge_type, NeighborSamplerImpl> sampler_dict;
phmap::flat_hash_map<node_type, std::pair<size_t, size_t>> slice_dict;
std::vector<scalar_t> seed_times;

for (const auto& k : node_types) {
const auto N = num_nodes_dict.count(k) > 0 ? num_nodes_dict.at(k) : 0;
sampled_nodes_dict[k]; // Initialize empty vector;
sampled_nodes_dict[k]; // Initialize empty vector.
mapper_dict.insert({k, Mapper<node_t, scalar_t>(N)});
slice_dict[k] = {0, 0};
}
Expand All @@ -359,7 +384,7 @@ sample(const std::vector<node_type>& node_types,
temporal_strategy)});
}

scalar_t i = 0;
scalar_t batch_idx = 0;
for (const auto& kv : seed_dict) {
const at::Tensor& seed = kv.value();
slice_dict[kv.key()] = {0, seed.size(0)};
Expand All @@ -371,20 +396,22 @@ sample(const std::vector<node_type>& node_types,
auto& sampled_nodes = sampled_nodes_dict.at(kv.key());
auto& mapper = mapper_dict.at(kv.key());
const auto seed_data = seed.data_ptr<scalar_t>();
if (!time_dict.has_value()) {
for (size_t j = 0; j < seed.numel(); j++) {
sampled_nodes.push_back({i, seed_data[j]});
mapper.insert({i, seed_data[j]});
i++;
for (size_t i = 0; i < seed.numel(); ++i) {
sampled_nodes.push_back({batch_idx, seed_data[i]});
mapper.insert({batch_idx, seed_data[i]});
batch_idx++;
}
if (seed_time_dict.has_value()) {
const at::Tensor& seed_time = seed_time_dict.value().at(kv.key());
const auto seed_time_data = seed_time.data_ptr<scalar_t>();
for (size_t i = 0; i < seed.numel(); ++i) {
seed_times.push_back(seed_time_data[i]);
}
} else {
} else if (time_dict.has_value()) {
const at::Tensor& time = time_dict.value().at(kv.key());
const auto time_data = time.data_ptr<scalar_t>();
for (size_t j = 0; j < seed.numel(); j++) {
sampled_nodes.push_back({i, seed_data[j]});
mapper.insert({i, seed_data[j]});
seed_times.push_back(time_data[seed_data[j]]);
i++;
for (size_t i = 0; i < seed.numel(); ++i) {
seed_times.push_back(time_data[seed_data[i]]);
}
}
}
Expand Down Expand Up @@ -412,7 +439,7 @@ sample(const std::vector<node_type>& node_types,
const at::Tensor& dst_time = time_dict.value().at(dst);
const auto dst_time_data = dst_time.data_ptr<scalar_t>();
for (size_t i = begin; i < end; ++i) {
const auto batch_idx = src_sampled_nodes[i].first;
batch_idx = src_sampled_nodes[i].first;
sampler.temporal_sample(/*global_src_node=*/src_sampled_nodes[i],
/*local_src_node=*/i, count,
seed_times[batch_idx], dst_time_data,
Expand Down Expand Up @@ -491,14 +518,15 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
bool csc,
bool replace,
bool directed,
bool disjoint,
std::string temporal_strategy,
bool return_edge_id) {
DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, rowptr, col,
seed, num_neighbors, time, csc, temporal_strategy);
seed, num_neighbors, time, seed_time, csc, temporal_strategy);
}

std::tuple<c10::Dict<rel_type, at::Tensor>,
Expand All @@ -513,6 +541,7 @@ hetero_neighbor_sample_kernel(
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
bool csc,
bool replace,
bool directed,
Expand All @@ -521,7 +550,8 @@ hetero_neighbor_sample_kernel(
bool return_edge_id) {
DISPATCH_SAMPLE(replace, directed, disjoint, return_edge_id, node_types,
edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, time_dict, csc, temporal_strategy);
num_neighbors_dict, time_dict, seed_time_dict, csc,
temporal_strategy);
}

} // namespace
Expand Down
29 changes: 16 additions & 13 deletions pyg_lib/csrc/sampler/neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ neighbor_sample(const at::Tensor& rowptr,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time,
const c10::optional<at::Tensor>& seed_time,
bool csc,
bool replace,
bool directed,
Expand All @@ -29,8 +30,9 @@ neighbor_sample(const at::Tensor& rowptr,
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::neighbor_sample", "")
.typed<decltype(neighbor_sample)>();
return op.call(rowptr, col, seed, num_neighbors, time, csc, replace, directed,
disjoint, temporal_strategy, return_edge_id);
return op.call(rowptr, col, seed, num_neighbors, time, seed_time, csc,
replace, directed, disjoint, temporal_strategy,
return_edge_id);
}

std::tuple<c10::Dict<rel_type, at::Tensor>,
Expand All @@ -45,6 +47,7 @@ hetero_neighbor_sample(
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
bool csc,
bool replace,
bool directed,
Expand All @@ -56,26 +59,26 @@ hetero_neighbor_sample(
.findSchemaOrThrow("pyg::hetero_neighbor_sample_cpu", "")
.typed<decltype(hetero_neighbor_sample)>();
return op.call(node_types, edge_types, rowptr_dict, col_dict, seed_dict,
num_neighbors_dict, time_dict, csc, replace, directed,
disjoint, temporal_strategy, return_edge_id);
num_neighbors_dict, time_dict, seed_time_dict, csc, replace,
directed, disjoint, temporal_strategy, return_edge_id);
}

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::neighbor_sample(Tensor rowptr, Tensor col, Tensor seed, int[] "
"num_neighbors, Tensor? time = None, bool csc = False, bool replace = "
"False, bool directed = True, bool disjoint = False, str "
"temporal_strategy = 'uniform', bool return_edge_id = True) -> (Tensor, "
"Tensor, Tensor, Tensor?)"));
"num_neighbors, Tensor? time = None, Tensor? seed_time = None, bool csc "
"= False, bool replace = False, bool directed = True, bool disjoint = "
"False, str temporal_strategy = 'uniform', bool return_edge_id = True) "
"-> (Tensor, Tensor, Tensor, Tensor?)"));
m.def(TORCH_SELECTIVE_SCHEMA(
"pyg::hetero_neighbor_sample(str[] node_types, (str, str, str)[] "
"edge_types, Dict(str, Tensor) rowptr_dict, Dict(str, Tensor) col_dict, "
"Dict(str, Tensor) seed_dict, Dict(str, int[]) num_neighbors_dict, "
"Dict(str, Tensor)? time_dict = None, bool csc = False, bool replace = "
"False, bool directed = True, bool disjoint = False, str "
"temporal_strategy = 'uniform', bool return_edge_id = True) -> "
"(Dict(str, Tensor), Dict(str, Tensor), Dict(str, Tensor), Dict(str, "
"Tensor)?)"));
"Dict(str, Tensor)? time_dict = None, Dict(str, Tensor)? seed_time_dict "
"= None, bool csc = False, bool replace = False, bool directed = True, "
"bool disjoint = False, str temporal_strategy = 'uniform', bool "
"return_edge_id = True) -> (Dict(str, Tensor), Dict(str, Tensor), "
"Dict(str, Tensor), Dict(str, Tensor)?)"));
}

} // namespace sampler
Expand Down
3 changes: 3 additions & 0 deletions pyg_lib/csrc/sampler/neighbor.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ neighbor_sample(const at::Tensor& rowptr,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const c10::optional<at::Tensor>& time = c10::nullopt,
const c10::optional<at::Tensor>& seed_time = c10::nullopt,
bool csc = false,
bool replace = false,
bool directed = true,
Expand All @@ -41,6 +42,8 @@ hetero_neighbor_sample(
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& time_dict =
c10::nullopt,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict =
c10::nullopt,
bool csc = false,
bool replace = false,
bool directed = true,
Expand Down
10 changes: 8 additions & 2 deletions pyg_lib/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def neighbor_sample(
seed: Tensor,
num_neighbors: List[int],
time: Optional[Tensor] = None,
seed_time: Optional[Tensor] = None,
csc: bool = False,
replace: bool = False,
directed: bool = True,
Expand Down Expand Up @@ -44,6 +45,9 @@ def neighbor_sample(
If used, the :obj:`col` vector needs to be sorted according to time
within individual neighborhoods. Requires :obj:`disjoint=True`.
(default: :obj:`None`)
seed_time (torch.Tensor, optional): Optional values to override the
timestamp for seed nodes. If not set, will use timestamps in
:obj:`time` as default for seed nodes. (default: :obj:`None`)
csc (bool, optional): If set to :obj:`True`, assumes that the graph is
given in CSC format :obj:`(colptr, row)`. (default: :obj:`False`)
replace (bool, optional): If set to :obj:`True`, will sample with
Expand All @@ -66,8 +70,8 @@ def neighbor_sample(
In addition, may return the indices of edges of the original graph.
"""
return torch.ops.pyg.neighbor_sample(rowptr, col, seed, num_neighbors,
time, csc, replace, directed,
disjoint, temporal_strategy,
time, seed_time, csc, replace,
directed, disjoint, temporal_strategy,
return_edge_id)


Expand All @@ -77,6 +81,7 @@ def hetero_neighbor_sample(
seed_dict: Dict[NodeType, Tensor],
num_neighbors_dict: Dict[EdgeType, List[int]],
time_dict: Optional[Dict[NodeType, Tensor]] = None,
seed_time_dict: Optional[Dict[NodeType, Tensor]] = None,
csc: bool = False,
replace: bool = False,
directed: bool = True,
Expand Down Expand Up @@ -119,6 +124,7 @@ def hetero_neighbor_sample(
seed_dict,
num_neighbors_dict,
time_dict,
seed_time_dict,
csc,
replace,
directed,
Expand Down
14 changes: 8 additions & 6 deletions test/csrc/sampler/test_neighbor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ TEST(WithoutReplacementNeighborTest, BasicAssertions) {
auto out = pyg::sampler::neighbor_sample(
/*rowptr=*/std::get<0>(graph),
/*col=*/std::get<1>(graph), seed, num_neighbors, /*time=*/c10::nullopt,
/*csc=*/false, /*replace=*/false);
/*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false);

auto expected_row = at::tensor({0, 1, 2, 3}, options);
EXPECT_TRUE(at::equal(std::get<0>(out), expected_row));
Expand All @@ -60,7 +60,7 @@ TEST(WithReplacementNeighborTest, BasicAssertions) {
auto out = pyg::sampler::neighbor_sample(
/*rowptr=*/std::get<0>(graph),
/*col=*/std::get<1>(graph), seed, num_neighbors, /*time=*/c10::nullopt,
/*csc=*/false, /*replace=*/true);
/*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/true);

auto expected_row = at::tensor({0, 1, 2, 3}, options);
EXPECT_TRUE(at::equal(std::get<0>(out), expected_row));
Expand All @@ -82,7 +82,8 @@ TEST(DisjointNeighborTest, BasicAssertions) {
auto out = pyg::sampler::neighbor_sample(
/*rowptr=*/std::get<0>(graph),
/*col=*/std::get<1>(graph), seed, num_neighbors, /*time=*/c10::nullopt,
/*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true);
/*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false,
/*directed=*/true, /*disjoint=*/true);

auto expected_row = at::tensor({0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5}, options);
EXPECT_TRUE(at::equal(std::get<0>(out), expected_row));
Expand Down Expand Up @@ -111,7 +112,8 @@ TEST(TemporalNeighborTest, BasicAssertions) {

auto out1 = pyg::sampler::neighbor_sample(
rowptr, col, seed, /*num_neighbors=*/{2, 2}, /*time=*/time,
/*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true);
/*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false,
/*directed=*/true, /*disjoint=*/true);

// Expect only the earlier neighbors to be sampled:
auto expected_row = at::tensor({0, 1, 2, 3}, options);
Expand All @@ -126,8 +128,8 @@ TEST(TemporalNeighborTest, BasicAssertions) {

auto out2 = pyg::sampler::neighbor_sample(
rowptr, col, seed, /*num_neighbors=*/{1, 1}, /*time=*/time,
/*csc=*/false, /*replace=*/false, /*directed=*/true, /*disjoint=*/true,
/*temporal_strategy=*/"last");
/*seed_time=*/c10::nullopt, /*csc=*/false, /*replace=*/false,
/*directed=*/true, /*disjoint=*/true, /*temporal_strategy=*/"last");

EXPECT_TRUE(at::equal(std::get<0>(out1), std::get<0>(out2)));
EXPECT_TRUE(at::equal(std::get<1>(out1), std::get<1>(out2)));
Expand Down