Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[gfx12] mla v3#3500

Open
feifei14119 wants to merge 1 commit into
ROCm:mainfrom
feifei14119:feiw/pr/mla2
Open

[gfx12] mla v3#3500
feifei14119 wants to merge 1 commit into
ROCm:mainfrom
feifei14119:feiw/pr/mla2

Conversation

@feifei14119
Copy link
Copy Markdown
Contributor

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 2, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3500 --add-label <label>

@feifei14119 feifei14119 force-pushed the feiw/pr/mla2 branch 2 times, most recently from 8ccaaa6 to 583232d Compare June 3, 2026 06:01
@feifei14119 feifei14119 marked this pull request as ready for review June 3, 2026 06:04
@feifei14119 feifei14119 requested review from a team and Copilot June 3, 2026 06:04
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds gfx1250 (mi400) MLA v3 decode support by wiring in a dedicated asm dispatch path, a mi400-focused sweep in the existing MLA test driver, and config entries for gfx1250 MLA asm kernels. Also introduces an env-based runtime switch to fully disable the optional FlyDSL backend.

Changes:

  • Extend mla_decode_stage1_asm_fwd dispatch to route gfx1250 to a mi400-specific kernarg pack + kernel selection from hsa/gfx1250/mla/mla_asm.csv (with optional debug dumping under ASM_DEBUG).
  • Add a --mi400 {auto,on,off} sweep mode to op_tests/test_mla.py that builds fp8/rope-split2 packed inputs and validates numerics against a reference.
  • Add ENABLE_FLYDSL runtime opt-out plumbing for FlyDSL availability checks.

Reviewed changes

Copilot reviewed 6 out of 10 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
op_tests/test_mla.py Adds mi400 sweep mode, fp8 packing helpers, and a mi400-specific decode validation path.
hsa/gfx1250/mla/mla_asm.csv Introduces gfx1250 MLA asm kernel registry entries used by heuristic dispatch.
csrc/py_itfs_cu/asm_mla.cu Adds gfx1250 mi400 stage1 dispatch (kernargs ABI) and optional debug instrumentation/dumps.
aiter/ops/flydsl/utils.py Adds ENABLE_FLYDSL env opt-out to disable FlyDSL backend at runtime.
aiter/mla.py Adjusts decode buffer allocation/aliasing and kv_indptr handling for gfx1250 mi400 decode.
aiter/jit/core.py Adds global ENABLE_FLYDSL flag mirroring ENABLE_CK.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread aiter/mla.py Outdated
Comment on lines +234 to +241
_is_gfx1250 = get_gfx() == "gfx1250"
_can_alias_o_as_logits = (
num_kv_splits == 1
and (
q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4)
)
and not _is_gfx1250
)
Comment thread op_tests/test_mla.py
Comment on lines +358 to +367
def _cosine_diff(actual, expected):
actual = actual.detach().float().cpu()
expected = expected.detach().float().cpu()
assert torch.isfinite(actual).all()
assert torch.isfinite(expected).all()
numerator = 2 * (actual.double() * expected.double()).sum()
denominator = (
(actual.double().square() + expected.double().square()).sum().clamp_min(1e-12)
)
return (1 - (numerator / denominator)).item()
Comment thread op_tests/test_mla.py
is_causal=True,
dtype=out_dtype,
)
# troch implementation. mi400 uses its own _ref_mla_mi400 golden (built on
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants