@@ -62,23 +62,22 @@ class Int8Tensor(TorchAOBaseTensor):
6262
6363 # TODO: Static quantization support using `static_scale`
6464 tensor_data_names = ["qdata" , "scale" ]
65- tensor_attribute_names = ["block_size" ]
65+ tensor_attribute_names = ["block_size" , "dtype" ]
6666 optional_tensor_attribute_names = [
6767 "act_quant_kwargs" ,
68- "dtype" ,
6968 ]
7069
7170 def __new__ (
7271 cls : type ,
7372 qdata : torch .Tensor ,
7473 scale : torch .Tensor ,
75- block_size : List [int ] = None ,
74+ block_size : List [int ],
75+ dtype : torch .dtype ,
7676 act_quant_kwargs : Optional [QuantizeTensorToInt8Kwargs ] = None ,
77- dtype : Optional [torch .dtype ] = None ,
7877 ):
7978 kwargs = {
8079 "device" : qdata .device ,
81- "dtype" : dtype or scale . dtype ,
80+ "dtype" : dtype ,
8281 "requires_grad" : False ,
8382 }
8483 return torch .Tensor ._make_wrapper_subclass (cls , qdata .shape , ** kwargs )
@@ -88,13 +87,14 @@ def __init__(
8887 qdata : torch .Tensor ,
8988 scale : torch .Tensor ,
9089 block_size : List [int ],
90+ dtype : torch .dtype ,
9191 act_quant_kwargs : Optional [QuantizeTensorToInt8Kwargs ] = None ,
92- dtype : Optional [torch .dtype ] = None ,
9392 ):
9493 super ().__init__ ()
9594 self .qdata = qdata
9695 self .scale = scale
9796 self .block_size = block_size
97+ # don't set dtype because this gets done in __new__
9898 self .act_quant_kwargs = act_quant_kwargs
9999
100100 def __repr__ (self ):
@@ -145,8 +145,8 @@ def from_hp(
145145 int_data ,
146146 scale ,
147147 block_size ,
148+ hp_tensor .dtype ,
148149 act_quant_kwargs = act_quant_kwargs ,
149- dtype = hp_tensor .dtype ,
150150 )
151151
152152 def dequantize (self , output_dtype : Optional [torch .dtype ] = None ) -> torch .Tensor :
@@ -159,7 +159,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
159159 input_dtype = torch .int8 ,
160160 quant_min = - 128 ,
161161 quant_max = 127 ,
162- output_dtype = output_dtype or self .dtype ,
162+ output_dtype = output_dtype if output_dtype is not None else self .dtype ,
163163 )
164164
165165
@@ -273,8 +273,8 @@ def _(func, types, args, kwargs):
273273 sliced_qdata ,
274274 sliced_scale ,
275275 block_size ,
276+ self .dtype ,
276277 act_quant_kwargs = self .act_quant_kwargs ,
277- dtype = self .dtype ,
278278 ),
279279 )
280280
@@ -304,8 +304,8 @@ def _(func, types, args, kwargs):
304304 old_int8_tensor .qdata [index ],
305305 old_int8_tensor .scale [index ],
306306 old_int8_tensor .block_size [1 :],
307- old_int8_tensor .act_quant_kwargs ,
308307 old_int8_tensor .dtype ,
308+ old_int8_tensor .act_quant_kwargs ,
309309 )
310310 return return_and_correct_aliasing (func , args , kwargs , new_int8_tensor )
311311
0 commit comments