@@ -62,9 +62,8 @@ class Int8Tensor(TorchAOBaseTensor):
6262
6363 # TODO: Static quantization support using `static_scale`
6464 tensor_data_names = ["qdata" , "scale" ]
65- tensor_attribute_names = []
65+ tensor_attribute_names = ["block_size" ]
6666 optional_tensor_attribute_names = [
67- "block_size" ,
6867 "act_quant_kwargs" ,
6968 "dtype" ,
7069 ]
@@ -73,7 +72,7 @@ def __new__(
7372 cls : type ,
7473 qdata : torch .Tensor ,
7574 scale : torch .Tensor ,
76- block_size : Optional [ List [int ] ] = None ,
75+ block_size : List [int ] = None ,
7776 act_quant_kwargs : Optional [QuantizeTensorToInt8Kwargs ] = None ,
7877 dtype : Optional [torch .dtype ] = None ,
7978 ):
@@ -88,7 +87,7 @@ def __init__(
8887 self ,
8988 qdata : torch .Tensor ,
9089 scale : torch .Tensor ,
91- block_size : Optional [ List [int ]] = None ,
90+ block_size : List [int ],
9291 act_quant_kwargs : Optional [QuantizeTensorToInt8Kwargs ] = None ,
9392 dtype : Optional [torch .dtype ] = None ,
9493 ):
@@ -145,7 +144,7 @@ def from_hp(
145144 return cls (
146145 int_data ,
147146 scale ,
148- block_size = block_size ,
147+ block_size ,
149148 act_quant_kwargs = act_quant_kwargs ,
150149 dtype = hp_tensor .dtype ,
151150 )
@@ -273,7 +272,7 @@ def _(func, types, args, kwargs):
273272 Int8Tensor (
274273 sliced_qdata ,
275274 sliced_scale ,
276- block_size = block_size ,
275+ block_size ,
277276 act_quant_kwargs = self .act_quant_kwargs ,
278277 dtype = self .dtype ,
279278 ),
0 commit comments