-
Notifications
You must be signed in to change notification settings - Fork 15k
[mlir][linalg] convert arith ops to destination-passing-style. #157854
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Converts arith ops that operate on tensors but are not in destination passing style (DPS) to equivalent linalg generic which is in DPS. This new pass `linalg-convert-to-dps` has general use, but specifically is useful for loewr-quant-ops which operate on tensors and ops like qcast generates arith ops on tensors which without dps cannot bufferize. e.g. `%0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>` gets rewritten as: %c0 = arith.constant 0 : index %dim = tensor.dim %arg0, %c0 : tensor<?xi32> %0 = tensor.empty(%dim) : tensor<?xf32> %1 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) { ^bb0(%in: i32, %out: f32): %2 = arith.uitofp %in : i32 to f32 linalg.yield %2 : f32 } -> tensor<?xf32>
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-quant Author: Javed Absar (javedabsar1) ChangesConverts arith ops that operate on tensors but are not in destination passing style (DPS) to equivalent linalg generic which is in DPS. This new pass e.g.
Full diff: https://github.com/llvm/llvm-project/pull/157854.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 44da2965e6892..365356d3c7d6b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -171,6 +171,32 @@ def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgConvertToDPSPass : Pass<"linalg-convert-to-dps"> {
+ let summary = "Convert ops to destination-passing-style";
+ let description = [{
+ Converts ops that operate on tensors but are not in
+ destination passing style (DPS) to equivalent linalg
+ generic which is in DPS. e.g.
+ ```mlir
+ %0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>
+ ```
+ gets rewritten as:
+ ```mlir
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?xi32>
+ %0 = tensor.empty(%dim) : tensor<?xf32>
+ %1 = linalg.generic
+ {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
+ ^bb0(%in: i32, %out: f32):
+ %2 = arith.uitofp %in : i32 to f32
+ linalg.yield %2 : f32
+ } -> tensor<?xf32>
+ ```
+ }];
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0cfc8821c0add..c0d492cf69492 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1377,6 +1377,23 @@ rewriteInDestinationPassingStyle(RewriterBase &rewriter,
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
tensor::PadOp padOp);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::UIToFPOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::SIToFPOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::FPToUIOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::FPToSIOp op);
+
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::AddIOp op);
+
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::AddFOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::DivFOp op);
+
/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
/// and linalg.matmul.
///
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..b150dc084aaa7 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -58,10 +58,10 @@ using namespace mlir::transform;
/// pattern failed to apply. Extra arguments are forwarded to the pattern
/// constructor.
template <typename PatternTy, typename... Args>
-static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
+static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&... args) {
// Check if the given operation has the type expected by the pattern.
- using OpTy = typename llvm::function_traits<
- decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
+ using OpTy = typename llvm::function_traits<decltype(
+ &PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
auto op = dyn_cast<OpTy>(operation);
if (!op)
return failure();
@@ -2611,7 +2611,9 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne(
rewriter.setInsertionPoint(target);
FailureOr<Operation *> maybeResult =
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
- .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
+ .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp,
+ arith::UIToFPOp, arith::SIToFPOp, arith::FPToUIOp,
+ arith::FPToSIOp, arith::AddIOp, arith::AddFOp, arith::DivFOp>(
[&rewriter](auto op) {
return rewriteInDestinationPassingStyle(rewriter, op);
});
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 76ddee4f0e9cf..79f44ff87b3f6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -17,13 +17,22 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
+namespace mlir {
+#define GEN_PASS_DEF_LINALGCONVERTTODPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-convert-to-dps"
+
using namespace mlir;
using namespace mlir::tensor;
@@ -96,7 +105,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
OpBuilder::InsertionGuard g(rewriter);
RankedTensorType resultType = padOp.getResultType();
- // Examine the yielded value to decide if a linalg.generic is neede or a
+ // Examine the yielded value to decide if a linalg.generic is needed or a
// linalg.fill is sufficient.
Value yieldedValue =
cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
@@ -603,6 +612,56 @@ Value linalg::bufferizeToAllocation(
}
namespace {
+template <typename OpTy>
+FailureOr<Operation *>
+rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
+ // reject ops such as `arith.constant` and `arith.select`.
+ auto numOperands = op->getNumOperands();
+ if (numOperands == 0 || numOperands > 2)
+ return failure();
+
+ // destination passing style rewrite is only for ops on tensor types.
+ Type resultType = op->getResult(0).getType();
+ auto tensorType = dyn_cast<RankedTensorType>(resultType);
+ if (!tensorType)
+ return failure();
+
+ auto loc = op.getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+ auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0));
+
+ // Create tensor.empty.
+ Value empty = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
+
+ // Create linalg.generic
+ auto rank = tensorType.getRank();
+ SmallVector<AffineMap> indexingMaps(numOperands + 1,
+ rewriter.getMultiDimIdentityMap(rank));
+ SmallVector<utils::IteratorType> iteratorTypes(rank,
+ utils::IteratorType::parallel);
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, tensorType,
+ op->getOperands(), // inputs
+ ValueRange{empty}, // outputs
+ indexingMaps, iteratorTypes,
+ [&](OpBuilder &builder, Location loc, ValueRange args) {
+ Value res;
+ if (args.size() == 2) {
+ res =
+ builder.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
+ .getResult();
+ } else if (args.size() == 3) {
+ res = builder.create<OpTy>(loc, args[2].getType(),
+ ValueRange{args[0], args[1]});
+ } else
+ llvm_unreachable("did not expect ops other than nary and binary");
+ linalg::YieldOp::create(builder, loc, res);
+ });
+
+ rewriter.replaceAllUsesWith(op, genericOp.getResult(0));
+ rewriter.eraseOp(op);
+ return genericOp.getOperation();
+}
template <typename OpTy>
LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
@@ -612,9 +671,53 @@ LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
} // namespace
+#define STAMP_OUT_ARITH_DPS_FUNCS(OPTY) \
+ FailureOr<Operation *> linalg::rewriteInDestinationPassingStyle( \
+ RewriterBase &rewriter, OPTY op) { \
+ return rewriteArithInDestinationPassingStyle<OPTY>(rewriter, op); \
+ }
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::UIToFPOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::SIToFPOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToUIOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToSIOp)
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::AddIOp)
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::AddFOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::DivFOp)
+
void linalg::populateConvertToDestinationStylePatterns(
RewritePatternSet &patterns) {
patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
+
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::UIToFPOp>);
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::SIToFPOp>);
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToUIOp>);
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToSIOp>);
+
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::AddIOp>);
+
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::AddFOp>);
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::DivFOp>);
}
+
+namespace {
+struct LinalgConvertToDPSPass
+ : public impl::LinalgConvertToDPSPassBase<LinalgConvertToDPSPass> {
+ using impl::LinalgConvertToDPSPassBase<
+ LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase;
+
+ void runOnOperation() override;
+};
+
+void LinalgConvertToDPSPass::runOnOperation() {
+
+ RewritePatternSet patterns(&getContext());
+ linalg::populateConvertToDestinationStylePatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+}
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
index 63c9f1f27517b..a1df34c6555f2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
@@ -252,3 +252,64 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_unary_op(
+// CHECK-SAME: %[[X:.+]]: tensor<64xi32>) -> tensor<64xf32> {
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[X:.+]] : tensor<64xi32>) outs(%[[EMPTY]] : tensor<64xf32>) {
+// CHECK: ^bb0(%[[x:.+]]: i32, %[[Out:.+]]: f32):
+// CHECK: %[[z:.+]] = arith.uitofp %[[x]] : i32 to f32
+// CHECK: linalg.yield %[[z]] : f32
+// CHECK: return %[[GENERIC]] : tensor<64xf32>
+
+func.func @arith_unary_op(%x : tensor<64xi32>) -> tensor<64xf32> {
+ %z = arith.uitofp %x : tensor<64xi32> to tensor<64xf32>
+ return %z : tensor<64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.uitofp"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_binop(
+// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
+// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
+// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] : f32
+// CHECK: linalg.yield %[[z]] : f32
+// CHECK: return %[[GENERIC]] : tensor<?xf32>
+
+func.func @arith_binop(%x : tensor<?xf32>, %y : tensor<?xf32>)
+ -> tensor<?xf32> {
+ %z = arith.addf %x, %y : tensor<?xf32>
+ return %z : tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addf"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir b/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
new file mode 100644
index 0000000000000..0fc9f1e3ed9be
--- /dev/null
+++ b/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -lower-quant-ops -linalg-convert-to-dps \
+// RUN: -linalg-specialize-generic-ops -cse | FileCheck %s
+
+// CHECK-LABEL: func.func @lower_qcast_to_dps(
+// CHECK-SAME: %[[X:.+]]: tensor<10xf32>) -> tensor<10x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK-DAG: %[[CST_10I:.+]] = arith.constant dense<10> : tensor<10xi8>
+// CHECK-DAG: %[[CST_2F:.+]] = arith.constant dense<2.000000e+00> : tensor<10xf32>
+// CHECK: %[[E:.+]] = tensor.empty() : tensor<10xf32>
+// CHECK: %[[DIV:.+]] = linalg.div ins(%[[X]], %[[CST_2F]] : tensor<10xf32>, tensor<10xf32>)
+// CHECK-SAME: outs(%[[E]] : tensor<10xf32>) -> tensor<10xf32>
+//
+// CHECK: %[[SITOFP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[CST_10I]] : tensor<10xi8>) outs(%[[E]] : tensor<10xf32>)
+// CHECK: %{{.*}} = arith.sitofp %{{.*}} : i8 to f32
+//
+// CHECK: %[[ADD:.+]] = linalg.add ins(%[[DIV]], %[[SITOFP]] : tensor<10xf32>, tensor<10xf32>)
+// CHECK: %{{.*}} = linalg.generic
+// CHECK-SAME: ins(%[[ADD]] : tensor<10xf32>)
+// CHECK: %{{.*}} = arith.fptosi %{{.*}} : f32 to i8
+
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @lower_qcast_to_dps(%arg0: tensor<10xf32>) -> tensor<10x!qalias> {
+ %0 = quant.qcast %arg0 : tensor<10xf32> to tensor<10x!qalias>
+ return %0 : tensor<10x!qalias>
+}
|
@llvm/pr-subscribers-mlir Author: Javed Absar (javedabsar1) ChangesConverts arith ops that operate on tensors but are not in destination passing style (DPS) to equivalent linalg generic which is in DPS. This new pass e.g.
Full diff: https://github.com/llvm/llvm-project/pull/157854.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 44da2965e6892..365356d3c7d6b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -171,6 +171,32 @@ def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
let dependentDialects = ["linalg::LinalgDialect"];
}
+def LinalgConvertToDPSPass : Pass<"linalg-convert-to-dps"> {
+ let summary = "Convert ops to destination-passing-style";
+ let description = [{
+ Converts ops that operate on tensors but are not in
+ destination passing style (DPS) to equivalent linalg
+ generic which is in DPS. e.g.
+ ```mlir
+ %0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>
+ ```
+ gets rewritten as:
+ ```mlir
+ %c0 = arith.constant 0 : index
+ %dim = tensor.dim %arg0, %c0 : tensor<?xi32>
+ %0 = tensor.empty(%dim) : tensor<?xf32>
+ %1 = linalg.generic
+ {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
+ ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
+ ^bb0(%in: i32, %out: f32):
+ %2 = arith.uitofp %in : i32 to f32
+ linalg.yield %2 : f32
+ } -> tensor<?xf32>
+ ```
+ }];
+ let dependentDialects = ["linalg::LinalgDialect"];
+}
+
def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
let summary = "Detensorize linalg ops";
let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0cfc8821c0add..c0d492cf69492 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1377,6 +1377,23 @@ rewriteInDestinationPassingStyle(RewriterBase &rewriter,
FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
tensor::PadOp padOp);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::UIToFPOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::SIToFPOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::FPToUIOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::FPToSIOp op);
+
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::AddIOp op);
+
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::AddFOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+ arith::DivFOp op);
+
/// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
/// and linalg.matmul.
///
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..b150dc084aaa7 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -58,10 +58,10 @@ using namespace mlir::transform;
/// pattern failed to apply. Extra arguments are forwarded to the pattern
/// constructor.
template <typename PatternTy, typename... Args>
-static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
+static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&... args) {
// Check if the given operation has the type expected by the pattern.
- using OpTy = typename llvm::function_traits<
- decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
+ using OpTy = typename llvm::function_traits<decltype(
+ &PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
auto op = dyn_cast<OpTy>(operation);
if (!op)
return failure();
@@ -2611,7 +2611,9 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne(
rewriter.setInsertionPoint(target);
FailureOr<Operation *> maybeResult =
TypeSwitch<Operation *, FailureOr<Operation *>>(target)
- .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
+ .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp,
+ arith::UIToFPOp, arith::SIToFPOp, arith::FPToUIOp,
+ arith::FPToSIOp, arith::AddIOp, arith::AddFOp, arith::DivFOp>(
[&rewriter](auto op) {
return rewriteInDestinationPassingStyle(rewriter, op);
});
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 76ddee4f0e9cf..79f44ff87b3f6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -17,13 +17,22 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
+namespace mlir {
+#define GEN_PASS_DEF_LINALGCONVERTTODPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-convert-to-dps"
+
using namespace mlir;
using namespace mlir::tensor;
@@ -96,7 +105,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
OpBuilder::InsertionGuard g(rewriter);
RankedTensorType resultType = padOp.getResultType();
- // Examine the yielded value to decide if a linalg.generic is neede or a
+ // Examine the yielded value to decide if a linalg.generic is needed or a
// linalg.fill is sufficient.
Value yieldedValue =
cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
@@ -603,6 +612,56 @@ Value linalg::bufferizeToAllocation(
}
namespace {
+template <typename OpTy>
+FailureOr<Operation *>
+rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
+ // reject ops such as `arith.constant` and `arith.select`.
+ auto numOperands = op->getNumOperands();
+ if (numOperands == 0 || numOperands > 2)
+ return failure();
+
+ // destination passing style rewrite is only for ops on tensor types.
+ Type resultType = op->getResult(0).getType();
+ auto tensorType = dyn_cast<RankedTensorType>(resultType);
+ if (!tensorType)
+ return failure();
+
+ auto loc = op.getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+ auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0));
+
+ // Create tensor.empty.
+ Value empty = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
+
+ // Create linalg.generic
+ auto rank = tensorType.getRank();
+ SmallVector<AffineMap> indexingMaps(numOperands + 1,
+ rewriter.getMultiDimIdentityMap(rank));
+ SmallVector<utils::IteratorType> iteratorTypes(rank,
+ utils::IteratorType::parallel);
+ auto genericOp = linalg::GenericOp::create(
+ rewriter, loc, tensorType,
+ op->getOperands(), // inputs
+ ValueRange{empty}, // outputs
+ indexingMaps, iteratorTypes,
+ [&](OpBuilder &builder, Location loc, ValueRange args) {
+ Value res;
+ if (args.size() == 2) {
+ res =
+ builder.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
+ .getResult();
+ } else if (args.size() == 3) {
+ res = builder.create<OpTy>(loc, args[2].getType(),
+ ValueRange{args[0], args[1]});
+ } else
+ llvm_unreachable("did not expect ops other than nary and binary");
+ linalg::YieldOp::create(builder, loc, res);
+ });
+
+ rewriter.replaceAllUsesWith(op, genericOp.getResult(0));
+ rewriter.eraseOp(op);
+ return genericOp.getOperation();
+}
template <typename OpTy>
LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
@@ -612,9 +671,53 @@ LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
} // namespace
+#define STAMP_OUT_ARITH_DPS_FUNCS(OPTY) \
+ FailureOr<Operation *> linalg::rewriteInDestinationPassingStyle( \
+ RewriterBase &rewriter, OPTY op) { \
+ return rewriteArithInDestinationPassingStyle<OPTY>(rewriter, op); \
+ }
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::UIToFPOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::SIToFPOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToUIOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToSIOp)
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::AddIOp)
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::AddFOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::DivFOp)
+
void linalg::populateConvertToDestinationStylePatterns(
RewritePatternSet &patterns) {
patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
+
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::UIToFPOp>);
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::SIToFPOp>);
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToUIOp>);
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToSIOp>);
+
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::AddIOp>);
+
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::AddFOp>);
+ patterns.add(rewriteOpInDestinationPassingStyle<arith::DivFOp>);
}
+
+namespace {
+struct LinalgConvertToDPSPass
+ : public impl::LinalgConvertToDPSPassBase<LinalgConvertToDPSPass> {
+ using impl::LinalgConvertToDPSPassBase<
+ LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase;
+
+ void runOnOperation() override;
+};
+
+void LinalgConvertToDPSPass::runOnOperation() {
+
+ RewritePatternSet patterns(&getContext());
+ linalg::populateConvertToDestinationStylePatterns(patterns);
+ if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+ signalPassFailure();
+}
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
index 63c9f1f27517b..a1df34c6555f2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
@@ -252,3 +252,64 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_unary_op(
+// CHECK-SAME: %[[X:.+]]: tensor<64xi32>) -> tensor<64xf32> {
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[X:.+]] : tensor<64xi32>) outs(%[[EMPTY]] : tensor<64xf32>) {
+// CHECK: ^bb0(%[[x:.+]]: i32, %[[Out:.+]]: f32):
+// CHECK: %[[z:.+]] = arith.uitofp %[[x]] : i32 to f32
+// CHECK: linalg.yield %[[z]] : f32
+// CHECK: return %[[GENERIC]] : tensor<64xf32>
+
+func.func @arith_unary_op(%x : tensor<64xi32>) -> tensor<64xf32> {
+ %z = arith.uitofp %x : tensor<64xi32> to tensor<64xf32>
+ return %z : tensor<64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.uitofp"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_binop(
+// CHECK-SAME: %[[X:.+]]: tensor<?xf32>, %[[Y:.+]]: tensor<?xf32>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
+// CHECK: %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
+// CHECK: %[[GENERIC:.+]] = linalg.generic
+// CHECK-SAME: {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+// CHECK-SAME: ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
+// CHECK: ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
+// CHECK: %[[z:.+]] = arith.addf %[[x]], %[[y]] : f32
+// CHECK: linalg.yield %[[z]] : f32
+// CHECK: return %[[GENERIC]] : tensor<?xf32>
+
+func.func @arith_binop(%x : tensor<?xf32>, %y : tensor<?xf32>)
+ -> tensor<?xf32> {
+ %z = arith.addf %x, %y : tensor<?xf32>
+ return %z : tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["arith.addf"]} in %arg1
+ : (!transform.any_op) -> !transform.any_op
+ transform.structured.rewrite_in_destination_passing_style %0
+ : (!transform.any_op) -> !transform.any_op
+ transform.yield
+ }
+}
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir b/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
new file mode 100644
index 0000000000000..0fc9f1e3ed9be
--- /dev/null
+++ b/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -lower-quant-ops -linalg-convert-to-dps \
+// RUN: -linalg-specialize-generic-ops -cse | FileCheck %s
+
+// CHECK-LABEL: func.func @lower_qcast_to_dps(
+// CHECK-SAME: %[[X:.+]]: tensor<10xf32>) -> tensor<10x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK-DAG: %[[CST_10I:.+]] = arith.constant dense<10> : tensor<10xi8>
+// CHECK-DAG: %[[CST_2F:.+]] = arith.constant dense<2.000000e+00> : tensor<10xf32>
+// CHECK: %[[E:.+]] = tensor.empty() : tensor<10xf32>
+// CHECK: %[[DIV:.+]] = linalg.div ins(%[[X]], %[[CST_2F]] : tensor<10xf32>, tensor<10xf32>)
+// CHECK-SAME: outs(%[[E]] : tensor<10xf32>) -> tensor<10xf32>
+//
+// CHECK: %[[SITOFP:.+]] = linalg.generic
+// CHECK-SAME: ins(%[[CST_10I]] : tensor<10xi8>) outs(%[[E]] : tensor<10xf32>)
+// CHECK: %{{.*}} = arith.sitofp %{{.*}} : i8 to f32
+//
+// CHECK: %[[ADD:.+]] = linalg.add ins(%[[DIV]], %[[SITOFP]] : tensor<10xf32>, tensor<10xf32>)
+// CHECK: %{{.*}} = linalg.generic
+// CHECK-SAME: ins(%[[ADD]] : tensor<10xf32>)
+// CHECK: %{{.*}} = arith.fptosi %{{.*}} : f32 to i8
+
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @lower_qcast_to_dps(%arg0: tensor<10xf32>) -> tensor<10x!qalias> {
+ %0 = quant.qcast %arg0 : tensor<10xf32> to tensor<10x!qalias>
+ return %0 : tensor<10x!qalias>
+}
|
You can test this locally with the following command:git-clang-format --diff origin/main HEAD --extensions cpp,h -- mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
View the diff from clang-format here.diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 2fc864594..58de12a4e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -653,7 +653,7 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
auto genericOp = linalg::GenericOp::create(
rewriter, loc, tensorType,
op->getOperands(), // inputs
- ValueRange{outs}, // outputs
+ ValueRange{outs}, // outputs
indexingMaps, iteratorTypes,
[&](OpBuilder &builder, Location loc, ValueRange args) {
Value res;
|
Thanks Javed!
This sounds like a very good justification, but could you share a repro so that we can see what the issue is? In particular, it would be good to make sure that we are not missing some other, less involved, solution to this. This could be either here or as a GitHub issue if it's more involved.
Hm, isn't this saying "rewrite arith Ops on tensors as a linalg.generic"? As in, DPS is just a by-product/justification for this rewrite (as opposed to being the end goal)? EDIT OK, sorry, I've just realised that you are adding this in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we go to linalg.generic
instead of linalg.elementwise
(and potentially generalize those later)? Are you worried about support for fast math flags etc.?
mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
Outdated
Show resolved
Hide resolved
rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) { | ||
// reject ops such as `arith.constant` and `arith.select`. | ||
auto numOperands = op->getNumOperands(); | ||
if (numOperands == 0 || numOperands > 2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
only unary and binary we care about.
mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
Outdated
Show resolved
Hide resolved
mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
Outdated
Show resolved
Hide resolved
Thanks @banach-space for the review.
Not limited to lower-quant but here is an example. When we run
|
We could just use |
|
answered above. |
addressed all review comments. |
: (!transform.any_op) -> !transform.any_op | ||
transform.yield | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test case that contains fast math flags?
Converts arith ops that operate on tensors but are not in destination passing style (DPS) to equivalent linalg generic which is in DPS. This new pass
linalg-convert-to-dps
has general use, but specifically is useful for loewr-quant-ops which operate on tensors and ops like qcast generates arith ops on tensors which without dps cannot bufferize.e.g.
%0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>
gets rewritten as: