Skip to content

[CUDA] columnwise quantize with tma#3157

Open
nastya236 wants to merge 18 commits intoml-explore:mainfrom
nastya236:tma_load
Open

[CUDA] columnwise quantize with tma#3157
nastya236 wants to merge 18 commits intoml-explore:mainfrom
nastya236:tma_load

Conversation

@nastya236
Copy link
Copy Markdown
Collaborator

@nastya236 nastya236 commented Feb 23, 2026

Columnwise quantization with tma (mxfp8), bfloat16:

Size With tma (ms) Without tma (ms)
4096×4096 68.48 77.21
4096×8192 78.74 102.58
8192×4096 80.73 100.61
8192×8192 100.67 145.16
4096×16384 97.08 144.45
16384×4096 99.50 137.13

This PR:

  • Adds PTX instructions for asynchronous copy with TMA
  • Adds fp_quantize_columnwise_mxfp8 kernel for columnwise MXFP8 quantization using TMA on SM100+
  • Splits fp_quantize.cu into fp_quantize.cu (dispatch) and fp_quantize.cuh (kernels) to reduce file size
  • Moves swizzle_scales constants and get_swizzle_launch_args into cu:: namespace for consistency

TODO:
nvfp4 requires a separate columnwise kernel due to TMA tile size constraints. In the proposed kernel each thread processing a tile of size (N, M) and store a transposed result. M is equal to group_size -- 32 bytes formcfp8, but only 8 bytes for nvfp4. Since TMA requires the innermost tile dimension to be at least 128 bits (16 bytes), for nvfp4 kernel would need to load a larger tile and iterate over multiple groups.

@nastya236 nastya236 requested a review from zcbenz March 17, 2026 15:42
@nastya236 nastya236 marked this pull request as ready for review March 17, 2026 15:43
@nastya236 nastya236 changed the title [WIP] columnwise quantize with tma [CUDA] columnwise quantize with tma Mar 17, 2026
#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \
defined(__CUDA_ARCH_SPECIFIC__)

__device__ __forceinline__ void mbarrier_init(uint64_t* mbar, uint32_t count) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use the cuda::ptx APIs like cuda::ptx::mbarrier_init API instead? They don't have good documentation and you would have to search https://github.com/NVIDIA/cccl to find out API names though.

@@ -10,11 +10,6 @@ namespace mlx::core {

namespace cg = cooperative_groups;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be moved to namespace cu too.

auto tidy = block_idx.y * block_size.y + idx_in_block.y;
auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x;

size_t thread_idx = tidx + grid_dim_x * size_t(tidy);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can just use cg::this_grid().thread_rank()?

int in_size_bytes, // itemsize
int bits) {
dim3 grid;
grid.x = (grid_dim_x_size + block_size_x - 1) / block_size_x;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you use cuda::ceil_div when possible?

Suggested change
grid.x = (grid_dim_x_size + block_size_x - 1) / block_size_x;
grid.x = cuda::ceil_div(grid_dim_x_size, block_size_x);


constexpr size_t out_tile_elems = BUFF_ELEMS / elem_per_byte;
constexpr size_t out_tile_size = out_tile_elems;
constexpr size_t out_buff_size_aligned =
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not used anywhere.

(reinterpret_cast<uintptr_t>(shared_mem) + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));

T* in_sh = reinterpret_cast<T*>(aligned_shared);
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can make sure you get necessary alignment with dynamic allocated shared memory with this:

  extern __shared__ uint128_t shared_mem[];
  T* in_sh = reinterpret_cast<T*>(shared_mem);

or:

extern __shared__ alignas(128) char shared_mem[];

Also, I think in_smem would be an easier to understand name.

((out_tile_elems * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) /
TMA_SHMEM_ALIGNMENT) *
TMA_SHMEM_ALIGNMENT;
const size_t smem_size =
Copy link
Copy Markdown
Collaborator

@zcbenz zcbenz Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that the size of shared memory is static? I don't think you need to use dynamic shared memory in this case, you can ensure alignment with:

__shared__ alignas(128) T smem[SIZE];

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants