Skip to content

Commit 8c363ed

Browse files
authored
[ROCm][Attention] Sliding window support for AiterFlashAttentionBackend (#29234)
Signed-off-by: ganyi <ygan@amd.com>
1 parent 64bc09b commit 8c363ed

File tree

1 file changed

+224
-49
lines changed

1 file changed

+224
-49
lines changed

vllm/v1/attention/backends/rocm_aiter_fa.py

Lines changed: 224 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@
1313
AttentionType,
1414
MultipleOf,
1515
)
16+
from vllm.attention.layer import Attention
1617
from vllm.attention.ops.merge_attn_states import merge_attn_states
17-
from vllm.config import VllmConfig
18+
from vllm.config import VllmConfig, get_layers_from_vllm_config
1819
from vllm.logger import init_logger
1920
from vllm.platforms import current_platform
2021
from vllm.utils.math_utils import cdiv
@@ -57,58 +58,55 @@ def cp_mha_gather_cache_kernel(
5758
head_size,
5859
x,
5960
max_block_num,
60-
num_tokens,
61-
num_programs,
6261
DEQUANT: tl.constexpr,
6362
PAGE_SIZE: tl.constexpr,
6463
CACHE_FORMAT: tl.constexpr,
6564
BLOCK_SIZE: tl.constexpr,
6665
):
67-
bid = tl.program_id(0)
66+
token_id = tl.program_id(0)
6867
col_offsets = tl.arange(0, BLOCK_SIZE)
6968
if DEQUANT:
7069
k_scale = tl.load(k_scale_ptr)
7170
v_scale = tl.load(v_scale_ptr)
7271

73-
for token_id in tl.range(bid, num_tokens, num_programs):
74-
key_ptr_offset = key_ptr + token_id * head_size * num_heads
75-
value_ptr_offset = value_ptr + token_id * head_size * num_heads
76-
batch_idx = tl.load(token_to_batch_ptr + token_id)
77-
batch_start = tl.load(seq_start_ptr + batch_idx)
78-
token_start = tl.load(cu_seqlens_kv_ptr + batch_idx)
79-
batch_offset = token_id - token_start + batch_start
80-
block_offset = batch_offset // PAGE_SIZE
81-
block_id = tl.load(
82-
block_table_ptr + max_block_num * batch_idx + block_offset
72+
key_ptr_offset = key_ptr + token_id * head_size * num_heads
73+
value_ptr_offset = value_ptr + token_id * head_size * num_heads
74+
batch_idx = tl.load(token_to_batch_ptr + token_id)
75+
batch_start = tl.load(seq_start_ptr + batch_idx)
76+
token_start = tl.load(cu_seqlens_kv_ptr + batch_idx)
77+
batch_offset = token_id - token_start + batch_start
78+
block_offset = batch_offset // PAGE_SIZE
79+
block_id = tl.load(
80+
block_table_ptr + max_block_num * batch_idx + block_offset
81+
).to(tl.int64)
82+
slot_id = batch_offset % PAGE_SIZE
83+
84+
if CACHE_FORMAT == "NHD":
85+
# for kv cache layout as
86+
# K: [num_blocks, page_size, num_head, head_dim]
87+
# V: [num_blocks, page_size, num_head, head_dim]
88+
key_cache_ptr_offset = (
89+
key_cache_ptr
90+
+ block_id * num_heads * head_size * PAGE_SIZE
91+
+ slot_id * num_heads * head_size
92+
)
93+
value_cache_ptr_offset = (
94+
value_cache_ptr
95+
+ block_id * num_heads * head_size * PAGE_SIZE
96+
+ slot_id * num_heads * head_size
8397
)
84-
slot_id = batch_offset % PAGE_SIZE
85-
86-
if CACHE_FORMAT == "NHD":
87-
# for kv cache layout as
88-
# K: [num_blocks, page_size, num_head, head_dim]
89-
# V: [num_blocks, page_size, num_head, head_dim]
90-
key_cache_ptr_offset = (
91-
key_cache_ptr
92-
+ block_id * num_heads * head_size * PAGE_SIZE
93-
+ slot_id * num_heads * head_size
94-
)
95-
value_cache_ptr_offset = (
96-
value_cache_ptr
97-
+ block_id * num_heads * head_size * PAGE_SIZE
98-
+ slot_id * num_heads * head_size
99-
)
10098

101-
for i in tl.range(0, head_size * num_heads, BLOCK_SIZE):
102-
mask = (col_offsets + i) < head_size * num_heads
103-
k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask)
104-
v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask)
105-
if DEQUANT:
106-
k_dtype = k_reg.dtype
107-
v_dtype = v_reg.dtype
108-
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
109-
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
110-
tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask)
111-
tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask)
99+
for i in tl.range(0, head_size * num_heads, BLOCK_SIZE):
100+
mask = (col_offsets + i) < head_size * num_heads
101+
k_reg = tl.load(key_cache_ptr_offset + col_offsets + i, mask=mask)
102+
v_reg = tl.load(value_cache_ptr_offset + col_offsets + i, mask=mask)
103+
if DEQUANT:
104+
k_dtype = k_reg.dtype
105+
v_dtype = v_reg.dtype
106+
k_reg = (k_reg.to(tl.float32) * k_scale).to(k_dtype)
107+
v_reg = (v_reg.to(tl.float32) * v_scale).to(v_dtype)
108+
tl.store(key_ptr_offset + col_offsets + i, k_reg, mask=mask)
109+
tl.store(value_ptr_offset + col_offsets + i, v_reg, mask=mask)
112110

