diff --git a/examples/mimo/model_providers/kimi_k25/model.py b/examples/mimo/model_providers/kimi_k25/model.py index 2d9c87c436e..11a108eca82 100644 --- a/examples/mimo/model_providers/kimi_k25/model.py +++ b/examples/mimo/model_providers/kimi_k25/model.py @@ -81,6 +81,18 @@ def __init__( for p in self.mm_projector.parameters(): p.requires_grad = False + # Mark frozen vision modules as non-shardable so FSDP keeps them + # replicated, avoiding unnecessary all-gather/reshard communication. + try: + from megatron.core.distributed.fsdp.src.megatron_fsdp import fsdp_no_shard + + if freeze_vision_model and hasattr(self, "vision_tower"): + fsdp_no_shard(self.vision_tower) + if freeze_vision_projection and hasattr(self, "mm_projector"): + fsdp_no_shard(self.mm_projector) + except ImportError: + pass + # ------------------------------------------------------------------ # Vision init # ------------------------------------------------------------------ diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py index 1dd7263c4e2..685284a59d3 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py @@ -15,6 +15,7 @@ from .distributed_data_parallel_config import DistributedDataParallelConfig from .fully_shard import fully_shard, fully_shard_model, fully_shard_optimizer from .megatron_fsdp import MegatronFSDP +from .param_and_grad_buffer import fsdp_no_shard from .package_info import ( __contact_emails__, __contact_names__, @@ -34,6 +35,7 @@ "DistributedDataParallelConfig", "MegatronFSDP", "FSDPDistributedIndex", + "fsdp_no_shard", "fully_shard", "fully_shard_model", "fully_shard_optimizer", diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py index 671487a30eb..601d1372977 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py @@ -929,11 +929,25 @@ def forward_hook(_module, inputs, output): # on the output tensor(s). return module.register_forward_hook(forward_hook) + def _module_params_all_no_shard(module): + """Check if all parameters of a module (including children) are _fsdp_no_shard. + + Returns True if every parameter under this module is marked no_shard, + meaning no all-gather/reshard hooks are needed for this module or its + descendants. Returns False if there are no parameters. + """ + params = list(module.parameters()) + return len(params) > 0 and all( + getattr(p, "_fsdp_no_shard", False) for p in params + ) + def _register_pre_forward_param_unshard_hook(module): """ Register the forward pre-hook to unshard parameters before the forward pass. If we are not sharding anything, we do not have a model weight buffer and thus have nothing to all-gather / un-shard. + Skip for modules whose parameters are all marked _fsdp_no_shard, since they + are kept replicated and never need all-gathering. """ if self.ddp_config.data_parallel_sharding_strategy != "no_shard": self.forward_pre_hooks[f"{module._get_name()} parameter unshard"] = ( @@ -947,13 +961,24 @@ def _register_pre_backward_param_unshard_hook(module): Register the backward pre-hook to unshard FSDP unit module parameters immediately before the backward pass via attaching a gradient-triggered hook to the output tensor(s) of a module during a post-forward hook. + Skip for modules whose parameters are all marked _fsdp_no_shard. """ self.backward_pre_hooks[f"all-gather {module._get_name()} parameters"] = ( create_custom_backward_hook(module, _pre_backward_param_unshard) ) fsdp_modules = [] + no_shard_modules = [] for name, module in root_module.named_modules(): + # Skip modules (and their descendants) whose params are all no_shard. + # These are replicated and never need all-gather/reshard hooks. + if any(is_submodule(module, ns_mod) for ns_mod in no_shard_modules): + logger.debug(f"[fsdp_no_shard] skip descendant: {name}") + continue + if _module_params_all_no_shard(module): + no_shard_modules.append(module) + continue + if self.enable_fine_grained_param_gather_hook: _register_pre_forward_param_unshard_hook(module) _register_pre_backward_param_unshard_hook(module) @@ -1013,6 +1038,13 @@ def _register_pre_backward_param_unshard_hook(module): ) ) + if no_shard_modules: + logger.info( + f"[fsdp_no_shard] Skipped FSDP hooks for {len(no_shard_modules)} " + f"no_shard root module(s): " + f"{[m._get_name() for m in no_shard_modules]}" + ) + # Register root module pre- and post-backward hooks in cases where the # forward function of root module is not called, but rather the forward # function of the root module from named_modules() is called instead. diff --git a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py index aabdd010ed9..f15feefd30a 100644 --- a/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py +++ b/megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py @@ -1259,6 +1259,7 @@ class ParameterGroup: is_expert_param: bool = False requires_grad: Optional[bool] = None fsdp_unit_id: Optional[int] = None + fsdp_no_shard: bool = False chunk_size_factor: int = 1 model_weight_buffer: Optional[DataParallelBuffer] = None transpose_weight_buffer: Optional[DataParallelBuffer] = None @@ -1268,6 +1269,33 @@ class ParameterGroup: hsdp_gbuf: Optional[DataParallelBuffer] = None +def fsdp_no_shard(module_or_param: "torch.nn.Module | torch.nn.Parameter"): + """Mark a module or parameter so that FSDP keeps it replicated (not sharded). + + This is useful for small or frozen sub-modules (e.g. a frozen vision encoder) + where sharding would add unnecessary all-gather / reshard communication. + + Args: + module_or_param: An ``nn.Module`` (all its parameters are marked) or a + single ``nn.Parameter``. + + Returns: + The input ``module_or_param`` (for convenient chaining). + + Example:: + + from megatron.core.distributed.fsdp.src.megatron_fsdp import fsdp_no_shard + + model = MyModel() + fsdp_no_shard(model.vision_tower) # mark entire sub-module + fsdp_no_shard(model.mm_projector) # mark another sub-module + # … then wrap the whole model with MegatronFSDP / FullyShardedDataParallel + """ + for param in module_or_param.parameters(): + param._fsdp_no_shard = True + return module_or_param + + def _get_parameter_groups( module: torch.nn.Module, policy: BucketingPolicy, @@ -1339,6 +1367,7 @@ def _does_param_require_new_bucket(param): is_expert_param=is_expert_parameter(name, param), requires_grad=param.requires_grad, fsdp_unit_id=None, + fsdp_no_shard=getattr(param, "_fsdp_no_shard", False), ) # For all the new FSDP unit parameters collected, assign an ID number @@ -1375,7 +1404,7 @@ def _does_param_require_new_bucket(param): basic_attrs = { key: value for key, value in group.__dict__.items() - if key in ["dtype", "is_expert_param", "requires_grad", "fsdp_unit_id"] + if key in ["dtype", "is_expert_param", "requires_grad", "fsdp_unit_id", "fsdp_no_shard"] } for param in group.params: if _does_param_require_new_bucket(param): @@ -1446,6 +1475,7 @@ def _does_param_require_new_bucket(param): is_expert_param=group.is_expert_param, requires_grad=group.requires_grad, fsdp_unit_id=group.fsdp_unit_id, + fsdp_no_shard=group.fsdp_no_shard, chunk_size_factor=chunk_size_factor, ) ) @@ -1824,9 +1854,11 @@ def _bytes_to_mb(bytes_val: int) -> str: total_comm_bytes += group_comm # One-line summary for the group + no_shard_tag = " no_shard" if group.fsdp_no_shard else "" log_lines.append( f"[FSDP_UNIT {group.fsdp_unit_id}] Group {idx}: elems={numel} dtype={group.dtype} " f"bufs={','.join(buf_flags) or 'None'} pad={_bytes_to_mb(group_padded)}" + f"{no_shard_tag}" ) # List parameters below for param in group.params: @@ -1969,12 +2001,19 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): ) # Initialize the model weight buffer from bucket parameters. + # Parameters marked with _fsdp_no_shard=True are kept replicated + # (is_data_distributed=False) to avoid unnecessary all-gather/reshard + # overhead for small or non-trainable modules (e.g. frozen vision encoders). + should_shard_model_weights = ( + is_model_weight_buffer_distributed + and model_wbuf_dp_group.size() > 1 + and not group.fsdp_no_shard + ) if data_parallel_sharding_strategy != "no_shard": group.model_weight_buffer = DataParallelBuffer( self.ddp_config, group.params, - is_data_distributed=is_model_weight_buffer_distributed - and model_wbuf_dp_group.size() > 1, + is_data_distributed=should_shard_model_weights, dtype=param_dtype, device=self.device, data_parallel_group=model_wbuf_dp_group, @@ -1989,8 +2028,7 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params): group.transpose_weight_buffer = DataParallelBuffer( self.ddp_config, group.params, - is_data_distributed=is_model_weight_buffer_distributed - and main_buf_dp_group.size() > 1, + is_data_distributed=should_shard_model_weights, dtype=param_dtype, device=self.device, data_parallel_group=main_buf_dp_group, @@ -2435,7 +2473,9 @@ def _reset_parameters(self, old_params, new_params): new_param.requires_grad_(old_param.requires_grad) - for tp_attr in ["_mcore_tp", "_tp_partition_dim", "_tp_duplicated"]: + for tp_attr in [ + "_mcore_tp", "_tp_partition_dim", "_tp_duplicated", "_fsdp_no_shard", + ]: if getattr(old_param, tp_attr, None) is not None: setattr(new_param, tp_attr, getattr(old_param, tp_attr)) @@ -2590,6 +2630,7 @@ def set_param_attribute(): "_mcore_tp", "_tp_duplicated", "_tp_partition_dim", + "_fsdp_no_shard", ]: if hasattr(orig_param, attr_name): setattr(param, attr_name, getattr(orig_param, attr_name)) @@ -3618,6 +3659,9 @@ def need_skip_prefetch(bucket_id): # Replace the parameter all-gather event with coalescing event. for bucket_id in buckets: bucket_key = self.get_bucket_key(bucket_id, bwd) + if bucket_key not in self.param_gather_event_map: + # Bucket was not all-gathered (e.g. replicated/frozen params). + continue _, mark_bucket_ready_to_use = self.param_gather_event_map[bucket_key] self.param_gather_event_map[bucket_key] = ( coalescing_event, @@ -3735,6 +3779,12 @@ def async_bucket_gather(self, bucket_id, bwd) -> None: self.recycle_unused_buckets() # Allocate an empty bucket to store the module weights. bucket = wbuf.fetch_bucket(set_param_data=True) + + if not wbuf.is_data_distributed: + # Buffer is replicated (e.g. frozen params), no all-gather needed. + self.bucket_status[bucket_key] = BucketStatus.READY_TO_USE + return + # All-gather the module weights in each buffer shard into the allocated bucket. # Now each rank will have a copy of this FSDP unit module's weights. param_gather_event = torch.distributed.all_gather_into_tensor(