diff --git a/models/tt_transformers/tt/generator_vllm.py b/models/tt_transformers/tt/generator_vllm.py index e4188c530dd5..7a6e04c73d25 100644 --- a/models/tt_transformers/tt/generator_vllm.py +++ b/models/tt_transformers/tt/generator_vllm.py @@ -525,7 +525,7 @@ def __init__(self, *args, **kwargs): def initialize_vllm_model( cls, hf_config, mesh_device, max_batch_size, max_seq_len=131072, n_layers=None, tt_data_parallel=1 ): - from models.demos.gemma3.demo.vision_demo import create_multimodal_model + from models.tt_transformers.demo.simple_vision_demo import create_multimodal_model submesh_devices = create_submeshes(mesh_device, tt_data_parallel) @@ -564,3 +564,37 @@ def allocate_kv_cache(self, *args, **kwargs): def decode_forward(self, *args, **kwargs): return super().decode_forward_text(*args, **kwargs) + + +class Gemma3ForCausalLM(Generator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def initialize_vllm_model( + cls, hf_config, mesh_device, max_batch_size, max_seq_len=32768, n_layers=None, tt_data_parallel=1 + ): + tt_model, model_args = initialize_vllm_text_transformer( + hf_config, + tt_data_parallel, + mesh_device, + max_batch_size, + max_seq_len=max_seq_len, + n_layers=n_layers, + dtype=ttnn.bfloat8_b, + optimizations=DecodersPrecision.performance, + ) + return cls(tt_model, model_args, mesh_device) + + @property + def cache_path(self): + return self.model_args[0].model_cache_path + + def prefill_forward(self, *args, **kwargs): + return super().prefill_forward_text(*args, **kwargs) + + def decode_forward(self, *args, **kwargs): + return super().decode_forward_text(*args, **kwargs) + + def allocate_kv_cache(self, *args, **kwargs): + return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path) diff --git a/models/tt_transformers/tt/multimodal/gemma3/gemma_e2e_model.py b/models/tt_transformers/tt/multimodal/gemma3/gemma_e2e_model.py index 41047dd3f92f..55e05e4fd10b 100644 --- a/models/tt_transformers/tt/multimodal/gemma3/gemma_e2e_model.py +++ b/models/tt_transformers/tt/multimodal/gemma3/gemma_e2e_model.py @@ -51,7 +51,7 @@ def prepare_inputs_prefill(self, pt_tokens, start_pos=0, page_table=None, chunk_ mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), ) tokens_embd = self.embd(tokens) - vision_output = self.compute_vision_token(**kwargs) + vision_output = self.compute_vision_token(kwargs.get("pixel_values", None)) if vision_output is not None: tokens_embd = ttnn.to_torch(tokens_embd, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1))