Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[AArch64] Don't allow mixed partial reductions without i8mm #137602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 1, 2025

Conversation

SamTebbs33
Copy link
Collaborator

Partial reductions with mixed extends should only be allowed if i8mm is present.

@llvmbot
Copy link
Member

llvmbot commented Apr 28, 2025

@llvm/pr-subscribers-backend-aarch64

@llvm/pr-subscribers-llvm-transforms

Author: Sam Tebbs (SamTebbs33)

Changes

Partial reductions with mixed extends should only be allowed if i8mm is present.


Full diff: https://github.com/llvm/llvm-project/pull/137602.diff

2 Files Affected:

  • (modified) llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp (+2-3)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll (+70-44)
diff --git a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
index c670b2ae71bf3..4c6803ed310b2 100644
--- a/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
+++ b/llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
@@ -5345,10 +5345,9 @@ InstructionCost AArch64TTIImpl::getPartialReductionCost(
     return Invalid;
 
   // AArch64 supports lowering mixed extensions to a usdot but only if the
-  // i8mm or sve/streaming features are available.
+  // i8mm feature is available.
   if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
-      (OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
-       !ST->isSVEorStreamingSVEAvailable()))
+      (OpAExtend != OpBExtend && !ST->hasMatMulInt8()))
     return Invalid;
 
   if (!BinOp || *BinOp != Instruction::Mul)
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
index cdd23b8f4d761..5950a3f84502b 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-mixed.ll
@@ -46,38 +46,51 @@ define i32 @dotp_z_s(ptr %a, ptr %b) #0 {
 ; CHECK-NOI8MM-LABEL: define i32 @dotp_z_s(
 ; CHECK-NOI8MM-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0:[0-9]+]] {
 ; CHECK-NOI8MM-NEXT:  entry:
+; CHECK-NOI8MM-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT:    [[TMP13:%.*]] = mul i64 [[TMP0]], 8
 ; CHECK-NOI8MM-NEXT:    br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK-NOI8MM:       vector.ph:
+; CHECK-NOI8MM-NEXT:    [[TMP14:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP14]], 8
+; CHECK-NOI8MM-NEXT:    [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]]
+; CHECK-NOI8MM-NEXT:    [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]]
+; CHECK-NOI8MM-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 8
 ; CHECK-NOI8MM-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-NOI8MM:       vector.body:
 ; CHECK-NOI8MM-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT:    [[VEC_PHI1:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP22:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT:    [[VEC_PHI1:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP23:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NOI8MM-NEXT:    [[TMP1:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
 ; CHECK-NOI8MM-NEXT:    [[TMP2:%.*]] = getelementptr i8, ptr [[TMP1]], i32 0
-; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
-; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
-; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP4:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP8:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT:    [[TMP9:%.*]] = mul i64 [[TMP8]], 4
+; CHECK-NOI8MM-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[TMP1]], i64 [[TMP9]]
+; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP2]], align 1
+; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 4 x i8>, ptr [[TMP10]], align 1
+; CHECK-NOI8MM-NEXT:    [[TMP11:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP12:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD2]] to <vscale x 4 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
 ; CHECK-NOI8MM-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
-; CHECK-NOI8MM-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
-; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
-; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP9:%.*]] = sext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP10:%.*]] = sext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
-; CHECK-NOI8MM-NEXT:    [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
-; CHECK-NOI8MM-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
-; CHECK-NOI8MM-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
-; CHECK-NOI8MM-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
+; CHECK-NOI8MM-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT:    [[TMP16:%.*]] = mul i64 [[TMP15]], 4
+; CHECK-NOI8MM-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[TMP6]], i64 [[TMP16]]
+; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD3:%.*]] = load <vscale x 4 x i8>, ptr [[TMP7]], align 1
+; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 4 x i8>, ptr [[TMP17]], align 1
+; CHECK-NOI8MM-NEXT:    [[TMP18:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD3]] to <vscale x 4 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP19:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD4]] to <vscale x 4 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP20:%.*]] = mul <vscale x 4 x i32> [[TMP18]], [[TMP11]]
+; CHECK-NOI8MM-NEXT:    [[TMP21:%.*]] = mul <vscale x 4 x i32> [[TMP19]], [[TMP12]]
+; CHECK-NOI8MM-NEXT:    [[TMP22]] = add <vscale x 4 x i32> [[TMP20]], [[VEC_PHI]]
+; CHECK-NOI8MM-NEXT:    [[TMP23]] = add <vscale x 4 x i32> [[TMP21]], [[VEC_PHI1]]
+; CHECK-NOI8MM-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NOI8MM-NEXT:    [[TMP24:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NOI8MM-NEXT:    br i1 [[TMP24]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP0:![0-9]+]]
 ; CHECK-NOI8MM:       middle.block:
-; CHECK-NOI8MM-NEXT:    [[BIN_RDX:%.*]] = add <4 x i32> [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]]
-; CHECK-NOI8MM-NEXT:    [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[BIN_RDX]])
-; CHECK-NOI8MM-NEXT:    br i1 true, label [[FOR_EXIT:%.*]], label [[SCALAR_PH]]
+; CHECK-NOI8MM-NEXT:    [[BIN_RDX:%.*]] = add <vscale x 4 x i32> [[TMP23]], [[TMP22]]
+; CHECK-NOI8MM-NEXT:    [[TMP25:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[BIN_RDX]])
+; CHECK-NOI8MM-NEXT:    [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]]
+; CHECK-NOI8MM-NEXT:    br i1 [[CMP_N]], label [[FOR_EXIT:%.*]], label [[SCALAR_PH]]
 ; CHECK-NOI8MM:       scalar.ph:
 ;
 entry:
