diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index faf8231342..30f6e6f285 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -15,11 +15,13 @@ ) from torchao.prototype.smoothquant.core import SmoothQuantStep from torchao.quantization import quantize_ +from torchao.quantization.granularity import PerRow, PerTensor from torchao.quantization.linear_activation_scale import ( WeightTensorWithLinearActivationScaleMetadata, ) from torchao.quantization.quant_api import ( Int8DynamicActivationInt8WeightConfig, + Int8StaticActivationInt8WeightConfig, ) from torchao.quantization.utils import ( compute_error as SQNR, @@ -83,7 +85,10 @@ def setUpClass(cls): @common_utils.parametrize( "base_config", [ - Int8DynamicActivationInt8WeightConfig(), + Int8DynamicActivationInt8WeightConfig(version=2), + # TODO: not sure if we should allow not passing scales as part of static config? + Int8StaticActivationInt8WeightConfig(granularity=PerRow()), + Int8StaticActivationInt8WeightConfig(granularity=PerTensor()), # Note: float8_static_activation_float8_weight is broken after recent PyTorch update. # TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py ], @@ -101,7 +106,15 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype): # Step 1. Basic quantization basic_model = deepcopy(m) - quantize_(basic_model, base_config) + if isinstance(base_config, Int8StaticActivationInt8WeightConfig): + quantize_( + basic_model, + Int8DynamicActivationInt8WeightConfig( + version=2, granularity=base_config.granularity + ), + ) + else: + quantize_(basic_model, base_config) out_basic = basic_model(*x) loss_base = torch.nn.functional.mse_loss(out_basic, out_ref).item() diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 9f78c49fb8..855340fb14 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -15,8 +15,12 @@ ) from torchao.quantization.quant_api import ( _QUANTIZE_CONFIG_HANDLER, + Int8StaticActivationInt8WeightConfig, _linear_extra_repr, ) +from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( + QuantizeTensorToInt8Kwargs, +) from torchao.quantization.transform_module import ( register_quantize_module_handler, ) @@ -95,8 +99,18 @@ def _smooth_quant_transform( else: raise ValueError(f"Unexpected step: {step}") + if isinstance(base_config, Int8StaticActivationInt8WeightConfig): + quant_kwargs = QuantizeTensorToInt8Kwargs( + granularity=base_config.granularity, + mapping_type=base_config.act_mapping_type, + ) + else: + quant_kwargs = None + # Compute smoothed weight parameters - smoothing_factor = observed_linear.obs.calculate_qparams() + smoothing_factor, activation_scale = observed_linear.obs.calculate_qparams( + weight_quant_kwargs=quant_kwargs + ) weight = observed_linear.weight * smoothing_factor # Create new linear layer @@ -111,6 +125,9 @@ def _smooth_quant_transform( linear.bias = observed_linear.bias # Quantize weights + if isinstance(base_config, Int8StaticActivationInt8WeightConfig): + base_config.scale = activation_scale + base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(base_config)] dummy_mod = DummyModule(weight) quant_mod = base_config_handler(dummy_mod, base_config) @@ -120,6 +137,7 @@ def _smooth_quant_transform( qw = to_weight_tensor_with_linear_activation_scale_metadata( qw, smoothing_factor.to(qw.dtype) ) + linear.weight = torch.nn.Parameter(qw, requires_grad=False) linear.extra_repr = types.MethodType(_linear_extra_repr, linear) diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index 83f1e78275..9974bf3719 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -9,6 +9,10 @@ import torch import torch.nn.functional as F +from torchao.quantization.quantize_.common import ( + _choose_quant_func_and_quantize_tensor, +) + class SmoothQuantStep(str, Enum): PREPARE = "prepare" @@ -41,13 +45,14 @@ def forward(self, input: torch.Tensor): self.inputs.append(input.to("cpu")) return input - def calculate_qparams(self): + def calculate_qparams(self, weight_quant_kwargs=None): assert self.inputs and len(self.inputs) > 0, ( "calibrate observer first by running model on exemplar data" ) inputs = [inp.to(self.device) for inp in self.inputs] acc = torch.cat(inputs, dim=0) # Reshape if needed: [batch, seq, features] -> [batch*seq, features] + example_input_for_quantization = acc if acc.ndim > 2: acc = acc.view(-1, acc.shape[-1]) @@ -57,12 +62,20 @@ def calculate_qparams(self): # Calculate smoothing factor if self.alpha is None: - return torch.ones_like(x_abs_max) + smoothing_factor = torch.ones_like(x_abs_max) + else: + eps = torch.finfo(torch.float32).eps + smoothing_factor = torch.pow(x_abs_max + eps, self.alpha) / torch.pow( + w_abs_max + eps, 1 - self.alpha + ) - eps = torch.finfo(torch.float32).eps - return torch.pow(x_abs_max + eps, self.alpha) / torch.pow( - w_abs_max + eps, 1 - self.alpha - ) + if weight_quant_kwargs is not None: + quant_smooth_activation = _choose_quant_func_and_quantize_tensor( + example_input_for_quantization / smoothing_factor, weight_quant_kwargs + ) + return smoothing_factor, quant_smooth_activation.scale + else: + return smoothing_factor, None class SmoothQuantObservedLinear(torch.nn.Linear): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e19e35c20c..c7ca2d34cc 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1658,7 +1658,7 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig): version (int): the version of the config """ - scale: torch.Tensor + scale: torch.Tensor = None granularity: Granularity = PerRow() act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC set_inductor_config: bool = True diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index ca16fa6326..40e0ecae56 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -3,7 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. - +import math from dataclasses import dataclass from typing import List, Optional @@ -199,12 +199,13 @@ def _(func, types, args, kwargs): output_dtype = activation_tensor.dtype if weight_tensor.act_quant_kwargs is not None: + # for int8 dynamic + static quantization path + activation_tensor = _choose_quant_func_and_quantize_tensor( activation_tensor, weight_tensor.act_quant_kwargs, scale=weight_tensor.act_scale, ) - # Dynamic activation quantization path # 1. do the matrix form of dot(X_i, W_j) #