-
Notifications
You must be signed in to change notification settings - Fork 386
enable smoothquant for int8 static tensor #3468
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
48cdb61
0b73aed
1e49945
669b6ee
9071526
1539e0f
d9a2b1b
673f228
739fd64
750db1a
9410488
45a3a76
4e2f09c
dd80cca
7f73062
ac6a2b6
f28df4a
328585e
ce4d568
a665d45
0338016
ee39691
9eb0aa9
d4a1514
3cdea56
fa9022d
8ce5cde
6f64121
8ae921d
5b9e243
7a0e38f
3d18edf
9e07f8b
4274e02
b5309eb
a732fee
0c23589
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -15,8 +15,12 @@ | |||||||||||||
| ) | ||||||||||||||
| from torchao.quantization.quant_api import ( | ||||||||||||||
| _QUANTIZE_CONFIG_HANDLER, | ||||||||||||||
| Int8StaticActivationInt8WeightConfig, | ||||||||||||||
| _linear_extra_repr, | ||||||||||||||
| ) | ||||||||||||||
| from torchao.quantization.quantize_.workflows.int8.int8_tensor import ( | ||||||||||||||
| QuantizeTensorToInt8Kwargs, | ||||||||||||||
| ) | ||||||||||||||
| from torchao.quantization.transform_module import ( | ||||||||||||||
| register_quantize_module_handler, | ||||||||||||||
| ) | ||||||||||||||
|
|
@@ -95,8 +99,18 @@ def _smooth_quant_transform( | |||||||||||||
| else: | ||||||||||||||
| raise ValueError(f"Unexpected step: {step}") | ||||||||||||||
|
|
||||||||||||||
| if isinstance(base_config, Int8StaticActivationInt8WeightConfig): | ||||||||||||||
| quant_kwargs = QuantizeTensorToInt8Kwargs( | ||||||||||||||
| granularity=base_config.granularity, | ||||||||||||||
| mapping_type=base_config.act_mapping_type, | ||||||||||||||
| ) | ||||||||||||||
| else: | ||||||||||||||
| quant_kwargs = None | ||||||||||||||
|
|
||||||||||||||
| # Compute smoothed weight parameters | ||||||||||||||
| smoothing_factor = observed_linear.obs.calculate_qparams() | ||||||||||||||
| smoothing_factor, activation_scale = observed_linear.obs.calculate_qparams( | ||||||||||||||
| weight_quant_kwargs=quant_kwargs | ||||||||||||||
| ) | ||||||||||||||
| weight = observed_linear.weight * smoothing_factor | ||||||||||||||
|
|
||||||||||||||
| # Create new linear layer | ||||||||||||||
|
|
@@ -111,6 +125,9 @@ def _smooth_quant_transform( | |||||||||||||
| linear.bias = observed_linear.bias | ||||||||||||||
|
|
||||||||||||||
| # Quantize weights | ||||||||||||||
| if isinstance(base_config, Int8StaticActivationInt8WeightConfig): | ||||||||||||||
| base_config.scale = activation_scale | ||||||||||||||
|
|
||||||||||||||
| base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(base_config)] | ||||||||||||||
| dummy_mod = DummyModule(weight) | ||||||||||||||
| quant_mod = base_config_handler(dummy_mod, base_config) | ||||||||||||||
|
|
@@ -120,6 +137,7 @@ def _smooth_quant_transform( | |||||||||||||
| qw = to_weight_tensor_with_linear_activation_scale_metadata( | ||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should not be using this, please check awq on how this should be implemented in the new stack: ao/torchao/prototype/awq/api.py Lines 108 to 113 in 08e5e20
|
||||||||||||||
| qw, smoothing_factor.to(qw.dtype) | ||||||||||||||
| ) | ||||||||||||||
|
|
||||||||||||||
| linear.weight = torch.nn.Parameter(qw, requires_grad=False) | ||||||||||||||
| linear.extra_repr = types.MethodType(_linear_extra_repr, linear) | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1658,7 +1658,7 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig): | |
| version (int): the version of the config | ||
| """ | ||
|
|
||
| scale: torch.Tensor | ||
| scale: torch.Tensor = None | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: Optional[torch.Tensor] |
||
| granularity: Granularity = PerRow() | ||
| act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC | ||
| set_inductor_config: bool = True | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah I think it's fine
side note: we may need a separate API/flow for plain static quant without Smoothquant if needed.