-
Notifications
You must be signed in to change notification settings - Fork 15k
[mlir][tosa] Introduce accumulator type for reduce_sum
on bf16
#158389
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
TOSA requires that `reduce_sum` operations on bf16 accumulate into fp32. This change updates the `linalg` legalization by introducing an explicit accumulator type to ensure compliance with the specification. Signed-off-by: Georgios Pinitas <[email protected]>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Georgios Pinitas (GeorgeARM) ChangesTOSA requires that Full diff: https://github.com/llvm/llvm-project/pull/158389.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e2b31f640da2f..96eab7197a585 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1160,6 +1160,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
auto elementTy = resultTy.getElementType();
Value input = op->getOperand(0);
+ // Figure out the accType if needed
+ bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
+ isa<FloatType>(elementTy) &&
+ cast<FloatType>(elementTy).isBF16();
+ Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
+
SmallVector<int64_t> reduceShape;
SmallVector<Value> dynDims;
for (unsigned i = 0; i < inputTy.getRank(); i++) {
@@ -1174,11 +1180,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
inputs.push_back(input);
// First fill the output buffer with the init value.
- auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape,
- resultTy.getElementType(), dynDims)
- .getResult();
+ auto emptyTensor =
+ tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
+ .getResult();
- auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
+ auto fillValueAttr = createInitialValueForReduceOp(op, accTy, rewriter);
if (!fillValueAttr)
return rewriter.notifyMatchFailure(
op, "No initial value found for reduction operation");
@@ -1231,8 +1237,14 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
std::array<Value, 2> binaryArgs{
blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
- auto result = createLinalgBodyCalculationForReduceOp(
- op, binaryArgs, elementTy, rewriter);
+
+ // If reduction type differs then extend (applicable to reduce_sum)
+ if (binaryArgs[0].getType() != accTy)
+ binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
+ binaryArgs[0]);
+
+ auto result = createLinalgBodyCalculationForReduceOp(op, binaryArgs,
+ accTy, rewriter);
if (result)
didEncounterError = true;
@@ -1273,12 +1285,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
// Create a tensor full of NaNs.
auto nanValueAttr = rewriter.getFloatAttr(
- elementTy,
+ accTy,
APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
auto emptyNanTensor =
- tensor::EmptyOp::create(rewriter, loc, reduceShape,
- resultTy.getElementType(), dynDims)
+ tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
.getResult();
auto nanFilledTensor =
linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
@@ -1288,8 +1299,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
// Create an empty tensor, non need to fill this since it will be
// overwritten by the select.
auto finalEmptyTensor =
- tensor::EmptyOp::create(rewriter, loc, reduceShape,
- resultTy.getElementType(), dynDims)
+ tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
.getResult();
// Do a selection between the tensors akin to:
@@ -1304,9 +1314,24 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
linalgOp = linalgSelect;
}
+ // Truncate back to resultTy if needed
+ Value reducedRes = linalgOp->getResult(0);
+ if (widenAccTy) {
+ auto resEmptyOp =
+ tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
+ .getResult();
+ reducedRes = linalg::MapOp::create(
+ rewriter, loc, ValueRange{reducedRes}, resEmptyOp,
+ [&](OpBuilder &builder, Location loc, ValueRange args) {
+ Value val = arith::TruncFOp::create(builder, loc,
+ elementTy, args[0]);
+ linalg::YieldOp::create(builder, loc, ValueRange{val});
+ })
+ .getResult()[0];
+ }
+
SmallVector<ReassociationExprs, 4> reassociationMap;
- uint64_t expandInputRank =
- cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
+ uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType()).getRank();
reassociationMap.resize(expandInputRank);
for (uint64_t i = 0; i < expandInputRank; i++) {
@@ -1324,8 +1349,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
// since here we know which dimension to expand, and `tosa::ReshapeOp` would
// not have access to such information. This matters when handling dynamically
// sized tensors.
- rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
- op, resultTy, linalgOp->getResults()[0], reassociationMap);
+ rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(op, resultTy, reducedRes,
+ reassociationMap);
return success();
}
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3fc513f823a1a..3b63bdf4f7219 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -912,6 +912,27 @@ func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<
// -----
+// CHECK-LABEL: @reduce_bf16
+// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16>
+func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () {
+ // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<4xf32>
+ // CHECK: [[CST0:%.+]] = arith.constant 0.0
+ // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
+ // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xbf16>) outs([[FILL]] : tensor<4xf32>) dimensions = [0]
+ // CHECK: (%[[ARG1:.*]]: bf16, %[[ARG2:.*]]: f32) {
+ // CHECK: [[EXTF:%.+]] = arith.extf %[[ARG1]] : bf16 to f32
+ // CHECK: [[ACC:%.+]] = arith.addf [[EXTF]], %[[ARG2]] : f32
+ // CHECK: linalg.yield [[ACC]] : f32
+ // CHECK: }
+ // CHECK: [[TRUNCF:%.+]] = tensor.empty() : tensor<4xbf16>
+ // CHECK: [[RES:%.+]] = linalg.map { arith.truncf } ins([[REDUCE]]{{.*}}outs([[TRUNCF]]
+ // CHECK: tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xbf16> into tensor<1x4xbf16>
+ %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xbf16>) -> tensor<1x4xbf16>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @reduce_float
// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {
|
TOSA requires that
reduce_sum
operations on bf16 accumulate into fp32. This change updates thelinalg
legalization by introducing an explicit accumulator type to ensure compliance with the specification.