Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions lib/Transforms/MemoryBanking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ struct MemoryBankingPass
// map from original memory definition to newly allocated banks
DenseMap<Value, SmallVector<Value>> memoryToBanks;
DenseSet<Operation *> opsToErase;
// Track memory references that need to be cleaned up after memory banking is
// complete.
DenseSet<Value> oldMemRefVals;
};
} // namespace

Expand Down Expand Up @@ -134,10 +137,11 @@ struct BankAffineLoadPattern
: public OpRewritePattern<mlir::affine::AffineLoadOp> {
BankAffineLoadPattern(MLIRContext *context, uint64_t bankingFactor,
unsigned bankingDimension,
DenseMap<Value, SmallVector<Value>> &memoryToBanks)
DenseMap<Value, SmallVector<Value>> &memoryToBanks,
DenseSet<Value> &oldMemRefVals)
: OpRewritePattern<mlir::affine::AffineLoadOp>(context),
bankingFactor(bankingFactor), bankingDimension(bankingDimension),
memoryToBanks(memoryToBanks) {}
memoryToBanks(memoryToBanks), oldMemRefVals(oldMemRefVals) {}

LogicalResult matchAndRewrite(mlir::affine::AffineLoadOp loadOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -187,6 +191,10 @@ struct BankAffineLoadPattern
auto defaultValue = rewriter.create<arith::ConstantOp>(loc, zeroAttr);
rewriter.create<scf::YieldOp>(loc, defaultValue.getResult());

// We track Load's memory reference only if it is a block argument - this is
// the only case where the reference isn't replaced.
if (Value memRef = loadOp.getMemref(); isa<BlockArgument>(memRef))
oldMemRefVals.insert(memRef);
rewriter.replaceOp(loadOp, switchOp.getResult(0));

return success();
Expand All @@ -196,6 +204,7 @@ struct BankAffineLoadPattern
uint64_t bankingFactor;
unsigned bankingDimension;
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
DenseSet<Value> &oldMemRefVals;
};

// Replace the original store operations with newly created memory banks
Expand All @@ -205,11 +214,12 @@ struct BankAffineStorePattern
unsigned bankingDimension,
DenseMap<Value, SmallVector<Value>> &memoryToBanks,
DenseSet<Operation *> &opsToErase,
DenseSet<Operation *> &processedOps)
DenseSet<Operation *> &processedOps,
DenseSet<Value> &oldMemRefVals)
: OpRewritePattern<mlir::affine::AffineStoreOp>(context),
bankingFactor(bankingFactor), bankingDimension(bankingDimension),
memoryToBanks(memoryToBanks), opsToErase(opsToErase),
processedOps(processedOps) {}
processedOps(processedOps), oldMemRefVals(oldMemRefVals) {}

LogicalResult matchAndRewrite(mlir::affine::AffineStoreOp storeOp,
PatternRewriter &rewriter) const override {
Expand Down Expand Up @@ -262,6 +272,7 @@ struct BankAffineStorePattern

processedOps.insert(storeOp);
opsToErase.insert(storeOp);
oldMemRefVals.insert(storeOp.getMemref());

return success();
}
Expand All @@ -272,6 +283,7 @@ struct BankAffineStorePattern
DenseMap<Value, SmallVector<Value>> &memoryToBanks;
DenseSet<Operation *> &opsToErase;
DenseSet<Operation *> &processedOps;
DenseSet<Value> &oldMemRefVals;
};

// Replace the original return operation with newly created memory banks
Expand Down Expand Up @@ -388,9 +400,10 @@ void MemoryBankingPass::runOnOperation() {

DenseSet<Operation *> processedOps;
patterns.add<BankAffineLoadPattern>(ctx, bankingFactor, bankingDimension,
memoryToBanks);
memoryToBanks, oldMemRefVals);
patterns.add<BankAffineStorePattern>(ctx, bankingFactor, bankingDimension,
memoryToBanks, opsToErase, processedOps);
memoryToBanks, opsToErase, processedOps,
oldMemRefVals);
patterns.add<BankReturnPattern>(ctx, memoryToBanks);

GreedyRewriteConfig config;
Expand All @@ -401,10 +414,6 @@ void MemoryBankingPass::runOnOperation() {
}

// Clean up the old memref values
DenseSet<Value> oldMemRefVals;
for (const auto &[memory, _] : memoryToBanks)
oldMemRefVals.insert(memory);

if (failed(cleanUpOldMemRefs(oldMemRefVals, opsToErase))) {
signalPassFailure();
}
Expand Down
76 changes: 76 additions & 0 deletions test/Transforms/memory_banking.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=2" | FileCheck %s --check-prefix UNROLL-BY-2
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=1" | FileCheck %s --check-prefix UNROLL-BY-1
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=8" | FileCheck %s --check-prefix UNROLL-BY-8
// RUN: circt-opt %s -split-input-file -memory-banking="banking-factor=2" | FileCheck %s --check-prefix ALLOC-UNROLL-2

// -----

Expand Down Expand Up @@ -259,3 +260,78 @@ func.func @bank_one_dim_unroll8(%arg0: memref<8xf32>, %arg1: memref<8xf32>) -> (
}
return %mem : memref<8xf32>
}

