diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md index ae9afbd9fdfe5..1e8c1c7be9f6a 100644 --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -734,6 +734,18 @@ func.func @loop(%count : i32) -> () { } ``` +Similarly to selection, loops can also yield values using `spirv.mlir.merge`. This +mechanism allows values defined within the loop region to be used outside of it. + +For example + +```mlir +%yielded = spirv.mlir.loop -> i32 { + // ... + spirv.mlir.merge %to_yield : i32 +} +``` + ### Block argument for Phi There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi` diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td index 039af03871411..ef6682ab3630c 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td @@ -311,17 +311,27 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> { The continue block should be the second to last block and it should have a branch to the loop header block. The loop continue block should be the only block, except the entry block, branching to the header block. + + Values defined inside the loop regions cannot be directly used + outside of them; however, the loop region can yield values. These values are + yielded using a `spirv.mlir.merge` op and returned as a result of the loop op. }]; let arguments = (ins SPIRV_LoopControlAttr:$loop_control ); - let results = (outs); + let results = (outs Variadic:$results); let regions = (region AnyRegion:$body); - let builders = [OpBuilder<(ins)>]; + let builders = [ + OpBuilder<(ins)>, + OpBuilder<(ins "spirv::LoopControl":$loopControl), + [{ + build($_builder, $_state, TypeRange(), loopControl); + }]> + ]; let extraClassDeclaration = [{ // Returns the entry block. diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp index 577959bbdbeaa..371456552b5b5 100644 --- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp @@ -229,6 +229,11 @@ ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) { if (parseControlAttribute(parser, result)) return failure(); + + if (succeeded(parser.parseOptionalArrow())) + if (parser.parseTypeList(result.types)) + return failure(); + return parser.parseRegion(*result.addRegion(), /*arguments=*/{}); } @@ -236,6 +241,10 @@ void LoopOp::print(OpAsmPrinter &printer) { auto control = getLoopControl(); if (control != spirv::LoopControl::None) printer << " control(" << spirv::stringifyLoopControl(control) << ")"; + if (getNumResults() > 0) { + printer << " -> "; + printer << getResultTypes(); + } printer << ' '; printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/true); diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp index b0220bc16e15e..1e867dde51001 100644 --- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp +++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp @@ -2003,7 +2003,14 @@ LogicalResult ControlFlowStructurizer::structurize() { // block inside the selection (`body.back()`). Values produced by block // arguments will be yielded by the selection region. We do not update uses or // erase original block arguments yet. It will be done later in the code. - if (!isLoop) { + // + // Code below is not executed for loops as it would interfere with the logic + // above. Currently block arguments in the merge block are not supported, but + // instead, the code above copies those arguments from the header block into + // the merge block. As such, running the code would yield those copied + // arguments that is most likely not a desired behaviour. This may need to be + // revisited in the future. + if (!isLoop) for (BlockArgument blockArg : mergeBlock->getArguments()) { // Create new block arguments in the last block ("merge block") of the // selection region. We create one argument for each argument in @@ -2013,7 +2020,6 @@ LogicalResult ControlFlowStructurizer::structurize() { valuesToYield.push_back(body.back().getArguments().back()); outsideUses.push_back(blockArg); } - } // All the blocks cloned into the SelectionOp/LoopOp's region can now be // cleaned up. @@ -2025,32 +2031,30 @@ LogicalResult ControlFlowStructurizer::structurize() { // All internal uses should be removed from original blocks by now, so // whatever is left is an outside use and will need to be yielded from - // the newly created selection region. - if (!isLoop) { - for (Block *block : constructBlocks) { - for (Operation &op : *block) { - if (!op.use_empty()) - for (Value result : op.getResults()) { - valuesToYield.push_back(mapper.lookupOrNull(result)); - outsideUses.push_back(result); - } - } - for (BlockArgument &arg : block->getArguments()) { - if (!arg.use_empty()) { - valuesToYield.push_back(mapper.lookupOrNull(arg)); - outsideUses.push_back(arg); + // the newly created selection / loop region. + for (Block *block : constructBlocks) { + for (Operation &op : *block) { + if (!op.use_empty()) + for (Value result : op.getResults()) { + valuesToYield.push_back(mapper.lookupOrNull(result)); + outsideUses.push_back(result); } + } + for (BlockArgument &arg : block->getArguments()) { + if (!arg.use_empty()) { + valuesToYield.push_back(mapper.lookupOrNull(arg)); + outsideUses.push_back(arg); } } } assert(valuesToYield.size() == outsideUses.size()); - // If we need to yield any values from the selection region we will take - // care of it here. - if (!isLoop && !valuesToYield.empty()) { + // If we need to yield any values from the selection / loop region we will + // take care of it here. + if (!valuesToYield.empty()) { LLVM_DEBUG(logger.startLine() - << "[cf] yielding values from the selection region\n"); + << "[cf] yielding values from the selection / loop region\n"); // Update `mlir.merge` with values to be yield. auto mergeOps = body.back().getOps(); @@ -2059,25 +2063,40 @@ LogicalResult ControlFlowStructurizer::structurize() { merge->setOperands(valuesToYield); // MLIR does not allow changing the number of results of an operation, so - // we create a new SelectionOp with required list of results and move - // the region from the initial SelectionOp. The initial operation is then - // removed. Since we move the region to the new op all links between blocks - // and remapping we have previously done should be preserved. + // we create a new SelectionOp / LoopOp with required list of results and + // move the region from the initial SelectionOp / LoopOp. The initial + // operation is then removed. Since we move the region to the new op all + // links between blocks and remapping we have previously done should be + // preserved. builder.setInsertionPoint(&mergeBlock->front()); - auto selectionOp = builder.create( - location, TypeRange(ValueRange(outsideUses)), - static_cast(control)); - selectionOp->getRegion(0).takeBody(body); + + Operation *newOp = nullptr; + + if (isLoop) + newOp = builder.create( + location, TypeRange(ValueRange(outsideUses)), + static_cast(control)); + else + newOp = builder.create( + location, TypeRange(ValueRange(outsideUses)), + static_cast(control)); + + newOp->getRegion(0).takeBody(body); // Remove initial op and swap the pointer to the newly created one. op->erase(); - op = selectionOp; + op = newOp; - // Update all outside uses to use results of the SelectionOp and remove - // block arguments from the original merge block. + // Update all outside uses to use results of the SelectionOp / LoopOp and + // remove block arguments from the original merge block. for (unsigned i = 0, e = outsideUses.size(); i != e; ++i) - outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i)); - mergeBlock->eraseArguments(0, mergeBlock->getNumArguments()); + outsideUses[i].replaceAllUsesWith(op->getResult(i)); + + // We do not support block arguments in loop merge block. Also running this + // function with loop would break some of the loop specific code above + // dealing with block arguments. + if (!isLoop) + mergeBlock->eraseArguments(0, mergeBlock->getNumArguments()); } // Check that whether some op in the to-be-erased blocks still has uses. Those diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp index 5ed59a4134d37..ff3cc92ee8078 100644 --- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp +++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp @@ -520,6 +520,13 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) { auto mergeID = getBlockID(mergeBlock); auto loc = loopOp.getLoc(); + // Before we do anything replace results of the selection operation with + // values yielded (with `mlir.merge`) from inside the region. + auto mergeOp = cast(mergeBlock->back()); + assert(loopOp.getNumResults() == mergeOp.getNumOperands()); + for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i) + loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i)); + // This LoopOp is in some MLIR block with preceding and following ops. In the // binary format, it should reside in separate SPIR-V blocks from its // preceding and following ops. So we need to emit unconditional branches to diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir index 107c8a3207b02..8ec0bf5bbaacf 100644 --- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir @@ -426,6 +426,47 @@ func.func @only_entry_and_continue_branch_to_header() -> () { // ----- +func.func @loop_yield(%count : i32) -> () { + %zero = spirv.Constant 0: i32 + %one = spirv.Constant 1: i32 + %var = spirv.Variable init(%zero) : !spirv.ptr + + // CHECK: {{%.*}} = spirv.mlir.loop -> i32 { + %final_i = spirv.mlir.loop -> i32 { + // CHECK-NEXT: spirv.Branch ^bb1({{%.*}}: i32) + spirv.Branch ^header(%zero: i32) + + // CHECK-NEXT: ^bb1({{%.*}}: i32): + ^header(%i : i32): + %cmp = spirv.SLessThan %i, %count : i32 + // CHECK: spirv.BranchConditional %{{.*}}, ^bb2, ^bb4 + spirv.BranchConditional %cmp, ^body, ^merge + + // CHECK-NEXT: ^bb2: + ^body: + // CHECK-NEXT: spirv.Branch ^bb3 + spirv.Branch ^continue + + // CHECK-NEXT: ^bb3: + ^continue: + %new_i = spirv.IAdd %i, %one : i32 + // CHECK: spirv.Branch ^bb1({{%.*}}: i32) + spirv.Branch ^header(%new_i: i32) + + // CHECK-NEXT: ^bb4: + ^merge: + // CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32 + spirv.mlir.merge %i : i32 + } + + // CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32 + spirv.Store "Function" %var, %final_i : i32 + + return +} + +// ----- + //===----------------------------------------------------------------------===// // spirv.mlir.merge //===----------------------------------------------------------------------===// diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir index d89600558f56d..95b87b319ac2d 100644 --- a/mlir/test/Target/SPIRV/loop.mlir +++ b/mlir/test/Target/SPIRV/loop.mlir @@ -288,3 +288,50 @@ spirv.module Physical64 OpenCL requires #spirv.vce { + spirv.func @loop_yield(%count : i32) -> () "None" { + %zero = spirv.Constant 0: i32 + %one = spirv.Constant 1: i32 + %var = spirv.Variable init(%zero) : !spirv.ptr + +// CHECK: {{%.*}} = spirv.mlir.loop -> i32 { + %final_i = spirv.mlir.loop -> i32 { +// CHECK-NEXT: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32) + spirv.Branch ^header(%zero: i32) + +// CHECK-NEXT: ^[[HEADER:.+]]({{%.*}}: i32): + ^header(%i : i32): + %cmp = spirv.SLessThan %i, %count : i32 +// CHECK: spirv.BranchConditional %{{.*}}, ^[[BODY:.+]], ^[[MERGE:.+]] + spirv.BranchConditional %cmp, ^body, ^merge + +// CHECK-NEXT: ^[[BODY:.+]]: + ^body: +// CHECK-NEXT: spirv.Branch ^[[CONTINUE:.+]] + spirv.Branch ^continue + +// CHECK-NEXT: ^[[CONTINUE:.+]]: + ^continue: + %new_i = spirv.IAdd %i, %one : i32 +// CHECK: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32) + spirv.Branch ^header(%new_i: i32) + +// CHECK-NEXT: ^[[MERGE:.+]]: + ^merge: +// CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32 + spirv.mlir.merge %i : i32 +// CHECK-NEXT: } + } + +// CHECK-NEXT: spirv.Store "Function" {{%.*}}, {{%.*}} : i32 + spirv.Store "Function" %var, %final_i : i32 + +// CHECK-NEXT: spirv.Return + spirv.Return + } +}