Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/compressed_tensors/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig
from compressed_tensors.registry import RegistryMixin
from compressed_tensors.utils import has_offloaded_params
from compressed_tensors.utils.offload import (
delete_offload_parameter,
get_offloaded_device,
register_offload_parameter,
)
from torch import Tensor
from torch.nn import Module

Expand Down Expand Up @@ -185,10 +190,23 @@ def decompress_module(self, module: Module):
for name, parameter in module.named_parameters():
compressed_data[name] = parameter

return self.decompress_weight(
result = self.decompress_weight(
compressed_data=compressed_data, quantization_args=quantization_args
).to(device)

# Update module's parameters if they were unpacked/upcast during decompression
for param_name in ["weight_zero_point", "weight_scale"]:
if param_name in compressed_data and hasattr(module, param_name):
# Delete the old parameter and register the updated one
delete_offload_parameter(module, param_name)
offload_device = get_offloaded_device(module)
param = torch.nn.Parameter(
compressed_data[param_name], requires_grad=False
)
register_offload_parameter(module, param_name, param, offload_device)

return result

def decompress_weight(
self, compressed_data: Dict[str, Tensor], **kwargs
) -> torch.Tensor:
Expand Down
47 changes: 16 additions & 31 deletions src/compressed_tensors/compressors/quantized_compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy
from compressed_tensors.quantization import QuantizationScheme
from compressed_tensors.utils import (
get_nested_mappings_from_state_dict,
get_nested_weight_mappings,
Expand Down Expand Up @@ -124,9 +124,21 @@ def compress(
compressed_dict[prefix + key] = value.to(compression_device)

else:
# omit saving zero points for symmetric or packed quantization
if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme):
continue
# omit saving zero points for symmetric quantization
if name.endswith("weight_zero_point"):
module_path = name.rsplit(".", 1)[0]
if (
module_path in names_to_scheme
and names_to_scheme[module_path].weights.symmetric
):
continue
# Call compress_zp if available (for PackedQuantizationCompressor)
if module_path in names_to_scheme and hasattr(self, "compress_zp"):
value = self.compress_zp(
value, names_to_scheme[module_path].weights
)
if value is None:
continue

if name.endswith("weight_scale") and self._skip_scale():
continue
Expand All @@ -140,33 +152,6 @@ def _skip_scale(self):

return isinstance(self, NVFP4PackedCompressor)

def _skip_zp(
self, name: str, names_to_scheme: Dict[str, QuantizationScheme]
) -> bool:
from compressed_tensors.compressors import PackedQuantizationCompressor

module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name)
scheme = names_to_scheme[module_name]

if zp_name == "weight_zero_point":
args = scheme.weights
if zp_name == "input_zero_point":
args = scheme.input_activations
if zp_name == "output_zero_point":
args = scheme.output_activations

symmetric = args.symmetric
packable_strategies = [
QuantizationStrategy.GROUP.value,
QuantizationStrategy.CHANNEL.value,
]
packed = (
isinstance(self, PackedQuantizationCompressor)
and args.strategy in packable_strategies
)

return symmetric or packed

def decompress(
self,
path_to_model_or_tensors: Union[str, Path, Dict[str, Any]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,12 @@ def decompress_weight(
m, n = weight.shape
# TODO: use a user provided dequant dtype
unpacked = unpack_fp4_from_uint8(weight, m, n * 2)

# cast scale dtype to match unpacked dtype for dequantization
if scale.dtype != unpacked.dtype:
scale = scale.to(unpacked.dtype)
compressed_data["weight_scale"] = scale

decompressed_weight = dequantize(
x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,13 +175,31 @@ def decompress_weight(
zero_point = unpack_from_int32(
zero_point, num_bits, original_zp_shape, packed_dim=0
)
# Update the compressed_data dict with the unpacked zero_point
compressed_data["weight_zero_point"] = zero_point

decompressed_weight = dequantize(
x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx
)

return decompressed_weight

def compress_zp(
self, zero_point: Tensor, quantization_args: Optional[QuantizationArgs] = None
) -> Optional[Tensor]:
if zero_point is None or quantization_args.symmetric:
return None
if zero_point.dtype == torch.int32:
return zero_point
if quantization_args.strategy in [
QuantizationStrategy.GROUP.value,
QuantizationStrategy.CHANNEL.value,
]:
return pack_to_int32(
zero_point, quantization_args.num_bits, packed_dim=0
).contiguous()
return zero_point


def pack_to_int32(
value: torch.Tensor,
Expand Down