-
Notifications
You must be signed in to change notification settings - Fork 759
Open
Description
Input:
!lhs = f4E2M1FN
!rhs = f4E2M1FN
!scale_ty = f8E8M0FNU
!A = tensor<1024x128x32xi8>
!A_i4 = tensor<1024x128x32xi4>
!B = tensor<1024x128x32xi8>
!B_i4 = tensor<1024x128x32xi4>
!A_fp4 = tensor<1024x128x32x!lhs>
!B_fp4 = tensor<1024x128x32x!rhs>
!A_scales = tensor<1024x128x!scale_ty>
!B_scales = tensor<1024x128x!scale_ty>
!A_s = tensor<1024x128xi8>
!B_s = tensor<1024x128xi8>
!C = tensor<1024x1024xf32>
#lhs_map = affine_map<(M, N, Ko, Kb) -> (M, Ko, Kb)>
#rhs_map = affine_map<(M, N, Ko, Kb) -> (N, Ko, Kb)>
#scale_m = affine_map<(M, N, Ko, Kb) -> (M, Ko)>
#scale_n = affine_map<(M, N, Ko, Kb) -> (N, Ko)>
#out_map = affine_map<(M, N, Ko, Kb) -> (M, N)>
func.func @scaled_matmul(%lhs : !A, %lhs_scales : !A_s, %rhs : !B, %rhs_scales : !B_s) -> !C {
%A2 = arith.trunci %lhs : !A to !A_i4
%B2 = arith.trunci %rhs : !B to !B_i4
%A_scales = arith.bitcast %lhs_scales : !A_s to !A_scales
%B_scales = arith.bitcast %rhs_scales : !B_s to !B_scales
%A = arith.bitcast %A2 : !A_i4 to !A_fp4
%B = arith.bitcast %B2 : !B_i4 to !B_fp4
%cst = arith.constant 0.000000e+00 : f32
%empty = tensor.empty() : !C
%C = linalg.fill ins(%cst : f32) outs(%empty : !C) -> !C
%D = linalg.generic {
indexing_maps = [#lhs_map, #rhs_map, #scale_m, #scale_n, #out_map],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]
} ins(%A, %B, %A_scales, %B_scales : !A_fp4, !B_fp4, !A_scales, !B_scales) outs(%C : !C) {
^bb0(%a: !lhs, %b: !rhs, %a_scale: !scale_ty, %b_scale: !scale_ty, %out: f32):
%1 = arith.scaling_extf %a, %a_scale : !lhs, !scale_ty to f32
%2 = arith.scaling_extf %b, %b_scale : !rhs, !scale_ty to f32
%3 = arith.mulf %1, %2 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
} -> !C
return %D : !C
}
Repro:
iree-compile <input.mlir> --iree-hal-target-device=hip --iree-hip-target=gfx950
Error:
input.mlir:47:8: error: function 'scaled_matmul_dispatch_4_reduction_1024x1024x128x32_f4E2M1FNxf4E2M1FNxf8E8M0FNUxf8E8M0FNUxf32' uses 417792 bytes of shared memory; exceeded the limit of 163840 bytes
%D = linalg.generic {
^
input.mlir:33:1: note: called from
func.func @scaled_matmul(%lhs : !A, %lhs_scales : !A_s, %rhs : !B, %rhs_scales : !B_s) -> !C {
^
input.mlir:47:8: error: failed to run translation of source executable to target executable for backend #hal.executable.target<"rocm", "rocm-hsaco-fb", {abi = "hip", iree_codegen.target_info = #iree_gpu.target<arch = "gfx950", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [...], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 163840, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>, ukernels = "none"}>
%D = linalg.generic {
^
input.mlir:33:1: note: called from
func.func @scaled_matmul(%lhs : !A, %lhs_scales : !A_s, %rhs : !B, %rhs_scales : !B_s) -> !C {
^
Metadata
Metadata
Assignees
Labels
No labels