diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 99ec96869a7..60742a8e696 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -224,6 +224,8 @@ static const char * cu_get_error_str(CUresult err) { #define AMD_MFMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) +// WMMA is only properly supported on RDNA4, not RDNA3 +// RDNA3 has incomplete int8 WMMA support and should use fallback path #if defined(GGML_USE_HIP) && defined(RDNA4) #define AMD_WMMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(RDNA4) @@ -288,6 +290,11 @@ static bool amd_mfma_available(const int cc) { } static bool amd_wmma_available(const int cc) { + return GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc); +} + +// Integer WMMA is only available on RDNA4 +static bool amd_wmma_int_available(const int cc) { return GGML_CUDA_CC_IS_RDNA4(cc); } diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index caa08b360b5..a24a317c67c 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -152,15 +152,26 @@ namespace ggml_cuda_mma { #elif defined(AMD_WMMA_AVAILABLE) #if defined(RDNA4) static constexpr int ne = I * J / 32; +#else + // RDNA3: Accumulator uses same ne, but input tiles need more VGPRs + static constexpr int ne = (I == 16 && J == 16) ? (I * J / 32) : (I * J / 16); +#endif T x[ne] = {0}; static constexpr __device__ bool supported() { + // Integer and FP16 WMMA are supported on RDNA3/RDNA4 + if (I == 16 && J == 4) return true; + if (I == 16 && J == 8) return true; if (I == 16 && J == 16) return true; return false; } static __device__ __forceinline__ int get_i(const int l) { - if constexpr (I == 16 && J == 16) { + if constexpr (I == 16 && J == 4) { + return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 8) { + return threadIdx.x % 16; + } else if constexpr (I == 16 && J == 16) { return 8 * (threadIdx.x / 16) + l; } else { NO_DEVICE_CODE; @@ -169,14 +180,25 @@ namespace ggml_cuda_mma { } static __device__ __forceinline__ int get_j(const int l) { - if constexpr (I == 16 && J == 16) { + if constexpr (I == 16 && J == 4) { +#if defined(RDNA4) + return 2 * (threadIdx.x / 16) + l; +#else + return l; // RDNA3: ne=4, l ranges 0-3, all threads load full J dimension +#endif + } else if constexpr (I == 16 && J == 8) { +#if defined(RDNA4) + return 4 * (threadIdx.x / 16) + l; +#else + return l; // RDNA3: ne=8, l ranges 0-7, all threads load full J dimension +#endif + } else if constexpr (I == 16 && J == 16) { return threadIdx.x % 16; } else { NO_DEVICE_CODE; return -1; } } -#endif #else static constexpr int ne = I * J / 32; T x[ne] = {0}; @@ -223,7 +245,7 @@ namespace ggml_cuda_mma { return -1; } } -#endif // defined(GGML_USE_HIP) +#endif // defined(AMD_WMMA_AVAILABLE) }; template @@ -265,7 +287,11 @@ namespace ggml_cuda_mma { } } #elif defined(AMD_WMMA_AVAILABLE) - static constexpr int ne = I * J / 32; +#if defined(RDNA4) + static constexpr int ne = I * J / 32; // 4 half2 = 8 FP16 for RDNA4 +#else + static constexpr int ne = I * J / 16; // 8 half2 = 16 FP16 for RDNA3 (duplicate layout) +#endif // defind(AMD_WMMA_AVAILABLE) half2 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { @@ -284,7 +310,12 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { +#if defined(RDNA4) return 4 * (threadIdx.x / 16) + l; +#else + // RDNA3 has duplicate layout, iterate through full J dimension + return l % J; +#endif } else { NO_DEVICE_CODE; return -1; @@ -341,7 +372,11 @@ namespace ggml_cuda_mma { static constexpr int J = J_; #if defined(AMD_WMMA_AVAILABLE) - static constexpr int ne = I * J / 32; +#if defined(RDNA4) + static constexpr int ne = I * J / 32; // 4 bfloat162 = 8 BF16 for RDNA4 +#else + static constexpr int ne = I * J / 16; // 8 bfloat162 = 16 BF16 for RDNA3 (duplicate layout) +#endif // defined(RDNA4) nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; static constexpr __device__ bool supported() { @@ -360,7 +395,12 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ int get_j(const int l) { if constexpr (I == 16 && J == 8) { +#if defined(RDNA4) return 4 * (threadIdx.x / 16) + l; +#else + // RDNA3 has duplicate layout, iterate through full J dimension + return l % J; +#endif } else { NO_DEVICE_CODE; return -1; @@ -437,26 +477,17 @@ namespace ggml_cuda_mma { xi[0] = xs[0]; } #elif defined(AMD_WMMA_AVAILABLE) - if constexpr (I == 16 && J == 4) { - int64_t * xi = (int64_t *) t.x; - const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); - xi[0] = xs[0]; - }else if constexpr (I == 16 && J == 8) { - int64_t * xi = (int64_t *) t.x; - const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); - xi[0] = xs[0]; - - const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); - xi[1] = xs1[0]; - }else{ - NO_DEVICE_CODE; + // Use generic load path for all AMD WMMA to ensure correct register mapping +#pragma unroll + for (int l = 0; l < t.ne; ++l) { + t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } -#else +#else #pragma unroll for (int l = 0; l < t.ne; ++l) { t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; } -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // AMD_MFMA_AVAILABLE } template @@ -738,12 +769,21 @@ namespace ggml_cuda_mma { : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #elif defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; using floatx8_t = __attribute__((ext_vector_type(8))) float; floatx8_t& acc_frag = reinterpret_cast(D.x[0]); const halfx8_t& a_frag = reinterpret_cast(A.x[0]); const halfx8_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); +#else // RDNA3 + using halfx16_t = __attribute__((ext_vector_type(16))) _Float16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast(D.x[0]); + const halfx16_t& a_frag = reinterpret_cast(A.x[0]); + const halfx16_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag); +#endif // defined(RDNA4) #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; @@ -753,16 +793,25 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void mma( tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) { #if defined(AMD_WMMA_AVAILABLE) +#if defined(RDNA4) using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; using floatx8_t = __attribute__((ext_vector_type(8))) float; floatx8_t& acc_frag = reinterpret_cast(D.x[0]); const bf16x8_t& a_frag = reinterpret_cast(A.x[0]); const bf16x8_t& b_frag = reinterpret_cast(B.x[0]); acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag); +#else // RDNA3 + using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16; + using floatx8_t = __attribute__((ext_vector_type(8))) float; + floatx8_t& acc_frag = reinterpret_cast(D.x[0]); + const bf16x16_t& a_frag = reinterpret_cast(A.x[0]); + const bf16x16_t& b_frag = reinterpret_cast(B.x[0]); + acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag); +#endif // defined(RDNA4) #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; -#endif // AMPERE_MMA_AVAILABLE +#endif // AMD_WMMA_AVAILABLE } static __device__ __forceinline__ void mma( @@ -786,15 +835,15 @@ namespace ggml_cuda_mma { 0, 0, 0); #endif // defined(CDNA3) -#elif defined(AMD_WMMA_AVAILABLE) - using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; - int32x2_t * a_vec = (int32x2_t *) A.x; - int32x2_t * b_vec = (int32x2_t *) B.x; - +#elif defined(AMD_WMMA_INT_AVAILABLE) using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; int32x8_t * acc = (int32x8_t *) D.x; #if defined(RDNA4) + // RDNA4 uses gfx12-specific intrinsic with int32x2_t + using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; + int32x2_t * a_vec = (int32x2_t *) A.x; + int32x2_t * b_vec = (int32x2_t *) B.x; acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( true, @@ -813,6 +862,20 @@ namespace ggml_cuda_mma { acc[0], true ); +#else // RDNA3 + // RDNA3 uses generic intrinsic with int32x4_t + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + const int32x4_t& a0 = reinterpret_cast(A.x[0]); + const int32x4_t& b0 = reinterpret_cast(B.x[0]); + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + true, + a0, + true, + b0, + acc[0], + true + ); #endif // defined(RDNA4) #else @@ -889,14 +952,16 @@ namespace ggml_cuda_mma { static __device__ __forceinline__ void mma( tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_INT_AVAILABLE) + using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; + int32x8_t * acc = (int32x8_t *) D.x; + +#if defined(RDNA4) + // RDNA4 uses gfx12-specific intrinsic with int32x2_t using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; int32x2_t * a_vec = (int32x2_t *) A.x; int32x2_t * b_vec = (int32x2_t *) B.x; - using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; - int32x8_t * acc = (int32x8_t *) D.x; - acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( true, a_vec[0], @@ -905,6 +970,21 @@ static __device__ __forceinline__ void mma( acc[0], false ); +#else // RDNA3 + // RDNA3 uses generic intrinsic with int32x4_t + using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; + const int32x4_t& a0 = reinterpret_cast(A.x[0]); + const int32x4_t& b0 = reinterpret_cast(B.x[0]); + + acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( + true, + a0, + true, + b0, + acc[0], + false + ); +#endif #else GGML_UNUSED(D); GGML_UNUSED(A); diff --git a/ggml/src/ggml-cuda/mmf.cu b/ggml/src/ggml-cuda/mmf.cu index 5c51a22256a..74c63b19f23 100644 --- a/ggml/src/ggml-cuda/mmf.cu +++ b/ggml/src/ggml-cuda/mmf.cu @@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const return false; } } else { - if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) { + if (src1_ncols > 16 || amd_wmma_available(cc)) { return false; } } diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 03ceba874d8..46ab507853e 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -306,11 +306,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { return false; } - if (amd_wmma_available(cc)) { - if (GGML_CUDA_CC_IS_RDNA4(cc)) { - return true; - } + // RDNA4 has integer WMMA support, always use MMQ + if (GGML_CUDA_CC_IS_RDNA4(cc)) { + return true; } + // RDNA3 uses DP4A path (no integer WMMA), enable for small batches return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; } diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index 99760d56c72..36965fde405 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -256,7 +256,7 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) #if defined(GGML_USE_HIP) static int mmq_get_nwarps_host(const int cc, const int warp_size) { - return amd_mfma_available(cc) ? 8 : 256/warp_size; + return (amd_mfma_available(cc) || amd_wmma_available(cc)) ? 8 : 256/warp_size; } #else static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) { @@ -1178,7 +1178,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( } } } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#elif defined(AMD_WMMA_INT_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles typedef tile<16, 4, int> tile_A; typedef tile<16, 4, int> tile_B; typedef tile<16, 16, int> tile_C; @@ -1500,7 +1500,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( } } } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#elif defined(AMD_WMMA_INT_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles typedef tile<16, 4, int> tile_A; typedef tile<16, 4, int> tile_B; @@ -1888,7 +1888,7 @@ template static __device__ __forceinline__ void loa #else int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; { -#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_INT_AVAILABLE) if (need_check) { i = min(i, i_max); } @@ -2313,7 +2313,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( } } } -#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles +#elif defined(AMD_WMMA_INT_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles typedef tile<16, 4, int> tile_A; typedef tile<16, 4, int> tile_B; typedef tile<16, 16, int> tile_C; @@ -3018,7 +3018,7 @@ static __device__ __forceinline__ void mmq_write_back_mma( #else typedef tile<16, 8, int> tile_C; constexpr int rows_per_warp = 2 * granularity; -#endif // defined(AMD_MFMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_INT_AVAILABLE) constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); @@ -3232,7 +3232,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( #else constexpr vec_dot_mmq_t vec_dot = mmq_type_traits::vec_dot_dp4a; constexpr mmq_write_back_t write_back = mmq_write_back_dp4a; -#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_INT_AVAILABLE) constexpr int blocks_per_iter = MMQ_ITER_K / qk;