Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

GeorgeARM
Copy link
Contributor

TOSA requires that reduce_sum operations on bf16 accumulate into fp32. This change updates the linalg legalization by introducing an explicit accumulator type to ensure compliance with the specification.

TOSA requires that `reduce_sum` operations on bf16 accumulate into fp32.
This change updates the `linalg` legalization by introducing an explicit
accumulator type to ensure compliance with the specification.

Signed-off-by: Georgios Pinitas <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Sep 13, 2025

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-tosa

Author: Georgios Pinitas (GeorgeARM)

Changes

TOSA requires that reduce_sum operations on bf16 accumulate into fp32. This change updates the linalg legalization by introducing an explicit accumulator type to ensure compliance with the specification.


Full diff: https://github.com/llvm/llvm-project/pull/158389.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+40-15)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+21)
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index e2b31f640da2f..96eab7197a585 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -1160,6 +1160,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
   auto elementTy = resultTy.getElementType();
   Value input = op->getOperand(0);
 
+  // Figure out the accType if needed
+  bool widenAccTy = std::is_same_v<OpTy, tosa::ReduceSumOp> &&
+                    isa<FloatType>(elementTy) &&
+                    cast<FloatType>(elementTy).isBF16();
+  Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy;
+
   SmallVector<int64_t> reduceShape;
   SmallVector<Value> dynDims;
   for (unsigned i = 0; i < inputTy.getRank(); i++) {
@@ -1174,11 +1180,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
   inputs.push_back(input);
 
   // First fill the output buffer with the init value.
-  auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape,
-                                             resultTy.getElementType(), dynDims)
-                         .getResult();
+  auto emptyTensor =
+      tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
+          .getResult();
 
-  auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter);
+  auto fillValueAttr = createInitialValueForReduceOp(op, accTy, rewriter);
   if (!fillValueAttr)
     return rewriter.notifyMatchFailure(
         op, "No initial value found for reduction operation");
@@ -1231,8 +1237,14 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
       [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) {
         std::array<Value, 2> binaryArgs{
             blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]};
-        auto result = createLinalgBodyCalculationForReduceOp(
-            op, binaryArgs, elementTy, rewriter);
+
+        // If reduction type differs then extend (applicable to reduce_sum)
+        if (binaryArgs[0].getType() != accTy)
+          binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy,
+                                                binaryArgs[0]);
+
+        auto result = createLinalgBodyCalculationForReduceOp(op, binaryArgs,
+                                                             accTy, rewriter);
         if (result)
           didEncounterError = true;
 
@@ -1273,12 +1285,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
 
     // Create a tensor full of NaNs.
     auto nanValueAttr = rewriter.getFloatAttr(
-        elementTy,
+        accTy,
         APFloat::getNaN(cast<FloatType>(elementTy).getFloatSemantics(), false));
     auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr);
     auto emptyNanTensor =
-        tensor::EmptyOp::create(rewriter, loc, reduceShape,
-                                resultTy.getElementType(), dynDims)
+        tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
             .getResult();
     auto nanFilledTensor =
         linalg::FillOp::create(rewriter, loc, ValueRange{nanValue},
@@ -1288,8 +1299,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
     // Create an empty tensor, non need to fill this since it will be
     // overwritten by the select.
     auto finalEmptyTensor =
-        tensor::EmptyOp::create(rewriter, loc, reduceShape,
-                                resultTy.getElementType(), dynDims)
+        tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims)
             .getResult();
 
     // Do a selection between the tensors akin to:
@@ -1304,9 +1314,24 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
     linalgOp = linalgSelect;
   }
 
+  // Truncate back to resultTy if needed
+  Value reducedRes = linalgOp->getResult(0);
+  if (widenAccTy) {
+    auto resEmptyOp =
+        tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims)
+            .getResult();
+    reducedRes = linalg::MapOp::create(
+                     rewriter, loc, ValueRange{reducedRes}, resEmptyOp,
+                     [&](OpBuilder &builder, Location loc, ValueRange args) {
+                       Value val = arith::TruncFOp::create(builder, loc,
+                                                           elementTy, args[0]);
+                       linalg::YieldOp::create(builder, loc, ValueRange{val});
+                     })
+                     .getResult()[0];
+  }
+
   SmallVector<ReassociationExprs, 4> reassociationMap;
-  uint64_t expandInputRank =
-      cast<ShapedType>(linalgOp->getResults()[0].getType()).getRank();
+  uint64_t expandInputRank = cast<ShapedType>(reducedRes.getType()).getRank();
   reassociationMap.resize(expandInputRank);
 
   for (uint64_t i = 0; i < expandInputRank; i++) {
@@ -1324,8 +1349,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis,
   // since here we know which dimension to expand, and `tosa::ReshapeOp` would
   // not have access to such information. This matters when handling dynamically
   // sized tensors.
-  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
-      op, resultTy, linalgOp->getResults()[0], reassociationMap);
+  rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(op, resultTy, reducedRes,
+                                                     reassociationMap);
   return success();
 }
 
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 3fc513f823a1a..3b63bdf4f7219 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -912,6 +912,27 @@ func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor<
 
 // -----
 
+// CHECK-LABEL: @reduce_bf16
+// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16>
+func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () {
+  // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<4xf32>
+  // CHECK: [[CST0:%.+]] = arith.constant 0.0
+  // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]]
+  // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xbf16>) outs([[FILL]] : tensor<4xf32>) dimensions = [0]
+  // CHECK:  (%[[ARG1:.*]]: bf16, %[[ARG2:.*]]: f32) {
+  // CHECK:   [[EXTF:%.+]] = arith.extf %[[ARG1]] : bf16 to f32
+  // CHECK:   [[ACC:%.+]] = arith.addf [[EXTF]], %[[ARG2]] : f32
+  // CHECK:   linalg.yield [[ACC]] : f32
+  // CHECK:  }
+  // CHECK:  [[TRUNCF:%.+]] = tensor.empty() : tensor<4xbf16>
+  // CHECK:  [[RES:%.+]] = linalg.map { arith.truncf } ins([[REDUCE]]{{.*}}outs([[TRUNCF]]
+  // CHECK:  tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xbf16> into tensor<1x4xbf16>
+  %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xbf16>) -> tensor<1x4xbf16>
+  return
+}
+
+// -----
+
 // CHECK-LABEL: @reduce_float
 // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32>
 func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants