Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 183 additions & 6 deletions examples/llm_ptq/example_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,161 @@ def get_dtype(dtype):
return dtype


def _patch_compressed_linear_init():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be a transformers version issue? I was able to load kimi k2 thinking int4 without an issue. Is this specific to kimi k2.5?

"""Patch CompressedLinear to prevent transformers weight initialization errors.

When loading pack-quantized models, CompressedLinear modules don't have a 'weight'
attribute (they have weight_packed instead). Transformers tries to initialize
missing weights which fails. This patch adds a dummy weight property.
"""
try:
from compressed_tensors.linear.compressed_linear import CompressedLinear
except ImportError:
return # compressed_tensors not installed

if hasattr(CompressedLinear, "_modelopt_init_patched"):
return # Already patched

# Patch __getattr__ to return dummy for weight access
if not hasattr(CompressedLinear, "_modelopt_original_getattr"):
CompressedLinear._modelopt_original_getattr = getattr(CompressedLinear, "__getattr__", None)
original_getattr = CompressedLinear._modelopt_original_getattr

class DummyWeightData:
"""Dummy tensor data that accepts initialization calls like .normal_(), .zero_()."""

def __getattr__(self, name):
# Return self for any method call to allow chaining
return lambda *args, **kwargs: self

class DummyWeight:
"""Dummy weight with .data that accepts any initialization."""

def __init__(self):
self.data = DummyWeightData()

def __getattr__(self, name):
return lambda *args, **kwargs: self

def patched_getattr(self, name):
if name == "weight":
# Check if real weight exists
if "_parameters" in self.__dict__ and "weight" in self._parameters:
return self._parameters["weight"]
if "weight" in self.__dict__:
return self.__dict__["weight"]
# Return dummy weight for initialization purposes (don't store it)
return DummyWeight()
if original_getattr is not None:
return original_getattr(self, name)
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

CompressedLinear.__getattr__ = patched_getattr
CompressedLinear._modelopt_init_patched = True
print("Patched CompressedLinear for transformers compatibility")
def _restore_compressed_linear():
"""Restore original CompressedLinear behavior after loading."""
try:
from compressed_tensors.linear.compressed_linear import CompressedLinear
if hasattr(CompressedLinear, "_modelopt_original_getattr"):
CompressedLinear.__getattr__ = CompressedLinear._modelopt_original_getattr
delattr(CompressedLinear, "_modelopt_original_getattr")
elif hasattr(CompressedLinear, "__getattr__"):
# If it didn't have one before, delete the patched one
del CompressedLinear.__getattr__
CompressedLinear._modelopt_init_patched = False
print("Restored CompressedLinear original state")
except Exception:
pass



def _unpack_compressed_linear_weights(model, ckpt_path=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not need it. We should be able to unpack on the fly with logics in the quantization plugins

"""Hybrid restoration: restores BF16 layers and fixes expert metadata.

1. BF16 layers (vision, lm_head) are restored from checkpoint and marked non-compressed.
2. INT4 experts stay compressed in HBM to save memory (decompressed on-the-fly).
3. Metadata (weight_shape) is fixed to avoid decompression errors.
"""
try:
from compressed_tensors.linear.compressed_linear import CompressedLinear
from compressed_tensors.quantization import QuantizationStatus
except ImportError:
return

if ckpt_path is None:
ckpt_path = getattr(model.config, "_name_or_path", None)
if not ckpt_path: return

import os, json, torch
from safetensors import safe_open

# 1. Load weights from safetensors
checkpoint_weights = {}
index_path = os.path.join(ckpt_path, "model.safetensors.index.json")
st_files = [os.path.join(ckpt_path, "model.safetensors")]
if os.path.exists(index_path):
with open(index_path) as f:
index = json.load(f)
st_files = [os.path.join(ckpt_path, f) for f in set(index.get("weight_map", {}).values())]

# We only need to load non-expert weights or metadata
for sf_path in st_files:
if not os.path.exists(sf_path): continue
with safe_open(sf_path, framework="pt") as f:
for key in f.keys():
# Load everything except the massive packed expert weights
if ".mlp.experts." not in key or "weight_shape" in key:
checkpoint_weights[key] = f.get_tensor(key)

# 2. Hybrid Restoration
for name, module in model.named_modules():
if not isinstance(module, CompressedLinear): continue

with torch.no_grad():
target_device = next(module.parameters()).device

# CASE A: Real BF16 weight exists (vision, lm_head)
if f"{name}.weight" in checkpoint_weights:
w = checkpoint_weights[f"{name}.weight"].to(target_device)
module._parameters.pop("weight", None)
module._buffers.pop("weight", None)
if "weight" in module.__dict__:
del module.__dict__["weight"]
param = torch.nn.Parameter(w, requires_grad=False)
module._parameters["weight"] = param
module.__dict__["weight"] = param
module.quantization_status = QuantizationStatus.FROZEN # Mark non-compressed
print(f" Restored BF16 layer: {name}")

# CASE B: Expert (stay compressed, fix metadata)
elif f"{name}.weight_shape" in checkpoint_weights:
ws = checkpoint_weights[f"{name}.weight_shape"]
# Restore int32 packed weights if present
if f"{name}.weight_packed" in checkpoint_weights:
module.weight_packed = checkpoint_weights[f"{name}.weight_packed"].to(torch.int32)
# Ensure no stale BF16 weight is registered for compressed experts
module._parameters.pop("weight", None)
module._buffers.pop("weight", None)
module.__dict__.pop("weight", None)
# Register weight_shape as int32 parameter for compressed_tensors forward
shape_param = torch.nn.Parameter(ws.to(torch.int32), requires_grad=False)
module._parameters.pop("weight_shape", None)
module.__dict__.pop("weight_shape", None)
module._parameters["weight_shape"] = shape_param
module.__dict__["weight_shape"] = shape_param
# Keep status as COMPRESSED for on-the-fly decompression

# Ensure compressed experts do not carry a stale weight attribute
for name, module in model.named_modules():
if not isinstance(module, CompressedLinear):
continue
if getattr(module, "quantization_status", None) != QuantizationStatus.COMPRESSED:
continue
module._parameters.pop("weight", None)
module._buffers.pop("weight", None)
module.__dict__.pop("weight", None)

def get_model(
ckpt_path,
device="cuda",
Expand All @@ -289,6 +444,8 @@ def get_model(
):
print(f"Initializing model from {ckpt_path}")

# Note: CompressedLinear weights will be unpacked after model loading

device_map = "auto"
if device == "cpu":
device_map = "cpu"
Expand Down Expand Up @@ -345,23 +502,39 @@ def get_model(
# device_map "auto" and "cuda" triggers error regarding meta tensor from safetensors
device_map = None

# Helper function to check if model has pack-quantized config
def has_pack_quantized_config(config):
# Check top-level quantization_config
if hasattr(config, "quantization_config"):
if config.quantization_config.get("format", None) == "pack-quantized":
return True
# Check nested text_config.quantization_config (for multi-modal models like kimi k2.5)
if hasattr(config, "text_config") and hasattr(
config.text_config, "quantization_config"
):
if config.text_config.quantization_config.get("format", None) == "pack-quantized":
return True
return False

if is_speculative(hf_config):
model = AutoModelForCausalLM.from_pretrained(
ckpt_path,
device_map=device_map,
**model_kwargs,
)
elif (
hasattr(hf_config, "quantization_config")
and hf_config.quantization_config.get("format", None) == "pack-quantized"
):
torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16)
elif has_pack_quantized_config(hf_config):
# Patch CompressedLinear before loading to handle missing weight attribute
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think you need this

_patch_compressed_linear_init()
# Pass torch_dtype="auto" to preserve original dtypes from safetensors
# This prevents int32 packed weights from being converted to float
model = AutoModelForCausalLM.from_pretrained(
ckpt_path,
device_map="auto",
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
torch_dtype="auto",
)
# Restore original CompressedLinear behavior after loading
_restore_compressed_linear()
else:
architecture = hf_config.architectures[0]

Expand Down Expand Up @@ -416,6 +589,10 @@ def get_model(
**model_kwargs,
)
model.eval()
_unpack_compressed_linear_weights(model, ckpt_path)


# Experts will be decompressed on-the-fly during calibration to save memory

# If device_map was disabled (None), manually move model to target device
if device_map is None and device != "cpu":
Expand Down
82 changes: 80 additions & 2 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,23 +674,101 @@ def _setup(self):

def forward(self, input: Tensor) -> Tensor:
from compressed_tensors.quantization import QuantizationStatus
import torch

if self.quantization_status == QuantizationStatus.COMPRESSED:
weight_data = self.compressor.decompress_module(self)
# Check if we should use decompress_module or manual decompress_weight
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this specific to kimi k2.5?

# Real packed weights are int32. If it's float, it's not actually compressed.
if self.weight_packed.dtype == torch.int32:
compressed_data = {"weight_packed": self.weight_packed}
if hasattr(self, "weight_scale"):
compressed_data["weight_scale"] = self.weight_scale
if hasattr(self, "weight_shape"):
ws = self.weight_shape
if isinstance(ws, torch.Tensor):
compressed_data["weight_shape"] = [int(x) for x in ws.tolist()]
else:
compressed_data["weight_shape"] = [int(x) for x in ws]
if hasattr(self, "weight_zero_point"):
compressed_data["weight_zero_point"] = self.weight_zero_point

quant_args = None
if hasattr(self, "quantization_scheme") and self.quantization_scheme:
if hasattr(self.quantization_scheme, "weights"):
quant_args = self.quantization_scheme.weights

if not hasattr(self, "_logged_on_the_fly"):
print(f"[on-the-fly-decompress] {self.__class__.__name__}")
self._logged_on_the_fly = True
weight_data = self.compressor.decompress_weight(
compressed_data=compressed_data,
quantization_args=quant_args,
)
else:
# If it's not int32, just use it as-is
weight_data = self.weight_packed
else:
# Standard path for non-compressed layers
weight_data = self.weight

return linear(self.input_quantizer(input), self.weight_quantizer(weight_data), self.bias)

def unpack_weight(self):
import torch
from compressed_tensors.quantization import QuantizationStatus

if self.quantization_status == QuantizationStatus.COMPRESSED:
self.weight = nn.Parameter(self.compressor.decompress_module(self), requires_grad=False)
# Build compressed_data dict manually to handle weight_shape tensor issue
compressed_data = {}
compressed_data["weight_packed"] = self.weight_packed
if hasattr(self, "weight_scale"):
compressed_data["weight_scale"] = self.weight_scale
if hasattr(self, "weight_shape"):
ws = self.weight_shape
if isinstance(ws, torch.Tensor):
compressed_data["weight_shape"] = [int(x) for x in ws.tolist()]
elif isinstance(ws, (list, tuple)):
compressed_data["weight_shape"] = [int(x) for x in ws]
else:
compressed_data["weight_shape"] = ws
if hasattr(self, "weight_zero_point"):
compressed_data["weight_zero_point"] = self.weight_zero_point

# Skip non-pack-quantized weights (e.g., vision modules use BF16)
if isinstance(compressed_data["weight_packed"], torch.Tensor):
if compressed_data["weight_packed"].dtype != torch.int32:
return

# Get quantization args
quant_args = None
if hasattr(self, "quantization_scheme") and self.quantization_scheme:
if hasattr(self.quantization_scheme, "weights"):
quant_args = self.quantization_scheme.weights

# Decompress
decompressed = self.compressor.decompress_weight(
compressed_data=compressed_data,
quantization_args=quant_args,
)
# Avoid register_parameter errors if a placeholder already exists
self._parameters.pop("weight", None)
self._buffers.pop("weight", None)
if "weight" in self.__dict__:
del self.__dict__["weight"]
param = nn.Parameter(decompressed, requires_grad=False)
self._parameters["weight"] = param
self.__dict__["weight"] = param
if hasattr(self, "weight_packed"):
del self.weight_packed
if hasattr(self, "weight_scale"):
del self.weight_scale
if hasattr(self, "weight_shape"):
if "weight_shape" in self._parameters:
del self._parameters["weight_shape"]
else:
delattr(self, "weight_shape")
if self.quantization_status == QuantizationStatus.COMPRESSED:
self.quantization_status = QuantizationStatus.FROZEN


try:
Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
"preprocess": lambda sample: "\n".join(turn["value"] for turn in sample["conversations"]),
},
"cnn_dailymail": {
"config": {"path": "cnn_dailymail", "name": "3.0.0", "split": ["train"]},
"config": {"path": "abisee/cnn_dailymail", "name": "3.0.0", "split": ["train"]},
"preprocess": lambda sample: sample["article"],
},
"pile": {
Expand Down
Loading