Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

javedabsar1
Copy link
Contributor

@javedabsar1 javedabsar1 commented Sep 10, 2025

Converts arith ops that operate on tensors but are not in destination passing style (DPS) to equivalent linalg generic which is in DPS. This new pass linalg-convert-to-dps has general use, but specifically is useful for loewr-quant-ops which operate on tensors and ops like qcast generates arith ops on tensors which without dps cannot bufferize.
e.g.
%0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>
gets rewritten as:

    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?xi32>
    %0 = tensor.empty(%dim) : tensor<?xf32>
    %1 = linalg.generic
          {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
          ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
           ^bb0(%in: i32, %out: f32):
             %2 = arith.uitofp %in : i32 to f32
             linalg.yield %2 : f32
    } -> tensor<?xf32>

Converts arith ops that operate on tensors but are not in destination
passing style (DPS) to equivalent linalg generic which is in DPS.
This new pass `linalg-convert-to-dps` has general use, but specifically
is useful for loewr-quant-ops which operate on tensors and ops like qcast
generates arith ops on tensors which without dps cannot bufferize.

 e.g.
    `%0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>`
    gets rewritten as:

    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?xi32>
    %0 = tensor.empty(%dim) : tensor<?xf32>
    %1 = linalg.generic
          {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
          ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
           ^bb0(%in: i32, %out: f32):
             %2 = arith.uitofp %in : i32 to f32
             linalg.yield %2 : f32
    } -> tensor<?xf32>
@llvmbot
Copy link
Member

llvmbot commented Sep 10, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-quant

Author: Javed Absar (javedabsar1)

Changes

Converts arith ops that operate on tensors but are not in destination passing style (DPS) to equivalent linalg generic which is in DPS. This new pass linalg-convert-to-dps has general use, but specifically is useful for loewr-quant-ops which operate on tensors and ops like qcast generates arith ops on tensors which without dps cannot bufferize.

e.g.
%0 = arith.uitofp %arg0 : tensor&lt;?xi32&gt; to tensor&lt;?xf32&gt;
gets rewritten as:

%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor&lt;?xi32&gt;
%0 = tensor.empty(%dim) : tensor&lt;?xf32&gt;
%1 = linalg.generic
      {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
      ins(%arg0 : tensor&lt;?xi32&gt;) outs(%0 : tensor&lt;?xf32&gt;) {
       ^bb0(%in: i32, %out: f32):
         %2 = arith.uitofp %in : i32 to f32
         linalg.yield %2 : f32
} -&gt; tensor&lt;?xf32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/157854.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+26)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+17)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp (+104-1)
  • (modified) mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir (+61)
  • (added) mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir (+26)
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 44da2965e6892..365356d3c7d6b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -171,6 +171,32 @@ def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
   let dependentDialects = ["linalg::LinalgDialect"];
 }
 
+def LinalgConvertToDPSPass : Pass<"linalg-convert-to-dps"> {
+  let summary = "Convert ops to destination-passing-style";
+  let description = [{
+    Converts ops that operate on tensors but are not in
+    destination passing style (DPS) to equivalent linalg
+    generic which is in DPS. e.g.
+    ```mlir
+      %0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>
+    ```
+    gets rewritten as:
+    ```mlir
+       %c0 = arith.constant 0 : index
+       %dim = tensor.dim %arg0, %c0 : tensor<?xi32>
+       %0 = tensor.empty(%dim) : tensor<?xf32>
+       %1 = linalg.generic
+           {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
+           ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
+           ^bb0(%in: i32, %out: f32):
+             %2 = arith.uitofp %in : i32 to f32
+             linalg.yield %2 : f32
+       } -> tensor<?xf32>
+     ```
+   }];
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
   let summary = "Detensorize linalg ops";
   let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0cfc8821c0add..c0d492cf69492 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1377,6 +1377,23 @@ rewriteInDestinationPassingStyle(RewriterBase &rewriter,
 FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
                                                         tensor::PadOp padOp);
 
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::UIToFPOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::SIToFPOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::FPToUIOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::FPToSIOp op);
+
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::AddIOp op);
+
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::AddFOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::DivFOp op);
+
 /// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
 /// and linalg.matmul.
 ///
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..b150dc084aaa7 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -58,10 +58,10 @@ using namespace mlir::transform;
 /// pattern failed to apply. Extra arguments are forwarded to the pattern
 /// constructor.
 template <typename PatternTy, typename... Args>
