diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp index 79b5087e4da68..8ba2f604df80a 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp @@ -1060,6 +1060,40 @@ bool ClauseProcessor::processIsDevicePtr( }); } +bool ClauseProcessor::processLinear(mlir::omp::LinearClauseOps &result) const { + lower::StatementContext stmtCtx; + return findRepeatableClause< + omp::clause::Linear>([&](const omp::clause::Linear &clause, + const parser::CharBlock &) { + auto &objects = std::get(clause.t); + for (const omp::Object &object : objects) { + semantics::Symbol *sym = object.sym(); + const mlir::Value variable = converter.getSymbolAddress(*sym); + result.linearVars.push_back(variable); + } + if (objects.size()) { + if (auto &mod = + std::get>( + clause.t)) { + mlir::Value operand = + fir::getBase(converter.genExprValue(toEvExpr(*mod), stmtCtx)); + result.linearStepVars.append(objects.size(), operand); + } else if (std::get>( + clause.t)) { + mlir::Location currentLocation = converter.getCurrentLocation(); + TODO(currentLocation, "Linear modifiers not yet implemented"); + } else { + // If nothing is present, add the default step of 1. + fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder(); + mlir::Location currentLocation = converter.getCurrentLocation(); + mlir::Value operand = firOpBuilder.createIntegerConstant( + currentLocation, firOpBuilder.getI32Type(), 1); + result.linearStepVars.append(objects.size(), operand); + } + } + }); +} + bool ClauseProcessor::processLink( llvm::SmallVectorImpl &result) const { return findRepeatableClause( diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.h b/flang/lib/Lower/OpenMP/ClauseProcessor.h index 7857ba3fd0845..0ec41bdd33256 100644 --- a/flang/lib/Lower/OpenMP/ClauseProcessor.h +++ b/flang/lib/Lower/OpenMP/ClauseProcessor.h @@ -122,6 +122,7 @@ class ClauseProcessor { bool processIsDevicePtr( mlir::omp::IsDevicePtrClauseOps &result, llvm::SmallVectorImpl &isDeviceSyms) const; + bool processLinear(mlir::omp::LinearClauseOps &result) const; bool processLink(llvm::SmallVectorImpl &result) const; diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp index 7eec598645eac..2a1c94407e1c8 100644 --- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp +++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp @@ -213,14 +213,15 @@ void DataSharingProcessor::collectSymbolsForPrivatization() { // so, we won't need to explicitely handle block objects (or forget to do // so). for (auto *sym : explicitlyPrivatizedSymbols) - allPrivatizedSymbols.insert(sym); + if (!sym->test(Fortran::semantics::Symbol::Flag::OmpLinear)) + allPrivatizedSymbols.insert(sym); } bool DataSharingProcessor::needBarrier() { // Emit implicit barrier to synchronize threads and avoid data races on // initialization of firstprivate variables and post-update of lastprivate // variables. - // Emit implicit barrier for linear clause. Maybe on somewhere else. + // Emit implicit barrier for linear clause in the OpenMPIRBuilder. for (const semantics::Symbol *sym : allPrivatizedSymbols) { if (sym->test(semantics::Symbol::Flag::OmpLastPrivate) && (sym->test(semantics::Symbol::Flag::OmpFirstPrivate) || diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index 54560729eb4af..6fa915b4364f9 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -1841,13 +1841,13 @@ static void genWsloopClauses( llvm::SmallVectorImpl &reductionSyms) { ClauseProcessor cp(converter, semaCtx, clauses); cp.processNowait(clauseOps); + cp.processLinear(clauseOps); cp.processOrder(clauseOps); cp.processOrdered(clauseOps); cp.processReduction(loc, clauseOps, reductionSyms); cp.processSchedule(stmtCtx, clauseOps); - cp.processTODO( - loc, llvm::omp::Directive::OMPD_do); + cp.processTODO(loc, llvm::omp::Directive::OMPD_do); } //===----------------------------------------------------------------------===// diff --git a/flang/test/Lower/OpenMP/wsloop-linear.f90 b/flang/test/Lower/OpenMP/wsloop-linear.f90 new file mode 100644 index 0000000000000..b99677108be2f --- /dev/null +++ b/flang/test/Lower/OpenMP/wsloop-linear.f90 @@ -0,0 +1,57 @@ +! This test checks lowering of OpenMP DO Directive (Worksharing) +! with linear clause + +! RUN: %flang_fc1 -fopenmp -emit-hlfir %s -o - 2>&1 | FileCheck %s + +!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFsimple_linearEx"} +!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFsimple_linearEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[const:.*]] = arith.constant 1 : i32 +subroutine simple_linear + implicit none + integer :: x, y, i + !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref) {{.*}} + !$omp do linear(x) + !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref + !CHECK: %[[const:.*]] = arith.constant 2 : i32 + !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32 + do i = 1, 10 + y = x + 2 + end do + !$omp end do +end subroutine + + +!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_stepEx"} +!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_stepEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +subroutine linear_step + implicit none + integer :: x, y, i + !CHECK: %[[const:.*]] = arith.constant 4 : i32 + !CHECK: omp.wsloop linear(%[[X]]#0 = %[[const]] : !fir.ref) {{.*}} + !$omp do linear(x:4) + !CHECK: %[[LOAD:.*]] = fir.load %[[X]]#0 : !fir.ref + !CHECK: %[[const:.*]] = arith.constant 2 : i32 + !CHECK: %[[RESULT:.*]] = arith.addi %[[LOAD]], %[[const]] : i32 + do i = 1, 10 + y = x + 2 + end do + !$omp end do +end subroutine + +!CHECK: %[[A_alloca:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFlinear_exprEa"} +!CHECK: %[[A:.*]]:2 = hlfir.declare %[[A_alloca]] {uniq_name = "_QFlinear_exprEa"} : (!fir.ref) -> (!fir.ref, !fir.ref) +!CHECK: %[[X_alloca:.*]] = fir.alloca i32 {bindc_name = "x", uniq_name = "_QFlinear_exprEx"} +!CHECK: %[[X:.*]]:2 = hlfir.declare %[[X_alloca]] {uniq_name = "_QFlinear_exprEx"} : (!fir.ref) -> (!fir.ref, !fir.ref) +subroutine linear_expr + implicit none + integer :: x, y, i, a + !CHECK: %[[LOAD_A:.*]] = fir.load %[[A]]#0 : !fir.ref + !CHECK: %[[const:.*]] = arith.constant 4 : i32 + !CHECK: %[[LINEAR_EXPR:.*]] = arith.addi %[[LOAD_A]], %[[const]] : i32 + !CHECK: omp.wsloop linear(%[[X]]#0 = %[[LINEAR_EXPR]] : !fir.ref) {{.*}} + !$omp do linear(x:a+4) + do i = 1, 10 + y = x + 2 + end do + !$omp end do +end subroutine diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h index ffc0fd0a0bdac..68f15d5c7d41e 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h +++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h @@ -3580,6 +3580,9 @@ class CanonicalLoopInfo { BasicBlock *Latch = nullptr; BasicBlock *Exit = nullptr; + // Hold the MLIR value for the `lastiter` of the canonical loop. + Value *LastIter = nullptr; + /// Add the control blocks of this loop to \p BBs. /// /// This does not include any block from the body, including the one returned @@ -3612,6 +3615,18 @@ class CanonicalLoopInfo { void mapIndVar(llvm::function_ref Updater); public: + /// Sets the last iteration variable for this loop. + void setLastIter(Value *IterVar) { LastIter = std::move(IterVar); } + + /// Returns the last iteration variable for this loop. + /// Certain use-cases (like translation of linear clause) may access + /// this variable even after a loop transformation. Hence, do not guard + /// this getter function by `isValid`. It is the responsibility of the + /// callee to ensure this functionality is not invoked by a non-outlined + /// CanonicalLoopInfo object (in which case, `setLastIter` will never be + /// invoked and `LastIter` will be by default `nullptr`). + Value *getLastIter() { return LastIter; } + /// Returns whether this object currently represents the IR of a loop. If /// returning false, it may have been consumed by a loop transformation or not /// been intialized. Do not use in this case; diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp index a1268ca76b2d5..991cdb7b6b416 100644 --- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp +++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp @@ -4254,6 +4254,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::applyStaticWorkshareLoop( Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound"); Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound"); Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride"); + CLI->setLastIter(PLastIter); // At the end of the preheader, prepare for calling the "init" function by // storing the current loop bounds into the allocated space. A canonical loop @@ -4361,6 +4362,7 @@ OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(DebugLoc DL, Value *PUpperBound = Builder.CreateAlloca(InternalIVTy, nullptr, "p.upperbound"); Value *PStride = Builder.CreateAlloca(InternalIVTy, nullptr, "p.stride"); + CLI->setLastIter(PLastIter); // Set up the source location value for the OpenMP runtime. Builder.restoreIP(CLI->getPreheaderIP()); @@ -4844,6 +4846,7 @@ OpenMPIRBuilder::applyDynamicWorkshareLoop(DebugLoc DL, CanonicalLoopInfo *CLI, Value *PLowerBound = Builder.CreateAlloca(IVTy, nullptr, "p.lowerbound"); Value *PUpperBound = Builder.CreateAlloca(IVTy, nullptr, "p.upperbound"); Value *PStride = Builder.CreateAlloca(IVTy, nullptr, "p.stride"); + CLI->setLastIter(PLastIter); // At the end of the preheader, prepare for calling the "init" function by // storing the current loop bounds into the allocated space. A canonical loop diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 9f7b5605556e6..d723eee10636f 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -124,6 +124,146 @@ class PreviouslyReportedError char PreviouslyReportedError::ID = 0; +/* + * Custom class for processing linear clause for omp.wsloop + * and omp.simd. Linear clause translation requires setup, + * initialization, update, and finalization at varying + * basic blocks in the IR. This class helps maintain + * internal state to allow consistent translation in + * each of these stages. + */ + +class LinearClauseProcessor { + +private: + SmallVector linearPreconditionVars; + SmallVector linearLoopBodyTemps; + SmallVector linearOrigVars; + SmallVector linearOrigVal; + SmallVector linearSteps; + llvm::BasicBlock *linearFinalizationBB; + llvm::BasicBlock *linearExitBB; + llvm::BasicBlock *linearLastIterExitBB; + +public: + // Allocate space for linear variabes + void createLinearVar(llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + mlir::Value &linearVar) { + if (llvm::AllocaInst *linearVarAlloca = dyn_cast( + moduleTranslation.lookupValue(linearVar))) { + linearPreconditionVars.push_back(builder.CreateAlloca( + linearVarAlloca->getAllocatedType(), nullptr, ".linear_var")); + llvm::Value *linearLoopBodyTemp = builder.CreateAlloca( + linearVarAlloca->getAllocatedType(), nullptr, ".linear_result"); + linearOrigVal.push_back(moduleTranslation.lookupValue(linearVar)); + linearLoopBodyTemps.push_back(linearLoopBodyTemp); + linearOrigVars.push_back(linearVarAlloca); + } + } + + // Initialize linear step + inline void initLinearStep(LLVM::ModuleTranslation &moduleTranslation, + mlir::Value &linearStep) { + linearSteps.push_back(moduleTranslation.lookupValue(linearStep)); + } + + // Emit IR for initialization of linear variables + llvm::OpenMPIRBuilder::InsertPointOrErrorTy + initLinearVar(llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::BasicBlock *loopPreHeader) { + builder.SetInsertPoint(loopPreHeader->getTerminator()); + for (size_t index = 0; index < linearOrigVars.size(); index++) { + llvm::LoadInst *linearVarLoad = builder.CreateLoad( + linearOrigVars[index]->getAllocatedType(), linearOrigVars[index]); + builder.CreateStore(linearVarLoad, linearPreconditionVars[index]); + } + llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP = + moduleTranslation.getOpenMPBuilder()->createBarrier( + builder.saveIP(), llvm::omp::OMPD_barrier); + return afterBarrierIP; + } + + // Emit IR for updating Linear variables + void updateLinearVar(llvm::IRBuilderBase &builder, llvm::BasicBlock *loopBody, + llvm::Value *loopInductionVar) { + builder.SetInsertPoint(loopBody->getTerminator()); + for (size_t index = 0; index < linearPreconditionVars.size(); index++) { + // Emit increments for linear vars + llvm::LoadInst *linearVarStart = + builder.CreateLoad(linearOrigVars[index]->getAllocatedType(), + + linearPreconditionVars[index]); + auto mulInst = builder.CreateMul(loopInductionVar, linearSteps[index]); + auto addInst = builder.CreateAdd(linearVarStart, mulInst); + builder.CreateStore(addInst, linearLoopBodyTemps[index]); + } + } + + // Linear variable finalization is conditional on the last logical iteration. + // Create BB splits to manage the same. + void outlineLinearFinalizationBB(llvm::IRBuilderBase &builder, + llvm::BasicBlock *loopExit) { + linearFinalizationBB = loopExit->splitBasicBlock( + loopExit->getTerminator(), "omp_loop.linear_finalization"); + linearExitBB = linearFinalizationBB->splitBasicBlock( + linearFinalizationBB->getTerminator(), "omp_loop.linear_exit"); + linearLastIterExitBB = linearFinalizationBB->splitBasicBlock( + linearFinalizationBB->getTerminator(), "omp_loop.linear_lastiter_exit"); + } + + // Finalize the linear vars + llvm::OpenMPIRBuilder::InsertPointOrErrorTy + finalizeLinearVar(llvm::IRBuilderBase &builder, + LLVM::ModuleTranslation &moduleTranslation, + llvm::Value *lastIter) { + // Emit condition to check whether last logical iteration is being executed + builder.SetInsertPoint(linearFinalizationBB->getTerminator()); + llvm::Value *loopLastIterLoad = builder.CreateLoad( + llvm::Type::getInt32Ty(builder.getContext()), lastIter); + llvm::Value *isLast = + builder.CreateCmp(llvm::CmpInst::ICMP_NE, loopLastIterLoad, + llvm::ConstantInt::get( + llvm::Type::getInt32Ty(builder.getContext()), 0)); + // Store the linear variable values to original variables. + builder.SetInsertPoint(linearLastIterExitBB->getTerminator()); + for (size_t index = 0; index < linearOrigVars.size(); index++) { + llvm::LoadInst *linearVarTemp = + builder.CreateLoad(linearOrigVars[index]->getAllocatedType(), + linearLoopBodyTemps[index]); + builder.CreateStore(linearVarTemp, linearOrigVars[index]); + } + + // Create conditional branch such that the linear variable + // values are stored to original variables only at the + // last logical iteration + builder.SetInsertPoint(linearFinalizationBB->getTerminator()); + builder.CreateCondBr(isLast, linearLastIterExitBB, linearExitBB); + linearFinalizationBB->getTerminator()->eraseFromParent(); + // Emit barrier + builder.SetInsertPoint(linearExitBB->getTerminator()); + return moduleTranslation.getOpenMPBuilder()->createBarrier( + builder.saveIP(), llvm::omp::OMPD_barrier); + } + + // Rewrite all uses of the original variable in `BBName` + // with the linear variable in-place + void rewriteInPlace(llvm::IRBuilderBase &builder, std::string BBName, + size_t varIndex) { + llvm::SmallVector users; + for (llvm::User *user : linearOrigVal[varIndex]->users()) + users.push_back(user); + for (auto *user : users) { + if (auto *userInst = dyn_cast(user)) { + if (userInst->getParent()->getName().str() == BBName) + user->replaceUsesOfWith(linearOrigVal[varIndex], + linearLoopBodyTemps[varIndex]); + } + } + } +}; + } // namespace /// Looks up from the operation from and returns the PrivateClauseOp with @@ -292,7 +432,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { }) .Case([&](omp::WsloopOp op) { checkAllocate(op, result); - checkLinear(op, result); checkOrder(op, result); checkReduction(op, result); }) @@ -2423,15 +2562,40 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, llvm::omp::Directive::OMPD_for); llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + + // Initialize linear variables and linear step + LinearClauseProcessor linearClauseProcessor; + if (wsloopOp.getLinearVars().size()) { + for (mlir::Value linearVar : wsloopOp.getLinearVars()) + linearClauseProcessor.createLinearVar(builder, moduleTranslation, + linearVar); + for (mlir::Value linearStep : wsloopOp.getLinearStepVars()) + linearClauseProcessor.initLinearStep(moduleTranslation, linearStep); + } + llvm::Expected regionBlock = convertOmpOpRegions( wsloopOp.getRegion(), "omp.wsloop.region", builder, moduleTranslation); if (failed(handleError(regionBlock, opInst))) return failure(); - builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); llvm::CanonicalLoopInfo *loopInfo = findCurrentLoopInfo(moduleTranslation); + // Emit Initialization and Update IR for linear variables + if (wsloopOp.getLinearVars().size()) { + llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP = + linearClauseProcessor.initLinearVar(builder, moduleTranslation, + loopInfo->getPreheader()); + if (failed(handleError(afterBarrierIP, *loopOp))) + return failure(); + builder.restoreIP(*afterBarrierIP); + linearClauseProcessor.updateLinearVar(builder, loopInfo->getBody(), + loopInfo->getIndVar()); + linearClauseProcessor.outlineLinearFinalizationBB(builder, + loopInfo->getExit()); + } + + builder.SetInsertPoint(*regionBlock, (*regionBlock)->begin()); llvm::OpenMPIRBuilder::InsertPointOrErrorTy wsloopIP = ompBuilder->applyWorkshareLoop( ompLoc.DL, loopInfo, allocaIP, loopNeedsBarrier, @@ -2443,6 +2607,22 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, if (failed(handleError(wsloopIP, opInst))) return failure(); + // Emit finalization and in-place rewrites for linear vars. + if (wsloopOp.getLinearVars().size()) { + llvm::OpenMPIRBuilder::InsertPointTy oldIP = builder.saveIP(); + assert(loopInfo->getLastIter() && + "`lastiter` in CanonicalLoopInfo is nullptr"); + llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterBarrierIP = + linearClauseProcessor.finalizeLinearVar(builder, moduleTranslation, + loopInfo->getLastIter()); + if (failed(handleError(afterBarrierIP, *loopOp))) + return failure(); + for (size_t index = 0; index < wsloopOp.getLinearVars().size(); index++) + linearClauseProcessor.rewriteInPlace(builder, "omp.loop_nest.region", + index); + builder.restoreIP(oldIP); + } + // Set the correct branch target for task cancellation popCancelFinalizationCB(cancelTerminators, *ompBuilder, wsloopIP.get()); diff --git a/mlir/test/Target/LLVMIR/openmp-llvm.mlir b/mlir/test/Target/LLVMIR/openmp-llvm.mlir index 32f0ba5b105ff..9ad9e93301239 100644 --- a/mlir/test/Target/LLVMIR/openmp-llvm.mlir +++ b/mlir/test/Target/LLVMIR/openmp-llvm.mlir @@ -358,6 +358,94 @@ llvm.func @wsloop_simple(%arg0: !llvm.ptr) { // ----- +// CHECK-LABEL: wsloop_linear + +// CHECK: {{.*}} = alloca i32, i64 1, align 4 +// CHECK: %[[Y:.*]] = alloca i32, i64 1, align 4 +// CHECK: %[[X:.*]] = alloca i32, i64 1, align 4 + +// CHECK: entry: +// CHECK: %[[LINEAR_VAR:.*]] = alloca i32, align 4 +// CHECK: %[[LINEAR_RESULT:.*]] = alloca i32, align 4 +// CHECK: br label %omp_loop.preheader + +// CHECK: omp_loop.preheader: +// CHECK: %[[LOAD:.*]] = load i32, ptr %[[X]], align 4 +// CHECK: store i32 %[[LOAD]], ptr %[[LINEAR_VAR]], align 4 +// CHECK: %omp_global_thread_num = call i32 @__kmpc_global_thread_num(ptr @2) +// CHECK: call void @__kmpc_barrier(ptr @1, i32 %omp_global_thread_num) + +// CHECK: omp_loop.body: +// CHECK: %[[LOOP_IV:.*]] = add i32 %omp_loop.iv, {{.*}} +// CHECK: %[[LINEAR_LOAD:.*]] = load i32, ptr %[[LINEAR_VAR]], align 4 +// CHECK: %[[MUL:.*]] = mul i32 %[[LOOP_IV]], 1 +// CHECK: %[[ADD:.*]] = add i32 %[[LINEAR_LOAD]], %[[MUL]] +// CHECK: store i32 %[[ADD]], ptr %[[LINEAR_RESULT]], align 4 +// CHECK: br label %omp.loop_nest.region + +// CHECK: omp.loop_nest.region: +// CHECK: %[[LINEAR_LOAD:.*]] = load i32, ptr %[[LINEAR_RESULT]], align 4 +// CHECK: %[[ADD:.*]] = add i32 %[[LINEAR_LOAD]], 2 +// CHECK: store i32 %[[ADD]], ptr %[[Y]], align 4 + +// CHECK: omp_loop.exit: +// CHECK: call void @__kmpc_for_static_fini(ptr @2, i32 %omp_global_thread_num4) +// CHECK: %omp_global_thread_num5 = call i32 @__kmpc_global_thread_num(ptr @2) +// CHECK: call void @__kmpc_barrier(ptr @3, i32 %omp_global_thread_num5) +// CHECK: br label %omp_loop.linear_finalization + +// CHECK: omp_loop.linear_finalization: +// CHECK: %[[LAST_ITER:.*]] = load i32, ptr %p.lastiter, align 4 +// CHECK: %[[CMP:.*]] = icmp ne i32 %[[LAST_ITER]], 0 +// CHECK: br i1 %[[CMP]], label %omp_loop.linear_lastiter_exit, label %omp_loop.linear_exit + +// CHECK: omp_loop.linear_lastiter_exit: +// CHECK: %[[LINEAR_RESULT_LOAD:.*]] = load i32, ptr %[[LINEAR_RESULT]], align 4 +// CHECK: store i32 %[[LINEAR_RESULT_LOAD]], ptr %[[X]], align 4 +// CHECK: br label %omp_loop.linear_exit + +// CHECK: omp_loop.linear_exit: +// CHECK: %omp_global_thread_num6 = call i32 @__kmpc_global_thread_num(ptr @2) +// CHECK: call void @__kmpc_barrier(ptr @1, i32 %omp_global_thread_num6) +// CHECK: br label %omp_loop.after + +llvm.func @wsloop_linear() { + %0 = llvm.mlir.constant(1 : i64) : i64 + %1 = llvm.alloca %0 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr + %2 = llvm.mlir.constant(1 : i64) : i64 + %3 = llvm.alloca %2 x i32 {bindc_name = "y"} : (i64) -> !llvm.ptr + %4 = llvm.mlir.constant(1 : i64) : i64 + %5 = llvm.alloca %4 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr + %6 = llvm.mlir.constant(1 : i64) : i64 + %7 = llvm.alloca %6 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr + %8 = llvm.mlir.constant(2 : i32) : i32 + %9 = llvm.mlir.constant(10 : i32) : i32 + %10 = llvm.mlir.constant(1 : i32) : i32 + %11 = llvm.mlir.constant(1 : i64) : i64 + %12 = llvm.mlir.constant(1 : i64) : i64 + %13 = llvm.mlir.constant(1 : i64) : i64 + %14 = llvm.mlir.constant(1 : i64) : i64 + omp.wsloop linear(%5 = %10 : !llvm.ptr) { + omp.loop_nest (%arg0) : i32 = (%10) to (%9) inclusive step (%10) { + llvm.store %arg0, %1 : i32, !llvm.ptr + %15 = llvm.load %5 : !llvm.ptr -> i32 + %16 = llvm.add %15, %8 : i32 + llvm.store %16, %3 : i32, !llvm.ptr + %17 = llvm.add %arg0, %10 : i32 + %18 = llvm.icmp "sgt" %17, %9 : i32 + llvm.cond_br %18, ^bb1, ^bb2 + ^bb1: // pred: ^bb0 + llvm.store %17, %1 : i32, !llvm.ptr + llvm.br ^bb2 + ^bb2: // 2 preds: ^bb0, ^bb1 + omp.yield + } + } + llvm.return +} + +// ----- + // CHECK-LABEL: @wsloop_inclusive_1 llvm.func @wsloop_inclusive_1(%arg0: !llvm.ptr) { %0 = llvm.mlir.constant(42 : index) : i64 diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir index 9a83b46efddca..98fccb1a80f67 100644 --- a/mlir/test/Target/LLVMIR/openmp-todo.mlir +++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir @@ -511,19 +511,6 @@ llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { // ----- -llvm.func @wsloop_linear(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) { - // expected-error@below {{not yet implemented: Unhandled clause linear in omp.wsloop operation}} - // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}} - omp.wsloop linear(%x = %step : !llvm.ptr) { - omp.loop_nest (%iv) : i32 = (%lb) to (%ub) step (%step) { - omp.yield - } - } - llvm.return -} - -// ----- - llvm.func @wsloop_order(%lb : i32, %ub : i32, %step : i32) { // expected-error@below {{not yet implemented: Unhandled clause order in omp.wsloop operation}} // expected-error@below {{LLVM Translation failed for operation: omp.wsloop}}