113111
def cp_mha_gather_cache(
114112
key_cache: torch.Tensor,
@@ -143,9 +141,7 @@ def cp_mha_gather_cache(
143141
page_size = key_cache.shape[1]
144142
num_heads = key_cache.shape[2]
145143

146-
NUM_PRGMS = num_programs(total_tokens)
147-
BLOCK_SIZE = block_size(key_cache, head_dim)
148-
grid = lambda meta: (NUM_PRGMS,)
144+
grid = lambda meta: (total_tokens,)
149145
cp_mha_gather_cache_kernel[grid](
150146
key_cache,
151147
value_cache,
@@ -161,12 +157,10 @@ def cp_mha_gather_cache(
161157
head_dim,
162158
x,
163159
block_tables.size(1),
164-
total_tokens,
165-
NUM_PRGMS,
166160
DEQUANT=dequant,
167161
PAGE_SIZE=page_size,
168162
CACHE_FORMAT=kv_cache_layout,
169-
BLOCK_SIZE=BLOCK_SIZE,
163+
BLOCK_SIZE=head_dim,
170164
)
171165

172166

@@ -189,6 +183,17 @@ class AiterFlashAttentionPrefillMetadata:
189183
query_start_loc: torch.Tensor
190184

191185

186+
@dataclass
187+
class AiterChunkSlidingWindowMetadata:
188+
swa_seqlens: torch.Tensor
189+
swa_cu_seqlens: torch.Tensor
190+
swa_seq_starts: torch.Tensor
191+
swa_token_to_batch: torch.Tensor
192+
swa_max_seqlens: int
193+
swa_total_tokens: int
194+
swa_workspace: torch.Tensor
195+
196+
192197
@dataclass
193198
class AiterChunkContextMetadata:
194199
workspace: torch.Tensor
@@ -200,6 +205,7 @@ class AiterChunkContextMetadata:
200205
seq_lens: torch.Tensor
201206
num_chunks: int
202207
total_token_per_batch: list[int]
208+
swa_metadata: AiterChunkSlidingWindowMetadata | None
203209

204210

205211
@dataclass
@@ -278,6 +284,20 @@ def __init__(
278284
self.aot_sliding_window: tuple[int, int] | None = None
279285
self.total_tokens: int = 0
280286

287+
sliding_window_configs: set[tuple[int, int] | None] = set()
288+
layers = get_layers_from_vllm_config(self.vllm_config, Attention)
289+
for layer in layers.values():
290+
assert isinstance(layer.impl, AiterFlashAttentionImpl)
291+
sliding_window_configs.add(layer.impl.sliding_window)
292+
293+
while len(sliding_window_configs) > 0:
294+
sliding_window_config = sliding_window_configs.pop()
295+
if sliding_window_config is not None and sliding_window_config[0] != -1:
296+
assert self.aot_sliding_window is None, (
297+
"Aiter Flash ATTENTION can only support one valid sliding window!"
298+
)
299+
self.aot_sliding_window = sliding_window_config
300+
281301
self.extend_workspace = torch.empty(
282302
[2, _CP_TOKENS_PER_ITER_ROCM, self.num_heads_kv, self.headdim],
283303
dtype=self.model_config.dtype,
@@ -349,6 +369,55 @@ def build(
349369
query_lens_for_extend = query_lens_cpu[num_extends_slice]
350370
seq_lens_for_extend = common_attn_metadata.seq_lens_cpu[num_extends_slice]
351371
computed_kv_lens = seq_lens_for_extend - query_lens_for_extend
372+
swa_metadata = None
373+
if self.aot_sliding_window is not None:
374+
swa_seqlen_for_extend = torch.minimum(
375+
seq_lens_for_extend,
376+
query_lens_for_extend + self.aot_sliding_window[0] + 1,
377+
)
378+
cu_seq_lens = torch.zeros(
379+
num_extends + 1,
380+
dtype=torch.int32,
381+
device=seq_lens_for_extend.device,
382+
)
383+
torch.cumsum(
384+
swa_seqlen_for_extend,
385+
dim=0,
386+
dtype=cu_seq_lens.dtype,
387+
out=cu_seq_lens[1:],
388+
)
389+
token_to_seq = torch.arange(
390+
0,
391+
num_extends,
392+
dtype=torch.int32,
393+
device=seq_lens_for_extend.device,
394+
)
395+
token_to_seq = torch.repeat_interleave(
396+
token_to_seq, swa_seqlen_for_extend
397+
)
398+
fetched_shape = cu_seq_lens[-1].item()
399+
# TODO(ganyi): Maybe reuse these 2 buffer from extend_workspace
400+
swa_workspace = torch.empty(
401+
(2, fetched_shape, self.num_heads_kv, self.headdim),
402+
dtype=self.vllm_config.model_config.dtype,
403+
device=self.device,
404+
)
405+
406+
seq_starts = seq_lens_for_extend - swa_seqlen_for_extend
407+
max_seqlen_k = swa_seqlen_for_extend.max().item()
408+
total_tokens = cu_seq_lens[-1].item()
409+
410+
swa_metadata = AiterChunkSlidingWindowMetadata(
411+
swa_seqlens=swa_seqlen_for_extend.to(
412+
self.device, non_blocking=True
413+
),
414+
swa_cu_seqlens=cu_seq_lens.to(self.device, non_blocking=True),
415+
swa_seq_starts=seq_starts.to(self.device, non_blocking=True),
416+
swa_token_to_batch=token_to_seq.to(self.device, non_blocking=True),
417+
swa_max_seqlens=max_seqlen_k,
418+
swa_total_tokens=total_tokens,
419+
swa_workspace=swa_workspace,
420+
)
352421

353422
# allocate the equal amount of workspace for
354423
# each chunk prefill request
@@ -392,6 +461,7 @@ def build(
392461
token_to_batch=token_to_batch_tensor.to(self.device, non_blocking=True),
393462
num_chunks=num_chunks,
394463
total_token_per_batch=cu_seq_lens_cpu[:, -1].tolist(),
464+
swa_metadata=swa_metadata,
395465
)
396466

397467
query_start_loc_device = common_attn_metadata.query_start_loc[
@@ -504,9 +574,9 @@ def __init__(
504574
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
505575
self.alibi_slopes = alibi_slopes
506576
if sliding_window is None:
507-
self.sliding_window = [-1, -1]
577+
self.sliding_window = (-1, -1)
508578
else:
509-
self.sliding_window = [sliding_window - 1, 0]
579+
self.sliding_window = (sliding_window - 1, 0)
510580
self.kv_cache_dtype = kv_cache_dtype
511581
if logits_soft_cap is None:
512582
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
@@ -522,6 +592,67 @@ def __init__(
522592
"Encoder self-attention is not implemented for FlashAttentionImpl"
523593
)
524594

595+
def extend_for_sliding_window(
596+
self,
597+
attn_metadata: AiterFlashAttentionMetadata,
598+
query: torch.Tensor,
599+
key_cache,
600+
value_cache,
601+
output: torch.Tensor,
602+
cu_seqlens_q: torch.Tensor,
603+
max_seqlen_q: int,
604+
block_table: torch.Tensor,
605+
k_scale: float,
606+
v_scale: float,
607+
):
608+
assert attn_metadata.extend_metadata is not None
609+
assert attn_metadata.extend_metadata.chunk_context_metadata is not None
610+
chunked_metadata = attn_metadata.extend_metadata.chunk_context_metadata
611+
swa_metadata = chunked_metadata.swa_metadata
612+
assert swa_metadata is not None
613+
swa_cu_seqlens = swa_metadata.swa_cu_seqlens
614+
swa_seq_starts = swa_metadata.swa_seq_starts
615+
swa_token_to_batch = swa_metadata.swa_token_to_batch
616+
swa_max_seqlens = swa_metadata.swa_max_seqlens
617+
swa_total_tokens = swa_metadata.swa_total_tokens
618+
key_fetched, value_fetched = (
619+
swa_metadata.swa_workspace[0],
620+
swa_metadata.swa_workspace[1],
621+
)
622+
cp_mha_gather_cache(
623+
key_cache=key_cache,
624+
value_cache=value_cache,
625+
key=key_fetched,
626+
value=value_fetched,
627+
block_tables=block_table,
628+
k_scales=k_scale,
629+
v_scales=v_scale,
630+
cu_seqlens_kv=swa_cu_seqlens,
631+
token_to_batch=swa_token_to_batch,
632+
seq_starts=swa_seq_starts,
633+
dequant=False,
634+
kv_cache_layout="NHD",
635+
total_tokens=swa_total_tokens,
636+
)
637+
638+
aiter.flash_attn_varlen_func(
639+
q=query,
640+
k=key_fetched,
641+
v=value_fetched,
642+
cu_seqlens_q=cu_seqlens_q,
643+
cu_seqlens_k=swa_cu_seqlens,
644+
max_seqlen_q=max_seqlen_q,
645+
max_seqlen_k=swa_max_seqlens,
646+
min_seqlen_q=1,
647+
dropout_p=0.0,
648+
softmax_scale=self.scale,
649+
causal=True,
650+
window_size=self.sliding_window,
651+
alibi_slopes=self.alibi_slopes,
652+
return_lse=False,
653+
out=output,
654+
)
655+
525656
def extend_forward(
526657
self,
527658
attn_metadata: AiterFlashAttentionMetadata,
@@ -540,6 +671,20 @@ def extend_forward(
540671
k_scale: float,
541672
v_scale: float,
542673
):
674+
if self.sliding_window[0] != -1:
675+
self.extend_for_sliding_window(
676+
attn_metadata,
677+
query,
678+
key_cache,
679+
value_cache,
680+
output,
681+
cu_seqlens_q,
682+
max_seqlen_q,
683+
block_table,
684+
k_scale,
685+
v_scale,
686+
)
687+
return
543688
out, lse = aiter.flash_attn_varlen_func(
544689
q=query,
545690
k=key,
@@ -782,6 +927,36 @@ def forward(
782927
# calculate for decodes
783928
if num_decodes > 0:
784929
assert attn_metadata.decode_metadata is not None
930+
if self.sliding_window[0] != -1:
931+
from aiter.ops.triton.unified_attention import (
932+
unified_attention,
933+
)
934+
935+
descale_shape = (
936+
attn_metadata.query_start_loc[:num_decodes].shape[0] - 1,
937+
key_cache.shape[2],
938+
)
939+
unified_attention(
940+
q=query[:num_decode_tokens],
941+
k=key_cache,
942+
v=value_cache,
943+
out=output[:num_decode_tokens],
944+
cu_seqlens_q=attn_metadata.query_start_loc[:num_decodes],
945+
max_seqlen_q=1, # optimize this
946+
seqused_k=attn_metadata.seq_lens[:num_decodes],
947+
max_seqlen_k=attn_metadata.max_seq_len,
948+
softmax_scale=self.scale,
949+
causal=True,
950+
alibi_slopes=self.alibi_slopes,
951+
window_size=self.sliding_window,
952+
block_table=attn_metadata.block_table[:num_decodes],
953+
softcap=self.logits_soft_cap,
954+
q_descale=None,
955+
k_descale=layer._k_scale.expand(descale_shape),
956+
v_descale=layer._v_scale.expand(descale_shape),
957+
)
958+
return
959+
assert attn_metadata.decode_metadata is not None
785960
_, num_heads, head_size = query.shape
786961
nbytes_per_qo_elem = torch.finfo(query.dtype).bits // 8
787962
num_seqs = attn_metadata.seq_lens.shape[0]

0 commit comments

Comments
 (0)