Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2870,7 +2870,9 @@ struct FoldRegMems : public mlir::RewritePattern {
if (hasDontTouch(mem) || info.depth != 1)
return failure();

auto memModule = mem->getParentOfType<FModuleOp>();
auto ty = mem.getDataType();
auto loc = mem.getLoc();
auto *block = mem->getBlock();

// Find the clock of the register-to-be, all write ports should share it.
Value clock;
Expand Down Expand Up @@ -2924,14 +2926,24 @@ struct FoldRegMems : public mlir::RewritePattern {
return failure();
clock = portClock;
}

// Create a new register to store the data.
auto ty = mem.getDataType();
rewriter.setInsertionPointAfterValue(clock);
auto reg = rewriter.create<RegOp>(mem.getLoc(), ty, clock, mem.getName())
.getResult();
// Create a new wire where the memory used to be. This wire will dominate
// all readers of the memory. Reads should be made through this wire.
rewriter.setInsertionPointAfter(mem);
auto memWire = rewriter.create<WireOp>(loc, ty).getResult();

// The memory is replaced by a register, which we place at the end of the
// block, so that any value driven to the original memory will dominate the
// new register (including the clock). All other ops will be placed
// after the register.
rewriter.setInsertionPointToEnd(block);
auto memReg =
rewriter.create<RegOp>(loc, ty, clock, mem.getName()).getResult();

// Connect the output of the register to the wire.
rewriter.create<MatchingConnectOp>(loc, memWire, memReg);

// Helper to insert a given number of pipeline stages through registers.
// The pipelines are placed at the end of the block.
auto pipeline = [&](Value value, Value clock, const Twine &name,
unsigned latency) {
for (unsigned i = 0; i < latency; ++i) {
Expand All @@ -2940,7 +2952,6 @@ struct FoldRegMems : public mlir::RewritePattern {
llvm::raw_string_ostream os(regName);
os << mem.getName() << "_" << name << "_" << i;
}

auto reg = rewriter
.create<RegOp>(mem.getLoc(), value.getType(), clock,
rewriter.getStringAttr(regName))
Expand All @@ -2964,7 +2975,6 @@ struct FoldRegMems : public mlir::RewritePattern {
auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
Value value = getPortFieldValue(port, field);
assert(value);
rewriter.setInsertionPointAfterValue(value);
return pipeline(value, portClock, name + "_" + field, stages);
};

Expand All @@ -2976,8 +2986,7 @@ struct FoldRegMems : public mlir::RewritePattern {
// address must be 0 for single-address memories and the enable signal
// is ignored, always reading out the register. Under these constraints,
// the read port can be replaced with the value from the register.
rewriter.setInsertionPointAfterValue(reg);
replacePortField(rewriter, port, "data", reg);
replacePortField(rewriter, port, "data", memWire);
break;
}
case MemOp::PortKind::Write: {
Expand All @@ -2989,16 +2998,14 @@ struct FoldRegMems : public mlir::RewritePattern {
}
case MemOp::PortKind::ReadWrite: {
// Always read the register into the read end.
rewriter.setInsertionPointAfterValue(reg);
replacePortField(rewriter, port, "rdata", reg);
replacePortField(rewriter, port, "rdata", memWire);

// Create a write enable and pipeline stages.
auto wdata = portPipeline("wdata", writeStages);
auto wmask = portPipeline("wmask", writeStages);

Value en = getPortFieldValue(port, "en");
Value wmode = getPortFieldValue(port, "wmode");
rewriter.setInsertionPointToEnd(memModule.getBodyBlock());

auto wen = rewriter.create<AndPrimOp>(port.getLoc(), en, wmode);
auto wenPipelined =
Expand All @@ -3010,8 +3017,7 @@ struct FoldRegMems : public mlir::RewritePattern {
}

// Regardless of `writeUnderWrite`, always implement PortOrder.
rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
Value next = reg;
Value next = memReg;
for (auto &[data, en, mask] : writes) {
Value masked;

Expand All @@ -3037,7 +3043,7 @@ struct FoldRegMems : public mlir::RewritePattern {

next = rewriter.create<MuxPrimOp>(next.getLoc(), en, masked, next);
}
rewriter.create<MatchingConnectOp>(reg.getLoc(), reg, next);
rewriter.create<MatchingConnectOp>(memReg.getLoc(), memReg, next);

// Delete the fields and their associated connects.
for (Operation *conn : connects)
Expand Down
150 changes: 111 additions & 39 deletions test/Dialect/FIRRTL/simplify-mems.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,11 @@ firrtl.circuit "OneAddressMasked" {
} :
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>,
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<2>>
// CHECK: %Memory = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>

// CHECK: firrtl.matchingconnect %result_read, %Memory : !firrtl.uint<32>
// CHECK: [[MemoryWire:%.+]] = firrtl.wire : !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %result_read, [[MemoryWire]] : !firrtl.uint<32>
// CHECK: %Memory = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect [[MemoryWire]], %Memory : !firrtl.uint<32>

%read_addr = firrtl.subfield %Memory_read[addr] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>
firrtl.connect %read_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1>
Expand Down Expand Up @@ -408,31 +410,6 @@ firrtl.circuit "OneAddressNoMask" {
in %in_rwen: !firrtl.uint<1>,
out %result_read: !firrtl.uint<32>,
out %result_rw: !firrtl.uint<32>) {

// Pipeline the inputs.
// TODO: It would be good to de-duplicate these either in the pass or in a canonicalizer.

// CHECK: %Memory_write_en_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_en_0, %in_wen : !firrtl.uint<1>
// CHECK: %Memory_write_en_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_en_1, %Memory_write_en_0 : !firrtl.uint<1>
// CHECK: %Memory_write_en_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_en_2, %Memory_write_en_1 : !firrtl.uint<1>

// CHECK: %Memory_write_data_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_write_data_0, %in_data : !firrtl.uint<32>
// CHECK: %Memory_write_data_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_write_data_1, %Memory_write_data_0 : !firrtl.uint<32>
// CHECK: %Memory_write_data_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_write_data_2, %Memory_write_data_1 : !firrtl.uint<32>

// CHECK: %Memory_rw_wdata_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_rw_wdata_0, %in_data : !firrtl.uint<32>
// CHECK: %Memory_rw_wdata_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_rw_wdata_1, %Memory_rw_wdata_0 : !firrtl.uint<32>
// CHECK: %Memory_rw_wdata_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_rw_wdata_2, %Memory_rw_wdata_1 : !firrtl.uint<32>

%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>

%Memory_read, %Memory_rw, %Memory_write = firrtl.mem Undefined
Expand All @@ -447,9 +424,56 @@ firrtl.circuit "OneAddressNoMask" {
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>,
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>>

// A wire, holding the value of the memory, goes to the front of the block.
// CHECK: [[MemoryWire:%.+]] = firrtl.wire : !firrtl.uint<32>

// The original uses of the memory are replaced with uses of the wire.
// CHECK: firrtl.matchingconnect %result_read, [[MemoryWire]] : !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %result_rw, [[MemoryWire]] : !firrtl.uint<32>

// The memory is replaced by a register at the end of the block
// CHECK: %Memory = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>

// CHECK: firrtl.matchingconnect %result_read, %Memory : !firrtl.uint<32>
// The register's data is written to the MemoryWire
// CHECK: firrtl.matchingconnect [[MemoryWire]], %Memory : !firrtl.uint<32>

// Following the register, we pipeline the inputs.
// TODO: It would be good to de-duplicate these either in the pass or in a canonicalizer.

// CHECK: %Memory_rw_wdata_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_rw_wdata_0, %in_data : !firrtl.uint<32>
// CHECK: %Memory_rw_wdata_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_rw_wdata_1, %Memory_rw_wdata_0 : !firrtl.uint<32>
// CHECK: %Memory_rw_wdata_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_rw_wdata_2, %Memory_rw_wdata_1 : !firrtl.uint<32>

// CHECK: [[WRITING:%.+]] = firrtl.and %in_rwen, %wmode_rw : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK: %Memory_rw_wen_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_rw_wen_0, [[WRITING]] : !firrtl.uint<1>
// CHECK: %Memory_rw_wen_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_rw_wen_1, %Memory_rw_wen_0 : !firrtl.uint<1>
// CHECK: %Memory_rw_wen_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_rw_wen_2, %Memory_rw_wen_1 : !firrtl.uint<1>

// CHECK: %Memory_write_data_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_write_data_0, %in_data : !firrtl.uint<32>
// CHECK: %Memory_write_data_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_write_data_1, %Memory_write_data_0 : !firrtl.uint<32>
// CHECK: %Memory_write_data_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_write_data_2, %Memory_write_data_1 : !firrtl.uint<32>

// CHECK: %Memory_write_en_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_en_0, %in_wen : !firrtl.uint<1>
// CHECK: %Memory_write_en_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_en_1, %Memory_write_en_0 : !firrtl.uint<1>
// CHECK: %Memory_write_en_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_en_2, %Memory_write_en_1 : !firrtl.uint<1>

// Finally, the pipelined inputs are driven to the register.
// CHECK: [[WRITE_RW:%.+]] = firrtl.mux(%Memory_rw_wen_2, %Memory_rw_wdata_2, %Memory) : (!firrtl.uint<1>, !firrtl.uint<32>, !firrtl.uint<32>) -> !firrtl.uint<32>
// CHECK: [[WRITE_W:%.+]] = firrtl.mux(%Memory_write_en_2, %Memory_write_data_2, [[WRITE_RW]]) : (!firrtl.uint<1>, !firrtl.uint<32>, !firrtl.uint<32>) -> !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory, [[WRITE_W]] : !firrtl.uint<32>

%read_addr = firrtl.subfield %Memory_read[addr] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>
firrtl.connect %read_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1>
%read_en = firrtl.subfield %Memory_read[en] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>
Expand All @@ -459,7 +483,6 @@ firrtl.circuit "OneAddressNoMask" {
%read_data = firrtl.subfield %Memory_read[data] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>
firrtl.connect %result_read, %read_data : !firrtl.uint<32>, !firrtl.uint<32>

// CHECK: firrtl.matchingconnect %result_rw, %Memory : !firrtl.uint<32>
%rw_addr = firrtl.subfield %Memory_rw[addr] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1>
%rw_en = firrtl.subfield %Memory_rw[en] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
Expand All @@ -475,16 +498,7 @@ firrtl.circuit "OneAddressNoMask" {
%rw_wmask = firrtl.subfield %Memory_rw[wmask] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_wmask, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1>

// CHECK: [[WRITING:%.+]] = firrtl.and %in_rwen, %wmode_rw : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK: %Memory_rw_wen_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_rw_wen_0, [[WRITING]] : !firrtl.uint<1>
// CHECK: %Memory_rw_wen_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_rw_wen_1, %Memory_rw_wen_0 : !firrtl.uint<1>
// CHECK: %Memory_rw_wen_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_rw_wen_2, %Memory_rw_wen_1 : !firrtl.uint<1>
// CHECK: [[WRITE_RW:%.+]] = firrtl.mux(%Memory_rw_wen_2, %Memory_rw_wdata_2, %Memory) : (!firrtl.uint<1>, !firrtl.uint<32>, !firrtl.uint<32>) -> !firrtl.uint<32>
// CHECK: [[WRITE_W:%.+]] = firrtl.mux(%Memory_write_en_2, %Memory_write_data_2, [[WRITE_RW]]) : (!firrtl.uint<1>, !firrtl.uint<32>, !firrtl.uint<32>) -> !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory, [[WRITE_W]] : !firrtl.uint<32>

%write_addr = firrtl.subfield %Memory_write[addr] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>>
firrtl.connect %write_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1>
%write_en = firrtl.subfield %Memory_write[en] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>>
Expand All @@ -497,3 +511,61 @@ firrtl.circuit "OneAddressNoMask" {
firrtl.connect %write_mask, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1>
}
}

// -----

// This test ensures that the FoldRegMems canonicalization correctly
// folds memories under layerblocks.
firrtl.circuit "Rewrite1ElementMemoryToRegisterUnderLayerblock" {
firrtl.layer @A bind {}

firrtl.module public @Rewrite1ElementMemoryToRegisterUnderLayerblock(
in %clock: !firrtl.clock,
in %addr: !firrtl.uint<1>,
in %in_data: !firrtl.uint<32>,
in %wmode_rw: !firrtl.uint<1>,
in %in_wen: !firrtl.uint<1>,
in %in_rwen: !firrtl.uint<1>) {
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>

// CHECK firrtl.layerblock @A
firrtl.layerblock @A {
// CHECK: %result_read = firrtl.wire : !firrtl.uint<32>
// CHECK: %result_rw = firrtl.wire : !firrtl.uint<32>
%result_read = firrtl.wire : !firrtl.uint<32>
%result_rw = firrtl.wire : !firrtl.uint<32>

// CHECK: [[MemoryWire:%.+]] = firrtl.wire : !firrtl.uint<32>
%Memory_rw = firrtl.mem Undefined
{
depth = 1 : i64,
name = "Memory",
portNames = ["rw"],
readLatency = 2 : i32,
writeLatency = 2 : i32
} : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>

%rw_addr = firrtl.subfield %Memory_rw[addr] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1>
%rw_en = firrtl.subfield %Memory_rw[en] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_en, %in_rwen : !firrtl.uint<1>, !firrtl.uint<1>
%rw_clk = firrtl.subfield %Memory_rw[clk] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_clk, %clock : !firrtl.clock, !firrtl.clock
%rw_rdata = firrtl.subfield %Memory_rw[rdata] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>

%rw_wmode = firrtl.subfield %Memory_rw[wmode] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_wmode, %wmode_rw : !firrtl.uint<1>, !firrtl.uint<1>
%rw_wdata = firrtl.subfield %Memory_rw[wdata] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_wdata, %in_data : !firrtl.uint<32>, !firrtl.uint<32>
%rw_wmask = firrtl.subfield %Memory_rw[wmask] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_wmask, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1>

// CHECK: firrtl.matchingconnect %result_rw, [[MemoryWire]] : !firrtl.uint<32>
firrtl.connect %result_rw, %rw_rdata : !firrtl.uint<32>, !firrtl.uint<32>

// CHECK: %Memory = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect [[MemoryWire]], %Memory
// CHECK: firrtl.matchingconnect %Memory, {{%.+}} : !firrtl.uint<32>
}
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same review comments as last time for the test. The layer test is large. Can it be made smaller by either: (1) removing connections or (2) reducing the number of ports?