-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[MLIR][SCF] Fix normalizeForallOp helper function #138615
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir-scf @llvm/pr-subscribers-mlir Author: Colin De Vlieghere (Cubevoid) ChangesPreviously the This patch fixes the helper function and adds a unit test for it. Full diff: https://github.com/llvm/llvm-project/pull/138615.diff 3 Files Affected:
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index e9471c1dbd0b7..d6bed551ec8fa 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1482,30 +1482,31 @@ FailureOr<scf::ForallOp> mlir::normalizeForallOp(RewriterBase &rewriter,
SmallVector<OpFoldResult> ubs = forallOp.getMixedUpperBound();
SmallVector<OpFoldResult> steps = forallOp.getMixedStep();
- if (llvm::all_of(
- lbs, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }) &&
- llvm::all_of(
- steps, [](OpFoldResult ofr) { return isConstantIntValue(ofr, 1); })) {
+ if (forallOp.isNormalized())
return forallOp;
- }
- SmallVector<OpFoldResult> newLbs, newUbs, newSteps;
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(forallOp);
+ SmallVector<OpFoldResult> newUbs;
for (auto [lb, ub, step] : llvm::zip_equal(lbs, ubs, steps)) {
Range normalizedLoopParams =
emitNormalizedLoopBounds(rewriter, forallOp.getLoc(), lb, ub, step);
- newLbs.push_back(normalizedLoopParams.offset);
newUbs.push_back(normalizedLoopParams.size);
- newSteps.push_back(normalizedLoopParams.stride);
}
+ (void)foldDynamicIndexList(newUbs);
+ // Use the normalized builder since the lower bounds are always 0 and the
+ // steps are always 1.
auto normalizedForallOp = rewriter.create<scf::ForallOp>(
- forallOp.getLoc(), newLbs, newUbs, newSteps, forallOp.getOutputs(),
- forallOp.getMapping(), [](OpBuilder &, Location, ValueRange) {});
+ forallOp.getLoc(), newUbs, forallOp.getOutputs(), forallOp.getMapping(),
+ [](OpBuilder &, Location, ValueRange) {});
rewriter.inlineRegionBefore(forallOp.getBodyRegion(),
normalizedForallOp.getBodyRegion(),
normalizedForallOp.getBodyRegion().begin());
+ // Remove the original empty block in the new loop.
+ rewriter.eraseBlock(&normalizedForallOp.getBodyRegion().back());
- rewriter.replaceAllOpUsesWith(forallOp, normalizedForallOp);
- return success();
+ rewriter.replaceOp(forallOp, normalizedForallOp);
+ return normalizedForallOp;
}
diff --git a/mlir/unittests/Dialect/SCF/CMakeLists.txt b/mlir/unittests/Dialect/SCF/CMakeLists.txt
index c0c1757b80fb5..83cefbcabf4d9 100644
--- a/mlir/unittests/Dialect/SCF/CMakeLists.txt
+++ b/mlir/unittests/Dialect/SCF/CMakeLists.txt
@@ -5,4 +5,5 @@ mlir_target_link_libraries(MLIRSCFTests
PRIVATE
MLIRIR
MLIRSCFDialect
+ MLIRSCFUtils
)
diff --git a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
index 53a4af14d119a..e4a3a857a747e 100644
--- a/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
+++ b/mlir/unittests/Dialect/SCF/LoopLikeSCFOpsTest.cpp
@@ -6,11 +6,15 @@
//
//===----------------------------------------------------------------------===//
+#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
+#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OwningOpRef.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Interfaces/LoopLikeInterface.h"
#include "gtest/gtest.h"
using namespace mlir;
@@ -23,7 +27,7 @@ using namespace mlir::scf;
class SCFLoopLikeTest : public ::testing::Test {
protected:
SCFLoopLikeTest() : b(&context), loc(UnknownLoc::get(&context)) {
- context.loadDialect<arith::ArithDialect, scf::SCFDialect>();
+ context.loadDialect<affine::AffineDialect, arith::ArithDialect, scf::SCFDialect>();
}
void checkUnidimensional(LoopLikeOpInterface loopLikeOp) {
@@ -88,6 +92,24 @@ class SCFLoopLikeTest : public ::testing::Test {
EXPECT_EQ((*maybeInductionVars).size(), 2u);
}
+ void checkNormalized(LoopLikeOpInterface loopLikeOp) {
+ std::optional<SmallVector<OpFoldResult>> maybeLb =
+ loopLikeOp.getLoopLowerBounds();
+ ASSERT_TRUE(maybeLb.has_value());
+ std::optional<SmallVector<OpFoldResult>> maybeStep =
+ loopLikeOp.getLoopSteps();
+ ASSERT_TRUE(maybeStep.has_value());
+
+ auto allEqual = [](ArrayRef<OpFoldResult> results, int64_t val) {
+ return llvm::all_of(results, [&](OpFoldResult ofr) {
+ auto intValue = getConstantIntValue(ofr);
+ return intValue.has_value() && intValue == val;
+ });
+ };
+ EXPECT_TRUE(allEqual(*maybeLb, 0));
+ EXPECT_TRUE(allEqual(*maybeStep, 1));
+ }
+
MLIRContext context;
OpBuilder b;
Location loc;
@@ -138,3 +160,23 @@ TEST_F(SCFLoopLikeTest, queryMultidimensionalLooplikes) {
ValueRange({step->getResult(), step->getResult()}), ValueRange());
checkMultidimensional(parallelOp.get());
}
+
+TEST_F(SCFLoopLikeTest, testForallNormalize) {
+ OwningOpRef<arith::ConstantIndexOp> lb =
+ b.create<arith::ConstantIndexOp>(loc, 1);
+ OwningOpRef<arith::ConstantIndexOp> ub =
+ b.create<arith::ConstantIndexOp>(loc, 10);
+ OwningOpRef<arith::ConstantIndexOp> step =
+ b.create<arith::ConstantIndexOp>(loc, 3);
+
+ scf::ForallOp forallOp = b.create<scf::ForallOp>(
+ loc, ArrayRef<OpFoldResult>({lb->getResult(), lb->getResult()}),
+ ArrayRef<OpFoldResult>({ub->getResult(), ub->getResult()}),
+ ArrayRef<OpFoldResult>({step->getResult(), step->getResult()}),
+ ValueRange(), std::nullopt);
+ IRRewriter rewriter(b);
+ FailureOr<scf::ForallOp> maybeNormalizedForallOp = normalizeForallOp(rewriter, forallOp);
+ EXPECT_TRUE(succeeded(maybeNormalizedForallOp));
+ OwningOpRef<scf::ForallOp> normalizedForallOp(*maybeNormalizedForallOp);
+ checkNormalized(normalizedForallOp.get());
+}
|
cd9fa1a
to
58c9fc2
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the fix!
Previously the `normalizeForallOp` function did not work properly, since the newly created op was not being returned in addition to the op failing verification. This patch fixes the helper function and adds a unit test for it.
58c9fc2
to
e1acacc
Compare
Rebased, and check-mlir passes. |
I realized that the previous code was also incorrect because I forgot to update the induction variable users in the loops. Now I added a call to |
Previously the
normalizeForallOp
function did not work properly, since the newly created op was not being returned in addition to the op failing verification.This patch fixes the helper function and adds a unit test for it.