-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[InstCombine] Pull shuffles out of binops with splatted ops #137948
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[InstCombine] Pull shuffles out of binops with splatted ops #137948
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Luke Lau (lukel97) ChangesGiven a binary op on splatted vector and a splatted constant, InstCombine will normally pull the shuffle out in define <4 x i32> @<!-- -->f(i32 %x) {
%x.insert = insertelement <4 x i32> poison, i32 %x, i64 0
%x.splat = shufflevector <4 x i32> %x.insert, <4 x i32> poison, <4 x i32> zeroinitializer
%res = add <4 x i32> %x.splat, splat (i32 42)
ret <4 x i32> %res
} define <4 x i32> @<!-- -->f(i32 %x) {
%x.insert = insertelement <4 x i32> poison, i32 %x, i64 0
%1 = add <4 x i32> %x.insert, <i32 42, i32 poison, i32 poison, i32 poison>
%res = shufflevector <4 x i32> %1, <4 x i32> poison, <4 x i32> zeroinitializer
ret <4 x i32> %res
} However, this currently only operates on fixed length vectors. Splats of scalable vectors don't currently have their shuffle pulled out, e.g:
Having this canonical form with the shuffle pulled out is important as VectorCombine relies on it in order to scalarize binary ops in This adds a combine just after the fixed-length version, but restricted to splats at index 0 so that it also handles the scalable case: So the whilst the existing combine looks like: This patch adds: I think this could be generalized to other splat indexes that aren't zero, but I think it would be dead code since only fixed-length vectors can have non-zero shuffle indices, which would be covered by the existing combine. Full diff: https://github.com/llvm/llvm-project/pull/137948.diff 3 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index f807f5f4519fc..ecb3899eafe67 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2271,6 +2271,31 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
}
}
+ // Similar to the combine above, but handles the case for scalable vectors
+ // where both V1 and C are splats.
+ //
+ // Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0)
+ if (isa<ScalableVectorType>(Inst.getType()) &&
+ match(&Inst, m_c_BinOp(m_OneUse(m_Shuffle(m_Value(V1), m_Poison(),
+ m_ZeroMask())),
+ m_ImmConstant(C)))) {
+ if (Constant *Splat = C->getSplatValue()) {
+ bool ConstOp1 = isa<Constant>(RHS);
+ VectorType *V1Ty = cast<VectorType>(V1->getType());
+ Constant *NewC = ConstantVector::getSplat(V1Ty->getElementCount(), Splat);
+
+ Value *NewLHS = ConstOp1 ? V1 : NewC;
+ Value *NewRHS = ConstOp1 ? NewC : V1;
+ Value *XY = Builder.CreateBinOp(Opcode, NewLHS, NewRHS);
+ if (auto *BO = dyn_cast<BinaryOperator>(XY))
+ BO->copyIRFlags(&Inst);
+
+ VectorType *VTy = cast<VectorType>(Inst.getType());
+ SmallVector<int> NewM(VTy->getElementCount().getKnownMinValue(), 0);
+ return new ShuffleVectorInst(XY, NewM);
+ }
+ }
+
// Try to reassociate to sink a splat shuffle after a binary operation.
if (Inst.isAssociative() && Inst.isCommutative()) {
// Canonicalize shuffle operand as LHS.
diff --git a/llvm/test/Transforms/InstCombine/getelementptr.ll b/llvm/test/Transforms/InstCombine/getelementptr.ll
index feba952919b9a..61236df80bfa6 100644
--- a/llvm/test/Transforms/InstCombine/getelementptr.ll
+++ b/llvm/test/Transforms/InstCombine/getelementptr.ll
@@ -282,8 +282,8 @@ define <2 x i1> @test13_fixed_scalable(i64 %X, ptr %P, <2 x i64> %y) nounwind {
define <vscale x 2 x i1> @test13_scalable_scalable(i64 %X, ptr %P, <vscale x 2 x i64> %y) nounwind {
; CHECK-LABEL: @test13_scalable_scalable(
; CHECK-NEXT: [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[X:%.*]], i64 0
-; CHECK-NEXT: [[DOTSPLAT:%.*]] = shufflevector <vscale x 2 x i64> [[DOTSPLATINSERT]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
-; CHECK-NEXT: [[A_IDX:%.*]] = shl nsw <vscale x 2 x i64> [[DOTSPLAT]], splat (i64 3)
+; CHECK-NEXT: [[TMP3:%.*]] = shl nsw <vscale x 2 x i64> [[DOTSPLATINSERT]], splat (i64 3)
+; CHECK-NEXT: [[A_IDX:%.*]] = shufflevector <vscale x 2 x i64> [[TMP3]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[TMP2:%.*]] = shl i64 [[TMP1]], 4
; CHECK-NEXT: [[DOTSPLATINSERT1:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[TMP2]], i64 0
diff --git a/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll b/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
index c6329af164623..926f272d36a4b 100644
--- a/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
@@ -1789,3 +1789,26 @@ define <4 x i32> @PR46872(<4 x i32> %x) {
ret <4 x i32> %a
}
+define <vscale x 4 x i32> @scalable_splat_binop_constant_rhs(<vscale x 4 x i32> %x) {
+; CHECK-LABEL: @scalable_splat_binop_constant_rhs(
+; CHECK-NEXT: [[R1:%.*]] = add <vscale x 4 x i32> [[R:%.*]], splat (i32 42)
+; CHECK-NEXT: [[R2:%.*]] = shufflevector <vscale x 4 x i32> [[R1]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT: ret <vscale x 4 x i32> [[R2]]
+;
+
+ %splatx = shufflevector <vscale x 4 x i32> %x, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
+ %r = add <vscale x 4 x i32> %splatx, splat (i32 42)
+ ret <vscale x 4 x i32> %r
+}
+
+define <vscale x 4 x float> @scalable_splat_binop_constant_lhs(<vscale x 4 x float> %x) {
+; CHECK-LABEL: @scalable_splat_binop_constant_lhs(
+; CHECK-NEXT: [[R1:%.*]] = fadd <vscale x 4 x float> [[R:%.*]], splat (float 4.200000e+01)
+; CHECK-NEXT: [[R2:%.*]] = shufflevector <vscale x 4 x float> [[R1]], <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT: ret <vscale x 4 x float> [[R2]]
+;
+
+ %splatx = shufflevector <vscale x 4 x float> %x, <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
+ %r = fadd <vscale x 4 x float> splat (float 42.0), %splatx
+ ret <vscale x 4 x float> %r
+}
|
@@ -2271,6 +2271,27 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) { | |||
} | |||
} | |||
|
|||
// Similar to the combine above, but handles the case for scalable vectors | |||
// where both V1 and C are splats. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
// where both V1 and C are splats. | |
// where both shuffle(V1, 0) and C are splats. |
V1
is not required to be a splat vector.
// Similar to the combine above, but handles the case for scalable vectors | ||
// where both V1 and C are splats. | ||
// | ||
// Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This transformation may introduce immediate UB.
Counterexample: sdiv (splat C), shuffle(V1, 0) -> shuffle(Op((splat C), V1), 0)
is invalid if V1
has a zero value at a non-zero index.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a check for isSafeToSpeculativelyExecuteWithVariableReplaced
earlier on in the function which the previous combine also relies on, so I think this should also be covered in the new combine. I've added a negative test in e9d51e7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but the transform you're modeling off just above has several other checks for introducing UB. The mask logic is one, the getSafeVectorConstantForBinop is another. I'm not sure the mask reasoning is needed for this one, but the divrem/shift definitely is? If not, why is the previous block required?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getSafeVectorConstantForBinop is needed in the previous block because if the mask has undef then the "new constant" will have undef for those lanes, which need swapped out e.g:
%shuffle = shufflevector <4 x i32> %v1, <4 x i32> poison, <4 x i32> <i32 undef, i32 1, i32 2, i32 undef>
%x = udiv %shuffle, <1, 2, 3, 4>
; undefs need replacing with safe constant
%x = udiv %v1, <undef, 2, 3, undef>
%shuffle = shufflevector <4 x i32> %t1, <4 x i32> poison, <4 x i32> <i32 undef, i32 1, i32 2, i32 undef>
We don't have that problem in the new combine because the mask is always zero
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you!
// Similar to the combine above, but handles the case for scalable vectors | ||
// where both V1 and C are splats. | ||
// | ||
// Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, but the transform you're modeling off just above has several other checks for introducing UB. The mask logic is one, the getSafeVectorConstantForBinop is another. I'm not sure the mask reasoning is needed for this one, but the divrem/shift definitely is? If not, why is the previous block required?
// where both shuffle(V1, 0) and C are splats. | ||
// | ||
// Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0) | ||
if (isa<ScalableVectorType>(Inst.getType()) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code structure wise: Remove the scalable vector check, move this above the other transform. That way fixed vector splats go through this too, and we have a much higher chance of finding any bugs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I originally had this above the other transform but it results in regressions because we lose information about which lanes are poison when splatting. E.g. the original transform will do:
Op(shuffle(V1, 0), <42, poison>) --> shuffle(Op(V1, <42, poison>), 0)
But this transform which is less general will do
Op(shuffle(V1, 0), <42, poison>) --> shuffle(Op(V1, <42, 42>), 0)
The scalable vector check is an early exit just to avoid calling getSplatValue
unnecessarily. But if I'm remembering correctly it's actually an invariant too, i.e. no fixed length vectors will actually pass through the new combine because they will have all been caught by the more general combine above.
Given a binary op on splatted vector and a splatted constant, InstCombine will normally pull the shuffle out in
InstCombinerImpl::foldVectorBinop
:However, this currently only operates on fixed length vectors. Splats of scalable vectors don't currently have their shuffle pulled out, e.g:
Having this canonical form with the shuffle pulled out is important as VectorCombine relies on it in order to scalarize binary ops in
scalarizeBinopOrCmp
, which would prevent the need for #137786. This also brings it in line for scalable binary ops with two non-constant operands: https://godbolt.org/z/M9f7ebzcaThis adds a combine just after the fixed-length version, but restricted to splats at index 0 so that it also handles the scalable case:
So the whilst the existing combine looks like:
Op(shuffle(V1, Mask), C) -> shuffle(Op(V1, NewC), Mask)
This patch adds:
Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0)
I think this could be generalized to other splat indexes that aren't zero, but I think it would be dead code since only fixed-length vectors can have non-zero shuffle indices, which would be covered by the existing combine.