diff --git a/examples/offline_inference_tt.py b/examples/offline_inference_tt.py index af97de193df7..f8ee9815d982 100644 --- a/examples/offline_inference_tt.py +++ b/examples/offline_inference_tt.py @@ -10,7 +10,7 @@ from PIL import Image as PIL_Image from pkg_resources import resource_filename from tqdm import tqdm -from transformers import AutoTokenizer +from transformers import AutoProcessor, AutoTokenizer from vllm import LLM, ModelRegistry, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs @@ -48,8 +48,10 @@ def register_tt_models(): # Qwen2.5 - Text path_qwen_text = "models.tt_transformers.tt.generator_vllm:QwenForCausalLM" + path_qwen_vision = "models.tt_transformers.tt.generator_vllm:Qwen2_5_VLForConditionalGeneration" ModelRegistry.register_model("TTQwen2ForCausalLM", path_qwen_text) ModelRegistry.register_model("TTQwen3ForCausalLM", path_qwen_text) + ModelRegistry.register_model("TTQwen2_5_VLForConditionalGeneration", path_qwen_vision) # Mistral ModelRegistry.register_model( @@ -88,6 +90,57 @@ def get_sample_multi_modal_llama_inputs(): inputs.append({"prompt": question}) return inputs +def get_sample_multi_modal_qwen_inputs(model): + # Prepare a sample multi-modal prompt for Qwen2.5-VL + text_prompts = [] + imgs = [] + questions = ["Describe this image."] + img_refs = [ + "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg" + ] + prompts = [[{ + "role": + "user", + "content": [{ + "type": "image", + "image": img_ref, + "resized_height": 224, + "resized_width": 224, + }, { + "type": "text", + "text": question + }] + }] for img_ref, question in zip(img_refs, questions)] + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + for prompt in prompts: + chat_prompt = tokenizer.apply_chat_template(prompt, + tokenize=False, + add_generation_prompt=True) + if any(ctnt["type"] == "image" for entry in prompt + for ctnt in entry['content']): + from qwen_vl_utils import ( + process_vision_info) # Import here to avoid for other models + image_inputs, video_inputs = process_vision_info(prompt) + assert video_inputs is None, "Video inputs not supported yet" + assert len( + image_inputs) == 1, "Multi-image inputs not supported yet" + imgs.append(image_inputs[0]) + else: + imgs.append(None) + text_prompts.append(chat_prompt) + + inputs = [] + for img, text_prompt in zip(imgs, text_prompts): + if img is not None: + inputs.append({ + "prompt": text_prompt, + "multi_modal_data": { + "image": img + } + }) + else: + inputs.append({"prompt": text_prompt}) + return inputs def check_tt_model_supported(model): supported_models = [ @@ -118,6 +171,7 @@ def check_tt_model_supported(model): "Qwen/Qwen3-8B", "Qwen/Qwen3-14B", "Qwen/Qwen3-32B", + "Qwen/Qwen2.5-VL-7B-Instruct", "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", "mistralai/Mistral-7B-Instruct-v0.3", @@ -177,8 +231,8 @@ def run_inference( check_tt_model_supported(model) if multi_modal: - assert "Llama-3.2" in model, "The multi-modal inference test " + \ - "currently only supports Llama-3.2 models" + assert "Llama-3.2" in model or "Qwen2.5-VL" in model, "The multi-modal inference test " + \ + "currently only supports Llama-3.2 and Qwen2.5 models" # LLM args engine_kw_args = { @@ -234,7 +288,10 @@ def run_inference( list), "Prompts must be a list of strings" else: print("Ignoring prompts json for multi-modal inference") - prompts = get_sample_multi_modal_llama_inputs() + if "Qwen2.5-VL" in model: + prompts = get_sample_multi_modal_qwen_inputs(model) + else: + prompts = get_sample_multi_modal_llama_inputs() if num_repeat_prompts is not None: prompts = prompts * num_repeat_prompts print("Number of prompts:", len(prompts)) @@ -252,8 +309,15 @@ def run_inference( "prompt_token_ids": prompt_token_ids_user } for _ in range(max_seqs_in_batch)] else: - MLLAMA_IMAGE_TOKEN_ID = 128256 # Specific to multi-modal llama - prompt_token_ids_user.insert(0, MLLAMA_IMAGE_TOKEN_ID) + if "Llama-3.2" in model: + IMAGE_TOKEN_ID = 128256 # Specific to multi-modal llama + elif "Qwen2.5-VL" in model: + IMAGE_TOKEN_ID = 151655 # Specific to multi-modal qwen + else: + raise ValueError( + f"Unsupported model for multi-modal inference test in perf " + f"mode: {model}") + prompt_token_ids_user.insert(0, IMAGE_TOKEN_ID) random_pixels = np.random.randint(0, 256, (512, 512, 3), dtype=np.uint8)