From 4e9f7cebb23a5c7c91bb90729c24cdbad25702cf Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Nov 2025 20:57:13 +0000 Subject: [PATCH 01/12] [AWQ] small refactor to use match_modules_set Summary: modified _set_resolved_mappings to get smoothing and balance layers at same time. Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 93 ++++++++++++++----------- 1 file changed, 52 insertions(+), 41 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 98e53b4e0..3396b5cde 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -7,6 +7,7 @@ from compressed_tensors.utils import ( align_modules, get_execution_device, + match_modules_set, match_named_modules, update_offload_parameter, ) @@ -312,19 +313,22 @@ def _set_resolved_mappings(self, model: Module) -> None: into ResolvedMapping objects, resolving regular expressions. Result is stored in _resolved_mappings. - For each activation in the mapping list, we find the corresponding weight to - balance by searching for the longest substring. For instance, if our balance - weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we - would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and - repeat for model.layer.1 and so on + Uses match_modules_set to find coherent sets of (smooth_layer, *balance_layers) + that belong together in the model architecture. """ + # Build a module-to-name mapping for efficient lookups + module_to_name = {module: name for name, module in model.named_modules()} + resolved_mappings: list[ResolvedMapping] = [] for mapping_idx, mapping in enumerate(self.mappings): num_skipped_mappings = 0 - for smooth_name, smooth_layer in ( + # Use match_modules_set to find coherent sets of modules + target_patterns = (mapping.smooth_layer, *mapping.balance_layers) + + for modules_set in ( pbar := tqdm( - match_named_modules(model, [mapping.smooth_layer], self.ignore) + match_modules_set(model, target_patterns, self.ignore) ) ): pbar.set_description( @@ -332,48 +336,55 @@ def _set_resolved_mappings(self, model: Module) -> None: f" ({num_skipped_mappings} skipped)" ) - smooth_parent_name = ".".join(smooth_name.split(".")[:-1]) - smooth_parent = get_layer_by_name(smooth_parent_name, model) + # Unpack the matched set: first is smooth_layer, rest are balance_layers + smooth_layer = modules_set[0] + all_balance_layers = list(modules_set[1:]) - balance_layers, balance_names = [], [] - for balance_regex in mapping.balance_layers: - # find the submodules that match the activation layer - for balance_suffix, balance_layer in match_named_modules( - smooth_parent, [balance_regex], self.ignore - ): - balance_name = f"{smooth_parent_name}.{balance_suffix}" - - # exclude v_proj->o_proj mappings whose shapes are incompatible - # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 - if ( - isinstance(smooth_layer, torch.nn.Linear) - and isinstance(balance_layer, torch.nn.Linear) - and balance_name.endswith(".o_proj") - and ( - ( - smooth_name.endswith(".v_proj") - and smooth_layer.out_features - != balance_layer.in_features - ) - or ( - smooth_name.endswith(".qkv_proj") - and smooth_layer.out_features - != 3 * balance_layer.in_features - ) + # Get names using the pre-built mapping + smooth_name = module_to_name.get(smooth_layer) + if smooth_name is None: + continue + + # Filter balance layers, skipping incompatible ones + balance_layers = [] + balance_names = [] + + for balance_layer in all_balance_layers: + balance_name = module_to_name.get(balance_layer) + if balance_name is None: + continue + + # exclude v_proj->o_proj mappings whose shapes are incompatible + # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 + if ( + isinstance(smooth_layer, torch.nn.Linear) + and isinstance(balance_layer, torch.nn.Linear) + and balance_name.endswith(".o_proj") + and ( + ( + smooth_name.endswith(".v_proj") + and smooth_layer.out_features + != balance_layer.in_features + ) + or ( + smooth_name.endswith(".qkv_proj") + and smooth_layer.out_features + != 3 * balance_layer.in_features ) - ): - num_skipped_mappings += 1 - continue + ) + ): + num_skipped_mappings += 1 + continue - balance_layers.append(balance_layer) - balance_names.append(balance_name) + balance_layers.append(balance_layer) + balance_names.append(balance_name) if len(balance_layers) == 0: continue - elif len(balance_layers) == 1: + if len(balance_layers) == 1: # for single balance layer, parent is the balance layer - parent_name, parent = balance_name, balance_layer + parent_name, parent = balance_names[0], balance_layers[0] else: # for multiple balance layers, find lowest common parent parent_name, parent = get_lowest_common_parent(balance_names, model) From 39c1da066f0bf7a36e0abebaefa993727aca727b Mon Sep 17 00:00:00 2001 From: HDCharles Date: Tue, 25 Nov 2025 21:07:36 +0000 Subject: [PATCH 02/12] formatting Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 3396b5cde..e52ae54b6 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -327,9 +327,7 @@ def _set_resolved_mappings(self, model: Module) -> None: target_patterns = (mapping.smooth_layer, *mapping.balance_layers) for modules_set in ( - pbar := tqdm( - match_modules_set(model, target_patterns, self.ignore) - ) + pbar := tqdm(match_modules_set(model, target_patterns, self.ignore)) ): pbar.set_description( f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}" From cd2798ca7537fa068604a2ef83ad2e991433b1cc Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 26 Nov 2025 18:09:44 +0000 Subject: [PATCH 03/12] fixing logic and test update Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 96 +++++++++---------- .../llmcompressor/modifiers/awq/test_base.py | 15 ++- 2 files changed, 61 insertions(+), 50 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index e52ae54b6..b5643d240 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -313,20 +313,21 @@ def _set_resolved_mappings(self, model: Module) -> None: into ResolvedMapping objects, resolving regular expressions. Result is stored in _resolved_mappings. - Uses match_modules_set to find coherent sets of (smooth_layer, *balance_layers) - that belong together in the model architecture. + For each activation in the mapping list, we find the corresponding weight to + balance by searching for the longest substring. For instance, if our balance + weight is ".*re:.*q_proj" and the activation is "re:.*self_attn_layer_norm" we + would match model.layer.0.p_proj to model.layer.0.self_attn_layer_norm and + repeat for model.layer.1 and so on """ - # Build a module-to-name mapping for efficient lookups - module_to_name = {module: name for name, module in model.named_modules()} - resolved_mappings: list[ResolvedMapping] = [] + module_to_name = {module: name for name, module in model.named_modules()} for mapping_idx, mapping in enumerate(self.mappings): num_skipped_mappings = 0 # Use match_modules_set to find coherent sets of modules target_patterns = (mapping.smooth_layer, *mapping.balance_layers) - for modules_set in ( + for smooth_layer, *balance_layers in ( pbar := tqdm(match_modules_set(model, target_patterns, self.ignore)) ): pbar.set_description( @@ -334,53 +335,21 @@ def _set_resolved_mappings(self, model: Module) -> None: f" ({num_skipped_mappings} skipped)" ) - # Unpack the matched set: first is smooth_layer, rest are balance_layers - smooth_layer = modules_set[0] - all_balance_layers = list(modules_set[1:]) - - # Get names using the pre-built mapping smooth_name = module_to_name.get(smooth_layer) - if smooth_name is None: - continue + balance_names = [ + module_to_name.get(balance_layer) + for balance_layer in balance_layers + ] - # Filter balance layers, skipping incompatible ones - balance_layers = [] - balance_names = [] - - for balance_layer in all_balance_layers: - balance_name = module_to_name.get(balance_layer) - if balance_name is None: - continue - - # exclude v_proj->o_proj mappings whose shapes are incompatible - # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 - if ( - isinstance(smooth_layer, torch.nn.Linear) - and isinstance(balance_layer, torch.nn.Linear) - and balance_name.endswith(".o_proj") - and ( - ( - smooth_name.endswith(".v_proj") - and smooth_layer.out_features - != balance_layer.in_features - ) - or ( - smooth_name.endswith(".qkv_proj") - and smooth_layer.out_features - != 3 * balance_layer.in_features - ) - ) - ): - num_skipped_mappings += 1 - continue - - balance_layers.append(balance_layer) - balance_names.append(balance_name) + all_compatible = _check_layers_are_compatible( + smooth_layer, smooth_name, balance_layers, balance_names + ) - if len(balance_layers) == 0: + # skip mapping if any of the balance layers are incompatible + if not all_compatible or len(balance_layers) == 0: + num_skipped_mappings += 1 continue - - if len(balance_layers) == 1: + elif len(balance_layers) == 1: # for single balance layer, parent is the balance layer parent_name, parent = balance_names[0], balance_layers[0] else: @@ -730,6 +699,35 @@ def _assert_all_activations_consumed(self): raise RuntimeError("Some cached activations were not used") +def _check_layers_are_compatible( + smooth_layer, smooth_name, balance_layers, balance_names +): + """ + returns True if they are all compatible + returns False if any smooth & balance layers are incompatible + """ + for balance_layer, balance_name in zip(balance_layers, balance_names): + # exclude v_proj->o_proj mappings whose shapes are incompatible + # https://github.com/mit-han-lab/llm-awq/pull/67#issuecomment-1681632777 + if ( + isinstance(smooth_layer, torch.nn.Linear) + and isinstance(balance_layer, torch.nn.Linear) + and balance_name.endswith(".o_proj") + and ( + ( + smooth_name.endswith(".v_proj") + and smooth_layer.out_features != balance_layer.in_features + ) + or ( + smooth_name.endswith(".qkv_proj") + and smooth_layer.out_features != 3 * balance_layer.in_features + ) + ) + ): + return False + return True + + def _pseudo_quantize_tensor( w: torch.Tensor, symmetric: bool = False, bit_width: int = 8, group_size: int = -1 ): diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 950ab0f51..83eaaa970 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -85,10 +85,12 @@ def test_set_resolved_mappings(): assert set(mapping.balance_names) == {"decoder.mlp.down_proj"} assert mapping.parent_name == "decoder.mlp.down_proj" - # make sure we exclude case where o_proj/v_proj shapes are mismatched awq = AWQModifier( mappings=[ + # make sure we exclude case where o_proj/v_proj shapes are mismatched AWQMapping("re:.*v_proj", ["re:.*o_proj"]), + # make sure we exclude mapping if any balance layers are skipped + AWQMapping("re:.*v_proj", ["re:.*z_proj", "re:.*o_proj"]), ], scheme="W4A16_ASYM", ) @@ -101,6 +103,7 @@ def test_set_resolved_mappings(): "q_proj": torch.nn.Linear(4, 2), "k_proj": torch.nn.Linear(4, 2), "v_proj": torch.nn.Linear(4, 2), + "z_proj": torch.nn.Linear(2, 4), "o_proj": torch.nn.Linear(4, 4), } ) @@ -109,6 +112,16 @@ def test_set_resolved_mappings(): } ) awq._set_resolved_mappings(model) + if len(awq._resolved_mappings) > 0: + assert all( + "o_proj" not in name for name in awq._resolved_mappings[0].balance_names + ), "should have skipped v->o mapping because o is incompatible" + assert all( + "z_proj" not in name for name in awq._resolved_mappings[0].balance_names + ), ( + "should have skipped v->[z,o] mapping because o is incompatible even though" + "z is compatible" + ) assert len(awq._resolved_mappings) == 0 From 4553c588ee1b72b8249193265e69f8e658ad8e48 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 27 Nov 2025 03:08:53 +0000 Subject: [PATCH 04/12] updates to get_lowest_common_x Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 67 ++++++++++++------- .../llmcompressor/modifiers/awq/test_base.py | 65 +++++++++++------- 2 files changed, 84 insertions(+), 48 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index b5643d240..8ae2c41e9 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -320,21 +320,26 @@ def _set_resolved_mappings(self, model: Module) -> None: repeat for model.layer.1 and so on """ resolved_mappings: list[ResolvedMapping] = [] - module_to_name = {module: name for name, module in model.named_modules()} - for mapping_idx, mapping in enumerate(self.mappings): - num_skipped_mappings = 0 + + module_to_name = {} + for name, module in model.named_modules(): + if module in module_to_name: + logger.info( + f"Warning, {name} and {module_to_name[module]} both " + "share the same module the same module, " + "may have trouble resolving mappings." + ) + module_to_name[module] = name + + + + for mapping in self.mappings: - # Use match_modules_set to find coherent sets of modules target_patterns = (mapping.smooth_layer, *mapping.balance_layers) for smooth_layer, *balance_layers in ( - pbar := tqdm(match_modules_set(model, target_patterns, self.ignore)) + match_modules_set(model, target_patterns, self.ignore) ): - pbar.set_description( - f"Resolving mapping {mapping_idx+1}/{len(self.mappings)}" - f" ({num_skipped_mappings} skipped)" - ) - smooth_name = module_to_name.get(smooth_layer) balance_names = [ module_to_name.get(balance_layer) @@ -347,14 +352,18 @@ def _set_resolved_mappings(self, model: Module) -> None: # skip mapping if any of the balance layers are incompatible if not all_compatible or len(balance_layers) == 0: - num_skipped_mappings += 1 + logger.info( + f"skipping AWQ for {smooth_name} for mapping {mapping}" + ( + " because found incompatible balance layers" + if not all_compatible else + f" because no balance layers were found" + ) + ) + continue - elif len(balance_layers) == 1: - # for single balance layer, parent is the balance layer - parent_name, parent = balance_names[0], balance_layers[0] else: # for multiple balance layers, find lowest common parent - parent_name, parent = get_lowest_common_parent(balance_names, model) + parent_name, parent = get_lowest_common_module(balance_names, model) resolved_mappings.append( ResolvedMapping( @@ -788,29 +797,41 @@ def _accumulate_mean( return (prev_sum + sum_added) / new_count, new_count -def get_lowest_common_parent(names: list[str], module: Module) -> tuple[str, Module]: +def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Module]: """ - Given a list of names, returns the lowest-scope common parent. + Given a list of names, returns the lowest-scope common module. - NOTE: function excludes parents of type ModuleList, which don't play + NOTE: function excludes modules of type ModuleList, which don't play nicely with hooks because their forward method is never directly called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts are selected based on router output and their forward method is called. https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 - Returns name of parent and pointer to parent module + Returns name of module and pointer to module Implementation is a small alteration of os.path.commonprefix https://docs.python.org/3/library/os.path.html#os.path.commonprefix """ - s1 = min(names) - s2 = max(names) - parent_name = "" + # adding "." before and after allows for handling a lot of corner + # cases which were previously mishandled ([case]->prefix->result) + # case 0: single module: [.abc.] -> .abc. -> abc + # case 1: substring modules: [.abc., .ab.] -> .ab -> "" + # case 2: parent & child: [.ab., .ab.a.] -> .ab. -> ab + s1 = min(names) + "." + s2 = max(names) + "." + + # 1) find longest shared prefix + parent_name = "." for i, c in enumerate(s1): if c != s2[i]: - parent_name = s1[:i].rstrip(".") break + parent_name += c + + # 2) throw away module name fragment and leading dot + # ".keep.thro" -> "keep" + parent_name = parent_name[1:parent_name.rfind(".")] + # 3) return first parent that is not a module list while True: if parent_name == "": return "", module diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 83eaaa970..2cb78fb65 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -2,9 +2,9 @@ import torch from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme from pydantic import ValidationError - +from torch.nn import Linear from llmcompressor.modifiers.awq import AWQMapping, AWQModifier -from llmcompressor.modifiers.awq.base import get_lowest_common_parent +from llmcompressor.modifiers.awq.base import get_lowest_common_module from llmcompressor.modifiers.factory import ModifierFactory @@ -40,16 +40,16 @@ def test_set_resolved_mappings(): ) self_attn = torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 4), - "k_proj": torch.nn.Linear(4, 4), - "v_proj": torch.nn.Linear(4, 4), - "o_proj": torch.nn.Linear(4, 4), + "q_proj": Linear(4, 4), + "k_proj": Linear(4, 4), + "v_proj": Linear(4, 4), + "o_proj": Linear(4, 4), } ) mlp = torch.nn.ModuleDict( { - "up_proj": torch.nn.Linear(4, 10), - "down_proj": torch.nn.Linear(10, 4), + "up_proj": Linear(4, 10), + "down_proj": Linear(10, 4), } ) model = torch.nn.ModuleDict( @@ -100,11 +100,11 @@ def test_set_resolved_mappings(): { "self_attn": torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 2), - "k_proj": torch.nn.Linear(4, 2), - "v_proj": torch.nn.Linear(4, 2), - "z_proj": torch.nn.Linear(2, 4), - "o_proj": torch.nn.Linear(4, 4), + "q_proj": Linear(4, 2), + "k_proj": Linear(4, 2), + "v_proj": Linear(4, 2), + "z_proj": Linear(2, 4), + "o_proj": Linear(4, 4), } ) } @@ -192,15 +192,15 @@ def test_validate(): @pytest.mark.unit -def test_get_lowest_common_parent(): +def test_get_lowest_common_module(): mlp = torch.nn.ModuleDict( { "experts": torch.nn.ModuleList( [ torch.nn.ModuleDict( { - "gate_proj": torch.nn.Linear(4, 2), - "down_proj": torch.nn.Linear(4, 2), + "gate_proj": Linear(4, 2), + "down_proj": Linear(4, 2), } ) for _ in range(10) @@ -210,15 +210,15 @@ def test_get_lowest_common_parent(): ) self_attn = torch.nn.ModuleDict( { - "q_proj": torch.nn.Linear(4, 2), - "k_proj": torch.nn.Linear(4, 2), - "v_proj": torch.nn.Linear(4, 2), - "o_proj": torch.nn.Linear(4, 4), + "q_proj": Linear(4, 2), + "k_proj": Linear(4, 2), + "v_proj": Linear(4, 2), + "o_proj": Linear(4, 4), } ) model = torch.nn.ModuleDict( { - "embed_tokens": torch.nn.Linear(4, 2), + "embed_tokens": Linear(4, 2), "decoder": torch.nn.ModuleDict( { "self_attn": self_attn, @@ -228,22 +228,37 @@ def test_get_lowest_common_parent(): } ) - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["decoder.mlp.experts.1.gate_proj", "decoder.mlp.experts.4.down_proj"], model ) assert parent_name == "decoder.mlp" and parent == mlp - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["decoder.self_attn.q_proj", "decoder.self_attn.v_proj"], model ) assert parent_name == "decoder.self_attn" and parent == self_attn - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["decoder.mlp.experts.1.gate_proj", "decoder.self_attn.v_proj"], model ) assert parent_name == "decoder" and parent == model["decoder"] - parent_name, parent = get_lowest_common_parent( + parent_name, parent = get_lowest_common_module( ["embed_tokens", "decoder.self_attn.v_proj"], model ) assert parent_name == "" and parent == model + + m = torch.nn.ModuleDict( + { + "abc": Linear(3,3), + "ab": torch.nn.ModuleDict({"a": Linear(3,3)}), + "z": Linear(3,3) + } + ) + parent_name, parent = get_lowest_common_module(["abc", "ab"], m) + assert parent_name == "" + parent_name, parent = get_lowest_common_module(["ab", "ab.a"], m) + assert parent_name == "ab" + parent_name, parent = get_lowest_common_module(["z"], m) + assert parent_name == "z" + From d6dc2898c6a1e47a3f280c5f6ef45640a4b7219a Mon Sep 17 00:00:00 2001 From: HDCharles Date: Thu, 27 Nov 2025 03:26:33 +0000 Subject: [PATCH 05/12] format Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 24 +++++++++---------- .../llmcompressor/modifiers/awq/test_base.py | 8 +++---- 2 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 8ae2c41e9..420dadf2b 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -320,7 +320,7 @@ def _set_resolved_mappings(self, model: Module) -> None: repeat for model.layer.1 and so on """ resolved_mappings: list[ResolvedMapping] = [] - + module_to_name = {} for name, module in model.named_modules(): if module in module_to_name: @@ -331,14 +331,11 @@ def _set_resolved_mappings(self, model: Module) -> None: ) module_to_name[module] = name - - for mapping in self.mappings: - target_patterns = (mapping.smooth_layer, *mapping.balance_layers) - for smooth_layer, *balance_layers in ( - match_modules_set(model, target_patterns, self.ignore) + for smooth_layer, *balance_layers in match_modules_set( + model, target_patterns, self.ignore ): smooth_name = module_to_name.get(smooth_layer) balance_names = [ @@ -353,10 +350,11 @@ def _set_resolved_mappings(self, model: Module) -> None: # skip mapping if any of the balance layers are incompatible if not all_compatible or len(balance_layers) == 0: logger.info( - f"skipping AWQ for {smooth_name} for mapping {mapping}" + ( - " because found incompatible balance layers" - if not all_compatible else - f" because no balance layers were found" + f"skipping AWQ for {smooth_name} for mapping {mapping}" + + ( + " because found incompatible balance layers" + if not all_compatible + else " because no balance layers were found" ) ) @@ -812,7 +810,7 @@ def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Mod Implementation is a small alteration of os.path.commonprefix https://docs.python.org/3/library/os.path.html#os.path.commonprefix """ - # adding "." before and after allows for handling a lot of corner + # adding "." before and after allows for handling a lot of corner # cases which were previously mishandled ([case]->prefix->result) # case 0: single module: [.abc.] -> .abc. -> abc # case 1: substring modules: [.abc., .ab.] -> .ab -> "" @@ -829,9 +827,9 @@ def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Mod # 2) throw away module name fragment and leading dot # ".keep.thro" -> "keep" - parent_name = parent_name[1:parent_name.rfind(".")] + parent_name = parent_name[1 : parent_name.rfind(".")] - # 3) return first parent that is not a module list + # 3) return first common module that is not a module list while True: if parent_name == "": return "", module diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 2cb78fb65..e8103f9e3 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -3,6 +3,7 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationScheme from pydantic import ValidationError from torch.nn import Linear + from llmcompressor.modifiers.awq import AWQMapping, AWQModifier from llmcompressor.modifiers.awq.base import get_lowest_common_module from llmcompressor.modifiers.factory import ModifierFactory @@ -250,9 +251,9 @@ def test_get_lowest_common_module(): m = torch.nn.ModuleDict( { - "abc": Linear(3,3), - "ab": torch.nn.ModuleDict({"a": Linear(3,3)}), - "z": Linear(3,3) + "abc": Linear(3, 3), + "ab": torch.nn.ModuleDict({"a": Linear(3, 3)}), + "z": Linear(3, 3), } ) parent_name, parent = get_lowest_common_module(["abc", "ab"], m) @@ -261,4 +262,3 @@ def test_get_lowest_common_module(): assert parent_name == "ab" parent_name, parent = get_lowest_common_module(["z"], m) assert parent_name == "z" - From 0f265f9b7b6eca6bab3295f929be08b1ce873bd4 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 16:04:41 +0000 Subject: [PATCH 06/12] fixes Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 68 +++++++-------- .../modifiers/transform/spinquant/base.py | 7 +- .../llmcompressor/modifiers/awq/test_base.py | 86 ++++++------------- 3 files changed, 62 insertions(+), 99 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 420dadf2b..c963a1149 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -10,12 +10,13 @@ match_modules_set, match_named_modules, update_offload_parameter, + get_lowest_common_ancestor_name, ) from loguru import logger from pydantic import ConfigDict, PrivateAttr, model_validator from torch.nn import Module from tqdm import tqdm - +from torch.utils._pytree import tree_flatten from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.awq.mappings import ( @@ -332,12 +333,20 @@ def _set_resolved_mappings(self, model: Module) -> None: module_to_name[module] = name for mapping in self.mappings: - target_patterns = (mapping.smooth_layer, *mapping.balance_layers) - - for smooth_layer, *balance_layers in match_modules_set( - model, target_patterns, self.ignore + for smooth_layers, *nested_balance_layers in match_modules_set( + model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore ): + assert len(smooth_layers)==1, ( + "AWQ mappings need to match a single smoothlayer for each mapping but got " + f"{[module_to_name.get(smooth_layer) for smooth_layer in smooth_layers]} " + f"when matching {mapping.smooth_layer}" + ) + smooth_layer = smooth_layers[0] smooth_name = module_to_name.get(smooth_layer) + + #[[b00, b01, b02...], [b10, b11, b12,...], ...] v + # [b00, b01, b02, ..., b10, b11, b12, ...] + balance_layers = tree_flatten(nested_balance_layers)[0] balance_names = [ module_to_name.get(balance_layer) for balance_layer in balance_layers @@ -361,7 +370,8 @@ def _set_resolved_mappings(self, model: Module) -> None: continue else: # for multiple balance layers, find lowest common parent - parent_name, parent = get_lowest_common_module(balance_names, model) + ancestor_name = get_lowest_common_ancestor_name(balance_names) + ancestor, ancestor_name = get_lowest_non_module_list_ancestor(ancestor_name, ) resolved_mappings.append( ResolvedMapping( @@ -369,8 +379,8 @@ def _set_resolved_mappings(self, model: Module) -> None: smooth_layer, balance_layers, balance_names=balance_names, - parent=parent, - parent_name=parent_name, + parent=ancestor, + parent_name=ancestor_name, ) ) self._resolved_mappings = resolved_mappings @@ -795,45 +805,25 @@ def _accumulate_mean( return (prev_sum + sum_added) / new_count, new_count -def get_lowest_common_module(names: list[str], module: Module) -> tuple[str, Module]: +def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Module]: """ - Given a list of names, returns the lowest-scope common module. + Given a name: foo.bar.baz, finds lowest ancestor that's not a ModuleList + i.e. module_list.module_dict.module_list -> module_list.module_dict + i.e. module_list.module_dict -> module_list.module_dict + (self is an ancestor of self) - NOTE: function excludes modules of type ModuleList, which don't play + NOTE: This is needed because ModuleLists don't play nicely with hooks because their forward method is never directly called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts are selected based on router output and their forward method is called. https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 Returns name of module and pointer to module - - Implementation is a small alteration of os.path.commonprefix - https://docs.python.org/3/library/os.path.html#os.path.commonprefix """ - # adding "." before and after allows for handling a lot of corner - # cases which were previously mishandled ([case]->prefix->result) - # case 0: single module: [.abc.] -> .abc. -> abc - # case 1: substring modules: [.abc., .ab.] -> .ab -> "" - # case 2: parent & child: [.ab., .ab.a.] -> .ab. -> ab - s1 = min(names) + "." - s2 = max(names) + "." - - # 1) find longest shared prefix - parent_name = "." - for i, c in enumerate(s1): - if c != s2[i]: - break - parent_name += c - - # 2) throw away module name fragment and leading dot - # ".keep.thro" -> "keep" - parent_name = parent_name[1 : parent_name.rfind(".")] - - # 3) return first common module that is not a module list while True: - if parent_name == "": + if name == "": return "", module - parent = get_layer_by_name(parent_name, module) - if not isinstance(parent, torch.nn.ModuleList): - return parent_name, parent - parent_name = ".".join(parent_name.split(".")[:-1]) + module = get_layer_by_name(name, module) + if not isinstance(module, torch.nn.ModuleList): + return name, module + name = ".".join(parent_name.split(".")[:-1]) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index aac21d29c..740fe3922 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -9,6 +9,7 @@ TransformScheme, apply_transform_config, ) +from torch.utils._pytree import tree_flatten from compressed_tensors.utils import TorchDtype, get_head_dim from pydantic import Field, ValidationInfo, field_validator from transformers import PreTrainedModel @@ -203,8 +204,10 @@ def _fuse_norms(self, model: PreTrainedModel): for mapping in self.norm_mappings: for norm, *linears in match_modules_set( model, (mapping.norm, *mapping.linears) - ): - fuse_norm_linears(norm, linears) + ): + # match_modules_set returns a list of lists + assert len(norm) == 1 + fuse_norm_linears(norm[0], tree_flatten(linears)[0]) def _create_r1_scheme(self) -> TransformScheme: return TransformScheme( diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index e8103f9e3..119dfc11a 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -5,7 +5,7 @@ from torch.nn import Linear from llmcompressor.modifiers.awq import AWQMapping, AWQModifier -from llmcompressor.modifiers.awq.base import get_lowest_common_module +from llmcompressor.modifiers.awq.base import get_lowest_non_module_list_ancestor from llmcompressor.modifiers.factory import ModifierFactory @@ -47,11 +47,18 @@ def test_set_resolved_mappings(): "o_proj": Linear(4, 4), } ) - mlp = torch.nn.ModuleDict( - { - "up_proj": Linear(4, 10), - "down_proj": Linear(10, 4), - } + mlp = torch.nn.ModuleList( + "experts": torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + "gate_proj": Linear(4, 2), + "down_proj": Linear(4, 2), + } + ) + for _ in range(3) + ] + ) ) model = torch.nn.ModuleDict( { @@ -83,8 +90,8 @@ def test_set_resolved_mappings(): assert set(mapping.balance_names) == {"decoder.self_attn.o_proj"} assert mapping.parent_name == "decoder.self_attn.o_proj" if "mlp.up_proj" in mapping.smooth_name: - assert set(mapping.balance_names) == {"decoder.mlp.down_proj"} - assert mapping.parent_name == "decoder.mlp.down_proj" + assert set(mapping.balance_names) == {"decoder.mlp.0.down_proj", "decoder.mlp.0.down_proj", "decoder.mlp.0.down_proj"} + assert mapping.parent_name == "decoder.mlp.down_proj" # TODODODO awq = AWQModifier( mappings=[ @@ -193,15 +200,15 @@ def test_validate(): @pytest.mark.unit -def test_get_lowest_common_module(): - mlp = torch.nn.ModuleDict( +def test_get_lowest_non_module_list_ancestor(): + model = torch.nn.ModuleDict( { "experts": torch.nn.ModuleList( [ torch.nn.ModuleDict( { "gate_proj": Linear(4, 2), - "down_proj": Linear(4, 2), + "down_proj": Linear(2, 4), } ) for _ in range(10) @@ -209,56 +216,19 @@ def test_get_lowest_common_module(): ) } ) - self_attn = torch.nn.ModuleDict( - { - "q_proj": Linear(4, 2), - "k_proj": Linear(4, 2), - "v_proj": Linear(4, 2), - "o_proj": Linear(4, 4), - } + + ancestor_name, ancestor = get_lowest_non_module_list_ancestor( + "", model ) - model = torch.nn.ModuleDict( - { - "embed_tokens": Linear(4, 2), - "decoder": torch.nn.ModuleDict( - { - "self_attn": self_attn, - "mlp": mlp, - } - ), - } - ) - - parent_name, parent = get_lowest_common_module( - ["decoder.mlp.experts.1.gate_proj", "decoder.mlp.experts.4.down_proj"], model - ) - assert parent_name == "decoder.mlp" and parent == mlp - - parent_name, parent = get_lowest_common_module( - ["decoder.self_attn.q_proj", "decoder.self_attn.v_proj"], model - ) - assert parent_name == "decoder.self_attn" and parent == self_attn + assert ancestor_name == "" and ancestor == model - parent_name, parent = get_lowest_common_module( - ["decoder.mlp.experts.1.gate_proj", "decoder.self_attn.v_proj"], model + ancestor_name, ancestor = get_lowest_non_module_list_ancestor( + ["experts"], model ) - assert parent_name == "decoder" and parent == model["decoder"] + assert ancestor_name == "" and ancestor == model - parent_name, parent = get_lowest_common_module( - ["embed_tokens", "decoder.self_attn.v_proj"], model + ancestor_name, ancestor = get_lowest_non_module_list_ancestor( + "experts.1.gate_proj", model ) - assert parent_name == "" and parent == model + assert ancestor_name == "experts.1.gate_proj" and ancestor == model["experts"][1]["gate_proj"] - m = torch.nn.ModuleDict( - { - "abc": Linear(3, 3), - "ab": torch.nn.ModuleDict({"a": Linear(3, 3)}), - "z": Linear(3, 3), - } - ) - parent_name, parent = get_lowest_common_module(["abc", "ab"], m) - assert parent_name == "" - parent_name, parent = get_lowest_common_module(["ab", "ab.a"], m) - assert parent_name == "ab" - parent_name, parent = get_lowest_common_module(["z"], m) - assert parent_name == "z" From 2b138a774ec79fb760a73549413f08533df4fea9 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 18:44:36 +0000 Subject: [PATCH 07/12] tests Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 13 +-- .../llmcompressor/modifiers/awq/test_base.py | 99 ++++++++++++++++--- 2 files changed, 90 insertions(+), 22 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index c963a1149..5d83a2420 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -371,7 +371,7 @@ def _set_resolved_mappings(self, model: Module) -> None: else: # for multiple balance layers, find lowest common parent ancestor_name = get_lowest_common_ancestor_name(balance_names) - ancestor, ancestor_name = get_lowest_non_module_list_ancestor(ancestor_name, ) + ancestor_name, ancestor = get_lowest_non_module_list_ancestor(ancestor_name, model) resolved_mappings.append( ResolvedMapping( @@ -807,7 +807,8 @@ def _accumulate_mean( def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Module]: """ - Given a name: foo.bar.baz, finds lowest ancestor that's not a ModuleList + Given a name and a model, finds lowest ancestor of + named module that's not a ModuleList i.e. module_list.module_dict.module_list -> module_list.module_dict i.e. module_list.module_dict -> module_list.module_dict (self is an ancestor of self) @@ -823,7 +824,7 @@ def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Modu while True: if name == "": return "", module - module = get_layer_by_name(name, module) - if not isinstance(module, torch.nn.ModuleList): - return name, module - name = ".".join(parent_name.split(".")[:-1]) + current_module = get_layer_by_name(name, module) + if not isinstance(current_module, torch.nn.ModuleList): + return name, current_module + name = ".".join(name.split(".")[:-1]) diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 119dfc11a..b4161e7b3 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -47,18 +47,21 @@ def test_set_resolved_mappings(): "o_proj": Linear(4, 4), } ) - mlp = torch.nn.ModuleList( - "experts": torch.nn.ModuleList( - [ - torch.nn.ModuleDict( - { - "gate_proj": Linear(4, 2), - "down_proj": Linear(4, 2), - } - ) - for _ in range(3) - ] - ) + mlp = torch.nn.ModuleDict( + { + "experts": torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + "gate_proj": Linear(4, 2), + "up_proj": Linear(4, 2), + "down_proj": Linear(2, 4), + } + ) + for _ in range(3) + ] + ) + } ) model = torch.nn.ModuleDict( { @@ -89,9 +92,12 @@ def test_set_resolved_mappings(): if "self_attn.v_proj" in mapping.smooth_name: assert set(mapping.balance_names) == {"decoder.self_attn.o_proj"} assert mapping.parent_name == "decoder.self_attn.o_proj" - if "mlp.up_proj" in mapping.smooth_name: - assert set(mapping.balance_names) == {"decoder.mlp.0.down_proj", "decoder.mlp.0.down_proj", "decoder.mlp.0.down_proj"} - assert mapping.parent_name == "decoder.mlp.down_proj" # TODODODO + if "mlp.experts" in mapping.smooth_name and "up_proj" in mapping.smooth_name: + expert_idx = mapping.smooth_name.split(".")[-2] + expected_down_proj = f"decoder.mlp.experts.{expert_idx}.down_proj" + assert set(mapping.balance_names) == {expected_down_proj} + assert mapping.parent_name == expected_down_proj + assert mapping.parent == mlp["experts"][int(expert_idx)]["down_proj"] awq = AWQModifier( mappings=[ @@ -223,7 +229,7 @@ def test_get_lowest_non_module_list_ancestor(): assert ancestor_name == "" and ancestor == model ancestor_name, ancestor = get_lowest_non_module_list_ancestor( - ["experts"], model + "experts", model ) assert ancestor_name == "" and ancestor == model @@ -232,3 +238,64 @@ def test_get_lowest_non_module_list_ancestor(): ) assert ancestor_name == "experts.1.gate_proj" and ancestor == model["experts"][1]["gate_proj"] + +@pytest.mark.unit +def test_moe_multiple_balance_layers(): + """Test AWQ mapping with multiple balance layers in MoE architecture""" + awq = AWQModifier( + mappings=[ + # Map input_layernorm to multiple experts' gate_proj and up_proj + AWQMapping( + "re:.*input_layernorm", + ["re:.*gate_proj", "re:.*up_proj"], + ), + ], + scheme="W4A16_ASYM", + ) + + # Create a simplified MoE model structure + mlp = torch.nn.ModuleDict( + { + "experts": torch.nn.ModuleList( + [ + torch.nn.ModuleDict( + { + "gate_proj": Linear(4, 4), + "up_proj": Linear(4, 4), + "down_proj": Linear(4, 4), + } + ) + for _ in range(2) + ] + ) + } + ) + model = torch.nn.ModuleDict( + { + "layer": torch.nn.ModuleDict( + { + "input_layernorm": torch.nn.LayerNorm(4), + "mlp": mlp, + } + ) + } + ) + + awq._set_resolved_mappings(model) + + # Should have one mapping for input_layernorm + assert len(awq._resolved_mappings) == 1 + mapping = awq._resolved_mappings[0] + + # Should map to all gate_proj and up_proj across all experts + expected_balance_names = { + "layer.mlp.experts.0.gate_proj", + "layer.mlp.experts.0.up_proj", + "layer.mlp.experts.1.gate_proj", + "layer.mlp.experts.1.up_proj", + } + assert set(mapping.balance_names) == expected_balance_names + + assert mapping.parent_name == "layer.mlp" + assert mapping.parent == mlp + From 07c657e2f9bfdc794fd29429b90d04d32096cec2 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 20:27:15 +0000 Subject: [PATCH 08/12] fixes and formatting Summary fix smoothquant logic to align with AWQ Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 38 +++++++--------- .../modifiers/smoothquant/base.py | 44 +++++++++---------- .../modifiers/smoothquant/utils.py | 4 +- .../modifiers/transform/spinquant/base.py | 4 +- src/llmcompressor/utils/pytorch/module.py | 14 ++++++ .../llmcompressor/modifiers/awq/test_base.py | 16 +++---- 6 files changed, 62 insertions(+), 58 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 5d83a2420..a28e5c326 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -7,16 +7,17 @@ from compressed_tensors.utils import ( align_modules, get_execution_device, + get_lowest_common_ancestor_name, match_modules_set, match_named_modules, update_offload_parameter, - get_lowest_common_ancestor_name, ) from loguru import logger from pydantic import ConfigDict, PrivateAttr, model_validator from torch.nn import Module -from tqdm import tqdm from torch.utils._pytree import tree_flatten +from tqdm import tqdm + from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.awq.mappings import ( @@ -30,7 +31,10 @@ from llmcompressor.pipelines.cache import IntermediatesCache from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context -from llmcompressor.utils.pytorch.module import get_layer_by_name +from llmcompressor.utils.pytorch.module import ( + get_layer_by_name, + get_module_to_name_dict, +) __all__ = ["AWQModifier"] @@ -321,30 +325,20 @@ def _set_resolved_mappings(self, model: Module) -> None: repeat for model.layer.1 and so on """ resolved_mappings: list[ResolvedMapping] = [] - - module_to_name = {} - for name, module in model.named_modules(): - if module in module_to_name: - logger.info( - f"Warning, {name} and {module_to_name[module]} both " - "share the same module the same module, " - "may have trouble resolving mappings." - ) - module_to_name[module] = name - + module_to_name = get_module_to_name_dict(model) for mapping in self.mappings: for smooth_layers, *nested_balance_layers in match_modules_set( model, (mapping.smooth_layer, *mapping.balance_layers), self.ignore ): - assert len(smooth_layers)==1, ( - "AWQ mappings need to match a single smoothlayer for each mapping but got " - f"{[module_to_name.get(smooth_layer) for smooth_layer in smooth_layers]} " - f"when matching {mapping.smooth_layer}" + assert len(smooth_layers) == 1, ( + "AWQ mappings need to match a single smoothlayer for each " + f"mapping but got {[module_to_name.get(s) for s in smooth_layers]}" + f" for mapping: {mapping}" ) smooth_layer = smooth_layers[0] smooth_name = module_to_name.get(smooth_layer) - #[[b00, b01, b02...], [b10, b11, b12,...], ...] v + # [[b00, b01, b02...], [b10, b11, b12,...], ...] v # [b00, b01, b02, ..., b10, b11, b12, ...] balance_layers = tree_flatten(nested_balance_layers)[0] balance_names = [ @@ -371,7 +365,9 @@ def _set_resolved_mappings(self, model: Module) -> None: else: # for multiple balance layers, find lowest common parent ancestor_name = get_lowest_common_ancestor_name(balance_names) - ancestor_name, ancestor = get_lowest_non_module_list_ancestor(ancestor_name, model) + ancestor_name, ancestor = get_lowest_non_module_list_ancestor( + ancestor_name, model + ) resolved_mappings.append( ResolvedMapping( @@ -807,7 +803,7 @@ def _accumulate_mean( def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Module]: """ - Given a name and a model, finds lowest ancestor of + Given a name and a model, finds lowest ancestor of named module that's not a ModuleList i.e. module_list.module_dict.module_list -> module_list.module_dict i.e. module_list.module_dict -> module_list.module_dict diff --git a/src/llmcompressor/modifiers/smoothquant/base.py b/src/llmcompressor/modifiers/smoothquant/base.py index dcefa2fa4..ea5939c57 100644 --- a/src/llmcompressor/modifiers/smoothquant/base.py +++ b/src/llmcompressor/modifiers/smoothquant/base.py @@ -1,11 +1,12 @@ from dataclasses import dataclass -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple import torch -from compressed_tensors.utils import align_module_device, match_named_modules +from compressed_tensors.utils import align_module_device, match_modules_set from loguru import logger from pydantic import ConfigDict, Field from torch.nn import Module +from torch.utils._pytree import tree_flatten from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier @@ -14,7 +15,7 @@ handle_mapping_resolution_errors, ) from llmcompressor.utils.fsdp.helpers import get_fsdp_parent -from llmcompressor.utils.pytorch.module import get_layer_by_name +from llmcompressor.utils.pytorch.module import get_module_to_name_dict MINIMUM_SMOOTHING_SCALE = 1e-5 @@ -95,7 +96,7 @@ class SmoothQuantModifier(Modifier): """ smoothing_strength: float = 0.5 - mappings: Optional[List[Union[Tuple, List]]] = None + mappings: Optional[List[Tuple[List[str], str]]] = None ignore: Optional[List[str]] = None num_calibration_steps: Optional[int] = None calibration_function: Optional[Callable] = None @@ -198,27 +199,22 @@ def _resolve_mappings(self, model: Module) -> List[SmoothQuantMapping]: be balanced. """ resolved_mappings = [] - for to_balance, to_smooth in self.mappings: - to_smooth_list = [to_smooth] if isinstance(to_smooth, str) else to_smooth - - for smooth_name, smooth_layer in match_named_modules( - model, to_smooth_list, self.ignore + module_to_name = get_module_to_name_dict(model) + for mapping in self.mappings: + for *nested_balance_layers, smooth_layers in match_modules_set( + model, tree_flatten(mapping)[0], self.ignore ): - # Search for balance layers within the parent scope - smooth_parent_name = ".".join(smooth_name.split(".")[:-1]) - smooth_parent = get_layer_by_name(smooth_parent_name, model) - - balance_layers = [ - balance_layer - for _, balance_layer in match_named_modules( - smooth_parent, to_balance, self.ignore - ) - ] - - if balance_layers: - resolved_mappings.append( - SmoothQuantMapping(smooth_name, smooth_layer, balance_layers) - ) + assert len(smooth_layers) == 1, ( + "SmoothQuant mappings must match a single smooth layer for each " + f"mapping but got {[module_to_name.get(s) for s in smooth_layers]}" + f" for mapping: {mapping}" + ) + smooth_layer = smooth_layers[0] + smooth_name = module_to_name.get(smooth_layers[0]) + balance_layers = tree_flatten(nested_balance_layers)[0] + resolved_mappings.append( + SmoothQuantMapping(smooth_name, smooth_layer, balance_layers) + ) return resolved_mappings diff --git a/src/llmcompressor/modifiers/smoothquant/utils.py b/src/llmcompressor/modifiers/smoothquant/utils.py index 8ab38f633..62f544dae 100644 --- a/src/llmcompressor/modifiers/smoothquant/utils.py +++ b/src/llmcompressor/modifiers/smoothquant/utils.py @@ -1,6 +1,6 @@ import functools from collections import namedtuple -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Tuple from loguru import logger @@ -10,7 +10,7 @@ "DEFAULT_SMOOTHQUANT_MAPPINGS", ] -LayerMapType = Tuple[Union[List[str], str], Union[List[str], str]] +LayerMapType = Tuple[List[str], str] LayerMap: LayerMapType = namedtuple("LayerMap", ["balance_layers", "smooth_layers"]) DEFAULT_SMOOTHQUANT_MAPPINGS: List[LayerMap] = [ diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 740fe3922..b035fe242 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -9,9 +9,9 @@ TransformScheme, apply_transform_config, ) -from torch.utils._pytree import tree_flatten from compressed_tensors.utils import TorchDtype, get_head_dim from pydantic import Field, ValidationInfo, field_validator +from torch.utils._pytree import tree_flatten from transformers import PreTrainedModel from llmcompressor.core import Event, EventType, State @@ -204,7 +204,7 @@ def _fuse_norms(self, model: PreTrainedModel): for mapping in self.norm_mappings: for norm, *linears in match_modules_set( model, (mapping.norm, *mapping.linears) - ): + ): # match_modules_set returns a list of lists assert len(norm) == 1 fuse_norm_linears(norm[0], tree_flatten(linears)[0]) diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index 6d2152fe4..fc4da3d04 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -10,6 +10,7 @@ import torch from compressed_tensors import InternalModule from compressed_tensors.quantization.utils import is_module_quantized +from loguru import logger from torch.nn import Linear, Module, Parameter from torch.nn.modules.conv import _ConvNd from transformers import PreTrainedModel @@ -369,3 +370,16 @@ def get_layer_by_name(layer_name: str, module: Module) -> Module: if not layer_name: return module return attrgetter(layer_name)(module) + + +def get_module_to_name_dict(model: Module) -> dict[Module:str]: + module_to_name = {} + for name, module in model.named_modules(): + if module in module_to_name: + logger.info( + f"Warning, {name} and {module_to_name[module]} both " + "share the same module the same module, " + "may have trouble resolving mappings." + ) + module_to_name[module] = name + return module_to_name diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index b4161e7b3..5ed6594ee 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -222,21 +222,20 @@ def test_get_lowest_non_module_list_ancestor(): ) } ) - - ancestor_name, ancestor = get_lowest_non_module_list_ancestor( - "", model - ) + + ancestor_name, ancestor = get_lowest_non_module_list_ancestor("", model) assert ancestor_name == "" and ancestor == model - ancestor_name, ancestor = get_lowest_non_module_list_ancestor( - "experts", model - ) + ancestor_name, ancestor = get_lowest_non_module_list_ancestor("experts", model) assert ancestor_name == "" and ancestor == model ancestor_name, ancestor = get_lowest_non_module_list_ancestor( "experts.1.gate_proj", model ) - assert ancestor_name == "experts.1.gate_proj" and ancestor == model["experts"][1]["gate_proj"] + assert ( + ancestor_name == "experts.1.gate_proj" + and ancestor == model["experts"][1]["gate_proj"] + ) @pytest.mark.unit @@ -298,4 +297,3 @@ def test_moe_multiple_balance_layers(): assert mapping.parent_name == "layer.mlp" assert mapping.parent == mlp - From 4a45eca03147dd237f7996c759b9d4737787786d Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 20:55:30 +0000 Subject: [PATCH 09/12] fixing tests to old versions Summary Signed-off-by: HDCharles --- .../logarithmic_equalization/test_pytorch.py | 11 ++++----- .../modifiers/smoothquant/test_pytorch.py | 24 +++++++++---------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py index f3e948469..3f111aa74 100644 --- a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py @@ -8,18 +8,15 @@ @pytest.mark.unit def test_log_equalization_mapping(state): - # Use regex patterns with parent-scoped search - # Searches for balance layers within the parent of smooth layer - mappings = [(["re:^fc2$"], "re:.*block1\\.fc1$")] + mappings = [(["seq.fc2"], "seq.block1.fc1")] modifier = LogarithmicEqualizationModifier(mappings=mappings) modifier.ignore = [] modifier.resolved_mappings_ = modifier._resolve_mappings(state.model) - assert len(modifier.resolved_mappings_) == 1 + assert len(modifier.resolved_mappings_) == len(mappings) mapping = modifier.resolved_mappings_[0] - assert mapping.smooth_name == "seq.block1.fc1" + assert mapping.smooth_name == mappings[0][1] assert isinstance(mapping.smooth_layer, Linear) - assert len(mapping.balance_layers) == 1 - assert isinstance(mapping.balance_layers[0], Linear) + assert isinstance(mapping.balance_layers[0], Linear) \ No newline at end of file diff --git a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py index aefa6b957..0abe22523 100644 --- a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py @@ -4,23 +4,23 @@ from llmcompressor.modifiers.smoothquant import SmoothQuantModifier +import pytest +from torch.nn import Linear + +from llmcompressor.modifiers.smoothquant import SmoothQuantModifier + + @pytest.mark.unit def test_smooth_quant_mapping(state): - # Use regex patterns with parent-scoped search - # ^fc1$ matches only direct child "fc1", not nested "block1.fc1" - mappings = [(["re:^fc1$"], "re:.*fc2$")] + mappings = [(["seq.fc1"], "seq.fc2")] modifier = SmoothQuantModifier(mappings=mappings) modifier.ignore = [] modifier.resolved_mappings_ = modifier._resolve_mappings(state.model) - # Should match seq.fc2 and block1.fc2 (both end with fc2) - assert len(modifier.resolved_mappings_) == 2 + assert len(modifier.resolved_mappings_) == len(mappings) - # Verify seq.fc2 mapping - should find only seq.fc1 (direct child) - seq_mapping = [ - m for m in modifier.resolved_mappings_ if m.smooth_name == "seq.fc2" - ][0] - assert isinstance(seq_mapping.smooth_layer, Linear) - assert len(seq_mapping.balance_layers) == 1 - assert isinstance(seq_mapping.balance_layers[0], Linear) + mapping = modifier.resolved_mappings_[0] + assert mapping.smooth_name == mappings[0][1] + assert isinstance(mapping.smooth_layer, Linear) + assert isinstance(mapping.balance_layers[0], Linear) \ No newline at end of file From c7826270bbc3c2adde1a0301d71d05b52a590920 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 21:11:42 +0000 Subject: [PATCH 10/12] formatting Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 11 +++++------ src/llmcompressor/utils/pytorch/module.py | 3 +-- .../logarithmic_equalization/test_pytorch.py | 2 +- .../pytorch/modifiers/smoothquant/test_pytorch.py | 8 +------- 4 files changed, 8 insertions(+), 16 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index a28e5c326..e136bef0c 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -362,12 +362,11 @@ def _set_resolved_mappings(self, model: Module) -> None: ) continue - else: - # for multiple balance layers, find lowest common parent - ancestor_name = get_lowest_common_ancestor_name(balance_names) - ancestor_name, ancestor = get_lowest_non_module_list_ancestor( - ancestor_name, model - ) + + ancestor_name = get_lowest_common_ancestor_name(balance_names) + ancestor_name, ancestor = get_lowest_non_module_list_ancestor( + ancestor_name, model + ) resolved_mappings.append( ResolvedMapping( diff --git a/src/llmcompressor/utils/pytorch/module.py b/src/llmcompressor/utils/pytorch/module.py index fc4da3d04..19326c3a9 100644 --- a/src/llmcompressor/utils/pytorch/module.py +++ b/src/llmcompressor/utils/pytorch/module.py @@ -378,8 +378,7 @@ def get_module_to_name_dict(model: Module) -> dict[Module:str]: if module in module_to_name: logger.info( f"Warning, {name} and {module_to_name[module]} both " - "share the same module the same module, " - "may have trouble resolving mappings." + "share the same module, which can result in unexpected behavior" ) module_to_name[module] = name return module_to_name diff --git a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py index 3f111aa74..49d0cdda3 100644 --- a/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/logarithmic_equalization/test_pytorch.py @@ -19,4 +19,4 @@ def test_log_equalization_mapping(state): mapping = modifier.resolved_mappings_[0] assert mapping.smooth_name == mappings[0][1] assert isinstance(mapping.smooth_layer, Linear) - assert isinstance(mapping.balance_layers[0], Linear) \ No newline at end of file + assert isinstance(mapping.balance_layers[0], Linear) diff --git a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py index 0abe22523..ee8844041 100644 --- a/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py +++ b/tests/llmcompressor/pytorch/modifiers/smoothquant/test_pytorch.py @@ -4,12 +4,6 @@ from llmcompressor.modifiers.smoothquant import SmoothQuantModifier -import pytest -from torch.nn import Linear - -from llmcompressor.modifiers.smoothquant import SmoothQuantModifier - - @pytest.mark.unit def test_smooth_quant_mapping(state): mappings = [(["seq.fc1"], "seq.fc2")] @@ -23,4 +17,4 @@ def test_smooth_quant_mapping(state): mapping = modifier.resolved_mappings_[0] assert mapping.smooth_name == mappings[0][1] assert isinstance(mapping.smooth_layer, Linear) - assert isinstance(mapping.balance_layers[0], Linear) \ No newline at end of file + assert isinstance(mapping.balance_layers[0], Linear) From 7d28a11d1ed7284c471372f4e8026aab4eb6a82f Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 21:54:05 +0000 Subject: [PATCH 11/12] comments Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index e136bef0c..4f80f58aa 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -364,9 +364,9 @@ def _set_resolved_mappings(self, model: Module) -> None: continue ancestor_name = get_lowest_common_ancestor_name(balance_names) - ancestor_name, ancestor = get_lowest_non_module_list_ancestor( - ancestor_name, model - ) + # no ModuleList ancestors + while not isinstance((ancestor := model.get_submodule(ancestor_name)), torch.nn.ModuleList): + ancestor_name = ancestor_name.rsplit(".", 1)[0] resolved_mappings.append( ResolvedMapping( From f51afb1863e03d137a5778767f9f00e6ca88bca9 Mon Sep 17 00:00:00 2001 From: HDCharles Date: Wed, 3 Dec 2025 21:56:12 +0000 Subject: [PATCH 12/12] format Summary Signed-off-by: HDCharles --- src/llmcompressor/modifiers/awq/base.py | 31 +++-------------- .../llmcompressor/modifiers/awq/test_base.py | 34 ------------------- 2 files changed, 4 insertions(+), 61 deletions(-) diff --git a/src/llmcompressor/modifiers/awq/base.py b/src/llmcompressor/modifiers/awq/base.py index 4f80f58aa..f1cb3930c 100644 --- a/src/llmcompressor/modifiers/awq/base.py +++ b/src/llmcompressor/modifiers/awq/base.py @@ -32,7 +32,6 @@ from llmcompressor.utils.fsdp.helpers import get_fsdp_parent from llmcompressor.utils.helpers import calibration_forward_context from llmcompressor.utils.pytorch.module import ( - get_layer_by_name, get_module_to_name_dict, ) @@ -365,7 +364,10 @@ def _set_resolved_mappings(self, model: Module) -> None: ancestor_name = get_lowest_common_ancestor_name(balance_names) # no ModuleList ancestors - while not isinstance((ancestor := model.get_submodule(ancestor_name)), torch.nn.ModuleList): + while not isinstance( + (ancestor := model.get_submodule(ancestor_name)), + torch.nn.ModuleList, + ): ancestor_name = ancestor_name.rsplit(".", 1)[0] resolved_mappings.append( @@ -798,28 +800,3 @@ def _accumulate_mean( new_count = prev_count + num_added return (prev_sum + sum_added) / new_count, new_count - - -def get_lowest_non_module_list_ancestor(name, module: Module) -> tuple[str, Module]: - """ - Given a name and a model, finds lowest ancestor of - named module that's not a ModuleList - i.e. module_list.module_dict.module_list -> module_list.module_dict - i.e. module_list.module_dict -> module_list.module_dict - (self is an ancestor of self) - - NOTE: This is needed because ModuleLists don't play - nicely with hooks because their forward method is never directly - called for MoE models. See Qwen3MoeSparseMoeBlock for example, experts - are selected based on router output and their forward method is called. - https://github.com/huggingface/transformers/blob/v4.52.4/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py#L233 - - Returns name of module and pointer to module - """ - while True: - if name == "": - return "", module - current_module = get_layer_by_name(name, module) - if not isinstance(current_module, torch.nn.ModuleList): - return name, current_module - name = ".".join(name.split(".")[:-1]) diff --git a/tests/llmcompressor/modifiers/awq/test_base.py b/tests/llmcompressor/modifiers/awq/test_base.py index 5ed6594ee..5d82da29a 100644 --- a/tests/llmcompressor/modifiers/awq/test_base.py +++ b/tests/llmcompressor/modifiers/awq/test_base.py @@ -5,7 +5,6 @@ from torch.nn import Linear from llmcompressor.modifiers.awq import AWQMapping, AWQModifier -from llmcompressor.modifiers.awq.base import get_lowest_non_module_list_ancestor from llmcompressor.modifiers.factory import ModifierFactory @@ -205,39 +204,6 @@ def test_validate(): AWQModifier(scheme="W4A16", duo_scaling="x") -@pytest.mark.unit -def test_get_lowest_non_module_list_ancestor(): - model = torch.nn.ModuleDict( - { - "experts": torch.nn.ModuleList( - [ - torch.nn.ModuleDict( - { - "gate_proj": Linear(4, 2), - "down_proj": Linear(2, 4), - } - ) - for _ in range(10) - ] - ) - } - ) - - ancestor_name, ancestor = get_lowest_non_module_list_ancestor("", model) - assert ancestor_name == "" and ancestor == model - - ancestor_name, ancestor = get_lowest_non_module_list_ancestor("experts", model) - assert ancestor_name == "" and ancestor == model - - ancestor_name, ancestor = get_lowest_non_module_list_ancestor( - "experts.1.gate_proj", model - ) - assert ( - ancestor_name == "experts.1.gate_proj" - and ancestor == model["experts"][1]["gate_proj"] - ) - - @pytest.mark.unit def test_moe_multiple_balance_layers(): """Test AWQ mapping with multiple balance layers in MoE architecture"""