Skip to content

Commit 7af924a

Browse files
koparasyAdUhTkJm
andauthored
[CIR][CUDA|HIP] Register global variables (#1980)
Extends and Closes #1978 --------- Co-authored-by: Yue Huang <yue.huang@terapines.com>
1 parent ca31760 commit 7af924a

File tree

5 files changed

+249
-59
lines changed

5 files changed

+249
-59
lines changed

clang/include/clang/CIR/Dialect/IR/CIRDataLayout.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,10 @@ class CIRDataLayout {
129129
mlir::Type getCharType(mlir::MLIRContext *ctx) const {
130130
return typeSizeInfo.getCharType(ctx);
131131
}
132+
133+
mlir::Type getSizeType(mlir::MLIRContext *ctx) const {
134+
return typeSizeInfo.getSizeType(ctx);
135+
}
132136
};
133137

134138
/// Used to lazily calculate structure layout information for a target machine,

clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
127127

128128
// Maps CUDA kernel name to device stub function.
129129
llvm::StringMap<FuncOp> cudaKernelMap;
130+
// Maps CUDA device-side variable name to host-side (shadow) GlobalOp.
131+
llvm::StringMap<GlobalOp> cudaVarMap;
130132

131133
void buildCUDAModuleCtor();
132134
std::optional<FuncOp> buildCUDAModuleDtor();
@@ -135,6 +137,8 @@ struct LoweringPreparePass : public LoweringPrepareBase<LoweringPreparePass> {
135137

136138
void buildCUDARegisterGlobalFunctions(cir::CIRBaseBuilderTy &builder,
137139
FuncOp regGlobalFunc);
140+
void buildCUDARegisterVars(cir::CIRBaseBuilderTy &builder,
141+
FuncOp regGlobalFunc);
138142

139143
///
140144
/// AST related
@@ -1261,8 +1265,7 @@ std::optional<FuncOp> LoweringPreparePass::buildCUDARegisterGlobals() {
12611265
builder.setInsertionPointToStart(regGlobalFunc.addEntryBlock());
12621266

12631267
buildCUDARegisterGlobalFunctions(builder, regGlobalFunc);
1264-
1265-
// TODO(cir): registration for global variables.
1268+
buildCUDARegisterVars(builder, regGlobalFunc);
12661269

12671270
ReturnOp::create(builder, loc);
12681271
return regGlobalFunc;
@@ -1409,6 +1412,81 @@ std::optional<FuncOp> LoweringPreparePass::buildHIPModuleDtor() {
14091412
return dtor;
14101413
}
14111414

1415+
void LoweringPreparePass::buildCUDARegisterVars(cir::CIRBaseBuilderTy &builder,
1416+
FuncOp regGlobalFunc) {
1417+
auto loc = theModule.getLoc();
1418+
auto cudaPrefix = getCUDAPrefix(astCtx);
1419+
1420+
auto voidTy = VoidType::get(&getContext());
1421+
auto voidPtrTy = PointerType::get(voidTy);
1422+
auto voidPtrPtrTy = PointerType::get(voidPtrTy);
1423+
auto intTy = datalayout->getIntType(&getContext());
1424+
auto charTy = datalayout->getCharType(&getContext());
1425+
auto sizeTy = datalayout->getSizeType(&getContext());
1426+
1427+
// Extract the GPU binary handle argument.
1428+
mlir::Value fatbinHandle = *regGlobalFunc.args_begin();
1429+
1430+
cir::CIRBaseBuilderTy globalBuilder(getContext());
1431+
globalBuilder.setInsertionPointToStart(theModule.getBody());
1432+
1433+
// Declare CUDA internal function:
1434+
// void __cudaRegisterVar(
1435+
// void **fatbinHandle,
1436+
// char *hostVarName,
1437+
// char *deviceVarName,
1438+
// const char *deviceVarName,
1439+
// int isExtern, size_t varSize,
1440+
// int isConstant, int zero
1441+
// );
1442+
// Similar to the registration of global functions, OG does not care about
1443+
// pointer types. They will generate the same IR anyway.
1444+
1445+
FuncOp cudaRegisterVar = buildRuntimeFunction(
1446+
globalBuilder, addUnderscoredPrefix(cudaPrefix, "RegisterVar"), loc,
1447+
FuncType::get({voidPtrPtrTy, voidPtrTy, voidPtrTy, voidPtrTy, intTy,
1448+
sizeTy, intTy, intTy},
1449+
voidTy));
1450+
1451+
unsigned int count = 0;
1452+
auto makeConstantString = [&](llvm::StringRef str) -> GlobalOp {
1453+
auto strType = ArrayType::get(&getContext(), charTy, 1 + str.size());
1454+
1455+
auto tmpString = GlobalOp::create(
1456+
globalBuilder, loc, (".str" + str + std::to_string(count++)).str(),
1457+
strType, /*isConstant=*/true,
1458+
/*linkage=*/cir::GlobalLinkageKind::PrivateLinkage);
1459+
1460+
// We must make the string zero-terminated.
1461+
tmpString.setInitialValueAttr(ConstArrayAttr::get(
1462+
strType, StringAttr::get(&getContext(), str + "\0")));
1463+
tmpString.setPrivate();
1464+
return tmpString;
1465+
};
1466+
1467+
for (auto &[deviceSideName, global] : cudaVarMap) {
1468+
GlobalOp varNameStr = makeConstantString(deviceSideName);
1469+
mlir::Value varNameValue =
1470+
builder.createBitcast(builder.createGetGlobal(varNameStr), voidPtrTy);
1471+
1472+
auto globalVarValue =
1473+
builder.createBitcast(builder.createGetGlobal(global), voidPtrTy);
1474+
1475+
// Every device variable that has a shadow on host will not be extern.
1476+
// See CIRGenModule::emitGlobalVarDefinition.
1477+
auto isExtern = ConstantOp::create(builder, loc, IntAttr::get(intTy, 0));
1478+
llvm::TypeSize size = datalayout->getTypeSizeInBits(global.getSymType());
1479+
auto varSize = ConstantOp::create(
1480+
builder, loc, IntAttr::get(sizeTy, size.getFixedValue() / 8));
1481+
auto isConstant = ConstantOp::create(
1482+
builder, loc, IntAttr::get(intTy, global.getConstant()));
1483+
auto zero = ConstantOp::create(builder, loc, IntAttr::get(intTy, 0));
1484+
builder.createCallOp(loc, cudaRegisterVar,
1485+
{fatbinHandle, globalVarValue, varNameValue,
1486+
varNameValue, isExtern, varSize, isConstant, zero});
1487+
}
1488+
}
1489+
14121490
std::optional<FuncOp> LoweringPreparePass::buildCUDAModuleDtor() {
14131491
if (!theModule->getAttr(CIRDialect::getCUDABinaryHandleAttrName()))
14141492
return {};
@@ -1431,9 +1509,9 @@ std::optional<FuncOp> LoweringPreparePass::buildCUDAModuleDtor() {
14311509

14321510
// void __cuda_module_dtor();
14331511
// Despite the name, OG doesn't treat it as a destructor, so it shouldn't be
1434-
// put into globalDtorList. If it were a real dtor, then it would cause double
1435-
// free above CUDA 9.2. The way to use it is to manually call atexit() at end
1436-
// of module ctor.
1512+
// put into globalDtorList. If it were a real dtor, then it would cause
1513+
// double free above CUDA 9.2. The way to use it is to manually call
1514+
// atexit() at end of module ctor.
14371515
std::string dtorName = addUnderscoredPrefix(prefix, "_module_dtor");
14381516
FuncOp dtor =
14391517
buildRuntimeFunction(builder, dtorName, loc, FuncType::get({}, voidTy),
@@ -1721,8 +1799,13 @@ void LoweringPreparePass::runOnOp(Operation *op) {
17211799
lowerVAArgOp(vaArgOp);
17221800
} else if (auto deleteArrayOp = dyn_cast<DeleteArrayOp>(op)) {
17231801
lowerDeleteArrayOp(deleteArrayOp);
1724-
} else if (auto getGlobal = dyn_cast<GlobalOp>(op)) {
1725-
lowerGlobalOp(getGlobal);
1802+
} else if (auto global = dyn_cast<GlobalOp>(op)) {
1803+
lowerGlobalOp(global);
1804+
if (auto attr = op->getAttr(cir::CUDAShadowNameAttr::getMnemonic())) {
1805+
auto shadowNameAttr = dyn_cast<CUDAShadowNameAttr>(attr);
1806+
std::string deviceSideName = shadowNameAttr.getDeviceSideName();
1807+
cudaVarMap[deviceSideName] = global;
1808+
}
17261809
} else if (auto dynamicCast = dyn_cast<DynamicCastOp>(op)) {
17271810
lowerDynamicCastOp(dynamicCast);
17281811
} else if (auto stdFind = dyn_cast<StdFindOp>(op)) {

clang/lib/CIR/FrontendAction/CIRGenAction.cpp

Lines changed: 41 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include "clang/Frontend/FrontendDiagnostic.h"
3636
#include "clang/Frontend/MultiplexConsumer.h"
3737
#include "clang/Lex/Preprocessor.h"
38+
#include "llvm/ADT/SmallString.h"
3839
#include "llvm/Bitcode/BitcodeReader.h"
3940
#include "llvm/IR/DebugInfo.h"
4041
#include "llvm/IR/DiagnosticInfo.h"
@@ -47,11 +48,10 @@
4748
#include "llvm/LTO/LTOBackend.h"
4849
#include "llvm/Linker/Linker.h"
4950
#include "llvm/Pass.h"
50-
#include "llvm/ADT/SmallString.h"
5151
#include "llvm/Support/MemoryBuffer.h"
52+
#include "llvm/Support/Path.h"
5253
#include "llvm/Support/Signals.h"
5354
#include "llvm/Support/SourceMgr.h"
54-
#include "llvm/Support/Path.h"
5555
#include "llvm/Support/TimeProfiler.h"
5656
#include "llvm/Support/Timer.h"
5757
#include "llvm/Support/ToolOutputFile.h"
@@ -112,12 +112,12 @@ class CIRGenConsumer : public clang::ASTConsumer {
112112

113113
CIRGenAction::OutputType Action;
114114

115-
CompilerInstance &CompilerInstance;
116-
DiagnosticsEngine &DiagnosticsEngine;
117-
[[maybe_unused]] const HeaderSearchOptions &HeaderSearchOptions;
118-
CodeGenOptions &CodeGenOptions;
119-
[[maybe_unused]] const TargetOptions &TargetOptions;
120-
[[maybe_unused]] const LangOptions &LangOptions;
115+
CompilerInstance &CI;
116+
DiagnosticsEngine &Diags;
117+
[[maybe_unused]] const HeaderSearchOptions &HeaderSearchOpts;
118+
CodeGenOptions &CodeGenOpts;
119+
[[maybe_unused]] const TargetOptions &TargetOpts;
120+
[[maybe_unused]] const LangOptions &LangOpts;
121121
const FrontendOptions &FeOptions;
122122

123123
std::string InputFileName;
@@ -128,25 +128,21 @@ class CIRGenConsumer : public clang::ASTConsumer {
128128
std::unique_ptr<CIRGenerator> Gen;
129129

130130
public:
131-
CIRGenConsumer(CIRGenAction::OutputType Action,
132-
class CompilerInstance &CompilerInstance,
133-
class DiagnosticsEngine &DiagnosticsEngine,
131+
CIRGenConsumer(CIRGenAction::OutputType Action, class CompilerInstance &CI,
132+
class DiagnosticsEngine &Diags,
134133
IntrusiveRefCntPtr<llvm::vfs::FileSystem> VFS,
135-
const class HeaderSearchOptions &HeaderSearchOptions,
136-
class CodeGenOptions &CodeGenOptions,
137-
const class TargetOptions &TargetOptions,
138-
const class LangOptions &LangOptions,
134+
const class HeaderSearchOptions &HeaderSearchOpts,
135+
class CodeGenOptions &CodeGenOpts,
136+
const class TargetOptions &TargetOpts,
137+
const class LangOptions &LangOpts,
139138
const FrontendOptions &FeOptions, StringRef InputFile,
140139
std::unique_ptr<raw_pwrite_stream> Os)
141-
: Action(Action), CompilerInstance(CompilerInstance),
142-
DiagnosticsEngine(DiagnosticsEngine),
143-
HeaderSearchOptions(HeaderSearchOptions),
144-
CodeGenOptions(CodeGenOptions), TargetOptions(TargetOptions),
145-
LangOptions(LangOptions), FeOptions(FeOptions),
146-
InputFileName(InputFile.str()),
147-
OutputStream(std::move(Os)), FS(VFS),
148-
Gen(std::make_unique<CIRGenerator>(DiagnosticsEngine, std::move(VFS),
149-
CodeGenOptions)) {}
140+
: Action(Action), CI(CI), Diags(Diags),
141+
HeaderSearchOpts(HeaderSearchOpts), CodeGenOpts(CodeGenOpts),
142+
TargetOpts(TargetOpts), LangOpts(LangOpts), FeOptions(FeOptions),
143+
InputFileName(InputFile.str()), OutputStream(std::move(Os)), FS(VFS),
144+
Gen(std::make_unique<CIRGenerator>(Diags, std::move(VFS),
145+
CodeGenOpts)) {}
150146

151147
void Initialize(ASTContext &Ctx) override {
152148
assert(!AstContext && "initialized multiple times");
@@ -221,12 +217,12 @@ class CIRGenConsumer : public clang::ASTConsumer {
221217
FeOptions.ClangIRLifetimeCheck, LifetimeOpts,
222218
FeOptions.ClangIRIdiomRecognizer, IdiomRecognizerOpts,
223219
FeOptions.ClangIRLibOpt, LibOptOpts, PassOptParsingFailure,
224-
CodeGenOptions.OptimizationLevel > 0, FlattenCir,
220+
CodeGenOpts.OptimizationLevel > 0, FlattenCir,
225221
!FeOptions.ClangIRDirectLowering, EnableCcLowering,
226222
FeOptions.ClangIREnableMem2Reg)
227223
.failed()) {
228224
if (!PassOptParsingFailure.empty()) {
229-
auto D = DiagnosticsEngine.Report(diag::err_drv_cir_pass_opt_parsing);
225+
auto D = Diags.Report(diag::err_drv_cir_pass_opt_parsing);
230226
D << PassOptParsingFailure;
231227
} else
232228
llvm::report_fatal_error("CIR codegen: MLIR pass manager fails "
@@ -269,24 +265,21 @@ class CIRGenConsumer : public clang::ASTConsumer {
269265
}
270266
}
271267

272-
bool EmitCIR = LangOptions.EmitCIRToFile || FeOptions.EmitClangIRFile ||
273-
!LangOptions.CIRFile.empty() ||
274-
!FeOptions.ClangIRFile.empty();
268+
bool EmitCIR = LangOpts.EmitCIRToFile || FeOptions.EmitClangIRFile ||
269+
!LangOpts.CIRFile.empty() || !FeOptions.ClangIRFile.empty();
275270
if (EmitCIR) {
276271
std::unique_ptr<raw_pwrite_stream> CIRStream;
277272
llvm::SmallString<128> DefaultPath;
278273
if (!FeOptions.ClangIRFile.empty()) {
279-
CIRStream = CompilerInstance.createOutputFile(
280-
FeOptions.ClangIRFile,
281-
/*Binary=*/false,
282-
/*RemoveFileOnSignal=*/true,
283-
/*UseTemporary=*/true);
284-
} else if (!LangOptions.CIRFile.empty()) {
285-
CIRStream = CompilerInstance.createOutputFile(
286-
LangOptions.CIRFile,
287-
/*Binary=*/false,
288-
/*RemoveFileOnSignal=*/true,
289-
/*UseTemporary=*/true);
274+
CIRStream = CI.createOutputFile(FeOptions.ClangIRFile,
275+
/*Binary=*/false,
276+
/*RemoveFileOnSignal=*/true,
277+
/*UseTemporary=*/true);
278+
} else if (!LangOpts.CIRFile.empty()) {
279+
CIRStream = CI.createOutputFile(LangOpts.CIRFile,
280+
/*Binary=*/false,
281+
/*RemoveFileOnSignal=*/true,
282+
/*UseTemporary=*/true);
290283
} else {
291284
if (!FeOptions.OutputFile.empty() && FeOptions.OutputFile != "-") {
292285
DefaultPath = FeOptions.OutputFile;
@@ -299,11 +292,10 @@ class CIRGenConsumer : public clang::ASTConsumer {
299292
DefaultPath = "clangir-output";
300293
}
301294
llvm::sys::path::replace_extension(DefaultPath, "cir");
302-
CIRStream = CompilerInstance.createOutputFile(
303-
DefaultPath,
304-
/*Binary=*/false,
305-
/*RemoveFileOnSignal=*/true,
306-
/*UseTemporary=*/true);
295+
CIRStream = CI.createOutputFile(DefaultPath,
296+
/*Binary=*/false,
297+
/*RemoveFileOnSignal=*/true,
298+
/*UseTemporary=*/true);
307299
}
308300

309301
if (CIRStream) {
@@ -354,18 +346,17 @@ class CIRGenConsumer : public clang::ASTConsumer {
354346
case CIRGenAction::OutputType::EmitAssembly: {
355347
llvm::LLVMContext LlvmCtx;
356348
bool DisableDebugInfo =
357-
CodeGenOptions.getDebugInfo() == llvm::codegenoptions::NoDebugInfo;
349+
CodeGenOpts.getDebugInfo() == llvm::codegenoptions::NoDebugInfo;
358350
auto LlvmModule = lowerFromCIRToLLVMIR(
359351
FeOptions, MlirMod, std::move(MlirCtx), LlvmCtx,
360352
FeOptions.ClangIRDisableCIRVerifier,
361353
!FeOptions.ClangIRCallConvLowering, DisableDebugInfo);
362354

363355
BackendAction BackendAction = getBackendActionFromOutputType(Action);
364356

365-
emitBackendOutput(CompilerInstance, CodeGenOptions,
366-
C.getTargetInfo().getDataLayoutString(),
367-
LlvmModule.get(), BackendAction, FS,
368-
std::move(OutputStream));
357+
emitBackendOutput(
358+
CI, CodeGenOpts, C.getTargetInfo().getDataLayoutString(),
359+
LlvmModule.get(), BackendAction, FS, std::move(OutputStream));
369360
break;
370361
}
371362
case CIRGenAction::OutputType::None:

0 commit comments

Comments
 (0)