diff --git a/src/compressed_tensors/compressors/base.py b/src/compressed_tensors/compressors/base.py index 3a8a97eb..be607f4f 100644 --- a/src/compressed_tensors/compressors/base.py +++ b/src/compressed_tensors/compressors/base.py @@ -20,6 +20,11 @@ from compressed_tensors.quantization import QuantizationArgs, QuantizationConfig from compressed_tensors.registry import RegistryMixin from compressed_tensors.utils import has_offloaded_params +from compressed_tensors.utils.offload import ( + delete_offload_parameter, + get_offloaded_device, + register_offload_parameter, +) from torch import Tensor from torch.nn import Module @@ -185,10 +190,23 @@ def decompress_module(self, module: Module): for name, parameter in module.named_parameters(): compressed_data[name] = parameter - return self.decompress_weight( + result = self.decompress_weight( compressed_data=compressed_data, quantization_args=quantization_args ).to(device) + # Update module's parameters if they were unpacked/upcast during decompression + for param_name in ["weight_zero_point", "weight_scale"]: + if param_name in compressed_data and hasattr(module, param_name): + # Delete the old parameter and register the updated one + delete_offload_parameter(module, param_name) + offload_device = get_offloaded_device(module) + param = torch.nn.Parameter( + compressed_data[param_name], requires_grad=False + ) + register_offload_parameter(module, param_name, param, offload_device) + + return result + def decompress_weight( self, compressed_data: Dict[str, Tensor], **kwargs ) -> torch.Tensor: diff --git a/src/compressed_tensors/compressors/quantized_compressors/base.py b/src/compressed_tensors/compressors/quantized_compressors/base.py index 19f6c9c0..71913e6f 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/base.py +++ b/src/compressed_tensors/compressors/quantized_compressors/base.py @@ -18,7 +18,7 @@ import torch from compressed_tensors.compressors.base import BaseCompressor -from compressed_tensors.quantization import QuantizationScheme, QuantizationStrategy +from compressed_tensors.quantization import QuantizationScheme from compressed_tensors.utils import ( get_nested_mappings_from_state_dict, get_nested_weight_mappings, @@ -124,9 +124,21 @@ def compress( compressed_dict[prefix + key] = value.to(compression_device) else: - # omit saving zero points for symmetric or packed quantization - if name.endswith("zero_point") and self._skip_zp(name, names_to_scheme): - continue + # omit saving zero points for symmetric quantization + if name.endswith("weight_zero_point"): + module_path = name.rsplit(".", 1)[0] + if ( + module_path in names_to_scheme + and names_to_scheme[module_path].weights.symmetric + ): + continue + # Call compress_zp if available (for PackedQuantizationCompressor) + if module_path in names_to_scheme and hasattr(self, "compress_zp"): + value = self.compress_zp( + value, names_to_scheme[module_path].weights + ) + if value is None: + continue if name.endswith("weight_scale") and self._skip_scale(): continue @@ -140,33 +152,6 @@ def _skip_scale(self): return isinstance(self, NVFP4PackedCompressor) - def _skip_zp( - self, name: str, names_to_scheme: Dict[str, QuantizationScheme] - ) -> bool: - from compressed_tensors.compressors import PackedQuantizationCompressor - - module_name, zp_name = name.rsplit(".", 1) if "." in name else ("", name) - scheme = names_to_scheme[module_name] - - if zp_name == "weight_zero_point": - args = scheme.weights - if zp_name == "input_zero_point": - args = scheme.input_activations - if zp_name == "output_zero_point": - args = scheme.output_activations - - symmetric = args.symmetric - packable_strategies = [ - QuantizationStrategy.GROUP.value, - QuantizationStrategy.CHANNEL.value, - ] - packed = ( - isinstance(self, PackedQuantizationCompressor) - and args.strategy in packable_strategies - ) - - return symmetric or packed - def decompress( self, path_to_model_or_tensors: Union[str, Path, Dict[str, Any]], diff --git a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py index b9bd8ede..7cdcc981 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py @@ -117,6 +117,12 @@ def decompress_weight( m, n = weight.shape # TODO: use a user provided dequant dtype unpacked = unpack_fp4_from_uint8(weight, m, n * 2) + + # cast scale dtype to match unpacked dtype for dequantization + if scale.dtype != unpacked.dtype: + scale = scale.to(unpacked.dtype) + compressed_data["weight_scale"] = scale + decompressed_weight = dequantize( x_q=unpacked, scale=scale, global_scale=global_scale, dtype=unpacked.dtype ) diff --git a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py index 07da50c7..d8560f3a 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/pack_quantized.py @@ -175,6 +175,8 @@ def decompress_weight( zero_point = unpack_from_int32( zero_point, num_bits, original_zp_shape, packed_dim=0 ) + # Update the compressed_data dict with the unpacked zero_point + compressed_data["weight_zero_point"] = zero_point decompressed_weight = dequantize( x_q=unpacked, scale=scale, zero_point=zero_point, g_idx=g_idx @@ -182,6 +184,22 @@ def decompress_weight( return decompressed_weight + def compress_zp( + self, zero_point: Tensor, quantization_args: Optional[QuantizationArgs] = None + ) -> Optional[Tensor]: + if zero_point is None or quantization_args.symmetric: + return None + if zero_point.dtype == torch.int32: + return zero_point + if quantization_args.strategy in [ + QuantizationStrategy.GROUP.value, + QuantizationStrategy.CHANNEL.value, + ]: + return pack_to_int32( + zero_point, quantization_args.num_bits, packed_dim=0 + ).contiguous() + return zero_point + def pack_to_int32( value: torch.Tensor,