From d07f98b041d346fb8a9c0e1f63ea655f08b80566 Mon Sep 17 00:00:00 2001 From: MohammedTaherMcW Date: Sun, 24 Aug 2025 18:13:32 +0000 Subject: [PATCH 1/2] Add vLLM Support for Gemma3 Models --- examples/offline_inference_tt.py | 89 ++++++++++++++++++++++++++++++-- vllm/worker/tt_worker.py | 1 + 2 files changed, 86 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference_tt.py b/examples/offline_inference_tt.py index 2f1d6bd78fe0..dceda7c09302 100644 --- a/examples/offline_inference_tt.py +++ b/examples/offline_inference_tt.py @@ -10,7 +10,7 @@ from PIL import Image as PIL_Image from pkg_resources import resource_filename from tqdm import tqdm -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoProcessor from vllm import LLM, ModelRegistry, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs @@ -62,6 +62,10 @@ def register_tt_models(): "TTMistralForCausalLM", "models.tt_transformers.tt.generator_vllm:MistralForCausalLM") + ModelRegistry.register_model( + "TTGemma3ForConditionalGeneration", + "models.tt_transformers.tt.generator_vllm:Gemma3ForConditionalGeneration" + ) register_tt_models() # Import and register models from tt-metal @@ -94,6 +98,71 @@ def get_sample_multi_modal_llama_inputs(): inputs.append({"prompt": question}) return inputs +def get_sample_multi_modal_gemma_inputs(model): + # Prepare a sample multi-modal prompt for Gemma3-4b-it + text_prompts = [] + imgs = [] + img_refs = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", None] + + questions = ["Describe this image.", "What is Global Warming?"] + assert len(img_refs) == len(questions), ( + "Number of image references must match number of questions") + + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(model, trust_remote_code=True) + + text_prompts = [] + imgs = [] + import requests + from PIL import Image as PIL_Image + from io import BytesIO + + prompts = [] + imgs = [] + text_prompts = [] + + for img_ref, question in zip(img_refs, questions): + content_list = [] + + if img_ref: # image present + content_list.append({"type": "image", "image": img_ref}) + + # Load and convert image + response = requests.get(img_ref) + img = PIL_Image.open(BytesIO(response.content)).convert("RGB") + imgs.append(img) + else: # no image + imgs.append(None) + + # Always add text + content_list.append({"type": "text", "text": question}) + + # Append final prompt + prompt = [{"role": "user", "content": content_list}] + prompts.append(prompt) + + # Create text-only template for tokenizer + chat_prompt = tokenizer.apply_chat_template( + prompt, + tokenize=False, + add_generation_prompt=True + ) + text_prompts.append(chat_prompt) + + inputs = [] + for img, text_prompt in zip(imgs, text_prompts): + if img is not None: + inputs.append({ + "prompt": text_prompt, + "multi_modal_data": { + "image": img + } + }) + else: + inputs.append({"prompt": text_prompt}) + return inputs + def get_sample_multi_modal_qwen_inputs(model): # Prepare a sample multi-modal prompt for Qwen2.5-VL @@ -183,6 +252,8 @@ def check_tt_model_supported(model): "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", "mistralai/Mistral-7B-Instruct-v0.3", + "google/gemma-3-4b-it", + "google/gemma-3-27b-it", ] assert model in supported_models, f"Invalid model: {model}" @@ -239,9 +310,15 @@ def run_inference( check_tt_model_supported(model) if multi_modal: - assert "Llama-3.2" in model or "Qwen2.5-VL" in model, ( - "The multi-modal inference test " - "currently only supports Llama-3.2 and Qwen2.5-VL models") + assert ( + "Llama-3.2" in model + or "Qwen2.5-VL" in model + or "gemma" in model + ), ( + "The multi-modal inference test currently only supports " + "Llama-3.2, Qwen2.5-VL, and Gemma models" + ) + # LLM args engine_kw_args = { @@ -301,6 +378,8 @@ def run_inference( prompts = get_sample_multi_modal_llama_inputs() elif "Qwen2.5-VL" in model: prompts = get_sample_multi_modal_qwen_inputs(model) + elif "gemma" in model: + prompts = get_sample_multi_modal_gemma_inputs(model) else: raise ValueError( f"Unsupported model for multi-modal inference test: {model}" @@ -326,6 +405,8 @@ def run_inference( IMAGE_TOKEN_ID = 128256 # Specific to multi-modal llama elif "Qwen2.5-VL" in model: IMAGE_TOKEN_ID = 151655 # Specific to multi-modal qwen + elif "gemma" in model: + IMAGE_TOKEN_ID = 262144 else: raise ValueError( f"Unsupported model for multi-modal inference test in perf " diff --git a/vllm/worker/tt_worker.py b/vllm/worker/tt_worker.py index 45c4cbb34c79..07b29edc02ba 100644 --- a/vllm/worker/tt_worker.py +++ b/vllm/worker/tt_worker.py @@ -503,6 +503,7 @@ def device_params_from_override_tt_config(override_tt_config, trace_mode): if override_tt_config and "worker_l1_size" in override_tt_config: device_params["worker_l1_size"] = override_tt_config["worker_l1_size"] + device_params["l1_small_size"]=79104 return device_params From 5562a824c46916ec76350e2e6db58dab6782b8a7 Mon Sep 17 00:00:00 2001 From: MohammedTaherMcW Date: Mon, 1 Sep 2025 17:08:58 +0000 Subject: [PATCH 2/2] Add Gemma 1B vllm support --- examples/offline_inference_tt.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference_tt.py b/examples/offline_inference_tt.py index dceda7c09302..56dc5ab6bb7a 100644 --- a/examples/offline_inference_tt.py +++ b/examples/offline_inference_tt.py @@ -62,11 +62,17 @@ def register_tt_models(): "TTMistralForCausalLM", "models.tt_transformers.tt.generator_vllm:MistralForCausalLM") + # Gemma ModelRegistry.register_model( "TTGemma3ForConditionalGeneration", "models.tt_transformers.tt.generator_vllm:Gemma3ForConditionalGeneration" ) + ModelRegistry.register_model( + "TTGemma3ForCausalLM", + "models.tt_transformers.tt.generator_vllm:Gemma3ForCausalLM" + ) + register_tt_models() # Import and register models from tt-metal @@ -103,9 +109,9 @@ def get_sample_multi_modal_gemma_inputs(model): text_prompts = [] imgs = [] img_refs = [ - "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", None] + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", "https://algodocs.com/assets/img/articles/2021-10-14/handwritten-text-1.jpg",None] - questions = ["Describe this image.", "What is Global Warming?"] + questions = ["Describe this image.", "Do OCR For this image","What is Global Warming"] assert len(img_refs) == len(questions), ( "Number of image references must match number of questions") @@ -252,6 +258,7 @@ def check_tt_model_supported(model): "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", "mistralai/Mistral-7B-Instruct-v0.3", + "google/gemma-3-1b-it", "google/gemma-3-4b-it", "google/gemma-3-27b-it", ]