-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[AArch64] Don't allow mixed partial reductions without i8mm #137602
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-backend-aarch64 @llvm/pr-subscribers-llvm-transforms Author: Sam Tebbs (SamTebbs33) ChangesPartial reductions with mixed extends should only be allowed if i8mm is present. Full diff: https://github.com/llvm/llvm-project/pull/137602.diff 2 Files Affected:
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index c670b2ae71bf3..4c6803ed310b2 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5345,10 +5345,9 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
return Invalid;
// AArch64 supports lowering mixed extensions to a usdot but only if the
- // i8mm or sve/streaming features are available.
+ // i8mm feature is available.
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
- (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
- !ST->isSVEorStreamingSVEAvailable()))
+ (OpAExtend != OpBExtend && !ST->hasMatMulInt8()))
return Invalid;
if (!BinOp || *BinOp != Instruction::Mul)
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
index cdd23b8f4d761..5950a3f84502b 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
@@ -46,38 +46,51 @@ define i32 @dotp_z_s(ptr %a, ptr %b) #0 {
; CHECK-NOI8MM-LABEL: define i32 @dotp_z_s(
; CHECK-NOI8MM-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
; CHECK-NOI8MM-NEXT: entry:
+; CHECK-NOI8MM-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT: [[TMP13:%.*]] = mul i64 [[TMP0]], 8
; CHECK-NOI8MM-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK-NOI8MM: vector.ph:
+; CHECK-NOI8MM-NEXT: [[TMP14:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT: [[TMP3:%.*]] = mul i64 [[TMP14]], 8
+; CHECK-NOI8MM-NEXT: [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]]
+; CHECK-NOI8MM-NEXT: [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]]
+; CHECK-NOI8MM-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 8
; CHECK-NOI8MM-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK-NOI8MM: vector.body:
; CHECK-NOI8MM-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT: [[VEC_PHI1:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP22:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT: [[VEC_PHI1:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP23:%.*]], [[VECTOR_BODY]] ]
; CHECK-NOI8MM-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
; CHECK-NOI8MM-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr [[TMP1]], i32 0
-; CHECK-NOI8MM-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
-; CHECK-NOI8MM-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
-; CHECK-NOI8MM-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NOI8MM-NEXT: [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NOI8MM-NEXT: [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NOI8MM-NEXT: [[TMP8:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT: [[TMP9:%.*]] = mul i64 [[TMP8]], 4
+; CHECK-NOI8MM-NEXT: [[TMP10:%.*]] = getelementptr i8, ptr [[TMP1]], i64 [[TMP9]]
+; CHECK-NOI8MM-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP2]], align 1
+; CHECK-NOI8MM-NEXT: [[WIDE_LOAD2:%.*]] = load <vscale x 4 x i8>, ptr [[TMP10]], align 1
+; CHECK-NOI8MM-NEXT: [[TMP11:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
+; CHECK-NOI8MM-NEXT: [[TMP12:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD2]] to <vscale x 4 x i32>
; CHECK-NOI8MM-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
; CHECK-NOI8MM-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
-; CHECK-NOI8MM-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
-; CHECK-NOI8MM-NEXT: [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
-; CHECK-NOI8MM-NEXT: [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NOI8MM-NEXT: [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NOI8MM-NEXT: [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NOI8MM-NEXT: [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
-; CHECK-NOI8MM-NEXT: [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NOI8MM-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NOI8MM-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
-; CHECK-NOI8MM-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
-; CHECK-NOI8MM-NEXT: [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
-; CHECK-NOI8MM-NEXT: br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-NOI8MM-NEXT: [[TMP15:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT: [[TMP16:%.*]] = mul i64 [[TMP15]], 4
+; CHECK-NOI8MM-NEXT: [[TMP17:%.*]] = getelementptr i8, ptr [[TMP6]], i64 [[TMP16]]
+; CHECK-NOI8MM-NEXT: [[WIDE_LOAD3:%.*]] = load <vscale x 4 x i8>, ptr [[TMP7]], align 1
+; CHECK-NOI8MM-NEXT: [[WIDE_LOAD4:%.*]] = load <vscale x 4 x i8>, ptr [[TMP17]], align 1
+; CHECK-NOI8MM-NEXT: [[TMP18:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD3]] to <vscale x 4 x i32>
+; CHECK-NOI8MM-NEXT: [[TMP19:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD4]] to <vscale x 4 x i32>
+; CHECK-NOI8MM-NEXT: [[TMP20:%.*]] = mul <vscale x 4 x i32> [[TMP18]], [[TMP11]]
+; CHECK-NOI8MM-NEXT: [[TMP21:%.*]] = mul <vscale x 4 x i32> [[TMP19]], [[TMP12]]
+; CHECK-NOI8MM-NEXT: [[TMP22]] = add <vscale x 4 x i32> [[TMP20]], [[VEC_PHI]]
+; CHECK-NOI8MM-NEXT: [[TMP23]] = add <vscale x 4 x i32> [[TMP21]], [[VEC_PHI1]]
+; CHECK-NOI8MM-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NOI8MM-NEXT: [[TMP24:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NOI8MM-NEXT: br i1 [[TMP24]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
; CHECK-NOI8MM: middle.block:
-; CHECK-NOI8MM-NEXT: [[BIN_RDX:%.*]] = add <4 x i32> [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]]
-; CHECK-NOI8MM-NEXT: [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[BIN_RDX]])
-; CHECK-NOI8MM-NEXT: br i1 true, label [[FOR_EXIT:%.*]], label [[SCALAR_PH]]
+; CHECK-NOI8MM-NEXT: [[BIN_RDX:%.*]] = add <vscale x 4 x i32> [[TMP23]], [[TMP22]]
+; CHECK-NOI8MM-NEXT: [[TMP25:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[BIN_RDX]])
+; CHECK-NOI8MM-NEXT: [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]]
+; CHECK-NOI8MM-NEXT: br i1 [[CMP_N]], label [[FOR_EXIT:%.*]], label [[SCALAR_PH]]
; CHECK-NOI8MM: scalar.ph:
;
entry:
@@ -143,38 +156,51 @@ define i32 @dotp_s_z(ptr %a, ptr %b) #0 {
; CHECK-NOI8MM-LABEL: define i32 @dotp_s_z(
; CHECK-NOI8MM-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0]] {
; CHECK-NOI8MM-NEXT: entry:
+; CHECK-NOI8MM-NEXT: [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT: [[TMP13:%.*]] = mul i64 [[TMP0]], 8
; CHECK-NOI8MM-NEXT: br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
; CHECK-NOI8MM: vector.ph:
+; CHECK-NOI8MM-NEXT: [[TMP14:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT: [[TMP3:%.*]] = mul i64 [[TMP14]], 8
+; CHECK-NOI8MM-NEXT: [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]]
+; CHECK-NOI8MM-NEXT: [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]]
+; CHECK-NOI8MM-NEXT: [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT: [[TMP5:%.*]] = mul i64 [[TMP4]], 8
; CHECK-NOI8MM-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK-NOI8MM: vector.body:
; CHECK-NOI8MM-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT: [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT: [[VEC_PHI1:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT: [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP22:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT: [[VEC_PHI1:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP23:%.*]], [[VECTOR_BODY]] ]
; CHECK-NOI8MM-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
; CHECK-NOI8MM-NEXT: [[TMP2:%.*]] = getelementptr i8, ptr [[TMP1]], i32 0
-; CHECK-NOI8MM-NEXT: [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
-; CHECK-NOI8MM-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
-; CHECK-NOI8MM-NEXT: [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NOI8MM-NEXT: [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NOI8MM-NEXT: [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NOI8MM-NEXT: [[TMP8:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT: [[TMP9:%.*]] = mul i64 [[TMP8]], 4
+; CHECK-NOI8MM-NEXT: [[TMP10:%.*]] = getelementptr i8, ptr [[TMP1]], i64 [[TMP9]]
+; CHECK-NOI8MM-NEXT: [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP2]], align 1
+; CHECK-NOI8MM-NEXT: [[WIDE_LOAD2:%.*]] = load <vscale x 4 x i8>, ptr [[TMP10]], align 1
+; CHECK-NOI8MM-NEXT: [[TMP11:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
+; CHECK-NOI8MM-NEXT: [[TMP12:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD2]] to <vscale x 4 x i32>
; CHECK-NOI8MM-NEXT: [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
; CHECK-NOI8MM-NEXT: [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
-; CHECK-NOI8MM-NEXT: [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
-; CHECK-NOI8MM-NEXT: [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
-; CHECK-NOI8MM-NEXT: [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NOI8MM-NEXT: [[TMP9:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NOI8MM-NEXT: [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NOI8MM-NEXT: [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
-; CHECK-NOI8MM-NEXT: [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NOI8MM-NEXT: [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NOI8MM-NEXT: [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
-; CHECK-NOI8MM-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
-; CHECK-NOI8MM-NEXT: [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
-; CHECK-NOI8MM-NEXT: br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
+; CHECK-NOI8MM-NEXT: [[TMP15:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT: [[TMP16:%.*]] = mul i64 [[TMP15]], 4
+; CHECK-NOI8MM-NEXT: [[TMP17:%.*]] = getelementptr i8, ptr [[TMP6]], i64 [[TMP16]]
+; CHECK-NOI8MM-NEXT: [[WIDE_LOAD3:%.*]] = load <vscale x 4 x i8>, ptr [[TMP7]], align 1
+; CHECK-NOI8MM-NEXT: [[WIDE_LOAD4:%.*]] = load <vscale x 4 x i8>, ptr [[TMP17]], align 1
+; CHECK-NOI8MM-NEXT: [[TMP18:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD3]] to <vscale x 4 x i32>
+; CHECK-NOI8MM-NEXT: [[TMP19:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD4]] to <vscale x 4 x i32>
+; CHECK-NOI8MM-NEXT: [[TMP20:%.*]] = mul <vscale x 4 x i32> [[TMP18]], [[TMP11]]
+; CHECK-NOI8MM-NEXT: [[TMP21:%.*]] = mul <vscale x 4 x i32> [[TMP19]], [[TMP12]]
+; CHECK-NOI8MM-NEXT: [[TMP22]] = add <vscale x 4 x i32> [[TMP20]], [[VEC_PHI]]
+; CHECK-NOI8MM-NEXT: [[TMP23]] = add <vscale x 4 x i32> [[TMP21]], [[VEC_PHI1]]
+; CHECK-NOI8MM-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NOI8MM-NEXT: [[TMP24:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NOI8MM-NEXT: br i1 [[TMP24]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
; CHECK-NOI8MM: middle.block:
-; CHECK-NOI8MM-NEXT: [[BIN_RDX:%.*]] = add <4 x i32> [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]]
-; CHECK-NOI8MM-NEXT: [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[BIN_RDX]])
-; CHECK-NOI8MM-NEXT: br i1 true, label [[FOR_EXIT:%.*]], label [[SCALAR_PH]]
+; CHECK-NOI8MM-NEXT: [[BIN_RDX:%.*]] = add <vscale x 4 x i32> [[TMP23]], [[TMP22]]
+; CHECK-NOI8MM-NEXT: [[TMP25:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[BIN_RDX]])
+; CHECK-NOI8MM-NEXT: [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]]
+; CHECK-NOI8MM-NEXT: br i1 [[CMP_N]], label [[FOR_EXIT:%.*]], label [[SCALAR_PH]]
; CHECK-NOI8MM: scalar.ph:
;
entry:
|
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None || | ||
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() && | ||
!ST->isSVEorStreamingSVEAvailable())) | ||
(OpAExtend != OpBExtend && !ST->hasMatMulInt8())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to the specification, the instruction is undefined if !i8mm || (!sve && !sme)
, which means the instruction is defined if i8mm && (sve || sme)
. So this needs an additional && ST->isSVEorStreamingSVEAvailable()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess this should be:
(OpAExtend != OpBExtend && (!ST->hasMatMulInt8() || !ST->isSVEorStreamingSVEAvailable())))
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems I forgot that there was also a corresponding NEON instruction. The code above already checks to see if SVE/Streaming-SVE is available when the type is scalable, or if NEON is available when the type is fixed-length, so only checking for i8mm here should indeed be sufficient. Could you add a test-case for the NEON case?
@@ -225,8 +225,203 @@ for.exit: ; preds = %for.body | |||
ret i32 %add | |||
} | |||
|
|||
define i32 @dotp_z_s_neon(ptr %a, ptr %b) #1 { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very minor nit: does it make more sense to call this usdot_neon
? (similar question for the other one)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. I've called this one sudot_neon since it's mul(sext, zext)
.
; RUN: opt -passes=loop-vectorize -enable-epilogue-vectorization=false -mattr=+i8mm,+dotprod --vectorizer-maximize-bandwidth -S < %s | FileCheck %s | ||
; RUN: opt -passes=loop-vectorize -enable-epilogue-vectorization=false -mattr=+dotprod -S --vectorizer-maximize-bandwidth < %s | FileCheck %s --check-prefix=CHECK-NOI8MM |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this necessary? If not, can you remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is, since otherwise we don't test the scalable pathway as the max VF available is vscale x 4.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the unfortunate timing, but I've just come to realise that this flag is probably required if you want the SVE sudot/usdot tests to enforce scalable vectors. Otherwise, it would still pick a fixed-length VF for the +i8mm
case, where you actually want to test this specifically uses a scalable VF. Could you remove the last commit (486aa4c) ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None || | ||
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() && | ||
!ST->isSVEorStreamingSVEAvailable())) | ||
(OpAExtend != OpBExtend && VF.isFixed() && !ST->hasMatMulInt8())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think VF.isFixed()
can be removed, because SVE also has usdot instructions (if i8mm is available)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
Partial reductions with mixed extends should only be allowed if i8mm is present.