diff --git a/llvm/include/llvm/CodeGen/TargetLowering.h b/llvm/include/llvm/CodeGen/TargetLowering.h index 00c36266a069f..c8f6967f7ab41 100644 --- a/llvm/include/llvm/CodeGen/TargetLowering.h +++ b/llvm/include/llvm/CodeGen/TargetLowering.h @@ -3353,8 +3353,10 @@ class TargetLoweringBase { /// Return true if pulling a binary operation into a select with an identity /// constant is profitable. This is the inverse of an IR transform. /// Example: X + (Cond ? Y : 0) --> Cond ? (X + Y) : X - virtual bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, - EVT VT) const { + virtual bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT, + unsigned SelectOpcode, + SDValue X, + SDValue Y) const { return false; } diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index b175e35385ec6..7c8619aa29346 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -2425,8 +2425,9 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG, if (ShouldCommuteOperands) std::swap(N0, N1); - // TODO: Should this apply to scalar select too? - if (N1.getOpcode() != ISD::VSELECT || !N1.hasOneUse()) + unsigned SelOpcode = N1.getOpcode(); + if ((SelOpcode != ISD::VSELECT && SelOpcode != ISD::SELECT) || + !N1.hasOneUse()) return SDValue(); // We can't hoist all instructions because of immediate UB (not speculatable). @@ -2439,17 +2440,22 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG, SDValue Cond = N1.getOperand(0); SDValue TVal = N1.getOperand(1); SDValue FVal = N1.getOperand(2); + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); // This transform increases uses of N0, so freeze it to be safe. // binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal) unsigned OpNo = ShouldCommuteOperands ? 0 : 1; - if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo)) { + if (isNeutralConstant(Opcode, N->getFlags(), TVal, OpNo) && + TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0, + FVal)) { SDValue F0 = DAG.getFreeze(N0); SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags()); return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO); } // binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0 - if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo)) { + if (isNeutralConstant(Opcode, N->getFlags(), FVal, OpNo) && + TLI.shouldFoldSelectWithIdentityConstant(Opcode, VT, SelOpcode, N0, + TVal)) { SDValue F0 = DAG.getFreeze(N0); SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags()); return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0); @@ -2459,26 +2465,23 @@ static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG, } SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) { + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); assert(TLI.isBinOp(BO->getOpcode()) && BO->getNumValues() == 1 && "Unexpected binary operator"); - const TargetLowering &TLI = DAG.getTargetLoweringInfo(); - auto BinOpcode = BO->getOpcode(); - EVT VT = BO->getValueType(0); - if (TLI.shouldFoldSelectWithIdentityConstant(BinOpcode, VT)) { - if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false)) - return Sel; + if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, false)) + return Sel; - if (TLI.isCommutativeBinOp(BO->getOpcode())) - if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true)) - return Sel; - } + if (TLI.isCommutativeBinOp(BO->getOpcode())) + if (SDValue Sel = foldSelectWithIdentityConstant(BO, DAG, true)) + return Sel; // Don't do this unless the old select is going away. We want to eliminate the // binary operator, not replace a binop with a select. // TODO: Handle ISD::SELECT_CC. unsigned SelOpNo = 0; SDValue Sel = BO->getOperand(0); + auto BinOpcode = BO->getOpcode(); if (Sel.getOpcode() != ISD::SELECT || !Sel.hasOneUse()) { SelOpNo = 1; Sel = BO->getOperand(1); @@ -2526,6 +2529,7 @@ SDValue DAGCombiner::foldBinOpIntoSelect(SDNode *BO) { SDLoc DL(Sel); SDValue NewCT, NewCF; + EVT VT = BO->getValueType(0); if (CanFoldNonConst) { // If CBO is an opaque constant, we can't rely on getNode to constant fold. diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 771eee1b3fecf..0a53a56006ded 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -18040,8 +18040,10 @@ bool AArch64TargetLowering::shouldFoldConstantShiftPairToMask( } bool AArch64TargetLowering::shouldFoldSelectWithIdentityConstant( - unsigned BinOpcode, EVT VT) const { - return VT.isScalableVector() && isTypeLegal(VT); + unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X, + SDValue Y) const { + return VT.isScalableVector() && isTypeLegal(VT) && + SelectOpcode == ISD::VSELECT; } bool AArch64TargetLowering::shouldConvertConstantLoadToIntImm(const APInt &Imm, diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.h b/llvm/lib/Target/AArch64/AArch64ISelLowering.h index 0d51ef2be8631..2371505a5fc99 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -786,8 +786,9 @@ class AArch64TargetLowering : public TargetLowering { bool shouldFoldConstantShiftPairToMask(const SDNode *N, CombineLevel Level) const override; - bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, - EVT VT) const override; + bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT, + unsigned SelectOpcode, SDValue X, + SDValue Y) const override; /// Returns true if it is beneficial to convert a load of a constant /// to just the constant itself. diff --git a/llvm/lib/Target/ARM/ARMISelLowering.cpp b/llvm/lib/Target/ARM/ARMISelLowering.cpp index 2290ac2728c6d..7f418e3f5f857 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.cpp +++ b/llvm/lib/Target/ARM/ARMISelLowering.cpp @@ -13960,9 +13960,11 @@ bool ARMTargetLowering::shouldFoldConstantShiftPairToMask( return false; } -bool ARMTargetLowering::shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, - EVT VT) const { - return Subtarget->hasMVEIntegerOps() && isTypeLegal(VT); +bool ARMTargetLowering::shouldFoldSelectWithIdentityConstant( + unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X, + SDValue Y) const { + return Subtarget->hasMVEIntegerOps() && isTypeLegal(VT) && + SelectOpcode == ISD::VSELECT; } bool ARMTargetLowering::preferIncOfAddToSubOfNot(EVT VT) const { diff --git a/llvm/lib/Target/ARM/ARMISelLowering.h b/llvm/lib/Target/ARM/ARMISelLowering.h index 9fad056edd3f1..87710ee29a249 100644 --- a/llvm/lib/Target/ARM/ARMISelLowering.h +++ b/llvm/lib/Target/ARM/ARMISelLowering.h @@ -758,8 +758,9 @@ class VectorType; bool shouldFoldConstantShiftPairToMask(const SDNode *N, CombineLevel Level) const override; - bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, - EVT VT) const override; + bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT, + unsigned SelectOpcode, SDValue X, + SDValue Y) const override; bool preferIncOfAddToSubOfNot(EVT VT) const override; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 98fba9e86e88a..907b817490bd6 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -2090,8 +2090,12 @@ bool RISCVTargetLowering::hasBitTest(SDValue X, SDValue Y) const { return C && C->getAPIntValue().ule(10); } -bool RISCVTargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode, - EVT VT) const { +bool RISCVTargetLowering::shouldFoldSelectWithIdentityConstant( + unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X, + SDValue Y) const { + if (SelectOpcode != ISD::VSELECT) + return false; + // Only enable for rvv. if (!VT.isVector() || !Subtarget.hasVInstructions()) return false; diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h index baf1b2e4d8e6e..e291505564945 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -585,8 +585,9 @@ class RISCVTargetLowering : public TargetLowering { unsigned &NumIntermediates, MVT &RegisterVT) const override; - bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, - EVT VT) const override; + bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT, + unsigned SelectOpcode, SDValue X, + SDValue Y) const override; /// Return true if the given shuffle mask can be codegen'd directly, or if it /// should be stack expanded. diff --git a/llvm/lib/Target/X86/X86ISelLowering.cpp b/llvm/lib/Target/X86/X86ISelLowering.cpp index 993118c52564e..9293f4e74399c 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.cpp +++ b/llvm/lib/Target/X86/X86ISelLowering.cpp @@ -35383,8 +35383,11 @@ bool X86TargetLowering::isNarrowingProfitable(SDNode *N, EVT SrcVT, return !(SrcVT == MVT::i32 && DestVT == MVT::i16); } -bool X86TargetLowering::shouldFoldSelectWithIdentityConstant(unsigned Opcode, - EVT VT) const { +bool X86TargetLowering::shouldFoldSelectWithIdentityConstant( + unsigned BinOpcode, EVT VT, unsigned SelectOpcode, SDValue X, + SDValue Y) const { + if (SelectOpcode != ISD::VSELECT) + return false; // TODO: This is too general. There are cases where pre-AVX512 codegen would // benefit. The transform may also be profitable for scalar code. if (!Subtarget.hasAVX512()) diff --git a/llvm/lib/Target/X86/X86ISelLowering.h b/llvm/lib/Target/X86/X86ISelLowering.h index 4a2b35e9efe7c..f5dafedfc5464 100644 --- a/llvm/lib/Target/X86/X86ISelLowering.h +++ b/llvm/lib/Target/X86/X86ISelLowering.h @@ -1460,8 +1460,9 @@ namespace llvm { /// from i32 to i16. bool isNarrowingProfitable(SDNode *N, EVT SrcVT, EVT DestVT) const override; - bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, - EVT VT) const override; + bool shouldFoldSelectWithIdentityConstant(unsigned BinOpcode, EVT VT, + unsigned SelectOpcode, SDValue X, + SDValue Y) const override; /// Given an intrinsic, checks if on the target the intrinsic will need to map /// to a MemIntrinsicNode (touches memory). If this is the case, it returns