Skip to content

Commit e47d90b

Browse files
Qwen VL 7B vLLM support
1 parent 67f5893 commit e47d90b

File tree

1 file changed

+59
-190
lines changed

1 file changed

+59
-190
lines changed

models/tt_transformers/tt/generator_vllm.py

Lines changed: 59 additions & 190 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import os
6+
from types import SimpleNamespace
67
from typing import List, Union
78

89
import PIL
910
import torch
1011
from llama_models.llama3.api.chat_format import create_vision_mask
1112
from tqdm import tqdm
12-
from transformers import AutoProcessor
1313
from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, EncoderDecoderInputs, InputContext, TokenInputs, token_inputs
1414
from vllm.model_executor.models.interfaces import SupportsMultiModal
1515

@@ -194,110 +194,41 @@ def input_processor_for_llama_text(ctx: InputContext, inputs: Union[DecoderOnlyI
194194
return inputs
195195

196196

197-
# TODO: Update input processor to inherit from EncDecMultiModalProcessor as is done in vllm.model_executor.models.mllama.py
198-
def input_processor_for_qwen2_5_vl(
199-
ctx: InputContext,
200-
inputs: EncoderDecoderInputs,
201-
) -> EncoderDecoderInputs:
202-
"""
203-
This was based on a previous version of vllm.model_executor.models.mllama.py::input_processor_for_mllama()
204-
without the additional processing for computing num_tiles (here it is fixed).
205-
"""
206-
# Example input to processor:
207-
# {
208-
# 'encoder': {
209-
# 'type': 'token',
210-
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
211-
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
212-
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
213-
# },
214-
# 'decoder': {
215-
# 'type': 'token',
216-
# 'prompt_token_ids': [128000],
217-
# },
218-
# }
219-
220-
# Move encoder_prompt to prompt. If the user does not explicitly provide separate
221-
# encoder and decoder prompts, vLLM by default will treat the prompt as the encoder prompt.
222-
# For the block manager to allocate enough blocks and add them to the block table, the decoder prompt
223-
# must contain the full text prompt.
224-
dec_inputs = TokenInputs(**inputs)
225-
226-
if os.environ.get("MESH_DEVICE") == "N300":
227-
prompt_len = len(dec_inputs.get("prompt_token_ids"))
228-
MAX_PROMPT_LEN = 8192
229-
if prompt_len > MAX_PROMPT_LEN:
230-
raise ValueError(
231-
f"TT-LLama11B-Vision does not support prompts longer than {MAX_PROMPT_LEN} tokens on N300 (received prompt with {prompt_len} tokens)"
232-
)
233-
234-
multi_modal_data = dec_inputs.get("multi_modal_data")
235-
if multi_modal_data is None or "image" not in multi_modal_data:
236-
# text-only
237-
return EncoderDecoderInputs(
238-
encoder=token_inputs([]),
239-
decoder=dec_inputs,
240-
)
241-
242-
# Set encoder prompt length based on the number of vision tokens so block manager allocates enough blocks (cross block tables).
243-
# hf_config = ctx.model_config.hf_config
244-
# vision_config = hf_config.vision_config
245-
# assert vision_config.image_size % 14 == 0, "chunk size should be multiple of 14"
246-
# token_per_chunk = nearest_32(
247-
# (vision_config.image_size // 14) ** 2 + 1
248-
# ) # Note: we use nearest 32 while vLLM does not by default
249-
# num_vision_tokens = (
250-
# vision_config.max_num_tiles * token_per_chunk
251-
# ) # Note: we use max_num_tiles while vLLM uses num_tiles by default
252-
253-
hf_config = ctx.model_config.hf_config
254-
vision_config = hf_config.vision_config
255-
256-
# Infer image size from window_size and spatial_patch_size
257-
# Qwen uses windowed attention, and window_size = image_size // patch_size
258-
# So image_size = window_size * patch_size
259-
image_size = vision_config.window_size * vision_config.spatial_patch_size # e.g., 112 * 14 = 1568
197+
def input_processor_for_qwen25_vl(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]):
198+
input_processor = ctx.get_hf_processor()
199+
if "prompt" in inputs:
200+
prompt_text = inputs["prompt"]
201+
else:
202+
# [INFO] with current version of vLLM, in server mode, inputs["prompt"] gives KeyError; only inputs['prompt_token_ids'] is available
203+
assert "prompt_token_ids" in inputs, "prompt_token_ids must be available in server mode"
204+
prompt_text = input_processor.decode(inputs["prompt_token_ids"], skip_special_tokens=False)
205+
if "multi_modal_data" in inputs and "image" in inputs["multi_modal_data"]:
206+
images = inputs["multi_modal_data"]["image"]
207+
else:
208+
images = None
209+
210+
processed_inputs = input_processor(
211+
text=prompt_text, # [INFO] Qwen2VLProcessor handles the case where text is a string or a list of strings
212+
images=images,
213+
videos=None, # [INFO] videos are not supported yet
214+
return_tensors="pt",
215+
)
260216

261-
# Optional: verify it's divisible by 14 if needed
262-
assert image_size % vision_config.spatial_patch_size == 0, "chunk size should be multiple of patch size"
217+
assert processed_inputs.input_ids.shape[0] == 1, "Only one image is processed at a time by vLLM"
218+
return {
219+
"type": inputs["type"],
220+
"prompt_token_ids": processed_inputs.input_ids[0].tolist(),
221+
"prompt": prompt_text,
222+
"multi_modal_data": {"image": processed_inputs}, # [INFO] add processed_inputs
223+
}
263224

264-
token_per_chunk = nearest_32((image_size // vision_config.spatial_patch_size) ** 2 + 1)
265225

266-
# Qwen2.5-VL does not use max_num_tiles, but you can set it manually or derive it from your image splitting strategy
267-
# Example: treat whole image as 1 tile unless your pipeline splits into tiles
268-
num_tiles = getattr(vision_config, "max_num_tiles", 1) # fallback to 1 if not defined
226+
class CustomNamespace(SimpleNamespace):
227+
def __contains__(self, key):
228+
return key in self.__dict__
269229

270-
num_vision_tokens = num_tiles * token_per_chunk
271230

272-
# Example output from processor:
273-
# {
274-
# 'encoder': {
275-
# 'type': 'token',
276-
# 'prompt_token_ids': [128256, 128256, ..., 128256],
277-
# 'prompt': '<|image|><|image|>...<|image|>',
278-
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
279-
# },
280-
# 'decoder': {
281-
# 'type': 'token',
282-
# 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501
283-
# 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501
284-
# 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501
285-
# },
286-
# }
287-
MLLAMA_IMAGE_TOKEN_ID = hf_config.image_token_id
288-
MLLAMA_IMAGE_TOKEN = "<|image_pad|>"
289-
290-
return EncoderDecoderInputs(
291-
encoder=token_inputs(
292-
prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_vision_tokens,
293-
prompt=MLLAMA_IMAGE_TOKEN * num_vision_tokens,
294-
multi_modal_data=multi_modal_data,
295-
),
296-
decoder=dec_inputs,
297-
)
298-
299-
300-
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_5_vl)
231+
@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen25_vl)
301232
class Qwen2_5_VLForConditionalGeneration(Generator, SupportsMultiModal):
302233
def __init__(self, *args, **kwargs):
303234
super().__init__(*args, **kwargs)
@@ -336,100 +267,38 @@ def cache_path(self):
336267
def max_cross_attn_tokens(self):
337268
return self.model_args[0].vision_max_num_chunks * nearest_32(self.model_args[0].vision_chunk_ntok)
338269

339-
def encode_input(self, token, image, processor):
340-
print(image)
341-
if image:
342-
print
343-
hf_messages = [
344-
{
345-
"role": "user",
346-
"content": [
347-
{
348-
"type": "image",
349-
"image": image,
350-
},
351-
{"type": "text", "text": self.model_args[0].tokenizer.decode(token)},
352-
],
353-
}
354-
]
355-
else:
356-
hf_messages = [
357-
{
358-
"role": "user",
359-
"content": [
360-
{"type": "text", "text": self.model_args[0].tokenizer.decode(token)},
361-
],
362-
}
363-
]
364-
365-
encoded = processor.apply_chat_template(
366-
hf_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt"
367-
).to("cpu", dtype=torch.bfloat16)
368-
369-
return encoded
370-
371-
def prefill_forward(
372-
self,
373-
tokens: torch.Tensor,
374-
images: Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]],
375-
page_table: torch.Tensor,
376-
kv_cache,
377-
prompt_lens,
378-
cross_page_table=None,
379-
):
380-
"""
381-
Replaces prefill_forward from Generator with a version that supports mask creation.
382-
"""
383-
batch = tokens.shape[0]
384-
385-
vision_images = []
386-
tokens_list = []
387-
image_grid_thw = []
388-
389-
processor = AutoProcessor.from_pretrained(self.model_args[0].CKPT_DIR)
390-
391-
for user_id in range(batch):
392-
image = images[user_id]
393-
if isinstance(image, list):
394-
assert len(image) == 1, "Only one image is supported for each user in the batch"
395-
image = image[0]
396-
397-
prompt_tokens = [int(tokens[user_id, i]) for i in range(prompt_lens[user_id])]
398-
encoded_input = self.encode_input(prompt_tokens, image, processor)
399-
vision_images.append(encoded_input["pixel_values"] if image else None)
400-
tokens_list.append(encoded_input["input_ids"].squeeze(0))
401-
image_grid_thw.append(encoded_input["image_grid_thw"] if image else None)
402-
403-
prefill_lens = torch.tensor([len(token) for token in tokens_list], dtype=torch.long)
404-
total_lens = prefill_lens + self.max_gen_len
405-
406-
pad_id = processor.tokenizer.pad_token_id
407-
tokens = torch.full((batch, max(total_lens)), pad_id, dtype=torch.long)
408-
409-
for i, seq in enumerate(tokens_list):
410-
tokens[i, : len(seq)] = torch.tensor(seq, dtype=torch.long)
411-
412-
self.prefill_lens = prefill_lens
413-
414-
return super().prefill_forward(
415-
vision_images,
416-
None,
417-
tokens,
418-
None,
419-
total_lens=total_lens,
420-
prompt_lens=prefill_lens,
270+
def prefill_forward(self, *args, **kwargs):
271+
self.tokenizer = self.model_args[0].tokenizer
272+
pad_token_id = self.tokenizer.pad_token_id
273+
274+
tokens = kwargs["tokens"]
275+
prompt_lens = kwargs["prompt_lens"]
276+
inputs = CustomNamespace()
277+
inputs.input_ids = tokens
278+
data = kwargs.get("images", None) # This contains the entire Data list, not just the pixel values
279+
for i in range(tokens.shape[0]): # for each user, fix their padding
280+
tokens[i][prompt_lens[i] :] = pad_token_id
281+
pixel_values, image_grid_thw = None, None
282+
283+
if hasattr(data[0], "pixel_values"):
284+
# If inputs is a list of objects with .pixel_values, concatenate them
285+
pixel_values = [im.pixel_values for im in data if hasattr(im, "pixel_values")]
286+
image_grid_thw = [im.image_grid_thw for im in data if hasattr(im, "image_grid_thw")]
287+
288+
page_table = kwargs.get("page_table", None)
289+
kv_cache = kwargs.get("kv_cache", None)
290+
291+
return super().prefill_forward_text(
292+
tokens=inputs.input_ids,
421293
page_table=page_table,
422294
kv_cache=kv_cache,
423-
cross_page_table=cross_page_table,
424-
image_grid_thw=image_grid_thw,
425-
)[0]
295+
prompt_lens=prompt_lens,
296+
pixel_values=pixel_values if pixel_values else None,
297+
image_grid_thw=image_grid_thw if image_grid_thw else None,
298+
)
426299

427300
def decode_forward(self, *args, **kwargs):
428-
if kwargs.get("start_pos") is not None:
429-
kwargs["start_pos"][: len(self.prefill_lens)] = self.prefill_lens
430-
logits = super().decode_forward_text(*args, **kwargs)
431-
self.prefill_lens += 1
432-
return logits
301+
return super().decode_forward_text(*args, **kwargs)
433302

434303
def allocate_kv_cache(self, *args, **kwargs):
435304
return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path)

0 commit comments

Comments
 (0)