Skip to content

Commit d08eccc

Browse files
Fix generator vllm after rebase
1 parent a5be789 commit d08eccc

File tree

1 file changed

+37
-5
lines changed

1 file changed

+37
-5
lines changed

models/tt_transformers/tt/generator_vllm.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -459,16 +459,14 @@ def prefill_forward(self, *args, **kwargs):
459459
tokens[i][prompt_lens[i] :] = pad_token_id
460460
pixel_values = None
461461

462-
if hasattr(data[0], "pixel_values"):
463-
# If inputs is a list of objects with pixel_values, concatenate them
464-
pixel_values = torch.concat([im.pixel_values for im in data if hasattr(im, "pixel_values")], dim=0)
462+
if any(hasattr(d, "pixel_values") for d in data):
463+
# If inputs is a list of objects with .pixel_values, concatenate them
464+
pixel_values = [im.pixel_values if hasattr(im, "pixel_values") else None for im in data]
465465

466466
page_table = kwargs.get("page_table", None)
467467
kv_cache = kwargs.get("kv_cache", None)
468468
vision_images = pixel_values
469469

470-
vision_images = [vision_images] if vision_images is not None else None
471-
472470
return super().prefill_forward_text(
473471
tokens=inputs.input_ids,
474472
page_table=page_table,
@@ -482,3 +480,37 @@ def allocate_kv_cache(self, *args, **kwargs):
482480

483481
def decode_forward(self, *args, **kwargs):
484482
return super().decode_forward_text(*args, **kwargs)
483+
484+
485+
class Gemma3ForCausalLM(Generator):
486+
def __init__(self, *args, **kwargs):
487+
super().__init__(*args, **kwargs)
488+
489+
@classmethod
490+
def initialize_vllm_model(
491+
cls, hf_config, mesh_device, max_batch_size, max_seq_len=32768, n_layers=None, tt_data_parallel=1
492+
):
493+
tt_model, model_args = initialize_vllm_text_transformer(
494+
hf_config,
495+
tt_data_parallel,
496+
mesh_device,
497+
max_batch_size,
498+
max_seq_len=max_seq_len,
499+
n_layers=n_layers,
500+
dtype=ttnn.bfloat8_b,
501+
optimizations=DecodersPrecision.performance,
502+
)
503+
return cls(tt_model, model_args, mesh_device)
504+
505+
@property
506+
def cache_path(self):
507+
return self.model_args[0].model_cache_path
508+
509+
def prefill_forward(self, *args, **kwargs):
510+
return super().prefill_forward_text(*args, **kwargs)
511+
512+
def decode_forward(self, *args, **kwargs):
513+
return super().decode_forward_text(*args, **kwargs)
514+
515+
def allocate_kv_cache(self, *args, **kwargs):
516+
return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path)

0 commit comments

Comments
 (0)