Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 3bcc2d3

Browse files
authored
Add dynamic-shapes variant of the quantized_matmul-vs-matmul test (iree-org#8473)
1 parent 9a50d72 commit 3bcc2d3

File tree

1 file changed

+87
-1
lines changed

1 file changed

+87
-1
lines changed

‎iree/test/e2e/regression/linalg_quantized_matmul_vs_linalg_matmul.mlir‎

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,78 @@ func private @quantized_matmul_as_matmul_3x4x5(%lhs : tensor<3x4xi8>, %rhs : ten
7575
return %quantized_matmul_from_matmul_result : tensor<3x5xi32>
7676
}
7777

78+
// Equivalent to linalg.quantized_matmul, but not using linalg.quantized_matmul
79+
func private @quantized_matmul_as_matmul_dynamic(%lhs : tensor<?x?xi8>, %rhs : tensor<?x?xi8>, %lhs_zp : i32, %rhs_zp : i32, %acc : tensor<?x?xi32>) -> tensor<?x?xi32> {
80+
// compute the matmul itself, which would be the end result already in the case
81+
// where both zero-point values %lhs_zp and %rhs_zp are zero.
82+
%matmul_result = linalg.matmul ins(%lhs, %rhs : tensor<?x?xi8>, tensor<?x?xi8>) outs(%acc : tensor<?x?xi32>) -> tensor<?x?xi32>
83+
84+
%c_0_index = arith.constant 0 : index
85+
%c_1_index = arith.constant 1 : index
86+
%m_size = tensor.dim %lhs, %c_0_index : tensor<?x?xi8>
87+
%k_size = tensor.dim %lhs, %c_1_index : tensor<?x?xi8>
88+
%n_size = tensor.dim %rhs, %c_1_index : tensor<?x?xi8>
89+
%k_size_i32 = arith.index_cast %k_size : index to i32
90+
91+
%c_0 = arith.constant 0 : i32
92+
93+
// compute the sums along rows of %lhs.
94+
%lhs_i32 = arith.extsi %lhs : tensor<?x?xi8> to tensor<?x?xi32>
95+
%init_lhs_sums_uninitialized = linalg.init_tensor [%m_size] : tensor<?xi32>
96+
%zero_lhs_sums = linalg.fill(%c_0, %init_lhs_sums_uninitialized) : i32, tensor<?xi32> -> tensor<?xi32>
97+
%lhs_sums = linalg.generic {
98+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
99+
affine_map<(d0, d1) -> (d0)>],
100+
iterator_types = ["parallel", "reduction"]}
101+
ins(%lhs_i32 : tensor<?x?xi32>)
102+
outs(%zero_lhs_sums : tensor<?xi32>) {
103+
^bb0(%arg0: i32, %arg1: i32) :
104+
%1 = arith.addi %arg0, %arg1 : i32
105+
linalg.yield %1 : i32
106+
} -> tensor<?xi32>
107+
108+
// compute the sums along columns of %rhs.
109+
%rhs_i32 = arith.extsi %rhs : tensor<?x?xi8> to tensor<?x?xi32>
110+
%init_rhs_sums_uninitialized = linalg.init_tensor [%n_size] : tensor<?xi32>
111+
%zero_rhs_sums = linalg.fill(%c_0, %init_rhs_sums_uninitialized) : i32, tensor<?xi32> -> tensor<?xi32>
112+
%rhs_sums = linalg.generic {
113+
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
114+
affine_map<(d0, d1) -> (d1)>],
115+
iterator_types = ["reduction", "parallel"]}
116+
ins(%rhs_i32 : tensor<?x?xi32>)
117+
outs(%zero_rhs_sums : tensor<?xi32>) {
118+
^bb0(%arg0: i32, %arg1: i32) :
119+
%1 = arith.addi %arg0, %arg1 : i32
120+
linalg.yield %1 : i32
121+
} -> tensor<?xi32>
122+
123+
// add all the terms together.
124+
%init_acc_uninitialized = linalg.init_tensor [%m_size, %n_size] : tensor<?x?xi32>
125+
%quantized_matmul_from_matmul_result = linalg.generic {
126+
indexing_maps = [
127+
affine_map<(d0, d1) -> (d0, d1)>,
128+
affine_map<(d0, d1) -> (d0)>,
129+
affine_map<(d0, d1) -> (d1)>,
130+
affine_map<(d0, d1) -> ()>,
131+
affine_map<(d0, d1) -> ()>,
132+
affine_map<(d0, d1) -> ()>,
133+
affine_map<(d0, d1) -> (d0, d1)>],
134+
iterator_types = ["parallel", "parallel"]}
135+
ins(%matmul_result, %lhs_sums, %rhs_sums, %lhs_zp, %rhs_zp, %k_size_i32 : tensor<?x?xi32>, tensor<?xi32>, tensor<?xi32>, i32, i32, i32)
136+
outs(%init_acc_uninitialized : tensor<?x?xi32>) {
137+
^bb0(%matmul_result_val : i32, %lhs_sums_val: i32, %rhs_sums_val: i32, %lhs_zp_val: i32, %rhs_zp_val: i32, %k : i32, %acc_val: i32) :
138+
%linear_term_in_rhs_zp = arith.muli %lhs_sums_val, %rhs_zp_val : i32
139+
%linear_term_in_lhs_zp = arith.muli %rhs_sums_val, %lhs_zp_val : i32
140+
%linear_term = arith.addi %linear_term_in_rhs_zp, %linear_term_in_lhs_zp : i32
141+
%product_of_zp = arith.muli %lhs_zp_val, %rhs_zp_val : i32
142+
%quadratic_term = arith.muli %k, %product_of_zp : i32
143+
%corrected_for_linear_term = arith.subi %matmul_result_val, %linear_term : i32
144+
%corrected = arith.addi %corrected_for_linear_term, %quadratic_term : i32
145+
linalg.yield %corrected : i32
146+
} -> tensor<?x?xi32>
147+
return %quantized_matmul_from_matmul_result : tensor<?x?xi32>
148+
}
149+
78150
// Checks that linalg.quantized_matmul agrees with @quantized_matmul_as_matmul_3x4x5
79151
func private @check_one_quantized_matmul_as_matmul_3x4x5(%lhs : tensor<3x4xi8>, %rhs : tensor<4x5xi8>, %lhs_zp : i32, %rhs_zp : i32, %acc : tensor<3x5xi32>) {
80152
%result_of_quantized_matmul = linalg.quantized_matmul ins(%lhs, %rhs, %lhs_zp, %rhs_zp : tensor<3x4xi8>, tensor<4x5xi8>, i32, i32) outs(%acc : tensor<3x5xi32>) -> tensor<3x5xi32>
@@ -83,7 +155,15 @@ func private @check_one_quantized_matmul_as_matmul_3x4x5(%lhs : tensor<3x4xi8>,
83155
return
84156
}
85157

86-
func @test_quantized_matmul_as_matmul_3x4x5() {
158+
// Checks that linalg.quantized_matmul agrees with @quantized_matmul_as_matmul_dynamic
159+
func private @check_one_quantized_matmul_as_matmul_dynamic(%lhs : tensor<?x?xi8>, %rhs : tensor<?x?xi8>, %lhs_zp : i32, %rhs_zp : i32, %acc : tensor<?x?xi32>) {
160+
%result_of_quantized_matmul = linalg.quantized_matmul ins(%lhs, %rhs, %lhs_zp, %rhs_zp : tensor<?x?xi8>, tensor<?x?xi8>, i32, i32) outs(%acc : tensor<?x?xi32>) -> tensor<?x?xi32>
161+
%result_of_quantized_matmul_as_matmul = call @quantized_matmul_as_matmul_dynamic(%lhs, %rhs, %lhs_zp, %rhs_zp, %acc) : (tensor<?x?xi8>, tensor<?x?xi8>, i32, i32, tensor<?x?xi32>) -> tensor<?x?xi32>
162+
check.expect_eq(%result_of_quantized_matmul, %result_of_quantized_matmul_as_matmul) : tensor<?x?xi32>
163+
return
164+
}
165+
166+
func @test_quantized_matmul_as_matmul() {
87167
%lhs_3x4_1 = util.unfoldable_constant dense<[
88168
[1, 2, 3, 4],
89169
[5, 6, 7, 8],
@@ -122,5 +202,11 @@ func @test_quantized_matmul_as_matmul_3x4x5() {
122202
call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_1, %rhs_4x5_1, %c_minus2, %c_plus3, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
123203
call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_2, %rhs_4x5_2, %c_plus41, %c_minus57, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
124204
call @check_one_quantized_matmul_as_matmul_3x4x5(%lhs_3x4_2, %rhs_4x5_2, %c_minus128, %c_plus127, %zero_acc) : (tensor<3x4xi8>, tensor<4x5xi8>, i32, i32, tensor<3x5xi32>) -> ()
205+
206+
%lhs_3x4_dynamic = tensor.cast %lhs_3x4_1 : tensor<3x4xi8> to tensor<?x?xi8>
207+
%rhs_4x5_dynamic = tensor.cast %rhs_4x5_1 : tensor<4x5xi8> to tensor<?x?xi8>
208+
%zero_acc_dynamic = tensor.cast %zero_acc : tensor<3x5xi32> to tensor<?x?xi32>
209+
call @check_one_quantized_matmul_as_matmul_dynamic(%lhs_3x4_dynamic, %rhs_4x5_dynamic, %c_minus128, %c_plus127, %zero_acc_dynamic) : (tensor<?x?xi8>, tensor<?x?xi8>, i32, i32, tensor<?x?xi32>) -> ()
210+
125211
return
126212
}

0 commit comments

Comments
 (0)