-
Notifications
You must be signed in to change notification settings - Fork 759
Open
Description
Input :
!lhs = f4E2M1FN
!rhs = f4E2M1FN
!scale_ty = f8E8M0FNU
!A = tensor<1024x512x32xi8>
!A_i4 = tensor<1024x512x32xi4>
!B = tensor<1024x512x32xi8>
!B_i4 = tensor<1024x512x32xi4>
!A_fp4 = tensor<1024x512x32x!lhs>
!B_fp4 = tensor<1024x512x32x!rhs>
!A_scales = tensor<1024x512x!scale_ty>
!B_scales = tensor<1024x512x!scale_ty>
!A_s = tensor<1024x512xi8>
!B_s = tensor<1024x512xi8>
!C = tensor<1024x1024xf32>
#lhs_map = affine_map<(M, N, Ko, Kb) -> (M, Ko, Kb)>
#rhs_map = affine_map<(M, N, Ko, Kb) -> (N, Ko, Kb)>
#scale_m = affine_map<(M, N, Ko, Kb) -> (M, Ko)>
#scale_n = affine_map<(M, N, Ko, Kb) -> (N, Ko)>
#out_map = affine_map<(M, N, Ko, Kb) -> (M, N)>
func.func @scaled_matmul(%lhs : !A, %lhs_scales : !A_s, %rhs : !B, %rhs_scales : !B_s) -> !C {
%A2 = arith.trunci %lhs : !A to !A_i4
%B2 = arith.trunci %rhs : !B to !B_i4
%A_scales = arith.bitcast %lhs_scales : !A_s to !A_scales
%B_scales = arith.bitcast %rhs_scales : !B_s to !B_scales
%A = arith.bitcast %A2 : !A_i4 to !A_fp4
%B = arith.bitcast %B2 : !B_i4 to !B_fp4
%cst = arith.constant 0.000000e+00 : f32
%empty = tensor.empty() : !C
%C = linalg.fill ins(%cst : f32) outs(%empty : !C) -> !C
%D = linalg.generic {
indexing_maps = [#lhs_map, #rhs_map, #scale_m, #scale_n, #out_map],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]
} ins(%A, %B, %A_scales, %B_scales : !A_fp4, !B_fp4, !A_scales, !B_scales) outs(%C : !C) {
^bb0(%a: !lhs, %b: !rhs, %a_scale: !scale_ty, %b_scale: !scale_ty, %out: f32):
%1 = arith.scaling_extf %a, %a_scale : !lhs, !scale_ty to f32
%2 = arith.scaling_extf %b, %b_scale : !rhs, !scale_ty to f32
%3 = arith.mulf %1, %2 : f32
%4 = arith.addf %out, %3 : f32
linalg.yield %4 : f32
} -> !C
return %D : !C
}
Reproduction:
iree-compile <input file> --iree-hal-target-device=hip --iree-hip-target=gfx950
Metadata
Metadata
Assignees
Labels
No labels