Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[LoopVectorizer] Bundle partial reductions inside VPMulAccumulateReductionRecipe #136173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: users/SamTebbs33/elvis-vp-arm-mve-transform
Choose a base branch
from

Conversation

SamTebbs33
Copy link
Collaborator

This PR bundles compatible partial reductions into a VPMulAccumulateReductionRecipe so that the cost of the extends and mul can be hidden. Depends on #113903.

At the moment only partial reductions with the same extension type are supported as VPMulAccumulateReductionRecipe only supports such a scheme.

@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2025

@llvm/pr-subscribers-backend-aarch64
@llvm/pr-subscribers-llvm-analysis

@llvm/pr-subscribers-llvm-transforms

Author: Sam Tebbs (SamTebbs33)

Changes

This PR bundles compatible partial reductions into a VPMulAccumulateReductionRecipe so that the cost of the extends and mul can be hidden. Depends on #113903.

At the moment only partial reductions with the same extension type are supported as VPMulAccumulateReductionRecipe only supports such a scheme.


Patch is 206.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136173.diff

10 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+2)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+13)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+3-5)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+62-51)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+19-8)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+36-4)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll (+83-163)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-neon.ll (+261-501)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll (+93-94)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll (+154-6)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 99e21aca97631..78ab6d3faaf33 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -219,6 +219,8 @@ class TargetTransformInfo {
   /// Get the kind of extension that an instruction represents.
   static PartialReductionExtendKind
   getPartialReductionExtendKind(Instruction *I);
+  static PartialReductionExtendKind
+  getPartialReductionExtendKind(Instruction::CastOps ExtOpcode);
 
   /// Construct a TTI object using a type implementing the \c Concept
   /// API below.
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 4df551aca30a7..fdef0c484e12a 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -993,6 +993,19 @@ TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
   return PR_None;
 }
 
+TargetTransformInfo::PartialReductionExtendKind
+TargetTransformInfo::getPartialReductionExtendKind(
+    Instruction::CastOps ExtOpcode) {
+  switch (ExtOpcode) {
+  case Instruction::CastOps::ZExt:
+    return PR_ZeroExtend;
+  case Instruction::CastOps::SExt:
+    return PR_SignExtend;
+  default:
+    return PR_None;
+  }
+}
+
 TTI::CastContextHint
 TargetTransformInfo::getCastContextHint(const Instruction *I) {
   if (!I)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0233a32f99e6f..030bd86847e1c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8879,17 +8879,15 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
     ReductionOpcode = Instruction::Add;
   }
 
+  VPValue *Cond = nullptr;
   if (CM.blockNeedsPredicationForAnyReason(Reduction->getParent())) {
     assert((ReductionOpcode == Instruction::Add ||
             ReductionOpcode == Instruction::Sub) &&
            "Expected an ADD or SUB operation for predicated partial "
            "reductions (because the neutral element in the mask is zero)!");
-    VPValue *Mask = getBlockInMask(Reduction->getParent());
-    VPValue *Zero =
-        Plan.getOrAddLiveIn(ConstantInt::get(Reduction->getType(), 0));
-    BinOp = Builder.createSelect(Mask, BinOp, Zero, Reduction->getDebugLoc());
+    Cond = getBlockInMask(Reduction->getParent());
   }
-  return new VPPartialReductionRecipe(ReductionOpcode, BinOp, Accumulator,
+  return new VPPartialReductionRecipe(ReductionOpcode, Accumulator, BinOp, Cond,
                                       Reduction);
 }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 1724b12b23d41..660cea629d01d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2056,55 +2056,6 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
   }
 };
 
-/// A recipe for forming partial reductions. In the loop, an accumulator and
-/// vector operand are added together and passed to the next iteration as the
-/// next accumulator. After the loop body, the accumulator is reduced to a
-/// scalar value.
-class VPPartialReductionRecipe : public VPSingleDefRecipe {
-  unsigned Opcode;
-
-public:
-  VPPartialReductionRecipe(Instruction *ReductionInst, VPValue *Op0,
-                           VPValue *Op1)
-      : VPPartialReductionRecipe(ReductionInst->getOpcode(), Op0, Op1,
-                                 ReductionInst) {}
-  VPPartialReductionRecipe(unsigned Opcode, VPValue *Op0, VPValue *Op1,
-                           Instruction *ReductionInst = nullptr)
-      : VPSingleDefRecipe(VPDef::VPPartialReductionSC,
-                          ArrayRef<VPValue *>({Op0, Op1}), ReductionInst),
-        Opcode(Opcode) {
-    [[maybe_unused]] auto *AccumulatorRecipe =
-        getOperand(1)->getDefiningRecipe();
-    assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
-            isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
-           "Unexpected operand order for partial reduction recipe");
-  }
-  ~VPPartialReductionRecipe() override = default;
-
-  VPPartialReductionRecipe *clone() override {
-    return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1),
-                                        getUnderlyingInstr());
-  }
-
-  VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
-
-  /// Generate the reduction in the loop.
-  void execute(VPTransformState &State) override;
-
-  /// Return the cost of this VPPartialReductionRecipe.
-  InstructionCost computeCost(ElementCount VF,
-                              VPCostContext &Ctx) const override;
-
-  /// Get the binary op's opcode.
-  unsigned getOpcode() const { return Opcode; }
-
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
-  /// Print the recipe.
-  void print(raw_ostream &O, const Twine &Indent,
-             VPSlotTracker &SlotTracker) const override;
-#endif
-};
-
 /// A recipe for vectorizing a phi-node as a sequence of mask-based select
 /// instructions.
 class VPBlendRecipe : public VPSingleDefRecipe {
@@ -2376,6 +2327,58 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
   }
 };
 
+/// A recipe for forming partial reductions. In the loop, an accumulator and
+/// vector operand are added together and passed to the next iteration as the
+/// next accumulator. After the loop body, the accumulator is reduced to a
+/// scalar value.
+class VPPartialReductionRecipe : public VPReductionRecipe {
+  unsigned Opcode;
+
+public:
+  VPPartialReductionRecipe(Instruction *ReductionInst, VPValue *Op0,
+                           VPValue *Op1, VPValue *Cond)
+      : VPPartialReductionRecipe(ReductionInst->getOpcode(), Op0, Op1, Cond,
+                                 ReductionInst) {}
+  VPPartialReductionRecipe(unsigned Opcode, VPValue *Op0, VPValue *Op1,
+                           VPValue *Cond, Instruction *ReductionInst = nullptr)
+      : VPReductionRecipe(VPDef::VPPartialReductionSC, RecurKind::Add,
+                          FastMathFlags(), ReductionInst,
+                          ArrayRef<VPValue *>({Op0, Op1}), Cond, false, {}),
+        Opcode(Opcode) {
+    [[maybe_unused]] auto *AccumulatorRecipe = getChainOp()->getDefiningRecipe();
+    assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
+            isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
+           "Unexpected operand order for partial reduction recipe");
+  }
+  ~VPPartialReductionRecipe() override = default;
+
+  VPPartialReductionRecipe *clone() override {
+    return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1),
+                                        getCondOp(), getUnderlyingInstr());
+  }
+
+  VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
+
+  /// Generate the reduction in the loop.
+  void execute(VPTransformState &State) override;
+
+  /// Return the cost of this VPPartialReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+  /// Get the binary op's opcode.
+  unsigned getOpcode() const { return Opcode; }
+
+  /// Get the binary op this reduction is applied to.
+  VPValue *getBinOp() const { return getOperand(1); }
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+};
+
 /// A recipe to represent inloop reduction operations with vector-predication
 /// intrinsics, performing a reduction on a vector operand with the explicit
 /// vector length (EVL) into a scalar value, and adding the result to a chain.
@@ -2496,6 +2499,9 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
 
   Type *ResultTy;
 
+  /// If the reduction this is based on is a partial reduction.
+  bool IsPartialReduction = false;
+
   /// For cloning VPMulAccumulateReductionRecipe.
   VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
       : VPReductionRecipe(
@@ -2505,7 +2511,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
             MulAcc->getDebugLoc()),
         ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
-        ResultTy(MulAcc->getResultType()) {}
+        ResultTy(MulAcc->getResultType()),
+        IsPartialReduction(MulAcc->isPartialReduction()) {}
 
 public:
   VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
@@ -2518,7 +2525,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
             R->getDebugLoc()),
         ExtOp(Ext0->getOpcode()), IsNonNeg(Ext0->isNonNeg()),
-        ResultTy(ResultTy) {
+        ResultTy(ResultTy),
+        IsPartialReduction(isa<VPPartialReductionRecipe>(R)) {
     assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
                Instruction::Add &&
            "The reduction instruction in MulAccumulateteReductionRecipe must "
@@ -2589,6 +2597,9 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
 
   /// Return the non negative flag of the ext recipe.
   bool isNonNeg() const { return IsNonNeg; }
+
+  /// Return if the underlying reduction recipe is a partial reduction.
+  bool isPartialReduction() const { return IsPartialReduction; }
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 72b4c6d885e98..7a8ab2c4144e6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -287,14 +287,9 @@ InstructionCost
 VPPartialReductionRecipe::computeCost(ElementCount VF,
                                       VPCostContext &Ctx) const {
   std::optional<unsigned> Opcode = std::nullopt;
-  VPValue *BinOp = getOperand(0);
+  VPValue *BinOp = getBinOp();
 
-  // If the partial reduction is predicated, a select will be operand 0 rather
-  // than the binary op
   using namespace llvm::VPlanPatternMatch;
-  if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
-    BinOp = BinOp->getDefiningRecipe()->getOperand(1);
-
   // If BinOp is a negation, use the side effect of match to assign the actual
   // binary operation to BinOp
   match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
@@ -338,12 +333,18 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
   assert(getOpcode() == Instruction::Add &&
          "Unhandled partial reduction opcode");
 
-  Value *BinOpVal = State.get(getOperand(0));
-  Value *PhiVal = State.get(getOperand(1));
+  Value *BinOpVal = State.get(getBinOp());
+  Value *PhiVal = State.get(getChainOp());
   assert(PhiVal && BinOpVal && "Phi and Mul must be set");
 
   Type *RetTy = PhiVal->getType();
 
+  /// Mask the bin op output.
+  if (VPValue *Cond = getCondOp()) {
+    Value *Zero = ConstantInt::get(BinOpVal->getType(), 0);
+    BinOpVal = Builder.CreateSelect(State.get(Cond), BinOpVal, Zero);
+  }
+
   CallInst *V = Builder.CreateIntrinsic(
       RetTy, Intrinsic::experimental_vector_partial_reduce_add,
       {PhiVal, BinOpVal}, nullptr, "partial.reduce");
@@ -2432,6 +2433,14 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 InstructionCost
 VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
+  if (isPartialReduction()) {
+    return Ctx.TTI.getPartialReductionCost(
+        Instruction::Add, Ctx.Types.inferScalarType(getVecOp0()),
+        Ctx.Types.inferScalarType(getVecOp1()), getResultType(), VF,
+        TTI::getPartialReductionExtendKind(getExtOpcode()),
+        TTI::getPartialReductionExtendKind(getExtOpcode()), Instruction::Mul);
+  }
+
   Type *RedTy = Ctx.Types.inferScalarType(this);
   auto *SrcVecTy =
       cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
@@ -2509,6 +2518,8 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " = ";
   getChainOp()->printAsOperand(O, SlotTracker);
   O << " + ";
+  if (isPartialReduction())
+    O << "partial.";
   O << "reduce."
     << Instruction::getOpcodeName(
            RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 74bf7c4d3a39e..c0c1329161db8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2158,9 +2158,14 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
   Mul->insertBefore(MulAcc);
 
   // Generate VPReductionRecipe.
-  auto *Red = new VPReductionRecipe(
-      MulAcc->getRecurrenceKind(), FastMathFlags(), MulAcc->getChainOp(), Mul,
-      MulAcc->getCondOp(), MulAcc->isOrdered(), MulAcc->getDebugLoc());
+  VPReductionRecipe *Red = nullptr;
+  if (MulAcc->isPartialReduction())
+    Red = new VPPartialReductionRecipe(Instruction::Add, MulAcc->getChainOp(),
+                                       Mul, MulAcc->getCondOp());
+  else
+    Red = new VPReductionRecipe(MulAcc->getRecurrenceKind(), FastMathFlags(),
+                                MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+                                MulAcc->isOrdered(), MulAcc->getDebugLoc());
   Red->insertBefore(MulAcc);
 
   MulAcc->replaceAllUsesWith(Red);
@@ -2432,12 +2437,39 @@ static void tryToCreateAbstractReductionRecipe(VPReductionRecipe *Red,
   Red->replaceAllUsesWith(AbstractR);
 }
 
+static void
+tryToCreateAbstractPartialReductionRecipe(VPPartialReductionRecipe *PRed) {
+  if (PRed->getOpcode() != Instruction::Add)
+    return;
+
+  VPRecipeBase *BinOpR = PRed->getBinOp()->getDefiningRecipe();
+  auto *BinOp = dyn_cast<VPWidenRecipe>(BinOpR);
+  if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
+    return;
+
+  auto *Ext0 = dyn_cast<VPWidenCastRecipe>(BinOp->getOperand(0));
+  auto *Ext1 = dyn_cast<VPWidenCastRecipe>(BinOp->getOperand(1));
+  // TODO: Make work with extends of different signedness
+  if (!Ext0 || Ext0->hasMoreThanOneUniqueUser() || !Ext1 ||
+      Ext1->hasMoreThanOneUniqueUser() ||
+      Ext0->getOpcode() != Ext1->getOpcode())
+    return;
+
+  auto *AbstractR = new VPMulAccumulateReductionRecipe(PRed, BinOp, Ext0, Ext1,
+                                                       Ext0->getResultType());
+  AbstractR->insertBefore(PRed);
+  PRed->replaceAllUsesWith(AbstractR);
+  PRed->eraseFromParent();
+}
+
 void VPlanTransforms::convertToAbstractRecipes(VPlan &Plan, VPCostContext &Ctx,
                                                VFRange &Range) {
   for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
            vp_depth_first_deep(Plan.getVectorLoopRegion()))) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
-      if (auto *Red = dyn_cast<VPReductionRecipe>(&R))
+      if (auto *PRed = dyn_cast<VPPartialReductionRecipe>(&R))
+        tryToCreateAbstractPartialReductionRecipe(PRed);
+      else if (auto *Red = dyn_cast<VPReductionRecipe>(&R))
         tryToCreateAbstractReductionRecipe(Red, Ctx, Range);
     }
   }
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll
index f622701308d21..eb2667de339a8 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll
@@ -23,11 +23,11 @@ define i32 @dotp(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP2]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr i8, ptr [[TMP3]], i32 0
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP4]], align 1
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP2]]
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
 ; CHECK-NEXT:    [[TMP8:%.*]] = zext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP9:%.*]] = mul <16 x i32> [[TMP8]], [[TMP5]]
 ; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP9]])
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
@@ -132,7 +132,6 @@ define void @dotp_small_epilogue_vf(i64 %idx.neg, i8 %a) #1 {
 ; CHECK-NEXT:    [[IV_NEXT:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
 ; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <16 x i8> poison, i8 [[A]], i64 0
 ; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <16 x i8> [[BROADCAST_SPLATINSERT]], <16 x i8> poison, <16 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <16 x i8> [[BROADCAST_SPLAT]] to <16 x i32>
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
@@ -141,6 +140,7 @@ define void @dotp_small_epilogue_vf(i64 %idx.neg, i8 %a) #1 {
 ; CHECK-NEXT:    [[BROADCAST_SPLATINSERT2:%.*]] = insertelement <16 x i8> poison, i8 [[TMP2]], i64 0
 ; CHECK-NEXT:    [[BROADCAST_SPLAT3:%.*]] = shufflevector <16 x i8> [[BROADCAST_SPLATINSERT2]], <16 x i8> poison, <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[BROADCAST_SPLAT3]] to <16 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <16 x i8> [[BROADCAST_SPLAT]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
@@ -247,313 +247,233 @@ define i32 @dotp_predicated(i64 %N, ptr %a, ptr %b) {
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[PRED_LOAD_CONTINUE62:%.*]] ]
 ; CHECK-NEXT:    [[VEC_IND:%.*]] = phi <16 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15>, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[PRED_LOAD_CONTINUE62]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[PRED_LOAD_CONTINUE62]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
-; CHECK-NEXT:    [[TMP1:%.*]] = add i64 [[INDEX]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = add i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP3:%.*]] = add i64 [[INDEX]], 3
-; CHECK-NEXT:    [[TMP4:%.*]] = add i64 [[INDEX]], 4
-; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[INDEX]], 5
-; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 6
-; CHECK-NEXT:    [[TMP7:%.*]] = add i64 [[INDEX]], 7
-; CHECK-NEXT:    [[TMP8:%.*]] = add i64 [[INDEX]], 8
-; CHECK-NEXT:    [[TMP9:%.*]] = add i64 [[INDEX]], 9
-; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[INDEX]], 10
-; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 11
-; CHECK-NEXT:    [[TMP12:%.*]] = add i64 [[INDEX]], 12
-; CHECK-NEXT:    [[TMP13:%.*]] = add i64 [[INDEX]], 13
-; CHECK-NEXT:    [[TMP14:%.*]] = add i64 [[INDEX]], 14
-; CHECK-NEXT:    [[TMP15:%.*]] = add i64 [[INDEX]], 15
 ; CHECK-NEXT:    [[TMP16:%.*]] = icmp ule <16 x i64> [[VEC_IND]], [[BROADCAST_SPLAT]]
 ; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <16 x i1> [[TMP16]], i32 0
 ; CHECK-NEXT:    br i1 [[TMP17]], label [[PRED_LOAD_IF:%.*]], label [[PRED_LOAD_CONTINUE:%.*]]
 ; CHECK:       pred.load.if:
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP0]]
 ; CHECK-NEXT:    [[TMP19:%.*]] = load i8, ptr [[TMP18]], align 1
 ; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <16 x i8> poison, i8 [[TMP19]], i32 0
