diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 22143da28..3724d412d 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning (using ``export_config``). See `examples/pruning/README.md `_ for more details on its usage. - Add support for calibration data with multiple samples in ``npz`` format in the ONNX Autocast workflow. - Add ``--opset`` option to ONNX quantization CLI to specify the target opset version for the quantized model. +- Add support for context parallelism in Eagle speculative decoding for huggingface and megatron core models. 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py index f9818e464..a3d1681c4 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py @@ -17,10 +17,10 @@ import argparse import asyncio -import json from pathlib import Path import torch +from datasets import load_dataset from tqdm import tqdm as tqdm from transformers import AutoModel, AutoTokenizer @@ -54,12 +54,10 @@ def parse_args() -> argparse.Namespace: ## I/O Parameters ## parser.add_argument( - "--input-file", + "--input-data", type=Path, required=True, - help="""Path to the input `jsonl` file containing conversations. - Each entry must have a unique `conversation_id` field and a `conversations` field - containing a list of messages.""", + help="""Path to the `jsonl` file or directory containing `jsonl` files.""", ) parser.add_argument( "--output-dir", @@ -75,21 +73,68 @@ def parse_args() -> argparse.Namespace: help="""For debugging purposes, limit the number of conversations processed. Default is None, meaning no limit.""", ) + parser.add_argument( + "--dp-rank", + type=int, + default=0, + help="""Data parallel rank. TASK_ID on SLURM.""", + ) + parser.add_argument( + "--dp-world-size", + type=int, + default=1, + help="""Data parallel world size. Number of tasks on SLURM.""", + ) return parser.parse_args() -async def main(args: argparse.Namespace) -> None: - all_conversations = [] - with args.input_file.open("r", encoding="utf-8") as f: - all_conversations.extend([json.loads(line) for line in f if line.strip()]) +def main(args: argparse.Namespace) -> None: + # Load conversations + if args.input_data.is_file() and str(args.input_data).endswith(".jsonl"): + dataset = load_dataset("json", data_files=str(args.input_data), split="train") + elif args.input_data.is_dir(): + dataset = load_dataset( + "json", data_files={"train": f"{args.input_data}/*.jsonl"}, split="train" + ) + else: + raise ValueError( + f"input_data must be a .jsonl file or directory containing .jsonl files, got: {args.input_data}" + ) + print(f"Loaded {len(dataset)} conversations from {args.input_data}") - print("Loaded", len(all_conversations), "conversations from", args.input_file) + # Shard data + if args.dp_world_size > 1: + dataset = dataset.shard(num_shards=args.dp_world_size, index=args.dp_rank) + print( + f"Sharded dataset to {len(dataset)} conversations for DP#{args.dp_rank}/{args.dp_world_size}" + ) + + # Remove already dumped conversations + def keep_conversation(entry): + conversation_id = entry.get("conversation_id", entry.get("uuid", None)) + assert conversation_id is not None, "conversation_id is required" + output_file = args.output_dir / f"{conversation_id}.pt" + return not output_file.exists() + + original_num = len(dataset) + dataset = dataset.filter(keep_conversation) + print( + "Removed", + original_num - len(dataset), + "conversations due to existing output files", + ) - model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto") + # For debugging + if args.debug_max_num_conversations is not None: + dataset = dataset.select(range(args.debug_max_num_conversations)) + + model = AutoModel.from_pretrained( + args.model, torch_dtype="auto", device_map="auto", trust_remote_code=True + ) num_hidden_layers = getattr(model.config, "num_hidden_layers", None) - tokenizer = AutoTokenizer.from_pretrained(args.model) + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") @@ -99,30 +144,11 @@ async def main(args: argparse.Namespace) -> None: num_skipped_too_long = 0 num_invalid = 0 num_success = 0 - num_total_conversations = min( - len(all_conversations), args.debug_max_num_conversations or len(all_conversations) - ) - for idx, entry in enumerate( - tqdm( - all_conversations[: args.debug_max_num_conversations], - desc="Processing conversations", - total=num_total_conversations, - ) - ): - conversation_id = entry.get("conversation_id", "{:08d}".format(idx)) - conversations = entry["conversations"] - if not conversations or not isinstance(conversations, list): - num_invalid += 1 - continue - - # Tokenize and check length - input_ids = tokenizer.apply_chat_template( - conversations, return_tensors="pt", add_generation_template=False - ) - num_input_tokens = input_ids.shape[1] - if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: - num_skipped_too_long += 1 - continue + pbar = tqdm(total=len(dataset), desc=f"DP#{args.dp_rank} Processing conversations") + + async def dump_hidden_states(idx: int, conversation_id: int, input_ids: torch.Tensor): + nonlocal num_success + nonlocal num_hidden_layers # Get hidden states with torch.inference_mode(): @@ -144,9 +170,9 @@ async def main(args: argparse.Namespace) -> None: aux_hidden_states = torch.cat( [hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1 ) - output_hidden_states = outputs.last_hidden_state.squeeze(0).cpu() + output_hidden_states = hidden_states[-1].squeeze(0).cpu() output_file = output_dir / f"{conversation_id}.pt" - num_success += 1 + with open(output_file, "wb") as f: torch.save( { @@ -158,19 +184,49 @@ async def main(args: argparse.Namespace) -> None: f, ) + num_success += 1 + pbar.update(1) + + async def submit_generates(): + nonlocal num_skipped_too_long + nonlocal num_invalid + tasks = [] + idx = 0 + for entry in dataset: + conversation_id = entry.get("conversation_id", entry.get("uuid")) + + conversations = entry["conversations"] + if not conversations or not isinstance(conversations, list): + num_invalid += 1 + continue + + # Tokenize and check length + input_ids = tokenizer.apply_chat_template( + conversations, return_tensors="pt", add_generation_template=False + )["input_ids"] + num_input_tokens = input_ids.shape[1] + if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: + num_skipped_too_long += 1 + continue + + tasks.append(dump_hidden_states(idx, conversation_id, input_ids)) + # Increment only for valid conversations to match dump file index + idx += 1 + await asyncio.gather(*tasks) + + asyncio.run(submit_generates()) + if num_skipped_too_long > 0: print(f"Skipped {num_skipped_too_long} conversations due to length constraints.") if num_invalid > 0: print(f"Skipped {num_invalid} invalid conversations without proper fields.") - if num_success == num_total_conversations: + if num_success == len(dataset): print(f"Successfully processed all {num_success} conversations.") else: - print( - f"Successfully processed {num_success} out of {num_total_conversations} conversations." - ) + print(f"Successfully processed {num_success} out of {len(dataset)} conversations.") if __name__ == "__main__": cli_args = parse_args() - asyncio.run(main(cli_args)) + main(cli_args) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index f8452cd90..8706ca049 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -162,7 +162,9 @@ def train(): use_offline_training = data_args.offline_data_path is not None if checkpoint: - model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto") + model = transformers.AutoModelForCausalLM.from_pretrained( + checkpoint, torch_dtype="auto", trust_remote_code=True + ) tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) else: # To avoid OOM for large models, we load and convert model on CPU first. diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 5435a8efa..e37e8f931 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -17,6 +17,7 @@ import copy import warnings +from contextlib import contextmanager import megatron.core import torch @@ -24,13 +25,15 @@ from megatron.core import InferenceParams, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding -from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.extensions.transformer_engine import TELinear, TENorm from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding from megatron.core.models.gpt import GPTModel from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.parallel_state import ( + get_context_parallel_group, + get_context_parallel_world_size, get_data_parallel_rank, get_expert_tensor_parallel_world_size, get_pipeline_model_parallel_world_size, @@ -59,7 +62,6 @@ try: from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec - from megatron.core.post_training.modelopt.layers import Linear except ImportError: warnings.warn("Fail to import megatron.core.post_training! EAGLE feature will be disable!") @@ -388,7 +390,11 @@ def sharded_state_dict( if module is not self.layers: sharded_state_dict.update( sharded_state_dict_default( - module, f"{prefix}{name}.", sharded_offsets, metadata + module, + f"{prefix}{name}.", + sharded_offsets, + metadata, + tp_group=self.tp_group, ) ) @@ -442,16 +448,19 @@ def __init__( self._num_aux_hidden_states if self._num_aux_hidden_states > 0 else 2 ) - # This linear was previously a ColumnParallelLinear. We changed it to a normal linear + # This linear was previously a ColumnParallelLinear. We changed it to a TELinear # since ColumnParallelLinear will have try to gather the input sequence when sequence # parallel is used and does not allow gathering the outputs. with torch.device(device): - self.fc = Linear( + self.fc = TELinear( config.hidden_size * fc_input_size_multiplier, config.hidden_size, + parallel_mode="duplicated", config=config, init_method=(lambda w: None), # not used bias=bias, + skip_bias_add=False, + skip_weight_param_allocation=False, ) self.rotary_pos_emb = rotary_pos_emb @@ -529,11 +538,13 @@ def _get_eagle_transformer_layer_spec(self, config): IMPORTANT: EagleModule must use arbitrary_attention_mask since we need to manipulate the mask to compute the correct loss. The default causal mask will result in leaking. + However, if context parallel is used, we need to switch to causal + mask and inject attention_mask as attention_bias instead. """ transformer_layer_spec = get_gpt_modelopt_spec( config, remap_te_layernorm=True, - use_arbitrary_attention_mask=True, + use_arbitrary_attention_mask=get_context_parallel_world_size() == 1, ) # If heterogenous layers (e.g. DeepSeek), transformer_layer_spec is a # TransformerBlockSubmodules instead. We use the last layer_specs. @@ -583,9 +594,13 @@ def forward( # NOTE: Even if sequence_parallel is used, the rotary_seq_len must be in the original # length. Since we get the seq_len from hidden_states.shape[0], we need to # multiply the the tp back. + # Similarly, if context parallel is used, the rotary_seq_len must also be + # multiplied by context parallel size. rotary_seq_len = hidden_states.shape[0] if self.config.sequence_parallel: rotary_seq_len *= self.config.tensor_model_parallel_size + if get_context_parallel_world_size() > 1: + rotary_seq_len *= get_context_parallel_world_size() if self.config.use_mtp_layernorm: embeddings = self.enorm(embeddings) @@ -838,16 +853,41 @@ def _get_eagle_module_inputs( ttt_step: int = 0, ): """Getting EAGLE module inputs.""" - # [b, 1] + # gather_from_sequence_parallel_region gathers from the first dimention + # so we need to transpose input_ids first + # [b,s] -> [s,b] + input_ids = input_ids.clone().transpose(0, 1).contiguous() + input_ids = gather_from_sequence_parallel_region( + input_ids, group=get_context_parallel_group() + ) + # [s,b] -> [b,s] + input_ids = input_ids.transpose(0, 1).contiguous() id_padding = torch.zeros( (input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device ) padded_input_ids = torch.cat((input_ids[:, 1:], id_padding), dim=-1) + # RotaryEmbedding's output is already scattered to context parallel region + # No need to scatter again. rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1]) + # [b,s] -> [s,b] + padded_input_ids = padded_input_ids.transpose(0, 1).contiguous() + padded_input_ids = scatter_to_sequence_parallel_region( + padded_input_ids, group=get_context_parallel_group() + ) + # [s,b] -> [b,s] + padded_input_ids = padded_input_ids.transpose(0, 1).contiguous() + attn_mask = attention_mask.clone().detach() - attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:] + # [b, 1, sq, sk] -> [sq, 1, b, sk] + attn_mask = attn_mask.transpose(0, 2).contiguous() + attn_mask = gather_from_sequence_parallel_region( + attn_mask, group=get_context_parallel_group() + ) + # [sq, 1, b, sk] -> [b, 1, sq, sk] + attn_mask = attn_mask.transpose(0, 2).contiguous() + attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] attn_mask[:, :, -1, :] = True attn_mask[:, :, :, -1] = True @@ -860,9 +900,17 @@ def _get_eagle_module_inputs( input_ids=eagle_inputs["input_ids"], position_ids=eagle_inputs["position_ids"], ) + eagle_inputs["hidden_states"] = hidden_states - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, ttt_step) + attn_mask = set_multi_step_attention_mask(attn_mask, ttt_step) + # [b, 1, sq, sk] -> [sq, 1, b, sk] + attn_mask = attn_mask.transpose(0, 2).contiguous() + attn_mask = scatter_to_sequence_parallel_region( + attn_mask, group=get_context_parallel_group() + ) + # [sq, 1, b, sk] -> [b, 1, sq, sk] + eagle_inputs["attention_mask"] = attn_mask.transpose(0, 2).contiguous() eagle_inputs["rotary_pos_emb"] = torch.cat( [rotary_pos_emb] * (ttt_step + 1), @@ -1111,14 +1159,17 @@ def forward( ttt_step=ttt_step, ) - _, eagle_logits, eagle_module_input_hidden_states = self._eagle_forward( - eagle_inputs, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - inference_context=eagle_inference_context, - **(extra_block_kwargs or {}), - ) + with te_dot_product_attention_with_cp( + eagle_inputs["attention_mask"], self.eagle_config.num_attention_heads + ): + _, eagle_logits, eagle_module_input_hidden_states = self._eagle_forward( + eagle_inputs, + output_weight, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + inference_context=eagle_inference_context, + **(extra_block_kwargs or {}), + ) if self.config.sequence_parallel: eagle_module_input_hidden_states = gather_from_sequence_parallel_region( @@ -1330,3 +1381,80 @@ def get_ground_truth(self, input_ids, osl): if input_id[0, 0] == self.end_token: break return input_ids + + +@contextmanager +def te_dot_product_attention_with_cp(attention_mask: torch.Tensor, num_attention_heads: int): + """Context manager for TEDotProductAttention with context parallelism. + + Context manager that temporarily replace `attention_bias` + with `attention_mask` for `TEDotProductAttention.forward` calls across the process + if context parallel is used. + + Any call to `TEDotProductAttention.forward` (including calls originating + from other modules) inside the context will receive `attention_bias=attention_mask` + if context parallelism is used. + + Example: + with te_dot_product_attention_with_cp(attention_mask_tensor, num_attention_heads): + outputs = model(...) + + Note: This monkey-patches the class method and restores it on exit. + """ + from megatron.core.extensions.transformer_engine import TEDotProductAttention + + orig_forward = TEDotProductAttention.forward + + def _wrapped_forward(self, *args, **kwargs): + # Build attention_bias from the boolean attention_mask and ensure + # it's a fresh, detached tensor on the query's device/dtype to + # avoid shared-storage in-place modifications that break autograd. + query = args[0] if len(args) > 0 else None + if isinstance(query, torch.Tensor): + q_device = query.device + q_dtype = query.dtype + else: + q_device = None + q_dtype = None + + mask_fill = -1e9 + if q_dtype in (torch.float16, torch.bfloat16): + mask_fill = -40.0 + mask_val = torch.tensor(mask_fill, device=attention_mask.device) + zero_val = torch.tensor(0.0, device=attention_mask.device) + attention_bias = torch.where(attention_mask, mask_val, zero_val) + + if q_device is not None and q_dtype is not None: + attention_bias = attention_bias.to(device=q_device, dtype=q_dtype) + + attention_bias = attention_bias.clone().detach().contiguous() + kwargs["attention_bias"] = attention_bias + + # Defensive clone of query/key/value positional tensors to avoid + # passing views into the fused attention kernel that might be + # modified in-place during backward. + if len(args) >= 1: + original_args = args + new_args = list(original_args) + try: + for i in range(min(3, len(new_args))): + if isinstance(new_args[i], torch.Tensor): + if not new_args[i].is_contiguous(): + new_args[i] = new_args[i].contiguous() + new_args[i] = new_args[i].clone() + + if any(x is None for x in new_args): + args = original_args + else: + args = tuple(new_args) + except Exception: + args = original_args + + return orig_forward(self, *args, **kwargs) + + if get_context_parallel_world_size() > 1: + TEDotProductAttention.forward = _wrapped_forward + try: + yield + finally: + TEDotProductAttention.forward = orig_forward