diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 2819903e69..80147e68ba 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -14,11 +14,15 @@ from torchao.quantization import ( Int8DynamicActivationInt8WeightConfig, + Int8StaticActivationInt8WeightConfig, Int8WeightOnlyConfig, quantize_, ) from torchao.quantization.granularity import PerRow, PerTensor from torchao.quantization.quant_primitives import MappingType +from torchao.quantization.quantize_.common import ( + _choose_quant_func_and_quantize_tensor, +) from torchao.quantization.utils import compute_error, get_block_size from torchao.testing.model_architectures import ToyTwoLinearModel from torchao.testing.utils import TorchAOIntegrationTestCase @@ -221,5 +225,66 @@ def test_available_gpu_kernels(self): ).check_count("triton_poi_fused", 1).run(code[0]) +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@common_utils.instantiate_parametrized_tests +class TestInt8StaticQuant(TorchAOIntegrationTestCase): + @common_utils.parametrize("granularity", [PerRow(), PerTensor()]) + @common_utils.parametrize("dtype", [torch.bfloat16]) + def test_static_activation_per_row_int8_weight(self, granularity, dtype): + torch.compiler.reset() + + M, N, K = 32, 32, 32 + input_tensor = torch.randn(M, K, dtype=dtype, device="cuda") + + model = torch.nn.Linear(K, N, bias=False).eval().to(device="cuda", dtype=dtype) + model_static_quant = copy.deepcopy(model) + model_dynamic_quant = copy.deepcopy(model) + + model_out_baseline = model(input_tensor) + + dynamic_config = Int8DynamicActivationInt8WeightConfig( + version=2, granularity=granularity + ) + quantize_(model_dynamic_quant, dynamic_config) + + dynamic_out_eager = model_dynamic_quant(input_tensor) + sqnr_dynamic_eager = compute_error(model_out_baseline, dynamic_out_eager) + + model_dynamic_quant = torch.compile(model_dynamic_quant, fullgraph=True) + + dynamic_out_compile = model_dynamic_quant(input_tensor) + sqnr_dynamic_compile = compute_error(model_out_baseline, dynamic_out_compile) + + # we use eager scales to calculate + int8_input = _choose_quant_func_and_quantize_tensor( + input_tensor, model_dynamic_quant.weight.act_quant_kwargs + ) + + static_config = Int8StaticActivationInt8WeightConfig( + scale=int8_input.scale.detach().clone(), + granularity=granularity, + ) + quantize_(model_static_quant, static_config) + + static_out_eager = model_static_quant(input_tensor) + sqnr_static_eager = compute_error(model_out_baseline, static_out_eager) + + model_static_quant = torch.compile(model_static_quant, fullgraph=True) + + static_out_compile = model_dynamic_quant(input_tensor) + sqnr_static_compile = compute_error(model_out_baseline, static_out_compile) + + assert ( + sqnr_static_compile + == sqnr_static_eager + == sqnr_dynamic_compile + == sqnr_dynamic_eager + ), "SQNR should be the same for all quantization methods and eager/compile" + + # eager numerics should match exactly + # for compile, we can't compare dynamic vs static because we may get slightly different qparams when fused + torch.testing.assert_close(dynamic_out_eager, static_out_eager) + + if __name__ == "__main__": common_utils.run_tests() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 80e11dda5b..f7adfcd6e4 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -59,6 +59,7 @@ Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, Int8DynamicActivationIntxWeightConfig, + Int8StaticActivationInt8WeightConfig, Int8WeightOnlyConfig, IntxWeightOnlyConfig, ModuleFqnToConfig, @@ -150,6 +151,7 @@ "Int8DynamicActivationInt4WeightConfig", "Int8DynamicActivationInt8WeightConfig", "Int8DynamicActivationIntxWeightConfig", + "Int8StaticActivationInt8WeightConfig", "Int4WeightOnlyConfig", "Float8DynamicActivationInt4WeightConfig", "Int8WeightOnlyConfig", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 24d6b6676c..e19e35c20c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -88,6 +88,7 @@ IntxPackingFormat, IntxUnpackedToInt8Tensor, QuantizeTensorToFloat8Kwargs, + QuantizeTensorToInt8Kwargs, ) from torchao.quantization.transform_module import ( _QUANTIZE_CONFIG_HANDLER, @@ -1590,10 +1591,6 @@ def get_weight_block_size(x): ) quantized_weight = to_linear_activation_quantized(new_weight, input_quant_func) else: - from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( - QuantizeTensorToInt8Kwargs, - ) - assert config.granularity in {PerRow(), PerTensor()}, ( "Only PerRow and PerTensor are supported" ) @@ -1621,7 +1618,10 @@ def get_weight_block_size(x): @register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig) def _int8_dynamic_activation_int8_weight_transform( - module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig + module: torch.nn.Module, + config: Int8DynamicActivationInt8WeightConfig, + *, + parameter_name="weight", ) -> torch.nn.Module: if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() @@ -1634,7 +1634,88 @@ def _int8_dynamic_activation_int8_weight_transform( module.weight, config ) module.weight = torch.nn.Parameter(new_weight, requires_grad=False) - module.extra_repr = types.MethodType(_linear_extra_repr, module) + module.extra_repr = types.MethodType( + partial( + _module_extra_repr, + original_extra_repr=module.extra_repr, + parameter_name=parameter_name, + ), + module, + ) + return module + + +@dataclass +class Int8StaticActivationInt8WeightConfig(AOBaseConfig): + """ + Configuration for applying int8 static symmetric quantization to both activation and weight + + Args: + scale (torch.Tensor): The scale tensor for activation quantization. + granularity (Granularity): The granularity of quantization. PerRow() and PerTensor() are supported currently + act_mapping_type (MappingType): The mapping type for activation quantization. only SYMMETRIC is supported currently + set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values. + version (int): the version of the config + """ + + scale: torch.Tensor + granularity: Granularity = PerRow() + act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC + set_inductor_config: bool = True + version: int = 1 + + def __post_init__(self): + torch._C._log_api_usage_once( + "torchao.quantization.Int8StaticActivationInt8WeightConfig" + ) + + +@register_quantize_module_handler(Int8StaticActivationInt8WeightConfig) +def _int8_static_activation_int8_weight_transform( + module: torch.nn.Module, + config: Int8StaticActivationInt8WeightConfig, + *, + parameter_name="weight", +): + assert config.granularity in {PerRow(), PerTensor()}, ( + "Only PerRow and PerTensor is supported currently" + ) + assert config.act_mapping_type == MappingType.SYMMETRIC, ( + "asymmetric static quant not supported currently" + ) + assert hasattr(module, parameter_name), ( + f"Expected module to have attribute `{parameter_name}` but not found" + ) + + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + activation_granularity = config.granularity + weight_granularity = config.granularity + + quantized_tensor = Int8Tensor.from_hp( + getattr(module, parameter_name), + granularity=weight_granularity, + act_quant_kwargs=QuantizeTensorToInt8Kwargs( + granularity=activation_granularity, + mapping_type=config.act_mapping_type, + ), + act_scale=config.scale.detach(), + ) + + setattr( + module, + parameter_name, + torch.nn.Parameter(quantized_tensor, requires_grad=False), + ) + module.extra_repr = types.MethodType( + partial( + _module_extra_repr, + original_extra_repr=module.extra_repr, + parameter_name=parameter_name, + ), + module, + ) return module diff --git a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py index e4544a2f0c..06a289ab4a 100644 --- a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py +++ b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import abc -from typing import ClassVar +from typing import ClassVar, Optional import torch @@ -31,7 +31,9 @@ def from_hp(cls, tensor, quant_kwargs: QuantizeTensorKwargs) def _choose_quant_func_and_quantize_tensor( - tensor: torch.Tensor, quant_kwargs: QuantizeTensorKwargs + tensor: torch.Tensor, + quant_kwargs: QuantizeTensorKwargs, + scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """Given a tensor and a kwargs container, chooses a derived dtype (float8, int8, etc) to quantize tensor to, based on the type of quant_kwargs quantizes tensor to the derived dtype chosen in (1) @@ -60,6 +62,7 @@ def _choose_quant_func_and_quantize_tensor( tensor, quant_kwargs.granularity, mapping_type=quant_kwargs.mapping_type, + scale=scale, ) raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}") diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index dd422b90f6..ca16fa6326 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -53,15 +53,14 @@ class Int8Tensor(TorchAOBaseTensor): Tensor Attributes: qdata: (N, K) or (B, N, K) int8 quantized weight data (2D or 3D) scale: scale factors for dequantization - # TODO: Static quantization support using `static_scale` Non-Tensor Attributes: granularity: the granularity for quantization (e.g., PerRow(), PerTensor()) act_quant_kwargs: flags for dynamic activation quantization """ - # TODO: Static quantization support using `static_scale` tensor_data_names = ["qdata", "scale"] + optional_tensor_data_names = ["act_scale"] tensor_attribute_names = ["block_size", "dtype"] optional_tensor_attribute_names = [ "act_quant_kwargs", @@ -73,6 +72,7 @@ def __new__( scale: torch.Tensor, block_size: List[int], dtype: torch.dtype, + act_scale=None, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, ): kwargs = { @@ -88,6 +88,7 @@ def __init__( scale: torch.Tensor, block_size: List[int], dtype: torch.dtype, + act_scale=None, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, ): super().__init__() @@ -96,6 +97,7 @@ def __init__( self.block_size = block_size # don't set dtype because this gets done in __new__ self.act_quant_kwargs = act_quant_kwargs + self.act_scale = act_scale def __repr__(self): return ( @@ -103,6 +105,7 @@ def __repr__(self): f"act_quant_kwargs={self.act_quant_kwargs}, " f"qdata={self.qdata}, " f"scale={self.scale}, " + f"act_scale={self.act_scale}, " f"block_size={self.block_size}, " f"shape={self.shape}, " f"device={self.device}, " @@ -114,24 +117,35 @@ def from_hp( cls, hp_tensor: torch.Tensor, granularity: Granularity, - act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, mapping_type=MappingType.SYMMETRIC, + scale: Optional[torch.Tensor] = None, + act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, + act_scale: Optional[torch.Tensor] = None, ): """Create Int8Tensor from high-precision tensor""" block_size = get_block_size(hp_tensor.shape, granularity) block_size = list(block_size) - scale, zero_point = choose_qparams_affine( - input=hp_tensor, - mapping_type=mapping_type, - block_size=block_size, - target_dtype=torch.int8, - quant_min=-128, - quant_max=127, - scale_dtype=hp_tensor.dtype, - zero_point_dtype=torch.int8, - keepdim=True, - ) + if scale is None: + scale, zero_point = choose_qparams_affine( + input=hp_tensor, + mapping_type=mapping_type, + block_size=block_size, + target_dtype=torch.int8, + quant_min=-128, + quant_max=127, + scale_dtype=hp_tensor.dtype, + zero_point_dtype=torch.int8, + keepdim=True, + ) + else: + # Scale can be provided in the case of static quant + assert scale.ndim == hp_tensor.ndim + assert all( + (hp_tensor.shape[i] // block_size[i]) == scale.shape[i] + for i in range(hp_tensor.ndim) + ) + zero_point = torch.zeros_like(scale, dtype=torch.int8) int_data = quantize_affine( hp_tensor, @@ -146,6 +160,7 @@ def from_hp( scale, block_size, hp_tensor.dtype, + act_scale=act_scale, act_quant_kwargs=act_quant_kwargs, ) @@ -185,7 +200,9 @@ def _(func, types, args, kwargs): if weight_tensor.act_quant_kwargs is not None: activation_tensor = _choose_quant_func_and_quantize_tensor( - activation_tensor, weight_tensor.act_quant_kwargs + activation_tensor, + weight_tensor.act_quant_kwargs, + scale=weight_tensor.act_scale, ) # Dynamic activation quantization path