diff --git a/models/tt_transformers/tt/generator_vllm.py b/models/tt_transformers/tt/generator_vllm.py index 5125f551053d..7b85a79ea9ef 100644 --- a/models/tt_transformers/tt/generator_vllm.py +++ b/models/tt_transformers/tt/generator_vllm.py @@ -373,3 +373,37 @@ def decode_forward(self, *args, **kwargs): def allocate_kv_cache(self, *args, **kwargs): return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path) + + +class Gemma3ForCausalLM(Generator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def initialize_vllm_model( + cls, hf_config, mesh_device, max_batch_size, max_seq_len=32768, n_layers=None, tt_data_parallel=1 + ): + tt_model, model_args = initialize_vllm_text_transformer( + hf_config, + tt_data_parallel, + mesh_device, + max_batch_size, + max_seq_len=max_seq_len, + n_layers=n_layers, + dtype=ttnn.bfloat16, + optimizations=DecodersPrecision.accuracy, + ) + return cls(tt_model, model_args, mesh_device) + + @property + def cache_path(self): + return self.model_args[0].model_cache_path + + def prefill_forward(self, *args, **kwargs): + return super().prefill_forward_text(*args, **kwargs) + + def decode_forward(self, *args, **kwargs): + return super().decode_forward_text(*args, **kwargs) + + def allocate_kv_cache(self, *args, **kwargs): + return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path)