From 333777a55803e6a26a00119550f3f34bfea33d88 Mon Sep 17 00:00:00 2001 From: cowardsa Date: Thu, 17 Jul 2025 15:00:14 +0100 Subject: [PATCH 01/12] Initiate datapath to comb pass --- include/circt/Conversion/DatapathToComb.h | 25 ++ include/circt/Conversion/Passes.h | 1 + include/circt/Conversion/Passes.td | 17 + lib/Conversion/CMakeLists.txt | 1 + lib/Conversion/DatapathToComb/CMakeLists.txt | 18 + .../DatapathToComb/DatapathToComb.cpp | 348 ++++++++++++++++++ 6 files changed, 410 insertions(+) create mode 100644 include/circt/Conversion/DatapathToComb.h create mode 100644 lib/Conversion/DatapathToComb/CMakeLists.txt create mode 100644 lib/Conversion/DatapathToComb/DatapathToComb.cpp diff --git a/include/circt/Conversion/DatapathToComb.h b/include/circt/Conversion/DatapathToComb.h new file mode 100644 index 000000000000..7368041cb39e --- /dev/null +++ b/include/circt/Conversion/DatapathToComb.h @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef CIRCT_CONVERSION_DATAPATHTOCOMB_H +#define CIRCT_CONVERSION_DATAPATHTOCOMB_H + +#include "circt/Support/LLVM.h" + +namespace circt { + +void populateDatapathToCombConversionPatterns(TypeConverter &converter, + RewritePatternSet &patterns, + bool lowerCompressToAdd); + +#define GEN_PASS_DECL_CONVERTDATAPATHTOCOMB +#include "circt/Conversion/Passes.h.inc" + +} // namespace circt + +#endif // CIRCT_CONVERSION_DATAPATHTOCOMB_H diff --git a/include/circt/Conversion/Passes.h b/include/circt/Conversion/Passes.h index d773f8d04296..f3aa5e985b7f 100644 --- a/include/circt/Conversion/Passes.h +++ b/include/circt/Conversion/Passes.h @@ -26,6 +26,7 @@ #include "circt/Conversion/CombToSMT.h" #include "circt/Conversion/ConvertToArcs.h" #include "circt/Conversion/DCToHW.h" +#include "circt/Conversion/DatapathToComb.h" #include "circt/Conversion/DatapathToSMT.h" #include "circt/Conversion/ExportChiselInterface.h" #include "circt/Conversion/ExportVerilog.h" diff --git a/include/circt/Conversion/Passes.td b/include/circt/Conversion/Passes.td index 70f3928a73fd..fb432dffa07d 100644 --- a/include/circt/Conversion/Passes.td +++ b/include/circt/Conversion/Passes.td @@ -871,4 +871,21 @@ def ConvertDatapathToSMT : Pass<"convert-datapath-to-smt"> { ]; } +//===----------------------------------------------------------------------===// +// ConvertDatapathToComb +//===----------------------------------------------------------------------===// + +def ConvertDatapathToComb : Pass<"convert-datapath-to-comb"> { + let summary = "Convert Datapath ops to Comb ops"; + let dependentDialects = [ + "circt::comb::CombDialect", + "circt::datapath::DatapathDialect", + "circt::hw::HWDialect" + ]; + let options = [ + Option<"lowerCompressToAdd", "lower-compress-to-add", "bool", "false", + "Lower compress operators to variadic add."> + ]; +} + #endif // CIRCT_CONVERSION_PASSES_TD diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 9e8e4d37e40c..281f1a804121 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -9,6 +9,7 @@ add_subdirectory(CombToDatapath) add_subdirectory(CombToLLVM) add_subdirectory(CombToSMT) add_subdirectory(ConvertToArcs) +add_subdirectory(DatapathToComb) add_subdirectory(DatapathToSMT) add_subdirectory(DCToHW) add_subdirectory(ExportAIGER) diff --git a/lib/Conversion/DatapathToComb/CMakeLists.txt b/lib/Conversion/DatapathToComb/CMakeLists.txt new file mode 100644 index 000000000000..21c5dbf893b2 --- /dev/null +++ b/lib/Conversion/DatapathToComb/CMakeLists.txt @@ -0,0 +1,18 @@ +add_circt_conversion_library(CIRCTDatapathToComb + DatapathToComb.cpp + + DEPENDS + CIRCTConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + CIRCTDatapath + CIRCTComb + CIRCTHW + MLIRIR + MLIRPass + MLIRSupport + MLIRTransforms +) diff --git a/lib/Conversion/DatapathToComb/DatapathToComb.cpp b/lib/Conversion/DatapathToComb/DatapathToComb.cpp new file mode 100644 index 000000000000..a03e03f02a7c --- /dev/null +++ b/lib/Conversion/DatapathToComb/DatapathToComb.cpp @@ -0,0 +1,348 @@ +//===- DatapathToComb.cpp--------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "circt/Conversion/DatapathToComb.h" +#include "circt/Dialect/Comb/CombOps.h" +#include "circt/Dialect/Datapath/DatapathOps.h" +#include "circt/Dialect/HW/HWOps.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/PointerUnion.h" +#include "llvm/Support/Debug.h" + +#define DEBUG_TYPE "datapath-to-comb" + +namespace circt { +#define GEN_PASS_DEF_CONVERTDATAPATHTOCOMB +#include "circt/Conversion/Passes.h.inc" +} // namespace circt + +// using namespace mlir; +using namespace circt; +using namespace datapath; + +// A wrapper for comb::extractBits that returns a SmallVector. +static SmallVector extractBits(OpBuilder &builder, Value val) { + SmallVector bits; + comb::extractBits(builder, val, bits); + return bits; +} + +//===----------------------------------------------------------------------===// +// Conversion patterns +//===----------------------------------------------------------------------===// + +namespace { +// Construct a full adder for three 1-bit inputs. +std::pair fullAdder(ConversionPatternRewriter &rewriter, + Location loc, Value a, Value b, Value c) { + auto aXorB = rewriter.createOrFold(loc, a, b, true); + // a ^ b ^ c + Value sum = rewriter.createOrFold(loc, aXorB, c, true); + // (a & b) | ((a ^ b) & c) + auto carry = rewriter.createOrFold( + loc, + ArrayRef{rewriter.createOrFold(loc, a, b, true), + rewriter.createOrFold(loc, aXorB, c, true)}, + true); + + return {sum, carry}; +} + +struct DatapathCompressOpAddConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(CompressOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto inputs = op.getOperands(); + unsigned width = inputs[0].getType().getIntOrFloatBitWidth(); + auto addOp = rewriter.create(loc, inputs, true); + auto zeroOp = rewriter.create(loc, APInt(width, 0)); + + rewriter.replaceOp(op, {addOp, zeroOp}); + return success(); + } +}; + +struct DatapathCompressOpConversion : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(CompressOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto inputs = op.getOperands(); + unsigned width = inputs[0].getType().getIntOrFloatBitWidth(); + // TODO - implement a more efficient compression algorithm to compete with + // yosys's `compress` pass. + + auto falseValue = rewriter.create(loc, APInt(1, 0)); + SmallVector> partialProducts; + for (auto input : inputs) { + partialProducts.push_back( + extractBits(rewriter, input)); // Extract bits from each input + } + + // Wallace tree reduction + rewriter.replaceOp(op, wallaceReduction(falseValue, width, rewriter, loc, + partialProducts)); + return success(); + } + +private: + // Perform Wallace tree reduction on partial products. + // See https://en.wikipedia.org/wiki/Wallace_tree + static SmallVector + wallaceReduction(Value falseValue, size_t width, + ConversionPatternRewriter &rewriter, Location loc, + SmallVector> &partialProducts) { + SmallVector> newPartialProducts; + newPartialProducts.reserve(partialProducts.size()); + // Continue reduction until we have only two rows. The length of + // `partialProducts` is reduced by 1/3 in each iteration. + auto numReductionStages = 0; + while (partialProducts.size() > 2) { + newPartialProducts.clear(); + ++numReductionStages; + // Take three rows at a time and reduce to two rows(sum and carry). + for (unsigned i = 0; i < partialProducts.size(); i += 3) { + if (i + 2 < partialProducts.size()) { + // We have three rows to reduce + auto &row1 = partialProducts[i]; + auto &row2 = partialProducts[i + 1]; + auto &row3 = partialProducts[i + 2]; + + assert(row1.size() == width && row2.size() == width && + row3.size() == width); + + SmallVector sumRow, carryRow; + sumRow.reserve(width); + carryRow.reserve(width); + carryRow.push_back(falseValue); + + // Process each bit position + for (unsigned j = 0; j < width; ++j) { + // Full adder logic + auto [sum, carry] = + fullAdder(rewriter, loc, row1[j], row2[j], row3[j]); + sumRow.push_back(sum); + if (j + 1 < width) + carryRow.push_back(carry); + } + + newPartialProducts.push_back(std::move(sumRow)); + newPartialProducts.push_back(std::move(carryRow)); + } else { + // Add remaining rows as is + newPartialProducts.append(partialProducts.begin() + i, + partialProducts.end()); + } + } + + std::swap(newPartialProducts, partialProducts); + } + + LLVM_DEBUG(llvm::dbgs() << "Wallace tree reduction completed in " + << numReductionStages << " stages\n"); + + assert(partialProducts.size() == 2); + SmallVector carrySave; + for (auto partialProduct : partialProducts) { + // Reverse the order of the bits + std::reverse(partialProduct.begin(), partialProduct.end()); + carrySave.push_back(rewriter.create(loc, partialProduct)); + } + // Use comb.add for the final addition. + return carrySave; + } +}; + +struct DatapathPartialProductOpConversion + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(PartialProductOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + auto inputs = op.getOperands(); + Value a = inputs[0]; + Value b = inputs[1]; + unsigned width = a.getType().getIntOrFloatBitWidth(); + + // Skip a zero width value. + if (width == 0) { + rewriter.replaceOpWithNewOp(op, op.getType(0), 0); + return success(); + } + + // Extract individual bits from operands + SmallVector aBits = extractBits(rewriter, a); + SmallVector bBits = extractBits(rewriter, b); + + SmallVector partialProducts; + partialProducts.reserve(width); + + // Implement a basic and array + if (width <= 16) + lowerAndArray(rewriter, op, partialProducts, aBits, bBits, width); + else + lowerBoothArray(rewriter, op, partialProducts, aBits, bBits, width); + + return success(); + } + + void lowerAndArray(ConversionPatternRewriter &rewriter, PartialProductOp op, + SmallVector partialProducts, + SmallVector &aBits, SmallVector &bBits, + unsigned width) const { + + Location loc = op.getLoc(); + Value a = op.getOperand(0); + + // Generate partial products + + for (unsigned i = 0; i < width; ++i) { + auto repl = rewriter.create(loc, bBits[i], width); + auto ppRow = rewriter.create(loc, repl, a); + Value shiftBy = rewriter.create(loc, APInt(width, i)); + Value ppAlign = rewriter.create(loc, ppRow, shiftBy); + partialProducts.push_back(ppAlign); + if (partialProducts.size() == op.getNumResults()) + break; + } + if (partialProducts.size() != op.getNumResults()) { + llvm::errs() << "Expected " << op.getNumResults() + << " partial products, but got " << partialProducts.size() + << " width " << width << "\n"; + assert(false && "Expected width number of booth partial products"); + } + rewriter.replaceOp(op, partialProducts); + } + + void lowerBoothArray(ConversionPatternRewriter &rewriter, PartialProductOp op, + SmallVector &partialProducts, + SmallVector &aBits, SmallVector &bBits, + unsigned width) const { + Location loc = op.getLoc(); + auto zeroFalse = rewriter.create(loc, APInt(1, 0)); + auto zeroWidth = rewriter.create(loc, APInt(width, 0)); + auto oneWidth = rewriter.create(loc, APInt(width, 1)); + Value a = op.getOperand(0); + Value twoA = rewriter.create(loc, a, oneWidth); + + // Booth encoding: examine b[i+1:i-1] for each i + // We need width/2 partial products for radix-2 Booth + Value cplPrev; + for (unsigned i = 0; i < width; i += 2) { + // Get Booth bits: b[i+1], b[i], b[i-1] (b[-1] = 0) + Value b_i = (i < width) ? bBits[i] : zeroFalse; + Value b_ip1 = (i + 1 < width) ? bBits[i + 1] : zeroFalse; + Value b_im1 = (i == 0) ? zeroFalse : bBits[i - 1]; + + // Is the encoding zero or negative (an approximation) + Value cpl = b_ip1; + // Is the encoding one = b_i xor b_im1 + Value one = rewriter.create(loc, b_i, b_im1, true); + // Is the encoding two = (b_ip1 & ~b_i & ~b_im1) | (~b_ip1 & b_i & b_im1) + Value const_one = rewriter.create(loc, APInt(1, 1)); + Value b_i_inv = rewriter.create(loc, b_i, const_one, true); + Value b_ip1_inv = + rewriter.create(loc, b_ip1, const_one, true); + Value b_im1_inv = + rewriter.create(loc, b_im1, const_one, true); + Value andLeft = rewriter.create( + loc, ValueRange{b_ip1_inv, b_i, b_im1}, true); + Value andRight = rewriter.create( + loc, ValueRange{b_ip1, b_i_inv, b_im1_inv}, true); + Value two = rewriter.create(loc, andLeft, andRight, true); + + Value cpl_repl = rewriter.create(loc, cpl, width); + Value one_repl = rewriter.create(loc, one, width); + Value two_repl = rewriter.create(loc, two, width); + + // Select between 2*a or 1*a or 0*a + Value selTwoA = rewriter.create(loc, two_repl, twoA); + Value selOneA = rewriter.create(loc, one_repl, a); + Value magA = rewriter.create(loc, selTwoA, selOneA, true); + + // Conditionally invert the row + Value ppRow = rewriter.create(loc, magA, cpl_repl, true); + if (i == 0) { + partialProducts.push_back(ppRow); + cplPrev = cpl; + continue; + } + assert(i >= 2 && "Expected i to be at least 2 for sign correction"); + Value withSignCorrection = rewriter.create( + loc, ValueRange{ppRow, zeroFalse, cplPrev}); + Value ppAlignPre = + rewriter.create(loc, withSignCorrection, 0, width); + Value shiftBy = rewriter.create(loc, APInt(width, i - 2)); + Value ppAlign = rewriter.create(loc, ppAlignPre, shiftBy); + partialProducts.push_back(ppAlign); + cplPrev = cpl; + + if (partialProducts.size() == op.getNumResults()) + break; + } + + while (partialProducts.size() < op.getNumResults()) + partialProducts.push_back(zeroWidth); + + if (partialProducts.size() != op.getNumResults()) { + llvm::errs() << "Expected " << op.getNumResults() + << " partial products, but got " << partialProducts.size() + << "width " << width << "\n"; + assert(false && "Expected width number of booth partial products"); + } + assert(partialProducts.size() == op.getNumResults() && + "Expected width number of booth partial products"); + + rewriter.replaceOp(op, partialProducts); + } +}; +} // namespace + +//===----------------------------------------------------------------------===// +// Convert Comb to Datapath pass +//===----------------------------------------------------------------------===// + +namespace { +struct ConvertDatapathToCombPass + : public impl::ConvertDatapathToCombBase { + void runOnOperation() override; + using ConvertDatapathToCombBase< + ConvertDatapathToCombPass>::ConvertDatapathToCombBase; +}; +} // namespace + +static void +populateDatapathToCombConversionPatterns(RewritePatternSet &patterns, + bool lowerCompressToAdd) { + patterns.add(patterns.getContext()); + + if (lowerCompressToAdd) + patterns.add(patterns.getContext()); + else + patterns.add(patterns.getContext()); +} + +void ConvertDatapathToCombPass::runOnOperation() { + ConversionTarget target(getContext()); + + target.addLegalDialect(); + target.addIllegalDialect(); + + RewritePatternSet patterns(&getContext()); + populateDatapathToCombConversionPatterns(patterns, lowerCompressToAdd); + + if (failed(mlir::applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); +} From c8fdf05a1e914fa9573fffa1c028c0d58ce5e97b Mon Sep 17 00:00:00 2001 From: cowardsa Date: Fri, 18 Jul 2025 11:59:47 +0100 Subject: [PATCH 02/12] Add tests and tidy datapath to compress implementation --- lib/Conversion/DatapathToComb/CMakeLists.txt | 5 +- .../DatapathToComb/DatapathToComb.cpp | 212 +++++++++--------- .../DatapathToComb/datapath-to-comb.mlir | 47 ++++ 3 files changed, 157 insertions(+), 107 deletions(-) create mode 100644 test/Conversion/DatapathToComb/datapath-to-comb.mlir diff --git a/lib/Conversion/DatapathToComb/CMakeLists.txt b/lib/Conversion/DatapathToComb/CMakeLists.txt index 21c5dbf893b2..2b85166968a7 100644 --- a/lib/Conversion/DatapathToComb/CMakeLists.txt +++ b/lib/Conversion/DatapathToComb/CMakeLists.txt @@ -4,12 +4,9 @@ add_circt_conversion_library(CIRCTDatapathToComb DEPENDS CIRCTConversionPassIncGen - LINK_COMPONENTS - Core - LINK_LIBS PUBLIC - CIRCTDatapath CIRCTComb + CIRCTDatapath CIRCTHW MLIRIR MLIRPass diff --git a/lib/Conversion/DatapathToComb/DatapathToComb.cpp b/lib/Conversion/DatapathToComb/DatapathToComb.cpp index a03e03f02a7c..b2c3bfdaf109 100644 --- a/lib/Conversion/DatapathToComb/DatapathToComb.cpp +++ b/lib/Conversion/DatapathToComb/DatapathToComb.cpp @@ -13,7 +13,6 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" -#include "llvm/ADT/PointerUnion.h" #include "llvm/Support/Debug.h" #define DEBUG_TYPE "datapath-to-comb" @@ -23,7 +22,6 @@ namespace circt { #include "circt/Conversion/Passes.h.inc" } // namespace circt -// using namespace mlir; using namespace circt; using namespace datapath; @@ -55,6 +53,9 @@ std::pair fullAdder(ConversionPatternRewriter &rewriter, return {sum, carry}; } +// Replace compressor by an adder of the inputs and zero for the other results: +// compress(a,b,c,d) -> {a+b+c+d, 0} +// Facilitates use of downstream compression algorithms e.g. Yosys struct DatapathCompressOpAddConversion : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -63,14 +64,18 @@ struct DatapathCompressOpAddConversion : OpConversionPattern { Location loc = op.getLoc(); auto inputs = op.getOperands(); unsigned width = inputs[0].getType().getIntOrFloatBitWidth(); + // Sum all the inputs - set that to result value 0 auto addOp = rewriter.create(loc, inputs, true); + // Replace remaining results with zeros auto zeroOp = rewriter.create(loc, APInt(width, 0)); - - rewriter.replaceOp(op, {addOp, zeroOp}); + SmallVector results(op.getNumResults() - 1, zeroOp); + results.push_back(addOp); + rewriter.replaceOp(op, results); return success(); } }; +// Replace compressor by a wallace tree of full-adders struct DatapathCompressOpConversion : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult @@ -79,19 +84,21 @@ struct DatapathCompressOpConversion : OpConversionPattern { Location loc = op.getLoc(); auto inputs = op.getOperands(); unsigned width = inputs[0].getType().getIntOrFloatBitWidth(); - // TODO - implement a more efficient compression algorithm to compete with - // yosys's `compress` pass. - auto falseValue = rewriter.create(loc, APInt(1, 0)); - SmallVector> partialProducts; + SmallVector> addends; for (auto input : inputs) { - partialProducts.push_back( + addends.push_back( extractBits(rewriter, input)); // Extract bits from each input } // Wallace tree reduction - rewriter.replaceOp(op, wallaceReduction(falseValue, width, rewriter, loc, - partialProducts)); + // TODO - implement a more efficient compression algorithm to compete with + // yosys's `alumacc` lowering - a coarse grained timing model would help to + // sort the inputs according to arrival time. + auto falseValue = rewriter.create(loc, APInt(1, 0)); + auto targetAddends = op.getNumResults(); + rewriter.replaceOp(op, wallaceReduction(falseValue, width, targetAddends, + rewriter, loc, addends)); return success(); } @@ -99,24 +106,24 @@ struct DatapathCompressOpConversion : OpConversionPattern { // Perform Wallace tree reduction on partial products. // See https://en.wikipedia.org/wiki/Wallace_tree static SmallVector - wallaceReduction(Value falseValue, size_t width, + wallaceReduction(Value falseValue, size_t width, size_t targetAddends, ConversionPatternRewriter &rewriter, Location loc, - SmallVector> &partialProducts) { - SmallVector> newPartialProducts; - newPartialProducts.reserve(partialProducts.size()); + SmallVector> &addends) { + SmallVector> newAddends; + newAddends.reserve(addends.size()); // Continue reduction until we have only two rows. The length of - // `partialProducts` is reduced by 1/3 in each iteration. + // `addends` is reduced by 1/3 in each iteration. auto numReductionStages = 0; - while (partialProducts.size() > 2) { - newPartialProducts.clear(); + while (addends.size() > targetAddends) { + newAddends.clear(); ++numReductionStages; // Take three rows at a time and reduce to two rows(sum and carry). - for (unsigned i = 0; i < partialProducts.size(); i += 3) { - if (i + 2 < partialProducts.size()) { + for (unsigned i = 0; i < addends.size(); i += 3) { + if (i + 2 < addends.size()) { // We have three rows to reduce - auto &row1 = partialProducts[i]; - auto &row2 = partialProducts[i + 1]; - auto &row3 = partialProducts[i + 2]; + auto &row1 = addends[i]; + auto &row2 = addends[i + 1]; + auto &row3 = addends[i + 2]; assert(row1.size() == width && row2.size() == width && row3.size() == width); @@ -136,29 +143,32 @@ struct DatapathCompressOpConversion : OpConversionPattern { carryRow.push_back(carry); } - newPartialProducts.push_back(std::move(sumRow)); - newPartialProducts.push_back(std::move(carryRow)); + newAddends.push_back(std::move(sumRow)); + newAddends.push_back(std::move(carryRow)); } else { // Add remaining rows as is - newPartialProducts.append(partialProducts.begin() + i, - partialProducts.end()); + newAddends.append(addends.begin() + i, addends.end()); } } - - std::swap(newPartialProducts, partialProducts); + std::swap(newAddends, addends); } LLVM_DEBUG(llvm::dbgs() << "Wallace tree reduction completed in " << numReductionStages << " stages\n"); - assert(partialProducts.size() == 2); + assert(addends.size() <= targetAddends); SmallVector carrySave; - for (auto partialProduct : partialProducts) { + for (auto addend : addends) { // Reverse the order of the bits - std::reverse(partialProduct.begin(), partialProduct.end()); - carrySave.push_back(rewriter.create(loc, partialProduct)); + std::reverse(addend.begin(), addend.end()); + carrySave.push_back(rewriter.create(loc, addend)); } - // Use comb.add for the final addition. + + // Pad with zeros + auto zero = rewriter.create(loc, APInt(width, 0)); + while (carrySave.size() < targetAddends) + carrySave.push_back(zero); + return carrySave; } }; @@ -181,130 +191,124 @@ struct DatapathPartialProductOpConversion return success(); } - // Extract individual bits from operands - SmallVector aBits = extractBits(rewriter, a); - SmallVector bBits = extractBits(rewriter, b); - - SmallVector partialProducts; - partialProducts.reserve(width); - - // Implement a basic and array + // Use width as a heuristic to guide partial product implementation if (width <= 16) - lowerAndArray(rewriter, op, partialProducts, aBits, bBits, width); + return lowerAndArray(rewriter, a, b, op, width); else - lowerBoothArray(rewriter, op, partialProducts, aBits, bBits, width); - - return success(); + return lowerBoothArray(rewriter, a, b, op, width); } - void lowerAndArray(ConversionPatternRewriter &rewriter, PartialProductOp op, - SmallVector partialProducts, - SmallVector &aBits, SmallVector &bBits, - unsigned width) const { +private: + static LogicalResult lowerAndArray(ConversionPatternRewriter &rewriter, + Value a, Value b, PartialProductOp op, + unsigned width) { Location loc = op.getLoc(); - Value a = op.getOperand(0); + // Keep a as a bitvector - multiply by each digit of b + SmallVector bBits = extractBits(rewriter, b); - // Generate partial products + SmallVector partialProducts; + partialProducts.reserve(width); + // AND Array Construction: + // partialProducts[i] = ({b[i],..., b[i]} & a) << i + assert(op.getNumResults() <= width && + "Cannot return more results than the operator width"); - for (unsigned i = 0; i < width; ++i) { + for (unsigned i = 0; i < op.getNumResults(); ++i) { auto repl = rewriter.create(loc, bBits[i], width); auto ppRow = rewriter.create(loc, repl, a); - Value shiftBy = rewriter.create(loc, APInt(width, i)); - Value ppAlign = rewriter.create(loc, ppRow, shiftBy); + auto shiftBy = rewriter.create(loc, APInt(width, i)); + auto ppAlign = rewriter.create(loc, ppRow, shiftBy); partialProducts.push_back(ppAlign); - if (partialProducts.size() == op.getNumResults()) - break; - } - if (partialProducts.size() != op.getNumResults()) { - llvm::errs() << "Expected " << op.getNumResults() - << " partial products, but got " << partialProducts.size() - << " width " << width << "\n"; - assert(false && "Expected width number of booth partial products"); } + rewriter.replaceOp(op, partialProducts); + return success(); } - void lowerBoothArray(ConversionPatternRewriter &rewriter, PartialProductOp op, - SmallVector &partialProducts, - SmallVector &aBits, SmallVector &bBits, - unsigned width) const { + static LogicalResult lowerBoothArray(ConversionPatternRewriter &rewriter, + Value a, Value b, PartialProductOp op, + unsigned width) { Location loc = op.getLoc(); auto zeroFalse = rewriter.create(loc, APInt(1, 0)); auto zeroWidth = rewriter.create(loc, APInt(width, 0)); auto oneWidth = rewriter.create(loc, APInt(width, 1)); - Value a = op.getOperand(0); Value twoA = rewriter.create(loc, a, oneWidth); - // Booth encoding: examine b[i+1:i-1] for each i - // We need width/2 partial products for radix-2 Booth - Value cplPrev; + SmallVector bBits = extractBits(rewriter, b); + + SmallVector partialProducts; + partialProducts.reserve(width); + + // Booth encoding halves array height by grouping three bits at a time: + // // partialProducts[i] = a * (-2*b[2*i+1] + b[2*i] + b[2*i-1] << 2*i + Value encNegPrev; for (unsigned i = 0; i < width; i += 2) { // Get Booth bits: b[i+1], b[i], b[i-1] (b[-1] = 0) - Value b_i = (i < width) ? bBits[i] : zeroFalse; - Value b_ip1 = (i + 1 < width) ? bBits[i + 1] : zeroFalse; - Value b_im1 = (i == 0) ? zeroFalse : bBits[i - 1]; + Value bim1 = (i == 0) ? zeroFalse : bBits[i - 1]; + Value bi = (i < width) ? bBits[i] : zeroFalse; + Value bip1 = (i + 1 < width) ? bBits[i + 1] : zeroFalse; // Is the encoding zero or negative (an approximation) - Value cpl = b_ip1; - // Is the encoding one = b_i xor b_im1 - Value one = rewriter.create(loc, b_i, b_im1, true); + Value encNeg = bip1; + // Is the encoding one = b[i] xor b[i-1] + Value encOne = rewriter.create(loc, bi, bim1, true); // Is the encoding two = (b_ip1 & ~b_i & ~b_im1) | (~b_ip1 & b_i & b_im1) - Value const_one = rewriter.create(loc, APInt(1, 1)); - Value b_i_inv = rewriter.create(loc, b_i, const_one, true); - Value b_ip1_inv = - rewriter.create(loc, b_ip1, const_one, true); - Value b_im1_inv = - rewriter.create(loc, b_im1, const_one, true); + Value constOne = rewriter.create(loc, APInt(1, 1)); + Value biInv = rewriter.create(loc, bi, constOne, true); + Value bip1Inv = rewriter.create(loc, bip1, constOne, true); + Value bim1Inv = rewriter.create(loc, bim1, constOne, true); + Value andLeft = rewriter.create( - loc, ValueRange{b_ip1_inv, b_i, b_im1}, true); + loc, ValueRange{bip1Inv, bi, bim1}, true); Value andRight = rewriter.create( - loc, ValueRange{b_ip1, b_i_inv, b_im1_inv}, true); - Value two = rewriter.create(loc, andLeft, andRight, true); + loc, ValueRange{bip1, biInv, bim1Inv}, true); + Value encTwo = rewriter.create(loc, andLeft, andRight, true); - Value cpl_repl = rewriter.create(loc, cpl, width); - Value one_repl = rewriter.create(loc, one, width); - Value two_repl = rewriter.create(loc, two, width); + Value encNegRepl = rewriter.create(loc, encNeg, width); + Value encOneRepl = rewriter.create(loc, encOne, width); + Value encTwoRepl = rewriter.create(loc, encTwo, width); // Select between 2*a or 1*a or 0*a - Value selTwoA = rewriter.create(loc, two_repl, twoA); - Value selOneA = rewriter.create(loc, one_repl, a); + Value selTwoA = rewriter.create(loc, encTwoRepl, twoA); + Value selOneA = rewriter.create(loc, encOneRepl, a); Value magA = rewriter.create(loc, selTwoA, selOneA, true); // Conditionally invert the row - Value ppRow = rewriter.create(loc, magA, cpl_repl, true); + Value ppRow = rewriter.create(loc, magA, encNegRepl, true); + + // No sign-correction in the first row if (i == 0) { partialProducts.push_back(ppRow); - cplPrev = cpl; + encNegPrev = encNeg; continue; } + + // Insert a sign-correction from the previous row assert(i >= 2 && "Expected i to be at least 2 for sign correction"); + // {ppRow, 0, encNegPrev} << 2*(i-1) Value withSignCorrection = rewriter.create( - loc, ValueRange{ppRow, zeroFalse, cplPrev}); + loc, ValueRange{ppRow, zeroFalse, encNegPrev}); Value ppAlignPre = rewriter.create(loc, withSignCorrection, 0, width); Value shiftBy = rewriter.create(loc, APInt(width, i - 2)); Value ppAlign = rewriter.create(loc, ppAlignPre, shiftBy); partialProducts.push_back(ppAlign); - cplPrev = cpl; + encNegPrev = encNeg; if (partialProducts.size() == op.getNumResults()) break; } + // Zero-pad to match the required output width while (partialProducts.size() < op.getNumResults()) partialProducts.push_back(zeroWidth); - if (partialProducts.size() != op.getNumResults()) { - llvm::errs() << "Expected " << op.getNumResults() - << " partial products, but got " << partialProducts.size() - << "width " << width << "\n"; - assert(false && "Expected width number of booth partial products"); - } assert(partialProducts.size() == op.getNumResults() && - "Expected width number of booth partial products"); + "Expected number of booth partial products to match results"); rewriter.replaceOp(op, partialProducts); + return success(); } }; } // namespace @@ -328,8 +332,10 @@ populateDatapathToCombConversionPatterns(RewritePatternSet &patterns, patterns.add(patterns.getContext()); if (lowerCompressToAdd) + // Lower compressors to simple add operations for downstream optimisations patterns.add(patterns.getContext()); else + // Lower compressors to a complete gate-level implementation patterns.add(patterns.getContext()); } @@ -337,7 +343,7 @@ void ConvertDatapathToCombPass::runOnOperation() { ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalDialect(); + target.addIllegalDialect(); RewritePatternSet patterns(&getContext()); populateDatapathToCombConversionPatterns(patterns, lowerCompressToAdd); diff --git a/test/Conversion/DatapathToComb/datapath-to-comb.mlir b/test/Conversion/DatapathToComb/datapath-to-comb.mlir new file mode 100644 index 000000000000..d6981904b55a --- /dev/null +++ b/test/Conversion/DatapathToComb/datapath-to-comb.mlir @@ -0,0 +1,47 @@ +// RUN: circt-opt %s --convert-datapath-to-smt | FileCheck %s + +// CHECK-LABEL: @compressor +hw.module @compressor(in %a : i2, in %b : i2, in %c : i2, out carry : i2, out save : i2) { + //CHECK-NEXT: %[[A0:.+]] = comb.extract %a from 0 : (i2) -> i1 + //CHECK-NEXT: %[[A1:.+]] = comb.extract %a from 1 : (i2) -> i1 + //CHECK-NEXT: %[[B0:.+]] = comb.extract %b from 0 : (i2) -> i1 + //CHECK-NEXT: %[[B1:.+]] = comb.extract %b from 1 : (i2) -> i1 + //CHECK-NEXT: %[[C0:.+]] = comb.extract %c from 0 : (i2) -> i1 + //CHECK-NEXT: %[[C1:.+]] = comb.extract %c from 1 : (i2) -> i1 + //CHECK-NEXT: %false = hw.constant false + //CHECK-NEXT: %[[AxB0:.+]] = comb.xor bin %[[A0]], %[[B0]] : i1 + //CHECK-NEXT: %[[AxBxC0:.+]] = comb.xor bin %[[AxB0]], %[[C0]] : i1 + //CHECK-NEXT: %[[AB0:.+]] = comb.and bin %[[A0]], %[[B0]] : i1 + //CHECK-NEXT: %[[AxBC0:.+]] = comb.and bin %[[AxB0]], %[[C0]] : i1 + //CHECK-NEXT: %[[AB0oAxBC0:.+]] = comb.or bin %[[AB0]], %[[AxBC0]] : i1 + //CHECK-NEXT: %[[AxB1:.+]] = comb.xor bin %[[A1]], %[[B1]] : i1 + //CHECK-NEXT: %[[AxBxC1:.+]] = comb.xor bin %[[AxB1]], %[[C1]] : i1 + //CHECK-NEXT: %[[AB1:.+]] = comb.and bin %[[A1]], %[[B1]] : i1 + //CHECK-NEXT: %[[AxBC1:.+]] = comb.and bin %[[AxB1]], %[[C1]] : i1 + //CHECK-NEXT: comb.or bin %[[AB1]], %[[AxBC1]] : i1 + //CHECK-NEXT: comb.concat %[[AxBxC1]], %[[AxBxC0]] : i1, i1 + //CHECK-NEXT: comb.concat %[[AB0oAxBC0]], %false : i1, i1 + %0:2 = datapath.compress %a, %b, %c : i2 [3 -> 2] + hw.output %0#0, %0#1 : i2, i2 +} + +// CHECK-LABEL: @partial_product +hw.module @partial_product(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, out pp2 : i3) { + // CHECK-NEXT: %[[B0:.+]] = comb.extract %b from 0 : (i3) -> i1 + // CHECK-NEXT: %[[B1:.+]] = comb.extract %b from 1 : (i3) -> i1 + // CHECK-NEXT: %[[B2:.+]] = comb.extract %b from 2 : (i3) -> i1 + // CHECK-NEXT: %[[B0R:.+]] = comb.replicate %[[B0]] : (i1) -> i3 + // CHECK-NEXT: %[[PP0:.+]] = comb.and %[[B0R]], %a : i3 + // CHECK-NEXT: %c0_i3 = hw.constant 0 : i3 + // CHECK-NEXT: comb.shl %[[PP0]], %c0_i3 : i3 + // CHECK-NEXT: %[[B1R:.+]] = comb.replicate %[[B1]] : (i1) -> i3 + // CHECK-NEXT: %[[PP1:.+]] = comb.and %[[B1R]], %a : i3 + // CHECK-NEXT: %c1_i3 = hw.constant 1 : i3 + // CHECK-NEXT: comb.shl %[[PP1]], %c1_i3 : i3 + // CHECK-NEXT: %[[B2R:.+]] = comb.replicate %[[B2]] : (i1) -> i3 + // CHECK-NEXT: %[[PP2:.+]] = comb.and %[[B2R]], %a : i3 + // CHECK-NEXT: %c2_i3 = hw.constant 2 : i3 + // CHECK-NEXT: comb.shl %[[PP2]], %c2_i3 : i3 + %0:3 = datapath.partial_product %a, %b : (i3, i3) -> (i3, i3, i3) + hw.output %0#0, %0#1, %0#2 : i3, i3, i3 +} From baa30c866ab6c7962dee31bcd55e148b67a74064 Mon Sep 17 00:00:00 2001 From: cowardsa Date: Fri, 18 Jul 2025 12:03:42 +0100 Subject: [PATCH 03/12] Improve comments --- lib/Conversion/DatapathToComb/DatapathToComb.cpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/DatapathToComb/DatapathToComb.cpp b/lib/Conversion/DatapathToComb/DatapathToComb.cpp index b2c3bfdaf109..641178ce3c56 100644 --- a/lib/Conversion/DatapathToComb/DatapathToComb.cpp +++ b/lib/Conversion/DatapathToComb/DatapathToComb.cpp @@ -241,7 +241,10 @@ struct DatapathPartialProductOpConversion partialProducts.reserve(width); // Booth encoding halves array height by grouping three bits at a time: - // // partialProducts[i] = a * (-2*b[2*i+1] + b[2*i] + b[2*i-1] << 2*i + // partialProducts[i] = a * (-2*b[2*i+1] + b[2*i] + b[2*i-1]) << 2*i + // encNeg \approx (-2*b[2*i+1] + b[2*i] + b[2*i-1]) <= 0 + // encOne = (-2*b[2*i+1] + b[2*i] + b[2*i-1]) == +/- 1 + // encTwo = (-2*b[2*i+1] + b[2*i] + b[2*i-1]) == +/- 2 Value encNegPrev; for (unsigned i = 0; i < width; i += 2) { // Get Booth bits: b[i+1], b[i], b[i-1] (b[-1] = 0) @@ -253,7 +256,7 @@ struct DatapathPartialProductOpConversion Value encNeg = bip1; // Is the encoding one = b[i] xor b[i-1] Value encOne = rewriter.create(loc, bi, bim1, true); - // Is the encoding two = (b_ip1 & ~b_i & ~b_im1) | (~b_ip1 & b_i & b_im1) + // Is the encoding two = (bip1 & ~bi & ~bim1) | (~bip1 & bi & bim1) Value constOne = rewriter.create(loc, APInt(1, 1)); Value biInv = rewriter.create(loc, bi, constOne, true); Value bip1Inv = rewriter.create(loc, bip1, constOne, true); From 3b7586111e6a555fd3a3507fdc3bfeb9f27bf512 Mon Sep 17 00:00:00 2001 From: cowardsa Date: Fri, 18 Jul 2025 12:09:05 +0100 Subject: [PATCH 04/12] Formatting and test corrections --- include/circt/Conversion/Passes.td | 3 +-- lib/Conversion/DatapathToComb/DatapathToComb.cpp | 4 ++-- test/Conversion/DatapathToComb/datapath-to-comb.mlir | 2 +- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/include/circt/Conversion/Passes.td b/include/circt/Conversion/Passes.td index fb432dffa07d..9c78f41fcb01 100644 --- a/include/circt/Conversion/Passes.td +++ b/include/circt/Conversion/Passes.td @@ -878,8 +878,7 @@ def ConvertDatapathToSMT : Pass<"convert-datapath-to-smt"> { def ConvertDatapathToComb : Pass<"convert-datapath-to-comb"> { let summary = "Convert Datapath ops to Comb ops"; let dependentDialects = [ - "circt::comb::CombDialect", - "circt::datapath::DatapathDialect", + "circt::comb::CombDialect", "circt::datapath::DatapathDialect", "circt::hw::HWDialect" ]; let options = [ diff --git a/lib/Conversion/DatapathToComb/DatapathToComb.cpp b/lib/Conversion/DatapathToComb/DatapathToComb.cpp index 641178ce3c56..2b276f897f71 100644 --- a/lib/Conversion/DatapathToComb/DatapathToComb.cpp +++ b/lib/Conversion/DatapathToComb/DatapathToComb.cpp @@ -1,4 +1,4 @@ -//===- DatapathToComb.cpp--------------------------------------------------===// +//===----------------------------------------------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -317,7 +317,7 @@ struct DatapathPartialProductOpConversion } // namespace //===----------------------------------------------------------------------===// -// Convert Comb to Datapath pass +// Convert Datapath to Comb pass //===----------------------------------------------------------------------===// namespace { diff --git a/test/Conversion/DatapathToComb/datapath-to-comb.mlir b/test/Conversion/DatapathToComb/datapath-to-comb.mlir index d6981904b55a..8ace72a66924 100644 --- a/test/Conversion/DatapathToComb/datapath-to-comb.mlir +++ b/test/Conversion/DatapathToComb/datapath-to-comb.mlir @@ -1,4 +1,4 @@ -// RUN: circt-opt %s --convert-datapath-to-smt | FileCheck %s +// RUN: circt-opt %s --convert-datapath-to-comb | FileCheck %s // CHECK-LABEL: @compressor hw.module @compressor(in %a : i2, in %b : i2, in %c : i2, out carry : i2, out save : i2) { From e68207aba76c11fe7ef935205f21cead26de94c8 Mon Sep 17 00:00:00 2001 From: cowardsa Date: Fri, 18 Jul 2025 14:07:06 +0100 Subject: [PATCH 05/12] Correct CAPI --- lib/CAPI/Conversion/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/lib/CAPI/Conversion/CMakeLists.txt b/lib/CAPI/Conversion/CMakeLists.txt index 0784bee53316..41ef6dc5d303 100644 --- a/lib/CAPI/Conversion/CMakeLists.txt +++ b/lib/CAPI/Conversion/CMakeLists.txt @@ -14,6 +14,7 @@ add_circt_public_c_api_library(CIRCTCAPIConversion CIRCTCombToLLVM CIRCTCombToSMT CIRCTConvertToArcs + CIRCTDatapathToComb CIRCTDatapathToSMT CIRCTDCToHW CIRCTExportChiselInterface From fa70d10e302dd29f1be27156640370ed505b37b4 Mon Sep 17 00:00:00 2001 From: cowardsa Date: Mon, 21 Jul 2025 15:25:19 +0100 Subject: [PATCH 06/12] Move wallace tree reduction and full-adder to comb ops --- include/circt/Dialect/Comb/CombOps.h | 12 +++ lib/Conversion/CombToAIG/CombToAIG.cpp | 87 ++---------------- .../DatapathToComb/DatapathToComb.cpp | 91 +------------------ lib/Dialect/Comb/CombOps.cpp | 80 ++++++++++++++++ 4 files changed, 101 insertions(+), 169 deletions(-) diff --git a/include/circt/Dialect/Comb/CombOps.h b/include/circt/Dialect/Comb/CombOps.h index 5c1c4cbdca3d..d4358e939ca3 100644 --- a/include/circt/Dialect/Comb/CombOps.h +++ b/include/circt/Dialect/Comb/CombOps.h @@ -23,6 +23,7 @@ #include "mlir/Interfaces/InferIntRangeInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Transforms/DialectConversion.h" namespace llvm { struct KnownBits; @@ -85,6 +86,17 @@ Value createDynamicInject(OpBuilder &builder, Location loc, Value value, Value createInject(OpBuilder &builder, Location loc, Value value, unsigned offset, Value replacement); +/// Construct a full adder for three 1-bit inputs. +std::pair fullAdder(ConversionPatternRewriter &rewriter, + Location loc, Value a, Value b, Value c); + +/// Perform Wallace tree reduction on partial products. +/// See https://en.wikipedia.org/wiki/Wallace_tree +SmallVector +wallaceReduction(Value falseValue, size_t width, size_t targetAddends, + ConversionPatternRewriter &rewriter, Location loc, + SmallVector> &partialProducts); + } // namespace comb } // namespace circt diff --git a/lib/Conversion/CombToAIG/CombToAIG.cpp b/lib/Conversion/CombToAIG/CombToAIG.cpp index f4b2262022af..c271f82d38a8 100644 --- a/lib/Conversion/CombToAIG/CombToAIG.cpp +++ b/lib/Conversion/CombToAIG/CombToAIG.cpp @@ -668,21 +668,6 @@ struct CombSubOpConversion : OpConversionPattern { } }; -// Construct a full adder for three 1-bit inputs. -std::pair fullAdder(ConversionPatternRewriter &rewriter, - Location loc, Value a, Value b, Value c) { - auto aXorB = rewriter.createOrFold(loc, a, b, true); - Value sum = rewriter.createOrFold(loc, aXorB, c, true); - - auto carry = rewriter.createOrFold( - loc, - ArrayRef{rewriter.createOrFold(loc, a, b, true), - rewriter.createOrFold(loc, aXorB, c, true)}, - true); - - return {sum, carry}; -} - struct CombMulOpConversion : OpConversionPattern { using OpConversionPattern::OpConversionPattern; using OpAdaptor = typename OpConversionPattern::OpAdaptor; @@ -729,74 +714,14 @@ struct CombMulOpConversion : OpConversionPattern { return success(); } - // Wallace tree reduction - replaceOpAndCopyNamehint( - rewriter, op, - wallaceReduction(falseValue, width, rewriter, loc, partialProducts)); + // Wallace tree reduction - reduce to two addends. + auto addends = comb::wallaceReduction(falseValue, width, 2, rewriter, loc, + partialProducts); + // Sum the two addends using a carry-propagate adder + auto newAdd = rewriter.create(loc, addends, true); + replaceOpAndCopyNamehint(rewriter, op, newAdd); return success(); } - -private: - // Perform Wallace tree reduction on partial products. - // See https://en.wikipedia.org/wiki/Wallace_tree - static Value - wallaceReduction(Value falseValue, size_t width, - ConversionPatternRewriter &rewriter, Location loc, - SmallVector> &partialProducts) { - SmallVector> newPartialProducts; - newPartialProducts.reserve(partialProducts.size()); - // Continue reduction until we have only two rows. The length of - // `partialProducts` is reduced by 1/3 in each iteration. - while (partialProducts.size() > 2) { - newPartialProducts.clear(); - // Take three rows at a time and reduce to two rows(sum and carry). - for (unsigned i = 0; i < partialProducts.size(); i += 3) { - if (i + 2 < partialProducts.size()) { - // We have three rows to reduce - auto &row1 = partialProducts[i]; - auto &row2 = partialProducts[i + 1]; - auto &row3 = partialProducts[i + 2]; - - assert(row1.size() == width && row2.size() == width && - row3.size() == width); - - SmallVector sumRow, carryRow; - sumRow.reserve(width); - carryRow.reserve(width); - carryRow.push_back(falseValue); - - // Process each bit position - for (unsigned j = 0; j < width; ++j) { - // Full adder logic - auto [sum, carry] = - fullAdder(rewriter, loc, row1[j], row2[j], row3[j]); - sumRow.push_back(sum); - if (j + 1 < width) - carryRow.push_back(carry); - } - - newPartialProducts.push_back(std::move(sumRow)); - newPartialProducts.push_back(std::move(carryRow)); - } else { - // Add remaining rows as is - newPartialProducts.append(partialProducts.begin() + i, - partialProducts.end()); - } - } - - std::swap(newPartialProducts, partialProducts); - } - - assert(partialProducts.size() == 2); - // Reverse the order of the bits - std::reverse(partialProducts[0].begin(), partialProducts[0].end()); - std::reverse(partialProducts[1].begin(), partialProducts[1].end()); - auto lhs = rewriter.create(loc, partialProducts[0]); - auto rhs = rewriter.create(loc, partialProducts[1]); - - // Use comb.add for the final addition. - return rewriter.create(loc, ArrayRef{lhs, rhs}, true); - } }; template diff --git a/lib/Conversion/DatapathToComb/DatapathToComb.cpp b/lib/Conversion/DatapathToComb/DatapathToComb.cpp index 2b276f897f71..6556f6503963 100644 --- a/lib/Conversion/DatapathToComb/DatapathToComb.cpp +++ b/lib/Conversion/DatapathToComb/DatapathToComb.cpp @@ -37,22 +37,6 @@ static SmallVector extractBits(OpBuilder &builder, Value val) { //===----------------------------------------------------------------------===// namespace { -// Construct a full adder for three 1-bit inputs. -std::pair fullAdder(ConversionPatternRewriter &rewriter, - Location loc, Value a, Value b, Value c) { - auto aXorB = rewriter.createOrFold(loc, a, b, true); - // a ^ b ^ c - Value sum = rewriter.createOrFold(loc, aXorB, c, true); - // (a & b) | ((a ^ b) & c) - auto carry = rewriter.createOrFold( - loc, - ArrayRef{rewriter.createOrFold(loc, a, b, true), - rewriter.createOrFold(loc, aXorB, c, true)}, - true); - - return {sum, carry}; -} - // Replace compressor by an adder of the inputs and zero for the other results: // compress(a,b,c,d) -> {a+b+c+d, 0} // Facilitates use of downstream compression algorithms e.g. Yosys @@ -97,80 +81,11 @@ struct DatapathCompressOpConversion : OpConversionPattern { // sort the inputs according to arrival time. auto falseValue = rewriter.create(loc, APInt(1, 0)); auto targetAddends = op.getNumResults(); - rewriter.replaceOp(op, wallaceReduction(falseValue, width, targetAddends, - rewriter, loc, addends)); + rewriter.replaceOp(op, + comb::wallaceReduction(falseValue, width, targetAddends, + rewriter, loc, addends)); return success(); } - -private: - // Perform Wallace tree reduction on partial products. - // See https://en.wikipedia.org/wiki/Wallace_tree - static SmallVector - wallaceReduction(Value falseValue, size_t width, size_t targetAddends, - ConversionPatternRewriter &rewriter, Location loc, - SmallVector> &addends) { - SmallVector> newAddends; - newAddends.reserve(addends.size()); - // Continue reduction until we have only two rows. The length of - // `addends` is reduced by 1/3 in each iteration. - auto numReductionStages = 0; - while (addends.size() > targetAddends) { - newAddends.clear(); - ++numReductionStages; - // Take three rows at a time and reduce to two rows(sum and carry). - for (unsigned i = 0; i < addends.size(); i += 3) { - if (i + 2 < addends.size()) { - // We have three rows to reduce - auto &row1 = addends[i]; - auto &row2 = addends[i + 1]; - auto &row3 = addends[i + 2]; - - assert(row1.size() == width && row2.size() == width && - row3.size() == width); - - SmallVector sumRow, carryRow; - sumRow.reserve(width); - carryRow.reserve(width); - carryRow.push_back(falseValue); - - // Process each bit position - for (unsigned j = 0; j < width; ++j) { - // Full adder logic - auto [sum, carry] = - fullAdder(rewriter, loc, row1[j], row2[j], row3[j]); - sumRow.push_back(sum); - if (j + 1 < width) - carryRow.push_back(carry); - } - - newAddends.push_back(std::move(sumRow)); - newAddends.push_back(std::move(carryRow)); - } else { - // Add remaining rows as is - newAddends.append(addends.begin() + i, addends.end()); - } - } - std::swap(newAddends, addends); - } - - LLVM_DEBUG(llvm::dbgs() << "Wallace tree reduction completed in " - << numReductionStages << " stages\n"); - - assert(addends.size() <= targetAddends); - SmallVector carrySave; - for (auto addend : addends) { - // Reverse the order of the bits - std::reverse(addend.begin(), addend.end()); - carrySave.push_back(rewriter.create(loc, addend)); - } - - // Pad with zeros - auto zero = rewriter.create(loc, APInt(width, 0)); - while (carrySave.size() < targetAddends) - carrySave.push_back(zero); - - return carrySave; - } }; struct DatapathPartialProductOpConversion diff --git a/lib/Dialect/Comb/CombOps.cpp b/lib/Dialect/Comb/CombOps.cpp index 403ed3a02ac1..6538a8897346 100644 --- a/lib/Dialect/Comb/CombOps.cpp +++ b/lib/Dialect/Comb/CombOps.cpp @@ -216,6 +216,86 @@ Value comb::createInject(OpBuilder &builder, Location loc, Value value, return builder.createOrFold(loc, fragments); } +// Construct a full adder for three 1-bit inputs. +std::pair comb::fullAdder(ConversionPatternRewriter &rewriter, + Location loc, Value a, Value b, + Value c) { + auto aXorB = rewriter.createOrFold(loc, a, b, true); + Value sum = rewriter.createOrFold(loc, aXorB, c, true); + + auto carry = rewriter.createOrFold( + loc, + ArrayRef{rewriter.createOrFold(loc, a, b, true), + rewriter.createOrFold(loc, aXorB, c, true)}, + true); + + return {sum, carry}; +} + +// Perform Wallace tree reduction on partial products. +// See https://en.wikipedia.org/wiki/Wallace_tree +SmallVector +comb::wallaceReduction(Value falseValue, size_t width, size_t targetAddends, + ConversionPatternRewriter &rewriter, Location loc, + SmallVector> &addends) { + SmallVector> newAddends; + newAddends.reserve(addends.size()); + // Continue reduction until we have only two rows. The length of + // `addends` is reduced by 1/3 in each iteration. + while (addends.size() > targetAddends) { + newAddends.clear(); + // Take three rows at a time and reduce to two rows(sum and carry). + for (unsigned i = 0; i < addends.size(); i += 3) { + if (i + 2 < addends.size()) { + // We have three rows to reduce + auto &row1 = addends[i]; + auto &row2 = addends[i + 1]; + auto &row3 = addends[i + 2]; + + assert(row1.size() == width && row2.size() == width && + row3.size() == width); + + SmallVector sumRow, carryRow; + sumRow.reserve(width); + carryRow.reserve(width); + carryRow.push_back(falseValue); + + // Process each bit position + for (unsigned j = 0; j < width; ++j) { + // Full adder logic + auto [sum, carry] = + comb::fullAdder(rewriter, loc, row1[j], row2[j], row3[j]); + sumRow.push_back(sum); + if (j + 1 < width) + carryRow.push_back(carry); + } + + newAddends.push_back(std::move(sumRow)); + newAddends.push_back(std::move(carryRow)); + } else { + // Add remaining rows as is + newAddends.append(addends.begin() + i, addends.end()); + } + } + std::swap(newAddends, addends); + } + + assert(addends.size() <= targetAddends); + SmallVector carrySave; + for (auto addend : addends) { + // Reverse the order of the bits + std::reverse(addend.begin(), addend.end()); + carrySave.push_back(rewriter.create(loc, addend)); + } + + // Pad with zeros + auto zero = rewriter.create(loc, APInt(width, 0)); + while (carrySave.size() < targetAddends) + carrySave.push_back(zero); + + return carrySave; +} + //===----------------------------------------------------------------------===// // ICmpOp //===----------------------------------------------------------------------===// From 87482062db914489b5e9f11090d966f2a5f089a4 Mon Sep 17 00:00:00 2001 From: cowardsa Date: Tue, 22 Jul 2025 10:50:48 +0100 Subject: [PATCH 07/12] Adding integration tests using circt-lec and correcting review comments --- include/circt/Dialect/Comb/CombOps.h | 13 +- .../circt-synth/datapath-lowering-lec.mlir | 27 ++ integration_test/circt-synth/test.mlir | 313 ++++++++++++++++++ lib/Conversion/CombToAIG/CombToAIG.cpp | 4 +- .../DatapathToComb/DatapathToComb.cpp | 13 +- lib/Dialect/Comb/CombOps.cpp | 28 +- .../DatapathToComb/datapath-to-comb.mlir | 25 ++ 7 files changed, 393 insertions(+), 30 deletions(-) create mode 100644 integration_test/circt-synth/datapath-lowering-lec.mlir create mode 100644 integration_test/circt-synth/test.mlir diff --git a/include/circt/Dialect/Comb/CombOps.h b/include/circt/Dialect/Comb/CombOps.h index d4358e939ca3..78b7f890d6e2 100644 --- a/include/circt/Dialect/Comb/CombOps.h +++ b/include/circt/Dialect/Comb/CombOps.h @@ -87,15 +87,16 @@ Value createInject(OpBuilder &builder, Location loc, Value value, unsigned offset, Value replacement); /// Construct a full adder for three 1-bit inputs. -std::pair fullAdder(ConversionPatternRewriter &rewriter, - Location loc, Value a, Value b, Value c); +std::pair fullAdder(OpBuilder &builder, Location loc, Value a, + Value b, Value c); /// Perform Wallace tree reduction on partial products. /// See https://en.wikipedia.org/wiki/Wallace_tree -SmallVector -wallaceReduction(Value falseValue, size_t width, size_t targetAddends, - ConversionPatternRewriter &rewriter, Location loc, - SmallVector> &partialProducts); +/// \param targetAddends The number of addends to reduce to (2 for carry-save). +/// \param inputAddends The rows of bits to be summed. +SmallVector wallaceReduction(OpBuilder &builder, Location loc, + size_t width, size_t targetAddends, + SmallVector> &addends); } // namespace comb } // namespace circt diff --git a/integration_test/circt-synth/datapath-lowering-lec.mlir b/integration_test/circt-synth/datapath-lowering-lec.mlir new file mode 100644 index 000000000000..bedb83ebf162 --- /dev/null +++ b/integration_test/circt-synth/datapath-lowering-lec.mlir @@ -0,0 +1,27 @@ +// REQUIRES: libz3 +// REQUIRES: circt-lec-jit + +// RUN: circt-opt %s --convert-datapath-to-comb -o %t.mlir +// RUN: circt-lec %t.mlir %s -c1=partial_product_6 -c2=partial_product_6 --shared-libs=%libz3 | FileCheck %s --check-prefix=PP_6 +// PP_6: c1 == c2 +hw.module @partial_product_6(in %a : i6, in %b : i6, out sum : i6) { + %0:6 = datapath.partial_product %a, %b : (i6, i6) -> (i6, i6, i6, i6, i6, i6) + %1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : i6 + hw.output %1 : i6 +} + +// RUN: circt-lec %t.mlir %s -c1=partial_product_16 -c2=partial_product_16 --shared-libs=%libz3 | FileCheck %s --check-prefix=PP_16 +// PP_16: c1 == c2 +hw.module @partial_product_16(in %a : i16, in %b : i16, out sum : i16) { + %0:16 = datapath.partial_product %a, %b : (i16, i16) -> (i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16) + %1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %0#8, %0#9, %0#10, %0#11, %0#12, %0#13, %0#14, %0#15 : i16 + hw.output %1 : i16 +} + +// RUN: circt-lec %t.mlir %s -c1=partial_product_17 -c2=partial_product_17 --shared-libs=%libz3 | FileCheck %s --check-prefix=PP_17 +// PP_17: c1 == c2 +hw.module @partial_product_17(in %a : i17, in %b : i17, out sum : i17) { + %0:17 = datapath.partial_product %a, %b : (i17, i17) -> (i17, i17, i17, i17, i17, i17, i17, i17, i17, i17, i17, i17, i17, i17, i17, i17, i17) + %1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %0#8, %0#9, %0#10, %0#11, %0#12, %0#13, %0#14, %0#15, %0#16 : i17 + hw.output %1 : i17 +} diff --git a/integration_test/circt-synth/test.mlir b/integration_test/circt-synth/test.mlir new file mode 100644 index 000000000000..24b8d14ee549 --- /dev/null +++ b/integration_test/circt-synth/test.mlir @@ -0,0 +1,313 @@ +module { + hw.module @partial_product_6(in %a : i6, in %b : i6, out sum : i6) { + %0 = comb.extract %b from 0 : (i6) -> i1 + %1 = comb.extract %b from 1 : (i6) -> i1 + %2 = comb.extract %b from 2 : (i6) -> i1 + %3 = comb.extract %b from 3 : (i6) -> i1 + %4 = comb.extract %b from 4 : (i6) -> i1 + %5 = comb.extract %b from 5 : (i6) -> i1 + %6 = comb.replicate %0 : (i1) -> i6 + %7 = comb.and %6, %a : i6 + %c0_i6 = hw.constant 0 : i6 + %8 = comb.shl %7, %c0_i6 : i6 + %9 = comb.replicate %1 : (i1) -> i6 + %10 = comb.and %9, %a : i6 + %c1_i6 = hw.constant 1 : i6 + %11 = comb.shl %10, %c1_i6 : i6 + %12 = comb.replicate %2 : (i1) -> i6 + %13 = comb.and %12, %a : i6 + %c2_i6 = hw.constant 2 : i6 + %14 = comb.shl %13, %c2_i6 : i6 + %15 = comb.replicate %3 : (i1) -> i6 + %16 = comb.and %15, %a : i6 + %c3_i6 = hw.constant 3 : i6 + %17 = comb.shl %16, %c3_i6 : i6 + %18 = comb.replicate %4 : (i1) -> i6 + %19 = comb.and %18, %a : i6 + %c4_i6 = hw.constant 4 : i6 + %20 = comb.shl %19, %c4_i6 : i6 + %21 = comb.replicate %5 : (i1) -> i6 + %22 = comb.and %21, %a : i6 + %c5_i6 = hw.constant 5 : i6 + %23 = comb.shl %22, %c5_i6 : i6 + %24 = comb.add bin %8, %11, %14, %17, %20, %23 : i6 + hw.output %24 : i6 + } + hw.module @partial_product_16(in %a : i16, in %b : i16, out sum : i16) { + %0 = comb.extract %b from 0 : (i16) -> i1 + %1 = comb.extract %b from 1 : (i16) -> i1 + %2 = comb.extract %b from 2 : (i16) -> i1 + %3 = comb.extract %b from 3 : (i16) -> i1 + %4 = comb.extract %b from 4 : (i16) -> i1 + %5 = comb.extract %b from 5 : (i16) -> i1 + %6 = comb.extract %b from 6 : (i16) -> i1 + %7 = comb.extract %b from 7 : (i16) -> i1 + %8 = comb.extract %b from 8 : (i16) -> i1 + %9 = comb.extract %b from 9 : (i16) -> i1 + %10 = comb.extract %b from 10 : (i16) -> i1 + %11 = comb.extract %b from 11 : (i16) -> i1 + %12 = comb.extract %b from 12 : (i16) -> i1 + %13 = comb.extract %b from 13 : (i16) -> i1 + %14 = comb.extract %b from 14 : (i16) -> i1 + %15 = comb.extract %b from 15 : (i16) -> i1 + %16 = comb.replicate %0 : (i1) -> i16 + %17 = comb.and %16, %a : i16 + %c0_i16 = hw.constant 0 : i16 + %18 = comb.shl %17, %c0_i16 : i16 + %19 = comb.replicate %1 : (i1) -> i16 + %20 = comb.and %19, %a : i16 + %c1_i16 = hw.constant 1 : i16 + %21 = comb.shl %20, %c1_i16 : i16 + %22 = comb.replicate %2 : (i1) -> i16 + %23 = comb.and %22, %a : i16 + %c2_i16 = hw.constant 2 : i16 + %24 = comb.shl %23, %c2_i16 : i16 + %25 = comb.replicate %3 : (i1) -> i16 + %26 = comb.and %25, %a : i16 + %c3_i16 = hw.constant 3 : i16 + %27 = comb.shl %26, %c3_i16 : i16 + %28 = comb.replicate %4 : (i1) -> i16 + %29 = comb.and %28, %a : i16 + %c4_i16 = hw.constant 4 : i16 + %30 = comb.shl %29, %c4_i16 : i16 + %31 = comb.replicate %5 : (i1) -> i16 + %32 = comb.and %31, %a : i16 + %c5_i16 = hw.constant 5 : i16 + %33 = comb.shl %32, %c5_i16 : i16 + %34 = comb.replicate %6 : (i1) -> i16 + %35 = comb.and %34, %a : i16 + %c6_i16 = hw.constant 6 : i16 + %36 = comb.shl %35, %c6_i16 : i16 + %37 = comb.replicate %7 : (i1) -> i16 + %38 = comb.and %37, %a : i16 + %c7_i16 = hw.constant 7 : i16 + %39 = comb.shl %38, %c7_i16 : i16 + %40 = comb.replicate %8 : (i1) -> i16 + %41 = comb.and %40, %a : i16 + %c8_i16 = hw.constant 8 : i16 + %42 = comb.shl %41, %c8_i16 : i16 + %43 = comb.replicate %9 : (i1) -> i16 + %44 = comb.and %43, %a : i16 + %c9_i16 = hw.constant 9 : i16 + %45 = comb.shl %44, %c9_i16 : i16 + %46 = comb.replicate %10 : (i1) -> i16 + %47 = comb.and %46, %a : i16 + %c10_i16 = hw.constant 10 : i16 + %48 = comb.shl %47, %c10_i16 : i16 + %49 = comb.replicate %11 : (i1) -> i16 + %50 = comb.and %49, %a : i16 + %c11_i16 = hw.constant 11 : i16 + %51 = comb.shl %50, %c11_i16 : i16 + %52 = comb.replicate %12 : (i1) -> i16 + %53 = comb.and %52, %a : i16 + %c12_i16 = hw.constant 12 : i16 + %54 = comb.shl %53, %c12_i16 : i16 + %55 = comb.replicate %13 : (i1) -> i16 + %56 = comb.and %55, %a : i16 + %c13_i16 = hw.constant 13 : i16 + %57 = comb.shl %56, %c13_i16 : i16 + %58 = comb.replicate %14 : (i1) -> i16 + %59 = comb.and %58, %a : i16 + %c14_i16 = hw.constant 14 : i16 + %60 = comb.shl %59, %c14_i16 : i16 + %61 = comb.replicate %15 : (i1) -> i16 + %62 = comb.and %61, %a : i16 + %c15_i16 = hw.constant 15 : i16 + %63 = comb.shl %62, %c15_i16 : i16 + %64 = comb.add bin %18, %21, %24, %27, %30, %33, %36, %39, %42, %45, %48, %51, %54, %57, %60, %63 : i16 + hw.output %64 : i16 + } + hw.module @partial_product_17(in %a : i17, in %b : i17, out sum : i17) { + %false = hw.constant false + %c0_i17 = hw.constant 0 : i17 + %c1_i17 = hw.constant 1 : i17 + %0 = comb.shl %a, %c1_i17 : i17 + %1 = comb.extract %b from 0 : (i17) -> i1 + %2 = comb.extract %b from 1 : (i17) -> i1 + %3 = comb.extract %b from 2 : (i17) -> i1 + %4 = comb.extract %b from 3 : (i17) -> i1 + %5 = comb.extract %b from 4 : (i17) -> i1 + %6 = comb.extract %b from 5 : (i17) -> i1 + %7 = comb.extract %b from 6 : (i17) -> i1 + %8 = comb.extract %b from 7 : (i17) -> i1 + %9 = comb.extract %b from 8 : (i17) -> i1 + %10 = comb.extract %b from 9 : (i17) -> i1 + %11 = comb.extract %b from 10 : (i17) -> i1 + %12 = comb.extract %b from 11 : (i17) -> i1 + %13 = comb.extract %b from 12 : (i17) -> i1 + %14 = comb.extract %b from 13 : (i17) -> i1 + %15 = comb.extract %b from 14 : (i17) -> i1 + %16 = comb.extract %b from 15 : (i17) -> i1 + %17 = comb.extract %b from 16 : (i17) -> i1 + %18 = comb.xor bin %1, %false : i1 + %true = hw.constant true + %19 = comb.xor bin %1, %true : i1 + %20 = comb.xor bin %2, %true : i1 + %21 = comb.xor bin %false, %true : i1 + %22 = comb.and bin %20, %1, %false : i1 + %23 = comb.and bin %2, %19, %21 : i1 + %24 = comb.or bin %22, %23 : i1 + %25 = comb.replicate %2 : (i1) -> i17 + %26 = comb.replicate %18 : (i1) -> i17 + %27 = comb.replicate %24 : (i1) -> i17 + %28 = comb.and %27, %0 : i17 + %29 = comb.and %26, %a : i17 + %30 = comb.or bin %28, %29 : i17 + %31 = comb.xor bin %30, %25 : i17 + %32 = comb.xor bin %3, %2 : i1 + %true_0 = hw.constant true + %33 = comb.xor bin %3, %true_0 : i1 + %34 = comb.xor bin %4, %true_0 : i1 + %35 = comb.xor bin %2, %true_0 : i1 + %36 = comb.and bin %34, %3, %2 : i1 + %37 = comb.and bin %4, %33, %35 : i1 + %38 = comb.or bin %36, %37 : i1 + %39 = comb.replicate %4 : (i1) -> i17 + %40 = comb.replicate %32 : (i1) -> i17 + %41 = comb.replicate %38 : (i1) -> i17 + %42 = comb.and %41, %0 : i17 + %43 = comb.and %40, %a : i17 + %44 = comb.or bin %42, %43 : i17 + %45 = comb.xor bin %44, %39 : i17 + %46 = comb.concat %45, %false, %2 : i17, i1, i1 + %47 = comb.extract %46 from 0 : (i19) -> i17 + %c0_i17_1 = hw.constant 0 : i17 + %48 = comb.shl %47, %c0_i17_1 : i17 + %49 = comb.xor bin %5, %4 : i1 + %true_2 = hw.constant true + %50 = comb.xor bin %5, %true_2 : i1 + %51 = comb.xor bin %6, %true_2 : i1 + %52 = comb.xor bin %4, %true_2 : i1 + %53 = comb.and bin %51, %5, %4 : i1 + %54 = comb.and bin %6, %50, %52 : i1 + %55 = comb.or bin %53, %54 : i1 + %56 = comb.replicate %6 : (i1) -> i17 + %57 = comb.replicate %49 : (i1) -> i17 + %58 = comb.replicate %55 : (i1) -> i17 + %59 = comb.and %58, %0 : i17 + %60 = comb.and %57, %a : i17 + %61 = comb.or bin %59, %60 : i17 + %62 = comb.xor bin %61, %56 : i17 + %63 = comb.concat %62, %false, %4 : i17, i1, i1 + %64 = comb.extract %63 from 0 : (i19) -> i17 + %c2_i17 = hw.constant 2 : i17 + %65 = comb.shl %64, %c2_i17 : i17 + %66 = comb.xor bin %7, %6 : i1 + %true_3 = hw.constant true + %67 = comb.xor bin %7, %true_3 : i1 + %68 = comb.xor bin %8, %true_3 : i1 + %69 = comb.xor bin %6, %true_3 : i1 + %70 = comb.and bin %68, %7, %6 : i1 + %71 = comb.and bin %8, %67, %69 : i1 + %72 = comb.or bin %70, %71 : i1 + %73 = comb.replicate %8 : (i1) -> i17 + %74 = comb.replicate %66 : (i1) -> i17 + %75 = comb.replicate %72 : (i1) -> i17 + %76 = comb.and %75, %0 : i17 + %77 = comb.and %74, %a : i17 + %78 = comb.or bin %76, %77 : i17 + %79 = comb.xor bin %78, %73 : i17 + %80 = comb.concat %79, %false, %6 : i17, i1, i1 + %81 = comb.extract %80 from 0 : (i19) -> i17 + %c4_i17 = hw.constant 4 : i17 + %82 = comb.shl %81, %c4_i17 : i17 + %83 = comb.xor bin %9, %8 : i1 + %true_4 = hw.constant true + %84 = comb.xor bin %9, %true_4 : i1 + %85 = comb.xor bin %10, %true_4 : i1 + %86 = comb.xor bin %8, %true_4 : i1 + %87 = comb.and bin %85, %9, %8 : i1 + %88 = comb.and bin %10, %84, %86 : i1 + %89 = comb.or bin %87, %88 : i1 + %90 = comb.replicate %10 : (i1) -> i17 + %91 = comb.replicate %83 : (i1) -> i17 + %92 = comb.replicate %89 : (i1) -> i17 + %93 = comb.and %92, %0 : i17 + %94 = comb.and %91, %a : i17 + %95 = comb.or bin %93, %94 : i17 + %96 = comb.xor bin %95, %90 : i17 + %97 = comb.concat %96, %false, %8 : i17, i1, i1 + %98 = comb.extract %97 from 0 : (i19) -> i17 + %c6_i17 = hw.constant 6 : i17 + %99 = comb.shl %98, %c6_i17 : i17 + %100 = comb.xor bin %11, %10 : i1 + %true_5 = hw.constant true + %101 = comb.xor bin %11, %true_5 : i1 + %102 = comb.xor bin %12, %true_5 : i1 + %103 = comb.xor bin %10, %true_5 : i1 + %104 = comb.and bin %102, %11, %10 : i1 + %105 = comb.and bin %12, %101, %103 : i1 + %106 = comb.or bin %104, %105 : i1 + %107 = comb.replicate %12 : (i1) -> i17 + %108 = comb.replicate %100 : (i1) -> i17 + %109 = comb.replicate %106 : (i1) -> i17 + %110 = comb.and %109, %0 : i17 + %111 = comb.and %108, %a : i17 + %112 = comb.or bin %110, %111 : i17 + %113 = comb.xor bin %112, %107 : i17 + %114 = comb.concat %113, %false, %10 : i17, i1, i1 + %115 = comb.extract %114 from 0 : (i19) -> i17 + %c8_i17 = hw.constant 8 : i17 + %116 = comb.shl %115, %c8_i17 : i17 + %117 = comb.xor bin %13, %12 : i1 + %true_6 = hw.constant true + %118 = comb.xor bin %13, %true_6 : i1 + %119 = comb.xor bin %14, %true_6 : i1 + %120 = comb.xor bin %12, %true_6 : i1 + %121 = comb.and bin %119, %13, %12 : i1 + %122 = comb.and bin %14, %118, %120 : i1 + %123 = comb.or bin %121, %122 : i1 + %124 = comb.replicate %14 : (i1) -> i17 + %125 = comb.replicate %117 : (i1) -> i17 + %126 = comb.replicate %123 : (i1) -> i17 + %127 = comb.and %126, %0 : i17 + %128 = comb.and %125, %a : i17 + %129 = comb.or bin %127, %128 : i17 + %130 = comb.xor bin %129, %124 : i17 + %131 = comb.concat %130, %false, %12 : i17, i1, i1 + %132 = comb.extract %131 from 0 : (i19) -> i17 + %c10_i17 = hw.constant 10 : i17 + %133 = comb.shl %132, %c10_i17 : i17 + %134 = comb.xor bin %15, %14 : i1 + %true_7 = hw.constant true + %135 = comb.xor bin %15, %true_7 : i1 + %136 = comb.xor bin %16, %true_7 : i1 + %137 = comb.xor bin %14, %true_7 : i1 + %138 = comb.and bin %136, %15, %14 : i1 + %139 = comb.and bin %16, %135, %137 : i1 + %140 = comb.or bin %138, %139 : i1 + %141 = comb.replicate %16 : (i1) -> i17 + %142 = comb.replicate %134 : (i1) -> i17 + %143 = comb.replicate %140 : (i1) -> i17 + %144 = comb.and %143, %0 : i17 + %145 = comb.and %142, %a : i17 + %146 = comb.or bin %144, %145 : i17 + %147 = comb.xor bin %146, %141 : i17 + %148 = comb.concat %147, %false, %14 : i17, i1, i1 + %149 = comb.extract %148 from 0 : (i19) -> i17 + %c12_i17 = hw.constant 12 : i17 + %150 = comb.shl %149, %c12_i17 : i17 + %151 = comb.xor bin %17, %16 : i1 + %true_8 = hw.constant true + %152 = comb.xor bin %17, %true_8 : i1 + %153 = comb.xor bin %false, %true_8 : i1 + %154 = comb.xor bin %16, %true_8 : i1 + %155 = comb.and bin %153, %17, %16 : i1 + %156 = comb.and bin %false, %152, %154 : i1 + %157 = comb.or bin %155, %156 : i1 + %158 = comb.replicate %false : (i1) -> i17 + %159 = comb.replicate %151 : (i1) -> i17 + %160 = comb.replicate %157 : (i1) -> i17 + %161 = comb.and %160, %0 : i17 + %162 = comb.and %159, %a : i17 + %163 = comb.or bin %161, %162 : i17 + %164 = comb.xor bin %163, %158 : i17 + %165 = comb.concat %164, %false, %16 : i17, i1, i1 + %166 = comb.extract %165 from 0 : (i19) -> i17 + %c14_i17 = hw.constant 14 : i17 + %167 = comb.shl %166, %c14_i17 : i17 + %168 = comb.add bin %31, %48, %65, %82, %99, %116, %133, %150, %167, %c0_i17, %c0_i17, %c0_i17, %c0_i17, %c0_i17, %c0_i17, %c0_i17, %c0_i17 : i17 + hw.output %168 : i17 + } +} + diff --git a/lib/Conversion/CombToAIG/CombToAIG.cpp b/lib/Conversion/CombToAIG/CombToAIG.cpp index c271f82d38a8..1c59d8df57e9 100644 --- a/lib/Conversion/CombToAIG/CombToAIG.cpp +++ b/lib/Conversion/CombToAIG/CombToAIG.cpp @@ -715,8 +715,8 @@ struct CombMulOpConversion : OpConversionPattern { } // Wallace tree reduction - reduce to two addends. - auto addends = comb::wallaceReduction(falseValue, width, 2, rewriter, loc, - partialProducts); + auto addends = + comb::wallaceReduction(rewriter, loc, width, 2, partialProducts); // Sum the two addends using a carry-propagate adder auto newAdd = rewriter.create(loc, addends, true); replaceOpAndCopyNamehint(rewriter, op, newAdd); diff --git a/lib/Conversion/DatapathToComb/DatapathToComb.cpp b/lib/Conversion/DatapathToComb/DatapathToComb.cpp index 6556f6503963..8a633f4596b9 100644 --- a/lib/Conversion/DatapathToComb/DatapathToComb.cpp +++ b/lib/Conversion/DatapathToComb/DatapathToComb.cpp @@ -79,11 +79,9 @@ struct DatapathCompressOpConversion : OpConversionPattern { // TODO - implement a more efficient compression algorithm to compete with // yosys's `alumacc` lowering - a coarse grained timing model would help to // sort the inputs according to arrival time. - auto falseValue = rewriter.create(loc, APInt(1, 0)); auto targetAddends = op.getNumResults(); - rewriter.replaceOp(op, - comb::wallaceReduction(falseValue, width, targetAddends, - rewriter, loc, addends)); + rewriter.replaceOp(op, comb::wallaceReduction(rewriter, loc, width, + targetAddends, addends)); return success(); } }; @@ -95,9 +93,8 @@ struct DatapathPartialProductOpConversion matchAndRewrite(PartialProductOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto inputs = op.getOperands(); - Value a = inputs[0]; - Value b = inputs[1]; + Value a = op.getLhs(); + Value b = op.getRhs(); unsigned width = a.getType().getIntOrFloatBitWidth(); // Skip a zero width value. @@ -164,7 +161,7 @@ struct DatapathPartialProductOpConversion for (unsigned i = 0; i < width; i += 2) { // Get Booth bits: b[i+1], b[i], b[i-1] (b[-1] = 0) Value bim1 = (i == 0) ? zeroFalse : bBits[i - 1]; - Value bi = (i < width) ? bBits[i] : zeroFalse; + Value bi = bBits[i]; Value bip1 = (i + 1 < width) ? bBits[i + 1] : zeroFalse; // Is the encoding zero or negative (an approximation) diff --git a/lib/Dialect/Comb/CombOps.cpp b/lib/Dialect/Comb/CombOps.cpp index 6538a8897346..c68888bbbbb4 100644 --- a/lib/Dialect/Comb/CombOps.cpp +++ b/lib/Dialect/Comb/CombOps.cpp @@ -217,16 +217,15 @@ Value comb::createInject(OpBuilder &builder, Location loc, Value value, } // Construct a full adder for three 1-bit inputs. -std::pair comb::fullAdder(ConversionPatternRewriter &rewriter, - Location loc, Value a, Value b, - Value c) { - auto aXorB = rewriter.createOrFold(loc, a, b, true); - Value sum = rewriter.createOrFold(loc, aXorB, c, true); +std::pair comb::fullAdder(OpBuilder &builder, Location loc, + Value a, Value b, Value c) { + auto aXorB = builder.createOrFold(loc, a, b, true); + Value sum = builder.createOrFold(loc, aXorB, c, true); - auto carry = rewriter.createOrFold( + auto carry = builder.createOrFold( loc, - ArrayRef{rewriter.createOrFold(loc, a, b, true), - rewriter.createOrFold(loc, aXorB, c, true)}, + ArrayRef{builder.createOrFold(loc, a, b, true), + builder.createOrFold(loc, aXorB, c, true)}, true); return {sum, carry}; @@ -235,9 +234,10 @@ std::pair comb::fullAdder(ConversionPatternRewriter &rewriter, // Perform Wallace tree reduction on partial products. // See https://en.wikipedia.org/wiki/Wallace_tree SmallVector -comb::wallaceReduction(Value falseValue, size_t width, size_t targetAddends, - ConversionPatternRewriter &rewriter, Location loc, +comb::wallaceReduction(OpBuilder &builder, Location loc, size_t width, + size_t targetAddends, SmallVector> &addends) { + auto falseValue = builder.create(loc, APInt(1, 0)); SmallVector> newAddends; newAddends.reserve(addends.size()); // Continue reduction until we have only two rows. The length of @@ -264,7 +264,7 @@ comb::wallaceReduction(Value falseValue, size_t width, size_t targetAddends, for (unsigned j = 0; j < width; ++j) { // Full adder logic auto [sum, carry] = - comb::fullAdder(rewriter, loc, row1[j], row2[j], row3[j]); + comb::fullAdder(builder, loc, row1[j], row2[j], row3[j]); sumRow.push_back(sum); if (j + 1 < width) carryRow.push_back(carry); @@ -282,14 +282,14 @@ comb::wallaceReduction(Value falseValue, size_t width, size_t targetAddends, assert(addends.size() <= targetAddends); SmallVector carrySave; - for (auto addend : addends) { + for (auto &addend : addends) { // Reverse the order of the bits std::reverse(addend.begin(), addend.end()); - carrySave.push_back(rewriter.create(loc, addend)); + carrySave.push_back(builder.create(loc, addend)); } // Pad with zeros - auto zero = rewriter.create(loc, APInt(width, 0)); + auto zero = builder.create(loc, APInt(width, 0)); while (carrySave.size() < targetAddends) carrySave.push_back(zero); diff --git a/test/Conversion/DatapathToComb/datapath-to-comb.mlir b/test/Conversion/DatapathToComb/datapath-to-comb.mlir index 8ace72a66924..659cd67939d3 100644 --- a/test/Conversion/DatapathToComb/datapath-to-comb.mlir +++ b/test/Conversion/DatapathToComb/datapath-to-comb.mlir @@ -1,4 +1,5 @@ // RUN: circt-opt %s --convert-datapath-to-comb | FileCheck %s +// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-datapath-to-comb{lower-compress-to-add=true}))" | FileCheck %s --check-prefix=ALLOW_ADD // CHECK-LABEL: @compressor hw.module @compressor(in %a : i2, in %b : i2, in %c : i2, out carry : i2, out save : i2) { @@ -25,6 +26,16 @@ hw.module @compressor(in %a : i2, in %b : i2, in %c : i2, out carry : i2, out sa hw.output %0#0, %0#1 : i2, i2 } +// CHECK-LABEL: @compressor_add +// TO-ADD-LABEL: @compressor_add +// TO-ADD-NEXT: %[[ADD:.+]] = comb.add bin %a, %b, %c : i2 +// TO-ADD-NEXT: %c0_i2 = hw.constant 0 : i2 +// TO-ADD-NEXT: hw.output %c0_i2, %[[ADD]] : i2, i2 +hw.module @compressor_add(in %a : i2, in %b : i2, in %c : i2, out carry : i2, out save : i2) { + %0:2 = datapath.compress %a, %b, %c : i2 [3 -> 2] + hw.output %0#0, %0#1 : i2, i2 +} + // CHECK-LABEL: @partial_product hw.module @partial_product(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, out pp2 : i3) { // CHECK-NEXT: %[[B0:.+]] = comb.extract %b from 0 : (i3) -> i1 @@ -45,3 +56,17 @@ hw.module @partial_product(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, o %0:3 = datapath.partial_product %a, %b : (i3, i3) -> (i3, i3, i3) hw.output %0#0, %0#1, %0#2 : i3, i3, i3 } + +// CHECK-LABEL: @partial_product_24 +hw.module @partial_product_24(in %a : i24, in %b : i24, out sum : i24) { + %0:24 = datapath.partial_product %a, %b : (i24, i24) -> (i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24) + %1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %0#8, %0#9, %0#10, %0#11, %0#12, %0#13, %0#14, %0#15, %0#16, %0#17, %0#18, %0#19, %0#20, %0#21, %0#22, %0#23 : i24 + hw.output %1 : i24 +} + +// CHECK-LABEL: @partial_product_25 +hw.module @partial_product_25(in %a : i25, in %b : i25, out sum : i25) { + %0:25 = datapath.partial_product %a, %b : (i25, i25) -> (i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25, i25) + %1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %0#8, %0#9, %0#10, %0#11, %0#12, %0#13, %0#14, %0#15, %0#16, %0#17, %0#18, %0#19, %0#20, %0#21, %0#22, %0#23 : i25 + hw.output %1 : i25 +} From a1b34a8a2de8da0cab90d15111cf9049ddd6609b Mon Sep 17 00:00:00 2001 From: cowardsa Date: Wed, 23 Jul 2025 11:36:04 +0100 Subject: [PATCH 08/12] Fix bug in Booth code for final sign correction row and add testing using lec. Add a forceBooth option largely for testing purposes --- include/circt/Conversion/DatapathToComb.h | 3 +- include/circt/Conversion/Passes.td | 4 +- .../circt-synth/datapath-lowering-lec.mlir | 57 +++- integration_test/circt-synth/test.mlir | 313 ------------------ .../DatapathToComb/DatapathToComb.cpp | 30 +- .../CombToAIG/comb-to-aig-arith.mlir | 2 +- .../DatapathToComb/datapath-to-comb.mlir | 36 +- 7 files changed, 100 insertions(+), 345 deletions(-) delete mode 100644 integration_test/circt-synth/test.mlir diff --git a/include/circt/Conversion/DatapathToComb.h b/include/circt/Conversion/DatapathToComb.h index 7368041cb39e..d3b95b2d07bd 100644 --- a/include/circt/Conversion/DatapathToComb.h +++ b/include/circt/Conversion/DatapathToComb.h @@ -15,7 +15,8 @@ namespace circt { void populateDatapathToCombConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns, - bool lowerCompressToAdd); + bool lowerCompressToAdd, + bool forceBooth); #define GEN_PASS_DECL_CONVERTDATAPATHTOCOMB #include "circt/Conversion/Passes.h.inc" diff --git a/include/circt/Conversion/Passes.td b/include/circt/Conversion/Passes.td index 9c78f41fcb01..178a66aaecef 100644 --- a/include/circt/Conversion/Passes.td +++ b/include/circt/Conversion/Passes.td @@ -883,7 +883,9 @@ def ConvertDatapathToComb : Pass<"convert-datapath-to-comb"> { ]; let options = [ Option<"lowerCompressToAdd", "lower-compress-to-add", "bool", "false", - "Lower compress operators to variadic add."> + "Lower compress operators to variadic add.">, + Option<"forceBooth", "lower-partial-product-to-booth", "bool", "false", + "Force all partial products to be lowered to Booth arrays."> ]; } diff --git a/integration_test/circt-synth/datapath-lowering-lec.mlir b/integration_test/circt-synth/datapath-lowering-lec.mlir index bedb83ebf162..1cfd20d6d3bd 100644 --- a/integration_test/circt-synth/datapath-lowering-lec.mlir +++ b/integration_test/circt-synth/datapath-lowering-lec.mlir @@ -2,26 +2,47 @@ // REQUIRES: circt-lec-jit // RUN: circt-opt %s --convert-datapath-to-comb -o %t.mlir -// RUN: circt-lec %t.mlir %s -c1=partial_product_6 -c2=partial_product_6 --shared-libs=%libz3 | FileCheck %s --check-prefix=PP_6 -// PP_6: c1 == c2 -hw.module @partial_product_6(in %a : i6, in %b : i6, out sum : i6) { - %0:6 = datapath.partial_product %a, %b : (i6, i6) -> (i6, i6, i6, i6, i6, i6) - %1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4, %0#5 : i6 - hw.output %1 : i6 +// RUN: circt-lec %t.mlir %s -c1=partial_product_5 -c2=partial_product_5 --shared-libs=%libz3 | FileCheck %s --check-prefix=AND5 +// AND5: c1 == c2 +hw.module @partial_product_5(in %a : i5, in %b : i5, out sum : i5) { + %0:5 = datapath.partial_product %a, %b : (i5, i5) -> (i5, i5, i5, i5, i5) + %1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4 : i5 + hw.output %1 : i5 } -// RUN: circt-lec %t.mlir %s -c1=partial_product_16 -c2=partial_product_16 --shared-libs=%libz3 | FileCheck %s --check-prefix=PP_16 -// PP_16: c1 == c2 -hw.module @partial_product_16(in %a : i16, in %b : i16, out sum : i16) { - %0:16 = datapath.partial_product %a, %b : (i16, i16) -> (i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16, i16) - %1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %0#8, %0#9, %0#10, %0#11, %0#12, %0#13, %0#14, %0#15 : i16 - hw.output %1 : i16 +// RUN: circt-lec %t.mlir %s -c1=partial_product_4 -c2=partial_product_4 --shared-libs=%libz3 | FileCheck %s --check-prefix=AND4 +// AND4: c1 == c2 +hw.module @partial_product_4(in %a : i4, in %b : i4, out sum : i4) { + %0:4 = datapath.partial_product %a, %b : (i4, i4) -> (i4, i4, i4, i4) + %1 = comb.add bin %0#0, %0#1, %0#2, %0#3 : i4 + hw.output %1 : i4 } -// RUN: circt-lec %t.mlir %s -c1=partial_product_17 -c2=partial_product_17 --shared-libs=%libz3 | FileCheck %s --check-prefix=PP_17 -// PP_17: c1 == c2 -hw.module @partial_product_17(in %a : i17, in %b : i17, out sum : i17) { - %0:17 = datapath.partial_product %a, %b : (i17, i17) -> (i17, i17, i17, i17, i17, i17, i17, i17, i17, i17, i17, i17, i17, i17, i17, i17, i17) - %1 = comb.add bin %0#0, %0#1, %0#2, %0#3, %0#4, %0#5, %0#6, %0#7, %0#8, %0#9, %0#10, %0#11, %0#12, %0#13, %0#14, %0#15, %0#16 : i17 - hw.output %1 : i17 +// RUN: circt-lec %t.mlir %s -c1=compress_3 -c2=compress_3 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMP3 +// COMP3: c1 == c2 +hw.module @compress_3(in %a : i4, in %b : i4, in %c : i4, out sum : i4) { + %0:2 = datapath.compress %a, %b, %c : i4 [3 -> 2] + %1 = comb.add bin %0#0, %0#1 : i4 + hw.output %1 : i4 } + +// RUN: circt-lec %t.mlir %s -c1=compress_6 -c2=compress_6 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMP6 +// COMP6: c1 == c2 +hw.module @compress_6(in %a : i4, in %b : i4, in %c : i4, in %d : i4, in %e : i4, in %f : i4, out sum : i4) { + %0:3 = datapath.compress %a, %b, %c, %d, %e, %f : i4 [6 -> 3] + %1 = comb.add bin %0#0, %0#1, %0#2 : i4 + hw.output %1 : i4 +} + +// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-datapath-to-comb{lower-partial-product-to-booth=true lower-compress-to-add=true}))" -o %t.mlir +// RUN: circt-lec %t.mlir %s -c1=partial_product_5 -c2=partial_product_5 --shared-libs=%libz3 | FileCheck %s --check-prefix=BOOTH5 +// BOOTH5: c1 == c2 + +// RUN: circt-lec %t.mlir %s -c1=partial_product_4 -c2=partial_product_4 --shared-libs=%libz3 | FileCheck %s --check-prefix=BOOTH4 +// BOOTH4: c1 == c2 + +// RUN: circt-lec %t.mlir %s -c1=compress_3 -c2=compress_3 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMPADD3 +// COMPADD3: c1 == c2 + +// RUN: circt-lec %t.mlir %s -c1=compress_6 -c2=compress_6 --shared-libs=%libz3 | FileCheck %s --check-prefix=COMPADD6 +// COMPADD6: c1 == c2 diff --git a/integration_test/circt-synth/test.mlir b/integration_test/circt-synth/test.mlir deleted file mode 100644 index 24b8d14ee549..000000000000 --- a/integration_test/circt-synth/test.mlir +++ /dev/null @@ -1,313 +0,0 @@ -module { - hw.module @partial_product_6(in %a : i6, in %b : i6, out sum : i6) { - %0 = comb.extract %b from 0 : (i6) -> i1 - %1 = comb.extract %b from 1 : (i6) -> i1 - %2 = comb.extract %b from 2 : (i6) -> i1 - %3 = comb.extract %b from 3 : (i6) -> i1 - %4 = comb.extract %b from 4 : (i6) -> i1 - %5 = comb.extract %b from 5 : (i6) -> i1 - %6 = comb.replicate %0 : (i1) -> i6 - %7 = comb.and %6, %a : i6 - %c0_i6 = hw.constant 0 : i6 - %8 = comb.shl %7, %c0_i6 : i6 - %9 = comb.replicate %1 : (i1) -> i6 - %10 = comb.and %9, %a : i6 - %c1_i6 = hw.constant 1 : i6 - %11 = comb.shl %10, %c1_i6 : i6 - %12 = comb.replicate %2 : (i1) -> i6 - %13 = comb.and %12, %a : i6 - %c2_i6 = hw.constant 2 : i6 - %14 = comb.shl %13, %c2_i6 : i6 - %15 = comb.replicate %3 : (i1) -> i6 - %16 = comb.and %15, %a : i6 - %c3_i6 = hw.constant 3 : i6 - %17 = comb.shl %16, %c3_i6 : i6 - %18 = comb.replicate %4 : (i1) -> i6 - %19 = comb.and %18, %a : i6 - %c4_i6 = hw.constant 4 : i6 - %20 = comb.shl %19, %c4_i6 : i6 - %21 = comb.replicate %5 : (i1) -> i6 - %22 = comb.and %21, %a : i6 - %c5_i6 = hw.constant 5 : i6 - %23 = comb.shl %22, %c5_i6 : i6 - %24 = comb.add bin %8, %11, %14, %17, %20, %23 : i6 - hw.output %24 : i6 - } - hw.module @partial_product_16(in %a : i16, in %b : i16, out sum : i16) { - %0 = comb.extract %b from 0 : (i16) -> i1 - %1 = comb.extract %b from 1 : (i16) -> i1 - %2 = comb.extract %b from 2 : (i16) -> i1 - %3 = comb.extract %b from 3 : (i16) -> i1 - %4 = comb.extract %b from 4 : (i16) -> i1 - %5 = comb.extract %b from 5 : (i16) -> i1 - %6 = comb.extract %b from 6 : (i16) -> i1 - %7 = comb.extract %b from 7 : (i16) -> i1 - %8 = comb.extract %b from 8 : (i16) -> i1 - %9 = comb.extract %b from 9 : (i16) -> i1 - %10 = comb.extract %b from 10 : (i16) -> i1 - %11 = comb.extract %b from 11 : (i16) -> i1 - %12 = comb.extract %b from 12 : (i16) -> i1 - %13 = comb.extract %b from 13 : (i16) -> i1 - %14 = comb.extract %b from 14 : (i16) -> i1 - %15 = comb.extract %b from 15 : (i16) -> i1 - %16 = comb.replicate %0 : (i1) -> i16 - %17 = comb.and %16, %a : i16 - %c0_i16 = hw.constant 0 : i16 - %18 = comb.shl %17, %c0_i16 : i16 - %19 = comb.replicate %1 : (i1) -> i16 - %20 = comb.and %19, %a : i16 - %c1_i16 = hw.constant 1 : i16 - %21 = comb.shl %20, %c1_i16 : i16 - %22 = comb.replicate %2 : (i1) -> i16 - %23 = comb.and %22, %a : i16 - %c2_i16 = hw.constant 2 : i16 - %24 = comb.shl %23, %c2_i16 : i16 - %25 = comb.replicate %3 : (i1) -> i16 - %26 = comb.and %25, %a : i16 - %c3_i16 = hw.constant 3 : i16 - %27 = comb.shl %26, %c3_i16 : i16 - %28 = comb.replicate %4 : (i1) -> i16 - %29 = comb.and %28, %a : i16 - %c4_i16 = hw.constant 4 : i16 - %30 = comb.shl %29, %c4_i16 : i16 - %31 = comb.replicate %5 : (i1) -> i16 - %32 = comb.and %31, %a : i16 - %c5_i16 = hw.constant 5 : i16 - %33 = comb.shl %32, %c5_i16 : i16 - %34 = comb.replicate %6 : (i1) -> i16 - %35 = comb.and %34, %a : i16 - %c6_i16 = hw.constant 6 : i16 - %36 = comb.shl %35, %c6_i16 : i16 - %37 = comb.replicate %7 : (i1) -> i16 - %38 = comb.and %37, %a : i16 - %c7_i16 = hw.constant 7 : i16 - %39 = comb.shl %38, %c7_i16 : i16 - %40 = comb.replicate %8 : (i1) -> i16 - %41 = comb.and %40, %a : i16 - %c8_i16 = hw.constant 8 : i16 - %42 = comb.shl %41, %c8_i16 : i16 - %43 = comb.replicate %9 : (i1) -> i16 - %44 = comb.and %43, %a : i16 - %c9_i16 = hw.constant 9 : i16 - %45 = comb.shl %44, %c9_i16 : i16 - %46 = comb.replicate %10 : (i1) -> i16 - %47 = comb.and %46, %a : i16 - %c10_i16 = hw.constant 10 : i16 - %48 = comb.shl %47, %c10_i16 : i16 - %49 = comb.replicate %11 : (i1) -> i16 - %50 = comb.and %49, %a : i16 - %c11_i16 = hw.constant 11 : i16 - %51 = comb.shl %50, %c11_i16 : i16 - %52 = comb.replicate %12 : (i1) -> i16 - %53 = comb.and %52, %a : i16 - %c12_i16 = hw.constant 12 : i16 - %54 = comb.shl %53, %c12_i16 : i16 - %55 = comb.replicate %13 : (i1) -> i16 - %56 = comb.and %55, %a : i16 - %c13_i16 = hw.constant 13 : i16 - %57 = comb.shl %56, %c13_i16 : i16 - %58 = comb.replicate %14 : (i1) -> i16 - %59 = comb.and %58, %a : i16 - %c14_i16 = hw.constant 14 : i16 - %60 = comb.shl %59, %c14_i16 : i16 - %61 = comb.replicate %15 : (i1) -> i16 - %62 = comb.and %61, %a : i16 - %c15_i16 = hw.constant 15 : i16 - %63 = comb.shl %62, %c15_i16 : i16 - %64 = comb.add bin %18, %21, %24, %27, %30, %33, %36, %39, %42, %45, %48, %51, %54, %57, %60, %63 : i16 - hw.output %64 : i16 - } - hw.module @partial_product_17(in %a : i17, in %b : i17, out sum : i17) { - %false = hw.constant false - %c0_i17 = hw.constant 0 : i17 - %c1_i17 = hw.constant 1 : i17 - %0 = comb.shl %a, %c1_i17 : i17 - %1 = comb.extract %b from 0 : (i17) -> i1 - %2 = comb.extract %b from 1 : (i17) -> i1 - %3 = comb.extract %b from 2 : (i17) -> i1 - %4 = comb.extract %b from 3 : (i17) -> i1 - %5 = comb.extract %b from 4 : (i17) -> i1 - %6 = comb.extract %b from 5 : (i17) -> i1 - %7 = comb.extract %b from 6 : (i17) -> i1 - %8 = comb.extract %b from 7 : (i17) -> i1 - %9 = comb.extract %b from 8 : (i17) -> i1 - %10 = comb.extract %b from 9 : (i17) -> i1 - %11 = comb.extract %b from 10 : (i17) -> i1 - %12 = comb.extract %b from 11 : (i17) -> i1 - %13 = comb.extract %b from 12 : (i17) -> i1 - %14 = comb.extract %b from 13 : (i17) -> i1 - %15 = comb.extract %b from 14 : (i17) -> i1 - %16 = comb.extract %b from 15 : (i17) -> i1 - %17 = comb.extract %b from 16 : (i17) -> i1 - %18 = comb.xor bin %1, %false : i1 - %true = hw.constant true - %19 = comb.xor bin %1, %true : i1 - %20 = comb.xor bin %2, %true : i1 - %21 = comb.xor bin %false, %true : i1 - %22 = comb.and bin %20, %1, %false : i1 - %23 = comb.and bin %2, %19, %21 : i1 - %24 = comb.or bin %22, %23 : i1 - %25 = comb.replicate %2 : (i1) -> i17 - %26 = comb.replicate %18 : (i1) -> i17 - %27 = comb.replicate %24 : (i1) -> i17 - %28 = comb.and %27, %0 : i17 - %29 = comb.and %26, %a : i17 - %30 = comb.or bin %28, %29 : i17 - %31 = comb.xor bin %30, %25 : i17 - %32 = comb.xor bin %3, %2 : i1 - %true_0 = hw.constant true - %33 = comb.xor bin %3, %true_0 : i1 - %34 = comb.xor bin %4, %true_0 : i1 - %35 = comb.xor bin %2, %true_0 : i1 - %36 = comb.and bin %34, %3, %2 : i1 - %37 = comb.and bin %4, %33, %35 : i1 - %38 = comb.or bin %36, %37 : i1 - %39 = comb.replicate %4 : (i1) -> i17 - %40 = comb.replicate %32 : (i1) -> i17 - %41 = comb.replicate %38 : (i1) -> i17 - %42 = comb.and %41, %0 : i17 - %43 = comb.and %40, %a : i17 - %44 = comb.or bin %42, %43 : i17 - %45 = comb.xor bin %44, %39 : i17 - %46 = comb.concat %45, %false, %2 : i17, i1, i1 - %47 = comb.extract %46 from 0 : (i19) -> i17 - %c0_i17_1 = hw.constant 0 : i17 - %48 = comb.shl %47, %c0_i17_1 : i17 - %49 = comb.xor bin %5, %4 : i1 - %true_2 = hw.constant true - %50 = comb.xor bin %5, %true_2 : i1 - %51 = comb.xor bin %6, %true_2 : i1 - %52 = comb.xor bin %4, %true_2 : i1 - %53 = comb.and bin %51, %5, %4 : i1 - %54 = comb.and bin %6, %50, %52 : i1 - %55 = comb.or bin %53, %54 : i1 - %56 = comb.replicate %6 : (i1) -> i17 - %57 = comb.replicate %49 : (i1) -> i17 - %58 = comb.replicate %55 : (i1) -> i17 - %59 = comb.and %58, %0 : i17 - %60 = comb.and %57, %a : i17 - %61 = comb.or bin %59, %60 : i17 - %62 = comb.xor bin %61, %56 : i17 - %63 = comb.concat %62, %false, %4 : i17, i1, i1 - %64 = comb.extract %63 from 0 : (i19) -> i17 - %c2_i17 = hw.constant 2 : i17 - %65 = comb.shl %64, %c2_i17 : i17 - %66 = comb.xor bin %7, %6 : i1 - %true_3 = hw.constant true - %67 = comb.xor bin %7, %true_3 : i1 - %68 = comb.xor bin %8, %true_3 : i1 - %69 = comb.xor bin %6, %true_3 : i1 - %70 = comb.and bin %68, %7, %6 : i1 - %71 = comb.and bin %8, %67, %69 : i1 - %72 = comb.or bin %70, %71 : i1 - %73 = comb.replicate %8 : (i1) -> i17 - %74 = comb.replicate %66 : (i1) -> i17 - %75 = comb.replicate %72 : (i1) -> i17 - %76 = comb.and %75, %0 : i17 - %77 = comb.and %74, %a : i17 - %78 = comb.or bin %76, %77 : i17 - %79 = comb.xor bin %78, %73 : i17 - %80 = comb.concat %79, %false, %6 : i17, i1, i1 - %81 = comb.extract %80 from 0 : (i19) -> i17 - %c4_i17 = hw.constant 4 : i17 - %82 = comb.shl %81, %c4_i17 : i17 - %83 = comb.xor bin %9, %8 : i1 - %true_4 = hw.constant true - %84 = comb.xor bin %9, %true_4 : i1 - %85 = comb.xor bin %10, %true_4 : i1 - %86 = comb.xor bin %8, %true_4 : i1 - %87 = comb.and bin %85, %9, %8 : i1 - %88 = comb.and bin %10, %84, %86 : i1 - %89 = comb.or bin %87, %88 : i1 - %90 = comb.replicate %10 : (i1) -> i17 - %91 = comb.replicate %83 : (i1) -> i17 - %92 = comb.replicate %89 : (i1) -> i17 - %93 = comb.and %92, %0 : i17 - %94 = comb.and %91, %a : i17 - %95 = comb.or bin %93, %94 : i17 - %96 = comb.xor bin %95, %90 : i17 - %97 = comb.concat %96, %false, %8 : i17, i1, i1 - %98 = comb.extract %97 from 0 : (i19) -> i17 - %c6_i17 = hw.constant 6 : i17 - %99 = comb.shl %98, %c6_i17 : i17 - %100 = comb.xor bin %11, %10 : i1 - %true_5 = hw.constant true - %101 = comb.xor bin %11, %true_5 : i1 - %102 = comb.xor bin %12, %true_5 : i1 - %103 = comb.xor bin %10, %true_5 : i1 - %104 = comb.and bin %102, %11, %10 : i1 - %105 = comb.and bin %12, %101, %103 : i1 - %106 = comb.or bin %104, %105 : i1 - %107 = comb.replicate %12 : (i1) -> i17 - %108 = comb.replicate %100 : (i1) -> i17 - %109 = comb.replicate %106 : (i1) -> i17 - %110 = comb.and %109, %0 : i17 - %111 = comb.and %108, %a : i17 - %112 = comb.or bin %110, %111 : i17 - %113 = comb.xor bin %112, %107 : i17 - %114 = comb.concat %113, %false, %10 : i17, i1, i1 - %115 = comb.extract %114 from 0 : (i19) -> i17 - %c8_i17 = hw.constant 8 : i17 - %116 = comb.shl %115, %c8_i17 : i17 - %117 = comb.xor bin %13, %12 : i1 - %true_6 = hw.constant true - %118 = comb.xor bin %13, %true_6 : i1 - %119 = comb.xor bin %14, %true_6 : i1 - %120 = comb.xor bin %12, %true_6 : i1 - %121 = comb.and bin %119, %13, %12 : i1 - %122 = comb.and bin %14, %118, %120 : i1 - %123 = comb.or bin %121, %122 : i1 - %124 = comb.replicate %14 : (i1) -> i17 - %125 = comb.replicate %117 : (i1) -> i17 - %126 = comb.replicate %123 : (i1) -> i17 - %127 = comb.and %126, %0 : i17 - %128 = comb.and %125, %a : i17 - %129 = comb.or bin %127, %128 : i17 - %130 = comb.xor bin %129, %124 : i17 - %131 = comb.concat %130, %false, %12 : i17, i1, i1 - %132 = comb.extract %131 from 0 : (i19) -> i17 - %c10_i17 = hw.constant 10 : i17 - %133 = comb.shl %132, %c10_i17 : i17 - %134 = comb.xor bin %15, %14 : i1 - %true_7 = hw.constant true - %135 = comb.xor bin %15, %true_7 : i1 - %136 = comb.xor bin %16, %true_7 : i1 - %137 = comb.xor bin %14, %true_7 : i1 - %138 = comb.and bin %136, %15, %14 : i1 - %139 = comb.and bin %16, %135, %137 : i1 - %140 = comb.or bin %138, %139 : i1 - %141 = comb.replicate %16 : (i1) -> i17 - %142 = comb.replicate %134 : (i1) -> i17 - %143 = comb.replicate %140 : (i1) -> i17 - %144 = comb.and %143, %0 : i17 - %145 = comb.and %142, %a : i17 - %146 = comb.or bin %144, %145 : i17 - %147 = comb.xor bin %146, %141 : i17 - %148 = comb.concat %147, %false, %14 : i17, i1, i1 - %149 = comb.extract %148 from 0 : (i19) -> i17 - %c12_i17 = hw.constant 12 : i17 - %150 = comb.shl %149, %c12_i17 : i17 - %151 = comb.xor bin %17, %16 : i1 - %true_8 = hw.constant true - %152 = comb.xor bin %17, %true_8 : i1 - %153 = comb.xor bin %false, %true_8 : i1 - %154 = comb.xor bin %16, %true_8 : i1 - %155 = comb.and bin %153, %17, %16 : i1 - %156 = comb.and bin %false, %152, %154 : i1 - %157 = comb.or bin %155, %156 : i1 - %158 = comb.replicate %false : (i1) -> i17 - %159 = comb.replicate %151 : (i1) -> i17 - %160 = comb.replicate %157 : (i1) -> i17 - %161 = comb.and %160, %0 : i17 - %162 = comb.and %159, %a : i17 - %163 = comb.or bin %161, %162 : i17 - %164 = comb.xor bin %163, %158 : i17 - %165 = comb.concat %164, %false, %16 : i17, i1, i1 - %166 = comb.extract %165 from 0 : (i19) -> i17 - %c14_i17 = hw.constant 14 : i17 - %167 = comb.shl %166, %c14_i17 : i17 - %168 = comb.add bin %31, %48, %65, %82, %99, %116, %133, %150, %167, %c0_i17, %c0_i17, %c0_i17, %c0_i17, %c0_i17, %c0_i17, %c0_i17, %c0_i17 : i17 - hw.output %168 : i17 - } -} - diff --git a/lib/Conversion/DatapathToComb/DatapathToComb.cpp b/lib/Conversion/DatapathToComb/DatapathToComb.cpp index 8a633f4596b9..f79d78052727 100644 --- a/lib/Conversion/DatapathToComb/DatapathToComb.cpp +++ b/lib/Conversion/DatapathToComb/DatapathToComb.cpp @@ -89,6 +89,13 @@ struct DatapathCompressOpConversion : OpConversionPattern { struct DatapathPartialProductOpConversion : OpConversionPattern { using OpConversionPattern::OpConversionPattern; + + DatapathPartialProductOpConversion(MLIRContext *context, bool forceBooth) + : OpConversionPattern(context), + forceBooth(forceBooth) {}; + + const bool forceBooth; + LogicalResult matchAndRewrite(PartialProductOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { @@ -104,10 +111,10 @@ struct DatapathPartialProductOpConversion } // Use width as a heuristic to guide partial product implementation - if (width <= 16) - return lowerAndArray(rewriter, a, b, op, width); - else + if (width > 16 || forceBooth) return lowerBoothArray(rewriter, a, b, op, width); + else + return lowerAndArray(rewriter, a, b, op, width); } private: @@ -158,10 +165,12 @@ struct DatapathPartialProductOpConversion // encOne = (-2*b[2*i+1] + b[2*i] + b[2*i-1]) == +/- 1 // encTwo = (-2*b[2*i+1] + b[2*i] + b[2*i-1]) == +/- 2 Value encNegPrev; - for (unsigned i = 0; i < width; i += 2) { + + // For even width - additional row contains the final sign correction + for (unsigned i = 0; i <= width; i += 2) { // Get Booth bits: b[i+1], b[i], b[i-1] (b[-1] = 0) Value bim1 = (i == 0) ? zeroFalse : bBits[i - 1]; - Value bi = bBits[i]; + Value bi = (i < width) ? bBits[i] : zeroFalse; Value bip1 = (i + 1 < width) ? bBits[i + 1] : zeroFalse; // Is the encoding zero or negative (an approximation) @@ -241,10 +250,10 @@ struct ConvertDatapathToCombPass }; } // namespace -static void -populateDatapathToCombConversionPatterns(RewritePatternSet &patterns, - bool lowerCompressToAdd) { - patterns.add(patterns.getContext()); +static void populateDatapathToCombConversionPatterns( + RewritePatternSet &patterns, bool lowerCompressToAdd, bool forceBooth) { + patterns.add(patterns.getContext(), + forceBooth); if (lowerCompressToAdd) // Lower compressors to simple add operations for downstream optimisations @@ -261,7 +270,8 @@ void ConvertDatapathToCombPass::runOnOperation() { target.addIllegalDialect(); RewritePatternSet patterns(&getContext()); - populateDatapathToCombConversionPatterns(patterns, lowerCompressToAdd); + populateDatapathToCombConversionPatterns(patterns, lowerCompressToAdd, + forceBooth); if (failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/test/Conversion/CombToAIG/comb-to-aig-arith.mlir b/test/Conversion/CombToAIG/comb-to-aig-arith.mlir index aee3e6f56861..93323547450c 100644 --- a/test/Conversion/CombToAIG/comb-to-aig-arith.mlir +++ b/test/Conversion/CombToAIG/comb-to-aig-arith.mlir @@ -67,7 +67,6 @@ hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) { // ALLOW_ADD-NEXT: %[[RHS0:.+]] = comb.extract %rhs from 0 : (i3) -> i1 // ALLOW_ADD-NEXT: %[[RHS1:.+]] = comb.extract %rhs from 1 : (i3) -> i1 // ALLOW_ADD-NEXT: %[[RHS2:.+]] = comb.extract %rhs from 2 : (i3) -> i1 -// ALLOW_ADD-NEXT: %false = hw.constant false // Partial Products // ALLOW_ADD-NEXT: %[[P_0_0:.+]] = comb.and %[[LHS0]], %[[RHS0]] : i1 // ALLOW_ADD-NEXT: %[[P_1_0:.+]] = comb.and %[[LHS1]], %[[RHS0]] : i1 @@ -76,6 +75,7 @@ hw.module @sub(in %lhs: i4, in %rhs: i4, out out: i4) { // ALLOW_ADD-NEXT: %[[P_1_1:.+]] = comb.and %[[LHS1]], %[[RHS1]] : i1 // ALLOW_ADD-NEXT: %[[P_2_1:.+]] = comb.and %[[LHS0]], %[[RHS2]] : i1 // Wallace Tree Reduction +// ALLOW_ADD-NEXT: %false = hw.constant false // ALLOW_ADD-NEXT: %[[XOR0:.+]] = comb.xor bin %[[P_1_0]], %[[P_0_1]] : i1 // ALLOW_ADD-NEXT: %[[AND0:.+]] = comb.and bin %[[P_1_0]], %[[P_0_1]] : i1 // ALLOW_ADD-NEXT: %[[XOR1:.+]] = comb.xor bin %[[P_2_0]], %[[P_1_1]] : i1 diff --git a/test/Conversion/DatapathToComb/datapath-to-comb.mlir b/test/Conversion/DatapathToComb/datapath-to-comb.mlir index 659cd67939d3..07479dc367c9 100644 --- a/test/Conversion/DatapathToComb/datapath-to-comb.mlir +++ b/test/Conversion/DatapathToComb/datapath-to-comb.mlir @@ -1,5 +1,6 @@ // RUN: circt-opt %s --convert-datapath-to-comb | FileCheck %s -// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-datapath-to-comb{lower-compress-to-add=true}))" | FileCheck %s --check-prefix=ALLOW_ADD +// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-datapath-to-comb{lower-compress-to-add=true}))" | FileCheck %s --check-prefix=TO-ADD +// RUN: circt-opt %s --pass-pipeline="builtin.module(hw.module(convert-datapath-to-comb{lower-partial-product-to-booth=true}, canonicalize))" | FileCheck %s --check-prefix=FORCE-BOOTH // CHECK-LABEL: @compressor hw.module @compressor(in %a : i2, in %b : i2, in %c : i2, out carry : i2, out save : i2) { @@ -57,6 +58,39 @@ hw.module @partial_product(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, o hw.output %0#0, %0#1, %0#2 : i3, i3, i3 } +// CHECK-LABEL: @partial_product +// FORCE-BOOTH-LABEL: @partial_product_booth +// Constants +// FORCE-BOOTH-NEXT: %true = hw.constant true +// FORCE-BOOTH-NEXT: %false = hw.constant false +// FORCE-BOOTH-NEXT: %c0_i3 = hw.constant 0 : i3 +// 2*a +// FORCE-BOOTH-NEXT: %0 = comb.extract %a from 0 : (i3) -> i2 +// FORCE-BOOTH-NEXT: %[[TWOA:.+]] = comb.concat %0, %false : i2, i1 +// FORCE-BOOTH-NEXT: %[[B0:.+]] = comb.extract %b from 0 : (i3) -> i1 +// FORCE-BOOTH-NEXT: %[[B1:.+]] = comb.extract %b from 1 : (i3) -> i1 +// FORCE-BOOTH-NEXT: %[[B2:.+]] = comb.extract %b from 2 : (i3) -> i1 +// PP0 +// FORCE-BOOTH-NEXT: %[[NB0:.+]] = comb.xor bin %[[B0]], %true : i1 +// FORCE-BOOTH-NEXT: %[[TWO0:.+]] = comb.and %[[B1]], %[[NB0]] : i1 +// FORCE-BOOTH-NEXT: %[[PPOSGN:.+]] = comb.replicate %[[B1]] : (i1) -> i3 +// FORCE-BOOTH-NEXT: %[[ONER:.+]] = comb.replicate %[[B0]] : (i1) -> i3 +// FORCE-BOOTH-NEXT: %[[TWO0R:.+]] = comb.replicate %[[TWO0]] : (i1) -> i3 +// FORCE-BOOTH-NEXT: %[[PP0TWOA:.+]] = comb.and %[[TWO0R]], %[[TWOA]] : i3 +// FORCE-BOOTH-NEXT: %[[PP0ONEA:.+]] = comb.and %[[ONER]], %a : i3 +// FORCE-BOOTH-NEXT: %[[PP0MAG:.+]] = comb.or bin %[[PP0TWOA]], %[[PP0ONEA]] : i3 +// FORCE-BOOTH-NEXT: %[[PP0:.+]] = comb.xor bin %[[PP0MAG]], %[[PPOSGN]] : i3 +// PP1 +// FORCE-BOOTH-NEXT: %[[B2XORB1:.+]] = comb.xor bin %4, %3 : i1 +// FORCE-BOOTH-NEXT: %[[A0:.+]] = comb.extract %a from 0 : (i3) -> i1 +// FORCE-BOOTH-NEXT: %[[PP1MSB:.+]] = comb.and %[[B2XORB1]], %[[A0]] : i1 +// FORCE-BOOTH-NEXT: %[[PP1:.+]] = comb.concat %[[PP1MSB]], %false, %[[B1]] : i1, i1, i1 +// FORCE-BOOTH-NEXT: hw.output %[[PP0]], %[[PP1]], %c0_i3 : i3, i3, i3 +hw.module @partial_product_booth(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, out pp2 : i3) { + %0:3 = datapath.partial_product %a, %b : (i3, i3) -> (i3, i3, i3) + hw.output %0#0, %0#1, %0#2 : i3, i3, i3 +} + // CHECK-LABEL: @partial_product_24 hw.module @partial_product_24(in %a : i24, in %b : i24, out sum : i24) { %0:24 = datapath.partial_product %a, %b : (i24, i24) -> (i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24, i24) From 826234e3966faf7145cdcf3e5167737a1d585be1 Mon Sep 17 00:00:00 2001 From: cowardsa Date: Wed, 23 Jul 2025 11:36:33 +0100 Subject: [PATCH 09/12] Minor fix --- test/Conversion/DatapathToComb/datapath-to-comb.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Conversion/DatapathToComb/datapath-to-comb.mlir b/test/Conversion/DatapathToComb/datapath-to-comb.mlir index 07479dc367c9..211573619f00 100644 --- a/test/Conversion/DatapathToComb/datapath-to-comb.mlir +++ b/test/Conversion/DatapathToComb/datapath-to-comb.mlir @@ -58,7 +58,7 @@ hw.module @partial_product(in %a : i3, in %b : i3, out pp0 : i3, out pp1 : i3, o hw.output %0#0, %0#1, %0#2 : i3, i3, i3 } -// CHECK-LABEL: @partial_product +// CHECK-LABEL: @partial_product_booth // FORCE-BOOTH-LABEL: @partial_product_booth // Constants // FORCE-BOOTH-NEXT: %true = hw.constant true From dddc4d1f0438a11ccb497f948eea2f5b6bbb655b Mon Sep 17 00:00:00 2001 From: cowardsa Date: Wed, 23 Jul 2025 11:41:41 +0100 Subject: [PATCH 10/12] Formatting --- lib/Conversion/DatapathToComb/DatapathToComb.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/DatapathToComb/DatapathToComb.cpp b/lib/Conversion/DatapathToComb/DatapathToComb.cpp index f79d78052727..8c3c2b5fb43e 100644 --- a/lib/Conversion/DatapathToComb/DatapathToComb.cpp +++ b/lib/Conversion/DatapathToComb/DatapathToComb.cpp @@ -92,7 +92,7 @@ struct DatapathPartialProductOpConversion DatapathPartialProductOpConversion(MLIRContext *context, bool forceBooth) : OpConversionPattern(context), - forceBooth(forceBooth) {}; + forceBooth(forceBooth){}; const bool forceBooth; From 71324fd7d8cb4e76b6cdafb9296762b1508919a6 Mon Sep 17 00:00:00 2001 From: cowardsa Date: Thu, 24 Jul 2025 10:05:04 +0100 Subject: [PATCH 11/12] Removing populate patterns function --- include/circt/Conversion/DatapathToComb.h | 5 ---- .../DatapathToComb/DatapathToComb.cpp | 23 ++++++++----------- 2 files changed, 9 insertions(+), 19 deletions(-) diff --git a/include/circt/Conversion/DatapathToComb.h b/include/circt/Conversion/DatapathToComb.h index d3b95b2d07bd..c3cf22d0ddb0 100644 --- a/include/circt/Conversion/DatapathToComb.h +++ b/include/circt/Conversion/DatapathToComb.h @@ -13,11 +13,6 @@ namespace circt { -void populateDatapathToCombConversionPatterns(TypeConverter &converter, - RewritePatternSet &patterns, - bool lowerCompressToAdd, - bool forceBooth); - #define GEN_PASS_DECL_CONVERTDATAPATHTOCOMB #include "circt/Conversion/Passes.h.inc" diff --git a/lib/Conversion/DatapathToComb/DatapathToComb.cpp b/lib/Conversion/DatapathToComb/DatapathToComb.cpp index 8c3c2b5fb43e..5c40b4fe747f 100644 --- a/lib/Conversion/DatapathToComb/DatapathToComb.cpp +++ b/lib/Conversion/DatapathToComb/DatapathToComb.cpp @@ -92,7 +92,7 @@ struct DatapathPartialProductOpConversion DatapathPartialProductOpConversion(MLIRContext *context, bool forceBooth) : OpConversionPattern(context), - forceBooth(forceBooth){}; + forceBooth(forceBooth) {}; const bool forceBooth; @@ -250,8 +250,14 @@ struct ConvertDatapathToCombPass }; } // namespace -static void populateDatapathToCombConversionPatterns( - RewritePatternSet &patterns, bool lowerCompressToAdd, bool forceBooth) { +void ConvertDatapathToCombPass::runOnOperation() { + ConversionTarget target(getContext()); + + target.addLegalDialect(); + target.addIllegalDialect(); + + RewritePatternSet patterns(&getContext()); + patterns.add(patterns.getContext(), forceBooth); @@ -261,17 +267,6 @@ static void populateDatapathToCombConversionPatterns( else // Lower compressors to a complete gate-level implementation patterns.add(patterns.getContext()); -} - -void ConvertDatapathToCombPass::runOnOperation() { - ConversionTarget target(getContext()); - - target.addLegalDialect(); - target.addIllegalDialect(); - - RewritePatternSet patterns(&getContext()); - populateDatapathToCombConversionPatterns(patterns, lowerCompressToAdd, - forceBooth); if (failed(mlir::applyPartialConversion(getOperation(), target, std::move(patterns)))) From 40ddd48c3c31fdf6a2c9209f4f23dc5537893a62 Mon Sep 17 00:00:00 2001 From: cowardsa Date: Thu, 24 Jul 2025 10:07:55 +0100 Subject: [PATCH 12/12] Formatting --- lib/Conversion/DatapathToComb/DatapathToComb.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/DatapathToComb/DatapathToComb.cpp b/lib/Conversion/DatapathToComb/DatapathToComb.cpp index 5c40b4fe747f..9d09fc5df099 100644 --- a/lib/Conversion/DatapathToComb/DatapathToComb.cpp +++ b/lib/Conversion/DatapathToComb/DatapathToComb.cpp @@ -92,7 +92,7 @@ struct DatapathPartialProductOpConversion DatapathPartialProductOpConversion(MLIRContext *context, bool forceBooth) : OpConversionPattern(context), - forceBooth(forceBooth) {}; + forceBooth(forceBooth){}; const bool forceBooth;