Skip to content

Jinliang/fsdp no shard vit#1

Open
wplf wants to merge 4 commits intojinliang/mimo-kimifrom
jinliang/fsdp-no-shard-vit
Open

Jinliang/fsdp no shard vit#1
wplf wants to merge 4 commits intojinliang/mimo-kimifrom
jinliang/fsdp-no-shard-vit

Conversation

@wplf
Copy link
Copy Markdown
Owner

@wplf wplf commented Apr 9, 2026

What does this PR do ?

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @megatron-oncall.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @megatron-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

Summary by Sourcery

Introduce an opt-in mechanism to keep selected parameters/modules replicated under FSDP and integrate it into buffer initialization, hook registration, and example models to avoid unnecessary all-gather/reshard for frozen components.

New Features:

  • Add an fsdp_no_shard helper and flag to mark parameters or modules that should remain replicated instead of being sharded in FSDP.
  • Expose fsdp_no_shard from the megatron_fsdp package and apply it to frozen vision components in the Kimi K25 example model.

Enhancements:

  • Propagate the fsdp_no_shard attribute through parameter regrouping, cloning, and logging so FSDP bucket metadata and debug output reflect non-sharded groups.
  • Update FSDP buffer initialization and async bucket gather logic to skip data distribution and all-gather for non-sharded parameter groups, and guard hook coalescing for buckets that were never gathered.
  • Skip registration of unshard/all-gather hooks for modules whose parameters are all marked no_shard and log the set of such modules at startup.
  • Decorate key FSDP pre/post-forward and pre-backward hooks with nvtx_decorator for improved profiling.

wplf and others added 3 commits April 7, 2026 07:04
Add a flexible `fsdp_no_shard()` utility that marks modules or
parameters to be kept replicated (not sharded) by Megatron FSDP.
This avoids unnecessary all-gather/reshard communication for small
or frozen sub-modules such as vision encoders in VLMs.

Changes:
- Add `fsdp_no_shard` field to `ParameterGroup` dataclass
- Add public `fsdp_no_shard(module_or_param)` function that sets
  `param._fsdp_no_shard = True` on target parameters
- Wire the flag through parameter grouping (Steps 1-3), bucketing,
  and buffer init so marked groups get `is_data_distributed=False`
- Skip NCCL all-gather in `AllGatherPipeline.async_bucket_gather`
  for non-distributed (replicated) buffers
- Log `no_shard` tag in FSDP unit summary for visibility
- Apply to frozen vision_tower and mm_projector in KimiK25VLModel

Verified on proxy model (2 GPU, EP=2, FSDP optim_grads_params):
- Loss: numerically equivalent (diff < 5e-4)
- Peak memory (max_allocated): -454 MB (no temp all-gather buffers)
- CUDA reserved: -908 MB
- Throughput: within noise on 2 GPUs

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Skip registering _pre_forward_param_unshard and
_pre_backward_param_unshard hooks on modules whose parameters
are all marked with _fsdp_no_shard. This eliminates per-module
hook dispatch overhead for replicated (non-sharded) sub-modules
like frozen vision encoders, which have many small layers that
each trigger unnecessary hook calls.

The optimization detects no_shard modules at hook registration
time and skips both the module and all its descendants, avoiding
the cost of hook dispatch, parameter list construction, and
bucket lookups on every forward/backward pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
When --init-model-with-meta-device is used, all model params (including
frozen ViT) start on meta device. During ParamAndGradBuffer init,
m.to_empty() + reset_parameters() creates new Parameter objects.
The _reset_parameters method was copying TP attributes but not
_fsdp_no_shard, causing the flag to be lost.

Add _fsdp_no_shard to the attribute copy list in _reset_parameters
so the flag survives meta-device materialization and is visible at
hook registration time.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@sourcery-ai
Copy link
Copy Markdown

sourcery-ai Bot commented Apr 9, 2026

Reviewer's Guide

Introduce an opt-in fsdp_no_shard flag and helper to keep selected parameters/modules replicated under Megatron FSDP, and wire it through parameter grouping, buffer initialization, hook registration, logging, and an example MIMO model to avoid unnecessary all‑gather/reshard for frozen vision components.

Sequence diagram for async_bucket_gather behavior with fsdp_no_shard

sequenceDiagram
    actor Trainer
    participant MegatronFSDP
    participant DataParallelBuffer as DataParallelBuffer_wbuf
    participant Dist as torch_distributed

    Trainer->>MegatronFSDP: async_bucket_gather(bucket_id, bwd)
    MegatronFSDP->>MegatronFSDP: compute bucket_key
    MegatronFSDP->>DataParallelBuffer: wbuf.fetch_bucket(set_param_data=True)
    DataParallelBuffer-->>MegatronFSDP: bucket

    alt wbuf.is_data_distributed == false (fsdp_no_shard)
        MegatronFSDP->>MegatronFSDP: bucket_status[bucket_key] = READY_TO_USE
        MegatronFSDP-->>Trainer: return (no all_gather)
    else wbuf.is_data_distributed == true
        MegatronFSDP->>Dist: all_gather_into_tensor(bucket, wbuf.shard)
        Dist-->>MegatronFSDP: param_gather_event
        MegatronFSDP->>MegatronFSDP: param_gather_event_map[bucket_key] = event
        MegatronFSDP->>MegatronFSDP: bucket_status[bucket_key] = GATHER_IN_PROGRESS
        MegatronFSDP-->>Trainer: return
    end
Loading

Class diagram for fsdp_no_shard integration in ParameterGroup

classDiagram
    class ParameterGroup {
        +torch.dtype dtype
        +bool is_expert_param
        +bool~Optional~ requires_grad
        +int~Optional~ fsdp_unit_id
        +bool fsdp_no_shard
        +int chunk_size_factor
        +DataParallelBuffer~Optional~ model_weight_buffer
        +DataParallelBuffer~Optional~ transpose_weight_buffer
        +DataParallelBuffer~Optional~ grad_buffer
        +DataParallelBuffer~Optional~ hsdp_gbuf
    }

    class DataParallelBuffer {
        +bool is_data_distributed
        +dtype
        +device
        +data_parallel_group
        +fetch_bucket()
    }

    class MegatronFSDP {
        +dict forward_pre_hooks
        +dict backward_pre_hooks
        +dict param_gather_event_map
        +dict bucket_status
        +bool enable_fine_grained_param_gather_hook
        +_init_each_parameter_group_buffers(meta_device_init_fp8_params)
        +async_bucket_gather(bucket_id, bwd)
        +_setup_hooks(root_module)
    }

    class fsdp_no_shard_helper {
        +fsdp_no_shard(module_or_param)
    }

    ParameterGroup o-- DataParallelBuffer : model_weight_buffer
    ParameterGroup o-- DataParallelBuffer : transpose_weight_buffer
    ParameterGroup o-- DataParallelBuffer : grad_buffer
    ParameterGroup o-- DataParallelBuffer : hsdp_gbuf

    MegatronFSDP o-- ParameterGroup : manages
    MegatronFSDP o-- DataParallelBuffer : initializes

    fsdp_no_shard_helper ..> ParameterGroup : sets_fsdp_no_shard_flag
    fsdp_no_shard_helper ..> MegatronFSDP : influences_sharding_behavior
Loading

File-Level Changes

Change Details Files
Add a fsdp_no_shard flag on parameter groups and plumb it through grouping, cloning, and logging so FSDP can treat selected params as replicated.
  • Extend ParameterGroup with a fsdp_no_shard boolean and populate it from a per-parameter _fsdp_no_shard attribute when building parameter groups.
  • Ensure parameter-group splitting/copying preserves fsdp_no_shard via the basic_attrs dict and when constructing new ParameterGroup instances.
  • Include fsdp_no_shard in parameter cloning/reset paths (_reset_parameters, set_param_attribute) so the attribute survives re-wrapping and TP transformations.
  • Annotate FSDP logging (_bytes_to_mb path) to tag groups with no_shard to aid debugging and inspection.
megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py
Update buffer init and async gather logic so fsdp_no_shard groups keep weights replicated and skip all-gather/reshard operations.
  • Compute a should_shard_model_weights flag that is false for fsdp_no_shard groups, and pass it as is_data_distributed when constructing model and transpose DataParallelBuffers.
  • In async_bucket_gather, short‑circuit buckets backed by non-distributed buffers by marking them READY_TO_USE without issuing an all-gather and without touching param_gather_event_map.
  • Guard coalescing logic to skip buckets not present in param_gather_event_map, which can occur for replicated/no‑shard parameters.
megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py
Add a public fsdp_no_shard API and integrate it with MegatronFSDP’s hook registration to avoid attaching all-gather/reshard hooks to fully no-shard modules.
  • Introduce fsdp_no_shard(module_or_param) helper that tags parameters with _fsdp_no_shard=True and re-export it from the Megatron FSDP package.
  • Add _module_params_all_no_shard and use it when registering pre-forward and pre-backward unshard hooks to skip modules whose parameters are all marked no_shard.
  • Track no-shard root modules to skip their descendants during hook registration, and log a summary of skipped modules for visibility.
  • Decorate key FSDP internal hooks (_pre_forward_param_unshard, _pre_backward_param_unshard, _post_forward, _release_module_fp8_transpose_cache) with nvtx_decorator() for profiling.
megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py
megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py
Use the new no-shard mechanism in the Kimi K25 MIMO example to keep frozen vision components replicated under FSDP.
  • In the Kimi K25 example model, when freeze_vision_model or freeze_vision_projection is enabled and the corresponding submodules exist, call fsdp_no_shard on vision_tower and mm_projector to avoid sharding their parameters.
  • Wrap the import of fsdp_no_shard in a try/except to keep the example robust when Megatron FSDP is not available.
examples/mimo/model_providers/kimi_k25/model.py

Tips and commands

Interacting with Sourcery

  • Trigger a new review: Comment @sourcery-ai review on the pull request.
  • Continue discussions: Reply directly to Sourcery's review comments.
  • Generate a GitHub issue from a review comment: Ask Sourcery to create an
    issue from a review comment by replying to it. You can also reply to a
    review comment with @sourcery-ai issue to create an issue from it.
  • Generate a pull request title: Write @sourcery-ai anywhere in the pull
    request title to generate a title at any time. You can also comment
    @sourcery-ai title on the pull request to (re-)generate the title at any time.
  • Generate a pull request summary: Write @sourcery-ai summary anywhere in
    the pull request body to generate a PR summary at any time exactly where you
    want it. You can also comment @sourcery-ai summary on the pull request to
    (re-)generate the summary at any time.
  • Generate reviewer's guide: Comment @sourcery-ai guide on the pull
    request to (re-)generate the reviewer's guide at any time.
  • Resolve all Sourcery comments: Comment @sourcery-ai resolve on the
    pull request to resolve all Sourcery comments. Useful if you've already
    addressed all the comments and don't want to see them anymore.
  • Dismiss all Sourcery reviews: Comment @sourcery-ai dismiss on the pull
    request to dismiss all existing Sourcery reviews. Especially useful if you
    want to start fresh with a new review - don't forget to comment
    @sourcery-ai review to trigger a new review!

Customizing Your Experience

Access your dashboard to:

  • Enable or disable review features such as the Sourcery-generated pull request
    summary, the reviewer's guide, and others.
  • Change the review language.
  • Add, remove or edit custom review instructions.
  • Adjust other review settings.

Getting Help

Copy link
Copy Markdown

@sourcery-ai sourcery-ai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey - I've found 3 issues, and left some high level feedback:

  • The fsdp_no_shard helper advertises support for both nn.Module and nn.Parameter, but unconditionally calls .parameters() on the argument; consider adding an isinstance(..., nn.Parameter) branch so that passing a single parameter does not raise an AttributeError.
  • In _module_params_all_no_shard, the docstring says it returns False if there are no parameters; currently it returns False in that case but this implicitly treats such modules as shardable—if the intent is to skip hook registration for param-less modules as well, you may want to short-circuit early instead of having them flow through to FSDP hook registration logic.
Prompt for AI Agents
Please address the comments from this code review:

## Overall Comments
- The `fsdp_no_shard` helper advertises support for both `nn.Module` and `nn.Parameter`, but unconditionally calls `.parameters()` on the argument; consider adding an `isinstance(..., nn.Parameter)` branch so that passing a single parameter does not raise an AttributeError.
- In `_module_params_all_no_shard`, the docstring says it returns False if there are no parameters; currently it returns False in that case but this implicitly treats such modules as shardable—if the intent is to skip hook registration for param-less modules as well, you may want to short-circuit early instead of having them flow through to FSDP hook registration logic.

## Individual Comments

### Comment 1
<location path="megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py" line_range="1272-1281" />
<code_context>
+def fsdp_no_shard(module_or_param: "torch.nn.Module | torch.nn.Parameter"):
</code_context>
<issue_to_address>
**issue (bug_risk):** fsdp_no_shard breaks when passed a single Parameter since it assumes `.parameters()` exists.

The implementation assumes `module_or_param` has `.parameters()`, which contradicts the signature/docstring allowing a single `nn.Parameter`. Passing a `Parameter` will raise `AttributeError`. Please branch on type/capability so a single parameter is wrapped in a list and iterated, while modules still use `.parameters()`.
</issue_to_address>

### Comment 2
<location path="megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py" line_range="2007-2016" />
<code_context>
+            # Parameters marked with _fsdp_no_shard=True are kept replicated
+            # (is_data_distributed=False) to avoid unnecessary all-gather/reshard
+            # overhead for small or non-trainable modules (e.g. frozen vision encoders).
+            should_shard_model_weights = (
+                is_model_weight_buffer_distributed
+                and model_wbuf_dp_group.size() > 1
+                and not group.fsdp_no_shard
+            )
             if data_parallel_sharding_strategy != "no_shard":
                 group.model_weight_buffer = DataParallelBuffer(
                     self.ddp_config,
                     group.params,
-                    is_data_distributed=is_model_weight_buffer_distributed
-                    and model_wbuf_dp_group.size() > 1,
+                    is_data_distributed=should_shard_model_weights,
                     dtype=param_dtype,
                     device=self.device,
</code_context>
<issue_to_address>
**issue (bug_risk):** Reusing `should_shard_model_weights` for the transpose buffer may be incorrect if its DP group differs.

Previously, the DP-based `is_data_distributed` flags for the model and transpose buffers were computed from their respective groups. Now `should_shard_model_weights` is derived from `model_wbuf_dp_group` but reused for the transpose buffer, which uses `main_buf_dp_group`. If these groups can differ, the transpose buffer’s sharding state could be wrong. Either assert that the groups are equivalent in this context or introduce a separate `should_shard_transpose_weights` based on `main_buf_dp_group` to avoid hidden coupling.
</issue_to_address>

### Comment 3
<location path="megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py" line_range="936" />
<code_context>
             # on the output tensor(s).
             return module.register_forward_hook(forward_hook)

+        def _module_params_all_no_shard(module):
+            """Check if all parameters of a module (including children) are _fsdp_no_shard.
+
</code_context>
<issue_to_address>
**issue (complexity):** Consider centralizing the no-shard module detection into a cached per-module flag and simplifying hook registration and NVTX usage around that flag.

The complexity mostly comes from scattering the “no‑shard module” decision across multiple passes and data structures. You can keep all behavior and simplify by:

1. **Cache a module‑level “all params no_shard” flag once**

Instead of recalculating `_module_params_all_no_shard(module)` and tracking `no_shard_modules` + `is_submodule(...)` in the main loop, compute and cache a flag on each module during the first traversal, then use that everywhere:

```python
def _module_params_all_no_shard(module: nn.Module) -> bool:
    params = list(module.parameters())
    return len(params) > 0 and all(
        getattr(p, "_fsdp_no_shard", False) for p in params
    )

# Top-down classification: once parent is no_shard, children inherit.
for _, module in root_module.named_modules():
    parent = getattr(module, "_fsdp_parent", None)
    parent_no_shard = getattr(parent, "_fsdp_no_shard_module", False) if parent else False

    if parent_no_shard:
        module._fsdp_no_shard_module = True
    else:
        module._fsdp_no_shard_module = _module_params_all_no_shard(module)
```

You can set `_fsdp_parent` when building the module hierarchy if you don’t already have it, or maintain a parent stack during `named_modules()` traversal.

Then your hook registration becomes linear and easier to reason about, with no `no_shard_modules` list and no `is_submodule` checks:

```python
def _register_pre_forward_param_unshard_hook(module):
    if (
        self.ddp_config.data_parallel_sharding_strategy == "no_shard"
        or getattr(module, "_fsdp_no_shard_module", False)
    ):
        return
    self.forward_pre_hooks[f"{module._get_name()} parameter unshard"] = (
        module.register_forward_pre_hook(
            _pre_forward_param_unshard, prepend=True, with_kwargs=True
        )
    )

def _register_pre_backward_param_unshard_hook(module):
    if getattr(module, "_fsdp_no_shard_module", False):
        return
    self.backward_pre_hooks[f"all-gather {module._get_name()} parameters"] = (
        create_custom_backward_hook(module, _pre_backward_param_unshard)
    )
```

And the main loop simplifies to:

```python
fsdp_modules = []

for name, module in root_module.named_modules():
    if self.enable_fine_grained_param_gather_hook:
        _register_pre_forward_param_unshard_hook(module)
        _register_pre_backward_param_unshard_hook(module)

    if any(is_submodule(module, fm) for fm in fsdp_modules):
        continue

    if not self.enable_fine_grained_param_gather_hook:
        _register_pre_forward_param_unshard_hook(module)
        _register_pre_backward_param_unshard_hook(module)

# Optional: log only root no_shard modules (those with no no_shard parent)
no_shard_roots = [
    m for _, m in root_module.named_modules()
    if getattr(m, "_fsdp_no_shard_module", False)
    and not getattr(getattr(m, "_fsdp_parent", None), "_fsdp_no_shard_module", False)
]
if no_shard_roots:
    logger.info(
        "[fsdp_no_shard] Skipped FSDP hooks for %d no_shard root module(s): %s",
        len(no_shard_roots),
        [m._get_name() for m in no_shard_roots],
    )
```

This preserves:

- “all params no_shard” semantics, including descendants.
- Skipping all‑gather / reshard hooks for those modules and their children.
- Existing `fsdp_modules` behavior.

2. **Remove duplicated `all params no_shard` checks**

Once you have `module._fsdp_no_shard_module`, you can remove multiple `_module_params_all_no_shard(module)` calls and rely on the cached flag as shown above. This reduces both runtime overhead and conceptual duplication.

3. **Localize NVTX tracing if you want to reduce decorator noise**

If the NVTX annotation is mostly for profiling, wrapping just the bodies instead of decorating the whole function keeps the function signature and decoration stack simpler:

```python
def _pre_forward_param_unshard(module: nn.Module, args, kwargs):
    with nvtx_decorator().range("_pre_forward_param_unshard"):
        # existing body...
        ...

def _pre_backward_param_unshard(module: nn.Module, *unused):
    with nvtx_decorator().range("_pre_backward_param_unshard"):
        # existing body...
        ...
```

(or equivalent context manager if `nvtx_decorator` supports it).

This keeps all functionality intact while making the no‑shard logic more centralized and the control flow easier to follow.
</issue_to_address>

Sourcery is free for open source - if you like our reviews please consider sharing them ✨
Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.

Comment thread megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py
@wplf wplf force-pushed the jinliang/fsdp-no-shard-vit branch from 1bdf2fb to f1d93f2 Compare April 9, 2026 04:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant