Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 90b3712

Browse files
committed
Reapply "[VPlan] Detect and create partial reductions in VPlan. (NFCI) (llvm#167851)"
This reverts commit d1e477b. Recommit with a extra checks making sure extends are VPWidenCastRecipes, rejecting VPReplicateRecipes. Original message: As a first step, move the existing partial reduction detection logic to VPlan, trying to preserve the existing code structure & behavior as closely as possible. With this, partial reductions are detected and created together in a single step. This allows forming partial reductions and bundling them up if profitable together in a follow-up. PR: llvm#167851
1 parent 1818b23 commit 90b3712

11 files changed

Lines changed: 493 additions & 368 deletions

File tree

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 7 additions & 267 deletions
Original file line numberDiff line numberDiff line change
@@ -8036,209 +8036,6 @@ VPReplicateRecipe *VPRecipeBuilder::handleReplication(VPInstruction *VPI,
80368036
return Recipe;
80378037
}
80388038

8039-
/// Find all possible partial reductions in the loop and track all of those that
8040-
/// are valid so recipes can be formed later.
8041-
void VPRecipeBuilder::collectScaledReductions(VFRange &Range) {
8042-
// Find all possible partial reductions, grouping chains by their PHI. This
8043-
// grouping allows invalidating the whole chain, if any link is not a valid
8044-
// partial reduction.
8045-
MapVector<Instruction *,
8046-
SmallVector<std::pair<PartialReductionChain, unsigned>>>
8047-
ChainsByPhi;
8048-
for (const auto &[Phi, RdxDesc] : Legal->getReductionVars()) {
8049-
if (Instruction *RdxExitInstr = RdxDesc.getLoopExitInstr())
8050-
getScaledReductions(Phi, RdxExitInstr, Range, ChainsByPhi[Phi]);
8051-
}
8052-
8053-
// A partial reduction is invalid if any of its extends are used by
8054-
// something that isn't another partial reduction. This is because the
8055-
// extends are intended to be lowered along with the reduction itself.
8056-
8057-
// Build up a set of partial reduction ops for efficient use checking.
8058-
SmallPtrSet<User *, 4> PartialReductionOps;
8059-
for (const auto &[_, Chains] : ChainsByPhi)
8060-
for (const auto &[PartialRdx, _] : Chains)
8061-
PartialReductionOps.insert(PartialRdx.ExtendUser);
8062-
8063-
auto ExtendIsOnlyUsedByPartialReductions =
8064-
[&PartialReductionOps](Instruction *Extend) {
8065-
return all_of(Extend->users(), [&](const User *U) {
8066-
return PartialReductionOps.contains(U);
8067-
});
8068-
};
8069-
8070-
// Check if each use of a chain's two extends is a partial reduction
8071-
// and only add those that don't have non-partial reduction users.
8072-
for (const auto &[_, Chains] : ChainsByPhi) {
8073-
for (const auto &[Chain, Scale] : Chains) {
8074-
if (ExtendIsOnlyUsedByPartialReductions(Chain.ExtendA) &&
8075-
(!Chain.ExtendB ||
8076-
ExtendIsOnlyUsedByPartialReductions(Chain.ExtendB)))
8077-
ScaledReductionMap.try_emplace(Chain.Reduction, Scale);
8078-
}
8079-
}
8080-
8081-
// Check that all partial reductions in a chain are only used by other
8082-
// partial reductions with the same scale factor. Otherwise we end up creating
8083-
// users of scaled reductions where the types of the other operands don't
8084-
// match.
8085-
for (const auto &[Phi, Chains] : ChainsByPhi) {
8086-
for (const auto &[Chain, Scale] : Chains) {
8087-
auto AllUsersPartialRdx = [ScaleVal = Scale, RdxPhi = Phi,
8088-
this](const User *U) {
8089-
auto *UI = cast<Instruction>(U);
8090-
if (isa<PHINode>(UI) && UI->getParent() == OrigLoop->getHeader())
8091-
return UI == RdxPhi;
8092-
return ScaledReductionMap.lookup_or(UI, 0) == ScaleVal ||
8093-
!OrigLoop->contains(UI->getParent());
8094-
};
8095-
8096-
// If any partial reduction entry for the phi is invalid, invalidate the
8097-
// whole chain.
8098-
if (!all_of(Chain.Reduction->users(), AllUsersPartialRdx)) {
8099-
for (const auto &[Chain, _] : Chains)
8100-
ScaledReductionMap.erase(Chain.Reduction);
8101-
break;
8102-
}
8103-
}
8104-
}
8105-
}
8106-
8107-
bool VPRecipeBuilder::getScaledReductions(
8108-
Instruction *PHI, Instruction *RdxExitInstr, VFRange &Range,
8109-
SmallVectorImpl<std::pair<PartialReductionChain, unsigned>> &Chains) {
8110-
if (!CM.TheLoop->contains(RdxExitInstr))
8111-
return false;
8112-
8113-
auto *Update = dyn_cast<BinaryOperator>(RdxExitInstr);
8114-
if (!Update)
8115-
return false;
8116-
8117-
Value *Op = Update->getOperand(0);
8118-
Value *PhiOp = Update->getOperand(1);
8119-
if (Op == PHI)
8120-
std::swap(Op, PhiOp);
8121-
8122-
using namespace llvm::PatternMatch;
8123-
// If Op is an extend, then it's still a valid partial reduction if the
8124-
// extended mul fulfills the other requirements.
8125-
// For example, reduce.add(ext(mul(ext(A), ext(B)))) is still a valid partial
8126-
// reduction since the inner extends will be widened. We already have oneUse
8127-
// checks on the inner extends so widening them is safe.
8128-
std::optional<TTI::PartialReductionExtendKind> OuterExtKind = std::nullopt;
8129-
if (match(Op, m_ZExtOrSExt(m_Mul(m_Value(), m_Value())))) {
8130-
auto *Cast = cast<CastInst>(Op);
8131-
OuterExtKind = TTI::getPartialReductionExtendKind(Cast->getOpcode());
8132-
Op = Cast->getOperand(0);
8133-
}
8134-
8135-
// Try and get a scaled reduction from the first non-phi operand.
8136-
// If one is found, we use the discovered reduction instruction in
8137-
// place of the accumulator for costing.
8138-
if (auto *OpInst = dyn_cast<Instruction>(Op)) {
8139-
if (getScaledReductions(PHI, OpInst, Range, Chains)) {
8140-
PHI = Chains.rbegin()->first.Reduction;
8141-
8142-
Op = Update->getOperand(0);
8143-
PhiOp = Update->getOperand(1);
8144-
if (Op == PHI)
8145-
std::swap(Op, PhiOp);
8146-
}
8147-
}
8148-
if (PhiOp != PHI)
8149-
return false;
8150-
8151-
// If the update is a binary operator, check both of its operands to see if
8152-
// they are extends. Otherwise, see if the update comes directly from an
8153-
// extend.
8154-
Instruction *Exts[2] = {nullptr};
8155-
BinaryOperator *ExtendUser = dyn_cast<BinaryOperator>(Op);
8156-
std::optional<unsigned> BinOpc;
8157-
Type *ExtOpTypes[2] = {nullptr};
8158-
TTI::PartialReductionExtendKind ExtKinds[2] = {TTI::PR_None};
8159-
8160-
auto CollectExtInfo = [this, OuterExtKind, &Exts, &ExtOpTypes,
8161-
&ExtKinds](SmallVectorImpl<Value *> &Ops) -> bool {
8162-
for (const auto &[I, OpI] : enumerate(Ops)) {
8163-
const APInt *C;
8164-
if (I > 0 && match(OpI, m_APInt(C)) &&
8165-
canConstantBeExtended(C, ExtOpTypes[0], ExtKinds[0])) {
8166-
ExtOpTypes[I] = ExtOpTypes[0];
8167-
ExtKinds[I] = ExtKinds[0];
8168-
continue;
8169-
}
8170-
Value *ExtOp;
8171-
if (!match(OpI, m_ZExtOrSExt(m_Value(ExtOp))) &&
8172-
!match(OpI, m_FPExt(m_Value(ExtOp))))
8173-
return false;
8174-
Exts[I] = cast<Instruction>(OpI);
8175-
8176-
// TODO: We should be able to support live-ins.
8177-
if (!CM.TheLoop->contains(Exts[I]))
8178-
return false;
8179-
8180-
ExtOpTypes[I] = ExtOp->getType();
8181-
ExtKinds[I] = TTI::getPartialReductionExtendKind(Exts[I]);
8182-
// The outer extend kind must be the same as the inner extends, so that
8183-
// they can be folded together.
8184-
if (OuterExtKind.has_value() && OuterExtKind.value() != ExtKinds[I])
8185-
return false;
8186-
}
8187-
return true;
8188-
};
8189-
8190-
if (ExtendUser) {
8191-
if (!ExtendUser->hasOneUse())
8192-
return false;
8193-
8194-
// Use the side-effect of match to replace BinOp only if the pattern is
8195-
// matched, we don't care at this point whether it actually matched.
8196-
match(ExtendUser, m_Neg(m_BinOp(ExtendUser)));
8197-
8198-
SmallVector<Value *> Ops(ExtendUser->operands());
8199-
if (!CollectExtInfo(Ops))
8200-
return false;
8201-
8202-
BinOpc = std::make_optional(ExtendUser->getOpcode());
8203-
} else if (match(Update, m_Add(m_Value(), m_Value())) ||
8204-
match(Update, m_FAdd(m_Value(), m_Value()))) {
8205-
// We already know the operands for Update are Op and PhiOp.
8206-
SmallVector<Value *> Ops({Op});
8207-
if (!CollectExtInfo(Ops))
8208-
return false;
8209-
8210-
ExtendUser = Update;
8211-
BinOpc = std::nullopt;
8212-
} else
8213-
return false;
8214-
8215-
PartialReductionChain Chain(RdxExitInstr, Exts[0], Exts[1], ExtendUser);
8216-
8217-
TypeSize PHISize = PHI->getType()->getPrimitiveSizeInBits();
8218-
TypeSize ASize = ExtOpTypes[0]->getPrimitiveSizeInBits();
8219-
if (!PHISize.hasKnownScalarFactor(ASize))
8220-
return false;
8221-
unsigned TargetScaleFactor = PHISize.getKnownScalarFactor(ASize);
8222-
8223-
if (LoopVectorizationPlanner::getDecisionAndClampRange(
8224-
[&](ElementCount VF) {
8225-
std::optional<FastMathFlags> FMF = std::nullopt;
8226-
if (Update->getOpcode() == Instruction::FAdd)
8227-
FMF = Update->getFastMathFlags();
8228-
InstructionCost Cost = TTI->getPartialReductionCost(
8229-
Update->getOpcode(), ExtOpTypes[0], ExtOpTypes[1],
8230-
PHI->getType(), VF, ExtKinds[0], ExtKinds[1], BinOpc,
8231-
CM.CostKind, FMF);
8232-
return Cost.isValid();
8233-
},
8234-
Range)) {
8235-
Chains.emplace_back(Chain, TargetScaleFactor);
8236-
return true;
8237-
}
8238-
8239-
return false;
8240-
}
8241-
82428039
VPRecipeBase *
82438040
VPRecipeBuilder::tryToCreateWidenNonPhiRecipe(VPSingleDefRecipe *R,
82448041
VFRange &Range) {
@@ -8269,9 +8066,6 @@ VPRecipeBuilder::tryToCreateWidenNonPhiRecipe(VPSingleDefRecipe *R,
82698066
VPI->getOpcode() == Instruction::Store)
82708067
return tryToWidenMemory(VPI, Range);
82718068

8272-
if (std::optional<unsigned> ScaleFactor = getScalingForReduction(Instr))
8273-
return tryToCreatePartialReduction(VPI, ScaleFactor.value());
8274-
82758069
if (!shouldWiden(Instr, Range))
82768070
return nullptr;
82778071

@@ -8290,55 +8084,6 @@ VPRecipeBuilder::tryToCreateWidenNonPhiRecipe(VPSingleDefRecipe *R,
82908084
return tryToWiden(VPI);
82918085
}
82928086

8293-
VPRecipeBase *
8294-
VPRecipeBuilder::tryToCreatePartialReduction(VPInstruction *Reduction,
8295-
unsigned ScaleFactor) {
8296-
assert(Reduction->getNumOperands() == 2 &&
8297-
"Unexpected number of operands for partial reduction");
8298-
8299-
VPValue *BinOp = Reduction->getOperand(0);
8300-
VPValue *Accumulator = Reduction->getOperand(1);
8301-
VPRecipeBase *BinOpRecipe = BinOp->getDefiningRecipe();
8302-
if (isa<VPReductionPHIRecipe>(BinOpRecipe) ||
8303-
(isa<VPReductionRecipe>(BinOpRecipe) &&
8304-
cast<VPReductionRecipe>(BinOpRecipe)->isPartialReduction()))
8305-
std::swap(BinOp, Accumulator);
8306-
8307-
if (auto *RedPhiR = dyn_cast<VPReductionPHIRecipe>(Accumulator))
8308-
RedPhiR->setVFScaleFactor(ScaleFactor);
8309-
8310-
assert(ScaleFactor ==
8311-
vputils::getVFScaleFactor(Accumulator->getDefiningRecipe()) &&
8312-
"all accumulators in chain must have same scale factor");
8313-
8314-
auto *ReductionI = Reduction->getUnderlyingInstr();
8315-
assert(
8316-
Reduction->getOpcode() != Instruction::FAdd ||
8317-
(ReductionI->hasAllowReassoc() && ReductionI->hasAllowContract()) &&
8318-
"FAdd partial reduction requires allow-reassoc and allow-contract");
8319-
if (Reduction->getOpcode() == Instruction::Sub) {
8320-
SmallVector<VPValue *, 2> Ops;
8321-
Ops.push_back(Plan.getConstantInt(ReductionI->getType(), 0));
8322-
Ops.push_back(BinOp);
8323-
BinOp = new VPWidenRecipe(*ReductionI, Ops, VPIRFlags(*ReductionI),
8324-
VPIRMetadata(), ReductionI->getDebugLoc());
8325-
Builder.insert(BinOp->getDefiningRecipe());
8326-
}
8327-
8328-
VPValue *Cond = nullptr;
8329-
if (CM.blockNeedsPredicationForAnyReason(ReductionI->getParent()))
8330-
Cond = getBlockInMask(Builder.getInsertBlock());
8331-
8332-
return new VPReductionRecipe(
8333-
Reduction->getOpcode() == Instruction::FAdd ? RecurKind::FAdd
8334-
: RecurKind::Add,
8335-
Reduction->getOpcode() == Instruction::FAdd
8336-
? Reduction->getFastMathFlags()
8337-
: FastMathFlags(),
8338-
ReductionI, Accumulator, BinOp, Cond,
8339-
RdxUnordered{/*VFScaleFactor=*/ScaleFactor}, ReductionI->getDebugLoc());
8340-
}
8341-
83428087
void LoopVectorizationPlanner::buildVPlansWithVPRecipes(ElementCount MinVF,
83438088
ElementCount MaxVF) {
83448089
if (ElementCount::isKnownGT(MinVF, MaxVF))
@@ -8475,11 +8220,7 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
84758220
// Construct wide recipes and apply predication for original scalar
84768221
// VPInstructions in the loop.
84778222
// ---------------------------------------------------------------------------
8478-
VPRecipeBuilder RecipeBuilder(*Plan, OrigLoop, TLI, &TTI, Legal, CM, Builder,
8479-
BlockMaskCache);
8480-
// TODO: Handle partial reductions with EVL tail folding.
8481-
if (!CM.foldTailWithEVL())
8482-
RecipeBuilder.collectScaledReductions(Range);
8223+
VPRecipeBuilder RecipeBuilder(*Plan, TLI, Legal, CM, Builder, BlockMaskCache);
84838224

84848225
// Scan the body of the loop in a topological order to visit each basic block
84858226
// after having visited its predecessor basic blocks.
@@ -8601,13 +8342,16 @@ VPlanPtr LoopVectorizationPlanner::tryToBuildVPlanWithVPRecipes(
86018342
if (!RUN_VPLAN_PASS(VPlanTransforms::handleFindLastReductions, *Plan))
86028343
return nullptr;
86038344

8604-
// Transform recipes to abstract recipes if it is legal and beneficial and
8605-
// clamp the range for better cost estimation.
8345+
// Create partial reduction recipes for scaled reductions and transform
8346+
// recipes to abstract recipes if it is legal and beneficial and clamp the
8347+
// range for better cost estimation.
86068348
// TODO: Enable following transform when the EVL-version of extended-reduction
86078349
// and mulacc-reduction are implemented.
86088350
if (!CM.foldTailWithEVL()) {
86098351
VPCostContext CostCtx(CM.TTI, *CM.TLI, *Plan, CM, CM.CostKind, CM.PSE,
86108352
OrigLoop);
8353+
RUN_VPLAN_PASS(VPlanTransforms::createPartialReductions, *Plan, CostCtx,
8354+
Range);
86118355
RUN_VPLAN_PASS(VPlanTransforms::convertToAbstractRecipes, *Plan, CostCtx,
86128356
Range);
86138357
}
@@ -8902,11 +8646,7 @@ void LoopVectorizationPlanner::addReductionResultComputation(
89028646
VPBuilder PHBuilder(Plan->getVectorPreheader());
89038647
VPValue *Iden = Plan->getOrAddLiveIn(
89048648
getRecurrenceIdentity(RK, PhiTy, RdxDesc.getFastMathFlags()));
8905-
// If the PHI is used by a partial reduction, set the scale factor.
8906-
unsigned ScaleFactor =
8907-
RecipeBuilder.getScalingForReduction(RdxDesc.getLoopExitInstr())
8908-
.value_or(1);
8909-
auto *ScaleFactorVPV = Plan->getConstantInt(32, ScaleFactor);
8649+
auto *ScaleFactorVPV = Plan->getConstantInt(32, 1);
89108650
VPValue *StartV = PHBuilder.createNaryOp(
89118651
VPInstruction::ReductionStartVector,
89128652
{PhiR->getStartValue(), Iden, ScaleFactorVPV},

0 commit comments

Comments
 (0)