-
Notifications
You must be signed in to change notification settings - Fork 1.6k
[CUDA] columnwise quantize with tma #3157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
nastya236
wants to merge
18
commits into
ml-explore:main
Choose a base branch
from
nastya236:tma_load
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
4a5ebd8
[wip] tma
nastya236 4d7b8d8
Merge remote-tracking branch 'upstream/main' into tma_load
nastya236 b03cc9d
tma columnwise quantize
nastya236 252c686
refactoring: separate fp_quant kernels from dispatch
nastya236 6248b16
fixed gpu_ptr
nastya236 a2e0f4d
[wip] rowwise tma, columnwise tma
nastya236 5a7668c
fixed batched case
nastya236 145f6e6
[wip] bank conflict
nastya236 66fb9eb
drop rowwise for now
nastya236 4a6e1b8
fix kernel names
nastya236 c2f2a97
Merge branch 'main' into tma_load
nastya236 1c87d0e
refactor
nastya236 08d49af
move common.h in quantized_utils.cuh
nastya236 56f25b3
fixed example
nastya236 3b10cb4
fix qqmm example
nastya236 8882fa6
Merge branch 'main' into tma_load
nastya236 8376a1e
merge main, fix conflicts
nastya236 4c4f3a8
fix old scale convertion
nastya236 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,127 @@ | ||
| #pragma once | ||
|
|
||
| #include <cuda.h> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| namespace mlx::core { | ||
|
|
||
| namespace ptx { | ||
|
|
||
| #if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \ | ||
| defined(__CUDA_ARCH_SPECIFIC__) | ||
|
|
||
| __device__ __forceinline__ void mbarrier_init(uint64_t* mbar, uint32_t count) { | ||
| uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); | ||
| asm volatile("mbarrier.init.shared.b64 [%0], %1;" | ||
| : | ||
| : "r"(mbar_ptr), "r"(count) | ||
| : "memory"); | ||
| } | ||
|
|
||
| __device__ __forceinline__ void mbarrier_invalidate(uint64_t* mbar) { | ||
| uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); | ||
| asm volatile("mbarrier.inval.shared.b64 [%0];" : : "r"(mbar_ptr) : "memory"); | ||
| } | ||
|
|
||
| // Arrive at barrier (non-master threads) | ||
| __device__ __forceinline__ void mbarrier_arrive(uint64_t* mbar) { | ||
| uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); | ||
| asm volatile("mbarrier.arrive.shared.b64 _, [%0];" | ||
| : | ||
| : "r"(mbar_ptr) | ||
| : "memory"); | ||
| } | ||
|
|
||
| // Arrive at barrier and set expected transaction count (master thread) | ||
| __device__ __forceinline__ void mbarrier_arrive_expect_tx( | ||
| uint64_t* mbar, | ||
| uint32_t tx_count) { | ||
| uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); | ||
| asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" | ||
| : | ||
| : "r"(mbar_ptr), "r"(tx_count) | ||
| : "memory"); | ||
| } | ||
|
|
||
| // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-asynchronous-copy-completion-mechanisms-mbarrier | ||
| __device__ __forceinline__ void mbarrier_wait_parity( | ||
| uint64_t* mbar, | ||
| uint32_t parity) { | ||
| uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); | ||
| asm volatile( | ||
| "{\n\t" | ||
| ".reg .pred P;\n\t" | ||
| "WAIT_LOOP:\n\t" | ||
| "mbarrier.try_wait.parity.shared.b64 P, [%0], %1;\n\t" | ||
| "@!P bra WAIT_LOOP;\n\t" | ||
| "}\n\t" | ||
| : | ||
| : "r"(mbar_ptr), "r"(parity) | ||
| : "memory"); | ||
| } | ||
|
|
||
| // Async bulk tensor copy: global -> shared (2D) | ||
| __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( | ||
| void* dst_shmem, | ||
| const CUtensorMap* tensor_map, | ||
| uint32_t tile_x, | ||
| uint32_t tile_y, | ||
| uint64_t* mbar) { | ||
| uint32_t dst_ptr = __cvta_generic_to_shared(dst_shmem); | ||
| uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); | ||
|
|
||
| asm volatile( | ||
| "cp.async.bulk.tensor.2d.shared::cluster.global.tile" | ||
| ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" | ||
| : | ||
| : "r"(dst_ptr), "l"(tensor_map), "r"(tile_x), "r"(tile_y), "r"(mbar_ptr) | ||
| : "memory"); | ||
| } | ||
|
|
||
| // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor | ||
| // shared::cta -> global | ||
| __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( | ||
| const uint64_t* tensor_map_ptr, | ||
| const uint32_t offset_x, | ||
| const uint32_t offset_y, | ||
| uint64_t* src_shmem) { | ||
| uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); | ||
| asm volatile( | ||
| "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" :: | ||
| "l"(tensor_map_ptr), | ||
| "r"(offset_x), | ||
| "r"(offset_y), | ||
| "r"(src_shmem_ptr) | ||
| : "memory"); | ||
| } | ||
|
|
||
| // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group | ||
| template <int N> | ||
| __device__ __forceinline__ void cp_async_bulk_wait_group_read() { | ||
| if constexpr (N == 0) { | ||
| asm volatile("cp.async.bulk.wait_group.read 0;"); | ||
| } else if constexpr (N == 1) { | ||
| asm volatile("cp.async.bulk.wait_group.read 1;"); | ||
| } else if constexpr (N == 2) { | ||
| asm volatile("cp.async.bulk.wait_group.read 2;"); | ||
| } else if constexpr (N == 4) { | ||
| asm volatile("cp.async.bulk.wait_group.read 4;"); | ||
| } | ||
| } | ||
|
|
||
| // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group | ||
| __device__ __forceinline__ void cp_async_bulk_commit_group() { | ||
| asm volatile("cp.async.bulk.commit_group;"); | ||
| } | ||
|
|
||
| // Ccreates a memory ordering barrier between generic and async proxies | ||
| // details: | ||
| // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#proxies | ||
| __device__ __forceinline__ void fence_proxy_async_shared_cta() { | ||
| asm volatile("fence.proxy.async.shared::cta;"); | ||
| } | ||
|
|
||
| #endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && | ||
| // (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000) | ||
| } // namespace ptx | ||
| } // namespace mlx::core | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we use the
cuda::ptxAPIs likecuda::ptx::mbarrier_initAPI instead? They don't have good documentation and you would have to search https://github.com/NVIDIA/cccl to find out API names though.