From d8721d7f17499da9401564be47c8a0f037d0c71d Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 2 Dec 2025 20:30:12 -0800 Subject: [PATCH] add static quant --- .../workflows/int8/test_int8_tensor.py | 14 +--- torchao/quantization/__init__.py | 2 + torchao/quantization/quant_api.py | 71 +++++++++++++++++-- .../quantize_/workflows/int8/int8_tensor.py | 51 +++++++++---- 4 files changed, 106 insertions(+), 32 deletions(-) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 41583ad743..6f08376038 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -31,13 +31,13 @@ version=2, granularity=PerTensor(), act_mapping_type=MappingType.ASYMMETRIC ), Int8DynamicActivationInt8WeightConfig( - version=2, granularity=PerRow(), act_mapping_type=MappingType.ASYMMETRIC + version=2, granularity=PerTensor(), act_mapping_type=MappingType.SYMMETRIC ), Int8DynamicActivationInt8WeightConfig( - version=2, granularity=PerTensor(), act_mapping_type=MappingType.SYMMETRIC + version=2, granularity=PerRow(), act_mapping_type=MappingType.ASYMMETRIC ), Int8DynamicActivationInt8WeightConfig( - version=2, granularity=PerRow(), act_mapping_type=MappingType.SYMMETRIC + version=2, granularity=PerTensor(), act_mapping_type=MappingType.SYMMETRIC ), ] @@ -77,14 +77,6 @@ def test_creation_and_attributes(self, config): elif isinstance(config.granularity, PerTensor): self.assertEqual(w.scale.shape, (1, 1)) - if config.act_mapping_type == MappingType.SYMMETRIC: - self.assertEqual(w.zero_point, None) - elif config.act_mapping_type == MappingType.ASYMMETRIC: - if isinstance(config.granularity, PerRow): - self.assertEqual(w.zero_point.shape, (w.shape[0], 1)) - elif isinstance(config.granularity, PerTensor): - self.assertEqual(w.zero_point.shape, (1, 1)) - @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("compile", [True, False]) @common_utils.parametrize("config", INT8_TEST_CONFIGS) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 80e11dda5b..f7adfcd6e4 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -59,6 +59,7 @@ Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, Int8DynamicActivationIntxWeightConfig, + Int8StaticActivationInt8WeightConfig, Int8WeightOnlyConfig, IntxWeightOnlyConfig, ModuleFqnToConfig, @@ -150,6 +151,7 @@ "Int8DynamicActivationInt4WeightConfig", "Int8DynamicActivationInt8WeightConfig", "Int8DynamicActivationIntxWeightConfig", + "Int8StaticActivationInt8WeightConfig", "Int4WeightOnlyConfig", "Float8DynamicActivationInt4WeightConfig", "Int8WeightOnlyConfig", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index fe5dfd0c74..72ac4dded7 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -88,6 +88,7 @@ IntxPackingFormat, IntxUnpackedToInt8Tensor, QuantizeTensorToFloat8Kwargs, + QuantizeTensorToInt8Kwargs, ) from torchao.quantization.transform_module import ( _QUANTIZE_CONFIG_HANDLER, @@ -1590,10 +1591,6 @@ def get_weight_block_size(x): ) quantized_weight = to_linear_activation_quantized(new_weight, input_quant_func) else: - from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( - QuantizeTensorToInt8Kwargs, - ) - assert config.granularity in {PerRow(), PerTensor()}, ( "Only PerRow and PerTensor are supported" ) @@ -1608,7 +1605,7 @@ def get_weight_block_size(x): granularity=config.granularity, act_quant_kwargs=QuantizeTensorToInt8Kwargs( granularity=act_granularity, - act_mapping_type=config.act_mapping_type, + mapping_type=config.act_mapping_type, ), ) @@ -1617,7 +1614,10 @@ def get_weight_block_size(x): @register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) def _int8_dynamic_activation_int8_weight_transform( - module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig + module: torch.nn.Module, + config: Int8DynamicActivationInt8WeightConfig, + *, + parameter_name="weight", ) -> torch.nn.Module: if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() @@ -1634,6 +1634,65 @@ def _int8_dynamic_activation_int8_weight_transform( return module +@dataclass +class Int8StaticActivationInt8WeightConfig(AOBaseConfig): + """ + Configuration for applying float8 static symmetric quantization to + + Args: + scale (torch.Tensor): The scale tensor for activation quantization. + activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m + weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m + mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation. + set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. + """ + + scale: torch.Tensor + zero_point: Optional[torch.Tensor] = None + act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC + granularity: Optional[Union[Granularity, List[Granularity]]] = PerRow() + set_inductor_config: bool = True + version: int = 1 + + def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.quantization.Int8StaticActivationInt8WeightConfig" + ) + + +@register_quantize_module_handler(Int8StaticActivationInt8WeightConfig) +def _int8_static_activation_int8_weight_transform( + module: torch.nn.Module, config: Int8StaticActivationInt8WeightConfig +): + assert config.granularity in {PerRow(), PerTensor()}, ( + "Only PerRow and PerTensor are supported" + ) + + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + activation_granularity, weight_granularity = _normalize_granularity( + config.granularity + ) + weight = module.weight + + # TODO: Symmentric/Asymmetric choice for weight quantization + # https://github.com/pytorch/ao/pull/3241#discussion_r2551515539 + quantized_weight = Int8Tensor.from_hp( + weight, + granularity=weight_granularity, + act_quant_kwargs=QuantizeTensorToInt8Kwargs( + granularity=activation_granularity, + mapping_type=config.act_mapping_type, + scale=config.scale, + zero_point=config.zero_point, + ), + ) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + def int8_dynamic_activation_int8_semi_sparse_weight(): """ Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 420e7f5b81..9fe85449fd 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -37,7 +37,9 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): """ granularity: Granularity = PerRow() - act_mapping_type: MappingType = MappingType.SYMMETRIC + mapping_type: MappingType = MappingType.SYMMETRIC + scale: Optional[torch.Tensor] = None + zero_point: Optional[torch.Tensor] = None class Int8Tensor(TorchAOBaseTensor): @@ -58,6 +60,7 @@ class Int8Tensor(TorchAOBaseTensor): tensor_data_names = ["qdata", "scale"] tensor_attribute_names = [] optional_tensor_attribute_names = [ + "zero_point", "block_size", "act_quant_kwargs", "dtype", @@ -67,6 +70,7 @@ def __new__( cls: type, qdata: torch.Tensor, scale: torch.Tensor, + zero_point: Optional[torch.Tensor] = None, block_size: Optional[List[int]] = None, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, dtype: Optional[torch.dtype] = None, @@ -82,6 +86,7 @@ def __init__( self, qdata: torch.Tensor, scale: torch.Tensor, + zero_point: Optional[torch.Tensor] = None, block_size: Optional[List[int]] = None, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, dtype: Optional[torch.dtype] = None, @@ -89,6 +94,7 @@ def __init__( super().__init__() self.qdata = qdata self.scale = scale + self.zero_point = zero_point self.block_size = block_size self.act_quant_kwargs = act_quant_kwargs @@ -98,6 +104,7 @@ def __repr__(self): f"act_quant_kwargs={self.act_quant_kwargs}, " f"qdata={self.qdata}, " f"scale={self.scale}, " + f"zero_point={self.scale}, " f"block_size={self.block_size}, " f"shape={self.shape}, " f"device={self.device}, " @@ -110,23 +117,30 @@ def from_hp( hp_tensor: torch.Tensor, granularity: Granularity = PerRow(), act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, + scale: Optional[torch.Tensor] = None, + zero_point: Optional[torch.Tensor] = None, + mapping_type: MappingType = MappingType.SYMMETRIC, ): """Create Int8Tensor from high-precision tensor""" block_size = get_block_size(hp_tensor.shape, granularity) block_size = list(block_size) - scale, zero_point = choose_qparams_affine( - input=hp_tensor, - mapping_type=MappingType.SYMMETRIC, - block_size=block_size, - target_dtype=torch.int8, - quant_min=-128, - quant_max=127, - scale_dtype=hp_tensor.dtype, - zero_point_dtype=torch.int8, - keepdim=True, - ) - + # if scale and zero_point not given, then choose them dynamically + if scale is None and zero_point is None: + scale, zero_point = choose_qparams_affine( + input=hp_tensor, + mapping_type=mapping_type, + block_size=block_size, + target_dtype=torch.int8, + quant_min=-128, + quant_max=127, + scale_dtype=hp_tensor.dtype, + zero_point_dtype=torch.int8, + keepdim=True, + ) + + # if they are given, then use them to quantize + # this is how we support static quantization int_data = quantize_affine( hp_tensor, block_size=block_size, @@ -145,11 +159,14 @@ def from_hp( def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize int8 tensor to floating point""" + zero_point = self.zero_point + if zero_point is not None: + zero_point = zero_point.squeeze() return dequantize_affine( input=self.qdata, block_size=self.block_size, scale=self.scale.squeeze(), - zero_point=None, + zero_point=zero_point, input_dtype=torch.int8, quant_min=-128, quant_max=127, @@ -179,7 +196,11 @@ def _(func, types, args, kwargs): if weight_tensor.act_quant_kwargs is not None: activation_tensor = Int8Tensor.from_hp( - activation_tensor, weight_tensor.act_quant_kwargs.granularity + activation_tensor, + granularity=weight_tensor.act_quant_kwargs.granularity, + mapping_type=weight_tensor.act_quant_kwargs.mapping_type, + scale=weight_tensor.act_quant_kwargs.scale, + zero_point=weight_tensor.act_quant_kwargs.zero_point, ) # Dynamic activation quantization path