-
Notifications
You must be signed in to change notification settings - Fork 247
Fixes for Megatron Expert Parallel, GroupedMLP and SequentialMLP #831
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -62,6 +62,34 @@ def weight_only_quantize(model: nn.Module): | |
| seen_modules.add(module) | ||
|
|
||
|
|
||
| def _has_expert_parallelism(module: nn.Module) -> bool: | ||
| """Check if module has expert parallelism enabled.""" | ||
| ps = getattr(module, "parallel_state", None) | ||
| return ps is not None and ps.expert_model_parallel_group.is_initialized() | ||
|
|
||
|
|
||
| def _check_moe_calibration_complete(quantizer, parallel_state): | ||
| """Raise error if MoE calibration is incomplete (some ranks have amax, others don't).""" | ||
| if isinstance(quantizer, SequentialQuantizer): | ||
| for _q in quantizer: | ||
| _check_moe_calibration_complete(_q, parallel_state) | ||
| return | ||
| for group in [ | ||
| parallel_state.data_parallel_group, | ||
| parallel_state.expert_model_parallel_group, | ||
| parallel_state.tensor_parallel_group, | ||
| ]: | ||
| if not group.is_initialized(): | ||
| continue | ||
| has_amax = getattr(quantizer, "_amax", None) is not None | ||
| amax_states = DistributedProcessGroup.get_dist_syncd_obj(has_amax, group, lambda objs: objs) | ||
| if any(amax_states) and not all(amax_states): | ||
| raise RuntimeError( | ||
| "MoE calibration incomplete: some experts received no tokens during calibration. " | ||
| "Increase --calib-size to ensure all experts see calibration data." | ||
| ) | ||
|
|
||
|
|
||
| @torch.no_grad() | ||
| def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True): | ||
| """Calibrate the model using max. | ||
|
|
@@ -81,9 +109,21 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis | |
| forward_loop(model) | ||
| finish_stats_collection(model) | ||
|
|
||
| # Sync amax across local experts within each rank (for SequentialMLP) | ||
| for name, module in model.named_modules(): | ||
| if hasattr(module, "sync_moe_local_experts_amax"): | ||
| module.sync_moe_local_experts_amax() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see you remove the all expert routing patch above. So this function will make sure those un-calibrated experts all have amax?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since the amax is synced across experts, the incomplete calibration shouldn't happen, right?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. And another thing, can the sync count the amax of an unseen expert? Current logic seems, the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm not entirely sure this is true because the amax is only "synced" between experts in a local layer and we ran into deadlocks before when
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is best to know what
It is doing 2) for sure. The question if this function is doing 1)? If it is not creating amax for those who didn't have amax before than likely we still have the problem.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As long as at least one expert in the EP rank sees a token we should be good I think
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem becomes if EP distributed parallelism is so high such as 128 - then it is possible that some ranks might have experts without any tokens
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But I guess we never do that extreme EP PTQ
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wdyt?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Our previous logic is to force using all experts during calibration, that is equivalent to set an unseen expert's |
||
|
|
||
| if not distributed_sync: | ||
| return | ||
|
|
||
| # Check MoE calibration completeness before sync | ||
| for name, module in model.named_modules(): | ||
| if isinstance(module, QuantModule) and _has_expert_parallelism(module): | ||
| for child in module.children(): | ||
| if isinstance(child, (TensorQuantizer, SequentialQuantizer)): | ||
| _check_moe_calibration_complete(child, module.parallel_state) | ||
|
|
||
| def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): | ||
| """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" | ||
| if isinstance(quantizer, SequentialQuantizer): | ||
|
|
@@ -182,10 +222,6 @@ def sync_quantizer_amax_across_tp( | |
| parallel_state=module.parallel_state, | ||
| ) | ||
|
|
||
| # MOE Quantization | ||
| if hasattr(module, "sync_moe_local_experts_amax"): | ||
| module.sync_moe_local_experts_amax() | ||
|
|
||
| # KV Cache Quantization | ||
| if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): | ||
| # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.