+; CHECK-NEXT:    [[TMP99:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP0]]
+; CHECK-NEXT:    [[TMP101:%.*]] = lo...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Apr 17, 2025

@llvm/pr-subscribers-vectorizers

Author: Sam Tebbs (SamTebbs33)

Changes

This PR bundles compatible partial reductions into a VPMulAccumulateReductionRecipe so that the cost of the extends and mul can be hidden. Depends on #113903.

At the moment only partial reductions with the same extension type are supported as VPMulAccumulateReductionRecipe only supports such a scheme.


Patch is 206.82 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/136173.diff

10 Files Affected:

  • (modified) llvm/include/llvm/Analysis/TargetTransformInfo.h (+2)
  • (modified) llvm/lib/Analysis/TargetTransformInfo.cpp (+13)
  • (modified) llvm/lib/Transforms/Vectorize/LoopVectorize.cpp (+3-5)
  • (modified) llvm/lib/Transforms/Vectorize/VPlan.h (+62-51)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp (+19-8)
  • (modified) llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp (+36-4)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll (+83-163)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-neon.ll (+261-501)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product.ll (+93-94)
  • (modified) llvm/test/Transforms/LoopVectorize/AArch64/vplan-printing.ll (+154-6)
diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h
index 99e21aca97631..78ab6d3faaf33 100644
--- a/llvm/include/llvm/Analysis/TargetTransformInfo.h
+++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h
@@ -219,6 +219,8 @@ class TargetTransformInfo {
   /// Get the kind of extension that an instruction represents.
   static PartialReductionExtendKind
   getPartialReductionExtendKind(Instruction *I);
+  static PartialReductionExtendKind
+  getPartialReductionExtendKind(Instruction::CastOps ExtOpcode);
 
   /// Construct a TTI object using a type implementing the \c Concept
   /// API below.
diff --git a/llvm/lib/Analysis/TargetTransformInfo.cpp b/llvm/lib/Analysis/TargetTransformInfo.cpp
index 4df551aca30a7..fdef0c484e12a 100644
--- a/llvm/lib/Analysis/TargetTransformInfo.cpp
+++ b/llvm/lib/Analysis/TargetTransformInfo.cpp
@@ -993,6 +993,19 @@ TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
   return PR_None;
 }
 
+TargetTransformInfo::PartialReductionExtendKind
+TargetTransformInfo::getPartialReductionExtendKind(
+    Instruction::CastOps ExtOpcode) {
+  switch (ExtOpcode) {
+  case Instruction::CastOps::ZExt:
+    return PR_ZeroExtend;
+  case Instruction::CastOps::SExt:
+    return PR_SignExtend;
+  default:
+    return PR_None;
+  }
+}
+
 TTI::CastContextHint
 TargetTransformInfo::getCastContextHint(const Instruction *I) {
   if (!I)
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 0233a32f99e6f..030bd86847e1c 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -8879,17 +8879,15 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
     ReductionOpcode = Instruction::Add;
   }
 
+  VPValue *Cond = nullptr;
   if (CM.blockNeedsPredicationForAnyReason(Reduction->getParent())) {
     assert((ReductionOpcode == Instruction::Add ||
             ReductionOpcode == Instruction::Sub) &&
            "Expected an ADD or SUB operation for predicated partial "
            "reductions (because the neutral element in the mask is zero)!");
-    VPValue *Mask = getBlockInMask(Reduction->getParent());
-    VPValue *Zero =
-        Plan.getOrAddLiveIn(ConstantInt::get(Reduction->getType(), 0));
-    BinOp = Builder.createSelect(Mask, BinOp, Zero, Reduction->getDebugLoc());
+    Cond = getBlockInMask(Reduction->getParent());
   }
-  return new VPPartialReductionRecipe(ReductionOpcode, BinOp, Accumulator,
+  return new VPPartialReductionRecipe(ReductionOpcode, Accumulator, BinOp, Cond,
                                       Reduction);
 }
 
diff --git a/llvm/lib/Transforms/Vectorize/VPlan.h b/llvm/lib/Transforms/Vectorize/VPlan.h
index 1724b12b23d41..660cea629d01d 100644
--- a/llvm/lib/Transforms/Vectorize/VPlan.h
+++ b/llvm/lib/Transforms/Vectorize/VPlan.h
@@ -2056,55 +2056,6 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
   }
 };
 
-/// A recipe for forming partial reductions. In the loop, an accumulator and
-/// vector operand are added together and passed to the next iteration as the
-/// next accumulator. After the loop body, the accumulator is reduced to a
-/// scalar value.
-class VPPartialReductionRecipe : public VPSingleDefRecipe {
-  unsigned Opcode;
-
-public:
-  VPPartialReductionRecipe(Instruction *ReductionInst, VPValue *Op0,
-                           VPValue *Op1)
-      : VPPartialReductionRecipe(ReductionInst->getOpcode(), Op0, Op1,
-                                 ReductionInst) {}
-  VPPartialReductionRecipe(unsigned Opcode, VPValue *Op0, VPValue *Op1,
-                           Instruction *ReductionInst = nullptr)
-      : VPSingleDefRecipe(VPDef::VPPartialReductionSC,
-                          ArrayRef<VPValue *>({Op0, Op1}), ReductionInst),
-        Opcode(Opcode) {
-    [[maybe_unused]] auto *AccumulatorRecipe =
-        getOperand(1)->getDefiningRecipe();
-    assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
-            isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
-           "Unexpected operand order for partial reduction recipe");
-  }
-  ~VPPartialReductionRecipe() override = default;
-
-  VPPartialReductionRecipe *clone() override {
-    return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1),
-                                        getUnderlyingInstr());
-  }
-
-  VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
-
-  /// Generate the reduction in the loop.
-  void execute(VPTransformState &State) override;
-
-  /// Return the cost of this VPPartialReductionRecipe.
-  InstructionCost computeCost(ElementCount VF,
-                              VPCostContext &Ctx) const override;
-
-  /// Get the binary op's opcode.
-  unsigned getOpcode() const { return Opcode; }
-
-#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
-  /// Print the recipe.
-  void print(raw_ostream &O, const Twine &Indent,
-             VPSlotTracker &SlotTracker) const override;
-#endif
-};
-
 /// A recipe for vectorizing a phi-node as a sequence of mask-based select
 /// instructions.
 class VPBlendRecipe : public VPSingleDefRecipe {
@@ -2376,6 +2327,58 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
   }
 };
 
+/// A recipe for forming partial reductions. In the loop, an accumulator and
+/// vector operand are added together and passed to the next iteration as the
+/// next accumulator. After the loop body, the accumulator is reduced to a
+/// scalar value.
+class VPPartialReductionRecipe : public VPReductionRecipe {
+  unsigned Opcode;
+
+public:
+  VPPartialReductionRecipe(Instruction *ReductionInst, VPValue *Op0,
+                           VPValue *Op1, VPValue *Cond)
+      : VPPartialReductionRecipe(ReductionInst->getOpcode(), Op0, Op1, Cond,
+                                 ReductionInst) {}
+  VPPartialReductionRecipe(unsigned Opcode, VPValue *Op0, VPValue *Op1,
+                           VPValue *Cond, Instruction *ReductionInst = nullptr)
+      : VPReductionRecipe(VPDef::VPPartialReductionSC, RecurKind::Add,
+                          FastMathFlags(), ReductionInst,
+                          ArrayRef<VPValue *>({Op0, Op1}), Cond, false, {}),
+        Opcode(Opcode) {
+    [[maybe_unused]] auto *AccumulatorRecipe = getChainOp()->getDefiningRecipe();
+    assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
+            isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
+           "Unexpected operand order for partial reduction recipe");
+  }
+  ~VPPartialReductionRecipe() override = default;
+
+  VPPartialReductionRecipe *clone() override {
+    return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1),
+                                        getCondOp(), getUnderlyingInstr());
+  }
+
+  VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
+
+  /// Generate the reduction in the loop.
+  void execute(VPTransformState &State) override;
+
+  /// Return the cost of this VPPartialReductionRecipe.
+  InstructionCost computeCost(ElementCount VF,
+                              VPCostContext &Ctx) const override;
+
+  /// Get the binary op's opcode.
+  unsigned getOpcode() const { return Opcode; }
+
+  /// Get the binary op this reduction is applied to.
+  VPValue *getBinOp() const { return getOperand(1); }
+
+#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
+  /// Print the recipe.
+  void print(raw_ostream &O, const Twine &Indent,
+             VPSlotTracker &SlotTracker) const override;
+#endif
+};
+
 /// A recipe to represent inloop reduction operations with vector-predication
 /// intrinsics, performing a reduction on a vector operand with the explicit
 /// vector length (EVL) into a scalar value, and adding the result to a chain.
@@ -2496,6 +2499,9 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
 
   Type *ResultTy;
 
+  /// If the reduction this is based on is a partial reduction.
+  bool IsPartialReduction = false;
+
   /// For cloning VPMulAccumulateReductionRecipe.
   VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
       : VPReductionRecipe(
@@ -2505,7 +2511,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
             MulAcc->getDebugLoc()),
         ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
-        ResultTy(MulAcc->getResultType()) {}
+        ResultTy(MulAcc->getResultType()),
+        IsPartialReduction(MulAcc->isPartialReduction()) {}
 
 public:
   VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
@@ -2518,7 +2525,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
             WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
             R->getDebugLoc()),
         ExtOp(Ext0->getOpcode()), IsNonNeg(Ext0->isNonNeg()),
-        ResultTy(ResultTy) {
+        ResultTy(ResultTy),
+        IsPartialReduction(isa<VPPartialReductionRecipe>(R)) {
     assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
                Instruction::Add &&
            "The reduction instruction in MulAccumulateteReductionRecipe must "
@@ -2589,6 +2597,9 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
 
   /// Return the non negative flag of the ext recipe.
   bool isNonNeg() const { return IsNonNeg; }
+
+  /// Return if the underlying reduction recipe is a partial reduction.
+  bool isPartialReduction() const { return IsPartialReduction; }
 };
 
 /// VPReplicateRecipe replicates a given instruction producing multiple scalar
diff --git a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
index 72b4c6d885e98..7a8ab2c4144e6 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
@@ -287,14 +287,9 @@ InstructionCost
 VPPartialReductionRecipe::computeCost(ElementCount VF,
                                       VPCostContext &Ctx) const {
   std::optional<unsigned> Opcode = std::nullopt;
-  VPValue *BinOp = getOperand(0);
+  VPValue *BinOp = getBinOp();
 
-  // If the partial reduction is predicated, a select will be operand 0 rather
-  // than the binary op
   using namespace llvm::VPlanPatternMatch;
-  if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
-    BinOp = BinOp->getDefiningRecipe()->getOperand(1);
-
   // If BinOp is a negation, use the side effect of match to assign the actual
   // binary operation to BinOp
   match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
@@ -338,12 +333,18 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
   assert(getOpcode() == Instruction::Add &&
          "Unhandled partial reduction opcode");
 
-  Value *BinOpVal = State.get(getOperand(0));
-  Value *PhiVal = State.get(getOperand(1));
+  Value *BinOpVal = State.get(getBinOp());
+  Value *PhiVal = State.get(getChainOp());
   assert(PhiVal && BinOpVal && "Phi and Mul must be set");
 
   Type *RetTy = PhiVal->getType();
 
+  /// Mask the bin op output.
+  if (VPValue *Cond = getCondOp()) {
+    Value *Zero = ConstantInt::get(BinOpVal->getType(), 0);
+    BinOpVal = Builder.CreateSelect(State.get(Cond), BinOpVal, Zero);
+  }
+
   CallInst *V = Builder.CreateIntrinsic(
       RetTy, Intrinsic::experimental_vector_partial_reduce_add,
       {PhiVal, BinOpVal}, nullptr, "partial.reduce");
@@ -2432,6 +2433,14 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
 InstructionCost
 VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
                                             VPCostContext &Ctx) const {
+  if (isPartialReduction()) {
+    return Ctx.TTI.getPartialReductionCost(
+        Instruction::Add, Ctx.Types.inferScalarType(getVecOp0()),
+        Ctx.Types.inferScalarType(getVecOp1()), getResultType(), VF,
+        TTI::getPartialReductionExtendKind(getExtOpcode()),
+        TTI::getPartialReductionExtendKind(getExtOpcode()), Instruction::Mul);
+  }
+
   Type *RedTy = Ctx.Types.inferScalarType(this);
   auto *SrcVecTy =
       cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
@@ -2509,6 +2518,8 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
   O << " = ";
   getChainOp()->printAsOperand(O, SlotTracker);
   O << " + ";
+  if (isPartialReduction())
+    O << "partial.";
   O << "reduce."
     << Instruction::getOpcodeName(
            RecurrenceDescriptor::getOpcode(getRecurrenceKind()))
diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
index 74bf7c4d3a39e..c0c1329161db8 100644
--- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
+++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
@@ -2158,9 +2158,14 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
   Mul->insertBefore(MulAcc);
 
   // Generate VPReductionRecipe.
-  auto *Red = new VPReductionRecipe(
-      MulAcc->getRecurrenceKind(), FastMathFlags(), MulAcc->getChainOp(), Mul,
-      MulAcc->getCondOp(), MulAcc->isOrdered(), MulAcc->getDebugLoc());
+  VPReductionRecipe *Red = nullptr;
+  if (MulAcc->isPartialReduction())
+    Red = new VPPartialReductionRecipe(Instruction::Add, MulAcc->getChainOp(),
+                                       Mul, MulAcc->getCondOp());
+  else
+    Red = new VPReductionRecipe(MulAcc->getRecurrenceKind(), FastMathFlags(),
+                                MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
+                                MulAcc->isOrdered(), MulAcc->getDebugLoc());
   Red->insertBefore(MulAcc);
 
   MulAcc->replaceAllUsesWith(Red);
@@ -2432,12 +2437,39 @@ static void tryToCreateAbstractReductionRecipe(VPReductionRecipe *Red,
   Red->replaceAllUsesWith(AbstractR);
 }
 
+static void
+tryToCreateAbstractPartialReductionRecipe(VPPartialReductionRecipe *PRed) {
+  if (PRed->getOpcode() != Instruction::Add)
+    return;
+
+  VPRecipeBase *BinOpR = PRed->getBinOp()->getDefiningRecipe();
+  auto *BinOp = dyn_cast<VPWidenRecipe>(BinOpR);
+  if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
+    return;
+
+  auto *Ext0 = dyn_cast<VPWidenCastRecipe>(BinOp->getOperand(0));
+  auto *Ext1 = dyn_cast<VPWidenCastRecipe>(BinOp->getOperand(1));
+  // TODO: Make work with extends of different signedness
+  if (!Ext0 || Ext0->hasMoreThanOneUniqueUser() || !Ext1 ||
+      Ext1->hasMoreThanOneUniqueUser() ||
+      Ext0->getOpcode() != Ext1->getOpcode())
+    return;
+
+  auto *AbstractR = new VPMulAccumulateReductionRecipe(PRed, BinOp, Ext0, Ext1,
+                                                       Ext0->getResultType());
+  AbstractR->insertBefore(PRed);
+  PRed->replaceAllUsesWith(AbstractR);
+  PRed->eraseFromParent();
+}
+
 void VPlanTransforms::convertToAbstractRecipes(VPlan &Plan, VPCostContext &Ctx,
                                                VFRange &Range) {
   for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
            vp_depth_first_deep(Plan.getVectorLoopRegion()))) {
     for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
-      if (auto *Red = dyn_cast<VPReductionRecipe>(&R))
+      if (auto *PRed = dyn_cast<VPPartialReductionRecipe>(&R))
+        tryToCreateAbstractPartialReductionRecipe(PRed);
+      else if (auto *Red = dyn_cast<VPReductionRecipe>(&R))
         tryToCreateAbstractReductionRecipe(Red, Ctx, Range);
     }
   }
diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll
index f622701308d21..eb2667de339a8 100644
--- a/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll
+++ b/llvm/test/Transforms/LoopVectorize/AArch64/partial-reduce-dot-product-epilogue.ll
@@ -23,11 +23,11 @@ define i32 @dotp(ptr %a, ptr %b) #0 {
 ; CHECK-NEXT:    [[TMP3:%.*]] = getelementptr i8, ptr [[A]], i64 [[TMP2]]
 ; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr i8, ptr [[TMP3]], i32 0
 ; CHECK-NEXT:    [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[TMP4]], align 1
-; CHECK-NEXT:    [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i8, ptr [[B]], i64 [[TMP2]]
 ; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i8, ptr [[TMP6]], i32 0
 ; CHECK-NEXT:    [[WIDE_LOAD1:%.*]] = load <16 x i8>, ptr [[TMP7]], align 1
 ; CHECK-NEXT:    [[TMP8:%.*]] = zext <16 x i8> [[WIDE_LOAD1]] to <16 x i32>
+; CHECK-NEXT:    [[TMP5:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP9:%.*]] = mul <16 x i32> [[TMP8]], [[TMP5]]
 ; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP9]])
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
@@ -132,7 +132,6 @@ define void @dotp_small_epilogue_vf(i64 %idx.neg, i8 %a) #1 {
 ; CHECK-NEXT:    [[IV_NEXT:%.*]] = sub i64 [[TMP0]], [[N_MOD_VF]]
 ; CHECK-NEXT:    [[BROADCAST_SPLATINSERT:%.*]] = insertelement <16 x i8> poison, i8 [[A]], i64 0
 ; CHECK-NEXT:    [[BROADCAST_SPLAT:%.*]] = shufflevector <16 x i8> [[BROADCAST_SPLATINSERT]], <16 x i8> poison, <16 x i32> zeroinitializer
-; CHECK-NEXT:    [[TMP1:%.*]] = sext <16 x i8> [[BROADCAST_SPLAT]] to <16 x i32>
 ; CHECK-NEXT:    br label [[VECTOR_BODY:%.*]]
 ; CHECK:       vector.body:
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
@@ -141,6 +140,7 @@ define void @dotp_small_epilogue_vf(i64 %idx.neg, i8 %a) #1 {
 ; CHECK-NEXT:    [[BROADCAST_SPLATINSERT2:%.*]] = insertelement <16 x i8> poison, i8 [[TMP2]], i64 0
 ; CHECK-NEXT:    [[BROADCAST_SPLAT3:%.*]] = shufflevector <16 x i8> [[BROADCAST_SPLATINSERT2]], <16 x i8> poison, <16 x i32> zeroinitializer
 ; CHECK-NEXT:    [[TMP3:%.*]] = sext <16 x i8> [[BROADCAST_SPLAT3]] to <16 x i32>
+; CHECK-NEXT:    [[TMP1:%.*]] = sext <16 x i8> [[BROADCAST_SPLAT]] to <16 x i32>
 ; CHECK-NEXT:    [[TMP4:%.*]] = mul <16 x i32> [[TMP3]], [[TMP1]]
 ; CHECK-NEXT:    [[PARTIAL_REDUCE]] = call <4 x i32> @llvm.experimental.vector.partial.reduce.add.v4i32.v16i32(<4 x i32> [[VEC_PHI]], <16 x i32> [[TMP4]])
 ; CHECK-NEXT:    [[INDEX_NEXT]] = add nuw i64 [[INDEX]], 16
@@ -247,313 +247,233 @@ define i32 @dotp_predicated(i64 %N, ptr %a, ptr %b) {
 ; CHECK-NEXT:    [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[INDEX_NEXT:%.*]], [[PRED_LOAD_CONTINUE62:%.*]] ]
 ; CHECK-NEXT:    [[VEC_IND:%.*]] = phi <16 x i64> [ <i64 0, i64 1, i64 2, i64 3, i64 4, i64 5, i64 6, i64 7, i64 8, i64 9, i64 10, i64 11, i64 12, i64 13, i64 14, i64 15>, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[PRED_LOAD_CONTINUE62]] ]
 ; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <4 x i32> [ zeroinitializer, [[VECTOR_PH]] ], [ [[PARTIAL_REDUCE:%.*]], [[PRED_LOAD_CONTINUE62]] ]
-; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
-; CHECK-NEXT:    [[TMP1:%.*]] = add i64 [[INDEX]], 1
-; CHECK-NEXT:    [[TMP2:%.*]] = add i64 [[INDEX]], 2
-; CHECK-NEXT:    [[TMP3:%.*]] = add i64 [[INDEX]], 3
-; CHECK-NEXT:    [[TMP4:%.*]] = add i64 [[INDEX]], 4
-; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[INDEX]], 5
-; CHECK-NEXT:    [[TMP6:%.*]] = add i64 [[INDEX]], 6
-; CHECK-NEXT:    [[TMP7:%.*]] = add i64 [[INDEX]], 7
-; CHECK-NEXT:    [[TMP8:%.*]] = add i64 [[INDEX]], 8
-; CHECK-NEXT:    [[TMP9:%.*]] = add i64 [[INDEX]], 9
-; CHECK-NEXT:    [[TMP10:%.*]] = add i64 [[INDEX]], 10
-; CHECK-NEXT:    [[TMP11:%.*]] = add i64 [[INDEX]], 11
-; CHECK-NEXT:    [[TMP12:%.*]] = add i64 [[INDEX]], 12
-; CHECK-NEXT:    [[TMP13:%.*]] = add i64 [[INDEX]], 13
-; CHECK-NEXT:    [[TMP14:%.*]] = add i64 [[INDEX]], 14
-; CHECK-NEXT:    [[TMP15:%.*]] = add i64 [[INDEX]], 15
 ; CHECK-NEXT:    [[TMP16:%.*]] = icmp ule <16 x i64> [[VEC_IND]], [[BROADCAST_SPLAT]]
 ; CHECK-NEXT:    [[TMP17:%.*]] = extractelement <16 x i1> [[TMP16]], i32 0
 ; CHECK-NEXT:    br i1 [[TMP17]], label [[PRED_LOAD_IF:%.*]], label [[PRED_LOAD_CONTINUE:%.*]]
 ; CHECK:       pred.load.if:
+; CHECK-NEXT:    [[TMP0:%.*]] = add i64 [[INDEX]], 0
 ; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr inbounds i8, ptr [[A]], i64 [[TMP0]]
 ; CHECK-NEXT:    [[TMP19:%.*]] = load i8, ptr [[TMP18]], align 1
 ; CHECK-NEXT:    [[TMP20:%.*]] = insertelement <16 x i8> poison, i8 [[TMP19]], i32 0
+; CHECK-NEXT:    [[TMP99:%.*]] = getelementptr inbounds i8, ptr [[B]], i64 [[TMP0]]
+; CHECK-NEXT:    [[TMP101:%.*]] = lo...
[truncated]

Copy link

github-actions bot commented Apr 17, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

/// vector operand are added together and passed to the next iteration as the
/// next accumulator. After the loop body, the accumulator is reduced to a
/// scalar value.
class VPPartialReductionRecipe : public VPReductionRecipe {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the classof for VPReductionRecipe now include VPPartialReductionRecipe?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines +278 to +282
; CHECK-MAXBW-NEXT: [[WIDE_LOAD2:%.*]] = load <vscale x 16 x i8>, ptr [[TMP12]], align 1
; CHECK-MAXBW-NEXT: [[TMP15:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD2]] to <vscale x 16 x i64>
; CHECK-MAXBW-NEXT: [[TMP13:%.*]] = zext <vscale x 16 x i8> [[WIDE_LOAD]] to <vscale x 16 x i64>
; CHECK-MAXBW-NEXT: [[TMP14:%.*]] = mul nuw nsw <vscale x 16 x i64> [[TMP15]], [[TMP13]]
; CHECK-MAXBW-NEXT: [[PARTIAL_REDUCE]] = call <vscale x 2 x i64> @llvm.experimental.vector.partial.reduce.add.nxv2i64.nxv16i64(<vscale x 2 x i64> [[VEC_PHI]], <vscale x 16 x i64> [[TMP14]])
Copy link
Member

@MacDue MacDue Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test is called "not_dotp" but now looks like it's dotp/partial reduction 🙂 IIRC this won't map directly a dot product instruction (as nxv16i64 to nxv2i64 is not supported at the moment).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah it looks like what was previously too high a cost for it to choose a 16i8 -> 2i64 partial reduction isn't sufficiently high now that the extend cost is hidden. I've made this permutation invalid.

if (AccumEVT == MVT::i64)
Cost *= 2;
else if (AccumEVT != MVT::i32)
if (AccumEVT != MVT::i32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you making this change?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's due to: #136173 (comment)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm.partial.reduce(nxv2i64 %acc, mul(zext(nxv16i8 %a) -> nxv16i64, zext(nxv16i8 %b) -> nxv16i64))
This case can be lowered with the correct support in selectiondag (see #130935 (comment))

I'd argue that this function will need a bit of a rewrite anyway once we get proper support in SelectionDAG for legalising these operations, so I wouldn't bother changing that in this PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we allow this case make sure to rename the test from "not_dotp" to "dotp".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good spot. Done.

@@ -2056,55 +2056,6 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
}
};

/// A recipe for forming partial reductions. In the loop, an accumulator and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to make the change of VPPartialReductionRecipe : public VPSingleDefRecipe -> VPPartialReductionRecipe : public VPReductionRecipe as an NFC change? (For cases around VPMulAccumulateReductionRecipes you can initially add some asserts that the recipe isn't a partial reduction, because that won't be supported until this PR lands)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think I could make it an NFC change, since to conform to VPReductionRecipe, the accumulator and binop have to be swapped around.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(We discussed this offline) swapping the operands in the debug-print function of the recipe is not something that really concerns me, and I think there's still value making this functionally (from end-user perspective) NFC change.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've pre-committed the NFC but rebasing Elvis's changes on top of that has been pretty challenging considering the number of commits on that branch. So I will cherry-pick the NFC on to this branch and it'll just go away once Elvis's PR lands and I rebase this PR on top of main.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried downloading this patch and applying to the HEAD of LLVM and patch said this diff had already been applied. Does the PR need rebasing?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah perhaps this is my mistake. You did say it depends upon #113903. :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that's the case :). Let me know if you have any issues applying it after applying 113903 too.

Comment on lines 2447 to 2451
VPRecipeBase *BinOpR = PRed->getBinOp()->getDefiningRecipe();
auto *BinOp = dyn_cast<VPWidenRecipe>(BinOpR);
if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
VPRecipeBase *BinOpR = PRed->getBinOp()->getDefiningRecipe();
auto *BinOp = dyn_cast<VPWidenRecipe>(BinOpR);
if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
return;
VPRecipeBase *BinOpR = PRed->getBinOp()->getDefiningRecipe();
auto *BinOp = dyn_cast<VPWidenRecipe>(BinOpR);
if (!match(PRed->getBinOp(), m_Mul(m_ZExtOrSExt(Ext0), m_ZExtOrSExt(Ext1))))`
return;

Could you use pattern matching to combine some of the checks?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍 .

if (AccumEVT == MVT::i64)
Cost *= 2;
else if (AccumEVT != MVT::i32)
if (AccumEVT != MVT::i32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm.partial.reduce(nxv2i64 %acc, mul(zext(nxv16i8 %a) -> nxv16i64, zext(nxv16i8 %b) -> nxv16i64))
This case can be lowered with the correct support in selectiondag (see #130935 (comment))

I'd argue that this function will need a bit of a rewrite anyway once we get proper support in SelectionDAG for legalising these operations, so I wouldn't bother changing that in this PR.

/// This function tries to create an abstract recipe from a partial reduction to
/// hide its mul and extends from cost estimation.
static void
tryToCreateAbstractPartialReductionRecipe(VPPartialReductionRecipe *PRed) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be given the Range & and for that range to be clamped if it doesn't match or if the cost is higher than the individual operations (similar to what happens in tryToCreateAbstractReductionRecipe) ?

(note that the cost part is still missing)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this point we've already created the partial reduction and clamped the range so I don't think we need to do any costing (like tryToMatchAndCreateMulAccumulateReduction does with getMulAccReductionCost) since we already know it's worthwhile (see getScaledReductions in LoopVectorize.cpp). This part of the code just puts the partial reduction inside the abstract recipe, which shouldn't need to consider any costing.

@@ -2056,55 +2056,6 @@ class VPReductionPHIRecipe : public VPHeaderPHIRecipe,
}
};

/// A recipe for forming partial reductions. In the loop, an accumulator and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(We discussed this offline) swapping the operands in the debug-print function of the recipe is not something that really concerns me, and I think there's still value making this functionally (from end-user perspective) NFC change.

@SamTebbs33 SamTebbs33 force-pushed the mulacc-partial-reduction branch from 54d066c to 3f02b10 Compare May 1, 2025 15:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants