Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 23 additions & 23 deletions pyg_lib/csrc/classes/cpu/neighbor_sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ struct NeighborSampler : torch::CustomClassHolder {
public:
NeighborSampler(const at::Tensor& rowptr,
const at::Tensor& col,
const std::optional<at::Tensor>& edge_weight,
const std::optional<at::Tensor>& node_time,
const std::optional<at::Tensor>& edge_time)
const c10::optional<at::Tensor>& edge_weight,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& edge_time)
: rowptr_(rowptr),
col_(col),
edge_weight_(edge_weight),
Expand All @@ -24,13 +24,13 @@ struct NeighborSampler : torch::CustomClassHolder {
std::tuple<at::Tensor, // row
at::Tensor, // col
at::Tensor, // node_id
std::optional<at::Tensor>, // edge_id,
std::optional<at::Tensor>, // batch,
c10::optional<at::Tensor>, // edge_id,
c10::optional<at::Tensor>, // batch,
std::vector<int64_t>, // num_sampled_nodes,
std::vector<int64_t>> // num_sampled_edges,
sample(const std::vector<int64_t>& num_neighbors,
const at::Tensor& seed_node,
const std::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& seed_time,
bool disjoint = false,
std::string temporal_strategy = "uniform",
bool return_edge_id = true) {
Expand All @@ -49,9 +49,9 @@ struct NeighborSampler : torch::CustomClassHolder {
private:
const at::Tensor& rowptr_;
const at::Tensor& col_;
const std::optional<at::Tensor>& edge_weight_;
const std::optional<at::Tensor>& node_time_;
const std::optional<at::Tensor>& edge_time_;
const c10::optional<at::Tensor>& edge_weight_;
const c10::optional<at::Tensor>& node_time_;
const c10::optional<at::Tensor>& edge_time_;
};

struct HeteroNeighborSampler : torch::CustomClassHolder {
Expand All @@ -61,9 +61,9 @@ struct HeteroNeighborSampler : torch::CustomClassHolder {
const std::vector<edge_type>& edge_types,
const c10::Dict<rel_type, at::Tensor>& rowptr,
const c10::Dict<rel_type, at::Tensor>& col,
const std::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight,
const std::optional<c10::Dict<node_type, at::Tensor>>& node_time,
const std::optional<c10::Dict<rel_type, at::Tensor>>& edge_time)
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight,
const c10::optional<c10::Dict<node_type, at::Tensor>>& node_time,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_time)
: node_types_(node_types),
edge_types_(edge_types),
rowptr_(rowptr),
Expand All @@ -75,13 +75,13 @@ struct HeteroNeighborSampler : torch::CustomClassHolder {
std::tuple<c10::Dict<rel_type, at::Tensor>, // row
c10::Dict<rel_type, at::Tensor>, // col
c10::Dict<node_type, at::Tensor>, // node_id
std::optional<c10::Dict<rel_type, at::Tensor>>, // edge_id
std::optional<c10::Dict<node_type, at::Tensor>>, // batch
c10::optional<c10::Dict<rel_type, at::Tensor>>, // edge_id
c10::optional<c10::Dict<node_type, at::Tensor>>, // batch
c10::Dict<node_type, std::vector<int64_t>>, // num_sampled_nodes
c10::Dict<rel_type, std::vector<int64_t>>> // num_sampled_edges
sample(const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors,
const c10::Dict<node_type, at::Tensor>& seed_node,
const std::optional<c10::Dict<node_type, at::Tensor>>& seed_time,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time,
bool disjoint = false,
std::string temporal_strategy = "uniform",
bool return_edge_id = true) {
Expand All @@ -102,26 +102,26 @@ struct HeteroNeighborSampler : torch::CustomClassHolder {
const std::vector<edge_type>& edge_types_;
const c10::Dict<rel_type, at::Tensor>& rowptr_;
const c10::Dict<rel_type, at::Tensor>& col_;
const std::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_;
const std::optional<c10::Dict<node_type, at::Tensor>>& node_time_;
const std::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_;
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_;
const c10::optional<c10::Dict<node_type, at::Tensor>>& node_time_;
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_;
};

} // namespace

TORCH_LIBRARY_FRAGMENT(pyg, m) {
m.class_<NeighborSampler>("NeighborSampler")
.def(torch::init<at::Tensor&, at::Tensor&, std::optional<at::Tensor>,
std::optional<at::Tensor>, std::optional<at::Tensor>>())
.def(torch::init<at::Tensor&, at::Tensor&, c10::optional<at::Tensor>,
c10::optional<at::Tensor>, c10::optional<at::Tensor>>())
.def("sample", &NeighborSampler::sample);

m.class_<HeteroNeighborSampler>("HeteroNeighborSampler")
.def(torch::init<std::vector<node_type>, std::vector<edge_type>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
std::optional<c10::Dict<rel_type, at::Tensor>>,
std::optional<c10::Dict<node_type, at::Tensor>>,
std::optional<c10::Dict<rel_type, at::Tensor>>>())
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::optional<c10::Dict<node_type, at::Tensor>>,
c10::optional<c10::Dict<rel_type, at::Tensor>>>())
.def("sample", &HeteroNeighborSampler::sample);
}

Expand Down
4 changes: 2 additions & 2 deletions pyg_lib/csrc/partition/cpu/metis_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ namespace {
at::Tensor metis_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
int64_t num_partitions,
const std::optional<at::Tensor>& node_weight,
const std::optional<at::Tensor>& edge_weight,
const c10::optional<at::Tensor>& node_weight,
const c10::optional<at::Tensor>& edge_weight,
bool recursive) {
#ifdef _WIN32
TORCH_INTERNAL_ASSERT(false, "METIS not yet supported on Windows");
Expand Down
4 changes: 2 additions & 2 deletions pyg_lib/csrc/partition/metis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ namespace partition {
at::Tensor metis(const at::Tensor& rowptr,
const at::Tensor& col,
int64_t num_partitions,
const std::optional<at::Tensor>& node_weight,
const std::optional<at::Tensor>& edge_weight,
const c10::optional<at::Tensor>& node_weight,
const c10::optional<at::Tensor>& edge_weight,
bool recursive) {
at::TensorArg rowptr_t{rowptr, "rowtpr", 1};
at::TensorArg col_t{col, "col", 1};
Expand Down
2 changes: 1 addition & 1 deletion pyg_lib/csrc/random/cpu/biased_sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace pyg {
namespace random {

std::optional<at::Tensor> biased_to_cdf(const at::Tensor& rowptr,
c10::optional<at::Tensor> biased_to_cdf(const at::Tensor& rowptr,
at::Tensor& bias,
bool inplace) {
TORCH_CHECK(rowptr.is_cpu(), "'rowptr' must be a CPU tensor");
Expand Down
10 changes: 5 additions & 5 deletions pyg_lib/csrc/sampler/cpu/dist_merge_outputs_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace {
template <bool disjoint>
std::tuple<at::Tensor,
at::Tensor,
std::optional<at::Tensor>,
c10::optional<at::Tensor>,
std::vector<int64_t>>
merge_outputs(
const std::vector<at::Tensor>& node_ids,
Expand All @@ -25,10 +25,10 @@ merge_outputs(
const std::vector<int64_t>& partition_orders,
const int64_t num_partitions,
const int64_t num_neighbors,
const std::optional<at::Tensor>& batch) {
const c10::optional<at::Tensor>& batch) {
at::Tensor out_node_id;
at::Tensor out_edge_id;
std::optional<at::Tensor> out_batch = c10::nullopt;
c10::optional<at::Tensor> out_batch = c10::nullopt;

auto offset = num_neighbors;

Expand Down Expand Up @@ -140,7 +140,7 @@ merge_outputs(

std::tuple<at::Tensor,
at::Tensor,
std::optional<at::Tensor>,
c10::optional<at::Tensor>,
std::vector<int64_t>>
merge_sampler_outputs_kernel(
const std::vector<at::Tensor>& node_ids,
Expand All @@ -150,7 +150,7 @@ merge_sampler_outputs_kernel(
const std::vector<int64_t>& partition_orders,
const int64_t num_partitions,
const int64_t num_neighbors,
const std::optional<at::Tensor>& batch,
const c10::optional<at::Tensor>& batch,
bool disjoint) {
DISPATCH_MERGE_OUTPUTS(
disjoint, node_ids, edge_ids, cumsum_neighbors_per_node, partition_ids,
Expand Down
8 changes: 4 additions & 4 deletions pyg_lib/csrc/sampler/cpu/dist_relabel_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ std::tuple<at::Tensor, at::Tensor> relabel(
const at::Tensor& sampled_nodes_with_duplicates,
const std::vector<int64_t>& num_sampled_neighbors_per_node,
const int64_t num_nodes,
const std::optional<at::Tensor>& batch,
const c10::optional<at::Tensor>& batch,
const bool csc) {
if (disjoint) {
TORCH_CHECK(batch.has_value(),
Expand Down Expand Up @@ -103,7 +103,7 @@ relabel(
const c10::Dict<rel_type, std::vector<std::vector<int64_t>>>&
num_sampled_neighbors_per_node_dict,
const c10::Dict<node_type, int64_t>& num_nodes_dict,
const std::optional<c10::Dict<node_type, at::Tensor>>& batch_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& batch_dict,
const bool csc) {
c10::Dict<rel_type, at::Tensor> out_row_dict, out_col_dict;

Expand Down Expand Up @@ -281,7 +281,7 @@ std::tuple<at::Tensor, at::Tensor> relabel_neighborhood_kernel(
const at::Tensor& sampled_nodes_with_duplicates,
const std::vector<int64_t>& num_sampled_neighbors_per_node,
const int64_t num_nodes,
const std::optional<at::Tensor>& batch,
const c10::optional<at::Tensor>& batch,
bool csc,
bool disjoint) {
DISPATCH_RELABEL(disjoint, seed, sampled_nodes_with_duplicates,
Expand All @@ -297,7 +297,7 @@ hetero_relabel_neighborhood_kernel(
const c10::Dict<rel_type, std::vector<std::vector<int64_t>>>&
num_sampled_neighbors_per_node_dict,
const c10::Dict<node_type, int64_t>& num_nodes_dict,
const std::optional<c10::Dict<node_type, at::Tensor>>& batch_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& batch_dict,
bool csc,
bool disjoint) {
c10::Dict<rel_type, at::Tensor> out_row_dict, out_col_dict;
Expand Down
56 changes: 28 additions & 28 deletions pyg_lib/csrc/sampler/cpu/neighbor_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,12 +136,12 @@ class NeighborSampler {
dst_mapper, generator, out_global_dst_nodes);
}

std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>>
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>
get_sampled_edges(bool csc = false) {
TORCH_CHECK(save_edges, "No edges have been stored")
const auto row = pyg::utils::from_vector(sampled_rows_);
const auto col = pyg::utils::from_vector(sampled_cols_);
std::optional<at::Tensor> edge_id = c10::nullopt;
c10::optional<at::Tensor> edge_id = c10::nullopt;
if (save_edge_ids) {
edge_id = pyg::utils::from_vector(sampled_edge_ids_);
}
Expand Down Expand Up @@ -330,18 +330,18 @@ template <bool replace,
std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
std::optional<at::Tensor>,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>,
std::vector<int64_t>>
sample(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const std::optional<at::Tensor>& node_time,
const std::optional<at::Tensor>& edge_time,
const std::optional<at::Tensor>& seed_time,
const std::optional<at::Tensor>& edge_weight,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
const bool csc,
const std::string temporal_strategy) {
TORCH_CHECK(!node_time.has_value() || disjoint,
Expand Down Expand Up @@ -373,7 +373,7 @@ sample(const at::Tensor& rowptr,
"Biased edge temporal sampling not yet supported");

at::Tensor out_row, out_col, out_node_id;
std::optional<at::Tensor> out_edge_id = c10::nullopt;
c10::optional<at::Tensor> out_edge_id = c10::nullopt;
std::vector<int64_t> num_sampled_nodes_per_hop;
std::vector<int64_t> num_sampled_edges_per_hop;
std::vector<int64_t> cumsum_neighbors_per_node =
Expand Down Expand Up @@ -516,7 +516,7 @@ template <bool replace,
std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
std::optional<c10::Dict<rel_type, at::Tensor>>,
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
sample(const std::vector<node_type>& node_types,
Expand All @@ -525,10 +525,10 @@ sample(const std::vector<node_type>& node_types,
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const std::optional<c10::Dict<node_type, at::Tensor>>& node_time_dict,
const std::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_dict,
const std::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const std::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& node_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
const bool csc,
const std::string temporal_strategy) {
TORCH_CHECK(!node_time_dict.has_value() || disjoint,
Expand Down Expand Up @@ -576,7 +576,7 @@ sample(const std::vector<node_type>& node_types,

c10::Dict<rel_type, at::Tensor> out_row_dict, out_col_dict;
c10::Dict<node_type, at::Tensor> out_node_id_dict;
std::optional<c10::Dict<node_type, at::Tensor>> out_edge_id_dict;
c10::optional<c10::Dict<node_type, at::Tensor>> out_edge_id_dict;
if (return_edge_id) {
out_edge_id_dict = c10::Dict<rel_type, at::Tensor>();
} else {
Expand Down Expand Up @@ -892,17 +892,17 @@ sample(const std::vector<node_type>& node_types,
std::tuple<at::Tensor,
at::Tensor,
at::Tensor,
std::optional<at::Tensor>,
c10::optional<at::Tensor>,
std::vector<int64_t>,
std::vector<int64_t>>
neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const std::vector<int64_t>& num_neighbors,
const std::optional<at::Tensor>& node_time,
const std::optional<at::Tensor>& edge_time,
const std::optional<at::Tensor>& seed_time,
const std::optional<at::Tensor>& edge_weight,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
bool replace,
bool directed,
Expand All @@ -921,7 +921,7 @@ neighbor_sample_kernel(const at::Tensor& rowptr,
std::tuple<c10::Dict<rel_type, at::Tensor>,
c10::Dict<rel_type, at::Tensor>,
c10::Dict<node_type, at::Tensor>,
std::optional<c10::Dict<rel_type, at::Tensor>>,
c10::optional<c10::Dict<rel_type, at::Tensor>>,
c10::Dict<node_type, std::vector<int64_t>>,
c10::Dict<rel_type, std::vector<int64_t>>>
hetero_neighbor_sample_kernel(
Expand All @@ -931,10 +931,10 @@ hetero_neighbor_sample_kernel(
const c10::Dict<rel_type, at::Tensor>& col_dict,
const c10::Dict<node_type, at::Tensor>& seed_dict,
const c10::Dict<rel_type, std::vector<int64_t>>& num_neighbors_dict,
const std::optional<c10::Dict<node_type, at::Tensor>>& node_time_dict,
const std::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_dict,
const std::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const std::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& node_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_time_dict,
const c10::optional<c10::Dict<node_type, at::Tensor>>& seed_time_dict,
const c10::optional<c10::Dict<rel_type, at::Tensor>>& edge_weight_dict,
bool csc,
bool replace,
bool directed,
Expand All @@ -952,10 +952,10 @@ dist_neighbor_sample_kernel(const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& seed,
const int64_t num_neighbors,
const std::optional<at::Tensor>& node_time,
const std::optional<at::Tensor>& edge_time,
const std::optional<at::Tensor>& seed_time,
const std::optional<at::Tensor>& edge_weight,
const c10::optional<at::Tensor>& node_time,
const c10::optional<at::Tensor>& edge_time,
const c10::optional<at::Tensor>& seed_time,
const c10::optional<at::Tensor>& edge_weight,
bool csc,
bool replace,
bool directed,
Expand Down
4 changes: 2 additions & 2 deletions pyg_lib/csrc/sampler/cpu/subgraph_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace sampler {

namespace {

std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>> subgraph_kernel(
std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>> subgraph_kernel(
const at::Tensor& rowptr,
const at::Tensor& col,
const at::Tensor& nodes,
Expand All @@ -21,7 +21,7 @@ std::tuple<at::Tensor, at::Tensor, std::optional<at::Tensor>> subgraph_kernel(

const auto out_rowptr = rowptr.new_empty({nodes.size(0) + 1});
at::Tensor out_col;
std::optional<at::Tensor> out_edge_id = c10::nullopt;
c10::optional<at::Tensor> out_edge_id = c10::nullopt;

AT_DISPATCH_INTEGRAL_TYPES(nodes.scalar_type(), "subgraph_kernel", [&] {
const auto rowptr_data = rowptr.data_ptr<scalar_t>();
Expand Down
4 changes: 2 additions & 2 deletions pyg_lib/csrc/sampler/dist_merge_outputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ namespace sampler {

std::tuple<at::Tensor,
at::Tensor,
std::optional<at::Tensor>,
c10::optional<at::Tensor>,
std::vector<int64_t>>
merge_sampler_outputs(
const std::vector<at::Tensor>& node_ids,
Expand All @@ -20,7 +20,7 @@ merge_sampler_outputs(
const std::vector<int64_t>& partition_orders,
const int64_t num_partitions,
const int64_t num_neighbors,
const std::optional<at::Tensor>& batch,
const c10::optional<at::Tensor>& batch,
bool disjoint) {
std::vector<at::TensorArg> node_ids_args;
std::vector<at::TensorArg> edge_ids_args;
Expand Down
Loading
Loading