From 1ba0d5f7b14e5503d287744b0659dbb3525e22ff Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 20 Jan 2026 12:41:24 -0800 Subject: [PATCH 01/23] implement TEDotProductAttentionCP context manager for megatron CP TTT patch Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 76 ++++++++++++++++--- 1 file changed, 64 insertions(+), 12 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 5435a8efa..f559b0c80 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -17,6 +17,7 @@ import copy import warnings +from contextlib import contextmanager import megatron.core import torch @@ -31,6 +32,7 @@ from megatron.core.models.gpt import GPTModel from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.parallel_state import ( + get_context_parallel_world_size, get_data_parallel_rank, get_expert_tensor_parallel_world_size, get_pipeline_model_parallel_world_size, @@ -1111,14 +1113,17 @@ def forward( ttt_step=ttt_step, ) - _, eagle_logits, eagle_module_input_hidden_states = self._eagle_forward( - eagle_inputs, - output_weight, - inference_params=inference_params, - packed_seq_params=packed_seq_params, - inference_context=eagle_inference_context, - **(extra_block_kwargs or {}), - ) + with TEDotProductAttentionCP( + eagle_inputs["attention_mask"], self.eagle_config.num_attention_heads + ): + _, eagle_logits, eagle_module_input_hidden_states = self._eagle_forward( + eagle_inputs, + output_weight, + inference_params=inference_params, + packed_seq_params=packed_seq_params, + inference_context=eagle_inference_context, + **(extra_block_kwargs or {}), + ) if self.config.sequence_parallel: eagle_module_input_hidden_states = gather_from_sequence_parallel_region( @@ -1266,10 +1271,13 @@ def pseudo_speculative_generate( # [TODO] (chenhany): let the module compute itself eagle_inputs["rotary_pos_emb"] = None - _, eagle_logits, eagle_next_hidden_states_input = self._eagle_forward( - eagle_inputs, - output_weight, - ) + with TEDotProductAttentionCP( + eagle_inputs["attention_mask"], self.eagle_config.num_attention_heads + ): + _, eagle_logits, eagle_next_hidden_states_input = self._eagle_forward( + eagle_inputs, + output_weight, + ) # parallel_logits are only used after the last step if step == steps - 1 and self.eagle_config.parallel_draft_step > 1: @@ -1330,3 +1338,47 @@ def get_ground_truth(self, input_ids, osl): if input_id[0, 0] == self.end_token: break return input_ids + + +@contextmanager +def TEDotProductAttentionCP(attention_mask: torch.Tensor, num_attention_heads: int): + """Context manager for TEDotProductAttention with context parallelism. + + Context manager that temporarily replace `attention_bias` + with `attention_mask` for `TEDotProductAttention.forward` calls across the process + if context parallel is used. + + Any call to `TEDotProductAttention.forward` (including calls originating + from other modules) inside the context will receive `attention_bias=attention_mask` + if context parallelism is used. + + Example: + with TEDotProductAttentionCP(attention_mask_tensor, num_attention_heads): + outputs = model(...) + + Note: This monkey-patches the class method and restores it on exit. + """ + from megatron.core.extensions.transformer_engine import TEDotProductAttention as cls + + orig_forward = cls.forward + + def _wrapped_forward(self, *args, **kwargs): + # Megatron mask is in shape [b, 1, s, s] + # TEDotProductAttention expects bias in [b, h, s, s] + # Replace the attention_bias argument passed to forward + kwargs["attention_bias"] = attention_mask.repeat( + [ + attention_mask.shape[0], + num_attention_heads, + attention_mask.shape[2], + attention_mask.shape[3], + ] + ) + return orig_forward(self, *args, **kwargs) + + if get_context_parallel_world_size() > 1: + cls.forward = _wrapped_forward + try: + yield + finally: + cls.forward = orig_forward From e0062b2f817654752bc2035ad3ca6b2706575fbe Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 20 Jan 2026 14:21:07 -0800 Subject: [PATCH 02/23] debug: rotary_seq_len in EagleModule forward need to multiply with cp_size Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index f559b0c80..083cac6b0 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -585,9 +585,11 @@ def forward( # NOTE: Even if sequence_parallel is used, the rotary_seq_len must be in the original # length. Since we get the seq_len from hidden_states.shape[0], we need to # multiply the the tp back. + # Similarly, if get_context_parallel_world_size() > 1, we also need to multiply the cp size. rotary_seq_len = hidden_states.shape[0] if self.config.sequence_parallel: rotary_seq_len *= self.config.tensor_model_parallel_size + rotary_seq_len *= get_context_parallel_world_size() if self.config.use_mtp_layernorm: embeddings = self.enorm(embeddings) From 39e2ec67435c3de3f0bc8d0459bd73d21640db9f Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Tue, 20 Jan 2026 14:27:12 -0800 Subject: [PATCH 03/23] debug: revert previous change Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 083cac6b0..f559b0c80 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -585,11 +585,9 @@ def forward( # NOTE: Even if sequence_parallel is used, the rotary_seq_len must be in the original # length. Since we get the seq_len from hidden_states.shape[0], we need to # multiply the the tp back. - # Similarly, if get_context_parallel_world_size() > 1, we also need to multiply the cp size. rotary_seq_len = hidden_states.shape[0] if self.config.sequence_parallel: rotary_seq_len *= self.config.tensor_model_parallel_size - rotary_seq_len *= get_context_parallel_world_size() if self.config.use_mtp_layernorm: embeddings = self.enorm(embeddings) From 739ebff8ea9db3e08ae2bad3889e3282c359487c Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 21 Jan 2026 11:45:21 -0800 Subject: [PATCH 04/23] debug: eagle inputs need gather and scatter for cp Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index f559b0c80..d173505fc 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -40,8 +40,10 @@ get_tensor_model_parallel_world_size, ) from megatron.core.tensor_parallel.mappings import ( + gather_from_context_parallel_region, gather_from_sequence_parallel_region, gather_from_tensor_model_parallel_region, + scatter_to_context_parallel_region, scatter_to_sequence_parallel_region, ) from megatron.core.transformer.attention import SelfAttention @@ -585,9 +587,13 @@ def forward( # NOTE: Even if sequence_parallel is used, the rotary_seq_len must be in the original # length. Since we get the seq_len from hidden_states.shape[0], we need to # multiply the the tp back. + # Similarly, if context parallel is used, the rotary_seq_len must also be + # multiplied by context parallel size. rotary_seq_len = hidden_states.shape[0] if self.config.sequence_parallel: rotary_seq_len *= self.config.tensor_model_parallel_size + if get_context_parallel_world_size() > 1: + rotary_seq_len *= get_context_parallel_world_size() if self.config.use_mtp_layernorm: embeddings = self.enorm(embeddings) @@ -841,6 +847,7 @@ def _get_eagle_module_inputs( ): """Getting EAGLE module inputs.""" # [b, 1] + input_ids = gather_from_context_parallel_region(input_ids) id_padding = torch.zeros( (input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device ) @@ -848,8 +855,12 @@ def _get_eagle_module_inputs( rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1]) - attn_mask = attention_mask.clone().detach() - attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:] + padded_input_ids = scatter_to_context_parallel_region(padded_input_ids) + input_ids = scatter_to_context_parallel_region(input_ids) + rotary_pos_emb = scatter_to_context_parallel_region(rotary_pos_emb) + + attn_mask = gather_from_context_parallel_region(attention_mask.clone().detach()) + attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] attn_mask[:, :, -1, :] = True attn_mask[:, :, :, -1] = True @@ -862,9 +873,12 @@ def _get_eagle_module_inputs( input_ids=eagle_inputs["input_ids"], position_ids=eagle_inputs["position_ids"], ) + eagle_inputs["hidden_states"] = hidden_states - eagle_inputs["attention_mask"] = set_multi_step_attention_mask(attn_mask, ttt_step) + eagle_inputs["attention_mask"] = scatter_to_context_parallel_region( + set_multi_step_attention_mask(attn_mask, ttt_step) + ) eagle_inputs["rotary_pos_emb"] = torch.cat( [rotary_pos_emb] * (ttt_step + 1), From 751eef88eb338fea7c8bed3ae60f44dd93a037bd Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 21 Jan 2026 11:56:03 -0800 Subject: [PATCH 05/23] debug: update GPTModel path Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_medusa.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/megatron_medusa.py b/modelopt/torch/speculative/plugins/megatron_medusa.py index ccc6c7a69..10501ae33 100644 --- a/modelopt/torch/speculative/plugins/megatron_medusa.py +++ b/modelopt/torch/speculative/plugins/megatron_medusa.py @@ -127,7 +127,7 @@ def sharded_state_dict( return sharded_state_dict -@MedusaDMRegistry.register({GPTModel: "megatron.core.models.gpt.GPTModel"}) +@MedusaDMRegistry.register({GPTModel: "megatron.core.models.gpt.gpt_model.GPTModel"}) class _DynamicMedusaGPTModel(MedusaModel): """A ``megatron.core.models.gpt.GPTModel`` model with dynamic hyperparams.""" From b8f74c9d1bf7e22b5971da4195898e4c86daf1d1 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 21 Jan 2026 12:20:10 -0800 Subject: [PATCH 06/23] debug: RotaryEmbedding's output is already scattered to context parallel region Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index d173505fc..4ca2f729e 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -853,11 +853,12 @@ def _get_eagle_module_inputs( ) padded_input_ids = torch.cat((input_ids[:, 1:], id_padding), dim=-1) + # RotaryEmbedding's output is already scattered to context parallel region + # No need to scatter again. rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1]) padded_input_ids = scatter_to_context_parallel_region(padded_input_ids) input_ids = scatter_to_context_parallel_region(input_ids) - rotary_pos_emb = scatter_to_context_parallel_region(rotary_pos_emb) attn_mask = gather_from_context_parallel_region(attention_mask.clone().detach()) attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] From faa8822d82729f74c07dc896ef75b69f30ae47fe Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 21 Jan 2026 12:22:18 -0800 Subject: [PATCH 07/23] debug Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 4ca2f729e..d7178f588 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -638,7 +638,7 @@ def forward( return hidden_states, next_hidden_states_input -@EagleDMRegistry.register({GPTModel: "megatron.core.models.gpt.GPTModel"}) +@EagleDMRegistry.register({GPTModel: "megatron.core.models.gpt.gpt_model.GPTModel"}) class _DynamicEagleGPTModel(EagleModel): """A ``megatron.core.models.gpt.GPTModel`` model with dynamic hyperparams.""" From 1999c37bbbfb35e3f8ec6401a59c4a998238c800 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 21 Jan 2026 12:27:32 -0800 Subject: [PATCH 08/23] revert Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 2 +- modelopt/torch/speculative/plugins/megatron_medusa.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index d7178f588..4ca2f729e 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -638,7 +638,7 @@ def forward( return hidden_states, next_hidden_states_input -@EagleDMRegistry.register({GPTModel: "megatron.core.models.gpt.gpt_model.GPTModel"}) +@EagleDMRegistry.register({GPTModel: "megatron.core.models.gpt.GPTModel"}) class _DynamicEagleGPTModel(EagleModel): """A ``megatron.core.models.gpt.GPTModel`` model with dynamic hyperparams.""" diff --git a/modelopt/torch/speculative/plugins/megatron_medusa.py b/modelopt/torch/speculative/plugins/megatron_medusa.py index 10501ae33..ccc6c7a69 100644 --- a/modelopt/torch/speculative/plugins/megatron_medusa.py +++ b/modelopt/torch/speculative/plugins/megatron_medusa.py @@ -127,7 +127,7 @@ def sharded_state_dict( return sharded_state_dict -@MedusaDMRegistry.register({GPTModel: "megatron.core.models.gpt.gpt_model.GPTModel"}) +@MedusaDMRegistry.register({GPTModel: "megatron.core.models.gpt.GPTModel"}) class _DynamicMedusaGPTModel(MedusaModel): """A ``megatron.core.models.gpt.GPTModel`` model with dynamic hyperparams.""" From c7b185345cd7e0e29fa9b5a40e5044d3563ad2f4 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 21 Jan 2026 13:06:58 -0800 Subject: [PATCH 09/23] debug: megatron doesn't have gather_from_context_parallel_region; use gather_from_sequence_parallel_region and change the group Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 4ca2f729e..2cf6b2f61 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -32,6 +32,7 @@ from megatron.core.models.gpt import GPTModel from megatron.core.packed_seq_params import PackedSeqParams from megatron.core.parallel_state import ( + get_context_parallel_group, get_context_parallel_world_size, get_data_parallel_rank, get_expert_tensor_parallel_world_size, @@ -40,10 +41,8 @@ get_tensor_model_parallel_world_size, ) from megatron.core.tensor_parallel.mappings import ( - gather_from_context_parallel_region, gather_from_sequence_parallel_region, gather_from_tensor_model_parallel_region, - scatter_to_context_parallel_region, scatter_to_sequence_parallel_region, ) from megatron.core.transformer.attention import SelfAttention @@ -847,7 +846,9 @@ def _get_eagle_module_inputs( ): """Getting EAGLE module inputs.""" # [b, 1] - input_ids = gather_from_context_parallel_region(input_ids) + input_ids = gather_from_sequence_parallel_region( + input_ids, group=get_context_parallel_group() + ) id_padding = torch.zeros( (input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device ) @@ -857,10 +858,17 @@ def _get_eagle_module_inputs( # No need to scatter again. rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1]) - padded_input_ids = scatter_to_context_parallel_region(padded_input_ids) - input_ids = scatter_to_context_parallel_region(input_ids) + padded_input_ids = scatter_to_sequence_parallel_region( + padded_input_ids, group=get_context_parallel_group() + ) + # Not sure this is needed + input_ids = scatter_to_sequence_parallel_region( + input_ids, group=get_context_parallel_group() + ) - attn_mask = gather_from_context_parallel_region(attention_mask.clone().detach()) + attn_mask = gather_from_sequence_parallel_region( + attention_mask.clone().detach(), group=get_context_parallel_group() + ) attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] attn_mask[:, :, -1, :] = True attn_mask[:, :, :, -1] = True @@ -877,8 +885,8 @@ def _get_eagle_module_inputs( eagle_inputs["hidden_states"] = hidden_states - eagle_inputs["attention_mask"] = scatter_to_context_parallel_region( - set_multi_step_attention_mask(attn_mask, ttt_step) + eagle_inputs["attention_mask"] = scatter_to_sequence_parallel_region( + set_multi_step_attention_mask(attn_mask, ttt_step), group=get_context_parallel_group() ) eagle_inputs["rotary_pos_emb"] = torch.cat( From fe86792f8a5efb22630f625c42ed7060a35dd08c Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 21 Jan 2026 13:55:34 -0800 Subject: [PATCH 10/23] debug: gather_from_sequence_parallel_region gathers from the first dimention so we need to transpose tensors first Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 2cf6b2f61..5a30f8398 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -845,10 +845,15 @@ def _get_eagle_module_inputs( ttt_step: int = 0, ): """Getting EAGLE module inputs.""" - # [b, 1] + # gather_from_sequence_parallel_region gathers from the first dimention + # so we need to transpose input_ids first + # [b,s] -> [s,b] + input_ids = input_ids.clone().transpose(0, 1).contiguous() input_ids = gather_from_sequence_parallel_region( input_ids, group=get_context_parallel_group() ) + # [s,b] -> [b,s] + input_ids = input_ids.transpose(0, 1).contiguous() id_padding = torch.zeros( (input_ids.shape[0], 1), dtype=input_ids.dtype, device=input_ids.device ) @@ -858,17 +863,22 @@ def _get_eagle_module_inputs( # No need to scatter again. rotary_pos_emb = self.eagle_module.rotary_pos_emb(padded_input_ids.shape[-1]) + # [b,s] -> [s,b] + padded_input_ids = padded_input_ids.transpose(0, 1).contiguous() padded_input_ids = scatter_to_sequence_parallel_region( padded_input_ids, group=get_context_parallel_group() ) - # Not sure this is needed - input_ids = scatter_to_sequence_parallel_region( - input_ids, group=get_context_parallel_group() - ) + # [s,b] -> [b,s] + padded_input_ids = padded_input_ids.transpose(0, 1).contiguous() + attn_mask = attention_mask.clone().detach() + # [b, 1, sq, sk] -> [sq, 1, b, sk] + attn_mask = attn_mask.transpose(0, 2).contiguous() attn_mask = gather_from_sequence_parallel_region( - attention_mask.clone().detach(), group=get_context_parallel_group() + attn_mask, group=get_context_parallel_group() ) + # [sq, 1, b, sk] -> [b, 1, sq, sk] + attn_mask = attn_mask.transpose(0, 2).contiguous() attn_mask[:, :, :-1, :-1] = attn_mask[:, :, 1:, 1:] attn_mask[:, :, -1, :] = True attn_mask[:, :, :, -1] = True @@ -885,9 +895,14 @@ def _get_eagle_module_inputs( eagle_inputs["hidden_states"] = hidden_states - eagle_inputs["attention_mask"] = scatter_to_sequence_parallel_region( - set_multi_step_attention_mask(attn_mask, ttt_step), group=get_context_parallel_group() + attn_mask = set_multi_step_attention_mask(attn_mask, ttt_step) + # [b, 1, sq, sk] -> [sq, 1, b, sk] + attn_mask = attn_mask.transpose(0, 2).contiguous() + attn_mask = scatter_to_sequence_parallel_region( + attn_mask, group=get_context_parallel_group() ) + # [sq, 1, b, sk] -> [b, 1, sq, sk] + eagle_inputs["attention_mask"] = attn_mask.transpose(0, 2).contiguous() eagle_inputs["rotary_pos_emb"] = torch.cat( [rotary_pos_emb] * (ttt_step + 1), From 5215713a98aa85e8b41708fa7397368b636ae822 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 21 Jan 2026 14:15:07 -0800 Subject: [PATCH 11/23] attention_mask needs to convert to 0/-inf for attention_bias Signed-off-by: Ye Yu --- .../torch/speculative/plugins/megatron_eagle.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 5a30f8398..6a6ff894c 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1401,17 +1401,8 @@ def TEDotProductAttentionCP(attention_mask: torch.Tensor, num_attention_heads: i orig_forward = cls.forward def _wrapped_forward(self, *args, **kwargs): - # Megatron mask is in shape [b, 1, s, s] - # TEDotProductAttention expects bias in [b, h, s, s] - # Replace the attention_bias argument passed to forward - kwargs["attention_bias"] = attention_mask.repeat( - [ - attention_mask.shape[0], - num_attention_heads, - attention_mask.shape[2], - attention_mask.shape[3], - ] - ) + attention_bias = torch.where(attention_mask, torch.tensor(-1e9), torch.tensor(0.0)) + kwargs["attention_bias"] = attention_bias return orig_forward(self, *args, **kwargs) if get_context_parallel_world_size() > 1: From 92d772e707897d53fda5f6b7c81ad7fba671eca3 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 21 Jan 2026 17:44:54 -0800 Subject: [PATCH 12/23] debug: when CP is enabled, we need to switch to causal mask for eagle Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 6a6ff894c..b7a649a12 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -532,11 +532,13 @@ def _get_eagle_transformer_layer_spec(self, config): IMPORTANT: EagleModule must use arbitrary_attention_mask since we need to manipulate the mask to compute the correct loss. The default causal mask will result in leaking. + However, if context parallel is used, we need to switch to causal + mask and inject attention_mask as attention_bias instead. """ transformer_layer_spec = get_gpt_modelopt_spec( config, remap_te_layernorm=True, - use_arbitrary_attention_mask=True, + use_arbitrary_attention_mask=get_context_parallel_world_size() == 1, ) # If heterogenous layers (e.g. DeepSeek), transformer_layer_spec is a # TransformerBlockSubmodules instead. We use the last layer_specs. From efa389731431b500d08fd2ca6ae9d404410d438c Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 21 Jan 2026 17:47:04 -0800 Subject: [PATCH 13/23] make attention_bias the same dtype as query Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index b7a649a12..2684a25ec 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1404,7 +1404,7 @@ def TEDotProductAttentionCP(attention_mask: torch.Tensor, num_attention_heads: i def _wrapped_forward(self, *args, **kwargs): attention_bias = torch.where(attention_mask, torch.tensor(-1e9), torch.tensor(0.0)) - kwargs["attention_bias"] = attention_bias + kwargs["attention_bias"] = attention_bias.to(args[0].dtype) return orig_forward(self, *args, **kwargs) if get_context_parallel_world_size() > 1: From 3ac312580a3197634cb032e26f748d0ffbe7779a Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 21 Jan 2026 20:21:12 -0800 Subject: [PATCH 14/23] fix the bug; runnable Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 81 ++++++++++++++++++- 1 file changed, 79 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 2684a25ec..bb17fd346 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1403,8 +1403,85 @@ def TEDotProductAttentionCP(attention_mask: torch.Tensor, num_attention_heads: i orig_forward = cls.forward def _wrapped_forward(self, *args, **kwargs): - attention_bias = torch.where(attention_mask, torch.tensor(-1e9), torch.tensor(0.0)) - kwargs["attention_bias"] = attention_bias.to(args[0].dtype) + # Build attention_bias from the boolean attention_mask and ensure + # it's a fresh, detached tensor on the query's device/dtype to + # avoid shared-storage in-place modifications that break autograd. + query = args[0] if len(args) > 0 else None + if isinstance(query, torch.Tensor): + q_device = query.device + q_dtype = query.dtype + else: + q_device = None + q_dtype = None + + mask_fill = -1e9 + if q_dtype in (torch.float16, torch.bfloat16): + mask_fill = -40.0 + mask_val = torch.tensor(mask_fill, device=attention_mask.device) + zero_val = torch.tensor(0.0, device=attention_mask.device) + attention_bias = torch.where(attention_mask, mask_val, zero_val) + + # Expand head dimension if needed + try: + if attention_bias.dim() == 4 and attention_bias.shape[1] == 1: + attention_bias = attention_bias.expand(-1, num_attention_heads, -1, -1) + except Exception: + pass + + if q_device is not None and q_dtype is not None: + attention_bias = attention_bias.to(device=q_device, dtype=q_dtype) + + attention_bias = attention_bias.clone().detach().contiguous() + if q_dtype in (torch.float16, torch.bfloat16): + attention_bias = attention_bias.clamp(min=-40.0) + kwargs["attention_bias"] = attention_bias + + # Defensive clone of query/key/value positional tensors to avoid + # passing views into the fused attention kernel that might be + # modified in-place during backward. + if len(args) >= 1: + original_args = args + new_args = list(original_args) + try: + for i in range(min(3, len(new_args))): + if isinstance(new_args[i], torch.Tensor): + if not new_args[i].is_contiguous(): + new_args[i] = new_args[i].contiguous() + new_args[i] = new_args[i].clone() + + if any(x is None for x in new_args): + args = original_args + else: + args = tuple(new_args) + except Exception: + args = original_args + + # Ensure any provided attention_bias matches query dtype/device + if "attention_bias" in kwargs and isinstance(kwargs["attention_bias"], torch.Tensor): + if q_dtype is not None and q_device is not None: + try: + if ( + kwargs["attention_bias"].dtype != q_dtype + or kwargs["attention_bias"].device != q_device + ): + kwargs["attention_bias"] = ( + kwargs["attention_bias"] + .to(device=q_device, dtype=q_dtype) + .clone() + .detach() + .contiguous() + ) + else: + kwargs["attention_bias"] = ( + kwargs["attention_bias"].clone().detach().contiguous() + ) + except Exception: + kwargs["attention_bias"] = ( + kwargs["attention_bias"].clone().detach().contiguous() + ) + else: + kwargs["attention_bias"] = kwargs["attention_bias"].clone().detach().contiguous() + return orig_forward(self, *args, **kwargs) if get_context_parallel_world_size() > 1: From 19e77fd3e032a9986ee80cc4cccebe459ab19150 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Thu, 22 Jan 2026 10:09:46 -0800 Subject: [PATCH 15/23] remove unnecessary code Signed-off-by: Ye Yu --- .../speculative/plugins/megatron_eagle.py | 28 ------------------- 1 file changed, 28 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index bb17fd346..08579a660 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1432,8 +1432,6 @@ def _wrapped_forward(self, *args, **kwargs): attention_bias = attention_bias.to(device=q_device, dtype=q_dtype) attention_bias = attention_bias.clone().detach().contiguous() - if q_dtype in (torch.float16, torch.bfloat16): - attention_bias = attention_bias.clamp(min=-40.0) kwargs["attention_bias"] = attention_bias # Defensive clone of query/key/value positional tensors to avoid @@ -1456,32 +1454,6 @@ def _wrapped_forward(self, *args, **kwargs): except Exception: args = original_args - # Ensure any provided attention_bias matches query dtype/device - if "attention_bias" in kwargs and isinstance(kwargs["attention_bias"], torch.Tensor): - if q_dtype is not None and q_device is not None: - try: - if ( - kwargs["attention_bias"].dtype != q_dtype - or kwargs["attention_bias"].device != q_device - ): - kwargs["attention_bias"] = ( - kwargs["attention_bias"] - .to(device=q_device, dtype=q_dtype) - .clone() - .detach() - .contiguous() - ) - else: - kwargs["attention_bias"] = ( - kwargs["attention_bias"].clone().detach().contiguous() - ) - except Exception: - kwargs["attention_bias"] = ( - kwargs["attention_bias"].clone().detach().contiguous() - ) - else: - kwargs["attention_bias"] = kwargs["attention_bias"].clone().detach().contiguous() - return orig_forward(self, *args, **kwargs) if get_context_parallel_world_size() > 1: From cc88044056b0ce4032d9eb197c340889620e63df Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 26 Jan 2026 09:27:46 -0800 Subject: [PATCH 16/23] fix: HF main needs trust_remote_code=True for resuming ckpt Signed-off-by: Ye Yu --- examples/speculative_decoding/main.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index f8452cd90..8706ca049 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -162,7 +162,9 @@ def train(): use_offline_training = data_args.offline_data_path is not None if checkpoint: - model = transformers.AutoModelForCausalLM.from_pretrained(checkpoint, torch_dtype="auto") + model = transformers.AutoModelForCausalLM.from_pretrained( + checkpoint, torch_dtype="auto", trust_remote_code=True + ) tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) else: # To avoid OOM for large models, we load and convert model on CPU first. From 2fb7a571f71b68096075e68204098655e6e63a1d Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 26 Jan 2026 10:22:56 -0800 Subject: [PATCH 17/23] update changelog Signed-off-by: Ye Yu --- CHANGELOG.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 22143da28..3724d412d 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -15,6 +15,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add support for ``params`` constraint based automatic neural architecture search in Minitron pruning (``mcore_minitron``) as an alternative to manual pruning (using ``export_config``). See `examples/pruning/README.md `_ for more details on its usage. - Add support for calibration data with multiple samples in ``npz`` format in the ONNX Autocast workflow. - Add ``--opset`` option to ONNX quantization CLI to specify the target opset version for the quantized model. +- Add support for context parallelism in Eagle speculative decoding for huggingface and megatron core models. 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ From 8234c682c316220ba150c6f670e0a2beaa475f32 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Mon, 26 Jan 2026 10:41:45 -0800 Subject: [PATCH 18/23] formatting Signed-off-by: Ye Yu --- .../torch/speculative/plugins/megatron_eagle.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 08579a660..1a2a127d2 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1153,7 +1153,7 @@ def forward( ttt_step=ttt_step, ) - with TEDotProductAttentionCP( + with te_dot_product_attention_with_cp( eagle_inputs["attention_mask"], self.eagle_config.num_attention_heads ): _, eagle_logits, eagle_module_input_hidden_states = self._eagle_forward( @@ -1311,7 +1311,7 @@ def pseudo_speculative_generate( # [TODO] (chenhany): let the module compute itself eagle_inputs["rotary_pos_emb"] = None - with TEDotProductAttentionCP( + with te_dot_product_attention_with_cp( eagle_inputs["attention_mask"], self.eagle_config.num_attention_heads ): _, eagle_logits, eagle_next_hidden_states_input = self._eagle_forward( @@ -1381,7 +1381,7 @@ def get_ground_truth(self, input_ids, osl): @contextmanager -def TEDotProductAttentionCP(attention_mask: torch.Tensor, num_attention_heads: int): +def te_dot_product_attention_with_cp(attention_mask: torch.Tensor, num_attention_heads: int): """Context manager for TEDotProductAttention with context parallelism. Context manager that temporarily replace `attention_bias` @@ -1393,14 +1393,14 @@ def TEDotProductAttentionCP(attention_mask: torch.Tensor, num_attention_heads: i if context parallelism is used. Example: - with TEDotProductAttentionCP(attention_mask_tensor, num_attention_heads): + with te_dot_product_attention_with_cp(attention_mask_tensor, num_attention_heads): outputs = model(...) Note: This monkey-patches the class method and restores it on exit. """ - from megatron.core.extensions.transformer_engine import TEDotProductAttention as cls + from megatron.core.extensions.transformer_engine import TEDotProductAttention - orig_forward = cls.forward + orig_forward = TEDotProductAttention.forward def _wrapped_forward(self, *args, **kwargs): # Build attention_bias from the boolean attention_mask and ensure @@ -1457,8 +1457,8 @@ def _wrapped_forward(self, *args, **kwargs): return orig_forward(self, *args, **kwargs) if get_context_parallel_world_size() > 1: - cls.forward = _wrapped_forward + TEDotProductAttention.forward = _wrapped_forward try: yield finally: - cls.forward = orig_forward + TEDotProductAttention.forward = orig_forward From a2e2489614d97215483c178daca59f1fdc5a8285 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 28 Jan 2026 14:50:41 -0800 Subject: [PATCH 19/23] debug: fix the sharded state dict issue Signed-off-by: Ye Yu --- .../compute_hidden_states_hf.py | 140 ++++++++++++------ .../speculative/plugins/megatron_eagle.py | 14 +- 2 files changed, 107 insertions(+), 47 deletions(-) diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py index f9818e464..8c06e1162 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py @@ -17,10 +17,10 @@ import argparse import asyncio -import json from pathlib import Path import torch +from datasets import load_dataset from tqdm import tqdm as tqdm from transformers import AutoModel, AutoTokenizer @@ -54,12 +54,10 @@ def parse_args() -> argparse.Namespace: ## I/O Parameters ## parser.add_argument( - "--input-file", + "--input-data", type=Path, required=True, - help="""Path to the input `jsonl` file containing conversations. - Each entry must have a unique `conversation_id` field and a `conversations` field - containing a list of messages.""", + help="""Path to the `jsonl` file or directory containing `jsonl` files.""", ) parser.add_argument( "--output-dir", @@ -75,21 +73,68 @@ def parse_args() -> argparse.Namespace: help="""For debugging purposes, limit the number of conversations processed. Default is None, meaning no limit.""", ) + parser.add_argument( + "--dp-rank", + type=int, + default=0, + help="""Data parallel rank. TASK_ID on SLURM.""", + ) + parser.add_argument( + "--dp-world-size", + type=int, + default=1, + help="""Data parallel world size. Number of tasks on SLURM.""", + ) return parser.parse_args() -async def main(args: argparse.Namespace) -> None: - all_conversations = [] - with args.input_file.open("r", encoding="utf-8") as f: - all_conversations.extend([json.loads(line) for line in f if line.strip()]) +def main(args: argparse.Namespace) -> None: + # Load conversations + if args.input_data.is_file() and str(args.input_data).endswith(".jsonl"): + dataset = load_dataset("json", data_files=str(args.input_data), split="train") + elif args.input_data.is_dir(): + dataset = load_dataset( + "json", data_files={"train": f"{args.input_data}/*.jsonl"}, split="train" + ) + else: + raise ValueError( + f"input_data must be a .jsonl file or directory containing .jsonl files, got: {args.input_data}" + ) + print(f"Loaded {len(dataset)} conversations from {args.input_data}") - print("Loaded", len(all_conversations), "conversations from", args.input_file) + # Shard data + if args.dp_world_size > 1: + dataset = dataset.shard(num_shards=args.dp_world_size, index=args.dp_rank) + print( + f"Sharded dataset to {len(dataset)} conversations for DP#{args.dp_rank}/{args.dp_world_size}" + ) + + # Remove already dumped conversations + def keep_conversation(entry): + conversation_id = entry.get("conversation_id", entry.get("uuid", None)) + assert conversation_id is not None, "conversation_id is required" + output_file = args.output_dir / f"{conversation_id}.pt" + return not output_file.exists() + + original_num = len(dataset) + dataset = dataset.filter(keep_conversation) + print( + "Removed", + original_num - len(dataset), + "conversations due to existing output files", + ) - model = AutoModel.from_pretrained(args.model, torch_dtype="auto", device_map="auto") + # For debugging + if args.debug_max_num_conversations is not None: + dataset = dataset.select(range(args.debug_max_num_conversations)) + + model = AutoModel.from_pretrained( + args.model, torch_dtype="auto", device_map="auto", trust_remote_code=True + ) num_hidden_layers = getattr(model.config, "num_hidden_layers", None) - tokenizer = AutoTokenizer.from_pretrained(args.model) + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") @@ -99,30 +144,11 @@ async def main(args: argparse.Namespace) -> None: num_skipped_too_long = 0 num_invalid = 0 num_success = 0 - num_total_conversations = min( - len(all_conversations), args.debug_max_num_conversations or len(all_conversations) - ) - for idx, entry in enumerate( - tqdm( - all_conversations[: args.debug_max_num_conversations], - desc="Processing conversations", - total=num_total_conversations, - ) - ): - conversation_id = entry.get("conversation_id", "{:08d}".format(idx)) - conversations = entry["conversations"] - if not conversations or not isinstance(conversations, list): - num_invalid += 1 - continue - - # Tokenize and check length - input_ids = tokenizer.apply_chat_template( - conversations, return_tensors="pt", add_generation_template=False - ) - num_input_tokens = input_ids.shape[1] - if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: - num_skipped_too_long += 1 - continue + pbar = tqdm(total=len(dataset), desc=f"DP#{args.dp_rank} Processing conversations") + + async def dump_hidden_states(idx: int, conversation_id: int, input_ids: torch.Tensor): + nonlocal num_success + nonlocal num_hidden_layers # Get hidden states with torch.inference_mode(): @@ -144,9 +170,9 @@ async def main(args: argparse.Namespace) -> None: aux_hidden_states = torch.cat( [hidden_states[i].squeeze(0).cpu() for i in selected_layer_indices], dim=-1 ) - output_hidden_states = outputs.last_hidden_state.squeeze(0).cpu() + output_hidden_states = hidden_states[-1].squeeze(0).cpu() output_file = output_dir / f"{conversation_id}.pt" - num_success += 1 + with open(output_file, "wb") as f: torch.save( { @@ -158,17 +184,47 @@ async def main(args: argparse.Namespace) -> None: f, ) + num_success += 1 + pbar.update(1) + + async def submit_generates(): + nonlocal num_skipped_too_long + nonlocal num_invalid + tasks = [] + idx = 0 + for entry in dataset: + conversation_id = entry.get("conversation_id", entry.get("uuid")) + + conversations = entry["conversations"] + if not conversations or not isinstance(conversations, list): + num_invalid += 1 + continue + + # Tokenize and check length + input_ids = tokenizer.apply_chat_template( + conversations, return_tensors="pt", add_generation_template=False + )["input_ids"] + num_input_tokens = input_ids.shape[1] + if num_input_tokens <= 10 or num_input_tokens > args.max_seq_len: + num_skipped_too_long += 1 + continue + + tasks.append(dump_hidden_states(idx, conversation_id, input_ids)) + # Increment only for valid conversations to match dump file index + idx += 1 + await asyncio.gather(*tasks) + + asyncio.run(submit_generates()) + if num_skipped_too_long > 0: print(f"Skipped {num_skipped_too_long} conversations due to length constraints.") if num_invalid > 0: print(f"Skipped {num_invalid} invalid conversations without proper fields.") - if num_success == num_total_conversations: + if num_success == len(dataset): print(f"Successfully processed all {num_success} conversations.") else: - print( - f"Successfully processed {num_success} out of {num_total_conversations} conversations." - ) + print(f"Successfully processed {num_success} out of {len(dataset)} conversations.") if __name__ == "__main__": diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 1a2a127d2..4c88562b4 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -25,7 +25,7 @@ from megatron.core import InferenceParams, tensor_parallel from megatron.core.dist_checkpointing.mapping import ShardedStateDict from megatron.core.dist_checkpointing.utils import replace_prefix_for_sharding -from megatron.core.extensions.transformer_engine import TENorm +from megatron.core.extensions.transformer_engine import TELinear, TENorm from megatron.core.inference.contexts import StaticInferenceContext from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding @@ -62,7 +62,6 @@ try: from megatron.core.post_training.modelopt.gpt.model_specs import get_gpt_modelopt_spec - from megatron.core.post_training.modelopt.layers import Linear except ImportError: warnings.warn("Fail to import megatron.core.post_training! EAGLE feature will be disable!") @@ -391,7 +390,11 @@ def sharded_state_dict( if module is not self.layers: sharded_state_dict.update( sharded_state_dict_default( - module, f"{prefix}{name}.", sharded_offsets, metadata + module, + f"{prefix}{name}.", + sharded_offsets, + metadata, + tp_group=self.tp_group, ) ) @@ -445,13 +448,14 @@ def __init__( self._num_aux_hidden_states if self._num_aux_hidden_states > 0 else 2 ) - # This linear was previously a ColumnParallelLinear. We changed it to a normal linear + # This linear was previously a ColumnParallelLinear. We changed it to a TELinear # since ColumnParallelLinear will have try to gather the input sequence when sequence # parallel is used and does not allow gathering the outputs. with torch.device(device): - self.fc = Linear( + self.fc = TELinear( config.hidden_size * fc_input_size_multiplier, config.hidden_size, + parallel_mode="duplicated", config=config, init_method=(lambda w: None), # not used bias=bias, From f7ce0c3b03f8425f916a37b16f549b21564939e7 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 28 Jan 2026 14:54:53 -0800 Subject: [PATCH 20/23] debug Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 4c88562b4..5438596b7 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -459,6 +459,8 @@ def __init__( config=config, init_method=(lambda w: None), # not used bias=bias, + skip_bias_add=False, + skip_weight_param_allocation=False, ) self.rotary_pos_emb = rotary_pos_emb From b36807a1899b2048bdaea66ac3883652749db19a Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 28 Jan 2026 17:56:36 -0800 Subject: [PATCH 21/23] remove attn mask expansion as it is unnecessary and will cause error when TP>1 Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index 5438596b7..ee7d4ad4f 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1427,13 +1427,6 @@ def _wrapped_forward(self, *args, **kwargs): zero_val = torch.tensor(0.0, device=attention_mask.device) attention_bias = torch.where(attention_mask, mask_val, zero_val) - # Expand head dimension if needed - try: - if attention_bias.dim() == 4 and attention_bias.shape[1] == 1: - attention_bias = attention_bias.expand(-1, num_attention_heads, -1, -1) - except Exception: - pass - if q_device is not None and q_dtype is not None: attention_bias = attention_bias.to(device=q_device, dtype=q_dtype) From 8bfd98b7a253a3723a0dafb2e94f659089f49994 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Wed, 28 Jan 2026 18:12:09 -0800 Subject: [PATCH 22/23] formatting Signed-off-by: Ye Yu --- .../collect_hidden_states/compute_hidden_states_hf.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py index 8c06e1162..a3d1681c4 100644 --- a/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py +++ b/examples/speculative_decoding/collect_hidden_states/compute_hidden_states_hf.py @@ -229,4 +229,4 @@ async def submit_generates(): if __name__ == "__main__": cli_args = parse_args() - asyncio.run(main(cli_args)) + main(cli_args) From 4c44c1c85787ce177dbd488a2724e4523f6b2959 Mon Sep 17 00:00:00 2001 From: Ye Yu Date: Thu, 29 Jan 2026 09:07:10 -0800 Subject: [PATCH 23/23] remove te_dot_product_attention_with_cp from pseudo_speculative_generate as we disable this function when CP>1 Signed-off-by: Ye Yu --- modelopt/torch/speculative/plugins/megatron_eagle.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index ee7d4ad4f..e37e8f931 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -1317,13 +1317,10 @@ def pseudo_speculative_generate( # [TODO] (chenhany): let the module compute itself eagle_inputs["rotary_pos_emb"] = None - with te_dot_product_attention_with_cp( - eagle_inputs["attention_mask"], self.eagle_config.num_attention_heads - ): - _, eagle_logits, eagle_next_hidden_states_input = self._eagle_forward( - eagle_inputs, - output_weight, - ) + _, eagle_logits, eagle_next_hidden_states_input = self._eagle_forward( + eagle_inputs, + output_weight, + ) # parallel_logits are only used after the last step if step == steps - 1 and self.eagle_config.parallel_draft_step > 1: