Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 5282373

Browse files
[MLIR][ArmSVE] Add an ArmSVE dialect operation which maps to svusmmla
1 parent dbf1aa0 commit 5282373

File tree

5 files changed

+96
-38
lines changed

5 files changed

+96
-38
lines changed

mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td

+57-38
Original file line numberDiff line numberDiff line change
@@ -147,11 +147,9 @@ class ScalableMaskedIOp<string mnemonic, string op_description,
147147
"$mask `,` $src1 `,` $src2 attr-dict `:` type($mask) `,` type($res)";
148148
}
149149

150-
def SdotOp : ArmSVE_Op<"sdot",
151-
[Pure,
152-
AllTypesMatch<["src1", "src2"]>,
153-
AllTypesMatch<["acc", "dst"]>,
154-
]> {
150+
def SdotOp : ArmSVE_Op<"sdot", [Pure,
151+
AllTypesMatch<["src1", "src2"]>,
152+
AllTypesMatch<["acc", "dst"]>]> {
155153
let summary = "Vector-vector dot product and accumulate op";
156154
let description = [{
157155
SDOT: Signed integer addition of dot product.
@@ -178,11 +176,9 @@ def SdotOp : ArmSVE_Op<"sdot",
178176
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
179177
}
180178

181-
def SmmlaOp : ArmSVE_Op<"smmla",
182-
[Pure,
183-
AllTypesMatch<["src1", "src2"]>,
184-
AllTypesMatch<["acc", "dst"]>,
185-
]> {
179+
def SmmlaOp : ArmSVE_Op<"smmla", [Pure,
180+
AllTypesMatch<["src1", "src2"]>,
181+
AllTypesMatch<["acc", "dst"]>]> {
186182
let summary = "Matrix-matrix multiply and accumulate op";
187183
let description = [{
188184
SMMLA: Signed integer matrix multiply-accumulate.
@@ -210,11 +206,9 @@ def SmmlaOp : ArmSVE_Op<"smmla",
210206
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
211207
}
212208

213-
def UdotOp : ArmSVE_Op<"udot",
214-
[Pure,
215-
AllTypesMatch<["src1", "src2"]>,
216-
AllTypesMatch<["acc", "dst"]>,
217-
]> {
209+
def UdotOp : ArmSVE_Op<"udot", [Pure,
210+
AllTypesMatch<["src1", "src2"]>,
211+
AllTypesMatch<["acc", "dst"]>]> {
218212
let summary = "Vector-vector dot product and accumulate op";
219213
let description = [{
220214
UDOT: Unsigned integer addition of dot product.
@@ -241,11 +235,9 @@ def UdotOp : ArmSVE_Op<"udot",
241235
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
242236
}
243237

244-
def UmmlaOp : ArmSVE_Op<"ummla",
245-
[Pure,
246-
AllTypesMatch<["src1", "src2"]>,
247-
AllTypesMatch<["acc", "dst"]>,
248-
]> {
238+
def UmmlaOp : ArmSVE_Op<"ummla", [Pure,
239+
AllTypesMatch<["src1", "src2"]>,
240+
AllTypesMatch<["acc", "dst"]>]> {
249241
let summary = "Matrix-matrix multiply and accumulate op";
250242
let description = [{
251243
UMMLA: Unsigned integer matrix multiply-accumulate.
@@ -273,14 +265,42 @@ def UmmlaOp : ArmSVE_Op<"ummla",
273265
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
274266
}
275267

268+
def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
269+
AllTypesMatch<["src1", "src2"]>,
270+
AllTypesMatch<["acc", "dst"]>]> {
271+
let summary = "Matrix-matrix multiply and accumulate op";
272+
let description = [{
273+
USMMLA: Unsigned by signed integer matrix multiply-accumulate.
274+
275+
The unsigned by signed integer matrix multiply-accumulate operation
276+
multiplies the 2×8 matrix of unsigned 8-bit integer values held
277+
the first source vector by the 8×2 matrix of signed 8-bit integer
278+
values in the second source vector. The resulting 2×2 widened 32-bit
279+
integer matrix product is then added to the 32-bit integer matrix
280+
accumulator.
281+
282+
Source:
283+
https://developer.arm.com/documentation/100987/0000
284+
}];
285+
// Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
286+
let arguments = (ins
287+
ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
288+
ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
289+
ScalableVectorOfLengthAndType<[16], [I8]>:$src2
290+
);
291+
let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
292+
let assemblyFormat =
293+
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
294+
}
295+
276296
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
277297
"expected corresponding svbool type widened to [16]xi1",
278298
lhsArg, rhsArg,
279299
"VectorType(VectorType::Builder(::llvm::cast<VectorType>($_self)).setDim(::llvm::cast<VectorType>($_self).getRank() - 1, 16))">;
280300

281301
def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
282-
[Pure, SvboolTypeConstraint<"result", "source">]>
283-
{
302+
[Pure,
303+
SvboolTypeConstraint<"result", "source">]> {
284304
let summary = "Convert a svbool type to a SVE predicate type";
285305
let description = [{
286306
Converts svbool types (`vector<[16]xi1>` or vectors of that type, e.g.
@@ -313,8 +333,8 @@ def ConvertFromSvboolOp : ArmSVE_Op<"convert_from_svbool",
313333
}
314334

315335
def ConvertToSvboolOp : ArmSVE_Op<"convert_to_svbool",
316-
[Pure, SvboolTypeConstraint<"source", "result">]>
317-
{
336+
[Pure,
337+
SvboolTypeConstraint<"source", "result">]> {
318338
let summary = "Convert a SVE predicate type to a svbool type";
319339
let description = [{
320340
Converts SVE predicate types (or vectors of predicate types, e.g.
@@ -356,10 +376,9 @@ def ZipInputVectorType : AnyTypeOf<[
356376
Scalable1DVectorOfLength<16, [I8]>],
357377
"an SVE vector with element size <= 64-bit">;
358378

359-
def ZipX2Op : ArmSVE_Op<"zip.x2", [
360-
Pure,
361-
AllTypesMatch<["sourceV1", "sourceV2", "resultV1", "resultV2"]>]
362-
> {
379+
def ZipX2Op : ArmSVE_Op<"zip.x2", [Pure,
380+
AllTypesMatch<["sourceV1", "sourceV2",
381+
"resultV1", "resultV2"]>]> {
363382
let summary = "Multi-vector two-way zip op";
364383

365384
let description = [{
@@ -400,12 +419,11 @@ def ZipX2Op : ArmSVE_Op<"zip.x2", [
400419
}];
401420
}
402421

403-
def ZipX4Op : ArmSVE_Op<"zip.x4", [
404-
Pure,
405-
AllTypesMatch<[
406-
"sourceV1", "sourceV2", "sourceV3", "sourceV4",
407-
"resultV1", "resultV2", "resultV3", "resultV4"]>]
408-
> {
422+
def ZipX4Op
423+
: ArmSVE_Op<"zip.x4",
424+
[Pure,
425+
AllTypesMatch<["sourceV1", "sourceV2", "sourceV3", "sourceV4",
426+
"resultV1", "resultV2", "resultV3", "resultV4"]>]> {
409427
let summary = "Multi-vector four-way zip op";
410428

411429
let description = [{
@@ -463,10 +481,7 @@ def ZipX4Op : ArmSVE_Op<"zip.x4", [
463481
}];
464482
}
465483

466-
def PselOp : ArmSVE_Op<"psel", [
467-
Pure,
468-
AllTypesMatch<["p1", "result"]>,
469-
]> {
484+
def PselOp : ArmSVE_Op<"psel", [Pure, AllTypesMatch<["p1", "result"]>]> {
470485
let summary = "Predicate select";
471486

472487
let description = [{
@@ -571,6 +586,10 @@ def SmmlaIntrOp :
571586
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
572587
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
573588

589+
def UsmmlaIntrOp :
590+
ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
591+
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
592+
574593
def SdotIntrOp :
575594
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
576595
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;

mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
2424
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
2525
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
2626
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
27+
using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
2728
using DupQLaneLowering =
2829
OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
2930
using ScalableMaskedAddIOpLowering =
@@ -206,6 +207,7 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
206207
SmmlaOpLowering,
207208
UdotOpLowering,
208209
UmmlaOpLowering,
210+
UsmmlaOpLowering,
209211
ZipX2OpLowering,
210212
ZipX4OpLowering,
211213
SdotOpLowering>(converter);
@@ -234,6 +236,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
234236
SmmlaIntrOp,
235237
UdotIntrOp,
236238
UmmlaIntrOp,
239+
UsmmlaIntrOp,
237240
WhileLTIntrOp,
238241
ZipX2IntrOp,
239242
ZipX4IntrOp,
@@ -254,6 +257,7 @@ void mlir::configureArmSVELegalizeForExportTarget(
254257
SmmlaOp,
255258
UdotOp,
256259
UmmlaOp,
260+
UsmmlaOp,
257261
ZipX2Op,
258262
ZipX4Op,
259263
SdotOp>();

mlir/test/Dialect/ArmSVE/legalize-for-llvm.mlir

+12
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
4848

4949
// -----
5050

51+
func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
52+
%b: vector<[16]xi8>,
53+
%c: vector<[4]xi32>)
54+
-> vector<[4]xi32> {
55+
// CHECK: arm_sve.intr.usmmla
56+
%0 = arm_sve.usmmla %c, %a, %b :
57+
vector<[16]xi8> to vector<[4]xi32>
58+
return %0 : vector<[4]xi32>
59+
}
60+
61+
// -----
62+
5163
func.func @arm_sve_arithi_masked(%a: vector<[4]xi32>,
5264
%b: vector<[4]xi32>,
5365
%c: vector<[4]xi32>,

mlir/test/Dialect/ArmSVE/roundtrip.mlir

+11
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,17 @@ func.func @arm_sve_ummla(%a: vector<[16]xi8>,
4444

4545
// -----
4646

47+
func.func @arm_sve_usmmla(%a: vector<[16]xi8>,
48+
%b: vector<[16]xi8>,
49+
%c: vector<[4]xi32>) -> vector<[4]xi32> {
50+
// CHECK: arm_sve.usmmla {{.*}}: vector<[16]xi8> to vector<[4]xi3
51+
%0 = arm_sve.usmmla %c, %a, %b :
52+
vector<[16]xi8> to vector<[4]xi32>
53+
return %0 : vector<[4]xi32>
54+
}
55+
56+
// -----
57+
4758
func.func @arm_sve_masked_arithi(%a: vector<[4]xi32>,
4859
%b: vector<[4]xi32>,
4960
%c: vector<[4]xi32>,

mlir/test/Target/LLVMIR/arm-sve.mlir

+12
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ llvm.func @arm_sve_ummla(%arg0: vector<[16]xi8>,
4848
llvm.return %0 : vector<[4]xi32>
4949
}
5050

51+
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_usmmla
52+
llvm.func @arm_sve_usmmla(%arg0: vector<[16]xi8>,
53+
%arg1: vector<[16]xi8>,
54+
%arg2: vector<[4]xi32>)
55+
-> vector<[4]xi32> {
56+
// CHECK: call <vscale x 4 x i32> @llvm.aarch64.sve.usmmla.nxv4i32(<vscale x 4
57+
%0 = "arm_sve.intr.usmmla"(%arg2, %arg0, %arg1) :
58+
(vector<[4]xi32>, vector<[16]xi8>, vector<[16]xi8>)
59+
-> vector<[4]xi32>
60+
llvm.return %0 : vector<[4]xi32>
61+
}
62+
5163
// CHECK-LABEL: define <vscale x 4 x i32> @arm_sve_arithi
5264
llvm.func @arm_sve_arithi(%arg0: vector<[4]xi32>,
5365
%arg1: vector<[4]xi32>,

0 commit comments

Comments
 (0)