-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][OpenMP] cancel(lation point) taskgroup LLVMIR #137841
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: users/tblah/omp-cancel-codegen-3
Are you sure you want to change the base?
[mlir][OpenMP] cancel(lation point) taskgroup LLVMIR #137841
Conversation
A cancel or cancellation point for taskgroup is always nested inside of a task inside of the taskgroup. For the task which is cancelled, it is that task which needs to be cleaned up: not the owning taskgroup. Therefore the cancellation branch handler is done in the conversion of the task not in conversion of taskgroup. I added a firstprivate clause to the test for cancel taskgroup to demonstrate that the block being branched to is the same block where mandatory cleanup code is added. Cancellation point follows exactly the same code path.
@llvm/pr-subscribers-mlir-llvm @llvm/pr-subscribers-mlir Author: Tom Eccles (tblah) ChangesA cancel or cancellation point for taskgroup is always nested inside of a task inside of the taskgroup. For the task which is cancelled, it is that task which needs to be cleaned up: not the owning taskgroup. Therefore the cancellation branch handler is done in the conversion of the task not in conversion of taskgroup. I added a firstprivate clause to the test for cancel taskgroup to demonstrate that the block being branched to is the same block where mandatory cleanup code is added. Cancellation point follows exactly the same code path. Full diff: https://github.com/llvm/llvm-project/pull/137841.diff 5 Files Affected:
diff --git a/flang/docs/OpenMPSupport.md b/flang/docs/OpenMPSupport.md
index 2d4b9dd737777..a535866d665ac 100644
--- a/flang/docs/OpenMPSupport.md
+++ b/flang/docs/OpenMPSupport.md
@@ -51,8 +51,8 @@ Note : No distinction is made between the support in Parser/Semantics, MLIR, Low
| depend clause | P | depend clause with array sections are not supported |
| declare reduction construct | N | |
| atomic construct extensions | Y | |
-| cancel construct | N | |
-| cancellation point construct | N | |
+| cancel construct | Y | |
+| cancellation point construct | Y | |
| parallel do simd construct | P | linear clause is not supported |
| target teams construct | P | device and reduction clauses are not supported |
| teams distribute construct | P | reduction and dist_schedule clauses not supported |
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 11ac5e5bf1e7f..c42397de1125b 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -158,12 +158,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.getBare())
result = todo("ompx_bare");
};
- auto checkCancelDirective = [&todo](auto op, LogicalResult &result) {
- omp::ClauseCancellationConstructType cancelledDirective =
- op.getCancelDirective();
- if (cancelledDirective == omp::ClauseCancellationConstructType::Taskgroup)
- result = todo("cancel directive construct type not yet supported");
- };
auto checkDepend = [&todo](auto op, LogicalResult &result) {
if (!op.getDependVars().empty() || op.getDependKinds())
result = todo("depend");
@@ -254,10 +248,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
LogicalResult result = success();
llvm::TypeSwitch<Operation &>(op)
- .Case([&](omp::CancelOp op) { checkCancelDirective(op, result); })
- .Case([&](omp::CancellationPointOp op) {
- checkCancelDirective(op, result);
- })
.Case([&](omp::DistributeOp op) {
checkAllocate(op, result);
checkDistSchedule(op, result);
@@ -1902,6 +1892,55 @@ buildDependData(std::optional<ArrayAttr> dependKinds, OperandRange dependVars,
}
}
+/// Shared implementation of a callback which adds a termiator for the new block
+/// created for the branch taken when an openmp construct is cancelled. The
+/// terminator is saved in \p cancelTerminators. This callback is invoked only
+/// if there is cancellation inside of the taskgroup body.
+/// The terminator will need to be fixed to branch to the correct block to
+/// cleanup the construct.
+static void
+pushCancelFinalizationCB(SmallVectorImpl<llvm::BranchInst *> &cancelTerminators,
+ llvm::IRBuilderBase &llvmBuilder,
+ llvm::OpenMPIRBuilder &ompBuilder, mlir::Operation *op,
+ llvm::omp::Directive cancelDirective) {
+ auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
+ llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
+
+ // ip is currently in the block branched to if cancellation occured.
+ // We need to create a branch to terminate that block.
+ llvmBuilder.restoreIP(ip);
+
+ // We must still clean up the construct after cancelling it, so we need to
+ // branch to the block that finalizes the taskgroup.
+ // That block has not been created yet so use this block as a dummy for now
+ // and fix this after creating the operation.
+ cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
+ return llvm::Error::success();
+ };
+ // We have to add the cleanup to the OpenMPIRBuilder before the body gets
+ // created in case the body contains omp.cancel (which will then expect to be
+ // able to find this cleanup callback).
+ ompBuilder.pushFinalizationCB(
+ {finiCB, cancelDirective, constructIsCancellable(op)});
+}
+
+/// If we cancelled the construct, we should branch to the finalization block of
+/// that construct. OMPIRBuilder structures the CFG such that the cleanup block
+/// is immediately before the continuation block. Now this finalization has
+/// been created we can fix the branch.
+static void
+popCancelFinalizationCB(const ArrayRef<llvm::BranchInst *> cancelTerminators,
+ llvm::OpenMPIRBuilder &ompBuilder,
+ const llvm::OpenMPIRBuilder::InsertPointTy &afterIP) {
+ ompBuilder.popFinalizationCB();
+ llvm::BasicBlock *constructFini = afterIP.getBlock()->getSinglePredecessor();
+ for (llvm::BranchInst *cancelBranch : cancelTerminators) {
+ assert(cancelBranch->getNumSuccessors() == 1 &&
+ "cancel branch should have one target");
+ cancelBranch->setSuccessor(0, constructFini);
+ }
+}
+
namespace {
/// TaskContextStructManager takes care of creating and freeing a structure
/// containing information needed by the task body to execute.
@@ -2215,6 +2254,14 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
return llvm::Error::success();
};
+ llvm::OpenMPIRBuilder &ompBuilder = *moduleTranslation.getOpenMPBuilder();
+ SmallVector<llvm::BranchInst *> cancelTerminators;
+ // The directive to match here is OMPD_taskgroup because it is the taskgroup
+ // which is canceled. This is handled here because it is the task's cleanup
+ // block which should be branched to.
+ pushCancelFinalizationCB(cancelTerminators, builder, ompBuilder, taskOp,
+ llvm::omp::Directive::OMPD_taskgroup);
+
SmallVector<llvm::OpenMPIRBuilder::DependData> dds;
buildDependData(taskOp.getDependKinds(), taskOp.getDependVars(),
moduleTranslation, dds);
@@ -2232,6 +2279,9 @@ convertOmpTaskOp(omp::TaskOp taskOp, llvm::IRBuilderBase &builder,
if (failed(handleError(afterIP, *taskOp)))
return failure();
+ // Set the correct branch target for task cancellation
+ popCancelFinalizationCB(cancelTerminators, ompBuilder, afterIP.get());
+
builder.restoreIP(*afterIP);
return success();
}
@@ -2362,28 +2412,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
: llvm::omp::WorksharingLoopType::ForStaticLoop;
SmallVector<llvm::BranchInst *> cancelTerminators;
- // This callback is invoked only if there is cancellation inside of the wsloop
- // body.
- auto finiCB = [&](llvm::OpenMPIRBuilder::InsertPointTy ip) -> llvm::Error {
- llvm::IRBuilderBase &llvmBuilder = ompBuilder->Builder;
- llvm::IRBuilderBase::InsertPointGuard guard(llvmBuilder);
-
- // ip is currently in the block branched to if cancellation occured.
- // We need to create a branch to terminate that block.
- llvmBuilder.restoreIP(ip);
-
- // We must still clean up the wsloop after cancelling it, so we need to
- // branch to the block that finalizes the wsloop.
- // That block has not been created yet so use this block as a dummy for now
- // and fix this after creating the wsloop.
- cancelTerminators.push_back(llvmBuilder.CreateBr(ip.getBlock()));
- return llvm::Error::success();
- };
- // We have to add the cleanup to the OpenMPIRBuilder before the body gets
- // created in case the body contains omp.cancel (which will then expect to be
- // able to find this cleanup callback).
- ompBuilder->pushFinalizationCB({finiCB, llvm::omp::Directive::OMPD_for,
- constructIsCancellable(wsloopOp)});
+ pushCancelFinalizationCB(cancelTerminators, builder, *ompBuilder, wsloopOp,
+ llvm::omp::Directive::OMPD_for);
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::Expected<llvm::BasicBlock *> regionBlock = convertOmpOpRegions(
@@ -2406,18 +2436,8 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
if (failed(handleError(wsloopIP, opInst)))
return failure();
- ompBuilder->popFinalizationCB();
- if (!cancelTerminators.empty()) {
- // If we cancelled the loop, we should branch to the finalization block of
- // the wsloop (which is always immediately before the loop continuation
- // block). Now the finalization has been created, we can fix the branch.
- llvm::BasicBlock *wsloopFini = wsloopIP->getBlock()->getSinglePredecessor();
- for (llvm::BranchInst *cancelBranch : cancelTerminators) {
- assert(cancelBranch->getNumSuccessors() == 1 &&
- "cancel branch should have one target");
- cancelBranch->setSuccessor(0, wsloopFini);
- }
- }
+ // Set the correct branch target for task cancellation
+ popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get());
// Process the reductions if required.
if (failed(createReductionsAndCleanup(
@@ -3075,9 +3095,6 @@ convertOmpCancel(omp::CancelOp op, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
- if (failed(checkImplementationStatus(*op.getOperation())))
- return failure();
-
llvm::Value *ifCond = nullptr;
if (Value ifVar = op.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
@@ -3103,9 +3120,6 @@ convertOmpCancellationPoint(omp::CancellationPointOp op,
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
- if (failed(checkImplementationStatus(*op.getOperation())))
- return failure();
-
llvm::omp::Directive cancelledDirective =
convertCancellationConstructType(op.getCancelDirective());
diff --git a/mlir/test/Target/LLVMIR/openmp-cancel.mlir b/mlir/test/Target/LLVMIR/openmp-cancel.mlir
index 3c195a98d1000..21241702ad569 100644
--- a/mlir/test/Target/LLVMIR/openmp-cancel.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-cancel.mlir
@@ -243,3 +243,51 @@ llvm.func @cancel_wsloop_if(%lb : i32, %ub : i32, %step : i32, %cond : i1) {
// CHECK: ret void
// CHECK: .cncl: ; preds = %[[VAL_44]]
// CHECK: br label %[[VAL_38]]
+
+omp.private {type = firstprivate} @i32_priv : i32 copy {
+^bb0(%arg0: !llvm.ptr, %arg1: !llvm.ptr):
+ %0 = llvm.load %arg0 : !llvm.ptr -> i32
+ llvm.store %0, %arg1 : i32, !llvm.ptr
+ omp.yield(%arg1 : !llvm.ptr)
+}
+
+llvm.func @do_something(!llvm.ptr)
+
+llvm.func @cancel_taskgroup(%arg0: !llvm.ptr) {
+ omp.taskgroup {
+// Using firstprivate clause so we have some end of task cleanup to branch to
+// after the cancellation.
+ omp.task private(@i32_priv %arg0 -> %arg1 : !llvm.ptr) {
+ omp.cancel cancellation_construct_type(taskgroup)
+ llvm.call @do_something(%arg1) : (!llvm.ptr) -> ()
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+}
+// CHECK-LABEL: define internal void @cancel_taskgroup..omp_par(
+// CHECK: task.alloca:
+// CHECK: %[[VAL_21:.*]] = load ptr, ptr %[[VAL_22:.*]], align 8
+// CHECK: %[[VAL_23:.*]] = getelementptr { ptr }, ptr %[[VAL_21]], i32 0, i32 0
+// CHECK: %[[VAL_24:.*]] = load ptr, ptr %[[VAL_23]], align 8, !align !1
+// CHECK: br label %[[VAL_25:.*]]
+// CHECK: task.body: ; preds = %[[VAL_26:.*]]
+// CHECK: %[[VAL_27:.*]] = getelementptr { i32 }, ptr %[[VAL_24]], i32 0, i32 0
+// CHECK: br label %[[VAL_28:.*]]
+// CHECK: omp.task.region: ; preds = %[[VAL_25]]
+// CHECK: %[[VAL_29:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK: %[[VAL_30:.*]] = call i32 @__kmpc_cancel(ptr @1, i32 %[[VAL_29]], i32 4)
+// CHECK: %[[VAL_31:.*]] = icmp eq i32 %[[VAL_30]], 0
+// CHECK: br i1 %[[VAL_31]], label %omp.task.region.split, label %omp.task.region.cncl
+// CHECK: omp.task.region.cncl:
+// CHECK: br label %omp.region.cont2
+// CHECK: omp.region.cont2:
+// Both cancellation and normal paths reach the end-of-task cleanup:
+// CHECK: tail call void @free(ptr %[[VAL_24]])
+// CHECK: br label %task.exit.exitStub
+// CHECK: omp.task.region.split:
+// CHECK: call void @do_something(ptr %[[VAL_27]])
+// CHECK: br label %omp.region.cont2
+// CHECK: task.exit.exitStub:
+// CHECK: ret void
diff --git a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir
index bbb313c113567..5e0d3f9f7e293 100644
--- a/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-cancellation-point.mlir
@@ -186,3 +186,33 @@ llvm.func @cancellation_point_wsloop(%lb : i32, %ub : i32, %step : i32) {
// CHECK: ret void
// CHECK: omp.loop_nest.region.cncl: ; preds = %[[VAL_100]]
// CHECK: br label %[[VAL_96]]
+
+
+llvm.func @cancellation_point_taskgroup() {
+ omp.taskgroup {
+ omp.task {
+ omp.cancellation_point cancellation_construct_type(taskgroup)
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+}
+// CHECK-LABEL: define internal void @cancellation_point_taskgroup..omp_par(
+// CHECK: task.alloca:
+// CHECK: br label %[[VAL_50:.*]]
+// CHECK: task.body: ; preds = %[[VAL_51:.*]]
+// CHECK: br label %[[VAL_52:.*]]
+// CHECK: omp.task.region: ; preds = %[[VAL_50]]
+// CHECK: %[[VAL_53:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
+// CHECK: %[[VAL_54:.*]] = call i32 @__kmpc_cancellationpoint(ptr @1, i32 %[[VAL_53]], i32 4)
+// CHECK: %[[VAL_55:.*]] = icmp eq i32 %[[VAL_54]], 0
+// CHECK: br i1 %[[VAL_55]], label %omp.task.region.split, label %omp.task.region.cncl
+// CHECK: omp.task.region.cncl:
+// CHECK: br label %omp.region.cont1
+// CHECK: omp.region.cont1:
+// CHECK: br label %task.exit.exitStub
+// CHECK: omp.task.region.split:
+// CHECK: br label %omp.region.cont1
+// CHECK: task.exit.exitStub:
+// CHECK: ret void
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 789c0ad9ebb48..8e5b5610aa28a 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -26,40 +26,6 @@ llvm.func @atomic_hint(%v : !llvm.ptr, %x : !llvm.ptr, %expr : i32) {
// -----
-llvm.func @cancel_taskgroup() {
- // expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
- omp.taskgroup {
- // expected-error@below {{LLVM Translation failed for operation: omp.task}}
- omp.task {
- // expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancel operation}}
- // expected-error@below {{LLVM Translation failed for operation: omp.cancel}}
- omp.cancel cancellation_construct_type(taskgroup)
- omp.terminator
- }
- omp.terminator
- }
- llvm.return
-}
-
-// -----
-
-llvm.func @cancellation_point_taskgroup() {
- // expected-error@below {{LLVM Translation failed for operation: omp.taskgroup}}
- omp.taskgroup {
- // expected-error@below {{LLVM Translation failed for operation: omp.task}}
- omp.task {
- // expected-error@below {{not yet implemented: Unhandled clause cancel directive construct type not yet supported in omp.cancellation_point operation}}
- // expected-error@below {{LLVM Translation failed for operation: omp.cancellation_point}}
- omp.cancellation_point cancellation_construct_type(taskgroup)
- omp.terminator
- }
- omp.terminator
- }
- llvm.return
-}
-
-// -----
-
llvm.func @do_simd(%lb : i32, %ub : i32, %step : i32) {
omp.wsloop {
// expected-warning@below {{simd information on composite construct discarded}}
|
PR Stack:
|
No unit test can be written for this todo because there is no support at all for lowering taskloop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you, this LGTM. I just have a non-blocking suggestion.
A cancel or cancellation point for taskgroup is always nested inside of a task inside of the taskgroup. For the task which is cancelled, it is that task which needs to be cleaned up: not the owning taskgroup. Therefore the cancellation branch handler is done in the conversion of the task not in conversion of taskgroup.
I added a firstprivate clause to the test for cancel taskgroup to demonstrate that the block being branched to is the same block where mandatory cleanup code is added. Cancellation point follows exactly the same code path.