Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions .buildkite/test-amd.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ steps:
- pytest -v -s -m 'not cpu_test' multimodal
- pytest -v -s utils_

- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins
timeout_in_minutes: 10
- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min
timeout_in_minutes: 20
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
Expand All @@ -72,6 +72,7 @@ steps:
- tests/test_outputs.py
- tests/multimodal
- tests/standalone_tests/lazy_imports.py
- tests/tokenizers_
- tests/transformers_utils
- tests/config
no_gpu: true
Expand All @@ -80,6 +81,7 @@ steps:
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s tokenizers_
- pytest -v -s transformers_utils
- pytest -v -s config

Expand Down Expand Up @@ -308,23 +310,20 @@ steps:
- pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional

- label: Engine Test # 25min
timeout_in_minutes: 40
- label: Engine Test # 9min
timeout_in_minutes: 15
mirror_hardwares: [amdexperimental, amdproduction]
agent_pool: mi325_1
# grade: Blocking
source_file_dependencies:
- vllm/
- tests/engine
- tests/tokenizers_
- tests/test_sequence
- tests/test_config
- tests/test_logger
- tests/test_vllm_port
commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
# OOM in the CI unless we run this separately
- pytest -v -s tokenizers_

- label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45
Expand Down
13 changes: 6 additions & 7 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,15 @@ steps:
- pytest -v -s -m 'not cpu_test' multimodal
- pytest -v -s utils_

- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 4 mins
timeout_in_minutes: 10
- label: Async Engine, Inputs, Utils, Worker, Config Test (CPU) # 15min
timeout_in_minutes: 20
source_file_dependencies:
- vllm/
- tests/test_inputs.py
- tests/test_outputs.py
- tests/multimodal
- tests/standalone_tests/lazy_imports.py
- tests/tokenizers_
- tests/transformers_utils
- tests/config
no_gpu: true
Expand All @@ -73,6 +74,7 @@ steps:
- pytest -v -s test_inputs.py
- pytest -v -s test_outputs.py
- pytest -v -s -m 'cpu_test' multimodal
- pytest -v -s tokenizers_
- pytest -v -s transformers_utils
- pytest -v -s config

Expand Down Expand Up @@ -276,21 +278,18 @@ steps:
- pytest -v -s test_regression.py
working_dir: "/vllm-workspace/tests" # optional

- label: Engine Test # 25min
timeout_in_minutes: 40
- label: Engine Test # 9min
timeout_in_minutes: 15
mirror_hardwares: [amdexperimental]
source_file_dependencies:
- vllm/
- tests/engine
- tests/tokenizers_
- tests/test_sequence
- tests/test_config
- tests/test_logger
- tests/test_vllm_port
commands:
- pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py
# OOM in the CI unless we run this separately
- pytest -v -s tokenizers_

- label: V1 Test e2e + engine # 30min
timeout_in_minutes: 45
Expand Down
2 changes: 1 addition & 1 deletion docs/design/huggingface_integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Let's say we want to serve the popular Qwen model by running `vllm serve Qwen/Qw

Beyond that, there are two more things vLLM depends on Hugging Face for.

