Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/mimo/model_providers/kimi_k25/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ def __init__(
for p in self.mm_projector.parameters():
p.requires_grad = False

# Mark frozen vision modules as non-shardable so FSDP keeps them
# replicated, avoiding unnecessary all-gather/reshard communication.
try:
from megatron.core.distributed.fsdp.src.megatron_fsdp import fsdp_no_shard

if freeze_vision_model and hasattr(self, "vision_tower"):
fsdp_no_shard(self.vision_tower)
if freeze_vision_projection and hasattr(self, "mm_projector"):
fsdp_no_shard(self.mm_projector)
except ImportError:
pass

# ------------------------------------------------------------------
# Vision init
# ------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions megatron/core/distributed/fsdp/src/megatron_fsdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .distributed_data_parallel_config import DistributedDataParallelConfig
from .fully_shard import fully_shard, fully_shard_model, fully_shard_optimizer
from .megatron_fsdp import MegatronFSDP
from .param_and_grad_buffer import fsdp_no_shard
from .package_info import (
__contact_emails__,
__contact_names__,
Expand All @@ -34,6 +35,7 @@
"DistributedDataParallelConfig",
"MegatronFSDP",
"FSDPDistributedIndex",
"fsdp_no_shard",
"fully_shard",
"fully_shard_model",
"fully_shard_optimizer",
Expand Down
32 changes: 32 additions & 0 deletions megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -929,11 +929,25 @@ def forward_hook(_module, inputs, output):
# on the output tensor(s).
return module.register_forward_hook(forward_hook)

def _module_params_all_no_shard(module):
Comment thread
wplf marked this conversation as resolved.
"""Check if all parameters of a module (including children) are _fsdp_no_shard.

Returns True if every parameter under this module is marked no_shard,
meaning no all-gather/reshard hooks are needed for this module or its
descendants. Returns False if there are no parameters.
"""
params = list(module.parameters())
return len(params) > 0 and all(
getattr(p, "_fsdp_no_shard", False) for p in params
)

def _register_pre_forward_param_unshard_hook(module):
"""
Register the forward pre-hook to unshard parameters before the forward pass.
If we are not sharding anything, we do not have a model weight buffer and thus
have nothing to all-gather / un-shard.
Skip for modules whose parameters are all marked _fsdp_no_shard, since they
are kept replicated and never need all-gathering.
"""
if self.ddp_config.data_parallel_sharding_strategy != "no_shard":
self.forward_pre_hooks[f"{module._get_name()} parameter unshard"] = (
Expand All @@ -947,13 +961,24 @@ def _register_pre_backward_param_unshard_hook(module):
Register the backward pre-hook to unshard FSDP unit module parameters
immediately before the backward pass via attaching a gradient-triggered
hook to the output tensor(s) of a module during a post-forward hook.
Skip for modules whose parameters are all marked _fsdp_no_shard.
"""
self.backward_pre_hooks[f"all-gather {module._get_name()} parameters"] = (
create_custom_backward_hook(module, _pre_backward_param_unshard)
)

fsdp_modules = []
no_shard_modules = []
for name, module in root_module.named_modules():
# Skip modules (and their descendants) whose params are all no_shard.
# These are replicated and never need all-gather/reshard hooks.
if any(is_submodule(module, ns_mod) for ns_mod in no_shard_modules):
logger.debug(f"[fsdp_no_shard] skip descendant: {name}")
continue
if _module_params_all_no_shard(module):
no_shard_modules.append(module)
continue

if self.enable_fine_grained_param_gather_hook:
_register_pre_forward_param_unshard_hook(module)
_register_pre_backward_param_unshard_hook(module)
Expand Down Expand Up @@ -1013,6 +1038,13 @@ def _register_pre_backward_param_unshard_hook(module):
)
)

if no_shard_modules:
logger.info(
f"[fsdp_no_shard] Skipped FSDP hooks for {len(no_shard_modules)} "
f"no_shard root module(s): "
f"{[m._get_name() for m in no_shard_modules]}"
)

# Register root module pre- and post-backward hooks in cases where the
# forward function of root module is not called, but rather the forward
# function of the root module from named_modules() is called instead.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1259,6 +1259,7 @@ class ParameterGroup:
is_expert_param: bool = False
requires_grad: Optional[bool] = None
fsdp_unit_id: Optional[int] = None
fsdp_no_shard: bool = False
chunk_size_factor: int = 1
model_weight_buffer: Optional[DataParallelBuffer] = None
transpose_weight_buffer: Optional[DataParallelBuffer] = None
Expand All @@ -1268,6 +1269,33 @@ class ParameterGroup:
hsdp_gbuf: Optional[DataParallelBuffer] = None


def fsdp_no_shard(module_or_param: "torch.nn.Module | torch.nn.Parameter"):
"""Mark a module or parameter so that FSDP keeps it replicated (not sharded).

This is useful for small or frozen sub-modules (e.g. a frozen vision encoder)
where sharding would add unnecessary all-gather / reshard communication.

Args:
module_or_param: An ``nn.Module`` (all its parameters are marked) or a
single ``nn.Parameter``.

Comment thread
wplf marked this conversation as resolved.
Returns:
The input ``module_or_param`` (for convenient chaining).

Example::

from megatron.core.distributed.fsdp.src.megatron_fsdp import fsdp_no_shard

model = MyModel()
fsdp_no_shard(model.vision_tower) # mark entire sub-module
fsdp_no_shard(model.mm_projector) # mark another sub-module
# … then wrap the whole model with MegatronFSDP / FullyShardedDataParallel
"""
for param in module_or_param.parameters():
param._fsdp_no_shard = True
return module_or_param


def _get_parameter_groups(
module: torch.nn.Module,
policy: BucketingPolicy,
Expand Down Expand Up @@ -1339,6 +1367,7 @@ def _does_param_require_new_bucket(param):
is_expert_param=is_expert_parameter(name, param),
requires_grad=param.requires_grad,
fsdp_unit_id=None,
fsdp_no_shard=getattr(param, "_fsdp_no_shard", False),
)

# For all the new FSDP unit parameters collected, assign an ID number
Expand Down Expand Up @@ -1375,7 +1404,7 @@ def _does_param_require_new_bucket(param):
basic_attrs = {
key: value
for key, value in group.__dict__.items()
if key in ["dtype", "is_expert_param", "requires_grad", "fsdp_unit_id"]
if key in ["dtype", "is_expert_param", "requires_grad", "fsdp_unit_id", "fsdp_no_shard"]
}
for param in group.params:
if _does_param_require_new_bucket(param):
Expand Down Expand Up @@ -1446,6 +1475,7 @@ def _does_param_require_new_bucket(param):
is_expert_param=group.is_expert_param,
requires_grad=group.requires_grad,
fsdp_unit_id=group.fsdp_unit_id,
fsdp_no_shard=group.fsdp_no_shard,
chunk_size_factor=chunk_size_factor,
)
)
Expand Down Expand Up @@ -1824,9 +1854,11 @@ def _bytes_to_mb(bytes_val: int) -> str:
total_comm_bytes += group_comm

# One-line summary for the group
no_shard_tag = " no_shard" if group.fsdp_no_shard else ""
log_lines.append(
f"[FSDP_UNIT {group.fsdp_unit_id}] Group {idx}: elems={numel} dtype={group.dtype} "
f"bufs={','.join(buf_flags) or 'None'} pad={_bytes_to_mb(group_padded)}"
f"{no_shard_tag}"
)
# List parameters below
for param in group.params:
Expand Down Expand Up @@ -1969,12 +2001,19 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
)

# Initialize the model weight buffer from bucket parameters.
# Parameters marked with _fsdp_no_shard=True are kept replicated
# (is_data_distributed=False) to avoid unnecessary all-gather/reshard
# overhead for small or non-trainable modules (e.g. frozen vision encoders).
should_shard_model_weights = (
is_model_weight_buffer_distributed
and model_wbuf_dp_group.size() > 1
and not group.fsdp_no_shard
)
if data_parallel_sharding_strategy != "no_shard":
group.model_weight_buffer = DataParallelBuffer(
self.ddp_config,
group.params,
is_data_distributed=is_model_weight_buffer_distributed
and model_wbuf_dp_group.size() > 1,
is_data_distributed=should_shard_model_weights,
Comment thread
wplf marked this conversation as resolved.
dtype=param_dtype,
device=self.device,
data_parallel_group=model_wbuf_dp_group,
Expand All @@ -1989,8 +2028,7 @@ def _init_each_parameter_group_buffers(self, meta_device_init_fp8_params):
group.transpose_weight_buffer = DataParallelBuffer(
self.ddp_config,
group.params,
is_data_distributed=is_model_weight_buffer_distributed
and main_buf_dp_group.size() > 1,
is_data_distributed=should_shard_model_weights,
dtype=param_dtype,
device=self.device,
data_parallel_group=main_buf_dp_group,
Expand Down Expand Up @@ -2435,7 +2473,9 @@ def _reset_parameters(self, old_params, new_params):

new_param.requires_grad_(old_param.requires_grad)

for tp_attr in ["_mcore_tp", "_tp_partition_dim", "_tp_duplicated"]:
for tp_attr in [
"_mcore_tp", "_tp_partition_dim", "_tp_duplicated", "_fsdp_no_shard",
]:
if getattr(old_param, tp_attr, None) is not None:
setattr(new_param, tp_attr, getattr(old_param, tp_attr))

Expand Down Expand Up @@ -2590,6 +2630,7 @@ def set_param_attribute():
"_mcore_tp",
"_tp_duplicated",
"_tp_partition_dim",
"_fsdp_no_shard",
]:
if hasattr(orig_param, attr_name):
setattr(param, attr_name, getattr(orig_param, attr_name))
Expand Down Expand Up @@ -3618,6 +3659,9 @@ def need_skip_prefetch(bucket_id):
# Replace the parameter all-gather event with coalescing event.
for bucket_id in buckets:
bucket_key = self.get_bucket_key(bucket_id, bwd)
if bucket_key not in self.param_gather_event_map:
# Bucket was not all-gathered (e.g. replicated/frozen params).
continue
_, mark_bucket_ready_to_use = self.param_gather_event_map[bucket_key]
self.param_gather_event_map[bucket_key] = (
coalescing_event,
Expand Down Expand Up @@ -3735,6 +3779,12 @@ def async_bucket_gather(self, bucket_id, bwd) -> None:
self.recycle_unused_buckets()
# Allocate an empty bucket to store the module weights.
bucket = wbuf.fetch_bucket(set_param_data=True)

if not wbuf.is_data_distributed:
# Buffer is replicated (e.g. frozen params), no all-gather needed.
self.bucket_status[bucket_key] = BucketStatus.READY_TO_USE
return

# All-gather the module weights in each buffer shard into the allocated bucket.
# Now each rank will have a copy of this FSDP unit module's weights.
param_gather_event = torch.distributed.all_gather_into_tensor(
Expand Down