-
Notifications
You must be signed in to change notification settings - Fork 13.9k
HIP: Add RDNA3 WMMA support to MMF #17495
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
2f7cfcf
d564a35
0ec241d
bbee5fe
6b8ceeb
fd18344
7a09e22
c65dd59
48a53b5
b7c13ee
a0aa491
8c2f9a3
7a88d7c
cfc149a
59a012f
6802fbf
facded5
28b5e45
ba25661
edb86ef
5a80fb4
8aed111
9191856
c692629
9c9b0ea
816cb1d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -152,15 +152,26 @@ namespace ggml_cuda_mma { | |||||
| #elif defined(AMD_WMMA_AVAILABLE) | ||||||
| #if defined(RDNA4) | ||||||
| static constexpr int ne = I * J / 32; | ||||||
| #else | ||||||
| // RDNA3: Accumulator uses same ne, but input tiles need more VGPRs | ||||||
| static constexpr int ne = (I == 16 && J == 16) ? (I * J / 32) : (I * J / 16); | ||||||
| #endif | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Please annotate all |
||||||
| T x[ne] = {0}; | ||||||
|
|
||||||
| static constexpr __device__ bool supported() { | ||||||
| // Integer and FP16 WMMA are supported on RDNA3/RDNA4 | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| if (I == 16 && J == 4) return true; | ||||||
| if (I == 16 && J == 8) return true; | ||||||
| if (I == 16 && J == 16) return true; | ||||||
| return false; | ||||||
| } | ||||||
|
|
||||||
| static __device__ __forceinline__ int get_i(const int l) { | ||||||
| if constexpr (I == 16 && J == 16) { | ||||||
| if constexpr (I == 16 && J == 4) { | ||||||
| return threadIdx.x % 16; | ||||||
| } else if constexpr (I == 16 && J == 8) { | ||||||
| return threadIdx.x % 16; | ||||||
| } else if constexpr (I == 16 && J == 16) { | ||||||
| return 8 * (threadIdx.x / 16) + l; | ||||||
| } else { | ||||||
| NO_DEVICE_CODE; | ||||||
|
|
@@ -169,14 +180,25 @@ namespace ggml_cuda_mma { | |||||
| } | ||||||
|
|
||||||
| static __device__ __forceinline__ int get_j(const int l) { | ||||||
| if constexpr (I == 16 && J == 16) { | ||||||
| if constexpr (I == 16 && J == 4) { | ||||||
| #if defined(RDNA4) | ||||||
| return 2 * (threadIdx.x / 16) + l; | ||||||
| #else | ||||||
| return l; // RDNA3: ne=4, l ranges 0-3, all threads load full J dimension | ||||||
| #endif | ||||||
| } else if constexpr (I == 16 && J == 8) { | ||||||
| #if defined(RDNA4) | ||||||
| return 4 * (threadIdx.x / 16) + l; | ||||||
| #else | ||||||
| return l; // RDNA3: ne=8, l ranges 0-7, all threads load full J dimension | ||||||
| #endif | ||||||
| } else if constexpr (I == 16 && J == 16) { | ||||||
| return threadIdx.x % 16; | ||||||
| } else { | ||||||
| NO_DEVICE_CODE; | ||||||
| return -1; | ||||||
| } | ||||||
| } | ||||||
| #endif | ||||||
| #else | ||||||
| static constexpr int ne = I * J / 32; | ||||||
| T x[ne] = {0}; | ||||||
|
|
@@ -223,7 +245,7 @@ namespace ggml_cuda_mma { | |||||
| return -1; | ||||||
| } | ||||||
| } | ||||||
| #endif // defined(GGML_USE_HIP) | ||||||
| #endif // defined(AMD_WMMA_AVAILABLE) | ||||||
| }; | ||||||
|
|
||||||
| template <int I_, int J_> | ||||||
|
|
@@ -265,7 +287,11 @@ namespace ggml_cuda_mma { | |||||
| } | ||||||
| } | ||||||
| #elif defined(AMD_WMMA_AVAILABLE) | ||||||
| static constexpr int ne = I * J / 32; | ||||||
| #if defined(RDNA4) | ||||||
| static constexpr int ne = I * J / 32; // 4 half2 = 8 FP16 for RDNA4 | ||||||
| #else | ||||||
| static constexpr int ne = I * J / 16; // 8 half2 = 16 FP16 for RDNA3 (duplicate layout) | ||||||
| #endif // defind(AMD_WMMA_AVAILABLE) | ||||||
| half2 x[ne] = {{0.0f, 0.0f}}; | ||||||
|
|
||||||
| static constexpr __device__ bool supported() { | ||||||
|
|
@@ -284,7 +310,12 @@ namespace ggml_cuda_mma { | |||||
|
|
||||||
| static __device__ __forceinline__ int get_j(const int l) { | ||||||
| if constexpr (I == 16 && J == 8) { | ||||||
| #if defined(RDNA4) | ||||||
| return 4 * (threadIdx.x / 16) + l; | ||||||
| #else | ||||||
| // RDNA3 has duplicate layout, iterate through full J dimension | ||||||
| return l % J; | ||||||
| #endif | ||||||
| } else { | ||||||
| NO_DEVICE_CODE; | ||||||
| return -1; | ||||||
|
|
@@ -341,7 +372,11 @@ namespace ggml_cuda_mma { | |||||
| static constexpr int J = J_; | ||||||
|
|
||||||
| #if defined(AMD_WMMA_AVAILABLE) | ||||||
| static constexpr int ne = I * J / 32; | ||||||
| #if defined(RDNA4) | ||||||
| static constexpr int ne = I * J / 32; // 4 bfloat162 = 8 BF16 for RDNA4 | ||||||
| #else | ||||||
| static constexpr int ne = I * J / 16; // 8 bfloat162 = 16 BF16 for RDNA3 (duplicate layout) | ||||||
| #endif // defined(RDNA4) | ||||||
| nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; | ||||||
|
|
||||||
| static constexpr __device__ bool supported() { | ||||||
|
|
@@ -360,7 +395,12 @@ namespace ggml_cuda_mma { | |||||
|
|
||||||
| static __device__ __forceinline__ int get_j(const int l) { | ||||||
| if constexpr (I == 16 && J == 8) { | ||||||
| #if defined(RDNA4) | ||||||
| return 4 * (threadIdx.x / 16) + l; | ||||||
| #else | ||||||
| // RDNA3 has duplicate layout, iterate through full J dimension | ||||||
| return l % J; | ||||||
| #endif | ||||||
| } else { | ||||||
| NO_DEVICE_CODE; | ||||||
| return -1; | ||||||
|
|
@@ -437,26 +477,17 @@ namespace ggml_cuda_mma { | |||||
| xi[0] = xs[0]; | ||||||
| } | ||||||
| #elif defined(AMD_WMMA_AVAILABLE) | ||||||
| if constexpr (I == 16 && J == 4) { | ||||||
| int64_t * xi = (int64_t *) t.x; | ||||||
| const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I)); | ||||||
| xi[0] = xs[0]; | ||||||
| }else if constexpr (I == 16 && J == 8) { | ||||||
| int64_t * xi = (int64_t *) t.x; | ||||||
| const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I)); | ||||||
| xi[0] = xs[0]; | ||||||
|
|
||||||
| const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2); | ||||||
| xi[1] = xs1[0]; | ||||||
| }else{ | ||||||
| NO_DEVICE_CODE; | ||||||
| // Use generic load path for all AMD WMMA to ensure correct register mapping | ||||||
| #pragma unroll | ||||||
| for (int l = 0; l < t.ne; ++l) { | ||||||
| t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; | ||||||
|
Comment on lines
-440
to
+483
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't just remove the preexisting code. |
||||||
| } | ||||||
| #else | ||||||
| #else | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| #pragma unroll | ||||||
| for (int l = 0; l < t.ne; ++l) { | ||||||
| t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)]; | ||||||
| } | ||||||
| #endif // defined(AMD_MFMA_AVAILABLE) | ||||||
| #endif // AMD_MFMA_AVAILABLE | ||||||
| } | ||||||
|
|
||||||
| template <typename T> | ||||||
|
|
@@ -738,12 +769,21 @@ namespace ggml_cuda_mma { | |||||
| : "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3])); | ||||||
| #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE | ||||||
| #elif defined(AMD_WMMA_AVAILABLE) | ||||||
| #if defined(RDNA4) | ||||||
| using halfx8_t = __attribute__((ext_vector_type(8))) _Float16; | ||||||
| using floatx8_t = __attribute__((ext_vector_type(8))) float; | ||||||
| floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]); | ||||||
| const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]); | ||||||
| const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]); | ||||||
| acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag); | ||||||
| #else // RDNA3 | ||||||
| using halfx16_t = __attribute__((ext_vector_type(16))) _Float16; | ||||||
| using floatx8_t = __attribute__((ext_vector_type(8))) float; | ||||||
| floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]); | ||||||
| const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]); | ||||||
| const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]); | ||||||
| acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag); | ||||||
| #endif // defined(RDNA4) | ||||||
| #else | ||||||
| GGML_UNUSED_VARS(D, A, B); | ||||||
| NO_DEVICE_CODE; | ||||||
|
|
@@ -753,16 +793,25 @@ namespace ggml_cuda_mma { | |||||
| static __device__ __forceinline__ void mma( | ||||||
| tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) { | ||||||
| #if defined(AMD_WMMA_AVAILABLE) | ||||||
| #if defined(RDNA4) | ||||||
| using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16; | ||||||
| using floatx8_t = __attribute__((ext_vector_type(8))) float; | ||||||
| floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]); | ||||||
| const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]); | ||||||
| const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]); | ||||||
| acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag); | ||||||
| #else // RDNA3 | ||||||
| using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16; | ||||||
| using floatx8_t = __attribute__((ext_vector_type(8))) float; | ||||||
| floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]); | ||||||
| const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]); | ||||||
| const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]); | ||||||
| acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag); | ||||||
| #endif // defined(RDNA4) | ||||||
| #else | ||||||
| GGML_UNUSED_VARS(D, A, B); | ||||||
| NO_DEVICE_CODE; | ||||||
| #endif // AMPERE_MMA_AVAILABLE | ||||||
| #endif // AMD_WMMA_AVAILABLE | ||||||
| } | ||||||
|
|
||||||
| static __device__ __forceinline__ void mma( | ||||||
|
|
@@ -786,15 +835,15 @@ namespace ggml_cuda_mma { | |||||
| 0, 0, 0); | ||||||
| #endif // defined(CDNA3) | ||||||
|
|
||||||
| #elif defined(AMD_WMMA_AVAILABLE) | ||||||
| using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; | ||||||
| int32x2_t * a_vec = (int32x2_t *) A.x; | ||||||
| int32x2_t * b_vec = (int32x2_t *) B.x; | ||||||
|
|
||||||
| #elif defined(AMD_WMMA_INT_AVAILABLE) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You forgot to add |
||||||
| using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; | ||||||
| int32x8_t * acc = (int32x8_t *) D.x; | ||||||
|
|
||||||
| #if defined(RDNA4) | ||||||
| // RDNA4 uses gfx12-specific intrinsic with int32x2_t | ||||||
| using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; | ||||||
| int32x2_t * a_vec = (int32x2_t *) A.x; | ||||||
| int32x2_t * b_vec = (int32x2_t *) B.x; | ||||||
|
|
||||||
| acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( | ||||||
| true, | ||||||
|
|
@@ -813,6 +862,20 @@ namespace ggml_cuda_mma { | |||||
| acc[0], | ||||||
| true | ||||||
| ); | ||||||
| #else // RDNA3 | ||||||
| // RDNA3 uses generic intrinsic with int32x4_t | ||||||
| using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; | ||||||
| const int32x4_t& a0 = reinterpret_cast<const int32x4_t&>(A.x[0]); | ||||||
| const int32x4_t& b0 = reinterpret_cast<const int32x4_t&>(B.x[0]); | ||||||
|
|
||||||
| acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( | ||||||
| true, | ||||||
| a0, | ||||||
| true, | ||||||
| b0, | ||||||
| acc[0], | ||||||
| true | ||||||
| ); | ||||||
| #endif // defined(RDNA4) | ||||||
|
|
||||||
| #else | ||||||
|
|
@@ -889,14 +952,16 @@ namespace ggml_cuda_mma { | |||||
|
|
||||||
| static __device__ __forceinline__ void mma( | ||||||
| tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) { | ||||||
| #if defined(AMD_WMMA_AVAILABLE) | ||||||
| #if defined(AMD_WMMA_INT_AVAILABLE) | ||||||
| using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; | ||||||
| int32x8_t * acc = (int32x8_t *) D.x; | ||||||
|
|
||||||
| #if defined(RDNA4) | ||||||
| // RDNA4 uses gfx12-specific intrinsic with int32x2_t | ||||||
| using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int; | ||||||
| int32x2_t * a_vec = (int32x2_t *) A.x; | ||||||
| int32x2_t * b_vec = (int32x2_t *) B.x; | ||||||
|
|
||||||
| using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int; | ||||||
| int32x8_t * acc = (int32x8_t *) D.x; | ||||||
|
|
||||||
| acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12( | ||||||
| true, | ||||||
| a_vec[0], | ||||||
|
|
@@ -905,6 +970,21 @@ static __device__ __forceinline__ void mma( | |||||
| acc[0], | ||||||
| false | ||||||
| ); | ||||||
| #else // RDNA3 | ||||||
| // RDNA3 uses generic intrinsic with int32x4_t | ||||||
| using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int; | ||||||
| const int32x4_t& a0 = reinterpret_cast<const int32x4_t&>(A.x[0]); | ||||||
| const int32x4_t& b0 = reinterpret_cast<const int32x4_t&>(B.x[0]); | ||||||
|
|
||||||
| acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32( | ||||||
| true, | ||||||
| a0, | ||||||
| true, | ||||||
| b0, | ||||||
| acc[0], | ||||||
| false | ||||||
| ); | ||||||
| #endif | ||||||
| #else | ||||||
| GGML_UNUSED(D); | ||||||
| GGML_UNUSED(A); | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -256,7 +256,7 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/) | |
|
|
||
| #if defined(GGML_USE_HIP) | ||
| static int mmq_get_nwarps_host(const int cc, const int warp_size) { | ||
| return amd_mfma_available(cc) ? 8 : 256/warp_size; | ||
| return (amd_mfma_available(cc) || amd_wmma_available(cc)) ? 8 : 256/warp_size; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you changing this? |
||
| } | ||
| #else | ||
| static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) { | ||
|
|
@@ -1178,7 +1178,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma( | |
| } | ||
| } | ||
| } | ||
| #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles | ||
| #elif defined(AMD_WMMA_INT_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles | ||
| typedef tile<16, 4, int> tile_A; | ||
| typedef tile<16, 4, int> tile_B; | ||
| typedef tile<16, 16, int> tile_C; | ||
|
|
@@ -1500,7 +1500,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( | |
| } | ||
| } | ||
| } | ||
| #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles | ||
| #elif defined(AMD_WMMA_INT_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles | ||
|
|
||
| typedef tile<16, 4, int> tile_A; | ||
| typedef tile<16, 4, int> tile_B; | ||
|
|
@@ -1888,7 +1888,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa | |
| #else | ||
| int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y; | ||
| { | ||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) | ||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_INT_AVAILABLE) | ||
| if (need_check) { | ||
| i = min(i, i_max); | ||
| } | ||
|
|
@@ -2313,7 +2313,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( | |
| } | ||
| } | ||
| } | ||
| #elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles | ||
| #elif defined(AMD_WMMA_INT_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles | ||
| typedef tile<16, 4, int> tile_A; | ||
| typedef tile<16, 4, int> tile_B; | ||
| typedef tile<16, 16, int> tile_C; | ||
|
|
@@ -3018,7 +3018,7 @@ static __device__ __forceinline__ void mmq_write_back_mma( | |
| #else | ||
| typedef tile<16, 8, int> tile_C; | ||
| constexpr int rows_per_warp = 2 * granularity; | ||
| #endif // defined(AMD_MFMA_AVAILABLE) | ||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_INT_AVAILABLE) | ||
| constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp. | ||
|
|
||
| const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I); | ||
|
|
@@ -3232,7 +3232,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile( | |
| #else | ||
| constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a; | ||
| constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>; | ||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) | ||
| #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_INT_AVAILABLE) | ||
|
|
||
| constexpr int blocks_per_iter = MMQ_ITER_K / qk; | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I said before, use the term "not implemented" instead of "not supported" when talking about RDNA3 int8 WMMA in the context of llama.cpp/ggml.