Skip to content

Commit 3c3515a

Browse files
authored
Introduce int8 quantization api (version 2) (#3391)
Summary: Introduce a new tensor subclass API. The main features are Int8Tensor: Main API, which handles quantization and dequantization operations Utility operation functions: Tensor slice, index selection This API is integrated into global variants (Int8WeightOnlyConfig, Int8DynamicActivationInt8WeightConfig) using version, and not defined as a default. Related Issue/PR: #3241 (reland) Test plan: pytest -sv test/quantization/quantize_/workflows/int8/test_int8_tensor.py PERF Test: https://github.com/pytorch/ao/blob/main/tutorials/quantize_vit/run_vit_b_quant.py with a batch size of 32: API With torch.compile Without torch.compile Old 65.47 ms 234.39 ms New 63.30 ms 239.30 ms Future Plan: #3241 (review)
1 parent 119b4d1 commit 3c3515a

File tree

6 files changed

+672
-42
lines changed

6 files changed

+672
-42
lines changed

docs/source/quantization_overview.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ First we want to lay out the torchao stack::
55

66
Quantization Algorithms/Flows: weight only/dynamic/static quantization, hqq, awq, gptq etc.
77
---------------------------------------------------------------------------------------------
8-
Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Float8Tensor
8+
Quantized Tensors (derived dtypes): Int4Tensor, Int4PreshuffledTensor, Int8Tensor, Float8Tensor
99
---------------------------------------------------------------------------------------------
1010
Quantization Primitive Ops/Efficient Kernels: matmul, quantize, dequantize
1111
---------------------------------------------------------------------------------------------
@@ -88,6 +88,8 @@ So in general we structure Tensor subclasses by dervied dtpype and packing forma
8888
- scaled int4
8989
- preshuffled (special format to optimize for loading)
9090
- float8 act + int4 weight dynamic quantization and int4 weight only quantization
91+
* - Int8Tensor
92+
- plain (no packing needed)
9193

9294
.. note::
9395
We don't have granularity specific tensor subclasses, i.e. no Float8RowwiseTensor or Float8BlockwiseTensor, all granularities are implemented in the same Tensor, we typically use a general `block_size` attribute to distinguish between different granularities, and each Tensor is allowed to support only a subset of all possible granularity options.
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import copy
8+
import unittest
9+
10+
import torch
11+
from torch._inductor.utils import run_and_get_code
12+
from torch.testing import FileCheck
13+
from torch.testing._internal import common_utils
14+
15+
from torchao.quantization import (
16+
Int8DynamicActivationInt8WeightConfig,
17+
Int8WeightOnlyConfig,
18+
quantize_,
19+
)
20+
from torchao.quantization.granularity import PerRow, PerTensor
21+
from torchao.quantization.utils import compute_error, get_block_size
22+
from torchao.testing.model_architectures import ToyTwoLinearModel
23+
from torchao.testing.utils import TorchAOIntegrationTestCase
24+
from torchao.utils import torch_version_at_least
25+
26+
27+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
28+
@common_utils.instantiate_parametrized_tests
29+
class TestInt8Tensor(TorchAOIntegrationTestCase):
30+
def setUp(self):
31+
super().setUp()
32+
33+
self.test_shape = (32, 20)
34+
self.dtype = torch.bfloat16
35+
self.batch_size = 32
36+
37+
torch.manual_seed(42)
38+
39+
@common_utils.parametrize(
40+
"config",
41+
[
42+
Int8DynamicActivationInt8WeightConfig(version=2),
43+
Int8WeightOnlyConfig(version=2),
44+
],
45+
)
46+
def test_creation_and_attributes(self, config):
47+
"""Test tensor creation, dtypes, and ranges"""
48+
linear = torch.nn.Linear(
49+
self.test_shape[1],
50+
self.test_shape[0],
51+
bias=False,
52+
dtype=self.dtype,
53+
device="cuda",
54+
)
55+
quantize_(linear, config)
56+
57+
w = linear.weight
58+
59+
self.assertEqual(w.shape, self.test_shape)
60+
self.assertEqual(w.qdata.dtype, torch.int8)
61+
self.assertTrue(torch.all(w.qdata >= -128) and torch.all(w.qdata <= 127))
62+
63+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float32])
64+
@common_utils.parametrize("compile", [True, False])
65+
@common_utils.parametrize(
66+
"config",
67+
[
68+
Int8DynamicActivationInt8WeightConfig(version=2),
69+
Int8WeightOnlyConfig(version=2),
70+
],
71+
)
72+
@common_utils.parametrize(
73+
"sizes",
74+
[
75+
((128,), 256, 128), # 2D
76+
((32, 128), 64, 256), # 3D
77+
],
78+
)
79+
def test_int8_linear_variants(
80+
self,
81+
dtype: torch.dtype,
82+
config,
83+
compile: bool,
84+
sizes: tuple,
85+
):
86+
"""Test linear operation supports including shape and compile"""
87+
M, N, K = sizes
88+
input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda")
89+
model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval()
90+
model_q = copy.deepcopy(model)
91+
92+
quantize_(model_q, config)
93+
94+
self.assertEqual(model_q.linear2.weight.scale.shape, (K,))
95+
self.assertEqual(model_q.linear2.weight.scale.ndim, 1)
96+
97+
if compile:
98+
model_q = torch.compile(model_q, fullgraph=True)
99+
100+
output_fp = model(input_tensor)
101+
output_quantized = model_q(input_tensor)
102+
103+
assert compute_error(output_fp, output_quantized) > 20, (
104+
f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}"
105+
)
106+
107+
@common_utils.parametrize(
108+
"config",
109+
[
110+
Int8DynamicActivationInt8WeightConfig(version=2),
111+
Int8WeightOnlyConfig(version=2),
112+
],
113+
)
114+
@common_utils.parametrize("device", ["cpu", "cuda"])
115+
@common_utils.parametrize("dtype", [torch.bfloat16, torch.float16])
116+
def test_slice(self, config, device, dtype):
117+
"""Test tensor slicing with per-row quantization"""
118+
tensor_size = 256
119+
slice_sizes = (64, 128)
120+
121+
dummy = torch.nn.Linear(
122+
tensor_size, tensor_size, bias=False, dtype=dtype, device=device
123+
)
124+
quantize_(dummy, config)
125+
126+
weight1 = dummy.weight.clone().narrow(0, 0, slice_sizes[0])
127+
weight2 = dummy.weight.clone().narrow(1, 0, slice_sizes[1])
128+
129+
self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0]))
130+
self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1]))
131+
self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0]))
132+
self.assertEqual(weight2.scale, dummy.weight.scale)
133+
with self.assertRaises(NotImplementedError):
134+
_ = dummy.weight[::2]
135+
136+
@common_utils.parametrize(
137+
"config",
138+
[
139+
Int8DynamicActivationInt8WeightConfig,
140+
Int8WeightOnlyConfig,
141+
],
142+
)
143+
@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
144+
def test_index_select(self, config, granularity):
145+
"""test that `x_0 = x[0]` works when `x` is a 2D quantized tensor."""
146+
N, K = 256, 512
147+
x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16)
148+
linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda")
149+
linear.weight.data = x
150+
151+
config = config(version=2, granularity=granularity)
152+
quantize_(linear, config)
153+
154+
x_int8 = linear.weight
155+
x_int8_0 = x_int8[0]
156+
157+
# Test dequantization consistency
158+
torch.testing.assert_close(
159+
x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0
160+
)
161+
162+
# Test block_size granularity
163+
if isinstance(granularity, PerRow):
164+
self.assertEqual(
165+
list(get_block_size(x_int8.shape, x_int8.granularity)), [1, K]
166+
)
167+
elif isinstance(granularity, PerTensor):
168+
self.assertEqual(
169+
list(get_block_size(x_int8.shape, x_int8.granularity)), [N, K]
170+
)
171+
172+
@common_utils.parametrize(
173+
"config",
174+
[
175+
Int8DynamicActivationInt8WeightConfig(version=2),
176+
Int8WeightOnlyConfig(version=2),
177+
],
178+
)
179+
def test_dequantization_accuracy(self, config):
180+
"""Test dequantization accuracy separately"""
181+
linear = torch.nn.Linear(
182+
256, 512, bias=False, dtype=torch.bfloat16, device="cuda"
183+
)
184+
weight_fp = copy.deepcopy(linear.weight)
185+
quantize_(linear, config)
186+
187+
tensor = linear.weight
188+
dequantized = tensor.dequantize()
189+
self.assertEqual(dequantized.shape, weight_fp.shape)
190+
assert compute_error(dequantized, weight_fp) > 20, (
191+
f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, weight_fp)}"
192+
)
193+
194+
@unittest.skipIf(
195+
not torch_version_at_least("2.7.0"), "torch 2.6.0 and below has custom fx pass"
196+
)
197+
def test_available_gpu_kernels(self):
198+
"""Check which GPU kernels are used"""
199+
torch.compiler.reset()
200+
201+
M, K, N = 128, 256, 512
202+
m = torch.nn.Sequential(
203+
torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16)
204+
)
205+
206+
config = Int8DynamicActivationInt8WeightConfig(version=2)
207+
quantize_(m, config)
208+
209+
m = torch.compile(m)
210+
x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16)
211+
212+
out, code = run_and_get_code(m, x)
213+
214+
# Check expected kernels are present
215+
FileCheck().check_count("triton_per_fused", 1).check_count(
216+
"extern_kernels._int_mm", 1
217+
).check_count("triton_poi_fused", 1).run(code[0])
218+
219+
220+
if __name__ == "__main__":
221+
common_utils.run_tests()

0 commit comments

Comments
 (0)