Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 12 additions & 18 deletions examples/speculative_decoding/eagle_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import inspect
import json
import os
from collections.abc import Callable
from pathlib import Path
from typing import TYPE_CHECKING
Expand All @@ -29,7 +28,6 @@
import transformers
from datasets import load_dataset
from packaging.version import Version
from PIL import Image
from scripts.ar_validate import validate_ar
from torch.distributed.tensor.experimental._attention import _SDPAMerger
from torch.utils.data import Dataset
Expand Down Expand Up @@ -102,14 +100,14 @@ def get_role_content(item):


def preprocess_vlm(examples, tokenizer, processor, img_dir):
# NOTE: This function is hard-coded to support Qwen3-VL. Consolidate before merging.

tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
new_examples = {
"input_ids": [],
"attention_mask": [],
"loss_mask": [],
"labels": [],
"pixel_values": [],
"image_flags": [],
}
for i in range(len(examples)):
messages = []
Expand All @@ -134,31 +132,27 @@ def convert_role(role):

for sentence in source:
role, content = get_role_content(sentence)
content = [{"type": "text", "text": content}]
new_role = convert_role(role)
messages.append({"role": new_role, "content": content})
conversation = tokenizer.apply_chat_template(

inputs = processor.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
fps=4,
)

img_filename = os.path.join(img_dir, examples[i]["image"])
img = Image.open(img_filename)
output = processor(images=img, text=conversation, return_tensors="pt")
input_ids = output.input_ids[0]
attention_mask = output.attention_mask[0]
input_ids = inputs.input_ids[0]
attention_mask = inputs.attention_mask[0]
loss_mask = torch.ones_like(input_ids)
labels = torch.cat([input_ids[1:], torch.tensor([IGNORE_TOKEN_ID], dtype=input_ids.dtype)])
# TODO: add labels and answer-only loss masking?

new_examples["input_ids"].append(input_ids)
new_examples["attention_mask"].append(attention_mask)
new_examples["loss_mask"].append(loss_mask)
new_examples["labels"].append(labels)
new_examples["pixel_values"].append(output.pixel_values)
new_examples["image_flags"].append(
torch.ones((output.pixel_values.shape[0],), dtype=torch.int64)
)
return new_examples


Expand Down
6 changes: 4 additions & 2 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,13 +162,15 @@ def train():
use_offline_training = data_args.offline_data_path is not None

if checkpoint:
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto")
model = transformers.Qwen3VLForConditionalGeneration.from_pretrained(
checkpoint, torch_dtype="auto"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
else:
# To avoid OOM for large models, we load and convert model on CPU first.
# Model will be moved to GPU during HF trainer.init().
offline_kwargs = {"num_hidden_layers": 0} if use_offline_training else {}
model = transformers.AutoModelForCausalLM.from_pretrained(
model = transformers.Qwen3VLForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
torch_dtype="auto",
device_map="cpu",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import argparse

import torch
from transformers import AutoModelForCausalLM
from transformers import Qwen3VLForConditionalGeneration

import modelopt.torch.opt as mto
from modelopt.torch.export import export_hf_checkpoint
Expand All @@ -38,7 +38,7 @@ def parse_args():
mto.enable_huggingface_checkpointing()

args = parse_args()
model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto")
model = Qwen3VLForConditionalGeneration.from_pretrained(args.model_path, torch_dtype="auto")
model.eval()
with torch.inference_mode():
export_hf_checkpoint(
Expand Down
11 changes: 9 additions & 2 deletions modelopt/torch/speculative/plugins/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,16 +425,23 @@ def _base_model_lm_head(self):
@property
def _base_llm_config(self):
"""Return the llm config for the base model, from LLM or VLM."""
return self.config.llm_config if hasattr(self.config, "llm_config") else self.config
# return self.config.llm_config if hasattr(self.config, "llm_config") else self.config
return self.config.text_config

def _find_base_model_parts(self):
"""Find model parts from different models and set base_{part}_path attributes."""
base_model_parts_mapping = {
"base_model_path": ["model", "backbone", "language_model.backbone"],
"base_model_path": [
"model.language_model",
"model",
"backbone",
"language_model.backbone",
],
"base_model_embeddings_path": [
"model.embed_tokens",
"backbone.embeddings",
"language_model.backbone.embeddings",
"model.language_model.embed_tokens",
],
"base_model_lm_head_path": ["lm_head", "language_model.lm_head"],
}
Expand Down
3 changes: 3 additions & 0 deletions modelopt/torch/speculative/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
def calibrate_frequent_vocab(tokenizer, text, target_vocab_size, output_file=None):
"""Given a calibration text, find the most common vocabs and return the mapping."""
conversations = tokenizer.apply_chat_template(text)
# Transformers5.x returns a BatchEncoding from apply_chat_template
if hasattr(conversations, "input_ids"):
conversations = conversations.input_ids
counter = Counter(conversations)
vocab = counter.most_common(target_vocab_size)
mapping = torch.zeros(target_vocab_size, dtype=torch.int64)
Expand Down