Skip to content

Commit a98a8d6

Browse files
wip
1 parent 90f21ad commit a98a8d6

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"total_num_scheduled_tokens": 8192, "num_prefill_tokens": 16349, "num_decode_tokens": 35, "padded_total_num_scheduled_tokens": 16384, "num_reqs": 45}

tpu_inference/layers/common/sharding.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,19 @@ def from_vllm_config(cls,
121121
if enable_dp_attention:
122122
# Replicate attention layer when num_kv_heads < TP
123123
num_kv_heads = vllm_config.model_config.get_total_num_kv_heads()
124+
124125
kv_dtype = utils.get_jax_dtype_from_str_dtype(
125126
vllm_config.cache_config.cache_dtype) or jnp.bfloat16
126127
packing = 4 // jnp.dtype(kv_dtype).itemsize
127128
# When num_kv_heads * 2 / packing < TP, tensor parallelism would
128129
# duplicate KV heads across devices, wasting kv cache memory.
129130
# Use attention DP instead to reduce per-device num_kv_heads and
130131
# eliminate this waste.
132+
133+
# if head_dim is 64, multiply packing by 2
134+
if vllm_config.model_config.get_head_size() == 64:
135+
packing *= 2
136+
131137
num_kv_heads_per_device_in_kv_cache = (num_kv_heads * 2) / packing
132138
attn_dp = max(
133139
int(tensor_parallelism // num_kv_heads_per_device_in_kv_cache),

0 commit comments

Comments
 (0)