|
4 | 4 | #include <hip/hip_runtime.h>
|
5 | 5 | #include <hipblas/hipblas.h>
|
6 | 6 | #include <hip/hip_fp16.h>
|
7 |
| -#include <hip/hip_bfloat16.h> |
| 7 | +#include <hip/hip_bf16.h> |
8 | 8 |
|
9 | 9 | #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
10 | 10 | #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
|
135 | 135 | #define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
136 | 136 | #define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
137 | 137 |
|
138 |
| -#if HIP_VERSION >= 70000000 |
| 138 | +#if HIP_VERSION >= 60500000 |
139 | 139 | #define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
|
140 | 140 | #define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
|
141 | 141 | #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
|
|
147 | 147 | #define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
148 | 148 | #define cublasComputeType_t hipblasDatatype_t
|
149 | 149 | #define cudaDataType_t hipblasDatatype_t
|
150 |
| -#endif // HIP_VERSION >= 7000000 |
| 150 | +#endif // HIP_VERSION >= 6050000 |
151 | 151 |
|
152 | 152 | #if !defined(__HIP_PLATFORM_AMD__)
|
153 | 153 | #error "The HIP backend supports only AMD targets"
|
|
179 | 179 | #define RDNA4
|
180 | 180 | #endif
|
181 | 181 |
|
182 |
| -#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \ |
183 |
| - defined(__gfx1150__) || defined(__gfx1151__) |
| 182 | +#if defined(__GFX11__) |
184 | 183 | #define RDNA3
|
185 | 184 | #endif
|
186 | 185 |
|
|
197 | 196 | #define __has_builtin(x) 0
|
198 | 197 | #endif
|
199 | 198 |
|
200 |
| -typedef hip_bfloat16 nv_bfloat16; |
201 |
| -typedef short2 nv_bfloat162; // FIXME there is no 2x BF16 type being defined in bfloat16.h, ad-hoc compilation fix |
| 199 | +typedef __hip_bfloat16 nv_bfloat16; |
| 200 | +typedef __hip_bfloat162 nv_bfloat162; |
202 | 201 |
|
203 | 202 | typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
204 | 203 | typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
|
0 commit comments