Skip to content

Conversation

@yeyu-nvidia
Copy link
Contributor

@yeyu-nvidia yeyu-nvidia commented Jan 26, 2026

What does this PR do?

New feature

Overview:
This PR implements the context manager which injects attn_mask as attn_bias to TEDotProductAttention so that we can enable EAGLE training with arbitrary mask.

Usage

set CP>1 in https://github.com/NVIDIA/Megatron-LM/blob/main/examples/post_training/modelopt/finetune.sh

# Add a code snippet demonstrating how to use this

Testing

Tested on DSR1 Llama 8B.
CP1->CP2
38854MB->28050MB
MTbench AL 2.26->2.31

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes/No
  • Did you write any new necessary tests?: Yes/No
  • Did you add or update any necessary documentation?: Yes/No
  • Did you update Changelog?: Yes/No

Additional Information

Summary by CodeRabbit

New Features

  • Context parallelism support added for Eagle speculative decoding with HuggingFace and Megatron Core models.
  • Model checkpoint loading enhanced to enable remote code execution capabilities when required.

✏️ Tip: You can customize this high-level summary in your review settings.

@yeyu-nvidia yeyu-nvidia requested a review from a team as a code owner January 26, 2026 18:28
@yeyu-nvidia yeyu-nvidia requested a review from ChenhanYu January 26, 2026 18:28
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 26, 2026

📝 Walkthrough

Walkthrough

Introduces context parallelism support in Eagle speculative decoding for Megatron Core and HuggingFace models. Changes include a new context manager for parallel-aware attention handling, addition of trust_remote_code parameter in model loading, and distribution-aware input reshaping in the speculative decoding pipeline.

Changes

Cohort / File(s) Summary
Documentation & Examples
CHANGELOG.rst, examples/speculative_decoding/main.py
Changelog entry documenting context parallelism feature; model loading now passes trust_remote_code=True to AutoModelForCausalLM and AutoTokenizer.
Core Implementation
modelopt/torch/speculative/plugins/megatron_eagle.py
New TEDotProductAttentionCP context manager enables context-parallel aware attention computation via attention_bias injection; conditional transformations for rotary_seq_len scaling; distribution-aware input reshaping and scattering for attention masks and input_ids; modified Eagle wiring to set use_arbitrary_attention_mask based on context parallel world size.

Sequence Diagram

sequenceDiagram
    participant Input as Input Preparation
    participant CP as Context Parallel Check
    participant TEAttn as TEDotProductAttentionCP
    participant Eagle as Eagle Forward Pass
    participant Output as Attention Output

    Input->>CP: Reshape & scatter attention_mask<br/>and input_ids across CP regions
    activate CP
    
    alt Context Parallel World Size > 1
        CP->>TEAttn: Activate context manager
        activate TEAttn
        
        Note over TEAttn: Convert attention_mask to attention_bias<br/>Expand to match num_heads<br/>Cast to device/dtype
        
        TEAttn->>Eagle: Inject attention_bias via<br/>monkey-patching
        activate Eagle
        
        Note over Eagle: Forward pass uses<br/>attention_bias instead of<br/>attention_mask
        
        Eagle->>Output: Compute parallel-aware<br/>attention
        deactivate Eagle
        deactivate TEAttn
    else Context Parallel World Size == 1
        CP->>Eagle: Use standard attention_mask<br/>(no TEDotProductAttentionCP)
        activate Eagle
        Eagle->>Output: Compute attention
        deactivate Eagle
    end
    deactivate CP
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Context parallelism for Megatron core models' directly and clearly summarizes the main change: introducing context parallelism support for Megatron core models, which is the primary focus of the PR as evidenced by the majority of changes in megatron_eagle.py.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)

850-907: Guard context parallel gather/scatter operations when CP is disabled.

The _get_eagle_module_inputs method unconditionally calls gather_from_sequence_parallel_region() and scatter_to_sequence_parallel_region() with get_context_parallel_group() at lines 855–856, 871–872, 880–881, and 904–905. When context parallelism is disabled (get_context_parallel_world_size() == 1), this group may be None, causing errors or hangs. Gate these operations behind if get_context_parallel_world_size() > 1: and skip the gather/scatter when disabled, following the pattern already established elsewhere in the file.

🔧 Suggested fix (guard CP ops)
-        input_ids = input_ids.clone().transpose(0, 1).contiguous()
-        input_ids = gather_from_sequence_parallel_region(
-            input_ids, group=get_context_parallel_group()
-        )
-        input_ids = input_ids.transpose(0, 1).contiguous()
+        input_ids = input_ids.clone().transpose(0, 1).contiguous()
+        if get_context_parallel_world_size() > 1:
+            input_ids = gather_from_sequence_parallel_region(
+                input_ids, group=get_context_parallel_group()
+            )
+        input_ids = input_ids.transpose(0, 1).contiguous()
...
-        padded_input_ids = padded_input_ids.transpose(0, 1).contiguous()
-        padded_input_ids = scatter_to_sequence_parallel_region(
-            padded_input_ids, group=get_context_parallel_group()
-        )
-        padded_input_ids = padded_input_ids.transpose(0, 1).contiguous()
+        padded_input_ids = padded_input_ids.transpose(0, 1).contiguous()
+        if get_context_parallel_world_size() > 1:
+            padded_input_ids = scatter_to_sequence_parallel_region(
+                padded_input_ids, group=get_context_parallel_group()
+            )
+        padded_input_ids = padded_input_ids.transpose(0, 1).contiguous()
...
-        attn_mask = attn_mask.transpose(0, 2).contiguous()
-        attn_mask = gather_from_sequence_parallel_region(
-            attn_mask, group=get_context_parallel_group()
-        )
-        attn_mask = attn_mask.transpose(0, 2).contiguous()
+        attn_mask = attn_mask.transpose(0, 2).contiguous()
+        if get_context_parallel_world_size() > 1:
+            attn_mask = gather_from_sequence_parallel_region(
+                attn_mask, group=get_context_parallel_group()
+            )
+        attn_mask = attn_mask.transpose(0, 2).contiguous()
...
-        attn_mask = attn_mask.transpose(0, 2).contiguous()
-        attn_mask = scatter_to_sequence_parallel_region(
-            attn_mask, group=get_context_parallel_group()
-        )
-        eagle_inputs["attention_mask"] = attn_mask.transpose(0, 2).contiguous()
+        attn_mask = attn_mask.transpose(0, 2).contiguous()
+        if get_context_parallel_world_size() > 1:
+            attn_mask = scatter_to_sequence_parallel_region(
+                attn_mask, group=get_context_parallel_group()
+            )
+        eagle_inputs["attention_mask"] = attn_mask.transpose(0, 2).contiguous()
🤖 Fix all issues with AI agents
In `@examples/speculative_decoding/main.py`:
- Around line 165-168: Add a new boolean field trust_remote_code: bool = False
to the TrainingArguments dataclass and replace all hardcoded
trust_remote_code=True calls to use training_args.trust_remote_code instead;
specifically, update the calls to
transformers.AutoModelForCausalLM.from_pretrained (the model variable),
transformers.AutoTokenizer.from_pretrained (the tokenizer variable),
transformers.AutoConfig.from_pretrained, and the later
transformers.AutoTokenizer.from_pretrained that use model_name_or_path so they
pass training_args.trust_remote_code.

In `@modelopt/torch/speculative/plugins/megatron_eagle.py`:
- Around line 1383-1464: The wrapper in TEDotProductAttentionCP constructs
attention_bias and injects it into kwargs but doesn't remove the original
boolean attention_mask, which can cause both to be passed into
TEDotProductAttention.forward; update the wrapper (_wrapped_forward) to remove
the original mask by calling kwargs.pop("attention_mask", None) immediately
after assigning kwargs["attention_bias"] = attention_bias so only attention_bias
is passed to cls.forward; keep the rest of the logic (cloning args, device/dtype
conversion, and restoring cls.forward) unchanged.

Comment on lines +165 to 168
model = transformers.AutoModelForCausalLM.from_pretrained(
checkpoint, torch_dtype="auto", trust_remote_code=True
)
tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, check the exact file and lines in question
cd /repo && head -n 175 examples/speculative_decoding/main.py | tail -n 20

Repository: NVIDIA/Model-Optimizer

Length of output: 121


🏁 Script executed:

# Search for all trust_remote_code usage in the file
rg "trust_remote_code" examples/speculative_decoding/main.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 1068


🏁 Script executed:

# Check if there's a training_args or argument parser structure
rg "argparse|ArgumentParser|TrainingArguments" examples/speculative_decoding/main.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 431


🏁 Script executed:

# Look at the overall structure of the file to understand the context
wc -l examples/speculative_decoding/main.py
head -n 50 examples/speculative_decoding/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 2045


🏁 Script executed:

# Get full TrainingArguments definition
rg "@dataclass" examples/speculative_decoding/main.py -A 20 | head -n 80

Repository: NVIDIA/Model-Optimizer

Length of output: 2794


🏁 Script executed:

# Check if trust_remote_code is referenced anywhere in TrainingArguments or passed as argument
rg "trust_remote_code" examples/speculative_decoding/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 373


🏁 Script executed:

# Look at the actual context around lines 165-168 to understand the flow
sed -n '155,175p' examples/speculative_decoding/main.py

Repository: NVIDIA/Model-Optimizer

Length of output: 987


🏁 Script executed:

# Check if training_args is available at that point in the code
sed -n '1,170p' examples/speculative_decoding/main.py | grep -E "training_args|def train"

Repository: NVIDIA/Model-Optimizer

Length of output: 706


Gate trust_remote_code behind an explicit flag.
Lines 165, 167, 173, and 186 all hardcode trust_remote_code=True for model/tokenizer/config loading, which automatically executes remote code if the checkpoint points to an untrusted repository. Add a trust_remote_code: bool = False field to the TrainingArguments dataclass and use training_args.trust_remote_code in all four locations.

Affected locations
  • Line 165-167: checkpoint model and tokenizer loading
  • Line 173, 186: model config and tokenizer loading with model_name_or_path
🤖 Prompt for AI Agents
In `@examples/speculative_decoding/main.py` around lines 165 - 168, Add a new
boolean field trust_remote_code: bool = False to the TrainingArguments dataclass
and replace all hardcoded trust_remote_code=True calls to use
training_args.trust_remote_code instead; specifically, update the calls to
transformers.AutoModelForCausalLM.from_pretrained (the model variable),
transformers.AutoTokenizer.from_pretrained (the tokenizer variable),
transformers.AutoConfig.from_pretrained, and the later
transformers.AutoTokenizer.from_pretrained that use model_name_or_path so they
pass training_args.trust_remote_code.

Comment on lines 1383 to 1464
@contextmanager
def TEDotProductAttentionCP(attention_mask: torch.Tensor, num_attention_heads: int):
"""Context manager for TEDotProductAttention with context parallelism.

Context manager that temporarily replace `attention_bias`
with `attention_mask` for `TEDotProductAttention.forward` calls across the process
if context parallel is used.

Any call to `TEDotProductAttention.forward` (including calls originating
from other modules) inside the context will receive `attention_bias=attention_mask`
if context parallelism is used.

Example:
with TEDotProductAttentionCP(attention_mask_tensor, num_attention_heads):
outputs = model(...)

Note: This monkey-patches the class method and restores it on exit.
"""
from megatron.core.extensions.transformer_engine import TEDotProductAttention as cls

orig_forward = cls.forward

def _wrapped_forward(self, *args, **kwargs):
# Build attention_bias from the boolean attention_mask and ensure
# it's a fresh, detached tensor on the query's device/dtype to
# avoid shared-storage in-place modifications that break autograd.
query = args[0] if len(args) > 0 else None
if isinstance(query, torch.Tensor):
q_device = query.device
q_dtype = query.dtype
else:
q_device = None
q_dtype = None

mask_fill = -1e9
if q_dtype in (torch.float16, torch.bfloat16):
mask_fill = -40.0
mask_val = torch.tensor(mask_fill, device=attention_mask.device)
zero_val = torch.tensor(0.0, device=attention_mask.device)
attention_bias = torch.where(attention_mask, mask_val, zero_val)

# Expand head dimension if needed
try:
if attention_bias.dim() == 4 and attention_bias.shape[1] == 1:
attention_bias = attention_bias.expand(-1, num_attention_heads, -1, -1)
except Exception:
pass

if q_device is not None and q_dtype is not None:
attention_bias = attention_bias.to(device=q_device, dtype=q_dtype)

attention_bias = attention_bias.clone().detach().contiguous()
kwargs["attention_bias"] = attention_bias

# Defensive clone of query/key/value positional tensors to avoid
# passing views into the fused attention kernel that might be
# modified in-place during backward.
if len(args) >= 1:
original_args = args
new_args = list(original_args)
try:
for i in range(min(3, len(new_args))):
if isinstance(new_args[i], torch.Tensor):
if not new_args[i].is_contiguous():
new_args[i] = new_args[i].contiguous()
new_args[i] = new_args[i].clone()

if any(x is None for x in new_args):
args = original_args
else:
args = tuple(new_args)
except Exception:
args = original_args

return orig_forward(self, *args, **kwargs)

if get_context_parallel_world_size() > 1:
cls.forward = _wrapped_forward
try:
yield
finally:
cls.forward = orig_forward
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's find where TEDotProductAttentionCP is used
rg "TEDotProductAttentionCP" --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 528


🏁 Script executed:

# Let's also check the file structure and find related code
fd "megatron_eagle" --type f

Repository: NVIDIA/Model-Optimizer

Length of output: 118


🏁 Script executed:

# Check what's around the context manager to understand the usage pattern
cat -n "modelopt/torch/speculative/plugins/megatron_eagle.py" | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 5319


🏁 Script executed:

# Search for any imports or usages of attention_mask/attention_bias in the same file
rg "attention_mask|attention_bias" modelopt/torch/speculative/plugins/megatron_eagle.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 7105


🏁 Script executed:

# Let's check if there's documentation or comments about TE's API
rg "TEDotProductAttention" --type py -B 3 -A 3

Repository: NVIDIA/Model-Optimizer

Length of output: 13792


🏁 Script executed:

# Find _eagle_forward definition
rg "_eagle_forward" modelopt/torch/speculative/plugins/megatron_eagle.py -B 2 -A 10

Repository: NVIDIA/Model-Optimizer

Length of output: 1792


🏁 Script executed:

# Search for SelfAttention or transformer layer that calls TEDotProductAttention
rg "SelfAttention|core_attention" modelopt/torch/speculative/plugins/megatron_eagle.py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 625


🏁 Script executed:

# Look for any documentation on TE's TEDotProductAttention in megatron-core or comments
rg "attention_bias|attention_mask" modelopt/torch/speculative/plugins/megatron_eagle.py -B 3 -A 3 | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 3098


🏁 Script executed:

# Check what parameters get passed through the decoder/forward calls
cat -n modelopt/torch/speculative/plugins/megatron_eagle.py | sed -n '1200,1250p'

Repository: NVIDIA/Model-Optimizer

Length of output: 2730


🏁 Script executed:

# Check the transformer layer spec and SelfAttention configuration
cat -n modelopt/torch/speculative/plugins/megatron_eagle.py | sed -n '250,350p'

Repository: NVIDIA/Model-Optimizer

Length of output: 4455


🏁 Script executed:

# Search for transformer decoder forward or layer forward to see what parameters are passed
rg "def forward.*attention_mask" --type py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the actual forward signature of transformer layers in megatron
rg "class.*SelfAttention" --type py -A 10 | head -40

Repository: NVIDIA/Model-Optimizer

Length of output: 1872


🏁 Script executed:

# Check if there's any handling of attention_mask vs attention_bias in the layer
rg "attention_bias" --type py -B 5 -A 5 | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 9320


🏁 Script executed:

# Search for the actual calling pattern - what params does transformer layer pass?
rg "core_attention" --type py -B 3 -A 3 | grep -A 3 "forward\|__call__" | head -50

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Let's look at the transformer layer's forward method more directly
rg "hidden_states.*attention_mask" --type py -B 2 -A 5 | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look at the get_gpt_modelopt_spec to understand how attention is configured
rg "get_gpt_modelopt_spec" --type py -A 20 | head -60

Repository: NVIDIA/Model-Optimizer

Length of output: 5723


🏁 Script executed:

# Search for how SelfAttention layer passes parameters to core_attention
rg "class SelfAttention" --type py -A 50 | grep -A 20 "def forward"

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Look for any documentation on whether TE accepts both attention_mask and attention_bias
web_search
TE TEDotProductAttention forward signature attention_mask attention_bias parameters

Repository: NVIDIA/Model-Optimizer

Length of output: 157


🏁 Script executed:

# Let's check the actual usage - what kwargs are being passed when the context manager is active
cat -n modelopt/torch/speculative/plugins/megatron_eagle.py | sed -n '1100,1180p'

Repository: NVIDIA/Model-Optimizer

Length of output: 4395


🏁 Script executed:

# Look at how the spec is configured when context parallel is used
cat -n modelopt/torch/speculative/plugins/megatron_eagle.py | sed -n '840,870p'

Repository: NVIDIA/Model-Optimizer

Length of output: 1558


🏁 Script executed:

# Search in the codebase for SelfAttention.forward to see what it passes to core_attention
rg "core_attention.*forward|core_attention\(" --type py -B 5 -A 3 | head -100

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check if the transformer layer spec shows how the attention is wired
rg "self_attention.*=|core_attention.*=" --type py -B 2 -A 2 | head -80

Repository: NVIDIA/Model-Optimizer

Length of output: 7121


🏁 Script executed:

# Let's look at megatron-core imports to understand the structure
rg "from megatron.core" modelopt/torch/speculative/plugins/megatron_eagle.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 1429


🏁 Script executed:

# Search for SelfAttention class implementation or usage
rg "class SelfAttention" --type py -B 2 -A 80 | head -120

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Check if there's test code that shows how attention parameters flow
rg "attention_mask.*attention_bias|attention_bias.*attention_mask" --type py

Repository: NVIDIA/Model-Optimizer

Length of output: 589


🏁 Script executed:

# Look for any PR comments or TODOs about this
rg "TODO.*attention|FIXME.*attention|XXX.*attention" --type py -B 2 -A 2

Repository: NVIDIA/Model-Optimizer

Length of output: 48


🏁 Script executed:

# Let's check the actual spec configuration to understand what gets passed
cat -n modelopt/torch/speculative/plugins/megatron_eagle.py | sed -n '320,380p'

Repository: NVIDIA/Model-Optimizer

Length of output: 3028


🏁 Script executed:

# Check if there are any error handling or assertions around attention parameters
rg "assert.*attention|raise.*attention" --type py -B 1 -A 1 | head -40

Repository: NVIDIA/Model-Optimizer

Length of output: 2881


🏁 Script executed:

# Look for the actual eagle_module instantiation to understand its configuration
cat -n modelopt/torch/speculative/plugins/megatron_eagle.py | sed -n '950,1050p'

Repository: NVIDIA/Model-Optimizer

Length of output: 4748


🏁 Script executed:

# Search for how SelfAttention calls core_attention to see if it passes attention_mask
rg "def forward" modelopt/torch/speculative/plugins/megatron_eagle.py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 125


🏁 Script executed:

# Look at the comment more carefully - what does "instead" imply?
cat -n modelopt/torch/speculative/plugins/megatron_eagle.py | sed -n '1383,1400p'

Repository: NVIDIA/Model-Optimizer

Length of output: 962


🏁 Script executed:

# Check if there's any existing pattern of removing parameters from kwargs
rg "kwargs.pop|del.*kwargs" --type py | head -20

Repository: NVIDIA/Model-Optimizer

Length of output: 2386


🏁 Script executed:

# Verify the context - is this confirmed as a needed fix by looking at similar patterns
rg "attention_mask.*kwargs|kwargs.*attention" --type py -B 2 -A 2 | head -60

Repository: NVIDIA/Model-Optimizer

Length of output: 5499


Remove attention_mask from kwargs when injecting attention_bias.

The context manager converts the boolean attention_mask to a numeric attention_bias and injects it into kwargs (line 1437). However, the original attention_mask may still be present in kwargs when passed through the transformer layer, causing both parameters to reach TEDotProductAttention.forward simultaneously. The docstring explicitly states the intent is to "replace" the mask (not supplement it), implying TE expects one parameter, not both. If TE prioritizes or rejects duplicate specifications, the injected bias may be ignored or cause errors. Add kwargs.pop("attention_mask", None) after line 1437 to ensure clean parameter passing.

Recommended fix
         attention_bias = attention_bias.clone().detach().contiguous()
         kwargs["attention_bias"] = attention_bias
+        kwargs.pop("attention_mask", None)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
@contextmanager
def TEDotProductAttentionCP(attention_mask: torch.Tensor, num_attention_heads: int):
"""Context manager for TEDotProductAttention with context parallelism.
Context manager that temporarily replace `attention_bias`
with `attention_mask` for `TEDotProductAttention.forward` calls across the process
if context parallel is used.
Any call to `TEDotProductAttention.forward` (including calls originating
from other modules) inside the context will receive `attention_bias=attention_mask`
if context parallelism is used.
Example:
with TEDotProductAttentionCP(attention_mask_tensor, num_attention_heads):
outputs = model(...)
Note: This monkey-patches the class method and restores it on exit.
"""
from megatron.core.extensions.transformer_engine import TEDotProductAttention as cls
orig_forward = cls.forward
def _wrapped_forward(self, *args, **kwargs):
# Build attention_bias from the boolean attention_mask and ensure
# it's a fresh, detached tensor on the query's device/dtype to
# avoid shared-storage in-place modifications that break autograd.
query = args[0] if len(args) > 0 else None
if isinstance(query, torch.Tensor):
q_device = query.device
q_dtype = query.dtype
else:
q_device = None
q_dtype = None
mask_fill = -1e9
if q_dtype in (torch.float16, torch.bfloat16):
mask_fill = -40.0
mask_val = torch.tensor(mask_fill, device=attention_mask.device)
zero_val = torch.tensor(0.0, device=attention_mask.device)
attention_bias = torch.where(attention_mask, mask_val, zero_val)
# Expand head dimension if needed
try:
if attention_bias.dim() == 4 and attention_bias.shape[1] == 1:
attention_bias = attention_bias.expand(-1, num_attention_heads, -1, -1)
except Exception:
pass
if q_device is not None and q_dtype is not None:
attention_bias = attention_bias.to(device=q_device, dtype=q_dtype)
attention_bias = attention_bias.clone().detach().contiguous()
kwargs["attention_bias"] = attention_bias
# Defensive clone of query/key/value positional tensors to avoid
# passing views into the fused attention kernel that might be
# modified in-place during backward.
if len(args) >= 1:
original_args = args
new_args = list(original_args)
try:
for i in range(min(3, len(new_args))):
if isinstance(new_args[i], torch.Tensor):
if not new_args[i].is_contiguous():
new_args[i] = new_args[i].contiguous()
new_args[i] = new_args[i].clone()
if any(x is None for x in new_args):
args = original_args
else:
args = tuple(new_args)
except Exception:
args = original_args
return orig_forward(self, *args, **kwargs)
if get_context_parallel_world_size() > 1:
cls.forward = _wrapped_forward
try:
yield
finally:
cls.forward = orig_forward
`@contextmanager`
def TEDotProductAttentionCP(attention_mask: torch.Tensor, num_attention_heads: int):
"""Context manager for TEDotProductAttention with context parallelism.
Context manager that temporarily replace `attention_bias`
with `attention_mask` for `TEDotProductAttention.forward` calls across the process
if context parallel is used.
Any call to `TEDotProductAttention.forward` (including calls originating
from other modules) inside the context will receive `attention_bias=attention_mask`
if context parallelism is used.
Example:
with TEDotProductAttentionCP(attention_mask_tensor, num_attention_heads):
outputs = model(...)
Note: This monkey-patches the class method and restores it on exit.
"""
from megatron.core.extensions.transformer_engine import TEDotProductAttention as cls
orig_forward = cls.forward
def _wrapped_forward(self, *args, **kwargs):
# Build attention_bias from the boolean attention_mask and ensure
# it's a fresh, detached tensor on the query's device/dtype to
# avoid shared-storage in-place modifications that break autograd.
query = args[0] if len(args) > 0 else None
if isinstance(query, torch.Tensor):
q_device = query.device
q_dtype = query.dtype
else:
q_device = None
q_dtype = None
mask_fill = -1e9
if q_dtype in (torch.float16, torch.bfloat16):
mask_fill = -40.0
mask_val = torch.tensor(mask_fill, device=attention_mask.device)
zero_val = torch.tensor(0.0, device=attention_mask.device)
attention_bias = torch.where(attention_mask, mask_val, zero_val)
# Expand head dimension if needed
try:
if attention_bias.dim() == 4 and attention_bias.shape[1] == 1:
attention_bias = attention_bias.expand(-1, num_attention_heads, -1, -1)
except Exception:
pass
if q_device is not None and q_dtype is not None:
attention_bias = attention_bias.to(device=q_device, dtype=q_dtype)
attention_bias = attention_bias.clone().detach().contiguous()
kwargs["attention_bias"] = attention_bias
kwargs.pop("attention_mask", None)
# Defensive clone of query/key/value positional tensors to avoid
# passing views into the fused attention kernel that might be
# modified in-place during backward.
if len(args) >= 1:
original_args = args
new_args = list(original_args)
try:
for i in range(min(3, len(new_args))):
if isinstance(new_args[i], torch.Tensor):
if not new_args[i].is_contiguous():
new_args[i] = new_args[i].contiguous()
new_args[i] = new_args[i].clone()
if any(x is None for x in new_args):
args = original_args
else:
args = tuple(new_args)
except Exception:
args = original_args
return orig_forward(self, *args, **kwargs)
if get_context_parallel_world_size() > 1:
cls.forward = _wrapped_forward
try:
yield
finally:
cls.forward = orig_forward
🤖 Prompt for AI Agents
In `@modelopt/torch/speculative/plugins/megatron_eagle.py` around lines 1383 -
1464, The wrapper in TEDotProductAttentionCP constructs attention_bias and
injects it into kwargs but doesn't remove the original boolean attention_mask,
which can cause both to be passed into TEDotProductAttention.forward; update the
wrapper (_wrapped_forward) to remove the original mask by calling
kwargs.pop("attention_mask", None) immediately after assigning
kwargs["attention_bias"] = attention_bias so only attention_bias is passed to
cls.forward; keep the rest of the logic (cloning args, device/dtype conversion,
and restoring cls.forward) unchanged.

@codecov
Copy link

codecov bot commented Jan 26, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 73.84%. Comparing base (770962b) to head (4c44c1c).
⚠️ Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #818      +/-   ##
==========================================
+ Coverage   73.82%   73.84%   +0.01%     
==========================================
  Files         193      193              
  Lines       19745    19745              
==========================================
+ Hits        14577    14580       +3     
+ Misses       5168     5165       -3     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:]
# [b, 1, sq, sk] -> [sq, 1, b, sk]
attn_mask = attn_mask.transpose(0, 2).contiguous()
attn_mask = gather_from_sequence_parallel_region(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume here that here attn_mask has shape [sq*cp, 1, b, sk]. My understanding is that sk=ttt_step*sq*cp? So the attn_mask is O(ttt_stepbsq^2*cp^2)?

Copy link
Contributor Author

@yeyu-nvidia yeyu-nvidia Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sk=sq*cp. ttt_step will be introduced later at set_multi_step_attention_mask

attn_mask = set_multi_step_attention_mask(attn_mask, ttt_step)
# [b, 1, sq, sk] -> [sq, 1, b, sk]
attn_mask = attn_mask.transpose(0, 2).contiguous()
attn_mask = scatter_to_sequence_parallel_region(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume scatter is from [cp*sq, 1, b, sk] back to [sq, 1, b, sk]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes



@contextmanager
def te_dot_product_attention_with_cp(attention_mask: torch.Tensor, num_attention_heads: int):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I need a walkthrough here.

… patch

Signed-off-by: Ye Yu <yeyu@nvidia.com>
…_size

Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…lel region

Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
… gather_from_sequence_parallel_region and change the group

Signed-off-by: Ye Yu <yeyu@nvidia.com>
…mention so we need to transpose tensors first

Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…when TP>1

Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…ate as we disable this function when CP>1

Signed-off-by: Ye Yu <yeyu@nvidia.com>
@yeyu-nvidia yeyu-nvidia enabled auto-merge (squash) January 29, 2026 18:15
@yeyu-nvidia yeyu-nvidia merged commit 81b67dd into main Jan 29, 2026
37 checks passed
@yeyu-nvidia yeyu-nvidia deleted the yeyu/megatron_cp branch January 29, 2026 19:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants