-
Notifications
You must be signed in to change notification settings - Fork 247
Context parallelism for Megatron core models #818
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
📝 WalkthroughWalkthroughIntroduces context parallelism support in Eagle speculative decoding for Megatron Core and HuggingFace models. Changes include a new context manager for parallel-aware attention handling, addition of Changes
Sequence DiagramsequenceDiagram
participant Input as Input Preparation
participant CP as Context Parallel Check
participant TEAttn as TEDotProductAttentionCP
participant Eagle as Eagle Forward Pass
participant Output as Attention Output
Input->>CP: Reshape & scatter attention_mask<br/>and input_ids across CP regions
activate CP
alt Context Parallel World Size > 1
CP->>TEAttn: Activate context manager
activate TEAttn
Note over TEAttn: Convert attention_mask to attention_bias<br/>Expand to match num_heads<br/>Cast to device/dtype
TEAttn->>Eagle: Inject attention_bias via<br/>monkey-patching
activate Eagle
Note over Eagle: Forward pass uses<br/>attention_bias instead of<br/>attention_mask
Eagle->>Output: Compute parallel-aware<br/>attention
deactivate Eagle
deactivate TEAttn
else Context Parallel World Size == 1
CP->>Eagle: Use standard attention_mask<br/>(no TEDotProductAttentionCP)
activate Eagle
Eagle->>Output: Compute attention
deactivate Eagle
end
deactivate CP
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/speculative/plugins/megatron_eagle.py (1)
850-907: Guard context parallel gather/scatter operations when CP is disabled.The
_get_eagle_module_inputsmethod unconditionally callsgather_from_sequence_parallel_region()andscatter_to_sequence_parallel_region()withget_context_parallel_group()at lines 855–856, 871–872, 880–881, and 904–905. When context parallelism is disabled (get_context_parallel_world_size() == 1), this group may beNone, causing errors or hangs. Gate these operations behindif get_context_parallel_world_size() > 1:and skip the gather/scatter when disabled, following the pattern already established elsewhere in the file.🔧 Suggested fix (guard CP ops)
- input_ids = input_ids.clone().transpose(0, 1).contiguous() - input_ids = gather_from_sequence_parallel_region( - input_ids, group=get_context_parallel_group() - ) - input_ids = input_ids.transpose(0, 1).contiguous() + input_ids = input_ids.clone().transpose(0, 1).contiguous() + if get_context_parallel_world_size() > 1: + input_ids = gather_from_sequence_parallel_region( + input_ids, group=get_context_parallel_group() + ) + input_ids = input_ids.transpose(0, 1).contiguous() ... - padded_input_ids = padded_input_ids.transpose(0, 1).contiguous() - padded_input_ids = scatter_to_sequence_parallel_region( - padded_input_ids, group=get_context_parallel_group() - ) - padded_input_ids = padded_input_ids.transpose(0, 1).contiguous() + padded_input_ids = padded_input_ids.transpose(0, 1).contiguous() + if get_context_parallel_world_size() > 1: + padded_input_ids = scatter_to_sequence_parallel_region( + padded_input_ids, group=get_context_parallel_group() + ) + padded_input_ids = padded_input_ids.transpose(0, 1).contiguous() ... - attn_mask = attn_mask.transpose(0, 2).contiguous() - attn_mask = gather_from_sequence_parallel_region( - attn_mask, group=get_context_parallel_group() - ) - attn_mask = attn_mask.transpose(0, 2).contiguous() + attn_mask = attn_mask.transpose(0, 2).contiguous() + if get_context_parallel_world_size() > 1: + attn_mask = gather_from_sequence_parallel_region( + attn_mask, group=get_context_parallel_group() + ) + attn_mask = attn_mask.transpose(0, 2).contiguous() ... - attn_mask = attn_mask.transpose(0, 2).contiguous() - attn_mask = scatter_to_sequence_parallel_region( - attn_mask, group=get_context_parallel_group() - ) - eagle_inputs["attention_mask"] = attn_mask.transpose(0, 2).contiguous() + attn_mask = attn_mask.transpose(0, 2).contiguous() + if get_context_parallel_world_size() > 1: + attn_mask = scatter_to_sequence_parallel_region( + attn_mask, group=get_context_parallel_group() + ) + eagle_inputs["attention_mask"] = attn_mask.transpose(0, 2).contiguous()
🤖 Fix all issues with AI agents
In `@examples/speculative_decoding/main.py`:
- Around line 165-168: Add a new boolean field trust_remote_code: bool = False
to the TrainingArguments dataclass and replace all hardcoded
trust_remote_code=True calls to use training_args.trust_remote_code instead;
specifically, update the calls to
transformers.AutoModelForCausalLM.from_pretrained (the model variable),
transformers.AutoTokenizer.from_pretrained (the tokenizer variable),
transformers.AutoConfig.from_pretrained, and the later
transformers.AutoTokenizer.from_pretrained that use model_name_or_path so they
pass training_args.trust_remote_code.
In `@modelopt/torch/speculative/plugins/megatron_eagle.py`:
- Around line 1383-1464: The wrapper in TEDotProductAttentionCP constructs
attention_bias and injects it into kwargs but doesn't remove the original
boolean attention_mask, which can cause both to be passed into
TEDotProductAttention.forward; update the wrapper (_wrapped_forward) to remove
the original mask by calling kwargs.pop("attention_mask", None) immediately
after assigning kwargs["attention_bias"] = attention_bias so only attention_bias
is passed to cls.forward; keep the rest of the logic (cloning args, device/dtype
conversion, and restoring cls.forward) unchanged.
| model = transformers.AutoModelForCausalLM.from_pretrained( | ||
| checkpoint, torch_dtype="auto", trust_remote_code=True | ||
| ) | ||
| tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, check the exact file and lines in question
cd /repo && head -n 175 examples/speculative_decoding/main.py | tail -n 20Repository: NVIDIA/Model-Optimizer
Length of output: 121
🏁 Script executed:
# Search for all trust_remote_code usage in the file
rg "trust_remote_code" examples/speculative_decoding/main.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 1068
🏁 Script executed:
# Check if there's a training_args or argument parser structure
rg "argparse|ArgumentParser|TrainingArguments" examples/speculative_decoding/main.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 431
🏁 Script executed:
# Look at the overall structure of the file to understand the context
wc -l examples/speculative_decoding/main.py
head -n 50 examples/speculative_decoding/main.pyRepository: NVIDIA/Model-Optimizer
Length of output: 2045
🏁 Script executed:
# Get full TrainingArguments definition
rg "@dataclass" examples/speculative_decoding/main.py -A 20 | head -n 80Repository: NVIDIA/Model-Optimizer
Length of output: 2794
🏁 Script executed:
# Check if trust_remote_code is referenced anywhere in TrainingArguments or passed as argument
rg "trust_remote_code" examples/speculative_decoding/main.pyRepository: NVIDIA/Model-Optimizer
Length of output: 373
🏁 Script executed:
# Look at the actual context around lines 165-168 to understand the flow
sed -n '155,175p' examples/speculative_decoding/main.pyRepository: NVIDIA/Model-Optimizer
Length of output: 987
🏁 Script executed:
# Check if training_args is available at that point in the code
sed -n '1,170p' examples/speculative_decoding/main.py | grep -E "training_args|def train"Repository: NVIDIA/Model-Optimizer
Length of output: 706
Gate trust_remote_code behind an explicit flag.
Lines 165, 167, 173, and 186 all hardcode trust_remote_code=True for model/tokenizer/config loading, which automatically executes remote code if the checkpoint points to an untrusted repository. Add a trust_remote_code: bool = False field to the TrainingArguments dataclass and use training_args.trust_remote_code in all four locations.
Affected locations
- Line 165-167: checkpoint model and tokenizer loading
- Line 173, 186: model config and tokenizer loading with model_name_or_path
🤖 Prompt for AI Agents
In `@examples/speculative_decoding/main.py` around lines 165 - 168, Add a new
boolean field trust_remote_code: bool = False to the TrainingArguments dataclass
and replace all hardcoded trust_remote_code=True calls to use
training_args.trust_remote_code instead; specifically, update the calls to
transformers.AutoModelForCausalLM.from_pretrained (the model variable),
transformers.AutoTokenizer.from_pretrained (the tokenizer variable),
transformers.AutoConfig.from_pretrained, and the later
transformers.AutoTokenizer.from_pretrained that use model_name_or_path so they
pass training_args.trust_remote_code.
| @contextmanager | ||
| def TEDotProductAttentionCP(attention_mask: torch.Tensor, num_attention_heads: int): | ||
| """Context manager for TEDotProductAttention with context parallelism. | ||
|
|
||
| Context manager that temporarily replace `attention_bias` | ||
| with `attention_mask` for `TEDotProductAttention.forward` calls across the process | ||
| if context parallel is used. | ||
|
|
||
| Any call to `TEDotProductAttention.forward` (including calls originating | ||
| from other modules) inside the context will receive `attention_bias=attention_mask` | ||
| if context parallelism is used. | ||
|
|
||
| Example: | ||
| with TEDotProductAttentionCP(attention_mask_tensor, num_attention_heads): | ||
| outputs = model(...) | ||
|
|
||
| Note: This monkey-patches the class method and restores it on exit. | ||
| """ | ||
| from megatron.core.extensions.transformer_engine import TEDotProductAttention as cls | ||
|
|
||
| orig_forward = cls.forward | ||
|
|
||
| def _wrapped_forward(self, *args, **kwargs): | ||
| # Build attention_bias from the boolean attention_mask and ensure | ||
| # it's a fresh, detached tensor on the query's device/dtype to | ||
| # avoid shared-storage in-place modifications that break autograd. | ||
| query = args[0] if len(args) > 0 else None | ||
| if isinstance(query, torch.Tensor): | ||
| q_device = query.device | ||
| q_dtype = query.dtype | ||
| else: | ||
| q_device = None | ||
| q_dtype = None | ||
|
|
||
| mask_fill = -1e9 | ||
| if q_dtype in (torch.float16, torch.bfloat16): | ||
| mask_fill = -40.0 | ||
| mask_val = torch.tensor(mask_fill, device=attention_mask.device) | ||
| zero_val = torch.tensor(0.0, device=attention_mask.device) | ||
| attention_bias = torch.where(attention_mask, mask_val, zero_val) | ||
|
|
||
| # Expand head dimension if needed | ||
| try: | ||
| if attention_bias.dim() == 4 and attention_bias.shape[1] == 1: | ||
| attention_bias = attention_bias.expand(-1, num_attention_heads, -1, -1) | ||
| except Exception: | ||
| pass | ||
|
|
||
| if q_device is not None and q_dtype is not None: | ||
| attention_bias = attention_bias.to(device=q_device, dtype=q_dtype) | ||
|
|
||
| attention_bias = attention_bias.clone().detach().contiguous() | ||
| kwargs["attention_bias"] = attention_bias | ||
|
|
||
| # Defensive clone of query/key/value positional tensors to avoid | ||
| # passing views into the fused attention kernel that might be | ||
| # modified in-place during backward. | ||
| if len(args) >= 1: | ||
| original_args = args | ||
| new_args = list(original_args) | ||
| try: | ||
| for i in range(min(3, len(new_args))): | ||
| if isinstance(new_args[i], torch.Tensor): | ||
| if not new_args[i].is_contiguous(): | ||
| new_args[i] = new_args[i].contiguous() | ||
| new_args[i] = new_args[i].clone() | ||
|
|
||
| if any(x is None for x in new_args): | ||
| args = original_args | ||
| else: | ||
| args = tuple(new_args) | ||
| except Exception: | ||
| args = original_args | ||
|
|
||
| return orig_forward(self, *args, **kwargs) | ||
|
|
||
| if get_context_parallel_world_size() > 1: | ||
| cls.forward = _wrapped_forward | ||
| try: | ||
| yield | ||
| finally: | ||
| cls.forward = orig_forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's find where TEDotProductAttentionCP is used
rg "TEDotProductAttentionCP" --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 528
🏁 Script executed:
# Let's also check the file structure and find related code
fd "megatron_eagle" --type fRepository: NVIDIA/Model-Optimizer
Length of output: 118
🏁 Script executed:
# Check what's around the context manager to understand the usage pattern
cat -n "modelopt/torch/speculative/plugins/megatron_eagle.py" | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 5319
🏁 Script executed:
# Search for any imports or usages of attention_mask/attention_bias in the same file
rg "attention_mask|attention_bias" modelopt/torch/speculative/plugins/megatron_eagle.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 7105
🏁 Script executed:
# Let's check if there's documentation or comments about TE's API
rg "TEDotProductAttention" --type py -B 3 -A 3Repository: NVIDIA/Model-Optimizer
Length of output: 13792
🏁 Script executed:
# Find _eagle_forward definition
rg "_eagle_forward" modelopt/torch/speculative/plugins/megatron_eagle.py -B 2 -A 10Repository: NVIDIA/Model-Optimizer
Length of output: 1792
🏁 Script executed:
# Search for SelfAttention or transformer layer that calls TEDotProductAttention
rg "SelfAttention|core_attention" modelopt/torch/speculative/plugins/megatron_eagle.py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 625
🏁 Script executed:
# Look for any documentation on TE's TEDotProductAttention in megatron-core or comments
rg "attention_bias|attention_mask" modelopt/torch/speculative/plugins/megatron_eagle.py -B 3 -A 3 | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 3098
🏁 Script executed:
# Check what parameters get passed through the decoder/forward calls
cat -n modelopt/torch/speculative/plugins/megatron_eagle.py | sed -n '1200,1250p'Repository: NVIDIA/Model-Optimizer
Length of output: 2730
🏁 Script executed:
# Check the transformer layer spec and SelfAttention configuration
cat -n modelopt/torch/speculative/plugins/megatron_eagle.py | sed -n '250,350p'Repository: NVIDIA/Model-Optimizer
Length of output: 4455
🏁 Script executed:
# Search for transformer decoder forward or layer forward to see what parameters are passed
rg "def forward.*attention_mask" --type py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look at the actual forward signature of transformer layers in megatron
rg "class.*SelfAttention" --type py -A 10 | head -40Repository: NVIDIA/Model-Optimizer
Length of output: 1872
🏁 Script executed:
# Check if there's any handling of attention_mask vs attention_bias in the layer
rg "attention_bias" --type py -B 5 -A 5 | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 9320
🏁 Script executed:
# Search for the actual calling pattern - what params does transformer layer pass?
rg "core_attention" --type py -B 3 -A 3 | grep -A 3 "forward\|__call__" | head -50Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Let's look at the transformer layer's forward method more directly
rg "hidden_states.*attention_mask" --type py -B 2 -A 5 | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look at the get_gpt_modelopt_spec to understand how attention is configured
rg "get_gpt_modelopt_spec" --type py -A 20 | head -60Repository: NVIDIA/Model-Optimizer
Length of output: 5723
🏁 Script executed:
# Search for how SelfAttention layer passes parameters to core_attention
rg "class SelfAttention" --type py -A 50 | grep -A 20 "def forward"Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Look for any documentation on whether TE accepts both attention_mask and attention_bias
web_search
TE TEDotProductAttention forward signature attention_mask attention_bias parametersRepository: NVIDIA/Model-Optimizer
Length of output: 157
🏁 Script executed:
# Let's check the actual usage - what kwargs are being passed when the context manager is active
cat -n modelopt/torch/speculative/plugins/megatron_eagle.py | sed -n '1100,1180p'Repository: NVIDIA/Model-Optimizer
Length of output: 4395
🏁 Script executed:
# Look at how the spec is configured when context parallel is used
cat -n modelopt/torch/speculative/plugins/megatron_eagle.py | sed -n '840,870p'Repository: NVIDIA/Model-Optimizer
Length of output: 1558
🏁 Script executed:
# Search in the codebase for SelfAttention.forward to see what it passes to core_attention
rg "core_attention.*forward|core_attention\(" --type py -B 5 -A 3 | head -100Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check if the transformer layer spec shows how the attention is wired
rg "self_attention.*=|core_attention.*=" --type py -B 2 -A 2 | head -80Repository: NVIDIA/Model-Optimizer
Length of output: 7121
🏁 Script executed:
# Let's look at megatron-core imports to understand the structure
rg "from megatron.core" modelopt/torch/speculative/plugins/megatron_eagle.py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 1429
🏁 Script executed:
# Search for SelfAttention class implementation or usage
rg "class SelfAttention" --type py -B 2 -A 80 | head -120Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Check if there's test code that shows how attention parameters flow
rg "attention_mask.*attention_bias|attention_bias.*attention_mask" --type pyRepository: NVIDIA/Model-Optimizer
Length of output: 589
🏁 Script executed:
# Look for any PR comments or TODOs about this
rg "TODO.*attention|FIXME.*attention|XXX.*attention" --type py -B 2 -A 2Repository: NVIDIA/Model-Optimizer
Length of output: 48
🏁 Script executed:
# Let's check the actual spec configuration to understand what gets passed
cat -n modelopt/torch/speculative/plugins/megatron_eagle.py | sed -n '320,380p'Repository: NVIDIA/Model-Optimizer
Length of output: 3028
🏁 Script executed:
# Check if there are any error handling or assertions around attention parameters
rg "assert.*attention|raise.*attention" --type py -B 1 -A 1 | head -40Repository: NVIDIA/Model-Optimizer
Length of output: 2881
🏁 Script executed:
# Look for the actual eagle_module instantiation to understand its configuration
cat -n modelopt/torch/speculative/plugins/megatron_eagle.py | sed -n '950,1050p'Repository: NVIDIA/Model-Optimizer
Length of output: 4748
🏁 Script executed:
# Search for how SelfAttention calls core_attention to see if it passes attention_mask
rg "def forward" modelopt/torch/speculative/plugins/megatron_eagle.py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 125
🏁 Script executed:
# Look at the comment more carefully - what does "instead" imply?
cat -n modelopt/torch/speculative/plugins/megatron_eagle.py | sed -n '1383,1400p'Repository: NVIDIA/Model-Optimizer
Length of output: 962
🏁 Script executed:
# Check if there's any existing pattern of removing parameters from kwargs
rg "kwargs.pop|del.*kwargs" --type py | head -20Repository: NVIDIA/Model-Optimizer
Length of output: 2386
🏁 Script executed:
# Verify the context - is this confirmed as a needed fix by looking at similar patterns
rg "attention_mask.*kwargs|kwargs.*attention" --type py -B 2 -A 2 | head -60Repository: NVIDIA/Model-Optimizer
Length of output: 5499
Remove attention_mask from kwargs when injecting attention_bias.
The context manager converts the boolean attention_mask to a numeric attention_bias and injects it into kwargs (line 1437). However, the original attention_mask may still be present in kwargs when passed through the transformer layer, causing both parameters to reach TEDotProductAttention.forward simultaneously. The docstring explicitly states the intent is to "replace" the mask (not supplement it), implying TE expects one parameter, not both. If TE prioritizes or rejects duplicate specifications, the injected bias may be ignored or cause errors. Add kwargs.pop("attention_mask", None) after line 1437 to ensure clean parameter passing.
Recommended fix
attention_bias = attention_bias.clone().detach().contiguous()
kwargs["attention_bias"] = attention_bias
+ kwargs.pop("attention_mask", None)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| @contextmanager | |
| def TEDotProductAttentionCP(attention_mask: torch.Tensor, num_attention_heads: int): | |
| """Context manager for TEDotProductAttention with context parallelism. | |
| Context manager that temporarily replace `attention_bias` | |
| with `attention_mask` for `TEDotProductAttention.forward` calls across the process | |
| if context parallel is used. | |
| Any call to `TEDotProductAttention.forward` (including calls originating | |
| from other modules) inside the context will receive `attention_bias=attention_mask` | |
| if context parallelism is used. | |
| Example: | |
| with TEDotProductAttentionCP(attention_mask_tensor, num_attention_heads): | |
| outputs = model(...) | |
| Note: This monkey-patches the class method and restores it on exit. | |
| """ | |
| from megatron.core.extensions.transformer_engine import TEDotProductAttention as cls | |
| orig_forward = cls.forward | |
| def _wrapped_forward(self, *args, **kwargs): | |
| # Build attention_bias from the boolean attention_mask and ensure | |
| # it's a fresh, detached tensor on the query's device/dtype to | |
| # avoid shared-storage in-place modifications that break autograd. | |
| query = args[0] if len(args) > 0 else None | |
| if isinstance(query, torch.Tensor): | |
| q_device = query.device | |
| q_dtype = query.dtype | |
| else: | |
| q_device = None | |
| q_dtype = None | |
| mask_fill = -1e9 | |
| if q_dtype in (torch.float16, torch.bfloat16): | |
| mask_fill = -40.0 | |
| mask_val = torch.tensor(mask_fill, device=attention_mask.device) | |
| zero_val = torch.tensor(0.0, device=attention_mask.device) | |
| attention_bias = torch.where(attention_mask, mask_val, zero_val) | |
| # Expand head dimension if needed | |
| try: | |
| if attention_bias.dim() == 4 and attention_bias.shape[1] == 1: | |
| attention_bias = attention_bias.expand(-1, num_attention_heads, -1, -1) | |
| except Exception: | |
| pass | |
| if q_device is not None and q_dtype is not None: | |
| attention_bias = attention_bias.to(device=q_device, dtype=q_dtype) | |
| attention_bias = attention_bias.clone().detach().contiguous() | |
| kwargs["attention_bias"] = attention_bias | |
| # Defensive clone of query/key/value positional tensors to avoid | |
| # passing views into the fused attention kernel that might be | |
| # modified in-place during backward. | |
| if len(args) >= 1: | |
| original_args = args | |
| new_args = list(original_args) | |
| try: | |
| for i in range(min(3, len(new_args))): | |
| if isinstance(new_args[i], torch.Tensor): | |
| if not new_args[i].is_contiguous(): | |
| new_args[i] = new_args[i].contiguous() | |
| new_args[i] = new_args[i].clone() | |
| if any(x is None for x in new_args): | |
| args = original_args | |
| else: | |
| args = tuple(new_args) | |
| except Exception: | |
| args = original_args | |
| return orig_forward(self, *args, **kwargs) | |
| if get_context_parallel_world_size() > 1: | |
| cls.forward = _wrapped_forward | |
| try: | |
| yield | |
| finally: | |
| cls.forward = orig_forward | |
| `@contextmanager` | |
| def TEDotProductAttentionCP(attention_mask: torch.Tensor, num_attention_heads: int): | |
| """Context manager for TEDotProductAttention with context parallelism. | |
| Context manager that temporarily replace `attention_bias` | |
| with `attention_mask` for `TEDotProductAttention.forward` calls across the process | |
| if context parallel is used. | |
| Any call to `TEDotProductAttention.forward` (including calls originating | |
| from other modules) inside the context will receive `attention_bias=attention_mask` | |
| if context parallelism is used. | |
| Example: | |
| with TEDotProductAttentionCP(attention_mask_tensor, num_attention_heads): | |
| outputs = model(...) | |
| Note: This monkey-patches the class method and restores it on exit. | |
| """ | |
| from megatron.core.extensions.transformer_engine import TEDotProductAttention as cls | |
| orig_forward = cls.forward | |
| def _wrapped_forward(self, *args, **kwargs): | |
| # Build attention_bias from the boolean attention_mask and ensure | |
| # it's a fresh, detached tensor on the query's device/dtype to | |
| # avoid shared-storage in-place modifications that break autograd. | |
| query = args[0] if len(args) > 0 else None | |
| if isinstance(query, torch.Tensor): | |
| q_device = query.device | |
| q_dtype = query.dtype | |
| else: | |
| q_device = None | |
| q_dtype = None | |
| mask_fill = -1e9 | |
| if q_dtype in (torch.float16, torch.bfloat16): | |
| mask_fill = -40.0 | |
| mask_val = torch.tensor(mask_fill, device=attention_mask.device) | |
| zero_val = torch.tensor(0.0, device=attention_mask.device) | |
| attention_bias = torch.where(attention_mask, mask_val, zero_val) | |
| # Expand head dimension if needed | |
| try: | |
| if attention_bias.dim() == 4 and attention_bias.shape[1] == 1: | |
| attention_bias = attention_bias.expand(-1, num_attention_heads, -1, -1) | |
| except Exception: | |
| pass | |
| if q_device is not None and q_dtype is not None: | |
| attention_bias = attention_bias.to(device=q_device, dtype=q_dtype) | |
| attention_bias = attention_bias.clone().detach().contiguous() | |
| kwargs["attention_bias"] = attention_bias | |
| kwargs.pop("attention_mask", None) | |
| # Defensive clone of query/key/value positional tensors to avoid | |
| # passing views into the fused attention kernel that might be | |
| # modified in-place during backward. | |
| if len(args) >= 1: | |
| original_args = args | |
| new_args = list(original_args) | |
| try: | |
| for i in range(min(3, len(new_args))): | |
| if isinstance(new_args[i], torch.Tensor): | |
| if not new_args[i].is_contiguous(): | |
| new_args[i] = new_args[i].contiguous() | |
| new_args[i] = new_args[i].clone() | |
| if any(x is None for x in new_args): | |
| args = original_args | |
| else: | |
| args = tuple(new_args) | |
| except Exception: | |
| args = original_args | |
| return orig_forward(self, *args, **kwargs) | |
| if get_context_parallel_world_size() > 1: | |
| cls.forward = _wrapped_forward | |
| try: | |
| yield | |
| finally: | |
| cls.forward = orig_forward |
🤖 Prompt for AI Agents
In `@modelopt/torch/speculative/plugins/megatron_eagle.py` around lines 1383 -
1464, The wrapper in TEDotProductAttentionCP constructs attention_bias and
injects it into kwargs but doesn't remove the original boolean attention_mask,
which can cause both to be passed into TEDotProductAttention.forward; update the
wrapper (_wrapped_forward) to remove the original mask by calling
kwargs.pop("attention_mask", None) immediately after assigning
kwargs["attention_bias"] = attention_bias so only attention_bias is passed to
cls.forward; keep the rest of the logic (cloning args, device/dtype conversion,
and restoring cls.forward) unchanged.
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #818 +/- ##
==========================================
+ Coverage 73.82% 73.84% +0.01%
==========================================
Files 193 193
Lines 19745 19745
==========================================
+ Hits 14577 14580 +3
+ Misses 5168 5165 -3 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
162d3a2 to
83c4c3b
Compare
| attn_mask[:, :, :-1, :-1] = attention_mask[:, :, 1:, 1:] | ||
| # [b, 1, sq, sk] -> [sq, 1, b, sk] | ||
| attn_mask = attn_mask.transpose(0, 2).contiguous() | ||
| attn_mask = gather_from_sequence_parallel_region( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume here that here attn_mask has shape [sq*cp, 1, b, sk]. My understanding is that sk=ttt_step*sq*cp? So the attn_mask is O(ttt_stepbsq^2*cp^2)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sk=sq*cp. ttt_step will be introduced later at set_multi_step_attention_mask
| attn_mask = set_multi_step_attention_mask(attn_mask, ttt_step) | ||
| # [b, 1, sq, sk] -> [sq, 1, b, sk] | ||
| attn_mask = attn_mask.transpose(0, 2).contiguous() | ||
| attn_mask = scatter_to_sequence_parallel_region( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume scatter is from [cp*sq, 1, b, sk] back to [sq, 1, b, sk]?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes
|
|
||
|
|
||
| @contextmanager | ||
| def te_dot_product_attention_with_cp(attention_mask: torch.Tensor, num_attention_heads: int): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I need a walkthrough here.
… patch Signed-off-by: Ye Yu <yeyu@nvidia.com>
…_size Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…lel region Signed-off-by: Ye Yu <yeyu@nvidia.com>
… gather_from_sequence_parallel_region and change the group Signed-off-by: Ye Yu <yeyu@nvidia.com>
…mention so we need to transpose tensors first Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…when TP>1 Signed-off-by: Ye Yu <yeyu@nvidia.com>
Signed-off-by: Ye Yu <yeyu@nvidia.com>
…ate as we disable this function when CP>1 Signed-off-by: Ye Yu <yeyu@nvidia.com>
4c78855 to
4c44c1c
Compare
What does this PR do?
New feature
Overview:
This PR implements the context manager which injects attn_mask as attn_bias to TEDotProductAttention so that we can enable EAGLE training with arbitrary mask.
Usage
set CP>1 in https://github.com/NVIDIA/Megatron-LM/blob/main/examples/post_training/modelopt/finetune.sh
# Add a code snippet demonstrating how to use thisTesting
Tested on DSR1 Llama 8B.
CP1->CP2
38854MB->28050MB
MTbench AL 2.26->2.31
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
✏️ Tip: You can customize this high-level summary in your review settings.