Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions models/tt_transformers/tt/generator_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import os
from types import SimpleNamespace
from typing import List, Union

import PIL
Expand Down Expand Up @@ -191,6 +192,116 @@ def input_processor_for_llama_text(ctx: InputContext, inputs: Union[DecoderOnlyI
return inputs


def input_processor_for_qwen25_vl(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]):
input_processor = ctx.get_hf_processor()
if "prompt" in inputs:
prompt_text = inputs["prompt"]
else:
# [INFO] with current version of vLLM, in server mode, inputs["prompt"] gives KeyError; only inputs['prompt_token_ids'] is available
assert "prompt_token_ids" in inputs, "prompt_token_ids must be available in server mode"
prompt_text = input_processor.decode(inputs["prompt_token_ids"], skip_special_tokens=False)
if "multi_modal_data" in inputs and "image" in inputs["multi_modal_data"]:
images = inputs["multi_modal_data"]["image"]
else:
images = None

processed_inputs = input_processor(
text=prompt_text, # [INFO] Qwen2VLProcessor handles the case where text is a string or a list of strings
images=images,
videos=None, # [INFO] videos are not supported yet
return_tensors="pt",
)

assert processed_inputs.input_ids.shape[0] == 1, "Only one image is processed at a time by vLLM"
return {
"type": inputs["type"],
"prompt_token_ids": processed_inputs.input_ids[0].tolist(),
"prompt": prompt_text,
"multi_modal_data": {"image": processed_inputs}, # [INFO] add processed_inputs
}


class CustomNamespace(SimpleNamespace):
def __contains__(self, key):
return key in self.__dict__


@INPUT_REGISTRY.register_input_processor(input_processor_for_qwen25_vl)
class Qwen2_5_VLForConditionalGeneration(Generator, SupportsMultiModal):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

self.QWEN_IMAGE_TOKEN_ID = 151655
self.max_gen_len = self.model_args[0].max_seq_len - 1

@classmethod
def initialize_vllm_model(cls, hf_config, mesh_device, max_batch_size, tt_data_parallel=1):
max_seq_len = 1024 * 128

submesh_devices = create_submeshes(mesh_device, tt_data_parallel)

model_args = []
model = []
state_dict = None

for submesh in submesh_devices:
model_args_i, model_i, state_dict = create_multimodal_model(
mesh_device=submesh,
max_batch_size=max_batch_size // tt_data_parallel,
max_seq_len=max_seq_len,
use_paged_kv_cache=True,
checkpoint=state_dict,
)
model_args.append(model_args_i)
model.append(model_i)

return cls(model, model_args, mesh_device)

@property
def cache_path(self):
return self.model_args[0].model_cache_path

@property
def max_cross_attn_tokens(self):
return self.model_args[0].vision_max_num_chunks * nearest_32(self.model_args[0].vision_chunk_ntok)

def prefill_forward(self, *args, **kwargs):
self.tokenizer = self.model_args[0].tokenizer
pad_token_id = self.tokenizer.pad_token_id

tokens = kwargs["tokens"]
prompt_lens = kwargs["prompt_lens"]
inputs = CustomNamespace()
inputs.input_ids = tokens
data = kwargs.get("images", None) # This contains the entire Data list, not just the pixel values
for i in range(tokens.shape[0]): # for each user, fix their padding
tokens[i][prompt_lens[i] :] = pad_token_id
pixel_values, image_grid_thw = None, None

if hasattr(data[0], "pixel_values"):
# If inputs is a list of objects with .pixel_values, concatenate them
pixel_values = [im.pixel_values for im in data if hasattr(im, "pixel_values")]
image_grid_thw = [im.image_grid_thw for im in data if hasattr(im, "image_grid_thw")]

page_table = kwargs.get("page_table", None)
kv_cache = kwargs.get("kv_cache", None)

return super().prefill_forward_text(
tokens=inputs.input_ids,
page_table=page_table,
kv_cache=kv_cache,
prompt_lens=prompt_lens,
pixel_values=pixel_values if pixel_values else None,
image_grid_thw=image_grid_thw if image_grid_thw else None,
)

def decode_forward(self, *args, **kwargs):
return super().decode_forward_text(*args, **kwargs)

def allocate_kv_cache(self, *args, **kwargs):
return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path)


