diff --git a/common/arg.cpp b/common/arg.cpp index 3d0183ed702..52e77ddf5a8 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -390,6 +390,10 @@ const std::vector kv_cache_types = { GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, + // TurboQuant KV cache types (GPU only) + GGML_TYPE_TURBO3_0, + GGML_TYPE_TURBO4_0, + GGML_TYPE_TURBO2_0, }; static ggml_type kv_cache_type_from_str(const std::string & s) { diff --git a/docs/turbo-quant.md b/docs/turbo-quant.md new file mode 100644 index 00000000000..d35a16a57b4 --- /dev/null +++ b/docs/turbo-quant.md @@ -0,0 +1,159 @@ +# TurboQuant HIP/CUDA — Implementation Notes + +TurboQuant is a family of quantization schemes for the **KV-cache** (and optionally for +weight matrices) that achieves higher compression than existing quantization types while +preserving flash-attention quality. This document describes the upstream port, the +conflict-resolution choices made during the port, and how to enable/test the new code. + +--- + +## What is TurboQuant? + +| Type | Bits/value | Description | +|------|-----------|-------------| +| `TURBO3_0` | ≈3.19 | 3-bit KV cache: 2-bit PolarQuant indices + 1-bit QJL sign correction | +| `TURBO4_0` | 4.25 | 4-bit KV cache: 4-bit PolarQuant indices (nibble-packed) | +| `TURBO2_0` | 2.13 | 2-bit KV cache: 2-bit PolarQuant indices only (fastest, most aggressive) | +| `TQ3_1S` | ≈3 | 3-bit weight quantization: WHT-rotated 8-level Lloyd-Max | +| `TQ4_1S` | ≈4 | 4-bit weight quantization: WHT-rotated 16-level Lloyd-Max | + +All KV-cache types use a block size of **128** elements. The WHT (Walsh-Hadamard +Transform) rotation mixes correlations within each block to improve quantization quality. + +--- + +## New GGML types + +``` +GGML_TYPE_TURBO3_0 = 42 +GGML_TYPE_TURBO4_0 = 43 +GGML_TYPE_TURBO2_0 = 44 +GGML_TYPE_TQ3_1S = 45 +GGML_TYPE_TQ4_1S = 46 +``` + +A new GGML operator `GGML_OP_TURBO_WHT` applies the (inverse) Walsh-Hadamard Transform +to floating-point vectors. + +--- + +## New source files + +| File | Description | +|------|-------------| +| `ggml/src/ggml-turbo-quant.c` | CPU quantization / dequantization reference implementations | +| `ggml/src/ggml-cuda/turbo-quant.cuh` | GPU centroid tables and block dequant helpers | +| `ggml/src/ggml-cuda/turbo-wht.cu/.cuh` | GPU Walsh-Hadamard Transform kernel | +| `ggml/src/ggml-cuda/turbo-innerq.cu/.cuh` | GPU inner quantization / conversion kernels | +| `ggml/src/ggml-cuda/mmvq-tq.cu/.cuh` | GPU matrix-vector multiply for TQ weight types | +| `ggml/src/ggml-cuda/template-instances/fattn-vec-instance-{K}-{V}.cu` | Per-type flash-attention kernel instantiations | +| `ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq640-dv512.cu` | Tile flash-attention instance for D_KQ=640, D_V=512 | + +--- + +## Modified files + +| File | Change | +|------|--------| +| `ggml/include/ggml.h` | New `ggml_type` values, `GGML_OP_TURBO_WHT`, `ggml_turbo_wht()`, `ggml_flash_attn_ext_set_kv_indices()` | +| `ggml/src/ggml-common.h` | Block struct definitions for new types | +| `ggml/src/ggml.c` | Type metadata tables, `ggml_turbo_wht()` and `ggml_flash_attn_ext_set_kv_indices()` implementations | +| `ggml/src/ggml-cuda/ggml-cuda.cu` | TQ4_1S load-time conversion, mul_mat dispatch for TQ weights, `GGML_OP_TURBO_WHT` dispatch, `supports_op` entries | +| `ggml/src/ggml-cuda/fattn-common.cuh` | `kv_indices` parameter in kernel typedef; KQ dot and V dequant functions for turbo types | +| `ggml/src/ggml-cuda/fattn-vec.cuh` | `kv_indices` parameter wired through | +| `ggml/src/ggml-cuda/fattn.cu` | Turbo VEC kernel instantiations; D=640 MMA case; mixed-type allow-list | +| `ggml/src/ggml-cuda/dequantize.cuh` | Dequantize helpers for TURBO types | +| `ggml/src/ggml-cuda/convert.cu` | TQ4_1S → Q8_0 conversion kernel used during model load | +| `ggml/src/ggml-cuda/set-rows.cu` | SET_ROWS support for TURBO3_0, TURBO4_0, TURBO2_0 | +| `ggml/src/CMakeLists.txt` | Adds `ggml-turbo-quant.c` to the ggml-base target | +| `ggml/src/ggml-cuda/CMakeLists.txt` | Adds new `.cu` files to the CUDA target | +| `ggml/src/ggml-hip/CMakeLists.txt` | Mirrors CUDA CMake additions for HIP/ROCm | +| `src/llama-kv-cache.cpp` | KV cache quantization type validation accepts turbo types | +| `common/arg.cpp` | `--cache-type-k` / `--cache-type-v` CLI help updated | + +--- + +## Conflict resolution notes + +The source branch (`domvox/llama.cpp-turboquant-hip:feature/triattention-scoring`) had +significant divergence from upstream `ggml-org/master`. Key reconciliations: + +1. **`fattn_kernel_t` typedef** — The source branch adds a `kv_indices` parameter for the + sparse-attention TriAttention feature. We include this parameter in the typedef and + thread it through `launch_fattn` (passing `nullptr` when unused), but we **omit** the + full TriAttention routing/pruning logic as out-of-scope for this port. + +2. **`GGML_API` visibility** — The source branch removes `extern` from the non-Windows + `GGML_API` macro. This is a project-wide API change that requires careful review; we + omit it here to minimise diff size. + +3. **`TURBO4_USE_4BIT` guard** — The C implementation in `ggml-turbo-quant.c` contains a + legacy `#else` branch that references a `signs` field absent from `block_turbo4_0`. + We add `#define TURBO4_USE_4BIT 1` to `ggml-common.h` to ensure the correct 4-bit + path is always used. + +4. **D=640 MMA kernel** — Added for GLM-4.7 Flash (K head-dim 576 zero-padded to 640) + but depends on `ggml_cuda_flash_attn_ext_mma_f16_case<640, 512, …>`. Because the + corresponding MMA template instances are not yet part of this PR, the runtime will + fall back to the VEC kernel or CPU for that model variant. + +5. **TQ weight `mul_mat_id`** — `TQ4_1S` / `TQ3_1S` weight tensors are transparently + converted to `Q8_0` on upload to the GPU; they use the cuBLAS dequant path and bypass + mmvq/mmq kernels. + +--- + +## How to enable TurboQuant KV cache + +Pass one of the new types to `--cache-type-k` and/or `--cache-type-v`: + +```bash +# 3-bit K, 3-bit V (highest quality turbo) +llama-cli -m model.gguf --cache-type-k turbo3_0 --cache-type-v turbo3_0 ... + +# 2-bit K, 3-bit V (aggressive K, quality V) +llama-cli -m model.gguf --cache-type-k turbo2_0 --cache-type-v turbo3_0 ... + +# 4-bit K, 8-bit V (near-lossless V, compressed K) +llama-cli -m model.gguf --cache-type-k turbo4_0 --cache-type-v q8_0 ... +``` + +### Supported combinations (GPU/VEC kernel) + +Any combination of `{turbo2_0, turbo3_0, turbo4_0, q8_0}` for K and V is supported on the VEC +flash-attention kernel. The VEC kernel activates when: + +- Head dimension ≤ 256 and is a multiple of 64 +- Context length is a multiple of 256 (`FATTN_KQ_STRIDE`) + +Larger head dimensions or non-standard context lengths fall back to the MMA tile kernel +(if available) or to the CPU. + +--- + +## Building with HIP/ROCm + +```bash +cmake -B build -DGGML_HIP=ON -DAMDGPU_TARGETS="gfx1100;gfx942" \ + -DCMAKE_BUILD_TYPE=Release +cmake --build build -j$(nproc) +``` + +No extra CMake flags are required to enable TurboQuant; it is compiled in unconditionally. + +--- + +## Limitations / caveats + +- **MMA kernel not yet supported** for turbo KV types. Very long contexts or large + batch sizes fall back to the VEC kernel or CPU. +- **TQ3_1S / TQ4_1S weight types** require the WHT rotation operator + (`GGML_OP_TURBO_WHT`) to be applied by the graph builder. A reference + `llama_model_loader` hook is not yet wired; these types are usable via low-level API + only. +- **TriAttention / kv_indices** sparse-attention routing is partially scaffolded + (`ggml_flash_attn_ext_set_kv_indices` API exists) but not fully implemented on the + VEC kernel side. +- The TURBO2_0 and TURBO3_0 centroid tables live in `turbo-quant.cuh` (GPU) and + `ggml-turbo-quant.c` (CPU). They are Lloyd-Max optimal for a unit-normal distribution + of WHT-rotated activations. diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 703e3783136..169c3872fbe 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -429,7 +429,12 @@ extern "C" { GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block) GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale) GGML_TYPE_Q1_0 = 41, - GGML_TYPE_COUNT = 42, + GGML_TYPE_TURBO3_0 = 42, // TurboQuant 3-bit KV cache: 2-bit PolarQuant + 1-bit sign + GGML_TYPE_TURBO4_0 = 43, // TurboQuant 4-bit KV cache: 4-bit PolarQuant (nibble packed) + GGML_TYPE_TURBO2_0 = 44, // TurboQuant 2-bit KV cache: 2-bit PolarQuant (no QJL) + GGML_TYPE_TQ3_1S = 45, // TurboQuant 3-bit weight: WHT-rotated 8-level Lloyd-Max, block_size=32 + GGML_TYPE_TQ4_1S = 46, // TurboQuant 4-bit weight: WHT-rotated 16-level Lloyd-Max, block_size=32 + GGML_TYPE_COUNT = 47, }; // precision @@ -561,6 +566,7 @@ extern "C" { GGML_OP_RWKV_WKV7, GGML_OP_SOLVE_TRI, GGML_OP_GATED_DELTA_NET, + GGML_OP_TURBO_WHT, GGML_OP_UNARY, @@ -2401,6 +2407,12 @@ extern "C" { GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec( const struct ggml_tensor * a); + // Optional KV row indices for sparse cache access. + // indices is a contiguous 1D i32 tensor mapping logical to physical KV rows. + GGML_API void ggml_flash_attn_ext_set_kv_indices( + struct ggml_tensor * a, + struct ggml_tensor * indices); + GGML_API void ggml_flash_attn_ext_add_sinks( struct ggml_tensor * a, struct ggml_tensor * sinks); @@ -2539,6 +2551,16 @@ extern "C" { struct ggml_tensor * beta, struct ggml_tensor * state); + // TurboQuant Walsh-Hadamard Transform (O(d log d) rotation for KV cache compression) + // Applies WHT rotation to 128-element groups along ne[0]: sign1 -> butterfly -> sign2 -> normalize + // direction: 0 = forward (signs1 -> WHT -> signs2), 1 = inverse (signs2 -> WHT -> signs1) + GGML_API struct ggml_tensor * ggml_turbo_wht( + struct ggml_context * ctx, + struct ggml_tensor * a, + int direction, + int group_size, // 0 = auto (64 or 128 from ne[0]) + struct ggml_tensor * scale); // NULL = no InnerQ scaling + // custom operators typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata); diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 48fbe208d90..af53c4c892b 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -206,6 +206,7 @@ add_library(ggml-base ggml-threading.h ggml-quants.c ggml-quants.h + ggml-turbo-quant.c gguf.cpp) set_target_properties(ggml-base PROPERTIES diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index f05683b44cd..5687be2d87b 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -277,6 +277,83 @@ typedef struct { } block_tq2_0; static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding"); +// TurboQuant 3-bit MSE-only: 3-bit PolarQuant indices (no QJL) +// Storage block size = 128 (one block per rotation group / head_dim) +// Per block: norm(fp16) + 2-bit indices (32 bytes) + 1-bit extra (16 bytes) = 50 bytes per 128 values +// = 3.125 bits/value -> ~5.1x compression vs fp16 +// The 3-bit index is split: lower 2 bits in qs[], upper 1 bit in signs[] +#define QK_TURBO3 128 // Block size 128: one block per rotation group, eliminates redundant norms +#define QK_TURBO3_GROUP 128 // rotation group size = head_dim +static_assert(QK_TURBO3 % 2 == 0, "QK_TURBO3 must be even for 2D VQ pair quantization"); +// Derived: FA template nl parameters (auto-scale with block size) +#define NL_TURBO3 (QK_TURBO3 / 16) // non-vec FA iterations per block +#define NL_TURBO3_VEC (QK_TURBO3 / 4) // vec FA iterations per block +typedef struct { + ggml_half norm; // 2 bytes: vector L2 norm (for rescaling) + uint8_t qs[QK_TURBO3 / 4]; // 32 bytes: lower 2-bit indices (4 per byte) + uint8_t signs[QK_TURBO3 / 8]; // 16 bytes: upper 1-bit of 3-bit index (8 per byte) +} block_turbo3_0; // 50 bytes total +static_assert(sizeof(block_turbo3_0) == sizeof(ggml_half) + QK_TURBO3/4 + QK_TURBO3/8, "wrong turbo3_0 block size/padding"); + +// TurboQuant 4-bit: 4-bit PolarQuant indices (nibble packed) +// Default: 4-bit PolarQuant on all backends +#define QK_TURBO4 128 +// Enable 4-bit path: always use the 4-bit PolarQuant implementation +#ifndef TURBO4_USE_4BIT +#define TURBO4_USE_4BIT 1 +#endif + +// 4-bit PolarQuant: 16 optimal centroids, nibble packed +// Per block: norm(fp16) + rnorm(fp16, reserved) + 4-bit indices (64 bytes) +// = 68 bytes per 128 values = 4.25 bits/value -> 3.8x compression vs fp16 +typedef struct { + ggml_half norm; // 2 bytes + ggml_half rnorm; // 2 bytes (reserved, unused in 4-bit mode) + uint8_t qs[QK_TURBO4 / 2]; // 64 bytes: 4-bit PolarQuant indices (nibble packed) +} block_turbo4_0; // 68 bytes total +static_assert(sizeof(block_turbo4_0) == 68, "wrong turbo4_0 block size"); + +static_assert(QK_TURBO4 == 128, "turbo4 kernels assume QK_TURBO4 == 128"); + +// TurboQuant 2-bit: 2-bit PolarQuant indices only (no QJL) +// Per block: norm(fp16) + 2-bit indices (32 bytes) = 34 bytes per 128 values +// = 2.125 bits/value -> ~7.5x compression vs fp16 +// 4 centroids (Lloyd-Max for N(0, 1/128)): {-0.133462, -0.039994, 0.039994, 0.133462} +#define QK_TURBO2 128 // Block size 128: one block per rotation group +#define QK_TURBO2_GROUP 128 // rotation group size = head_dim +// Derived: FA template nl parameters (auto-scale with block size) +#define NL_TURBO2 (QK_TURBO2 / 16) // non-vec FA iterations per block +#define NL_TURBO2_VEC (QK_TURBO2 / 4) // vec FA iterations per block +typedef struct { + ggml_half norm; // 2 bytes: corrected L2 norm + uint8_t qs[QK_TURBO2 / 4]; // 32 bytes: 2-bit indices (4 per byte) +} block_turbo2_0; // 34 bytes total +static_assert(sizeof(block_turbo2_0) == sizeof(ggml_half) + QK_TURBO2/4, "wrong turbo2_0 block size/padding"); + +// TQ3_1S: WHT-rotated 3-bit weight quantization (8-level Lloyd-Max for N(0,1)) +// Block size 32, dual half-block scales (d0 for [0..15], d1 for [16..31]) +// Per block: d0(fp16) + d1(fp16) + 3-bit indices packed (12 bytes) = 16 bytes per 32 values +// = 4.0 bits/value +#define QK_TQ3_0 32 +typedef struct { + ggml_half d0; // 2 bytes: scale for first 16 elements + ggml_half d1; // 2 bytes: scale for last 16 elements + uint8_t qs[QK_TQ3_0 * 3 / 8]; // 12 bytes: 3-bit indices packed (4 groups of 8 in 3 bytes) +} block_tq3_1s; // 16 bytes total +static_assert(sizeof(block_tq3_1s) == 16, "wrong tq3_1s block size"); + +// TQ4_1S: WHT-rotated 4-bit weight quantization (16-level Lloyd-Max for N(0,1)) +// Block size 32, dual half-block scales (d0 for [0..15], d1 for [16..31]) +// Per block: d0(fp16) + d1(fp16) + 4-bit indices packed (16 bytes) = 20 bytes per 32 values +// = 5.0 bits/value +#define QK_TQ4_1S 32 +typedef struct { + ggml_half d0; // 2 bytes: scale for first 16 elements + ggml_half d1; // 2 bytes: scale for last 16 elements + uint8_t qs[QK_TQ4_1S / 2]; // 16 bytes: 4-bit indices nibble-packed +} block_tq4_1s; // 20 bytes total +static_assert(sizeof(block_tq4_1s) == 20, "wrong tq4_1s block size"); + // // Super-block quantization structures // diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index b54d4a6b107..a8a7af24736 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -120,7 +120,22 @@ if (CUDAToolkit_FOUND) template-instances/fattn-vec-instance-f16-f16.cu template-instances/fattn-vec-instance-q4_0-q4_0.cu template-instances/fattn-vec-instance-q8_0-q8_0.cu - template-instances/fattn-vec-instance-bf16-bf16.cu) + template-instances/fattn-vec-instance-bf16-bf16.cu + template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu + template-instances/fattn-vec-instance-turbo3_0-q8_0.cu + template-instances/fattn-vec-instance-q8_0-turbo3_0.cu + template-instances/fattn-vec-instance-turbo2_0-turbo2_0.cu + template-instances/fattn-vec-instance-turbo2_0-q8_0.cu + template-instances/fattn-vec-instance-q8_0-turbo2_0.cu + template-instances/fattn-vec-instance-turbo3_0-turbo2_0.cu + template-instances/fattn-vec-instance-turbo2_0-turbo3_0.cu + template-instances/fattn-vec-instance-turbo4_0-turbo4_0.cu + template-instances/fattn-vec-instance-turbo4_0-q8_0.cu + template-instances/fattn-vec-instance-q8_0-turbo4_0.cu + template-instances/fattn-vec-instance-turbo4_0-turbo3_0.cu + template-instances/fattn-vec-instance-turbo3_0-turbo4_0.cu + template-instances/fattn-vec-instance-turbo4_0-turbo2_0.cu + template-instances/fattn-vec-instance-turbo2_0-turbo4_0.cu) endif() ggml_add_backend_library(ggml-cuda diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 61630a35a29..c19e19e1a46 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -1,5 +1,6 @@ #include "convert.cuh" #include "dequantize.cuh" +#include "turbo-quant.cuh" #include @@ -758,6 +759,16 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { return dequantize_row_mxfp4_cuda; case GGML_TYPE_NVFP4: return dequantize_row_nvfp4_cuda; + case GGML_TYPE_TURBO3_0: + return dequantize_block_cont_cuda; + case GGML_TYPE_TURBO2_0: + return dequantize_block_cont_cuda; + case GGML_TYPE_TURBO4_0: + return dequantize_block_cont_cuda; + case GGML_TYPE_TQ4_1S: + return dequantize_block_cont_cuda; + case GGML_TYPE_TQ3_1S: + return dequantize_block_cont_cuda; case GGML_TYPE_F32: return convert_unary_cont_cuda; case GGML_TYPE_BF16: @@ -813,6 +824,16 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { return dequantize_row_mxfp4_cuda; case GGML_TYPE_NVFP4: return dequantize_row_nvfp4_cuda; + case GGML_TYPE_TURBO3_0: + return dequantize_block_cont_cuda; + case GGML_TYPE_TURBO2_0: + return dequantize_block_cont_cuda; + case GGML_TYPE_TURBO4_0: + return dequantize_block_cont_cuda; + case GGML_TYPE_TQ4_1S: + return dequantize_block_cont_cuda; + case GGML_TYPE_TQ3_1S: + return dequantize_block_cont_cuda; case GGML_TYPE_F16: return convert_unary_cont_cuda; case GGML_TYPE_BF16: @@ -838,6 +859,16 @@ to_fp16_nc_cuda_t ggml_get_to_fp16_nc_cuda(ggml_type type) { return dequantize_block_cuda; case GGML_TYPE_Q8_0: return dequantize_block_cuda; + case GGML_TYPE_TURBO3_0: + return dequantize_block_cuda; + case GGML_TYPE_TURBO2_0: + return dequantize_block_cuda; + case GGML_TYPE_TURBO4_0: + return dequantize_block_cuda; + case GGML_TYPE_TQ4_1S: + return dequantize_block_cuda; + case GGML_TYPE_TQ3_1S: + return dequantize_block_cuda; case GGML_TYPE_BF16: return convert_unary_cuda; default: @@ -884,6 +915,16 @@ to_fp32_nc_cuda_t ggml_get_to_fp32_nc_cuda(ggml_type type) { return dequantize_block_cuda; case GGML_TYPE_Q8_0: return dequantize_block_cuda; + case GGML_TYPE_TURBO3_0: + return dequantize_block_cuda; + case GGML_TYPE_TURBO2_0: + return dequantize_block_cuda; + case GGML_TYPE_TURBO4_0: + return dequantize_block_cuda; + case GGML_TYPE_TQ4_1S: + return dequantize_block_cuda; + case GGML_TYPE_TQ3_1S: + return dequantize_block_cuda; case GGML_TYPE_BF16: return convert_unary_cuda; default: diff --git a/ggml/src/ggml-cuda/dequantize.cuh b/ggml/src/ggml-cuda/dequantize.cuh index 9ae1342fc0e..c004015adc6 100644 --- a/ggml/src/ggml-cuda/dequantize.cuh +++ b/ggml/src/ggml-cuda/dequantize.cuh @@ -1,4 +1,5 @@ #include "common.cuh" +#include "turbo-quant.cuh" static __device__ __forceinline__ void dequantize_q1_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ const block_q1_0 * x = (const block_q1_0 *) vx; @@ -97,3 +98,97 @@ static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const in v.x *= d; v.y *= d; } + +// Turbo4: 4-bit PolarQuant (nibble packed), block size 128 +// iqs is the element index within the block (even), produces elements iqs and iqs+1 +static __device__ __forceinline__ void dequantize_turbo4_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_turbo4_0 * x = (const block_turbo4_0 *) vx; + const float norm = __half2float(x[ib].norm); + v.x = turbo4_dequant_element(&x[ib], iqs + 0, norm); + v.y = turbo4_dequant_element(&x[ib], iqs + 1, norm); +} + +// Turbo3: 3-bit PolarQuant (2-bit qs + 1-bit sign), block size 128 +// iqs is the element index within the block (even), produces elements iqs and iqs+1 +static __device__ __forceinline__ void dequantize_turbo3_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_turbo3_0 * x = (const block_turbo3_0 *) vx; + const float norm = __half2float(x[ib].norm); + { float2 _vp = turbo3_dequant_pair(&x[ib], iqs, norm); v.x = _vp.x; v.y = _vp.y; } +} + +// Turbo2: 2-bit PolarQuant (2-bit qs only, no sign), block size 128 +static __device__ __forceinline__ void dequantize_turbo2_0(const void * vx, const int64_t ib, const int iqs, float2 & v){ + const block_turbo2_0 * x = (const block_turbo2_0 *) vx; + const float norm = __half2float(x[ib].norm); + v.x = turbo2_dequant_element(&x[ib], iqs + 0, norm); + v.y = turbo2_dequant_element(&x[ib], iqs + 1, norm); +} + +// TQ4_1S: 4-bit weight type with inverse WHT, block size 32, dual half-block scales +static __device__ __forceinline__ void dequantize_tq4_1s(const void * vx, const int64_t ib, const int iqs, float2 & v) { + const block_tq4_1s * x = (const block_tq4_1s *) vx; + const float d0 = __half2float(x[ib].d0); + const float d1 = __half2float(x[ib].d1); + + float buf[32]; + for (int j = 0; j < 32; j++) { + uint8_t idx = (x[ib].qs[j / 2] >> ((j & 1) * 4)) & 0xF; + float d = (j < 16) ? d0 : d1; + buf[j] = TQ4_CENTROIDS_WEIGHT[idx] * d; + } + + for (int step = 1; step < 32; step <<= 1) { + for (int i = 0; i < 32; i += step << 1) { + for (int j = i; j < i + step; j++) { + float a = buf[j], b = buf[j + step]; + buf[j] = a + b; buf[j + step] = a - b; + } + } + } + const float inv_sqrt32 = 0.17677669529663688f; + for (int j = 0; j < 32; j++) buf[j] *= inv_sqrt32 * TQ_WEIGHT_SIGNS[j]; + + v.x = buf[iqs]; + v.y = buf[iqs + 1]; +} + +// TQ3_1S: 3-bit weight type with inverse WHT, block size 32, dual half-block scales +static __device__ __forceinline__ void dequantize_tq3_1s(const void * vx, const int64_t ib, const int iqs, float2 & v) { + const block_tq3_1s * x = (const block_tq3_1s *) vx; + const float d0 = __half2float(x[ib].d0); + const float d1 = __half2float(x[ib].d1); + + float buf[32]; + for (int g = 0; g < 4; g++) { + const uint8_t * qp = x[ib].qs + g * 3; + uint8_t idx[8]; + idx[0] = qp[0] & 7; + idx[1] = (qp[0] >> 3) & 7; + idx[2] = ((qp[0] >> 6) | (qp[1] << 2)) & 7; + idx[3] = (qp[1] >> 1) & 7; + idx[4] = (qp[1] >> 4) & 7; + idx[5] = ((qp[1] >> 7) | (qp[2] << 1)) & 7; + idx[6] = (qp[2] >> 2) & 7; + idx[7] = (qp[2] >> 5) & 7; + + for (int i = 0; i < 8; i++) { + int j = g * 8 + i; + float d = (j < 16) ? d0 : d1; + buf[j] = TQ3_CENTROIDS_WEIGHT[idx[i]] * d; + } + } + + for (int step = 1; step < 32; step <<= 1) { + for (int i = 0; i < 32; i += step << 1) { + for (int j = i; j < i + step; j++) { + float a = buf[j], b = buf[j + step]; + buf[j] = a + b; buf[j + step] = a - b; + } + } + } + const float inv_sqrt32 = 0.17677669529663688f; + for (int j = 0; j < 32; j++) buf[j] *= inv_sqrt32 * TQ_WEIGHT_SIGNS[j]; + + v.x = buf[iqs]; + v.y = buf[iqs + 1]; +} diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index beeb5238946..9a3811ca59f 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -3,6 +3,9 @@ #include "common.cuh" #include "convert.cuh" #include "vecdotq.cuh" +#ifdef GGML_FATTN_TURBO +#include "turbo-quant.cuh" +#endif #include @@ -39,7 +42,8 @@ typedef void (* fattn_kernel_t)( const int32_t nb11, const int32_t nb12, const int64_t nb13, const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33); + const int32_t nb31, const int32_t nb32, const int64_t nb33, + const int32_t * __restrict__ kv_indices); typedef float (*vec_dot_KQ_t)( const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8 , const void * __restrict__ Q_ds); @@ -288,6 +292,151 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_q8_0( return sum; } +#ifdef GGML_FATTN_TURBO +// Turbo3 KQ dot product: dequantize K from turbo3 blocks, dot with Q +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_turbo3_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_turbo3_0 * K_turbo = (const block_turbo3_0 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { + const int k_KQ = k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne + k_KQ_1; + + const int elem0 = k_KQ * 2; + const int ib = elem0 / QK_TURBO3; + const int j0 = elem0 % QK_TURBO3; + + const float norm = __half2float(K_turbo[ib].norm); + const uint8_t qs_byte = K_turbo[ib].qs[j0 / 4]; + const uint8_t sgn_byte = K_turbo[ib].signs[j0 / 8]; + + const int shift = (j0 % 4) * 2; + const uint8_t idx0 = ((qs_byte >> shift) & 0x3) | (((sgn_byte >> (j0 % 8)) & 0x1) << 2); + const uint8_t idx1 = ((qs_byte >> (shift+2)) & 0x3) | (((sgn_byte >> (j0 % 8 + 1)) & 0x1) << 2); + + float2 kv; + { uint8_t vq_ = (idx0 << 3) | idx1; kv.x = TURBO_VQ2D_X[vq_] * norm; kv.y = TURBO_VQ2D_Y[vq_] * norm; } + +#ifdef V_DOT2_F32_F16_AVAILABLE + const half2 qv = ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]; + ggml_cuda_mad(sum, make_float2(kv.x, kv.y), __half22float2(qv)); +#else + const float2 qv = ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]; + sum += kv.x * qv.x + kv.y * qv.y; +#endif + } + } + + return sum; +} + +// Turbo2 KQ dot product: dequantize K from turbo2 blocks, dot with Q +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_turbo2_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_turbo2_0 * K_turbo = (const block_turbo2_0 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { + const int k_KQ = k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne + k_KQ_1; + + const int elem0 = k_KQ * 2; + const int ib = elem0 / QK_TURBO2; + const int j0 = elem0 % QK_TURBO2; + + const float norm = __half2float(K_turbo[ib].norm); + const uint8_t qs_byte = K_turbo[ib].qs[j0 / 4]; + + const int shift = (j0 % 4) * 2; + const uint8_t idx0 = (qs_byte >> shift) & 0x3; + const uint8_t idx1 = (qs_byte >> (shift+2)) & 0x3; + + float2 kv; + kv.x = TURBO_CENTROIDS_2BIT[idx0] * norm; + kv.y = TURBO_CENTROIDS_2BIT[idx1] * norm; + +#ifdef V_DOT2_F32_F16_AVAILABLE + const half2 qv = ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]; + ggml_cuda_mad(sum, make_float2(kv.x, kv.y), __half22float2(qv)); +#else + const float2 qv = ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]; + sum += kv.x * qv.x + kv.y * qv.y; +#endif + } + } + + return sum; +} + +// Turbo4 KQ dot product: dequantize K from turbo4 blocks, dot with Q +template +static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_turbo4_0( + const char * __restrict__ K_c, const void * __restrict__ Q_v, const int * __restrict__ Q_q8, const void * __restrict__ Q_ds_v) { + + const block_turbo4_0 * K_turbo = (const block_turbo4_0 *) K_c; + GGML_UNUSED(Q_q8); + GGML_UNUSED(Q_ds_v); + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + float sum = 0.0f; + +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += nthreads*cpy_ne) { +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) { + const int k_KQ = k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne + k_KQ_1; + + const int elem0 = k_KQ * 2; + const int ib = elem0 / QK_TURBO4; + const int j0 = elem0 % QK_TURBO4; + + const float norm = __half2float(K_turbo[ib].norm); + const uint8_t qs_byte = K_turbo[ib].qs[j0 / 2]; + + const uint8_t idx0 = (qs_byte >> 0) & 0xF; + const uint8_t idx1 = (qs_byte >> 4) & 0xF; + + float2 kv; + kv.x = TURBO_CENTROIDS_4BIT[idx0] * norm; + kv.y = TURBO_CENTROIDS_4BIT[idx1] * norm; + +#ifdef V_DOT2_F32_F16_AVAILABLE + const half2 qv = ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]; + ggml_cuda_mad(sum, make_float2(kv.x, kv.y), __half22float2(qv)); +#else + const float2 qv = ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]; + sum += kv.x * qv.x + kv.y * qv.y; +#endif + } + } + + return sum; +} +#endif // GGML_FATTN_TURBO + template static __device__ __forceinline__ void quantize_q8_1_to_shared( const float * __restrict__ x, const float scale, int * __restrict__ yq32, void * __restrict__ yds) { @@ -577,6 +726,171 @@ static __device__ __forceinline__ void dequantize_V_q8_0(const void * __restrict } } +#ifdef GGML_FATTN_TURBO +// Turbo3 V dequantize: extract ne float/half values at position i0 +template +static __device__ __forceinline__ void dequantize_V_turbo3_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_turbo3_0 * x = (const block_turbo3_0 *) vx; + + const int64_t ib = i0 / QK_TURBO3; + const int j0 = i0 % QK_TURBO3; + const float norm = __half2float(x[ib].norm); + + static_assert(ne == 2 || ne == 4, "bad ne"); + + if constexpr (ne == 4) { + const uint8_t qs_byte = x[ib].qs[j0 / 4]; + const uint8_t sgn_byte = x[ib].signs[j0 / 8]; + const int shift_s = j0 % 8; + + const uint8_t idx0 = ((qs_byte >> 0) & 0x3) | (((sgn_byte >> (shift_s+0)) & 0x1) << 2); + const uint8_t idx1 = ((qs_byte >> 2) & 0x3) | (((sgn_byte >> (shift_s+1)) & 0x1) << 2); + const uint8_t idx2 = ((qs_byte >> 4) & 0x3) | (((sgn_byte >> (shift_s+2)) & 0x1) << 2); + const uint8_t idx3 = ((qs_byte >> 6) & 0x3) | (((sgn_byte >> (shift_s+3)) & 0x1) << 2); + +#ifdef FP16_AVAILABLE + if constexpr (std::is_same_v) { + ((half2 *) dst)[0] = make_half2( + __float2half(TURBO_VQ2D_X[(idx0<<3)|idx1] * norm), + __float2half(TURBO_VQ2D_Y[(idx0<<3)|idx1] * norm)); + ((half2 *) dst)[1] = make_half2( + __float2half(TURBO_VQ2D_X[(idx2<<3)|idx3] * norm), + __float2half(TURBO_VQ2D_Y[(idx2<<3)|idx3] * norm)); + } else +#endif + if constexpr (std::is_same_v) { + ((float2 *) dst)[0] = make_float2( + TURBO_VQ2D_X[(idx0<<3)|idx1] * norm, + TURBO_VQ2D_Y[(idx0<<3)|idx1] * norm); + ((float2 *) dst)[1] = make_float2( + TURBO_VQ2D_X[(idx2<<3)|idx3] * norm, + TURBO_VQ2D_Y[(idx2<<3)|idx3] * norm); + } else { + static_assert(std::is_same_v, "unsupported type"); + } + } else { // ne == 2 +#ifdef FP16_AVAILABLE + if constexpr (std::is_same_v) { + float2 _vp = turbo3_dequant_pair(&x[ib], j0, norm); + ((half2 *) dst)[0] = make_half2(__float2half(_vp.x), __float2half(_vp.y)); + } else +#endif + if constexpr (std::is_same_v) { + float2 _vp = turbo3_dequant_pair(&x[ib], j0, norm); + ((float *) dst)[0] = _vp.x; + ((float *) dst)[1] = _vp.y; + } else { + static_assert(std::is_same_v, "unsupported type"); + } + } +} + +// Turbo2 V dequantize +template +static __device__ __forceinline__ void dequantize_V_turbo2_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_turbo2_0 * x = (const block_turbo2_0 *) vx; + + const int64_t ib = i0 / QK_TURBO2; + const int j0 = i0 % QK_TURBO2; + const float norm = __half2float(x[ib].norm); + + static_assert(ne == 2 || ne == 4, "bad ne"); + + if constexpr (ne == 4) { + const uint8_t qs_byte = x[ib].qs[j0 / 4]; + + const uint8_t idx0 = (qs_byte >> 0) & 0x3; + const uint8_t idx1 = (qs_byte >> 2) & 0x3; + const uint8_t idx2 = (qs_byte >> 4) & 0x3; + const uint8_t idx3 = (qs_byte >> 6) & 0x3; + +#ifdef FP16_AVAILABLE + if constexpr (std::is_same_v) { + ((half2 *) dst)[0] = make_half2( + __float2half(TURBO_CENTROIDS_2BIT[idx0] * norm), + __float2half(TURBO_CENTROIDS_2BIT[idx1] * norm)); + ((half2 *) dst)[1] = make_half2( + __float2half(TURBO_CENTROIDS_2BIT[idx2] * norm), + __float2half(TURBO_CENTROIDS_2BIT[idx3] * norm)); + } else +#endif + if constexpr (std::is_same_v) { + ((float2 *) dst)[0] = make_float2(TURBO_CENTROIDS_2BIT[idx0] * norm, TURBO_CENTROIDS_2BIT[idx1] * norm); + ((float2 *) dst)[1] = make_float2(TURBO_CENTROIDS_2BIT[idx2] * norm, TURBO_CENTROIDS_2BIT[idx3] * norm); + } else { + static_assert(std::is_same_v, "unsupported type"); + } + } else { // ne == 2 +#ifdef FP16_AVAILABLE + if constexpr (std::is_same_v) { + float v0 = turbo2_dequant_element(&x[ib], j0, norm); + float v1 = turbo2_dequant_element(&x[ib], j0+1, norm); + ((half2 *) dst)[0] = make_half2(__float2half(v0), __float2half(v1)); + } else +#endif + if constexpr (std::is_same_v) { + ((float *) dst)[0] = turbo2_dequant_element(&x[ib], j0, norm); + ((float *) dst)[1] = turbo2_dequant_element(&x[ib], j0+1, norm); + } else { + static_assert(std::is_same_v, "unsupported type"); + } + } +} + +// Turbo4 V dequantize +template +static __device__ __forceinline__ void dequantize_V_turbo4_0(const void * __restrict__ vx, void * __restrict__ dst, const int64_t i0) { + const block_turbo4_0 * x = (const block_turbo4_0 *) vx; + + const int64_t ib = i0 / QK_TURBO4; + const int j0 = i0 % QK_TURBO4; + const float norm = __half2float(x[ib].norm); + + static_assert(ne == 2 || ne == 4, "bad ne"); + + if constexpr (ne == 4) { + const uint8_t qs_byte0 = x[ib].qs[j0 / 2]; + const uint8_t qs_byte1 = x[ib].qs[j0 / 2 + 1]; + + const uint8_t idx0 = (qs_byte0 >> 0) & 0xF; + const uint8_t idx1 = (qs_byte0 >> 4) & 0xF; + const uint8_t idx2 = (qs_byte1 >> 0) & 0xF; + const uint8_t idx3 = (qs_byte1 >> 4) & 0xF; + +#ifdef FP16_AVAILABLE + if constexpr (std::is_same_v) { + ((half2 *) dst)[0] = make_half2( + __float2half(TURBO_CENTROIDS_4BIT[idx0] * norm), + __float2half(TURBO_CENTROIDS_4BIT[idx1] * norm)); + ((half2 *) dst)[1] = make_half2( + __float2half(TURBO_CENTROIDS_4BIT[idx2] * norm), + __float2half(TURBO_CENTROIDS_4BIT[idx3] * norm)); + } else +#endif + if constexpr (std::is_same_v) { + ((float2 *) dst)[0] = make_float2(TURBO_CENTROIDS_4BIT[idx0] * norm, TURBO_CENTROIDS_4BIT[idx1] * norm); + ((float2 *) dst)[1] = make_float2(TURBO_CENTROIDS_4BIT[idx2] * norm, TURBO_CENTROIDS_4BIT[idx3] * norm); + } else { + static_assert(std::is_same_v, "unsupported type"); + } + } else { // ne == 2 +#ifdef FP16_AVAILABLE + if constexpr (std::is_same_v) { + float v0 = turbo4_dequant_element(&x[ib], j0, norm); + float v1 = turbo4_dequant_element(&x[ib], j0+1, norm); + ((half2 *) dst)[0] = make_half2(__float2half(v0), __float2half(v1)); + } else +#endif + if constexpr (std::is_same_v) { + ((float *) dst)[0] = turbo4_dequant_element(&x[ib], j0, norm); + ((float *) dst)[1] = turbo4_dequant_element(&x[ib], j0+1, norm); + } else { + static_assert(std::is_same_v, "unsupported type"); + } + } +} +#endif // GGML_FATTN_TURBO + template constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { if constexpr (type_K == GGML_TYPE_F16) { @@ -593,6 +907,14 @@ constexpr __device__ vec_dot_KQ_t get_vec_dot_KQ() { return vec_dot_fattn_vec_KQ_q8_0; } else if constexpr (type_K == GGML_TYPE_BF16) { return vec_dot_fattn_vec_KQ_bf16; +#ifdef GGML_FATTN_TURBO + } else if constexpr (type_K == GGML_TYPE_TURBO3_0) { + return vec_dot_fattn_vec_KQ_turbo3_0; + } else if constexpr (type_K == GGML_TYPE_TURBO2_0) { + return vec_dot_fattn_vec_KQ_turbo2_0; + } else if constexpr (type_K == GGML_TYPE_TURBO4_0) { + return vec_dot_fattn_vec_KQ_turbo4_0; +#endif } else { static_assert(type_K == -1, "bad type"); return nullptr; @@ -615,6 +937,14 @@ constexpr __device__ dequantize_V_t get_dequantize_V() { return dequantize_V_q8_0; } else if constexpr (type_V == GGML_TYPE_BF16) { return dequantize_V_bf16; +#ifdef GGML_FATTN_TURBO + } else if constexpr (type_V == GGML_TYPE_TURBO3_0) { + return dequantize_V_turbo3_0; + } else if constexpr (type_V == GGML_TYPE_TURBO2_0) { + return dequantize_V_turbo2_0; + } else if constexpr (type_V == GGML_TYPE_TURBO4_0) { + return dequantize_V_turbo4_0; +#endif } else { static_assert(type_V == -1, "bad type"); return nullptr; @@ -928,6 +1258,10 @@ void launch_fattn( const ggml_tensor * mask = dst->src[3]; const ggml_tensor * sinks = dst->src[4]; + const ggml_tensor * kv_idx = dst->src[5]; + + // Logical KV length: kv_indices->ne[0] when sparse, K->ne[1] when dense + const int32_t ne11_logical = kv_idx ? (int32_t)kv_idx->ne[0] : (int32_t)K->ne[1]; ggml_tensor * KQV = dst; @@ -1032,7 +1366,7 @@ void launch_fattn( // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or // multiple sequences of possibly different lengths. - if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { + if (mask && ne11_logical % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { const int s31 = mask->nb[1] / sizeof(half2); const int s33 = mask->nb[3] / sizeof(half2); @@ -1040,7 +1374,7 @@ void launch_fattn( const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1); const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y; - const int iter_k = K->ne[1] / FATTN_KQ_STRIDE; + const int iter_k = ne11_logical / FATTN_KQ_STRIDE; KV_max.alloc(ne_KV_max); flash_attn_mask_to_KV_max<<>> @@ -1054,7 +1388,7 @@ void launch_fattn( GGML_ASSERT(max_blocks_per_sm > 0); int parallel_blocks = max_blocks_per_sm; - const int ntiles_KV = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by KV cache length. + const int ntiles_KV = (ne11_logical + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by KV cache length. dim3 blocks_num; if (stream_k) { @@ -1156,10 +1490,11 @@ void launch_fattn( !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, scale, max_bias, m0, m1, n_head_log2, logit_softcap, Q->ne[0], ne01, Q->ne[2], Q->ne[3], Q->nb[1], Q->nb[2], Q->nb[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], nb11, nb12, nb13, + K->ne[0], ne11_logical, K->ne[2], K->ne[3], nb11, nb12, nb13, nb21, nb22, nb23, mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0, - mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0 + mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0, + kv_idx ? (const int32_t *) kv_idx->data : nullptr ); CUDA_CHECK(cudaGetLastError()); diff --git a/ggml/src/ggml-cuda/fattn-vec.cuh b/ggml/src/ggml-cuda/fattn-vec.cuh index f0bd42a5761..33e845a1819 100644 --- a/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ggml/src/ggml-cuda/fattn-vec.cuh @@ -10,6 +10,16 @@ static constexpr __device__ int ggml_cuda_fattn_vec_get_nthreads_device() { return 128; } +static __device__ __forceinline__ const char * ggml_cuda_fattn_vec_kv_row( + const char * tile_base, + const char * base, + const int32_t * __restrict__ kv_indices, + const int logical_pos, + const int local_pos, + const int32_t row_stride) { + return kv_indices ? base + int64_t(kv_indices[logical_pos])*row_stride : tile_base + local_pos*row_stride; +} + // Currently llvm with the amdgcn target does not support unrolling loops // that contain a break that can not be resolved at compile time. #ifdef __clang__ @@ -39,12 +49,13 @@ static __global__ void flash_attn_ext_vec( const int32_t nb11, const int32_t nb12, const int64_t nb13, const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { + const int32_t nb31, const int32_t nb32, const int64_t nb33, + const int32_t * __restrict__ kv_indices) { #ifdef FLASH_ATTN_AVAILABLE // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, kv_indices, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, @@ -503,7 +514,7 @@ static __global__ void flash_attn_ext_vec( dst_meta[((sequence*int(ne01.z) + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(KQ_max[tid], KQ_sum[tid]); } #else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, kv_indices, scale, max_bias, m0, m1, n_head_log2, logit_softcap, ne00, ne01, ne02, ne03, nb01, nb02, nb03, diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index ea6607cd337..3f7ccfe8724 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -203,6 +203,17 @@ static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, gg ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 4>(ctx, dst); } } break; + case 640: { + // Padded turbo KV cache for GLM-4.7 Flash (K head_dim=576 zero-padded to 640). + // D=640 shared memory (Q storage = ncols*(DKQ/2+4)*4) exceeds hardware limit at ncols1>=4. + // Cap at ncols1=2 (ncols=32): Q=32*324*4=41KB + KV≈37KB = ~78KB total. + GGML_ASSERT(V->ne[0] == 512); + if (Q->ne[1] <= 1) { + ggml_cuda_flash_attn_ext_mma_f16_case<640, 512, 1, 16>(ctx, dst); + } else { + ggml_cuda_flash_attn_ext_mma_f16_case<640, 512, 2, 16>(ctx, dst); + } + } break; default: GGML_ABORT("fatal error"); break; @@ -292,6 +303,23 @@ static void ggml_cuda_flash_attn_ext_vec(ggml_backend_cuda_context & ctx, ggml_t FATTN_VEC_CASES_ALL_D(GGML_TYPE_BF16, GGML_TYPE_BF16) #endif // GGML_CUDA_FA_ALL_QUANTS + // TurboQuant KV cache types (always enabled) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO2_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0) + FATTN_VEC_CASES_ALL_D(GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0) + GGML_ABORT("fatal error"); } @@ -315,6 +343,10 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; const ggml_tensor * mask = dst->src[3]; + const ggml_tensor * kv_idx = dst->src[5]; + + // Logical KV length for kernel selection + const int32_t ne11_logical = kv_idx ? (int32_t)kv_idx->ne[0] : (int32_t)K->ne[1]; const int gqa_ratio = Q->ne[2] / K->ne[2]; GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); @@ -324,7 +356,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const // The effective batch size for the kernel can be increased by gqa_ratio. // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, - bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + bool gqa_opt_applies = gqa_ratio >= 2 && mask && max_bias == 0.0f && ne11_logical % FATTN_KQ_STRIDE == 0; for (const ggml_tensor * t : {Q, K, V, mask}) { if (t == nullptr || ggml_is_quantized(t->type)) { continue; @@ -361,6 +393,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } break; case 576: + case 640: if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; } @@ -374,7 +407,13 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const #ifndef GGML_CUDA_FA_ALL_QUANTS if (K->type != V->type) { - return BEST_FATTN_KERNEL_NONE; + // Allow mixed turbo KV types (any combination of turbo2, turbo3, turbo4, q8_0) + auto is_turbo = [](ggml_type t) { + return t == GGML_TYPE_TURBO2_0 || t == GGML_TYPE_TURBO3_0 || t == GGML_TYPE_TURBO4_0 || t == GGML_TYPE_Q8_0; + }; + if (!is_turbo(K->type) || !is_turbo(V->type)) { + return BEST_FATTN_KERNEL_NONE; + } } #endif // GGML_CUDA_FA_ALL_QUANTS @@ -392,6 +431,13 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const case GGML_TYPE_Q8_0: case GGML_TYPE_BF16: break; + case GGML_TYPE_TURBO3_0: + case GGML_TYPE_TURBO2_0: + case GGML_TYPE_TURBO4_0: + if (K->ne[0] % 64 != 0) { + return BEST_FATTN_KERNEL_NONE; + } + break; default: return BEST_FATTN_KERNEL_NONE; } @@ -401,13 +447,13 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: - const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && ne11_logical % FATTN_KQ_STRIDE == 0; // If Turing tensor cores are available, use them: if (turing_mma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72) { if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { - if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { + if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && ne11_logical >= 8192)) { return BEST_FATTN_KERNEL_VEC; } } else { @@ -444,7 +490,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const } // Use the WMMA kernel if possible: - if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) { + if (ggml_cuda_should_use_wmma_fattn(cc) && ne11_logical % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 512 && Q->ne[0] != 576) { if (can_use_vector_kernel && Q->ne[1] <= 2) { return BEST_FATTN_KERNEL_VEC; } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index de579d2ed50..688ded1f1ce 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -56,6 +56,8 @@ #include "ggml-cuda/gated_delta_net.cuh" #include "ggml-cuda/set.cuh" #include "ggml-cuda/set-rows.cuh" +#include "ggml-cuda/turbo-wht.cuh" +#include "ggml-cuda/mmvq-tq.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml-cuda/solve_tri.cuh" #include "ggml-cuda/tri.cuh" @@ -659,7 +661,34 @@ static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, gg ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *) buffer->context; ggml_cuda_set_device(ctx->device); - CUDA_CHECK(cudaMemcpyAsync((char *) tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); + + // TQ4_1S → q8_0 load-time conversion + if (tensor->type == GGML_TYPE_TQ4_1S && offset == 0 && size == ggml_nbytes(tensor)) { + const int64_t n_elements = ggml_nelements(tensor); + + // Upload TQ4_1S to a temp GPU buffer + void * tmp_tq4; + CUDA_CHECK(cudaMalloc(&tmp_tq4, size)); + CUDA_CHECK(cudaMemcpyAsync(tmp_tq4, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); + + // Convert TQ4_1S (tmp) → q8_0 (tensor->data, which has q8_0-sized allocation) + ggml_cuda_convert_tq4_1s_to_q8_0(tmp_tq4, tensor->data, n_elements, cudaStreamPerThread); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); + + CUDA_CHECK(cudaFree(tmp_tq4)); + + // Update tensor metadata to q8_0 + tensor->type = GGML_TYPE_Q8_0; + tensor->nb[0] = ggml_type_size(GGML_TYPE_Q8_0); + tensor->nb[1] = tensor->nb[0] * (tensor->ne[0] / ggml_blck_size(GGML_TYPE_Q8_0)); + for (int i = 2; i < GGML_MAX_DIMS; i++) { + tensor->nb[i] = tensor->nb[i-1] * tensor->ne[i-1]; + } + + return; + } + + CUDA_CHECK(cudaMemcpyAsync((char *)tensor->data + offset, data, size, cudaMemcpyHostToDevice, cudaStreamPerThread)); CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); } @@ -779,6 +808,16 @@ static size_t ggml_backend_cuda_buffer_type_get_alloc_size(ggml_backend_buffer_t size_t size = ggml_nbytes(tensor); int64_t ne0 = tensor->ne[0]; + // TQ4_1S → q8_0 load-time conversion: allocate q8_0-sized space in VRAM + if (tensor->type == GGML_TYPE_TQ4_1S) { + const int64_t n_blocks = ggml_nelements(tensor) / QK_TQ4_1S; + size = n_blocks * sizeof(block_q8_0); + if (ne0 % MATRIX_ROW_PADDING != 0) { + size += ggml_row_size(GGML_TYPE_Q8_0, MATRIX_ROW_PADDING - ne0 % MATRIX_ROW_PADDING); + } + return size; + } + if (ggml_is_quantized(tensor->type)) { if (ne0 % MATRIX_ROW_PADDING != 0) { GGML_ASSERT(tensor->nb[0] == ggml_element_size(tensor)); @@ -2323,7 +2362,9 @@ static bool ggml_cuda_should_fuse_mul_mat_vec_q(const ggml_tensor * tensor) { ggml_nbytes(src0) != ggml_backend_buffer_get_alloc_size(src0->buffer, src0) && src0->view_src; - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && src1->type == GGML_TYPE_F32 && + const bool is_tq_weight = (src0->type == GGML_TYPE_TQ4_1S || src0->type == GGML_TYPE_TQ3_1S); + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && !is_tq_weight && + src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; // fusion is not universally faster on Pascal @@ -2365,10 +2406,12 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; bool use_mul_mat_f = !ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear + // TQ weight types use dequant-to-f16 cuBLAS path only (no mmvq/mmq kernels) + const bool is_tq_weight = (src0->type == GGML_TYPE_TQ4_1S || src0->type == GGML_TYPE_TQ3_1S); + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && !bad_padding_clear && !is_tq_weight && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear + bool use_mul_mat_q = ggml_is_quantized(src0->type) && !bad_padding_clear && !is_tq_weight && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; bool any_gpus_with_slow_fp16 = false; @@ -2432,6 +2475,9 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_q, quantize_mmq_q8_1_cuda); + } else if (!split && is_tq_weight && src1->ne[1] == 1) { + // Fused TQ weight mul_mat_vec with pre-rotated activations via warp shuffle WHT + ggml_cuda_mul_mat_vec_tq(ctx, src0, src1, dst); } else { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_cublas, nullptr); } @@ -2451,16 +2497,18 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; // [TAG_MUL_MAT_ID_CUDA_GRAPHS] + // TQ weight types use dequant-to-f16 cuBLAS path only (no mmvq/mmq kernels) + const bool is_tq_weight_id = (src0->type == GGML_TYPE_TQ4_1S || src0->type == GGML_TYPE_TQ3_1S); if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE); if (ne2 <= MMVQ_MAX_BATCH_SIZE) { - if (ggml_is_quantized(src0->type)) { + if (ggml_is_quantized(src0->type) && !is_tq_weight_id) { const int mmvq_mmid_max = get_mmvq_mmid_max_batch(src0->type, cc); if (ne2 <= mmvq_mmid_max) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); return; } - } else { + } else if (!ggml_is_quantized(src0->type)) { if (GGML_CUDA_CC_IS_AMD(cc)) { ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); return; @@ -2618,6 +2666,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SET_ROWS: ggml_cuda_op_set_rows(ctx, dst); break; + case GGML_OP_TURBO_WHT: + ggml_cuda_turbo_wht(ctx, dst); + break; case GGML_OP_SET: ggml_cuda_op_set(ctx, dst); break; @@ -3076,10 +3127,11 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { } // [TAG_MUL_MAT_ID_CUDA_GRAPHS] + const bool is_tq_w = (node->src[0]->type == GGML_TYPE_TQ4_1S || node->src[0]->type == GGML_TYPE_TQ3_1S); if (node->op == GGML_OP_MUL_MAT_ID) { const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int mmvq_mmid_max = get_mmvq_mmid_max_batch(node->src[0]->type, cc); - if (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > mmvq_mmid_max) { + if (!ggml_is_quantized(node->src[0]->type) || is_tq_w || node->ne[2] > mmvq_mmid_max) { // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs // TODO: figure out a way to enable for larger batch sizes, without hurting performance // ref: https://github.com/ggml-org/llama.cpp/pull/18958 @@ -4864,6 +4916,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_XS: case GGML_TYPE_BF16: + case GGML_TYPE_TQ4_1S: + case GGML_TYPE_TQ3_1S: return true; default: return false; @@ -4895,9 +4949,17 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } break; case GGML_OP_SET_ROWS: { + // turbo KV types require block-aligned head dim + if ((op->type == GGML_TYPE_TURBO3_0 || op->type == GGML_TYPE_TURBO2_0) && op->src[0]->ne[0] % 128 != 0) { + return false; + } + if (op->type == GGML_TYPE_TURBO4_0 && op->src[0]->ne[0] % 128 != 0) { + return false; + } return (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 || - op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) && + op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL || + op->type == GGML_TYPE_TURBO3_0 || op->type == GGML_TYPE_TURBO2_0 || op->type == GGML_TYPE_TURBO4_0) && op->src[0]->type == GGML_TYPE_F32 && (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32); } break; @@ -5024,6 +5086,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CLAMP: case GGML_OP_LOG: return true; + case GGML_OP_TURBO_WHT: + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && + op->src[0]->ne[0] % 32 == 0; // supports 32, 64, and 128 WHT groups case GGML_OP_SSM_SCAN: { if (op->src[3]->ne[0] == 1) { // Mamba2 diff --git a/ggml/src/ggml-cuda/mmvq-tq.cu b/ggml/src/ggml-cuda/mmvq-tq.cu new file mode 100644 index 00000000000..25bd84f10bd --- /dev/null +++ b/ggml/src/ggml-cuda/mmvq-tq.cu @@ -0,0 +1,357 @@ +/* + * Fused mul_mat_vec for TQ4_1S / TQ3_1S weight types. + * + * V12: Single-phase fused kernel with shmem activation sharing. + * All warps cooperatively rotate activation into shared memory, + * then each warp processes one row reading from shmem (broadcast). + * + * Eliminates: + * - Global memory scratch buffer (no CUDA graph incompatibility) + * - Separate pre-rotation kernel launch + * - 2x activation bandwidth (was: write global + read global per row) + * + * V12 avoids the NR0 regression that killed V3/V6/V11 — the single + * __syncthreads is OUTSIDE the dot product loop (between rotation and + * mmvq phases), not inside it. + * + * Falls back to V8 two-phase if shmem exceeds 48 KB (ncols > 12288). + * + * Based on signalnine's V8 two-phase kernel (commit b107175). + * Optimization by TheTom. + */ + +#include "mmvq-tq.cuh" +#include "turbo-quant.cuh" + +#define MMVQ_TQ_NWARPS 8 + +// ============================================================================ +// V8 two-phase kernels (fallback for very large ncols that exceed shmem) +// ============================================================================ + +static __global__ void tq_prerotate_activation_v8( + const float * __restrict__ src, + float * __restrict__ dst, + const int n_elements) { + + const int block_idx = blockIdx.x * blockDim.y + threadIdx.y; + const int lane = threadIdx.x; + const int offset = block_idx * 32 + lane; + if (offset >= n_elements) return; + + float val = src[offset]; + val *= TQ_WEIGHT_SIGNS[lane]; + + #pragma unroll + for (int h = 1; h < 32; h <<= 1) { + float o = __shfl_xor_sync(0xffffffff, val, h, WARP_SIZE); + val = (lane & h) ? (o - val) : (val + o); + } + val *= 0.17677669529663688f; + dst[offset] = val; +} + +static __global__ void mul_mat_vec_tq4_1s_v8( + const void * __restrict__ vx, + const float * __restrict__ vy_rot, + float * __restrict__ dst, + const int ncols_x, + const int nrows_x) { + + const int row = blockIdx.x * MMVQ_TQ_NWARPS + threadIdx.y; + if (row >= nrows_x) return; + + const int lane = threadIdx.x; + const int blocks_per_row = ncols_x / QK_TQ4_1S; + const block_tq4_1s * x_row = ((const block_tq4_1s *) vx) + (int64_t)row * blocks_per_row; + + float sum = 0.0f; + + for (int ib = 0; ib < blocks_per_row; ib++) { + const float act = vy_rot[ib * QK_TQ4_1S + lane]; + const float d = (lane < 16) ? __half2float(x_row[ib].d0) : __half2float(x_row[ib].d1); + const uint8_t idx = (x_row[ib].qs[lane / 2] >> ((lane & 1) * 4)) & 0xF; + + sum += act * TQ4_CENTROIDS_WEIGHT[idx] * d; + } + + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) + sum += __shfl_xor_sync(0xffffffff, sum, offset, WARP_SIZE); + + if (lane == 0) dst[row] = sum; +} + +static __device__ __forceinline__ uint8_t tq3_extract_index(const uint8_t * __restrict__ qs, int lane) { + const int group = lane / 8; + const int lane_in_group = lane % 8; + const uint8_t * qp = qs + group * 3; + const uint32_t packed = (uint32_t)qp[0] | ((uint32_t)qp[1] << 8) | ((uint32_t)qp[2] << 16); + return (packed >> (lane_in_group * 3)) & 7; +} + +static __global__ void mul_mat_vec_tq3_1s_v8( + const void * __restrict__ vx, + const float * __restrict__ vy_rot, + float * __restrict__ dst, + const int ncols_x, + const int nrows_x) { + + const int row = blockIdx.x * MMVQ_TQ_NWARPS + threadIdx.y; + if (row >= nrows_x) return; + + const int lane = threadIdx.x; + const int blocks_per_row = ncols_x / QK_TQ3_0; + const block_tq3_1s * x_row = ((const block_tq3_1s *) vx) + (int64_t)row * blocks_per_row; + + float sum = 0.0f; + + for (int ib = 0; ib < blocks_per_row; ib++) { + const float act = vy_rot[ib * QK_TQ3_0 + lane]; + const float d = (lane < 16) ? __half2float(x_row[ib].d0) : __half2float(x_row[ib].d1); + const uint8_t idx = tq3_extract_index(x_row[ib].qs, lane); + + sum += act * TQ3_CENTROIDS_WEIGHT[idx] * d; + } + + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) + sum += __shfl_xor_sync(0xffffffff, sum, offset, WARP_SIZE); + + if (lane == 0) dst[row] = sum; +} + +// ============================================================================ +// V12: Single-phase fused kernel — rotate in shmem, no global scratch +// +// All 8 warps cooperatively WHT-rotate activation into shared memory. +// Then each warp processes one row doing centroid×scale dot product +// reading activation from shmem (broadcast reads from L1). +// +// The key insight: the single __syncthreads is between the two phases +// (rotation vs dot product), NOT inside the inner dot product loop. +// This is why V3/V11 regressed (sync per block) but V12 should not. +// ============================================================================ + +static __global__ void mul_mat_vec_tq4_1s_v12( + const void * __restrict__ vx, + const float * __restrict__ vy, // UNROTATED activation (raw src1) + float * __restrict__ dst, + const int ncols_x, + const int nrows_x) { + + extern __shared__ float s_act[]; // ncols_x floats + + const int lane = threadIdx.x; // 0-31 + const int warp_id = threadIdx.y; // 0 to MMVQ_TQ_NWARPS-1 + const int blocks_per_row = ncols_x / QK_TQ4_1S; + + // Phase 1: ALL warps cooperatively pre-rotate activation into shmem. + // Each warp handles a strided subset of 32-element blocks. + // 8 warps × 32 threads = 256 threads rotating in parallel. + for (int ib = warp_id; ib < blocks_per_row; ib += MMVQ_TQ_NWARPS) { + float val = vy[ib * 32 + lane]; + val *= TQ_WEIGHT_SIGNS[lane]; + + #pragma unroll + for (int h = 1; h < 32; h <<= 1) { + float o = __shfl_xor_sync(0xffffffff, val, h, WARP_SIZE); + val = (lane & h) ? (o - val) : (val + o); + } + val *= 0.17677669529663688f; // 1/sqrt(32) + s_act[ib * 32 + lane] = val; + } + __syncthreads(); // ONE sync — between rotation and dot product, NOT in inner loop + + // Phase 2: Each warp processes one row using shmem activation (broadcast reads). + const int row = blockIdx.x * MMVQ_TQ_NWARPS + warp_id; + if (row >= nrows_x) return; + + const block_tq4_1s * x_row = ((const block_tq4_1s *) vx) + (int64_t)row * blocks_per_row; + float sum = 0.0f; + + for (int ib = 0; ib < blocks_per_row; ib++) { + const float act = s_act[ib * 32 + lane]; + const float d = (lane < 16) ? __half2float(x_row[ib].d0) : __half2float(x_row[ib].d1); + const uint8_t idx = (x_row[ib].qs[lane / 2] >> ((lane & 1) * 4)) & 0xF; + sum += act * TQ4_CENTROIDS_WEIGHT[idx] * d; + } + + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) + sum += __shfl_xor_sync(0xffffffff, sum, offset, WARP_SIZE); + + if (lane == 0) dst[row] = sum; +} + +static __global__ void mul_mat_vec_tq3_1s_v12( + const void * __restrict__ vx, + const float * __restrict__ vy, // UNROTATED activation (raw src1) + float * __restrict__ dst, + const int ncols_x, + const int nrows_x) { + + extern __shared__ float s_act[]; + + const int lane = threadIdx.x; + const int warp_id = threadIdx.y; + const int blocks_per_row = ncols_x / QK_TQ3_0; + + // Phase 1: cooperative rotation into shmem + for (int ib = warp_id; ib < blocks_per_row; ib += MMVQ_TQ_NWARPS) { + float val = vy[ib * 32 + lane]; + val *= TQ_WEIGHT_SIGNS[lane]; + + #pragma unroll + for (int h = 1; h < 32; h <<= 1) { + float o = __shfl_xor_sync(0xffffffff, val, h, WARP_SIZE); + val = (lane & h) ? (o - val) : (val + o); + } + val *= 0.17677669529663688f; + s_act[ib * 32 + lane] = val; + } + __syncthreads(); + + // Phase 2: mmvq from shmem + const int row = blockIdx.x * MMVQ_TQ_NWARPS + warp_id; + if (row >= nrows_x) return; + + const block_tq3_1s * x_row = ((const block_tq3_1s *) vx) + (int64_t)row * blocks_per_row; + float sum = 0.0f; + + for (int ib = 0; ib < blocks_per_row; ib++) { + const float act = s_act[ib * 32 + lane]; + const float d = (lane < 16) ? __half2float(x_row[ib].d0) : __half2float(x_row[ib].d1); + const uint8_t idx = tq3_extract_index(x_row[ib].qs, lane); + sum += act * TQ3_CENTROIDS_WEIGHT[idx] * d; + } + + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) + sum += __shfl_xor_sync(0xffffffff, sum, offset, WARP_SIZE); + + if (lane == 0) dst[row] = sum; +} + +// ============================================================================ +// Dispatch — V12 shmem when it fits, V8 two-phase fallback +// ============================================================================ + +void ggml_cuda_mul_mat_vec_tq(ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, + const ggml_tensor * src1, + ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_TQ4_1S || src0->type == GGML_TYPE_TQ3_1S); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(src1->ne[1] == 1); + + const int ncols_x = src0->ne[0]; + const int nrows_x = src0->ne[1]; + GGML_ASSERT(ncols_x % 32 == 0); + + const void * src0_d = src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + cudaStream_t stream = ctx.stream(); + + const size_t shmem_needed = (size_t)ncols_x * sizeof(float); + + // V12: single kernel, activation in shmem (fits for all models up to ncols=12288) + // V8 fallback: two-phase with global scratch (for hypothetical future huge models) + if (shmem_needed <= 48 * 1024) { + const dim3 block(WARP_SIZE, MMVQ_TQ_NWARPS); + const dim3 grid((nrows_x + MMVQ_TQ_NWARPS - 1) / MMVQ_TQ_NWARPS); + + if (src0->type == GGML_TYPE_TQ4_1S) { + mul_mat_vec_tq4_1s_v12<<>>(src0_d, src1_d, dst_d, ncols_x, nrows_x); + } else { + mul_mat_vec_tq3_1s_v12<<>>(src0_d, src1_d, dst_d, ncols_x, nrows_x); + } + } else { + // V8 fallback: two-phase with pool-allocated scratch (Codex P1: was global static) + ggml_cuda_pool & pool = ctx.pool(); + ggml_cuda_pool_alloc d_act_alloc(pool, (size_t)ncols_x); + + { + const int n_blocks = ncols_x / 32; + const dim3 rot_block(32, 4); + const dim3 rot_grid((n_blocks + 3) / 4); + tq_prerotate_activation_v8<<>>(src1_d, d_act_alloc.get(), ncols_x); + } + + { + const dim3 block(WARP_SIZE, MMVQ_TQ_NWARPS); + const dim3 grid((nrows_x + MMVQ_TQ_NWARPS - 1) / MMVQ_TQ_NWARPS); + + if (src0->type == GGML_TYPE_TQ4_1S) { + mul_mat_vec_tq4_1s_v8<<>>(src0_d, d_act_alloc.get(), dst_d, ncols_x, nrows_x); + } else { + mul_mat_vec_tq3_1s_v8<<>>(src0_d, d_act_alloc.get(), dst_d, ncols_x, nrows_x); + } + } + } +} + +// ============================================================================ +// Load-time conversion: TQ4_1S → q8_0 +// +// Fused kernel: dequant TQ4_1S (centroid lookup + inverse WHT) → quantize q8_0. +// One warp (32 threads) per block of 32 elements. +// Used at model load to convert TQ4_1S weights to q8_0 in VRAM for dp4a decode. +// ============================================================================ + +static __global__ void k_convert_tq4_1s_to_q8_0( + const block_tq4_1s * __restrict__ src, + block_q8_0 * __restrict__ dst, + const int n_blocks) { + + const int block_idx = blockIdx.x * blockDim.y + threadIdx.y; + if (block_idx >= n_blocks) return; + + const int lane = threadIdx.x; + const block_tq4_1s * blk = &src[block_idx]; + + // Step 1: Dequant — centroid lookup × half-block scale + const float d_scale = (lane < 16) ? __half2float(blk->d0) : __half2float(blk->d1); + const uint8_t idx = (blk->qs[lane / 2] >> ((lane & 1) * 4)) & 0xF; + float val = TQ4_CENTROIDS_WEIGHT[idx] * d_scale; + + // Step 2: Inverse WHT via warp shuffle (same as dequant path) + #pragma unroll + for (int h = 1; h < 32; h <<= 1) { + float o = __shfl_xor_sync(0xffffffff, val, h, WARP_SIZE); + val = (lane & h) ? (o - val) : (val + o); + } + val *= 0.17677669529663688f; // 1/sqrt(32) + val *= TQ_WEIGHT_SIGNS[lane]; + + // Step 3: Quantize to q8_0 — find block amax, compute scale, round + float amax = fabsf(val); + #pragma unroll + for (int off = 16; off > 0; off >>= 1) + amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, off, WARP_SIZE)); + + const float d = amax / 127.0f; + const float id = (d > 0.0f) ? 127.0f / amax : 0.0f; + + // Step 4: Write q8_0 block + dst[block_idx].qs[lane] = (int8_t)roundf(val * id); + if (lane == 0) { + dst[block_idx].d = __float2half(d); + } +} + +void ggml_cuda_convert_tq4_1s_to_q8_0(const void * src_tq4, void * dst_q8, int64_t n_elements, cudaStream_t stream) { + GGML_ASSERT(n_elements % QK_TQ4_1S == 0); + const int n_blocks = n_elements / QK_TQ4_1S; + + const int wpb = 4; // warps per CUDA block + const dim3 block(32, wpb); + const dim3 grid((n_blocks + wpb - 1) / wpb); + + k_convert_tq4_1s_to_q8_0<<>>( + (const block_tq4_1s *)src_tq4, + (block_q8_0 *)dst_q8, + n_blocks); +} diff --git a/ggml/src/ggml-cuda/mmvq-tq.cuh b/ggml/src/ggml-cuda/mmvq-tq.cuh new file mode 100644 index 00000000000..3315a062c3e --- /dev/null +++ b/ggml/src/ggml-cuda/mmvq-tq.cuh @@ -0,0 +1,8 @@ +#pragma once + +#include "common.cuh" + +void ggml_cuda_mul_mat_vec_tq(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +// Load-time conversion: TQ4_1S → q8_0 in VRAM (dequant + requantize) +void ggml_cuda_convert_tq4_1s_to_q8_0(const void * src_tq4, void * dst_q8, int64_t n_elements, cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/set-rows.cu b/ggml/src/ggml-cuda/set-rows.cu index 631de7e8fa5..d0e38a28ae1 100644 --- a/ggml/src/ggml-cuda/set-rows.cu +++ b/ggml/src/ggml-cuda/set-rows.cu @@ -1,5 +1,6 @@ #include "set-rows.cuh" #include "cpy-utils.cuh" +#include "turbo-quant.cuh" typedef void (*set_rows_kernel_t)(const char * src, char * dst); @@ -209,6 +210,966 @@ static void set_rows_cuda( } } +// ---- TurboQuant3 set_rows: GROUP_SIZE-element groups with WHT rotation + norm correction ---- +// +// Templated on GROUP_SIZE (128 or 64). +// Parallel kernel: one CUDA block per group, GROUP_SIZE threads per block. +// Thread j handles element j within the group. +// +// Steps (all parallel): +// 1. Load element j from global memory +// 2. Parallel L2 norm (warp reduce + inter-warp via shared memory) +// 3. Normalize +// 4. Forward WHT (log2(GROUP_SIZE) butterfly stages, shared memory) +// 5. Quantize element j to 3-bit centroid index +// 6. Pack qs (warp shuffle) and signs (__ballot_sync) into turbo3 block, no atomics +// 7. Parallel reconstruction norm (same pattern as step 2) +// 8. Write corrected norm (one thread per sub-block) + +template +__launch_bounds__(256) // max of 256, 128, or 64 +static __global__ void k_set_rows_turbo3( + const float * __restrict__ src0, + const idx_t * __restrict__ src1, + block_turbo3_0 * __restrict__ dst, + const int64_t ne00, + const int64_t ne01, + const int64_t ne10, + const int64_t ne11, + const int64_t ne12, + const int64_t ne13, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t s10, + const int64_t s11, + const int64_t s12, + const int64_t s1, + const int64_t s2, + const int64_t s3) { + + static_assert(GROUP_SIZE == 64 || GROUP_SIZE == 128 || GROUP_SIZE == 256, "GROUP_SIZE must be 64, 128, or 256"); + + // blockIdx.x = flat group index; threadIdx.x = element within group (0..GROUP_SIZE-1) + const int j = threadIdx.x; + + // Decode blockIdx.x → (i_grp, i01, i02, i03) + constexpr int blocks_per_group = GROUP_SIZE / QK_TURBO3; + const int64_t n_groups_per_row = ne00 / GROUP_SIZE; + const int64_t g = blockIdx.x; + const int64_t i_grp = g % n_groups_per_row; + int64_t tmp = g / n_groups_per_row; + const int64_t i01 = tmp % ne01; + tmp = tmp / ne01; + const int64_t i02 = tmp % ne12; + const int64_t i03 = tmp / ne12; + + const int64_t i12 = i02; + const int64_t i11 = i01 % ne11; + const int64_t i10 = i01; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + const float * src_row = src0 + i01*s01 + i02*s02 + i03*s03; + block_turbo3_0 * dst_row_ptr = (block_turbo3_0 *)((char *)dst + dst_row*s1 + i02*s2 + i03*s3); + block_turbo3_0 * blk_base = dst_row_ptr + i_grp * blocks_per_group; + + // ---- Step 1: Load element j (coalesced) ---- + __shared__ float x[GROUP_SIZE]; + x[j] = src_row[i_grp * GROUP_SIZE + j]; + __syncthreads(); + + // ---- InnerQ: calibrate on original (unscaled) values ---- + if (d_innerq_calibrating) { + atomicAdd(&d_innerq_sq_accum[j], x[j] * x[j]); + if (j == 0) atomicAdd(&d_innerq_count, 1); + } + + // ---- InnerQ: apply channel scale (only when active) ---- + if (d_innerq_active) { + x[j] *= d_innerq_scale[j]; + } + __syncthreads(); + + // ---- Step 2: Parallel L2 norm ---- + constexpr int n_warps = GROUP_SIZE / WARP_SIZE; + __shared__ float warp_accum[n_warps]; + float v = x[j]; + float v2 = v * v; + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) + v2 += __shfl_xor_sync(0xffffffff, v2, offset); + if (j % WARP_SIZE == 0) + warp_accum[j / WARP_SIZE] = v2; + __syncthreads(); + + __shared__ float s_norm_sq; + if (j == 0) { + float total = 0.0f; + for (int w = 0; w < n_warps; w++) total += warp_accum[w]; + s_norm_sq = total; + } + __syncthreads(); + const float grp_norm = sqrtf(s_norm_sq); + const float inv_norm = (grp_norm > 1e-10f) ? 1.0f / grp_norm : 0.0f; + + // ---- Step 3: Normalize ---- + x[j] *= inv_norm; + __syncthreads(); + + // ---- Step 4: Forward WHT (signs1 → butterfly → signs2, normalized) ---- + if (GROUP_SIZE == 256) { + x[j] *= TURBO_WHT_SIGNS1_256[j]; + } else if (GROUP_SIZE == 128) { + x[j] *= TURBO_WHT_SIGNS1[j]; + } else { + x[j] *= TURBO_WHT_SIGNS1_64[j]; + } + __syncthreads(); + +#define WHT_STAGE_SHARED(h) \ + if (j % (2*(h)) < (h)) { float a = x[j], b = x[j+(h)]; x[j] = a+b; x[j+(h)] = a-b; } \ + __syncthreads(); + + // Butterfly stages: loop from h=1 to h= 128) { WHT_STAGE_SHARED(64) } + if (GROUP_SIZE >= 256) { WHT_STAGE_SHARED(128) } +#undef WHT_STAGE_SHARED + + constexpr float inv_sqrt_group = (GROUP_SIZE == 256) ? 0.0625f : + (GROUP_SIZE == 128) ? 0.08838834764831845f : 0.125f; + if (GROUP_SIZE == 256) { + x[j] = x[j] * inv_sqrt_group * TURBO_WHT_SIGNS2_256[j]; + } else if (GROUP_SIZE == 128) { + x[j] = x[j] * inv_sqrt_group * TURBO_WHT_SIGNS2[j]; + } else { + x[j] = x[j] * inv_sqrt_group * TURBO_WHT_SIGNS2_64[j]; + } + __syncthreads(); + + // ---- Step 5: 2D VQ quantize pairs ---- + const float my_val = x[j]; + const float pair_val = __shfl_xor_sync(0xffffffff, my_val, 1); + const float vx = (j & 1) ? pair_val : my_val; + const float vy = (j & 1) ? my_val : pair_val; + uint8_t best_vq = 0; + if ((j & 1) == 0) { + float best_dist = 1e30f; + for (int c = 0; c < 64; c++) { + float dx = vx - TURBO_VQ2D_X[c]; + float dy = vy - TURBO_VQ2D_Y[c]; + float d = dx*dx + dy*dy; + if (d < best_dist) { best_dist = d; best_vq = (uint8_t)c; } + } + } + best_vq = __shfl_sync(0xffffffff, best_vq, (j % WARP_SIZE) & ~1); + const uint8_t idx = (j & 1) ? (best_vq & 0x7) : ((best_vq >> 3) & 0x7); + + // ---- Step 6: Pack qs and signs (warp-cooperative, no atomics) ---- + // Each warp handles 32 elements. With QK_TURBO3 > WARP_SIZE, multiple warps + // share one block and write to different byte offsets within it. + const int warp_id = j / WARP_SIZE; + const int lane = j % WARP_SIZE; + const int elem_in_block = j % QK_TURBO3; + block_turbo3_0 * blk = blk_base + (j / QK_TURBO3); + + // Pack qs: 4 elements per byte, 2 bits each. + // All 4 threads in a qs-group gather their low2 bits via shuffle. + const int qs_byte_idx = elem_in_block / 4; + const uint8_t my_low2 = idx & 0x3; + uint8_t qs_byte = 0; +#pragma unroll + for (int k = 0; k < 4; k++) { + uint8_t contrib = __shfl_sync(0xffffffff, my_low2, (lane & ~3) + k); + qs_byte |= contrib << (k * 2); + } + if (lane % 4 == 0) blk->qs[qs_byte_idx] = qs_byte; + + // Pack signs: 8 elements per byte, 1 bit each. __ballot_sync across warp. + // Ballot is per-warp (32 bits); extract local byte, write to global position in block. + const uint32_t ballot = __ballot_sync(0xffffffff, (idx >> 2) & 1); + const int local_signs_byte = lane / 8; // byte within 32-bit ballot (0..3) + const int global_signs_byte = elem_in_block / 8; // byte within block's signs array + const uint8_t signs_byte = (uint8_t)((ballot >> (local_signs_byte * 8)) & 0xFF); + if (lane % 8 == 0) blk->signs[global_signs_byte] = signs_byte; + + // ---- Step 7: Reconstruction norm (parallel, same pattern as step 2) ---- + const float c = (j & 1) ? TURBO_VQ2D_Y[best_vq] : TURBO_VQ2D_X[best_vq]; + float rc = c * c; + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) + rc += __shfl_xor_sync(0xffffffff, rc, offset); + if (j % WARP_SIZE == 0) + warp_accum[j / WARP_SIZE] = rc; + __syncthreads(); + + __shared__ float s_recon_sq; + if (j == 0) { + float total = 0.0f; + for (int w = 0; w < n_warps; w++) total += warp_accum[w]; + s_recon_sq = total; + } + __syncthreads(); + const float recon_norm = sqrtf(s_recon_sq); + const float corrected_norm = (recon_norm > 1e-10f) ? grp_norm / recon_norm : grp_norm; + + // ---- Step 8: Write corrected norm (one per turbo3 block) ---- + if (elem_in_block == 0) blk->norm = __float2half(corrected_norm); + + GGML_UNUSED(ne10); + GGML_UNUSED(ne13); +} + +// ---- TurboQuant3 tail kernel: straight 3-bit quantize without WHT rotation ---- +// +// For head dims not divisible by 128 (e.g. 576 = 4*128 + 64), the remainder +// elements can't use the 128-element WHT. They are quantised directly into +// standard turbo3 blocks. Q is also NOT rotated for these positions (the graph +// guards on ne[0] % 128), so stays in the original space. +// +// One CUDA block per row, with tail_size threads (must be multiple of 32). + +template +static __global__ void k_set_rows_turbo3_tail( + const float * __restrict__ src0, + const idx_t * __restrict__ src1, + block_turbo3_0 * __restrict__ dst, + const int64_t ne00, + const int64_t ne01, + const int64_t ne10, + const int64_t ne11, + const int64_t ne12, + const int64_t ne13, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t s10, + const int64_t s11, + const int64_t s12, + const int64_t s1, + const int64_t s2, + const int64_t s3, + const int tail_size) { + + const int j = threadIdx.x; // 0 .. tail_size-1 + + // Decode blockIdx.x → (i01, i02, i03) + int64_t tmp = blockIdx.x; + const int64_t i01 = tmp % ne01; tmp /= ne01; + const int64_t i02 = tmp % ne12; + const int64_t i03 = tmp / ne12; + + const int64_t i11 = i01 % ne11; + const int64_t i10 = i01; + const int64_t i12 = i02; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + const float * src_row = src0 + i01*s01 + i02*s02 + i03*s03; + block_turbo3_0 * dst_row_ptr = (block_turbo3_0 *)((char *)dst + dst_row*s1 + i02*s2 + i03*s3); + + // Tail starts after all full 128-element groups + const int64_t n_full = ne00 / QK_TURBO3_GROUP; + const int64_t tail_start = n_full * QK_TURBO3_GROUP; + block_turbo3_0 * blk_base = dst_row_ptr + n_full * (QK_TURBO3_GROUP / QK_TURBO3); + + // ---- Load ---- + const float val = src_row[tail_start + j]; + + // ---- L2 norm over the tail group (warp reduce + inter-warp) ---- + const int n_warps = tail_size / WARP_SIZE; + const int warp_id = j / WARP_SIZE; + const int lane = j % WARP_SIZE; + + __shared__ float warp_accum[4]; // max 3 warps (tail ≤ 96) + float v2 = val * val; + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) + v2 += __shfl_xor_sync(0xffffffff, v2, offset); + if (lane == 0) warp_accum[warp_id] = v2; + __syncthreads(); + + __shared__ float s_norm_sq; + if (j == 0) { + float total = 0.0f; + for (int w = 0; w < n_warps; w++) total += warp_accum[w]; + s_norm_sq = total; + } + __syncthreads(); + const float grp_norm = sqrtf(s_norm_sq); + const float inv_norm = (grp_norm > 1e-10f) ? 1.0f / grp_norm : 0.0f; + + // ---- Normalize (no WHT!) ---- + const float rv = val * inv_norm; + + // ---- Quantize ---- + const uint8_t idx = turbo_nearest_centroid_3bit(rv); + + // ---- Pack qs and signs (same warp-cooperative logic) ---- + block_turbo3_0 * blk = blk_base + warp_id; + + const uint8_t my_low2 = idx & 0x3; + uint8_t qs_byte = 0; +#pragma unroll + for (int k = 0; k < 4; k++) { + uint8_t contrib = __shfl_sync(0xffffffff, my_low2, (lane & ~3) + k); + qs_byte |= contrib << (k * 2); + } + if (lane % 4 == 0) blk->qs[lane / 4] = qs_byte; + + const uint32_t ballot = __ballot_sync(0xffffffff, (idx >> 2) & 1); + const int signs_byte_idx = lane / 8; + const uint8_t signs_byte = (uint8_t)((ballot >> (signs_byte_idx * 8)) & 0xFF); + if (lane % 8 == 0) blk->signs[signs_byte_idx] = signs_byte; + + // ---- Reconstruction norm ---- + const float c = TURBO_CENTROIDS_3BIT[idx]; + float rc = c * c; + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) + rc += __shfl_xor_sync(0xffffffff, rc, offset); + if (lane == 0) warp_accum[warp_id] = rc; + __syncthreads(); + + __shared__ float s_recon_sq; + if (j == 0) { + float total = 0.0f; + for (int w = 0; w < n_warps; w++) total += warp_accum[w]; + s_recon_sq = total; + } + __syncthreads(); + const float recon_norm = sqrtf(s_recon_sq); + const float corrected_norm = (recon_norm > 1e-10f) ? grp_norm / recon_norm : grp_norm; + + if (lane == 0) blk->norm = __float2half(corrected_norm); + + GGML_UNUSED(ne10); + GGML_UNUSED(ne13); +} + +template +static void set_rows_cuda_turbo3( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, + const ggml_tensor * src1, + ggml_tensor * dst) { + + const float * src0_d = (const float *)src0->data; + const idx_t * src1_d = (const idx_t *)src1->data; + + GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(ne00 % QK_TURBO3 == 0); // must be block-aligned (32) + + cudaStream_t stream = ctx.stream(); + + // Read WHT group size from op_params (set by llama-kv-cache.cpp based on head_dim). + // Default to 128 if not set (backward compat with head_dim=128 models). + int group_size = 128; + memcpy(&group_size, dst->op_params, sizeof(int)); + if (group_size != 64 && group_size != 128 && group_size != 256) group_size = 128; + GGML_ASSERT(ne00 % group_size == 0); + + const int64_t n_full_groups = ne00 / group_size; + const int tail_size = (int)(ne00 % group_size); + + const int64_t s01 = nb01/sizeof(float); + const int64_t s02 = nb02/sizeof(float); + const int64_t s03 = nb03/sizeof(float); + const int64_t s10 = nb10/sizeof(idx_t); + const int64_t s11 = nb11/sizeof(idx_t); + const int64_t s12 = nb12/sizeof(idx_t); + + // InnerQ: check/finalize calibration before kernel launch + turbo_innerq_check_finalize(group_size, ne00); + + // Launch 1: full groups with WHT rotation + if (n_full_groups > 0) { + const int64_t ne_total = n_full_groups * ne01 * ne02 * ne03; + if (group_size == 256) { + k_set_rows_turbo3<<<(int)ne_total, 256, 0, stream>>>( + src0_d, src1_d, (block_turbo3_0 *)dst->data, + ne00, ne01, ne10, ne11, ne12, ne13, + s01, s02, s03, s10, s11, s12, + nb1, nb2, nb3); + } else if (group_size == 128) { + k_set_rows_turbo3<<<(int)ne_total, 128, 0, stream>>>( + src0_d, src1_d, (block_turbo3_0 *)dst->data, + ne00, ne01, ne10, ne11, ne12, ne13, + s01, s02, s03, s10, s11, s12, + nb1, nb2, nb3); + } else { + k_set_rows_turbo3<<<(int)ne_total, 64, 0, stream>>>( + src0_d, src1_d, (block_turbo3_0 *)dst->data, + ne00, ne01, ne10, ne11, ne12, ne13, + s01, s02, s03, s10, s11, s12, + nb1, nb2, nb3); + } + } + + // Launch 2: tail elements (no WHT, straight quantize) + // Not needed for 64-aligned dims but kept for potential future use + if (tail_size > 0) { + GGML_ASSERT(tail_size % QK_TURBO3 == 0); // tail must be block-aligned + const int64_t n_rows = ne01 * ne02 * ne03; + k_set_rows_turbo3_tail<<<(int)n_rows, tail_size, 0, stream>>>( + src0_d, src1_d, (block_turbo3_0 *)dst->data, + ne00, ne01, ne10, ne11, ne12, ne13, + s01, s02, s03, s10, s11, s12, + nb1, nb2, nb3, tail_size); + } +} + +// ---- TurboQuant2 set_rows: GROUP_SIZE-element groups with WHT rotation + norm correction ---- +// +// Same structure as turbo3 but 2-bit quantization only (no signs byte). + +template +__launch_bounds__(256) // max of 256, 128, or 64 +static __global__ void k_set_rows_turbo2( + const float * __restrict__ src0, + const idx_t * __restrict__ src1, + block_turbo2_0 * __restrict__ dst, + const int64_t ne00, + const int64_t ne01, + const int64_t ne10, + const int64_t ne11, + const int64_t ne12, + const int64_t ne13, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t s10, + const int64_t s11, + const int64_t s12, + const int64_t s1, + const int64_t s2, + const int64_t s3) { + + static_assert(GROUP_SIZE == 64 || GROUP_SIZE == 128 || GROUP_SIZE == 256, "GROUP_SIZE must be 64, 128, or 256"); + + const int j = threadIdx.x; + + constexpr int blocks_per_group = GROUP_SIZE / QK_TURBO2; + const int64_t n_groups_per_row = ne00 / GROUP_SIZE; + const int64_t g = blockIdx.x; + const int64_t i_grp = g % n_groups_per_row; + int64_t tmp = g / n_groups_per_row; + const int64_t i01 = tmp % ne01; + tmp = tmp / ne01; + const int64_t i02 = tmp % ne12; + const int64_t i03 = tmp / ne12; + + const int64_t i12 = i02; + const int64_t i11 = i01 % ne11; + const int64_t i10 = i01; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + const float * src_row = src0 + i01*s01 + i02*s02 + i03*s03; + block_turbo2_0 * dst_row_ptr = (block_turbo2_0 *)((char *)dst + dst_row*s1 + i02*s2 + i03*s3); + block_turbo2_0 * blk_base = dst_row_ptr + i_grp * blocks_per_group; + + // ---- Step 1: Load element j (coalesced) ---- + __shared__ float x[GROUP_SIZE]; + x[j] = src_row[i_grp * GROUP_SIZE + j]; + __syncthreads(); + + // ---- InnerQ: calibrate on original (unscaled) values ---- + if (d_innerq_calibrating) { + atomicAdd(&d_innerq_sq_accum[j], x[j] * x[j]); + if (j == 0) atomicAdd(&d_innerq_count, 1); + } + + // ---- InnerQ: apply channel scale (only when active) ---- + if (d_innerq_active) { + x[j] *= d_innerq_scale[j]; + } + __syncthreads(); + + // ---- Step 2: Parallel L2 norm ---- + constexpr int n_warps = GROUP_SIZE / WARP_SIZE; + __shared__ float warp_accum[n_warps]; + float v = x[j]; + float v2 = v * v; + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) + v2 += __shfl_xor_sync(0xffffffff, v2, offset); + if (j % WARP_SIZE == 0) + warp_accum[j / WARP_SIZE] = v2; + __syncthreads(); + + __shared__ float s_norm_sq; + if (j == 0) { + float total = 0.0f; + for (int w = 0; w < n_warps; w++) total += warp_accum[w]; + s_norm_sq = total; + } + __syncthreads(); + const float grp_norm = sqrtf(s_norm_sq); + const float inv_norm = (grp_norm > 1e-10f) ? 1.0f / grp_norm : 0.0f; + + // ---- Step 3: Normalize ---- + x[j] *= inv_norm; + __syncthreads(); + + // ---- Step 4: Forward WHT ---- + if (GROUP_SIZE == 256) { + x[j] *= TURBO_WHT_SIGNS1_256[j]; + } else if (GROUP_SIZE == 128) { + x[j] *= TURBO_WHT_SIGNS1[j]; + } else { + x[j] *= TURBO_WHT_SIGNS1_64[j]; + } + __syncthreads(); + +#define WHT_STAGE_SHARED_T2(h) \ + if (j % (2*(h)) < (h)) { float a = x[j], b = x[j+(h)]; x[j] = a+b; x[j+(h)] = a-b; } \ + __syncthreads(); + + WHT_STAGE_SHARED_T2(1) + WHT_STAGE_SHARED_T2(2) + WHT_STAGE_SHARED_T2(4) + WHT_STAGE_SHARED_T2(8) + WHT_STAGE_SHARED_T2(16) + WHT_STAGE_SHARED_T2(32) + if (GROUP_SIZE >= 128) { WHT_STAGE_SHARED_T2(64) } + if (GROUP_SIZE >= 256) { WHT_STAGE_SHARED_T2(128) } +#undef WHT_STAGE_SHARED_T2 + + constexpr float inv_sqrt_group = (GROUP_SIZE == 256) ? 0.0625f : + (GROUP_SIZE == 128) ? 0.08838834764831845f : 0.125f; + if (GROUP_SIZE == 256) { + x[j] = x[j] * inv_sqrt_group * TURBO_WHT_SIGNS2_256[j]; + } else if (GROUP_SIZE == 128) { + x[j] = x[j] * inv_sqrt_group * TURBO_WHT_SIGNS2[j]; + } else { + x[j] = x[j] * inv_sqrt_group * TURBO_WHT_SIGNS2_64[j]; + } + __syncthreads(); + + // ---- Step 5: Quantize element j to 2-bit centroid ---- + const float rv = x[j]; + const uint8_t idx = turbo_nearest_centroid_2bit(rv); + + // ---- Step 6: Pack qs (warp-cooperative, no atomics) ---- + // Each warp handles 32 elements. With QK_TURBO2 > WARP_SIZE, multiple warps + // share one block and write to different byte offsets within it. + const int warp_id = j / WARP_SIZE; + const int lane = j % WARP_SIZE; + const int elem_in_block = j % QK_TURBO2; + block_turbo2_0 * blk = blk_base + (j / QK_TURBO2); + + // Pack qs: 4 elements per byte, 2 bits each. + const uint8_t my_bits = idx & 0x3; + uint8_t qs_byte = 0; +#pragma unroll + for (int k = 0; k < 4; k++) { + uint8_t contrib = __shfl_sync(0xffffffff, my_bits, (lane & ~3) + k); + qs_byte |= contrib << (k * 2); + } + if (lane % 4 == 0) blk->qs[elem_in_block / 4] = qs_byte; + + // No signs packing needed for turbo2 + + // ---- Step 7: Reconstruction norm ---- + const float c = TURBO_CENTROIDS_2BIT[idx]; + float rc = c * c; + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) + rc += __shfl_xor_sync(0xffffffff, rc, offset); + if (j % WARP_SIZE == 0) + warp_accum[j / WARP_SIZE] = rc; + __syncthreads(); + + __shared__ float s_recon_sq; + if (j == 0) { + float total = 0.0f; + for (int w = 0; w < n_warps; w++) total += warp_accum[w]; + s_recon_sq = total; + } + __syncthreads(); + const float recon_norm = sqrtf(s_recon_sq); + const float corrected_norm = (recon_norm > 1e-10f) ? grp_norm / recon_norm : grp_norm; + + // ---- Step 8: Write corrected norm (one per turbo2 block) ---- + if (elem_in_block == 0) blk->norm = __float2half(corrected_norm); + + GGML_UNUSED(ne10); + GGML_UNUSED(ne13); +} + +// ---- TurboQuant2 tail kernel: straight 2-bit quantize without WHT rotation ---- + +template +static __global__ void k_set_rows_turbo2_tail( + const float * __restrict__ src0, + const idx_t * __restrict__ src1, + block_turbo2_0 * __restrict__ dst, + const int64_t ne00, + const int64_t ne01, + const int64_t ne10, + const int64_t ne11, + const int64_t ne12, + const int64_t ne13, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t s10, + const int64_t s11, + const int64_t s12, + const int64_t s1, + const int64_t s2, + const int64_t s3, + const int tail_size) { + + const int j = threadIdx.x; + + int64_t tmp = blockIdx.x; + const int64_t i01 = tmp % ne01; tmp /= ne01; + const int64_t i02 = tmp % ne12; + const int64_t i03 = tmp / ne12; + + const int64_t i11 = i01 % ne11; + const int64_t i10 = i01; + const int64_t i12 = i02; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + const float * src_row = src0 + i01*s01 + i02*s02 + i03*s03; + block_turbo2_0 * dst_row_ptr = (block_turbo2_0 *)((char *)dst + dst_row*s1 + i02*s2 + i03*s3); + + const int64_t n_full = ne00 / QK_TURBO2_GROUP; + const int64_t tail_start = n_full * QK_TURBO2_GROUP; + block_turbo2_0 * blk_base = dst_row_ptr + n_full * (QK_TURBO2_GROUP / QK_TURBO2); + + // ---- Load ---- + const float val = src_row[tail_start + j]; + + // ---- L2 norm ---- + const int n_warps = tail_size / WARP_SIZE; + const int warp_id = j / WARP_SIZE; + const int lane = j % WARP_SIZE; + + __shared__ float warp_accum[4]; + float v2 = val * val; + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) + v2 += __shfl_xor_sync(0xffffffff, v2, offset); + if (lane == 0) warp_accum[warp_id] = v2; + __syncthreads(); + + __shared__ float s_norm_sq; + if (j == 0) { + float total = 0.0f; + for (int w = 0; w < n_warps; w++) total += warp_accum[w]; + s_norm_sq = total; + } + __syncthreads(); + const float grp_norm = sqrtf(s_norm_sq); + const float inv_norm = (grp_norm > 1e-10f) ? 1.0f / grp_norm : 0.0f; + + // ---- Normalize (no WHT!) ---- + const float rv = val * inv_norm; + + // ---- Quantize ---- + const uint8_t idx = turbo_nearest_centroid_2bit(rv); + + // ---- Pack qs ---- + block_turbo2_0 * blk = blk_base + warp_id; + + const uint8_t my_bits = idx & 0x3; + uint8_t qs_byte = 0; +#pragma unroll + for (int k = 0; k < 4; k++) { + uint8_t contrib = __shfl_sync(0xffffffff, my_bits, (lane & ~3) + k); + qs_byte |= contrib << (k * 2); + } + if (lane % 4 == 0) blk->qs[lane / 4] = qs_byte; + + // ---- Reconstruction norm ---- + const float c = TURBO_CENTROIDS_2BIT[idx]; + float rc = c * c; + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) + rc += __shfl_xor_sync(0xffffffff, rc, offset); + if (lane == 0) warp_accum[warp_id] = rc; + __syncthreads(); + + __shared__ float s_recon_sq; + if (j == 0) { + float total = 0.0f; + for (int w = 0; w < n_warps; w++) total += warp_accum[w]; + s_recon_sq = total; + } + __syncthreads(); + const float recon_norm = sqrtf(s_recon_sq); + const float corrected_norm = (recon_norm > 1e-10f) ? grp_norm / recon_norm : grp_norm; + + if (lane == 0) blk->norm = __float2half(corrected_norm); + + GGML_UNUSED(ne10); + GGML_UNUSED(ne13); + GGML_UNUSED(ne00); +} + +template +static void set_rows_cuda_turbo2( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, + const ggml_tensor * src1, + ggml_tensor * dst) { + + const float * src0_d = (const float *)src0->data; + const idx_t * src1_d = (const idx_t *)src1->data; + + GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(ne00 % QK_TURBO2 == 0); + + cudaStream_t stream = ctx.stream(); + + int group_size = 128; + memcpy(&group_size, dst->op_params, sizeof(int)); + if (group_size != 64 && group_size != 128 && group_size != 256) group_size = 128; + GGML_ASSERT(ne00 % group_size == 0); + + const int64_t n_full_groups = ne00 / group_size; + const int tail_size = (int)(ne00 % group_size); + + const int64_t s01 = nb01/sizeof(float); + const int64_t s02 = nb02/sizeof(float); + const int64_t s03 = nb03/sizeof(float); + const int64_t s10 = nb10/sizeof(idx_t); + const int64_t s11 = nb11/sizeof(idx_t); + const int64_t s12 = nb12/sizeof(idx_t); + + // InnerQ: check/finalize calibration before kernel launch + turbo_innerq_check_finalize(group_size, ne00); + + if (n_full_groups > 0) { + const int64_t ne_total = n_full_groups * ne01 * ne02 * ne03; + if (group_size == 256) { + k_set_rows_turbo2<<<(int)ne_total, 256, 0, stream>>>( + src0_d, src1_d, (block_turbo2_0 *)dst->data, + ne00, ne01, ne10, ne11, ne12, ne13, + s01, s02, s03, s10, s11, s12, + nb1, nb2, nb3); + } else if (group_size == 128) { + k_set_rows_turbo2<<<(int)ne_total, 128, 0, stream>>>( + src0_d, src1_d, (block_turbo2_0 *)dst->data, + ne00, ne01, ne10, ne11, ne12, ne13, + s01, s02, s03, s10, s11, s12, + nb1, nb2, nb3); + } else { + k_set_rows_turbo2<<<(int)ne_total, 64, 0, stream>>>( + src0_d, src1_d, (block_turbo2_0 *)dst->data, + ne00, ne01, ne10, ne11, ne12, ne13, + s01, s02, s03, s10, s11, s12, + nb1, nb2, nb3); + } + } + + if (tail_size > 0) { + GGML_ASSERT(tail_size % QK_TURBO2 == 0); + const int64_t n_rows = ne01 * ne02 * ne03; + k_set_rows_turbo2_tail<<<(int)n_rows, tail_size, 0, stream>>>( + src0_d, src1_d, (block_turbo2_0 *)dst->data, + ne00, ne01, ne10, ne11, ne12, ne13, + s01, s02, s03, s10, s11, s12, + nb1, nb2, nb3, tail_size); + } +} + +// ---- TurboQuant4 set_rows: 128-element groups with WHT rotation + 4-bit quantization ---- +// +// turbo4 block size IS the WHT group size (128), so 1 CUDA block = 1 turbo4 block. +// 128 threads per block, thread j handles element j. +// 4-bit centroids (16 values), nibble packed: qs[j/2] |= (idx & 0xF) << ((j%2)*4) + +template +__launch_bounds__(128) +static __global__ void k_set_rows_turbo4( + const float * __restrict__ src0, + const idx_t * __restrict__ src1, + block_turbo4_0 * __restrict__ dst, + const int64_t ne00, + const int64_t ne01, + const int64_t ne10, + const int64_t ne11, + const int64_t ne12, + const int64_t ne13, + const int64_t s01, + const int64_t s02, + const int64_t s03, + const int64_t s10, + const int64_t s11, + const int64_t s12, + const int64_t s1, + const int64_t s2, + const int64_t s3) { + + // blockIdx.x = flat block index; threadIdx.x = element within block (0..127) + const int j = threadIdx.x; + + // Decode blockIdx.x → (i_blk, i01, i02, i03) + const int64_t n_blocks_per_row = ne00 / QK_TURBO4; + const int64_t g = blockIdx.x; + const int64_t i_blk = g % n_blocks_per_row; + int64_t tmp = g / n_blocks_per_row; + const int64_t i01 = tmp % ne01; + tmp = tmp / ne01; + const int64_t i02 = tmp % ne12; + const int64_t i03 = tmp / ne12; + + const int64_t i12 = i02; + const int64_t i11 = i01 % ne11; + const int64_t i10 = i01; + + const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); + const float * src_row = src0 + i01*s01 + i02*s02 + i03*s03; + block_turbo4_0 * dst_row_ptr = (block_turbo4_0 *)((char *)dst + dst_row*s1 + i02*s2 + i03*s3); + block_turbo4_0 * blk = dst_row_ptr + i_blk; + + // ---- Step 1: Load element j (coalesced) ---- + __shared__ float x[128]; + x[j] = src_row[i_blk * QK_TURBO4 + j]; + __syncthreads(); + + // ---- InnerQ: calibrate on original (unscaled) values ---- + if (d_innerq_calibrating) { + atomicAdd(&d_innerq_sq_accum[j], x[j] * x[j]); + if (j == 0) atomicAdd(&d_innerq_count, 1); + } + + // ---- InnerQ: apply channel scale (only when active) ---- + if (d_innerq_active) { + x[j] *= d_innerq_scale[j]; + } + __syncthreads(); + + // ---- Step 2: Parallel L2 norm ---- + constexpr int n_warps = 128 / WARP_SIZE; // = 4 + __shared__ float warp_accum[n_warps]; + float v = x[j]; + float v2 = v * v; + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) + v2 += __shfl_xor_sync(0xffffffff, v2, offset); + if (j % WARP_SIZE == 0) + warp_accum[j / WARP_SIZE] = v2; + __syncthreads(); + + __shared__ float s_norm_sq; + if (j == 0) { + float total = 0.0f; + for (int w = 0; w < n_warps; w++) total += warp_accum[w]; + s_norm_sq = total; + } + __syncthreads(); + const float grp_norm = sqrtf(s_norm_sq); + const float inv_norm = (grp_norm > 1e-10f) ? 1.0f / grp_norm : 0.0f; + + // ---- Step 3: Normalize ---- + x[j] *= inv_norm; + __syncthreads(); + + // ---- Step 4: Forward WHT (signs1 → butterfly → signs2, normalized) ---- + x[j] *= TURBO_WHT_SIGNS1[j]; + __syncthreads(); + +#define WHT_STAGE_SHARED_T4(h) \ + if (j % (2*(h)) < (h)) { float a = x[j], b = x[j+(h)]; x[j] = a+b; x[j+(h)] = a-b; } \ + __syncthreads(); + + WHT_STAGE_SHARED_T4(1) + WHT_STAGE_SHARED_T4(2) + WHT_STAGE_SHARED_T4(4) + WHT_STAGE_SHARED_T4(8) + WHT_STAGE_SHARED_T4(16) + WHT_STAGE_SHARED_T4(32) + WHT_STAGE_SHARED_T4(64) +#undef WHT_STAGE_SHARED_T4 + + constexpr float inv_sqrt_128 = 0.08838834764831845f; + x[j] = x[j] * inv_sqrt_128 * TURBO_WHT_SIGNS2[j]; + __syncthreads(); + + // ---- Step 5: Quantize element j to 4-bit centroid ---- + const float rv = x[j]; + const uint8_t idx = turbo_nearest_centroid_4bit(rv); + + // ---- Step 6: Pack qs (nibble packed, warp-cooperative) ---- + // 2 elements per byte, 4 bits each. + // Thread pairs (j, j+1) share a qs byte. + const int lane = j % WARP_SIZE; + const uint8_t my_nibble = idx & 0xF; + uint8_t qs_byte = 0; + // Gather nibble from partner thread + uint8_t partner_nibble = __shfl_sync(0xffffffff, my_nibble, lane ^ 1); + if (j % 2 == 0) { + qs_byte = my_nibble | (partner_nibble << 4); + blk->qs[j / 2] = qs_byte; + } + + // ---- Step 7: Reconstruction norm (parallel) ---- + const float c = TURBO_CENTROIDS_4BIT[idx]; + float rc = c * c; + for (int offset = WARP_SIZE / 2; offset > 0; offset >>= 1) + rc += __shfl_xor_sync(0xffffffff, rc, offset); + if (j % WARP_SIZE == 0) + warp_accum[j / WARP_SIZE] = rc; + __syncthreads(); + + __shared__ float s_recon_sq; + if (j == 0) { + float total = 0.0f; + for (int w = 0; w < n_warps; w++) total += warp_accum[w]; + s_recon_sq = total; + } + __syncthreads(); + const float recon_norm = sqrtf(s_recon_sq); + const float corrected_norm = (recon_norm > 1e-10f) ? grp_norm / recon_norm : grp_norm; + + // ---- Step 8: Write corrected norm and zero rnorm (one thread) ---- + if (j == 0) { + blk->norm = __float2half(corrected_norm); + blk->rnorm = __float2half(0.0f); + } + + GGML_UNUSED(ne10); + GGML_UNUSED(ne13); +} + +template +static void set_rows_cuda_turbo4( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, + const ggml_tensor * src1, + ggml_tensor * dst) { + + const float * src0_d = (const float *)src0->data; + const idx_t * src1_d = (const idx_t *)src1->data; + + GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(ne00 % QK_TURBO4 == 0); // must be block-aligned (128) + + cudaStream_t stream = ctx.stream(); + + // turbo4 block size = WHT group size = 128, always + const int64_t n_blocks = ne00 / QK_TURBO4; + + const int64_t s01 = nb01/sizeof(float); + const int64_t s02 = nb02/sizeof(float); + const int64_t s03 = nb03/sizeof(float); + const int64_t s10 = nb10/sizeof(idx_t); + const int64_t s11 = nb11/sizeof(idx_t); + const int64_t s12 = nb12/sizeof(idx_t); + + // InnerQ: check/finalize calibration before kernel launch + turbo_innerq_check_finalize(QK_TURBO4, ne00); + + if (n_blocks > 0) { + const int64_t ne_total = n_blocks * ne01 * ne02 * ne03; + k_set_rows_turbo4<<<(int)ne_total, 128, 0, stream>>>( + src0_d, src1_d, (block_turbo4_0 *)dst->data, + ne00, ne01, ne10, ne11, ne12, ne13, + s01, s02, s03, s10, s11, s12, + nb1, nb2, nb3); + } +} + template static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const src_t * src0_d = (const src_t *)src0->data; @@ -309,6 +1270,12 @@ static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * s nb1, nb2, nb3, stream ); + } else if (dst->type == GGML_TYPE_TURBO3_0) { + set_rows_cuda_turbo3(ctx, src0, src1, dst); + } else if (dst->type == GGML_TYPE_TURBO2_0) { + set_rows_cuda_turbo2(ctx, src0, src1, dst); + } else if (dst->type == GGML_TYPE_TURBO4_0) { + set_rows_cuda_turbo4(ctx, src0, src1, dst); } else { GGML_ABORT("unsupported type %s", ggml_type_name(dst->type)); } diff --git a/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq640-dv512.cu b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq640-dv512.cu new file mode 100644 index 00000000000..c68a841ad41 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq640-dv512.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(640, 512); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo2_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo2_0.cu new file mode 100644 index 00000000000..0aaf9b153c1 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo2_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// Mixed KV: q8_0 K + turbo2 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO2_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo3_0.cu new file mode 100644 index 00000000000..a5aba051faa --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo3_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// Mixed KV: q8_0 K + turbo3 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu new file mode 100644 index 00000000000..9217ca9702f --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// Mixed KV: q8_0 K + turbo4 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_Q8_0, GGML_TYPE_TURBO4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-q8_0.cu new file mode 100644 index 00000000000..93d70cf48f4 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-q8_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// Mixed KV: turbo2 K + q8_0 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO2_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO2_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo2_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo2_0.cu new file mode 100644 index 00000000000..f0d182957ad --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo2_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// TurboQuant2 CUDA flash attention vec kernel instantiation + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO2_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO2_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO2_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo3_0.cu new file mode 100644 index 00000000000..1553089d220 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo3_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// Mixed KV: turbo2 K + turbo3 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo4_0.cu new file mode 100644 index 00000000000..fc32d8634af --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo4_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// Mixed KV: turbo2 K + turbo4 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO2_0, GGML_TYPE_TURBO4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-q8_0.cu new file mode 100644 index 00000000000..b6f1b1ab7aa --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-q8_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// Mixed KV: turbo3 K + q8_0 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO3_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO3_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO3_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo2_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo2_0.cu new file mode 100644 index 00000000000..87cee19406d --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo2_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// Mixed KV: turbo3 K + turbo2 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO2_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu new file mode 100644 index 00000000000..1ae51643256 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// TurboQuant3 CUDA flash attention vec kernel instantiation + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo4_0.cu new file mode 100644 index 00000000000..4da33c5e4fe --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo4_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// Mixed KV: turbo3 K + turbo4 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO3_0, GGML_TYPE_TURBO4_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-q8_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-q8_0.cu new file mode 100644 index 00000000000..c3813c17efa --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-q8_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// Mixed KV: turbo4 K + q8_0 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_Q8_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo2_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo2_0.cu new file mode 100644 index 00000000000..0470c74dcfb --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo2_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// Mixed KV: turbo4 K + turbo2 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO2_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo3_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo3_0.cu new file mode 100644 index 00000000000..b103c431ee6 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo3_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// Mixed KV: turbo4 K + turbo3 V + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO3_0); diff --git a/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo4_0.cu b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo4_0.cu new file mode 100644 index 00000000000..344ebf800a7 --- /dev/null +++ b/ggml/src/ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo4_0.cu @@ -0,0 +1,8 @@ +#define GGML_FATTN_TURBO +// TurboQuant4 CUDA flash attention vec kernel instantiation + +#include "../fattn-vec.cuh" + +DECL_FATTN_VEC_CASE( 64, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(128, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0); +DECL_FATTN_VEC_CASE(256, GGML_TYPE_TURBO4_0, GGML_TYPE_TURBO4_0); diff --git a/ggml/src/ggml-cuda/turbo-innerq.cu b/ggml/src/ggml-cuda/turbo-innerq.cu new file mode 100644 index 00000000000..51e2130176a --- /dev/null +++ b/ggml/src/ggml-cuda/turbo-innerq.cu @@ -0,0 +1,40 @@ +#include "turbo-innerq.cuh" +#include + +// Host-side shared state for InnerQ cross-TU communication +TURBO_IQ_API bool g_innerq_finalized = false; +TURBO_IQ_API float g_innerq_scale_inv_host[INNERQ_MAX_CHANNELS] = { + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1, + 1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1 +}; + +static bool g_innerq_tensor_needs_update = false; + +TURBO_IQ_API turbo_rpn_config_t g_rpn_config = {0, 0, true}; + +TURBO_IQ_API void turbo_rpn_set_config(int rope_type, int n_rot, bool is_key) { + g_rpn_config.rope_type = rope_type; + g_rpn_config.n_rot = n_rot; + g_rpn_config.is_key = is_key; +} + +void turbo_innerq_publish(const float * scale_inv, int group_size) { + for (int i = 0; i < group_size && i < INNERQ_MAX_CHANNELS; i++) { + g_innerq_scale_inv_host[i] = scale_inv[i]; + } + for (int i = group_size; i < INNERQ_MAX_CHANNELS; i++) { + g_innerq_scale_inv_host[i] = 1.0f; + } + g_innerq_finalized = true; + g_innerq_tensor_needs_update = true; +} + +TURBO_IQ_API bool turbo_innerq_needs_tensor_update(void) { + return g_innerq_tensor_needs_update; +} + +TURBO_IQ_API void turbo_innerq_mark_tensor_updated(void) { + g_innerq_tensor_needs_update = false; +} diff --git a/ggml/src/ggml-cuda/turbo-innerq.cuh b/ggml/src/ggml-cuda/turbo-innerq.cuh new file mode 100644 index 00000000000..4335c9f31d3 --- /dev/null +++ b/ggml/src/ggml-cuda/turbo-innerq.cuh @@ -0,0 +1,41 @@ +#pragma once + +// TurboQuant InnerQ per-channel equalization — cross-TU shared state +// The host-side state lives in turbo-innerq.cu; device-side state is per-TU +// in turbo-quant.cuh (only set-rows.cu needs device access). + +#define INNERQ_MAX_CHANNELS 128 + +// RPN (RoPE Pair Normalization) config — set by host before finalization +typedef struct { + int rope_type; // LLAMA_ROPE_TYPE_NORM=2, NEOX=20, MROPE=21, etc. 0=unknown + int n_rot; // number of rotated dims (head_dim for 100% RoPE) + bool is_key; // true=K cache, false=V cache +} turbo_rpn_config_t; + +#if defined(_WIN32) && !defined(__MINGW32__) +# ifdef GGML_BACKEND_BUILD +# define TURBO_IQ_API __declspec(dllexport) +# else +# define TURBO_IQ_API __declspec(dllimport) +# endif +#else +# define TURBO_IQ_API __attribute__((visibility("default"))) +#endif + +// Host-side shared state (defined in turbo-innerq.cu) +TURBO_IQ_API extern bool g_innerq_finalized; +TURBO_IQ_API extern float g_innerq_scale_inv_host[INNERQ_MAX_CHANNELS]; + +// Called from set-rows.cu after InnerQ finalization to publish scale_inv +void turbo_innerq_publish(const float * scale_inv, int group_size); + +// Called from llama-kv-cache.cpp (or equivalent) to check if tensor needs update +TURBO_IQ_API bool turbo_innerq_needs_tensor_update(void); + +// Called after tensor update to clear the flag +TURBO_IQ_API void turbo_innerq_mark_tensor_updated(void); + +// RPN config — set before InnerQ calibration starts +TURBO_IQ_API extern turbo_rpn_config_t g_rpn_config; +TURBO_IQ_API void turbo_rpn_set_config(int rope_type, int n_rot, bool is_key); diff --git a/ggml/src/ggml-cuda/turbo-quant.cuh b/ggml/src/ggml-cuda/turbo-quant.cuh new file mode 100644 index 00000000000..5e7a2174ffc --- /dev/null +++ b/ggml/src/ggml-cuda/turbo-quant.cuh @@ -0,0 +1,579 @@ +/* + * TurboQuant CUDA kernels for KV cache compression + * Based on: arXiv 2504.19874 (ICLR 2026) + * + * Implements GGML_TYPE_TURBO3_0 (3-bit PolarQuant, block size 32) + * Constants, WHT rotation, quantize/dequantize device functions. + */ + +#pragma once + +#include "common.cuh" +#include "turbo-innerq.cuh" +#include +#include + +// ---- Quantization ratios for dequantize_block template ---- +#define QR_TURBO3 1 // Each dequantize call produces 2 consecutive elements (like q8_0) +#define QR_TURBO2 1 // Each dequantize call produces 2 consecutive elements (like q8_0) +#define QR_TURBO4 1 // Each dequantize call produces 2 consecutive elements (like q8_0) + +// ---- 2-bit centroids (Lloyd-Max for N(0, 1/128)) ---- + +static __constant__ float TURBO_CENTROIDS_2BIT[4] = { + -0.133462f, -0.039994f, 0.039994f, 0.133462f +}; + +static __constant__ float TURBO_MID_2BIT[3] = { + -0.086728f, 0.0f, 0.086728f +}; + +// ---- 3-bit centroids (Lloyd-Max for N(0, 1/128)) ---- + +static __constant__ float TURBO_CENTROIDS_3BIT[8] = { + -0.190685f, -0.117832f, -0.065717f, -0.021460f, + 0.021460f, 0.065717f, 0.117832f, 0.190685f +}; + +// ---- Midpoints for nearest centroid lookup ---- + +static __constant__ float TURBO_MID_3BIT[7] = { + -0.154259f, -0.091775f, -0.043589f, 0.0f, + 0.043589f, 0.091775f, 0.154259f +}; + +// ---- WHT sign arrays (seed=42) ---- + +static __constant__ float TURBO_WHT_SIGNS1[128] = { + -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, + -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, + 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, + -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, + 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f +}; + +static __constant__ float TURBO_WHT_SIGNS2[128] = { + 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, + 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, + 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, + 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, + -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, + 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, + -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f +}; + +// ---- 64-element WHT sign arrays (first 64 of the 128-element arrays) ---- + +static __constant__ float TURBO_WHT_SIGNS1_64[64] = { + -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, + -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, + 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f +}; + +static __constant__ float TURBO_WHT_SIGNS2_64[64] = { + 1.0f, 1.0f, 1.0f, 1.0f, -1.0f, 1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, + 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, + 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, + 1.0f, -1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, -1.0f, 1.0f, 1.0f +}; + +// ---- Sign arrays for GROUP_SIZE=256 (Gemma 4 D=256 SWA, D=512 as 2×256 blocks) ---- +// Generated with deterministic seeds (256001, 256002) +static __constant__ float TURBO_WHT_SIGNS1_256[256] = { + -1.f, -1.f, -1.f, -1.f, 1.f, -1.f, 1.f, -1.f, -1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, 1.f, -1.f, 1.f, + -1.f, -1.f, 1.f, 1.f, 1.f, -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, 1.f, -1.f, -1.f, -1.f, 1.f, + 1.f, -1.f, -1.f, -1.f, -1.f, -1.f, 1.f, 1.f, -1.f, -1.f, 1.f, 1.f, 1.f, -1.f, -1.f, 1.f, + -1.f, 1.f, -1.f, -1.f, -1.f, 1.f, -1.f, 1.f, 1.f, 1.f, -1.f, 1.f, -1.f, -1.f, -1.f, 1.f, + 1.f, 1.f, -1.f, 1.f, -1.f, -1.f, 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, -1.f, -1.f, -1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, -1.f, -1.f, 1.f, 1.f, -1.f, -1.f, 1.f, -1.f, -1.f, -1.f, 1.f, 1.f, + 1.f, 1.f, -1.f, -1.f, 1.f, -1.f, 1.f, 1.f, 1.f, -1.f, 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, -1.f, 1.f, 1.f, -1.f, -1.f, -1.f, 1.f, -1.f, -1.f, -1.f, + 1.f, 1.f, 1.f, -1.f, 1.f, -1.f, -1.f, 1.f, 1.f, 1.f, -1.f, 1.f, -1.f, -1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, -1.f, 1.f, -1.f, 1.f, -1.f, -1.f, 1.f, 1.f, -1.f, 1.f, 1.f, -1.f, 1.f, + -1.f, -1.f, -1.f, 1.f, 1.f, -1.f, 1.f, -1.f, -1.f, 1.f, -1.f, 1.f, 1.f, -1.f, -1.f, 1.f, + -1.f, 1.f, 1.f, -1.f, -1.f, -1.f, -1.f, 1.f, -1.f, 1.f, 1.f, 1.f, -1.f, 1.f, 1.f, 1.f, + -1.f, 1.f, -1.f, 1.f, -1.f, 1.f, 1.f, 1.f, -1.f, -1.f, -1.f, 1.f, -1.f, 1.f, -1.f, 1.f, + 1.f, -1.f, 1.f, 1.f, -1.f, 1.f, -1.f, -1.f, 1.f, -1.f, 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, + 1.f, -1.f, 1.f, 1.f, 1.f, -1.f, 1.f, -1.f, 1.f, 1.f, -1.f, -1.f, 1.f, 1.f, 1.f, 1.f, +}; +static __constant__ float TURBO_WHT_SIGNS2_256[256] = { + 1.f, -1.f, 1.f, 1.f, -1.f, -1.f, -1.f, 1.f, -1.f, -1.f, 1.f, 1.f, 1.f, -1.f, 1.f, -1.f, + 1.f, 1.f, -1.f, -1.f, 1.f, -1.f, -1.f, 1.f, 1.f, -1.f, -1.f, -1.f, -1.f, 1.f, -1.f, -1.f, + 1.f, -1.f, 1.f, -1.f, -1.f, -1.f, 1.f, -1.f, 1.f, 1.f, 1.f, -1.f, 1.f, 1.f, -1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, -1.f, 1.f, 1.f, -1.f, 1.f, 1.f, -1.f, + -1.f, 1.f, -1.f, -1.f, 1.f, -1.f, 1.f, -1.f, 1.f, -1.f, 1.f, -1.f, -1.f, 1.f, 1.f, 1.f, + -1.f, 1.f, -1.f, 1.f, -1.f, 1.f, 1.f, -1.f, -1.f, -1.f, 1.f, 1.f, -1.f, 1.f, 1.f, 1.f, + -1.f, 1.f, -1.f, -1.f, -1.f, -1.f, 1.f, -1.f, -1.f, -1.f, 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, + 1.f, -1.f, 1.f, -1.f, 1.f, 1.f, -1.f, -1.f, -1.f, -1.f, 1.f, 1.f, 1.f, 1.f, -1.f, -1.f, + -1.f, 1.f, -1.f, -1.f, 1.f, 1.f, -1.f, 1.f, -1.f, -1.f, 1.f, -1.f, 1.f, 1.f, -1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, -1.f, 1.f, 1.f, -1.f, 1.f, 1.f, -1.f, 1.f, -1.f, 1.f, 1.f, + 1.f, 1.f, -1.f, 1.f, 1.f, 1.f, 1.f, -1.f, -1.f, 1.f, 1.f, -1.f, -1.f, 1.f, -1.f, -1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, 1.f, -1.f, 1.f, -1.f, 1.f, -1.f, -1.f, 1.f, 1.f, -1.f, + 1.f, -1.f, -1.f, -1.f, -1.f, 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, -1.f, 1.f, 1.f, -1.f, 1.f, + 1.f, -1.f, 1.f, -1.f, 1.f, 1.f, -1.f, -1.f, 1.f, -1.f, -1.f, 1.f, 1.f, 1.f, 1.f, -1.f, + -1.f, -1.f, -1.f, -1.f, -1.f, 1.f, 1.f, 1.f, 1.f, -1.f, -1.f, -1.f, 1.f, 1.f, -1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, -1.f, -1.f, 1.f, -1.f, 1.f, 1.f, 1.f, 1.f, 1.f, -1.f, +}; + +// ---- Fast Walsh-Hadamard Transform (in-place, normalized) ---- +// O(n log n) = 896 ops for n=128 + +static __device__ __forceinline__ void turbo_fwht_128(float * x) { + for (int h = 1; h < 128; h *= 2) { + for (int i = 0; i < 128; i += h * 2) { + for (int j = i; j < i + h; j++) { + float a = x[j]; + float b = x[j + h]; + x[j] = a + b; + x[j + h] = a - b; + } + } + } + const float inv_sqrt_128 = 0.08838834764831845f; + for (int i = 0; i < 128; i++) { + x[i] *= inv_sqrt_128; + } +} + +// ---- Fast Walsh-Hadamard Transform for 64-element groups ---- +// O(n log n) = 384 ops for n=64 + +static __device__ __forceinline__ void turbo_fwht_64(float * x) { + for (int h = 1; h < 64; h *= 2) { + for (int i = 0; i < 64; i += h * 2) { + for (int j = i; j < i + h; j++) { + float a = x[j]; + float b = x[j + h]; + x[j] = a + b; + x[j + h] = a - b; + } + } + } + const float inv_sqrt_64 = 0.125f; + for (int i = 0; i < 64; i++) { + x[i] *= inv_sqrt_64; + } +} + +// ---- Forward rotation: signs1 → FWHT → signs2 ---- + +static __device__ __forceinline__ void turbo_rotate_forward(float * x) { + for (int i = 0; i < 128; i++) x[i] *= TURBO_WHT_SIGNS1[i]; + turbo_fwht_128(x); + for (int i = 0; i < 128; i++) x[i] *= TURBO_WHT_SIGNS2[i]; +} + +// ---- Forward rotation for 64-element groups ---- + +static __device__ __forceinline__ void turbo_rotate_forward_64(float * x) { + for (int i = 0; i < 64; i++) x[i] *= TURBO_WHT_SIGNS1_64[i]; + turbo_fwht_64(x); + for (int i = 0; i < 64; i++) x[i] *= TURBO_WHT_SIGNS2_64[i]; +} + +// ---- InnerQ per-channel equalization ---- +// Equalizes K channel variances before WHT rotation to reduce quantization error. +// Enabled via TURBO_INNERQ=N env var (N = calibration token count). +// Math: = preserves dot products. +// INNERQ_MAX_CHANNELS is defined in turbo-innerq.cuh + +static __device__ float d_innerq_scale[INNERQ_MAX_CHANNELS]; +static __device__ float d_innerq_scale_inv[INNERQ_MAX_CHANNELS]; +static __device__ float d_innerq_sq_accum[INNERQ_MAX_CHANNELS]; +static __device__ int d_innerq_count; +static __device__ int d_innerq_active; // 0 = scales are identity, 1 = scales applied +static __device__ int d_innerq_calibrating; // 1 = accumulating K² stats + +static int innerq_enabled = 0; // host: 0=off, 1=calibrating, 2=active +static int innerq_target_tokens = 0; +static float innerq_strength = 0.15f; +static bool innerq_initialized = false; + +// Host: read TURBO_INNERQ env, start calibration if enabled +static void turbo_innerq_init(void) { + if (innerq_initialized) return; + innerq_initialized = true; + + const char * env = getenv("TURBO_INNERQ"); + if (!env || atoi(env) <= 0) { + innerq_enabled = 0; + return; + } + innerq_target_tokens = atoi(env); + innerq_enabled = 1; // calibrating + + const char * env_str = getenv("TURBO_INNERQ_STRENGTH"); + if (env_str) innerq_strength = atof(env_str); + if (innerq_strength <= 0.0f || innerq_strength > 1.0f) innerq_strength = 0.15f; + + // Zero accumulators and set calibrating flag on device + float zeros[INNERQ_MAX_CHANNELS] = {0}; + int zero = 0, one = 1; + (void)cudaMemcpyToSymbol(d_innerq_sq_accum, zeros, sizeof(zeros)); + (void)cudaMemcpyToSymbol(d_innerq_count, &zero, sizeof(int)); + (void)cudaMemcpyToSymbol(d_innerq_active, &zero, sizeof(int)); + (void)cudaMemcpyToSymbol(d_innerq_calibrating, &one, sizeof(int)); + + GGML_LOG_INFO("%s: InnerQ calibration started (target=%d tokens, strength=%.2f)\n", + __func__, innerq_target_tokens, innerq_strength); +} + +// Host: finalize calibration — compute scales, upload, activate +static void turbo_innerq_finalize(int group_size) { + // Read accumulators from device + float sq_accum[INNERQ_MAX_CHANNELS]; + int count = 0; + (void)cudaMemcpyFromSymbol(sq_accum, d_innerq_sq_accum, group_size * sizeof(float)); + (void)cudaMemcpyFromSymbol(&count, d_innerq_count, sizeof(int)); + + if (count <= 0) { + GGML_LOG_WARN("%s: InnerQ calibration got 0 tokens, disabling\n", __func__); + innerq_enabled = 0; + int zero = 0; + (void)cudaMemcpyToSymbol(d_innerq_calibrating, &zero, sizeof(int)); + return; + } + + // Compute per-channel RMS + float rms[INNERQ_MAX_CHANNELS]; + float mean_rms = 0.0f; + float max_ratio = 0.0f, min_ratio = 1e30f; + for (int i = 0; i < group_size; i++) { + rms[i] = sqrtf(sq_accum[i] / (float)count); + mean_rms += rms[i]; + } + mean_rms /= (float)group_size; + + // RPN: merge RoPE pairs to shared RMS before computing scales. + // RoPE rotates pairs together — independent scales would deform + // the pair circle into an ellipse, causing phase-dependent distortion. + // Pair mapping depends on rope_type: + // NORM: (2i, 2i+1) + // NEOX/MROPE/IMROPE: (i, i + n_rot/2) for i < n_rot/2 + // Only applied to K cache (V is not RoPE'd). + if (g_rpn_config.is_key && g_rpn_config.n_rot >= 2) { + const int n_rot = (g_rpn_config.n_rot <= group_size) ? g_rpn_config.n_rot : group_size; + + // Auto-disable InnerQ for partial-RoPE models (n_rot < group_size). + // InnerQ adds noise to non-RoPE dimensions without benefit. + if (n_rot < group_size) { + GGML_LOG_INFO("%s: partial RoPE detected (n_rot=%d < group=%d), disabling InnerQ\n", + __func__, n_rot, group_size); + innerq_enabled = 0; + int zero = 0; + (void)cudaMemcpyToSymbol(d_innerq_calibrating, &zero, sizeof(int)); + // Set scale_inv to 1.0 (no equalization) + for (int i = 0; i < group_size; i++) { + g_innerq_scale_inv_host[i] = 1.0f; + } + g_innerq_finalized = true; + return; + } + + const int rtype = g_rpn_config.rope_type; + // LLAMA_ROPE_TYPE_NORM == 2 + const bool is_norm = (rtype == 2); + if (is_norm) { + // NORM: pairs are (2i, 2i+1) + for (int i = 0; i + 1 < n_rot; i += 2) { + float pr = sqrtf(0.5f * (rms[i]*rms[i] + rms[i+1]*rms[i+1])); + rms[i] = pr; rms[i+1] = pr; + } + } else { + // NEOX/MROPE/IMROPE: pairs are (i, i + n_rot/2) + const int half = n_rot / 2; + for (int i = 0; i < half; i++) { + float pr = sqrtf(0.5f * (rms[i]*rms[i] + rms[i+half]*rms[i+half])); + rms[i] = pr; rms[i+half] = pr; + } + } + GGML_LOG_INFO("%s: RPN pair merge applied (rope_type=%d, n_rot=%d, mode=%s)\n", + __func__, rtype, n_rot, is_norm ? "norm" : "neox"); + } + + // Compute scale[i] = (mean_rms / channel_rms[i])^strength, clamp to [0.5, 2.0] + float scale[INNERQ_MAX_CHANNELS]; + float scale_inv[INNERQ_MAX_CHANNELS]; + for (int i = 0; i < group_size; i++) { + float ratio = (rms[i] > 1e-10f) ? (mean_rms / rms[i]) : 1.0f; + float s = powf(ratio, innerq_strength); + if (s < 0.5f) s = 0.5f; + if (s > 2.0f) s = 2.0f; + scale[i] = s; + scale_inv[i] = 1.0f / s; + if (ratio > max_ratio) max_ratio = ratio; + if (ratio < min_ratio) min_ratio = ratio; + } + + // Auto-skip if max channel ratio < 1.2 (already balanced) + if (max_ratio < 1.2f && min_ratio > (1.0f / 1.2f)) { + GGML_LOG_INFO("%s: InnerQ auto-disabled (channels already balanced, max_ratio=%.3f)\n", + __func__, max_ratio); + innerq_enabled = 0; + int zero = 0; + (void)cudaMemcpyToSymbol(d_innerq_calibrating, &zero, sizeof(int)); + return; + } + + // Stop calibrating, upload scales, activate + int zero = 0, one = 1; + (void)cudaMemcpyToSymbol(d_innerq_calibrating, &zero, sizeof(int)); + (void)cudaMemcpyToSymbol(d_innerq_scale, scale, group_size * sizeof(float)); + (void)cudaMemcpyToSymbol(d_innerq_scale_inv, scale_inv, group_size * sizeof(float)); + cudaDeviceSynchronize(); // ensure scales are visible before activating + (void)cudaMemcpyToSymbol(d_innerq_active, &one, sizeof(int)); + + innerq_enabled = 2; // active + + // Publish scale_inv to shared host state for cross-TU tensor update + turbo_innerq_publish(scale_inv, group_size); + + GGML_LOG_INFO("%s: InnerQ finalized (%d tokens, max_ratio=%.3f, min_ratio=%.3f)\n", + __func__, count, max_ratio, min_ratio); +} + +// Host: called before each set_rows kernel launch +static void turbo_innerq_check_finalize(int group_size, int64_t ne00) { + if (!innerq_initialized) { + turbo_innerq_init(); + } + if (innerq_enabled == 0) return; + + // InnerQ only works when each WHT group = one head (group_size == head_dim). + // For standard models: ne00 = n_heads * head_dim, group_size = head_dim → ne00 % group_size == 0, fine. + // For non-standard models (head_dim > group_size, e.g. GLM 576 → 64-group): + // ne00 = head_dim (single head), group_size = 64, ne00/group_size = 9 groups per head → WRONG. + // Detect: if ne00 / group_size doesn't divide evenly into standard head counts (1,2,4,8,16,32,64,128), + // it's likely multi-group-per-head. Simpler check: group_size < 128 means head_dim > 128. + // InnerQ only works when group_size == INNERQ_MAX_CHANNELS (128). + // Reject group_size > 128 (Gemma 4 D=256) and < 128 (multi-group-per-head). + // Codex review P0: group_size=256 overflows 128-entry device symbols and stack arrays. + const bool incompatible_group = (group_size != INNERQ_MAX_CHANNELS); + if (incompatible_group) { + if (innerq_enabled >= 1) { + GGML_LOG_WARN("%s: InnerQ disabled (group_size=%d != %d, incompatible)\n", + __func__, group_size, INNERQ_MAX_CHANNELS); + innerq_enabled = 0; + int zero = 0; + (void)cudaMemcpyToSymbol(d_innerq_calibrating, &zero, sizeof(int)); + (void)cudaMemcpyToSymbol(d_innerq_active, &zero, sizeof(int)); + } + return; + } + + // Check if calibration is complete + if (innerq_enabled == 1) { + int count = 0; + (void)cudaMemcpyFromSymbol(&count, d_innerq_count, sizeof(int)); + if (count >= innerq_target_tokens) { + turbo_innerq_finalize(group_size); + } + } +} + +// Host: check if InnerQ is currently active (finalized) +static bool turbo_innerq_is_active(void) { + return innerq_enabled == 2; +} + +// ---- 4-bit centroids (Lloyd-Max for N(0, 1/128)) ---- + +static __constant__ float TURBO_CENTROIDS_4BIT[16] = { + -0.173926f, -0.117195f, -0.089527f, -0.068756f, + -0.051262f, -0.035597f, -0.020989f, -0.006938f, + 0.006938f, 0.020989f, 0.035597f, 0.051262f, + 0.068756f, 0.089527f, 0.117195f, 0.173926f +}; + +// ---- Midpoints for nearest 4-bit centroid lookup ---- + +static __constant__ float TURBO_MID_4BIT[15] = { + -0.145561f, -0.103361f, -0.079142f, -0.060009f, + -0.043430f, -0.028293f, -0.013964f, 0.000000f, + 0.013964f, 0.028293f, 0.043430f, 0.060009f, + 0.079142f, 0.103361f, 0.145561f +}; + +// ---- Nearest 4-bit centroid index ---- + +static __device__ __forceinline__ uint8_t turbo_nearest_centroid_4bit(float val) { + if (val < TURBO_MID_4BIT[ 0]) return 0; + else if (val < TURBO_MID_4BIT[ 1]) return 1; + else if (val < TURBO_MID_4BIT[ 2]) return 2; + else if (val < TURBO_MID_4BIT[ 3]) return 3; + else if (val < TURBO_MID_4BIT[ 4]) return 4; + else if (val < TURBO_MID_4BIT[ 5]) return 5; + else if (val < TURBO_MID_4BIT[ 6]) return 6; + else if (val < TURBO_MID_4BIT[ 7]) return 7; + else if (val < TURBO_MID_4BIT[ 8]) return 8; + else if (val < TURBO_MID_4BIT[ 9]) return 9; + else if (val < TURBO_MID_4BIT[10]) return 10; + else if (val < TURBO_MID_4BIT[11]) return 11; + else if (val < TURBO_MID_4BIT[12]) return 12; + else if (val < TURBO_MID_4BIT[13]) return 13; + else if (val < TURBO_MID_4BIT[14]) return 14; + else return 15; +} + +// ---- Per-block quantize for turbo4 (128 elements, expects already-rotated input) ---- + +static __device__ void quantize_f32_turbo4_0_block(const float * __restrict__ src, + block_turbo4_0 * __restrict__ dst) { + for (int j = 0; j < QK_TURBO4 / 2; j++) dst->qs[j] = 0; + + for (int j = 0; j < QK_TURBO4; j++) { + uint8_t idx = turbo_nearest_centroid_4bit(src[j]); + dst->qs[j / 2] |= (idx & 0xF) << ((j % 2) * 4); + } +} + +// ---- Inline dequant helper: extract one float from turbo4 block ---- + +static __device__ __forceinline__ float turbo4_dequant_element( + const block_turbo4_0 * __restrict__ x, int j, float norm) { + uint8_t idx = (x->qs[j / 2] >> ((j % 2) * 4)) & 0xF; + return TURBO_CENTROIDS_4BIT[idx] * norm; +} + +// ---- Nearest 3-bit centroid index ---- + +static __device__ __forceinline__ uint8_t turbo_nearest_centroid_3bit(float val) { + if (val < TURBO_MID_3BIT[0]) return 0; + else if (val < TURBO_MID_3BIT[1]) return 1; + else if (val < TURBO_MID_3BIT[2]) return 2; + else if (val < TURBO_MID_3BIT[3]) return 3; + else if (val < TURBO_MID_3BIT[4]) return 4; + else if (val < TURBO_MID_3BIT[5]) return 5; + else if (val < TURBO_MID_3BIT[6]) return 6; + else return 7; +} + +// ---- Per-block quantize (32 elements, expects already-rotated input) ---- +// Used by set_rows after group-level WHT rotation + +static __device__ void quantize_f32_turbo3_0_block(const float * __restrict__ src, + block_turbo3_0 * __restrict__ dst) { + for (int j = 0; j < QK_TURBO3 / 4; j++) dst->qs[j] = 0; + for (int j = 0; j < QK_TURBO3 / 8; j++) dst->signs[j] = 0; + + for (int j = 0; j < QK_TURBO3; j++) { + uint8_t idx = turbo_nearest_centroid_3bit(src[j]); + dst->qs[j / 4] |= (idx & 0x3) << ((j % 4) * 2); + if (idx & 0x4) { + dst->signs[j / 8] |= (1 << (j % 8)); + } + } +} + +// ---- Inline dequant helper: extract one float from turbo3 block ---- + +static __device__ __forceinline__ float turbo3_dequant_element( + const block_turbo3_0 * __restrict__ x, int j, float norm) { + uint8_t low2 = (x->qs[j / 4] >> ((j % 4) * 2)) & 0x3; + uint8_t hi1 = (x->signs[j / 8] >> (j % 8)) & 0x1; + uint8_t idx = low2 | (hi1 << 2); + return TURBO_CENTROIDS_3BIT[idx] * norm; +} + +// ---- Nearest 2-bit centroid index ---- + +static __device__ __forceinline__ uint8_t turbo_nearest_centroid_2bit(float val) { + if (val < TURBO_MID_2BIT[0]) return 0; + else if (val < TURBO_MID_2BIT[1]) return 1; + else if (val < TURBO_MID_2BIT[2]) return 2; + else return 3; +} + +// ---- Per-block quantize for turbo2 (32 elements, expects already-rotated input) ---- + +static __device__ void quantize_f32_turbo2_0_block(const float * __restrict__ src, + block_turbo2_0 * __restrict__ dst) { + for (int j = 0; j < QK_TURBO2 / 4; j++) dst->qs[j] = 0; + + for (int j = 0; j < QK_TURBO2; j++) { + uint8_t idx = turbo_nearest_centroid_2bit(src[j]); + dst->qs[j / 4] |= (idx & 0x3) << ((j % 4) * 2); + } +} + +// ---- Inline dequant helper: extract one float from turbo2 block ---- + +static __device__ __forceinline__ float turbo2_dequant_element( + const block_turbo2_0 * __restrict__ x, int j, float norm) { + uint8_t idx = (x->qs[j / 4] >> ((j % 4) * 2)) & 0x3; + return TURBO_CENTROIDS_2BIT[idx] * norm; +} + +// ============================================================================ +// Weight compression types (TQ3_1S, TQ4_1S) +// These use N(0,1) centroids (NOT N(0,1/128) like KV cache types) +// and require inverse WHT (RHT) after centroid lookup. +// ============================================================================ + +#define QR_TQ4_1S 1 // dequantize produces 2 consecutive elements +#define QR_TQ3_1S 1 + +// ---- Weight centroids: Lloyd-Max for N(0,1) ---- + +static __constant__ float TQ4_CENTROIDS_WEIGHT[16] = { + -2.732590f, -2.069017f, -1.618046f, -1.256231f, + -0.942340f, -0.656759f, -0.388048f, -0.128395f, + 0.128395f, 0.388048f, 0.656759f, 0.942340f, + 1.256231f, 1.618046f, 2.069017f, 2.732590f +}; + +static __constant__ float TQ3_CENTROIDS_WEIGHT[8] = { + -1.996684f, -1.291398f, -0.740341f, -0.247508f, + 0.230106f, 0.725222f, 1.277503f, 1.988943f +}; + +// ---- Sign array for weight WHT (golden ratio hash, 32 elements) ---- + +static __constant__ float TQ_WEIGHT_SIGNS[32] = { + +1.0f, -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, +1.0f, + -1.0f, -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, +1.0f, + -1.0f, -1.0f, +1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f, + -1.0f, +1.0f, +1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f +}; + +// 2D VQ codebook: 64 entries, trained on actual Qwen3-8B WHT output pairs +static __constant__ float TURBO_VQ2D_X[64] = { + 0.0279071f, -0.1041781f, -0.0497183f, 0.0836585f, 0.0755566f, -0.1593080f, -0.0472192f, 0.1499346f, + -0.0259202f, -0.0749334f, -0.1060147f, -0.1302685f, 0.0510575f, 0.0321239f, 0.0427720f, 0.2017132f, + -0.0174130f, 0.0938271f, 0.1514418f, -0.1524931f, -0.0659325f, -0.1347785f, 0.1569419f, 0.0335782f, + 0.2139767f, 0.0298571f, 0.1024047f, -0.1463255f, -0.0380896f, -0.1880937f, 0.1287539f, -0.0810642f, + -0.0230893f, -0.0325119f, -0.0495625f, 0.0664514f, 0.1864402f, 0.0794077f, -0.2225531f, 0.0198063f, + -0.0478895f, 0.1485750f, 0.0846328f, 0.0470138f, 0.0562434f, -0.1950971f, 0.0961574f, -0.0095595f, + -0.0900242f, -0.0080224f, -0.0094565f, -0.1106773f, -0.0637866f, -0.1312685f, 0.0118203f, 0.0150917f, + 0.1209811f, -0.0833506f, -0.1212273f, 0.0995258f, -0.0725997f, 0.1161496f, 0.0609390f, -0.0160979f, +}; +static __constant__ float TURBO_VQ2D_Y[64] = { + -0.0263300f, 0.0685406f, -0.1090837f, 0.1035094f, -0.0896168f, -0.0125089f, -0.0671406f, -0.0187005f, + 0.0717508f, 0.1467829f, -0.0184862f, -0.1144251f, 0.0793044f, -0.1656622f, 0.1358503f, 0.0923961f, + -0.0055588f, 0.0639664f, -0.1557487f, 0.0507863f, 0.0079050f, 0.1759942f, 0.1642957f, -0.1138678f, + 0.0008668f, -0.0694018f, 0.0207315f, -0.1742128f, -0.2115104f, 0.1064816f, 0.1005936f, -0.1476948f, + 0.0305579f, 0.1157162f, -0.0274784f, -0.0479926f, -0.0830491f, -0.2145961f, 0.0121470f, 0.0135110f, + 0.2169799f, 0.0451792f, -0.1392938f, 0.2095770f, 0.0380025f, -0.0849720f, 0.1537713f, -0.0449162f, + -0.0923609f, 0.1603929f, -0.0886196f, 0.0234268f, 0.0493861f, -0.0623466f, 0.1004377f, 0.0549886f, + -0.1051118f, -0.0522899f, 0.1113740f, -0.0216251f, 0.0940127f, -0.0629645f, -0.0015928f, -0.1436934f, +}; + +// 2D VQ dequant pair helper +static __device__ __forceinline__ float2 turbo3_dequant_pair( + const block_turbo3_0 * __restrict__ x, int j_even, float norm) { + uint8_t low2_0 = (x->qs[j_even / 4] >> ((j_even % 4) * 2)) & 0x3; + uint8_t hi1_0 = (x->signs[j_even / 8] >> (j_even % 8)) & 0x1; + uint8_t idx0 = low2_0 | (hi1_0 << 2); + int j_odd = j_even + 1; + uint8_t low2_1 = (x->qs[j_odd / 4] >> ((j_odd % 4) * 2)) & 0x3; + uint8_t hi1_1 = (x->signs[j_odd / 8] >> (j_odd % 8)) & 0x1; + uint8_t idx1 = low2_1 | (hi1_1 << 2); + uint8_t vq = (idx0 << 3) | idx1; + return make_float2(TURBO_VQ2D_X[vq] * norm, TURBO_VQ2D_Y[vq] * norm); +} diff --git a/ggml/src/ggml-cuda/turbo-wht.cu b/ggml/src/ggml-cuda/turbo-wht.cu new file mode 100644 index 00000000000..6359ae76ba7 --- /dev/null +++ b/ggml/src/ggml-cuda/turbo-wht.cu @@ -0,0 +1,203 @@ +#include "turbo-quant.cuh" +#include "turbo-wht.cuh" + +// ─── CUDA kernel ────────────────────────────────────────────────────────────── +// +// Templated on direction and group_size (128 or 64). +// One block per group, group_size threads per block. +// direction: 0 = forward (signs1 → WHT → signs2), 1 = inverse (signs2 → WHT → signs1) +// +// When head_dim is not a multiple of group_size, only the full groups +// within each head are processed. Tail elements are left unchanged (identity). +// +// Algorithm mirrors the CPU implementation in ggml-cpu/ops.cpp: +// 1. Apply s_first elementwise +// 2. Radix-2 Hadamard butterfly (log2(group_size) stages, in-place) +// 3. Normalize by 1/sqrt(group_size) and apply s_second elementwise +// +// InnerQ scale_inv: when non-null, applies per-channel inverse scaling for +// Q/V equalization. For forward (Q rotation): multiply BEFORE signs+WHT. +// For inverse (V un-rotation): multiply AFTER WHT+signs. + +template +static __global__ void k_turbo_wht_f32(const float * __restrict__ src, + float * __restrict__ dst, + const float * __restrict__ scale_inv, + int64_t n_groups, + int64_t head_dim, + int64_t groups_per_head) { + static_assert(group_size == 256 || group_size == 128 || group_size == 64 || group_size == 32, "group_size must be 32, 64, 128, or 256"); + + const int64_t g = blockIdx.x; + if (g >= n_groups) return; + + const int t = threadIdx.x; // 0 .. group_size-1 + + // Map group index to position in the tensor: + // each head has groups_per_head full groups, then a gap of tail elements. + const int64_t head_idx = g / groups_per_head; + const int64_t grp_in_head = g % groups_per_head; + const int64_t base = head_idx * head_dim + grp_in_head * group_size; + + __shared__ float x[group_size]; + + // Load from global memory + x[t] = src[base + t]; + __syncthreads(); + + // InnerQ forward: apply scale_inv BEFORE signs+WHT (for Q pre-rotation) + if (direction == 0 && scale_inv != nullptr) { + x[t] *= scale_inv[t % group_size]; + __syncthreads(); + } + + // Apply first sign array + if (group_size == 256) { + x[t] *= (direction == 0) ? TURBO_WHT_SIGNS1_256[t] : TURBO_WHT_SIGNS2_256[t]; + } else if (group_size == 128) { + x[t] *= (direction == 0) ? TURBO_WHT_SIGNS1[t] : TURBO_WHT_SIGNS2[t]; + } else if (group_size == 64) { + x[t] *= (direction == 0) ? TURBO_WHT_SIGNS1_64[t] : TURBO_WHT_SIGNS2_64[t]; + } else { + // group_size == 32: TQ weight signs (same for forward and inverse) + x[t] *= TQ_WEIGHT_SIGNS[t]; + } + __syncthreads(); + + // WHT butterfly — log2(group_size) stages. + // In stage h, threads where (t % (2h)) < h read x[t] and x[t+h], + // then write x[t] = a+b and x[t+h] = a-b. Each active thread + // owns a disjoint pair, so no intra-stage conflicts exist. +#define WHT_STAGE(h) \ + if (t % (2*(h)) < (h)) { float a = x[t], b = x[t+(h)]; x[t] = a+b; x[t+(h)] = a-b; } \ + __syncthreads(); + + WHT_STAGE(1) + WHT_STAGE(2) + WHT_STAGE(4) + WHT_STAGE(8) + WHT_STAGE(16) + if (group_size >= 64) { WHT_STAGE(32) } + if (group_size >= 128) { WHT_STAGE(64) } + if (group_size >= 256) { WHT_STAGE(128) } +#undef WHT_STAGE + + // Normalize and apply second sign array, write to output + constexpr float inv_sqrt = (group_size == 256) ? 0.0625f : + (group_size == 128) ? 0.08838834764831845f : + (group_size == 64) ? 0.125f : + 0.17677669529663688f; // 1/sqrt(32) + float result; + if (group_size == 256) { + result = x[t] * inv_sqrt * + ((direction == 0) ? TURBO_WHT_SIGNS2_256[t] : TURBO_WHT_SIGNS1_256[t]); + } else if (group_size == 128) { + result = x[t] * inv_sqrt * + ((direction == 0) ? TURBO_WHT_SIGNS2[t] : TURBO_WHT_SIGNS1[t]); + } else if (group_size == 64) { + result = x[t] * inv_sqrt * + ((direction == 0) ? TURBO_WHT_SIGNS2_64[t] : TURBO_WHT_SIGNS1_64[t]); + } else { + // group_size == 32: normalize only (signs already applied before butterfly) + result = x[t] * inv_sqrt; + } + + // InnerQ inverse: apply scale_inv AFTER WHT+signs (for V un-rotation) + if (direction == 1 && scale_inv != nullptr) { + result *= scale_inv[t % group_size]; + } + + dst[base + t] = result; +} + +// ─── Simple copy kernel for tail elements (identity pass-through) ──────────── + +static __global__ void k_turbo_wht_copy_tail(const float * __restrict__ src, + float * __restrict__ dst, + int64_t n_heads, + int64_t head_dim, + int64_t tail_offset, + int tail_size) { + const int64_t i = (int64_t)blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n_heads * tail_size) return; + + const int64_t head_idx = i / tail_size; + const int64_t tail_elem = i % tail_size; + const int64_t offset = head_idx * head_dim + tail_offset + tail_elem; + dst[offset] = src[offset]; +} + +// ─── Dispatch ───────────────────────────────────────────────────────────────── + +void ggml_cuda_turbo_wht(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src = dst->src[0]; + const ggml_tensor * scale_tensor = dst->src[1]; // InnerQ scale_inv (may be NULL) + + GGML_ASSERT(src->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + int direction; + int group_size; + memcpy(&direction, dst->op_params + 0, sizeof(int)); + memcpy(&group_size, dst->op_params + sizeof(int), sizeof(int)); + + const int64_t head_dim = src->ne[0]; + const int64_t n_heads = ggml_nelements(src) / head_dim; + + GGML_ASSERT(group_size == 32 || group_size == 64 || group_size == 128 || group_size == 256); + const int64_t groups_per_head = head_dim / group_size; + const int tail_size = (int)(head_dim % group_size); + const int64_t n_groups = groups_per_head * n_heads; + + const float * src_ptr = (const float *) src->data; + float * dst_ptr = (float *) dst->data; + const float * scale_inv_ptr = scale_tensor ? (const float *) scale_tensor->data : nullptr; + + cudaStream_t stream = ctx.stream(); + + // Process full groups + if (n_groups > 0) { + dim3 blocks(n_groups); + if (group_size == 256) { + dim3 threads(256); + if (direction == 0) { + k_turbo_wht_f32<0, 256><<>>(src_ptr, dst_ptr, scale_inv_ptr, n_groups, head_dim, groups_per_head); + } else { + k_turbo_wht_f32<1, 256><<>>(src_ptr, dst_ptr, scale_inv_ptr, n_groups, head_dim, groups_per_head); + } + } else if (group_size == 128) { + dim3 threads(128); + if (direction == 0) { + k_turbo_wht_f32<0, 128><<>>(src_ptr, dst_ptr, scale_inv_ptr, n_groups, head_dim, groups_per_head); + } else { + k_turbo_wht_f32<1, 128><<>>(src_ptr, dst_ptr, scale_inv_ptr, n_groups, head_dim, groups_per_head); + } + } else if (group_size == 64) { + dim3 threads(64); + if (direction == 0) { + k_turbo_wht_f32<0, 64><<>>(src_ptr, dst_ptr, scale_inv_ptr, n_groups, head_dim, groups_per_head); + } else { + k_turbo_wht_f32<1, 64><<>>(src_ptr, dst_ptr, scale_inv_ptr, n_groups, head_dim, groups_per_head); + } + } else { + dim3 threads(32); + if (direction == 0) { + k_turbo_wht_f32<0, 32><<>>(src_ptr, dst_ptr, scale_inv_ptr, n_groups, head_dim, groups_per_head); + } else { + k_turbo_wht_f32<1, 32><<>>(src_ptr, dst_ptr, scale_inv_ptr, n_groups, head_dim, groups_per_head); + } + } + } + + // Pass through tail elements unchanged (no rotation) + // Not needed for 64-aligned dims but kept for completeness + if (tail_size > 0) { + const int64_t total_tail = n_heads * tail_size; + const int block_sz = 256; + const int n_blocks = (int)((total_tail + block_sz - 1) / block_sz); + k_turbo_wht_copy_tail<<>>( + src_ptr, dst_ptr, n_heads, head_dim, groups_per_head * group_size, tail_size); + } +} diff --git a/ggml/src/ggml-cuda/turbo-wht.cuh b/ggml/src/ggml-cuda/turbo-wht.cuh new file mode 100644 index 00000000000..3038a1ab082 --- /dev/null +++ b/ggml/src/ggml-cuda/turbo-wht.cuh @@ -0,0 +1,5 @@ +#pragma once + +#include "common.cuh" + +void ggml_cuda_turbo_wht(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index a7d4e0ea2b5..7ce95ab8a02 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -62,6 +62,8 @@ list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu") file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu") +# Exclude D>=576 tile kernels: exceed HIP local memory limit (67584 > 65536 bytes) +list(FILTER SRCS EXCLUDE REGEX "dkq(576|640)") list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) @@ -79,7 +81,22 @@ else() ../ggml-cuda/template-instances/fattn-vec-instance-f16-f16.cu ../ggml-cuda/template-instances/fattn-vec-instance-q4_0-q4_0.cu ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-q8_0.cu - ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu) + ../ggml-cuda/template-instances/fattn-vec-instance-bf16-bf16.cu + ../ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo3_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo2_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo2_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo2_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo3_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-q8_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-q8_0-turbo4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo3_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-turbo3_0-turbo4_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-turbo4_0-turbo2_0.cu + ../ggml-cuda/template-instances/fattn-vec-instance-turbo2_0-turbo4_0.cu) endif() ggml_add_backend_library(ggml-hip diff --git a/ggml/src/ggml-turbo-quant.c b/ggml/src/ggml-turbo-quant.c new file mode 100644 index 00000000000..a4972f58cd0 --- /dev/null +++ b/ggml/src/ggml-turbo-quant.c @@ -0,0 +1,1035 @@ +/* + * TurboQuant: KV cache compression via PolarQuant + QJL + * Based on: arXiv 2504.19874 (ICLR 2026) + * + * Implements GGML_TYPE_TURBO2_0 (2-bit), GGML_TYPE_TURBO3_0 (3-bit) and + * GGML_TYPE_TURBO4_0 (4-bit) for use as --cache-type-k turboN in llama-server. + */ + +#include "ggml-quants.h" +#include "ggml-common.h" +#include "ggml-impl.h" + +#define _USE_MATH_DEFINES +#include +#include +#include +#include + +#ifndef M_PI +#define M_PI 3.14159265358979323846 +#endif + +/* Global: WHT group size for CPU quantize path (set by CPU SET_ROWS handler) */ +GGML_API int turbo3_cpu_wht_group_size = 0; + +/* ---------- constants ---------- */ + +#define TURBO_SEED_ROTATION 42 +#define TURBO_SEED_QJL 1042 +#define TURBO_D 128 /* rotation group size = head_dim (independent of block size) */ +#define TURBO_QJL_CONST 1.2533141373155003f /* sqrt(pi/2) */ + +/* Optimal centroids from paper (scaled by 1/sqrt(d)) */ +/* 2-bit: {±0.453, ±1.51} / sqrt(d) */ +static const float CENTROIDS_2BIT[4] = { -0.133462f, -0.039994f, 0.039994f, 0.133462f }; + +/* 3-bit: Lloyd-Max for N(0, 1/128), pre-computed (legacy scalar, used by turbo4) */ +static const float CENTROIDS_3BIT[8] = { + -0.190685f, -0.117832f, -0.065717f, -0.021460f, + 0.021460f, 0.065717f, 0.117832f, 0.190685f +}; + +/* 2D VQ codebook (64 entries, K-means trained on 74K WHT output pairs) */ +static const float TURBO_VQ2D_X[64] = { + 0.0279071f, -0.1041781f, -0.0497183f, 0.0836585f, 0.0755566f, -0.1593080f, -0.0472192f, 0.1499346f, + -0.0259202f, -0.0749334f, -0.1060147f, -0.1302685f, 0.0510575f, 0.0321239f, 0.0427720f, 0.2017132f, + -0.0174130f, 0.0938271f, 0.1514418f, -0.1524931f, -0.0659325f, -0.1347785f, 0.1569419f, 0.0335782f, + 0.2139767f, 0.0298571f, 0.1024047f, -0.1463255f, -0.0380896f, -0.1880937f, 0.1287539f, -0.0810642f, + -0.0230893f, -0.0325119f, -0.0495625f, 0.0664514f, 0.1864402f, 0.0794077f, -0.2225531f, 0.0198063f, + -0.0478895f, 0.1485750f, 0.0846328f, 0.0470138f, 0.0562434f, -0.1950971f, 0.0961574f, -0.0095595f, + -0.0900242f, -0.0080224f, -0.0094565f, -0.1106773f, -0.0637866f, -0.1312685f, 0.0118203f, 0.0150917f, + 0.1209811f, -0.0833506f, -0.1212273f, 0.0995258f, -0.0725997f, 0.1161496f, 0.0609390f, -0.0160979f, +}; +static const float TURBO_VQ2D_Y[64] = { + -0.0263300f, 0.0685406f, -0.1090837f, 0.1035094f, -0.0896168f, -0.0125089f, -0.0671406f, -0.0187005f, + 0.0717508f, 0.1467829f, -0.0184862f, -0.1144251f, 0.0793044f, -0.1656622f, 0.1358503f, 0.0923961f, + -0.0055588f, 0.0639664f, -0.1557487f, 0.0507863f, 0.0079050f, 0.1759942f, 0.1642957f, -0.1138678f, + 0.0008668f, -0.0694018f, 0.0207315f, -0.1742128f, -0.2115104f, 0.1064816f, 0.1005936f, -0.1476948f, + 0.0305579f, 0.1157162f, -0.0274784f, -0.0479926f, -0.0830491f, -0.2145961f, 0.0121470f, 0.0135110f, + 0.2169799f, 0.0451792f, -0.1392938f, 0.2095770f, 0.0380025f, -0.0849720f, 0.1537713f, -0.0449162f, + -0.0923609f, 0.1603929f, -0.0886196f, 0.0234268f, 0.0493861f, -0.0623466f, 0.1004377f, 0.0549886f, + -0.1051118f, -0.0522899f, 0.1113740f, -0.0216251f, 0.0940127f, -0.0629645f, -0.0015928f, -0.1436934f, +}; + +/* ---------- rotation matrix (lazy init) ---------- */ + +static float turbo_rotation[TURBO_D * TURBO_D]; +static float turbo_rotation_t[TURBO_D * TURBO_D]; /* transpose */ +static int turbo_rotation_initialized = 0; + +/* Simple LCG PRNG for deterministic rotation generation */ +static uint64_t turbo_prng_state; + +static void turbo_prng_seed(uint64_t seed) { + turbo_prng_state = seed; +} + +static double turbo_prng_normal(void) { + /* Box-Muller transform from uniform LCG */ + turbo_prng_state = turbo_prng_state * 6364136223846793005ULL + 1442695040888963407ULL; + double u1 = (double)(turbo_prng_state >> 11) / (double)(1ULL << 53); + if (u1 < 1e-15) u1 = 1e-15; + turbo_prng_state = turbo_prng_state * 6364136223846793005ULL + 1442695040888963407ULL; + double u2 = (double)(turbo_prng_state >> 11) / (double)(1ULL << 53); + return sqrt(-2.0 * log(u1)) * cos(2.0 * M_PI * u2); +} + +static void turbo_init_rotation(void) { + if (turbo_rotation_initialized) return; + + const int d = TURBO_D; + + /* Generate random Gaussian matrix */ + turbo_prng_seed(TURBO_SEED_ROTATION); + float G[TURBO_D * TURBO_D]; + for (int i = 0; i < d * d; i++) { + G[i] = (float)turbo_prng_normal(); + } + + /* QR decomposition via modified Gram-Schmidt */ + /* Q stored column-major in turbo_rotation */ + memcpy(turbo_rotation, G, d * d * sizeof(float)); + + for (int j = 0; j < d; j++) { + /* Normalize column j */ + float norm = 0.0f; + for (int i = 0; i < d; i++) { + norm += turbo_rotation[i * d + j] * turbo_rotation[i * d + j]; + } + norm = sqrtf(norm); + if (norm > 1e-10f) { + for (int i = 0; i < d; i++) { + turbo_rotation[i * d + j] /= norm; + } + } + + /* Orthogonalize remaining columns against j */ + for (int k = j + 1; k < d; k++) { + float dot = 0.0f; + for (int i = 0; i < d; i++) { + dot += turbo_rotation[i * d + j] * turbo_rotation[i * d + k]; + } + for (int i = 0; i < d; i++) { + turbo_rotation[i * d + k] -= dot * turbo_rotation[i * d + j]; + } + } + } + + /* Compute transpose */ + for (int i = 0; i < d; i++) { + for (int j = 0; j < d; j++) { + turbo_rotation_t[i * d + j] = turbo_rotation[j * d + i]; + } + } + + turbo_rotation_initialized = 1; +} + +/* ---------- QJL projection matrix (lazy init, seed-based) ---------- */ + +static float turbo_qjl_matrix[TURBO_D * TURBO_D]; +static float turbo_qjl_matrix_t[TURBO_D * TURBO_D]; +static int turbo_qjl_initialized = 0; + +static void turbo_init_qjl(void) { + if (turbo_qjl_initialized) return; + + const int d = TURBO_D; + turbo_prng_seed(TURBO_SEED_QJL); + + for (int i = 0; i < d * d; i++) { + turbo_qjl_matrix[i] = (float)turbo_prng_normal(); + } + + /* Transpose */ + for (int i = 0; i < d; i++) { + for (int j = 0; j < d; j++) { + turbo_qjl_matrix_t[i * d + j] = turbo_qjl_matrix[j * d + i]; + } + } + + turbo_qjl_initialized = 1; +} + +/* ---------- helper: matrix-vector multiply ---------- */ + +static void matvec(const float * M, const float * x, float * y, int d) { + /* y = M @ x, M is row-major d×d */ + for (int i = 0; i < d; i++) { + float sum = 0.0f; + for (int j = 0; j < d; j++) { + sum += M[i * d + j] * x[j]; + } + y[i] = sum; + } +} + +/* ---------- nearest centroid ---------- */ + +static int nearest_centroid_2bit(float val) { + /* Binary search on midpoints: {-0.133, -0.040, 0.040, 0.133} */ + if (val < -0.086728f) return 0; /* midpoint(-0.133, -0.040) */ + if (val < 0.000000f) return 1; /* midpoint(-0.040, 0.040) */ + if (val < 0.086728f) return 2; /* midpoint(0.040, 0.133) */ + return 3; +} + +static int nearest_centroid_3bit(float val) { + /* 8 centroids, find nearest via midpoints */ + if (val < -0.154259f) return 0; + if (val < -0.091775f) return 1; + if (val < -0.043589f) return 2; + if (val < 0.000000f) return 3; + if (val < 0.043589f) return 4; + if (val < 0.091775f) return 5; + if (val < 0.154259f) return 6; + return 7; +} + +static int nearest_centroid_4bit(float val) { + /* 16 centroids, optimal for N(0, 1/sqrt(128)), find nearest via midpoints */ + if (val < -0.145560f) return 0; + if (val < -0.103361f) return 1; + if (val < -0.079142f) return 2; + if (val < -0.060009f) return 3; + if (val < -0.043430f) return 4; + if (val < -0.028293f) return 5; + if (val < -0.013963f) return 6; + if (val < 0.000000f) return 7; + if (val < 0.013963f) return 8; + if (val < 0.028293f) return 9; + if (val < 0.043430f) return 10; + if (val < 0.060009f) return 11; + if (val < 0.079142f) return 12; + if (val < 0.103361f) return 13; + if (val < 0.145560f) return 14; + return 15; +} + +/* ---------- WHT sign arrays (must match CUDA/Metal, seed=42) ---------- */ + +static const float turbo_cpu_s1[128] = { + -1,1,1,-1,-1,1,-1,1,-1,-1,1,1,1,1,1,1,1,-1,1,-1,1,-1,-1,1,1,1,-1,1,1,-1,-1,-1, + -1,1,1,-1,1,1,-1,1,-1,1,1,-1,-1,1,-1,1,1,1,1,-1,-1,-1,-1,-1,1,-1,1,1,1,1,-1,1, + -1,-1,1,-1,-1,-1,1,-1,-1,-1,1,-1,-1,-1,1,1,1,-1,-1,1,1,1,-1,-1,1,1,-1,1,1,-1,1,-1, + -1,1,1,-1,1,-1,1,-1,1,1,1,1,-1,1,-1,1,1,-1,1,1,-1,-1,-1,-1,-1,1,1,-1,1,1,-1,1 +}; + +static const float turbo_cpu_s2[128] = { + 1,1,1,1,-1,1,1,-1,1,-1,-1,-1,1,-1,-1,-1,1,1,-1,-1,1,-1,1,-1,1,-1,-1,1,-1,1,1,1, + 1,1,-1,-1,-1,1,-1,-1,-1,-1,-1,-1,1,1,1,-1,1,-1,1,1,1,-1,-1,1,-1,-1,-1,-1,-1,-1,1,1, + 1,-1,1,-1,-1,-1,-1,1,-1,1,-1,1,-1,-1,1,1,-1,1,-1,1,1,-1,1,-1,-1,-1,-1,1,-1,-1,1,-1, + 1,-1,1,1,1,-1,-1,1,-1,1,-1,1,1,-1,-1,1,-1,1,-1,1,1,-1,1,-1,1,-1,-1,-1,-1,-1,1,-1 +}; + +/* ---------- CPU forward WHT (in-place, group_size elements) ---------- */ + +static void turbo_cpu_fwht(float * x, int group_size) { + const float * s1 = turbo_cpu_s1; + const float * s2 = turbo_cpu_s2; + const float inv_sqrt = (group_size == 128) ? 0.08838834764831845f : 0.125f; + + // signs1 + for (int i = 0; i < group_size; i++) x[i] *= s1[i]; + + // butterfly stages + for (int h = 1; h < group_size; h *= 2) { + for (int i = 0; i < group_size; i += h * 2) { + for (int j = i; j < i + h; j++) { + float a = x[j], b = x[j + h]; + x[j] = a + b; + x[j + h] = a - b; + } + } + } + + // normalize + signs2 + for (int i = 0; i < group_size; i++) x[i] *= inv_sqrt * s2[i]; +} + +/* ---------- TURBO3_0: 3-bit PolarQuant with WHT rotation ---------- */ + +void quantize_row_turbo3_0_ref(const float * GGML_RESTRICT x, block_turbo3_0 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_TURBO3 == 0); + + // Read WHT group size from global (set by CPU SET_ROWS handler before each call). + // Fallback: 128 if row is 128-aligned, else 64. + extern int turbo3_cpu_wht_group_size; + int group_size = turbo3_cpu_wht_group_size; + if (group_size != 64 && group_size != 128) { + group_size = (k % 128 == 0) ? 128 : 64; + } + if (k % group_size != 0) group_size = (group_size == 128) ? 64 : 128; + assert(k % group_size == 0); + + const int n_groups = k / group_size; + const int blocks_per_group = group_size / QK_TURBO3; + + for (int g = 0; g < n_groups; g++) { + const float * grp_src = x + g * group_size; + block_turbo3_0 * grp_dst = y + g * blocks_per_group; + + // 1. L2 norm over the group + float norm_sq = 0.0f; + float buf[128]; // max group_size + for (int j = 0; j < group_size; j++) { + buf[j] = grp_src[j]; + norm_sq += buf[j] * buf[j]; + } + float grp_norm = sqrtf(norm_sq); + float inv_norm = (grp_norm > 1e-10f) ? 1.0f / grp_norm : 0.0f; + + // 2. Normalize + for (int j = 0; j < group_size; j++) buf[j] *= inv_norm; + + // 3. Forward WHT rotation + turbo_cpu_fwht(buf, group_size); + + // 4. 2D VQ quantize pairs + pack into sub-blocks + float recon_sq = 0.0f; + for (int b = 0; b < blocks_per_group; b++) { + block_turbo3_0 * blk = &grp_dst[b]; + const int off = b * QK_TURBO3; + + memset(blk->qs, 0, QK_TURBO3 / 4); + memset(blk->signs, 0, QK_TURBO3 / 8); + + for (int j = 0; j < QK_TURBO3; j += 2) { + float vx = buf[off + j]; + float vy = buf[off + j + 1]; + // Brute-force 64-entry 2D VQ search + uint8_t best_vq = 0; + float best_dist = 1e30f; + for (int c = 0; c < 64; c++) { + float dx = vx - TURBO_VQ2D_X[c]; + float dy = vy - TURBO_VQ2D_Y[c]; + float d = dx*dx + dy*dy; + if (d < best_dist) { best_dist = d; best_vq = (uint8_t)c; } + } + // Even element: high 3 bits, odd element: low 3 bits + uint8_t idx_even = (best_vq >> 3) & 0x7; + uint8_t idx_odd = best_vq & 0x7; + blk->qs[j / 4] |= (idx_even & 0x3) << ((j % 4) * 2); + blk->qs[(j+1) / 4] |= (idx_odd & 0x3) << (((j+1) % 4) * 2); + if (idx_even & 0x4) blk->signs[j / 8] |= (1 << (j % 8)); + if (idx_odd & 0x4) blk->signs[(j+1) / 8] |= (1 << ((j+1) % 8)); + recon_sq += TURBO_VQ2D_X[best_vq] * TURBO_VQ2D_X[best_vq] + + TURBO_VQ2D_Y[best_vq] * TURBO_VQ2D_Y[best_vq]; + } + } + + // 5. Corrected norm: grp_norm / recon_norm (matching CUDA kernel) + float recon_norm = sqrtf(recon_sq); + float corrected = (recon_norm > 1e-10f) ? grp_norm / recon_norm : grp_norm; + for (int b = 0; b < blocks_per_group; b++) { + grp_dst[b].norm = GGML_FP32_TO_FP16(corrected); + } + } +} + +void dequantize_row_turbo3_0(const block_turbo3_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_TURBO3 == 0); + const int nb = k / QK_TURBO3; + for (int block = 0; block < nb; block++) { + float norm = GGML_FP16_TO_FP32(x[block].norm); + for (int j = 0; j < QK_TURBO3; j += 2) { + uint8_t low2_e = (x[block].qs[j/4] >> ((j%4)*2)) & 0x3; + uint8_t hi1_e = (x[block].signs[j/8] >> (j%8)) & 0x1; + uint8_t low2_o = (x[block].qs[(j+1)/4] >> (((j+1)%4)*2)) & 0x3; + uint8_t hi1_o = (x[block].signs[(j+1)/8] >> ((j+1)%8)) & 0x1; + uint8_t idx_even = low2_e | (hi1_e << 2); + uint8_t idx_odd = low2_o | (hi1_o << 2); + uint8_t vq = (idx_even << 3) | idx_odd; + y[block * QK_TURBO3 + j] = TURBO_VQ2D_X[vq] * norm; + y[block * QK_TURBO3 + j + 1] = TURBO_VQ2D_Y[vq] * norm; + } + } +} + +size_t quantize_turbo3_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, + int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_UNUSED(imatrix); + assert(n_per_row % QK_TURBO3 == 0); + + size_t row_size = (n_per_row / QK_TURBO3) * sizeof(block_turbo3_0); + for (int64_t row = 0; row < nrows; row++) { + quantize_row_turbo3_0_ref( + src + row * n_per_row, + (block_turbo3_0 *)((char *)dst + row * row_size), + n_per_row + ); + } + return nrows * row_size; +} + +/* ---------- TURBO2_0: 2-bit PolarQuant (no QJL) ---------- */ + +void quantize_row_turbo2_0_ref(const float * GGML_RESTRICT x, block_turbo2_0 * GGML_RESTRICT y, int64_t k) { + assert(k % QK_TURBO2 == 0); + + extern int turbo3_cpu_wht_group_size; + int group_size = turbo3_cpu_wht_group_size; + if (group_size != 64 && group_size != 128) { + group_size = (k % 128 == 0) ? 128 : 64; + } + if (k % group_size != 0) group_size = (group_size == 128) ? 64 : 128; + assert(k % group_size == 0); + + const int n_groups = k / group_size; + const int blocks_per_group = group_size / QK_TURBO2; + + for (int g = 0; g < n_groups; g++) { + const float * grp_src = x + g * group_size; + block_turbo2_0 * grp_dst = y + g * blocks_per_group; + + /* 1. L2 norm over the group */ + float norm_sq = 0.0f; + float buf[128]; + for (int j = 0; j < group_size; j++) { + buf[j] = grp_src[j]; + norm_sq += buf[j] * buf[j]; + } + float grp_norm = sqrtf(norm_sq); + float inv_norm = (grp_norm > 1e-10f) ? 1.0f / grp_norm : 0.0f; + + /* 2. Normalize */ + for (int j = 0; j < group_size; j++) buf[j] *= inv_norm; + + /* 3. Forward WHT rotation */ + turbo_cpu_fwht(buf, group_size); + + /* 4. Quantize + pack into sub-blocks */ + float recon_sq = 0.0f; + for (int b = 0; b < blocks_per_group; b++) { + block_turbo2_0 * blk = &grp_dst[b]; + const int off = b * QK_TURBO2; + + memset(blk->qs, 0, QK_TURBO2 / 4); + + for (int j = 0; j < QK_TURBO2; j++) { + int idx = nearest_centroid_2bit(buf[off + j]); + blk->qs[j / 4] |= (idx & 0x3) << ((j % 4) * 2); + recon_sq += CENTROIDS_2BIT[idx] * CENTROIDS_2BIT[idx]; + } + } + + /* 5. Corrected norm */ + float recon_norm = sqrtf(recon_sq); + float corrected = (recon_norm > 1e-10f) ? grp_norm / recon_norm : grp_norm; + for (int b = 0; b < blocks_per_group; b++) { + grp_dst[b].norm = GGML_FP32_TO_FP16(corrected); + } + } +} + +void dequantize_row_turbo2_0(const block_turbo2_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_TURBO2 == 0); + const int nb = k / QK_TURBO2; + for (int block = 0; block < nb; block++) { + float norm = GGML_FP16_TO_FP32(x[block].norm); + for (int j = 0; j < QK_TURBO2; j++) { + uint8_t idx = (x[block].qs[j/4] >> ((j%4)*2)) & 0x3; + y[block * QK_TURBO2 + j] = CENTROIDS_2BIT[idx] * norm; + } + } +} + +size_t quantize_turbo2_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, + int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_UNUSED(imatrix); + assert(n_per_row % QK_TURBO2 == 0); + + size_t row_size = (n_per_row / QK_TURBO2) * sizeof(block_turbo2_0); + for (int64_t row = 0; row < nrows; row++) { + quantize_row_turbo2_0_ref( + src + row * n_per_row, + (block_turbo2_0 *)((char *)dst + row * row_size), + n_per_row + ); + } + return nrows * row_size; +} + +/* ---------- TURBO4_0: 3-bit PolarQuant + 1-bit QJL ---------- */ + +void quantize_row_turbo4_0_ref(const float * GGML_RESTRICT x, block_turbo4_0 * GGML_RESTRICT y, int64_t k) { + turbo_init_rotation(); + turbo_init_qjl(); + + assert(k % QK_TURBO4 == 0); + const int nb = k / QK_TURBO4; + const int d = QK_TURBO4; + + for (int block = 0; block < nb; block++) { + const float * src = x + block * d; + + /* Step 1: Extract norm */ + float norm_sq = 0.0f; + for (int i = 0; i < d; i++) norm_sq += src[i] * src[i]; + float norm = sqrtf(norm_sq); + + /* Normalize */ + float normalized[TURBO_D]; + if (norm > 1e-10f) { + const float inv = 1.0f / norm; + for (int i = 0; i < d; i++) normalized[i] = src[i] * inv; + } else { + memset(normalized, 0, d * sizeof(float)); + } + + /* Step 2: Forward WHT rotation (matches CUDA set_rows) */ + float rotated[TURBO_D]; + memcpy(rotated, normalized, d * sizeof(float)); + turbo_cpu_fwht(rotated, d); + +#if TURBO4_USE_4BIT + /* Step 3: 4-bit quantization (16 centroids) */ + static const float CENTROIDS_4BIT[16] = { + -0.173926f, -0.117195f, -0.089527f, -0.068756f, + -0.051262f, -0.035597f, -0.020989f, -0.006938f, + 0.006938f, 0.020989f, 0.035597f, 0.051262f, + 0.068756f, 0.089527f, 0.117195f, 0.173926f + }; + uint8_t indices[TURBO_D]; + for (int i = 0; i < d; i++) { + indices[i] = (uint8_t)nearest_centroid_4bit(rotated[i]); + } + + /* Norm correction */ + float recon_norm_sq = 0.0f; + for (int i = 0; i < d; i++) { + recon_norm_sq += CENTROIDS_4BIT[indices[i]] * CENTROIDS_4BIT[indices[i]]; + } + float recon_norm = sqrtf(recon_norm_sq); + float corrected_norm = (recon_norm > 1e-10f) ? norm / recon_norm : norm; + y[block].norm = GGML_FP32_TO_FP16(corrected_norm); +#else + /* Step 3: 3-bit quantization (8 centroids) */ + uint8_t indices[TURBO_D]; + for (int i = 0; i < d; i++) { + indices[i] = (uint8_t)nearest_centroid_3bit(rotated[i]); + } + + /* Step 4: Residual */ + float reconstructed[TURBO_D]; + for (int i = 0; i < d; i++) { + reconstructed[i] = CENTROIDS_3BIT[indices[i]]; + } + float mse_recon[TURBO_D]; + matvec(turbo_rotation_t, reconstructed, mse_recon, d); + + float residual[TURBO_D]; + for (int i = 0; i < d; i++) { + residual[i] = normalized[i] - mse_recon[i]; + } + + /* Step 5: QJL */ + float projected[TURBO_D]; + matvec(turbo_qjl_matrix, residual, projected, d); +#endif + + /* Pack */ +#if !TURBO4_USE_4BIT + y[block].norm = GGML_FP32_TO_FP16(norm); +#endif + +#if TURBO4_USE_4BIT + /* 4-bit PolarQuant: nibble pack into qs[64] */ + memset(y[block].qs, 0, d / 2); + for (int i = 0; i < d; i++) { + y[block].qs[i / 2] |= (uint8_t)((indices[i] & 0xF) << ((i % 2) * 4)); + } + y[block].rnorm = GGML_FP32_TO_FP16(0.0f); +#else + /* Legacy 3-bit + QJL: pack 3-bit indices + QJL signs */ + memset(y[block].qs, 0, d * 3 / 8); + for (int i = 0; i < d; i++) { + int bit_offset = i * 3; + int byte_idx = bit_offset / 8; + int bit_pos = bit_offset % 8; + uint16_t val = (uint16_t)(indices[i] & 0x7); + y[block].qs[byte_idx] |= (uint8_t)(val << bit_pos); + if (bit_pos > 5 && byte_idx + 1 < d * 3 / 8) { + y[block].qs[byte_idx + 1] |= (uint8_t)(val >> (8 - bit_pos)); + } + } + memset(y[block].signs, 0, d / 8); + for (int i = 0; i < d; i++) { + if (projected[i] >= 0.0f) { + y[block].signs[i / 8] |= (1 << (i % 8)); + } + } +#endif + } +} + +void dequantize_row_turbo4_0(const block_turbo4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + turbo_init_rotation(); + + assert(k % QK_TURBO4 == 0); + const int nb = k / QK_TURBO4; + const int d = QK_TURBO4; + +#if TURBO4_USE_4BIT + /* 4-bit PolarQuant: nibble unpack → centroid → inverse rotate → scale */ + /* TODO: add proper 4-bit centroid table to C code (currently only in Metal) */ + static const float CENTROIDS_4BIT[16] = { + -0.173926f, -0.117195f, -0.089527f, -0.068756f, + -0.051262f, -0.035597f, -0.020989f, -0.006938f, + 0.006938f, 0.020989f, 0.035597f, 0.051262f, + 0.068756f, 0.089527f, 0.117195f, 0.173926f + }; + for (int block = 0; block < nb; block++) { + float norm = GGML_FP16_TO_FP32(x[block].norm); + float * dst = y + block * d; + for (int i = 0; i < d; i++) { + uint8_t idx = (x[block].qs[i / 2] >> ((i % 2) * 4)) & 0xF; + dst[i] = CENTROIDS_4BIT[idx] * norm; + } + /* No inverse WHT, dequant stays in the rotated domain. + * Q is WHT-rotated by the graph, so gives correct attention scores. + * The inverse WHT is applied to the attention output via GGML_OP_TURBO_WHT (direction=1) in the graph. + */ + } +#else + /* Legacy 3-bit + QJL dequant */ + turbo_init_qjl(); + for (int block = 0; block < nb; block++) { + float norm = GGML_FP16_TO_FP32(x[block].norm); + + uint8_t indices[TURBO_D]; + for (int i = 0; i < d; i++) { + int bit_offset = i * 3; + int byte_idx = bit_offset / 8; + int bit_pos = bit_offset % 8; + uint16_t raw = (uint16_t)x[block].qs[byte_idx]; + if (byte_idx + 1 < d * 3 / 8) { + raw |= (uint16_t)x[block].qs[byte_idx + 1] << 8; + } + indices[i] = (uint8_t)((raw >> bit_pos) & 0x7); + } + + float signs[TURBO_D]; + for (int i = 0; i < d; i++) { + signs[i] = (x[block].signs[i / 8] & (1 << (i % 8))) ? 1.0f : -1.0f; + } + + float rnorm = GGML_FP16_TO_FP32(x[block].rnorm); + const float qjl_scale = TURBO_QJL_CONST / (float)d * rnorm; + + float rotated_recon[TURBO_D]; + for (int i = 0; i < d; i++) { + rotated_recon[i] = CENTROIDS_3BIT[indices[i]]; + } + float mse_recon[TURBO_D]; + matvec(turbo_rotation_t, rotated_recon, mse_recon, d); + + float qjl_recon[TURBO_D]; + matvec(turbo_qjl_matrix_t, signs, qjl_recon, d); + for (int i = 0; i < d; i++) { + qjl_recon[i] *= qjl_scale; + } + + float * dst = y + block * d; + for (int i = 0; i < d; i++) { + dst[i] = (mse_recon[i] + qjl_recon[i]) * norm; + } + } +#endif +} + +size_t quantize_turbo4_0(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, + int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_UNUSED(imatrix); + assert(n_per_row % QK_TURBO4 == 0); + + size_t row_size = (n_per_row / QK_TURBO4) * sizeof(block_turbo4_0); + for (int64_t row = 0; row < nrows; row++) { + quantize_row_turbo4_0_ref( + src + row * n_per_row, + (block_turbo4_0 *)((char *)dst + row * row_size), + n_per_row + ); + } + return nrows * row_size; +} + +/* ================================================================== */ +/* TQ3_1S / TQ4_1S: WHT-rotated weight quantization */ +/* ================================================================== */ + +/* Lloyd-Max centroids for N(0,1) — shared with Metal shaders */ +static const float TQ3_0_CENTROIDS[8] = { + -1.996684f, -1.291398f, -0.740341f, -0.247508f, + 0.230106f, 0.725222f, 1.277503f, 1.988943f +}; + +static const float TQ4_0_CENTROIDS[16] = { + -2.732590f, -2.069017f, -1.618046f, -1.256231f, + -0.942340f, -0.656759f, -0.388048f, -0.128395f, + 0.128395f, 0.388048f, 0.656759f, 0.942340f, + 1.256231f, 1.618046f, 2.069017f, 2.732590f, +}; + +/* WHT sign pattern (golden ratio hash, 32-element blocks) — shared by TQ3 and TQ4 */ +static const float TQ3_0_SIGNS[32] = { + +1.0f, -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, +1.0f, + -1.0f, -1.0f, +1.0f, -1.0f, +1.0f, +1.0f, -1.0f, +1.0f, + -1.0f, -1.0f, +1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f, + -1.0f, +1.0f, +1.0f, -1.0f, +1.0f, -1.0f, -1.0f, +1.0f, +}; + +#define TQ_BLOCK_SIZE 32 +#define TQ_INV_SQRT32 0.17677669529663688f /* 1/sqrt(32) */ + +/* Forward RHT: sign flips -> WHT butterfly -> normalize */ +static void tq3_0_rht_forward(float * buf) { + for (int i = 0; i < TQ_BLOCK_SIZE; i++) buf[i] *= TQ3_0_SIGNS[i]; + for (int step = 1; step < TQ_BLOCK_SIZE; step <<= 1) { + for (int i = 0; i < TQ_BLOCK_SIZE; i += step << 1) { + for (int j = i; j < i + step; j++) { + float a = buf[j], b = buf[j + step]; + buf[j] = a + b; + buf[j + step] = a - b; + } + } + } + for (int i = 0; i < TQ_BLOCK_SIZE; i++) buf[i] *= TQ_INV_SQRT32; +} + +/* Inverse RHT: WHT butterfly -> normalize + unsign */ +static void tq3_0_rht_inverse(float * buf) { + for (int step = 1; step < TQ_BLOCK_SIZE; step <<= 1) { + for (int i = 0; i < TQ_BLOCK_SIZE; i += step << 1) { + for (int j = i; j < i + step; j++) { + float a = buf[j], b = buf[j + step]; + buf[j] = a + b; + buf[j + step] = a - b; + } + } + } + for (int i = 0; i < TQ_BLOCK_SIZE; i++) buf[i] *= TQ_INV_SQRT32 * TQ3_0_SIGNS[i]; +} + +/* Nearest centroid for TQ3 (8 centroids) */ +static int tq3_0_choose_index(float val) { + /* Binary search on midpoints of TQ3_0_CENTROIDS */ + if (val < -1.644041f) return 0; + if (val < -1.015870f) return 1; + if (val < -0.493925f) return 2; + if (val < -0.008701f) return 3; + if (val < 0.477664f) return 4; + if (val < 1.001363f) return 5; + if (val < 1.633223f) return 6; + return 7; +} + +/* Nearest centroid for TQ4 (16 centroids) */ +static int tq4_0_choose_index(float val) { + /* Binary search on midpoints of TQ4_0_CENTROIDS */ + if (val < -2.400804f) return 0; + if (val < -1.843532f) return 1; + if (val < -1.437139f) return 2; + if (val < -1.099286f) return 3; + if (val < -0.799550f) return 4; + if (val < -0.522404f) return 5; + if (val < -0.258222f) return 6; + if (val < 0.000000f) return 7; + if (val < 0.258222f) return 8; + if (val < 0.522404f) return 9; + if (val < 0.799550f) return 10; + if (val < 1.099286f) return 11; + if (val < 1.437139f) return 12; + if (val < 1.843532f) return 13; + if (val < 2.400804f) return 14; + return 15; +} + +/* ---------- TQ3_1S quantization ---------- */ + +void quantize_row_tq3_1s_ref(const float * GGML_RESTRICT x, block_tq3_1s * GGML_RESTRICT y, int64_t k) { + assert(k % QK_TQ3_0 == 0); + const int nb = k / QK_TQ3_0; + + for (int block = 0; block < nb; block++) { + const float * src_blk = x + block * QK_TQ3_0; + block_tq3_1s * blk = &y[block]; + + /* 1. Forward RHT */ + float buf[TQ_BLOCK_SIZE]; + memcpy(buf, src_blk, TQ_BLOCK_SIZE * sizeof(float)); + tq3_0_rht_forward(buf); + + /* 2. Split into two halves, compute RMS per half */ + float rms0 = 0.0f, rms1 = 0.0f; + for (int j = 0; j < 16; j++) rms0 += buf[j] * buf[j]; + for (int j = 16; j < 32; j++) rms1 += buf[j] * buf[j]; + rms0 = sqrtf(rms0 / 16.0f); + rms1 = sqrtf(rms1 / 16.0f); + + /* 3. Scale search (9 points) */ + static const float scales[] = { 0.6f, 0.7f, 0.8f, 0.9f, 1.0f, 1.1f, 1.2f, 1.35f, 1.5f }; + float best_d0 = rms0, best_d1 = rms1; + float best_err = 1e30f; + + for (int si = 0; si < 9; si++) { + float d0 = rms0 * scales[si]; + float d1 = rms1 * scales[si]; + float inv0 = (d0 > 1e-10f) ? 1.0f / d0 : 0.0f; + float inv1 = (d1 > 1e-10f) ? 1.0f / d1 : 0.0f; + + float err = 0.0f; + for (int j = 0; j < 16; j++) { + int idx = tq3_0_choose_index(buf[j] * inv0); + float diff = buf[j] - TQ3_0_CENTROIDS[idx] * d0; + err += diff * diff; + } + for (int j = 16; j < 32; j++) { + int idx = tq3_0_choose_index(buf[j] * inv1); + float diff = buf[j] - TQ3_0_CENTROIDS[idx] * d1; + err += diff * diff; + } + if (err < best_err) { + best_err = err; + best_d0 = d0; + best_d1 = d1; + } + } + + /* 4. Iterative refinement (6 iterations) */ + for (int iter = 0; iter < 6; iter++) { + float inv0 = (best_d0 > 1e-10f) ? 1.0f / best_d0 : 0.0f; + float inv1 = (best_d1 > 1e-10f) ? 1.0f / best_d1 : 0.0f; + + float num0 = 0.0f, den0 = 0.0f; + float num1 = 0.0f, den1 = 0.0f; + for (int j = 0; j < 16; j++) { + int idx = tq3_0_choose_index(buf[j] * inv0); + float c = TQ3_0_CENTROIDS[idx]; + num0 += buf[j] * c; + den0 += c * c; + } + for (int j = 16; j < 32; j++) { + int idx = tq3_0_choose_index(buf[j] * inv1); + float c = TQ3_0_CENTROIDS[idx]; + num1 += buf[j] * c; + den1 += c * c; + } + if (den0 > 1e-10f) best_d0 = num0 / den0; + if (den1 > 1e-10f) best_d1 = num1 / den1; + } + + /* 5. Final quantize + pack */ + float inv0 = (best_d0 > 1e-10f) ? 1.0f / best_d0 : 0.0f; + float inv1 = (best_d1 > 1e-10f) ? 1.0f / best_d1 : 0.0f; + + blk->d0 = GGML_FP32_TO_FP16(best_d0); + blk->d1 = GGML_FP32_TO_FP16(best_d1); + memset(blk->qs, 0, QK_TQ3_0 * 3 / 8); + + /* TQ3 packing: 4 groups of 8 indices packed into 3 bytes each */ + for (int g = 0; g < 4; g++) { + uint8_t indices[8]; + for (int i = 0; i < 8; i++) { + int j = g * 8 + i; + float inv = (j < 16) ? inv0 : inv1; + indices[i] = (uint8_t)tq3_0_choose_index(buf[j] * inv); + } + uint8_t * qp = blk->qs + g * 3; + qp[0] = (indices[0] & 7) | ((indices[1] & 7) << 3) | ((indices[2] & 3) << 6); + qp[1] = ((indices[2] >> 2) & 1) | ((indices[3] & 7) << 1) | ((indices[4] & 7) << 4) | ((indices[5] & 1) << 7); + qp[2] = ((indices[5] >> 1) & 3) | ((indices[6] & 7) << 2) | ((indices[7] & 7) << 5); + } + } +} + +void dequantize_row_tq3_1s(const block_tq3_1s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_TQ3_0 == 0); + const int nb = k / QK_TQ3_0; + + for (int blk_i = 0; blk_i < nb; blk_i++) { + float d0 = GGML_FP16_TO_FP32(x[blk_i].d0); + float d1 = GGML_FP16_TO_FP32(x[blk_i].d1); + + /* Unpack 3-bit indices */ + float buf[32]; + for (int g = 0; g < 4; g++) { + const uint8_t * qp = x[blk_i].qs + g * 3; + uint8_t idx[8]; + idx[0] = qp[0] & 7; + idx[1] = (qp[0] >> 3) & 7; + idx[2] = ((qp[0] >> 6) | (qp[1] << 2)) & 7; + idx[3] = (qp[1] >> 1) & 7; + idx[4] = (qp[1] >> 4) & 7; + idx[5] = ((qp[1] >> 7) | (qp[2] << 1)) & 7; + idx[6] = (qp[2] >> 2) & 7; + idx[7] = (qp[2] >> 5) & 7; + + for (int i = 0; i < 8; i++) { + int j = g * 8 + i; + float d = (j < 16) ? d0 : d1; + buf[j] = TQ3_0_CENTROIDS[idx[i]] * d; + } + } + + /* Inverse RHT */ + tq3_0_rht_inverse(buf); + + memcpy(y + blk_i * QK_TQ3_0, buf, QK_TQ3_0 * sizeof(float)); + } +} + +size_t quantize_tq3_1s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, + int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_UNUSED(imatrix); + assert(n_per_row % QK_TQ3_0 == 0); + + size_t row_size = (n_per_row / QK_TQ3_0) * sizeof(block_tq3_1s); + for (int64_t row = 0; row < nrows; row++) { + quantize_row_tq3_1s_ref( + src + row * n_per_row, + (block_tq3_1s *)((char *)dst + row * row_size), + n_per_row + ); + } + return nrows * row_size; +} + +/* ---------- TQ4_1S quantization ---------- */ + +void quantize_row_tq4_1s_ref(const float * GGML_RESTRICT x, block_tq4_1s * GGML_RESTRICT y, int64_t k) { + assert(k % QK_TQ4_1S == 0); + const int nb = k / QK_TQ4_1S; + + for (int block = 0; block < nb; block++) { + const float * src_blk = x + block * QK_TQ4_1S; + block_tq4_1s * blk = &y[block]; + + /* 1. Forward RHT */ + float buf[TQ_BLOCK_SIZE]; + memcpy(buf, src_blk, TQ_BLOCK_SIZE * sizeof(float)); + tq3_0_rht_forward(buf); + + /* 2. Split into two halves, compute RMS per half */ + float rms0 = 0.0f, rms1 = 0.0f; + for (int j = 0; j < 16; j++) rms0 += buf[j] * buf[j]; + for (int j = 16; j < 32; j++) rms1 += buf[j] * buf[j]; + rms0 = sqrtf(rms0 / 16.0f); + rms1 = sqrtf(rms1 / 16.0f); + + /* 3. Scale search (9 points) */ + static const float scales[] = { 0.6f, 0.7f, 0.8f, 0.9f, 1.0f, 1.1f, 1.2f, 1.35f, 1.5f }; + float best_d0 = rms0, best_d1 = rms1; + float best_err = 1e30f; + + for (int si = 0; si < 9; si++) { + float d0 = rms0 * scales[si]; + float d1 = rms1 * scales[si]; + float inv0 = (d0 > 1e-10f) ? 1.0f / d0 : 0.0f; + float inv1 = (d1 > 1e-10f) ? 1.0f / d1 : 0.0f; + + float err = 0.0f; + for (int j = 0; j < 16; j++) { + int idx = tq4_0_choose_index(buf[j] * inv0); + float diff = buf[j] - TQ4_0_CENTROIDS[idx] * d0; + err += diff * diff; + } + for (int j = 16; j < 32; j++) { + int idx = tq4_0_choose_index(buf[j] * inv1); + float diff = buf[j] - TQ4_0_CENTROIDS[idx] * d1; + err += diff * diff; + } + if (err < best_err) { + best_err = err; + best_d0 = d0; + best_d1 = d1; + } + } + + /* 4. Iterative refinement (6 iterations) */ + for (int iter = 0; iter < 6; iter++) { + float inv0 = (best_d0 > 1e-10f) ? 1.0f / best_d0 : 0.0f; + float inv1 = (best_d1 > 1e-10f) ? 1.0f / best_d1 : 0.0f; + + float num0 = 0.0f, den0 = 0.0f; + float num1 = 0.0f, den1 = 0.0f; + for (int j = 0; j < 16; j++) { + int idx = tq4_0_choose_index(buf[j] * inv0); + float c = TQ4_0_CENTROIDS[idx]; + num0 += buf[j] * c; + den0 += c * c; + } + for (int j = 16; j < 32; j++) { + int idx = tq4_0_choose_index(buf[j] * inv1); + float c = TQ4_0_CENTROIDS[idx]; + num1 += buf[j] * c; + den1 += c * c; + } + if (den0 > 1e-10f) best_d0 = num0 / den0; + if (den1 > 1e-10f) best_d1 = num1 / den1; + } + + /* 5. Final quantize + pack (nibble packing) */ + float inv0 = (best_d0 > 1e-10f) ? 1.0f / best_d0 : 0.0f; + float inv1 = (best_d1 > 1e-10f) ? 1.0f / best_d1 : 0.0f; + + blk->d0 = GGML_FP32_TO_FP16(best_d0); + blk->d1 = GGML_FP32_TO_FP16(best_d1); + memset(blk->qs, 0, QK_TQ4_1S / 2); + + for (int j = 0; j < QK_TQ4_1S; j++) { + float inv = (j < 16) ? inv0 : inv1; + int idx = tq4_0_choose_index(buf[j] * inv); + blk->qs[j / 2] |= (uint8_t)((idx & 0xF) << ((j & 1) * 4)); + } + } +} + +void dequantize_row_tq4_1s(const block_tq4_1s * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k) { + assert(k % QK_TQ4_1S == 0); + const int nb = k / QK_TQ4_1S; + + for (int blk_i = 0; blk_i < nb; blk_i++) { + float d0 = GGML_FP16_TO_FP32(x[blk_i].d0); + float d1 = GGML_FP16_TO_FP32(x[blk_i].d1); + + float buf[32]; + for (int j = 0; j < 32; j++) { + uint8_t idx = (x[blk_i].qs[j / 2] >> ((j & 1) * 4)) & 0xF; + float d = (j < 16) ? d0 : d1; + buf[j] = TQ4_0_CENTROIDS[idx] * d; + } + + /* Inverse RHT */ + tq3_0_rht_inverse(buf); + + memcpy(y + blk_i * QK_TQ4_1S, buf, QK_TQ4_1S * sizeof(float)); + } +} + +size_t quantize_tq4_1s(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, + int64_t nrows, int64_t n_per_row, const float * imatrix) { + GGML_UNUSED(imatrix); + assert(n_per_row % QK_TQ4_1S == 0); + + size_t row_size = (n_per_row / QK_TQ4_1S) * sizeof(block_tq4_1s); + for (int64_t row = 0; row < nrows; row++) { + quantize_row_tq4_1s_ref( + src + row * n_per_row, + (block_tq4_1s *)((char *)dst + row * row_size), + n_per_row + ); + } + return nrows * row_size; +} diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index eda041f4518..8e458444f0b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -922,6 +922,46 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { .type_size = 0, .is_quantized = false, }, + [GGML_TYPE_TURBO3_0] = { + .type_name = "turbo3_0", + .blck_size = QK_TURBO3, + .type_size = sizeof(block_turbo3_0), + .is_quantized = true, + .to_float = NULL, + .from_float_ref = NULL, + }, + [GGML_TYPE_TURBO4_0] = { + .type_name = "turbo4_0", + .blck_size = QK_TURBO4, + .type_size = sizeof(block_turbo4_0), + .is_quantized = true, + .to_float = NULL, + .from_float_ref = NULL, + }, + [GGML_TYPE_TURBO2_0] = { + .type_name = "turbo2_0", + .blck_size = QK_TURBO2, + .type_size = sizeof(block_turbo2_0), + .is_quantized = true, + .to_float = NULL, + .from_float_ref = NULL, + }, + [GGML_TYPE_TQ3_1S] = { + .type_name = "tq3_1s", + .blck_size = QK_TQ3_0, + .type_size = sizeof(block_tq3_1s), + .is_quantized = true, + .to_float = NULL, + .from_float_ref = NULL, + }, + [GGML_TYPE_TQ4_1S] = { + .type_name = "tq4_1s", + .blck_size = QK_TQ4_1S, + .type_size = sizeof(block_tq4_1s), + .is_quantized = true, + .to_float = NULL, + .from_float_ref = NULL, + }, }; const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) { @@ -1058,6 +1098,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "RWKV_WKV7", "SOLVE_TRI", "GATED_DELTA_NET", + "TURBO_WHT", "UNARY", @@ -1075,7 +1116,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GLU", }; -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1168,6 +1209,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "rwkv_wkv7(r, w, k, v, a, b, s)", "A X = B, A triangular, solve X", "gated_delta_net(q, k, v, g, beta, s)", + "turbo_wht(x)", "unary(x)", @@ -1185,7 +1227,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "glu(x)", }; -static_assert(GGML_OP_COUNT == 96, "GGML_OP_COUNT != 96"); +static_assert(GGML_OP_COUNT == 97, "GGML_OP_COUNT != 97"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -5390,6 +5432,19 @@ void ggml_flash_attn_ext_add_sinks( a->src[4] = sinks; } +void ggml_flash_attn_ext_set_kv_indices( + struct ggml_tensor * a, + struct ggml_tensor * indices) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + // indices is an optional 1D i32 tensor of logical KV row indices + if (indices) { + GGML_ASSERT(indices->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(indices)); + GGML_ASSERT(indices->ne[1] == 1 && indices->ne[2] == 1 && indices->ne[3] == 1); + } + a->src[5] = indices; +} + // ggml_flash_attn_back struct ggml_tensor * ggml_flash_attn_back( @@ -6213,6 +6268,26 @@ struct ggml_tensor * ggml_gated_delta_net( return result; } +// ggml_turbo_wht + +struct ggml_tensor * ggml_turbo_wht( + struct ggml_context * ctx, + struct ggml_tensor * a, + int direction, + int group_size, + struct ggml_tensor * scale) { + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + int32_t params[2] = { direction, group_size }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_TURBO_WHT; + result->src[0] = a; + result->src[1] = scale; + + return result; +} + //////////////////////////////////////////////////////////////////////////////// struct ggml_hash_set ggml_hash_set_new(size_t size) { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 09102f549c8..25107e3763f 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -286,16 +286,24 @@ llama_kv_cache::llama_kv_cache( LLAMA_LOG_WARN("%s: attention rotation force disabled (LLAMA_ATTN_ROT_DISABLE)\n", __func__); } + // Turbo KV types have WHT baked into the SET_ROWS kernel; exclude them + // from the generic attn_rot path to avoid double-rotating K and V. + auto is_turbo_kv = [](ggml_type t) { + return t == GGML_TYPE_TURBO3_0 || t == GGML_TYPE_TURBO4_0 || t == GGML_TYPE_TURBO2_0; + }; + attn_rot_k = !attn_rot_disable && n_embd_head_k_all > 0 && ggml_is_quantized(type_k) && + !is_turbo_kv(type_k) && hparams.n_embd_head_k() % 64 == 0; attn_rot_v = !attn_rot_disable && n_embd_head_v_all > 0 && ggml_is_quantized(type_v) && + !is_turbo_kv(type_v) && hparams.n_embd_head_v() % 64 == 0; LLAMA_LOG_INFO("%s: attn_rot_k = %d, n_embd_head_k_all = %d\n", __func__, attn_rot_k, n_embd_head_k_all);