[gfx12] mla v3#3500
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
8ccaaa6 to
583232d
Compare
There was a problem hiding this comment.
Pull request overview
Adds gfx1250 (mi400) MLA v3 decode support by wiring in a dedicated asm dispatch path, a mi400-focused sweep in the existing MLA test driver, and config entries for gfx1250 MLA asm kernels. Also introduces an env-based runtime switch to fully disable the optional FlyDSL backend.
Changes:
- Extend
mla_decode_stage1_asm_fwddispatch to route gfx1250 to a mi400-specific kernarg pack + kernel selection fromhsa/gfx1250/mla/mla_asm.csv(with optional debug dumping underASM_DEBUG). - Add a
--mi400 {auto,on,off}sweep mode toop_tests/test_mla.pythat builds fp8/rope-split2 packed inputs and validates numerics against a reference. - Add
ENABLE_FLYDSLruntime opt-out plumbing for FlyDSL availability checks.
Reviewed changes
Copilot reviewed 6 out of 10 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| op_tests/test_mla.py | Adds mi400 sweep mode, fp8 packing helpers, and a mi400-specific decode validation path. |
| hsa/gfx1250/mla/mla_asm.csv | Introduces gfx1250 MLA asm kernel registry entries used by heuristic dispatch. |
| csrc/py_itfs_cu/asm_mla.cu | Adds gfx1250 mi400 stage1 dispatch (kernargs ABI) and optional debug instrumentation/dumps. |
| aiter/ops/flydsl/utils.py | Adds ENABLE_FLYDSL env opt-out to disable FlyDSL backend at runtime. |
| aiter/mla.py | Adjusts decode buffer allocation/aliasing and kv_indptr handling for gfx1250 mi400 decode. |
| aiter/jit/core.py | Adds global ENABLE_FLYDSL flag mirroring ENABLE_CK. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| _is_gfx1250 = get_gfx() == "gfx1250" | ||
| _can_alias_o_as_logits = ( | ||
| num_kv_splits == 1 | ||
| and ( | ||
| q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4) | ||
| ) | ||
| and not _is_gfx1250 | ||
| ) |
| def _cosine_diff(actual, expected): | ||
| actual = actual.detach().float().cpu() | ||
| expected = expected.detach().float().cpu() | ||
| assert torch.isfinite(actual).all() | ||
| assert torch.isfinite(expected).all() | ||
| numerator = 2 * (actual.double() * expected.double()).sum() | ||
| denominator = ( | ||
| (actual.double().square() + expected.double().square()).sum().clamp_min(1e-12) | ||
| ) | ||
| return (1 - (numerator / denominator)).item() |
| is_causal=True, | ||
| dtype=out_dtype, | ||
| ) | ||
| # troch implementation. mi400 uses its own _ref_mla_mi400 golden (built on |
Motivation
Technical Details
Test Plan
Test Result
Submission Checklist