Skip to content
This repository was archived by the owner on Oct 23, 2025. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 79 additions & 6 deletions examples/offline_inference_tt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from PIL import Image as PIL_Image
from pkg_resources import resource_filename
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoTokenizer, AutoProcessor

from vllm import LLM, ModelRegistry, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -61,6 +61,9 @@ def register_tt_models():
ModelRegistry.register_model(
"TTMistralForCausalLM",
"models.tt_transformers.tt.generator_vllm:MistralForCausalLM")
ModelRegistry.register_model(
"TTMistral3ForConditionalGeneration",
"models.tt_transformers.tt.generator_vllm:Mistral3ForConditionalGeneration")


register_tt_models() # Import and register models from tt-metal
Expand Down Expand Up @@ -94,6 +97,71 @@ def get_sample_multi_modal_llama_inputs():
inputs.append({"prompt": question})
return inputs

def get_sample_multi_modal_mistral_inputs(model):
# Prepare a sample multi-modal prompt for Mistral-Small-3.1-24B-Instruct-2503
text_prompts = []
imgs = []
img_refs = [
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
]

questions = ["Describe this image."]
assert len(img_refs) == len(questions), (
"Number of image references must match number of questions")

tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
processor = AutoProcessor.from_pretrained(model, trust_remote_code=True)

text_prompts = []
imgs = []
import requests
from PIL import Image as PIL_Image
from io import BytesIO

prompts = []
imgs = []
text_prompts = []

for img_ref, question in zip(img_refs, questions):
content_list = []

if img_ref: # image present
content_list.append({"type": "image", "image": img_ref})

# Load and convert image
response = requests.get(img_ref)
img = PIL_Image.open(BytesIO(response.content)).convert("RGB")
imgs.append(img)
else: # no image
imgs.append(None)

# Always add text
content_list.append({"type": "text", "text": question})

# Append final prompt
prompt = [{"role": "user", "content": content_list}]
prompts.append(prompt)

chat_prompt = processor.apply_chat_template(
prompt,
tokenize=False,
add_generation_prompt=True
)
text_prompts.append(chat_prompt)

inputs = []
for img, text_prompt in zip(imgs, text_prompts):
if img is not None:
inputs.append({
"prompt": text_prompt,
"multi_modal_data": {
"image": img
}
})
else:
inputs.append({"prompt": text_prompt})

return inputs

def get_sample_multi_modal_qwen_inputs(model):
# Prepare a sample multi-modal prompt for Qwen2.5-VL
Expand Down Expand Up @@ -183,6 +251,7 @@ def check_tt_model_supported(model):
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
"mistralai/Mistral-7B-Instruct-v0.3",
"mistralai/Mistral-Small-3.1-24B-Instruct-2503",
]
assert model in supported_models, f"Invalid model: {model}"

Expand Down Expand Up @@ -239,9 +308,10 @@ def run_inference(
check_tt_model_supported(model)

if multi_modal:
assert "Llama-3.2" in model or "Qwen2.5-VL" in model, (
"The multi-modal inference test "
"currently only supports Llama-3.2 and Qwen2.5-VL models")
# assert "Llama-3.2" in model, "The multi-modal inference test " + \
# "currently only supports Llama-3.2 models"
assert any(name in model for name in ["Llama-3.2", "gemma", "Qwen2.5-VL", "mistralai"]), \
"The multi-modal inference test requires Llama-3.2 or Gemma models or Mistral models"

# LLM args
engine_kw_args = {
Expand Down Expand Up @@ -296,15 +366,18 @@ def run_inference(
assert isinstance(prompts,
list), "Prompts must be a list of strings"
else:
print("Ignoring prompts json for multi-modal inference")
if "Llama-3.2" in model:
# print("Ignoring prompts json for multi-modal inference")
if "mistral" in model:
prompts = get_sample_multi_modal_mistral_inputs(model)
elif "Llama-3.2" in model:
prompts = get_sample_multi_modal_llama_inputs()
elif "Qwen2.5-VL" in model:
prompts = get_sample_multi_modal_qwen_inputs(model)
else:
raise ValueError(
f"Unsupported model for multi-modal inference test: {model}"
)

if num_repeat_prompts is not None:
prompts = prompts * num_repeat_prompts
print("Number of prompts:", len(prompts))
Expand Down