-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][ArmSVE] Add lowering of vector.contract
to SVE *MMLA
instructions
#135359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir-neon Author: Momchil Velikov (momchil-velikov) ChangesPatch is 93.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135359.diff 21 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+ Option<"armI8MM", "enable-arm-i8mm",
+ "bool", /*default=*/"false",
+ "Enables the use of Arm FEAT_I8MM instructions while lowering "
+ "the vector dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index cdcf4d8752e87..a678d9b09e66d 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -61,6 +61,13 @@ class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedCont
"a 1-D scalable vector with length " # length,
"::mlir::VectorType">;
+def SVEVector : AnyTypeOf<[
+ Scalable1DVectorOfLength<2, [I64, F64]>,
+ Scalable1DVectorOfLength<4, [I32, F32]>,
+ Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
+ Scalable1DVectorOfLength<16, [I8]>],
+ "an SVE vector with element size <= 64-bit">;
+
//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
@@ -72,14 +79,22 @@ class ArmSVE_IntrOp<string mnemonic,
list<Trait> traits = [],
list<int> overloadedOperands = [],
list<int> overloadedResults = [],
- int numResults = 1> :
+ int numResults = 1,
+ list<int> immArgPositions = [],
+ list<string> immArgAttrNames = []> :
LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
/*string opName=*/"intr." # mnemonic,
/*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
/*list<int> overloadedResults=*/overloadedResults,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
- /*int numResults=*/numResults>;
+ /*int numResults=*/numResults,
+ /*bit requiresAccessGroup=*/0,
+ /*bit requiresAliasAnalysis=*/0,
+ /*bit requiresFastmath=*/0,
+ /*bit requiresOpBundles=*/0,
+ /*list<int> immArgPositions=*/immArgPositions,
+ /*list<string> immArgAttrNames=*/immArgAttrNames>;
class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
list<Trait> traits = []>:
@@ -258,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
+def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "dst"]>]> {
+ let summary = "Matrix-matrix multiply and accumulate op";
+ let description = [{
+ USMMLA: Unsigned by signed integer matrix multiply-accumulate.
+
+ The unsigned by signed integer matrix multiply-accumulate operation
+ multiplies the 2×8 matrix of unsigned 8-bit integer values held
+ the first source vector by the 8×2 matrix of signed 8-bit integer
+ values in the second source vector. The resulting 2×2 widened 32-bit
+ integer matrix product is then added to the 32-bit integer matrix
+ accumulator.
+
+ Source:
+ https://developer.arm.com/documentation/100987/0000
+ }];
+ // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ let arguments = (ins
+ ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+ ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+ ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+ );
+ let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
@@ -509,6 +552,41 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
+def DupQLaneOp : ArmSVE_Op<"dupq_lane", [Pure, AllTypesMatch<["src", "dst"]>]> {
+ let summary = "Broadcast indexed 128-bit segment to vector";
+
+ let description = [{
+ This operation fills each 128-bit segment of a vector with the elements
+ from the indexed 128-bit sgement of the source vector. If the VL is
+ 128 bits the operation is a NOP.
+
+ Example:
+ ```mlir
+ // VL == 256
+ // %X = [A B C D x x x x]
+ %Y = arm_sve.dupq_lane %X[0] : vector<[4]xi32>
+ // Y = [A B C D A B C D]
+
+ // %U = [x x x x x x x x A B C D E F G H]
+ %V = arm_sve.dupq_lane %U[1] : vector<[8]xf16>
+ // %V = [A B C D E F H A B C D E F H]
+ ```
+ }];
+
+ let arguments = (ins SVEVector:$src,
+ I64Attr:$lane);
+ let results = (outs SVEVector:$dst);
+
+ let builders = [
+ OpBuilder<(ins "Value":$src, "int64_t":$lane), [{
+ build($_builder, $_state, src.getType(), src, lane);
+ }]>];
+
+ let assemblyFormat = [{
+ $src `[` $lane `]` attr-dict `:` type($dst)
+ }];
+}
+
def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
@@ -517,6 +595,10 @@ def SmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+def UsmmlaIntrOp :
+ ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
@@ -610,4 +692,14 @@ def WhileLTIntrOp :
/*overloadedResults=*/[0]>,
Arguments<(ins I64:$base, I64:$n)>;
+def DupQLaneIntrOp : ArmSVE_IntrOp<"dupq_lane",
+ /*traits=*/[],
+ /*overloadedOperands=*/[0],
+ /*overloadedResults=*/[],
+ /*numResults=*/1,
+ /*immArgPositions*/[1],
+ /*immArgAttrNames*/["lane"]>,
+ Arguments<(ins Arg<ScalableVectorOfRank<[1]>, "v">:$v,
+ Arg<I64Attr, "lane">:$lane)>;
+
#endif // ARMSVE_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
void populateArmSVELegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+void populateLowerContractionToSVEI8MMPatternPatterns(
+ RewritePatternSet &patterns);
+
/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
/// intrinsics.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
MLIRVectorToLLVM
MLIRArmNeonDialect
+ MLIRArmNeonTransforms
MLIRArmSVEDialect
MLIRArmSVETransforms
MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
+ if (armI8MM) {
+ if (armNeon)
+ arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+ if (armSVE)
+ populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+ }
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..e807b233aa7aa 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern
// Avoid 0-D vectors and 1-D rhs:
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
return failure();
+ // Avoid scalable vectors.
+ if (lhsType.isScalable() || rhsType.isScalable())
+ return failure();
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = rhsType.getDimSize(1);
@@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
- patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
+ patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
}
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index a70c489a51fea..65f98b44b1b69 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRArmSVETransforms
LegalizeForLLVMExport.cpp
LegalizeVectorStorage.cpp
+ LowerContractionToSVEI8MMPattern.cpp
DEPENDS
MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 2bdb640699d03..0bbe4717e0d17 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,8 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
+using DupQLaneLowering = OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
ScalableMaskedAddIIntrOp>;
@@ -192,6 +194,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
+ UsmmlaOpLowering,
+ DupQLaneLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedAddFOpLowering,
ScalableMaskedSubIOpLowering,
@@ -219,6 +223,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
+ UsmmlaIntrOp,
+ DupQLaneIntrOp,
ScalableMaskedAddIIntrOp,
ScalableMaskedAddFIntrOp,
ScalableMaskedSubIIntrOp,
@@ -238,6 +244,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaOp,
UdotOp,
UmmlaOp,
+ UsmmlaOp,
+ DupQLaneOp,
ScalableMaskedAddIOp,
ScalableMaskedAddFOp,
ScalableMaskedSubIOp,
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
new file mode 100644
index 0000000000000..c0620c71440bc
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the extension.
+template <typename T>
+inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
+ std::is_base_of_v<arith::ExtUIOp, T>),
+ std::optional<Value>>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+ auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
+ if (!extOp)
+ return {};
+
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy || inTy.getElementType() != i8Ty)
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!outTy || outTy.getElementType() != i32Ty)
+ return {};
+
+ return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+ Signed, // smmla
+ Unsigned, // ummla
+ Mixed, // usmmla
+ MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+ switch (op) {
+ case MMLA::Signed:
+ return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::Unsigned:
+ return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::Mixed:
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::MixedSwapped:
+ // The accumulator comes transposed and the result will be transposed
+ // later, so all we have to do here is swap the operands.
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
+ }
+}
+
+class LowerContractionToSVEI8MMPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ mlir::VectorType lhsType = op.getLhsType();
+ mlir::VectorType rhsType = op.getRhsType();
+
+ // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
+ // eventually expect from MMT4D. M and N dimensions must be even and at
+ // least 2.
+ if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+ rhsType.getRank() != 2)
+ return failure();
+
+ if (lhsType.isScalable() || !rhsType.isScalable())
+ return failure();
+
+ // M, N, and K are the conventional names for matrix dimensions in the
+ // context of matrix multiplication.
+ auto M = lhsType.getDimSize(0);
+ auto N = rhsType.getDimSize(0);
+ auto K = rhsType.getDimSize(1);
+
+ if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+ N % 2 != 0 || !rhsType.getScalableDims()[0])
+ return failure();
+
+ // Check permutation maps. For now only accept
+ // lhs: (d0, d1, d2) -> (d0, d2)
+ // rhs: (d0, d1, d2) -> (d1, d2)
+ // acc: (d0, d1, d2) -> (d0, d1)
+ // Note: RHS is transposed.
+ if (op.getIndexingMapsArray()[0] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[1] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[2] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+ op.getContext()))
+ return failure();
+
+ // Check iterator types for matrix multiplication.
+ auto itTypes = op.getIteratorTypesArray();
+ if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::parallel ||
+ itTypes[2] != vector::IteratorType::reduction)
+ return failure();
+
+ // Check the combining kind is addition.
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return failure();
+
+ // Check the output is a vector of i32 elements.
+ auto outTy = dyn_cast<VectorType>(op.getType());
+ if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+ return failure();
+
+ // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+ // before the extension. All four signed/unsigned combinations for input
+ // operands are supported, but they are lowered to different operations.
+ // Determina which is the appropriate operation to lower to.
+ MMLA mmlaOp = MMLA::Signed;
+ auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
+ op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ if (!maybeLhs) {
+ mmlaOp = MMLA::Unsigned;
+ maybeLhs = extractExtOperand<arith::ExtUIOp>(
+ op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ }
+ if (!maybeLhs)
+ return failure();
+
+ auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
+ op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ if (maybeRhs) {
+ if (mmlaOp == MMLA::Unsigned)
+ mmlaOp = MMLA::Mixed;
+ } else {
+ if (mmlaOp == MMLA::Signed)
+ mmlaOp = MMLA::MixedSwapped;
+ maybeRhs = extractExtOperand<arith::ExtUIOp>(
+ op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ }
+ if (!maybeRhs)
+ return failure();
+
+ // One-dimensional vector types for arm_sve.*mmla
+ auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
+ auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});
+
+ // Extract LHS sub-tiles.
+ SmallVector<Value> lhsTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Exract two consective rows of the LHS tile.
+ auto r0 = rew...
[truncated]
|
@llvm/pr-subscribers-mlir-vector Author: Momchil Velikov (momchil-velikov) ChangesPatch is 93.07 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/135359.diff 21 Files Affected:
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index bbba495e613b2..930d8b44abca0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1406,6 +1406,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
"bool", /*default=*/"false",
"Enables the use of ArmSVE dialect while lowering the vector "
"dialect.">,
+ Option<"armI8MM", "enable-arm-i8mm",
+ "bool", /*default=*/"false",
+ "Enables the use of Arm FEAT_I8MM instructions while lowering "
+ "the vector dialect.">,
Option<"x86Vector", "enable-x86vector",
"bool", /*default=*/"false",
"Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
index cdcf4d8752e87..a678d9b09e66d 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
+++ b/mlir/include/mlir/Dialect/ArmSVE/IR/ArmSVE.td
@@ -61,6 +61,13 @@ class Scalable1DVectorOfLength<int length, list<Type> elementTypes> : ShapedCont
"a 1-D scalable vector with length " # length,
"::mlir::VectorType">;
+def SVEVector : AnyTypeOf<[
+ Scalable1DVectorOfLength<2, [I64, F64]>,
+ Scalable1DVectorOfLength<4, [I32, F32]>,
+ Scalable1DVectorOfLength<8, [I16, F16, BF16]>,
+ Scalable1DVectorOfLength<16, [I8]>],
+ "an SVE vector with element size <= 64-bit">;
+
//===----------------------------------------------------------------------===//
// ArmSVE op definitions
//===----------------------------------------------------------------------===//
@@ -72,14 +79,22 @@ class ArmSVE_IntrOp<string mnemonic,
list<Trait> traits = [],
list<int> overloadedOperands = [],
list<int> overloadedResults = [],
- int numResults = 1> :
+ int numResults = 1,
+ list<int> immArgPositions = [],
+ list<string> immArgAttrNames = []> :
LLVM_IntrOpBase</*Dialect dialect=*/ArmSVE_Dialect,
/*string opName=*/"intr." # mnemonic,
/*string enumName=*/"aarch64_sve_" # !subst(".", "_", mnemonic),
/*list<int> overloadedResults=*/overloadedResults,
/*list<int> overloadedOperands=*/overloadedOperands,
/*list<Trait> traits=*/traits,
- /*int numResults=*/numResults>;
+ /*int numResults=*/numResults,
+ /*bit requiresAccessGroup=*/0,
+ /*bit requiresAliasAnalysis=*/0,
+ /*bit requiresFastmath=*/0,
+ /*bit requiresOpBundles=*/0,
+ /*list<int> immArgPositions=*/immArgPositions,
+ /*list<string> immArgAttrNames=*/immArgAttrNames>;
class ArmSVE_IntrBinaryOverloadedOp<string mnemonic,
list<Trait> traits = []>:
@@ -258,6 +273,34 @@ def UmmlaOp : ArmSVE_Op<"ummla",
"$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
}
+def UsmmlaOp : ArmSVE_Op<"usmmla", [Pure,
+ AllTypesMatch<["src1", "src2"]>,
+ AllTypesMatch<["acc", "dst"]>]> {
+ let summary = "Matrix-matrix multiply and accumulate op";
+ let description = [{
+ USMMLA: Unsigned by signed integer matrix multiply-accumulate.
+
+ The unsigned by signed integer matrix multiply-accumulate operation
+ multiplies the 2×8 matrix of unsigned 8-bit integer values held
+ the first source vector by the 8×2 matrix of signed 8-bit integer
+ values in the second source vector. The resulting 2×2 widened 32-bit
+ integer matrix product is then added to the 32-bit integer matrix
+ accumulator.
+
+ Source:
+ https://developer.arm.com/documentation/100987/0000
+ }];
+ // Supports (vector<16xi8>, vector<16xi8>) -> (vector<4xi32>)
+ let arguments = (ins
+ ScalableVectorOfLengthAndType<[4], [I32]>:$acc,
+ ScalableVectorOfLengthAndType<[16], [I8]>:$src1,
+ ScalableVectorOfLengthAndType<[16], [I8]>:$src2
+ );
+ let results = (outs ScalableVectorOfLengthAndType<[4], [I32]>:$dst);
+ let assemblyFormat =
+ "$acc `,` $src1 `,` $src2 attr-dict `:` type($src1) `to` type($dst)";
+}
+
class SvboolTypeConstraint<string lhsArg, string rhsArg> : TypesMatchWith<
"expected corresponding svbool type widened to [16]xi1",
lhsArg, rhsArg,
@@ -509,6 +552,41 @@ def ScalableMaskedUDivIOp : ScalableMaskedIOp<"masked.divi_unsigned",
def ScalableMaskedDivFOp : ScalableMaskedFOp<"masked.divf", "division">;
+def DupQLaneOp : ArmSVE_Op<"dupq_lane", [Pure, AllTypesMatch<["src", "dst"]>]> {
+ let summary = "Broadcast indexed 128-bit segment to vector";
+
+ let description = [{
+ This operation fills each 128-bit segment of a vector with the elements
+ from the indexed 128-bit sgement of the source vector. If the VL is
+ 128 bits the operation is a NOP.
+
+ Example:
+ ```mlir
+ // VL == 256
+ // %X = [A B C D x x x x]
+ %Y = arm_sve.dupq_lane %X[0] : vector<[4]xi32>
+ // Y = [A B C D A B C D]
+
+ // %U = [x x x x x x x x A B C D E F G H]
+ %V = arm_sve.dupq_lane %U[1] : vector<[8]xf16>
+ // %V = [A B C D E F H A B C D E F H]
+ ```
+ }];
+
+ let arguments = (ins SVEVector:$src,
+ I64Attr:$lane);
+ let results = (outs SVEVector:$dst);
+
+ let builders = [
+ OpBuilder<(ins "Value":$src, "int64_t":$lane), [{
+ build($_builder, $_state, src.getType(), src, lane);
+ }]>];
+
+ let assemblyFormat = [{
+ $src `[` $lane `]` attr-dict `:` type($dst)
+ }];
+}
+
def UmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"ummla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
@@ -517,6 +595,10 @@ def SmmlaIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"smmla">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+def UsmmlaIntrOp :
+ ArmSVE_IntrBinaryOverloadedOp<"usmmla">,
+ Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
+
def SdotIntrOp :
ArmSVE_IntrBinaryOverloadedOp<"sdot">,
Arguments<(ins AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank, AnyScalableVectorOfAnyRank)>;
@@ -610,4 +692,14 @@ def WhileLTIntrOp :
/*overloadedResults=*/[0]>,
Arguments<(ins I64:$base, I64:$n)>;
+def DupQLaneIntrOp : ArmSVE_IntrOp<"dupq_lane",
+ /*traits=*/[],
+ /*overloadedOperands=*/[0],
+ /*overloadedResults=*/[],
+ /*numResults=*/1,
+ /*immArgPositions*/[1],
+ /*immArgAttrNames*/["lane"]>,
+ Arguments<(ins Arg<ScalableVectorOfRank<[1]>, "v">:$v,
+ Arg<I64Attr, "lane">:$lane)>;
+
#endif // ARMSVE_OPS
diff --git a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
index 8665c8224cc45..232e2be29e574 100644
--- a/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h
@@ -20,6 +20,9 @@ class RewritePatternSet;
void populateArmSVELegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
+void populateLowerContractionToSVEI8MMPatternPatterns(
+ RewritePatternSet &patterns);
+
/// Configure the target to support lowering ArmSVE ops to ops that map to LLVM
/// intrinsics.
void configureArmSVELegalizeForExportTarget(LLVMConversionTarget &target);
diff --git a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
index 330474a718e30..8e2620029c354 100644
--- a/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
+++ b/mlir/lib/Conversion/VectorToLLVM/CMakeLists.txt
@@ -35,6 +35,7 @@ add_mlir_conversion_library(MLIRVectorToLLVMPass
MLIRVectorToLLVM
MLIRArmNeonDialect
+ MLIRArmNeonTransforms
MLIRArmSVEDialect
MLIRArmSVETransforms
MLIRAMXDialect
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 7082b92c95d1d..1e6c8122b1d0e 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -14,6 +14,7 @@
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
@@ -82,6 +83,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
populateVectorStepLoweringPatterns(patterns);
populateVectorRankReducingFMAPattern(patterns);
populateVectorGatherLoweringPatterns(patterns);
+ if (armI8MM) {
+ if (armNeon)
+ arm_neon::populateLowerContractionToSMMLAPatternPatterns(patterns);
+ if (armSVE)
+ populateLowerContractionToSVEI8MMPatternPatterns(patterns);
+ }
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
index 2a1271dfd6bdf..e807b233aa7aa 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp
@@ -56,6 +56,9 @@ class LowerContractionToSMMLAPattern
// Avoid 0-D vectors and 1-D rhs:
if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2)
return failure();
+ // Avoid scalable vectors.
+ if (lhsType.isScalable() || rhsType.isScalable())
+ return failure();
auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0);
auto dimN = rhsType.getDimSize(0);
auto dimK = rhsType.getDimSize(1);
@@ -238,5 +241,5 @@ class LowerContractionToSMMLAPattern
void mlir::arm_neon::populateLowerContractionToSMMLAPatternPatterns(
RewritePatternSet &patterns) {
MLIRContext *context = patterns.getContext();
- patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/1);
+ patterns.add<LowerContractionToSMMLAPattern>(context, /*benefit=*/2);
}
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
index a70c489a51fea..65f98b44b1b69 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/CMakeLists.txt
@@ -1,6 +1,7 @@
add_mlir_dialect_library(MLIRArmSVETransforms
LegalizeForLLVMExport.cpp
LegalizeVectorStorage.cpp
+ LowerContractionToSVEI8MMPattern.cpp
DEPENDS
MLIRArmSVEConversionsIncGen
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 2bdb640699d03..0bbe4717e0d17 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -24,6 +24,8 @@ using SdotOpLowering = OneToOneConvertToLLVMPattern<SdotOp, SdotIntrOp>;
using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
+using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
+using DupQLaneLowering = OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
ScalableMaskedAddIIntrOp>;
@@ -192,6 +194,8 @@ void mlir::populateArmSVELegalizeForLLVMExportPatterns(
SmmlaOpLowering,
UdotOpLowering,
UmmlaOpLowering,
+ UsmmlaOpLowering,
+ DupQLaneLowering,
ScalableMaskedAddIOpLowering,
ScalableMaskedAddFOpLowering,
ScalableMaskedSubIOpLowering,
@@ -219,6 +223,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaIntrOp,
UdotIntrOp,
UmmlaIntrOp,
+ UsmmlaIntrOp,
+ DupQLaneIntrOp,
ScalableMaskedAddIIntrOp,
ScalableMaskedAddFIntrOp,
ScalableMaskedSubIIntrOp,
@@ -238,6 +244,8 @@ void mlir::configureArmSVELegalizeForExportTarget(
SmmlaOp,
UdotOp,
UmmlaOp,
+ UsmmlaOp,
+ DupQLaneOp,
ScalableMaskedAddIOp,
ScalableMaskedAddFOp,
ScalableMaskedSubIOp,
diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
new file mode 100644
index 0000000000000..c0620c71440bc
--- /dev/null
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp
@@ -0,0 +1,304 @@
+//===- LowerContractionToSMMLAPattern.cpp - Contract to SMMLA ---*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to
+// SVE I8MM operations.
+//
+//===---
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmSVE/IR/ArmSVEDialect.h"
+#include "mlir/Dialect/ArmSVE/Transforms/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#include "mlir/Dialect/UB/IR/UBOps.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-sve-i8mm"
+
+using namespace mlir;
+using namespace mlir::arm_sve;
+
+namespace {
+// Check if the given value is a result of the operation `T` (which must be
+// sign- or zero- extend) from i8 to i32. Return the value before the extension.
+template <typename T>
+inline std::enable_if_t<(std::is_base_of_v<arith::ExtSIOp, T> ||
+ std::is_base_of_v<arith::ExtUIOp, T>),
+ std::optional<Value>>
+extractExtOperand(Value v, Type i8Ty, Type i32Ty) {
+ auto extOp = dyn_cast_or_null<T>(v.getDefiningOp());
+ if (!extOp)
+ return {};
+
+ auto inOp = extOp.getIn();
+ auto inTy = dyn_cast<VectorType>(inOp.getType());
+ if (!inTy || inTy.getElementType() != i8Ty)
+ return {};
+
+ auto outTy = dyn_cast<VectorType>(extOp.getType());
+ if (!outTy || outTy.getElementType() != i32Ty)
+ return {};
+
+ return inOp;
+}
+
+// Designate the operation (resp. instruction) used to do sub-tile matrix
+// multiplications.
+enum class MMLA {
+ Signed, // smmla
+ Unsigned, // ummla
+ Mixed, // usmmla
+ MixedSwapped // usmmla with LHS and RHS swapped
+};
+
+// Create the matrix multply and accumulate operation according to `op`.
+Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc,
+ mlir::VectorType accType, Value acc, Value lhs, Value rhs) {
+ switch (op) {
+ case MMLA::Signed:
+ return rewriter.create<arm_sve::SmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::Unsigned:
+ return rewriter.create<arm_sve::UmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::Mixed:
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, lhs, rhs);
+ case MMLA::MixedSwapped:
+ // The accumulator comes transposed and the result will be transposed
+ // later, so all we have to do here is swap the operands.
+ return rewriter.create<arm_sve::UsmmlaOp>(loc, accType, acc, rhs, lhs);
+ }
+}
+
+class LowerContractionToSVEI8MMPattern
+ : public OpRewritePattern<vector::ContractionOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ LogicalResult matchAndRewrite(vector::ContractionOp op,
+ PatternRewriter &rewriter) const override {
+
+ Location loc = op.getLoc();
+ mlir::VectorType lhsType = op.getLhsType();
+ mlir::VectorType rhsType = op.getRhsType();
+
+ // For now handle LHS<Mx8> and RHS<8x[N]> - these are the types we
+ // eventually expect from MMT4D. M and N dimensions must be even and at
+ // least 2.
+ if (!lhsType.hasRank() || lhsType.getRank() != 2 || !rhsType.hasRank() ||
+ rhsType.getRank() != 2)
+ return failure();
+
+ if (lhsType.isScalable() || !rhsType.isScalable())
+ return failure();
+
+ // M, N, and K are the conventional names for matrix dimensions in the
+ // context of matrix multiplication.
+ auto M = lhsType.getDimSize(0);
+ auto N = rhsType.getDimSize(0);
+ auto K = rhsType.getDimSize(1);
+
+ if (lhsType.getDimSize(1) != K || K != 8 || M < 2 || M % 2 != 0 || N < 2 ||
+ N % 2 != 0 || !rhsType.getScalableDims()[0])
+ return failure();
+
+ // Check permutation maps. For now only accept
+ // lhs: (d0, d1, d2) -> (d0, d2)
+ // rhs: (d0, d1, d2) -> (d1, d2)
+ // acc: (d0, d1, d2) -> (d0, d1)
+ // Note: RHS is transposed.
+ if (op.getIndexingMapsArray()[0] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[1] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{1u, 2u},
+ op.getContext()) ||
+ op.getIndexingMapsArray()[2] !=
+ AffineMap::getMultiDimMapWithTargets(3, ArrayRef{0u, 1u},
+ op.getContext()))
+ return failure();
+
+ // Check iterator types for matrix multiplication.
+ auto itTypes = op.getIteratorTypesArray();
+ if (itTypes.size() != 3 || itTypes[0] != vector::IteratorType::parallel ||
+ itTypes[1] != vector::IteratorType::parallel ||
+ itTypes[2] != vector::IteratorType::reduction)
+ return failure();
+
+ // Check the combining kind is addition.
+ if (op.getKind() != vector::CombiningKind::ADD)
+ return failure();
+
+ // Check the output is a vector of i32 elements.
+ auto outTy = dyn_cast<VectorType>(op.getType());
+ if (!outTy || outTy.getElementType() != rewriter.getI32Type())
+ return failure();
+
+ // Check inputs are sign-/zero- extensions from i8 to i32. Get the values
+ // before the extension. All four signed/unsigned combinations for input
+ // operands are supported, but they are lowered to different operations.
+ // Determina which is the appropriate operation to lower to.
+ MMLA mmlaOp = MMLA::Signed;
+ auto maybeLhs = extractExtOperand<arith::ExtSIOp>(
+ op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ if (!maybeLhs) {
+ mmlaOp = MMLA::Unsigned;
+ maybeLhs = extractExtOperand<arith::ExtUIOp>(
+ op.getLhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ }
+ if (!maybeLhs)
+ return failure();
+
+ auto maybeRhs = extractExtOperand<arith::ExtSIOp>(
+ op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ if (maybeRhs) {
+ if (mmlaOp == MMLA::Unsigned)
+ mmlaOp = MMLA::Mixed;
+ } else {
+ if (mmlaOp == MMLA::Signed)
+ mmlaOp = MMLA::MixedSwapped;
+ maybeRhs = extractExtOperand<arith::ExtUIOp>(
+ op.getRhs(), rewriter.getI8Type(), rewriter.getI32Type());
+ }
+ if (!maybeRhs)
+ return failure();
+
+ // One-dimensional vector types for arm_sve.*mmla
+ auto nxv16i8 = VectorType::get(16, rewriter.getI8Type(), {true});
+ auto nxv4i32 = VectorType::get(4, rewriter.getI32Type(), {true});
+
+ // Extract LHS sub-tiles.
+ SmallVector<Value> lhsTile;
+ for (int64_t i = 0; i < M; i += 2) {
+ // Exract two consective rows of the LHS tile.
+ auto r0 = rew...
[truncated]
|
You can test this locally with the following command:git-clang-format --diff HEAD~1 HEAD --extensions h,cpp -- mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp mlir/include/mlir/Dialect/ArmSVE/Transforms/Transforms.h mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToSMMLAPattern.cpp mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp View the diff from clang-format here.diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
index 0bbe4717e..b1846e151 100644
--- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeForLLVMExport.cpp
@@ -25,7 +25,8 @@ using SmmlaOpLowering = OneToOneConvertToLLVMPattern<SmmlaOp, SmmlaIntrOp>;
using UdotOpLowering = OneToOneConvertToLLVMPattern<UdotOp, UdotIntrOp>;
using UmmlaOpLowering = OneToOneConvertToLLVMPattern<UmmlaOp, UmmlaIntrOp>;
using UsmmlaOpLowering = OneToOneConvertToLLVMPattern<UsmmlaOp, UsmmlaIntrOp>;
-using DupQLaneLowering = OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
+using DupQLaneLowering =
+ OneToOneConvertToLLVMPattern<DupQLaneOp, DupQLaneIntrOp>;
using ScalableMaskedAddIOpLowering =
OneToOneConvertToLLVMPattern<ScalableMaskedAddIOp,
ScalableMaskedAddIIntrOp>;
|
Depends on:
svdupq_lane
#135356svusmmla
#135358In case there is more than one commit per PR, commits that will be eventually squished will
begin with
[fixup]
.