Skip to content

Commit cfc7610

Browse files
authored
Fix numerical issue on hybrid kv cache allocation (#1139)
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent f88d7a9 commit cfc7610

File tree

7 files changed

+163
-76
lines changed

7 files changed

+163
-76
lines changed

tests/runner/test_tpu_runner_dp.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,16 @@ def setup_method(self):
4646
self.runner.query_start_loc_cpu = np.zeros(10, dtype=np.int32)
4747
self.runner.seq_lens_cpu = np.zeros(8, dtype=np.int32)
4848
self.runner.logits_indices_cpu = np.zeros(8, dtype=np.int32)
49-
self.runner.block_table_cpu = np.zeros((8, 8), dtype=np.int32)
49+
self.runner.block_tables_cpu = [np.zeros((8, 8), dtype=np.int32)]
5050
self.runner.arange_cpu = np.arange(64, dtype=np.int64)
5151

52+
# mock kv cache group
53+
mock_kv_cache_config = MagicMock()
54+
mock_kv_cache_group = MagicMock()
55+
mock_kv_cache_config.kv_cache_groups = [mock_kv_cache_group]
56+
self.runner.kv_cache_config = mock_kv_cache_config
57+
self.runner.use_hybrid_kvcache = False
58+
5259
# Mock scheduler config for async scheduling
5360
self.runner.scheduler_config = MagicMock()
5461
self.runner.scheduler_config.async_scheduling = False # Default to False for most tests
@@ -102,8 +109,8 @@ def test_prepare_inputs_dp_basic_functionality(self,
102109
result = self.runner._prepare_inputs_dp(scheduler_output)
103110

104111
# Basic assertions
105-
assert len(result) == 7
106-
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
112+
assert len(result) == 8
113+
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
107114

108115
# Verify utility functions were called
109116
mock_runner_utils.get_padded_token_len.assert_called()
@@ -380,7 +387,7 @@ def mock_get_padded_token_len(paddings_list, val):
380387

381388
# Execute the method
382389
result = self.runner._prepare_inputs_dp(scheduler_output)
383-
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
390+
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
384391
# 1. Verify input_ids content
385392
expected_input_ids = np.zeros(16, dtype=np.int32)
386393
expected_input_ids[:2] = [1006, 1007]
@@ -494,7 +501,7 @@ def mock_get_padded_token_len(paddings_list, val):
494501

495502
# Execute the method
496503
result = self.runner._prepare_inputs_dp(scheduler_output)
497-
input_ids, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
504+
input_ids, positions, attention_metadata, sampling_metadata, logits_indices, spec_decode_metadata, logits_indices_selector, padded_num_reqs = result
498505

499506
# 1. Verify input_ids
500507
expected_input_ids = np.zeros(16, dtype=np.int32)

tpu_inference/models/common/model_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ def get_flax_model(
217217
hidden_states_sharding, # aux hidden states
218218
),
219219
donate_argnums=2, # 0 is graphdef, 1 is state, 2 is kv_cache
220-
static_argnums=6, #6 is layer_name_to_kvcache_index
220+
static_argnums=7, #7 is layer_name_to_kvcache_index
221221
)
222222
def run_model(graphdef, state, *args):
223223
model = nnx.merge(graphdef, state)

tpu_inference/models/vllm/vllm_model_wrapper.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,7 @@ def step_fun(
161161
input_ids: jax.Array,
162162
attn_metadata: AttentionMetadata,
163163
input_embeds: jax.Array,
164+
input_positions: jax.Array,
164165
layer_name_to_kvcache_index: Sequence[Tuple[str, int]],
165166
lora_metadata,
166167
intermediate_tensors: JaxIntermediateTensors = None,
@@ -187,8 +188,8 @@ def step_fun(
187188
torch_view(params_and_buffers),
188189
kwargs={
189190
"input_ids": torch_view(input_ids),
190-
"positions": torch_view(attn_metadata.input_positions),
191-
"intermediate_tensors": intermediate_tensors,
191+
"positions": torch_view(input_positions),
192+
"intermediate_tensors": None,
192193
"inputs_embeds": None,
193194
},
194195
tie_weights=False,

tpu_inference/platforms/tpu_platform.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,3 +266,7 @@ def use_sync_weight_loader(cls) -> bool:
266266
Returns if the current platform needs to sync weight loader.
267267
"""
268268
return True
269+
270+
@classmethod
271+
def support_hybrid_kv_cache(cls) -> bool:
272+
return True

tpu_inference/runner/compilation_manager.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
import time
3-
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
3+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
44

55
import jax
66
import jax.numpy as jnp
@@ -135,12 +135,6 @@ def _precompile_backbone_helper(self, name, *, input_ids, positions,
135135
ShardingAxisName.ATTN_DATA, )) if dp_size > 1 else None
136136

137137
# Keep existing pattern for complex array operations
138-
block_tables = self.runner.block_table_cpu[:self.runner.max_num_reqs]
139-
block_tables = block_tables.reshape(-1)
140-
block_tables = device_array(self.runner.mesh,
141-
block_tables,
142-
sharding=dp_sharding)
143-
144138
seq_lens = self._create_dummy_tensor((self.runner.max_num_reqs, ),
145139
jnp.int32, dp_sharding)
146140
query_start_loc = self._create_dummy_tensor(
@@ -152,40 +146,64 @@ def _precompile_backbone_helper(self, name, *, input_ids, positions,
152146
request_distribution,
153147
sharding=dp_sharding)
154148

155-
attention_metadata = AttentionMetadata(
156-
input_positions=positions,
157-
block_tables=block_tables,
158-
seq_lens=seq_lens,
159-
query_start_loc=query_start_loc,
160-
request_distribution=request_distribution,
161-
)
149+
attention_metadata_per_layer: Dict[str, AttentionMetadata] = {}
150+
uniform_attention_metadata: AttentionMetadata = None
151+
for kv_cache_gid, kv_cache_group in enumerate(
152+
self.runner.kv_cache_config.kv_cache_groups):
153+
block_tables = self.runner.block_tables_cpu[
154+
kv_cache_gid][:self.runner.max_num_reqs]
155+
block_tables = block_tables.reshape(-1)
156+
block_tables = device_array(self.runner.mesh,
157+
block_tables,
158+
sharding=dp_sharding)
159+
160+
attention_metadata_gid = AttentionMetadata(
161+
input_positions=positions,
162+
block_tables=block_tables,
163+
seq_lens=seq_lens,
164+
query_start_loc=query_start_loc,
165+
request_distribution=request_distribution,
166+
)
167+
if not self.runner.use_hybrid_kvcache:
168+
# all layers share the same attention metadata
169+
uniform_attention_metadata = attention_metadata_gid
170+
else:
171+
for layer_name in kv_cache_group.layer_names:
172+
attention_metadata_per_layer[
173+
layer_name] = attention_metadata_gid
162174

163175
def model_fn_wrapper(
164176
state,
165177
kv_caches,
166178
input_ids,
167179
attention_metadata,
180+
positions,
168181
inputs_embeds,
169182
layer_name_to_kvcache_index,
170183
lora_metadata,
171184
):
172185
kv_caches, hidden_states, _ = self.runner.model_fn(
173186
state, kv_caches, input_ids, attention_metadata, inputs_embeds,
174-
layer_name_to_kvcache_index, lora_metadata)
187+
positions, layer_name_to_kvcache_index, lora_metadata)
175188
self.runner.kv_caches = kv_caches
176189
return hidden_states
177190

178191
with self.runner.maybe_select_dummy_loras(
179192
self.runner.lora_config, np.array([num_tokens],
180193
dtype=np.int32)):
181194
lora_metadata = self.runner.lora_utils.extract_lora_metadata()
195+
if self.runner.use_hybrid_kvcache:
196+
attention_metadata = attention_metadata_per_layer
197+
else:
198+
attention_metadata = uniform_attention_metadata
182199
self._run_compilation(
183200
name,
184201
model_fn_wrapper,
185202
self.runner.state,
186203
self.runner.kv_caches,
187204
input_ids,
188205
attention_metadata,
206+
positions,
189207
inputs_embeds,
190208
tuple(self.runner.layer_name_to_kvcache_index.items()),
191209
lora_metadata,

tpu_inference/runner/kv_cache_manager.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@
33

44
import jax
55
import jax.numpy as jnp
6+
import numpy as np
67
import vllm.envs as envs
78
from jax.sharding import NamedSharding, PartitionSpec
89
from torchax.ops.mappings import t2j_dtype
910
from vllm.attention import Attention
1011
from vllm.attention.backends.abstract import AttentionType
1112
from vllm.config import get_layers_from_vllm_config
13+
from vllm.utils.math_utils import cdiv
1214
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
1315
KVCacheSpec, MLAAttentionSpec,
1416
SlidingWindowSpec)
@@ -174,6 +176,11 @@ def maybe_reinitialize_input_batch(self,
174176
)
175177
self.runner.input_batch = new_input_batch
176178
self.runner.persistent_batch_manager.input_batch = new_input_batch
179+
self.runner.block_tables_cpu = [
180+
np.zeros((self.runner.max_num_reqs,
181+
cdiv(self.runner.max_model_len, block_size)),
182+
dtype=np.int32) for block_size in block_sizes
183+
]
177184

178185
def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
179186
self.maybe_reinitialize_input_batch(kv_cache_config)

0 commit comments

Comments
 (0)