Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/check-all.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
run: |
wget -O - https://apt.llvm.org/llvm-snapshot.gpg.key|sudo apt-key add -
sudo apt-add-repository "deb http://apt.llvm.org/`lsb_release -c | cut -f2`/ llvm-toolchain-`lsb_release -c | cut -f2`-${{ matrix.llvm }} main" || true
sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev libclang-${{ matrix.llvm }}-dev clang-${{ matrix.llvm }} lld-${{ matrix.llvm }} mlir-${{ matrix.llvm }}-tools libmlir-${{ matrix.llvm }} libmlir-${{ matrix.llvm }}-dev libflang-${{ matrix.llvm }}-dev flang-${{ matrix.llvm }} libzstd-dev libmpfr-dev
sudo apt-get install -y cmake gcc g++ llvm-${{ matrix.llvm }}-dev libclang-${{ matrix.llvm }}-dev clang-${{ matrix.llvm }} lld-${{ matrix.llvm }} mlir-${{ matrix.llvm }}-tools libmlir-${{ matrix.llvm }} libmlir-${{ matrix.llvm }}-dev libflang-${{ matrix.llvm }}-dev flang-${{ matrix.llvm }} libzstd-dev libmpfr-dev libomp-${{ matrix.llvm }}-dev
sudo python3 -m pip install --upgrade pip lit
- uses: actions/checkout@v4
- name: mkdir
Expand Down
5 changes: 5 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ To install the dependencies on Debian and Ubuntu, [this repository](https://apt.
sudo apt-get install -y cmake gcc g++ llvm-20-dev libclang-20-dev clang-20 lld-20 mlir-20-tools libmlir-20 libmlir-20-dev libflang-20-dev flang-20 libmpfr-dev
```

LLVM can also be installed using spack as such:
```
spack install llvm+clang+flang+lld+mlir@20
```

## Building

``` shell
Expand Down
9 changes: 6 additions & 3 deletions pass/Raptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,8 @@ class RaptorBase {
auto [Truncation, NumArgsParsed] = parseTruncation(CI, Mode, 1);

RequestContext context(CI, &Builder);
llvm::Value *res = Logic.CreateTruncateFunc(context, F, Truncation, Mode);
llvm::Value *res = Logic.CreateTruncateFunc(
context, F, TruncationConfiguration::getInitial(Truncation, Mode));
if (!res)
return false;
res = Builder.CreatePointerCast(res, CI->getType());
Expand Down Expand Up @@ -696,8 +697,10 @@ class RaptorBase {
for (auto Truncation : FullModuleTruncs) {
IRBuilder<> Builder(F.getContext());
RequestContext context(&*F.getEntryBlock().begin(), &Builder);
Function *TruncatedFunc = Logic.CreateTruncateFunc(
context, &F, Truncation, TruncOpFullModuleMode);
Function *TruncatedFunc =
Logic.CreateTruncateFunc(context, &F,
TruncationConfiguration::getInitial(
Truncation, TruncOpFullModuleMode));

ValueToValueMapTy Mapping;
for (auto &&[Arg, TArg] : llvm::zip(F.args(), TruncatedFunc->args()))
Expand Down
224 changes: 129 additions & 95 deletions pass/RaptorLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Transforms/Utils/Instrumentation.h"
#include <array>
#include <cmath>
#include <tuple>

Expand Down Expand Up @@ -419,11 +420,11 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
LLVMContext &Ctx;

public:
TruncateGenerator(ValueToValueMapTy &originalToNewFn,
FloatTruncation Truncation, Function *oldFunc,
Function *newFunc, RaptorLogic &Logic, bool Root)
: TruncateUtils(Truncation, newFunc->getParent(), Logic),
OriginalToNewFn(originalToNewFn), Truncation(Truncation),
TruncateGenerator(ValueToValueMapTy &originalToNewFn, Function *oldFunc,
Function *newFunc, RaptorLogic &Logic,
TruncationConfiguration TC)
: TruncateUtils(TC.Truncation, newFunc->getParent(), Logic),
OriginalToNewFn(originalToNewFn), Truncation(TC.Truncation),
Mode(Truncation.getMode()), Logic(Logic), Ctx(newFunc->getContext()) {

auto AllocScratch = [&]() {
Expand All @@ -440,28 +441,36 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
// TODO should be the callsite or the function location itself
Value *Loc = getUniquedLocStr(
&*newFunc->getEntryBlock().getFirstNonPHIOrDbgOrAlloca());
createFPRTGeneric(B, TruncChangeName, changePushArgs, B.getVoidTy(), Loc);
scratch = createFPRTGeneric(B, GetName, scratchArgs, B.getPtrTy(), Loc);
if (TC.NeedTruncChange)
createFPRTGeneric(B, TruncChangeName, changePushArgs, B.getVoidTy(),
Loc);
if (TC.NeedNewScratch)
scratch = createFPRTGeneric(B, GetName, scratchArgs, B.getPtrTy(), Loc);
for (auto &BB : *newFunc) {
if (ReturnInst *ret = dyn_cast<ReturnInst>(BB.getTerminator())) {
B.SetInsertPoint(ret);
createFPRTGeneric(B, FreeName, scratchArgs, B.getPtrTy(), Loc);
createFPRTGeneric(B, "trunc_change", changePopArgs, B.getVoidTy(),
Loc);
if (TC.NeedNewScratch)
createFPRTGeneric(B, FreeName, scratchArgs, B.getPtrTy(), Loc);
if (TC.NeedTruncChange)
createFPRTGeneric(B, "trunc_change", changePopArgs, B.getVoidTy(),
Loc);
}
}
};
if (Truncation.isToFPRT()) {
if (Mode == TruncOpMode) {
if (Root) {
if (TC.NeedTruncChange || TC.NeedNewScratch)
AllocScratch();
} else {
if (!TC.NeedNewScratch) {
// make sure we passed in `void *scratch` as the final parameter
assert(newFunc->arg_size() == oldFunc->arg_size() + 1);
scratch = newFunc->getArg(newFunc->arg_size() - 1);
assert(scratch->getType()->isPointerTy());
}
} else if (Mode == TruncOpFullModuleMode) {
assert(TC.NeedNewScratch);
assert(!TC.NeedTruncChange);
// TODO we need to do a call to trunc_change in the module constructor
AllocScratch();
}
}
Expand Down Expand Up @@ -833,94 +842,44 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
return cast<Instruction>(getNewFromOriginal((llvm::Value *)v));
}

bool handleKnownCalls(llvm::CallBase &call, llvm::Function *called,
llvm::StringRef funcName,
llvm::CallBase *const newCall) {
return false;
}

Value *GetShadow(RequestContext &ctx, Value *v, bool root) {
Value *GetShadow(RequestContext &ctx, Value *v, bool WillPassScratch) {
if (auto F = dyn_cast<Function>(v))
return Logic.CreateTruncateFunc(ctx, F, Truncation, Mode, root);
return Logic.CreateTruncateFunc(
ctx, F,
TruncationConfiguration{Truncation, Mode, !WillPassScratch, false,
WillPassScratch});
llvm::errs() << " unknown get truncated func: " << *v << "\n";
llvm_unreachable("unknown get truncated func");
return v;
}
// void visitInvokeInst(llvm::InvokeInst &CI) {
// // fprintf(stderr, "Won't handle invoke instruction.\n");
// EmitWarning("FPNoInvoke", CI,
// "Will not handle invoke instruction.", CI);
// }

// Return
void visitCallBase(llvm::CallBase &CI) {
Intrinsic::ID ID;
StringRef funcName = getFuncNameFromCall(const_cast<CallBase *>(&CI));
if (isMemFreeLibMFunction(funcName, &ID))
if (handleIntrinsic(CI, ID))
return;

using namespace llvm;

CallBase *const newCall = cast<CallBase>(getNewFromOriginal(&CI));
IRBuilder<> BuilderZ(newCall);

if (auto called = CI.getCalledFunction())
if (handleKnownCalls(CI, called, getFuncNameFromCall(&CI), newCall))
return;

// if (!newCall->getDebugLoc()) {
// Function *ContainingF = newCall->getFunction();
// newCall->setDebugLoc(DILocation::get(ContainingF->getContext(), 0, 0,
// ContainingF->getSubprogram()));
// }
struct FunctionToTrunc {
Function *Func;
bool IsCallback;
unsigned ArgNo;
unsigned getCallbackArgNo() {
assert(isCallbackFunc());
return ArgNo;
}
bool isCallbackFunc() { return IsCallback; }
};

if (Mode == TruncOpMode || Mode == TruncMemMode) {
RequestContext ctx(&CI, &BuilderZ);
Function *Func = CI.getCalledFunction();
if (Func && !Func->empty()) {
bool truncOpIgnore = Func->getName().contains("raptor_trunc_op_ignore");
bool truncMemIgnore =
Func->getName().contains("raptor_trunc_mem_ignore");
bool truncIgnore = Func->getName().contains("raptor_trunc_ignore");
truncIgnore |= truncOpIgnore && Mode == TruncOpMode;
truncIgnore |= truncMemIgnore && Mode == TruncMemMode;
if (!truncIgnore) {
if (scratch && Mode == TruncOpMode && isa<CallInst>(&CI)) {
auto val = GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand()),
false);
Function *F = cast<Function>(val);
IRBuilder<> B(newCall);
SmallVector<Value *> args(newCall->args());
args.push_back(scratch);
CallInst *newNewCall = B.CreateCall(F, args);
newNewCall->copyMetadata(*newCall);
newNewCall->copyIRFlags(newCall);
newNewCall->setAttributes(newCall->getAttributes());
newNewCall->setCallingConv(newCall->getCallingConv());
// newNewCall->setTailCallKind(newCall->getTailCallKind());
newNewCall->setDebugLoc(newCall->getDebugLoc());
newCall->replaceAllUsesWith(newNewCall);
newCall->eraseFromParent();
// TODO not sure if we need to change the originalToNewFn mapping.
} else {
auto val =
GetShadow(ctx, getNewFromOriginal(CI.getCalledOperand()), true);
newCall->setCalledOperand(val);
}
}
} else if (!Func) {
SmallVector<FunctionToTrunc, 1> getFunctionToTruncate(llvm::CallBase &CI) {
SmallVector<FunctionToTrunc, 1> ToTrunc;
auto MaybeInsert = [&](Function *F, bool IsCallback, unsigned ArgNo = 0) {
if (!F) {
switch (Mode) {
case TruncMemMode:
case TruncOpMode:
// fprintf(stderr, "Won't follow indirect call.\n");
EmitWarning("FPNoFollow", CI,
"Will not follow FP through this indirect call.", CI);
break;
default:
llvm_unreachable("Unknown trunc mode");
}
} else {
return;
}
if (F->isDeclaration()) {
switch (Mode) {
case TruncMemMode:
EmitWarning("FPNoFollow", CI,
Expand All @@ -937,9 +896,87 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
default:
llvm_unreachable("Unknown trunc mode");
}
return;
}
ToTrunc.push_back(FunctionToTrunc{F, IsCallback, ArgNo});
};

Function *Callee = CI.getCalledFunction();
MaybeInsert(Callee, false);

if (!Callee)
return ToTrunc;
if (!Callee->isDeclaration())
return ToTrunc;

MDNode *CallbackMD = Callee->getMetadata(LLVMContext::MD_callback);
if (CallbackMD) {
for (const MDOperand &Op : CallbackMD->operands()) {
MDNode *OpMD = cast<MDNode>(Op.get());
auto *CBCalleeIdxAsCM = cast<ConstantAsMetadata>(OpMD->getOperand(0));
uint64_t CBCalleeIdx =
cast<ConstantInt>(CBCalleeIdxAsCM->getValue())->getZExtValue();
MaybeInsert(dyn_cast<Function>(CI.getArgOperand(CBCalleeIdx)), true,
CBCalleeIdx);
}
}

return ToTrunc;
}

// Return
void visitCallBase(llvm::CallBase &CI) {
Intrinsic::ID ID;
StringRef funcName = getFuncNameFromCall(const_cast<CallBase *>(&CI));
if (isMemFreeLibMFunction(funcName, &ID))
if (handleIntrinsic(CI, ID))
return;

using namespace llvm;

CallBase *const newCall = cast<CallBase>(getNewFromOriginal(&CI));
IRBuilder<> BuilderZ(newCall);

if (Mode != TruncOpMode && Mode != TruncMemMode)
return;

RequestContext ctx(&CI, &BuilderZ);
auto FTTs = getFunctionToTruncate(CI);
auto NeedDirectCall = [&](auto FTT) {
return scratch && Mode == TruncOpMode && isa<CallInst>(&CI) &&
!FTT.isCallbackFunc();
};
for (auto &FTT : FTTs) {
assert(FTT.Func && !FTT.Func->empty());
if (!NeedDirectCall(FTT)) {
auto val = GetShadow(ctx, getNewFromOriginal(FTT.Func), false);
if (FTT.isCallbackFunc()) {
newCall->setArgOperand(FTT.getCallbackArgNo(), val);
} else {
newCall->setCalledOperand(val);
}
}
}
for (auto &FTT : FTTs) {
assert(FTT.Func && !FTT.Func->empty());
if (NeedDirectCall(FTT)) {
auto val = GetShadow(ctx, getNewFromOriginal(FTT.Func), true);
Function *F = cast<Function>(val);
IRBuilder<> B(newCall);
SmallVector<Value *> args(newCall->args());
args.push_back(scratch);
CallInst *newNewCall = B.CreateCall(F, args);
newNewCall->copyMetadata(*newCall);
newNewCall->copyIRFlags(newCall);
newNewCall->setAttributes(newCall->getAttributes());
newNewCall->setCallingConv(newCall->getCallingConv());
// newNewCall->setTailCallKind(newCall->getTailCallKind());
newNewCall->setDebugLoc(newCall->getDebugLoc());
newCall->replaceAllUsesWith(newNewCall);
newCall->eraseFromParent();
// TODO not sure if we need to change the originalToNewFn mapping.
}
}
return;
}
void visitPHINode(llvm::PHINode &PN) {
switch (Mode) {
Expand Down Expand Up @@ -1011,9 +1048,8 @@ bool RaptorLogic::CountInFunc(llvm::Function *F, FloatRepresentation FR) {

llvm::Function *RaptorLogic::CreateTruncateFunc(RequestContext Context,
llvm::Function *ToTrunc,
FloatTruncation Truncation,
TruncateMode Mode, bool Root) {
TruncateCacheKey tup(ToTrunc, Truncation, Mode, Root);
TruncationConfiguration TC) {
TruncateCacheKey tup(ToTrunc, TC);
if (TruncateCachedFunctions.find(tup) != TruncateCachedFunctions.end()) {
return TruncateCachedFunctions.find(tup)->second;
}
Expand All @@ -1027,21 +1063,20 @@ llvm::Function *RaptorLogic::CreateTruncateFunc(RequestContext Context,
Params.push_back(OrigFTy->getParamType(i));
}

if (Mode == TruncOpMode && !Root) {
if (TC.ScratchFromArgs) {
// void *scratch
Params.push_back(B.getPtrTy());
}

Type *NewTy = ToTrunc->getReturnType();

FunctionType *FTy = FunctionType::get(NewTy, Params, ToTrunc->isVarArg());
std::string truncName =
std::string("__raptor_done_truncate_") + truncateModeStr(Mode) +
"_func_" + Truncation.mangleTruncation() + "_" + ToTrunc->getName().str();
std::string truncName = std::string("__raptor_done_truncate_") + TC.mangle() +
"_" + ToTrunc->getName().str();
Function *NewF = Function::Create(FTy, ToTrunc->getLinkage(), truncName,
ToTrunc->getParent());

if (Mode != TruncOpFullModuleMode)
if (TC.Mode != TruncOpFullModuleMode)
NewF->setLinkage(Function::LinkageTypes::InternalLinkage);

TruncateCachedFunctions[tup] = NewF;
Expand Down Expand Up @@ -1090,8 +1125,7 @@ llvm::Function *RaptorLogic::CreateTruncateFunc(RequestContext Context,

NewF->setLinkage(Function::LinkageTypes::InternalLinkage);

TruncateGenerator Handle(originalToNewFn, Truncation, ToTrunc, NewF, *this,
Root);
TruncateGenerator Handle(originalToNewFn, ToTrunc, NewF, *this, TC);
for (auto &BB : *ToTrunc)
for (auto &I : BB)
Handle.visit(&I);
Expand Down
Loading