Skip to content

Commit 1994cf9

Browse files
Add vLLM support for Gemma-3-4b-it
1 parent 4c4201b commit 1994cf9

File tree

4 files changed

+141
-10
lines changed

4 files changed

+141
-10
lines changed

models/tt_transformers/tt/generator.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def prefill_forward_text(
9797
model_kv_cache = kv_cache[model_id] if kv_cache is not None else None
9898

9999
# Check if 'pixel_values' exists and index it safely
100-
if "pixel_values" in local_kwargs:
100+
if "pixel_values" in local_kwargs and local_kwargs["pixel_values"] is not None:
101101
local_kwargs["pixel_values"] = local_kwargs["pixel_values"][idx]
102102
if "image_grid_thw" in local_kwargs:
103103
local_kwargs["image_grid_thw"] = local_kwargs["image_grid_thw"][idx]
@@ -413,6 +413,7 @@ def _prefill_forward_single_user(
413413
kv_cache=None,
414414
cross_page_table=None,
415415
model_id=-1,
416+
**kwargs,
416417
):
417418
"""
418419
Performs vision encode step then text prefill.
@@ -434,6 +435,7 @@ def _prefill_forward_single_user(
434435
batch_masks=[vision_mask],
435436
total_len=total_len,
436437
prefill_len=prefill_len,
438+
**kwargs,
437439
)
438440

439441
if cross_page_table is not None:
@@ -467,6 +469,8 @@ def _prefill_forward_single_user(
467469
page_table=page_table,
468470
cross_page_table=cross_page_table,
469471
text_only_inference=text_only_inference,
472+
vision_tokens=vision_tokens,
473+
**kwargs,
470474
)
471475

472476
tt_logits = self.model[model_id].ttnn_prefill_forward(
@@ -565,6 +569,7 @@ def prefill_forward_llama_vision(
565569
kv_cache=None,
566570
cross_page_table=None,
567571
empty_slots=None,
572+
**kwargs,
568573
):
569574
"""
570575
Batched version of _prefill_forward_single_user for vision model.
@@ -600,6 +605,11 @@ def prefill_forward_llama_vision(
600605
model_kv_cache = kv_cache[model_id] if kv_cache is not None else None
601606
model_xattn_cache = xattn_caches[model_id] if xattn_caches is not None else None
602607

608+
# prefill_seq_len = get_padded_prefill_len(seq_len)
609+
# tokens = torch.cat(
610+
# [tokens[idx : idx + 1, :seq_len], torch.zeros(1, prefill_seq_len - seq_len).long()], dim=-1
611+
# )
612+
603613
(
604614
model_xattn_cache,
605615
prefill_cross_attention_masks,
@@ -619,6 +629,8 @@ def prefill_forward_llama_vision(
619629
kv_cache=model_kv_cache,
620630
cross_page_table=user_cross_page_table,
621631
model_id=model_id,
632+
image_grid_thw=kwargs["image_grid_thw"][idx] if kwargs.get("image_grid_thw") else None,
633+
input_ids=kwargs["input_ids"][idx] if kwargs.get("input_ids") else None,
622634
)
623635

624636
if xattn_caches is not None:

models/tt_transformers/tt/generator_vllm.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,112 @@ def decode_forward(self, *args, **kwargs):
373373

374374
def allocate_kv_cache(self, *args, **kwargs):
375375
return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path)
376+
377+
378+
def input_processor_for_gemma(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]):
379+
input_processor = ctx.get_hf_processor()
380+
if "prompt" in inputs:
381+
prompt_text = inputs["prompt"]
382+
else:
383+
assert "prompt_token_ids" in inputs, "prompt_token_ids must be available in server mode"
384+
prompt_text = input_processor.decode(inputs["prompt_token_ids"], skip_special_tokens=False)
385+
386+
if "multi_modal_data" in inputs and "image" in inputs["multi_modal_data"]:
387+
images = inputs["multi_modal_data"]["image"]
388+
else:
389+
images = None
390+
391+
processed_inputs = input_processor(
392+
text=prompt_text,
393+
images=images,
394+
return_tensors="pt",
395+
)
396+
397+
assert processed_inputs.input_ids.shape[0] == 1, "Only one image is processed at a time by vLLM"
398+
return {
399+
"type": inputs["type"],
400+
"prompt_token_ids": processed_inputs.input_ids[0].tolist(),
401+
"prompt": prompt_text,
402+
"multi_modal_data": {"image": processed_inputs}, # [INFO] add processed_inputs
403+
}
404+
405+
406+
from types import SimpleNamespace
407+
408+
409+
class CustomNamespace(SimpleNamespace):
410+
def __contains__(self, key):
411+
return key in self.__dict__
412+
413+
414+
@INPUT_REGISTRY.register_input_processor(input_processor_for_gemma)
415+
class Gemma3ForConditionalGeneration(Generator, SupportsMultiModal):
416+
def __init__(self, *args, **kwargs):
417+
super().__init__(*args, **kwargs)
418+
419+
self.GEMMA_IMAGE_TOKEN_ID = 262144
420+
self.max_gen_len = self.model_args[0].max_seq_len - 1 # TODO: double check what this should be
421+
422+
@classmethod
423+
def initialize_vllm_model(
424+
cls, hf_config, mesh_device, max_batch_size, max_seq_len=131072, n_layers=None, tt_data_parallel=1
425+
):
426+
submesh_devices = create_submeshes(mesh_device, tt_data_parallel)
427+
428+
model_args = []
429+
model = []
430+
state_dict = None
431+
432+
for submesh in submesh_devices:
433+
model_args_i, model_i, state_dict = create_multimodal_model(
434+
mesh_device=submesh,
435+
max_batch_size=max_batch_size // tt_data_parallel,
436+
max_seq_len=max_seq_len,
437+
use_paged_kv_cache=True,
438+
checkpoint=state_dict,
439+
)
440+
model_args.append(model_args_i)
441+
model.append(model_i)
442+
443+
return cls(model, model_args, mesh_device)
444+
445+
@property
446+
def cache_path(self):
447+
return self.model_args[0].model_cache_path
448+
449+
def prefill_forward(self, *args, **kwargs):
450+
self.tokenizer = self.model_args[0].tokenizer
451+
pad_token_id = self.tokenizer.pad_token_id
452+
453+
tokens = kwargs["tokens"]
454+
prompt_lens = kwargs["prompt_lens"]
455+
inputs = CustomNamespace()
456+
inputs.input_ids = tokens
457+
data = kwargs.get("images", None) # This contains the entire Data list, not just the pixel values
458+
for i in range(tokens.shape[0]): # for each user, fix their padding
459+
tokens[i][prompt_lens[i] :] = pad_token_id
460+
pixel_values = None
461+
462+
if hasattr(data[0], "pixel_values"):
463+
# If inputs is a list of objects with .pixel_values, concatenate them
464+
pixel_values = torch.concat([im.pixel_values for im in data if hasattr(im, "pixel_values")], dim=0)
465+
466+
page_table = kwargs.get("page_table", None)
467+
kv_cache = kwargs.get("kv_cache", None)
468+
vision_images = pixel_values
469+
470+
vision_images = [vision_images] if vision_images is not None else None
471+
472+
return super().prefill_forward_text(
473+
tokens=inputs.input_ids,
474+
page_table=page_table,
475+
kv_cache=kv_cache,
476+
prompt_lens=prompt_lens,
477+
pixel_values=vision_images,
478+
)
479+
480+
def allocate_kv_cache(self, *args, **kwargs):
481+
return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path)
482+
483+
def decode_forward(self, *args, **kwargs):
484+
return super().decode_forward_text(*args, **kwargs)

models/tt_transformers/tt/multimodal/gemma/gemma_e2e_model.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,16 @@ def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_
7373

7474
vision_output = self.compute_vision_token(**kwargs)
7575
tokens_embd = ttnn.to_torch(tokens_embd, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1))
76-
comp_vision_output = ttnn.to_torch(
77-
vision_output, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)
78-
)[: vision_output.shape[0], :]
76+
if vision_output is not None:
77+
comp_vision_output = ttnn.to_torch(
78+
vision_output, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=0)
79+
)[: vision_output.shape[0], :]
7980

80-
image_features = comp_vision_output.squeeze(0)
81-
special_image_mask = (pt_tokens == self.args.image_token_index).unsqueeze(-1)
82-
special_image_mask = special_image_mask.expand_as(tokens_embd)
83-
image_features = image_features.to(tokens_embd.device, tokens_embd.dtype)
84-
tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features)
81+
image_features = comp_vision_output.squeeze(0)
82+
special_image_mask = (pt_tokens == self.args.image_token_index).unsqueeze(-1)
83+
special_image_mask = special_image_mask.expand_as(tokens_embd)
84+
image_features = image_features.to(tokens_embd.device, tokens_embd.dtype)
85+
tokens_embd = tokens_embd.masked_scatter(special_image_mask, image_features)
8586

8687
tokens_embd = self.args.prepare_residual_tensor_prefill(
8788
tokens_embd,
@@ -127,6 +128,8 @@ def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_
127128

128129
return tokens_embd, [tt_rot_mats_prefill_global, tt_rot_mats_prefill_local], tt_page_table, tt_chunk_page_table
129130

130-
def compute_vision_token(self, pixel_values):
131+
def compute_vision_token(self, pixel_values=None):
132+
if pixel_values is None:
133+
return None
131134
vision_output = self.vision_model(pixel_values)
132135
return vision_output

models/tt_transformers/tt/multimodal/gemma/gemma_image_mlp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,11 @@ def forward(self, x: ttnn.Tensor) -> ttnn.Tensor:
117117
pre_bias_output = c_proj_out
118118

119119
output = ttnn.add(pre_bias_output, self.c_proj_bias)
120+
121+
ttnn.deallocate(c_fc_out)
122+
ttnn.deallocate(c_proj_out)
123+
ttnn.deallocate(pre_bias_output)
124+
# Deallocate input tensor to free memory
125+
ttnn.deallocate(x_in)
126+
# Reshape output back to original shape
120127
return output

0 commit comments

Comments
 (0)