Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[InstCombine] Pull shuffles out of binops with splatted ops #137948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Apr 30, 2025

Given a binary op on splatted vector and a splatted constant, InstCombine will normally pull the shuffle out in InstCombinerImpl::foldVectorBinop:

define <4 x i32> @f(i32 %x) {
  %x.insert = insertelement <4 x i32> poison, i32 %x, i64 0
  %x.splat = shufflevector <4 x i32> %x.insert, <4 x i32> poison, <4 x i32> zeroinitializer
  %res = add <4 x i32> %x.splat, splat (i32 42)
  ret <4 x i32> %res
}
define <4 x i32> @f(i32 %x) {
  %x.insert = insertelement <4 x i32> poison, i32 %x, i64 0
  %1 = add <4 x i32> %x.insert, <i32 42, i32 poison, i32 poison, i32 poison>
  %res = shufflevector <4 x i32> %1, <4 x i32> poison, <4 x i32> zeroinitializer
  ret <4 x i32> %res
}

However, this currently only operates on fixed length vectors. Splats of scalable vectors don't currently have their shuffle pulled out, e.g:

define <vscale x 4 x i32> @f(i32 %x) {
  %x.insert = insertelement <vscale x 4 x i32> poison, i32 %x, i64 0
  %x.splat = shufflevector <vscale x 4 x i32> %x.insert, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
  %res = add <vscale x 4 x i32> %x.splat, splat (i32 42)
  ret <vscale x 4 x i32> %res
}

Having this canonical form with the shuffle pulled out is important as VectorCombine relies on it in order to scalarize binary ops in scalarizeBinopOrCmp, which would prevent the need for #137786. This also brings it in line for scalable binary ops with two non-constant operands: https://godbolt.org/z/M9f7ebzca

This adds a combine just after the fixed-length version, but restricted to splats at index 0 so that it also handles the scalable case:

So the whilst the existing combine looks like: Op(shuffle(V1, Mask), C) -> shuffle(Op(V1, NewC), Mask)

This patch adds: Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0)

I think this could be generalized to other splat indexes that aren't zero, but I think it would be dead code since only fixed-length vectors can have non-zero shuffle indices, which would be covered by the existing combine.

@llvmbot
Copy link
Member

llvmbot commented Apr 30, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Luke Lau (lukel97)

Changes

Given a binary op on splatted vector and a splatted constant, InstCombine will normally pull the shuffle out in InstCombinerImpl::foldVectorBinop:

define &lt;4 x i32&gt; @<!-- -->f(i32 %x) {
  %x.insert = insertelement &lt;4 x i32&gt; poison, i32 %x, i64 0
  %x.splat = shufflevector &lt;4 x i32&gt; %x.insert, &lt;4 x i32&gt; poison, &lt;4 x i32&gt; zeroinitializer
  %res = add &lt;4 x i32&gt; %x.splat, splat (i32 42)
  ret &lt;4 x i32&gt; %res
}
define &lt;4 x i32&gt; @<!-- -->f(i32 %x) {
  %x.insert = insertelement &lt;4 x i32&gt; poison, i32 %x, i64 0
  %1 = add &lt;4 x i32&gt; %x.insert, &lt;i32 42, i32 poison, i32 poison, i32 poison&gt;
  %res = shufflevector &lt;4 x i32&gt; %1, &lt;4 x i32&gt; poison, &lt;4 x i32&gt; zeroinitializer
  ret &lt;4 x i32&gt; %res
}

However, this currently only operates on fixed length vectors. Splats of scalable vectors don't currently have their shuffle pulled out, e.g:

define &lt;vscale x 4 x i32&gt; @<!-- -->f(i32 %x) {
  %x.insert = insertelement &lt;vscale x 4 x i32&gt; poison, i32 %x, i64 0
  %x.splat = shufflevector &lt;vscale x 4 x i32&gt; %x.insert, &lt;vscale x 4 x i32&gt; poison, &lt;vscale x 4 x i32&gt; zeroinitializer
  %res = add &lt;vscale x 4 x i32&gt; %x.splat, splat (i32 42)
  ret &lt;vscale x 4 x i32&gt; %res
}

Having this canonical form with the shuffle pulled out is important as VectorCombine relies on it in order to scalarize binary ops in scalarizeBinopOrCmp, which would prevent the need for #137786. This also brings it in line for scalable binary ops with two non-constant operands: https://godbolt.org/z/M9f7ebzca

This adds a combine just after the fixed-length version, but restricted to splats at index 0 so that it also handles the scalable case:

So the whilst the existing combine looks like: Op(shuffle(V1, Mask), C) -&gt; shuffle(Op(V1, NewC), Mask)

This patch adds: Op(shuffle(V1, 0), (splat C)) -&gt; shuffle(Op(V1, (splat C)), 0)

I think this could be generalized to other splat indexes that aren't zero, but I think it would be dead code since only fixed-length vectors can have non-zero shuffle indices, which would be covered by the existing combine.


Full diff: https://github.com/llvm/llvm-project/pull/137948.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+25)
  • (modified) llvm/test/Transforms/InstCombine/getelementptr.ll (+2-2)
  • (modified) llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll (+23)
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index f807f5f4519fc..ecb3899eafe67 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2271,6 +2271,31 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
     }
   }
 
+  // Similar to the combine above, but handles the case for scalable vectors
+  // where both V1 and C are splats.
+  //
+  // Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0)
+  if (isa<ScalableVectorType>(Inst.getType()) &&
+      match(&Inst, m_c_BinOp(m_OneUse(m_Shuffle(m_Value(V1), m_Poison(),
+                                                m_ZeroMask())),
+                             m_ImmConstant(C)))) {
+    if (Constant *Splat = C->getSplatValue()) {
+      bool ConstOp1 = isa<Constant>(RHS);
+      VectorType *V1Ty = cast<VectorType>(V1->getType());
+      Constant *NewC = ConstantVector::getSplat(V1Ty->getElementCount(), Splat);
+
+      Value *NewLHS = ConstOp1 ? V1 : NewC;
+      Value *NewRHS = ConstOp1 ? NewC : V1;
+      Value *XY = Builder.CreateBinOp(Opcode, NewLHS, NewRHS);
+      if (auto *BO = dyn_cast<BinaryOperator>(XY))
+        BO->copyIRFlags(&Inst);
+
+      VectorType *VTy = cast<VectorType>(Inst.getType());
+      SmallVector<int> NewM(VTy->getElementCount().getKnownMinValue(), 0);
+      return new ShuffleVectorInst(XY, NewM);
+    }
+  }
+
   // Try to reassociate to sink a splat shuffle after a binary operation.
   if (Inst.isAssociative() && Inst.isCommutative()) {
     // Canonicalize shuffle operand as LHS.
diff --git a/llvm/test/Transforms/InstCombine/getelementptr.ll b/llvm/test/Transforms/InstCombine/getelementptr.ll
index feba952919b9a..61236df80bfa6 100644
--- a/llvm/test/Transforms/InstCombine/getelementptr.ll
+++ b/llvm/test/Transforms/InstCombine/getelementptr.ll
@@ -282,8 +282,8 @@ define <2 x i1> @test13_fixed_scalable(i64 %X, ptr %P, <2 x i64> %y) nounwind {
 define <vscale x 2 x i1> @test13_scalable_scalable(i64 %X, ptr %P, <vscale x 2 x i64> %y) nounwind {
 ; CHECK-LABEL: @test13_scalable_scalable(
 ; CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[X:%.*]], i64 0
-; CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <vscale x 2 x i64> [[DOTSPLATINSERT]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
-; CHECK-NEXT:    [[A_IDX:%.*]] = shl nsw <vscale x 2 x i64> [[DOTSPLAT]], splat (i64 3)
+; CHECK-NEXT:    [[TMP3:%.*]] = shl nsw <vscale x 2 x i64> [[DOTSPLATINSERT]], splat (i64 3)
+; CHECK-NEXT:    [[A_IDX:%.*]] = shufflevector <vscale x 2 x i64> [[TMP3]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[TMP2:%.*]] = shl i64 [[TMP1]], 4
 ; CHECK-NEXT:    [[DOTSPLATINSERT1:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[TMP2]], i64 0
diff --git a/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll b/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
index c6329af164623..926f272d36a4b 100644
--- a/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
@@ -1789,3 +1789,26 @@ define <4 x i32> @PR46872(<4 x i32> %x) {
   ret <4 x i32> %a
 }
 
+define <vscale x 4 x i32> @scalable_splat_binop_constant_rhs(<vscale x 4 x i32> %x) {
+; CHECK-LABEL: @scalable_splat_binop_constant_rhs(
+; CHECK-NEXT:    [[R1:%.*]] = add <vscale x 4 x i32> [[R:%.*]], splat (i32 42)
+; CHECK-NEXT:    [[R2:%.*]] = shufflevector <vscale x 4 x i32> [[R1]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT:    ret <vscale x 4 x i32> [[R2]]
+;
+
+  %splatx = shufflevector <vscale x 4 x i32> %x, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
+  %r = add <vscale x 4 x i32> %splatx, splat (i32 42)
+  ret <vscale x 4 x i32> %r
+}
+
+define <vscale x 4 x float> @scalable_splat_binop_constant_lhs(<vscale x 4 x float> %x) {
+; CHECK-LABEL: @scalable_splat_binop_constant_lhs(
+; CHECK-NEXT:    [[R1:%.*]] = fadd <vscale x 4 x float> [[R:%.*]], splat (float 4.200000e+01)
+; CHECK-NEXT:    [[R2:%.*]] = shufflevector <vscale x 4 x float> [[R1]], <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT:    ret <vscale x 4 x float> [[R2]]
+;
+
+  %splatx = shufflevector <vscale x 4 x float> %x, <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
+  %r = fadd <vscale x 4 x float> splat (float 42.0), %splatx
+  ret <vscale x 4 x float> %r
+}

@@ -2271,6 +2271,27 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
}
}

// Similar to the combine above, but handles the case for scalable vectors
// where both V1 and C are splats.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// where both V1 and C are splats.
// where both shuffle(V1, 0) and C are splats.

V1 is not required to be a splat vector.

// Similar to the combine above, but handles the case for scalable vectors
// where both V1 and C are splats.
//
// Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This transformation may introduce immediate UB.
Counterexample: sdiv (splat C), shuffle(V1, 0) -> shuffle(Op((splat C), V1), 0) is invalid if V1 has a zero value at a non-zero index.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a check for isSafeToSpeculativelyExecuteWithVariableReplaced earlier on in the function which the previous combine also relies on, so I think this should also be covered in the new combine. I've added a negative test in e9d51e7

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but the transform you're modeling off just above has several other checks for introducing UB. The mask logic is one, the getSafeVectorConstantForBinop is another. I'm not sure the mask reasoning is needed for this one, but the divrem/shift definitely is? If not, why is the previous block required?

Copy link
Contributor Author

@lukel97 lukel97 May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

getSafeVectorConstantForBinop is needed in the previous block because if the mask has undef then the "new constant" will have undef for those lanes, which need swapped out e.g:

  %shuffle = shufflevector <4 x i32> %v1, <4 x i32> poison, <4 x i32> <i32 undef, i32 1, i32 2, i32 undef>
  %x = udiv %shuffle, <1, 2, 3, 4>

  ; undefs need replacing with safe constant
  %x = udiv %v1, <undef, 2, 3, undef>
  %shuffle = shufflevector <4 x i32> %t1, <4 x i32> poison, <4 x i32> <i32 undef, i32 1, i32 2, i32 undef>

We don't have that problem in the new combine because the mask is always zero

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you!

// Similar to the combine above, but handles the case for scalable vectors
// where both V1 and C are splats.
//
// Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but the transform you're modeling off just above has several other checks for introducing UB. The mask logic is one, the getSafeVectorConstantForBinop is another. I'm not sure the mask reasoning is needed for this one, but the divrem/shift definitely is? If not, why is the previous block required?

// where both shuffle(V1, 0) and C are splats.
//
// Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0)
if (isa<ScalableVectorType>(Inst.getType()) &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code structure wise: Remove the scalable vector check, move this above the other transform. That way fixed vector splats go through this too, and we have a much higher chance of finding any bugs.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally had this above the other transform but it results in regressions because we lose information about which lanes are poison when splatting. E.g. the original transform will do:

Op(shuffle(V1, 0), <42, poison>) --> shuffle(Op(V1, <42, poison>), 0)

But this transform which is less general will do

Op(shuffle(V1, 0), <42, poison>) --> shuffle(Op(V1, <42, 42>), 0)

The scalable vector check is an early exit just to avoid calling getSplatValue unnecessarily. But if I'm remembering correctly it's actually an invariant too, i.e. no fixed length vectors will actually pass through the new combine because they will have all been caught by the more general combine above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants