Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,13 @@
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8WeightOnlyConfig,
Int8StaticActivationInt8WeightConfig,
_replace_with_custom_fn_if_matches_filter,
quantize_,
)
from torchao.quantization.quant_primitives import (
MappingType,
choose_qparams_affine,
dequantize_affine,
)
from torchao.quantization.smoothquant import (
Expand Down Expand Up @@ -1004,6 +1006,25 @@ def test_dynamic_quant(self):
sqnr = compute_error(y_ref, y_test)
self.assertGreater(sqnr, 40.0)

class TestStaticQuant(unittest.TestCase):
def test_static_quant(self):
M, K, N = 8, 16, 8
x = torch.randn(M, K)
m = nn.Sequential(nn.Linear(K, N))
block_size = [M, K] # per-tensor quantization
scale, _ = choose_qparams_affine(
x,
mapping_type=MappingType.SYMMETRIC,
block_size=block_size,
target_dtype=torch.int8,
)

y_ref = m(x)
quantize_(m, Int8StaticActivationInt8WeightConfig(act_quant_scale=scale))
y_test = m(x)

sqnr = compute_error(y_ref, y_test)
self.assertGreater(sqnr, 40.0)

class TestWeightOnlyInt8Quant(unittest.TestCase):
def test_weight_only_quant(self):
Expand Down
4 changes: 4 additions & 0 deletions test/prototype/test_smoothquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)
from torchao.quantization.quant_api import (
Int8DynamicActivationInt8WeightConfig,
Int8StaticActivationInt8WeightConfig
)
from torchao.quantization.utils import (
compute_error as SQNR,
Expand Down Expand Up @@ -84,6 +85,7 @@ def setUpClass(cls):
"base_config",
[
Int8DynamicActivationInt8WeightConfig(),
Int8StaticActivationInt8WeightConfig(),
# Note: float8_static_activation_float8_weight is broken after recent PyTorch update.
# TODO(#1639): Fix for supporting more API in torchao/quantization/quant_api.py
],
Expand Down Expand Up @@ -139,6 +141,7 @@ def test_smoothquant_accuracy(self, alpha, base_config, device, input_dtype):
"base_config",
[
Int8DynamicActivationInt8WeightConfig(),
Int8StaticActivationInt8WeightConfig(),
# TODO: Check more quantization APIs
],
)
Expand Down Expand Up @@ -178,6 +181,7 @@ def test_observer_insertion(self, base_config):
"base_config",
[
Int8DynamicActivationInt8WeightConfig(),
Int8StaticActivationInt8WeightConfig(),
# TODO: Check more quantization APIs
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from torchao.quantization import (
Int8DynamicActivationInt8WeightConfig,
Int8StaticActivationInt8WeightConfig,
Int8WeightOnlyConfig,
quantize_,
)
Expand Down Expand Up @@ -66,6 +67,7 @@ def test_creation_and_attributes(self, config):
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8StaticActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
Expand Down Expand Up @@ -108,6 +110,7 @@ def test_int8_linear_variants(
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8StaticActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
Expand Down Expand Up @@ -173,6 +176,7 @@ def test_index_select(self, config, granularity):
"config",
[
Int8DynamicActivationInt8WeightConfig(version=2),
Int8StaticActivationInt8WeightConfig(version=2),
Int8WeightOnlyConfig(version=2),
],
)
Expand Down
3 changes: 2 additions & 1 deletion torchao/kernel/intmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def int_scaled_matmul(
assert M == scales1.size(0) or scales1.numel() == 1
assert 1 == scales1.size(1)
assert scales1.is_contiguous()
scales1 = scales1.expand((M, N))
if scales1.device.type != "cpu":
scales1 = scales1.expand((M, N))
assert scales1.dim() == 2

if check_cpu_version(scales1.device):
Expand Down
1 change: 1 addition & 0 deletions torchao/prototype/smoothquant/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ for data in calibration_dataset:
quant_config.step = SmoothQuantStep.CONVERT
quantize_(model, quant_config)
```
For static quantization of activation, use `Int8StaticActivationInt8WeightConfig` instead of `Int8DynamicActivationInt8WeightConfig`. Generally, static quantization produces better througput at the cost of accuracy (higher perplexity).

## Benchmarks

Expand Down
32 changes: 20 additions & 12 deletions torchao/prototype/smoothquant/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)
from torchao.quantization.quant_api import (
_QUANTIZE_CONFIG_HANDLER,
Int8StaticActivationInt8WeightConfig,
_linear_extra_repr,
)
from torchao.quantization.transform_module import (
Expand Down Expand Up @@ -96,21 +97,17 @@ def _smooth_quant_transform(
raise ValueError(f"Unexpected step: {step}")

# Compute smoothed weight parameters
smoothing_factor = observed_linear.obs.calculate_qparams()
act_quant_min, act_quant_max = None, None
if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
act_quant_min, act_quant_max = -127, 127
smoothing_factor, act_quant_scale = observed_linear.obs.calculate_qparams(
act_quant_min, act_quant_max
)
weight = observed_linear.weight * smoothing_factor

# Create new linear layer
with torch.device("meta"):
linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
observed_linear.bias is not None,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.bias = observed_linear.bias

# Quantize weights
if isinstance(base_config, Int8StaticActivationInt8WeightConfig):
base_config.act_quant_scale = act_quant_scale
base_config_handler = _QUANTIZE_CONFIG_HANDLER[type(base_config)]
dummy_mod = DummyModule(weight)
quant_mod = base_config_handler(dummy_mod, base_config)
Expand All @@ -120,7 +117,18 @@ def _smooth_quant_transform(
qw = to_weight_tensor_with_linear_activation_scale_metadata(
qw, smoothing_factor.to(qw.dtype)
)

# Create new linear layer
with torch.device("meta"):
linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
observed_linear.bias is not None,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.weight = torch.nn.Parameter(qw, requires_grad=False)
linear.extra_repr = types.MethodType(_linear_extra_repr, linear)
linear.bias = observed_linear.bias

return linear
20 changes: 14 additions & 6 deletions torchao/prototype/smoothquant/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def forward(self, input: torch.Tensor):
self.inputs.append(input.to("cpu"))
return input

def calculate_qparams(self):
def calculate_qparams(self, act_quant_min=None, act_quant_max=None):
assert self.inputs and len(self.inputs) > 0, (
"calibrate observer first by running model on exemplar data"
)
Expand All @@ -57,12 +57,20 @@ def calculate_qparams(self):

# Calculate smoothing factor
if self.alpha is None:
return torch.ones_like(x_abs_max)
smooth_factor = torch.ones_like(x_abs_max)
else:
eps = torch.finfo(torch.float32).eps
smooth_factor = torch.pow(x_abs_max + eps, self.alpha) / torch.pow(
w_abs_max + eps, 1 - self.alpha
)

eps = torch.finfo(torch.float32).eps
return torch.pow(x_abs_max + eps, self.alpha) / torch.pow(
w_abs_max + eps, 1 - self.alpha
)
# Calculate per-tensor act_quant_scale
act_quant_scale = None
if act_quant_min is not None and act_quant_max is not None:
x_abs_max_t = acc.abs().max()
act_quant_scale = (x_abs_max_t / (act_quant_max - act_quant_min) / 2).item()

return smooth_factor, act_quant_scale


class SmoothQuantObservedLinear(torch.nn.Linear):
Expand Down
37 changes: 34 additions & 3 deletions torchao/prototype/smoothquant/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
)
from torchao.prototype.smoothquant.core import SmoothQuantStep
from torchao.quantization import quantize_
from torchao.quantization.quant_api import Int8DynamicActivationInt8WeightConfig

from torchao.quantization.quant_api import (
Int8DynamicActivationInt8WeightConfig,
Int8StaticActivationInt8WeightConfig,
)

# TODO: Build benchmark within vLLM ecosystem with more quantization APIs
# See https://github.com/pytorch/ao/issues/2815 for more details
Expand Down Expand Up @@ -82,6 +84,8 @@ def quantize_and_eval(
device: str,
model_save_path: str,
model_save_hf_hub_path: str,
static_quant_act: bool,
compile: bool,
):
print(f"Loading model on {device}...")
torch.manual_seed(34)
Expand All @@ -96,9 +100,14 @@ def quantize_and_eval(

# Step 1: Prepare - insert observers
print("running SmoothQuant prepare and calibrate")
base_config = (
Int8StaticActivationInt8WeightConfig()
if static_quant_act
else Int8DynamicActivationInt8WeightConfig()
)
t0 = time.time()
quant_config = SmoothQuantConfig(
base_config=Int8DynamicActivationInt8WeightConfig(),
base_config=base_config,
step=SmoothQuantStep.PREPARE,
alpha=alpha,
)
Expand Down Expand Up @@ -133,6 +142,8 @@ def quantize_and_eval(
print("pushing model to hub:", model_save_hf_hub_path)
model.push_to_hub(model_save_hf_hub_path, safe_serialization=False)
tokenizer.push_to_hub(model_save_hf_hub_path)
if compile:
model.forward = torch.compile(model.forward, dynamic=True)

print("Benchmarking SmoothQuant model...")
return benchmark(model, tokenizer, max_seq_length, tasks=tasks, device=device)
Expand All @@ -147,6 +158,8 @@ def compare_models(
device: str,
model_save_path: str,
model_save_hf_hub_path: str,
static_quant_act: bool,
compile: bool,
):
"""Compare perplexity and speed for behchmarking SmoothQuant"""

Expand All @@ -159,6 +172,8 @@ def compare_models(
.eval()
.to(device)
)
if compile:
model.forward = torch.compile(model.forward, dynamic=True)
base_results = benchmark(
model, tokenizer, max_seq_length, tasks=tasks, device=device
)
Expand All @@ -172,6 +187,8 @@ def compare_models(
.to(device)
)
quantize_(w8a8_model, Int8DynamicActivationInt8WeightConfig())
if compile:
w8a8_model.forward = torch.compile(w8a8_model.forward, dynamic=True)
w8a8_results = benchmark(
w8a8_model, tokenizer, max_seq_length, tasks=tasks, device=device
)
Expand All @@ -187,6 +204,8 @@ def compare_models(
device,
model_save_path,
model_save_hf_hub_path,
static_quant_act,
compile,
)

# Calculate changes and display results
Expand Down Expand Up @@ -289,6 +308,16 @@ def create_parser() -> argparse.ArgumentParser:
default=None,
help="Huggingface hub path to store the quantized model and tokenizer.",
)
parser.add_argument(
"--static-quant-act",
action="store_true",
help="Use static quantization of activation instead of dynamic quantization.",
)
parser.add_argument(
"--compile",
action="store_true",
help="Use torch.compile to compile the model for potentially better performance.",
)

return parser

Expand All @@ -306,4 +335,6 @@ def create_parser() -> argparse.ArgumentParser:
args.device,
args.model_save_path,
args.model_save_hf_hub_path,
args.static_quant_act,
args.compile,
)
1 change: 1 addition & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
Int8DynamicActivationInt4WeightConfig,
Int8DynamicActivationInt8WeightConfig,
Int8DynamicActivationIntxWeightConfig,
Int8StaticActivationInt8WeightConfig,
Int8WeightOnlyConfig,
IntxWeightOnlyConfig,
ModuleFqnToConfig,
Expand Down
Loading