Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
2f7cfcf
mmf for rdna4
Nov 7, 2025
d564a35
align the padding for rdna4
Nov 7, 2025
0ec241d
Merge branch 'ggml-org:master' into mmf_wmma_rdna4
zhang-hui-yulo Nov 9, 2025
bbee5fe
forbit mul_mat_f for rdna4
Nov 9, 2025
6b8ceeb
Merge branch 'ggml-org:master' into mmf_wmma_rdna4
zhang-hui-yulo Nov 11, 2025
fd18344
fix as comment
Nov 11, 2025
7a09e22
remove device kernels
Nov 11, 2025
c65dd59
add constexpr for early return
Nov 11, 2025
48a53b5
Merge branch 'ggml-org:master' into mmf_wmma_rdna4
zhang-hui-yulo Nov 13, 2025
b7c13ee
update based on review comment
Nov 13, 2025
a0aa491
change based on the review comment
Nov 13, 2025
8c2f9a3
Merge branch 'ggml-org:master' into mmf_wmma_rdna4
zhang-hui-yulo Nov 14, 2025
7a88d7c
pass compile error
Nov 14, 2025
cfc149a
Merge branch 'ggml-org:master' into mmf_wmma_rdna4
zhang-hui-yulo Nov 14, 2025
59a012f
keep code consistency
Nov 17, 2025
6802fbf
Merge branch 'ggml-org:master' into mmf_wmma_rdna4
zhang-hui-yulo Nov 19, 2025
facded5
Merge branch 'ggml-org:master' into mmf_wmma_rdna4
zhang-hui-yulo Nov 21, 2025
28b5e45
Attempt to port RDNA4 WMMA optimizations to RDNA3
Nov 21, 2025
ba25661
Merge branch 'master' into mmf_wmma_rdna3
Nov 25, 2025
edb86ef
WMMA RDNA3 fixes
Nov 25, 2025
5a80fb4
fix RDNA3 not using the fast DP4A-based MMQ path. RDNA4 should still …
Nov 25, 2025
8aed111
more fixes for RDNA3 crashes with quantized models
Nov 25, 2025
9191856
Fix RDNA3 WMMA int8 codepath
Nov 27, 2025
c692629
fix endif commets
Nov 27, 2025
9c9b0ea
More fixes on get_i/get_j, add support for more tile sizes. get_j res…
Nov 29, 2025
816cb1d
Disable WMMA for RDNA3, enable only for RDNA4
Nov 29, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ static const char * cu_get_error_str(CUresult err) {
#define AMD_MFMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA)

// WMMA is only properly supported on RDNA4, not RDNA3
// RDNA3 has incomplete int8 WMMA support and should use fallback path
Comment on lines +227 to +228
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I said before, use the term "not implemented" instead of "not supported" when talking about RDNA3 int8 WMMA in the context of llama.cpp/ggml.

#if defined(GGML_USE_HIP) && defined(RDNA4)
#define AMD_WMMA_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(RDNA4)
Expand Down Expand Up @@ -288,6 +290,11 @@ static bool amd_mfma_available(const int cc) {
}

static bool amd_wmma_available(const int cc) {
return GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc);
}

// Integer WMMA is only available on RDNA4
static bool amd_wmma_int_available(const int cc) {
return GGML_CUDA_CC_IS_RDNA4(cc);
}

Expand Down
142 changes: 111 additions & 31 deletions ggml/src/ggml-cuda/mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -152,15 +152,26 @@ namespace ggml_cuda_mma {
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
static constexpr int ne = I * J / 32;
#else
// RDNA3: Accumulator uses same ne, but input tiles need more VGPRs
static constexpr int ne = (I == 16 && J == 16) ? (I * J / 32) : (I * J / 16);
#endif
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#endif
#endif // defined(RDNA4)

Please annotate all #endif statements with comments like this to indicate which #if/#ifdef they're closing.

T x[ne] = {0};

static constexpr __device__ bool supported() {
// Integer and FP16 WMMA are supported on RDNA3/RDNA4
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Integer and FP16 WMMA are supported on RDNA3/RDNA4

if (I == 16 && J == 4) return true;
if (I == 16 && J == 8) return true;
if (I == 16 && J == 16) return true;
return false;
}

static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 16) {
if constexpr (I == 16 && J == 4) {
return threadIdx.x % 16;
} else if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else if constexpr (I == 16 && J == 16) {
return 8 * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
Expand All @@ -169,14 +180,25 @@ namespace ggml_cuda_mma {
}

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 16) {
if constexpr (I == 16 && J == 4) {
#if defined(RDNA4)
return 2 * (threadIdx.x / 16) + l;
#else
return l; // RDNA3: ne=4, l ranges 0-3, all threads load full J dimension
#endif
} else if constexpr (I == 16 && J == 8) {
#if defined(RDNA4)
return 4 * (threadIdx.x / 16) + l;
#else
return l; // RDNA3: ne=8, l ranges 0-7, all threads load full J dimension
#endif
} else if constexpr (I == 16 && J == 16) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#endif
#else
static constexpr int ne = I * J / 32;
T x[ne] = {0};
Expand Down Expand Up @@ -223,7 +245,7 @@ namespace ggml_cuda_mma {
return -1;
}
}
#endif // defined(GGML_USE_HIP)
#endif // defined(AMD_WMMA_AVAILABLE)
};

template <int I_, int J_>
Expand Down Expand Up @@ -265,7 +287,11 @@ namespace ggml_cuda_mma {
}
}
#elif defined(AMD_WMMA_AVAILABLE)
static constexpr int ne = I * J / 32;
#if defined(RDNA4)
static constexpr int ne = I * J / 32; // 4 half2 = 8 FP16 for RDNA4
#else
static constexpr int ne = I * J / 16; // 8 half2 = 16 FP16 for RDNA3 (duplicate layout)
#endif // defind(AMD_WMMA_AVAILABLE)
half2 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
Expand All @@ -284,7 +310,12 @@ namespace ggml_cuda_mma {

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
#if defined(RDNA4)
return 4 * (threadIdx.x / 16) + l;
#else
// RDNA3 has duplicate layout, iterate through full J dimension
return l % J;
#endif
} else {
NO_DEVICE_CODE;
return -1;
Expand Down Expand Up @@ -341,7 +372,11 @@ namespace ggml_cuda_mma {
static constexpr int J = J_;

#if defined(AMD_WMMA_AVAILABLE)
static constexpr int ne = I * J / 32;
#if defined(RDNA4)
static constexpr int ne = I * J / 32; // 4 bfloat162 = 8 BF16 for RDNA4
#else
static constexpr int ne = I * J / 16; // 8 bfloat162 = 16 BF16 for RDNA3 (duplicate layout)
#endif // defined(RDNA4)
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};

static constexpr __device__ bool supported() {
Expand All @@ -360,7 +395,12 @@ namespace ggml_cuda_mma {

static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
#if defined(RDNA4)
return 4 * (threadIdx.x / 16) + l;
#else
// RDNA3 has duplicate layout, iterate through full J dimension
return l % J;
#endif
} else {
NO_DEVICE_CODE;
return -1;
Expand Down Expand Up @@ -437,26 +477,17 @@ namespace ggml_cuda_mma {
xi[0] = xs[0];
}
#elif defined(AMD_WMMA_AVAILABLE)
if constexpr (I == 16 && J == 4) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 2 * (threadIdx.x / t.I));
xi[0] = xs[0];
}else if constexpr (I == 16 && J == 8) {
int64_t * xi = (int64_t *) t.x;
const int64_t * xs = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I));
xi[0] = xs[0];

const int64_t * xs1 = (int64_t *) ((const int *) xs0 + (threadIdx.x % t.I) * stride + 4 * (threadIdx.x / t.I) + 2);
xi[1] = xs1[0];
}else{
NO_DEVICE_CODE;
// Use generic load path for all AMD WMMA to ensure correct register mapping
#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
Comment on lines -440 to +483
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't just remove the preexisting code.

}
#else
#else
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#else
#else

#pragma unroll
for (int l = 0; l < t.ne; ++l) {
t.x[l] = xs0[t.get_i(l)*stride + t.get_j(l)];
}
#endif // defined(AMD_MFMA_AVAILABLE)
#endif // AMD_MFMA_AVAILABLE
}

template <typename T>
Expand Down Expand Up @@ -738,12 +769,21 @@ namespace ggml_cuda_mma {
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[3]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#elif defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using halfx8_t = __attribute__((ext_vector_type(8))) _Float16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const halfx8_t& a_frag = reinterpret_cast<const halfx8_t&>(A.x[0]);
const halfx8_t& b_frag = reinterpret_cast<const halfx8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12(a_frag, b_frag, acc_frag);
#else // RDNA3
using halfx16_t = __attribute__((ext_vector_type(16))) _Float16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const halfx16_t& a_frag = reinterpret_cast<const halfx16_t&>(A.x[0]);
const halfx16_t& b_frag = reinterpret_cast<const halfx16_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(a_frag, b_frag, acc_frag);
#endif // defined(RDNA4)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
Expand All @@ -753,16 +793,25 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ void mma(
tile<16, 16, float> & D, const tile<16, 8, nv_bfloat162> & A, const tile<16, 8, nv_bfloat162> & B) {
#if defined(AMD_WMMA_AVAILABLE)
#if defined(RDNA4)
using bf16x8_t = __attribute__((ext_vector_type(8))) __bf16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const bf16x8_t& a_frag = reinterpret_cast<const bf16x8_t&>(A.x[0]);
const bf16x8_t& b_frag = reinterpret_cast<const bf16x8_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12(a_frag, b_frag, acc_frag);
#else // RDNA3
using bf16x16_t = __attribute__((ext_vector_type(16))) __bf16;
using floatx8_t = __attribute__((ext_vector_type(8))) float;
floatx8_t& acc_frag = reinterpret_cast<floatx8_t&>(D.x[0]);
const bf16x16_t& a_frag = reinterpret_cast<const bf16x16_t&>(A.x[0]);
const bf16x16_t& b_frag = reinterpret_cast<const bf16x16_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(a_frag, b_frag, acc_frag);
#endif // defined(RDNA4)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE
#endif // AMD_WMMA_AVAILABLE
}

static __device__ __forceinline__ void mma(
Expand All @@ -786,15 +835,15 @@ namespace ggml_cuda_mma {
0, 0, 0);
#endif // defined(CDNA3)

#elif defined(AMD_WMMA_AVAILABLE)
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;

#elif defined(AMD_WMMA_INT_AVAILABLE)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You forgot to add #define AMD_WMMA_INT_AVAILABLE in common.cuh so your PR is essentially disabling int8 WMMA for all GPUs. Also, if your int8 WMMA implementation is broken, please remove it again. This also goes for the tile sizes that you added, please keep only those that are actually being used and confirmed to work correctly.

using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;

#if defined(RDNA4)
// RDNA4 uses gfx12-specific intrinsic with int32x2_t
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
Expand All @@ -813,6 +862,20 @@ namespace ggml_cuda_mma {
acc[0],
true
);
#else // RDNA3
// RDNA3 uses generic intrinsic with int32x4_t
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
const int32x4_t& a0 = reinterpret_cast<const int32x4_t&>(A.x[0]);
const int32x4_t& b0 = reinterpret_cast<const int32x4_t&>(B.x[0]);

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
true,
a0,
true,
b0,
acc[0],
true
);
#endif // defined(RDNA4)

#else
Expand Down Expand Up @@ -889,14 +952,16 @@ namespace ggml_cuda_mma {

static __device__ __forceinline__ void mma(
tile<16, 16, int> & D, const tile<16, 4, int> & A, const tile<16, 4, int> & B) {
#if defined(AMD_WMMA_AVAILABLE)
#if defined(AMD_WMMA_INT_AVAILABLE)
using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;

#if defined(RDNA4)
// RDNA4 uses gfx12-specific intrinsic with int32x2_t
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
int32x2_t * a_vec = (int32x2_t *) A.x;
int32x2_t * b_vec = (int32x2_t *) B.x;

using int32x8_t = __attribute__((__vector_size__(8 * sizeof(int)))) int;
int32x8_t * acc = (int32x8_t *) D.x;

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12(
true,
a_vec[0],
Expand All @@ -905,6 +970,21 @@ static __device__ __forceinline__ void mma(
acc[0],
false
);
#else // RDNA3
// RDNA3 uses generic intrinsic with int32x4_t
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
const int32x4_t& a0 = reinterpret_cast<const int32x4_t&>(A.x[0]);
const int32x4_t& b0 = reinterpret_cast<const int32x4_t&>(B.x[0]);

acc[0] = __builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
true,
a0,
true,
b0,
acc[0],
false
);
#endif
#else
GGML_UNUSED(D);
GGML_UNUSED(A);
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/mmf.cu
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
return false;
}
} else {
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
if (src1_ncols > 16 || amd_wmma_available(cc)) {
return false;
}
}
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-cuda/mmq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -306,11 +306,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}

if (amd_wmma_available(cc)) {
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return true;
}
// RDNA4 has integer WMMA support, always use MMQ
if (GGML_CUDA_CC_IS_RDNA4(cc)) {
return true;
}

// RDNA3 uses DP4A path (no integer WMMA), enable for small batches
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
14 changes: 7 additions & 7 deletions ggml/src/ggml-cuda/mmq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ static constexpr __device__ int mmq_get_granularity_device(const int /*mmq_x*/)

#if defined(GGML_USE_HIP)
static int mmq_get_nwarps_host(const int cc, const int warp_size) {
return amd_mfma_available(cc) ? 8 : 256/warp_size;
return (amd_mfma_available(cc) || amd_wmma_available(cc)) ? 8 : 256/warp_size;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you changing this?

}
#else
static int mmq_get_nwarps_host(const int /*cc*/, const int warp_size) {
Expand Down Expand Up @@ -1178,7 +1178,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
}
}
}
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
#elif defined(AMD_WMMA_INT_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
typedef tile<16, 4, int> tile_A;
typedef tile<16, 4, int> tile_B;
typedef tile<16, 16, int> tile_C;
Expand Down Expand Up @@ -1500,7 +1500,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
}
}
}
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
#elif defined(AMD_WMMA_INT_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles

typedef tile<16, 4, int> tile_A;
typedef tile<16, 4, int> tile_B;
Expand Down Expand Up @@ -1888,7 +1888,7 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
#else
int i = (i0 + threadIdx.y*rows_per_warp + threadIdx.x/2) % mmq_y;
{
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_INT_AVAILABLE)
if (need_check) {
i = min(i, i_max);
}
Expand Down Expand Up @@ -2313,7 +2313,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
}
}
}
#elif defined(AMD_WMMA_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
#elif defined(AMD_WMMA_INT_AVAILABLE) //wmma instructions can handle 16x4 tiles, does not require loading 64x2 tiles
typedef tile<16, 4, int> tile_A;
typedef tile<16, 4, int> tile_B;
typedef tile<16, 16, int> tile_C;
Expand Down Expand Up @@ -3018,7 +3018,7 @@ static __device__ __forceinline__ void mmq_write_back_mma(
#else
typedef tile<16, 8, int> tile_C;
constexpr int rows_per_warp = 2 * granularity;
#endif // defined(AMD_MFMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_INT_AVAILABLE)
constexpr int ntx = rows_per_warp/tile_C::I; // Number of x minitiles per warp.

const int i0 = (threadIdx.y / ntx) * (ntx*tile_C::I);
Expand Down Expand Up @@ -3232,7 +3232,7 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
#else
constexpr vec_dot_mmq_t vec_dot = mmq_type_traits<mmq_x, mmq_y, need_check, type>::vec_dot_dp4a;
constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_INT_AVAILABLE)

constexpr int blocks_per_iter = MMQ_ITER_K / qk;

Expand Down