Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[mlir][spirv] Allow yielding values from loop regions #135344

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 30, 2025

Conversation

IgWod-IMG
Copy link
Contributor

This change extends spirv.mlir.loop so it can yield values, the same as spirv.mlir.selection.

@llvmbot
Copy link
Member

llvmbot commented Apr 11, 2025

@llvm/pr-subscribers-mlir-spirv

@llvm/pr-subscribers-mlir

Author: Igor Wodiany (IgWod-IMG)

Changes

This change extends spirv.mlir.loop so it can yield values, the same as spirv.mlir.selection.


Full diff: https://github.com/llvm/llvm-project/pull/135344.diff

7 Files Affected:

  • (modified) mlir/docs/Dialects/SPIR-V.md (+12)
  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td (+12-2)
  • (modified) mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp (+9)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+46-33)
  • (modified) mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp (+7)
  • (modified) mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir (+41)
  • (modified) mlir/test/Target/SPIRV/loop.mlir (+47)
diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md
index ae9afbd9fdfe5..1e8c1c7be9f6a 100644
--- a/mlir/docs/Dialects/SPIR-V.md
+++ b/mlir/docs/Dialects/SPIR-V.md
@@ -734,6 +734,18 @@ func.func @loop(%count : i32) -> () {
 }
 ```
 
+Similarly to selection, loops can also yield values using `spirv.mlir.merge`. This
+mechanism allows values defined within the loop region to be used outside of it.
+
+For example
+
+```mlir
+%yielded = spirv.mlir.loop -> i32 {
+  // ...
+  spirv.mlir.merge %to_yield : i32
+}
+```
+
 ### Block argument for Phi
 
 There are no direct Phi operations in the SPIR-V dialect; SPIR-V `OpPhi`
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
index 039af03871411..ef6682ab3630c 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td
@@ -311,17 +311,27 @@ def SPIRV_LoopOp : SPIRV_Op<"mlir.loop", [InFunctionScope]> {
     The continue block should be the second to last block and it should have a
     branch to the loop header block. The loop continue block should be the only
     block, except the entry block, branching to the header block.
+
+    Values defined inside the loop regions cannot be directly used
+    outside of them; however, the loop region can yield values. These values are
+    yielded using a `spirv.mlir.merge` op and returned as a result of the loop op.
   }];
 
   let arguments = (ins
     SPIRV_LoopControlAttr:$loop_control
   );
 
-  let results = (outs);
+  let results = (outs Variadic<AnyType>:$results);
 
   let regions = (region AnyRegion:$body);
 
-  let builders = [OpBuilder<(ins)>];
+  let builders = [
+    OpBuilder<(ins)>,
+    OpBuilder<(ins "spirv::LoopControl":$loopControl),
+    [{
+      build($_builder, $_state, TypeRange(), loopControl);
+    }]>
+  ];
 
   let extraClassDeclaration = [{
     // Returns the entry block.
diff --git a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
index ed9a30086deca..cf983af6a07ac 100644
--- a/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/ControlFlowOps.cpp
@@ -230,6 +230,11 @@ ParseResult LoopOp::parse(OpAsmParser &parser, OperationState &result) {
   if (parseControlAttribute<spirv::LoopControlAttr, spirv::LoopControl>(parser,
                                                                         result))
     return failure();
+
+  if (succeeded(parser.parseOptionalArrow()))
+    if (parser.parseTypeList(result.types))
+      return failure();
+
   return parser.parseRegion(*result.addRegion(), /*arguments=*/{});
 }
 
@@ -237,6 +242,10 @@ void LoopOp::print(OpAsmPrinter &printer) {
   auto control = getLoopControl();
   if (control != spirv::LoopControl::None)
     printer << " control(" << spirv::stringifyLoopControl(control) << ")";
+  if (getNumResults() > 0) {
+    printer << " -> ";
+    printer << getResultTypes();
+  }
   printer << ' ';
   printer.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
                       /*printBlockTerminators=*/true);
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 25749ec598f00..2c7a93949307c 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -2003,7 +2003,8 @@ LogicalResult ControlFlowStructurizer::structurize() {
   // block inside the selection (`body.back()`). Values produced by block
   // arguments will be yielded by the selection region. We do not update uses or
   // erase original block arguments yet. It will be done later in the code.
-  if (!isLoop) {
+  // We do not currently support block arguments in loop merge blocks.
+  if (!isLoop)
     for (BlockArgument blockArg : mergeBlock->getArguments()) {
       // Create new block arguments in the last block ("merge block") of the
       // selection region. We create one argument for each argument in
@@ -2013,7 +2014,6 @@ LogicalResult ControlFlowStructurizer::structurize() {
       valuesToYield.push_back(body.back().getArguments().back());
       outsideUses.push_back(blockArg);
     }
-  }
 
   // All the blocks cloned into the SelectionOp/LoopOp's region can now be
   // cleaned up.
@@ -2025,32 +2025,30 @@ LogicalResult ControlFlowStructurizer::structurize() {
 
   // All internal uses should be removed from original blocks by now, so
   // whatever is left is an outside use and will need to be yielded from
-  // the newly created selection region.
-  if (!isLoop) {
-    for (Block *block : constructBlocks) {
-      for (Operation &op : *block) {
-        if (!op.use_empty())
-          for (Value result : op.getResults()) {
-            valuesToYield.push_back(mapper.lookupOrNull(result));
-            outsideUses.push_back(result);
-          }
-      }
-      for (BlockArgument &arg : block->getArguments()) {
-        if (!arg.use_empty()) {
-          valuesToYield.push_back(mapper.lookupOrNull(arg));
-          outsideUses.push_back(arg);
+  // the newly created selection / loop region.
+  for (Block *block : constructBlocks) {
+    for (Operation &op : *block) {
+      if (!op.use_empty())
+        for (Value result : op.getResults()) {
+          valuesToYield.push_back(mapper.lookupOrNull(result));
+          outsideUses.push_back(result);
         }
+    }
+    for (BlockArgument &arg : block->getArguments()) {
+      if (!arg.use_empty()) {
+        valuesToYield.push_back(mapper.lookupOrNull(arg));
+        outsideUses.push_back(arg);
       }
     }
   }
 
   assert(valuesToYield.size() == outsideUses.size());
 
-  // If we need to yield any values from the selection region we will take
-  // care of it here.
-  if (!isLoop && !valuesToYield.empty()) {
+  // If we need to yield any values from the selection / loop region we will
+  // take care of it here.
+  if (!valuesToYield.empty()) {
     LLVM_DEBUG(logger.startLine()
-               << "[cf] yielding values from the selection region\n");
+               << "[cf] yielding values from the selection / loop region\n");
 
     // Update `mlir.merge` with values to be yield.
     auto mergeOps = body.back().getOps<spirv::MergeOp>();
@@ -2059,25 +2057,40 @@ LogicalResult ControlFlowStructurizer::structurize() {
     merge->setOperands(valuesToYield);
 
     // MLIR does not allow changing the number of results of an operation, so
-    // we create a new SelectionOp with required list of results and move
-    // the region from the initial SelectionOp. The initial operation is then
-    // removed. Since we move the region to the new op all links between blocks
-    // and remapping we have previously done should be preserved.
+    // we create a new SelectionOp / LoopOp with required list of results and
+    // move the region from the initial SelectionOp / LoopOp. The initial
+    // operation is then removed. Since we move the region to the new op all
+    // links between blocks and remapping we have previously done should be
+    // preserved.
     builder.setInsertionPoint(&mergeBlock->front());
-    auto selectionOp = builder.create<spirv::SelectionOp>(
-        location, TypeRange(ValueRange(outsideUses)),
-        static_cast<spirv::SelectionControl>(control));
-    selectionOp->getRegion(0).takeBody(body);
+
+    Operation *newOp = nullptr;
+
+    if (isLoop)
+      newOp = builder.create<spirv::LoopOp>(
+          location, TypeRange(ValueRange(outsideUses)),
+          static_cast<spirv::LoopControl>(control));
+    else
+      newOp = builder.create<spirv::SelectionOp>(
+          location, TypeRange(ValueRange(outsideUses)),
+          static_cast<spirv::SelectionControl>(control));
+
+    newOp->getRegion(0).takeBody(body);
 
     // Remove initial op and swap the pointer to the newly created one.
     op->erase();
-    op = selectionOp;
+    op = newOp;
 
-    // Update all outside uses to use results of the SelectionOp and remove
-    // block arguments from the original merge block.
+    // Update all outside uses to use results of the SelectionOp / LoopOp and
+    // remove block arguments from the original merge block.
     for (unsigned i = 0, e = outsideUses.size(); i != e; ++i)
-      outsideUses[i].replaceAllUsesWith(selectionOp.getResult(i));
-    mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
+      outsideUses[i].replaceAllUsesWith(op->getResult(i));
+
+    // We do not support block arguments in loop merge block. Also running this
+    // function with loop would break some of the loop specific code above
+    // dealing with block arguments.
+    if (!isLoop)
+      mergeBlock->eraseArguments(0, mergeBlock->getNumArguments());
   }
 
   // Check that whether some op in the to-be-erased blocks still has uses. Those
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 5ed59a4134d37..ff3cc92ee8078 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -520,6 +520,13 @@ LogicalResult Serializer::processLoopOp(spirv::LoopOp loopOp) {
   auto mergeID = getBlockID(mergeBlock);
   auto loc = loopOp.getLoc();
 
+  // Before we do anything replace results of the selection operation with
+  // values yielded (with `mlir.merge`) from inside the region.
+  auto mergeOp = cast<spirv::MergeOp>(mergeBlock->back());
+  assert(loopOp.getNumResults() == mergeOp.getNumOperands());
+  for (unsigned i = 0, e = loopOp.getNumResults(); i != e; ++i)
+    loopOp.getResult(i).replaceAllUsesWith(mergeOp.getOperand(i));
+
   // This LoopOp is in some MLIR block with preceding and following ops. In the
   // binary format, it should reside in separate SPIR-V blocks from its
   // preceding and following ops. So we need to emit unconditional branches to
diff --git a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
index 107c8a3207b02..8ec0bf5bbaacf 100644
--- a/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/control-flow-ops.mlir
@@ -426,6 +426,47 @@ func.func @only_entry_and_continue_branch_to_header() -> () {
 
 // -----
 
+func.func @loop_yield(%count : i32) -> () {
+  %zero = spirv.Constant 0: i32
+  %one = spirv.Constant 1: i32
+  %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+  // CHECK: {{%.*}} = spirv.mlir.loop -> i32 {
+  %final_i = spirv.mlir.loop -> i32 {
+    // CHECK-NEXT: spirv.Branch ^bb1({{%.*}}: i32)
+    spirv.Branch ^header(%zero: i32)
+
+  // CHECK-NEXT: ^bb1({{%.*}}: i32):
+  ^header(%i : i32):
+    %cmp = spirv.SLessThan %i, %count : i32
+    // CHECK: spirv.BranchConditional %{{.*}}, ^bb2, ^bb4
+    spirv.BranchConditional %cmp, ^body, ^merge
+
+  // CHECK-NEXT: ^bb2:
+  ^body:
+    // CHECK-NEXT: spirv.Branch ^bb3
+    spirv.Branch ^continue
+
+  // CHECK-NEXT: ^bb3:
+  ^continue:
+    %new_i = spirv.IAdd %i, %one : i32
+    // CHECK: spirv.Branch ^bb1({{%.*}}: i32)
+    spirv.Branch ^header(%new_i: i32)
+
+  // CHECK-NEXT: ^bb4:
+  ^merge:
+    // CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
+    spirv.mlir.merge %i : i32
+  }
+
+  // CHECK: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+  spirv.Store "Function" %var, %final_i : i32
+
+  return
+}
+
+// -----
+
 //===----------------------------------------------------------------------===//
 // spirv.mlir.merge
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Target/SPIRV/loop.mlir b/mlir/test/Target/SPIRV/loop.mlir
index d89600558f56d..95b87b319ac2d 100644
--- a/mlir/test/Target/SPIRV/loop.mlir
+++ b/mlir/test/Target/SPIRV/loop.mlir
@@ -288,3 +288,50 @@ spirv.module Physical64 OpenCL requires #spirv.vce<v1.0, [Kernel, Linkage, Addre
     spirv.Return
   }
 }
+
+// -----
+
+// Loop yielding values
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader, Linkage], []> {
+  spirv.func @loop_yield(%count : i32) -> () "None" {
+    %zero = spirv.Constant 0: i32
+    %one = spirv.Constant 1: i32
+    %var = spirv.Variable init(%zero) : !spirv.ptr<i32, Function>
+
+// CHECK: {{%.*}} = spirv.mlir.loop -> i32 {
+    %final_i = spirv.mlir.loop -> i32 {
+// CHECK-NEXT: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+      spirv.Branch ^header(%zero: i32)
+
+// CHECK-NEXT: ^[[HEADER:.+]]({{%.*}}: i32):
+    ^header(%i : i32):
+      %cmp = spirv.SLessThan %i, %count : i32
+// CHECK: spirv.BranchConditional %{{.*}}, ^[[BODY:.+]], ^[[MERGE:.+]]
+      spirv.BranchConditional %cmp, ^body, ^merge
+
+// CHECK-NEXT: ^[[BODY:.+]]:
+    ^body:
+// CHECK-NEXT: spirv.Branch ^[[CONTINUE:.+]]
+      spirv.Branch ^continue
+
+// CHECK-NEXT: ^[[CONTINUE:.+]]:
+    ^continue:
+      %new_i = spirv.IAdd %i, %one : i32
+// CHECK: spirv.Branch ^[[HEADER:.+]]({{%.*}}: i32)
+      spirv.Branch ^header(%new_i: i32)
+
+// CHECK-NEXT: ^[[MERGE:.+]]:
+    ^merge:
+// CHECK-NEXT: spirv.mlir.merge {{%.*}} : i32
+      spirv.mlir.merge %i : i32
+// CHECK-NEXT: }
+    }
+
+// CHECK-NEXT: spirv.Store "Function" {{%.*}}, {{%.*}} : i32
+    spirv.Store "Function" %var, %final_i : i32
+
+// CHECK-NEXT: spirv.Return
+    spirv.Return
+  }
+}

@kuhar kuhar requested a review from andfau-amd April 11, 2025 14:11
Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@IgWod-IMG
Copy link
Contributor Author

Ping @andfau-amd (but no pressure, I just came back from holiday, so I am pinging open PRs :))

@andfau-amd
Copy link
Contributor

Oh, thanks for the ping, I either missed this or forgot about it. Will take a look now!

@@ -2003,7 +2003,8 @@ LogicalResult ControlFlowStructurizer::structurize() {
// block inside the selection (`body.back()`). Values produced by block
// arguments will be yielded by the selection region. We do not update uses or
// erase original block arguments yet. It will be done later in the code.
if (!isLoop) {
// We do not currently support block arguments in loop merge blocks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does this mean in practice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you look above (lines 1971-1974) there is a piece of code specific to loops that fails if there are block arguments in the merge block, then the next statements (lines 1979-1980) copy arguments from the header block to the merge block. So, running this code for loops would yield arguments that were copied from the header block into the merge block. I'm not sure it's a desired behaviour or at least I don't have / can't think of an example where this is a desired behaviour. So, I don't run this code for loops. Worst case scenario, there will be some uses that escape the loop but that should be caught later. Maybe the comment in the code should make it clearer.

Hope that explains the rational and somehow answers the question.

This change extends `spirv.mlir.loop` so it can yield values, the
same as `spirv.mlir.selection`.
@IgWod-IMG
Copy link
Contributor Author

I have updated the comment and rebased on top of main to include other SPIR-V patches I merged today - I wanted to make sure it all still builds and passes tests. I'll merge it once all checks passes.

@IgWod-IMG IgWod-IMG merged commit 721c5cc into llvm:main Apr 30, 2025
11 checks passed
@IgWod-IMG IgWod-IMG deleted the img_loop-yield branch April 30, 2025 15:07
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This change extends `spirv.mlir.loop` so it can yield values, the same
as `spirv.mlir.selection`.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This change extends `spirv.mlir.loop` so it can yield values, the same
as `spirv.mlir.selection`.
IanWood1 pushed a commit to IanWood1/llvm-project that referenced this pull request May 6, 2025
This change extends `spirv.mlir.loop` so it can yield values, the same
as `spirv.mlir.selection`.
GeorgeARM pushed a commit to GeorgeARM/llvm-project that referenced this pull request May 7, 2025
This change extends `spirv.mlir.loop` so it can yield values, the same
as `spirv.mlir.selection`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants