Thanks to visit codestin.com
Credit goes to github.com

Skip to content

perf(ROCm): add is_rdna() detection and optimize CE loss for RDNA GPUs#4123

Merged
danielhanchen merged 1 commit intounslothai:mainfrom
GoldenGrapeGentleman:perf/rdna-kernel-tuning
Mar 1, 2026
Merged

perf(ROCm): add is_rdna() detection and optimize CE loss for RDNA GPUs#4123
danielhanchen merged 1 commit intounslothai:mainfrom
GoldenGrapeGentleman:perf/rdna-kernel-tuning

Conversation

@GoldenGrapeGentleman
Copy link
Contributor

@GoldenGrapeGentleman GoldenGrapeGentleman commented Feb 27, 2026

Summary

Apply targeted Triton kernel tuning for the chunked cross-entropy forward path on AMD RDNA consumer/workstation GPUs (RDNA3/RDNA4).

Note: is_rdna() scope restriction is tracked in PR #4136.

Changes

unsloth/kernels/cross_entropy_loss.py

  • Chunked CE forward (large vocab > 65536): set num_warps=16 for RDNA
  • Refactored condition from 32 if not is_cdna() else 1616 if is_cdna() or is_rdna() else 32

Correction (post-merge): The benchmark "Speedup" column compared warps=16 against warps=8, but the previous code path for RDNA already used warps=32. The actual change vs. old behavior is ~0.4%, within measurement noise. The warp change has been reverted in PR #TODO — is_rdna() remains available in utils.py for future RDNA-specific tuning.

Benchmark Results

Hardware: AMD Radeon PRO W7900 (gfx1100, RDNA3, 48GB)
Method: 5 independent trials × 300 iterations each, median reported

Chunked CE Forward (large vocab, BS=65536)

Model vocab warps=8 warps=16 (RDNA) warps=32 (old code) Δ vs old code
Llama-3.1-8B 128K 0.450ms 0.193ms 0.194ms ~0.4%
Qwen2.5-7B 152K 0.603ms 0.242ms 0.244ms ~0.6%
Gemma-3-4B 256K 0.911ms 0.379ms 0.381ms ~0.4%

Other kernels — no modification needed

Kernel BLOCK_SIZE Default warps Doubled warps Verdict
RMS LayerNorm fwd 2048–8192 8 16 (−10~25%) ❌ Keep default
RMS LayerNorm bwd 2048–4096 8 16 (−12~20%) ❌ Keep default
LayerNorm fwd 2048–4096 8 16 (−6~27%) ❌ Keep default
RoPE legacy 64 4 8 (−5~22%) ❌ Keep default
RoPE QK (GQA) 128 4 8 (−45~47%) ❌ Keep default
CE fwd (small vocab) 32768 32 ❌ Keep default
CE backward 4096 8 ❌ Keep default

Testing

  • Verified is_rdna() returns True on W7900 (gfx1100)
  • Verified is_rdna() returns False on NVIDIA GPUs and CDNA GPUs
  • Confirmed no functional regression — only the chunked CE forward path is affected

cc @danielhanchen

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances performance for Triton kernels on AMD RDNA consumer GPUs by introducing architecture-specific optimizations. It intelligently adjusts kernel launch parameters, specifically the number of warps, to better utilize the unique dual-issue SIMD32 Compute Unit design of RDNA architectures. This results in notable speedups for critical operations like RMS LayerNorm and Cross Entropy Loss, ensuring more efficient execution on a wider range of AMD hardware.

Highlights

  • RDNA Architecture Detection: Added a new is_rdna() function to detect AMD RDNA consumer GPUs (Radeon RX 7000/9000 series, PRO W7000 series) based on their gfx architecture IDs.
  • Triton Kernel Parameter Tuning: Optimized num_warps in calculate_settings() to double for RDNA GPUs, capped at 32, leveraging RDNA's dual-issue SIMD32 CU microarchitecture for improved occupancy and performance.
  • Cross Entropy Loss Optimization: Configured chunked cross entropy kernels to use 16 warps on RDNA GPUs, matching CDNA, which was benchmarked to be faster than the default 32 warps.
  • Module Exports: Exported is_cdna and the newly added is_rdna functions from unsloth.kernels for broader use within the codebase.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth/kernels/init.py
    • Exported is_cdna and is_rdna functions from the module.
  • unsloth/kernels/cross_entropy_loss.py
    • Imported the new is_rdna function.
    • Modified the num_warps parameter for chunked cross entropy kernels to be 16 if is_cdna() or is_rdna() is true, otherwise 32.
  • unsloth/kernels/utils.py
    • Implemented a new is_rdna() function to detect AMD RDNA GPUs (gfx1100-gfx1201) using triton.runtime.driver.active.get_current_target().arch.
    • Updated the calculate_settings function to double the num_warps (up to a maximum of 32) when an RDNA GPU is detected, based on benchmarked performance gains.
Activity
  • No human activity has occurred on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces performance optimizations for AMD RDNA consumer GPUs by detecting the RDNA architecture and tuning Triton kernel parameters. A new is_rdna() function is added, and num_warps is adjusted in calculate_settings and cross_entropy_loss based on RDNA's specific microarchitecture, leading to significant performance gains as shown in the benchmarks. The changes are well-structured and justified. I have one suggestion to improve the implementation of the new is_rdna function for better performance and maintainability.

Comment on lines +91 to +99
return is_hip() and triton.runtime.driver.active.get_current_target().arch in (
"gfx1100", # RDNA3 (RX 7900 XTX/XT, PRO W7900/W7800)
"gfx1101", # RDNA3 (RX 7800 XT, RX 7700 XT)
"gfx1102", # RDNA3 (RX 7600 XT/7600)
"gfx1150", # RDNA3.5 (Strix Point APU)
"gfx1151", # RDNA3.5 (Strix Halo APU)
"gfx1200", # RDNA4 (RX 9070 XT)
"gfx1201", # RDNA4 (RX 9070)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For improved performance, consider using a set for the architecture check instead of a tuple. Membership testing against a set is more efficient (O(1) on average) than a tuple (O(n)). Since this function is cached, the set will only be created once.

Suggested change
return is_hip() and triton.runtime.driver.active.get_current_target().arch in (
"gfx1100", # RDNA3 (RX 7900 XTX/XT, PRO W7900/W7800)
"gfx1101", # RDNA3 (RX 7800 XT, RX 7700 XT)
"gfx1102", # RDNA3 (RX 7600 XT/7600)
"gfx1150", # RDNA3.5 (Strix Point APU)
"gfx1151", # RDNA3.5 (Strix Halo APU)
"gfx1200", # RDNA4 (RX 9070 XT)
"gfx1201", # RDNA4 (RX 9070)
)
return is_hip() and triton.runtime.driver.active.get_current_target().arch in {
"gfx1100", # RDNA3 (RX 7900 XTX/XT, PRO W7900/W7800)
"gfx1101", # RDNA3 (RX 7800 XT, RX 7700 XT)
"gfx1102", # RDNA3 (RX 7600 XT/7600)
"gfx1150", # RDNA3.5 (Strix Point APU)
"gfx1151", # RDNA3.5 (Strix Halo APU)
"gfx1200", # RDNA4 (RX 9070 XT)
"gfx1201", # RDNA4 (RX 9070)
}

@GoldenGrapeGentleman GoldenGrapeGentleman force-pushed the perf/rdna-kernel-tuning branch 2 times, most recently from bdfcdf8 to c1dc4bb Compare February 27, 2026 08:58
@GoldenGrapeGentleman GoldenGrapeGentleman changed the title perf(ROCm): tune Triton kernel parameters for RDNA consumer GPUs perf(ROCm): add is_rdna() detection and optimize CE loss for RDNA GPUs Feb 27, 2026
@GoldenGrapeGentleman
Copy link
Contributor Author

Rebased: Removed duplicate is_rdna() definition from this PR. Now only contains the CE loss num_warps optimization (1 file, 2 lines).

Dependency: This PR depends on #4109 which provides is_rdna(). Should be merged after #4109.

Changes: Only cross_entropy_loss.py — use 16 warps for RDNA (matching existing CDNA path) in the chunked large-vocab forward kernel. Benchmarked 2.4-2.6x speedup on W7900.

Use 16 warps for RDNA in the chunked cross-entropy forward kernel
(large vocab > 65536), matching the existing CDNA optimization.

Benchmarked on W7900 (gfx1100) with actual unsloth kernels (5 trials, median):
  - Chunked CE forward (BS=65536): 16 warps = 2.4-2.6x faster than 32
  - All other kernels (LayerNorm, RoPE, SwiGLU): default heuristic is
    already optimal for RDNA; no modification needed.

Depends on: unslothai#4109 (provides is_rdna() detection)
Copy link
Contributor

@danielhanchen danielhanchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! This works great!

@danielhanchen danielhanchen merged commit 4d3e7d7 into unslothai:main Mar 1, 2026
1 check passed
GoldenGrapeGentleman added a commit to GoldenGrapeGentleman/unsloth that referenced this pull request Mar 2, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants