Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 38 additions & 16 deletions clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -375,28 +375,28 @@ static Value *MakeCpAsync(unsigned IntrinsicID, unsigned IntrinsicIDS,
CGF.EmitScalarExpr(E->getArg(1))});
}

static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
const CallExpr *E, CodeGenFunction &CGF) {
static bool EnsureNativeHalfSupport(unsigned BuiltinID, const CallExpr *E,
CodeGenFunction &CGF) {
auto &C = CGF.CGM.getContext();
if (!(C.getLangOpts().NativeHalfType ||
!C.getTargetInfo().useFP16ConversionIntrinsics())) {
CGF.CGM.Error(E->getExprLoc(), C.BuiltinInfo.getQuotedName(BuiltinID) +
" requires native half type support.");
return nullptr;
return false;
}
return true;
}

if (BuiltinID == NVPTX::BI__nvvm_ldg_h || BuiltinID == NVPTX::BI__nvvm_ldg_h2)
return MakeLdg(CGF, E);

if (IntrinsicID == Intrinsic::nvvm_ldu_global_f)
return MakeLdu(IntrinsicID, CGF, E);
static Value *MakeHalfType(Function *Intrinsic, unsigned BuiltinID,
const CallExpr *E, CodeGenFunction &CGF) {
if (!EnsureNativeHalfSupport(BuiltinID, E, CGF))
return nullptr;

SmallVector<Value *, 16> Args;
auto *F = CGF.CGM.getIntrinsic(IntrinsicID);
auto *FTy = F->getFunctionType();
auto *FTy = Intrinsic->getFunctionType();
unsigned ICEArguments = 0;
ASTContext::GetBuiltinTypeError Error;
C.GetBuiltinType(BuiltinID, Error, &ICEArguments);
CGF.CGM.getContext().GetBuiltinType(BuiltinID, Error, &ICEArguments);
assert(Error == ASTContext::GE_None && "Should not codegen an error");
for (unsigned i = 0, e = E->getNumArgs(); i != e; ++i) {
assert((ICEArguments & (1 << i)) == 0);
Expand All @@ -407,8 +407,14 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
Args.push_back(ArgValue);
}

return CGF.Builder.CreateCall(F, Args);
return CGF.Builder.CreateCall(Intrinsic, Args);
}

static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
const CallExpr *E, CodeGenFunction &CGF) {
return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF);
}

} // namespace

Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
Expand Down Expand Up @@ -913,9 +919,14 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
}
// The following builtins require half type support
case NVPTX::BI__nvvm_ex2_approx_f16:
return MakeHalfType(Intrinsic::nvvm_ex2_approx_f16, BuiltinID, E, *this);
return MakeHalfType(
CGM.getIntrinsic(Intrinsic::nvvm_ex2_approx, Builder.getHalfTy()),
BuiltinID, E, *this);
case NVPTX::BI__nvvm_ex2_approx_f16x2:
return MakeHalfType(Intrinsic::nvvm_ex2_approx_f16x2, BuiltinID, E, *this);
return MakeHalfType(
CGM.getIntrinsic(Intrinsic::nvvm_ex2_approx,
FixedVectorType::get(Builder.getHalfTy(), 2)),
BuiltinID, E, *this);
case NVPTX::BI__nvvm_ff2f16x2_rn:
return MakeHalfType(Intrinsic::nvvm_ff2f16x2_rn, BuiltinID, E, *this);
case NVPTX::BI__nvvm_ff2f16x2_rn_relu:
Expand Down Expand Up @@ -1049,12 +1060,23 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
case NVPTX::BI__nvvm_fabs_d:
return Builder.CreateUnaryIntrinsic(Intrinsic::fabs,
EmitScalarExpr(E->getArg(0)));
case NVPTX::BI__nvvm_ex2_approx_d:
case NVPTX::BI__nvvm_ex2_approx_f:
return Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_ex2_approx,
EmitScalarExpr(E->getArg(0)));
case NVPTX::BI__nvvm_ex2_approx_ftz_f:
return Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_ex2_approx_ftz,
EmitScalarExpr(E->getArg(0)));
case NVPTX::BI__nvvm_ldg_h:
case NVPTX::BI__nvvm_ldg_h2:
return MakeHalfType(Intrinsic::not_intrinsic, BuiltinID, E, *this);
if (!EnsureNativeHalfSupport(BuiltinID, E, *this))
return nullptr;
return MakeLdg(*this, E);
case NVPTX::BI__nvvm_ldu_h:
case NVPTX::BI__nvvm_ldu_h2:
return MakeHalfType(Intrinsic::nvvm_ldu_global_f, BuiltinID, E, *this);
if (!EnsureNativeHalfSupport(BuiltinID, E, *this))
return nullptr;
return MakeLdu(Intrinsic::nvvm_ldu_global_f, *this, E);
case NVPTX::BI__nvvm_cp_async_ca_shared_global_4:
return MakeCpAsync(Intrinsic::nvvm_cp_async_ca_shared_global_4,
Intrinsic::nvvm_cp_async_ca_shared_global_4_s, *this, E,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
typedef __fp16 __fp16v2 __attribute__((ext_vector_type(2)));

// CHECK: call half @llvm.nvvm.ex2.approx.f16(half {{.*}})
// CHECK: call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> {{.*}})
// CHECK: call <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half> {{.*}})
// CHECK: call half @llvm.nvvm.fma.rn.relu.f16(half {{.*}}, half {{.*}}, half {{.*}})
// CHECK: call half @llvm.nvvm.fma.rn.ftz.relu.f16(half {{.*}}, half {{.*}}, half {{.*}})
// CHECK: call <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}})
Expand Down
2 changes: 1 addition & 1 deletion clang/test/CodeGen/builtins-nvptx-native-half-type.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ __device__ void nvvm_ex2_sm75() {
#if __CUDA_ARCH__ >= 750
// CHECK_PTX70_SM75: call half @llvm.nvvm.ex2.approx.f16
__nvvm_ex2_approx_f16(0.1f16);
// CHECK_PTX70_SM75: call <2 x half> @llvm.nvvm.ex2.approx.f16x2
// CHECK_PTX70_SM75: call <2 x half> @llvm.nvvm.ex2.approx.v2f16
__nvvm_ex2_approx_f16x2({0.1f16, 0.7f16});
#endif
// CHECK: ret void
Expand Down
11 changes: 2 additions & 9 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -1334,15 +1334,8 @@ let TargetPrefix = "nvvm" in {
//
let IntrProperties = [IntrNoMem] in {
foreach ftz = ["", "_ftz"] in
def int_nvvm_ex2_approx # ftz # _f : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty]>;

def int_nvvm_ex2_approx_d : NVVMBuiltin,
DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty]>;
def int_nvvm_ex2_approx_f16 :
DefaultAttrsIntrinsic<[llvm_half_ty], [llvm_half_ty]>;
def int_nvvm_ex2_approx_f16x2 :
DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty]>;
def int_nvvm_ex2_approx # ftz :
DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;

foreach ftz = ["", "_ftz"] in
def int_nvvm_lg2_approx # ftz # _f : NVVMBuiltin,
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/IR/AutoUpgrade.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1504,6 +1504,10 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
else if (Name.consume_front("fabs."))
// nvvm.fabs.{f,ftz.f,d}
Expand = Name == "f" || Name == "ftz.f" || Name == "d";
else if (Name.consume_front("ex2.approx."))
// nvvm.ex2.approx.{f,ftz.f,d,f16x2}
Expand =
Name == "f" || Name == "ftz.f" || Name == "d" || Name == "f16x2";
else if (Name.consume_front("max.") || Name.consume_front("min."))
// nvvm.{min,max}.{i,ii,ui,ull}
Expand = Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
Expand Down Expand Up @@ -2550,6 +2554,11 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
Intrinsic::ID IID = (Name == "fabs.ftz.f") ? Intrinsic::nvvm_fabs_ftz
: Intrinsic::nvvm_fabs;
Rep = Builder.CreateUnaryIntrinsic(IID, CI->getArgOperand(0));
} else if (Name.consume_front("ex2.approx.")) {
// nvvm.ex2.approx.{f,ftz.f,d,f16x2}
Intrinsic::ID IID = Name.starts_with("ftz") ? Intrinsic::nvvm_ex2_approx_ftz
: Intrinsic::nvvm_ex2_approx;
Rep = Builder.CreateUnaryIntrinsic(IID, CI->getArgOperand(0));
} else if (Name.starts_with("atomic.load.add.f32.p") ||
Name.starts_with("atomic.load.add.f64.p")) {
Value *Ptr = CI->getArgOperand(0);
Expand Down
13 changes: 9 additions & 4 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -1605,12 +1605,17 @@ def : Pat<(int_nvvm_saturate_d f64:$a), (CVT_f64_f64 $a, CvtSAT)>;
// Exp2 Log2
//

def : Pat<(int_nvvm_ex2_approx_ftz_f f32:$a), (EX2_APPROX_f32 $a, FTZ)>;
def : Pat<(int_nvvm_ex2_approx_f f32:$a), (EX2_APPROX_f32 $a, NoFTZ)>;
def : Pat<(f32 (int_nvvm_ex2_approx_ftz f32:$a)), (EX2_APPROX_f32 $a, FTZ)>;
def : Pat<(f32 (int_nvvm_ex2_approx f32:$a)), (EX2_APPROX_f32 $a, NoFTZ)>;

let Predicates = [hasPTX<70>, hasSM<75>] in {
def : Pat<(int_nvvm_ex2_approx_f16 f16:$a), (EX2_APPROX_f16 $a)>;
def : Pat<(int_nvvm_ex2_approx_f16x2 v2f16:$a), (EX2_APPROX_f16x2 $a)>;
def : Pat<(f16 (int_nvvm_ex2_approx f16:$a)), (EX2_APPROX_f16 $a)>;
def : Pat<(v2f16 (int_nvvm_ex2_approx v2f16:$a)), (EX2_APPROX_f16x2 $a)>;
}

let Predicates = [hasPTX<78>, hasSM<90>] in {
def : Pat<(bf16 (int_nvvm_ex2_approx_ftz bf16:$a)), (EX2_APPROX_bf16 $a)>;
def : Pat<(v2bf16 (int_nvvm_ex2_approx_ftz v2bf16:$a)), (EX2_APPROX_bf16x2 $a)>;
}

def LG2_APPROX_f32 :
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC,
// answer. These include:
//
// - nvvm_cos_approx_{f,ftz_f}
// - nvvm_ex2_approx_{d,f,ftz_f}
// - nvvm_ex2_approx(_ftz)
// - nvvm_lg2_approx_{d,f,ftz_f}
// - nvvm_sin_approx_{f,ftz_f}
// - nvvm_sqrt_approx_{f,ftz_f}
Expand Down
17 changes: 17 additions & 0 deletions llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ declare void @llvm.nvvm.barrier(i32, i32)
declare void @llvm.nvvm.barrier.sync(i32)
declare void @llvm.nvvm.barrier.sync.cnt(i32, i32)

declare float @llvm.nvvm.ex2.approx.f(float)
declare double @llvm.nvvm.ex2.approx.d(double)
declare <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half>)
declare float @llvm.nvvm.ex2.approx.ftz.f(float)

; CHECK-LABEL: @simple_upgrade
define void @simple_upgrade(i32 %a, i64 %b, i16 %c) {
; CHECK: call i32 @llvm.bitreverse.i32(i32 %a)
Expand Down Expand Up @@ -355,3 +360,15 @@ define void @cta_barriers(i32 %x, i32 %y) {
call void @llvm.nvvm.barrier.sync.cnt(i32 %x, i32 %y)
ret void
}

define void @nvvm_ex2_approx(float %a, double %b, half %c, <2 x half> %d) {
; CHECK: call float @llvm.nvvm.ex2.approx.f32(float %a)
; CHECK: call double @llvm.nvvm.ex2.approx.f64(double %b)
; CHECK: call <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half> %d)
; CHECK: call float @llvm.nvvm.ex2.approx.ftz.f32(float %a)
%r1 = call float @llvm.nvvm.ex2.approx.f(float %a)
%r2 = call double @llvm.nvvm.ex2.approx.d(double %b)
%r3 = call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> %d)
%r4 = call float @llvm.nvvm.ex2.approx.ftz.f(float %a)
ret void
}
40 changes: 34 additions & 6 deletions llvm/test/CodeGen/NVPTX/f16-ex2.ll
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
; RUN: llc < %s -mcpu=sm_75 -mattr=+ptx70 | FileCheck --check-prefixes=CHECK-FP16 %s
; RUN: %if ptxas-sm_75 && ptxas-isa-7.0 %{ llc < %s -mcpu=sm_75 -mattr=+ptx70 | %ptxas-verify -arch=sm_75 %}
; RUN: llc < %s -mcpu=sm_90 -mattr=+ptx78 | FileCheck --check-prefixes=CHECK-FP16 %s
; RUN: %if ptxas-sm_90 && ptxas-isa-7.8 %{ llc < %s -mcpu=sm_90 -mattr=+ptx78 | %ptxas-verify -arch=sm_90 %}
target triple = "nvptx64-nvidia-cuda"

declare half @llvm.nvvm.ex2.approx.f16(half)
declare <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half>)
declare <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half>)
declare bfloat @llvm.nvvm.ex2.approx.ftz.bf16(bfloat)
declare <2 x bfloat> @llvm.nvvm.ex2.approx.ftz.v2bf16(<2 x bfloat>)

; CHECK-LABEL: ex2_half
define half @ex2_half(half %0) {
; CHECK-FP16-LABEL: ex2_half(
; CHECK-FP16: {
Expand All @@ -21,7 +22,6 @@ define half @ex2_half(half %0) {
ret half %res
}

; CHECK-LABEL: ex2_2xhalf
define <2 x half> @ex2_2xhalf(<2 x half> %0) {
; CHECK-FP16-LABEL: ex2_2xhalf(
; CHECK-FP16: {
Expand All @@ -32,6 +32,34 @@ define <2 x half> @ex2_2xhalf(<2 x half> %0) {
; CHECK-FP16-NEXT: ex2.approx.f16x2 %r2, %r1;
; CHECK-FP16-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-FP16-NEXT: ret;
%res = call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> %0)
%res = call <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half> %0)
ret <2 x half> %res
}

define bfloat @ex2_bfloat(bfloat %0) {
; CHECK-FP16-LABEL: ex2_bfloat(
; CHECK-FP16: {
; CHECK-FP16-NEXT: .reg .b16 %rs<3>;
; CHECK-FP16-EMPTY:
; CHECK-FP16-NEXT: // %bb.0:
; CHECK-FP16-NEXT: ld.param.b16 %rs1, [ex2_bfloat_param_0];
; CHECK-FP16-NEXT: ex2.approx.ftz.bf16 %rs2, %rs1;
; CHECK-FP16-NEXT: st.param.b16 [func_retval0], %rs2;
; CHECK-FP16-NEXT: ret;
%res = call bfloat @llvm.nvvm.ex2.approx.ftz.bf16(bfloat %0)
ret bfloat %res
}

define <2 x bfloat> @ex2_2xbfloat(<2 x bfloat> %0) {
; CHECK-FP16-LABEL: ex2_2xbfloat(
; CHECK-FP16: {
; CHECK-FP16-NEXT: .reg .b32 %r<3>;
; CHECK-FP16-EMPTY:
; CHECK-FP16-NEXT: // %bb.0:
; CHECK-FP16-NEXT: ld.param.b32 %r1, [ex2_2xbfloat_param_0];
; CHECK-FP16-NEXT: ex2.approx.ftz.bf16x2 %r2, %r1;
; CHECK-FP16-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-FP16-NEXT: ret;
%res = call <2 x bfloat> @llvm.nvvm.ex2.approx.ftz.v2bf16(<2 x bfloat> %0)
ret <2 x bfloat> %res
}
7 changes: 4 additions & 3 deletions llvm/test/CodeGen/NVPTX/f32-ex2.ll
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
; RUN: %if ptxas-sm_50 && ptxas-isa-3.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_50 -mattr=+ptx32 | %ptxas-verify -arch=sm_50 %}
target triple = "nvptx-nvidia-cuda"

declare float @llvm.nvvm.ex2.approx.f(float)
declare float @llvm.nvvm.ex2.approx.f32(float)
declare float @llvm.nvvm.ex2.approx.ftz.f32(float)

; CHECK-LABEL: ex2_float
define float @ex2_float(float %0) {
Expand All @@ -16,7 +17,7 @@ define float @ex2_float(float %0) {
; CHECK-NEXT: ex2.approx.f32 %r2, %r1;
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-NEXT: ret;
%res = call float @llvm.nvvm.ex2.approx.f(float %0)
%res = call float @llvm.nvvm.ex2.approx.f32(float %0)
ret float %res
}

Expand All @@ -31,6 +32,6 @@ define float @ex2_float_ftz(float %0) {
; CHECK-NEXT: ex2.approx.ftz.f32 %r2, %r1;
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-NEXT: ret;
%res = call float @llvm.nvvm.ex2.approx.ftz.f(float %0)
%res = call float @llvm.nvvm.ex2.approx.ftz.f32(float %0)
ret float %res
}
Loading