From f747616b062bcb1555b98e7e12c2fe5418487c44 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 29 Nov 2025 14:55:12 +0000 Subject: [PATCH 1/4] [Misc] Update `TokenizerLike` interface and move `get_cached_tokenizer` Signed-off-by: DarkLight1337 --- .buildkite/test-amd.yaml | 9 +- .buildkite/test-pipeline.yaml | 9 +- docs/design/huggingface_integration.md | 2 +- .../{test_cached_tokenizer.py => test_hf.py} | 2 +- tests/tokenizers_/test_registry.py | 12 +- tools/pre_commit/check_pickle_imports.py | 2 +- vllm/entrypoints/llm.py | 2 +- vllm/entrypoints/score_utils.py | 4 +- vllm/tokenizers/__init__.py | 3 +- vllm/tokenizers/hf.py | 122 +++++++++++++++ vllm/tokenizers/mistral.py | 67 ++++---- vllm/tokenizers/protocol.py | 32 ++-- vllm/transformers_utils/tokenizer.py | 146 +++++------------- vllm/v1/engine/detokenizer.py | 2 +- 14 files changed, 247 insertions(+), 167 deletions(-) rename tests/tokenizers_/{test_cached_tokenizer.py => test_hf.py} (95%) create mode 100644 vllm/tokenizers/hf.py diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 4d98ee40a4bb..7cb0a9bf761e 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -61,8 +61,8 @@ steps: - pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s utils_ -- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins - timeout_in_minutes: 10 +- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 10min + timeout_in_minutes: 15 mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking @@ -72,6 +72,7 @@ steps: - tests/test_outputs.py - tests/multimodal - tests/standalone_tests/lazy_imports.py + - tests/tokenizers_ - tests/transformers_utils - tests/config no_gpu: true @@ -80,6 +81,7 @@ steps: - pytest -v -s test_inputs.py - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal + - pytest -v -s tokenizers_ - pytest -v -s transformers_utils - pytest -v -s config @@ -316,15 +318,12 @@ steps: source_file_dependencies: - vllm/ - tests/engine - - tests/tokenizers_ - tests/test_sequence - tests/test_config - tests/test_logger - tests/test_vllm_port commands: - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py - # OOM in the CI unless we run this separately - - pytest -v -s tokenizers_ - label: V1 Test e2e + engine # 30min timeout_in_minutes: 45 diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 16d490754958..4ad2fe3cc3a2 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -57,14 +57,15 @@ steps: - pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s utils_ -- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins - timeout_in_minutes: 10 +- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 10min + timeout_in_minutes: 15 source_file_dependencies: - vllm/ - tests/test_inputs.py - tests/test_outputs.py - tests/multimodal - tests/standalone_tests/lazy_imports.py + - tests/tokenizers_ - tests/transformers_utils - tests/config no_gpu: true @@ -73,6 +74,7 @@ steps: - pytest -v -s test_inputs.py - pytest -v -s test_outputs.py - pytest -v -s -m 'cpu_test' multimodal + - pytest -v -s tokenizers_ - pytest -v -s transformers_utils - pytest -v -s config @@ -282,15 +284,12 @@ steps: source_file_dependencies: - vllm/ - tests/engine - - tests/tokenizers_ - tests/test_sequence - tests/test_config - tests/test_logger - tests/test_vllm_port commands: - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py - # OOM in the CI unless we run this separately - - pytest -v -s tokenizers_ - label: V1 Test e2e + engine # 30min timeout_in_minutes: 45 diff --git a/docs/design/huggingface_integration.md b/docs/design/huggingface_integration.md index 412ce658b92a..1109abf6cb93 100644 --- a/docs/design/huggingface_integration.md +++ b/docs/design/huggingface_integration.md @@ -21,7 +21,7 @@ Let's say we want to serve the popular Qwen model by running `vllm serve Qwen/Qw Beyond that, there are two more things vLLM depends on Hugging Face for. -1. **Tokenizer**: vLLM uses the tokenizer from Hugging Face to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check Hugging Face's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [get_cached_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L24). +1. **Tokenizer**: vLLM uses the tokenizer from Hugging Face to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check Hugging Face's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [vllm.tokenizers.hf.get_cached_tokenizer][]. 2. **Model weight**: vLLM downloads the model weight from the Hugging Face model hub using the `model` argument as the model name and the `--revision` argument as the revision. vLLM provides the argument `--load-format` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass `--load-format dummy` to skip downloading the weights. - It is recommended to use the safetensors format, as it is efficient for loading in distributed inference and also safe from arbitrary code execution. See the [documentation](https://huggingface.co/docs/safetensors/en/index) for more information on the safetensors format. This part of the logic can be found [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/model_executor/model_loader/loader.py#L385). Please note that: diff --git a/tests/tokenizers_/test_cached_tokenizer.py b/tests/tokenizers_/test_hf.py similarity index 95% rename from tests/tokenizers_/test_cached_tokenizer.py rename to tests/tokenizers_/test_hf.py index 48234687ea1e..c1238900ce0d 100644 --- a/tests/tokenizers_/test_cached_tokenizer.py +++ b/tests/tokenizers_/test_hf.py @@ -7,7 +7,7 @@ from transformers import AutoTokenizer from vllm.tokenizers import TokenizerLike -from vllm.transformers_utils.tokenizer import get_cached_tokenizer +from vllm.tokenizers.hf import get_cached_tokenizer @pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"]) diff --git a/tests/tokenizers_/test_registry.py b/tests/tokenizers_/test_registry.py index 1eb19a0996dd..46918e5fa3eb 100644 --- a/tests/tokenizers_/test_registry.py +++ b/tests/tokenizers_/test_registry.py @@ -17,20 +17,22 @@ def bos_token_id(self) -> int: def eos_token_id(self) -> int: return 1 + @property + def pad_token_id(self) -> int: + return 2 + def test_customized_tokenizer(): - TokenizerRegistry.register( - "test_tokenizer", - __name__, - TestTokenizer.__name__, - ) + TokenizerRegistry.register("test_tokenizer", __name__, TestTokenizer.__name__) tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer") assert isinstance(tokenizer, TestTokenizer) assert tokenizer.bos_token_id == 0 assert tokenizer.eos_token_id == 1 + assert tokenizer.pad_token_id == 2 tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom") assert isinstance(tokenizer, TestTokenizer) assert tokenizer.bos_token_id == 0 assert tokenizer.eos_token_id == 1 + assert tokenizer.pad_token_id == 2 diff --git a/tools/pre_commit/check_pickle_imports.py b/tools/pre_commit/check_pickle_imports.py index 2bb468da68c2..13e5a0eda751 100644 --- a/tools/pre_commit/check_pickle_imports.py +++ b/tools/pre_commit/check_pickle_imports.py @@ -27,7 +27,7 @@ "vllm/distributed/device_communicators/shm_broadcast.py", "vllm/distributed/device_communicators/shm_object_storage.py", "vllm/utils/hashing.py", - "tests/tokenizers_/test_cached_tokenizer.py", + "tests/tokenizers_/test_hf.py", "tests/utils_/test_hashing.py", "benchmarks/kernels/graph_machete_bench.py", "benchmarks/kernels/benchmark_lora.py", diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 4ea213752e39..acdf28501cbb 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -72,7 +72,7 @@ from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams from vllm.tasks import PoolingTask from vllm.tokenizers import MistralTokenizer, TokenizerLike -from vllm.transformers_utils.tokenizer import get_cached_tokenizer +from vllm.tokenizers.hf import get_cached_tokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils.collection_utils import as_iter, is_list_of from vllm.utils.counter import Counter diff --git a/vllm/entrypoints/score_utils.py b/vllm/entrypoints/score_utils.py index 04d5a192918d..602f59ac09f5 100644 --- a/vllm/entrypoints/score_utils.py +++ b/vllm/entrypoints/score_utils.py @@ -51,8 +51,8 @@ def _cosine_similarity( for emb_1, emb_2 in zip(embed_1, embed_2): pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data) - padding = [] - if (pad_token_id := getattr(tokenizer, "pad_token_id", None)) is not None: + padding: list[int] = [] + if (pad_token_id := tokenizer.pad_token_id) is not None: padding = [pad_token_id] tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids diff --git a/vllm/tokenizers/__init__.py b/vllm/tokenizers/__init__.py index e26b4e8797ec..03174872146a 100644 --- a/vllm/tokenizers/__init__.py +++ b/vllm/tokenizers/__init__.py @@ -1,8 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from .hf import HfTokenizer from .mistral import MistralTokenizer from .protocol import TokenizerLike from .registry import TokenizerRegistry -__all__ = ["TokenizerLike", "MistralTokenizer", "TokenizerRegistry"] +__all__ = ["TokenizerLike", "HfTokenizer", "MistralTokenizer", "TokenizerRegistry"] diff --git a/vllm/tokenizers/hf.py b/vllm/tokenizers/hf.py new file mode 100644 index 000000000000..64672fdbb120 --- /dev/null +++ b/vllm/tokenizers/hf.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import contextlib +import copy +from pathlib import Path +from typing import TYPE_CHECKING + +from transformers import AutoTokenizer + +from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config + +from .protocol import TokenizerLike + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + + +def get_cached_tokenizer( + tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast", +) -> TokenizerLike: + """ + By default, transformers will recompute multiple tokenizer properties + each time they are called, leading to a significant slowdown. + This proxy caches these properties for faster access. + """ + cached_tokenizer = copy.copy(tokenizer) + + tokenizer_all_special_ids = tokenizer.all_special_ids + tokenizer_all_special_tokens = tokenizer.all_special_tokens + tokenizer_vocab = tokenizer.get_vocab() + tokenizer_len = len(tokenizer) + + max_token_id = max(tokenizer_vocab.values()) + # Some tokenizers (e.g., QwenTokenizer) have special tokens that + # are added and included in the implementation of the vocab_size + # property, but not in get_vocab(); if there is an implementation + # of vocab size, we should take the greater value. + if hasattr(tokenizer, "vocab_size"): + with contextlib.suppress(NotImplementedError): + max_token_id = max(max_token_id, tokenizer.vocab_size) + + class CachedTokenizer(tokenizer.__class__): # type: ignore + @property + def all_special_ids(self) -> list[int]: + return tokenizer_all_special_ids + + @property + def all_special_tokens(self) -> list[str]: + return tokenizer_all_special_tokens + + @property + def max_token_id(self) -> int: + return max_token_id + + def get_vocab(self) -> dict[str, int]: + return tokenizer_vocab + + def __len__(self) -> int: + return tokenizer_len + + def __reduce__(self): + return get_cached_tokenizer, (tokenizer,) + + CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" + + cached_tokenizer.__class__ = CachedTokenizer + return cached_tokenizer # type: ignore + + +class HfTokenizer(TokenizerLike): + @classmethod + def from_pretrained( + cls, + path_or_repo_id: str | Path, + *args, + trust_remote_code: bool = False, + revision: str | None = None, + download_dir: str | None = None, + **kwargs, + ) -> "TokenizerLike": + try: + tokenizer = AutoTokenizer.from_pretrained( + path_or_repo_id, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + cache_dir=download_dir, + **kwargs, + ) + except ValueError as e: + # If the error pertains to the tokenizer class not existing or not + # currently being imported, + # suggest using the --trust-remote-code flag. + if not trust_remote_code and ( + "does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e) + ): + err_msg = ( + "Failed to load the tokenizer. If the tokenizer " + "is a custom tokenizer not yet available in the " + "HuggingFace transformers library, consider " + "setting `trust_remote_code=True` in LLM or using " + "the `--trust-remote-code` flag in the CLI." + ) + raise RuntimeError(err_msg) from e + else: + raise e + + # The special_tokens in tokenizer should also be + # controlled by do_lower_case in encoder_config + encoder_config = get_sentence_transformer_tokenizer_config( + path_or_repo_id, revision + ) + if isinstance(encoder_config, dict) and encoder_config.get( + "do_lower_case", False + ): + special_tokens_map = { + k: v.lower() for k, v in tokenizer.special_tokens_map.items() + } + tokenizer.add_special_tokens(special_tokens_map) + + return get_cached_tokenizer(tokenizer) diff --git a/vllm/tokenizers/mistral.py b/vllm/tokenizers/mistral.py index a42fb0e1e5f1..de3e5ec43854 100644 --- a/vllm/tokenizers/mistral.py +++ b/vllm/tokenizers/mistral.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +from pathlib import Path from typing import TYPE_CHECKING, Any, cast from vllm.logger import init_logger @@ -12,6 +12,7 @@ ChatCompletionRequest as MistralChatCompletionRequest, ) from mistral_common.tokens.tokenizers.tekken import Tekkenizer + from transformers import BatchEncoding from transformers.tokenization_mistral_common import ( MistralCommonTokenizer as TransformersMistralTokenizer, ) @@ -165,7 +166,35 @@ def _tekken_token_to_id(tokenizer: "Tekkenizer", t: str | bytes) -> int: class MistralTokenizer(TokenizerLike): + @classmethod + def from_pretrained( + cls, + path_or_repo_id: str | Path, + *args, + trust_remote_code: bool = False, + revision: str | None = None, + download_dir: str | None = None, + **kwargs, + ) -> "MistralTokenizer": + from mistral_common.protocol.instruct.validator import ValidationMode + from transformers.tokenization_mistral_common import ( + MistralCommonTokenizer as TransformersMistralTokenizer, + ) + + tokenizer = TransformersMistralTokenizer.from_pretrained( + path_or_repo_id, + *args, + mode=ValidationMode.test, + cache_dir=download_dir, + revision="main" if revision is None else revision, + **kwargs, + ) + + return cls(tokenizer) + def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None: + super().__init__() + from mistral_common.protocol.instruct.validator import ValidationMode from mistral_common.tokens.tokenizers.sentencepiece import ( SentencePieceTokenizer, @@ -211,22 +240,6 @@ def __init__(self, tokenizer: "TransformersMistralTokenizer") -> None: self._vocab = self.tokenizer._vocab self._max_token_id = self.vocab_size - 1 - @classmethod - def from_pretrained( - cls, path_or_repo_id: str, *, revision: str | None = None - ) -> "MistralTokenizer": - from mistral_common.protocol.instruct.validator import ValidationMode - from transformers.tokenization_mistral_common import ( - MistralCommonTokenizer as TransformersMistralTokenizer, - ) - - str_revision = "main" if revision is None else revision - return cls( - TransformersMistralTokenizer.from_pretrained( - path_or_repo_id, revision=str_revision, mode=ValidationMode.test - ) - ) - def _get_special_token_ids(self) -> list[int]: from mistral_common.tokens.tokenizers.sentencepiece import ( SentencePieceTokenizer, @@ -271,6 +284,10 @@ def bos_token_id(self) -> int: def eos_token_id(self) -> int: return self.tokenizer.eos_id + @property + def pad_token_id(self) -> int: + return self.tokenizer.pad_id + @property def is_fast(self) -> bool: return True @@ -298,12 +315,12 @@ def __len__(self) -> int: def __call__( self, - text: str | list[str] | list[int], + text: str | list[str], text_pair: str | None = None, - add_special_tokens: bool = False, + add_special_tokens: bool = True, truncation: bool = False, max_length: int | None = None, - ): + ) -> "BatchEncoding": if text_pair is not None: raise ValueError( "`text_pair` is not supported by `MistralTokenizer.__call__`." @@ -342,13 +359,11 @@ def encode( text: str, truncation: bool | None = None, max_length: int | None = None, - add_special_tokens: bool | None = None, + add_special_tokens: bool = True, ) -> list[int]: # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962 # is in, directly call self.transformers_tokenizer.encode(...). - encoded = self.tokenizer.encode( - text, bos=add_special_tokens is not False, eos=False - ) + encoded = self.tokenizer.encode(text, bos=add_special_tokens, eos=False) if truncation is not False and max_length is not None: return encoded[:max_length] @@ -383,7 +398,7 @@ def apply_chat_template( return_dict=False, ) - def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str: + def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str: # TODO(juliendenize): once https://github.com/huggingface/transformers/pull/41962 # is in, directly call self.transformers_tokenizer.decode(...). if isinstance(ids, int): @@ -455,7 +470,7 @@ def convert_tokens_to_string(self, tokens: list[str]) -> str: def convert_ids_to_tokens( self, ids: list[int], - skip_special_tokens: bool = True, + skip_special_tokens: bool = False, ) -> list[str]: from mistral_common.tokens.tokenizers.base import ( SpecialTokenPolicy, diff --git a/vllm/tokenizers/protocol.py b/vllm/tokenizers/protocol.py index 58a1a7c23f21..6c807bd99878 100644 --- a/vllm/tokenizers/protocol.py +++ b/vllm/tokenizers/protocol.py @@ -1,11 +1,11 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project - +from pathlib import Path from typing import TYPE_CHECKING, Any, Protocol -from typing_extensions import Self - if TYPE_CHECKING: + from transformers import BatchEncoding + from vllm.entrypoints.chat_utils import ChatCompletionMessageParam @@ -13,11 +13,13 @@ class TokenizerLike(Protocol): @classmethod def from_pretrained( cls, - pretrained_model_name_or_path: str, - /, - *, + path_or_repo_id: str | Path, + *args, + trust_remote_code: bool = False, revision: str | None = None, - ) -> Self: + download_dir: str | None = None, + **kwargs, + ) -> "TokenizerLike": raise NotImplementedError @property @@ -36,6 +38,10 @@ def bos_token_id(self) -> int: def eos_token_id(self) -> int: raise NotImplementedError + @property + def pad_token_id(self) -> int: + raise NotImplementedError + @property def is_fast(self) -> bool: raise NotImplementedError @@ -60,12 +66,12 @@ def __len__(self) -> int: def __call__( self, - text: str | list[str] | list[int], + text: str | list[str], text_pair: str | None = None, - add_special_tokens: bool = False, + add_special_tokens: bool = True, truncation: bool = False, max_length: int | None = None, - ): + ) -> "BatchEncoding": raise NotImplementedError def get_vocab(self) -> dict[str, int]: @@ -79,7 +85,7 @@ def encode( text: str, truncation: bool | None = None, max_length: int | None = None, - add_special_tokens: bool | None = None, + add_special_tokens: bool = True, ) -> list[int]: raise NotImplementedError @@ -94,12 +100,12 @@ def apply_chat_template( def convert_tokens_to_string(self, tokens: list[str]) -> str: raise NotImplementedError - def decode(self, ids: list[int] | int, skip_special_tokens: bool = True) -> str: + def decode(self, ids: list[int] | int, skip_special_tokens: bool = False) -> str: raise NotImplementedError def convert_ids_to_tokens( self, ids: list[int], - skip_special_tokens: bool = True, + skip_special_tokens: bool = False, ) -> list[str]: raise NotImplementedError diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 87d5cc2b483f..3bddfcd8a071 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import contextlib -import copy import importlib.util import os import warnings @@ -11,14 +9,17 @@ from typing import TYPE_CHECKING, Any import huggingface_hub -from transformers import AutoTokenizer, PreTrainedTokenizerBase from typing_extensions import assert_never from vllm import envs from vllm.logger import init_logger -from vllm.tokenizers import MistralTokenizer, TokenizerLike, TokenizerRegistry +from vllm.tokenizers import ( + HfTokenizer, + MistralTokenizer, + TokenizerLike, + TokenizerRegistry, +) -from .config import get_sentence_transformer_tokenizer_config from .gguf_utils import get_gguf_file_path_from_hf from .repo_utils import list_filtered_repo_files from .utils import check_gguf_file, is_gguf, is_remote_gguf, split_remote_gguf @@ -41,6 +42,18 @@ def __getattr__(name: str): ) return TokenizerLike + if name == "get_cached_tokenizer": + from vllm.tokenizers.hf import get_cached_tokenizer + + warnings.warn( + "`vllm.transformers_utils.tokenizer.get_cached_tokenizer` " + "has been moved to `vllm.tokenizers.hf.get_cached_tokenizer`. " + "The old name will be removed in v0.13.", + DeprecationWarning, + stacklevel=2, + ) + + return get_cached_tokenizer raise AttributeError(f"module {__name__!r} has no attribute {name!r}") @@ -58,10 +71,12 @@ def decode_tokens( `skip_special_tokens=None` means to use the backend's default settings. """ + kw_args: dict[str, Any] = {} + if skip_special_tokens is not None: - return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) + kw_args["skip_special_tokens"] = skip_special_tokens - return tokenizer.decode(token_ids) + return tokenizer.decode(token_ids, **kw_args) def encode_tokens( @@ -93,56 +108,6 @@ def encode_tokens( return tokenizer.encode(text, **kw_args) -def get_cached_tokenizer(tokenizer: TokenizerLike) -> TokenizerLike: - """ - By default, transformers will recompute multiple tokenizer properties - each time they are called, leading to a significant slowdown. - This proxy caches these properties for faster access. - """ - cached_tokenizer = copy.copy(tokenizer) - - tokenizer_all_special_ids = tokenizer.all_special_ids - tokenizer_all_special_tokens = tokenizer.all_special_tokens - tokenizer_vocab = tokenizer.get_vocab() - tokenizer_len = len(tokenizer) - - max_token_id = max(tokenizer_vocab.values()) - # Some tokenizers (e.g., QwenTokenizer) have special tokens that - # are added and included in the implementation of the vocab_size - # property, but not in get_vocab(); if there is an implementation - # of vocab size, we should take the greater value. - if hasattr(tokenizer, "vocab_size"): - with contextlib.suppress(NotImplementedError): - max_token_id = max(max_token_id, tokenizer.vocab_size) - - class CachedTokenizer(tokenizer.__class__): # type: ignore - @property - def all_special_ids(self) -> list[int]: - return tokenizer_all_special_ids - - @property - def all_special_tokens(self) -> list[str]: - return tokenizer_all_special_tokens - - @property - def max_token_id(self) -> int: - return max_token_id - - def get_vocab(self) -> dict[str, int]: - return tokenizer_vocab - - def __len__(self) -> int: - return tokenizer_len - - def __reduce__(self): - return get_cached_tokenizer, (tokenizer,) - - CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" - - cached_tokenizer.__class__ = CachedTokenizer - return cached_tokenizer - - def get_tokenizer( tokenizer_name: str | Path, *args, @@ -217,66 +182,37 @@ def get_tokenizer( if tokenizer_mode == "mistral": logger.debug_once(f"Loading MistralTokenizer from {tokenizer_name}") tokenizer = MistralTokenizer.from_pretrained( - str(tokenizer_name), revision=revision + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, ) elif tokenizer_mode == "custom": logger.debug_once(f"Loading CustomTokenizer from {tokenizer_name}") tokenizer = TokenizerRegistry.get_tokenizer( str(tokenizer_name), *args, + trust_remote_code=trust_remote_code, revision=revision, download_dir=download_dir, **kwargs, ) else: - try: - logger.debug_once(f"Loading AutoTokenizer from {tokenizer_name}") - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, - *args, - trust_remote_code=trust_remote_code, - revision=revision, - **kwargs, - ) - except ValueError as e: - # If the error pertains to the tokenizer class not existing or not - # currently being imported, - # suggest using the --trust-remote-code flag. - if not trust_remote_code and ( - "does not exist or is not currently imported." in str(e) - or "requires you to execute the tokenizer file" in str(e) - ): - err_msg = ( - "Failed to load the tokenizer. If the tokenizer " - "is a custom tokenizer not yet available in the " - "HuggingFace transformers library, consider " - "setting `trust_remote_code=True` in LLM or using " - "the `--trust-remote-code` flag in the CLI." - ) - raise RuntimeError(err_msg) from e - else: - raise e - - # The special_tokens in tokenizer should also be - # controlled by do_lower_case in encoder_config - encoder_config = get_sentence_transformer_tokenizer_config( - tokenizer_name, revision + logger.debug_once(f"Loading HfTokenizer from {tokenizer_name}") + tokenizer = HfTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + + if not tokenizer.is_fast: + logger.warning( + "Using a slow tokenizer. This might cause a significant " + "slowdown. Consider using a fast tokenizer instead." ) - if isinstance(encoder_config, dict) and encoder_config.get( - "do_lower_case", False - ): - assert isinstance(tokenizer, PreTrainedTokenizerBase) - special_tokens_map = { - k: v.lower() for k, v in tokenizer.special_tokens_map.items() - } - tokenizer.add_special_tokens(special_tokens_map) - - if not tokenizer.is_fast: - logger.warning( - "Using a slow tokenizer. This might cause a significant " - "slowdown. Consider using a fast tokenizer instead." - ) - tokenizer = get_cached_tokenizer(tokenizer) return tokenizer diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 6c0acd9a9f59..dce8765fcf6b 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -9,8 +9,8 @@ from transformers import PreTrainedTokenizerFast from vllm.logger import init_logger +from vllm.tokenizers import TokenizerLike from vllm.tokenizers.detokenizer_utils import ( - TokenizerLike, convert_prompt_ids_to_tokens, detokenize_incrementally, ) From dca47edcb382d220cf2c9adc5407aad971a462bb Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sat, 29 Nov 2025 14:58:21 +0000 Subject: [PATCH 2/4] Pass args Signed-off-by: DarkLight1337 --- vllm/transformers_utils/tokenizer.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 3bddfcd8a071..622d5c7fe993 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -186,6 +186,7 @@ def get_tokenizer( *args, trust_remote_code=trust_remote_code, revision=revision, + download_dir=download_dir, **kwargs, ) elif tokenizer_mode == "custom": @@ -205,6 +206,7 @@ def get_tokenizer( *args, trust_remote_code=trust_remote_code, revision=revision, + download_dir=download_dir, **kwargs, ) From 7045c0205587b8063258faf2d7aea2a720a44ccc Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 30 Nov 2025 04:13:41 +0000 Subject: [PATCH 3/4] Update tests Signed-off-by: DarkLight1337 --- .buildkite/test-amd.yaml | 8 ++++---- .buildkite/test-pipeline.yaml | 8 ++++---- tests/tokenizers_/test_mistral.py | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 7cb0a9bf761e..687b6b08507c 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -61,8 +61,8 @@ steps: - pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s utils_ -- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 10min - timeout_in_minutes: 15 +- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min + timeout_in_minutes: 20 mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking @@ -310,8 +310,8 @@ steps: - pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional -- label: Engine Test # 25min - timeout_in_minutes: 40 +- label: Engine Test # 9min + timeout_in_minutes: 15 mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 4ad2fe3cc3a2..9f2107fb1e5a 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -57,8 +57,8 @@ steps: - pytest -v -s -m 'not cpu_test' multimodal - pytest -v -s utils_ -- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 10min - timeout_in_minutes: 15 +- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min + timeout_in_minutes: 20 source_file_dependencies: - vllm/ - tests/test_inputs.py @@ -278,8 +278,8 @@ steps: - pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional -- label: Engine Test # 25min - timeout_in_minutes: 40 +- label: Engine Test # 9min + timeout_in_minutes: 15 mirror_hardwares: [amdexperimental] source_file_dependencies: - vllm/ diff --git a/tests/tokenizers_/test_mistral.py b/tests/tokenizers_/test_mistral.py index 0706a94791dc..92efac86dff2 100644 --- a/tests/tokenizers_/test_mistral.py +++ b/tests/tokenizers_/test_mistral.py @@ -356,8 +356,8 @@ def test_call(self, mistral_tokenizer: MistralTokenizer): ) attn_mask = [1 for _ in range(len(token_ids))] - # Test 1: default - assert mistral_tokenizer("Hello world !") == { + # Test 1: no special tokens + assert mistral_tokenizer("Hello world !", add_special_tokens=False) == { "attention_mask": attn_mask[1:], "input_ids": token_ids[1:], } @@ -381,7 +381,7 @@ def test_call(self, mistral_tokenizer: MistralTokenizer): "input_ids": token_ids, } # Test 5: empty string - assert mistral_tokenizer("") == { + assert mistral_tokenizer("", add_special_tokens=False) == { "attention_mask": [], "input_ids": [], } From ef50006c46b22d5737b93694fcac99f24337cd32 Mon Sep 17 00:00:00 2001 From: DarkLight1337 Date: Sun, 30 Nov 2025 04:52:42 +0000 Subject: [PATCH 4/4] Fix Signed-off-by: DarkLight1337 --- tests/tokenizers_/test_registry.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/tokenizers_/test_registry.py b/tests/tokenizers_/test_registry.py index 46918e5fa3eb..b357669f8378 100644 --- a/tests/tokenizers_/test_registry.py +++ b/tests/tokenizers_/test_registry.py @@ -21,6 +21,10 @@ def eos_token_id(self) -> int: def pad_token_id(self) -> int: return 2 + @property + def is_fast(self) -> bool: + return True + def test_customized_tokenizer(): TokenizerRegistry.register("test_tokenizer", __name__, TestTokenizer.__name__)