Thanks to visit codestin.com
Credit goes to github.com

Skip to content

test_moe_blockscale_e2e harness bug: raw FP8-code activations cause f16 overflow and false 100% failures #642

@jhinpan

Description

@jhinpan

TL;DR

test_moe_blockscale_e2e reports the FlyDSL 2-stage block-scale MoE pipeline as
100% wrong (stage2 err_ratio = 1.0, all-NaN output). The kernel is not
wrong.
The 1.0 is entirely a test-harness artifact: three compounding bugs in
the test make a correct kernel — and even aiter's own production CK stage2
kernel — both look 100% wrong. Fixing the test data prep makes the whole pipeline
pass at err ≈ 0.000.

This supersedes the original framing ("the FlyDSL stage2 dequant / inter-stage
re-quant is structurally wrong") — that attribution is incorrect.

Reproduce

python -m pytest tests/kernels/test_moe_blockscale.py -k small-E8 -s -q

Original (broken) output on gfx950 / MI350X:

  stage1 err_ratio = 0.0001
  stage2 err_ratio = 1.0000        # FlyDSL stage2 vs torch ref
  flydsl: pipeline err_ratio vs ref = 1.0000
  aiter fused: 52.5us  (err_ratio=0.1390)
  ck stage1:  err_vs_ref=0.0001, err_vs_fly=0.0000
  ck stage2:  err_vs_ref=1.0000, err_vs_fly=0.8756   # aiter's PROD kernel ALSO 1.0

The decisive tell is ck stage2 err_vs_ref = 1.0000: aiter's known-good CK stage2
kernel, fed the same a2_bq/scales and compared to the same torch reference,
also scores 100% wrong. A test that condemns the vendor's production kernel
identically to ours is measuring the harness, not the kernel.

What's actually wrong — three test bugs

1. Activation data prep is ~2000x too large. The block-scale activation is
built from the raw fp8 codes instead of the dequantized activation:

x_q, x_scale = pertoken_quant(x_fp32, quant_dtype=DTYPE_FP8)  # x_q are fp8 codes, ~448-scale
...
x_f32_for_blk = x_q.float()          # <-- BUG: fp8 codes (~448), not x_q*x_scale (~0.2)
a1_bq, a1_bscale = pertoken_quant(x_f32_for_blk.view(...))

Dropping x_scale inflates the activation by ~1/x_scale (~2000x). The stage1
intermediate then reaches absmax ≈ 2.4e7, absmean ≈ 1.1e6.

2. f16 cannot hold those magnitudes (max 65504), so everything overflows.

  • The reference itself casts the intermediate to f16 before re-quantizing
    (out1_torch_ref.to(torch.float16)), which overflows to inf; the requant's
    nan_to_num clamps inf → fp8 max (448), collapsing the dequantized activation
    ~400x and scrambling its direction. Pure-torch check, no kernel involved:
    err(2-stage torch ref vs fused torch ref) = 0.9999, cosine = 0.04, ‖2stage‖/‖fused‖ = 0.0025.
    → This is why out2_ref_s2 / out_ref are garbage and every kernel scores 1.0 against them.
  • The kernels write f16 output (stage1's cshuffle epilog is f16-only; stage2 is f16
    here too) → 1.9e6 overflows → all-NaN output for both FlyDSL and aiter CK.
  • Even stage1 err_ratio = 0.0001 is an inf == inf artifact, not a real match.

3. The test returns instead of asserting (return us_fly_total, us_aiter_fused),
so all of the above passes silently (pytest warns Did you mean to use assert instead of return?).

Note: torch_stage2_blockscale_ref itself is correct — fed a clean fp32
activation it reproduces the fused reference exactly (err 0.0000, cos 1.0). It is
poisoned only by the overflowed input.

The kernel is correct

  • stage2 in isolation, clean in-range magnitudes, same a2_bq to both kernel
    and torch_stage2_blockscale_ref: err = 0.0000, cos = 1.0000, all finite.
  • Full pipeline after the data-prep fix (real pytest, with asserts):
    stage1 err = 0.0000
    stage2 err = 0.0000
    flydsl pipeline err vs ref = 0.0016
    aiter fused err = 0.0000
    ck stage1 err_vs_ref/vs_fly = 0.0000 / 0.0000
    ck stage2 err_vs_ref/vs_fly = 0.0000 / 0.0000
    2 passed
    

Fix

The minimal, correct fix is the data prep — not "bf16 output" (stage1's
cshuffle epilog only supports f16; once magnitudes are realistic, f16 is fine):

# build the block-scale activation from the dequantized x, not the fp8 codes
x_f32_for_blk = x_q.float() * x_scale
...
# same fix for the FlyDSL kernel input:
a1_bq, a1_scale_fly = per_group_quant_hip(
    (x_q.float() * x_scale).to(torch.bfloat16),
    quant_dtype=DTYPE_FP8, group_size=scale_blk_k, transpose_scale=True,
)

Then keep the assert on err_fly + finiteness (already added) so a real
regression can't pass silently. (The trailing return of timings — kept for the
__main__ benchmark path — still triggers pytest's return-not-assert warning;
optionally split the benchmark out so the test returns None.)

Impact / corrected note

The multi-shape cold-cache sweep that flagged moe_blockscale as 0/12 correct
(all-NaN final output) was hitting the same f16 overflow, not a kernel bug. In
a realistic-magnitude / bf16-output deployment the kernel produces correct,
finite results.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions