Thanks to visit codestin.com
Credit goes to github.com

Skip to content

Commit 6f8cea3

Browse files
JDPailleuxbonachea
andauthored
[flang][MIF] Adding Stop and ErrorStop PRIF call procedures (llvm#166787)
This PR proposes to add `Stop` and `ErrorStop` PRIF call procedures to the MIF dialect. If the `-fcoarray` flag is passed, then all calls to `STOP` and `ERROR STOP` will use those of PRIF in flang-rt. Thes procedure has been registered during the initialization (mif::InitOp). --------- Co-authored-by: Dan Bonachea <[email protected]>
1 parent 8a6be40 commit 6f8cea3

11 files changed

Lines changed: 262 additions & 20 deletions

File tree

flang-rt/include/flang-rt/runtime/terminator.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#define FLANG_RT_RUNTIME_TERMINATOR_H_
1313

1414
#include "flang/Common/api-attrs.h"
15+
#include "flang/Runtime/stop.h"
1516
#include <cstdarg>
1617
#include <cstdio>
1718
#include <cstdlib>
@@ -112,9 +113,16 @@ class Terminator {
112113
else \
113114
Terminator{__FILE__, __LINE__}.CheckFailed(#pred)
114115

115-
RT_API_ATTRS void NotifyOtherImagesOfNormalEnd();
116+
void SetNormalEndCallback(void (*callback)(int));
117+
void SetFailImageCallback(void (*callback)(void));
118+
void SetErrorCallback(void (*callback)(int));
119+
120+
[[noreturn]] void NormalExit(int exitCode);
121+
[[noreturn]] void ErrorExit(int exitCode);
122+
123+
RT_API_ATTRS void SynchronizeImagesOfNormalEnd(int);
116124
RT_API_ATTRS void NotifyOtherImagesOfFailImageStatement();
117-
RT_API_ATTRS void NotifyOtherImagesOfErrorTermination();
125+
RT_API_ATTRS void NotifyOtherImagesOfErrorTermination(int);
118126

119127
#if defined(RT_DEVICE_COMPILATION)
120128
/// Trap the execution on the device.

flang-rt/lib/runtime/main.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ static void ConfigureFloatingPoint() {
2929
extern "C" {
3030
void RTNAME(ProgramStart)(int argc, const char *argv[], const char *envp[],
3131
const EnvironmentDefaultList *envDefaults) {
32-
std::atexit(Fortran::runtime::NotifyOtherImagesOfNormalEnd);
3332
Fortran::runtime::executionEnvironment.Configure(
3433
argc, argv, envp, envDefaults);
3534
ConfigureFloatingPoint();

flang-rt/lib/runtime/stop.cpp

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,10 @@ static void CloseAllExternalUnits(const char *why) {
9696
std::fputc('\n', stderr);
9797
DescribeIEEESignaledExceptions();
9898
}
99-
std::exit(code);
99+
if (isErrorStop)
100+
Fortran::runtime::ErrorExit(code);
101+
else
102+
Fortran::runtime::NormalExit(code);
100103
#endif
101104
}
102105

@@ -124,9 +127,9 @@ static void CloseAllExternalUnits(const char *why) {
124127
DescribeIEEESignaledExceptions();
125128
}
126129
if (isErrorStop) {
127-
std::exit(EXIT_FAILURE);
130+
Fortran::runtime::ErrorExit(EXIT_FAILURE);
128131
} else {
129-
std::exit(EXIT_SUCCESS);
132+
Fortran::runtime::NormalExit(EXIT_SUCCESS);
130133
}
131134
#endif
132135
}
@@ -144,7 +147,7 @@ static void EndPause() {
144147
std::fflush(nullptr);
145148
if (std::fgetc(stdin) == EOF) {
146149
CloseAllExternalUnits("PAUSE statement");
147-
std::exit(EXIT_SUCCESS);
150+
Fortran::runtime::ErrorExit(EXIT_SUCCESS);
148151
}
149152
}
150153

@@ -172,19 +175,31 @@ void RTNAME(PauseStatementText)(const char *code, std::size_t length) {
172175
}
173176

174177
[[noreturn]] void RTNAME(FailImageStatement)() {
175-
Fortran::runtime::NotifyOtherImagesOfFailImageStatement();
176178
CloseAllExternalUnits("FAIL IMAGE statement");
177-
std::exit(EXIT_FAILURE);
179+
Fortran::runtime::NotifyOtherImagesOfFailImageStatement();
180+
Fortran::runtime::NormalExit(EXIT_FAILURE);
178181
}
179182

180183
[[noreturn]] void RTNAME(ProgramEndStatement)() {
181184
CloseAllExternalUnits("END statement");
182-
std::exit(EXIT_SUCCESS);
185+
Fortran::runtime::NormalExit(EXIT_SUCCESS);
186+
}
187+
188+
void RTNAME(RegisterImagesNormalEndCallback)(void (*callback)(int)) {
189+
Fortran::runtime::SetNormalEndCallback(callback);
190+
}
191+
192+
void RTNAME(RegisterImagesErrorCallback)(void (*callback)(int)) {
193+
Fortran::runtime::SetErrorCallback(callback);
194+
}
195+
196+
void RTNAME(RegisterFailImageCallback)(void (*callback)(void)) {
197+
Fortran::runtime::SetFailImageCallback(callback);
183198
}
184199

185200
[[noreturn]] void RTNAME(Exit)(int status) {
186201
CloseAllExternalUnits("CALL EXIT()");
187-
std::exit(status);
202+
Fortran::runtime::NormalExit(status);
188203
}
189204

190205
static RT_NOINLINE_ATTR void PrintBacktrace() {

flang-rt/lib/runtime/terminator.cpp

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ RT_API_ATTRS void Terminator::CrashHeader() const {
7373
// FIXME: re-enable the flush along with the IO enabling.
7474
io::FlushOutputOnCrash(*this);
7575
#endif
76-
NotifyOtherImagesOfErrorTermination();
76+
NotifyOtherImagesOfErrorTermination(EXIT_FAILURE);
7777
#if defined(RT_DEVICE_COMPILATION)
7878
DeviceTrap();
7979
#else
@@ -93,11 +93,48 @@ RT_API_ATTRS void Terminator::CrashHeader() const {
9393
sourceFileName_, sourceLine_);
9494
}
9595

96-
// TODO: These will be defined in the coarray runtime library
97-
RT_API_ATTRS void NotifyOtherImagesOfNormalEnd() {}
98-
RT_API_ATTRS void NotifyOtherImagesOfFailImageStatement() {}
99-
RT_API_ATTRS void NotifyOtherImagesOfErrorTermination() {}
96+
static RT_VAR_ATTRS void (*normalEndCallback)(int) = nullptr;
97+
static RT_VAR_ATTRS void (*failImageCallback)(void) = nullptr;
98+
static RT_VAR_ATTRS void (*errorCallback)(int) = nullptr;
10099

100+
void SetNormalEndCallback(void (*callback)(int)) {
101+
normalEndCallback = callback;
102+
}
103+
104+
void SetFailImageCallback(void (*callback)(void)) {
105+
failImageCallback = callback;
106+
}
107+
108+
void SetErrorCallback(void (*callback)(int)) { errorCallback = callback; }
109+
110+
[[noreturn]]
111+
void NormalExit(int exitCode) {
112+
SynchronizeImagesOfNormalEnd(exitCode); // might never return
113+
114+
std::exit(exitCode);
115+
}
116+
117+
[[noreturn]]
118+
void ErrorExit(int exitCode) {
119+
NotifyOtherImagesOfErrorTermination(exitCode); // might never return
120+
121+
std::exit(exitCode);
122+
}
123+
124+
RT_API_ATTRS void SynchronizeImagesOfNormalEnd(int code) {
125+
if (normalEndCallback)
126+
(*normalEndCallback)(code);
127+
}
128+
129+
RT_API_ATTRS void NotifyOtherImagesOfFailImageStatement() {
130+
if (failImageCallback)
131+
(*failImageCallback)();
132+
}
133+
134+
RT_API_ATTRS void NotifyOtherImagesOfErrorTermination(int code) {
135+
if (errorCallback)
136+
(*errorCallback)(code);
137+
}
101138
RT_OFFLOAD_API_GROUP_END
102139

103140
} // namespace Fortran::runtime

flang/include/flang/Optimizer/Builder/Runtime/RTBuilder.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,14 @@ getModel<void *(*)(void *, const void *, unsigned __int64)>() {
279279
}
280280
#endif
281281
template <>
282+
constexpr TypeBuilderFunc getModel<void (*)(void)>() {
283+
return [](mlir::MLIRContext *context) -> mlir::Type {
284+
return fir::LLVMPointerType::get(
285+
context,
286+
mlir::FunctionType::get(context, /*inputs=*/{}, /*results*/ {}));
287+
};
288+
}
289+
template <>
282290
constexpr TypeBuilderFunc getModel<void **>() {
283291
return [](mlir::MLIRContext *context) -> mlir::Type {
284292
return fir::ReferenceType::get(

flang/include/flang/Optimizer/Dialect/FIRTypes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ def AnyLogicalLike : TypeConstraint<Or<[BoolLike.predicate,
597597
fir_LogicalType.predicate]>, "any logical">;
598598
def AnyRealLike : TypeConstraint<FloatLike.predicate, "any real">;
599599
def AnyIntegerType : Type<AnyIntegerLike.predicate, "any integer">;
600+
def AnyLogicalType : Type<AnyLogicalLike.predicate, "any logical">;
600601

601602
def AnyFirComplexLike : TypeConstraint<CPred<"::fir::isa_complex($_self)">,
602603
"any floating point complex type">;

flang/include/flang/Runtime/stop.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ void RTNAME(PauseStatementText)(const char *, size_t);
2828
NORETURN void RTNAME(FailImageStatement)(NO_ARGUMENTS);
2929
NORETURN void RTNAME(ProgramEndStatement)(NO_ARGUMENTS);
3030

31+
void RTNAME(RegisterImagesNormalEndCallback)(void (*)(int));
32+
void RTNAME(RegisterImagesErrorCallback)(void (*)(int));
33+
void RTNAME(RegisterFailImageCallback)(void (*)(void));
34+
3135
// Extensions
3236
NORETURN void RTNAME(Exit)(int status DEFAULT_VALUE(EXIT_SUCCESS));
3337
RT_OFFLOAD_API_GROUP_BEGIN

flang/lib/Lower/Runtime.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ void Fortran::lower::genStopStatement(
118118
}
119119

120120
fir::CallOp::create(builder, loc, callee, operands);
121+
121122
auto blockIsUnterminated = [&builder]() {
122123
mlir::Block *currentBlock = builder.getBlock();
123124
return currentBlock->empty() ||

flang/lib/Optimizer/Builder/Runtime/Main.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,8 @@ void fir::runtime::genMain(
7474
mif::InitOp::create(builder, loc);
7575

7676
fir::CallOp::create(builder, loc, qqMainFn);
77-
fir::CallOp::create(builder, loc, stopFn);
7877

7978
mlir::Value ret = builder.createIntegerConstant(loc, argcTy, 0);
79+
fir::CallOp::create(builder, loc, stopFn);
8080
mlir::func::ReturnOp::create(builder, loc, ret);
8181
}

flang/lib/Optimizer/Transforms/MIFOpConversion.cpp

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1717
#include "flang/Optimizer/Support/DataLayout.h"
1818
#include "flang/Optimizer/Support/InternalNames.h"
19+
#include "flang/Runtime/stop.h"
1920
#include "mlir/IR/Matchers.h"
2021
#include "mlir/Transforms/DialectConversion.h"
2122
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -74,6 +75,111 @@ static mlir::Value genStatPRIF(fir::FirOpBuilder &builder, mlir::Location loc,
7475
return stat;
7576
}
7677

78+
static fir::CallOp genPRIFStopErrorStop(fir::FirOpBuilder &builder,
79+
mlir::Location loc,
80+
mlir::Value stopCode,
81+
bool isError = false) {
82+
mlir::Type stopCharTy = fir::BoxCharType::get(builder.getContext(), 1);
83+
mlir::Type i1Ty = builder.getI1Type();
84+
mlir::Type i32Ty = builder.getI32Type();
85+
86+
mlir::FunctionType ftype = mlir::FunctionType::get(
87+
builder.getContext(),
88+
/*inputs*/
89+
{builder.getRefType(i1Ty), builder.getRefType(i32Ty), stopCharTy},
90+
/*results*/ {});
91+
mlir::func::FuncOp funcOp =
92+
isError
93+
? builder.createFunction(loc, getPRIFProcName("error_stop"), ftype)
94+
: builder.createFunction(loc, getPRIFProcName("stop"), ftype);
95+
96+
// QUIET is managed in flang-rt, so its value is set to TRUE here.
97+
mlir::Value q = builder.createBool(loc, true);
98+
mlir::Value quiet = builder.createTemporary(loc, i1Ty);
99+
fir::StoreOp::create(builder, loc, q, quiet);
100+
101+
mlir::Value stopCodeInt, stopCodeChar;
102+
if (!stopCode) {
103+
stopCodeChar = fir::AbsentOp::create(builder, loc, stopCharTy);
104+
stopCodeInt =
105+
fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty));
106+
} else if (fir::isa_integer(stopCode.getType())) {
107+
stopCodeChar = fir::AbsentOp::create(builder, loc, stopCharTy);
108+
stopCodeInt = builder.createTemporary(loc, i32Ty);
109+
if (stopCode.getType() != i32Ty)
110+
stopCode = fir::ConvertOp::create(builder, loc, i32Ty, stopCode);
111+
fir::StoreOp::create(builder, loc, stopCode, stopCodeInt);
112+
} else {
113+
stopCodeChar = stopCode;
114+
if (!mlir::isa<fir::BoxCharType>(stopCodeChar.getType())) {
115+
auto len =
116+
fir::UndefOp::create(builder, loc, builder.getCharacterLengthType());
117+
stopCodeChar =
118+
fir::EmboxCharOp::create(builder, loc, stopCharTy, stopCodeChar, len);
119+
}
120+
stopCodeInt =
121+
fir::AbsentOp::create(builder, loc, builder.getRefType(i32Ty));
122+
}
123+
124+
llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments(
125+
builder, loc, ftype, quiet, stopCodeInt, stopCodeChar);
126+
return fir::CallOp::create(builder, loc, funcOp, args);
127+
}
128+
129+
enum class TerminationKind { Normal = 0, Error = 1, FailImage = 2 };
130+
// Generates a wrapper function for the different kind of termination in PRIF.
131+
// This function will be used to register wrappers on PRIF runtime termination
132+
// functions into the Fortran runtime.
133+
mlir::Value genTerminationOperationWrapper(fir::FirOpBuilder &builder,
134+
mlir::Location loc,
135+
mlir::ModuleOp module,
136+
TerminationKind termKind) {
137+
std::string funcName;
138+
mlir::FunctionType funcType =
139+
mlir::FunctionType::get(builder.getContext(), {}, {});
140+
mlir::Type i32Ty = builder.getI32Type();
141+
if (termKind == TerminationKind::Normal) {
142+
funcName = getPRIFProcName("stop");
143+
funcType = mlir::FunctionType::get(builder.getContext(), {i32Ty}, {});
144+
} else if (termKind == TerminationKind::Error) {
145+
funcName = getPRIFProcName("error_stop");
146+
funcType = mlir::FunctionType::get(builder.getContext(), {i32Ty}, {});
147+
} else {
148+
funcName = getPRIFProcName("fail_image");
149+
}
150+
funcName += "_termination_wrapper";
151+
mlir::func::FuncOp funcWrapperOp =
152+
module.lookupSymbol<mlir::func::FuncOp>(funcName);
153+
154+
if (!funcWrapperOp) {
155+
funcWrapperOp = builder.createFunction(loc, funcName, funcType);
156+
157+
// generating the body of the function.
158+
mlir::OpBuilder::InsertPoint saveInsertPoint = builder.saveInsertionPoint();
159+
builder.setInsertionPointToStart(funcWrapperOp.addEntryBlock());
160+
161+
if (termKind == TerminationKind::Normal) {
162+
genPRIFStopErrorStop(builder, loc, funcWrapperOp.getArgument(0),
163+
/*isError*/ false);
164+
} else if (termKind == TerminationKind::Error) {
165+
genPRIFStopErrorStop(builder, loc, funcWrapperOp.getArgument(0),
166+
/*isError*/ true);
167+
} else {
168+
mlir::func::FuncOp fOp = builder.createFunction(
169+
loc, getPRIFProcName("fail_image"),
170+
mlir::FunctionType::get(builder.getContext(), {}, {}));
171+
fir::CallOp::create(builder, loc, fOp);
172+
}
173+
174+
mlir::func::ReturnOp::create(builder, loc);
175+
builder.restoreInsertionPoint(saveInsertPoint);
176+
}
177+
178+
mlir::SymbolRefAttr symbolRef = mlir::SymbolRefAttr::get(
179+
builder.getContext(), funcWrapperOp.getSymNameAttr());
180+
return fir::AddrOfOp::create(builder, loc, funcType, symbolRef);
181+
}
182+
77183
/// Convert mif.init operation to runtime call of 'prif_init'
78184
struct MIFInitOpConversion : public mlir::OpRewritePattern<mif::InitOp> {
79185
using OpRewritePattern::OpRewritePattern;
@@ -87,6 +193,39 @@ struct MIFInitOpConversion : public mlir::OpRewritePattern<mif::InitOp> {
87193

88194
mlir::Type i32Ty = builder.getI32Type();
89195
mlir::Value result = builder.createTemporary(loc, i32Ty);
196+
197+
// Registering PRIF runtime termination to the Fortran runtime
198+
// STOP
199+
mlir::Value funcStopOp = genTerminationOperationWrapper(
200+
builder, loc, mod, TerminationKind::Normal);
201+
mlir::func::FuncOp normalEndFunc =
202+
fir::runtime::getRuntimeFunc<mkRTKey(RegisterImagesNormalEndCallback)>(
203+
loc, builder);
204+
llvm::SmallVector<mlir::Value> args1 = fir::runtime::createArguments(
205+
builder, loc, normalEndFunc.getFunctionType(), funcStopOp);
206+
fir::CallOp::create(builder, loc, normalEndFunc, args1);
207+
208+
// ERROR STOP
209+
mlir::Value funcErrorStopOp = genTerminationOperationWrapper(
210+
builder, loc, mod, TerminationKind::Error);
211+
mlir::func::FuncOp errorFunc =
212+
fir::runtime::getRuntimeFunc<mkRTKey(RegisterImagesErrorCallback)>(
213+
loc, builder);
214+
llvm::SmallVector<mlir::Value> args2 = fir::runtime::createArguments(
215+
builder, loc, errorFunc.getFunctionType(), funcErrorStopOp);
216+
fir::CallOp::create(builder, loc, errorFunc, args2);
217+
218+
// FAIL IMAGE
219+
mlir::Value failImageOp = genTerminationOperationWrapper(
220+
builder, loc, mod, TerminationKind::FailImage);
221+
mlir::func::FuncOp failImageFunc =
222+
fir::runtime::getRuntimeFunc<mkRTKey(RegisterFailImageCallback)>(
223+
loc, builder);
224+
llvm::SmallVector<mlir::Value> args3 = fir::runtime::createArguments(
225+
builder, loc, errorFunc.getFunctionType(), failImageOp);
226+
fir::CallOp::create(builder, loc, failImageFunc, args3);
227+
228+
// Intialize the multi-image parallel environment
90229
mlir::FunctionType ftype = mlir::FunctionType::get(
91230
builder.getContext(),
92231
/*inputs*/ {builder.getRefType(i32Ty)}, /*results*/ {});

0 commit comments

Comments
 (0)