From 263655fad5ce96210f05b4b44ddfa908d4ba7ddd Mon Sep 17 00:00:00 2001 From: aditjadh Date: Tue, 26 Aug 2025 18:23:18 -0700 Subject: [PATCH] Added Multiframe Inference for llama4+internvl --- .../transformers/models/modeling_auto.py | 215 +++++++++--- examples/intern_example/internvl_inference.py | 120 +++++-- .../llama4_multi_image_example.py | 27 +- .../llama4_multi_image_inference.py | 330 ++++++++++++++++++ examples/llama4_example/run_llama4.sh | 25 ++ 5 files changed, 620 insertions(+), 97 deletions(-) rename examples/{ => llama4_example}/llama4_multi_image_example.py (69%) create mode 100644 examples/llama4_example/llama4_multi_image_inference.py create mode 100644 examples/llama4_example/run_llama4.sh diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index b3d27f3a5..e91f0b2cc 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -4,7 +4,9 @@ # SPDX-License-Identifier: BSD-3-Clause # # ---------------------------------------------------------------------------- - +import os +import json +import hashlib import warnings from pathlib import Path from time import perf_counter @@ -55,6 +57,7 @@ constants, get_padding_shape_from_config, ) +from QEfficient.utils.cache import to_hashable from QEfficient.utils.logging_utils import logger @@ -73,7 +76,7 @@ def __init__(self, model: nn.Module, **kwargs) -> None: ): raise AssertionError("Please use `from_pretrained` method to load quantized models") - super().__init__(model, **kwargs) + super().__init__(model) def __repr__(self) -> str: return self.__class__.__name__ + "\n" + self.model.__repr__() @@ -162,7 +165,7 @@ class QEFFAutoModel(QEFFTransformersBase): _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] def __init__(self, model: nn.Module, pooling=None, **kwargs): - super().__init__(model, **kwargs) + super().__init__(model) # Make Embedding specific transforms like appending pooling if pooling: @@ -170,7 +173,7 @@ def __init__(self, model: nn.Module, pooling=None, **kwargs): self.model.base_model.config.use_cache = True - self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) @classmethod @with_replaced_quantizers @@ -223,11 +226,29 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k kv_offload = kwargs.pop("kv_offload", None) if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( - model, kv_offload=kv_offload, **kwargs + model, kv_offload=kv_offload ) return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, pooling=pooling, **kwargs) + @property + def model_hash(self) -> str: + # NOTE: model_config.to_diff_dict() has "_name_or_path" attribute which is the model card name or path. + # Using same card name will result in same hash. But, using a relative path for one run and + # absolute path for another run will result in different hash. + # The added complexity to resolve different paths to same location is not worth pursuing. + # Instead, advise the user to always provide same relative paths or absolute paths for local models. + + # Compute the hash with: model_config, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(self.model.config.to_diff_dict())) + mhash.update(to_hashable(self._transform_names())) + + mhash.update(to_hashable(self.pretrained_model_name_or_path)) + + mhash = mhash.hexdigest()[:16] + return mhash + @property def get_model_config(self) -> dict: return self.model.config.__dict__ @@ -305,8 +326,8 @@ def compile( ] return self._compile( - onnx_path=onnx_path, - compile_dir=compile_dir, + onnx_path, + compile_dir, compile_only=True, specializations=specializations, convert_to_fp16=True, @@ -428,15 +449,12 @@ class QEffVisionEncoderForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model: nn.modules, **kwargs): - super().__init__(model, **kwargs) + def __init__(self, model: nn.modules): + super().__init__(model) self.model = model.get_qeff_vision_encoder() - self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): - return self._export( - inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights - ) + def export(self, inputs, output_names, dynamic_axes, export_dir=None): + return self._export(inputs, output_names, dynamic_axes, export_dir) def compile( self, @@ -462,6 +480,20 @@ def compile( **compiler_options, ) + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(self.model.model.config.to_diff_dict())) + mhash.update(to_hashable(self._transform_names())) + mhash.update(to_hashable({"QEffVisionEncoderForTextImageToTextModel": True})) + if hasattr(self.model, "model"): + mhash.update(to_hashable(self.model.model.pretrained_model_name_or_path)) + else: + mhash.update(to_hashable(self.model.pretrained_model_name_or_path)) + mhash = mhash.hexdigest()[:16] + return mhash + @property def model_name(self) -> str: mname = self.model.__class__.__name__ @@ -485,15 +517,12 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model, **kwargs): - super().__init__(model, **kwargs) + def __init__(self, model): + super().__init__(model) self.model = model.get_qeff_language_decoder() - self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): - return self._export( - inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights - ) + def export(self, inputs, output_names, dynamic_axes, export_dir=None): + return self._export(inputs, output_names, dynamic_axes, export_dir) def compile( self, @@ -519,6 +548,20 @@ def compile( **compiler_options, ) + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(self.model.config.to_diff_dict())) + mhash.update(to_hashable(self._transform_names())) + mhash.update(to_hashable({"QEffCausalLMForTextImageToTextModel": True})) + if hasattr(self.model, "model"): + mhash.update(to_hashable(self.model.model.pretrained_model_name_or_path)) + else: + mhash.update(to_hashable(self.model.pretrained_model_name_or_path)) + mhash = mhash.hexdigest()[:16] + return mhash + @property def model_name(self) -> str: mname = self.model.__class__.__name__ @@ -543,8 +586,9 @@ def __init__( raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") self.model = model self.config = model.config - self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) - self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) + self.model.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) + self.vision_model = QEffVisionEncoderForTextImageToTextModel(model) + self.lang_model = QEffCausalLMForTextImageToTextModel(model) self.input_shapes, self.output_names = None, None @property @@ -587,18 +631,14 @@ def export( inputs = self.model.get_dummy_inputs(kv_offload=True) dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) output_names = self.model.get_output_names(kv_offload=True) - self.vision_model.export( inputs["vision"], output_names["vision"], dynamic_axes["vision"], - export_dir=export_dir, - offload_pt_weights=False, - ) - self.lang_model.export( - inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir=export_dir, offload_pt_weights=True + export_dir, ) + self.lang_model.export(inputs["lang"], output_names["lang"], dynamic_axes["lang"], export_dir) return self.onnx_path def compile( @@ -663,7 +703,7 @@ def compile( if not skip_vision: self.vision_model._compile( - compile_dir=compile_dir, + compile_dir, compile_only=True, specializations=specializations["vision"], convert_to_fp16=True, @@ -690,7 +730,7 @@ def compile( custom_io_lang[output_name] = "float16" if "vision_embeds" in output_name else kv_cache_dtype self.lang_model._compile( - compile_dir=compile_dir, + compile_dir, compile_only=True, retained_state=True, specializations=specializations["lang"], @@ -708,7 +748,8 @@ def generate( self, inputs: torch.Tensor, streamer: Optional[TextStreamer] = None, - device_ids: List[int] = None, + device_id_lang: List[int] = None, + device_id_vision: List[int] = None, runtime_ai100: bool = True, generation_len: Optional[int] = None, ) -> Union[torch.Tensor, np.ndarray]: @@ -726,24 +767,24 @@ def generate( raise NotImplementedError("PyTorch execution is not supported yet for this model!") return self.kv_offload_generate( - inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len + inputs=inputs, device_id_lang=device_id_lang, device_id_vision=device_id_vision, streamer=streamer, generation_len=generation_len ) def kv_offload_generate( self, inputs: List[str] = None, streamer: Optional[TextStreamer] = None, - device_ids: List[int] = None, + device_id_lang: List[int] = None, + device_id_vision: List[int] = None, generation_len: int = None, ): if not self.lang_model.qpc_path: raise TypeError("Please run compile API for language model first!") - lang_session = QAICInferenceSession(self.lang_model.qpc_path, device_ids, activate=False) + lang_session = QAICInferenceSession(self.lang_model.qpc_path, device_id_lang) if self.vision_model.qpc_path: - vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_ids) - + vision_session = QAICInferenceSession(self.vision_model.qpc_path, device_id_vision) batch_size, ctx_len, fbs = get_compilation_dims(self.lang_model.qpc_path) pad_token_id = 1 @@ -782,7 +823,7 @@ def kv_offload_generate( inputs["input_ids"], (0, padded_len - input_ids_length), "constant", - pad_token_id, + 1, ) inputs["attention_mask"] = torch.nn.functional.pad( inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 @@ -799,6 +840,25 @@ def kv_offload_generate( k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} } + llama4 = hasattr(self.model.config, "model_type") and self.model.config.model_type == "llama4" + #How do I get specializations from text qpc + if llama4: + qpc_base_path = os.path.dirname(os.path.normpath(self.lang_model.qpc_path)) + specialization_file_path = os.path.join(qpc_base_path, "specializations.json") + logger.info(f"specialization_file_path : {specialization_file_path}") + + if os.path.exists(specialization_file_path): + with open(specialization_file_path, "r") as file: + data = json.load(file) + else: + raise FileNotFoundError(f"expected specializations.json file at path, {qpc_base_path}") + num_patches = int(data["specializations"][0]["max_num_tiles"]) + if vision_inputs['pixel_values'].shape[0] != num_patches: + single_patch = np.expand_dims(vision_inputs['pixel_values'][0], axis=0) + while vision_inputs['pixel_values'].shape[0] < num_patches: + vision_inputs['pixel_values'] = np.concatenate((vision_inputs['pixel_values'], single_patch), axis=0) + + if vision_inputs: vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") vision_start = perf_counter() @@ -817,6 +877,11 @@ def kv_offload_generate( if not_mllama: lang_inputs["image_idx"] = np.array([[0]]) + internvl = hasattr(self.model.config, "model_type") and self.model.config.model_type == "internvl_chat" + if internvl: + vision_shape = vision_outputs["vision_embeds"].shape + vision_outputs["vision_embeds"] = vision_outputs["vision_embeds"].reshape(1,vision_shape[0] * vision_shape[1],vision_shape[2]) + if self.vision_model.qpc_path: vision_session.deactivate() lang_session.activate() @@ -827,6 +892,10 @@ def kv_offload_generate( chunk_inputs = lang_inputs.copy() prefill_start = perf_counter() + # Prepare inputs for prefill + chunk_inputs = lang_inputs.copy() + prefill_start = perf_counter() + # Run prefill chunk_inputs = lang_inputs.copy() for i in range(num_chunks): @@ -907,7 +976,7 @@ def __init__( ): if kwargs.pop("full_batch_size", None): raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") - super().__init__(model, **kwargs) + super().__init__(model) # to handle internvl models if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"): @@ -916,7 +985,7 @@ def __init__( self.model.config.vision_config.use_flash_attn = "false" else: self.model.config.text_config.use_cache = True - self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) @classmethod def from_pretrained( @@ -1002,8 +1071,8 @@ def compile( custom_io[output_name] = "float16" if "pixel_values" in output_name else kv_cache_dtype self._compile( - onnx_path=onnx_path, - compile_dir=compile_dir, + onnx_path, + compile_dir, compile_only=True, retained_state=True, specializations=specializations, @@ -1096,7 +1165,7 @@ def cloud_ai_100_generate( inputs["input_ids"], (0, padded_len - input_ids_length), "constant", - pad_token_id, + 1, ) inputs["attention_mask"] = torch.nn.functional.pad( inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 @@ -1169,6 +1238,16 @@ def cloud_ai_100_generate( ), ) + @property + def model_hash(self) -> str: + mhash = hashlib.sha256() + mhash.update(to_hashable(self.model.config.to_diff_dict())) + mhash.update(to_hashable(self._transform_names())) + mhash.update(to_hashable({"QEFFAutoModelForImageTextToText1QPC": True})) + mhash.update(to_hashable(self.pretrained_model_name_or_path)) + mhash = mhash.hexdigest()[:16] + return mhash + @property def model_name(self) -> str: mname = self.model.__class__.__name__ @@ -1355,16 +1434,17 @@ def __init__( logger.warning( "Please use `from_pretrained` method to load quantized models, might give unexpected results" ) + + super().__init__(model) # Set use_cache=True to get KV values as output during ONNX export - model.config.use_cache = True - super().__init__(model, **kwargs) + self.model.config.use_cache = True self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching self.model.qaic_config = qaic_config + self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) self.is_tlm = transformed - - self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms @@ -1453,7 +1533,7 @@ def from_pretrained( if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( - model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs + model, kv_offload=kv_offload ) return cls( model, @@ -1463,6 +1543,19 @@ def from_pretrained( **kwargs, ) + @property + def model_hash(self) -> str: + # Compute the hash with: model_config, continuous_batching, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(self.model.config.to_diff_dict())) + mhash.update(to_hashable({"continuous_batching": self.continuous_batching})) + mhash.update(to_hashable({"is_tlm": self.is_tlm})) + mhash.update(to_hashable({"qaic_config": self.model.qaic_config})) + mhash.update(to_hashable(self._transform_names())) + mhash.update(to_hashable(self.pretrained_model_name_or_path)) + mhash = mhash.hexdigest()[:16] + return mhash + @property def get_model_config(self) -> dict: return self.model.config.__dict__ @@ -1903,10 +1996,26 @@ def __init__(self, model: nn.Module, **kwargs): if not (model_class_name.endswith("ForConditionalGeneration")): raise TypeError(f"Required pytorch module with ForConditionalGeneration, got {model_class_name}") - model.config.use_cache = True - super().__init__(model, **kwargs) + super().__init__(model) + self.model.config.use_cache = True self.num_layers = model.config.num_hidden_layers - self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.pretrained_model_name_or_path = kwargs.get("pretrained_model_name_or_path", None) + + @property + def model_hash(self) -> str: + # NOTE: model_config.to_diff_dict() has "_name_or_path" attribute which is the model card name or path. + # Using same card name will result in same hash. But, using a relative path for one run and + # absolute path for another run will result in different hash. + # The added complexity to resolve different paths to same location is not worth pursuing. + # Instead, advise the user to always provide same relative paths or absolute paths for local models. + + # Compute the hash with: model_config, transforms + mhash = hashlib.sha256() + mhash.update(to_hashable(self.model.config.to_diff_dict())) + mhash.update(to_hashable(self._transform_names())) + mhash.update(to_hashable(self.pretrained_model_name_or_path)) + mhash = mhash.hexdigest()[:16] + return mhash @property def get_model_config(self) -> dict: @@ -2002,8 +2111,8 @@ def compile( custom_io[output_name] = kv_cache_dtype return self._compile( - onnx_path=onnx_path, - compile_dir=compile_dir, + onnx_path, + compile_dir, compile_only=True, retained_state=True, specializations=specializations, diff --git a/examples/intern_example/internvl_inference.py b/examples/intern_example/internvl_inference.py index eba8c10d5..1e3030742 100644 --- a/examples/intern_example/internvl_inference.py +++ b/examples/intern_example/internvl_inference.py @@ -16,6 +16,10 @@ from torchvision.transforms.functional import InterpolationMode from transformers import AutoTokenizer, TextStreamer +import decord +import numpy as np +from decord import VideoReader, cpu + from QEfficient import QEFFAutoModelForCausalLM from QEfficient.utils.logging_utils import logger @@ -130,6 +134,34 @@ def load_image(self, image, input_size=448, max_num=12): pixel_values = torch.stack(pixel_values) return pixel_values + def get_index(self, bound, fps, max_frame, first_idx=0, num_segments=13): + start, end = -100000, 100000 + start_idx = max(first_idx, round(start * fps)) + end_idx = min(round(end * fps), max_frame) + seg_size = float(end_idx - start_idx) / num_segments + frame_indices = np.array([ + int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) + for idx in range(num_segments) + ]) + return frame_indices + + def load_video(self, video_path:str, bound=None, input_size=448, max_num=1, num_segments=13): + vr = VideoReader(video_path, ctx=cpu(0)) + max_frame = len(vr) - 1 + fps = float(vr.get_avg_fps()) + pixel_values_list, num_patches_list = [], [] + transform = self.build_transform(input_size=input_size) + frame_indices = self.get_index(bound, fps, max_frame, first_idx=0, num_segments=num_segments) + for frame_index in frame_indices: + img = Image.fromarray(vr[frame_index].asnumpy()).convert('RGB') + img = self.dynamic_preprocess(img, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(tile) for tile in img] + pixel_values = torch.stack(pixel_values) + num_patches_list.append(pixel_values.shape[0]) + pixel_values_list.append(pixel_values) + pixel_values = torch.cat(pixel_values_list) + return pixel_values, num_patches_list + def __call__( self, pixel_values, @@ -167,13 +199,15 @@ def __call__( def run_intern_on_aic( model_name, prompt, - image_url, messages, roles, kv_offload=False, prefill_seq_len=3840, num_devices=1, num_cores=16, + multi_frame_inference=False, + image_url=None, + video_path=None, ): ## STEP 1 -- LOAD THE MODEL @@ -188,7 +222,10 @@ def run_intern_on_aic( num_cores=num_cores, num_devices=num_devices, prefill_seq_len=prefill_seq_len, - mxfp6_matmul=False, + mxfp6_matmul=True, + mxint8_kv_cache=True, + allow_mxint8_mdp_io=True, + aic_enable_depth_first=True, ) ## STEP 3 -- SETUP THE PROCESSOR @@ -198,16 +235,20 @@ def run_intern_on_aic( internProcessor = InternProcessor(model.model, tokenizer) ## STEP 4 -- PREPROCESS THE INPUTS + if multi_frame_inference: + pixel_values, num_patches_list = internProcessor.load_video(video_path) + video_prefix = ''.join([f'Frame{i+1}: \n' for i in range(len(num_patches_list))]) + question = video_prefix + prompt + else: + response = requests.get(image_url, stream=True) + img = Image.open(BytesIO(response.content)).convert("RGB") + # img = Image.open(image_url).convert("RGB") + # Images are resized to (1000, 747) for inference + image = img.resize((1000, 747)) + # preprocess the resized image + pixel_values = internProcessor.load_image(image, max_num=12) + question = "\n" + prompt - img = requests.get(image_url, stream=True) - image = Image.open(BytesIO(img.content)).convert("RGB") - - # Images are resized to (1000, 747) for inference - image = image.resize((1000, 747)) - - # preprocess the resized image - pixel_values = internProcessor.load_image(image, max_num=12) - question = "\n" + prompt query = internProcessor(pixel_values, question, messages, roles) inputs = tokenizer( query, return_tensors="pt", padding="max_length", max_length=prefill_seq_len, padding_side="right" @@ -217,27 +258,28 @@ def run_intern_on_aic( ## STEP 5 -- RUN INFERENCE VIA GENERATE FUNCTION streamer = TextStreamer(tokenizer) - model.generate(inputs=inputs, streamer=streamer, generation_len=128) + if kv_offload: + outputs=model.generate(inputs=inputs, streamer=streamer,device_id_lang=[16,17,18,19], device_id_vision=[20,21,22,23], generation_len=128) + else: + outputs=model.generate(inputs=inputs, streamer=streamer,device_ids=[24,25,26,27], generation_len=128) + print(outputs) if __name__ == "__main__": - model_name = "OpenGVLab/InternVL2_5-1B" + model_name = "OpenGVLab/InternVL3-8B" # Chat Template information for prompt preprocessing messages: List[List[str]] = [] roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") - # Inputs for the model - prompt = "Please describe the image in detail." - image_url = "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg" - ## Compilation parameters # `kv_offload` is used to compile the model in a Single QPC or 2 QPCs. # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. # The outputs of the Vision Encoder are then passed to the Language model via host in this case. - kv_offload = False + kv_offload = True + multi_frame_inference=True # InternVL is an Early-Fusion model that uses placeholder tokens within the input_ids to interleave text_embeddings with # Image embeddings and generate final input_embeds for outout generation. Hence we need very large prefill_seq_len (3840 in this case) to @@ -247,17 +289,37 @@ def run_intern_on_aic( num_devices = 4 num_cores = 16 - run_intern_on_aic( - model_name=model_name, - prompt=prompt, - image_url=image_url, - messages=messages, - roles=roles, - kv_offload=kv_offload, - prefill_seq_len=prefill_seq_len, - num_devices=num_devices, - num_cores=num_cores, - ) + # Inputs for the model + if multi_frame_inference: + video_path = "/local/mnt/workspace/aditjadh/aisyssol/red-panda.mp4" + prompt="What is happening in this video" + run_intern_on_aic( + model_name=model_name, + prompt=prompt, + messages=messages, + roles=roles, + kv_offload=kv_offload, + prefill_seq_len=prefill_seq_len, + num_devices=num_devices, + num_cores=num_cores, + multi_frame_inference=multi_frame_inference, + video_path=video_path, + ) + else: + image_url = "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg" + prompt="Describe the image" + run_intern_on_aic( + model_name=model_name, + prompt=prompt, + image_url=image_url, + messages=messages, + roles=roles, + kv_offload=kv_offload, + prefill_seq_len=prefill_seq_len, + num_devices=num_devices, + num_cores=num_cores, + multi_frame_inference=multi_frame_inference, + ) """ diff --git a/examples/llama4_multi_image_example.py b/examples/llama4_example/llama4_multi_image_example.py similarity index 69% rename from examples/llama4_multi_image_example.py rename to examples/llama4_example/llama4_multi_image_example.py index 868f2b4c8..a8c5ee25b 100644 --- a/examples/llama4_multi_image_example.py +++ b/examples/llama4_example/llama4_multi_image_example.py @@ -7,30 +7,29 @@ import torch import transformers -from transformers import AutoConfig, AutoProcessor, TextStreamer +from transformers import AutoConfig, AutoModelForImageTextToText, AutoProcessor, TextStreamer from QEfficient import QEFFAutoModelForImageTextToText -model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +model_id = "/local/mnt/workspace/aditjadh/aisyssol/models--meta-llama--Llama-4-Scout-17B-16E-Instruct/snapshots/7dab2f5f854fe665b6b2f1eccbd3c48e5f627ad8" config = AutoConfig.from_pretrained(model_id) -# For Testing Purpose Only -config.text_config.num_hidden_layers = 4 -config.vision_config.num_hidden_layers = 2 -qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( - model_id, attn_implementation="eager", kv_offload=True, config=config -) -tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager", config=config) +model.eval() +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) processor = AutoProcessor.from_pretrained(model_id) +### For running the model in single QPC approach use kv_offload=False. For Dual QPC approach use kv_offload=True ### +qeff_model = QEFFAutoModelForImageTextToText(model, kv_offload=True) + ### For multi-image, the value of max_num_tiles should be the sum of the num_tiles values across all the images ### qeff_model.compile( prefill_seq_len=128, - ctx_len=5376, + ctx_len=8192, img_size=336, num_cores=16, - num_devices=8, - max_num_tiles=34, + num_devices=4, + max_num_tiles=45, mxfp6_matmul=True, mxint8_kv_cache=True, aic_enable_depth_first=True, @@ -69,7 +68,5 @@ inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) streamer = TextStreamer(tokenizer) -output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3, 4, 5, 6, 7], generation_len=100) -print(output.generated_ids) -print(tokenizer.batch_decode(output.generated_ids)) +output = qeff_model.generate(inputs=inputs, device_id_vision=[32,33,34,35], device_id_lang=[36,37,38,39], generation_len=100) print(output) diff --git a/examples/llama4_example/llama4_multi_image_inference.py b/examples/llama4_example/llama4_multi_image_inference.py new file mode 100644 index 000000000..922923f45 --- /dev/null +++ b/examples/llama4_example/llama4_multi_image_inference.py @@ -0,0 +1,330 @@ +from io import BytesIO +from typing import List + +from time import perf_counter +import transformers +import numpy as np +import torch +import torch.nn as nn +import torchvision.transforms as T +from torchvision.transforms.functional import InterpolationMode +from typing import Dict, List, Optional, Tuple, Union +from transformers import ( + AutoConfig, + AutoProcessor, + AutoModelForImageTextToText, + TextStreamer, +) + +import os +import json +import torch +import decord +from decord import VideoReader, cpu +from QEfficient import QEFFAutoModelForImageTextToText +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.generation.text_generation_inference import ( + get_compilation_dims, +) +from QEfficient.generation.text_generation_inference import ( + CloudAI100ExecInfoNew, + PerfMetrics, + calculate_latency, + get_compilation_dims, +) + +from PIL import Image +import requests + + +def get_index(fps, max_frame, first_idx=0,num_frames=8): + start, end = -100000, 100000 + start_idx = max(first_idx, round(start * fps)) + end_idx = min(round(end * fps), max_frame) + seg_size = float(end_idx - start_idx) / num_frames + frame_indices = np.array([ + int(start_idx + (seg_size / 2) + np.round(seg_size * idx)) + for idx in range(num_frames) + ]) + return frame_indices + +def load_video(video_path:str, output_dir:str, input_size=336, num_frames=13): + vr = VideoReader(video_path, ctx=cpu(0)) + max_frame = len(vr) - 1 + fps = float(vr.get_avg_fps()) + # transform = build_transform(input_size=input_size) + frame_indices = get_index(fps, max_frame, first_idx=0, num_frames=num_frames) + for i, frame_index in enumerate(frame_indices): + frame = vr[frame_index].asnumpy() + if frame.ndim == 4: + frame = frame[0] + image = Image.fromarray(frame) + # image = transform(image) + image = image.resize((672,672)) + path = f"{output_dir}/{i}.jpg" + image.save(path) + +def main( + model: str, + hf_token: str, + qpc_vision: str, + qpc_text: str, + prompt: str, + output_dir: str, + video_path: str, + device_id_vision: List[int] = [0,1], + device_id_text: List[int] = [0,1], + generation_len: Optional[int] = None, +): + config = AutoConfig.from_pretrained(model, token=hf_token) + tokenizer = transformers.AutoTokenizer.from_pretrained(model, token=hf_token, trust_remote_code=True) + processor = AutoProcessor.from_pretrained(model, token=hf_token) + streamer = TextStreamer(tokenizer) + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + load_video(video_path, output_dir, input_size=672, num_frames=8) + messages=[] + os.makedirs(output_dir, exist_ok=True) + + # Build the content list + content = [] + for filename in sorted(os.listdir(output_dir)): + if filename.lower().endswith((".png", ".jpg", ".jpeg", ".gif", ".bmp")): + image_path = os.path.join(output_dir, filename) + content.append({"type": "image", "url": image_path}) + + # Add the analysis prompt + content.append({ + "type": "text", + "text": ( + "You are a video analysis expert. Given a continuous set of frames from the video, your task is to generate a concise and informative summary. Focus on identifying key events, important dialogues, visual highlights, and emotional tone. Structure the summary to reflect the overall narrative or progression of the video. If the video contains multiple scenes or segments, break the summary into logical parts. Ensure the summary is clear, coherent, and suitable for someone who hasn’t watched the video." + ) + }) + + # Final messages structure + messages = [ + { + "role": "user", + "content": content, + } + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + + if not qpc_text: + raise TypeError("Please run compile API for language model first!") + + lang_session = QAICInferenceSession(qpc_text, device_id_text) + + if qpc_vision: + vision_session = QAICInferenceSession(qpc_vision, device_id_vision) + batch_size, ctx_len, fbs = get_compilation_dims(qpc_text) + + pad_token_id = 1 + + # Skip inputs/outputs + lang_session.skip_buffers( + [ + x + for x in lang_session.input_names + lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + + # Read prompt and ctx len from session + batch_size = max( + [x[lang_session.binding_index_map["input_ids"]][1][0] for x in lang_session.allowed_shapes] + + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[0]] + ) + + prefill_seq_len = max( + [x[lang_session.binding_index_map["input_ids"]][1][1] for x in lang_session.allowed_shapes] + + [lang_session.bindings[lang_session.binding_index_map["input_ids"]].dims[1]] + ) + + input_len = inputs["attention_mask"].sum(1, keepdims=True) + input_ids_length = inputs["input_ids"].shape[1] + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + if generation_len is None: + generation_len = ctx_len - input_len.max() + assert generation_len > 0, "generation length should be greater than zero" + generated_ids = np.full((batch_size, generation_len + 1), pad_token_id) + + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (0, padded_len - input_ids_length), + "constant", + 1, + ) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (0, padded_len - input_ids_length), "constant", 0 + ) + if "cross_attention_mask" in inputs: + inputs["cross_attention_mask"] = torch.nn.functional.pad( + inputs["cross_attention_mask"], (0, 0, 0, 0, 0, padded_len - input_ids_length) + ) + + for k, v in inputs.items(): + inputs[k] = np.array(v) + + vision_inputs = { + k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} + } + + llama4 = hasattr(config, "model_type") and config.model_type == "llama4" + if llama4: + qpc_base_path = os.path.dirname(os.path.normpath(qpc_text)) + specialization_file_path = os.path.join(qpc_base_path, "specializations.json") + if os.path.exists(specialization_file_path): + with open(specialization_file_path, "r") as file: + data = json.load(file) + else: + raise FileNotFoundError(f"expected specializations.json file at path, {qpc_base_path}") + num_patches = int(data["specializations"][0]["max_num_tiles"]) + if vision_inputs['pixel_values'].shape[0] != num_patches: + single_patch = np.expand_dims(vision_inputs['pixel_values'][0], axis=0) + while vision_inputs['pixel_values'].shape[0] < num_patches: + vision_inputs['pixel_values'] = np.concatenate((vision_inputs['pixel_values'], single_patch), axis=0) + + + if vision_inputs: + vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + vision_start = perf_counter() + + vision_outputs = {} + if vision_inputs: + vision_outputs = vision_session.run(vision_inputs) + vision_end = perf_counter() + + lang_inputs = {k: v for k, v in inputs.items() if k not in vision_inputs} + lang_inputs["position_ids"] = np.where( + lang_inputs.pop("attention_mask"), np.arange(padded_len), -1 + ) # Need to use -1 as position_ids for invalid tokens + + if qpc_vision: + vision_session.deactivate() + lang_session.activate() + + lang_session.set_buffers(vision_outputs) + + # Prepare inputs for prefill + chunk_inputs = lang_inputs.copy() + prefill_start = perf_counter() + + # Prepare inputs for prefill + chunk_inputs = lang_inputs.copy() + prefill_start = perf_counter() + + # Run prefill + chunk_inputs = lang_inputs.copy() + for i in range(num_chunks): + chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] + chunk_inputs["position_ids"] = lang_inputs["position_ids"][ + :, i * prefill_seq_len : (i + 1) * prefill_seq_len + ] + outputs = lang_session.run(chunk_inputs) + chunk_inputs["image_idx"] = outputs["image_idx_output"] + + prefill_time = perf_counter() - prefill_start + vision_end - vision_start + # Skip inputs/outputs again + lang_session.skip_buffers( + [ + x + for x in lang_session.input_names + lang_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + ) + + # Get first token + lang_inputs["input_ids"] = outputs["logits"].argmax(2) + lang_inputs["position_ids"] = input_len.numpy() + if "cross_attention_mask" in lang_inputs: + bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape + lang_inputs["cross_attention_mask"] = torch.ones((bs, 1, num_images, img_tiles), dtype=torch.int64).numpy() + generated_ids[:, 0] = lang_inputs["input_ids"].squeeze(1) + + if streamer: + streamer.put(lang_inputs["input_ids"][0]) + + # Decode loop + decode_start = perf_counter() + for num_token in range(1, generation_len): + outputs = lang_session.run(lang_inputs) + + # Prepare inputs for next iteration + lang_inputs["input_ids"] = outputs["logits"].argmax(2) + lang_inputs["position_ids"] += 1 + generated_ids[:, num_token] = lang_inputs["input_ids"].squeeze(1) + if streamer: + streamer.put(lang_inputs["input_ids"][0]) + + decode_end = perf_counter() + if streamer: + streamer.end() + + decode_perf = (num_token - 1) / (decode_end - decode_start) + total_time = decode_end - decode_start + prefill_time + total_perf = num_token / total_time + + print(CloudAI100ExecInfoNew( + batch_size=batch_size, + generated_ids=generated_ids, + perf_metrics=PerfMetrics( + prefill_time=prefill_time, decode_perf=decode_perf, total_perf=total_perf, total_time=total_time + ), + ) + ) + +if __name__ == "__main__": + import argparse + + argp = argparse.ArgumentParser() + argp.add_argument("--model", required=True, help="Model name to run") + argp.add_argument("--hf-token", required=True, help="Hugging Face Token") + argp.add_argument("--qpc-vision", required=True, help="Compiled binary QPC of image and text input model") + argp.add_argument("--qpc-text", required=True, help="Compiled binary QPC of text only model") + argp.add_argument( + "--prompt", + type=lambda prompt: prompt.split("|"), + default="Please describe the image in detail.", + help="Input prompt(s) to generate for (pipe-separated)", + ) + # argp.add_argument("--image", required=True, help="Image to be passed as input") + argp.add_argument( + "--device-id-vision", + type=lambda device_ids: [int(x) for x in device_ids.split(",")], + help="QAIC device ids (comma-separated)", + default=[0,1,2,3] + ) + argp.add_argument( + "--device-id-text", + type=lambda device_ids: [int(x) for x in device_ids.split(",")], + help="QAIC device ids (comma-separated)", + default=[0,1,2,3] + ) + argp.add_argument( + "--generation-len", + type=int, + help="Number of tokens to generate. \ + Note: For models without rolling buffer, (generation length + input length) should \ + be less than model context length", + ) + + + argp.add_argument("--video-path", required=True, help="Path to the input video file") + argp.add_argument("--output-dir", required=True, help="Directory to save the output") + + args = argp.parse_args() + main(**vars(args)) diff --git a/examples/llama4_example/run_llama4.sh b/examples/llama4_example/run_llama4.sh new file mode 100644 index 000000000..7c819b87c --- /dev/null +++ b/examples/llama4_example/run_llama4.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +PROMPT="What is happening in the video" +MODEL_NAME="" +HF_TOKEN="" +QPC_VISION="" +QPC_TEXT="" +DEVICE_ID_VISION="24,25,26,27" +DEVICE_ID_TEXT="28,29,30,31" +GENERATION_LEN=256 +VIDEO_PATH="" +OUTPUT_DIR="" + +# Run the Python script +python ./llama4_multi_image_inference.py \ + --model "${MODEL_NAME}" \ + --hf-token "${HF_TOKEN}" \ + --qpc-vision "${QPC_VISION}" \ + --qpc-text "${QPC_TEXT}" \ + --prompt "${PROMPT}" \ + --device-id-vision ${DEVICE_ID_VISION} \ + --device-id-text ${DEVICE_ID_TEXT} \ + --generation-len ${GENERATION_LEN} \ + --video-path "${VIDEO_PATH}" \ + --output-dir "${OUTPUT_DIR}"