-static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
+static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&... args) {
   // Check if the given operation has the type expected by the pattern.
-  using OpTy = typename llvm::function_traits<
-      decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
+  using OpTy = typename llvm::function_traits<decltype(
+      &PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
   auto op = dyn_cast<OpTy>(operation);
   if (!op)
     return failure();
@@ -2611,7 +2611,9 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne(
   rewriter.setInsertionPoint(target);
   FailureOr<Operation *> maybeResult =
       TypeSwitch<Operation *, FailureOr<Operation *>>(target)
-          .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
+          .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp,
+                arith::UIToFPOp, arith::SIToFPOp, arith::FPToUIOp,
+                arith::FPToSIOp, arith::AddIOp, arith::AddFOp, arith::DivFOp>(
               [&rewriter](auto op) {
                 return rewriteInDestinationPassingStyle(rewriter, op);
               });
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 76ddee4f0e9cf..79f44ff87b3f6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -17,13 +17,22 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 
+namespace mlir {
+#define GEN_PASS_DEF_LINALGCONVERTTODPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-convert-to-dps"
+
 using namespace mlir;
 using namespace mlir::tensor;
 
@@ -96,7 +105,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
   OpBuilder::InsertionGuard g(rewriter);
   RankedTensorType resultType = padOp.getResultType();
 
-  // Examine the yielded value to decide if a linalg.generic is neede or a
+  // Examine the yielded value to decide if a linalg.generic is needed or a
   // linalg.fill is sufficient.
   Value yieldedValue =
       cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
@@ -603,6 +612,56 @@ Value linalg::bufferizeToAllocation(
 }
 
 namespace {
+template <typename OpTy>
+FailureOr<Operation *>
+rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
+  // reject ops such as `arith.constant` and `arith.select`.
+  auto numOperands = op->getNumOperands();
+  if (numOperands == 0 || numOperands > 2)
+    return failure();
+
+  // destination passing style rewrite is only for ops on tensor types.
+  Type resultType = op->getResult(0).getType();
+  auto tensorType = dyn_cast<RankedTensorType>(resultType);
+  if (!tensorType)
+    return failure();
+
+  auto loc = op.getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+  auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0));
+
+  // Create tensor.empty.
+  Value empty = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
+
+  // Create linalg.generic
+  auto rank = tensorType.getRank();
+  SmallVector<AffineMap> indexingMaps(numOperands + 1,
+                                      rewriter.getMultiDimIdentityMap(rank));
+  SmallVector<utils::IteratorType> iteratorTypes(rank,
+                                                 utils::IteratorType::parallel);
+  auto genericOp = linalg::GenericOp::create(
+      rewriter, loc, tensorType,
+      op->getOperands(), // inputs
+      ValueRange{empty}, // outputs
+      indexingMaps, iteratorTypes,
+      [&](OpBuilder &builder, Location loc, ValueRange args) {
+        Value res;
+        if (args.size() == 2) {
+          res =
+              builder.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
+                  .getResult();
+        } else if (args.size() == 3) {
+          res = builder.create<OpTy>(loc, args[2].getType(),
+                                     ValueRange{args[0], args[1]});
+        } else
+          llvm_unreachable("did not expect ops other than nary and binary");
+        linalg::YieldOp::create(builder, loc, res);
+      });
+
+  rewriter.replaceAllUsesWith(op, genericOp.getResult(0));
+  rewriter.eraseOp(op);
+  return genericOp.getOperation();
+}
 
 template <typename OpTy>
 LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
@@ -612,9 +671,53 @@ LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
 
 } // namespace
 
+#define STAMP_OUT_ARITH_DPS_FUNCS(OPTY)                                        \
+  FailureOr<Operation *> linalg::rewriteInDestinationPassingStyle(             \
+      RewriterBase &rewriter, OPTY op) {                                       \
+    return rewriteArithInDestinationPassingStyle<OPTY>(rewriter, op);          \
+  }
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::UIToFPOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::SIToFPOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToUIOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToSIOp)
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::AddIOp)
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::AddFOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::DivFOp)
+
 void linalg::populateConvertToDestinationStylePatterns(
     RewritePatternSet &patterns) {
   patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
   patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
   patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
+
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::UIToFPOp>);
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::SIToFPOp>);
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToUIOp>);
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToSIOp>);
+
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::AddIOp>);
+
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::AddFOp>);
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::DivFOp>);
 }
