Skip to content

Commit 328585e

Browse files
committed
made dtype required as well
1 parent f28df4a commit 328585e

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

torchao/quantization/quant_api.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1610,7 +1610,6 @@ def get_weight_block_size(x):
16101610
quantized_weight = Int8Tensor.from_hp(
16111611
weight,
16121612
granularity=weight_granularity,
1613-
mapping_type=MappingType.SYMMETRIC,
16141613
act_quant_kwargs=QuantizeTensorToInt8Kwargs(
16151614
granularity=act_granularity,
16161615
mapping_type=config.act_mapping_type,

torchao/quantization/quantize_/workflows/int8/int8_tensor.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -62,23 +62,22 @@ class Int8Tensor(TorchAOBaseTensor):
6262

6363
# TODO: Static quantization support using `static_scale`
6464
tensor_data_names = ["qdata", "scale"]
65-
tensor_attribute_names = ["block_size"]
65+
tensor_attribute_names = ["block_size", "dtype"]
6666
optional_tensor_attribute_names = [
6767
"act_quant_kwargs",
68-
"dtype",
6968
]
7069

7170
def __new__(
7271
cls: type,
7372
qdata: torch.Tensor,
7473
scale: torch.Tensor,
75-
block_size: List[int] = None,
74+
block_size: List[int],
75+
dtype: torch.dtype,
7676
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
77-
dtype: Optional[torch.dtype] = None,
7877
):
7978
kwargs = {
8079
"device": qdata.device,
81-
"dtype": dtype or scale.dtype,
80+
"dtype": dtype,
8281
"requires_grad": False,
8382
}
8483
return torch.Tensor._make_wrapper_subclass(cls, qdata.shape, **kwargs)
@@ -88,13 +87,14 @@ def __init__(
8887
qdata: torch.Tensor,
8988
scale: torch.Tensor,
9089
block_size: List[int],
90+
dtype: torch.dtype,
9191
act_quant_kwargs: Optional[QuantizeTensorToInt8Kwargs] = None,
92-
dtype: Optional[torch.dtype] = None,
9392
):
9493
super().__init__()
9594
self.qdata = qdata
9695
self.scale = scale
9796
self.block_size = block_size
97+
# don't set dtype because this gets done in __new__
9898
self.act_quant_kwargs = act_quant_kwargs
9999

100100
def __repr__(self):
@@ -145,8 +145,8 @@ def from_hp(
145145
int_data,
146146
scale,
147147
block_size,
148+
hp_tensor.dtype,
148149
act_quant_kwargs=act_quant_kwargs,
149-
dtype=hp_tensor.dtype,
150150
)
151151

152152
def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
@@ -159,7 +159,7 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor
159159
input_dtype=torch.int8,
160160
quant_min=-128,
161161
quant_max=127,
162-
output_dtype=output_dtype or self.dtype,
162+
output_dtype=output_dtype if output_dtype is not None else self.dtype,
163163
)
164164

165165

@@ -273,8 +273,8 @@ def _(func, types, args, kwargs):
273273
sliced_qdata,
274274
sliced_scale,
275275
block_size,
276+
self.dtype,
276277
act_quant_kwargs=self.act_quant_kwargs,
277-
dtype=self.dtype,
278278
),
279279
)
280280

@@ -304,8 +304,8 @@ def _(func, types, args, kwargs):
304304
old_int8_tensor.qdata[index],
305305
old_int8_tensor.scale[index],
306306
old_int8_tensor.block_size[1:],
307-
old_int8_tensor.act_quant_kwargs,
308307
old_int8_tensor.dtype,
308+
old_int8_tensor.act_quant_kwargs,
309309
)
310310
return return_and_correct_aliasing(func, args, kwargs, new_int8_tensor)
311311

0 commit comments

Comments
 (0)