Skip to content

Commit ac6a2b6

Browse files
committed
update _choose_quant_func_and_quantize_tensor
1 parent 7f73062 commit ac6a2b6

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

torchao/quantization/quant_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,6 +1610,7 @@ def get_weight_block_size(x):
16101610
quantized_weight = Int8Tensor.from_hp(
16111611
weight,
16121612
granularity=weight_granularity,
1613+
mapping_type=MappingType.SYMMETRIC,
16131614
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
16141615
granularity=act_granularity,
16151616
mapping_type=config.act_mapping_type,

torchao/quantization/quantize_/common/quantize_tensor_kwargs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def _choose_quant_func_and_quantize_tensor(
3939
"""
4040
from torchao.quantization.quantize_.workflows import (
4141
Float8Tensor,
42+
Int8Tensor,
4243
QuantizeTensorToFloat8Kwargs,
44+
QuantizeTensorToInt8Kwargs,
4345
)
4446

4547
if isinstance(quant_kwargs, QuantizeTensorToFloat8Kwargs):
@@ -53,4 +55,11 @@ def _choose_quant_func_and_quantize_tensor(
5355
quant_kwargs.kernel_preference,
5456
)
5557

58+
if isinstance(quant_kwargs, QuantizeTensorToInt8Kwargs):
59+
return Int8Tensor.from_hp(
60+
tensor,
61+
quant_kwargs.granularity,
62+
quant_kwargs.mapping_type,
63+
)
64+
5665
raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}")

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919
dequantize_affine,
2020
quantize_affine,
2121
)
22-
from torchao.quantization.quantize_.common import QuantizeTensorKwargs
22+
from torchao.quantization.quantize_.common import (
23+
QuantizeTensorKwargs,
24+
_choose_quant_func_and_quantize_tensor,
25+
)
2326
from torchao.quantization.utils import get_block_size
2427
from torchao.utils import TorchAOBaseTensor, fill_defaults
2528

@@ -182,9 +185,8 @@ def _(func, types, args, kwargs):
182185
output_dtype = activation_tensor.dtype
183186

184187
if weight_tensor.act_quant_kwargs is not None:
185-
activation_tensor = Int8Tensor.from_hp(
186-
activation_tensor,
187-
granularity=weight_tensor.act_quant_kwargs.granularity,
188+
activation_tensor = _choose_quant_func_and_quantize_tensor(
189+
activation_tensor, weight_tensor.act_quant_kwargs
188190
)
189191
# Dynamic activation quantization path
190192

0 commit comments

Comments
 (0)