-
Notifications
You must be signed in to change notification settings - Fork 13.4k
[flang][openmp]Add UserReductionDetails and use in DECLARE REDUCTION #131628
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[flang][openmp]Add UserReductionDetails and use in DECLARE REDUCTION #131628
Conversation
@llvm/pr-subscribers-flang-semantics @llvm/pr-subscribers-flang-parser Author: Mats Petersson (Leporacanthicus) ChangesThis adds another puzzle piece for the support of OpenMP DECLARE REDUCTION functionality. This adds support for operators with derived types, as well as declaring multiple different types with the same name or operator. A new detail class for UserReductionDetials is introduced to hold the list of types supported for a given reduction declaration. Tests for parsing and symbol generation added. Declare reduction is still not supported to lowering, it will generate a "Not yet implemented" fatal error. Patch is 31.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131628.diff 12 Files Affected:
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index 715811885c219..12867a5f8ec6f 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -701,6 +701,25 @@ class GenericDetails {
};
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const GenericDetails &);
+class UserReductionDetails : public WithBindName {
+public:
+ using TypeVector = std::vector<const DeclTypeSpec *>;
+ UserReductionDetails() = default;
+
+ void AddType(const DeclTypeSpec *type) { typeList_.push_back(type); }
+ const TypeVector &GetTypeList() const { return typeList_; }
+
+ bool SupportsType(const DeclTypeSpec *type) const {
+ for (auto t : typeList_)
+ if (t == type)
+ return true;
+ return false;
+ }
+
+private:
+ TypeVector typeList_;
+};
+
class UnknownDetails {};
using Details = std::variant<UnknownDetails, MainProgramDetails, ModuleDetails,
@@ -708,7 +727,7 @@ using Details = std::variant<UnknownDetails, MainProgramDetails, ModuleDetails,
ObjectEntityDetails, ProcEntityDetails, AssocEntityDetails,
DerivedTypeDetails, UseDetails, UseErrorDetails, HostAssocDetails,
GenericDetails, ProcBindingDetails, NamelistDetails, CommonBlockDetails,
- TypeParamDetails, MiscDetails>;
+ TypeParamDetails, MiscDetails, UserReductionDetails>;
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const Details &);
std::string DetailsToString(const Details &);
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index 5fcebdca0bc5f..fc8b6f1021b02 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -8,6 +8,7 @@
#include "check-omp-structure.h"
#include "definable.h"
+#include "resolve-names-utils.h"
#include "flang/Evaluate/check-expression.h"
#include "flang/Evaluate/expression.h"
#include "flang/Evaluate/type.h"
@@ -3361,8 +3362,8 @@ bool OmpStructureChecker::CheckReductionOperator(
valid =
llvm::is_contained({"max", "min", "iand", "ior", "ieor"}, realName);
if (!valid) {
- auto *misc{name->symbol->detailsIf<MiscDetails>()};
- valid = misc && misc->kind() == MiscDetails::Kind::ConstructName;
+ auto *reductionDetails{name->symbol->detailsIf<UserReductionDetails>()};
+ valid = reductionDetails != nullptr;
}
}
if (!valid) {
@@ -3444,7 +3445,8 @@ void OmpStructureChecker::CheckReductionObjects(
}
static bool IsReductionAllowedForType(
- const parser::OmpReductionIdentifier &ident, const DeclTypeSpec &type) {
+ const parser::OmpReductionIdentifier &ident, const DeclTypeSpec &type,
+ const Scope &scope) {
auto isLogical{[](const DeclTypeSpec &type) -> bool {
return type.category() == DeclTypeSpec::Logical;
}};
@@ -3464,9 +3466,11 @@ static bool IsReductionAllowedForType(
case parser::DefinedOperator::IntrinsicOperator::Multiply:
case parser::DefinedOperator::IntrinsicOperator::Add:
case parser::DefinedOperator::IntrinsicOperator::Subtract:
- return type.IsNumeric(TypeCategory::Integer) ||
+ if (type.IsNumeric(TypeCategory::Integer) ||
type.IsNumeric(TypeCategory::Real) ||
- type.IsNumeric(TypeCategory::Complex);
+ type.IsNumeric(TypeCategory::Complex))
+ return true;
+ break;
case parser::DefinedOperator::IntrinsicOperator::AND:
case parser::DefinedOperator::IntrinsicOperator::OR:
@@ -3479,8 +3483,18 @@ static bool IsReductionAllowedForType(
DIE("This should have been caught in CheckIntrinsicOperator");
return false;
}
+ parser::CharBlock name{MakeNameFromOperator(*intrinsicOp)};
+ Symbol *symbol{scope.FindSymbol(name)};
+ if (symbol) {
+ const auto *reductionDetails{symbol->detailsIf<UserReductionDetails>()};
+ assert(reductionDetails && "Expected to find reductiondetails");
+
+ return reductionDetails->SupportsType(&type);
+ }
+ return false;
}
- return true;
+ assert(0 && "Intrinsic Operator not found - parsing gone wrong?");
+ return false; // Reject everything else.
}};
auto checkDesignator{[&](const parser::ProcedureDesignator &procD) {
@@ -3493,18 +3507,42 @@ static bool IsReductionAllowedForType(
// IAND: arguments must be integers: F2023 16.9.100
// IEOR: arguments must be integers: F2023 16.9.106
// IOR: arguments must be integers: F2023 16.9.111
- return type.IsNumeric(TypeCategory::Integer);
+ if (type.IsNumeric(TypeCategory::Integer)) {
+ return true;
+ }
} else if (realName == "max" || realName == "min") {
// MAX: arguments must be integer, real, or character:
// F2023 16.9.135
// MIN: arguments must be integer, real, or character:
// F2023 16.9.141
- return type.IsNumeric(TypeCategory::Integer) ||
- type.IsNumeric(TypeCategory::Real) || isCharacter(type);
+ if (type.IsNumeric(TypeCategory::Integer) ||
+ type.IsNumeric(TypeCategory::Real) || isCharacter(type)) {
+ return true;
+ }
}
+
+ // If we get here, it may be a user declared reduction, so check
+ // if the symbol has UserReductionDetails, and if so, the type is
+ // supported.
+ if (const auto *reductionDetails{
+ name->symbol->detailsIf<UserReductionDetails>()}) {
+ return reductionDetails->SupportsType(&type);
+ }
+
+ // We also need to check for mangled names (max, min, iand, ieor and ior)
+ // and then check if the type is there.
+ parser::CharBlock mangledName = MangleSpecialFunctions(name->source);
+ if (const auto &symbol{scope.FindSymbol(mangledName)}) {
+ if (const auto *reductionDetails{
+ symbol->detailsIf<UserReductionDetails>()}) {
+ return reductionDetails->SupportsType(&type);
+ }
+ }
+ // Everything else is "not matching type".
+ return false;
}
- // TODO: user defined reduction operators. Just allow everything for now.
- return true;
+ assert(0 && "name and name->symbol should be set here...");
+ return false;
}};
return common::visit(
@@ -3519,7 +3557,8 @@ void OmpStructureChecker::CheckReductionObjectTypes(
for (auto &[symbol, source] : symbols) {
if (auto *type{symbol->GetType()}) {
- if (!IsReductionAllowedForType(ident, *type)) {
+ const auto &scope{context_.FindScope(symbol->name())};
+ if (!IsReductionAllowedForType(ident, *type, scope)) {
context_.Say(source,
"The type of '%s' is incompatible with the reduction operator."_err_en_US,
symbol->name());
diff --git a/flang/lib/Semantics/resolve-names-utils.h b/flang/lib/Semantics/resolve-names-utils.h
index 64784722ff4f8..de0991d69b61b 100644
--- a/flang/lib/Semantics/resolve-names-utils.h
+++ b/flang/lib/Semantics/resolve-names-utils.h
@@ -146,5 +146,9 @@ struct SymbolAndTypeMappings;
void MapSubprogramToNewSymbols(const Symbol &oldSymbol, Symbol &newSymbol,
Scope &newScope, SymbolAndTypeMappings * = nullptr);
+parser::CharBlock MakeNameFromOperator(
+ const parser::DefinedOperator::IntrinsicOperator &op);
+parser::CharBlock MangleSpecialFunctions(const parser::CharBlock name);
+
} // namespace Fortran::semantics
#endif // FORTRAN_SEMANTICS_RESOLVE_NAMES_H_
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index fcd4ba6a51907..825ab36d2e800 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1748,15 +1748,75 @@ void OmpVisitor::ProcessMapperSpecifier(const parser::OmpMapperSpecifier &spec,
PopScope();
}
+parser::CharBlock MakeNameFromOperator(
+ const parser::DefinedOperator::IntrinsicOperator &op) {
+ switch (op) {
+ case parser::DefinedOperator::IntrinsicOperator::Multiply:
+ return parser::CharBlock{"op.*", 4};
+ case parser::DefinedOperator::IntrinsicOperator::Add:
+ return parser::CharBlock{"op.+", 4};
+ case parser::DefinedOperator::IntrinsicOperator::Subtract:
+ return parser::CharBlock{"op.-", 4};
+
+ case parser::DefinedOperator::IntrinsicOperator::AND:
+ return parser::CharBlock{"op.AND", 6};
+ case parser::DefinedOperator::IntrinsicOperator::OR:
+ return parser::CharBlock{"op.OR", 6};
+ case parser::DefinedOperator::IntrinsicOperator::EQV:
+ return parser::CharBlock{"op.EQV", 7};
+ case parser::DefinedOperator::IntrinsicOperator::NEQV:
+ return parser::CharBlock{"op.NEQV", 8};
+
+ default:
+ assert(0 && "Unsupported operator...");
+ return parser::CharBlock{"op.?", 4};
+ }
+}
+
+parser::CharBlock MangleSpecialFunctions(const parser::CharBlock name) {
+ if (name == "max") {
+ return parser::CharBlock{"op.max", 6};
+ }
+ if (name == "min") {
+ return parser::CharBlock{"op.min", 6};
+ }
+ if (name == "iand") {
+ return parser::CharBlock{"op.iand", 7};
+ }
+ if (name == "ior") {
+ return parser::CharBlock{"op.ior", 6};
+ }
+ if (name == "ieor") {
+ return parser::CharBlock{"op.ieor", 7};
+ }
+ // All other names: return as is.
+ return name;
+}
+
void OmpVisitor::ProcessReductionSpecifier(
const parser::OmpReductionSpecifier &spec,
const std::optional<parser::OmpClauseList> &clauses) {
+ const parser::Name *name{nullptr};
+ parser::Name mangledName{};
+ UserReductionDetails reductionDetailsTemp{};
const auto &id{std::get<parser::OmpReductionIdentifier>(spec.t)};
if (auto procDes{std::get_if<parser::ProcedureDesignator>(&id.u)}) {
- if (auto *name{std::get_if<parser::Name>(&procDes->u)}) {
- name->symbol =
- &MakeSymbol(*name, MiscDetails{MiscDetails::Kind::ConstructName});
+ name = std::get_if<parser::Name>(&procDes->u);
+ if (name) {
+ mangledName.source = MangleSpecialFunctions(name->source);
}
+ } else {
+ const auto &defOp{std::get<parser::DefinedOperator>(id.u)};
+ mangledName.source = MakeNameFromOperator(
+ std::get<parser::DefinedOperator::IntrinsicOperator>(defOp.u));
+ name = &mangledName;
+ }
+
+ UserReductionDetails *reductionDetails{&reductionDetailsTemp};
+ Symbol *symbol{name ? name->symbol : nullptr};
+ symbol = FindSymbol(mangledName);
+ if (symbol) {
+ reductionDetails = symbol->detailsIf<UserReductionDetails>();
}
auto &typeList{std::get<parser::OmpTypeNameList>(spec.t)};
@@ -1788,6 +1848,10 @@ void OmpVisitor::ProcessReductionSpecifier(
const DeclTypeSpec *typeSpec{GetDeclTypeSpec()};
assert(typeSpec && "We should have a type here");
+ if (reductionDetails) {
+ reductionDetails->AddType(typeSpec);
+ }
+
for (auto &nm : ompVarNames) {
ObjectEntityDetails details{};
details.set_type(*typeSpec);
@@ -1798,6 +1862,13 @@ void OmpVisitor::ProcessReductionSpecifier(
Walk(clauses);
PopScope();
}
+
+ if (name) {
+ if (!symbol) {
+ symbol = &MakeSymbol(mangledName, Attrs{}, std::move(*reductionDetails));
+ }
+ name->symbol = symbol;
+ }
}
bool OmpVisitor::Pre(const parser::OmpDirectiveSpecification &x) {
diff --git a/flang/lib/Semantics/symbol.cpp b/flang/lib/Semantics/symbol.cpp
index 32eb6c2c5a188..e627dd293ba7c 100644
--- a/flang/lib/Semantics/symbol.cpp
+++ b/flang/lib/Semantics/symbol.cpp
@@ -246,7 +246,7 @@ void GenericDetails::CopyFrom(const GenericDetails &from) {
// This is primarily for debugging.
std::string DetailsToString(const Details &details) {
return common::visit(
- common::visitors{
+ common::visitors{//
[](const UnknownDetails &) { return "Unknown"; },
[](const MainProgramDetails &) { return "MainProgram"; },
[](const ModuleDetails &) { return "Module"; },
@@ -266,7 +266,7 @@ std::string DetailsToString(const Details &details) {
[](const TypeParamDetails &) { return "TypeParam"; },
[](const MiscDetails &) { return "Misc"; },
[](const AssocEntityDetails &) { return "AssocEntity"; },
- },
+ [](const UserReductionDetails &) { return "UserReductionDetails"; }},
details);
}
@@ -300,6 +300,9 @@ bool Symbol::CanReplaceDetails(const Details &details) const {
[&](const HostAssocDetails &) {
return this->has<HostAssocDetails>();
},
+ [&](const UserReductionDetails &) {
+ return this->has<UserReductionDetails>();
+ },
[](const auto &) { return false; },
},
details);
@@ -598,6 +601,11 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Details &details) {
[&](const MiscDetails &x) {
os << ' ' << MiscDetails::EnumToString(x.kind());
},
+ [&](const UserReductionDetails &x) {
+ for (auto &type : x.GetTypeList()) {
+ DumpType(os, type);
+ }
+ },
[&](const auto &x) { os << x; },
},
details);
diff --git a/flang/test/Parser/OpenMP/declare-reduction-multi.f90 b/flang/test/Parser/OpenMP/declare-reduction-multi.f90
new file mode 100644
index 0000000000000..0e1adcc9958d7
--- /dev/null
+++ b/flang/test/Parser/OpenMP/declare-reduction-multi.f90
@@ -0,0 +1,134 @@
+! RUN: %flang_fc1 -fdebug-unparse -fopenmp %s | FileCheck --ignore-case %s
+! RUN: %flang_fc1 -fdebug-dump-parse-tree -fopenmp %s | FileCheck --check-prefix="PARSE-TREE" %s
+
+!! Test multiple declarations for the same type, with different operations.
+module mymod
+ type :: tt
+ real r
+ end type tt
+contains
+ function mymax(a, b)
+ type(tt) :: a, b, mymax
+ if (a%r > b%r) then
+ mymax = a
+ else
+ mymax = b
+ end if
+ end function mymax
+end module mymod
+
+program omp_examples
+!CHECK-LABEL: PROGRAM omp_examples
+ use mymod
+ implicit none
+ integer, parameter :: n = 100
+ integer :: i
+ type(tt) :: values(n), sum, prod, big, small
+
+ !$omp declare reduction(+:tt:omp_out%r = omp_out%r + omp_in%r) initializer(omp_priv%r = 0)
+!CHECK: !$OMP DECLARE REDUCTION (+:tt: omp_out%r=omp_out%r+omp_in%r
+!CHECK-NEXT: ) INITIALIZER(omp_priv%r=0_4)
+!PARSE-TREE: DeclarationConstruct -> SpecificationConstruct -> OpenMPDeclarativeConstruct -> OpenMPDeclareReductionConstruct
+!PARSE-TREE: Verbatim
+!PARSE-TREE: OmpReductionSpecifier
+!PARSE-TREE-NEXT: OmpReductionIdentifier -> DefinedOperator -> IntrinsicOperator = Add
+!PARSE-TREE: OmpTypeNameList -> OmpTypeSpecifier -> TypeSpec -> DerivedTypeSpec
+!PARSE-TREE-NEXT: Name = 'tt'
+!PARSE-TREE: OmpReductionCombiner -> AssignmentStmt = 'omp_out%r=omp_out%r+omp_in%r'
+!PARSE-TREE: OmpClauseList -> OmpClause -> Initializer -> OmpInitializerClause -> AssignmentStmt = 'omp_priv%r=0._4
+ !$omp declare reduction(*:tt:omp_out%r = omp_out%r * omp_in%r) initializer(omp_priv%r = 1)
+!CHECK-NEXT: !$OMP DECLARE REDUCTION (*:tt: omp_out%r=omp_out%r*omp_in%r
+!CHECK-NEXT: ) INITIALIZER(omp_priv%r=1_4)
+!PARSE-TREE: DeclarationConstruct -> SpecificationConstruct -> OpenMPDeclarativeConstruct -> OpenMPDeclareReductionConstruct
+!PARSE-TREE: Verbatim
+!PARSE-TREE: OmpReductionSpecifier
+!PARSE-TREE: OmpReductionIdentifier -> DefinedOperator -> IntrinsicOperator = Multiply
+!PARSE-TREE: OmpTypeNameList -> OmpTypeSpecifier -> TypeSpec -> DerivedTypeSpec
+!PARSE-TREE-NEXT: Name = 'tt'
+!PARSE-TREE: OmpReductionCombiner -> AssignmentStmt = 'omp_out%r=omp_out%r*omp_in%r'
+!PARSE-TREE: OmpClauseList -> OmpClause -> Initializer -> OmpInitializerClause -> AssignmentStmt = 'omp_priv%r=1._4'
+ !$omp declare reduction(max:tt:omp_out = mymax(omp_out, omp_in)) initializer(omp_priv%r = 0)
+!CHECK-NEXT: !$OMP DECLARE REDUCTION (max:tt: omp_out=mymax(omp_out,omp_in)
+!CHECK-NEXT: ) INITIALIZER(omp_priv%r=0_4)
+!PARSE-TREE: DeclarationConstruct -> SpecificationConstruct -> OpenMPDeclarativeConstruct -> OpenMPDeclareReductionConstruct
+!PARSE-TREE: Verbatim
+!PARSE-TREE: OmpReductionSpecifier
+!PARSE-TREE: OmpReductionIdentifier -> ProcedureDesignator -> Name = 'max'
+!PARSE-TREE: OmpTypeNameList -> OmpTypeSpecifier -> TypeSpec -> DerivedTypeSpec
+!PARSE-TREE: Name = 'tt'
+!PARSE-TREE: OmpReductionCombiner -> AssignmentStmt = 'omp_out=mymax(omp_out,omp_in)'
+!PARSE-TREE: OmpClauseList -> OmpClause -> Initializer -> OmpInitializerClause -> AssignmentStmt = 'omp_priv%r=0._4'
+ !$omp declare reduction(min:tt:omp_out%r = min(omp_out%r, omp_in%r)) initializer(omp_priv%r = 1)
+!CHECK-NEXT: !$OMP DECLARE REDUCTION (min:tt: omp_out%r=min(omp_out%r,omp_in%r)
+!CHECK-NEXT: ) INITIALIZER(omp_priv%r=1_4)
+!PARSE-TREE: DeclarationConstruct -> SpecificationConstruct -> OpenMPDeclarativeConstruct -> OpenMPDeclareReductionConstruct
+!PARSE-TREE: Verbatim
+!PARSE-TREE: OmpReductionSpecifier
+!PARSE-TREE: OmpReductionIdentifier -> ProcedureDesignator -> Name = 'min'
+!PARSE-TREE: OmpTypeNameList -> OmpTypeSpecifier -> TypeSpec -> DerivedTypeSpec
+!PARSE-TREE: Name = 'tt'
+!PARSE-TREE: OmpReductionCombiner -> AssignmentStmt = 'omp_out%r=min(omp_out%r,omp_in%r)'
+!PARSE-TREE: OmpClauseList -> OmpClause -> Initializer -> OmpInitializerClause -> AssignmentStmt = 'omp_priv%r=1._4'
+ call random_number(values%r)
+
+ sum%r = 0
+ !$omp parallel do reduction(+:sum)
+!CHECK: !$OMP PARALLEL DO REDUCTION(+: sum)
+!PARSE-TREE: ExecutionPartConstruct -> ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
+!PARSE-TREE: OmpBeginLoopDirective
+!PARSE-TREE: OmpLoopDirective -> llvm::omp::Directive = parallel do
+!PARSE-TREE: OmpClauseList -> OmpClause -> Reduction -> OmpReductionClause
+!PARSE-TREE: Modifier -> OmpReductionIdentifier -> DefinedOperator -> IntrinsicOperator = Add
+!PARSE-TREE: OmpObjectList -> OmpObject -> Designator -> DataRef -> Name = 'sum
+!PARSE-TREE: DoConstruct
+ do i = 1, n
+ sum%r = sum%r + values(i)%r
+ end do
+
+ prod%r = 1
+ !$omp parallel do reduction(*:prod)
+!CHECK: !$OMP PARALLEL DO REDUCTION(*: prod)
+!PARSE-TREE: ExecutionPartConstruct -> ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
+!PARSE-TREE: OmpBeginLoopDirective
+!PARSE-TREE: OmpLoopDirective -> llvm::omp::Directive = parallel do
+!PARSE-TREE: OmpClauseList -> OmpClause -> Reduction -> OmpReductionClause
+!PARSE-TREE: Modifier -> OmpReductionIdentifier -> DefinedOperator -> IntrinsicOperator = Multiply
+!PARSE-TREE: OmpObjectList -> OmpObject -> Designator -> DataRef -> Name = 'prod'
+!PARSE-TREE: DoConstruct
+ do i = 1, n
+ prod%r = prod%r * (values(i)%r+0.6)
+ end do
+
+ big%r = 0
+ !$omp parallel do reduction(max:big)
+!CHECK: $OMP PARALLEL DO REDUCTION(max: big)
+!PARSE-TREE: ExecutionPartConstruct -> ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
+!PARSE-TREE: OmpBeginLoopDirective
+!PARSE-TREE: OmpLoopDirective -> llvm::omp::Directive = parallel do
+!PARSE-TREE: OmpClauseList -> OmpClause -> Reduction -> OmpReductionClause
+!PARSE-TREE: Modifier -> OmpReductionIdentifier -> ProcedureDesignator -> Name = 'max'
+!PARSE-TREE: OmpObjectList -> OmpObject -> Designator -> DataRef -> Name = 'big'
+!PARSE-TREE: DoConstruct
+ do i = 1, n
+ big = mymax(values(i), big)
+ end do
+
+ small%r = 1
+ !$omp parallel do reduction(min:small)
+!CHECK: !$OMP PARALLEL DO REDUCTION(min: small)
+!CHECK-TREE: ExecutionPartConstruct -> ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
+!CHECK-TREE: OmpBeginLoopDirective
+!CHECK-TREE: OmpLoopDirective -> llvm::omp::Directive = parallel do
+!CHECK-TREE: OmpClauseList -> OmpClause -> Reduction -> OmpReductionClause
+!CHECK-TREE: Modifier -> OmpReductionIdentifier -> ProcedureDesignator -> Name = 'min'
+!CHECK-TREE: OmpObjectList -> OmpObject -> Designator -> DataRef -> Name = 'small'
+!CHECK-TREE: DoConstruct
+ do i = 1, n
+ small%r = min(values(i)%r, small%r)
+ end do
+
+ print *, values%r
+ print *, "sum=", sum%r
+ print *, "prod=", prod%r
+ print *, "small=", small%r, " big=", big%r
+end program omp_examples
diff --git a/flang/test/Parser/OpenMP/declare-reduction-operator.f90 b/flang/test/Parser/OpenMP/declare-reduction-operator.f90
new file mode 100644
index 0000000000000..7bfb78115b10d
--- /dev/null
+++ b/flang/test/Parser/OpenMP/declare-reduction-operator.f90
@@ -0,0 +1,59 @@
+! RUN: %flang_fc1 -fdebug-unparse -fopenmp %s | FileCheck --ignore-case %s
+! RUN: %flang_fc1 -fdebug-dump-parse-tree -fopenmp %s | FileCheck --check-prefix="PARSE-TREE...
[truncated]
|
@llvm/pr-subscribers-flang-openmp Author: Mats Petersson (Leporacanthicus) ChangesThis adds another puzzle piece for the support of OpenMP DECLARE REDUCTION functionality. This adds support for operators with derived types, as well as declaring multiple different types with the same name or operator. A new detail class for UserReductionDetials is introduced to hold the list of types supported for a given reduction declaration. Tests for parsing and symbol generation added. Declare reduction is still not supported to lowering, it will generate a "Not yet implemented" fatal error. Patch is 31.68 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/131628.diff 12 Files Affected:
diff --git a/flang/include/flang/Semantics/symbol.h b/flang/include/flang/Semantics/symbol.h
index 715811885c219..12867a5f8ec6f 100644
--- a/flang/include/flang/Semantics/symbol.h
+++ b/flang/include/flang/Semantics/symbol.h
@@ -701,6 +701,25 @@ class GenericDetails {
};
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const GenericDetails &);
+class UserReductionDetails : public WithBindName {
+public:
+ using TypeVector = std::vector<const DeclTypeSpec *>;
+ UserReductionDetails() = default;
+
+ void AddType(const DeclTypeSpec *type) { typeList_.push_back(type); }
+ const TypeVector &GetTypeList() const { return typeList_; }
+
+ bool SupportsType(const DeclTypeSpec *type) const {
+ for (auto t : typeList_)
+ if (t == type)
+ return true;
+ return false;
+ }
+
+private:
+ TypeVector typeList_;
+};
+
class UnknownDetails {};
using Details = std::variant<UnknownDetails, MainProgramDetails, ModuleDetails,
@@ -708,7 +727,7 @@ using Details = std::variant<UnknownDetails, MainProgramDetails, ModuleDetails,
ObjectEntityDetails, ProcEntityDetails, AssocEntityDetails,
DerivedTypeDetails, UseDetails, UseErrorDetails, HostAssocDetails,
GenericDetails, ProcBindingDetails, NamelistDetails, CommonBlockDetails,
- TypeParamDetails, MiscDetails>;
+ TypeParamDetails, MiscDetails, UserReductionDetails>;
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const Details &);
std::string DetailsToString(const Details &);
diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp
index 5fcebdca0bc5f..fc8b6f1021b02 100644
--- a/flang/lib/Semantics/check-omp-structure.cpp
+++ b/flang/lib/Semantics/check-omp-structure.cpp
@@ -8,6 +8,7 @@
#include "check-omp-structure.h"
#include "definable.h"
+#include "resolve-names-utils.h"
#include "flang/Evaluate/check-expression.h"
#include "flang/Evaluate/expression.h"
#include "flang/Evaluate/type.h"
@@ -3361,8 +3362,8 @@ bool OmpStructureChecker::CheckReductionOperator(
valid =
llvm::is_contained({"max", "min", "iand", "ior", "ieor"}, realName);
if (!valid) {
- auto *misc{name->symbol->detailsIf<MiscDetails>()};
- valid = misc && misc->kind() == MiscDetails::Kind::ConstructName;
+ auto *reductionDetails{name->symbol->detailsIf<UserReductionDetails>()};
+ valid = reductionDetails != nullptr;
}
}
if (!valid) {
@@ -3444,7 +3445,8 @@ void OmpStructureChecker::CheckReductionObjects(
}
static bool IsReductionAllowedForType(
- const parser::OmpReductionIdentifier &ident, const DeclTypeSpec &type) {
+ const parser::OmpReductionIdentifier &ident, const DeclTypeSpec &type,
+ const Scope &scope) {
auto isLogical{[](const DeclTypeSpec &type) -> bool {
return type.category() == DeclTypeSpec::Logical;
}};
@@ -3464,9 +3466,11 @@ static bool IsReductionAllowedForType(
case parser::DefinedOperator::IntrinsicOperator::Multiply:
case parser::DefinedOperator::IntrinsicOperator::Add:
case parser::DefinedOperator::IntrinsicOperator::Subtract:
- return type.IsNumeric(TypeCategory::Integer) ||
+ if (type.IsNumeric(TypeCategory::Integer) ||
type.IsNumeric(TypeCategory::Real) ||
- type.IsNumeric(TypeCategory::Complex);
+ type.IsNumeric(TypeCategory::Complex))
+ return true;
+ break;
case parser::DefinedOperator::IntrinsicOperator::AND:
case parser::DefinedOperator::IntrinsicOperator::OR:
@@ -3479,8 +3483,18 @@ static bool IsReductionAllowedForType(
DIE("This should have been caught in CheckIntrinsicOperator");
return false;
}
+ parser::CharBlock name{MakeNameFromOperator(*intrinsicOp)};
+ Symbol *symbol{scope.FindSymbol(name)};
+ if (symbol) {
+ const auto *reductionDetails{symbol->detailsIf<UserReductionDetails>()};
+ assert(reductionDetails && "Expected to find reductiondetails");
+
+ return reductionDetails->SupportsType(&type);
+ }
+ return false;
}
- return true;
+ assert(0 && "Intrinsic Operator not found - parsing gone wrong?");
+ return false; // Reject everything else.
}};
auto checkDesignator{[&](const parser::ProcedureDesignator &procD) {
@@ -3493,18 +3507,42 @@ static bool IsReductionAllowedForType(
// IAND: arguments must be integers: F2023 16.9.100
// IEOR: arguments must be integers: F2023 16.9.106
// IOR: arguments must be integers: F2023 16.9.111
- return type.IsNumeric(TypeCategory::Integer);
+ if (type.IsNumeric(TypeCategory::Integer)) {
+ return true;
+ }
} else if (realName == "max" || realName == "min") {
// MAX: arguments must be integer, real, or character:
// F2023 16.9.135
// MIN: arguments must be integer, real, or character:
// F2023 16.9.141
- return type.IsNumeric(TypeCategory::Integer) ||
- type.IsNumeric(TypeCategory::Real) || isCharacter(type);
+ if (type.IsNumeric(TypeCategory::Integer) ||
+ type.IsNumeric(TypeCategory::Real) || isCharacter(type)) {
+ return true;
+ }
}
+
+ // If we get here, it may be a user declared reduction, so check
+ // if the symbol has UserReductionDetails, and if so, the type is
+ // supported.
+ if (const auto *reductionDetails{
+ name->symbol->detailsIf<UserReductionDetails>()}) {
+ return reductionDetails->SupportsType(&type);
+ }
+
+ // We also need to check for mangled names (max, min, iand, ieor and ior)
+ // and then check if the type is there.
+ parser::CharBlock mangledName = MangleSpecialFunctions(name->source);
+ if (const auto &symbol{scope.FindSymbol(mangledName)}) {
+ if (const auto *reductionDetails{
+ symbol->detailsIf<UserReductionDetails>()}) {
+ return reductionDetails->SupportsType(&type);
+ }
+ }
+ // Everything else is "not matching type".
+ return false;
}
- // TODO: user defined reduction operators. Just allow everything for now.
- return true;
+ assert(0 && "name and name->symbol should be set here...");
+ return false;
}};
return common::visit(
@@ -3519,7 +3557,8 @@ void OmpStructureChecker::CheckReductionObjectTypes(
for (auto &[symbol, source] : symbols) {
if (auto *type{symbol->GetType()}) {
- if (!IsReductionAllowedForType(ident, *type)) {
+ const auto &scope{context_.FindScope(symbol->name())};
+ if (!IsReductionAllowedForType(ident, *type, scope)) {
context_.Say(source,
"The type of '%s' is incompatible with the reduction operator."_err_en_US,
symbol->name());
diff --git a/flang/lib/Semantics/resolve-names-utils.h b/flang/lib/Semantics/resolve-names-utils.h
index 64784722ff4f8..de0991d69b61b 100644
--- a/flang/lib/Semantics/resolve-names-utils.h
+++ b/flang/lib/Semantics/resolve-names-utils.h
@@ -146,5 +146,9 @@ struct SymbolAndTypeMappings;
void MapSubprogramToNewSymbols(const Symbol &oldSymbol, Symbol &newSymbol,
Scope &newScope, SymbolAndTypeMappings * = nullptr);
+parser::CharBlock MakeNameFromOperator(
+ const parser::DefinedOperator::IntrinsicOperator &op);
+parser::CharBlock MangleSpecialFunctions(const parser::CharBlock name);
+
} // namespace Fortran::semantics
#endif // FORTRAN_SEMANTICS_RESOLVE_NAMES_H_
diff --git a/flang/lib/Semantics/resolve-names.cpp b/flang/lib/Semantics/resolve-names.cpp
index fcd4ba6a51907..825ab36d2e800 100644
--- a/flang/lib/Semantics/resolve-names.cpp
+++ b/flang/lib/Semantics/resolve-names.cpp
@@ -1748,15 +1748,75 @@ void OmpVisitor::ProcessMapperSpecifier(const parser::OmpMapperSpecifier &spec,
PopScope();
}
+parser::CharBlock MakeNameFromOperator(
+ const parser::DefinedOperator::IntrinsicOperator &op) {
+ switch (op) {
+ case parser::DefinedOperator::IntrinsicOperator::Multiply:
+ return parser::CharBlock{"op.*", 4};
+ case parser::DefinedOperator::IntrinsicOperator::Add:
+ return parser::CharBlock{"op.+", 4};
+ case parser::DefinedOperator::IntrinsicOperator::Subtract:
+ return parser::CharBlock{"op.-", 4};
+
+ case parser::DefinedOperator::IntrinsicOperator::AND:
+ return parser::CharBlock{"op.AND", 6};
+ case parser::DefinedOperator::IntrinsicOperator::OR:
+ return parser::CharBlock{"op.OR", 6};
+ case parser::DefinedOperator::IntrinsicOperator::EQV:
+ return parser::CharBlock{"op.EQV", 7};
+ case parser::DefinedOperator::IntrinsicOperator::NEQV:
+ return parser::CharBlock{"op.NEQV", 8};
+
+ default:
+ assert(0 && "Unsupported operator...");
+ return parser::CharBlock{"op.?", 4};
+ }
+}
+
+parser::CharBlock MangleSpecialFunctions(const parser::CharBlock name) {
+ if (name == "max") {
+ return parser::CharBlock{"op.max", 6};
+ }
+ if (name == "min") {
+ return parser::CharBlock{"op.min", 6};
+ }
+ if (name == "iand") {
+ return parser::CharBlock{"op.iand", 7};
+ }
+ if (name == "ior") {
+ return parser::CharBlock{"op.ior", 6};
+ }
+ if (name == "ieor") {
+ return parser::CharBlock{"op.ieor", 7};
+ }
+ // All other names: return as is.
+ return name;
+}
+
void OmpVisitor::ProcessReductionSpecifier(
const parser::OmpReductionSpecifier &spec,
const std::optional<parser::OmpClauseList> &clauses) {
+ const parser::Name *name{nullptr};
+ parser::Name mangledName{};
+ UserReductionDetails reductionDetailsTemp{};
const auto &id{std::get<parser::OmpReductionIdentifier>(spec.t)};
if (auto procDes{std::get_if<parser::ProcedureDesignator>(&id.u)}) {
- if (auto *name{std::get_if<parser::Name>(&procDes->u)}) {
- name->symbol =
- &MakeSymbol(*name, MiscDetails{MiscDetails::Kind::ConstructName});
+ name = std::get_if<parser::Name>(&procDes->u);
+ if (name) {
+ mangledName.source = MangleSpecialFunctions(name->source);
}
+ } else {
+ const auto &defOp{std::get<parser::DefinedOperator>(id.u)};
+ mangledName.source = MakeNameFromOperator(
+ std::get<parser::DefinedOperator::IntrinsicOperator>(defOp.u));
+ name = &mangledName;
+ }
+
+ UserReductionDetails *reductionDetails{&reductionDetailsTemp};
+ Symbol *symbol{name ? name->symbol : nullptr};
+ symbol = FindSymbol(mangledName);
+ if (symbol) {
+ reductionDetails = symbol->detailsIf<UserReductionDetails>();
}
auto &typeList{std::get<parser::OmpTypeNameList>(spec.t)};
@@ -1788,6 +1848,10 @@ void OmpVisitor::ProcessReductionSpecifier(
const DeclTypeSpec *typeSpec{GetDeclTypeSpec()};
assert(typeSpec && "We should have a type here");
+ if (reductionDetails) {
+ reductionDetails->AddType(typeSpec);
+ }
+
for (auto &nm : ompVarNames) {
ObjectEntityDetails details{};
details.set_type(*typeSpec);
@@ -1798,6 +1862,13 @@ void OmpVisitor::ProcessReductionSpecifier(
Walk(clauses);
PopScope();
}
+
+ if (name) {
+ if (!symbol) {
+ symbol = &MakeSymbol(mangledName, Attrs{}, std::move(*reductionDetails));
+ }
+ name->symbol = symbol;
+ }
}
bool OmpVisitor::Pre(const parser::OmpDirectiveSpecification &x) {
diff --git a/flang/lib/Semantics/symbol.cpp b/flang/lib/Semantics/symbol.cpp
index 32eb6c2c5a188..e627dd293ba7c 100644
--- a/flang/lib/Semantics/symbol.cpp
+++ b/flang/lib/Semantics/symbol.cpp
@@ -246,7 +246,7 @@ void GenericDetails::CopyFrom(const GenericDetails &from) {
// This is primarily for debugging.
std::string DetailsToString(const Details &details) {
return common::visit(
- common::visitors{
+ common::visitors{//
[](const UnknownDetails &) { return "Unknown"; },
[](const MainProgramDetails &) { return "MainProgram"; },
[](const ModuleDetails &) { return "Module"; },
@@ -266,7 +266,7 @@ std::string DetailsToString(const Details &details) {
[](const TypeParamDetails &) { return "TypeParam"; },
[](const MiscDetails &) { return "Misc"; },
[](const AssocEntityDetails &) { return "AssocEntity"; },
- },
+ [](const UserReductionDetails &) { return "UserReductionDetails"; }},
details);
}
@@ -300,6 +300,9 @@ bool Symbol::CanReplaceDetails(const Details &details) const {
[&](const HostAssocDetails &) {
return this->has<HostAssocDetails>();
},
+ [&](const UserReductionDetails &) {
+ return this->has<UserReductionDetails>();
+ },
[](const auto &) { return false; },
},
details);
@@ -598,6 +601,11 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Details &details) {
[&](const MiscDetails &x) {
os << ' ' << MiscDetails::EnumToString(x.kind());
},
+ [&](const UserReductionDetails &x) {
+ for (auto &type : x.GetTypeList()) {
+ DumpType(os, type);
+ }
+ },
[&](const auto &x) { os << x; },
},
details);
diff --git a/flang/test/Parser/OpenMP/declare-reduction-multi.f90 b/flang/test/Parser/OpenMP/declare-reduction-multi.f90
new file mode 100644
index 0000000000000..0e1adcc9958d7
--- /dev/null
+++ b/flang/test/Parser/OpenMP/declare-reduction-multi.f90
@@ -0,0 +1,134 @@
+! RUN: %flang_fc1 -fdebug-unparse -fopenmp %s | FileCheck --ignore-case %s
+! RUN: %flang_fc1 -fdebug-dump-parse-tree -fopenmp %s | FileCheck --check-prefix="PARSE-TREE" %s
+
+!! Test multiple declarations for the same type, with different operations.
+module mymod
+ type :: tt
+ real r
+ end type tt
+contains
+ function mymax(a, b)
+ type(tt) :: a, b, mymax
+ if (a%r > b%r) then
+ mymax = a
+ else
+ mymax = b
+ end if
+ end function mymax
+end module mymod
+
+program omp_examples
+!CHECK-LABEL: PROGRAM omp_examples
+ use mymod
+ implicit none
+ integer, parameter :: n = 100
+ integer :: i
+ type(tt) :: values(n), sum, prod, big, small
+
+ !$omp declare reduction(+:tt:omp_out%r = omp_out%r + omp_in%r) initializer(omp_priv%r = 0)
+!CHECK: !$OMP DECLARE REDUCTION (+:tt: omp_out%r=omp_out%r+omp_in%r
+!CHECK-NEXT: ) INITIALIZER(omp_priv%r=0_4)
+!PARSE-TREE: DeclarationConstruct -> SpecificationConstruct -> OpenMPDeclarativeConstruct -> OpenMPDeclareReductionConstruct
+!PARSE-TREE: Verbatim
+!PARSE-TREE: OmpReductionSpecifier
+!PARSE-TREE-NEXT: OmpReductionIdentifier -> DefinedOperator -> IntrinsicOperator = Add
+!PARSE-TREE: OmpTypeNameList -> OmpTypeSpecifier -> TypeSpec -> DerivedTypeSpec
+!PARSE-TREE-NEXT: Name = 'tt'
+!PARSE-TREE: OmpReductionCombiner -> AssignmentStmt = 'omp_out%r=omp_out%r+omp_in%r'
+!PARSE-TREE: OmpClauseList -> OmpClause -> Initializer -> OmpInitializerClause -> AssignmentStmt = 'omp_priv%r=0._4
+ !$omp declare reduction(*:tt:omp_out%r = omp_out%r * omp_in%r) initializer(omp_priv%r = 1)
+!CHECK-NEXT: !$OMP DECLARE REDUCTION (*:tt: omp_out%r=omp_out%r*omp_in%r
+!CHECK-NEXT: ) INITIALIZER(omp_priv%r=1_4)
+!PARSE-TREE: DeclarationConstruct -> SpecificationConstruct -> OpenMPDeclarativeConstruct -> OpenMPDeclareReductionConstruct
+!PARSE-TREE: Verbatim
+!PARSE-TREE: OmpReductionSpecifier
+!PARSE-TREE: OmpReductionIdentifier -> DefinedOperator -> IntrinsicOperator = Multiply
+!PARSE-TREE: OmpTypeNameList -> OmpTypeSpecifier -> TypeSpec -> DerivedTypeSpec
+!PARSE-TREE-NEXT: Name = 'tt'
+!PARSE-TREE: OmpReductionCombiner -> AssignmentStmt = 'omp_out%r=omp_out%r*omp_in%r'
+!PARSE-TREE: OmpClauseList -> OmpClause -> Initializer -> OmpInitializerClause -> AssignmentStmt = 'omp_priv%r=1._4'
+ !$omp declare reduction(max:tt:omp_out = mymax(omp_out, omp_in)) initializer(omp_priv%r = 0)
+!CHECK-NEXT: !$OMP DECLARE REDUCTION (max:tt: omp_out=mymax(omp_out,omp_in)
+!CHECK-NEXT: ) INITIALIZER(omp_priv%r=0_4)
+!PARSE-TREE: DeclarationConstruct -> SpecificationConstruct -> OpenMPDeclarativeConstruct -> OpenMPDeclareReductionConstruct
+!PARSE-TREE: Verbatim
+!PARSE-TREE: OmpReductionSpecifier
+!PARSE-TREE: OmpReductionIdentifier -> ProcedureDesignator -> Name = 'max'
+!PARSE-TREE: OmpTypeNameList -> OmpTypeSpecifier -> TypeSpec -> DerivedTypeSpec
+!PARSE-TREE: Name = 'tt'
+!PARSE-TREE: OmpReductionCombiner -> AssignmentStmt = 'omp_out=mymax(omp_out,omp_in)'
+!PARSE-TREE: OmpClauseList -> OmpClause -> Initializer -> OmpInitializerClause -> AssignmentStmt = 'omp_priv%r=0._4'
+ !$omp declare reduction(min:tt:omp_out%r = min(omp_out%r, omp_in%r)) initializer(omp_priv%r = 1)
+!CHECK-NEXT: !$OMP DECLARE REDUCTION (min:tt: omp_out%r=min(omp_out%r,omp_in%r)
+!CHECK-NEXT: ) INITIALIZER(omp_priv%r=1_4)
+!PARSE-TREE: DeclarationConstruct -> SpecificationConstruct -> OpenMPDeclarativeConstruct -> OpenMPDeclareReductionConstruct
+!PARSE-TREE: Verbatim
+!PARSE-TREE: OmpReductionSpecifier
+!PARSE-TREE: OmpReductionIdentifier -> ProcedureDesignator -> Name = 'min'
+!PARSE-TREE: OmpTypeNameList -> OmpTypeSpecifier -> TypeSpec -> DerivedTypeSpec
+!PARSE-TREE: Name = 'tt'
+!PARSE-TREE: OmpReductionCombiner -> AssignmentStmt = 'omp_out%r=min(omp_out%r,omp_in%r)'
+!PARSE-TREE: OmpClauseList -> OmpClause -> Initializer -> OmpInitializerClause -> AssignmentStmt = 'omp_priv%r=1._4'
+ call random_number(values%r)
+
+ sum%r = 0
+ !$omp parallel do reduction(+:sum)
+!CHECK: !$OMP PARALLEL DO REDUCTION(+: sum)
+!PARSE-TREE: ExecutionPartConstruct -> ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
+!PARSE-TREE: OmpBeginLoopDirective
+!PARSE-TREE: OmpLoopDirective -> llvm::omp::Directive = parallel do
+!PARSE-TREE: OmpClauseList -> OmpClause -> Reduction -> OmpReductionClause
+!PARSE-TREE: Modifier -> OmpReductionIdentifier -> DefinedOperator -> IntrinsicOperator = Add
+!PARSE-TREE: OmpObjectList -> OmpObject -> Designator -> DataRef -> Name = 'sum
+!PARSE-TREE: DoConstruct
+ do i = 1, n
+ sum%r = sum%r + values(i)%r
+ end do
+
+ prod%r = 1
+ !$omp parallel do reduction(*:prod)
+!CHECK: !$OMP PARALLEL DO REDUCTION(*: prod)
+!PARSE-TREE: ExecutionPartConstruct -> ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
+!PARSE-TREE: OmpBeginLoopDirective
+!PARSE-TREE: OmpLoopDirective -> llvm::omp::Directive = parallel do
+!PARSE-TREE: OmpClauseList -> OmpClause -> Reduction -> OmpReductionClause
+!PARSE-TREE: Modifier -> OmpReductionIdentifier -> DefinedOperator -> IntrinsicOperator = Multiply
+!PARSE-TREE: OmpObjectList -> OmpObject -> Designator -> DataRef -> Name = 'prod'
+!PARSE-TREE: DoConstruct
+ do i = 1, n
+ prod%r = prod%r * (values(i)%r+0.6)
+ end do
+
+ big%r = 0
+ !$omp parallel do reduction(max:big)
+!CHECK: $OMP PARALLEL DO REDUCTION(max: big)
+!PARSE-TREE: ExecutionPartConstruct -> ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
+!PARSE-TREE: OmpBeginLoopDirective
+!PARSE-TREE: OmpLoopDirective -> llvm::omp::Directive = parallel do
+!PARSE-TREE: OmpClauseList -> OmpClause -> Reduction -> OmpReductionClause
+!PARSE-TREE: Modifier -> OmpReductionIdentifier -> ProcedureDesignator -> Name = 'max'
+!PARSE-TREE: OmpObjectList -> OmpObject -> Designator -> DataRef -> Name = 'big'
+!PARSE-TREE: DoConstruct
+ do i = 1, n
+ big = mymax(values(i), big)
+ end do
+
+ small%r = 1
+ !$omp parallel do reduction(min:small)
+!CHECK: !$OMP PARALLEL DO REDUCTION(min: small)
+!CHECK-TREE: ExecutionPartConstruct -> ExecutableConstruct -> OpenMPConstruct -> OpenMPLoopConstruct
+!CHECK-TREE: OmpBeginLoopDirective
+!CHECK-TREE: OmpLoopDirective -> llvm::omp::Directive = parallel do
+!CHECK-TREE: OmpClauseList -> OmpClause -> Reduction -> OmpReductionClause
+!CHECK-TREE: Modifier -> OmpReductionIdentifier -> ProcedureDesignator -> Name = 'min'
+!CHECK-TREE: OmpObjectList -> OmpObject -> Designator -> DataRef -> Name = 'small'
+!CHECK-TREE: DoConstruct
+ do i = 1, n
+ small%r = min(values(i)%r, small%r)
+ end do
+
+ print *, values%r
+ print *, "sum=", sum%r
+ print *, "prod=", prod%r
+ print *, "small=", small%r, " big=", big%r
+end program omp_examples
diff --git a/flang/test/Parser/OpenMP/declare-reduction-operator.f90 b/flang/test/Parser/OpenMP/declare-reduction-operator.f90
new file mode 100644
index 0000000000000..7bfb78115b10d
--- /dev/null
+++ b/flang/test/Parser/OpenMP/declare-reduction-operator.f90
@@ -0,0 +1,59 @@
+! RUN: %flang_fc1 -fdebug-unparse -fopenmp %s | FileCheck --ignore-case %s
+! RUN: %flang_fc1 -fdebug-dump-parse-tree -fopenmp %s | FileCheck --check-prefix="PARSE-TREE...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, just a few minor comments. Good job thinking of all the edge cases to test.
for (auto t : typeList_) | ||
if (t == type) | ||
return true; | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think semantics style requires braces here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would the following be better?
bool SupportsType(const DeclTypeSpec *type) const {
return llvm::is_contained(typeList_, type);
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
if (name == "max") { | ||
return parser::CharBlock{"op.max", 6}; | ||
} | ||
if (name == "min") { | ||
return parser::CharBlock{"op.min", 6}; | ||
} | ||
if (name == "iand") { | ||
return parser::CharBlock{"op.iand", 7}; | ||
} | ||
if (name == "ior") { | ||
return parser::CharBlock{"op.ior", 6}; | ||
} | ||
if (name == "ieor") { | ||
return parser::CharBlock{"op.ieor", 7}; | ||
} | ||
// All other names: return as is. | ||
return name; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: llvm::StringSwitch might be more efficient here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
void OmpVisitor::ProcessReductionSpecifier( | ||
const parser::OmpReductionSpecifier &spec, | ||
const std::optional<parser::OmpClauseList> &clauses) { | ||
const parser::Name *name{nullptr}; | ||
parser::Name mangledName{}; | ||
UserReductionDetails reductionDetailsTemp{}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
UserReductionDetails reductionDetailsTemp{}; | |
UserReductionDetails reductionDetailsTemp; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. [and the line above too].
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A few initial comments.
} | ||
return true; | ||
assert(0 && "Intrinsic Operator not found - parsing gone wrong?"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think Semantics uses the DIE macro.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
} | ||
// TODO: user defined reduction operators. Just allow everything for now. | ||
return true; | ||
assert(0 && "name and name->symbol should be set here..."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment regarding similar usage above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
return parser::CharBlock{"op.NEQV", 8}; | ||
|
||
default: | ||
assert(0 && "Unsupported operator..."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comment regarding similar usage above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@@ -701,14 +701,33 @@ class GenericDetails { | |||
}; | |||
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const GenericDetails &); | |||
|
|||
class UserReductionDetails : public WithBindName { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the relation between UserReductionDetail and BindName? Why did you decide to inherit from WithBindName?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Probably shouldn't inherit anything... :)
for (auto t : typeList_) | ||
if (t == type) | ||
return true; | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would the following be better?
bool SupportsType(const DeclTypeSpec *type) const {
return llvm::is_contained(typeList_, type);
}
|
||
// We also need to check for mangled names (max, min, iand, ieor and ior) | ||
// and then check if the type is there. | ||
parser::CharBlock mangledName = MangleSpecialFunctions(name->source); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: braced initialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -701,14 +701,33 @@ class GenericDetails { | |||
}; | |||
llvm::raw_ostream &operator<<(llvm::raw_ostream &, const GenericDetails &); | |||
|
|||
class UserReductionDetails : public WithBindName { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a comment above detailing the class and its usage.
if (type.IsNumeric(TypeCategory::Integer) || | ||
type.IsNumeric(TypeCategory::Real) || | ||
type.IsNumeric(TypeCategory::Complex); | ||
type.IsNumeric(TypeCategory::Complex)) | ||
return true; | ||
break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the same change required for the isLogical check below?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
if (!symbol) { | ||
symbol = &MakeSymbol(mangledName, Attrs{}, std::move(*reductionDetails)); | ||
} | ||
name->symbol = symbol; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will be good to see a test for the multiple case as well. Does it override the symbol?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm adding a new test for dupliocate symbols. There is one, but I think we need another different one.
And yes, we need to set the symbol on the original name, even if the actual name is not matching, so updating the symbol is required. If not, another portion finds the name and says it's not been resolved.
There is a case for UserReductionDetails that has to be added while writing a module file. The following test crashes in that code.
|
} else { | ||
const auto &defOp{std::get<parser::DefinedOperator>(id.u)}; | ||
mangledName.source = MakeNameFromOperator( | ||
std::get<parser::DefinedOperator::IntrinsicOperator>(defOp.u)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For cases where there is operator overloading this std::get
will crash.
program m1
interface operator(.mul.)
procedure my_mul
end interface
type t1
integer :: val = 1
end type
!$omp declare reduction(.mul.:t1:omp_out=omp_out.mul.omp_in)
contains
function my_mul(x, y)
type (t1), intent (in) :: x, y
type (t1) :: my_mul
my_mul%val = x%val * y%val
end function
end program
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now fixed, and additional tests added.
|
||
if (!reductionDetails) { | ||
context().Say(name->source, | ||
"Duplicate defineition of '%s' in !$OMP DECLARE REDUCTION"_err_en_US, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Duplicate defineition of '%s' in !$OMP DECLARE REDUCTION"_err_en_US, | |
"Duplicate definition of '%s' in DECLARE REDUCTION"_err_en_US, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the patch. A few nits.
flang/lib/Semantics/symbol.cpp
Outdated
@@ -246,7 +246,7 @@ void GenericDetails::CopyFrom(const GenericDetails &from) { | |||
// This is primarily for debugging. | |||
std::string DetailsToString(const Details &details) { | |||
return common::visit( | |||
common::visitors{ | |||
common::visitors{// |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stray "//" ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No idea how that happened. I don't use VI. :)
|
||
default: | ||
DIE("Unsupported operator..."); | ||
return parser::CharBlock{"op.?", 4}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will return parser::CharBlock{"op.?", 4};
be executed after the DIE
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As below (I'm working "up" the list), this was an assert, which of course gets removed in some builds. DIE, as I understand it, always kills the compiler process, regardless of build.
// TODO: user defined reduction operators. Just allow everything for now. | ||
return true; | ||
DIE("name and name->symbol should be set here..."); | ||
return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will return false;
be executed after the DIE
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this used to be an assert, and when changing to die, I didn't remove the line below.
} | ||
// TODO: user defined reduction operators. Just allow everything for now. | ||
return true; | ||
DIE("name and name->symbol should be set here..."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When we say should be set here
, do we want to give a location for the error? Can name->symbol
be rather wrapped as an assertion in this function rather than a DIE
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, moved up to just where name is fetched (and return false moved around to always return something).
return parser::CharBlock{"op.NEQV", 8}; | ||
|
||
default: | ||
DIE("Unsupported operator..."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we emit a semantic error for unsupported operator here? Is there a specific reason to prefer DIE
over an error?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good spot, I thought there was some other code that stops us from using "bad" combitnation here, but it turns out that you can do !$omp declare reduction(/:integer:...)
and make it die. Error message printing implemented, and another test added.
Now works, and additional test added. |
Adding @klausler for the changes in writing mod-file, and other areas like symbol.{cpp,h}. |
flang/lib/Semantics/mod-file.cpp
Outdated
@@ -1035,6 +1037,25 @@ void ModFileWriter::PutTypeParam(llvm::raw_ostream &os, const Symbol &symbol) { | |||
os << '\n'; | |||
} | |||
|
|||
void ModFileWriter::PutUserReduction( | |||
llvm::raw_ostream &os, const Symbol &symbol) { | |||
auto &details{symbol.get<UserReductionDetails>()}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
flang/lib/Semantics/mod-file.cpp
Outdated
// declaration. There may be multiple declarations. | ||
// Decls are pointers, so do not use a referene. | ||
for (const auto decl : details.GetDeclList()) { | ||
if (auto d = std::get_if<const parser::OpenMPDeclareReductionConstruct *>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use common::visit
on a lambda with a const auto &
argument instead!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
auto *misc{name->symbol->detailsIf<MiscDetails>()}; | ||
valid = misc && misc->kind() == MiscDetails::Kind::ConstructName; | ||
auto *reductionDetails{name->symbol->detailsIf<UserReductionDetails>()}; | ||
valid = reductionDetails != nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why didn't you use valid = symbol->has<UserReductionDetails>();
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
parser::CharBlock MakeNameFromOperator( | ||
const parser::DefinedOperator::IntrinsicOperator &op, | ||
SemanticsContext &context); | ||
parser::CharBlock MangleSpecialFunctions(const parser::CharBlock name); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's not a reference or pointer, const
is superfluous in the prototype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
|
||
UserReductionDetails() = default; | ||
|
||
void AddType(const DeclTypeSpec *type) { typeList_.push_back(type); } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
argument should be a reference so you don't have to worry about null pointers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
void AddType(const DeclTypeSpec *type) { typeList_.push_back(type); } | ||
const TypeVector &GetTypeList() const { return typeList_; } | ||
|
||
bool SupportsType(const DeclTypeSpec *type) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use a reference argument, not a pointer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍
This adds another puzzle piece for the support of OpenMP DECLARE REDUCTION functionality. This adds support for operators with derived types, as well as declaring multiple different types with the same name or operator. A new detail class for UserReductionDetials is introduced to hold the list of types supported for a given reduction declaration. Tests for parsing and symbol generation added. Declare reduction is still not supported to lowering, it will generate a "Not yet implemented" fatal error.
* Add two more tests (multiple operator-based declarations and re-using symbol already declared. * Add a few comments. * Fix up logical results.
Also print the reduction declaration in the module file. Fix trivial typo. Add/modify tests to cover all the new things, including fixing the duplicated typo in the test...
Also rebase, as the branch was quite a way behind. Small conflict was resolved.
36bc8be
to
93f8179
Compare
Can you check whether the declare reduction works with renamed operators. I am giving an example below for reference, but please think through in general about the possibility of renamed operators.
|
There is a crash in
The following test also crashes in some checks for pure functions.
|
Add code to better handle operators in parsing and semantics. Add a function to set the the scope when processign assignments, which caused a crash in "check for pure functions". Add three new tests and amend existing tests to cover a pure function.
[snip big chunk of code]
Fixed all three of the code snippets that fail, and added tests for checking that it works. |
@@ -3485,8 +3496,20 @@ void OmpStructureChecker::CheckReductionObjects( | |||
} | |||
} | |||
|
|||
static bool CheckSymbolSupportsType(const Scope &scope, | |||
const parser::CharBlock &name, const DeclTypeSpec &type) { | |||
if (const auto &symbol{scope.FindSymbol(name)}) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
const auto &
-> const auto *
@@ -8,6 +8,7 @@ | |||
|
|||
#include "mod-file.h" | |||
#include "resolve-names.h" | |||
#include "flang/Common/indirection.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems unnecessary...
reductionDetails = symbol->detailsIf<UserReductionDetails>(); | ||
|
||
if (!reductionDetails) { | ||
context().Say(name->source, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The "source" argument to Say must point to somewhere within the cooked sources. It's used to locate the line number and for underlining of the problematic fragment. I'm not sure if that's guaranteed here, since some names are created within this function.
@@ -343,6 +347,11 @@ class SemanticsContext { | |||
std::map<const Symbol *, SourceName> moduleFileOutputRenamings_; | |||
UnorderedSymbolSet isDefined_; | |||
std::list<ProgramTree> programTrees_; | |||
|
|||
// storage for mangled names used in OMP DECLARE REDUCTION. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo: "storage" -> "Storage".
This adds another puzzle piece for the support of OpenMP DECLARE REDUCTION functionality.
This adds support for operators with derived types, as well as declaring multiple different types with the same name or operator.
A new detail class for UserReductionDetials is introduced to hold the list of types supported for a given reduction declaration.
Tests for parsing and symbol generation added.
Declare reduction is still not supported to lowering, it will generate a "Not yet implemented" fatal error.