-
Notifications
You must be signed in to change notification settings - Fork 13.9k
HIP: enable WMMA-MMQ INT kernels for RDNA 3 #17576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -307,10 +307,11 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) { | |||||||||
| } | ||||||||||
|
|
||||||||||
| if (amd_wmma_available(cc)) { | ||||||||||
| if (GGML_CUDA_CC_IS_RDNA4(cc)) { | ||||||||||
| if (GGML_CUDA_CC_IS_RDNA4(cc) || GGML_CUDA_CC_IS_RDNA3(cc)) { | ||||||||||
| return true; | ||||||||||
| } | ||||||||||
|
Comment on lines
+310
to
312
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||
| } | ||||||||||
|
|
||||||||||
| return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; | ||||||||||
| return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE; | ||||||||||
|
|
||||||||||
| } | ||||||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -1544,6 +1544,8 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( | |||||||||||||||||
| tile_A A1; | ||||||||||||||||||
| A1.x[0] = 0x01010101; | ||||||||||||||||||
| A1.x[1] = 0x01010101; | ||||||||||||||||||
| A1.x[2] = 0x01010101; | ||||||||||||||||||
| A1.x[3] = 0x01010101; | ||||||||||||||||||
|
Comment on lines
1545
to
+1548
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
To my understanding |
||||||||||||||||||
| mma(Cm, A1, B); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -3701,7 +3703,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int | |||||||||||||||||
| const tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(type, mmq_y); | ||||||||||||||||||
| const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type); | ||||||||||||||||||
| const size_t nbs_ids = mmq_x*sizeof(int); | ||||||||||||||||||
| const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); | ||||||||||||||||||
| const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int); | ||||||||||||||||||
| const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq); | ||||||||||||||||||
| return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int)); | ||||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add comments to indicate which
#if/#ifdefand#endifis closing.