Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 49 additions & 45 deletions test/quantization/quantize_/workflows/int8/test_int8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,23 @@
quantize_,
)
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.quant_primitives import MappingType
from torchao.quantization.utils import compute_error, get_block_size
from torchao.testing.model_architectures import ToyTwoLinearModel
from torchao.testing.utils import TorchAOIntegrationTestCase
from torchao.utils import torch_version_at_least

INT8_TEST_CONFIGS = [
Int8WeightOnlyConfig(version=2, granularity=PerTensor()),
Int8WeightOnlyConfig(version=2, granularity=PerRow()),
Int8DynamicActivationInt8WeightConfig(
version=2, granularity=PerTensor(), act_mapping_type=MappingType.SYMMETRIC
),
Int8DynamicActivationInt8WeightConfig(
version=2, granularity=PerRow(), act_mapping_type=MappingType.SYMMETRIC
),
]


@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.instantiate_parametrized_tests
Expand All @@ -36,13 +48,7 @@ def setUp(self):

torch.manual_seed(42)

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
@common_utils.parametrize("config", INT8_TEST_CONFIGS)
def test_creation_and_attributes(self, config):
"""Test tensor creation, dtypes, and ranges"""
linear = torch.nn.Linear(
Expand All @@ -60,15 +66,17 @@ def test_creation_and_attributes(self, config):
self.assertEqual(w.qdata.dtype, torch.int8)
self.assertTrue(torch.all(w.qdata >= -128) and torch.all(w.qdata <= 127))

if isinstance(config.granularity, PerRow):
self.assertEqual(w.scale.shape, (w.shape[0], 1))
elif isinstance(config.granularity, PerTensor):
self.assertEqual(w.scale.shape, (1, 1))

if hasattr(config, "act_mapping_type"):
self.assertEqual(w.act_quant_kwargs.mapping_type, config.act_mapping_type)

@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
@common_utils.parametrize("config", INT8_TEST_CONFIGS)
@common_utils.parametrize(
"sizes",
[
Expand All @@ -84,17 +92,28 @@ def test_int8_linear_variants(
sizes: tuple,
):
"""Test linear operation supports including shape and compile"""
torch.compiler.reset()

M, N, K = sizes
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval()
model_q = copy.deepcopy(model)

quantize_(model_q, config)

self.assertEqual(model_q.linear2.weight.scale.shape, (K,))
self.assertEqual(model_q.linear2.weight.scale.ndim, 1)
if isinstance(config.granularity, PerRow):
self.assertEqual(model_q.linear2.weight.scale.shape, (K, 1))
elif isinstance(config.granularity, PerTensor):
self.assertEqual(model_q.linear2.weight.scale.shape, (1, 1))

self.assertEqual(model_q.linear2.weight.scale.ndim, 2)

if compile:
if isinstance(config, Int8WeightOnlyConfig) and isinstance(
config.granularity, PerTensor
):
# currently the inductor lowering for weight only quant in core does not support per-tensor gpu, so this errors. Skipping for now, but will address this in core
return
model_q = torch.compile(model_q, fullgraph=True)

output_fp = model(input_tensor)
Expand All @@ -104,13 +123,7 @@ def test_int8_linear_variants(
f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}"
)

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
@common_utils.parametrize("config", INT8_TEST_CONFIGS)
@common_utils.parametrize("device", ["cpu", "cuda"])
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
def test_slice(self, config, device, dtype):
Expand All @@ -128,27 +141,24 @@ def test_slice(self, config, device, dtype):

self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0]))
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1]))
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0]))

if isinstance(config.granularity, PerRow):
self.assertEqual(
weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0])
)

self.assertEqual(weight2.scale, dummy.weight.scale)
with self.assertRaises(NotImplementedError):
_ = dummy.weight[::2]

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
],
)
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
def test_index_select(self, config, granularity):
@common_utils.parametrize("config", INT8_TEST_CONFIGS)
def test_index_select(self, config):
"""test that `x_0 = x[0]` works when `x` is a 2D quantized tensor."""
N, K = 256, 512
x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda")
linear.weight.data = x

