Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions tests/platforms/test_tpu_platform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from unittest.mock import MagicMock, patch

import pytest
import torch
from vllm.config import CacheConfig, VllmConfig

from tpu_inference.platforms.tpu_platform import TpuPlatform


class TestTpuPlatform:

@pytest.fixture
def vllm_config(self):
cache_config = CacheConfig(block_size=16,
gpu_memory_utilization=0.9,
swap_space=4,
cache_dtype="fp8")

vllm_config = MagicMock(spec=VllmConfig)
vllm_config.cache_config = cache_config
vllm_config.model_config = MagicMock(dtype='bfloat16')
vllm_config.scheduler_config = MagicMock(is_multimodal_model=False)
vllm_config.parallel_config = MagicMock()
vllm_config.compilation_config = MagicMock(mode="dynamo_trace_once",
backend="openxla")
vllm_config.kv_transfer_config = None
return vllm_config

@pytest.mark.parametrize("chip_name,expected_dtype", [
("v6e", torch.float8_e5m2),
("v5e", torch.float8_e4m3fn),
])
def test_fp8_dtype(self, chip_name, expected_dtype):
mock_chip_type = MagicMock()
mock_chip_type.name = chip_name

with patch('tpu_inference.platforms.tpu_platform.init_logger'), \
patch('tpu_inference.platforms.tpu_platform.device.get_local_chips', return_value=(mock_chip_type, None)), \
patch('vllm.envs.VLLM_TPU_USING_PATHWAYS', False):
assert TpuPlatform.fp8_dtype() == expected_dtype
10 changes: 10 additions & 0 deletions tpu_inference/platforms/tpu_platform.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union, cast

import jax.numpy as jnp
import torch
import vllm.envs as vllm_envs
from torchax.ops.mappings import j2t_dtype
from tpu_info import device
Expand Down Expand Up @@ -82,6 +83,14 @@ def get_device_name(cls, device_id: int = 0) -> str:
logger.warning(f"Error getting device name: {e}")
return 'TPU'

@classmethod
def fp8_dtype(cls) -> torch.dtype:
if cls.get_device_name().lower() == "tpu v6e":
logger.info(
"Automatically using fp8_e5m2 for FP8 KV cache on TPU v6e.")
return torch.float8_e5m2
return torch.float8_e4m3fn

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
raise NotImplementedError
Expand Down Expand Up @@ -132,6 +141,7 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
# For v0, the default block size is 16.
if cache_config and cache_config.block_size is None:
cache_config.block_size = cast(BlockSize, 16)

compilation_config = vllm_config.compilation_config

# TPU only supports DYNAMO_TRACE_ONCE compilation level
Expand Down