From d97f2887403483fa6eacfe155a0cac829159de91 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 1 Dec 2025 17:55:40 +0000 Subject: [PATCH 01/15] generalize embeddings utils Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/autoround/base.py | 15 +-- .../quantization/quantization/mixin.py | 12 +- .../modifiers/transform/quip/base.py | 29 +++-- .../modifiers/transform/spinquant/base.py | 14 +-- .../compression/compressed_tensors_utils.py | 119 +----------------- src/llmcompressor/typing.py | 14 ++- src/llmcompressor/utils/__init__.py | 1 + src/llmcompressor/utils/transformers.py | 68 ++++++++++ 8 files changed, 112 insertions(+), 160 deletions(-) create mode 100644 src/llmcompressor/utils/transformers.py diff --git a/src/llmcompressor/modifiers/autoround/base.py b/src/llmcompressor/modifiers/autoround/base.py index 2480751a9b..6b6d4fe3ed 100644 --- a/src/llmcompressor/modifiers/autoround/base.py +++ b/src/llmcompressor/modifiers/autoround/base.py @@ -20,10 +20,8 @@ from llmcompressor.modifiers import Modifier from llmcompressor.modifiers.quantization.calibration import apply_calibration_status from llmcompressor.modifiers.quantization.quantization import QuantizationMixin -from llmcompressor.transformers.compression.compressed_tensors_utils import ( - untie_if_target_shared_embedding, -) -from llmcompressor.utils.pytorch.module import get_no_split_params +from llmcompressor.utils import targets_embeddings, untie_word_embeddings +from llmcompressor.utils.pytorch import get_no_split_params __all__ = ["AutoRoundModifier"] @@ -109,7 +107,6 @@ class AutoRoundModifier(Modifier, QuantizationMixin): enable_torch_compile: bool = True # private variables - _module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict) _all_module_input: Dict[str, List[Tuple]] = PrivateAttr(default_factory=dict) _q_input: Optional[torch.Tensor] = PrivateAttr(default=None) @@ -124,10 +121,6 @@ def on_initialize(self, state: State, **kwargs) -> bool: QuantizationMixin.initialize_quantization(self, state.model) # prepare module names - self._module_names = { - m: name - for name, m in match_named_modules(state.model, self.targets, self.ignore) - } self._add_temporary_names(state.model) # freeze all model parameters for _, param in state.model.named_parameters(): @@ -142,7 +135,9 @@ def start_calibration(self, model: torch.nn.Module): :param model: model to prepare for calibration """ - untie_if_target_shared_embedding(model, self._module_names.values()) + targets = match_named_modules(model, self.targets, self.ignore) + if targets_embeddings(model, targets): + untie_word_embeddings(model) for _, module in match_named_modules(model, self.targets, self.ignore): # Note: No need to register observers for auto-round diff --git a/src/llmcompressor/modifiers/quantization/quantization/mixin.py b/src/llmcompressor/modifiers/quantization/quantization/mixin.py index 42264af22e..4a63b3922b 100644 --- a/src/llmcompressor/modifiers/quantization/quantization/mixin.py +++ b/src/llmcompressor/modifiers/quantization/quantization/mixin.py @@ -34,9 +34,7 @@ reset_quantization_status, ) from llmcompressor.modifiers.utils.hooks import HooksMixin -from llmcompressor.transformers.compression.compressed_tensors_utils import ( - untie_if_target_shared_embedding, -) +from llmcompressor.utils import targets_embeddings, untie_word_embeddings __all__ = ["QuantizationMixin"] @@ -182,11 +180,9 @@ def start_calibration(self, model: torch.nn.Module): :param model: model to prepare for calibration """ - - matched_module_generator = ( - x[1] for x in match_named_modules(model, self.resolved_targets, self.ignore) - ) - untie_if_target_shared_embedding(model, matched_module_generator) + targets = match_named_modules(model, self.resolved_targets, self.ignore) + if targets_embeddings(model, targets): + untie_word_embeddings(model) for _, module in match_named_modules(model, self.resolved_targets, self.ignore): self._initialize_observers(module) diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py index ace8d64fd4..f69260b273 100644 --- a/src/llmcompressor/modifiers/transform/quip/base.py +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -12,9 +12,8 @@ from llmcompressor.core import Event, EventType, State from llmcompressor.modifiers import Modifier -from llmcompressor.transformers.compression.compressed_tensors_utils import ( - untie_if_target_shared_embedding, -) +from llmcompressor.typing import NamedModules +from llmcompressor.utils import targets_embeddings, untie_word_embeddings __all__ = ["QuIPModifier"] @@ -102,18 +101,13 @@ def on_initialize(self, state: State, **kwargs) -> bool: def on_start(self, state: State, event: Event, **kwargs): self.started_ = True - - def matched_module_generator(): - for scheme in self.transform_config.config_groups.values(): - for arg in scheme.apply: - gen = match_named_modules(state.model, arg.targets, arg.ignore) - for _, module in gen: - yield module + model = state.model # Untie embeddings if they will be targeted by transforms - untie_if_target_shared_embedding(state.model, matched_module_generator()) + if targets_embeddings(model, self.get_targets()): + untie_word_embeddings(model) - apply_transform_config(state.model, self.transform_config) + apply_transform_config(model, self.transform_config) def on_event(self, state: State, event: Event, **kwargs): if event.type_ == EventType.CALIBRATION_EPOCH_START: @@ -136,6 +130,17 @@ def on_finalize(self, state: State, **kwargs) -> bool: return True + def get_targets(self, model: torch.nn.Module) -> NamedModules: + if not self.initialized_: + raise ValueError("Cannot get targets before modifier has been initialized") + + return [ + (name, module) + for scheme in self.transform_config.config_groups.values() + for arg in scheme.apply + for name, module in match_named_modules(model, arg.targets, arg.ignore) + ] + def _create_config(self) -> TransformConfig: config_groups = dict() if "v" in self.rotations: diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 8d84e860f2..a4f5bd04bc 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -16,9 +16,7 @@ from llmcompressor.core import Event, EventType, State from llmcompressor.modeling import center_embeddings, fuse_norm_linears from llmcompressor.modifiers import Modifier -from llmcompressor.transformers.compression.compressed_tensors_utils import ( - untie_word_embeddings, -) +from llmcompressor.utils import untie_word_embeddings from .mappings import SpinQuantMapping, infer_mapping_from_model from .norm_mappings import NormMapping, infer_norm_mapping_from_model @@ -151,14 +149,16 @@ def on_initialize(self, state: State, **kwargs) -> bool: @torch.no_grad() def on_start(self, state: State, event: Event, **kwargs): self.started_ = True + model = state.model # needed any time embeddings/lm_head is modified - untie_word_embeddings(state.model) + untie_word_embeddings(model) + # needs to happen after the model has been hooked to execute on the GPU # otherwise we're applying weight transforms on CPU - self._center_embeddings(state.model) - self._fuse_norms(state.model) - apply_transform_config(state.model, self.transform_config) + self._center_embeddings(model) + self._fuse_norms(model) + apply_transform_config(model, self.transform_config) def on_event(self, state: State, event: Event, **kwargs): if event.type_ == EventType.CALIBRATION_EPOCH_START: diff --git a/src/llmcompressor/transformers/compression/compressed_tensors_utils.py b/src/llmcompressor/transformers/compression/compressed_tensors_utils.py index ddb5b97f43..8393c6ac7b 100644 --- a/src/llmcompressor/transformers/compression/compressed_tensors_utils.py +++ b/src/llmcompressor/transformers/compression/compressed_tensors_utils.py @@ -1,6 +1,5 @@ import os import weakref -from collections.abc import Generator from functools import wraps from typing import Optional @@ -9,9 +8,6 @@ from compressed_tensors import ( ModelCompressor, SparsityCompressionConfig, - delete_offload_parameter, - has_offloaded_params, - register_offload_parameter, ) from compressed_tensors.config import CompressionFormat from loguru import logger @@ -25,7 +21,7 @@ from llmcompressor.transformers.utils import RECIPE_FILE_NAME from llmcompressor.transformers.utils.helpers import infer_recipe_from_model_path -__all__ = ["modify_save_pretrained", "untie_word_embeddings"] +__all__ = ["modify_save_pretrained"] def modify_save_pretrained(model: PreTrainedModel): @@ -118,119 +114,6 @@ def save_pretrained_wrapper( model.save_pretrained = save_pretrained_compressed(model.save_pretrained) -def untie_word_embeddings(model: PreTrainedModel): - """ - Patches bug where HF transformers will fail to untie weights under specific - circumstances (https://github.com/huggingface/transformers/issues/33689). - - This function detects those cases and unties the tensors if applicable - - :param model: model to fix - """ - try: - input_embed = model.get_input_embeddings() - output_embed = model.get_output_embeddings() - except NotImplementedError as e: - logger.warning( - f"cannot untie model of type {model.__class__} which doesn't have " - f"get_input_embeddings and get_output_embeddings implmented\n{e}" - ) - return - - for module in (input_embed, output_embed): - if module is None or not hasattr(module, "weight"): - logger.warning(f"Cannot untie {module} which does not have weight param") - continue - - # this could be replaced by a `get_offloaded_parameter` util - if not has_offloaded_params(module): - untied_data = module.weight.data.clone() - else: - untied_data = module._hf_hook.weights_map["weight"].clone() - - requires_grad = module.weight.requires_grad - new_parameter = torch.nn.Parameter(untied_data, requires_grad=requires_grad) - delete_offload_parameter(module, "weight") - register_offload_parameter(module, "weight", new_parameter) - - if hasattr(model.config, "tie_word_embeddings"): - model.config.tie_word_embeddings = False - - -def _get_embeddings_or_warn( - model: torch.nn.Module, -) -> tuple[torch.nn.Module | None, torch.nn.Module | None]: - if not ( - hasattr(model, "get_input_embeddings") - and hasattr(model, "get_output_embeddings") - ): - logger.warning( - f"{model.__class__} doesn't have attribute get_input_embeddings and" - " get_output_embeddings implemented." - "\nThis can cause" - " problems when quantizing layers with shared weights" - ) - return None, None - - try: - input_embeddings, output_embeddings = ( - model.get_input_embeddings(), - model.get_output_embeddings(), - ) - except NotImplementedError as e: - logger.warning( - f"{model.__class__} doesn't have get_input_embeddings and " - "get_output_embeddings implemented." - "\nThis can cause" - " problems when quantizing layers with shared weights" - f"\n{e}" - ) - return None, None - - if not ( - isinstance(input_embeddings, torch.nn.Module) - and isinstance(output_embeddings, torch.nn.Module) - ): - logger.warning( - f"expected modules from {model.__class__} get_input_embeddings and" - f" get_output_embeddings but got {type(input_embeddings)}" - f" and {type(output_embeddings)}." - "\nThis can cause" - " problems when quantizing layers with shared weights" - ) - return None, None - return input_embeddings, output_embeddings - - -def untie_if_target_shared_embedding( - model: torch.nn.Module, matched_module_generator: Generator[torch.nn.Module] -): - """ - Helper method that checks for shared input/output embedding and unties them - if either shows up in the matched_module_generator - - :param model: model to untie if embeddings are shared and targeted by - matched_module_generator - :param matched_module_generator: Generator of all modules (not names) which - will be modified by quantization or transformation - """ - input_embeddings, output_embeddings = _get_embeddings_or_warn(model) - - if None in (input_embeddings, output_embeddings): # if couldn't find embeddings - return - - if ( - input_embeddings.weight is not output_embeddings.weight - ): # if not shared, can ignore - return - - # if shared, check if either is targeted - for module in matched_module_generator: - if module in (input_embeddings, output_embeddings): - untie_word_embeddings(model) - return - - def get_model_compressor( model: torch.nn.Module, sparsity_config: Optional[SparsityCompressionConfig] = None, diff --git a/src/llmcompressor/typing.py b/src/llmcompressor/typing.py index 1d2d195158..233f4df56a 100644 --- a/src/llmcompressor/typing.py +++ b/src/llmcompressor/typing.py @@ -2,8 +2,9 @@ Defines type aliases for the llm-compressor library. """ -from typing import Union +from typing import Iterable +import torch from datasets import Dataset, DatasetDict, IterableDataset from transformers import ( BaseImageProcessor, @@ -13,9 +14,12 @@ ) # Tokenizer or Processor. Processors do not inherit from a unified base class -Processor = Union[ - PreTrainedTokenizer, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin -] +Processor = ( + PreTrainedTokenizer | BaseImageProcessor | FeatureExtractionMixin | ProcessorMixin +) # Supported dataset types, IterableDataset is a streamed dataset -DatasetType = Union[Dataset, DatasetDict, IterableDataset] +DatasetType = Dataset | DatasetDict | IterableDataset + +# Torch types +NamedModules = Iterable[tuple[str, torch.nn.Module]] diff --git a/src/llmcompressor/utils/__init__.py b/src/llmcompressor/utils/__init__.py index 42daa5bce6..4da185855b 100644 --- a/src/llmcompressor/utils/__init__.py +++ b/src/llmcompressor/utils/__init__.py @@ -4,5 +4,6 @@ # ruff: noqa +from .transformers import * from .dev import * from .helpers import * diff --git a/src/llmcompressor/utils/transformers.py b/src/llmcompressor/utils/transformers.py new file mode 100644 index 0000000000..bb20000d0f --- /dev/null +++ b/src/llmcompressor/utils/transformers.py @@ -0,0 +1,68 @@ +import torch +from compressed_tensors import has_offloaded_params +from loguru import logger +from transformers import PreTrainedModel + +from llmcompressor.typing import NamedModules + +__all__ = ["untie_word_embeddings", "targets_embeddings", "get_embeddings"] + + +def untie_word_embeddings(model: PreTrainedModel): + """Untie word embeddings, if possible.""" + input_embed, output_embed = get_embeddings(model) + if input_embed is None or output_embed is None: + logger.warning( + "Cannot untie embeddings. If this model has word embeddings, please " + "implement `get_input_embeddings` and `get_output_embeddings`" + ) + return + + # clone data to untie + for module in (input_embed, output_embed): + if not has_offloaded_params(module): + module.weight.data = module.weight.data.clone() + else: + weights_map = module._hf_hook.weights_map + weights_map["weight"] = weights_map["weight"].clone() + + # modify model config + if hasattr(model.config, "tie_word_embeddings"): + model.config.tie_word_embeddings = False + + +def targets_embeddings( + model: PreTrainedModel, + targets: NamedModules, + check_input: bool = True, + check_output: bool = True, +) -> bool: + input_embed, output_embed = get_embeddings(model) + if check_input and input_embed is None or check_output and output_embed is None: + logger.warning( + "Cannot check embeddings. If this model has word embeddings, please " + "implement `get_input_embeddings` and `get_output_embeddings`" + ) + return False + + targets = set(module for _, module in targets) + return (not check_input or input_embed in targets) and ( + not check_output or output_embed in targets + ) + + +def get_embeddings( + model: PreTrainedModel, +) -> tuple[torch.nn.Module | None, torch.nn.Module | None]: + try: + input_embed = model.get_input_embeddings() + + except (AttributeError, NotImplementedError): + input_embed = None + + try: + output_embed = model.get_output_embeddings() + except (AttributeError, NotImplementedError): + output_embed = None + + return input_embed, output_embed From f30f4652532e7042a24d4614607dc07fdac16afb Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 1 Dec 2025 18:43:10 +0000 Subject: [PATCH 02/15] add tests Signed-off-by: Kyle Sayers --- src/llmcompressor/entrypoints/utils.py | 2 +- src/llmcompressor/utils/transformers.py | 34 +++++++- .../compression/test_compress_tensor_utils.py | 57 +------------ .../llmcompressor/utils/test_transformers.py | 85 +++++++++++++++++++ 4 files changed, 117 insertions(+), 61 deletions(-) create mode 100644 tests/llmcompressor/utils/test_transformers.py diff --git a/src/llmcompressor/entrypoints/utils.py b/src/llmcompressor/entrypoints/utils.py index 3c1b354ce7..62f7e86de4 100644 --- a/src/llmcompressor/entrypoints/utils.py +++ b/src/llmcompressor/entrypoints/utils.py @@ -29,12 +29,12 @@ from llmcompressor.pytorch.model_load.helpers import parse_dtype from llmcompressor.transformers.compression.compressed_tensors_utils import ( modify_save_pretrained, - untie_word_embeddings, ) from llmcompressor.transformers.utils.helpers import ( is_model_ct_quantized_from_path, ) from llmcompressor.typing import Processor +from llmcompressor.utils import untie_word_embeddings from llmcompressor.utils.fsdp.helpers import is_fsdp_model diff --git a/src/llmcompressor/utils/transformers.py b/src/llmcompressor/utils/transformers.py index bb20000d0f..2b5690b313 100644 --- a/src/llmcompressor/utils/transformers.py +++ b/src/llmcompressor/utils/transformers.py @@ -1,6 +1,7 @@ import torch from compressed_tensors import has_offloaded_params from loguru import logger +from torch.nn import Parameter from transformers import PreTrainedModel from llmcompressor.typing import NamedModules @@ -9,7 +10,14 @@ def untie_word_embeddings(model: PreTrainedModel): - """Untie word embeddings, if possible.""" + """ + Untie word embeddings, if possible. This function raises a warning if + embeddings cannot be found in the model definition. + + The model config will be updated to reflect that embeddings are now untied + + :param model: transformers model containing word embeddings + """ input_embed, output_embed = get_embeddings(model) if input_embed is None or output_embed is None: logger.warning( @@ -21,7 +29,8 @@ def untie_word_embeddings(model: PreTrainedModel): # clone data to untie for module in (input_embed, output_embed): if not has_offloaded_params(module): - module.weight.data = module.weight.data.clone() + requires_grad = module.weight.requires_grad + module.weight = Parameter(module.weight.data, requires_grad=requires_grad) else: weights_map = module._hf_hook.weights_map weights_map["weight"] = weights_map["weight"].clone() @@ -37,6 +46,15 @@ def targets_embeddings( check_input: bool = True, check_output: bool = True, ) -> bool: + """ + Returns True if the given targets target the word embeddings of the model + + :param model: containing word embeddings + :param targets: named modules to check + :param check_input: whether to check if input embeddings are targeted + :param check_output: whether to check if output embeddings are targeted + :return: True if embeddings are targeted, False otherwise + """ input_embed, output_embed = get_embeddings(model) if check_input and input_embed is None or check_output and output_embed is None: logger.warning( @@ -46,14 +64,22 @@ def targets_embeddings( return False targets = set(module for _, module in targets) - return (not check_input or input_embed in targets) and ( - not check_output or output_embed in targets + return (check_input and input_embed in targets) or ( + check_output and output_embed in targets ) def get_embeddings( model: PreTrainedModel, ) -> tuple[torch.nn.Module | None, torch.nn.Module | None]: + """ + Returns input and output embeddings of a model. If `get_input_embeddings`/ + `get_output_embeddings` is not implemented on the model, then None will be returned + instead. + + :param model: model to get embeddings from + :return: tuple of containing embedding modules or none + """ try: input_embed = model.get_input_embeddings() diff --git a/tests/llmcompressor/transformers/compression/test_compress_tensor_utils.py b/tests/llmcompressor/transformers/compression/test_compress_tensor_utils.py index 96b5d6fce6..9963ee4c4a 100644 --- a/tests/llmcompressor/transformers/compression/test_compress_tensor_utils.py +++ b/tests/llmcompressor/transformers/compression/test_compress_tensor_utils.py @@ -14,7 +14,6 @@ QuantizationStatus, quantize, ) -from compressed_tensors.utils import align_module_device, update_offload_parameter from torch import nn from transformers import AutoConfig, AutoModelForCausalLM from transformers.utils.quantization_config import CompressedTensorsConfig @@ -25,11 +24,11 @@ from llmcompressor.transformers.compression.compressed_tensors_utils import ( get_model_compressor, modify_save_pretrained, - untie_word_embeddings, ) from llmcompressor.transformers.compression.sparsity_metadata_config import ( SparsityConfigMetadata, ) +from llmcompressor.utils import untie_word_embeddings from tests.testing_utils import requires_gpu @@ -283,60 +282,6 @@ def test_model_reload_gpu(offload, torch_dtype, tie_word_embeddings, device, tmp test_model_reload(offload, torch_dtype, tie_word_embeddings, device, tmp_path) -@pytest.mark.parametrize( - "offload,torch_dtype,tie_word_embeddings,device", - [ - (False, torch.float16, False, "cpu"), - (False, torch.float32, False, "cpu"), - (True, torch.float32, False, "cpu"), - (False, torch.float16, True, "cpu"), - (False, torch.float32, True, "cpu"), - (True, torch.float16, True, "cpu"), - (True, torch.float32, True, "cpu"), - ], -) -def test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device): - """ - Test whether model offloading breaks tied/untied embeddings - """ - # load model - model_path = "nm-testing/tinysmokellama-3.2" - model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype) - if offload: - model = dispatch_model(model, {"": device}, force_hooks=True) - else: - model = model.to(device) - - if not tie_word_embeddings: - untie_word_embeddings(model) - - # modify lm head - with torch.no_grad(), align_module_device(model.lm_head): - update_offload_parameter(model.lm_head, "weight", model.lm_head.weight + 1) - - with align_module_device(model.lm_head), align_module_device( - model.model.embed_tokens - ): - if tie_word_embeddings: - assert model.lm_head.weight is model.model.embed_tokens.weight - assert model.config.tie_word_embeddings - else: - assert model.lm_head.weight is not model.model.embed_tokens.weight - assert not model.config.tie_word_embeddings - - -@requires_gpu -@pytest.mark.parametrize( - "offload,torch_dtype,tie_word_embeddings,device", - [ - (False, torch.float32, False, "cuda:0"), - (False, torch.float32, True, "cuda:0"), - ], -) -def test_model_shared_tensors_gpu(offload, torch_dtype, tie_word_embeddings, device): - test_model_shared_tensors(offload, torch_dtype, tie_word_embeddings, device) - - @requires_gpu @pytest.mark.parametrize( "model_stub, recipe, sparse_format, quant_format", diff --git a/tests/llmcompressor/utils/test_transformers.py b/tests/llmcompressor/utils/test_transformers.py new file mode 100644 index 0000000000..deb9c4a074 --- /dev/null +++ b/tests/llmcompressor/utils/test_transformers.py @@ -0,0 +1,85 @@ +import pytest +import torch +from accelerate import dispatch_model +from compressed_tensors import align_module_device, update_offload_parameter +from transformers import AutoModelForCausalLM + +from llmcompressor.utils import targets_embeddings, untie_word_embeddings +from tests.testing_utils import requires_gpu + + +@pytest.mark.parametrize( + "offload,torch_dtype,tie_word_embeddings,device", + [ + (False, torch.float16, False, "cpu"), + (False, torch.float32, False, "cpu"), + (True, torch.float32, False, "cpu"), + (False, torch.float16, True, "cpu"), + (False, torch.float32, True, "cpu"), + (True, torch.float16, True, "cpu"), + (True, torch.float32, True, "cpu"), + ], +) +def test_untie_word_embeddings(offload, torch_dtype, tie_word_embeddings, device): + """ + Test whether model offloading breaks tied/untied embeddings + """ + # load model + model_path = "nm-testing/tinysmokellama-3.2" + model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch_dtype) + if offload: + model = dispatch_model(model, {"": device}, force_hooks=True) + else: + model = model.to(device) + + if not tie_word_embeddings: + untie_word_embeddings(model) + + # modify lm head + with torch.no_grad(), align_module_device(model.lm_head): + update_offload_parameter(model.lm_head, "weight", model.lm_head.weight + 1) + + with align_module_device(model.lm_head), align_module_device( + model.model.embed_tokens + ): + if tie_word_embeddings: + assert model.lm_head.weight is model.model.embed_tokens.weight + assert model.config.tie_word_embeddings + else: + assert model.lm_head.weight is not model.model.embed_tokens.weight + assert not model.config.tie_word_embeddings + + +@requires_gpu +@pytest.mark.parametrize( + "offload,torch_dtype,tie_word_embeddings,device", + [ + (False, torch.float32, False, "cuda:0"), + (False, torch.float32, True, "cuda:0"), + ], +) +def test_untie_word_embeddings_gpu(offload, torch_dtype, tie_word_embeddings, device): + test_untie_word_embeddings(offload, torch_dtype, tie_word_embeddings, device) + + +def test_targets_embeddings(): + model_path = "nm-testing/tinysmokellama-3.2" + model = AutoModelForCausalLM.from_pretrained(model_path) + + targets = {"embed_tokens": model.model.embed_tokens}.items() + assert targets_embeddings(model, targets, check_input=True, check_output=True) + assert targets_embeddings(model, targets, check_input=True, check_output=False) + assert not targets_embeddings(model, targets, check_input=False, check_output=True) + assert not targets_embeddings(model, targets, check_input=False, check_output=False) + + targets = {"lm_head": model.lm_head}.items() + assert targets_embeddings(model, targets, check_input=True, check_output=True) + assert not targets_embeddings(model, targets, check_input=True, check_output=False) + assert targets_embeddings(model, targets, check_input=False, check_output=True) + assert not targets_embeddings(model, targets, check_input=False, check_output=False) + + targets = {}.items() + assert not targets_embeddings(model, targets, check_input=True, check_output=True) + assert not targets_embeddings(model, targets, check_input=True, check_output=False) + assert not targets_embeddings(model, targets, check_input=False, check_output=True) + assert not targets_embeddings(model, targets, check_input=False, check_output=False) From d671faf68b181ae507c11ef55e1cadd51830793a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 1 Dec 2025 19:23:08 +0000 Subject: [PATCH 03/15] fix assignment Signed-off-by: Kyle Sayers --- src/llmcompressor/utils/transformers.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/utils/transformers.py b/src/llmcompressor/utils/transformers.py index 2b5690b313..2cc04578d7 100644 --- a/src/llmcompressor/utils/transformers.py +++ b/src/llmcompressor/utils/transformers.py @@ -1,5 +1,5 @@ import torch -from compressed_tensors import has_offloaded_params +from compressed_tensors import has_offloaded_params, register_offload_parameter from loguru import logger from torch.nn import Parameter from transformers import PreTrainedModel @@ -29,11 +29,13 @@ def untie_word_embeddings(model: PreTrainedModel): # clone data to untie for module in (input_embed, output_embed): if not has_offloaded_params(module): - requires_grad = module.weight.requires_grad - module.weight = Parameter(module.weight.data, requires_grad=requires_grad) + data = module.weight.data else: - weights_map = module._hf_hook.weights_map - weights_map["weight"] = weights_map["weight"].clone() + data = module._hf_hook.weights_map["weight"] + + requires_grad = module.weight.requires_grad + untied_param = Parameter(data.clone(), requires_grad=requires_grad) + register_offload_parameter(module, "weight", untied_param) # modify model config if hasattr(model.config, "tie_word_embeddings"): From 9f4f89f457c81c9d222007964b25c9a3b6d9146a Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 1 Dec 2025 20:10:47 +0000 Subject: [PATCH 04/15] more precision for spinquant, fix typo in quip Signed-off-by: Kyle Sayers --- .../modifiers/transform/quip/base.py | 4 ++-- .../modifiers/transform/spinquant/base.py | 20 ++++++++++++++++--- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/quip/base.py b/src/llmcompressor/modifiers/transform/quip/base.py index f69260b273..4ac2e9c881 100644 --- a/src/llmcompressor/modifiers/transform/quip/base.py +++ b/src/llmcompressor/modifiers/transform/quip/base.py @@ -104,7 +104,7 @@ def on_start(self, state: State, event: Event, **kwargs): model = state.model # Untie embeddings if they will be targeted by transforms - if targets_embeddings(model, self.get_targets()): + if targets_embeddings(model, self._get_targets(model)): untie_word_embeddings(model) apply_transform_config(model, self.transform_config) @@ -130,7 +130,7 @@ def on_finalize(self, state: State, **kwargs) -> bool: return True - def get_targets(self, model: torch.nn.Module) -> NamedModules: + def _get_targets(self, model: torch.nn.Module) -> NamedModules: if not self.initialized_: raise ValueError("Cannot get targets before modifier has been initialized") diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index a4f5bd04bc..1fcb7f0680 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -16,7 +16,8 @@ from llmcompressor.core import Event, EventType, State from llmcompressor.modeling import center_embeddings, fuse_norm_linears from llmcompressor.modifiers import Modifier -from llmcompressor.utils import untie_word_embeddings +from llmcompressor.typing import NamedModules +from llmcompressor.utils import targets_embeddings, untie_word_embeddings from .mappings import SpinQuantMapping, infer_mapping_from_model from .norm_mappings import NormMapping, infer_norm_mapping_from_model @@ -151,8 +152,10 @@ def on_start(self, state: State, event: Event, **kwargs): self.started_ = True model = state.model - # needed any time embeddings/lm_head is modified - untie_word_embeddings(model) + # untie embeddings and lm_head so they can be rotated separately + # in the future, these can remain tied to save memory + if targets_embeddings(model, self._get_targets(model)): + untie_word_embeddings(model) # needs to happen after the model has been hooked to execute on the GPU # otherwise we're applying weight transforms on CPU @@ -181,6 +184,17 @@ def on_finalize(self, state: State, **kwargs) -> bool: return True + def _get_targets(self, model: torch.nn.Module) -> NamedModules: + if not self.initialized_: + raise ValueError("Cannot get targets before modifier has been initialized") + + return [ + (name, module) + for scheme in self.transform_config.config_groups.values() + for arg in scheme.apply + for name, module in match_named_modules(model, arg.targets, arg.ignore) + ] + def _center_embeddings(self, model: PreTrainedModel): for _, embedding in match_named_modules( model, [self.mappings.embedding], warn_on_fail=True From 90d1cace621f425cb669e51e26cada8e8e341177 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 1 Dec 2025 20:14:00 +0000 Subject: [PATCH 05/15] parenthesis for clarity Signed-off-by: Kyle Sayers --- src/llmcompressor/utils/transformers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/utils/transformers.py b/src/llmcompressor/utils/transformers.py index 2cc04578d7..9b4831a290 100644 --- a/src/llmcompressor/utils/transformers.py +++ b/src/llmcompressor/utils/transformers.py @@ -58,7 +58,7 @@ def targets_embeddings( :return: True if embeddings are targeted, False otherwise """ input_embed, output_embed = get_embeddings(model) - if check_input and input_embed is None or check_output and output_embed is None: + if (check_input and input_embed) is None or (check_output and output_embed is None): logger.warning( "Cannot check embeddings. If this model has word embeddings, please " "implement `get_input_embeddings` and `get_output_embeddings`" From 1615a1abc00944cde546aaac59d2daead95fafc2 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 1 Dec 2025 23:42:28 +0000 Subject: [PATCH 06/15] always untie for spinquant Signed-off-by: Kyle Sayers --- src/llmcompressor/modifiers/transform/spinquant/base.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/llmcompressor/modifiers/transform/spinquant/base.py b/src/llmcompressor/modifiers/transform/spinquant/base.py index 1fcb7f0680..aac21d29c4 100644 --- a/src/llmcompressor/modifiers/transform/spinquant/base.py +++ b/src/llmcompressor/modifiers/transform/spinquant/base.py @@ -17,7 +17,7 @@ from llmcompressor.modeling import center_embeddings, fuse_norm_linears from llmcompressor.modifiers import Modifier from llmcompressor.typing import NamedModules -from llmcompressor.utils import targets_embeddings, untie_word_embeddings +from llmcompressor.utils import untie_word_embeddings from .mappings import SpinQuantMapping, infer_mapping_from_model from .norm_mappings import NormMapping, infer_norm_mapping_from_model @@ -152,10 +152,8 @@ def on_start(self, state: State, event: Event, **kwargs): self.started_ = True model = state.model - # untie embeddings and lm_head so they can be rotated separately - # in the future, these can remain tied to save memory - if targets_embeddings(model, self._get_targets(model)): - untie_word_embeddings(model) + # untie embeddings to avoid unintended effects of `_center_embeddings` + untie_word_embeddings(model) # needs to happen after the model has been hooked to execute on the GPU # otherwise we're applying weight transforms on CPU From 314a930a60efd153ec114d2842e7abe8e80f5401 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 1 Dec 2025 19:53:36 +0000 Subject: [PATCH 07/15] implement requires_lm_head_calibration, disable_lm_head Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/basic/pipeline.py | 16 ++++- .../pipelines/sequential/pipeline.py | 5 ++ src/llmcompressor/utils/helpers.py | 58 +++++++++++++++++-- 3 files changed, 73 insertions(+), 6 deletions(-) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index e6494fc5e0..0a877880ff 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -6,11 +6,16 @@ from compressed_tensors.utils import get_execution_device from torch.utils.data.dataloader import DataLoader -from llmcompressor.core import LifecycleCallbacks +from llmcompressor.core import LifecycleCallbacks, active_session from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch from llmcompressor.pipelines.registry import CalibrationPipeline from llmcompressor.pytorch.utils.helpers import tensors_to_device -from llmcompressor.utils import calibration_forward_context, dispatch_for_generation +from llmcompressor.utils import ( + calibration_forward_context, + disable_lm_head, + dispatch_for_generation, + requires_lm_head_calibration, +) if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments @@ -38,6 +43,9 @@ def __call__( :param dataloader: loads data for calibration :param dataset_args: dataset arguments relevant to pipelines """ + session = active_session() + modifiers = session.lifecycle.recipe.modifiers + dispatch_for_generation(model) # basic dispatch is identical to generation model_device = get_execution_device(model) @@ -45,6 +53,10 @@ def __call__( with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) + # Optionally disable lm_head + if not requires_lm_head_calibration(model, modifiers): + stack.enter_context(disable_lm_head(model)) + for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) batch = tensors_to_device(batch, model_device) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 511a693b95..658f360311 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -19,6 +19,8 @@ DISABLE_QAC_MODIFIERS, DisableQuantization, calibration_forward_context, + disable_lm_head, + requires_lm_head_calibration, ) if TYPE_CHECKING: @@ -88,6 +90,9 @@ def __call__( # Optionally disable quantization if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) + # Optionally disable lm_head + if not requires_lm_head_calibration(model, modifiers): + stack.enter_context(disable_lm_head(model)) # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device) diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 0be09bd062..4cd5a5a4e3 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -18,15 +18,21 @@ from collections import OrderedDict from io import BytesIO from pathlib import Path -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Tuple, Union from urllib.parse import urlparse import numpy import torch +from compressed_tensors import has_offloaded_params, match_named_modules from compressed_tensors.quantization import disable_quantization, enable_quantization from loguru import logger from transformers import PreTrainedModel +from llmcompressor.utils import get_embeddings, targets_embeddings + +if TYPE_CHECKING: + from llmcompressor.modifiers import Modifier + __all__ = [ "ALL_TOKEN", "ALL_PRUNABLE_TOKEN", @@ -65,6 +71,8 @@ "DisableQuantization", "eval_context", "calibration_forward_context", + "disable_lm_head", + "requires_lm_head_calibration", "patch_attr", "disable_hf_kernels", "DISABLE_QAC_MODIFIERS", @@ -1050,12 +1058,54 @@ def calibration_forward_context(model: torch.nn.Module): - Disable train mode and enable eval mode - Disable hf kernels which could bypass hooks """ - with torch.no_grad(), disable_cache(model), eval_context(model), disable_hf_kernels( - model - ): + with contextlib.ExitStack() as stack: + stack.enter_context(torch.no_grad()) + stack.enter_context(disable_cache(model)) + stack.enter_context(eval_context(model)) + stack.enter_context(disable_hf_kernels(model)) yield +@contextlib.contextmanager +def disable_lm_head(model: torch.nn.Module): + """ + Disable the lm_head of a model by moving it to the meta device. This function + does not untie parameters and restores the model proper loading upon exit + """ + _, lm_head = get_embeddings(model) + if lm_head is not None: + if has_offloaded_params(lm_head): + # keep weight on meta device + with patch_attr(lm_head._hf_hook, "offload", False): + yield + else: + with patch_attr(lm_head, "weight", lm_head.weight.to("meta")): + yield + + else: + logger.warning( + f"Attempted to disable lm_head of instance {model.__class__.__name__}, " + "but was unable to to find lm_head. This may lead to unexpected OOM." + ) + yield + + +def requires_lm_head_calibration( + model: PreTrainedModel, modifiers: Iterable["Modifier"] +) -> bool: + """Returns True if any of the quantization modifers target the lm_head""" + from llmcompressor.modifiers.quantization.quantization.mixin import ( + QuantizationMixin, + ) + + targets = set() + for mod in modifiers: + if isinstance(mod, QuantizationMixin): + targets |= set(match_named_modules(model, mod.resolved_targets, mod.ignore)) + + return targets_embeddings(model, targets, check_input=True, check_output=False) + + @contextlib.contextmanager def patch_attr(base: object, attr: str, value: Any): """ From 6690012a2519cdc4eed129c3f07cd0deee8bfa05 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 1 Dec 2025 20:14:59 +0000 Subject: [PATCH 08/15] fix typo Signed-off-by: Kyle Sayers --- src/llmcompressor/utils/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index 4cd5a5a4e3..d6369d74fb 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -1103,7 +1103,7 @@ def requires_lm_head_calibration( if isinstance(mod, QuantizationMixin): targets |= set(match_named_modules(model, mod.resolved_targets, mod.ignore)) - return targets_embeddings(model, targets, check_input=True, check_output=False) + return targets_embeddings(model, targets, check_input=False, check_output=True) @contextlib.contextmanager From 73fc7f17ffee77f1ecadf001b0eae60728438283 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 1 Dec 2025 23:36:37 +0000 Subject: [PATCH 09/15] always disable Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/basic/pipeline.py | 11 +---- .../pipelines/sequential/pipeline.py | 4 -- src/llmcompressor/utils/helpers.py | 43 +++++++------------ tests/llmcompressor/utils/test_helpers.py | 41 +++++++++++++++--- 4 files changed, 51 insertions(+), 48 deletions(-) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 0a877880ff..90f8914da0 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -6,15 +6,13 @@ from compressed_tensors.utils import get_execution_device from torch.utils.data.dataloader import DataLoader -from llmcompressor.core import LifecycleCallbacks, active_session +from llmcompressor.core import LifecycleCallbacks from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch from llmcompressor.pipelines.registry import CalibrationPipeline from llmcompressor.pytorch.utils.helpers import tensors_to_device from llmcompressor.utils import ( calibration_forward_context, - disable_lm_head, dispatch_for_generation, - requires_lm_head_calibration, ) if TYPE_CHECKING: @@ -43,9 +41,6 @@ def __call__( :param dataloader: loads data for calibration :param dataset_args: dataset arguments relevant to pipelines """ - session = active_session() - modifiers = session.lifecycle.recipe.modifiers - dispatch_for_generation(model) # basic dispatch is identical to generation model_device = get_execution_device(model) @@ -53,10 +48,6 @@ def __call__( with contextlib.ExitStack() as stack: stack.enter_context(calibration_forward_context(model)) - # Optionally disable lm_head - if not requires_lm_head_calibration(model, modifiers): - stack.enter_context(disable_lm_head(model)) - for batch in tqdm.tqdm(dataloader, desc="Calibrating"): batch = apply_pad_mask_to_batch(batch) batch = tensors_to_device(batch, model_device) diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 658f360311..8dcda749fd 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -19,8 +19,6 @@ DISABLE_QAC_MODIFIERS, DisableQuantization, calibration_forward_context, - disable_lm_head, - requires_lm_head_calibration, ) if TYPE_CHECKING: @@ -91,8 +89,6 @@ def __call__( if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) # Optionally disable lm_head - if not requires_lm_head_calibration(model, modifiers): - stack.enter_context(disable_lm_head(model)) # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device) diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index d6369d74fb..fdc2377175 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -23,15 +23,14 @@ import numpy import torch -from compressed_tensors import has_offloaded_params, match_named_modules from compressed_tensors.quantization import disable_quantization, enable_quantization from loguru import logger from transformers import PreTrainedModel -from llmcompressor.utils import get_embeddings, targets_embeddings +from llmcompressor.utils import get_embeddings if TYPE_CHECKING: - from llmcompressor.modifiers import Modifier + pass __all__ = [ "ALL_TOKEN", @@ -72,7 +71,6 @@ "eval_context", "calibration_forward_context", "disable_lm_head", - "requires_lm_head_calibration", "patch_attr", "disable_hf_kernels", "DISABLE_QAC_MODIFIERS", @@ -1057,12 +1055,14 @@ def calibration_forward_context(model: torch.nn.Module): - Disable the KV cache - Disable train mode and enable eval mode - Disable hf kernels which could bypass hooks + - Disable lm head (input and weights can still be calibrated, output will be meta) """ with contextlib.ExitStack() as stack: stack.enter_context(torch.no_grad()) stack.enter_context(disable_cache(model)) stack.enter_context(eval_context(model)) stack.enter_context(disable_hf_kernels(model)) + stack.enter_context(disable_lm_head(model)) yield @@ -1074,13 +1074,18 @@ def disable_lm_head(model: torch.nn.Module): """ _, lm_head = get_embeddings(model) if lm_head is not None: - if has_offloaded_params(lm_head): - # keep weight on meta device - with patch_attr(lm_head._hf_hook, "offload", False): - yield - else: - with patch_attr(lm_head, "weight", lm_head.weight.to("meta")): - yield + if not isinstance(lm_head, torch.nn.Linear): + raise NotImplementedError( + f"Cannot disable LM head of type {lm_head.__class__.__name__}" + ) + + dummy_weight = lm_head.weight.to("meta") + + def dummy_forward(self, input: torch.Tensor) -> torch.Tensor: + return input.to("meta") @ dummy_weight.T + + with patch_attr(lm_head, "forward", dummy_forward.__get__(lm_head)): + yield else: logger.warning( @@ -1090,22 +1095,6 @@ def disable_lm_head(model: torch.nn.Module): yield -def requires_lm_head_calibration( - model: PreTrainedModel, modifiers: Iterable["Modifier"] -) -> bool: - """Returns True if any of the quantization modifers target the lm_head""" - from llmcompressor.modifiers.quantization.quantization.mixin import ( - QuantizationMixin, - ) - - targets = set() - for mod in modifiers: - if isinstance(mod, QuantizationMixin): - targets |= set(match_named_modules(model, mod.resolved_targets, mod.ignore)) - - return targets_embeddings(model, targets, check_input=False, check_output=True) - - @contextlib.contextmanager def patch_attr(base: object, attr: str, value: Any): """ diff --git a/tests/llmcompressor/utils/test_helpers.py b/tests/llmcompressor/utils/test_helpers.py index cb474c1f55..5e47bebe31 100644 --- a/tests/llmcompressor/utils/test_helpers.py +++ b/tests/llmcompressor/utils/test_helpers.py @@ -5,23 +5,23 @@ from transformers import ( AutoModelForCausalLM, MllamaForConditionalGeneration, - PretrainedConfig, - PreTrainedModel, ) +from llmcompressor.pipelines.sequential.helpers import dispatch_for_sequential from llmcompressor.utils import ( ALL_TOKEN, DisableQuantization, calibration_forward_context, convert_to_bool, disable_cache, + disable_lm_head, flatten_iterable, getattr_chain, interpolate, patch_attr, validate_str_iterable, ) -from llmcompressor.utils.dev import skip_weights_download +from llmcompressor.utils.dev import dispatch_for_generation, skip_weights_download from tests.testing_utils import requires_gpu @@ -149,10 +149,8 @@ def test_DisableQuantization(): @pytest.mark.unit def test_calibration_forward_context(): - class DummyModel(PreTrainedModel): - config_class = PretrainedConfig - - model = DummyModel(PretrainedConfig()) + with skip_weights_download(): + model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2") model.config.use_cache = True model.train() @@ -160,9 +158,12 @@ class DummyModel(PreTrainedModel): assert not torch.is_grad_enabled() assert not model.config.use_cache assert not model.training + assert model.lm_head.forward.__name__ == "dummy_forward" + assert torch.is_grad_enabled() assert model.config.use_cache assert model.training + assert model.lm_head.forward.__name__ == "forward" @pytest.mark.unit @@ -203,3 +204,29 @@ def test_disable_cache(model_cls, model_stub): output = model(**inputs) assert output.past_key_values is not None + + +@requires_gpu +@pytest.mark.parametrize("offload", ["sequential", "basic", "none"]) +def test_disable_lm_head(offload): + model = AutoModelForCausalLM.from_pretrained("nm-testing/tinysmokellama-3.2") + if offload == "sequential": + dispatch_for_sequential(model) + if offload == "basic": + dispatch_for_generation(model) + if offload == "none": + model = model.to("cuda") + + lm_input_device = None + + def hook(module, args): + nonlocal lm_input_device + lm_input_device = args[0].device + + model.lm_head.register_forward_pre_hook(hook) + + with disable_lm_head(model): + input = {key: value.to("cuda") for key, value in model.dummy_inputs.items()} + output = model(**input) + assert lm_input_device == torch.device("cuda:0") + assert output.logits.device == torch.device("meta") From d4b3e7ae780d12ea15124fea79e23ab682e16bb0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 1 Dec 2025 23:38:01 +0000 Subject: [PATCH 10/15] clean up dreggs Signed-off-by: Kyle Sayers --- src/llmcompressor/pipelines/basic/pipeline.py | 5 +---- src/llmcompressor/pipelines/sequential/pipeline.py | 1 - 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/src/llmcompressor/pipelines/basic/pipeline.py b/src/llmcompressor/pipelines/basic/pipeline.py index 90f8914da0..e6494fc5e0 100644 --- a/src/llmcompressor/pipelines/basic/pipeline.py +++ b/src/llmcompressor/pipelines/basic/pipeline.py @@ -10,10 +10,7 @@ from llmcompressor.modifiers.utils.pytorch_helpers import apply_pad_mask_to_batch from llmcompressor.pipelines.registry import CalibrationPipeline from llmcompressor.pytorch.utils.helpers import tensors_to_device -from llmcompressor.utils import ( - calibration_forward_context, - dispatch_for_generation, -) +from llmcompressor.utils import calibration_forward_context, dispatch_for_generation if TYPE_CHECKING: from llmcompressor.args.dataset_arguments import DatasetArguments diff --git a/src/llmcompressor/pipelines/sequential/pipeline.py b/src/llmcompressor/pipelines/sequential/pipeline.py index 8dcda749fd..511a693b95 100644 --- a/src/llmcompressor/pipelines/sequential/pipeline.py +++ b/src/llmcompressor/pipelines/sequential/pipeline.py @@ -88,7 +88,6 @@ def __call__( # Optionally disable quantization if not dataset_args.quantization_aware_calibration or disable_qac: stack.enter_context(DisableQuantization(model)) - # Optionally disable lm_head # prepare intermediates cache activations = IntermediatesCache.from_dataloader(dataloader, model_device) From 34814c700f6a448c94c085be34af0953fbf7fd6c Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Mon, 1 Dec 2025 23:47:25 +0000 Subject: [PATCH 11/15] clean up implementation Signed-off-by: Kyle Sayers --- src/llmcompressor/utils/helpers.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index fdc2377175..fddcafd829 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -1074,11 +1074,19 @@ def disable_lm_head(model: torch.nn.Module): """ _, lm_head = get_embeddings(model) if lm_head is not None: - if not isinstance(lm_head, torch.nn.Linear): - raise NotImplementedError( - f"Cannot disable LM head of type {lm_head.__class__.__name__}" - ) + logger.warning( + f"Attempted to disable lm_head of instance {model.__class__.__name__}, " + "but was unable to to find lm_head. This may lead to unexpected OOM." + ) + yield + return + + elif not isinstance(lm_head, torch.nn.Linear): + logger.warning(f"Cannot disable LM head of type {lm_head.__class__.__name__}") + yield + return + else: dummy_weight = lm_head.weight.to("meta") def dummy_forward(self, input: torch.Tensor) -> torch.Tensor: @@ -1087,13 +1095,6 @@ def dummy_forward(self, input: torch.Tensor) -> torch.Tensor: with patch_attr(lm_head, "forward", dummy_forward.__get__(lm_head)): yield - else: - logger.warning( - f"Attempted to disable lm_head of instance {model.__class__.__name__}, " - "but was unable to to find lm_head. This may lead to unexpected OOM." - ) - yield - @contextlib.contextmanager def patch_attr(base: object, attr: str, value: Any): From 35a0507151126b855cc500a3ba9c177bb83469b0 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 2 Dec 2025 01:09:48 +0000 Subject: [PATCH 12/15] add batch size args Signed-off-by: Kyle Sayers --- examples/multimodal_vision/gemma3_example.py | 25 ++++++---- .../multimodal_vision/idefics3_example.py | 26 ++++++---- examples/quantization_w4a16/llama3_example.py | 8 ++-- src/llmcompressor/args/dataset_arguments.py | 26 +++++++--- src/llmcompressor/datasets/utils.py | 48 ++++++++++++------- src/llmcompressor/entrypoints/oneshot.py | 4 +- 6 files changed, 90 insertions(+), 47 deletions(-) diff --git a/examples/multimodal_vision/gemma3_example.py b/examples/multimodal_vision/gemma3_example.py index dce35b7b83..6fe8be4db7 100644 --- a/examples/multimodal_vision/gemma3_example.py +++ b/examples/multimodal_vision/gemma3_example.py @@ -1,7 +1,10 @@ import requests -import torch from PIL import Image -from transformers import AutoProcessor, Gemma3ForConditionalGeneration +from transformers import ( + AutoProcessor, + DataCollatorWithPadding, + Gemma3ForConditionalGeneration, +) from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier @@ -11,18 +14,21 @@ model_id = "google/gemma-3-4b-it" model = Gemma3ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) +collator = DataCollatorWithPadding(processor.tokenizer) # Oneshot arguments -DATASET_ID = "flickr30k" -DATASET_SPLIT = {"calibration": "test[:512]"} NUM_CALIBRATION_SAMPLES = 512 MAX_SEQUENCE_LENGTH = 2048 +BATCH_SIZE = 512 +DATASET_ID = "flickr30k" +DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"} -# Define a oneshot data collator for multimodal inputs. -def data_collator(batch): - assert len(batch) == 1 - return {key: torch.tensor(value) for key, value in batch[0].items()} +# Define a oneshot data collator for multimodal processors +# remove extra dim added by vision processor +def data_collator(features: list[dict[str, object]]): + features = [{key: feature[key][0] for key in feature} for feature in features] + return collator(features) # Recipe @@ -45,10 +51,11 @@ def data_collator(batch): dataset=DATASET_ID, splits=DATASET_SPLIT, recipe=recipe, + batch_size=BATCH_SIZE, max_seq_length=MAX_SEQUENCE_LENGTH, num_calibration_samples=NUM_CALIBRATION_SAMPLES, - trust_remote_code_model=True, data_collator=data_collator, + trust_remote_code_model=True, ) # Confirm generations of the quantized model look sane. diff --git a/examples/multimodal_vision/idefics3_example.py b/examples/multimodal_vision/idefics3_example.py index 2fdaeb1a4a..b49678888f 100644 --- a/examples/multimodal_vision/idefics3_example.py +++ b/examples/multimodal_vision/idefics3_example.py @@ -1,8 +1,11 @@ import requests -import torch from datasets import load_dataset from PIL import Image -from transformers import AutoProcessor, Idefics3ForConditionalGeneration +from transformers import ( + AutoProcessor, + DataCollatorWithPadding, + Idefics3ForConditionalGeneration, +) from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier @@ -12,18 +15,21 @@ model_id = "HuggingFaceM4/Idefics3-8B-Llama3" # or "HuggingFaceTB/SmolVLM-Instruct" model = Idefics3ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) +collator = DataCollatorWithPadding(processor.tokenizer) # Oneshot arguments -DATASET_ID = "lmms-lab/flickr30k" -DATASET_SPLIT = "test[:512]" NUM_CALIBRATION_SAMPLES = 512 -MAX_SEQUENCE_LENGTH = 4096 # Seems to be required here +MAX_SEQUENCE_LENGTH = 4096 +BATCH_SIZE = 512 +DATASET_ID = "lmms-lab/flickr30k" +DATASET_SPLIT = f"test[:{NUM_CALIBRATION_SAMPLES}]" -# Define a oneshot data collator for multimodal inputs. -def data_collator(batch): - assert len(batch) == 1 - return {key: torch.tensor(value) for key, value in batch[0].items()} +# Define a oneshot data collator for multimodal processors +# remove extra dim added by vision processor +def data_collator(features: list[dict[str, object]]): + features = [{key: feature[key][0] for key in feature} for feature in features] + return collator(features) # Recipe @@ -86,8 +92,8 @@ def tokenize(sample): model=model, dataset=ds, recipe=recipe, + batch_size=BATCH_SIZE, max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, trust_remote_code_model=True, data_collator=data_collator, sequential_targets=["LlamaDecoderLayer"], diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index b03aacee35..239e8e15df 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -6,7 +6,8 @@ from llmcompressor.utils import dispatch_for_generation # Select model and load it. -model_id = "meta-llama/Meta-Llama-3-8B-Instruct" +# model_id = "meta-llama/Meta-Llama-3-8B-Instruct" +model_id = "meta-llama/Llama-3.2-1B-Instruct" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -16,8 +17,9 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 512 +NUM_CALIBRATION_SAMPLES = 16 MAX_SEQUENCE_LENGTH = 2048 +BATCH_SIZE = 4 # Load dataset and preprocess. ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") @@ -58,8 +60,8 @@ def tokenize(sample): model=model, dataset=ds, recipe=recipe, + batch_size=BATCH_SIZE, max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, ) # Confirm generations of the quantized model look sane. diff --git a/src/llmcompressor/args/dataset_arguments.py b/src/llmcompressor/args/dataset_arguments.py index 2618b90197..b66fcfa275 100644 --- a/src/llmcompressor/args/dataset_arguments.py +++ b/src/llmcompressor/args/dataset_arguments.py @@ -8,9 +8,7 @@ """ from dataclasses import dataclass, field -from typing import Any, Callable - -from transformers import DefaultDataCollator +from typing import Callable, Optional @dataclass @@ -69,9 +67,25 @@ class CustomDatasetArguments(DVCDatasetArguments): }, ) - data_collator: Callable[[Any], Any] = field( - default_factory=lambda: DefaultDataCollator(), - metadata={"help": "The function to used to form a batch from the dataset"}, + data_collator: Optional[Callable] = field( + default=None, + metadata={ + "help": ( + "The function to used to form a batch from the dataset. Defaults to " + "`DataCollatorWithPadding(processor)`." + ) + }, + ) + + batch_size: int = field( + default=1, + metadata={ + "help": ( + "Calibration batch size. During calibration, LLM Compressor disables " + "lm_head output computations to reduce memory usage from large " + "calibration matches" + ) + }, ) diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index 2b80b1ed9a..5bd1138a63 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -7,6 +7,7 @@ one-shot calibration workflows. """ +import math import multiprocessing import re from typing import Any, Callable @@ -15,7 +16,7 @@ from datasets import Dataset from loguru import logger from torch.utils.data import DataLoader, RandomSampler, SequentialSampler -from transformers.data import default_data_collator +from transformers.data import DataCollatorWithPadding from llmcompressor.args import DatasetArguments from llmcompressor.transformers.data import TextGenerationDataset @@ -115,44 +116,56 @@ def get_calibration_dataloader( ) calibration_dataset = datasets.get("calibration") + tokenizer = getattr(processor, "tokenizer", processor) + collate_fn = dataset_args.data_collator or DataCollatorWithPadding(tokenizer) + if dataset_args.batch_size > 1 and ( + tokenizer.pad_token is None or tokenizer.pad_token_id < 0 + ): + logger.warning("Could not find padding token. Setting PAD token to EOS token") + tokenizer.pad_token = tokenizer.eos_token return format_calibration_data( tokenized_dataset=calibration_dataset, + collate_fn=collate_fn, + batch_size=dataset_args.batch_size, num_calibration_samples=dataset_args.num_calibration_samples, do_shuffle=dataset_args.shuffle_calibration_samples, - collate_fn=dataset_args.data_collator, ) def format_calibration_data( tokenized_dataset: Dataset, + collate_fn: Callable, + batch_size: int = 1, num_calibration_samples: int | None = None, do_shuffle: bool = True, - collate_fn: Callable = default_data_collator, ) -> list[torch.Tensor]: """ Creates a dataloader out of the calibration dataset split, trimming it to the desired number of calibration samples :param tokenized_dataset: dataset to convert to dataloader - :param num_calibration_samples: number of data samples to convert + :param num_calibration_samples: number of batches to convert :param do_shuffle: whether to shuffle the dataset before selecting calibration samples, true by default :param collate_fn: optional custom collate function, or use default :return: list of trimmed calibration data tensors """ - safe_calibration_samples = len(tokenized_dataset) + # (1) shuffle dataset + if do_shuffle: + tokenized_dataset = tokenized_dataset.shuffle() + + # (2) truncate dataset if num_calibration_samples is not None: - safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) - if safe_calibration_samples != num_calibration_samples: + num_batches = math.ceil(num_calibration_samples / batch_size) + if num_batches > len(tokenized_dataset): logger.warning( - f"Requested {num_calibration_samples} calibration samples but " - f"the provided dataset only has {safe_calibration_samples}. " + f"Requested {num_calibration_samples} calibration samples but the " + f"provided dataset only has {len(tokenized_dataset) * batch_size}. " ) + num_batches = len(tokenized_dataset) + tokenized_calibration = tokenized_dataset.select(num_batches) - if do_shuffle: - tokenized_dataset = tokenized_dataset.shuffle() - tokenized_calibration = tokenized_dataset.select(range(safe_calibration_samples)) - + # (3) infer number of workers MAX_DATALOADER_WORKERS = 8 try: num_workers = min(MAX_DATALOADER_WORKERS, multiprocessing.cpu_count() // 2) @@ -161,8 +174,10 @@ def format_calibration_data( "Could not determine number of CPUs, defaulting to 0 dataloader workers." ) num_workers = 0 + + # (4) create dataloader dataloader_params = { - "batch_size": 1, + "batch_size": batch_size, "sampler": RandomSampler(tokenized_calibration) if do_shuffle else SequentialSampler(tokenized_calibration), @@ -170,10 +185,7 @@ def format_calibration_data( "pin_memory": True, "num_workers": num_workers, } - - calibration_dataloader = DataLoader(tokenized_calibration, **dataloader_params) - - return calibration_dataloader + return DataLoader(tokenized_calibration, **dataloader_params) def make_dataset_splits( diff --git a/src/llmcompressor/entrypoints/oneshot.py b/src/llmcompressor/entrypoints/oneshot.py index c2b29aa97c..fba7dc6688 100644 --- a/src/llmcompressor/entrypoints/oneshot.py +++ b/src/llmcompressor/entrypoints/oneshot.py @@ -12,7 +12,7 @@ import os from datetime import datetime from pathlib import Path -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable, Optional from loguru import logger from torch.utils.data import DataLoader @@ -248,6 +248,8 @@ def oneshot( dataset_config_name: str | None = None, dataset_path: str | None = None, splits: str | list[str] | dict[str, str] | None = None, + batch_size: int = 1, + data_collator: Optional[Callable] = None, num_calibration_samples: int = 512, shuffle_calibration_samples: bool = True, max_seq_length: int = 384, From 5fe709dafc75d67c6214f3565e1a8b3c108d2112 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 2 Dec 2025 01:33:55 +0000 Subject: [PATCH 13/15] fix typos Signed-off-by: Kyle Sayers --- examples/multimodal_vision/gemma3_example.py | 20 +++++++++---------- .../multimodal_vision/idefics3_example.py | 20 +++++-------------- src/llmcompressor/datasets/utils.py | 8 ++++---- src/llmcompressor/utils/helpers.py | 2 +- 4 files changed, 20 insertions(+), 30 deletions(-) diff --git a/examples/multimodal_vision/gemma3_example.py b/examples/multimodal_vision/gemma3_example.py index 6fe8be4db7..509f48ad58 100644 --- a/examples/multimodal_vision/gemma3_example.py +++ b/examples/multimodal_vision/gemma3_example.py @@ -7,7 +7,6 @@ ) from llmcompressor import oneshot -from llmcompressor.modifiers.quantization import GPTQModifier from llmcompressor.utils import dispatch_for_generation # Load model. @@ -33,15 +32,15 @@ def data_collator(features: list[dict[str, object]]): # Recipe recipe = [ - GPTQModifier( - targets="Linear", - scheme="W4A16", - ignore=[ - "lm_head", - r"re:model\.vision_tower.*", - r"re:model\.multi_modal_projector.*", - ], - ), + # GPTQModifier( + # targets="Linear", + # scheme="W4A16", + # ignore=[ + # "lm_head", + # r"re:model\.vision_tower.*", + # r"re:model\.multi_modal_projector.*", + # ], + # ), ] # Perform oneshot @@ -56,6 +55,7 @@ def data_collator(features: list[dict[str, object]]): num_calibration_samples=NUM_CALIBRATION_SAMPLES, data_collator=data_collator, trust_remote_code_model=True, + pipeline="sequential", ) # Confirm generations of the quantized model look sane. diff --git a/examples/multimodal_vision/idefics3_example.py b/examples/multimodal_vision/idefics3_example.py index b49678888f..a3c75722d6 100644 --- a/examples/multimodal_vision/idefics3_example.py +++ b/examples/multimodal_vision/idefics3_example.py @@ -1,11 +1,7 @@ import requests from datasets import load_dataset from PIL import Image -from transformers import ( - AutoProcessor, - DataCollatorWithPadding, - Idefics3ForConditionalGeneration, -) +from transformers import AutoProcessor, Idefics3ForConditionalGeneration from llmcompressor import oneshot from llmcompressor.modifiers.quantization import GPTQModifier @@ -15,7 +11,6 @@ model_id = "HuggingFaceM4/Idefics3-8B-Llama3" # or "HuggingFaceTB/SmolVLM-Instruct" model = Idefics3ForConditionalGeneration.from_pretrained(model_id, torch_dtype="auto") processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) -collator = DataCollatorWithPadding(processor.tokenizer) # Oneshot arguments NUM_CALIBRATION_SAMPLES = 512 @@ -25,13 +20,6 @@ DATASET_SPLIT = f"test[:{NUM_CALIBRATION_SAMPLES}]" -# Define a oneshot data collator for multimodal processors -# remove extra dim added by vision processor -def data_collator(features: list[dict[str, object]]): - features = [{key: feature[key][0] for key in feature} for feature in features] - return collator(features) - - # Recipe recipe = [ GPTQModifier( @@ -75,7 +63,7 @@ def preprocess(example): # Tokenize inputs. def tokenize(sample): - return processor( + features = processor( text=sample["text"], images=sample["images"], padding=False, @@ -83,6 +71,9 @@ def tokenize(sample): truncation=True, ) + # remove extra dim added by vision processor + return [{key: feature[key][0] for key in feature} for feature in features] + # avoid errors with writer_batch_size ds = ds.map(tokenize, writer_batch_size=1, remove_columns=ds.column_names) @@ -95,7 +86,6 @@ def tokenize(sample): batch_size=BATCH_SIZE, max_seq_length=MAX_SEQUENCE_LENGTH, trust_remote_code_model=True, - data_collator=data_collator, sequential_targets=["LlamaDecoderLayer"], ) diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index 5bd1138a63..879e857d3c 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -163,7 +163,7 @@ def format_calibration_data( f"provided dataset only has {len(tokenized_dataset) * batch_size}. " ) num_batches = len(tokenized_dataset) - tokenized_calibration = tokenized_dataset.select(num_batches) + tokenized_dataset = tokenized_dataset.select(range(num_batches)) # (3) infer number of workers MAX_DATALOADER_WORKERS = 8 @@ -178,14 +178,14 @@ def format_calibration_data( # (4) create dataloader dataloader_params = { "batch_size": batch_size, - "sampler": RandomSampler(tokenized_calibration) + "sampler": RandomSampler(tokenized_dataset) if do_shuffle - else SequentialSampler(tokenized_calibration), + else SequentialSampler(tokenized_dataset), "collate_fn": collate_fn, "pin_memory": True, "num_workers": num_workers, } - return DataLoader(tokenized_calibration, **dataloader_params) + return DataLoader(tokenized_dataset, **dataloader_params) def make_dataset_splits( diff --git a/src/llmcompressor/utils/helpers.py b/src/llmcompressor/utils/helpers.py index fddcafd829..ecd63d4239 100644 --- a/src/llmcompressor/utils/helpers.py +++ b/src/llmcompressor/utils/helpers.py @@ -1073,7 +1073,7 @@ def disable_lm_head(model: torch.nn.Module): does not untie parameters and restores the model proper loading upon exit """ _, lm_head = get_embeddings(model) - if lm_head is not None: + if lm_head is None: logger.warning( f"Attempted to disable lm_head of instance {model.__class__.__name__}, " "but was unable to to find lm_head. This may lead to unexpected OOM." From 5bc0166c26215c990c57b06cafd01a3cac49ec02 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 2 Dec 2025 03:01:29 +0000 Subject: [PATCH 14/15] clean up format_calibration_data Signed-off-by: Kyle Sayers --- examples/multimodal_vision/gemma3_example.py | 35 +++++++++++-------- examples/quantization_w4a16/llama3_example.py | 11 +++--- src/llmcompressor/datasets/utils.py | 19 ++++------ 3 files changed, 35 insertions(+), 30 deletions(-) diff --git a/examples/multimodal_vision/gemma3_example.py b/examples/multimodal_vision/gemma3_example.py index 509f48ad58..7f9b7cfabe 100644 --- a/examples/multimodal_vision/gemma3_example.py +++ b/examples/multimodal_vision/gemma3_example.py @@ -43,20 +43,27 @@ def data_collator(features: list[dict[str, object]]): # ), ] -# Perform oneshot -oneshot( - model=model, - tokenizer=model_id, - dataset=DATASET_ID, - splits=DATASET_SPLIT, - recipe=recipe, - batch_size=BATCH_SIZE, - max_seq_length=MAX_SEQUENCE_LENGTH, - num_calibration_samples=NUM_CALIBRATION_SAMPLES, - data_collator=data_collator, - trust_remote_code_model=True, - pipeline="sequential", -) +from pttp import TensorProfiler + +with TensorProfiler() as prof: + # Perform oneshot + oneshot( + model=model, + tokenizer=model_id, + dataset=DATASET_ID, + splits=DATASET_SPLIT, + recipe=recipe, + batch_size=BATCH_SIZE, + max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + data_collator=data_collator, + trust_remote_code_model=True, + pipeline="sequential", + ) +import torch +del prof._memory.timeline[torch.device("cpu")] +prof.save_memory_timeline("with_disable.png") +exit(0) # Confirm generations of the quantized model look sane. print("========== SAMPLE GENERATION ==============") diff --git a/examples/quantization_w4a16/llama3_example.py b/examples/quantization_w4a16/llama3_example.py index 239e8e15df..e970d41eb8 100644 --- a/examples/quantization_w4a16/llama3_example.py +++ b/examples/quantization_w4a16/llama3_example.py @@ -6,8 +6,8 @@ from llmcompressor.utils import dispatch_for_generation # Select model and load it. -# model_id = "meta-llama/Meta-Llama-3-8B-Instruct" -model_id = "meta-llama/Llama-3.2-1B-Instruct" +model_id = "meta-llama/Meta-Llama-3-8B-Instruct" +#model_id = "meta-llama/Llama-3.2-1B-Instruct" model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto") tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -17,9 +17,9 @@ # Select number of samples. 512 samples is a good place to start. # Increasing the number of samples can improve accuracy. -NUM_CALIBRATION_SAMPLES = 16 +NUM_CALIBRATION_SAMPLES = 32 MAX_SEQUENCE_LENGTH = 2048 -BATCH_SIZE = 4 +BATCH_SIZE = 16 # Load dataset and preprocess. ds = load_dataset(DATASET_ID, split=f"{DATASET_SPLIT}[:{NUM_CALIBRATION_SAMPLES}]") @@ -62,7 +62,10 @@ def tokenize(sample): recipe=recipe, batch_size=BATCH_SIZE, max_seq_length=MAX_SEQUENCE_LENGTH, + num_calibration_samples=NUM_CALIBRATION_SAMPLES, + pipeline="sequential", ) +exit(0) # Confirm generations of the quantized model look sane. print("\n\n") diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index 879e857d3c..c2fa96ffde 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -15,7 +15,7 @@ import torch from datasets import Dataset from loguru import logger -from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torch.utils.data import DataLoader, SequentialSampler from transformers.data import DataCollatorWithPadding from llmcompressor.args import DatasetArguments @@ -118,9 +118,7 @@ def get_calibration_dataloader( calibration_dataset = datasets.get("calibration") tokenizer = getattr(processor, "tokenizer", processor) collate_fn = dataset_args.data_collator or DataCollatorWithPadding(tokenizer) - if dataset_args.batch_size > 1 and ( - tokenizer.pad_token is None or tokenizer.pad_token_id < 0 - ): + if tokenizer.pad_token is None or tokenizer.pad_token_id < 0: logger.warning("Could not find padding token. Setting PAD token to EOS token") tokenizer.pad_token = tokenizer.eos_token @@ -156,14 +154,13 @@ def format_calibration_data( # (2) truncate dataset if num_calibration_samples is not None: - num_batches = math.ceil(num_calibration_samples / batch_size) - if num_batches > len(tokenized_dataset): + if num_calibration_samples > len(tokenized_dataset): logger.warning( f"Requested {num_calibration_samples} calibration samples but the " - f"provided dataset only has {len(tokenized_dataset) * batch_size}. " + f"provided dataset only has {len(tokenized_dataset)} samples." ) - num_batches = len(tokenized_dataset) - tokenized_dataset = tokenized_dataset.select(range(num_batches)) + num_calibration_samples = len(tokenized_dataset) + tokenized_dataset = tokenized_dataset.select(range(num_calibration_samples)) # (3) infer number of workers MAX_DATALOADER_WORKERS = 8 @@ -178,9 +175,7 @@ def format_calibration_data( # (4) create dataloader dataloader_params = { "batch_size": batch_size, - "sampler": RandomSampler(tokenized_dataset) - if do_shuffle - else SequentialSampler(tokenized_dataset), + "sampler": SequentialSampler(tokenized_dataset), "collate_fn": collate_fn, "pin_memory": True, "num_workers": num_workers, From 44d4b468cc27c0991f515efbaf05bae2df2f9612 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Tue, 2 Dec 2025 03:44:57 +0000 Subject: [PATCH 15/15] debug message instead Signed-off-by: Kyle Sayers --- src/llmcompressor/datasets/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llmcompressor/datasets/utils.py b/src/llmcompressor/datasets/utils.py index c2fa96ffde..dc11088cd4 100644 --- a/src/llmcompressor/datasets/utils.py +++ b/src/llmcompressor/datasets/utils.py @@ -119,7 +119,7 @@ def get_calibration_dataloader( tokenizer = getattr(processor, "tokenizer", processor) collate_fn = dataset_args.data_collator or DataCollatorWithPadding(tokenizer) if tokenizer.pad_token is None or tokenizer.pad_token_id < 0: - logger.warning("Could not find padding token. Setting PAD token to EOS token") + logger.debug("Could not find padding token. Setting PAD token to EOS token") tokenizer.pad_token = tokenizer.eos_token return format_calibration_data( @@ -177,7 +177,7 @@ def format_calibration_data( "batch_size": batch_size, "sampler": SequentialSampler(tokenized_dataset), "collate_fn": collate_fn, - "pin_memory": True, + "pin_memory": False, "num_workers": num_workers, } return DataLoader(tokenized_dataset, **dataloader_params)