Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[MLIR][SCF] Fix normalizeForallOp helper function #138615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

Cubevoid
Copy link

@Cubevoid Cubevoid commented May 5, 2025

Previously the normalizeForallOp function did not work properly, since the newly created op was not being returned in addition to the op failing verification.

This patch fixes the helper function and adds a unit test for it.

Copy link

github-actions bot commented May 5, 2025

Thank you for submitting a Pull Request (PR) to the LLVM Project!

This PR will be automatically labeled and the relevant teams will be notified.

If you wish to, you can add reviewers by using the "Reviewers" section on this page.

If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using @ followed by their GitHub username.

If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers.

If you have further questions, they may be answered by the LLVM GitHub User Guide.

You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums.

@llvmbot
Copy link
Member

llvmbot commented May 5, 2025

@llvm/pr-subscribers-mlir-scf

@llvm/pr-subscribers-mlir

Author: Colin De Vlieghere (Cubevoid)

Changes

Previously the normalizeForallOp function did not work properly, since the newly created op was not being returned in addition to the op failing verification.

This patch fixes the helper function and adds a unit test for it.


Full diff: https://github.com/llvm/llvm-project/pull/138615.diff

3 Files Affected:

  • (modified) mlir/lib/Dialect/SCF/Utils/Utils.cpp (+13-12)
  • (modified) mlir/unittests/Dialect/SCF/CMakeLists.txt (+1)
  • (modified) mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp (+43-1)
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index e9471c1dbd0b7..d6bed551ec8fa 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1482,30 +1482,31 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
   SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
   SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
 
-  if (llvm::all_of(
-          lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
-      llvm::all_of(
-          steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
+  if (forallOp.isNormalized())
     return forallOp;
-  }
 
-  SmallVector<OpFoldResult> newLbs, newUbs, newSteps;
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(forallOp);
+  SmallVector<OpFoldResult> newUbs;
   for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
     Range normalizedLoopParams =
         emitNormalizedLoopBounds(rewriter, forallOp.getLoc(), lb, ub, step);
-    newLbs.push_back(normalizedLoopParams.offset);
     newUbs.push_back(normalizedLoopParams.size);
-    newSteps.push_back(normalizedLoopParams.stride);
   }
+  (void)foldDynamicIndexList(newUbs);
 
+  // Use the normalized builder since the lower bounds are always 0 and the
+  // steps are always 1.
   auto normalizedForallOp = rewriter.create<scf::ForallOp>(
-      forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(),
-      forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {});
+      forallOp.getLoc(), newUbs, forallOp.getOutputs(), forallOp.getMapping(),
+      [](OpBuilder &, Location, ValueRange) {});
 
   rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
                               normalizedForallOp.getBodyRegion(),
                               normalizedForallOp.getBodyRegion().begin());
+  // Remove the original empty block in the new loop.
+  rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
 
-  rewriter.replaceAllOpUsesWith(forallOp, normalizedForallOp);
-  return success();
+  rewriter.replaceOp(forallOp, normalizedForallOp);
+  return normalizedForallOp;
 }
diff --git a/mlir/unittests/Dialect/SCF/CMakeLists.txt b/mlir/unittests/Dialect/SCF/CMakeLists.txt
index c0c1757b80fb5..83cefbcabf4d9 100644
--- a/mlir/unittests/Dialect/SCF/CMakeLists.txt
+++ b/mlir/unittests/Dialect/SCF/CMakeLists.txt
@@ -5,4 +5,5 @@ mlir_target_link_libraries(MLIRSCFTests
   PRIVATE
   MLIRIR
   MLIRSCFDialect
+  MLIRSCFUtils
 )
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index 53a4af14d119a..e4a3a857a747e 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -6,11 +6,15 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
 #include "mlir/IR/Diagnostics.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/OwningOpRef.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
 #include "gtest/gtest.h"
 
 using namespace mlir;
@@ -23,7 +27,7 @@ using namespace mlir::scf;
 class SCFLoopLikeTest : public ::testing::Test {
 protected:
   SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
-    context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
+    context.loadDialect<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect>();
   }
 
   void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -88,6 +92,24 @@ class SCFLoopLikeTest : public ::testing::Test {
     EXPECT_EQ((*maybeInductionVars).size(), 2u);
   }
 
+  void checkNormalized(LoopLikeOpInterface loopLikeOp) {
+    std::optional<SmallVector<OpFoldResult>> maybeLb =
+        loopLikeOp.getLoopLowerBounds();
+    ASSERT_TRUE(maybeLb.has_value());
+    std::optional<SmallVector<OpFoldResult>> maybeStep =
+        loopLikeOp.getLoopSteps();
+    ASSERT_TRUE(maybeStep.has_value());
+
+    auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
+      return llvm::all_of(results, [&](OpFoldResult ofr) {
+        auto intValue = getConstantIntValue(ofr);
+        return intValue.has_value() && intValue == val;
+      });
+    };
+    EXPECT_TRUE(allEqual(*maybeLb, 0));
+    EXPECT_TRUE(allEqual(*maybeStep, 1));
+  }
+
   MLIRContext context;
   OpBuilder b;
   Location loc;
@@ -138,3 +160,23 @@ TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
       ValueRange({step->getResult(), step->getResult()}), ValueRange());
   checkMultidimensional(parallelOp.get());
 }
+
+TEST_F(SCFLoopLikeTest, testForallNormalize) {
+  OwningOpRef<arith::ConstantIndexOp> lb =
+      b.create<arith::ConstantIndexOp>(loc, 1);
+  OwningOpRef<arith::ConstantIndexOp> ub =
+      b.create<arith::ConstantIndexOp>(loc, 10);
+  OwningOpRef<arith::ConstantIndexOp> step =
+      b.create<arith::ConstantIndexOp>(loc, 3);
+
+  scf::ForallOp forallOp = b.create<scf::ForallOp>(
+      loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
+      ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
+      ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
+      ValueRange(), std::nullopt);
+  IRRewriter rewriter(b);
+  FailureOr<scf::ForallOp> maybeNormalizedForallOp = normalizeForallOp(rewriter, forallOp);
+  EXPECT_TRUE(succeeded(maybeNormalizedForallOp));
+  OwningOpRef<scf::ForallOp> normalizedForallOp(*maybeNormalizedForallOp);
+  checkNormalized(normalizedForallOp.get());
+}

@Cubevoid Cubevoid force-pushed the mlir/scf/fix_forall_normalize branch from cd9fa1a to 58c9fc2 Compare May 6, 2025 02:10
Copy link
Contributor

@mfrancio mfrancio left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

Previously the `normalizeForallOp` function did not work properly, since the
newly created op was not being returned in addition to the op failing
verification.

This patch fixes the helper function and adds a unit test for it.
@Cubevoid Cubevoid force-pushed the mlir/scf/fix_forall_normalize branch from 58c9fc2 to e1acacc Compare May 7, 2025 17:47
@Cubevoid
Copy link
Author

Cubevoid commented May 7, 2025

Rebased, and check-mlir passes.

@Cubevoid
Copy link
Author

Cubevoid commented May 7, 2025

I realized that the previous code was also incorrect because I forgot to update the induction variable users in the loops. Now I added a call to denormalizeInductionVariable to update the users, and the unit test checks that the IV user no longer points to the normalized IV.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants