diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index dc35a5c02f..3d3254b9b1 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -7,12 +7,14 @@ from compressed_tensors.utils import ( align_modules, get_execution_device, + match_modules_set, match_named_modules, update_offload_parameter, ) from loguru import logger from pydantic import ConfigDict, PrivateAttr, model_validator from torch.nn import Module +from torch.utils._pytree import tree_leaves from tqdm import tqdm from llmcompressor.core import Event, EventType, State @@ -28,7 +30,10 @@ from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context -from llmcompressor.utils.pytorch.module import get_layer_by_name +from llmcompressor.utils.pytorch.module import ( + get_layer_by_name, + get_module_to_name_dict, +) __all__ = ["AWQModifier"] @@ -319,64 +324,48 @@ def _set_resolved_mappings(self, model: Module) -> None: repeat for model.layer.1 and so on """ resolved_mappings: list[ResolvedMapping] = [] - for mapping_idx, mapping in enumerate(self.mappings): - num_skipped_mappings = 0 - - for smooth_name, smooth_layer in ( - pbar := tqdm( - match_named_modules(model, [mapping.smooth_layer], self.ignore) - ) + module_to_name = get_module_to_name_dict(model) + for mapping in self.mappings: + for smooth_layers, *nested_balance_layers in match_modules_set( + model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore ): - pbar.set_description( - f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}" - f" ({num_skipped_mappings} skipped)" + if len(smooth_layers) > 1: + raise ValueError( + "AWQ needs to match a single smoothlayer for each mapping but " + f"got {[module_to_name.get(s) for s in smooth_layers]}" + f" for mapping: {mapping}" + ) + smooth_layer = smooth_layers[0] + smooth_name = module_to_name.get(smooth_layer) + + # [[b00, b01, b02...], [b10, b11, b12,...], ...] v + # [b00, b01, b02, ..., b10, b11, b12, ...] + balance_layers = tree_leaves(nested_balance_layers) + balance_names = [ + module_to_name.get(balance_layer) + for balance_layer in balance_layers + ] + + all_compatible = _check_layers_are_compatible( + smooth_layer, smooth_name, balance_layers, balance_names ) - smooth_parent_name = ".".join(smooth_name.split(".")[:-1]) - smooth_parent = get_layer_by_name(smooth_parent_name, model) - - balance_layers, balance_names = [], [] - for balance_regex in mapping.balance_layers: - # find the submodules that match the activation layer - for balance_suffix, balance_layer in match_named_modules( - smooth_parent, [balance_regex], self.ignore - ): - balance_name = f"{smooth_parent_name}.{balance_suffix}" - - # exclude v_proj->o_proj mappings whose shapes are incompatible - # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 - if ( - isinstance(smooth_layer, torch.nn.Linear) - and isinstance(balance_layer, torch.nn.Linear) - and balance_name.endswith(".o_proj") - and ( - ( - smooth_name.endswith(".v_proj") - and smooth_layer.out_features - != balance_layer.in_features - ) - or ( - smooth_name.endswith(".qkv_proj") - and smooth_layer.out_features - != 3 * balance_layer.in_features - ) - ) - ): - num_skipped_mappings += 1 - continue - - balance_layers.append(balance_layer) - balance_names.append(balance_name) + # skip mapping if any of the balance layers are incompatible + if not all_compatible or len(balance_layers) == 0: + logger.warning( + f"skipping AWQ for {smooth_name} for mapping {mapping}" + + ( + " because found incompatible balance layers" + if not all_compatible + else " because no balance layers were found" + ) + ) - if len(balance_layers) == 0: continue - elif len(balance_layers) == 1: - # for single balance layer, parent is the balance layer - parent_name, parent = balance_name, balance_layer - else: - # for multiple balance layers, find lowest common parent - parent_name, parent = get_lowest_common_parent(balance_names, model) + ancestor_name, ancestor = get_lowest_ancestor_with_avoid( + balance_names, model, torch.nn.ModuleList + ) resolved_mappings.append( ResolvedMapping( @@ -384,8 +373,8 @@ def _set_resolved_mappings(self, model: Module) -> None: smooth_layer, balance_layers, balance_names=balance_names, - parent=parent, - parent_name=parent_name, + parent=ancestor, + parent_name=ancestor_name, ) ) self._resolved_mappings = resolved_mappings @@ -721,6 +710,54 @@ def _assert_all_activations_consumed(self): raise RuntimeError("Some cached activations were not used") +def _check_layers_are_compatible( + smooth_layer, smooth_name, balance_layers, balance_names +): + """ + returns True if they are all compatible + returns False if any smooth & balance layers are incompatible + """ + for balance_layer, balance_name in zip(balance_layers, balance_names): + # exclude v_proj->o_proj mappings whose shapes are incompatible + # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 + if ( + isinstance(smooth_layer, torch.nn.Linear) + and isinstance(balance_layer, torch.nn.Linear) + and balance_name.endswith(".o_proj") + and ( + ( + smooth_name.endswith(".v_proj") + and smooth_layer.out_features != balance_layer.in_features + ) + or ( + smooth_name.endswith(".qkv_proj") + and smooth_layer.out_features != 3 * balance_layer.in_features + ) + ) + ): + return False + return True + + +def get_lowest_ancestor_with_avoid(name: str, model: Module, avoid=torch.nn.Module): + """ + get lowest ancestor that is not the avoided class/type + + NOTE: primarily used to exclude parents of type ModuleList, which don't play + nicely with hooks because their forward method is never directly + called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts + are selected based on router output and their forward method is called. + https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 + """ + while True: + if name == "": + return "", model + ancestor = get_layer_by_name(name, model) + if not isinstance(ancestor, avoid): + return name, ancestor + name = ".".join(name.split(".")[:-1]) + + def _pseudo_quantize_tensor( w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 ): @@ -779,35 +816,3 @@ def _accumulate_mean( new_count = prev_count + num_added return (prev_sum + sum_added) / new_count, new_count - - -def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]: - """ - Given a list of names, returns the lowest-scope common parent. - - NOTE: function excludes parents of type ModuleList, which don't play - nicely with hooks because their forward method is never directly - called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts - are selected based on router output and their forward method is called. - https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 - - Returns name of parent and pointer to parent module - - Implementation is a small alteration of os.path.commonprefix - https://docs.python.org/3/library/os.path.html#os.path.commonprefix - """ - s1 = min(names) - s2 = max(names) - parent_name = "" - for i, c in enumerate(s1): - if c != s2[i]: - parent_name = s1[:i].rstrip(".") - break - - while True: - if parent_name == "": - return "", module - parent = get_layer_by_name(parent_name, module) - if not isinstance(parent, torch.nn.ModuleList): - return parent_name, parent - parent_name = ".".join(parent_name.split(".")[:-1]) diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index dcefa2fa43..98b6d0f091 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -1,11 +1,12 @@ from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple import torch -from compressed_tensors.utils import align_module_device, match_named_modules +from compressed_tensors.utils import align_module_device, match_modules_set from loguru import logger from pydantic import ConfigDict, Field from torch.nn import Module +from torch.utils._pytree import tree_leaves from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier @@ -14,7 +15,7 @@ handle_mapping_resolution_errors, ) from llmcompressor.utils.fsdp.helpers import get_fsdp_parent -from llmcompressor.utils.pytorch.module import get_layer_by_name +from llmcompressor.utils.pytorch.module import get_module_to_name_dict MINIMUM_SMOOTHING_SCALE = 1e-5 @@ -95,7 +96,7 @@ class SmoothQuantModifier(Modifier): """ smoothing_strength: float = 0.5 - mappings: Optional[List[Union[Tuple, List]]] = None + mappings: Optional[List[Tuple[List[str], str]]] = None ignore: Optional[List[str]] = None num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None @@ -198,27 +199,23 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: be balanced. """ resolved_mappings = [] - for to_balance, to_smooth in self.mappings: - to_smooth_list = [to_smooth] if isinstance(to_smooth, str) else to_smooth - - for smooth_name, smooth_layer in match_named_modules( - model, to_smooth_list, self.ignore + module_to_name = get_module_to_name_dict(model) + for mapping in self.mappings: + for *nested_balance_layers, smooth_layers in match_modules_set( + model, tree_leaves(mapping), self.ignore ): - # Search for balance layers within the parent scope - smooth_parent_name = ".".join(smooth_name.split(".")[:-1]) - smooth_parent = get_layer_by_name(smooth_parent_name, model) - - balance_layers = [ - balance_layer - for _, balance_layer in match_named_modules( - smooth_parent, to_balance, self.ignore - ) - ] - - if balance_layers: - resolved_mappings.append( - SmoothQuantMapping(smooth_name, smooth_layer, balance_layers) + if len(smooth_layers) > 1: + raise ValueError( + "SmoothQuant must match a single smooth layer for each mapping" + f" but got {[module_to_name.get(s) for s in smooth_layers]}" + f" for mapping: {mapping}" ) + smooth_layer = smooth_layers[0] + smooth_name = module_to_name.get(smooth_layers[0]) + balance_layers = tree_leaves(nested_balance_layers) + resolved_mappings.append( + SmoothQuantMapping(smooth_name, smooth_layer, balance_layers) + ) return resolved_mappings diff --git a/src/llmcompressor/modifiers/smoothquant/utils.py b/src/llmcompressor/modifiers/smoothquant/utils.py index 8ab38f633a..62f544daeb 100644 --- a/src/llmcompressor/modifiers/smoothquant/utils.py +++ b/src/llmcompressor/modifiers/smoothquant/utils.py @@ -1,6 +1,6 @@ import functools from collections import namedtuple -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple from loguru import logger @@ -10,7 +10,7 @@ "DEFAULT_SMOOTHQUANT_MAPPINGS", ] -LayerMapType = Tuple[Union[List[str], str], Union[List[str], str]] +LayerMapType = Tuple[List[str], str] LayerMap: LayerMapType = namedtuple("LayerMap", ["balance_layers", "smooth_layers"]) DEFAULT_SMOOTHQUANT_MAPPINGS: List[LayerMap] = [ diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 58f2e2977a..d61fffec6d 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -11,6 +11,7 @@ ) from compressed_tensors.utils import TorchDtype, get_head_dim from pydantic import Field, ValidationInfo, field_validator +from torch.utils._pytree import tree_leaves from transformers import PreTrainedModel from llmcompressor.core import Event, EventType, State @@ -205,7 +206,9 @@ def _fuse_norms(self, model: PreTrainedModel): for norm, *linears in match_modules_set( model, (mapping.norm, *mapping.linears) ): - fuse_norm_linears(norm, linears) + # match_modules_set returns a list of lists + assert len(norm) == 1 + fuse_norm_linears(norm[0], tree_leaves(linears)) def _create_r1_scheme(self) -> TransformScheme: return TransformScheme( diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 6d2152fe47..19326c3a94 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -10,6 +10,7 @@ import torch from compressed_tensors import InternalModule from compressed_tensors.quantization.utils import is_module_quantized +from loguru import logger from torch.nn import Linear, Module, Parameter from torch.nn.modules.conv import _ConvNd from transformers import PreTrainedModel @@ -369,3 +370,15 @@ def get_layer_by_name(layer_name: str, module: Module) -> Module: if not layer_name: return module return attrgetter(layer_name)(module) + + +def get_module_to_name_dict(model: Module) -> dict[Module:str]: + module_to_name = {} + for name, module in model.named_modules(): + if module in module_to_name: + logger.info( + f"Warning, {name} and {module_to_name[module]} both " + "share the same module, which can result in unexpected behavior" + ) + module_to_name[module] = name + return module_to_name diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 950ab0f51a..5d82da29a1 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -2,9 +2,9 @@ import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme from pydantic import ValidationError +from torch.nn import Linear from llmcompressor.modifiers.awq import AWQMapping, AWQModifier -from llmcompressor.modifiers.awq.base import get_lowest_common_parent from llmcompressor.modifiers.factory import ModifierFactory @@ -40,16 +40,26 @@ def test_set_resolved_mappings(): ) self_attn = torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 4), - "k_proj": torch.nn.Linear(4, 4), - "v_proj": torch.nn.Linear(4, 4), - "o_proj": torch.nn.Linear(4, 4), + "q_proj": Linear(4, 4), + "k_proj": Linear(4, 4), + "v_proj": Linear(4, 4), + "o_proj": Linear(4, 4), } ) mlp = torch.nn.ModuleDict( { - "up_proj": torch.nn.Linear(4, 10), - "down_proj": torch.nn.Linear(10, 4), + "experts": torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + "gate_proj": Linear(4, 2), + "up_proj": Linear(4, 2), + "down_proj": Linear(2, 4), + } + ) + for _ in range(3) + ] + ) } ) model = torch.nn.ModuleDict( @@ -81,14 +91,19 @@ def test_set_resolved_mappings(): if "self_attn.v_proj" in mapping.smooth_name: assert set(mapping.balance_names) == {"decoder.self_attn.o_proj"} assert mapping.parent_name == "decoder.self_attn.o_proj" - if "mlp.up_proj" in mapping.smooth_name: - assert set(mapping.balance_names) == {"decoder.mlp.down_proj"} - assert mapping.parent_name == "decoder.mlp.down_proj" + if "mlp.experts" in mapping.smooth_name and "up_proj" in mapping.smooth_name: + expert_idx = mapping.smooth_name.split(".")[-2] + expected_down_proj = f"decoder.mlp.experts.{expert_idx}.down_proj" + assert set(mapping.balance_names) == {expected_down_proj} + assert mapping.parent_name == expected_down_proj + assert mapping.parent == mlp["experts"][int(expert_idx)]["down_proj"] - # make sure we exclude case where o_proj/v_proj shapes are mismatched awq = AWQModifier( mappings=[ + # make sure we exclude case where o_proj/v_proj shapes are mismatched AWQMapping("re:.*v_proj", ["re:.*o_proj"]), + # make sure we exclude mapping if any balance layers are skipped + AWQMapping("re:.*v_proj", ["re:.*z_proj", "re:.*o_proj"]), ], scheme="W4A16_ASYM", ) @@ -98,10 +113,11 @@ def test_set_resolved_mappings(): { "self_attn": torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 2), - "k_proj": torch.nn.Linear(4, 2), - "v_proj": torch.nn.Linear(4, 2), - "o_proj": torch.nn.Linear(4, 4), + "q_proj": Linear(4, 2), + "k_proj": Linear(4, 2), + "v_proj": Linear(4, 2), + "z_proj": Linear(2, 4), + "o_proj": Linear(4, 4), } ) } @@ -109,6 +125,16 @@ def test_set_resolved_mappings(): } ) awq._set_resolved_mappings(model) + if len(awq._resolved_mappings) > 0: + assert all( + "o_proj" not in name for name in awq._resolved_mappings[0].balance_names + ), "should have skipped v->o mapping because o is incompatible" + assert all( + "z_proj" not in name for name in awq._resolved_mappings[0].balance_names + ), ( + "should have skipped v->[z,o] mapping because o is incompatible even though" + "z is compatible" + ) assert len(awq._resolved_mappings) == 0 @@ -179,58 +205,61 @@ def test_validate(): @pytest.mark.unit -def test_get_lowest_common_parent(): +def test_moe_multiple_balance_layers(): + """Test AWQ mapping with multiple balance layers in MoE architecture""" + awq = AWQModifier( + mappings=[ + # Map input_layernorm to multiple experts' gate_proj and up_proj + AWQMapping( + "re:.*input_layernorm", + ["re:.*gate_proj", "re:.*up_proj"], + ), + ], + scheme="W4A16_ASYM", + ) + + # Create a simplified MoE model structure mlp = torch.nn.ModuleDict( { "experts": torch.nn.ModuleList( [ torch.nn.ModuleDict( { - "gate_proj": torch.nn.Linear(4, 2), - "down_proj": torch.nn.Linear(4, 2), + "gate_proj": Linear(4, 4), + "up_proj": Linear(4, 4), + "down_proj": Linear(4, 4), } ) - for _ in range(10) + for _ in range(2) ] ) } ) - self_attn = torch.nn.ModuleDict( - { - "q_proj": torch.nn.Linear(4, 2), - "k_proj": torch.nn.Linear(4, 2), - "v_proj": torch.nn.Linear(4, 2), - "o_proj": torch.nn.Linear(4, 4), - } - ) model = torch.nn.ModuleDict( { - "embed_tokens": torch.nn.Linear(4, 2), - "decoder": torch.nn.ModuleDict( + "layer": torch.nn.ModuleDict( { - "self_attn": self_attn, + "input_layernorm": torch.nn.LayerNorm(4), "mlp": mlp, } - ), + ) } ) - parent_name, parent = get_lowest_common_parent( - ["decoder.mlp.experts.1.gate_proj", "decoder.mlp.experts.4.down_proj"], model - ) - assert parent_name == "decoder.mlp" and parent == mlp + awq._set_resolved_mappings(model) - parent_name, parent = get_lowest_common_parent( - ["decoder.self_attn.q_proj", "decoder.self_attn.v_proj"], model - ) - assert parent_name == "decoder.self_attn" and parent == self_attn + # Should have one mapping for input_layernorm + assert len(awq._resolved_mappings) == 1 + mapping = awq._resolved_mappings[0] - parent_name, parent = get_lowest_common_parent( - ["decoder.mlp.experts.1.gate_proj", "decoder.self_attn.v_proj"], model - ) - assert parent_name == "decoder" and parent == model["decoder"] + # Should map to all gate_proj and up_proj across all experts + expected_balance_names = { + "layer.mlp.experts.0.gate_proj", + "layer.mlp.experts.0.up_proj", + "layer.mlp.experts.1.gate_proj", + "layer.mlp.experts.1.up_proj", + } + assert set(mapping.balance_names) == expected_balance_names - parent_name, parent = get_lowest_common_parent( - ["embed_tokens", "decoder.self_attn.v_proj"], model - ) - assert parent_name == "" and parent == model + assert mapping.parent_name == "layer.mlp" + assert mapping.parent == mlp diff --git a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py index f3e948469e..49d0cdda3a 100644 --- a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py @@ -8,18 +8,15 @@ @pytest.mark.unit def test_log_equalization_mapping(state): - # Use regex patterns with parent-scoped search - # Searches for balance layers within the parent of smooth layer - mappings = [(["re:^fc2$"], "re:.*block1\\.fc1$")] + mappings = [(["seq.fc2"], "seq.block1.fc1")] modifier = LogarithmicEqualizationModifier(mappings=mappings) modifier.ignore = [] modifier.resolved_mappings_ = modifier._resolve_mappings(state.model) - assert len(modifier.resolved_mappings_) == 1 + assert len(modifier.resolved_mappings_) == len(mappings) mapping = modifier.resolved_mappings_[0] - assert mapping.smooth_name == "seq.block1.fc1" + assert mapping.smooth_name == mappings[0][1] assert isinstance(mapping.smooth_layer, Linear) - assert len(mapping.balance_layers) == 1 assert isinstance(mapping.balance_layers[0], Linear) diff --git a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py index aefa6b9579..ee8844041e 100644 --- a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py @@ -6,21 +6,15 @@ @pytest.mark.unit def test_smooth_quant_mapping(state): - # Use regex patterns with parent-scoped search - # ^fc1$ matches only direct child "fc1", not nested "block1.fc1" - mappings = [(["re:^fc1$"], "re:.*fc2$")] + mappings = [(["seq.fc1"], "seq.fc2")] modifier = SmoothQuantModifier(mappings=mappings) modifier.ignore = [] modifier.resolved_mappings_ = modifier._resolve_mappings(state.model) - # Should match seq.fc2 and block1.fc2 (both end with fc2) - assert len(modifier.resolved_mappings_) == 2 + assert len(modifier.resolved_mappings_) == len(mappings) - # Verify seq.fc2 mapping - should find only seq.fc1 (direct child) - seq_mapping = [ - m for m in modifier.resolved_mappings_ if m.smooth_name == "seq.fc2" - ][0] - assert isinstance(seq_mapping.smooth_layer, Linear) - assert len(seq_mapping.balance_layers) == 1 - assert isinstance(seq_mapping.balance_layers[0], Linear) + mapping = modifier.resolved_mappings_[0] + assert mapping.smooth_name == mappings[0][1] + assert isinstance(mapping.smooth_layer, Linear) + assert isinstance(mapping.balance_layers[0], Linear)