From 1ad479bdd3c40a22632de6e599c950de639b8ae9 Mon Sep 17 00:00:00 2001 From: Mohit Soni Date: Mon, 8 Sep 2025 16:38:18 +0000 Subject: [PATCH] Onboarding Molmo Model Signed-off-by: Mohit Soni --- .../transformers/models/modeling_auto.py | 31 +- .../transformers/models/molmo/__init__.py | 6 + .../models/molmo/modeling_molmo.py | 802 ++++++++++++++++++ .../transformers/models/pytorch_transforms.py | 29 + examples/molmo_example.py | 86 ++ 5 files changed, 949 insertions(+), 5 deletions(-) create mode 100644 QEfficient/transformers/models/molmo/__init__.py create mode 100644 QEfficient/transformers/models/molmo/modeling_molmo.py create mode 100644 examples/molmo_example.py diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 42898381d..3cf136869 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -479,7 +479,9 @@ def model_name(self) -> str: @property def get_model_config(self) -> dict: - return self.model.model.vision_model.config.__dict__ + if hasattr(self.model.model, "vision_model"): + return self.model.model.vision_model.config.__dict__ + return self.model.model.config.__dict__ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): @@ -536,7 +538,9 @@ def model_name(self) -> str: @property def get_model_config(self) -> dict: - return self.model.language_model.config.__dict__ + if hasattr(self.model, "language_model"): + return self.model.language_model.config.__dict__ + return self.model.config.__dict__ class _QEffAutoModelForImageTextToTextDualQPC: @@ -652,7 +656,11 @@ def compile( custom_io_vision = {} kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" + molmo = hasattr(self.model.config, "model_type") and self.model.config.model_type == "molmo" + if molmo: + custom_io_vision["image_masks"] = "float16" custom_io_vision["pixel_values"] = "float16" + for output_name in output_names["vision"]: if output_name.startswith("past_"): custom_io_vision[output_name] = kv_cache_dtype @@ -804,11 +812,18 @@ def kv_offload_generate( inputs[k] = np.array(v) vision_inputs = { - k: v for k, v in inputs.items() if k in {"pixel_values", "aspect_ratio_ids", "aspect_ratio_mask"} + k: v + for k, v in inputs.items() + if k + in {"pixel_values", "image_masks", "image_input_idx", "valid_idx", "aspect_ratio_ids", "aspect_ratio_mask"} } + molmo = hasattr(self.model.config, "model_type") and self.model.config.model_type == "molmo" + if vision_inputs: vision_inputs["pixel_values"] = vision_inputs["pixel_values"].astype("float16") + if molmo: + vision_inputs["image_masks"] = vision_inputs["image_masks"].astype("float16") vision_start = perf_counter() vision_outputs = {} @@ -923,7 +938,10 @@ def __init__( self.model.config.llm_config._attn_implementation = "eager" self.model.config.vision_config.use_flash_attn = "false" else: - self.model.config.text_config.use_cache = True + if hasattr(self.model.config, "text_config"): + self.model.config.text_config.use_cache = True + else: + self.model.config.use_cache = True self.hash_params["qeff_auto_class"] = self.__class__.__name__ @classmethod @@ -1292,7 +1310,10 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) -MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = {"InternVLChatModel": QEFFAutoModelForImageTextToText} +MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = { + "InternVLChatModel": QEFFAutoModelForImageTextToText, + "MolmoForCausalLM": QEFFAutoModelForImageTextToText, +} class QEFFAutoModelForCausalLM(QEFFBaseModel): diff --git a/QEfficient/transformers/models/molmo/__init__.py b/QEfficient/transformers/models/molmo/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/QEfficient/transformers/models/molmo/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py new file mode 100644 index 000000000..6d81f007b --- /dev/null +++ b/QEfficient/transformers/models/molmo/modeling_molmo.py @@ -0,0 +1,802 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import math +from typing import Callable, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ModelOutput + +from QEfficient.transformers.cache_utils import QEffDynamicCache +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils import constants +from QEfficient.utils._utils import IOInfo, get_padding_shape_from_config + + +def _non_meta_init_device(config) -> torch.device: + if config.init_device is not None and config.init_device != "meta": + return torch.device(config.init_device) + else: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +def eager_attention_forward( + module: nn.Module, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_mask: Optional[torch.Tensor], + dropout_p: float = 0.0, + **kwargs, +): + scale_factor = 1 / math.sqrt(q.size(-1)) + num_kv_heads = k.size(1) + num_q_heads = q.size(1) + + if num_q_heads != num_kv_heads: + assert num_q_heads % num_kv_heads == 0 + repeat_factor = num_q_heads // num_kv_heads + _, _, S, D = k.shape + k = k.unsqueeze(2) + k = k.expand(-1, -1, repeat_factor, -1, -1) + k = k.reshape(1, num_q_heads, S, D) + + v = v.unsqueeze(2) + v = v.expand(-1, -1, repeat_factor, -1, -1) + v = v.reshape(1, num_q_heads, S, D) + + attn_weights = torch.matmul(q, k.transpose(2, 3)) * scale_factor + + if attention_mask is not None: + attn_weights = torch.where(attention_mask, torch.tensor(-10000.0, dtype=torch.float32), attn_weights) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) + attn_output = torch.matmul(attn_weights, v) + + return attn_output, attn_weights + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + B, nh, T, hs = x.size() + x = x.view(B, nh, T, 2, hs // 2) + x1, x2 = x.unbind(dim=-2) + return torch.cat((-x2, x1), dim=-1) + + +def rotate_every_two(x: torch.Tensor) -> torch.Tensor: + B, nh, T, hs = x.size() + x = x.view(B, nh, T, hs // 2, 2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return x.view(B, nh, T, hs) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, config, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + # Apply rotation + if config.rope_impl == "interleave": + q_embed = (q * cos) + (rotate_every_two(q) * sin) + k_embed = (k * cos) + (rotate_every_two(k) * sin) + else: + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + # Cast back to original dtype + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +class QEffMolmoRotaryEmbedding(nn.Module): + """ + Copied from Olmo2RotaryEmbedding: https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmo2/modeling_olmo2.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config, device=None): + super().__init__() + dim = config.d_model // config.n_heads + self.inv_freq = 1.0 / (config.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)) + self.original_max_seq_len = config.max_position_embeddings or config.max_sequence_length + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=_non_meta_init_device(config), dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +class QEffMolmoBlock(nn.Module): + def __qeff_init__(self): + self.rotary_emb = QEffMolmoRotaryEmbedding(config=self.config) + + def attention( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + layer_past: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + B, T, C = q.size() # batch size, sequence length, d_model + dtype = k.dtype + + # Optionally apply layer norm to keys and queries. + if self.q_norm is not None and self.k_norm is not None: + q = self.q_norm(q).to(dtype=dtype) + k = self.k_norm(k).to(dtype=dtype) + + # Move head forward to be next to the batch dim. + # shape: (B, nh, T, hs) + q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2) + # shape: (B, n_kv_h, T, hs) + k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) + # shape: (B, n_kv_h, T, hs) + v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2) + + if self.config.use_position_ids and self.config.rope: + # Apply rotary embeddings + kv_seq_len = k.shape[-2] + # Apply rotary embeddings + kv_seq_len = layer_past.get_usable_length(kv_seq_len, self.layer_id) + cos, sin = self.rotary_emb(v, seq_len=kv_seq_len) + q, k = qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, self.config) + + if not self.config.use_position_ids and self.config.rope: + kv_seq_len = k.shape[-2] + # Apply rotary embeddings + kv_seq_len = layer_past.get_usable_length(kv_seq_len, self.layer_id) + cos, sin = self.rotary_emb(v, seq_len=kv_seq_len) + q, k = qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, self.config) + + if layer_past is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + k, v = layer_past.update(k, v, self.layer_id, cache_kwargs) + + attention_interface: Callable = eager_attention_forward + + attn_output, attn_weights = attention_interface( + self, + q, + k, + v, + attention_bias, + dropout_p=0.0 if not self.training else self.config.attention_dropout, + **kwargs, + ) + + # Re-assemble all head outputs side-by-side. + att = attn_output.transpose(1, 2).contiguous().view(B, T, C) + + # Apply output projection. + return self.attn_out(att), layer_past + + +class QEffMolmoSequentialBlock(nn.Module): + def __qeff_init__(self): + self.rotary_emb = QEffMolmoRotaryEmbedding(config=self.config) + + def forward( + self, + x: torch.Tensor, + attention_bias: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + layer_past: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: + if not self.config.norm_after: + if self._activation_checkpoint_fn is not None: + atten_in = self._activation_checkpoint_fn(self.attn_norm, x) + else: + atten_in = self.attn_norm(x) + else: + atten_in = x + qkv = self.att_proj(atten_in) + + if self.config.clip_qkv is not None: + qkv.clamp_(min=-self.config.clip_qkv, max=self.config.clip_qkv) + + q, k, v = qkv.split(self.fused_dims, dim=-1) + + # Get attention scores. + if self._activation_checkpoint_fn is not None: + att, cache = self._activation_checkpoint_fn( # type: ignore + self.attention, + q, + k, + v, + attention_bias, + position_ids=position_ids, + layer_past=layer_past, + use_cache=use_cache, + ) + else: + att, cache = self.attention( + q, + k, + v, + attention_bias, + position_ids=position_ids, + layer_past=layer_past, + batch_index=batch_index, + use_cache=use_cache, + ) + + if self.config.norm_after: + if self._activation_checkpoint_fn is not None: + att = self._activation_checkpoint_fn(self.attn_norm, att) + else: + att = self.attn_norm(att) + + # Add attention scores. + # shape: (B, T, C) + x = x + self.dropout(att) + + # Add feed-forward projection. + # shape: (batch_size, seq_len, d_model) + og_x = x + + if not self.config.norm_after: + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) + + x = self.ff_proj(x) + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.act, x) # type: ignore + else: + x = self.act(x) + x = self.ff_out(x) + + if self.config.norm_after: + if self._activation_checkpoint_fn is not None: + x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore + else: + x = self.ff_norm(x) + + x = self.dropout(x) + x = og_x + x + + return x, cache + + +class QEffMolmo(nn.Module): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + input_embeddings: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + attention_bias: Optional[torch.Tensor] = None, + response_mask: Optional[torch.Tensor] = None, + images: Optional[torch.Tensor] = None, + image_masks: Optional[torch.Tensor] = None, + image_input_idx: Optional[torch.Tensor] = None, + subsegment_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + batch_index: Optional[torch.LongTensor] = None, + use_cache: bool = False, + last_logits_only: bool = False, + output_hidden_states: Optional[bool] = None, + append_last_valid_logits: Optional[torch.Tensor] = None, + **kwargs, + ) -> ModelOutput: + """ + :param input_ids: A tensor of shape `(batch_size, seq_len)`. + :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input + embeddings. When provided, it is treated as the output of the input embedding layer. + :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates + which input IDs are masked. A `1` value in the mask means that + the corresponding input ID should *not* be ignored. A `0` means + that the corresponding input ID is masked. + + This has the same meaning as the `attention_mask` in HuggingFace's `transformers` + library. + :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`, + `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used + to introduce causal or other biases. + + If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]` + indicates that the i-th element in the sequence is allowed to attend to the j-th + element in the sequence. + + If the tensor is a float tensor, it will just be added to the attention + scores before the softmax. + + The default is causal, which corresponds to a lower-diagonal byte matrix of ones. + :param response_mask: A tensor of shape `(batch_size, seq_len)` that indicates + the response mask. A `1` value in the mask means that the corresponding token + is a response token. A `0` means that the corresponding token is not + a response token. + :param past_key_values: Pre-computed keys and values for each attention block. + Can be used to speed up sequential decoding. The `input_ids` which have + their past given to this model should not be passed as `input_ids` as they have already been computed. + :param use_cache: If `True`, return key and value tensors for each block. + :param last_logits_only: If `True`, only compute the logits for the last token of each sequence. + This can speed up decoding when you only care about the next token. + """ + + output_hidden_states = output_hidden_states if output_hidden_states is not None else False + + if past_key_values: + assert len(past_key_values) == self.config.n_layers + + has_image = images is not None + + assert not (has_image and input_embeddings is not None), "Cannot provide both images and input embeddings." + # assert not (has_image and past_key_values is not None), "Cached key and values should not be used with images." + + batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2] + if past_key_values is None: + past_length = 0 + else: + past_length = past_key_values[0][0].size(-2) + + if use_cache and not isinstance(past_key_values, Cache): + past_key_values = QEffDynamicCache.from_legacy_cache(past_key_values) + + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_seen_tokens) + + if self.config.use_position_ids and attention_mask is None: + attention_mask = input_ids != -1 + + if subsegment_ids is not None: + assert not use_cache, "Subsegment_ids cannot be used with cache." + subsegment_mask = subsegment_ids.unsqueeze(2) <= subsegment_ids.unsqueeze(1) + attention_mask = ( + subsegment_mask.to(attention_mask.dtype) * attention_mask.unsqueeze(2) * attention_mask.unsqueeze(1) + ) + if position_ids is None: + raise ValueError("Positioned ids must be given if using subsegment_ids") + else: + if self.config.use_position_ids and position_ids is None: + position_ids = torch.clamp( + torch.cumsum(attention_mask.to(torch.int32), dim=-1) - 1, + min=0, + ).broadcast_to((batch_size, attention_mask.shape[-1])) + + x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore + + if not self.config.rope: + # Get positional embeddings. + # shape: (1, seq_len) + pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0) + # shape: (1, seq_len, d_model) + pos_emb = self.transformer.wpe(pos) # type: ignore + x = pos_emb + x + + # Add input + positional embeddings and apply dropout. + # shape: (batch_size, seq_len, d_model) + x = self.transformer.emb_drop(x) # type: ignore + + # normalized + if self.config.normalize_input_embeds: + x = x * (self.config.d_model**0.5) + + # decoder layers + all_hidden_states = [] + + # Apply blocks one-by-one. + if self.config.block_group_size == 1: + for block_idx, block in enumerate(self.transformer.blocks): + if output_hidden_states: + # add hidden states + all_hidden_states.append(x) + + layer_past = None if past_key_values is None else past_key_values + x, past_key_values = block( + x, + attention_bias=causal_mask, + position_ids=position_ids, + layer_past=layer_past, + batch_index=batch_index, + use_cache=use_cache, + ) + + else: + for group_idx, block_group in enumerate(self.transformer.block_groups): + if output_hidden_states: + # add hidden states + all_hidden_states.append(x) + + layers_past = ( + None + if past_key_values is None + else past_key_values[ + group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size + ] + ) + x, past_key_values = block_group( + x, + attention_bias=causal_mask, + position_ids=position_ids, + layers_past=layers_past, + use_cache=use_cache, + ) + + # Cast to INT32 to avoid issue while running in ONNXRT + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = x[torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + + x = self.transformer.ln_f(hidden_states) # type: ignore + if output_hidden_states: + # add final hidden state post-final-layernorm, following HuggingFace's convention + all_hidden_states.append(x) + + # Get logits. + # shape: (batch_size, seq_len or 1, vocab_size) + if self.config.weight_tying: + logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore + else: + logits = self.transformer.ff_out(x) # type: ignore + if self.config.scale_logits: + logits.mul_(1 / math.sqrt(self.config.d_model)) + + if use_cache: + next_cache = past_key_values.to_legacy_cache() + + return ModelOutput( + logits=logits, + past_key_values=next_cache, + hidden_states=tuple(all_hidden_states) if output_hidden_states else None, + ) # type: ignore[arg-type] + + +class QEffMolmoEncoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, pixel_values, image_masks, image_input_idx, valid_idx): + image_features, _ = self.model.model.vision_backbone(pixel_values, image_masks) + num_image, num_patch = image_features.shape[1:3] + batch_size = image_input_idx.shape[0] + image_features = image_features.view(batch_size, num_image * num_patch, -1) + image_input_idx = image_input_idx.view(batch_size, num_image * num_patch) + + image_input_idx = image_input_idx[0, valid_idx] + sorted_indices = torch.argsort(image_input_idx) + + return image_features[0, valid_idx][0, sorted_indices] + + +class QEffMolmoDecoderWrapper(nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + # self.language_model = self.model.language_model + self.config = self.model.config + + def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + if input_ids is not None: + input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) + inputs_embeds = self.model.model.transformer.wte(input_ids) + selected = input_ids == 152066 + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = vision_embeds[indices0, indices1] + image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded + inputs_embeds, inputs_embeds) + # + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) + outputs = self.model.model.forward( + input_embeddings=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + ) + next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) + return outputs.logits, vision_embeds, image_idx, outputs.past_key_values + + +class QEffMolmoModel(nn.Module): + def get_qeff_vision_encoder(self): + return QEffMolmoEncoderWrapper(self) + + def get_qeff_language_decoder(self): + return QEffMolmoDecoderWrapper(self) + + """ + Copied from Llama4ForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama4/modeling_llama.py + The only differences are: + - add new args cache idx for the kv retention + """ + + def forward( + self, pixel_values, image_masks, image_input_idx, valid_idx, input_ids, position_ids, image_idx, past_key_values + ): + image_features, _ = self.model.vision_backbone(pixel_values, image_masks) + num_image, num_patch = image_features.shape[1:3] + batch_size = image_input_idx.shape[0] + image_features = image_features.view(batch_size, num_image * num_patch, -1) + image_input_idx = image_input_idx.view(batch_size, num_image * num_patch) + + valid = image_input_idx >= 0 + indices0 = torch.arange(valid.unsqueeze(0).shape[0]).view(-1, 1) + + image_input_idx = image_input_idx[0, valid_idx] + sorted_indices = torch.argsort(image_input_idx) + + vision_embeds = image_features[0, valid_idx][0, sorted_indices] + + if input_ids is not None: + input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) + + inputs_embeds = self.model.transformer.wte(input_ids) + selected = input_ids == 152066 + indices1 = selected.to(torch.int64).cumsum(1) - 1 + indices1 = torch.where(indices1 != -1, indices1 + image_idx, indices1) + indices0 = torch.arange(selected.unsqueeze(0).shape[0]).view(-1, 1) + image_features_expanded = vision_embeds[indices0, indices1] + image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded + inputs_embeds, inputs_embeds) + + inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) + outputs = self.model.forward( + input_embeddings=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + ) + next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) + image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) + + return outputs.logits, pixel_values, image_idx, outputs.past_key_values + + def get_specializations( + self, + batch_size: int, + prefill_seq_len: int, + ctx_len: int, + img_size: int, + kv_offload: bool = False, + **compiler_options, + ): + prefill_seq_len = prefill_seq_len if prefill_seq_len else 1024 + ctx_len = ctx_len if ctx_len else constants.INTERN_CTX_LEN + + img_size = 588 + img_tile = 576 + num_images = 5 + num_patch = 144 + valid_size = 544 + vision = [ + { + "batch_size": batch_size, + "img_size": img_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "img_tile": img_tile, + "num_images": num_images, + "num_patch": num_patch, + "valid_size": valid_size, + } + ] + + lang_prefill = { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + } + + lang_decode = { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + } + + if not kv_offload: + lang_prefill["img_size"] = img_size + lang_prefill["img_tile"] = img_tile + lang_prefill["num_images"] = num_images + lang_prefill["num_patch"] = num_patch + lang_prefill["valid_size"] = valid_size + lang_decode["img_size"] = img_size + lang_decode["img_tile"] = img_tile + lang_decode["num_images"] = num_images + lang_decode["num_patch"] = num_patch + lang_decode["valid_size"] = valid_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) + specializations = {} + + if kv_offload: + specializations["vision"] = vision + specializations["lang"] = lang + return specializations, compiler_options + else: + return lang, compiler_options + + def get_onnx_dynamic_axes(self, kv_offload: bool = False): + # Define dynamic axes + vision_dynamic_axes = {} + lang_dynamic_axes = {} + lang_dynamic_axes["input_ids"] = {0: "batch_size", 1: "seq_len"} + lang_dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} + + vision_dynamic_axes["pixel_values"] = {0: "batch_size", 1: "num_images", 2: "img_tile", 3: "img_size"} + vision_dynamic_axes["image_input_idx"] = {0: "batch_size", 1: "num_images", 2: "num_patch"} + vision_dynamic_axes["image_masks"] = {0: "batch_size", 1: "num_images", 2: "img_tile"} + vision_dynamic_axes["valid_idx"] = {0: "batch_size", 1: "valid_size"} + + num_layers = self.model.config.n_layers + + for i in range(num_layers): + lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} + lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + dynamic_axes = {} + if kv_offload: + dynamic_axes["vision"] = vision_dynamic_axes + dynamic_axes["lang"] = lang_dynamic_axes + else: + dynamic_axes = {**vision_dynamic_axes, **lang_dynamic_axes} + return dynamic_axes + + def get_output_names(self, kv_offload: bool = False): + vision_output_names = ["vision_embeds"] + lang_output_names = ["logits"] + + # + for i in range(self.model.config.n_layers): + for kv in ["key", "value"]: + lang_output_names.append(f"past_{kv}.{i}_RetainedState") + + output_names = {} + if kv_offload: + lang_output_names.insert(1, "vision_embeds_RetainedState") + lang_output_names.insert(2, "image_idx_output") + output_names["vision"] = vision_output_names + output_names["lang"] = lang_output_names + else: + lang_output_names.insert(1, "pixel_values_RetainedState") + lang_output_names.insert(2, "image_idx_output") + return lang_output_names + return output_names + + def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + inputs_shapes = {} + inputs_shapes_lang = {} + inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + + inputs_shapes["vision_embeds"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + 544, + self.config.hidden_size, + ) + inputs_shapes["position_ids"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + inputs_shapes["pixel_values"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + 5, + 576, + 588, + ) + + inputs_shapes["image_masks"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 5, 576) + + inputs_shapes["image_input_idx"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 5, 144) + + inputs_shapes_lang["image_input_idx"] = ( + constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + 720, + ) + + inputs_shapes["valid_idx"] = (1, 544) + + inputs_shapes["image_idx"] = (1, 1) + inputs_shapes["image_sizes"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 2) + # Define inputs + vision_inputs = {} + lang_inputs = {} + vision_inputs["pixel_values"] = torch.zeros((inputs_shapes["pixel_values"]), dtype=torch.float32) + vision_inputs["image_masks"] = torch.zeros((inputs_shapes["image_masks"]), dtype=torch.float32) + vision_inputs["image_input_idx"] = torch.zeros((inputs_shapes["image_input_idx"]), dtype=torch.int32) + + vision_inputs["valid_idx"] = torch.zeros((inputs_shapes["valid_idx"]), dtype=torch.int64) + + lang_inputs["input_ids"] = torch.zeros((inputs_shapes["input_ids"]), dtype=torch.int64) + lang_inputs["vision_embeds"] = torch.zeros((inputs_shapes["vision_embeds"]), dtype=torch.float32) + lang_inputs["position_ids"] = ( + torch.arange(constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, dtype=torch.int64) + .view(1, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) + .repeat(constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, 1) + ) + lang_inputs["image_idx"] = torch.zeros((inputs_shapes["image_idx"]), dtype=torch.int64) + # Add data for KV + kv_cache_shape = get_padding_shape_from_config( + config=self.config, + batch_size=constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, + seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, + ) + + lang_inputs["past_key_values"] = [[] for _ in range(self.model.config.n_layers)] + for i in range(self.model.config.n_layers): + for kv in ["key", "value"]: + lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + + inputs = {} + if kv_offload: + inputs["vision"] = vision_inputs + inputs["lang"] = lang_inputs + else: + lang_inputs.pop("vision_embeds") + inputs = {**vision_inputs, **lang_inputs} + + return inputs + + def get_inputs_info(self): + return [ + IOInfo(name="input_ids", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo(name="attention_mask", datatype=torch.int64, shape=("batch_size", "seq_len")), + IOInfo( + name="pixel_values", + datatype=torch.float32, + shape=("batch_size", "num_images", "img_tile", "img_size"), + ), + IOInfo( + name="image_masks", + datatype=torch.float32, + shape=("batch_size", "num_images", "img_tile"), + ), + IOInfo( + name="image_input_idx", + datatype=torch.int32, + shape=("batch_size", "num_images", "num_patches"), + ), + IOInfo( + name="valid_idx", + datatype=torch.int64, + shape=("batch_size", "valid_size"), + ), + ] diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index ca74c0ddd..af36388c7 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -279,6 +279,12 @@ QEffMllamaTextSelfAttention, QEffMllamaVisionModel, ) +from QEfficient.transformers.models.molmo.modeling_molmo import ( + QEffMolmo, + QEffMolmoBlock, + QEffMolmoModel, + QEffMolmoSequentialBlock, +) from QEfficient.transformers.models.mpt.modeling_mpt import ( QEffMptAttention, QEffMptBlock, @@ -595,6 +601,29 @@ class KVCacheExternalModuleMapperTransform(ExternalModuleMapperTransform): "get_qeff_language_decoder": QEffInternVLModel.get_qeff_language_decoder, }, "InternVisionEmbeddings": {"forward": QEffInternVisionEmbeddings.forward}, + # Mapping for Molmo + "MolmoForCausalLM": { + "forward": QEffMolmoModel.forward, + "get_qeff_vision_encoder": QEffMolmoModel.get_qeff_vision_encoder, + "get_qeff_language_decoder": QEffMolmoModel.get_qeff_language_decoder, + "get_specializations": QEffMolmoModel.get_specializations, + "get_onnx_dynamic_axes": QEffMolmoModel.get_onnx_dynamic_axes, + "get_output_names": QEffMolmoModel.get_output_names, + "get_dummy_inputs": QEffMolmoModel.get_dummy_inputs, + "get_inputs_info": QEffMolmoModel.get_inputs_info, + }, + "RMSLayerNorm": {"forward": CustomRMSNormAIC.forward}, + # "MolmoForCausalLM": {"forward": QEffMolmoForCausalLM.forward}, + "Molmo": {"forward": QEffMolmo.forward}, + "MolmoSequentialBlock": { + "forward": QEffMolmoSequentialBlock.forward, + "attention": QEffMolmoBlock.attention, + "__qeff_init__": QEffMolmoBlock.__qeff_init__, + }, + "MolmoBlock": { + "attention": QEffMolmoBlock.attention, + "__qeff_init__": QEffMolmoBlock.__qeff_init__, + }, # Mapping for grok1 model "Grok1ModelForCausalLM": {"forward": QEffGrok1ModelForCausalLM.forward}, "Grok1Model": { diff --git a/examples/molmo_example.py b/examples/molmo_example.py new file mode 100644 index 000000000..db05e1670 --- /dev/null +++ b/examples/molmo_example.py @@ -0,0 +1,86 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +import torch +import transformers +from PIL import Image +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM + +model_id = "allenai/Molmo-7B-D-0924" +config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + +config.num_hidden_layers = 2 + +# load the model +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, kv_offload=True, trust_remote_code=True, config=config) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = True + +if skip_vision: + ## Only Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) + + inputs = processor.process(text="Tell me about yourself") + inputs = {k: v.unsqueeze(0) for k, v in inputs.items()} + inputs["input_ids"] = inputs["input_ids"].to(torch.int64) + inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3], generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + ## Vision + Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=4096, + num_cores=16, + num_devices=4, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + + image = Image.open(requests.get(image_url, stream=True).raw) + image = image.resize((536, 354)) + + inputs = processor.process(images=[image], text="Can you describe the image in detail.") + + inputs = {k: v.unsqueeze(0) for k, v in inputs.items()} + inputs["pixel_values"] = inputs.pop("images") + inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64) + + valid = inputs["image_input_idx"] > 0 + valid = valid.reshape(1, -1) + inputs["valid_idx"] = torch.nonzero(valid)[:, 1].unsqueeze(0) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3], generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + print()