Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

@DamianSzwichtenberg
Copy link
Member

@DamianSzwichtenberg DamianSzwichtenberg commented Jan 18, 2023

In GNN workloads we use torch.(arg)sort operation for example when we create permutation from CSR to CSC format. This case happens to be slow in PyTorch because 1-dimensional sorting works sequentially. Generally, in our case, we sort positive integer values in ascending order, so this is a perfect case for a linear sorting algorithm like radix_sort, which is already implemented in the PyTorch subproject, called fbgemm (here). This PR ports the radix_sort operation to pyg-lib, in form of a new operation called index_sort.

NOTE: In the future, similar optimization may be applied in PyTorch directly.

For reviewers:
@rusty1s could you please help with LICENSE files, I followed some guidelines, but I feel like it can be done easier.

Profiling results (inference using ogb example, sort inside csr2csc was replaced with index_sort):

BEFORE:
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     torch_sparse::spmm_sum        60.63%       43.423s        60.63%       43.423s       14.474s             3  
                              aten::argsort         0.11%      79.531ms        31.13%       22.297s       11.149s             2  
                                 aten::sort        29.97%       21.464s        31.02%       22.218s       11.109s             2  
                                aten::index         1.68%        1.202s         1.93%        1.382s     138.222ms            10  
                               aten::linear         0.00%       1.217ms         1.76%        1.262s     420.825ms             3  
                               aten::matmul         0.00%     509.000us         1.76%        1.261s     420.399ms             3  
                                   aten::mm         1.76%        1.261s         1.76%        1.261s     420.230ms             3  
                                  aten::add         1.14%     816.860ms         1.14%     816.860ms     204.215ms             4  
                           aten::index_put_         0.00%     701.000us         1.11%     792.478ms     158.496ms             5  
                     aten::_index_put_impl_         0.68%     486.707ms         1.11%     791.777ms     158.355ms             5  
                               aten::arange         0.39%     276.297ms         0.77%     552.737ms      92.123ms             6  
                              aten::nonzero         0.67%     482.123ms         0.67%     483.326ms      80.554ms             6  
                                aten::copy_         0.67%     479.840ms         0.67%     479.840ms      79.973ms             6  
                                 aten::relu         0.00%       1.003ms         0.51%     365.745ms     182.873ms             2  
                            aten::clamp_min         0.51%     364.742ms         0.51%     364.742ms     182.371ms             2  
                  torch_scatter::gather_csr         0.47%     335.408ms         0.47%     335.472ms     335.472ms             1  
                                  aten::mul         0.36%     255.875ms         0.36%     255.875ms     127.938ms             2  
                torch_sparse::non_diag_mask         0.28%     197.670ms         0.30%     213.400ms     213.400ms             1  
                                  aten::max         0.19%     135.257ms         0.19%     135.301ms      27.060ms             5  
                                 aten::add_         0.11%      81.008ms         0.11%      81.008ms      40.504ms             2  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 71.621s

AFTER:
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                       Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     torch_sparse::spmm_sum        70.67%       43.037s        70.67%       43.037s       14.346s             3  
                              aten::argsort         0.05%      30.113ms        17.37%       10.580s       10.580s             1  
                                 aten::sort        17.06%       10.390s        17.32%       10.550s       10.550s             1  
                            pyg::index_sort         3.28%        1.999s         3.48%        2.122s        2.122s             1  
                                aten::index         2.00%        1.221s         2.33%        1.417s     141.723ms            10  
                               aten::linear         0.00%      34.000us         1.70%        1.036s     345.240ms             3  
                               aten::matmul         0.00%      25.000us         1.70%        1.036s     345.213ms             3  
                                   aten::mm         1.70%        1.036s         1.70%        1.036s     345.205ms             3  
                                  aten::add         1.21%     739.917ms         1.21%     739.917ms     184.979ms             4  
                           aten::index_put_        -0.00%    -148.000us         0.89%     543.333ms     108.667ms             5  
                     aten::_index_put_impl_         0.48%     290.251ms         0.89%     543.282ms     108.656ms             5  
                              aten::nonzero         0.74%     449.217ms         0.74%     449.265ms      74.877ms             6  
                                 aten::relu         0.00%      30.000us         0.53%     324.495ms     162.248ms             2  
                            aten::clamp_min         0.53%     324.465ms         0.53%     324.465ms     162.232ms             2  
                  torch_scatter::gather_csr         0.53%     320.724ms         0.53%     320.747ms     320.747ms             1  
                               aten::arange         0.17%     102.355ms         0.34%     204.799ms      34.133ms             6  
                torch_sparse::non_diag_mask         0.29%     176.260ms         0.31%     189.252ms     189.252ms             1  
                                aten::copy_         0.27%     164.309ms         0.27%     164.309ms      32.862ms             5  
                                  aten::max         0.25%     151.721ms         0.25%     151.766ms      25.294ms             6  
                                  aten::mul         0.20%     122.082ms         0.20%     122.082ms      61.041ms             2  
-------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 60.900s

@DamianSzwichtenberg
Copy link
Member Author

@rusty1s PTAL.

@rusty1s
Copy link
Member

rusty1s commented Jan 24, 2023

Yes, will do. Can you also ping someone from @pyg-team/intel-team for another review?

@DamianSzwichtenberg
Copy link
Member Author

Yes, will do. Can you also ping someone from @pyg-team/intel-team for another review?

Sure. @mszarma, @kgajdamo, @JakubPietrakIntel Please take a look.

Copy link
Contributor

@mszarma mszarma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@DamianSzwichtenberg DamianSzwichtenberg force-pushed the index-sort branch 2 times, most recently from a43fe6e to 4014d35 Compare January 25, 2023 12:24
@codecov-commenter
Copy link

codecov-commenter commented Jan 25, 2023

Codecov Report

Merging #181 (e2feec5) into master (b9d8e60) will decrease coverage by 11.11%.
The diff coverage is 5.55%.

@@             Coverage Diff             @@
##           master     #181       +/-   ##
===========================================
- Coverage   93.28%   82.18%   -11.11%     
===========================================
  Files          23       26        +3     
  Lines         745      853      +108     
===========================================
+ Hits          695      701        +6     
- Misses         50      152      +102     
Impacted Files Coverage Δ
pyg_lib/csrc/ops/cpu/radix_sort.h 0.00% <0.00%> (ø)
pyg_lib/csrc/ops/cpu/index_sort_kernel.cpp 11.11% <11.11%> (ø)
pyg_lib/csrc/ops/index_sort.cpp 100.00% <100.00%> (ø)

📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more

@DamianSzwichtenberg DamianSzwichtenberg merged commit f980c23 into pyg-team:master Jan 25, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants