TL;DR
test_moe_blockscale_e2e reports the FlyDSL 2-stage block-scale MoE pipeline as
100% wrong (stage2 err_ratio = 1.0, all-NaN output). The kernel is not
wrong. The 1.0 is entirely a test-harness artifact: three compounding bugs in
the test make a correct kernel — and even aiter's own production CK stage2
kernel — both look 100% wrong. Fixing the test data prep makes the whole pipeline
pass at err ≈ 0.000.
This supersedes the original framing ("the FlyDSL stage2 dequant / inter-stage
re-quant is structurally wrong") — that attribution is incorrect.
Reproduce
python -m pytest tests/kernels/test_moe_blockscale.py -k small-E8 -s -q
Original (broken) output on gfx950 / MI350X:
stage1 err_ratio = 0.0001
stage2 err_ratio = 1.0000 # FlyDSL stage2 vs torch ref
flydsl: pipeline err_ratio vs ref = 1.0000
aiter fused: 52.5us (err_ratio=0.1390)
ck stage1: err_vs_ref=0.0001, err_vs_fly=0.0000
ck stage2: err_vs_ref=1.0000, err_vs_fly=0.8756 # aiter's PROD kernel ALSO 1.0
The decisive tell is ck stage2 err_vs_ref = 1.0000: aiter's known-good CK stage2
kernel, fed the same a2_bq/scales and compared to the same torch reference,
also scores 100% wrong. A test that condemns the vendor's production kernel
identically to ours is measuring the harness, not the kernel.
What's actually wrong — three test bugs
1. Activation data prep is ~2000x too large. The block-scale activation is
built from the raw fp8 codes instead of the dequantized activation:
x_q, x_scale = pertoken_quant(x_fp32, quant_dtype=DTYPE_FP8) # x_q are fp8 codes, ~448-scale
...
x_f32_for_blk = x_q.float() # <-- BUG: fp8 codes (~448), not x_q*x_scale (~0.2)
a1_bq, a1_bscale = pertoken_quant(x_f32_for_blk.view(...))
Dropping x_scale inflates the activation by ~1/x_scale (~2000x). The stage1
intermediate then reaches absmax ≈ 2.4e7, absmean ≈ 1.1e6.
2. f16 cannot hold those magnitudes (max 65504), so everything overflows.
- The reference itself casts the intermediate to f16 before re-quantizing
(out1_torch_ref.to(torch.float16)), which overflows to inf; the requant's
nan_to_num clamps inf → fp8 max (448), collapsing the dequantized activation
~400x and scrambling its direction. Pure-torch check, no kernel involved:
err(2-stage torch ref vs fused torch ref) = 0.9999, cosine = 0.04, ‖2stage‖/‖fused‖ = 0.0025.
→ This is why out2_ref_s2 / out_ref are garbage and every kernel scores 1.0 against them.
- The kernels write f16 output (stage1's cshuffle epilog is f16-only; stage2 is f16
here too) → 1.9e6 overflows → all-NaN output for both FlyDSL and aiter CK.
- Even
stage1 err_ratio = 0.0001 is an inf == inf artifact, not a real match.
3. The test returns instead of asserting (return us_fly_total, us_aiter_fused),
so all of the above passes silently (pytest warns Did you mean to use assert instead of return?).
Note: torch_stage2_blockscale_ref itself is correct — fed a clean fp32
activation it reproduces the fused reference exactly (err 0.0000, cos 1.0). It is
poisoned only by the overflowed input.
The kernel is correct
- stage2 in isolation, clean in-range magnitudes, same
a2_bq to both kernel
and torch_stage2_blockscale_ref: err = 0.0000, cos = 1.0000, all finite.
- Full pipeline after the data-prep fix (real pytest, with asserts):
stage1 err = 0.0000
stage2 err = 0.0000
flydsl pipeline err vs ref = 0.0016
aiter fused err = 0.0000
ck stage1 err_vs_ref/vs_fly = 0.0000 / 0.0000
ck stage2 err_vs_ref/vs_fly = 0.0000 / 0.0000
2 passed
Fix
The minimal, correct fix is the data prep — not "bf16 output" (stage1's
cshuffle epilog only supports f16; once magnitudes are realistic, f16 is fine):
# build the block-scale activation from the dequantized x, not the fp8 codes
x_f32_for_blk = x_q.float() * x_scale
...
# same fix for the FlyDSL kernel input:
a1_bq, a1_scale_fly = per_group_quant_hip(
(x_q.float() * x_scale).to(torch.bfloat16),
quant_dtype=DTYPE_FP8, group_size=scale_blk_k, transpose_scale=True,
)
Then keep the assert on err_fly + finiteness (already added) so a real
regression can't pass silently. (The trailing return of timings — kept for the
__main__ benchmark path — still triggers pytest's return-not-assert warning;
optionally split the benchmark out so the test returns None.)
Impact / corrected note
The multi-shape cold-cache sweep that flagged moe_blockscale as 0/12 correct
(all-NaN final output) was hitting the same f16 overflow, not a kernel bug. In
a realistic-magnitude / bf16-output deployment the kernel produces correct,
finite results.
TL;DR
test_moe_blockscale_e2ereports the FlyDSL 2-stage block-scale MoE pipeline as100% wrong (
stage2 err_ratio = 1.0, all-NaN output). The kernel is notwrong. The
1.0is entirely a test-harness artifact: three compounding bugs inthe test make a correct kernel — and even aiter's own production CK stage2
kernel — both look 100% wrong. Fixing the test data prep makes the whole pipeline
pass at
err ≈ 0.000.This supersedes the original framing ("the FlyDSL stage2 dequant / inter-stage
re-quant is structurally wrong") — that attribution is incorrect.
Reproduce
Original (broken) output on gfx950 / MI350X:
The decisive tell is
ck stage2 err_vs_ref = 1.0000: aiter's known-good CK stage2kernel, fed the same
a2_bq/scales and compared to the same torch reference,also scores 100% wrong. A test that condemns the vendor's production kernel
identically to ours is measuring the harness, not the kernel.
What's actually wrong — three test bugs
1. Activation data prep is ~2000x too large. The block-scale activation is
built from the raw fp8 codes instead of the dequantized activation:
Dropping
x_scaleinflates the activation by ~1/x_scale (~2000x). The stage1intermediate then reaches
absmax ≈ 2.4e7,absmean ≈ 1.1e6.2. f16 cannot hold those magnitudes (max 65504), so everything overflows.
(
out1_torch_ref.to(torch.float16)), which overflows toinf; the requant'snan_to_numclampsinf → fp8 max (448), collapsing the dequantized activation~400x and scrambling its direction. Pure-torch check, no kernel involved:
err(2-stage torch ref vs fused torch ref) = 0.9999, cosine = 0.04, ‖2stage‖/‖fused‖ = 0.0025.→ This is why
out2_ref_s2/out_refare garbage and every kernel scores 1.0 against them.here too) →
1.9e6overflows → all-NaN output for both FlyDSL and aiter CK.stage1 err_ratio = 0.0001is aninf == infartifact, not a real match.3. The test
returns instead ofasserting (return us_fly_total, us_aiter_fused),so all of the above passes silently (pytest warns
Did you mean to use assert instead of return?).Note:
torch_stage2_blockscale_refitself is correct — fed a clean fp32activation it reproduces the fused reference exactly (
err 0.0000, cos 1.0). It ispoisoned only by the overflowed input.
The kernel is correct
a2_bqto both kerneland
torch_stage2_blockscale_ref:err = 0.0000, cos = 1.0000, all finite.Fix
The minimal, correct fix is the data prep — not "bf16 output" (stage1's
cshuffle epilog only supports f16; once magnitudes are realistic, f16 is fine):
Then keep the
assertonerr_fly+ finiteness (already added) so a realregression can't pass silently. (The trailing
returnof timings — kept for the__main__benchmark path — still triggers pytest's return-not-assert warning;optionally split the benchmark out so the test returns
None.)Impact / corrected note
The multi-shape cold-cache sweep that flagged
moe_blockscaleas 0/12 correct(all-NaN final output) was hitting the same f16 overflow, not a kernel bug. In
a realistic-magnitude / bf16-output deployment the kernel produces correct,
finite results.