Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
48cdb61
Int8Tensor migration
jcaip Dec 1, 2025
0b73aed
ruff fixes
jcaip Dec 1, 2025
1e49945
add init
jcaip Dec 1, 2025
669b6ee
fix ruff again
jcaip Dec 1, 2025
9071526
update
jcaip Dec 1, 2025
1539e0f
wip
jcaip Dec 2, 2025
d9a2b1b
Merge branch 'main' into jcaip/int8-tensor
jcaip Dec 3, 2025
673f228
undo update tests
jcaip Dec 3, 2025
739fd64
fix ruff
jcaip Dec 3, 2025
750db1a
fix varname
jcaip Dec 3, 2025
9410488
fix typing
jcaip Dec 3, 2025
45a3a76
add tests
jcaip Dec 3, 2025
4e2f09c
fix dtype
jcaip Dec 3, 2025
dd80cca
fix ci
jcaip Dec 3, 2025
7f73062
address granularity cr
jcaip Dec 4, 2025
ac6a2b6
update _choose_quant_func_and_quantize_tensor
jcaip Dec 4, 2025
f28df4a
make block size required attribute
jcaip Dec 4, 2025
328585e
made dtype required as well
jcaip Dec 4, 2025
ce4d568
address nits
jcaip Dec 4, 2025
a665d45
skip per tensor weight only test for now
jcaip Dec 4, 2025
0338016
add static quant
jcaip Dec 3, 2025
ee39691
add static quant
jcaip Dec 4, 2025
9eb0aa9
update
jcaip Dec 5, 2025
d4a1514
static quant working eager + compile
jcaip Dec 6, 2025
3cdea56
remove file
jcaip Dec 6, 2025
fa9022d
added asserts
jcaip Dec 6, 2025
8ce5cde
undo smoothquant change
jcaip Dec 6, 2025
6f64121
fix return
jcaip Dec 6, 2025
8ae921d
Merge branch 'main' into jcaip/static-quant-rebased
jcaip Dec 7, 2025
5b9e243
got smoothquant + int8 static working
jcaip Dec 8, 2025
7a0e38f
generalized smoothquat code
jcaip Dec 8, 2025
3d18edf
free tests
jcaip Dec 8, 2025
9e07f8b
fix static scale check
jcaip Dec 8, 2025
4274e02
update
jcaip Dec 8, 2025
b5309eb
address cr feedback
jcaip Dec 9, 2025
a732fee
Merge branch 'jcaip/static-quant-rebased' into jcaip/enable-smoothquant
jcaip Dec 9, 2025
0c23589
Merge branch 'main' into jcaip/enable-smoothquant
jcaip Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions test/prototype/test_smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
)
from torchao.prototype.smoothquant.core import SmoothQuantStep
from torchao.quantization import quantize_
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.linear_activation_scale import (
WeightTensorWithLinearActivationScaleMetadata,
)
from torchao.quantization.quant_api import (
Int8DynamicActivationInt8WeightConfig,
Int8StaticActivationInt8WeightConfig,
)
from torchao.quantization.utils import (
compute_error as SQNR,
Expand Down Expand Up @@ -83,7 +85,10 @@ def setUpClass(cls):
@common_utils.parametrize(
"base_config",
[
Int8DynamicActivationInt8WeightConfig(),
Int8DynamicActivationInt8WeightConfig(version=2),
# TODO: not sure if we should allow not passing scales as part of static config?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think it's fine

side note: we may need a separate API/flow for plain static quant without Smoothquant if needed.

Int8StaticActivationInt8WeightConfig(granularity=PerRow()),
Int8StaticActivationInt8WeightConfig(granularity=PerTensor()),
# Note: float8_static_activation_float8_weight is broken after recent PyTorch update.
# TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py
],
Expand All @@ -101,7 +106,15 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype):

# Step 1. Basic quantization
basic_model = deepcopy(m)
quantize_(basic_model, base_config)
if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
quantize_(
basic_model,
Int8DynamicActivationInt8WeightConfig(
version=2, granularity=base_config.granularity
),
)
else:
quantize_(basic_model, base_config)
out_basic = basic_model(*x)
loss_base = torch.nn.functional.mse_loss(out_basic, out_ref).item()

Expand Down
20 changes: 19 additions & 1 deletion torchao/prototype/smoothquant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,12 @@
)
from torchao.quantization.quant_api import (
_QUANTIZE_CONFIG_HANDLER,
Int8StaticActivationInt8WeightConfig,
_linear_extra_repr,
)
from torchao.quantization.quantize_.workflows.int8.int8_tensor import (
QuantizeTensorToInt8Kwargs,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
Expand Down Expand Up @@ -95,8 +99,18 @@ def _smooth_quant_transform(
else:
raise ValueError(f"Unexpected step: {step}")

if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
quant_kwargs = QuantizeTensorToInt8Kwargs(
granularity=base_config.granularity,
mapping_type=base_config.act_mapping_type,
)
else:
quant_kwargs = None

# Compute smoothed weight parameters
smoothing_factor = observed_linear.obs.calculate_qparams()
smoothing_factor, activation_scale = observed_linear.obs.calculate_qparams(
weight_quant_kwargs=quant_kwargs
)
weight = observed_linear.weight * smoothing_factor

# Create new linear layer
Expand All @@ -111,6 +125,9 @@ def _smooth_quant_transform(
linear.bias = observed_linear.bias

# Quantize weights
if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
base_config.scale = activation_scale

base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(base_config)]
dummy_mod = DummyModule(weight)
quant_mod = base_config_handler(dummy_mod, base_config)
Expand All @@ -120,6 +137,7 @@ def _smooth_quant_transform(
qw = to_weight_tensor_with_linear_activation_scale_metadata(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not be using this, please check awq on how this should be implemented in the new stack:

assert isinstance(qw, SupportsActivationPreScaling), (
"weight must support activation scaling through implementing `SupportsActivationPreScaling`"
)
# since we want to do `act` * `act_pre_scale` during runtime for speed, we'll save the
# reciprocal of the `equalization_scale`
qw.act_pre_scale = 1.0 / equalization_scale

qw, smoothing_factor.to(qw.dtype)
)

linear.weight = torch.nn.Parameter(qw, requires_grad=False)
linear.extra_repr = types.MethodType(_linear_extra_repr, linear)

Expand Down
25 changes: 19 additions & 6 deletions torchao/prototype/smoothquant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
import torch
import torch.nn.functional as F

from torchao.quantization.quantize_.common import (
_choose_quant_func_and_quantize_tensor,
)


class SmoothQuantStep(str, Enum):
PREPARE = "prepare"
Expand Down Expand Up @@ -41,13 +45,14 @@ def forward(self, input: torch.Tensor):
self.inputs.append(input.to("cpu"))
return input

def calculate_qparams(self):
def calculate_qparams(self, weight_quant_kwargs=None):
assert self.inputs and len(self.inputs) > 0, (
"calibrate observer first by running model on exemplar data"
)
inputs = [inp.to(self.device) for inp in self.inputs]
acc = torch.cat(inputs, dim=0)
# Reshape if needed: [batch, seq, features] -> [batch*seq, features]
example_input_for_quantization = acc
if acc.ndim > 2:
acc = acc.view(-1, acc.shape[-1])

Expand All @@ -57,12 +62,20 @@ def calculate_qparams(self):

# Calculate smoothing factor
if self.alpha is None:
return torch.ones_like(x_abs_max)
smoothing_factor = torch.ones_like(x_abs_max)
else:
eps = torch.finfo(torch.float32).eps
smoothing_factor = torch.pow(x_abs_max + eps, self.alpha) / torch.pow(
w_abs_max + eps, 1 - self.alpha
)

eps = torch.finfo(torch.float32).eps
return torch.pow(x_abs_max + eps, self.alpha) / torch.pow(
w_abs_max + eps, 1 - self.alpha
)
if weight_quant_kwargs is not None:
quant_smooth_activation = _choose_quant_func_and_quantize_tensor(
example_input_for_quantization / smoothing_factor, weight_quant_kwargs
)
return smoothing_factor, quant_smooth_activation.scale
else:
return smoothing_factor, None


class SmoothQuantObservedLinear(torch.nn.Linear):
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,7 +1658,7 @@ class Int8StaticActivationInt8WeightConfig(AOBaseConfig):
version (int): the version of the config
"""

scale: torch.Tensor
scale: torch.Tensor = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Optional[torch.Tensor]

granularity: Granularity = PerRow()
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
set_inductor_config: bool = True
Expand Down
5 changes: 3 additions & 2 deletions torchao/quantization/quantize_/workflows/int8/int8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.

import math
from dataclasses import dataclass
from typing import List, Optional

Expand Down Expand Up @@ -199,12 +199,13 @@ def _(func, types, args, kwargs):
output_dtype = activation_tensor.dtype

if weight_tensor.act_quant_kwargs is not None:
# for int8 dynamic + static quantization path

activation_tensor = _choose_quant_func_and_quantize_tensor(
activation_tensor,
weight_tensor.act_quant_kwargs,
scale=weight_tensor.act_scale,
)
# Dynamic activation quantization path

# 1. do the matrix form of dot(X_i, W_j)
#
Expand Down
Loading