Thanks to visit codestin.com
Credit goes to github.com

Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1160,6 +1160,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
auto elementTy = resultTy.getElementType();
Value input = op->getOperand(0);

// Figure out the accType if needed
bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
isa<FloatType>(elementTy) &&
cast<FloatType>(elementTy).isBF16();
Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;

SmallVector<int64_t> reduceShape;
SmallVector<Value> dynDims;
for (unsigned i = 0; i < inputTy.getRank(); i++) {
Expand All @@ -1174,11 +1180,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
inputs.push_back(input);

// First fill the output buffer with the init value.
auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape,
resultTy.getElementType(), dynDims)
.getResult();
auto emptyTensor =
tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
.getResult();

auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
auto fillValueAttr = createInitialValueForReduceOp(op, accTy, rewriter);
if (!fillValueAttr)
return rewriter.notifyMatchFailure(
op, "No initial value found for reduction operation");
Expand Down Expand Up @@ -1231,8 +1237,14 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
std::array<Value, 2> binaryArgs{
blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
auto result = createLinalgBodyCalculationForReduceOp(
op, binaryArgs, elementTy, rewriter);

// If reduction type differs then extend (applicable to reduce_sum)
if (binaryArgs[0].getType() != accTy)
binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
binaryArgs[0]);

auto result = createLinalgBodyCalculationForReduceOp(op, binaryArgs,
accTy, rewriter);
if (result)
didEncounterError = true;

Expand Down Expand Up @@ -1273,12 +1285,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,

// Create a tensor full of NaNs.
auto nanValueAttr = rewriter.getFloatAttr(
elementTy,
accTy,
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
auto emptyNanTensor =
tensor::EmptyOp::create(rewriter, loc, reduceShape,
resultTy.getElementType(), dynDims)
tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
.getResult();
auto nanFilledTensor =
linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
Expand All @@ -1288,8 +1299,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
// Create an empty tensor, non need to fill this since it will be
// overwritten by the select.
auto finalEmptyTensor =
tensor::EmptyOp::create(rewriter, loc, reduceShape,
resultTy.getElementType(), dynDims)
tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
.getResult();

// Do a selection between the tensors akin to:
Expand All @@ -1304,9 +1314,24 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
linalgOp = linalgSelect;
}

// Truncate back to resultTy if needed
Value reducedRes = linalgOp->getResult(0);
if (widenAccTy) {
auto resEmptyOp =
tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
.getResult();
reducedRes = linalg::MapOp::create(
rewriter, loc, ValueRange{reducedRes}, resEmptyOp,
[&](OpBuilder &builder, Location loc, ValueRange args) {
Value val = arith::TruncFOp::create(builder, loc,
elementTy, args[0]);
linalg::YieldOp::create(builder, loc, ValueRange{val});
})
.getResult()[0];
}

SmallVector<ReassociationExprs, 4> reassociationMap;
uint64_t expandInputRank =
cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType()).getRank();
reassociationMap.resize(expandInputRank);

for (uint64_t i = 0; i < expandInputRank; i++) {
Expand All @@ -1324,8 +1349,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
// since here we know which dimension to expand, and `tosa::ReshapeOp` would
// not have access to such information. This matters when handling dynamically
// sized tensors.
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
op, resultTy, linalgOp->getResults()[0], reassociationMap);
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(op, resultTy, reducedRes,
reassociationMap);
return success();
}

Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,27 @@ func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<

// -----

// CHECK-LABEL: @reduce_bf16
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16>
func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () {
// CHECK: [[INIT:%.+]] = tensor.empty() : tensor<4xf32>
// CHECK: [[CST0:%.+]] = arith.constant 0.0
// CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
// CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xbf16>) outs([[FILL]] : tensor<4xf32>) dimensions = [0]
// CHECK: (%[[ARG1:.*]]: bf16, %[[ARG2:.*]]: f32) {
// CHECK: [[EXTF:%.+]] = arith.extf %[[ARG1]] : bf16 to f32
// CHECK: [[ACC:%.+]] = arith.addf [[EXTF]], %[[ARG2]] : f32
// CHECK: linalg.yield [[ACC]] : f32
// CHECK: }
// CHECK: [[TRUNCF:%.+]] = tensor.empty() : tensor<4xbf16>
// CHECK: [[RES:%.+]] = linalg.map { arith.truncf } ins([[REDUCE]]{{.*}}outs([[TRUNCF]]
// CHECK: tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xbf16> into tensor<1x4xbf16>
%0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xbf16>) -> tensor<1x4xbf16>
return
}

// -----

// CHECK-LABEL: @reduce_float
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
Expand Down