|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import copy |
| 8 | +import unittest |
| 9 | + |
| 10 | +import torch |
| 11 | +from torch._inductor.utils import run_and_get_code |
| 12 | +from torch.testing import FileCheck |
| 13 | +from torch.testing._internal import common_utils |
| 14 | + |
| 15 | +from torchao.quantization import ( |
| 16 | + Int8DynamicActivationInt8WeightConfig, |
| 17 | + Int8WeightOnlyConfig, |
| 18 | + quantize_, |
| 19 | +) |
| 20 | +from torchao.quantization.granularity import PerRow, PerTensor |
| 21 | +from torchao.quantization.utils import compute_error, get_block_size |
| 22 | +from torchao.testing.model_architectures import ToyTwoLinearModel |
| 23 | +from torchao.testing.utils import TorchAOIntegrationTestCase |
| 24 | +from torchao.utils import torch_version_at_least |
| 25 | + |
| 26 | + |
| 27 | +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 28 | +@common_utils.instantiate_parametrized_tests |
| 29 | +class TestInt8Tensor(TorchAOIntegrationTestCase): |
| 30 | + def setUp(self): |
| 31 | + super().setUp() |
| 32 | + |
| 33 | + self.test_shape = (32, 20) |
| 34 | + self.dtype = torch.bfloat16 |
| 35 | + self.batch_size = 32 |
| 36 | + |
| 37 | + torch.manual_seed(42) |
| 38 | + |
| 39 | + @common_utils.parametrize( |
| 40 | + "config", |
| 41 | + [ |
| 42 | + Int8DynamicActivationInt8WeightConfig(version=2), |
| 43 | + Int8WeightOnlyConfig(version=2), |
| 44 | + ], |
| 45 | + ) |
| 46 | + def test_creation_and_attributes(self, config): |
| 47 | + """Test tensor creation, dtypes, and ranges""" |
| 48 | + linear = torch.nn.Linear( |
| 49 | + self.test_shape[1], |
| 50 | + self.test_shape[0], |
| 51 | + bias=False, |
| 52 | + dtype=self.dtype, |
| 53 | + device="cuda", |
| 54 | + ) |
| 55 | + quantize_(linear, config) |
| 56 | + |
| 57 | + w = linear.weight |
| 58 | + |
| 59 | + self.assertEqual(w.shape, self.test_shape) |
| 60 | + self.assertEqual(w.qdata.dtype, torch.int8) |
| 61 | + self.assertTrue(torch.all(w.qdata >= -128) and torch.all(w.qdata <= 127)) |
| 62 | + |
| 63 | + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float32]) |
| 64 | + @common_utils.parametrize("compile", [True, False]) |
| 65 | + @common_utils.parametrize( |
| 66 | + "config", |
| 67 | + [ |
| 68 | + Int8DynamicActivationInt8WeightConfig(version=2), |
| 69 | + Int8WeightOnlyConfig(version=2), |
| 70 | + ], |
| 71 | + ) |
| 72 | + @common_utils.parametrize( |
| 73 | + "sizes", |
| 74 | + [ |
| 75 | + ((128,), 256, 128), # 2D |
| 76 | + ((32, 128), 64, 256), # 3D |
| 77 | + ], |
| 78 | + ) |
| 79 | + def test_int8_linear_variants( |
| 80 | + self, |
| 81 | + dtype: torch.dtype, |
| 82 | + config, |
| 83 | + compile: bool, |
| 84 | + sizes: tuple, |
| 85 | + ): |
| 86 | + """Test linear operation supports including shape and compile""" |
| 87 | + M, N, K = sizes |
| 88 | + input_tensor = torch.randn(*M, K, dtype=dtype, device="cuda") |
| 89 | + model = ToyTwoLinearModel(K, N, K, dtype=dtype, device="cuda").eval() |
| 90 | + model_q = copy.deepcopy(model) |
| 91 | + |
| 92 | + quantize_(model_q, config) |
| 93 | + |
| 94 | + self.assertEqual(model_q.linear2.weight.scale.shape, (K,)) |
| 95 | + self.assertEqual(model_q.linear2.weight.scale.ndim, 1) |
| 96 | + |
| 97 | + if compile: |
| 98 | + model_q = torch.compile(model_q, fullgraph=True) |
| 99 | + |
| 100 | + output_fp = model(input_tensor) |
| 101 | + output_quantized = model_q(input_tensor) |
| 102 | + |
| 103 | + assert compute_error(output_fp, output_quantized) > 20, ( |
| 104 | + f"Quantization error is too high got a SQNR of {compute_error(output_fp, output_quantized)}" |
| 105 | + ) |
| 106 | + |
| 107 | + @common_utils.parametrize( |
| 108 | + "config", |
| 109 | + [ |
| 110 | + Int8DynamicActivationInt8WeightConfig(version=2), |
| 111 | + Int8WeightOnlyConfig(version=2), |
| 112 | + ], |
| 113 | + ) |
| 114 | + @common_utils.parametrize("device", ["cpu", "cuda"]) |
| 115 | + @common_utils.parametrize("dtype", [torch.bfloat16, torch.float16]) |
| 116 | + def test_slice(self, config, device, dtype): |
| 117 | + """Test tensor slicing with per-row quantization""" |
| 118 | + tensor_size = 256 |
| 119 | + slice_sizes = (64, 128) |
| 120 | + |
| 121 | + dummy = torch.nn.Linear( |
| 122 | + tensor_size, tensor_size, bias=False, dtype=dtype, device=device |
| 123 | + ) |
| 124 | + quantize_(dummy, config) |
| 125 | + |
| 126 | + weight1 = dummy.weight.clone().narrow(0, 0, slice_sizes[0]) |
| 127 | + weight2 = dummy.weight.clone().narrow(1, 0, slice_sizes[1]) |
| 128 | + |
| 129 | + self.assertEqual(weight1.qdata, dummy.weight.qdata.narrow(0, 0, slice_sizes[0])) |
| 130 | + self.assertEqual(weight2.qdata, dummy.weight.qdata.narrow(1, 0, slice_sizes[1])) |
| 131 | + self.assertEqual(weight1.scale, dummy.weight.scale.narrow(0, 0, slice_sizes[0])) |
| 132 | + self.assertEqual(weight2.scale, dummy.weight.scale) |
| 133 | + with self.assertRaises(NotImplementedError): |
| 134 | + _ = dummy.weight[::2] |
| 135 | + |
| 136 | + @common_utils.parametrize( |
| 137 | + "config", |
| 138 | + [ |
| 139 | + Int8DynamicActivationInt8WeightConfig, |
| 140 | + Int8WeightOnlyConfig, |
| 141 | + ], |
| 142 | + ) |
| 143 | + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) |
| 144 | + def test_index_select(self, config, granularity): |
| 145 | + """test that `x_0 = x[0]` works when `x` is a 2D quantized tensor.""" |
| 146 | + N, K = 256, 512 |
| 147 | + x = torch.randn(N, K, device="cuda", dtype=torch.bfloat16) |
| 148 | + linear = torch.nn.Linear(K, N, bias=False, dtype=torch.bfloat16, device="cuda") |
| 149 | + linear.weight.data = x |
| 150 | + |
| 151 | + config = config(version=2, granularity=granularity) |
| 152 | + quantize_(linear, config) |
| 153 | + |
| 154 | + x_int8 = linear.weight |
| 155 | + x_int8_0 = x_int8[0] |
| 156 | + |
| 157 | + # Test dequantization consistency |
| 158 | + torch.testing.assert_close( |
| 159 | + x_int8.dequantize()[0], x_int8_0.dequantize(), atol=0, rtol=0 |
| 160 | + ) |
| 161 | + |
| 162 | + # Test block_size granularity |
| 163 | + if isinstance(granularity, PerRow): |
| 164 | + self.assertEqual( |
| 165 | + list(get_block_size(x_int8.shape, x_int8.granularity)), [1, K] |
| 166 | + ) |
| 167 | + elif isinstance(granularity, PerTensor): |
| 168 | + self.assertEqual( |
| 169 | + list(get_block_size(x_int8.shape, x_int8.granularity)), [N, K] |
| 170 | + ) |
| 171 | + |
| 172 | + @common_utils.parametrize( |
| 173 | + "config", |
| 174 | + [ |
| 175 | + Int8DynamicActivationInt8WeightConfig(version=2), |
| 176 | + Int8WeightOnlyConfig(version=2), |
| 177 | + ], |
| 178 | + ) |
| 179 | + def test_dequantization_accuracy(self, config): |
| 180 | + """Test dequantization accuracy separately""" |
| 181 | + linear = torch.nn.Linear( |
| 182 | + 256, 512, bias=False, dtype=torch.bfloat16, device="cuda" |
| 183 | + ) |
| 184 | + weight_fp = copy.deepcopy(linear.weight) |
| 185 | + quantize_(linear, config) |
| 186 | + |
| 187 | + tensor = linear.weight |
| 188 | + dequantized = tensor.dequantize() |
| 189 | + self.assertEqual(dequantized.shape, weight_fp.shape) |
| 190 | + assert compute_error(dequantized, weight_fp) > 20, ( |
| 191 | + f"Dequantization error is too high to get a SQNR of {compute_error(dequantized, weight_fp)}" |
| 192 | + ) |
| 193 | + |
| 194 | + @unittest.skipIf( |
| 195 | + not torch_version_at_least("2.7.0"), "torch 2.6.0 and below has custom fx pass" |
| 196 | + ) |
| 197 | + def test_available_gpu_kernels(self): |
| 198 | + """Check which GPU kernels are used""" |
| 199 | + torch.compiler.reset() |
| 200 | + |
| 201 | + M, K, N = 128, 256, 512 |
| 202 | + m = torch.nn.Sequential( |
| 203 | + torch.nn.Linear(K, N, device="cuda", dtype=torch.bfloat16) |
| 204 | + ) |
| 205 | + |
| 206 | + config = Int8DynamicActivationInt8WeightConfig(version=2) |
| 207 | + quantize_(m, config) |
| 208 | + |
| 209 | + m = torch.compile(m) |
| 210 | + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) |
| 211 | + |
| 212 | + out, code = run_and_get_code(m, x) |
| 213 | + |
| 214 | + # Check expected kernels are present |
| 215 | + FileCheck().check_count("triton_per_fused", 1).check_count( |
| 216 | + "extern_kernels._int_mm", 1 |
| 217 | + ).check_count("triton_poi_fused", 1).run(code[0]) |
| 218 | + |
| 219 | + |
| 220 | +if __name__ == "__main__": |
| 221 | + common_utils.run_tests() |
0 commit comments