@@ -75,6 +75,78 @@ func private @quantized_matmul_as_matmul_3x4x5(%lhs : tensor<3x4xi8>, %rhs : ten
75
75
return %quantized_matmul_from_matmul_result : tensor <3 x5 xi32 >
76
76
}
77
77
78
+ // Equivalent to linalg.quantized_matmul, but not using linalg.quantized_matmul
79
+ func private @quantized_matmul_as_matmul_dynamic (%lhs : tensor <?x?xi8 >, %rhs : tensor <?x?xi8 >, %lhs_zp : i32 , %rhs_zp : i32 , %acc : tensor <?x?xi32 >) -> tensor <?x?xi32 > {
80
+ // compute the matmul itself, which would be the end result already in the case
81
+ // where both zero-point values %lhs_zp and %rhs_zp are zero.
82
+ %matmul_result = linalg.matmul ins (%lhs , %rhs : tensor <?x?xi8 >, tensor <?x?xi8 >) outs (%acc : tensor <?x?xi32 >) -> tensor <?x?xi32 >
83
+
84
+ %c_0_index = arith.constant 0 : index
85
+ %c_1_index = arith.constant 1 : index
86
+ %m_size = tensor.dim %lhs , %c_0_index : tensor <?x?xi8 >
87
+ %k_size = tensor.dim %lhs , %c_1_index : tensor <?x?xi8 >
88
+ %n_size = tensor.dim %rhs , %c_1_index : tensor <?x?xi8 >
89
+ %k_size_i32 = arith.index_cast %k_size : index to i32
90
+
91
+ %c_0 = arith.constant 0 : i32
92
+
93
+ // compute the sums along rows of %lhs.
94
+ %lhs_i32 = arith.extsi %lhs : tensor <?x?xi8 > to tensor <?x?xi32 >
95
+ %init_lhs_sums_uninitialized = linalg.init_tensor [%m_size ] : tensor <?xi32 >
96
+ %zero_lhs_sums = linalg.fill (%c_0 , %init_lhs_sums_uninitialized ) : i32 , tensor <?xi32 > -> tensor <?xi32 >
97
+ %lhs_sums = linalg.generic {
98
+ indexing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>,
99
+ affine_map <(d0 , d1 ) -> (d0 )>],
100
+ iterator_types = [" parallel" , " reduction" ]}
101
+ ins (%lhs_i32 : tensor <?x?xi32 >)
102
+ outs (%zero_lhs_sums : tensor <?xi32 >) {
103
+ ^bb0 (%arg0: i32 , %arg1: i32 ) :
104
+ %1 = arith.addi %arg0 , %arg1 : i32
105
+ linalg.yield %1 : i32
106
+ } -> tensor <?xi32 >
107
+
108
+ // compute the sums along columns of %rhs.
109
+ %rhs_i32 = arith.extsi %rhs : tensor <?x?xi8 > to tensor <?x?xi32 >
110
+ %init_rhs_sums_uninitialized = linalg.init_tensor [%n_size ] : tensor <?xi32 >
111
+ %zero_rhs_sums = linalg.fill (%c_0 , %init_rhs_sums_uninitialized ) : i32 , tensor <?xi32 > -> tensor <?xi32 >
112
+ %rhs_sums = linalg.generic {
113
+ indexing_maps = [affine_map <(d0 , d1 ) -> (d0 , d1 )>,
114
+ affine_map <(d0 , d1 ) -> (d1 )>],
115
+ iterator_types = [" reduction" , " parallel" ]}
116
+ ins (%rhs_i32 : tensor <?x?xi32 >)
117
+ outs (%zero_rhs_sums : tensor <?xi32 >) {
118
+ ^bb0 (%arg0: i32 , %arg1: i32 ) :
119
+ %1 = arith.addi %arg0 , %arg1 : i32
120
+ linalg.yield %1 : i32
121
+ } -> tensor <?xi32 >
122
+
123
+ // add all the terms together.
124
+ %init_acc_uninitialized = linalg.init_tensor [%m_size , %n_size ] : tensor <?x?xi32 >
125
+ %quantized_matmul_from_matmul_result = linalg.generic {
126
+ indexing_maps = [
127
+ affine_map <(d0 , d1 ) -> (d0 , d1 )>,
128
+ affine_map <(d0 , d1 ) -> (d0 )>,
129
+ affine_map <(d0 , d1 ) -> (d1 )>,
130
+ affine_map <(d0 , d1 ) -> ()>,
131
+ affine_map <(d0 , d1 ) -> ()>,
132
+ affine_map <(d0 , d1 ) -> ()>,
133
+ affine_map <(d0 , d1 ) -> (d0 , d1 )>],
134
+ iterator_types = [" parallel" , " parallel" ]}
135
+ ins (%matmul_result , %lhs_sums , %rhs_sums , %lhs_zp , %rhs_zp , %k_size_i32 : tensor <?x?xi32 >, tensor <?xi32 >, tensor <?xi32 >, i32 , i32 , i32 )
136
+ outs (%init_acc_uninitialized : tensor <?x?xi32 >) {
137
+ ^bb0 (%matmul_result_val : i32 , %lhs_sums_val: i32 , %rhs_sums_val: i32 , %lhs_zp_val: i32 , %rhs_zp_val: i32 , %k : i32 , %acc_val: i32 ) :
138
+ %linear_term_in_rhs_zp = arith.muli %lhs_sums_val , %rhs_zp_val : i32
139
+ %linear_term_in_lhs_zp = arith.muli %rhs_sums_val , %lhs_zp_val : i32
140
+ %linear_term = arith.addi %linear_term_in_rhs_zp , %linear_term_in_lhs_zp : i32
141
+ %product_of_zp = arith.muli %lhs_zp_val , %rhs_zp_val : i32
142
+ %quadratic_term = arith.muli %k , %product_of_zp : i32
143
+ %corrected_for_linear_term = arith.subi %matmul_result_val , %linear_term : i32
144
+ %corrected = arith.addi %corrected_for_linear_term , %quadratic_term : i32
145
+ linalg.yield %corrected : i32
146
+ } -> tensor <?x?xi32 >
147
+ return %quantized_matmul_from_matmul_result : tensor <?x?xi32 >
148
+ }
149
+
78
150
// Checks that linalg.quantized_matmul agrees with @quantized_matmul_as_matmul_3x4x5
79
151
func private @check_one_quantized_matmul_as_matmul_3x4x5 (%lhs : tensor <3 x4 xi8 >, %rhs : tensor <4 x5 xi8 >, %lhs_zp : i32 , %rhs_zp : i32 , %acc : tensor <3 x5 xi32 >) {
80
152
%result_of_quantized_matmul = linalg.quantized_matmul ins (%lhs , %rhs , %lhs_zp , %rhs_zp : tensor <3 x4 xi8 >, tensor <4 x5 xi8 >, i32 , i32 ) outs (%acc : tensor <3 x5 xi32 >) -> tensor <3 x5 xi32 >
@@ -83,7 +155,15 @@ func private @check_one_quantized_matmul_as_matmul_3x4x5(%lhs : tensor<3x4xi8>,
83
155
return
84
156
}
85
157
86
- func @test_quantized_matmul_as_matmul_3x4x5 () {
158
+ // Checks that linalg.quantized_matmul agrees with @quantized_matmul_as_matmul_dynamic
159
+ func private @check_one_quantized_matmul_as_matmul_dynamic (%lhs : tensor <?x?xi8 >, %rhs : tensor <?x?xi8 >, %lhs_zp : i32 , %rhs_zp : i32 , %acc : tensor <?x?xi32 >) {
160
+ %result_of_quantized_matmul = linalg.quantized_matmul ins (%lhs , %rhs , %lhs_zp , %rhs_zp : tensor <?x?xi8 >, tensor <?x?xi8 >, i32 , i32 ) outs (%acc : tensor <?x?xi32 >) -> tensor <?x?xi32 >
161
+ %result_of_quantized_matmul_as_matmul = call @quantized_matmul_as_matmul_dynamic (%lhs , %rhs , %lhs_zp , %rhs_zp , %acc ) : (tensor <?x?xi8 >, tensor <?x?xi8 >, i32 , i32 , tensor <?x?xi32 >) -> tensor <?x?xi32 >
162
+ check.expect_eq (%result_of_quantized_matmul , %result_of_quantized_matmul_as_matmul ) : tensor <?x?xi32 >
163
+ return
164
+ }
165
+
166
+ func @test_quantized_matmul_as_matmul () {
87
167
%lhs_3x4_1 = util.unfoldable_constant dense <[
88
168
[1 , 2 , 3 , 4 ],
89
169
[5 , 6 , 7 , 8 ],
@@ -122,5 +202,11 @@ func @test_quantized_matmul_as_matmul_3x4x5() {
122
202
call @check_one_quantized_matmul_as_matmul_3x4x5 (%lhs_3x4_1 , %rhs_4x5_1 , %c_minus2 , %c_plus3 , %zero_acc ) : (tensor <3 x4 xi8 >, tensor <4 x5 xi8 >, i32 , i32 , tensor <3 x5 xi32 >) -> ()
123
203
call @check_one_quantized_matmul_as_matmul_3x4x5 (%lhs_3x4_2 , %rhs_4x5_2 , %c_plus41 , %c_minus57 , %zero_acc ) : (tensor <3 x4 xi8 >, tensor <4 x5 xi8 >, i32 , i32 , tensor <3 x5 xi32 >) -> ()
124
204
call @check_one_quantized_matmul_as_matmul_3x4x5 (%lhs_3x4_2 , %rhs_4x5_2 , %c_minus128 , %c_plus127 , %zero_acc ) : (tensor <3 x4 xi8 >, tensor <4 x5 xi8 >, i32 , i32 , tensor <3 x5 xi32 >) -> ()
205
+
206
+ %lhs_3x4_dynamic = tensor.cast %lhs_3x4_1 : tensor <3 x4 xi8 > to tensor <?x?xi8 >
207
+ %rhs_4x5_dynamic = tensor.cast %rhs_4x5_1 : tensor <4 x5 xi8 > to tensor <?x?xi8 >
208
+ %zero_acc_dynamic = tensor.cast %zero_acc : tensor <3 x5 xi32 > to tensor <?x?xi32 >
209
+ call @check_one_quantized_matmul_as_matmul_dynamic (%lhs_3x4_dynamic , %rhs_4x5_dynamic , %c_minus128 , %c_plus127 , %zero_acc_dynamic ) : (tensor <?x?xi8 >, tensor <?x?xi8 >, i32 , i32 , tensor <?x?xi32 >) -> ()
210
+
125
211
return
126
212
}
0 commit comments