-
Notifications
You must be signed in to change notification settings - Fork 118
feat: rmsnorm fuse quant and unitest #312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
feat: rmsnorm fuse quant and unitest #312
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for fused FP8 per-token quantization with ROCm GEMM operations and layer normalization. The implementation optimizes MoE models by conditionally disabling quantization for layer normalization when processing MoE models, and extends FP8 quantization support in GEMM operations to handle already-quantized input buffers.
- Added new test file for FP8 per-token, per-channel (PTPC) A8W8 GEMM operations
- Extended ROCm layer normalization to support fused FP8 per-token quantization for RMSNorm
- Modified GEMM operations to skip redundant quantization when inputs are already quantized
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/gemm/rocm_ptpc_a8w8_gemm_op_test.py | New test file for FP8 A8W8 GEMM with Chinese error messages and comprehensive tensor swizzling/shuffling utilities |
| tests/gemm/rocm_pertensor_int8_gemm_op_test.py | Removed trailing empty line |
| tests/gemm/gemm_op_test.cc | Added blank line after class constructor declaration |
| tests/BUILD | Added build configuration for new PTPC A8W8 GEMM test |
| rtp_llm/cpp/models/GptModel.cc | Added conditional logic to disable quantization for MoE models in post-layernorm |
| rtp_llm/cpp/devices/rocm_impl/ROCmLayernorm.cc | Refactored buffer allocation logic and added FP8 per-token quantization support for RMSNorm |
| rtp_llm/cpp/devices/rocm_impl/ROCmGemmOp.cc | Added logic to skip quantization when input is already QBuffer and extended FP8 GEMM dispatch |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 6 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
4d57ed4 to
bdaaa24
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.
Comments suppressed due to low confidence (1)
tests/gemm/rocm_ptpc_a8w8_gemm_op_test.py:1
- Help text contains Chinese characters '计算使用的数据类型'. Comments and documentation should be in English for consistency with the rest of the codebase.
# SPDX-License-Identifier: MIT
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
8057e6b to
b1cad8a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 8 out of 8 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| f"相对误差={torch.abs((a[idx_tuple] - b[idx_tuple]) / b[idx_tuple])}\n" | ||
| ) | ||
| if len(mismatch_indices) > 10: | ||
| error_msg += f"...(共 {len(mismatch_indices)} 处不匹配)" |
Copilot
AI
Nov 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error messages contain Chinese text which may not be accessible to all developers. Consider using English for error messages to maintain consistency with the rest of the codebase and ensure international accessibility.
| error_msg += f"...(共 {len(mismatch_indices)} 处不匹配)" | |
| error_msg += f"...(total {len(mismatch_indices)} mismatches)" |
| auto res_tensor = rmsnorm2d(input_tensor, weight_tensor, static_cast<double>(eps), 0); | ||
| copy({*norm_output, *torchTensor2Buffer(res_tensor)}); | ||
| } | ||
| if (params.qscheme == QScheme::Qfp8PerToken /* Do fuse fp8 pertoken*/) { |
Copilot
AI
Nov 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Grammar correction: 'Do fuse' should be 'Fused' or 'Fuses' in comment.
| if (params.qscheme == QScheme::Qfp8PerToken /* Do fuse fp8 pertoken*/) { | |
| if (params.qscheme == QScheme::Qfp8PerToken /* Fused fp8 per-token */) { |
| autil::StringUtil::toString(arguments.Dshape).c_str(), | ||
| params.D->debugString().c_str()); | ||
| params.D->debugString().c_str()); | ||
| } else if (params.A.type() == DataType::TYPE_QFP8_E4M3 && params.B.type() == DataType::TYPE_QFP8_E4M3 /* if fused fp8 pertoken & rmsnorm */) { |
Copilot
AI
Nov 4, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Grammar correction: 'if fused' should be 'fused' (remove 'if') in comment.
| } else if (params.A.type() == DataType::TYPE_QFP8_E4M3 && params.B.type() == DataType::TYPE_QFP8_E4M3 /* if fused fp8 pertoken & rmsnorm */) { | |
| } else if (params.A.type() == DataType::TYPE_QFP8_E4M3 && params.B.type() == DataType::TYPE_QFP8_E4M3 /* fused fp8 pertoken & rmsnorm */) { |
1573862 to
34a8e0e
Compare
34a8e0e to
d9f1d42
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| } else if (params.A.type() == DataType::TYPE_INT8 || params.A.type() == DataType::TYPE_QINT8){ | ||
| DDtype = DataType::TYPE_FP16; | ||
| } else if (params.A.type() == DataType::TYPE_FP8_E4M3 || params.A.type() == DataType::TYPE_QFP8_E4M3){ | ||
| // TO DO: When A is TYPE_FP8_E4M3, choose output dtype according to env "ACT_TYPE". |
Copilot
AI
Nov 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment uses 'TO DO' instead of standard 'TODO' format.
| // TO DO: When A is TYPE_FP8_E4M3, choose output dtype according to env "ACT_TYPE". | |
| // TODO: When A is TYPE_FP8_E4M3, choose output dtype according to env "ACT_TYPE". |
| arguments.DDtype, | ||
| autil::StringUtil::toString(arguments.Dshape).c_str(), | ||
| params.D->debugString().c_str()); | ||
| params.D->debugString().c_str()); |
Copilot
AI
Nov 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This line appears to have only trailing whitespace removed compared to the original, but the trailing spaces on line 380 should be fully cleaned up for consistency.
| params.D->debugString().c_str()); | |
| params.D->debugString().c_str()); |
| scale_N); | ||
| } | ||
| return std::move(output); | ||
|
|
Copilot
AI
Nov 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Missing blank line after this return statement before the closing brace for consistency with surrounding code patterns (see line 504 which has similar structure).
No description provided.