Thanks to visit codestin.com
Credit goes to github.com

Skip to content

ROCm: Add gfx950 (MI355X/CDNA4) to is_cdna()#4051

Merged
danielhanchen merged 1 commit intounslothai:mainfrom
GoldenGrapeGentleman:fix/is-cdna-gfx950-clean
Feb 14, 2026
Merged

ROCm: Add gfx950 (MI355X/CDNA4) to is_cdna()#4051
danielhanchen merged 1 commit intounslothai:mainfrom
GoldenGrapeGentleman:fix/is-cdna-gfx950-clean

Conversation

@GoldenGrapeGentleman
Copy link
Contributor

@GoldenGrapeGentleman GoldenGrapeGentleman commented Feb 14, 2026

Summary

Add AMD Instinct MI355X (gfx950 / CDNA4) to is_cdna() so Triton kernels use the correct num_warps.

Problem

is_cdna() only listed gfx940/941/942 (MI300 series). MI355X (gfx950, CDNA4) has the same 1024-thread workgroup limit and 64-thread wavefront size, but was missing. This caused all Triton kernels to use num_warps=32 (2048 threads) instead of 16 (1024 threads):

triton.runtime.errors.OutOfResources: out of resource: threads,
Required: 2048, Hardware limit: 1024

This blocked all training on MI355X.

Change

 def is_cdna():
     return is_hip() and triton.runtime.driver.active.get_current_target().arch in (
         "gfx940",
         "gfx941",
         "gfx942",
+        "gfx950",  # CDNA4 (MI350/MI355X)
     )

Hardware verification

GPU: AMD Instinct MI355X
gcnArchName: gfx950:sramecc+:xnack-
warp_size: 64
workgroup thread limit: 1024 (same as gfx942)

Tested on 8× AMD Instinct MI355X (gfx950), ROCm 7.1

Test Result
Vision RL GRPO (Qwen2.5-VL-7B) ✅ 5/5 steps
Code RL GRPO (gpt-oss-20b BF16) ✅ 20/20 steps
gpt-oss-120b GRPO (8-GPU) ✅ 5/5 steps
MoE expert LoRA + merge ✅ 46.2M trainable, merge success

Note

Full MI355X support also requires PR #4021 (ROCm GPT-OSS MXFP4→BF16 routing) by @danielhanchen, I closed the full change which is PR #4050. This PR is the additional piece needed for CDNA4 Triton kernel compatibility.

MI355X (gfx950) has the same 1024-thread workgroup limit as MI300X (gfx942),
but was missing from is_cdna(), causing all Triton kernels to use num_warps=32
(2048 threads) instead of 16 (1024 threads), resulting in OutOfResources crash.

Tested on: 8x AMD Instinct MI355X (gfx950), ROCm 7.1
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @GoldenGrapeGentleman, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical compatibility issue for AMD Instinct MI355X (CDNA4) GPUs within the Triton framework. By extending the is_cdna() function to recognize the gfx950 architecture, it enables Triton kernels to correctly determine the appropriate num_warps, thereby preventing OutOfResources errors and allowing successful execution of machine learning workloads on MI355X hardware.

Highlights

  • ROCm Support: Added gfx950 (AMD Instinct MI355X / CDNA4) to the is_cdna() function to ensure Triton kernels use the correct num_warps.
  • Bug Fix: Resolved an OutOfResources error that prevented Triton kernels from running on MI355X by ensuring the correct 1024-thread workgroup limit is recognized.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • unsloth/kernels/utils.py
    • Added gfx950 to the list of CDNA architectures recognized by the is_cdna() function.
Activity
  • Hardware verification was performed, confirming gfx950 has a 1024-thread workgroup limit and 64-thread wavefront size.
  • Extensive testing was conducted on 8× AMD Instinct MI355X (gfx950) with ROCm 7.1, including Vision RL GRPO, Code RL GRPO, gpt-oss-120b GRPO, and MoE expert LoRA + merge, all showing successful results.
  • This PR is noted as an additional piece required for full MI355X support, complementing PR ROCm: default GPT-OSS to BF16 and disable AITER #4021.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for AMD's MI355X (gfx950 / CDNA4) GPUs by including 'gfx950' in the list of CDNA architectures. This is a crucial fix that resolves an 'OutOfResources' error, enabling Triton kernels to use the correct number of warps on this hardware. The change is straightforward, well-justified, and supported by thorough testing results provided in the description. The implementation is correct and follows the existing pattern. Excellent work!

@danielhanchen
Copy link
Contributor

Oh thank you - also thanks for the other PRs! Will review

@danielhanchen
Copy link
Contributor

Oh thanks

@danielhanchen danielhanchen merged commit 5fedf82 into unslothai:main Feb 14, 2026
1 check passed
@danielhanchen
Copy link
Contributor

I was still working on the other PR!

@GoldenGrapeGentleman
Copy link
Contributor Author

I was still working on the other PR!

Great Effort!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants