11import os
22import time
3- from typing import TYPE_CHECKING , Any , Callable , List , Optional , Tuple
3+ from typing import TYPE_CHECKING , Any , Callable , Dict , List , Optional , Tuple
44
55import jax
66import jax .numpy as jnp
@@ -135,12 +135,6 @@ def _precompile_backbone_helper(self, name, *, input_ids, positions,
135135 ShardingAxisName .ATTN_DATA , )) if dp_size > 1 else None
136136
137137 # Keep existing pattern for complex array operations
138- block_tables = self .runner .block_table_cpu [:self .runner .max_num_reqs ]
139- block_tables = block_tables .reshape (- 1 )
140- block_tables = device_array (self .runner .mesh ,
141- block_tables ,
142- sharding = dp_sharding )
143-
144138 seq_lens = self ._create_dummy_tensor ((self .runner .max_num_reqs , ),
145139 jnp .int32 , dp_sharding )
146140 query_start_loc = self ._create_dummy_tensor (
@@ -152,40 +146,64 @@ def _precompile_backbone_helper(self, name, *, input_ids, positions,
152146 request_distribution ,
153147 sharding = dp_sharding )
154148
155- attention_metadata = AttentionMetadata (
156- input_positions = positions ,
157- block_tables = block_tables ,
158- seq_lens = seq_lens ,
159- query_start_loc = query_start_loc ,
160- request_distribution = request_distribution ,
161- )
149+ attention_metadata_per_layer : Dict [str , AttentionMetadata ] = {}
150+ uniform_attention_metadata : AttentionMetadata = None
151+ for kv_cache_gid , kv_cache_group in enumerate (
152+ self .runner .kv_cache_config .kv_cache_groups ):
153+ block_tables = self .runner .block_tables_cpu [
154+ kv_cache_gid ][:self .runner .max_num_reqs ]
155+ block_tables = block_tables .reshape (- 1 )
156+ block_tables = device_array (self .runner .mesh ,
157+ block_tables ,
158+ sharding = dp_sharding )
159+
160+ attention_metadata_gid = AttentionMetadata (
161+ input_positions = positions ,
162+ block_tables = block_tables ,
163+ seq_lens = seq_lens ,
164+ query_start_loc = query_start_loc ,
165+ request_distribution = request_distribution ,
166+ )
167+ if not self .runner .use_hybrid_kvcache :
168+ # all layers share the same attention metadata
169+ uniform_attention_metadata = attention_metadata_gid
170+ else :
171+ for layer_name in kv_cache_group .layer_names :
172+ attention_metadata_per_layer [
173+ layer_name ] = attention_metadata_gid
162174
163175 def model_fn_wrapper (
164176 state ,
165177 kv_caches ,
166178 input_ids ,
167179 attention_metadata ,
180+ positions ,
168181 inputs_embeds ,
169182 layer_name_to_kvcache_index ,
170183 lora_metadata ,
171184 ):
172185 kv_caches , hidden_states , _ = self .runner .model_fn (
173186 state , kv_caches , input_ids , attention_metadata , inputs_embeds ,
174- layer_name_to_kvcache_index , lora_metadata )
187+ positions , layer_name_to_kvcache_index , lora_metadata )
175188 self .runner .kv_caches = kv_caches
176189 return hidden_states
177190
178191 with self .runner .maybe_select_dummy_loras (
179192 self .runner .lora_config , np .array ([num_tokens ],
180193 dtype = np .int32 )):
181194 lora_metadata = self .runner .lora_utils .extract_lora_metadata ()
195+ if self .runner .use_hybrid_kvcache :
196+ attention_metadata = attention_metadata_per_layer
197+ else :
198+ attention_metadata = uniform_attention_metadata
182199 self ._run_compilation (
183200 name ,
184201 model_fn_wrapper ,
185202 self .runner .state ,
186203 self .runner .kv_caches ,
187204 input_ids ,
188205 attention_metadata ,
206+ positions ,
189207 inputs_embeds ,
190208 tuple (self .runner .layer_name_to_kvcache_index .items ()),
191209 lora_metadata ,
0 commit comments