Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for sm90, sm100 #149282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

eqy
Copy link
Collaborator

@eqy eqy commented Mar 16, 2025

cleanup tuple/tensor boilerplate in cuDNN SDPA, preparation for nested/ragged tensor backward

@eqy eqy added open source topic: not user facing topic category module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion labels Mar 16, 2025
@eqy eqy requested a review from syed-ahmed as a code owner March 16, 2025 21:09
Copy link

pytorch-bot bot commented Mar 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/149282

Note: Links to docs will display an error until the docs builds have been completed.

❌ 15 New Failures, 1 Unrelated Failure

As of commit 7b6fd2d with merge base b027cb8 (image):

NEW FAILURES - The following jobs have failed:

UNSTABLE - The following job is marked as unstable, possibly due to flakiness on trunk:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@eqy
Copy link
Collaborator Author

eqy commented Mar 16, 2025

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Tried to rebase and push PR #149282, but it was already up to date. Try rebasing against main by issuing:
@pytorchbot rebase -b main

@eqy eqy changed the title [cuDNN][SDPA] cuDNN SDPA refactor/cleanup [WIP][cuDNN][SDPA] cuDNN SDPA refactor/cleanup Mar 17, 2025
@pytorch pytorch deleted a comment from pytorch-bot bot Mar 17, 2025
@eqy
Copy link
Collaborator Author

eqy commented Mar 17, 2025

@pytorchmergebot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased cudnnsdparefactor onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout cudnnsdparefactor && git pull --rebase)

Copy link

linux-foundation-easycla bot commented Mar 17, 2025

CLA Missing ID CLA Not Signed

}
auto workspace_size = mha_graph->get_workspace_size();
auto workspace_ptr =
c10::cuda::CUDACachingAllocator::get()->allocate(workspace_size);
TORCH_CHECK(
mha_graph->execute(handle, variant_pack, workspace_ptr.get()).is_good());
mhagraphcache.update(key, graph_and_tensors_values);
mhagraphcache.update(key, mha_graph);
Copy link
Collaborator

@Skylion007 Skylion007 Mar 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The update method for mhagraphcache should probably use perfect forward up where the update method is defined instead of an lref. And throughout the file should be to remove extra copies.

Suggested change
mhagraphcache.update(key, mha_graph);
mhagraphcache.update(key, std::move(mha_graph));

@eqy eqy force-pushed the cudnnsdparefactor branch from 3848e20 to bd4432a Compare April 7, 2025 23:31
@eqy eqy requested review from albanD and soulitzer as code owners April 15, 2025 00:44
@eqy eqy changed the title [WIP][cuDNN][SDPA] cuDNN SDPA refactor/cleanup [cuDNN][SDPA] cuDNN SDPA refactor/cleanup, nested tensor backward, test priority bump for sm90, sm100 Apr 28, 2025
@eqy eqy requested review from drisspg and jbschlosser April 28, 2025 21:58
@jerryzh168 jerryzh168 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Apr 29, 2025
@eqy eqy force-pushed the cudnnsdparefactor branch from b6a75a5 to 0baac8a Compare April 30, 2025 18:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: sdpa All things related to torch.nn.functional.scaled_dot_product_attentiion open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants