Skip to content

Commit 90510f6

Browse files
committed
[Misc] Fix model dtype not being configured correctly
Signed-off-by: Kyuyeun Kim <kyuyeunk@google.com>
1 parent 45edde6 commit 90510f6

File tree

3 files changed

+44
-52
lines changed

3 files changed

+44
-52
lines changed

tpu_inference/platforms/tpu_platform.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import jax.numpy as jnp
77
import vllm.envs as vllm_envs
8-
from torchax.ops.mappings import j2t_dtype
98
from tpu_info import device
109
from vllm.inputs import ProcessorInputs, PromptType
1110
from vllm.platforms.interface import Platform, PlatformEnum
@@ -14,6 +13,7 @@
1413
from tpu_inference import envs
1514
from tpu_inference.layers.common.sharding import ShardingConfigManager
1615
from tpu_inference.logger import init_logger
16+
from tpu_inference.utils import TpuDtype
1717

1818
if TYPE_CHECKING:
1919
from vllm.attention.backends.registry import _Backend
@@ -151,18 +151,9 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
151151
# For mm model preprocessors, it may need the output dtype to be torch.
152152
# In order to avoid a PR to vLLM, we postpone the dtype checking during tpu_worker initialization
153153
if not vllm_config.scheduler_config.is_multimodal_model or impl == "vllm":
154-
if not isinstance(vllm_config.model_config.dtype, str):
155-
logger.warning(
156-
"The model dtype is not properly set for JAX backend. "
157-
"Overwriting it to jnp.bfloat16")
158-
vllm_config.model_config.dtype = jnp.bfloat16
159-
else:
160-
vllm_config.model_config.dtype = _DTYPE.get(
161-
vllm_config.model_config.dtype, jnp.bfloat16)
162-
163-
if impl == "vllm":
164-
vllm_config.model_config.dtype = j2t_dtype(
165-
vllm_config.model_config.dtype.dtype)
154+
dtype = TpuDtype(vllm_config.model_config.dtype)
155+
model_dtype = dtype.torch if impl == "vllm" else dtype.jax
156+
vllm_config.model_config.dtype = model_dtype
166157

167158
# TODO(cuiq): remove this dependency.
168159
from vllm.v1.attention.backends.pallas import PallasAttentionBackend

tpu_inference/runner/tpu_runner.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,11 @@
1010
import jax.numpy as jnp
1111
import jaxtyping
1212
import numpy as np
13-
import torch
1413
import vllm.envs as envs
1514
from flax import nnx
1615
from jax.experimental import mesh_utils
1716
from jax.sharding import NamedSharding, PartitionSpec
18-
from torchax.ops.mappings import j2t, j2t_dtype
17+
from torchax.ops.mappings import j2t
1918
from vllm.config import VllmConfig
2019
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
2120
has_kv_transfer_group)
@@ -63,7 +62,7 @@
6362
from tpu_inference.runner.structured_decoding_manager import \
6463
StructuredDecodingManager
6564
from tpu_inference.spec_decode.jax.eagle3 import Eagle3Proposer
66-
from tpu_inference.utils import (device_array, make_optimized_mesh,
65+
from tpu_inference.utils import (TpuDtype, device_array, make_optimized_mesh,
6766
time_function)
6867

6968
logger = init_logger(__name__)
@@ -78,17 +77,6 @@
7877
request_distribution=[0, 0, 0],
7978
)
8079

81-
TPU_STR_DTYPE_TO_TORCH_DTYPE = {
82-
"half": torch.half,
83-
"bfloat16": torch.bfloat16,
84-
"float": torch.float,
85-
"fp8": torch.float8_e4m3fn,
86-
"fp8_e4m3": torch.float8_e4m3fn,
87-
"fp8_e5m2": torch.float8_e5m2,
88-
"int8": torch.int8,
89-
"uint8": torch.uint8,
90-
}
91-
9280

