Skip to content

Commit ab6bc89

Browse files
authored
Move Int4OpaqueTensor to prototype (#3378)
1 parent 3ad4d0a commit ab6bc89

File tree

11 files changed

+112
-45
lines changed

11 files changed

+112
-45
lines changed

test/prototype/test_awq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616

1717
from torchao.prototype.awq import AWQConfig, AWQStep
18+
from torchao.prototype.int4_opaque_tensor import Int4WeightOnlyOpaqueTensorConfig
1819
from torchao.quantization import Int4WeightOnlyConfig, quantize_
1920
from torchao.utils import _is_fbgemm_gpu_genai_available, torch_version_at_least
2021

@@ -76,7 +77,7 @@ def forward(self, x):
7677
# Note: the functionality unit test doesn't work for hqq
7778
Int4WeightOnlyConfig(group_size=128, int4_packing_format="tile_packed_to_4d"),
7879
],
79-
"cpu": [Int4WeightOnlyConfig(group_size=128, int4_packing_format="opaque")],
80+
"cpu": [Int4WeightOnlyOpaqueTensorConfig(group_size=128)],
8081
"xpu": [Int4WeightOnlyConfig(group_size=128, int4_packing_format="plain_int32")],
8182
}
8283

test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py renamed to test/prototype/test_int4_opaque_tensor.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
run_tests,
1616
)
1717

