From c9b7aeeda669cacf1e0d6bead4a40d33d7dad9bb Mon Sep 17 00:00:00 2001 From: root Date: Sun, 10 Aug 2025 17:09:22 +0000 Subject: [PATCH] Add vLLM support for Gemma-3-4b-it --- examples/offline_inference_tt.py | 89 +++++++++++++++++++++++++++++--- vllm/worker/tt_worker.py | 2 +- 2 files changed, 83 insertions(+), 8 deletions(-) diff --git a/examples/offline_inference_tt.py b/examples/offline_inference_tt.py index a4fb4fa6b8b6..595b1a8ecdc0 100644 --- a/examples/offline_inference_tt.py +++ b/examples/offline_inference_tt.py @@ -10,7 +10,7 @@ from PIL import Image as PIL_Image from pkg_resources import resource_filename from tqdm import tqdm -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoProcessor from vllm import LLM, ModelRegistry, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs @@ -26,16 +26,16 @@ def register_tt_models(): if llama_text_version == "tt_transformers": path_llama_text = \ "models.tt_transformers.tt.generator_vllm:LlamaForCausalLM" - elif llama_text_version == "llama3_subdevices": + elif llama_text_version == "llama3_70b_galaxy": path_llama_text = \ - "models.demos.llama3_subdevices.tt.generator_vllm:LlamaForCausalLM" + "models.demos.llama3_70b_galaxy.tt.generator_vllm:LlamaForCausalLM" elif llama_text_version == "llama2_70b": path_llama_text = \ "models.demos.t3000.llama2_70b.tt.generator_vllm:TtLlamaForCausalLM" else: raise ValueError( f"Unsupported TT Llama version: {llama_text_version}, " - "pick one of [tt_transformers, llama3_subdevices, llama2_70b]") + "pick one of [tt_transformers, llama3_70b_galaxy, llama2_70b]") # Llama3.1/3.2 - Text ModelRegistry.register_model("TTLlamaForCausalLM", path_llama_text) @@ -56,6 +56,10 @@ def register_tt_models(): "TTMistralForCausalLM", "models.tt_transformers.tt.generator_vllm:MistralForCausalLM") + ModelRegistry.register_model( + "TTGemma3ForConditionalGeneration", + "models.tt_transformers.tt.generator_vllm:Gemma3ForConditionalGeneration" + ) register_tt_models() # Import and register models from tt-metal @@ -89,6 +93,72 @@ def get_sample_multi_modal_llama_inputs(): return inputs +def get_sample_multi_modal_gemma_inputs(model): + # Prepare a sample multi-modal prompt for Gemma3-4b-it + text_prompts = [] + imgs = [] + img_refs = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", None] + + questions = ["Describe this image.", "What is Global Warming?"] + assert len(img_refs) == len(questions), ( + "Number of image references must match number of questions") + + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(model, trust_remote_code=True) + + text_prompts = [] + imgs = [] + import requests + from PIL import Image as PIL_Image + from io import BytesIO + + prompts = [] + imgs = [] + text_prompts = [] + + for img_ref, question in zip(img_refs, questions): + content_list = [] + + if img_ref: # image present + content_list.append({"type": "image", "image": img_ref}) + + # Load and convert image + response = requests.get(img_ref) + img = PIL_Image.open(BytesIO(response.content)).convert("RGB") + imgs.append(img) + else: # no image + imgs.append(None) + + # Always add text + content_list.append({"type": "text", "text": question}) + + # Append final prompt + prompt = [{"role": "user", "content": content_list}] + prompts.append(prompt) + + # Create text-only template for tokenizer + chat_prompt = tokenizer.apply_chat_template( + prompt, + tokenize=False, + add_generation_prompt=True + ) + text_prompts.append(chat_prompt) + + inputs = [] + for img, text_prompt in zip(imgs, text_prompts): + if img is not None: + inputs.append({ + "prompt": text_prompt, + "multi_modal_data": { + "image": img + } + }) + else: + inputs.append({"prompt": text_prompt}) + + return inputs + def check_tt_model_supported(model): supported_models = [ "meta-llama/Llama-3.1-70B", @@ -121,6 +191,7 @@ def check_tt_model_supported(model): "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", "mistralai/Mistral-7B-Instruct-v0.3", + "google/gemma-3-4b-it", ] assert model in supported_models, f"Invalid model: {model}" @@ -177,8 +248,9 @@ def run_inference( check_tt_model_supported(model) if multi_modal: - assert "Llama-3.2" in model, "The multi-modal inference test " + \ - "currently only supports Llama-3.2 models" + assert any(name in model for name in ["Llama-3.2", "gemma"]), \ + "The multi-modal inference test requires Llama-3.2 or Gemma models" + # LLM args engine_kw_args = { @@ -234,7 +306,10 @@ def run_inference( list), "Prompts must be a list of strings" else: print("Ignoring prompts json for multi-modal inference") - prompts = get_sample_multi_modal_llama_inputs() + if "gemma" in model: + prompts = get_sample_multi_modal_gemma_inputs(model) + else: + prompts = get_sample_multi_modal_llama_inputs() if num_repeat_prompts is not None: prompts = prompts * num_repeat_prompts print("Number of prompts:", len(prompts)) diff --git a/vllm/worker/tt_worker.py b/vllm/worker/tt_worker.py index e51d53bef5b6..8d6e72895db3 100644 --- a/vllm/worker/tt_worker.py +++ b/vllm/worker/tt_worker.py @@ -476,7 +476,7 @@ def _device_params_from_override_tt_config(self): if override_tt_config and "worker_l1_size" in override_tt_config: device_params["worker_l1_size"] = override_tt_config[ "worker_l1_size"] - + device_params["l1_small_size"]=79104 return device_params def _open_mesh_device(self):