+
+namespace {
+struct LinalgConvertToDPSPass
+    : public impl::LinalgConvertToDPSPassBase<LinalgConvertToDPSPass> {
+  using impl::LinalgConvertToDPSPassBase<
+      LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase;
+
+  void runOnOperation() override;
+};
+
+void LinalgConvertToDPSPass::runOnOperation() {
+
+  RewritePatternSet patterns(&getContext());
+  linalg::populateConvertToDestinationStylePatterns(patterns);
+  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+    signalPassFailure();
+}
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
index 63c9f1f27517b..a1df34c6555f2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
@@ -252,3 +252,64 @@ module attributes {transform.with_named_sequence} {
       transform.yield
   }
 }
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_unary_op(
+//  CHECK-SAME:   %[[X:.+]]: tensor<64xi32>) ->  tensor<64xf32> {
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32>
+//       CHECK:   %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:       {indexing_maps = [#[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+//  CHECK-SAME:       ins(%[[X:.+]] : tensor<64xi32>) outs(%[[EMPTY]] : tensor<64xf32>) {
+//       CHECK:     ^bb0(%[[x:.+]]: i32, %[[Out:.+]]: f32):
+//       CHECK:        %[[z:.+]] = arith.uitofp %[[x]] : i32 to f32
+//       CHECK:        linalg.yield %[[z]] : f32
+//       CHECK:  return %[[GENERIC]] : tensor<64xf32>
+
+func.func @arith_unary_op(%x : tensor<64xi32>) -> tensor<64xf32> {
+  %z = arith.uitofp %x : tensor<64xi32> to tensor<64xf32>
+  return %z : tensor<64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.uitofp"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.rewrite_in_destination_passing_style %0
+      : (!transform.any_op) -> !transform.any_op
+      transform.yield
+  }
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_binop(
+//  CHECK-SAME:   %[[X:.+]]: tensor<?xf32>,  %[[Y:.+]]: tensor<?xf32>
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
+//       CHECK:   %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:       {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+//  CHECK-SAME:       ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
+//       CHECK:     ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
+//       CHECK:        %[[z:.+]] = arith.addf %[[x]], %[[y]] : f32
+//       CHECK:        linalg.yield %[[z]] : f32
+//       CHECK:  return %[[GENERIC]] : tensor<?xf32>
+
+func.func @arith_binop(%x : tensor<?xf32>, %y : tensor<?xf32>)
+    -> tensor<?xf32> {
+  %z = arith.addf %x, %y : tensor<?xf32>
+  return %z : tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addf"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.rewrite_in_destination_passing_style %0
+      : (!transform.any_op) -> !transform.any_op
+      transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir b/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
new file mode 100644
index 0000000000000..0fc9f1e3ed9be
--- /dev/null
+++ b/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -lower-quant-ops -linalg-convert-to-dps \
+// RUN:             -linalg-specialize-generic-ops -cse | FileCheck %s
+
+// CHECK-LABEL: func.func @lower_qcast_to_dps(
+// CHECK-SAME:    %[[X:.+]]: tensor<10xf32>) -> tensor<10x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK-DAG:     %[[CST_10I:.+]] = arith.constant dense<10> : tensor<10xi8>
+// CHECK-DAG:     %[[CST_2F:.+]] = arith.constant dense<2.000000e+00> : tensor<10xf32>
+// CHECK:         %[[E:.+]] = tensor.empty() : tensor<10xf32>
+// CHECK:         %[[DIV:.+]] = linalg.div ins(%[[X]], %[[CST_2F]] : tensor<10xf32>, tensor<10xf32>)
+// CHECK-SAME:                             outs(%[[E]] : tensor<10xf32>) -> tensor<10xf32>
+//
+// CHECK:         %[[SITOFP:.+]] = linalg.generic
+// CHECK-SAME:       ins(%[[CST_10I]] : tensor<10xi8>) outs(%[[E]] : tensor<10xf32>)
+// CHECK:            %{{.*}} = arith.sitofp %{{.*}} : i8 to f32
+//
+// CHECK:         %[[ADD:.+]] = linalg.add ins(%[[DIV]], %[[SITOFP]] : tensor<10xf32>, tensor<10xf32>)
+// CHECK:         %{{.*}} = linalg.generic
+// CHECK-SAME:       ins(%[[ADD]] : tensor<10xf32>)
+// CHECK:            %{{.*}} = arith.fptosi %{{.*}} : f32 to i8
+
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @lower_qcast_to_dps(%arg0: tensor<10xf32>) -> tensor<10x!qalias> {
+  %0 = quant.qcast %arg0 : tensor<10xf32> to tensor<10x!qalias>
+  return %0 : tensor<10x!qalias>
+}

@llvmbot
Copy link
Member

llvmbot commented Sep 10, 2025

@llvm/pr-subscribers-mlir

Author: Javed Absar (javedabsar1)

Changes

Converts arith ops that operate on tensors but are not in destination passing style (DPS) to equivalent linalg generic which is in DPS. This new pass linalg-convert-to-dps has general use, but specifically is useful for loewr-quant-ops which operate on tensors and ops like qcast generates arith ops on tensors which without dps cannot bufferize.

e.g.
%0 = arith.uitofp %arg0 : tensor&lt;?xi32&gt; to tensor&lt;?xf32&gt;
gets rewritten as:

%c0 = arith.constant 0 : index
%dim = tensor.dim %arg0, %c0 : tensor&lt;?xi32&gt;
%0 = tensor.empty(%dim) : tensor&lt;?xf32&gt;
%1 = linalg.generic
      {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
      ins(%arg0 : tensor&lt;?xi32&gt;) outs(%0 : tensor&lt;?xf32&gt;) {
       ^bb0(%in: i32, %out: f32):
         %2 = arith.uitofp %in : i32 to f32
         linalg.yield %2 : f32
} -&gt; tensor&lt;?xf32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/157854.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Passes.td (+26)
  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+17)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+6-4)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp (+104-1)
  • (modified) mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir (+61)
  • (added) mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir (+26)
diff --git a/mlir/include/mlir/Dialect/Linalg/Passes.td b/mlir/include/mlir/Dialect/Linalg/Passes.td
index 44da2965e6892..365356d3c7d6b 100644
--- a/mlir/include/mlir/Dialect/Linalg/Passes.td
+++ b/mlir/include/mlir/Dialect/Linalg/Passes.td
@@ -171,6 +171,32 @@ def LinalgFoldIntoElementwisePass : Pass<"linalg-fold-into-elementwise"> {
   let dependentDialects = ["linalg::LinalgDialect"];
 }
 
+def LinalgConvertToDPSPass : Pass<"linalg-convert-to-dps"> {
+  let summary = "Convert ops to destination-passing-style";
+  let description = [{
+    Converts ops that operate on tensors but are not in
+    destination passing style (DPS) to equivalent linalg
+    generic which is in DPS. e.g.
+    ```mlir
+      %0 = arith.uitofp %arg0 : tensor<?xi32> to tensor<?xf32>
+    ```
+    gets rewritten as:
+    ```mlir
+       %c0 = arith.constant 0 : index
+       %dim = tensor.dim %arg0, %c0 : tensor<?xi32>
+       %0 = tensor.empty(%dim) : tensor<?xf32>
+       %1 = linalg.generic
+           {indexing_maps = [#map, #map], iterator_types = ["parallel"]}
+           ins(%arg0 : tensor<?xi32>) outs(%0 : tensor<?xf32>) {
+           ^bb0(%in: i32, %out: f32):
+             %2 = arith.uitofp %in : i32 to f32
+             linalg.yield %2 : f32
+       } -> tensor<?xf32>
+     ```
+   }];
+  let dependentDialects = ["linalg::LinalgDialect"];
+}
+
 def LinalgDetensorizePass : InterfacePass<"linalg-detensorize", "FunctionOpInterface"> {
   let summary = "Detensorize linalg ops";
   let dependentDialects = [];
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 0cfc8821c0add..c0d492cf69492 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1377,6 +1377,23 @@ rewriteInDestinationPassingStyle(RewriterBase &rewriter,
 FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
                                                         tensor::PadOp padOp);
 
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::UIToFPOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::SIToFPOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::FPToUIOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::FPToSIOp op);
+
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::AddIOp op);
+
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::AddFOp op);
+FailureOr<Operation *> rewriteInDestinationPassingStyle(RewriterBase &rewriter,
+                                                        arith::DivFOp op);
+
 /// Convert linalg.conv_2d_nhwc_hwcf into linalg.generic (for img2col packing)
 /// and linalg.matmul.
 ///
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..b150dc084aaa7 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -58,10 +58,10 @@ using namespace mlir::transform;
 /// pattern failed to apply. Extra arguments are forwarded to the pattern
 /// constructor.
 template <typename PatternTy, typename... Args>
