Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,10 @@ const std::vector<ggml_type> kv_cache_types = {
GGML_TYPE_IQ4_NL,
GGML_TYPE_Q5_0,
GGML_TYPE_Q5_1,
// TurboQuant KV cache types (GPU only)
GGML_TYPE_TURBO3_0,
GGML_TYPE_TURBO4_0,
GGML_TYPE_TURBO2_0,
};

static ggml_type kv_cache_type_from_str(const std::string & s) {
Expand Down
159 changes: 159 additions & 0 deletions docs/turbo-quant.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# TurboQuant HIP/CUDA — Implementation Notes

TurboQuant is a family of quantization schemes for the **KV-cache** (and optionally for
weight matrices) that achieves higher compression than existing quantization types while
preserving flash-attention quality. This document describes the upstream port, the
conflict-resolution choices made during the port, and how to enable/test the new code.

---

## What is TurboQuant?

| Type | Bits/value | Description |
|------|-----------|-------------|
| `TURBO3_0` | ≈3.19 | 3-bit KV cache: 2-bit PolarQuant indices + 1-bit QJL sign correction |
| `TURBO4_0` | 4.25 | 4-bit KV cache: 4-bit PolarQuant indices (nibble-packed) |
| `TURBO2_0` | 2.13 | 2-bit KV cache: 2-bit PolarQuant indices only (fastest, most aggressive) |
| `TQ3_1S` | ≈3 | 3-bit weight quantization: WHT-rotated 8-level Lloyd-Max |
| `TQ4_1S` | ≈4 | 4-bit weight quantization: WHT-rotated 16-level Lloyd-Max |

All KV-cache types use a block size of **128** elements. The WHT (Walsh-Hadamard
Transform) rotation mixes correlations within each block to improve quantization quality.

---

## New GGML types

```
GGML_TYPE_TURBO3_0 = 42
GGML_TYPE_TURBO4_0 = 43
GGML_TYPE_TURBO2_0 = 44
GGML_TYPE_TQ3_1S = 45
GGML_TYPE_TQ4_1S = 46
```

A new GGML operator `GGML_OP_TURBO_WHT` applies the (inverse) Walsh-Hadamard Transform
to floating-point vectors.

---

## New source files

| File | Description |
|------|-------------|
| `ggml/src/ggml-turbo-quant.c` | CPU quantization / dequantization reference implementations |
| `ggml/src/ggml-cuda/turbo-quant.cuh` | GPU centroid tables and block dequant helpers |
| `ggml/src/ggml-cuda/turbo-wht.cu/.cuh` | GPU Walsh-Hadamard Transform kernel |
| `ggml/src/ggml-cuda/turbo-innerq.cu/.cuh` | GPU inner quantization / conversion kernels |
| `ggml/src/ggml-cuda/mmvq-tq.cu/.cuh` | GPU matrix-vector multiply for TQ weight types |
| `ggml/src/ggml-cuda/template-instances/fattn-vec-instance-{K}-{V}.cu` | Per-type flash-attention kernel instantiations |
| `ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq640-dv512.cu` | Tile flash-attention instance for D_KQ=640, D_V=512 |

---

## Modified files

| File | Change |
|------|--------|
| `ggml/include/ggml.h` | New `ggml_type` values, `GGML_OP_TURBO_WHT`, `ggml_turbo_wht()`, `ggml_flash_attn_ext_set_kv_indices()` |
| `ggml/src/ggml-common.h` | Block struct definitions for new types |
| `ggml/src/ggml.c` | Type metadata tables, `ggml_turbo_wht()` and `ggml_flash_attn_ext_set_kv_indices()` implementations |
| `ggml/src/ggml-cuda/ggml-cuda.cu` | TQ4_1S load-time conversion, mul_mat dispatch for TQ weights, `GGML_OP_TURBO_WHT` dispatch, `supports_op` entries |
| `ggml/src/ggml-cuda/fattn-common.cuh` | `kv_indices` parameter in kernel typedef; KQ dot and V dequant functions for turbo types |
| `ggml/src/ggml-cuda/fattn-vec.cuh` | `kv_indices` parameter wired through |
| `ggml/src/ggml-cuda/fattn.cu` | Turbo VEC kernel instantiations; D=640 MMA case; mixed-type allow-list |
| `ggml/src/ggml-cuda/dequantize.cuh` | Dequantize helpers for TURBO types |
| `ggml/src/ggml-cuda/convert.cu` | TQ4_1S → Q8_0 conversion kernel used during model load |
| `ggml/src/ggml-cuda/set-rows.cu` | SET_ROWS support for TURBO3_0, TURBO4_0, TURBO2_0 |
| `ggml/src/CMakeLists.txt` | Adds `ggml-turbo-quant.c` to the ggml-base target |
| `ggml/src/ggml-cuda/CMakeLists.txt` | Adds new `.cu` files to the CUDA target |
| `ggml/src/ggml-hip/CMakeLists.txt` | Mirrors CUDA CMake additions for HIP/ROCm |
| `src/llama-kv-cache.cpp` | KV cache quantization type validation accepts turbo types |
| `common/arg.cpp` | `--cache-type-k` / `--cache-type-v` CLI help updated |

---

## Conflict resolution notes

The source branch (`domvox/llama.cpp-turboquant-hip:feature/triattention-scoring`) had
significant divergence from upstream `ggml-org/master`. Key reconciliations:

1. **`fattn_kernel_t` typedef** — The source branch adds a `kv_indices` parameter for the
sparse-attention TriAttention feature. We include this parameter in the typedef and
thread it through `launch_fattn` (passing `nullptr` when unused), but we **omit** the
full TriAttention routing/pruning logic as out-of-scope for this port.

2. **`GGML_API` visibility** — The source branch removes `extern` from the non-Windows
`GGML_API` macro. This is a project-wide API change that requires careful review; we
omit it here to minimise diff size.

3. **`TURBO4_USE_4BIT` guard** — The C implementation in `ggml-turbo-quant.c` contains a
legacy `#else` branch that references a `signs` field absent from `block_turbo4_0`.
We add `#define TURBO4_USE_4BIT 1` to `ggml-common.h` to ensure the correct 4-bit
path is always used.

4. **D=640 MMA kernel** — Added for GLM-4.7 Flash (K head-dim 576 zero-padded to 640)
but depends on `ggml_cuda_flash_attn_ext_mma_f16_case<640, 512, …>`. Because the
corresponding MMA template instances are not yet part of this PR, the runtime will
fall back to the VEC kernel or CPU for that model variant.

5. **TQ weight `mul_mat_id`** — `TQ4_1S` / `TQ3_1S` weight tensors are transparently
converted to `Q8_0` on upload to the GPU; they use the cuBLAS dequant path and bypass
mmvq/mmq kernels.

---

## How to enable TurboQuant KV cache

Pass one of the new types to `--cache-type-k` and/or `--cache-type-v`:

```bash
# 3-bit K, 3-bit V (highest quality turbo)
llama-cli -m model.gguf --cache-type-k turbo3_0 --cache-type-v turbo3_0 ...

# 2-bit K, 3-bit V (aggressive K, quality V)
llama-cli -m model.gguf --cache-type-k turbo2_0 --cache-type-v turbo3_0 ...

# 4-bit K, 8-bit V (near-lossless V, compressed K)
llama-cli -m model.gguf --cache-type-k turbo4_0 --cache-type-v q8_0 ...
```

### Supported combinations (GPU/VEC kernel)

Any combination of `{turbo2_0, turbo3_0, turbo4_0, q8_0}` for K and V is supported on the VEC
flash-attention kernel. The VEC kernel activates when:

- Head dimension ≤ 256 and is a multiple of 64
- Context length is a multiple of 256 (`FATTN_KQ_STRIDE`)

Larger head dimensions or non-standard context lengths fall back to the MMA tile kernel
(if available) or to the CPU.

---

## Building with HIP/ROCm

```bash
cmake -B build -DGGML_HIP=ON -DAMDGPU_TARGETS="gfx1100;gfx942" \
-DCMAKE_BUILD_TYPE=Release
cmake --build build -j$(nproc)
```

No extra CMake flags are required to enable TurboQuant; it is compiled in unconditionally.

---

## Limitations / caveats

- **MMA kernel not yet supported** for turbo KV types. Very long contexts or large
batch sizes fall back to the VEC kernel or CPU.
- **TQ3_1S / TQ4_1S weight types** require the WHT rotation operator
(`GGML_OP_TURBO_WHT`) to be applied by the graph builder. A reference
`llama_model_loader` hook is not yet wired; these types are usable via low-level API
only.
- **TriAttention / kv_indices** sparse-attention routing is partially scaffolded
(`ggml_flash_attn_ext_set_kv_indices` API exists) but not fully implemented on the
VEC kernel side.
- The TURBO2_0 and TURBO3_0 centroid tables live in `turbo-quant.cuh` (GPU) and
`ggml-turbo-quant.c` (CPU). They are Lloyd-Max optimal for a unit-normal distribution
of WHT-rotated activations.
24 changes: 23 additions & 1 deletion ggml/include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,12 @@ extern "C" {
GGML_TYPE_MXFP4 = 39, // MXFP4 (1 block)
GGML_TYPE_NVFP4 = 40, // NVFP4 (4 blocks, E4M3 scale)
GGML_TYPE_Q1_0 = 41,
GGML_TYPE_COUNT = 42,
GGML_TYPE_TURBO3_0 = 42, // TurboQuant 3-bit KV cache: 2-bit PolarQuant + 1-bit sign
GGML_TYPE_TURBO4_0 = 43, // TurboQuant 4-bit KV cache: 4-bit PolarQuant (nibble packed)
GGML_TYPE_TURBO2_0 = 44, // TurboQuant 2-bit KV cache: 2-bit PolarQuant (no QJL)
GGML_TYPE_TQ3_1S = 45, // TurboQuant 3-bit weight: WHT-rotated 8-level Lloyd-Max, block_size=32
GGML_TYPE_TQ4_1S = 46, // TurboQuant 4-bit weight: WHT-rotated 16-level Lloyd-Max, block_size=32
GGML_TYPE_COUNT = 47,
};

// precision
Expand Down Expand Up @@ -561,6 +566,7 @@ extern "C" {
GGML_OP_RWKV_WKV7,
GGML_OP_SOLVE_TRI,
GGML_OP_GATED_DELTA_NET,
GGML_OP_TURBO_WHT,

GGML_OP_UNARY,

Expand Down Expand Up @@ -2401,6 +2407,12 @@ extern "C" {
GGML_API enum ggml_prec ggml_flash_attn_ext_get_prec(
const struct ggml_tensor * a);

// Optional KV row indices for sparse cache access.
// indices is a contiguous 1D i32 tensor mapping logical to physical KV rows.
GGML_API void ggml_flash_attn_ext_set_kv_indices(
struct ggml_tensor * a,
struct ggml_tensor * indices);

GGML_API void ggml_flash_attn_ext_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);
Expand Down Expand Up @@ -2539,6 +2551,16 @@ extern "C" {
struct ggml_tensor * beta,
struct ggml_tensor * state);

// TurboQuant Walsh-Hadamard Transform (O(d log d) rotation for KV cache compression)
// Applies WHT rotation to 128-element groups along ne[0]: sign1 -> butterfly -> sign2 -> normalize
// direction: 0 = forward (signs1 -> WHT -> signs2), 1 = inverse (signs2 -> WHT -> signs1)
GGML_API struct ggml_tensor * ggml_turbo_wht(
struct ggml_context * ctx,
struct ggml_tensor * a,
int direction,
int group_size, // 0 = auto (64 or 128 from ne[0])
struct ggml_tensor * scale); // NULL = no InnerQ scaling

// custom operators

typedef void (*ggml_custom1_op_t)(struct ggml_tensor * dst , const struct ggml_tensor * a, int ith, int nth, void * userdata);
Expand Down
1 change: 1 addition & 0 deletions ggml/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ add_library(ggml-base
ggml-threading.h
ggml-quants.c
ggml-quants.h
ggml-turbo-quant.c
gguf.cpp)

set_target_properties(ggml-base PROPERTIES
Expand Down
77 changes: 77 additions & 0 deletions ggml/src/ggml-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,83 @@ typedef struct {
} block_tq2_0;
static_assert(sizeof(block_tq2_0) == sizeof(ggml_half) + QK_K / 4, "wrong tq2_0 block size/padding");

// TurboQuant 3-bit MSE-only: 3-bit PolarQuant indices (no QJL)
// Storage block size = 128 (one block per rotation group / head_dim)
// Per block: norm(fp16) + 2-bit indices (32 bytes) + 1-bit extra (16 bytes) = 50 bytes per 128 values
// = 3.125 bits/value -> ~5.1x compression vs fp16
// The 3-bit index is split: lower 2 bits in qs[], upper 1 bit in signs[]
#define QK_TURBO3 128 // Block size 128: one block per rotation group, eliminates redundant norms
#define QK_TURBO3_GROUP 128 // rotation group size = head_dim
static_assert(QK_TURBO3 % 2 == 0, "QK_TURBO3 must be even for 2D VQ pair quantization");
// Derived: FA template nl parameters (auto-scale with block size)
#define NL_TURBO3 (QK_TURBO3 / 16) // non-vec FA iterations per block
#define NL_TURBO3_VEC (QK_TURBO3 / 4) // vec FA iterations per block
typedef struct {
ggml_half norm; // 2 bytes: vector L2 norm (for rescaling)
uint8_t qs[QK_TURBO3 / 4]; // 32 bytes: lower 2-bit indices (4 per byte)
uint8_t signs[QK_TURBO3 / 8]; // 16 bytes: upper 1-bit of 3-bit index (8 per byte)
} block_turbo3_0; // 50 bytes total
static_assert(sizeof(block_turbo3_0) == sizeof(ggml_half) + QK_TURBO3/4 + QK_TURBO3/8, "wrong turbo3_0 block size/padding");

// TurboQuant 4-bit: 4-bit PolarQuant indices (nibble packed)
// Default: 4-bit PolarQuant on all backends
#define QK_TURBO4 128
// Enable 4-bit path: always use the 4-bit PolarQuant implementation
#ifndef TURBO4_USE_4BIT
#define TURBO4_USE_4BIT 1
#endif

// 4-bit PolarQuant: 16 optimal centroids, nibble packed
// Per block: norm(fp16) + rnorm(fp16, reserved) + 4-bit indices (64 bytes)
// = 68 bytes per 128 values = 4.25 bits/value -> 3.8x compression vs fp16
typedef struct {
ggml_half norm; // 2 bytes
ggml_half rnorm; // 2 bytes (reserved, unused in 4-bit mode)
uint8_t qs[QK_TURBO4 / 2]; // 64 bytes: 4-bit PolarQuant indices (nibble packed)
} block_turbo4_0; // 68 bytes total
static_assert(sizeof(block_turbo4_0) == 68, "wrong turbo4_0 block size");

static_assert(QK_TURBO4 == 128, "turbo4 kernels assume QK_TURBO4 == 128");

// TurboQuant 2-bit: 2-bit PolarQuant indices only (no QJL)
// Per block: norm(fp16) + 2-bit indices (32 bytes) = 34 bytes per 128 values
// = 2.125 bits/value -> ~7.5x compression vs fp16
// 4 centroids (Lloyd-Max for N(0, 1/128)): {-0.133462, -0.039994, 0.039994, 0.133462}
#define QK_TURBO2 128 // Block size 128: one block per rotation group
#define QK_TURBO2_GROUP 128 // rotation group size = head_dim
// Derived: FA template nl parameters (auto-scale with block size)
#define NL_TURBO2 (QK_TURBO2 / 16) // non-vec FA iterations per block
#define NL_TURBO2_VEC (QK_TURBO2 / 4) // vec FA iterations per block
typedef struct {
ggml_half norm; // 2 bytes: corrected L2 norm
uint8_t qs[QK_TURBO2 / 4]; // 32 bytes: 2-bit indices (4 per byte)
} block_turbo2_0; // 34 bytes total
static_assert(sizeof(block_turbo2_0) == sizeof(ggml_half) + QK_TURBO2/4, "wrong turbo2_0 block size/padding");

// TQ3_1S: WHT-rotated 3-bit weight quantization (8-level Lloyd-Max for N(0,1))
// Block size 32, dual half-block scales (d0 for [0..15], d1 for [16..31])
// Per block: d0(fp16) + d1(fp16) + 3-bit indices packed (12 bytes) = 16 bytes per 32 values
// = 4.0 bits/value
#define QK_TQ3_0 32
typedef struct {
ggml_half d0; // 2 bytes: scale for first 16 elements
ggml_half d1; // 2 bytes: scale for last 16 elements
uint8_t qs[QK_TQ3_0 * 3 / 8]; // 12 bytes: 3-bit indices packed (4 groups of 8 in 3 bytes)
} block_tq3_1s; // 16 bytes total
static_assert(sizeof(block_tq3_1s) == 16, "wrong tq3_1s block size");

// TQ4_1S: WHT-rotated 4-bit weight quantization (16-level Lloyd-Max for N(0,1))
// Block size 32, dual half-block scales (d0 for [0..15], d1 for [16..31])
// Per block: d0(fp16) + d1(fp16) + 4-bit indices packed (16 bytes) = 20 bytes per 32 values
// = 5.0 bits/value
#define QK_TQ4_1S 32
typedef struct {
ggml_half d0; // 2 bytes: scale for first 16 elements
ggml_half d1; // 2 bytes: scale for last 16 elements
uint8_t qs[QK_TQ4_1S / 2]; // 16 bytes: 4-bit indices nibble-packed
} block_tq4_1s; // 20 bytes total
static_assert(sizeof(block_tq4_1s) == 20, "wrong tq4_1s block size");

//
// Super-block quantization structures
//
Expand Down
17 changes: 16 additions & 1 deletion ggml/src/ggml-cuda/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,22 @@ if (CUDAToolkit_FOUND)
template-instances/fattn-vec-instance-f16-f16.cu
template-instances/fattn-vec-instance-q4_0-q4_0.cu
template-instances/fattn-vec-instance-q8_0-q8_0.cu
template-instances/fattn-vec-instance-bf16-bf16.cu)
template-instances/fattn-vec-instance-bf16-bf16.cu
template-instances/fattn-vec-instance-turbo3_0-turbo3_0.cu
template-instances/fattn-vec-instance-turbo3_0-q8_0.cu
template-instances/fattn-vec-instance-q8_0-turbo3_0.cu
template-instances/fattn-vec-instance-turbo2_0-turbo2_0.cu
template-instances/fattn-vec-instance-turbo2_0-q8_0.cu
template-instances/fattn-vec-instance-q8_0-turbo2_0.cu
template-instances/fattn-vec-instance-turbo3_0-turbo2_0.cu
template-instances/fattn-vec-instance-turbo2_0-turbo3_0.cu
template-instances/fattn-vec-instance-turbo4_0-turbo4_0.cu
template-instances/fattn-vec-instance-turbo4_0-q8_0.cu
template-instances/fattn-vec-instance-q8_0-turbo4_0.cu
template-instances/fattn-vec-instance-turbo4_0-turbo3_0.cu
template-instances/fattn-vec-instance-turbo3_0-turbo4_0.cu
template-instances/fattn-vec-instance-turbo4_0-turbo2_0.cu
template-instances/fattn-vec-instance-turbo2_0-turbo4_0.cu)
endif()

ggml_add_backend_library(ggml-cuda
Expand Down
Loading