From 04c784d0521ebf2484fd8cb590e0e71c4fcf7126 Mon Sep 17 00:00:00 2001 From: mcw Date: Thu, 14 Aug 2025 18:56:06 +0530 Subject: [PATCH] add vllm support model mistral_24B Model --- examples/offline_inference_tt.py | 85 +++++++++++++++++++++++++++++--- 1 file changed, 79 insertions(+), 6 deletions(-) diff --git a/examples/offline_inference_tt.py b/examples/offline_inference_tt.py index 2f1d6bd78fe0..5ecb5debb30b 100644 --- a/examples/offline_inference_tt.py +++ b/examples/offline_inference_tt.py @@ -10,7 +10,7 @@ from PIL import Image as PIL_Image from pkg_resources import resource_filename from tqdm import tqdm -from transformers import AutoTokenizer +from transformers import AutoTokenizer, AutoProcessor from vllm import LLM, ModelRegistry, SamplingParams from vllm.engine.arg_utils import AsyncEngineArgs @@ -61,6 +61,9 @@ def register_tt_models(): ModelRegistry.register_model( "TTMistralForCausalLM", "models.tt_transformers.tt.generator_vllm:MistralForCausalLM") + ModelRegistry.register_model( + "TTMistral3ForConditionalGeneration", + "models.tt_transformers.tt.generator_vllm:Mistral3ForConditionalGeneration") register_tt_models() # Import and register models from tt-metal @@ -94,6 +97,71 @@ def get_sample_multi_modal_llama_inputs(): inputs.append({"prompt": question}) return inputs +def get_sample_multi_modal_mistral_inputs(model): + # Prepare a sample multi-modal prompt for Mistral-Small-3.1-24B-Instruct-2503 + text_prompts = [] + imgs = [] + img_refs = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg" + ] + + questions = ["Describe this image."] + assert len(img_refs) == len(questions), ( + "Number of image references must match number of questions") + + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(model, trust_remote_code=True) + + text_prompts = [] + imgs = [] + import requests + from PIL import Image as PIL_Image + from io import BytesIO + + prompts = [] + imgs = [] + text_prompts = [] + + for img_ref, question in zip(img_refs, questions): + content_list = [] + + if img_ref: # image present + content_list.append({"type": "image", "image": img_ref}) + + # Load and convert image + response = requests.get(img_ref) + img = PIL_Image.open(BytesIO(response.content)).convert("RGB") + imgs.append(img) + else: # no image + imgs.append(None) + + # Always add text + content_list.append({"type": "text", "text": question}) + + # Append final prompt + prompt = [{"role": "user", "content": content_list}] + prompts.append(prompt) + + chat_prompt = processor.apply_chat_template( + prompt, + tokenize=False, + add_generation_prompt=True + ) + text_prompts.append(chat_prompt) + + inputs = [] + for img, text_prompt in zip(imgs, text_prompts): + if img is not None: + inputs.append({ + "prompt": text_prompt, + "multi_modal_data": { + "image": img + } + }) + else: + inputs.append({"prompt": text_prompt}) + + return inputs def get_sample_multi_modal_qwen_inputs(model): # Prepare a sample multi-modal prompt for Qwen2.5-VL @@ -183,6 +251,7 @@ def check_tt_model_supported(model): "deepseek-ai/DeepSeek-R1-Distill-Llama-70B", "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", "mistralai/Mistral-7B-Instruct-v0.3", + "mistralai/Mistral-Small-3.1-24B-Instruct-2503", ] assert model in supported_models, f"Invalid model: {model}" @@ -239,9 +308,10 @@ def run_inference( check_tt_model_supported(model) if multi_modal: - assert "Llama-3.2" in model or "Qwen2.5-VL" in model, ( - "The multi-modal inference test " - "currently only supports Llama-3.2 and Qwen2.5-VL models") + # assert "Llama-3.2" in model, "The multi-modal inference test " + \ + # "currently only supports Llama-3.2 models" + assert any(name in model for name in ["Llama-3.2", "gemma", "Qwen2.5-VL", "mistralai"]), \ + "The multi-modal inference test requires Llama-3.2 or Gemma models or Mistral models" # LLM args engine_kw_args = { @@ -296,8 +366,10 @@ def run_inference( assert isinstance(prompts, list), "Prompts must be a list of strings" else: - print("Ignoring prompts json for multi-modal inference") - if "Llama-3.2" in model: + # print("Ignoring prompts json for multi-modal inference") + if "mistral" in model: + prompts = get_sample_multi_modal_mistral_inputs(model) + elif "Llama-3.2" in model: prompts = get_sample_multi_modal_llama_inputs() elif "Qwen2.5-VL" in model: prompts = get_sample_multi_modal_qwen_inputs(model) @@ -305,6 +377,7 @@ def run_inference( raise ValueError( f"Unsupported model for multi-modal inference test: {model}" ) + if num_repeat_prompts is not None: prompts = prompts * num_repeat_prompts print("Number of prompts:", len(prompts))