Skip to content

Commit c6f7a42

Browse files
[MUSA] enable fp16/fast_fp16/bf16_mma on PH1 (#17551)
* [MUSA] enable fp16/fast_fp16/bf16_mma on PH1 Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> * Update ggml/src/ggml-cuda/fattn-vec.cuh Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Update ggml/src/ggml-cuda/fattn-vec.cuh Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Update ggml/src/ggml-cuda/fattn-tile.cuh Co-authored-by: Johannes Gäßler <johannesg@5d6.de> * Address review comments Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> --------- Signed-off-by: Xiaodong Ye <xiaodong.ye@mthreads.com> Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
1 parent 2e7ef98 commit c6f7a42

File tree

5 files changed

+21
-14
lines changed

5 files changed

+21
-14
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,12 @@
8484

8585
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
8686
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
87-
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
87+
#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000
8888

8989
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
9090
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
91-
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
92-
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
91+
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1)
92+
#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1)
9393

9494
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
9595
# define GGML_CUDA_USE_CUB
@@ -212,9 +212,9 @@ static const char * cu_get_error_str(CUresult err) {
212212
#define GGML_USE_VMM
213213
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
214214

215-
#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
215+
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
216216
#define FP16_AVAILABLE
217-
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
217+
#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
218218

219219
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
220220
#define FAST_FP16_AVAILABLE
@@ -250,12 +250,14 @@ static const char * cu_get_error_str(CUresult err) {
250250
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
251251

252252
static bool fp16_available(const int cc) {
253-
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
253+
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
254+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
254255
}
255256

256257
static bool fast_fp16_available(const int cc) {
257258
return GGML_CUDA_CC_IS_AMD(cc) ||
258-
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
259+
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) ||
260+
(GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc));
259261
}
260262

261263
// To be used for feature selection of external libraries, e.g. cuBLAS.
@@ -272,7 +274,9 @@ static bool fp16_mma_hardware_available(const int cc) {
272274
}
273275

274276
static bool bf16_mma_hardware_available(const int cc) {
275-
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
277+
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||
278+
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 ||
279+
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
276280
}
277281

278282
static bool fp32_mma_hardware_available(const int cc) {

ggml/src/ggml-cuda/cpy.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
8686
}
8787
}
8888
}
89+
90+
GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,
91+
nb12, nb13);
8992
}
9093

9194
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
@@ -202,7 +205,7 @@ static void ggml_cpy_scalar_cuda(
202205
ne00n = ne00;
203206
ne01n = ne01;
204207
ne02n = ne02;
205-
} else if (nb00 > nb02) {
208+
} else {
206209
ne00n = ne00;
207210
ne01n = ne01*ne02;
208211
ne02n = 1;

ggml/src/ggml-cuda/fattn-tile.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
609609
float KQ_sum_add = 0.0f;
610610
#pragma unroll
611611
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
612-
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
612+
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast<uint32_t>(k_VKQ_sup) ?
613613
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
614614
KQ_sum_add += val;
615615
tmp[i0/(np*warp_size)][jc1] = val;

ggml/src/ggml-cuda/fattn-vec.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ static __global__ void flash_attn_ext_vec(
155155
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
156156
const int i = i0 + threadIdx.x;
157157

158-
if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) {
158+
if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
159159
tmp_q_i32[i] = 0;
160160
}
161161
}
@@ -272,7 +272,7 @@ static __global__ void flash_attn_ext_vec(
272272

273273
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
274274

275-
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) {
275+
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
276276
KQ_reg[j] = sum;
277277
}
278278
}

ggml/src/ggml-cuda/mma.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -889,8 +889,8 @@ namespace ggml_cuda_mma {
889889
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
890890
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
891891
#else
892-
tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D;
893-
tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A;
892+
tile <16, 8, float> * D16 = reinterpret_cast<tile <16, 8, float> *>(&D);
893+
const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
894894
mma(D16[0], A16[0], B);
895895
mma(D16[1], A16[1], B);
896896
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

0 commit comments

Comments
 (0)