Add GEMM + AllScatter benchmark sweep with multi-shape tiling analysis, optimal tile guidance, iris vectorization hints, roofline analysis, kernel comm-compute overlap optimization, comprehensive MNK roofline sweep, strong/weak scaling analysis, deep G...#411
Add GEMM + AllScatter benchmark sweep with multi-shape tiling analysis, optimal tile guidance, iris vectorization hints, roofline analysis, kernel comm-compute overlap optimization, comprehensive MNK roofline sweep, strong/weak scaling analysis, deep G...#411
Conversation
Co-authored-by: mawad-amd <[email protected]>
|
@copilot can you just run them and report back the results. |
Co-authored-by: mawad-amd <[email protected]>
Ran the GEMM+AllScatter benchmarks on 8x AMD MI300X (304 CUs each) with fp16, sweeping M from 1 to 1024 (N=4096, K=14336). Results committed in
At M=1024 the fused kernel achieves 172.8 TFLOPS in 0.696 ms end-to-end. Small M values are communication/launch-overhead bound (~0.39–0.46 ms baseline). |
|
@copilot can you sweep the tiling parameters and other configuration space and make me a chart? |
…MM+AllScatter on MI300X Co-authored-by: mawad-amd <[email protected]>
Done — swept tile size Key takeaways:
Optimal config: |
|
@copilot very cool. Can you sweep for more shapes and show me some plots. Would like your help understanding how to determine the optimal tiling parameters for any problem shape. |
|
@copilot also, we recently added an option for the load/store iris apis to support hints. Can you add these hints, analyze the assembly before and after and performance before and after? Tip: you can find the generated assembly inside the ~/.triton/cache dir. I typically do someting like rm -rf ~/.triton/cache then run my code and inspect the generated assembly. Please be careful when deleting stuff and don't delete anything else though! |
…ling analysis for MI300X Co-authored-by: mawad-amd <[email protected]>
Extended the sweep to 4 representative LLM shapes across M=64..1024 (80 configs total). Results + guidance committed in
How to pick the optimal tile for any shape: The key metric is if total_tiles < num_sms:
BLK_M, BLK_N, BLK_K = 64, 64, 64 # SM-starved: maximize tile count
else:
BLK_M, BLK_N, BLK_K = 128, 64, 64 # compute-bound: balance occupancy vs reuse
|
…erf analysis Co-authored-by: mawad-amd <[email protected]>
Added Assembly impact (BLK_M=BLK_N=BLK_K=64, gfx942):
Assembly size: 2014 → 1151 lines (−43%). All scalar fp16 scatter stores are replaced by 128-bit vectorized Performance (8×MI300X, fp16, stages=2):
The BLK=64 configs (our optimal tile) benefit 6–10% because the scatter loop is spread across more tiles with more memory traffic per tile. Larger tiles (BLK_M=128+) benefit less as store throughput is a smaller fraction of total runtime. The |
|
@copilot, can we optimize the triton code? maybe we can do better job at overlapping communication and computation? We can't change M to get better performance -- we must change code instead. Please optimize the triton code. |
… → ctx.store) Co-authored-by: mawad-amd <[email protected]>
Two Triton-level optimizations applied and benchmarked in commit Optimization 1: Scatter directly from accumulator registers (
Replacing with Optimization 2: For BLK_M=64, Combined results (8×MI300X, fp16, BLK=64,64,64):
Note: |
|
@copilot Ok, I want to see a roofline with clear lines for how far we can go and where we are now. You may want to consider an x-axis that is MNK and sweep many different variants of these. Take your time. You have access to the 8 GPUs. I need some comprehensive analysis with and without the optimizations you added. |
… configs with hardware limits Co-authored-by: mawad-amd <[email protected]>
Done — comprehensive roofline analysis committed in 60 benchmark points: M ∈ {64,128,256,512,1024} × (N,K) ∈ {(4096,4096),(4096,14336),(8192,4096),(8192,28672)} × 3 configs. TFLOPS = 2·M·N·K / time where N is the global output dimension. 4-panel chart: Panel A (roofline — AI on x-axis): All points are below the 10,459 TFLOPS compute ceiling. The bottleneck is NOT HBM bandwidth (all points are right of the ridge point) — it's SM under-utilization and MFMA latency chains. Panel B (TFLOPS vs 2·M·N·K): BLK=64 is 1.5–1.9× faster than Baseline (BLK=256) across all shapes. Best measured: 1002 TFLOPS at M=1024, N=8192, K=28672. Panel C (efficiency vs problem size): Efficiency ranges from 0.2% (small M) to 9.6% of total FP16 peak for the optimized config. Panel D (speedup vs baseline): Peak speedup of 1.86× at M=512, N=8192, K=28672. Key takeaway: the ~10× gap from the 10,459 TFLOPS compute ceiling is explained by four compounding factors (SM utilization 42%, MFMA latency chains, LDS barriers, scatter setup), not bandwidth. The primary lever for closing the gap further is increasing M (more tiles per SM) or fusing multiple sequence positions. |
|
@copilot ok, let's go back to basics. Let's look into strong and weak scaling and rooflines. If there are no communication, we should be compute bound. As we add comm, the bottleneck will shift but ideally we remain compute bound. Would like to reason better about the bottleneck and what we can do better. |
…weak scaling breakdown Co-authored-by: mawad-amd <[email protected]>
Done — full scaling analysis committed in 80 data points: 40 GEMM-only (rocBLAS) + 40 GEMM+AllScatter (Triton BLK=64, s=3) × M ∈ {64,128,256,512,1024} × 4 shapes. Key findings: Without communication (GEMM-only, rocBLAS):
With AllScatter (fused Triton kernel):
Strong scaling (doubling M = doubling work): fused kernel scales at 1.77–1.83× vs ideal 2×. Sub-linear because scatter setup is a fixed per-tile cost. Weak scaling (fixed per-GPU work): efficiency improves super-linearly with M (0.3% at M=64 → 4.4% at M=1024) because scatter setup amortizes as K-loop iterations per tile increase. What would help most:
|
|
@copilot ok, good analysis. Let's dig deeper into the utilization of the GEMMs. Are you sure we can't do better with the current kernels? Remember, everything is configurable. From tiling, pipelining, number of CUs we assign to the GEMM, and everything. If it helps. checkout the Triton source code (just do |
…_warps/mfma/kpack knobs in matmul_wrapper Co-authored-by: mawad-amd <[email protected]>
Dug into GEMM utilization by sweeping the four knobs that were hardcoded in Key finding: BLK_K=128 is a major win at M≤512 For K=14336, BLK_K=128 cuts the K-loop from 224 → 112 iterations, halving the
*M=1024 with BLK_K=128 was skipped (VGPR budget exceeded for 64×128 A-tile × 4 wavefronts). Other findings:
The dominant bottleneck (SM underutilization, 42% at M=1024) can only be addressed by larger M or batching sequences. With BLK_K=128 for M≤512 the kernel is now ~10% of FP16 SoL vs the previous 8%. Full analysis: |
|
@copilot ok, upped your timelimit on the GPUs. I want 1000 data point on the roofline graph. X-axis should be MNK, Y-axis should be TFLOPs. I want one single scatter plot per world size showing results for the various values of tile sizes, CU counts, and other kernel paramters. Each combination of paramater configuration should be colored differently, I know this will take time but we have all the time. |
…onfigs, world_size=8 Co-authored-by: mawad-amd <[email protected]>
Done (commit Sweep axes:
Chart (
New findings from the sweep: Results in |




Adds infrastructure to run GEMM+AllScatter benchmarks across a range of problem sizes and tiling configurations, and includes benchmark results collected on 8x AMD MI300X GPUs across multiple LLM-representative problem shapes. Also adds iris vectorization hints to the kernel with assembly and performance analysis, a full speed-of-light (roofline) study, kernel-level communication-compute overlap optimizations, a comprehensive 60-point roofline sweep across many (M, N, K) shapes, a strong/weak scaling analysis comparing GEMM-only (rocBLAS) against the fused GEMM+AllScatter kernel, a deep GEMM utilization tuning sweep exposing low-level knobs (
BLK_K,num_warps,mfma,num_sms) that were previously hardcoded, and a 1200-point roofline scatter plot sweeping 40 unique kernel configurations across 30 problem sizes.New files
dataset/gemm_all_scatter.json— sweep configurations sweeping M from 1→1024 with K=14336, N=4096 (typical LLM FF layer dims, divisible by 8 GPUs)benchmark/examples/benchmark_gemm_all_scatter.py— sweep harness using the kernel fromexamples/23_gemm_all_scatter_tracing/; mirrors the existingbenchmark_all_gather_gemm_pull.pystructurebenchmark/examples/benchmark_gemm_all_scatter_tiling_sweep.py— tiling parameter sweep script exploring(BLK_M, BLK_N, BLK_K),num_stages, andgsize_m; generates TFLOPS charts via matplotlibbenchmark/examples/benchmark_gemm_all_scatter_deep_tuning.py— deep GEMM utilization sweep script exploringnum_warps,mfma(16 vs 32),BLK_K(64 vs 128),num_stages, andnum_smsmode (full vs tiles) across M ∈ {256,512,1024}; generates 4-panel analysis chartsbenchmark/examples/benchmark_gemm_all_scatter_1000pt_roofline.py— comprehensive 1200-point roofline scatter sweep script: 40 unique kernel configurations × 30 problem sizes (M ∈ {32,64,128,256,512,1024} × 5 (N,K) shapes), generates a single scatter plot per world size with X-axis = M×N×K, Y-axis = TFLOPS, each configuration uniquely coloredbenchmark/gemm_all_scatter_mi300x.md— baseline benchmark results on 8x AMD MI300X (304 CUs each), fp16, sweeping M from 1→1024benchmark/gemm_all_scatter_tiling_sweep_mi300x.md— single-shape tiling sweep results (N=4096, K=14336) with TFLOPS tables and analysisbenchmark/gemm_all_scatter_tiling_sweep_mi300x.png— 3-panel TFLOPS chart: tile size sweep, pipeline stages sweep, and group-size-M sweepbenchmark/gemm_all_scatter_multishape_analysis_mi300x.md— multi-shape tiling analysis across 4 LLM shapes with optimal tile selection guidebenchmark/gemm_all_scatter_multishape_guidance_mi300x.png— 4-panel TFLOPS chart per shape + optimal tile heatmap + total_tiles scatter plotbenchmark/gemm_all_scatter_tile_heatmap_mi300x.png— standalone heatmap of best tile config per shape × Mbenchmark/gemm_all_scatter_hints_analysis_mi300x.md— assembly diff and performance comparison before/after adding iris vectorization hintsbenchmark/gemm_all_scatter_roofline_mi300x.png— 2-panel roofline chart: per-GPU roofline (HBM bandwidth + compute peaks, SM-saturation ceiling, measured points) and performance hierarchy bars at each M valuebenchmark/gemm_all_scatter_roofline_mi300x.md— speed-of-light analysis: MI300X hardware limits, per-config arithmetic intensity, SoL vs measured efficiency, inner K-loop instruction-mix breakdown, and optimization experiments (kpack=2, waves_per_eu=2) with conclusionsbenchmark/gemm_all_scatter_optimization_mi300x.md— kernel comm-compute overlap optimization analysis: register-scatter (ctx.storevsctx.put) assembly impact,num_stages=3LDS budget analysis and performance results, and recommended configuration matrix by M rangebenchmark/gemm_all_scatter_comprehensive_roofline_mi300x.png— 4-panel comprehensive roofline chart with x-axis = 2·M·N·K: (A) classical roofline (arithmetic intensity vs TFLOPS with hardware ceiling), (B) performance vs total GEMM FLOPs, (C) compute efficiency (% of 8-GPU FP16 peak) vs FLOPs, (D) speedup of optimized configs vs baselinebenchmark/gemm_all_scatter_comprehensive_roofline_mi300x.md— 60-point comprehensive roofline analysis sweeping M ∈ {64,128,256,512,1024} × (N,K) ∈ {(4096,4096),(4096,14336),(8192,4096),(8192,28672)} × 3 configs vs 10,459 TFLOPS hardware limit, with detailed per-shape tables and efficiency analysisbenchmark/gemm_all_scatter_scaling_mi300x.png— 4-panel scaling analysis chart: (A) per-GPU roofline overlaying GEMM-only and GEMM+scatter points with HBM and XGMI ceilings, (B) total TFLOPS vs 2·M·N·K with XGMI scatter ceiling lines, (C) communication overhead ratio (fused / GEMM-only latency), (D) per-GPU compute efficiency vs FP16 tensor peak for both configsbenchmark/gemm_all_scatter_scaling_mi300x.md— 80-point (40 GEMM-only + 40 GEMM+scatter) strong/weak scaling analysis: identifies that XGMI bandwidth is NOT the bottleneck (7–8% of ceiling), the 3.5–4.3× overhead gap is from SM underutilization + Triton vs rocBLAS GEMM efficiency, and strong/weak scaling behavior with recommended optimization pathsbenchmark/gemm_all_scatter_deep_tuning_mi300x.png— 4-panel deep tuning chart: (A) BLK_K/stages impact, (B) num_warps impact, (C) mfma instruction size impact, (D) num_sms mode impactbenchmark/gemm_all_scatter_deep_tuning_mi300x.md— 62-point deep GEMM utilization analysis: BLK_K=128 halves LDS barriers for +20% at M=256 and +8% at M=512; mfma=32 with num_warps=4 optimal for small M; updated recommended configuration matrix by M rangebenchmark/gemm_all_scatter_roofline_1000pt_mi300x.png— 1200-point roofline scatter chart: X-axis = M×N×K (log scale), Y-axis = TFLOPS (log scale), one point per (M, N, K, kernel config) combination, 40 uniquely colored configurations, hardware ceiling lines for FP16 peak and SM-utilization ceilingbenchmark/gemm_all_scatter_roofline_1000pt_mi300x.md— 1200-point sweep analysis: sweep design, best configuration per (M, N, K) shape, key findings (BLK_K=128 wins at M≤256; BLK=(256,64,64) wins at M≥512), and performance spread across configurationsModified files
examples/23_gemm_all_scatter_tracing/gemm_all_scatter.py— addedhint=(1, BLOCK_SIZE_N)to scatter stores; replacedctx.put(C_ptr, ...)withctx.store(c_global_ptr, c, ...)to scatter accumulator values directly from registers to remote ranks, eliminating the intermediate HBM read-backexamples/23_gemm_all_scatter_tracing/matmul_wrapper.py— exposed four previously hardcoded GEMM knobs as optional keyword arguments with backward-compatible defaults:num_warps(default: 8),mfma(default: 16),kpack(default: 1),waves_per_eu(default: 0)Usage
Benchmark Results (MI300X, 8 GPUs, fp16)
Baseline M sweep (N=4096, K=14336, BLK_M=256, BLK_N=64, BLK_K=64, gsize_m=6, stages=2)
1200-point roofline scatter sweep
Sweep design: 40 kernel configurations (5 tile+pipeline configs × 2
num_warps× 2mfma× 2sms_mode) × 30 problem sizes (M ∈ {32,64,128,256,512,1024} × 5 (N,K) shapes). Chart shows X = M×N×K (log), Y = TFLOPS (log), each config uniquely colored.New findings from the 1200-point sweep:
BLK_K=128wins at M≤256 across all shapes — halving K-loop iterations cuts s_barrier overhead by 50%BLK=(256,64,64)wins at M≥512 — larger tile SRAM reuse benefit exceeds SM-utilization advantage of smaller tiles at high Mmfma=32helps at small M with BLK_K=128 by encoding 4× more MACs per instructionsms=full(304 CUs) is generally preferred at M≥128;sms=tileshelps only at very small MTiling parameter sweep (N=4096, K=14336)
Key findings:
(BLK_M=64, BLK_N=64, BLK_K=64)achieves 349 TFLOPS at M=1024 — a 2× improvement over the default(256,64,64)config (177 TFLOPS). Smaller tiles keep all 304 SMs fully utilized.stages=1via software pipelining of global memory loads.stages=3exceeds the MI300X 64 KB LDS limit for BLK_M≥256.Multi-shape analysis (80 configs across 4 LLM shapes)
Shapes tested: (N=4096,K=4096), (N=4096,K=14336), (N=8192,K=4096), (N=8192,K=28672) × M ∈ {64,128,256,512,1024} × 4 tile configs.
How to determine the optimal tile for any problem shape
The key metric is
total_tiles = ⌈M/BLK_M⌉ × ⌈N_local/BLK_N⌉whereN_local = N/world_size:(64,64,64)wins in 14 of 16 shape×M combinations because for typical LLM shapestotal_tilesstays well below the 304 SM threshold. Full tables and guidance are inbenchmark/gemm_all_scatter_multishape_analysis_mi300x.md.Iris vectorization hints — assembly and performance impact
Added
hint=(1, BLOCK_SIZE_N)to the iris scatter operations inexamples/23_gemm_all_scatter_tracing/gemm_all_scatter.py. The hint informs the Triton backend that the N-dimension of the scatter destination pointer is contiguous and aligned toBLOCK_SIZE_Nelements.Assembly impact (BLK_M=BLK_N=BLK_K=64, gfx942):
global_store_short(2-byte scalar)global_store_short_d16_hi(2-byte scalar)global_store_dwordx4(16-byte vector)Assembly size: 2014 → 1151 lines (−43%). All scalar fp16 scatter stores are replaced by 128-bit vectorized
global_store_dwordx4instructions (8 fp16 elements per instruction).Performance (8×MI300X, fp16, stages=2, BLK=64,64,64):
The recommended
(64,64,64)tile config gains 6–10% end-to-end from this one-line change. Full assembly snippets and analysis are inbenchmark/gemm_all_scatter_hints_analysis_mi300x.md.Speed-of-light (roofline) analysis
A full roofline study was performed against MI300X hardware limits (1307.4 TFLOPS/GPU FP16 tensor, 5.3 TB/s HBM3, 3.15 TB/s aggregate XGMI). The arithmetic intensity for this workload (including A/B/C HBM traffic and XGMI scatter) ranges from 96–281 FLOPs/byte across M=128..1024, placing the kernel above the roofline ridge point (246.7 FLOPs/byte) at M≥512, making it theoretically compute-bound.
Speed-of-light vs measured (8 GPUs total):
The ~12× gap from SoL is explained by four compounding factors identified via assembly inspection: SM under-utilization (only 42% of 304 SMs active at M=1024), MFMA latency chains (4 sequential MFMAs per accumulator × 32-cycle latency), LDS barrier overhead (2
s_barrierper K-iteration × 224 iterations = 448 barriers per tile), and scatter heap-base loads (10global_load_dwordx2per tile for iris symmetric-heap pointer translation).Original prompt
✨ Let Copilot coding agent set things up for you — coding agent works faster and does higher quality work when set up for your repo.