From ef2c2c499e9b559f3cc8601d77bac7c3518c9266 Mon Sep 17 00:00:00 2001 From: Zhiyu Date: Tue, 27 Jan 2026 07:46:52 +0000 Subject: [PATCH 1/6] support kimi k2.5 ptq Signed-off-by: Zhiyu --- examples/llm_ptq/example_utils.py | 115 ++++++++++++++++++++++++++++-- 1 file changed, 111 insertions(+), 4 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index e213a7f25..86297636a 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -279,6 +279,94 @@ def get_dtype(dtype): return dtype +def _patch_compressed_linear_init(): + """Patch CompressedLinear to prevent transformers weight initialization errors. + + When loading pack-quantized models, CompressedLinear modules don't have a 'weight' + attribute (they have weight_packed instead). Transformers tries to initialize + missing weights which fails. This patch adds a dummy weight property. + """ + try: + from compressed_tensors.linear.compressed_linear import CompressedLinear + except ImportError: + return # compressed_tensors not installed + + if hasattr(CompressedLinear, "_modelopt_init_patched"): + return # Already patched + + # Patch __getattr__ to return dummy for weight access + original_getattr = getattr(CompressedLinear, "__getattr__", None) + + def patched_getattr(self, name): + if name == "weight": + # Check if real weight exists + if "_parameters" in self.__dict__ and "weight" in self._parameters: + return self._parameters["weight"] + if "weight" in self.__dict__: + return self.__dict__["weight"] + # Return dummy weight for initialization purposes + if hasattr(self, "weight_packed"): + if "_dummy_weight" not in self.__dict__: + object.__setattr__( + self, + "_dummy_weight", + torch.nn.Parameter( + torch.empty(0, device=self.weight_packed.device), + requires_grad=False, + ), + ) + return self._dummy_weight + if original_getattr is not None: + return original_getattr(self, name) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + CompressedLinear.__getattr__ = patched_getattr + CompressedLinear._modelopt_init_patched = True + print("Patched CompressedLinear for transformers compatibility") + + +def _unpack_compressed_linear_weights(model): + """Unpack all CompressedLinear weights from INT4 to BF16. + + This decompresses the packed INT4 weights to BF16 format so they can be + processed by Model-Optimizer's PTQ flow. This must be called after model + loading but before quantization/calibration. + """ + try: + from compressed_tensors.linear.compressed_linear import CompressedLinear + except ImportError: + return # compressed_tensors not installed, nothing to unpack + + unpacked_count = 0 + for name, module in model.named_modules(): + if isinstance(module, CompressedLinear) and hasattr(module, "weight_packed"): + # Decompress the INT4 weights to BF16 + with torch.no_grad(): + # Fix weight_shape if it's a tensor (should be a list of ints) + if hasattr(module, "weight_shape") and isinstance(module.weight_shape, torch.Tensor): + module.weight_shape = module.weight_shape.tolist() + + # Get the decompressed weight from the compressor + decompressed_weight = module.compressor.decompress_module(module) + # Register as a proper parameter (this replaces any dummy weight) + module.weight = torch.nn.Parameter(decompressed_weight, requires_grad=False) + # Clean up compressed attributes + if hasattr(module, "weight_packed"): + delattr(module, "weight_packed") + if hasattr(module, "weight_scale"): + delattr(module, "weight_scale") + if hasattr(module, "_dummy_weight"): + delattr(module, "_dummy_weight") + # Update quantization status to not compressed + from compressed_tensors.quantization import QuantizationStatus + + module.quantization_status = QuantizationStatus.FROZEN + unpacked_count += 1 + + if unpacked_count > 0: + print(f"Unpacked {unpacked_count} CompressedLinear modules from INT4 to BF16") + + def get_model( ckpt_path, device="cuda", @@ -289,6 +377,8 @@ def get_model( ): print(f"Initializing model from {ckpt_path}") + # Note: CompressedLinear weights will be unpacked after model loading + device_map = "auto" if device == "cpu": device_map = "cpu" @@ -345,16 +435,29 @@ def get_model( # device_map "auto" and "cuda" triggers error regarding meta tensor from safetensors device_map = None + # Helper function to check if model has pack-quantized config + def has_pack_quantized_config(config): + # Check top-level quantization_config + if hasattr(config, "quantization_config"): + if config.quantization_config.get("format", None) == "pack-quantized": + return True + # Check nested text_config.quantization_config (for multi-modal models like kimi k2.5) + if hasattr(config, "text_config") and hasattr( + config.text_config, "quantization_config" + ): + if config.text_config.quantization_config.get("format", None) == "pack-quantized": + return True + return False + if is_speculative(hf_config): model = AutoModelForCausalLM.from_pretrained( ckpt_path, device_map=device_map, **model_kwargs, ) - elif ( - hasattr(hf_config, "quantization_config") - and hf_config.quantization_config.get("format", None) == "pack-quantized" - ): + elif has_pack_quantized_config(hf_config): + # Patch CompressedLinear before loading to handle missing weight attribute + _patch_compressed_linear_init() torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) model = AutoModelForCausalLM.from_pretrained( ckpt_path, @@ -417,6 +520,10 @@ def get_model( ) model.eval() + # Unpack CompressedLinear weights (INT4 -> BF16) if present + # This is needed for models with compressed-tensors quantization (e.g., kimi k2.5) + _unpack_compressed_linear_weights(model) + # If device_map was disabled (None), manually move model to target device if device_map is None and device != "cpu": print(f"Moving model to {device} device...") From 1fa9847eb57324bbff443729a238c48983e590c5 Mon Sep 17 00:00:00 2001 From: Zhiyu Date: Tue, 27 Jan 2026 18:06:51 +0000 Subject: [PATCH 2/6] update Signed-off-by: Zhiyu --- examples/llm_ptq/example_utils.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 86297636a..7b52ef091 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -343,8 +343,13 @@ def _unpack_compressed_linear_weights(model): # Decompress the INT4 weights to BF16 with torch.no_grad(): # Fix weight_shape if it's a tensor (should be a list of ints) - if hasattr(module, "weight_shape") and isinstance(module.weight_shape, torch.Tensor): - module.weight_shape = module.weight_shape.tolist() + if hasattr(module, "weight_shape") and isinstance( + module.weight_shape, torch.Tensor + ): + shape_list = module.weight_shape.tolist() + # Remove the parameter and set as regular attribute + del module._parameters["weight_shape"] + module.weight_shape = shape_list # Get the decompressed weight from the compressor decompressed_weight = module.compressor.decompress_module(module) From 338b22b79d542f220bea91e47adbb84d22ff79b1 Mon Sep 17 00:00:00 2001 From: Zhiyu Date: Wed, 28 Jan 2026 04:07:32 +0000 Subject: [PATCH 3/6] updates, fix weight_shape tensor issue Signed-off-by: Zhiyu --- examples/llm_ptq/example_utils.py | 180 ++++++++++++++---- .../torch/quantization/plugins/huggingface.py | 47 ++++- modelopt/torch/utils/dataset_utils.py | 2 +- 3 files changed, 192 insertions(+), 37 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 7b52ef091..3a1a11366 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -297,6 +297,22 @@ def _patch_compressed_linear_init(): # Patch __getattr__ to return dummy for weight access original_getattr = getattr(CompressedLinear, "__getattr__", None) + class DummyWeightData: + """Dummy tensor data that accepts initialization calls like .normal_(), .zero_().""" + + def __getattr__(self, name): + # Return self for any method call to allow chaining + return lambda *args, **kwargs: self + + class DummyWeight: + """Dummy weight with .data that accepts any initialization.""" + + def __init__(self): + self.data = DummyWeightData() + + def __getattr__(self, name): + return lambda *args, **kwargs: self + def patched_getattr(self, name): if name == "weight": # Check if real weight exists @@ -304,18 +320,8 @@ def patched_getattr(self, name): return self._parameters["weight"] if "weight" in self.__dict__: return self.__dict__["weight"] - # Return dummy weight for initialization purposes - if hasattr(self, "weight_packed"): - if "_dummy_weight" not in self.__dict__: - object.__setattr__( - self, - "_dummy_weight", - torch.nn.Parameter( - torch.empty(0, device=self.weight_packed.device), - requires_grad=False, - ), - ) - return self._dummy_weight + # Return dummy weight for initialization purposes (don't store it) + return DummyWeight() if original_getattr is not None: return original_getattr(self, name) raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") @@ -325,44 +331,149 @@ def patched_getattr(self, name): print("Patched CompressedLinear for transformers compatibility") -def _unpack_compressed_linear_weights(model): +def _unpack_compressed_linear_weights(model, ckpt_path=None): """Unpack all CompressedLinear weights from INT4 to BF16. This decompresses the packed INT4 weights to BF16 format so they can be processed by Model-Optimizer's PTQ flow. This must be called after model loading but before quantization/calibration. + + Args: + model: The loaded model with CompressedLinear modules + ckpt_path: Path to the checkpoint to reload int32 weights from safetensors """ try: from compressed_tensors.linear.compressed_linear import CompressedLinear except ImportError: return # compressed_tensors not installed, nothing to unpack + # Get checkpoint path from model config if not provided + if ckpt_path is None: + if hasattr(model, "config") and hasattr(model.config, "_name_or_path"): + ckpt_path = model.config._name_or_path + + # Load original int32 weights from safetensors + original_weights = {} + if ckpt_path: + import json + import os + + from safetensors import safe_open + + # Find safetensors files + index_path = os.path.join(ckpt_path, "model.safetensors.index.json") + single_path = os.path.join(ckpt_path, "model.safetensors") + + safetensor_files = [] + if os.path.exists(index_path): + with open(index_path) as f: + index = json.load(f) + safetensor_files = list(set(index.get("weight_map", {}).values())) + safetensor_files = [os.path.join(ckpt_path, f) for f in safetensor_files] + elif os.path.exists(single_path): + safetensor_files = [single_path] + + # Load int32 weights (weight_packed, weight_shape) + for sf_path in safetensor_files: + with safe_open(sf_path, framework="pt") as f: + for key in f.keys(): + if "weight_packed" in key or "weight_shape" in key: + original_weights[key] = f.get_tensor(key) + + print(f"Loaded {len(original_weights)} packed weight tensors from safetensors") + + def _lookup_original(key: str): + if not original_weights: + return None + if key in original_weights: + return original_weights[key] + # Fallback: match by suffix when prefixes differ + matches = [k for k in original_weights if k.endswith(key)] + if len(matches) == 1: + return original_weights[matches[0]] + return None + unpacked_count = 0 for name, module in model.named_modules(): if isinstance(module, CompressedLinear) and hasattr(module, "weight_packed"): - # Decompress the INT4 weights to BF16 with torch.no_grad(): - # Fix weight_shape if it's a tensor (should be a list of ints) - if hasattr(module, "weight_shape") and isinstance( - module.weight_shape, torch.Tensor - ): - shape_list = module.weight_shape.tolist() - # Remove the parameter and set as regular attribute - del module._parameters["weight_shape"] - module.weight_shape = shape_list - - # Get the decompressed weight from the compressor - decompressed_weight = module.compressor.decompress_module(module) - # Register as a proper parameter (this replaces any dummy weight) - module.weight = torch.nn.Parameter(decompressed_weight, requires_grad=False) - # Clean up compressed attributes + # Build compressed_data dict + compressed_data = {} + + # Use original int32 weights from safetensors if available + packed_key = f"{name}.weight_packed" + shape_key = f"{name}.weight_shape" + + packed_tensor = _lookup_original(packed_key) + if packed_tensor is None: + packed_tensor = module.weight_packed + elif isinstance(packed_tensor, torch.Tensor): + packed_tensor = packed_tensor.to(module.weight_packed.device) + compressed_data["weight_packed"] = packed_tensor + + if isinstance(compressed_data["weight_packed"], torch.Tensor): + if compressed_data["weight_packed"].dtype != torch.int32: + # Some modules (e.g., vision tower) are not pack-quantized. + # Skip decompression if packed weights are not int32. + continue + + shape_tensor = _lookup_original(shape_key) + if shape_tensor is not None: + if isinstance(shape_tensor, torch.Tensor): + shape_tensor = shape_tensor.to(module.weight_packed.device) + compressed_data["weight_shape"] = [int(x) for x in shape_tensor.tolist()] + elif hasattr(module, "weight_shape"): + ws = module.weight_shape + if isinstance(ws, torch.Tensor): + compressed_data["weight_shape"] = [int(x) for x in ws.tolist()] + elif isinstance(ws, (list, tuple)): + compressed_data["weight_shape"] = [int(x) for x in ws] + else: + compressed_data["weight_shape"] = ws + + # Get weight_scale (this one is float, so conversion is OK) + if hasattr(module, "weight_scale"): + compressed_data["weight_scale"] = module.weight_scale.to( + module.weight_packed.device + ) + + if hasattr(module, "weight_zero_point"): + compressed_data["weight_zero_point"] = module.weight_zero_point.to( + module.weight_packed.device + ) + + # Get quantization args + quant_args = None + if hasattr(module, "quantization_scheme") and module.quantization_scheme: + if hasattr(module.quantization_scheme, "weights"): + quant_args = module.quantization_scheme.weights + + # Decompress + decompressed_weight = module.compressor.decompress_weight( + compressed_data=compressed_data, + quantization_args=quant_args, + ) + + # Register as parameter (avoid register_parameter to bypass __getattr__ hook) + module._parameters.pop("weight", None) + module._buffers.pop("weight", None) + if "weight" in module.__dict__: + del module.__dict__["weight"] + param = torch.nn.Parameter(decompressed_weight, requires_grad=False) + module._parameters["weight"] = param + module.__dict__["weight"] = param + + # Clean up if hasattr(module, "weight_packed"): delattr(module, "weight_packed") if hasattr(module, "weight_scale"): delattr(module, "weight_scale") - if hasattr(module, "_dummy_weight"): - delattr(module, "_dummy_weight") - # Update quantization status to not compressed + if hasattr(module, "weight_shape"): + if "weight_shape" in module._parameters: + del module._parameters["weight_shape"] + elif hasattr(module, "weight_shape"): + delattr(module, "weight_shape") + from compressed_tensors.quantization import QuantizationStatus module.quantization_status = QuantizationStatus.FROZEN @@ -463,12 +574,13 @@ def has_pack_quantized_config(config): elif has_pack_quantized_config(hf_config): # Patch CompressedLinear before loading to handle missing weight attribute _patch_compressed_linear_init() - torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) + # Pass torch_dtype="auto" to preserve original dtypes from safetensors + # This prevents int32 packed weights from being converted to float model = AutoModelForCausalLM.from_pretrained( ckpt_path, device_map="auto", trust_remote_code=trust_remote_code, - torch_dtype=torch_dtype, + torch_dtype="auto", ) else: architecture = hf_config.architectures[0] @@ -527,7 +639,7 @@ def has_pack_quantized_config(config): # Unpack CompressedLinear weights (INT4 -> BF16) if present # This is needed for models with compressed-tensors quantization (e.g., kimi k2.5) - _unpack_compressed_linear_weights(model) + _unpack_compressed_linear_weights(model, ckpt_path) # If device_map was disabled (None), manually move model to target device if device_map is None and device != "cpu": diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index a29d7c754..354015cbd 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -676,21 +676,64 @@ def forward(self, input: Tensor) -> Tensor: from compressed_tensors.quantization import QuantizationStatus if self.quantization_status == QuantizationStatus.COMPRESSED: - weight_data = self.compressor.decompress_module(self) + # Decompress once to avoid weight_shape tensor issues in decompress_module + if "weight" not in self._parameters: + self.unpack_weight() + weight_data = self._parameters.get("weight", self.weight) else: weight_data = self.weight return linear(self.input_quantizer(input), self.weight_quantizer(weight_data), self.bias) def unpack_weight(self): + import torch from compressed_tensors.quantization import QuantizationStatus if self.quantization_status == QuantizationStatus.COMPRESSED: - self.weight = nn.Parameter(self.compressor.decompress_module(self), requires_grad=False) + # Build compressed_data dict manually to handle weight_shape tensor issue + compressed_data = {} + compressed_data["weight_packed"] = self.weight_packed + if hasattr(self, "weight_scale"): + compressed_data["weight_scale"] = self.weight_scale + if hasattr(self, "weight_shape"): + ws = self.weight_shape + if isinstance(ws, torch.Tensor): + compressed_data["weight_shape"] = [int(x) for x in ws.tolist()] + elif isinstance(ws, (list, tuple)): + compressed_data["weight_shape"] = [int(x) for x in ws] + else: + compressed_data["weight_shape"] = ws + if hasattr(self, "weight_zero_point"): + compressed_data["weight_zero_point"] = self.weight_zero_point + + # Skip non-pack-quantized weights (e.g., vision modules use BF16) + if isinstance(compressed_data["weight_packed"], torch.Tensor): + if compressed_data["weight_packed"].dtype != torch.int32: + return + + # Get quantization args + quant_args = None + if hasattr(self, "quantization_scheme") and self.quantization_scheme: + if hasattr(self.quantization_scheme, "weights"): + quant_args = self.quantization_scheme.weights + + # Decompress + decompressed = self.compressor.decompress_weight( + compressed_data=compressed_data, + quantization_args=quant_args, + ) + self.weight = nn.Parameter(decompressed, requires_grad=False) if hasattr(self, "weight_packed"): del self.weight_packed if hasattr(self, "weight_scale"): del self.weight_scale + if hasattr(self, "weight_shape"): + if "weight_shape" in self._parameters: + del self._parameters["weight_shape"] + else: + delattr(self, "weight_shape") + if self.quantization_status == QuantizationStatus.COMPRESSED: + self.quantization_status = QuantizationStatus.FROZEN try: diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 7908ec514..042e74ba5 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -73,7 +73,7 @@ "preprocess": lambda sample: "\n".join(turn["value"] for turn in sample["conversations"]), }, "cnn_dailymail": { - "config": {"path": "cnn_dailymail", "name": "3.0.0", "split": ["train"]}, + "config": {"path": "abisee/cnn_dailymail", "name": "3.0.0", "split": ["train"]}, "preprocess": lambda sample: sample["article"], }, "pile": { From efb39485477fcf7745272ed73e53f69a79887aed Mon Sep 17 00:00:00 2001 From: Zhiyu Date: Wed, 28 Jan 2026 06:50:53 +0000 Subject: [PATCH 4/6] decompress on CPU Signed-off-by: Zhiyu --- examples/llm_ptq/example_utils.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 3a1a11366..e3c469061 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -448,11 +448,29 @@ def _lookup_original(key: str): if hasattr(module.quantization_scheme, "weights"): quant_args = module.quantization_scheme.weights - # Decompress + # Decompress on CPU to avoid GPU OOM, then move to target device. + decompress_device = torch.device("cpu") + if isinstance(compressed_data["weight_packed"], torch.Tensor): + compressed_data["weight_packed"] = compressed_data["weight_packed"].to( + decompress_device + ) + if "weight_scale" in compressed_data and isinstance( + compressed_data["weight_scale"], torch.Tensor + ): + compressed_data["weight_scale"] = compressed_data["weight_scale"].to( + decompress_device + ) + if "weight_zero_point" in compressed_data and isinstance( + compressed_data["weight_zero_point"], torch.Tensor + ): + compressed_data["weight_zero_point"] = compressed_data["weight_zero_point"].to( + decompress_device + ) + decompressed_weight = module.compressor.decompress_weight( compressed_data=compressed_data, quantization_args=quant_args, - ) + ).to(module.weight_packed.device, non_blocking=True) # Register as parameter (avoid register_parameter to bypass __getattr__ hook) module._parameters.pop("weight", None) From e4070282f850d081609e57730a3683a35a039048 Mon Sep 17 00:00:00 2001 From: Zhiyu Date: Thu, 29 Jan 2026 18:47:44 +0000 Subject: [PATCH 5/6] decompress experts on-the-fly to avoid OOTM Signed-off-by: Zhiyu --- examples/llm_ptq/example_utils.py | 212 +++++------------- .../torch/quantization/plugins/huggingface.py | 42 +++- 2 files changed, 91 insertions(+), 163 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index e3c469061..2533cea55 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -332,174 +332,70 @@ def patched_getattr(self, name): def _unpack_compressed_linear_weights(model, ckpt_path=None): - """Unpack all CompressedLinear weights from INT4 to BF16. - - This decompresses the packed INT4 weights to BF16 format so they can be - processed by Model-Optimizer's PTQ flow. This must be called after model - loading but before quantization/calibration. - - Args: - model: The loaded model with CompressedLinear modules - ckpt_path: Path to the checkpoint to reload int32 weights from safetensors + """Hybrid restoration: restores BF16 layers and fixes expert metadata. + + 1. BF16 layers (vision, lm_head) are restored from checkpoint and marked non-compressed. + 2. INT4 experts stay compressed in HBM to save memory (decompressed on-the-fly). + 3. Metadata (weight_shape) is fixed to avoid decompression errors. """ try: from compressed_tensors.linear.compressed_linear import CompressedLinear + from compressed_tensors.quantization import QuantizationStatus except ImportError: - return # compressed_tensors not installed, nothing to unpack + return - # Get checkpoint path from model config if not provided if ckpt_path is None: - if hasattr(model, "config") and hasattr(model.config, "_name_or_path"): - ckpt_path = model.config._name_or_path - - # Load original int32 weights from safetensors - original_weights = {} - if ckpt_path: - import json - import os - - from safetensors import safe_open - - # Find safetensors files - index_path = os.path.join(ckpt_path, "model.safetensors.index.json") - single_path = os.path.join(ckpt_path, "model.safetensors") - - safetensor_files = [] - if os.path.exists(index_path): - with open(index_path) as f: - index = json.load(f) - safetensor_files = list(set(index.get("weight_map", {}).values())) - safetensor_files = [os.path.join(ckpt_path, f) for f in safetensor_files] - elif os.path.exists(single_path): - safetensor_files = [single_path] - - # Load int32 weights (weight_packed, weight_shape) - for sf_path in safetensor_files: - with safe_open(sf_path, framework="pt") as f: - for key in f.keys(): - if "weight_packed" in key or "weight_shape" in key: - original_weights[key] = f.get_tensor(key) - - print(f"Loaded {len(original_weights)} packed weight tensors from safetensors") - - def _lookup_original(key: str): - if not original_weights: - return None - if key in original_weights: - return original_weights[key] - # Fallback: match by suffix when prefixes differ - matches = [k for k in original_weights if k.endswith(key)] - if len(matches) == 1: - return original_weights[matches[0]] - return None - - unpacked_count = 0 + ckpt_path = getattr(model.config, "_name_or_path", None) + if not ckpt_path: return + + import os, json, torch + from safetensors import safe_open + + # 1. Load weights from safetensors + checkpoint_weights = {} + index_path = os.path.join(ckpt_path, "model.safetensors.index.json") + st_files = [os.path.join(ckpt_path, "model.safetensors")] + if os.path.exists(index_path): + with open(index_path) as f: + index = json.load(f) + st_files = [os.path.join(ckpt_path, f) for f in set(index.get("weight_map", {}).values())] + + # We only need to load non-expert weights or metadata + for sf_path in st_files: + if not os.path.exists(sf_path): continue + with safe_open(sf_path, framework="pt") as f: + for key in f.keys(): + # Load everything except the massive packed expert weights + if ".mlp.experts." not in key or "weight_shape" in key: + checkpoint_weights[key] = f.get_tensor(key) + + # 2. Hybrid Restoration for name, module in model.named_modules(): - if isinstance(module, CompressedLinear) and hasattr(module, "weight_packed"): - with torch.no_grad(): - # Build compressed_data dict - compressed_data = {} - - # Use original int32 weights from safetensors if available - packed_key = f"{name}.weight_packed" - shape_key = f"{name}.weight_shape" - - packed_tensor = _lookup_original(packed_key) - if packed_tensor is None: - packed_tensor = module.weight_packed - elif isinstance(packed_tensor, torch.Tensor): - packed_tensor = packed_tensor.to(module.weight_packed.device) - compressed_data["weight_packed"] = packed_tensor - - if isinstance(compressed_data["weight_packed"], torch.Tensor): - if compressed_data["weight_packed"].dtype != torch.int32: - # Some modules (e.g., vision tower) are not pack-quantized. - # Skip decompression if packed weights are not int32. - continue - - shape_tensor = _lookup_original(shape_key) - if shape_tensor is not None: - if isinstance(shape_tensor, torch.Tensor): - shape_tensor = shape_tensor.to(module.weight_packed.device) - compressed_data["weight_shape"] = [int(x) for x in shape_tensor.tolist()] - elif hasattr(module, "weight_shape"): - ws = module.weight_shape - if isinstance(ws, torch.Tensor): - compressed_data["weight_shape"] = [int(x) for x in ws.tolist()] - elif isinstance(ws, (list, tuple)): - compressed_data["weight_shape"] = [int(x) for x in ws] - else: - compressed_data["weight_shape"] = ws - - # Get weight_scale (this one is float, so conversion is OK) - if hasattr(module, "weight_scale"): - compressed_data["weight_scale"] = module.weight_scale.to( - module.weight_packed.device - ) - - if hasattr(module, "weight_zero_point"): - compressed_data["weight_zero_point"] = module.weight_zero_point.to( - module.weight_packed.device - ) - - # Get quantization args - quant_args = None - if hasattr(module, "quantization_scheme") and module.quantization_scheme: - if hasattr(module.quantization_scheme, "weights"): - quant_args = module.quantization_scheme.weights - - # Decompress on CPU to avoid GPU OOM, then move to target device. - decompress_device = torch.device("cpu") - if isinstance(compressed_data["weight_packed"], torch.Tensor): - compressed_data["weight_packed"] = compressed_data["weight_packed"].to( - decompress_device - ) - if "weight_scale" in compressed_data and isinstance( - compressed_data["weight_scale"], torch.Tensor - ): - compressed_data["weight_scale"] = compressed_data["weight_scale"].to( - decompress_device - ) - if "weight_zero_point" in compressed_data and isinstance( - compressed_data["weight_zero_point"], torch.Tensor - ): - compressed_data["weight_zero_point"] = compressed_data["weight_zero_point"].to( - decompress_device - ) - - decompressed_weight = module.compressor.decompress_weight( - compressed_data=compressed_data, - quantization_args=quant_args, - ).to(module.weight_packed.device, non_blocking=True) - - # Register as parameter (avoid register_parameter to bypass __getattr__ hook) + if not isinstance(module, CompressedLinear): continue + + with torch.no_grad(): + target_device = next(module.parameters()).device + + # CASE A: Real BF16 weight exists (vision, lm_head) + if f"{name}.weight" in checkpoint_weights: + w = checkpoint_weights[f"{name}.weight"].to(target_device) module._parameters.pop("weight", None) module._buffers.pop("weight", None) if "weight" in module.__dict__: del module.__dict__["weight"] - param = torch.nn.Parameter(decompressed_weight, requires_grad=False) + param = torch.nn.Parameter(w, requires_grad=False) module._parameters["weight"] = param module.__dict__["weight"] = param - - # Clean up - if hasattr(module, "weight_packed"): - delattr(module, "weight_packed") - if hasattr(module, "weight_scale"): - delattr(module, "weight_scale") - if hasattr(module, "weight_shape"): - if "weight_shape" in module._parameters: - del module._parameters["weight_shape"] - elif hasattr(module, "weight_shape"): - delattr(module, "weight_shape") - - from compressed_tensors.quantization import QuantizationStatus - - module.quantization_status = QuantizationStatus.FROZEN - unpacked_count += 1 - - if unpacked_count > 0: - print(f"Unpacked {unpacked_count} CompressedLinear modules from INT4 to BF16") - + module.quantization_status = QuantizationStatus.FROZEN # Mark non-compressed + print(f" Restored BF16 layer: {name}") + + # CASE B: Expert (stay compressed, fix metadata) + elif f"{name}.weight_shape" in checkpoint_weights: + ws = checkpoint_weights[f"{name}.weight_shape"] + if "weight_shape" in module._parameters: + del module._parameters["weight_shape"] + module.weight_shape = [int(x) for x in ws.tolist()] + # Keep status as COMPRESSED for on-the-fly decompression def get_model( ckpt_path, @@ -654,11 +550,11 @@ def has_pack_quantized_config(config): **model_kwargs, ) model.eval() - - # Unpack CompressedLinear weights (INT4 -> BF16) if present - # This is needed for models with compressed-tensors quantization (e.g., kimi k2.5) _unpack_compressed_linear_weights(model, ckpt_path) + + # Experts will be decompressed on-the-fly during calibration to save memory + # If device_map was disabled (None), manually move model to target device if device_map is None and device != "cpu": print(f"Moving model to {device} device...") diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 354015cbd..b519545df 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -674,13 +674,38 @@ def _setup(self): def forward(self, input: Tensor) -> Tensor: from compressed_tensors.quantization import QuantizationStatus + import torch if self.quantization_status == QuantizationStatus.COMPRESSED: - # Decompress once to avoid weight_shape tensor issues in decompress_module - if "weight" not in self._parameters: - self.unpack_weight() - weight_data = self._parameters.get("weight", self.weight) + # Check if we should use decompress_module or manual decompress_weight + # Real packed weights are int32. If it's float, it's not actually compressed. + if self.weight_packed.dtype == torch.int32: + compressed_data = {"weight_packed": self.weight_packed} + if hasattr(self, "weight_scale"): + compressed_data["weight_scale"] = self.weight_scale + if hasattr(self, "weight_shape"): + ws = self.weight_shape + if isinstance(ws, torch.Tensor): + compressed_data["weight_shape"] = [int(x) for x in ws.tolist()] + else: + compressed_data["weight_shape"] = [int(x) for x in ws] + if hasattr(self, "weight_zero_point"): + compressed_data["weight_zero_point"] = self.weight_zero_point + + quant_args = None + if hasattr(self, "quantization_scheme") and self.quantization_scheme: + if hasattr(self.quantization_scheme, "weights"): + quant_args = self.quantization_scheme.weights + + weight_data = self.compressor.decompress_weight( + compressed_data=compressed_data, + quantization_args=quant_args, + ) + else: + # If it's not int32, just use it as-is + weight_data = self.weight_packed else: + # Standard path for non-compressed layers weight_data = self.weight return linear(self.input_quantizer(input), self.weight_quantizer(weight_data), self.bias) @@ -722,7 +747,14 @@ def unpack_weight(self): compressed_data=compressed_data, quantization_args=quant_args, ) - self.weight = nn.Parameter(decompressed, requires_grad=False) + # Avoid register_parameter errors if a placeholder already exists + self._parameters.pop("weight", None) + self._buffers.pop("weight", None) + if "weight" in self.__dict__: + del self.__dict__["weight"] + param = nn.Parameter(decompressed, requires_grad=False) + self._parameters["weight"] = param + self.__dict__["weight"] = param if hasattr(self, "weight_packed"): del self.weight_packed if hasattr(self, "weight_scale"): From 99912fbdf4eed9be48383330ff2ecaf42a242d1b Mon Sep 17 00:00:00 2001 From: Zhiyu Date: Fri, 30 Jan 2026 20:31:31 +0000 Subject: [PATCH 6/6] resolve issues when use large calib size Signed-off-by: Zhiyu --- examples/llm_ptq/example_utils.py | 47 +++++++++++++++++-- .../torch/quantization/plugins/huggingface.py | 3 ++ 2 files changed, 46 insertions(+), 4 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 2533cea55..033cf6e67 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -295,7 +295,9 @@ def _patch_compressed_linear_init(): return # Already patched # Patch __getattr__ to return dummy for weight access - original_getattr = getattr(CompressedLinear, "__getattr__", None) + if not hasattr(CompressedLinear, "_modelopt_original_getattr"): + CompressedLinear._modelopt_original_getattr = getattr(CompressedLinear, "__getattr__", None) + original_getattr = CompressedLinear._modelopt_original_getattr class DummyWeightData: """Dummy tensor data that accepts initialization calls like .normal_(), .zero_().""" @@ -329,6 +331,21 @@ def patched_getattr(self, name): CompressedLinear.__getattr__ = patched_getattr CompressedLinear._modelopt_init_patched = True print("Patched CompressedLinear for transformers compatibility") +def _restore_compressed_linear(): + """Restore original CompressedLinear behavior after loading.""" + try: + from compressed_tensors.linear.compressed_linear import CompressedLinear + if hasattr(CompressedLinear, "_modelopt_original_getattr"): + CompressedLinear.__getattr__ = CompressedLinear._modelopt_original_getattr + delattr(CompressedLinear, "_modelopt_original_getattr") + elif hasattr(CompressedLinear, "__getattr__"): + # If it didn't have one before, delete the patched one + del CompressedLinear.__getattr__ + CompressedLinear._modelopt_init_patched = False + print("Restored CompressedLinear original state") + except Exception: + pass + def _unpack_compressed_linear_weights(model, ckpt_path=None): @@ -392,11 +409,31 @@ def _unpack_compressed_linear_weights(model, ckpt_path=None): # CASE B: Expert (stay compressed, fix metadata) elif f"{name}.weight_shape" in checkpoint_weights: ws = checkpoint_weights[f"{name}.weight_shape"] - if "weight_shape" in module._parameters: - del module._parameters["weight_shape"] - module.weight_shape = [int(x) for x in ws.tolist()] + # Restore int32 packed weights if present + if f"{name}.weight_packed" in checkpoint_weights: + module.weight_packed = checkpoint_weights[f"{name}.weight_packed"].to(torch.int32) + # Ensure no stale BF16 weight is registered for compressed experts + module._parameters.pop("weight", None) + module._buffers.pop("weight", None) + module.__dict__.pop("weight", None) + # Register weight_shape as int32 parameter for compressed_tensors forward + shape_param = torch.nn.Parameter(ws.to(torch.int32), requires_grad=False) + module._parameters.pop("weight_shape", None) + module.__dict__.pop("weight_shape", None) + module._parameters["weight_shape"] = shape_param + module.__dict__["weight_shape"] = shape_param # Keep status as COMPRESSED for on-the-fly decompression + # Ensure compressed experts do not carry a stale weight attribute + for name, module in model.named_modules(): + if not isinstance(module, CompressedLinear): + continue + if getattr(module, "quantization_status", None) != QuantizationStatus.COMPRESSED: + continue + module._parameters.pop("weight", None) + module._buffers.pop("weight", None) + module.__dict__.pop("weight", None) + def get_model( ckpt_path, device="cuda", @@ -496,6 +533,8 @@ def has_pack_quantized_config(config): trust_remote_code=trust_remote_code, torch_dtype="auto", ) + # Restore original CompressedLinear behavior after loading + _restore_compressed_linear() else: architecture = hf_config.architectures[0] diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index b519545df..346f3abf1 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -697,6 +697,9 @@ def forward(self, input: Tensor) -> Tensor: if hasattr(self.quantization_scheme, "weights"): quant_args = self.quantization_scheme.weights + if not hasattr(self, "_logged_on_the_fly"): + print(f"[on-the-fly-decompress] {self.__class__.__name__}") + self._logged_on_the_fly = True weight_data = self.compressor.decompress_weight( compressed_data=compressed_data, quantization_args=quant_args,