Skip to content
This repository was archived by the owner on Oct 23, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 82 additions & 7 deletions examples/offline_inference_tt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from PIL import Image as PIL_Image
from pkg_resources import resource_filename
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoProcessor

from vllm import LLM, ModelRegistry, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
Expand All @@ -26,16 +26,16 @@ def register_tt_models():
if llama_text_version == "tt_transformers":
path_llama_text = \
"models.tt_transformers.tt.generator_vllm:LlamaForCausalLM"
elif llama_text_version == "llama3_subdevices":
elif llama_text_version == "llama3_70b_galaxy":
path_llama_text = \
"models.demos.llama3_subdevices.tt.generator_vllm:LlamaForCausalLM"
"models.demos.llama3_70b_galaxy.tt.generator_vllm:LlamaForCausalLM"
elif llama_text_version == "llama2_70b":
path_llama_text = \
"models.demos.t3000.llama2_70b.tt.generator_vllm:TtLlamaForCausalLM"
else:
raise ValueError(
f"Unsupported TT Llama version: {llama_text_version}, "
"pick one of [tt_transformers, llama3_subdevices, llama2_70b]")
"pick one of [tt_transformers, llama3_70b_galaxy, llama2_70b]")

# Llama3.1/3.2 - Text
ModelRegistry.register_model("TTLlamaForCausalLM", path_llama_text)
Expand All @@ -56,6 +56,10 @@ def register_tt_models():
"TTMistralForCausalLM",
"models.tt_transformers.tt.generator_vllm:MistralForCausalLM")

ModelRegistry.register_model(
"TTGemma3ForConditionalGeneration",
"models.tt_transformers.tt.generator_vllm:Gemma3ForConditionalGeneration"
)

register_tt_models() # Import and register models from tt-metal

Expand Down Expand Up @@ -89,6 +93,72 @@ def get_sample_multi_modal_llama_inputs():
return inputs


def get_sample_multi_modal_gemma_inputs(model):
# Prepare a sample multi-modal prompt for Gemma3-4b-it
text_prompts = []
imgs = []
img_refs = [
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg", None]

questions = ["Describe this image.", "What is Global Warming?"]
assert len(img_refs) == len(questions), (
"Number of image references must match number of questions")

tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model, trust_remote_code=True)

text_prompts = []
imgs = []
import requests
from PIL import Image as PIL_Image
from io import BytesIO

prompts = []
imgs = []
text_prompts = []

for img_ref, question in zip(img_refs, questions):
content_list = []

if img_ref: # image present
content_list.append({"type": "image", "image": img_ref})

# Load and convert image
response = requests.get(img_ref)
img = PIL_Image.open(BytesIO(response.content)).convert("RGB")
imgs.append(img)
else: # no image
imgs.append(None)

# Always add text
content_list.append({"type": "text", "text": question})

# Append final prompt
prompt = [{"role": "user", "content": content_list}]
prompts.append(prompt)

# Create text-only template for tokenizer
chat_prompt = tokenizer.apply_chat_template(
prompt,
tokenize=False,
add_generation_prompt=True
)
text_prompts.append(chat_prompt)

inputs = []
for img, text_prompt in zip(imgs, text_prompts):
if img is not None:
inputs.append({
"prompt": text_prompt,
"multi_modal_data": {
"image": img
}
})
else:
inputs.append({"prompt": text_prompt})

return inputs

def check_tt_model_supported(model):
supported_models = [
"meta-llama/Llama-3.1-70B",
Expand Down Expand Up @@ -121,6 +191,7 @@ def check_tt_model_supported(model):
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
"mistralai/Mistral-7B-Instruct-v0.3",
"google/gemma-3-4b-it",
]
assert model in supported_models, f"Invalid model: {model}"

Expand Down Expand Up @@ -177,8 +248,9 @@ def run_inference(
check_tt_model_supported(model)

if multi_modal:
assert "Llama-3.2" in model, "The multi-modal inference test " + \
"currently only supports Llama-3.2 models"
assert any(name in model for name in ["Llama-3.2", "gemma"]), \
"The multi-modal inference test requires Llama-3.2 or Gemma models"


# LLM args
engine_kw_args = {
Expand Down Expand Up @@ -234,7 +306,10 @@ def run_inference(
list), "Prompts must be a list of strings"
else:
print("Ignoring prompts json for multi-modal inference")
prompts = get_sample_multi_modal_llama_inputs()
if "gemma" in model:
prompts = get_sample_multi_modal_gemma_inputs(model)
else:
prompts = get_sample_multi_modal_llama_inputs()
if num_repeat_prompts is not None:
prompts = prompts * num_repeat_prompts
print("Number of prompts:", len(prompts))
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/tt_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ def _device_params_from_override_tt_config(self):
if override_tt_config and "worker_l1_size" in override_tt_config:
device_params["worker_l1_size"] = override_tt_config[
"worker_l1_size"]

device_params["l1_small_size"]=79104
return device_params

def _open_mesh_device(self):
Expand Down