Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Conversation

lialan
Copy link
Member

@lialan lialan commented Sep 9, 2025

This commit:

  • Introduces a new InParallelOpInterface, along with the ParallelCombiningOpInterface, represent the parallel updating operations we have in a parallel loop of scf.forall.
  • Change the name of ParallelCombiningOpInterface to InParallelOpInterface as the naming was quite confusing.
  • ParallelCombiningOpInterface now is used to generalize operations that insert into shared tensors within parallel combining regions. Previously, only tensor.parallel_insert_slice was supported directly in scf.InParallelOp regions.
  • tensor.parallel_insert_slice now implements ParallelCombiningOpInterface.

This change enables future extensions to support additional parallel combining operations beyond tensor.parallel_insert_slice, which have different update semantics, so the in_parallel region can correctly and safely represent these kinds of operation without potential mistakes such as races.

Author credits: @qedawkins

@llvmbot
Copy link
Member

llvmbot commented Sep 9, 2025

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir-tensor

Author: Alan Li (lialan)

Changes

This commit introduces the InParallelOpInterface to generalize operations that insert into shared tensors within parallel combining regions. Previously, only tensor.parallel_insert_slice was supported directly in scf.InParallelOp regions.

  • Add InParallelOpInterface with getUpdatedDestinations() and getIteratingParent() methods to identify operations that update shared tensors in parallel regions
  • Update ParallelCombiningOpInterface verification to require only ops implementing InParallelOpInterface in its region
  • Move verification logic from hardcoded tensor.parallel_insert_slice checks to interface-based validation, also in places such as DropUnitDims and bufferizer to depend on interface rather than specific op.
  • tensor.ParallelInsertSliceOp now implements InParallelOpInterface

This change enables future extensions to support additional parallel combining operations beyond tensor.parallel_insert_slice, which have different update semantics, so the in_parallel region can correctly and safely represent these kinds of operation without potential mistakes such as races.

author credits: @qedawkins


Patch is 29.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/157736.diff

15 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+2)
  • (modified) mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h (+1)
  • (modified) mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td (+29)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+9-5)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp (+6-2)
  • (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+26-29)
  • (modified) mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp (+3-3)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+16-6)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+3-1)
  • (modified) mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp (+4-2)
  • (modified) mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp (+33-4)
  • (modified) mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir (+25-1)
  • (modified) mlir/test/Dialect/SCF/invalid.mlir (+4-6)
  • (modified) mlir/test/Dialect/SCF/one-shot-bufferize-analysis.mlir (+108)
  • (modified) mlir/test/Dialect/SCF/one-shot-bufferize-tensor-copy-insertion.mlir (+29-1)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 7d396e5c64c28..842a76e8fe90f 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1474,6 +1474,8 @@ def Tensor_PadOp : Tensor_Op<"pad", [
 def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
        AttrSizedOperandSegments,
        OffsetSizeAndStrideOpInterface,
+       DeclareOpInterfaceMethods<InParallelOpInterface,
+          ["getUpdatedDestinations", "getIteratingParent"]>,
        // TODO: Cannot use an interface here atm, verify this manually for now.
        // HasParent<"ParallelCombiningOpInterface">
   ]> {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
index 72db06163df37..e3441b8322d96 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
@@ -20,6 +20,7 @@ namespace mlir {
 namespace detail {
 // TODO: Single region single block interface on interfaces ?
 LogicalResult verifyParallelCombiningOpInterface(Operation *op);
+LogicalResult verifyInParallelOpInterface(Operation *op);
 } // namespace detail
 } // namespace mlir
 
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
index 424b4cf0a0a58..86eaf2c95462c 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
@@ -56,4 +56,33 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
   }];
 }
 
