Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Add Scaled Dot Product Attention (SDPA) from PyTorch#5994

Merged
sw005320 merged 5 commits intoespnet:masterfrom
pyf98:owsm-test
Dec 28, 2024
Merged

Add Scaled Dot Product Attention (SDPA) from PyTorch#5994
sw005320 merged 5 commits intoespnet:masterfrom
pyf98:owsm-test

Conversation

@pyf98
Copy link
Collaborator

@pyf98 pyf98 commented Dec 25, 2024

What?

This PR adds the SDPA implementation from PyTorch. It supports several implementations of the attention mechanism (including flash attention), some of which are more efficient than the default implementation.

See https://pytorch.org/docs/2.5/generated/torch.nn.functional.scaled_dot_product_attention.html for more information.

Note: SDPA does not seem to support flash attention with variable-length inputs. So, we still need to use the previous flash attention implementation.

@mergify mergify bot added the ESPnet1 label Dec 25, 2024
@sw005320 sw005320 added the Enhancement Enhancement label Dec 26, 2024
@sw005320 sw005320 added this to the v.202503 milestone Dec 26, 2024
@sw005320
Copy link
Contributor

Thanks!
Can you fix some CI errors?

@codecov
Copy link

codecov bot commented Dec 26, 2024

Codecov Report

Attention: Patch coverage is 37.50000% with 5 lines in your changes missing coverage. Please review.

Project coverage is 56.08%. Comparing base (9f8c41a) to head (805ed6d).
Report is 7 commits behind head on master.

Files with missing lines Patch % Lines
...pnet/nets/pytorch_backend/transformer/attention.py 37.50% 5 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master    #5994       +/-   ##
===========================================
+ Coverage   28.45%   56.08%   +27.63%     
===========================================
  Files         829      830        +1     
  Lines       78052    78031       -21     
===========================================
+ Hits        22208    43763    +21555     
+ Misses      55844    34268    -21576     
Flag Coverage Δ
test_integration_espnetez 38.21% <37.50%> (?)
test_python_espnet1 ?
test_python_espnet2 51.59% <37.50%> (?)
test_python_espnetez ?
test_utils 20.66% <37.50%> (+<0.01%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@sw005320
Copy link
Contributor

Can you check some treatments for this library?
Our CI complains about itl
https://github.com/espnet/espnet/actions/runs/12507998828/job/34895445316?pr=5994

@sw005320 sw005320 merged commit 081a376 into espnet:master Dec 28, 2024
@sw005320
Copy link
Contributor

Thx!

@pyf98 pyf98 deleted the owsm-test branch December 30, 2024 01:17
Shikhar-S pushed a commit to Shikhar-S/espnet that referenced this pull request Mar 13, 2025
Add Scaled Dot Product Attention (SDPA) from PyTorch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants