From 43e298ec49193ea4e80b4ddd49cc34c1db6fb6ce Mon Sep 17 00:00:00 2001 From: "Cui, Lily" Date: Wed, 26 Nov 2025 17:42:55 -0800 Subject: [PATCH 1/6] Static smooth quant Signed-off-by: Cui, Lily --- test/integration/test_integration.py | 23 +++- test/prototype/test_smoothquant.py | 4 + .../workflows/int8/test_int8_tensor.py | 4 + torchao/prototype/smoothquant/README.md | 1 + torchao/prototype/smoothquant/api.py | 35 +++-- torchao/prototype/smoothquant/core.py | 20 ++- torchao/prototype/smoothquant/example.py | 37 ++++- torchao/quantization/__init__.py | 1 + .../quantization/linear_activation_scale.py | 20 ++- torchao/quantization/quant_api.py | 128 ++++++++++++++++++ .../quantize_/workflows/int8/int8_tensor.py | 90 ++++++++---- 11 files changed, 312 insertions(+), 51 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 70da622c73..2aac286cf6 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -43,11 +43,13 @@ Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, Int8WeightOnlyConfig, + Int8StaticActivationInt8WeightConfig, _replace_with_custom_fn_if_matches_filter, quantize_, ) from torchao.quantization.quant_primitives import ( MappingType, + choose_qparams_affine, dequantize_affine, ) from torchao.quantization.smoothquant import ( @@ -1004,6 +1006,25 @@ def test_dynamic_quant(self): sqnr = compute_error(y_ref, y_test) self.assertGreater(sqnr, 40.0) +class TestStaticQuant(unittest.TestCase): + def test_static_quant(self): + M, K, N = 8, 16, 8 + x = torch.randn(M, K) + m = nn.Sequential(nn.Linear(K, N)) + block_size = [M, K] # per-tensor quantization + scale, _ = choose_qparams_affine( + x, + mapping_type=MappingType.SYMMETRIC, + block_size=block_size, + target_dtype=torch.int8, + ) + + y_ref = m(x) + quantize_(m, Int8StaticActivationInt8WeightConfig(act_quant_scale=scale)) + y_test = m(x) + + sqnr = compute_error(y_ref, y_test) + self.assertGreater(sqnr, 40.0) class TestWeightOnlyInt8Quant(unittest.TestCase): def test_weight_only_quant(self): @@ -1037,7 +1058,7 @@ def test_weight_only_groupwise_embedding_quant(self): quantize_( m, - Int8WeightOnlyConfig(group_size=group_size), + Int8WeightOnlyConfig(group_size=group_size,version=2), filter_fn=lambda x, *args: isinstance(x, nn.Embedding), ) y_q = m(input) diff --git a/test/prototype/test_smoothquant.py b/test/prototype/test_smoothquant.py index faf8231342..6a485f4e27 100644 --- a/test/prototype/test_smoothquant.py +++ b/test/prototype/test_smoothquant.py @@ -20,6 +20,7 @@ ) from torchao.quantization.quant_api import ( Int8DynamicActivationInt8WeightConfig, + Int8StaticActivationInt8WeightConfig ) from torchao.quantization.utils import ( compute_error as SQNR, @@ -84,6 +85,7 @@ def setUpClass(cls): "base_config", [ Int8DynamicActivationInt8WeightConfig(), + Int8StaticActivationInt8WeightConfig(), # Note: float8_static_activation_float8_weight is broken after recent PyTorch update. # TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py ], @@ -139,6 +141,7 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype): "base_config", [ Int8DynamicActivationInt8WeightConfig(), + Int8StaticActivationInt8WeightConfig(), # TODO: Check more quantization APIs ], ) @@ -178,6 +181,7 @@ def test_observer_insertion(self, base_config): "base_config", [ Int8DynamicActivationInt8WeightConfig(), + Int8StaticActivationInt8WeightConfig(), # TODO: Check more quantization APIs ], ) diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 2acdff2b84..df32b3230a 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -14,6 +14,7 @@ from torchao.quantization import ( Int8DynamicActivationInt8WeightConfig, + Int8StaticActivationInt8WeightConfig, Int8WeightOnlyConfig, quantize_, ) @@ -66,6 +67,7 @@ def test_creation_and_attributes(self, config): "config", [ Int8DynamicActivationInt8WeightConfig(version=2), + Int8StaticActivationInt8WeightConfig(version=2), Int8WeightOnlyConfig(version=2), ], ) @@ -108,6 +110,7 @@ def test_int8_linear_variants( "config", [ Int8DynamicActivationInt8WeightConfig(version=2), + Int8StaticActivationInt8WeightConfig(version=2), Int8WeightOnlyConfig(version=2), ], ) @@ -173,6 +176,7 @@ def test_index_select(self, config, granularity): "config", [ Int8DynamicActivationInt8WeightConfig(version=2), + Int8StaticActivationInt8WeightConfig(version=2), Int8WeightOnlyConfig(version=2), ], ) diff --git a/torchao/prototype/smoothquant/README.md b/torchao/prototype/smoothquant/README.md index 00e819c438..c5aeaea78a 100644 --- a/torchao/prototype/smoothquant/README.md +++ b/torchao/prototype/smoothquant/README.md @@ -50,6 +50,7 @@ for data in calibration_dataset: quant_config.step = SmoothQuantStep.CONVERT quantize_(model, quant_config) ``` +For static quantization of activation, use `Int8StaticActivationInt8WeightConfig` instead of `Int8DynamicActivationInt8WeightConfig`. Generally, static quantization produces better througput at the cost of accuracy (higher perplexity). ## Benchmarks diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index 9f78c49fb8..b7992da7a5 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -15,6 +15,7 @@ ) from torchao.quantization.quant_api import ( _QUANTIZE_CONFIG_HANDLER, + Int8StaticActivationInt8WeightConfig, _linear_extra_repr, ) from torchao.quantization.transform_module import ( @@ -96,31 +97,39 @@ def _smooth_quant_transform( raise ValueError(f"Unexpected step: {step}") # Compute smoothed weight parameters - smoothing_factor = observed_linear.obs.calculate_qparams() + act_quant_min, act_quant_max = None, None + if isinstance(base_config, Int8StaticActivationInt8WeightConfig): + act_quant_min, act_quant_max = -127, 127 + smoothing_factor, act_quant_scale = observed_linear.obs.calculate_qparams( + act_quant_min, act_quant_max + ) weight = observed_linear.weight * smoothing_factor - # Create new linear layer - with torch.device("meta"): - linear = torch.nn.Linear( - observed_linear.in_features, - observed_linear.out_features, - observed_linear.bias is not None, - device=observed_linear.weight.device, - dtype=observed_linear.weight.dtype, - ) - linear.bias = observed_linear.bias - # Quantize weights + if isinstance(base_config, Int8StaticActivationInt8WeightConfig): + base_config.act_quant_scale = act_quant_scale base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(base_config)] dummy_mod = DummyModule(weight) quant_mod = base_config_handler(dummy_mod, base_config) qw = quant_mod.weight # Add smoothing factor metadata + use_inv_scale = qw.device.type == "cpu" qw = to_weight_tensor_with_linear_activation_scale_metadata( - qw, smoothing_factor.to(qw.dtype) + qw, smoothing_factor.to(qw.dtype), use_inv_scale ) + + # Create new linear layer + with torch.device("meta"): + linear = torch.nn.Linear( + observed_linear.in_features, + observed_linear.out_features, + observed_linear.bias is not None, + device=observed_linear.weight.device, + dtype=observed_linear.weight.dtype, + ) linear.weight = torch.nn.Parameter(qw, requires_grad=False) linear.extra_repr = types.MethodType(_linear_extra_repr, linear) + linear.bias = observed_linear.bias return linear diff --git a/torchao/prototype/smoothquant/core.py b/torchao/prototype/smoothquant/core.py index 83f1e78275..ebaae134e0 100644 --- a/torchao/prototype/smoothquant/core.py +++ b/torchao/prototype/smoothquant/core.py @@ -41,7 +41,7 @@ def forward(self, input: torch.Tensor): self.inputs.append(input.to("cpu")) return input - def calculate_qparams(self): + def calculate_qparams(self, act_quant_min=None, act_quant_max=None): assert self.inputs and len(self.inputs) > 0, ( "calibrate observer first by running model on exemplar data" ) @@ -57,12 +57,20 @@ def calculate_qparams(self): # Calculate smoothing factor if self.alpha is None: - return torch.ones_like(x_abs_max) + smooth_factor = torch.ones_like(x_abs_max) + else: + eps = torch.finfo(torch.float32).eps + smooth_factor = torch.pow(x_abs_max + eps, self.alpha) / torch.pow( + w_abs_max + eps, 1 - self.alpha + ) - eps = torch.finfo(torch.float32).eps - return torch.pow(x_abs_max + eps, self.alpha) / torch.pow( - w_abs_max + eps, 1 - self.alpha - ) + # Calculate per-tensor act_quant_scale + act_quant_scale = None + if act_quant_min is not None and act_quant_max is not None: + x_abs_max_t = acc.abs().max() + act_quant_scale = (x_abs_max_t / (act_quant_max - act_quant_min) / 2).item() + + return smooth_factor, act_quant_scale class SmoothQuantObservedLinear(torch.nn.Linear): diff --git a/torchao/prototype/smoothquant/example.py b/torchao/prototype/smoothquant/example.py index 8602b57e20..475f942ea0 100644 --- a/torchao/prototype/smoothquant/example.py +++ b/torchao/prototype/smoothquant/example.py @@ -16,8 +16,10 @@ ) from torchao.prototype.smoothquant.core import SmoothQuantStep from torchao.quantization import quantize_ -from torchao.quantization.quant_api import Int8DynamicActivationInt8WeightConfig - +from torchao.quantization.quant_api import ( + Int8DynamicActivationInt8WeightConfig, + Int8StaticActivationInt8WeightConfig, +) # TODO: Build benchmark within vLLM ecosystem with more quantization APIs # See https://github.com/pytorch/ao/issues/2815 for more details @@ -82,6 +84,8 @@ def quantize_and_eval( device: str, model_save_path: str, model_save_hf_hub_path: str, + static_quant_act: bool, + compile: bool, ): print(f"Loading model on {device}...") torch.manual_seed(34) @@ -96,9 +100,14 @@ def quantize_and_eval( # Step 1: Prepare - insert observers print("running SmoothQuant prepare and calibrate") + base_config = ( + Int8StaticActivationInt8WeightConfig() + if static_quant_act + else Int8DynamicActivationInt8WeightConfig() + ) t0 = time.time() quant_config = SmoothQuantConfig( - base_config=Int8DynamicActivationInt8WeightConfig(), + base_config=base_config, step=SmoothQuantStep.PREPARE, alpha=alpha, ) @@ -133,6 +142,8 @@ def quantize_and_eval( print("pushing model to hub:", model_save_hf_hub_path) model.push_to_hub(model_save_hf_hub_path, safe_serialization=False) tokenizer.push_to_hub(model_save_hf_hub_path) + if compile: + model.forward = torch.compile(model.forward, dynamic=True) print("Benchmarking SmoothQuant model...") return benchmark(model, tokenizer, max_seq_length, tasks=tasks, device=device) @@ -147,6 +158,8 @@ def compare_models( device: str, model_save_path: str, model_save_hf_hub_path: str, + static_quant_act: bool, + compile: bool, ): """Compare perplexity and speed for behchmarking SmoothQuant""" @@ -159,6 +172,8 @@ def compare_models( .eval() .to(device) ) + if compile: + model.forward = torch.compile(model.forward, dynamic=True) base_results = benchmark( model, tokenizer, max_seq_length, tasks=tasks, device=device ) @@ -172,6 +187,8 @@ def compare_models( .to(device) ) quantize_(w8a8_model, Int8DynamicActivationInt8WeightConfig()) + if compile: + w8a8_model.forward = torch.compile(w8a8_model.forward, dynamic=True) w8a8_results = benchmark( w8a8_model, tokenizer, max_seq_length, tasks=tasks, device=device ) @@ -187,6 +204,8 @@ def compare_models( device, model_save_path, model_save_hf_hub_path, + static_quant_act, + compile, ) # Calculate changes and display results @@ -289,6 +308,16 @@ def create_parser() -> argparse.ArgumentParser: default=None, help="Huggingface hub path to store the quantized model and tokenizer.", ) + parser.add_argument( + "--static-quant-act", + action="store_true", + help="Use static quantization of activation instead of dynamic quantization.", + ) + parser.add_argument( + "--compile", + action="store_true", + help="Use torch.compile to compile the model for potentially better performance.", + ) return parser @@ -306,4 +335,6 @@ def create_parser() -> argparse.ArgumentParser: args.device, args.model_save_path, args.model_save_hf_hub_path, + args.static_quant_act, + args.compile, ) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index a1ca6b0b94..5c607c4a79 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -59,6 +59,7 @@ Int8DynamicActivationInt4WeightConfig, Int8DynamicActivationInt8WeightConfig, Int8DynamicActivationIntxWeightConfig, + Int8StaticActivationInt8WeightConfig, Int8WeightOnlyConfig, IntxWeightOnlyConfig, ModuleFqnToConfig, diff --git a/torchao/quantization/linear_activation_scale.py b/torchao/quantization/linear_activation_scale.py index fd61d07d9a..4b327f7de7 100644 --- a/torchao/quantization/linear_activation_scale.py +++ b/torchao/quantization/linear_activation_scale.py @@ -21,6 +21,7 @@ class WeightTensorWithLinearActivationScaleMetadata(TorchAOBaseTensor): Tensor subclass that wraps a weight tensor and provides metadata for linear activation scaling. Right now we hardcode how we apply the scale: scaled_linear_act = input_act / scale + # or scaled_linear_act = input_act * inv_scale out = F.linear(scaled_linear_act, weight, ...) We can generalize this to accept a function as well if needed. @@ -31,12 +32,13 @@ class WeightTensorWithLinearActivationScaleMetadata(TorchAOBaseTensor): """ tensor_data_names = ["original_weight_tensor", "scale"] - tensor_attribute_names = [] + tensor_attribute_names = ["use_inv_scale"] def __new__( cls, original_weight_tensor: torch.Tensor, scale: torch.Tensor, + use_inv_scale: bool = False, ): kwargs = {} dtype = original_weight_tensor.dtype @@ -50,9 +52,12 @@ def __init__( self, original_weight_tensor: torch.Tensor, scale: torch.Tensor, + use_inv_scale: bool = False, ): self.original_weight_tensor = original_weight_tensor self.scale = scale + self.use_inv_scale = use_inv_scale + self.inv_scale = 1.0 / scale if use_inv_scale else None def _quantization_type(self): return f"{self.__class__}" @@ -63,8 +68,12 @@ def _quantized_linear_op( ): original_weight_tensor = weight_tensor.original_weight_tensor scale = weight_tensor.scale + inv_scale = weight_tensor.inv_scale + use_inv_scale = weight_tensor.use_inv_scale # Note: we can make this function configurable as well - scaled_input_act = input_tensor / scale + scaled_input_act = ( + input_tensor * inv_scale if use_inv_scale else input_tensor / scale + ) return torch.nn.functional.linear( scaled_input_act, original_weight_tensor, bias ) @@ -74,8 +83,9 @@ def from_float( cls, input_float: torch.Tensor, scale: torch.Tensor, + use_inv_scale: bool = False, ): - return cls(input_float, scale) + return cls(input_float, scale, use_inv_scale) implements = WeightTensorWithLinearActivationScaleMetadata.implements @@ -103,7 +113,9 @@ def _(func, types, args, kwargs): def _(func, types, args, kwargs): self = args[0] new = self.__class__( - func(self.original_weight_tensor, *args[1:], **kwargs), self.scale + func(self.original_weight_tensor, *args[1:], **kwargs), + self.scale, + self.use_inv_scale, ) return return_and_correct_aliasing(func, args, kwargs, new) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9a1bfeb0a5..ff371ad9fb 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -50,6 +50,7 @@ to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, + to_affine_quantized_intx_static, to_marlinqqq_quantized_intx, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( @@ -476,6 +477,7 @@ def quantize_( # currently options are # Int8DynamicActivationInt4WeightConfig (for executorch) # Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile) + # Int8StaticActivationInt8WeightConfig (optimized with int8 mm op and torch.compile) # Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile) # Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile from torchao.quantization.quant_api import int4_weight_only @@ -1649,6 +1651,132 @@ def int8_dynamic_activation_int8_semi_sparse_weight(): return Int8DynamicActivationInt8WeightConfig(layout=SemiSparseLayout()) +def _activation_static_sym_quant_func_int8( + x: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor] = None +) -> torch.Tensor: + assert zero_point is None, "Zero point must be None" + quant_min = -127 + quant_max = 127 + target_dtype = torch.int8 + zero_point_domain = ZeroPointDomain.NONE + + return to_affine_quantized_intx_static( + x, + scale=scale, + zero_point=zero_point, + block_size=get_block_size(x.shape, PerTensor()), + target_dtype=target_dtype, + quant_min=quant_min, + quant_max=quant_max, + zero_point_domain=zero_point_domain, + ) + + +@dataclass +class Int8StaticActivationInt8WeightConfig(AOBaseConfig): + """ + Configuration for applying int8 static activation and int8 per-channel weight + quantization to linear layers. Activation is always quantized with symmetric per-tensor quantization. + Args: + layout: Optional[Layout] = PlainLayout() - Tensor layout for the quantized weights. Controls how the + quantized data is stored and accessed. Only PlainLayout is supported now. + set_inductor_config: bool = True - If True, adjusts `torchinductor` settings to recommended values + for better performance with this quantization scheme. + version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Int8Tensor + act_quant_scale: Optional[float] = None - Scale float for static activation quantization. + act_quant_zero_point: Optional[float] = None - Zero point float for static activation quantization. + """ + + layout: Optional[Layout] = PlainLayout() + granularity: Optional[Granularity] = PerRow() + set_inductor_config: bool = False + version: int = 2 + act_quant_scale: Optional[float] = None + act_quant_zero_point: Optional[float] = None + + def __post_init__(self): + assert isinstance(self.layout, PlainLayout), ( + f"Only support PlainLayout for layout, got {self.layout}" + ) + torch._C._log_api_usage_once( + "torchao.quantization.Int8StaticActivationInt8WeightConfig" + ) + + +def _int8_static_activation_int8_weight_quantize_tensor(weight, config): + act_quant_scale = config.act_quant_scale + act_quant_zero_point = config.act_quant_zero_point + layout = config.layout + + # weight settings + mapping_type = MappingType.SYMMETRIC + weight_zero_point_domain = ZeroPointDomain.NONE + + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + block_size = get_block_size(weight.shape, PerRow()) + if config.version == 1: + warnings.warn( + "Config Deprecation: version 1 of Int8StaticActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" + ) + new_weight = to_affine_quantized_intx( + weight, + mapping_type, + block_size, + target_dtype, + eps=eps, + zero_point_dtype=zero_point_dtype, + _layout=layout, + zero_point_domain=weight_zero_point_domain, + ) + new_weight = to_weight_tensor_with_linear_activation_quantization_metadata( + new_weight, + _activation_static_sym_quant_func_int8, + scale=act_quant_scale, + zero_point=act_quant_zero_point, + ) + else: + from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( + QuantizeTensorToInt8Kwargs, + ) + + assert config.version == 2, f"Unexpected version: {config.version}" + # Compute block_size from granularity for activation quantization kwargs + block_size = get_block_size(weight.shape, config.granularity) + new_weight = Int8Tensor.from_hp( + weight, + granularity=config.granularity, + act_quant_kwargs=QuantizeTensorToInt8Kwargs( + granularity=config.granularity, + is_act=False, + act_quant_scale=act_quant_scale, + act_quant_zero_point=act_quant_zero_point, + ), + ) + return new_weight + + +@register_quantize_module_handler(Int8StaticActivationInt8WeightConfig) +def _int8_static_activation_int8_weight_transform( + module: torch.nn.Module, config: Int8StaticActivationInt8WeightConfig +) -> torch.nn.Module: + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + + assert hasattr(module, "weight"), ( + "applying int8 static activation int8 weight quant requires module to have weight attribute" + + f"but {module} does not have one" + ) + new_weight = _int8_static_activation_int8_weight_quantize_tensor( + module.weight, config + ) + module.weight = torch.nn.Parameter(new_weight, requires_grad=False) + module.extra_repr = types.MethodType(_linear_extra_repr, module) + return module + + @dataclass class Float8WeightOnlyConfig(AOBaseConfig): """ diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 6ca31326cd..445ff8fdb8 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -11,12 +11,17 @@ from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.kernel import int_scaled_matmul -from torchao.quantization.granularity import Granularity, PerRow +from torchao.quantization.granularity import ( + Granularity, + PerRow, + PerTensor, +) from torchao.quantization.quant_primitives import ( MappingType, choose_qparams_affine, dequantize_affine, quantize_affine, + _get_reduction_params, ) from torchao.quantization.quantize_.common import QuantizeTensorKwargs from torchao.quantization.utils import get_block_size @@ -33,9 +38,15 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): Args: granularity: the granularity for the Tensor, currently either PerRow() or PerTensor() + is_act (bool): whether the tensor is activation tensor + act_quant_scale (Optional[float]): pre-computed scale for static activation quantization + act_quant_zero_point (Optional[float]): pre-computed zero point for static activation quantization """ granularity: Granularity = PerRow() + is_act: bool = False + act_quant_scale: Optional[float] = None + act_quant_zero_point: Optional[float] = None class Int8Tensor(TorchAOBaseTensor): @@ -49,7 +60,7 @@ class Int8Tensor(TorchAOBaseTensor): Non-Tensor Attributes: granularity: the granularity for quantization (e.g., PerRow(), PerTensor()) - act_quant_kwargs: flags for dynamic activation quantization + act_quant_kwargs: flags for activation quantization """ # TODO: Static quantization support using `static_scale` @@ -118,24 +129,50 @@ def from_hp( f"got tensor dim={w_hp.dim()}, block_size length={len(block_size)}" ) - scale, zero_point = choose_qparams_affine( - input=w_hp, - mapping_type=MappingType.SYMMETRIC, - block_size=block_size, - target_dtype=torch.int8, - quant_min=-128, - quant_max=127, - scale_dtype=w_hp.dtype, - zero_point_dtype=torch.int8, - ) + if act_quant_kwargs is not None and act_quant_kwargs.act_quant_scale is not None and act_quant_kwargs.is_act: + # Static quantization + shape_for_reduction, reduction_dims = _get_reduction_params( + block_size, w_hp.size() + ) + scale = w_hp.view(shape_for_reduction) + scale = torch.amin(scale, dim=reduction_dims, keepdim=False) + scale = scale.fill_(act_quant_kwargs.act_quant_scale) + scale = scale.to(device=w_hp.device, dtype=w_hp.dtype) + if act_quant_kwargs.act_quant_zero_point is not None: + zero_point = scale.fill_(act_quant_kwargs.act_quant_zero_point) + zero_point = zero_point.to(device=w_hp.device, dtype=torch.int) + else: + zero_point = torch.zeros_like(scale, dtype=torch.int) + + int_data = quantize_affine( + w_hp, + block_size=block_size, + scale=scale, + zero_point=zero_point, + output_dtype=torch.int8, + quant_min=-127, + quant_max=127, + ) + else: + # Dynamic quantization + scale, zero_point = choose_qparams_affine( + input=w_hp, + mapping_type=MappingType.SYMMETRIC, + block_size=block_size, + target_dtype=torch.int8, + quant_min=-128, + quant_max=127, + scale_dtype=w_hp.dtype, + zero_point_dtype=torch.int8, + ) - int_data = quantize_affine( - w_hp, - block_size=block_size, - scale=scale, - zero_point=zero_point, - output_dtype=torch.int8, - ) + int_data = quantize_affine( + w_hp, + block_size=block_size, + scale=scale, + zero_point=zero_point, + output_dtype=torch.int8, + ) return cls( int_data, @@ -230,7 +267,7 @@ def _slice_scale( @implements(aten.linear.default) @implements_torch_function(torch.nn.functional.linear) def _(func, types, args, kwargs): - """INT8 quantization: dynamic activation or weight-only""" + """INT8 quantization: static/dynamic activation or weight-only""" activation_tensor, weight_tensor, bias = ( args[0], args[1], @@ -243,10 +280,15 @@ def _(func, types, args, kwargs): output_dtype = activation_tensor.dtype if weight_tensor.act_quant_kwargs is not None: - activation_tensor = Int8Tensor.from_hp( - activation_tensor, weight_tensor.act_quant_kwargs.granularity - ) - # Dynamic activation quantization path + if not isinstance(activation_tensor, Int8Tensor): + act_kwargs = weight_tensor.act_quant_kwargs + act_kwargs.is_act = True + if act_kwargs.act_quant_scale is not None: + # Static activation quantization path + act_kwargs.granularity = PerTensor() + activation_tensor = Int8Tensor.from_hp( + activation_tensor, act_kwargs.granularity, act_kwargs + ) # 1. do the matrix form of dot(X_i, W_j) # From ae17dc204db0dc4eac3d6f296ac976abf19b89aa Mon Sep 17 00:00:00 2001 From: "Cui, Lily" Date: Thu, 13 Nov 2025 19:32:50 -0800 Subject: [PATCH 2/6] Remove cpu expand Signed-off-by: Cui, Lily --- torchao/kernel/intmm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index 292b67380d..aca95eb403 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -127,7 +127,8 @@ def int_scaled_matmul( assert M == scales1.size(0) or scales1.numel() == 1 assert 1 == scales1.size(1) assert scales1.is_contiguous() - scales1 = scales1.expand((M, N)) + if scales1.device.type != "cpu": + scales1 = scales1.expand((M, N)) assert scales1.dim() == 2 if check_cpu_version(scales1.device): From 14114f3344ab61b9b5cc5232fbe42f607b56e5a1 Mon Sep 17 00:00:00 2001 From: "Cui, Lily" Date: Wed, 26 Nov 2025 17:57:14 -0800 Subject: [PATCH 3/6] Change shapes Signed-off-by: Cui, Lily --- .../quantize_/workflows/int8/int8_tensor.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 445ff8fdb8..d2161a85cb 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -21,7 +21,6 @@ choose_qparams_affine, dequantize_affine, quantize_affine, - _get_reduction_params, ) from torchao.quantization.quantize_.common import QuantizeTensorKwargs from torchao.quantization.utils import get_block_size @@ -128,14 +127,9 @@ def from_hp( f"Expected 2D or 3D tensor with matching block_size dimensions, " f"got tensor dim={w_hp.dim()}, block_size length={len(block_size)}" ) - if act_quant_kwargs is not None and act_quant_kwargs.act_quant_scale is not None and act_quant_kwargs.is_act: - # Static quantization - shape_for_reduction, reduction_dims = _get_reduction_params( - block_size, w_hp.size() - ) - scale = w_hp.view(shape_for_reduction) - scale = torch.amin(scale, dim=reduction_dims, keepdim=False) + # Static activation quantization + scale = torch.amin(w_hp, keepdim=False) scale = scale.fill_(act_quant_kwargs.act_quant_scale) scale = scale.to(device=w_hp.device, dtype=w_hp.dtype) if act_quant_kwargs.act_quant_zero_point is not None: @@ -146,7 +140,7 @@ def from_hp( int_data = quantize_affine( w_hp, - block_size=block_size, + block_size=w_hp.shape, scale=scale, zero_point=zero_point, output_dtype=torch.int8, From c76c5526464ab61e6d896361a844d14ffe395fd0 Mon Sep 17 00:00:00 2001 From: "Cui, Lily" Date: Thu, 20 Nov 2025 16:14:46 +0800 Subject: [PATCH 4/6] use uint8 when only support avx512-vnni Signed-off-by: Cui, Lily --- torchao/kernel/intmm.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index aca95eb403..d5e5731d17 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -134,8 +134,17 @@ def int_scaled_matmul( if check_cpu_version(scales1.device): # CPU prefers decomposed version of int_scaled_matmul # to leverage the fusion capability of Inductor - c = torch._int_mm(a, b) - return c.to(scales1.dtype) * scales1 + + if not torch.cpu._is_amx_tile_supported() and torch.cpu._is_vnni_supported():# uint8 path + a = (a.to(torch.int32) + 128).to(torch.uint8) + c = torch._int_mm(a, b) + zp = a.fill_(128) + zpb = torch._int_mm(zp, b) + c = c - zpb + return c.to(scales1.dtype) * scales1 + else: # int8 path + c = torch._int_mm(a, b) + return c.to(scales1.dtype) * scales1 if intmm_triton is not None and AUTOTUNER_ENABLE: return torch.ops.torchao.int_scaled_matmul(a, b, scales1) From 932bd63930b8d137f66b9057da7102130d08c5c2 Mon Sep 17 00:00:00 2001 From: "Cui, Lily" Date: Tue, 2 Dec 2025 11:52:55 +0800 Subject: [PATCH 5/6] Optimize zero point term compensation Signed-off-by: Cui, Lily --- torchao/kernel/intmm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index d5e5731d17..a5dc98bac7 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -138,9 +138,8 @@ def int_scaled_matmul( if not torch.cpu._is_amx_tile_supported() and torch.cpu._is_vnni_supported():# uint8 path a = (a.to(torch.int32) + 128).to(torch.uint8) c = torch._int_mm(a, b) - zp = a.fill_(128) - zpb = torch._int_mm(zp, b) - c = c - zpb + comp = b.sum(dim=0,keepdim=True, dtype=torch.int32) * 128 + c.sub_(comp) return c.to(scales1.dtype) * scales1 else: # int8 path c = torch._int_mm(a, b) From f84dfbac97ff2b9a8d0dc38aaf1b6797b077224f Mon Sep 17 00:00:00 2001 From: "Cui, Lily" Date: Wed, 3 Dec 2025 14:43:59 +0800 Subject: [PATCH 6/6] Clean up the PR Signed-off-by: Cui, Lily --- test/integration/test_integration.py | 2 +- torchao/kernel/intmm.py | 12 ++--------- torchao/prototype/smoothquant/api.py | 3 +-- .../quantization/linear_activation_scale.py | 20 ++++--------------- 4 files changed, 8 insertions(+), 29 deletions(-) diff --git a/test/integration/test_integration.py b/test/integration/test_integration.py index 2aac286cf6..dbe0606eda 100644 --- a/test/integration/test_integration.py +++ b/test/integration/test_integration.py @@ -1058,7 +1058,7 @@ def test_weight_only_groupwise_embedding_quant(self): quantize_( m, - Int8WeightOnlyConfig(group_size=group_size,version=2), + Int8WeightOnlyConfig(group_size=group_size), filter_fn=lambda x, *args: isinstance(x, nn.Embedding), ) y_q = m(input) diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index a5dc98bac7..aca95eb403 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -134,16 +134,8 @@ def int_scaled_matmul( if check_cpu_version(scales1.device): # CPU prefers decomposed version of int_scaled_matmul # to leverage the fusion capability of Inductor - - if not torch.cpu._is_amx_tile_supported() and torch.cpu._is_vnni_supported():# uint8 path - a = (a.to(torch.int32) + 128).to(torch.uint8) - c = torch._int_mm(a, b) - comp = b.sum(dim=0,keepdim=True, dtype=torch.int32) * 128 - c.sub_(comp) - return c.to(scales1.dtype) * scales1 - else: # int8 path - c = torch._int_mm(a, b) - return c.to(scales1.dtype) * scales1 + c = torch._int_mm(a, b) + return c.to(scales1.dtype) * scales1 if intmm_triton is not None and AUTOTUNER_ENABLE: return torch.ops.torchao.int_scaled_matmul(a, b, scales1) diff --git a/torchao/prototype/smoothquant/api.py b/torchao/prototype/smoothquant/api.py index b7992da7a5..0917792265 100644 --- a/torchao/prototype/smoothquant/api.py +++ b/torchao/prototype/smoothquant/api.py @@ -114,9 +114,8 @@ def _smooth_quant_transform( qw = quant_mod.weight # Add smoothing factor metadata - use_inv_scale = qw.device.type == "cpu" qw = to_weight_tensor_with_linear_activation_scale_metadata( - qw, smoothing_factor.to(qw.dtype), use_inv_scale + qw, smoothing_factor.to(qw.dtype) ) # Create new linear layer diff --git a/torchao/quantization/linear_activation_scale.py b/torchao/quantization/linear_activation_scale.py index 4b327f7de7..fd61d07d9a 100644 --- a/torchao/quantization/linear_activation_scale.py +++ b/torchao/quantization/linear_activation_scale.py @@ -21,7 +21,6 @@ class WeightTensorWithLinearActivationScaleMetadata(TorchAOBaseTensor): Tensor subclass that wraps a weight tensor and provides metadata for linear activation scaling. Right now we hardcode how we apply the scale: scaled_linear_act = input_act / scale - # or scaled_linear_act = input_act * inv_scale out = F.linear(scaled_linear_act, weight, ...) We can generalize this to accept a function as well if needed. @@ -32,13 +31,12 @@ class WeightTensorWithLinearActivationScaleMetadata(TorchAOBaseTensor): """ tensor_data_names = ["original_weight_tensor", "scale"] - tensor_attribute_names = ["use_inv_scale"] + tensor_attribute_names = [] def __new__( cls, original_weight_tensor: torch.Tensor, scale: torch.Tensor, - use_inv_scale: bool = False, ): kwargs = {} dtype = original_weight_tensor.dtype @@ -52,12 +50,9 @@ def __init__( self, original_weight_tensor: torch.Tensor, scale: torch.Tensor, - use_inv_scale: bool = False, ): self.original_weight_tensor = original_weight_tensor self.scale = scale - self.use_inv_scale = use_inv_scale - self.inv_scale = 1.0 / scale if use_inv_scale else None def _quantization_type(self): return f"{self.__class__}" @@ -68,12 +63,8 @@ def _quantized_linear_op( ): original_weight_tensor = weight_tensor.original_weight_tensor scale = weight_tensor.scale - inv_scale = weight_tensor.inv_scale - use_inv_scale = weight_tensor.use_inv_scale # Note: we can make this function configurable as well - scaled_input_act = ( - input_tensor * inv_scale if use_inv_scale else input_tensor / scale - ) + scaled_input_act = input_tensor / scale return torch.nn.functional.linear( scaled_input_act, original_weight_tensor, bias ) @@ -83,9 +74,8 @@ def from_float( cls, input_float: torch.Tensor, scale: torch.Tensor, - use_inv_scale: bool = False, ): - return cls(input_float, scale, use_inv_scale) + return cls(input_float, scale) implements = WeightTensorWithLinearActivationScaleMetadata.implements @@ -113,9 +103,7 @@ def _(func, types, args, kwargs): def _(func, types, args, kwargs): self = args[0] new = self.__class__( - func(self.original_weight_tensor, *args[1:], **kwargs), - self.scale, - self.use_inv_scale, + func(self.original_weight_tensor, *args[1:], **kwargs), self.scale ) return return_and_correct_aliasing(func, args, kwargs, new)