9381
class AsyncTPUModelRunnerOutput(AsyncModelRunnerOutput):
9482
"""Holds asynchronous model output specifically from a TPU runner.
@@ -250,22 +238,10 @@ def __init__(
250238
self.uses_mrope, self.model_config)
251239
self.lora_utils = LoraUtils(self)
252240

253-
cache_config = self.cache_config
254-
if cache_config.cache_dtype == "auto":
255-
model_dtype = self.dtype
256-
if isinstance(model_dtype, str):
257-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[model_dtype]
258-
elif isinstance(getattr(model_dtype, 'dtype', None), jnp.dtype):
259-
self.kv_cache_dtype = j2t_dtype(model_dtype.dtype)
260-
elif isinstance(model_dtype, torch.dtype):
261-
self.kv_cache_dtype = model_dtype
262-
else:
263-
raise ValueError(
264-
"KV cache is unsupported for model_dtype of %s",
265-
model_dtype)
266-
else:
267-
self.kv_cache_dtype = TPU_STR_DTYPE_TO_TORCH_DTYPE[
268-
cache_config.cache_dtype]
241+
cache_dtype = self.cache_config.cache_dtype
242+
if cache_dtype == "auto":
243+
cache_dtype = self.dtype
244+
self.kv_cache_dtype = TpuDtype(cache_dtype).torch
269245

270246
self._pre_async_results: AsyncPreResults | None = None
271247
self._substitute_placeholder_token_fn = _substitute_placeholder_token

tpu_inference/utils.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,14 @@
99
import jax
1010
import jax.numpy as jnp
1111
import numpy as np
12+
import torch
1213
from jax._src import dtypes
1314
from jax._src import mesh as mesh_lib
1415
from jax._src import xla_bridge as xb
1516
from jax._src.lib import xla_client as xc
17+
from jax._src.numpy.scalar_types import _ScalarMeta
1618
from jax.sharding import Mesh, NamedSharding, PartitionSpec
19+
from torchax.ops.mappings import j2t_dtype, t2j_dtype
1720
from vllm import envs, utils
1821

1922
from tpu_inference.logger import init_logger
@@ -25,13 +28,35 @@
2528
# This is used to translate from a string name for a dtype
2629
# to formal jax.numpy DType. One use case for this is
2730
# converting the `--kv_cache_dtype` flag to a dtype.
28-
TPU_STR_DTYPE_TO_JAX_DTYPE = {
29-
"bfloat16": jnp.bfloat16,
30-
"fp8": jnp.float8_e4m3fn,
31-
"fp8_e4m3": jnp.float8_e4m3,
32-
"fp8_e5m2": jnp.float8_e5m2,
33-
"int8": jnp.int8,
34-
}
31+
32+
33+
class TpuDtype:
34+
dtype: jnp.dtype = None
35+
36+
def __init__(self, dtype: str | jnp.dtype | torch.dtype):
37+
if isinstance(dtype, str):
38+
self.dtype = jnp.dtype(dtype)
39+
elif isinstance(dtype, torch.dtype):
40+
self.dtype = t2j_dtype(dtype)
41+
elif isinstance(dtype, jnp.dtype):
42+
self.dtype = dtype
43+
elif isinstance(dtype, _ScalarMeta):
44+
self.dtype = dtype.dtype
45+
else:
46+
raise ValueError(f'Unkonw type of dtype {type(dtype)}')
47+
48+
@property
49+
def jax(self):
50+
return self.dtype
51+
52+
@property
53+
def torch(self):
54+
return j2t_dtype(self.dtype)
55+
56+
@property
57+
def str(self):
58+
return self.dtype.name
59+
3560

3661
_megacore = False
3762
logger = init_logger(__name__)
@@ -294,8 +319,8 @@ def get_jax_dtype_from_str_dtype(str_dtype: str) -> jnp.dtype:
294319
Returns:
295320
jnp.dtype: The JAX dtype.
296321
"""
297-
str_dtype = str_dtype.lower().strip()
298-
return TPU_STR_DTYPE_TO_JAX_DTYPE.get(str_dtype)
322+
# TODO(kyuyeunk): Replace all reference of this function into TpuDtype.
323+
return TpuDtype(str_dtype).to_jax()
299324

300325

301326
def time_function(func):

0 commit comments

Comments
 (0)