diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 70da622c73..dbe0606eda 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -43,11 +43,13 @@ Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, + Int8StaticActivationInt8WeightConfig, _replace_with_custom_fn_if_matches_filter, quantize_, ) from torchao.quantization.quant_primitives import ( MappingType, + choose_qparams_affine, dequantize_affine, ) from torchao.quantization.smoothquant import ( @@ -1004,6 +1006,25 @@ def test_dynamic_quant(self): sqnr = compute_error(y_ref, y_test) self.assertGreater(sqnr, 40.0) +class TestStaticQuant(unittest.TestCase): + def test_static_quant(self): + M, K, N = 8, 16, 8 + x = torch.randn(M, K) + m = nn.Sequential(nn.Linear(K, N)) + block_size = [M, K] # per-tensor quantization + scale, _ = choose_qparams_affine( + x, + mapping_type=MappingType.SYMMETRIC, + block_size=block_size, + target_dtype=torch.int8, + ) + + y_ref = m(x) + quantize_(m, Int8StaticActivationInt8WeightConfig(act_quant_scale=scale)) + y_test = m(x) + + sqnr = compute_error(y_ref, y_test) + self.assertGreater(sqnr, 40.0) class TestWeightOnlyInt8Quant(unittest.TestCase): def test_weight_only_quant(self): diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index faf8231342..6a485f4e27 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -20,6 +20,7 @@ ) from torchao.quantization.quant_api import ( Int8DynamicActivationInt8WeightConfig, + Int8StaticActivationInt8WeightConfig ) from torchao.quantization.utils import ( compute_error as SQNR, @@ -84,6 +85,7 @@ def setUpClass(cls): "base_config", [ Int8DynamicActivationInt8WeightConfig(), + Int8StaticActivationInt8WeightConfig(), # Note: float8_static_activation_float8_weight is broken after recent PyTorch update. # TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py ], @@ -139,6 +141,7 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype): "base_config", [ Int8DynamicActivationInt8WeightConfig(), + Int8StaticActivationInt8WeightConfig(), # TODO: Check more quantization APIs ], ) @@ -178,6 +181,7 @@ def test_observer_insertion(self, base_config): "base_config", [ Int8DynamicActivationInt8WeightConfig(), + Int8StaticActivationInt8WeightConfig(), # TODO: Check more quantization APIs ], ) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 2acdff2b84..df32b3230a 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -14,6 +14,7 @@ from torchao.quantization import ( Int8DynamicActivationInt8WeightConfig, + Int8StaticActivationInt8WeightConfig, Int8WeightOnlyConfig, quantize_, ) @@ -66,6 +67,7 @@ def test_creation_and_attributes(self, config): "config", [ Int8DynamicActivationInt8WeightConfig(version=2), + Int8StaticActivationInt8WeightConfig(version=2), Int8WeightOnlyConfig(version=2), ], ) @@ -108,6 +110,7 @@ def test_int8_linear_variants( "config", [ Int8DynamicActivationInt8WeightConfig(version=2), + Int8StaticActivationInt8WeightConfig(version=2), Int8WeightOnlyConfig(version=2), ], ) @@ -173,6 +176,7 @@ def test_index_select(self, config, granularity): "config", [ Int8DynamicActivationInt8WeightConfig(version=2), + Int8StaticActivationInt8WeightConfig(version=2), Int8WeightOnlyConfig(version=2), ], ) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 292b67380d..aca95eb403 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -127,7 +127,8 @@ def int_scaled_matmul( assert M == scales1.size(0) or scales1.numel() == 1 assert 1 == scales1.size(1) assert scales1.is_contiguous() - scales1 = scales1.expand((M, N)) + if scales1.device.type != "cpu": + scales1 = scales1.expand((M, N)) assert scales1.dim() == 2 if check_cpu_version(scales1.device): diff --git a/torchao/prototype/smoothquant/README.md b/torchao/prototype/smoothquant/README.md index 00e819c438..c5aeaea78a 100644 --- a/torchao/prototype/smoothquant/README.md +++ b/torchao/prototype/smoothquant/README.md @@ -50,6 +50,7 @@ for data in calibration_dataset: quant_config.step = SmoothQuantStep.CONVERT quantize_(model, quant_config) ``` +For static quantization of activation, use `Int8StaticActivationInt8WeightConfig` instead of `Int8DynamicActivationInt8WeightConfig`. Generally, static quantization produces better througput at the cost of accuracy (higher perplexity). ## Benchmarks diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 9f78c49fb8..0917792265 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -15,6 +15,7 @@ ) from torchao.quantization.quant_api import ( _QUANTIZE_CONFIG_HANDLER, + Int8StaticActivationInt8WeightConfig, _linear_extra_repr, ) from torchao.quantization.transform_module import ( @@ -96,21 +97,17 @@ def _smooth_quant_transform( raise ValueError(f"Unexpected step: {step}") # Compute smoothed weight parameters - smoothing_factor = observed_linear.obs.calculate_qparams() + act_quant_min, act_quant_max = None, None + if isinstance(base_config, Int8StaticActivationInt8WeightConfig): + act_quant_min, act_quant_max = -127, 127 + smoothing_factor, act_quant_scale = observed_linear.obs.calculate_qparams( + act_quant_min, act_quant_max + ) weight = observed_linear.weight * smoothing_factor - # Create new linear layer - with torch.device("meta"): - linear = torch.nn.Linear( - observed_linear.in_features, - observed_linear.out_features, - observed_linear.bias is not None, - device=observed_linear.weight.device, - dtype=observed_linear.weight.dtype, - ) - linear.bias = observed_linear.bias - # Quantize weights + if isinstance(base_config, Int8StaticActivationInt8WeightConfig): + base_config.act_quant_scale = act_quant_scale base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(base_config)] dummy_mod = DummyModule(weight) quant_mod = base_config_handler(dummy_mod, base_config) @@ -120,7 +117,18 @@ def _smooth_quant_transform( qw = to_weight_tensor_with_linear_activation_scale_metadata( qw, smoothing_factor.to(qw.dtype) ) + + # Create new linear layer + with torch.device("meta"): + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + observed_linear.bias is not None, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, + ) linear.weight = torch.nn.Parameter(qw, requires_grad=False) linear.extra_repr = types.MethodType(_linear_extra_repr, linear) + linear.bias = observed_linear.bias return linear diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index 83f1e78275..ebaae134e0 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -41,7 +41,7 @@ def forward(self, input: torch.Tensor): self.inputs.append(input.to("cpu")) return input - def calculate_qparams(self): + def calculate_qparams(self, act_quant_min=None, act_quant_max=None): assert self.inputs and len(self.inputs) > 0, ( "calibrate observer first by running model on exemplar data" ) @@ -57,12 +57,20 @@ def calculate_qparams(self): # Calculate smoothing factor if self.alpha is None: - return torch.ones_like(x_abs_max) + smooth_factor = torch.ones_like(x_abs_max) + else: + eps = torch.finfo(torch.float32).eps + smooth_factor = torch.pow(x_abs_max + eps, self.alpha) / torch.pow( + w_abs_max + eps, 1 - self.alpha + ) - eps = torch.finfo(torch.float32).eps - return torch.pow(x_abs_max + eps, self.alpha) / torch.pow( - w_abs_max + eps, 1 - self.alpha - ) + # Calculate per-tensor act_quant_scale + act_quant_scale = None + if act_quant_min is not None and act_quant_max is not None: + x_abs_max_t = acc.abs().max() + act_quant_scale = (x_abs_max_t / (act_quant_max - act_quant_min) / 2).item() + + return smooth_factor, act_quant_scale class SmoothQuantObservedLinear(torch.nn.Linear): diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py index 8602b57e20..475f942ea0 100644 --- a/torchao/prototype/smoothquant/example.py +++ b/torchao/prototype/smoothquant/example.py @@ -16,8 +16,10 @@ ) from torchao.prototype.smoothquant.core import SmoothQuantStep from torchao.quantization import quantize_ -from torchao.quantization.quant_api import Int8DynamicActivationInt8WeightConfig - +from torchao.quantization.quant_api import ( + Int8DynamicActivationInt8WeightConfig, + Int8StaticActivationInt8WeightConfig, +) # TODO: Build benchmark within vLLM ecosystem with more quantization APIs # See https://github.com/pytorch/ao/issues/2815 for more details @@ -82,6 +84,8 @@ def quantize_and_eval( device: str, model_save_path: str, model_save_hf_hub_path: str, + static_quant_act: bool, + compile: bool, ): print(f"Loading model on {device}...") torch.manual_seed(34) @@ -96,9 +100,14 @@ def quantize_and_eval( # Step 1: Prepare - insert observers print("running SmoothQuant prepare and calibrate") + base_config = ( + Int8StaticActivationInt8WeightConfig() + if static_quant_act + else Int8DynamicActivationInt8WeightConfig() + ) t0 = time.time() quant_config = SmoothQuantConfig( - base_config=Int8DynamicActivationInt8WeightConfig(), + base_config=base_config, step=SmoothQuantStep.PREPARE, alpha=alpha, ) @@ -133,6 +142,8 @@ def quantize_and_eval( print("pushing model to hub:", model_save_hf_hub_path) model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) tokenizer.push_to_hub(model_save_hf_hub_path) + if compile: + model.forward = torch.compile(model.forward, dynamic=True) print("Benchmarking SmoothQuant model...") return benchmark(model, tokenizer, max_seq_length, tasks=tasks, device=device) @@ -147,6 +158,8 @@ def compare_models( device: str, model_save_path: str, model_save_hf_hub_path: str, + static_quant_act: bool, + compile: bool, ): """Compare perplexity and speed for behchmarking SmoothQuant""" @@ -159,6 +172,8 @@ def compare_models( .eval() .to(device) ) + if compile: + model.forward = torch.compile(model.forward, dynamic=True) base_results = benchmark( model, tokenizer, max_seq_length, tasks=tasks, device=device ) @@ -172,6 +187,8 @@ def compare_models( .to(device) ) quantize_(w8a8_model, Int8DynamicActivationInt8WeightConfig()) + if compile: + w8a8_model.forward = torch.compile(w8a8_model.forward, dynamic=True) w8a8_results = benchmark( w8a8_model, tokenizer, max_seq_length, tasks=tasks, device=device ) @@ -187,6 +204,8 @@ def compare_models( device, model_save_path, model_save_hf_hub_path, + static_quant_act, + compile, ) # Calculate changes and display results @@ -289,6 +308,16 @@ def create_parser() -> argparse.ArgumentParser: default=None, help="Huggingface hub path to store the quantized model and tokenizer.", ) + parser.add_argument( + "--static-quant-act", + action="store_true", + help="Use static quantization of activation instead of dynamic quantization.", + ) + parser.add_argument( + "--compile", + action="store_true", + help="Use torch.compile to compile the model for potentially better performance.", + ) return parser @@ -306,4 +335,6 @@ def create_parser() -> argparse.ArgumentParser: args.device, args.model_save_path, args.model_save_hf_hub_path, + args.static_quant_act, + args.compile, ) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index a1ca6b0b94..5c607c4a79 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -59,6 +59,7 @@ Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, Int8DynamicActivationIntxWeightConfig, + Int8StaticActivationInt8WeightConfig, Int8WeightOnlyConfig, IntxWeightOnlyConfig, ModuleFqnToConfig, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9a1bfeb0a5..ff371ad9fb 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -50,6 +50,7 @@ to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, + to_affine_quantized_intx_static, to_marlinqqq_quantized_intx, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( @@ -476,6 +477,7 @@ def quantize_( # currently options are # Int8DynamicActivationInt4WeightConfig (for executorch) # Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile) + # Int8StaticActivationInt8WeightConfig (optimized with int8 mm op and torch.compile) # Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile) # Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile from torchao.quantization.quant_api import int4_weight_only @@ -1649,6 +1651,132 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): return Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()) +def _activation_static_sym_quant_func_int8( + x: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor] = None +) -> torch.Tensor: + assert zero_point is None, "Zero point must be None" + quant_min = -127 + quant_max = 127 + target_dtype = torch.int8 + zero_point_domain = ZeroPointDomain.NONE + + return to_affine_quantized_intx_static( + x, + scale=scale, + zero_point=zero_point, + block_size=get_block_size(x.shape, PerTensor()), + target_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + zero_point_domain=zero_point_domain, + ) + + +@dataclass +class Int8StaticActivationInt8WeightConfig(AOBaseConfig): + """ + Configuration for applying int8 static activation and int8 per-channel weight + quantization to linear layers. Activation is always quantized with symmetric per-tensor quantization. + Args: + layout: Optional[Layout] = PlainLayout() - Tensor layout for the quantized weights. Controls how the + quantized data is stored and accessed. Only PlainLayout is supported now. + set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values + for better performance with this quantization scheme. + version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Int8Tensor + act_quant_scale: Optional[float] = None - Scale float for static activation quantization. + act_quant_zero_point: Optional[float] = None - Zero point float for static activation quantization. + """ + + layout: Optional[Layout] = PlainLayout() + granularity: Optional[Granularity] = PerRow() + set_inductor_config: bool = False + version: int = 2 + act_quant_scale: Optional[float] = None + act_quant_zero_point: Optional[float] = None + + def __post_init__(self): + assert isinstance(self.layout, PlainLayout), ( + f"Only support PlainLayout for layout, got {self.layout}" + ) + torch._C._log_api_usage_once( + "torchao.quantization.Int8StaticActivationInt8WeightConfig" + ) + + +def _int8_static_activation_int8_weight_quantize_tensor(weight, config): + act_quant_scale = config.act_quant_scale + act_quant_zero_point = config.act_quant_zero_point + layout = config.layout + + # weight settings + mapping_type = MappingType.SYMMETRIC + weight_zero_point_domain = ZeroPointDomain.NONE + + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + block_size = get_block_size(weight.shape, PerRow()) + if config.version == 1: + warnings.warn( + "Config Deprecation: version 1 of Int8StaticActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" + ) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + _layout=layout, + zero_point_domain=weight_zero_point_domain, + ) + new_weight = to_weight_tensor_with_linear_activation_quantization_metadata( + new_weight, + _activation_static_sym_quant_func_int8, + scale=act_quant_scale, + zero_point=act_quant_zero_point, + ) + else: + from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( + QuantizeTensorToInt8Kwargs, + ) + + assert config.version == 2, f"Unexpected version: {config.version}" + # Compute block_size from granularity for activation quantization kwargs + block_size = get_block_size(weight.shape, config.granularity) + new_weight = Int8Tensor.from_hp( + weight, + granularity=config.granularity, + act_quant_kwargs=QuantizeTensorToInt8Kwargs( + granularity=config.granularity, + is_act=False, + act_quant_scale=act_quant_scale, + act_quant_zero_point=act_quant_zero_point, + ), + ) + return new_weight + + +@register_quantize_module_handler(Int8StaticActivationInt8WeightConfig) +def _int8_static_activation_int8_weight_transform( + module: torch.nn.Module, config: Int8StaticActivationInt8WeightConfig +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + assert hasattr(module, "weight"), ( + "applying int8 static activation int8 weight quant requires module to have weight attribute" + + f"but {module} does not have one" + ) + new_weight = _int8_static_activation_int8_weight_quantize_tensor( + module.weight, config + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + @dataclass class Float8WeightOnlyConfig(AOBaseConfig): """ diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 6ca31326cd..d2161a85cb 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -11,7 +11,11 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.kernel import int_scaled_matmul -from torchao.quantization.granularity import Granularity, PerRow +from torchao.quantization.granularity import ( + Granularity, + PerRow, + PerTensor, +) from torchao.quantization.quant_primitives import ( MappingType, choose_qparams_affine, @@ -33,9 +37,15 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): Args: granularity: the granularity for the Tensor, currently either PerRow() or PerTensor() + is_act (bool): whether the tensor is activation tensor + act_quant_scale (Optional[float]): pre-computed scale for static activation quantization + act_quant_zero_point (Optional[float]): pre-computed zero point for static activation quantization """ granularity: Granularity = PerRow() + is_act: bool = False + act_quant_scale: Optional[float] = None + act_quant_zero_point: Optional[float] = None class Int8Tensor(TorchAOBaseTensor): @@ -49,7 +59,7 @@ class Int8Tensor(TorchAOBaseTensor): Non-Tensor Attributes: granularity: the granularity for quantization (e.g., PerRow(), PerTensor()) - act_quant_kwargs: flags for dynamic activation quantization + act_quant_kwargs: flags for activation quantization """ # TODO: Static quantization support using `static_scale` @@ -117,25 +127,46 @@ def from_hp( f"Expected 2D or 3D tensor with matching block_size dimensions, " f"got tensor dim={w_hp.dim()}, block_size length={len(block_size)}" ) + if act_quant_kwargs is not None and act_quant_kwargs.act_quant_scale is not None and act_quant_kwargs.is_act: + # Static activation quantization + scale = torch.amin(w_hp, keepdim=False) + scale = scale.fill_(act_quant_kwargs.act_quant_scale) + scale = scale.to(device=w_hp.device, dtype=w_hp.dtype) + if act_quant_kwargs.act_quant_zero_point is not None: + zero_point = scale.fill_(act_quant_kwargs.act_quant_zero_point) + zero_point = zero_point.to(device=w_hp.device, dtype=torch.int) + else: + zero_point = torch.zeros_like(scale, dtype=torch.int) + + int_data = quantize_affine( + w_hp, + block_size=w_hp.shape, + scale=scale, + zero_point=zero_point, + output_dtype=torch.int8, + quant_min=-127, + quant_max=127, + ) + else: + # Dynamic quantization + scale, zero_point = choose_qparams_affine( + input=w_hp, + mapping_type=MappingType.SYMMETRIC, + block_size=block_size, + target_dtype=torch.int8, + quant_min=-128, + quant_max=127, + scale_dtype=w_hp.dtype, + zero_point_dtype=torch.int8, + ) - scale, zero_point = choose_qparams_affine( - input=w_hp, - mapping_type=MappingType.SYMMETRIC, - block_size=block_size, - target_dtype=torch.int8, - quant_min=-128, - quant_max=127, - scale_dtype=w_hp.dtype, - zero_point_dtype=torch.int8, - ) - - int_data = quantize_affine( - w_hp, - block_size=block_size, - scale=scale, - zero_point=zero_point, - output_dtype=torch.int8, - ) + int_data = quantize_affine( + w_hp, + block_size=block_size, + scale=scale, + zero_point=zero_point, + output_dtype=torch.int8, + ) return cls( int_data, @@ -230,7 +261,7 @@ def _slice_scale( @implements(aten.linear.default) @implements_torch_function(torch.nn.functional.linear) def _(func, types, args, kwargs): - """INT8 quantization: dynamic activation or weight-only""" + """INT8 quantization: static/dynamic activation or weight-only""" activation_tensor, weight_tensor, bias = ( args[0], args[1], @@ -243,10 +274,15 @@ def _(func, types, args, kwargs): output_dtype = activation_tensor.dtype if weight_tensor.act_quant_kwargs is not None: - activation_tensor = Int8Tensor.from_hp( - activation_tensor, weight_tensor.act_quant_kwargs.granularity - ) - # Dynamic activation quantization path + if not isinstance(activation_tensor, Int8Tensor): + act_kwargs = weight_tensor.act_quant_kwargs + act_kwargs.is_act = True + if act_kwargs.act_quant_scale is not None: + # Static activation quantization path + act_kwargs.granularity = PerTensor() + activation_tensor = Int8Tensor.from_hp( + activation_tensor, act_kwargs.granularity, act_kwargs + ) # 1. do the matrix form of dot(X_i, W_j) #