1313 AttentionType ,
1414 MultipleOf ,
1515)
16+ from vllm .attention .layer import Attention
1617from vllm .attention .ops .merge_attn_states import merge_attn_states
17- from vllm .config import VllmConfig
18+ from vllm .config import VllmConfig , get_layers_from_vllm_config
1819from vllm .logger import init_logger
1920from vllm .platforms import current_platform
2021from vllm .utils .math_utils import cdiv
@@ -57,58 +58,55 @@ def cp_mha_gather_cache_kernel(
5758 head_size ,
5859 x ,
5960 max_block_num ,
60- num_tokens ,
61- num_programs ,
6261 DEQUANT : tl .constexpr ,
6362 PAGE_SIZE : tl .constexpr ,
6463 CACHE_FORMAT : tl .constexpr ,
6564 BLOCK_SIZE : tl .constexpr ,
6665 ):
67- bid = tl .program_id (0 )
66+ token_id = tl .program_id (0 )
6867 col_offsets = tl .arange (0 , BLOCK_SIZE )
6968 if DEQUANT :
7069 k_scale = tl .load (k_scale_ptr )
7170 v_scale = tl .load (v_scale_ptr )
7271
73- for token_id in tl .range (bid , num_tokens , num_programs ):
74- key_ptr_offset = key_ptr + token_id * head_size * num_heads
75- value_ptr_offset = value_ptr + token_id * head_size * num_heads
76- batch_idx = tl .load (token_to_batch_ptr + token_id )
77- batch_start = tl .load (seq_start_ptr + batch_idx )
78- token_start = tl .load (cu_seqlens_kv_ptr + batch_idx )
79- batch_offset = token_id - token_start + batch_start
80- block_offset = batch_offset // PAGE_SIZE
81- block_id = tl .load (
82- block_table_ptr + max_block_num * batch_idx + block_offset
72+ key_ptr_offset = key_ptr + token_id * head_size * num_heads
73+ value_ptr_offset = value_ptr + token_id * head_size * num_heads
74+ batch_idx = tl .load (token_to_batch_ptr + token_id )
75+ batch_start = tl .load (seq_start_ptr + batch_idx )
76+ token_start = tl .load (cu_seqlens_kv_ptr + batch_idx )
77+ batch_offset = token_id - token_start + batch_start
78+ block_offset = batch_offset // PAGE_SIZE
79+ block_id = tl .load (
80+ block_table_ptr + max_block_num * batch_idx + block_offset
81+ ).to (tl .int64 )
82+ slot_id = batch_offset % PAGE_SIZE
83+
84+ if CACHE_FORMAT == "NHD" :
85+ # for kv cache layout as
86+ # K: [num_blocks, page_size, num_head, head_dim]
87+ # V: [num_blocks, page_size, num_head, head_dim]
88+ key_cache_ptr_offset = (
89+ key_cache_ptr
90+ + block_id * num_heads * head_size * PAGE_SIZE
91+ + slot_id * num_heads * head_size
92+ )
93+ value_cache_ptr_offset = (
94+ value_cache_ptr
95+ + block_id * num_heads * head_size * PAGE_SIZE
96+ + slot_id * num_heads * head_size
8397 )
84- slot_id = batch_offset % PAGE_SIZE
85-
86- if CACHE_FORMAT == "NHD" :
87- # for kv cache layout as
88- # K: [num_blocks, page_size, num_head, head_dim]
89- # V: [num_blocks, page_size, num_head, head_dim]
90- key_cache_ptr_offset = (
91- key_cache_ptr
92- + block_id * num_heads * head_size * PAGE_SIZE
93- + slot_id * num_heads * head_size
94- )
95- value_cache_ptr_offset = (
96- value_cache_ptr
97- + block_id * num_heads * head_size * PAGE_SIZE
98- + slot_id * num_heads * head_size
99- )
10098
101- for i in tl .range (0 , head_size * num_heads , BLOCK_SIZE ):
102- mask = (col_offsets + i ) < head_size * num_heads
103- k_reg = tl .load (key_cache_ptr_offset + col_offsets + i , mask = mask )
104- v_reg = tl .load (value_cache_ptr_offset + col_offsets + i , mask = mask )
105- if DEQUANT :
106- k_dtype = k_reg .dtype
107- v_dtype = v_reg .dtype
108- k_reg = (k_reg .to (tl .float32 ) * k_scale ).to (k_dtype )
109- v_reg = (v_reg .to (tl .float32 ) * v_scale ).to (v_dtype )
110- tl .store (key_ptr_offset + col_offsets + i , k_reg , mask = mask )
111- tl .store (value_ptr_offset + col_offsets + i , v_reg , mask = mask )
99+ for i in tl .range (0 , head_size * num_heads , BLOCK_SIZE ):
100+ mask = (col_offsets + i ) < head_size * num_heads
101+ k_reg = tl .load (key_cache_ptr_offset + col_offsets + i , mask = mask )
102+ v_reg = tl .load (value_cache_ptr_offset + col_offsets + i , mask = mask )
103+ if DEQUANT :
104+ k_dtype = k_reg .dtype
105+ v_dtype = v_reg .dtype
106+ k_reg = (k_reg .to (tl .float32 ) * k_scale ).to (k_dtype )
107+ v_reg = (v_reg .to (tl .float32 ) * v_scale ).to (v_dtype )
108+ tl .store (key_ptr_offset + col_offsets + i , k_reg , mask = mask )
109+ tl .store (value_ptr_offset + col_offsets + i , v_reg , mask = mask )
112110
113111 def cp_mha_gather_cache (
114112 key_cache : torch .Tensor ,
@@ -143,9 +141,7 @@ def cp_mha_gather_cache(
143141 page_size = key_cache .shape [1 ]
144142 num_heads = key_cache .shape [2 ]
145143
146- NUM_PRGMS = num_programs (total_tokens )
147- BLOCK_SIZE = block_size (key_cache , head_dim )
148- grid = lambda meta : (NUM_PRGMS ,)
144+ grid = lambda meta : (total_tokens ,)
149145 cp_mha_gather_cache_kernel [grid ](
150146 key_cache ,
151147 value_cache ,
@@ -161,12 +157,10 @@ def cp_mha_gather_cache(
161157 head_dim ,
162158 x ,
163159 block_tables .size (1 ),
164- total_tokens ,
165- NUM_PRGMS ,
166160 DEQUANT = dequant ,
167161 PAGE_SIZE = page_size ,
168162 CACHE_FORMAT = kv_cache_layout ,
169- BLOCK_SIZE = BLOCK_SIZE ,
163+ BLOCK_SIZE = head_dim ,
170164 )
171165
172166
@@ -189,6 +183,17 @@ class AiterFlashAttentionPrefillMetadata:
189183 query_start_loc : torch .Tensor
190184
191185
186+ @dataclass
187+ class AiterChunkSlidingWindowMetadata :
188+ swa_seqlens : torch .Tensor
189+ swa_cu_seqlens : torch .Tensor
190+ swa_seq_starts : torch .Tensor
191+ swa_token_to_batch : torch .Tensor
192+ swa_max_seqlens : int
193+ swa_total_tokens : int
194+ swa_workspace : torch .Tensor
195+
196+
192197@dataclass
193198class AiterChunkContextMetadata :
194199 workspace : torch .Tensor
@@ -200,6 +205,7 @@ class AiterChunkContextMetadata:
200205 seq_lens : torch .Tensor
201206 num_chunks : int
202207 total_token_per_batch : list [int ]
208+ swa_metadata : AiterChunkSlidingWindowMetadata | None
203209
204210
205211@dataclass
@@ -278,6 +284,20 @@ def __init__(
278284 self .aot_sliding_window : tuple [int , int ] | None = None
279285 self .total_tokens : int = 0
280286
287+ sliding_window_configs : set [tuple [int , int ] | None ] = set ()
288+ layers = get_layers_from_vllm_config (self .vllm_config , Attention )
289+ for layer in layers .values ():
290+ assert isinstance (layer .impl , AiterFlashAttentionImpl )
291+ sliding_window_configs .add (layer .impl .sliding_window )
292+
293+ while len (sliding_window_configs ) > 0 :
294+ sliding_window_config = sliding_window_configs .pop ()
295+ if sliding_window_config is not None and sliding_window_config [0 ] != - 1 :
296+ assert self .aot_sliding_window is None , (
297+ "Aiter Flash ATTENTION can only support one valid sliding window!"
298+ )
299+ self .aot_sliding_window = sliding_window_config
300+
281301 self .extend_workspace = torch .empty (
282302 [2 , _CP_TOKENS_PER_ITER_ROCM , self .num_heads_kv , self .headdim ],
283303 dtype = self .model_config .dtype ,
@@ -349,6 +369,55 @@ def build(
349369 query_lens_for_extend = query_lens_cpu [num_extends_slice ]
350370 seq_lens_for_extend = common_attn_metadata .seq_lens_cpu [num_extends_slice ]
351371 computed_kv_lens = seq_lens_for_extend - query_lens_for_extend
372+ swa_metadata = None
373+ if self .aot_sliding_window is not None :
374+ swa_seqlen_for_extend = torch .minimum (
375+ seq_lens_for_extend ,
376+ query_lens_for_extend + self .aot_sliding_window [0 ] + 1 ,
377+ )
378+ cu_seq_lens = torch .zeros (
379+ num_extends + 1 ,
380+ dtype = torch .int32 ,
381+ device = seq_lens_for_extend .device ,
382+ )
383+ torch .cumsum (
384+ swa_seqlen_for_extend ,
385+ dim = 0 ,
386+ dtype = cu_seq_lens .dtype ,
387+ out = cu_seq_lens [1 :],
388+ )
389+ token_to_seq = torch .arange (
390+ 0 ,
391+ num_extends ,
392+ dtype = torch .int32 ,
393+ device = seq_lens_for_extend .device ,
394+ )
395+ token_to_seq = torch .repeat_interleave (
396+ token_to_seq , swa_seqlen_for_extend
397+ )
398+ fetched_shape = cu_seq_lens [- 1 ].item ()
399+ # TODO(ganyi): Maybe reuse these 2 buffer from extend_workspace
400+ swa_workspace = torch .empty (
401+ (2 , fetched_shape , self .num_heads_kv , self .headdim ),
402+ dtype = self .vllm_config .model_config .dtype ,
403+ device = self .device ,
404+ )
405+
406+ seq_starts = seq_lens_for_extend - swa_seqlen_for_extend
407+ max_seqlen_k = swa_seqlen_for_extend .max ().item ()
408+ total_tokens = cu_seq_lens [- 1 ].item ()
409+
410+ swa_metadata = AiterChunkSlidingWindowMetadata (
411+ swa_seqlens = swa_seqlen_for_extend .to (
412+ self .device , non_blocking = True
413+ ),
414+ swa_cu_seqlens = cu_seq_lens .to (self .device , non_blocking = True ),
415+ swa_seq_starts = seq_starts .to (self .device , non_blocking = True ),
416+ swa_token_to_batch = token_to_seq .to (self .device , non_blocking = True ),
417+ swa_max_seqlens = max_seqlen_k ,
418+ swa_total_tokens = total_tokens ,
419+ swa_workspace = swa_workspace ,
420+ )
352421
353422 # allocate the equal amount of workspace for
354423 # each chunk prefill request
@@ -392,6 +461,7 @@ def build(
392461 token_to_batch = token_to_batch_tensor .to (self .device , non_blocking = True ),
393462 num_chunks = num_chunks ,
394463 total_token_per_batch = cu_seq_lens_cpu [:, - 1 ].tolist (),
464+ swa_metadata = swa_metadata ,
395465 )
396466
397467 query_start_loc_device = common_attn_metadata .query_start_loc [
@@ -504,9 +574,9 @@ def __init__(
504574 alibi_slopes = torch .tensor (alibi_slopes , dtype = torch .float32 )
505575 self .alibi_slopes = alibi_slopes
506576 if sliding_window is None :
507- self .sliding_window = [ - 1 , - 1 ]
577+ self .sliding_window = ( - 1 , - 1 )
508578 else :
509- self .sliding_window = [ sliding_window - 1 , 0 ]
579+ self .sliding_window = ( sliding_window - 1 , 0 )
510580 self .kv_cache_dtype = kv_cache_dtype
511581 if logits_soft_cap is None :
512582 # In flash-attn, setting logits_soft_cap as 0 means no soft cap.
@@ -522,6 +592,67 @@ def __init__(
522592 "Encoder self-attention is not implemented for FlashAttentionImpl"
523593 )
524594
595+ def extend_for_sliding_window (
596+ self ,
597+ attn_metadata : AiterFlashAttentionMetadata ,
598+ query : torch .Tensor ,
599+ key_cache ,
600+ value_cache ,
601+ output : torch .Tensor ,
602+ cu_seqlens_q : torch .Tensor ,
603+ max_seqlen_q : int ,
604+ block_table : torch .Tensor ,
605+ k_scale : float ,
606+ v_scale : float ,
607+ ):
608+ assert attn_metadata .extend_metadata is not None
609+ assert attn_metadata .extend_metadata .chunk_context_metadata is not None
610+ chunked_metadata = attn_metadata .extend_metadata .chunk_context_metadata
611+ swa_metadata = chunked_metadata .swa_metadata
612+ assert swa_metadata is not None
613+ swa_cu_seqlens = swa_metadata .swa_cu_seqlens
614+ swa_seq_starts = swa_metadata .swa_seq_starts
615+ swa_token_to_batch = swa_metadata .swa_token_to_batch
616+ swa_max_seqlens = swa_metadata .swa_max_seqlens
617+ swa_total_tokens = swa_metadata .swa_total_tokens
618+ key_fetched , value_fetched = (
619+ swa_metadata .swa_workspace [0 ],
620+ swa_metadata .swa_workspace [1 ],
621+ )
622+ cp_mha_gather_cache (
623+ key_cache = key_cache ,
624+ value_cache = value_cache ,
625+ key = key_fetched ,
626+ value = value_fetched ,
627+ block_tables = block_table ,
628+ k_scales = k_scale ,
629+ v_scales = v_scale ,
630+ cu_seqlens_kv = swa_cu_seqlens ,
631+ token_to_batch = swa_token_to_batch ,
632+ seq_starts = swa_seq_starts ,
633+ dequant = False ,
634+ kv_cache_layout = "NHD" ,
635+ total_tokens = swa_total_tokens ,
636+ )
637+
638+ aiter .flash_attn_varlen_func (
639+ q = query ,
640+ k = key_fetched ,
641+ v = value_fetched ,
642+ cu_seqlens_q = cu_seqlens_q ,
643+ cu_seqlens_k = swa_cu_seqlens ,
644+ max_seqlen_q = max_seqlen_q ,
645+ max_seqlen_k = swa_max_seqlens ,
646+ min_seqlen_q = 1 ,
647+ dropout_p = 0.0 ,
648+ softmax_scale = self .scale ,
649+ causal = True ,
650+ window_size = self .sliding_window ,
651+ alibi_slopes = self .alibi_slopes ,
652+ return_lse = False ,
653+ out = output ,
654+ )
655+
525656 def extend_forward (
526657 self ,
527658 attn_metadata : AiterFlashAttentionMetadata ,
@@ -540,6 +671,20 @@ def extend_forward(
540671 k_scale : float ,
541672 v_scale : float ,
542673 ):
674+ if self .sliding_window [0 ] != - 1 :
675+ self .extend_for_sliding_window (
676+ attn_metadata ,
677+ query ,
678+ key_cache ,
679+ value_cache ,
680+ output ,
681+ cu_seqlens_q ,
682+ max_seqlen_q ,
683+ block_table ,
684+ k_scale ,
685+ v_scale ,
686+ )
687+ return
543688 out , lse = aiter .flash_attn_varlen_func (
544689 q = query ,
545690 k = key ,
@@ -782,6 +927,36 @@ def forward(
782927 # calculate for decodes
783928 if num_decodes > 0 :
784929 assert attn_metadata .decode_metadata is not None
930+ if self .sliding_window [0 ] != - 1 :
931+ from aiter .ops .triton .unified_attention import (
932+ unified_attention ,
933+ )
934+
935+ descale_shape = (
936+ attn_metadata .query_start_loc [:num_decodes ].shape [0 ] - 1 ,
937+ key_cache .shape [2 ],
938+ )
939+ unified_attention (
940+ q = query [:num_decode_tokens ],
941+ k = key_cache ,
942+ v = value_cache ,
943+ out = output [:num_decode_tokens ],
944+ cu_seqlens_q = attn_metadata .query_start_loc [:num_decodes ],
945+ max_seqlen_q = 1 , # optimize this
946+ seqused_k = attn_metadata .seq_lens [:num_decodes ],
947+ max_seqlen_k = attn_metadata .max_seq_len ,
948+ softmax_scale = self .scale ,
949+ causal = True ,
950+ alibi_slopes = self .alibi_slopes ,
951+ window_size = self .sliding_window ,
952+ block_table = attn_metadata .block_table [:num_decodes ],
953+ softcap = self .logits_soft_cap ,
954+ q_descale = None ,
955+ k_descale = layer ._k_scale .expand (descale_shape ),
956+ v_descale = layer ._v_scale .expand (descale_shape ),
957+ )
958+ return
959+ assert attn_metadata .decode_metadata is not None
785960 _ , num_heads , head_size = query .shape
786961 nbytes_per_qo_elem = torch .finfo (query .dtype ).bits // 8
787962 num_seqs = attn_metadata .seq_lens .shape [0 ]
0 commit comments