Skip to content

Commit 0eff89e

Browse files
Add vLLM Support for Gemma3 Models
1 parent 359c925 commit 0eff89e

File tree

1 file changed

+109
-0
lines changed

1 file changed

+109
-0
lines changed

models/tt_transformers/tt/generator_vllm.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -373,3 +373,112 @@ def decode_forward(self, *args, **kwargs):
373373

374374
def allocate_kv_cache(self, *args, **kwargs):
375375
return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path)
376+
377+
378+
def input_processor_for_gemma(ctx: InputContext, inputs: Union[DecoderOnlyInputs, EncoderDecoderInputs]):
379+
input_processor = ctx.get_hf_processor()
380+
if "prompt" in inputs:
381+
prompt_text = inputs["prompt"]
382+
else:
383+
assert "prompt_token_ids" in inputs, "prompt_token_ids must be available in server mode"
384+
prompt_text = input_processor.decode(inputs["prompt_token_ids"], skip_special_tokens=False)
385+
386+
if "multi_modal_data" in inputs and "image" in inputs["multi_modal_data"]:
387+
images = inputs["multi_modal_data"]["image"]
388+
else:
389+
images = None
390+
391+
processed_inputs = input_processor(
392+
text=prompt_text,
393+
images=images,
394+
return_tensors="pt",
395+
)
396+
397+
assert processed_inputs.input_ids.shape[0] == 1, "Only one image is processed at a time by vLLM"
398+
return {
399+
"type": inputs["type"],
400+
"prompt_token_ids": processed_inputs.input_ids[0].tolist(),
401+
"prompt": prompt_text,
402+
"multi_modal_data": {"image": processed_inputs}, # [INFO] add processed_inputs
403+
}
404+
405+
406+
from types import SimpleNamespace
407+
408+
409+
class CustomNamespace(SimpleNamespace):
410+
def __contains__(self, key):
411+
return key in self.__dict__
412+
413+
414+
@INPUT_REGISTRY.register_input_processor(input_processor_for_gemma)
415+
class Gemma3ForConditionalGeneration(Generator, SupportsMultiModal):
416+
def __init__(self, *args, **kwargs):
417+
super().__init__(*args, **kwargs)
418+
419+
self.GEMMA_IMAGE_TOKEN_ID = 262144
420+
self.max_gen_len = self.model_args[0].max_seq_len - 1 # TODO: double check what this should be
421+
422+
@classmethod
423+
def initialize_vllm_model(
424+
cls, hf_config, mesh_device, max_batch_size, max_seq_len=131072, n_layers=None, tt_data_parallel=1
425+
):
426+
submesh_devices = create_submeshes(mesh_device, tt_data_parallel)
427+
428+
model_args = []
429+
model = []
430+
state_dict = None
431+
432+
for submesh in submesh_devices:
433+
model_args_i, model_i, state_dict = create_multimodal_model(
434+
mesh_device=submesh,
435+
max_batch_size=max_batch_size // tt_data_parallel,
436+
max_seq_len=max_seq_len,
437+
use_paged_kv_cache=True,
438+
checkpoint=state_dict,
439+
)
440+
model_args.append(model_args_i)
441+
model.append(model_i)
442+
443+
return cls(model, model_args, mesh_device)
444+
445+
@property
446+
def cache_path(self):
447+
return self.model_args[0].model_cache_path
448+
449+
def prefill_forward(self, *args, **kwargs):
450+
self.tokenizer = self.model_args[0].tokenizer
451+
pad_token_id = self.tokenizer.pad_token_id
452+
453+
tokens = kwargs["tokens"]
454+
prompt_lens = kwargs["prompt_lens"]
455+
inputs = CustomNamespace()
456+
inputs.input_ids = tokens
457+
data = kwargs.get("images", None) # This contains the entire Data list, not just the pixel values
458+
for i in range(tokens.shape[0]): # for each user, fix their padding
459+
tokens[i][prompt_lens[i] :] = pad_token_id
460+
pixel_values = None
461+
462+
if hasattr(data[0], "pixel_values"):
463+
# If inputs is a list of objects with pixel_values, concatenate them
464+
pixel_values = torch.concat([im.pixel_values for im in data if hasattr(im, "pixel_values")], dim=0)
465+
466+
page_table = kwargs.get("page_table", None)
467+
kv_cache = kwargs.get("kv_cache", None)
468+
vision_images = pixel_values
469+
470+
vision_images = [vision_images] if vision_images is not None else None
471+
472+
return super().prefill_forward_text(
473+
tokens=inputs.input_ids,
474+
page_table=page_table,
475+
kv_cache=kv_cache,
476+
prompt_lens=prompt_lens,
477+
pixel_values=vision_images,
478+
)
479+
480+
def allocate_kv_cache(self, *args, **kwargs):
481+
return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path)
482+
483+
def decode_forward(self, *args, **kwargs):
484+
return super().decode_forward_text(*args, **kwargs)

0 commit comments

Comments
 (0)