config = config(version=2, granularity=granularity)
quantize_(linear, config)

x_int8 = linear.weight
Expand All @@ -160,22 +170,16 @@ def test_index_select(self, config, granularity):
)

# Test block_size granularity
if isinstance(granularity, PerRow):
if isinstance(config.granularity, PerRow):
self.assertEqual(
list(get_block_size(x_int8.shape, x_int8.granularity)), [1, K]
list(get_block_size(x_int8.shape, config.granularity)), [1, K]
)
elif isinstance(granularity, PerTensor):
elif isinstance(config.granularity, PerTensor):
self.assertEqual(
list(get_block_size(x_int8.shape, x_int8.granularity)), [N, K]
list(get_block_size(x_int8.shape, config.granularity)), [N, K]
)

@common_utils.parametrize(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
@common_utils.parametrize("config", INT8_TEST_CONFIGS)
def test_dequantization_accuracy(self, config):
"""Test dequantization accuracy separately"""
linear = torch.nn.Linear(
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@
Int4PreshuffledTensor,
Int4Tensor,
Int4TilePackedTo4dTensor,
Int8Tensor,
IntxOpaqueTensor,
IntxUnpackedToInt8Tensor,
)
Expand Down Expand Up @@ -164,6 +165,7 @@
"FqnToConfig",
"ModuleFqnToConfig",
# tensor subclasses
"Int8Tensor",
"Int4Tensor",
"Int4PlainInt32Tensor",
"Int4PreshuffledTensor",
Expand Down
82 changes: 43 additions & 39 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,10 @@ class Int8WeightOnlyConfig(AOBaseConfig):

def __post_init__(self):
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")
if self.version == 2:
assert self.group_size is None, (
f"Only support version 2 with group_size=None, got {self.group_size}"
)


# for BC
Expand Down Expand Up @@ -1522,9 +1526,7 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
layout: Optional[Layout] = PlainLayout()
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
weight_only_decode: bool = False
# TODO: Revisit for supported granularitys
# https://github.com/pytorch/ao/pull/3241#discussion_r2551497849
granularity: Optional[Granularity] = PerRow()
granularity: Granularity = PerRow()
set_inductor_config: bool = True
version: int = 1

Expand All @@ -1541,37 +1543,30 @@ def __post_init__(self):


def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
layout = config.layout
act_mapping_type = config.act_mapping_type
weight_only_decode = config.weight_only_decode

in_features = weight.shape[-1]
# int8 dynamic quantization only has benefit when in_feature > 16
if in_features <= 16:
logger.info(
f"Skipping applying Int8DynamicActivationInt8WeightConfig to weight of shape {weight.shape}"
f" because `in_feature` is <= 16: {in_features}"
)
return weight
if config.version == 1:
layout = config.layout
act_mapping_type = config.act_mapping_type
weight_only_decode = config.weight_only_decode

in_features = weight.shape[-1]
# int8 dynamic quantization only has benefit when in_feature > 16
if in_features <= 16:
logger.info(
f"Skipping applying Int8DynamicActivationInt8WeightConfig to weight of shape {weight.shape}"
f" because `in_feature` is <= 16: {in_features}"
)
return weight

# weight settings
mapping_type = MappingType.SYMMETRIC
weight_zero_point_domain = ZeroPointDomain.NONE
# weight settings
mapping_type = MappingType.SYMMETRIC
weight_zero_point_domain = ZeroPointDomain.NONE

target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64
def get_weight_block_size(x):
return tuple([1 for _ in range(x.dim() - 1)] + [x.shape[-1]])

if config.version == 1:
warnings.warn(
"Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details"
)
if isinstance(config.granularity, PerTensor):
block_size = weight.shape
else:
block_size = tuple(
[1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]]
)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
zero_point_dtype = torch.int64

if weight_only_decode:
input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode
Expand All @@ -1582,7 +1577,8 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
else:
input_quant_func = _int8_asymm_per_token_quant

quantized_weight = to_affine_quantized_intx(
block_size = get_weight_block_size(weight)
new_weight = to_affine_quantized_intx(
weight,
mapping_type,
block_size,
Expand All @@ -1592,24 +1588,32 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
_layout=layout,
zero_point_domain=weight_zero_point_domain,
)
quantized_weight = to_linear_activation_quantized(
quantized_weight, input_quant_func
)
quantized_weight = to_linear_activation_quantized(new_weight, input_quant_func)
else:
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
QuantizeTensorToInt8Kwargs,
)

assert config.granularity in {PerRow(), PerTensor()}, (
"Only PerRow and PerTensor are supported"
)
weight_granularity = config.granularity
act_granularity = config.granularity

assert config.act_mapping_type == MappingType.SYMMETRIC, (
"asymmetric dynamic quant not supported currently"
)
assert config.version == 2, f"Unexpected version: {config.version}"

# TODO: Symmentric/Asymmetric choice for weight quantization
# https://github.com/pytorch/ao/pull/3241#discussion_r2551515539
# TODO: Add block_size args to return in from_hp
# https://github.com/pytorch/ao/pull/3241#discussion_r2552016429
quantized_weight = Int8Tensor.from_hp(
weight,
granularity=config.granularity,
act_quant_kwargs=QuantizeTensorToInt8Kwargs(granularity=config.granularity),
granularity=weight_granularity,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mapping_type for weight is not passed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the config doesn't have an option for the weight_mapping_type, so we just use the default (symmetric)

act_quant_kwargs=QuantizeTensorToInt8Kwargs(
granularity=act_granularity,
mapping_type=config.act_mapping_type,
),
)

return quantized_weight
Expand Down
7 changes: 5 additions & 2 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1217,6 +1217,7 @@ def choose_qparams_affine(
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = torch.int32,
keepdim: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand Down Expand Up @@ -1247,6 +1248,7 @@ def choose_qparams_affine(
eps,
scale_dtype,
zero_point_dtype,
keepdim,
)


Expand Down Expand Up @@ -1521,6 +1523,7 @@ def _choose_qparams_affine(
eps: Optional[float] = None,
scale_dtype: Optional[torch.dtype] = None,
zero_point_dtype: Optional[torch.dtype] = None,
keepdim: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""op definition that has compatible signatures with custom op library

Expand Down Expand Up @@ -1550,8 +1553,8 @@ def _choose_qparams_affine(
)
input = input.view(shape_for_reduction)

min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
min_val = torch.amin(input, dim=reduction_dims, keepdim=keepdim)
max_val = torch.amax(input, dim=reduction_dims, keepdim=keepdim)

min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def _choose_quant_func_and_quantize_tensor(
"""
from torchao.quantization.quantize_.workflows import (
Float8Tensor,
Int8Tensor,
QuantizeTensorToFloat8Kwargs,
QuantizeTensorToInt8Kwargs,
)

if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs):
Expand All @@ -53,4 +55,11 @@ def _choose_quant_func_and_quantize_tensor(
quant_kwargs.kernel_preference,
)

if isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs):
return Int8Tensor.from_hp(
tensor,
quant_kwargs.granularity,
mapping_type=quant_kwargs.mapping_type,
)

raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}")
2 changes: 2 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
"QuantizeTensorToInt8Kwargs",
"Float8Tensor",
"QuantizeTensorToFloat8Kwargs",
"Int8Tensor",
"QuantizeTensorToInt8Kwargs",
"Int4ChooseQParamsAlgorithm",
"Int4PackingFormat",
"IntxChooseQParamsAlgorithm",
Expand Down
Loading
Loading