1. **Tokenizer**: vLLM uses the tokenizer from Hugging Face to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check Hugging Face's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [get_cached_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L24).
1. **Tokenizer**: vLLM uses the tokenizer from Hugging Face to tokenize the input text. The tokenizer is loaded using [AutoTokenizer.from_pretrained](https://huggingface.co/docs/transformers/en/model_doc/auto#transformers.AutoTokenizer.from_pretrained) with the `model` argument as the model name and the `--revision` argument as the revision. It is also possible to use a tokenizer from another model by specifying the `--tokenizer` argument in the `vllm serve` command. Other relevant arguments are `--tokenizer-revision` and `--tokenizer-mode`. Please check Hugging Face's documentation for the meaning of these arguments. This part of the logic can be found in the [get_tokenizer](https://github.com/vllm-project/vllm/blob/127c07480ecea15e4c2990820c457807ff78a057/vllm/transformers_utils/tokenizer.py#L87) function. After obtaining the tokenizer, notably, vLLM will cache some expensive attributes of the tokenizer in [vllm.tokenizers.hf.get_cached_tokenizer][].

2. **Model weight**: vLLM downloads the model weight from the Hugging Face model hub using the `model` argument as the model name and the `--revision` argument as the revision. vLLM provides the argument `--load-format` to control what files to download from the model hub. By default, it will try to load the weights in the safetensors format and fall back to the PyTorch bin format if the safetensors format is not available. We can also pass `--load-format dummy` to skip downloading the weights.
- It is recommended to use the safetensors format, as it is efficient for loading in distributed inference and also safe from arbitrary code execution. See the [documentation](https://huggingface.co/docs/safetensors/en/index) for more information on the safetensors format. This part of the logic can be found [here](https://github.com/vllm-project/vllm/blob/10b67d865d92e376956345becafc249d4c3c0ab7/vllm/model_executor/model_loader/loader.py#L385). Please note that:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from transformers import AutoTokenizer

from vllm.tokenizers import TokenizerLike
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.tokenizers.hf import get_cached_tokenizer


@pytest.mark.parametrize("model_id", ["gpt2", "zai-org/chatglm3-6b"])
Expand Down
6 changes: 3 additions & 3 deletions tests/tokenizers_/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,8 +356,8 @@ def test_call(self, mistral_tokenizer: MistralTokenizer):
)
attn_mask = [1 for _ in range(len(token_ids))]

# Test 1: default
assert mistral_tokenizer("Hello world !") == {
# Test 1: no special tokens
assert mistral_tokenizer("Hello world !", add_special_tokens=False) == {
"attention_mask": attn_mask[1:],
"input_ids": token_ids[1:],
}
Expand All @@ -381,7 +381,7 @@ def test_call(self, mistral_tokenizer: MistralTokenizer):
"input_ids": token_ids,
}
# Test 5: empty string
assert mistral_tokenizer("") == {
assert mistral_tokenizer("", add_special_tokens=False) == {
"attention_mask": [],
"input_ids": [],
}
Expand Down
16 changes: 11 additions & 5 deletions tests/tokenizers_/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,26 @@ def bos_token_id(self) -> int:
def eos_token_id(self) -> int:
return 1

@property
def pad_token_id(self) -> int:
return 2

@property
def is_fast(self) -> bool:
return True


def test_customized_tokenizer():
TokenizerRegistry.register(
"test_tokenizer",
__name__,
TestTokenizer.__name__,
)
TokenizerRegistry.register("test_tokenizer", __name__, TestTokenizer.__name__)

tokenizer = TokenizerRegistry.get_tokenizer("test_tokenizer")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1
assert tokenizer.pad_token_id == 2

tokenizer = get_tokenizer("test_tokenizer", tokenizer_mode="custom")
assert isinstance(tokenizer, TestTokenizer)
assert tokenizer.bos_token_id == 0
assert tokenizer.eos_token_id == 1
assert tokenizer.pad_token_id == 2
2 changes: 1 addition & 1 deletion tools/pre_commit/check_pickle_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"vllm/distributed/device_communicators/shm_broadcast.py",
"vllm/distributed/device_communicators/shm_object_storage.py",
"vllm/utils/hashing.py",
"tests/tokenizers_/test_cached_tokenizer.py",
"tests/tokenizers_/test_hf.py",
"tests/utils_/test_hashing.py",
"benchmarks/kernels/graph_machete_bench.py",
"benchmarks/kernels/benchmark_lora.py",
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
from vllm.sampling_params import BeamSearchParams, RequestOutputKind, SamplingParams
from vllm.tasks import PoolingTask
from vllm.tokenizers import MistralTokenizer, TokenizerLike
from vllm.transformers_utils.tokenizer import get_cached_tokenizer
from vllm.tokenizers.hf import get_cached_tokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils.collection_utils import as_iter, is_list_of
from vllm.utils.counter import Counter
Expand Down
4 changes: 2 additions & 2 deletions vllm/entrypoints/score_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def _cosine_similarity(
for emb_1, emb_2 in zip(embed_1, embed_2):
pair_score = scorer(emb_1.outputs.data, emb_2.outputs.data)

padding = []
if (pad_token_id := getattr(tokenizer, "pad_token_id", None)) is not None:
padding: list[int] = []
if (pad_token_id := tokenizer.pad_token_id) is not None:
padding = [pad_token_id]

tokens = emb_1.prompt_token_ids + padding + emb_2.prompt_token_ids
Expand Down
3 changes: 2 additions & 1 deletion vllm/tokenizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from .hf import HfTokenizer
from .mistral import MistralTokenizer
from .protocol import TokenizerLike
from .registry import TokenizerRegistry

__all__ = ["TokenizerLike", "MistralTokenizer", "TokenizerRegistry"]
__all__ = ["TokenizerLike", "HfTokenizer", "MistralTokenizer", "TokenizerRegistry"]
122 changes: 122 additions & 0 deletions vllm/tokenizers/hf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import copy
from pathlib import Path
from typing import TYPE_CHECKING

from transformers import AutoTokenizer

from vllm.transformers_utils.config import get_sentence_transformer_tokenizer_config

from .protocol import TokenizerLike

if TYPE_CHECKING:
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast


def get_cached_tokenizer(
tokenizer: "PreTrainedTokenizer | PreTrainedTokenizerFast",
) -> TokenizerLike:
"""
By default, transformers will recompute multiple tokenizer properties
each time they are called, leading to a significant slowdown.
This proxy caches these properties for faster access.
"""
cached_tokenizer = copy.copy(tokenizer)

tokenizer_all_special_ids = tokenizer.all_special_ids
tokenizer_all_special_tokens = tokenizer.all_special_tokens
tokenizer_vocab = tokenizer.get_vocab()
tokenizer_len = len(tokenizer)

max_token_id = max(tokenizer_vocab.values())
# Some tokenizers (e.g., QwenTokenizer) have special tokens that
# are added and included in the implementation of the vocab_size
# property, but not in get_vocab(); if there is an implementation
# of vocab size, we should take the greater value.
if hasattr(tokenizer, "vocab_size"):
with contextlib.suppress(NotImplementedError):
max_token_id = max(max_token_id, tokenizer.vocab_size)

class CachedTokenizer(tokenizer.__class__): # type: ignore
@property
def all_special_ids(self) -> list[int]:
return tokenizer_all_special_ids

@property
def all_special_tokens(self) -> list[str]:
return tokenizer_all_special_tokens

@property
def max_token_id(self) -> int:
return max_token_id

def get_vocab(self) -> dict[str, int]:
return tokenizer_vocab

def __len__(self) -> int:
return tokenizer_len

def __reduce__(self):
return get_cached_tokenizer, (tokenizer,)

CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}"

cached_tokenizer.__class__ = CachedTokenizer
return cached_tokenizer # type: ignore


class HfTokenizer(TokenizerLike):
@classmethod
def from_pretrained(
cls,
path_or_repo_id: str | Path,
*args,
trust_remote_code: bool = False,
revision: str | None = None,
download_dir: str | None = None,
**kwargs,
) -> "TokenizerLike":
try:
tokenizer = AutoTokenizer.from_pretrained(
path_or_repo_id,
*args,
trust_remote_code=trust_remote_code,
revision=revision,
cache_dir=download_dir,
**kwargs,
)
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported,
# suggest using the --trust-remote-code flag.
if not trust_remote_code and (
"does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e)
):
err_msg = (
"Failed to load the tokenizer. If the tokenizer "
"is a custom tokenizer not yet available in the "
"HuggingFace transformers library, consider "
"setting `trust_remote_code=True` in LLM or using "
"the `--trust-remote-code` flag in the CLI."
)
raise RuntimeError(err_msg) from e
else:
raise e

# The special_tokens in tokenizer should also be
# controlled by do_lower_case in encoder_config
encoder_config = get_sentence_transformer_tokenizer_config(
path_or_repo_id, revision
)
if isinstance(encoder_config, dict) and encoder_config.get(
"do_lower_case", False
):
special_tokens_map = {
k: v.lower() for k, v in tokenizer.special_tokens_map.items()
}
tokenizer.add_special_tokens(special_tokens_map)

return get_cached_tokenizer(tokenizer)
Loading