Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[InstCombine] Pull shuffles out of binops with splatted ops #137948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Apr 30, 2025

Given a binary op on splatted vector and a splatted constant, InstCombine will normally pull the shuffle out in InstCombinerImpl::foldVectorBinop:

define <4 x i32> @f(i32 %x) {
  %x.insert = insertelement <4 x i32> poison, i32 %x, i64 0
  %x.splat = shufflevector <4 x i32> %x.insert, <4 x i32> poison, <4 x i32> zeroinitializer
  %res = add <4 x i32> %x.splat, splat (i32 42)
  ret <4 x i32> %res
}
define <4 x i32> @f(i32 %x) {
  %x.insert = insertelement <4 x i32> poison, i32 %x, i64 0
  %1 = add <4 x i32> %x.insert, <i32 42, i32 poison, i32 poison, i32 poison>
  %res = shufflevector <4 x i32> %1, <4 x i32> poison, <4 x i32> zeroinitializer
  ret <4 x i32> %res
}

However, this currently only operates on fixed length vectors. Splats of scalable vectors don't currently have their shuffle pulled out, e.g:

define <vscale x 4 x i32> @f(i32 %x) {
  %x.insert = insertelement <vscale x 4 x i32> poison, i32 %x, i64 0
  %x.splat = shufflevector <vscale x 4 x i32> %x.insert, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
  %res = add <vscale x 4 x i32> %x.splat, splat (i32 42)
  ret <vscale x 4 x i32> %res
}

Having this canonical form with the shuffle pulled out is important as VectorCombine relies on it in order to scalarize binary ops in scalarizeBinopOrCmp, which would prevent the need for #137786. This also brings it in line for scalable binary ops with two non-constant operands: https://godbolt.org/z/M9f7ebzca

This adds a combine just after the fixed-length version, but restricted to splats at index 0 so that it also handles the scalable case:

So the whilst the existing combine looks like: Op(shuffle(V1, Mask), C) -> shuffle(Op(V1, NewC), Mask)

This patch adds: Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0)

I think this could be generalized to other splat indexes that aren't zero, but I think it would be dead code since only fixed-length vectors can have non-zero shuffle indices, which would be covered by the existing combine.

@llvmbot
Copy link
Member

llvmbot commented Apr 30, 2025

@llvm/pr-subscribers-llvm-transforms

Author: Luke Lau (lukel97)

Changes

Given a binary op on splatted vector and a splatted constant, InstCombine will normally pull the shuffle out in InstCombinerImpl::foldVectorBinop:

define &lt;4 x i32&gt; @<!-- -->f(i32 %x) {
  %x.insert = insertelement &lt;4 x i32&gt; poison, i32 %x, i64 0
  %x.splat = shufflevector &lt;4 x i32&gt; %x.insert, &lt;4 x i32&gt; poison, &lt;4 x i32&gt; zeroinitializer
  %res = add &lt;4 x i32&gt; %x.splat, splat (i32 42)
  ret &lt;4 x i32&gt; %res
}
define &lt;4 x i32&gt; @<!-- -->f(i32 %x) {
  %x.insert = insertelement &lt;4 x i32&gt; poison, i32 %x, i64 0
  %1 = add &lt;4 x i32&gt; %x.insert, &lt;i32 42, i32 poison, i32 poison, i32 poison&gt;
  %res = shufflevector &lt;4 x i32&gt; %1, &lt;4 x i32&gt; poison, &lt;4 x i32&gt; zeroinitializer
  ret &lt;4 x i32&gt; %res
}

However, this currently only operates on fixed length vectors. Splats of scalable vectors don't currently have their shuffle pulled out, e.g:

define &lt;vscale x 4 x i32&gt; @<!-- -->f(i32 %x) {
  %x.insert = insertelement &lt;vscale x 4 x i32&gt; poison, i32 %x, i64 0
  %x.splat = shufflevector &lt;vscale x 4 x i32&gt; %x.insert, &lt;vscale x 4 x i32&gt; poison, &lt;vscale x 4 x i32&gt; zeroinitializer
  %res = add &lt;vscale x 4 x i32&gt; %x.splat, splat (i32 42)
  ret &lt;vscale x 4 x i32&gt; %res
}

Having this canonical form with the shuffle pulled out is important as VectorCombine relies on it in order to scalarize binary ops in scalarizeBinopOrCmp, which would prevent the need for #137786. This also brings it in line for scalable binary ops with two non-constant operands: https://godbolt.org/z/M9f7ebzca

This adds a combine just after the fixed-length version, but restricted to splats at index 0 so that it also handles the scalable case:

So the whilst the existing combine looks like: Op(shuffle(V1, Mask), C) -&gt; shuffle(Op(V1, NewC), Mask)

This patch adds: Op(shuffle(V1, 0), (splat C)) -&gt; shuffle(Op(V1, (splat C)), 0)

I think this could be generalized to other splat indexes that aren't zero, but I think it would be dead code since only fixed-length vectors can have non-zero shuffle indices, which would be covered by the existing combine.


Full diff: https://github.com/llvm/llvm-project/pull/137948.diff

3 Files Affected:

  • (modified) llvm/lib/Transforms/InstCombine/InstructionCombining.cpp (+25)
  • (modified) llvm/test/Transforms/InstCombine/getelementptr.ll (+2-2)
  • (modified) llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll (+23)
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index f807f5f4519fc..ecb3899eafe67 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -2271,6 +2271,31 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
     }
   }
 
+  // Similar to the combine above, but handles the case for scalable vectors
+  // where both V1 and C are splats.
+  //
+  // Op(shuffle(V1, 0), (splat C)) -> shuffle(Op(V1, (splat C)), 0)
+  if (isa<ScalableVectorType>(Inst.getType()) &&
+      match(&Inst, m_c_BinOp(m_OneUse(m_Shuffle(m_Value(V1), m_Poison(),
+                                                m_ZeroMask())),
+                             m_ImmConstant(C)))) {
+    if (Constant *Splat = C->getSplatValue()) {
+      bool ConstOp1 = isa<Constant>(RHS);
+      VectorType *V1Ty = cast<VectorType>(V1->getType());
+      Constant *NewC = ConstantVector::getSplat(V1Ty->getElementCount(), Splat);
+
+      Value *NewLHS = ConstOp1 ? V1 : NewC;
+      Value *NewRHS = ConstOp1 ? NewC : V1;
+      Value *XY = Builder.CreateBinOp(Opcode, NewLHS, NewRHS);
+      if (auto *BO = dyn_cast<BinaryOperator>(XY))
+        BO->copyIRFlags(&Inst);
+
+      VectorType *VTy = cast<VectorType>(Inst.getType());
+      SmallVector<int> NewM(VTy->getElementCount().getKnownMinValue(), 0);
+      return new ShuffleVectorInst(XY, NewM);
+    }
+  }
+
   // Try to reassociate to sink a splat shuffle after a binary operation.
   if (Inst.isAssociative() && Inst.isCommutative()) {
     // Canonicalize shuffle operand as LHS.
diff --git a/llvm/test/Transforms/InstCombine/getelementptr.ll b/llvm/test/Transforms/InstCombine/getelementptr.ll
index feba952919b9a..61236df80bfa6 100644
--- a/llvm/test/Transforms/InstCombine/getelementptr.ll
+++ b/llvm/test/Transforms/InstCombine/getelementptr.ll
@@ -282,8 +282,8 @@ define <2 x i1> @test13_fixed_scalable(i64 %X, ptr %P, <2 x i64> %y) nounwind {
 define <vscale x 2 x i1> @test13_scalable_scalable(i64 %X, ptr %P, <vscale x 2 x i64> %y) nounwind {
 ; CHECK-LABEL: @test13_scalable_scalable(
 ; CHECK-NEXT:    [[DOTSPLATINSERT:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[X:%.*]], i64 0
-; CHECK-NEXT:    [[DOTSPLAT:%.*]] = shufflevector <vscale x 2 x i64> [[DOTSPLATINSERT]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
-; CHECK-NEXT:    [[A_IDX:%.*]] = shl nsw <vscale x 2 x i64> [[DOTSPLAT]], splat (i64 3)
+; CHECK-NEXT:    [[TMP3:%.*]] = shl nsw <vscale x 2 x i64> [[DOTSPLATINSERT]], splat (i64 3)
+; CHECK-NEXT:    [[A_IDX:%.*]] = shufflevector <vscale x 2 x i64> [[TMP3]], <vscale x 2 x i64> poison, <vscale x 2 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP1:%.*]] = call i64 @llvm.vscale.i64()
 ; CHECK-NEXT:    [[TMP2:%.*]] = shl i64 [[TMP1]], 4
 ; CHECK-NEXT:    [[DOTSPLATINSERT1:%.*]] = insertelement <vscale x 2 x i64> poison, i64 [[TMP2]], i64 0
diff --git a/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll b/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
index c6329af164623..926f272d36a4b 100644
--- a/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
+++ b/llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll
@@ -1789,3 +1789,26 @@ define <4 x i32> @PR46872(<4 x i32> %x) {
   ret <4 x i32> %a
 }
 
+define <vscale x 4 x i32> @scalable_splat_binop_constant_rhs(<vscale x 4 x i32> %x) {
+; CHECK-LABEL: @scalable_splat_binop_constant_rhs(
+; CHECK-NEXT:    [[R1:%.*]] = add <vscale x 4 x i32> [[R:%.*]], splat (i32 42)
+; CHECK-NEXT:    [[R2:%.*]] = shufflevector <vscale x 4 x i32> [[R1]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT:    ret <vscale x 4 x i32> [[R2]]
+;
+
+  %splatx = shufflevector <vscale x 4 x i32> %x, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
+  %r = add <vscale x 4 x i32> %splatx, splat (i32 42)
+  ret <vscale x 4 x i32> %r
+}
+
+define <vscale x 4 x float> @scalable_splat_binop_constant_lhs(<vscale x 4 x float> %x) {
+; CHECK-LABEL: @scalable_splat_binop_constant_lhs(
+; CHECK-NEXT:    [[R1:%.*]] = fadd <vscale x 4 x float> [[R:%.*]], splat (float 4.200000e+01)
+; CHECK-NEXT:    [[R2:%.*]] = shufflevector <vscale x 4 x float> [[R1]], <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
+; CHECK-NEXT:    ret <vscale x 4 x float> [[R2]]
+;
+
+  %splatx = shufflevector <vscale x 4 x float> %x, <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
+  %r = fadd <vscale x 4 x float> splat (float 42.0), %splatx
+  ret <vscale x 4 x float> %r
+}

@@ -2271,6 +2271,27 @@ Instruction *InstCombinerImpl::foldVectorBinop(BinaryOperator &Inst) {
}
}

// Similar to the combine above, but handles the case for scalable vectors
// where both V1 and C are splats.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// where both V1 and C are splats.
// where both shuffle(V1, 0) and C are splats.

V1 is not required to be a splat vector.

Copy link
Member

@dtcxzyw dtcxzyw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you!

Copy link
Collaborator

@preames preames left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants