-
Couldn't load subscription status.
- Fork 53
Hetero subgraph with dispatching #43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #43 +/- ##
==========================================
- Coverage 97.27% 96.51% -0.76%
==========================================
Files 10 12 +2
Lines 220 287 +67
==========================================
+ Hits 214 277 +63
- Misses 6 10 +4 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
for more information, see https://pre-commit.ci
| void fill(const scalar_t* nodes_data, const scalar_t size) { | ||
| if (use_vec) { | ||
| for (scalar_t i = 0; i < size; ++i) | ||
| for (scalar_t i = 0; i < size; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me post my question here, I read some documents and based on my understanding scalar_t includes both float, double, int32, int64 during compile. But in a lot of our usecases we are iterating over integers. How does pytorch avoid compile float type for these functions? Is there a better way to be more specific to the data types here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some helper functions like is_integral for a dtype, but IMO it is mostly runtime checking. We can also use some STL type checking for compile time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The AT_DISPATCH_INTEGRAL_TYPES call handles which types scalar_t can take (during compile time).
| }); | ||
|
|
||
| return std::make_tuple(out_rowptr, out_col, out_edge_id); | ||
| return subgraph_bipartite(rowptr, col, nodes, nodes, return_edge_id); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code structure looks a little weird to me because csrc/sampler/cpu/subgraph_kernel exists for register TORCH_LIBRARY_IMPL and it is using a general implementation in csr/sampler/subgraph.cpp. How about reorganize the code like this:
csr
- ops
# all ops expose for pytorch.
- sampler
# all general graph operation.
- sampler
We don't need to refactor the code structure now. But want to hear your opinion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm, seems subgraph.cpp also defines library. Why not merge them together since sampler/subgraph.cpp also runs on cpu only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could follow the style in other pyg repos: put CPU/GPU specific impl in separate folders and provide common interface in a higher directory.
| }); | ||
|
|
||
| return std::make_tuple(out_rowptr, out_col, out_edge_id); | ||
| return subgraph_bipartite(rowptr, col, nodes, nodes, return_edge_id); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm, seems subgraph.cpp also defines library. Why not merge them together since sampler/subgraph.cpp also runs on cpu only.
|
|
||
| auto res = subgraph_with_mapper<scalar_t>(rowptr, col, src_nodes, | ||
| mapper, return_edge_id); | ||
| out_rowptr = std::get<0>(res); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or maybe we could do std::tie(out_powptr, out_col, out_edge_id) = res?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
pyg_lib/csrc/sampler/subgraph.cpp
Outdated
|
|
||
| for (const auto& kv : rowptr) { | ||
| const auto& edge_type = kv.key(); | ||
| bool pass = filter_args_by_edge(edge_type, src_nodes_arg, dst_nodes_arg, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd still prefer
pass = src_nodes_args.filter_by_edge(edge_type) && dst_nodes_args.filter_by_edge(edge_type) && edge_id_arg.filter_by_edge(edge_type)
or from an efficiency point of view.
auto dst = get_dst(edge_type)
auto src = get_src(edge_type)
bool pass = return_edge_id.counts(edge_type) > 0 && src_nodes.counts(src) > 0 && dst_nodes.counts(dst) > 0;
pyg_lib/csrc/sampler/subgraph.cpp
Outdated
| const auto& r = rowptr.at(edge_type); | ||
| const auto& c = col.at(edge_type); | ||
| res.insert(edge_type, | ||
| subgraph_bipartite(r, c, std::get<0>(vals), std::get<1>(vals), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and here would just be
subgraph_bipartite(r, c, src_nodes.at(src), dst_nodes.at(dst), return_edge_id.at(edge_type));
pyg_lib/csrc/sampler/subgraph.cpp
Outdated
| return op.call(rowptr, col, nodes, return_edge_id); | ||
| } | ||
|
|
||
| c10::Dict<utils::edge_t, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I actually would have expected we return a tuple of dictionaries, similar to how the input looks like.
| void fill(const scalar_t* nodes_data, const scalar_t size) { | ||
| if (use_vec) { | ||
| for (scalar_t i = 0; i < size; ++i) | ||
| for (scalar_t i = 0; i < size; ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The AT_DISPATCH_INTEGRAL_TYPES call handles which types scalar_t can take (during compile time).
| offset++; | ||
| } | ||
| AT_DISPATCH_INTEGRAL_TYPES( | ||
| nodes.scalar_type(), "subgraph_kernel_with_mapper", [&] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we make this a one-liner again?
|
|
||
| TORCH_LIBRARY_IMPL(pyg, CPU, m) { | ||
| m.impl(TORCH_SELECTIVE_NAME("pyg::subgraph"), TORCH_FN(subgraph_kernel)); | ||
| m.impl(TORCH_SELECTIVE_NAME("pyg::subgraph_bipartite"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason we want to expose that? Looks more like an internal function to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the user want to build a subgraph of a bipartite graph then he can use it.
| } | ||
|
|
||
| c10::Dict<utils::EdgeType, | ||
| std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO, the output should be a tuple of dictionaries (similar to the input).
| if (pass) { | ||
| const auto& r = rowptr.at(edge_type); | ||
| const auto& c = col.at(edge_type); | ||
| res.insert(edge_type, subgraph_bipartite( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we user the mapper here? Other-wise, we will re-map across every edge type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it has a cost, but the mapper is more read-intensive. I will add a TODO here.
|
|
||
| inline NodeType get_dst(const EdgeType& e) { | ||
| return e.substr(e.find_last_of(SPLIT_TOKEN) + 1); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also add a function that maps tuples to strings and vice versa.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea.
| std::tuple<at::Tensor, at::Tensor, c10::optional<at::Tensor>>> | ||
| hetero_subgraph(const utils::EdgeTensorDict& rowptr, | ||
| const utils::EdgeTensorDict& col, | ||
| const utils::NodeTensorDict& src_nodes, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why we have both src_nodes and dst_nodes. IMO, these can be safely merged as in https://pytorch-geometric.readthedocs.io/en/latest/modules/data.html#torch_geometric.data.HeteroData.subgraph.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separating src and dst is just to give some flexibility. We could also have the merged API though.
Co-authored-by: Matthias Fey <[email protected]>
Co-authored-by: Matthias Fey <[email protected]>
No description provided.