diff --git a/tests/platforms/test_tpu_platform.py b/tests/platforms/test_tpu_platform.py new file mode 100644 index 000000000..81dca30db --- /dev/null +++ b/tests/platforms/test_tpu_platform.py @@ -0,0 +1,40 @@ +from unittest.mock import MagicMock, patch + +import pytest +import torch +from vllm.config import CacheConfig, VllmConfig + +from tpu_inference.platforms.tpu_platform import TpuPlatform + + +class TestTpuPlatform: + + @pytest.fixture + def vllm_config(self): + cache_config = CacheConfig(block_size=16, + gpu_memory_utilization=0.9, + swap_space=4, + cache_dtype="fp8") + + vllm_config = MagicMock(spec=VllmConfig) + vllm_config.cache_config = cache_config + vllm_config.model_config = MagicMock(dtype='bfloat16') + vllm_config.scheduler_config = MagicMock(is_multimodal_model=False) + vllm_config.parallel_config = MagicMock() + vllm_config.compilation_config = MagicMock(mode="dynamo_trace_once", + backend="openxla") + vllm_config.kv_transfer_config = None + return vllm_config + + @pytest.mark.parametrize("chip_name,expected_dtype", [ + ("v6e", torch.float8_e5m2), + ("v5e", torch.float8_e4m3fn), + ]) + def test_fp8_dtype(self, chip_name, expected_dtype): + mock_chip_type = MagicMock() + mock_chip_type.name = chip_name + + with patch('tpu_inference.platforms.tpu_platform.init_logger'), \ + patch('tpu_inference.platforms.tpu_platform.device.get_local_chips', return_value=(mock_chip_type, None)), \ + patch('vllm.envs.VLLM_TPU_USING_PATHWAYS', False): + assert TpuPlatform.fp8_dtype() == expected_dtype diff --git a/tpu_inference/platforms/tpu_platform.py b/tpu_inference/platforms/tpu_platform.py index b3a4a7de3..c5603ee60 100644 --- a/tpu_inference/platforms/tpu_platform.py +++ b/tpu_inference/platforms/tpu_platform.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast import jax.numpy as jnp +import torch import vllm.envs as vllm_envs from torchax.ops.mappings import j2t_dtype from tpu_info import device @@ -82,6 +83,14 @@ def get_device_name(cls, device_id: int = 0) -> str: logger.warning(f"Error getting device name: {e}") return 'TPU' + @classmethod + def fp8_dtype(cls) -> torch.dtype: + if cls.get_device_name().lower() == "tpu v6e": + logger.info( + "Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.") + return torch.float8_e5m2 + return torch.float8_e4m3fn + @classmethod def get_device_total_memory(cls, device_id: int = 0) -> int: raise NotImplementedError @@ -132,6 +141,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # For v0, the default block size is 16. if cache_config and cache_config.block_size is None: cache_config.block_size = cast(BlockSize, 16) + compilation_config = vllm_config.compilation_config # TPU only supports DYNAMO_TRACE_ONCE compilation level