1616#include " flang/Optimizer/HLFIR/HLFIROps.h"
1717#include " flang/Optimizer/Support/DataLayout.h"
1818#include " flang/Optimizer/Support/InternalNames.h"
19+ #include " flang/Runtime/stop.h"
1920#include " mlir/IR/Matchers.h"
2021#include " mlir/Transforms/DialectConversion.h"
2122#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -74,6 +75,111 @@ static mlir::Value genStatPRIF(fir::FirOpBuilder &builder, mlir::Location loc,
7475 return stat;
7576}
7677
78+ static fir::CallOp genPRIFStopErrorStop (fir::FirOpBuilder &builder,
79+ mlir::Location loc,
80+ mlir::Value stopCode,
81+ bool isError = false ) {
82+ mlir::Type stopCharTy = fir::BoxCharType::get (builder.getContext (), 1 );
83+ mlir::Type i1Ty = builder.getI1Type ();
84+ mlir::Type i32Ty = builder.getI32Type ();
85+
86+ mlir::FunctionType ftype = mlir::FunctionType::get (
87+ builder.getContext (),
88+ /* inputs*/
89+ {builder.getRefType (i1Ty), builder.getRefType (i32Ty), stopCharTy},
90+ /* results*/ {});
91+ mlir::func::FuncOp funcOp =
92+ isError
93+ ? builder.createFunction (loc, getPRIFProcName (" error_stop" ), ftype)
94+ : builder.createFunction (loc, getPRIFProcName (" stop" ), ftype);
95+
96+ // QUIET is managed in flang-rt, so its value is set to TRUE here.
97+ mlir::Value q = builder.createBool (loc, true );
98+ mlir::Value quiet = builder.createTemporary (loc, i1Ty);
99+ fir::StoreOp::create (builder, loc, q, quiet);
100+
101+ mlir::Value stopCodeInt, stopCodeChar;
102+ if (!stopCode) {
103+ stopCodeChar = fir::AbsentOp::create (builder, loc, stopCharTy);
104+ stopCodeInt =
105+ fir::AbsentOp::create (builder, loc, builder.getRefType (i32Ty));
106+ } else if (fir::isa_integer (stopCode.getType ())) {
107+ stopCodeChar = fir::AbsentOp::create (builder, loc, stopCharTy);
108+ stopCodeInt = builder.createTemporary (loc, i32Ty);
109+ if (stopCode.getType () != i32Ty)
110+ stopCode = fir::ConvertOp::create (builder, loc, i32Ty, stopCode);
111+ fir::StoreOp::create (builder, loc, stopCode, stopCodeInt);
112+ } else {
113+ stopCodeChar = stopCode;
114+ if (!mlir::isa<fir::BoxCharType>(stopCodeChar.getType ())) {
115+ auto len =
116+ fir::UndefOp::create (builder, loc, builder.getCharacterLengthType ());
117+ stopCodeChar =
118+ fir::EmboxCharOp::create (builder, loc, stopCharTy, stopCodeChar, len);
119+ }
120+ stopCodeInt =
121+ fir::AbsentOp::create (builder, loc, builder.getRefType (i32Ty));
122+ }
123+
124+ llvm::SmallVector<mlir::Value> args = fir::runtime::createArguments (
125+ builder, loc, ftype, quiet, stopCodeInt, stopCodeChar);
126+ return fir::CallOp::create (builder, loc, funcOp, args);
127+ }
128+
129+ enum class TerminationKind { Normal = 0 , Error = 1 , FailImage = 2 };
130+ // Generates a wrapper function for the different kind of termination in PRIF.
131+ // This function will be used to register wrappers on PRIF runtime termination
132+ // functions into the Fortran runtime.
133+ mlir::Value genTerminationOperationWrapper (fir::FirOpBuilder &builder,
134+ mlir::Location loc,
135+ mlir::ModuleOp module ,
136+ TerminationKind termKind) {
137+ std::string funcName;
138+ mlir::FunctionType funcType =
139+ mlir::FunctionType::get (builder.getContext (), {}, {});
140+ mlir::Type i32Ty = builder.getI32Type ();
141+ if (termKind == TerminationKind::Normal) {
142+ funcName = getPRIFProcName (" stop" );
143+ funcType = mlir::FunctionType::get (builder.getContext (), {i32Ty}, {});
144+ } else if (termKind == TerminationKind::Error) {
145+ funcName = getPRIFProcName (" error_stop" );
146+ funcType = mlir::FunctionType::get (builder.getContext (), {i32Ty}, {});
147+ } else {
148+ funcName = getPRIFProcName (" fail_image" );
149+ }
150+ funcName += " _termination_wrapper" ;
151+ mlir::func::FuncOp funcWrapperOp =
152+ module .lookupSymbol <mlir::func::FuncOp>(funcName);
153+
154+ if (!funcWrapperOp) {
155+ funcWrapperOp = builder.createFunction (loc, funcName, funcType);
156+
157+ // generating the body of the function.
158+ mlir::OpBuilder::InsertPoint saveInsertPoint = builder.saveInsertionPoint ();
159+ builder.setInsertionPointToStart (funcWrapperOp.addEntryBlock ());
160+
161+ if (termKind == TerminationKind::Normal) {
162+ genPRIFStopErrorStop (builder, loc, funcWrapperOp.getArgument (0 ),
163+ /* isError*/ false );
164+ } else if (termKind == TerminationKind::Error) {
165+ genPRIFStopErrorStop (builder, loc, funcWrapperOp.getArgument (0 ),
166+ /* isError*/ true );
167+ } else {
168+ mlir::func::FuncOp fOp = builder.createFunction (
169+ loc, getPRIFProcName (" fail_image" ),
170+ mlir::FunctionType::get (builder.getContext (), {}, {}));
171+ fir::CallOp::create (builder, loc, fOp );
172+ }
173+
174+ mlir::func::ReturnOp::create (builder, loc);
175+ builder.restoreInsertionPoint (saveInsertPoint);
176+ }
177+
178+ mlir::SymbolRefAttr symbolRef = mlir::SymbolRefAttr::get (
179+ builder.getContext (), funcWrapperOp.getSymNameAttr ());
180+ return fir::AddrOfOp::create (builder, loc, funcType, symbolRef);
181+ }
182+
77183// / Convert mif.init operation to runtime call of 'prif_init'
78184struct MIFInitOpConversion : public mlir ::OpRewritePattern<mif::InitOp> {
79185 using OpRewritePattern::OpRewritePattern;
@@ -87,6 +193,39 @@ struct MIFInitOpConversion : public mlir::OpRewritePattern<mif::InitOp> {
87193
88194 mlir::Type i32Ty = builder.getI32Type ();
89195 mlir::Value result = builder.createTemporary (loc, i32Ty);
196+
197+ // Registering PRIF runtime termination to the Fortran runtime
198+ // STOP
199+ mlir::Value funcStopOp = genTerminationOperationWrapper (
200+ builder, loc, mod, TerminationKind::Normal);
201+ mlir::func::FuncOp normalEndFunc =
202+ fir::runtime::getRuntimeFunc<mkRTKey (RegisterImagesNormalEndCallback)>(
203+ loc, builder);
204+ llvm::SmallVector<mlir::Value> args1 = fir::runtime::createArguments (
205+ builder, loc, normalEndFunc.getFunctionType (), funcStopOp);
206+ fir::CallOp::create (builder, loc, normalEndFunc, args1);
207+
208+ // ERROR STOP
209+ mlir::Value funcErrorStopOp = genTerminationOperationWrapper (
210+ builder, loc, mod, TerminationKind::Error);
211+ mlir::func::FuncOp errorFunc =
212+ fir::runtime::getRuntimeFunc<mkRTKey (RegisterImagesErrorCallback)>(
213+ loc, builder);
214+ llvm::SmallVector<mlir::Value> args2 = fir::runtime::createArguments (
215+ builder, loc, errorFunc.getFunctionType (), funcErrorStopOp);
216+ fir::CallOp::create (builder, loc, errorFunc, args2);
217+
218+ // FAIL IMAGE
219+ mlir::Value failImageOp = genTerminationOperationWrapper (
220+ builder, loc, mod, TerminationKind::FailImage);
221+ mlir::func::FuncOp failImageFunc =
222+ fir::runtime::getRuntimeFunc<mkRTKey (RegisterFailImageCallback)>(
223+ loc, builder);
224+ llvm::SmallVector<mlir::Value> args3 = fir::runtime::createArguments (
225+ builder, loc, errorFunc.getFunctionType (), failImageOp);
226+ fir::CallOp::create (builder, loc, failImageFunc, args3);
227+
228+ // Intialize the multi-image parallel environment
90229 mlir::FunctionType ftype = mlir::FunctionType::get (
91230 builder.getContext (),
92231 /* inputs*/ {builder.getRefType (i32Ty)}, /* results*/ {});
0 commit comments