Skip to content

Commit 1a5e007

Browse files
xin3heXuehaoSun
authored andcommitted
fix torchao renaming issue (#2277)
Signed-off-by: xinhe3 <xinhe3@habana.ai>
1 parent 0d202cb commit 1a5e007

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

neural_compressor/torch/algorithms/fp8_quant/_core/quantized_func_wrappers/cpu/cpu_quantized_func_wrapper.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
from ..quantized_func_wrapper import QuantizedFuncWrapperBase, OP_TYPE, QuantizedFuncWrapperFactory
1616

1717
import torch
18-
import torchao
18+
from torchao.quantization.quant_primitives import (
19+
_quantize_affine_float8,
20+
_dequantize_affine_float8,
21+
)
1922

2023
from abc import ABCMeta
2124

@@ -32,7 +35,7 @@ def __init__(self, scale_format, is_dynamic=False):
3235
class QuantizedCPUQuant(QuantizedCPUFuncWrapperBase):
3336

3437
def get_default_quantized_func(self):
35-
return torch.ops.torchao.quantize_affine_float8
38+
return _quantize_affine_float8
3639

3740
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn):
3841
return self._quantized_func_(tensor=input, scale=torch.tensor(scale), float8_dtype=dtype)
@@ -41,7 +44,7 @@ def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_
4144
class QuantizedCPUQuantPC(QuantizedCPUFuncWrapperBase):
4245

4346
def get_default_quantized_func(self):
44-
return torch.ops.torchao.quantize_affine_float8
47+
return _quantize_affine_float8
4548

4649
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn):
4750
return self._quantized_func_(tensor=input, scale=scale.view((-1, 1)), float8_dtype=dtype)
@@ -50,7 +53,7 @@ def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_
5053
class QuantizedCPUDeQuant(QuantizedCPUFuncWrapperBase):
5154

5255
def get_default_quantized_func(self):
53-
return torch.ops.torchao.dequantize_affine_float8
56+
return _dequantize_affine_float8
5457

5558
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn, out_dtype=torch.bfloat16):
5659
return self._quantized_func_(tensor=input, scale=torch.tensor(scale), output_dtype=out_dtype)
@@ -59,7 +62,7 @@ def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_
5962
class QuantizedCPUDeQuantPC(QuantizedCPUFuncWrapperBase):
6063

6164
def get_default_quantized_func(self):
62-
return torch.ops.torchao.dequantize_affine_float8
65+
return _dequantize_affine_float8
6366

6467
def __call__(self, input, scale, zero_point=None, axis=0, quant_min=None, quant_max=None, dtype=torch.float8_e4m3fn, out_dtype=torch.bfloat16):
6568
return self._quantized_func_(tensor=input, scale=scale.view((1, -1)), output_dtype=out_dtype)

0 commit comments

Comments
 (0)