From c3ea097f984b2b764a53599b7e936744a2c52535 Mon Sep 17 00:00:00 2001 From: MohammedTaherMcW Date: Wed, 30 Jul 2025 13:56:25 +0000 Subject: [PATCH 1/3] Add vLLM support --- models/tt_transformers/tt/generator_vllm.py | 34 +++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/models/tt_transformers/tt/generator_vllm.py b/models/tt_transformers/tt/generator_vllm.py index 5125f551053d..213cec99a51c 100644 --- a/models/tt_transformers/tt/generator_vllm.py +++ b/models/tt_transformers/tt/generator_vllm.py @@ -373,3 +373,37 @@ def decode_forward(self, *args, **kwargs): def allocate_kv_cache(self, *args, **kwargs): return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path) + + +class Gemma3ForCausalLM(Generator): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def initialize_vllm_model( + cls, hf_config, mesh_device, max_batch_size, max_seq_len=32768, n_layers=None, tt_data_parallel=1 + ): + tt_model, model_args = initialize_vllm_text_transformer( + hf_config, + tt_data_parallel, + mesh_device, + max_batch_size, + max_seq_len=max_seq_len, + n_layers=n_layers, + dtype=ttnn.bfloat8_b, + optimizations=DecodersPrecision.performance, + ) + return cls(tt_model, model_args, mesh_device) + + @property + def cache_path(self): + return self.model_args[0].model_cache_path + + def prefill_forward(self, *args, **kwargs): + return super().prefill_forward_text(*args, **kwargs) + + def decode_forward(self, *args, **kwargs): + return super().decode_forward_text(*args, **kwargs) + + def allocate_kv_cache(self, *args, **kwargs): + return allocate_vllm_kv_cache(*args, **kwargs, dp_model=self.model, tt_cache_path=self.cache_path) From 66ea9d7f125d6637f13248a14e29275b821b0f4b Mon Sep 17 00:00:00 2001 From: MohammedTaherMcW Date: Thu, 31 Jul 2025 07:55:42 +0000 Subject: [PATCH 2/3] Change dtype of Gemma3ForCausalLM vLLM from bfloat8_b to bfloat16 --- models/tt_transformers/tt/generator_vllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/tt_transformers/tt/generator_vllm.py b/models/tt_transformers/tt/generator_vllm.py index 213cec99a51c..61c76f89fd10 100644 --- a/models/tt_transformers/tt/generator_vllm.py +++ b/models/tt_transformers/tt/generator_vllm.py @@ -390,7 +390,7 @@ def initialize_vllm_model( max_batch_size, max_seq_len=max_seq_len, n_layers=n_layers, - dtype=ttnn.bfloat8_b, + dtype=ttnn.bfloat16, optimizations=DecodersPrecision.performance, ) return cls(tt_model, model_args, mesh_device) From e324df80cd8c8f1d15ab06fe6c74c656826a4ec7 Mon Sep 17 00:00:00 2001 From: MohammedTaherMcW Date: Thu, 31 Jul 2025 07:58:00 +0000 Subject: [PATCH 3/3] Change optimizations for Gemma3ForCausalLM vLLM to accuracy --- models/tt_transformers/tt/generator_vllm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/tt_transformers/tt/generator_vllm.py b/models/tt_transformers/tt/generator_vllm.py index 61c76f89fd10..7b85a79ea9ef 100644 --- a/models/tt_transformers/tt/generator_vllm.py +++ b/models/tt_transformers/tt/generator_vllm.py @@ -391,7 +391,7 @@ def initialize_vllm_model( max_seq_len=max_seq_len, n_layers=n_layers, dtype=ttnn.bfloat16, - optimizations=DecodersPrecision.performance, + optimizations=DecodersPrecision.accuracy, ) return cls(tt_model, model_args, mesh_device)