diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md index 1e8c1c7be9f6a..526c85febb0bd 100644 --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -746,6 +746,61 @@ For example } ``` +#### Early Exit + +In the current form loops do support an early exit as any block can branch to +the merge block of the loop. However, the problem arises when such early exit +is conditional and the branch is sunk into a `spirv.mlir.selection` region. +In such structure the branch inside the selection region cannot reference block +of the loop enclosing the selection. At the same time such pattern is not unusual. +To support early loop exit within nested structured control flow, SPIR-V dialect +introduces `spirv.mlir.break` operation. The semantic of this operation is to branch +to the merge block of the first enclosing loop. + +For example + +```mlir +spirv.mlir.loop { + spirv.Branch ^header(%zero: i32) + +^header(%i : i32): + %cmp = spirv.SLessThan %i, %count : i32 + spirv.BranchConditional %cmp, ^body, ^merge_loop + +^body: + %cond = spirv.SGreaterThan %i, %five : i32 + spirv.Branch ^selection + +^selection: + spirv.mlir.selection { + spirv.BranchConditional %cond, ^true, ^merge_sel + ^true: + spirv.mlir.break // Jump to ^merge_loop. Regular branch cannot reference ^merge_loop, as it is outside the region. + ^merge_sel: + spirv.mlir.merge + } + + spirv.Branch ^continue + +^continue: + %new_i = spirv.IAdd %i, %one : i32 + spirv.Branch ^header(%new_i: i32) + +^merge_loop: + spirv.mlir.merge +} +``` + +The equivalent GLSL or C code would be + +```c +for (int i = 0; i < 10; ++i) { + x += 1; + if(x > 5) + break; +} +``` + ### Block argument for Phi There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi` diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td index ef6682ab3630c..a6fc454d2fb34 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td @@ -21,6 +21,15 @@ include "mlir/Interfaces/SideEffectInterfaces.td" // ----- +// TODO: This is not only specific to control flow ops, so it could be moved +// somewhere else. +class SPIRV_HasParentOfType : PredOpTrait< + "op expects to be nested in " # op, + CPred<"getOperation()->getParentOfType<::mlir::spirv::" # op # ">() != nullptr"> +>; + +// ----- + def SPIRV_BranchOp : SPIRV_Op<"Branch", [ DeclareOpInterfaceMethods, InFunctionScope, Pure, Terminator]> { @@ -535,4 +544,30 @@ def SPIRV_SelectionOp : SPIRV_Op<"mlir.selection", [InFunctionScope]> { let hasRegionVerifier = 1; } +// ----- + +def SPIRV_BreakOp : SPIRV_Op<"mlir.break", [ + Pure, Terminator, SPIRV_HasParentOfType<"LoopOp">, ReturnLike]> { + let summary = "Early exit from a structured loop."; + + let description = [{ + Since the SPIR-V dialect relies on structured control flow, early exit using + branches is not possible. Since branch cannot reference blocks outside a region + a `spirv.mlir.selection` cannot arbitrarily branch to the merge block of the + enclosing loop. + + To provide support for early exits dialect implements a `spirv.mlir.break` + operation. The semantic of the operation is like that in GLSL / C / C++. + The break operation should be treated as a branch to the merge block of the + enclosing loop. + }]; + + let arguments = (ins); + let results = (outs); + let assemblyFormat = "attr-dict"; + let hasOpcode = 0; + let autogenSerialization = 0; + let hasVerifier = 0; +} + #endif // MLIR_DIALECT_SPIRV_IR_CONTROLFLOW_OPS diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index 1e867dde51001..7b1b9607e032f 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -2345,6 +2345,46 @@ LogicalResult spirv::Deserializer::splitConditionalBlocks() { return success(); } +LogicalResult spirv::Deserializer::handleEarlyExits() { + SmallVector loopMergeBlocks; + + // Find all blocks that are loops' merge blocks. + for (auto &[_, mergeInfo] : blockMergeInfo) + if (mergeInfo.continueBlock) + loopMergeBlocks.push_back(mergeInfo.mergeBlock); + + for (auto &[header, mergeInfo] : blockMergeInfo) { + // We look for something like `if(x) break; ...` so we only process + // selection for now. + if (!mergeInfo.continueBlock) { + SetVector constructBlocks; + constructBlocks.insert(header); + + // Iterate over all blocks in the selection. This is similar to + // `collectBlocksInConstruct()` but with extra logic inserting + // `spirv.mlir.break`. We look for any block inside the selection region + // that jumps directly to the loop merge and does not go through the merge + // block of the selection. This indicates the unstructured jump so the + // branch is replaced with break. + for (unsigned i = 0; i < constructBlocks.size(); ++i) { + for (Block *successor : constructBlocks[i]->getSuccessors()) { + Block *block = constructBlocks[i]; + if (llvm::is_contained(loopMergeBlocks, successor)) { + assert(!block->empty() && block->getNumSuccessors() == 1); + block->back().erase(); + OpBuilder builder(block, block->end()); + builder.create(mergeInfo.loc); + } + if (successor != mergeInfo.mergeBlock) + constructBlocks.insert(successor); + } + } + } + } + + return success(); +} + LogicalResult spirv::Deserializer::structurizeControlFlow() { LLVM_DEBUG({ logger.startLine() @@ -2361,6 +2401,10 @@ LogicalResult spirv::Deserializer::structurizeControlFlow() { return failure(); } + if (failed(handleEarlyExits())) { + return failure(); + } + // TODO: This loop is non-deterministic. Iteration order may vary between runs // for the same shader as the key to the map is a pointer. See: // https://github.com/llvm/llvm-project/issues/128547 diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h index bcc78e3e6508d..c17b4f5f1f860 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.h @@ -252,6 +252,10 @@ class Deserializer { /// selection construct and the merge block of another. LogicalResult splitConditionalBlocks(); + /// Detect unstructured early exits from loops and replaces those arbitrary + /// branches with `spirv.mlir.break` statements. + LogicalResult handleEarlyExits(); + //===--------------------------------------------------------------------===// // Type //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp index ff3cc92ee8078..aedfd05701177 100644 --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -576,6 +576,20 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { return success(); } +LogicalResult Serializer::processBreakOp(spirv::BreakOp breakOp) { + auto parentLoopOp = breakOp.getOperation()->getParentOfType(); + + if (!parentLoopOp) + return failure(); + + auto *mergeBlock = parentLoopOp.getMergeBlock(); + auto mergeID = getBlockID(mergeBlock); + + encodeInstructionInto(functionBody, spirv::Opcode::OpBranch, {mergeID}); + + return success(); +} + LogicalResult Serializer::processBranchConditionalOp( spirv::BranchConditionalOp condBranchOp) { auto conditionID = getValueID(condBranchOp.getCondition()); diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp index 1f4f5d7f764db..a8043a2f65086 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp @@ -1255,6 +1255,7 @@ LogicalResult Serializer::processOperation(Operation *opInst) { return processGlobalVariableOp(op); }) .Case([&](spirv::LoopOp op) { return processLoopOp(op); }) + .Case([&](spirv::BreakOp op) { return processBreakOp(op); }) .Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); }) .Case([&](spirv::SelectionOp op) { return processSelectionOp(op); }) .Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); }) diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.h b/mlir/lib/Target/SPIRV/Serialization/Serializer.h index 9edb0f4af008d..b758dd810cbea 100644 --- a/mlir/lib/Target/SPIRV/Serialization/Serializer.h +++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.h @@ -261,6 +261,8 @@ class Serializer { LogicalResult processLoopOp(spirv::LoopOp loopOp); + LogicalResult processBreakOp(spirv::BreakOp breakOp); + LogicalResult processBranchConditionalOp(spirv::BranchConditionalOp); LogicalResult processBranchOp(spirv::BranchOp branchOp); diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir index 8ec0bf5bbaacf..23fac9a140333 100644 --- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir @@ -467,6 +467,76 @@ func.func @loop_yield(%count : i32) -> () { // ----- +func.func @loop_break(%count : i32) -> () { + %zero = spirv.Constant 0: i32 + %one = spirv.Constant 1: i32 + %five = spirv.Constant 5: i32 + + // CHECK: spirv.mlir.loop { + spirv.mlir.loop { + // CHECK-NEXT: spirv.Branch ^bb1({{%.*}}: i32) + spirv.Branch ^header(%zero: i32) + + // CHECK-NEXT: ^bb1({{%.*}}: i32): + ^header(%i : i32): + %cmp = spirv.SLessThan %i, %count : i32 + // CHECK: spirv.BranchConditional {{%.*}}, ^bb2, ^bb5 + spirv.BranchConditional %cmp, ^body, ^merge + + // CHECK-NEXT: ^bb2: + ^body: + %cond = spirv.SGreaterThan %i, %five : i32 + + // CHECK: spirv.Branch ^bb3 + spirv.Branch ^selection + + // CHECK-NEXT: ^bb3: + ^selection: + // CHECK-NEXT: spirv.mlir.selection { + spirv.mlir.selection { + // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^bb1, ^bb2 + spirv.BranchConditional %cond, ^true, ^merge + // CHECK-NEXT: ^bb1: + ^true: + // CHECK-NEXT: spirv.mlir.break + spirv.mlir.break + // CHECK-NEXT: ^bb2: + ^merge: + // CHECK-NEXT: spirv.mlir.merge + spirv.mlir.merge + } + + // CHECK: spirv.Branch ^bb4 + spirv.Branch ^continue + + // CHECK-NEXT: ^bb4: + ^continue: + %new_i = spirv.IAdd %i, %one : i32 + // CHECK: spirv.Branch ^bb1({{%.*}}: i32) + spirv.Branch ^header(%new_i: i32) + + // CHECK-NEXT: ^bb5: + ^merge: + // CHECK-NEXT: spirv.mlir.merge + spirv.mlir.merge + } + + return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spirv.mlir.break +//===----------------------------------------------------------------------===// + +func.func @break() -> () { + // expected-error @+1 {{op expects to be nested in LoopOp}} + spirv.mlir.break +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.mlir.merge //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir index 95b87b319ac2d..f6b5d44aa9c74 100644 --- a/mlir/test/Target/SPIRV/loop.mlir +++ b/mlir/test/Target/SPIRV/loop.mlir @@ -335,3 +335,71 @@ spirv.module Logical GLSL450 requires #spirv.vce { spirv.Return } } + +// ----- + +// Loop with break statement + +spirv.module Logical GLSL450 requires #spirv.vce { + spirv.func @loop_break(%count : i32) -> () "None" { + %zero = spirv.Constant 0: i32 + %one = spirv.Constant 1: i32 + %five = spirv.Constant 5: i32 + + // CHECK: spirv.mlir.loop { + spirv.mlir.loop { + // CHECK-NEXT: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32) + spirv.Branch ^header(%zero: i32) + + // CHECK-NEXT: ^[[HEADER:.+]]({{%.*}}: i32): + ^header(%i : i32): + %cmp = spirv.SLessThan %i, %count : i32 + // CHECK: spirv.BranchConditional {{%.*}}, ^[[BODY:.+]], ^[[MERGE:.+]] + spirv.BranchConditional %cmp, ^body, ^merge + + // CHECK-NEXT: ^[[BODY:.+]]: + ^body: + %cond = spirv.SGreaterThan %i, %five : i32 + + // CHECK: spirv.Branch ^[[LINK:.+]] + spirv.Branch ^selection + + // COM: Artificial block introduced by block splitting in the deserializer. + // CHECK-NEXT: ^[[LINK:.+]]: + // CHECK-NEXT: spirv.Branch ^[[SELECTION:.+]] + + // CHECK-NEXT: ^[[SELECTION:.+]]: + ^selection: + // CHECK-NEXT: spirv.mlir.selection { + spirv.mlir.selection { + // CHECK-NEXT: spirv.BranchConditional {{%.*}}, ^[[TRUE:.+]], ^[[FALSE:.+]] + spirv.BranchConditional %cond, ^true, ^merge + // CHECK-NEXT: ^[[TRUE:.+]]: + ^true: + // CHECK-NEXT: spirv.mlir.break + spirv.mlir.break + // CHECK-NEXT: ^[[MERGE:.+]]: + ^merge: + // CHECK-NEXT: spirv.mlir.merge + spirv.mlir.merge + } + + // CHECK: spirv.Branch ^[[CONTINUE:.+]] + spirv.Branch ^continue + + // CHECK-NEXT: ^[[CONTINUE:.+]]: + ^continue: + %new_i = spirv.IAdd %i, %one : i32 + // CHECK: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32) + spirv.Branch ^header(%new_i: i32) + + // CHECK-NEXT: ^[[MERGE:.+]]: + ^merge: + // CHECK-NEXT: spirv.mlir.merge + spirv.mlir.merge + } + + // CHECK: spirv.Return + spirv.Return + } +}