From 35fb1d5da23bcdcdfd66a8d76a0ecf315f9ddc66 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Tue, 2 Sep 2025 15:36:06 -0700 Subject: [PATCH 1/2] [Synth] Support MIG lowering in CombToAIG conversion pass This patch extends the synthesis pipeline to support lowering combinational operations to Majority-Inverter Graphs (MIG) in addition to And-Inverter Graphs (AIG). The implementation adds a CombToSynthTarget enum with AIG/MIG options, implements MIG-specific conversion patterns for OR operations using majority functions, updates adder carry logic to use majority functions in MIG mode, and renames the pipeline from AIGLowering to CombLowering to reflect the generalized capability. The MIG representation uses 3-input majority functions (MAJ(a,b,c) = a&b | a&c | b&c) which are more efficient than AIG for arithmetic operations. Integration tests verify functional equivalence using circt-lec. --- include/circt/Conversion/CombToAIG.h | 8 + include/circt/Conversion/Passes.td | 13 +- include/circt/Dialect/AIG/AIG.td | 1 + include/circt/Dialect/AIG/AIGOps.h | 13 ++ .../Synth/Transforms/SynthesisPipeline.h | 18 ++- .../Bindings/Python/dialects/synth.py | 2 +- integration_test/circt-synth/mig.mlir | 20 +++ lib/Conversion/AIGToComb/AIGToComb.cpp | 56 ++++++- lib/Conversion/AIGToComb/CMakeLists.txt | 1 + lib/Conversion/CombToAIG/CMakeLists.txt | 1 + lib/Conversion/CombToAIG/CombToAIG.cpp | 142 +++++++++++++++--- lib/Dialect/AIG/AIGDialect.cpp | 1 + lib/Dialect/AIG/AIGOps.cpp | 44 ++++++ lib/Dialect/AIG/CMakeLists.txt | 1 + lib/Dialect/AIG/Transforms/CMakeLists.txt | 1 + .../AIG/Transforms/LowerWordToBits.cpp | 19 ++- .../Synth/Transforms/SynthesisPipeline.cpp | 29 ++-- test/Conversion/AIGToComb/aig-to-comb.mlir | 12 ++ test/Conversion/CombToAIG/comb-to-mig.mlir | 26 ++++ test/Dialect/AIG/lower-word-to-bits.mlir | 16 ++ test/circt-synth/basic.mlir | 7 +- tools/circt-synth/circt-synth.cpp | 25 +-- 22 files changed, 390 insertions(+), 66 deletions(-) create mode 100644 integration_test/circt-synth/mig.mlir create mode 100644 test/Conversion/CombToAIG/comb-to-mig.mlir diff --git a/include/circt/Conversion/CombToAIG.h b/include/circt/Conversion/CombToAIG.h index 58bd329e99ca..009aa006c623 100644 --- a/include/circt/Conversion/CombToAIG.h +++ b/include/circt/Conversion/CombToAIG.h @@ -16,6 +16,14 @@ namespace circt { +// FIXME: Rename to CombToSynthTargetIR +enum CombToAIGTargetIR { + // Lower to And-Inverter + AIG, + // Lower to Majority-Inverter + MIG +}; + #define GEN_PASS_DECL_CONVERTCOMBTOAIG #include "circt/Conversion/Passes.h.inc" diff --git a/include/circt/Conversion/Passes.td b/include/circt/Conversion/Passes.td index f39eab817f7b..65ff56f690e9 100644 --- a/include/circt/Conversion/Passes.td +++ b/include/circt/Conversion/Passes.td @@ -779,16 +779,25 @@ def LowerSimToSV: Pass<"lower-sim-to-sv", "mlir::ModuleOp"> { // ConvertCombToAIG //===----------------------------------------------------------------------===// +// FIXME: Rename to ConvertCombToSynth def ConvertCombToAIG: Pass<"convert-comb-to-aig", "hw::HWModuleOp"> { let summary = "Lower Comb ops to AIG ops."; let dependentDialects = [ "circt::comb::CombDialect", "circt::aig::AIGDialect", + "circt::synth::SynthDialect" ]; let options = [ ListOption<"additionalLegalOps", "additional-legal-ops", "std::string", - "Specify additional legal ops to partially legalize Comb to AIG">, + "Specify additional legal ops to partially legalize Comb">, + Option<"targetIR", "target-ir", "circt::CombToAIGTargetIR", + /*default=*/"circt::CombToAIGTargetIR::AIG", + "Target IR kind", + [{::llvm::cl::values( + clEnumValN(circt::CombToAIGTargetIR::AIG, "aig", "Lower to AIG"), + clEnumValN(circt::CombToAIGTargetIR::MIG, "mig", "Lower to MIG") + )}]>, Option<"maxEmulationUnknownBits", "max-emulation-unknown-bits", "uint32_t", "10", "Maximum number of unknown bits to emulate in a table lookup"> ]; @@ -797,7 +806,7 @@ def ConvertCombToAIG: Pass<"convert-comb-to-aig", "hw::HWModuleOp"> { //===----------------------------------------------------------------------===// // ConvertAIGToComb //===----------------------------------------------------------------------===// - +// FIXME: Rename to ConvertSynthToComb def ConvertAIGToComb: Pass<"convert-aig-to-comb", "hw::HWModuleOp"> { let summary = "Lower AIG ops to Comb ops"; let description = [{ diff --git a/include/circt/Dialect/AIG/AIG.td b/include/circt/Dialect/AIG/AIG.td index ff04617316a7..15469a69c07a 100644 --- a/include/circt/Dialect/AIG/AIG.td +++ b/include/circt/Dialect/AIG/AIG.td @@ -14,6 +14,7 @@ include "mlir/IR/DialectBase.td" def AIG_Dialect : Dialect { let name = "aig"; let cppNamespace = "::circt::aig"; + let dependentDialects = ["::circt::synth::SynthDialect"]; let summary = "Representation of AIGs"; } diff --git a/include/circt/Dialect/AIG/AIGOps.h b/include/circt/Dialect/AIG/AIGOps.h index 0a298cd71a7c..5a0440b49f26 100644 --- a/include/circt/Dialect/AIG/AIGOps.h +++ b/include/circt/Dialect/AIG/AIGOps.h @@ -22,10 +22,23 @@ #include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Rewrite/PatternApplicator.h" #include "circt/Dialect/AIG/AIGDialect.h" #define GET_OP_CLASSES #include "circt/Dialect/AIG/AIG.h.inc" +namespace circt { +namespace aig { +struct AndInverterVariadicOpConversion + : mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + mlir::LogicalResult + matchAndRewrite(circt::aig::AndInverterOp op, + mlir::PatternRewriter &rewriter) const override; +}; +} // namespace aig +} // namespace circt + #endif // CIRCT_DIALECT_AIG_AIGOPS_H diff --git a/include/circt/Dialect/Synth/Transforms/SynthesisPipeline.h b/include/circt/Dialect/Synth/Transforms/SynthesisPipeline.h index 4f9632f084c1..19ac57a88df0 100644 --- a/include/circt/Dialect/Synth/Transforms/SynthesisPipeline.h +++ b/include/circt/Dialect/Synth/Transforms/SynthesisPipeline.h @@ -24,9 +24,16 @@ namespace circt { namespace synth { +enum TargetIR { + // Lower to And-Inverter Graph + AIG, + // Lower to Majority-Inverter Graph + MIG +}; + /// Options for the aig lowering pipeline. -struct AIGLoweringPipelineOptions - : public mlir::PassPipelineOptions { +struct CombLoweringPipelineOptions + : public mlir::PassPipelineOptions { PassOptions::Option disableDatapath{ *this, "disable-datapath", llvm::cl::desc("Disable datapath optimization passes"), @@ -35,6 +42,9 @@ struct AIGLoweringPipelineOptions *this, "timing-aware", llvm::cl::desc("Lower operators in a timing-aware fashion"), llvm::cl::init(false)}; + PassOptions::Option targetIR{ + *this, "lowering-target", llvm::cl::desc("Target IR to lower to"), + llvm::cl::init(TargetIR::AIG)}; }; /// Options for the aig optimization pipeline. @@ -61,8 +71,8 @@ struct AIGOptimizationPipelineOptions //===----------------------------------------------------------------------===// /// Populate the synthesis pipelines. -void buildAIGLoweringPipeline(mlir::OpPassManager &pm, - const AIGLoweringPipelineOptions &options); +void buildCombLoweringPipeline(mlir::OpPassManager &pm, + const CombLoweringPipelineOptions &options); void buildAIGOptimizationPipeline( mlir::OpPassManager &pm, const AIGOptimizationPipelineOptions &options); diff --git a/integration_test/Bindings/Python/dialects/synth.py b/integration_test/Bindings/Python/dialects/synth.py index 8929ac1a7037..467c12c7d25c 100644 --- a/integration_test/Bindings/Python/dialects/synth.py +++ b/integration_test/Bindings/Python/dialects/synth.py @@ -27,7 +27,7 @@ def build_module(module): # Check that the synthesis pipeline is registered. pm = PassManager.parse( - "builtin.module(hw.module(synth-aig-lowering-pipeline, " + "builtin.module(hw.module(synth-comb-lowering-pipeline, " "synth-aig-optimization-pipeline))") pm.run(m.operation) # CHECK: hw.module @foo( diff --git a/integration_test/circt-synth/mig.mlir b/integration_test/circt-synth/mig.mlir new file mode 100644 index 000000000000..06405c436945 --- /dev/null +++ b/integration_test/circt-synth/mig.mlir @@ -0,0 +1,20 @@ +// RUN: circt-opt %s --pass-pipeline='builtin.module(hw.module(hw-aggregate-to-comb,convert-comb-to-aig{target-ir=mig},convert-aig-to-comb))' -o %t.mlir +// RUN: circt-lec %t.mlir %s -c1=bit_logical -c2=bit_logical --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_BIT_LOGICAL +// COMB_BIT_LOGICAL: c1 == c2 +hw.module @bit_logical(in %arg0: i32, in %arg1: i32, in %arg2: i32, in %arg3: i32, + in %cond: i1, out out0: i32, out out1: i32, out out2: i32, out out3: i32) { + %0 = comb.or %arg0, %arg1, %arg2, %arg3 : i32 + %1 = comb.and %arg0, %arg1, %arg2, %arg3 : i32 + %2 = comb.xor %arg0, %arg1, %arg2, %arg3 : i32 + %3 = comb.mux %cond, %arg0, %arg1 : i32 + + hw.output %0, %1, %2, %3 : i32, i32, i32, i32 +} + +// RUN: circt-lec %t.mlir %s -c1=add -c2=add --shared-libs=%libz3 | FileCheck %s --check-prefix=COMB_ADD +// COMB_ADD: c1 == c2 +hw.module @add(in %arg0: i4, in %arg1: i4, in %arg2: i4, out add: i4) { + %0 = comb.add %arg0, %arg1, %arg2 : i4 + hw.output %0 : i4 +} + diff --git a/lib/Conversion/AIGToComb/AIGToComb.cpp b/lib/Conversion/AIGToComb/AIGToComb.cpp index 27d9aea4ff3c..4fddc38c8113 100644 --- a/lib/Conversion/AIGToComb/AIGToComb.cpp +++ b/lib/Conversion/AIGToComb/AIGToComb.cpp @@ -14,6 +14,8 @@ #include "circt/Dialect/AIG/AIGOps.h" #include "circt/Dialect/Comb/CombOps.h" #include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/Synth/SynthDialect.h" +#include "circt/Dialect/Synth/SynthOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" @@ -53,6 +55,55 @@ struct AIGAndInverterOpConversion : OpConversionPattern { } }; +struct SynthMajorityInverterOpConversion + : OpConversionPattern { + using OpConversionPattern< + synth::mig::MajorityInverterOp>::OpConversionPattern; + LogicalResult + matchAndRewrite(synth::mig::MajorityInverterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // Only handle 1 or 3-input majority inverter for now. + if (op.getNumOperands() > 3) + return failure(); + + auto getOperand = [&](unsigned idx) { + auto input = op.getInputs()[idx]; + if (!op.getInverted()[idx]) + return input; + auto width = input.getType().getIntOrFloatBitWidth(); + auto allOnes = hw::ConstantOp::create(rewriter, op.getLoc(), + APInt::getAllOnes(width)); + return rewriter.createOrFold(op.getLoc(), input, allOnes, + true); + }; + + if (op.getNumOperands() == 1) { + rewriter.replaceOp(op, getOperand(0)); + return success(); + } + + assert(op.getNumOperands() == 3 && "Expected 3 operands for majority op"); + SmallVector inputs; + for (size_t i = 0; i < 3; ++i) + inputs.push_back(getOperand(i)); + + // MAJ(x, y, z) = x & y | x & z | y & z + auto getProduct = [&](unsigned idx1, unsigned idx2) { + return rewriter.createOrFold( + op.getLoc(), ValueRange{inputs[idx1], inputs[idx2]}, true); + }; + + SmallVector operands; + operands.push_back(getProduct(0, 1)); + operands.push_back(getProduct(0, 2)); + operands.push_back(getProduct(1, 2)); + + rewriter.replaceOp( + op, rewriter.createOrFold(op.getLoc(), operands, true)); + return success(); + } +}; + } // namespace //===----------------------------------------------------------------------===// @@ -69,13 +120,14 @@ struct ConvertAIGToCombPass } // namespace static void populateAIGToCombConversionPatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); + patterns.add( + patterns.getContext()); } void ConvertAIGToCombPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalDialect(); + target.addIllegalDialect(); RewritePatternSet patterns(&getContext()); populateAIGToCombConversionPatterns(patterns); diff --git a/lib/Conversion/AIGToComb/CMakeLists.txt b/lib/Conversion/AIGToComb/CMakeLists.txt index 6e025ba4fd55..8796c7d4879c 100644 --- a/lib/Conversion/AIGToComb/CMakeLists.txt +++ b/lib/Conversion/AIGToComb/CMakeLists.txt @@ -8,6 +8,7 @@ add_circt_conversion_library(CIRCTAIGToComb CIRCTAIG CIRCTHW CIRCTComb + CIRCTSynth MLIRIR MLIRPass MLIRSupport diff --git a/lib/Conversion/CombToAIG/CMakeLists.txt b/lib/Conversion/CombToAIG/CMakeLists.txt index 5add2ed5198a..d2c2d9a3996f 100644 --- a/lib/Conversion/CombToAIG/CMakeLists.txt +++ b/lib/Conversion/CombToAIG/CMakeLists.txt @@ -11,6 +11,7 @@ add_circt_conversion_library(CIRCTCombToAIG CIRCTHW CIRCTComb CIRCTAIG + CIRCTSynth MLIRIR MLIRPass MLIRSupport diff --git a/lib/Conversion/CombToAIG/CombToAIG.cpp b/lib/Conversion/CombToAIG/CombToAIG.cpp index e94a4e7fe8f4..14c8d1d3ccfd 100644 --- a/lib/Conversion/CombToAIG/CombToAIG.cpp +++ b/lib/Conversion/CombToAIG/CombToAIG.cpp @@ -1,4 +1,4 @@ -//===- CombToAIG.cpp - Comb to AIG Conversion Pass --------------*- C++ -*-===// +//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -8,15 +8,32 @@ // // This is the main Comb to AIG Conversion Pass Implementation. // +// High-level Comb Operations +// | | +// v | +// +-------------------+ | +// | and, or, xor, mux | | +// +---------+---------+ | +// | | +// +-------+--------+ | +// v v v +// +-----+ +-----+ +// | AIG |-------->| MIG | +// +-----+ +-----+ +// //===----------------------------------------------------------------------===// #include "circt/Conversion/CombToAIG.h" +#include "circt/Dialect/AIG/AIGDialect.h" #include "circt/Dialect/AIG/AIGOps.h" #include "circt/Dialect/Comb/CombOps.h" #include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/Synth/SynthDialect.h" +#include "circt/Dialect/Synth/SynthOps.h" #include "circt/Support/Naming.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/APInt.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/Support/Debug.h" @@ -95,6 +112,27 @@ static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, outOfBoundsValue); } +// Return a majority operation if MIG is enabled, otherwise return a majority +// function implemented with Comb operations. In that case `carry` has slightly +// smaller depth than the other inputs. +static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a, + Value b, Value carry, + bool useMajorityInverterOp) { + if (useMajorityInverterOp) { + SmallVector inputs = {a, b, carry}; + SmallVector inverts = {false, false, false}; + return synth::mig::MajorityInverterOp::create(rewriter, loc, inputs, + inverts); + } + + // maj(a, b, c) = (c & (a ^ b)) | (a & b) + auto aXnorB = comb::XorOp::create(rewriter, loc, ValueRange{a, b}, true); + auto andOp = + comb::AndOp::create(rewriter, loc, ValueRange{carry, aXnorB}, true); + auto aAndB = comb::AndOp::create(rewriter, loc, ValueRange{a, b}, true); + return comb::OrOp::create(rewriter, loc, ValueRange{andOp, aAndB}, true); +} + namespace { // A union of Value and IntegerAttr to cleanly handle constant values. using ConstantOrValue = llvm::PointerUnion; @@ -253,7 +291,7 @@ struct CombAndOpConversion : OpConversionPattern { }; /// Lower a comb::OrOp operation to aig::AndInverterOp with invert flags -struct CombOrOpConversion : OpConversionPattern { +struct CombOrToAIGConversion : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -269,6 +307,50 @@ struct CombOrOpConversion : OpConversionPattern { } }; +struct CombOrToMIGConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OrOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getNumOperands() != 2) + return failure(); + SmallVector inputs(adaptor.getInputs()); + auto one = hw::ConstantOp::create( + rewriter, op.getLoc(), + APInt::getAllOnes(hw::getBitWidth(op.getType()))); + inputs.push_back(one); + SmallVector inverts(inputs.size(), false); + replaceOpWithNewOpAndCopyNamehint( + rewriter, op, inputs, inverts); + return success(); + } +}; + +struct AndInverterToMIGConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(aig::AndInverterOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (op.getNumOperands() > 2) + return failure(); + if (op.getNumOperands() == 1) { + SmallVector inverts{op.getInverted()[0]}; + replaceOpWithNewOpAndCopyNamehint( + rewriter, op, adaptor.getInputs(), inverts); + return success(); + } + SmallVector inputs(adaptor.getInputs()); + auto one = hw::ConstantOp::create( + rewriter, op.getLoc(), APInt::getZero(hw::getBitWidth(op.getType()))); + inputs.push_back(one); + SmallVector inverts(adaptor.getInverted()); + inverts.push_back(false); + replaceOpWithNewOpAndCopyNamehint( + rewriter, op, inputs, inverts); + return success(); + } +}; + /// Lower a comb::XorOp operation to AIG operations struct CombXorOpConversion : OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -342,8 +424,6 @@ struct CombMuxOpConversion : OpConversionPattern { LogicalResult matchAndRewrite(MuxOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - // Implement: c ? a : b = (replicate(c) & a) | (~replicate(c) & b) - Value cond = op.getCond(); auto trueVal = op.getTrueValue(); auto falseVal = op.getFalseValue(); @@ -377,6 +457,7 @@ struct CombMuxOpConversion : OpConversionPattern { } }; +template struct CombAddOpConversion : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -430,20 +511,15 @@ struct CombAddOpConversion : OpConversionPattern { break; // carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i]) - Value nextCarry = comb::AndOp::create( - rewriter, op.getLoc(), ValueRange{aBits[i], bBits[i]}, true); if (!carry) { // This is the first bit, so the carry is the next carry. - carry = nextCarry; + carry = comb::AndOp::create(rewriter, op.getLoc(), + ValueRange{aBits[i], bBits[i]}, true); continue; } - auto aXnorB = comb::XorOp::create(rewriter, op.getLoc(), - ValueRange{aBits[i], bBits[i]}, true); - auto andOp = comb::AndOp::create(rewriter, op.getLoc(), - ValueRange{carry, aXnorB}, true); - carry = comb::OrOp::create(rewriter, op.getLoc(), - ValueRange{andOp, nextCarry}, true); + carry = createMajorityFunction(rewriter, op.getLoc(), aBits[i], bBits[i], + carry, lowerToMIG); } LLVM_DEBUG(llvm::dbgs() << "Lower comb.add to Ripple-Carry Adder of width " << width << "\n"); @@ -530,6 +606,7 @@ struct CombAddOpConversion : OpConversionPattern { for (int64_t stride = 1; stride < width; stride *= 2) { for (int64_t i = stride; i < width; ++i) { int64_t j = i - stride; + // Group generate: g_i OR (p_i AND g_j) Value andPG = comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], gPrefix[j]); @@ -604,7 +681,8 @@ struct CombAddOpConversion : OpConversionPattern { // Group generate: g_i OR (p_i AND g_j) Value andPG = comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], gPrefix[j]); - gPrefixNew[i] = OrOp::create(rewriter, op.getLoc(), gPrefix[i], andPG); + gPrefixNew[i] = + comb::OrOp::create(rewriter, op.getLoc(), gPrefix[i], andPG); // Group propagate: p_i AND p_j pPrefixNew[i] = @@ -1087,27 +1165,35 @@ struct ConvertCombToAIGPass : public impl::ConvertCombToAIGBase { void runOnOperation() override; using ConvertCombToAIGBase::ConvertCombToAIGBase; - using ConvertCombToAIGBase::additionalLegalOps; - using ConvertCombToAIGBase::maxEmulationUnknownBits; }; } // namespace static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns, - uint32_t maxEmulationUnknownBits) { + uint32_t maxEmulationUnknownBits, + bool lowerToMIG) { patterns.add< // Bitwise Logical Ops - CombAndOpConversion, CombOrOpConversion, CombXorOpConversion, - CombMuxOpConversion, CombParityOpConversion, + CombAndOpConversion, CombXorOpConversion, CombMuxOpConversion, + CombParityOpConversion, // Arithmetic Ops - CombAddOpConversion, CombSubOpConversion, CombMulOpConversion, - CombICmpOpConversion, + CombSubOpConversion, CombMulOpConversion, CombICmpOpConversion, // Shift Ops CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion, // Variadic ops that must be lowered to binary operations CombLowerVariadicOp, CombLowerVariadicOp, CombLowerVariadicOp>(patterns.getContext()); + if (lowerToMIG) { + patterns.add, + AndInverterToMIGConversion, + circt::aig::AndInverterVariadicOpConversion, + CombAddOpConversion>(patterns.getContext()); + } else { + patterns.add>( + patterns.getContext()); + } + // Add div/mod patterns with a threshold given by the pass option. patterns.add(patterns.getContext(), @@ -1130,8 +1216,15 @@ void ConvertCombToAIGPass::runOnOperation() { target.addIllegalOp(); - // AIG is target dialect. - target.addLegalDialect(); + if (targetIR == CombToAIGTargetIR::AIG) { + // AIG is target dialect. + target.addLegalDialect(); + target.addIllegalOp(); + } else if (targetIR == CombToAIGTargetIR::MIG) { + target.addLegalDialect(); + target.addLegalOp(); + target.addIllegalOp(); + } // If additional legal ops are specified, add them to the target. if (!additionalLegalOps.empty()) @@ -1139,7 +1232,8 @@ void ConvertCombToAIGPass::runOnOperation() { target.addLegalOp(OperationName(opName, &getContext())); RewritePatternSet patterns(&getContext()); - populateCombToAIGConversionPatterns(patterns, maxEmulationUnknownBits); + populateCombToAIGConversionPatterns(patterns, maxEmulationUnknownBits, + targetIR == CombToAIGTargetIR::MIG); if (failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/AIG/AIGDialect.cpp b/lib/Dialect/AIG/AIGDialect.cpp index ff5b0df17a7c..37f7afaab79c 100644 --- a/lib/Dialect/AIG/AIGDialect.cpp +++ b/lib/Dialect/AIG/AIGDialect.cpp @@ -12,6 +12,7 @@ #include "circt/Dialect/AIG/AIGDialect.h" #include "circt/Dialect/AIG/AIGOps.h" +#include "circt/Dialect/Synth/SynthDialect.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectImplementation.h" diff --git a/lib/Dialect/AIG/AIGOps.cpp b/lib/Dialect/AIG/AIGOps.cpp index 39aabf566ac0..cfa2ae195026 100644 --- a/lib/Dialect/AIG/AIGOps.cpp +++ b/lib/Dialect/AIG/AIGOps.cpp @@ -118,3 +118,47 @@ APInt AndInverterOp::evaluate(ArrayRef inputs) { } return result; } + +static Value lowerVariadicAndInverterOp(aig::AndInverterOp op, + OperandRange operands, + ArrayRef inverts, + PatternRewriter &rewriter) { + using namespace aig; + switch (operands.size()) { + case 0: + assert(0 && "cannot be called with empty operand range"); + break; + case 1: + if (inverts[0]) + return AndInverterOp::create(rewriter, op.getLoc(), operands[0], true); + else + return operands[0]; + case 2: + return AndInverterOp::create(rewriter, op.getLoc(), operands[0], + operands[1], inverts[0], inverts[1]); + default: + auto firstHalf = operands.size() / 2; + auto lhs = + lowerVariadicAndInverterOp(op, operands.take_front(firstHalf), + inverts.take_front(firstHalf), rewriter); + auto rhs = + lowerVariadicAndInverterOp(op, operands.drop_front(firstHalf), + inverts.drop_front(firstHalf), rewriter); + return AndInverterOp::create(rewriter, op.getLoc(), lhs, rhs); + } + + return Value(); +} + +LogicalResult circt::aig::AndInverterVariadicOpConversion::matchAndRewrite( + aig::AndInverterOp op, PatternRewriter &rewriter) const { + if (op.getInputs().size() <= 2) + return failure(); + + // TODO: This is a naive implementation that creates a balanced binary tree. + // We can improve by analyzing the dataflow and creating a tree that + // improves the critical path or area. + rewriter.replaceOp(op, lowerVariadicAndInverterOp( + op, op.getOperands(), op.getInverted(), rewriter)); + return success(); +} diff --git a/lib/Dialect/AIG/CMakeLists.txt b/lib/Dialect/AIG/CMakeLists.txt index 162e871d15d8..37a21b0befdb 100644 --- a/lib/Dialect/AIG/CMakeLists.txt +++ b/lib/Dialect/AIG/CMakeLists.txt @@ -9,6 +9,7 @@ add_circt_dialect_library(CIRCTAIG MLIRIR CIRCTHW CIRCTSupport + CIRCTSynth DEPENDS MLIRAIGIncGen diff --git a/lib/Dialect/AIG/Transforms/CMakeLists.txt b/lib/Dialect/AIG/Transforms/CMakeLists.txt index 1a9e519cc798..c8e6acff2fab 100644 --- a/lib/Dialect/AIG/Transforms/CMakeLists.txt +++ b/lib/Dialect/AIG/Transforms/CMakeLists.txt @@ -13,4 +13,5 @@ add_circt_dialect_library(CIRCTAIGTransforms CIRCTHW CIRCTImportAIGER CIRCTSeq + CIRCTSynth ) diff --git a/lib/Dialect/AIG/Transforms/LowerWordToBits.cpp b/lib/Dialect/AIG/Transforms/LowerWordToBits.cpp index ded5827d61c4..89091f0f2908 100644 --- a/lib/Dialect/AIG/Transforms/LowerWordToBits.cpp +++ b/lib/Dialect/AIG/Transforms/LowerWordToBits.cpp @@ -14,6 +14,7 @@ #include "circt/Dialect/AIG/AIGPasses.h" #include "circt/Dialect/Comb/CombOps.h" #include "circt/Dialect/HW/HWOps.h" +#include "circt/Dialect/Synth/SynthOps.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "aig-lower-word-to-bits" @@ -33,11 +34,11 @@ using namespace aig; //===----------------------------------------------------------------------===// namespace { +template +struct WordRewritePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; -struct WordRewritePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(AndInverterOp op, + LogicalResult matchAndRewrite(OpTy op, PatternRewriter &rewriter) const override { auto width = op.getType().getIntOrFloatBitWidth(); if (width <= 1) @@ -50,7 +51,7 @@ struct WordRewritePattern : public OpRewritePattern { SmallVector operands; for (auto operand : op.getOperands()) { // Reuse bits if we can extract from `comb.concat` operands. - if (auto concat = operand.getDefiningOp()) { + if (auto concat = operand.template getDefiningOp()) { // For the simplicity, we only handle the case where all the // `comb.concat` operands are single-bit. if (concat.getNumOperands() == width && @@ -66,8 +67,8 @@ struct WordRewritePattern : public OpRewritePattern { operands.push_back( comb::ExtractOp::create(rewriter, op.getLoc(), operand, i, 1)); } - results.push_back(AndInverterOp::create(rewriter, op.getLoc(), operands, - op.getInvertedAttr())); + results.push_back( + OpTy::create(rewriter, op.getLoc(), operands, op.getInvertedAttr())); } rewriter.replaceOpWithNewOp(op, results); @@ -90,7 +91,9 @@ struct LowerWordToBitsPass void LowerWordToBitsPass::runOnOperation() { RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns.add, + WordRewritePattern>( + &getContext()); mlir::FrozenRewritePatternSet frozenPatterns(std::move(patterns)); mlir::GreedyRewriteConfig config; diff --git a/lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp b/lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp index ac800195e7a6..f340b0a9da5e 100644 --- a/lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp +++ b/lib/Dialect/Synth/Transforms/SynthesisPipeline.cpp @@ -34,12 +34,12 @@ using namespace circt::synth; /// Helper function to populate additional legal ops for partial legalization. template -static void partiallyLegalizeCombToAIG(SmallVectorImpl &ops) { +static void partiallyLegalizeCombToSynth(SmallVectorImpl &ops) { (ops.push_back(AllowedOpTy::getOperationName().str()), ...); } -void circt::synth::buildAIGLoweringPipeline( - OpPassManager &pm, const AIGLoweringPipelineOptions &options) { +void circt::synth::buildCombLoweringPipeline( + OpPassManager &pm, const CombLoweringPipelineOptions &options) { { if (!options.disableDatapath) { pm.addPass(createConvertCombToDatapath()); @@ -49,12 +49,12 @@ void circt::synth::buildAIGLoweringPipeline( pm.addPass(createConvertDatapathToComb(datapathOptions)); pm.addPass(createSimpleCanonicalizerPass()); } - // Partially legalize Comb to AIG, run CSE and canonicalization. + // Partially legalize Comb, then run CSE and canonicalization. circt::ConvertCombToAIGOptions convOptions; - partiallyLegalizeCombToAIG( + partiallyLegalizeCombToSynth( convOptions.additionalLegalOps); pm.addPass(circt::createConvertCombToAIG(convOptions)); } @@ -62,7 +62,11 @@ void circt::synth::buildAIGLoweringPipeline( pm.addPass(createSimpleCanonicalizerPass()); pm.addPass(circt::hw::createHWAggregateToComb()); - pm.addPass(circt::createConvertCombToAIG()); + circt::ConvertCombToAIGOptions convOptions; + convOptions.targetIR = options.targetIR.getValue() == TargetIR::AIG + ? CombToAIGTargetIR::AIG + : CombToAIGTargetIR::MIG; + pm.addPass(circt::createConvertCombToAIG(convOptions)); pm.addPass(createCSEPass()); pm.addPass(createSimpleCanonicalizerPass()); pm.addPass(createCSEPass()); @@ -97,9 +101,10 @@ void circt::synth::buildAIGOptimizationPipeline( //===----------------------------------------------------------------------===// void circt::synth::registerSynthesisPipeline() { - PassPipelineRegistration( - "synth-aig-lowering-pipeline", - "The default pipeline for until AIG lowering", buildAIGLoweringPipeline); + PassPipelineRegistration( + "synth-comb-lowering-pipeline", + "The default pipeline for until Comb lowering", + buildCombLoweringPipeline); PassPipelineRegistration( "synth-aig-optimization-pipeline", "The default pipeline for AIG optimization pipeline", diff --git a/test/Conversion/AIGToComb/aig-to-comb.mlir b/test/Conversion/AIGToComb/aig-to-comb.mlir index 3f4c0bc3de65..573503b9878c 100644 --- a/test/Conversion/AIGToComb/aig-to-comb.mlir +++ b/test/Conversion/AIGToComb/aig-to-comb.mlir @@ -10,3 +10,15 @@ hw.module @test(in %a: i32, in %b: i32, in %c: i32, in %d: i32, out out0: i32) { %0 = aig.and_inv not %a, %b, not %c, %d : i32 hw.output %0 : i32 } + +// CHECK-LABEL: @test_maj +hw.module @test_maj(in %a: i32, in %b: i32, in %c: i32, out out0: i32) { + // CHECK: %c-1_i32 = hw.constant -1 : i32 + // CHECK: %[[NOT_B:.+]] = comb.xor bin %b, %c-1_i32 : i32 + // CHECK: %[[AND1:.+]] = comb.and bin %a, %[[NOT_B]] : i32 + // CHECK: %[[AND2:.+]] = comb.and bin %a, %c : i32 + // CHECK: %[[AND3:.+]] = comb.and bin %[[NOT_B]], %c : i32 + // CHECK: %[[RESULT:.+]] = comb.or bin %[[AND1]], %[[AND2]], %[[AND3]] : i32 + %0 = synth.mig.maj_inv %a, not %b, %c : i32 + hw.output %0 : i32 +} diff --git a/test/Conversion/CombToAIG/comb-to-mig.mlir b/test/Conversion/CombToAIG/comb-to-mig.mlir new file mode 100644 index 000000000000..230b54079e7f --- /dev/null +++ b/test/Conversion/CombToAIG/comb-to-mig.mlir @@ -0,0 +1,26 @@ +// RUN: circt-opt %s --pass-pipeline='builtin.module(hw.module(convert-comb-to-aig{target-ir=mig},cse))' | FileCheck %s + +hw.module @logic(in %arg0: i32, in %arg1: i32, in %arg2: i32, in %arg3: i32, out out0: i32, out out1: i32) { + // CHECK: %[[MIG_0:.+]] = synth.mig.maj_inv %arg1, %arg2, %{{c-1_i32.*}} : i32 + // CHECK: %[[MIG_1:.+]] = synth.mig.maj_inv %arg0, %[[MIG_0]], %{{c-1_i32.*}} : i32 + // CHECK: %[[MIG_2:.+]] = synth.mig.maj_inv %arg1, %arg2, %{{c0_i32.*}} : i32 + // CHECK: %[[MIG_3:.+]] = synth.mig.maj_inv %arg0, %[[MIG_2]], %{{c0_i32.*}} : i32 + // CHECK: hw.output %[[MIG_1]], %[[MIG_3]] : i32, i32 + %0 = comb.or %arg0, %arg1, %arg2 : i32 + %1 = comb.and %arg0, %arg1, %arg2 : i32 + hw.output %0, %1 : i32, i32 +} + + +// CHECK-LABEL: @add +hw.module @add(in %lhs: i3, in %rhs: i3, out out: i3) { + // Check majority function is used for carry logic + // CHECK: %[[LHS_0:.+]] = comb.extract %lhs from 0 : (i3) -> i1 + // CHECK: %[[LHS_1:.+]] = comb.extract %lhs from 1 : (i3) -> i1 + // CHECK: %[[RHS_0:.+]] = comb.extract %rhs from 0 : (i3) -> i1 + // CHECK: %[[RHS_1:.+]] = comb.extract %rhs from 1 : (i3) -> i1 + // CHECK: %[[LHS_0_AND_RHS_0:.+]] = synth.mig.maj_inv %[[LHS_0]], %[[RHS_0]], %{{false.*}} : i1 + // CHECK: %[[CARRY_1:.+]] = synth.mig.maj_inv %[[LHS_1]], %[[RHS_1]], %[[LHS_0_AND_RHS_0]] : i1 + %0 = comb.add %lhs, %rhs : i3 + hw.output %0 : i3 +} diff --git a/test/Dialect/AIG/lower-word-to-bits.mlir b/test/Dialect/AIG/lower-word-to-bits.mlir index 3b7654f2bcc9..e4fa6fd4300d 100644 --- a/test/Dialect/AIG/lower-word-to-bits.mlir +++ b/test/Dialect/AIG/lower-word-to-bits.mlir @@ -14,3 +14,19 @@ hw.module @Basic(in %a: i2, in %b: i2, out f: i2) { // CHECK-NEXT: %[[CONCAT:.+]] = comb.concat %[[AND_INV_2]], %[[AND_INV_3]] hw.output %1 : i2 } + +// CHECK-LABEL: hw.module @Basic_MIG +hw.module @Basic_MIG(in %a: i2, in %b: i2, in %c: i2, out f: i2) { + %0 = synth.mig.maj_inv not %a, %b, %c : i2 + // CHECK-NEXT: %[[EXTRACT_A_1:.+]] = comb.extract %a from 1 : (i2) -> i1 + // CHECK-NEXT: %[[EXTRACT_B_1:.+]] = comb.extract %b from 1 : (i2) -> i1 + // CHECK-NEXT: %[[EXTRACT_C_1:.+]] = comb.extract %c from 1 : (i2) -> i1 + // CHECK-NEXT: %[[MAJ_INV_0:.+]] = synth.mig.maj_inv not %[[EXTRACT_A_1]], %[[EXTRACT_B_1]], %[[EXTRACT_C_1]] : i1 + // CHECK-NEXT: %[[EXTRACT_A_0:.+]] = comb.extract %a from 0 : (i2) -> i1 + // CHECK-NEXT: %[[EXTRACT_B_0:.+]] = comb.extract %b from 0 : (i2) -> i1 + // CHECK-NEXT: %[[EXTRACT_C_0:.+]] = comb.extract %c from 0 : (i2) -> i1 + // CHECK-NEXT: %[[MAJ_INV_1:.+]] = synth.mig.maj_inv not %[[EXTRACT_A_0]], %[[EXTRACT_B_0]], %[[EXTRACT_C_0]] : i1 + // CHECK-NEXT: %[[CONCAT:.+]] = comb.concat %[[MAJ_INV_0]], %[[MAJ_INV_1]] : i1, i1 + // CHECK-NEXT: hw.output %[[CONCAT]] : i2 + hw.output %0 : i2 +} diff --git a/test/circt-synth/basic.mlir b/test/circt-synth/basic.mlir index 2a47156618e2..a5af878747e5 100644 --- a/test/circt-synth/basic.mlir +++ b/test/circt-synth/basic.mlir @@ -1,15 +1,17 @@ // RUN: circt-synth %s | FileCheck %s // RUN: circt-synth %s --top and | FileCheck %s --check-prefixes=TOP,CHECK // RUN: circt-synth %s --top and --emit-bytecode -f | circt-opt | FileCheck %s --check-prefix=CHECK -// RUN: circt-synth %s --until-before aig-lowering | FileCheck %s --check-prefix=AIG -// RUN: circt-synth %s --until-before aig-lowering --convert-to-comb | FileCheck %s --check-prefix=COMB +// RUN: circt-synth %s --until-before comb-lowering | FileCheck %s --check-prefix=AIG +// RUN: circt-synth %s --until-before comb-lowering --convert-to-comb | FileCheck %s --check-prefix=COMB // RUN: circt-synth %s --top and --disable-word-to-bits | FileCheck %s --check-prefix=DISABLE_WORD +// RUN: circt-synth %s --target-ir mig | FileCheck %s --check-prefix=MIG // TOP-LABEL: module attributes {"aig.longest-path-analysis-top" = @and} // AIG-LABEL: @and( // CHECK-LABEL: @and( // COMB-LABEL: @and( // DISABLE_WORD-LABEL: @and( +// MIG-LABEL: @and( hw.module @and(in %a: i2, in %b: i2, in %c: i2, in %d: i1, out and: i2) { // AIG-NEXT: %[[AND_INV:.+]] = aig.and_inv %a, %b, %c : i2 // AIG-NEXT: dbg.variable @@ -30,6 +32,7 @@ hw.module @and(in %a: i2, in %b: i2, in %c: i2, in %d: i1, out and: i2) { // COMB-NOT: aig.and_inv // DISABLE_WORD-NOT: comb.extract // DISABLE_WORD-NOT: comb.concat + // MIG: synth.mig.maj_inv %0 = comb.and %a, %b, %c : i2 dbg.variable "test", %0 : i2 hw.output %0 : i2 diff --git a/tools/circt-synth/circt-synth.cpp b/tools/circt-synth/circt-synth.cpp index 5e04d3d2dd8e..bce8e53dd783 100644 --- a/tools/circt-synth/circt-synth.cpp +++ b/tools/circt-synth/circt-synth.cpp @@ -28,6 +28,7 @@ #include "circt/Dialect/SV/SVPasses.h" #include "circt/Dialect/Seq/SeqDialect.h" #include "circt/Dialect/Sim/SimDialect.h" +#include "circt/Dialect/Synth/SynthDialect.h" #include "circt/Dialect/Synth/Transforms/SynthPasses.h" #include "circt/Dialect/Synth/Transforms/SynthesisPipeline.h" #include "circt/Dialect/Verif/VerifDialect.h" @@ -94,10 +95,10 @@ static cl::opt cl::desc("Allow unknown dialects in the input"), cl::init(false), cl::cat(mainCategory)); -enum Until { UntilAIGLowering, UntilMapping, UntilEnd }; +enum Until { UntilCombLowering, UntilMapping, UntilEnd }; static auto runUntilValues = llvm::cl::values( - clEnumValN(UntilAIGLowering, "aig-lowering", "Lowering of AIG"), + clEnumValN(UntilCombLowering, "comb-lowering", "Lowering Comb to AIG/MIG"), clEnumValN(UntilMapping, "mapping", "Run technology/lut mapping"), clEnumValN(UntilEnd, "all", "Run entire pipeline (default)")); @@ -177,6 +178,12 @@ static cl::opt cl::desc("Lower to generic a truth table op with K inputs"), cl::init(0), cl::cat(mainCategory)); +static cl::opt + targetIR("target-ir", cl::desc("Target IR to lower to"), + cl::values(clEnumValN(TargetIR::AIG, "aig", "AIG operation"), + clEnumValN(TargetIR::MIG, "mig", "MIG operation")), + cl::init(TargetIR::AIG), cl::cat(mainCategory)); + //===----------------------------------------------------------------------===// // Main Tool Logic //===----------------------------------------------------------------------===// @@ -200,11 +207,6 @@ nestOrAddToHierarchicalRunner(OpPassManager &pm, // Tool implementation //===----------------------------------------------------------------------===// -template -static void partiallyLegalizeCombToAIG(SmallVectorImpl &ops) { - (ops.push_back(AllowedOpTy::getOperationName().str()), ...); -} - // Add a default synthesis pipeline and analysis. static void populateCIRCTSynthPipeline(PassManager &pm) { // ExtractTestCode is used to move verification code from design to @@ -213,11 +215,12 @@ static void populateCIRCTSynthPipeline(PassManager &pm) { /*disableInstanceExtraction=*/false, /*disableRegisterExtraction=*/false, /*disableModuleInlining=*/false)); auto pipeline = [](OpPassManager &pm) { - circt::synth::AIGLoweringPipelineOptions loweringOptions; + circt::synth::CombLoweringPipelineOptions loweringOptions; loweringOptions.disableDatapath = disableDatapath; loweringOptions.timingAware = !disableTimingAware; - circt::synth::buildAIGLoweringPipeline(pm, loweringOptions); - if (untilReached(UntilAIGLowering)) + loweringOptions.targetIR = targetIR; + circt::synth::buildCombLoweringPipeline(pm, loweringOptions); + if (untilReached(UntilCombLowering)) return; circt::synth::AIGOptimizationPipelineOptions optimizationOptions; @@ -378,7 +381,7 @@ int main(int argc, char **argv) { registry.insert(); + synth::SynthDialect, sv::SVDialect, verif::VerifDialect>(); MLIRContext context(registry); if (allowUnregisteredDialects) context.allowUnregisteredDialects(); From 7be6e45dc30cfab313a4532bf5c7349ff59df0e9 Mon Sep 17 00:00:00 2001 From: Hideto Ueno Date: Mon, 8 Sep 2025 10:48:49 -0700 Subject: [PATCH 2/2] Address comments --- lib/Conversion/CombToAIG/CombToAIG.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/CombToAIG/CombToAIG.cpp b/lib/Conversion/CombToAIG/CombToAIG.cpp index 14c8d1d3ccfd..30ba4713e5d2 100644 --- a/lib/Conversion/CombToAIG/CombToAIG.cpp +++ b/lib/Conversion/CombToAIG/CombToAIG.cpp @@ -36,6 +36,7 @@ #include "llvm/ADT/APInt.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/Support/Debug.h" +#include #define DEBUG_TYPE "comb-to-aig" @@ -119,8 +120,8 @@ static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a, Value b, Value carry, bool useMajorityInverterOp) { if (useMajorityInverterOp) { - SmallVector inputs = {a, b, carry}; - SmallVector inverts = {false, false, false}; + std::array inputs = {a, b, carry}; + std::array inverts = {false, false, false}; return synth::mig::MajorityInverterOp::create(rewriter, loc, inputs, inverts); } @@ -319,7 +320,7 @@ struct CombOrToMIGConversion : OpConversionPattern { rewriter, op.getLoc(), APInt::getAllOnes(hw::getBitWidth(op.getType()))); inputs.push_back(one); - SmallVector inverts(inputs.size(), false); + std::array inverts = {false, false, false}; replaceOpWithNewOpAndCopyNamehint( rewriter, op, inputs, inverts); return success();