diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index e2b31f640da2f..0a6f2477560a1 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1160,6 +1160,12 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, auto elementTy = resultTy.getElementType(); Value input = op->getOperand(0); + // Figure out the accType if needed + bool widenAccTy = std::is_same_v && + isa(elementTy) && + cast(elementTy).isBF16(); + Type accTy = widenAccTy ? rewriter.getF32Type() : elementTy; + SmallVector reduceShape; SmallVector dynDims; for (unsigned i = 0; i < inputTy.getRank(); i++) { @@ -1174,11 +1180,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, inputs.push_back(input); // First fill the output buffer with the init value. - auto emptyTensor = tensor::EmptyOp::create(rewriter, loc, reduceShape, - resultTy.getElementType(), dynDims) - .getResult(); + auto emptyTensor = + tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims) + .getResult(); - auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); + auto fillValueAttr = createInitialValueForReduceOp(op, accTy, rewriter); if (!fillValueAttr) return rewriter.notifyMatchFailure( op, "No initial value found for reduction operation"); @@ -1231,8 +1237,14 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { std::array binaryArgs{ blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]}; - auto result = createLinalgBodyCalculationForReduceOp( - op, binaryArgs, elementTy, rewriter); + + // If reduction type differs then extend (applicable to reduce_sum) + if (binaryArgs[0].getType() != accTy) + binaryArgs[0] = arith::ExtFOp::create(nestedBuilder, nestedLoc, accTy, + binaryArgs[0]); + + auto result = createLinalgBodyCalculationForReduceOp(op, binaryArgs, + accTy, rewriter); if (result) didEncounterError = true; @@ -1273,12 +1285,11 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, // Create a tensor full of NaNs. auto nanValueAttr = rewriter.getFloatAttr( - elementTy, + accTy, APFloat::getNaN(cast(elementTy).getFloatSemantics(), false)); auto nanValue = arith::ConstantOp::create(rewriter, loc, nanValueAttr); auto emptyNanTensor = - tensor::EmptyOp::create(rewriter, loc, reduceShape, - resultTy.getElementType(), dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims) .getResult(); auto nanFilledTensor = linalg::FillOp::create(rewriter, loc, ValueRange{nanValue}, @@ -1288,8 +1299,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, // Create an empty tensor, non need to fill this since it will be // overwritten by the select. auto finalEmptyTensor = - tensor::EmptyOp::create(rewriter, loc, reduceShape, - resultTy.getElementType(), dynDims) + tensor::EmptyOp::create(rewriter, loc, reduceShape, accTy, dynDims) .getResult(); // Do a selection between the tensors akin to: @@ -1304,9 +1314,32 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, linalgOp = linalgSelect; } + // Truncate back to resultTy if needed + Value reducedRes = linalgOp->getResult(0); + if (widenAccTy) { + auto resEmptyOp = + tensor::EmptyOp::create(rewriter, loc, reduceShape, elementTy, dynDims) + .getResult(); + + const unsigned reducedRank = + cast(reducedRes.getType()).getRank(); + auto identityMap = rewriter.getMultiDimIdentityMap(reducedRank); + reducedRes = + linalg::GenericOp::create( + rewriter, loc, resEmptyOp.getType(), ValueRange{reducedRes}, + ValueRange{resEmptyOp}, + ArrayRef{identityMap, identityMap}, + getNParallelLoopsAttrs(reducedRank), + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { + Value truncf = arith::TruncFOp::create(nestedBuilder, nestedLoc, + elementTy, args[0]); + linalg::YieldOp::create(nestedBuilder, nestedLoc, truncf); + }) + .getResults()[0]; + } + SmallVector reassociationMap; - uint64_t expandInputRank = - cast(linalgOp->getResults()[0].getType()).getRank(); + uint64_t expandInputRank = cast(reducedRes.getType()).getRank(); reassociationMap.resize(expandInputRank); for (uint64_t i = 0; i < expandInputRank; i++) { @@ -1324,8 +1357,8 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, // since here we know which dimension to expand, and `tosa::ReshapeOp` would // not have access to such information. This matters when handling dynamically // sized tensors. - rewriter.replaceOpWithNewOp( - op, resultTy, linalgOp->getResults()[0], reassociationMap); + rewriter.replaceOpWithNewOp(op, resultTy, reducedRes, + reassociationMap); return success(); } diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 3fc513f823a1a..37af8b8859852 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -912,6 +912,32 @@ func.func @test_identity(%arg0: tensor<1xf32>, %arg1: tensor<1xi32>) -> (tensor< // ----- +// CHECK: #[[$MAP0:.*]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: @reduce_bf16 +// CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xbf16> +func.func @reduce_bf16(%arg0: tensor<5x4xbf16>) -> () { + // CHECK: [[INIT:%.+]] = tensor.empty() : tensor<4xf32> + // CHECK: [[CST0:%.+]] = arith.constant 0.0 + // CHECK: [[FILL:%.+]] = linalg.fill ins([[CST0]]{{.*}}outs([[INIT]] + // CHECK: [[REDUCE:%.+]] = linalg.reduce ins([[ARG0]] : tensor<5x4xbf16>) outs([[FILL]] : tensor<4xf32>) dimensions = [0] + // CHECK: (%[[ARG1:.*]]: bf16, %[[ARG2:.*]]: f32) { + // CHECK: [[EXTF:%.+]] = arith.extf %[[ARG1]] : bf16 to f32 + // CHECK: [[ACC:%.+]] = arith.addf [[EXTF]], %[[ARG2]] : f32 + // CHECK: linalg.yield [[ACC]] : f32 + // CHECK: } + // CHECK: [[INIT_RES:%.+]] = tensor.empty() : tensor<4xbf16> + // CHECK: [[RES:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel"]} ins([[REDUCE]] : tensor<4xf32>) outs([[INIT_RES]] : tensor<4xbf16>) + // CHECK: ^bb0(%[[IN:.*]]: f32, %[[OUT:.*]]: bf16): + // CHECK: [[TRUNCF:%.+]] = arith.truncf %[[IN]] : f32 to bf16 + // CHECK: linalg.yield [[TRUNCF]] : bf16 + // CHECK: } + // CHECK: tensor.expand_shape [[RES]] {{\[}}[0, 1]] output_shape [1, 4] : tensor<4xbf16> into tensor<1x4xbf16> + %0 = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<5x4xbf16>) -> tensor<1x4xbf16> + return +} + +// ----- + // CHECK-LABEL: @reduce_float // CHECK-SAME: [[ARG0:%.+]]: tensor<5x4xf32> func.func @reduce_float(%arg0: tensor<5x4xf32>) -> () {