Fused softmax-topk-sort HIP kernel for MoE decode dispatch#3526
Draft
akii96 wants to merge 1 commit into
Draft
Conversation
Contributor
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
Combine the softmax, top-k selection, and counting sort into a single HIP kernel launch for the MoE decode path. The intermediate topk_ids and topk_weights are kept in LDS instead of being materialized in global memory, eliminating the roundtrip between topk_softmax and moe_sorting_opus_fwd. The fused kernel handles decode-sized batches that fit in one CU's LDS (M<=122 for E=128/K=4, M<=60 for E=256/K=8). Larger problems fall back to the existing separate kernel chain transparently. Three phases share a single dynamic LDS allocation: 1. Softmax + iterative argmax top-K per row 2. Counting sort (histogram, prefix sum, scatter) 3. Vectorized moe_buf zeroing Measured on gfx950 (MI355X) with openai/gpt-oss-120b, 3-repeat: conc=16: +1.3% throughput conc=32: +3.9% throughput, -3.5% TPOT Signed-off-by: Aakif Nawaz <[email protected]>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
The MoE decode path currently requires three separate kernel launches for token-to-expert routing:
topk_softmax(gating + top-K selection), followed bymoe_sorting_opus_fwd(counting sort + moe_buf zeroing). Between these launches,topk_idsandtopk_weightsare materialized in global memory only to be immediately read back — a pure roundtrip that costs ~2.7ms per iteration on GPT-OSS-120b (1.18ms topk + 1.55ms sorting).Technical Details
A single HIP kernel (
csrc/include/fused_topk_moe_sorting.h) replaces both launches for decode-sized batches. Three phases share one dynamic LDS allocation:sorted_ids/sorted_weights/sorted_expert_idsscatter. SameMOCK_IDencoding andSkipExpertsWithZeroTokenssemantics as the opus oneshot path.uint4stores.LDS budget determines eligibility: M≤122 for (E=128, K=4), M≤60 for (E=256, K=8). Larger batches (prefill) fall back to the existing separate kernel chain transparently via
topk_softmax_sorting()inaiter/fused_moe.py.Files
csrc/include/fused_topk_moe_sorting.hcsrc/py_itfs_cu/fused_topk_moe_sorting_kernels.cucsrc/pybind/fused_topk_moe_sorting_pybind.cucsrc/include/rocm_ops.hppFUSED_TOPK_MOE_SORTING_PYBINDmacroaiter/ops/moe_op.pyfused_topk_moe_sorting_fwdand_max_tokensbindingsaiter/jit/optCompilerConfig.jsonmodule_fused_topk_moe_sortingaiter/fused_moe.pytopk_softmax_sorting()integration with fallbackop_tests/test_fused_topk_moe_sorting.pySafety
dispatch_policy=1(force oneshot) anddispatch_policy=2(force MP) bypass the fused path entirely.sorted_ids/sorted_expert_ids/num_valid_idsandallcloseonsorted_weightsversus the separate chain, verified across 60 configurations.Test Plan
op_tests/test_fused_topk_moe_sorting.pyverifies correctness across (E=128, K=4), (E=256, K=8), and various token counts against the referencetopk_softmax+moe_sorting_opus_fwdchain.Test Result
Measured on gfx950 (MI355X), 3-repeat, ISL=1000/OSL=100: