diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index e213a7f25..033cf6e67 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -279,6 +279,161 @@ def get_dtype(dtype): return dtype +def _patch_compressed_linear_init(): + """Patch CompressedLinear to prevent transformers weight initialization errors. + + When loading pack-quantized models, CompressedLinear modules don't have a 'weight' + attribute (they have weight_packed instead). Transformers tries to initialize + missing weights which fails. This patch adds a dummy weight property. + """ + try: + from compressed_tensors.linear.compressed_linear import CompressedLinear + except ImportError: + return # compressed_tensors not installed + + if hasattr(CompressedLinear, "_modelopt_init_patched"): + return # Already patched + + # Patch __getattr__ to return dummy for weight access + if not hasattr(CompressedLinear, "_modelopt_original_getattr"): + CompressedLinear._modelopt_original_getattr = getattr(CompressedLinear, "__getattr__", None) + original_getattr = CompressedLinear._modelopt_original_getattr + + class DummyWeightData: + """Dummy tensor data that accepts initialization calls like .normal_(), .zero_().""" + + def __getattr__(self, name): + # Return self for any method call to allow chaining + return lambda *args, **kwargs: self + + class DummyWeight: + """Dummy weight with .data that accepts any initialization.""" + + def __init__(self): + self.data = DummyWeightData() + + def __getattr__(self, name): + return lambda *args, **kwargs: self + + def patched_getattr(self, name): + if name == "weight": + # Check if real weight exists + if "_parameters" in self.__dict__ and "weight" in self._parameters: + return self._parameters["weight"] + if "weight" in self.__dict__: + return self.__dict__["weight"] + # Return dummy weight for initialization purposes (don't store it) + return DummyWeight() + if original_getattr is not None: + return original_getattr(self, name) + raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") + + CompressedLinear.__getattr__ = patched_getattr + CompressedLinear._modelopt_init_patched = True + print("Patched CompressedLinear for transformers compatibility") +def _restore_compressed_linear(): + """Restore original CompressedLinear behavior after loading.""" + try: + from compressed_tensors.linear.compressed_linear import CompressedLinear + if hasattr(CompressedLinear, "_modelopt_original_getattr"): + CompressedLinear.__getattr__ = CompressedLinear._modelopt_original_getattr + delattr(CompressedLinear, "_modelopt_original_getattr") + elif hasattr(CompressedLinear, "__getattr__"): + # If it didn't have one before, delete the patched one + del CompressedLinear.__getattr__ + CompressedLinear._modelopt_init_patched = False + print("Restored CompressedLinear original state") + except Exception: + pass + + + +def _unpack_compressed_linear_weights(model, ckpt_path=None): + """Hybrid restoration: restores BF16 layers and fixes expert metadata. + + 1. BF16 layers (vision, lm_head) are restored from checkpoint and marked non-compressed. + 2. INT4 experts stay compressed in HBM to save memory (decompressed on-the-fly). + 3. Metadata (weight_shape) is fixed to avoid decompression errors. + """ + try: + from compressed_tensors.linear.compressed_linear import CompressedLinear + from compressed_tensors.quantization import QuantizationStatus + except ImportError: + return + + if ckpt_path is None: + ckpt_path = getattr(model.config, "_name_or_path", None) + if not ckpt_path: return + + import os, json, torch + from safetensors import safe_open + + # 1. Load weights from safetensors + checkpoint_weights = {} + index_path = os.path.join(ckpt_path, "model.safetensors.index.json") + st_files = [os.path.join(ckpt_path, "model.safetensors")] + if os.path.exists(index_path): + with open(index_path) as f: + index = json.load(f) + st_files = [os.path.join(ckpt_path, f) for f in set(index.get("weight_map", {}).values())] + + # We only need to load non-expert weights or metadata + for sf_path in st_files: + if not os.path.exists(sf_path): continue + with safe_open(sf_path, framework="pt") as f: + for key in f.keys(): + # Load everything except the massive packed expert weights + if ".mlp.experts." not in key or "weight_shape" in key: + checkpoint_weights[key] = f.get_tensor(key) + + # 2. Hybrid Restoration + for name, module in model.named_modules(): + if not isinstance(module, CompressedLinear): continue + + with torch.no_grad(): + target_device = next(module.parameters()).device + + # CASE A: Real BF16 weight exists (vision, lm_head) + if f"{name}.weight" in checkpoint_weights: + w = checkpoint_weights[f"{name}.weight"].to(target_device) + module._parameters.pop("weight", None) + module._buffers.pop("weight", None) + if "weight" in module.__dict__: + del module.__dict__["weight"] + param = torch.nn.Parameter(w, requires_grad=False) + module._parameters["weight"] = param + module.__dict__["weight"] = param + module.quantization_status = QuantizationStatus.FROZEN # Mark non-compressed + print(f" Restored BF16 layer: {name}") + + # CASE B: Expert (stay compressed, fix metadata) + elif f"{name}.weight_shape" in checkpoint_weights: + ws = checkpoint_weights[f"{name}.weight_shape"] + # Restore int32 packed weights if present + if f"{name}.weight_packed" in checkpoint_weights: + module.weight_packed = checkpoint_weights[f"{name}.weight_packed"].to(torch.int32) + # Ensure no stale BF16 weight is registered for compressed experts + module._parameters.pop("weight", None) + module._buffers.pop("weight", None) + module.__dict__.pop("weight", None) + # Register weight_shape as int32 parameter for compressed_tensors forward + shape_param = torch.nn.Parameter(ws.to(torch.int32), requires_grad=False) + module._parameters.pop("weight_shape", None) + module.__dict__.pop("weight_shape", None) + module._parameters["weight_shape"] = shape_param + module.__dict__["weight_shape"] = shape_param + # Keep status as COMPRESSED for on-the-fly decompression + + # Ensure compressed experts do not carry a stale weight attribute + for name, module in model.named_modules(): + if not isinstance(module, CompressedLinear): + continue + if getattr(module, "quantization_status", None) != QuantizationStatus.COMPRESSED: + continue + module._parameters.pop("weight", None) + module._buffers.pop("weight", None) + module.__dict__.pop("weight", None) + def get_model( ckpt_path, device="cuda", @@ -289,6 +444,8 @@ def get_model( ): print(f"Initializing model from {ckpt_path}") + # Note: CompressedLinear weights will be unpacked after model loading + device_map = "auto" if device == "cpu": device_map = "cpu" @@ -345,23 +502,39 @@ def get_model( # device_map "auto" and "cuda" triggers error regarding meta tensor from safetensors device_map = None + # Helper function to check if model has pack-quantized config + def has_pack_quantized_config(config): + # Check top-level quantization_config + if hasattr(config, "quantization_config"): + if config.quantization_config.get("format", None) == "pack-quantized": + return True + # Check nested text_config.quantization_config (for multi-modal models like kimi k2.5) + if hasattr(config, "text_config") and hasattr( + config.text_config, "quantization_config" + ): + if config.text_config.quantization_config.get("format", None) == "pack-quantized": + return True + return False + if is_speculative(hf_config): model = AutoModelForCausalLM.from_pretrained( ckpt_path, device_map=device_map, **model_kwargs, ) - elif ( - hasattr(hf_config, "quantization_config") - and hf_config.quantization_config.get("format", None) == "pack-quantized" - ): - torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) + elif has_pack_quantized_config(hf_config): + # Patch CompressedLinear before loading to handle missing weight attribute + _patch_compressed_linear_init() + # Pass torch_dtype="auto" to preserve original dtypes from safetensors + # This prevents int32 packed weights from being converted to float model = AutoModelForCausalLM.from_pretrained( ckpt_path, device_map="auto", trust_remote_code=trust_remote_code, - torch_dtype=torch_dtype, + torch_dtype="auto", ) + # Restore original CompressedLinear behavior after loading + _restore_compressed_linear() else: architecture = hf_config.architectures[0] @@ -416,6 +589,10 @@ def get_model( **model_kwargs, ) model.eval() + _unpack_compressed_linear_weights(model, ckpt_path) + + + # Experts will be decompressed on-the-fly during calibration to save memory # If device_map was disabled (None), manually move model to target device if device_map is None and device != "cpu": diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index a29d7c754..346f3abf1 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -674,23 +674,101 @@ def _setup(self): def forward(self, input: Tensor) -> Tensor: from compressed_tensors.quantization import QuantizationStatus + import torch if self.quantization_status == QuantizationStatus.COMPRESSED: - weight_data = self.compressor.decompress_module(self) + # Check if we should use decompress_module or manual decompress_weight + # Real packed weights are int32. If it's float, it's not actually compressed. + if self.weight_packed.dtype == torch.int32: + compressed_data = {"weight_packed": self.weight_packed} + if hasattr(self, "weight_scale"): + compressed_data["weight_scale"] = self.weight_scale + if hasattr(self, "weight_shape"): + ws = self.weight_shape + if isinstance(ws, torch.Tensor): + compressed_data["weight_shape"] = [int(x) for x in ws.tolist()] + else: + compressed_data["weight_shape"] = [int(x) for x in ws] + if hasattr(self, "weight_zero_point"): + compressed_data["weight_zero_point"] = self.weight_zero_point + + quant_args = None + if hasattr(self, "quantization_scheme") and self.quantization_scheme: + if hasattr(self.quantization_scheme, "weights"): + quant_args = self.quantization_scheme.weights + + if not hasattr(self, "_logged_on_the_fly"): + print(f"[on-the-fly-decompress] {self.__class__.__name__}") + self._logged_on_the_fly = True + weight_data = self.compressor.decompress_weight( + compressed_data=compressed_data, + quantization_args=quant_args, + ) + else: + # If it's not int32, just use it as-is + weight_data = self.weight_packed else: + # Standard path for non-compressed layers weight_data = self.weight return linear(self.input_quantizer(input), self.weight_quantizer(weight_data), self.bias) def unpack_weight(self): + import torch from compressed_tensors.quantization import QuantizationStatus if self.quantization_status == QuantizationStatus.COMPRESSED: - self.weight = nn.Parameter(self.compressor.decompress_module(self), requires_grad=False) + # Build compressed_data dict manually to handle weight_shape tensor issue + compressed_data = {} + compressed_data["weight_packed"] = self.weight_packed + if hasattr(self, "weight_scale"): + compressed_data["weight_scale"] = self.weight_scale + if hasattr(self, "weight_shape"): + ws = self.weight_shape + if isinstance(ws, torch.Tensor): + compressed_data["weight_shape"] = [int(x) for x in ws.tolist()] + elif isinstance(ws, (list, tuple)): + compressed_data["weight_shape"] = [int(x) for x in ws] + else: + compressed_data["weight_shape"] = ws + if hasattr(self, "weight_zero_point"): + compressed_data["weight_zero_point"] = self.weight_zero_point + + # Skip non-pack-quantized weights (e.g., vision modules use BF16) + if isinstance(compressed_data["weight_packed"], torch.Tensor): + if compressed_data["weight_packed"].dtype != torch.int32: + return + + # Get quantization args + quant_args = None + if hasattr(self, "quantization_scheme") and self.quantization_scheme: + if hasattr(self.quantization_scheme, "weights"): + quant_args = self.quantization_scheme.weights + + # Decompress + decompressed = self.compressor.decompress_weight( + compressed_data=compressed_data, + quantization_args=quant_args, + ) + # Avoid register_parameter errors if a placeholder already exists + self._parameters.pop("weight", None) + self._buffers.pop("weight", None) + if "weight" in self.__dict__: + del self.__dict__["weight"] + param = nn.Parameter(decompressed, requires_grad=False) + self._parameters["weight"] = param + self.__dict__["weight"] = param if hasattr(self, "weight_packed"): del self.weight_packed if hasattr(self, "weight_scale"): del self.weight_scale + if hasattr(self, "weight_shape"): + if "weight_shape" in self._parameters: + del self._parameters["weight_shape"] + else: + delattr(self, "weight_shape") + if self.quantization_status == QuantizationStatus.COMPRESSED: + self.quantization_status = QuantizationStatus.FROZEN try: diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 7908ec514..042e74ba5 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -73,7 +73,7 @@ "preprocess": lambda sample: "\n".join(turn["value"] for turn in sample["conversations"]), }, "cnn_dailymail": { - "config": {"path": "cnn_dailymail", "name": "3.0.0", "split": ["train"]}, + "config": {"path": "abisee/cnn_dailymail", "name": "3.0.0", "split": ["train"]}, "preprocess": lambda sample: sample["article"], }, "pile": {