From 83f3b4ed79aa2b1e07f86868ae0a61fc5f782662 Mon Sep 17 00:00:00 2001 From: realAsma Date: Thu, 29 Jan 2026 15:58:59 -0800 Subject: [PATCH 1/2] Fixes for Megatron Expert Parallel, GroupedMLP and SequentialMLP Signed-off-by: realAsma --- modelopt/torch/quantization/model_calib.py | 43 +++++++++++++++++-- .../torch/quantization/plugins/megatron.py | 33 -------------- 2 files changed, 39 insertions(+), 37 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 0077d8666..361d32bce 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -62,6 +62,33 @@ def weight_only_quantize(model: nn.Module): seen_modules.add(module) +def _is_moe_submodule(module: nn.Module) -> bool: + """Check if a module is an MoE submodule.""" + parallel_state = getattr(module, "parallel_state", None) + if parallel_state is None: + return False + return parallel_state.expert_model_parallel_group.is_initialized() + + +def _check_moe_calibration_complete(quantizer, parallel_state): + """Raise error if MoE calibration is incomplete (some ranks have amax, others don't).""" + if isinstance(quantizer, SequentialQuantizer): + for _q in quantizer: + _check_moe_calibration_complete(_q, parallel_state) + return + for group in [parallel_state.data_parallel_group, parallel_state.expert_model_parallel_group, + parallel_state.tensor_parallel_group]: + if not group.is_initialized(): + continue + has_amax = getattr(quantizer, "_amax", None) is not None + amax_states = DistributedProcessGroup.get_dist_syncd_obj(has_amax, group, lambda objs: objs) + if any(amax_states) and not all(amax_states): + raise RuntimeError( + "MoE calibration incomplete: some experts received no tokens during calibration. " + "Increase --calib-size to ensure all experts see calibration data." + ) + + @torch.no_grad() def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True): """Calibrate the model using max. @@ -81,9 +108,21 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis forward_loop(model) finish_stats_collection(model) + # Sync amax across local experts within each rank (for SequentialMLP) + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + if not distributed_sync: return + # Check MoE calibration completeness before sync + for name, module in model.named_modules(): + if isinstance(module, QuantModule) and _is_moe_submodule(module): + for child in module.children(): + if isinstance(child, (TensorQuantizer, SequentialQuantizer)): + _check_moe_calibration_complete(child, module.parallel_state) + def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" if isinstance(quantizer, SequentialQuantizer): @@ -182,10 +221,6 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) - # MOE Quantization - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - # KV Cache Quantization if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 6a2bdc15f..9d31f848b 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -24,7 +24,6 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import megatron.core.transformer.moe.experts as megatron_moe -import megatron.core.transformer.moe.moe_layer as megatron_moe_layer import torch from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region @@ -735,35 +734,3 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): # Affine KVCache Quant bias vector. state_dict = self.state_dict(prefix="", keep_vars=True) return make_sharded_tensors_for_checkpoint(state_dict, prefix, {}, sharded_offsets) - - -@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"}) -class _QuantMoELayer(QuantModule): - """Module to support special handling of token dispatching during calibration. - - During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate. - However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance - returns. - - When group routing is used (e.g. DSR1 or V3), with group_topk set, topk can never be all experts (otherwise, - the total indices will be out of bound). Since DSR1 is the only model that uses group routing and all experts - can be calibrated normally, we disable this WAR when group_topk is used. - - If calibration is not enabled, this module behaves as a normal MoELayer. - - Note: - There are new arguments (e.g. padding_mask) passed through the forward function. Since we always pass through - all the arguments, we use **args and **kwargs here. - """ - - def _setup(self): - pass - - def forward(self, *args, **kwargs): - if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): - if self.config.moe_router_num_groups is None: - original_top_k = self.router.topk - self.router.topk = self.router.num_experts - super().forward(*args, **kwargs) - self.router.topk = original_top_k - return super().forward(*args, **kwargs) From 8880392f110ac67978cfda958e8f9316b135a7c3 Mon Sep 17 00:00:00 2001 From: realAsma Date: Fri, 30 Jan 2026 11:32:28 -0800 Subject: [PATCH 2/2] minor Signed-off-by: realAsma --- modelopt/torch/quantization/model_calib.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 361d32bce..61559ea41 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -62,12 +62,10 @@ def weight_only_quantize(model: nn.Module): seen_modules.add(module) -def _is_moe_submodule(module: nn.Module) -> bool: - """Check if a module is an MoE submodule.""" - parallel_state = getattr(module, "parallel_state", None) - if parallel_state is None: - return False - return parallel_state.expert_model_parallel_group.is_initialized() +def _has_expert_parallelism(module: nn.Module) -> bool: + """Check if module has expert parallelism enabled.""" + ps = getattr(module, "parallel_state", None) + return ps is not None and ps.expert_model_parallel_group.is_initialized() def _check_moe_calibration_complete(quantizer, parallel_state): @@ -76,8 +74,11 @@ def _check_moe_calibration_complete(quantizer, parallel_state): for _q in quantizer: _check_moe_calibration_complete(_q, parallel_state) return - for group in [parallel_state.data_parallel_group, parallel_state.expert_model_parallel_group, - parallel_state.tensor_parallel_group]: + for group in [ + parallel_state.data_parallel_group, + parallel_state.expert_model_parallel_group, + parallel_state.tensor_parallel_group, + ]: if not group.is_initialized(): continue has_amax = getattr(quantizer, "_amax", None) is not None @@ -118,7 +119,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis # Check MoE calibration completeness before sync for name, module in model.named_modules(): - if isinstance(module, QuantModule) and _is_moe_submodule(module): + if isinstance(module, QuantModule) and _has_expert_parallelism(module): for child in module.children(): if isinstance(child, (TensorQuantizer, SequentialQuantizer)): _check_moe_calibration_complete(child, module.parallel_state)