|
| 1 | +from unittest.mock import MagicMock, patch |
| 2 | +import pytest |
| 3 | +from vllm.config import CacheConfig, VllmConfig |
| 4 | + |
| 5 | +from tpu_inference.platforms.tpu_platform import TpuPlatform |
| 6 | + |
| 7 | +class TestTpuPlatform: |
| 8 | + |
| 9 | + @pytest.fixture |
| 10 | + def vllm_config(self): |
| 11 | + cache_config = CacheConfig(block_size=16, gpu_memory_utilization=0.9, swap_space=4, cache_dtype="fp8") |
| 12 | + |
| 13 | + vllm_config = MagicMock(spec=VllmConfig) |
| 14 | + vllm_config.cache_config = cache_config |
| 15 | + vllm_config.model_config = MagicMock(dtype='bfloat16') |
| 16 | + vllm_config.scheduler_config = MagicMock(is_multimodal_model=False) |
| 17 | + vllm_config.parallel_config = MagicMock() |
| 18 | + vllm_config.compilation_config = MagicMock(mode="dynamo_trace_once", backend="openxla") |
| 19 | + vllm_config.kv_transfer_config = None |
| 20 | + return vllm_config |
| 21 | + |
| 22 | + @pytest.mark.parametrize("chip_name,expected_dtype", [ |
| 23 | + ("v6e", "fp8_e5m2"), |
| 24 | + ("v5e", "fp8"), |
| 25 | + ]) |
| 26 | + def test_check_and_update_config_fp8(self, chip_name, expected_dtype, vllm_config): |
| 27 | + mock_chip_type = MagicMock() |
| 28 | + mock_chip_type.name = chip_name |
| 29 | + |
| 30 | + # Common patches |
| 31 | + with patch('tpu_inference.platforms.tpu_platform.init_logger'), \ |
| 32 | + patch('tpu_inference.platforms.tpu_platform.device.get_local_chips', return_value=(mock_chip_type, None)), \ |
| 33 | + patch('vllm.envs.VLLM_TPU_USING_PATHWAYS', False), \ |
| 34 | + patch('tpu_inference.platforms.tpu_platform.ShardingConfigManager.from_vllm_config'), \ |
| 35 | + patch('tpu_inference.platforms.tpu_platform.envs.MODEL_IMPL_TYPE', "vllm"), \ |
| 36 | + patch('vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_page_size', return_value=16), \ |
| 37 | + patch('vllm.v1.attention.backends.pallas.PallasAttentionBackend.get_min_page_size', return_value=16), \ |
| 38 | + patch('tpu_inference.models.jax.utils.quantization.quantization_utils.update_vllm_config_for_qwix_quantization'), \ |
| 39 | + patch('tpu_inference.core.sched.dp_scheduler.update_vllm_config_for_dp_scheduler'): |
| 40 | + |
| 41 | + TpuPlatform.check_and_update_config(vllm_config) |
| 42 | + |
| 43 | + assert vllm_config.cache_config.cache_dtype == expected_dtype |
0 commit comments