// -----

// ALLOC-UNROLL-2: #[[$ATTR_0:.+]] = affine_map<(d0) -> (d0 mod 2)>
// ALLOC-UNROLL-2: #[[$ATTR_1:.+]] = affine_map<(d0) -> (d0 floordiv 2)>


// ALLOC-UNROLL-2-LABEL: func.func @alloc_unroll2() -> (memref<4xf32>, memref<4xf32>) {
// ALLOC-UNROLL-2: %[[VAL_0:.*]] = arith.constant 0.000000e+00 : f32
// ALLOC-UNROLL-2: %[[VAL_1:.*]] = memref.alloc() : memref<4xf32>
// ALLOC-UNROLL-2: %[[VAL_2:.*]] = memref.alloc() : memref<4xf32>
// ALLOC-UNROLL-2: %[[VAL_3:.*]] = memref.alloc() : memref<4xf32>
// ALLOC-UNROLL-2: %[[VAL_4:.*]] = memref.alloc() : memref<4xf32>
// ALLOC-UNROLL-2: %[[VAL_5:.*]] = memref.alloc() : memref<4xf32>
// ALLOC-UNROLL-2: %[[VAL_6:.*]] = memref.alloc() : memref<4xf32>
// ALLOC-UNROLL-2: affine.parallel (%[[VAL_7:.*]]) = (0) to (8) {
// ALLOC-UNROLL-2: %[[VAL_8:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]])
// ALLOC-UNROLL-2: %[[VAL_9:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]])
// ALLOC-UNROLL-2: %[[VAL_10:.*]] = scf.index_switch %[[VAL_8]] -> f32
// ALLOC-UNROLL-2: case 0 {
// ALLOC-UNROLL-2: %[[VAL_11:.*]] = affine.load %[[VAL_1]]{{\[}}%[[VAL_9]]] : memref<4xf32>
// ALLOC-UNROLL-2: scf.yield %[[VAL_11]] : f32
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: case 1 {
// ALLOC-UNROLL-2: %[[VAL_12:.*]] = affine.load %[[VAL_2]]{{\[}}%[[VAL_9]]] : memref<4xf32>
// ALLOC-UNROLL-2: scf.yield %[[VAL_12]] : f32
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: default {
// ALLOC-UNROLL-2: scf.yield %[[VAL_0]] : f32
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: %[[VAL_13:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]])
// ALLOC-UNROLL-2: %[[VAL_14:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]])
// ALLOC-UNROLL-2: %[[VAL_15:.*]] = scf.index_switch %[[VAL_13]] -> f32
// ALLOC-UNROLL-2: case 0 {
// ALLOC-UNROLL-2: %[[VAL_16:.*]] = affine.load %[[VAL_3]]{{\[}}%[[VAL_14]]] : memref<4xf32>
// ALLOC-UNROLL-2: scf.yield %[[VAL_16]] : f32
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: case 1 {
// ALLOC-UNROLL-2: %[[VAL_17:.*]] = affine.load %[[VAL_4]]{{\[}}%[[VAL_14]]] : memref<4xf32>
// ALLOC-UNROLL-2: scf.yield %[[VAL_17]] : f32
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: default {
// ALLOC-UNROLL-2: scf.yield %[[VAL_0]] : f32
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: %[[VAL_18:.*]] = arith.mulf %[[VAL_10]], %[[VAL_15]] : f32
// ALLOC-UNROLL-2: %[[VAL_19:.*]] = affine.apply #[[$ATTR_0]](%[[VAL_7]])
// ALLOC-UNROLL-2: %[[VAL_20:.*]] = affine.apply #[[$ATTR_1]](%[[VAL_7]])
// ALLOC-UNROLL-2: scf.index_switch %[[VAL_19]]
// ALLOC-UNROLL-2: case 0 {
// ALLOC-UNROLL-2: affine.store %[[VAL_18]], %[[VAL_5]]{{\[}}%[[VAL_20]]] : memref<4xf32>
// ALLOC-UNROLL-2: scf.yield
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: case 1 {
// ALLOC-UNROLL-2: affine.store %[[VAL_18]], %[[VAL_6]]{{\[}}%[[VAL_20]]] : memref<4xf32>
// ALLOC-UNROLL-2: scf.yield
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: default {
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: }
// ALLOC-UNROLL-2: return %[[VAL_5]], %[[VAL_6]] : memref<4xf32>, memref<4xf32>
// ALLOC-UNROLL-2: }

func.func @alloc_unroll2() -> (memref<8xf32>) {
%arg0 = memref.alloc() : memref<8xf32>
%arg1 = memref.alloc() : memref<8xf32>
%mem = memref.alloc() : memref<8xf32>
affine.parallel (%i) = (0) to (8) {
%1 = affine.load %arg0[%i] : memref<8xf32>
%2 = affine.load %arg1[%i] : memref<8xf32>
%3 = arith.mulf %1, %2 : f32
affine.store %3, %mem[%i] : memref<8xf32>
}
return %mem : memref<8xf32>
}

Loading