-
Couldn't load subscription status.
- Fork 15k
[NVPTX] Add ex2.approx bf16 support and cleanup intrinsic definition #165446
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-clang-codegen @llvm/pr-subscribers-backend-nvptx Author: Alex MacLean (AlexMaclean) ChangesFull diff: https://github.com/llvm/llvm-project/pull/165446.diff 10 Files Affected:
diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
index 6da65b681df1e..5a8dd2153595c 100644
--- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
+++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp
@@ -375,28 +375,28 @@ static Value *MakeCpAsync(unsigned IntrinsicID, unsigned IntrinsicIDS,
CGF.EmitScalarExpr(E->getArg(1))});
}
-static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
- const CallExpr *E, CodeGenFunction &CGF) {
+static bool EnsureNativeHalfSupport(unsigned BuiltinID, const CallExpr *E,
+ CodeGenFunction &CGF) {
auto &C = CGF.CGM.getContext();
if (!(C.getLangOpts().NativeHalfType ||
!C.getTargetInfo().useFP16ConversionIntrinsics())) {
CGF.CGM.Error(E->getExprLoc(), C.BuiltinInfo.getQuotedName(BuiltinID) +
" requires native half type support.");
- return nullptr;
+ return false;
}
+ return true;
+}
- if (BuiltinID == NVPTX::BI__nvvm_ldg_h || BuiltinID == NVPTX::BI__nvvm_ldg_h2)
- return MakeLdg(CGF, E);
-
- if (IntrinsicID == Intrinsic::nvvm_ldu_global_f)
- return MakeLdu(IntrinsicID, CGF, E);
+static Value *MakeHalfType(Function *Intrinsic, unsigned BuiltinID,
+ const CallExpr *E, CodeGenFunction &CGF) {
+ if (!EnsureNativeHalfSupport(BuiltinID, E, CGF))
+ return nullptr;
SmallVector<Value *, 16> Args;
- auto *F = CGF.CGM.getIntrinsic(IntrinsicID);
- auto *FTy = F->getFunctionType();
+ auto *FTy = Intrinsic->getFunctionType();
unsigned ICEArguments = 0;
ASTContext::GetBuiltinTypeError Error;
- C.GetBuiltinType(BuiltinID, Error, &ICEArguments);
+ CGF.CGM.getContext().GetBuiltinType(BuiltinID, Error, &ICEArguments);
assert(Error == ASTContext::GE_None && "Should not codegen an error");
for (unsigned i = 0, e = E->getNumArgs(); i != e; ++i) {
assert((ICEArguments & (1 << i)) == 0);
@@ -407,8 +407,14 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
Args.push_back(ArgValue);
}
- return CGF.Builder.CreateCall(F, Args);
+ return CGF.Builder.CreateCall(Intrinsic, Args);
+}
+
+static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID,
+ const CallExpr *E, CodeGenFunction &CGF) {
+ return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF);
}
+
} // namespace
Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
@@ -913,9 +919,14 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
}
// The following builtins require half type support
case NVPTX::BI__nvvm_ex2_approx_f16:
- return MakeHalfType(Intrinsic::nvvm_ex2_approx_f16, BuiltinID, E, *this);
+ return MakeHalfType(
+ CGM.getIntrinsic(Intrinsic::nvvm_ex2_approx, Builder.getHalfTy()),
+ BuiltinID, E, *this);
case NVPTX::BI__nvvm_ex2_approx_f16x2:
- return MakeHalfType(Intrinsic::nvvm_ex2_approx_f16x2, BuiltinID, E, *this);
+ return MakeHalfType(
+ CGM.getIntrinsic(Intrinsic::nvvm_ex2_approx,
+ FixedVectorType::get(Builder.getHalfTy(), 2)),
+ BuiltinID, E, *this);
case NVPTX::BI__nvvm_ff2f16x2_rn:
return MakeHalfType(Intrinsic::nvvm_ff2f16x2_rn, BuiltinID, E, *this);
case NVPTX::BI__nvvm_ff2f16x2_rn_relu:
@@ -1049,12 +1060,23 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID,
case NVPTX::BI__nvvm_fabs_d:
return Builder.CreateUnaryIntrinsic(Intrinsic::fabs,
EmitScalarExpr(E->getArg(0)));
+ case NVPTX::BI__nvvm_ex2_approx_d:
+ case NVPTX::BI__nvvm_ex2_approx_f:
+ return Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_ex2_approx,
+ EmitScalarExpr(E->getArg(0)));
+ case NVPTX::BI__nvvm_ex2_approx_ftz_f:
+ return Builder.CreateUnaryIntrinsic(Intrinsic::nvvm_ex2_approx_ftz,
+ EmitScalarExpr(E->getArg(0)));
case NVPTX::BI__nvvm_ldg_h:
case NVPTX::BI__nvvm_ldg_h2:
- return MakeHalfType(Intrinsic::not_intrinsic, BuiltinID, E, *this);
+ if (!EnsureNativeHalfSupport(BuiltinID, E, *this))
+ return nullptr;
+ return MakeLdg(*this, E);
case NVPTX::BI__nvvm_ldu_h:
case NVPTX::BI__nvvm_ldu_h2:
- return MakeHalfType(Intrinsic::nvvm_ldu_global_f, BuiltinID, E, *this);
+ if (!EnsureNativeHalfSupport(BuiltinID, E, *this))
+ return nullptr;
+ return MakeLdu(Intrinsic::nvvm_ldu_global_f, *this, E);
case NVPTX::BI__nvvm_cp_async_ca_shared_global_4:
return MakeCpAsync(Intrinsic::nvvm_cp_async_ca_shared_global_4,
Intrinsic::nvvm_cp_async_ca_shared_global_4_s, *this, E,
diff --git a/clang/test/CodeGen/builtins-nvptx-native-half-type-native.c b/clang/test/CodeGen/builtins-nvptx-native-half-type-native.c
index 035c4c6066be2..60a35f4fe0c37 100644
--- a/clang/test/CodeGen/builtins-nvptx-native-half-type-native.c
+++ b/clang/test/CodeGen/builtins-nvptx-native-half-type-native.c
@@ -8,7 +8,7 @@
typedef __fp16 __fp16v2 __attribute__((ext_vector_type(2)));
// CHECK: call half @llvm.nvvm.ex2.approx.f16(half {{.*}})
-// CHECK: call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> {{.*}})
+// CHECK: call <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half> {{.*}})
// CHECK: call half @llvm.nvvm.fma.rn.relu.f16(half {{.*}}, half {{.*}}, half {{.*}})
// CHECK: call half @llvm.nvvm.fma.rn.ftz.relu.f16(half {{.*}}, half {{.*}}, half {{.*}})
// CHECK: call <2 x half> @llvm.nvvm.fma.rn.relu.f16x2(<2 x half> {{.*}}, <2 x half> {{.*}}, <2 x half> {{.*}})
diff --git a/clang/test/CodeGen/builtins-nvptx-native-half-type.c b/clang/test/CodeGen/builtins-nvptx-native-half-type.c
index 01a004efd71e4..1f16c7e54b85d 100644
--- a/clang/test/CodeGen/builtins-nvptx-native-half-type.c
+++ b/clang/test/CodeGen/builtins-nvptx-native-half-type.c
@@ -41,7 +41,7 @@ __device__ void nvvm_ex2_sm75() {
#if __CUDA_ARCH__ >= 750
// CHECK_PTX70_SM75: call half @llvm.nvvm.ex2.approx.f16
__nvvm_ex2_approx_f16(0.1f16);
- // CHECK_PTX70_SM75: call <2 x half> @llvm.nvvm.ex2.approx.f16x2
+ // CHECK_PTX70_SM75: call <2 x half> @llvm.nvvm.ex2.approx.v2f16
__nvvm_ex2_approx_f16x2({0.1f16, 0.7f16});
#endif
// CHECK: ret void
diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td
index 719181a09f475..2710853e17688 100644
--- a/llvm/include/llvm/IR/IntrinsicsNVVM.td
+++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td
@@ -1334,15 +1334,8 @@ let TargetPrefix = "nvvm" in {
//
let IntrProperties = [IntrNoMem] in {
foreach ftz = ["", "_ftz"] in
- def int_nvvm_ex2_approx # ftz # _f : NVVMBuiltin,
- DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty]>;
-
- def int_nvvm_ex2_approx_d : NVVMBuiltin,
- DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty]>;
- def int_nvvm_ex2_approx_f16 :
- DefaultAttrsIntrinsic<[llvm_half_ty], [llvm_half_ty]>;
- def int_nvvm_ex2_approx_f16x2 :
- DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty]>;
+ def int_nvvm_ex2_approx # ftz :
+ DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;
foreach ftz = ["", "_ftz"] in
def int_nvvm_lg2_approx # ftz # _f : NVVMBuiltin,
diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp
index b838e36c8824f..4d4e9f9b31fcf 100644
--- a/llvm/lib/IR/AutoUpgrade.cpp
+++ b/llvm/lib/IR/AutoUpgrade.cpp
@@ -1504,6 +1504,10 @@ static bool upgradeIntrinsicFunction1(Function *F, Function *&NewFn,
else if (Name.consume_front("fabs."))
// nvvm.fabs.{f,ftz.f,d}
Expand = Name == "f" || Name == "ftz.f" || Name == "d";
+ else if (Name.consume_front("ex2.approx."))
+ // nvvm.ex2.approx.{f,ftz.f,d,f16x2}
+ Expand =
+ Name == "f" || Name == "ftz.f" || Name == "d" || Name == "f16x2";
else if (Name.consume_front("max.") || Name.consume_front("min."))
// nvvm.{min,max}.{i,ii,ui,ull}
Expand = Name == "s" || Name == "i" || Name == "ll" || Name == "us" ||
@@ -2550,6 +2554,11 @@ static Value *upgradeNVVMIntrinsicCall(StringRef Name, CallBase *CI,
Intrinsic::ID IID = (Name == "fabs.ftz.f") ? Intrinsic::nvvm_fabs_ftz
: Intrinsic::nvvm_fabs;
Rep = Builder.CreateUnaryIntrinsic(IID, CI->getArgOperand(0));
+ } else if (Name.consume_front("ex2.approx.")) {
+ // nvvm.ex2.approx.{f,ftz.f,d,f16x2}
+ Intrinsic::ID IID = Name.starts_with("ftz") ? Intrinsic::nvvm_ex2_approx_ftz
+ : Intrinsic::nvvm_ex2_approx;
+ Rep = Builder.CreateUnaryIntrinsic(IID, CI->getArgOperand(0));
} else if (Name.starts_with("atomic.load.add.f32.p") ||
Name.starts_with("atomic.load.add.f64.p")) {
Value *Ptr = CI->getArgOperand(0);
diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
index c923f0ec907e7..8ff6cae94dd4b 100644
--- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
+++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
@@ -1605,12 +1605,17 @@ def : Pat<(int_nvvm_saturate_d f64:$a), (CVT_f64_f64 $a, CvtSAT)>;
// Exp2 Log2
//
-def : Pat<(int_nvvm_ex2_approx_ftz_f f32:$a), (EX2_APPROX_f32 $a, FTZ)>;
-def : Pat<(int_nvvm_ex2_approx_f f32:$a), (EX2_APPROX_f32 $a, NoFTZ)>;
+def : Pat<(f32 (int_nvvm_ex2_approx_ftz f32:$a)), (EX2_APPROX_f32 $a, FTZ)>;
+def : Pat<(f32 (int_nvvm_ex2_approx f32:$a)), (EX2_APPROX_f32 $a, NoFTZ)>;
let Predicates = [hasPTX<70>, hasSM<75>] in {
- def : Pat<(int_nvvm_ex2_approx_f16 f16:$a), (EX2_APPROX_f16 $a)>;
- def : Pat<(int_nvvm_ex2_approx_f16x2 v2f16:$a), (EX2_APPROX_f16x2 $a)>;
+ def : Pat<(f16 (int_nvvm_ex2_approx f16:$a)), (EX2_APPROX_f16 $a)>;
+ def : Pat<(v2f16 (int_nvvm_ex2_approx v2f16:$a)), (EX2_APPROX_f16x2 $a)>;
+}
+
+let Predicates = [hasPTX<78>, hasSM<90>] in {
+ def : Pat<(bf16 (int_nvvm_ex2_approx_ftz bf16:$a)), (EX2_APPROX_bf16 $a)>;
+ def : Pat<(v2bf16 (int_nvvm_ex2_approx_ftz v2bf16:$a)), (EX2_APPROX_bf16x2 $a)>;
}
def LG2_APPROX_f32 :
diff --git a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
index 4029e143ae2a4..37f53f5b6f0a2 100644
--- a/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.cpp
@@ -318,7 +318,7 @@ static Instruction *convertNvvmIntrinsicToLlvm(InstCombiner &IC,
// answer. These include:
//
// - nvvm_cos_approx_{f,ftz_f}
- // - nvvm_ex2_approx_{d,f,ftz_f}
+ // - nvvm_ex2_approx(_ftz)
// - nvvm_lg2_approx_{d,f,ftz_f}
// - nvvm_sin_approx_{f,ftz_f}
// - nvvm_sqrt_approx_{f,ftz_f}
diff --git a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll
index 362586af4f9b7..4fc506f1f5edf 100644
--- a/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll
+++ b/llvm/test/Assembler/auto_upgrade_nvvm_intrinsics.ll
@@ -87,6 +87,11 @@ declare void @llvm.nvvm.barrier(i32, i32)
declare void @llvm.nvvm.barrier.sync(i32)
declare void @llvm.nvvm.barrier.sync.cnt(i32, i32)
+declare float @llvm.nvvm.ex2.approx.f(float)
+declare double @llvm.nvvm.ex2.approx.d(double)
+declare <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half>)
+declare float @llvm.nvvm.ex2.approx.ftz.f(float)
+
; CHECK-LABEL: @simple_upgrade
define void @simple_upgrade(i32 %a, i64 %b, i16 %c) {
; CHECK: call i32 @llvm.bitreverse.i32(i32 %a)
@@ -355,3 +360,15 @@ define void @cta_barriers(i32 %x, i32 %y) {
call void @llvm.nvvm.barrier.sync.cnt(i32 %x, i32 %y)
ret void
}
+
+define void @nvvm_ex2_approx(float %a, double %b, half %c, <2 x half> %d) {
+; CHECK: call float @llvm.nvvm.ex2.approx.f32(float %a)
+; CHECK: call double @llvm.nvvm.ex2.approx.f64(double %b)
+; CHECK: call <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half> %d)
+; CHECK: call float @llvm.nvvm.ex2.approx.ftz.f32(float %a)
+ %r1 = call float @llvm.nvvm.ex2.approx.f(float %a)
+ %r2 = call double @llvm.nvvm.ex2.approx.d(double %b)
+ %r3 = call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> %d)
+ %r4 = call float @llvm.nvvm.ex2.approx.ftz.f(float %a)
+ ret void
+}
diff --git a/llvm/test/CodeGen/NVPTX/f16-ex2.ll b/llvm/test/CodeGen/NVPTX/f16-ex2.ll
index ee79f9d6d056f..af3fe67269205 100644
--- a/llvm/test/CodeGen/NVPTX/f16-ex2.ll
+++ b/llvm/test/CodeGen/NVPTX/f16-ex2.ll
@@ -1,12 +1,13 @@
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
-; RUN: llc < %s -mcpu=sm_75 -mattr=+ptx70 | FileCheck --check-prefixes=CHECK-FP16 %s
-; RUN: %if ptxas-sm_75 && ptxas-isa-7.0 %{ llc < %s -mcpu=sm_75 -mattr=+ptx70 | %ptxas-verify -arch=sm_75 %}
+; RUN: llc < %s -mcpu=sm_90 -mattr=+ptx78 | FileCheck --check-prefixes=CHECK-FP16 %s
+; RUN: %if ptxas-sm_90 && ptxas-isa-7.8 %{ llc < %s -mcpu=sm_90 -mattr=+ptx78 | %ptxas-verify -arch=sm_90 %}
target triple = "nvptx64-nvidia-cuda"
declare half @llvm.nvvm.ex2.approx.f16(half)
-declare <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half>)
+declare <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half>)
+declare bfloat @llvm.nvvm.ex2.approx.ftz.bf16(bfloat)
+declare <2 x bfloat> @llvm.nvvm.ex2.approx.ftz.v2bf16(<2 x bfloat>)
-; CHECK-LABEL: ex2_half
define half @ex2_half(half %0) {
; CHECK-FP16-LABEL: ex2_half(
; CHECK-FP16: {
@@ -21,7 +22,6 @@ define half @ex2_half(half %0) {
ret half %res
}
-; CHECK-LABEL: ex2_2xhalf
define <2 x half> @ex2_2xhalf(<2 x half> %0) {
; CHECK-FP16-LABEL: ex2_2xhalf(
; CHECK-FP16: {
@@ -32,6 +32,34 @@ define <2 x half> @ex2_2xhalf(<2 x half> %0) {
; CHECK-FP16-NEXT: ex2.approx.f16x2 %r2, %r1;
; CHECK-FP16-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-FP16-NEXT: ret;
- %res = call <2 x half> @llvm.nvvm.ex2.approx.f16x2(<2 x half> %0)
+ %res = call <2 x half> @llvm.nvvm.ex2.approx.v2f16(<2 x half> %0)
ret <2 x half> %res
}
+
+define bfloat @ex2_bfloat(bfloat %0) {
+; CHECK-FP16-LABEL: ex2_bfloat(
+; CHECK-FP16: {
+; CHECK-FP16-NEXT: .reg .b16 %rs<3>;
+; CHECK-FP16-EMPTY:
+; CHECK-FP16-NEXT: // %bb.0:
+; CHECK-FP16-NEXT: ld.param.b16 %rs1, [ex2_bfloat_param_0];
+; CHECK-FP16-NEXT: ex2.approx.ftz.bf16 %rs2, %rs1;
+; CHECK-FP16-NEXT: st.param.b16 [func_retval0], %rs2;
+; CHECK-FP16-NEXT: ret;
+ %res = call bfloat @llvm.nvvm.ex2.approx.ftz.bf16(bfloat %0)
+ ret bfloat %res
+}
+
+define <2 x bfloat> @ex2_2xbfloat(<2 x bfloat> %0) {
+; CHECK-FP16-LABEL: ex2_2xbfloat(
+; CHECK-FP16: {
+; CHECK-FP16-NEXT: .reg .b32 %r<3>;
+; CHECK-FP16-EMPTY:
+; CHECK-FP16-NEXT: // %bb.0:
+; CHECK-FP16-NEXT: ld.param.b32 %r1, [ex2_2xbfloat_param_0];
+; CHECK-FP16-NEXT: ex2.approx.ftz.bf16x2 %r2, %r1;
+; CHECK-FP16-NEXT: st.param.b32 [func_retval0], %r2;
+; CHECK-FP16-NEXT: ret;
+ %res = call <2 x bfloat> @llvm.nvvm.ex2.approx.ftz.v2bf16(<2 x bfloat> %0)
+ ret <2 x bfloat> %res
+}
diff --git a/llvm/test/CodeGen/NVPTX/f32-ex2.ll b/llvm/test/CodeGen/NVPTX/f32-ex2.ll
index 796d80d3c2c39..97b9d35be371e 100644
--- a/llvm/test/CodeGen/NVPTX/f32-ex2.ll
+++ b/llvm/test/CodeGen/NVPTX/f32-ex2.ll
@@ -3,7 +3,8 @@
; RUN: %if ptxas-sm_50 && ptxas-isa-3.2 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_50 -mattr=+ptx32 | %ptxas-verify -arch=sm_50 %}
target triple = "nvptx-nvidia-cuda"
-declare float @llvm.nvvm.ex2.approx.f(float)
+declare float @llvm.nvvm.ex2.approx.f32(float)
+declare float @llvm.nvvm.ex2.approx.ftz.f32(float)
; CHECK-LABEL: ex2_float
define float @ex2_float(float %0) {
@@ -16,7 +17,7 @@ define float @ex2_float(float %0) {
; CHECK-NEXT: ex2.approx.f32 %r2, %r1;
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-NEXT: ret;
- %res = call float @llvm.nvvm.ex2.approx.f(float %0)
+ %res = call float @llvm.nvvm.ex2.approx.f32(float %0)
ret float %res
}
@@ -31,6 +32,6 @@ define float @ex2_float_ftz(float %0) {
; CHECK-NEXT: ex2.approx.ftz.f32 %r2, %r1;
; CHECK-NEXT: st.param.b32 [func_retval0], %r2;
; CHECK-NEXT: ret;
- %res = call float @llvm.nvvm.ex2.approx.ftz.f(float %0)
+ %res = call float @llvm.nvvm.ex2.approx.ftz.f32(float %0)
ret float %res
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM with a question.
No description provided.