@@ -143,38 +156,51 @@ define i32 @dotp_s_z(ptr %a, ptr %b) #0 {
 ; CHECK-NOI8MM-LABEL: define i32 @dotp_s_z(
 ; CHECK-NOI8MM-SAME: ptr [[A:%.*]], ptr [[B:%.*]]) #[[ATTR0]] {
 ; CHECK-NOI8MM-NEXT:  entry:
+; CHECK-NOI8MM-NEXT:    [[TMP0:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT:    [[TMP13:%.*]] = mul i64 [[TMP0]], 8
 ; CHECK-NOI8MM-NEXT:    br i1 false, label [[SCALAR_PH:%.*]], label [[VECTOR_PH:%.*]]
 ; CHECK-NOI8MM:       vector.ph:
+; CHECK-NOI8MM-NEXT:    [[TMP14:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP14]], 8
+; CHECK-NOI8MM-NEXT:    [[N_MOD_VF:%.*]] = urem i64 1024, [[TMP3]]
+; CHECK-NOI8MM-NEXT:    [[N_VEC:%.*]] = sub i64 1024, [[N_MOD_VF]]
+; CHECK-NOI8MM-NEXT:    [[TMP4:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = mul i64 [[TMP4]], 8
 ; CHECK-NOI8MM-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK-NOI8MM:       vector.body:
 ; CHECK-NOI8MM-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[VECTOR_BODY]] ]
-; CHECK-NOI8MM-NEXT:    [[VEC_PHI1:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE5:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT:    [[VEC_PHI:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP22:%.*]], [[VECTOR_BODY]] ]
+; CHECK-NOI8MM-NEXT:    [[VEC_PHI1:%.*]] = phi <vscale x 4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[TMP23:%.*]], [[VECTOR_BODY]] ]
 ; CHECK-NOI8MM-NEXT:    [[TMP1:%.*]] = getelementptr i8, ptr [[A]], i64 [[INDEX]]
 ; CHECK-NOI8MM-NEXT:    [[TMP2:%.*]] = getelementptr i8, ptr [[TMP1]], i32 0
-; CHECK-NOI8MM-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[TMP1]], i32 16
-; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP2]], align 1
-; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD2:%.*]] = load <16 x i8>, ptr [[TMP3]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP4:%.*]] = sext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP5:%.*]] = sext <16 x i8> [[WIDE_LOAD2]] to <16 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP8:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT:    [[TMP9:%.*]] = mul i64 [[TMP8]], 4
+; CHECK-NOI8MM-NEXT:    [[TMP10:%.*]] = getelementptr i8, ptr [[TMP1]], i64 [[TMP9]]
+; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD:%.*]] = load <vscale x 4 x i8>, ptr [[TMP2]], align 1
+; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD2:%.*]] = load <vscale x 4 x i8>, ptr [[TMP10]], align 1
+; CHECK-NOI8MM-NEXT:    [[TMP11:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD]] to <vscale x 4 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP12:%.*]] = sext <vscale x 4 x i8> [[WIDE_LOAD2]] to <vscale x 4 x i32>
 ; CHECK-NOI8MM-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[INDEX]]
 ; CHECK-NOI8MM-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
-; CHECK-NOI8MM-NEXT:    [[TMP8:%.*]] = getelementptr i8, ptr [[TMP6]], i32 16
-; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD3:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
-; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD4:%.*]] = load <16 x i8>, ptr [[TMP8]], align 1
-; CHECK-NOI8MM-NEXT:    [[TMP9:%.*]] = zext <16 x i8> [[WIDE_LOAD3]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP10:%.*]] = zext <16 x i8> [[WIDE_LOAD4]] to <16 x i32>
-; CHECK-NOI8MM-NEXT:    [[TMP11:%.*]] = mul <16 x i32> [[TMP9]], [[TMP4]]
-; CHECK-NOI8MM-NEXT:    [[TMP12:%.*]] = mul <16 x i32> [[TMP10]], [[TMP5]]
-; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP11]])
-; CHECK-NOI8MM-NEXT:    [[PARTIAL_REDUCE5]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI1]], <16 x i32> [[TMP12]])
-; CHECK-NOI8MM-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 32
-; CHECK-NOI8MM-NEXT:    [[TMP13:%.*]] = icmp eq i64 [[INDEX_NEXT]], 1024
-; CHECK-NOI8MM-NEXT:    br i1 [[TMP13]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
+; CHECK-NOI8MM-NEXT:    [[TMP15:%.*]] = call i64 @llvm.vscale.i64()
+; CHECK-NOI8MM-NEXT:    [[TMP16:%.*]] = mul i64 [[TMP15]], 4
+; CHECK-NOI8MM-NEXT:    [[TMP17:%.*]] = getelementptr i8, ptr [[TMP6]], i64 [[TMP16]]
+; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD3:%.*]] = load <vscale x 4 x i8>, ptr [[TMP7]], align 1
+; CHECK-NOI8MM-NEXT:    [[WIDE_LOAD4:%.*]] = load <vscale x 4 x i8>, ptr [[TMP17]], align 1
+; CHECK-NOI8MM-NEXT:    [[TMP18:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD3]] to <vscale x 4 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP19:%.*]] = zext <vscale x 4 x i8> [[WIDE_LOAD4]] to <vscale x 4 x i32>
+; CHECK-NOI8MM-NEXT:    [[TMP20:%.*]] = mul <vscale x 4 x i32> [[TMP18]], [[TMP11]]
+; CHECK-NOI8MM-NEXT:    [[TMP21:%.*]] = mul <vscale x 4 x i32> [[TMP19]], [[TMP12]]
+; CHECK-NOI8MM-NEXT:    [[TMP22]] = add <vscale x 4 x i32> [[TMP20]], [[VEC_PHI]]
+; CHECK-NOI8MM-NEXT:    [[TMP23]] = add <vscale x 4 x i32> [[TMP21]], [[VEC_PHI1]]
+; CHECK-NOI8MM-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP5]]
+; CHECK-NOI8MM-NEXT:    [[TMP24:%.*]] = icmp eq i64 [[INDEX_NEXT]], [[N_VEC]]
+; CHECK-NOI8MM-NEXT:    br i1 [[TMP24]], label [[MIDDLE_BLOCK:%.*]], label [[VECTOR_BODY]], !llvm.loop [[LOOP4:![0-9]+]]
 ; CHECK-NOI8MM:       middle.block:
-; CHECK-NOI8MM-NEXT:    [[BIN_RDX:%.*]] = add <4 x i32> [[PARTIAL_REDUCE5]], [[PARTIAL_REDUCE]]
-; CHECK-NOI8MM-NEXT:    [[TMP14:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[BIN_RDX]])
-; CHECK-NOI8MM-NEXT:    br i1 true, label [[FOR_EXIT:%.*]], label [[SCALAR_PH]]
+; CHECK-NOI8MM-NEXT:    [[BIN_RDX:%.*]] = add <vscale x 4 x i32> [[TMP23]], [[TMP22]]
+; CHECK-NOI8MM-NEXT:    [[TMP25:%.*]] = call i32 @llvm.vector.reduce.add.nxv4i32(<vscale x 4 x i32> [[BIN_RDX]])
+; CHECK-NOI8MM-NEXT:    [[CMP_N:%.*]] = icmp eq i64 1024, [[N_VEC]]
+; CHECK-NOI8MM-NEXT:    br i1 [[CMP_N]], label [[FOR_EXIT:%.*]], label [[SCALAR_PH]]
 ; CHECK-NOI8MM:       scalar.ph:
 ;
 entry:

if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
!ST->isSVEorStreamingSVEAvailable()))
(OpAExtend != OpBExtend && !ST->hasMatMulInt8()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to the specification, the instruction is undefined if !i8mm || (!sve && !sme), which means the instruction is defined if i8mm && (sve || sme). So this needs an additional && ST->isSVEorStreamingSVEAvailable()

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this should be:

    (OpAExtend != OpBExtend && (!ST->hasMatMulInt8() || !ST->isSVEorStreamingSVEAvailable())))

?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems I forgot that there was also a corresponding NEON instruction. The code above already checks to see if SVE/Streaming-SVE is available when the type is scalable, or if NEON is available when the type is fixed-length, so only checking for i8mm here should indeed be sufficient. Could you add a test-case for the NEON case?

@@ -225,8 +225,203 @@ for.exit: ; preds = %for.body
ret i32 %add
}

define i32 @dotp_z_s_neon(ptr %a, ptr %b) #1 {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very minor nit: does it make more sense to call this usdot_neon ? (similar question for the other one)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I've called this one sudot_neon since it's mul(sext, zext).

Comment on lines +2 to +3
; RUN: opt -passes=loop-vectorize -enable-epilogue-vectorization=false -mattr=+i8mm,+dotprod --vectorizer-maximize-bandwidth -S < %s | FileCheck %s
; RUN: opt -passes=loop-vectorize -enable-epilogue-vectorization=false -mattr=+dotprod -S --vectorizer-maximize-bandwidth < %s | FileCheck %s --check-prefix=CHECK-NOI8MM
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? If not, can you remove it?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is, since otherwise we don't test the scalable pathway as the max VF available is vscale x 4.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the unfortunate timing, but I've just come to realise that this flag is probably required if you want the SVE sudot/usdot tests to enforce scalable vectors. Otherwise, it would still pick a fixed-length VF for the +i8mm case, where you actually want to test this specifically uses a scalable VF. Could you remove the last commit (486aa4c) ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

if (OpAExtend == TTI::PR_None || OpBExtend == TTI::PR_None ||
(OpAExtend != OpBExtend && !ST->hasMatMulInt8() &&
!ST->isSVEorStreamingSVEAvailable()))
(OpAExtend != OpBExtend && VF.isFixed() && !ST->hasMatMulInt8()))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think VF.isFixed() can be removed, because SVE also has usdot instructions (if i8mm is available)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@SamTebbs33 SamTebbs33 merged commit 2876dbc into llvm:main May 1, 2025
11 checks passed
@SamTebbs33 SamTebbs33 deleted the mixed-pred-i8mm branch May 1, 2025 15:06
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
)

Partial reductions with mixed extends should only be allowed if i8mm is
present.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
)

Partial reductions with mixed extends should only be allowed if i8mm is
present.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
)

Partial reductions with mixed extends should only be allowed if i8mm is
present.
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
)

Partial reductions with mixed extends should only be allowed if i8mm is
present.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants