@@ -1628,26 +1628,37 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
16281628 Returns:
16291629 (dict): quantized model
16301630 """
1631+ if self .performance_only :
1632+ tmp_model = model
1633+ else :
1634+ try :
1635+ tmp_model = copy .deepcopy (model )
1636+ except Exception as e : # pragma: no cover
1637+ logger .warning ("Fail to deep copy the model due to {}, inplace is used now." .format (repr (e )))
1638+ tmp_model = model
1639+
16311640 assert q_func is None , "quantization aware training has not been supported on ONNXRUNTIME"
16321641 for precision in self .query_handler .get_precisions ():
16331642 if precision == "weight_only_integer" :
16341643 self .quantizable_op_types += self .query_handler .get_op_types_by_precision (precision = precision )
1635- self .quantizable_ops = self ._query_quantizable_ops (model .model )
1644+ self .quantizable_ops = self ._query_quantizable_ops (tmp_model .model )
16361645
1646+ self ._update_tune_cfg (tune_cfg , tmp_model .model )
16371647 quant_config = self ._cfg_to_quantize_config (tune_cfg )
16381648 algos = set ([item ["algorithm" ] for key , item in quant_config .items () if isinstance (item , dict )])
16391649 if "GPTQ" in algos :
16401650 from neural_compressor .adaptor .ox_utils .weight_only import gptq_quantize
16411651
1652+ assert data_loader is not None , "GPTQ WOQ algorithm needs to pass 'calib_dataloader' to quantization.fit()"
16421653 percdamp = self .recipes .get ("gptq_args" , {}).get ("percdamp" , 0.01 )
16431654 blocksize = self .recipes .get ("gptq_args" , {}).get ("blocksize" , 128 )
16441655 actorder = self .recipes .get ("gptq_args" , {}).get ("actorder" , False )
16451656 mse = self .recipes .get ("gptq_args" , {}).get ("mse" , False )
16461657 perchannel = self .recipes .get ("gptq_args" , {}).get ("perchannel" , True )
16471658 accuracy_level = self .recipes .get ("gptq_args" , {}).get ("accuracy_level" , 0 )
16481659 calib_sampling_size = tune_cfg .get ("calib_sampling_size" , 1 )
1649- model = gptq_quantize (
1650- model ,
1660+ tmp_model = gptq_quantize (
1661+ tmp_model ,
16511662 data_loader ,
16521663 quant_config ,
16531664 n_samples = calib_sampling_size ,
@@ -1661,12 +1672,13 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
16611672 if "AWQ" in algos :
16621673 from neural_compressor .adaptor .ox_utils .weight_only import awq_quantize
16631674
1675+ assert data_loader is not None , "AWQ WOQ algorithm needs to pass 'calib_dataloader' to quantization.fit()"
16641676 enable_auto_scale = self .recipes .get ("awq_args" , {}).get ("enable_auto_scale" , True )
16651677 enable_mse_search = self .recipes .get ("awq_args" , {}).get ("enable_mse_search" , True )
16661678 accuracy_level = self .recipes .get ("awq_args" , {}).get ("accuracy_level" , 0 )
16671679 calib_sampling_size = tune_cfg .get ("calib_sampling_size" , 1 )
1668- model = awq_quantize (
1669- model ,
1680+ tmp_model = awq_quantize (
1681+ tmp_model ,
16701682 data_loader ,
16711683 quant_config ,
16721684 n_samples = calib_sampling_size ,
@@ -1683,6 +1695,7 @@ def quantize(self, tune_cfg, model, data_loader, q_func=None):
16831695 quant_config ,
16841696 accuracy_level = accuracy_level ,
16851697 )
1698+ tmp_model = rtn_quantize (tmp_model , quant_config )
16861699 tmp_model .q_config = copy .deepcopy (quant_config )
16871700 self ._dump_model_op_stats (tmp_model , tune_cfg )
16881701 tmp_model .topological_sort ()
@@ -1752,6 +1765,31 @@ def _cfg_to_quantize_config(self, tune_cfg):
17521765
17531766 return quantize_config
17541767
1768+ def _update_tune_cfg (self , tune_cfg , model ):
1769+ """Update tune cfg according to woq_tuning_cfg."""
1770+ if tune_cfg .get ("woq_tuning_cfg" ) is None :
1771+ return tune_cfg
1772+
1773+ from neural_compressor .strategy .utils .constant import WOQ_TUNING_ALGOS
1774+
1775+ woq_tuning_cfg = tune_cfg .get ("woq_tuning_cfg" )
1776+ new_woq_cfg = WOQ_TUNING_ALGOS .get (woq_tuning_cfg )
1777+
1778+ for node_cfg in tune_cfg ["op" ].values ():
1779+ node_cfg ["weight" ].update (
1780+ {cfg_name : cfg_value for cfg_name , cfg_value in new_woq_cfg .items () if cfg_name in node_cfg ["weight" ]}
1781+ )
1782+
1783+ # find last matmul and set to fp32
1784+ if "DISABLE_LAST_MATMUL" in woq_tuning_cfg :
1785+ last_matmul = None
1786+ fp32_op_cfg = {"weight" : {"dtype" : "fp32" }, "activation" : {"dtype" : "fp32" , "quant_mode" : "fp32" }}
1787+ for node in model .graph .node :
1788+ if node .op_type in ["MatMul" ]:
1789+ last_matmul = (node .name , node .op_type )
1790+ if last_matmul in tune_cfg ["op" ]:
1791+ tune_cfg ["op" ][last_matmul ].update (fp32_op_cfg )
1792+
17551793 def query_fw_capability (self , model ):
17561794 """The function is used to query framework capability.
17571795 TODO: will be replaced by framework query API
0 commit comments