-static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
+static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&... args) {
   // Check if the given operation has the type expected by the pattern.
-  using OpTy = typename llvm::function_traits<
-      decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
+  using OpTy = typename llvm::function_traits<decltype(
+      &PatternTy::returningMatchAndRewrite)>::template arg_t<0>;
   auto op = dyn_cast<OpTy>(operation);
   if (!op)
     return failure();
@@ -2611,7 +2611,9 @@ transform::RewriteInDestinationPassingStyleOp::applyToOne(
   rewriter.setInsertionPoint(target);
   FailureOr<Operation *> maybeResult =
       TypeSwitch<Operation *, FailureOr<Operation *>>(target)
-          .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp>(
+          .Case<tensor::FromElementsOp, tensor::GenerateOp, tensor::PadOp,
+                arith::UIToFPOp, arith::SIToFPOp, arith::FPToUIOp,
+                arith::FPToSIOp, arith::AddIOp, arith::AddFOp, arith::DivFOp>(
               [&rewriter](auto op) {
                 return rewriteInDestinationPassingStyle(rewriter, op);
               });
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 76ddee4f0e9cf..79f44ff87b3f6 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -17,13 +17,22 @@
 #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
 #include "mlir/Dialect/Bufferization/IR/Bufferization.h"
 #include "mlir/Dialect/Linalg/IR/Linalg.h"
+#include "mlir/Dialect/Linalg/Passes.h"
 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/Dialect/Utils/StaticValueUtils.h"
 #include "mlir/IR/Matchers.h"
 #include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 #include "llvm/ADT/STLExtras.h"
 
+namespace mlir {
+#define GEN_PASS_DEF_LINALGCONVERTTODPSPASS
+#include "mlir/Dialect/Linalg/Passes.h.inc"
+} // namespace mlir
+
+#define DEBUG_TYPE "linalg-convert-to-dps"
+
 using namespace mlir;
 using namespace mlir::tensor;
 
@@ -96,7 +105,7 @@ static Operation *movePaddingToFillOrGenericOp(RewriterBase &rewriter,
   OpBuilder::InsertionGuard g(rewriter);
   RankedTensorType resultType = padOp.getResultType();
 
-  // Examine the yielded value to decide if a linalg.generic is neede or a
+  // Examine the yielded value to decide if a linalg.generic is needed or a
   // linalg.fill is sufficient.
   Value yieldedValue =
       cast<tensor::YieldOp>(padOp.getBody()->getTerminator()).getValue();
@@ -603,6 +612,56 @@ Value linalg::bufferizeToAllocation(
 }
 
 namespace {
+template <typename OpTy>
+FailureOr<Operation *>
+rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
+  // reject ops such as `arith.constant` and `arith.select`.
+  auto numOperands = op->getNumOperands();
+  if (numOperands == 0 || numOperands > 2)
+    return failure();
+
+  // destination passing style rewrite is only for ops on tensor types.
+  Type resultType = op->getResult(0).getType();
+  auto tensorType = dyn_cast<RankedTensorType>(resultType);
+  if (!tensorType)
+    return failure();
+
+  auto loc = op.getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+  auto dynSizes = reifyOrComputeDynamicSizes(rewriter, op->getOperand(0));
+
+  // Create tensor.empty.
+  Value empty = tensor::EmptyOp::create(rewriter, loc, resultType, dynSizes);
+
+  // Create linalg.generic
+  auto rank = tensorType.getRank();
+  SmallVector<AffineMap> indexingMaps(numOperands + 1,
+                                      rewriter.getMultiDimIdentityMap(rank));
+  SmallVector<utils::IteratorType> iteratorTypes(rank,
+                                                 utils::IteratorType::parallel);
+  auto genericOp = linalg::GenericOp::create(
+      rewriter, loc, tensorType,
+      op->getOperands(), // inputs
+      ValueRange{empty}, // outputs
+      indexingMaps, iteratorTypes,
+      [&](OpBuilder &builder, Location loc, ValueRange args) {
+        Value res;
+        if (args.size() == 2) {
+          res =
+              builder.create<OpTy>(loc, args[1].getType(), ValueRange{args[0]})
+                  .getResult();
+        } else if (args.size() == 3) {
+          res = builder.create<OpTy>(loc, args[2].getType(),
+                                     ValueRange{args[0], args[1]});
+        } else
+          llvm_unreachable("did not expect ops other than nary and binary");
+        linalg::YieldOp::create(builder, loc, res);
+      });
+
+  rewriter.replaceAllUsesWith(op, genericOp.getResult(0));
+  rewriter.eraseOp(op);
+  return genericOp.getOperation();
+}
 
 template <typename OpTy>
 LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
@@ -612,9 +671,53 @@ LogicalResult rewriteOpInDestinationPassingStyle(OpTy op,
 
 } // namespace
 
+#define STAMP_OUT_ARITH_DPS_FUNCS(OPTY)                                        \
+  FailureOr<Operation *> linalg::rewriteInDestinationPassingStyle(             \
+      RewriterBase &rewriter, OPTY op) {                                       \
+    return rewriteArithInDestinationPassingStyle<OPTY>(rewriter, op);          \
+  }
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::UIToFPOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::SIToFPOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToUIOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::FPToSIOp)
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::AddIOp)
+
+STAMP_OUT_ARITH_DPS_FUNCS(arith::AddFOp)
+STAMP_OUT_ARITH_DPS_FUNCS(arith::DivFOp)
+
 void linalg::populateConvertToDestinationStylePatterns(
     RewritePatternSet &patterns) {
   patterns.add(rewriteOpInDestinationPassingStyle<tensor::FromElementsOp>);
   patterns.add(rewriteOpInDestinationPassingStyle<tensor::GenerateOp>);
   patterns.add(rewriteOpInDestinationPassingStyle<tensor::PadOp>);
+
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::UIToFPOp>);
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::SIToFPOp>);
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToUIOp>);
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::FPToSIOp>);
+
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::AddIOp>);
+
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::AddFOp>);
+  patterns.add(rewriteOpInDestinationPassingStyle<arith::DivFOp>);
 }
+
+namespace {
+struct LinalgConvertToDPSPass
+    : public impl::LinalgConvertToDPSPassBase<LinalgConvertToDPSPass> {
+  using impl::LinalgConvertToDPSPassBase<
+      LinalgConvertToDPSPass>::LinalgConvertToDPSPassBase;
+
+  void runOnOperation() override;
+};
+
+void LinalgConvertToDPSPass::runOnOperation() {
+
+  RewritePatternSet patterns(&getContext());
+  linalg::populateConvertToDestinationStylePatterns(patterns);
+  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
+    signalPassFailure();
+}
+} // namespace
diff --git a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
index 63c9f1f27517b..a1df34c6555f2 100644
--- a/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
+++ b/mlir/test/Dialect/Linalg/transform-op-rewrite-in-destination-passing-style.mlir
@@ -252,3 +252,64 @@ module attributes {transform.with_named_sequence} {
       transform.yield
   }
 }
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_unary_op(
+//  CHECK-SAME:   %[[X:.+]]: tensor<64xi32>) ->  tensor<64xf32> {
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty() : tensor<64xf32>
+//       CHECK:   %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:       {indexing_maps = [#[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+//  CHECK-SAME:       ins(%[[X:.+]] : tensor<64xi32>) outs(%[[EMPTY]] : tensor<64xf32>) {
+//       CHECK:     ^bb0(%[[x:.+]]: i32, %[[Out:.+]]: f32):
+//       CHECK:        %[[z:.+]] = arith.uitofp %[[x]] : i32 to f32
+//       CHECK:        linalg.yield %[[z]] : f32
+//       CHECK:  return %[[GENERIC]] : tensor<64xf32>
+
+func.func @arith_unary_op(%x : tensor<64xi32>) -> tensor<64xf32> {
+  %z = arith.uitofp %x : tensor<64xi32> to tensor<64xf32>
+  return %z : tensor<64xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.uitofp"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.rewrite_in_destination_passing_style %0
+      : (!transform.any_op) -> !transform.any_op
+      transform.yield
+  }
+}
+
+// -----
+
+// CHECK: #[[$map:.*]] = affine_map<(d0) -> (d0)>
+// CHECK-LABEL: func @arith_binop(
+//  CHECK-SAME:   %[[X:.+]]: tensor<?xf32>,  %[[Y:.+]]: tensor<?xf32>
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[DIM:.+]] = tensor.dim %[[X]], %[[C0]] : tensor<?xf32>
+//       CHECK:   %[[EMPTY:.+]] = tensor.empty(%[[DIM]]) : tensor<?xf32>
+//       CHECK:   %[[GENERIC:.+]] = linalg.generic
+//  CHECK-SAME:       {indexing_maps = [#[[$map]], #[[$map]], #[[$map]]], iterator_types = ["parallel"]}
+//  CHECK-SAME:       ins(%[[X:.+]], %[[Y:.+]] : tensor<?xf32>, tensor<?xf32>) outs(%[[EMPTY]] : tensor<?xf32>) {
+//       CHECK:     ^bb0(%[[x:.+]]: f32, %[[y:.+]]: f32, %[[Out:.+]]: f32):
+//       CHECK:        %[[z:.+]] = arith.addf %[[x]], %[[y]] : f32
+//       CHECK:        linalg.yield %[[z]] : f32
+//       CHECK:  return %[[GENERIC]] : tensor<?xf32>
+
+func.func @arith_binop(%x : tensor<?xf32>, %y : tensor<?xf32>)
+    -> tensor<?xf32> {
+  %z = arith.addf %x, %y : tensor<?xf32>
+  return %z : tensor<?xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+  transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+    %0 = transform.structured.match ops{["arith.addf"]} in %arg1
+      : (!transform.any_op) -> !transform.any_op
+    transform.structured.rewrite_in_destination_passing_style %0
+      : (!transform.any_op) -> !transform.any_op
+      transform.yield
+  }
+}
diff --git a/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir b/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
new file mode 100644
index 0000000000000..0fc9f1e3ed9be
--- /dev/null
+++ b/mlir/test/Dialect/Quant/lower-quant-ops-to-dps.mlir
@@ -0,0 +1,26 @@
+// RUN: mlir-opt %s -lower-quant-ops -linalg-convert-to-dps \
+// RUN:             -linalg-specialize-generic-ops -cse | FileCheck %s
+
+// CHECK-LABEL: func.func @lower_qcast_to_dps(
+// CHECK-SAME:    %[[X:.+]]: tensor<10xf32>) -> tensor<10x!quant.uniform<i8:f32, 2.000000e+00:10>>
+// CHECK-DAG:     %[[CST_10I:.+]] = arith.constant dense<10> : tensor<10xi8>
+// CHECK-DAG:     %[[CST_2F:.+]] = arith.constant dense<2.000000e+00> : tensor<10xf32>
+// CHECK:         %[[E:.+]] = tensor.empty() : tensor<10xf32>
+// CHECK:         %[[DIV:.+]] = linalg.div ins(%[[X]], %[[CST_2F]] : tensor<10xf32>, tensor<10xf32>)
+// CHECK-SAME:                             outs(%[[E]] : tensor<10xf32>) -> tensor<10xf32>
+//
+// CHECK:         %[[SITOFP:.+]] = linalg.generic
+// CHECK-SAME:       ins(%[[CST_10I]] : tensor<10xi8>) outs(%[[E]] : tensor<10xf32>)
+// CHECK:            %{{.*}} = arith.sitofp %{{.*}} : i8 to f32
+//
+// CHECK:         %[[ADD:.+]] = linalg.add ins(%[[DIV]], %[[SITOFP]] : tensor<10xf32>, tensor<10xf32>)
+// CHECK:         %{{.*}} = linalg.generic
+// CHECK-SAME:       ins(%[[ADD]] : tensor<10xf32>)
+// CHECK:            %{{.*}} = arith.fptosi %{{.*}} : f32 to i8
+
+
+!qalias = !quant.uniform<i8:f32, 2.0:10>
+func.func @lower_qcast_to_dps(%arg0: tensor<10xf32>) -> tensor<10x!qalias> {
+  %0 = quant.qcast %arg0 : tensor<10xf32> to tensor<10x!qalias>
+  return %0 : tensor<10x!qalias>
+}

@rengolin rengolin requested a review from kuhar September 10, 2025 13:16
Copy link

github-actions bot commented Sep 10, 2025

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff origin/main HEAD --extensions cpp,h -- mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp

⚠️
The reproduction instructions above might return results for more than one PR
in a stack if you are using a stacked PR workflow. You can limit the results by
changing origin/main to the base branch/commit you want to compare against.
⚠️

View the diff from clang-format here.
diff --git a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
index 2fc864594..58de12a4e 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/ConvertToDestinationStyle.cpp
@@ -653,7 +653,7 @@ rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
   auto genericOp = linalg::GenericOp::create(
       rewriter, loc, tensorType,
       op->getOperands(), // inputs
-      ValueRange{outs}, // outputs
+      ValueRange{outs},  // outputs
       indexingMaps, iteratorTypes,
       [&](OpBuilder &builder, Location loc, ValueRange args) {
         Value res;

@banach-space
Copy link
Contributor

banach-space commented Sep 10, 2025

Thanks Javed!

specifically is useful for loewr-quant-ops which operate on tensors and ops like qcast generates arith ops on tensors which without dps cannot bufferize

This sounds like a very good justification, but could you share a repro so that we can see what the issue is? In particular, it would be good to make sure that we are not missing some other, less involved, solution to this. This could be either here or as a GitHub issue if it's more involved.

Converts arith ops that operate on tensors but are not in destination passing style (DPS) to equivalent linalg generic which is in DPS.

Hm, isn't this saying "rewrite arith Ops on tensors as a linalg.generic"? As in, DPS is just a by-product/justification for this rewrite (as opposed to being the end goal)?

EDIT OK, sorry, I've just realised that you are adding this in ConvertToDestinationStyle.cpp, which explains the naming. A bit counter-intuitive to me, but oh well.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we go to linalg.generic instead of linalg.elementwise (and potentially generalize those later)? Are you worried about support for fast math flags etc.?

rewriteArithInDestinationPassingStyle(RewriterBase &rewriter, OpTy op) {
// reject ops such as `arith.constant` and `arith.select`.
auto numOperands = op->getNumOperands();
if (numOperands == 0 || numOperands > 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only unary and binary we care about.

@javedabsar1
Copy link
Contributor Author

but could you share a repro so that we can see what the issue is?

Thanks @banach-space for the review.
Reproducer, yes here it is (also @matthias-springer is aware of this) -

$ cat repro.mlir
!qalias = !quant.uniform<i8:f32, 2.0:10>
func.func @reproducer(%arg0: tensor<10xf32>) -> tensor<10xf32> {
  %0 = quant.qcast %arg0 : tensor<10xf32> to tensor<10x!qalias>
  %1 = quant.dcast %0 :  tensor<10x!qalias> to  tensor<10xf32>
  return %1 :  tensor<10xf32>
}

Not limited to lower-quant but here is an example. When we run

$ mlir-opt -lower-quant-ops -one-shot-bufferize repro.mlir
repro.mlir:3:8: error: op was not bufferized
  %0 = quant.qcast %arg0 : tensor<10xf32> to tensor<10x!qalias>                                                                                                     ^
repro.mlir:3:8: note: see current operation: %4 = "arith.divf"(%arg0, %3) <{fastmath = #arith.fastmath<none>}> : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>

@javedabsar1
Copy link
Contributor Author

Why do we go to linalg.generic instead of linalg.elementwise (and potentially generalize those later)? Are you worried about support for fast math flags etc.?

We could just use linalg-morph-ops=generic-to-named etc.

@javedabsar1
Copy link
Contributor Author

Why do we go to linalg.generic instead of linalg.elementwise (and potentially generalize those later)? Are you worried about support for fast math flags etc.?

@javedabsar1
Copy link
Contributor Author

Why do we go to linalg.generic instead of linalg.elementwise (and potentially generalize those later)? Are you worried about support for fast math flags etc.?

answered above.

@javedabsar1 javedabsar1 reopened this Sep 13, 2025
@javedabsar1
Copy link
Contributor Author

addressed all review comments.

: (!transform.any_op) -> !transform.any_op
transform.yield
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test case that contains fast math flags?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants