Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[CIR] Add cir-simplify pass #138317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 7, 2025
Merged

[CIR] Add cir-simplify pass #138317

merged 2 commits into from
May 7, 2025

Conversation

mmha
Copy link
Contributor

@mmha mmha commented May 2, 2025

This patch adds the cir-simplify pass for SelectOp and TernaryOp. It also adds the SelectOp folder and adds the constant materializer for the CIR dialect.

This patch adds the cir-simplify pass for SelectOp and TernaryOp. It also adds the SelectOp folder and adds the constant materializer for the CIR dialect.
@mmha mmha requested review from erichkeane and andykaylor May 2, 2025 18:06
@mmha mmha requested review from lanza and bcardosolopes as code owners May 2, 2025 18:06
@llvmbot llvmbot added clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project labels May 2, 2025
@llvmbot
Copy link
Member

llvmbot commented May 2, 2025

@llvm/pr-subscribers-clang

Author: Morris Hafner (mmha)

Changes

This patch adds the cir-simplify pass for SelectOp and TernaryOp. It also adds the SelectOp folder and adds the constant materializer for the CIR dialect.


Patch is 23.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138317.diff

16 Files Affected:

  • (modified) clang/include/clang/CIR/CIRToCIRPasses.h (+2-1)
  • (modified) clang/include/clang/CIR/Dialect/IR/CIRDialect.td (+2)
  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+2)
  • (modified) clang/include/clang/CIR/Dialect/Passes.h (+1)
  • (modified) clang/include/clang/CIR/Dialect/Passes.td (+14)
  • (modified) clang/include/clang/CIR/FrontendAction/CIRGenAction.h (+1-1)
  • (modified) clang/include/clang/CIR/MissingFeatures.h (-1)
  • (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+30)
  • (modified) clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp (+1-2)
  • (added) clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp (+184)
  • (modified) clang/lib/CIR/Dialect/Transforms/CMakeLists.txt (+1)
  • (modified) clang/lib/CIR/FrontendAction/CIRGenAction.cpp (+12-9)
  • (modified) clang/lib/CIR/Lowering/CIRPasses.cpp (+5-1)
  • (added) clang/test/CIR/Transforms/select.cir (+60)
  • (added) clang/test/CIR/Transforms/ternary-fold.cir (+60)
  • (modified) clang/tools/cir-opt/cir-opt.cpp (+3)
diff --git a/clang/include/clang/CIR/CIRToCIRPasses.h b/clang/include/clang/CIR/CIRToCIRPasses.h
index 361ebb9e9b840..4a23790ee8b76 100644
--- a/clang/include/clang/CIR/CIRToCIRPasses.h
+++ b/clang/include/clang/CIR/CIRToCIRPasses.h
@@ -32,7 +32,8 @@ namespace cir {
 mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
                                       mlir::MLIRContext &mlirCtx,
                                       clang::ASTContext &astCtx,
-                                      bool enableVerifier);
+                                      bool enableVerifier,
+                                      bool enableCIRSimplify);
 
 } // namespace cir
 
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index 73759cfa9c3c9..818a605ab74d3 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -27,6 +27,8 @@ def CIR_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 0;
   let useDefaultTypePrinterParser = 0;
 
+  let hasConstantMaterializer = 1;
+
   let extraClassDeclaration = [{
     static llvm::StringRef getTripleAttrName() { return "cir.triple"; }
 
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 9215543ab67e6..8205718e0fc30 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1464,6 +1464,8 @@ def SelectOp : CIR_Op<"select", [Pure,
       qualified(type($false_value))
     `)` `->` qualified(type($result)) attr-dict
   }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/clang/include/clang/CIR/Dialect/Passes.h b/clang/include/clang/CIR/Dialect/Passes.h
index 133eb462dcf1f..dbecf81acf7bb 100644
--- a/clang/include/clang/CIR/Dialect/Passes.h
+++ b/clang/include/clang/CIR/Dialect/Passes.h
@@ -22,6 +22,7 @@ namespace mlir {
 
 std::unique_ptr<Pass> createCIRCanonicalizePass();
 std::unique_ptr<Pass> createCIRFlattenCFGPass();
+std::unique_ptr<Pass> createCIRSimplifyPass();
 std::unique_ptr<Pass> createHoistAllocasPass();
 
 void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);
diff --git a/clang/include/clang/CIR/Dialect/Passes.td b/clang/include/clang/CIR/Dialect/Passes.td
index 74c255861c879..46fa97da04ca1 100644
--- a/clang/include/clang/CIR/Dialect/Passes.td
+++ b/clang/include/clang/CIR/Dialect/Passes.td
@@ -29,6 +29,20 @@ def CIRCanonicalize : Pass<"cir-canonicalize"> {
   let dependentDialects = ["cir::CIRDialect"];
 }
 
+def CIRSimplify : Pass<"cir-simplify"> {
+  let summary = "Performs CIR simplification and code optimization";
+  let description = [{
+    The pass performs code simplification and optimization on CIR.
+
+    Unlike the `cir-canonicalize` pass, this pass contains more aggresive code
+    transformations that could significantly affect CIR-to-source fidelity.
+    Example transformations performed in this pass include ternary folding,
+    code hoisting, etc.
+  }];
+  let constructor = "mlir::createCIRSimplifyPass()";
+  let dependentDialects = ["cir::CIRDialect"];
+}
+
 def HoistAllocas : Pass<"cir-hoist-allocas"> {
   let summary = "Hoist allocas to the entry of the function";
   let description = [{
diff --git a/clang/include/clang/CIR/FrontendAction/CIRGenAction.h b/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
index 99495f4718c5f..b52166b58b882 100644
--- a/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
+++ b/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
@@ -49,7 +49,7 @@ class CIRGenAction : public clang::ASTFrontendAction {
 public:
   ~CIRGenAction() override;
 
-  OutputType Action;
+  OutputType action;
 };
 
 class EmitCIRAction : public CIRGenAction {
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index 3db13278261e6..b26144095792d 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -199,7 +199,6 @@ struct MissingFeatures {
   static bool labelOp() { return false; }
   static bool ptrDiffOp() { return false; }
   static bool ptrStrideOp() { return false; }
-  static bool selectOp() { return false; }
   static bool switchOp() { return false; }
   static bool ternaryOp() { return false; }
   static bool tryOp() { return false; }
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index f5d6a424a71f6..5356630ece196 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -79,6 +79,14 @@ void cir::CIRDialect::initialize() {
   addInterfaces<CIROpAsmDialectInterface>();
 }
 
+Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
+                                                mlir::Attribute value,
+                                                mlir::Type type,
+                                                mlir::Location loc) {
+  return builder.create<cir::ConstantOp>(loc, type,
+                                         mlir::cast<mlir::TypedAttr>(value));
+}
+
 //===----------------------------------------------------------------------===//
 // Helpers
 //===----------------------------------------------------------------------===//
@@ -1261,6 +1269,28 @@ void cir::TernaryOp::build(
     result.addTypes(TypeRange{yield.getOperandTypes().front()});
 }
 
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
+  mlir::Attribute condition = adaptor.getCondition();
+  if (condition) {
+    bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
+    return conditionValue ? getTrueValue() : getFalseValue();
+  }
+
+  // cir.select if %0 then x else x -> x
+  mlir::Attribute trueValue = adaptor.getTrueValue();
+  mlir::Attribute falseValue = adaptor.getFalseValue();
+  if (trueValue == falseValue)
+    return trueValue;
+  if (getTrueValue() == getFalseValue())
+    return getTrueValue();
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // ShiftOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index cdac69e66dba3..3b4c7bc613133 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -121,14 +121,13 @@ void CIRCanonicalizePass::runOnOperation() {
   getOperation()->walk([&](Operation *op) {
     assert(!cir::MissingFeatures::switchOp());
     assert(!cir::MissingFeatures::tryOp());
-    assert(!cir::MissingFeatures::selectOp());
     assert(!cir::MissingFeatures::complexCreateOp());
     assert(!cir::MissingFeatures::complexRealOp());
     assert(!cir::MissingFeatures::complexImagOp());
     assert(!cir::MissingFeatures::callOp());
     // CastOp and UnaryOp are here to perform a manual `fold` in
     // applyOpPatternsGreedily.
-    if (isa<BrOp, BrCondOp, ScopeOp, CastOp, UnaryOp>(op))
+    if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op))
       ops.push_back(op);
   });
 
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
new file mode 100644
index 0000000000000..442801d062638
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -0,0 +1,184 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/Passes.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+using namespace cir;
+
+//===----------------------------------------------------------------------===//
+// Rewrite patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Simplify suitable ternary operations into select operations.
+///
+/// For now we only simplify those ternary operations whose true and false
+/// branches directly yield a value or a constant. That is, both of the true and
+/// the false branch must either contain a cir.yield operation as the only
+/// operation in the branch, or contain a cir.const operation followed by a
+/// cir.yield operation that yields the constant value.
+///
+/// For example, we will simplify the following ternary operation:
+///
+///   %0 = cir.ternary (%condition, true {
+///     %1 = cir.const ...
+///     cir.yield %1
+///   } false {
+///     cir.yield %2
+///   })
+///
+/// into the following sequence of operations:
+///
+///   %1 = cir.const ...
+///   %0 = cir.select if %condition then %1 else %2
+struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
+  using OpRewritePattern<TernaryOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TernaryOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op->getNumResults() != 1)
+      return mlir::failure();
+
+    if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
+        !isSimpleTernaryBranch(op.getFalseRegion()))
+      return mlir::failure();
+
+    cir::YieldOp trueBranchYieldOp =
+        mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
+    cir::YieldOp falseBranchYieldOp =
+        mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
+    mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
+    mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
+
+    rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
+    rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
+    rewriter.eraseOp(trueBranchYieldOp);
+    rewriter.eraseOp(falseBranchYieldOp);
+    rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
+                                               falseValue);
+
+    return mlir::success();
+  }
+
+private:
+  bool isSimpleTernaryBranch(mlir::Region &region) const {
+    if (!region.hasOneBlock())
+      return false;
+
+    mlir::Block &onlyBlock = region.front();
+    mlir::Block::OpListType &ops = onlyBlock.getOperations();
+
+    // The region/block could only contain at most 2 operations.
+    if (ops.size() > 2)
+      return false;
+
+    if (ops.size() == 1) {
+      // The region/block only contain a cir.yield operation.
+      return true;
+    }
+
+    // Check whether the region/block contains a cir.const followed by a
+    // cir.yield that yields the value.
+    auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
+    auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(
+        yieldOp.getArgs()[0].getDefiningOp());
+    return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
+  }
+};
+
+struct SimplifySelect : public OpRewritePattern<SelectOp> {
+  using OpRewritePattern<SelectOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SelectOp op,
+                                PatternRewriter &rewriter) const final {
+    mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
+    mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
+    auto trueValueConstOp =
+        mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);
+    auto falseValueConstOp =
+        mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);
+    if (!trueValueConstOp || !falseValueConstOp)
+      return mlir::failure();
+
+    auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());
+    auto falseValue =
+        mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());
+    if (!trueValue || !falseValue)
+      return mlir::failure();
+
+    // cir.select if %0 then #true else #false -> %0
+    if (trueValue.getValue() && !falseValue.getValue()) {
+      rewriter.replaceAllUsesWith(op, op.getCondition());
+      rewriter.eraseOp(op);
+      return mlir::success();
+    }
+
+    // cir.select if %0 then #false else #true -> cir.unary not %0
+    if (!trueValue.getValue() && falseValue.getValue()) {
+      rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
+                                                op.getCondition());
+      return mlir::success();
+    }
+
+    return mlir::failure();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// CIRSimplifyPass
+//===----------------------------------------------------------------------===//
+
+struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
+  using CIRSimplifyBase::CIRSimplifyBase;
+
+  void runOnOperation() override;
+};
+
+void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
+  // clang-format off
+  patterns.add<
+    SimplifyTernary,
+    SimplifySelect
+  >(patterns.getContext());
+  // clang-format on
+}
+
+void CIRSimplifyPass::runOnOperation() {
+  // Collect rewrite patterns.
+  RewritePatternSet patterns(&getContext());
+  populateMergeCleanupPatterns(patterns);
+
+  // Collect operations to apply patterns.
+  llvm::SmallVector<Operation *, 16> ops;
+  getOperation()->walk([&](Operation *op) {
+    if (isa<TernaryOp, SelectOp>(op))
+      ops.push_back(op);
+  });
+
+  // Apply patterns.
+  if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
+    signalPassFailure();
+}
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
+  return std::make_unique<CIRSimplifyPass>();
+}
diff --git a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
index 4678435b54c79..4dece5b57e450 100644
--- a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_clang_library(MLIRCIRTransforms
   CIRCanonicalize.cpp
+  CIRSimplify.cpp
   FlattenCFG.cpp
   HoistAllocas.cpp
 
diff --git a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
index a32e6a7584774..570403dda9d9f 100644
--- a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
+++ b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
@@ -62,15 +62,17 @@ class CIRGenConsumer : public clang::ASTConsumer {
   IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS;
   std::unique_ptr<CIRGenerator> Gen;
   const FrontendOptions &FEOptions;
+  CodeGenOptions &codeGenOptions;
 
 public:
   CIRGenConsumer(CIRGenAction::OutputType Action, CompilerInstance &CI,
+                 CodeGenOptions &codeGenOptions,
                  std::unique_ptr<raw_pwrite_stream> OS)
       : Action(Action), CI(CI), OutputStream(std::move(OS)),
         FS(&CI.getVirtualFileSystem()),
         Gen(std::make_unique<CIRGenerator>(CI.getDiagnostics(), std::move(FS),
                                            CI.getCodeGenOpts())),
-        FEOptions(CI.getFrontendOpts()) {}
+        FEOptions(CI.getFrontendOpts()), codeGenOptions(codeGenOptions) {}
 
   void Initialize(ASTContext &Ctx) override {
     assert(!Context && "initialized multiple times");
@@ -102,7 +104,8 @@ class CIRGenConsumer : public clang::ASTConsumer {
     if (!FEOptions.ClangIRDisablePasses) {
       // Setup and run CIR pipeline.
       if (runCIRToCIRPasses(MlirModule, MlirCtx, C,
-                            !FEOptions.ClangIRDisableCIRVerifier)
+                            !FEOptions.ClangIRDisableCIRVerifier,
+                            codeGenOptions.OptimizationLevel > 0)
               .failed()) {
         CI.getDiagnostics().Report(diag::err_cir_to_cir_transform_failed);
         return;
@@ -139,7 +142,7 @@ class CIRGenConsumer : public clang::ASTConsumer {
 void CIRGenConsumer::anchor() {}
 
 CIRGenAction::CIRGenAction(OutputType Act, mlir::MLIRContext *MLIRCtx)
-    : MLIRCtx(MLIRCtx ? MLIRCtx : new mlir::MLIRContext), Action(Act) {}
+    : MLIRCtx(MLIRCtx ? MLIRCtx : new mlir::MLIRContext), action(Act) {}
 
 CIRGenAction::~CIRGenAction() { MLIRMod.release(); }
 
@@ -162,14 +165,14 @@ getOutputStream(CompilerInstance &CI, StringRef InFile,
 }
 
 std::unique_ptr<ASTConsumer>
-CIRGenAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
-  std::unique_ptr<llvm::raw_pwrite_stream> Out = CI.takeOutputStream();
+CIRGenAction::CreateASTConsumer(CompilerInstance &ci, StringRef inFile) {
+  std::unique_ptr<llvm::raw_pwrite_stream> out = ci.takeOutputStream();
 
-  if (!Out)
-    Out = getOutputStream(CI, InFile, Action);
+  if (!out)
+    out = getOutputStream(ci, inFile, action);
 
-  auto Result =
-      std::make_unique<cir::CIRGenConsumer>(Action, CI, std::move(Out));
+  auto Result = std::make_unique<cir::CIRGenConsumer>(
+      action, ci, ci.getCodeGenOpts(), std::move(out));
 
   return Result;
 }
diff --git a/clang/lib/CIR/Lowering/CIRPasses.cpp b/clang/lib/CIR/Lowering/CIRPasses.cpp
index a37a0480a56ac..7a581939580a9 100644
--- a/clang/lib/CIR/Lowering/CIRPasses.cpp
+++ b/clang/lib/CIR/Lowering/CIRPasses.cpp
@@ -20,13 +20,17 @@ namespace cir {
 mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
                                       mlir::MLIRContext &mlirContext,
                                       clang::ASTContext &astContext,
-                                      bool enableVerifier) {
+                                      bool enableVerifier,
+                                      bool enableCIRSimplify) {
 
   llvm::TimeTraceScope scope("CIR To CIR Passes");
 
   mlir::PassManager pm(&mlirContext);
   pm.addPass(mlir::createCIRCanonicalizePass());
 
+  if (enableCIRSimplify)
+    pm.addPass(mlir::createCIRSimplifyPass());
+
   pm.enableVerifier(enableVerifier);
   (void)mlir::applyPassManagerCLOptions(pm);
   return pm.run(theModule);
diff --git a/clang/test/CIR/Transforms/select.cir b/clang/test/CIR/Transforms/select.cir
new file mode 100644
index 0000000000000..29a5d1ed1ddeb
--- /dev/null
+++ b/clang/test/CIR/Transforms/select.cir
@@ -0,0 +1,60 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+  cir.func @fold_true(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
+    %0 = cir.const #cir.bool<true> : !cir.bool
+    %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
+    cir.return %1 : !s32i
+  }
+
+  //      CHECK: cir.func @fold_true(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
+  // CHECK-NEXT:   cir.return %[[ARG0]] : !s32i
+  // CHECK-NEXT: }
+
+  cir.func @fold_false(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
+    %0 = cir.const #cir.bool<false> : !cir.bool
+    %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
+    cir.return %1 : !s32i
+  }
+
+  //      CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
+  // CHECK-NEXT:   cir.return %[[ARG1]] : !s32i
+  // CHECK-NEXT: }
+
+  cir.func @fold_to_const(%arg0 : !cir.bool) -> !s32i {
+    %0 = cir.const #cir.int<42> : !s32i
+    %1 = cir.select if %arg0 then %0 else %0 : (!cir.bool, !s32i, !s32i) -> !s32i
+    cir.return %1 : !s32i
+  }
+
+  //      CHECK: cir.func @fold_to_const(%{{.+}}: !cir.bool) -> !s32i {
+  // CHECK-NEXT:   %[[#A:]] = cir.const #cir.int<42> : !s32i
+  // CHECK-NEXT:   cir.return %[[#A]] : !s32i
+  // CHECK-NEXT: }
+
+  cir.func @simplify_1(%arg0 : !cir.bool) -> !cir.bool {
+    %0 = cir.const #cir.bool<true> : !cir.bool
+    %1 = cir.const #cir.bool<false> : !cir.bool
+    %2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
+    cir.return %2 : !cir.bool
+  }
+
+  //      CHECK: cir.func @simplify_1(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
+  // CHECK-NEXT:   cir.return %[[ARG0]] : !cir.bool
+  // CHECK-NEXT: }
+
+  cir.func @simplify_2(%ar...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented May 2, 2025

@llvm/pr-subscribers-clangir

Author: Morris Hafner (mmha)

Changes

This patch adds the cir-simplify pass for SelectOp and TernaryOp. It also adds the SelectOp folder and adds the constant materializer for the CIR dialect.


Patch is 23.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138317.diff

16 Files Affected:

  • (modified) clang/include/clang/CIR/CIRToCIRPasses.h (+2-1)
  • (modified) clang/include/clang/CIR/Dialect/IR/CIRDialect.td (+2)
  • (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+2)
  • (modified) clang/include/clang/CIR/Dialect/Passes.h (+1)
  • (modified) clang/include/clang/CIR/Dialect/Passes.td (+14)
  • (modified) clang/include/clang/CIR/FrontendAction/CIRGenAction.h (+1-1)
  • (modified) clang/include/clang/CIR/MissingFeatures.h (-1)
  • (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+30)
  • (modified) clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp (+1-2)
  • (added) clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp (+184)
  • (modified) clang/lib/CIR/Dialect/Transforms/CMakeLists.txt (+1)
  • (modified) clang/lib/CIR/FrontendAction/CIRGenAction.cpp (+12-9)
  • (modified) clang/lib/CIR/Lowering/CIRPasses.cpp (+5-1)
  • (added) clang/test/CIR/Transforms/select.cir (+60)
  • (added) clang/test/CIR/Transforms/ternary-fold.cir (+60)
  • (modified) clang/tools/cir-opt/cir-opt.cpp (+3)
diff --git a/clang/include/clang/CIR/CIRToCIRPasses.h b/clang/include/clang/CIR/CIRToCIRPasses.h
index 361ebb9e9b840..4a23790ee8b76 100644
--- a/clang/include/clang/CIR/CIRToCIRPasses.h
+++ b/clang/include/clang/CIR/CIRToCIRPasses.h
@@ -32,7 +32,8 @@ namespace cir {
 mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
                                       mlir::MLIRContext &mlirCtx,
                                       clang::ASTContext &astCtx,
-                                      bool enableVerifier);
+                                      bool enableVerifier,
+                                      bool enableCIRSimplify);
 
 } // namespace cir
 
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index 73759cfa9c3c9..818a605ab74d3 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -27,6 +27,8 @@ def CIR_Dialect : Dialect {
   let useDefaultAttributePrinterParser = 0;
   let useDefaultTypePrinterParser = 0;
 
+  let hasConstantMaterializer = 1;
+
   let extraClassDeclaration = [{
     static llvm::StringRef getTripleAttrName() { return "cir.triple"; }
 
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 9215543ab67e6..8205718e0fc30 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1464,6 +1464,8 @@ def SelectOp : CIR_Op<"select", [Pure,
       qualified(type($false_value))
     `)` `->` qualified(type($result)) attr-dict
   }];
+
+  let hasFolder = 1;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/clang/include/clang/CIR/Dialect/Passes.h b/clang/include/clang/CIR/Dialect/Passes.h
index 133eb462dcf1f..dbecf81acf7bb 100644
--- a/clang/include/clang/CIR/Dialect/Passes.h
+++ b/clang/include/clang/CIR/Dialect/Passes.h
@@ -22,6 +22,7 @@ namespace mlir {
 
 std::unique_ptr<Pass> createCIRCanonicalizePass();
 std::unique_ptr<Pass> createCIRFlattenCFGPass();
+std::unique_ptr<Pass> createCIRSimplifyPass();
 std::unique_ptr<Pass> createHoistAllocasPass();
 
 void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);
diff --git a/clang/include/clang/CIR/Dialect/Passes.td b/clang/include/clang/CIR/Dialect/Passes.td
index 74c255861c879..46fa97da04ca1 100644
--- a/clang/include/clang/CIR/Dialect/Passes.td
+++ b/clang/include/clang/CIR/Dialect/Passes.td
@@ -29,6 +29,20 @@ def CIRCanonicalize : Pass<"cir-canonicalize"> {
   let dependentDialects = ["cir::CIRDialect"];
 }
 
+def CIRSimplify : Pass<"cir-simplify"> {
+  let summary = "Performs CIR simplification and code optimization";
+  let description = [{
+    The pass performs code simplification and optimization on CIR.
+
+    Unlike the `cir-canonicalize` pass, this pass contains more aggresive code
+    transformations that could significantly affect CIR-to-source fidelity.
+    Example transformations performed in this pass include ternary folding,
+    code hoisting, etc.
+  }];
+  let constructor = "mlir::createCIRSimplifyPass()";
+  let dependentDialects = ["cir::CIRDialect"];
+}
+
 def HoistAllocas : Pass<"cir-hoist-allocas"> {
   let summary = "Hoist allocas to the entry of the function";
   let description = [{
diff --git a/clang/include/clang/CIR/FrontendAction/CIRGenAction.h b/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
index 99495f4718c5f..b52166b58b882 100644
--- a/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
+++ b/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
@@ -49,7 +49,7 @@ class CIRGenAction : public clang::ASTFrontendAction {
 public:
   ~CIRGenAction() override;
 
-  OutputType Action;
+  OutputType action;
 };
 
 class EmitCIRAction : public CIRGenAction {
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index 3db13278261e6..b26144095792d 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -199,7 +199,6 @@ struct MissingFeatures {
   static bool labelOp() { return false; }
   static bool ptrDiffOp() { return false; }
   static bool ptrStrideOp() { return false; }
-  static bool selectOp() { return false; }
   static bool switchOp() { return false; }
   static bool ternaryOp() { return false; }
   static bool tryOp() { return false; }
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index f5d6a424a71f6..5356630ece196 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -79,6 +79,14 @@ void cir::CIRDialect::initialize() {
   addInterfaces<CIROpAsmDialectInterface>();
 }
 
+Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
+                                                mlir::Attribute value,
+                                                mlir::Type type,
+                                                mlir::Location loc) {
+  return builder.create<cir::ConstantOp>(loc, type,
+                                         mlir::cast<mlir::TypedAttr>(value));
+}
+
 //===----------------------------------------------------------------------===//
 // Helpers
 //===----------------------------------------------------------------------===//
@@ -1261,6 +1269,28 @@ void cir::TernaryOp::build(
     result.addTypes(TypeRange{yield.getOperandTypes().front()});
 }
 
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
+  mlir::Attribute condition = adaptor.getCondition();
+  if (condition) {
+    bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
+    return conditionValue ? getTrueValue() : getFalseValue();
+  }
+
+  // cir.select if %0 then x else x -> x
+  mlir::Attribute trueValue = adaptor.getTrueValue();
+  mlir::Attribute falseValue = adaptor.getFalseValue();
+  if (trueValue == falseValue)
+    return trueValue;
+  if (getTrueValue() == getFalseValue())
+    return getTrueValue();
+
+  return {};
+}
+
 //===----------------------------------------------------------------------===//
 // ShiftOp
 //===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index cdac69e66dba3..3b4c7bc613133 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -121,14 +121,13 @@ void CIRCanonicalizePass::runOnOperation() {
   getOperation()->walk([&](Operation *op) {
     assert(!cir::MissingFeatures::switchOp());
     assert(!cir::MissingFeatures::tryOp());
-    assert(!cir::MissingFeatures::selectOp());
     assert(!cir::MissingFeatures::complexCreateOp());
     assert(!cir::MissingFeatures::complexRealOp());
     assert(!cir::MissingFeatures::complexImagOp());
     assert(!cir::MissingFeatures::callOp());
     // CastOp and UnaryOp are here to perform a manual `fold` in
     // applyOpPatternsGreedily.
-    if (isa<BrOp, BrCondOp, ScopeOp, CastOp, UnaryOp>(op))
+    if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op))
       ops.push_back(op);
   });
 
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
new file mode 100644
index 0000000000000..442801d062638
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -0,0 +1,184 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/Passes.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+using namespace cir;
+
+//===----------------------------------------------------------------------===//
+// Rewrite patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Simplify suitable ternary operations into select operations.
+///
+/// For now we only simplify those ternary operations whose true and false
+/// branches directly yield a value or a constant. That is, both of the true and
+/// the false branch must either contain a cir.yield operation as the only
+/// operation in the branch, or contain a cir.const operation followed by a
+/// cir.yield operation that yields the constant value.
+///
+/// For example, we will simplify the following ternary operation:
+///
+///   %0 = cir.ternary (%condition, true {
+///     %1 = cir.const ...
+///     cir.yield %1
+///   } false {
+///     cir.yield %2
+///   })
+///
+/// into the following sequence of operations:
+///
+///   %1 = cir.const ...
+///   %0 = cir.select if %condition then %1 else %2
+struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
+  using OpRewritePattern<TernaryOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(TernaryOp op,
+                                PatternRewriter &rewriter) const override {
+    if (op->getNumResults() != 1)
+      return mlir::failure();
+
+    if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
+        !isSimpleTernaryBranch(op.getFalseRegion()))
+      return mlir::failure();
+
+    cir::YieldOp trueBranchYieldOp =
+        mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
+    cir::YieldOp falseBranchYieldOp =
+        mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
+    mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
+    mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
+
+    rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
+    rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
+    rewriter.eraseOp(trueBranchYieldOp);
+    rewriter.eraseOp(falseBranchYieldOp);
+    rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
+                                               falseValue);
+
+    return mlir::success();
+  }
+
+private:
+  bool isSimpleTernaryBranch(mlir::Region &region) const {
+    if (!region.hasOneBlock())
+      return false;
+
+    mlir::Block &onlyBlock = region.front();
+    mlir::Block::OpListType &ops = onlyBlock.getOperations();
+
+    // The region/block could only contain at most 2 operations.
+    if (ops.size() > 2)
+      return false;
+
+    if (ops.size() == 1) {
+      // The region/block only contain a cir.yield operation.
+      return true;
+    }
+
+    // Check whether the region/block contains a cir.const followed by a
+    // cir.yield that yields the value.
+    auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
+    auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(
+        yieldOp.getArgs()[0].getDefiningOp());
+    return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
+  }
+};
+
+struct SimplifySelect : public OpRewritePattern<SelectOp> {
+  using OpRewritePattern<SelectOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(SelectOp op,
+                                PatternRewriter &rewriter) const final {
+    mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
+    mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
+    auto trueValueConstOp =
+        mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);
+    auto falseValueConstOp =
+        mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);
+    if (!trueValueConstOp || !falseValueConstOp)
+      return mlir::failure();
+
+    auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());
+    auto falseValue =
+        mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());
+    if (!trueValue || !falseValue)
+      return mlir::failure();
+
+    // cir.select if %0 then #true else #false -> %0
+    if (trueValue.getValue() && !falseValue.getValue()) {
+      rewriter.replaceAllUsesWith(op, op.getCondition());
+      rewriter.eraseOp(op);
+      return mlir::success();
+    }
+
+    // cir.select if %0 then #false else #true -> cir.unary not %0
+    if (!trueValue.getValue() && falseValue.getValue()) {
+      rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
+                                                op.getCondition());
+      return mlir::success();
+    }
+
+    return mlir::failure();
+  }
+};
+
+//===----------------------------------------------------------------------===//
+// CIRSimplifyPass
+//===----------------------------------------------------------------------===//
+
+struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
+  using CIRSimplifyBase::CIRSimplifyBase;
+
+  void runOnOperation() override;
+};
+
+void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
+  // clang-format off
+  patterns.add<
+    SimplifyTernary,
+    SimplifySelect
+  >(patterns.getContext());
+  // clang-format on
+}
+
+void CIRSimplifyPass::runOnOperation() {
+  // Collect rewrite patterns.
+  RewritePatternSet patterns(&getContext());
+  populateMergeCleanupPatterns(patterns);
+
+  // Collect operations to apply patterns.
+  llvm::SmallVector<Operation *, 16> ops;
+  getOperation()->walk([&](Operation *op) {
+    if (isa<TernaryOp, SelectOp>(op))
+      ops.push_back(op);
+  });
+
+  // Apply patterns.
+  if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
+    signalPassFailure();
+}
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
+  return std::make_unique<CIRSimplifyPass>();
+}
diff --git a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
index 4678435b54c79..4dece5b57e450 100644
--- a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
 add_clang_library(MLIRCIRTransforms
   CIRCanonicalize.cpp
+  CIRSimplify.cpp
   FlattenCFG.cpp
   HoistAllocas.cpp
 
diff --git a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
index a32e6a7584774..570403dda9d9f 100644
--- a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
+++ b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
@@ -62,15 +62,17 @@ class CIRGenConsumer : public clang::ASTConsumer {
   IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS;
   std::unique_ptr<CIRGenerator> Gen;
   const FrontendOptions &FEOptions;
+  CodeGenOptions &codeGenOptions;
 
 public:
   CIRGenConsumer(CIRGenAction::OutputType Action, CompilerInstance &CI,
+                 CodeGenOptions &codeGenOptions,
                  std::unique_ptr<raw_pwrite_stream> OS)
       : Action(Action), CI(CI), OutputStream(std::move(OS)),
         FS(&CI.getVirtualFileSystem()),
         Gen(std::make_unique<CIRGenerator>(CI.getDiagnostics(), std::move(FS),
                                            CI.getCodeGenOpts())),
-        FEOptions(CI.getFrontendOpts()) {}
+        FEOptions(CI.getFrontendOpts()), codeGenOptions(codeGenOptions) {}
 
   void Initialize(ASTContext &Ctx) override {
     assert(!Context && "initialized multiple times");
@@ -102,7 +104,8 @@ class CIRGenConsumer : public clang::ASTConsumer {
     if (!FEOptions.ClangIRDisablePasses) {
       // Setup and run CIR pipeline.
       if (runCIRToCIRPasses(MlirModule, MlirCtx, C,
-                            !FEOptions.ClangIRDisableCIRVerifier)
+                            !FEOptions.ClangIRDisableCIRVerifier,
+                            codeGenOptions.OptimizationLevel > 0)
               .failed()) {
         CI.getDiagnostics().Report(diag::err_cir_to_cir_transform_failed);
         return;
@@ -139,7 +142,7 @@ class CIRGenConsumer : public clang::ASTConsumer {
 void CIRGenConsumer::anchor() {}
 
 CIRGenAction::CIRGenAction(OutputType Act, mlir::MLIRContext *MLIRCtx)
-    : MLIRCtx(MLIRCtx ? MLIRCtx : new mlir::MLIRContext), Action(Act) {}
+    : MLIRCtx(MLIRCtx ? MLIRCtx : new mlir::MLIRContext), action(Act) {}
 
 CIRGenAction::~CIRGenAction() { MLIRMod.release(); }
 
@@ -162,14 +165,14 @@ getOutputStream(CompilerInstance &CI, StringRef InFile,
 }
 
 std::unique_ptr<ASTConsumer>
-CIRGenAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
-  std::unique_ptr<llvm::raw_pwrite_stream> Out = CI.takeOutputStream();
+CIRGenAction::CreateASTConsumer(CompilerInstance &ci, StringRef inFile) {
+  std::unique_ptr<llvm::raw_pwrite_stream> out = ci.takeOutputStream();
 
-  if (!Out)
-    Out = getOutputStream(CI, InFile, Action);
+  if (!out)
+    out = getOutputStream(ci, inFile, action);
 
-  auto Result =
-      std::make_unique<cir::CIRGenConsumer>(Action, CI, std::move(Out));
+  auto Result = std::make_unique<cir::CIRGenConsumer>(
+      action, ci, ci.getCodeGenOpts(), std::move(out));
 
   return Result;
 }
diff --git a/clang/lib/CIR/Lowering/CIRPasses.cpp b/clang/lib/CIR/Lowering/CIRPasses.cpp
index a37a0480a56ac..7a581939580a9 100644
--- a/clang/lib/CIR/Lowering/CIRPasses.cpp
+++ b/clang/lib/CIR/Lowering/CIRPasses.cpp
@@ -20,13 +20,17 @@ namespace cir {
 mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
                                       mlir::MLIRContext &mlirContext,
                                       clang::ASTContext &astContext,
-                                      bool enableVerifier) {
+                                      bool enableVerifier,
+                                      bool enableCIRSimplify) {
 
   llvm::TimeTraceScope scope("CIR To CIR Passes");
 
   mlir::PassManager pm(&mlirContext);
   pm.addPass(mlir::createCIRCanonicalizePass());
 
+  if (enableCIRSimplify)
+    pm.addPass(mlir::createCIRSimplifyPass());
+
   pm.enableVerifier(enableVerifier);
   (void)mlir::applyPassManagerCLOptions(pm);
   return pm.run(theModule);
diff --git a/clang/test/CIR/Transforms/select.cir b/clang/test/CIR/Transforms/select.cir
new file mode 100644
index 0000000000000..29a5d1ed1ddeb
--- /dev/null
+++ b/clang/test/CIR/Transforms/select.cir
@@ -0,0 +1,60 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+  cir.func @fold_true(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
+    %0 = cir.const #cir.bool<true> : !cir.bool
+    %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
+    cir.return %1 : !s32i
+  }
+
+  //      CHECK: cir.func @fold_true(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
+  // CHECK-NEXT:   cir.return %[[ARG0]] : !s32i
+  // CHECK-NEXT: }
+
+  cir.func @fold_false(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
+    %0 = cir.const #cir.bool<false> : !cir.bool
+    %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
+    cir.return %1 : !s32i
+  }
+
+  //      CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
+  // CHECK-NEXT:   cir.return %[[ARG1]] : !s32i
+  // CHECK-NEXT: }
+
+  cir.func @fold_to_const(%arg0 : !cir.bool) -> !s32i {
+    %0 = cir.const #cir.int<42> : !s32i
+    %1 = cir.select if %arg0 then %0 else %0 : (!cir.bool, !s32i, !s32i) -> !s32i
+    cir.return %1 : !s32i
+  }
+
+  //      CHECK: cir.func @fold_to_const(%{{.+}}: !cir.bool) -> !s32i {
+  // CHECK-NEXT:   %[[#A:]] = cir.const #cir.int<42> : !s32i
+  // CHECK-NEXT:   cir.return %[[#A]] : !s32i
+  // CHECK-NEXT: }
+
+  cir.func @simplify_1(%arg0 : !cir.bool) -> !cir.bool {
+    %0 = cir.const #cir.bool<true> : !cir.bool
+    %1 = cir.const #cir.bool<false> : !cir.bool
+    %2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
+    cir.return %2 : !cir.bool
+  }
+
+  //      CHECK: cir.func @simplify_1(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
+  // CHECK-NEXT:   cir.return %[[ARG0]] : !cir.bool
+  // CHECK-NEXT: }
+
+  cir.func @simplify_2(%ar...
[truncated]

Copy link
Collaborator

@erichkeane erichkeane left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the best reviewer for this, but found 1 thing I'd like a comment on. Else LGTM pending others being ok with it.

@@ -27,6 +27,8 @@ def CIR_Dialect : Dialect {
let useDefaultAttributePrinterParser = 0;
let useDefaultTypePrinterParser = 0;

let hasConstantMaterializer = 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh boy would i love a comment explaining this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's some documentation here: https://mlir.llvm.org/docs/Canonicalization/#generating-constants-from-attributes

But yes I can add a small comment explaining that we need this for canonicalization.

The pass performs code simplification and optimization on CIR.

Unlike the `cir-canonicalize` pass, this pass contains more aggresive code
transformations that could significantly affect CIR-to-source fidelity.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The phrase "could significantly affect CIR-to-source fidelity" is somewhat alarming. I don't think we're doing anything that changes semantics are we? I think this comment is about transformations that could affect ease of debugging. Perhaps we can word it more gently?

@@ -49,7 +49,7 @@ class CIRGenAction : public clang::ASTFrontendAction {
public:
~CIRGenAction() override;

OutputType Action;
OutputType action;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file uses the normal clang coding style, so variable names should be capitalized.

@@ -62,15 +62,17 @@ class CIRGenConsumer : public clang::ASTConsumer {
IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS;
std::unique_ptr<CIRGenerator> Gen;
const FrontendOptions &FEOptions;
CodeGenOptions &codeGenOptions;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
CodeGenOptions &codeGenOptions;
CodeGenOptions &CodeGenOptions;

Keep clang-style identifier naming throughout this file.

/// %1 = cir.const ...
/// cir.yield %1
/// } false {
/// cir.yield %2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example is confusing. The identifier %2 gives the impression that it was defined sometime after %0 but I don't think that's the intention. This will only happen if the false case returns a value that exists prior to the ternary, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would probably look a bit better:

///   %0 = ...
///   %1 = cir.ternary (%condition, true {
///     %2 = cir.const ...
///     cir.yield %2
///   } false {
///     cir.yield %0

}
};

struct SimplifySelect : public OpRewritePattern<SelectOp> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be great if this had an explanatory comment like SimplifyTernary does. If there are going to be more cases added later, a general comment is fine. Otherwise, the two cases handled here can be explained pretty easily here.

Copy link
Member

@bcardosolopes bcardosolopes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM with minor tip: can you also add a source to CIR test that guarantee this kicks in from the driver?

/// %1 = cir.const ...
/// cir.yield %1
/// } false {
/// cir.yield %2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would probably look a bit better:

///   %0 = ...
///   %1 = cir.ternary (%condition, true {
///     %2 = cir.const ...
///     cir.yield %2
///   } false {
///     cir.yield %0

@mmha
Copy link
Contributor Author

mmha commented May 6, 2025

@bcardosolopes Adding a test for C++ -> CIR depends on #138003 which adds lowering for ?:, && and ||. Since this one's ready to be merged apart from your remark IMO I'd add that test in #138003

Copy link
Contributor

@andykaylor andykaylor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@Andres-Salamanca
Copy link
Contributor

@bcardosolopes Adding a test for C++ -> CIR depends on #138003 which adds lowering for ?:, && and ||. Since this one's ready to be merged apart from your remark IMO I'd add that test in #138003

I think you're referring to this PR: #138156

@mmha mmha merged commit 2eb6545 into llvm:main May 7, 2025
11 checks passed
@llvm-ci
Copy link
Collaborator

llvm-ci commented May 7, 2025

LLVM Buildbot has detected a new failure on builder llvm-clang-x86_64-gcc-ubuntu running on sie-linux-worker3 while building clang at step 6 "test-build-unified-tree-check-all".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/174/builds/17384

Here is the relevant piece of the build log for the reference
Step 6 (test-build-unified-tree-check-all) failure: test (failure)
******************** TEST 'AddressSanitizer-x86_64-linux-dynamic :: TestCases/asan_lsan_deadlock.cpp' FAILED ********************
Exit Code: 1

Command Output (stderr):
--
/home/buildbot/buildbot-root/llvm-clang-x86_64-gcc-ubuntu/build/./bin/clang  --driver-mode=g++ -fsanitize=address -mno-omit-leaf-frame-pointer -fno-omit-frame-pointer -fno-optimize-sibling-calls -gline-tables-only  -m64  -shared-libasan -O0 /home/buildbot/buildbot-root/llvm-clang-x86_64-gcc-ubuntu/llvm-project/compiler-rt/test/asan/TestCases/asan_lsan_deadlock.cpp -o /home/buildbot/buildbot-root/llvm-clang-x86_64-gcc-ubuntu/build/runtimes/runtimes-bins/compiler-rt/test/asan/X86_64LinuxDynamicConfig/TestCases/Output/asan_lsan_deadlock.cpp.tmp # RUN: at line 4
+ /home/buildbot/buildbot-root/llvm-clang-x86_64-gcc-ubuntu/build/./bin/clang --driver-mode=g++ -fsanitize=address -mno-omit-leaf-frame-pointer -fno-omit-frame-pointer -fno-optimize-sibling-calls -gline-tables-only -m64 -shared-libasan -O0 /home/buildbot/buildbot-root/llvm-clang-x86_64-gcc-ubuntu/llvm-project/compiler-rt/test/asan/TestCases/asan_lsan_deadlock.cpp -o /home/buildbot/buildbot-root/llvm-clang-x86_64-gcc-ubuntu/build/runtimes/runtimes-bins/compiler-rt/test/asan/X86_64LinuxDynamicConfig/TestCases/Output/asan_lsan_deadlock.cpp.tmp
env ASAN_OPTIONS=detect_leaks=1 not  /home/buildbot/buildbot-root/llvm-clang-x86_64-gcc-ubuntu/build/runtimes/runtimes-bins/compiler-rt/test/asan/X86_64LinuxDynamicConfig/TestCases/Output/asan_lsan_deadlock.cpp.tmp 2>&1 | FileCheck /home/buildbot/buildbot-root/llvm-clang-x86_64-gcc-ubuntu/llvm-project/compiler-rt/test/asan/TestCases/asan_lsan_deadlock.cpp # RUN: at line 5
+ FileCheck /home/buildbot/buildbot-root/llvm-clang-x86_64-gcc-ubuntu/llvm-project/compiler-rt/test/asan/TestCases/asan_lsan_deadlock.cpp
+ env ASAN_OPTIONS=detect_leaks=1 not /home/buildbot/buildbot-root/llvm-clang-x86_64-gcc-ubuntu/build/runtimes/runtimes-bins/compiler-rt/test/asan/X86_64LinuxDynamicConfig/TestCases/Output/asan_lsan_deadlock.cpp.tmp
�[1m/home/buildbot/buildbot-root/llvm-clang-x86_64-gcc-ubuntu/llvm-project/compiler-rt/test/asan/TestCases/asan_lsan_deadlock.cpp:58:12: �[0m�[0;1;31merror: �[0m�[1mCHECK: expected string not found in input
�[0m // CHECK: SUMMARY: AddressSanitizer: stack-buffer-overflow
�[0;1;32m           ^
�[0m�[1m<stdin>:1:1: �[0m�[0;1;30mnote: �[0m�[1mscanning from here
�[0m=================================================================
�[0;1;32m^
�[0m�[1m<stdin>:2:10: �[0m�[0;1;30mnote: �[0m�[1mpossible intended match here
�[0m==1549046==ERROR: AddressSanitizer: stack-buffer-overflow on address 0x7b358cfde034 at pc 0x55f6fbee2220 bp 0x7b358b1fdce0 sp 0x7b358b1fdcd8
�[0;1;32m         ^
�[0m
Input file: <stdin>
Check file: /home/buildbot/buildbot-root/llvm-clang-x86_64-gcc-ubuntu/llvm-project/compiler-rt/test/asan/TestCases/asan_lsan_deadlock.cpp

-dump-input=help explains the following input dump.

Input was:
<<<<<<
�[1m�[0m�[0;1;30m            1: �[0m�[1m�[0;1;46m================================================================= �[0m
�[0;1;31mcheck:58'0     X~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ error: no match found
�[0m�[0;1;30m            2: �[0m�[1m�[0;1;46m==1549046==ERROR: AddressSanitizer: stack-buffer-overflow on address 0x7b358cfde034 at pc 0x55f6fbee2220 bp 0x7b358b1fdce0 sp 0x7b358b1fdcd8 �[0m
�[0;1;31mcheck:58'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
�[0m�[0;1;35mcheck:58'1              ?                                                                                                                                    possible intended match
�[0m�[0;1;30m            3: �[0m�[1m�[0;1;46mWRITE of size 4 at 0x7b358cfde034 thread T2 �[0m
�[0;1;31mcheck:58'0     ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
�[0m>>>>>>

--

********************


@llvm-ci
Copy link
Collaborator

llvm-ci commented May 7, 2025

LLVM Buildbot has detected a new failure on builder clang-ppc64-aix running on aix-ppc64 while building clang at step 6 "test-build-unified-tree-check-all".

Full details are available at: https://lab.llvm.org/buildbot/#/builders/64/builds/3451

Here is the relevant piece of the build log for the reference
Step 6 (test-build-unified-tree-check-all) failure: test (failure)
******************** TEST 'lit :: timeout-hang.py' FAILED ********************
Exit Code: 1

Command Output (stdout):
--
# RUN: at line 13
not env -u FILECHECK_OPTS "/home/llvm/llvm-external-buildbots/workers/env/bin/python3.11" /home/llvm/llvm-external-buildbots/workers/aix-ppc64/clang-ppc64-aix/llvm-project/llvm/utils/lit/lit.py -j1 --order=lexical Inputs/timeout-hang/run-nonexistent.txt  --timeout=1 --param external=0 | "/home/llvm/llvm-external-buildbots/workers/env/bin/python3.11" /home/llvm/llvm-external-buildbots/workers/aix-ppc64/clang-ppc64-aix/build/utils/lit/tests/timeout-hang.py 1
# executed command: not env -u FILECHECK_OPTS /home/llvm/llvm-external-buildbots/workers/env/bin/python3.11 /home/llvm/llvm-external-buildbots/workers/aix-ppc64/clang-ppc64-aix/llvm-project/llvm/utils/lit/lit.py -j1 --order=lexical Inputs/timeout-hang/run-nonexistent.txt --timeout=1 --param external=0
# .---command stderr------------
# | lit.py: /home/llvm/llvm-external-buildbots/workers/aix-ppc64/clang-ppc64-aix/llvm-project/llvm/utils/lit/lit/main.py:72: note: The test suite configuration requested an individual test timeout of 0 seconds but a timeout of 1 seconds was requested on the command line. Forcing timeout to be 1 seconds.
# `-----------------------------
# executed command: /home/llvm/llvm-external-buildbots/workers/env/bin/python3.11 /home/llvm/llvm-external-buildbots/workers/aix-ppc64/clang-ppc64-aix/build/utils/lit/tests/timeout-hang.py 1
# .---command stdout------------
# | Testing took as long or longer than timeout
# `-----------------------------
# error: command failed with exit status: 1

--

********************


Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang Clang issues not falling into any other category ClangIR Anything related to the ClangIR project
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants