From 68de9194e69e2314f32e5c50774f4e6ac69cf027 Mon Sep 17 00:00:00 2001 From: jinliangl Date: Tue, 7 Apr 2026 07:04:35 -0700 Subject: [PATCH 1/4] feat: add fsdp_no_shard API to keep params replicated in FSDP Add a flexible `fsdp_no_shard()` utility that marks modules or parameters to be kept replicated (not sharded) by Megatron FSDP. This avoids unnecessary all-gather/reshard communication for small or frozen sub-modules such as vision encoders in VLMs. Changes: - Add `fsdp_no_shard` field to `ParameterGroup` dataclass - Add public `fsdp_no_shard(module_or_param)` function that sets `param._fsdp_no_shard = True` on target parameters - Wire the flag through parameter grouping (Steps 1-3), bucketing, and buffer init so marked groups get `is_data_distributed=False` - Skip NCCL all-gather in `AllGatherPipeline.async_bucket_gather` for non-distributed (replicated) buffers - Log `no_shard` tag in FSDP unit summary for visibility - Apply to frozen vision_tower and mm_projector in KimiK25VLModel Verified on proxy model (2 GPU, EP=2, FSDP optim_grads_params): - Loss: numerically equivalent (diff < 5e-4) - Peak memory (max_allocated): -454 MB (no temp all-gather buffers) - CUDA reserved: -908 MB - Throughput: within noise on 2 GPUs Co-Authored-By: Claude Opus 4.6 (1M context) --- .../mimo/model_providers/kimi_k25/model.py | 12 ++++ .../fsdp/src/megatron_fsdp/__init__.py | 2 + .../megatron_fsdp/param_and_grad_buffer.py | 64 +++++++++++++++++-- 3 files changed, 73 insertions(+), 5 deletions(-) diff --git a/examples/mimo/model_providers/kimi_k25/model.py b/examples/mimo/model_providers/kimi_k25/model.py index 2d9c87c436e..11a108eca82 100644 --- a/examples/mimo/model_providers/kimi_k25/model.py +++ b/examples/mimo/model_providers/kimi_k25/model.py @@ -81,6 +81,18 @@ def __init__( for p in self.mm_projector.parameters(): p.requires_grad = False + # Mark frozen vision modules as non-shardable so FSDP keeps them + # replicated, avoiding unnecessary all-gather/reshard communication. + try: + from megatron.core.distributed.fsdp.src.megatron_fsdp import fsdp_no_shard + + if freeze_vision_model and hasattr(self, "vision_tower"): + fsdp_no_shard(self.vision_tower) + if freeze_vision_projection and hasattr(self, "mm_projector"): + fsdp_no_shard(self.mm_projector) + except ImportError: + pass + # ------------------------------------------------------------------ # Vision init # ------------------------------------------------------------------ diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py index 1dd7263c4e2..685284a59d3 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py @@ -15,6 +15,7 @@ from .distributed_data_parallel_config import DistributedDataParallelConfig from .fully_shard import fully_shard, fully_shard_model, fully_shard_optimizer from .megatron_fsdp import MegatronFSDP +from .param_and_grad_buffer import fsdp_no_shard from .package_info import ( __contact_emails__, __contact_names__, @@ -34,6 +35,7 @@ "DistributedDataParallelConfig", "MegatronFSDP", "FSDPDistributedIndex", + "fsdp_no_shard", "fully_shard", "fully_shard_model", "fully_shard_optimizer", diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py index aabdd010ed9..b230fb07bc0 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py @@ -1259,6 +1259,7 @@ class ParameterGroup: is_expert_param: bool = False requires_grad: Optional[bool] = None fsdp_unit_id: Optional[int] = None + fsdp_no_shard: bool = False chunk_size_factor: int = 1 model_weight_buffer: Optional[DataParallelBuffer] = None transpose_weight_buffer: Optional[DataParallelBuffer] = None @@ -1268,6 +1269,40 @@ class ParameterGroup: hsdp_gbuf: Optional[DataParallelBuffer] = None +def fsdp_no_shard(module_or_param: "torch.nn.Module | torch.nn.Parameter"): + """Mark a module or parameter so that FSDP keeps it replicated (not sharded). + + This is useful for small or frozen sub-modules (e.g. a frozen vision encoder) + where sharding would add unnecessary all-gather / reshard communication. + + Args: + module_or_param: An ``nn.Module`` (all its parameters are marked) or a + single ``nn.Parameter``. + + Returns: + The input ``module_or_param`` (for convenient chaining). + + Example:: + + from megatron.core.distributed.fsdp.src.megatron_fsdp import fsdp_no_shard + + model = MyModel() + fsdp_no_shard(model.vision_tower) # mark entire sub-module + fsdp_no_shard(model.mm_projector) # mark another sub-module + # … then wrap the whole model with MegatronFSDP / FullyShardedDataParallel + """ + if isinstance(module_or_param, torch.nn.Module): + for param in module_or_param.parameters(): + param._fsdp_no_shard = True + elif isinstance(module_or_param, torch.nn.Parameter): + module_or_param._fsdp_no_shard = True + else: + raise TypeError( + f"Expected nn.Module or nn.Parameter, got {type(module_or_param)}" + ) + return module_or_param + + def _get_parameter_groups( module: torch.nn.Module, policy: BucketingPolicy, @@ -1339,6 +1374,7 @@ def _does_param_require_new_bucket(param): is_expert_param=is_expert_parameter(name, param), requires_grad=param.requires_grad, fsdp_unit_id=None, + fsdp_no_shard=getattr(param, "_fsdp_no_shard", False), ) # For all the new FSDP unit parameters collected, assign an ID number @@ -1375,7 +1411,7 @@ def _does_param_require_new_bucket(param): basic_attrs = { key: value for key, value in group.__dict__.items() - if key in ["dtype", "is_expert_param", "requires_grad", "fsdp_unit_id"] + if key in ["dtype", "is_expert_param", "requires_grad", "fsdp_unit_id", "fsdp_no_shard"] } for param in group.params: if _does_param_require_new_bucket(param): @@ -1446,6 +1482,7 @@ def _does_param_require_new_bucket(param): is_expert_param=group.is_expert_param, requires_grad=group.requires_grad, fsdp_unit_id=group.fsdp_unit_id, + fsdp_no_shard=group.fsdp_no_shard, chunk_size_factor=chunk_size_factor, ) ) @@ -1824,9 +1861,11 @@ def _bytes_to_mb(bytes_val: int) -> str: total_comm_bytes += group_comm # One-line summary for the group + no_shard_tag = " no_shard" if group.fsdp_no_shard else "" log_lines.append( f"[FSDP_UNIT {group.fsdp_unit_id}] Group {idx}: elems={numel} dtype={group.dtype} " f"bufs={','.join(buf_flags) or 'None'} pad={_bytes_to_mb(group_padded)}" + f"{no_shard_tag}" ) # List parameters below for param in group.params: @@ -1969,12 +2008,19 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): ) # Initialize the model weight buffer from bucket parameters. + # Parameters marked with _fsdp_no_shard=True are kept replicated + # (is_data_distributed=False) to avoid unnecessary all-gather/reshard + # overhead for small or non-trainable modules (e.g. frozen vision encoders). + should_shard_model_weights = ( + is_model_weight_buffer_distributed + and model_wbuf_dp_group.size() > 1 + and not group.fsdp_no_shard + ) if data_parallel_sharding_strategy != "no_shard": group.model_weight_buffer = DataParallelBuffer( self.ddp_config, group.params, - is_data_distributed=is_model_weight_buffer_distributed - and model_wbuf_dp_group.size() > 1, + is_data_distributed=should_shard_model_weights, dtype=param_dtype, device=self.device, data_parallel_group=model_wbuf_dp_group, @@ -1989,8 +2035,7 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): group.transpose_weight_buffer = DataParallelBuffer( self.ddp_config, group.params, - is_data_distributed=is_model_weight_buffer_distributed - and main_buf_dp_group.size() > 1, + is_data_distributed=should_shard_model_weights, dtype=param_dtype, device=self.device, data_parallel_group=main_buf_dp_group, @@ -3618,6 +3663,9 @@ def need_skip_prefetch(bucket_id): # Replace the parameter all-gather event with coalescing event. for bucket_id in buckets: bucket_key = self.get_bucket_key(bucket_id, bwd) + if bucket_key not in self.param_gather_event_map: + # Bucket was not all-gathered (e.g. replicated/frozen params). + continue _, mark_bucket_ready_to_use = self.param_gather_event_map[bucket_key] self.param_gather_event_map[bucket_key] = ( coalescing_event, @@ -3735,6 +3783,12 @@ def async_bucket_gather(self, bucket_id, bwd) -> None: self.recycle_unused_buckets() # Allocate an empty bucket to store the module weights. bucket = wbuf.fetch_bucket(set_param_data=True) + + if not wbuf.is_data_distributed: + # Buffer is replicated (e.g. frozen params), no all-gather needed. + self.bucket_status[bucket_key] = BucketStatus.READY_TO_USE + return + # All-gather the module weights in each buffer shard into the allocated bucket. # Now each rank will have a copy of this FSDP unit module's weights. param_gather_event = torch.distributed.all_gather_into_tensor( From 155f2bf83e1c7e4fc09ce54560cb93ba614ce583 Mon Sep 17 00:00:00 2001 From: jinliangl Date: Wed, 8 Apr 2026 01:46:16 -0700 Subject: [PATCH 2/4] perf: skip FSDP hooks for fsdp_no_shard modules Skip registering _pre_forward_param_unshard and _pre_backward_param_unshard hooks on modules whose parameters are all marked with _fsdp_no_shard. This eliminates per-module hook dispatch overhead for replicated (non-sharded) sub-modules like frozen vision encoders, which have many small layers that each trigger unnecessary hook calls. The optimization detects no_shard modules at hook registration time and skips both the module and all its descendants, avoiding the cost of hook dispatch, parameter list construction, and bucket lookups on every forward/backward pass. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../fsdp/src/megatron_fsdp/megatron_fsdp.py | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py index 671487a30eb..84a1d831629 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py @@ -18,7 +18,7 @@ from contextlib import contextmanager from enum import Enum, auto from typing import Any, Dict, List, Optional, Tuple - +from megatron.core.utils import nvtx_decorator import torch import torch.nn as nn from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -677,6 +677,7 @@ def _process_post_backward_gradients(param_list): self._params_require_handle_grad.discard(param) @torch.compiler.disable + @nvtx_decorator() def _pre_forward_param_unshard( module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] ): @@ -805,6 +806,7 @@ def _root_post_backward(*unused): self.finish_grad_sync() @torch.compiler.disable + @nvtx_decorator() def _pre_backward_param_unshard(module: nn.Module, *unused): """ Sub-module pre-backward hook to all-gather the module parameters @@ -868,6 +870,7 @@ def _root_pre_backward(module: nn.Module, *unused): torch.autograd.Variable._execution_engine.queue_callback(_root_post_backward) @torch.compiler.disable + @nvtx_decorator() def _post_forward(module: nn.Module, input: Any, output: Any): # When composed with module-hook-based activation recomputation, the # post-backward hook is responsible for resharding the module parameters @@ -892,6 +895,7 @@ def _post_forward(module: nn.Module, input: Any, output: Any): return output @torch.compiler.disable + @nvtx_decorator() def _release_module_fp8_transpose_cache(module: nn.Module, *unused): release_params_fp8_transpose_cache(module.parameters(recurse=False)) @@ -929,13 +933,29 @@ def forward_hook(_module, inputs, output): # on the output tensor(s). return module.register_forward_hook(forward_hook) + def _module_params_all_no_shard(module): + """Check if all parameters of a module (including children) are _fsdp_no_shard. + + Returns True if every parameter under this module is marked no_shard, + meaning no all-gather/reshard hooks are needed for this module or its + descendants. Returns False if there are no parameters. + """ + params = list(module.parameters()) + return len(params) > 0 and all( + getattr(p, "_fsdp_no_shard", False) for p in params + ) + def _register_pre_forward_param_unshard_hook(module): """ Register the forward pre-hook to unshard parameters before the forward pass. If we are not sharding anything, we do not have a model weight buffer and thus have nothing to all-gather / un-shard. + Skip for modules whose parameters are all marked _fsdp_no_shard, since they + are kept replicated and never need all-gathering. """ if self.ddp_config.data_parallel_sharding_strategy != "no_shard": + if _module_params_all_no_shard(module): + return self.forward_pre_hooks[f"{module._get_name()} parameter unshard"] = ( module.register_forward_pre_hook( _pre_forward_param_unshard, prepend=True, with_kwargs=True @@ -947,13 +967,25 @@ def _register_pre_backward_param_unshard_hook(module): Register the backward pre-hook to unshard FSDP unit module parameters immediately before the backward pass via attaching a gradient-triggered hook to the output tensor(s) of a module during a post-forward hook. + Skip for modules whose parameters are all marked _fsdp_no_shard. """ + if _module_params_all_no_shard(module): + return self.backward_pre_hooks[f"all-gather {module._get_name()} parameters"] = ( create_custom_backward_hook(module, _pre_backward_param_unshard) ) fsdp_modules = [] + no_shard_modules = [] for name, module in root_module.named_modules(): + # Skip modules (and their descendants) whose params are all no_shard. + # These are replicated and never need all-gather/reshard hooks. + if any(is_submodule(module, ns_mod) for ns_mod in no_shard_modules): + continue + if _module_params_all_no_shard(module): + no_shard_modules.append(module) + continue + if self.enable_fine_grained_param_gather_hook: _register_pre_forward_param_unshard_hook(module) _register_pre_backward_param_unshard_hook(module) From 70510daa77885cdb1609843da3c3085568700bd1 Mon Sep 17 00:00:00 2001 From: jinliangl Date: Wed, 8 Apr 2026 04:32:48 -0700 Subject: [PATCH 3/4] fix: preserve _fsdp_no_shard through meta-device param materialization When --init-model-with-meta-device is used, all model params (including frozen ViT) start on meta device. During ParamAndGradBuffer init, m.to_empty() + reset_parameters() creates new Parameter objects. The _reset_parameters method was copying TP attributes but not _fsdp_no_shard, causing the flag to be lost. Add _fsdp_no_shard to the attribute copy list in _reset_parameters so the flag survives meta-device materialization and is visible at hook registration time. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../fsdp/src/megatron_fsdp/megatron_fsdp.py | 8 ++++++++ .../src/megatron_fsdp/param_and_grad_buffer.py | 16 ++++++---------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py index 84a1d831629..624e69e54fc 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py @@ -981,6 +981,7 @@ def _register_pre_backward_param_unshard_hook(module): # Skip modules (and their descendants) whose params are all no_shard. # These are replicated and never need all-gather/reshard hooks. if any(is_submodule(module, ns_mod) for ns_mod in no_shard_modules): + logger.debug(f"[fsdp_no_shard] skip descendant: {name}") continue if _module_params_all_no_shard(module): no_shard_modules.append(module) @@ -1045,6 +1046,13 @@ def _register_pre_backward_param_unshard_hook(module): ) ) + if no_shard_modules: + logger.info( + f"[fsdp_no_shard] Skipped FSDP hooks for {len(no_shard_modules)} " + f"no_shard root module(s): " + f"{[m._get_name() for m in no_shard_modules]}" + ) + # Register root module pre- and post-backward hooks in cases where the # forward function of root module is not called, but rather the forward # function of the root module from named_modules() is called instead. diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py index b230fb07bc0..f15feefd30a 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py @@ -1291,15 +1291,8 @@ def fsdp_no_shard(module_or_param: "torch.nn.Module | torch.nn.Parameter"): fsdp_no_shard(model.mm_projector) # mark another sub-module # … then wrap the whole model with MegatronFSDP / FullyShardedDataParallel """ - if isinstance(module_or_param, torch.nn.Module): - for param in module_or_param.parameters(): - param._fsdp_no_shard = True - elif isinstance(module_or_param, torch.nn.Parameter): - module_or_param._fsdp_no_shard = True - else: - raise TypeError( - f"Expected nn.Module or nn.Parameter, got {type(module_or_param)}" - ) + for param in module_or_param.parameters(): + param._fsdp_no_shard = True return module_or_param @@ -2480,7 +2473,9 @@ def _reset_parameters(self, old_params, new_params): new_param.requires_grad_(old_param.requires_grad) - for tp_attr in ["_mcore_tp", "_tp_partition_dim", "_tp_duplicated"]: + for tp_attr in [ + "_mcore_tp", "_tp_partition_dim", "_tp_duplicated", "_fsdp_no_shard", + ]: if getattr(old_param, tp_attr, None) is not None: setattr(new_param, tp_attr, getattr(old_param, tp_attr)) @@ -2635,6 +2630,7 @@ def set_param_attribute(): "_mcore_tp", "_tp_duplicated", "_tp_partition_dim", + "_fsdp_no_shard", ]: if hasattr(orig_param, attr_name): setattr(param, attr_name, getattr(orig_param, attr_name)) From f1d93f26a6b7c986906f0099ac632eae4cb626f0 Mon Sep 17 00:00:00 2001 From: jinliangl Date: Wed, 8 Apr 2026 21:06:31 -0700 Subject: [PATCH 4/4] clean useless code --- .../fsdp/src/megatron_fsdp/megatron_fsdp.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py index 624e69e54fc..601d1372977 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py @@ -18,7 +18,7 @@ from contextlib import contextmanager from enum import Enum, auto from typing import Any, Dict, List, Optional, Tuple -from megatron.core.utils import nvtx_decorator + import torch import torch.nn as nn from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten @@ -677,7 +677,6 @@ def _process_post_backward_gradients(param_list): self._params_require_handle_grad.discard(param) @torch.compiler.disable - @nvtx_decorator() def _pre_forward_param_unshard( module: nn.Module, args: Tuple[Any, ...], kwargs: Dict[str, Any] ): @@ -806,7 +805,6 @@ def _root_post_backward(*unused): self.finish_grad_sync() @torch.compiler.disable - @nvtx_decorator() def _pre_backward_param_unshard(module: nn.Module, *unused): """ Sub-module pre-backward hook to all-gather the module parameters @@ -870,7 +868,6 @@ def _root_pre_backward(module: nn.Module, *unused): torch.autograd.Variable._execution_engine.queue_callback(_root_post_backward) @torch.compiler.disable - @nvtx_decorator() def _post_forward(module: nn.Module, input: Any, output: Any): # When composed with module-hook-based activation recomputation, the # post-backward hook is responsible for resharding the module parameters @@ -895,7 +892,6 @@ def _post_forward(module: nn.Module, input: Any, output: Any): return output @torch.compiler.disable - @nvtx_decorator() def _release_module_fp8_transpose_cache(module: nn.Module, *unused): release_params_fp8_transpose_cache(module.parameters(recurse=False)) @@ -954,8 +950,6 @@ def _register_pre_forward_param_unshard_hook(module): are kept replicated and never need all-gathering. """ if self.ddp_config.data_parallel_sharding_strategy != "no_shard": - if _module_params_all_no_shard(module): - return self.forward_pre_hooks[f"{module._get_name()} parameter unshard"] = ( module.register_forward_pre_hook( _pre_forward_param_unshard, prepend=True, with_kwargs=True @@ -969,8 +963,6 @@ def _register_pre_backward_param_unshard_hook(module): hook to the output tensor(s) of a module during a post-forward hook. Skip for modules whose parameters are all marked _fsdp_no_shard. """ - if _module_params_all_no_shard(module): - return self.backward_pre_hooks[f"all-gather {module._get_name()} parameters"] = ( create_custom_backward_hook(module, _pre_backward_param_unshard) )