18-
from torchao.quantization import (
19-
Int4WeightOnlyConfig,
20-
quantize_,
21-
)
18+
from torchao.prototype.int4_opaque_tensor import Int4WeightOnlyOpaqueTensorConfig
19+
from torchao.quantization import quantize_
2220
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
2321
from torchao.quantization.utils import compute_error
2422
from torchao.utils import (
@@ -27,9 +25,8 @@
2725

2826

2927
def get_config(group_size, use_hqq):
30-
return Int4WeightOnlyConfig(
28+
return Int4WeightOnlyOpaqueTensorConfig(
3129
group_size=group_size,
32-
int4_packing_format="opaque",
3330
int4_choose_qparams_algorithm="hqq" if use_hqq else "tinygemm",
3431
)
3532

@@ -68,7 +65,7 @@ def test_module_path(self, dtype, use_hqq):
6865
quantize_(linear, get_config(group_size=128, use_hqq=use_hqq))
6966
self.assertEqual(
7067
str(type(linear.weight)),
71-
"<class 'torchao.quantization.Int4OpaqueTensor'>",
68+
"<class 'torchao.prototype.int4_opaque_tensor.Int4OpaqueTensor'>",
7269
)
7370

7471
with tempfile.NamedTemporaryFile() as f:
@@ -77,7 +74,7 @@ def test_module_path(self, dtype, use_hqq):
7774
state_dict = torch.load(f)
7875
self.assertEqual(
7976
str(type(state_dict["weight"])),
80-
"<class 'torchao.quantization.Int4OpaqueTensor'>",
77+
"<class 'torchao.prototype.int4_opaque_tensor.Int4OpaqueTensor'>",
8178
)
8279

8380
@parametrize("use_hqq", [True, False])

torchao/prototype/awq/example.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torchao.prototype.awq import (
1818
AWQConfig,
1919
)
20+
from torchao.prototype.int4_opaque_tensor import Int4WeightOnlyOpaqueTensorConfig
2021
from torchao.quantization import Int4WeightOnlyConfig, quantize_
2122

2223

@@ -259,9 +260,7 @@ def quantize_and_eval(
259260
group_size=group_size, int4_packing_format="plain_int32"
260261
)
261262
elif device == "cpu":
262-
base_config = Int4WeightOnlyConfig(
263-
group_size=group_size, int4_packing_format="opaque"
264-
)
263+
base_config = Int4WeightOnlyOpaqueTensorConfig(group_size=group_size)
265264
else:
266265
assert False, "Unsupported device: {}".format(device)
267266
print(f"running {quant} prepare and calibrate")
@@ -301,9 +300,7 @@ def quantize_and_eval(
301300
if device == "cuda":
302301
base_config = Int4WeightOnlyConfig(group_size=group_size)
303302
elif device == "cpu":
304-
base_config = Int4WeightOnlyConfig(
305-
group_size=group_size, int4_packing_format="opaque"
306-
)
303+
base_config = Int4WeightOnlyOpaqueTensorConfig(group_size=group_size)
307304
else:
308305
assert False, "Unsupported device: {}".format(device)
309306
quantize_(model, base_config)

torchao/prototype/float8_opaque_tensor/inference_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class Float8DynamicActivationFloat8WeightOpaqueTensorConfig(AOBaseConfig):
5555

5656
def __post_init__(self):
5757
torch._C._log_api_usage_once(
58-
"torchao.quantization.Float8DynamicActivationFloat8WeightConfig"
58+
"torchao.prototype.float8_opaque_tensor.Float8DynamicActivationFloat8WeightOpaqueTensorConfig"
5959
)
6060
activation_granularity, weight_granularity = (
6161
Float8OpaqueTensor._normalize_and_check_granularity(self.granularity)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .inference_workflow import Int4WeightOnlyOpaqueTensorConfig
2+
from .int4_opaque_tensor import Int4OpaqueTensor
3+
4+
__all__ = [
5+
"Int4OpaqueTensor",
6+
"Int4WeightOnlyOpaqueTensorConfig",
7+
]
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
from dataclasses import dataclass
9+
10+
import torch
11+
12+
import torchao
13+
from torchao.core.config import AOBaseConfig
14+
15+
logger = logging.getLogger(__name__)
16+
import types
17+
18+
from torchao.quantization.quant_api import _linear_extra_repr
19+
from torchao.quantization.quantize_.workflows import (
20+
Int4ChooseQParamsAlgorithm,
21+
)
22+
from torchao.quantization.transform_module import (
23+
register_quantize_module_handler,
24+
)
25+
26+
from .int4_opaque_tensor import Int4OpaqueTensor
27+
28+
29+
@dataclass
30+
class Int4WeightOnlyOpaqueTensorConfig(AOBaseConfig):
31+
"""
32+
Configuration for int4 weight only quantization, only groupwise quantization is supported right now.
33+
34+
Args:
35+
`group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [256, 128, 64, 32]
36+
`int4_choose_qparams_algorithm`: variants of choose qparams algorithm to use for int4, currently support TINYGEMM ("tinygemm") and HQQ ("hqq")
37+
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values
38+
"""
39+
40+
group_size: int = 128
41+
int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = (
42+
Int4ChooseQParamsAlgorithm.TINYGEMM
43+
)
44+
set_inductor_config: bool = True
45+
46+
def __post_init__(self):
47+
torch._C._log_api_usage_once(
48+
"torchao.prototype.int4_opaque_tensor.Int4WeightOnlyOpaqueTensorConfig"
49+
)
50+
51+
52+
def _int4_weight_only_opaque_tensor_quantize(weight, config):
53+
group_size = config.group_size
54+
int4_choose_qparams_algorithm = config.int4_choose_qparams_algorithm
55+
56+
if weight.shape[-1] % group_size != 0:
57+
logger.info(
58+
f"Skipping quantizing weight with int4 weight only quantization because the shape of weight {weight.shape} is not compatible with group_size {group_size}"
59+
)
60+
return weight
61+
62+
block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size])
63+
64+
block_size = list(block_size)
65+
66+
new_weight = Int4OpaqueTensor.from_hp(
67+
weight,
68+
block_size,
69+
int4_choose_qparams_algorithm=int4_choose_qparams_algorithm,
70+
)
71+
return new_weight
72+
73+
74+
@register_quantize_module_handler(Int4WeightOnlyOpaqueTensorConfig)
75+
def _int4_weight_only_transform(
76+
module: torch.nn.Module, config: Int4WeightOnlyOpaqueTensorConfig
77+
) -> torch.nn.Module:
78+
if config.set_inductor_config:
79+
torchao.quantization.utils.recommended_inductor_config_setter()
80+
81+
assert hasattr(module, "weight"), (
82+
"applying int4 weight only quant requires module to have weight attribute"
83+
+ " but {module} does not have one"
84+
)
85+
new_weight = _int4_weight_only_opaque_tensor_quantize(module.weight, config)
86+
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
87+
module.extra_repr = types.MethodType(_linear_extra_repr, module)
88+
return module

torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py renamed to torchao/prototype/int4_opaque_tensor/int4_opaque_tensor.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
_choose_qparams_and_quantize_affine_hqq,
1717
_quantize_affine_tinygemm,
1818
)
19+
from torchao.quantization.quantize_.workflows import (
20+
Int4ChooseQParamsAlgorithm,
21+
)
1922
from torchao.quantization.utils import pack_tinygemm_scales_and_zeros
2023
from torchao.utils import (
2124
TorchAOBaseTensor,
2225
)
2326

24-
from .int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm
25-
2627
__all__ = [
2728
"Int4OpaqueTensor",
2829
]
@@ -241,7 +242,7 @@ def _(func, types, args, kwargs):
241242
return y.to(orig_dtype)
242243

243244

244-
Int4OpaqueTensor.__module__ = "torchao.quantization"
245+
Int4OpaqueTensor.__module__ = "torchao.prototype.int4_opaque_tensor"
245246

246247
# Allow a model with Int4OpaqueTensor weights to be loaded with `weights_only=True`
247248
torch.serialization.add_safe_globals([Int4OpaqueTensor])

torchao/quantization/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@
9494
from .quantize_.workflows import (
9595
Float8Tensor,
9696
Int4MarlinSparseTensor,
97-
Int4OpaqueTensor,
9897
Int4PlainInt32Tensor,
9998
Int4PreshuffledTensor,
10099
Int4Tensor,
@@ -173,7 +172,6 @@
173172
"IntxUnpackedToInt8Tensor",
174173
"Int4TilePackedTo4dTensor",
175174
"Float8Tensor",
176-
"Int4OpaqueTensor",
177175
# smooth quant - subject to change
178176
"get_scale",
179177
"SmoothFakeDynQuantMixin",

torchao/quantization/quant_api.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@
7777
Float8Tensor,
7878
Int4ChooseQParamsAlgorithm,
7979
Int4MarlinSparseTensor,
80-
Int4OpaqueTensor,
8180
Int4PackingFormat,
8281
Int4PlainInt32Tensor,
8382
Int4PreshuffledTensor,
@@ -1163,12 +1162,9 @@ def _int4_weight_only_quantize_tensor(weight, config):
11631162
block_size = list(block_size)
11641163

11651164
if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
1166-
assert int4_packing_format in [
1167-
Int4PackingFormat.TILE_PACKED_TO_4D,
1168-
Int4PackingFormat.OPAQUE,
1169-
], (
1165+
assert int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D, (
11701166
f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, "
1171-
f"it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D and Int4PackingFormat.OPAQUE currently"
1167+
f"it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D currently"
11721168
)
11731169

11741170
if int4_packing_format == Int4PackingFormat.PRESHUFFLED:
@@ -1196,13 +1192,6 @@ def _int4_weight_only_quantize_tensor(weight, config):
11961192
block_size,
11971193
)
11981194
return new_weight
1199-
elif int4_packing_format == Int4PackingFormat.OPAQUE:
1200-
new_weight = Int4OpaqueTensor.from_hp(
1201-
weight,
1202-
block_size,
1203-
int4_choose_qparams_algorithm=int4_choose_qparams_algorithm,
1204-
)
1205-
return new_weight
12061195
elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D:
12071196
new_weight = Int4TilePackedTo4dTensor.from_hp(
12081197
weight,

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
from .int4.int4_marlin_sparse_tensor import (
77
Int4MarlinSparseTensor,
88
)
9-
from .int4.int4_opaque_tensor import (
10-
Int4OpaqueTensor,
11-
)
129
from .int4.int4_packing_format import Int4PackingFormat
1310
from .int4.int4_plain_int32_tensor import (
1411
Int4PlainInt32Tensor,
@@ -39,7 +36,6 @@
3936
"Int4TilePackedTo4dTensor",
4037
"Float8Tensor",
4138
"QuantizeTensorToFloat8Kwargs",
42-
"Int4OpaqueTensor",
4339
"Int4ChooseQParamsAlgorithm",
4440
"Int4PackingFormat",
4541
"IntxChooseQParamsAlgorithm",

0 commit comments

Comments
 (0)