diff --git a/src/coreclr/jit/optimizebools.cpp b/src/coreclr/jit/optimizebools.cpp index 580a64d2e6727f..7f44a53407b7b9 100644 --- a/src/coreclr/jit/optimizebools.cpp +++ b/src/coreclr/jit/optimizebools.cpp @@ -66,6 +66,7 @@ class OptBoolsDsc public: bool optOptimizeBoolsCondBlock(); bool optOptimizeCompareChainCondBlock(); + bool optOptimizeRangeTests(); bool optOptimizeBoolsReturnBlock(BasicBlock* b3); #ifdef DEBUG void optOptimizeBoolsGcStress(); @@ -404,6 +405,337 @@ bool OptBoolsDsc::FindCompareChain(GenTree* condition, bool* isTestCondition) return false; } +//------------------------------------------------------------------------------ +// GetIntersection: Given two ranges, return true if they intersect and form a closed range. +// Examples: +// >10 and <=20 -> [11,20] +// >10 and >100 -> false +// <10 and >10 -> false +// +// Arguments: +// type - The type of the compare nodes. +// cmp1 - The first compare node. +// cmp2 - The second compare node. +// cns1 - The constant value of the first compare node (always RHS). +// cns2 - The constant value of the second compare node (always RHS). +// pRangeStart - [OUT] The start of the intersection range (inclusive). +// pRangeEnd - [OUT] The end of the intersection range (inclusive). +// +// Returns: +// true if the ranges intersect and form a closed range. +// +static bool GetIntersection(var_types type, + genTreeOps cmp1, + genTreeOps cmp2, + ssize_t cns1, + ssize_t cns2, + ssize_t* pRangeStart, + ssize_t* pRangeEnd) +{ + if ((cns1 < 0) || (cns2 < 0)) + { + // We don't yet support negative ranges. + return false; + } + + // Convert to a canonical form with GT_GE or GT_LE (inclusive). + auto normalize = [](genTreeOps* cmp, ssize_t* cns) { + if (*cmp == GT_GT) + { + // "X > cns" -> "X >= cns + 1" + *cns = *cns + 1; + *cmp = GT_GE; + } + if (*cmp == GT_LT) + { + // "X < cns" -> "X <= cns - 1" + *cns = *cns - 1; + *cmp = GT_LE; + } + // whether these overflow or not is checked below. + }; + normalize(&cmp1, &cns1); + normalize(&cmp2, &cns2); + + if (cmp1 == cmp2) + { + // Ranges have the same direction (we don't yet support that yet). + return false; + } + + if (cmp1 == GT_GE) + { + *pRangeStart = cns1; + *pRangeEnd = cns2; + } + else + { + assert(cmp1 == GT_LE); + *pRangeStart = cns2; + *pRangeEnd = cns1; + } + + if ((*pRangeStart >= *pRangeEnd) || (*pRangeStart < 0) || (*pRangeEnd < 0) || !FitsIn(type, *pRangeStart) || + !FitsIn(type, *pRangeEnd)) + { + // TODO: If ranges don't intersect we might be able to fold the condition to true/false. + // Also, check again if any of the ranges are negative (in case of overflow after normalization) + // and fits into the given type. + return false; + } + + return true; +} + +//------------------------------------------------------------------------------ +// IsConstantRangeTest: Does the given compare node represent a constant range test? E.g. +// "X relop CNS" or "CNS relop X" where relop is [<, <=, >, >=] +// +// Arguments: +// tree - compare node +// varNode - [OUT] this will be set to the variable part of the constant range test +// cnsNode - [OUT] this will be set to the constant part of the constant range test +// cmp - [OUT] this will be set to a normalized compare operator so that the constant +// is always on the right hand side of the compare. +// +// Returns: +// true if the compare node represents a constant range test. +// +bool IsConstantRangeTest(GenTreeOp* tree, GenTree** varNode, GenTreeIntCon** cnsNode, genTreeOps* cmp) +{ + if (tree->OperIs(GT_LE, GT_LT, GT_GE, GT_GT) && !tree->IsUnsigned()) + { + GenTree* op1 = tree->gtGetOp1(); + GenTree* op2 = tree->gtGetOp2(); + if (varTypeIsIntegral(op1) && varTypeIsIntegral(op2) && op1->TypeIs(op2->TypeGet())) + { + if (op2->IsCnsIntOrI()) + { + // X relop CNS + *varNode = op1; + *cnsNode = op2->AsIntCon(); + *cmp = tree->OperGet(); + return true; + } + if (op1->IsCnsIntOrI()) + { + // CNS relop X + *varNode = op2; + *cnsNode = op1->AsIntCon(); + + // Normalize to "X relop CNS" + *cmp = GenTree::SwapRelop(tree->OperGet()); + return true; + } + } + } + return false; +} + +//------------------------------------------------------------------------------ +// FoldRangeTests: Given two compare nodes (cmp1 && cmp2) that represent a range check, +// fold them into a single compare node if possible, e.g.: +// 1) "X >= 10 && X <= 100" -> "(X - 10) u<= 90" +// 2) "X >= 0 && X <= 100" -> "X u<= 100" +// where 'u' stands for unsigned comparison. cmp1 is used as the target node for folding. +// It's also guaranteed to be first in the execution order (so can allow some side effects). +// +// Arguments: +// compiler - compiler instance +// cmp1 - first compare node +// cmp1IsReversed - true if cmp1 is in fact reversed +// cmp2 - second compare node +// cmp2IsReversed - true if cmp2 is in fact reversed +// +// Returns: +// true if cmp1 now represents the folded range check and cmp2 can be removed. +// +bool FoldRangeTests(Compiler* comp, GenTreeOp* cmp1, bool cmp1IsReversed, GenTreeOp* cmp2, bool cmp2IsReversed) +{ + GenTree* var1Node; + GenTree* var2Node; + GenTreeIntCon* cns1Node; + GenTreeIntCon* cns2Node; + genTreeOps cmp1Op; + genTreeOps cmp2Op; + + // Make sure both conditions are constant range checks, e.g. "X > CNS" + // TODO: support more cases, e.g. "X >= 0 && X < array.Length" -> "(uint)X < array.Length" + // Basically, we can use GenTree::IsNeverNegative() for it. + if (!IsConstantRangeTest(cmp1, &var1Node, &cns1Node, &cmp1Op) || + !IsConstantRangeTest(cmp2, &var2Node, &cns2Node, &cmp2Op)) + { + return false; + } + + // Reverse the comparisons if necessary so we'll get a canonical form "cond1 == true && cond2 == true" -> InRange. + cmp1Op = cmp1IsReversed ? GenTree::ReverseRelop(cmp1Op) : cmp1Op; + cmp2Op = cmp2IsReversed ? GenTree::ReverseRelop(cmp2Op) : cmp2Op; + + // Make sure variables are the same: + if (!var2Node->OperIs(GT_LCL_VAR) || !GenTree::Compare(var1Node->gtEffectiveVal(), var2Node)) + { + // Variables don't match in two conditions + // We use gtEffectiveVal() for the first block's variable to ignore COMMAs, e.g. + // + // m_b1: + // * JTRUE void + // \--* LT int + // +--* COMMA int + // | +--* STORE_LCL_VAR int V03 cse0 + // | | \--* CAST int <- ushort <- int + // | | \--* LCL_VAR int V01 arg1 + // | \--* LCL_VAR int V03 cse0 + // \--* CNS_INT int 97 + // + // m_b2: + // * JTRUE void + // \--* GT int + // +--* LCL_VAR int V03 cse0 + // \--* CNS_INT int 122 + // + // For the m_b2 we require the variable to be just a local with no side-effects (hence, no statements) + return false; + } + + ssize_t rangeStart; + ssize_t rangeEnd; + if (!GetIntersection(var1Node->TypeGet(), cmp1Op, cmp2Op, cns1Node->IconValue(), cns2Node->IconValue(), &rangeStart, + &rangeEnd)) + { + // The range we test via two conditions is not a closed range + // TODO: We should support overlapped ranges here, e.g. "X > 10 && x > 100" -> "X > 100" + return false; + } + assert(rangeStart < rangeEnd); + + if (rangeStart == 0) + { + // We don't need to subtract anything, it's already 0-based + cmp1->gtOp1 = var1Node; + } + else + { + // We need to subtract the rangeStartIncl from the variable to make the range start from 0 + cmp1->gtOp1 = comp->gtNewOperNode(GT_SUB, var1Node->TypeGet(), var1Node, + comp->gtNewIconNode(rangeStart, var1Node->TypeGet())); + } + cmp1->gtOp2->BashToConst(rangeEnd - rangeStart, var1Node->TypeGet()); + cmp1->SetOper(cmp2IsReversed ? GT_GT : GT_LE); + cmp1->SetUnsigned(); + return true; +} + +//------------------------------------------------------------------------------ +// optOptimizeRangeTests : Optimize two conditional blocks representing a constant range test. +// E.g. "X >= 10 && X <= 100" is optimized to "(X - 10) <= 90". +// +// Return Value: +// True if m_b1 and m_b2 are merged. +// +bool OptBoolsDsc::optOptimizeRangeTests() +{ + // At this point we have two consecutive conditional blocks (BBJ_COND): m_b1 and m_b2 + assert((m_b1 != nullptr) && (m_b2 != nullptr) && (m_b3 == nullptr)); + assert(m_b1->KindIs(BBJ_COND) && m_b2->KindIs(BBJ_COND) && m_b1->NextIs(m_b2)); + + if (m_b2->isRunRarely()) + { + // We don't want to make the first comparison to be slightly slower + // if the 2nd one is rarely executed. + return false; + } + + if (!BasicBlock::sameEHRegion(m_b1, m_b2) || ((m_b2->bbFlags & BBF_DONT_REMOVE) != 0)) + { + // Conditions aren't in the same EH region or m_b2 can't be removed + return false; + } + + if (m_b1->HasJumpTo(m_b1) || m_b1->HasJumpTo(m_b2) || m_b2->HasJumpTo(m_b2) || m_b2->HasJumpTo(m_b1)) + { + // Ignoring weird cases like a condition jumping to itself or when JumpDest == Next + return false; + } + + // We're interested in just two shapes for e.g. "X > 10 && X < 100" range test: + // + BasicBlock* notInRangeBb = m_b1->GetJumpDest(); + BasicBlock* inRangeBb; + if (notInRangeBb == m_b2->GetJumpDest()) + { + // Shape 1: both conditions jump to NotInRange + // + // if (X <= 10) + // goto NotInRange; + // + // if (X >= 100) + // goto NotInRange + // + // InRange: + // ... + inRangeBb = m_b2->Next(); + } + else if (notInRangeBb == m_b2->Next()) + { + // Shape 2: 2nd block jumps to InRange + // + // if (X <= 10) + // goto NotInRange; + // + // if (X > 100) + // goto InRange + // + // NotInRange: + // ... + inRangeBb = m_b2->GetJumpDest(); + } + else + { + // Unknown shape + return false; + } + + if (!m_b2->hasSingleStmt() || (m_b2->GetUniquePred(m_comp) != m_b1)) + { + // The 2nd block has to be single-statement to avoid side-effects between the two conditions. + // Also, make sure m_b2 has no other predecessors. + return false; + } + + // m_b1 and m_b2 are both BBJ_COND blocks with GT_JTRUE(cmp) root nodes + GenTreeOp* cmp1 = m_b1->lastStmt()->GetRootNode()->gtGetOp1()->AsOp(); + GenTreeOp* cmp2 = m_b2->lastStmt()->GetRootNode()->gtGetOp1()->AsOp(); + + // cmp1 is always reversed (see shape1 and shape2 above) + const bool cmp1IsReversed = true; + + // cmp2 can be either reversed or not + const bool cmp2IsReversed = m_b2->HasJumpTo(notInRangeBb); + + if (!FoldRangeTests(m_comp, cmp1, cmp1IsReversed, cmp2, cmp2IsReversed)) + { + return false; + } + + m_comp->fgAddRefPred(inRangeBb, m_b1); + if (!cmp2IsReversed) + { + // Re-direct firstBlock to jump to inRangeBb + m_b1->SetJumpDest(inRangeBb); + } + + // Remove the 2nd condition block as we no longer need it + m_comp->fgRemoveRefPred(m_b2, m_b1); + m_comp->fgRemoveBlock(m_b2, true); + + Statement* stmt = m_b1->lastStmt(); + m_comp->gtSetStmtInfo(stmt); + m_comp->fgSetStmtSeq(stmt); + m_comp->gtUpdateStmtSideEffects(stmt); + return true; +} + //----------------------------------------------------------------------------- // optOptimizeCompareChainCondBlock: Create a chain when when both m_b1 and m_b2 are BBJ_COND. // @@ -1506,6 +1838,12 @@ PhaseStatus Compiler::optOptimizeBools() change = true; numCond++; } + else if (optBoolsDsc.optOptimizeRangeTests()) + { + change = true; + retry = true; + numCond++; + } #ifdef TARGET_ARM64 else if (optBoolsDsc.optOptimizeCompareChainCondBlock()) {