Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
1ba0d5f
implement TEDotProductAttentionCP context manager for megatron CP TTT…
yeyu-nvidia Jan 20, 2026
e0062b2
debug: rotary_seq_len in EagleModule forward need to multiply with cp…
yeyu-nvidia Jan 20, 2026
39e2ec6
debug: revert previous change
yeyu-nvidia Jan 20, 2026
739ebff
debug: eagle inputs need gather and scatter for cp
yeyu-nvidia Jan 21, 2026
751eef8
debug: update GPTModel path
yeyu-nvidia Jan 21, 2026
b8f74c9
debug: RotaryEmbedding's output is already scattered to context paral…
yeyu-nvidia Jan 21, 2026
faa8822
debug
yeyu-nvidia Jan 21, 2026
1999c37
revert
yeyu-nvidia Jan 21, 2026
c7b1853
debug: megatron doesn't have gather_from_context_parallel_region; use…
yeyu-nvidia Jan 21, 2026
fe86792
debug: gather_from_sequence_parallel_region gathers from the first di…
yeyu-nvidia Jan 21, 2026
5215713
attention_mask needs to convert to 0/-inf for attention_bias
yeyu-nvidia Jan 21, 2026
92d772e
debug: when CP is enabled, we need to switch to causal mask for eagle
yeyu-nvidia Jan 22, 2026
efa3897
make attention_bias the same dtype as query
yeyu-nvidia Jan 22, 2026
3ac3125
fix the bug; runnable
yeyu-nvidia Jan 22, 2026
19e77fd
remove unnecessary code
yeyu-nvidia Jan 22, 2026
cc88044
fix: HF main needs trust_remote_code=True for resuming ckpt
yeyu-nvidia Jan 26, 2026
2fb7a57
update changelog
yeyu-nvidia Jan 26, 2026
8234c68
formatting
yeyu-nvidia Jan 26, 2026
a2e2489
debug: fix the sharded state dict issue
yeyu-nvidia Jan 28, 2026
f7ce0c3
debug
yeyu-nvidia Jan 28, 2026
b36807a
remove attn mask expansion as it is unnecessary and will cause error …
yeyu-nvidia Jan 29, 2026
8bfd98b
formatting
yeyu-nvidia Jan 29, 2026
4c44c1c
remove te_dot_product_attention_with_cp from pseudo_speculative_gener…
yeyu-nvidia Jan 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ NVIDIA Model Optimizer Changelog (Linux)
- Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning (using ``export_config``). See `examples/pruning/README.md <https://github.com/NVIDIA/Model-Optimizer/tree/main/examples/pruning>`_ for more details on its usage.
- Add support for calibration data with multiple samples in ``npz`` format in the ONNX Autocast workflow.
- Add ``--opset`` option to ONNX quantization CLI to specify the target opset version for the quantized model.
- Add support for context parallelism in Eagle speculative decoding for huggingface and megatron core models.

0.41 (2026-01-19)
^^^^^^^^^^^^^^^^^
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

import argparse
import asyncio
import json
from pathlib import Path

import torch
from datasets import load_dataset
from tqdm import tqdm as tqdm
from transformers import AutoModel, AutoTokenizer

Expand Down Expand Up @@ -54,12 +54,10 @@ def parse_args() -> argparse.Namespace:

## I/O Parameters ##
parser.add_argument(
"--input-file",
"--input-data",
type=Path,
required=True,
help="""Path to the input `jsonl` file containing conversations.
Each entry must have a unique `conversation_id` field and a `conversations` field
containing a list of messages.""",
help="""Path to the `jsonl` file or directory containing `jsonl` files.""",
)
parser.add_argument(
"--output-dir",
Expand All @@ -75,21 +73,68 @@ def parse_args() -> argparse.Namespace:
help="""For debugging purposes, limit the number of conversations processed.
Default is None, meaning no limit.""",
)
parser.add_argument(
"--dp-rank",
type=int,
default=0,
help="""Data parallel rank. TASK_ID on SLURM.""",
)
parser.add_argument(
"--dp-world-size",
type=int,
default=1,
help="""Data parallel world size. Number of tasks on SLURM.""",
)

return parser.parse_args()


async def main(args: argparse.Namespace) -> None:
all_conversations = []
with args.input_file.open("r", encoding="utf-8") as f:
all_conversations.extend([json.loads(line) for line in f if line.strip()])
def main(args: argparse.Namespace) -> None:
# Load conversations
if args.input_data.is_file() and str(args.input_data).endswith(".jsonl"):
dataset = load_dataset("json", data_files=str(args.input_data), split="train")
elif args.input_data.is_dir():
dataset = load_dataset(
"json", data_files={"train": f"{args.input_data}/*.jsonl"}, split="train"
)
else:
raise ValueError(
f"input_data must be a .jsonl file or directory containing .jsonl files, got: {args.input_data}"
)
print(f"Loaded {len(dataset)} conversations from {args.input_data}")

print("Loaded", len(all_conversations), "conversations from", args.input_file)
# Shard data
if args.dp_world_size > 1:
dataset = dataset.shard(num_shards=args.dp_world_size, index=args.dp_rank)
print(
f"Sharded dataset to {len(dataset)} conversations for DP#{args.dp_rank}/{args.dp_world_size}"
)

# Remove already dumped conversations
def keep_conversation(entry):
conversation_id = entry.get("conversation_id", entry.get("uuid", None))
assert conversation_id is not None, "conversation_id is required"
output_file = args.output_dir / f"{conversation_id}.pt"
return not output_file.exists()

original_num = len(dataset)
dataset = dataset.filter(keep_conversation)
print(
"Removed",
original_num - len(dataset),
"conversations due to existing output files",
)

model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto")
# For debugging
if args.debug_max_num_conversations is not None:
dataset = dataset.select(range(args.debug_max_num_conversations))

model = AutoModel.from_pretrained(
args.model, torch_dtype="auto", device_map="auto", trust_remote_code=True
)
num_hidden_layers = getattr(model.config, "num_hidden_layers", None)

tokenizer = AutoTokenizer.from_pretrained(args.model)
tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "")
Expand All @@ -99,30 +144,11 @@ async def main(args: argparse.Namespace) -> None:
num_skipped_too_long = 0
num_invalid = 0
num_success = 0
num_total_conversations = min(
len(all_conversations), args.debug_max_num_conversations or len(all_conversations)
)
for idx, entry in enumerate(
tqdm(
all_conversations[: args.debug_max_num_conversations],
desc="Processing conversations",
total=num_total_conversations,
)
):
conversation_id = entry.get("conversation_id", "{:08d}".format(idx))
conversations = entry["conversations"]
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue

# Tokenize and check length
input_ids = tokenizer.apply_chat_template(
conversations, return_tensors="pt", add_generation_template=False
)
num_input_tokens = input_ids.shape[1]
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
num_skipped_too_long += 1
continue
pbar = tqdm(total=len(dataset), desc=f"DP#{args.dp_rank} Processing conversations")

async def dump_hidden_states(idx: int, conversation_id: int, input_ids: torch.Tensor):
nonlocal num_success
nonlocal num_hidden_layers

# Get hidden states
with torch.inference_mode():
Expand All @@ -144,9 +170,9 @@ async def main(args: argparse.Namespace) -> None:
aux_hidden_states = torch.cat(
[hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1
)
output_hidden_states = outputs.last_hidden_state.squeeze(0).cpu()
output_hidden_states = hidden_states[-1].squeeze(0).cpu()
output_file = output_dir / f"{conversation_id}.pt"
num_success += 1

with open(output_file, "wb") as f:
torch.save(
{
Expand All @@ -158,19 +184,49 @@ async def main(args: argparse.Namespace) -> None:
f,
)

num_success += 1
pbar.update(1)

async def submit_generates():
nonlocal num_skipped_too_long
nonlocal num_invalid
tasks = []
idx = 0
for entry in dataset:
conversation_id = entry.get("conversation_id", entry.get("uuid"))

conversations = entry["conversations"]
if not conversations or not isinstance(conversations, list):
num_invalid += 1
continue

# Tokenize and check length
input_ids = tokenizer.apply_chat_template(
conversations, return_tensors="pt", add_generation_template=False
)["input_ids"]
num_input_tokens = input_ids.shape[1]
if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len:
num_skipped_too_long += 1
continue

tasks.append(dump_hidden_states(idx, conversation_id, input_ids))
# Increment only for valid conversations to match dump file index
idx += 1
await asyncio.gather(*tasks)

asyncio.run(submit_generates())

if num_skipped_too_long > 0:
print(f"Skipped {num_skipped_too_long} conversations due to length constraints.")
if num_invalid > 0:
print(f"Skipped {num_invalid} invalid conversations without proper fields.")

if num_success == num_total_conversations:
if num_success == len(dataset):
print(f"Successfully processed all {num_success} conversations.")
else:
print(
f"Successfully processed {num_success} out of {num_total_conversations} conversations."
)
print(f"Successfully processed {num_success} out of {len(dataset)} conversations.")


if __name__ == "__main__":
cli_args = parse_args()
asyncio.run(main(cli_args))
main(cli_args)
4 changes: 3 additions & 1 deletion examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ def train():
use_offline_training = data_args.offline_data_path is not None

if checkpoint:
model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto")
model = transformers.AutoModelForCausalLM.from_pretrained(
checkpoint, torch_dtype="auto", trust_remote_code=True
)
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
Comment on lines +165 to 168
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, check the exact file and lines in question
cd /repo && head -n 175 examples/speculative_decoding/main.py | tail -n 20

Repository: NVIDIA/Model-Optimizer

Length of output: 121


🏁 Script executed:

# Search for all trust_remote_code usage in the file
rg "trust_remote_code" examples/speculative_decoding/main.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 1068


🏁 Script executed:

# Check if there's a training_args or argument parser structure
rg "argparse|ArgumentParser|TrainingArguments" examples/speculative_decoding/main.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 431


🏁 Script executed:

# Look at the overall structure of the file to understand the context
wc -l examples/speculative_decoding/main.py
head -n 50 examples/speculative_decoding/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2045


🏁 Script executed:

# Get full TrainingArguments definition
rg "@dataclass" examples/speculative_decoding/main.py -A 20 | head -n 80

Repository: NVIDIA/Model-Optimizer

Length of output: 2794


🏁 Script executed:

# Check if trust_remote_code is referenced anywhere in TrainingArguments or passed as argument
rg "trust_remote_code" examples/speculative_decoding/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 373


🏁 Script executed:

# Look at the actual context around lines 165-168 to understand the flow
sed -n '155,175p' examples/speculative_decoding/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 987


🏁 Script executed:

# Check if training_args is available at that point in the code
sed -n '1,170p' examples/speculative_decoding/main.py | grep -E "training_args|def train"

Repository: NVIDIA/Model-Optimizer

Length of output: 706


Gate trust_remote_code behind an explicit flag.
Lines 165, 167, 173, and 186 all hardcode trust_remote_code=True for model/tokenizer/config loading, which automatically executes remote code if the checkpoint points to an untrusted repository. Add a trust_remote_code: bool = False field to the TrainingArguments dataclass and use training_args.trust_remote_code in all four locations.

Affected locations
  • Line 165-167: checkpoint model and tokenizer loading
  • Line 173, 186: model config and tokenizer loading with model_name_or_path
🤖 Prompt for AI Agents
In `@examples/speculative_decoding/main.py` around lines 165 - 168, Add a new
boolean field trust_remote_code: bool = False to the TrainingArguments dataclass
and replace all hardcoded trust_remote_code=True calls to use
training_args.trust_remote_code instead; specifically, update the calls to
transformers.AutoModelForCausalLM.from_pretrained (the model variable),
transformers.AutoTokenizer.from_pretrained (the tokenizer variable),
transformers.AutoConfig.from_pretrained, and the later
transformers.AutoTokenizer.from_pretrained that use model_name_or_path so they
pass training_args.trust_remote_code.

else:
# To avoid OOM for large models, we load and convert model on CPU first.
Expand Down
Loading