Skip to content

Commit 10ad5bd

Browse files
authored
[Reland] Remove config functions like int4_weight_only (#3145) (#3308)
Remove config functions like `int4_weight_only` (#3145) (#3308) Summary: **Summary:** As a follow-up to #2994, this commit removes all quantization functions that were used as configs. These functions were deprecated in 0.14.0 and will be removed in the next release, 0.15.0. Test Plan: CI Differential Revision: D86530816 Pulled By: andrewor14
1 parent 31192f2 commit 10ad5bd

File tree

5 files changed

+21
-151
lines changed

5 files changed

+21
-151
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ Our framework makes it straightforward to add tensor parallel support to your cu
270270
271271
We've added support for authoring and releasing [custom ops](./torchao/csrc/) that do not graph break with `torch.compile()`. We have a few examples you can follow
272272
273-
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, fpx_weight_only(3, 2))`
273+
1. [fp6](torchao/dtypes/floatx/README.md) for 2x faster inference over fp16 with an easy to use API `quantize_(model, FPXWeightOnlyConfig(3, 2))`
274274
2. [2:4 Sparse Marlin GEMM](https://github.com/pytorch/ao/pull/733) 2x speedups for FP16xINT4 kernels even at batch sizes up to 256
275275
3. [int4 tinygemm unpacker](https://github.com/pytorch/ao/pull/415) which makes it easier to switch quantized backends for inference
276276

test/quantization/test_quant_api.py

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -801,38 +801,28 @@ def test_int4wo_cuda_serialization(self):
801801

802802
def test_config_deprecation(self):
803803
"""
804-
Test that old config functions like `int4_weight_only` trigger deprecation warnings.
804+
Test that old config functions like `Int8DynamicActivationInt4WeightConfig` trigger deprecation warnings.
805805
"""
806806
from torchao.quantization import (
807-
float8_dynamic_activation_float8_weight,
808-
float8_static_activation_float8_weight,
809-
float8_weight_only,
810-
fpx_weight_only,
811-
gemlite_uintx_weight_only,
812-
int4_dynamic_activation_int4_weight,
813-
int4_weight_only,
814-
int8_dynamic_activation_int4_weight,
815-
int8_dynamic_activation_int8_weight,
816-
int8_weight_only,
817-
uintx_weight_only,
807+
Float8StaticActivationFloat8WeightConfig,
808+
FPXWeightOnlyConfig,
809+
GemliteUIntXWeightOnlyConfig,
810+
Int4DynamicActivationInt4WeightConfig,
811+
Int8DynamicActivationInt4WeightConfig,
812+
UIntXWeightOnlyConfig,
818813
)
819814

820815
# Reset deprecation warning state, otherwise we won't log warnings here
821816
warnings.resetwarnings()
822817

823818
# Map from deprecated API to the args needed to instantiate it
824819
deprecated_apis_to_args = {
825-
float8_dynamic_activation_float8_weight: (),
826-
float8_static_activation_float8_weight: (torch.randn(3)),
827-
float8_weight_only: (),
828-
fpx_weight_only: (3, 2),
829-
gemlite_uintx_weight_only: (),
830-
int4_dynamic_activation_int4_weight: (),
831-
int4_weight_only: (),
832-
int8_dynamic_activation_int4_weight: (),
833-
int8_dynamic_activation_int8_weight: (),
834-
int8_weight_only: (),
835-
uintx_weight_only: (torch.uint4,),
820+
Float8StaticActivationFloat8WeightConfig: (torch.randn(3),),
821+
FPXWeightOnlyConfig: (3, 2),
822+
GemliteUIntXWeightOnlyConfig: (),
823+
Int4DynamicActivationInt4WeightConfig: (),
824+
Int8DynamicActivationInt4WeightConfig: (),
825+
UIntXWeightOnlyConfig: (torch.uint4,),
836826
}
837827

838828
# Call each deprecated API twice
@@ -841,19 +831,16 @@ def test_config_deprecation(self):
841831
cls(*args)
842832
cls(*args)
843833

844-
# Each call should have at least one warning.
845-
# Some of them can have two warnings - one for deprecation,
846-
# one for moving to prototype
847-
# 1 warning - just deprecation
848-
# 2 warnings - deprecation and prototype warnings
849-
self.assertTrue(len(_warnings) in (1, 2))
834+
self.assertTrue(len(_warnings) == 1)
850835
found_deprecated = False
851836
for w in _warnings:
852-
if "is deprecated and will be removed in a future release" in str(
837+
if "will be moving to prototype in a future release" in str(
853838
w.message
854839
):
855840
found_deprecated = True
856-
self.assertTrue(found_deprecated)
841+
self.assertTrue(
842+
found_deprecated, f"did not find deprecated warning for {cls}"
843+
)
857844

858845

859846
common_utils.instantiate_parametrized_tests(TestQuantFlow)

torchao/quantization/__init__.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -66,22 +66,10 @@
6666
PlainLayout,
6767
TensorCoreTiledLayout,
6868
UIntXWeightOnlyConfig,
69-
float8_dynamic_activation_float8_weight,
70-
float8_static_activation_float8_weight,
71-
float8_weight_only,
72-
fpx_weight_only,
7369
fqn_matches_fqn_config,
74-
gemlite_uintx_weight_only,
75-
int4_dynamic_activation_int4_weight,
76-
int4_weight_only,
77-
int8_dynamic_activation_int4_weight,
78-
int8_dynamic_activation_int8_semi_sparse_weight,
79-
int8_dynamic_activation_int8_weight,
80-
int8_weight_only,
8170
intx_quantization_aware_training,
8271
quantize_,
8372
swap_conv2d_1x1_to_linear,
84-
uintx_weight_only,
8573
)
8674
from .quant_primitives import (
8775
MappingType,
@@ -132,20 +120,8 @@
132120
"ALL_AUTOQUANT_CLASS_LIST",
133121
# top level API - manual
134122
"quantize_",
135-
"int4_dynamic_activation_int4_weight",
136-
"int8_dynamic_activation_int4_weight",
137-
"int8_dynamic_activation_int8_weight",
138-
"int8_dynamic_activation_int8_semi_sparse_weight",
139-
"int4_weight_only",
140-
"int8_weight_only",
141123
"intx_quantization_aware_training",
142-
"float8_weight_only",
143-
"float8_dynamic_activation_float8_weight",
144-
"float8_static_activation_float8_weight",
145-
"uintx_weight_only",
146-
"fpx_weight_only",
147124
"fqn_matches_fqn_config",
148-
"gemlite_uintx_weight_only",
149125
"swap_conv2d_1x1_to_linear",
150126
"Int4DynamicActivationInt4WeightConfig",
151127
"Int8DynamicActivationInt4WeightConfig",

torchao/quantization/quant_api.py

Lines changed: 1 addition & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@
9999
to_weight_tensor_with_linear_activation_quantization_metadata,
100100
)
101101
from torchao.utils import (
102-
_ConfigDeprecationWrapper,
103102
is_MI300,
104103
is_sm_at_least_89,
105104
is_sm_at_least_90,
@@ -148,18 +147,7 @@
148147
"autoquant",
149148
"_get_subclass_inserter",
150149
"quantize_",
151-
"int8_dynamic_activation_int4_weight",
152-
"int8_dynamic_activation_int8_weight",
153-
"int8_dynamic_activation_int8_semi_sparse_weight",
154-
"int4_weight_only",
155-
"int8_weight_only",
156150
"intx_quantization_aware_training",
157-
"float8_weight_only",
158-
"uintx_weight_only",
159-
"fpx_weight_only",
160-
"gemlite_uintx_weight_only",
161-
"float8_dynamic_activation_float8_weight",
162-
"float8_static_activation_float8_weight",
163151
"Int8DynActInt4WeightQuantizer",
164152
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
165153
"ModuleFqnToConfig",
@@ -479,7 +467,7 @@ def quantize_(
479467
# Int8DynamicActivationInt8WeightConfig (optimized with int8 mm op and torch.compile)
480468
# Int4WeightOnlyConfig (optimized with int4 tinygemm kernel and torch.compile)
481469
# Int8WeightOnlyConfig (optimized with int8 mm op and torch.compile
482-
from torchao.quantization.quant_api import int4_weight_only
470+
from torchao.quantization.quant_api import Int4WeightOnlyConfig
483471
484472
m = nn.Sequential(nn.Linear(32, 1024), nn.Linear(1024, 32))
485473
quantize_(m, Int4WeightOnlyConfig(group_size=32, version=1))
@@ -611,12 +599,6 @@ def __post_init__(self):
611599
)
612600

613601

614-
# for BC
615-
int8_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
616-
"int8_dynamic_activation_int4_weight", Int8DynamicActivationInt4WeightConfig
617-
)
618-
619-
620602
@register_quantize_module_handler(Int8DynamicActivationInt4WeightConfig)
621603
def _int8_dynamic_activation_int4_weight_transform(
622604
module: torch.nn.Module,
@@ -985,12 +967,6 @@ def __post_init__(self):
985967
)
986968

987969

988-
# for bc
989-
int4_dynamic_activation_int4_weight = _ConfigDeprecationWrapper(
990-
"int4_dynamic_activation_int4_weight", Int4DynamicActivationInt4WeightConfig
991-
)
992-
993-
994970
@register_quantize_module_handler(Int4DynamicActivationInt4WeightConfig)
995971
def _int4_dynamic_activation_int4_weight_transform(
996972
module: torch.nn.Module, config: Int4DynamicActivationInt4WeightConfig
@@ -1051,12 +1027,6 @@ def __post_init__(self):
10511027
)
10521028

10531029

1054-
# for BC
1055-
gemlite_uintx_weight_only = _ConfigDeprecationWrapper(
1056-
"gemlite_uintx_weight_only", GemliteUIntXWeightOnlyConfig
1057-
)
1058-
1059-
10601030
@register_quantize_module_handler(GemliteUIntXWeightOnlyConfig)
10611031
def _gemlite_uintx_weight_only_transform(
10621032
module: torch.nn.Module, config: GemliteUIntXWeightOnlyConfig
@@ -1134,11 +1104,6 @@ def __post_init__(self):
11341104
torch._C._log_api_usage_once("torchao.quantization.Int4WeightOnlyConfig")
11351105

11361106

1137-
# for BC
1138-
# TODO maybe change other callsites
1139-
int4_weight_only = _ConfigDeprecationWrapper("int4_weight_only", Int4WeightOnlyConfig)
1140-
1141-
11421107
def _int4_weight_only_quantize_tensor(weight, config):
11431108
# TODO(future PR): perhaps move this logic to a different file, to keep the API
11441109
# file clean of implementation details
@@ -1348,10 +1313,6 @@ def __post_init__(self):
13481313
)
13491314

13501315

1351-
# for BC
1352-
int8_weight_only = _ConfigDeprecationWrapper("int8_weight_only", Int8WeightOnlyConfig)
1353-
1354-
13551316
def _int8_weight_only_quantize_tensor(weight, config):
13561317
if config.version == 1:
13571318
warnings.warn(
@@ -1537,12 +1498,6 @@ def __post_init__(self):
15371498
)
15381499

15391500

1540-
# for BC
1541-
int8_dynamic_activation_int8_weight = _ConfigDeprecationWrapper(
1542-
"int8_dynamic_activation_int8_weight", Int8DynamicActivationInt8WeightConfig
1543-
)
1544-
1545-
15461501
def _int8_dynamic_activation_int8_weight_quantize_tensor(weight, config):
15471502
if config.version == 1:
15481503
layout = config.layout
@@ -1756,12 +1711,6 @@ def __post_init__(self):
17561711
torch._C._log_api_usage_once("torchao.quantization.Float8WeightOnlyConfig")
17571712

17581713

1759-
# for BC
1760-
float8_weight_only = _ConfigDeprecationWrapper(
1761-
"float8_weight_only", Float8WeightOnlyConfig
1762-
)
1763-
1764-
17651714
def _float8_weight_only_quant_tensor(weight, config):
17661715
if config.version == 1:
17671716
warnings.warn(
@@ -1946,12 +1895,6 @@ def __post_init__(self):
19461895
self.mm_config = Float8MMConfig(use_fast_accum=default_use_fast_accum)
19471896

19481897

1949-
# for bc
1950-
float8_dynamic_activation_float8_weight = _ConfigDeprecationWrapper(
1951-
"float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig
1952-
)
1953-
1954-
19551898
def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config):
19561899
activation_dtype = config.activation_dtype
19571900
weight_dtype = config.weight_dtype
@@ -2160,12 +2103,6 @@ def __post_init__(self):
21602103
)
21612104

21622105

2163-
# for bc
2164-
float8_static_activation_float8_weight = _ConfigDeprecationWrapper(
2165-
"float8_static_activation_float8_weight", Float8StaticActivationFloat8WeightConfig
2166-
)
2167-
2168-
21692106
@register_quantize_module_handler(Float8StaticActivationFloat8WeightConfig)
21702107
def _float8_static_activation_float8_weight_transform(
21712108
module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig
@@ -2251,12 +2188,6 @@ def __post_init__(self):
22512188
)
22522189

22532190

2254-
# for BC
2255-
uintx_weight_only = _ConfigDeprecationWrapper(
2256-
"uintx_weight_only", UIntXWeightOnlyConfig
2257-
)
2258-
2259-
22602191
@register_quantize_module_handler(UIntXWeightOnlyConfig)
22612192
def _uintx_weight_only_transform(
22622193
module: torch.nn.Module, config: UIntXWeightOnlyConfig
@@ -2550,10 +2481,6 @@ def __post_init__(self):
25502481
)
25512482

25522483

2553-
# for BC
2554-
fpx_weight_only = _ConfigDeprecationWrapper("fpx_weight_only", FPXWeightOnlyConfig)
2555-
2556-
25572484
@register_quantize_module_handler(FPXWeightOnlyConfig)
25582485
def _fpx_weight_only_transform(
25592486
module: torch.nn.Module, config: FPXWeightOnlyConfig

torchao/utils.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
import itertools
99
import re
1010
import time
11-
import warnings
1211
from functools import reduce
1312
from importlib.metadata import version
1413
from math import gcd
15-
from typing import Any, Callable, Optional, Type
14+
from typing import Any, Callable, Optional
1615

1716
import torch
1817
import torch.nn.utils.parametrize as parametrize
@@ -376,25 +375,6 @@ def torch_version_at_least(min_version):
376375
return parse_version(torch.__version__) >= parse_version(min_version)
377376

378377

379-
class _ConfigDeprecationWrapper:
380-
"""
381-
A deprecation wrapper that directs users from a deprecated "config function"
382-
(e.g. `int4_weight_only`) to the replacement config class.
383-
"""
384-
385-
def __init__(self, deprecated_name: str, config_cls: Type):
386-
self.deprecated_name = deprecated_name
387-
self.config_cls = config_cls
388-
389-
def __call__(self, *args, **kwargs):
390-
warnings.warn(
391-
f"`{self.deprecated_name}` is deprecated and will be removed in a future release. "
392-
f"Please use `{self.config_cls.__name__}` instead. Example usage:\n"
393-
f" quantize_(model, {self.config_cls.__name__}(...))"
394-
)
395-
return self.config_cls(*args, **kwargs)
396-
397-
398378
"""
399379
Helper function for implementing aten op or torch function dispatch
400380
and dispatching to these implementations.

0 commit comments

Comments
 (0)