Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions lib/Dialect/Comb/CombFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1921,7 +1921,7 @@ OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
return {};

// mux (c, b, b) -> b
if (getTrueValue() == getFalseValue())
if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
return getTrueValue();
if (auto tv = adaptor.getTrueValue())
if (tv == adaptor.getFalseValue())
Expand Down Expand Up @@ -2183,6 +2183,9 @@ static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand,
// `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)`
// `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)`
if (auto subMux = dyn_cast<MuxOp>(subExpr)) {
if (subMux == op)
return false;

Value otherValue;
Value subCond = subMux.getCond();

Expand Down Expand Up @@ -2514,8 +2517,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
}
}

if (auto falseMux =
dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp())) {
if (auto falseMux = op.getFalseValue().getDefiningOp<MuxOp>();
falseMux && falseMux != op) {
// mux(selector, x, mux(selector, y, z) = mux(selector, x, z)
if (op.getCond() == falseMux.getCond()) {
replaceOpWithNewOpAndCopyName<MuxOp>(
Expand All @@ -2529,8 +2532,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
return success();
}

if (auto trueMux =
dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp())) {
if (auto trueMux = op.getTrueValue().getDefiningOp<MuxOp>();
trueMux && trueMux != op) {
// mux(selector, mux(selector, a, b), c) = mux(selector, a, c)
if (op.getCond() == trueMux.getCond()) {
replaceOpWithNewOpAndCopyName<MuxOp>(
Expand All @@ -2548,7 +2551,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
trueMux.getTrueValue() == falseMux.getTrueValue()) {
trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
falseMux != op) {
auto subMux = rewriter.create<MuxOp>(
rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
Expand All @@ -2562,7 +2566,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
trueMux.getFalseValue() == falseMux.getFalseValue()) {
trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
falseMux != op) {
auto subMux = rewriter.create<MuxOp>(
rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
Expand All @@ -2577,7 +2582,8 @@ LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
trueMux && falseMux &&
trueMux.getTrueValue() == falseMux.getTrueValue() &&
trueMux.getFalseValue() == falseMux.getFalseValue()) {
trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
falseMux != op) {
auto subMux = rewriter.create<MuxOp>(
rewriter.getFusedLoc(
{op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
Expand Down
11 changes: 7 additions & 4 deletions test/Dialect/Comb/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ hw.module @muxConstantsFold(in %cond: i1, out o: i25) {
hw.module @muxCommon(in %cond: i1, in %cond2: i1,
in %arg0 : i32, in %arg1 : i32, in %arg2: i32, in %arg3: i32,
out o1: i32, out o2: i32, out o3: i32, out o4: i32,
out o5: i32, out orResult: i32, out o6: i32, out o7: i32) {
out o5: i32, out orResult: i32, out o6: i32, out o7: i32, out o8 : i1) {
%allones = hw.constant -1 : i32
%notArg0 = comb.xor %arg0, %allones : i32

Expand Down Expand Up @@ -1275,10 +1275,13 @@ hw.module @muxCommon(in %cond: i1, in %cond2: i1,
%1 = comb.mux %cond, %arg1, %arg0 : i32
%o7 = comb.mux %cond2, %1, %arg0 : i32

/// CHECK: [[O8:%.+]] = comb.mux [[O8]], [[O8]], [[O8]] : i1
%o8 = comb.mux %o8, %o8, %o8 : i1

// CHECK: hw.output [[O1]], [[O2]], [[O3]], [[O4]], [[O5]], [[ORRESULT]],
// CHECK: [[O6]], [[O7]]
hw.output %o1, %o2, %o3, %o4, %o5, %orResult, %o6, %o7
: i32, i32, i32, i32, i32, i32, i32, i32
// CHECK: [[O6]], [[O7]], [[O8]]
hw.output %o1, %o2, %o3, %o4, %o5, %orResult, %o6, %o7, %o8
: i32, i32, i32, i32, i32, i32, i32, i32, i1
}

// CHECK-LABEL: @flatten_multi_use_and
Expand Down
Loading