diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td index 44da2965e6892..8deb208573203 100644 --- a/mlir/include/mlir/Dialect/Linalg/Passes.td +++ b/mlir/include/mlir/Dialect/Linalg/Passes.td @@ -171,6 +171,32 @@ def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> { let dependentDialects = ["linalg::LinalgDialect"]; } +def LinalgConvertToDPSPass : Pass<"linalg-convert-to-dps"> { + let summary = "Convert ops to destination-passing-style"; + let description = [{ + Converts ops that operate on tensors but are not in + destination passing style (DPS) to equivalent linalg + generic which is in DPS. e.g. + ```mlir + %0 = arith.uitofp %arg0 : tensor to tensor + ``` + gets rewritten as: + ```mlir + %c0 = arith.constant 0 : index + %dim = tensor.dim %arg0, %c0 : tensor + %0 = tensor.empty(%dim) : tensor + %1 = linalg.generic + {indexing_maps = [#map, #map], iterator_types = ["parallel"]} + ins(%arg0 : tensor) outs(%0 : tensor) { + ^bb0(%in: i32, %out: f32): + %2 = arith.uitofp %in : i32 to f32 + linalg.yield %2 : f32 + } -> tensor + ``` + }]; + let dependentDialects = ["linalg::LinalgDialect"]; +} + def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> { let summary = "Detensorize linalg ops"; let dependentDialects = []; diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 0cfc8821c0add..c0d492cf69492 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1377,6 +1377,23 @@ rewriteInDestinationPassingStyle(RewriterBase &rewriter, FailureOr rewriteInDestinationPassingStyle(RewriterBase &rewriter, tensor::PadOp padOp); +FailureOr rewriteInDestinationPassingStyle(RewriterBase &rewriter, + arith::UIToFPOp op); +FailureOr rewriteInDestinationPassingStyle(RewriterBase &rewriter, + arith::SIToFPOp op); +FailureOr rewriteInDestinationPassingStyle(RewriterBase &rewriter, + arith::FPToUIOp op); +FailureOr rewriteInDestinationPassingStyle(RewriterBase &rewriter, + arith::FPToSIOp op); + +FailureOr rewriteInDestinationPassingStyle(RewriterBase &rewriter, + arith::AddIOp op); + +FailureOr rewriteInDestinationPassingStyle(RewriterBase &rewriter, + arith::AddFOp op); +FailureOr rewriteInDestinationPassingStyle(RewriterBase &rewriter, + arith::DivFOp op); + /// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing) /// and linalg.matmul. /// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index f0c1f4485b054..94531ff854593 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2611,7 +2611,9 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne( rewriter.setInsertionPoint(target); FailureOr maybeResult = TypeSwitch>(target) - .Case( + .Case( [&rewriter](auto op) { return rewriteInDestinationPassingStyle(rewriter, op); }); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp index 76ddee4f0e9cf..3fec9b8e62cf3 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp @@ -17,13 +17,20 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" +#include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/ADT/STLExtras.h" +namespace mlir { +#define GEN_PASS_DEF_LINALGCONVERTTODPSPASS +#include "mlir/Dialect/Linalg/Passes.h.inc" +} // namespace mlir + using namespace mlir; using namespace mlir::tensor; @@ -96,7 +103,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter, OpBuilder::InsertionGuard g(rewriter); RankedTensorType resultType = padOp.getResultType(); - // Examine the yielded value to decide if a linalg.generic is neede or a + // Examine the yielded value to decide if a linalg.generic is needed or a // linalg.fill is sufficient. Value yieldedValue = cast(padOp.getBody()->getTerminator()).getValue(); @@ -603,6 +610,69 @@ Value linalg::bufferizeToAllocation( } namespace { +/// Rewrites an arith op operating on tensors, e.g. +/// `%z = arith.addf %x, %y : tensor<5xf32>` +/// into an equivalent linalg.generic in destination-passing-style. +/// ```mlir +/// %0 = tensor.empty() : tensor<5xf32> +/// %1 = linalg.generic ... +/// ins(%x, %y : tensor<5xf32>, tensor<5xf32>) +/// outs(%0 : tensor<5xf32>) { +/// ^bb0(%in: f32, %in_0: f32, %out: f32): +/// %2 = arith.addf %in, %in_0 : f32 +/// linalg.yield %2 : f32 +/// } -> tensor<5xf32> +template +FailureOr +rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) { + // Reject ops such as `arith.constant` and `arith.select`. + // constants don't need dps conversion and select is a a `todo`. + auto numOperands = op->getNumOperands(); + if (numOperands == 0 || numOperands > 2) + return failure(); + + // destination passing style rewrite is only for ops on tensor types. + Type resultType = op->getResult(0).getType(); + auto tensorType = dyn_cast(resultType); + if (!tensorType) + return failure(); + + auto loc = op.getLoc(); + OpBuilder::InsertionGuard g(rewriter); + auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0)); + + // Create tensor.empty for `outs` of destination-passing-style. + Value outs = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes); + + // Create linalg.generic + auto rank = tensorType.getRank(); + SmallVector indexingMaps(numOperands + 1, + rewriter.getMultiDimIdentityMap(rank)); + SmallVector iteratorTypes(rank, + utils::IteratorType::parallel); + auto genericOp = linalg::GenericOp::create( + rewriter, loc, tensorType, + op->getOperands(), // inputs + ValueRange{outs}, // outputs + indexingMaps, iteratorTypes, + [&](OpBuilder &builder, Location loc, ValueRange args) { + Value res; + if (args.size() == 2) { + res = + builder.create(loc, args[1].getType(), ValueRange{args[0]}) + .getResult(); + } else if (args.size() == 3) { + res = builder.create(loc, args[2].getType(), + ValueRange{args[0], args[1]}); + } else + llvm_unreachable("did not expect ops other than nary and binary"); + linalg::YieldOp::create(builder, loc, res); + }); + + rewriter.replaceAllUsesWith(op, genericOp.getResult(0)); + rewriter.eraseOp(op); + return genericOp.getOperation(); +} template LogicalResult rewriteOpInDestinationPassingStyle(OpTy op, @@ -612,9 +682,53 @@ LogicalResult rewriteOpInDestinationPassingStyle(OpTy op, } // namespace +#define STAMP_OUT_ARITH_DPS_FUNCS(OPTY) \ + FailureOr linalg::rewriteInDestinationPassingStyle( \ + RewriterBase &rewriter, OPTY op) { \ + return rewriteArithInDestinationPassingStyle(rewriter, op); \ + } + +STAMP_OUT_ARITH_DPS_FUNCS(arith::UIToFPOp) +STAMP_OUT_ARITH_DPS_FUNCS(arith::SIToFPOp) +STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToUIOp) +STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToSIOp) + +STAMP_OUT_ARITH_DPS_FUNCS(arith::AddIOp) + +STAMP_OUT_ARITH_DPS_FUNCS(arith::AddFOp) +STAMP_OUT_ARITH_DPS_FUNCS(arith::DivFOp) + void linalg::populateConvertToDestinationStylePatterns( RewritePatternSet &patterns) { patterns.add(rewriteOpInDestinationPassingStyle); patterns.add(rewriteOpInDestinationPassingStyle); patterns.add(rewriteOpInDestinationPassingStyle); + + patterns.add(rewriteOpInDestinationPassingStyle); + patterns.add(rewriteOpInDestinationPassingStyle); + patterns.add(rewriteOpInDestinationPassingStyle); + patterns.add(rewriteOpInDestinationPassingStyle); + + patterns.add(rewriteOpInDestinationPassingStyle); + + patterns.add(rewriteOpInDestinationPassingStyle); + patterns.add(rewriteOpInDestinationPassingStyle); } + +namespace { +struct LinalgConvertToDPSPass + : public impl::LinalgConvertToDPSPassBase { + using impl::LinalgConvertToDPSPassBase< + LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase; + + void runOnOperation() override; +}; + +void LinalgConvertToDPSPass::runOnOperation() { + + RewritePatternSet patterns(&getContext()); + linalg::populateConvertToDestinationStylePatterns(patterns); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) + signalPassFailure(); +} +} // namespace diff --git a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir index 63c9f1f27517b..a1df34c6555f2 100644 --- a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir @@ -252,3 +252,64 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func @arith_unary_op( +// CHECK-SAME: %[[X:.+]]: tensor<64xi32>) -> tensor<64xf32> { +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32> +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]]], iterator_types = ["parallel"]} +// CHECK-SAME: ins(%[[X:.+]] : tensor<64xi32>) outs(%[[EMPTY]] : tensor<64xf32>) { +// CHECK: ^bb0(%[[x:.+]]: i32, %[[Out:.+]]: f32): +// CHECK: %[[z:.+]] = arith.uitofp %[[x]] : i32 to f32 +// CHECK: linalg.yield %[[z]] : f32 +// CHECK: return %[[GENERIC]] : tensor<64xf32> + +func.func @arith_unary_op(%x : tensor<64xi32>) -> tensor<64xf32> { + %z = arith.uitofp %x : tensor<64xi32> to tensor<64xf32> + return %z : tensor<64xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.uitofp"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.rewrite_in_destination_passing_style %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)> +// CHECK-LABEL: func @arith_binop( +// CHECK-SAME: %[[X:.+]]: tensor, %[[Y:.+]]: tensor +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor +// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor +// CHECK: %[[GENERIC:.+]] = linalg.generic +// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]} +// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor, tensor) outs(%[[EMPTY]] : tensor) { +// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32): +// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] : f32 +// CHECK: linalg.yield %[[z]] : f32 +// CHECK: return %[[GENERIC]] : tensor + +func.func @arith_binop(%x : tensor, %y : tensor) + -> tensor { + %z = arith.addf %x, %y : tensor + return %z : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["arith.addf"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.rewrite_in_destination_passing_style %0 + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir b/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir new file mode 100644 index 0000000000000..0fc9f1e3ed9be --- /dev/null +++ b/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir @@ -0,0 +1,26 @@ +// RUN: mlir-opt %s -lower-quant-ops -linalg-convert-to-dps \ +// RUN: -linalg-specialize-generic-ops -cse | FileCheck %s + +// CHECK-LABEL: func.func @lower_qcast_to_dps( +// CHECK-SAME: %[[X:.+]]: tensor<10xf32>) -> tensor<10x!quant.uniform> +// CHECK-DAG: %[[CST_10I:.+]] = arith.constant dense<10> : tensor<10xi8> +// CHECK-DAG: %[[CST_2F:.+]] = arith.constant dense<2.000000e+00> : tensor<10xf32> +// CHECK: %[[E:.+]] = tensor.empty() : tensor<10xf32> +// CHECK: %[[DIV:.+]] = linalg.div ins(%[[X]], %[[CST_2F]] : tensor<10xf32>, tensor<10xf32>) +// CHECK-SAME: outs(%[[E]] : tensor<10xf32>) -> tensor<10xf32> +// +// CHECK: %[[SITOFP:.+]] = linalg.generic +// CHECK-SAME: ins(%[[CST_10I]] : tensor<10xi8>) outs(%[[E]] : tensor<10xf32>) +// CHECK: %{{.*}} = arith.sitofp %{{.*}} : i8 to f32 +// +// CHECK: %[[ADD:.+]] = linalg.add ins(%[[DIV]], %[[SITOFP]] : tensor<10xf32>, tensor<10xf32>) +// CHECK: %{{.*}} = linalg.generic +// CHECK-SAME: ins(%[[ADD]] : tensor<10xf32>) +// CHECK: %{{.*}} = arith.fptosi %{{.*}} : f32 to i8 + + +!qalias = !quant.uniform +func.func @lower_qcast_to_dps(%arg0: tensor<10xf32>) -> tensor<10x!qalias> { + %0 = quant.qcast %arg0 : tensor<10xf32> to tensor<10x!qalias> + return %0 : tensor<10x!qalias> +}