diff --git a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp index 43fccf52dc8ab..ef82e400bea14 100644 --- a/flang/lib/Optimizer/Transforms/AffinePromotion.cpp +++ b/flang/lib/Optimizer/Transforms/AffinePromotion.cpp @@ -49,8 +49,9 @@ struct AffineIfAnalysis; /// second when doing rewrite. struct AffineFunctionAnalysis { explicit AffineFunctionAnalysis(mlir::func::FuncOp funcOp) { - for (fir::DoLoopOp op : funcOp.getOps()) - loopAnalysisMap.try_emplace(op, op, *this); + funcOp->walk([&](fir::DoLoopOp doloop) { + loopAnalysisMap.try_emplace(doloop, doloop, *this); + }); } AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const; @@ -102,10 +103,23 @@ struct AffineLoopAnalysis { return true; } + bool analysisResults(fir::DoLoopOp loopOperation) { + if (loopOperation.getFinalValue() && + !loopOperation.getResult(0).use_empty()) { + LLVM_DEBUG( + llvm::dbgs() + << "AffineLoopAnalysis: cannot promote loop final value\n";); + return false; + } + + return true; + } + bool analyzeLoop(fir::DoLoopOp loopOperation, AffineFunctionAnalysis &functionAnalysis) { LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n"; loopOperation.dump();); return analyzeMemoryAccess(loopOperation) && + analysisResults(loopOperation) && analyzeBody(loopOperation, functionAnalysis); } @@ -461,14 +475,28 @@ class AffineLoopConversion : public mlir::OpRewritePattern { LLVM_ATTRIBUTE_UNUSED auto loopAnalysis = functionAnalysis.getChildLoopAnalysis(loop); auto &loopOps = loop.getBody()->getOperations(); + auto resultOp = cast(loop.getBody()->getTerminator()); + auto results = resultOp.getOperands(); + auto loopResults = loop->getResults(); auto loopAndIndex = createAffineFor(loop, rewriter); auto affineFor = loopAndIndex.first; auto inductionVar = loopAndIndex.second; + if (loop.getFinalValue()) { + results = results.drop_front(); + loopResults = loopResults.drop_front(); + } + rewriter.startOpModification(affineFor.getOperation()); affineFor.getBody()->getOperations().splice( std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(), std::prev(loopOps.end())); + rewriter.replaceAllUsesWith(loop.getRegionIterArgs(), + affineFor.getRegionIterArgs()); + if (!results.empty()) { + rewriter.setInsertionPointToEnd(affineFor.getBody()); + rewriter.create(resultOp->getLoc(), results); + } rewriter.finalizeOpModification(affineFor.getOperation()); rewriter.startOpModification(loop.getOperation()); @@ -479,7 +507,8 @@ class AffineLoopConversion : public mlir::OpRewritePattern { LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n"; affineFor.dump();); - rewriter.replaceOp(loop, affineFor.getOperation()->getResults()); + rewriter.replaceAllUsesWith(loopResults, affineFor->getResults()); + rewriter.eraseOp(loop); return success(); } @@ -503,7 +532,7 @@ class AffineLoopConversion : public mlir::OpRewritePattern { ValueRange(op.getUpperBound()), mlir::AffineMap::get(0, 1, 1 + mlir::getAffineSymbolExpr(0, op.getContext())), - step); + step, op.getIterOperands()); return std::make_pair(affineFor, affineFor.getInductionVar()); } @@ -528,7 +557,7 @@ class AffineLoopConversion : public mlir::OpRewritePattern { genericUpperBound.getResult(), mlir::AffineMap::get(0, 1, 1 + mlir::getAffineSymbolExpr(0, op.getContext())), - 1); + 1, op.getIterOperands()); rewriter.setInsertionPointToStart(affineFor.getBody()); auto actualIndex = rewriter.create( op.getLoc(), actualIndexMap, diff --git a/flang/test/Fir/affine-promotion.fir b/flang/test/Fir/affine-promotion.fir index aae35c6ef5659..46467ab4a292a 100644 --- a/flang/test/Fir/affine-promotion.fir +++ b/flang/test/Fir/affine-promotion.fir @@ -131,3 +131,89 @@ func.func @loop_with_if(%a: !arr_d1, %v: f32) { // CHECK: } // CHECK: return // CHECK: } + +func.func @loop_with_result(%arg0: !fir.ref>, %arg1: !fir.ref>, %arg2: !fir.ref>) -> f32 { + %c1 = arith.constant 1 : index + %cst = arith.constant 0.000000e+00 : f32 + %c100 = arith.constant 100 : index + %0 = fir.shape %c100 : (index) -> !fir.shape<1> + %1 = fir.shape %c100, %c100 : (index, index) -> !fir.shape<2> + %2 = fir.alloca i32 + %3:2 = fir.do_loop %arg3 = %c1 to %c100 step %c1 iter_args(%arg4 = %cst) -> (index, f32) { + %8 = fir.array_coor %arg0(%0) %arg3 : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref + %9 = fir.load %8 : !fir.ref + %10 = arith.addf %arg4, %9 fastmath : f32 + %11 = arith.addi %arg3, %c1 overflow : index + fir.result %11, %10 : index, f32 + } + %4:2 = fir.do_loop %arg3 = %c1 to %c100 step %c1 iter_args(%arg4 = %3#1) -> (index, f32) { + %8 = fir.array_coor %arg1(%1) %c1, %arg3 : (!fir.ref>, !fir.shape<2>, index, index) -> !fir.ref + %9 = fir.convert %8 : (!fir.ref) -> !fir.ref> + %10 = fir.do_loop %arg5 = %c1 to %c100 step %c1 iter_args(%arg6 = %arg4) -> (f32) { + %12 = fir.array_coor %9(%0) %arg5 : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref + %13 = fir.load %12 : !fir.ref + %14 = arith.addf %arg6, %13 fastmath : f32 + fir.result %14 : f32 + } + %11 = arith.addi %arg3, %c1 overflow : index + fir.result %11, %10 : index, f32 + } + %5:2 = fir.do_loop %arg3 = %c1 to %c100 step %c1 iter_args(%arg4 = %4#1, %arg5 = %cst) -> (f32, f32) { + %8 = fir.array_coor %arg0(%0) %arg3 : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref + %9 = fir.load %8 : !fir.ref + %10 = arith.addf %arg4, %9 fastmath : f32 + %11 = fir.array_coor %arg2(%0) %arg3 : (!fir.ref>, !fir.shape<1>, index) -> !fir.ref + %12 = fir.load %11 : !fir.ref + %13 = arith.addf %arg5, %12 fastmath : f32 + fir.result %10, %13 : f32, f32 + } + %6 = arith.addf %5#0, %5#1 fastmath : f32 + %7 = fir.convert %4#0 : (index) -> i32 + fir.store %7 to %2 : !fir.ref + return %6 : f32 +} + +// CHECK-LABEL: func.func @loop_with_result( +// CHECK-SAME: %[[ARG0:.*]]: !fir.ref>, +// CHECK-SAME: %[[ARG1:.*]]: !fir.ref>, +// CHECK-SAME: %[[ARG2:.*]]: !fir.ref>) -> f32 { +// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index +// CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[VAL_2:.*]] = arith.constant 100 : index +// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1> +// CHECK: %[[VAL_4:.*]] = fir.shape %[[VAL_2]], %[[VAL_2]] : (index, index) -> !fir.shape<2> +// CHECK: %[[VAL_5:.*]] = fir.alloca i32 +// CHECK: %[[VAL_6:.*]] = fir.convert %[[ARG0]] : (!fir.ref>) -> memref +// CHECK: %[[VAL_7:.*]] = affine.for %[[VAL_8:.*]] = %[[VAL_0]] to #{{.*}}(){{\[}}%[[VAL_2]]] iter_args(%[[VAL_9:.*]] = %[[VAL_1]]) -> (f32) { +// CHECK: %[[VAL_10:.*]] = affine.apply #{{.*}}(%[[VAL_8]]){{\[}}%[[VAL_0]], %[[VAL_2]], %[[VAL_0]]] +// CHECK: %[[VAL_11:.*]] = affine.load %[[VAL_6]]{{\[}}%[[VAL_10]]] : memref +// CHECK: %[[VAL_12:.*]] = arith.addf %[[VAL_9]], %[[VAL_11]] fastmath : f32 +// CHECK: affine.yield %[[VAL_12]] : f32 +// CHECK: } +// CHECK: %[[VAL_13:.*]]:2 = fir.do_loop %[[VAL_14:.*]] = %[[VAL_0]] to %[[VAL_2]] step %[[VAL_0]] iter_args(%[[VAL_15:.*]] = %[[VAL_7]]) -> (index, f32) { +// CHECK: %[[VAL_16:.*]] = fir.array_coor %[[ARG1]](%[[VAL_4]]) %[[VAL_0]], %[[VAL_14]] : (!fir.ref>, !fir.shape<2>, index, index) -> !fir.ref +// CHECK: %[[VAL_17:.*]] = fir.convert %[[VAL_16]] : (!fir.ref) -> !fir.ref> +// CHECK: %[[VAL_18:.*]] = fir.convert %[[VAL_17]] : (!fir.ref>) -> memref +// CHECK: %[[VAL_19:.*]] = affine.for %[[VAL_20:.*]] = %[[VAL_0]] to #{{.*}}(){{\[}}%[[VAL_2]]] iter_args(%[[VAL_21:.*]] = %[[VAL_15]]) -> (f32) { +// CHECK: %[[VAL_22:.*]] = affine.apply #{{.*}}(%[[VAL_20]]){{\[}}%[[VAL_0]], %[[VAL_2]], %[[VAL_0]]] +// CHECK: %[[VAL_23:.*]] = affine.load %[[VAL_18]]{{\[}}%[[VAL_22]]] : memref +// CHECK: %[[VAL_24:.*]] = arith.addf %[[VAL_21]], %[[VAL_23]] fastmath : f32 +// CHECK: affine.yield %[[VAL_24]] : f32 +// CHECK: } +// CHECK: %[[VAL_25:.*]] = arith.addi %[[VAL_14]], %[[VAL_0]] overflow : index +// CHECK: fir.result %[[VAL_25]], %[[VAL_19]] : index, f32 +// CHECK: } +// CHECK: %[[VAL_26:.*]] = fir.convert %[[ARG2]] : (!fir.ref>) -> memref +// CHECK: %[[VAL_27:.*]]:2 = affine.for %[[VAL_28:.*]] = %[[VAL_0]] to #{{.*}}(){{\[}}%[[VAL_2]]] iter_args(%[[VAL_29:.*]] = %[[VAL_30:.*]]#1, %[[VAL_31:.*]] = %[[VAL_1]]) -> (f32, f32) { +// CHECK: %[[VAL_32:.*]] = affine.apply #{{.*}}(%[[VAL_28]]){{\[}}%[[VAL_0]], %[[VAL_2]], %[[VAL_0]]] +// CHECK: %[[VAL_33:.*]] = affine.load %[[VAL_6]]{{\[}}%[[VAL_32]]] : memref +// CHECK: %[[VAL_34:.*]] = arith.addf %[[VAL_29]], %[[VAL_33]] fastmath : f32 +// CHECK: %[[VAL_35:.*]] = affine.load %[[VAL_26]]{{\[}}%[[VAL_32]]] : memref +// CHECK: %[[VAL_36:.*]] = arith.addf %[[VAL_31]], %[[VAL_35]] fastmath : f32 +// CHECK: affine.yield %[[VAL_34]], %[[VAL_36]] : f32, f32 +// CHECK: } +// CHECK: %[[VAL_37:.*]] = arith.addf %[[VAL_38:.*]]#0, %[[VAL_38]]#1 fastmath : f32 +// CHECK: %[[VAL_39:.*]] = fir.convert %[[VAL_40:.*]]#0 : (index) -> i32 +// CHECK: fir.store %[[VAL_39]] to %[[VAL_5]] : !fir.ref +// CHECK: return %[[VAL_37]] : f32 +// CHECK: }