Skip to content
This repository was archived by the owner on Oct 23, 2025. It is now read-only.
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 70 additions & 6 deletions examples/offline_inference_tt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from PIL import Image as PIL_Image
from pkg_resources import resource_filename
from tqdm import tqdm
from transformers import AutoTokenizer
from transformers import AutoProcessor, AutoTokenizer

from vllm import LLM, ModelRegistry, SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -48,8 +48,10 @@ def register_tt_models():

# Qwen2.5 - Text
path_qwen_text = "models.tt_transformers.tt.generator_vllm:QwenForCausalLM"
path_qwen_vision = "models.tt_transformers.tt.generator_vllm:Qwen2_5_VLForConditionalGeneration"
ModelRegistry.register_model("TTQwen2ForCausalLM", path_qwen_text)
ModelRegistry.register_model("TTQwen3ForCausalLM", path_qwen_text)
ModelRegistry.register_model("TTQwen2_5_VLForConditionalGeneration", path_qwen_vision)

# Mistral
ModelRegistry.register_model(
Expand Down Expand Up @@ -88,6 +90,57 @@ def get_sample_multi_modal_llama_inputs():
inputs.append({"prompt": question})
return inputs

def get_sample_multi_modal_qwen_inputs(model):
# Prepare a sample multi-modal prompt for Qwen2.5-VL
text_prompts = []
imgs = []
questions = ["Describe this image."]
img_refs = [
"https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
]
prompts = [[{
"role":
"user",
"content": [{
"type": "image",
"image": img_ref,
"resized_height": 224,
"resized_width": 224,
}, {
"type": "text",
"text": question
}]
}] for img_ref, question in zip(img_refs, questions)]
tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)
for prompt in prompts:
chat_prompt = tokenizer.apply_chat_template(prompt,
tokenize=False,
add_generation_prompt=True)
if any(ctnt["type"] == "image" for entry in prompt
for ctnt in entry['content']):
from qwen_vl_utils import (
process_vision_info) # Import here to avoid for other models
image_inputs, video_inputs = process_vision_info(prompt)
assert video_inputs is None, "Video inputs not supported yet"
assert len(
image_inputs) == 1, "Multi-image inputs not supported yet"
imgs.append(image_inputs[0])
else:
imgs.append(None)
text_prompts.append(chat_prompt)

inputs = []
for img, text_prompt in zip(imgs, text_prompts):
if img is not None:
inputs.append({
"prompt": text_prompt,
"multi_modal_data": {
"image": img
}
})
else:
inputs.append({"prompt": text_prompt})
return inputs

def check_tt_model_supported(model):
supported_models = [
Expand Down Expand Up @@ -118,6 +171,7 @@ def check_tt_model_supported(model):
"Qwen/Qwen3-8B",
"Qwen/Qwen3-14B",
"Qwen/Qwen3-32B",
"Qwen/Qwen2.5-VL-7B-Instruct",
"deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
"mistralai/Mistral-7B-Instruct-v0.3",
Expand Down Expand Up @@ -177,8 +231,8 @@ def run_inference(
check_tt_model_supported(model)

if multi_modal:
assert "Llama-3.2" in model, "The multi-modal inference test " + \
"currently only supports Llama-3.2 models"
assert "Llama-3.2" in model or "Qwen2.5-VL" in model, "The multi-modal inference test " + \
"currently only supports Llama-3.2 and Qwen2.5 models"

# LLM args
engine_kw_args = {
Expand Down Expand Up @@ -234,7 +288,10 @@ def run_inference(
list), "Prompts must be a list of strings"
else:
print("Ignoring prompts json for multi-modal inference")
prompts = get_sample_multi_modal_llama_inputs()
if "Qwen2.5-VL" in model:
prompts = get_sample_multi_modal_qwen_inputs(model)
else:
prompts = get_sample_multi_modal_llama_inputs()
if num_repeat_prompts is not None:
prompts = prompts * num_repeat_prompts
print("Number of prompts:", len(prompts))
Expand All @@ -252,8 +309,15 @@ def run_inference(
"prompt_token_ids": prompt_token_ids_user
} for _ in range(max_seqs_in_batch)]
else:
MLLAMA_IMAGE_TOKEN_ID = 128256 # Specific to multi-modal llama
prompt_token_ids_user.insert(0, MLLAMA_IMAGE_TOKEN_ID)
if "Llama-3.2" in model:
IMAGE_TOKEN_ID = 128256 # Specific to multi-modal llama
elif "Qwen2.5-VL" in model:
IMAGE_TOKEN_ID = 151655 # Specific to multi-modal qwen
else:
raise ValueError(
f"Unsupported model for multi-modal inference test in perf "
f"mode: {model}")
prompt_token_ids_user.insert(0, IMAGE_TOKEN_ID)
random_pixels = np.random.randint(0,
256, (512, 512, 3),
dtype=np.uint8)
Expand Down