Thanks to visit codestin.com
Credit goes to github.com

Skip to content

[flang][fir] Add locality specifiers modeling to fir.do_concurrent.loop #138506

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ergawy
Copy link
Member

@ergawy ergawy commented May 5, 2025

Extends fir.do_concurrent.loop ops to model locality specifiers. This follows the same pattern used in OpenMP where an op of type fir.local (in OpenMP it is omp.private) is referenced from the do concurrent locality specifier. This PR adds the MLIR op changes as well as printing and parsing logic.

PR stack:

@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels May 5, 2025
@llvmbot
Copy link
Member

llvmbot commented May 5, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kareem Ergawy (ergawy)

Changes

Extends fir.do_concurrent.loop ops to model locality specifiers. This follows the same pattern used in OpenMP where an op of type fir.local (in OpenMP it is omp.private) is referenced from the do concurrent locality specifier. This PR adds the MLIR op changes as well as printing and parsing logic.


Full diff: https://github.com/llvm/llvm-project/pull/138506.diff

5 Files Affected:

  • (modified) flang/include/flang/Optimizer/Dialect/FIROps.td (+32-1)
  • (modified) flang/lib/Lower/Bridge.cpp (+1-1)
  • (modified) flang/lib/Optimizer/Dialect/FIROps.cpp (+94-18)
  • (modified) flang/test/Fir/do_concurrent.fir (+63-1)
  • (modified) flang/test/Fir/invalid.fir (+5-5)
diff --git a/flang/include/flang/Optimizer/Dialect/FIROps.td b/flang/include/flang/Optimizer/Dialect/FIROps.td
index aea57d2e8dd71..e1d9f877855c4 100644
--- a/flang/include/flang/Optimizer/Dialect/FIROps.td
+++ b/flang/include/flang/Optimizer/Dialect/FIROps.td
@@ -3647,6 +3647,13 @@ def fir_DoConcurrentOp : fir_Op<"do_concurrent",
   let hasVerifier = 1;
 }
 
+def fir_LocalSpecifier {
+  dag arguments = (ins
+    Variadic<AnyType>:$local_vars,
+    OptionalAttr<SymbolRefArrayAttr>:$local_syms
+  );
+}
+
 def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
     [AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopLikeOpInterface,
                                                          ["getLoopInductionVars"]>,
@@ -3700,7 +3707,7 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
       LLVM.
   }];
 
-  let arguments = (ins
+  defvar opArgs = (ins
     Variadic<Index>:$lowerBound,
     Variadic<Index>:$upperBound,
     Variadic<Index>:$step,
@@ -3709,16 +3716,40 @@ def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
     OptionalAttr<LoopAnnotationAttr>:$loopAnnotation
   );
 
+  let arguments = !con(opArgs, fir_LocalSpecifier.arguments);
+
   let regions = (region SizedRegion<1>:$region);
 
   let hasCustomAssemblyFormat = 1;
   let hasVerifier = 1;
 
   let extraClassDeclaration = [{
+    unsigned getNumInductionVars() { return getLowerBound().size(); }
+
+    unsigned getNumLocalOperands() { return getLocalVars().size(); }
+
+    mlir::Block::BlockArgListType getInductionVars() {
+      return getBody()->getArguments().slice(0, getNumInductionVars());
+    }
+
+    mlir::Block::BlockArgListType getRegionLocalArgs() {
+      return getBody()->getArguments().slice(getNumInductionVars(),
+                                             getNumLocalOperands());
+    }
+
+    /// Number of operands controlling the loop
+    unsigned getNumControlOperands() { return getLowerBound().size() * 3; }
+
     // Get Number of reduction operands
     unsigned getNumReduceOperands() {
       return getReduceOperands().size();
     }
+
+    mlir::Operation::operand_range getLocalOperands() {
+      return getOperands()
+          .slice(getNumControlOperands() + getNumReduceOperands(),
+                 getNumLocalOperands());
+    }
   }];
 }
 
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 8da05255d5f41..0a61f61ab8f75 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -2460,7 +2460,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
           nestReduceAttrs.empty()
               ? nullptr
               : mlir::ArrayAttr::get(builder->getContext(), nestReduceAttrs),
-          nullptr);
+          nullptr, /*local_vars=*/std::nullopt, /*local_syms=*/nullptr);
 
       llvm::SmallVector<mlir::Type> loopBlockArgTypes(
           incrementLoopNestInfo.size(), builder->getIndexType());
diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 65ec730e134c2..c95655d7dcef6 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -5033,21 +5033,25 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
                                                  mlir::OperationState &result) {
   auto &builder = parser.getBuilder();
   // Parse an opening `(` followed by induction variables followed by `)`
-  llvm::SmallVector<mlir::OpAsmParser::Argument, 4> ivs;
-  if (parser.parseArgumentList(ivs, mlir::OpAsmParser::Delimiter::Paren))
+  llvm::SmallVector<mlir::OpAsmParser::Argument, 4> regionArgs;
+
+  if (parser.parseArgumentList(regionArgs, mlir::OpAsmParser::Delimiter::Paren))
     return mlir::failure();
 
+  llvm::SmallVector<mlir::Type> argTypes(regionArgs.size(),
+                                         builder.getIndexType());
+
   // Parse loop bounds.
   llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> lower;
   if (parser.parseEqual() ||
-      parser.parseOperandList(lower, ivs.size(),
+      parser.parseOperandList(lower, regionArgs.size(),
                               mlir::OpAsmParser::Delimiter::Paren) ||
       parser.resolveOperands(lower, builder.getIndexType(), result.operands))
     return mlir::failure();
 
   llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> upper;
   if (parser.parseKeyword("to") ||
-      parser.parseOperandList(upper, ivs.size(),
+      parser.parseOperandList(upper, regionArgs.size(),
                               mlir::OpAsmParser::Delimiter::Paren) ||
       parser.resolveOperands(upper, builder.getIndexType(), result.operands))
     return mlir::failure();
@@ -5055,7 +5059,7 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
   // Parse step values.
   llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> steps;
   if (parser.parseKeyword("step") ||
-      parser.parseOperandList(steps, ivs.size(),
+      parser.parseOperandList(steps, regionArgs.size(),
                               mlir::OpAsmParser::Delimiter::Paren) ||
       parser.resolveOperands(steps, builder.getIndexType(), result.operands))
     return mlir::failure();
@@ -5086,12 +5090,55 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
                         builder.getArrayAttr(arrayAttr));
   }
 
-  // Now parse the body.
-  mlir::Region *body = result.addRegion();
-  for (auto &iv : ivs)
-    iv.type = builder.getIndexType();
-  if (parser.parseRegion(*body, ivs))
-    return mlir::failure();
+  llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> localOperands;
+  if (succeeded(parser.parseOptionalKeyword("local"))) {
+    std::size_t oldArgTypesSize = argTypes.size();
+    if (failed(parser.parseLParen()))
+      return mlir::failure();
+
+    llvm::SmallVector<mlir::SymbolRefAttr> localSymbolVec;
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (failed(parser.parseAttribute(localSymbolVec.emplace_back())))
+            return mlir::failure();
+
+          if (parser.parseOperand(localOperands.emplace_back()) ||
+              parser.parseArrow() ||
+              parser.parseArgument(regionArgs.emplace_back()))
+            return mlir::failure();
+
+          return mlir::success();
+        })))
+      return mlir::failure();
+
+    if (failed(parser.parseColon()))
+      return mlir::failure();
+
+    if (failed(parser.parseCommaSeparatedList([&]() {
+          if (failed(parser.parseType(argTypes.emplace_back())))
+            return mlir::failure();
+
+          return mlir::success();
+        })))
+      return mlir::failure();
+
+    if (regionArgs.size() != argTypes.size())
+      return parser.emitError(parser.getNameLoc(),
+                              "mismatch in number of local arg and types");
+
+    if (failed(parser.parseRParen()))
+      return mlir::failure();
+
+    for (auto operandType : llvm::zip_equal(
+             localOperands, llvm::drop_begin(argTypes, oldArgTypesSize)))
+      if (parser.resolveOperand(std::get<0>(operandType),
+                                std::get<1>(operandType), result.operands))
+        return mlir::failure();
+
+    llvm::SmallVector<mlir::Attribute> symbolAttrs(localSymbolVec.begin(),
+                                                   localSymbolVec.end());
+    result.addAttribute(getLocalSymsAttrName(result.name),
+                        builder.getArrayAttr(symbolAttrs));
+  }
 
   // Set `operandSegmentSizes` attribute.
   result.addAttribute(DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
@@ -5099,7 +5146,16 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
                           {static_cast<int32_t>(lower.size()),
                            static_cast<int32_t>(upper.size()),
                            static_cast<int32_t>(steps.size()),
-                           static_cast<int32_t>(reduceOperands.size())}));
+                           static_cast<int32_t>(reduceOperands.size()),
+                           static_cast<int32_t>(localOperands.size())}));
+
+  // Now parse the body.
+  for (auto [arg, type] : llvm::zip_equal(regionArgs, argTypes))
+    arg.type = type;
+
+  mlir::Region *body = result.addRegion();
+  if (parser.parseRegion(*body, regionArgs))
+    return mlir::failure();
 
   // Parse attributes.
   if (parser.parseOptionalAttrDict(result.attributes))
@@ -5109,8 +5165,9 @@ mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
 }
 
 void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
-  p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
-    << ") to (" << getUpperBound() << ") step (" << getStep() << ")";
+  p << " (" << getBody()->getArguments().slice(0, getNumInductionVars())
+    << ") = (" << getLowerBound() << ") to (" << getUpperBound() << ") step ("
+    << getStep() << ")";
 
   if (!getReduceOperands().empty()) {
     p << " reduce(";
@@ -5123,12 +5180,27 @@ void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
     p << ')';
   }
 
+  if (!getLocalVars().empty()) {
+    p << " local(";
+    llvm::interleaveComma(llvm::zip_equal(getLocalSymsAttr(), getLocalVars(),
+                                          getRegionLocalArgs()),
+                          p, [&](auto it) {
+                            p << std::get<0>(it) << " " << std::get<1>(it)
+                              << " -> " << std::get<2>(it);
+                          });
+    p << " : ";
+    llvm::interleaveComma(getLocalVars(), p,
+                          [&](auto it) { p << it.getType(); });
+    p << ")";
+  }
+
   p << ' ';
   p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
   p.printOptionalAttrDict(
       (*this)->getAttrs(),
       /*elidedAttrs=*/{DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
-                       DoConcurrentLoopOp::getReduceAttrsAttrName()});
+                       DoConcurrentLoopOp::getReduceAttrsAttrName(),
+                       DoConcurrentLoopOp::getLocalSymsAttrName()});
 }
 
 llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions() {
@@ -5139,6 +5211,7 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
   mlir::Operation::operand_range lbValues = getLowerBound();
   mlir::Operation::operand_range ubValues = getUpperBound();
   mlir::Operation::operand_range stepValues = getStep();
+  mlir::Operation::operand_range localVars = getLocalVars();
 
   if (lbValues.empty())
     return emitOpError(
@@ -5152,11 +5225,13 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
   // Check that the body defines the same number of block arguments as the
   // number of tuple elements in step.
   mlir::Block *body = getBody();
-  if (body->getNumArguments() != stepValues.size())
+  unsigned numIndVarArgs = body->getNumArguments() - localVars.size();
+
+  if (numIndVarArgs != stepValues.size())
     return emitOpError() << "expects the same number of induction variables: "
                          << body->getNumArguments()
                          << " as bound and step values: " << stepValues.size();
-  for (auto arg : body->getArguments())
+  for (auto arg : body->getArguments().slice(0, numIndVarArgs))
     if (!arg.getType().isIndex())
       return emitOpError(
           "expects arguments for the induction variable to be of index type");
@@ -5171,7 +5246,8 @@ llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
 
 std::optional<llvm::SmallVector<mlir::Value>>
 fir::DoConcurrentLoopOp::getLoopInductionVars() {
-  return llvm::SmallVector<mlir::Value>{getBody()->getArguments()};
+  return llvm::SmallVector<mlir::Value>{
+      getBody()->getArguments().slice(0, getLowerBound().size())};
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/flang/test/Fir/do_concurrent.fir b/flang/test/Fir/do_concurrent.fir
index 4e55777402428..cfb9a7abac15b 100644
--- a/flang/test/Fir/do_concurrent.fir
+++ b/flang/test/Fir/do_concurrent.fir
@@ -91,7 +91,6 @@ func.func @dc_2d_reduction(%i_lb: index, %i_ub: index, %i_st: index,
 // CHECK:           }
 // CHECK:         }
 
-
 fir.local {type = local} @local_privatizer : i32
 
 // CHECK:   fir.local {type = local} @[[LOCAL_PRIV_SYM:local_privatizer]] : i32
@@ -109,3 +108,66 @@ fir.local {type = local_init} @local_init_privatizer : i32 copy {
 // CHECK:      fir.store %[[ORIG_VAL_LD]] to %[[LOCAL_VAL]] : !fir.ref<i32>
 // CHECK:      fir.yield(%[[LOCAL_VAL]] : !fir.ref<i32>)
 // CHECK:   }
+
+func.func @_QPdo_concurrent() {
+  %3 = fir.alloca i32 {bindc_name = "local_init_var", uniq_name = "_QFdo_concurrentElocal_init_var"}
+  %4:2 = hlfir.declare %3 {uniq_name = "_QFdo_concurrentElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %5 = fir.alloca i32 {bindc_name = "local_var", uniq_name = "_QFdo_concurrentElocal_var"}
+  %6:2 = hlfir.declare %5 {uniq_name = "_QFdo_concurrentElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+  %c1 = arith.constant 1 : index
+  %c10 = arith.constant 1 : index
+  fir.do_concurrent {
+    %9 = fir.alloca i32 {bindc_name = "i"}
+    %10:2 = hlfir.declare %9 {uniq_name = "_QFdo_concurrentEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+    fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) local(@local_privatizer %6#0 -> %arg1, @local_init_privatizer %4#0 -> %arg2 : !fir.ref<i32>, !fir.ref<i32>) {
+      %11 = fir.convert %arg0 : (index) -> i32
+      fir.store %11 to %10#0 : !fir.ref<i32>
+      %13:2 = hlfir.declare %arg1 {uniq_name = "_QFdo_concurrentElocal_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+      %15:2 = hlfir.declare %arg2 {uniq_name = "_QFdo_concurrentElocal_init_var"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+      %17 = fir.load %10#0 : !fir.ref<i32>
+      %c5_i32 = arith.constant 5 : i32
+      %18 = arith.cmpi slt, %17, %c5_i32 : i32
+      fir.if %18 {
+        %c42_i32 = arith.constant 42 : i32
+        hlfir.assign %c42_i32 to %13#0 : i32, !fir.ref<i32>
+      } else {
+        %c84_i32 = arith.constant 84 : i32
+        hlfir.assign %c84_i32 to %15#0 : i32, !fir.ref<i32>
+      }
+    }
+  }
+  return
+}
+
+// CHECK-LABEL:   func.func @_QPdo_concurrent() {
+// CHECK:           %[[LOC_INIT_ALLOC:.*]] = fir.alloca i32 {bindc_name = "local_init_var", {{.*}}}
+// CHECK:           %[[LOC_INIT_DECL:.*]]:2 = hlfir.declare %[[LOC_INIT_ALLOC]]
+
+// CHECK:           %[[LOC_ALLOC:.*]] = fir.alloca i32 {bindc_name = "local_var", {{.*}}}
+// CHECK:           %[[LOC_DECL:.*]]:2 = hlfir.declare %[[LOC_ALLOC]]
+
+// CHECK:           %[[C1:.*]] = arith.constant 1 : index
+// CHECK:           %[[C10:.*]] = arith.constant 1 : index
+
+// CHECK:           fir.do_concurrent {
+// CHECK:             %[[DC_I_ALLOC:.*]] = fir.alloca i32 {bindc_name = "i"}
+// CHECK:             %[[DC_I_DECL:.*]]:2 = hlfir.declare %[[DC_I_ALLOC]]
+
+// CHECK:             fir.do_concurrent.loop (%[[IV:.*]]) = (%[[C1]]) to (%[[C10]]) step (%[[C1]]) local(@[[LOCAL_PRIV_SYM]] %[[LOC_DECL]]#0 -> %[[LOC_ARG:.*]], @[[LOCAL_INIT_PRIV_SYM]] %[[LOC_INIT_DECL]]#0 -> %[[LOC_INIT_ARG:.*]] : !fir.ref<i32>, !fir.ref<i32>) {
+// CHECK:               %[[IV_CVT:.*]] = fir.convert %[[IV]] : (index) -> i32
+// CHECK:               fir.store %[[IV_CVT]] to %[[DC_I_DECL]]#0 : !fir.ref<i32>
+
+// CHECK:               %[[LOC_PRIV_DECL:.*]]:2 = hlfir.declare %[[LOC_ARG]]
+// CHECK:               %[[LOC_INIT_PRIV_DECL:.*]]:2 = hlfir.declare %[[LOC_INIT_ARG]]
+
+// CHECK:               fir.if %{{.*}} {
+// CHECK:                 %[[C42:.*]] = arith.constant 42 : i32
+// CHECK:                 hlfir.assign %[[C42]] to %[[LOC_PRIV_DECL]]#0 : i32, !fir.ref<i32>
+// CHECK:               } else {
+// CHECK:                 %[[C84:.*]] = arith.constant 84 : i32
+// CHECK:                 hlfir.assign %[[C84]] to %[[LOC_INIT_PRIV_DECL]]#0 : i32, !fir.ref<i32>
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:           return
+// CHECK:         }
diff --git a/flang/test/Fir/invalid.fir b/flang/test/Fir/invalid.fir
index f9f5e267dd9bc..3cd3ab439b0e9 100644
--- a/flang/test/Fir/invalid.fir
+++ b/flang/test/Fir/invalid.fir
@@ -1198,7 +1198,7 @@ func.func @dc_0d() {
 
 func.func @dc_invalid_parent(%arg0: index, %arg1: index) {
   // expected-error@+1 {{'fir.do_concurrent.loop' op expects parent op 'fir.do_concurrent'}}
-  "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>}> ({
+  "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0, 0>}> ({
   ^bb0(%arg2: index):
      %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
   }) : (index, index) -> ()
@@ -1210,7 +1210,7 @@ func.func @dc_invalid_parent(%arg0: index, %arg1: index) {
 func.func @dc_invalid_control(%arg0: index, %arg1: index) {
   // expected-error@+2 {{'fir.do_concurrent.loop' op different number of tuple elements for lowerBound, upperBound or step}}
   fir.do_concurrent {
-    "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0>}> ({
+    "fir.do_concurrent.loop"(%arg0, %arg1) <{operandSegmentSizes = array<i32: 1, 1, 0, 0, 0>}> ({
     ^bb0(%arg2: index):
       %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
     }) : (index, index) -> ()
@@ -1223,7 +1223,7 @@ func.func @dc_invalid_control(%arg0: index, %arg1: index) {
 func.func @dc_invalid_ind_var(%arg0: index, %arg1: index) {
   // expected-error@+2 {{'fir.do_concurrent.loop' op expects the same number of induction variables: 2 as bound and step values: 1}}
   fir.do_concurrent {
-    "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>}> ({
+    "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 0>}> ({
     ^bb0(%arg3: index, %arg4: index):
       %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
     }) : (index, index, index) -> ()
@@ -1236,7 +1236,7 @@ func.func @dc_invalid_ind_var(%arg0: index, %arg1: index) {
 func.func @dc_invalid_ind_var_type(%arg0: index, %arg1: index) {
   // expected-error@+2 {{'fir.do_concurrent.loop' op expects arguments for the induction variable to be of index type}}
   fir.do_concurrent {
-    "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0>}> ({
+    "fir.do_concurrent.loop"(%arg0, %arg1, %arg0) <{operandSegmentSizes = array<i32: 1, 1, 1, 0, 0>}> ({
     ^bb0(%arg3: i32):
       %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
     }) : (index, index, index) -> ()
@@ -1250,7 +1250,7 @@ func.func @dc_invalid_reduction(%arg0: index, %arg1: index) {
   %sum = fir.alloca i32
   // expected-error@+2 {{'fir.do_concurrent.loop' op mismatch in number of reduction variables and reduction attributes}}
   fir.do_concurrent {
-    "fir.do_concurrent.loop"(%arg0, %arg1, %arg0, %sum) <{operandSegmentSizes = array<i32: 1, 1, 1, 1>}> ({
+    "fir.do_concurrent.loop"(%arg0, %arg1, %arg0, %sum) <{operandSegmentSizes = array<i32: 1, 1, 1, 1, 0>}> ({
     ^bb0(%arg3: index):
       %tmp = "fir.alloca"() <{in_type = i32, operandSegmentSizes = array<i32: 0, 0>}> : () -> !fir.ref<i32>
     }) : (index, index, index, !fir.ref<i32>) -> ()

Copy link
Contributor

@tblah tblah left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks

@ergawy ergawy force-pushed the users/ergawy/fir-dc-local-spec-1 branch from e673a6f to 3562e43 Compare May 6, 2025 19:10
ergawy added a commit that referenced this pull request May 7, 2025
Adds a new `fir.local` op to model `local` and `local_init` locality
specifiers. This op is a clone of `omp.private`. In particular, this new
op also models the privatization/localization logic of an SSA value in
the `fir` dialect just like `omp.private` does for OpenMP.

PR stack:
- #137928
- #138505 (this PR)
- #138506
- #138512
- #138534
- #138816
Base automatically changed from users/ergawy/fir-dc-local-spec-0 to main May 7, 2025 12:00
…oop`

Extends `fir.do_concurrent.loop` ops to model locality specifiers. This
follows the same pattern used in OpenMP where an op of type `fir.local`
(in OpenMP it is `omp.private`) is referenced from the `do concurrent`
locality specifier. This PR adds the MLIR op changes as well as printing
and parsing logic.
@ergawy ergawy force-pushed the users/ergawy/fir-dc-local-spec-1 branch from 3562e43 to bf2bd22 Compare May 7, 2025 12:01
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request May 7, 2025
…138505)

Adds a new `fir.local` op to model `local` and `local_init` locality
specifiers. This op is a clone of `omp.private`. In particular, this new
op also models the privatization/localization logic of an SSA value in
the `fir` dialect just like `omp.private` does for OpenMP.

PR stack:
- llvm/llvm-project#137928
- llvm/llvm-project#138505 (this PR)
- llvm/llvm-project#138506
- llvm/llvm-project#138512
- llvm/llvm-project#138534
- llvm/llvm-project#138816
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants