-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][spirv] Add support for spirv.mlir.break
#138688
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Since the SPIR-V dialect uses structured control flow arbitrary branching, which includes loops' early exit, is not supported. This patch introduces new `break` operation that allows to support early exits in loop within the structured control flow.
@llvm/pr-subscribers-mlir Author: Igor Wodiany (IgWod-IMG) ChangesSince the SPIR-V dialect uses structured control flow arbitrary branching, which includes conditional loops' early exit, is not supported. This patch introduces new The main problem this PR tries to solve is the case where a branch to the loop merge block is wrapped in a selection op. Since the selection op cannot reference blocks outside it, a different approach is needed. I am open to feedback whether a better approach exists that does not introduce a new op. Full diff: https://github.com/llvm/llvm-project/pull/138688.diff 9 Files Affected:
diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index 1e8c1c7be9f6a..526c85febb0bd 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -746,6 +746,61 @@ For example
}
```
+#### Early Exit
+
+In the current form loops do support an early exit as any block can branch to
+the merge block of the loop. However, the problem arises when such early exit
+is conditional and the branch is sunk into a `spirv.mlir.selection` region.
+In such structure the branch inside the selection region cannot reference block
+of the loop enclosing the selection. At the same time such pattern is not unusual.
+To support early loop exit within nested structured control flow, SPIR-V dialect
+introduces `spirv.mlir.break` operation. The semantic of this operation is to branch
+to the merge block of the first enclosing loop.
+
+For example
+
+```mlir
+spirv.mlir.loop {
+ spirv.Branch ^header(%zero: i32)
+
+^header(%i : i32):
+ %cmp = spirv.SLessThan %i, %count : i32
+ spirv.BranchConditional %cmp, ^body, ^merge_loop
+
+^body:
+ %cond = spirv.SGreaterThan %i, %five : i32
+ spirv.Branch ^selection
+
+^selection:
+ spirv.mlir.selection {
+ spirv.BranchConditional %cond, ^true, ^merge_sel
+ ^true:
+ spirv.mlir.break // Jump to ^merge_loop. Regular branch cannot reference ^merge_loop, as it is outside the region.
+ ^merge_sel:
+ spirv.mlir.merge
+ }
+
+ spirv.Branch ^continue
+
+^continue:
+ %new_i = spirv.IAdd %i, %one : i32
+ spirv.Branch ^header(%new_i: i32)
+
+^merge_loop:
+ spirv.mlir.merge
+}
+```
+
+The equivalent GLSL or C code would be
+
+```c
+for (int i = 0; i < 10; ++i) {
+ x += 1;
+ if(x > 5)
+ break;
+}
+```
+
### Block argument for Phi
There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi`
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index ef6682ab3630c..a6fc454d2fb34 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -21,6 +21,15 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
// -----
+// TODO: This is not only specific to control flow ops, so it could be moved
+// somewhere else.
+class SPIRV_HasParentOfType<string op> : PredOpTrait<
+ "op expects to be nested in " # op,
+ CPred<"getOperation()->getParentOfType<::mlir::spirv::" # op # ">() != nullptr">
+>;
+
+// -----
+
def SPIRV_BranchOp : SPIRV_Op<"Branch", [
DeclareOpInterfaceMethods<BranchOpInterface>, InFunctionScope, Pure,
Terminator]> {
@@ -535,4 +544,30 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
let hasRegionVerifier = 1;
}
+// -----
+
+def SPIRV_BreakOp : SPIRV_Op<"mlir.break", [
+ Pure, Terminator, SPIRV_HasParentOfType<"LoopOp">, ReturnLike]> {
+ let summary = "Early exit from a structured loop.";
+
+ let description = [{
+ Since the SPIR-V dialect relies on structured control flow, early exit using
+ branches is not possible. Since branch cannot reference blocks outside a region
+ a `spirv.mlir.selection` cannot arbitrarily branch to the merge block of the
+ enclosing loop.
+
+ To provide support for early exits dialect implements a `spirv.mlir.break`
+ operation. The semantic of the operation is like that in GLSL / C / C++.
+ The break operation should be treated as a branch to the merge block of the
+ enclosing loop.
+ }];
+
+ let arguments = (ins);
+ let results = (outs);
+ let assemblyFormat = "attr-dict";
+ let hasOpcode = 0;
+ let autogenSerialization = 0;
+ let hasVerifier = 0;
+}
+
#endif // MLIR_DIALECT_SPIRV_IR_CONTROLFLOW_OPS
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 1e867dde51001..7b1b9607e032f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2345,6 +2345,46 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
return success();
}
+LogicalResult spirv::Deserializer::handleEarlyExits() {
+ SmallVector<Block *> loopMergeBlocks;
+
+ // Find all blocks that are loops' merge blocks.
+ for (auto &[_, mergeInfo] : blockMergeInfo)
+ if (mergeInfo.continueBlock)
+ loopMergeBlocks.push_back(mergeInfo.mergeBlock);
+
+ for (auto &[header, mergeInfo] : blockMergeInfo) {
+ // We look for something like `if(x) break; ...` so we only process
+ // selection for now.
+ if (!mergeInfo.continueBlock) {
+ SetVector<Block *> constructBlocks;
+ constructBlocks.insert(header);
+
+ // Iterate over all blocks in the selection. This is similar to
+ // `collectBlocksInConstruct()` but with extra logic inserting
+ // `spirv.mlir.break`. We look for any block inside the selection region
+ // that jumps directly to the loop merge and does not go through the merge
+ // block of the selection. This indicates the unstructured jump so the
+ // branch is replaced with break.
+ for (unsigned i = 0; i < constructBlocks.size(); ++i) {
+ for (Block *successor : constructBlocks[i]->getSuccessors()) {
+ Block *block = constructBlocks[i];
+ if (llvm::is_contained(loopMergeBlocks, successor)) {
+ assert(!block->empty() && block->getNumSuccessors() == 1);
+ block->back().erase();
+ OpBuilder builder(block, block->end());
+ builder.create<spirv::BreakOp>(mergeInfo.loc);
+ }
+ if (successor != mergeInfo.mergeBlock)
+ constructBlocks.insert(successor);
+ }
+ }
+ }
+ }
+
+ return success();
+}
+
LogicalResult spirv::Deserializer::structurizeControlFlow() {
LLVM_DEBUG({
logger.startLine()
@@ -2361,6 +2401,10 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
return failure();
}
+ if (failed(handleEarlyExits())) {
+ return failure();
+ }
+
// TODO: This loop is non-deterministic. Iteration order may vary between runs
// for the same shader as the key to the map is a pointer. See:
// https://github.com/llvm/llvm-project/issues/128547
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index bcc78e3e6508d..c17b4f5f1f860 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -252,6 +252,10 @@ class Deserializer {
/// selection construct and the merge block of another.
LogicalResult splitConditionalBlocks();
+ /// Detect unstructured early exits from loops and replaces those arbitrary
+ /// branches with `spirv.mlir.break` statements.
+ LogicalResult handleEarlyExits();
+
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index ff3cc92ee8078..aedfd05701177 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -576,6 +576,20 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
return success();
}
+LogicalResult Serializer::processBreakOp(spirv::BreakOp breakOp) {
+ auto parentLoopOp = breakOp.getOperation()->getParentOfType<spirv::LoopOp>();
+
+ if (!parentLoopOp)
+ return failure();
+
+ auto *mergeBlock = parentLoopOp.getMergeBlock();
+ auto mergeID = getBlockID(mergeBlock);
+
+ encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {mergeID});
+
+ return success();
+}
+
LogicalResult Serializer::processBranchConditionalOp(
spirv::BranchConditionalOp condBranchOp) {
auto conditionID = getValueID(condBranchOp.getCondition());
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 1f4f5d7f764db..a8043a2f65086 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1255,6 +1255,7 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
return processGlobalVariableOp(op);
})
.Case([&](spirv::LoopOp op) { return processLoopOp(op); })
+ .Case([&](spirv::BreakOp op) { return processBreakOp(op); })
.Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
.Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
.Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 9edb0f4af008d..b758dd810cbea 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -261,6 +261,8 @@ class Serializer {
LogicalResult processLoopOp(spirv::LoopOp loopOp);
+ LogicalResult processBreakOp(spirv::BreakOp breakOp);
+
LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp);
LogicalResult processBranchOp(spirv::BranchOp branchOp);
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 8ec0bf5bbaacf..23fac9a140333 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -467,6 +467,76 @@ func.func @loop_yield(%count : i32) -> () {
// -----
+func.func @loop_break(%count : i32) -> () {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %five = spirv.Constant 5: i32
+
+ // CHECK: spirv.mlir.loop {
+ spirv.mlir.loop {
+ // CHECK-NEXT: spirv.Branch ^bb1({{%.*}}: i32)
+ spirv.Branch ^header(%zero: i32)
+
+ // CHECK-NEXT: ^bb1({{%.*}}: i32):
+ ^header(%i : i32):
+ %cmp = spirv.SLessThan %i, %count : i32
+ // CHECK: spirv.BranchConditional {{%.*}}, ^bb2, ^bb5
+ spirv.BranchConditional %cmp, ^body, ^merge
+
+ // CHECK-NEXT: ^bb2:
+ ^body:
+ %cond = spirv.SGreaterThan %i, %five : i32
+
+ // CHECK: spirv.Branch ^bb3
+ spirv.Branch ^selection
+
+ // CHECK-NEXT: ^bb3:
+ ^selection:
+ // CHECK-NEXT: spirv.mlir.selection {
+ spirv.mlir.selection {
+ // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^bb1, ^bb2
+ spirv.BranchConditional %cond, ^true, ^merge
+ // CHECK-NEXT: ^bb1:
+ ^true:
+ // CHECK-NEXT: spirv.mlir.break
+ spirv.mlir.break
+ // CHECK-NEXT: ^bb2:
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+ }
+
+ // CHECK: spirv.Branch ^bb4
+ spirv.Branch ^continue
+
+ // CHECK-NEXT: ^bb4:
+ ^continue:
+ %new_i = spirv.IAdd %i, %one : i32
+ // CHECK: spirv.Branch ^bb1({{%.*}}: i32)
+ spirv.Branch ^header(%new_i: i32)
+
+ // CHECK-NEXT: ^bb5:
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+ }
+
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.mlir.break
+//===----------------------------------------------------------------------===//
+
+func.func @break() -> () {
+ // expected-error @+1 {{op expects to be nested in LoopOp}}
+ spirv.mlir.break
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.mlir.merge
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir
index 95b87b319ac2d..f6b5d44aa9c74 100644
--- a/mlir/test/Target/SPIRV/loop.mlir
+++ b/mlir/test/Target/SPIRV/loop.mlir
@@ -335,3 +335,71 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
spirv.Return
}
}
+
+// -----
+
+// Loop with break statement
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+ spirv.func @loop_break(%count : i32) -> () "None" {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %five = spirv.Constant 5: i32
+
+ // CHECK: spirv.mlir.loop {
+ spirv.mlir.loop {
+ // CHECK-NEXT: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+ spirv.Branch ^header(%zero: i32)
+
+ // CHECK-NEXT: ^[[HEADER:.+]]({{%.*}}: i32):
+ ^header(%i : i32):
+ %cmp = spirv.SLessThan %i, %count : i32
+ // CHECK: spirv.BranchConditional {{%.*}}, ^[[BODY:.+]], ^[[MERGE:.+]]
+ spirv.BranchConditional %cmp, ^body, ^merge
+
+ // CHECK-NEXT: ^[[BODY:.+]]:
+ ^body:
+ %cond = spirv.SGreaterThan %i, %five : i32
+
+ // CHECK: spirv.Branch ^[[LINK:.+]]
+ spirv.Branch ^selection
+
+ // COM: Artificial block introduced by block splitting in the deserializer.
+ // CHECK-NEXT: ^[[LINK:.+]]:
+ // CHECK-NEXT: spirv.Branch ^[[SELECTION:.+]]
+
+ // CHECK-NEXT: ^[[SELECTION:.+]]:
+ ^selection:
+ // CHECK-NEXT: spirv.mlir.selection {
+ spirv.mlir.selection {
+ // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^[[TRUE:.+]], ^[[FALSE:.+]]
+ spirv.BranchConditional %cond, ^true, ^merge
+ // CHECK-NEXT: ^[[TRUE:.+]]:
+ ^true:
+ // CHECK-NEXT: spirv.mlir.break
+ spirv.mlir.break
+ // CHECK-NEXT: ^[[MERGE:.+]]:
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+ }
+
+ // CHECK: spirv.Branch ^[[CONTINUE:.+]]
+ spirv.Branch ^continue
+
+ // CHECK-NEXT: ^[[CONTINUE:.+]]:
+ ^continue:
+ %new_i = spirv.IAdd %i, %one : i32
+ // CHECK: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+ spirv.Branch ^header(%new_i: i32)
+
+ // CHECK-NEXT: ^[[MERGE:.+]]:
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+ }
+
+ // CHECK: spirv.Return
+ spirv.Return
+ }
+}
|
@llvm/pr-subscribers-mlir-spirv Author: Igor Wodiany (IgWod-IMG) ChangesSince the SPIR-V dialect uses structured control flow arbitrary branching, which includes conditional loops' early exit, is not supported. This patch introduces new The main problem this PR tries to solve is the case where a branch to the loop merge block is wrapped in a selection op. Since the selection op cannot reference blocks outside it, a different approach is needed. I am open to feedback whether a better approach exists that does not introduce a new op. Full diff: https://github.com/llvm/llvm-project/pull/138688.diff 9 Files Affected:
diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index 1e8c1c7be9f6a..526c85febb0bd 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -746,6 +746,61 @@ For example
}
```
+#### Early Exit
+
+In the current form loops do support an early exit as any block can branch to
+the merge block of the loop. However, the problem arises when such early exit
+is conditional and the branch is sunk into a `spirv.mlir.selection` region.
+In such structure the branch inside the selection region cannot reference block
+of the loop enclosing the selection. At the same time such pattern is not unusual.
+To support early loop exit within nested structured control flow, SPIR-V dialect
+introduces `spirv.mlir.break` operation. The semantic of this operation is to branch
+to the merge block of the first enclosing loop.
+
+For example
+
+```mlir
+spirv.mlir.loop {
+ spirv.Branch ^header(%zero: i32)
+
+^header(%i : i32):
+ %cmp = spirv.SLessThan %i, %count : i32
+ spirv.BranchConditional %cmp, ^body, ^merge_loop
+
+^body:
+ %cond = spirv.SGreaterThan %i, %five : i32
+ spirv.Branch ^selection
+
+^selection:
+ spirv.mlir.selection {
+ spirv.BranchConditional %cond, ^true, ^merge_sel
+ ^true:
+ spirv.mlir.break // Jump to ^merge_loop. Regular branch cannot reference ^merge_loop, as it is outside the region.
+ ^merge_sel:
+ spirv.mlir.merge
+ }
+
+ spirv.Branch ^continue
+
+^continue:
+ %new_i = spirv.IAdd %i, %one : i32
+ spirv.Branch ^header(%new_i: i32)
+
+^merge_loop:
+ spirv.mlir.merge
+}
+```
+
+The equivalent GLSL or C code would be
+
+```c
+for (int i = 0; i < 10; ++i) {
+ x += 1;
+ if(x > 5)
+ break;
+}
+```
+
### Block argument for Phi
There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi`
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index ef6682ab3630c..a6fc454d2fb34 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -21,6 +21,15 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
// -----
+// TODO: This is not only specific to control flow ops, so it could be moved
+// somewhere else.
+class SPIRV_HasParentOfType<string op> : PredOpTrait<
+ "op expects to be nested in " # op,
+ CPred<"getOperation()->getParentOfType<::mlir::spirv::" # op # ">() != nullptr">
+>;
+
+// -----
+
def SPIRV_BranchOp : SPIRV_Op<"Branch", [
DeclareOpInterfaceMethods<BranchOpInterface>, InFunctionScope, Pure,
Terminator]> {
@@ -535,4 +544,30 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> {
let hasRegionVerifier = 1;
}
+// -----
+
+def SPIRV_BreakOp : SPIRV_Op<"mlir.break", [
+ Pure, Terminator, SPIRV_HasParentOfType<"LoopOp">, ReturnLike]> {
+ let summary = "Early exit from a structured loop.";
+
+ let description = [{
+ Since the SPIR-V dialect relies on structured control flow, early exit using
+ branches is not possible. Since branch cannot reference blocks outside a region
+ a `spirv.mlir.selection` cannot arbitrarily branch to the merge block of the
+ enclosing loop.
+
+ To provide support for early exits dialect implements a `spirv.mlir.break`
+ operation. The semantic of the operation is like that in GLSL / C / C++.
+ The break operation should be treated as a branch to the merge block of the
+ enclosing loop.
+ }];
+
+ let arguments = (ins);
+ let results = (outs);
+ let assemblyFormat = "attr-dict";
+ let hasOpcode = 0;
+ let autogenSerialization = 0;
+ let hasVerifier = 0;
+}
+
#endif // MLIR_DIALECT_SPIRV_IR_CONTROLFLOW_OPS
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 1e867dde51001..7b1b9607e032f 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2345,6 +2345,46 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() {
return success();
}
+LogicalResult spirv::Deserializer::handleEarlyExits() {
+ SmallVector<Block *> loopMergeBlocks;
+
+ // Find all blocks that are loops' merge blocks.
+ for (auto &[_, mergeInfo] : blockMergeInfo)
+ if (mergeInfo.continueBlock)
+ loopMergeBlocks.push_back(mergeInfo.mergeBlock);
+
+ for (auto &[header, mergeInfo] : blockMergeInfo) {
+ // We look for something like `if(x) break; ...` so we only process
+ // selection for now.
+ if (!mergeInfo.continueBlock) {
+ SetVector<Block *> constructBlocks;
+ constructBlocks.insert(header);
+
+ // Iterate over all blocks in the selection. This is similar to
+ // `collectBlocksInConstruct()` but with extra logic inserting
+ // `spirv.mlir.break`. We look for any block inside the selection region
+ // that jumps directly to the loop merge and does not go through the merge
+ // block of the selection. This indicates the unstructured jump so the
+ // branch is replaced with break.
+ for (unsigned i = 0; i < constructBlocks.size(); ++i) {
+ for (Block *successor : constructBlocks[i]->getSuccessors()) {
+ Block *block = constructBlocks[i];
+ if (llvm::is_contained(loopMergeBlocks, successor)) {
+ assert(!block->empty() && block->getNumSuccessors() == 1);
+ block->back().erase();
+ OpBuilder builder(block, block->end());
+ builder.create<spirv::BreakOp>(mergeInfo.loc);
+ }
+ if (successor != mergeInfo.mergeBlock)
+ constructBlocks.insert(successor);
+ }
+ }
+ }
+ }
+
+ return success();
+}
+
LogicalResult spirv::Deserializer::structurizeControlFlow() {
LLVM_DEBUG({
logger.startLine()
@@ -2361,6 +2401,10 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() {
return failure();
}
+ if (failed(handleEarlyExits())) {
+ return failure();
+ }
+
// TODO: This loop is non-deterministic. Iteration order may vary between runs
// for the same shader as the key to the map is a pointer. See:
// https://github.com/llvm/llvm-project/issues/128547
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
index bcc78e3e6508d..c17b4f5f1f860 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h
@@ -252,6 +252,10 @@ class Deserializer {
/// selection construct and the merge block of another.
LogicalResult splitConditionalBlocks();
+ /// Detect unstructured early exits from loops and replaces those arbitrary
+ /// branches with `spirv.mlir.break` statements.
+ LogicalResult handleEarlyExits();
+
//===--------------------------------------------------------------------===//
// Type
//===--------------------------------------------------------------------===//
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index ff3cc92ee8078..aedfd05701177 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -576,6 +576,20 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
return success();
}
+LogicalResult Serializer::processBreakOp(spirv::BreakOp breakOp) {
+ auto parentLoopOp = breakOp.getOperation()->getParentOfType<spirv::LoopOp>();
+
+ if (!parentLoopOp)
+ return failure();
+
+ auto *mergeBlock = parentLoopOp.getMergeBlock();
+ auto mergeID = getBlockID(mergeBlock);
+
+ encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {mergeID});
+
+ return success();
+}
+
LogicalResult Serializer::processBranchConditionalOp(
spirv::BranchConditionalOp condBranchOp) {
auto conditionID = getValueID(condBranchOp.getCondition());
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 1f4f5d7f764db..a8043a2f65086 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -1255,6 +1255,7 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
return processGlobalVariableOp(op);
})
.Case([&](spirv::LoopOp op) { return processLoopOp(op); })
+ .Case([&](spirv::BreakOp op) { return processBreakOp(op); })
.Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
.Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
.Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
index 9edb0f4af008d..b758dd810cbea 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h
@@ -261,6 +261,8 @@ class Serializer {
LogicalResult processLoopOp(spirv::LoopOp loopOp);
+ LogicalResult processBreakOp(spirv::BreakOp breakOp);
+
LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp);
LogicalResult processBranchOp(spirv::BranchOp branchOp);
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 8ec0bf5bbaacf..23fac9a140333 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -467,6 +467,76 @@ func.func @loop_yield(%count : i32) -> () {
// -----
+func.func @loop_break(%count : i32) -> () {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %five = spirv.Constant 5: i32
+
+ // CHECK: spirv.mlir.loop {
+ spirv.mlir.loop {
+ // CHECK-NEXT: spirv.Branch ^bb1({{%.*}}: i32)
+ spirv.Branch ^header(%zero: i32)
+
+ // CHECK-NEXT: ^bb1({{%.*}}: i32):
+ ^header(%i : i32):
+ %cmp = spirv.SLessThan %i, %count : i32
+ // CHECK: spirv.BranchConditional {{%.*}}, ^bb2, ^bb5
+ spirv.BranchConditional %cmp, ^body, ^merge
+
+ // CHECK-NEXT: ^bb2:
+ ^body:
+ %cond = spirv.SGreaterThan %i, %five : i32
+
+ // CHECK: spirv.Branch ^bb3
+ spirv.Branch ^selection
+
+ // CHECK-NEXT: ^bb3:
+ ^selection:
+ // CHECK-NEXT: spirv.mlir.selection {
+ spirv.mlir.selection {
+ // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^bb1, ^bb2
+ spirv.BranchConditional %cond, ^true, ^merge
+ // CHECK-NEXT: ^bb1:
+ ^true:
+ // CHECK-NEXT: spirv.mlir.break
+ spirv.mlir.break
+ // CHECK-NEXT: ^bb2:
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+ }
+
+ // CHECK: spirv.Branch ^bb4
+ spirv.Branch ^continue
+
+ // CHECK-NEXT: ^bb4:
+ ^continue:
+ %new_i = spirv.IAdd %i, %one : i32
+ // CHECK: spirv.Branch ^bb1({{%.*}}: i32)
+ spirv.Branch ^header(%new_i: i32)
+
+ // CHECK-NEXT: ^bb5:
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+ }
+
+ return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spirv.mlir.break
+//===----------------------------------------------------------------------===//
+
+func.func @break() -> () {
+ // expected-error @+1 {{op expects to be nested in LoopOp}}
+ spirv.mlir.break
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.mlir.merge
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir
index 95b87b319ac2d..f6b5d44aa9c74 100644
--- a/mlir/test/Target/SPIRV/loop.mlir
+++ b/mlir/test/Target/SPIRV/loop.mlir
@@ -335,3 +335,71 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
spirv.Return
}
}
+
+// -----
+
+// Loop with break statement
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+ spirv.func @loop_break(%count : i32) -> () "None" {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %five = spirv.Constant 5: i32
+
+ // CHECK: spirv.mlir.loop {
+ spirv.mlir.loop {
+ // CHECK-NEXT: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+ spirv.Branch ^header(%zero: i32)
+
+ // CHECK-NEXT: ^[[HEADER:.+]]({{%.*}}: i32):
+ ^header(%i : i32):
+ %cmp = spirv.SLessThan %i, %count : i32
+ // CHECK: spirv.BranchConditional {{%.*}}, ^[[BODY:.+]], ^[[MERGE:.+]]
+ spirv.BranchConditional %cmp, ^body, ^merge
+
+ // CHECK-NEXT: ^[[BODY:.+]]:
+ ^body:
+ %cond = spirv.SGreaterThan %i, %five : i32
+
+ // CHECK: spirv.Branch ^[[LINK:.+]]
+ spirv.Branch ^selection
+
+ // COM: Artificial block introduced by block splitting in the deserializer.
+ // CHECK-NEXT: ^[[LINK:.+]]:
+ // CHECK-NEXT: spirv.Branch ^[[SELECTION:.+]]
+
+ // CHECK-NEXT: ^[[SELECTION:.+]]:
+ ^selection:
+ // CHECK-NEXT: spirv.mlir.selection {
+ spirv.mlir.selection {
+ // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^[[TRUE:.+]], ^[[FALSE:.+]]
+ spirv.BranchConditional %cond, ^true, ^merge
+ // CHECK-NEXT: ^[[TRUE:.+]]:
+ ^true:
+ // CHECK-NEXT: spirv.mlir.break
+ spirv.mlir.break
+ // CHECK-NEXT: ^[[MERGE:.+]]:
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+ }
+
+ // CHECK: spirv.Branch ^[[CONTINUE:.+]]
+ spirv.Branch ^continue
+
+ // CHECK-NEXT: ^[[CONTINUE:.+]]:
+ ^continue:
+ %new_i = spirv.IAdd %i, %one : i32
+ // CHECK: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+ spirv.Branch ^header(%new_i: i32)
+
+ // CHECK-NEXT: ^[[MERGE:.+]]:
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge
+ spirv.mlir.merge
+ }
+
+ // CHECK: spirv.Return
+ spirv.Return
+ }
+}
|
Since the SPIR-V dialect uses structured control flow arbitrary branching, which includes conditional loops' early exit, is not supported. This patch introduces new
break
operation that allows to support early exits in loop within the structured control flow.The main problem this PR tries to solve is the case where a branch to the loop merge block is wrapped in a selection op. Since the selection op cannot reference blocks outside it, a different approach is needed. I am open to feedback whether a better approach exists that does not introduce a new op.