Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[TEST][ATen][CUDA] Skip row-wise scaled matrix mmultiplication tests on sm_120+ #152814

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

Aidyn-A
Copy link
Collaborator

@Aidyn-A Aidyn-A commented May 5, 2025

The float8 row-wise scaled matmuls are not supported on Blackwell yet. This PR adds skips to those tests to decrease the noise on sm_120+ machines.

cc @ptrblck @msaroufim @eqy @jerryzh168

Copy link

pytorch-bot bot commented May 5, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/152814

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6aba225 with merge base 5796212 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot bot added the topic: not user facing topic category label May 5, 2025
@Aidyn-A Aidyn-A requested a review from eqy May 5, 2025 09:28
Copy link
Collaborator

@eqy eqy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still needed after #148421 ?
I'm seeing that on a fresh source build
test_float8_rowwise_scaling_sanity_use_fast_accum_True_cuda
test_float8_rowwise_scaling_sanity_use_fast_accum_False_cuda
test_scaled_mm_vs_emulated_row_wise
all pass

@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented May 6, 2025

Is this still needed after #148421 ? I'm seeing that on a fresh source build test_float8_rowwise_scaling_sanity_use_fast_accum_True_cuda test_float8_rowwise_scaling_sanity_use_fast_accum_False_cuda test_scaled_mm_vs_emulated_row_wise all pass

My bad, it is needed for sm_120, not sm_100.

@Aidyn-A Aidyn-A changed the title [TEST][ATen][CUDA] Skip row-wise scaled matrix mmultiplication tests on Blackwell [TEST][ATen][CUDA] Skip row-wise scaled matrix mmultiplication tests on sm_12x May 6, 2025
@Aidyn-A Aidyn-A changed the title [TEST][ATen][CUDA] Skip row-wise scaled matrix mmultiplication tests on sm_12x [TEST][ATen][CUDA] Skip row-wise scaled matrix mmultiplication tests on sm_120+ May 6, 2025
@Aidyn-A Aidyn-A requested a review from eqy May 6, 2025 10:08
@drisspg drisspg added module: cuda Related to torch.cuda, and CUDA support in general triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels May 6, 2025
@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented May 7, 2025

@pytorchbot rebase

@pytorchmergebot
Copy link
Collaborator

@pytorchbot started a rebase job onto refs/remotes/origin/viable/strict. Check the current status here

@pytorchmergebot
Copy link
Collaborator

Successfully rebased test_matmul_cuda_skip_rowwise_scaled_mm_on_blackwell onto refs/remotes/origin/viable/strict, please pull locally before adding more changes (for example, via git checkout test_matmul_cuda_skip_rowwise_scaled_mm_on_blackwell && git pull --rebase)

@pytorchmergebot pytorchmergebot force-pushed the test_matmul_cuda_skip_rowwise_scaled_mm_on_blackwell branch from dec9eab to c02fe40 Compare May 7, 2025 09:00
@Skylion007
Copy link
Collaborator

Skylion007 commented May 7, 2025

Hmm, is this because CUTLASS is missing those specializations? Or because we are missing something minor on our side that would unblock support? Like a missing template specialization, overly restrictive dispatch logic, or missing cmake are to build those kernels from SM120?

@Skylion007
Copy link
Collaborator

Doesn't it support it here? Are we missing dispatch logic?

@Skylion007
Copy link
Collaborator

Does just adding SM120 here fix it or is SM100 not compatible with SM120?

const bool sm10x = properties != nullptr && properties->major == 10;

@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented May 7, 2025

cc @eqy @malfet to review

@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented May 7, 2025

Hmm, is this because CUTLASS is missing those specializations? Or because we are missing something minor on our side that would unblock support? Like a missing template specialization, overly restrictive dispatch logic, or missing cmake are to build those kernels from SM120?

No, it is not as trivial as adding sm_120 to those places. I have tried that, CUTLASS just fails as "uninitialized" (whatever that means).

@@ -1013,7 +1015,7 @@ def test_float8_scale_fast_accum(self, device) -> None:
self.assertEqual(out_fp8, out_fp8_s)

@unittest.skipIf(not PLATFORM_SUPPORTS_FP8 or IS_WINDOWS, f8_msg)
@unittest.skipIf(not SM89OrLater, "rowwise implementation is currently sm89+ specific")
@unittest.skipIf(not SM89OrLater or _IS_SM12X, "rowwise implementation is currently sm89-sm90 specific")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we skip or XFAIL?

@Skylion007
Copy link
Collaborator

Skylion007 commented May 8, 2025

Hmm, is this because CUTLASS is missing those specializations? Or because we are missing something minor on our side that would unblock support? Like a missing template specialization, overly restrictive dispatch logic, or missing cmake are to build those kernels from SM120?

No, it is not as trivial as adding sm_120 to those places. I have tried that, CUTLASS just fails as "uninitialized" (whatever that means).

If you could post the trace, that could be helpful to figuring out how to enable it.

@Aidyn-A
Copy link
Collaborator Author

Aidyn-A commented May 8, 2025

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label May 8, 2025
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/trunk Trigger trunk jobs on your pull request Merged module: cuda Related to torch.cuda, and CUDA support in general open source topic: not user facing topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants