Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,34 @@ def weight_only_quantize(model: nn.Module):
seen_modules.add(module)


def _has_expert_parallelism(module: nn.Module) -> bool:
"""Check if module has expert parallelism enabled."""
ps = getattr(module, "parallel_state", None)
return ps is not None and ps.expert_model_parallel_group.is_initialized()


def _check_moe_calibration_complete(quantizer, parallel_state):
"""Raise error if MoE calibration is incomplete (some ranks have amax, others don't)."""
if isinstance(quantizer, SequentialQuantizer):
for _q in quantizer:
_check_moe_calibration_complete(_q, parallel_state)
return
for group in [
parallel_state.data_parallel_group,
parallel_state.expert_model_parallel_group,
parallel_state.tensor_parallel_group,
]:
if not group.is_initialized():
continue
has_amax = getattr(quantizer, "_amax", None) is not None
amax_states = DistributedProcessGroup.get_dist_syncd_obj(has_amax, group, lambda objs: objs)
if any(amax_states) and not all(amax_states):
raise RuntimeError(
"MoE calibration incomplete: some experts received no tokens during calibration. "
"Increase --calib-size to ensure all experts see calibration data."
)


@torch.no_grad()
def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True):
"""Calibrate the model using max.
Expand All @@ -81,9 +109,21 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
forward_loop(model)
finish_stats_collection(model)

# Sync amax across local experts within each rank (for SequentialMLP)
for name, module in model.named_modules():
if hasattr(module, "sync_moe_local_experts_amax"):
module.sync_moe_local_experts_amax()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you remove the all expert routing patch above. So this function will make sure those un-calibrated experts all have amax?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the amax is synced across experts, the incomplete calibration shouldn't happen, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And another thing, can the sync count the amax of an unseen expert? Current logic seems, the weight_quantizer.amax will be the max of all seen experts.
And for GroupedMLP, we don't have this issue, right?

Copy link
Contributor

@jenchen13 jenchen13 Jan 30, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the amax is synced across experts, the incomplete calibration shouldn't happen,

I'm not entirely sure this is true because the amax is only "synced" between experts in a local layer and we ran into deadlocks before when _QuantMoELayer.forward was not introduced

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is best to know what sync_moe_local_experts_amax does and doesn't.

  1. Does it create amax for those input quantizers that didn't have amax before? a.k.a not calibrated
  2. It sets all quantizers that "have amax" to the same max value.

It is doing 2) for sure. The question if this function is doing 1)? If it is not creating amax for those who didn't have amax before than likely we still have the problem.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as at least one expert in the EP rank sees a token we should be good I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem becomes if EP distributed parallelism is so high such as 128 - then it is possible that some ranks might have experts without any tokens

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But I guess we never do that extreme EP PTQ

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our previous logic is to force using all experts during calibration, that is equivalent to set an unseen expert's weight_quantizer.amax to its weight's amax and input_quantizer.amax to amax of input. If the previous result is good, we can still give users an option to force_using_all_experts and create amax for unseen experts. And for GroupedMLP, is this silently set to True? We should give a warning if so.


if not distributed_sync:
return

# Check MoE calibration completeness before sync
for name, module in model.named_modules():
if isinstance(module, QuantModule) and _has_expert_parallelism(module):
for child in module.children():
if isinstance(child, (TensorQuantizer, SequentialQuantizer)):
_check_moe_calibration_complete(child, module.parallel_state)

def sync_quantizer_amax_across_dp_ep(quantizer, parallel_state):
"""Synchronize the amax across all ranks in the data parallel and expert parallel groups."""
if isinstance(quantizer, SequentialQuantizer):
Expand Down Expand Up @@ -182,10 +222,6 @@ def sync_quantizer_amax_across_tp(
parallel_state=module.parallel_state,
)

# MOE Quantization
if hasattr(module, "sync_moe_local_experts_amax"):
module.sync_moe_local_experts_amax()

# KV Cache Quantization
if hasattr(module, "k_bmm_quantizer") and hasattr(module, "v_bmm_quantizer"):
# We only support KVCache quantization with scalar per-tensor states for now (NVFP4 & FP8 KV cache)
Expand Down
33 changes: 0 additions & 33 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import megatron.core.tensor_parallel.layers as megatron_parallel
import megatron.core.transformer.mlp as megatron_mlp
import megatron.core.transformer.moe.experts as megatron_moe
import megatron.core.transformer.moe.moe_layer as megatron_moe_layer
import torch
from megatron.core.parallel_state import get_data_parallel_group
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
Expand Down Expand Up @@ -735,35 +734,3 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
# Affine KVCache Quant bias vector.
state_dict = self.state_dict(prefix="", keep_vars=True)
return make_sharded_tensors_for_checkpoint(state_dict, prefix, {}, sharded_offsets)


@QuantModuleRegistry.register({megatron_moe_layer.MoELayer: "megatron_moe_MoELayer"})
class _QuantMoELayer(QuantModule):
"""Module to support special handling of token dispatching during calibration.

During calibration, we forward all tokens to all experts so that all experts see sufficient tokens to calibrate.
However, even in calibration mode, the actual top_k routing is used to calculate the actual outputs this instance
returns.

When group routing is used (e.g. DSR1 or V3), with group_topk set, topk can never be all experts (otherwise,
the total indices will be out of bound). Since DSR1 is the only model that uses group routing and all experts
can be calibrated normally, we disable this WAR when group_topk is used.

If calibration is not enabled, this module behaves as a normal MoELayer.

Note:
There are new arguments (e.g. padding_mask) passed through the forward function. Since we always pass through
all the arguments, we use **args and **kwargs here.
"""

def _setup(self):
pass

def forward(self, *args, **kwargs):
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
if self.config.moe_router_num_groups is None:
original_top_k = self.router.topk
self.router.topk = self.router.num_experts
super().forward(*args, **kwargs)
self.router.topk = original_top_k
return super().forward(*args, **kwargs)