-
Notifications
You must be signed in to change notification settings - Fork 10
[Feat] Add GQA/MHA FlashAttention-3 baseline #56
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @RMLYC, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly upgrades the benchmarking infrastructure by introducing a dedicated mechanism to compare custom operations against established baselines. Specifically, it integrates FlashAttention-3 for GQA and MHA, providing a robust way to measure performance improvements. The changes also refine the output reporting to offer a more comprehensive view of both forward and backward pass metrics for both the primary operation and its baseline. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a baseline profiling mechanism for GQA and MHA operators using FlashAttention-3. The changes include adding baseline_profile and baseline_program to the benchmarking framework, implementing them for GQA and MHA, and updating the profiling script to capture and display these new metrics. My review has identified a few critical issues related to incorrect return values in the new baseline programs and a bug in the profiling script's error handling. I've also suggested an improvement to ensure profiling consistency by using torch.no_grad().
benchmarks/flash_attn/gqa.py
Outdated
| if isinstance(out, tuple): | ||
| out = out[0] | ||
|
|
||
| return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
benchmarks/flash_attn/mha.py
Outdated
| if isinstance(out, tuple): | ||
| out = out[0] | ||
|
|
||
| return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tests/profile_run.py
Outdated
| error_result = { | ||
| **params, 'tl-latency(ms)': None, | ||
| 'tl-TFlops': None, | ||
| 'tl-Bandwidth(GB/s)': None, | ||
| 'Baseline-latency(ms)': None, | ||
| 'Baseline-TFlops': None, | ||
| 'Baseline-Bandwidth(GB/s)': None | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The keys in error_result do not match the new fieldnames defined for the CSV output. The fieldnames now include prefixes like fwd- and bwd- (e.g., 'fwd-tl-latency(ms)'), but error_result uses keys without these prefixes (e.g., 'tl-latency(ms)'). This will cause issues when writing to the CSV file. The error_result dictionary should contain all the output fieldnames with None as their values.
A more robust way to construct error_result would be to derive the keys from fieldnames dynamically.
output_fields = [f for f in fieldnames if f not in params]
error_result = {**params, **{f: None for f in output_fields}}| # Warmup to get rid of CUDA lazy initialization effects. | ||
| for _ in range(warmup): | ||
| _ = baseline_op(*inputs) | ||
| torch.cuda.synchronize(device=device) | ||
|
|
||
| # CUDA event-based timing for higher precision. | ||
| start_event = torch.cuda.Event(enable_timing=True) | ||
| end_event = torch.cuda.Event(enable_timing=True) | ||
|
|
||
| start_event.record() | ||
| for _ in range(rep): | ||
| _ = baseline_op(*inputs) | ||
| end_event.record() | ||
|
|
||
| torch.cuda.synchronize(device=device) | ||
| total_ms = start_event.elapsed_time(end_event) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The profile method wraps the benchmarked operation within a torch.no_grad() context to prevent gradient computations from affecting performance measurements. The baseline_profile method is missing this, which could lead to inaccurate performance metrics if the baseline_op has gradient tracking enabled. It's recommended to add with torch.no_grad(): for consistency and correctness, especially for forward pass benchmarks.
| # Warmup to get rid of CUDA lazy initialization effects. | |
| for _ in range(warmup): | |
| _ = baseline_op(*inputs) | |
| torch.cuda.synchronize(device=device) | |
| # CUDA event-based timing for higher precision. | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| start_event.record() | |
| for _ in range(rep): | |
| _ = baseline_op(*inputs) | |
| end_event.record() | |
| torch.cuda.synchronize(device=device) | |
| total_ms = start_event.elapsed_time(end_event) | |
| with torch.no_grad(): | |
| # Warmup to get rid of CUDA lazy initialization effects. | |
| for _ in range(warmup): | |
| _ = baseline_op(*inputs) | |
| torch.cuda.synchronize(device=device) | |
| # CUDA event-based timing for higher precision. | |
| start_event = torch.cuda.Event(enable_timing=True) | |
| end_event = torch.cuda.Event(enable_timing=True) | |
| start_event.record() | |
| for _ in range(rep): | |
| _ = baseline_op(*inputs) | |
| end_event.record() | |
| torch.cuda.synchronize(device=device) | |
| total_ms = start_event.elapsed_time(end_event) |
baseline_programinterface inBenchmark.py.baseline_profileinterface inBenchmark.py