Skip to content

Commit 2669cde

Browse files
committed
Use FP8_e5m2 automatically when using quantized kv cache FP8 on trillium
Signed-off-by: zixi-qi <qizixi@meta.com>
1 parent 5fe5fad commit 2669cde

File tree

2 files changed

+50
-0
lines changed

2 files changed

+50
-0
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
from unittest.mock import MagicMock, patch
2+
import pytest
3+
from vllm.config import CacheConfig, VllmConfig
4+
5+
from tpu_inference.platforms.tpu_platform import TpuPlatform
6+
7+
class TestTpuPlatform:
8+
9+
@pytest.fixture
10+
def vllm_config(self):
11+
cache_config = CacheConfig(block_size=16, gpu_memory_utilization=0.9, swap_space=4, cache_dtype="fp8")
12+
13+
vllm_config = MagicMock(spec=VllmConfig)
14+
vllm_config.cache_config = cache_config
15+
vllm_config.model_config = MagicMock(dtype='bfloat16')
16+
vllm_config.scheduler_config = MagicMock(is_multimodal_model=False)
17+
vllm_config.parallel_config = MagicMock()
18+
vllm_config.compilation_config = MagicMock(mode="dynamo_trace_once", backend="openxla")
19+
vllm_config.kv_transfer_config = None
20+
return vllm_config
21+
22+
@pytest.mark.parametrize("chip_name,expected_dtype", [
23+
("v6e", "fp8_e5m2"),
24+
("v5e", "fp8"),
25+
])
26+
def test_check_and_update_config_fp8(self, chip_name, expected_dtype, vllm_config):
27+
mock_chip_type = MagicMock()
28+
mock_chip_type.name = chip_name
29+
30+
# Common patches
31+
with patch('tpu_inference.platforms.tpu_platform.init_logger'), \
32+
patch('tpu_inference.platforms.tpu_platform.device.get_local_chips', return_value=(mock_chip_type, None)), \
33+
patch('vllm.envs.VLLM_TPU_USING_PATHWAYS', False), \
34+
patch('tpu_inference.platforms.tpu_platform.ShardingConfigManager.from_vllm_config'), \
35+
patch('tpu_inference.platforms.tpu_platform.envs.MODEL_IMPL_TYPE', "vllm"), \
36+
patch('vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_page_size', return_value=16), \
37+
patch('vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_min_page_size', return_value=16), \
38+
patch('tpu_inference.models.jax.utils.quantization.quantization_utils.update_vllm_config_for_qwix_quantization'), \
39+
patch('tpu_inference.core.sched.dp_scheduler.update_vllm_config_for_dp_scheduler'):
40+
41+
TpuPlatform.check_and_update_config(vllm_config)
42+
43+
assert vllm_config.cache_config.cache_dtype == expected_dtype

tpu_inference/platforms/tpu_platform.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
132132
# For v0, the default block size is 16.
133133
if cache_config and cache_config.block_size is None:
134134
cache_config.block_size = cast(BlockSize, 16)
135+
136+
if cache_config and cache_config.cache_dtype == "fp8" and cls.get_device_name(
137+
) == "TPU v6e":
138+
logger.info(
139+
"Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.")
140+
cache_config.cache_dtype = "fp8_e5m2"
141+
135142
compilation_config = vllm_config.compilation_config
136143

137144
# TPU only supports DYNAMO_TRACE_ONCE compilation level

0 commit comments

Comments
 (0)