clang 23.0.0git
LoweringPrepare.cpp
Go to the documentation of this file.
1//===- LoweringPrepare.cpp - pareparation work for LLVM lowering ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "PassDetail.h"
10#include "mlir/IR/Attributes.h"
11#include "mlir/IR/BuiltinAttributeInterfaces.h"
12#include "mlir/IR/IRMapping.h"
13#include "mlir/IR/Location.h"
14#include "mlir/IR/Value.h"
16#include "clang/AST/Mangle.h"
17#include "clang/Basic/Cuda.h"
18#include "clang/Basic/Module.h"
32#include "llvm/ADT/StringRef.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/IR/Instructions.h"
35#include "llvm/Support/ErrorHandling.h"
36#include "llvm/Support/MemoryBuffer.h"
37#include "llvm/Support/Path.h"
38#include "llvm/Support/VirtualFileSystem.h"
39
40#include <memory>
41#include <optional>
42
43using namespace mlir;
44using namespace cir;
45
46namespace mlir {
47#define GEN_PASS_DEF_LOWERINGPREPARE
48#include "clang/CIR/Dialect/Passes.h.inc"
49} // namespace mlir
50
51static SmallString<128> getTransformedFileName(mlir::ModuleOp mlirModule) {
52 SmallString<128> fileName;
53
54 if (mlirModule.getSymName())
55 fileName = llvm::sys::path::filename(mlirModule.getSymName()->str());
56
57 if (fileName.empty())
58 fileName = "<null>";
59
60 for (size_t i = 0; i < fileName.size(); ++i) {
61 // Replace everything that's not [a-zA-Z0-9._] with a _. This set happens
62 // to be the set of C preprocessing numbers.
63 if (!clang::isPreprocessingNumberBody(fileName[i]))
64 fileName[i] = '_';
65 }
66
67 return fileName;
68}
69
70namespace {
71struct LoweringPreparePass
72 : public impl::LoweringPrepareBase<LoweringPreparePass> {
73 LoweringPreparePass() = default;
74
75 // `mlir::SymbolTableCollection` is move-only (it owns lazily-created
76 // `unique_ptr<SymbolTable>` entries), which makes the implicit copy
77 // constructor ill-formed. MLIR's `clonePass()` requires copy
78 // construction, so define one explicitly. Per-run state members
79 // (dynamic initializers, guard maps, symbol-table cache, etc.) all
80 // start fresh in the cloned pass, which matches MLIR convention for
81 // pass clones and is more correct than the previous default-generated
82 // behavior that silently copied them.
83 LoweringPreparePass(const LoweringPreparePass &other)
84 : impl::LoweringPrepareBase<LoweringPreparePass>(other) {}
85
86 void runOnOperation() override;
87
88 void runOnOp(mlir::Operation *op);
89 void lowerCastOp(cir::CastOp op);
90 void lowerComplexDivOp(cir::ComplexDivOp op);
91 void lowerComplexMulOp(cir::ComplexMulOp op);
92 void lowerUnaryOp(cir::UnaryOpInterface op);
93 void lowerGetGlobalOp(cir::GetGlobalOp op);
94 void lowerGlobalOp(cir::GlobalOp op);
95 void lowerThreeWayCmpOp(cir::CmpThreeWayOp op);
96 void lowerArrayDtor(cir::ArrayDtor op);
97 void lowerArrayCtor(cir::ArrayCtor op);
98 void lowerTrivialCopyCall(cir::CallOp op);
99 void lowerStoreOfConstAggregate(cir::StoreOp op);
100 void lowerLocalInitOp(cir::LocalInitOp op);
101
102 /// Return the FuncOp called by `callOp`. Uses the cached `symbolTables`
103 /// member to avoid the O(M) module-wide scan that the static
104 /// `mlir::SymbolTable::lookupNearestSymbolFrom` would do per call.
105 cir::FuncOp getCalledFunction(cir::CallOp callOp);
106
107 /// Return a private constant cir::GlobalOp with the given type and initial
108 /// value, suitable for backing a memcpy-initialized local aggregate.
109 ///
110 /// If a global with `baseName` (or one of its `.<n>` versioned siblings)
111 /// already has a matching type and initial value, that global is reused.
112 /// Otherwise a new global is created with the next available `.<n>` suffix
113 /// (matching CIRGenBuilder::createVersionedGlobal and OGCG behavior).
114 cir::GlobalOp getOrCreateConstAggregateGlobal(CIRBaseBuilderTy &builder,
115 mlir::Location loc,
116 llvm::StringRef baseName,
117 mlir::Type ty,
118 mlir::TypedAttr constant);
119
120 /// Build the function that initializes the specified global
121 cir::FuncOp buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op);
122
123 /// When looking at the 'global' op, create the wrapper function.
124 void defineGlobalThreadLocalWrapper(cir::GlobalOp op, cir::FuncOp initAlias,
125 bool isVarDefinition);
126 /// Create an initialization alias for a thread-local variable.
127 cir::FuncOp defineGlobalThreadLocalInitAlias(cir::GlobalOp op,
128 cir::FuncOp aliasee);
129 /// Get the declaration for the 'wrapper' function for a global-TLS variable.
130 cir::FuncOp getOrCreateThreadLocalWrapper(CIRBaseBuilderTy &builder,
131 cir::GlobalOp op);
132 // Function that generates the guard global variable, get-global, and 'if'
133 // condition for global TLS init function generation. This inserts an 'if'
134 // with the store at the beginning of the 'then' region, so inserts into the
135 // body should happen after that.
136 cir::IfOp buildGlobalTlsGuardCheck(CIRBaseBuilderTy &builder,
137 mlir::Location loc, cir::GlobalOp guard);
138 /// Handle the dtor region by registering destructor with __cxa_atexit
139 cir::FuncOp getOrCreateDtorFunc(CIRBaseBuilderTy &builder, cir::GlobalOp op,
140 mlir::Region &dtorRegion,
141 cir::CallOp &dtorCall);
142
143 /// Build a module init function that calls all the dynamic initializers.
144 void buildCXXGlobalInitFunc();
145 // Build an init function for all of the ordered global thread local storage
146 // variables.
147 void buildCXXGlobalTlsFunc();
148
149 /// Materialize global ctor/dtor list
150 void buildGlobalCtorDtorList();
151
152 cir::FuncOp buildRuntimeFunction(
153 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
154 cir::FuncType type,
155 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage);
156
157 cir::GlobalOp getOrCreateRuntimeVariable(
158 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
159 mlir::Type type,
160 cir::GlobalLinkageKind linkage = cir::GlobalLinkageKind::ExternalLinkage,
161 cir::VisibilityKind visibility = cir::VisibilityKind::Default);
162
163 /// ------------
164 /// CUDA registration related
165 /// ------------
166
167 llvm::StringMap<FuncOp> cudaKernelMap;
168 llvm::SmallVector<std::pair<cir::GlobalOp, cir::CUDAVarRegistrationInfoAttr>>
169 cudaDeviceVars;
170
171 /// Build the CUDA module constructor that registers the fat binary
172 /// with the CUDA runtime.
173 void buildCUDAModuleCtor();
174 std::optional<FuncOp> buildCUDAModuleDtor();
175 std::optional<FuncOp> buildHIPModuleDtor();
176 std::optional<FuncOp> buildCUDARegisterGlobals();
177 void buildCUDARegisterVars(cir::CIRBaseBuilderTy &builder,
178 FuncOp regGlobalFunc);
179 void buildCUDARegisterGlobalFunctions(cir::CIRBaseBuilderTy &builder,
180 FuncOp regGlobalFunc);
181
182 /// Handle static local variable initialization with guard variables.
183 void handleStaticLocal(cir::GlobalOp globalOp, cir::LocalInitOp localInitOp);
184
185 /// Get or create __cxa_guard_acquire function.
186 cir::FuncOp getGuardAcquireFn(cir::PointerType guardPtrTy);
187
188 /// Get or create __cxa_guard_release function.
189 cir::FuncOp getGuardReleaseFn(cir::PointerType guardPtrTy);
190
191 /// Get or create the __init_tls function.
192 cir::FuncOp getTlsInitFn();
193
194 // Create the __tls_guard variable.
195 cir::GlobalOp createGlobalThreadLocalGuard(CIRBaseBuilderTy &builder,
196 mlir::Location loc);
197
198 /// Create a guard global variable for a static local.
199 cir::GlobalOp createGuardGlobalOp(CIRBaseBuilderTy &builder,
200 mlir::Location loc, llvm::StringRef name,
201 cir::IntType guardTy,
202 cir::GlobalLinkageKind linkage);
203
204 /// Get the guard variable for a static local declaration.
205 cir::GlobalOp getStaticLocalDeclGuardAddress(llvm::StringRef globalSymName) {
206 auto it = staticLocalDeclGuardMap.find(globalSymName);
207 if (it != staticLocalDeclGuardMap.end())
208 return it->second;
209 return nullptr;
210 }
211
212 /// Set the guard variable for a static local declaration.
213 void setStaticLocalDeclGuardAddress(llvm::StringRef globalSymName,
214 cir::GlobalOp guard) {
215 staticLocalDeclGuardMap[globalSymName] = guard;
216 }
217
218 /// Get or create the guard variable for a static local declaration.
219 cir::GlobalOp getOrCreateStaticLocalDeclGuardAddress(
220 CIRBaseBuilderTy &builder, cir::GlobalOp globalOp, StringRef guardName,
221 bool isLocalVarDecl, bool useInt8GuardVariable) {
222
223 cir::CIRDataLayout dataLayout(mlirModule);
224 cir::IntType guardTy;
225 clang::CharUnits guardAlignment;
226 // Guard variables are 64 bits in the generic ABI and size width on ARM
227 // (i.e. 32-bit on AArch32, 64-bit on AArch64).
228 if (useInt8GuardVariable) {
229 guardTy = cir::IntType::get(&getContext(), 8, /*isSigned=*/true);
230 guardAlignment = clang::CharUnits::One();
231 } else if (useARMGuardVarABI()) {
232 // Guard variables are size width on ARM (32-bit AArch32, 64-bit AArch64).
233 const unsigned sizeTypeSize =
234 astCtx->getTypeSize(astCtx->getSignedSizeType());
235 guardTy =
236 cir::IntType::get(&getContext(), sizeTypeSize, /*isSigned=*/true);
237 guardAlignment =
238 clang::CharUnits::fromQuantity(dataLayout.getABITypeAlign(guardTy));
239 } else {
240 guardTy = cir::IntType::get(&getContext(), 64, /*isSigned=*/true);
241 guardAlignment =
242 clang::CharUnits::fromQuantity(dataLayout.getABITypeAlign(guardTy));
243 }
244 assert(guardTy && guardAlignment.getQuantity() != 0);
245
246 llvm::StringRef globalSymName = globalOp.getSymName();
247 cir::GlobalOp guard = getStaticLocalDeclGuardAddress(globalSymName);
248 if (!guard) {
249 // Create the guard variable with a zero-initializer.
250 guard = createGuardGlobalOp(builder, globalOp->getLoc(), guardName,
251 guardTy, globalOp.getLinkage());
252 guard.setInitialValueAttr(cir::IntAttr::get(guardTy, 0));
253 guard.setDSOLocal(globalOp.getDsoLocal());
254 guard.setAlignment(guardAlignment.getAsAlign().value());
255 guard.setTlsModel(globalOp.getTlsModel());
256
257 // The ABI says: "It is suggested that it be emitted in the same COMDAT
258 // group as the associated data object." In practice, this doesn't work
259 // for non-ELF and non-Wasm object formats, so only do it for ELF and
260 // Wasm.
261 bool hasComdat = globalOp.getComdat();
262 const llvm::Triple &triple = astCtx->getTargetInfo().getTriple();
263 // TODO(cir): for now, we're just setting comdat to true, but it should
264 // contain a comdat reference name here instead.
265 if (!isLocalVarDecl && hasComdat &&
266 (triple.isOSBinFormatELF() || triple.isOSBinFormatWasm())) {
267 // This should be a comdat for the variable.
268 guard.setComdat(true);
269 } else if (hasComdat && globalOp.isWeakForLinker()) {
270 guard.setComdat(true);
271 }
272
273 setStaticLocalDeclGuardAddress(globalSymName, guard);
274 }
275 return guard;
276 }
277
278 ///
279 /// AST related
280 /// -----------
281
282 clang::ASTContext *astCtx;
283
284 /// Tracks current module.
285 mlir::ModuleOp mlirModule;
286
287 /// Cached symbol tables used to avoid repeated O(M) module-wide scans
288 /// during per-call/per-global symbol lookups. Lazily populated on first
289 /// use. Pass methods access this directly rather than threading it
290 /// through helper signatures (see PR feedback on #195919).
291 ///
292 /// Invariant: every site that mutates the module's symbol table either
293 /// (a) keeps `symbolTables` in sync via
294 /// `symbolTables.getSymbolTable(mlirModule).insert(...)` (as
295 /// `getOrCreateConstAggregateGlobal` does), or (b) creates a symbol
296 /// that is never resolved through the cache later. Today
297 /// `buildRuntimeFunction` and `getOrCreateRuntimeVariable` fall in the
298 /// (b) bucket: their callers either use a separate map
299 /// (`cudaKernelMap`, `staticLocalDeclGuardMap`, `dynamicInitializers`)
300 /// or the static `mlir::SymbolTable::lookupNearestSymbolFrom`, never
301 /// the cached path. If a future change adds a cached lookup of a
302 /// freshly created symbol, the corresponding create site MUST move
303 /// to bucket (a) (insert into the cache or call
304 /// `invalidateSymbolTable`).
305 mlir::SymbolTableCollection symbolTables;
306
307 /// Tracks existing dynamic initializers.
308 llvm::StringMap<uint32_t> dynamicInitializerNames;
309 llvm::SmallVector<cir::FuncOp> dynamicInitializers;
310 llvm::SmallVector<cir::FuncOp> globalThreadLocalInitializers;
311 llvm::StringMap<cir::FuncOp> threadLocalWrappers;
312 llvm::StringMap<cir::FuncOp> threadLocalInitAliases;
313
314 /// Tracks guard variables for static locals (keyed by global symbol name).
315 llvm::StringMap<cir::GlobalOp> staticLocalDeclGuardMap;
316
317 llvm::StringMap<llvm::SmallVector<cir::GlobalOp, 1>> constAggregateGlobals;
318
319 /// List of ctors and their priorities to be called before main()
320 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalCtorList;
321 /// List of dtors and their priorities to be called when unloading module.
322 llvm::SmallVector<std::pair<std::string, uint32_t>, 4> globalDtorList;
323
324 /// Returns true if the target uses ARM-style guard variables for static
325 /// local initialization (32-bit guard, check bit 0 only).
326 bool useARMGuardVarABI() const {
327 switch (astCtx->getCXXABIKind()) {
328 case clang::TargetCXXABI::GenericARM:
329 case clang::TargetCXXABI::iOS:
330 case clang::TargetCXXABI::WatchOS:
331 case clang::TargetCXXABI::GenericAArch64:
332 case clang::TargetCXXABI::WebAssembly:
333 return true;
334 default:
335 return false;
336 }
337 }
338
339 void emitGlobalGuardedDtorRegion(CIRBaseBuilderTy &builder,
340 cir::GlobalOp global,
341 mlir::Region &dtorRegion, bool tls,
342 mlir::Block &entryBB) {
343 // Create a variable that binds the atexit to this shared object.
344 builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front());
345 cir::GlobalOp handle = getOrCreateRuntimeVariable(
346 builder, "__dso_handle", global.getLoc(), builder.getI8Type(),
347 cir::GlobalLinkageKind::ExternalLinkage, cir::VisibilityKind::Hidden);
348
349 // If this is a simple call to a destructor, get the called function.
350 // Otherwise, create a helper function for the entire dtor region,
351 // replacing the current dtor region body with a call to the helper
352 // function.
353 cir::CallOp dtorCall;
354 cir::FuncOp dtorFunc =
355 getOrCreateDtorFunc(builder, global, dtorRegion, dtorCall);
356
357 // Create a runtime helper function:
358 // extern "C" int __cxa_atexit(void (*f)(void *), void *p, void *d);
359 cir::PointerType voidPtrTy = builder.getVoidPtrTy();
360 cir::PointerType voidFnPtrTy = builder.getVoidFnPtrTy({voidPtrTy});
361 cir::PointerType handlePtrTy = builder.getPointerTo(handle.getSymType());
362 auto fnAtExitType =
363 builder.getVoidFnTy({voidFnPtrTy, voidPtrTy, handlePtrTy});
364
365 llvm::StringLiteral nameAtExit = "__cxa_atexit";
366 if (tls)
367 nameAtExit = astCtx->getTargetInfo().getTriple().isOSDarwin()
368 ? llvm::StringLiteral("_tlv_atexit")
369 : llvm::StringLiteral("__cxa_thread_atexit");
370
371 cir::FuncOp fnAtExit = buildRuntimeFunction(builder, nameAtExit,
372 global.getLoc(), fnAtExitType);
373
374 // Replace the dtor (or helper) call with a call to
375 // __cxa_atexit(&dtor, &var, &__dso_handle)
376 builder.setInsertionPointAfter(dtorCall);
377 mlir::Value args[3];
378 auto dtorPtrTy = cir::PointerType::get(dtorFunc.getFunctionType());
379 args[0] = cir::GetGlobalOp::create(builder, dtorCall.getLoc(), dtorPtrTy,
380 dtorFunc.getSymName());
381 args[0] = cir::CastOp::create(builder, dtorCall.getLoc(), voidFnPtrTy,
382 cir::CastKind::bitcast, args[0]);
383 args[1] =
384 cir::CastOp::create(builder, dtorCall.getLoc(), voidPtrTy,
385 cir::CastKind::bitcast, dtorCall.getArgOperand(0));
386 args[2] = cir::GetGlobalOp::create(builder, handle.getLoc(), handlePtrTy,
387 handle.getSymName());
388 builder.createCallOp(dtorCall.getLoc(), fnAtExit, args);
389 dtorCall->erase();
390 mlir::Block &dtorBlock = dtorRegion.front();
391 entryBB.getOperations().splice(entryBB.end(), dtorBlock.getOperations(),
392 dtorBlock.begin(),
393 std::prev(dtorBlock.end()));
394 // make sure we leave the insert location after the operations we just
395 // inserted.
396 builder.setInsertionPointToEnd(&entryBB);
397 }
398
399 /// Emit the guarded initialization for a static local variable.
400 /// This handles the if/else structure after the guard byte check,
401 /// following OG's ItaniumCXXABI::EmitGuardedInit skeleton.
402 void emitCXXGuardedInitIf(CIRBaseBuilderTy &builder, cir::GlobalOp globalOp,
403 mlir::Region &ctorRegion, mlir::Region &dtorRegion,
404 cir::ASTVarDeclInterface varDecl,
405 mlir::Value guardPtr, cir::PointerType guardPtrTy,
406 bool threadsafe) {
407 auto loc = globalOp->getLoc();
408
409 // The semantics of dynamic initialization of variables with static or
410 // thread storage duration depends on whether they are declared at
411 // block-scope. The initialization of such variables at block-scope can be
412 // aborted with an exception and later retried (per C++20 [stmt.dcl]p4),
413 // and recursive entry to their initialization has undefined behavior (also
414 // per C++20 [stmt.dcl]p4). For such variables declared at non-block scope,
415 // exceptions lead to termination (per C++20 [except.terminate]p1), and
416 // recursive references to the variables are governed only by the lifetime
417 // rules (per C++20 [class.cdtor]p2), which means such references are
418 // perfectly fine as long as they avoid touching memory. As a result,
419 // block-scope variables must not be marked as initialized until after
420 // initialization completes (unless the mark is reverted following an
421 // exception), but non-block-scope variables must be marked prior to
422 // initialization so that recursive accesses during initialization do not
423 // restart initialization.
424
425 auto emitBody = [&]() {
426 // Emit the initializer and add a global destructor if appropriate.
427 mlir::Block *insertBlock = builder.getInsertionBlock();
428 if (!ctorRegion.empty()) {
429 assert(ctorRegion.hasOneBlock() && "Enforced by MaxSizedRegion<1>");
430
431 mlir::Block &block = ctorRegion.front();
432 insertBlock->getOperations().splice(
433 insertBlock->end(), block.getOperations(), block.begin(),
434 std::prev(block.end()));
435 }
436
437 if (!dtorRegion.empty()) {
438 assert(dtorRegion.hasOneBlock() && "Enforced by MaxSizedRegion<1>");
439
440 emitGlobalGuardedDtorRegion(builder, globalOp, dtorRegion, !threadsafe,
441 *insertBlock);
442 }
443 builder.setInsertionPointToEnd(insertBlock);
444 ctorRegion.getBlocks().clear();
445 };
446
447 // Variables used when coping with thread-safe statics and exceptions.
448 if (threadsafe) {
449 // Call __cxa_guard_acquire.
450 cir::CallOp acquireCall = builder.createCallOp(
451 loc, getGuardAcquireFn(guardPtrTy), mlir::ValueRange{guardPtr});
452 mlir::Value acquireResult = acquireCall.getResult();
453
454 auto acquireZero = builder.getConstantInt(
455 loc, mlir::cast<cir::IntType>(acquireResult.getType()), 0);
456 auto shouldInit = builder.createCompare(loc, cir::CmpOpKind::ne,
457 acquireResult, acquireZero);
458
459 // Create the IfOp for the shouldInit check.
460 // Pass an empty callback to avoid auto-creating a yield terminator.
461 auto ifOp =
462 cir::IfOp::create(builder, loc, shouldInit, /*withElseRegion=*/false,
463 [](mlir::OpBuilder &, mlir::Location) {});
464 mlir::OpBuilder::InsertionGuard insertGuard(builder);
465 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
466
467 // Call __cxa_guard_abort along the exceptional edge.
468 // OG: CGF.EHStack.pushCleanup<CallGuardAbort>(EHCleanup, guard);
470
471 emitBody();
472
473 // Pop the guard-abort cleanup if we pushed one.
474 // OG: CGF.PopCleanupBlock();
476
477 // Call __cxa_guard_release. This cannot throw.
478 builder.createCallOp(loc, getGuardReleaseFn(guardPtrTy),
479 mlir::ValueRange{guardPtr});
480
481 builder.createYield(loc);
482 } else if (!varDecl.isLocalVarDecl()) {
483 // For non-local variables, store 1 into the first byte of the guard
484 // variable before the object initialization begins so that references
485 // to the variable during initialization don't restart initialization.
486 // OG: Builder.CreateStore(llvm::ConstantInt::get(CGM.Int8Ty, 1), ...);
487 // Then: CGF.EmitCXXGlobalVarDeclInit(D, var, shouldPerformInit);
488 globalOp->emitError("NYI: non-threadsafe init for non-local variables");
489 return;
490 } else {
491 emitBody();
492 // For local variables, store 1 into the first byte of the guard variable
493 // after the object initialization completes so that initialization is
494 // retried if initialization is interrupted by an exception.
495 builder.createStore(
496 loc, builder.getConstantInt(loc, guardPtrTy.getPointee(), 1),
497 guardPtr);
498 }
499
500 builder.createYield(loc); // Outermost IfOp
501 }
502
503 void setASTContext(clang::ASTContext *c) { astCtx = c; }
504};
505
506} // namespace
507
508cir::GlobalOp LoweringPreparePass::getOrCreateRuntimeVariable(
509 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
510 mlir::Type type, cir::GlobalLinkageKind linkage,
511 cir::VisibilityKind visibility) {
512 cir::GlobalOp g = dyn_cast_or_null<cir::GlobalOp>(
513 mlir::SymbolTable::lookupNearestSymbolFrom(
514 mlirModule, mlir::StringAttr::get(mlirModule->getContext(), name)));
515 if (!g) {
516 g = cir::GlobalOp::create(builder, loc, name, type);
517 g.setLinkageAttr(
518 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
519 mlir::SymbolTable::setSymbolVisibility(
520 g, mlir::SymbolTable::Visibility::Private);
521 g.setGlobalVisibility(visibility);
522 }
523 return g;
524}
525
526cir::FuncOp LoweringPreparePass::buildRuntimeFunction(
527 mlir::OpBuilder &builder, llvm::StringRef name, mlir::Location loc,
528 cir::FuncType type, cir::GlobalLinkageKind linkage) {
529 cir::FuncOp f = dyn_cast_or_null<FuncOp>(SymbolTable::lookupNearestSymbolFrom(
530 mlirModule, StringAttr::get(mlirModule->getContext(), name)));
531 if (!f) {
532 f = cir::FuncOp::create(builder, loc, name, type);
533 f.setLinkageAttr(
534 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
535 mlir::SymbolTable::setSymbolVisibility(
536 f, mlir::SymbolTable::Visibility::Private);
537
539 }
540 return f;
541}
542
543static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx,
544 cir::CastOp op) {
545 cir::CIRBaseBuilderTy builder(ctx);
546 builder.setInsertionPoint(op);
547
548 mlir::Value src = op.getSrc();
549 mlir::Value imag = builder.getNullValue(src.getType(), op.getLoc());
550 return builder.createComplexCreate(op.getLoc(), src, imag);
551}
552
553static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx,
554 cir::CastOp op,
555 cir::CastKind elemToBoolKind) {
556 cir::CIRBaseBuilderTy builder(ctx);
557 builder.setInsertionPoint(op);
558
559 mlir::Value src = op.getSrc();
560 if (!mlir::isa<cir::BoolType>(op.getType()))
561 return builder.createComplexReal(op.getLoc(), src);
562
563 // Complex cast to bool: (bool)(a+bi) => (bool)a || (bool)b
564 mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
565 mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
566
567 cir::BoolType boolTy = builder.getBoolTy();
568 mlir::Value srcRealToBool =
569 builder.createCast(op.getLoc(), elemToBoolKind, srcReal, boolTy);
570 mlir::Value srcImagToBool =
571 builder.createCast(op.getLoc(), elemToBoolKind, srcImag, boolTy);
572 return builder.createLogicalOr(op.getLoc(), srcRealToBool, srcImagToBool);
573}
574
575static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx,
576 cir::CastOp op,
577 cir::CastKind scalarCastKind) {
578 CIRBaseBuilderTy builder(ctx);
579 builder.setInsertionPoint(op);
580
581 mlir::Value src = op.getSrc();
582 auto dstComplexElemTy =
583 mlir::cast<cir::ComplexType>(op.getType()).getElementType();
584
585 mlir::Value srcReal = builder.createComplexReal(op.getLoc(), src);
586 mlir::Value srcImag = builder.createComplexImag(op.getLoc(), src);
587
588 mlir::Value dstReal = builder.createCast(op.getLoc(), scalarCastKind, srcReal,
589 dstComplexElemTy);
590 mlir::Value dstImag = builder.createCast(op.getLoc(), scalarCastKind, srcImag,
591 dstComplexElemTy);
592 return builder.createComplexCreate(op.getLoc(), dstReal, dstImag);
593}
594
595void LoweringPreparePass::lowerCastOp(cir::CastOp op) {
596 mlir::MLIRContext &ctx = getContext();
597 mlir::Value loweredValue = [&]() -> mlir::Value {
598 switch (op.getKind()) {
599 case cir::CastKind::float_to_complex:
600 case cir::CastKind::int_to_complex:
601 return lowerScalarToComplexCast(ctx, op);
602 case cir::CastKind::float_complex_to_real:
603 case cir::CastKind::int_complex_to_real:
604 return lowerComplexToScalarCast(ctx, op, op.getKind());
605 case cir::CastKind::float_complex_to_bool:
606 return lowerComplexToScalarCast(ctx, op, cir::CastKind::float_to_bool);
607 case cir::CastKind::int_complex_to_bool:
608 return lowerComplexToScalarCast(ctx, op, cir::CastKind::int_to_bool);
609 case cir::CastKind::float_complex:
610 return lowerComplexToComplexCast(ctx, op, cir::CastKind::floating);
611 case cir::CastKind::float_complex_to_int_complex:
612 return lowerComplexToComplexCast(ctx, op, cir::CastKind::float_to_int);
613 case cir::CastKind::int_complex:
614 return lowerComplexToComplexCast(ctx, op, cir::CastKind::integral);
615 case cir::CastKind::int_complex_to_float_complex:
616 return lowerComplexToComplexCast(ctx, op, cir::CastKind::int_to_float);
617 default:
618 return nullptr;
619 }
620 }();
621
622 if (loweredValue) {
623 op.replaceAllUsesWith(loweredValue);
624 op.erase();
625 }
626}
627
628static mlir::Value buildComplexBinOpLibCall(
629 LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
630 llvm::StringRef (*libFuncNameGetter)(llvm::APFloat::Semantics),
631 mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal,
632 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag) {
633 cir::FPTypeInterface elementTy =
634 mlir::cast<cir::FPTypeInterface>(ty.getElementType());
635
636 llvm::StringRef libFuncName = libFuncNameGetter(
637 llvm::APFloat::SemanticsToEnum(elementTy.getFloatSemantics()));
638 llvm::SmallVector<mlir::Type, 4> libFuncInputTypes(4, elementTy);
639
640 cir::FuncType libFuncTy = cir::FuncType::get(libFuncInputTypes, ty);
641
642 // Insert a declaration for the runtime function to be used in Complex
643 // multiplication and division when needed
644 cir::FuncOp libFunc;
645 {
646 mlir::OpBuilder::InsertionGuard ipGuard{builder};
647 builder.setInsertionPointToStart(pass.mlirModule.getBody());
648 libFunc = pass.buildRuntimeFunction(builder, libFuncName, loc, libFuncTy);
649 }
650
651 cir::CallOp call =
652 builder.createCallOp(loc, libFunc, {lhsReal, lhsImag, rhsReal, rhsImag});
653 return call.getResult();
654}
655
656static llvm::StringRef
657getComplexDivLibCallName(llvm::APFloat::Semantics semantics) {
658 switch (semantics) {
659 case llvm::APFloat::S_IEEEhalf:
660 return "__divhc3";
661 case llvm::APFloat::S_IEEEsingle:
662 return "__divsc3";
663 case llvm::APFloat::S_IEEEdouble:
664 return "__divdc3";
665 case llvm::APFloat::S_PPCDoubleDouble:
666 return "__divtc3";
667 case llvm::APFloat::S_x87DoubleExtended:
668 return "__divxc3";
669 case llvm::APFloat::S_IEEEquad:
670 return "__divtc3";
671 default:
672 llvm_unreachable("unsupported floating point type");
673 }
674}
675
676static mlir::Value
677buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc,
678 mlir::Value lhsReal, mlir::Value lhsImag,
679 mlir::Value rhsReal, mlir::Value rhsImag) {
680 // (a+bi) / (c+di) = ((ac+bd)/(cc+dd)) + ((bc-ad)/(cc+dd))i
681 mlir::Value &a = lhsReal;
682 mlir::Value &b = lhsImag;
683 mlir::Value &c = rhsReal;
684 mlir::Value &d = rhsImag;
685
686 mlir::Value ac = builder.createMul(loc, a, c); // a*c
687 mlir::Value bd = builder.createMul(loc, b, d); // b*d
688 mlir::Value cc = builder.createMul(loc, c, c); // c*c
689 mlir::Value dd = builder.createMul(loc, d, d); // d*d
690 mlir::Value acbd = builder.createAdd(loc, ac, bd); // ac+bd
691 mlir::Value ccdd = builder.createAdd(loc, cc, dd); // cc+dd
692 mlir::Value resultReal = builder.createDiv(loc, acbd, ccdd);
693
694 mlir::Value bc = builder.createMul(loc, b, c); // b*c
695 mlir::Value ad = builder.createMul(loc, a, d); // a*d
696 mlir::Value bcad = builder.createSub(loc, bc, ad); // bc-ad
697 mlir::Value resultImag = builder.createDiv(loc, bcad, ccdd);
698 return builder.createComplexCreate(loc, resultReal, resultImag);
699}
700
701static mlir::Value
703 mlir::Value lhsReal, mlir::Value lhsImag,
704 mlir::Value rhsReal, mlir::Value rhsImag) {
705 // Implements Smith's algorithm for complex division.
706 // SMITH, R. L. Algorithm 116: Complex division. Commun. ACM 5, 8 (1962).
707
708 // Let:
709 // - lhs := a+bi
710 // - rhs := c+di
711 // - result := lhs / rhs = e+fi
712 //
713 // The algorithm pseudocode looks like follows:
714 // if fabs(c) >= fabs(d):
715 // r := d / c
716 // tmp := c + r*d
717 // e = (a + b*r) / tmp
718 // f = (b - a*r) / tmp
719 // else:
720 // r := c / d
721 // tmp := d + r*c
722 // e = (a*r + b) / tmp
723 // f = (b*r - a) / tmp
724
725 mlir::Value &a = lhsReal;
726 mlir::Value &b = lhsImag;
727 mlir::Value &c = rhsReal;
728 mlir::Value &d = rhsImag;
729
730 auto trueBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
731 mlir::Value r = builder.createDiv(loc, d, c); // r := d / c
732 mlir::Value rd = builder.createMul(loc, r, d); // r*d
733 mlir::Value tmp = builder.createAdd(loc, c, rd); // tmp := c + r*d
734
735 mlir::Value br = builder.createMul(loc, b, r); // b*r
736 mlir::Value abr = builder.createAdd(loc, a, br); // a + b*r
737 mlir::Value e = builder.createDiv(loc, abr, tmp);
738
739 mlir::Value ar = builder.createMul(loc, a, r); // a*r
740 mlir::Value bar = builder.createSub(loc, b, ar); // b - a*r
741 mlir::Value f = builder.createDiv(loc, bar, tmp);
742
743 mlir::Value result = builder.createComplexCreate(loc, e, f);
744 builder.createYield(loc, result);
745 };
746
747 auto falseBranchBuilder = [&](mlir::OpBuilder &, mlir::Location) {
748 mlir::Value r = builder.createDiv(loc, c, d); // r := c / d
749 mlir::Value rc = builder.createMul(loc, r, c); // r*c
750 mlir::Value tmp = builder.createAdd(loc, d, rc); // tmp := d + r*c
751
752 mlir::Value ar = builder.createMul(loc, a, r); // a*r
753 mlir::Value arb = builder.createAdd(loc, ar, b); // a*r + b
754 mlir::Value e = builder.createDiv(loc, arb, tmp);
755
756 mlir::Value br = builder.createMul(loc, b, r); // b*r
757 mlir::Value bra = builder.createSub(loc, br, a); // b*r - a
758 mlir::Value f = builder.createDiv(loc, bra, tmp);
759
760 mlir::Value result = builder.createComplexCreate(loc, e, f);
761 builder.createYield(loc, result);
762 };
763
764 auto cFabs = cir::FAbsOp::create(builder, loc, c);
765 auto dFabs = cir::FAbsOp::create(builder, loc, d);
766 cir::CmpOp cmpResult =
767 builder.createCompare(loc, cir::CmpOpKind::ge, cFabs, dFabs);
768 auto ternary = cir::TernaryOp::create(builder, loc, cmpResult,
769 trueBranchBuilder, falseBranchBuilder);
770
771 return ternary.getResult();
772}
773
775 mlir::MLIRContext &context, clang::ASTContext &cc,
776 CIRBaseBuilderTy &builder, mlir::Type elementType) {
777
778 auto getHigherPrecisionFPType = [&context](mlir::Type type) -> mlir::Type {
779 if (mlir::isa<cir::FP16Type>(type))
780 return cir::SingleType::get(&context);
781
782 if (mlir::isa<cir::SingleType>(type) || mlir::isa<cir::BF16Type>(type))
783 return cir::DoubleType::get(&context);
784
785 if (mlir::isa<cir::DoubleType>(type))
786 return cir::LongDoubleType::get(&context, type);
787
788 return type;
789 };
790
791 auto getFloatTypeSemantics =
792 [&cc](mlir::Type type) -> const llvm::fltSemantics & {
793 const clang::TargetInfo &info = cc.getTargetInfo();
794 if (mlir::isa<cir::FP16Type>(type))
795 return info.getHalfFormat();
796
797 if (mlir::isa<cir::BF16Type>(type))
798 return info.getBFloat16Format();
799
800 if (mlir::isa<cir::SingleType>(type))
801 return info.getFloatFormat();
802
803 if (mlir::isa<cir::DoubleType>(type))
804 return info.getDoubleFormat();
805
806 if (mlir::isa<cir::LongDoubleType>(type)) {
807 if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
808 llvm_unreachable("NYI Float type semantics with OpenMP");
809 return info.getLongDoubleFormat();
810 }
811
812 if (mlir::isa<cir::FP128Type>(type)) {
813 if (cc.getLangOpts().OpenMP && cc.getLangOpts().OpenMPIsTargetDevice)
814 llvm_unreachable("NYI Float type semantics with OpenMP");
815 return info.getFloat128Format();
816 }
817
818 llvm_unreachable("Unsupported float type semantics");
819 };
820
821 const mlir::Type higherElementType = getHigherPrecisionFPType(elementType);
822 const llvm::fltSemantics &elementTypeSemantics =
823 getFloatTypeSemantics(elementType);
824 const llvm::fltSemantics &higherElementTypeSemantics =
825 getFloatTypeSemantics(higherElementType);
826
827 // Check that the promoted type can handle the intermediate values without
828 // overflowing. This can be interpreted as:
829 // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal) * 2 <=
830 // LargerType.LargestFiniteVal.
831 // In terms of exponent it gives this formula:
832 // (SmallerType.LargestFiniteVal * SmallerType.LargestFiniteVal
833 // doubles the exponent of SmallerType.LargestFiniteVal)
834 if (llvm::APFloat::semanticsMaxExponent(elementTypeSemantics) * 2 + 1 <=
835 llvm::APFloat::semanticsMaxExponent(higherElementTypeSemantics)) {
836 return higherElementType;
837 }
838
839 // The intermediate values can't be represented in the promoted type
840 // without overflowing.
841 return {};
842}
843
844static mlir::Value
845lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder,
846 mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal,
847 mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag,
848 mlir::MLIRContext &mlirCx, clang::ASTContext &cc) {
849 cir::ComplexType complexTy = op.getType();
850 if (mlir::isa<cir::FPTypeInterface>(complexTy.getElementType())) {
851 cir::ComplexRangeKind range = op.getRange();
852 if (range == cir::ComplexRangeKind::Improved)
853 return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
854 rhsReal, rhsImag);
855
856 if (range == cir::ComplexRangeKind::Full)
858 loc, complexTy, lhsReal, lhsImag, rhsReal,
859 rhsImag);
860
861 if (range == cir::ComplexRangeKind::Promoted) {
862 mlir::Type originalElementType = complexTy.getElementType();
863 mlir::Type higherPrecisionElementType =
865 originalElementType);
866
867 if (!higherPrecisionElementType)
868 return buildRangeReductionComplexDiv(builder, loc, lhsReal, lhsImag,
869 rhsReal, rhsImag);
870
871 cir::CastKind floatingCastKind = cir::CastKind::floating;
872 lhsReal = builder.createCast(floatingCastKind, lhsReal,
873 higherPrecisionElementType);
874 lhsImag = builder.createCast(floatingCastKind, lhsImag,
875 higherPrecisionElementType);
876 rhsReal = builder.createCast(floatingCastKind, rhsReal,
877 higherPrecisionElementType);
878 rhsImag = builder.createCast(floatingCastKind, rhsImag,
879 higherPrecisionElementType);
880
881 mlir::Value algebraicResult = buildAlgebraicComplexDiv(
882 builder, loc, lhsReal, lhsImag, rhsReal, rhsImag);
883
884 mlir::Value resultReal = builder.createComplexReal(loc, algebraicResult);
885 mlir::Value resultImag = builder.createComplexImag(loc, algebraicResult);
886
887 mlir::Value finalReal =
888 builder.createCast(floatingCastKind, resultReal, originalElementType);
889 mlir::Value finalImag =
890 builder.createCast(floatingCastKind, resultImag, originalElementType);
891 return builder.createComplexCreate(loc, finalReal, finalImag);
892 }
893 }
894
895 return buildAlgebraicComplexDiv(builder, loc, lhsReal, lhsImag, rhsReal,
896 rhsImag);
897}
898
899void LoweringPreparePass::lowerComplexDivOp(cir::ComplexDivOp op) {
900 cir::CIRBaseBuilderTy builder(getContext());
901 builder.setInsertionPointAfter(op);
902 mlir::Location loc = op.getLoc();
903 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
904 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
905 mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
906 mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
907 mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
908 mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
909
910 mlir::Value loweredResult =
911 lowerComplexDiv(*this, builder, loc, op, lhsReal, lhsImag, rhsReal,
912 rhsImag, getContext(), *astCtx);
913 op.replaceAllUsesWith(loweredResult);
914 op.erase();
915}
916
917static llvm::StringRef
918getComplexMulLibCallName(llvm::APFloat::Semantics semantics) {
919 switch (semantics) {
920 case llvm::APFloat::S_IEEEhalf:
921 return "__mulhc3";
922 case llvm::APFloat::S_IEEEsingle:
923 return "__mulsc3";
924 case llvm::APFloat::S_IEEEdouble:
925 return "__muldc3";
926 case llvm::APFloat::S_PPCDoubleDouble:
927 return "__multc3";
928 case llvm::APFloat::S_x87DoubleExtended:
929 return "__mulxc3";
930 case llvm::APFloat::S_IEEEquad:
931 return "__multc3";
932 default:
933 llvm_unreachable("unsupported floating point type");
934 }
935}
936
937static mlir::Value lowerComplexMul(LoweringPreparePass &pass,
938 CIRBaseBuilderTy &builder,
939 mlir::Location loc, cir::ComplexMulOp op,
940 mlir::Value lhsReal, mlir::Value lhsImag,
941 mlir::Value rhsReal, mlir::Value rhsImag) {
942 // (a+bi) * (c+di) = (ac-bd) + (ad+bc)i
943 mlir::Value resultRealLhs = builder.createMul(loc, lhsReal, rhsReal); // ac
944 mlir::Value resultRealRhs = builder.createMul(loc, lhsImag, rhsImag); // bd
945 mlir::Value resultImagLhs = builder.createMul(loc, lhsReal, rhsImag); // ad
946 mlir::Value resultImagRhs = builder.createMul(loc, lhsImag, rhsReal); // bc
947 mlir::Value resultReal = builder.createSub(loc, resultRealLhs, resultRealRhs);
948 mlir::Value resultImag = builder.createAdd(loc, resultImagLhs, resultImagRhs);
949 mlir::Value algebraicResult =
950 builder.createComplexCreate(loc, resultReal, resultImag);
951
952 cir::ComplexType complexTy = op.getType();
953 cir::ComplexRangeKind rangeKind = op.getRange();
954 if (mlir::isa<cir::IntType>(complexTy.getElementType()) ||
955 rangeKind == cir::ComplexRangeKind::Basic ||
956 rangeKind == cir::ComplexRangeKind::Improved ||
957 rangeKind == cir::ComplexRangeKind::Promoted)
958 return algebraicResult;
959
961
962 // Check whether the real part and the imaginary part of the result are both
963 // NaN. If so, emit a library call to compute the multiplication instead.
964 // We check a value against NaN by comparing the value against itself.
965 mlir::Value resultRealIsNaN = builder.createIsNaN(loc, resultReal);
966 mlir::Value resultImagIsNaN = builder.createIsNaN(loc, resultImag);
967 mlir::Value resultRealAndImagAreNaN =
968 builder.createLogicalAnd(loc, resultRealIsNaN, resultImagIsNaN);
969
970 return cir::TernaryOp::create(
971 builder, loc, resultRealAndImagAreNaN,
972 [&](mlir::OpBuilder &, mlir::Location) {
973 mlir::Value libCallResult = buildComplexBinOpLibCall(
974 pass, builder, &getComplexMulLibCallName, loc, complexTy,
975 lhsReal, lhsImag, rhsReal, rhsImag);
976 builder.createYield(loc, libCallResult);
977 },
978 [&](mlir::OpBuilder &, mlir::Location) {
979 builder.createYield(loc, algebraicResult);
980 })
981 .getResult();
982}
983
984void LoweringPreparePass::lowerComplexMulOp(cir::ComplexMulOp op) {
985 cir::CIRBaseBuilderTy builder(getContext());
986 builder.setInsertionPointAfter(op);
987 mlir::Location loc = op.getLoc();
988 mlir::TypedValue<cir::ComplexType> lhs = op.getLhs();
989 mlir::TypedValue<cir::ComplexType> rhs = op.getRhs();
990 mlir::Value lhsReal = builder.createComplexReal(loc, lhs);
991 mlir::Value lhsImag = builder.createComplexImag(loc, lhs);
992 mlir::Value rhsReal = builder.createComplexReal(loc, rhs);
993 mlir::Value rhsImag = builder.createComplexImag(loc, rhs);
994 mlir::Value loweredResult = lowerComplexMul(*this, builder, loc, op, lhsReal,
995 lhsImag, rhsReal, rhsImag);
996 op.replaceAllUsesWith(loweredResult);
997 op.erase();
998}
999
1000void LoweringPreparePass::lowerUnaryOp(cir::UnaryOpInterface op) {
1001 if (!mlir::isa<cir::ComplexType>(op.getResult().getType()))
1002 return;
1003
1004 mlir::Location loc = op->getLoc();
1005 CIRBaseBuilderTy builder(getContext());
1006 builder.setInsertionPointAfter(op);
1007
1008 mlir::Value operand = op.getInput();
1009 mlir::Value operandReal = builder.createComplexReal(loc, operand);
1010 mlir::Value operandImag = builder.createComplexImag(loc, operand);
1011
1012 mlir::Value resultReal = operandReal;
1013 mlir::Value resultImag = operandImag;
1014
1015 llvm::TypeSwitch<mlir::Operation *>(op)
1016 .Case<cir::IncOp>(
1017 [&](auto) { resultReal = builder.createInc(loc, operandReal); })
1018 .Case<cir::DecOp>(
1019 [&](auto) { resultReal = builder.createDec(loc, operandReal); })
1020 .Case<cir::MinusOp>([&](auto) {
1021 resultReal = builder.createMinus(loc, operandReal);
1022 resultImag = builder.createMinus(loc, operandImag);
1023 })
1024 .Case<cir::NotOp>(
1025 [&](auto) { resultImag = builder.createMinus(loc, operandImag); })
1026 .Default([](auto) { llvm_unreachable("unhandled unary complex op"); });
1027
1028 mlir::Value result = builder.createComplexCreate(loc, resultReal, resultImag);
1029 op->replaceAllUsesWith(mlir::ValueRange{result});
1030 op->erase();
1031}
1032
1033cir::FuncOp LoweringPreparePass::getOrCreateDtorFunc(CIRBaseBuilderTy &builder,
1034 cir::GlobalOp op,
1035 mlir::Region &dtorRegion,
1036 cir::CallOp &dtorCall) {
1037 mlir::OpBuilder::InsertionGuard guard(builder);
1039
1040 cir::VoidType voidTy = builder.getVoidTy();
1041 auto voidPtrTy = cir::PointerType::get(voidTy);
1042
1043 // Look for operations in dtorBlock
1044 mlir::Block &dtorBlock = dtorRegion.front();
1045
1046 // The first operation should be a get_global to retrieve the address
1047 // of the global variable we're destroying.
1048 auto opIt = dtorBlock.getOperations().begin();
1049 cir::GetGlobalOp ggop = mlir::cast<cir::GetGlobalOp>(*opIt);
1050
1051 // The simple case is just a call to a destructor, like this:
1052 //
1053 // %0 = cir.get_global %globalS : !cir.ptr<!rec_S>
1054 // cir.call %_ZN1SD1Ev(%0) : (!cir.ptr<!rec_S>) -> ()
1055 // (implicit cir.yield)
1056 //
1057 // That is, if the second operation is a call that takes the get_global result
1058 // as its only operand, and the only other operation is a yield, then we can
1059 // just return the called function.
1060 if (dtorBlock.getOperations().size() == 3) {
1061 auto callOp = mlir::dyn_cast<cir::CallOp>(&*(++opIt));
1062 auto yieldOp = mlir::dyn_cast<cir::YieldOp>(&*(++opIt));
1063 if (yieldOp && callOp && callOp.getNumOperands() == 1 &&
1064 callOp.getArgOperand(0) == ggop) {
1065 dtorCall = callOp;
1066 return getCalledFunction(callOp);
1067 }
1068 }
1069
1070 // Otherwise, we need to create a helper function to replace the dtor region.
1071 // This name is kind of arbitrary, but it matches the name that classic
1072 // codegen uses, based on the expected case that gets us here.
1073 builder.setInsertionPointAfter(op);
1074 SmallString<256> fnName("__cxx_global_array_dtor");
1075 uint32_t cnt = dynamicInitializerNames[fnName]++;
1076 if (cnt)
1077 fnName += "." + std::to_string(cnt);
1078
1079 // Create the helper function.
1080 auto fnType = cir::FuncType::get({voidPtrTy}, voidTy);
1081 cir::FuncOp dtorFunc =
1082 buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
1083 cir::GlobalLinkageKind::InternalLinkage);
1084
1085 SmallVector<mlir::NamedAttribute> paramAttrs;
1086 paramAttrs.push_back(
1087 builder.getNamedAttr("llvm.noundef", builder.getUnitAttr()));
1088 SmallVector<mlir::Attribute> argAttrDicts;
1089 argAttrDicts.push_back(
1090 mlir::DictionaryAttr::get(builder.getContext(), paramAttrs));
1091 dtorFunc.setArgAttrsAttr(
1092 mlir::ArrayAttr::get(builder.getContext(), argAttrDicts));
1093
1094 mlir::Block *entryBB = dtorFunc.addEntryBlock();
1095
1096 // Move everything from the dtor region into the helper function.
1097 entryBB->getOperations().splice(entryBB->begin(), dtorBlock.getOperations(),
1098 dtorBlock.begin(), dtorBlock.end());
1099
1100 // Before erasing this, clone it back into the dtor region
1101 cir::GetGlobalOp dtorGGop =
1102 mlir::cast<cir::GetGlobalOp>(entryBB->getOperations().front());
1103 builder.setInsertionPointToStart(&dtorBlock);
1104 builder.clone(*dtorGGop.getOperation());
1105
1106 // Replace all uses of the help function's get_global with the function
1107 // argument.
1108 mlir::Value dtorArg = entryBB->getArgument(0);
1109 dtorGGop.replaceAllUsesWith(dtorArg);
1110 dtorGGop.erase();
1111
1112 // Replace the yield in the final block with a return
1113 mlir::Block &finalBlock = dtorFunc.getBody().back();
1114 auto yieldOp = cast<cir::YieldOp>(finalBlock.getTerminator());
1115 builder.setInsertionPoint(yieldOp);
1116 cir::ReturnOp::create(builder, yieldOp->getLoc());
1117 yieldOp->erase();
1118
1119 // Create a call to the helper function, passing the original get_global op
1120 // as the argument.
1121 cir::GetGlobalOp origGGop =
1122 mlir::cast<cir::GetGlobalOp>(dtorBlock.getOperations().front());
1123 builder.setInsertionPointAfter(origGGop);
1124 mlir::Value ggopResult = origGGop.getResult();
1125 dtorCall = builder.createCallOp(op.getLoc(), dtorFunc, ggopResult);
1126
1127 // Add a yield after the call.
1128 auto finalYield = cir::YieldOp::create(builder, op.getLoc());
1129
1130 // Erase everything after the yield.
1131 dtorBlock.getOperations().erase(std::next(mlir::Block::iterator(finalYield)),
1132 dtorBlock.end());
1133 dtorRegion.getBlocks().erase(std::next(dtorRegion.begin()), dtorRegion.end());
1134
1135 return dtorFunc;
1136}
1137
1138cir::FuncOp
1139LoweringPreparePass::buildCXXGlobalVarDeclInitFunc(cir::GlobalOp op) {
1140 // TODO(cir): Store this in the GlobalOp.
1141 // This should come from the MangleContext, but for now I'm hardcoding it.
1142 SmallString<256> fnName("__cxx_global_var_init");
1143 // Get a unique name
1144 uint32_t cnt = dynamicInitializerNames[fnName]++;
1145 if (cnt)
1146 fnName += "." + std::to_string(cnt);
1147
1148 // Create a variable initialization function.
1149 CIRBaseBuilderTy builder(getContext());
1150 builder.setInsertionPointAfter(op);
1151 cir::VoidType voidTy = builder.getVoidTy();
1152 auto fnType = cir::FuncType::get({}, voidTy);
1153 FuncOp f = buildRuntimeFunction(builder, fnName, op.getLoc(), fnType,
1154 cir::GlobalLinkageKind::InternalLinkage);
1155
1156 // Move over the initialization code of the ctor region.
1157 // The ctor region may have multiple blocks when exception handling
1158 // scaffolding creates extra blocks (e.g., unreachable/trap blocks).
1159 // We move all operations from the first block (minus the yield) into
1160 // the function entry, and discard extra blocks (which contain only
1161 // unreachable terminators from EH cleanup paths).
1162 mlir::Block *entryBB = f.addEntryBlock();
1163 builder.setInsertionPointToStart(entryBB);
1164
1165 // If this is a global TLS variable (that is, declared at namespace scope), we
1166 // have to emit the guard variable here.
1167 bool needsTlsGuard = op.getDynTlsRefs() && op.getDynTlsRefs()->getGuardName();
1168 cir::IfOp guardIf;
1169 if (needsTlsGuard) {
1170 guardIf = buildGlobalTlsGuardCheck(
1171 builder, op.getLoc(),
1172 getOrCreateStaticLocalDeclGuardAddress(
1173 builder, op, op.getDynTlsRefs()->getGuardName().getValue(),
1174 /*isLocalVarDecl=*/false,
1175 /*useInt8GuardVariable=*/op.hasInternalLinkage()));
1176 builder.setInsertionPointToEnd(&guardIf.getThenRegion().front());
1177 }
1178
1179 if (!op.getCtorRegion().empty()) {
1180 mlir::Block &block = op.getCtorRegion().front();
1181 mlir::Block *insertBlock = builder.getBlock();
1182 insertBlock->getOperations().splice(insertBlock->end(),
1183 block.getOperations(), block.begin(),
1184 std::prev(block.end()));
1185 }
1186
1187 // Register the destructor call with __cxa_atexit
1188 mlir::Region &dtorRegion = op.getDtorRegion();
1189 if (!dtorRegion.empty()) {
1191
1192 emitGlobalGuardedDtorRegion(builder, op, dtorRegion,
1193 op.getTlsModel().has_value(),
1194 *builder.getBlock());
1195 }
1196
1197 // If we're actually in the 'if' above, create a yield.
1198 if (needsTlsGuard) {
1199 builder.setInsertionPointToEnd(&guardIf.getThenRegion().back());
1200 cir::YieldOp::create(builder, op.getLoc());
1201 }
1202
1203 // Replace cir.yield with cir.return
1204 builder.setInsertionPointToEnd(entryBB);
1205 mlir::Operation *yieldOp = nullptr;
1206 if (!op.getCtorRegion().empty()) {
1207 mlir::Block &block = op.getCtorRegion().front();
1208 yieldOp = &block.getOperations().back();
1209 } else {
1210 assert(!dtorRegion.empty());
1211 mlir::Block &block = dtorRegion.front();
1212 yieldOp = &block.getOperations().back();
1213 }
1214
1215 assert(isa<cir::YieldOp>(*yieldOp));
1216 cir::ReturnOp::create(builder, yieldOp->getLoc());
1217 return f;
1218}
1219
1220cir::FuncOp
1221LoweringPreparePass::getGuardAcquireFn(cir::PointerType guardPtrTy) {
1222 // int __cxa_guard_acquire(__guard *guard_object);
1223 CIRBaseBuilderTy builder(getContext());
1224 mlir::OpBuilder::InsertionGuard ipGuard{builder};
1225 builder.setInsertionPointToStart(mlirModule.getBody());
1226 mlir::Location loc = mlirModule.getLoc();
1227 cir::IntType intTy = cir::IntType::get(&getContext(), 32, /*isSigned=*/true);
1228 auto fnType = cir::FuncType::get({guardPtrTy}, intTy);
1229 return buildRuntimeFunction(builder, "__cxa_guard_acquire", loc, fnType);
1230}
1231
1232cir::FuncOp
1233LoweringPreparePass::getGuardReleaseFn(cir::PointerType guardPtrTy) {
1234 // void __cxa_guard_release(__guard *guard_object);
1235 CIRBaseBuilderTy builder(getContext());
1236 mlir::OpBuilder::InsertionGuard ipGuard{builder};
1237 builder.setInsertionPointToStart(mlirModule.getBody());
1238 mlir::Location loc = mlirModule.getLoc();
1239 cir::VoidType voidTy = cir::VoidType::get(&getContext());
1240 auto fnType = cir::FuncType::get({guardPtrTy}, voidTy);
1241 return buildRuntimeFunction(builder, "__cxa_guard_release", loc, fnType);
1242}
1243
1244cir::FuncOp LoweringPreparePass::getTlsInitFn() {
1245 // void __tls_init(void);
1246 CIRBaseBuilderTy builder(getContext());
1247 mlir::OpBuilder::InsertionGuard _{builder};
1248 builder.setInsertionPointToStart(mlirModule.getBody());
1249 mlir::Location loc = mlirModule.getLoc();
1250 auto fnType = builder.getVoidFnTy();
1251 return buildRuntimeFunction(builder, "__tls_init", loc, fnType,
1252 cir::GlobalLinkageKind::InternalLinkage);
1253}
1254
1255cir::GlobalOp LoweringPreparePass::createGuardGlobalOp(
1256 CIRBaseBuilderTy &builder, mlir::Location loc, llvm::StringRef name,
1257 cir::IntType guardTy, cir::GlobalLinkageKind linkage) {
1258 mlir::OpBuilder::InsertionGuard guard(builder);
1259 builder.setInsertionPointToStart(mlirModule.getBody());
1260 cir::GlobalOp g = cir::GlobalOp::create(builder, loc, name, guardTy);
1261 g.setLinkageAttr(
1262 cir::GlobalLinkageKindAttr::get(builder.getContext(), linkage));
1263 mlir::SymbolTable::setSymbolVisibility(
1264 g, mlir::SymbolTable::Visibility::Private);
1265 return g;
1266}
1267
1268void LoweringPreparePass::handleStaticLocal(cir::GlobalOp globalOp,
1269 cir::LocalInitOp localInitOp) {
1270 CIRBaseBuilderTy builder(getContext());
1271
1272 std::optional<cir::ASTVarDeclInterface> astOption = globalOp.getAst();
1273 assert(astOption.has_value());
1274 cir::ASTVarDeclInterface varDecl = astOption.value();
1275
1276 builder.setInsertionPointAfter(localInitOp);
1277 mlir::Block *localInitBlock = builder.getInsertionBlock();
1278
1279 // Remove the terminator temporarily - we'll add it back at the end.
1280 mlir::Operation *ret = localInitBlock->getTerminator();
1281 ret->remove();
1282 // Note: These two insert-point-after sets are necessary, as the 'trailing'
1283 // operation has changed thanks to the terminator removal.
1284 builder.setInsertionPointAfter(localInitOp);
1285
1286 // Inline variables that weren't instantiated from variable templates have
1287 // partially-ordered initialization within their translation unit.
1288 bool nonTemplateInline =
1289 varDecl.isInline() &&
1290 !clang::isTemplateInstantiation(varDecl.getTemplateSpecializationKind());
1291
1292 // Inline namespace-scope variables require guarded initialization in a
1293 // __cxx_global_var_init function. This is not yet implemented.
1294 if (nonTemplateInline) {
1295 globalOp->emitError(
1296 "NYI: guarded initialization for inline namespace-scope variables");
1297 return;
1298 }
1299
1300 // We only need to use thread-safe statics for local non-TLS variables and
1301 // inline variables; other global initialization is always single-threaded
1302 // or (through lazy dynamic loading in multiple threads) unsequenced.
1303 bool threadsafe = astCtx->getLangOpts().ThreadsafeStatics &&
1304 (varDecl.isLocalVarDecl() || nonTemplateInline) &&
1305 !varDecl.getTLSKind();
1306
1307 // If we have a global variable with internal linkage and thread-safe statics
1308 // are disabled, we can just let the guard variable be of type i8.
1309 bool useInt8GuardVariable = !threadsafe && globalOp.hasInternalLinkage();
1310
1311 // Create the guard variable if we don't already have it.
1312 cir::GlobalOp guard = getOrCreateStaticLocalDeclGuardAddress(
1313 builder, globalOp, globalOp.getStaticLocalGuard()->getName().getValue(),
1314 varDecl.isLocalVarDecl(), useInt8GuardVariable);
1315 if (!guard) {
1316 // Error was already emitted, just restore the terminator and return.
1317 localInitBlock->push_back(ret);
1318 return;
1319 }
1320
1321 mlir::Value guardPtr = builder.createGetGlobal(guard, localInitOp.getTls());
1322
1323 // Test whether the variable has completed initialization.
1324 //
1325 // Itanium C++ ABI 3.3.2:
1326 // The following is pseudo-code showing how these functions can be used:
1327 // if (obj_guard.first_byte == 0) {
1328 // if ( __cxa_guard_acquire (&obj_guard) ) {
1329 // try {
1330 // ... initialize the object ...;
1331 // } catch (...) {
1332 // __cxa_guard_abort (&obj_guard);
1333 // throw;
1334 // }
1335 // ... queue object destructor with __cxa_atexit() ...;
1336 // __cxa_guard_release (&obj_guard);
1337 // }
1338 // }
1339 //
1340 // If threadsafe statics are enabled, but we don't have inline atomics, just
1341 // call __cxa_guard_acquire unconditionally. The "inline" check isn't
1342 // actually inline, and the user might not expect calls to __atomic libcalls.
1343 unsigned maxInlineWidthInBits =
1345
1346 if (!threadsafe || maxInlineWidthInBits) {
1347 // Load the first byte of the guard variable.
1348 auto bytePtrTy = cir::PointerType::get(builder.getSIntNTy(8));
1349 mlir::Value bytePtr = builder.createBitcast(guardPtr, bytePtrTy);
1350 mlir::Value guardLoad = builder.createAlignedLoad(
1351 localInitOp.getLoc(), bytePtr, *guard.getAlignment());
1352
1353 // Itanium ABI:
1354 // An implementation supporting thread-safety on multiprocessor
1355 // systems must also guarantee that references to the initialized
1356 // object do not occur before the load of the initialization flag.
1357 //
1358 // In LLVM, we do this by marking the load Acquire.
1359 if (threadsafe) {
1360 auto loadOp = mlir::cast<cir::LoadOp>(guardLoad.getDefiningOp());
1361 loadOp.setMemOrder(cir::MemOrder::Acquire);
1362 loadOp.setSyncScope(cir::SyncScopeKind::System);
1363 }
1364
1365 // For ARM, we should only check the first bit, rather than the entire byte:
1366 //
1367 // ARM C++ ABI 3.2.3.1:
1368 // To support the potential use of initialization guard variables
1369 // as semaphores that are the target of ARM SWP and LDREX/STREX
1370 // synchronizing instructions we define a static initialization
1371 // guard variable to be a 4-byte aligned, 4-byte word with the
1372 // following inline access protocol.
1373 // #define INITIALIZED 1
1374 // if ((obj_guard & INITIALIZED) != INITIALIZED) {
1375 // if (__cxa_guard_acquire(&obj_guard))
1376 // ...
1377 // }
1378 //
1379 // and similarly for ARM64:
1380 //
1381 // ARM64 C++ ABI 3.2.2:
1382 // This ABI instead only specifies the value bit 0 of the static guard
1383 // variable; all other bits are platform defined. Bit 0 shall be 0 when
1384 // the variable is not initialized and 1 when it is.
1385 if (useARMGuardVarABI() && !useInt8GuardVariable) {
1386 auto one = builder.getConstantInt(
1387 localInitOp.getLoc(), mlir::cast<cir::IntType>(guardLoad.getType()),
1388 1);
1389 guardLoad = builder.createAnd(localInitOp.getLoc(), guardLoad, one);
1390 }
1391
1392 // Check if the first byte of the guard variable is zero.
1393 auto zero = builder.getConstantInt(
1394 localInitOp.getLoc(), mlir::cast<cir::IntType>(guardLoad.getType()), 0);
1395 auto needsInit = builder.createCompare(localInitOp.getLoc(),
1396 cir::CmpOpKind::eq, guardLoad, zero);
1397
1398 // Build the guarded initialization inside an if block.
1399 cir::IfOp::create(
1400 builder, globalOp.getLoc(), needsInit,
1401 /*withElseRegion=*/false, [&](mlir::OpBuilder &, mlir::Location) {
1402 emitCXXGuardedInitIf(builder, globalOp, localInitOp.getCtorRegion(),
1403 localInitOp.getDtorRegion(), varDecl, guardPtr,
1404 builder.getPointerTo(guard.getSymType()),
1405 threadsafe);
1406 });
1407 } else {
1408 // Threadsafe statics without inline atomics - call __cxa_guard_acquire
1409 // unconditionally without the initial guard byte check.
1410 globalOp->emitError("NYI: guarded init without inline atomics support");
1411 return;
1412 }
1413
1414 // Insert the removed terminator back.
1415 builder.getInsertionBlock()->push_back(ret);
1416}
1417
1418void LoweringPreparePass::lowerLocalInitOp(cir::LocalInitOp initOp) {
1419
1420 // If we don't actually need to initialize anything anymore, we're done here.
1421 if (initOp.getCtorRegion().empty() && initOp.getDtorRegion().empty()) {
1422 initOp.erase();
1423 return;
1424 }
1425
1426 cir::GlobalOp globalOp = initOp.getReferencedGlobal(symbolTables);
1427 assert(globalOp && "No global-op found");
1428
1429 handleStaticLocal(globalOp, initOp);
1430
1431 // Remove the init local op, now that we've done everything we need with it.
1432 initOp.erase();
1433}
1434static bool isThreadWrapperReplaceable(cir::TLS_Model tls,
1435 clang::ASTContext &astCtx) {
1436 return tls == cir::TLS_Model::GeneralDynamic &&
1437 astCtx.getTargetInfo().getTriple().isOSDarwin();
1438}
1439
1440static cir::GlobalLinkageKind
1442 if (isLocalLinkage(op.getLinkage()))
1443 return op.getLinkage();
1444
1445 if (isThreadWrapperReplaceable(*op.getTlsModel(), astCtx))
1446 if (!isLinkOnceLinkage(op.getLinkage()) &&
1447 !isWeakODRLinkage(op.getLinkage()))
1448 return op.getLinkage();
1449
1450 // If this isn't a TU in which this variable is defined, the thread wrapper is
1451 // discardable.
1452 if (op.isDeclaration())
1453 return cir::GlobalLinkageKind::LinkOnceODRLinkage;
1454 return cir::GlobalLinkageKind::WeakODRLinkage;
1455}
1456
1457cir::FuncOp
1458LoweringPreparePass::getOrCreateThreadLocalWrapper(CIRBaseBuilderTy &builder,
1459 GlobalOp op) {
1460 mlir::OpBuilder::InsertionGuard insertGuard(builder);
1461 builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front());
1462
1463 mlir::StringAttr wrapperName = op.getDynTlsRefs()->getWrapperName();
1464
1465 auto existingWrapperIter = threadLocalWrappers.find(wrapperName.getValue());
1466 if (existingWrapperIter != threadLocalWrappers.end())
1467 return existingWrapperIter->second;
1468
1469 // type is ptr-to-global-type(void);
1470 auto funcType = cir::FuncType::get({}, builder.getPointerTo(op.getSymType()));
1471 cir::FuncOp func =
1472 cir::FuncOp::create(builder, op.getLoc(), wrapperName, funcType);
1473
1474 cir::GlobalLinkageKind linkageKind =
1475 getThreadLocalWrapperLinkage(op, *astCtx);
1476 func.setLinkageAttr(
1477 cir::GlobalLinkageKindAttr::get(&getContext(), linkageKind));
1478
1479 // TODO(cir): This is supposed to refer to the comdat of the global symbol,
1480 // but that isn't in CIR yet.
1481 if (astCtx->getTargetInfo().getTriple().supportsCOMDAT() &&
1482 func.isWeakForLinker())
1483 func.setComdat(true);
1484
1485 mlir::SymbolTable::setSymbolVisibility(
1486 func, mlir::SymbolTable::Visibility::Private);
1487
1488 if (!isLocalLinkage(linkageKind)) {
1489 if (!isThreadWrapperReplaceable(*op.getTlsModel(), *astCtx) ||
1490 isLinkOnceLinkage(linkageKind) || isWeakODRLinkage(linkageKind) ||
1491 op.getGlobalVisibility() == cir::VisibilityKind::Hidden)
1492 func.setGlobalVisibility(cir::VisibilityKind::Hidden);
1493 }
1494 if (isThreadWrapperReplaceable(*op.getTlsModel(), *astCtx))
1495 op->emitError("Unhandled thread wrapper attributes for CC and Nounwind");
1496
1497 threadLocalWrappers.insert({wrapperName.getValue(), func});
1498 return func;
1499}
1500
1501void LoweringPreparePass::defineGlobalThreadLocalWrapper(cir::GlobalOp op,
1502 cir::FuncOp initAlias,
1503 bool isVarDefinition) {
1504 CIRBaseBuilderTy builder(getContext());
1505 cir::FuncOp wrapper = getOrCreateThreadLocalWrapper(builder, op);
1506 mlir::Block *entryBB = wrapper.addEntryBlock();
1507 builder.setInsertionPointToStart(entryBB);
1508 // If we are a situation where we have/need one, emit a call to the init
1509 // function.
1510 if (initAlias) {
1511 mlir::Location aliasLoc = initAlias.getLoc();
1512 if (!isVarDefinition) {
1513 // If this isn't a definition, we have to check that the alias exists.
1514 mlir::Value funcLoad = cir::GetGlobalOp::create(
1515 builder, aliasLoc, cir::PointerType::get(initAlias.getFunctionType()),
1516 initAlias.getSymName());
1517 mlir::Value nullCheck =
1518 builder.getNullValue(funcLoad.getType(), aliasLoc);
1519 mlir::Value cmp = cir::CmpOp::create(
1520 builder, aliasLoc, cir::CmpOpKind::ne, funcLoad, nullCheck);
1521 cir::IfOp::create(builder, aliasLoc, cmp, /*withElseRegion=*/false,
1522 [&](mlir::OpBuilder &, mlir::Location loc) {
1523 builder.createCallOp(aliasLoc, initAlias, {});
1524 cir::YieldOp::create(builder, aliasLoc);
1525 });
1526 } else {
1527 // If this IS a definition, we know the alias exists, so we can just emit
1528 // a call to it.
1529 builder.createCallOp(aliasLoc, initAlias, {});
1530 }
1531 }
1532 cir::GetGlobalOp get = builder.createGetGlobal(op, /*tls=*/true);
1533 cir::ReturnOp::create(builder, op.getLoc(), {get});
1534}
1535
1536cir::FuncOp
1537LoweringPreparePass::defineGlobalThreadLocalInitAlias(cir::GlobalOp op,
1538 cir::FuncOp aliasee) {
1539 CIRBaseBuilderTy builder(getContext());
1540 mlir::OpBuilder::InsertionGuard insertGuard(builder);
1541 builder.setInsertionPointToStart(&mlirModule.getBodyRegion().front());
1542 mlir::StringAttr aliasName = op.getDynTlsRefs()->getInitName();
1543 auto existingAliasIter = threadLocalInitAliases.find(aliasName.getValue());
1544
1545 if (existingAliasIter != threadLocalInitAliases.end())
1546 return existingAliasIter->second;
1547
1548 auto funcType = builder.getVoidFnTy();
1549 cir::FuncOp alias =
1550 cir::FuncOp::create(builder, op.getLoc(), aliasName, funcType);
1551 alias.setLinkage(op.getLinkage());
1552
1553 if (aliasee) {
1554 alias.setAliasee(aliasee.getSymName());
1555 } else {
1556 // If we don't have anything to alias (because this isn't a variable
1557 // definition!), we set this as just a function definition with no alias,
1558 // and extern-weak.
1559 alias.setLinkage(cir::GlobalLinkageKind::ExternalWeakLinkage);
1560 mlir::SymbolTable::setSymbolVisibility(
1561 alias, mlir::SymbolTable::Visibility::Private);
1562 }
1563
1564 threadLocalInitAliases.insert({aliasName.getValue(), alias});
1565 return alias;
1566}
1567
1568void LoweringPreparePass::lowerGlobalOp(GlobalOp op) {
1569 // Static locals are handled separately via guard variables.
1570 if (op.getStaticLocalGuard())
1571 return;
1572
1573 mlir::Region &ctorRegion = op.getCtorRegion();
1574 mlir::Region &dtorRegion = op.getDtorRegion();
1575 cir::FuncOp initAlias;
1576
1577 if (!ctorRegion.empty() || !dtorRegion.empty()) {
1578 // Build a variable initialization function and move the initialzation code
1579 // in the ctor region over.
1580 cir::FuncOp f = buildCXXGlobalVarDeclInitFunc(op);
1581
1582 // Clear the ctor and dtor region
1583 ctorRegion.getBlocks().clear();
1584 dtorRegion.getBlocks().clear();
1585
1587 if (op.getTlsModel() == TLS_Model::GeneralDynamic &&
1588 !op.getStaticLocalGuard().has_value()) {
1589 // There are two types of global TLS variables: 'ordered' and 'unordered'.
1590 // 'ordered' are the common case. A call to any of them causes all of the
1591 // initializers for all other 'ordered' ones to be called, via a
1592 // `__tls_init` function. So the 'init alias' that gets called in the
1593 // wrapper for these goes directly to `__tls_init`.
1594
1595 // 'Unordered' values are the case for variable templates. In this case,
1596 // their init alias goes directly to their init function. The FE generates
1597 // a guard variable for them (since they cannot use the global guard), so
1598 // we differentiate them that way.
1599
1600 if (op.getDynTlsRefs()->getGuardName()) {
1601 // Unordered: the alias is the function we just generated.
1602 initAlias = defineGlobalThreadLocalInitAlias(op, f);
1603 } else {
1604 // Ordered: Get the __tls_init, and make the alias to that.
1605 initAlias = defineGlobalThreadLocalInitAlias(op, getTlsInitFn());
1606 // Ordered inits also need to get called from the __tls_init function,
1607 // so we add the init function to the list, so that we can add them to
1608 // it later.
1609 globalThreadLocalInitializers.push_back(f);
1610 }
1611 } else {
1612 dynamicInitializers.push_back(f);
1613 }
1614 } else if (op.getTlsModel() == TLS_Model::GeneralDynamic &&
1615 op.getDynTlsRefs() && op.isDeclaration()) {
1616 // If this is a declaration and has no init function, we probably DO have to
1617 // create an alias that needs checking, so create it as extern-weak.
1618 initAlias = defineGlobalThreadLocalInitAlias(op, {});
1619 }
1620
1621 // We need a wrapper for TLS globals that MIGHT have a non-constant
1622 // initialization. The FE will have generated the DynTlsRefs for any with
1623 // known dynamic init, or unknown (extern) init.
1624 if (op.getTlsModel() == TLS_Model::GeneralDynamic && op.getDynTlsRefs())
1625 defineGlobalThreadLocalWrapper(op, initAlias, !op.isDeclaration());
1626
1628}
1629
1630void LoweringPreparePass::lowerGetGlobalOp(GetGlobalOp op) {
1631 if (!op.getTls())
1632 return;
1633 auto globalOp = mlir::cast<cir::GlobalOp>(
1634 symbolTables.lookupNearestSymbolFrom(op, op.getNameAttr()));
1635
1636 // Only global/namespace scope thread local variables need to have their
1637 // get-global operations rewritten to be calls to a wrapper function. If
1638 // we're not in a dynamic TLS (or one without the TLS markers), we can leave
1639 // this one as a get-global and return early.
1640 if (globalOp.getTlsModel() != TLS_Model::GeneralDynamic ||
1641 !globalOp.getDynTlsRefs())
1642 return;
1643
1644 // If this is a global TLS, we need to replace the call to 'get_global' with a
1645 // call to the wrapper function. Classic codegen figures out some cases where
1646 // we can omit this, but for now we're going to always put it in, as it is
1647 // effectively a no-op.
1648
1649 // The first 'GetGlobalOp' at the beginning of a ctor/dtor region on one of
1650 // these is for the purpose of creating/destroying. We want to skip replacing
1651 // THAT one, but leave all other get-global-ops in place, else
1652 // self-referential ops won't work right.
1653
1654 // Note that ctors/dtors are removed during this pass. We get away with these
1655 // checks because the only time that these situations can actually be true
1656 // (that is, the ctor/dtor region exist) is if we're in the process of
1657 // converting the ctor/dtor for this. If we're NOT doing that, the ctor/dtor
1658 // will have already disappeared.
1659 mlir::Operation *parentOp = op->getParentOp();
1660 if (parentOp == globalOp) {
1661 mlir::Region *ctorRegion = &globalOp.getCtorRegion();
1662 mlir::Region *dtorRegion = &globalOp.getDtorRegion();
1663
1664 if (!ctorRegion->empty() && &*ctorRegion->op_begin() == op.getOperation())
1665 return;
1666 if (!dtorRegion->empty() && &*dtorRegion->op_begin() == op.getOperation())
1667 return;
1668 }
1669
1670 CIRBaseBuilderTy builder(getContext());
1671 cir::FuncOp wrapperFunc = getOrCreateThreadLocalWrapper(builder, globalOp);
1672
1673 builder.setInsertionPoint(op);
1674 cir::CallOp call = builder.createCallOp(
1675 wrapperFunc.getLoc(),
1676 mlir::FlatSymbolRefAttr::get(wrapperFunc.getSymNameAttr()),
1677 wrapperFunc.getFunctionType().getReturnType(), {});
1678 op->replaceAllUsesWith(call);
1679 op.erase();
1680}
1681
1682void LoweringPreparePass::lowerThreeWayCmpOp(CmpThreeWayOp op) {
1683 CIRBaseBuilderTy builder(getContext());
1684 builder.setInsertionPointAfter(op);
1685
1686 mlir::Location loc = op->getLoc();
1687 cir::CmpThreeWayInfoAttr cmpInfo = op.getInfo();
1688
1689 mlir::Value ltRes =
1690 builder.getConstantInt(loc, op.getType(), cmpInfo.getLt());
1691 mlir::Value eqRes =
1692 builder.getConstantInt(loc, op.getType(), cmpInfo.getEq());
1693 mlir::Value gtRes =
1694 builder.getConstantInt(loc, op.getType(), cmpInfo.getGt());
1695
1696 mlir::Value transformedResult;
1697 if (cmpInfo.getOrdering() != CmpOrdering::Partial) {
1698 // Total ordering
1699 mlir::Value lt =
1700 builder.createCompare(loc, CmpOpKind::lt, op.getLhs(), op.getRhs());
1701 mlir::Value selectOnLt = builder.createSelect(loc, lt, ltRes, gtRes);
1702 mlir::Value eq =
1703 builder.createCompare(loc, CmpOpKind::eq, op.getLhs(), op.getRhs());
1704 transformedResult = builder.createSelect(loc, eq, eqRes, selectOnLt);
1705 } else {
1706 // Partial ordering
1707 cir::ConstantOp unorderedRes = builder.getConstantInt(
1708 loc, op.getType(), cmpInfo.getUnordered().value());
1709
1710 mlir::Value eq =
1711 builder.createCompare(loc, CmpOpKind::eq, op.getLhs(), op.getRhs());
1712 mlir::Value selectOnEq = builder.createSelect(loc, eq, eqRes, unorderedRes);
1713 mlir::Value gt =
1714 builder.createCompare(loc, CmpOpKind::gt, op.getLhs(), op.getRhs());
1715 mlir::Value selectOnGt = builder.createSelect(loc, gt, gtRes, selectOnEq);
1716 mlir::Value lt =
1717 builder.createCompare(loc, CmpOpKind::lt, op.getLhs(), op.getRhs());
1718 transformedResult = builder.createSelect(loc, lt, ltRes, selectOnGt);
1719 }
1720
1721 op.replaceAllUsesWith(transformedResult);
1722 op.erase();
1723}
1724
1725template <typename AttributeTy>
1726static llvm::SmallVector<mlir::Attribute>
1727prepareCtorDtorAttrList(mlir::MLIRContext *context,
1728 llvm::ArrayRef<std::pair<std::string, uint32_t>> list) {
1730 for (const auto &[name, priority] : list)
1731 attrs.push_back(AttributeTy::get(context, name, priority));
1732 return attrs;
1733}
1734
1735void LoweringPreparePass::buildGlobalCtorDtorList() {
1736 if (!globalCtorList.empty()) {
1737 llvm::SmallVector<mlir::Attribute> globalCtors =
1739 globalCtorList);
1740
1741 mlirModule->setAttr(cir::CIRDialect::getGlobalCtorsAttrName(),
1742 mlir::ArrayAttr::get(&getContext(), globalCtors));
1743 }
1744
1745 if (!globalDtorList.empty()) {
1746 llvm::SmallVector<mlir::Attribute> globalDtors =
1748 globalDtorList);
1749 mlirModule->setAttr(cir::CIRDialect::getGlobalDtorsAttrName(),
1750 mlir::ArrayAttr::get(&getContext(), globalDtors));
1751 }
1752}
1753
1754cir::GlobalOp
1755LoweringPreparePass::createGlobalThreadLocalGuard(CIRBaseBuilderTy &builder,
1756 mlir::Location loc) {
1757 mlir::OpBuilder::InsertionGuard guard(builder);
1758 builder.setInsertionPointToStart(mlirModule.getBody());
1759
1760 // The TLS Guard is always an Int8Ty.
1761 cir::IntType guardTy = builder.getSIntNTy(8);
1762 auto g = cir::GlobalOp::create(builder, loc, "__tls_guard", guardTy);
1763 g.setLinkageAttr(cir::GlobalLinkageKindAttr::get(
1764 builder.getContext(), cir::GlobalLinkageKind::InternalLinkage));
1765 g.setAlignment(clang::CharUnits::One().getAsAlign().value());
1766 // At the moment, we only have implementation for this mode, as it is the
1767 // default. At one point we might need to load this mode from the module.
1768 g.setTlsModel(TLS_Model::GeneralDynamic);
1769 g.setInitialValueAttr(cir::IntAttr::get(guardTy, 0));
1770 return g;
1771}
1772
1773cir::IfOp LoweringPreparePass::buildGlobalTlsGuardCheck(
1774 CIRBaseBuilderTy &builder, mlir::Location loc, cir::GlobalOp guard) {
1775 cir::GetGlobalOp getGuard = builder.createGetGlobal(guard, /*tls=*/true);
1776 mlir::Value getGuardValue = getGuard;
1777
1778 // Classic codegen always just loads the first byte of the guard instead of
1779 // the whole thing. __tls_guard is already only 8 bits, but for the case of
1780 // unordered TLS, it gets created as 64 bits.
1781 if (guard.getSymType() != builder.getSIntNTy(8))
1782 getGuardValue = builder.createBitcast(
1783 getGuard, cir::PointerType::get(builder.getSIntNTy(8)));
1784
1785 mlir::Value guardLoad =
1786 builder.createAlignedLoad(loc, getGuardValue, *guard.getAlignment());
1787 auto zero = builder.getConstantInt(loc, builder.getSIntNTy(8), 0);
1788 cir::CmpOp compare =
1789 builder.createCompare(loc, cir::CmpOpKind::eq, guardLoad, zero);
1790 return cir::IfOp::create(
1791 builder, loc, compare,
1792 /*withElseRegion=*/false, [&](mlir::OpBuilder &, mlir::Location loc) {
1793 // Classic codegen still does this store as a i8, but it doesn't seem
1794 // reasonable to do an i8 store into a 64 bit value?
1795 builder.createStore(
1796 loc, builder.getConstantInt(loc, guard.getSymType(), 1), getGuard);
1797 });
1798}
1799
1800void LoweringPreparePass::buildCXXGlobalTlsFunc() {
1801 if (globalThreadLocalInitializers.empty())
1802 return;
1803
1804 // The global-ordered-init function for TLS variables just calls each of the
1805 // init-functions in order after doing a guard.
1806
1807 cir::FuncOp tlsInit = getTlsInitFn();
1808 mlir::Location loc = tlsInit.getLoc();
1809 CIRBaseBuilderTy builder(getContext());
1810 mlir::Block *entryBB = tlsInit.addEntryBlock();
1811 builder.setInsertionPointToStart(entryBB);
1812
1813 cir::IfOp ifOperation = buildGlobalTlsGuardCheck(
1814 builder, loc, createGlobalThreadLocalGuard(builder, loc));
1815
1816 // Emit the body of the guarded spot.
1817 builder.setInsertionPointToEnd(&ifOperation.getThenRegion().front());
1818 for (cir::FuncOp initFunc : globalThreadLocalInitializers)
1819 builder.createCallOp(loc, initFunc, {});
1820 cir::YieldOp::create(builder, loc);
1821
1822 builder.setInsertionPointAfter(ifOperation);
1823 cir::ReturnOp::create(builder, loc);
1824}
1825
1826void LoweringPreparePass::buildCXXGlobalInitFunc() {
1827 if (dynamicInitializers.empty())
1828 return;
1829
1830 // TODO: handle globals with a user-specified initialzation priority.
1831 // TODO: handle default priority more nicely.
1833
1834 SmallString<256> fnName;
1835 // Include the filename in the symbol name. Including "sub_" matches gcc
1836 // and makes sure these symbols appear lexicographically behind the symbols
1837 // with priority (TBD). Module implementation units behave the same
1838 // way as a non-modular TU with imports.
1839 // TODO: check CXX20ModuleInits
1840 if (astCtx->getCurrentNamedModule() &&
1842 llvm::raw_svector_ostream out(fnName);
1843 std::unique_ptr<clang::MangleContext> mangleCtx(
1844 astCtx->createMangleContext());
1845 cast<clang::ItaniumMangleContext>(*mangleCtx)
1846 .mangleModuleInitializer(astCtx->getCurrentNamedModule(), out);
1847 } else {
1848 fnName += "_GLOBAL__sub_I_";
1849 fnName += getTransformedFileName(mlirModule);
1850 }
1851
1852 CIRBaseBuilderTy builder(getContext());
1853 builder.setInsertionPointToEnd(&mlirModule.getBodyRegion().back());
1854 auto fnType = cir::FuncType::get({}, builder.getVoidTy());
1855 cir::FuncOp f =
1856 buildRuntimeFunction(builder, fnName, mlirModule.getLoc(), fnType,
1857 cir::GlobalLinkageKind::ExternalLinkage);
1858 builder.setInsertionPointToStart(f.addEntryBlock());
1859 for (cir::FuncOp &f : dynamicInitializers)
1860 builder.createCallOp(f.getLoc(), f, {});
1861 // Add the global init function (not the individual ctor functions) to the
1862 // global ctor list.
1863 globalCtorList.emplace_back(fnName,
1864 cir::GlobalCtorAttr::getDefaultPriority());
1865
1866 cir::ReturnOp::create(builder, f.getLoc());
1867}
1868
1869/// Lower a cir.array.ctor or cir.array.dtor into a do-while loop that
1870/// iterates over every element. For cir.array.ctor ops whose partial_dtor
1871/// region is non-empty, the ctor loop is wrapped in a cir.cleanup.scope whose
1872/// EH cleanup performs a reverse destruction loop using the partial dtor body.
1874 clang::ASTContext *astCtx,
1875 mlir::Operation *op, mlir::Type eltTy,
1876 mlir::Value addr,
1877 mlir::Value numElements,
1878 uint64_t arrayLen, bool isCtor) {
1879 mlir::Location loc = op->getLoc();
1880 bool isDynamic = numElements != nullptr;
1881
1882 // TODO: instead of getting the size from the AST context, create alias for
1883 // PtrDiffTy and unify with CIRGen stuff.
1884 const unsigned sizeTypeSize =
1885 astCtx->getTypeSize(astCtx->getSignedSizeType());
1886
1887 // Both constructors and destructors use end = begin + numElements.
1888 // Constructors iterate forward [begin, end). Destructors iterate backward
1889 // from end, decrementing before calling the destructor on each element.
1890 mlir::Value begin, end;
1891 if (isDynamic) {
1892 begin = addr;
1893 end = cir::PtrStrideOp::create(builder, loc, eltTy, begin, numElements);
1894 } else {
1895 mlir::Value endOffsetVal =
1896 builder.getUnsignedInt(loc, arrayLen, sizeTypeSize);
1897 begin = cir::CastOp::create(builder, loc, eltTy,
1898 cir::CastKind::array_to_ptrdecay, addr);
1899 end = cir::PtrStrideOp::create(builder, loc, eltTy, begin, endOffsetVal);
1900 }
1901
1902 mlir::Value start = isCtor ? begin : end;
1903 mlir::Value stop = isCtor ? end : begin;
1904
1905 // For dynamic destructors, guard against zero elements.
1906 // This places the destructor loop emitted below inside the if block.
1907 cir::IfOp ifOp;
1908 if (isDynamic) {
1909 mlir::Value guardCond;
1910 if (isCtor) {
1911 mlir::Value zero = builder.getUnsignedInt(loc, 0, sizeTypeSize);
1912 guardCond = cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne,
1913 numElements, zero);
1914 } else {
1915 // We could check for numElements != 0 in this case too, but this matches
1916 // what classic codegen does.
1917 guardCond =
1918 cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne, start, stop);
1919 }
1920 ifOp = cir::IfOp::create(builder, loc, guardCond,
1921 /*withElseRegion=*/false,
1922 [&](mlir::OpBuilder &, mlir::Location) {});
1923 builder.setInsertionPointToStart(&ifOp.getThenRegion().front());
1924 }
1925
1926 mlir::Value tmpAddr = builder.createAlloca(
1927 loc, /*addr type*/ builder.getPointerTo(eltTy),
1928 /*var type*/ eltTy, "__array_idx", builder.getAlignmentAttr(1));
1929 builder.createStore(loc, start, tmpAddr);
1930
1931 mlir::Block *bodyBlock = &op->getRegion(0).front();
1932
1933 // Clone the region body (ctor/dtor call and any setup ops like per-element
1934 // zero-init) into the loop, remapping the block argument to the current
1935 // element pointer.
1936 auto cloneRegionBodyInto = [&](mlir::Block *srcBlock,
1937 mlir::Value replacement) {
1938 mlir::IRMapping map;
1939 map.map(srcBlock->getArgument(0), replacement);
1940 for (mlir::Operation &regionOp : *srcBlock) {
1941 if (!mlir::isa<cir::YieldOp>(&regionOp))
1942 builder.clone(regionOp, map);
1943 }
1944 };
1945
1946 mlir::Block *partialDtorBlock = nullptr;
1947 if (auto arrayCtor = mlir::dyn_cast<cir::ArrayCtor>(op)) {
1948 mlir::Region &partialDtor = arrayCtor.getPartialDtor();
1949 if (!partialDtor.empty())
1950 partialDtorBlock = &partialDtor.front();
1951 } else if (auto arrayDtor = mlir::dyn_cast<cir::ArrayDtor>(op)) {
1952 // When the element destructor may throw, reuse the body block as the
1953 // partial-dtor block so that an exception thrown by an element's dtor
1954 // continues the reverse-destruction loop in the EH cleanup region. The
1955 // body block already stores the next element pointer to `tmpAddr`
1956 // before invoking the dtor, so when an exception unwinds from the
1957 // dtor call `tmpAddr` already points at the element that threw, and
1958 // the cleanup loop picks up from `tmpAddr - 1` and walks back to
1959 // `begin`.
1960 if (arrayDtor.getDtorMayThrow())
1961 partialDtorBlock = bodyBlock;
1962 }
1963
1964 auto emitCtorDtorLoop = [&]() {
1965 builder.createDoWhile(
1966 loc,
1967 /*condBuilder=*/
1968 [&](mlir::OpBuilder &b, mlir::Location loc) {
1969 auto currentElement = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1970 auto cmp = cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne,
1971 currentElement, stop);
1972 builder.createCondition(cmp);
1973 },
1974 /*bodyBuilder=*/
1975 [&](mlir::OpBuilder &b, mlir::Location loc) {
1976 auto currentElement = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
1977 if (isCtor) {
1978 cloneRegionBodyInto(bodyBlock, currentElement);
1979 mlir::Value stride = builder.getUnsignedInt(loc, 1, sizeTypeSize);
1980 auto nextElement = cir::PtrStrideOp::create(builder, loc, eltTy,
1981 currentElement, stride);
1982 builder.createStore(loc, nextElement, tmpAddr);
1983 } else {
1984 mlir::Value stride = builder.getSignedInt(loc, -1, sizeTypeSize);
1985 auto prevElement = cir::PtrStrideOp::create(builder, loc, eltTy,
1986 currentElement, stride);
1987 builder.createStore(loc, prevElement, tmpAddr);
1988 cloneRegionBodyInto(bodyBlock, prevElement);
1989 }
1990
1991 cir::YieldOp::create(b, loc);
1992 });
1993 };
1994
1995 if (partialDtorBlock) {
1996 cir::CleanupScopeOp::create(
1997 builder, loc, cir::CleanupKind::EH,
1998 /*bodyBuilder=*/
1999 [&](mlir::OpBuilder &b, mlir::Location loc) {
2000 emitCtorDtorLoop();
2001 cir::YieldOp::create(b, loc);
2002 },
2003 /*cleanupBuilder=*/
2004 [&](mlir::OpBuilder &b, mlir::Location loc) {
2005 auto cur = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
2006 auto cmp =
2007 cir::CmpOp::create(builder, loc, cir::CmpOpKind::ne, cur, begin);
2008 cir::IfOp::create(
2009 builder, loc, cmp, /*withElseRegion=*/false,
2010 [&](mlir::OpBuilder &b, mlir::Location loc) {
2011 builder.createDoWhile(
2012 loc,
2013 /*condBuilder=*/
2014 [&](mlir::OpBuilder &b, mlir::Location loc) {
2015 auto el = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
2016 auto neq = cir::CmpOp::create(
2017 builder, loc, cir::CmpOpKind::ne, el, begin);
2018 builder.createCondition(neq);
2019 },
2020 /*bodyBuilder=*/
2021 [&](mlir::OpBuilder &b, mlir::Location loc) {
2022 auto el = cir::LoadOp::create(b, loc, eltTy, tmpAddr);
2023 mlir::Value negOne =
2024 builder.getSignedInt(loc, -1, sizeTypeSize);
2025 auto prev = cir::PtrStrideOp::create(builder, loc, eltTy,
2026 el, negOne);
2027 builder.createStore(loc, prev, tmpAddr);
2028 cloneRegionBodyInto(partialDtorBlock, prev);
2029 builder.createYield(loc);
2030 });
2031 cir::YieldOp::create(builder, loc);
2032 });
2033 cir::YieldOp::create(b, loc);
2034 });
2035 } else {
2036 emitCtorDtorLoop();
2037 }
2038
2039 if (ifOp)
2040 cir::YieldOp::create(builder, loc);
2041
2042 op->erase();
2043}
2044
2045void LoweringPreparePass::lowerArrayDtor(cir::ArrayDtor op) {
2046 CIRBaseBuilderTy builder(getContext());
2047 builder.setInsertionPointAfter(op.getOperation());
2048
2049 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
2050
2051 if (op.getNumElements()) {
2052 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
2053 op.getNumElements(), /*arrayLen=*/0,
2054 /*isCtor=*/false);
2055 return;
2056 }
2057
2058 auto arrayLen =
2059 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
2060 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
2061 /*numElements=*/nullptr, arrayLen,
2062 /*isCtor=*/false);
2063}
2064
2065void LoweringPreparePass::lowerArrayCtor(cir::ArrayCtor op) {
2066 cir::CIRBaseBuilderTy builder(getContext());
2067 builder.setInsertionPointAfter(op.getOperation());
2068
2069 mlir::Type eltTy = op->getRegion(0).getArgument(0).getType();
2070
2071 if (op.getNumElements()) {
2072 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
2073 op.getNumElements(), /*arrayLen=*/0,
2074 /*isCtor=*/true);
2075 return;
2076 }
2077
2078 auto arrayLen =
2079 mlir::cast<cir::ArrayType>(op.getAddr().getType().getPointee()).getSize();
2080 lowerArrayDtorCtorIntoLoop(builder, astCtx, op, eltTy, op.getAddr(),
2081 /*numElements=*/nullptr, arrayLen,
2082 /*isCtor=*/true);
2083}
2084
2085cir::FuncOp LoweringPreparePass::getCalledFunction(cir::CallOp callOp) {
2086 mlir::SymbolRefAttr sym = llvm::dyn_cast_if_present<mlir::SymbolRefAttr>(
2087 callOp.getCallableForCallee());
2088 if (!sym)
2089 return nullptr;
2090 return symbolTables.lookupNearestSymbolFrom<cir::FuncOp>(callOp, sym);
2091}
2092
2093void LoweringPreparePass::lowerTrivialCopyCall(cir::CallOp op) {
2094 cir::FuncOp funcOp = getCalledFunction(op);
2095 if (!funcOp)
2096 return;
2097
2098 std::optional<cir::CtorKind> ctorKind = funcOp.getCxxConstructorKind();
2099 if (ctorKind && *ctorKind == cir::CtorKind::Copy &&
2100 funcOp.isCxxTrivialMemberFunction()) {
2101 // Replace the trivial copy constructor call with a `CopyOp`
2102 CIRBaseBuilderTy builder(getContext());
2103 mlir::ValueRange operands = op.getOperands();
2104 mlir::Value dest = operands[0];
2105 mlir::Value src = operands[1];
2106 builder.setInsertionPoint(op);
2107 builder.createCopy(dest, src);
2108 op.erase();
2109 }
2110}
2111
2112cir::GlobalOp LoweringPreparePass::getOrCreateConstAggregateGlobal(
2113 CIRBaseBuilderTy &builder, mlir::Location loc, llvm::StringRef baseName,
2114 mlir::Type ty, mlir::TypedAttr constant) {
2115 // Look up (and lazily populate) the per-base-name cache.
2116 llvm::SmallVector<cir::GlobalOp, 1> &versions =
2117 constAggregateGlobals[baseName];
2118
2119 // First, check globals we've already discovered for this base name.
2120 for (cir::GlobalOp gv : versions) {
2121 if (gv.getSymType() == ty && gv.getInitialValue() == constant)
2122 return gv;
2123 }
2124
2125 // No cached match. Scan the module's symbol table starting from the next
2126 // unscanned version. In practice this should usually exit on the first
2127 // iteration, but it's possible that some other pass or a previous
2128 // invocation of this pass created globals using this same logic.
2129 llvm::SmallString<128> name(baseName);
2130 size_t baseLen = name.size();
2131 unsigned version = versions.size();
2132 while (true) {
2133 name.resize(baseLen);
2134 if (version != 0) {
2135 name.push_back('.');
2136 llvm::Twine(version).toVector(name);
2137 }
2138 auto existingGv = symbolTables.lookupSymbolIn<cir::GlobalOp>(
2139 mlirModule, mlir::StringAttr::get(&getContext(), name));
2140 if (!existingGv)
2141 break;
2142 versions.push_back(existingGv);
2143 if (existingGv.getSymType() == ty &&
2144 existingGv.getInitialValue() == constant)
2145 return existingGv;
2146 ++version;
2147 }
2148
2149 // No match found, create a new global. The loop above found an unused name.
2150 mlir::OpBuilder::InsertionGuard guard(builder);
2151 builder.setInsertionPointToStart(mlirModule.getBody());
2152 auto gv =
2153 cir::GlobalOp::create(builder, loc, name, ty,
2154 /*isConstant=*/true,
2155 cir::LangAddressSpaceAttr::get(
2156 &getContext(), cir::LangAddressSpace::Default),
2157 cir::GlobalLinkageKind::PrivateLinkage);
2158 mlir::SymbolTable::setSymbolVisibility(
2159 gv, mlir::SymbolTable::Visibility::Private);
2160 gv.setInitialValueAttr(constant);
2161
2162 // Keep the cached symbol table in sync with the new global so subsequent
2163 // lookups for other base names find it.
2164 symbolTables.getSymbolTable(mlirModule).insert(gv);
2165
2166 versions.push_back(gv);
2167 return gv;
2168}
2169
2170void LoweringPreparePass::lowerStoreOfConstAggregate(cir::StoreOp op) {
2171 // Check if the value operand is a cir.const with aggregate type.
2172 auto constOp = op.getValue().getDefiningOp<cir::ConstantOp>();
2173 if (!constOp)
2174 return;
2175
2176 mlir::Type ty = constOp.getType();
2177 if (!mlir::isa<cir::ArrayType, cir::RecordType>(ty))
2178 return;
2179
2180 // Only transform stores to local variables (backed by cir.alloca).
2181 // Stores to other addresses (e.g. base_class_addr) should not be
2182 // transformed as they may be partial initializations.
2183 auto alloca = op.getAddr().getDefiningOp<cir::AllocaOp>();
2184 if (!alloca)
2185 return;
2186
2187 mlir::TypedAttr constant = constOp.getValue();
2188
2189 // OG implements several optimization tiers for constant aggregate
2190 // initialization. For now we always create a global constant + memcpy
2191 // (shouldCreateMemCpyFromGlobal). Future work can add the intermediate
2192 // tiers.
2196
2197 // Get function name from parent cir.func.
2198 auto func = op->getParentOfType<cir::FuncOp>();
2199 if (!func)
2200 return;
2201 llvm::StringRef funcName = func.getSymName();
2202
2203 // Get variable name from the alloca.
2204 llvm::StringRef varName = alloca.getName();
2205
2206 // Build base name: __const.<func>.<var>
2207 std::string baseName = ("__const." + funcName + "." + varName).str();
2208 CIRBaseBuilderTy builder(getContext());
2209
2210 // Check for existing globals and create a new global with a unique name
2211 // if no match is found.
2212 cir::GlobalOp gv = getOrCreateConstAggregateGlobal(builder, op.getLoc(),
2213 baseName, ty, constant);
2214
2215 // Now replace the store with get_global + copy.
2216 builder.setInsertionPoint(op);
2217
2218 auto ptrTy = cir::PointerType::get(ty);
2219 mlir::Value globalPtr =
2220 cir::GetGlobalOp::create(builder, op.getLoc(), ptrTy, gv.getSymName());
2221
2222 // Replace store with copy.
2223 builder.createCopy(op.getAddr(), globalPtr);
2224
2225 // Erase the original store.
2226 op.erase();
2227
2228 // Erase the cir.const if it has no remaining users.
2229 if (constOp.use_empty())
2230 constOp.erase();
2231}
2232
2233void LoweringPreparePass::runOnOp(mlir::Operation *op) {
2234 if (auto arrayCtor = dyn_cast<cir::ArrayCtor>(op)) {
2235 lowerArrayCtor(arrayCtor);
2236 } else if (auto arrayDtor = dyn_cast<cir::ArrayDtor>(op)) {
2237 lowerArrayDtor(arrayDtor);
2238 } else if (auto cast = mlir::dyn_cast<cir::CastOp>(op)) {
2239 lowerCastOp(cast);
2240 } else if (auto complexDiv = mlir::dyn_cast<cir::ComplexDivOp>(op)) {
2241 lowerComplexDivOp(complexDiv);
2242 } else if (auto complexMul = mlir::dyn_cast<cir::ComplexMulOp>(op)) {
2243 lowerComplexMulOp(complexMul);
2244 } else if (auto glob = mlir::dyn_cast<cir::GlobalOp>(op)) {
2245 lowerGlobalOp(glob);
2246 if (auto regAttr = glob->getAttrOfType<CUDAVarRegistrationInfoAttr>(
2247 CUDAVarRegistrationInfoAttr::getMnemonic()))
2248 cudaDeviceVars.emplace_back(glob, regAttr);
2249 } else if (auto getGlob = mlir::dyn_cast<cir::GetGlobalOp>(op)) {
2250 lowerGetGlobalOp(getGlob);
2251 } else if (auto unaryOp = mlir::dyn_cast<cir::UnaryOpInterface>(op)) {
2252 lowerUnaryOp(unaryOp);
2253 } else if (auto callOp = dyn_cast<cir::CallOp>(op)) {
2254 lowerTrivialCopyCall(callOp);
2255 } else if (auto storeOp = dyn_cast<cir::StoreOp>(op)) {
2256 lowerStoreOfConstAggregate(storeOp);
2257 } else if (auto fnOp = dyn_cast<cir::FuncOp>(op)) {
2258 if (auto globalCtor = fnOp.getGlobalCtorPriority())
2259 globalCtorList.emplace_back(fnOp.getName(), globalCtor.value());
2260 else if (auto globalDtor = fnOp.getGlobalDtorPriority())
2261 globalDtorList.emplace_back(fnOp.getName(), globalDtor.value());
2262
2263 if (mlir::Attribute attr =
2264 fnOp->getAttr(cir::CUDAKernelNameAttr::getMnemonic())) {
2265 auto kernelNameAttr = dyn_cast<CUDAKernelNameAttr>(attr);
2266 llvm::StringRef kernelName = kernelNameAttr.getKernelName();
2267 cudaKernelMap[kernelName] = fnOp;
2268 }
2269 } else if (auto threeWayCmp = dyn_cast<cir::CmpThreeWayOp>(op)) {
2270 lowerThreeWayCmpOp(threeWayCmp);
2271 } else if (auto initOp = dyn_cast<cir::LocalInitOp>(op)) {
2272 lowerLocalInitOp(initOp);
2273 }
2274}
2275
2276static llvm::StringRef getCUDAPrefix(clang::ASTContext *astCtx) {
2277 if (astCtx->getLangOpts().HIP)
2278 return "hip";
2279 return "cuda";
2280}
2281
2282static std::string addUnderscoredPrefix(llvm::StringRef prefix,
2283 llvm::StringRef name) {
2284 return ("__" + prefix + name).str();
2285}
2286
2287/// Creates a global constructor function for the module:
2288///
2289/// For CUDA:
2290/// \code
2291/// void __cuda_module_ctor() {
2292/// Handle = __cudaRegisterFatBinary(GpuBinaryBlob);
2293/// __cuda_register_globals(Handle);
2294/// }
2295/// \endcode
2296///
2297/// For HIP:
2298/// \code
2299/// void __hip_module_ctor() {
2300/// if (__hip_gpubin_handle == 0) {
2301/// __hip_gpubin_handle = __hipRegisterFatBinary(GpuBinaryBlob);
2302/// __hip_register_globals(__hip_gpubin_handle);
2303/// }
2304/// }
2305/// \endcode
2306void LoweringPreparePass::buildCUDAModuleCtor() {
2307 bool isHIP = astCtx->getLangOpts().HIP;
2308
2309 if (astCtx->getLangOpts().GPURelocatableDeviceCode)
2310 llvm_unreachable("GPU RDC NYI");
2311
2312 // For CUDA without -fgpu-rdc, it's safe to stop generating ctor
2313 // if there's nothing to register.
2314 if (cudaKernelMap.empty() && cudaDeviceVars.empty())
2315 return;
2316
2317 // There's no device-side binary, so no need to proceed for CUDA.
2318 // HIP has to create an external symbol in this case, which is NYI.
2319 mlir::Attribute cudaBinaryHandleAttr =
2320 mlirModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName());
2321 if (!cudaBinaryHandleAttr) {
2322 if (isHIP)
2324 return;
2325 }
2326
2327 llvm::StringRef cudaGPUBinaryName =
2328 mlir::cast<CUDABinaryHandleAttr>(cudaBinaryHandleAttr)
2329 .getName()
2330 .getValue();
2331
2332 llvm::vfs::FileSystem &vfs =
2334 llvm::ErrorOr<std::unique_ptr<llvm::MemoryBuffer>> gpuBinaryOrErr =
2335 vfs.getBufferForFile(cudaGPUBinaryName);
2336 if (std::error_code ec = gpuBinaryOrErr.getError()) {
2337 mlirModule->emitError("cannot open GPU binary file: " + cudaGPUBinaryName +
2338 ": " + ec.message());
2339 return;
2340 }
2341 std::unique_ptr<llvm::MemoryBuffer> gpuBinary =
2342 std::move(gpuBinaryOrErr.get());
2343
2344 // Set up common types and builder.
2345 llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
2346 mlir::Location loc = mlirModule->getLoc();
2347 CIRBaseBuilderTy builder(getContext());
2348 builder.setInsertionPointToStart(mlirModule.getBody());
2349
2350 Type voidTy = builder.getVoidTy();
2351 PointerType voidPtrTy = builder.getVoidPtrTy();
2352 PointerType voidPtrPtrTy = builder.getPointerTo(voidPtrTy);
2353 IntType intTy = builder.getSIntNTy(32);
2354 IntType charTy = cir::IntType::get(&getContext(), astCtx->getCharWidth(),
2355 /*isSigned=*/false);
2356
2357 // --- Create fatbin globals ---
2358
2359 // The section names are different for MAC OS X.
2360 llvm::StringRef fatbinConstName =
2361 astCtx->getLangOpts().HIP ? ".hip_fatbin" : ".nv_fatbin";
2362
2363 llvm::StringRef fatbinSectionName =
2364 astCtx->getLangOpts().HIP ? ".hipFatBinSegment" : ".nvFatBinSegment";
2365
2366 // Create the fatbin string constant with GPU binary contents.
2367 auto fatbinType =
2368 ArrayType::get(&getContext(), charTy, gpuBinary->getBuffer().size());
2369 std::string fatbinStrName = addUnderscoredPrefix(cudaPrefix, "_fatbin_str");
2370 GlobalOp fatbinStr = GlobalOp::create(builder, loc, fatbinStrName, fatbinType,
2371 /*isConstant=*/true, {},
2372 GlobalLinkageKind::PrivateLinkage);
2373 fatbinStr.setAlignment(8);
2374 fatbinStr.setInitialValueAttr(cir::ConstArrayAttr::get(
2375 fatbinType, StringAttr::get(gpuBinary->getBuffer(), fatbinType)));
2376 fatbinStr.setSection(fatbinConstName);
2377 fatbinStr.setPrivate();
2378
2379 // Create the fatbin wrapper struct:
2380 // struct { int magic; int version; void *fatbin; void *unused; };
2381 auto fatbinWrapperType = RecordType::get(
2382 &getContext(), {intTy, intTy, voidPtrTy, voidPtrTy},
2383 /*packed=*/false, /*padded=*/false, RecordType::RecordKind::Struct);
2384 std::string fatbinWrapperName =
2385 addUnderscoredPrefix(cudaPrefix, "_fatbin_wrapper");
2386 GlobalOp fatbinWrapper = GlobalOp::create(
2387 builder, loc, fatbinWrapperName, fatbinWrapperType,
2388 /*isConstant=*/true, {}, GlobalLinkageKind::PrivateLinkage);
2389 fatbinWrapper.setSection(fatbinSectionName);
2390
2391 constexpr unsigned cudaFatMagic = 0x466243b1;
2392 constexpr unsigned hipFatMagic = 0x48495046;
2393 unsigned fatMagic = isHIP ? hipFatMagic : cudaFatMagic;
2394
2395 auto magicInit = IntAttr::get(intTy, fatMagic);
2396 auto versionInit = IntAttr::get(intTy, 1);
2397 auto fatbinStrSymbol =
2398 mlir::FlatSymbolRefAttr::get(fatbinStr.getSymNameAttr());
2399 auto fatbinInit = GlobalViewAttr::get(voidPtrTy, fatbinStrSymbol);
2400 mlir::TypedAttr unusedInit = builder.getConstNullPtrAttr(voidPtrTy);
2401 fatbinWrapper.setInitialValueAttr(cir::ConstRecordAttr::get(
2402 fatbinWrapperType,
2403 mlir::ArrayAttr::get(&getContext(),
2404 {magicInit, versionInit, fatbinInit, unusedInit})));
2405
2406 // Create the GPU binary handle global variable.
2407 std::string gpubinHandleName =
2408 addUnderscoredPrefix(cudaPrefix, "_gpubin_handle");
2409
2410 GlobalOp gpuBinHandle = GlobalOp::create(
2411 builder, loc, gpubinHandleName, voidPtrPtrTy,
2412 /*isConstant=*/false, {}, cir::GlobalLinkageKind::InternalLinkage);
2413 gpuBinHandle.setInitialValueAttr(builder.getConstNullPtrAttr(voidPtrPtrTy));
2414 gpuBinHandle.setPrivate();
2415
2416 // Declare this function:
2417 // void **__{cuda|hip}RegisterFatBinary(void *);
2418
2419 std::string regFuncName =
2420 addUnderscoredPrefix(cudaPrefix, "RegisterFatBinary");
2421 FuncType regFuncType = FuncType::get({voidPtrTy}, voidPtrPtrTy);
2422 cir::FuncOp regFunc =
2423 buildRuntimeFunction(builder, regFuncName, loc, regFuncType);
2424
2425 std::string moduleCtorName = addUnderscoredPrefix(cudaPrefix, "_module_ctor");
2426 cir::FuncOp moduleCtor = buildRuntimeFunction(
2427 builder, moduleCtorName, loc, FuncType::get({}, voidTy),
2428 GlobalLinkageKind::InternalLinkage);
2429
2430 globalCtorList.emplace_back(moduleCtorName,
2431 cir::GlobalCtorAttr::getDefaultPriority());
2432 builder.setInsertionPointToStart(moduleCtor.addEntryBlock());
2434 if (isHIP) {
2435 // --- Create HIP CTOR ---
2436 // if (__hip_gpubin_handle == nullptr)
2437 // __hip_gpubin_handle = __hipRegisterFatBinary(&fatbinWrapper);
2438 // __hip_register_globals(__hip_gpubin_handle);
2439 // atexit(__hip_module_dtor);
2440 mlir::Block *entryBlock = builder.getInsertionBlock();
2441 mlir::Region *parent = entryBlock->getParent();
2442 mlir::Block *ifBlock = builder.createBlock(parent);
2443 mlir::Block *exitBlock = builder.createBlock(parent);
2444 {
2445 mlir::OpBuilder::InsertionGuard guard(builder);
2446 builder.setInsertionPointToEnd(entryBlock);
2447 mlir::Value handle =
2448 builder.createLoad(loc, builder.createGetGlobal(gpuBinHandle));
2449 auto handlePtrTy = mlir::cast<cir::PointerType>(handle.getType());
2450 mlir::Value nullPtr = builder.getNullPtr(handlePtrTy, loc);
2451 mlir::Value isNull =
2452 builder.createCompare(loc, cir::CmpOpKind::eq, handle, nullPtr);
2453 cir::BrCondOp::create(builder, loc, isNull, ifBlock, exitBlock);
2454 }
2455 {
2456 // Handle is null: load the fatbin and register it.
2457 mlir::OpBuilder::InsertionGuard guard(builder);
2458 builder.setInsertionPointToStart(ifBlock);
2459 mlir::Value wrapper = builder.createGetGlobal(fatbinWrapper);
2460 mlir::Value fatbinVoidPtr = builder.createBitcast(wrapper, voidPtrTy);
2461 cir::CallOp gpuBinaryHandleCall =
2462 builder.createCallOp(loc, regFunc, fatbinVoidPtr);
2463 mlir::Value gpuBinaryHandle = gpuBinaryHandleCall.getResult();
2464 // Store the value back to the global `__hip_gpubin_handle`.
2465 mlir::Value gpuBinaryHandleGlobal = builder.createGetGlobal(gpuBinHandle);
2466 builder.createStore(loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
2467 cir::BrOp::create(builder, loc, exitBlock);
2468 }
2469 {
2470 // Exit block: load the (possibly newly-registered) handle, call
2471 // __hip_register_globals, and register the module dtor with atexit().
2472 mlir::OpBuilder::InsertionGuard guard(builder);
2473 builder.setInsertionPointToStart(exitBlock);
2474 mlir::Value gHandle =
2475 builder.createLoad(loc, builder.createGetGlobal(gpuBinHandle));
2476
2477 if (std::optional<FuncOp> regGlobal = buildCUDARegisterGlobals())
2478 builder.createCallOp(loc, *regGlobal, gHandle);
2479
2480 if (std::optional<FuncOp> dtor = buildHIPModuleDtor()) {
2481 cir::CIRBaseBuilderTy globalBuilder(getContext());
2482 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2483 FuncOp atexit = buildRuntimeFunction(
2484 globalBuilder, "atexit", loc,
2485 FuncType::get(PointerType::get(dtor->getFunctionType()), intTy));
2486 mlir::Value dtorFunc = GetGlobalOp::create(
2487 builder, loc, PointerType::get(dtor->getFunctionType()),
2488 mlir::FlatSymbolRefAttr::get(dtor->getSymNameAttr()));
2489 builder.createCallOp(loc, atexit, dtorFunc);
2490 }
2491 cir::ReturnOp::create(builder, loc);
2492 }
2493 return;
2494 }
2495 if (!astCtx->getLangOpts().GPURelocatableDeviceCode) {
2496
2497 // --- Create CUDA CTOR-DTOR ---
2498 // Register binary with CUDA runtime. This is substantially different in
2499 // default mode vs. separate compilation.
2500 // Corresponding code:
2501 // gpuBinaryHandle = __cudaRegisterFatBinary(&fatbinWrapper);
2502 mlir::Value wrapper = builder.createGetGlobal(fatbinWrapper);
2503 mlir::Value fatbinVoidPtr = builder.createBitcast(wrapper, voidPtrTy);
2504 cir::CallOp gpuBinaryHandleCall =
2505 builder.createCallOp(loc, regFunc, fatbinVoidPtr);
2506 mlir::Value gpuBinaryHandle = gpuBinaryHandleCall.getResult();
2507 // Store the value back to the global `__cuda_gpubin_handle`.
2508 mlir::Value gpuBinaryHandleGlobal = builder.createGetGlobal(gpuBinHandle);
2509 builder.createStore(loc, gpuBinaryHandle, gpuBinaryHandleGlobal);
2510
2511 // --- Generate __cuda_register_globals and call it ---
2512 if (std::optional<FuncOp> regGlobal = buildCUDARegisterGlobals()) {
2513 builder.createCallOp(loc, *regGlobal, gpuBinaryHandle);
2514 }
2515
2516 // From CUDA 10.1 onwards, we must call this function to end registration:
2517 // void __cudaRegisterFatBinaryEnd(void **fatbinHandle);
2518 // This is CUDA-specific, so no need to use `addUnderscoredPrefix`.
2520 astCtx->getTargetInfo().getSDKVersion(),
2522 cir::CIRBaseBuilderTy globalBuilder(getContext());
2523 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2524 FuncOp endFunc =
2525 buildRuntimeFunction(globalBuilder, "__cudaRegisterFatBinaryEnd", loc,
2526 FuncType::get({voidPtrPtrTy}, voidTy));
2527 builder.createCallOp(loc, endFunc, gpuBinaryHandle);
2528 }
2529 } else
2530 llvm_unreachable("GPU RDC NYI");
2531
2532 // Create destructor and register it with atexit() the way NVCC does it. Doing
2533 // it during regular destructor phase worked in CUDA before 9.2 but results in
2534 // double-free in 9.2.
2535 if (std::optional<FuncOp> dtor = buildCUDAModuleDtor()) {
2536
2537 // extern "C" int atexit(void (*f)(void));
2538 cir::CIRBaseBuilderTy globalBuilder(getContext());
2539 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2540 FuncOp atexit = buildRuntimeFunction(
2541 globalBuilder, "atexit", loc,
2542 FuncType::get(PointerType::get(dtor->getFunctionType()), intTy));
2543 mlir::Value dtorFunc = GetGlobalOp::create(
2544 builder, loc, PointerType::get(dtor->getFunctionType()),
2545 mlir::FlatSymbolRefAttr::get(dtor->getSymNameAttr()));
2546 builder.createCallOp(loc, atexit, dtorFunc);
2547 }
2548 cir::ReturnOp::create(builder, loc);
2549}
2550
2551std::optional<FuncOp> LoweringPreparePass::buildCUDAModuleDtor() {
2552 if (!mlirModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName()))
2553 return {};
2554
2555 llvm::StringRef prefix = getCUDAPrefix(astCtx);
2556
2557 VoidType voidTy = VoidType::get(&getContext());
2558 PointerType voidPtrPtrTy = PointerType::get(PointerType::get(voidTy));
2559
2560 mlir::Location loc = mlirModule.getLoc();
2561
2562 cir::CIRBaseBuilderTy builder(getContext());
2563 builder.setInsertionPointToStart(mlirModule.getBody());
2564
2565 // define: void __cudaUnregisterFatBinary(void ** handle);
2566 std::string unregisterFuncName =
2567 addUnderscoredPrefix(prefix, "UnregisterFatBinary");
2568 FuncOp unregisterFunc = buildRuntimeFunction(
2569 builder, unregisterFuncName, loc, FuncType::get({voidPtrPtrTy}, voidTy));
2570
2571 // void __cuda_module_dtor();
2572 // Despite the name, OG doesn't treat it as a destructor, so it shouldn't be
2573 // put into globalDtorList. If it were a real dtor, then it would cause
2574 // double free above CUDA 9.2. The way to use it is to manually call
2575 // atexit() at end of module ctor.
2576 std::string dtorName = addUnderscoredPrefix(prefix, "_module_dtor");
2577 FuncOp dtor =
2578 buildRuntimeFunction(builder, dtorName, loc, FuncType::get({}, voidTy),
2579 GlobalLinkageKind::InternalLinkage);
2580
2581 builder.setInsertionPointToStart(dtor.addEntryBlock());
2582
2583 // For dtor, we only need to call:
2584 // __cudaUnregisterFatBinary(__cuda_gpubin_handle);
2585
2586 std::string gpubinName = addUnderscoredPrefix(prefix, "_gpubin_handle");
2587 GlobalOp gpubinGlobal = cast<GlobalOp>(mlirModule.lookupSymbol(gpubinName));
2588 mlir::Value gpubinAddress = builder.createGetGlobal(gpubinGlobal);
2589 mlir::Value gpubin = builder.createLoad(loc, gpubinAddress);
2590 builder.createCallOp(loc, unregisterFunc, gpubin);
2591 ReturnOp::create(builder, loc);
2592
2593 return dtor;
2594}
2595
2596/// Build the HIP module dtor:
2597///
2598/// void __hip_module_dtor() {
2599/// if (__hip_gpubin_handle != nullptr) {
2600/// __hipUnregisterFatBinary(__hip_gpubin_handle);
2601/// __hip_gpubin_handle = nullptr;
2602/// }
2603/// }
2604///
2605/// Despite the name, OG doesn't treat this as a real destructor: putting it on
2606/// the dtor list would cause a double-free. It is meant to be registered via
2607/// atexit() at the end of the module ctor.
2608std::optional<FuncOp> LoweringPreparePass::buildHIPModuleDtor() {
2609 if (!mlirModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName()))
2610 return {};
2611
2612 llvm::StringRef prefix = getCUDAPrefix(astCtx);
2613
2614 VoidType voidTy = VoidType::get(&getContext());
2615 PointerType voidPtrPtrTy = PointerType::get(PointerType::get(voidTy));
2616
2617 mlir::Location loc = mlirModule.getLoc();
2618
2619 cir::CIRBaseBuilderTy builder(getContext());
2620 builder.setInsertionPointToStart(mlirModule.getBody());
2621
2622 // void __hipUnregisterFatBinary(void ** handle);
2623 std::string unregisterFuncName =
2624 addUnderscoredPrefix(prefix, "UnregisterFatBinary");
2625 FuncOp unregisterFunc = buildRuntimeFunction(
2626 builder, unregisterFuncName, loc, FuncType::get({voidPtrPtrTy}, voidTy));
2627
2628 std::string dtorName = addUnderscoredPrefix(prefix, "_module_dtor");
2629 FuncOp dtor =
2630 buildRuntimeFunction(builder, dtorName, loc, FuncType::get({}, voidTy),
2631 GlobalLinkageKind::InternalLinkage);
2632
2633 std::string gpubinName = addUnderscoredPrefix(prefix, "_gpubin_handle");
2634 GlobalOp gpuBinGlobal = cast<GlobalOp>(mlirModule.lookupSymbol(gpubinName));
2635
2636 mlir::Block *entryBlock = dtor.addEntryBlock();
2637 mlir::Block *ifBlock = builder.createBlock(&dtor.getBody());
2638 mlir::Block *exitBlock = builder.createBlock(&dtor.getBody());
2639
2640 mlir::OpBuilder::InsertionGuard guard(builder);
2641 builder.setInsertionPointToEnd(entryBlock);
2642 mlir::Value handle =
2643 builder.createLoad(loc, builder.createGetGlobal(gpuBinGlobal));
2644 auto handlePtrTy = mlir::cast<cir::PointerType>(handle.getType());
2645 mlir::Value nullPtr = builder.getNullPtr(handlePtrTy, loc);
2646 mlir::Value isNotNull =
2647 builder.createCompare(loc, cir::CmpOpKind::ne, handle, nullPtr);
2648 cir::BrCondOp::create(builder, loc, isNotNull, ifBlock, exitBlock);
2649
2650 {
2651 // Handle is non-null: unregister and clear it.
2652 mlir::OpBuilder::InsertionGuard ifGuard(builder);
2653 builder.setInsertionPointToStart(ifBlock);
2654 builder.createCallOp(loc, unregisterFunc, handle);
2655 builder.createStore(loc, nullPtr, builder.createGetGlobal(gpuBinGlobal));
2656 cir::BrOp::create(builder, loc, exitBlock);
2657 }
2658 {
2659 mlir::OpBuilder::InsertionGuard exitGuard(builder);
2660 builder.setInsertionPointToStart(exitBlock);
2661 cir::ReturnOp::create(builder, loc);
2662 }
2663
2664 return dtor;
2665}
2666
2667std::optional<FuncOp> LoweringPreparePass::buildCUDARegisterGlobals() {
2668 if (cudaKernelMap.empty() && cudaDeviceVars.empty())
2669 return {};
2670
2671 cir::CIRBaseBuilderTy builder(getContext());
2672 builder.setInsertionPointToStart(mlirModule.getBody());
2673
2674 mlir::Location loc = mlirModule.getLoc();
2675 llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
2676
2677 auto voidTy = VoidType::get(&getContext());
2678 auto voidPtrTy = PointerType::get(voidTy);
2679 auto voidPtrPtrTy = PointerType::get(voidPtrTy);
2680
2681 // Create the function:
2682 // void __cuda_register_globals(void **fatbinHandle)
2683 std::string regGlobalFuncName =
2684 addUnderscoredPrefix(cudaPrefix, "_register_globals");
2685 auto regGlobalFuncTy = FuncType::get({voidPtrPtrTy}, voidTy);
2686 FuncOp regGlobalFunc =
2687 buildRuntimeFunction(builder, regGlobalFuncName, loc, regGlobalFuncTy,
2688 /*linkage=*/GlobalLinkageKind::InternalLinkage);
2689 builder.setInsertionPointToStart(regGlobalFunc.addEntryBlock());
2690
2691 buildCUDARegisterGlobalFunctions(builder, regGlobalFunc);
2692 buildCUDARegisterVars(builder, regGlobalFunc);
2693
2694 ReturnOp::create(builder, loc);
2695 return regGlobalFunc;
2696}
2697
2698void LoweringPreparePass::buildCUDARegisterGlobalFunctions(
2699 cir::CIRBaseBuilderTy &builder, FuncOp regGlobalFunc) {
2700 mlir::Location loc = mlirModule.getLoc();
2701 llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
2702 cir::CIRDataLayout dataLayout(mlirModule);
2703
2704 auto voidTy = VoidType::get(&getContext());
2705 auto voidPtrTy = PointerType::get(voidTy);
2706 auto voidPtrPtrTy = PointerType::get(voidPtrTy);
2707 IntType intTy = builder.getSIntNTy(32);
2708 IntType charTy = cir::IntType::get(&getContext(), astCtx->getCharWidth(),
2709 /*isSigned=*/false);
2710
2711 // Extract the GPU binary handle argument.
2712 mlir::Value fatbinHandle = *regGlobalFunc.args_begin();
2713
2714 cir::CIRBaseBuilderTy globalBuilder(getContext());
2715 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2716
2717 // Declare CUDA internal functions:
2718 // int __cudaRegisterFunction(
2719 // void **fatbinHandle,
2720 // const char *hostFunc,
2721 // char *deviceFunc,
2722 // const char *deviceName,
2723 // int threadLimit,
2724 // uint3 *tid, uint3 *bid, dim3 *bDim, dim3 *gDim,
2725 // int *wsize
2726 // )
2727 // OG doesn't care about the types at all. They're treated as void*.
2728
2729 FuncOp cudaRegisterFunction = buildRuntimeFunction(
2730 globalBuilder, addUnderscoredPrefix(cudaPrefix, "RegisterFunction"), loc,
2731 FuncType::get({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
2732 voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, voidPtrTy},
2733 intTy));
2734
2735 auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
2736 auto strType = ArrayType::get(&getContext(), charTy, 1 + str.size());
2737 auto tmpString = cir::GlobalOp::create(
2738 globalBuilder, loc, (".str" + str).str(), strType,
2739 /*isConstant=*/true, {},
2740 /*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
2741
2742 // We must make the string zero-terminated.
2743 tmpString.setInitialValueAttr(
2744 ConstArrayAttr::get(strType, StringAttr::get(str + "\0", strType)));
2745 tmpString.setPrivate();
2746 return tmpString;
2747 };
2748
2749 cir::ConstantOp cirNullPtr = builder.getNullPtr(voidPtrTy, loc);
2750 bool isHIP = astCtx->getLangOpts().HIP;
2751 for (auto kernelName : cudaKernelMap.keys()) {
2752 FuncOp deviceStub = cudaKernelMap[kernelName];
2753 GlobalOp deviceFuncStr = makeConstantString(kernelName);
2754 mlir::Value deviceFunc = builder.createBitcast(
2755 builder.createGetGlobal(deviceFuncStr), voidPtrTy);
2756
2757 mlir::Value hostFunc;
2758 if (isHIP) {
2759 // Under HIP, the kernel-handle is a GlobalOp shadow created by CIR
2760 // codegen and named with the kernel-reference mangled name (e.g.
2761 // `@_Z2fnv` pointing at the device-stub function
2762 // `_Z17__device_stub__fnv`). The CUDAKernelNameAttr on the device-stub
2763 // uses the same name, so we can resolve the shadow by symbol lookup.
2764 auto funcHandle = cast<GlobalOp>(mlirModule.lookupSymbol(kernelName));
2765 hostFunc =
2766 builder.createBitcast(builder.createGetGlobal(funcHandle), voidPtrTy);
2767 } else {
2768 hostFunc = builder.createBitcast(
2769 GetGlobalOp::create(
2770 builder, loc, PointerType::get(deviceStub.getFunctionType()),
2771 mlir::FlatSymbolRefAttr::get(deviceStub.getSymNameAttr())),
2772 voidPtrTy);
2773 }
2774 builder.createCallOp(
2775 loc, cudaRegisterFunction,
2776 {fatbinHandle, hostFunc, deviceFunc, deviceFunc,
2777 ConstantOp::create(builder, loc, IntAttr::get(intTy, -1)), cirNullPtr,
2778 cirNullPtr, cirNullPtr, cirNullPtr, cirNullPtr});
2779 }
2780}
2781
2782// Emit `__{cuda|hip}RegisterVar` calls inside `__{cuda|hip}_register_globals`
2783// for every device-side shadow that carries a `cu.var_registration` attribute
2784// (attached by `CIRGenNVCUDARuntime::handleVarRegistration`).
2785void LoweringPreparePass::buildCUDARegisterVars(cir::CIRBaseBuilderTy &builder,
2786 FuncOp regGlobalFunc) {
2787 mlir::Location loc = mlirModule.getLoc();
2788 llvm::StringRef cudaPrefix = getCUDAPrefix(astCtx);
2789 cir::CIRDataLayout dataLayout(mlirModule);
2790
2791 PointerType voidPtrTy = builder.getVoidPtrTy();
2792 PointerType voidPtrPtrTy = builder.getPointerTo(voidPtrTy);
2793 IntType intTy = builder.getSIntNTy(32);
2794 IntType sizeTy =
2795 builder.getUIntNTy(astCtx->getTargetInfo().getMaxPointerWidth());
2796 IntType charTy = cir::IntType::get(&getContext(), astCtx->getCharWidth(),
2797 /*isSigned=*/false);
2798
2799 if (cudaDeviceVars.empty())
2800 return;
2801
2802 cir::CIRBaseBuilderTy globalBuilder(getContext());
2803 globalBuilder.setInsertionPointToStart(mlirModule.getBody());
2804
2805 // void __{cuda|hip}RegisterVar(void **fatbinHandle,
2806 // char *hostVar, char *deviceAddress,
2807 // const char *deviceName, int ext,
2808 // size_t size, int constant, int normalized);
2809 // OG ignores parameter types, treating pointers as void*.
2810 cir::VoidType voidTy = builder.getVoidTy();
2811 FuncOp cudaRegisterVar = buildRuntimeFunction(
2812 globalBuilder, addUnderscoredPrefix(cudaPrefix, "RegisterVar"), loc,
2813 FuncType::get({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
2814 sizeTy, intTy, intTy},
2815 voidTy));
2816
2817 auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
2818 auto strType = ArrayType::get(&getContext(), charTy, 1 + str.size());
2819 auto tmpString = cir::GlobalOp::create(
2820 globalBuilder, loc, (".str" + str).str(), strType,
2821 /*isConstant=*/true, {},
2822 /*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
2823 tmpString.setInitialValueAttr(
2824 ConstArrayAttr::get(strType, StringAttr::get(str + "\0", strType)));
2825 tmpString.setPrivate();
2826 return tmpString;
2827 };
2828
2829 mlir::Value fatbinHandle = *regGlobalFunc.args_begin();
2830
2831 for (auto &[global, regAttr] : cudaDeviceVars) {
2832 switch (regAttr.getKind()) {
2833 case cir::CUDADeviceVarKind::Variable:
2834 break;
2835 case cir::CUDADeviceVarKind::Surface:
2836 llvm_unreachable("Surface registration NYI");
2837 case cir::CUDADeviceVarKind::Texture:
2838 llvm_unreachable("Texture registration NYI");
2839 }
2840
2841 if (regAttr.getIsManaged())
2842 llvm_unreachable("Managed variable registration NYI");
2843
2844 GlobalOp deviceNameStr = makeConstantString(regAttr.getDeviceSideName());
2845 mlir::Value deviceName = builder.createBitcast(
2846 builder.createGetGlobal(deviceNameStr), voidPtrTy);
2847 mlir::Value hostVar =
2848 builder.createBitcast(builder.createGetGlobal(global), voidPtrTy);
2849
2850 auto isExtern = ConstantOp::create(
2851 builder, loc, IntAttr::get(intTy, regAttr.getIsExtern() ? 1 : 0));
2852 llvm::TypeSize size = dataLayout.getTypeAllocSize(global.getSymType());
2853 auto varSize = ConstantOp::create(
2854 builder, loc, IntAttr::get(sizeTy, size.getFixedValue()));
2855 auto isConstant = ConstantOp::create(
2856 builder, loc, IntAttr::get(intTy, regAttr.getIsConstant() ? 1 : 0));
2857 auto normalized = ConstantOp::create(builder, loc, IntAttr::get(intTy, 0));
2858 builder.createCallOp(loc, cudaRegisterVar,
2859 {fatbinHandle, hostVar, deviceName, deviceName,
2860 isExtern, varSize, isConstant, normalized});
2861 }
2862}
2863
2864void LoweringPreparePass::runOnOperation() {
2865 mlir::Operation *op = getOperation();
2866 if (isa<::mlir::ModuleOp>(op))
2867 mlirModule = cast<::mlir::ModuleOp>(op);
2868
2869 llvm::SmallVector<mlir::Operation *> opsToTransform;
2870
2871 op->walk([&](mlir::Operation *op) {
2872 if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp,
2873 cir::ComplexMulOp, cir::ComplexDivOp, cir::DynamicCastOp,
2874 cir::FuncOp, cir::CallOp, cir::GetGlobalOp, cir::GlobalOp,
2875 cir::StoreOp, cir::CmpThreeWayOp, cir::IncOp, cir::DecOp,
2876 cir::MinusOp, cir::NotOp, cir::LocalInitOp>(op))
2877 opsToTransform.push_back(op);
2878 });
2879
2880 for (mlir::Operation *o : opsToTransform)
2881 runOnOp(o);
2882
2883 buildCXXGlobalInitFunc();
2884 buildCXXGlobalTlsFunc();
2885 if (astCtx->getLangOpts().CUDA && !astCtx->getLangOpts().CUDAIsDevice)
2886 buildCUDAModuleCtor();
2887
2888 buildGlobalCtorDtorList();
2889}
2890
2891std::unique_ptr<Pass> mlir::createLoweringPreparePass() {
2892 return std::make_unique<LoweringPreparePass>();
2893}
2894
2895std::unique_ptr<Pass>
2897 auto pass = std::make_unique<LoweringPreparePass>();
2898 pass->setASTContext(astCtx);
2899 return std::move(pass);
2900}
Defines the clang::ASTContext interface.
static void emitBody(CodeGenFunction &CGF, const Stmt *S, const Stmt *NextLoop, int MaxLevel, int Level=0)
static llvm::FunctionCallee getGuardReleaseFn(CodeGenModule &CGM, llvm::PointerType *GuardPtrTy)
static llvm::FunctionCallee getGuardAcquireFn(CodeGenModule &CGM, llvm::PointerType *GuardPtrTy)
static mlir::Value buildRangeReductionComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static llvm::StringRef getComplexDivLibCallName(llvm::APFloat::Semantics semantics)
static llvm::SmallVector< mlir::Attribute > prepareCtorDtorAttrList(mlir::MLIRContext *context, llvm::ArrayRef< std::pair< std::string, uint32_t > > list)
static llvm::StringRef getComplexMulLibCallName(llvm::APFloat::Semantics semantics)
static cir::GlobalLinkageKind getThreadLocalWrapperLinkage(GlobalOp op, clang::ASTContext &astCtx)
static mlir::Value buildComplexBinOpLibCall(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, llvm::StringRef(*libFuncNameGetter)(llvm::APFloat::Semantics), mlir::Location loc, cir::ComplexType ty, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static mlir::Value lowerComplexMul(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexMulOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static std::string addUnderscoredPrefix(llvm::StringRef prefix, llvm::StringRef name)
static SmallString< 128 > getTransformedFileName(mlir::ModuleOp mlirModule)
static mlir::Value lowerComplexToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind scalarCastKind)
static void lowerArrayDtorCtorIntoLoop(cir::CIRBaseBuilderTy &builder, clang::ASTContext *astCtx, mlir::Operation *op, mlir::Type eltTy, mlir::Value addr, mlir::Value numElements, uint64_t arrayLen, bool isCtor)
Lower a cir.array.ctor or cir.array.dtor into a do-while loop that iterates over every element.
static bool isThreadWrapperReplaceable(cir::TLS_Model tls, clang::ASTContext &astCtx)
static mlir::Value lowerComplexToScalarCast(mlir::MLIRContext &ctx, cir::CastOp op, cir::CastKind elemToBoolKind)
static mlir::Value buildAlgebraicComplexDiv(CIRBaseBuilderTy &builder, mlir::Location loc, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag)
static llvm::StringRef getCUDAPrefix(clang::ASTContext *astCtx)
static mlir::Type higherPrecisionElementTypeForComplexArithmetic(mlir::MLIRContext &context, clang::ASTContext &cc, CIRBaseBuilderTy &builder, mlir::Type elementType)
static mlir::Value lowerScalarToComplexCast(mlir::MLIRContext &ctx, cir::CastOp op)
static mlir::Value lowerComplexDiv(LoweringPreparePass &pass, CIRBaseBuilderTy &builder, mlir::Location loc, cir::ComplexDivOp op, mlir::Value lhsReal, mlir::Value lhsImag, mlir::Value rhsReal, mlir::Value rhsImag, mlir::MLIRContext &mlirCx, clang::ASTContext &cc)
Defines the clang::Module class, which describes a module in the source code.
static bool compare(const PathDiagnostic &X, const PathDiagnostic &Y)
Defines the SourceManager interface.
Defines various enumerations that describe declaration and type specifiers.
Defines the TargetCXXABI class, which abstracts details of the C++ ABI that we're targeting.
__device__ __2f16 b
__device__ __2f16 float c
mlir::Value createDiv(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::TypedAttr getConstNullPtrAttr(mlir::Type t)
mlir::Value createDec(mlir::Location loc, mlir::Value input, bool nsw=false)
mlir::Value createLogicalOr(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createSub(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::ConditionOp createCondition(mlir::Value condition)
Create a loop condition.
mlir::Value createInc(mlir::Location loc, mlir::Value input, bool nsw=false)
cir::CopyOp createCopy(mlir::Value dst, mlir::Value src, bool isVolatile=false, bool skipTailPadding=false)
Create a copy with inferred length.
cir::VoidType getVoidTy()
cir::ConstantOp getNullValue(mlir::Type ty, mlir::Location loc)
mlir::Value createCast(mlir::Location loc, cir::CastKind kind, mlir::Value src, mlir::Type newTy)
cir::PointerType getVoidFnPtrTy(mlir::TypeRange argTypes={})
Returns void (*)(T...) as a cir::PointerType.
mlir::Value createAdd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::PointerType getPointerTo(mlir::Type ty)
mlir::Value createComplexImag(mlir::Location loc, mlir::Value operand)
cir::ConstantOp getNullPtr(mlir::Type ty, mlir::Location loc)
cir::IntType getUIntNTy(int n)
cir::DoWhileOp createDoWhile(mlir::Location loc, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> condBuilder, llvm::function_ref< void(mlir::OpBuilder &, mlir::Location)> bodyBuilder)
Create a do-while operation.
cir::GetGlobalOp createGetGlobal(mlir::Location loc, cir::GlobalOp global, bool threadLocal=false)
mlir::Value getSignedInt(mlir::Location loc, int64_t val, unsigned numBits)
mlir::Value createAnd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createBitcast(mlir::Value src, mlir::Type newTy)
cir::FuncType getVoidFnTy(mlir::TypeRange argTypes={})
Returns void (T...) as a cir::FuncType.
cir::CmpOp createCompare(mlir::Location loc, cir::CmpOpKind kind, mlir::Value lhs, mlir::Value rhs)
mlir::IntegerAttr getAlignmentAttr(clang::CharUnits alignment)
mlir::Value createSelect(mlir::Location loc, mlir::Value condition, mlir::Value trueValue, mlir::Value falseValue)
mlir::Value createMul(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, OverflowBehavior ob=OverflowBehavior::None)
cir::LoadOp createLoad(mlir::Location loc, mlir::Value ptr, bool isVolatile=false, uint64_t alignment=0)
mlir::Value createMinus(mlir::Location loc, mlir::Value input, bool nsw=false)
cir::ConstantOp getConstantInt(mlir::Location loc, mlir::Type ty, int64_t value)
mlir::Value createComplexCreate(mlir::Location loc, mlir::Value real, mlir::Value imag)
cir::PointerType getVoidPtrTy(clang::LangAS langAS=clang::LangAS::Default)
mlir::Value createIsNaN(mlir::Location loc, mlir::Value operand)
cir::IntType getSIntNTy(int n)
mlir::Value createAlignedLoad(mlir::Location loc, mlir::Value ptr, uint64_t alignment)
cir::CallOp createCallOp(mlir::Location loc, mlir::SymbolRefAttr callee, mlir::Type returnType, mlir::ValueRange operands, llvm::ArrayRef< mlir::NamedAttribute > attrs={}, llvm::ArrayRef< mlir::NamedAttrList > argAttrs={}, llvm::ArrayRef< mlir::NamedAttribute > resAttrs={})
cir::StoreOp createStore(mlir::Location loc, mlir::Value val, mlir::Value dst, bool isVolatile=false, mlir::IntegerAttr align={}, cir::SyncScopeKindAttr scope={}, cir::MemOrderAttr order={})
cir::YieldOp createYield(mlir::Location loc, mlir::ValueRange value={})
Create a yield operation.
mlir::Value createLogicalAnd(mlir::Location loc, mlir::Value lhs, mlir::Value rhs)
mlir::Value createAlloca(mlir::Location loc, cir::PointerType addrType, mlir::Type type, llvm::StringRef name, mlir::IntegerAttr alignment, mlir::Value dynAllocSize)
cir::BoolType getBoolTy()
mlir::Value getUnsignedInt(mlir::Location loc, uint64_t val, unsigned numBits)
mlir::Value createComplexReal(mlir::Location loc, mlir::Value operand)
Holds long-lived AST nodes (such as types and decls) that can be referred to throughout the semantic ...
Definition ASTContext.h:228
SourceManager & getSourceManager()
Definition ASTContext.h:867
MangleContext * createMangleContext(const TargetInfo *T=nullptr)
If T is null pointer, assume the target in ASTContext.
const LangOptions & getLangOpts() const
Definition ASTContext.h:960
uint64_t getTypeSize(QualType T) const
Return the size of the specified (complete) type T, in bits.
const TargetInfo & getTargetInfo() const
Definition ASTContext.h:925
QualType getSignedSizeType() const
Return the unique signed counterpart of the integer type corresponding to size_t.
Module * getCurrentNamedModule() const
Get module under construction, nullptr if this is not a C++20 module.
uint64_t getCharWidth() const
Return the size of the character type, in bits.
llvm::Align getAsAlign() const
getAsAlign - Returns Quantity as a valid llvm::Align, Beware llvm::Align assumes power of two 8-bit b...
Definition CharUnits.h:189
QuantityType getQuantity() const
getQuantity - Get the raw integer representation of this quantity.
Definition CharUnits.h:185
static CharUnits One()
One - Construct a CharUnits quantity of one.
Definition CharUnits.h:58
static CharUnits fromQuantity(QuantityType Quantity)
fromQuantity - Construct a CharUnits quantity from a raw integer type.
Definition CharUnits.h:63
llvm::vfs::FileSystem & getVirtualFileSystem() const
bool isModuleImplementation() const
Is this a module implementation.
Definition Module.h:882
FileManager & getFileManager() const
Exposes information about the current target.
Definition TargetInfo.h:227
const llvm::Triple & getTriple() const
Returns the target triple of the primary target.
unsigned getMaxAtomicInlineWidth() const
Return the maximum width lock-free atomic operation which can be inlined given the supported features...
Definition TargetInfo.h:859
const llvm::fltSemantics & getDoubleFormat() const
Definition TargetInfo.h:804
const llvm::fltSemantics & getHalfFormat() const
Definition TargetInfo.h:789
const llvm::fltSemantics & getBFloat16Format() const
Definition TargetInfo.h:799
const llvm::fltSemantics & getLongDoubleFormat() const
Definition TargetInfo.h:810
const llvm::fltSemantics & getFloatFormat() const
Definition TargetInfo.h:794
virtual uint64_t getMaxPointerWidth() const
Return the maximum width of pointers on this target.
Definition TargetInfo.h:500
const llvm::fltSemantics & getFloat128Format() const
Definition TargetInfo.h:818
const llvm::VersionTuple & getSDKVersion() const
Defines the clang::TargetInfo interface.
static bool isLocalLinkage(GlobalLinkageKind linkage)
Definition CIROpsEnums.h:51
static bool isWeakODRLinkage(GlobalLinkageKind linkage)
Definition CIROpsEnums.h:39
static bool isLinkOnceLinkage(GlobalLinkageKind linkage)
Definition CIROpsEnums.h:33
const internal::VariadicDynCastAllOfMatcher< Decl, VarDecl > varDecl
Matches variable declarations.
bool isHIP(ID Id)
isHIP - Is this a HIP input.
Definition Types.cpp:291
RangeSelector name(std::string ID)
Given a node with a "name", (like NamedDecl, DeclRefExpr, CxxCtorInitializer, and TypeLoc) selects th...
bool isTemplateInstantiation(TemplateSpecializationKind Kind)
Determine whether this template specialization kind refers to an instantiation of an entity (as oppos...
Definition Specifiers.h:213
bool CudaFeatureEnabled(llvm::VersionTuple, CudaFeature)
Definition Cuda.cpp:172
LLVM_READONLY bool isPreprocessingNumberBody(unsigned char c)
Return true if this is the body character of a C preprocessing number, which is [a-zA-Z0-9_.
Definition CharInfo.h:168
@ CUDA_USES_FATBIN_REGISTER_END
Definition Cuda.h:82
unsigned int uint32_t
std::unique_ptr< Pass > createLoweringPreparePass()
static bool hipModuleCtor()
static bool guardAbortOnException()
static bool opGlobalAnnotations()
static bool opGlobalCtorPriority()
static bool shouldSplitConstantStore()
static bool shouldUseMemSetToInitialize()
static bool opFuncExtraAttrs()
static bool shouldUseBZeroPlusStoresToInitialize()
static bool fastMathFlags()
static bool astVarDeclInterface()