diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 05ec3f864..98bc4eb25 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -15,7 +15,6 @@ import inspect import json -import os from collections.abc import Callable from pathlib import Path from typing import TYPE_CHECKING @@ -29,7 +28,6 @@ import transformers from datasets import load_dataset from packaging.version import Version -from PIL import Image from scripts.ar_validate import validate_ar from torch.distributed.tensor.experimental._attention import _SDPAMerger from torch.utils.data import Dataset @@ -102,14 +100,14 @@ def get_role_content(item): def preprocess_vlm(examples, tokenizer, processor, img_dir): + # NOTE: This function is hard-coded to support Qwen3-VL. Consolidate before merging. + tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") new_examples = { "input_ids": [], "attention_mask": [], "loss_mask": [], "labels": [], - "pixel_values": [], - "image_flags": [], } for i in range(len(examples)): messages = [] @@ -134,31 +132,27 @@ def convert_role(role): for sentence in source: role, content = get_role_content(sentence) + content = [{"type": "text", "text": content}] new_role = convert_role(role) messages.append({"role": new_role, "content": content}) - conversation = tokenizer.apply_chat_template( + + inputs = processor.apply_chat_template( messages, - tokenize=False, - add_generation_prompt=False, + tokenize=True, + add_generation_prompt=True, + return_dict=True, + return_tensors="pt", + fps=4, ) - - img_filename = os.path.join(img_dir, examples[i]["image"]) - img = Image.open(img_filename) - output = processor(images=img, text=conversation, return_tensors="pt") - input_ids = output.input_ids[0] - attention_mask = output.attention_mask[0] + input_ids = inputs.input_ids[0] + attention_mask = inputs.attention_mask[0] loss_mask = torch.ones_like(input_ids) labels = torch.cat([input_ids[1:], torch.tensor([IGNORE_TOKEN_ID], dtype=input_ids.dtype)]) - # TODO: add labels and answer-only loss masking? new_examples["input_ids"].append(input_ids) new_examples["attention_mask"].append(attention_mask) new_examples["loss_mask"].append(loss_mask) new_examples["labels"].append(labels) - new_examples["pixel_values"].append(output.pixel_values) - new_examples["image_flags"].append( - torch.ones((output.pixel_values.shape[0],), dtype=torch.int64) - ) return new_examples diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index f8452cd90..15c67ada3 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -162,13 +162,15 @@ def train(): use_offline_training = data_args.offline_data_path is not None if checkpoint: - model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto") + model = transformers.Qwen3VLForConditionalGeneration.from_pretrained( + checkpoint, torch_dtype="auto" + ) tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) else: # To avoid OOM for large models, we load and convert model on CPU first. # Model will be moved to GPU during HF trainer.init(). offline_kwargs = {"num_hidden_layers": 0} if use_offline_training else {} - model = transformers.AutoModelForCausalLM.from_pretrained( + model = transformers.Qwen3VLForConditionalGeneration.from_pretrained( model_args.model_name_or_path, torch_dtype="auto", device_map="cpu", diff --git a/examples/speculative_decoding/scripts/export_hf_checkpoint.py b/examples/speculative_decoding/scripts/export_hf_checkpoint.py index dfc293ee9..486b51e6c 100644 --- a/examples/speculative_decoding/scripts/export_hf_checkpoint.py +++ b/examples/speculative_decoding/scripts/export_hf_checkpoint.py @@ -18,7 +18,7 @@ import argparse import torch -from transformers import AutoModelForCausalLM +from transformers import Qwen3VLForConditionalGeneration import modelopt.torch.opt as mto from modelopt.torch.export import export_hf_checkpoint @@ -38,7 +38,7 @@ def parse_args(): mto.enable_huggingface_checkpointing() args = parse_args() -model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto") +model = Qwen3VLForConditionalGeneration.from_pretrained(args.model_path, torch_dtype="auto") model.eval() with torch.inference_mode(): export_hf_checkpoint( diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 3090297aa..630712998 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -425,16 +425,23 @@ def _base_model_lm_head(self): @property def _base_llm_config(self): """Return the llm config for the base model, from LLM or VLM.""" - return self.config.llm_config if hasattr(self.config, "llm_config") else self.config + # return self.config.llm_config if hasattr(self.config, "llm_config") else self.config + return self.config.text_config def _find_base_model_parts(self): """Find model parts from different models and set base_{part}_path attributes.""" base_model_parts_mapping = { - "base_model_path": ["model", "backbone", "language_model.backbone"], + "base_model_path": [ + "model.language_model", + "model", + "backbone", + "language_model.backbone", + ], "base_model_embeddings_path": [ "model.embed_tokens", "backbone.embeddings", "language_model.backbone.embeddings", + "model.language_model.embed_tokens", ], "base_model_lm_head_path": ["lm_head", "language_model.lm_head"], } diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index d259a1fce..fc30b1f1c 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -42,6 +42,9 @@ def calibrate_frequent_vocab(tokenizer, text, target_vocab_size, output_file=None): """Given a calibration text, find the most common vocabs and return the mapping.""" conversations = tokenizer.apply_chat_template(text) + # Transformers5.x returns a BatchEncoding from apply_chat_template + if hasattr(conversations, "input_ids"): + conversations = conversations.input_ids counter = Counter(conversations) vocab = counter.most_common(target_vocab_size) mapping = torch.zeros(target_vocab_size, dtype=torch.int64)