-
Notifications
You must be signed in to change notification settings - Fork 15k
[MLIR] Add InParallelOpInterface
for parallel combining operations
#157736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-tensor Author: Alan Li (lialan) ChangesThis commit introduces the
This change enables future extensions to support additional parallel combining operations beyond author credits: @qedawkins Patch is 29.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/157736.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index 7d396e5c64c28..842a76e8fe90f 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1474,6 +1474,8 @@ def Tensor_PadOp : Tensor_Op<"pad", [
def Tensor_ParallelInsertSliceOp : Tensor_Op<"parallel_insert_slice", [
AttrSizedOperandSegments,
OffsetSizeAndStrideOpInterface,
+ DeclareOpInterfaceMethods<InParallelOpInterface,
+ ["getUpdatedDestinations", "getIteratingParent"]>,
// TODO: Cannot use an interface here atm, verify this manually for now.
// HasParent<"ParallelCombiningOpInterface">
]> {
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
index 72db06163df37..e3441b8322d96 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.h
@@ -20,6 +20,7 @@ namespace mlir {
namespace detail {
// TODO: Single region single block interface on interfaces ?
LogicalResult verifyParallelCombiningOpInterface(Operation *op);
+LogicalResult verifyInParallelOpInterface(Operation *op);
} // namespace detail
} // namespace mlir
diff --git a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
index 424b4cf0a0a58..86eaf2c95462c 100644
--- a/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
+++ b/mlir/include/mlir/Interfaces/ParallelCombiningOpInterface.td
@@ -56,4 +56,33 @@ def ParallelCombiningOpInterface : OpInterface<"ParallelCombiningOpInterface"> {
}];
}
+def InParallelOpInterface : OpInterface<"InParallelOpInterface"> {
+ let description = [{
+ An in_parallel op is an operation that inserts into a shared tensor in
+ conjunction with a parent combining and iterating op.
+ }];
+ let cppNamespace = "::mlir";
+
+ let methods = [
+ InterfaceMethod<[{
+ Returns the list of values updated by this op.
+ }],
+ /*retTy=*/"::mlir::MutableOperandRange",
+ /*methodName=*/"getUpdatedDestinations",
+ /*args=*/(ins)
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Returns the iterating parent for this op.
+ }],
+ /*retTy=*/"::mlir::Operation*",
+ /*methodName=*/"getIteratingParent",
+ /*args=*/(ins)
+ >,
+ ];
+ let verify = [{
+ return ::mlir::detail::verifyInParallelOpInterface($_op);
+ }];
+}
+
#endif // MLIR_INTERFACES_PARALLELCOMBININGOPINTERFACE
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index f0c1f4485b054..840737fdb836b 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -36,6 +36,7 @@
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
+#include "mlir/Interfaces/ParallelCombiningOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -4140,11 +4141,14 @@ DiagnosedSilenceableFailure doit(RewriterBase &rewriter, OpTy target,
return DiagnosedSilenceableFailure::success();
}
- // If we are inside an InParallel region, temporarily set the insertion point
- // outside: only tensor.parallel_insert_slice ops are allowed in there.
- if constexpr (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
- rewriter.setInsertionPoint(
- target->template getParentOfType<scf::InParallelOp>());
+ // If we are inside an ParallelCombiningOp region, temporarily set the
+ // insertion point outside: only ops implementing InParallelOpInterface are
+ // allowed in there.
+ if (isa<mlir::InParallelOpInterface>(target.getOperation())) {
+ if (auto combiningParent =
+ dyn_cast<ParallelCombiningOpInterface>(target->getParentOp())) {
+ rewriter.setInsertionPoint(target->getParentOp());
+ }
}
Value extracted = tensor::ExtractSliceOp::create(
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
index 22690daa4f9e1..9eea88fb5a837 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
@@ -784,8 +784,12 @@ struct RankReducedInsertSliceOp : public OpRewritePattern<InsertOpTy> {
// The only difference between InsertSliceOp and ParallelInsertSliceOp
// is the insertion point is just before the ParallelCombiningOp in the
// parallel case.
- if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value)
- rewriter.setInsertionPoint(insertSliceOp->getParentOp());
+ if (std::is_same<InsertOpTy, tensor::ParallelInsertSliceOp>::value) {
+ if (auto combiningParent = dyn_cast<ParallelCombiningOpInterface>(
+ insertSliceOp->getParentOp())) {
+ rewriter.setInsertionPoint(insertSliceOp->getParentOp());
+ }
+ }
reshapedSource = tensor::CollapseShapeOp::create(
rewriter, loc, insertSliceOp.getSource(), *reassociation);
}
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 84f9777a443fd..873dbbde48b37 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -680,8 +680,11 @@ void mlir::scf::promote(RewriterBase &rewriter, scf::ForallOp forallOp) {
SmallVector<Value> results;
results.reserve(forallOp.getResults().size());
for (auto &yieldingOp : terminator.getYieldingOps()) {
+ // Skip non-ParallelInsertSliceOp operations
auto parallelInsertSliceOp =
- cast<tensor::ParallelInsertSliceOp>(yieldingOp);
+ dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
+ if (!parallelInsertSliceOp)
+ continue;
Value dst = parallelInsertSliceOp.getDest();
Value src = parallelInsertSliceOp.getSource();
@@ -1437,14 +1440,12 @@ InParallelOp ForallOp::getTerminator() {
return cast<InParallelOp>(getBody()->getTerminator());
}
+
SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) {
SmallVector<Operation *> storeOps;
- InParallelOp inParallelOp = getTerminator();
- for (Operation &yieldOp : inParallelOp.getYieldingOps()) {
- if (auto parallelInsertSliceOp =
- dyn_cast<tensor::ParallelInsertSliceOp>(yieldOp);
- parallelInsertSliceOp && parallelInsertSliceOp.getDest() == bbArg) {
- storeOps.push_back(parallelInsertSliceOp);
+ for (Operation *user : bbArg.getUsers()) {
+ if (auto parallelOp = dyn_cast<InParallelOpInterface>(user)) {
+ storeOps.push_back(parallelOp);
}
}
return storeOps;
@@ -1673,7 +1674,12 @@ struct ForallOpIterArgsFolder : public OpRewritePattern<ForallOp> {
for (OpResult result : forallOp.getResults()) {
OpOperand *opOperand = forallOp.getTiedOpOperand(result);
BlockArgument blockArg = forallOp.getTiedBlockArgument(opOperand);
- if (result.use_empty() || forallOp.getCombiningOps(blockArg).empty()) {
+ SmallVector<Operation *> combiningOps =
+ forallOp.getCombiningOps(blockArg);
+ if ((result.use_empty() &&
+ llvm::all_of(combiningOps,
+ [](Operation *op) { return op->use_empty(); })) ||
+ combiningOps.empty()) {
resultToDelete.insert(result);
} else {
resultToReplace.push_back(result);
@@ -1911,8 +1917,9 @@ struct FoldTensorCastOfOutputIntoForallOp
auto terminator = newForallOp.getTerminator();
for (auto [yieldingOp, outputBlockArg] : llvm::zip(
terminator.getYieldingOps(), newForallOp.getRegionIterArgs())) {
- auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(yieldingOp);
- insertSliceOp.getDestMutable().assign(outputBlockArg);
+ auto insertSliceOp = dyn_cast<tensor::ParallelInsertSliceOp>(yieldingOp);
+ if (insertSliceOp)
+ insertSliceOp.getDestMutable().assign(outputBlockArg);
}
// Cast results back to the original types.
@@ -1971,19 +1978,6 @@ LogicalResult InParallelOp::verify() {
if (!forallOp)
return this->emitOpError("expected forall op parent");
- // TODO: InParallelOpInterface.
- for (Operation &op : getRegion().front().getOperations()) {
- if (!isa<tensor::ParallelInsertSliceOp>(op)) {
- return this->emitOpError("expected only ")
- << tensor::ParallelInsertSliceOp::getOperationName() << " ops";
- }
-
- // Verify that inserts are into out block arguments.
- Value dest = cast<tensor::ParallelInsertSliceOp>(op).getDest();
- ArrayRef<BlockArgument> regionOutArgs = forallOp.getRegionOutArgs();
- if (!llvm::is_contained(regionOutArgs, dest))
- return op.emitOpError("may only insert into an output block argument");
- }
return success();
}
@@ -2018,12 +2012,15 @@ OpResult InParallelOp::getParentResult(int64_t idx) {
}
SmallVector<BlockArgument> InParallelOp::getDests() {
- return llvm::to_vector<4>(
- llvm::map_range(getYieldingOps(), [](Operation &op) {
- // Add new ops here as needed.
- auto insertSliceOp = cast<tensor::ParallelInsertSliceOp>(&op);
- return llvm::cast<BlockArgument>(insertSliceOp.getDest());
- }));
+ SmallVector<BlockArgument> updatedDests;
+ for (auto &yieldingOp : getYieldingOps()) {
+ auto inParallelOp = dyn_cast<InParallelOpInterface>(&yieldingOp);
+ if (!inParallelOp)
+ continue;
+ for (auto &updatedOperand : inParallelOp.getUpdatedDestinations())
+ updatedDests.push_back(cast<BlockArgument>(updatedOperand.get()));
+ }
+ return updatedDests;
}
llvm::iterator_range<Block::iterator> InParallelOp::getYieldingOps() {
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
index a44612410bdee..d70392131df51 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferDeallocationOpInterfaceImpl.cpp
@@ -40,8 +40,8 @@ namespace {
/// <implicit in_parallel terminator here>
/// }
/// ```
-struct InParallelOpInterface
- : public BufferDeallocationOpInterface::ExternalModel<InParallelOpInterface,
+struct InParallelDeallocOpInterface
+ : public BufferDeallocationOpInterface::ExternalModel<InParallelDeallocOpInterface,
scf::InParallelOp> {
FailureOr<Operation *> process(Operation *op, DeallocationState &state,
const DeallocationOptions &options) const {
@@ -75,7 +75,7 @@ struct ReduceReturnOpInterface
void mlir::scf::registerBufferDeallocationOpInterfaceExternalModels(
DialectRegistry ®istry) {
registry.addExtension(+[](MLIRContext *ctx, SCFDialect *dialect) {
- InParallelOp::attachInterface<InParallelOpInterface>(*ctx);
+ InParallelOp::attachInterface<InParallelDeallocOpInterface>(*ctx);
ReduceReturnOp::attachInterface<ReduceReturnOpInterface>(*ctx);
});
}
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 68584ec4fd814..3770690c21a03 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -2978,7 +2978,7 @@ class InsertSliceOpConstantArgumentFolder final
// The only difference between InsertSliceOp and ParallelInsertSliceOp
// is that the insertion point is just before the ParallelCombiningOp in
// the parallel case.
- if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
+ if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
toInsert = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
sourceType, toInsert);
@@ -3155,7 +3155,7 @@ struct InsertSliceOpSourceCastInserter final
// The only difference between InsertSliceOp and ParallelInsertSliceOp is
// that the insertion point is just before the ParallelCombiningOp in the
// parallel case.
- if (std::is_same<InsertOpTy, ParallelInsertSliceOp>::value)
+ if (isa<ParallelCombiningOpInterface>(insertSliceOp->getParentOp()))
rewriter.setInsertionPoint(insertSliceOp->getParentOp());
Value cast = tensor::CastOp::create(rewriter, insertSliceOp.getLoc(),
newSrcType, insertSliceOp.getSource());
@@ -3901,10 +3901,6 @@ void ParallelInsertSliceOp::build(OpBuilder &b, OperationState &result,
}
LogicalResult ParallelInsertSliceOp::verify() {
- if (!isa<ParallelCombiningOpInterface>(getOperation()->getParentOp()))
- return this->emitError("expected ParallelCombiningOpInterface parent, got:")
- << *(getOperation()->getParentOp());
-
// Verify result type against inferred type.
RankedTensorType expectedType;
SliceVerificationResult result =
@@ -3935,6 +3931,20 @@ llvm::SmallBitVector ParallelInsertSliceOp::getDroppedDims() {
return ::getDroppedDims(getSourceType().getShape(), getMixedSizes());
}
+// InParallelOpInterface implementation
+MutableOperandRange ParallelInsertSliceOp::getUpdatedDestinations() {
+ return getDestMutable();
+}
+
+Operation *ParallelInsertSliceOp::getIteratingParent() {
+ // Return the parent ParallelCombiningOpInterface's parent
+ if (auto combiningOp = dyn_cast<ParallelCombiningOpInterface>(
+ getOperation()->getParentOp())) {
+ return combiningOp->getParentOp();
+ }
+ return nullptr;
+}
+
//===----------------------------------------------------------------------===//
// ScatterOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index c3356c1e4b9d8..def56687477db 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -974,7 +974,9 @@ struct ParallelInsertSliceOpInterface
parallelInsertSliceOp.getParallelCombiningParent();
// Bufferize the op outside of the parallel combining terminator.
- rewriter.setInsertionPoint(parallelCombiningParent);
+ if (parallelCombiningParent) {
+ rewriter.setInsertionPoint(parallelCombiningParent);
+ }
// Get source and destination buffers.
FailureOr<Value> destBuffer =
diff --git a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
index d76c02af7ab16..0c0380a370d56 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/FoldTensorSubsetOps.cpp
@@ -219,8 +219,10 @@ struct InsertSliceOfInsertSliceFolder : public OpRewritePattern<OpTy> {
// point outside: only tensor.parallel_insert_slice ops are allowed in
// there.
if (std::is_same_v<OpTy, tensor::ParallelInsertSliceOp>) {
- rewriter.setInsertionPoint(
- insertSliceOp->template getParentOfType<scf::InParallelOp>());
+ if (auto combiningParent = dyn_cast<ParallelCombiningOpInterface>(
+ insertSliceOp->getParentOp())) {
+ rewriter.setInsertionPoint(insertSliceOp->getParentOp());
+ }
}
// Resolve offsets according to source offsets and strides.
diff --git a/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp b/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp
index 2b6703543bbd3..30fcbf0ab3be6 100644
--- a/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp
+++ b/mlir/lib/Interfaces/ParallelCombiningOpInterface.cpp
@@ -10,18 +10,47 @@
using namespace mlir;
+/// Include the definitions of the interface.
+#include "mlir/Interfaces/ParallelCombiningOpInterface.cpp.inc"
+
//===----------------------------------------------------------------------===//
-// ParallelCombiningOpInterface
+// InParallelOpInterface
//===----------------------------------------------------------------------===//
+// TODO: Catch-22 with interface methods used to verify means methods can't
+// assume the impl is valid.
+LogicalResult mlir::detail::verifyInParallelOpInterface(Operation *op) {
+ auto inParallel = cast<InParallelOpInterface>(op);
+ auto parent = inParallel.getIteratingParent();
+ if (!parent) {
+ return op->emitError(
+ "in_parallel interface op must have an iterating parent");
+ }
+
+ // Simple verification without requiring ParallelIterationOpInterface
+ // Just check that updated destinations are block arguments
+ for (OpOperand &updatedValue : inParallel.getUpdatedDestinations()) {
+ auto bbArg = dyn_cast<BlockArgument>(updatedValue.get());
+ if (!bbArg) {
+ return op->emitError("updating a non block argument");
+ }
+ }
+ return success();
+}
+
+
+//===----------------------------------------------------------------------===//
+// ParallelCombiningOpInterface
+//===----------------------------------------------------------------------===//
// TODO: Single region single block interface on interfaces ?
LogicalResult mlir::detail::verifyParallelCombiningOpInterface(Operation *op) {
if (op->getNumRegions() != 1)
return op->emitError("expected single region op");
if (!op->getRegion(0).hasOneBlock())
return op->emitError("expected single block op region");
+ for (Operation &child : *op->getRegion(0).getBlocks().begin()) {
+ if (!isa<InParallelOpInterface>(&child))
+ return op->emitError("expected only in_parallel interface ops");
+ }
return success();
}
-
-/// Include the definitions of the interface.
-#include "mlir/Interfaces/ParallelCombiningOpInterface.cpp.inc"
diff --git a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
index 5f42938244db6..d498f30289fa4 100644
--- a/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
+++ b/mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
@@ -915,7 +915,7 @@ func.func @sparse_case(%arg0: tensor<8x8xf32, #CSR>, %arg1: tensor<8xf32>) -> te
// -----
-func.func @reduce_dispatch_0() -> tensor<4x2xf32> {
+func.func @parallel_insert_slice() -> tensor<4x2xf32> {
%c2 = arith.constant 2 : index
%c4 = arith.constant 4 : index
%cst = arith.constant 0.000000e+00 : f32
@@ -923,6 +923,7 @@ func.func @reduce_dispatch_0() -> tensor<4x2xf32> {
%res = scf.forall (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) {
%1 = tensor.empty() : tensor<1x1xf32>
%2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
+ // CHECK: scf.forall.in_parallel
scf.forall.in_parallel {
// CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}}
// CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<4x2xf32>
@@ -935,6 +936,29 @@ func.func @reduce_dispatch_0() -> tensor<4x2xf32> {
// -----
+// CHECK-LABEL: func @parallel_insert_slice_no_terminator
+func.func @parallel_insert_slice_no_terminator() -> tensor<4x2xf32> {
+ %c2 = arith.constant 2 : index
+ %c4 = arith.constant 4 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<4x2xf32>
+ // CHECK: scf.forall
+ %res = scf.forall (%arg0, %arg1) in (%c4, %c2) shared_outs(%o = %0) -> (tensor<4x2xf32>) {
+ %1 = tensor.empty() : tensor<1x1xf32>
+ %2 = linalg.fill ins(%cst : f32) outs(%1 : tensor<1x1xf32>) -> tensor<1x1xf32>
+ // CHECK: scf.forall.in_parallel
+ // CHECK: tensor.parallel_insert_slice %{{[0-9a-z]*}} into %{{[0-9a-z]*}}
+ // CHECK-SAME: [%{{.*}}, %{{.*}}] [1, 1] [1, 1] : tensor<f32> into tensor<4x2xf32>
+ scf.forall.in_parallel {
+ tensor.parallel_insert_slice %2 into %o[%arg0, %arg1] [1, 1] [1, 1] :
+ tensor<1x1xf32> into tensor<4x2xf32>
+ }
+ }
+ return %res: tensor<4x2xf32>
+}
+
+// -----
+
#map0 = affine_map<(i, j) -> (i, j)>
#access = [#map0, #map0]
#trait = {
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index bb7958083e55c..d8455b47f6b1d 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Diale...
[truncated]
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Top level naming comment, I think the InParallelOpInterface
and ParallelCombiningOpInterface
are confusingly assigned to the opposite of what we'd expect. scf.forall.in_parallel
not being an InParallelOpInterface
is surprising so we should swap the names.
mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
Outdated
Show resolved
Hide resolved
05d8701
to
7634fac
Compare
2701105
to
feb5939
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That LGTM, but someone else should like also have a look at these interfaces changes, so I'd be more comfortable with another approval.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just some nits.
8462f2c
to
4eaa746
Compare
4eaa746
to
abb402f
Compare
This commit:
InParallelOpInterface
, along with theParallelCombiningOpInterface
, represent the parallel updating operations we have in a parallel loop ofscf.forall
.ParallelCombiningOpInterface
toInParallelOpInterface
as the naming was quite confusing.ParallelCombiningOpInterface
now is used to generalize operations that insert into shared tensors within parallel combining regions. Previously, onlytensor.parallel_insert_slice
was supported directly inscf.InParallelOp
regions.tensor.parallel_insert_slice
now implementsParallelCombiningOpInterface
.This change enables future extensions to support additional parallel combining operations beyond
tensor.parallel_insert_slice
, which have different update semantics, so thein_parallel
region can correctly and safely represent these kinds of operation without potential mistakes such as races.Author credits: @qedawkins