Skip to content

Commit 08abefa

Browse files
committed
Add PagedAttention support (experimental, CUDA only)
Implement PagedAttention algorithm from for memory-efficient KV cache management. This feature reduces memory fragmentation by storing KV cache in fixed-size blocks (similar to virtual memory paging) and enables efficient memory sharing between sequences through copy-on-write semantics. The implementation is experimental and disabled by default. Enable with the --pagedattention flag Signed-off-by: Eric Curtin <eric.curtin@docker.com>
1 parent 7f8ef50 commit 08abefa

19 files changed

+2274
-3
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
10171017
string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
10181018
}
10191019
}).set_env("LLAMA_ARG_FLASH_ATTN"));
1020+
add_opt(common_arg(
1021+
{"--pagedattention"},
1022+
"enable PagedAttention for KV cache (experimental, requires CUDA)",
1023+
[](common_params & params) {
1024+
params.use_paged_attention = true;
1025+
}
1026+
));
10201027
add_opt(common_arg(
10211028
{"-p", "--prompt"}, "PROMPT",
10221029
"prompt to start generation with; for system message, use -sys",

common/common.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,6 +1275,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
12751275
cparams.op_offload = !params.no_op_offload;
12761276
cparams.swa_full = params.swa_full;
12771277
cparams.kv_unified = params.kv_unified;
1278+
cparams.use_paged_attention = params.use_paged_attention;
12781279

12791280
cparams.type_k = params.cache_type_k;
12801281
cparams.type_v = params.cache_type_v;

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@ struct common_params {
406406
bool ctx_shift = false; // context shift on infinite text generation
407407
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
408408
bool kv_unified = false; // enable unified KV cache
409+
bool use_paged_attention = false; // enable PagedAttention (experimental, requires CUDA)
409410

410411
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
411412
bool use_mmap = true; // use mmap for faster loads

ggml/include/ggml.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ extern "C" {
537537

538538
GGML_OP_FLASH_ATTN_EXT,
539539
GGML_OP_FLASH_ATTN_BACK,
540+
GGML_OP_PAGED_ATTENTION,
540541
GGML_OP_SSM_CONV,
541542
GGML_OP_SSM_SCAN,
542543
GGML_OP_WIN_PART,
@@ -2312,6 +2313,22 @@ extern "C" {
23122313
struct ggml_tensor * a,
23132314
struct ggml_tensor * sinks);
23142315

2316+
// PagedAttention (paged KV cache attention)
2317+
// q: [n_tokens, n_heads, head_size]
2318+
// k_cache: [num_blocks, block_size, n_kv_heads, head_size] (paged)
2319+
// v_cache: [num_blocks, block_size, n_kv_heads, head_size] (paged)
2320+
// block_tables: [n_seqs, max_blocks_per_seq] (int32)
2321+
// seq_lens: [n_seqs] (int32)
2322+
GGML_API struct ggml_tensor * ggml_paged_attention(
2323+
struct ggml_context * ctx,
2324+
struct ggml_tensor * q,
2325+
struct ggml_tensor * k_cache,
2326+
struct ggml_tensor * v_cache,
2327+
struct ggml_tensor * block_tables,
2328+
struct ggml_tensor * seq_lens,
2329+
int32_t block_size,
2330+
float scale);
2331+
23152332
// TODO: needs to be adapted to ggml_flash_attn_ext
23162333
GGML_API struct ggml_tensor * ggml_flash_attn_back(
23172334
struct ggml_context * ctx,

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2062,6 +2062,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
20622062
{
20632063
// nop
20642064
} break;
2065+
case GGML_OP_PAGED_ATTENTION:
2066+
{
2067+
// nop (CUDA-only operation)
2068+
} break;
20652069
case GGML_OP_COUNT:
20662070
{
20672071
GGML_ABORT("fatal error");
Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
/**
2+
* GGML CUDA Backend for PagedAttention
3+
*
4+
* This file provides the CUDA backend implementation for the GGML_OP_PAGED_ATTENTION operation.
5+
* It bridges GGML's operation framework with the PagedAttention CUDA kernels.
6+
*
7+
* NOTE: PagedAttention is currently experimental and only supported on CUDA.
8+
* MUSA support is disabled due to compiler compatibility issues.
9+
*/
10+
11+
// PagedAttention is not yet supported on MUSA
12+
#ifndef GGML_USE_MUSA
13+
14+
#include "common.cuh"
15+
#include "paged-attention.cuh"
16+
17+
// Extract parameters from GGML tensor
18+
static void ggml_cuda_op_paged_attention_get_params(
19+
const ggml_tensor * dst,
20+
float * scale,
21+
int32_t * block_size) {
22+
23+
const float * params = (const float *)dst->op_params;
24+
*scale = params[0];
25+
*block_size = (int32_t)params[1];
26+
}
27+
28+
// Main CUDA backend function for PagedAttention
29+
void ggml_cuda_op_paged_attention(
30+
ggml_backend_cuda_context & ctx,
31+
ggml_tensor * dst) {
32+
33+
const ggml_tensor * q = dst->src[0]; // query
34+
const ggml_tensor * k_cache = dst->src[1]; // key cache (paged)
35+
const ggml_tensor * v_cache = dst->src[2]; // value cache (paged)
36+
const ggml_tensor * block_tables = dst->src[3]; // block tables
37+
const ggml_tensor * seq_lens = dst->src[4]; // sequence lengths
38+
const ggml_tensor * alibi_slopes = dst->src[5]; // optional ALiBi slopes (can be nullptr)
39+
40+
// Extract parameters
41+
float scale;
42+
int32_t block_size;
43+
ggml_cuda_op_paged_attention_get_params(dst, &scale, &block_size);
44+
45+
// Get tensor dimensions
46+
const int64_t head_size = q->ne[0];
47+
const int64_t n_heads = q->ne[1];
48+
const int64_t n_tokens = q->ne[2];
49+
const int64_t n_seqs = q->ne[3];
50+
51+
const int64_t n_kv_heads = k_cache->ne[2];
52+
const int64_t num_blocks = k_cache->ne[0];
53+
54+
const int64_t max_blocks_per_seq = block_tables->ne[0];
55+
56+
// Validate tensor dimensions
57+
GGML_ASSERT(n_tokens > 0 && "Number of query tokens must be positive");
58+
GGML_ASSERT(n_seqs > 0 && "Number of sequences must be positive");
59+
GGML_ASSERT(num_blocks > 0 && "Number of KV cache blocks must be positive");
60+
GGML_ASSERT(max_blocks_per_seq > 0 && "Max blocks per sequence must be positive");
61+
62+
// Validate that we have enough blocks available
63+
// Note: This is a soft check - actual usage depends on sequence lengths
64+
GGML_ASSERT(num_blocks >= max_blocks_per_seq &&
65+
"Total number of blocks should be >= max blocks per sequence");
66+
67+
// For PagedAttention, typically we have one query per sequence (decode mode)
68+
// or multiple queries per sequence (prefill mode)
69+
GGML_ASSERT(n_tokens <= n_seqs * 1024 &&
70+
"Number of tokens seems unusually large relative to batch size");
71+
72+
// Get pointers
73+
void * out_ptr = dst->data;
74+
const void * q_ptr = q->data;
75+
const void * k_cache_ptr = k_cache->data;
76+
const void * v_cache_ptr = v_cache->data;
77+
const int32_t * block_tables_ptr = (const int32_t *)block_tables->data;
78+
const int32_t * seq_lens_ptr = (const int32_t *)seq_lens->data;
79+
80+
// Get ALiBi slopes pointer if provided
81+
const float * alibi_slopes_ptr = nullptr;
82+
if (alibi_slopes != nullptr) {
83+
// ALiBi slopes should be a 1D tensor with one slope per attention head
84+
GGML_ASSERT(alibi_slopes->type == GGML_TYPE_F32 &&
85+
"ALiBi slopes must be float32");
86+
GGML_ASSERT(alibi_slopes->ne[0] == n_heads &&
87+
"ALiBi slopes tensor must have one value per head");
88+
alibi_slopes_ptr = (const float *)alibi_slopes->data;
89+
}
90+
91+
// Calculate max sequence length (needed to decide V1 vs V2)
92+
int max_seq_len = 0;
93+
for (int i = 0; i < n_seqs; i++) {
94+
if (seq_lens_ptr[i] > max_seq_len) {
95+
max_seq_len = seq_lens_ptr[i];
96+
}
97+
}
98+
99+
// Get CUDA stream
100+
cudaStream_t stream = ctx.stream();
101+
102+
// Decide whether to use V1 or V2
103+
const bool use_v1 = ggml_cuda_paged_attention::should_use_v1(
104+
max_seq_len, n_seqs, n_heads);
105+
106+
// Launch appropriate kernel
107+
if (use_v1) {
108+
ggml_cuda_paged_attention::paged_attention_v1_launcher(
109+
out_ptr,
110+
q_ptr,
111+
k_cache_ptr,
112+
v_cache_ptr,
113+
n_seqs,
114+
n_heads,
115+
n_kv_heads,
116+
head_size,
117+
block_size,
118+
max_blocks_per_seq,
119+
block_tables_ptr,
120+
seq_lens_ptr,
121+
max_seq_len,
122+
scale,
123+
alibi_slopes_ptr,
124+
q->type,
125+
k_cache->type,
126+
stream);
127+
} else {
128+
ggml_cuda_paged_attention::paged_attention_v2_launcher(
129+
out_ptr,
130+
q_ptr,
131+
k_cache_ptr,
132+
v_cache_ptr,
133+
n_seqs,
134+
n_heads,
135+
n_kv_heads,
136+
head_size,
137+
block_size,
138+
max_blocks_per_seq,
139+
block_tables_ptr,
140+
seq_lens_ptr,
141+
max_seq_len,
142+
scale,
143+
alibi_slopes_ptr,
144+
q->type,
145+
k_cache->type,
146+
ctx.pool(),
147+
stream);
148+
}
149+
150+
// Check for errors
151+
CUDA_CHECK(cudaGetLastError());
152+
}
153+
154+
// Check if PagedAttention is supported for given configuration
155+
bool ggml_cuda_can_paged_attention(const ggml_tensor * dst) {
156+
const ggml_tensor * q = dst->src[0];
157+
const ggml_tensor * k_cache = dst->src[1];
158+
159+
// Check data types
160+
if (q->type != GGML_TYPE_F16 && q->type != GGML_TYPE_F32) {
161+
return false;
162+
}
163+
164+
if (k_cache->type != GGML_TYPE_F16 && k_cache->type != GGML_TYPE_F32) {
165+
return false;
166+
}
167+
168+
// Check head size is supported
169+
const int64_t head_size = q->ne[0];
170+
const int supported_head_sizes[] = {32, 64, 80, 96, 112, 120, 128, 192, 256};
171+
bool head_size_supported = false;
172+
173+
for (int hs : supported_head_sizes) {
174+
if (head_size == hs) {
175+
head_size_supported = true;
176+
break;
177+
}
178+
}
179+
180+
if (!head_size_supported) {
181+
return false;
182+
}
183+
184+
// Extract block size and check it's supported
185+
float scale;
186+
int32_t block_size;
187+
ggml_cuda_op_paged_attention_get_params(dst, &scale, &block_size);
188+
189+
if (block_size != 8 && block_size != 16 && block_size != 32) {
190+
return false;
191+
}
192+
193+
return true;
194+
}
195+
196+
#else // GGML_USE_MUSA
197+
198+
// Stub implementations for MUSA (PagedAttention not yet supported)
199+
#include "common.cuh"
200+
201+
void ggml_cuda_op_paged_attention(
202+
ggml_backend_cuda_context & ctx,
203+
ggml_tensor * dst) {
204+
GGML_UNUSED(ctx);
205+
GGML_UNUSED(dst);
206+
GGML_ABORT("PagedAttention is not yet supported on MUSA");
207+
}
208+
209+
bool ggml_cuda_supports_paged_attention(const ggml_tensor * dst) {
210+
GGML_UNUSED(dst);
211+
return false;
212+
}
213+
214+
#endif // GGML_USE_MUSA

0 commit comments

Comments
 (0)