diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 0077d8666..61559ea41 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -62,6 +62,34 @@ def weight_only_quantize(model: nn.Module): seen_modules.add(module) +def _has_expert_parallelism(module: nn.Module) -> bool: + """Check if module has expert parallelism enabled.""" + ps = getattr(module, "parallel_state", None) + return ps is not None and ps.expert_model_parallel_group.is_initialized() + + +def _check_moe_calibration_complete(quantizer, parallel_state): + """Raise error if MoE calibration is incomplete (some ranks have amax, others don't).""" + if isinstance(quantizer, SequentialQuantizer): + for _q in quantizer: + _check_moe_calibration_complete(_q, parallel_state) + return + for group in [ + parallel_state.data_parallel_group, + parallel_state.expert_model_parallel_group, + parallel_state.tensor_parallel_group, + ]: + if not group.is_initialized(): + continue + has_amax = getattr(quantizer, "_amax", None) is not None + amax_states = DistributedProcessGroup.get_dist_syncd_obj(has_amax, group, lambda objs: objs) + if any(amax_states) and not all(amax_states): + raise RuntimeError( + "MoE calibration incomplete: some experts received no tokens during calibration. " + "Increase --calib-size to ensure all experts see calibration data." + ) + + @torch.no_grad() def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True): """Calibrate the model using max. @@ -81,9 +109,21 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis forward_loop(model) finish_stats_collection(model) + # Sync amax across local experts within each rank (for SequentialMLP) + for name, module in model.named_modules(): + if hasattr(module, "sync_moe_local_experts_amax"): + module.sync_moe_local_experts_amax() + if not distributed_sync: return + # Check MoE calibration completeness before sync + for name, module in model.named_modules(): + if isinstance(module, QuantModule) and _has_expert_parallelism(module): + for child in module.children(): + if isinstance(child, (TensorQuantizer, SequentialQuantizer)): + _check_moe_calibration_complete(child, module.parallel_state) + def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state): """Synchronize the amax across all ranks in the data parallel and expert parallel groups.""" if isinstance(quantizer, SequentialQuantizer): @@ -182,10 +222,6 @@ def sync_quantizer_amax_across_tp( parallel_state=module.parallel_state, ) - # MOE Quantization - if hasattr(module, "sync_moe_local_experts_amax"): - module.sync_moe_local_experts_amax() - # KV Cache Quantization if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"): # We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache) diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 6a2bdc15f..9d31f848b 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -24,7 +24,6 @@ import megatron.core.tensor_parallel.layers as megatron_parallel import megatron.core.transformer.mlp as megatron_mlp import megatron.core.transformer.moe.experts as megatron_moe -import megatron.core.transformer.moe.moe_layer as megatron_moe_layer import torch from megatron.core.parallel_state import get_data_parallel_group from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region @@ -735,35 +734,3 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): # Affine KVCache Quant bias vector. state_dict = self.state_dict(prefix="", keep_vars=True) return make_sharded_tensors_for_checkpoint(state_dict, prefix, {}, sharded_offsets) - - -@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"}) -class _QuantMoELayer(QuantModule): - """Module to support special handling of token dispatching during calibration. - - During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate. - However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance - returns. - - When group routing is used (e.g. DSR1 or V3), with group_topk set, topk can never be all experts (otherwise, - the total indices will be out of bound). Since DSR1 is the only model that uses group routing and all experts - can be calibrated normally, we disable this WAR when group_topk is used. - - If calibration is not enabled, this module behaves as a normal MoELayer. - - Note: - There are new arguments (e.g. padding_mask) passed through the forward function. Since we always pass through - all the arguments, we use **args and **kwargs here. - """ - - def _setup(self): - pass - - def forward(self, *args, **kwargs): - if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): - if self.config.moe_router_num_groups is None: - original_top_k = self.router.topk - self.router.topk = self.router.num_experts - super().forward(*args, **kwargs) - self.router.topk = original_top_k - return super().forward(*args, **kwargs)