-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[mlir][spirv] Allow yielding values from loop regions #135344
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Igor Wodiany (IgWod-IMG) ChangesThis change extends Full diff: https://github.com/llvm/llvm-project/pull/135344.diff 7 Files Affected:
diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index ae9afbd9fdfe5..1e8c1c7be9f6a 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -734,6 +734,18 @@ func.func @loop(%count : i32) -> () {
}
```
+Similarly to selection, loops can also yield values using `spirv.mlir.merge`. This
+mechanism allows values defined within the loop region to be used outside of it.
+
+For example
+
+```mlir
+%yielded = spirv.mlir.loop -> i32 {
+ // ...
+ spirv.mlir.merge %to_yield : i32
+}
+```
+
### Block argument for Phi
There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi`
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index 039af03871411..ef6682ab3630c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -311,17 +311,27 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
The continue block should be the second to last block and it should have a
branch to the loop header block. The loop continue block should be the only
block, except the entry block, branching to the header block.
+
+ Values defined inside the loop regions cannot be directly used
+ outside of them; however, the loop region can yield values. These values are
+ yielded using a `spirv.mlir.merge` op and returned as a result of the loop op.
}];
let arguments = (ins
SPIRV_LoopControlAttr:$loop_control
);
- let results = (outs);
+ let results = (outs Variadic<AnyType>:$results);
let regions = (region AnyRegion:$body);
- let builders = [OpBuilder<(ins)>];
+ let builders = [
+ OpBuilder<(ins)>,
+ OpBuilder<(ins "spirv::LoopControl":$loopControl),
+ [{
+ build($_builder, $_state, TypeRange(), loopControl);
+ }]>
+ ];
let extraClassDeclaration = [{
// Returns the entry block.
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index ed9a30086deca..cf983af6a07ac 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -230,6 +230,11 @@ ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
result))
return failure();
+
+ if (succeeded(parser.parseOptionalArrow()))
+ if (parser.parseTypeList(result.types))
+ return failure();
+
return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
}
@@ -237,6 +242,10 @@ void LoopOp::print(OpAsmPrinter &printer) {
auto control = getLoopControl();
if (control != spirv::LoopControl::None)
printer << " control(" << spirv::stringifyLoopControl(control) << ")";
+ if (getNumResults() > 0) {
+ printer << " -> ";
+ printer << getResultTypes();
+ }
printer << ' ';
printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/true);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 25749ec598f00..2c7a93949307c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2003,7 +2003,8 @@ LogicalResult ControlFlowStructurizer::structurize() {
// block inside the selection (`body.back()`). Values produced by block
// arguments will be yielded by the selection region. We do not update uses or
// erase original block arguments yet. It will be done later in the code.
- if (!isLoop) {
+ // We do not currently support block arguments in loop merge blocks.
+ if (!isLoop)
for (BlockArgument blockArg : mergeBlock->getArguments()) {
// Create new block arguments in the last block ("merge block") of the
// selection region. We create one argument for each argument in
@@ -2013,7 +2014,6 @@ LogicalResult ControlFlowStructurizer::structurize() {
valuesToYield.push_back(body.back().getArguments().back());
outsideUses.push_back(blockArg);
}
- }
// All the blocks cloned into the SelectionOp/LoopOp's region can now be
// cleaned up.
@@ -2025,32 +2025,30 @@ LogicalResult ControlFlowStructurizer::structurize() {
// All internal uses should be removed from original blocks by now, so
// whatever is left is an outside use and will need to be yielded from
- // the newly created selection region.
- if (!isLoop) {
- for (Block *block : constructBlocks) {
- for (Operation &op : *block) {
- if (!op.use_empty())
- for (Value result : op.getResults()) {
- valuesToYield.push_back(mapper.lookupOrNull(result));
- outsideUses.push_back(result);
- }
- }
- for (BlockArgument &arg : block->getArguments()) {
- if (!arg.use_empty()) {
- valuesToYield.push_back(mapper.lookupOrNull(arg));
- outsideUses.push_back(arg);
+ // the newly created selection / loop region.
+ for (Block *block : constructBlocks) {
+ for (Operation &op : *block) {
+ if (!op.use_empty())
+ for (Value result : op.getResults()) {
+ valuesToYield.push_back(mapper.lookupOrNull(result));
+ outsideUses.push_back(result);
}
+ }
+ for (BlockArgument &arg : block->getArguments()) {
+ if (!arg.use_empty()) {
+ valuesToYield.push_back(mapper.lookupOrNull(arg));
+ outsideUses.push_back(arg);
}
}
}
assert(valuesToYield.size() == outsideUses.size());
- // If we need to yield any values from the selection region we will take
- // care of it here.
- if (!isLoop && !valuesToYield.empty()) {
+ // If we need to yield any values from the selection / loop region we will
+ // take care of it here.
+ if (!valuesToYield.empty()) {
LLVM_DEBUG(logger.startLine()
- << "[cf] yielding values from the selection region\n");
+ << "[cf] yielding values from the selection / loop region\n");
// Update `mlir.merge` with values to be yield.
auto mergeOps = body.back().getOps<spirv::MergeOp>();
@@ -2059,25 +2057,40 @@ LogicalResult ControlFlowStructurizer::structurize() {
merge->setOperands(valuesToYield);
// MLIR does not allow changing the number of results of an operation, so
- // we create a new SelectionOp with required list of results and move
- // the region from the initial SelectionOp. The initial operation is then
- // removed. Since we move the region to the new op all links between blocks
- // and remapping we have previously done should be preserved.
+ // we create a new SelectionOp / LoopOp with required list of results and
+ // move the region from the initial SelectionOp / LoopOp. The initial
+ // operation is then removed. Since we move the region to the new op all
+ // links between blocks and remapping we have previously done should be
+ // preserved.
builder.setInsertionPoint(&mergeBlock->front());
- auto selectionOp = builder.create<spirv::SelectionOp>(
- location, TypeRange(ValueRange(outsideUses)),
- static_cast<spirv::SelectionControl>(control));
- selectionOp->getRegion(0).takeBody(body);
+
+ Operation *newOp = nullptr;
+
+ if (isLoop)
+ newOp = builder.create<spirv::LoopOp>(
+ location, TypeRange(ValueRange(outsideUses)),
+ static_cast<spirv::LoopControl>(control));
+ else
+ newOp = builder.create<spirv::SelectionOp>(
+ location, TypeRange(ValueRange(outsideUses)),
+ static_cast<spirv::SelectionControl>(control));
+
+ newOp->getRegion(0).takeBody(body);
// Remove initial op and swap the pointer to the newly created one.
op->erase();
- op = selectionOp;
+ op = newOp;
- // Update all outside uses to use results of the SelectionOp and remove
- // block arguments from the original merge block.
+ // Update all outside uses to use results of the SelectionOp / LoopOp and
+ // remove block arguments from the original merge block.
for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
- outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i));
- mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
+ outsideUses[i].replaceAllUsesWith(op->getResult(i));
+
+ // We do not support block arguments in loop merge block. Also running this
+ // function with loop would break some of the loop specific code above
+ // dealing with block arguments.
+ if (!isLoop)
+ mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
}
// Check that whether some op in the to-be-erased blocks still has uses. Those
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 5ed59a4134d37..ff3cc92ee8078 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -520,6 +520,13 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
auto mergeID = getBlockID(mergeBlock);
auto loc = loopOp.getLoc();
+ // Before we do anything replace results of the selection operation with
+ // values yielded (with `mlir.merge`) from inside the region.
+ auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
+ assert(loopOp.getNumResults() == mergeOp.getNumOperands());
+ for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
+ loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
+
// This LoopOp is in some MLIR block with preceding and following ops. In the
// binary format, it should reside in separate SPIR-V blocks from its
// preceding and following ops. So we need to emit unconditional branches to
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 107c8a3207b02..8ec0bf5bbaacf 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -426,6 +426,47 @@ func.func @only_entry_and_continue_branch_to_header() -> () {
// -----
+func.func @loop_yield(%count : i32) -> () {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+ // CHECK: {{%.*}} = spirv.mlir.loop -> i32 {
+ %final_i = spirv.mlir.loop -> i32 {
+ // CHECK-NEXT: spirv.Branch ^bb1({{%.*}}: i32)
+ spirv.Branch ^header(%zero: i32)
+
+ // CHECK-NEXT: ^bb1({{%.*}}: i32):
+ ^header(%i : i32):
+ %cmp = spirv.SLessThan %i, %count : i32
+ // CHECK: spirv.BranchConditional %{{.*}}, ^bb2, ^bb4
+ spirv.BranchConditional %cmp, ^body, ^merge
+
+ // CHECK-NEXT: ^bb2:
+ ^body:
+ // CHECK-NEXT: spirv.Branch ^bb3
+ spirv.Branch ^continue
+
+ // CHECK-NEXT: ^bb3:
+ ^continue:
+ %new_i = spirv.IAdd %i, %one : i32
+ // CHECK: spirv.Branch ^bb1({{%.*}}: i32)
+ spirv.Branch ^header(%new_i: i32)
+
+ // CHECK-NEXT: ^bb4:
+ ^merge:
+ // CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
+ spirv.mlir.merge %i : i32
+ }
+
+ // CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+ spirv.Store "Function" %var, %final_i : i32
+
+ return
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// spirv.mlir.merge
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir
index d89600558f56d..95b87b319ac2d 100644
--- a/mlir/test/Target/SPIRV/loop.mlir
+++ b/mlir/test/Target/SPIRV/loop.mlir
@@ -288,3 +288,50 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage, Addre
spirv.Return
}
}
+
+// -----
+
+// Loop yielding values
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+ spirv.func @loop_yield(%count : i32) -> () "None" {
+ %zero = spirv.Constant 0: i32
+ %one = spirv.Constant 1: i32
+ %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+// CHECK: {{%.*}} = spirv.mlir.loop -> i32 {
+ %final_i = spirv.mlir.loop -> i32 {
+// CHECK-NEXT: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+ spirv.Branch ^header(%zero: i32)
+
+// CHECK-NEXT: ^[[HEADER:.+]]({{%.*}}: i32):
+ ^header(%i : i32):
+ %cmp = spirv.SLessThan %i, %count : i32
+// CHECK: spirv.BranchConditional %{{.*}}, ^[[BODY:.+]], ^[[MERGE:.+]]
+ spirv.BranchConditional %cmp, ^body, ^merge
+
+// CHECK-NEXT: ^[[BODY:.+]]:
+ ^body:
+// CHECK-NEXT: spirv.Branch ^[[CONTINUE:.+]]
+ spirv.Branch ^continue
+
+// CHECK-NEXT: ^[[CONTINUE:.+]]:
+ ^continue:
+ %new_i = spirv.IAdd %i, %one : i32
+// CHECK: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+ spirv.Branch ^header(%new_i: i32)
+
+// CHECK-NEXT: ^[[MERGE:.+]]:
+ ^merge:
+// CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
+ spirv.mlir.merge %i : i32
+// CHECK-NEXT: }
+ }
+
+// CHECK-NEXT: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+ spirv.Store "Function" %var, %final_i : i32
+
+// CHECK-NEXT: spirv.Return
+ spirv.Return
+ }
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Ping @andfau-amd (but no pressure, I just came back from holiday, so I am pinging open PRs :)) |
Oh, thanks for the ping, I either missed this or forgot about it. Will take a look now! |
@@ -2003,7 +2003,8 @@ LogicalResult ControlFlowStructurizer::structurize() { | |||
// block inside the selection (`body.back()`). Values produced by block | |||
// arguments will be yielded by the selection region. We do not update uses or | |||
// erase original block arguments yet. It will be done later in the code. | |||
if (!isLoop) { | |||
// We do not currently support block arguments in loop merge blocks. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this mean in practice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you look above (lines 1971-1974) there is a piece of code specific to loops that fails if there are block arguments in the merge block, then the next statements (lines 1979-1980) copy arguments from the header block to the merge block. So, running this code for loops would yield arguments that were copied from the header block into the merge block. I'm not sure it's a desired behaviour or at least I don't have / can't think of an example where this is a desired behaviour. So, I don't run this code for loops. Worst case scenario, there will be some uses that escape the loop but that should be caught later. Maybe the comment in the code should make it clearer.
Hope that explains the rational and somehow answers the question.
This change extends `spirv.mlir.loop` so it can yield values, the same as `spirv.mlir.selection`.
2ecea55
to
1a5851a
Compare
I have updated the comment and rebased on top of main to include other SPIR-V patches I merged today - I wanted to make sure it all still builds and passes tests. I'll merge it once all checks passes. |
This change extends `spirv.mlir.loop` so it can yield values, the same as `spirv.mlir.selection`.
This change extends `spirv.mlir.loop` so it can yield values, the same as `spirv.mlir.selection`.
This change extends `spirv.mlir.loop` so it can yield values, the same as `spirv.mlir.selection`.
This change extends `spirv.mlir.loop` so it can yield values, the same as `spirv.mlir.selection`.
This change extends
spirv.mlir.loop
so it can yield values, the same asspirv.mlir.selection
.