-
Notifications
You must be signed in to change notification settings - Fork 248
Support Kimi-K2.5 PTQ #820
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
ef2c2c4
1fa9847
338b22b
efb3948
e407028
99912fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -279,6 +279,161 @@ def get_dtype(dtype): | |
| return dtype | ||
|
|
||
|
|
||
| def _patch_compressed_linear_init(): | ||
| """Patch CompressedLinear to prevent transformers weight initialization errors. | ||
|
|
||
| When loading pack-quantized models, CompressedLinear modules don't have a 'weight' | ||
| attribute (they have weight_packed instead). Transformers tries to initialize | ||
| missing weights which fails. This patch adds a dummy weight property. | ||
| """ | ||
| try: | ||
| from compressed_tensors.linear.compressed_linear import CompressedLinear | ||
| except ImportError: | ||
| return # compressed_tensors not installed | ||
|
|
||
| if hasattr(CompressedLinear, "_modelopt_init_patched"): | ||
| return # Already patched | ||
|
|
||
| # Patch __getattr__ to return dummy for weight access | ||
| if not hasattr(CompressedLinear, "_modelopt_original_getattr"): | ||
| CompressedLinear._modelopt_original_getattr = getattr(CompressedLinear, "__getattr__", None) | ||
| original_getattr = CompressedLinear._modelopt_original_getattr | ||
|
|
||
| class DummyWeightData: | ||
| """Dummy tensor data that accepts initialization calls like .normal_(), .zero_().""" | ||
|
|
||
| def __getattr__(self, name): | ||
| # Return self for any method call to allow chaining | ||
| return lambda *args, **kwargs: self | ||
|
|
||
| class DummyWeight: | ||
| """Dummy weight with .data that accepts any initialization.""" | ||
|
|
||
| def __init__(self): | ||
| self.data = DummyWeightData() | ||
|
|
||
| def __getattr__(self, name): | ||
| return lambda *args, **kwargs: self | ||
|
|
||
| def patched_getattr(self, name): | ||
| if name == "weight": | ||
| # Check if real weight exists | ||
| if "_parameters" in self.__dict__ and "weight" in self._parameters: | ||
| return self._parameters["weight"] | ||
| if "weight" in self.__dict__: | ||
| return self.__dict__["weight"] | ||
| # Return dummy weight for initialization purposes (don't store it) | ||
| return DummyWeight() | ||
| if original_getattr is not None: | ||
| return original_getattr(self, name) | ||
| raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") | ||
|
|
||
| CompressedLinear.__getattr__ = patched_getattr | ||
| CompressedLinear._modelopt_init_patched = True | ||
| print("Patched CompressedLinear for transformers compatibility") | ||
| def _restore_compressed_linear(): | ||
| """Restore original CompressedLinear behavior after loading.""" | ||
| try: | ||
| from compressed_tensors.linear.compressed_linear import CompressedLinear | ||
| if hasattr(CompressedLinear, "_modelopt_original_getattr"): | ||
| CompressedLinear.__getattr__ = CompressedLinear._modelopt_original_getattr | ||
| delattr(CompressedLinear, "_modelopt_original_getattr") | ||
| elif hasattr(CompressedLinear, "__getattr__"): | ||
| # If it didn't have one before, delete the patched one | ||
| del CompressedLinear.__getattr__ | ||
| CompressedLinear._modelopt_init_patched = False | ||
| print("Restored CompressedLinear original state") | ||
| except Exception: | ||
| pass | ||
|
|
||
|
|
||
|
|
||
| def _unpack_compressed_linear_weights(model, ckpt_path=None): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we do not need it. We should be able to unpack on the fly with logics in the quantization plugins |
||
| """Hybrid restoration: restores BF16 layers and fixes expert metadata. | ||
|
|
||
| 1. BF16 layers (vision, lm_head) are restored from checkpoint and marked non-compressed. | ||
| 2. INT4 experts stay compressed in HBM to save memory (decompressed on-the-fly). | ||
| 3. Metadata (weight_shape) is fixed to avoid decompression errors. | ||
| """ | ||
| try: | ||
| from compressed_tensors.linear.compressed_linear import CompressedLinear | ||
| from compressed_tensors.quantization import QuantizationStatus | ||
| except ImportError: | ||
| return | ||
|
|
||
| if ckpt_path is None: | ||
| ckpt_path = getattr(model.config, "_name_or_path", None) | ||
| if not ckpt_path: return | ||
|
|
||
| import os, json, torch | ||
| from safetensors import safe_open | ||
|
|
||
| # 1. Load weights from safetensors | ||
| checkpoint_weights = {} | ||
| index_path = os.path.join(ckpt_path, "model.safetensors.index.json") | ||
| st_files = [os.path.join(ckpt_path, "model.safetensors")] | ||
| if os.path.exists(index_path): | ||
| with open(index_path) as f: | ||
| index = json.load(f) | ||
| st_files = [os.path.join(ckpt_path, f) for f in set(index.get("weight_map", {}).values())] | ||
|
|
||
| # We only need to load non-expert weights or metadata | ||
| for sf_path in st_files: | ||
| if not os.path.exists(sf_path): continue | ||
| with safe_open(sf_path, framework="pt") as f: | ||
| for key in f.keys(): | ||
| # Load everything except the massive packed expert weights | ||
| if ".mlp.experts." not in key or "weight_shape" in key: | ||
| checkpoint_weights[key] = f.get_tensor(key) | ||
|
|
||
| # 2. Hybrid Restoration | ||
| for name, module in model.named_modules(): | ||
| if not isinstance(module, CompressedLinear): continue | ||
|
|
||
| with torch.no_grad(): | ||
| target_device = next(module.parameters()).device | ||
|
|
||
| # CASE A: Real BF16 weight exists (vision, lm_head) | ||
| if f"{name}.weight" in checkpoint_weights: | ||
| w = checkpoint_weights[f"{name}.weight"].to(target_device) | ||
| module._parameters.pop("weight", None) | ||
| module._buffers.pop("weight", None) | ||
| if "weight" in module.__dict__: | ||
| del module.__dict__["weight"] | ||
| param = torch.nn.Parameter(w, requires_grad=False) | ||
| module._parameters["weight"] = param | ||
| module.__dict__["weight"] = param | ||
| module.quantization_status = QuantizationStatus.FROZEN # Mark non-compressed | ||
| print(f" Restored BF16 layer: {name}") | ||
|
|
||
| # CASE B: Expert (stay compressed, fix metadata) | ||
| elif f"{name}.weight_shape" in checkpoint_weights: | ||
| ws = checkpoint_weights[f"{name}.weight_shape"] | ||
| # Restore int32 packed weights if present | ||
| if f"{name}.weight_packed" in checkpoint_weights: | ||
| module.weight_packed = checkpoint_weights[f"{name}.weight_packed"].to(torch.int32) | ||
| # Ensure no stale BF16 weight is registered for compressed experts | ||
| module._parameters.pop("weight", None) | ||
| module._buffers.pop("weight", None) | ||
| module.__dict__.pop("weight", None) | ||
| # Register weight_shape as int32 parameter for compressed_tensors forward | ||
| shape_param = torch.nn.Parameter(ws.to(torch.int32), requires_grad=False) | ||
| module._parameters.pop("weight_shape", None) | ||
| module.__dict__.pop("weight_shape", None) | ||
| module._parameters["weight_shape"] = shape_param | ||
| module.__dict__["weight_shape"] = shape_param | ||
| # Keep status as COMPRESSED for on-the-fly decompression | ||
|
|
||
| # Ensure compressed experts do not carry a stale weight attribute | ||
| for name, module in model.named_modules(): | ||
| if not isinstance(module, CompressedLinear): | ||
| continue | ||
| if getattr(module, "quantization_status", None) != QuantizationStatus.COMPRESSED: | ||
| continue | ||
| module._parameters.pop("weight", None) | ||
| module._buffers.pop("weight", None) | ||
| module.__dict__.pop("weight", None) | ||
|
|
||
| def get_model( | ||
| ckpt_path, | ||
| device="cuda", | ||
|
|
@@ -289,6 +444,8 @@ def get_model( | |
| ): | ||
| print(f"Initializing model from {ckpt_path}") | ||
|
|
||
| # Note: CompressedLinear weights will be unpacked after model loading | ||
|
|
||
| device_map = "auto" | ||
| if device == "cpu": | ||
| device_map = "cpu" | ||
|
|
@@ -345,23 +502,39 @@ def get_model( | |
| # device_map "auto" and "cuda" triggers error regarding meta tensor from safetensors | ||
| device_map = None | ||
|
|
||
| # Helper function to check if model has pack-quantized config | ||
| def has_pack_quantized_config(config): | ||
| # Check top-level quantization_config | ||
| if hasattr(config, "quantization_config"): | ||
| if config.quantization_config.get("format", None) == "pack-quantized": | ||
| return True | ||
| # Check nested text_config.quantization_config (for multi-modal models like kimi k2.5) | ||
| if hasattr(config, "text_config") and hasattr( | ||
| config.text_config, "quantization_config" | ||
| ): | ||
| if config.text_config.quantization_config.get("format", None) == "pack-quantized": | ||
| return True | ||
| return False | ||
|
|
||
| if is_speculative(hf_config): | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| ckpt_path, | ||
| device_map=device_map, | ||
| **model_kwargs, | ||
| ) | ||
| elif ( | ||
| hasattr(hf_config, "quantization_config") | ||
| and hf_config.quantization_config.get("format", None) == "pack-quantized" | ||
| ): | ||
| torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16) | ||
| elif has_pack_quantized_config(hf_config): | ||
| # Patch CompressedLinear before loading to handle missing weight attribute | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think you need this |
||
| _patch_compressed_linear_init() | ||
| # Pass torch_dtype="auto" to preserve original dtypes from safetensors | ||
| # This prevents int32 packed weights from being converted to float | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| ckpt_path, | ||
| device_map="auto", | ||
| trust_remote_code=trust_remote_code, | ||
| torch_dtype=torch_dtype, | ||
| torch_dtype="auto", | ||
| ) | ||
| # Restore original CompressedLinear behavior after loading | ||
| _restore_compressed_linear() | ||
| else: | ||
| architecture = hf_config.architectures[0] | ||
|
|
||
|
|
@@ -416,6 +589,10 @@ def get_model( | |
| **model_kwargs, | ||
| ) | ||
| model.eval() | ||
| _unpack_compressed_linear_weights(model, ckpt_path) | ||
|
|
||
|
|
||
| # Experts will be decompressed on-the-fly during calibration to save memory | ||
|
|
||
| # If device_map was disabled (None), manually move model to target device | ||
| if device_map is None and device != "cpu": | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -674,23 +674,101 @@ def _setup(self): | |
|
|
||
| def forward(self, input: Tensor) -> Tensor: | ||
| from compressed_tensors.quantization import QuantizationStatus | ||
| import torch | ||
|
|
||
| if self.quantization_status == QuantizationStatus.COMPRESSED: | ||
| weight_data = self.compressor.decompress_module(self) | ||
| # Check if we should use decompress_module or manual decompress_weight | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this specific to kimi k2.5? |
||
| # Real packed weights are int32. If it's float, it's not actually compressed. | ||
| if self.weight_packed.dtype == torch.int32: | ||
| compressed_data = {"weight_packed": self.weight_packed} | ||
| if hasattr(self, "weight_scale"): | ||
| compressed_data["weight_scale"] = self.weight_scale | ||
| if hasattr(self, "weight_shape"): | ||
| ws = self.weight_shape | ||
| if isinstance(ws, torch.Tensor): | ||
| compressed_data["weight_shape"] = [int(x) for x in ws.tolist()] | ||
| else: | ||
| compressed_data["weight_shape"] = [int(x) for x in ws] | ||
| if hasattr(self, "weight_zero_point"): | ||
| compressed_data["weight_zero_point"] = self.weight_zero_point | ||
|
|
||
| quant_args = None | ||
| if hasattr(self, "quantization_scheme") and self.quantization_scheme: | ||
| if hasattr(self.quantization_scheme, "weights"): | ||
| quant_args = self.quantization_scheme.weights | ||
|
|
||
| if not hasattr(self, "_logged_on_the_fly"): | ||
| print(f"[on-the-fly-decompress] {self.__class__.__name__}") | ||
| self._logged_on_the_fly = True | ||
| weight_data = self.compressor.decompress_weight( | ||
| compressed_data=compressed_data, | ||
| quantization_args=quant_args, | ||
| ) | ||
| else: | ||
| # If it's not int32, just use it as-is | ||
| weight_data = self.weight_packed | ||
| else: | ||
| # Standard path for non-compressed layers | ||
| weight_data = self.weight | ||
|
|
||
| return linear(self.input_quantizer(input), self.weight_quantizer(weight_data), self.bias) | ||
|
|
||
| def unpack_weight(self): | ||
| import torch | ||
| from compressed_tensors.quantization import QuantizationStatus | ||
|
|
||
| if self.quantization_status == QuantizationStatus.COMPRESSED: | ||
| self.weight = nn.Parameter(self.compressor.decompress_module(self), requires_grad=False) | ||
| # Build compressed_data dict manually to handle weight_shape tensor issue | ||
| compressed_data = {} | ||
| compressed_data["weight_packed"] = self.weight_packed | ||
| if hasattr(self, "weight_scale"): | ||
| compressed_data["weight_scale"] = self.weight_scale | ||
| if hasattr(self, "weight_shape"): | ||
| ws = self.weight_shape | ||
| if isinstance(ws, torch.Tensor): | ||
| compressed_data["weight_shape"] = [int(x) for x in ws.tolist()] | ||
| elif isinstance(ws, (list, tuple)): | ||
| compressed_data["weight_shape"] = [int(x) for x in ws] | ||
| else: | ||
| compressed_data["weight_shape"] = ws | ||
| if hasattr(self, "weight_zero_point"): | ||
| compressed_data["weight_zero_point"] = self.weight_zero_point | ||
|
|
||
| # Skip non-pack-quantized weights (e.g., vision modules use BF16) | ||
| if isinstance(compressed_data["weight_packed"], torch.Tensor): | ||
| if compressed_data["weight_packed"].dtype != torch.int32: | ||
| return | ||
|
|
||
| # Get quantization args | ||
| quant_args = None | ||
| if hasattr(self, "quantization_scheme") and self.quantization_scheme: | ||
| if hasattr(self.quantization_scheme, "weights"): | ||
| quant_args = self.quantization_scheme.weights | ||
|
|
||
| # Decompress | ||
| decompressed = self.compressor.decompress_weight( | ||
| compressed_data=compressed_data, | ||
| quantization_args=quant_args, | ||
| ) | ||
| # Avoid register_parameter errors if a placeholder already exists | ||
| self._parameters.pop("weight", None) | ||
| self._buffers.pop("weight", None) | ||
| if "weight" in self.__dict__: | ||
| del self.__dict__["weight"] | ||
| param = nn.Parameter(decompressed, requires_grad=False) | ||
| self._parameters["weight"] = param | ||
| self.__dict__["weight"] = param | ||
| if hasattr(self, "weight_packed"): | ||
| del self.weight_packed | ||
| if hasattr(self, "weight_scale"): | ||
| del self.weight_scale | ||
| if hasattr(self, "weight_shape"): | ||
| if "weight_shape" in self._parameters: | ||
| del self._parameters["weight_shape"] | ||
| else: | ||
| delattr(self, "weight_shape") | ||
| if self.quantization_status == QuantizationStatus.COMPRESSED: | ||
| self.quantization_status = QuantizationStatus.FROZEN | ||
|
|
||
|
|
||
| try: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can it be a transformers version issue? I was able to load kimi k2 thinking int4 without an issue. Is this specific to kimi k2.5?