# @MULTIMODAL_REGISTRY.register_image_input_mapper() # TODO: Add once model can accept inputs from multi_modal_input_mapper (raw pixel values)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
class MllamaForConditionalGeneration(Generator, SupportsMultiModal):
Expand Down
27 changes: 15 additions & 12 deletions models/tt_transformers/tt/multimodal/qwen_vl/qwen_e2e_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,17 @@ def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_
vision_output = self.compute_vision_token(**kwargs)

tokens_embd = ttnn.to_torch(tokens_embd, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1))
comp_vision_output = ttnn.to_torch(
vision_output, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)
)[: vision_output.shape[0], :]

image_features = comp_vision_output.squeeze(0)
special_image_mask = (pt_tokens == 151655).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(tokens_embd)
image_features = image_features.to(tokens_embd.device, tokens_embd.dtype)
tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features)
if vision_output is not None:
comp_vision_output = ttnn.to_torch(
vision_output, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)
)[: vision_output.shape[0], :]

image_features = comp_vision_output.squeeze(0)
special_image_mask = (pt_tokens == 151655).unsqueeze(-1)
special_image_mask = special_image_mask.expand_as(tokens_embd)
image_features = image_features.to(tokens_embd.device, tokens_embd.dtype)
tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features)

tokens_embd = self.args.prepare_residual_tensor_prefill(
tokens_embd,
Expand Down Expand Up @@ -126,7 +128,8 @@ def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_
return tokens_embd, tt_rot_mats_prefill_global, tt_page_table, tt_chunk_page_table

def compute_vision_token(self, pixel_values, image_grid_thw):
pixel_values = self.args.prepare_residual_tensor_prefill(pixel_values.unsqueeze(0), force_replicated=True)

vision_output = self.vision_model(pixel_values, image_grid_thw)
return vision_output
if pixel_values is not None:
pixel_values = self.args.prepare_residual_tensor_prefill(pixel_values.unsqueeze(0), force_replicated=True)
vision_output = self.vision_model(pixel_values, image_grid_thw)
return vision_output
return None
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def forward(self, hidden_states, cu_seqlens, position_embeddings):
) # shape [batch, seq_len, hidden_size*3]

if self.configuration.num_devices > 1:
qkv = ttnn.all_gather(qkv, dim=-1, num_links=1)
qkv = ttnn.all_gather(qkv, dim=3, num_links=1)

(q, k, v) = ttnn.permute(ttnn.reshape(qkv, [seq_len, 3, self.num_heads, -1]), [1, 0, 2, 3])
ttnn.deallocate(qkv)
Expand Down Expand Up @@ -155,6 +155,7 @@ def forward(self, hidden_states, cu_seqlens, position_embeddings):
ttnn.deallocate(attn_output)

if self.configuration.num_devices > 1:
output = ttnn.all_gather(output, dim=1, num_links=1)
output = ttnn.all_gather(ttnn.reshape(output, (1, 1, output.shape[0], -1)), dim=3, num_links=1)
output = ttnn.reshape(output, (output.shape[2], -1))

return output
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,6 @@ def __call__(self, x: ttnn.Tensor) -> ttnn.Tensor:
output = ttnn.matmul(x_flattened, self.weight, compute_kernel_config=self.compute_kernel_config)

if self.args.num_devices > 1:
output = ttnn.all_gather(output, dim=1, num_links=1)

output = ttnn.all_gather(ttnn.reshape(output, (1, 1, output.shape[0], -1)), dim=3, num_links=1)
output = ttnn.reshape(output, (output.shape[2], -1))
return output