Skip to content

Commit f28df4a

Browse files
committed
make block size required attribute
1 parent ac6a2b6 commit f28df4a

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,8 @@ class Int8Tensor(TorchAOBaseTensor):
6262

6363
# TODO: Static quantization support using `static_scale`
6464
tensor_data_names = ["qdata", "scale"]
65-
tensor_attribute_names = []
65+
tensor_attribute_names = ["block_size"]
6666
optional_tensor_attribute_names = [
67-
"block_size",
6867
"act_quant_kwargs",
6968
"dtype",
7069
]
@@ -73,7 +72,7 @@ def __new__(
7372
cls: type,
7473
qdata: torch.Tensor,
7574
scale: torch.Tensor,
76-
block_size: Optional[List[int]] = None,
75+
block_size: List[int] = None,
7776
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
7877
dtype: Optional[torch.dtype] = None,
7978
):
@@ -88,7 +87,7 @@ def __init__(
8887
self,
8988
qdata: torch.Tensor,
9089
scale: torch.Tensor,
91-
block_size: Optional[List[int]] = None,
90+
block_size: List[int],
9291
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
9392
dtype: Optional[torch.dtype] = None,
9493
):
@@ -145,7 +144,7 @@ def from_hp(
145144
return cls(
146145
int_data,
147146
scale,
148-
block_size=block_size,
147+
block_size,
149148
act_quant_kwargs=act_quant_kwargs,
150149
dtype=hp_tensor.dtype,
151150
)
@@ -273,7 +272,7 @@ def _(func, types, args, kwargs):
273272
Int8Tensor(
274273
sliced_qdata,
275274
sliced_scale,
276-
block_size=block_size,
275+
block_size,
277276
act_quant_kwargs=self.act_quant_kwargs,
278277
dtype=self.dtype,
279278
),

0 commit comments

Comments
 (0)