Thanks to visit codestin.com
Credit goes to clang.llvm.org

clang 22.0.0git
LoweringPrepare.cpp
Go to the documentation of this file.
1//===- LoweringPrepare.cpp - pareparation work for LLVM lowering ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PassDetail.h"
11#include "clang/Basic/Module.h"
18#include "llvm/Support/Path.h"
19
20#include <memory>
21
22using namespace mlir;
23using namespace cir;
24
25static SmallString<128> getTransformedFileName(mlir::ModuleOp mlirModule) {
26 SmallString<128> fileName;
27
28 if (mlirModule.getSymName())
29 fileName = llvm::sys::path::filename(mlirModule.getSymName()->str());
30
31 if (fileName.empty())
32 fileName = "<null>";
33
34 for (size_t i = 0; i < fileName.size(); ++i) {
35 // Replace everything that's not [a-zA-Z0-9._] with a _. This set happens
36 // to be the set of C preprocessing numbers.
37 if (!clang::isPreprocessingNumberBody(fileName[i]))
38 fileName[i] = '_';
39 }
40
41 return fileName;
42}
43
44namespace {
45struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
46 LoweringPreparePass() = default;
47 void runOnOperation() override;
48
49 void runOnOp(mlir::Operation *op);
50 void lowerCastOp(cir::CastOp op);
51 void lowerComplexDivOp(cir::ComplexDivOp op);
52 void lowerComplexMulOp(cir::ComplexMulOp op);
53 void lowerUnaryOp(cir::UnaryOp op);
54 void lowerGlobalOp(cir::GlobalOp op);
55 void lowerArrayDtor(cir::ArrayDtor op);
56 void lowerArrayCtor(cir::ArrayCtor op);
57
58 /// Build the function that initializes the specified global
59 cir::FuncOp buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op);
60
61 /// Build a module init function that calls all the dynamic initializers.
62 void buildCXXGlobalInitFunc();
63
64 /// Materialize global ctor/dtor list
65 void buildGlobalCtorDtorList();
66
67 cir::FuncOp buildRuntimeFunction(
68 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
69 cir::FuncType type,
70 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
71
72 ///
73 /// AST related
74 /// -----------
75
76 clang::ASTContext *astCtx;
77
78 /// Tracks current module.
79 mlir::ModuleOp mlirModule;
80
81 /// Tracks existing dynamic initializers.
82 llvm::StringMap<uint32_t> dynamicInitializerNames;
83 llvm::SmallVector<cir::FuncOp> dynamicInitializers;
84
85 /// List of ctors and their priorities to be called before main()
86 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalCtorList;
87
88 void setASTContext(clang::ASTContext *c) { astCtx = c; }
89};
90
91} // namespace
92
93cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
94 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
95 cir::FuncType type, cir::GlobalLinkageKind linkage) {
96 cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
97 mlirModule, StringAttr::get(mlirModule->getContext(), name)));
98 if (!f) {
99 f = builder.create<cir::FuncOp>(loc, name, type);
100 f.setLinkageAttr(
101 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
102 mlir::SymbolTable::setSymbolVisibility(
103 f, mlir::SymbolTable::Visibility::Private);
104
106 }
107 return f;
108}
109
110static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
111 cir::CastOp op) {
112 cir::CIRBaseBuilderTy builder(ctx);
113 builder.setInsertionPoint(op);
114
115 mlir::Value src = op.getSrc();
116 mlir::Value imag = builder.getNullValue(src.getType(), op.getLoc());
117 return builder.createComplexCreate(op.getLoc(), src, imag);
118}
119
120static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx,
121 cir::CastOp op,
122 cir::CastKind elemToBoolKind) {
123 cir::CIRBaseBuilderTy builder(ctx);
124 builder.setInsertionPoint(op);
125
126 mlir::Value src = op.getSrc();
127 if (!mlir::isa<cir::BoolType>(op.getType()))
128 return builder.createComplexReal(op.getLoc(), src);
129
130 // Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b
131 mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
132 mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
133
134 cir::BoolType boolTy = builder.getBoolTy();
135 mlir::Value srcRealToBool =
136 builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy);
137 mlir::Value srcImagToBool =
138 builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy);
139 return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool);
140}
141
142static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx,
143 cir::CastOp op,
144 cir::CastKind scalarCastKind) {
145 CIRBaseBuilderTy builder(ctx);
146 builder.setInsertionPoint(op);
147
148 mlir::Value src = op.getSrc();
149 auto dstComplexElemTy =
150 mlir::cast<cir::ComplexType>(op.getType()).getElementType();
151
152 mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
153 mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
154
155 mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal,
156 dstComplexElemTy);
157 mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag,
158 dstComplexElemTy);
159 return builder.createComplexCreate(op.getLoc(), dstReal, dstImag);
160}
161
162void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
163 mlir::MLIRContext &ctx = getContext();
164 mlir::Value loweredValue = [&]() -> mlir::Value {
165 switch (op.getKind()) {
166 case cir::CastKind::float_to_complex:
167 case cir::CastKind::int_to_complex:
168 return lowerScalarToComplexCast(ctx, op);
169 case cir::CastKind::float_complex_to_real:
170 case cir::CastKind::int_complex_to_real:
171 return lowerComplexToScalarCast(ctx, op, op.getKind());
172 case cir::CastKind::float_complex_to_bool:
173 return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool);
174 case cir::CastKind::int_complex_to_bool:
175 return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool);
176 case cir::CastKind::float_complex:
177 return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating);
178 case cir::CastKind::float_complex_to_int_complex:
179 return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int);
180 case cir::CastKind::int_complex:
181 return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral);
182 case cir::CastKind::int_complex_to_float_complex:
183 return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float);
184 default:
185 return nullptr;
186 }
187 }();
188
189 if (loweredValue) {
190 op.replaceAllUsesWith(loweredValue);
191 op.erase();
192 }
193}
194
195static mlir::Value buildComplexBinOpLibCall(
196 LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
197 llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
198 mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
199 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
200 cir::FPTypeInterface elementTy =
201 mlir::cast<cir::FPTypeInterface>(ty.getElementType());
202
203 llvm::StringRef libFuncName = libFuncNameGetter(
204 llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
205 llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy);
206
207 cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
208
209 // Insert a declaration for the runtime function to be used in Complex
210 // multiplication and division when needed
211 cir::FuncOp libFunc;
212 {
213 mlir::OpBuilder::InsertionGuard ipGuard{builder};
214 builder.setInsertionPointToStart(pass.mlirModule.getBody());
215 libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
216 }
217
218 cir::CallOp call =
219 builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
220 return call.getResult();
221}
222
223static llvm::StringRef
224getComplexDivLibCallName(llvm::APFloat::Semantics semantics) {
225 switch (semantics) {
226 case llvm::APFloat::S_IEEEhalf:
227 return "__divhc3";
228 case llvm::APFloat::S_IEEEsingle:
229 return "__divsc3";
230 case llvm::APFloat::S_IEEEdouble:
231 return "__divdc3";
232 case llvm::APFloat::S_PPCDoubleDouble:
233 return "__divtc3";
234 case llvm::APFloat::S_x87DoubleExtended:
235 return "__divxc3";
236 case llvm::APFloat::S_IEEEquad:
237 return "__divtc3";
238 default:
239 llvm_unreachable("unsupported floating point type");
240 }
241}
242
243static mlir::Value
244buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
245 mlir::Value lhsReal, mlir::Value lhsImag,
246 mlir::Value rhsReal, mlir::Value rhsImag) {
247 // (a+bi) / (c+di) = ((ac+bd)/(cc+dd)) + ((bc-ad)/(cc+dd))i
248 mlir::Value &a = lhsReal;
249 mlir::Value &b = lhsImag;
250 mlir::Value &c = rhsReal;
251 mlir::Value &d = rhsImag;
252
253 mlir::Value ac = builder.createBinop(loc, a, cir::BinOpKind::Mul, c); // a*c
254 mlir::Value bd = builder.createBinop(loc, b, cir::BinOpKind::Mul, d); // b*d
255 mlir::Value cc = builder.createBinop(loc, c, cir::BinOpKind::Mul, c); // c*c
256 mlir::Value dd = builder.createBinop(loc, d, cir::BinOpKind::Mul, d); // d*d
257 mlir::Value acbd =
258 builder.createBinop(loc, ac, cir::BinOpKind::Add, bd); // ac+bd
259 mlir::Value ccdd =
260 builder.createBinop(loc, cc, cir::BinOpKind::Add, dd); // cc+dd
261 mlir::Value resultReal =
262 builder.createBinop(loc, acbd, cir::BinOpKind::Div, ccdd);
263
264 mlir::Value bc = builder.createBinop(loc, b, cir::BinOpKind::Mul, c); // b*c
265 mlir::Value ad = builder.createBinop(loc, a, cir::BinOpKind::Mul, d); // a*d
266 mlir::Value bcad =
267 builder.createBinop(loc, bc, cir::BinOpKind::Sub, ad); // bc-ad
268 mlir::Value resultImag =
269 builder.createBinop(loc, bcad, cir::BinOpKind::Div, ccdd);
270 return builder.createComplexCreate(loc, resultReal, resultImag);
271}
272
273static mlir::Value
275 mlir::Value lhsReal, mlir::Value lhsImag,
276 mlir::Value rhsReal, mlir::Value rhsImag) {
277 // Implements Smith's algorithm for complex division.
278 // SMITH, R. L. Algorithm 116: Complex division. Commun. ACM 5, 8 (1962).
279
280 // Let:
281 // - lhs := a+bi
282 // - rhs := c+di
283 // - result := lhs / rhs = e+fi
284 //
285 // The algorithm pseudocode looks like follows:
286 // if fabs(c) >= fabs(d):
287 // r := d / c
288 // tmp := c + r*d
289 // e = (a + b*r) / tmp
290 // f = (b - a*r) / tmp
291 // else:
292 // r := c / d
293 // tmp := d + r*c
294 // e = (a*r + b) / tmp
295 // f = (b*r - a) / tmp
296
297 mlir::Value &a = lhsReal;
298 mlir::Value &b = lhsImag;
299 mlir::Value &c = rhsReal;
300 mlir::Value &d = rhsImag;
301
302 auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
303 mlir::Value r = builder.createBinop(loc, d, cir::BinOpKind::Div,
304 c); // r := d / c
305 mlir::Value rd = builder.createBinop(loc, r, cir::BinOpKind::Mul, d); // r*d
306 mlir::Value tmp = builder.createBinop(loc, c, cir::BinOpKind::Add,
307 rd); // tmp := c + r*d
308
309 mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r
310 mlir::Value abr =
311 builder.createBinop(loc, a, cir::BinOpKind::Add, br); // a + b*r
312 mlir::Value e = builder.createBinop(loc, abr, cir::BinOpKind::Div, tmp);
313
314 mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r
315 mlir::Value bar =
316 builder.createBinop(loc, b, cir::BinOpKind::Sub, ar); // b - a*r
317 mlir::Value f = builder.createBinop(loc, bar, cir::BinOpKind::Div, tmp);
318
319 mlir::Value result = builder.createComplexCreate(loc, e, f);
320 builder.createYield(loc, result);
321 };
322
323 auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
324 mlir::Value r = builder.createBinop(loc, c, cir::BinOpKind::Div,
325 d); // r := c / d
326 mlir::Value rc = builder.createBinop(loc, r, cir::BinOpKind::Mul, c); // r*c
327 mlir::Value tmp = builder.createBinop(loc, d, cir::BinOpKind::Add,
328 rc); // tmp := d + r*c
329
330 mlir::Value ar = builder.createBinop(loc, a, cir::BinOpKind::Mul, r); // a*r
331 mlir::Value arb =
332 builder.createBinop(loc, ar, cir::BinOpKind::Add, b); // a*r + b
333 mlir::Value e = builder.createBinop(loc, arb, cir::BinOpKind::Div, tmp);
334
335 mlir::Value br = builder.createBinop(loc, b, cir::BinOpKind::Mul, r); // b*r
336 mlir::Value bra =
337 builder.createBinop(loc, br, cir::BinOpKind::Sub, a); // b*r - a
338 mlir::Value f = builder.createBinop(loc, bra, cir::BinOpKind::Div, tmp);
339
340 mlir::Value result = builder.createComplexCreate(loc, e, f);
341 builder.createYield(loc, result);
342 };
343
344 auto cFabs = builder.create<cir::FAbsOp>(loc, c);
345 auto dFabs = builder.create<cir::FAbsOp>(loc, d);
346 cir::CmpOp cmpResult =
347 builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs);
348 auto ternary = builder.create<cir::TernaryOp>(
349 loc, cmpResult, trueBranchBuilder, falseBranchBuilder);
350
351 return ternary.getResult();
352}
353
355 mlir::MLIRContext &context, clang::ASTContext &cc,
356 CIRBaseBuilderTy &builder, mlir::Type elementType) {
357
358 auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type {
359 if (mlir::isa<cir::FP16Type>(type))
360 return cir::SingleType::get(&context);
361
362 if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type))
363 return cir::DoubleType::get(&context);
364
365 if (mlir::isa<cir::DoubleType>(type))
366 return cir::LongDoubleType::get(&context, type);
367
368 return type;
369 };
370
371 auto getFloatTypeSemantics =
372 [&cc](mlir::Type type) -> const llvm::fltSemantics & {
373 const clang::TargetInfo &info = cc.getTargetInfo();
374 if (mlir::isa<cir::FP16Type>(type))
375 return info.getHalfFormat();
376
377 if (mlir::isa<cir::BF16Type>(type))
378 return info.getBFloat16Format();
379
380 if (mlir::isa<cir::SingleType>(type))
381 return info.getFloatFormat();
382
383 if (mlir::isa<cir::DoubleType>(type))
384 return info.getDoubleFormat();
385
386 if (mlir::isa<cir::LongDoubleType>(type)) {
387 if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
388 llvm_unreachable("NYI Float type semantics with OpenMP");
389 return info.getLongDoubleFormat();
390 }
391
392 if (mlir::isa<cir::FP128Type>(type)) {
393 if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
394 llvm_unreachable("NYI Float type semantics with OpenMP");
395 return info.getFloat128Format();
396 }
397
398 assert(false && "Unsupported float type semantics");
399 };
400
401 const mlir::Type higherElementType = getHigherPrecisionFPType(elementType);
402 const llvm::fltSemantics &elementTypeSemantics =
403 getFloatTypeSemantics(elementType);
404 const llvm::fltSemantics &higherElementTypeSemantics =
405 getFloatTypeSemantics(higherElementType);
406
407 // Check that the promoted type can handle the intermediate values without
408 // overflowing. This can be interpreted as:
409 // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <=
410 // LargerType.LargestFiniteVal.
411 // In terms of exponent it gives this formula:
412 // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal
413 // doubles the exponent of SmallerType.LargestFiniteVal)
414 if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <=
415 llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) {
416 return higherElementType;
417 }
418
419 // The intermediate values can't be represented in the promoted type
420 // without overflowing.
421 return {};
422}
423
424static mlir::Value
425lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
426 mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal,
427 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag,
428 mlir::MLIRContext &mlirCx, clang::ASTContext &cc) {
429 cir::ComplexType complexTy = op.getType();
430 if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) {
431 cir::ComplexRangeKind range = op.getRange();
432 if (range == cir::ComplexRangeKind::Improved)
433 return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
434 rhsReal, rhsImag);
435
436 if (range == cir::ComplexRangeKind::Full)
438 loc, complexTy, lhsReal, lhsImag, rhsReal,
439 rhsImag);
440
441 if (range == cir::ComplexRangeKind::Promoted) {
442 mlir::Type originalElementType = complexTy.getElementType();
443 mlir::Type higherPrecisionElementType =
445 originalElementType);
446
447 if (!higherPrecisionElementType)
448 return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
449 rhsReal, rhsImag);
450
451 cir::CastKind floatingCastKind = cir::CastKind::floating;
452 lhsReal = builder.createCast(floatingCastKind, lhsReal,
453 higherPrecisionElementType);
454 lhsImag = builder.createCast(floatingCastKind, lhsImag,
455 higherPrecisionElementType);
456 rhsReal = builder.createCast(floatingCastKind, rhsReal,
457 higherPrecisionElementType);
458 rhsImag = builder.createCast(floatingCastKind, rhsImag,
459 higherPrecisionElementType);
460
461 mlir::Value algebraicResult = buildAlgebraicComplexDiv(
462 builder, loc, lhsReal, lhsImag, rhsReal, rhsImag);
463
464 mlir::Value resultReal = builder.createComplexReal(loc, algebraicResult);
465 mlir::Value resultImag = builder.createComplexImag(loc, algebraicResult);
466
467 mlir::Value finalReal =
468 builder.createCast(floatingCastKind, resultReal, originalElementType);
469 mlir::Value finalImag =
470 builder.createCast(floatingCastKind, resultImag, originalElementType);
471 return builder.createComplexCreate(loc, finalReal, finalImag);
472 }
473 }
474
475 return buildAlgebraicComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal,
476 rhsImag);
477}
478
479void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) {
480 cir::CIRBaseBuilderTy builder(getContext());
481 builder.setInsertionPointAfter(op);
482 mlir::Location loc = op.getLoc();
483 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
484 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
485 mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
486 mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
487 mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
488 mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
489
490 mlir::Value loweredResult =
491 lowerComplexDiv(*this, builder, loc, op, lhsReal, lhsImag, rhsReal,
492 rhsImag, getContext(), *astCtx);
493 op.replaceAllUsesWith(loweredResult);
494 op.erase();
495}
496
497static llvm::StringRef
498getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
499 switch (semantics) {
500 case llvm::APFloat::S_IEEEhalf:
501 return "__mulhc3";
502 case llvm::APFloat::S_IEEEsingle:
503 return "__mulsc3";
504 case llvm::APFloat::S_IEEEdouble:
505 return "__muldc3";
506 case llvm::APFloat::S_PPCDoubleDouble:
507 return "__multc3";
508 case llvm::APFloat::S_x87DoubleExtended:
509 return "__mulxc3";
510 case llvm::APFloat::S_IEEEquad:
511 return "__multc3";
512 default:
513 llvm_unreachable("unsupported floating point type");
514 }
515}
516
517static mlir::Value lowerComplexMul(LoweringPreparePass &pass,
518 CIRBaseBuilderTy &builder,
519 mlir::Location loc, cir::ComplexMulOp op,
520 mlir::Value lhsReal, mlir::Value lhsImag,
521 mlir::Value rhsReal, mlir::Value rhsImag) {
522 // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
523 mlir::Value resultRealLhs =
524 builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsReal);
525 mlir::Value resultRealRhs =
526 builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsImag);
527 mlir::Value resultImagLhs =
528 builder.createBinop(loc, lhsReal, cir::BinOpKind::Mul, rhsImag);
529 mlir::Value resultImagRhs =
530 builder.createBinop(loc, lhsImag, cir::BinOpKind::Mul, rhsReal);
531 mlir::Value resultReal = builder.createBinop(
532 loc, resultRealLhs, cir::BinOpKind::Sub, resultRealRhs);
533 mlir::Value resultImag = builder.createBinop(
534 loc, resultImagLhs, cir::BinOpKind::Add, resultImagRhs);
535 mlir::Value algebraicResult =
536 builder.createComplexCreate(loc, resultReal, resultImag);
537
538 cir::ComplexType complexTy = op.getType();
539 cir::ComplexRangeKind rangeKind = op.getRange();
540 if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
541 rangeKind == cir::ComplexRangeKind::Basic ||
542 rangeKind == cir::ComplexRangeKind::Improved ||
543 rangeKind == cir::ComplexRangeKind::Promoted)
544 return algebraicResult;
545
547
548 // Check whether the real part and the imaginary part of the result are both
549 // NaN. If so, emit a library call to compute the multiplication instead.
550 // We check a value against NaN by comparing the value against itself.
551 mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal);
552 mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag);
553 mlir::Value resultRealAndImagAreNaN =
554 builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN);
555
556 return builder
557 .create<cir::TernaryOp>(
558 loc, resultRealAndImagAreNaN,
559 [&](mlir::OpBuilder &, mlir::Location) {
560 mlir::Value libCallResult = buildComplexBinOpLibCall(
561 pass, builder, &getComplexMulLibCallName, loc, complexTy,
562 lhsReal, lhsImag, rhsReal, rhsImag);
563 builder.createYield(loc, libCallResult);
564 },
565 [&](mlir::OpBuilder &, mlir::Location) {
566 builder.createYield(loc, algebraicResult);
567 })
568 .getResult();
569}
570
571void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
572 cir::CIRBaseBuilderTy builder(getContext());
573 builder.setInsertionPointAfter(op);
574 mlir::Location loc = op.getLoc();
575 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
576 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
577 mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
578 mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
579 mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
580 mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
581 mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal,
582 lhsImag, rhsReal, rhsImag);
583 op.replaceAllUsesWith(loweredResult);
584 op.erase();
585}
586
587void LoweringPreparePass::lowerUnaryOp(cir::UnaryOp op) {
588 mlir::Type ty = op.getType();
589 if (!mlir::isa<cir::ComplexType>(ty))
590 return;
591
592 mlir::Location loc = op.getLoc();
593 cir::UnaryOpKind opKind = op.getKind();
594
595 CIRBaseBuilderTy builder(getContext());
596 builder.setInsertionPointAfter(op);
597
598 mlir::Value operand = op.getInput();
599 mlir::Value operandReal = builder.createComplexReal(loc, operand);
600 mlir::Value operandImag = builder.createComplexImag(loc, operand);
601
602 mlir::Value resultReal;
603 mlir::Value resultImag;
604
605 switch (opKind) {
606 case cir::UnaryOpKind::Inc:
607 case cir::UnaryOpKind::Dec:
608 resultReal = builder.createUnaryOp(loc, opKind, operandReal);
609 resultImag = operandImag;
610 break;
611
612 case cir::UnaryOpKind::Plus:
613 case cir::UnaryOpKind::Minus:
614 resultReal = builder.createUnaryOp(loc, opKind, operandReal);
615 resultImag = builder.createUnaryOp(loc, opKind, operandImag);
616 break;
617
618 case cir::UnaryOpKind::Not:
619 resultReal = operandReal;
620 resultImag =
621 builder.createUnaryOp(loc, cir::UnaryOpKind::Minus, operandImag);
622 break;
623 }
624
625 mlir::Value result = builder.createComplexCreate(loc, resultReal, resultImag);
626 op.replaceAllUsesWith(result);
627 op.erase();
628}
629
630cir::FuncOp
631LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op) {
632 // TODO(cir): Store this in the GlobalOp.
633 // This should come from the MangleContext, but for now I'm hardcoding it.
634 SmallString<256> fnName("__cxx_global_var_init");
635 // Get a unique name
636 uint32_t cnt = dynamicInitializerNames[fnName]++;
637 if (cnt)
638 fnName += "." + llvm::Twine(cnt).str();
639
640 // Create a variable initialization function.
641 CIRBaseBuilderTy builder(getContext());
642 builder.setInsertionPointAfter(op);
643 auto fnType = cir::FuncType::get({}, builder.getVoidTy());
644 FuncOp f = buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
645 cir::GlobalLinkageKind::InternalLinkage);
646
647 // Move over the initialzation code of the ctor region.
648 mlir::Block *entryBB = f.addEntryBlock();
649 if (!op.getCtorRegion().empty()) {
650 mlir::Block &block = op.getCtorRegion().front();
651 entryBB->getOperations().splice(entryBB->begin(), block.getOperations(),
652 block.begin(), std::prev(block.end()));
653 }
654
655 // Register the destructor call with __cxa_atexit
656 mlir::Region &dtorRegion = op.getDtorRegion();
657 if (!dtorRegion.empty()) {
659 llvm_unreachable("dtor region lowering is NYI");
660 }
661
662 // Replace cir.yield with cir.return
663 builder.setInsertionPointToEnd(entryBB);
664 mlir::Operation *yieldOp = nullptr;
665 if (!op.getCtorRegion().empty()) {
666 mlir::Block &block = op.getCtorRegion().front();
667 yieldOp = &block.getOperations().back();
668 } else {
670 llvm_unreachable("dtor region lowering is NYI");
671 }
672
673 assert(isa<YieldOp>(*yieldOp));
674 cir::ReturnOp::create(builder, yieldOp->getLoc());
675 return f;
676}
677
678void LoweringPreparePass::lowerGlobalOp(GlobalOp op) {
679 mlir::Region &ctorRegion = op.getCtorRegion();
680 mlir::Region &dtorRegion = op.getDtorRegion();
681
682 if (!ctorRegion.empty() || !dtorRegion.empty()) {
683 // Build a variable initialization function and move the initialzation code
684 // in the ctor region over.
685 cir::FuncOp f = buildCXXGlobalVarDeclInitFunc(op);
686
687 // Clear the ctor and dtor region
688 ctorRegion.getBlocks().clear();
689 dtorRegion.getBlocks().clear();
690
692 dynamicInitializers.push_back(f);
693 }
694
696}
697
698template <typename AttributeTy>
699static llvm::SmallVector<mlir::Attribute>
700prepareCtorDtorAttrList(mlir::MLIRContext *context,
701 llvm::ArrayRef<std::pair<std::string, uint32_t>> list) {
703 for (const auto &[name, priority] : list)
704 attrs.push_back(AttributeTy::get(context, name, priority));
705 return attrs;
706}
707
708void LoweringPreparePass::buildGlobalCtorDtorList() {
709 if (!globalCtorList.empty()) {
710 llvm::SmallVector<mlir::Attribute> globalCtors =
712 globalCtorList);
713
714 mlirModule->setAttr(cir::CIRDialect::getGlobalCtorsAttrName(),
715 mlir::ArrayAttr::get(&getContext(), globalCtors));
716 }
717
719}
720
721void LoweringPreparePass::buildCXXGlobalInitFunc() {
722 if (dynamicInitializers.empty())
723 return;
724
725 // TODO: handle globals with a user-specified initialzation priority.
726 // TODO: handle default priority more nicely.
728
729 SmallString<256> fnName;
730 // Include the filename in the symbol name. Including "sub_" matches gcc
731 // and makes sure these symbols appear lexicographically behind the symbols
732 // with priority (TBD). Module implementation units behave the same
733 // way as a non-modular TU with imports.
734 // TODO: check CXX20ModuleInits
735 if (astCtx->getCurrentNamedModule() &&
737 llvm::raw_svector_ostream out(fnName);
738 std::unique_ptr<clang::MangleContext> mangleCtx(
739 astCtx->createMangleContext());
740 cast<clang::ItaniumMangleContext>(*mangleCtx)
741 .mangleModuleInitializer(astCtx->getCurrentNamedModule(), out);
742 } else {
743 fnName += "_GLOBAL__sub_I_";
744 fnName += getTransformedFileName(mlirModule);
745 }
746
747 CIRBaseBuilderTy builder(getContext());
748 builder.setInsertionPointToEnd(&mlirModule.getBodyRegion().back());
749 auto fnType = cir::FuncType::get({}, builder.getVoidTy());
750 cir::FuncOp f =
751 buildRuntimeFunction(builder, fnName, mlirModule.getLoc(), fnType,
752 cir::GlobalLinkageKind::ExternalLinkage);
753 builder.setInsertionPointToStart(f.addEntryBlock());
754 for (cir::FuncOp &f : dynamicInitializers)
755 builder.createCallOp(f.getLoc(), f, {});
756 // Add the global init function (not the individual ctor functions) to the
757 // global ctor list.
758 globalCtorList.emplace_back(fnName,
759 cir::GlobalCtorAttr::getDefaultPriority());
760
761 cir::ReturnOp::create(builder, f.getLoc());
762}
763
765 clang::ASTContext *astCtx,
766 mlir::Operation *op, mlir::Type eltTy,
767 mlir::Value arrayAddr, uint64_t arrayLen,
768 bool isCtor) {
769 // Generate loop to call into ctor/dtor for every element.
770 mlir::Location loc = op->getLoc();
771
772 // TODO: instead of getting the size from the AST context, create alias for
773 // PtrDiffTy and unify with CIRGen stuff.
774 const unsigned sizeTypeSize =
775 astCtx->getTypeSize(astCtx->getSignedSizeType());
776 uint64_t endOffset = isCtor ? arrayLen : arrayLen - 1;
777 mlir::Value endOffsetVal =
778 builder.getUnsignedInt(loc, endOffset, sizeTypeSize);
779
780 auto begin = cir::CastOp::create(builder, loc, eltTy,
781 cir::CastKind::array_to_ptrdecay, arrayAddr);
782 mlir::Value end =
783 cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal);
784 mlir::Value start = isCtor ? begin : end;
785 mlir::Value stop = isCtor ? end : begin;
786
787 mlir::Value tmpAddr = builder.createAlloca(
788 loc, /*addr type*/ builder.getPointerTo(eltTy),
789 /*var type*/ eltTy, "__array_idx", builder.getAlignmentAttr(1));
790 builder.createStore(loc, start, tmpAddr);
791
792 cir::DoWhileOp loop = builder.createDoWhile(
793 loc,
794 /*condBuilder=*/
795 [&](mlir::OpBuilder &b, mlir::Location loc) {
796 auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr);
797 mlir::Type boolTy = cir::BoolType::get(b.getContext());
798 auto cmp = builder.create<cir::CmpOp>(loc, boolTy, cir::CmpOpKind::ne,
799 currentElement, stop);
800 builder.createCondition(cmp);
801 },
802 /*bodyBuilder=*/
803 [&](mlir::OpBuilder &b, mlir::Location loc) {
804 auto currentElement = b.create<cir::LoadOp>(loc, eltTy, tmpAddr);
805
806 cir::CallOp ctorCall;
807 op->walk([&](cir::CallOp c) { ctorCall = c; });
808 assert(ctorCall && "expected ctor call");
809
810 // Array elements get constructed in order but destructed in reverse.
811 mlir::Value stride;
812 if (isCtor)
813 stride = builder.getUnsignedInt(loc, 1, sizeTypeSize);
814 else
815 stride = builder.getSignedInt(loc, -1, sizeTypeSize);
816
817 ctorCall->moveBefore(stride.getDefiningOp());
818 ctorCall->setOperand(0, currentElement);
819 auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy,
820 currentElement, stride);
821
822 // Store the element pointer to the temporary variable
823 builder.createStore(loc, nextElement, tmpAddr);
824 builder.createYield(loc);
825 });
826
827 op->replaceAllUsesWith(loop);
828 op->erase();
829}
830
831void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) {
832 CIRBaseBuilderTy builder(getContext());
833 builder.setInsertionPointAfter(op.getOperation());
834
835 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
837 auto arrayLen =
838 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
839 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen,
840 false);
841}
842
843void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) {
844 cir::CIRBaseBuilderTy builder(getContext());
845 builder.setInsertionPointAfter(op.getOperation());
846
847 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
849 auto arrayLen =
850 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
851 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(), arrayLen,
852 true);
853}
854
855void LoweringPreparePass::runOnOp(mlir::Operation *op) {
856 if (auto arrayCtor = dyn_cast<ArrayCtor>(op))
857 lowerArrayCtor(arrayCtor);
858 else if (auto arrayDtor = dyn_cast<cir::ArrayDtor>(op))
859 lowerArrayDtor(arrayDtor);
860 else if (auto cast = mlir::dyn_cast<cir::CastOp>(op))
861 lowerCastOp(cast);
862 else if (auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op))
863 lowerComplexDivOp(complexDiv);
864 else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op))
865 lowerComplexMulOp(complexMul);
866 else if (auto glob = mlir::dyn_cast<cir::GlobalOp>(op))
867 lowerGlobalOp(glob);
868 else if (auto unary = mlir::dyn_cast<cir::UnaryOp>(op))
869 lowerUnaryOp(unary);
870}
871
872void LoweringPreparePass::runOnOperation() {
873 mlir::Operation *op = getOperation();
874 if (isa<::mlir::ModuleOp>(op))
875 mlirModule = cast<::mlir::ModuleOp>(op);
876
877 llvm::SmallVector<mlir::Operation *> opsToTransform;
878
879 op->walk([&](mlir::Operation *op) {
880 if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
881 cir::ComplexMulOp, cir::ComplexDivOp, cir::GlobalOp,
882 cir::UnaryOp>(op))
883 opsToTransform.push_back(op);
884 });
885
886 for (mlir::Operation *o : opsToTransform)
887 runOnOp(o);
888
889 buildCXXGlobalInitFunc();
890 buildGlobalCtorDtorList();
891}
892
893std::unique_ptr<Pass> mlir::createLoweringPreparePass() {
894 return std::make_unique<LoweringPreparePass>();
895}
896
897std::unique_ptr<Pass>
899 auto pass = std::make_unique<LoweringPreparePass>();
900 pass->setASTContext(astCtx);
901 return std::move(pass);
902}
Defines the clang::ASTContext interface.
static mlir::Value buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, clang::ASTContext *astCtx, mlir::Operation *op, mlir::Type eltTy, mlir::Value arrayAddr, uint64_t arrayLen, bool isCtor)
static llvm::StringRef getComplexDivLibCallName(llvm::APFloat::Semantics semantics)
static llvm::SmallVector< mlir::Attribute > prepareCtorDtorAttrList(mlir::MLIRContext *context, llvm::ArrayRef< std::pair< std::string, uint32_t > > list)
static llvm::StringRef getComplexMulLibCallName(llvm::APFloat::Semantics semantics)
static mlir::Value buildComplexBinOpLibCall(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, llvm::StringRef(*libFuncNameGetter)(llvm::APFloat::Semantics), mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static mlir::Value lowerComplexMul(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexMulOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static SmallString< 128 > getTransformedFileName(mlir::ModuleOp mlirModule)
static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind scalarCastKind)
static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind elemToBoolKind)
static mlir::Value buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static mlir::Type higherPrecisionElementTypeForComplexArithmetic(mlir::MLIRContext &context, clang::ASTContext &cc, CIRBaseBuilderTy &builder, mlir::Type elementType)
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op)
static mlir::Value lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag, mlir::MLIRContext &mlirCx, clang::ASTContext &cc)
Defines the clang::Module class, which describes a module in the source code.
__device__ __2f16 b
__device__ __2f16 float c
mlir::Value createLogicalOr(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
cir::ConditionOp createCondition(mlir::Value condition)
Create a loop condition.
cir::VoidType getVoidTy()
cir::ConstantOp getNullValue(mlir::Type ty, mlir::Location loc)
mlir::Value createCast(mlir::Location loc, cir::CastKind kind, mlir::Value src, mlir::Type newTy)
cir::PointerType getPointerTo(mlir::Type ty)
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand)
cir::DoWhileOp createDoWhile(mlir::Location loc, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> condBuilder, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> bodyBuilder)
Create a do-while operation.
cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee, mlir::Type returnType, mlir::ValueRange operands, llvm::ArrayRef< mlir::NamedAttribute > attrs={})
mlir::Value getSignedInt(mlir::Location loc, int64_t val, unsigned numBits)
cir::StoreOp createStore(mlir::Location loc, mlir::Value val, mlir::Value dst, bool isVolatile=false, mlir::IntegerAttr align={}, cir::MemOrderAttr order={})
cir::CmpOp createCompare(mlir::Location loc, cir::CmpOpKind kind, mlir::Value lhs, mlir::Value rhs)
mlir::IntegerAttr getAlignmentAttr(clang::CharUnits alignment)
mlir::Value createBinop(mlir::Location loc, mlir::Value lhs, cir::BinOpKind kind, mlir::Value rhs)
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real, mlir::Value imag)
mlir::Value createIsNaN(mlir::Location loc, mlir::Value operand)
cir::YieldOp createYield(mlir::Location loc, mlir::ValueRange value={})
Create a yield operation.
mlir::Value createLogicalAnd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createUnaryOp(mlir::Location loc, cir::UnaryOpKind kind, mlir::Value operand)
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType, mlir::Type type, llvm::StringRef name, mlir::IntegerAttr alignment, mlir::Value dynAllocSize)
cir::BoolType getBoolTy()
mlir::Value getUnsignedInt(mlir::Location loc, uint64_t val, unsigned numBits)
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand)
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition ASTContext.h:220
MangleContext * createMangleContext(const TargetInfo *T=nullptr)
If T is null pointer, assume the target in ASTContext.
const LangOptions & getLangOpts() const
Definition ASTContext.h:926
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
const TargetInfo & getTargetInfo() const
Definition ASTContext.h:891
QualType getSignedSizeType() const
Return the unique signed counterpart of the integer type corresponding to size_t.
Module * getCurrentNamedModule() const
Get module under construction, nullptr if this is not a C++20 module.
bool isModuleImplementation() const
Is this a module implementation.
Definition Module.h:664
Exposes information about the current target.
Definition TargetInfo.h:226
const llvm::fltSemantics & getDoubleFormat() const
Definition TargetInfo.h:798
const llvm::fltSemantics & getHalfFormat() const
Definition TargetInfo.h:783
const llvm::fltSemantics & getBFloat16Format() const
Definition TargetInfo.h:793
const llvm::fltSemantics & getLongDoubleFormat() const
Definition TargetInfo.h:804
const llvm::fltSemantics & getFloatFormat() const
Definition TargetInfo.h:788
const llvm::fltSemantics & getFloat128Format() const
Definition TargetInfo.h:812
Defines the clang::TargetInfo interface.
const internal::VariadicAllOfMatcher< Type > type
Matches Types in the clang AST.
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
LLVM_READONLY bool isPreprocessingNumberBody(unsigned char c)
Return true if this is the body character of a C preprocessing number, which is [a-zA-Z0-9_.
Definition CharInfo.h:168
unsigned int uint32_t
std::unique_ptr< Pass > createLoweringPreparePass()
static bool opGlobalAnnotations()
static bool opGlobalCtorPriority()
static bool opFuncExtraAttrs()
static bool fastMathFlags()
static bool astVarDeclInterface()
static bool opGlobalDtorLowering()