Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Scaled MFMA failing at GPUCheckResourceUsagePass #22018

@Muzammiluddin-Syed-ECE

Description

@Muzammiluddin-Syed-ECE

Input:

!lhs = f4E2M1FN
!rhs = f4E2M1FN
!scale_ty = f8E8M0FNU

!A = tensor<1024x128x32xi8>
!A_i4 = tensor<1024x128x32xi4>
!B = tensor<1024x128x32xi8>
!B_i4 = tensor<1024x128x32xi4>
!A_fp4 = tensor<1024x128x32x!lhs>
!B_fp4 = tensor<1024x128x32x!rhs>

!A_scales = tensor<1024x128x!scale_ty>
!B_scales = tensor<1024x128x!scale_ty>
!A_s = tensor<1024x128xi8>
!B_s = tensor<1024x128xi8>

!C = tensor<1024x1024xf32>

#lhs_map = affine_map<(M, N, Ko, Kb) -> (M, Ko, Kb)>
#rhs_map = affine_map<(M, N, Ko, Kb) -> (N, Ko, Kb)>
#scale_m = affine_map<(M, N, Ko, Kb) -> (M, Ko)>
#scale_n = affine_map<(M, N, Ko, Kb) -> (N, Ko)>
#out_map = affine_map<(M, N, Ko, Kb) -> (M, N)>
func.func @scaled_matmul(%lhs : !A, %lhs_scales : !A_s, %rhs : !B, %rhs_scales : !B_s) -> !C {
  %A2 = arith.trunci %lhs : !A to !A_i4
  %B2 = arith.trunci %rhs : !B to !B_i4
  %A_scales = arith.bitcast %lhs_scales : !A_s to !A_scales
  %B_scales = arith.bitcast %rhs_scales : !B_s to !B_scales
  %A = arith.bitcast %A2 : !A_i4 to !A_fp4
  %B = arith.bitcast %B2 : !B_i4 to !B_fp4
  %cst = arith.constant 0.000000e+00 : f32
  %empty = tensor.empty() : !C
  %C = linalg.fill ins(%cst : f32) outs(%empty : !C) -> !C
  %D = linalg.generic {
    indexing_maps = [#lhs_map, #rhs_map, #scale_m, #scale_n, #out_map],
    iterator_types = ["parallel", "parallel", "reduction", "reduction"]
  } ins(%A, %B, %A_scales, %B_scales : !A_fp4, !B_fp4, !A_scales, !B_scales) outs(%C : !C) {
  ^bb0(%a: !lhs, %b: !rhs, %a_scale: !scale_ty, %b_scale: !scale_ty, %out: f32):
    %1 = arith.scaling_extf %a, %a_scale : !lhs, !scale_ty to f32
    %2 = arith.scaling_extf %b, %b_scale : !rhs, !scale_ty to f32
    %3 = arith.mulf %1, %2 : f32
    %4 = arith.addf %out, %3 : f32
    linalg.yield %4 : f32
  } -> !C
  return %D : !C
}

Repro:

iree-compile <input.mlir> --iree-hal-target-device=hip --iree-hip-target=gfx950

Error:

input.mlir:47:8: error: function 'scaled_matmul_dispatch_4_reduction_1024x1024x128x32_f4E2M1FNxf4E2M1FNxf8E8M0FNUxf8E8M0FNUxf32' uses 417792 bytes of shared memory; exceeded the limit of 163840 bytes
  %D = linalg.generic {
       ^
input.mlir:33:1: note: called from
func.func @scaled_matmul(%lhs : !A, %lhs_scales : !A_s, %rhs : !B, %rhs_scales : !B_s) -> !C {
^
input.mlir:47:8: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree_codegen.target_info = #iree_gpu.target<arch = "gfx950", features = "", wgp = <compute =  fp64|fp32|fp16|int64|int32|int16|int8, storage =  b64|b32|b16|b8, subgroup =  shuffle|arithmetic, dot =  dp4xi8toi32, mma = [...], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 163840, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
  %D = linalg.generic {
       ^
input.mlir:33:1: note: called from
func.func @scaled_matmul(%lhs : !A, %lhs_scales : !A_s, %rhs : !B, %rhs_scales : !B_s) -> !C {
^

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions