Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Fused softmax-topk-sort HIP kernel for MoE decode dispatch#3526

Draft
akii96 wants to merge 1 commit into
mainfrom
fused_topk_sort
Draft

Fused softmax-topk-sort HIP kernel for MoE decode dispatch#3526
akii96 wants to merge 1 commit into
mainfrom
fused_topk_sort

Conversation

@akii96
Copy link
Copy Markdown
Contributor

@akii96 akii96 commented Jun 4, 2026

Motivation

The MoE decode path currently requires three separate kernel launches for token-to-expert routing: topk_softmax (gating + top-K selection), followed by moe_sorting_opus_fwd (counting sort + moe_buf zeroing). Between these launches, topk_ids and topk_weights are materialized in global memory only to be immediately read back — a pure roundtrip that costs ~2.7ms per iteration on GPT-OSS-120b (1.18ms topk + 1.55ms sorting).

Technical Details

A single HIP kernel (csrc/include/fused_topk_moe_sorting.h) replaces both launches for decode-sized batches. Three phases share one dynamic LDS allocation:

  1. Softmax + top-K: per-row softmax over experts, iterative argmax top-K selection. Results written directly to LDS histogram, never touching global memory.
  2. Counting sort: per-expert histogram → padded prefix sum → sorted_ids/sorted_weights/sorted_expert_ids scatter. Same MOCK_ID encoding and SkipExpertsWithZeroTokens semantics as the opus oneshot path.
  3. moe_buf zeroing: vectorized uint4 stores.

LDS budget determines eligibility: M≤122 for (E=128, K=4), M≤60 for (E=256, K=8). Larger batches (prefill) fall back to the existing separate kernel chain transparently via topk_softmax_sorting() in aiter/fused_moe.py.

Files

File Change
csrc/include/fused_topk_moe_sorting.h New — the fused HIP kernel
csrc/py_itfs_cu/fused_topk_moe_sorting_kernels.cu New — kernel launch wrapper
csrc/pybind/fused_topk_moe_sorting_pybind.cu New — pybind module
csrc/include/rocm_ops.hpp Added FUSED_TOPK_MOE_SORTING_PYBIND macro
aiter/ops/moe_op.py Added fused_topk_moe_sorting_fwd and _max_tokens bindings
aiter/jit/optCompilerConfig.json Registered module_fused_topk_moe_sorting
aiter/fused_moe.py Added topk_softmax_sorting() integration with fallback
op_tests/test_fused_topk_moe_sorting.py New — correctness verification harness

Safety

  • Shapes above the LDS budget are never routed to the fused kernel — transparent fallback to the existing separate chain.
  • Dense (non-MoE) models are unaffected — the kernel is only called from the MoE sorting dispatch.
  • dispatch_policy=1 (force oneshot) and dispatch_policy=2 (force MP) bypass the fused path entirely.
  • Output is bit-exact on sorted_ids/sorted_expert_ids/num_valid_ids and allclose on sorted_weights versus the separate chain, verified across 60 configurations.

Test Plan

  • op_tests/test_fused_topk_moe_sorting.py verifies correctness across (E=128, K=4), (E=256, K=8), and various token counts against the reference topk_softmax + moe_sorting_opus_fwd chain.
  • E2E serving benchmarks on GPT-OSS-120b (gfx950 / MI355X, vLLM 0.22.0).

Test Result

Measured on gfx950 (MI355X), 3-repeat, ISL=1000/OSL=100:

Model conc=16 conc=32
openai/gpt-oss-120b (128 experts, TP=1) +1.3% throughput +3.9% throughput, −3.5% TPOT

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Jun 4, 2026

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests on MI35X (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-300x Run an additional Triton test job on MI300X in PRs; main branch always runs both MI35X and MI300X
ci:sglang SGLang integration tests: DeepSeek-R1-MXFP4 accuracy, Qwen 3.5 accuracy
ci:atom ATOM benchmark: DeepSeek-R1-0528, GPT-OSS-120B
ci:atom_full ATOM accuracy suite for PR and main models from ATOM models_accuracy.json
ci:vllm vLLM benchmark: GPT-OSS-120B, DeepSeek-R1-0528, Kimi-K2.5
ci:all All standard extended tests (excludes ci:atom_full)

Only add ci:atom_full for FlyDSL or Triton upgrades.
Add labels via the sidebar or gh pr edit 3526 --add-label <label>

Combine the softmax, top-k selection, and counting sort into a single
HIP kernel launch for the MoE decode path. The intermediate topk_ids
and topk_weights are kept in LDS instead of being materialized in
global memory, eliminating the roundtrip between topk_softmax and
moe_sorting_opus_fwd.
The fused kernel handles decode-sized batches that fit in one CU's LDS
(M<=122 for E=128/K=4, M<=60 for E=256/K=8). Larger problems fall back
to the existing separate kernel chain transparently.
Three phases share a single dynamic LDS allocation:
  1. Softmax + iterative argmax top-K per row
  2. Counting sort (histogram, prefix sum, scatter)
  3. Vectorized moe_buf zeroing
Measured on gfx950 (MI355X) with openai/gpt-oss-120b, 3-repeat:
  conc=16: +1.3% throughput
  conc=32: +3.9% throughput, -3.5% TPOT
Signed-off-by: Aakif Nawaz <[email protected]>
@akii96 akii96 force-pushed the fused_topk_sort branch from fddeda1 to adc0a27 Compare June 4, 2026 00:38
@akii96 akii96 added the ci:vllm label Jun 4, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant