From 4a5ebd87b60b92774bb34babd301ad76089f996e Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sun, 15 Feb 2026 23:26:04 +0100 Subject: [PATCH 01/15] [wip] tma --- mlx/backend/cuda/ptx.cuh | 213 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 mlx/backend/cuda/ptx.cuh diff --git a/mlx/backend/cuda/ptx.cuh b/mlx/backend/cuda/ptx.cuh new file mode 100644 index 0000000000..ee2b8f351b --- /dev/null +++ b/mlx/backend/cuda/ptx.cuh @@ -0,0 +1,213 @@ +#include +#include + +namespace mlx::core { + +namespace ptx { +// For pipelining with TMA, we use memory barries, the mental model is the +// folowing: +// 1. master thread inits a barrier with expected number of threads to arrive, +// while the rest skip and __synchthreads() to wait for the barrier to be +// ready. +// 2. pipelining: before the loop, master thread launch an asynch copy and +// signals the barrier with the expected number of bytes to arrive, +// while the rest just signal arrival (first buffer). +// 3. So we have 2 buffers, and n pipeline stages. We launch the fist asynch +// copy before the begining of the loop. Then iterate over pipeline stages and +// copy to freed up buffer (buff = stage % BUFFS_NUM), but before the copy we +// want to check that the result was copied to global memory. at the firt loop +// nothing was commited so the check is succeseful so we launch a seconf asynch +// copy. Then we wait for the first one, do operations then lauch an asynch +// copy from shared -> global and return to the begining of the loop, now we +// check if the previous read from shared to global is done so we can reuse the +// buffer and launch async copy to fill buffer 0 + +__device__ __forceinline__ void mbarrier_init(uint64_t* mbar, uint32_t count) { +#if __CUDA_ARCH__ >= 900 + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.init.shared.b64 [%0], %1;" + : + : "r"(mbar_ptr), "r"(count) + : "memory"); +#endif +} + +// Invalidate a memory barrier (cleanup) +__device__ __forceinline__ void mbarrier_invalidate(uint64_t* mbar) { +#if __CUDA_ARCH__ >= 900 + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.inval.shared.b64 [%0];" : : "r"(mbar_ptr) : "memory"); +#endif +} + +// Arrive at barrier (non-master threads) +__device__ __forceinline__ void mbarrier_arrive(uint64_t* mbar) { +#if __CUDA_ARCH__ >= 900 + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.shared.b64 _, [%0];" + : + : "r"(mbar_ptr) + : "memory"); +#endif +} + +// Arrive at barrier and set expected transaction count (master thread) +__device__ __forceinline__ void mbarrier_arrive_expect_tx( + uint64_t* mbar, + uint32_t tx_count) { +#if __CUDA_ARCH__ >= 900 + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" + : + : "r"(mbar_ptr), "r"(tx_count) + : "memory"); +#endif +} + +// Wait for barrier to be satisfied +__device__ __forceinline__ void mbarrier_wait_parity( + uint64_t* mbar, + uint32_t parity) { +#if __CUDA_ARCH__ >= 900 + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + asm volatile( + "{\n\t" + ".reg .pred P;\n\t" + "WAIT_LOOP:\n\t" + "mbarrier.try_wait.parity.shared.b64 P, [%0], %1;\n\t" + "@!P bra WAIT_LOOP;\n\t" + "}\n\t" + : + : "r"(mbar_ptr), "r"(parity) + : "memory"); +#endif +} + +// Async bulk tensor copy: global -> shared (2D) +__device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( + void* dst_shmem, + const CUtensorMap* tensor_map, + uint32_t tile_x, + uint32_t tile_y, + uint64_t* mbar) { +#if __CUDA_ARCH__ >= 900 + uint32_t dst_ptr = __cvta_generic_to_shared(dst_shmem); + uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); + + asm volatile( + "cp.async.bulk.tensor.2d.shared::cluster.global.tile" + ".mbarrier::complete_tx::bytes [%0], [%1, {%2, %3}], [%4];" + : + : "r"(dst_ptr), "l"(tensor_map), "r"(tile_x), "r"(tile_y), "r"(mbar_ptr) + : "memory"); +#endif +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor +// shared::cta -> global +__device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( + const uint64_t* tensor_map_ptr, + const uint32_t offset_x, + const uint32_t offset_y, + uint64_t* src_shmem) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); + asm volatile( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" :: + "l"(tensor_map_ptr), + "r"(offset_x), + "r"(offset_y), + "r"(src_shmem_ptr) + : "memory"); +#else + NVTE_DEVICE_ERROR( + "cp_async_bulk_tensor_2d_shared_to_global is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +__device__ __forceinline__ void cp_async_bulk_wait_group() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("cp.async.bulk.wait_group 0;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_wait_group is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group +template +__device__ __forceinline__ void cp_async_bulk_wait_group_read() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("cp.async.bulk.wait_group.read 0;"); +#else + NVTE_DEVICE_ERROR( + "cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} + +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("cp.async.bulk.wait_group.read 0;"); +#else + NVTE_DEVICE_ERROR( + "cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("cp.async.bulk.wait_group.read 1;"); +#else + NVTE_DEVICE_ERROR( + "cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("cp.async.bulk.wait_group.read 2;"); +#else + NVTE_DEVICE_ERROR( + "cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} +template <> +__device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("cp.async.bulk.wait_group.read 4;"); +#else + NVTE_DEVICE_ERROR( + "cp_async_bulk_wait_group_read is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group +__device__ __forceinline__ void cp_async_bulk_commit_group() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("cp.async.bulk.commit_group;"); +#else + NVTE_DEVICE_ERROR("cp_async_bulk_commit_group is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} + +// Proxy fence (bi-directional): +__device__ __forceinline__ void fence_proxy_async() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("fence.proxy.async;"); +#else + NVTE_DEVICE_ERROR("fence_proxy_async is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} + +__device__ __forceinline__ void fence_proxy_async_shared_cta() { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + asm volatile("fence.proxy.async.shared::cta;"); +#else + NVTE_DEVICE_ERROR( + "fence_proxy_async_shared_cta is only supported on SM 9.0+."); +#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +} + +} // namespace ptx +} // namespace mlx::core \ No newline at end of file From b03cc9d1faf9a6a7088969fc0b17dc1df382f8e1 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 18 Feb 2026 14:43:19 +0100 Subject: [PATCH 02/15] tma columnwise quantize --- mlx/backend/cuda/common.h | 69 +++ mlx/backend/cuda/ptx.cuh | 111 +---- mlx/backend/cuda/quantized/fp_quantize.cu | 517 +++++++++++++++++++--- 3 files changed, 546 insertions(+), 151 deletions(-) create mode 100644 mlx/backend/cuda/common.h diff --git a/mlx/backend/cuda/common.h b/mlx/backend/cuda/common.h new file mode 100644 index 0000000000..e6d6d9c756 --- /dev/null +++ b/mlx/backend/cuda/common.h @@ -0,0 +1,69 @@ +// Copyright © 2025 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/backend/cuda/cuda_utils.h" +#include "mlx/backend/cuda/ptx.cuh" + +namespace mlx::core { + +__forceinline__ __device__ void copy_2d_to_shared( + void* dst, + const CUtensorMap* tensor_map, + uint32_t tile_x, + uint32_t tile_y, + uint32_t num_bytes, + uint64_t* barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Arrive and tell how many bytes are expected + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + dst, tensor_map, tile_x, tile_y, barrier); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + +// Host function to create a 2D TMA tensor map descriptor +// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html +inline void create_2D_tensor_map( + CUtensorMap* tensorMap, + void* input_ptr, + CUtensorMapDataType dtype, + uint64_t rows, + uint64_t cols, + uint32_t tile_y, + uint32_t tile_x, + uint64_t stride_bytes, + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE) { + constexpr uint32_t rank = 2; // 2D + uint64_t global_dim[rank] = {cols, rows}; + // For row-major layout + uint64_t strides[rank - 1] = {stride_bytes}; + uint32_t tile_dim[rank] = {tile_x, tile_y}; + uint32_t elem_stride[rank] = {1, 1}; + + CHECK_CUDA_ERROR(cuTensorMapEncodeTiled( + tensorMap, + dtype, + rank, + input_ptr, + global_dim, + strides, + tile_dim, + elem_stride, + CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, + CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); +} + +} // namespace mlx::core diff --git a/mlx/backend/cuda/ptx.cuh b/mlx/backend/cuda/ptx.cuh index ee2b8f351b..831600ee6c 100644 --- a/mlx/backend/cuda/ptx.cuh +++ b/mlx/backend/cuda/ptx.cuh @@ -1,3 +1,5 @@ +#pragma once + #include #include @@ -22,53 +24,46 @@ namespace ptx { // check if the previous read from shared to global is done so we can reuse the // buffer and launch async copy to fill buffer 0 +#if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \ + defined(__CUDA_ARCH_SPECIFIC__) + __device__ __forceinline__ void mbarrier_init(uint64_t* mbar, uint32_t count) { -#if __CUDA_ARCH__ >= 900 uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.init.shared.b64 [%0], %1;" : : "r"(mbar_ptr), "r"(count) : "memory"); -#endif } -// Invalidate a memory barrier (cleanup) __device__ __forceinline__ void mbarrier_invalidate(uint64_t* mbar) { -#if __CUDA_ARCH__ >= 900 uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.inval.shared.b64 [%0];" : : "r"(mbar_ptr) : "memory"); -#endif } // Arrive at barrier (non-master threads) __device__ __forceinline__ void mbarrier_arrive(uint64_t* mbar) { -#if __CUDA_ARCH__ >= 900 uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.arrive.shared.b64 _, [%0];" : : "r"(mbar_ptr) : "memory"); -#endif } // Arrive at barrier and set expected transaction count (master thread) __device__ __forceinline__ void mbarrier_arrive_expect_tx( uint64_t* mbar, uint32_t tx_count) { -#if __CUDA_ARCH__ >= 900 uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile("mbarrier.arrive.expect_tx.shared.b64 _, [%0], %1;" : : "r"(mbar_ptr), "r"(tx_count) : "memory"); -#endif } -// Wait for barrier to be satisfied +// https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-asynchronous-copy-completion-mechanisms-mbarrier __device__ __forceinline__ void mbarrier_wait_parity( uint64_t* mbar, uint32_t parity) { -#if __CUDA_ARCH__ >= 900 uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); asm volatile( "{\n\t" @@ -80,7 +75,6 @@ __device__ __forceinline__ void mbarrier_wait_parity( : : "r"(mbar_ptr), "r"(parity) : "memory"); -#endif } // Async bulk tensor copy: global -> shared (2D) @@ -90,7 +84,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( uint32_t tile_x, uint32_t tile_y, uint64_t* mbar) { -#if __CUDA_ARCH__ >= 900 uint32_t dst_ptr = __cvta_generic_to_shared(dst_shmem); uint32_t mbar_ptr = __cvta_generic_to_shared(mbar); @@ -100,7 +93,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_global_to_shared( : : "r"(dst_ptr), "l"(tensor_map), "r"(tile_x), "r"(tile_y), "r"(mbar_ptr) : "memory"); -#endif } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor @@ -110,7 +102,6 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( const uint32_t offset_x, const uint32_t offset_y, uint64_t* src_shmem) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) uint32_t src_shmem_ptr = __cvta_generic_to_shared(src_shmem); asm volatile( "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%1, %2}], [%3];" :: @@ -119,95 +110,35 @@ __device__ __forceinline__ void cp_async_bulk_tensor_2d_shared_to_global( "r"(offset_y), "r"(src_shmem_ptr) : "memory"); -#else - NVTE_DEVICE_ERROR( - "cp_async_bulk_tensor_2d_shared_to_global is only supported on SM 9.0+."); -#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) -} - -// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group -__device__ __forceinline__ void cp_async_bulk_wait_group() { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - asm volatile("cp.async.bulk.wait_group 0;"); -#else - NVTE_DEVICE_ERROR("cp_async_bulk_wait_group is only supported on SM 9.0+."); -#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-wait-group -template +template __device__ __forceinline__ void cp_async_bulk_wait_group_read() { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - asm volatile("cp.async.bulk.wait_group.read 0;"); -#else - NVTE_DEVICE_ERROR( - "cp_async_bulk_wait_group_read is only supported on SM 9.0+."); -#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) -} - -template <> -__device__ __forceinline__ void cp_async_bulk_wait_group_read<0>() { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - asm volatile("cp.async.bulk.wait_group.read 0;"); -#else - NVTE_DEVICE_ERROR( - "cp_async_bulk_wait_group_read is only supported on SM 9.0+."); -#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) -} -template <> -__device__ __forceinline__ void cp_async_bulk_wait_group_read<1>() { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - asm volatile("cp.async.bulk.wait_group.read 1;"); -#else - NVTE_DEVICE_ERROR( - "cp_async_bulk_wait_group_read is only supported on SM 9.0+."); -#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) -} -template <> -__device__ __forceinline__ void cp_async_bulk_wait_group_read<2>() { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - asm volatile("cp.async.bulk.wait_group.read 2;"); -#else - NVTE_DEVICE_ERROR( - "cp_async_bulk_wait_group_read is only supported on SM 9.0+."); -#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) -} -template <> -__device__ __forceinline__ void cp_async_bulk_wait_group_read<4>() { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - asm volatile("cp.async.bulk.wait_group.read 4;"); -#else - NVTE_DEVICE_ERROR( - "cp_async_bulk_wait_group_read is only supported on SM 9.0+."); -#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + if constexpr (N == 0) { + asm volatile("cp.async.bulk.wait_group.read 0;"); + } else if constexpr (N == 1) { + asm volatile("cp.async.bulk.wait_group.read 1;"); + } else if constexpr (N == 2) { + asm volatile("cp.async.bulk.wait_group.read 2;"); + } else if constexpr (N == 4) { + asm volatile("cp.async.bulk.wait_group.read 4;"); + } } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-commit-group __device__ __forceinline__ void cp_async_bulk_commit_group() { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("cp.async.bulk.commit_group;"); -#else - NVTE_DEVICE_ERROR("cp_async_bulk_commit_group is only supported on SM 9.0+."); -#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) -} - -// Proxy fence (bi-directional): -__device__ __forceinline__ void fence_proxy_async() { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - asm volatile("fence.proxy.async;"); -#else - NVTE_DEVICE_ERROR("fence_proxy_async is only supported on SM 9.0+."); -#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } +// Ccreates a memory ordering barrier between generic and async proxies +// details: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#proxies __device__ __forceinline__ void fence_proxy_async_shared_cta() { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) asm volatile("fence.proxy.async.shared::cta;"); -#else - NVTE_DEVICE_ERROR( - "fence_proxy_async_shared_cta is only supported on SM 9.0+."); -#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } +#endif // (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && + // (__CUDA_ARCH_FAMILY_SPECIFIC__ >= 1000) } // namespace ptx } // namespace mlx::core \ No newline at end of file diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index d36ae6581e..e09ff21b7a 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -1,7 +1,9 @@ // Copyright © 2025 Apple Inc. +#include "mlx/backend/cuda/common.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" +#include "mlx/backend/cuda/ptx.cuh" #include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh" #include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh" #include "mlx/backend/cuda/quantized/quantized.h" @@ -13,6 +15,7 @@ #include #include #include +#include constexpr float F8E4M3_MAX = 448.0f; constexpr float F4E2M1_MAX = 6.0f; @@ -184,7 +187,7 @@ __global__ void fp_quantize_rowwise( } template -__global__ void fp_quantize_columnwise( +__global__ void fp_quantize_columnwise_fallback( T* w, uint8_t* out, uint8_t* scales, @@ -316,6 +319,216 @@ __global__ void fp_quantize_columnwise( } } +constexpr size_t TMA_SHMEM_ALIGNMENT = 128; +constexpr size_t ROWS_PER_BLOCK = 128; +constexpr size_t COLS_PER_BLOCK = 128; +constexpr size_t THREADS_PER_BLOCK = 128; +constexpr size_t TILE_M = 32; +constexpr size_t TILE_K = 128; +constexpr size_t BUFFS_NUM = 2; + +template +__global__ void __launch_bounds__(128) fp_quantize_columnwise_tma( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + uint8_t* __restrict__ scales, + const size_t rows, + const size_t cols) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) + using Tx2 = cu::Vector2_t; + using Tx4 = cu::Vector4_t; + + constexpr size_t TILE_M = group_size; + constexpr size_t TILE_K = COLS_PER_BLOCK; + constexpr size_t STAGES = ROWS_PER_BLOCK / TILE_M; // 4 for gs=32, 8 for gs=16 + + constexpr int elem_per_byte = (bits == 8) ? 1 : 2; + + const auto block_idx = cg::this_thread_block().group_index(); + const auto idx_in_block = cg::this_thread_block().thread_index(); + const int tidx = idx_in_block.x; // Thread handles column tidx + const bool is_master = (tidx == 0); + + const size_t block_offset_col = block_idx.x * COLS_PER_BLOCK; + const size_t block_offset_row = block_idx.y * ROWS_PER_BLOCK; + + constexpr size_t BUFF_ELEMS = TILE_M * TILE_K; + constexpr size_t in_tile_size = BUFF_ELEMS * sizeof(T); + constexpr size_t in_buff_size_aligned = + ((in_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + // Transposed output tile: TILE_K rows x TILE_M cols (128 x group_size) + // For FP8: 128 * group_size bytes, for FP4: 128 * (group_size/2) bytes + constexpr size_t out_tile_elems = BUFF_ELEMS / elem_per_byte; + constexpr size_t out_tile_size = out_tile_elems; + constexpr size_t out_buff_size_aligned = + ((out_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + constexpr size_t scales_per_tile = TILE_K; // One scale per column + constexpr size_t scales_tile_size = scales_per_tile * sizeof(uint8_t); + + extern __shared__ char shared_mem[]; + uintptr_t aligned_shared = + (reinterpret_cast(shared_mem) + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + T* in_sh = reinterpret_cast(aligned_shared); + uint8_t* out_sh = + reinterpret_cast(aligned_shared + in_buff_size_aligned); + uint8_t* scales_sh = reinterpret_cast( + aligned_shared + in_buff_size_aligned + out_buff_size_aligned); + + constexpr uint32_t tile_bytes = static_cast(in_tile_size); + + __shared__ alignas(8) uint64_t mbar[STAGES]; + + T thread_data[group_size]; + uint32_t rbits = 0; // Reserved for stochastic rounding + const size_t scale_stride = rows / group_size; + + // Master thread init memory barriers for all stages + // fence for tma, synchronize threads so all see mbarrier + if (is_master) { +#pragma unroll + for (int iter = 0; iter < STAGES; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + // Launch first async copy before entering the loop + copy_2d_to_shared( + &in_sh[0], + &tensor_map_input, + static_cast(block_offset_col), + static_cast(block_offset_row), + tile_bytes, + &mbar[0], + is_master); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + // buffer memory offset in shared memory (we use double buffering for + // pipelining) + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_row_offset = stage * TILE_M; + + if (next_stage < STAGES) { + // before launching another async copy, check that there is less than 2 + // (to ensure that shared -> global synch is finished and buffer can be + // reused) + ptx::cp_async_bulk_wait_group_read<1>(); + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_row_offset = block_offset_row + next_stage * TILE_M; + const size_t next_buff_elem_offset = next_buff * BUFF_ELEMS; + + copy_2d_to_shared( + &in_sh[next_buff_elem_offset], + &tensor_map_input, + static_cast(block_offset_col), + static_cast(next_row_offset), + tile_bytes, + &mbar[next_stage], + is_master); + } + + ptx::fence_proxy_async_shared_cta(); + // Wait until the data is ready, parity is always 0 because for simplicity + // we dont reuse barriers between stages + ptx::mbarrier_wait_parity(&mbar[stage], 0); + const size_t buff_offset = buff * BUFF_ELEMS; + // Read the data from shared to registers +#pragma unroll + for (int row = 0; row < group_size; ++row) { + thread_data[row] = in_sh[buff_offset + row * TILE_K + tidx]; + } + // Compute scale: find max absolute value in the column group + Tx2 amax_2x = Tx2{T(0.0f), T(0.0f)}; +#pragma unroll + for (int row = 0; row < group_size; row += 2) { + auto pair = Tx2{thread_data[row], thread_data[row + 1]}; + cu::absmax_x2(amax_2x, amax_2x, pair); + } + + float scale = + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y))); + + scale /= (bits == 4) ? 6.0f : 448.0f; + + using ScaleType = + std::conditional_t; + auto s = ScaleType(scale); + scale = float(s); + // Store scale to shared memory buffer + // TODO: currently it results in 4-way bank conflict + scales_sh[buff * TILE_K + tidx] = s.__x; + const size_t out_buff_offset = buff * out_tile_elems; + // Quantize and write output to shared memory +#pragma unroll + for (int j = 0; j < group_size / 4; ++j) { + Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); + if constexpr (bits == 8) { + uint32_t quantized_val = + cu::scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + *reinterpret_cast( + &out_sh[out_buff_offset + tidx * TILE_M + j * 4]) = quantized_val; + } else { + uint16_t quantized_val = + cu::scale_cvt_Tx4_to_fp4x4(w_Tx4, 1.0f / scale, rbits); + *reinterpret_cast( + &out_sh[out_buff_offset + tidx * (TILE_M / 2) + j * 2]) = + quantized_val; + } + } + __syncthreads(); + // Thread tidx computes scale for input column (block_offset_col + tidx) + // This scale goes to output row (block_offset_col + tidx), column + // (global_row_group) + const size_t global_row_group = + (block_offset_row + stage_row_offset) / group_size; + const size_t global_col = block_offset_col + tidx; + // TODO: scale writing is not good + if (global_col < cols && (block_offset_row + stage_row_offset) < rows) { + scales[global_col * scale_stride + global_row_group] = + scales_sh[buff * TILE_K + tidx]; + } + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + + if (is_master) { + const size_t global_row = block_offset_row + stage_row_offset; + const uint32_t out_x = (bits == 8) + ? static_cast(global_row) + : static_cast(global_row / 2); + const uint32_t out_y = static_cast(block_offset_col); + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), + out_x, + out_y, + reinterpret_cast(&out_sh[out_buff_offset])); + ptx::cp_async_bulk_commit_group(); + } + } + // Wait for all TMA stores to complete + ptx::cp_async_bulk_wait_group_read<0>(); + + __syncthreads(); + if (is_master) { +#pragma unroll + for (int iter = 0; iter < STAGES; ++iter) { + ptx::mbarrier_invalidate(&mbar[iter]); + } + } +#endif // __CUDA_ARCH__ >= 1000 +} + template __global__ void fp_dequantize( const uint8_t* w, @@ -364,6 +577,60 @@ __global__ void fp_dequantize( } } +inline std::tuple get_tma_columnwise_launch_args( + size_t rows, + size_t cols, + int group_size, + int bits, + int bytes_per_element) { + dim3 grid; + grid.x = (cols + COLS_PER_BLOCK - 1) / COLS_PER_BLOCK; + grid.y = (rows + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK; + grid.z = 1; + + dim3 block(THREADS_PER_BLOCK, 1, 1); + + const int elem_per_byte = (bits == 8) ? 1 : 2; + + const size_t BUFF_ELEMS = TILE_M * TILE_K; + const size_t in_tile_size = BUFF_ELEMS * bytes_per_element; + const size_t in_buff_size_aligned = + ((in_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + const size_t out_tile_elems = BUFF_ELEMS / elem_per_byte; + const size_t out_buff_size_aligned = + ((out_tile_elems * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + const size_t scales_tile_size = TILE_K * sizeof(uint8_t); + const size_t scales_buff_size_aligned = + ((scales_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + const size_t smem_size = in_buff_size_aligned + out_buff_size_aligned + + scales_buff_size_aligned + TMA_SHMEM_ALIGNMENT; + + return std::make_tuple(grid, block, smem_size); +} + +inline CUtensorMapDataType get_tma_dtype(Dtype dtype) { + switch (dtype) { + case float16: + return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; + case bfloat16: + return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; + case float32: + return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; + default: + throw std::runtime_error( + "[fp_quantize_columnwise_tma] Unsupported dtype for TMA"); + } +} + inline std::tuple get_columnwise_quantize_launch_args(size_t size, int group_size, int M, int K) { constexpr int BLOCK_X = 16; @@ -424,7 +691,7 @@ void fp_quantize_dequantize( }); } -void fp_quantize( +void fp_quantize_rowwise( const array& w, array& wq, array& scales, @@ -439,68 +706,196 @@ void fp_quantize( } enc.set_output_array(wq); enc.set_output_array(scales); - if (w.strides().back() != 1) { - dispatch_float_types(w.dtype(), "fp_quantize_columnwise", [&](auto type_tag) { - using T = cuda_type_t; - if constexpr (!std::is_same_v) { - auto M = w.shape(-2); - auto K = w.shape(-1); - auto kernel = cu::fp_quantize_columnwise; - if (bits == 8) { - kernel = cu::fp_quantize_columnwise; - } else if (group_size == 16) { - kernel = cu::fp_quantize_columnwise; - } - auto [num_blocks, block_dims] = - cu::get_columnwise_quantize_launch_args(w.size(), group_size, M, K); - enc.add_kernel_node( - kernel, - num_blocks, - block_dims, - 0, - gpu_ptr(w), - gpu_ptr(wq), - gpu_ptr(scales), - w.size(), - M, - K, - global_scale.has_value() ? gpu_ptr(global_scale.value()) - : nullptr); - } else { - throw std::runtime_error( - "[Quantize::eval_gpu] Can not quantize input with type float64."); + dispatch_float_types(w.dtype(), "fp_quantize_rowwise", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + auto kernel = cu::fp_quantize_rowwise; + if (bits == 8) { + kernel = cu::fp_quantize_rowwise; + } else if (group_size == 16) { + kernel = cu::fp_quantize_rowwise; } - }); - } else { - dispatch_float_types(w.dtype(), "fp_quantize_rowwise", [&](auto type_tag) { - using T = cuda_type_t; - if constexpr (!std::is_same_v) { - auto kernel = cu::fp_quantize_rowwise; - if (bits == 8) { - kernel = cu::fp_quantize_rowwise; - } else if (group_size == 16) { - kernel = cu::fp_quantize_rowwise; + bool large = w.size() > UINT_MAX; + auto [num_blocks, block_dims] = + get_launch_args(w.size(), w.shape(), w.strides(), large, group_size); + + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + gpu_ptr(w), + gpu_ptr(wq), + gpu_ptr(scales), + w.size(), + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); + } else { + throw std::runtime_error( + "[Quantize::eval_gpu] Can not quantize input with type float64."); + } + }); +} + +void fp_quantize_columnwise_fallback( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + const std::optional& global_scale /* = std::nullopt */, + cu::CommandEncoder& enc, + const Stream& s) { + enc.set_input_array(w); + if (global_scale.has_value()) { + enc.set_input_array(global_scale.value()); + } + enc.set_output_array(wq); + enc.set_output_array(scales); + dispatch_float_types( + w.dtype(), "fp_quantize_columnwise_fallback", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + auto M = w.shape(-2); + auto K = w.shape(-1); + auto kernel = + cu::fp_quantize_columnwise_fallback; + if (bits == 8) { + kernel = cu::fp_quantize_columnwise_fallback; + } else if (group_size == 16) { + kernel = + cu::fp_quantize_columnwise_fallback; + } + auto [num_blocks, block_dims] = + cu::get_columnwise_quantize_launch_args( + w.size(), group_size, M, K); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + gpu_ptr(w), + gpu_ptr(wq), + gpu_ptr(scales), + w.size(), + M, + K, + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); + } else { + throw std::runtime_error( + "[Quantize::eval_gpu] Can not quantize input with type float64."); } - bool large = w.size() > UINT_MAX; - auto [num_blocks, block_dims] = get_launch_args( - w.size(), w.shape(), w.strides(), large, group_size); - - enc.add_kernel_node( - kernel, - num_blocks, - block_dims, - 0, - gpu_ptr(w), - gpu_ptr(wq), - gpu_ptr(scales), - w.size(), - global_scale.has_value() ? gpu_ptr(global_scale.value()) - : nullptr); - } else { - throw std::runtime_error( - "[Quantize::eval_gpu] Can not quantize input with type float64."); + }); +} + +void fp_quantize_columnwise_tma( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + const std::optional& global_scale /* = std::nullopt */, + cu::CommandEncoder& enc, + const Stream& s) { + enc.set_input_array(w); + enc.set_output_array(wq); + enc.set_output_array(scales); + + size_t rows = w.shape(-1); + size_t cols = w.shape(-2); + size_t stride_bytes = w.strides(-1) * w.itemsize(); + + auto [grid, block, smem_size] = cu::get_tma_columnwise_launch_args( + rows, cols, group_size, bits, w.itemsize()); + + CUtensorMap tensor_map_input; + CUtensorMap tensor_map_output; + + create_2D_tensor_map( + &tensor_map_input, + const_cast(static_cast(w.data())), + cu::get_tma_dtype(w.dtype()), + rows, + cols, + static_cast(cu::TILE_M), + static_cast(cu::TILE_K), + stride_bytes); + + const size_t output_rows = cols; + const size_t output_cols = (bits == 8) ? rows : rows / 2; + const uint32_t out_tile_x = (bits == 8) ? cu::TILE_M : cu::TILE_M / 2; + const uint32_t out_tile_y = cu::TILE_K; + + create_2D_tensor_map( + &tensor_map_output, + static_cast(wq.data()), + CU_TENSOR_MAP_DATA_TYPE_UINT8, + output_rows, + output_cols, + out_tile_y, + out_tile_x, + output_cols); + + dispatch_float_types(w.dtype(), "fp_quantize_columnwise_tma", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + auto kernel = cu::fp_quantize_columnwise_tma; + if (bits == 8) { + kernel = cu::fp_quantize_columnwise_tma; + } else if (group_size == 16) { + kernel = cu::fp_quantize_columnwise_tma; } - }); + enc.add_kernel_node( + kernel, + grid, + block, + static_cast(smem_size), + tensor_map_input, + tensor_map_output, + gpu_ptr(scales), + rows, + cols); + } else { + throw std::runtime_error( + "[fp_quantize_columnwise_tma] Cannot quantize input with type float64."); + } + }); +} + +void fp_quantize_columnwise( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + const std::optional& global_scale /* = std::nullopt */, + cu::CommandEncoder& enc, + const Stream& s) { + if (enc.device().compute_capability_major() >= 10) { + std::cout << "Using TMA kernel for column-wise quantization" << std::endl; + fp_quantize_columnwise_tma( + w, wq, scales, group_size, bits, global_scale, enc, s); + } else { + fp_quantize_columnwise_fallback( + w, wq, scales, group_size, bits, global_scale, enc, s); + } +} + +void fp_quantize( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + const std::optional& global_scale /* = std::nullopt */, + cu::CommandEncoder& enc, + const Stream& s) { + if (w.strides(-1) == 1) { + fp_quantize_rowwise(w, wq, scales, group_size, bits, global_scale, enc, s); + } else { + fp_quantize_columnwise( + w, wq, scales, group_size, bits, global_scale, enc, s); } } From 252c686a3547c909e80e0e3242560ed9eb8c4309 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 18 Feb 2026 16:03:13 +0100 Subject: [PATCH 03/15] refactoring: separate fp_quant kernels from dispatch --- mlx/backend/cuda/quantized/fp_quantize.cu | 617 +----------------- mlx/backend/cuda/quantized/fp_quantize.cuh | 576 ++++++++++++++++ mlx/backend/cuda/quantized/qqmm_utils.cu | 21 +- .../cuda/quantized/quantized_utils.cuh | 1 + 4 files changed, 614 insertions(+), 601 deletions(-) create mode 100644 mlx/backend/cuda/quantized/fp_quantize.cuh diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index e09ff21b7a..359c1b78bc 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -1,11 +1,8 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/common.h" #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" -#include "mlx/backend/cuda/ptx.cuh" -#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh" -#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh" +#include "mlx/backend/cuda/quantized/fp_quantize.cuh" #include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/quantized/quantized_utils.cuh" #include "mlx/backend/cuda/vector_types.cuh" @@ -17,567 +14,10 @@ #include #include -constexpr float F8E4M3_MAX = 448.0f; -constexpr float F4E2M1_MAX = 6.0f; - namespace mlx::core { namespace cu { -template -struct Dequantize { - __device__ float operator()(uint8_t x) { - if constexpr (bits == 8) { - return float(*(__nv_fp8_e4m3*)(&x)); - } else { - return float(*(__nv_fp4_e2m1*)(&x)); - } - } -}; - -namespace cg = cooperative_groups; - -template -__global__ void fp_quantize_dequantize( - T* w, - T* out, - size_t size, - float* global_scale = nullptr) { - const bool use_global_scale = global_scale != nullptr; - const float scale_enc = - use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; - const float inv_scale_enc = use_global_scale ? 1.0f / scale_enc : 1.0f; - - using Tx2 = Vector2_t; - using Tx4 = Vector4_t; - uint32_t rbits = 0; // reserved bits for future use - auto block_size = cg::this_thread_block().dim_threads(); - auto block_idx = cg::this_thread_block().group_index(); - auto idx_in_block = cg::this_thread_block().thread_index(); - auto tidx = block_idx.x * block_size.x + idx_in_block.x; - auto tidy = block_idx.y * block_size.y + idx_in_block.y; - auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; - - size_t thread_idx = tidx + grid_dim_x * size_t(tidy); - size_t base_idx = thread_idx * group_size; - - if (base_idx >= size) { - return; - } - - auto w_tile = load_vector(w, thread_idx); - float scale_dec_b = 0.0f; - - Tx2 amax_2x = Tx2{0.0f, 0.0f}; - -#pragma unroll - for (int i = 0; i < group_size; i += 2) { - auto pair = Tx2{w_tile[i], w_tile[i + 1]}; - absmax_x2(amax_2x, amax_2x, pair); - } - - scale_dec_b = static_cast( - max(fabsf(static_cast(amax_2x.x)), - fabsf(static_cast(amax_2x.y)))); - - scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; - scale_dec_b *= scale_enc; - // Convert to mx scale or nv scale - using ScaleType = - std::conditional_t; - auto s = ScaleType(scale_dec_b); - float scale_enc_b = scale_enc / float(s); - float scale_dec = float(s) * inv_scale_enc; - AlignedVector w_hat; - -#pragma unroll - for (int i = 0; i < group_size / 4; i++) { - Tx4 w_Tx4 = *reinterpret_cast(&w_tile[i * 4]); - float4 dq; - if constexpr (bits == 8) { - uint32_t quantized_val = - scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); - dq = dequant_fp8(quantized_val); - } else { - uint16_t quantized_val = - scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); - dq = dequant_fp4(quantized_val); - } - w_hat[i * 4] = static_cast(dq.x * scale_dec); - w_hat[i * 4 + 1] = static_cast(dq.y * scale_dec); - w_hat[i * 4 + 2] = static_cast(dq.z * scale_dec); - w_hat[i * 4 + 3] = static_cast(dq.w * scale_dec); - } - store_vector(out, thread_idx, w_hat); -} - -template -__global__ void fp_quantize_rowwise( - T* w, - uint8_t* out, - uint8_t* scales, - size_t size, - float* global_scale = nullptr) { - // NVFP4 conversion: - // Global encode scale: (448 × 6) / *global_scale - // Per-block decode scale: S_dec_b = (block_amax / 6) × S_enc → stored as FP8 - // E4M3 Per-block encode scale: S_enc_b = S_enc / S_dec_b - const bool use_global_scale = global_scale != nullptr; - const float scale_enc = - use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; - - using Tx2 = Vector2_t; - using Tx4 = Vector4_t; - uint32_t rbits = 0; // reserved bits for future use - auto block_size = cg::this_thread_block().dim_threads(); - auto block_idx = cg::this_thread_block().group_index(); - auto idx_in_block = cg::this_thread_block().thread_index(); - auto tidx = block_idx.x * block_size.x + idx_in_block.x; - auto tidy = block_idx.y * block_size.y + idx_in_block.y; - auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; - - size_t thread_idx = tidx + grid_dim_x * size_t(tidy); - size_t base_idx = thread_idx * group_size; - - if (base_idx >= size) { - return; - } - - auto w_tile = load_vector(w, thread_idx); - float scale_dec_b = 0.0f; - - Tx2 amax_2x = Tx2{0.0f, 0.0f}; - -#pragma unroll - for (int i = 0; i < group_size; i += 2) { - auto pair = Tx2{w_tile[i], w_tile[i + 1]}; - absmax_x2(amax_2x, amax_2x, pair); - } - - scale_dec_b = static_cast( - max(fabsf(static_cast(amax_2x.x)), - fabsf(static_cast(amax_2x.y)))); - - scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; - scale_dec_b *= scale_enc; - // Convert to mx scale or nv scale - using ScaleType = - std::conditional_t; - auto s = ScaleType(scale_dec_b); - uint8_t q_scale = s.__x; - float scale_enc_b = scale_enc / float(s); - - scales[thread_idx] = q_scale; - constexpr int elem_per_byte = bits == 8 ? 1 : 2; - AlignedVector quantized; - -#pragma unroll - for (int i = 0; i < group_size / 4; i++) { - Tx4 w_Tx4 = *reinterpret_cast(&w_tile[i * 4]); - if constexpr (bits == 8) { - uint32_t quantized_val = - scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); - *reinterpret_cast(&quantized[i * 4]) = quantized_val; - } else { - uint16_t quantized_val = - scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); - *reinterpret_cast(&quantized[i * 2]) = quantized_val; - } - } - store_vector(out, thread_idx, quantized); -} - -template -__global__ void fp_quantize_columnwise_fallback( - T* w, - uint8_t* out, - uint8_t* scales, - size_t size, - int M, - int K, - float* global_scale = nullptr) { - // Input: [M, K] with strides [1, M] (M-major) - // Quantized output: [M, K/elem_per_byte] row-major (K-major) - // Scales: [M, K/group_size] row-major (K-major) - // Quantize along K (last dimension, groups of group_size elements) - const bool use_global_scale = global_scale != nullptr; - const float scale_enc = - use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; - - using Tx2 = Vector2_t; - using Tx4 = Vector4_t; - uint32_t rbits = 0; - - auto block_idx = cg::this_thread_block().group_index(); - auto idx_in_block = cg::this_thread_block().thread_index(); - - constexpr int BLOCK_X = 16; - constexpr int BLOCK_Y = 32; - constexpr int elem_per_byte = (bits == 8) ? 1 : 2; - constexpr int bytes_per_group = group_size / elem_per_byte; - - constexpr int rows_per_block = BLOCK_X; - constexpr int cols_per_block = BLOCK_Y * group_size; - constexpr int local_cols = cols_per_block / elem_per_byte; - constexpr int bytes_per_block = rows_per_block * local_cols; - - constexpr int SMEM_PAD = 4; - constexpr int padded_local_cols = local_cols + SMEM_PAD; - - auto tidx = idx_in_block.x; - auto tidy = idx_in_block.y; - - int num_col_blocks = (K + cols_per_block - 1) / cols_per_block; - auto bidx = block_idx.x % num_col_blocks; - auto bidy = block_idx.x / num_col_blocks; - - T thread_data[group_size]; - - __shared__ uint8_t quantized_smem[rows_per_block * padded_local_cols]; - __shared__ uint8_t scales_smem[BLOCK_X][BLOCK_Y + SMEM_PAD]; - - int row_base = bidy * rows_per_block + tidx; - int col_base = bidx * cols_per_block + tidy * group_size; - - bool valid = (row_base < M) && (col_base + group_size <= K); - if (valid) { -#pragma unroll - for (int i = 0; i < group_size; i++) { - auto index = row_base + (col_base + i) * M; - thread_data[i] = w[index]; - } - - // Compute scale - Tx2 amax_2x = Tx2{0.0f, 0.0f}; -#pragma unroll - for (int r = 0; r < group_size; r += 2) { - auto pair = Tx2{thread_data[r], thread_data[r + 1]}; - absmax_x2(amax_2x, amax_2x, pair); - } - float scale_dec_b = - max(fabsf(static_cast(amax_2x.x)), - fabsf(static_cast(amax_2x.y))); - scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; - scale_dec_b *= scale_enc; - // Convert to mx scale or nv scale - using ScaleType = - std::conditional_t; - auto s = ScaleType(scale_dec_b); - float scale_enc_b = scale_enc / float(s); - scales_smem[tidx][tidy] = s.__x; - - int shared_idx = tidx * padded_local_cols + tidy * bytes_per_group; - -#pragma unroll - for (int j = 0; j < group_size / 4; j++) { - Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); - if constexpr (bits == 8) { - uint32_t quantized_val = - scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); - *reinterpret_cast(&quantized_smem[shared_idx + j * 4]) = - quantized_val; - } else { - uint16_t quantized_val = - scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); - *reinterpret_cast(&quantized_smem[shared_idx + j * 2]) = - quantized_val; - } - } - } - __syncthreads(); - - int output_cols = K / elem_per_byte; - int num_groups_per_row = K / group_size; - int linear_tid = tidx + tidy * BLOCK_X; - // Write back quantized values -#pragma unroll - for (int i = linear_tid; i < bytes_per_block; i += BLOCK_X * BLOCK_Y) { - int local_row = i / local_cols; - int local_col = i % local_cols; - - int global_row = bidy * rows_per_block + local_row; - int global_col = bidx * local_cols + local_col; - - if (global_row < M && global_col < output_cols) { - int physical_idx = local_row * padded_local_cols + local_col; - out[global_row * output_cols + global_col] = quantized_smem[physical_idx]; - } - } - // Write back scales - constexpr int num_scales = BLOCK_X * BLOCK_Y; -#pragma unroll - for (int i = linear_tid; i < num_scales; i += BLOCK_X * BLOCK_Y) { - int local_row = i / BLOCK_Y; - int local_col = i % BLOCK_Y; - - int global_row = bidy * BLOCK_X + local_row; - int global_col = bidx * BLOCK_Y + local_col; - - if (global_row < M && global_col < num_groups_per_row) { - scales[global_row * num_groups_per_row + global_col] = - scales_smem[local_row][local_col]; - } - } -} - -constexpr size_t TMA_SHMEM_ALIGNMENT = 128; -constexpr size_t ROWS_PER_BLOCK = 128; -constexpr size_t COLS_PER_BLOCK = 128; -constexpr size_t THREADS_PER_BLOCK = 128; -constexpr size_t TILE_M = 32; -constexpr size_t TILE_K = 128; -constexpr size_t BUFFS_NUM = 2; - -template -__global__ void __launch_bounds__(128) fp_quantize_columnwise_tma( - const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - uint8_t* __restrict__ scales, - const size_t rows, - const size_t cols) { -#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) - using Tx2 = cu::Vector2_t; - using Tx4 = cu::Vector4_t; - - constexpr size_t TILE_M = group_size; - constexpr size_t TILE_K = COLS_PER_BLOCK; - constexpr size_t STAGES = ROWS_PER_BLOCK / TILE_M; // 4 for gs=32, 8 for gs=16 - - constexpr int elem_per_byte = (bits == 8) ? 1 : 2; - - const auto block_idx = cg::this_thread_block().group_index(); - const auto idx_in_block = cg::this_thread_block().thread_index(); - const int tidx = idx_in_block.x; // Thread handles column tidx - const bool is_master = (tidx == 0); - - const size_t block_offset_col = block_idx.x * COLS_PER_BLOCK; - const size_t block_offset_row = block_idx.y * ROWS_PER_BLOCK; - - constexpr size_t BUFF_ELEMS = TILE_M * TILE_K; - constexpr size_t in_tile_size = BUFF_ELEMS * sizeof(T); - constexpr size_t in_buff_size_aligned = - ((in_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / - TMA_SHMEM_ALIGNMENT) * - TMA_SHMEM_ALIGNMENT; - - // Transposed output tile: TILE_K rows x TILE_M cols (128 x group_size) - // For FP8: 128 * group_size bytes, for FP4: 128 * (group_size/2) bytes - constexpr size_t out_tile_elems = BUFF_ELEMS / elem_per_byte; - constexpr size_t out_tile_size = out_tile_elems; - constexpr size_t out_buff_size_aligned = - ((out_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / - TMA_SHMEM_ALIGNMENT) * - TMA_SHMEM_ALIGNMENT; - - constexpr size_t scales_per_tile = TILE_K; // One scale per column - constexpr size_t scales_tile_size = scales_per_tile * sizeof(uint8_t); - - extern __shared__ char shared_mem[]; - uintptr_t aligned_shared = - (reinterpret_cast(shared_mem) + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - T* in_sh = reinterpret_cast(aligned_shared); - uint8_t* out_sh = - reinterpret_cast(aligned_shared + in_buff_size_aligned); - uint8_t* scales_sh = reinterpret_cast( - aligned_shared + in_buff_size_aligned + out_buff_size_aligned); - - constexpr uint32_t tile_bytes = static_cast(in_tile_size); - - __shared__ alignas(8) uint64_t mbar[STAGES]; - - T thread_data[group_size]; - uint32_t rbits = 0; // Reserved for stochastic rounding - const size_t scale_stride = rows / group_size; - - // Master thread init memory barriers for all stages - // fence for tma, synchronize threads so all see mbarrier - if (is_master) { -#pragma unroll - for (int iter = 0; iter < STAGES; ++iter) { - ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); - } - ptx::fence_proxy_async_shared_cta(); - } - __syncthreads(); - // Launch first async copy before entering the loop - copy_2d_to_shared( - &in_sh[0], - &tensor_map_input, - static_cast(block_offset_col), - static_cast(block_offset_row), - tile_bytes, - &mbar[0], - is_master); - -#pragma unroll - for (size_t stage = 0; stage < STAGES; ++stage) { - // buffer memory offset in shared memory (we use double buffering for - // pipelining) - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_row_offset = stage * TILE_M; - - if (next_stage < STAGES) { - // before launching another async copy, check that there is less than 2 - // (to ensure that shared -> global synch is finished and buffer can be - // reused) - ptx::cp_async_bulk_wait_group_read<1>(); - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_row_offset = block_offset_row + next_stage * TILE_M; - const size_t next_buff_elem_offset = next_buff * BUFF_ELEMS; - - copy_2d_to_shared( - &in_sh[next_buff_elem_offset], - &tensor_map_input, - static_cast(block_offset_col), - static_cast(next_row_offset), - tile_bytes, - &mbar[next_stage], - is_master); - } - - ptx::fence_proxy_async_shared_cta(); - // Wait until the data is ready, parity is always 0 because for simplicity - // we dont reuse barriers between stages - ptx::mbarrier_wait_parity(&mbar[stage], 0); - const size_t buff_offset = buff * BUFF_ELEMS; - // Read the data from shared to registers -#pragma unroll - for (int row = 0; row < group_size; ++row) { - thread_data[row] = in_sh[buff_offset + row * TILE_K + tidx]; - } - // Compute scale: find max absolute value in the column group - Tx2 amax_2x = Tx2{T(0.0f), T(0.0f)}; -#pragma unroll - for (int row = 0; row < group_size; row += 2) { - auto pair = Tx2{thread_data[row], thread_data[row + 1]}; - cu::absmax_x2(amax_2x, amax_2x, pair); - } - - float scale = - max(fabsf(static_cast(amax_2x.x)), - fabsf(static_cast(amax_2x.y))); - - scale /= (bits == 4) ? 6.0f : 448.0f; - - using ScaleType = - std::conditional_t; - auto s = ScaleType(scale); - scale = float(s); - // Store scale to shared memory buffer - // TODO: currently it results in 4-way bank conflict - scales_sh[buff * TILE_K + tidx] = s.__x; - const size_t out_buff_offset = buff * out_tile_elems; - // Quantize and write output to shared memory -#pragma unroll - for (int j = 0; j < group_size / 4; ++j) { - Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); - if constexpr (bits == 8) { - uint32_t quantized_val = - cu::scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); - *reinterpret_cast( - &out_sh[out_buff_offset + tidx * TILE_M + j * 4]) = quantized_val; - } else { - uint16_t quantized_val = - cu::scale_cvt_Tx4_to_fp4x4(w_Tx4, 1.0f / scale, rbits); - *reinterpret_cast( - &out_sh[out_buff_offset + tidx * (TILE_M / 2) + j * 2]) = - quantized_val; - } - } - __syncthreads(); - // Thread tidx computes scale for input column (block_offset_col + tidx) - // This scale goes to output row (block_offset_col + tidx), column - // (global_row_group) - const size_t global_row_group = - (block_offset_row + stage_row_offset) / group_size; - const size_t global_col = block_offset_col + tidx; - // TODO: scale writing is not good - if (global_col < cols && (block_offset_row + stage_row_offset) < rows) { - scales[global_col * scale_stride + global_row_group] = - scales_sh[buff * TILE_K + tidx]; - } - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - - if (is_master) { - const size_t global_row = block_offset_row + stage_row_offset; - const uint32_t out_x = (bits == 8) - ? static_cast(global_row) - : static_cast(global_row / 2); - const uint32_t out_y = static_cast(block_offset_col); - - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), - out_x, - out_y, - reinterpret_cast(&out_sh[out_buff_offset])); - ptx::cp_async_bulk_commit_group(); - } - } - // Wait for all TMA stores to complete - ptx::cp_async_bulk_wait_group_read<0>(); - - __syncthreads(); - if (is_master) { -#pragma unroll - for (int iter = 0; iter < STAGES; ++iter) { - ptx::mbarrier_invalidate(&mbar[iter]); - } - } -#endif // __CUDA_ARCH__ >= 1000 -} - -template -__global__ void fp_dequantize( - const uint8_t* w, - const uint8_t* scales, - T* out, - size_t size, - float* global_scale = nullptr) { - auto block_size = cg::this_thread_block().dim_threads(); - auto block_idx = cg::this_thread_block().group_index(); - auto idx_in_block = cg::this_thread_block().thread_index(); - - auto tidx = block_idx.x * block_size.x + idx_in_block.x; - auto tidy = block_idx.y * block_size.y + idx_in_block.y; - - auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; - - constexpr int pack_factor = bits == 8 ? 1 : 2; - const bool use_global_scale = global_scale != nullptr; - const float inv_scale_enc = use_mx_scale - ? 1.0f - : (use_global_scale ? (*global_scale) / (F8E4M3_MAX * F4E2M1_MAX) : 1.0f); - size_t offset = tidx + grid_dim_x * size_t(tidy); - size_t oindex = offset * pack_factor; - - if (oindex >= size) { - return; - } - - size_t gindex = oindex / group_size; - using ScaleType = - std::conditional_t; - auto scale = float(((ScaleType*)(scales))[gindex]) * inv_scale_enc; - - out += oindex; - - uint32_t val = w[offset]; -#pragma clang loop unroll(full) - for (int i = 0; i < pack_factor; i++) { - uint8_t d; - if (bits == 4) { - d = (val >> (bits * i)) & 0x0f; - } else if (bits == 8) { - d = val; - } - out[i] = static_cast(scale * Dequantize{}(d)); - } -} - -inline std::tuple get_tma_columnwise_launch_args( +inline std::tuple get_columnwise_tma_launch_args( size_t rows, size_t cols, int group_size, @@ -632,7 +72,7 @@ inline CUtensorMapDataType get_tma_dtype(Dtype dtype) { } inline std::tuple -get_columnwise_quantize_launch_args(size_t size, int group_size, int M, int K) { +get_columnwise_fallback_launch_args(size_t size, int group_size, int M, int K) { constexpr int BLOCK_X = 16; constexpr int BLOCK_Y = 32; int rows_per_block = BLOCK_X; @@ -767,7 +207,7 @@ void fp_quantize_columnwise_fallback( cu::fp_quantize_columnwise_fallback; } auto [num_blocks, block_dims] = - cu::get_columnwise_quantize_launch_args( + cu::get_columnwise_fallback_launch_args( w.size(), group_size, M, K); enc.add_kernel_node( kernel, @@ -806,7 +246,7 @@ void fp_quantize_columnwise_tma( size_t cols = w.shape(-2); size_t stride_bytes = w.strides(-1) * w.itemsize(); - auto [grid, block, smem_size] = cu::get_tma_columnwise_launch_args( + auto [grid, block, smem_size] = cu::get_columnwise_tma_launch_args( rows, cols, group_size, bits, w.itemsize()); CUtensorMap tensor_map_input; @@ -837,30 +277,28 @@ void fp_quantize_columnwise_tma( out_tile_x, output_cols); - dispatch_float_types(w.dtype(), "fp_quantize_columnwise_tma", [&](auto type_tag) { - using T = cuda_type_t; - if constexpr (!std::is_same_v) { - auto kernel = cu::fp_quantize_columnwise_tma; - if (bits == 8) { - kernel = cu::fp_quantize_columnwise_tma; - } else if (group_size == 16) { - kernel = cu::fp_quantize_columnwise_tma; - } - enc.add_kernel_node( - kernel, - grid, - block, - static_cast(smem_size), - tensor_map_input, - tensor_map_output, - gpu_ptr(scales), - rows, - cols); - } else { - throw std::runtime_error( - "[fp_quantize_columnwise_tma] Cannot quantize input with type float64."); - } - }); + dispatch_float_types( + w.dtype(), "fp_quantize_columnwise_mxfp8", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + // Currently only MXFP8 (bits=8, group_size=32) is implemented + // TODO: Add NVFP4 support + auto kernel = cu::fp_quantize_columnwise_tma_mxfp8; + enc.add_kernel_node( + kernel, + grid, + block, + static_cast(smem_size), + tensor_map_input, + tensor_map_output, + gpu_ptr(scales), + rows, + cols); + } else { + throw std::runtime_error( + "[fp_quantize_columnwise_tma] Cannot quantize input with type float64."); + } + }); } void fp_quantize_columnwise( @@ -873,7 +311,6 @@ void fp_quantize_columnwise( cu::CommandEncoder& enc, const Stream& s) { if (enc.device().compute_capability_major() >= 10) { - std::cout << "Using TMA kernel for column-wise quantization" << std::endl; fp_quantize_columnwise_tma( w, wq, scales, group_size, bits, global_scale, enc, s); } else { diff --git a/mlx/backend/cuda/quantized/fp_quantize.cuh b/mlx/backend/cuda/quantized/fp_quantize.cuh new file mode 100644 index 0000000000..3ce7a0a2a3 --- /dev/null +++ b/mlx/backend/cuda/quantized/fp_quantize.cuh @@ -0,0 +1,576 @@ +// Copyright © 2025 Apple Inc. +#include "mlx/backend/cuda/common.h" +#include "mlx/backend/cuda/ptx.cuh" +#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh" +#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh" +#include "mlx/backend/cuda/quantized/quantized_utils.cuh" +#include "mlx/backend/cuda/vector_types.cuh" +#include "mlx/dtype_utils.h" + +#include +#include +#include +#include + +namespace mlx::core { +namespace cu { + +constexpr float F8E4M3_MAX = 448.0f; +constexpr float F4E2M1_MAX = 6.0f; + +// constants for tma kernels +constexpr size_t TMA_SHMEM_ALIGNMENT = 128; +constexpr size_t ROWS_PER_BLOCK = 128; +constexpr size_t COLS_PER_BLOCK = 128; +constexpr size_t THREADS_PER_BLOCK = 128; +constexpr size_t TILE_M = 32; +constexpr size_t TILE_K = 128; +constexpr size_t BUFFS_NUM = 2; + +template +struct Dequantize { + __device__ float operator()(uint8_t x) { + if constexpr (bits == 8) { + return float(*(__nv_fp8_e4m3*)(&x)); + } else { + return float(*(__nv_fp4_e2m1*)(&x)); + } + } +}; + +namespace cg = cooperative_groups; + +template +__global__ void fp_quantize_dequantize( + T* w, + T* out, + size_t size, + float* global_scale = nullptr) { + const bool use_global_scale = global_scale != nullptr; + const float scale_enc = + use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; + const float inv_scale_enc = use_global_scale ? 1.0f / scale_enc : 1.0f; + + using Tx2 = Vector2_t; + using Tx4 = Vector4_t; + uint32_t rbits = 0; // reserved bits for future use + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; + + size_t thread_idx = tidx + grid_dim_x * size_t(tidy); + size_t base_idx = thread_idx * group_size; + + if (base_idx >= size) { + return; + } + + auto w_tile = load_vector(w, thread_idx); + float scale_dec_b = 0.0f; + + Tx2 amax_2x = Tx2{0.0f, 0.0f}; + +#pragma unroll + for (int i = 0; i < group_size; i += 2) { + auto pair = Tx2{w_tile[i], w_tile[i + 1]}; + absmax_x2(amax_2x, amax_2x, pair); + } + + scale_dec_b = static_cast( + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y)))); + + scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; + scale_dec_b *= scale_enc; + // Convert to mx scale or nv scale + using ScaleType = + std::conditional_t; + auto s = ScaleType(scale_dec_b); + float scale_enc_b = scale_enc / float(s); + float scale_dec = float(s) * inv_scale_enc; + AlignedVector w_hat; + +#pragma unroll + for (int i = 0; i < group_size / 4; i++) { + Tx4 w_Tx4 = *reinterpret_cast(&w_tile[i * 4]); + float4 dq; + if constexpr (bits == 8) { + uint32_t quantized_val = + scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); + dq = dequant_fp8(quantized_val); + } else { + uint16_t quantized_val = + scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); + dq = dequant_fp4(quantized_val); + } + w_hat[i * 4] = static_cast(dq.x * scale_dec); + w_hat[i * 4 + 1] = static_cast(dq.y * scale_dec); + w_hat[i * 4 + 2] = static_cast(dq.z * scale_dec); + w_hat[i * 4 + 3] = static_cast(dq.w * scale_dec); + } + store_vector(out, thread_idx, w_hat); +} + +template +__global__ void fp_quantize_rowwise( + T* w, + uint8_t* out, + uint8_t* scales, + size_t size, + float* global_scale = nullptr) { + // NVFP4 conversion: + // Global encode scale: (448 × 6) / *global_scale + // Per-block decode scale: S_dec_b = (block_amax / 6) × S_enc → stored as FP8 + // E4M3 Per-block encode scale: S_enc_b = S_enc / S_dec_b + const bool use_global_scale = global_scale != nullptr; + const float scale_enc = + use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; + + using Tx2 = Vector2_t; + using Tx4 = Vector4_t; + uint32_t rbits = 0; // reserved bits for future use + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; + + size_t thread_idx = tidx + grid_dim_x * size_t(tidy); + size_t base_idx = thread_idx * group_size; + + if (base_idx >= size) { + return; + } + + auto w_tile = load_vector(w, thread_idx); + float scale_dec_b = 0.0f; + + Tx2 amax_2x = Tx2{0.0f, 0.0f}; + +#pragma unroll + for (int i = 0; i < group_size; i += 2) { + auto pair = Tx2{w_tile[i], w_tile[i + 1]}; + absmax_x2(amax_2x, amax_2x, pair); + } + + scale_dec_b = static_cast( + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y)))); + + scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; + scale_dec_b *= scale_enc; + // Convert to mx scale or nv scale + using ScaleType = + std::conditional_t; + auto s = ScaleType(scale_dec_b); + uint8_t q_scale = s.__x; + float scale_enc_b = scale_enc / float(s); + + scales[thread_idx] = q_scale; + constexpr int elem_per_byte = bits == 8 ? 1 : 2; + AlignedVector quantized; + +#pragma unroll + for (int i = 0; i < group_size / 4; i++) { + Tx4 w_Tx4 = *reinterpret_cast(&w_tile[i * 4]); + if constexpr (bits == 8) { + uint32_t quantized_val = + scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); + *reinterpret_cast(&quantized[i * 4]) = quantized_val; + } else { + uint16_t quantized_val = + scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); + *reinterpret_cast(&quantized[i * 2]) = quantized_val; + } + } + store_vector(out, thread_idx, quantized); +} + +template +__global__ void fp_quantize_columnwise_fallback( + T* w, + uint8_t* out, + uint8_t* scales, + size_t size, + int M, + int K, + float* global_scale = nullptr) { + // Input: [M, K] with strides [1, M] (M-major) + // Quantized output: [M, K/elem_per_byte] row-major (K-major) + // Scales: [M, K/group_size] row-major (K-major) + // Quantize along K (last dimension, groups of group_size elements) + const bool use_global_scale = global_scale != nullptr; + const float scale_enc = + use_global_scale ? (F8E4M3_MAX * F4E2M1_MAX) / *global_scale : 1.0f; + + using Tx2 = Vector2_t; + using Tx4 = Vector4_t; + uint32_t rbits = 0; + + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + constexpr int BLOCK_X = 16; + constexpr int BLOCK_Y = 32; + constexpr int elem_per_byte = (bits == 8) ? 1 : 2; + constexpr int bytes_per_group = group_size / elem_per_byte; + + constexpr int rows_per_block = BLOCK_X; + constexpr int cols_per_block = BLOCK_Y * group_size; + constexpr int local_cols = cols_per_block / elem_per_byte; + constexpr int bytes_per_block = rows_per_block * local_cols; + + constexpr int SMEM_PAD = 4; + constexpr int padded_local_cols = local_cols + SMEM_PAD; + + auto tidx = idx_in_block.x; + auto tidy = idx_in_block.y; + + int num_col_blocks = (K + cols_per_block - 1) / cols_per_block; + auto bidx = block_idx.x % num_col_blocks; + auto bidy = block_idx.x / num_col_blocks; + + T thread_data[group_size]; + + __shared__ uint8_t quantized_smem[rows_per_block * padded_local_cols]; + __shared__ uint8_t scales_smem[BLOCK_X][BLOCK_Y + SMEM_PAD]; + + int row_base = bidy * rows_per_block + tidx; + int col_base = bidx * cols_per_block + tidy * group_size; + + bool valid = (row_base < M) && (col_base + group_size <= K); + if (valid) { +#pragma unroll + for (int i = 0; i < group_size; i++) { + auto index = row_base + (col_base + i) * M; + thread_data[i] = w[index]; + } + + // Compute scale + Tx2 amax_2x = Tx2{0.0f, 0.0f}; +#pragma unroll + for (int r = 0; r < group_size; r += 2) { + auto pair = Tx2{thread_data[r], thread_data[r + 1]}; + absmax_x2(amax_2x, amax_2x, pair); + } + float scale_dec_b = + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y))); + scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; + scale_dec_b *= scale_enc; + // Convert to mx scale or nv scale + using ScaleType = + std::conditional_t; + auto s = ScaleType(scale_dec_b); + float scale_enc_b = scale_enc / float(s); + scales_smem[tidx][tidy] = s.__x; + + int shared_idx = tidx * padded_local_cols + tidy * bytes_per_group; + +#pragma unroll + for (int j = 0; j < group_size / 4; j++) { + Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); + if constexpr (bits == 8) { + uint32_t quantized_val = + scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); + *reinterpret_cast(&quantized_smem[shared_idx + j * 4]) = + quantized_val; + } else { + uint16_t quantized_val = + scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); + *reinterpret_cast(&quantized_smem[shared_idx + j * 2]) = + quantized_val; + } + } + } + __syncthreads(); + + int output_cols = K / elem_per_byte; + int num_groups_per_row = K / group_size; + int linear_tid = tidx + tidy * BLOCK_X; + // Write back quantized values +#pragma unroll + for (int i = linear_tid; i < bytes_per_block; i += BLOCK_X * BLOCK_Y) { + int local_row = i / local_cols; + int local_col = i % local_cols; + + int global_row = bidy * rows_per_block + local_row; + int global_col = bidx * local_cols + local_col; + + if (global_row < M && global_col < output_cols) { + int physical_idx = local_row * padded_local_cols + local_col; + out[global_row * output_cols + global_col] = quantized_smem[physical_idx]; + } + } + // Write back scales + constexpr int num_scales = BLOCK_X * BLOCK_Y; +#pragma unroll + for (int i = linear_tid; i < num_scales; i += BLOCK_X * BLOCK_Y) { + int local_row = i / BLOCK_Y; + int local_col = i % BLOCK_Y; + + int global_row = bidy * BLOCK_X + local_row; + int global_col = bidx * BLOCK_Y + local_col; + + if (global_row < M && global_col < num_groups_per_row) { + scales[global_row * num_groups_per_row + global_col] = + scales_smem[local_row][local_col]; + } + } +} + +template +__global__ void __launch_bounds__(128) fp_quantize_columnwise_tma_mxfp8( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + uint8_t* __restrict__ scales, + const size_t rows, + const size_t cols) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) + using Tx2 = Vector2_t; + using Tx4 = Vector4_t; + + constexpr size_t STAGES = ROWS_PER_BLOCK / TILE_M; // 4 for gs=32, 8 for gs=16 + constexpr int elem_per_byte = 1; + + const auto block_idx = cg::this_thread_block().group_index(); + const auto idx_in_block = cg::this_thread_block().thread_index(); + const int tidx = idx_in_block.x; // Thread handles column tidx + const bool is_master = (tidx == 0); + + const size_t block_offset_col = block_idx.x * COLS_PER_BLOCK; + const size_t block_offset_row = block_idx.y * ROWS_PER_BLOCK; + + constexpr size_t BUFF_ELEMS = TILE_M * TILE_K; + constexpr size_t in_tile_size = BUFF_ELEMS * sizeof(T); + constexpr size_t in_buff_size_aligned = + ((in_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + // Transposed output tile: TILE_K rows x TILE_M cols (128 x group_size) + // For FP8: 128 * group_size bytes, for FP4: 128 * (group_size/2) bytes + constexpr size_t out_tile_elems = BUFF_ELEMS / elem_per_byte; + constexpr size_t out_tile_size = out_tile_elems; + constexpr size_t out_buff_size_aligned = + ((out_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + constexpr size_t scales_per_tile = TILE_K; // One scale per column + constexpr size_t scales_tile_size = scales_per_tile * sizeof(uint8_t); + + extern __shared__ char shared_mem[]; + uintptr_t aligned_shared = + (reinterpret_cast(shared_mem) + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + T* in_sh = reinterpret_cast(aligned_shared); + uint8_t* out_sh = + reinterpret_cast(aligned_shared + in_buff_size_aligned); + uint8_t* scales_sh = reinterpret_cast( + aligned_shared + in_buff_size_aligned + out_buff_size_aligned); + + constexpr uint32_t tile_bytes = static_cast(in_tile_size); + + __shared__ alignas(8) uint64_t mbar[STAGES]; + + T thread_data[TILE_M]; + uint32_t rbits = 0; // Reserved for stochastic rounding + const size_t scale_stride = rows / TILE_M; + + // Master thread init memory barriers for all stages + // fence for tma, synchronize threads so all see mbarrier + if (is_master) { +#pragma unroll + for (int iter = 0; iter < STAGES; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + // Launch first async copy before entering the loop + copy_2d_to_shared( + &in_sh[0], + &tensor_map_input, + static_cast(block_offset_col), + static_cast(block_offset_row), + tile_bytes, + &mbar[0], + is_master); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + // buffer memory offset in shared memory (we use double buffering for + // pipelining) + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_row_offset = stage * TILE_M; + + if (next_stage < STAGES) { + // before launching another async copy, check that there is less than 2 + // (to ensure that shared -> global synch is finished and buffer can be + // reused) + ptx::cp_async_bulk_wait_group_read<1>(); + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_row_offset = block_offset_row + next_stage * TILE_M; + const size_t next_buff_elem_offset = next_buff * BUFF_ELEMS; + + copy_2d_to_shared( + &in_sh[next_buff_elem_offset], + &tensor_map_input, + static_cast(block_offset_col), + static_cast(next_row_offset), + tile_bytes, + &mbar[next_stage], + is_master); + } + + ptx::fence_proxy_async_shared_cta(); + // Wait until the data is ready, parity is always 0 because for simplicity + // we dont reuse barriers between stages + ptx::mbarrier_wait_parity(&mbar[stage], 0); + const size_t buff_offset = buff * BUFF_ELEMS; + // Read the data from shared to registers +#pragma unroll + for (int row = 0; row < TILE_M; ++row) { + thread_data[row] = in_sh[buff_offset + row * TILE_K + tidx]; + } + // Compute scale: find max absolute value in the column group + Tx2 amax_2x = Tx2{T(0.0f), T(0.0f)}; +#pragma unroll + for (int row = 0; row < TILE_M; row += 2) { + auto pair = Tx2{thread_data[row], thread_data[row + 1]}; + absmax_x2(amax_2x, amax_2x, pair); + } + + float scale = + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y))); + + scale /= F8E4M3_MAX; + + using ScaleType = __nv_fp8_e8m0; + auto s = ScaleType(scale); + scale = float(s); + // Store scale to shared memory buffer + // TODO: currently it results in 4-way bank conflict + scales_sh[buff * TILE_K + tidx] = s.__x; + const size_t out_buff_offset = buff * out_tile_elems; + // Quantize and write output to shared memory +#pragma unroll + for (int j = 0; j < TILE_M / 4; ++j) { + Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); + uint32_t quantized_val = + cu::scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + *reinterpret_cast( + &out_sh[out_buff_offset + tidx * TILE_M + j * 4]) = quantized_val; + } + __syncthreads(); + // Thread tidx computes scale for input column (block_offset_col + tidx) + // This scale goes to output row (block_offset_col + tidx), column + // (global_row_group) + const size_t global_row_group = + (block_offset_row + stage_row_offset) / TILE_M; + const size_t global_col = block_offset_col + tidx; + // TODO: scale writing is not good + if (global_col < cols && (block_offset_row + stage_row_offset) < rows) { + scales[global_col * scale_stride + global_row_group] = + scales_sh[buff * TILE_K + tidx]; + } + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + + if (is_master) { + const size_t global_row = block_offset_row + stage_row_offset; + const uint32_t out_x = static_cast(global_row); + const uint32_t out_y = static_cast(block_offset_col); + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), + out_x, + out_y, + reinterpret_cast(&out_sh[out_buff_offset])); + ptx::cp_async_bulk_commit_group(); + } + } + // Wait for all TMA stores to complete + ptx::cp_async_bulk_wait_group_read<0>(); + + __syncthreads(); + if (is_master) { +#pragma unroll + for (int iter = 0; iter < STAGES; ++iter) { + ptx::mbarrier_invalidate(&mbar[iter]); + } + } +#endif // __CUDA_ARCH__ >= 1000 +} + +template +__global__ void __launch_bounds__(128) fp_quantize_columnwise_tma_nvfp4( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + uint8_t* __restrict__ scales, + const size_t rows, + const size_t cols, + float* global_scale) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) + // placeholder TODO - NVFP4 TMA kernel not yet implemented +#endif // __CUDA_ARCH__ >= 1000 +} + +template +__global__ void fp_dequantize( + const uint8_t* w, + const uint8_t* scales, + T* out, + size_t size, + float* global_scale = nullptr) { + auto block_size = cg::this_thread_block().dim_threads(); + auto block_idx = cg::this_thread_block().group_index(); + auto idx_in_block = cg::this_thread_block().thread_index(); + + auto tidx = block_idx.x * block_size.x + idx_in_block.x; + auto tidy = block_idx.y * block_size.y + idx_in_block.y; + + auto grid_dim_x = cg::this_grid().dim_blocks().x * block_size.x; + + constexpr int pack_factor = bits == 8 ? 1 : 2; + const bool use_global_scale = global_scale != nullptr; + const float inv_scale_enc = use_mx_scale + ? 1.0f + : (use_global_scale ? (*global_scale) / (F8E4M3_MAX * F4E2M1_MAX) : 1.0f); + size_t offset = tidx + grid_dim_x * size_t(tidy); + size_t oindex = offset * pack_factor; + + if (oindex >= size) { + return; + } + + size_t gindex = oindex / group_size; + using ScaleType = + std::conditional_t; + auto scale = float(((ScaleType*)(scales))[gindex]) * inv_scale_enc; + + out += oindex; + + uint32_t val = w[offset]; +#pragma clang loop unroll(full) + for (int i = 0; i < pack_factor; i++) { + uint8_t d; + if (bits == 4) { + d = (val >> (bits * i)) & 0x0f; + } else if (bits == 8) { + d = val; + } + out[i] = static_cast(scale * Dequantize{}(d)); + } +} + +} // namespace cu +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/qqmm_utils.cu b/mlx/backend/cuda/quantized/qqmm_utils.cu index d19865a3b3..831f69620c 100644 --- a/mlx/backend/cuda/quantized/qqmm_utils.cu +++ b/mlx/backend/cuda/quantized/qqmm_utils.cu @@ -10,11 +10,6 @@ namespace mlx::core { namespace cg = cooperative_groups; -constexpr int TILE_ROWS = 128; -constexpr int TILE_COLS = 4; -constexpr int TILES_PER_LANE = 1; -constexpr int LANES_PER_BLOCK = 32; - // To pass scales to tensor cores, they need to be repacked into a tiled layout // https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout // Tiled layout for scale factors is very well described in CUTLASS @@ -48,6 +43,15 @@ constexpr int LANES_PER_BLOCK = 32; // [252, 253, 254, 255], // [380, 381, 382, 383], // [508, 509, 510, 511]]]]], +namespace cu { + +constexpr float F8E4M3_MAX = 448.0f; +constexpr float F4E2M1_MAX = 6.0f; + +constexpr int TILE_ROWS = 128; +constexpr int TILE_COLS = 4; +constexpr int TILES_PER_LANE = 1; +constexpr int LANES_PER_BLOCK = 32; inline std::tuple get_swizzle_launch_args( size_t M_swizzled, @@ -68,11 +72,6 @@ inline std::tuple get_swizzle_launch_args( return std::make_tuple(grid, block); } -namespace cu { - -constexpr float F8E4M3_MAX = 448.0f; -constexpr float F4E2M1_MAX = 6.0f; - __global__ void compute_qqmm_pointers( float* alpha_out, float* beta_out, @@ -225,7 +224,7 @@ void swizzle_scales( size_t output_cols = scales_tiled.shape(-1); auto [num_blocks, block_dims] = - get_swizzle_launch_args(output_rows, output_cols); + cu::get_swizzle_launch_args(output_rows, output_cols); enc.add_kernel_node( cu::swizzle_scales, num_blocks, diff --git a/mlx/backend/cuda/quantized/quantized_utils.cuh b/mlx/backend/cuda/quantized/quantized_utils.cuh index 8cbe9b297d..a5a918d1da 100644 --- a/mlx/backend/cuda/quantized/quantized_utils.cuh +++ b/mlx/backend/cuda/quantized/quantized_utils.cuh @@ -1,4 +1,5 @@ // Copyright © 2025 Apple Inc. +#pragma once #include #include From 6248b16955c6a796c97e302518419d1c48bc2730 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Wed, 18 Feb 2026 17:30:22 +0100 Subject: [PATCH 04/15] fixed gpu_ptr --- mlx/backend/cuda/quantized/fp_quantize.cu | 46 +++++++++++----------- mlx/backend/cuda/quantized/fp_quantize.cuh | 3 -- 2 files changed, 23 insertions(+), 26 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 359c1b78bc..cfa2899ae9 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -249,38 +249,38 @@ void fp_quantize_columnwise_tma( auto [grid, block, smem_size] = cu::get_columnwise_tma_launch_args( rows, cols, group_size, bits, w.itemsize()); - CUtensorMap tensor_map_input; - CUtensorMap tensor_map_output; - - create_2D_tensor_map( - &tensor_map_input, - const_cast(static_cast(w.data())), - cu::get_tma_dtype(w.dtype()), - rows, - cols, - static_cast(cu::TILE_M), - static_cast(cu::TILE_K), - stride_bytes); - const size_t output_rows = cols; const size_t output_cols = (bits == 8) ? rows : rows / 2; const uint32_t out_tile_x = (bits == 8) ? cu::TILE_M : cu::TILE_M / 2; const uint32_t out_tile_y = cu::TILE_K; - create_2D_tensor_map( - &tensor_map_output, - static_cast(wq.data()), - CU_TENSOR_MAP_DATA_TYPE_UINT8, - output_rows, - output_cols, - out_tile_y, - out_tile_x, - output_cols); - dispatch_float_types( w.dtype(), "fp_quantize_columnwise_mxfp8", [&](auto type_tag) { using T = cuda_type_t; if constexpr (!std::is_same_v) { + CUtensorMap tensor_map_input; + CUtensorMap tensor_map_output; + + create_2D_tensor_map( + &tensor_map_input, + const_cast(static_cast(gpu_ptr(w))), + cu::get_tma_dtype(w.dtype()), + rows, + cols, + static_cast(cu::TILE_M), + static_cast(cu::TILE_K), + stride_bytes); + + create_2D_tensor_map( + &tensor_map_output, + gpu_ptr(wq), + CU_TENSOR_MAP_DATA_TYPE_UINT8, + output_rows, + output_cols, + out_tile_y, + out_tile_x, + output_cols); + // Currently only MXFP8 (bits=8, group_size=32) is implemented // TODO: Add NVFP4 support auto kernel = cu::fp_quantize_columnwise_tma_mxfp8; diff --git a/mlx/backend/cuda/quantized/fp_quantize.cuh b/mlx/backend/cuda/quantized/fp_quantize.cuh index 3ce7a0a2a3..97e2fa4e89 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cuh +++ b/mlx/backend/cuda/quantized/fp_quantize.cuh @@ -361,9 +361,6 @@ __global__ void __launch_bounds__(128) fp_quantize_columnwise_tma_mxfp8( TMA_SHMEM_ALIGNMENT) * TMA_SHMEM_ALIGNMENT; - constexpr size_t scales_per_tile = TILE_K; // One scale per column - constexpr size_t scales_tile_size = scales_per_tile * sizeof(uint8_t); - extern __shared__ char shared_mem[]; uintptr_t aligned_shared = (reinterpret_cast(shared_mem) + TMA_SHMEM_ALIGNMENT - 1) & From a2e0f4d3e2e95093cbc4dc85de2a5799eabce52f Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Fri, 20 Feb 2026 05:48:02 +0100 Subject: [PATCH 05/15] [wip] rowwise tma, columnwise tma --- mlx/backend/cuda/quantized/fp_quantize.cu | 333 +++++++++---- mlx/backend/cuda/quantized/fp_quantize.cuh | 216 +-------- .../cuda/quantized/fp_quantize_tma.cuh | 439 ++++++++++++++++++ .../cuda/quantized/quantized_utils.cuh | 3 + 4 files changed, 681 insertions(+), 310 deletions(-) create mode 100644 mlx/backend/cuda/quantized/fp_quantize_tma.cuh diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index cfa2899ae9..fc567146da 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -3,6 +3,7 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/quantized/fp_quantize.cuh" +#include "mlx/backend/cuda/quantized/fp_quantize_tma.cuh" #include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/quantized/quantized_utils.cuh" #include "mlx/backend/cuda/vector_types.cuh" @@ -17,23 +18,29 @@ namespace mlx::core { namespace cu { -inline std::tuple get_columnwise_tma_launch_args( - size_t rows, - size_t cols, - int group_size, - int bits, - int bytes_per_element) { +template < + int TILE_M, + int TILE_K, + int THREADS_PER_BLOCK, + int STAGES, + int SCALES_PER_STAGE> +inline std::tuple get_tma_launch_args( + size_t grid_dim_x_size, + size_t grid_dim_y_size, + size_t block_size_x, + size_t block_size_y, + int in_size_bytes, + int bits) { dim3 grid; - grid.x = (cols + COLS_PER_BLOCK - 1) / COLS_PER_BLOCK; - grid.y = (rows + ROWS_PER_BLOCK - 1) / ROWS_PER_BLOCK; + grid.x = (grid_dim_x_size + block_size_x - 1) / block_size_x; + grid.y = (grid_dim_y_size + block_size_y - 1) / block_size_y; grid.z = 1; dim3 block(THREADS_PER_BLOCK, 1, 1); - const int elem_per_byte = (bits == 8) ? 1 : 2; - - const size_t BUFF_ELEMS = TILE_M * TILE_K; - const size_t in_tile_size = BUFF_ELEMS * bytes_per_element; + const int elem_per_byte = bits == 8 ? 1 : 2; + constexpr size_t BUFF_ELEMS = TILE_M * TILE_K; + const size_t in_tile_size = BUFF_ELEMS * in_size_bytes; const size_t in_buff_size_aligned = ((in_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / TMA_SHMEM_ALIGNMENT) * @@ -45,10 +52,9 @@ inline std::tuple get_columnwise_tma_launch_args( TMA_SHMEM_ALIGNMENT) * TMA_SHMEM_ALIGNMENT; - const size_t scales_tile_size = TILE_K * sizeof(uint8_t); + const size_t scales_tile_size = STAGES * SCALES_PER_STAGE * sizeof(uint8_t); const size_t scales_buff_size_aligned = - ((scales_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / - TMA_SHMEM_ALIGNMENT) * + ((scales_tile_size + TMA_SHMEM_ALIGNMENT - 1) / TMA_SHMEM_ALIGNMENT) * TMA_SHMEM_ALIGNMENT; const size_t smem_size = in_buff_size_aligned + out_buff_size_aligned + @@ -131,7 +137,7 @@ void fp_quantize_dequantize( }); } -void fp_quantize_rowwise( +void fp_quantize_rowwise_fallback( const array& w, array& wq, array& scales, @@ -146,35 +152,142 @@ void fp_quantize_rowwise( } enc.set_output_array(wq); enc.set_output_array(scales); - dispatch_float_types(w.dtype(), "fp_quantize_rowwise", [&](auto type_tag) { - using T = cuda_type_t; - if constexpr (!std::is_same_v) { - auto kernel = cu::fp_quantize_rowwise; - if (bits == 8) { - kernel = cu::fp_quantize_rowwise; - } else if (group_size == 16) { - kernel = cu::fp_quantize_rowwise; + dispatch_float_types( + w.dtype(), "fp_quantize_rowwise_fallback", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + auto kernel = cu::fp_quantize_rowwise_fallback; + if (bits == 8) { + kernel = cu::fp_quantize_rowwise_fallback; + } else if (group_size == 16) { + kernel = cu::fp_quantize_rowwise_fallback; + } + bool large = w.size() > UINT_MAX; + auto [num_blocks, block_dims] = get_launch_args( + w.size(), w.shape(), w.strides(), large, group_size); + + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + gpu_ptr(w), + gpu_ptr(wq), + gpu_ptr(scales), + w.size(), + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); + } else { + throw std::runtime_error( + "[Quantize::eval_gpu] Can not quantize input with type float64."); + } + }); +} + +void fp_quantize_rowwise_tma( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + const std::optional& global_scale /* = std::nullopt */, + cu::CommandEncoder& enc, + const Stream& s) { + enc.set_input_array(w); + enc.set_output_array(wq); + enc.set_output_array(scales); + + size_t rows = w.shape(-2); + size_t cols = w.shape(-1); + size_t stride_bytes = cols * w.itemsize(); + + if (bits == 8 && group_size == 32) { + dispatch_float_types(w.dtype(), "fp_quantize_rowwise_mxfp8", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + constexpr int THREADS_PER_BLOCK = 64; + constexpr int ROWS_PER_BLOCK = 64; + constexpr int COLS_PER_BLOCK = 64; + constexpr size_t TILE_M = ROWS_PER_BLOCK; + constexpr size_t TILE_K = 32; // group_size + constexpr size_t STAGES = COLS_PER_BLOCK / TILE_K; + + auto [grid, block, smem_size] = cu::get_tma_launch_args< + TILE_M, + TILE_K, + THREADS_PER_BLOCK, + STAGES, + TILE_M>( + rows, cols, ROWS_PER_BLOCK, COLS_PER_BLOCK, w.itemsize(), bits); + + CUtensorMap tensor_map_input; + CUtensorMap tensor_map_output; + + create_2D_tensor_map( + &tensor_map_input, + const_cast(static_cast(gpu_ptr(w))), + cu::get_tma_dtype(w.dtype()), + rows, + cols, + TILE_M, + TILE_K, + stride_bytes); + + create_2D_tensor_map( + &tensor_map_output, + gpu_ptr(wq), + CU_TENSOR_MAP_DATA_TYPE_UINT8, + rows, + cols, + TILE_M, + TILE_K, + cols); + + auto kernel = cu::fp_quantize_rowwise_tma_mxfp8< + T, + TILE_K, + false, + THREADS_PER_BLOCK, + ROWS_PER_BLOCK, + COLS_PER_BLOCK>; + enc.add_kernel_node( + kernel, + grid, + block, + static_cast(smem_size), + tensor_map_input, + tensor_map_output, + gpu_ptr(scales), + rows, + cols); + } else { + throw std::runtime_error( + "[fp_quantize_rowwise_tma] Cannot quantize input with type float64."); } - bool large = w.size() > UINT_MAX; - auto [num_blocks, block_dims] = - get_launch_args(w.size(), w.shape(), w.strides(), large, group_size); + }); + } else { + throw std::runtime_error( + "[fp_quantize_rowwise_tma] TMA quantization only implemented for bits=8 and group_size=32."); + } +} - enc.add_kernel_node( - kernel, - num_blocks, - block_dims, - 0, - gpu_ptr(w), - gpu_ptr(wq), - gpu_ptr(scales), - w.size(), - global_scale.has_value() ? gpu_ptr(global_scale.value()) - : nullptr); - } else { - throw std::runtime_error( - "[Quantize::eval_gpu] Can not quantize input with type float64."); - } - }); +void fp_quantize_rowwise( + const array& w, + array& wq, + array& scales, + int group_size, + int bits, + const std::optional& global_scale /* = std::nullopt */, + cu::CommandEncoder& enc, + const Stream& s) { + if (enc.device().compute_capability_major() >= 10 && bits == 8 && + group_size == 32) { + fp_quantize_rowwise_tma( + w, wq, scales, group_size, bits, global_scale, enc, s); + } else { + fp_quantize_rowwise_fallback( + w, wq, scales, group_size, bits, global_scale, enc, s); + } } void fp_quantize_columnwise_fallback( @@ -246,59 +359,80 @@ void fp_quantize_columnwise_tma( size_t cols = w.shape(-2); size_t stride_bytes = w.strides(-1) * w.itemsize(); - auto [grid, block, smem_size] = cu::get_columnwise_tma_launch_args( - rows, cols, group_size, bits, w.itemsize()); - - const size_t output_rows = cols; - const size_t output_cols = (bits == 8) ? rows : rows / 2; - const uint32_t out_tile_x = (bits == 8) ? cu::TILE_M : cu::TILE_M / 2; - const uint32_t out_tile_y = cu::TILE_K; - - dispatch_float_types( - w.dtype(), "fp_quantize_columnwise_mxfp8", [&](auto type_tag) { - using T = cuda_type_t; - if constexpr (!std::is_same_v) { - CUtensorMap tensor_map_input; - CUtensorMap tensor_map_output; - - create_2D_tensor_map( - &tensor_map_input, - const_cast(static_cast(gpu_ptr(w))), - cu::get_tma_dtype(w.dtype()), - rows, - cols, - static_cast(cu::TILE_M), - static_cast(cu::TILE_K), - stride_bytes); - - create_2D_tensor_map( - &tensor_map_output, - gpu_ptr(wq), - CU_TENSOR_MAP_DATA_TYPE_UINT8, - output_rows, - output_cols, - out_tile_y, - out_tile_x, - output_cols); - - // Currently only MXFP8 (bits=8, group_size=32) is implemented - // TODO: Add NVFP4 support - auto kernel = cu::fp_quantize_columnwise_tma_mxfp8; - enc.add_kernel_node( - kernel, - grid, - block, - static_cast(smem_size), - tensor_map_input, - tensor_map_output, - gpu_ptr(scales), - rows, - cols); - } else { - throw std::runtime_error( - "[fp_quantize_columnwise_tma] Cannot quantize input with type float64."); - } - }); + // 16 bytes is minimum for TMA tile size along contigious dimension + // For nvfp4 (group_size = 16), we can't have a tile = group_size because we + // write transposed output so we will do 2 passes on the input tile + if (bits == 8 && group_size == 32) { + dispatch_float_types( + w.dtype(), "fp_quantize_columnwise_mxfp8", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + constexpr int THREADS_PER_BLOCK = 64; + constexpr int ROWS_PER_BLOCK = 64; + constexpr int COLS_PER_BLOCK = 64; + constexpr size_t TILE_M = 32; + constexpr size_t TILE_K = COLS_PER_BLOCK; + constexpr size_t STAGES = ROWS_PER_BLOCK / TILE_M; + + // For columnwise: grid.x = cols, grid.y = rows + // scales_per_stage = TILE_K (one scale per column per stage) + auto [grid, block, smem_size] = cu::get_tma_launch_args< + TILE_M, + TILE_K, + THREADS_PER_BLOCK, + STAGES, + TILE_K>( + cols, rows, COLS_PER_BLOCK, ROWS_PER_BLOCK, w.itemsize(), bits); + + CUtensorMap tensor_map_input; + CUtensorMap tensor_map_output; + + create_2D_tensor_map( + &tensor_map_input, + const_cast(static_cast(gpu_ptr(w))), + cu::get_tma_dtype(w.dtype()), + rows, + cols, + TILE_M, + TILE_K, + stride_bytes); + + create_2D_tensor_map( + &tensor_map_output, + gpu_ptr(wq), + CU_TENSOR_MAP_DATA_TYPE_UINT8, + cols, + rows, + TILE_K, + TILE_M, + rows); + + auto kernel = cu::fp_quantize_columnwise_tma_mxfp8< + T, + false, + THREADS_PER_BLOCK, + COLS_PER_BLOCK, + ROWS_PER_BLOCK>; + enc.add_kernel_node( + kernel, + grid, + block, + static_cast(smem_size), + tensor_map_input, + tensor_map_output, + gpu_ptr(scales), + rows, + cols); + } else { + throw std::runtime_error( + "[fp_quantize_columnwise_tma] Cannot quantize input with type float64."); + } + }); + } else { + // NVFP4 or other configurations - TMA kernel not implemented yet + throw std::runtime_error( + "[fp_quantize_columnwise_tma] TMA quantization only implemented for bits=8 and group_size=32."); + } } void fp_quantize_columnwise( @@ -310,7 +444,12 @@ void fp_quantize_columnwise( const std::optional& global_scale /* = std::nullopt */, cu::CommandEncoder& enc, const Stream& s) { - if (enc.device().compute_capability_major() >= 10) { + // Use TMA version for SM100+ with MXFP8 (bits=8, group_size=32) + // NVFP4 todo + bool use_tma = + (enc.device().compute_capability_major() >= 10 && bits == 8 && + group_size == 32 && !global_scale.has_value()); + if (use_tma) { fp_quantize_columnwise_tma( w, wq, scales, group_size, bits, global_scale, enc, s); } else { diff --git a/mlx/backend/cuda/quantized/fp_quantize.cuh b/mlx/backend/cuda/quantized/fp_quantize.cuh index 97e2fa4e89..bf89061b00 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cuh +++ b/mlx/backend/cuda/quantized/fp_quantize.cuh @@ -1,6 +1,6 @@ // Copyright © 2025 Apple Inc. -#include "mlx/backend/cuda/common.h" -#include "mlx/backend/cuda/ptx.cuh" +#pragma once + #include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh" #include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh" #include "mlx/backend/cuda/quantized/quantized_utils.cuh" @@ -15,18 +15,6 @@ namespace mlx::core { namespace cu { -constexpr float F8E4M3_MAX = 448.0f; -constexpr float F4E2M1_MAX = 6.0f; - -// constants for tma kernels -constexpr size_t TMA_SHMEM_ALIGNMENT = 128; -constexpr size_t ROWS_PER_BLOCK = 128; -constexpr size_t COLS_PER_BLOCK = 128; -constexpr size_t THREADS_PER_BLOCK = 128; -constexpr size_t TILE_M = 32; -constexpr size_t TILE_K = 128; -constexpr size_t BUFFS_NUM = 2; - template struct Dequantize { __device__ float operator()(uint8_t x) { @@ -115,7 +103,7 @@ __global__ void fp_quantize_dequantize( } template -__global__ void fp_quantize_rowwise( +__global__ void fp_quantize_rowwise_fallback( T* w, uint8_t* out, uint8_t* scales, @@ -323,204 +311,6 @@ __global__ void fp_quantize_columnwise_fallback( } } -template -__global__ void __launch_bounds__(128) fp_quantize_columnwise_tma_mxfp8( - const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - uint8_t* __restrict__ scales, - const size_t rows, - const size_t cols) { -#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) - using Tx2 = Vector2_t; - using Tx4 = Vector4_t; - - constexpr size_t STAGES = ROWS_PER_BLOCK / TILE_M; // 4 for gs=32, 8 for gs=16 - constexpr int elem_per_byte = 1; - - const auto block_idx = cg::this_thread_block().group_index(); - const auto idx_in_block = cg::this_thread_block().thread_index(); - const int tidx = idx_in_block.x; // Thread handles column tidx - const bool is_master = (tidx == 0); - - const size_t block_offset_col = block_idx.x * COLS_PER_BLOCK; - const size_t block_offset_row = block_idx.y * ROWS_PER_BLOCK; - - constexpr size_t BUFF_ELEMS = TILE_M * TILE_K; - constexpr size_t in_tile_size = BUFF_ELEMS * sizeof(T); - constexpr size_t in_buff_size_aligned = - ((in_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / - TMA_SHMEM_ALIGNMENT) * - TMA_SHMEM_ALIGNMENT; - - // Transposed output tile: TILE_K rows x TILE_M cols (128 x group_size) - // For FP8: 128 * group_size bytes, for FP4: 128 * (group_size/2) bytes - constexpr size_t out_tile_elems = BUFF_ELEMS / elem_per_byte; - constexpr size_t out_tile_size = out_tile_elems; - constexpr size_t out_buff_size_aligned = - ((out_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / - TMA_SHMEM_ALIGNMENT) * - TMA_SHMEM_ALIGNMENT; - - extern __shared__ char shared_mem[]; - uintptr_t aligned_shared = - (reinterpret_cast(shared_mem) + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - T* in_sh = reinterpret_cast(aligned_shared); - uint8_t* out_sh = - reinterpret_cast(aligned_shared + in_buff_size_aligned); - uint8_t* scales_sh = reinterpret_cast( - aligned_shared + in_buff_size_aligned + out_buff_size_aligned); - - constexpr uint32_t tile_bytes = static_cast(in_tile_size); - - __shared__ alignas(8) uint64_t mbar[STAGES]; - - T thread_data[TILE_M]; - uint32_t rbits = 0; // Reserved for stochastic rounding - const size_t scale_stride = rows / TILE_M; - - // Master thread init memory barriers for all stages - // fence for tma, synchronize threads so all see mbarrier - if (is_master) { -#pragma unroll - for (int iter = 0; iter < STAGES; ++iter) { - ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); - } - ptx::fence_proxy_async_shared_cta(); - } - __syncthreads(); - // Launch first async copy before entering the loop - copy_2d_to_shared( - &in_sh[0], - &tensor_map_input, - static_cast(block_offset_col), - static_cast(block_offset_row), - tile_bytes, - &mbar[0], - is_master); - -#pragma unroll - for (size_t stage = 0; stage < STAGES; ++stage) { - // buffer memory offset in shared memory (we use double buffering for - // pipelining) - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_row_offset = stage * TILE_M; - - if (next_stage < STAGES) { - // before launching another async copy, check that there is less than 2 - // (to ensure that shared -> global synch is finished and buffer can be - // reused) - ptx::cp_async_bulk_wait_group_read<1>(); - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_row_offset = block_offset_row + next_stage * TILE_M; - const size_t next_buff_elem_offset = next_buff * BUFF_ELEMS; - - copy_2d_to_shared( - &in_sh[next_buff_elem_offset], - &tensor_map_input, - static_cast(block_offset_col), - static_cast(next_row_offset), - tile_bytes, - &mbar[next_stage], - is_master); - } - - ptx::fence_proxy_async_shared_cta(); - // Wait until the data is ready, parity is always 0 because for simplicity - // we dont reuse barriers between stages - ptx::mbarrier_wait_parity(&mbar[stage], 0); - const size_t buff_offset = buff * BUFF_ELEMS; - // Read the data from shared to registers -#pragma unroll - for (int row = 0; row < TILE_M; ++row) { - thread_data[row] = in_sh[buff_offset + row * TILE_K + tidx]; - } - // Compute scale: find max absolute value in the column group - Tx2 amax_2x = Tx2{T(0.0f), T(0.0f)}; -#pragma unroll - for (int row = 0; row < TILE_M; row += 2) { - auto pair = Tx2{thread_data[row], thread_data[row + 1]}; - absmax_x2(amax_2x, amax_2x, pair); - } - - float scale = - max(fabsf(static_cast(amax_2x.x)), - fabsf(static_cast(amax_2x.y))); - - scale /= F8E4M3_MAX; - - using ScaleType = __nv_fp8_e8m0; - auto s = ScaleType(scale); - scale = float(s); - // Store scale to shared memory buffer - // TODO: currently it results in 4-way bank conflict - scales_sh[buff * TILE_K + tidx] = s.__x; - const size_t out_buff_offset = buff * out_tile_elems; - // Quantize and write output to shared memory -#pragma unroll - for (int j = 0; j < TILE_M / 4; ++j) { - Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); - uint32_t quantized_val = - cu::scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); - *reinterpret_cast( - &out_sh[out_buff_offset + tidx * TILE_M + j * 4]) = quantized_val; - } - __syncthreads(); - // Thread tidx computes scale for input column (block_offset_col + tidx) - // This scale goes to output row (block_offset_col + tidx), column - // (global_row_group) - const size_t global_row_group = - (block_offset_row + stage_row_offset) / TILE_M; - const size_t global_col = block_offset_col + tidx; - // TODO: scale writing is not good - if (global_col < cols && (block_offset_row + stage_row_offset) < rows) { - scales[global_col * scale_stride + global_row_group] = - scales_sh[buff * TILE_K + tidx]; - } - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - - if (is_master) { - const size_t global_row = block_offset_row + stage_row_offset; - const uint32_t out_x = static_cast(global_row); - const uint32_t out_y = static_cast(block_offset_col); - - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), - out_x, - out_y, - reinterpret_cast(&out_sh[out_buff_offset])); - ptx::cp_async_bulk_commit_group(); - } - } - // Wait for all TMA stores to complete - ptx::cp_async_bulk_wait_group_read<0>(); - - __syncthreads(); - if (is_master) { -#pragma unroll - for (int iter = 0; iter < STAGES; ++iter) { - ptx::mbarrier_invalidate(&mbar[iter]); - } - } -#endif // __CUDA_ARCH__ >= 1000 -} - -template -__global__ void __launch_bounds__(128) fp_quantize_columnwise_tma_nvfp4( - const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - uint8_t* __restrict__ scales, - const size_t rows, - const size_t cols, - float* global_scale) { -#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) - // placeholder TODO - NVFP4 TMA kernel not yet implemented -#endif // __CUDA_ARCH__ >= 1000 -} - template __global__ void fp_dequantize( const uint8_t* w, diff --git a/mlx/backend/cuda/quantized/fp_quantize_tma.cuh b/mlx/backend/cuda/quantized/fp_quantize_tma.cuh new file mode 100644 index 0000000000..0589f8d122 --- /dev/null +++ b/mlx/backend/cuda/quantized/fp_quantize_tma.cuh @@ -0,0 +1,439 @@ +// Copyright © 2026 Apple Inc. +#pragma once + +#include "mlx/backend/cuda/common.h" +#include "mlx/backend/cuda/ptx.cuh" +#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh" +#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh" +#include "mlx/backend/cuda/quantized/quantized_utils.cuh" +#include "mlx/backend/cuda/vector_types.cuh" +#include "mlx/dtype_utils.h" + +#include +#include +#include +#include + +namespace mlx::core { +namespace cu { + +constexpr size_t TMA_SHMEM_ALIGNMENT = 128; +constexpr size_t BUFFS_NUM = 2; +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4); // 32 banks, 4 bytes + +namespace cg = cooperative_groups; + +// TMA-based rowwise quantization for MXFP8 +// Input: [rows, cols] row-major +// Output: [rows, cols / 32 / group_size] row-major +// Scales: [rows, cols / group_size] - one scale per group +// Each thread processes one row of group_size elements per stage, +// Each block processes TILE_M x TILE_K elements per stage. +template < + typename T, + int group_size, + bool USE_SR, + int THREADS_PER_BLOCK, + int ROWS_PER_BLOCK, + int COLS_PER_BLOCK> +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + fp_quantize_rowwise_tma_mxfp8( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + uint8_t* __restrict__ scales, + const size_t rows, + const size_t cols) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) + using Tx2 = Vector2_t; + using Tx4 = Vector4_t; + + constexpr size_t TILE_M = ROWS_PER_BLOCK; + constexpr size_t TILE_K = group_size; + constexpr size_t STAGES = COLS_PER_BLOCK / group_size; + + const auto block_idx = cg::this_thread_block().group_index(); + const auto idx_in_block = cg::this_thread_block().thread_index(); + const int tidx = idx_in_block.x; // Thread handles row tidx + const bool is_master = (tidx == 0); + const bool is_active = (tidx < ROWS_PER_BLOCK); + + const size_t block_offset_row = block_idx.x * ROWS_PER_BLOCK; + const size_t block_offset_col = block_idx.y * COLS_PER_BLOCK; + + constexpr size_t BUFF_ELEMS = TILE_M * TILE_K; + constexpr size_t in_tile_size = BUFF_ELEMS * sizeof(T); + constexpr size_t in_buff_size_aligned = + ((in_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + constexpr size_t out_tile_elems = BUFF_ELEMS; + constexpr size_t out_buff_size_aligned = + ((out_tile_elems * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + extern __shared__ char shared_mem[]; + uintptr_t aligned_shared = + (reinterpret_cast(shared_mem) + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + T* in_sh = reinterpret_cast(aligned_shared); + uint8_t* out_sh = + reinterpret_cast(aligned_shared + in_buff_size_aligned); + uint8_t* scales_sh = reinterpret_cast( + aligned_shared + in_buff_size_aligned + out_buff_size_aligned); + + constexpr uint32_t tile_bytes = static_cast(in_tile_size); + + __shared__ alignas(8) uint64_t mbar[STAGES]; // 8 bytes alignment for mbarrier + + T thread_data[TILE_K]; // register storage for one row of TILE_K elements + uint32_t rbits = 0; + + const size_t groups_per_row = cols / TILE_K; + + if (is_master) { +#pragma unroll + for (int iter = 0; iter < STAGES; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + + copy_2d_to_shared( + &in_sh[0], + &tensor_map_input, + static_cast(block_offset_col), + static_cast(block_offset_row), + tile_bytes, + &mbar[0], + is_master); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_col_offset = stage * TILE_K; + + if (next_stage < STAGES) { + ptx::cp_async_bulk_wait_group_read<1>(); + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_col_offset = block_offset_col + next_stage * TILE_K; + const size_t next_buff_elem_offset = next_buff * BUFF_ELEMS; + + copy_2d_to_shared( + &in_sh[next_buff_elem_offset], + &tensor_map_input, + static_cast(next_col_offset), + static_cast(block_offset_row), + tile_bytes, + &mbar[next_stage], + is_master); + } + + ptx::fence_proxy_async_shared_cta(); + ptx::mbarrier_wait_parity(&mbar[stage], 0); + + const size_t buff_offset = buff * BUFF_ELEMS; + const size_t out_buff_offset = buff * out_tile_elems; + const int lane = tidx % 32; + constexpr int GROUPS = TILE_K / 4; + if (is_active) { + // First made 2 naive mistakes were made: + // 1. Stored without swizzling that results in 32 / 16 bank conflicts + // (depending of the input type) + // 2. More naive mistake was indexing registes with swizzling index that + // is known only at compile time this results in massive slow down of + // course +#pragma unroll + for (int j = 0; j < GROUPS; ++j) { + int swizzled_j = (j + lane) % GROUPS; + *reinterpret_cast(&thread_data[j * 4]) = *reinterpret_cast( + &in_sh[buff_offset + tidx * TILE_K + swizzled_j * 4]); + } + // Compute scale: find max absolute value in the row (order-independent) + Tx2 amax_2x = Tx2{T(0.0f), T(0.0f)}; +#pragma unroll + for (int col = 0; col < TILE_K; col += 2) { + auto pair = Tx2{thread_data[col], thread_data[col + 1]}; + absmax_x2(amax_2x, amax_2x, pair); + } + + float scale = + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y))); + + scale /= F8E4M3_MAX; + + using ScaleType = __nv_fp8_e8m0; + auto s = ScaleType(scale); + scale = float(s); + scales_sh[stage * TILE_M + tidx] = s.__x; + +#pragma unroll + for (int j = 0; j < GROUPS; ++j) { + int swizzled_j = (j + lane) % GROUPS; + Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); + uint32_t quantized_val = + cu::scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + *reinterpret_cast( + &out_sh[out_buff_offset + tidx * TILE_K + swizzled_j * 4]) = + quantized_val; + } + } + __syncthreads(); + + if (is_master) { + const size_t global_col = block_offset_col + stage_col_offset; + const uint32_t out_x = static_cast(global_col); + const uint32_t out_y = static_cast(block_offset_row); + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), + out_x, + out_y, + reinterpret_cast(&out_sh[out_buff_offset])); + ptx::cp_async_bulk_commit_group(); + } + } + + ptx::cp_async_bulk_wait_group_read<0>(); + __syncthreads(); + + const size_t global_row = block_offset_row + tidx; + const size_t base_group = block_offset_col / TILE_K; + + if (is_active && global_row < rows) { +#pragma unroll + for (size_t s = 0; s < STAGES; ++s) { + const size_t group_idx = base_group + s; + if ((block_offset_col + s * TILE_K) < cols) { + scales[global_row * groups_per_row + group_idx] = + scales_sh[s * TILE_M + tidx]; + } + } + } + + if (is_master) { +#pragma unroll + for (int iter = 0; iter < STAGES; ++iter) { + ptx::mbarrier_invalidate(&mbar[iter]); + } + } +#endif // __CUDA_ARCH__ >= 1000 +} + +template < + typename T, + bool USE_SR, + int THREADS_PER_BLOCK, + int COLS_PER_BLOCK, + int ROWS_PER_BLOCK> +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + fp_quantize_columnwise_tma_mxfp8( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + uint8_t* __restrict__ scales, + const size_t rows, + const size_t cols) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) + using Tx2 = Vector2_t; + using Tx4 = Vector4_t; + + constexpr size_t TILE_M = 32; + constexpr size_t TILE_K = COLS_PER_BLOCK; + constexpr size_t STAGES = ROWS_PER_BLOCK / TILE_M; + constexpr int elem_per_byte = 1; + + const auto block_idx = cg::this_thread_block().group_index(); + const auto idx_in_block = cg::this_thread_block().thread_index(); + const int tidx = idx_in_block.x; // Thread handles column tidx + const bool is_master = (tidx == 0); + const bool is_active = (tidx < COLS_PER_BLOCK); + + const size_t block_offset_col = block_idx.x * COLS_PER_BLOCK; + const size_t block_offset_row = block_idx.y * ROWS_PER_BLOCK; + + constexpr size_t BUFF_ELEMS = TILE_M * TILE_K; + constexpr size_t in_tile_size = BUFF_ELEMS * sizeof(T); + constexpr size_t in_buff_size_aligned = + ((in_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + constexpr size_t out_tile_elems = BUFF_ELEMS / elem_per_byte; + constexpr size_t out_tile_size = out_tile_elems; + constexpr size_t out_buff_size_aligned = + ((out_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + extern __shared__ char shared_mem[]; + uintptr_t aligned_shared = + (reinterpret_cast(shared_mem) + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + T* in_sh = reinterpret_cast(aligned_shared); + uint8_t* out_sh = + reinterpret_cast(aligned_shared + in_buff_size_aligned); + uint8_t* scales_sh = reinterpret_cast( + aligned_shared + in_buff_size_aligned + out_buff_size_aligned); + + constexpr uint32_t tile_bytes = static_cast(in_tile_size); + + __shared__ alignas(8) uint64_t mbar[STAGES]; + + T thread_data[TILE_M]; + uint32_t rbits = 0; // Reserved for stochastic rounding + const size_t scale_stride = rows / TILE_M; + + // Master thread init memory barriers for all stages + // fence for tma, synchronize threads so all see mbarrier + if (is_master) { +#pragma unroll + for (int iter = 0; iter < STAGES; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + // Launch first async copy before entering the loop + copy_2d_to_shared( + &in_sh[0], + &tensor_map_input, + static_cast(block_offset_col), + static_cast(block_offset_row), + tile_bytes, + &mbar[0], + is_master); + +#pragma unroll + for (size_t stage = 0; stage < STAGES; ++stage) { + // buffer memory offset in shared memory (we use double buffering for + // pipelining) + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_row_offset = stage * TILE_M; + + if (next_stage < STAGES) { + // before launching another async copy, check that there is less than 2 + // (to ensure that shared -> global synch is finished and buffer can be + // reused) + ptx::cp_async_bulk_wait_group_read<1>(); + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_row_offset = block_offset_row + next_stage * TILE_M; + const size_t next_buff_elem_offset = next_buff * BUFF_ELEMS; + + copy_2d_to_shared( + &in_sh[next_buff_elem_offset], + &tensor_map_input, + static_cast(block_offset_col), + static_cast(next_row_offset), + tile_bytes, + &mbar[next_stage], + is_master); + } + + ptx::fence_proxy_async_shared_cta(); + // Wait until the data is ready, parity is always 0 because for simplicity + // we dont reuse barriers between stages + ptx::mbarrier_wait_parity(&mbar[stage], 0); + const size_t buff_offset = buff * BUFF_ELEMS; + const size_t out_buff_offset = buff * out_tile_elems; + if (is_active) { + // Read the data from shared to registers +#pragma unroll + for (int row = 0; row < TILE_M; ++row) { + thread_data[row] = in_sh[buff_offset + row * TILE_K + tidx]; + } + Tx2 amax_2x = Tx2{T(0.0f), T(0.0f)}; +#pragma unroll + for (int row = 0; row < TILE_M; row += 2) { + auto pair = Tx2{thread_data[row], thread_data[row + 1]}; + absmax_x2(amax_2x, amax_2x, pair); + } + + float scale = + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y))); + + scale /= F8E4M3_MAX; + + using ScaleType = __nv_fp8_e8m0; + auto s = ScaleType(scale); + scale = float(s); + // Store scale to shared memory buffer + // TODO: this results in 4 way bank conflicts + // (tidx+0,tidx+1,tidx+2,tidx+3 writes to the same bank) + scales_sh[stage * TILE_K + tidx] = s.__x; + // Quantize and write output to shared memory +#pragma unroll + for (int j = 0; j < TILE_M / 4; ++j) { + Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); + uint32_t quantized_val = + cu::scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + *reinterpret_cast( + &out_sh[out_buff_offset + tidx * TILE_M + j * 4]) = quantized_val; + } + } + __syncthreads(); + + if (is_master) { + const size_t global_row = block_offset_row + stage_row_offset; + const uint32_t out_x = static_cast(global_row); + const uint32_t out_y = static_cast(block_offset_col); + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), + out_x, + out_y, + reinterpret_cast(&out_sh[out_buff_offset])); + ptx::cp_async_bulk_commit_group(); + } + } + // Wait for all TMA stores to complete + ptx::cp_async_bulk_wait_group_read<0>(); + + __syncthreads(); + + const size_t base_row_group = block_offset_row / TILE_M; + const size_t global_col = block_offset_col + tidx; + + // This also results in 4way bank conflicts + // This can be done better + if (is_active && global_col < cols) { +#pragma unroll + for (size_t s = 0; s < STAGES; ++s) { + if ((block_offset_row + s * TILE_M) < rows) { + scales[global_col * scale_stride + base_row_group + s] = + scales_sh[s * TILE_K + tidx]; + } + } + } + + if (is_master) { +#pragma unroll + for (int iter = 0; iter < STAGES; ++iter) { + ptx::mbarrier_invalidate(&mbar[iter]); + } + } +#endif // __CUDA_ARCH__ >= 1000 +} + +template +__global__ void __launch_bounds__(128) fp_quantize_columnwise_tma_nvfp4( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + uint8_t* __restrict__ scales, + const size_t rows, + const size_t cols, + float* global_scale) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) + // placeholder TODO - NVFP4 TMA kernel not yet implemented +#endif // __CUDA_ARCH__ >= 1000 +} + +} // namespace cu +} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/quantized_utils.cuh b/mlx/backend/cuda/quantized/quantized_utils.cuh index a5a918d1da..be93362fe6 100644 --- a/mlx/backend/cuda/quantized/quantized_utils.cuh +++ b/mlx/backend/cuda/quantized/quantized_utils.cuh @@ -8,6 +8,9 @@ namespace mlx::core { namespace cu { +constexpr float F8E4M3_MAX = 448.0f; +constexpr float F4E2M1_MAX = 6.0f; + inline __device__ float4 dequant_fp8(uint32_t bits) { auto out = *(__nv_fp8x4_e4m3*)(&bits); return out.operator float4(); From 5a7668cfb1655318d92a1444b709d99208a47240 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Fri, 20 Feb 2026 16:29:02 +0100 Subject: [PATCH 06/15] fixed batched case --- mlx/backend/cuda/quantized/fp_quantize.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index fc567146da..1e1d337b87 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -197,8 +197,8 @@ void fp_quantize_rowwise_tma( enc.set_output_array(wq); enc.set_output_array(scales); - size_t rows = w.shape(-2); size_t cols = w.shape(-1); + size_t rows = w.size() / cols; size_t stride_bytes = cols * w.itemsize(); if (bits == 8 && group_size == 32) { @@ -356,7 +356,7 @@ void fp_quantize_columnwise_tma( enc.set_output_array(scales); size_t rows = w.shape(-1); - size_t cols = w.shape(-2); + size_t cols = w.size() / rows; size_t stride_bytes = w.strides(-1) * w.itemsize(); // 16 bytes is minimum for TMA tile size along contigious dimension From 145f6e679e80661675ad972a37920da2094d8358 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Sun, 22 Feb 2026 22:02:57 +0100 Subject: [PATCH 07/15] [wip] bank conflict --- mlx/backend/cuda/quantized/fp_quantize.cu | 23 +- .../cuda/quantized/fp_quantize_tma.cuh | 224 ++++++++++-------- 2 files changed, 136 insertions(+), 111 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 1e1d337b87..fab0dccf3b 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -25,11 +25,11 @@ template < int STAGES, int SCALES_PER_STAGE> inline std::tuple get_tma_launch_args( - size_t grid_dim_x_size, - size_t grid_dim_y_size, - size_t block_size_x, - size_t block_size_y, - int in_size_bytes, + size_t grid_dim_x_size, // rows + size_t grid_dim_y_size, // cols + size_t block_size_x, // ROWS_PER_BLOCK + size_t block_size_y, // COL_PER_BLOCK + int in_size_bytes, // itemsize int bits) { dim3 grid; grid.x = (grid_dim_x_size + block_size_x - 1) / block_size_x; @@ -54,7 +54,8 @@ inline std::tuple get_tma_launch_args( const size_t scales_tile_size = STAGES * SCALES_PER_STAGE * sizeof(uint8_t); const size_t scales_buff_size_aligned = - ((scales_tile_size + TMA_SHMEM_ALIGNMENT - 1) / TMA_SHMEM_ALIGNMENT) * + ((scales_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * TMA_SHMEM_ALIGNMENT; const size_t smem_size = in_buff_size_aligned + out_buff_size_aligned + @@ -280,8 +281,11 @@ void fp_quantize_rowwise( const std::optional& global_scale /* = std::nullopt */, cu::CommandEncoder& enc, const Stream& s) { + size_t cols = w.shape(-1); + size_t rows = w.size() / cols; + const bool has_full_tma_tiles = ((rows % 128) == 0) && ((cols % 128) == 0); if (enc.device().compute_capability_major() >= 10 && bits == 8 && - group_size == 32) { + group_size == 32 && !global_scale.has_value() && has_full_tma_tiles) { fp_quantize_rowwise_tma( w, wq, scales, group_size, bits, global_scale, enc, s); } else { @@ -446,9 +450,12 @@ void fp_quantize_columnwise( const Stream& s) { // Use TMA version for SM100+ with MXFP8 (bits=8, group_size=32) // NVFP4 todo + const size_t rows = w.shape(-1); + const size_t cols = w.size() / rows; + const bool has_full_tma_tiles = ((rows % 128) == 0) && ((cols % 128) == 0); bool use_tma = (enc.device().compute_capability_major() >= 10 && bits == 8 && - group_size == 32 && !global_scale.has_value()); + group_size == 32 && !global_scale.has_value() && has_full_tma_tiles); if (use_tma) { fp_quantize_columnwise_tma( w, wq, scales, group_size, bits, global_scale, enc, s); diff --git a/mlx/backend/cuda/quantized/fp_quantize_tma.cuh b/mlx/backend/cuda/quantized/fp_quantize_tma.cuh index 0589f8d122..46d0ea9306 100644 --- a/mlx/backend/cuda/quantized/fp_quantize_tma.cuh +++ b/mlx/backend/cuda/quantized/fp_quantize_tma.cuh @@ -19,7 +19,6 @@ namespace cu { constexpr size_t TMA_SHMEM_ALIGNMENT = 128; constexpr size_t BUFFS_NUM = 2; -constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4); // 32 banks, 4 bytes namespace cg = cooperative_groups; @@ -55,7 +54,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const auto idx_in_block = cg::this_thread_block().thread_index(); const int tidx = idx_in_block.x; // Thread handles row tidx const bool is_master = (tidx == 0); - const bool is_active = (tidx < ROWS_PER_BLOCK); const size_t block_offset_row = block_idx.x * ROWS_PER_BLOCK; const size_t block_offset_col = block_idx.y * COLS_PER_BLOCK; @@ -137,53 +135,64 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ptx::mbarrier_wait_parity(&mbar[stage], 0); const size_t buff_offset = buff * BUFF_ELEMS; - const size_t out_buff_offset = buff * out_tile_elems; const int lane = tidx % 32; constexpr int GROUPS = TILE_K / 4; - if (is_active) { - // First made 2 naive mistakes were made: - // 1. Stored without swizzling that results in 32 / 16 bank conflicts - // (depending of the input type) - // 2. More naive mistake was indexing registes with swizzling index that - // is known only at compile time this results in massive slow down of - // course + // First made 2 naive mistakes were made: + // 1. Stored without swizzling that results in 32 / 16 bank conflicts + // (depending of the input type) + // 2. More naive mistake was indexing registes with swizzling index that + // is known only at compile time this results in massive slow down of + // course #pragma unroll - for (int j = 0; j < GROUPS; ++j) { - int swizzled_j = (j + lane) % GROUPS; - *reinterpret_cast(&thread_data[j * 4]) = *reinterpret_cast( - &in_sh[buff_offset + tidx * TILE_K + swizzled_j * 4]); - } - // Compute scale: find max absolute value in the row (order-independent) - Tx2 amax_2x = Tx2{T(0.0f), T(0.0f)}; + for (int j = 0; j < GROUPS; ++j) { + int swizzled_j = (j + lane) % GROUPS; + *reinterpret_cast(&thread_data[j * 4]) = *reinterpret_cast( + &in_sh[buff_offset + tidx * TILE_K + swizzled_j * 4]); + } + // Compute scale: find max absolute value in the row (order-independent) + Tx2 amax_2x = Tx2{T(0.0f), T(0.0f)}; #pragma unroll - for (int col = 0; col < TILE_K; col += 2) { - auto pair = Tx2{thread_data[col], thread_data[col + 1]}; - absmax_x2(amax_2x, amax_2x, pair); - } - - float scale = - max(fabsf(static_cast(amax_2x.x)), - fabsf(static_cast(amax_2x.y))); + for (int col = 0; col < TILE_K; col += 2) { + auto pair = Tx2{thread_data[col], thread_data[col + 1]}; + absmax_x2(amax_2x, amax_2x, pair); + } - scale /= F8E4M3_MAX; + float scale = + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y))); - using ScaleType = __nv_fp8_e8m0; - auto s = ScaleType(scale); - scale = float(s); - scales_sh[stage * TILE_M + tidx] = s.__x; + scale /= F8E4M3_MAX; + using ScaleType = __nv_fp8_e8m0; + auto s = ScaleType(scale); + scale = float(s); + scales_sh[buff * TILE_K + tidx] = s.__x; + const size_t out_buff_offset = buff * out_tile_elems; + // Quantize to registers first + uint32_t quantized_regs[GROUPS]; #pragma unroll - for (int j = 0; j < GROUPS; ++j) { - int swizzled_j = (j + lane) % GROUPS; - Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); - uint32_t quantized_val = - cu::scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); - *reinterpret_cast( - &out_sh[out_buff_offset + tidx * TILE_K + swizzled_j * 4]) = - quantized_val; - } + for (int j = 0; j < GROUPS; ++j) { + Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); + quantized_regs[j] = + cu::scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + } + const int group = lane / 4; // 8 groups of 4 threads +#pragma unroll + for (int j = 0; j < GROUPS; ++j) { + int rotated_j = (j + group) % GROUPS; + *reinterpret_cast( + &out_sh[out_buff_offset + tidx * TILE_K + rotated_j * 4]) = + quantized_regs[rotated_j]; } __syncthreads(); + const size_t global_row = block_offset_row + tidx; + const size_t global_group = (block_offset_col + stage_col_offset) / TILE_K; + if (global_row < rows) { + scales[global_row * groups_per_row + global_group] = + scales_sh[buff * TILE_K + tidx]; + } + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); if (is_master) { const size_t global_col = block_offset_col + stage_col_offset; @@ -202,20 +211,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ptx::cp_async_bulk_wait_group_read<0>(); __syncthreads(); - const size_t global_row = block_offset_row + tidx; - const size_t base_group = block_offset_col / TILE_K; - - if (is_active && global_row < rows) { -#pragma unroll - for (size_t s = 0; s < STAGES; ++s) { - const size_t group_idx = base_group + s; - if ((block_offset_col + s * TILE_K) < cols) { - scales[global_row * groups_per_row + group_idx] = - scales_sh[s * TILE_M + tidx]; - } - } - } - if (is_master) { #pragma unroll for (int iter = 0; iter < STAGES; ++iter) { @@ -251,7 +246,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const auto idx_in_block = cg::this_thread_block().thread_index(); const int tidx = idx_in_block.x; // Thread handles column tidx const bool is_master = (tidx == 0); - const bool is_active = (tidx < COLS_PER_BLOCK); const size_t block_offset_col = block_idx.x * COLS_PER_BLOCK; const size_t block_offset_row = block_idx.y * ROWS_PER_BLOCK; @@ -341,43 +335,83 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) // we dont reuse barriers between stages ptx::mbarrier_wait_parity(&mbar[stage], 0); const size_t buff_offset = buff * BUFF_ELEMS; - const size_t out_buff_offset = buff * out_tile_elems; - if (is_active) { - // Read the data from shared to registers + // Read the data from shared to registers #pragma unroll - for (int row = 0; row < TILE_M; ++row) { - thread_data[row] = in_sh[buff_offset + row * TILE_K + tidx]; - } - Tx2 amax_2x = Tx2{T(0.0f), T(0.0f)}; + for (int row = 0; row < TILE_M; ++row) { + thread_data[row] = in_sh[buff_offset + row * TILE_K + tidx]; + } + Tx2 amax_2x = Tx2{T(0.0f), T(0.0f)}; #pragma unroll - for (int row = 0; row < TILE_M; row += 2) { - auto pair = Tx2{thread_data[row], thread_data[row + 1]}; - absmax_x2(amax_2x, amax_2x, pair); - } - - float scale = - max(fabsf(static_cast(amax_2x.x)), - fabsf(static_cast(amax_2x.y))); - - scale /= F8E4M3_MAX; - - using ScaleType = __nv_fp8_e8m0; - auto s = ScaleType(scale); - scale = float(s); - // Store scale to shared memory buffer - // TODO: this results in 4 way bank conflicts - // (tidx+0,tidx+1,tidx+2,tidx+3 writes to the same bank) - scales_sh[stage * TILE_K + tidx] = s.__x; - // Quantize and write output to shared memory + for (int row = 0; row < TILE_M; row += 2) { + auto pair = Tx2{thread_data[row], thread_data[row + 1]}; + absmax_x2(amax_2x, amax_2x, pair); + } + + float scale = + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y))); + + scale /= F8E4M3_MAX; + + using ScaleType = __nv_fp8_e8m0; + auto s = ScaleType(scale); + scale = float(s); + scales_sh[buff * TILE_K + tidx] = s.__x; + const size_t out_buff_offset = buff * out_tile_elems; + // Quantize to registers first + constexpr int GROUPS = TILE_M / 4; + uint32_t quantized_regs[GROUPS]; #pragma unroll - for (int j = 0; j < TILE_M / 4; ++j) { - Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); - uint32_t quantized_val = - cu::scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); - *reinterpret_cast( - &out_sh[out_buff_offset + tidx * TILE_M + j * 4]) = quantized_val; - } + for (int j = 0; j < GROUPS; ++j) { + Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); + quantized_regs[j] = + cu::scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + } + // Write output to shared memory with swapped store order to reduce bank + // conflicts. Without swap: stride between threads is TILE_M=32 bytes + // = 8 banks, every 4th thread hits the same bank -> 8-way conflict. + const int lane = tidx % 32; + const int group = (lane / 4) % 2; + const size_t base = out_buff_offset + tidx * TILE_M; + switch (group) { + case 0: + *reinterpret_cast(&out_sh[base + 0]) = { + quantized_regs[0], + quantized_regs[1], + quantized_regs[2], + quantized_regs[3]}; + *reinterpret_cast(&out_sh[base + 16]) = { + quantized_regs[4], + quantized_regs[5], + quantized_regs[6], + quantized_regs[7]}; + break; + case 1: + *reinterpret_cast(&out_sh[base + 16]) = { + quantized_regs[4], + quantized_regs[5], + quantized_regs[6], + quantized_regs[7]}; + *reinterpret_cast(&out_sh[base + 0]) = { + quantized_regs[0], + quantized_regs[1], + quantized_regs[2], + quantized_regs[3]}; + break; + } + __syncthreads(); + // Thread tidx computes scale for input column (block_offset_col + tidx) + // This scale goes to output row (block_offset_col + tidx), column + // (global_row_group) + const size_t global_row_group = + (block_offset_row + stage_row_offset) / TILE_M; + const size_t global_col = block_offset_col + tidx; + // TODO: scale writing is not good + if (global_col < cols && (block_offset_row + stage_row_offset) < rows) { + scales[global_col * scale_stride + global_row_group] = + scales_sh[buff * TILE_K + tidx]; } + ptx::fence_proxy_async_shared_cta(); __syncthreads(); if (is_master) { @@ -397,22 +431,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ptx::cp_async_bulk_wait_group_read<0>(); __syncthreads(); - - const size_t base_row_group = block_offset_row / TILE_M; - const size_t global_col = block_offset_col + tidx; - - // This also results in 4way bank conflicts - // This can be done better - if (is_active && global_col < cols) { -#pragma unroll - for (size_t s = 0; s < STAGES; ++s) { - if ((block_offset_row + s * TILE_M) < rows) { - scales[global_col * scale_stride + base_row_group + s] = - scales_sh[s * TILE_K + tidx]; - } - } - } - if (is_master) { #pragma unroll for (int iter = 0; iter < STAGES; ++iter) { From 66fb9eb532285d7bdafa03c77faa20efd09152bc Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 23 Feb 2026 01:32:37 +0100 Subject: [PATCH 08/15] drop rowwise for now --- mlx/backend/cuda/quantized/fp_quantize.cu | 252 +++++------------- mlx/backend/cuda/quantized/fp_quantize.cuh | 4 +- .../cuda/quantized/fp_quantize_tma.cuh | 219 +-------------- 3 files changed, 74 insertions(+), 401 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index fab0dccf3b..20b6244dbe 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -51,16 +51,8 @@ inline std::tuple get_tma_launch_args( ((out_tile_elems * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / TMA_SHMEM_ALIGNMENT) * TMA_SHMEM_ALIGNMENT; - - const size_t scales_tile_size = STAGES * SCALES_PER_STAGE * sizeof(uint8_t); - const size_t scales_buff_size_aligned = - ((scales_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / - TMA_SHMEM_ALIGNMENT) * - TMA_SHMEM_ALIGNMENT; - - const size_t smem_size = in_buff_size_aligned + out_buff_size_aligned + - scales_buff_size_aligned + TMA_SHMEM_ALIGNMENT; - + const size_t smem_size = + in_buff_size_aligned + out_buff_size_aligned + TMA_SHMEM_ALIGNMENT; return std::make_tuple(grid, block, smem_size); } @@ -79,7 +71,7 @@ inline CUtensorMapDataType get_tma_dtype(Dtype dtype) { } inline std::tuple -get_columnwise_fallback_launch_args(size_t size, int group_size, int M, int K) { +get_columnwise_launch_args(size_t size, int group_size, int M, int K) { constexpr int BLOCK_X = 16; constexpr int BLOCK_Y = 32; int rows_per_block = BLOCK_X; @@ -138,7 +130,7 @@ void fp_quantize_dequantize( }); } -void fp_quantize_rowwise_fallback( +void fp_quantize_rowwise( const array& w, array& wq, array& scales, @@ -153,148 +145,38 @@ void fp_quantize_rowwise_fallback( } enc.set_output_array(wq); enc.set_output_array(scales); - dispatch_float_types( - w.dtype(), "fp_quantize_rowwise_fallback", [&](auto type_tag) { - using T = cuda_type_t; - if constexpr (!std::is_same_v) { - auto kernel = cu::fp_quantize_rowwise_fallback; - if (bits == 8) { - kernel = cu::fp_quantize_rowwise_fallback; - } else if (group_size == 16) { - kernel = cu::fp_quantize_rowwise_fallback; - } - bool large = w.size() > UINT_MAX; - auto [num_blocks, block_dims] = get_launch_args( - w.size(), w.shape(), w.strides(), large, group_size); - - enc.add_kernel_node( - kernel, - num_blocks, - block_dims, - 0, - gpu_ptr(w), - gpu_ptr(wq), - gpu_ptr(scales), - w.size(), - global_scale.has_value() ? gpu_ptr(global_scale.value()) - : nullptr); - } else { - throw std::runtime_error( - "[Quantize::eval_gpu] Can not quantize input with type float64."); - } - }); -} - -void fp_quantize_rowwise_tma( - const array& w, - array& wq, - array& scales, - int group_size, - int bits, - const std::optional& global_scale /* = std::nullopt */, - cu::CommandEncoder& enc, - const Stream& s) { - enc.set_input_array(w); - enc.set_output_array(wq); - enc.set_output_array(scales); - - size_t cols = w.shape(-1); - size_t rows = w.size() / cols; - size_t stride_bytes = cols * w.itemsize(); - - if (bits == 8 && group_size == 32) { - dispatch_float_types(w.dtype(), "fp_quantize_rowwise_mxfp8", [&](auto type_tag) { - using T = cuda_type_t; - if constexpr (!std::is_same_v) { - constexpr int THREADS_PER_BLOCK = 64; - constexpr int ROWS_PER_BLOCK = 64; - constexpr int COLS_PER_BLOCK = 64; - constexpr size_t TILE_M = ROWS_PER_BLOCK; - constexpr size_t TILE_K = 32; // group_size - constexpr size_t STAGES = COLS_PER_BLOCK / TILE_K; - - auto [grid, block, smem_size] = cu::get_tma_launch_args< - TILE_M, - TILE_K, - THREADS_PER_BLOCK, - STAGES, - TILE_M>( - rows, cols, ROWS_PER_BLOCK, COLS_PER_BLOCK, w.itemsize(), bits); - - CUtensorMap tensor_map_input; - CUtensorMap tensor_map_output; - - create_2D_tensor_map( - &tensor_map_input, - const_cast(static_cast(gpu_ptr(w))), - cu::get_tma_dtype(w.dtype()), - rows, - cols, - TILE_M, - TILE_K, - stride_bytes); - - create_2D_tensor_map( - &tensor_map_output, - gpu_ptr(wq), - CU_TENSOR_MAP_DATA_TYPE_UINT8, - rows, - cols, - TILE_M, - TILE_K, - cols); - - auto kernel = cu::fp_quantize_rowwise_tma_mxfp8< - T, - TILE_K, - false, - THREADS_PER_BLOCK, - ROWS_PER_BLOCK, - COLS_PER_BLOCK>; - enc.add_kernel_node( - kernel, - grid, - block, - static_cast(smem_size), - tensor_map_input, - tensor_map_output, - gpu_ptr(scales), - rows, - cols); - } else { - throw std::runtime_error( - "[fp_quantize_rowwise_tma] Cannot quantize input with type float64."); + dispatch_float_types(w.dtype(), "fp_quantize_rowwise", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + auto kernel = cu::fp_quantize_rowwise; + if (bits == 8) { + kernel = cu::fp_quantize_rowwise; + } else if (group_size == 16) { + kernel = cu::fp_quantize_rowwise; } - }); - } else { - throw std::runtime_error( - "[fp_quantize_rowwise_tma] TMA quantization only implemented for bits=8 and group_size=32."); - } -} + bool large = w.size() > UINT_MAX; + auto [num_blocks, block_dims] = + get_launch_args(w.size(), w.shape(), w.strides(), large, group_size); -void fp_quantize_rowwise( - const array& w, - array& wq, - array& scales, - int group_size, - int bits, - const std::optional& global_scale /* = std::nullopt */, - cu::CommandEncoder& enc, - const Stream& s) { - size_t cols = w.shape(-1); - size_t rows = w.size() / cols; - const bool has_full_tma_tiles = ((rows % 128) == 0) && ((cols % 128) == 0); - if (enc.device().compute_capability_major() >= 10 && bits == 8 && - group_size == 32 && !global_scale.has_value() && has_full_tma_tiles) { - fp_quantize_rowwise_tma( - w, wq, scales, group_size, bits, global_scale, enc, s); - } else { - fp_quantize_rowwise_fallback( - w, wq, scales, group_size, bits, global_scale, enc, s); - } + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + gpu_ptr(w), + gpu_ptr(wq), + gpu_ptr(scales), + w.size(), + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); + } else { + throw std::runtime_error( + "[Quantize::eval_gpu] Can not quantize input with type float64."); + } + }); } -void fp_quantize_columnwise_fallback( +void fp_quantize_columnwise( const array& w, array& wq, array& scales, @@ -309,41 +191,37 @@ void fp_quantize_columnwise_fallback( } enc.set_output_array(wq); enc.set_output_array(scales); - dispatch_float_types( - w.dtype(), "fp_quantize_columnwise_fallback", [&](auto type_tag) { - using T = cuda_type_t; - if constexpr (!std::is_same_v) { - auto M = w.shape(-2); - auto K = w.shape(-1); - auto kernel = - cu::fp_quantize_columnwise_fallback; - if (bits == 8) { - kernel = cu::fp_quantize_columnwise_fallback; - } else if (group_size == 16) { - kernel = - cu::fp_quantize_columnwise_fallback; - } - auto [num_blocks, block_dims] = - cu::get_columnwise_fallback_launch_args( - w.size(), group_size, M, K); - enc.add_kernel_node( - kernel, - num_blocks, - block_dims, - 0, - gpu_ptr(w), - gpu_ptr(wq), - gpu_ptr(scales), - w.size(), - M, - K, - global_scale.has_value() ? gpu_ptr(global_scale.value()) - : nullptr); - } else { - throw std::runtime_error( - "[Quantize::eval_gpu] Can not quantize input with type float64."); - } - }); + dispatch_float_types(w.dtype(), "fp_quantize_columnwise", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + auto M = w.shape(-2); + auto K = w.shape(-1); + auto kernel = cu::fp_quantize_columnwise; + if (bits == 8) { + kernel = cu::fp_quantize_columnwise; + } else if (group_size == 16) { + kernel = cu::fp_quantize_columnwise; + } + auto [num_blocks, block_dims] = + cu::get_columnwise_launch_args(w.size(), group_size, M, K); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + gpu_ptr(w), + gpu_ptr(wq), + gpu_ptr(scales), + w.size(), + M, + K, + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); + } else { + throw std::runtime_error( + "[Quantize::eval_gpu] Can not quantize input with type float64."); + } + }); } void fp_quantize_columnwise_tma( @@ -455,12 +333,12 @@ void fp_quantize_columnwise( const bool has_full_tma_tiles = ((rows % 128) == 0) && ((cols % 128) == 0); bool use_tma = (enc.device().compute_capability_major() >= 10 && bits == 8 && - group_size == 32 && !global_scale.has_value() && has_full_tma_tiles); + group_size == 32 && has_full_tma_tiles); if (use_tma) { fp_quantize_columnwise_tma( w, wq, scales, group_size, bits, global_scale, enc, s); } else { - fp_quantize_columnwise_fallback( + fp_quantize_columnwise( w, wq, scales, group_size, bits, global_scale, enc, s); } } diff --git a/mlx/backend/cuda/quantized/fp_quantize.cuh b/mlx/backend/cuda/quantized/fp_quantize.cuh index bf89061b00..54b345cc4c 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cuh +++ b/mlx/backend/cuda/quantized/fp_quantize.cuh @@ -103,7 +103,7 @@ __global__ void fp_quantize_dequantize( } template -__global__ void fp_quantize_rowwise_fallback( +__global__ void fp_quantize_rowwise( T* w, uint8_t* out, uint8_t* scales, @@ -179,7 +179,7 @@ __global__ void fp_quantize_rowwise_fallback( } template -__global__ void fp_quantize_columnwise_fallback( +__global__ void fp_quantize_columnwise( T* w, uint8_t* out, uint8_t* scales, diff --git a/mlx/backend/cuda/quantized/fp_quantize_tma.cuh b/mlx/backend/cuda/quantized/fp_quantize_tma.cuh index 46d0ea9306..cabeef2c07 100644 --- a/mlx/backend/cuda/quantized/fp_quantize_tma.cuh +++ b/mlx/backend/cuda/quantized/fp_quantize_tma.cuh @@ -22,204 +22,6 @@ constexpr size_t BUFFS_NUM = 2; namespace cg = cooperative_groups; -// TMA-based rowwise quantization for MXFP8 -// Input: [rows, cols] row-major -// Output: [rows, cols / 32 / group_size] row-major -// Scales: [rows, cols / group_size] - one scale per group -// Each thread processes one row of group_size elements per stage, -// Each block processes TILE_M x TILE_K elements per stage. -template < - typename T, - int group_size, - bool USE_SR, - int THREADS_PER_BLOCK, - int ROWS_PER_BLOCK, - int COLS_PER_BLOCK> -__global__ void __launch_bounds__(THREADS_PER_BLOCK) - fp_quantize_rowwise_tma_mxfp8( - const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - uint8_t* __restrict__ scales, - const size_t rows, - const size_t cols) { -#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) - using Tx2 = Vector2_t; - using Tx4 = Vector4_t; - - constexpr size_t TILE_M = ROWS_PER_BLOCK; - constexpr size_t TILE_K = group_size; - constexpr size_t STAGES = COLS_PER_BLOCK / group_size; - - const auto block_idx = cg::this_thread_block().group_index(); - const auto idx_in_block = cg::this_thread_block().thread_index(); - const int tidx = idx_in_block.x; // Thread handles row tidx - const bool is_master = (tidx == 0); - - const size_t block_offset_row = block_idx.x * ROWS_PER_BLOCK; - const size_t block_offset_col = block_idx.y * COLS_PER_BLOCK; - - constexpr size_t BUFF_ELEMS = TILE_M * TILE_K; - constexpr size_t in_tile_size = BUFF_ELEMS * sizeof(T); - constexpr size_t in_buff_size_aligned = - ((in_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / - TMA_SHMEM_ALIGNMENT) * - TMA_SHMEM_ALIGNMENT; - - constexpr size_t out_tile_elems = BUFF_ELEMS; - constexpr size_t out_buff_size_aligned = - ((out_tile_elems * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / - TMA_SHMEM_ALIGNMENT) * - TMA_SHMEM_ALIGNMENT; - - extern __shared__ char shared_mem[]; - uintptr_t aligned_shared = - (reinterpret_cast(shared_mem) + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - T* in_sh = reinterpret_cast(aligned_shared); - uint8_t* out_sh = - reinterpret_cast(aligned_shared + in_buff_size_aligned); - uint8_t* scales_sh = reinterpret_cast( - aligned_shared + in_buff_size_aligned + out_buff_size_aligned); - - constexpr uint32_t tile_bytes = static_cast(in_tile_size); - - __shared__ alignas(8) uint64_t mbar[STAGES]; // 8 bytes alignment for mbarrier - - T thread_data[TILE_K]; // register storage for one row of TILE_K elements - uint32_t rbits = 0; - - const size_t groups_per_row = cols / TILE_K; - - if (is_master) { -#pragma unroll - for (int iter = 0; iter < STAGES; ++iter) { - ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); - } - ptx::fence_proxy_async_shared_cta(); - } - __syncthreads(); - - copy_2d_to_shared( - &in_sh[0], - &tensor_map_input, - static_cast(block_offset_col), - static_cast(block_offset_row), - tile_bytes, - &mbar[0], - is_master); - -#pragma unroll - for (size_t stage = 0; stage < STAGES; ++stage) { - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_col_offset = stage * TILE_K; - - if (next_stage < STAGES) { - ptx::cp_async_bulk_wait_group_read<1>(); - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_col_offset = block_offset_col + next_stage * TILE_K; - const size_t next_buff_elem_offset = next_buff * BUFF_ELEMS; - - copy_2d_to_shared( - &in_sh[next_buff_elem_offset], - &tensor_map_input, - static_cast(next_col_offset), - static_cast(block_offset_row), - tile_bytes, - &mbar[next_stage], - is_master); - } - - ptx::fence_proxy_async_shared_cta(); - ptx::mbarrier_wait_parity(&mbar[stage], 0); - - const size_t buff_offset = buff * BUFF_ELEMS; - const int lane = tidx % 32; - constexpr int GROUPS = TILE_K / 4; - // First made 2 naive mistakes were made: - // 1. Stored without swizzling that results in 32 / 16 bank conflicts - // (depending of the input type) - // 2. More naive mistake was indexing registes with swizzling index that - // is known only at compile time this results in massive slow down of - // course -#pragma unroll - for (int j = 0; j < GROUPS; ++j) { - int swizzled_j = (j + lane) % GROUPS; - *reinterpret_cast(&thread_data[j * 4]) = *reinterpret_cast( - &in_sh[buff_offset + tidx * TILE_K + swizzled_j * 4]); - } - // Compute scale: find max absolute value in the row (order-independent) - Tx2 amax_2x = Tx2{T(0.0f), T(0.0f)}; -#pragma unroll - for (int col = 0; col < TILE_K; col += 2) { - auto pair = Tx2{thread_data[col], thread_data[col + 1]}; - absmax_x2(amax_2x, amax_2x, pair); - } - - float scale = - max(fabsf(static_cast(amax_2x.x)), - fabsf(static_cast(amax_2x.y))); - - scale /= F8E4M3_MAX; - - using ScaleType = __nv_fp8_e8m0; - auto s = ScaleType(scale); - scale = float(s); - scales_sh[buff * TILE_K + tidx] = s.__x; - const size_t out_buff_offset = buff * out_tile_elems; - // Quantize to registers first - uint32_t quantized_regs[GROUPS]; -#pragma unroll - for (int j = 0; j < GROUPS; ++j) { - Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); - quantized_regs[j] = - cu::scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); - } - const int group = lane / 4; // 8 groups of 4 threads -#pragma unroll - for (int j = 0; j < GROUPS; ++j) { - int rotated_j = (j + group) % GROUPS; - *reinterpret_cast( - &out_sh[out_buff_offset + tidx * TILE_K + rotated_j * 4]) = - quantized_regs[rotated_j]; - } - __syncthreads(); - const size_t global_row = block_offset_row + tidx; - const size_t global_group = (block_offset_col + stage_col_offset) / TILE_K; - if (global_row < rows) { - scales[global_row * groups_per_row + global_group] = - scales_sh[buff * TILE_K + tidx]; - } - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - - if (is_master) { - const size_t global_col = block_offset_col + stage_col_offset; - const uint32_t out_x = static_cast(global_col); - const uint32_t out_y = static_cast(block_offset_row); - - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), - out_x, - out_y, - reinterpret_cast(&out_sh[out_buff_offset])); - ptx::cp_async_bulk_commit_group(); - } - } - - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - - if (is_master) { -#pragma unroll - for (int iter = 0; iter < STAGES; ++iter) { - ptx::mbarrier_invalidate(&mbar[iter]); - } - } -#endif // __CUDA_ARCH__ >= 1000 -} - template < typename T, bool USE_SR, @@ -272,8 +74,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) T* in_sh = reinterpret_cast(aligned_shared); uint8_t* out_sh = reinterpret_cast(aligned_shared + in_buff_size_aligned); - uint8_t* scales_sh = reinterpret_cast( - aligned_shared + in_buff_size_aligned + out_buff_size_aligned); constexpr uint32_t tile_bytes = static_cast(in_tile_size); @@ -356,7 +156,13 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) using ScaleType = __nv_fp8_e8m0; auto s = ScaleType(scale); scale = float(s); - scales_sh[buff * TILE_K + tidx] = s.__x; + // Write scale directly to global memory + const size_t global_col = block_offset_col + tidx; + const size_t global_row_group = + (block_offset_row + stage_row_offset) / TILE_M; + if (global_col < cols && (block_offset_row + stage_row_offset) < rows) { + scales[global_col * scale_stride + global_row_group] = s.__x; + } const size_t out_buff_offset = buff * out_tile_elems; // Quantize to registers first constexpr int GROUPS = TILE_M / 4; @@ -400,17 +206,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) break; } __syncthreads(); - // Thread tidx computes scale for input column (block_offset_col + tidx) - // This scale goes to output row (block_offset_col + tidx), column - // (global_row_group) - const size_t global_row_group = - (block_offset_row + stage_row_offset) / TILE_M; - const size_t global_col = block_offset_col + tidx; - // TODO: scale writing is not good - if (global_col < cols && (block_offset_row + stage_row_offset) < rows) { - scales[global_col * scale_stride + global_row_group] = - scales_sh[buff * TILE_K + tidx]; - } ptx::fence_proxy_async_shared_cta(); __syncthreads(); From 4a6e1b8d8c55c0cf8adc8d496b1eac8ea76c10c8 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 23 Feb 2026 01:38:42 +0100 Subject: [PATCH 09/15] fix kernel names --- mlx/backend/cuda/quantized/fp_quantize.cu | 69 +++++++++++----------- mlx/backend/cuda/quantized/fp_quantize.cuh | 2 +- 2 files changed, 37 insertions(+), 34 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 20b6244dbe..4ee246c499 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -176,7 +176,7 @@ void fp_quantize_rowwise( }); } -void fp_quantize_columnwise( +void fp_quantize_columnwise_fallback( const array& w, array& wq, array& scales, @@ -191,37 +191,40 @@ void fp_quantize_columnwise( } enc.set_output_array(wq); enc.set_output_array(scales); - dispatch_float_types(w.dtype(), "fp_quantize_columnwise", [&](auto type_tag) { - using T = cuda_type_t; - if constexpr (!std::is_same_v) { - auto M = w.shape(-2); - auto K = w.shape(-1); - auto kernel = cu::fp_quantize_columnwise; - if (bits == 8) { - kernel = cu::fp_quantize_columnwise; - } else if (group_size == 16) { - kernel = cu::fp_quantize_columnwise; - } - auto [num_blocks, block_dims] = - cu::get_columnwise_launch_args(w.size(), group_size, M, K); - enc.add_kernel_node( - kernel, - num_blocks, - block_dims, - 0, - gpu_ptr(w), - gpu_ptr(wq), - gpu_ptr(scales), - w.size(), - M, - K, - global_scale.has_value() ? gpu_ptr(global_scale.value()) - : nullptr); - } else { - throw std::runtime_error( - "[Quantize::eval_gpu] Can not quantize input with type float64."); - } - }); + dispatch_float_types( + w.dtype(), "fp_quantize_columnwise_fallback", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + auto M = w.shape(-2); + auto K = w.shape(-1); + auto kernel = + cu::fp_quantize_columnwise_fallback; + if (bits == 8) { + kernel = cu::fp_quantize_columnwise_fallback; + } else if (group_size == 16) { + kernel = + cu::fp_quantize_columnwise_fallback; + } + auto [num_blocks, block_dims] = + cu::get_columnwise_launch_args(w.size(), group_size, M, K); + enc.add_kernel_node( + kernel, + num_blocks, + block_dims, + 0, + gpu_ptr(w), + gpu_ptr(wq), + gpu_ptr(scales), + w.size(), + M, + K, + global_scale.has_value() ? gpu_ptr(global_scale.value()) + : nullptr); + } else { + throw std::runtime_error( + "[Quantize::eval_gpu] Can not quantize input with type float64."); + } + }); } void fp_quantize_columnwise_tma( @@ -338,7 +341,7 @@ void fp_quantize_columnwise( fp_quantize_columnwise_tma( w, wq, scales, group_size, bits, global_scale, enc, s); } else { - fp_quantize_columnwise( + fp_quantize_columnwise_fallback( w, wq, scales, group_size, bits, global_scale, enc, s); } } diff --git a/mlx/backend/cuda/quantized/fp_quantize.cuh b/mlx/backend/cuda/quantized/fp_quantize.cuh index 54b345cc4c..83ff912868 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cuh +++ b/mlx/backend/cuda/quantized/fp_quantize.cuh @@ -179,7 +179,7 @@ __global__ void fp_quantize_rowwise( } template -__global__ void fp_quantize_columnwise( +__global__ void fp_quantize_columnwise_fallback( T* w, uint8_t* out, uint8_t* scales, From 1c87d0e5246633ed37c1cdcd96c2d7a815581cfc Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 23 Feb 2026 11:52:43 +0100 Subject: [PATCH 10/15] refactor --- .gitignore | 1 + examples/python/qqmm.py | 102 +++++++++++------- mlx/backend/cuda/ptx.cuh | 17 --- mlx/backend/cuda/quantized/fp_quantize.cu | 9 +- .../cuda/quantized/fp_quantize_tma.cuh | 36 +++---- 5 files changed, 85 insertions(+), 80 deletions(-) diff --git a/.gitignore b/.gitignore index 1daaa46d12..e8dbe864ab 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,4 @@ uv.lock .cache/ # vim *.swp +tmp diff --git a/examples/python/qqmm.py b/examples/python/qqmm.py index 5be7eae2f3..c68c243be2 100644 --- a/examples/python/qqmm.py +++ b/examples/python/qqmm.py @@ -33,48 +33,74 @@ def test_qqmm(): [64, 128, 256, 1024, 1024 * 8], # N [64, 128, 256, 1024, 1024 * 8], # K ) + layouts = ["TN", "NT", "TT", "NN"] for group_size, mode, bits in tests: for M, N, K in product(*shapes): for dtype in dtypes: - x = mx.random.normal(shape=(M, K), key=k1, dtype=dtype) - w = mx.random.normal(shape=(N, K), key=k2, dtype=dtype) - w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode) - w_dq = mx.dequantize( - w_q, - scales_w, - group_size=group_size, - bits=bits, - mode=mode, - dtype=dtype, - ) - y_q = mx.qqmm( - x, - w_q, - scales_w, - group_size=group_size, - bits=bits, - mode=mode, - ) - x_q, scales_x = mx.quantize( - x, group_size=group_size, bits=bits, mode=mode - ) - x_dq = mx.dequantize( - x_q, - scales_x, - group_size=group_size, - bits=bits, - mode=mode, - dtype=dtype, - ) - y_hat = mx.matmul(x_dq, mx.transpose(w_dq)) - ulp = ulp_bf16_at(y_hat) - error = (y_q - y_hat).abs() - if not (mx.logical_or(error < 1e-3, error <= ulp).all()): - raise AssertionError( - f"qqmm test failed for shape {(M, N, K)}, " - f"group_size={group_size}, bits={bits}, " - f"mode={mode}, dtype={dtype}" + for layout in layouts: + if layout == "NT": + x_shape = (M, K) + w_shape = (N, K) + elif layout == "TN": + x_shape = (K, M) + w_shape = (K, N) + elif layout == "TT": + x_shape = (K, M) + w_shape = (N, K) + else: # "NN" + x_shape = (M, K) + w_shape = (K, N) + x = mx.random.normal(shape=x_shape, key=k1, dtype=dtype) + w = mx.random.normal(shape=w_shape, key=k2, dtype=dtype) + + if layout == "TT": + x = mx.transpose(x) + elif layout == "TN": + w = mx.transpose(w) + x = mx.transpose(x) + else: # "NN" + w = mx.transpose(w) + + w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode) + w_dq = mx.dequantize( + w_q, + scales_w, + group_size=group_size, + bits=bits, + mode=mode, + dtype=dtype, + ) + y_q = mx.qqmm( + x, + w_q, + scales_w, + group_size=group_size, + bits=bits, + mode=mode, ) + x_q, scales_x = mx.quantize( + x, group_size=group_size, bits=bits, mode=mode + ) + x_dq = mx.dequantize( + x_q, + scales_x, + group_size=group_size, + bits=bits, + mode=mode, + dtype=dtype, + ) + y_hat = mx.matmul(x_dq, mx.transpose(w_dq)) + ulp = ulp_bf16_at(y_hat) + error = (y_q - y_hat).abs() + if not (mx.logical_or(error < 1e-3, error <= ulp).all()): + import pdb + + pdb.set_trace() + raise AssertionError( + f"qqmm test failed for shape {(M, N, K)}, " + f"group_size={group_size}, bits={bits}, " + f"mode={mode}, dtype={dtype}, layout={layout}" + ) def test_qqmm_vjp(): diff --git a/mlx/backend/cuda/ptx.cuh b/mlx/backend/cuda/ptx.cuh index 831600ee6c..6ec7caedd6 100644 --- a/mlx/backend/cuda/ptx.cuh +++ b/mlx/backend/cuda/ptx.cuh @@ -6,23 +6,6 @@ namespace mlx::core { namespace ptx { -// For pipelining with TMA, we use memory barries, the mental model is the -// folowing: -// 1. master thread inits a barrier with expected number of threads to arrive, -// while the rest skip and __synchthreads() to wait for the barrier to be -// ready. -// 2. pipelining: before the loop, master thread launch an asynch copy and -// signals the barrier with the expected number of bytes to arrive, -// while the rest just signal arrival (first buffer). -// 3. So we have 2 buffers, and n pipeline stages. We launch the fist asynch -// copy before the begining of the loop. Then iterate over pipeline stages and -// copy to freed up buffer (buff = stage % BUFFS_NUM), but before the copy we -// want to check that the result was copied to global memory. at the firt loop -// nothing was commited so the check is succeseful so we launch a seconf asynch -// copy. Then we wait for the first one, do operations then lauch an asynch -// copy from shared -> global and return to the begining of the loop, now we -// check if the previous read from shared to global is done so we can reuse the -// buffer and launch async copy to fill buffer 0 #if (CUDART_VERSION >= 12080) && (__CUDA_ARCH__ >= 1000) && \ defined(__CUDA_ARCH_SPECIFIC__) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 4ee246c499..760b6754bb 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -13,7 +13,6 @@ #include #include #include -#include namespace mlx::core { namespace cu { @@ -244,9 +243,6 @@ void fp_quantize_columnwise_tma( size_t cols = w.size() / rows; size_t stride_bytes = w.strides(-1) * w.itemsize(); - // 16 bytes is minimum for TMA tile size along contigious dimension - // For nvfp4 (group_size = 16), we can't have a tile = group_size because we - // write transposed output so we will do 2 passes on the input tile if (bits == 8 && group_size == 32) { dispatch_float_types( w.dtype(), "fp_quantize_columnwise_mxfp8", [&](auto type_tag) { @@ -274,7 +270,7 @@ void fp_quantize_columnwise_tma( create_2D_tensor_map( &tensor_map_input, - const_cast(static_cast(gpu_ptr(w))), + gpu_ptr(w), cu::get_tma_dtype(w.dtype()), rows, cols, @@ -302,7 +298,7 @@ void fp_quantize_columnwise_tma( kernel, grid, block, - static_cast(smem_size), + smem_size, tensor_map_input, tensor_map_output, gpu_ptr(scales), @@ -314,7 +310,6 @@ void fp_quantize_columnwise_tma( } }); } else { - // NVFP4 or other configurations - TMA kernel not implemented yet throw std::runtime_error( "[fp_quantize_columnwise_tma] TMA quantization only implemented for bits=8 and group_size=32."); } diff --git a/mlx/backend/cuda/quantized/fp_quantize_tma.cuh b/mlx/backend/cuda/quantized/fp_quantize_tma.cuh index cabeef2c07..e496a6b61f 100644 --- a/mlx/backend/cuda/quantized/fp_quantize_tma.cuh +++ b/mlx/backend/cuda/quantized/fp_quantize_tma.cuh @@ -41,7 +41,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) constexpr size_t TILE_M = 32; constexpr size_t TILE_K = COLS_PER_BLOCK; - constexpr size_t STAGES = ROWS_PER_BLOCK / TILE_M; + constexpr size_t STEPS = ROWS_PER_BLOCK / TILE_M; constexpr int elem_per_byte = 1; const auto block_idx = cg::this_thread_block().group_index(); @@ -77,17 +77,17 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) constexpr uint32_t tile_bytes = static_cast(in_tile_size); - __shared__ alignas(8) uint64_t mbar[STAGES]; + __shared__ alignas(8) uint64_t mbar[STEPS]; T thread_data[TILE_M]; uint32_t rbits = 0; // Reserved for stochastic rounding const size_t scale_stride = rows / TILE_M; - // Master thread init memory barriers for all stages + // Master thread init memory barriers for all steps // fence for tma, synchronize threads so all see mbarrier if (is_master) { #pragma unroll - for (int iter = 0; iter < STAGES; ++iter) { + for (int iter = 0; iter < STEPS; ++iter) { ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); } ptx::fence_proxy_async_shared_cta(); @@ -104,20 +104,20 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) is_master); #pragma unroll - for (size_t stage = 0; stage < STAGES; ++stage) { + for (size_t step = 0; step < STEPS; ++step) { // buffer memory offset in shared memory (we use double buffering for // pipelining) - const size_t buff = stage % BUFFS_NUM; - const size_t next_stage = stage + 1; - const size_t stage_row_offset = stage * TILE_M; + const size_t buff = step % BUFFS_NUM; + const size_t next_step = step + 1; + const size_t step_row_offset = step * TILE_M; - if (next_stage < STAGES) { + if (next_step < STEPS) { // before launching another async copy, check that there is less than 2 // (to ensure that shared -> global synch is finished and buffer can be // reused) ptx::cp_async_bulk_wait_group_read<1>(); - const size_t next_buff = next_stage % BUFFS_NUM; - const size_t next_row_offset = block_offset_row + next_stage * TILE_M; + const size_t next_buff = next_step % BUFFS_NUM; + const size_t next_row_offset = block_offset_row + next_step * TILE_M; const size_t next_buff_elem_offset = next_buff * BUFF_ELEMS; copy_2d_to_shared( @@ -126,14 +126,14 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) static_cast(block_offset_col), static_cast(next_row_offset), tile_bytes, - &mbar[next_stage], + &mbar[next_step], is_master); } ptx::fence_proxy_async_shared_cta(); // Wait until the data is ready, parity is always 0 because for simplicity - // we dont reuse barriers between stages - ptx::mbarrier_wait_parity(&mbar[stage], 0); + // we dont reuse barriers between steps + ptx::mbarrier_wait_parity(&mbar[step], 0); const size_t buff_offset = buff * BUFF_ELEMS; // Read the data from shared to registers #pragma unroll @@ -159,8 +159,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) // Write scale directly to global memory const size_t global_col = block_offset_col + tidx; const size_t global_row_group = - (block_offset_row + stage_row_offset) / TILE_M; - if (global_col < cols && (block_offset_row + stage_row_offset) < rows) { + (block_offset_row + step_row_offset) / TILE_M; + if (global_col < cols && (block_offset_row + step_row_offset) < rows) { scales[global_col * scale_stride + global_row_group] = s.__x; } const size_t out_buff_offset = buff * out_tile_elems; @@ -210,7 +210,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) __syncthreads(); if (is_master) { - const size_t global_row = block_offset_row + stage_row_offset; + const size_t global_row = block_offset_row + step_row_offset; const uint32_t out_x = static_cast(global_row); const uint32_t out_y = static_cast(block_offset_col); @@ -228,7 +228,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) __syncthreads(); if (is_master) { #pragma unroll - for (int iter = 0; iter < STAGES; ++iter) { + for (int iter = 0; iter < STEPS; ++iter) { ptx::mbarrier_invalidate(&mbar[iter]); } } From 08d49afe996d7bcf0fc1ce7a3dc0ee3e05541002 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 23 Feb 2026 12:24:19 +0100 Subject: [PATCH 11/15] move common.h in quantized_utils.cuh --- mlx/backend/cuda/common.h | 69 ------------------- mlx/backend/cuda/quantized/fp_quantize.cu | 5 +- .../cuda/quantized/fp_quantize_tma.cuh | 23 ++++--- .../cuda/quantized/quantized_utils.cuh | 56 +++++++++++++++ 4 files changed, 73 insertions(+), 80 deletions(-) delete mode 100644 mlx/backend/cuda/common.h diff --git a/mlx/backend/cuda/common.h b/mlx/backend/cuda/common.h deleted file mode 100644 index e6d6d9c756..0000000000 --- a/mlx/backend/cuda/common.h +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright © 2025 Apple Inc. - -#pragma once - -#include -#include - -#include "mlx/backend/cuda/cuda_utils.h" -#include "mlx/backend/cuda/ptx.cuh" - -namespace mlx::core { - -__forceinline__ __device__ void copy_2d_to_shared( - void* dst, - const CUtensorMap* tensor_map, - uint32_t tile_x, - uint32_t tile_y, - uint32_t num_bytes, - uint64_t* barrier, - const bool is_master_thread) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - if (is_master_thread) { - // Arrive and tell how many bytes are expected - ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); - // Initiate bulk tensor copy - ptx::cp_async_bulk_tensor_2d_global_to_shared( - dst, tensor_map, tile_x, tile_y, barrier); - } else { - // Other threads just arrive - ptx::mbarrier_arrive(barrier); - } -#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) -} - -// Host function to create a 2D TMA tensor map descriptor -// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html -inline void create_2D_tensor_map( - CUtensorMap* tensorMap, - void* input_ptr, - CUtensorMapDataType dtype, - uint64_t rows, - uint64_t cols, - uint32_t tile_y, - uint32_t tile_x, - uint64_t stride_bytes, - CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE) { - constexpr uint32_t rank = 2; // 2D - uint64_t global_dim[rank] = {cols, rows}; - // For row-major layout - uint64_t strides[rank - 1] = {stride_bytes}; - uint32_t tile_dim[rank] = {tile_x, tile_y}; - uint32_t elem_stride[rank] = {1, 1}; - - CHECK_CUDA_ERROR(cuTensorMapEncodeTiled( - tensorMap, - dtype, - rank, - input_ptr, - global_dim, - strides, - tile_dim, - elem_stride, - CU_TENSOR_MAP_INTERLEAVE_NONE, - swizzle, - CU_TENSOR_MAP_L2_PROMOTION_NONE, - CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); -} - -} // namespace mlx::core diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 760b6754bb..301e0660a2 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -70,7 +70,7 @@ inline CUtensorMapDataType get_tma_dtype(Dtype dtype) { } inline std::tuple -get_columnwise_launch_args(size_t size, int group_size, int M, int K) { +get_columnwise_quantize_launch_args(size_t size, int group_size, int M, int K) { constexpr int BLOCK_X = 16; constexpr int BLOCK_Y = 32; int rows_per_block = BLOCK_X; @@ -205,7 +205,8 @@ void fp_quantize_columnwise_fallback( cu::fp_quantize_columnwise_fallback; } auto [num_blocks, block_dims] = - cu::get_columnwise_launch_args(w.size(), group_size, M, K); + cu::get_columnwise_quantize_launch_args( + w.size(), group_size, M, K); enc.add_kernel_node( kernel, num_blocks, diff --git a/mlx/backend/cuda/quantized/fp_quantize_tma.cuh b/mlx/backend/cuda/quantized/fp_quantize_tma.cuh index e496a6b61f..459bbc4a82 100644 --- a/mlx/backend/cuda/quantized/fp_quantize_tma.cuh +++ b/mlx/backend/cuda/quantized/fp_quantize_tma.cuh @@ -1,7 +1,6 @@ // Copyright © 2026 Apple Inc. #pragma once -#include "mlx/backend/cuda/common.h" #include "mlx/backend/cuda/ptx.cuh" #include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh" #include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh" @@ -235,14 +234,20 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) #endif // __CUDA_ARCH__ >= 1000 } -template -__global__ void __launch_bounds__(128) fp_quantize_columnwise_tma_nvfp4( - const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - uint8_t* __restrict__ scales, - const size_t rows, - const size_t cols, - float* global_scale) { +template < + typename T, + bool USE_SR, + int THREADS_PER_BLOCK, + int COLS_PER_BLOCK, + int ROWS_PER_BLOCK> +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + fp_quantize_columnwise_tma_nvfp4( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + uint8_t* __restrict__ scales, + const size_t rows, + const size_t cols, + float* global_scale) { #if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) // placeholder TODO - NVFP4 TMA kernel not yet implemented #endif // __CUDA_ARCH__ >= 1000 diff --git a/mlx/backend/cuda/quantized/quantized_utils.cuh b/mlx/backend/cuda/quantized/quantized_utils.cuh index be93362fe6..92bc315d4a 100644 --- a/mlx/backend/cuda/quantized/quantized_utils.cuh +++ b/mlx/backend/cuda/quantized/quantized_utils.cuh @@ -3,6 +3,7 @@ #include #include +#include "mlx/backend/cuda/ptx.cuh" namespace mlx::core { @@ -48,6 +49,28 @@ __device__ __forceinline__ void absmax_x2(T& out, const T& x1, const T& x2) { } } +__device__ __forceinline__ void copy_2d_to_shared( + void* dst, + const CUtensorMap* tensor_map, + uint32_t tile_x, + uint32_t tile_y, + uint32_t num_bytes, + uint64_t* barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Arrive and tell how many bytes are expected + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + dst, tensor_map, tile_x, tile_y, barrier); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + } // namespace cu template @@ -89,4 +112,37 @@ void dispatch_bits(int bits, F&& f) { } } +// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TENSOR__MEMORY.html +inline void create_2D_tensor_map( + CUtensorMap* tensorMap, + void* input_ptr, + CUtensorMapDataType dtype, + uint64_t rows, + uint64_t cols, + uint32_t tile_y, + uint32_t tile_x, + uint64_t stride_bytes, + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE) { + constexpr uint32_t rank = 2; // 2D + uint64_t global_dim[rank] = {cols, rows}; + // For row-major layout + uint64_t strides[rank - 1] = {stride_bytes}; + uint32_t tile_dim[rank] = {tile_x, tile_y}; + uint32_t elem_stride[rank] = {1, 1}; + + CHECK_CUDA_ERROR(cuTensorMapEncodeTiled( + tensorMap, + dtype, + rank, + input_ptr, + global_dim, + strides, + tile_dim, + elem_stride, + CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, + CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); +} + } // namespace mlx::core From 56f25b38c09091406274411d6715d5a9800ed8b7 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Mon, 23 Feb 2026 12:56:43 +0100 Subject: [PATCH 12/15] fixed example --- examples/python/qqmm.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/examples/python/qqmm.py b/examples/python/qqmm.py index c68c243be2..d1f68b444d 100644 --- a/examples/python/qqmm.py +++ b/examples/python/qqmm.py @@ -50,6 +50,7 @@ def test_qqmm(): else: # "NN" x_shape = (M, K) w_shape = (K, N) + x = mx.random.normal(shape=x_shape, key=k1, dtype=dtype) w = mx.random.normal(shape=w_shape, key=k2, dtype=dtype) @@ -58,25 +59,24 @@ def test_qqmm(): elif layout == "TN": w = mx.transpose(w) x = mx.transpose(x) - else: # "NN" + elif layout == "NN": w = mx.transpose(w) - w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode) - w_dq = mx.dequantize( - w_q, - scales_w, + y_q = mx.qqmm( + x, + w, group_size=group_size, bits=bits, mode=mode, - dtype=dtype, ) - y_q = mx.qqmm( - x, + w_q, scales_w = mx.quantize(w, group_size, bits, mode=mode) + w_dq = mx.dequantize( w_q, scales_w, group_size=group_size, bits=bits, mode=mode, + dtype=dtype, ) x_q, scales_x = mx.quantize( x, group_size=group_size, bits=bits, mode=mode @@ -95,6 +95,8 @@ def test_qqmm(): if not (mx.logical_or(error < 1e-3, error <= ulp).all()): import pdb + pdb.set_trace() + pdb.set_trace() raise AssertionError( f"qqmm test failed for shape {(M, N, K)}, " From 3b10cb4bda0bc2bcb06fe869415ac5d64a4ff9f3 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 24 Feb 2026 12:50:35 +0100 Subject: [PATCH 13/15] fix qqmm example --- .gitignore | 1 - examples/python/qqmm.py | 5 ----- 2 files changed, 6 deletions(-) diff --git a/.gitignore b/.gitignore index e8dbe864ab..1daaa46d12 100644 --- a/.gitignore +++ b/.gitignore @@ -79,4 +79,3 @@ uv.lock .cache/ # vim *.swp -tmp diff --git a/examples/python/qqmm.py b/examples/python/qqmm.py index d1f68b444d..dc46fa29a0 100644 --- a/examples/python/qqmm.py +++ b/examples/python/qqmm.py @@ -93,11 +93,6 @@ def test_qqmm(): ulp = ulp_bf16_at(y_hat) error = (y_q - y_hat).abs() if not (mx.logical_or(error < 1e-3, error <= ulp).all()): - import pdb - - pdb.set_trace() - - pdb.set_trace() raise AssertionError( f"qqmm test failed for shape {(M, N, K)}, " f"group_size={group_size}, bits={bits}, " From 8376a1e614878d35d7611e16caae6f2ba05a4ec3 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 17 Mar 2026 15:53:58 +0100 Subject: [PATCH 14/15] merge main, fix conflicts --- benchmarks/python/concat_bench.py | 25 -- mlx/backend/cuda/quantized/fp_quantize.cu | 195 ++++++----- mlx/backend/cuda/quantized/fp_quantize.cuh | 319 ++++++++++++++++-- .../cuda/quantized/fp_quantize_tma.cuh | 257 -------------- 4 files changed, 405 insertions(+), 391 deletions(-) delete mode 100644 benchmarks/python/concat_bench.py delete mode 100644 mlx/backend/cuda/quantized/fp_quantize_tma.cuh diff --git a/benchmarks/python/concat_bench.py b/benchmarks/python/concat_bench.py deleted file mode 100644 index f3486acf0b..0000000000 --- a/benchmarks/python/concat_bench.py +++ /dev/null @@ -1,25 +0,0 @@ -# Copyright © 2025 Apple Inc. -import mlx.core as mx -from time_utils import measure_runtime - - -def benchmark_concat(arrays, axis): - def concat(arrays): - mx.eval(mx.concatenate(arrays, axis=axis)) - - runtime = measure_runtime(concat, arrays=arrays) - return runtime - - -if __name__ == "__main__": - - for n_arrays in [32]: - for shape in [(4096, 4096)]: - arrays = [mx.random.normal(shape) for _ in range(n_arrays)] - mx.eval(arrays) - runtime = benchmark_concat(arrays, axis=0) - total_elems = n_arrays * shape[0] * shape[1] - print( - f" {n_arrays}x {shape} axis=0 " - f"({total_elems / 1e6:.1f}M elems): {runtime:.3f}ms" - ) \ No newline at end of file diff --git a/mlx/backend/cuda/quantized/fp_quantize.cu b/mlx/backend/cuda/quantized/fp_quantize.cu index 3506974c58..66ded268ab 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cu +++ b/mlx/backend/cuda/quantized/fp_quantize.cu @@ -4,7 +4,6 @@ #include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/kernel_utils.cuh" #include "mlx/backend/cuda/quantized/fp_quantize.cuh" -#include "mlx/backend/cuda/quantized/fp_quantize_tma.cuh" #include "mlx/backend/cuda/quantized/quantized.h" #include "mlx/backend/cuda/vector_types.cuh" #include "mlx/dtype_utils.h" @@ -14,19 +13,48 @@ #include #include -constexpr float F8E4M3_MAX = 448.0f; -constexpr float F4E2M1_MAX = 6.0f; - namespace mlx::core { namespace cu { +inline void create_2D_tensor_map( + CUtensorMap* tensorMap, + void* input_ptr, + CUtensorMapDataType dtype, + uint64_t rows, + uint64_t cols, + uint32_t tile_y, + uint32_t tile_x, + uint64_t stride_bytes, + CUtensorMapSwizzle swizzle = CU_TENSOR_MAP_SWIZZLE_NONE) { + constexpr uint32_t rank = 2; // 2D + uint64_t global_dim[rank] = {cols, rows}; + // For row-major layout + uint64_t strides[rank - 1] = {stride_bytes}; + uint32_t tile_dim[rank] = {tile_x, tile_y}; + uint32_t elem_stride[rank] = {1, 1}; + + CHECK_CUDA_ERROR(cuTensorMapEncodeTiled( + tensorMap, + dtype, + rank, + input_ptr, + global_dim, + strides, + tile_dim, + elem_stride, + CU_TENSOR_MAP_INTERLEAVE_NONE, + swizzle, + CU_TENSOR_MAP_L2_PROMOTION_NONE, + CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE)); +} + template < int TILE_M, int TILE_K, int THREADS_PER_BLOCK, int STAGES, int SCALES_PER_STAGE> -inline std::tuple get_tma_launch_args( +inline std::tuple get_columnwise_quantize_mxfp8_launch_args( size_t grid_dim_x_size, // rows size_t grid_dim_y_size, // cols size_t block_size_x, // ROWS_PER_BLOCK @@ -72,8 +100,11 @@ inline CUtensorMapDataType get_tma_dtype(Dtype dtype) { } } -inline std::tuple -get_columnwise_quantize_launch_args(size_t size, int group_size, int M, int K) { +inline std::tuple get_columnwise_quantize_fallback_launch_args( + size_t size, + int group_size, + int M, + int K) { constexpr int BLOCK_X = 16; constexpr int BLOCK_Y = 32; int rows_per_block = BLOCK_X; @@ -163,7 +194,6 @@ void fp_quantize_rowwise( kernel, num_blocks, block_dims, - 0, gpu_ptr(w), gpu_ptr(wq), gpu_ptr(scales), @@ -207,13 +237,12 @@ void fp_quantize_columnwise_fallback( cu::fp_quantize_columnwise_fallback; } auto [num_blocks, block_dims] = - cu::get_columnwise_quantize_launch_args( + cu::get_columnwise_quantize_fallback_launch_args( w.size(), group_size, M, K); enc.add_kernel_node( kernel, num_blocks, block_dims, - 0, gpu_ptr(w), gpu_ptr(wq), gpu_ptr(scales), @@ -229,7 +258,7 @@ void fp_quantize_columnwise_fallback( }); } -void fp_quantize_columnwise_tma( +void fp_quantize_columnwise_mxfp8( const array& w, array& wq, array& scales, @@ -246,76 +275,78 @@ void fp_quantize_columnwise_tma( size_t cols = w.size() / rows; size_t stride_bytes = w.strides(-1) * w.itemsize(); - if (bits == 8 && group_size == 32) { - dispatch_float_types( - w.dtype(), "fp_quantize_columnwise_mxfp8", [&](auto type_tag) { - using T = cuda_type_t; - if constexpr (!std::is_same_v) { - constexpr int THREADS_PER_BLOCK = 64; - constexpr int ROWS_PER_BLOCK = 64; - constexpr int COLS_PER_BLOCK = 64; - constexpr size_t TILE_M = 32; - constexpr size_t TILE_K = COLS_PER_BLOCK; - constexpr size_t STAGES = ROWS_PER_BLOCK / TILE_M; - - // For columnwise: grid.x = cols, grid.y = rows - // scales_per_stage = TILE_K (one scale per column per stage) - auto [grid, block, smem_size] = cu::get_tma_launch_args< - TILE_M, - TILE_K, - THREADS_PER_BLOCK, - STAGES, - TILE_K>( - cols, rows, COLS_PER_BLOCK, ROWS_PER_BLOCK, w.itemsize(), bits); - - CUtensorMap tensor_map_input; - CUtensorMap tensor_map_output; - - create_2D_tensor_map( - &tensor_map_input, - gpu_ptr(w), - cu::get_tma_dtype(w.dtype()), - rows, - cols, - TILE_M, - TILE_K, - stride_bytes); - - create_2D_tensor_map( - &tensor_map_output, - gpu_ptr(wq), - CU_TENSOR_MAP_DATA_TYPE_UINT8, - cols, - rows, - TILE_K, - TILE_M, - rows); - - auto kernel = cu::fp_quantize_columnwise_tma_mxfp8< - T, - false, - THREADS_PER_BLOCK, - COLS_PER_BLOCK, - ROWS_PER_BLOCK>; - enc.add_kernel_node( - kernel, - grid, - block, - smem_size, - tensor_map_input, - tensor_map_output, - gpu_ptr(scales), - rows, - cols); - } else { - throw std::runtime_error( - "[fp_quantize_columnwise_tma] Cannot quantize input with type float64."); - } - }); - } else { - throw std::runtime_error( - "[fp_quantize_columnwise_tma] TMA quantization only implemented for bits=8 and group_size=32."); - } + dispatch_float_types( + w.dtype(), "fp_quantize_columnwise_mxfp8", [&](auto type_tag) { + using T = cuda_type_t; + if constexpr (!std::is_same_v) { + constexpr int THREADS_PER_BLOCK = 64; + constexpr int ROWS_PER_BLOCK = 64; + constexpr int COLS_PER_BLOCK = 64; + constexpr size_t TILE_M = 32; + constexpr size_t TILE_K = COLS_PER_BLOCK; + constexpr size_t STAGES = ROWS_PER_BLOCK / TILE_M; + + // For columnwise: grid.x = cols, grid.y = rows + // scales_per_stage = TILE_K (one scale per column per stage) + auto [grid, block, smem_size] = + cu::get_columnwise_quantize_mxfp8_launch_args< + TILE_M, + TILE_K, + THREADS_PER_BLOCK, + STAGES, + TILE_K>( + cols, + rows, + COLS_PER_BLOCK, + ROWS_PER_BLOCK, + w.itemsize(), + bits); + + CUtensorMap tensor_map_input; + CUtensorMap tensor_map_output; + + cu::create_2D_tensor_map( + &tensor_map_input, + gpu_ptr(w), + cu::get_tma_dtype(w.dtype()), + rows, + cols, + TILE_M, + TILE_K, + stride_bytes); + + cu::create_2D_tensor_map( + &tensor_map_output, + gpu_ptr(wq), + CU_TENSOR_MAP_DATA_TYPE_UINT8, + cols, + rows, + TILE_K, + TILE_M, + rows); + + auto kernel = cu::fp_quantize_columnwise_mxfp8< + T, + false, + THREADS_PER_BLOCK, + COLS_PER_BLOCK, + ROWS_PER_BLOCK>; + enc.add_kernel_node_ex( + kernel, + grid, + block, + {}, + smem_size, + tensor_map_input, + tensor_map_output, + gpu_ptr(scales), + rows, + cols); + } else { + throw std::runtime_error( + "[fp_quantize_columnwise_tma] Cannot quantize input with type float64."); + } + }); } void fp_quantize_columnwise( @@ -336,7 +367,7 @@ void fp_quantize_columnwise( (enc.device().compute_capability_major() >= 10 && bits == 8 && group_size == 32 && has_full_tma_tiles); if (use_tma) { - fp_quantize_columnwise_tma( + fp_quantize_columnwise_mxfp8( w, wq, scales, group_size, bits, global_scale, enc, s); } else { fp_quantize_columnwise_fallback( diff --git a/mlx/backend/cuda/quantized/fp_quantize.cuh b/mlx/backend/cuda/quantized/fp_quantize.cuh index 83ff912868..a09234adb9 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cuh +++ b/mlx/backend/cuda/quantized/fp_quantize.cuh @@ -1,16 +1,21 @@ // Copyright © 2025 Apple Inc. #pragma once +#include "mlx/backend/cuda/ptx.cuh" #include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh" #include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh" -#include "mlx/backend/cuda/quantized/quantized_utils.cuh" #include "mlx/backend/cuda/vector_types.cuh" #include "mlx/dtype_utils.h" #include #include -#include -#include +#include +#include + +constexpr float F8E4M3_MAX = 448.0f; +constexpr float F4E2M1_MAX = 6.0f; +constexpr size_t TMA_SHMEM_ALIGNMENT = 128; +constexpr size_t BUFFS_NUM = 2; namespace mlx::core { namespace cu { @@ -19,13 +24,51 @@ template struct Dequantize { __device__ float operator()(uint8_t x) { if constexpr (bits == 8) { - return float(*(__nv_fp8_e4m3*)(&x)); + return float(*(cutlass::float_e4m3_t*)(&x)); } else { - return float(*(__nv_fp4_e2m1*)(&x)); + return float(*(cutlass::float_e2m1_t*)(&x)); } } }; +template +__device__ __forceinline__ void absmax_x2(T& out, const T& x1, const T& x2) { + if constexpr ( + (std::is_same::value) || + (std::is_same::value)) { + T a = x1; + T b = x2; + out = __hmax2(__habs2(a), __habs2(b)); + } else if constexpr (std::is_same::value) { + float2 a = x1; + float2 b = x2; + out.x = fmaxf(fabsf(a.x), fabsf(b.x)); + out.y = fmaxf(fabsf(a.y), fabsf(b.y)); + } +} + +__device__ __forceinline__ void copy_2d_to_shared( + void* dst, + const CUtensorMap* tensor_map, + uint32_t tile_x, + uint32_t tile_y, + uint32_t num_bytes, + uint64_t* barrier, + const bool is_master_thread) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (is_master_thread) { + // Arrive and tell how many bytes are expected + ptx::mbarrier_arrive_expect_tx(barrier, num_bytes); + // Initiate bulk tensor copy + ptx::cp_async_bulk_tensor_2d_global_to_shared( + dst, tensor_map, tile_x, tile_y, barrier); + } else { + // Other threads just arrive + ptx::mbarrier_arrive(barrier); + } +#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +} + namespace cg = cooperative_groups; template @@ -40,7 +83,6 @@ __global__ void fp_quantize_dequantize( const float inv_scale_enc = use_global_scale ? 1.0f / scale_enc : 1.0f; using Tx2 = Vector2_t; - using Tx4 = Vector4_t; uint32_t rbits = 0; // reserved bits for future use auto block_size = cg::this_thread_block().dim_threads(); auto block_idx = cg::this_thread_block().group_index(); @@ -74,30 +116,35 @@ __global__ void fp_quantize_dequantize( scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; scale_dec_b *= scale_enc; // Convert to mx scale or nv scale - using ScaleType = - std::conditional_t; + using ScaleType = std::conditional_t< + use_mx_scale, + cutlass::float_ue8m0_t, + cutlass::float_e4m3_t>; auto s = ScaleType(scale_dec_b); float scale_enc_b = scale_enc / float(s); float scale_dec = float(s) * inv_scale_enc; AlignedVector w_hat; #pragma unroll - for (int i = 0; i < group_size / 4; i++) { - Tx4 w_Tx4 = *reinterpret_cast(&w_tile[i * 4]); - float4 dq; + for (int i = 0; i < group_size / 8; i++) { + auto& w = *reinterpret_cast*>(&w_tile[i * 8]); + cutlass::NumericArrayConverter fp32_t; + auto scaled = fp32_t(w) * scale_enc_b; + cutlass::Array dq; if constexpr (bits == 8) { - uint32_t quantized_val = - scale_cvt_Tx4_to_fp8x4(w_Tx4, scale_enc_b, rbits); - dq = dequant_fp8(quantized_val); + cutlass::NumericArrayConverter fp8_fp32; + auto quant = fp8_fp32(scaled); + cutlass::NumericArrayConverter fp32_fp8; + dq = fp32_fp8(quant); } else { - uint16_t quantized_val = - scale_cvt_Tx4_to_fp4x4(w_Tx4, scale_enc_b, rbits); - dq = dequant_fp4(quantized_val); + cutlass::NumericArrayConverter fp4_fp32; + auto quant = fp4_fp32(scaled); + cutlass::NumericArrayConverter fp32_fp4; + dq = fp32_fp4(quant); } - w_hat[i * 4] = static_cast(dq.x * scale_dec); - w_hat[i * 4 + 1] = static_cast(dq.y * scale_dec); - w_hat[i * 4 + 2] = static_cast(dq.z * scale_dec); - w_hat[i * 4 + 3] = static_cast(dq.w * scale_dec); + cutlass::NumericArrayConverter t_fp32; + *reinterpret_cast*>(&w_hat[i * 8]) = + t_fp32(dq * scale_dec); } store_vector(out, thread_idx, w_hat); } @@ -152,10 +199,12 @@ __global__ void fp_quantize_rowwise( scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; scale_dec_b *= scale_enc; // Convert to mx scale or nv scale - using ScaleType = - std::conditional_t; + using ScaleType = std::conditional_t< + use_mx_scale, + cutlass::float_ue8m0_t, + cutlass::float_e4m3_t>; auto s = ScaleType(scale_dec_b); - uint8_t q_scale = s.__x; + uint8_t q_scale = s.storage; float scale_enc_b = scale_enc / float(s); scales[thread_idx] = q_scale; @@ -178,6 +227,220 @@ __global__ void fp_quantize_rowwise( store_vector(out, thread_idx, quantized); } +template < + typename T, + bool USE_SR, + int THREADS_PER_BLOCK, + int COLS_PER_BLOCK, + int ROWS_PER_BLOCK> +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + fp_quantize_columnwise_mxfp8( + const __grid_constant__ CUtensorMap tensor_map_input, + const __grid_constant__ CUtensorMap tensor_map_output, + uint8_t* __restrict__ scales, + const size_t rows, + const size_t cols) { +#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) + using Tx2 = Vector2_t; + using Tx4 = Vector4_t; + + constexpr size_t TILE_M = 32; + constexpr size_t TILE_K = COLS_PER_BLOCK; + constexpr size_t STEPS = ROWS_PER_BLOCK / TILE_M; + constexpr int elem_per_byte = 1; + + const auto block_idx = cg::this_thread_block().group_index(); + const auto idx_in_block = cg::this_thread_block().thread_index(); + const int tidx = idx_in_block.x; // Thread handles column tidx + const bool is_master = (tidx == 0); + + const size_t block_offset_col = block_idx.x * COLS_PER_BLOCK; + const size_t block_offset_row = block_idx.y * ROWS_PER_BLOCK; + + constexpr size_t BUFF_ELEMS = TILE_M * TILE_K; + constexpr size_t in_tile_size = BUFF_ELEMS * sizeof(T); + constexpr size_t in_buff_size_aligned = + ((in_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + constexpr size_t out_tile_elems = BUFF_ELEMS / elem_per_byte; + constexpr size_t out_tile_size = out_tile_elems; + constexpr size_t out_buff_size_aligned = + ((out_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / + TMA_SHMEM_ALIGNMENT) * + TMA_SHMEM_ALIGNMENT; + + extern __shared__ char shared_mem[]; + uintptr_t aligned_shared = + (reinterpret_cast(shared_mem) + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + T* in_sh = reinterpret_cast(aligned_shared); + uint8_t* out_sh = + reinterpret_cast(aligned_shared + in_buff_size_aligned); + + constexpr uint32_t tile_bytes = static_cast(in_tile_size); + + __shared__ alignas(8) uint64_t mbar[STEPS]; + + T thread_data[TILE_M]; + uint32_t rbits = 0; // Reserved for stochastic rounding + const size_t scale_stride = rows / TILE_M; + + // Master thread init memory barriers for all steps + // fence for tma, synchronize threads so all see mbarrier + if (is_master) { +#pragma unroll + for (int iter = 0; iter < STEPS; ++iter) { + ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); + } + ptx::fence_proxy_async_shared_cta(); + } + __syncthreads(); + // Launch first async copy before entering the loop + copy_2d_to_shared( + &in_sh[0], + &tensor_map_input, + static_cast(block_offset_col), + static_cast(block_offset_row), + tile_bytes, + &mbar[0], + is_master); + +#pragma unroll + for (size_t step = 0; step < STEPS; ++step) { + // buffer memory offset in shared memory (we use double buffering for + // pipelining) + const size_t buff = step % BUFFS_NUM; + const size_t next_step = step + 1; + const size_t step_row_offset = step * TILE_M; + + if (next_step < STEPS) { + // before launching another async copy, check that there is less than 2 + // (to ensure that shared -> global synch is finished and buffer can be + // reused) + ptx::cp_async_bulk_wait_group_read<1>(); + const size_t next_buff = next_step % BUFFS_NUM; + const size_t next_row_offset = block_offset_row + next_step * TILE_M; + const size_t next_buff_elem_offset = next_buff * BUFF_ELEMS; + + copy_2d_to_shared( + &in_sh[next_buff_elem_offset], + &tensor_map_input, + static_cast(block_offset_col), + static_cast(next_row_offset), + tile_bytes, + &mbar[next_step], + is_master); + } + + ptx::fence_proxy_async_shared_cta(); + // Wait until the data is ready, parity is always 0 because for simplicity + // we dont reuse barriers between steps + ptx::mbarrier_wait_parity(&mbar[step], 0); + const size_t buff_offset = buff * BUFF_ELEMS; + // Read the data from shared to registers +#pragma unroll + for (int row = 0; row < TILE_M; ++row) { + thread_data[row] = in_sh[buff_offset + row * TILE_K + tidx]; + } + Tx2 amax_2x = Tx2{T(0.0f), T(0.0f)}; +#pragma unroll + for (int row = 0; row < TILE_M; row += 2) { + auto pair = Tx2{thread_data[row], thread_data[row + 1]}; + absmax_x2(amax_2x, amax_2x, pair); + } + + float scale = + max(fabsf(static_cast(amax_2x.x)), + fabsf(static_cast(amax_2x.y))); + + scale /= F8E4M3_MAX; + + using ScaleType = __nv_fp8_e8m0; + auto s = ScaleType(scale); + scale = float(s); + // Write scale directly to global memory + const size_t global_col = block_offset_col + tidx; + const size_t global_row_group = + (block_offset_row + step_row_offset) / TILE_M; + if (global_col < cols && (block_offset_row + step_row_offset) < rows) { + scales[global_col * scale_stride + global_row_group] = s.__x; + } + const size_t out_buff_offset = buff * out_tile_elems; + // Quantize to registers first + constexpr int GROUPS = TILE_M / 4; + uint32_t quantized_regs[GROUPS]; +#pragma unroll + for (int j = 0; j < GROUPS; ++j) { + Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); + quantized_regs[j] = + cu::scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); + } + // Write output to shared memory with swapped store order to reduce bank + // conflicts. Without swap: stride between threads is TILE_M=32 bytes + // = 8 banks, every 4th thread hits the same bank -> 8-way conflict. + const int lane = tidx % 32; + const int group = (lane / 4) % 2; + const size_t base = out_buff_offset + tidx * TILE_M; + switch (group) { + case 0: + *reinterpret_cast(&out_sh[base + 0]) = { + quantized_regs[0], + quantized_regs[1], + quantized_regs[2], + quantized_regs[3]}; + *reinterpret_cast(&out_sh[base + 16]) = { + quantized_regs[4], + quantized_regs[5], + quantized_regs[6], + quantized_regs[7]}; + break; + case 1: + *reinterpret_cast(&out_sh[base + 16]) = { + quantized_regs[4], + quantized_regs[5], + quantized_regs[6], + quantized_regs[7]}; + *reinterpret_cast(&out_sh[base + 0]) = { + quantized_regs[0], + quantized_regs[1], + quantized_regs[2], + quantized_regs[3]}; + break; + } + __syncthreads(); + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + + if (is_master) { + const size_t global_row = block_offset_row + step_row_offset; + const uint32_t out_x = static_cast(global_row); + const uint32_t out_y = static_cast(block_offset_col); + + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output), + out_x, + out_y, + reinterpret_cast(&out_sh[out_buff_offset])); + ptx::cp_async_bulk_commit_group(); + } + } + // Wait for all TMA stores to complete + ptx::cp_async_bulk_wait_group_read<0>(); + + __syncthreads(); + if (is_master) { +#pragma unroll + for (int iter = 0; iter < STEPS; ++iter) { + ptx::mbarrier_invalidate(&mbar[iter]); + } + } +#endif // __CUDA_ARCH__ >= 1000 +} + +// TODO: add kernel with tma instructions template __global__ void fp_quantize_columnwise_fallback( T* w, @@ -251,11 +514,13 @@ __global__ void fp_quantize_columnwise_fallback( scale_dec_b /= bits == 4 ? F4E2M1_MAX : F8E4M3_MAX; scale_dec_b *= scale_enc; // Convert to mx scale or nv scale - using ScaleType = - std::conditional_t; + using ScaleType = std::conditional_t< + use_mx_scale, + cutlass::float_ue8m0_t, + cutlass::float_e4m3_t>; auto s = ScaleType(scale_dec_b); float scale_enc_b = scale_enc / float(s); - scales_smem[tidx][tidy] = s.__x; + scales_smem[tidx][tidy] = s.storage; int shared_idx = tidx * padded_local_cols + tidy * bytes_per_group; diff --git a/mlx/backend/cuda/quantized/fp_quantize_tma.cuh b/mlx/backend/cuda/quantized/fp_quantize_tma.cuh deleted file mode 100644 index 459bbc4a82..0000000000 --- a/mlx/backend/cuda/quantized/fp_quantize_tma.cuh +++ /dev/null @@ -1,257 +0,0 @@ -// Copyright © 2026 Apple Inc. -#pragma once - -#include "mlx/backend/cuda/ptx.cuh" -#include "mlx/backend/cuda/quantized/mxfp8_quantize.cuh" -#include "mlx/backend/cuda/quantized/nvfp4_quantize.cuh" -#include "mlx/backend/cuda/quantized/quantized_utils.cuh" -#include "mlx/backend/cuda/vector_types.cuh" -#include "mlx/dtype_utils.h" - -#include -#include -#include -#include - -namespace mlx::core { -namespace cu { - -constexpr size_t TMA_SHMEM_ALIGNMENT = 128; -constexpr size_t BUFFS_NUM = 2; - -namespace cg = cooperative_groups; - -template < - typename T, - bool USE_SR, - int THREADS_PER_BLOCK, - int COLS_PER_BLOCK, - int ROWS_PER_BLOCK> -__global__ void __launch_bounds__(THREADS_PER_BLOCK) - fp_quantize_columnwise_tma_mxfp8( - const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - uint8_t* __restrict__ scales, - const size_t rows, - const size_t cols) { -#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) - using Tx2 = Vector2_t; - using Tx4 = Vector4_t; - - constexpr size_t TILE_M = 32; - constexpr size_t TILE_K = COLS_PER_BLOCK; - constexpr size_t STEPS = ROWS_PER_BLOCK / TILE_M; - constexpr int elem_per_byte = 1; - - const auto block_idx = cg::this_thread_block().group_index(); - const auto idx_in_block = cg::this_thread_block().thread_index(); - const int tidx = idx_in_block.x; // Thread handles column tidx - const bool is_master = (tidx == 0); - - const size_t block_offset_col = block_idx.x * COLS_PER_BLOCK; - const size_t block_offset_row = block_idx.y * ROWS_PER_BLOCK; - - constexpr size_t BUFF_ELEMS = TILE_M * TILE_K; - constexpr size_t in_tile_size = BUFF_ELEMS * sizeof(T); - constexpr size_t in_buff_size_aligned = - ((in_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / - TMA_SHMEM_ALIGNMENT) * - TMA_SHMEM_ALIGNMENT; - - constexpr size_t out_tile_elems = BUFF_ELEMS / elem_per_byte; - constexpr size_t out_tile_size = out_tile_elems; - constexpr size_t out_buff_size_aligned = - ((out_tile_size * BUFFS_NUM + TMA_SHMEM_ALIGNMENT - 1) / - TMA_SHMEM_ALIGNMENT) * - TMA_SHMEM_ALIGNMENT; - - extern __shared__ char shared_mem[]; - uintptr_t aligned_shared = - (reinterpret_cast(shared_mem) + TMA_SHMEM_ALIGNMENT - 1) & - ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - - T* in_sh = reinterpret_cast(aligned_shared); - uint8_t* out_sh = - reinterpret_cast(aligned_shared + in_buff_size_aligned); - - constexpr uint32_t tile_bytes = static_cast(in_tile_size); - - __shared__ alignas(8) uint64_t mbar[STEPS]; - - T thread_data[TILE_M]; - uint32_t rbits = 0; // Reserved for stochastic rounding - const size_t scale_stride = rows / TILE_M; - - // Master thread init memory barriers for all steps - // fence for tma, synchronize threads so all see mbarrier - if (is_master) { -#pragma unroll - for (int iter = 0; iter < STEPS; ++iter) { - ptx::mbarrier_init(&mbar[iter], THREADS_PER_BLOCK); - } - ptx::fence_proxy_async_shared_cta(); - } - __syncthreads(); - // Launch first async copy before entering the loop - copy_2d_to_shared( - &in_sh[0], - &tensor_map_input, - static_cast(block_offset_col), - static_cast(block_offset_row), - tile_bytes, - &mbar[0], - is_master); - -#pragma unroll - for (size_t step = 0; step < STEPS; ++step) { - // buffer memory offset in shared memory (we use double buffering for - // pipelining) - const size_t buff = step % BUFFS_NUM; - const size_t next_step = step + 1; - const size_t step_row_offset = step * TILE_M; - - if (next_step < STEPS) { - // before launching another async copy, check that there is less than 2 - // (to ensure that shared -> global synch is finished and buffer can be - // reused) - ptx::cp_async_bulk_wait_group_read<1>(); - const size_t next_buff = next_step % BUFFS_NUM; - const size_t next_row_offset = block_offset_row + next_step * TILE_M; - const size_t next_buff_elem_offset = next_buff * BUFF_ELEMS; - - copy_2d_to_shared( - &in_sh[next_buff_elem_offset], - &tensor_map_input, - static_cast(block_offset_col), - static_cast(next_row_offset), - tile_bytes, - &mbar[next_step], - is_master); - } - - ptx::fence_proxy_async_shared_cta(); - // Wait until the data is ready, parity is always 0 because for simplicity - // we dont reuse barriers between steps - ptx::mbarrier_wait_parity(&mbar[step], 0); - const size_t buff_offset = buff * BUFF_ELEMS; - // Read the data from shared to registers -#pragma unroll - for (int row = 0; row < TILE_M; ++row) { - thread_data[row] = in_sh[buff_offset + row * TILE_K + tidx]; - } - Tx2 amax_2x = Tx2{T(0.0f), T(0.0f)}; -#pragma unroll - for (int row = 0; row < TILE_M; row += 2) { - auto pair = Tx2{thread_data[row], thread_data[row + 1]}; - absmax_x2(amax_2x, amax_2x, pair); - } - - float scale = - max(fabsf(static_cast(amax_2x.x)), - fabsf(static_cast(amax_2x.y))); - - scale /= F8E4M3_MAX; - - using ScaleType = __nv_fp8_e8m0; - auto s = ScaleType(scale); - scale = float(s); - // Write scale directly to global memory - const size_t global_col = block_offset_col + tidx; - const size_t global_row_group = - (block_offset_row + step_row_offset) / TILE_M; - if (global_col < cols && (block_offset_row + step_row_offset) < rows) { - scales[global_col * scale_stride + global_row_group] = s.__x; - } - const size_t out_buff_offset = buff * out_tile_elems; - // Quantize to registers first - constexpr int GROUPS = TILE_M / 4; - uint32_t quantized_regs[GROUPS]; -#pragma unroll - for (int j = 0; j < GROUPS; ++j) { - Tx4 w_Tx4 = *reinterpret_cast(&thread_data[j * 4]); - quantized_regs[j] = - cu::scale_cvt_Tx4_to_fp8x4(w_Tx4, 1.0f / scale, rbits); - } - // Write output to shared memory with swapped store order to reduce bank - // conflicts. Without swap: stride between threads is TILE_M=32 bytes - // = 8 banks, every 4th thread hits the same bank -> 8-way conflict. - const int lane = tidx % 32; - const int group = (lane / 4) % 2; - const size_t base = out_buff_offset + tidx * TILE_M; - switch (group) { - case 0: - *reinterpret_cast(&out_sh[base + 0]) = { - quantized_regs[0], - quantized_regs[1], - quantized_regs[2], - quantized_regs[3]}; - *reinterpret_cast(&out_sh[base + 16]) = { - quantized_regs[4], - quantized_regs[5], - quantized_regs[6], - quantized_regs[7]}; - break; - case 1: - *reinterpret_cast(&out_sh[base + 16]) = { - quantized_regs[4], - quantized_regs[5], - quantized_regs[6], - quantized_regs[7]}; - *reinterpret_cast(&out_sh[base + 0]) = { - quantized_regs[0], - quantized_regs[1], - quantized_regs[2], - quantized_regs[3]}; - break; - } - __syncthreads(); - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - - if (is_master) { - const size_t global_row = block_offset_row + step_row_offset; - const uint32_t out_x = static_cast(global_row); - const uint32_t out_y = static_cast(block_offset_col); - - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output), - out_x, - out_y, - reinterpret_cast(&out_sh[out_buff_offset])); - ptx::cp_async_bulk_commit_group(); - } - } - // Wait for all TMA stores to complete - ptx::cp_async_bulk_wait_group_read<0>(); - - __syncthreads(); - if (is_master) { -#pragma unroll - for (int iter = 0; iter < STEPS; ++iter) { - ptx::mbarrier_invalidate(&mbar[iter]); - } - } -#endif // __CUDA_ARCH__ >= 1000 -} - -template < - typename T, - bool USE_SR, - int THREADS_PER_BLOCK, - int COLS_PER_BLOCK, - int ROWS_PER_BLOCK> -__global__ void __launch_bounds__(THREADS_PER_BLOCK) - fp_quantize_columnwise_tma_nvfp4( - const __grid_constant__ CUtensorMap tensor_map_input, - const __grid_constant__ CUtensorMap tensor_map_output, - uint8_t* __restrict__ scales, - const size_t rows, - const size_t cols, - float* global_scale) { -#if (defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000) - // placeholder TODO - NVFP4 TMA kernel not yet implemented -#endif // __CUDA_ARCH__ >= 1000 -} - -} // namespace cu -} // namespace mlx::core From 4c4f3a8419f90ac2fe75532abf7faa7772916064 Mon Sep 17 00:00:00 2001 From: Anastasiia Filippova Date: Tue, 17 Mar 2026 16:09:33 +0100 Subject: [PATCH 15/15] fix old scale convertion --- mlx/backend/cuda/quantized/fp_quantize.cuh | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mlx/backend/cuda/quantized/fp_quantize.cuh b/mlx/backend/cuda/quantized/fp_quantize.cuh index a09234adb9..c5895d5eec 100644 --- a/mlx/backend/cuda/quantized/fp_quantize.cuh +++ b/mlx/backend/cuda/quantized/fp_quantize.cuh @@ -358,15 +358,14 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) scale /= F8E4M3_MAX; - using ScaleType = __nv_fp8_e8m0; - auto s = ScaleType(scale); + auto s = cutlass::float_ue8m0_t(scale); scale = float(s); // Write scale directly to global memory const size_t global_col = block_offset_col + tidx; const size_t global_row_group = (block_offset_row + step_row_offset) / TILE_M; if (global_col < cols && (block_offset_row + step_row_offset) < rows) { - scales[global_col * scale_stride + global_row_group] = s.__x; + scales[global_col * scale_stride + global_row_group] = s.storage; } const size_t out_buff_offset = buff * out_tile_elems; // Quantize to registers first @@ -605,8 +604,10 @@ __global__ void fp_dequantize( } size_t gindex = oindex / group_size; - using ScaleType = - std::conditional_t; + using ScaleType = std::conditional_t< + use_mx_scale, + cutlass::float_ue8m0_t, + cutlass::float_e4m3_t>; auto scale = float(((ScaleType*)(scales))[gindex]) * inv_scale_enc; out += oindex;