+def InParallelOpInterface : OpInterface<"InParallelOpInterface"> {
+  let description = [{
+    An in_parallel op is an operation that inserts into a shared tensor in
+    conjunction with a parent combining and iterating op.
+  }];
+  let cppNamespace = "::mlir";
+
+  let methods = [
+    InterfaceMethod<[{
+        Returns the list of values updated by this op.
+      }],
+      /*retTy=*/"::mlir::MutableOperandRange",
+      /*methodName=*/"getUpdatedDestinations",
+      /*args=*/(ins)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Returns the iterating parent for this op.
+      }],
+      /*retTy=*/"::mlir::Operation*",
+      /*methodName=*/"getIteratingParent",
+      /*args=*/(ins)
+    >,
+  ];
+  let verify = [{
+    return ::mlir::detail::verifyInParallelOpInterface($_op);
+  }];
+}
+
 #endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..840737fdb836b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -36,6 +36,7 @@
 #include "mlir/IR/BuiltinTypeInterfaces.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
 #include "mlir/Interfaces/TilingInterface.h"
 #include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -4140,11 +4141,14 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
     return DiagnosedSilenceableFailure::success();
   }
 
-  // If we are inside an InParallel region, temporarily set the insertion point
-  // outside: only tensor.parallel_insert_slice ops are allowed in there.
-  if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
-    rewriter.setInsertionPoint(
-        target->template getParentOfType<scf::InParallelOp>());
+  // If we are inside an ParallelCombiningOp region, temporarily set the
+  // insertion point outside: only ops implementing InParallelOpInterface are
+  // allowed in there.
+  if (isa<mlir::InParallelOpInterface>(target.getOperation())) {
+    if (auto combiningParent =
+            dyn_cast<ParallelCombiningOpInterface>(target->getParentOp())) {
+      rewriter.setInsertionPoint(target->getParentOp());
+    }
   }
 
   Value extracted = tensor::ExtractSliceOp::create(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 22690daa4f9e1..9eea88fb5a837 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -784,8 +784,12 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
       // The only difference between InsertSliceOp and ParallelInsertSliceOp
       // is the insertion point is just before the ParallelCombiningOp in the
       // parallel case.
-      if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
-        rewriter.setInsertionPoint(insertSliceOp->getParentOp());
+      if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value) {
+        if (auto combiningParent = dyn_cast<ParallelCombiningOpInterface>(
+                insertSliceOp->getParentOp())) {
+          rewriter.setInsertionPoint(insertSliceOp->getParentOp());
+        }
+      }
       reshapedSource = tensor::CollapseShapeOp::create(
           rewriter, loc, insertSliceOp.getSource(), *reassociation);
     }
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 84f9777a443fd..873dbbde48b37 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -680,8 +680,11 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
   SmallVector<Value> results;
   results.reserve(forallOp.getResults().size());
   for (auto &yieldingOp : terminator.getYieldingOps()) {
+    // Skip non-ParallelInsertSliceOp operations
     auto parallelInsertSliceOp =
-        cast<tensor::ParallelInsertSliceOp>(yieldingOp);
+        dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
+    if (!parallelInsertSliceOp)
+      continue;
 
     Value dst = parallelInsertSliceOp.getDest();
     Value src = parallelInsertSliceOp.getSource();
@@ -1437,14 +1440,12 @@ InParallelOp ForallOp::getTerminator() {
   return cast<InParallelOp>(getBody()->getTerminator());
 }
 
+
 SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
   SmallVector<Operation *> storeOps;
-  InParallelOp inParallelOp = getTerminator();
-  for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
-    if (auto parallelInsertSliceOp =
-            dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
-        parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
-      storeOps.push_back(parallelInsertSliceOp);
+  for (Operation *user : bbArg.getUsers()) {
+    if (auto parallelOp = dyn_cast<InParallelOpInterface>(user)) {
+      storeOps.push_back(parallelOp);
     }
   }
   return storeOps;
@@ -1673,7 +1674,12 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
     for (OpResult result : forallOp.getResults()) {
       OpOperand *opOperand = forallOp.getTiedOpOperand(result);
       BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
-      if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
+      SmallVector<Operation *> combiningOps =
+          forallOp.getCombiningOps(blockArg);
+      if ((result.use_empty() &&
+           llvm::all_of(combiningOps,
+                        [](Operation *op) { return op->use_empty(); })) ||
+          combiningOps.empty()) {
         resultToDelete.insert(result);
       } else {
         resultToReplace.push_back(result);
@@ -1911,8 +1917,9 @@ struct FoldTensorCastOfOutputIntoForallOp
     auto terminator = newForallOp.getTerminator();
     for (auto [yieldingOp, outputBlockArg] : llvm::zip(
              terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
-      auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
-      insertSliceOp.getDestMutable().assign(outputBlockArg);
+      auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
+      if (insertSliceOp)
+        insertSliceOp.getDestMutable().assign(outputBlockArg);
     }
 
     // Cast results back to the original types.
@@ -1971,19 +1978,6 @@ LogicalResult InParallelOp::verify() {
   if (!forallOp)
     return this->emitOpError("expected forall op parent");
 
-  // TODO: InParallelOpInterface.
-  for (Operation &op : getRegion().front().getOperations()) {
-    if (!isa<tensor::ParallelInsertSliceOp>(op)) {
-      return this->emitOpError("expected only ")
-             << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
-    }
-
-    // Verify that inserts are into out block arguments.
-    Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
-    ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
-    if (!llvm::is_contained(regionOutArgs, dest))
-      return op.emitOpError("may only insert into an output block argument");
-  }
   return success();
 }
 
@@ -2018,12 +2012,15 @@ OpResult InParallelOp::getParentResult(int64_t idx) {
 }
 
 SmallVector<BlockArgument> InParallelOp::getDests() {
-  return llvm::to_vector<4>(
-      llvm::map_range(getYieldingOps(), [](Operation &op) {
-        // Add new ops here as needed.
-        auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
-        return llvm::cast<BlockArgument>(insertSliceOp.getDest());
-      }));
+  SmallVector<BlockArgument> updatedDests;
+  for (auto &yieldingOp : getYieldingOps()) {
+    auto inParallelOp = dyn_cast<InParallelOpInterface>(&yieldingOp);
+    if (!inParallelOp)
+      continue;
+    for (auto &updatedOperand : inParallelOp.getUpdatedDestinations())
+      updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
+  }
+  return updatedDests;
 }
 
 llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index a44612410bdee..d70392131df51 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -40,8 +40,8 @@ namespace {
 ///   <implicit in_parallel terminator here>
 /// }
 /// ```
-struct InParallelOpInterface
-    : public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
+struct InParallelDeallocOpInterface
+    : public BufferDeallocationOpInterface::ExternalModel<InParallelDeallocOpInterface,
                                                           scf::InParallelOp> {
   FailureOr<Operation *> process(Operation *op, DeallocationState &state,
                                  const DeallocationOptions &options) const {
@@ -75,7 +75,7 @@ struct ReduceReturnOpInterface
 void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
     DialectRegistry &registry) {
   registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
-    InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
+    InParallelOp::attachInterface<InParallelDeallocOpInterface>(*ctx);
     ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
   });
 }
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 68584ec4fd814..3770690c21a03 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2978,7 +2978,7 @@ class InsertSliceOpConstantArgumentFolder final
       // The only difference between InsertSliceOp and ParallelInsertSliceOp
       // is that the insertion point is just before the ParallelCombiningOp in
       // the parallel case.
-      if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
+      if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
         rewriter.setInsertionPoint(insertSliceOp->getParentOp());
       toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
                                         sourceType, toInsert);
@@ -3155,7 +3155,7 @@ struct InsertSliceOpSourceCastInserter final
     // The only difference between InsertSliceOp and ParallelInsertSliceOp is
     // that the insertion point is just before the ParallelCombiningOp in the
     // parallel case.
-    if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
+    if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
       rewriter.setInsertionPoint(insertSliceOp->getParentOp());
     Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
                                         newSrcType, insertSliceOp.getSource());
@@ -3901,10 +3901,6 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
 }
 
 LogicalResult ParallelInsertSliceOp::verify() {
-  if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
-    return this->emitError("expected ParallelCombiningOpInterface parent, got:")
-           << *(getOperation()->getParentOp());
-
   // Verify result type against inferred type.
   RankedTensorType expectedType;
   SliceVerificationResult result =
@@ -3935,6 +3931,20 @@ llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
   return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
 }
 
+// InParallelOpInterface implementation
+MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
+  return getDestMutable();
+}
+
+Operation *ParallelInsertSliceOp::getIteratingParent() {
+  // Return the parent ParallelCombiningOpInterface's parent
+  if (auto combiningOp = dyn_cast<ParallelCombiningOpInterface>(
+          getOperation()->getParentOp())) {
+    return combiningOp->getParentOp();
+  }
+  return nullptr;
+}
+
 //===----------------------------------------------------------------------===//
 // ScatterOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c3356c1e4b9d8..def56687477db 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -974,7 +974,9 @@ struct ParallelInsertSliceOpInterface
         parallelInsertSliceOp.getParallelCombiningParent();
 
     // Bufferize the op outside of the parallel combining terminator.
-    rewriter.setInsertionPoint(parallelCombiningParent);
+    if (parallelCombiningParent) {
+      rewriter.setInsertionPoint(parallelCombiningParent);
+    }
 
     // Get source and destination buffers.
     FailureOr<Value> destBuffer =
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index d76c02af7ab16..0c0380a370d56 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -219,8 +219,10 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
     // point outside: only tensor.parallel_insert_slice ops are allowed in
     // there.
     if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
-      rewriter.setInsertionPoint(
-          insertSliceOp->template getParentOfType<scf::InParallelOp>());
+      if (auto combiningParent = dyn_cast<ParallelCombiningOpInterface>(
+              insertSliceOp->getParentOp())) {
+        rewriter.setInsertionPoint(insertSliceOp->getParentOp());
+      }  
     }
 
     // Resolve offsets according to source offsets and strides.
diff --git a/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp b/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp
index 2b6703543bbd3..30fcbf0ab3be6 100644
--- a/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp
+++ b/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp
@@ -10,18 +10,47 @@
 
 using namespace mlir;
 
+/// Include the definitions of the interface.
+#include "mlir/Interfaces/ParallelCombiningOpInterface.cpp.inc"
+
 //===----------------------------------------------------------------------===//
-// ParallelCombiningOpInterface
+// InParallelOpInterface
 //===----------------------------------------------------------------------===//
 
+// TODO: Catch-22 with interface methods used to verify means methods can't
+// assume the impl is valid.
+LogicalResult mlir::detail::verifyInParallelOpInterface(Operation *op) {
+  auto inParallel = cast<InParallelOpInterface>(op);
+  auto parent = inParallel.getIteratingParent();
+  if (!parent) {
+    return op->emitError(
+        "in_parallel interface op must have an iterating parent");
+  }
+
+  // Simple verification without requiring ParallelIterationOpInterface
+  // Just check that updated destinations are block arguments
+  for (OpOperand &updatedValue : inParallel.getUpdatedDestinations()) {
+    auto bbArg = dyn_cast<BlockArgument>(updatedValue.get());
+    if (!bbArg) {
+      return op->emitError("updating a non block argument");
+    }
+  }
+  return success();
+}
+
+
+//===----------------------------------------------------------------------===//
+// ParallelCombiningOpInterface
+//===----------------------------------------------------------------------===//
 // TODO: Single region single block interface on interfaces ?
 LogicalResult mlir::detail::verifyParallelCombiningOpInterface(Operation *op) {
   if (op->getNumRegions() != 1)
     return op->emitError("expected single region op");
   if (!op->getRegion(0).hasOneBlock())
     return op->emitError("expected single block op region");
+  for (Operation &child : *op->getRegion(0).getBlocks().begin()) {
+    if (!isa<InParallelOpInterface>(&child))
+      return op->emitError("expected only in_parallel interface ops");
+  }
   return success();
 }
-
-/// Include the definitions of the interface.
-#include "mlir/Interfaces/ParallelCombiningOpInterface.cpp.inc"
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 5f42938244db6..d498f30289fa4 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -915,7 +915,7 @@ func.func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> te
 
 // -----
 
-func.func @reduce_dispatch_0() -> tensor<4x2xf32> {
+func.func @parallel_insert_slice() -> tensor<4x2xf32> {
   %c2 = arith.constant 2 : index
   %c4 = arith.constant 4 : index
   %cst = arith.constant 0.000000e+00 : f32
@@ -923,6 +923,7 @@ func.func @reduce_dispatch_0() -> tensor<4x2xf32> {
   %res = scf.forall (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) {
     %1 = tensor.empty() : tensor<1x1xf32>
     %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
+    // CHECK: scf.forall.in_parallel
     scf.forall.in_parallel {
       //      CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}}
       // CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<4x2xf32>
@@ -935,6 +936,29 @@ func.func @reduce_dispatch_0() -> tensor<4x2xf32> {
 
 // -----
 
+// CHECK-LABEL: func @parallel_insert_slice_no_terminator
+func.func @parallel_insert_slice_no_terminator() -> tensor<4x2xf32> {
+  %c2 = arith.constant 2 : index
+  %c4 = arith.constant 4 : index
+  %cst = arith.constant 0.000000e+00 : f32
+  %0 = tensor.empty() : tensor<4x2xf32>
+  // CHECK: scf.forall
+  %res = scf.forall (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) {
+    %1 = tensor.empty() : tensor<1x1xf32>
+    %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
+    //      CHECK: scf.forall.in_parallel
+    //      CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}}
+    // CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<4x2xf32>
+    scf.forall.in_parallel {
+      tensor.parallel_insert_slice %2 into %o[%arg0, %arg1] [1, 1] [1, 1] :
+        tensor<1x1xf32> into tensor<4x2xf32>
+    }
+  }
+  return %res: tensor<4x2xf32>
+}
+
+// -----
+
 #map0 = affine_map<(i, j) -> (i, j)>
 #access = [#map0, #map0]
 #trait = {
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index bb7958083e55c..d8455b47f6b1d 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Diale...
[truncated]

Copy link

github-actions bot commented Sep 9, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@qedawkins qedawkins left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Top level naming comment, I think the InParallelOpInterface and ParallelCombiningOpInterface are confusingly assigned to the opposite of what we'd expect. scf.forall.in_parallel not being an InParallelOpInterface is surprising so we should swap the names.

@lialan lialan force-pushed the rebase-inparallel-interface branch from 05d8701 to 7634fac Compare September 10, 2025 15:10
@lialan lialan force-pushed the rebase-inparallel-interface branch from 2701105 to feb5939 Compare September 10, 2025 19:29
Copy link
Collaborator

@joker-eph joker-eph left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That LGTM, but someone else should like also have a look at these interfaces changes, so I'd be more comfortable with another approval.

Thanks!

@llvmbot llvmbot added the bazel "Peripheral" support tier build system: utils/bazel label Sep 11, 2025
Copy link
Contributor

@hanhanW hanhanW left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just some nits.

@lialan lialan force-pushed the rebase-inparallel-interface branch from 8462f2c to 4eaa746 Compare September 12, 2025 20:57
@lialan lialan force-pushed the rebase-inparallel-interface branch from 4eaa746 to abb402f Compare September 12, 2025 21:12
@lialan lialan merged commit b87f1b2 into llvm:main Sep 12, 2025
10 checks passed
@lialan lialan deleted the rebase-inparallel-interface branch September 12, 2025 21:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bazel "Peripheral" support tier build system: utils/bazel mlir:linalg mlir:scf mlir:tensor mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants