Skip to content

Commit c4273fe

Browse files
authored
Int8Tensor migration cleanup (#3407)
* Int8Tensor migration Summary: This PR creates a new Int8Tensor and updates the configs to use the new Int8Tensor flow Test Plan: To ensure BC: ``` pytest test/quantization/test_quant_api.py ``` To test new Int8Tensor: ``` pytest test/quantization/quantize_/workflows/int8/test_int8_tensor.py ``` Reviewers: Subscribers: Tasks: Tags: * ruff fixes * add init * fix ruff again * update * wip * undo update tests * fix ruff * fix varname * fix typing * add tests * fix dtype * fix ci * address granularity cr * update _choose_quant_func_and_quantize_tensor * make block size required attribute * made dtype required as well * address nits * skip per tensor weight only test for now
1 parent 7e0d439 commit c4273fe

File tree

7 files changed

+194
-214
lines changed

7 files changed

+194
-214
lines changed

test/quantization/quantize_/workflows/int8/test_int8_tensor.py

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,23 @@
1818
quantize_,
1919
)
2020
from torchao.quantization.granularity import PerRow, PerTensor
21+
from torchao.quantization.quant_primitives import MappingType
2122
from torchao.quantization.utils import compute_error, get_block_size
2223
from torchao.testing.model_architectures import ToyTwoLinearModel
2324
from torchao.testing.utils import TorchAOIntegrationTestCase
2425
from torchao.utils import torch_version_at_least
2526

27+
INT8_TEST_CONFIGS = [
28+
Int8WeightOnlyConfig(version=2, granularity=PerTensor()),
29+
Int8WeightOnlyConfig(version=2, granularity=PerRow()),
30+
Int8DynamicActivationInt8WeightConfig(
31+
version=2, granularity=PerTensor(), act_mapping_type=MappingType.SYMMETRIC
32+
),
33+
Int8DynamicActivationInt8WeightConfig(
34+
version=2, granularity=PerRow(), act_mapping_type=MappingType.SYMMETRIC
35+
),
36+
]
37+
2638

2739
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2840
@common_utils.instantiate_parametrized_tests
@@ -36,13 +48,7 @@ def setUp(self):
3648

3749
torch.manual_seed(42)
3850

39-
@common_utils.parametrize(
40-
"config",
41-
[
42-
Int8DynamicActivationInt8WeightConfig(version=2),
43-
Int8WeightOnlyConfig(version=2),
44-
],
45-
)
51+
@common_utils.parametrize("config", INT8_TEST_CONFIGS)
4652
def test_creation_and_attributes(self, config):
4753
"""Test tensor creation, dtypes, and ranges"""
4854
linear = torch.nn.Linear(
@@ -60,15 +66,17 @@ def test_creation_and_attributes(self, config):
6066
self.assertEqual(w.qdata.dtype, torch.int8)
6167
self.assertTrue(torch.all(w.qdata >= -128) and torch.all(w.qdata <= 127))
6268

69+
if isinstance(config.granularity, PerRow):
70+
self.assertEqual(w.scale.shape, (w.shape[0], 1))
71+
elif isinstance(config.granularity, PerTensor):
72+
self.assertEqual(w.scale.shape, (1, 1))
73+
74+
if hasattr(config, "act_mapping_type"):
75+
self.assertEqual(w.act_quant_kwargs.mapping_type, config.act_mapping_type)
76+
6377
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
6478
@common_utils.parametrize("compile", [True, False])
65-
@common_utils.parametrize(
66-
"config",
67-
[
68-
Int8DynamicActivationInt8WeightConfig(version=2),
69-
Int8WeightOnlyConfig(version=2),
70-
],
71-
)
79+
@common_utils.parametrize("config", INT8_TEST_CONFIGS)
7280
@common_utils.parametrize(
7381
"sizes",
7482
[
@@ -84,17 +92,28 @@ def test_int8_linear_variants(
8492
sizes: tuple,
8593
):
8694
"""Test linear operation supports including shape and compile"""
95+
torch.compiler.reset()
96+
8797
M, N, K = sizes
8898
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
8999
model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval()
90100
model_q = copy.deepcopy(model)
91101

92102
quantize_(model_q, config)
93103

94-
self.assertEqual(model_q.linear2.weight.scale.shape, (K,))
95-
self.assertEqual(model_q.linear2.weight.scale.ndim, 1)
104+
if isinstance(config.granularity, PerRow):
105+
self.assertEqual(model_q.linear2.weight.scale.shape, (K, 1))
106+
elif isinstance(config.granularity, PerTensor):
107+
self.assertEqual(model_q.linear2.weight.scale.shape, (1, 1))
108+
109+
self.assertEqual(model_q.linear2.weight.scale.ndim, 2)
96110

97111
if compile:
112+
if isinstance(config, Int8WeightOnlyConfig) and isinstance(
113+
config.granularity, PerTensor
114+
):
115+
# currently the inductor lowering for weight only quant in core does not support per-tensor gpu, so this errors. Skipping for now, but will address this in core
116+
return
98117
model_q = torch.compile(model_q, fullgraph=True)
99118

100119
output_fp = model(input_tensor)
@@ -104,13 +123,7 @@ def test_int8_linear_variants(
104123
f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}"
105124
)
106125

107-
@common_utils.parametrize(
108-
"config",
109-
[
110-
Int8DynamicActivationInt8WeightConfig(version=2),
111-
Int8WeightOnlyConfig(version=2),
112-
],
113-
)
126+
@common_utils.parametrize("config", INT8_TEST_CONFIGS)
114127
@common_utils.parametrize("device", ["cpu", "cuda"])
115128
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
116129
def test_slice(self, config, device, dtype):
@@ -128,27 +141,24 @@ def test_slice(self, config, device, dtype):
128141

129142
self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0]))
130143
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1]))
131-
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0]))
144+
145+
if isinstance(config.granularity, PerRow):
146+
self.assertEqual(
147+
weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0])
148+
)
149+
132150
self.assertEqual(weight2.scale, dummy.weight.scale)
133151
with self.assertRaises(NotImplementedError):
134152
_ = dummy.weight[::2]
135153

136-
@common_utils.parametrize(
137-
"config",
138-
[
139-
Int8DynamicActivationInt8WeightConfig,
140-
Int8WeightOnlyConfig,
141-
],
142-
)
143-
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
144-
def test_index_select(self, config, granularity):
154+
@common_utils.parametrize("config", INT8_TEST_CONFIGS)
155+
def test_index_select(self, config):
145156
"""test that `x_0 = x[0]` works when `x` is a 2D quantized tensor."""
146157
N, K = 256, 512
147158
x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
148159
linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda")
149160
linear.weight.data = x
150161

151-
config = config(version=2, granularity=granularity)
152162
quantize_(linear, config)
153163

154164
x_int8 = linear.weight
@@ -160,22 +170,16 @@ def test_index_select(self, config, granularity):
160170
)
161171

162172
# Test block_size granularity
163-
if isinstance(granularity, PerRow):
173+
if isinstance(config.granularity, PerRow):
164174
self.assertEqual(
165-
list(get_block_size(x_int8.shape, x_int8.granularity)), [1, K]
175+
list(get_block_size(x_int8.shape, config.granularity)), [1, K]
166176
)
167-
elif isinstance(granularity, PerTensor):
177+
elif isinstance(config.granularity, PerTensor):
168178
self.assertEqual(
169-
list(get_block_size(x_int8.shape, x_int8.granularity)), [N, K]
179+
list(get_block_size(x_int8.shape, config.granularity)), [N, K]
170180
)
171181

172-
@common_utils.parametrize(
173-
"config",
174-
[
175-
Int8DynamicActivationInt8WeightConfig(version=2),
176-
Int8WeightOnlyConfig(version=2),
177-
],
178-
)
182+
@common_utils.parametrize("config", INT8_TEST_CONFIGS)
179183
def test_dequantization_accuracy(self, config):
180184
"""Test dequantization accuracy separately"""
181185
linear = torch.nn.Linear(

torchao/quantization/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@
9898
Int4PreshuffledTensor,
9999
Int4Tensor,
100100
Int4TilePackedTo4dTensor,
101+
Int8Tensor,
101102
IntxOpaqueTensor,
102103
IntxUnpackedToInt8Tensor,
103104
)
@@ -164,6 +165,7 @@
164165
"FqnToConfig",
165166
"ModuleFqnToConfig",
166167
# tensor subclasses
168+
"Int8Tensor",
167169
"Int4Tensor",
168170
"Int4PlainInt32Tensor",
169171
"Int4PreshuffledTensor",

torchao/quantization/quant_api.py

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1341,6 +1341,10 @@ class Int8WeightOnlyConfig(AOBaseConfig):
13411341

13421342
def __post_init__(self):
13431343
torch._C._log_api_usage_once("torchao.quantization.Int8WeightOnlyConfig")
1344+
if self.version == 2:
1345+
assert self.group_size is None, (
1346+
f"Only support version 2 with group_size=None, got {self.group_size}"
1347+
)
13441348

13451349

13461350
# for BC
@@ -1522,9 +1526,7 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
15221526
layout: Optional[Layout] = PlainLayout()
15231527
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
15241528
weight_only_decode: bool = False
1525-
# TODO: Revisit for supported granularitys
1526-
# https://github.com/pytorch/ao/pull/3241#discussion_r2551497849
1527-
granularity: Optional[Granularity] = PerRow()
1529+
granularity: Granularity = PerRow()
15281530
set_inductor_config: bool = True
15291531
version: int = 1
15301532

@@ -1541,37 +1543,30 @@ def __post_init__(self):
15411543

15421544

15431545
def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
1544-
layout = config.layout
1545-
act_mapping_type = config.act_mapping_type
1546-
weight_only_decode = config.weight_only_decode
1547-
1548-
in_features = weight.shape[-1]
1549-
# int8 dynamic quantization only has benefit when in_feature > 16
1550-
if in_features <= 16:
1551-
logger.info(
1552-
f"Skipping applying Int8DynamicActivationInt8WeightConfig to weight of shape {weight.shape}"
1553-
f" because `in_feature` is <= 16: {in_features}"
1554-
)
1555-
return weight
1546+
if config.version == 1:
1547+
layout = config.layout
1548+
act_mapping_type = config.act_mapping_type
1549+
weight_only_decode = config.weight_only_decode
1550+
1551+
in_features = weight.shape[-1]
1552+
# int8 dynamic quantization only has benefit when in_feature > 16
1553+
if in_features <= 16:
1554+
logger.info(
1555+
f"Skipping applying Int8DynamicActivationInt8WeightConfig to weight of shape {weight.shape}"
1556+
f" because `in_feature` is <= 16: {in_features}"
1557+
)
1558+
return weight
15561559

1557-
# weight settings
1558-
mapping_type = MappingType.SYMMETRIC
1559-
weight_zero_point_domain = ZeroPointDomain.NONE
1560+
# weight settings
1561+
mapping_type = MappingType.SYMMETRIC
1562+
weight_zero_point_domain = ZeroPointDomain.NONE
15601563

1561-
target_dtype = torch.int8
1562-
eps = torch.finfo(torch.float32).eps
1563-
zero_point_dtype = torch.int64
1564+
def get_weight_block_size(x):
1565+
return tuple([1 for _ in range(x.dim() - 1)] + [x.shape[-1]])
15641566

1565-
if config.version == 1:
1566-
warnings.warn(
1567-
"Config Deprecation: version 1 of Int8DynamicActivationInt8WeightConfig is deprecated and will no longer be supported in a future release, please use version 2, see https://github.com/pytorch/ao/issues/2752 for more details"
1568-
)
1569-
if isinstance(config.granularity, PerTensor):
1570-
block_size = weight.shape
1571-
else:
1572-
block_size = tuple(
1573-
[1 for _ in range(weight.dim() - 1)] + [weight.shape[-1]]
1574-
)
1567+
target_dtype = torch.int8
1568+
eps = torch.finfo(torch.float32).eps
1569+
zero_point_dtype = torch.int64
15751570

15761571
if weight_only_decode:
15771572
input_quant_func = _int8_symm_per_token_reduced_range_quant_noop_decode
@@ -1582,7 +1577,8 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
15821577
else:
15831578
input_quant_func = _int8_asymm_per_token_quant
15841579

1585-
quantized_weight = to_affine_quantized_intx(
1580+
block_size = get_weight_block_size(weight)
1581+
new_weight = to_affine_quantized_intx(
15861582
weight,
15871583
mapping_type,
15881584
block_size,
@@ -1592,24 +1588,32 @@ def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
15921588
_layout=layout,
15931589
zero_point_domain=weight_zero_point_domain,
15941590
)
1595-
quantized_weight = to_linear_activation_quantized(
1596-
quantized_weight, input_quant_func
1597-
)
1591+
quantized_weight = to_linear_activation_quantized(new_weight, input_quant_func)
15981592
else:
15991593
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
16001594
QuantizeTensorToInt8Kwargs,
16011595
)
16021596

1597+
assert config.granularity in {PerRow(), PerTensor()}, (
1598+
"Only PerRow and PerTensor are supported"
1599+
)
1600+
weight_granularity = config.granularity
1601+
act_granularity = config.granularity
1602+
1603+
assert config.act_mapping_type == MappingType.SYMMETRIC, (
1604+
"asymmetric dynamic quant not supported currently"
1605+
)
16031606
assert config.version == 2, f"Unexpected version: {config.version}"
16041607

16051608
# TODO: Symmentric/Asymmetric choice for weight quantization
16061609
# https://github.com/pytorch/ao/pull/3241#discussion_r2551515539
1607-
# TODO: Add block_size args to return in from_hp
1608-
# https://github.com/pytorch/ao/pull/3241#discussion_r2552016429
16091610
quantized_weight = Int8Tensor.from_hp(
16101611
weight,
1611-
granularity=config.granularity,
1612-
act_quant_kwargs=QuantizeTensorToInt8Kwargs(granularity=config.granularity),
1612+
granularity=weight_granularity,
1613+
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
1614+
granularity=act_granularity,
1615+
mapping_type=config.act_mapping_type,
1616+
),
16131617
)
16141618

16151619
return quantized_weight

torchao/quantization/quant_primitives.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,6 +1217,7 @@ def choose_qparams_affine(
12171217
eps: Optional[float] = None,
12181218
scale_dtype: Optional[torch.dtype] = None,
12191219
zero_point_dtype: Optional[torch.dtype] = torch.int32,
1220+
keepdim: bool = False,
12201221
) -> Tuple[torch.Tensor, torch.Tensor]:
12211222
"""
12221223
Args:
@@ -1247,6 +1248,7 @@ def choose_qparams_affine(
12471248
eps,
12481249
scale_dtype,
12491250
zero_point_dtype,
1251+
keepdim,
12501252
)
12511253

12521254

@@ -1521,6 +1523,7 @@ def _choose_qparams_affine(
15211523
eps: Optional[float] = None,
15221524
scale_dtype: Optional[torch.dtype] = None,
15231525
zero_point_dtype: Optional[torch.dtype] = None,
1526+
keepdim: bool = False,
15241527
) -> Tuple[torch.Tensor, torch.Tensor]:
15251528
"""op definition that has compatible signatures with custom op library
15261529
@@ -1550,8 +1553,8 @@ def _choose_qparams_affine(
15501553
)
15511554
input = input.view(shape_for_reduction)
15521555

1553-
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
1554-
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
1556+
min_val = torch.amin(input, dim=reduction_dims, keepdim=keepdim)
1557+
max_val = torch.amax(input, dim=reduction_dims, keepdim=keepdim)
15551558

15561559
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
15571560
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))

torchao/quantization/quantize_/common/quantize_tensor_kwargs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def _choose_quant_func_and_quantize_tensor(
3939
"""
4040
from torchao.quantization.quantize_.workflows import (
4141
Float8Tensor,
42+
Int8Tensor,
4243
QuantizeTensorToFloat8Kwargs,
44+
QuantizeTensorToInt8Kwargs,
4345
)
4446

4547
if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs):
@@ -53,4 +55,11 @@ def _choose_quant_func_and_quantize_tensor(
5355
quant_kwargs.kernel_preference,
5456
)
5557

58+
if isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs):
59+
return Int8Tensor.from_hp(
60+
tensor,
61+
quant_kwargs.granularity,
62+
mapping_type=quant_kwargs.mapping_type,
63+
)
64+
5665
raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}")

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242
"QuantizeTensorToInt8Kwargs",
4343
"Float8Tensor",
4444
"QuantizeTensorToFloat8Kwargs",
45+
"Int8Tensor",
46+
"QuantizeTensorToInt8Kwargs",
4547
"Int4ChooseQParamsAlgorithm",
4648
"Int4PackingFormat",
4749
"IntxChooseQParamsAlgorithm",

0 commit comments

Comments
 (0)