From e12ab1617bb8c3216ba34232d034d778bb82c96b Mon Sep 17 00:00:00 2001 From: zixi-qi Date: Thu, 20 Nov 2025 03:52:35 +0000 Subject: [PATCH 1/2] Use FP8_e5m2 automatically when using quantized kv cache FP8 on trillium Signed-off-by: zixi-qi --- tests/platforms/test_tpu_platform.py | 50 +++++++++++++++++++++++++ tpu_inference/platforms/tpu_platform.py | 7 ++++ 2 files changed, 57 insertions(+) create mode 100644 tests/platforms/test_tpu_platform.py diff --git a/tests/platforms/test_tpu_platform.py b/tests/platforms/test_tpu_platform.py new file mode 100644 index 000000000..375f4d4b2 --- /dev/null +++ b/tests/platforms/test_tpu_platform.py @@ -0,0 +1,50 @@ +from unittest.mock import MagicMock, patch + +import pytest +from vllm.config import CacheConfig, VllmConfig + +from tpu_inference.platforms.tpu_platform import TpuPlatform + + +class TestTpuPlatform: + + @pytest.fixture + def vllm_config(self): + cache_config = CacheConfig(block_size=16, + gpu_memory_utilization=0.9, + swap_space=4, + cache_dtype="fp8") + + vllm_config = MagicMock(spec=VllmConfig) + vllm_config.cache_config = cache_config + vllm_config.model_config = MagicMock(dtype='bfloat16') + vllm_config.scheduler_config = MagicMock(is_multimodal_model=False) + vllm_config.parallel_config = MagicMock() + vllm_config.compilation_config = MagicMock(mode="dynamo_trace_once", + backend="openxla") + vllm_config.kv_transfer_config = None + return vllm_config + + @pytest.mark.parametrize("chip_name,expected_dtype", [ + ("v6e", "fp8_e5m2"), + ("v5e", "fp8"), + ]) + def test_check_and_update_config_fp8(self, chip_name, expected_dtype, + vllm_config): + mock_chip_type = MagicMock() + mock_chip_type.name = chip_name + + # Common patches + with patch('tpu_inference.platforms.tpu_platform.init_logger'), \ + patch('tpu_inference.platforms.tpu_platform.device.get_local_chips', return_value=(mock_chip_type, None)), \ + patch('vllm.envs.VLLM_TPU_USING_PATHWAYS', False), \ + patch('tpu_inference.platforms.tpu_platform.ShardingConfigManager.from_vllm_config'), \ + patch('tpu_inference.platforms.tpu_platform.envs.MODEL_IMPL_TYPE', "vllm"), \ + patch('vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_page_size', return_value=16), \ + patch('vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_min_page_size', return_value=16), \ + patch('tpu_inference.models.jax.utils.quantization.quantization_utils.update_vllm_config_for_qwix_quantization'), \ + patch('tpu_inference.core.sched.dp_scheduler.update_vllm_config_for_dp_scheduler'): + + TpuPlatform.check_and_update_config(vllm_config) + + assert vllm_config.cache_config.cache_dtype == expected_dtype diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index b3a4a7de3..20fd0c385 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -132,6 +132,13 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # For v0, the default block size is 16. if cache_config and cache_config.block_size is None: cache_config.block_size = cast(BlockSize, 16) + + if cache_config and cache_config.cache_dtype == "fp8" and cls.get_device_name( + ) == "TPU v6e": + logger.info( + "Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.") + cache_config.cache_dtype = "fp8_e5m2" + compilation_config = vllm_config.compilation_config # TPU only supports DYNAMO_TRACE_ONCE compilation level From ada5c211acf838261a4f378382b3a34e6eef9fde Mon Sep 17 00:00:00 2001 From: zixi-qi Date: Thu, 20 Nov 2025 21:11:46 +0000 Subject: [PATCH 2/2] move dtype override logic Signed-off-by: zixi-qi --- tests/platforms/test_tpu_platform.py | 22 ++++++---------------- tpu_inference/platforms/tpu_platform.py | 15 +++++++++------ 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/tests/platforms/test_tpu_platform.py b/tests/platforms/test_tpu_platform.py index 375f4d4b2..81dca30db 100644 --- a/tests/platforms/test_tpu_platform.py +++ b/tests/platforms/test_tpu_platform.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +import torch from vllm.config import CacheConfig, VllmConfig from tpu_inference.platforms.tpu_platform import TpuPlatform @@ -26,25 +27,14 @@ def vllm_config(self): return vllm_config @pytest.mark.parametrize("chip_name,expected_dtype", [ - ("v6e", "fp8_e5m2"), - ("v5e", "fp8"), + ("v6e", torch.float8_e5m2), + ("v5e", torch.float8_e4m3fn), ]) - def test_check_and_update_config_fp8(self, chip_name, expected_dtype, - vllm_config): + def test_fp8_dtype(self, chip_name, expected_dtype): mock_chip_type = MagicMock() mock_chip_type.name = chip_name - # Common patches with patch('tpu_inference.platforms.tpu_platform.init_logger'), \ patch('tpu_inference.platforms.tpu_platform.device.get_local_chips', return_value=(mock_chip_type, None)), \ - patch('vllm.envs.VLLM_TPU_USING_PATHWAYS', False), \ - patch('tpu_inference.platforms.tpu_platform.ShardingConfigManager.from_vllm_config'), \ - patch('tpu_inference.platforms.tpu_platform.envs.MODEL_IMPL_TYPE', "vllm"), \ - patch('vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_page_size', return_value=16), \ - patch('vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_min_page_size', return_value=16), \ - patch('tpu_inference.models.jax.utils.quantization.quantization_utils.update_vllm_config_for_qwix_quantization'), \ - patch('tpu_inference.core.sched.dp_scheduler.update_vllm_config_for_dp_scheduler'): - - TpuPlatform.check_and_update_config(vllm_config) - - assert vllm_config.cache_config.cache_dtype == expected_dtype + patch('vllm.envs.VLLM_TPU_USING_PATHWAYS', False): + assert TpuPlatform.fp8_dtype() == expected_dtype diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index 20fd0c385..c5603ee60 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast import jax.numpy as jnp +import torch import vllm.envs as vllm_envs from torchax.ops.mappings import j2t_dtype from tpu_info import device @@ -82,6 +83,14 @@ def get_device_name(cls, device_id: int = 0) -> str: logger.warning(f"Error getting device name: {e}") return 'TPU' + @classmethod + def fp8_dtype(cls) -> torch.dtype: + if cls.get_device_name().lower() == "tpu v6e": + logger.info( + "Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.") + return torch.float8_e5m2 + return torch.float8_e4m3fn + @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError @@ -133,12 +142,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if cache_config and cache_config.block_size is None: cache_config.block_size = cast(BlockSize, 16) - if cache_config and cache_config.cache_dtype == "fp8" and cls.get_device_name( - ) == "TPU v6e": - logger.info( - "Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.") - cache_config.cache_dtype = "fp8_e5m2" - compilation_config = vllm_config.compilation_config # TPU only supports DYNAMO_TRACE_ONCE compilation level