-
Couldn't load subscription status.
- Fork 53
Add softmax_csr implementation
#264
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add softmax_csr implementation
#264
Conversation
2b91285 to
c56a93b
Compare
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #264 +/- ##
==========================================
+ Coverage 85.65% 86.19% +0.54%
==========================================
Files 32 34 +2
Lines 1115 1188 +73
==========================================
+ Hits 955 1024 +69
- Misses 160 164 +4 ☔ View full report in Codecov by Sentry. |
|
@pyg-team/intel-team Please take a look. |
pyg_lib/ops/__init__.py
Outdated
| [0.0598, 0.2923, 0.1206, 0.0921], | ||
| [0.7792, 0.3502, 0.1638, 0.2145]]) | ||
| """ | ||
| if src.dim() != 2 or not src.is_cpu or ptr is None or dim != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder why ptr is optional because if you don't provide it, you get an error.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to make API in its final form, otherwise, each change here would require a change in pytorch_geometric. I'll add support for index in the near future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kgajdamo, after rethinking your suggestion, I decided to change the API and create a specialized softmax_csr operation that accepts ptr only. Rationale:
torch.compile gives nice results for softmax with groups defined via index, hence I don't see a reason to have a specialized kernel for that option.
5dad205 to
167793a
Compare
softmax_csr implementation
167793a to
6703779
Compare
|
@kgajdamo @rusty1s Please take a look. I made the softmax implementation a bit more general, so now it covers any |
|
Hi @DamianSzwichtenberg , this PR looks good to me. The overall structure of softmax kernel with sparse input is similar with that in softmax kernel of dense input in PyTorch. With the sparsity, the performance boost is from parallelism, right? And will this PR upstream to PyTorch later? Since there is no SparseCsr support for softmax yet in PyTorch. |
These kernels differ quite a bit. In
There are no plans to upstream this operation to PyTorch. As above, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me.
for more information, see https://pre-commit.ci
68a2723 to
cb4bb68
Compare
| class Softmax(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we define the autograd function directly in C++?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be possible, will check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change available at #282
This PR uses optimized `softmax_csr` operation (introduced in [pyg-lib @ 264](pyg-team/pyg-lib#264)), when given is a CPU tensor, and softmax groups are defined via `ptr`.
This PR adds forward and backward implementation of sparse softmax operation as defined here.
In the
pytorch_geometricimplementation we cannot take advantage of model compilation when groups are defined viaptr.softmax_csrintroduced here provides a well-performing kernel for such a scenario.Performance boost (achieved on 28C, single socket machine):
~7x for forward pass
~8x for backward pass
Additionally, GAT training time was reduced by ~5%.