QMoE CUDA EP — FP4/FP8/WFP4AFP8 Quantized Mixture-of-Experts + MoE GEMM Refactor#28467
Draft
tianleiwu wants to merge 1 commit into
Draft
QMoE CUDA EP — FP4/FP8/WFP4AFP8 Quantized Mixture-of-Experts + MoE GEMM Refactor#28467tianleiwu wants to merge 1 commit into
tianleiwu wants to merge 1 commit into
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Add a new
QMoEcontrib operator for the CUDA EP that supports quantized Mixture-of-Experts inference with INT4, INT8, FP4 (MXFP4 e2m1), FP8 (e4m3fn), and WFP4AFP8 (mixed FP4 weight × FP8 activation) quantization formats. This also refactors the existing MoE GEMM infrastructure to support TMA warp-specialized grouped GEMM on Hopper (SM90), native MXFP4 on Blackwell (SM120), and block-scaled tensor ops on SM100+, with automatic fallback to dequantization on older architectures.Summary of Changes
New QMoE Operator
onnxruntime/core/graph/contrib_ops/contrib_defs.ccQMoEop schema (com.microsoft domain, opset 1)onnxruntime/contrib_ops/cuda/moe/moe_quantization.cc/honnxruntime/contrib_ops/cuda/moe/qmoe_kernels.cu/honnxruntime/contrib_ops/cuda/moe/moe_base.hdocs/contrib_ops/cuda/moe_qmoe.mdMoE GEMM Refactor
onnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemm_kernels.hCutlassMoeFCRunnertemplate with FP4/FP8/WFP4AFP8 specializationsonnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemm_template_dispatch.honnxruntime/contrib_ops/cuda/llm/moe_gemm/moe_gemm_profiler.cc/honnxruntime/contrib_ops/cuda/llm/moe_gemm/common.honnxruntime/contrib_ops/cuda/llm/moe_gemm/launchers/CUTLASS Extensions
onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/arch/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/collective/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm/kernel/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/epilogue/onnxruntime/contrib_ops/cuda/llm/cutlass_extensions/system_barrier.hCommon CUDA Utilities
onnxruntime/contrib_ops/cuda/llm/common/cuda_fp8_utils.cu/h— FP8 conversion, quantization, dequantization kernelsonnxruntime/contrib_ops/cuda/llm/common/memory_utils.cu/h— Device memory transpose, permute, type conversion utilitiesonnxruntime/contrib_ops/cuda/llm/common/cuda_type_utils.cuh— Unified type traits for half/bfloat16/float/fp8/fp4onnxruntime/contrib_ops/cuda/llm/common/quantization.h— Quantization parameter structs and helpersonnxruntime/contrib_ops/cuda/llm/common/reduce_kernel_utils.cuh— Warp/block reduction primitivesonnxruntime/contrib_ops/cuda/llm/kernels/quantization.cuh— FP4/FP8 quantization kernelsonnxruntime/contrib_ops/cuda/llm/kernels/pre_quant_scale_kernel.cu/h— Pre-quantization scaling kernelGEMM Profiler Refactor
onnxruntime/contrib_ops/cuda/llm/gemm_profiler.cc/honnxruntime/contrib_ops/cuda/llm/cutlass_heuristic.cc/honnxruntime/contrib_ops/cuda/llm/cutlass_extensions/gemm_configs.hBuild System
cmake/CMakeLists.txtENABLE_FP4,ENABLE_FP8,ENABLE_CUDA_FP4_QMOE,ORT_QUICK_BUILD,PLACEHOLDER_KERNELSoptionscmake/external/cuda_configuration.cmakecmake/external/cutlass.cmakecmake/onnxruntime_providers_cuda.cmakecmake/onnxruntime_python.cmakeonnxruntime_pybind_quant.ccfor Python quantization bindingsPython Quantization Bindings
onnxruntime/python/onnxruntime_pybind_quant.cconnxruntime/python/tools/quantization/quant_utils.pysetup.pyTests
onnxruntime/test/python/transformers/test_qmoe_cuda.pyonnxruntime/test/python/transformers/test_qmoe_fp4_cuda.pyonnxruntime/test/python/transformers/test_qmoe_fp8_cuda.pyonnxruntime/test/python/transformers/test_qmoe_wfp4afp8_cuda.pyonnxruntime/test/python/transformers/test_moe_cuda.pyonnxruntime/test/contrib_ops/moe_test.ccExisting MoE Refactor
onnxruntime/contrib_ops/cuda/moe/moe.cc/h— Refactored to share base with QMoEonnxruntime/contrib_ops/cuda/moe/ft_moe/→onnxruntime/contrib_ops/cuda/llm/moe_gemm/— Relocated and rewritten MoE GEMM kernelscuda/quantization/moe_quantization.cc/hin favor of newcuda/moe/moe_quantization.cc/hTesting
python -m pytest onnxruntime/test/python/transformers/test_qmoe_cuda.py -v(requires CUDA GPU, SM75+)python -m pytest onnxruntime/test/python/transformers/test_qmoe_fp4_cuda.py -v(requires SM120+ for native, falls back on older)python -m pytest onnxruntime/test/python/transformers/test_qmoe_fp8_cuda.py -v(requires SM90+ for native)python -m pytest onnxruntime/test/python/transformers/test_qmoe_wfp4afp8_cuda.py -v(requires SM100+)python -m pytest onnxruntime/test/python/transformers/test_moe_cuda.py -vonnxruntime_test_all --gtest_filter=*MoE*Motivation and Context
Modern LLMs increasingly use Mixture-of-Experts architectures (e.g., Mixtral, DeepSeek, Phi-3.5-MoE) for efficient scaling. These models benefit significantly from weight quantization to reduce memory bandwidth and enable larger models on fewer GPUs. This PR: