Jinliang/fsdp no shard vit#1
Open
wplf wants to merge 4 commits intojinliang/mimo-kimifrom
Open
Conversation
Add a flexible `fsdp_no_shard()` utility that marks modules or parameters to be kept replicated (not sharded) by Megatron FSDP. This avoids unnecessary all-gather/reshard communication for small or frozen sub-modules such as vision encoders in VLMs. Changes: - Add `fsdp_no_shard` field to `ParameterGroup` dataclass - Add public `fsdp_no_shard(module_or_param)` function that sets `param._fsdp_no_shard = True` on target parameters - Wire the flag through parameter grouping (Steps 1-3), bucketing, and buffer init so marked groups get `is_data_distributed=False` - Skip NCCL all-gather in `AllGatherPipeline.async_bucket_gather` for non-distributed (replicated) buffers - Log `no_shard` tag in FSDP unit summary for visibility - Apply to frozen vision_tower and mm_projector in KimiK25VLModel Verified on proxy model (2 GPU, EP=2, FSDP optim_grads_params): - Loss: numerically equivalent (diff < 5e-4) - Peak memory (max_allocated): -454 MB (no temp all-gather buffers) - CUDA reserved: -908 MB - Throughput: within noise on 2 GPUs Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Skip registering _pre_forward_param_unshard and _pre_backward_param_unshard hooks on modules whose parameters are all marked with _fsdp_no_shard. This eliminates per-module hook dispatch overhead for replicated (non-sharded) sub-modules like frozen vision encoders, which have many small layers that each trigger unnecessary hook calls. The optimization detects no_shard modules at hook registration time and skips both the module and all its descendants, avoiding the cost of hook dispatch, parameter list construction, and bucket lookups on every forward/backward pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When --init-model-with-meta-device is used, all model params (including frozen ViT) start on meta device. During ParamAndGradBuffer init, m.to_empty() + reset_parameters() creates new Parameter objects. The _reset_parameters method was copying TP attributes but not _fsdp_no_shard, causing the flag to be lost. Add _fsdp_no_shard to the attribute copy list in _reset_parameters so the flag survives meta-device materialization and is visible at hook registration time. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Reviewer's GuideIntroduce an opt-in Sequence diagram for async_bucket_gather behavior with fsdp_no_shardsequenceDiagram
actor Trainer
participant MegatronFSDP
participant DataParallelBuffer as DataParallelBuffer_wbuf
participant Dist as torch_distributed
Trainer->>MegatronFSDP: async_bucket_gather(bucket_id, bwd)
MegatronFSDP->>MegatronFSDP: compute bucket_key
MegatronFSDP->>DataParallelBuffer: wbuf.fetch_bucket(set_param_data=True)
DataParallelBuffer-->>MegatronFSDP: bucket
alt wbuf.is_data_distributed == false (fsdp_no_shard)
MegatronFSDP->>MegatronFSDP: bucket_status[bucket_key] = READY_TO_USE
MegatronFSDP-->>Trainer: return (no all_gather)
else wbuf.is_data_distributed == true
MegatronFSDP->>Dist: all_gather_into_tensor(bucket, wbuf.shard)
Dist-->>MegatronFSDP: param_gather_event
MegatronFSDP->>MegatronFSDP: param_gather_event_map[bucket_key] = event
MegatronFSDP->>MegatronFSDP: bucket_status[bucket_key] = GATHER_IN_PROGRESS
MegatronFSDP-->>Trainer: return
end
Class diagram for fsdp_no_shard integration in ParameterGroupclassDiagram
class ParameterGroup {
+torch.dtype dtype
+bool is_expert_param
+bool~Optional~ requires_grad
+int~Optional~ fsdp_unit_id
+bool fsdp_no_shard
+int chunk_size_factor
+DataParallelBuffer~Optional~ model_weight_buffer
+DataParallelBuffer~Optional~ transpose_weight_buffer
+DataParallelBuffer~Optional~ grad_buffer
+DataParallelBuffer~Optional~ hsdp_gbuf
}
class DataParallelBuffer {
+bool is_data_distributed
+dtype
+device
+data_parallel_group
+fetch_bucket()
}
class MegatronFSDP {
+dict forward_pre_hooks
+dict backward_pre_hooks
+dict param_gather_event_map
+dict bucket_status
+bool enable_fine_grained_param_gather_hook
+_init_each_parameter_group_buffers(meta_device_init_fp8_params)
+async_bucket_gather(bucket_id, bwd)
+_setup_hooks(root_module)
}
class fsdp_no_shard_helper {
+fsdp_no_shard(module_or_param)
}
ParameterGroup o-- DataParallelBuffer : model_weight_buffer
ParameterGroup o-- DataParallelBuffer : transpose_weight_buffer
ParameterGroup o-- DataParallelBuffer : grad_buffer
ParameterGroup o-- DataParallelBuffer : hsdp_gbuf
MegatronFSDP o-- ParameterGroup : manages
MegatronFSDP o-- DataParallelBuffer : initializes
fsdp_no_shard_helper ..> ParameterGroup : sets_fsdp_no_shard_flag
fsdp_no_shard_helper ..> MegatronFSDP : influences_sharding_behavior
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Hey - I've found 3 issues, and left some high level feedback:
- The
fsdp_no_shardhelper advertises support for bothnn.Moduleandnn.Parameter, but unconditionally calls.parameters()on the argument; consider adding anisinstance(..., nn.Parameter)branch so that passing a single parameter does not raise an AttributeError. - In
_module_params_all_no_shard, the docstring says it returns False if there are no parameters; currently it returns False in that case but this implicitly treats such modules as shardable—if the intent is to skip hook registration for param-less modules as well, you may want to short-circuit early instead of having them flow through to FSDP hook registration logic.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- The `fsdp_no_shard` helper advertises support for both `nn.Module` and `nn.Parameter`, but unconditionally calls `.parameters()` on the argument; consider adding an `isinstance(..., nn.Parameter)` branch so that passing a single parameter does not raise an AttributeError.
- In `_module_params_all_no_shard`, the docstring says it returns False if there are no parameters; currently it returns False in that case but this implicitly treats such modules as shardable—if the intent is to skip hook registration for param-less modules as well, you may want to short-circuit early instead of having them flow through to FSDP hook registration logic.
## Individual Comments
### Comment 1
<location path="megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py" line_range="1272-1281" />
<code_context>
+def fsdp_no_shard(module_or_param: "torch.nn.Module | torch.nn.Parameter"):
</code_context>
<issue_to_address>
**issue (bug_risk):** fsdp_no_shard breaks when passed a single Parameter since it assumes `.parameters()` exists.
The implementation assumes `module_or_param` has `.parameters()`, which contradicts the signature/docstring allowing a single `nn.Parameter`. Passing a `Parameter` will raise `AttributeError`. Please branch on type/capability so a single parameter is wrapped in a list and iterated, while modules still use `.parameters()`.
</issue_to_address>
### Comment 2
<location path="megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py" line_range="2007-2016" />
<code_context>
+ # Parameters marked with _fsdp_no_shard=True are kept replicated
+ # (is_data_distributed=False) to avoid unnecessary all-gather/reshard
+ # overhead for small or non-trainable modules (e.g. frozen vision encoders).
+ should_shard_model_weights = (
+ is_model_weight_buffer_distributed
+ and model_wbuf_dp_group.size() > 1
+ and not group.fsdp_no_shard
+ )
if data_parallel_sharding_strategy != "no_shard":
group.model_weight_buffer = DataParallelBuffer(
self.ddp_config,
group.params,
- is_data_distributed=is_model_weight_buffer_distributed
- and model_wbuf_dp_group.size() > 1,
+ is_data_distributed=should_shard_model_weights,
dtype=param_dtype,
device=self.device,
</code_context>
<issue_to_address>
**issue (bug_risk):** Reusing `should_shard_model_weights` for the transpose buffer may be incorrect if its DP group differs.
Previously, the DP-based `is_data_distributed` flags for the model and transpose buffers were computed from their respective groups. Now `should_shard_model_weights` is derived from `model_wbuf_dp_group` but reused for the transpose buffer, which uses `main_buf_dp_group`. If these groups can differ, the transpose buffer’s sharding state could be wrong. Either assert that the groups are equivalent in this context or introduce a separate `should_shard_transpose_weights` based on `main_buf_dp_group` to avoid hidden coupling.
</issue_to_address>
### Comment 3
<location path="megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py" line_range="936" />
<code_context>
# on the output tensor(s).
return module.register_forward_hook(forward_hook)
+ def _module_params_all_no_shard(module):
+ """Check if all parameters of a module (including children) are _fsdp_no_shard.
+
</code_context>
<issue_to_address>
**issue (complexity):** Consider centralizing the no-shard module detection into a cached per-module flag and simplifying hook registration and NVTX usage around that flag.
The complexity mostly comes from scattering the “no‑shard module” decision across multiple passes and data structures. You can keep all behavior and simplify by:
1. **Cache a module‑level “all params no_shard” flag once**
Instead of recalculating `_module_params_all_no_shard(module)` and tracking `no_shard_modules` + `is_submodule(...)` in the main loop, compute and cache a flag on each module during the first traversal, then use that everywhere:
```python
def _module_params_all_no_shard(module: nn.Module) -> bool:
params = list(module.parameters())
return len(params) > 0 and all(
getattr(p, "_fsdp_no_shard", False) for p in params
)
# Top-down classification: once parent is no_shard, children inherit.
for _, module in root_module.named_modules():
parent = getattr(module, "_fsdp_parent", None)
parent_no_shard = getattr(parent, "_fsdp_no_shard_module", False) if parent else False
if parent_no_shard:
module._fsdp_no_shard_module = True
else:
module._fsdp_no_shard_module = _module_params_all_no_shard(module)
```
You can set `_fsdp_parent` when building the module hierarchy if you don’t already have it, or maintain a parent stack during `named_modules()` traversal.
Then your hook registration becomes linear and easier to reason about, with no `no_shard_modules` list and no `is_submodule` checks:
```python
def _register_pre_forward_param_unshard_hook(module):
if (
self.ddp_config.data_parallel_sharding_strategy == "no_shard"
or getattr(module, "_fsdp_no_shard_module", False)
):
return
self.forward_pre_hooks[f"{module._get_name()} parameter unshard"] = (
module.register_forward_pre_hook(
_pre_forward_param_unshard, prepend=True, with_kwargs=True
)
)
def _register_pre_backward_param_unshard_hook(module):
if getattr(module, "_fsdp_no_shard_module", False):
return
self.backward_pre_hooks[f"all-gather {module._get_name()} parameters"] = (
create_custom_backward_hook(module, _pre_backward_param_unshard)
)
```
And the main loop simplifies to:
```python
fsdp_modules = []
for name, module in root_module.named_modules():
if self.enable_fine_grained_param_gather_hook:
_register_pre_forward_param_unshard_hook(module)
_register_pre_backward_param_unshard_hook(module)
if any(is_submodule(module, fm) for fm in fsdp_modules):
continue
if not self.enable_fine_grained_param_gather_hook:
_register_pre_forward_param_unshard_hook(module)
_register_pre_backward_param_unshard_hook(module)
# Optional: log only root no_shard modules (those with no no_shard parent)
no_shard_roots = [
m for _, m in root_module.named_modules()
if getattr(m, "_fsdp_no_shard_module", False)
and not getattr(getattr(m, "_fsdp_parent", None), "_fsdp_no_shard_module", False)
]
if no_shard_roots:
logger.info(
"[fsdp_no_shard] Skipped FSDP hooks for %d no_shard root module(s): %s",
len(no_shard_roots),
[m._get_name() for m in no_shard_roots],
)
```
This preserves:
- “all params no_shard” semantics, including descendants.
- Skipping all‑gather / reshard hooks for those modules and their children.
- Existing `fsdp_modules` behavior.
2. **Remove duplicated `all params no_shard` checks**
Once you have `module._fsdp_no_shard_module`, you can remove multiple `_module_params_all_no_shard(module)` calls and rely on the cached flag as shown above. This reduces both runtime overhead and conceptual duplication.
3. **Localize NVTX tracing if you want to reduce decorator noise**
If the NVTX annotation is mostly for profiling, wrapping just the bodies instead of decorating the whole function keeps the function signature and decoration stack simpler:
```python
def _pre_forward_param_unshard(module: nn.Module, args, kwargs):
with nvtx_decorator().range("_pre_forward_param_unshard"):
# existing body...
...
def _pre_backward_param_unshard(module: nn.Module, *unused):
with nvtx_decorator().range("_pre_backward_param_unshard"):
# existing body...
...
```
(or equivalent context manager if `nvtx_decorator` supports it).
This keeps all functionality intact while making the no‑shard logic more centralized and the control flow easier to follow.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
1bdf2fb to
f1d93f2
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do ?
Contribution process
flowchart LR A[Pre-checks] --> B[PR Tests] subgraph Code Review/Approval C1[Expert Review] --> C2[Final Review] end B --> C1 C2 --> D[Merge]Pre-checks
Core 0.8)Code review
The following process is enforced via the CODEOWNERS file for changes into
megatron/core. For changes outside ofmegatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.For MRs into `main` branch
Feel free to message or comment the @megatron-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
(Step 1): Add PR label
Expert Review(Step 2): Collect the expert reviewers reviews
Expert Reviewlabel when your PR is ready for review.Final Review might get declined if these requirements are not fulfilled.
(Step 3): Final Review
Final Reviewlabel(Optional Step 4): Cherry-pick into release branch
If this PR also needs to be merged into
core_r*release branches, after this PR has been merged, selectCherry-pickto open a new PR into the release branch.For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.Merging your PR
Any member of core-adlr and
core-nemowill be able to merge your PR.Summary by Sourcery
Introduce an opt-in mechanism to keep selected parameters/modules replicated under FSDP and integrate it into buffer initialization, hook registration, and example models to avoid unnecessary all-gather/reshard for frozen components.
New Features:
Enhancements: