-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[flang][fir] Support promoting fir.do_loop
with results to affine.for
.
#137790
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-flang-fir-hlfir Author: MingYan (NexMing) ChangesFull diff: https://github.com/llvm/llvm-project/pull/137790.diff 2 Files Affected:
diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
index 43fccf52dc8ab..ef82e400bea14 100644
--- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
+++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp
@@ -49,8 +49,9 @@ struct AffineIfAnalysis;
/// second when doing rewrite.
struct AffineFunctionAnalysis {
explicit AffineFunctionAnalysis(mlir::func::FuncOp funcOp) {
- for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>())
- loopAnalysisMap.try_emplace(op, op, *this);
+ funcOp->walk([&](fir::DoLoopOp doloop) {
+ loopAnalysisMap.try_emplace(doloop, doloop, *this);
+ });
}
AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const;
@@ -102,10 +103,23 @@ struct AffineLoopAnalysis {
return true;
}
+ bool analysisResults(fir::DoLoopOp loopOperation) {
+ if (loopOperation.getFinalValue() &&
+ !loopOperation.getResult(0).use_empty()) {
+ LLVM_DEBUG(
+ llvm::dbgs()
+ << "AffineLoopAnalysis: cannot promote loop final value\n";);
+ return false;
+ }
+
+ return true;
+ }
+
bool analyzeLoop(fir::DoLoopOp loopOperation,
AffineFunctionAnalysis &functionAnalysis) {
LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump(););
return analyzeMemoryAccess(loopOperation) &&
+ analysisResults(loopOperation) &&
analyzeBody(loopOperation, functionAnalysis);
}
@@ -461,14 +475,28 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
LLVM_ATTRIBUTE_UNUSED auto loopAnalysis =
functionAnalysis.getChildLoopAnalysis(loop);
auto &loopOps = loop.getBody()->getOperations();
+ auto resultOp = cast<fir::ResultOp>(loop.getBody()->getTerminator());
+ auto results = resultOp.getOperands();
+ auto loopResults = loop->getResults();
auto loopAndIndex = createAffineFor(loop, rewriter);
auto affineFor = loopAndIndex.first;
auto inductionVar = loopAndIndex.second;
+ if (loop.getFinalValue()) {
+ results = results.drop_front();
+ loopResults = loopResults.drop_front();
+ }
+
rewriter.startOpModification(affineFor.getOperation());
affineFor.getBody()->getOperations().splice(
std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(),
std::prev(loopOps.end()));
+ rewriter.replaceAllUsesWith(loop.getRegionIterArgs(),
+ affineFor.getRegionIterArgs());
+ if (!results.empty()) {
+ rewriter.setInsertionPointToEnd(affineFor.getBody());
+ rewriter.create<affine::AffineYieldOp>(resultOp->getLoc(), results);
+ }
rewriter.finalizeOpModification(affineFor.getOperation());
rewriter.startOpModification(loop.getOperation());
@@ -479,7 +507,8 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n";
affineFor.dump(););
- rewriter.replaceOp(loop, affineFor.getOperation()->getResults());
+ rewriter.replaceAllUsesWith(loopResults, affineFor->getResults());
+ rewriter.eraseOp(loop);
return success();
}
@@ -503,7 +532,7 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
ValueRange(op.getUpperBound()),
mlir::AffineMap::get(0, 1,
1 + mlir::getAffineSymbolExpr(0, op.getContext())),
- step);
+ step, op.getIterOperands());
return std::make_pair(affineFor, affineFor.getInductionVar());
}
@@ -528,7 +557,7 @@ class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
genericUpperBound.getResult(),
mlir::AffineMap::get(0, 1,
1 + mlir::getAffineSymbolExpr(0, op.getContext())),
- 1);
+ 1, op.getIterOperands());
rewriter.setInsertionPointToStart(affineFor.getBody());
auto actualIndex = rewriter.create<affine::AffineApplyOp>(
op.getLoc(), actualIndexMap,
diff --git a/flang/test/Fir/affine-promotion.fir b/flang/test/Fir/affine-promotion.fir
index aae35c6ef5659..f50f851a89eae 100644
--- a/flang/test/Fir/affine-promotion.fir
+++ b/flang/test/Fir/affine-promotion.fir
@@ -131,3 +131,68 @@ func.func @loop_with_if(%a: !arr_d1, %v: f32) {
// CHECK: }
// CHECK: return
// CHECK: }
+
+func.func @loop_with_result(%arg0: !fir.ref<!fir.array<100xf32>>, %arg1: !fir.ref<!fir.array<100x100xf32>>) -> f32 {
+ %c1 = arith.constant 1 : index
+ %cst = arith.constant 0.000000e+00 : f32
+ %c100 = arith.constant 100 : index
+ %0 = fir.shape %c100 : (index) -> !fir.shape<1>
+ %1 = fir.shape %c100, %c100 : (index, index) -> !fir.shape<2>
+ %2 = fir.alloca i32
+ %3:2 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %cst) -> (index, f32) {
+ %6 = fir.array_coor %arg0(%0) %arg2 : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
+ %7 = fir.load %6 : !fir.ref<f32>
+ %8 = arith.addf %arg3, %7 fastmath<contract> : f32
+ %9 = arith.addi %arg2, %c1 overflow<nsw> : index
+ fir.result %9, %8 : index, f32
+ }
+ %4:2 = fir.do_loop %arg2 = %c1 to %c100 step %c1 iter_args(%arg3 = %3#1) -> (index, f32) {
+ %6 = fir.array_coor %arg1(%1) %c1, %arg2 : (!fir.ref<!fir.array<100x100xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
+ %7 = fir.convert %6 : (!fir.ref<f32>) -> !fir.ref<!fir.array<100xf32>>
+ %8 = fir.do_loop %arg4 = %c1 to %c100 step %c1 iter_args(%arg5 = %arg3) -> (f32) {
+ %10 = fir.array_coor %7(%0) %arg4 : (!fir.ref<!fir.array<100xf32>>, !fir.shape<1>, index) -> !fir.ref<f32>
+ %11 = fir.load %10 : !fir.ref<f32>
+ %12 = arith.addf %arg5, %11 fastmath<contract> : f32
+ fir.result %12 : f32
+ }
+ %9 = arith.addi %arg2, %c1 overflow<nsw> : index
+ fir.result %9, %8 : index, f32
+ }
+ %5 = fir.convert %4#0 : (index) -> i32
+ fir.store %5 to %2 : !fir.ref<i32>
+ return %4#1 : f32
+}
+
+// CHECK-LABEL: func.func @loop_with_result(
+// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.array<100xf32>>,
+// CHECK-SAME: %[[ARG1:.*]]: !fir.ref<!fir.array<100x100xf32>>) -> f32 {
+// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[VAL_2:.*]] = arith.constant 100 : index
+// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_2]], %[[VAL_2]] : (index, index) -> !fir.shape<2>
+// CHECK: %[[VAL_5:.*]] = fir.alloca i32
+// CHECK: %[[VAL_6:.*]] = fir.convert %[[ARG0]] : (!fir.ref<!fir.array<100xf32>>) -> memref<?xf32>
+// CHECK: %[[VAL_7:.*]] = affine.for %[[VAL_8:.*]] = %[[VAL_0]] to #{{.*}}(){{\[}}%[[VAL_2]]] iter_args(%[[VAL_9:.*]] = %[[VAL_1]]) -> (f32) {
+// CHECK: %[[VAL_10:.*]] = affine.apply #{{.*}}(%[[VAL_8]]){{\[}}%[[VAL_0]], %[[VAL_2]], %[[VAL_0]]]
+// CHECK: %[[VAL_11:.*]] = affine.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref<?xf32>
+// CHECK: %[[VAL_12:.*]] = arith.addf %[[VAL_9]], %[[VAL_11]] fastmath<contract> : f32
+// CHECK: affine.yield %[[VAL_12]] : f32
+// CHECK: }
+// CHECK: %[[VAL_13:.*]]:2 = fir.do_loop %[[VAL_14:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_0]] iter_args(%[[VAL_15:.*]] = %[[VAL_7]]) -> (index, f32) {
+// CHECK: %[[VAL_16:.*]] = fir.array_coor %[[ARG1]](%[[VAL_4]]) %[[VAL_0]], %[[VAL_14]] : (!fir.ref<!fir.array<100x100xf32>>, !fir.shape<2>, index, index) -> !fir.ref<f32>
+// CHECK: %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (!fir.ref<f32>) -> !fir.ref<!fir.array<100xf32>>
+// CHECK: %[[VAL_18:.*]] = fir.convert %[[VAL_17]] : (!fir.ref<!fir.array<100xf32>>) -> memref<?xf32>
+// CHECK: %[[VAL_19:.*]] = affine.for %[[VAL_20:.*]] = %[[VAL_0]] to #{{.*}}(){{\[}}%[[VAL_2]]] iter_args(%[[VAL_21:.*]] = %[[VAL_15]]) -> (f32) {
+// CHECK: %[[VAL_22:.*]] = affine.apply #{{.*}}(%[[VAL_20]]){{\[}}%[[VAL_0]], %[[VAL_2]], %[[VAL_0]]]
+// CHECK: %[[VAL_23:.*]] = affine.load %[[VAL_18]]{{\[}}%[[VAL_22]]] : memref<?xf32>
+// CHECK: %[[VAL_24:.*]] = arith.addf %[[VAL_21]], %[[VAL_23]] fastmath<contract> : f32
+// CHECK: affine.yield %[[VAL_24]] : f32
+// CHECK: }
+// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_14]], %[[VAL_0]] overflow<nsw> : index
+// CHECK: fir.result %[[VAL_25]], %[[VAL_19]] : index, f32
+// CHECK: }
+// CHECK: %[[VAL_26:.*]] = fir.convert %[[VAL_27:.*]]#0 : (index) -> i32
+// CHECK: fir.store %[[VAL_26]] to %[[VAL_5]] : !fir.ref<i32>
+// CHECK: return %[[VAL_27]]#1 : f32
+// CHECK: }
|
Thank you for contributing this. After this patch, do you know of any remaining limitations of the affine promotion pass? And if you don't mind me asking, what are you using this for? |
I haven't done detailed testing, so I'm not sure how many features are still missing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for upstreaming this. Just a minor comment from me.
This pass was always experimental so I don't think this needs to work in every case before it can be merged.
If you plan to upstream your whole optimization pipeline (very welcome!) then please create an RFC at https://discourse.llvm.org/c/subprojects/flang/33
fir.do_loop
with results to affine.for
.fir.do_loop
with results to affine.for
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
No description provided.