-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[CIR] Add cir-simplify pass #138317
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CIR] Add cir-simplify pass #138317
Conversation
This patch adds the cir-simplify pass for SelectOp and TernaryOp. It also adds the SelectOp folder and adds the constant materializer for the CIR dialect.
@llvm/pr-subscribers-clang Author: Morris Hafner (mmha) ChangesThis patch adds the cir-simplify pass for SelectOp and TernaryOp. It also adds the SelectOp folder and adds the constant materializer for the CIR dialect. Patch is 23.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138317.diff 16 Files Affected:
diff --git a/clang/include/clang/CIR/CIRToCIRPasses.h b/clang/include/clang/CIR/CIRToCIRPasses.h
index 361ebb9e9b840..4a23790ee8b76 100644
--- a/clang/include/clang/CIR/CIRToCIRPasses.h
+++ b/clang/include/clang/CIR/CIRToCIRPasses.h
@@ -32,7 +32,8 @@ namespace cir {
mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
mlir::MLIRContext &mlirCtx,
clang::ASTContext &astCtx,
- bool enableVerifier);
+ bool enableVerifier,
+ bool enableCIRSimplify);
} // namespace cir
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index 73759cfa9c3c9..818a605ab74d3 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -27,6 +27,8 @@ def CIR_Dialect : Dialect {
let useDefaultAttributePrinterParser = 0;
let useDefaultTypePrinterParser = 0;
+ let hasConstantMaterializer = 1;
+
let extraClassDeclaration = [{
static llvm::StringRef getTripleAttrName() { return "cir.triple"; }
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 9215543ab67e6..8205718e0fc30 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1464,6 +1464,8 @@ def SelectOp : CIR_Op<"select", [Pure,
qualified(type($false_value))
`)` `->` qualified(type($result)) attr-dict
}];
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/clang/include/clang/CIR/Dialect/Passes.h b/clang/include/clang/CIR/Dialect/Passes.h
index 133eb462dcf1f..dbecf81acf7bb 100644
--- a/clang/include/clang/CIR/Dialect/Passes.h
+++ b/clang/include/clang/CIR/Dialect/Passes.h
@@ -22,6 +22,7 @@ namespace mlir {
std::unique_ptr<Pass> createCIRCanonicalizePass();
std::unique_ptr<Pass> createCIRFlattenCFGPass();
+std::unique_ptr<Pass> createCIRSimplifyPass();
std::unique_ptr<Pass> createHoistAllocasPass();
void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);
diff --git a/clang/include/clang/CIR/Dialect/Passes.td b/clang/include/clang/CIR/Dialect/Passes.td
index 74c255861c879..46fa97da04ca1 100644
--- a/clang/include/clang/CIR/Dialect/Passes.td
+++ b/clang/include/clang/CIR/Dialect/Passes.td
@@ -29,6 +29,20 @@ def CIRCanonicalize : Pass<"cir-canonicalize"> {
let dependentDialects = ["cir::CIRDialect"];
}
+def CIRSimplify : Pass<"cir-simplify"> {
+ let summary = "Performs CIR simplification and code optimization";
+ let description = [{
+ The pass performs code simplification and optimization on CIR.
+
+ Unlike the `cir-canonicalize` pass, this pass contains more aggresive code
+ transformations that could significantly affect CIR-to-source fidelity.
+ Example transformations performed in this pass include ternary folding,
+ code hoisting, etc.
+ }];
+ let constructor = "mlir::createCIRSimplifyPass()";
+ let dependentDialects = ["cir::CIRDialect"];
+}
+
def HoistAllocas : Pass<"cir-hoist-allocas"> {
let summary = "Hoist allocas to the entry of the function";
let description = [{
diff --git a/clang/include/clang/CIR/FrontendAction/CIRGenAction.h b/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
index 99495f4718c5f..b52166b58b882 100644
--- a/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
+++ b/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
@@ -49,7 +49,7 @@ class CIRGenAction : public clang::ASTFrontendAction {
public:
~CIRGenAction() override;
- OutputType Action;
+ OutputType action;
};
class EmitCIRAction : public CIRGenAction {
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index 3db13278261e6..b26144095792d 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -199,7 +199,6 @@ struct MissingFeatures {
static bool labelOp() { return false; }
static bool ptrDiffOp() { return false; }
static bool ptrStrideOp() { return false; }
- static bool selectOp() { return false; }
static bool switchOp() { return false; }
static bool ternaryOp() { return false; }
static bool tryOp() { return false; }
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index f5d6a424a71f6..5356630ece196 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -79,6 +79,14 @@ void cir::CIRDialect::initialize() {
addInterfaces<CIROpAsmDialectInterface>();
}
+Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
+ mlir::Attribute value,
+ mlir::Type type,
+ mlir::Location loc) {
+ return builder.create<cir::ConstantOp>(loc, type,
+ mlir::cast<mlir::TypedAttr>(value));
+}
+
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
@@ -1261,6 +1269,28 @@ void cir::TernaryOp::build(
result.addTypes(TypeRange{yield.getOperandTypes().front()});
}
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
+ mlir::Attribute condition = adaptor.getCondition();
+ if (condition) {
+ bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
+ return conditionValue ? getTrueValue() : getFalseValue();
+ }
+
+ // cir.select if %0 then x else x -> x
+ mlir::Attribute trueValue = adaptor.getTrueValue();
+ mlir::Attribute falseValue = adaptor.getFalseValue();
+ if (trueValue == falseValue)
+ return trueValue;
+ if (getTrueValue() == getFalseValue())
+ return getTrueValue();
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// ShiftOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index cdac69e66dba3..3b4c7bc613133 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -121,14 +121,13 @@ void CIRCanonicalizePass::runOnOperation() {
getOperation()->walk([&](Operation *op) {
assert(!cir::MissingFeatures::switchOp());
assert(!cir::MissingFeatures::tryOp());
- assert(!cir::MissingFeatures::selectOp());
assert(!cir::MissingFeatures::complexCreateOp());
assert(!cir::MissingFeatures::complexRealOp());
assert(!cir::MissingFeatures::complexImagOp());
assert(!cir::MissingFeatures::callOp());
// CastOp and UnaryOp are here to perform a manual `fold` in
// applyOpPatternsGreedily.
- if (isa<BrOp, BrCondOp, ScopeOp, CastOp, UnaryOp>(op))
+ if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op))
ops.push_back(op);
});
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
new file mode 100644
index 0000000000000..442801d062638
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -0,0 +1,184 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/Passes.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+using namespace cir;
+
+//===----------------------------------------------------------------------===//
+// Rewrite patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Simplify suitable ternary operations into select operations.
+///
+/// For now we only simplify those ternary operations whose true and false
+/// branches directly yield a value or a constant. That is, both of the true and
+/// the false branch must either contain a cir.yield operation as the only
+/// operation in the branch, or contain a cir.const operation followed by a
+/// cir.yield operation that yields the constant value.
+///
+/// For example, we will simplify the following ternary operation:
+///
+/// %0 = cir.ternary (%condition, true {
+/// %1 = cir.const ...
+/// cir.yield %1
+/// } false {
+/// cir.yield %2
+/// })
+///
+/// into the following sequence of operations:
+///
+/// %1 = cir.const ...
+/// %0 = cir.select if %condition then %1 else %2
+struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
+ using OpRewritePattern<TernaryOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TernaryOp op,
+ PatternRewriter &rewriter) const override {
+ if (op->getNumResults() != 1)
+ return mlir::failure();
+
+ if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
+ !isSimpleTernaryBranch(op.getFalseRegion()))
+ return mlir::failure();
+
+ cir::YieldOp trueBranchYieldOp =
+ mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
+ cir::YieldOp falseBranchYieldOp =
+ mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
+ mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
+ mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
+
+ rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
+ rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
+ rewriter.eraseOp(trueBranchYieldOp);
+ rewriter.eraseOp(falseBranchYieldOp);
+ rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
+ falseValue);
+
+ return mlir::success();
+ }
+
+private:
+ bool isSimpleTernaryBranch(mlir::Region ®ion) const {
+ if (!region.hasOneBlock())
+ return false;
+
+ mlir::Block &onlyBlock = region.front();
+ mlir::Block::OpListType &ops = onlyBlock.getOperations();
+
+ // The region/block could only contain at most 2 operations.
+ if (ops.size() > 2)
+ return false;
+
+ if (ops.size() == 1) {
+ // The region/block only contain a cir.yield operation.
+ return true;
+ }
+
+ // Check whether the region/block contains a cir.const followed by a
+ // cir.yield that yields the value.
+ auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
+ auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(
+ yieldOp.getArgs()[0].getDefiningOp());
+ return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
+ }
+};
+
+struct SimplifySelect : public OpRewritePattern<SelectOp> {
+ using OpRewritePattern<SelectOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SelectOp op,
+ PatternRewriter &rewriter) const final {
+ mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
+ mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
+ auto trueValueConstOp =
+ mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);
+ auto falseValueConstOp =
+ mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);
+ if (!trueValueConstOp || !falseValueConstOp)
+ return mlir::failure();
+
+ auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());
+ auto falseValue =
+ mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());
+ if (!trueValue || !falseValue)
+ return mlir::failure();
+
+ // cir.select if %0 then #true else #false -> %0
+ if (trueValue.getValue() && !falseValue.getValue()) {
+ rewriter.replaceAllUsesWith(op, op.getCondition());
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+
+ // cir.select if %0 then #false else #true -> cir.unary not %0
+ if (!trueValue.getValue() && falseValue.getValue()) {
+ rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
+ op.getCondition());
+ return mlir::success();
+ }
+
+ return mlir::failure();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// CIRSimplifyPass
+//===----------------------------------------------------------------------===//
+
+struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
+ using CIRSimplifyBase::CIRSimplifyBase;
+
+ void runOnOperation() override;
+};
+
+void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
+ // clang-format off
+ patterns.add<
+ SimplifyTernary,
+ SimplifySelect
+ >(patterns.getContext());
+ // clang-format on
+}
+
+void CIRSimplifyPass::runOnOperation() {
+ // Collect rewrite patterns.
+ RewritePatternSet patterns(&getContext());
+ populateMergeCleanupPatterns(patterns);
+
+ // Collect operations to apply patterns.
+ llvm::SmallVector<Operation *, 16> ops;
+ getOperation()->walk([&](Operation *op) {
+ if (isa<TernaryOp, SelectOp>(op))
+ ops.push_back(op);
+ });
+
+ // Apply patterns.
+ if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
+ signalPassFailure();
+}
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
+ return std::make_unique<CIRSimplifyPass>();
+}
diff --git a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
index 4678435b54c79..4dece5b57e450 100644
--- a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_clang_library(MLIRCIRTransforms
CIRCanonicalize.cpp
+ CIRSimplify.cpp
FlattenCFG.cpp
HoistAllocas.cpp
diff --git a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
index a32e6a7584774..570403dda9d9f 100644
--- a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
+++ b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
@@ -62,15 +62,17 @@ class CIRGenConsumer : public clang::ASTConsumer {
IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS;
std::unique_ptr<CIRGenerator> Gen;
const FrontendOptions &FEOptions;
+ CodeGenOptions &codeGenOptions;
public:
CIRGenConsumer(CIRGenAction::OutputType Action, CompilerInstance &CI,
+ CodeGenOptions &codeGenOptions,
std::unique_ptr<raw_pwrite_stream> OS)
: Action(Action), CI(CI), OutputStream(std::move(OS)),
FS(&CI.getVirtualFileSystem()),
Gen(std::make_unique<CIRGenerator>(CI.getDiagnostics(), std::move(FS),
CI.getCodeGenOpts())),
- FEOptions(CI.getFrontendOpts()) {}
+ FEOptions(CI.getFrontendOpts()), codeGenOptions(codeGenOptions) {}
void Initialize(ASTContext &Ctx) override {
assert(!Context && "initialized multiple times");
@@ -102,7 +104,8 @@ class CIRGenConsumer : public clang::ASTConsumer {
if (!FEOptions.ClangIRDisablePasses) {
// Setup and run CIR pipeline.
if (runCIRToCIRPasses(MlirModule, MlirCtx, C,
- !FEOptions.ClangIRDisableCIRVerifier)
+ !FEOptions.ClangIRDisableCIRVerifier,
+ codeGenOptions.OptimizationLevel > 0)
.failed()) {
CI.getDiagnostics().Report(diag::err_cir_to_cir_transform_failed);
return;
@@ -139,7 +142,7 @@ class CIRGenConsumer : public clang::ASTConsumer {
void CIRGenConsumer::anchor() {}
CIRGenAction::CIRGenAction(OutputType Act, mlir::MLIRContext *MLIRCtx)
- : MLIRCtx(MLIRCtx ? MLIRCtx : new mlir::MLIRContext), Action(Act) {}
+ : MLIRCtx(MLIRCtx ? MLIRCtx : new mlir::MLIRContext), action(Act) {}
CIRGenAction::~CIRGenAction() { MLIRMod.release(); }
@@ -162,14 +165,14 @@ getOutputStream(CompilerInstance &CI, StringRef InFile,
}
std::unique_ptr<ASTConsumer>
-CIRGenAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
- std::unique_ptr<llvm::raw_pwrite_stream> Out = CI.takeOutputStream();
+CIRGenAction::CreateASTConsumer(CompilerInstance &ci, StringRef inFile) {
+ std::unique_ptr<llvm::raw_pwrite_stream> out = ci.takeOutputStream();
- if (!Out)
- Out = getOutputStream(CI, InFile, Action);
+ if (!out)
+ out = getOutputStream(ci, inFile, action);
- auto Result =
- std::make_unique<cir::CIRGenConsumer>(Action, CI, std::move(Out));
+ auto Result = std::make_unique<cir::CIRGenConsumer>(
+ action, ci, ci.getCodeGenOpts(), std::move(out));
return Result;
}
diff --git a/clang/lib/CIR/Lowering/CIRPasses.cpp b/clang/lib/CIR/Lowering/CIRPasses.cpp
index a37a0480a56ac..7a581939580a9 100644
--- a/clang/lib/CIR/Lowering/CIRPasses.cpp
+++ b/clang/lib/CIR/Lowering/CIRPasses.cpp
@@ -20,13 +20,17 @@ namespace cir {
mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
mlir::MLIRContext &mlirContext,
clang::ASTContext &astContext,
- bool enableVerifier) {
+ bool enableVerifier,
+ bool enableCIRSimplify) {
llvm::TimeTraceScope scope("CIR To CIR Passes");
mlir::PassManager pm(&mlirContext);
pm.addPass(mlir::createCIRCanonicalizePass());
+ if (enableCIRSimplify)
+ pm.addPass(mlir::createCIRSimplifyPass());
+
pm.enableVerifier(enableVerifier);
(void)mlir::applyPassManagerCLOptions(pm);
return pm.run(theModule);
diff --git a/clang/test/CIR/Transforms/select.cir b/clang/test/CIR/Transforms/select.cir
new file mode 100644
index 0000000000000..29a5d1ed1ddeb
--- /dev/null
+++ b/clang/test/CIR/Transforms/select.cir
@@ -0,0 +1,60 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_true(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
+ %0 = cir.const #cir.bool<true> : !cir.bool
+ %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @fold_true(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
+ // CHECK-NEXT: cir.return %[[ARG0]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @fold_false(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
+ %0 = cir.const #cir.bool<false> : !cir.bool
+ %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
+ // CHECK-NEXT: cir.return %[[ARG1]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @fold_to_const(%arg0 : !cir.bool) -> !s32i {
+ %0 = cir.const #cir.int<42> : !s32i
+ %1 = cir.select if %arg0 then %0 else %0 : (!cir.bool, !s32i, !s32i) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @fold_to_const(%{{.+}}: !cir.bool) -> !s32i {
+ // CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i
+ // CHECK-NEXT: cir.return %[[#A]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @simplify_1(%arg0 : !cir.bool) -> !cir.bool {
+ %0 = cir.const #cir.bool<true> : !cir.bool
+ %1 = cir.const #cir.bool<false> : !cir.bool
+ %2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
+ cir.return %2 : !cir.bool
+ }
+
+ // CHECK: cir.func @simplify_1(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
+ // CHECK-NEXT: cir.return %[[ARG0]] : !cir.bool
+ // CHECK-NEXT: }
+
+ cir.func @simplify_2(%ar...
[truncated]
|
@llvm/pr-subscribers-clangir Author: Morris Hafner (mmha) ChangesThis patch adds the cir-simplify pass for SelectOp and TernaryOp. It also adds the SelectOp folder and adds the constant materializer for the CIR dialect. Patch is 23.47 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/138317.diff 16 Files Affected:
diff --git a/clang/include/clang/CIR/CIRToCIRPasses.h b/clang/include/clang/CIR/CIRToCIRPasses.h
index 361ebb9e9b840..4a23790ee8b76 100644
--- a/clang/include/clang/CIR/CIRToCIRPasses.h
+++ b/clang/include/clang/CIR/CIRToCIRPasses.h
@@ -32,7 +32,8 @@ namespace cir {
mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
mlir::MLIRContext &mlirCtx,
clang::ASTContext &astCtx,
- bool enableVerifier);
+ bool enableVerifier,
+ bool enableCIRSimplify);
} // namespace cir
diff --git a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
index 73759cfa9c3c9..818a605ab74d3 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIRDialect.td
@@ -27,6 +27,8 @@ def CIR_Dialect : Dialect {
let useDefaultAttributePrinterParser = 0;
let useDefaultTypePrinterParser = 0;
+ let hasConstantMaterializer = 1;
+
let extraClassDeclaration = [{
static llvm::StringRef getTripleAttrName() { return "cir.triple"; }
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 9215543ab67e6..8205718e0fc30 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1464,6 +1464,8 @@ def SelectOp : CIR_Op<"select", [Pure,
qualified(type($false_value))
`)` `->` qualified(type($result)) attr-dict
}];
+
+ let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
diff --git a/clang/include/clang/CIR/Dialect/Passes.h b/clang/include/clang/CIR/Dialect/Passes.h
index 133eb462dcf1f..dbecf81acf7bb 100644
--- a/clang/include/clang/CIR/Dialect/Passes.h
+++ b/clang/include/clang/CIR/Dialect/Passes.h
@@ -22,6 +22,7 @@ namespace mlir {
std::unique_ptr<Pass> createCIRCanonicalizePass();
std::unique_ptr<Pass> createCIRFlattenCFGPass();
+std::unique_ptr<Pass> createCIRSimplifyPass();
std::unique_ptr<Pass> createHoistAllocasPass();
void populateCIRPreLoweringPasses(mlir::OpPassManager &pm);
diff --git a/clang/include/clang/CIR/Dialect/Passes.td b/clang/include/clang/CIR/Dialect/Passes.td
index 74c255861c879..46fa97da04ca1 100644
--- a/clang/include/clang/CIR/Dialect/Passes.td
+++ b/clang/include/clang/CIR/Dialect/Passes.td
@@ -29,6 +29,20 @@ def CIRCanonicalize : Pass<"cir-canonicalize"> {
let dependentDialects = ["cir::CIRDialect"];
}
+def CIRSimplify : Pass<"cir-simplify"> {
+ let summary = "Performs CIR simplification and code optimization";
+ let description = [{
+ The pass performs code simplification and optimization on CIR.
+
+ Unlike the `cir-canonicalize` pass, this pass contains more aggresive code
+ transformations that could significantly affect CIR-to-source fidelity.
+ Example transformations performed in this pass include ternary folding,
+ code hoisting, etc.
+ }];
+ let constructor = "mlir::createCIRSimplifyPass()";
+ let dependentDialects = ["cir::CIRDialect"];
+}
+
def HoistAllocas : Pass<"cir-hoist-allocas"> {
let summary = "Hoist allocas to the entry of the function";
let description = [{
diff --git a/clang/include/clang/CIR/FrontendAction/CIRGenAction.h b/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
index 99495f4718c5f..b52166b58b882 100644
--- a/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
+++ b/clang/include/clang/CIR/FrontendAction/CIRGenAction.h
@@ -49,7 +49,7 @@ class CIRGenAction : public clang::ASTFrontendAction {
public:
~CIRGenAction() override;
- OutputType Action;
+ OutputType action;
};
class EmitCIRAction : public CIRGenAction {
diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h
index 3db13278261e6..b26144095792d 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -199,7 +199,6 @@ struct MissingFeatures {
static bool labelOp() { return false; }
static bool ptrDiffOp() { return false; }
static bool ptrStrideOp() { return false; }
- static bool selectOp() { return false; }
static bool switchOp() { return false; }
static bool ternaryOp() { return false; }
static bool tryOp() { return false; }
diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index f5d6a424a71f6..5356630ece196 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -79,6 +79,14 @@ void cir::CIRDialect::initialize() {
addInterfaces<CIROpAsmDialectInterface>();
}
+Operation *cir::CIRDialect::materializeConstant(mlir::OpBuilder &builder,
+ mlir::Attribute value,
+ mlir::Type type,
+ mlir::Location loc) {
+ return builder.create<cir::ConstantOp>(loc, type,
+ mlir::cast<mlir::TypedAttr>(value));
+}
+
//===----------------------------------------------------------------------===//
// Helpers
//===----------------------------------------------------------------------===//
@@ -1261,6 +1269,28 @@ void cir::TernaryOp::build(
result.addTypes(TypeRange{yield.getOperandTypes().front()});
}
+//===----------------------------------------------------------------------===//
+// SelectOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult cir::SelectOp::fold(FoldAdaptor adaptor) {
+ mlir::Attribute condition = adaptor.getCondition();
+ if (condition) {
+ bool conditionValue = mlir::cast<cir::BoolAttr>(condition).getValue();
+ return conditionValue ? getTrueValue() : getFalseValue();
+ }
+
+ // cir.select if %0 then x else x -> x
+ mlir::Attribute trueValue = adaptor.getTrueValue();
+ mlir::Attribute falseValue = adaptor.getFalseValue();
+ if (trueValue == falseValue)
+ return trueValue;
+ if (getTrueValue() == getFalseValue())
+ return getTrueValue();
+
+ return {};
+}
+
//===----------------------------------------------------------------------===//
// ShiftOp
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
index cdac69e66dba3..3b4c7bc613133 100644
--- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp
@@ -121,14 +121,13 @@ void CIRCanonicalizePass::runOnOperation() {
getOperation()->walk([&](Operation *op) {
assert(!cir::MissingFeatures::switchOp());
assert(!cir::MissingFeatures::tryOp());
- assert(!cir::MissingFeatures::selectOp());
assert(!cir::MissingFeatures::complexCreateOp());
assert(!cir::MissingFeatures::complexRealOp());
assert(!cir::MissingFeatures::complexImagOp());
assert(!cir::MissingFeatures::callOp());
// CastOp and UnaryOp are here to perform a manual `fold` in
// applyOpPatternsGreedily.
- if (isa<BrOp, BrCondOp, ScopeOp, CastOp, UnaryOp>(op))
+ if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op))
ops.push_back(op);
});
diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
new file mode 100644
index 0000000000000..442801d062638
--- /dev/null
+++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
@@ -0,0 +1,184 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "PassDetail.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/Region.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "clang/CIR/Dialect/IR/CIRDialect.h"
+#include "clang/CIR/Dialect/Passes.h"
+#include "llvm/ADT/SmallVector.h"
+
+using namespace mlir;
+using namespace cir;
+
+//===----------------------------------------------------------------------===//
+// Rewrite patterns
+//===----------------------------------------------------------------------===//
+
+namespace {
+
+/// Simplify suitable ternary operations into select operations.
+///
+/// For now we only simplify those ternary operations whose true and false
+/// branches directly yield a value or a constant. That is, both of the true and
+/// the false branch must either contain a cir.yield operation as the only
+/// operation in the branch, or contain a cir.const operation followed by a
+/// cir.yield operation that yields the constant value.
+///
+/// For example, we will simplify the following ternary operation:
+///
+/// %0 = cir.ternary (%condition, true {
+/// %1 = cir.const ...
+/// cir.yield %1
+/// } false {
+/// cir.yield %2
+/// })
+///
+/// into the following sequence of operations:
+///
+/// %1 = cir.const ...
+/// %0 = cir.select if %condition then %1 else %2
+struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
+ using OpRewritePattern<TernaryOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(TernaryOp op,
+ PatternRewriter &rewriter) const override {
+ if (op->getNumResults() != 1)
+ return mlir::failure();
+
+ if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
+ !isSimpleTernaryBranch(op.getFalseRegion()))
+ return mlir::failure();
+
+ cir::YieldOp trueBranchYieldOp =
+ mlir::cast<cir::YieldOp>(op.getTrueRegion().front().getTerminator());
+ cir::YieldOp falseBranchYieldOp =
+ mlir::cast<cir::YieldOp>(op.getFalseRegion().front().getTerminator());
+ mlir::Value trueValue = trueBranchYieldOp.getArgs()[0];
+ mlir::Value falseValue = falseBranchYieldOp.getArgs()[0];
+
+ rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
+ rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
+ rewriter.eraseOp(trueBranchYieldOp);
+ rewriter.eraseOp(falseBranchYieldOp);
+ rewriter.replaceOpWithNewOp<cir::SelectOp>(op, op.getCond(), trueValue,
+ falseValue);
+
+ return mlir::success();
+ }
+
+private:
+ bool isSimpleTernaryBranch(mlir::Region ®ion) const {
+ if (!region.hasOneBlock())
+ return false;
+
+ mlir::Block &onlyBlock = region.front();
+ mlir::Block::OpListType &ops = onlyBlock.getOperations();
+
+ // The region/block could only contain at most 2 operations.
+ if (ops.size() > 2)
+ return false;
+
+ if (ops.size() == 1) {
+ // The region/block only contain a cir.yield operation.
+ return true;
+ }
+
+ // Check whether the region/block contains a cir.const followed by a
+ // cir.yield that yields the value.
+ auto yieldOp = mlir::cast<cir::YieldOp>(onlyBlock.getTerminator());
+ auto yieldValueDefOp = mlir::dyn_cast_if_present<cir::ConstantOp>(
+ yieldOp.getArgs()[0].getDefiningOp());
+ return yieldValueDefOp && yieldValueDefOp->getBlock() == &onlyBlock;
+ }
+};
+
+struct SimplifySelect : public OpRewritePattern<SelectOp> {
+ using OpRewritePattern<SelectOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(SelectOp op,
+ PatternRewriter &rewriter) const final {
+ mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
+ mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
+ auto trueValueConstOp =
+ mlir::dyn_cast_if_present<cir::ConstantOp>(trueValueOp);
+ auto falseValueConstOp =
+ mlir::dyn_cast_if_present<cir::ConstantOp>(falseValueOp);
+ if (!trueValueConstOp || !falseValueConstOp)
+ return mlir::failure();
+
+ auto trueValue = mlir::dyn_cast<cir::BoolAttr>(trueValueConstOp.getValue());
+ auto falseValue =
+ mlir::dyn_cast<cir::BoolAttr>(falseValueConstOp.getValue());
+ if (!trueValue || !falseValue)
+ return mlir::failure();
+
+ // cir.select if %0 then #true else #false -> %0
+ if (trueValue.getValue() && !falseValue.getValue()) {
+ rewriter.replaceAllUsesWith(op, op.getCondition());
+ rewriter.eraseOp(op);
+ return mlir::success();
+ }
+
+ // cir.select if %0 then #false else #true -> cir.unary not %0
+ if (!trueValue.getValue() && falseValue.getValue()) {
+ rewriter.replaceOpWithNewOp<cir::UnaryOp>(op, cir::UnaryOpKind::Not,
+ op.getCondition());
+ return mlir::success();
+ }
+
+ return mlir::failure();
+ }
+};
+
+//===----------------------------------------------------------------------===//
+// CIRSimplifyPass
+//===----------------------------------------------------------------------===//
+
+struct CIRSimplifyPass : public CIRSimplifyBase<CIRSimplifyPass> {
+ using CIRSimplifyBase::CIRSimplifyBase;
+
+ void runOnOperation() override;
+};
+
+void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
+ // clang-format off
+ patterns.add<
+ SimplifyTernary,
+ SimplifySelect
+ >(patterns.getContext());
+ // clang-format on
+}
+
+void CIRSimplifyPass::runOnOperation() {
+ // Collect rewrite patterns.
+ RewritePatternSet patterns(&getContext());
+ populateMergeCleanupPatterns(patterns);
+
+ // Collect operations to apply patterns.
+ llvm::SmallVector<Operation *, 16> ops;
+ getOperation()->walk([&](Operation *op) {
+ if (isa<TernaryOp, SelectOp>(op))
+ ops.push_back(op);
+ });
+
+ // Apply patterns.
+ if (applyOpPatternsGreedily(ops, std::move(patterns)).failed())
+ signalPassFailure();
+}
+
+} // namespace
+
+std::unique_ptr<Pass> mlir::createCIRSimplifyPass() {
+ return std::make_unique<CIRSimplifyPass>();
+}
diff --git a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
index 4678435b54c79..4dece5b57e450 100644
--- a/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
+++ b/clang/lib/CIR/Dialect/Transforms/CMakeLists.txt
@@ -1,5 +1,6 @@
add_clang_library(MLIRCIRTransforms
CIRCanonicalize.cpp
+ CIRSimplify.cpp
FlattenCFG.cpp
HoistAllocas.cpp
diff --git a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
index a32e6a7584774..570403dda9d9f 100644
--- a/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
+++ b/clang/lib/CIR/FrontendAction/CIRGenAction.cpp
@@ -62,15 +62,17 @@ class CIRGenConsumer : public clang::ASTConsumer {
IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS;
std::unique_ptr<CIRGenerator> Gen;
const FrontendOptions &FEOptions;
+ CodeGenOptions &codeGenOptions;
public:
CIRGenConsumer(CIRGenAction::OutputType Action, CompilerInstance &CI,
+ CodeGenOptions &codeGenOptions,
std::unique_ptr<raw_pwrite_stream> OS)
: Action(Action), CI(CI), OutputStream(std::move(OS)),
FS(&CI.getVirtualFileSystem()),
Gen(std::make_unique<CIRGenerator>(CI.getDiagnostics(), std::move(FS),
CI.getCodeGenOpts())),
- FEOptions(CI.getFrontendOpts()) {}
+ FEOptions(CI.getFrontendOpts()), codeGenOptions(codeGenOptions) {}
void Initialize(ASTContext &Ctx) override {
assert(!Context && "initialized multiple times");
@@ -102,7 +104,8 @@ class CIRGenConsumer : public clang::ASTConsumer {
if (!FEOptions.ClangIRDisablePasses) {
// Setup and run CIR pipeline.
if (runCIRToCIRPasses(MlirModule, MlirCtx, C,
- !FEOptions.ClangIRDisableCIRVerifier)
+ !FEOptions.ClangIRDisableCIRVerifier,
+ codeGenOptions.OptimizationLevel > 0)
.failed()) {
CI.getDiagnostics().Report(diag::err_cir_to_cir_transform_failed);
return;
@@ -139,7 +142,7 @@ class CIRGenConsumer : public clang::ASTConsumer {
void CIRGenConsumer::anchor() {}
CIRGenAction::CIRGenAction(OutputType Act, mlir::MLIRContext *MLIRCtx)
- : MLIRCtx(MLIRCtx ? MLIRCtx : new mlir::MLIRContext), Action(Act) {}
+ : MLIRCtx(MLIRCtx ? MLIRCtx : new mlir::MLIRContext), action(Act) {}
CIRGenAction::~CIRGenAction() { MLIRMod.release(); }
@@ -162,14 +165,14 @@ getOutputStream(CompilerInstance &CI, StringRef InFile,
}
std::unique_ptr<ASTConsumer>
-CIRGenAction::CreateASTConsumer(CompilerInstance &CI, StringRef InFile) {
- std::unique_ptr<llvm::raw_pwrite_stream> Out = CI.takeOutputStream();
+CIRGenAction::CreateASTConsumer(CompilerInstance &ci, StringRef inFile) {
+ std::unique_ptr<llvm::raw_pwrite_stream> out = ci.takeOutputStream();
- if (!Out)
- Out = getOutputStream(CI, InFile, Action);
+ if (!out)
+ out = getOutputStream(ci, inFile, action);
- auto Result =
- std::make_unique<cir::CIRGenConsumer>(Action, CI, std::move(Out));
+ auto Result = std::make_unique<cir::CIRGenConsumer>(
+ action, ci, ci.getCodeGenOpts(), std::move(out));
return Result;
}
diff --git a/clang/lib/CIR/Lowering/CIRPasses.cpp b/clang/lib/CIR/Lowering/CIRPasses.cpp
index a37a0480a56ac..7a581939580a9 100644
--- a/clang/lib/CIR/Lowering/CIRPasses.cpp
+++ b/clang/lib/CIR/Lowering/CIRPasses.cpp
@@ -20,13 +20,17 @@ namespace cir {
mlir::LogicalResult runCIRToCIRPasses(mlir::ModuleOp theModule,
mlir::MLIRContext &mlirContext,
clang::ASTContext &astContext,
- bool enableVerifier) {
+ bool enableVerifier,
+ bool enableCIRSimplify) {
llvm::TimeTraceScope scope("CIR To CIR Passes");
mlir::PassManager pm(&mlirContext);
pm.addPass(mlir::createCIRCanonicalizePass());
+ if (enableCIRSimplify)
+ pm.addPass(mlir::createCIRSimplifyPass());
+
pm.enableVerifier(enableVerifier);
(void)mlir::applyPassManagerCLOptions(pm);
return pm.run(theModule);
diff --git a/clang/test/CIR/Transforms/select.cir b/clang/test/CIR/Transforms/select.cir
new file mode 100644
index 0000000000000..29a5d1ed1ddeb
--- /dev/null
+++ b/clang/test/CIR/Transforms/select.cir
@@ -0,0 +1,60 @@
+// RUN: cir-opt -cir-canonicalize -cir-simplify -o %t.cir %s
+// RUN: FileCheck --input-file=%t.cir %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+ cir.func @fold_true(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
+ %0 = cir.const #cir.bool<true> : !cir.bool
+ %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @fold_true(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
+ // CHECK-NEXT: cir.return %[[ARG0]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @fold_false(%arg0 : !s32i, %arg1 : !s32i) -> !s32i {
+ %0 = cir.const #cir.bool<false> : !cir.bool
+ %1 = cir.select if %0 then %arg0 else %arg1 : (!cir.bool, !s32i, !s32i) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
+ // CHECK-NEXT: cir.return %[[ARG1]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @fold_to_const(%arg0 : !cir.bool) -> !s32i {
+ %0 = cir.const #cir.int<42> : !s32i
+ %1 = cir.select if %arg0 then %0 else %0 : (!cir.bool, !s32i, !s32i) -> !s32i
+ cir.return %1 : !s32i
+ }
+
+ // CHECK: cir.func @fold_to_const(%{{.+}}: !cir.bool) -> !s32i {
+ // CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i
+ // CHECK-NEXT: cir.return %[[#A]] : !s32i
+ // CHECK-NEXT: }
+
+ cir.func @simplify_1(%arg0 : !cir.bool) -> !cir.bool {
+ %0 = cir.const #cir.bool<true> : !cir.bool
+ %1 = cir.const #cir.bool<false> : !cir.bool
+ %2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
+ cir.return %2 : !cir.bool
+ }
+
+ // CHECK: cir.func @simplify_1(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
+ // CHECK-NEXT: cir.return %[[ARG0]] : !cir.bool
+ // CHECK-NEXT: }
+
+ cir.func @simplify_2(%ar...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not the best reviewer for this, but found 1 thing I'd like a comment on. Else LGTM pending others being ok with it.
@@ -27,6 +27,8 @@ def CIR_Dialect : Dialect { | |||
let useDefaultAttributePrinterParser = 0; | |||
let useDefaultTypePrinterParser = 0; | |||
|
|||
let hasConstantMaterializer = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh boy would i love a comment explaining this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's some documentation here: https://mlir.llvm.org/docs/Canonicalization/#generating-constants-from-attributes
But yes I can add a small comment explaining that we need this for canonicalization.
The pass performs code simplification and optimization on CIR. | ||
|
||
Unlike the `cir-canonicalize` pass, this pass contains more aggresive code | ||
transformations that could significantly affect CIR-to-source fidelity. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The phrase "could significantly affect CIR-to-source fidelity" is somewhat alarming. I don't think we're doing anything that changes semantics are we? I think this comment is about transformations that could affect ease of debugging. Perhaps we can word it more gently?
@@ -49,7 +49,7 @@ class CIRGenAction : public clang::ASTFrontendAction { | |||
public: | |||
~CIRGenAction() override; | |||
|
|||
OutputType Action; | |||
OutputType action; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file uses the normal clang coding style, so variable names should be capitalized.
@@ -62,15 +62,17 @@ class CIRGenConsumer : public clang::ASTConsumer { | |||
IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS; | |||
std::unique_ptr<CIRGenerator> Gen; | |||
const FrontendOptions &FEOptions; | |||
CodeGenOptions &codeGenOptions; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CodeGenOptions &codeGenOptions; | |
CodeGenOptions &CodeGenOptions; |
Keep clang-style identifier naming throughout this file.
/// %1 = cir.const ... | ||
/// cir.yield %1 | ||
/// } false { | ||
/// cir.yield %2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This example is confusing. The identifier %2
gives the impression that it was defined sometime after %0
but I don't think that's the intention. This will only happen if the false case returns a value that exists prior to the ternary, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would probably look a bit better:
/// %0 = ...
/// %1 = cir.ternary (%condition, true {
/// %2 = cir.const ...
/// cir.yield %2
/// } false {
/// cir.yield %0
} | ||
}; | ||
|
||
struct SimplifySelect : public OpRewritePattern<SelectOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be great if this had an explanatory comment like SimplifyTernary
does. If there are going to be more cases added later, a general comment is fine. Otherwise, the two cases handled here can be explained pretty easily here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with minor tip: can you also add a source to CIR test that guarantee this kicks in from the driver?
/// %1 = cir.const ... | ||
/// cir.yield %1 | ||
/// } false { | ||
/// cir.yield %2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would probably look a bit better:
/// %0 = ...
/// %1 = cir.ternary (%condition, true {
/// %2 = cir.const ...
/// cir.yield %2
/// } false {
/// cir.yield %0
@bcardosolopes Adding a test for C++ -> CIR depends on #138003 which adds lowering for |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
I think you're referring to this PR: #138156 |
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/174/builds/17384 Here is the relevant piece of the build log for the reference
|
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/64/builds/3451 Here is the relevant piece of the build log for the reference
|
This patch adds the cir-simplify pass for SelectOp and TernaryOp. It also adds the SelectOp folder and adds the constant materializer for the CIR dialect.