@@ -8036,209 +8036,6 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(VPInstruction *VPI,
80368036 return Recipe;
80378037}
80388038
8039- // / Find all possible partial reductions in the loop and track all of those that
8040- // / are valid so recipes can be formed later.
8041- void VPRecipeBuilder::collectScaledReductions (VFRange &Range) {
8042- // Find all possible partial reductions, grouping chains by their PHI. This
8043- // grouping allows invalidating the whole chain, if any link is not a valid
8044- // partial reduction.
8045- MapVector<Instruction *,
8046- SmallVector<std::pair<PartialReductionChain, unsigned >>>
8047- ChainsByPhi;
8048- for (const auto &[Phi, RdxDesc] : Legal->getReductionVars ()) {
8049- if (Instruction *RdxExitInstr = RdxDesc.getLoopExitInstr ())
8050- getScaledReductions (Phi, RdxExitInstr, Range, ChainsByPhi[Phi]);
8051- }
8052-
8053- // A partial reduction is invalid if any of its extends are used by
8054- // something that isn't another partial reduction. This is because the
8055- // extends are intended to be lowered along with the reduction itself.
8056-
8057- // Build up a set of partial reduction ops for efficient use checking.
8058- SmallPtrSet<User *, 4 > PartialReductionOps;
8059- for (const auto &[_, Chains] : ChainsByPhi)
8060- for (const auto &[PartialRdx, _] : Chains)
8061- PartialReductionOps.insert (PartialRdx.ExtendUser );
8062-
8063- auto ExtendIsOnlyUsedByPartialReductions =
8064- [&PartialReductionOps](Instruction *Extend) {
8065- return all_of (Extend->users (), [&](const User *U) {
8066- return PartialReductionOps.contains (U);
8067- });
8068- };
8069-
8070- // Check if each use of a chain's two extends is a partial reduction
8071- // and only add those that don't have non-partial reduction users.
8072- for (const auto &[_, Chains] : ChainsByPhi) {
8073- for (const auto &[Chain, Scale] : Chains) {
8074- if (ExtendIsOnlyUsedByPartialReductions (Chain.ExtendA ) &&
8075- (!Chain.ExtendB ||
8076- ExtendIsOnlyUsedByPartialReductions (Chain.ExtendB )))
8077- ScaledReductionMap.try_emplace (Chain.Reduction , Scale);
8078- }
8079- }
8080-
8081- // Check that all partial reductions in a chain are only used by other
8082- // partial reductions with the same scale factor. Otherwise we end up creating
8083- // users of scaled reductions where the types of the other operands don't
8084- // match.
8085- for (const auto &[Phi, Chains] : ChainsByPhi) {
8086- for (const auto &[Chain, Scale] : Chains) {
8087- auto AllUsersPartialRdx = [ScaleVal = Scale, RdxPhi = Phi,
8088- this ](const User *U) {
8089- auto *UI = cast<Instruction>(U);
8090- if (isa<PHINode>(UI) && UI->getParent () == OrigLoop->getHeader ())
8091- return UI == RdxPhi;
8092- return ScaledReductionMap.lookup_or (UI, 0 ) == ScaleVal ||
8093- !OrigLoop->contains (UI->getParent ());
8094- };
8095-
8096- // If any partial reduction entry for the phi is invalid, invalidate the
8097- // whole chain.
8098- if (!all_of (Chain.Reduction ->users (), AllUsersPartialRdx)) {
8099- for (const auto &[Chain, _] : Chains)
8100- ScaledReductionMap.erase (Chain.Reduction );
8101- break ;
8102- }
8103- }
8104- }
8105- }
8106-
8107- bool VPRecipeBuilder::getScaledReductions (
8108- Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
8109- SmallVectorImpl<std::pair<PartialReductionChain, unsigned >> &Chains) {
8110- if (!CM.TheLoop ->contains (RdxExitInstr))
8111- return false ;
8112-
8113- auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
8114- if (!Update)
8115- return false ;
8116-
8117- Value *Op = Update->getOperand (0 );
8118- Value *PhiOp = Update->getOperand (1 );
8119- if (Op == PHI)
8120- std::swap (Op, PhiOp);
8121-
8122- using namespace llvm ::PatternMatch;
8123- // If Op is an extend, then it's still a valid partial reduction if the
8124- // extended mul fulfills the other requirements.
8125- // For example, reduce.add(ext(mul(ext(A), ext(B)))) is still a valid partial
8126- // reduction since the inner extends will be widened. We already have oneUse
8127- // checks on the inner extends so widening them is safe.
8128- std::optional<TTI::PartialReductionExtendKind> OuterExtKind = std::nullopt ;
8129- if (match (Op, m_ZExtOrSExt (m_Mul (m_Value (), m_Value ())))) {
8130- auto *Cast = cast<CastInst>(Op);
8131- OuterExtKind = TTI::getPartialReductionExtendKind (Cast->getOpcode ());
8132- Op = Cast->getOperand (0 );
8133- }
8134-
8135- // Try and get a scaled reduction from the first non-phi operand.
8136- // If one is found, we use the discovered reduction instruction in
8137- // place of the accumulator for costing.
8138- if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8139- if (getScaledReductions (PHI, OpInst, Range, Chains)) {
8140- PHI = Chains.rbegin ()->first .Reduction ;
8141-
8142- Op = Update->getOperand (0 );
8143- PhiOp = Update->getOperand (1 );
8144- if (Op == PHI)
8145- std::swap (Op, PhiOp);
8146- }
8147- }
8148- if (PhiOp != PHI)
8149- return false ;
8150-
8151- // If the update is a binary operator, check both of its operands to see if
8152- // they are extends. Otherwise, see if the update comes directly from an
8153- // extend.
8154- Instruction *Exts[2 ] = {nullptr };
8155- BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
8156- std::optional<unsigned > BinOpc;
8157- Type *ExtOpTypes[2 ] = {nullptr };
8158- TTI::PartialReductionExtendKind ExtKinds[2 ] = {TTI::PR_None};
8159-
8160- auto CollectExtInfo = [this , OuterExtKind, &Exts, &ExtOpTypes,
8161- &ExtKinds](SmallVectorImpl<Value *> &Ops) -> bool {
8162- for (const auto &[I, OpI] : enumerate(Ops)) {
8163- const APInt *C;
8164- if (I > 0 && match (OpI, m_APInt (C)) &&
8165- canConstantBeExtended (C, ExtOpTypes[0 ], ExtKinds[0 ])) {
8166- ExtOpTypes[I] = ExtOpTypes[0 ];
8167- ExtKinds[I] = ExtKinds[0 ];
8168- continue ;
8169- }
8170- Value *ExtOp;
8171- if (!match (OpI, m_ZExtOrSExt (m_Value (ExtOp))) &&
8172- !match (OpI, m_FPExt (m_Value (ExtOp))))
8173- return false ;
8174- Exts[I] = cast<Instruction>(OpI);
8175-
8176- // TODO: We should be able to support live-ins.
8177- if (!CM.TheLoop ->contains (Exts[I]))
8178- return false ;
8179-
8180- ExtOpTypes[I] = ExtOp->getType ();
8181- ExtKinds[I] = TTI::getPartialReductionExtendKind (Exts[I]);
8182- // The outer extend kind must be the same as the inner extends, so that
8183- // they can be folded together.
8184- if (OuterExtKind.has_value () && OuterExtKind.value () != ExtKinds[I])
8185- return false ;
8186- }
8187- return true ;
8188- };
8189-
8190- if (ExtendUser) {
8191- if (!ExtendUser->hasOneUse ())
8192- return false ;
8193-
8194- // Use the side-effect of match to replace BinOp only if the pattern is
8195- // matched, we don't care at this point whether it actually matched.
8196- match (ExtendUser, m_Neg (m_BinOp (ExtendUser)));
8197-
8198- SmallVector<Value *> Ops (ExtendUser->operands ());
8199- if (!CollectExtInfo (Ops))
8200- return false ;
8201-
8202- BinOpc = std::make_optional (ExtendUser->getOpcode ());
8203- } else if (match (Update, m_Add (m_Value (), m_Value ())) ||
8204- match (Update, m_FAdd (m_Value (), m_Value ()))) {
8205- // We already know the operands for Update are Op and PhiOp.
8206- SmallVector<Value *> Ops ({Op});
8207- if (!CollectExtInfo (Ops))
8208- return false ;
8209-
8210- ExtendUser = Update;
8211- BinOpc = std::nullopt ;
8212- } else
8213- return false ;
8214-
8215- PartialReductionChain Chain (RdxExitInstr, Exts[0 ], Exts[1 ], ExtendUser);
8216-
8217- TypeSize PHISize = PHI->getType ()->getPrimitiveSizeInBits ();
8218- TypeSize ASize = ExtOpTypes[0 ]->getPrimitiveSizeInBits ();
8219- if (!PHISize.hasKnownScalarFactor (ASize))
8220- return false ;
8221- unsigned TargetScaleFactor = PHISize.getKnownScalarFactor (ASize);
8222-
8223- if (LoopVectorizationPlanner::getDecisionAndClampRange (
8224- [&](ElementCount VF) {
8225- std::optional<FastMathFlags> FMF = std::nullopt ;
8226- if (Update->getOpcode () == Instruction::FAdd)
8227- FMF = Update->getFastMathFlags ();
8228- InstructionCost Cost = TTI->getPartialReductionCost (
8229- Update->getOpcode (), ExtOpTypes[0 ], ExtOpTypes[1 ],
8230- PHI->getType (), VF, ExtKinds[0 ], ExtKinds[1 ], BinOpc,
8231- CM.CostKind , FMF);
8232- return Cost.isValid ();
8233- },
8234- Range)) {
8235- Chains.emplace_back (Chain, TargetScaleFactor);
8236- return true ;
8237- }
8238-
8239- return false ;
8240- }
8241-
82428039VPRecipeBase *
82438040VPRecipeBuilder::tryToCreateWidenNonPhiRecipe (VPSingleDefRecipe *R,
82448041 VFRange &Range) {
@@ -8269,9 +8066,6 @@ VPRecipeBuilder::tryToCreateWidenNonPhiRecipe(VPSingleDefRecipe *R,
82698066 VPI->getOpcode () == Instruction::Store)
82708067 return tryToWidenMemory (VPI, Range);
82718068
8272- if (std::optional<unsigned > ScaleFactor = getScalingForReduction (Instr))
8273- return tryToCreatePartialReduction (VPI, ScaleFactor.value ());
8274-
82758069 if (!shouldWiden (Instr, Range))
82768070 return nullptr ;
82778071
@@ -8290,55 +8084,6 @@ VPRecipeBuilder::tryToCreateWidenNonPhiRecipe(VPSingleDefRecipe *R,
82908084 return tryToWiden (VPI);
82918085}
82928086
8293- VPRecipeBase *
8294- VPRecipeBuilder::tryToCreatePartialReduction (VPInstruction *Reduction,
8295- unsigned ScaleFactor) {
8296- assert (Reduction->getNumOperands () == 2 &&
8297- " Unexpected number of operands for partial reduction" );
8298-
8299- VPValue *BinOp = Reduction->getOperand (0 );
8300- VPValue *Accumulator = Reduction->getOperand (1 );
8301- VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe ();
8302- if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8303- (isa<VPReductionRecipe>(BinOpRecipe) &&
8304- cast<VPReductionRecipe>(BinOpRecipe)->isPartialReduction ()))
8305- std::swap (BinOp, Accumulator);
8306-
8307- if (auto *RedPhiR = dyn_cast<VPReductionPHIRecipe>(Accumulator))
8308- RedPhiR->setVFScaleFactor (ScaleFactor);
8309-
8310- assert (ScaleFactor ==
8311- vputils::getVFScaleFactor (Accumulator->getDefiningRecipe ()) &&
8312- " all accumulators in chain must have same scale factor" );
8313-
8314- auto *ReductionI = Reduction->getUnderlyingInstr ();
8315- assert (
8316- Reduction->getOpcode () != Instruction::FAdd ||
8317- (ReductionI->hasAllowReassoc () && ReductionI->hasAllowContract ()) &&
8318- " FAdd partial reduction requires allow-reassoc and allow-contract" );
8319- if (Reduction->getOpcode () == Instruction::Sub) {
8320- SmallVector<VPValue *, 2 > Ops;
8321- Ops.push_back (Plan.getConstantInt (ReductionI->getType (), 0 ));
8322- Ops.push_back (BinOp);
8323- BinOp = new VPWidenRecipe (*ReductionI, Ops, VPIRFlags (*ReductionI),
8324- VPIRMetadata (), ReductionI->getDebugLoc ());
8325- Builder.insert (BinOp->getDefiningRecipe ());
8326- }
8327-
8328- VPValue *Cond = nullptr ;
8329- if (CM.blockNeedsPredicationForAnyReason (ReductionI->getParent ()))
8330- Cond = getBlockInMask (Builder.getInsertBlock ());
8331-
8332- return new VPReductionRecipe (
8333- Reduction->getOpcode () == Instruction::FAdd ? RecurKind::FAdd
8334- : RecurKind::Add,
8335- Reduction->getOpcode () == Instruction::FAdd
8336- ? Reduction->getFastMathFlags ()
8337- : FastMathFlags (),
8338- ReductionI, Accumulator, BinOp, Cond,
8339- RdxUnordered{/* VFScaleFactor=*/ ScaleFactor}, ReductionI->getDebugLoc ());
8340- }
8341-
83428087void LoopVectorizationPlanner::buildVPlansWithVPRecipes (ElementCount MinVF,
83438088 ElementCount MaxVF) {
83448089 if (ElementCount::isKnownGT (MinVF, MaxVF))
@@ -8475,11 +8220,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
84758220 // Construct wide recipes and apply predication for original scalar
84768221 // VPInstructions in the loop.
84778222 // ---------------------------------------------------------------------------
8478- VPRecipeBuilder RecipeBuilder (*Plan, OrigLoop, TLI, &TTI, Legal, CM, Builder,
8479- BlockMaskCache);
8480- // TODO: Handle partial reductions with EVL tail folding.
8481- if (!CM.foldTailWithEVL ())
8482- RecipeBuilder.collectScaledReductions (Range);
8223+ VPRecipeBuilder RecipeBuilder (*Plan, TLI, Legal, CM, Builder, BlockMaskCache);
84838224
84848225 // Scan the body of the loop in a topological order to visit each basic block
84858226 // after having visited its predecessor basic blocks.
@@ -8601,13 +8342,16 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
86018342 if (!RUN_VPLAN_PASS (VPlanTransforms::handleFindLastReductions, *Plan))
86028343 return nullptr ;
86038344
8604- // Transform recipes to abstract recipes if it is legal and beneficial and
8605- // clamp the range for better cost estimation.
8345+ // Create partial reduction recipes for scaled reductions and transform
8346+ // recipes to abstract recipes if it is legal and beneficial and clamp the
8347+ // range for better cost estimation.
86068348 // TODO: Enable following transform when the EVL-version of extended-reduction
86078349 // and mulacc-reduction are implemented.
86088350 if (!CM.foldTailWithEVL ()) {
86098351 VPCostContext CostCtx (CM.TTI , *CM.TLI , *Plan, CM, CM.CostKind , CM.PSE ,
86108352 OrigLoop);
8353+ RUN_VPLAN_PASS (VPlanTransforms::createPartialReductions, *Plan, CostCtx,
8354+ Range);
86118355 RUN_VPLAN_PASS (VPlanTransforms::convertToAbstractRecipes, *Plan, CostCtx,
86128356 Range);
86138357 }
@@ -8902,11 +8646,7 @@ void LoopVectorizationPlanner::addReductionResultComputation(
89028646 VPBuilder PHBuilder (Plan->getVectorPreheader ());
89038647 VPValue *Iden = Plan->getOrAddLiveIn (
89048648 getRecurrenceIdentity (RK, PhiTy, RdxDesc.getFastMathFlags ()));
8905- // If the PHI is used by a partial reduction, set the scale factor.
8906- unsigned ScaleFactor =
8907- RecipeBuilder.getScalingForReduction (RdxDesc.getLoopExitInstr ())
8908- .value_or (1 );
8909- auto *ScaleFactorVPV = Plan->getConstantInt (32 , ScaleFactor);
8649+ auto *ScaleFactorVPV = Plan->getConstantInt (32 , 1 );
89108650 VPValue *StartV = PHBuilder.createNaryOp (
89118651 VPInstruction::ReductionStartVector,
89128652 {PhiR->getStartValue (), Iden, ScaleFactorVPV},
0 commit comments