From e6a125b61f11ae34027f275838d37bd81fc70355 Mon Sep 17 00:00:00 2001 From: MohammedTaherMcW Date: Fri, 8 Aug 2025 17:57:17 +0000 Subject: [PATCH 1/2] WIP vLLM support for Qwen VL 7B --- models/tt_transformers/tt/generator_vllm.py | 244 ++++++++++++++++++ .../tt/multimodal/qwen_vl/qwen_e2e_model.py | 27 +- .../qwen_vl/qwen_image_attention.py | 5 +- .../qwen_vl/qwen_image_patch_embed.py | 4 +- 4 files changed, 264 insertions(+), 16 deletions(-) diff --git a/models/tt_transformers/tt/generator_vllm.py b/models/tt_transformers/tt/generator_vllm.py index 5125f551053d..51b3d528b7ab 100644 --- a/models/tt_transformers/tt/generator_vllm.py +++ b/models/tt_transformers/tt/generator_vllm.py @@ -9,6 +9,7 @@ import torch from llama_models.llama3.api.chat_format import create_vision_mask from tqdm import tqdm +from transformers import AutoProcessor from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, EncoderDecoderInputs, InputContext, TokenInputs, token_inputs from vllm.model_executor.models.interfaces import SupportsMultiModal @@ -123,6 +124,8 @@ def input_processor_for_mllama( # encoder and decoder prompts, vLLM by default will treat the prompt as the encoder prompt. # For the block manager to allocate enough blocks and add them to the block table, the decoder prompt # must contain the full text prompt. + # print() + dec_inputs = TokenInputs(**inputs["encoder"]) if os.environ.get("MESH_DEVICE") == "N300": @@ -191,6 +194,247 @@ def input_processor_for_llama_text(ctx: InputContext, inputs: Union[DecoderOnlyI return inputs +# TODO: Update input processor to inherit from EncDecMultiModalProcessor as is done in vllm.model_executor.models.mllama.py +def input_processor_for_qwen2_5_vl( + ctx: InputContext, + inputs: EncoderDecoderInputs, +) -> EncoderDecoderInputs: + """ + This was based on a previous version of vllm.model_executor.models.mllama.py::input_processor_for_mllama() + without the additional processing for computing num_tiles (here it is fixed). + """ + # Example input to processor: + # { + # 'encoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 + # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # 'decoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000], + # }, + # } + + # Move encoder_prompt to prompt. If the user does not explicitly provide separate + # encoder and decoder prompts, vLLM by default will treat the prompt as the encoder prompt. + # For the block manager to allocate enough blocks and add them to the block table, the decoder prompt + # must contain the full text prompt. + dec_inputs = TokenInputs(**inputs) + + if os.environ.get("MESH_DEVICE") == "N300": + prompt_len = len(dec_inputs.get("prompt_token_ids")) + MAX_PROMPT_LEN = 8192 + if prompt_len > MAX_PROMPT_LEN: + raise ValueError( + f"TT-LLama11B-Vision does not support prompts longer than {MAX_PROMPT_LEN} tokens on N300 (received prompt with {prompt_len} tokens)" + ) + + multi_modal_data = dec_inputs.get("multi_modal_data") + if multi_modal_data is None or "image" not in multi_modal_data: + # text-only + return EncoderDecoderInputs( + encoder=token_inputs([]), + decoder=dec_inputs, + ) + + # Set encoder prompt length based on the number of vision tokens so block manager allocates enough blocks (cross block tables). + # hf_config = ctx.model_config.hf_config + # vision_config = hf_config.vision_config + # assert vision_config.image_size % 14 == 0, "chunk size should be multiple of 14" + # token_per_chunk = nearest_32( + # (vision_config.image_size // 14) ** 2 + 1 + # ) # Note: we use nearest 32 while vLLM does not by default + # num_vision_tokens = ( + # vision_config.max_num_tiles * token_per_chunk + # ) # Note: we use max_num_tiles while vLLM uses num_tiles by default + + hf_config = ctx.model_config.hf_config + vision_config = hf_config.vision_config + + # Infer image size from window_size and spatial_patch_size + # Qwen uses windowed attention, and window_size = image_size // patch_size + # So image_size = window_size * patch_size + image_size = vision_config.window_size * vision_config.spatial_patch_size # e.g., 112 * 14 = 1568 + + # Optional: verify it's divisible by 14 if needed + assert image_size % vision_config.spatial_patch_size == 0, "chunk size should be multiple of patch size" + + token_per_chunk = nearest_32((image_size // vision_config.spatial_patch_size) ** 2 + 1) + + # Qwen2.5-VL does not use max_num_tiles, but you can set it manually or derive it from your image splitting strategy + # Example: treat whole image as 1 tile unless your pipeline splits into tiles + num_tiles = getattr(vision_config, "max_num_tiles", 1) # fallback to 1 if not defined + + num_vision_tokens = num_tiles * token_per_chunk + + # Example output from processor: + # { + # 'encoder': { + # 'type': 'token', + # 'prompt_token_ids': [128256, 128256, ..., 128256], + # 'prompt': '<|image|><|image|>...<|image|>', + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # 'decoder': { + # 'type': 'token', + # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 + # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 + # 'multi_modal_data': {'image': }, # noqa: E501 + # }, + # } + MLLAMA_IMAGE_TOKEN_ID = hf_config.image_token_id + MLLAMA_IMAGE_TOKEN = "<|image_pad|>" + + return EncoderDecoderInputs( + encoder=token_inputs( + prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_vision_tokens, + prompt=MLLAMA_IMAGE_TOKEN * num_vision_tokens, + multi_modal_data=multi_modal_data, + ), + decoder=dec_inputs, + ) + + +@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_5_vl) +class Qwen2_5_VLForConditionalGeneration(Generator, SupportsMultiModal): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + self.QWEN_IMAGE_TOKEN_ID = 151655 + self.max_gen_len = self.model_args[0].max_seq_len - 1 # TODO: double check what this should be + + @classmethod + def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size, tt_data_parallel=1): + max_seq_len = 1024 * 128 + + submesh_devices = create_submeshes(mesh_device, tt_data_parallel) + + model_args = [] + model = [] + state_dict = None + + for submesh in submesh_devices: + model_args_i, model_i, state_dict = create_multimodal_model( + mesh_device=submesh, + max_batch_size=max_batch_size // tt_data_parallel, + max_seq_len=max_seq_len, + use_paged_kv_cache=True, + checkpoint=state_dict, + ) + model_args.append(model_args_i) + model.append(model_i) + + return cls(model, model_args, mesh_device) + + @property + def cache_path(self): + return self.model_args[0].model_cache_path + + @property + def max_cross_attn_tokens(self): + return self.model_args[0].vision_max_num_chunks * nearest_32(self.model_args[0].vision_chunk_ntok) + + def encode_input(self, token, image, processor): + print(image) + if image: + print + hf_messages = [ + { + "role": "user", + "content": [ + { + "type": "image", + "image": image, + }, + {"type": "text", "text": self.model_args[0].tokenizer.decode(token)}, + ], + } + ] + else: + hf_messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": self.model_args[0].tokenizer.decode(token)}, + ], + } + ] + + encoded = processor.apply_chat_template( + hf_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" + ).to("cpu", dtype=torch.bfloat16) + + return encoded + + def prefill_forward( + self, + tokens: torch.Tensor, + images: Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]], + page_table: torch.Tensor, + kv_cache, + prompt_lens, + cross_page_table=None, + ): + """ + Replaces prefill_forward from Generator with a version that supports mask creation. + """ + batch = tokens.shape[0] + + vision_images = [] + tokens_list = [] + image_grid_thw = [] + + processor = AutoProcessor.from_pretrained(self.model_args[0].CKPT_DIR) + + for user_id in range(batch): + image = images[user_id] + if isinstance(image, list): + assert len(image) == 1, "Only one image is supported for each user in the batch" + image = image[0] + + prompt_tokens = [int(tokens[user_id, i]) for i in range(prompt_lens[user_id])] + encoded_input = self.encode_input(prompt_tokens, image, processor) + vision_images.append(encoded_input["pixel_values"] if image else None) + tokens_list.append(encoded_input["input_ids"].squeeze(0)) + image_grid_thw.append(encoded_input["image_grid_thw"] if image else None) + + prefill_lens = torch.tensor([len(token) for token in tokens_list], dtype=torch.long) + total_lens = prefill_lens + self.max_gen_len + + pad_id = processor.tokenizer.pad_token_id + tokens = torch.full((batch, max(total_lens)), pad_id, dtype=torch.long) + + for i, seq in enumerate(tokens_list): + tokens[i, : len(seq)] = torch.tensor(seq, dtype=torch.long) + + self.prefill_lens = prefill_lens + + return super().prefill_forward( + vision_images, + None, + tokens, + None, + total_lens=total_lens, + prompt_lens=prefill_lens, + page_table=page_table, + kv_cache=kv_cache, + cross_page_table=cross_page_table, + image_grid_thw=image_grid_thw, + )[0] + + def decode_forward(self, *args, **kwargs): + if kwargs.get("start_pos") is not None: + kwargs["start_pos"][: len(self.prefill_lens)] = self.prefill_lens + logits = super().decode_forward_text(*args, **kwargs) + self.prefill_lens += 1 + return logits + + def allocate_kv_cache(self, *args, **kwargs): + return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path) + + # @MULTIMODAL_REGISTRY.register_image_input_mapper() # TODO: Add once model can accept inputs from multi_modal_input_mapper (raw pixel values) @INPUT_REGISTRY.register_input_processor(input_processor_for_mllama) class MllamaForConditionalGeneration(Generator, SupportsMultiModal): diff --git a/models/tt_transformers/tt/multimodal/qwen_vl/qwen_e2e_model.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_e2e_model.py index e8cc53f7c3a6..f5b4998ed2d6 100644 --- a/models/tt_transformers/tt/multimodal/qwen_vl/qwen_e2e_model.py +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_e2e_model.py @@ -76,15 +76,17 @@ def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_ vision_output = self.compute_vision_token(**kwargs) tokens_embd = ttnn.to_torch(tokens_embd, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1)) - comp_vision_output = ttnn.to_torch( - vision_output, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0) - )[: vision_output.shape[0], :] - image_features = comp_vision_output.squeeze(0) - special_image_mask = (pt_tokens == 151655).unsqueeze(-1) - special_image_mask = special_image_mask.expand_as(tokens_embd) - image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) - tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) + if vision_output is not None: + comp_vision_output = ttnn.to_torch( + vision_output, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0) + )[: vision_output.shape[0], :] + + image_features = comp_vision_output.squeeze(0) + special_image_mask = (pt_tokens == 151655).unsqueeze(-1) + special_image_mask = special_image_mask.expand_as(tokens_embd) + image_features = image_features.to(tokens_embd.device, tokens_embd.dtype) + tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features) tokens_embd = self.args.prepare_residual_tensor_prefill( tokens_embd, @@ -126,7 +128,8 @@ def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_ return tokens_embd, tt_rot_mats_prefill_global, tt_page_table, tt_chunk_page_table def compute_vision_token(self, pixel_values, image_grid_thw): - pixel_values = self.args.prepare_residual_tensor_prefill(pixel_values.unsqueeze(0), force_replicated=True) - - vision_output = self.vision_model(pixel_values, image_grid_thw) - return vision_output + if pixel_values is not None: + pixel_values = self.args.prepare_residual_tensor_prefill(pixel_values.unsqueeze(0), force_replicated=True) + vision_output = self.vision_model(pixel_values, image_grid_thw) + return vision_output + return None diff --git a/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_attention.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_attention.py index bd8dabf68676..90124e3ae635 100644 --- a/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_attention.py +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_attention.py @@ -109,7 +109,7 @@ def forward(self, hidden_states, cu_seqlens, position_embeddings): ) # shape [batch, seq_len, hidden_size*3] if self.configuration.num_devices > 1: - qkv = ttnn.all_gather(qkv, dim=-1, num_links=1) + qkv = ttnn.all_gather(qkv, dim=3, num_links=1) (q, k, v) = ttnn.permute(ttnn.reshape(qkv, [seq_len, 3, self.num_heads, -1]), [1, 0, 2, 3]) ttnn.deallocate(qkv) @@ -155,6 +155,7 @@ def forward(self, hidden_states, cu_seqlens, position_embeddings): ttnn.deallocate(attn_output) if self.configuration.num_devices > 1: - output = ttnn.all_gather(output, dim=1, num_links=1) + output = ttnn.all_gather(ttnn.reshape(output, (1, 1, output.shape[0], -1)), dim=3, num_links=1) + output = ttnn.reshape(output, (output.shape[2], -1)) return output diff --git a/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_patch_embed.py b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_patch_embed.py index 2cdc3e679103..9f0b29c781d0 100644 --- a/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_patch_embed.py +++ b/models/tt_transformers/tt/multimodal/qwen_vl/qwen_image_patch_embed.py @@ -63,6 +63,6 @@ def __call__(self, x: ttnn.Tensor) -> ttnn.Tensor: output = ttnn.matmul(x_flattened, self.weight, compute_kernel_config=self.compute_kernel_config) if self.args.num_devices > 1: - output = ttnn.all_gather(output, dim=1, num_links=1) - + output = ttnn.all_gather(ttnn.reshape(output, (1, 1, output.shape[0], -1)), dim=3, num_links=1) + output = ttnn.reshape(output, (output.shape[2], -1)) return output From 6a3034f6323c42f26c47fbbb3b91403f3493d379 Mon Sep 17 00:00:00 2001 From: mcw Date: Tue, 12 Aug 2025 13:59:19 +0530 Subject: [PATCH 2/2] Qwen VL 7B vLLM support --- models/tt_transformers/tt/generator_vllm.py | 253 +++++--------------- 1 file changed, 60 insertions(+), 193 deletions(-) diff --git a/models/tt_transformers/tt/generator_vllm.py b/models/tt_transformers/tt/generator_vllm.py index 51b3d528b7ab..a90612851fb1 100644 --- a/models/tt_transformers/tt/generator_vllm.py +++ b/models/tt_transformers/tt/generator_vllm.py @@ -3,13 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 import os +from types import SimpleNamespace from typing import List, Union import PIL import torch from llama_models.llama3.api.chat_format import create_vision_mask from tqdm import tqdm -from transformers import AutoProcessor from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, EncoderDecoderInputs, InputContext, TokenInputs, token_inputs from vllm.model_executor.models.interfaces import SupportsMultiModal @@ -124,8 +124,6 @@ def input_processor_for_mllama( # encoder and decoder prompts, vLLM by default will treat the prompt as the encoder prompt. # For the block manager to allocate enough blocks and add them to the block table, the decoder prompt # must contain the full text prompt. - # print() - dec_inputs = TokenInputs(**inputs["encoder"]) if os.environ.get("MESH_DEVICE") == "N300": @@ -194,116 +192,47 @@ def input_processor_for_llama_text(ctx: InputContext, inputs: Union[DecoderOnlyI return inputs -# TODO: Update input processor to inherit from EncDecMultiModalProcessor as is done in vllm.model_executor.models.mllama.py -def input_processor_for_qwen2_5_vl( - ctx: InputContext, - inputs: EncoderDecoderInputs, -) -> EncoderDecoderInputs: - """ - This was based on a previous version of vllm.model_executor.models.mllama.py::input_processor_for_mllama() - without the additional processing for computing num_tiles (here it is fixed). - """ - # Example input to processor: - # { - # 'encoder': { - # 'type': 'token', - # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 - # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 - # 'multi_modal_data': {'image': }, # noqa: E501 - # }, - # 'decoder': { - # 'type': 'token', - # 'prompt_token_ids': [128000], - # }, - # } - - # Move encoder_prompt to prompt. If the user does not explicitly provide separate - # encoder and decoder prompts, vLLM by default will treat the prompt as the encoder prompt. - # For the block manager to allocate enough blocks and add them to the block table, the decoder prompt - # must contain the full text prompt. - dec_inputs = TokenInputs(**inputs) - - if os.environ.get("MESH_DEVICE") == "N300": - prompt_len = len(dec_inputs.get("prompt_token_ids")) - MAX_PROMPT_LEN = 8192 - if prompt_len > MAX_PROMPT_LEN: - raise ValueError( - f"TT-LLama11B-Vision does not support prompts longer than {MAX_PROMPT_LEN} tokens on N300 (received prompt with {prompt_len} tokens)" - ) - - multi_modal_data = dec_inputs.get("multi_modal_data") - if multi_modal_data is None or "image" not in multi_modal_data: - # text-only - return EncoderDecoderInputs( - encoder=token_inputs([]), - decoder=dec_inputs, - ) - - # Set encoder prompt length based on the number of vision tokens so block manager allocates enough blocks (cross block tables). - # hf_config = ctx.model_config.hf_config - # vision_config = hf_config.vision_config - # assert vision_config.image_size % 14 == 0, "chunk size should be multiple of 14" - # token_per_chunk = nearest_32( - # (vision_config.image_size // 14) ** 2 + 1 - # ) # Note: we use nearest 32 while vLLM does not by default - # num_vision_tokens = ( - # vision_config.max_num_tiles * token_per_chunk - # ) # Note: we use max_num_tiles while vLLM uses num_tiles by default - - hf_config = ctx.model_config.hf_config - vision_config = hf_config.vision_config - - # Infer image size from window_size and spatial_patch_size - # Qwen uses windowed attention, and window_size = image_size // patch_size - # So image_size = window_size * patch_size - image_size = vision_config.window_size * vision_config.spatial_patch_size # e.g., 112 * 14 = 1568 +def input_processor_for_qwen25_vl(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]): + input_processor = ctx.get_hf_processor() + if "prompt" in inputs: + prompt_text = inputs["prompt"] + else: + # [INFO] with current version of vLLM, in server mode, inputs["prompt"] gives KeyError; only inputs['prompt_token_ids'] is available + assert "prompt_token_ids" in inputs, "prompt_token_ids must be available in server mode" + prompt_text = input_processor.decode(inputs["prompt_token_ids"], skip_special_tokens=False) + if "multi_modal_data" in inputs and "image" in inputs["multi_modal_data"]: + images = inputs["multi_modal_data"]["image"] + else: + images = None + + processed_inputs = input_processor( + text=prompt_text, # [INFO] Qwen2VLProcessor handles the case where text is a string or a list of strings + images=images, + videos=None, # [INFO] videos are not supported yet + return_tensors="pt", + ) - # Optional: verify it's divisible by 14 if needed - assert image_size % vision_config.spatial_patch_size == 0, "chunk size should be multiple of patch size" + assert processed_inputs.input_ids.shape[0] == 1, "Only one image is processed at a time by vLLM" + return { + "type": inputs["type"], + "prompt_token_ids": processed_inputs.input_ids[0].tolist(), + "prompt": prompt_text, + "multi_modal_data": {"image": processed_inputs}, # [INFO] add processed_inputs + } - token_per_chunk = nearest_32((image_size // vision_config.spatial_patch_size) ** 2 + 1) - # Qwen2.5-VL does not use max_num_tiles, but you can set it manually or derive it from your image splitting strategy - # Example: treat whole image as 1 tile unless your pipeline splits into tiles - num_tiles = getattr(vision_config, "max_num_tiles", 1) # fallback to 1 if not defined +class CustomNamespace(SimpleNamespace): + def __contains__(self, key): + return key in self.__dict__ - num_vision_tokens = num_tiles * token_per_chunk - # Example output from processor: - # { - # 'encoder': { - # 'type': 'token', - # 'prompt_token_ids': [128256, 128256, ..., 128256], - # 'prompt': '<|image|><|image|>...<|image|>', - # 'multi_modal_data': {'image': }, # noqa: E501 - # }, - # 'decoder': { - # 'type': 'token', - # 'prompt_token_ids': [128000, 128256, 128000, 3923, 374, 279, 2262, 315, 420, 2217, 30], # noqa: E501 - # 'prompt': '<|image|><|begin_of_text|>What is the content of this image?', # noqa: E501 - # 'multi_modal_data': {'image': }, # noqa: E501 - # }, - # } - MLLAMA_IMAGE_TOKEN_ID = hf_config.image_token_id - MLLAMA_IMAGE_TOKEN = "<|image_pad|>" - - return EncoderDecoderInputs( - encoder=token_inputs( - prompt_token_ids=[MLLAMA_IMAGE_TOKEN_ID] * num_vision_tokens, - prompt=MLLAMA_IMAGE_TOKEN * num_vision_tokens, - multi_modal_data=multi_modal_data, - ), - decoder=dec_inputs, - ) - - -@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_5_vl) +@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen25_vl) class Qwen2_5_VLForConditionalGeneration(Generator, SupportsMultiModal): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.QWEN_IMAGE_TOKEN_ID = 151655 - self.max_gen_len = self.model_args[0].max_seq_len - 1 # TODO: double check what this should be + self.max_gen_len = self.model_args[0].max_seq_len - 1 @classmethod def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size, tt_data_parallel=1): @@ -336,100 +265,38 @@ def cache_path(self): def max_cross_attn_tokens(self): return self.model_args[0].vision_max_num_chunks * nearest_32(self.model_args[0].vision_chunk_ntok) - def encode_input(self, token, image, processor): - print(image) - if image: - print - hf_messages = [ - { - "role": "user", - "content": [ - { - "type": "image", - "image": image, - }, - {"type": "text", "text": self.model_args[0].tokenizer.decode(token)}, - ], - } - ] - else: - hf_messages = [ - { - "role": "user", - "content": [ - {"type": "text", "text": self.model_args[0].tokenizer.decode(token)}, - ], - } - ] - - encoded = processor.apply_chat_template( - hf_messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt" - ).to("cpu", dtype=torch.bfloat16) - - return encoded - - def prefill_forward( - self, - tokens: torch.Tensor, - images: Union[List[PIL.Image.Image], List[List[PIL.Image.Image]]], - page_table: torch.Tensor, - kv_cache, - prompt_lens, - cross_page_table=None, - ): - """ - Replaces prefill_forward from Generator with a version that supports mask creation. - """ - batch = tokens.shape[0] - - vision_images = [] - tokens_list = [] - image_grid_thw = [] - - processor = AutoProcessor.from_pretrained(self.model_args[0].CKPT_DIR) - - for user_id in range(batch): - image = images[user_id] - if isinstance(image, list): - assert len(image) == 1, "Only one image is supported for each user in the batch" - image = image[0] - - prompt_tokens = [int(tokens[user_id, i]) for i in range(prompt_lens[user_id])] - encoded_input = self.encode_input(prompt_tokens, image, processor) - vision_images.append(encoded_input["pixel_values"] if image else None) - tokens_list.append(encoded_input["input_ids"].squeeze(0)) - image_grid_thw.append(encoded_input["image_grid_thw"] if image else None) - - prefill_lens = torch.tensor([len(token) for token in tokens_list], dtype=torch.long) - total_lens = prefill_lens + self.max_gen_len - - pad_id = processor.tokenizer.pad_token_id - tokens = torch.full((batch, max(total_lens)), pad_id, dtype=torch.long) - - for i, seq in enumerate(tokens_list): - tokens[i, : len(seq)] = torch.tensor(seq, dtype=torch.long) - - self.prefill_lens = prefill_lens - - return super().prefill_forward( - vision_images, - None, - tokens, - None, - total_lens=total_lens, - prompt_lens=prefill_lens, + def prefill_forward(self, *args, **kwargs): + self.tokenizer = self.model_args[0].tokenizer + pad_token_id = self.tokenizer.pad_token_id + + tokens = kwargs["tokens"] + prompt_lens = kwargs["prompt_lens"] + inputs = CustomNamespace() + inputs.input_ids = tokens + data = kwargs.get("images", None) # This contains the entire Data list, not just the pixel values + for i in range(tokens.shape[0]): # for each user, fix their padding + tokens[i][prompt_lens[i] :] = pad_token_id + pixel_values, image_grid_thw = None, None + + if hasattr(data[0], "pixel_values"): + # If inputs is a list of objects with .pixel_values, concatenate them + pixel_values = [im.pixel_values for im in data if hasattr(im, "pixel_values")] + image_grid_thw = [im.image_grid_thw for im in data if hasattr(im, "image_grid_thw")] + + page_table = kwargs.get("page_table", None) + kv_cache = kwargs.get("kv_cache", None) + + return super().prefill_forward_text( + tokens=inputs.input_ids, page_table=page_table, kv_cache=kv_cache, - cross_page_table=cross_page_table, - image_grid_thw=image_grid_thw, - )[0] + prompt_lens=prompt_lens, + pixel_values=pixel_values if pixel_values else None, + image_grid_thw=image_grid_thw if image_grid_thw else None, + ) def decode_forward(self, *args, **kwargs): - if kwargs.get("start_pos") is not None: - kwargs["start_pos"][: len(self.prefill_lens)] = self.prefill_lens - logits = super().decode_forward_text(*args, **kwargs) - self.prefill_lens += 1 - return logits + return super().decode_forward_text(*args, **kwargs) def allocate_kv_cache(self, *args, **kwargs): return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path)