From 0994db68b3ba37d5bcb76fc0ade561683d185589 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Tue, 23 Sep 2025 02:15:01 +0530 Subject: [PATCH 01/15] Deprecate replace_modules_for_calibration --- src/llmcompressor/modeling/prepare.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 113fd4364..286d61d27 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,3 +1,5 @@ +import warnings + import tqdm from compressed_tensors.utils import replace_module from transformers import PreTrainedModel @@ -20,6 +22,14 @@ def replace_modules_for_calibration( model: PreTrainedModel, calibrate_all_experts: bool = True, ) -> PreTrainedModel: + # This function is deprecated. Use moe_calibration_context instead. + warnings.warn( + "replace_modules_for_calibration is deprecated. \ + Use moe_calibration_context instead.", + DeprecationWarning, + stacklevel=2, + ) + for name, module in tqdm.tqdm(list(model.named_modules())): cls_name = module.__class__.__name__ if cls_name in replacements: From c7a99430d20507fdd50044118ac26b97702e8e39 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Tue, 23 Sep 2025 02:15:21 +0530 Subject: [PATCH 02/15] Refactor MoECalibrationContext --- src/llmcompressor/modeling/moe_context.py | 155 ++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 src/llmcompressor/modeling/moe_context.py diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py new file mode 100644 index 000000000..835f36d13 --- /dev/null +++ b/src/llmcompressor/modeling/moe_context.py @@ -0,0 +1,155 @@ +""" +Standardized interface for MoE model calibration. +MoE calibration context is used to apply MoE calibration modifications to the model. +There are two types of MoE calibration contexts: +1. ContextualMoECalibration: uses context managers for temporary modifications +2. PermanentMoECalibration: permanently modifies the model +""" + +import contextlib +from abc import ABC, abstractmethod +from typing import Dict, TypeVar, Union + +from transformers import PreTrainedModel + +T = TypeVar("T", bound="MoECalibrationContext") + + +class MoECalibrationContext(ABC): + """ + Abstract base class for MoE calibration. + This provides a standardized interface for MoE model calibration. + """ + + @abstractmethod + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """ + Apply MoE calibration modifications to the model. + :param model: The model to modify + :param calibrate_all_experts: Whether to calibrate all + experts or only routed ones + """ + pass + + @abstractmethod + def restore(self, model: PreTrainedModel) -> None: + """ + Restore the model to its original state. + :param model: The model to restore + """ + pass + + +class ContextualMoECalibration(MoECalibrationContext): + """ + MoE calibration that uses context managers for temporary modifications. + This is suitable for models that need to be restored after calibration. + """ + + def __init__(self, model_class_name: str, update_function): + """ + Initialize the context manager-based MoE calibration. + :param model_class_name: The class name of the model this context applies to + :param update_function: Function that applies the MoE modifications + """ + self.model_class_name = model_class_name + self.update_function = update_function + self._stack = None + + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """Apply MoE calibration modifications using context managers.""" + if self._stack is None: + self._stack = contextlib.ExitStack() + + self.update_function(model, self._stack, calibrate_all_experts) + + def restore(self, model: PreTrainedModel) -> None: + """Restore the model by exiting the context stack.""" + if self._stack is not None: + self._stack.close() + self._stack = None + + +class PermanentMoECalibration(MoECalibrationContext): + """ + MoE calibration context that permanently modifies the model. + This is suitable for models that can be loaded in their modified form + (e.g., Llama4 in vLLM). + """ + + def __init__(self, model_class_name: str, replacement_function): + """ + Initialize the permanent MoE calibration. + :param model_class_name: The class name of the model this context applies to + :param replacement_function: Function that permanently replaces MoE modules + """ + self.model_class_name = model_class_name + self.replacement_function = replacement_function + self._original_modules = {} + + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """Apply permanent MoE calibration modifications.""" + # Store original modules for potential restoration + for name, module in model.named_modules(): + if module.__class__.__name__ == self.model_class_name: + self._original_modules[name] = module + + # Apply the replacement + self.replacement_function(model, calibrate_all_experts) + + def restore(self, model: PreTrainedModel) -> None: + """Restore original modules (if needed).""" + # For permanent MoE calibrations, restoration is typically not needed + # as the model is meant to stay in its modified form + pass + + +# Registry for MoE calibrations +_MOE_CONTEXTS: Dict[str, MoECalibrationContext] = {} + + +def register_moe_context(model_class_name: str, context: MoECalibrationContext) -> None: + """ + Register a MoE calibration context for a model class. + :param model_class_name: The class name of the model + :param context: The MoE calibration context to register + """ + _MOE_CONTEXTS[model_class_name] = context + + +def get_moe_context(model_class_name: str) -> Union[MoECalibrationContext, None]: + """ + Get the registered MoE calibration context for a model class. + :param model_class_name: The class name of the model + :return: The MoE calibration context or None if not found + """ + return _MOE_CONTEXTS.get(model_class_name) + + +def list_supported_models() -> list: + """ + List all model classes that have registered MoE calibration contexts. + :return: List of supported model class names + """ + return list(_MOE_CONTEXTS.keys()) + + +# Convenience function for backward compatibility +def create_context_manager_context(model_class_name: str, update_function): + """ + Create a context manager-based MoE calibration. + :param model_class_name: The class name of the model + :param update_function: Function that applies the MoE modifications + :return: A ContextualMoECalibration instance + """ + return ContextualMoECalibration(model_class_name, update_function) + + +def create_permanent_context(model_class_name: str, replacement_function): + """ + Create a permanent MoE calibration. + :param model_class_name: The class name of the model + :param replacement_function: Function that permanently replaces MoE modules + :return: A PermanentMoECalibration instance + """ + return PermanentMoECalibration(model_class_name, replacement_function) From e374313a87908dcaed8160bff3a2ac06c1f0eb3a Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 24 Sep 2025 20:23:03 +0530 Subject: [PATCH 03/15] Use moe_context in pipeline and by default and add tests --- src/llmcompressor/args/dataset_arguments.py | 35 +++- src/llmcompressor/modeling/moe_context.py | 176 ++++++++++++++++-- src/llmcompressor/modeling/prepare.py | 96 +++++++--- src/llmcompressor/pipelines/basic/pipeline.py | 4 +- .../pipelines/layer_sequential/pipeline.py | 3 +- .../pipelines/sequential/pipeline.py | 3 +- .../modeling/test_calib_deepseek_v3.py | 57 +++--- .../modeling/test_calib_llama4.py | 57 +++--- .../modeling/test_calib_qwen3.py | 3 +- 9 files changed, 318 insertions(+), 116 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index da816584d..e3e92fdf4 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -7,6 +7,7 @@ HuggingFace datasets, custom JSON/CSV files, and DVC-managed datasets. """ +import warnings from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Union @@ -126,16 +127,6 @@ class DatasetArguments(CustomDatasetArguments): default=512, metadata={"help": "Number of samples to use for one-shot calibration"}, ) - calibrate_moe_context: bool = field( - default=False, - metadata={ - "help": "If during calibration, the MoE context should be enabled " - "for the given model. This usually involves updating all MoE modules " - "in the model for the duration of calibration. See moe_context under " - "modeling/prepare.py for a list of supported MoEs and their updated " - "module definitions" - }, - ) shuffle_calibration_samples: Optional[bool] = field( default=True, metadata={ @@ -181,6 +172,17 @@ class DatasetArguments(CustomDatasetArguments): ), }, ) + calibrate_moe_context: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "DEPRECATED: This parameter is deprecated and will be \ + removed in a future version. " + "MoE calibration context is now handled automatically by the pipeline. " + "This parameter is ignored and will not affect the calibration process." + ), + }, + ) # --- pipeline arguments --- # pipeline: Optional[str] = field( default="independent", @@ -229,3 +231,16 @@ class DatasetArguments(CustomDatasetArguments): def is_dataset_provided(self) -> bool: return self.dataset is not None or self.dataset_path is not None + + def __post_init__(self): + """Post-initialization hook to issue deprecation warnings.""" + if self.calibrate_moe_context is not None: + warnings.warn( + "The 'calibrate_moe_context' parameter is deprecated\ + and will be removed in a future version. " + "MoE calibration context is now handled automatically by the pipeline. " + "This parameter is ignored and will not affect\ + the calibration process.", + DeprecationWarning, + stacklevel=2, + ) diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 835f36d13..27e51c50e 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -8,13 +8,72 @@ import contextlib from abc import ABC, abstractmethod -from typing import Dict, TypeVar, Union +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Dict, Optional, TypeVar, Union +import tqdm +from compressed_tensors.utils import replace_module from transformers import PreTrainedModel +from llmcompressor.utils.helpers import patch_attr + T = TypeVar("T", bound="MoECalibrationContext") +class MoECalibrationType(Enum): + """Enumeration of supported MoE calibration types.""" + + PERMANENT = "permanent" + CONTEXTUAL = "contextual" + + +@dataclass +class MoEModelConfig: + """ + Configuration for MoE model calibration. + + This dataclass defines the parameters needed to configure MoE calibration + for a specific model architecture. It follows the same pattern used by + other model configuration systems in the project (e.g., SmoothQuant, AWQ). + + Attributes: + calibration_type: Type of calibration - MoECalibrationType.PERMANENT or + MoECalibrationType.CONTEXTUAL + target_class_name: The class name of the MoE module to replace + replace_function: Function that creates the replacement module + target_attribute: For contextual calibration, the attribute to replace + description: Optional description of the model configuration + """ + + calibration_type: MoECalibrationType + target_class_name: str + replace_function: Callable + target_attribute: Optional[str] = None + description: Optional[str] = None + + def __post_init__(self): + """Validate configuration after initialization.""" + if ( + self.calibration_type == MoECalibrationType.CONTEXTUAL + and self.target_attribute is None + ): + raise ValueError("target_attribute is required for contextual calibration") + + if ( + self.calibration_type == MoECalibrationType.PERMANENT + and self.target_attribute is not None + ): + raise ValueError( + "target_attribute should not be set for permanent calibration" + ) + + +# Registry of MoE model configurations +# Add new MoE models here following the same pattern as MAPPINGS_REGISTRY +MOE_MODEL_REGISTRY: Dict[str, MoEModelConfig] = {} + + class MoECalibrationContext(ABC): """ Abstract base class for MoE calibration. @@ -60,13 +119,14 @@ def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> N """Apply MoE calibration modifications using context managers.""" if self._stack is None: self._stack = contextlib.ExitStack() + self._stack.__enter__() self.update_function(model, self._stack, calibrate_all_experts) def restore(self, model: PreTrainedModel) -> None: """Restore the model by exiting the context stack.""" if self._stack is not None: - self._stack.close() + self._stack.__exit__(None, None, None) self._stack = None @@ -134,22 +194,108 @@ def list_supported_models() -> list: return list(_MOE_CONTEXTS.keys()) -# Convenience function for backward compatibility -def create_context_manager_context(model_class_name: str, update_function): +# Generic factory functions for creating MoE updaters +def create_permanent_moe_updater(target_class_name: str, replace_function: Callable): """ - Create a context manager-based MoE calibration. - :param model_class_name: The class name of the model - :param update_function: Function that applies the MoE modifications - :return: A ContextualMoECalibration instance + Create a permanent MoE updater function for the given target class. + + Args: + target_class_name: The class name to look for in the model + replace_function: Function that creates the replacement module + + Returns: + A function that can be used with PermanentMoECalibration """ - return ContextualMoECalibration(model_class_name, update_function) + def update_function(model: PreTrainedModel, calibrate_all_experts: bool): + """Update MoE modules for calibration.""" + for name, module in tqdm.tqdm(list(model.named_modules())): + if module.__class__.__name__ == target_class_name: + new_module = replace_function( + config=model.config, + module=module, + calibrate_all_experts=calibrate_all_experts, + ) + replace_module(model, name, new_module) + + return update_function -def create_permanent_context(model_class_name: str, replacement_function): + +def create_contextual_moe_updater( + target_class_name: str, target_attr: str, replace_function: Callable +): """ - Create a permanent MoE calibration. - :param model_class_name: The class name of the model - :param replacement_function: Function that permanently replaces MoE modules - :return: A PermanentMoECalibration instance + Create a contextual MoE updater function for the given target class and attribute. + + Args: + target_class_name: The class name to look for in the model + target_attr: The attribute name to replace within the target class + replace_function: Function that creates the replacement module + + Returns: + A function that can be used with ContextualMoECalibration + """ + + def update_function( + model: PreTrainedModel, stack: contextlib.ExitStack, calibrate_all_experts: bool + ): + """Update MoE modules for calibration using context managers.""" + for module in model.modules(): + if module.__class__.__name__ == target_class_name: + stack.enter_context( + patch_attr( + module, + target_attr, + replace_function( + config=model.config, + module=getattr(module, target_attr), + calibrate_all_experts=calibrate_all_experts, + ), + ) + ) + + return update_function + + +def register_moe_model(model_class_name: str, config: MoEModelConfig): """ - return PermanentMoECalibration(model_class_name, replacement_function) + Register a MoE model with its configuration. + + Args: + model_class_name: The model class name + config: MoEModelConfig dataclass instance with calibration parameters + """ + if config.calibration_type == MoECalibrationType.PERMANENT: + updater = create_permanent_moe_updater( + config.target_class_name, config.replace_function + ) + context = PermanentMoECalibration(config.target_class_name, updater) + elif config.calibration_type == MoECalibrationType.CONTEXTUAL: + updater = create_contextual_moe_updater( + config.target_class_name, config.target_attribute, config.replace_function + ) + context = ContextualMoECalibration(model_class_name, updater) + else: + raise ValueError(f"Unknown MoE type: {config.calibration_type}") + + register_moe_context(model_class_name, context) + + +def register_moe_model_from_dict(model_class_name: str, config_dict: dict): + """ + Register a MoE model from a dictionary configuration (backward compatibility). + + Args: + model_class_name: The model class name + config_dict: Dictionary with calibration parameters + """ + # Convert string calibration_type to enum + if "calibration_type" in config_dict and isinstance( + config_dict["calibration_type"], str + ): + config_dict["calibration_type"] = MoECalibrationType( + config_dict["calibration_type"] + ) + + config = MoEModelConfig(**config_dict) + register_moe_model(model_class_name, config) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 286d61d27..d3c985535 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,3 +1,4 @@ +import contextlib import warnings import tqdm @@ -6,10 +7,15 @@ from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 from llmcompressor.modeling.llama4 import replace as replace_llama4 +from llmcompressor.modeling.moe_context import ( + MoECalibrationType, + MoEModelConfig, + get_moe_context, + register_moe_model, +) from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE -from llmcompressor.utils.helpers import patch_attr -__all__ = ["replace_modules_for_calibration"] +__all__ = ["moe_calibration_context"] # ---------------------- module replacements; permanent ------------------------- replacements = { @@ -24,8 +30,8 @@ def replace_modules_for_calibration( ) -> PreTrainedModel: # This function is deprecated. Use moe_calibration_context instead. warnings.warn( - "replace_modules_for_calibration is deprecated. \ - Use moe_calibration_context instead.", + "replace_modules_for_calibration is deprecated. " + "Use moe_calibration_context instead.", DeprecationWarning, stacklevel=2, ) @@ -45,37 +51,67 @@ def replace_modules_for_calibration( # ------------------- module replacements; during calibration -------------------- - -def update_qwen3_moe(model, stack, calibrate_all_experts): - for module in model.modules(): - cls_name = module.__class__.__name__ - if cls_name == "Qwen3MoeDecoderLayer": - # Optionally update the model.config to pass in other arguments - stack.enter_context( - patch_attr( - module, - "mlp", - replace_Qwen3MoE( - config=model.config, - module=module.mlp, - calibrate_all_experts=calibrate_all_experts, - ), - ) - ) +# MoE model configurations - centralized registry +# Adding a new MoE model is now as simple as adding an entry here! +# This follows the same pattern as MAPPINGS_REGISTRY in SmoothQuant and AWQ +MOE_EXPERTS_REPLACEMENTS = { + "Qwen3MoeForCausalLM": MoEModelConfig( + calibration_type=MoECalibrationType.CONTEXTUAL, + target_class_name="Qwen3MoeDecoderLayer", + target_attribute="mlp", + replace_function=replace_Qwen3MoE, + description="Qwen3 MoE model with contextual calibration for MLP layers", + ), + "DeepseekV3ForCausalLM": MoEModelConfig( + calibration_type=MoECalibrationType.PERMANENT, + target_class_name="DeepseekV3MoE", + replace_function=replace_deepseekv3, + description="DeepSeek V3 MoE model with permanent calibration", + ), + "Llama4ForConditionalGeneration": MoEModelConfig( + calibration_type=MoECalibrationType.PERMANENT, + target_class_name="Llama4TextMoe", + replace_function=replace_llama4, + description=( + "Llama4 MoE model with permanent calibration for vLLM compatibility" + ), + ), +} -moe_context = { - "Qwen3MoeForCausalLM": update_qwen3_moe, -} +# Register all MoE models automatically +for model_class_name, config in MOE_EXPERTS_REPLACEMENTS.items(): + register_moe_model(model_class_name, config) +@contextlib.contextmanager def moe_calibration_context( model: PreTrainedModel, - stack, - calibrate_all_experts: bool = False, + calibrate_all_experts: bool = True, ): - # Temporarily updates the MoE modules within the context - # Once the context exists, parameter updates persist + """ + Context manager for MoE calibration that temporarily updates MoE modules. + + Args: + model: The model to apply MoE calibration to + calibrate_all_experts: Whether to calibrate all experts or only routed ones + + Yields: + The model with MoE calibration applied + """ cls_name = model.__class__.__name__ - if cls_name in moe_context: - moe_context.get(cls_name)(model, stack, calibrate_all_experts) + moe_context = get_moe_context(cls_name) + + if moe_context is None: + # No MoE context registered for this model, yield unchanged + yield model + return + + # Apply MoE calibration + moe_context.apply(model, calibrate_all_experts) + + try: + yield model + finally: + # Restore original state + moe_context.restore(model) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index db4b15305..edcb46b09 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -46,9 +46,7 @@ def __call__( with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) - - if dataset_args is not None and dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 565cb81fb..bab732f48 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -82,8 +82,7 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) # prepare intermediates cache intermediates: IntermediatesCache = capture_first_layer_intermediates( diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index f00a37e43..9827052b7 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -85,8 +85,7 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device) diff --git a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py index ca7fb06af..239365fdd 100644 --- a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py +++ b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py @@ -1,3 +1,4 @@ +import contextlib from functools import partial import pytest @@ -9,7 +10,7 @@ DeepseekV3MoECalibrate, OriginalDeepseekV3MoE, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -21,39 +22,43 @@ def test_calib_replace_deepseekv3moe_all_experts(model_stub): with skip_weights_download(): model = AutoModelForCausalLM.from_pretrained(model_stub) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Deepseek MoE layer - moe_layer = None - for _, module in model.named_modules(): - if isinstance(module, DeepseekV3MoECalibrate): - moe_layer = module - break + # Find a Deepseek MoE layer + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, DeepseekV3MoECalibrate): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu diff --git a/tests/llmcompressor/modeling/test_calib_llama4.py b/tests/llmcompressor/modeling/test_calib_llama4.py index 4eb609ca9..d5363d35c 100644 --- a/tests/llmcompressor/modeling/test_calib_llama4.py +++ b/tests/llmcompressor/modeling/test_calib_llama4.py @@ -1,3 +1,4 @@ +import contextlib import os from functools import partial @@ -10,7 +11,7 @@ Llama4TextMoe, SequentialLlama4TextMoe, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -28,39 +29,43 @@ def test_calib_replace_llama4_moe_all_experts(model_stub): model_stub, torch_dtype="auto" ) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Llama4 MoE layer - moe_layer = None - for module in model.modules(): - if isinstance(module, SequentialLlama4TextMoe): - moe_layer = module - break + # Find a Llama4 MoE layer + moe_layer = None + for module in model.modules(): + if isinstance(module, SequentialLlama4TextMoe): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.text_config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.text_config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu diff --git a/tests/llmcompressor/modeling/test_calib_qwen3.py b/tests/llmcompressor/modeling/test_calib_qwen3.py index 822af18db..a20acf8a8 100644 --- a/tests/llmcompressor/modeling/test_calib_qwen3.py +++ b/tests/llmcompressor/modeling/test_calib_qwen3.py @@ -26,8 +26,7 @@ def test_calib_replace_qwen3moe_all_experts(model_stub): with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) stack.enter_context(DisableQuantization(model)) - - moe_calibration_context(model, stack, calibrate_all_experts=True) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) # Find one MoE layer moe_layer = None From 7fefaacbda6d461da51adfe517ab7e46d8cf6a38 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 24 Sep 2025 22:35:23 +0530 Subject: [PATCH 04/15] Update documentation --- examples/multimodal_vision/llama4_example.py | 9 +++------ examples/quantization_w4a4_fp4/README.md | 8 ++++---- examples/quantization_w4a4_fp4/llama4_example.py | 9 +++------ examples/quantizing_moe/deepseek_r1_example.py | 5 +++-- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index 292fa8300..6832ababb 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -3,18 +3,15 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/examples/quantization_w4a4_fp4/README.md b/examples/quantization_w4a4_fp4/README.md index ab9e3eb37..a0d458722 100644 --- a/examples/quantization_w4a4_fp4/README.md +++ b/examples/quantization_w4a4_fp4/README.md @@ -84,11 +84,11 @@ We have successfully created an `nvfp4` model! # Quantizing MoEs -To quantize MoEs, a few additional steps are required. An example quantizing Llama4 can be found under `llama4_example.py`. Here, we replace all `Llama4TextMoe` modules by calling `replace_modules_for_calibration`. This replacement allows us to: +To quantize MoEs, MoE calibration is now handled automatically by the pipeline. An example quantizing Llama4 can be found under `llama4_example.py`. The pipeline automatically applies the appropriate MoE calibration context which: -1. Linearize the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. -2. Ensure experts are quantized correctly as not all experts are activated during calibration +1. Linearizes the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. +2. Ensures experts are quantized correctly as not all experts are activated during calibration -Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model does not require additional linearization as required by the Llama4 model. However, similar to Llama4, in order to ensure the experts are quantized correctly, we can pass in `calibrate_moe_context` which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. +Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model uses contextual MoE calibration which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. diff --git a/examples/quantization_w4a4_fp4/llama4_example.py b/examples/quantization_w4a4_fp4/llama4_example.py index 28b57dda9..3a52986ad 100644 --- a/examples/quantization_w4a4_fp4/llama4_example.py +++ b/examples/quantization_w4a4_fp4/llama4_example.py @@ -3,18 +3,15 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import QuantizationModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 20 diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index 9977584c3..9e5d1ca63 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -2,7 +2,6 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. @@ -20,7 +19,9 @@ model_id, torch_dtype="auto", config=config ) tokenizer = AutoTokenizer.from_pretrained(model_id) -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `DeepseekV3MoECalibrate` modules will be applied during calibration +# to enable proper expert calibration. # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" From 858a0f6cf97ccd043e8b46ba59395c03ad9c5f8a Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Tue, 23 Sep 2025 02:15:01 +0530 Subject: [PATCH 05/15] Deprecate replace_modules_for_calibration --- src/llmcompressor/modeling/prepare.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 113fd4364..286d61d27 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,3 +1,5 @@ +import warnings + import tqdm from compressed_tensors.utils import replace_module from transformers import PreTrainedModel @@ -20,6 +22,14 @@ def replace_modules_for_calibration( model: PreTrainedModel, calibrate_all_experts: bool = True, ) -> PreTrainedModel: + # This function is deprecated. Use moe_calibration_context instead. + warnings.warn( + "replace_modules_for_calibration is deprecated. \ + Use moe_calibration_context instead.", + DeprecationWarning, + stacklevel=2, + ) + for name, module in tqdm.tqdm(list(model.named_modules())): cls_name = module.__class__.__name__ if cls_name in replacements: From ef8e0b752e32f0addf1085c2490bc9d15c43b22d Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Tue, 23 Sep 2025 02:15:21 +0530 Subject: [PATCH 06/15] Refactor MoECalibrationContext --- src/llmcompressor/modeling/moe_context.py | 155 ++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 src/llmcompressor/modeling/moe_context.py diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py new file mode 100644 index 000000000..835f36d13 --- /dev/null +++ b/src/llmcompressor/modeling/moe_context.py @@ -0,0 +1,155 @@ +""" +Standardized interface for MoE model calibration. +MoE calibration context is used to apply MoE calibration modifications to the model. +There are two types of MoE calibration contexts: +1. ContextualMoECalibration: uses context managers for temporary modifications +2. PermanentMoECalibration: permanently modifies the model +""" + +import contextlib +from abc import ABC, abstractmethod +from typing import Dict, TypeVar, Union + +from transformers import PreTrainedModel + +T = TypeVar("T", bound="MoECalibrationContext") + + +class MoECalibrationContext(ABC): + """ + Abstract base class for MoE calibration. + This provides a standardized interface for MoE model calibration. + """ + + @abstractmethod + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """ + Apply MoE calibration modifications to the model. + :param model: The model to modify + :param calibrate_all_experts: Whether to calibrate all + experts or only routed ones + """ + pass + + @abstractmethod + def restore(self, model: PreTrainedModel) -> None: + """ + Restore the model to its original state. + :param model: The model to restore + """ + pass + + +class ContextualMoECalibration(MoECalibrationContext): + """ + MoE calibration that uses context managers for temporary modifications. + This is suitable for models that need to be restored after calibration. + """ + + def __init__(self, model_class_name: str, update_function): + """ + Initialize the context manager-based MoE calibration. + :param model_class_name: The class name of the model this context applies to + :param update_function: Function that applies the MoE modifications + """ + self.model_class_name = model_class_name + self.update_function = update_function + self._stack = None + + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """Apply MoE calibration modifications using context managers.""" + if self._stack is None: + self._stack = contextlib.ExitStack() + + self.update_function(model, self._stack, calibrate_all_experts) + + def restore(self, model: PreTrainedModel) -> None: + """Restore the model by exiting the context stack.""" + if self._stack is not None: + self._stack.close() + self._stack = None + + +class PermanentMoECalibration(MoECalibrationContext): + """ + MoE calibration context that permanently modifies the model. + This is suitable for models that can be loaded in their modified form + (e.g., Llama4 in vLLM). + """ + + def __init__(self, model_class_name: str, replacement_function): + """ + Initialize the permanent MoE calibration. + :param model_class_name: The class name of the model this context applies to + :param replacement_function: Function that permanently replaces MoE modules + """ + self.model_class_name = model_class_name + self.replacement_function = replacement_function + self._original_modules = {} + + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """Apply permanent MoE calibration modifications.""" + # Store original modules for potential restoration + for name, module in model.named_modules(): + if module.__class__.__name__ == self.model_class_name: + self._original_modules[name] = module + + # Apply the replacement + self.replacement_function(model, calibrate_all_experts) + + def restore(self, model: PreTrainedModel) -> None: + """Restore original modules (if needed).""" + # For permanent MoE calibrations, restoration is typically not needed + # as the model is meant to stay in its modified form + pass + + +# Registry for MoE calibrations +_MOE_CONTEXTS: Dict[str, MoECalibrationContext] = {} + + +def register_moe_context(model_class_name: str, context: MoECalibrationContext) -> None: + """ + Register a MoE calibration context for a model class. + :param model_class_name: The class name of the model + :param context: The MoE calibration context to register + """ + _MOE_CONTEXTS[model_class_name] = context + + +def get_moe_context(model_class_name: str) -> Union[MoECalibrationContext, None]: + """ + Get the registered MoE calibration context for a model class. + :param model_class_name: The class name of the model + :return: The MoE calibration context or None if not found + """ + return _MOE_CONTEXTS.get(model_class_name) + + +def list_supported_models() -> list: + """ + List all model classes that have registered MoE calibration contexts. + :return: List of supported model class names + """ + return list(_MOE_CONTEXTS.keys()) + + +# Convenience function for backward compatibility +def create_context_manager_context(model_class_name: str, update_function): + """ + Create a context manager-based MoE calibration. + :param model_class_name: The class name of the model + :param update_function: Function that applies the MoE modifications + :return: A ContextualMoECalibration instance + """ + return ContextualMoECalibration(model_class_name, update_function) + + +def create_permanent_context(model_class_name: str, replacement_function): + """ + Create a permanent MoE calibration. + :param model_class_name: The class name of the model + :param replacement_function: Function that permanently replaces MoE modules + :return: A PermanentMoECalibration instance + """ + return PermanentMoECalibration(model_class_name, replacement_function) From b04f957137188d45455a89fcdff6b209963de068 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 24 Sep 2025 20:23:03 +0530 Subject: [PATCH 07/15] Use moe_context in pipeline and by default and add tests --- src/llmcompressor/args/dataset_arguments.py | 35 +++- src/llmcompressor/modeling/moe_context.py | 176 ++++++++++++++++-- src/llmcompressor/modeling/prepare.py | 96 +++++++--- src/llmcompressor/pipelines/basic/pipeline.py | 4 +- .../pipelines/layer_sequential/pipeline.py | 3 +- .../pipelines/sequential/pipeline.py | 3 +- .../modeling/test_calib_deepseek_v3.py | 57 +++--- .../modeling/test_calib_llama4.py | 57 +++--- .../modeling/test_calib_qwen3.py | 3 +- 9 files changed, 318 insertions(+), 116 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index da816584d..e3e92fdf4 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -7,6 +7,7 @@ HuggingFace datasets, custom JSON/CSV files, and DVC-managed datasets. """ +import warnings from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Union @@ -126,16 +127,6 @@ class DatasetArguments(CustomDatasetArguments): default=512, metadata={"help": "Number of samples to use for one-shot calibration"}, ) - calibrate_moe_context: bool = field( - default=False, - metadata={ - "help": "If during calibration, the MoE context should be enabled " - "for the given model. This usually involves updating all MoE modules " - "in the model for the duration of calibration. See moe_context under " - "modeling/prepare.py for a list of supported MoEs and their updated " - "module definitions" - }, - ) shuffle_calibration_samples: Optional[bool] = field( default=True, metadata={ @@ -181,6 +172,17 @@ class DatasetArguments(CustomDatasetArguments): ), }, ) + calibrate_moe_context: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "DEPRECATED: This parameter is deprecated and will be \ + removed in a future version. " + "MoE calibration context is now handled automatically by the pipeline. " + "This parameter is ignored and will not affect the calibration process." + ), + }, + ) # --- pipeline arguments --- # pipeline: Optional[str] = field( default="independent", @@ -229,3 +231,16 @@ class DatasetArguments(CustomDatasetArguments): def is_dataset_provided(self) -> bool: return self.dataset is not None or self.dataset_path is not None + + def __post_init__(self): + """Post-initialization hook to issue deprecation warnings.""" + if self.calibrate_moe_context is not None: + warnings.warn( + "The 'calibrate_moe_context' parameter is deprecated\ + and will be removed in a future version. " + "MoE calibration context is now handled automatically by the pipeline. " + "This parameter is ignored and will not affect\ + the calibration process.", + DeprecationWarning, + stacklevel=2, + ) diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 835f36d13..27e51c50e 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -8,13 +8,72 @@ import contextlib from abc import ABC, abstractmethod -from typing import Dict, TypeVar, Union +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Dict, Optional, TypeVar, Union +import tqdm +from compressed_tensors.utils import replace_module from transformers import PreTrainedModel +from llmcompressor.utils.helpers import patch_attr + T = TypeVar("T", bound="MoECalibrationContext") +class MoECalibrationType(Enum): + """Enumeration of supported MoE calibration types.""" + + PERMANENT = "permanent" + CONTEXTUAL = "contextual" + + +@dataclass +class MoEModelConfig: + """ + Configuration for MoE model calibration. + + This dataclass defines the parameters needed to configure MoE calibration + for a specific model architecture. It follows the same pattern used by + other model configuration systems in the project (e.g., SmoothQuant, AWQ). + + Attributes: + calibration_type: Type of calibration - MoECalibrationType.PERMANENT or + MoECalibrationType.CONTEXTUAL + target_class_name: The class name of the MoE module to replace + replace_function: Function that creates the replacement module + target_attribute: For contextual calibration, the attribute to replace + description: Optional description of the model configuration + """ + + calibration_type: MoECalibrationType + target_class_name: str + replace_function: Callable + target_attribute: Optional[str] = None + description: Optional[str] = None + + def __post_init__(self): + """Validate configuration after initialization.""" + if ( + self.calibration_type == MoECalibrationType.CONTEXTUAL + and self.target_attribute is None + ): + raise ValueError("target_attribute is required for contextual calibration") + + if ( + self.calibration_type == MoECalibrationType.PERMANENT + and self.target_attribute is not None + ): + raise ValueError( + "target_attribute should not be set for permanent calibration" + ) + + +# Registry of MoE model configurations +# Add new MoE models here following the same pattern as MAPPINGS_REGISTRY +MOE_MODEL_REGISTRY: Dict[str, MoEModelConfig] = {} + + class MoECalibrationContext(ABC): """ Abstract base class for MoE calibration. @@ -60,13 +119,14 @@ def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> N """Apply MoE calibration modifications using context managers.""" if self._stack is None: self._stack = contextlib.ExitStack() + self._stack.__enter__() self.update_function(model, self._stack, calibrate_all_experts) def restore(self, model: PreTrainedModel) -> None: """Restore the model by exiting the context stack.""" if self._stack is not None: - self._stack.close() + self._stack.__exit__(None, None, None) self._stack = None @@ -134,22 +194,108 @@ def list_supported_models() -> list: return list(_MOE_CONTEXTS.keys()) -# Convenience function for backward compatibility -def create_context_manager_context(model_class_name: str, update_function): +# Generic factory functions for creating MoE updaters +def create_permanent_moe_updater(target_class_name: str, replace_function: Callable): """ - Create a context manager-based MoE calibration. - :param model_class_name: The class name of the model - :param update_function: Function that applies the MoE modifications - :return: A ContextualMoECalibration instance + Create a permanent MoE updater function for the given target class. + + Args: + target_class_name: The class name to look for in the model + replace_function: Function that creates the replacement module + + Returns: + A function that can be used with PermanentMoECalibration """ - return ContextualMoECalibration(model_class_name, update_function) + def update_function(model: PreTrainedModel, calibrate_all_experts: bool): + """Update MoE modules for calibration.""" + for name, module in tqdm.tqdm(list(model.named_modules())): + if module.__class__.__name__ == target_class_name: + new_module = replace_function( + config=model.config, + module=module, + calibrate_all_experts=calibrate_all_experts, + ) + replace_module(model, name, new_module) + + return update_function -def create_permanent_context(model_class_name: str, replacement_function): + +def create_contextual_moe_updater( + target_class_name: str, target_attr: str, replace_function: Callable +): """ - Create a permanent MoE calibration. - :param model_class_name: The class name of the model - :param replacement_function: Function that permanently replaces MoE modules - :return: A PermanentMoECalibration instance + Create a contextual MoE updater function for the given target class and attribute. + + Args: + target_class_name: The class name to look for in the model + target_attr: The attribute name to replace within the target class + replace_function: Function that creates the replacement module + + Returns: + A function that can be used with ContextualMoECalibration + """ + + def update_function( + model: PreTrainedModel, stack: contextlib.ExitStack, calibrate_all_experts: bool + ): + """Update MoE modules for calibration using context managers.""" + for module in model.modules(): + if module.__class__.__name__ == target_class_name: + stack.enter_context( + patch_attr( + module, + target_attr, + replace_function( + config=model.config, + module=getattr(module, target_attr), + calibrate_all_experts=calibrate_all_experts, + ), + ) + ) + + return update_function + + +def register_moe_model(model_class_name: str, config: MoEModelConfig): """ - return PermanentMoECalibration(model_class_name, replacement_function) + Register a MoE model with its configuration. + + Args: + model_class_name: The model class name + config: MoEModelConfig dataclass instance with calibration parameters + """ + if config.calibration_type == MoECalibrationType.PERMANENT: + updater = create_permanent_moe_updater( + config.target_class_name, config.replace_function + ) + context = PermanentMoECalibration(config.target_class_name, updater) + elif config.calibration_type == MoECalibrationType.CONTEXTUAL: + updater = create_contextual_moe_updater( + config.target_class_name, config.target_attribute, config.replace_function + ) + context = ContextualMoECalibration(model_class_name, updater) + else: + raise ValueError(f"Unknown MoE type: {config.calibration_type}") + + register_moe_context(model_class_name, context) + + +def register_moe_model_from_dict(model_class_name: str, config_dict: dict): + """ + Register a MoE model from a dictionary configuration (backward compatibility). + + Args: + model_class_name: The model class name + config_dict: Dictionary with calibration parameters + """ + # Convert string calibration_type to enum + if "calibration_type" in config_dict and isinstance( + config_dict["calibration_type"], str + ): + config_dict["calibration_type"] = MoECalibrationType( + config_dict["calibration_type"] + ) + + config = MoEModelConfig(**config_dict) + register_moe_model(model_class_name, config) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 286d61d27..d3c985535 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,3 +1,4 @@ +import contextlib import warnings import tqdm @@ -6,10 +7,15 @@ from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 from llmcompressor.modeling.llama4 import replace as replace_llama4 +from llmcompressor.modeling.moe_context import ( + MoECalibrationType, + MoEModelConfig, + get_moe_context, + register_moe_model, +) from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE -from llmcompressor.utils.helpers import patch_attr -__all__ = ["replace_modules_for_calibration"] +__all__ = ["moe_calibration_context"] # ---------------------- module replacements; permanent ------------------------- replacements = { @@ -24,8 +30,8 @@ def replace_modules_for_calibration( ) -> PreTrainedModel: # This function is deprecated. Use moe_calibration_context instead. warnings.warn( - "replace_modules_for_calibration is deprecated. \ - Use moe_calibration_context instead.", + "replace_modules_for_calibration is deprecated. " + "Use moe_calibration_context instead.", DeprecationWarning, stacklevel=2, ) @@ -45,37 +51,67 @@ def replace_modules_for_calibration( # ------------------- module replacements; during calibration -------------------- - -def update_qwen3_moe(model, stack, calibrate_all_experts): - for module in model.modules(): - cls_name = module.__class__.__name__ - if cls_name == "Qwen3MoeDecoderLayer": - # Optionally update the model.config to pass in other arguments - stack.enter_context( - patch_attr( - module, - "mlp", - replace_Qwen3MoE( - config=model.config, - module=module.mlp, - calibrate_all_experts=calibrate_all_experts, - ), - ) - ) +# MoE model configurations - centralized registry +# Adding a new MoE model is now as simple as adding an entry here! +# This follows the same pattern as MAPPINGS_REGISTRY in SmoothQuant and AWQ +MOE_EXPERTS_REPLACEMENTS = { + "Qwen3MoeForCausalLM": MoEModelConfig( + calibration_type=MoECalibrationType.CONTEXTUAL, + target_class_name="Qwen3MoeDecoderLayer", + target_attribute="mlp", + replace_function=replace_Qwen3MoE, + description="Qwen3 MoE model with contextual calibration for MLP layers", + ), + "DeepseekV3ForCausalLM": MoEModelConfig( + calibration_type=MoECalibrationType.PERMANENT, + target_class_name="DeepseekV3MoE", + replace_function=replace_deepseekv3, + description="DeepSeek V3 MoE model with permanent calibration", + ), + "Llama4ForConditionalGeneration": MoEModelConfig( + calibration_type=MoECalibrationType.PERMANENT, + target_class_name="Llama4TextMoe", + replace_function=replace_llama4, + description=( + "Llama4 MoE model with permanent calibration for vLLM compatibility" + ), + ), +} -moe_context = { - "Qwen3MoeForCausalLM": update_qwen3_moe, -} +# Register all MoE models automatically +for model_class_name, config in MOE_EXPERTS_REPLACEMENTS.items(): + register_moe_model(model_class_name, config) +@contextlib.contextmanager def moe_calibration_context( model: PreTrainedModel, - stack, - calibrate_all_experts: bool = False, + calibrate_all_experts: bool = True, ): - # Temporarily updates the MoE modules within the context - # Once the context exists, parameter updates persist + """ + Context manager for MoE calibration that temporarily updates MoE modules. + + Args: + model: The model to apply MoE calibration to + calibrate_all_experts: Whether to calibrate all experts or only routed ones + + Yields: + The model with MoE calibration applied + """ cls_name = model.__class__.__name__ - if cls_name in moe_context: - moe_context.get(cls_name)(model, stack, calibrate_all_experts) + moe_context = get_moe_context(cls_name) + + if moe_context is None: + # No MoE context registered for this model, yield unchanged + yield model + return + + # Apply MoE calibration + moe_context.apply(model, calibrate_all_experts) + + try: + yield model + finally: + # Restore original state + moe_context.restore(model) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index db4b15305..edcb46b09 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -46,9 +46,7 @@ def __call__( with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) - - if dataset_args is not None and dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 565cb81fb..bab732f48 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -82,8 +82,7 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) # prepare intermediates cache intermediates: IntermediatesCache = capture_first_layer_intermediates( diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index f00a37e43..9827052b7 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -85,8 +85,7 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device) diff --git a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py index ca7fb06af..239365fdd 100644 --- a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py +++ b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py @@ -1,3 +1,4 @@ +import contextlib from functools import partial import pytest @@ -9,7 +10,7 @@ DeepseekV3MoECalibrate, OriginalDeepseekV3MoE, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -21,39 +22,43 @@ def test_calib_replace_deepseekv3moe_all_experts(model_stub): with skip_weights_download(): model = AutoModelForCausalLM.from_pretrained(model_stub) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Deepseek MoE layer - moe_layer = None - for _, module in model.named_modules(): - if isinstance(module, DeepseekV3MoECalibrate): - moe_layer = module - break + # Find a Deepseek MoE layer + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, DeepseekV3MoECalibrate): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu diff --git a/tests/llmcompressor/modeling/test_calib_llama4.py b/tests/llmcompressor/modeling/test_calib_llama4.py index 4eb609ca9..d5363d35c 100644 --- a/tests/llmcompressor/modeling/test_calib_llama4.py +++ b/tests/llmcompressor/modeling/test_calib_llama4.py @@ -1,3 +1,4 @@ +import contextlib import os from functools import partial @@ -10,7 +11,7 @@ Llama4TextMoe, SequentialLlama4TextMoe, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -28,39 +29,43 @@ def test_calib_replace_llama4_moe_all_experts(model_stub): model_stub, torch_dtype="auto" ) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Llama4 MoE layer - moe_layer = None - for module in model.modules(): - if isinstance(module, SequentialLlama4TextMoe): - moe_layer = module - break + # Find a Llama4 MoE layer + moe_layer = None + for module in model.modules(): + if isinstance(module, SequentialLlama4TextMoe): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.text_config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.text_config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu diff --git a/tests/llmcompressor/modeling/test_calib_qwen3.py b/tests/llmcompressor/modeling/test_calib_qwen3.py index 822af18db..a20acf8a8 100644 --- a/tests/llmcompressor/modeling/test_calib_qwen3.py +++ b/tests/llmcompressor/modeling/test_calib_qwen3.py @@ -26,8 +26,7 @@ def test_calib_replace_qwen3moe_all_experts(model_stub): with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) stack.enter_context(DisableQuantization(model)) - - moe_calibration_context(model, stack, calibrate_all_experts=True) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) # Find one MoE layer moe_layer = None From ba42881c3367a2b04be889ba122a71fd806803ec Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 24 Sep 2025 22:35:23 +0530 Subject: [PATCH 08/15] Update documentation --- examples/multimodal_vision/llama4_example.py | 9 +++------ examples/quantization_w4a4_fp4/README.md | 8 ++++---- examples/quantization_w4a4_fp4/llama4_example.py | 9 +++------ examples/quantizing_moe/deepseek_r1_example.py | 5 +++-- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index 292fa8300..6832ababb 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -3,18 +3,15 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/examples/quantization_w4a4_fp4/README.md b/examples/quantization_w4a4_fp4/README.md index ab9e3eb37..a0d458722 100644 --- a/examples/quantization_w4a4_fp4/README.md +++ b/examples/quantization_w4a4_fp4/README.md @@ -84,11 +84,11 @@ We have successfully created an `nvfp4` model! # Quantizing MoEs -To quantize MoEs, a few additional steps are required. An example quantizing Llama4 can be found under `llama4_example.py`. Here, we replace all `Llama4TextMoe` modules by calling `replace_modules_for_calibration`. This replacement allows us to: +To quantize MoEs, MoE calibration is now handled automatically by the pipeline. An example quantizing Llama4 can be found under `llama4_example.py`. The pipeline automatically applies the appropriate MoE calibration context which: -1. Linearize the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. -2. Ensure experts are quantized correctly as not all experts are activated during calibration +1. Linearizes the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. +2. Ensures experts are quantized correctly as not all experts are activated during calibration -Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model does not require additional linearization as required by the Llama4 model. However, similar to Llama4, in order to ensure the experts are quantized correctly, we can pass in `calibrate_moe_context` which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. +Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model uses contextual MoE calibration which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. diff --git a/examples/quantization_w4a4_fp4/llama4_example.py b/examples/quantization_w4a4_fp4/llama4_example.py index 28b57dda9..3a52986ad 100644 --- a/examples/quantization_w4a4_fp4/llama4_example.py +++ b/examples/quantization_w4a4_fp4/llama4_example.py @@ -3,18 +3,15 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import QuantizationModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 20 diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index 9977584c3..9e5d1ca63 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -2,7 +2,6 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. @@ -20,7 +19,9 @@ model_id, torch_dtype="auto", config=config ) tokenizer = AutoTokenizer.from_pretrained(model_id) -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `DeepseekV3MoECalibrate` modules will be applied during calibration +# to enable proper expert calibration. # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" From 6a38a0fbccaeda55e901eec3e0a4576707d4b220 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Fri, 26 Sep 2025 18:11:53 +0530 Subject: [PATCH 09/15] Update docstrings to fix review comments --- examples/multimodal_vision/llama4_example.py | 5 +++++ src/llmcompressor/modeling/moe_context.py | 7 ++++--- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index 6832ababb..aa88304c3 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -12,6 +12,11 @@ # MoE calibration is now handled automatically by the pipeline. # The `SequentialLlama4TextMoe` modules will be applied during calibration # to enable proper expert calibration and vLLM compatibility. +# +# NOTE: This restructuring is specifically required for vLLM compatibility +# Users can customize the calibration behavior as needed by modifying the +# To define custom calibration logic, implement your function in modeling/llama4.py (e.g., `SequentialLlama4TextMoe`). +# Then, update `MOE_EXPERTS_REPLACEMENT` in prepare.py to reference your custom function. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 27e51c50e..ff2d9cf14 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -2,8 +2,8 @@ Standardized interface for MoE model calibration. MoE calibration context is used to apply MoE calibration modifications to the model. There are two types of MoE calibration contexts: -1. ContextualMoECalibration: uses context managers for temporary modifications -2. PermanentMoECalibration: permanently modifies the model +1. ContextualMoECalibration: uses context managers for temporary modifications and restores the model to its original state after pipeline execution +2. PermanentMoECalibration: permanently modifies the model and stays in its modified form after pipeline execution """ import contextlib @@ -41,7 +41,8 @@ class MoEModelConfig: calibration_type: Type of calibration - MoECalibrationType.PERMANENT or MoECalibrationType.CONTEXTUAL target_class_name: The class name of the MoE module to replace - replace_function: Function that creates the replacement module + replace_function: Function that creates the replacement module, + generally defined in modeling/model_name.py target_attribute: For contextual calibration, the attribute to replace description: Optional description of the model configuration """ From 9a131cbeff498a8f160d3cb6f36ece01eb0d94dd Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Fri, 26 Sep 2025 18:39:30 +0530 Subject: [PATCH 10/15] Fix style and quality checks --- examples/multimodal_vision/llama4_example.py | 6 +- src/llmcompressor/modeling/moe_context.py | 297 ------------------- 2 files changed, 4 insertions(+), 299 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index f3738429d..c7838bfbf 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -15,8 +15,10 @@ # # NOTE: This restructuring is specifically required for vLLM compatibility # Users can customize the calibration behavior as needed by modifying the -# To define custom calibration logic, implement your function in modeling/llama4.py (e.g., `SequentialLlama4TextMoe`). -# Then, update `MOE_EXPERTS_REPLACEMENT` in prepare.py to reference your custom function. +# To define custom calibration logic, implement your function in +# modeling/llama4.py (e.g., `SequentialLlama4TextMoe`). +# Then, update `MOE_EXPERTS_REPLACEMENT` in prepare.py to reference your +# custom function. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 558f9dcbf..c65a3a2b6 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -283,303 +283,6 @@ def register_moe_model(model_class_name: str, config: MoEModelConfig): register_moe_context(model_class_name, context) -def register_moe_model_from_dict(model_class_name: str, config_dict: dict): - """ - Register a MoE model from a dictionary configuration (backward compatibility). - - Args: - model_class_name: The model class name - config_dict: Dictionary with calibration parameters - """ - # Convert string calibration_type to enum - if "calibration_type" in config_dict and isinstance( - config_dict["calibration_type"], str - ): - config_dict["calibration_type"] = MoECalibrationType( - config_dict["calibration_type"] - ) - - config = MoEModelConfig(**config_dict) - register_moe_model(model_class_name, config) -""" -Standardized interface for MoE model calibration. -MoE calibration context is used to apply MoE calibration modifications to the model. -There are two types of MoE calibration contexts: -1. ContextualMoECalibration: uses context managers for temporary modifications and restores the model to its original state after pipeline execution -2. PermanentMoECalibration: permanently modifies the model and stays in its modified form after pipeline execution -""" - -from abc import ABC, abstractmethod -from dataclasses import dataclass -from enum import Enum -from typing import Callable, Dict, Optional, TypeVar, Union - -from transformers import PreTrainedModel - -T = TypeVar("T", bound="MoECalibrationContext") - - -class MoECalibrationType(Enum): - """Enumeration of supported MoE calibration types.""" - - PERMANENT = "permanent" - CONTEXTUAL = "contextual" - - -@dataclass -class MoEModelConfig: - """ - Configuration for MoE model calibration. - - This dataclass defines the parameters needed to configure MoE calibration - for a specific model architecture. It follows the same pattern used by - other model configuration systems in the project (e.g., SmoothQuant, AWQ). - - Attributes: - calibration_type: Type of calibration - MoECalibrationType.PERMANENT or - MoECalibrationType.CONTEXTUAL - target_class_name: The class name of the MoE module to replace - replace_function: Function that creates the replacement module, - generally defined in modeling/model_name.py - target_attribute: For contextual calibration, the attribute to replace - description: Optional description of the model configuration - """ - - calibration_type: MoECalibrationType - target_class_name: str - replace_function: Callable - target_attribute: Optional[str] = None - description: Optional[str] = None - - def __post_init__(self): - """Validate configuration after initialization.""" - if ( - self.calibration_type == MoECalibrationType.CONTEXTUAL - and self.target_attribute is None - ): - raise ValueError("target_attribute is required for contextual calibration") - - if ( - self.calibration_type == MoECalibrationType.PERMANENT - and self.target_attribute is not None - ): - raise ValueError( - "target_attribute should not be set for permanent calibration" - ) - - -# Registry of MoE model configurations -# Add new MoE models here following the same pattern as MAPPINGS_REGISTRY -MOE_MODEL_REGISTRY: Dict[str, MoEModelConfig] = {} - - -class MoECalibrationContext(ABC): - """ - Abstract base class for MoE calibration. - This provides a standardized interface for MoE model calibration. - """ - - @abstractmethod - def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: - """ - Apply MoE calibration modifications to the model. - :param model: The model to modify - :param calibrate_all_experts: Whether to calibrate all - experts or only routed ones - """ - pass - - @abstractmethod - def restore(self, model: PreTrainedModel) -> None: - """ - Restore the model to its original state. - :param model: The model to restore - """ - pass - - -class ContextualMoECalibration(MoECalibrationContext): - """ - MoE calibration that uses context managers for temporary modifications. - This is suitable for models that need to be restored after calibration. - """ - - def __init__(self, model_class_name: str, update_function): - """ - Initialize the context manager-based MoE calibration. - :param model_class_name: The class name of the model this context applies to - :param update_function: Function that applies the MoE modifications - """ - self.model_class_name = model_class_name - self.update_function = update_function - self._stack = None - - def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: - """Apply MoE calibration modifications using context managers.""" - if self._stack is None: - self._stack = contextlib.ExitStack() - self._stack.__enter__() - - self.update_function(model, self._stack, calibrate_all_experts) - - def restore(self, model: PreTrainedModel) -> None: - """Restore the model by exiting the context stack.""" - if self._stack is not None: - self._stack.__exit__(None, None, None) - self._stack = None - - -class PermanentMoECalibration(MoECalibrationContext): - """ - MoE calibration context that permanently modifies the model. - This is suitable for models that can be loaded in their modified form - (e.g., Llama4 in vLLM). - """ - - def __init__(self, model_class_name: str, replacement_function): - """ - Initialize the permanent MoE calibration. - :param model_class_name: The class name of the model this context applies to - :param replacement_function: Function that permanently replaces MoE modules - """ - self.model_class_name = model_class_name - self.replacement_function = replacement_function - self._original_modules = {} - - def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: - """Apply permanent MoE calibration modifications.""" - # Store original modules for potential restoration - for name, module in model.named_modules(): - if module.__class__.__name__ == self.model_class_name: - self._original_modules[name] = module - - # Apply the replacement - self.replacement_function(model, calibrate_all_experts) - - def restore(self, model: PreTrainedModel) -> None: - """Restore original modules (if needed).""" - # For permanent MoE calibrations, restoration is typically not needed - # as the model is meant to stay in its modified form - pass - - -# Registry for MoE calibrations -_MOE_CONTEXTS: Dict[str, MoECalibrationContext] = {} - - -def register_moe_context(model_class_name: str, context: MoECalibrationContext) -> None: - """ - Register a MoE calibration context for a model class. - :param model_class_name: The class name of the model - :param context: The MoE calibration context to register - """ - _MOE_CONTEXTS[model_class_name] = context - - -def get_moe_context(model_class_name: str) -> Union[MoECalibrationContext, None]: - """ - Get the registered MoE calibration context for a model class. - :param model_class_name: The class name of the model - :return: The MoE calibration context or None if not found - """ - return _MOE_CONTEXTS.get(model_class_name) - - -def list_supported_models() -> list: - """ - List all model classes that have registered MoE calibration contexts. - :return: List of supported model class names - """ - return list(_MOE_CONTEXTS.keys()) - - -# Generic factory functions for creating MoE updaters -def create_permanent_moe_updater(target_class_name: str, replace_function: Callable): - """ - Create a permanent MoE updater function for the given target class. - - Args: - target_class_name: The class name to look for in the model - replace_function: Function that creates the replacement module - - Returns: - A function that can be used with PermanentMoECalibration - """ - - def update_function(model: PreTrainedModel, calibrate_all_experts: bool): - """Update MoE modules for calibration.""" - for name, module in tqdm.tqdm(list(model.named_modules())): - if module.__class__.__name__ == target_class_name: - new_module = replace_function( - config=model.config, - module=module, - calibrate_all_experts=calibrate_all_experts, - ) - replace_module(model, name, new_module) - - return update_function - - -def create_contextual_moe_updater( - target_class_name: str, target_attr: str, replace_function: Callable -): - """ - Create a contextual MoE updater function for the given target class and attribute. - - Args: - target_class_name: The class name to look for in the model - target_attr: The attribute name to replace within the target class - replace_function: Function that creates the replacement module - - Returns: - A function that can be used with ContextualMoECalibration - """ - - def update_function( - model: PreTrainedModel, stack: contextlib.ExitStack, calibrate_all_experts: bool - ): - """Update MoE modules for calibration using context managers.""" - for module in model.modules(): - if module.__class__.__name__ == target_class_name: - stack.enter_context( - patch_attr( - module, - target_attr, - replace_function( - config=model.config, - module=getattr(module, target_attr), - calibrate_all_experts=calibrate_all_experts, - ), - ) - ) - - return update_function - - -def register_moe_model(model_class_name: str, config: MoEModelConfig): - """ - Register a MoE model with its configuration. - - Args: - model_class_name: The model class name - config: MoEModelConfig dataclass instance with calibration parameters - """ - if config.calibration_type == MoECalibrationType.PERMANENT: - updater = create_permanent_moe_updater( - config.target_class_name, config.replace_function - ) - context = PermanentMoECalibration(config.target_class_name, updater) - elif config.calibration_type == MoECalibrationType.CONTEXTUAL: - updater = create_contextual_moe_updater( - config.target_class_name, config.target_attribute, config.replace_function - ) - context = ContextualMoECalibration(model_class_name, updater) - else: - raise ValueError(f"Unknown MoE type: {config.calibration_type}") - - register_moe_context(model_class_name, context) - - def register_moe_model_from_dict(model_class_name: str, config_dict: dict): """ Register a MoE model from a dictionary configuration (backward compatibility). From 452042112d5ca4fe6382f746124b5e6a71abfc39 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Tue, 23 Sep 2025 02:15:01 +0530 Subject: [PATCH 11/15] Deprecate replace_modules_for_calibration --- src/llmcompressor/modeling/prepare.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index e966761bd..138ad773c 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,3 +1,5 @@ +import warnings + import tqdm from compressed_tensors.utils import replace_module from transformers import PreTrainedModel @@ -20,6 +22,14 @@ def replace_modules_for_calibration( model: PreTrainedModel, calibrate_all_experts: bool = True, ) -> PreTrainedModel: + # This function is deprecated. Use moe_calibration_context instead. + warnings.warn( + "replace_modules_for_calibration is deprecated. \ + Use moe_calibration_context instead.", + DeprecationWarning, + stacklevel=2, + ) + for name, module in tqdm.tqdm(list(model.named_modules())): cls_name = module.__class__.__name__ if cls_name in replacements: From 1c15741db15d4b0a144775394bd932a64c9f67cf Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Tue, 23 Sep 2025 02:15:21 +0530 Subject: [PATCH 12/15] Refactor MoECalibrationContext --- src/llmcompressor/modeling/moe_context.py | 155 ++++++++++++++++++++++ 1 file changed, 155 insertions(+) create mode 100644 src/llmcompressor/modeling/moe_context.py diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py new file mode 100644 index 000000000..835f36d13 --- /dev/null +++ b/src/llmcompressor/modeling/moe_context.py @@ -0,0 +1,155 @@ +""" +Standardized interface for MoE model calibration. +MoE calibration context is used to apply MoE calibration modifications to the model. +There are two types of MoE calibration contexts: +1. ContextualMoECalibration: uses context managers for temporary modifications +2. PermanentMoECalibration: permanently modifies the model +""" + +import contextlib +from abc import ABC, abstractmethod +from typing import Dict, TypeVar, Union + +from transformers import PreTrainedModel + +T = TypeVar("T", bound="MoECalibrationContext") + + +class MoECalibrationContext(ABC): + """ + Abstract base class for MoE calibration. + This provides a standardized interface for MoE model calibration. + """ + + @abstractmethod + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """ + Apply MoE calibration modifications to the model. + :param model: The model to modify + :param calibrate_all_experts: Whether to calibrate all + experts or only routed ones + """ + pass + + @abstractmethod + def restore(self, model: PreTrainedModel) -> None: + """ + Restore the model to its original state. + :param model: The model to restore + """ + pass + + +class ContextualMoECalibration(MoECalibrationContext): + """ + MoE calibration that uses context managers for temporary modifications. + This is suitable for models that need to be restored after calibration. + """ + + def __init__(self, model_class_name: str, update_function): + """ + Initialize the context manager-based MoE calibration. + :param model_class_name: The class name of the model this context applies to + :param update_function: Function that applies the MoE modifications + """ + self.model_class_name = model_class_name + self.update_function = update_function + self._stack = None + + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """Apply MoE calibration modifications using context managers.""" + if self._stack is None: + self._stack = contextlib.ExitStack() + + self.update_function(model, self._stack, calibrate_all_experts) + + def restore(self, model: PreTrainedModel) -> None: + """Restore the model by exiting the context stack.""" + if self._stack is not None: + self._stack.close() + self._stack = None + + +class PermanentMoECalibration(MoECalibrationContext): + """ + MoE calibration context that permanently modifies the model. + This is suitable for models that can be loaded in their modified form + (e.g., Llama4 in vLLM). + """ + + def __init__(self, model_class_name: str, replacement_function): + """ + Initialize the permanent MoE calibration. + :param model_class_name: The class name of the model this context applies to + :param replacement_function: Function that permanently replaces MoE modules + """ + self.model_class_name = model_class_name + self.replacement_function = replacement_function + self._original_modules = {} + + def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> None: + """Apply permanent MoE calibration modifications.""" + # Store original modules for potential restoration + for name, module in model.named_modules(): + if module.__class__.__name__ == self.model_class_name: + self._original_modules[name] = module + + # Apply the replacement + self.replacement_function(model, calibrate_all_experts) + + def restore(self, model: PreTrainedModel) -> None: + """Restore original modules (if needed).""" + # For permanent MoE calibrations, restoration is typically not needed + # as the model is meant to stay in its modified form + pass + + +# Registry for MoE calibrations +_MOE_CONTEXTS: Dict[str, MoECalibrationContext] = {} + + +def register_moe_context(model_class_name: str, context: MoECalibrationContext) -> None: + """ + Register a MoE calibration context for a model class. + :param model_class_name: The class name of the model + :param context: The MoE calibration context to register + """ + _MOE_CONTEXTS[model_class_name] = context + + +def get_moe_context(model_class_name: str) -> Union[MoECalibrationContext, None]: + """ + Get the registered MoE calibration context for a model class. + :param model_class_name: The class name of the model + :return: The MoE calibration context or None if not found + """ + return _MOE_CONTEXTS.get(model_class_name) + + +def list_supported_models() -> list: + """ + List all model classes that have registered MoE calibration contexts. + :return: List of supported model class names + """ + return list(_MOE_CONTEXTS.keys()) + + +# Convenience function for backward compatibility +def create_context_manager_context(model_class_name: str, update_function): + """ + Create a context manager-based MoE calibration. + :param model_class_name: The class name of the model + :param update_function: Function that applies the MoE modifications + :return: A ContextualMoECalibration instance + """ + return ContextualMoECalibration(model_class_name, update_function) + + +def create_permanent_context(model_class_name: str, replacement_function): + """ + Create a permanent MoE calibration. + :param model_class_name: The class name of the model + :param replacement_function: Function that permanently replaces MoE modules + :return: A PermanentMoECalibration instance + """ + return PermanentMoECalibration(model_class_name, replacement_function) From d8fecb9527340b4b38e99a1beb86783f8cd34793 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 24 Sep 2025 20:23:03 +0530 Subject: [PATCH 13/15] Use moe_context in pipeline and by default and add tests --- src/llmcompressor/args/dataset_arguments.py | 35 +++- src/llmcompressor/modeling/moe_context.py | 176 ++++++++++++++++-- src/llmcompressor/modeling/prepare.py | 94 +++++++--- src/llmcompressor/pipelines/basic/pipeline.py | 4 +- .../pipelines/layer_sequential/pipeline.py | 3 +- .../pipelines/sequential/pipeline.py | 3 +- .../modeling/test_calib_deepseek_v3.py | 57 +++--- .../modeling/test_calib_llama4.py | 57 +++--- .../modeling/test_calib_qwen3.py | 3 +- 9 files changed, 317 insertions(+), 115 deletions(-) diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index da816584d..e3e92fdf4 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -7,6 +7,7 @@ HuggingFace datasets, custom JSON/CSV files, and DVC-managed datasets. """ +import warnings from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Union @@ -126,16 +127,6 @@ class DatasetArguments(CustomDatasetArguments): default=512, metadata={"help": "Number of samples to use for one-shot calibration"}, ) - calibrate_moe_context: bool = field( - default=False, - metadata={ - "help": "If during calibration, the MoE context should be enabled " - "for the given model. This usually involves updating all MoE modules " - "in the model for the duration of calibration. See moe_context under " - "modeling/prepare.py for a list of supported MoEs and their updated " - "module definitions" - }, - ) shuffle_calibration_samples: Optional[bool] = field( default=True, metadata={ @@ -181,6 +172,17 @@ class DatasetArguments(CustomDatasetArguments): ), }, ) + calibrate_moe_context: Optional[bool] = field( + default=None, + metadata={ + "help": ( + "DEPRECATED: This parameter is deprecated and will be \ + removed in a future version. " + "MoE calibration context is now handled automatically by the pipeline. " + "This parameter is ignored and will not affect the calibration process." + ), + }, + ) # --- pipeline arguments --- # pipeline: Optional[str] = field( default="independent", @@ -229,3 +231,16 @@ class DatasetArguments(CustomDatasetArguments): def is_dataset_provided(self) -> bool: return self.dataset is not None or self.dataset_path is not None + + def __post_init__(self): + """Post-initialization hook to issue deprecation warnings.""" + if self.calibrate_moe_context is not None: + warnings.warn( + "The 'calibrate_moe_context' parameter is deprecated\ + and will be removed in a future version. " + "MoE calibration context is now handled automatically by the pipeline. " + "This parameter is ignored and will not affect\ + the calibration process.", + DeprecationWarning, + stacklevel=2, + ) diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 835f36d13..27e51c50e 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -8,13 +8,72 @@ import contextlib from abc import ABC, abstractmethod -from typing import Dict, TypeVar, Union +from dataclasses import dataclass +from enum import Enum +from typing import Callable, Dict, Optional, TypeVar, Union +import tqdm +from compressed_tensors.utils import replace_module from transformers import PreTrainedModel +from llmcompressor.utils.helpers import patch_attr + T = TypeVar("T", bound="MoECalibrationContext") +class MoECalibrationType(Enum): + """Enumeration of supported MoE calibration types.""" + + PERMANENT = "permanent" + CONTEXTUAL = "contextual" + + +@dataclass +class MoEModelConfig: + """ + Configuration for MoE model calibration. + + This dataclass defines the parameters needed to configure MoE calibration + for a specific model architecture. It follows the same pattern used by + other model configuration systems in the project (e.g., SmoothQuant, AWQ). + + Attributes: + calibration_type: Type of calibration - MoECalibrationType.PERMANENT or + MoECalibrationType.CONTEXTUAL + target_class_name: The class name of the MoE module to replace + replace_function: Function that creates the replacement module + target_attribute: For contextual calibration, the attribute to replace + description: Optional description of the model configuration + """ + + calibration_type: MoECalibrationType + target_class_name: str + replace_function: Callable + target_attribute: Optional[str] = None + description: Optional[str] = None + + def __post_init__(self): + """Validate configuration after initialization.""" + if ( + self.calibration_type == MoECalibrationType.CONTEXTUAL + and self.target_attribute is None + ): + raise ValueError("target_attribute is required for contextual calibration") + + if ( + self.calibration_type == MoECalibrationType.PERMANENT + and self.target_attribute is not None + ): + raise ValueError( + "target_attribute should not be set for permanent calibration" + ) + + +# Registry of MoE model configurations +# Add new MoE models here following the same pattern as MAPPINGS_REGISTRY +MOE_MODEL_REGISTRY: Dict[str, MoEModelConfig] = {} + + class MoECalibrationContext(ABC): """ Abstract base class for MoE calibration. @@ -60,13 +119,14 @@ def apply(self, model: PreTrainedModel, calibrate_all_experts: bool = True) -> N """Apply MoE calibration modifications using context managers.""" if self._stack is None: self._stack = contextlib.ExitStack() + self._stack.__enter__() self.update_function(model, self._stack, calibrate_all_experts) def restore(self, model: PreTrainedModel) -> None: """Restore the model by exiting the context stack.""" if self._stack is not None: - self._stack.close() + self._stack.__exit__(None, None, None) self._stack = None @@ -134,22 +194,108 @@ def list_supported_models() -> list: return list(_MOE_CONTEXTS.keys()) -# Convenience function for backward compatibility -def create_context_manager_context(model_class_name: str, update_function): +# Generic factory functions for creating MoE updaters +def create_permanent_moe_updater(target_class_name: str, replace_function: Callable): """ - Create a context manager-based MoE calibration. - :param model_class_name: The class name of the model - :param update_function: Function that applies the MoE modifications - :return: A ContextualMoECalibration instance + Create a permanent MoE updater function for the given target class. + + Args: + target_class_name: The class name to look for in the model + replace_function: Function that creates the replacement module + + Returns: + A function that can be used with PermanentMoECalibration """ - return ContextualMoECalibration(model_class_name, update_function) + def update_function(model: PreTrainedModel, calibrate_all_experts: bool): + """Update MoE modules for calibration.""" + for name, module in tqdm.tqdm(list(model.named_modules())): + if module.__class__.__name__ == target_class_name: + new_module = replace_function( + config=model.config, + module=module, + calibrate_all_experts=calibrate_all_experts, + ) + replace_module(model, name, new_module) + + return update_function -def create_permanent_context(model_class_name: str, replacement_function): + +def create_contextual_moe_updater( + target_class_name: str, target_attr: str, replace_function: Callable +): """ - Create a permanent MoE calibration. - :param model_class_name: The class name of the model - :param replacement_function: Function that permanently replaces MoE modules - :return: A PermanentMoECalibration instance + Create a contextual MoE updater function for the given target class and attribute. + + Args: + target_class_name: The class name to look for in the model + target_attr: The attribute name to replace within the target class + replace_function: Function that creates the replacement module + + Returns: + A function that can be used with ContextualMoECalibration + """ + + def update_function( + model: PreTrainedModel, stack: contextlib.ExitStack, calibrate_all_experts: bool + ): + """Update MoE modules for calibration using context managers.""" + for module in model.modules(): + if module.__class__.__name__ == target_class_name: + stack.enter_context( + patch_attr( + module, + target_attr, + replace_function( + config=model.config, + module=getattr(module, target_attr), + calibrate_all_experts=calibrate_all_experts, + ), + ) + ) + + return update_function + + +def register_moe_model(model_class_name: str, config: MoEModelConfig): """ - return PermanentMoECalibration(model_class_name, replacement_function) + Register a MoE model with its configuration. + + Args: + model_class_name: The model class name + config: MoEModelConfig dataclass instance with calibration parameters + """ + if config.calibration_type == MoECalibrationType.PERMANENT: + updater = create_permanent_moe_updater( + config.target_class_name, config.replace_function + ) + context = PermanentMoECalibration(config.target_class_name, updater) + elif config.calibration_type == MoECalibrationType.CONTEXTUAL: + updater = create_contextual_moe_updater( + config.target_class_name, config.target_attribute, config.replace_function + ) + context = ContextualMoECalibration(model_class_name, updater) + else: + raise ValueError(f"Unknown MoE type: {config.calibration_type}") + + register_moe_context(model_class_name, context) + + +def register_moe_model_from_dict(model_class_name: str, config_dict: dict): + """ + Register a MoE model from a dictionary configuration (backward compatibility). + + Args: + model_class_name: The model class name + config_dict: Dictionary with calibration parameters + """ + # Convert string calibration_type to enum + if "calibration_type" in config_dict and isinstance( + config_dict["calibration_type"], str + ): + config_dict["calibration_type"] = MoECalibrationType( + config_dict["calibration_type"] + ) + + config = MoEModelConfig(**config_dict) + register_moe_model(model_class_name, config) diff --git a/src/llmcompressor/modeling/prepare.py b/src/llmcompressor/modeling/prepare.py index 138ad773c..d3c985535 100644 --- a/src/llmcompressor/modeling/prepare.py +++ b/src/llmcompressor/modeling/prepare.py @@ -1,3 +1,4 @@ +import contextlib import warnings import tqdm @@ -6,10 +7,15 @@ from llmcompressor.modeling.deepseek_v3 import replace as replace_deepseekv3 from llmcompressor.modeling.llama4 import replace as replace_llama4 +from llmcompressor.modeling.moe_context import ( + MoECalibrationType, + MoEModelConfig, + get_moe_context, + register_moe_model, +) from llmcompressor.modeling.qwen3_moe import replace as replace_Qwen3MoE -from llmcompressor.utils.helpers import patch_attr -__all__ = ["replace_modules_for_calibration"] +__all__ = ["moe_calibration_context"] # ---------------------- module replacements; permanent ------------------------- replacements = { @@ -24,8 +30,8 @@ def replace_modules_for_calibration( ) -> PreTrainedModel: # This function is deprecated. Use moe_calibration_context instead. warnings.warn( - "replace_modules_for_calibration is deprecated. \ - Use moe_calibration_context instead.", + "replace_modules_for_calibration is deprecated. " + "Use moe_calibration_context instead.", DeprecationWarning, stacklevel=2, ) @@ -45,37 +51,67 @@ def replace_modules_for_calibration( # ------------------- module replacements; during calibration -------------------- - -def update_qwen3_moe(model, stack, calibrate_all_experts): - for module in model.modules(): - cls_name = module.__class__.__name__ - if cls_name == "Qwen3MoeDecoderLayer": - # Optionally update the model.config to pass in other arguments - stack.enter_context( - patch_attr( - module, - "mlp", - replace_Qwen3MoE( - config=model.config, - module=module.mlp, - calibrate_all_experts=calibrate_all_experts, - ), - ) - ) +# MoE model configurations - centralized registry +# Adding a new MoE model is now as simple as adding an entry here! +# This follows the same pattern as MAPPINGS_REGISTRY in SmoothQuant and AWQ +MOE_EXPERTS_REPLACEMENTS = { + "Qwen3MoeForCausalLM": MoEModelConfig( + calibration_type=MoECalibrationType.CONTEXTUAL, + target_class_name="Qwen3MoeDecoderLayer", + target_attribute="mlp", + replace_function=replace_Qwen3MoE, + description="Qwen3 MoE model with contextual calibration for MLP layers", + ), + "DeepseekV3ForCausalLM": MoEModelConfig( + calibration_type=MoECalibrationType.PERMANENT, + target_class_name="DeepseekV3MoE", + replace_function=replace_deepseekv3, + description="DeepSeek V3 MoE model with permanent calibration", + ), + "Llama4ForConditionalGeneration": MoEModelConfig( + calibration_type=MoECalibrationType.PERMANENT, + target_class_name="Llama4TextMoe", + replace_function=replace_llama4, + description=( + "Llama4 MoE model with permanent calibration for vLLM compatibility" + ), + ), +} -moe_context = { - "Qwen3MoeForCausalLM": update_qwen3_moe, -} +# Register all MoE models automatically +for model_class_name, config in MOE_EXPERTS_REPLACEMENTS.items(): + register_moe_model(model_class_name, config) +@contextlib.contextmanager def moe_calibration_context( model: PreTrainedModel, - stack, calibrate_all_experts: bool = True, ): - # Temporarily updates the MoE modules within the context - # Once the context exists, parameter updates persist + """ + Context manager for MoE calibration that temporarily updates MoE modules. + + Args: + model: The model to apply MoE calibration to + calibrate_all_experts: Whether to calibrate all experts or only routed ones + + Yields: + The model with MoE calibration applied + """ cls_name = model.__class__.__name__ - if cls_name in moe_context: - moe_context.get(cls_name)(model, stack, calibrate_all_experts) + moe_context = get_moe_context(cls_name) + + if moe_context is None: + # No MoE context registered for this model, yield unchanged + yield model + return + + # Apply MoE calibration + moe_context.apply(model, calibrate_all_experts) + + try: + yield model + finally: + # Restore original state + moe_context.restore(model) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index db4b15305..edcb46b09 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -46,9 +46,7 @@ def __call__( with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) - - if dataset_args is not None and dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) diff --git a/src/llmcompressor/pipelines/layer_sequential/pipeline.py b/src/llmcompressor/pipelines/layer_sequential/pipeline.py index 565cb81fb..bab732f48 100644 --- a/src/llmcompressor/pipelines/layer_sequential/pipeline.py +++ b/src/llmcompressor/pipelines/layer_sequential/pipeline.py @@ -82,8 +82,7 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) # prepare intermediates cache intermediates: IntermediatesCache = capture_first_layer_intermediates( diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index f00a37e43..9827052b7 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -85,8 +85,7 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - if dataset_args.calibrate_moe_context: - moe_calibration_context(model, stack) + stack.enter_context(moe_calibration_context(model)) # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device) diff --git a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py index ca7fb06af..239365fdd 100644 --- a/tests/llmcompressor/modeling/test_calib_deepseek_v3.py +++ b/tests/llmcompressor/modeling/test_calib_deepseek_v3.py @@ -1,3 +1,4 @@ +import contextlib from functools import partial import pytest @@ -9,7 +10,7 @@ DeepseekV3MoECalibrate, OriginalDeepseekV3MoE, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -21,39 +22,43 @@ def test_calib_replace_deepseekv3moe_all_experts(model_stub): with skip_weights_download(): model = AutoModelForCausalLM.from_pretrained(model_stub) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Deepseek MoE layer - moe_layer = None - for _, module in model.named_modules(): - if isinstance(module, DeepseekV3MoECalibrate): - moe_layer = module - break + # Find a Deepseek MoE layer + moe_layer = None + for _, module in model.named_modules(): + if isinstance(module, DeepseekV3MoECalibrate): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=torch.float32) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu diff --git a/tests/llmcompressor/modeling/test_calib_llama4.py b/tests/llmcompressor/modeling/test_calib_llama4.py index 4eb609ca9..d5363d35c 100644 --- a/tests/llmcompressor/modeling/test_calib_llama4.py +++ b/tests/llmcompressor/modeling/test_calib_llama4.py @@ -1,3 +1,4 @@ +import contextlib import os from functools import partial @@ -10,7 +11,7 @@ Llama4TextMoe, SequentialLlama4TextMoe, ) -from llmcompressor.modeling.prepare import replace_modules_for_calibration +from llmcompressor.modeling.prepare import moe_calibration_context from llmcompressor.utils.dev import skip_weights_download from llmcompressor.utils.helpers import calibration_forward_context from tests.testing_utils import requires_cadence, requires_gpu @@ -28,39 +29,43 @@ def test_calib_replace_llama4_moe_all_experts(model_stub): model_stub, torch_dtype="auto" ) - replace_modules_for_calibration(model, calibrate_all_experts=True) + with contextlib.ExitStack() as stack: + stack.enter_context(calibration_forward_context(model)) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) - # Find a Llama4 MoE layer - moe_layer = None - for module in model.modules(): - if isinstance(module, SequentialLlama4TextMoe): - moe_layer = module - break + # Find a Llama4 MoE layer + moe_layer = None + for module in model.modules(): + if isinstance(module, SequentialLlama4TextMoe): + moe_layer = module + break - assert moe_layer is not None + assert moe_layer is not None - num_experts = len(moe_layer.experts) - expert_triggered = [False for _ in range(num_experts)] + num_experts = len(moe_layer.experts) + expert_triggered = [False for _ in range(num_experts)] - # Define the hook function - def hook_fn(i, module, input, output): - expert_triggered[i] = True + # Define the hook function + def hook_fn(i, module, input, output): + expert_triggered[i] = True - # Attach hooks using functools.partial to bind each index - for i, expert in enumerate(moe_layer.experts): - expert.register_forward_hook(partial(hook_fn, i)) + # Attach hooks using functools.partial to bind each index + for i, expert in enumerate(moe_layer.experts): + expert.register_forward_hook(partial(hook_fn, i)) - # Create dummy input tensor that simulates hidden_states - hidden_dim = model.config.text_config.hidden_size - batch, seq_len = 4, 32 - sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) + # Create dummy input tensor that simulates hidden_states + hidden_dim = model.config.text_config.hidden_size + batch, seq_len = 4, 32 + sample = torch.randn(batch, seq_len, hidden_dim, dtype=model.dtype) - # Forward through the MoE layer directly - with torch.no_grad(): - _ = moe_layer(sample) + # Forward through the MoE layer directly + with torch.no_grad(): + _ = moe_layer(sample) - # Assert all experts are used - assert all(expert_triggered), f"Not all experts were triggered: {expert_triggered}" + # Assert all experts are used + assert all( + expert_triggered + ), f"Not all experts were triggered: {expert_triggered}" @requires_gpu diff --git a/tests/llmcompressor/modeling/test_calib_qwen3.py b/tests/llmcompressor/modeling/test_calib_qwen3.py index 822af18db..a20acf8a8 100644 --- a/tests/llmcompressor/modeling/test_calib_qwen3.py +++ b/tests/llmcompressor/modeling/test_calib_qwen3.py @@ -26,8 +26,7 @@ def test_calib_replace_qwen3moe_all_experts(model_stub): with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) stack.enter_context(DisableQuantization(model)) - - moe_calibration_context(model, stack, calibrate_all_experts=True) + stack.enter_context(moe_calibration_context(model, calibrate_all_experts=True)) # Find one MoE layer moe_layer = None From d099bf3488b14bae882455296084aa1276404fe6 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Wed, 24 Sep 2025 22:35:23 +0530 Subject: [PATCH 14/15] Update documentation --- examples/multimodal_vision/llama4_example.py | 9 +++------ examples/quantization_w4a4_fp4/README.md | 8 ++++---- examples/quantization_w4a4_fp4/llama4_example.py | 9 +++------ examples/quantizing_moe/deepseek_r1_example.py | 5 +++-- 4 files changed, 13 insertions(+), 18 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index 292fa8300..6832ababb 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -3,18 +3,15 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/examples/quantization_w4a4_fp4/README.md b/examples/quantization_w4a4_fp4/README.md index ab9e3eb37..a0d458722 100644 --- a/examples/quantization_w4a4_fp4/README.md +++ b/examples/quantization_w4a4_fp4/README.md @@ -84,11 +84,11 @@ We have successfully created an `nvfp4` model! # Quantizing MoEs -To quantize MoEs, a few additional steps are required. An example quantizing Llama4 can be found under `llama4_example.py`. Here, we replace all `Llama4TextMoe` modules by calling `replace_modules_for_calibration`. This replacement allows us to: +To quantize MoEs, MoE calibration is now handled automatically by the pipeline. An example quantizing Llama4 can be found under `llama4_example.py`. The pipeline automatically applies the appropriate MoE calibration context which: -1. Linearize the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. -2. Ensure experts are quantized correctly as not all experts are activated during calibration +1. Linearizes the model to enable quantization and execution in vLLM. This is required as the native model definition does not include `torch.nn.Linear` layers in its MoE blocks, a requirement for LLM Compressor to run quantization. +2. Ensures experts are quantized correctly as not all experts are activated during calibration -Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model does not require additional linearization as required by the Llama4 model. However, similar to Llama4, in order to ensure the experts are quantized correctly, we can pass in `calibrate_moe_context` which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. +Similarly, an example quantizing the Qwen3-30B-A3B model can be found under `qwen_30b_a3b.py`. This model uses contextual MoE calibration which temporarily updates the model definition to use `Qwen3MoeSparseMoeBlock` which updates how the forward pass is handled in the MoE block during calibration. Feel free to update the definition under `llm-compressor/src/llmcompressor/modeling/qwen3_moe.py` to play around with this behavior and evaluate its impact on quantization performance. diff --git a/examples/quantization_w4a4_fp4/llama4_example.py b/examples/quantization_w4a4_fp4/llama4_example.py index 28b57dda9..3a52986ad 100644 --- a/examples/quantization_w4a4_fp4/llama4_example.py +++ b/examples/quantization_w4a4_fp4/llama4_example.py @@ -3,18 +3,15 @@ from transformers import Llama4ForConditionalGeneration, Llama4Processor from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import QuantizationModifier # Select model and load it. model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" model = Llama4ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = Llama4Processor.from_pretrained(model_id) -# We update `Llama4TextMoe` modules with custom `SequentialLlama4TextMoe`. -# This change allows compatibility with vllm. -# To apply your own custom module for experimentation, consider updating -# `SequentialLlama4TextMoe` under llmcompressor/modeling/llama4.py -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `SequentialLlama4TextMoe` modules will be applied during calibration +# to enable proper expert calibration and vLLM compatibility. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 20 diff --git a/examples/quantizing_moe/deepseek_r1_example.py b/examples/quantizing_moe/deepseek_r1_example.py index 9977584c3..9e5d1ca63 100644 --- a/examples/quantizing_moe/deepseek_r1_example.py +++ b/examples/quantizing_moe/deepseek_r1_example.py @@ -2,7 +2,6 @@ from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from llmcompressor import oneshot -from llmcompressor.modeling import replace_modules_for_calibration from llmcompressor.modifiers.quantization import GPTQModifier # Select model and load it. @@ -20,7 +19,9 @@ model_id, torch_dtype="auto", config=config ) tokenizer = AutoTokenizer.from_pretrained(model_id) -model = replace_modules_for_calibration(model) +# MoE calibration is now handled automatically by the pipeline. +# The `DeepseekV3MoECalibrate` modules will be applied during calibration +# to enable proper expert calibration. # Select calibration dataset. DATASET_ID = "HuggingFaceH4/ultrachat_200k" From d4a6a11421b430afa404dc89372be7dd9e70fbf2 Mon Sep 17 00:00:00 2001 From: Sairam Pillai Date: Fri, 26 Sep 2025 18:39:30 +0530 Subject: [PATCH 15/15] Fix style and quality checks --- examples/multimodal_vision/llama4_example.py | 8 +++++--- src/llmcompressor/modeling/moe_context.py | 6 +++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/examples/multimodal_vision/llama4_example.py b/examples/multimodal_vision/llama4_example.py index aa88304c3..c7838bfbf 100644 --- a/examples/multimodal_vision/llama4_example.py +++ b/examples/multimodal_vision/llama4_example.py @@ -12,11 +12,13 @@ # MoE calibration is now handled automatically by the pipeline. # The `SequentialLlama4TextMoe` modules will be applied during calibration # to enable proper expert calibration and vLLM compatibility. -# +# # NOTE: This restructuring is specifically required for vLLM compatibility # Users can customize the calibration behavior as needed by modifying the -# To define custom calibration logic, implement your function in modeling/llama4.py (e.g., `SequentialLlama4TextMoe`). -# Then, update `MOE_EXPERTS_REPLACEMENT` in prepare.py to reference your custom function. +# To define custom calibration logic, implement your function in +# modeling/llama4.py (e.g., `SequentialLlama4TextMoe`). +# Then, update `MOE_EXPERTS_REPLACEMENT` in prepare.py to reference your +# custom function. DATASET_ID = "neuralmagic/calibration" NUM_CALIBRATION_SAMPLES = 512 diff --git a/src/llmcompressor/modeling/moe_context.py b/src/llmcompressor/modeling/moe_context.py index 72e495049..9de74c20a 100644 --- a/src/llmcompressor/modeling/moe_context.py +++ b/src/llmcompressor/modeling/moe_context.py @@ -2,9 +2,9 @@ Standardized interface for MoE model calibration. MoE calibration context is used to apply MoE calibration modifications to the model. There are two types of MoE calibration contexts: -1. ContextualMoECalibration: uses context managers for temporary modifications +1. ContextualMoECalibration: uses context managers for temporary modifications and restores the model to its original state after pipeline execution -2. PermanentMoECalibration: permanently modifies the model and stays in its modified +2. PermanentMoECalibration: permanently modifies the model and stays in its modified form after pipeline execution """ @@ -43,7 +43,7 @@ class MoEModelConfig: calibration_type: Type of calibration - MoECalibrationType.PERMANENT or MoECalibrationType.CONTEXTUAL target_class_name: The class name of the MoE module to replace - replace_function: Function that creates the replacement module, + replace_function: Function that creates the replacement module generally defined in modeling/model_name.py target_attribute: For contextual calibration, the attribute to replace description: Optional description of the model configuration