Skip to content

Commit ce4d568

Browse files
committed
address nits
1 parent 328585e commit ce4d568

File tree

3 files changed

+5
-5
lines changed

3 files changed

+5
-5
lines changed

torchao/quantization/quant_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1526,7 +1526,7 @@ class Int8DynamicActivationInt8WeightConfig(AOBaseConfig):
15261526
layout: Optional[Layout] = PlainLayout()
15271527
act_mapping_type: Optional[MappingType] = MappingType.SYMMETRIC
15281528
weight_only_decode: bool = False
1529-
granularity: Union[PerRow, PerTensor] = PerRow()
1529+
granularity: Granularity = PerRow()
15301530
set_inductor_config: bool = True
15311531
version: int = 1
15321532

torchao/quantization/quantize_/common/quantize_tensor_kwargs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def _choose_quant_func_and_quantize_tensor(
5959
return Int8Tensor.from_hp(
6060
tensor,
6161
quant_kwargs.granularity,
62-
quant_kwargs.mapping_type,
62+
mapping_type=quant_kwargs.mapping_type,
6363
)
6464

6565
raise NotImplementedError(f"Quant kwargs not supported: {quant_kwargs}")

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
from torchao.float8.inference import _slice_scale_for_dimension
1414
from torchao.kernel import int_scaled_matmul
15-
from torchao.quantization.granularity import Granularity, PerRow
15+
from torchao.quantization.granularity import Granularity
1616
from torchao.quantization.quant_primitives import (
1717
MappingType,
1818
choose_qparams_affine,
@@ -40,7 +40,7 @@ class QuantizeTensorToInt8Kwargs(QuantizeTensorKwargs):
4040
mapping_type: whether to use symmetric or asymmetric quant, only symmetric is supported currently
4141
"""
4242

43-
granularity: Granularity = PerRow()
43+
granularity: Granularity
4444
mapping_type: MappingType = MappingType.SYMMETRIC
4545

4646

@@ -113,7 +113,7 @@ def __repr__(self):
113113
def from_hp(
114114
cls,
115115
hp_tensor: torch.Tensor,
116-
granularity: Granularity = PerRow(),
116+
granularity: Granularity,
117117
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
118118
mapping_type=MappingType.SYMMETRIC,
119119
):

0 commit comments

Comments
 (0)