diff --git a/mlx/backend/metal/kernels/CMakeLists.txt b/mlx/backend/metal/kernels/CMakeLists.txt index 8d3d8a1953..c0126871c9 100644 --- a/mlx/backend/metal/kernels/CMakeLists.txt +++ b/mlx/backend/metal/kernels/CMakeLists.txt @@ -29,8 +29,8 @@ function(build_kernel_base TARGET SRCFILE DEPS) "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}") endif() add_custom_command( - COMMAND xcrun -sdk macosx metal ${METAL_FLAGS} -c ${SRCFILE} - -I${PROJECT_SOURCE_DIR} -o ${TARGET}.air + COMMAND xcrun metal ${METAL_FLAGS} -c ${SRCFILE} -I${PROJECT_SOURCE_DIR} -o + ${TARGET}.air DEPENDS ${SRCFILE} ${DEPS} ${BASE_HEADERS} OUTPUT ${TARGET}.air COMMENT "Building ${TARGET}.air" @@ -176,8 +176,7 @@ endif() add_custom_command( OUTPUT ${MLX_METAL_PATH}/mlx.metallib - COMMAND xcrun -sdk macosx metallib ${KERNEL_AIR} -o - ${MLX_METAL_PATH}/mlx.metallib + COMMAND xcrun metallib ${KERNEL_AIR} -o ${MLX_METAL_PATH}/mlx.metallib DEPENDS ${KERNEL_AIR} COMMENT "Building mlx.metallib" VERBATIM) diff --git a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal index c668d9d8c5..1a0a2b09b9 100644 --- a/mlx/backend/metal/kernels/scaled_dot_product_attention.metal +++ b/mlx/backend/metal/kernels/scaled_dot_product_attention.metal @@ -3,6 +3,7 @@ // clang-format off #include "mlx/backend/metal/kernels/utils.h" #include "mlx/backend/metal/kernels/sdpa_vector.h" +#include "mlx/backend/metal/kernels/sdpa_vector_turbo.h" using namespace metal; @@ -41,4 +42,25 @@ using namespace metal; instantiate_sdpa_vector_heads(float) instantiate_sdpa_vector_heads(bfloat16_t) instantiate_sdpa_vector_heads(float16_t) + +// TurboQuant SDPA: 3-bit packed K with codebook dequant +#define instantiate_sdpa_vector_turbo(type, qk_dim, value_dim, bits, vpw) \ + instantiate_kernel( \ + "sdpa_vector_turbo_" #type "_" #qk_dim "_" #value_dim \ + "_b" #bits "_vpw" #vpw, \ + sdpa_vector_turbo, \ + type, \ + qk_dim, \ + value_dim, \ + bits, \ + vpw) + +#define instantiate_sdpa_vector_turbo_heads(type) \ + instantiate_sdpa_vector_turbo(type, 64, 64, 3, 10) \ + instantiate_sdpa_vector_turbo(type, 128, 128, 3, 10) \ + instantiate_sdpa_vector_turbo(type, 64, 64, 4, 8) \ + instantiate_sdpa_vector_turbo(type, 128, 128, 4, 8) + +instantiate_sdpa_vector_turbo_heads(float16_t) +instantiate_sdpa_vector_turbo_heads(bfloat16_t) // clang-format on diff --git a/mlx/backend/metal/kernels/sdpa_vector_turbo.h b/mlx/backend/metal/kernels/sdpa_vector_turbo.h new file mode 100644 index 0000000000..76d0553106 --- /dev/null +++ b/mlx/backend/metal/kernels/sdpa_vector_turbo.h @@ -0,0 +1,204 @@ +// TurboQuant SDPA vector kernel: decode with pre-rotated queries and +// bit-packed KV cache. Reads 3-bit packed indices + norms + codebook, +// computes attention without materializing dequantized KV vectors. +// +// Pre-rotated query: Q_rot = WHT(signs * Q), computed once per head. +// Score: dot(Q_rot, codebook[K_indices]) * norm / sqrt(dim) +// No WHT butterfly in the inner loop. + +// NOTE: function_constants and metal includes are provided by the +// parent .metal file that includes this header. + +// TurboQuant SDPA: packed K/V with codebook dequantization +// K is stored as bit-packed uint32 indices + float32 norms +// V is stored as pre-dequantized fp16 (via incremental decode buffer) +// +// Template params: +// T: output type (float16/bfloat16) +// D: head dimension (64, 128) +// V_DIM: value dimension (usually == D) +// BITS: quantization bits (2, 3, 4) +// VPW: values per uint32 word (16, 10, 8 for 2, 3, 4 bits) +template +[[kernel]] void sdpa_vector_turbo( + const device T* queries [[buffer(0)]], // pre-rotated queries + const device uint32_t* k_packed [[buffer(1)]], // packed K indices + const device T* values [[buffer(2)]], // dequantized V (from decode buffer) + device T* out [[buffer(3)]], + const constant int& gqa_factor [[buffer(4)]], + const constant int& N [[buffer(5)]], + const constant size_t& k_head_stride [[buffer(6)]], // in uint32 words + const constant size_t& k_seq_stride [[buffer(7)]], // in uint32 words + const constant size_t& v_head_stride [[buffer(8)]], + const constant size_t& v_seq_stride [[buffer(9)]], + const constant float& scale [[buffer(10)]], + const device bool* bmask [[buffer(11), function_constant(bool_mask)]], + const device T* fmask [[buffer(12), function_constant(float_mask)]], + const constant int& mask_kv_seq_stride + [[buffer(13), function_constant(has_mask)]], + const constant int& mask_q_seq_stride + [[buffer(14), function_constant(has_mask)]], + const constant int& mask_head_stride + [[buffer(15), function_constant(has_mask)]], + const device T* sinks [[buffer(16), function_constant(has_sinks)]], + const constant int& num_q_heads + [[buffer(17), function_constant(has_sinks)]], + const device float* k_norms [[buffer(18)]], // per-vector norms + const constant size_t& k_norm_head_stride [[buffer(19)]], + const device float* codebook [[buffer(20)]], // 2^BITS centroids + const constant float& inv_sqrt_dim [[buffer(21)]], // 1/sqrt(dim) + uint3 tid [[threadgroup_position_in_grid]], + uint3 tpg [[threadgroups_per_grid]], + uint simd_gid [[simdgroup_index_in_threadgroup]], + uint simd_lid [[thread_index_in_simdgroup]]) { + constexpr int BN = 32; + constexpr int BD = 32; + constexpr int qk_per_thread = D / BD; + constexpr int v_per_thread = V_DIM / BD; + constexpr int PACKED_DIM = (D + VPW - 1) / VPW; + constexpr uint BIT_MASK = (1u << BITS) - 1u; + + int inner_v_stride = BN * int(v_seq_stride); + + typedef float U; + + thread U q[qk_per_thread]; + thread U o[v_per_thread]; + + threadgroup U outputs[BN * BD]; + threadgroup U max_scores[BN]; + threadgroup U sum_exp_scores[BN]; + + // Adjust positions + const int q_batch_head_idx = tid.x; + const int q_seq_idx = tid.y; + const int kv_head_idx = q_batch_head_idx / gqa_factor; + const int o_offset = q_batch_head_idx * tpg.y + q_seq_idx; + const int q_offset = + query_transposed ? tpg.x * q_seq_idx + q_batch_head_idx : o_offset; + + // Query pointer (pre-rotated) + queries += q_offset * D + simd_lid * qk_per_thread; + + // K packed pointer: navigate to correct head, then stride by simd_gid + k_packed += kv_head_idx * k_head_stride + simd_gid * k_seq_stride; + k_norms += kv_head_idx * k_norm_head_stride + simd_gid; + + // V pointer (dequantized fp16 from decode buffer) + values += kv_head_idx * v_head_stride + simd_gid * v_seq_stride + + simd_lid * v_per_thread; + + if (bool_mask) { + bmask += q_batch_head_idx * mask_head_stride + + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; + } + if (float_mask) { + fmask += q_batch_head_idx * mask_head_stride + + simd_gid * mask_kv_seq_stride + q_seq_idx * mask_q_seq_stride; + } + + out += o_offset * V_DIM + simd_gid * v_per_thread; + + // Read pre-rotated query (already scaled) + for (int i = 0; i < qk_per_thread; i++) { + q[i] = static_cast(scale) * queries[i]; + } + for (int i = 0; i < v_per_thread; i++) { + o[i] = 0; + } + + U max_score = -1e9f; + U sum_exp_score = 0; + if (has_sinks && simd_gid == 0) { + max_score = static_cast(sinks[q_batch_head_idx % num_q_heads]); + sum_exp_score = 1; + } + + // For each key position + for (int i = simd_gid; i < N; i += BN) { + bool use_key = true; + if (do_causal) { + use_key = i <= (N - int(tpg.y) + int(q_seq_idx)); + } else if (bool_mask) { + use_key = bmask[0]; + } else if (float_mask) { + use_key = (fmask[0] >= -1e9f); + } + if (use_key) { + // --- TurboQuant: read packed K indices, codebook lookup --- + // Each thread handles qk_per_thread = D/32 elements + // Thread simd_lid handles elements [simd_lid*qk_per_thread, + // (simd_lid+1)*qk_per_thread) + U score = 0; + int elem_start = simd_lid * qk_per_thread; + for (int j = 0; j < qk_per_thread; j++) { + int elem = elem_start + j; + int word_idx = elem / VPW; + int pos_in_word = elem % VPW; + uint word = k_packed[word_idx]; + uint idx = (word >> (pos_in_word * BITS)) & BIT_MASK; + U k_val = codebook[idx]; + score += q[j] * k_val; + } + + // Apply norm and scale: score = dot(q_rot, codebook[indices]) * norm * + // inv_sqrt_dim + U norm_val = k_norms[0]; + score = simd_sum(score) * norm_val * inv_sqrt_dim; + + if (float_mask) { + score += static_cast(fmask[0]); + } + + // Update the accumulators (same as standard sdpa_vector) + U new_max = max(max_score, score); + U factor = fast::exp(max_score - new_max); + U exp_score = fast::exp(score - new_max); + + max_score = new_max; + sum_exp_score = sum_exp_score * factor + exp_score; + + // Update output with dequantized V (from decode buffer, already fp16) + for (int j = 0; j < v_per_thread; j++) { + o[j] = o[j] * factor + exp_score * static_cast(values[j]); + } + } + + // Advance K packed pointer by BN positions + k_packed += BN * k_seq_stride; + k_norms += BN; + values += inner_v_stride; + if (bool_mask) { + bmask += BN * mask_kv_seq_stride; + } + if (float_mask) { + fmask += BN * mask_kv_seq_stride; + } + } + + // Reduction across SIMD groups (same as standard sdpa_vector) + if (simd_lid == 0) { + max_scores[simd_gid] = max_score; + sum_exp_scores[simd_gid] = sum_exp_score; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_score = max_scores[simd_lid]; + U new_max = simd_max(max_score); + U factor = fast::exp(max_score - new_max); + sum_exp_score = simd_sum(sum_exp_scores[simd_lid] * factor); + + for (int i = 0; i < v_per_thread; i++) { + outputs[simd_lid * BD + simd_gid] = o[i]; + threadgroup_barrier(mem_flags::mem_threadgroup); + o[i] = simd_sum(outputs[simd_gid * BD + simd_lid] * factor); + o[i] = sum_exp_score == 0 ? o[i] : (o[i] / sum_exp_score); + threadgroup_barrier(mem_flags::mem_threadgroup); + } + + // Write output + if (simd_lid == 0) { + for (int i = 0; i < v_per_thread; i++) { + out[i] = static_cast(o[i]); + } + } +} diff --git a/mlx/backend/metal/scaled_dot_product_attention.cpp b/mlx/backend/metal/scaled_dot_product_attention.cpp index 37e554f183..edd3ae804b 100644 --- a/mlx/backend/metal/scaled_dot_product_attention.cpp +++ b/mlx/backend/metal/scaled_dot_product_attention.cpp @@ -415,6 +415,106 @@ void sdpa_vector( compute_encoder.dispatch_threadgroups(grid_dims, group_dims); } +void sdpa_vector_turbo( + const Stream& s, + metal::Device& d, + const array& q, // pre-rotated queries (B*H_q, q_seq, D) + const array& + k_packed, // packed uint32 K indices (B*H_kv, kv_seq, packed_dim) + const array& v, // dequantized V from decode buffer + array& out, + float scale, + bool do_causal, + const std::optional& mask, + const array& k_norms, // per-vector norms (B*H_kv, kv_seq) + const array& codebook, // (n_centroids,) float32 + int bits, + float inv_sqrt_dim) { + int vpw = bits == 3 ? 10 : (bits == 4 ? 8 : (bits == 2 ? 16 : 32)); + + // Kernel name: sdpa_vector_turbo____b_vpw + std::string kname; + kname.reserve(64); + kname += "sdpa_vector_turbo_"; + kname += get_type_string(q.dtype()); + kname += "_"; + kname += std::to_string(q.shape(-1)); + kname += "_"; + kname += std::to_string(v.shape(-1)); + kname += "_b"; + kname += std::to_string(bits); + kname += "_vpw"; + kname += std::to_string(vpw); + + int gqa_factor = q.shape(1) / k_packed.shape(1); + int N = k_norms.shape(1); // kv sequence length + size_t k_head_stride = + k_packed.shape(1) == 1 ? k_packed.strides(0) : k_packed.strides(1); + size_t k_seq_stride = k_packed.strides()[2]; + size_t v_head_stride = v.shape(1) == 1 ? v.strides(0) : v.strides(1); + size_t v_seq_stride = v.strides()[2]; + size_t k_norm_head_stride = + k_norms.shape(1) == 1 ? k_norms.strides(0) : k_norms.strides(1); + + MTL::Size group_dims(1024, 1, 1); + MTL::Size grid_dims(q.shape(0) * q.shape(1), q.shape(2), 1); + + bool has_mask = mask.has_value(); + bool bool_mask = has_mask && (*mask).dtype() == bool_; + bool float_mask = has_mask && !bool_mask; + bool query_transposed = !q.flags().row_contiguous; + bool has_sinks = false; + metal::MTLFCList func_consts = { + {&has_mask, MTL::DataType::DataTypeBool, 20}, + {&query_transposed, MTL::DataType::DataTypeBool, 21}, + {&do_causal, MTL::DataType::DataTypeBool, 22}, + {&bool_mask, MTL::DataType::DataTypeBool, 23}, + {&float_mask, MTL::DataType::DataTypeBool, 24}, + {&has_sinks, MTL::DataType::DataTypeBool, 25}, + }; + std::string hash_name = kname; + hash_name += has_mask ? (bool_mask ? "_boolmask" : "_floatmask") : "_nomask"; + hash_name += query_transposed ? "_qt" : "_qnt"; + hash_name += do_causal ? "_c" : "_nc"; + + auto& compute_encoder = d.get_command_encoder(s.index); + auto kernel = d.get_kernel(kname, hash_name, func_consts); + compute_encoder.set_compute_pipeline_state(kernel); + + // Set arguments matching sdpa_vector_turbo kernel signature + compute_encoder.set_input_array(q, 0); // queries (pre-rotated) + compute_encoder.set_input_array(k_packed, 1); // packed K indices + compute_encoder.set_input_array(v, 2); // dequantized V + compute_encoder.set_output_array(out, 3); + compute_encoder.set_bytes(gqa_factor, 4); + compute_encoder.set_bytes(N, 5); + compute_encoder.set_bytes(k_head_stride, 6); + compute_encoder.set_bytes(k_seq_stride, 7); + compute_encoder.set_bytes(v_head_stride, 8); + compute_encoder.set_bytes(v_seq_stride, 9); + compute_encoder.set_bytes(scale, 10); + + if (has_mask) { + auto& m = *mask; + compute_encoder.set_input_array(m, 11 + float_mask); + int32_t kv_seq_stride = m.shape(3) > 1 ? m.strides(3) : 0; + int32_t q_seq_stride = m.shape(2) > 1 ? m.strides(2) : 0; + int32_t head_stride = + m.shape(1) > 1 ? m.strides(1) : (m.shape(0) > 1 ? m.strides(0) : 0); + compute_encoder.set_bytes(kv_seq_stride, 13); + compute_encoder.set_bytes(q_seq_stride, 14); + compute_encoder.set_bytes(head_stride, 15); + } + + // TurboQuant-specific buffers + compute_encoder.set_input_array(k_norms, 18); + compute_encoder.set_bytes(k_norm_head_stride, 19); + compute_encoder.set_input_array(codebook, 20); + compute_encoder.set_bytes(inv_sqrt_dim, 21); + + compute_encoder.dispatch_threadgroups(grid_dims, group_dims); +} + void sdpa_vector_2pass( const Stream& s, metal::Device& d, @@ -785,6 +885,64 @@ void ScaledDotProductAttention::eval_gpu( d.add_temporaries(std::move(copies), s.index); } +void TurboQuantSDPA::eval_gpu( + const std::vector& inputs, + std::vector& outputs) { + // inputs: [queries, k_packed, values, k_norms, codebook] + auto& q_pre = inputs[0]; + auto& k_packed = inputs[1]; + auto& v_pre = inputs[2]; + auto& k_norms = inputs[3]; + auto& codebook = inputs[4]; + auto& o = outputs[0]; + + auto& s = stream(); + auto& d = metal::device(s.device); + + std::vector copies; + copies.reserve(4); + auto ensure_contiguous = [&copies, &s](const array& arr) -> const array& { + if (arr.flags().row_contiguous || arr.strides(-1) == 1) { + return arr; + } + copies.push_back(contiguous_copy_gpu(arr, s)); + return copies.back(); + }; + + const auto& q = ensure_contiguous(q_pre); + const auto& v = ensure_contiguous(v_pre); + + // Try to donate query buffer for output + if (q.is_donatable() && q.flags().row_contiguous && q.size() == o.size()) { + o.copy_shared_buffer(q); + } else { + o.set_data(allocator::malloc(o.nbytes())); + } + + bool do_causal = do_causal_ && q.shape(2) > 1; + + sdpa_vector_turbo( + s, + d, + q, + k_packed, + v, + o, + scale_, + do_causal, + std::nullopt, + k_norms, + codebook, + bits_, + inv_sqrt_dim_); +} + +bool TurboQuantSDPA::is_equivalent(const Primitive& other) const { + const TurboQuantSDPA& a_other = static_cast(other); + return scale_ == a_other.scale_ && do_causal_ == a_other.do_causal_ && + bits_ == a_other.bits_; +} + bool ScaledDotProductAttentionVJP::use_fallback(const array& q, Stream s) { return true; } diff --git a/mlx/fast.cpp b/mlx/fast.cpp index a668fe9abd..5effeb611c 100644 --- a/mlx/fast.cpp +++ b/mlx/fast.cpp @@ -955,4 +955,59 @@ bool ConvertFP8::is_equivalent(const Primitive& other) const { return to_fp8_ == a_other.to_fp8_; } +array turboquant_sdpa( + const array& queries, + const array& k_packed, + const array& values, + const array& k_norms, + const array& codebook, + const float scale, + const int bits /* = 3 */, + const float inv_sqrt_dim_in /* = 0.0f */, + const std::string& mask_mode /* = "" */, + std::optional mask_arr /* = {} */, + StreamOrDevice s /* = {} */) { + if (queries.ndim() != 4 || values.ndim() != 4) { + throw std::invalid_argument( + "[turboquant_sdpa] queries and values expected to be rank 4"); + } + + int D = queries.shape(-1); + float inv_sqrt_dim = inv_sqrt_dim_in > 0 + ? inv_sqrt_dim_in + : (1.0f / std::sqrt(static_cast(D))); + bool do_causal = mask_mode == "causal"; + + auto final_type = queries.dtype(); + + // CPU fallback: dequantize K from packed, then standard attention + auto fallback = [scale, bits, inv_sqrt_dim, do_causal, s]( + const std::vector& inputs) { + // For CPU: just use V as both K and V (placeholder) + // Real CPU support would need dequant implementation + return std::vector{scaled_dot_product_attention( + inputs[0], + inputs[2], + inputs[2], + scale, + do_causal ? "causal" : "", + {}, + {}, + s)}; + }; + + auto out = array( + queries.shape(), + final_type, + std::make_shared( + to_stream(s), fallback, scale, do_causal, bits, inv_sqrt_dim), + {astype(queries, final_type, s), + astype(k_packed, uint32, s), + astype(values, final_type, s), + astype(k_norms, float32, s), + astype(codebook, float32, s)}); + + return out; +} + } // namespace mlx::core::fast diff --git a/mlx/fast.h b/mlx/fast.h index 1183aba8fe..942484f01c 100644 --- a/mlx/fast.h +++ b/mlx/fast.h @@ -54,6 +54,23 @@ MLX_API array scaled_dot_product_attention( const std::optional& sinks = {}, StreamOrDevice s = {}); +/** TurboQuant SDPA: attention with bit-packed KV cache. + * K is stored as packed uint32 indices with per-vector norms. + * Queries must be pre-rotated: Q_rot = WHT(signs * Q). + * V is passed as dequantized fp16 (from decode buffer). **/ +MLX_API array turboquant_sdpa( + const array& queries, // pre-rotated (B, H_q, T_q, D) + const array& k_packed, // packed uint32 (B, H_kv, T_kv, packed_dim) + const array& values, // dequantized V (B, H_kv, T_kv, D) + const array& k_norms, // per-vector norms (B, H_kv, T_kv) + const array& codebook, // centroids (n_levels,) + const float scale, + const int bits = 3, + const float inv_sqrt_dim = 0.0f, + const std::string& mask_mode = "", + std::optional mask_arr = {}, + StreamOrDevice s = {}); + using TemplateArg = std::variant; using ScalarArg = std::variant; diff --git a/mlx/fast_primitives.h b/mlx/fast_primitives.h index 4434830875..8c00c69532 100644 --- a/mlx/fast_primitives.h +++ b/mlx/fast_primitives.h @@ -260,6 +260,41 @@ class ScaledDotProductAttention : public Custom { bool output_logsumexp_; }; +class TurboQuantSDPA : public Custom { + public: + TurboQuantSDPA( + Stream stream, + std::function(std::vector)> fallback, + float scale, + bool do_causal, + int bits, + float inv_sqrt_dim) + : Custom(stream, std::move(fallback)), + scale_(scale), + do_causal_(do_causal), + bits_(bits), + inv_sqrt_dim_(inv_sqrt_dim) {} + + void eval_cpu(const std::vector& inputs, std::vector& outputs) + override { + throw std::runtime_error("[TurboQuantSDPA] CPU not supported"); + } + + void eval_gpu(const std::vector& inputs, std::vector& outputs) + override; + + bool is_equivalent(const Primitive& other) const override; + + DEFINE_NAME(TurboQuantSDPA); + DEFINE_INPUT_OUTPUT_SHAPE() + + private: + float scale_; + bool do_causal_; + int bits_; + float inv_sqrt_dim_; +}; + class ScaledDotProductAttentionVJP : public Custom { public: ScaledDotProductAttentionVJP( diff --git a/mlx/ops.cpp b/mlx/ops.cpp index ef792cd6f4..f94df8935c 100644 --- a/mlx/ops.cpp +++ b/mlx/ops.cpp @@ -4394,6 +4394,10 @@ std::pair quantization_params_from_mode( default_group_size = 32; default_bits = 8; break; + case QuantizationMode::TurboQuant: + default_group_size = 128; + default_bits = 3; + break; } return { group_size_.has_value() ? *group_size_ : default_group_size, diff --git a/mlx/primitives.cpp b/mlx/primitives.cpp index 220b8bcf55..baaaf780af 100644 --- a/mlx/primitives.cpp +++ b/mlx/primitives.cpp @@ -3377,8 +3377,10 @@ std::string quantization_mode_to_string(QuantizationMode mode) { case QuantizationMode::Mxfp8: return "mxfp8"; case QuantizationMode::Nvfp4: - default: return "nvfp4"; + case QuantizationMode::TurboQuant: + default: + return "turboquant"; } } @@ -3393,6 +3395,8 @@ QuantizationMode string_to_quantization_mode( return QuantizationMode::Mxfp8; } else if (mode == "nvfp4") { return QuantizationMode::Nvfp4; + } else if (mode == "turboquant") { + return QuantizationMode::TurboQuant; } std::string msg; if (!tag.empty()) { diff --git a/mlx/primitives.h b/mlx/primitives.h index ed580d6a4a..5bb6d94338 100644 --- a/mlx/primitives.h +++ b/mlx/primitives.h @@ -152,7 +152,7 @@ class MLX_API UnaryPrimitive : public Primitive { UnaryPrimitive& operator=(UnaryPrimitive&& other) = delete; }; -enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4 }; +enum class QuantizationMode { Affine, Mxfp4, Mxfp8, Nvfp4, TurboQuant }; std::string quantization_mode_to_string(QuantizationMode mode); QuantizationMode string_to_quantization_mode( diff --git a/python/src/fast.cpp b/python/src/fast.cpp index 1a43d89d9b..25d2bc216d 100644 --- a/python/src/fast.cpp +++ b/python/src/fast.cpp @@ -296,6 +296,65 @@ void init_fast(nb::module_& parent_module) { out = mx.fast.scaled_dot_product_attention(q, k, v, scale=scale, mask="causal") )pbdoc"); + m.def( + "turboquant_sdpa", + [](const mx::array& queries, + const mx::array& k_packed, + const mx::array& values, + const mx::array& k_norms, + const mx::array& codebook, + const float scale, + const int bits, + const float inv_sqrt_dim, + const std::optional& mask, + mx::StreamOrDevice s) { + std::string mask_mode = mask.value_or(""); + return mx::fast::turboquant_sdpa( + queries, + k_packed, + values, + k_norms, + codebook, + scale, + bits, + inv_sqrt_dim, + mask_mode, + {}, + s); + }, + "queries"_a, + "k_packed"_a, + "values"_a, + "k_norms"_a, + "codebook"_a, + nb::kw_only(), + "scale"_a, + "bits"_a = 3, + "inv_sqrt_dim"_a = 0.0f, + "mask"_a = nb::none(), + "stream"_a = nb::none(), + R"pbdoc( + TurboQuant SDPA: attention with bit-packed KV cache. + + Computes attention using pre-rotated queries and compressed K cache. + K is stored as bit-packed uint32 codebook indices + per-vector norms. + V is passed as dequantized fp16 (from incremental decode buffer). + + Args: + queries (array): Pre-rotated queries ``[B, H_q, T_q, D]``. + k_packed (array): Packed K indices ``[B, H_kv, T_kv, packed_dim]`` uint32. + values (array): Dequantized V ``[B, H_kv, T_kv, D]``. + k_norms (array): Per-vector K norms ``[B, H_kv, T_kv]`` float32. + codebook (array): Quantization centroids ``[n_levels]`` float32. + scale (float): Attention scale (typically ``1/sqrt(D)``). + bits (int): Quantization bits (2, 3, or 4). Default: ``3``. + inv_sqrt_dim (float): ``1/sqrt(D)``. Default: auto-computed. + mask (str, optional): ``"causal"`` or ``None``. Default: ``None``. + + Returns: + array: Attention output ``[B, H_q, T_q, D]``. + )pbdoc"); + m.def( "metal_kernel", [](const std::string& name, diff --git a/test_turbo_sdpa.py b/test_turbo_sdpa.py new file mode 100644 index 0000000000..3011819060 --- /dev/null +++ b/test_turbo_sdpa.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +"""Test mx.fast.turboquant_sdpa with the native Metal kernel.""" + +import os +import sys + +# Use our built MLX +build_path = os.path.join(os.path.dirname(__file__), "build", "python") +sys.path.insert(0, build_path) + +# mx.fast is accessed via mx.core.fast +import math +import time + +import mlx.core as mx + + +def test_basic(): + """Basic turboquant_sdpa call.""" + B, H_q, H_kv, T, D = 1, 4, 4, 32, 128 + bits = 3 + vpw = 10 + packed_dim = (D + vpw - 1) // vpw # 13 + + q = mx.random.normal(shape=(B, H_q, 1, D)).astype(mx.float16) + k_packed = mx.random.randint(0, 8, shape=(B, H_kv, T, packed_dim)).astype(mx.uint32) + v = mx.random.normal(shape=(B, H_kv, T, D)).astype(mx.float16) + k_norms = mx.random.normal(shape=(B, H_kv, T)).astype(mx.float32) + codebook = mx.array( + [-2.15, -1.34, -0.756, -0.245, 0.245, 0.756, 1.34, 2.15], dtype=mx.float32 + ) + scale = 1.0 / math.sqrt(D) + + print("Calling mx.fast.turboquant_sdpa...") + out = mx.fast.turboquant_sdpa( + q, + k_packed, + v, + k_norms, + codebook, + scale=scale, + bits=bits, + ) + mx.eval(out) + print(f" Output shape: {out.shape}") + print(f" Output dtype: {out.dtype}") + print(f" Output sample: {out[0, 0, 0, :5].tolist()}") + print(f" Has NaN: {mx.any(mx.isnan(out)).item()}") + print(" PASSED!") + + +def test_speed(): + """Benchmark native turbo SDPA vs standard SDPA.""" + B, H_q, H_kv, D = 1, 28, 4, 128 + bits = 3 + vpw = 10 + packed_dim = (D + vpw - 1) // vpw + scale = 1.0 / math.sqrt(D) + + for T in [256, 1024, 4096]: + q = mx.random.normal(shape=(B, H_q, 1, D)).astype(mx.float16) + v = mx.random.normal(shape=(B, H_kv, T, D)).astype(mx.float16) + + # Standard SDPA with float K + k_float = mx.random.normal(shape=(B, H_kv, T, D)).astype(mx.float16) + + # TurboQuant + k_packed = mx.random.randint(0, 8, shape=(B, H_kv, T, packed_dim)).astype( + mx.uint32 + ) + k_norms = mx.abs(mx.random.normal(shape=(B, H_kv, T))).astype(mx.float32) + codebook = mx.array( + [-2.15, -1.34, -0.756, -0.245, 0.245, 0.756, 1.34, 2.15], dtype=mx.float32 + ) + + mx.eval(q, k_float, v, k_packed, k_norms) + + # Warmup + for _ in range(5): + mx.eval(mx.fast.scaled_dot_product_attention(q, k_float, v, scale=scale)) + mx.eval( + mx.fast.turboquant_sdpa( + q, k_packed, v, k_norms, codebook, scale=scale, bits=bits + ) + ) + + # Standard SDPA + t0 = time.perf_counter() + for _ in range(50): + mx.eval(mx.fast.scaled_dot_product_attention(q, k_float, v, scale=scale)) + std_ms = (time.perf_counter() - t0) / 50 * 1000 + + # TurboQuant SDPA + t0 = time.perf_counter() + for _ in range(50): + mx.eval( + mx.fast.turboquant_sdpa( + q, k_packed, v, k_norms, codebook, scale=scale, bits=bits + ) + ) + tq_ms = (time.perf_counter() - t0) / 50 * 1000 + + print( + f" T={T:>5}: std={std_ms:.3f}ms turbo={tq_ms:.3f}ms ratio={tq_ms/std_ms:.2f}x" + ) + + +if __name__ == "__main__": + print("=" * 50) + print("TurboQuant Native Metal SDPA Test") + print("=" * 50) + + print("\n[Basic test]") + try: + test_basic() + except Exception as e: + print(f" FAILED: {e}") + import traceback + + traceback.print_exc() + + print("\n[Speed benchmark]") + try: + test_speed() + except Exception as e: + print(f" FAILED: {e}") + import traceback + + traceback.print_exc()