Skip to content

Commit 5119fcb

Browse files
authored
ONNX WOQ supports different dtypes (#1490)
ONNX WOQ supports different dtypes Signed-off-by: Mengni Wang <mengni.wang@intel.com> Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
1 parent 08221e1 commit 5119fcb

File tree

6 files changed

+208
-46
lines changed

6 files changed

+208
-46
lines changed

neural_compressor/adaptor/onnxrt.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -979,12 +979,10 @@ def _pre_optimize(self, model, level=1):
979979
sess_options.register_custom_ops_library(get_library_path())
980980

981981
if not model.is_large_model:
982-
sess = ort.InferenceSession(
983-
model.model.SerializeToString(), sess_options, providers=["CPUExecutionProvider"]
984-
)
982+
sess = ort.InferenceSession(model.model.SerializeToString(), sess_options, providers=[self.backend])
985983
elif model.model_path is not None: # pragma: no cover
986984
model.model = onnx.ModelProto() # clean memory for large model
987-
sess = ort.InferenceSession(model.model_path, sess_options, providers=["CPUExecutionProvider"])
985+
sess = ort.InferenceSession(model.model_path, sess_options, providers=[self.backend])
988986
else: # pragma: no cover
989987
logger.warning("Please use model path instead of onnx model object to quantize")
990988
del sess
@@ -1914,6 +1912,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
19141912
mse=mse,
19151913
perchannel=perchannel,
19161914
accuracy_level=accuracy_level,
1915+
providers=[self.backend],
19171916
)
19181917
if "AWQ" in algos:
19191918
from neural_compressor.adaptor.ox_utils.weight_only import awq_quantize
@@ -1931,6 +1930,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
19311930
enable_auto_scale=enable_auto_scale,
19321931
enable_mse_search=enable_mse_search,
19331932
accuracy_level=accuracy_level,
1933+
providers=[self.backend],
19341934
)
19351935
elif "RTN" in algos:
19361936
from neural_compressor.adaptor.ox_utils.weight_only import rtn_quantize
@@ -1940,6 +1940,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
19401940
tmp_model,
19411941
quant_config,
19421942
accuracy_level=accuracy_level,
1943+
providers=[self.backend],
19431944
)
19441945
tmp_model.q_config = copy.deepcopy(quant_config)
19451946
self._dump_model_op_stats(tmp_model, tune_cfg)

neural_compressor/adaptor/onnxrt_cuda.yaml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,20 @@
1717
-
1818
version:
1919
name: '1.6.0'
20+
weight_only_integer: &cap_weight_only {
21+
'MatMul': &cap_weight_only_matmul {
22+
'weight': {
23+
'dtype': ['int'], # no need to care uint
24+
'bits': [4, 3, 8], # [1-8]
25+
'group_size': [32, -1, 1, 16, 64, 128, 256, 512, 1024], # [1-inf]
26+
'scheme': ['sym', 'asym'], # sym, no ZP
27+
'algorithm': ['RTN', 'AWQ', 'GPTQ']
28+
},
29+
'activation': {
30+
'dtype': ['fp32']
31+
}
32+
},
33+
}
2034
int8: &ref_1_6 {
2135
'static': &ref_1_6_static {
2236
'Conv': {
@@ -114,6 +128,7 @@
114128
-
115129
version:
116130
name: '1.7.0'
131+
weight_only_integer: *cap_weight_only
117132
int8: {
118133
'static': {
119134
'FusedConv': {
@@ -155,6 +170,7 @@
155170
-
156171
version:
157172
name: '1.8.0'
173+
weight_only_integer: *cap_weight_only
158174
int8: {
159175
'static': {
160176
'FusedConv': {
@@ -224,6 +240,7 @@
224240
-
225241
version:
226242
name: '1.9.0'
243+
weight_only_integer: *cap_weight_only
227244
int8: {
228245
'static': {
229246
'FusedConv': {
@@ -300,6 +317,7 @@
300317
-
301318
version:
302319
name: '1.10.0'
320+
weight_only_integer: *cap_weight_only
303321
int8: {
304322
'static': {
305323
'FusedConv': {
@@ -356,6 +374,7 @@
356374
-
357375
version:
358376
name: '1.11.0'
377+
weight_only_integer: *cap_weight_only
359378
int8: &ref_1_11 {
360379
'static': {
361380
'FusedConv': {
@@ -427,6 +446,7 @@
427446
-
428447
version:
429448
name: '1.12.0'
449+
weight_only_integer: *cap_weight_only
430450
int8: *ref_1_11
431451
fp16: *common_fp16
432452
bf16: *common_bf16
@@ -436,6 +456,7 @@
436456
-
437457
version:
438458
name: 'default'
459+
weight_only_integer: *cap_weight_only
439460
int8: *ref_1_6
440461
fp16: *common_fp16
441462
bf16: *common_bf16

neural_compressor/adaptor/ox_utils/util.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757

5858
dtype_mapping = {
5959
"fp32": 1,
60+
"float32": 1,
6061
"uint8": 2,
6162
"int8": 3,
6263
"uint16": 4,
@@ -66,12 +67,14 @@
6667
"string": 8,
6768
"bool": 9,
6869
"fp16": 10,
70+
"float16": 10,
6971
"double": 11,
7072
"uint32": 12,
7173
"uint64": 13,
7274
"complex64": 14,
7375
"complex128": 15,
7476
"bf16": 16,
77+
"bfloat16": 16,
7578
}
7679

7780
PROVIDERS = {

0 commit comments

Comments
 (0)