Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@
version=2, granularity=PerTensor(), act_mapping_type=MappingType.ASYMMETRIC
),
Int8DynamicActivationInt8WeightConfig(
version=2, granularity=PerRow(), act_mapping_type=MappingType.ASYMMETRIC
version=2, granularity=PerTensor(), act_mapping_type=MappingType.SYMMETRIC
),
Int8DynamicActivationInt8WeightConfig(
version=2, granularity=PerTensor(), act_mapping_type=MappingType.SYMMETRIC
version=2, granularity=PerRow(), act_mapping_type=MappingType.ASYMMETRIC
),
Int8DynamicActivationInt8WeightConfig(
version=2, granularity=PerRow(), act_mapping_type=MappingType.SYMMETRIC
version=2, granularity=PerTensor(), act_mapping_type=MappingType.SYMMETRIC
),
]

Expand Down Expand Up @@ -77,14 +77,6 @@ def test_creation_and_attributes(self, config):
elif isinstance(config.granularity, PerTensor):
self.assertEqual(w.scale.shape, (1, 1))

if config.act_mapping_type == MappingType.SYMMETRIC:
self.assertEqual(w.zero_point, None)
elif config.act_mapping_type == MappingType.ASYMMETRIC:
if isinstance(config.granularity, PerRow):
self.assertEqual(w.zero_point.shape, (w.shape[0], 1))
elif isinstance(config.granularity, PerTensor):
self.assertEqual(w.zero_point.shape, (1, 1))

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize("config", INT8_TEST_CONFIGS)
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8DynamicActivationIntxWeightConfig,
Int8StaticActivationInt8WeightConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
ModuleFqnToConfig,
Expand Down Expand Up @@ -150,6 +151,7 @@
"Int8DynamicActivationInt4WeightConfig",
"Int8DynamicActivationInt8WeightConfig",
"Int8DynamicActivationIntxWeightConfig",
"Int8StaticActivationInt8WeightConfig",
"Int4WeightOnlyConfig",
"Float8DynamicActivationInt4WeightConfig",
"Int8WeightOnlyConfig",
Expand Down
71 changes: 65 additions & 6 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
IntxPackingFormat,
IntxUnpackedToInt8Tensor,
QuantizeTensorToFloat8Kwargs,
QuantizeTensorToInt8Kwargs,
)
from torchao.quantization.transform_module import (
_QUANTIZE_CONFIG_HANDLER,
Expand Down Expand Up @@ -1590,10 +1591,6 @@ def get_weight_block_size(x):
)
quantized_weight = to_linear_activation_quantized(new_weight, input_quant_func)
else:
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
QuantizeTensorToInt8Kwargs,
)

assert config.granularity in {PerRow(), PerTensor()}, (
"Only PerRow and PerTensor are supported"
)
Expand All @@ -1608,7 +1605,7 @@ def get_weight_block_size(x):
granularity=config.granularity,
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
granularity=act_granularity,
act_mapping_type=config.act_mapping_type,
mapping_type=config.act_mapping_type,
),
)

Expand All @@ -1617,7 +1614,10 @@ def get_weight_block_size(x):

@register_quantize_module_handler(Int8DynamicActivationInt8WeightConfig)
def _int8_dynamic_activation_int8_weight_transform(
module: torch.nn.Module, config: Int8DynamicActivationInt8WeightConfig
module: torch.nn.Module,
config: Int8DynamicActivationInt8WeightConfig,
*,
parameter_name="weight",
) -> torch.nn.Module:
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()
Expand All @@ -1634,6 +1634,65 @@ def _int8_dynamic_activation_int8_weight_transform(
return module


@dataclass
class Int8StaticActivationInt8WeightConfig(AOBaseConfig):
"""
Configuration for applying float8 static symmetric quantization to

Args:
scale (torch.Tensor): The scale tensor for activation quantization.
activation_dtype (torch.dtype): The target data type for activation quantization. Default is torch.float8_e4m
weight_dtype (torch.dtype): The target data type for weight quantization. Default is torch.float8_e4m
mm_config (Float8MMConfig): Configuration for the matrix multiplication. Default uses fast accumulation.
set_inductor_config (bool): if True, adjusts `torchinductor` settings to recommended values.
"""

scale: torch.Tensor
zero_point: Optional[torch.Tensor] = None
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
granularity: Optional[Union[Granularity, List[Granularity]]] = PerRow()
set_inductor_config: bool = True
version: int = 1

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Int8StaticActivationInt8WeightConfig"
)


@register_quantize_module_handler(Int8StaticActivationInt8WeightConfig)
def _int8_static_activation_int8_weight_transform(
module: torch.nn.Module, config: Int8StaticActivationInt8WeightConfig
):
assert config.granularity in {PerRow(), PerTensor()}, (
"Only PerRow and PerTensor are supported"
)

if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()

activation_granularity, weight_granularity = _normalize_granularity(
config.granularity
)
weight = module.weight

# TODO: Symmentric/Asymmetric choice for weight quantization
# https://github.com/pytorch/ao/pull/3241#discussion_r2551515539
quantized_weight = Int8Tensor.from_hp(
weight,
granularity=weight_granularity,
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
granularity=activation_granularity,
mapping_type=config.act_mapping_type,
scale=config.scale,
zero_point=config.zero_point,
),
)
module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


def int8_dynamic_activation_int8_semi_sparse_weight():
"""
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
Expand Down
51 changes: 36 additions & 15 deletions torchao/quantization/quantize_/workflows/int8/int8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs):
"""

granularity: Granularity = PerRow()
act_mapping_type: MappingType = MappingType.SYMMETRIC
mapping_type: MappingType = MappingType.SYMMETRIC
scale: Optional[torch.Tensor] = None
zero_point: Optional[torch.Tensor] = None


class Int8Tensor(TorchAOBaseTensor):
Expand All @@ -58,6 +60,7 @@ class Int8Tensor(TorchAOBaseTensor):
tensor_data_names = ["qdata", "scale"]
tensor_attribute_names = []
optional_tensor_attribute_names = [
"zero_point",
"block_size",
"act_quant_kwargs",
"dtype",
Expand All @@ -67,6 +70,7 @@ def __new__(
cls: type,
qdata: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor] = None,
block_size: Optional[List[int]] = None,
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
dtype: Optional[torch.dtype] = None,
Expand All @@ -82,13 +86,15 @@ def __init__(
self,
qdata: torch.Tensor,
scale: torch.Tensor,
zero_point: Optional[torch.Tensor] = None,
block_size: Optional[List[int]] = None,
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.qdata = qdata
self.scale = scale
self.zero_point = zero_point
self.block_size = block_size
self.act_quant_kwargs = act_quant_kwargs

Expand All @@ -98,6 +104,7 @@ def __repr__(self):
f"act_quant_kwargs={self.act_quant_kwargs}, "
f"qdata={self.qdata}, "
f"scale={self.scale}, "
f"zero_point={self.scale}, "
f"block_size={self.block_size}, "
f"shape={self.shape}, "
f"device={self.device}, "
Expand All @@ -110,23 +117,30 @@ def from_hp(
hp_tensor: torch.Tensor,
granularity: Granularity = PerRow(),
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
scale: Optional[torch.Tensor] = None,
zero_point: Optional[torch.Tensor] = None,
mapping_type: MappingType = MappingType.SYMMETRIC,
):
"""Create Int8Tensor from high-precision tensor"""
block_size = get_block_size(hp_tensor.shape, granularity)
block_size = list(block_size)

scale, zero_point = choose_qparams_affine(
input=hp_tensor,
mapping_type=MappingType.SYMMETRIC,
block_size=block_size,
target_dtype=torch.int8,
quant_min=-128,
quant_max=127,
scale_dtype=hp_tensor.dtype,
zero_point_dtype=torch.int8,
keepdim=True,
)

# if scale and zero_point not given, then choose them dynamically
if scale is None and zero_point is None:
scale, zero_point = choose_qparams_affine(
input=hp_tensor,
mapping_type=mapping_type,
block_size=block_size,
target_dtype=torch.int8,
quant_min=-128,
quant_max=127,
scale_dtype=hp_tensor.dtype,
zero_point_dtype=torch.int8,
keepdim=True,
)

# if they are given, then use them to quantize
# this is how we support static quantization
int_data = quantize_affine(
hp_tensor,
block_size=block_size,
Expand All @@ -145,11 +159,14 @@ def from_hp(

def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
"""Dequantize int8 tensor to floating point"""
zero_point = self.zero_point
if zero_point is not None:
zero_point = zero_point.squeeze()
return dequantize_affine(
input=self.qdata,
block_size=self.block_size,
scale=self.scale.squeeze(),
zero_point=None,
zero_point=zero_point,
input_dtype=torch.int8,
quant_min=-128,
quant_max=127,
Expand Down Expand Up @@ -179,7 +196,11 @@ def _(func, types, args, kwargs):

if weight_tensor.act_quant_kwargs is not None:
activation_tensor = Int8Tensor.from_hp(
activation_tensor, weight_tensor.act_quant_kwargs.granularity
activation_tensor,
granularity=weight_tensor.act_quant_kwargs.granularity,
mapping_type=weight_tensor.act_quant_kwargs.mapping_type,
scale=weight_tensor.act_quant_kwargs.scale,
zero_point=weight_tensor.act_quant_kwargs.zero_point,
)
# Dynamic activation quantization path

Expand Down
Loading