diff --git a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py index 2acdff2b84..2819903e69 100644 --- a/test/quantization/quantize_/workflows/int8/test_int8_tensor.py +++ b/test/quantization/quantize_/workflows/int8/test_int8_tensor.py @@ -18,11 +18,23 @@ quantize_, ) from torchao.quantization.granularity import PerRow, PerTensor +from torchao.quantization.quant_primitives import MappingType from torchao.quantization.utils import compute_error, get_block_size from torchao.testing.model_architectures import ToyTwoLinearModel from torchao.testing.utils import TorchAOIntegrationTestCase from torchao.utils import torch_version_at_least +INT8_TEST_CONFIGS = [ + Int8WeightOnlyConfig(version=2, granularity=PerTensor()), + Int8WeightOnlyConfig(version=2, granularity=PerRow()), + Int8DynamicActivationInt8WeightConfig( + version=2, granularity=PerTensor(), act_mapping_type=MappingType.SYMMETRIC + ), + Int8DynamicActivationInt8WeightConfig( + version=2, granularity=PerRow(), act_mapping_type=MappingType.SYMMETRIC + ), +] + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @common_utils.instantiate_parametrized_tests @@ -36,13 +48,7 @@ def setUp(self): torch.manual_seed(42) - @common_utils.parametrize( - "config", - [ - Int8DynamicActivationInt8WeightConfig(version=2), - Int8WeightOnlyConfig(version=2), - ], - ) + @common_utils.parametrize("config", INT8_TEST_CONFIGS) def test_creation_and_attributes(self, config): """Test tensor creation, dtypes, and ranges""" linear = torch.nn.Linear( @@ -60,15 +66,17 @@ def test_creation_and_attributes(self, config): self.assertEqual(w.qdata.dtype, torch.int8) self.assertTrue(torch.all(w.qdata >= -128) and torch.all(w.qdata <= 127)) + if isinstance(config.granularity, PerRow): + self.assertEqual(w.scale.shape, (w.shape[0], 1)) + elif isinstance(config.granularity, PerTensor): + self.assertEqual(w.scale.shape, (1, 1)) + + if hasattr(config, "act_mapping_type"): + self.assertEqual(w.act_quant_kwargs.mapping_type, config.act_mapping_type) + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) @common_utils.parametrize("compile", [True, False]) - @common_utils.parametrize( - "config", - [ - Int8DynamicActivationInt8WeightConfig(version=2), - Int8WeightOnlyConfig(version=2), - ], - ) + @common_utils.parametrize("config", INT8_TEST_CONFIGS) @common_utils.parametrize( "sizes", [ @@ -84,6 +92,8 @@ def test_int8_linear_variants( sizes: tuple, ): """Test linear operation supports including shape and compile""" + torch.compiler.reset() + M, N, K = sizes input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval() @@ -91,10 +101,19 @@ def test_int8_linear_variants( quantize_(model_q, config) - self.assertEqual(model_q.linear2.weight.scale.shape, (K,)) - self.assertEqual(model_q.linear2.weight.scale.ndim, 1) + if isinstance(config.granularity, PerRow): + self.assertEqual(model_q.linear2.weight.scale.shape, (K, 1)) + elif isinstance(config.granularity, PerTensor): + self.assertEqual(model_q.linear2.weight.scale.shape, (1, 1)) + + self.assertEqual(model_q.linear2.weight.scale.ndim, 2) if compile: + if isinstance(config, Int8WeightOnlyConfig) and isinstance( + config.granularity, PerTensor + ): + # currently the inductor lowering for weight only quant in core does not support per-tensor gpu, so this errors. Skipping for now, but will address this in core + return model_q = torch.compile(model_q, fullgraph=True) output_fp = model(input_tensor) @@ -104,13 +123,7 @@ def test_int8_linear_variants( f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}" ) - @common_utils.parametrize( - "config", - [ - Int8DynamicActivationInt8WeightConfig(version=2), - Int8WeightOnlyConfig(version=2), - ], - ) + @common_utils.parametrize("config", INT8_TEST_CONFIGS) @common_utils.parametrize("device", ["cpu", "cuda"]) @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) def test_slice(self, config, device, dtype): @@ -128,27 +141,24 @@ def test_slice(self, config, device, dtype): self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0])) self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1])) - self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0])) + + if isinstance(config.granularity, PerRow): + self.assertEqual( + weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0]) + ) + self.assertEqual(weight2.scale, dummy.weight.scale) with self.assertRaises(NotImplementedError): _ = dummy.weight[::2] - @common_utils.parametrize( - "config", - [ - Int8DynamicActivationInt8WeightConfig, - Int8WeightOnlyConfig, - ], - ) - @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) - def test_index_select(self, config, granularity): + @common_utils.parametrize("config", INT8_TEST_CONFIGS) + def test_index_select(self, config): """test that `x_0 = x[0]` works when `x` is a 2D quantized tensor.""" N, K = 256, 512 x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda") linear.weight.data = x - config = config(version=2, granularity=granularity) quantize_(linear, config) x_int8 = linear.weight @@ -160,22 +170,16 @@ def test_index_select(self, config, granularity): ) # Test block_size granularity - if isinstance(granularity, PerRow): + if isinstance(config.granularity, PerRow): self.assertEqual( - list(get_block_size(x_int8.shape, x_int8.granularity)), [1, K] + list(get_block_size(x_int8.shape, config.granularity)), [1, K] ) - elif isinstance(granularity, PerTensor): + elif isinstance(config.granularity, PerTensor): self.assertEqual( - list(get_block_size(x_int8.shape, x_int8.granularity)), [N, K] + list(get_block_size(x_int8.shape, config.granularity)), [N, K] ) - @common_utils.parametrize( - "config", - [ - Int8DynamicActivationInt8WeightConfig(version=2), - Int8WeightOnlyConfig(version=2), - ], - ) + @common_utils.parametrize("config", INT8_TEST_CONFIGS) def test_dequantization_accuracy(self, config): """Test dequantization accuracy separately""" linear = torch.nn.Linear( diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index a1ca6b0b94..80e11dda5b 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -98,6 +98,7 @@ Int4PreshuffledTensor, Int4Tensor, Int4TilePackedTo4dTensor, + Int8Tensor, IntxOpaqueTensor, IntxUnpackedToInt8Tensor, ) @@ -164,6 +165,7 @@ "FqnToConfig", "ModuleFqnToConfig", # tensor subclasses + "Int8Tensor", "Int4Tensor", "Int4PlainInt32Tensor", "Int4PreshuffledTensor", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9a1bfeb0a5..24d6b6676c 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1341,6 +1341,10 @@ class Int8WeightOnlyConfig(AOBaseConfig): def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig") + if self.version == 2: + assert self.group_size is None, ( + f"Only support version 2 with group_size=None, got {self.group_size}" + ) # for BC @@ -1522,9 +1526,7 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig): layout: Optional[Layout] = PlainLayout() act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC weight_only_decode: bool = False - # TODO: Revisit for supported granularitys - # https://github.com/pytorch/ao/pull/3241#discussion_r2551497849 - granularity: Optional[Granularity] = PerRow() + granularity: Granularity = PerRow() set_inductor_config: bool = True version: int = 1 @@ -1541,37 +1543,30 @@ def __post_init__(self): def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): - layout = config.layout - act_mapping_type = config.act_mapping_type - weight_only_decode = config.weight_only_decode - - in_features = weight.shape[-1] - # int8 dynamic quantization only has benefit when in_feature > 16 - if in_features <= 16: - logger.info( - f"Skipping applying Int8DynamicActivationInt8WeightConfig to weight of shape {weight.shape}" - f" because `in_feature` is <= 16: {in_features}" - ) - return weight + if config.version == 1: + layout = config.layout + act_mapping_type = config.act_mapping_type + weight_only_decode = config.weight_only_decode + + in_features = weight.shape[-1] + # int8 dynamic quantization only has benefit when in_feature > 16 + if in_features <= 16: + logger.info( + f"Skipping applying Int8DynamicActivationInt8WeightConfig to weight of shape {weight.shape}" + f" because `in_feature` is <= 16: {in_features}" + ) + return weight - # weight settings - mapping_type = MappingType.SYMMETRIC - weight_zero_point_domain = ZeroPointDomain.NONE + # weight settings + mapping_type = MappingType.SYMMETRIC + weight_zero_point_domain = ZeroPointDomain.NONE - target_dtype = torch.int8 - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int64 + def get_weight_block_size(x): + return tuple([1 for _ in range(x.dim() - 1)] + [x.shape[-1]]) - if config.version == 1: - warnings.warn( - "Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details" - ) - if isinstance(config.granularity, PerTensor): - block_size = weight.shape - else: - block_size = tuple( - [1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]] - ) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 if weight_only_decode: input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode @@ -1582,7 +1577,8 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): else: input_quant_func = _int8_asymm_per_token_quant - quantized_weight = to_affine_quantized_intx( + block_size = get_weight_block_size(weight) + new_weight = to_affine_quantized_intx( weight, mapping_type, block_size, @@ -1592,24 +1588,32 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config): _layout=layout, zero_point_domain=weight_zero_point_domain, ) - quantized_weight = to_linear_activation_quantized( - quantized_weight, input_quant_func - ) + quantized_weight = to_linear_activation_quantized(new_weight, input_quant_func) else: from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( QuantizeTensorToInt8Kwargs, ) + assert config.granularity in {PerRow(), PerTensor()}, ( + "Only PerRow and PerTensor are supported" + ) + weight_granularity = config.granularity + act_granularity = config.granularity + + assert config.act_mapping_type == MappingType.SYMMETRIC, ( + "asymmetric dynamic quant not supported currently" + ) assert config.version == 2, f"Unexpected version: {config.version}" # TODO: Symmentric/Asymmetric choice for weight quantization # https://github.com/pytorch/ao/pull/3241#discussion_r2551515539 - # TODO: Add block_size args to return in from_hp - # https://github.com/pytorch/ao/pull/3241#discussion_r2552016429 quantized_weight = Int8Tensor.from_hp( weight, - granularity=config.granularity, - act_quant_kwargs=QuantizeTensorToInt8Kwargs(granularity=config.granularity), + granularity=weight_granularity, + act_quant_kwargs=QuantizeTensorToInt8Kwargs( + granularity=act_granularity, + mapping_type=config.act_mapping_type, + ), ) return quantized_weight diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index ee1da11c50..9bdb3871a2 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -1217,6 +1217,7 @@ def choose_qparams_affine( eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = torch.int32, + keepdim: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -1247,6 +1248,7 @@ def choose_qparams_affine( eps, scale_dtype, zero_point_dtype, + keepdim, ) @@ -1521,6 +1523,7 @@ def _choose_qparams_affine( eps: Optional[float] = None, scale_dtype: Optional[torch.dtype] = None, zero_point_dtype: Optional[torch.dtype] = None, + keepdim: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]: """op definition that has compatible signatures with custom op library @@ -1550,8 +1553,8 @@ def _choose_qparams_affine( ) input = input.view(shape_for_reduction) - min_val = torch.amin(input, dim=reduction_dims, keepdim=False) - max_val = torch.amax(input, dim=reduction_dims, keepdim=False) + min_val = torch.amin(input, dim=reduction_dims, keepdim=keepdim) + max_val = torch.amax(input, dim=reduction_dims, keepdim=keepdim) min_val_neg = torch.min(min_val, torch.zeros_like(min_val)) max_val_pos = torch.max(max_val, torch.zeros_like(max_val)) diff --git a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py index 0adc8c786d..e4544a2f0c 100644 --- a/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py +++ b/torchao/quantization/quantize_/common/quantize_tensor_kwargs.py @@ -39,7 +39,9 @@ def _choose_quant_func_and_quantize_tensor( """ from torchao.quantization.quantize_.workflows import ( Float8Tensor, + Int8Tensor, QuantizeTensorToFloat8Kwargs, + QuantizeTensorToInt8Kwargs, ) if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs): @@ -53,4 +55,11 @@ def _choose_quant_func_and_quantize_tensor( quant_kwargs.kernel_preference, ) + if isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs): + return Int8Tensor.from_hp( + tensor, + quant_kwargs.granularity, + mapping_type=quant_kwargs.mapping_type, + ) + raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}") diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 962f95157f..17cb15d4f7 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -42,6 +42,8 @@ "QuantizeTensorToInt8Kwargs", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", + "Int8Tensor", + "QuantizeTensorToInt8Kwargs", "Int4ChooseQParamsAlgorithm", "Int4PackingFormat", "IntxChooseQParamsAlgorithm", diff --git a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py index 6ca31326cd..dd422b90f6 100644 --- a/torchao/quantization/quantize_/workflows/int8/int8_tensor.py +++ b/torchao/quantization/quantize_/workflows/int8/int8_tensor.py @@ -5,20 +5,24 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass -from typing import Optional +from typing import List, Optional import torch from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.float8.inference import _slice_scale_for_dimension from torchao.kernel import int_scaled_matmul -from torchao.quantization.granularity import Granularity, PerRow +from torchao.quantization.granularity import Granularity from torchao.quantization.quant_primitives import ( MappingType, choose_qparams_affine, dequantize_affine, quantize_affine, ) -from torchao.quantization.quantize_.common import QuantizeTensorKwargs +from torchao.quantization.quantize_.common import ( + QuantizeTensorKwargs, + _choose_quant_func_and_quantize_tensor, +) from torchao.quantization.utils import get_block_size from torchao.utils import TorchAOBaseTensor, fill_defaults @@ -29,18 +33,22 @@ @dataclass class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs): - """Tensor kwargs for creating int8 tensor (either activation or weight) + """Tensor kwargs for creating int8 tensor for activation. Args: granularity: the granularity for the Tensor, currently either PerRow() or PerTensor() + mapping_type: whether to use symmetric or asymmetric quant, only symmetric is supported currently """ - granularity: Granularity = PerRow() + granularity: Granularity + mapping_type: MappingType = MappingType.SYMMETRIC class Int8Tensor(TorchAOBaseTensor): """ - int8 quantized tensor with plain layout + int8 quantized tensor with plain layout. + + Currently only Symmetric quantization is supported. Tensor Attributes: qdata: (N, K) or (B, N, K) int8 quantized weight data (2D or 3D) @@ -54,21 +62,22 @@ class Int8Tensor(TorchAOBaseTensor): # TODO: Static quantization support using `static_scale` tensor_data_names = ["qdata", "scale"] - tensor_attribute_names = ["granularity"] - optional_tensor_attribute_names = ["act_quant_kwargs", "block_size", "dtype"] + tensor_attribute_names = ["block_size", "dtype"] + optional_tensor_attribute_names = [ + "act_quant_kwargs", + ] def __new__( cls: type, qdata: torch.Tensor, scale: torch.Tensor, - granularity: Optional[Granularity] = None, - block_size: Optional[torch.Size] = None, + block_size: List[int], + dtype: torch.dtype, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, - dtype: Optional[torch.dtype] = None, ): kwargs = { "device": qdata.device, - "dtype": dtype or scale.dtype, + "dtype": dtype, "requires_grad": False, } return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, **kwargs) @@ -77,16 +86,15 @@ def __init__( self, qdata: torch.Tensor, scale: torch.Tensor, - granularity: Granularity, - block_size: Optional[torch.Size] = None, + block_size: List[int], + dtype: torch.dtype, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, - dtype: Optional[torch.dtype] = None, ): super().__init__() self.qdata = qdata self.scale = scale - self.granularity = granularity - self.block_size = block_size or get_block_size(qdata.shape, granularity) + self.block_size = block_size + # don't set dtype because this gets done in __new__ self.act_quant_kwargs = act_quant_kwargs def __repr__(self): @@ -95,7 +103,6 @@ def __repr__(self): f"act_quant_kwargs={self.act_quant_kwargs}, " f"qdata={self.qdata}, " f"scale={self.scale}, " - f"granularity={self.granularity}, " f"block_size={self.block_size}, " f"shape={self.shape}, " f"device={self.device}, " @@ -105,32 +112,29 @@ def __repr__(self): @classmethod def from_hp( cls, - w_hp: torch.Tensor, - granularity: Granularity = PerRow(), + hp_tensor: torch.Tensor, + granularity: Granularity, act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None, + mapping_type=MappingType.SYMMETRIC, ): """Create Int8Tensor from high-precision tensor""" - block_size = get_block_size(w_hp.shape, granularity) - - if w_hp.dim() not in [2, 3] or len(block_size) != w_hp.dim(): - raise ValueError( - f"Expected 2D or 3D tensor with matching block_size dimensions, " - f"got tensor dim={w_hp.dim()}, block_size length={len(block_size)}" - ) + block_size = get_block_size(hp_tensor.shape, granularity) + block_size = list(block_size) scale, zero_point = choose_qparams_affine( - input=w_hp, - mapping_type=MappingType.SYMMETRIC, + input=hp_tensor, + mapping_type=mapping_type, block_size=block_size, target_dtype=torch.int8, quant_min=-128, quant_max=127, - scale_dtype=w_hp.dtype, + scale_dtype=hp_tensor.dtype, zero_point_dtype=torch.int8, + keepdim=True, ) int_data = quantize_affine( - w_hp, + hp_tensor, block_size=block_size, scale=scale, zero_point=zero_point, @@ -140,28 +144,22 @@ def from_hp( return cls( int_data, scale, - granularity, - block_size=block_size, + block_size, + hp_tensor.dtype, act_quant_kwargs=act_quant_kwargs, - dtype=w_hp.dtype, ) def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Dequantize int8 tensor to floating point""" - if output_dtype is None: - output_dtype = self.dtype - - block_size = get_block_size(self.qdata.shape, self.granularity) - return dequantize_affine( input=self.qdata, - block_size=block_size, + block_size=self.block_size, scale=self.scale, zero_point=None, input_dtype=torch.int8, quant_min=-128, quant_max=127, - output_dtype=output_dtype, + output_dtype=output_dtype if output_dtype is not None else self.dtype, ) @@ -169,64 +167,6 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor implements_torch_function = Int8Tensor.implements_torch_function -def _slice_scale( - scale: torch.Tensor, - data_shape: list[int], - dim: int, - start: int, - end: int, - step: int, -) -> torch.Tensor: - """ - Slice the scale tensor appropriately based on the data tensor slicing. - This function calculates how the scale should be sliced when the data tensor - is sliced along a given dimension, taking into account the block structure. - - Example: - If data_shape is [256, 128] and scale shape is [1] (indicating per-tensor scaling), - slicing along any dimension should return the same scale tensor. - - If data_shape is [256, 128] and scale shape is [256] (indicating per-row scaling), - and we slice data along dim=0 from 64 to 192, the corresponding scale - """ - aten = torch.ops.aten - - # Case 1: Per-tensor quantization (scalar scale) - if scale.numel() <= 1: - return scale - - # Case 2: Per-row quantization (1D scale) - # Scale is per-element along this dimension - if scale.ndim == 1: - if dim == 0: - return aten.slice.Tensor(scale, 0, start, end, step) - else: - return scale - - # Case 3: Per-block quantization (2D scale) - block_sizes = tuple( - data_shape[i] // scale.shape[i] for i in range(len(scale.shape)) - ) - - block_size_for_dim = block_sizes[dim] - - if step > 1: - raise NotImplementedError( - "Slicing with step > 1 is not implemented for scale tensors." - ) - - # There is blocking in this dimension - # Calculate which scale elements correspond to the sliced data - scale_start = start // block_size_for_dim if start is not None else None - scale_end = ( - (end + block_size_for_dim - 1) // block_size_for_dim - if end is not None - else None - ) - - return aten.slice.Tensor(scale, dim, scale_start, scale_end, 1) - - @implements(aten.linear.default) @implements_torch_function(torch.nn.functional.linear) def _(func, types, args, kwargs): @@ -237,14 +177,15 @@ def _(func, types, args, kwargs): args[2] if len(args) > 2 else None, ) - if not isinstance(weight_tensor, Int8Tensor): - raise TypeError(f"Expected weight to be Int8Tensor, got {type(weight_tensor)}") + assert isinstance(weight_tensor, Int8Tensor), ( + f"Expected weight to be Int8Tensor, got {type(weight_tensor)}" + ) output_dtype = activation_tensor.dtype if weight_tensor.act_quant_kwargs is not None: - activation_tensor = Int8Tensor.from_hp( - activation_tensor, weight_tensor.act_quant_kwargs.granularity + activation_tensor = _choose_quant_func_and_quantize_tensor( + activation_tensor, weight_tensor.act_quant_kwargs ) # Dynamic activation quantization path @@ -270,7 +211,7 @@ def _(func, types, args, kwargs): y_dot_scaled = int_scaled_matmul( tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype) ).to(output_dtype) - y = (y_dot_scaled * w_scales).reshape( + y = (y_dot_scaled * w_scales.flatten()).reshape( *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] ) @@ -281,7 +222,7 @@ def _(func, types, args, kwargs): activation_tensor.reshape(-1, activation_tensor.shape[-1]), w_vals_int8_t.to(output_dtype), ) - y = m * weight_tensor.scale.to(m.dtype) + y = m * weight_tensor.scale.to(m.dtype).flatten() y = y.reshape(*activation_tensor.shape[:-1], weight_tensor.qdata.shape[0]) if bias is not None: @@ -310,7 +251,19 @@ def _(func, types, args, kwargs): end = self.shape[dim] sliced_qdata = aten.slice.Tensor(self.qdata, dim, start, end, step) - sliced_scale = _slice_scale(self.scale, self.qdata.shape, dim, start, end, step) + if self.scale.numel() == 1: + # Per-tensor quantization - scale doesn't change + sliced_scale = self.scale + else: + # Block-wise quantization - need to slice the scale appropriately + sliced_scale = _slice_scale_for_dimension( + self.scale, self.qdata.shape, dim, start, end, step + ) + + # adjust block_size since the shape has changed, block_size[i] should not be greater than shape[i] + block_size = self.block_size.copy() + for i in range(len(self.block_size)): + block_size[i] = min(block_size[i], sliced_qdata.shape[i]) return return_and_correct_aliasing( func, @@ -319,39 +272,42 @@ def _(func, types, args, kwargs): Int8Tensor( sliced_qdata, sliced_scale, - self.granularity, - block_size=get_block_size(sliced_qdata.shape, self.granularity), + block_size, + self.dtype, act_quant_kwargs=self.act_quant_kwargs, - dtype=self.dtype, ), ) -@implements(aten.select.int) +@implements(aten.index.Tensor) def _(func, types, args, kwargs): - """Select operation for Int8Tensor""" - self, dim, index = args - if dim != 0: - raise NotImplementedError(f"Only dim=0 supported, got dim={dim}") - - selected_qdata = self.qdata[index] - selected_scale = _slice_scale( - self.scale, self.qdata.shape, dim, index, index + 1, step=1 - ).squeeze(0) - return return_and_correct_aliasing( func, args, kwargs, - Int8Tensor( - selected_qdata, - selected_scale, - self.granularity, - block_size=get_block_size(selected_qdata.shape, self.granularity), - act_quant_kwargs=self.act_quant_kwargs, - dtype=self.dtype, - ), + args[0]._apply_fn_to_data(lambda x: func(x, *args[1:], **kwargs)), + ) + + +@implements(aten.select.int) +def _(func, types, args, kwargs): + """Select operation for Int8Tensor""" + old_int8_tensor, dim, index = args + assert dim == 0, f"Int8Tensor aten.select.int with {dim=} is not yet supported" + assert len(old_int8_tensor.qdata.shape) == len(old_int8_tensor.scale.shape), ( + "unsupported" + ) + assert len(old_int8_tensor.qdata.shape) == len(old_int8_tensor.block_size), ( + "unsupported" + ) + new_int8_tensor = Int8Tensor( + old_int8_tensor.qdata[index], + old_int8_tensor.scale[index], + old_int8_tensor.block_size[1:], + old_int8_tensor.dtype, + old_int8_tensor.act_quant_kwargs, ) + return return_and_correct_aliasing(func, args, kwargs, new_int8_tensor) Int8Tensor.__module__ = "torchao.quantization"