Skip to content
This repository was archived by the owner on Oct 23, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 92 additions & 4 deletions examples/offline_inference_tt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from PIL import Image as PIL_Image
from pkg_resources import resource_filename
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoProcessor

from vllm import LLM, ModelRegistry, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -62,7 +62,17 @@ def register_tt_models():
"TTMistralForCausalLM",
"models.tt_transformers.tt.generator_vllm:MistralForCausalLM")

# Gemma
ModelRegistry.register_model(
"TTGemma3ForConditionalGeneration",
"models.tt_transformers.tt.generator_vllm:Gemma3ForConditionalGeneration"
)

ModelRegistry.register_model(
"TTGemma3ForCausalLM",
"models.tt_transformers.tt.generator_vllm:Gemma3ForCausalLM"
)

register_tt_models() # Import and register models from tt-metal


Expand Down Expand Up @@ -94,6 +104,71 @@ def get_sample_multi_modal_llama_inputs():
inputs.append({"prompt": question})
return inputs

def get_sample_multi_modal_gemma_inputs(model):
# Prepare a sample multi-modal prompt for Gemma3-4b-it
text_prompts = []
imgs = []
img_refs = [
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", "https://algodocs.com/assets/img/articles/2021-10-14/handwritten-text-1.jpg",None]

questions = ["Describe this image.", "Do OCR For this image","What is Global Warming"]
assert len(img_refs) == len(questions), (
"Number of image references must match number of questions")

tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model, trust_remote_code=True)

text_prompts = []
imgs = []
import requests
from PIL import Image as PIL_Image
from io import BytesIO

prompts = []
imgs = []
text_prompts = []

for img_ref, question in zip(img_refs, questions):
content_list = []

if img_ref: # image present
content_list.append({"type": "image", "image": img_ref})

# Load and convert image
response = requests.get(img_ref)
img = PIL_Image.open(BytesIO(response.content)).convert("RGB")
imgs.append(img)
else: # no image
imgs.append(None)

# Always add text
content_list.append({"type": "text", "text": question})

# Append final prompt
prompt = [{"role": "user", "content": content_list}]
prompts.append(prompt)

# Create text-only template for tokenizer
chat_prompt = tokenizer.apply_chat_template(
prompt,
tokenize=False,
add_generation_prompt=True
)
text_prompts.append(chat_prompt)

inputs = []
for img, text_prompt in zip(imgs, text_prompts):
if img is not None:
inputs.append({
"prompt": text_prompt,
"multi_modal_data": {
"image": img
}
})
else:
inputs.append({"prompt": text_prompt})
return inputs


def get_sample_multi_modal_qwen_inputs(model):
# Prepare a sample multi-modal prompt for Qwen2.5-VL
Expand Down Expand Up @@ -183,6 +258,9 @@ def check_tt_model_supported(model):
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
"mistralai/Mistral-7B-Instruct-v0.3",
"google/gemma-3-1b-it",
"google/gemma-3-4b-it",
"google/gemma-3-27b-it",
]
assert model in supported_models, f"Invalid model: {model}"

Expand Down Expand Up @@ -239,9 +317,15 @@ def run_inference(
check_tt_model_supported(model)

if multi_modal:
assert "Llama-3.2" in model or "Qwen2.5-VL" in model, (
"The multi-modal inference test "
"currently only supports Llama-3.2 and Qwen2.5-VL models")
assert (
"Llama-3.2" in model
or "Qwen2.5-VL" in model
or "gemma" in model
), (
"The multi-modal inference test currently only supports "
"Llama-3.2, Qwen2.5-VL, and Gemma models"
)


# LLM args
engine_kw_args = {
Expand Down Expand Up @@ -301,6 +385,8 @@ def run_inference(
prompts = get_sample_multi_modal_llama_inputs()
elif "Qwen2.5-VL" in model:
prompts = get_sample_multi_modal_qwen_inputs(model)
elif "gemma" in model:
prompts = get_sample_multi_modal_gemma_inputs(model)
else:
raise ValueError(
f"Unsupported model for multi-modal inference test: {model}"
Expand All @@ -326,6 +412,8 @@ def run_inference(
IMAGE_TOKEN_ID = 128256 # Specific to multi-modal llama
elif "Qwen2.5-VL" in model:
IMAGE_TOKEN_ID = 151655 # Specific to multi-modal qwen
elif "gemma" in model:
IMAGE_TOKEN_ID = 262144
else:
raise ValueError(
f"Unsupported model for multi-modal inference test in perf "
Expand Down
1 change: 1 addition & 0 deletions vllm/worker/tt_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,7 @@ def device_params_from_override_tt_config(override_tt_config, trace_mode):

if override_tt_config and "worker_l1_size" in override_tt_config:
device_params["worker_l1_size"] = override_tt_config["worker_l1_size"]
device_params["l1_small_size"]=79104

return device_params

Expand Down