-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[InstCombine] Pull shuffles out of binops with splatted ops #137948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[InstCombine] Pull shuffles out of binops with splatted ops #137948
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Luke Lau (lukel97) ChangesGiven a binary op on splatted vector and a splatted constant, InstCombine will normally pull the shuffle out in define <4 x i32> @<!-- -->f(i32 %x) {
%x.insert = insertelement <4 x i32> poison, i32 %x, i64 0
%x.splat = shufflevector <4 x i32> %x.insert, <4 x i32> poison, <4 x i32> zeroinitializer
%res = add <4 x i32> %x.splat, splat (i32 42)
ret <4 x i32> %res
} define <4 x i32> @<!-- -->f(i32 %x) {
%x.insert = insertelement <4 x i32> poison, i32 %x, i64 0
%1 = add <4 x i32> %x.insert, <i32 42, i32 poison, i32 poison, i32 poison>
%res = shufflevector <4 x i32> %1, <4 x i32> poison, <4 x i32> zeroinitializer
ret <4 x i32> %res
} However, this currently only operates on fixed length vectors. Splats of scalable vectors don't currently have their shuffle pulled out, e.g:
Having this canonical form with the shuffle pulled out is important as VectorCombine relies on it in order to scalarize binary ops in This adds a combine just after the fixed-length version, but restricted to splats at index 0 so that it also handles the scalable case: So the whilst the existing combine looks like: This patch adds: I think this could be generalized to other splat indexes that aren't zero, but I think it would be dead code since only fixed-length vectors can have non-zero shuffle indices, which would be covered by the existing combine. Full diff: https://github.com/llvm/llvm-project/pull/137948.diff 3 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index f807f5f4519fc..ecb3899eafe67 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2271,6 +2271,31 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
}
}
+ // Similar to the combine above, but handles the case for scalable vectors
+ // where both V1 and C are splats.
+ //
+ // Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0)
+ if (isa<ScalableVectorType>(Inst.getType()) &&
+ match(&Inst, m_c_BinOp(m_OneUse(m_Shuffle(m_Value(V1), m_Poison(),
+ m_ZeroMask())),
+ m_ImmConstant(C)))) {
+ if (Constant *Splat = C->getSplatValue()) {
+ bool ConstOp1 = isa<Constant>(RHS);
+ VectorType *V1Ty = cast<VectorType>(V1->getType());
+ Constant *NewC = ConstantVector::getSplat(V1Ty->getElementCount(), Splat);
+
+ Value *NewLHS = ConstOp1 ? V1 : NewC;
+ Value *NewRHS = ConstOp1 ? NewC : V1;
+ Value *XY = Builder.CreateBinOp(Opcode, NewLHS, NewRHS);
+ if (auto *BO = dyn_cast<BinaryOperator>(XY))
+ BO->copyIRFlags(&Inst);
+
+ VectorType *VTy = cast<VectorType>(Inst.getType());
+ SmallVector<int> NewM(VTy->getElementCount().getKnownMinValue(), 0);
+ return new ShuffleVectorInst(XY, NewM);
+ }
+ }
+
// Try to reassociate to sink a splat shuffle after a binary operation.
if (Inst.isAssociative() && Inst.isCommutative()) {
// Canonicalize shuffle operand as LHS.
diff --git a/llvm/test/Transforms/InstCombine/getelementptr.ll b/llvm/test/Transforms/InstCombine/getelementptr.ll
index feba952919b9a..61236df80bfa6 100644
--- a/llvm/test/Transforms/InstCombine/getelementptr.ll
+++ b/llvm/test/Transforms/InstCombine/getelementptr.ll
@@ -282,8 +282,8 @@ define <2 x i1> @test13_fixed_scalable(i64 %X, ptr %P, <2 x i64> %y) nounwind {
define <vscale x 2 x i1> @test13_scalable_scalable(i64 %X, ptr %P, <vscale x 2 x i64> %y) nounwind {
; CHECK-LABEL: @test13_scalable_scalable(
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[X:%.*]], i64 0
-; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 2 x i64> [[DOTSPLATINSERT]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
-; CHECK-NEXT: [[A_IDX:%.*]] = shl nsw <vscale x 2 x i64> [[DOTSPLAT]], splat (i64 3)
+; CHECK-NEXT: [[TMP3:%.*]] = shl nsw <vscale x 2 x i64> [[DOTSPLATINSERT]], splat (i64 3)
+; CHECK-NEXT: [[A_IDX:%.*]] = shufflevector <vscale x 2 x i64> [[TMP3]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP2:%.*]] = shl i64 [[TMP1]], 4
; CHECK-NEXT: [[DOTSPLATINSERT1:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[TMP2]], i64 0
diff --git a/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll b/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
index c6329af164623..926f272d36a4b 100644
--- a/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
@@ -1789,3 +1789,26 @@ define <4 x i32> @PR46872(<4 x i32> %x) {
ret <4 x i32> %a
}
+define <vscale x 4 x i32> @scalable_splat_binop_constant_rhs(<vscale x 4 x i32> %x) {
+; CHECK-LABEL: @scalable_splat_binop_constant_rhs(
+; CHECK-NEXT: [[R1:%.*]] = add <vscale x 4 x i32> [[R:%.*]], splat (i32 42)
+; CHECK-NEXT: [[R2:%.*]] = shufflevector <vscale x 4 x i32> [[R1]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT: ret <vscale x 4 x i32> [[R2]]
+;
+
+ %splatx = shufflevector <vscale x 4 x i32> %x, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
+ %r = add <vscale x 4 x i32> %splatx, splat (i32 42)
+ ret <vscale x 4 x i32> %r
+}
+
+define <vscale x 4 x float> @scalable_splat_binop_constant_lhs(<vscale x 4 x float> %x) {
+; CHECK-LABEL: @scalable_splat_binop_constant_lhs(
+; CHECK-NEXT: [[R1:%.*]] = fadd <vscale x 4 x float> [[R:%.*]], splat (float 4.200000e+01)
+; CHECK-NEXT: [[R2:%.*]] = shufflevector <vscale x 4 x float> [[R1]], <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT: ret <vscale x 4 x float> [[R2]]
+;
+
+ %splatx = shufflevector <vscale x 4 x float> %x, <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
+ %r = fadd <vscale x 4 x float> splat (float 42.0), %splatx
+ ret <vscale x 4 x float> %r
+}
|
@@ -2271,6 +2271,27 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { | |||
} | |||
} | |||
|
|||
// Similar to the combine above, but handles the case for scalable vectors | |||
// where both V1 and C are splats. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// where both V1 and C are splats. | |
// where both shuffle(V1, 0) and C are splats. |
V1
is not required to be a splat vector.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Given a binary op on splatted vector and a splatted constant, InstCombine will normally pull the shuffle out in
InstCombinerImpl::foldVectorBinop
:However, this currently only operates on fixed length vectors. Splats of scalable vectors don't currently have their shuffle pulled out, e.g:
Having this canonical form with the shuffle pulled out is important as VectorCombine relies on it in order to scalarize binary ops in
scalarizeBinopOrCmp
, which would prevent the need for #137786. This also brings it in line for scalable binary ops with two non-constant operands: https://godbolt.org/z/M9f7ebzcaThis adds a combine just after the fixed-length version, but restricted to splats at index 0 so that it also handles the scalable case:
So the whilst the existing combine looks like:
Op(shuffle(V1, Mask), C) -> shuffle(Op(V1, NewC), Mask)
This patch adds:
Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0)
I think this could be generalized to other splat indexes that aren't zero, but I think it would be dead code since only fixed-length vectors can have non-zero shuffle indices, which would be covered by the existing combine.