Skip to content

Commit 5c57184

Browse files
committed
HIP: Cleanup hipification header
Switch over to hip_bf16 from legacy hip_bfloat16 Simplify RDNA3 define Reduce swap over of new hipblas api to rocm 6.5 as this version is used for rocm 7.0 previews
1 parent 44c8947 commit 5c57184

File tree

1 file changed

+6
-7
lines changed
  • ggml/src/ggml-cuda/vendors

1 file changed

+6
-7
lines changed

ggml/src/ggml-cuda/vendors/hip.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <hip/hip_runtime.h>
55
#include <hipblas/hipblas.h>
66
#include <hip/hip_fp16.h>
7-
#include <hip/hip_bfloat16.h>
7+
#include <hip/hip_bf16.h>
88

99
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
1010
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
@@ -135,7 +135,7 @@
135135
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
136136
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
137137

138-
#if HIP_VERSION >= 70000000
138+
#if HIP_VERSION >= 60500000
139139
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
140140
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
141141
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
@@ -147,7 +147,7 @@
147147
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
148148
#define cublasComputeType_t hipblasDatatype_t
149149
#define cudaDataType_t hipblasDatatype_t
150-
#endif // HIP_VERSION >= 7000000
150+
#endif // HIP_VERSION >= 6050000
151151

152152
#if !defined(__HIP_PLATFORM_AMD__)
153153
#error "The HIP backend supports only AMD targets"
@@ -179,8 +179,7 @@
179179
#define RDNA4
180180
#endif
181181

182-
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
183-
defined(__gfx1150__) || defined(__gfx1151__)
182+
#if defined(__GFX11__)
184183
#define RDNA3
185184
#endif
186185

@@ -197,8 +196,8 @@
197196
#define __has_builtin(x) 0
198197
#endif
199198

200-
typedef hip_bfloat16 nv_bfloat16;
201-
typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix
199+
typedef __hip_bfloat16 nv_bfloat16;
200+
typedef __hip_bfloat162 nv_bfloat162;
202201

203202
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
204203
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));

0 commit comments

Comments
 (0)