|
3 | 3 | # SPDX-License-Identifier: Apache-2.0 |
4 | 4 |
|
5 | 5 | import os |
| 6 | +from types import SimpleNamespace |
6 | 7 | from typing import List, Union |
7 | 8 |
|
8 | 9 | import PIL |
9 | 10 | import torch |
10 | 11 | from llama_models.llama3.api.chat_format import create_vision_mask |
11 | 12 | from tqdm import tqdm |
12 | | -from transformers import AutoProcessor |
13 | 13 | from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, EncoderDecoderInputs, InputContext, TokenInputs, token_inputs |
14 | 14 | from vllm.model_executor.models.interfaces import SupportsMultiModal |
15 | 15 |
|
@@ -194,110 +194,41 @@ def input_processor_for_llama_text(ctx: InputContext, inputs: Union[DecoderOnlyI |
194 | 194 | return inputs |
195 | 195 |
|
196 | 196 |
|
197 | | -# TODO: Update input processor to inherit from EncDecMultiModalProcessor as is done in vllm.model_executor.models.mllama.py |
198 | | -def input_processor_for_qwen2_5_vl( |
199 | | - ctx: InputContext, |
200 | | - inputs: EncoderDecoderInputs, |
201 | | -) -> EncoderDecoderInputs: |
202 | | - """ |
203 | | - This was based on a previous version of vllm.model_executor.models.mllama.py::input_processor_for_mllama() |
204 | | - without the additional processing for computing num_tiles (here it is fixed). |
205 | | - """ |
206 | | - # Example input to processor: |
207 | | - # { |
208 | | - # 'encoder': { |
209 | | - # 'type': 'token', |
210 | | - # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 |
211 | | - # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 |
212 | | - # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501 |
213 | | - # }, |
214 | | - # 'decoder': { |
215 | | - # 'type': 'token', |
216 | | - # 'prompt_token_ids': [128000], |
217 | | - # }, |
218 | | - # } |
219 | | - |
220 | | - # Move encoder_prompt to prompt. If the user does not explicitly provide separate |
221 | | - # encoder and decoder prompts, vLLM by default will treat the prompt as the encoder prompt. |
222 | | - # For the block manager to allocate enough blocks and add them to the block table, the decoder prompt |
223 | | - # must contain the full text prompt. |
224 | | - dec_inputs = TokenInputs(**inputs) |
225 | | - |
226 | | - if os.environ.get("MESH_DEVICE") == "N300": |
227 | | - prompt_len = len(dec_inputs.get("prompt_token_ids")) |
228 | | - MAX_PROMPT_LEN = 8192 |
229 | | - if prompt_len > MAX_PROMPT_LEN: |
230 | | - raise ValueError( |
231 | | - f"TT-LLama11B-Vision does not support prompts longer than {MAX_PROMPT_LEN} tokens on N300 (received prompt with {prompt_len} tokens)" |
232 | | - ) |
233 | | - |
234 | | - multi_modal_data = dec_inputs.get("multi_modal_data") |
235 | | - if multi_modal_data is None or "image" not in multi_modal_data: |
236 | | - # text-only |
237 | | - return EncoderDecoderInputs( |
238 | | - encoder=token_inputs([]), |
239 | | - decoder=dec_inputs, |
240 | | - ) |
241 | | - |
242 | | - # Set encoder prompt length based on the number of vision tokens so block manager allocates enough blocks (cross block tables). |
243 | | - # hf_config = ctx.model_config.hf_config |
244 | | - # vision_config = hf_config.vision_config |
245 | | - # assert vision_config.image_size % 14 == 0, "chunk size should be multiple of 14" |
246 | | - # token_per_chunk = nearest_32( |
247 | | - # (vision_config.image_size // 14) ** 2 + 1 |
248 | | - # ) # Note: we use nearest 32 while vLLM does not by default |
249 | | - # num_vision_tokens = ( |
250 | | - # vision_config.max_num_tiles * token_per_chunk |
251 | | - # ) # Note: we use max_num_tiles while vLLM uses num_tiles by default |
252 | | - |
253 | | - hf_config = ctx.model_config.hf_config |
254 | | - vision_config = hf_config.vision_config |
255 | | - |
256 | | - # Infer image size from window_size and spatial_patch_size |
257 | | - # Qwen uses windowed attention, and window_size = image_size // patch_size |
258 | | - # So image_size = window_size * patch_size |
259 | | - image_size = vision_config.window_size * vision_config.spatial_patch_size # e.g., 112 * 14 = 1568 |
| 197 | +def input_processor_for_qwen25_vl(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]): |
| 198 | + input_processor = ctx.get_hf_processor() |
| 199 | + if "prompt" in inputs: |
| 200 | + prompt_text = inputs["prompt"] |
| 201 | + else: |
| 202 | + # [INFO] with current version of vLLM, in server mode, inputs["prompt"] gives KeyError; only inputs['prompt_token_ids'] is available |
| 203 | + assert "prompt_token_ids" in inputs, "prompt_token_ids must be available in server mode" |
| 204 | + prompt_text = input_processor.decode(inputs["prompt_token_ids"], skip_special_tokens=False) |
| 205 | + if "multi_modal_data" in inputs and "image" in inputs["multi_modal_data"]: |
| 206 | + images = inputs["multi_modal_data"]["image"] |
| 207 | + else: |
| 208 | + images = None |
| 209 | + |
| 210 | + processed_inputs = input_processor( |
| 211 | + text=prompt_text, # [INFO] Qwen2VLProcessor handles the case where text is a string or a list of strings |
| 212 | + images=images, |
| 213 | + videos=None, # [INFO] videos are not supported yet |
| 214 | + return_tensors="pt", |
| 215 | + ) |
260 | 216 |
|
261 | | - # Optional: verify it's divisible by 14 if needed |
262 | | - assert image_size % vision_config.spatial_patch_size == 0, "chunk size should be multiple of patch size" |
| 217 | + assert processed_inputs.input_ids.shape[0] == 1, "Only one image is processed at a time by vLLM" |
| 218 | + return { |
| 219 | + "type": inputs["type"], |
| 220 | + "prompt_token_ids": processed_inputs.input_ids[0].tolist(), |
| 221 | + "prompt": prompt_text, |
| 222 | + "multi_modal_data": {"image": processed_inputs}, # [INFO] add processed_inputs |
| 223 | + } |
263 | 224 |
|
264 | | - token_per_chunk = nearest_32((image_size // vision_config.spatial_patch_size) ** 2 + 1) |
265 | 225 |
|
266 | | - # Qwen2.5-VL does not use max_num_tiles, but you can set it manually or derive it from your image splitting strategy |
267 | | - # Example: treat whole image as 1 tile unless your pipeline splits into tiles |
268 | | - num_tiles = getattr(vision_config, "max_num_tiles", 1) # fallback to 1 if not defined |
| 226 | +class CustomNamespace(SimpleNamespace): |
| 227 | + def __contains__(self, key): |
| 228 | + return key in self.__dict__ |
269 | 229 |
|
270 | | - num_vision_tokens = num_tiles * token_per_chunk |
271 | 230 |
|
272 | | - # Example output from processor: |
273 | | - # { |
274 | | - # 'encoder': { |
275 | | - # 'type': 'token', |
276 | | - # 'prompt_token_ids': [128256, 128256, ..., 128256], |
277 | | - # 'prompt': '<|image|><|image|>...<|image|>', |
278 | | - # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501 |
279 | | - # }, |
280 | | - # 'decoder': { |
281 | | - # 'type': 'token', |
282 | | - # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 |
283 | | - # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 |
284 | | - # 'multi_modal_data': {'image': <PIL.Image.Image image mode=RGB size=1770x1180 at 0x7FDE2C624880>}, # noqa: E501 |
285 | | - # }, |
286 | | - # } |
287 | | - MLLAMA_IMAGE_TOKEN_ID = hf_config.image_token_id |
288 | | - MLLAMA_IMAGE_TOKEN = "<|image_pad|>" |
289 | | - |
290 | | - return EncoderDecoderInputs( |
291 | | - encoder=token_inputs( |
292 | | - prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_vision_tokens, |
293 | | - prompt=MLLAMA_IMAGE_TOKEN * num_vision_tokens, |
294 | | - multi_modal_data=multi_modal_data, |
295 | | - ), |
296 | | - decoder=dec_inputs, |
297 | | - ) |
298 | | - |
299 | | - |
300 | | -@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_5_vl) |
| 231 | +@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen25_vl) |
301 | 232 | class Qwen2_5_VLForConditionalGeneration(Generator, SupportsMultiModal): |
302 | 233 | def __init__(self, *args, **kwargs): |
303 | 234 | super().__init__(*args, **kwargs) |
@@ -336,100 +267,38 @@ def cache_path(self): |
336 | 267 | def max_cross_attn_tokens(self): |
337 | 268 | return self.model_args[0].vision_max_num_chunks * nearest_32(self.model_args[0].vision_chunk_ntok) |
338 | 269 |
|
339 | | - def encode_input(self, token, image, processor): |
340 | | - print(image) |
341 | | - if image: |
342 | | - print |
343 | | - hf_messages = [ |
344 | | - { |
345 | | - "role": "user", |
346 | | - "content": [ |
347 | | - { |
348 | | - "type": "image", |
349 | | - "image": image, |
350 | | - }, |
351 | | - {"type": "text", "text": self.model_args[0].tokenizer.decode(token)}, |
352 | | - ], |
353 | | - } |
354 | | - ] |
355 | | - else: |
356 | | - hf_messages = [ |
357 | | - { |
358 | | - "role": "user", |
359 | | - "content": [ |
360 | | - {"type": "text", "text": self.model_args[0].tokenizer.decode(token)}, |
361 | | - ], |
362 | | - } |
363 | | - ] |
364 | | - |
365 | | - encoded = processor.apply_chat_template( |
366 | | - hf_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" |
367 | | - ).to("cpu", dtype=torch.bfloat16) |
368 | | - |
369 | | - return encoded |
370 | | - |
371 | | - def prefill_forward( |
372 | | - self, |
373 | | - tokens: torch.Tensor, |
374 | | - images: Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]], |
375 | | - page_table: torch.Tensor, |
376 | | - kv_cache, |
377 | | - prompt_lens, |
378 | | - cross_page_table=None, |
379 | | - ): |
380 | | - """ |
381 | | - Replaces prefill_forward from Generator with a version that supports mask creation. |
382 | | - """ |
383 | | - batch = tokens.shape[0] |
384 | | - |
385 | | - vision_images = [] |
386 | | - tokens_list = [] |
387 | | - image_grid_thw = [] |
388 | | - |
389 | | - processor = AutoProcessor.from_pretrained(self.model_args[0].CKPT_DIR) |
390 | | - |
391 | | - for user_id in range(batch): |
392 | | - image = images[user_id] |
393 | | - if isinstance(image, list): |
394 | | - assert len(image) == 1, "Only one image is supported for each user in the batch" |
395 | | - image = image[0] |
396 | | - |
397 | | - prompt_tokens = [int(tokens[user_id, i]) for i in range(prompt_lens[user_id])] |
398 | | - encoded_input = self.encode_input(prompt_tokens, image, processor) |
399 | | - vision_images.append(encoded_input["pixel_values"] if image else None) |
400 | | - tokens_list.append(encoded_input["input_ids"].squeeze(0)) |
401 | | - image_grid_thw.append(encoded_input["image_grid_thw"] if image else None) |
402 | | - |
403 | | - prefill_lens = torch.tensor([len(token) for token in tokens_list], dtype=torch.long) |
404 | | - total_lens = prefill_lens + self.max_gen_len |
405 | | - |
406 | | - pad_id = processor.tokenizer.pad_token_id |
407 | | - tokens = torch.full((batch, max(total_lens)), pad_id, dtype=torch.long) |
408 | | - |
409 | | - for i, seq in enumerate(tokens_list): |
410 | | - tokens[i, : len(seq)] = torch.tensor(seq, dtype=torch.long) |
411 | | - |
412 | | - self.prefill_lens = prefill_lens |
413 | | - |
414 | | - return super().prefill_forward( |
415 | | - vision_images, |
416 | | - None, |
417 | | - tokens, |
418 | | - None, |
419 | | - total_lens=total_lens, |
420 | | - prompt_lens=prefill_lens, |
| 270 | + def prefill_forward(self, *args, **kwargs): |
| 271 | + self.tokenizer = self.model_args[0].tokenizer |
| 272 | + pad_token_id = self.tokenizer.pad_token_id |
| 273 | + |
| 274 | + tokens = kwargs["tokens"] |
| 275 | + prompt_lens = kwargs["prompt_lens"] |
| 276 | + inputs = CustomNamespace() |
| 277 | + inputs.input_ids = tokens |
| 278 | + data = kwargs.get("images", None) # This contains the entire Data list, not just the pixel values |
| 279 | + for i in range(tokens.shape[0]): # for each user, fix their padding |
| 280 | + tokens[i][prompt_lens[i] :] = pad_token_id |
| 281 | + pixel_values, image_grid_thw = None, None |
| 282 | + |
| 283 | + if hasattr(data[0], "pixel_values"): |
| 284 | + # If inputs is a list of objects with .pixel_values, concatenate them |
| 285 | + pixel_values = [im.pixel_values for im in data if hasattr(im, "pixel_values")] |
| 286 | + image_grid_thw = [im.image_grid_thw for im in data if hasattr(im, "image_grid_thw")] |
| 287 | + |
| 288 | + page_table = kwargs.get("page_table", None) |
| 289 | + kv_cache = kwargs.get("kv_cache", None) |
| 290 | + |
| 291 | + return super().prefill_forward_text( |
| 292 | + tokens=inputs.input_ids, |
421 | 293 | page_table=page_table, |
422 | 294 | kv_cache=kv_cache, |
423 | | - cross_page_table=cross_page_table, |
424 | | - image_grid_thw=image_grid_thw, |
425 | | - )[0] |
| 295 | + prompt_lens=prompt_lens, |
| 296 | + pixel_values=pixel_values if pixel_values else None, |
| 297 | + image_grid_thw=image_grid_thw if image_grid_thw else None, |
| 298 | + ) |
426 | 299 |
|
427 | 300 | def decode_forward(self, *args, **kwargs): |
428 | | - if kwargs.get("start_pos") is not None: |
429 | | - kwargs["start_pos"][: len(self.prefill_lens)] = self.prefill_lens |
430 | | - logits = super().decode_forward_text(*args, **kwargs) |
431 | | - self.prefill_lens += 1 |
432 | | - return logits |
| 301 | + return super().decode_forward_text(*args, **kwargs) |
433 | 302 |
|
434 | 303 | def allocate_kv_cache(self, *args, **kwargs): |
435 | 304 | return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path) |
|
0 commit comments