11import argparse
22import os
33import sys
4+
45sys .path .append ('./' )
56import time
67import json
3334 '--seed' ,
3435 type = int , default = 42 , help = 'Seed for sampling the calibration data.'
3536)
36- parser .add_argument ("--approach" , type = str , default = 'static' ,
37+ parser .add_argument ("--approach" , type = str , default = 'static' ,
3738 help = "Select from ['dynamic', 'static', 'weight-only']" )
3839parser .add_argument ("--int8" , action = "store_true" )
3940parser .add_argument ("--ipex" , action = "store_true" , help = "Use intel extension for pytorch." )
5051parser .add_argument ("--calib_iters" , default = 512 , type = int ,
5152 help = "calibration iters." )
5253parser .add_argument ("--tasks" , nargs = '+' , default = ["lambada_openai" ,
53- "hellaswag" ,"winogrande" ,"piqa" ,"wikitext" ],
54- type = str , help = "tasks list for accuracy validation" )
54+ "hellaswag" , "winogrande" , "piqa" , "wikitext" ],
55+ type = str , help = "tasks list for accuracy validation" )
5556parser .add_argument ("--peft_model_id" , type = str , default = None , help = "model_name_or_path of peft model" )
5657# ============SmoothQuant configs==============
5758parser .add_argument ("--sq" , action = "store_true" )
5859parser .add_argument ("--alpha" , default = "auto" , help = "Smooth quant parameter." )
5960# ============WeightOnly configs===============
60- parser .add_argument ("--woq_algo" , default = "RTN" , choices = ['RTN' , 'AWQ' , 'TEQ' , 'GPTQ' ],
61+ parser .add_argument ("--woq_algo" , default = "RTN" , choices = ['RTN' , 'AWQ' , 'TEQ' , 'GPTQ' ],
6162 help = "Weight-only parameter." )
6263parser .add_argument ("--woq_bits" , type = int , default = 8 )
6364parser .add_argument ("--woq_group_size" , type = int , default = - 1 )
6465parser .add_argument ("--woq_scheme" , default = "sym" )
6566parser .add_argument ("--woq_enable_mse_search" , action = "store_true" )
6667parser .add_argument ("--woq_enable_full_range" , action = "store_true" )
6768# =============GPTQ configs====================
68- parser .add_argument ("--gptq_actorder" , action = "store_true" , help = "Whether to apply the activation order GPTQ heuristic." )
69- parser .add_argument ('--gptq_percdamp' , type = float , default = .01 , help = 'Percent of the average Hessian diagonal to use for dampening.' )
69+ parser .add_argument ("--gptq_actorder" , action = "store_true" ,
70+ help = "Whether to apply the activation order GPTQ heuristic." )
71+ parser .add_argument ('--gptq_percdamp' , type = float , default = .01 ,
72+ help = 'Percent of the average Hessian diagonal to use for dampening.' )
7073parser .add_argument ('--gptq_block_size' , type = int , default = 128 , help = 'Block size. sub weight matrix size to run GPTQ.' )
7174parser .add_argument ('--gptq_nsamples' , type = int , default = 128 , help = 'Number of calibration data samples.' )
72- parser .add_argument ('--gptq_use_max_length' , action = "store_true" , help = 'Set all sequence length to be same length of args.gptq_pad_max_length' )
75+ parser .add_argument ('--gptq_use_max_length' , action = "store_true" ,
76+ help = 'Set all sequence length to be same length of args.gptq_pad_max_length' )
7377parser .add_argument ('--gptq_pad_max_length' , type = int , default = 2048 , help = 'Calibration dataset sequence max length, \
7478 this should align with your model config, \
7579 and your dataset builder args: args.pad_max_length' )
7680parser .add_argument ('--gptq_debug' , action = 'store_true' , help = 'Whether to use debug model ' )
77- parser .add_argument ('--gptq_gpu' , action = 'store_true' , help = 'Whether to use gpu' )
7881# =======================================
7982
8083args = parser .parse_args ()
8184if args .ipex :
8285 import intel_extension_for_pytorch as ipex
8386calib_size = 1
8487
88+
8589class Evaluator :
8690 def __init__ (self , dataset , tokenizer , batch_size = 8 , pad_val = 1 , pad_max = 196 , is_calib = False ):
8791 self .dataset = dataset
@@ -149,7 +153,7 @@ def evaluate(self, model):
149153 pred = last_token_logits .argmax (dim = - 1 )
150154 total += label .size (0 )
151155 hit += (pred == label ).sum ().item ()
152- if (i + 1 ) % 50 == 0 :
156+ if (i + 1 ) % 50 == 0 :
153157 print (hit / total )
154158 print ("Processed minibatch:" , i )
155159
@@ -187,6 +191,7 @@ def get_user_model():
187191 user_model .eval ()
188192 return user_model , tokenizer
189193
194+
190195if args .quantize :
191196 # dataset
192197 user_model , tokenizer = get_user_model ()
@@ -201,43 +206,46 @@ def get_user_model():
201206 collate_fn = calib_evaluator .collate_batch ,
202207 )
203208
209+
204210 def calib_func (prepared_model ):
205211 for i , calib_input in enumerate (calib_dataloader ):
206212 if i > args .calib_iters :
207213 break
208214 prepared_model (calib_input [0 ])
209215
216+
210217 recipes = {}
211218 eval_func = None
212219 from neural_compressor import PostTrainingQuantConfig , quantization
220+
213221 # specify the op_type_dict and op_name_dict
214222 if args .approach == 'weight_only' :
215223 op_type_dict = {
216- '.*' :{ # re.match
224+ '.*' : { # re.match
217225 "weight" : {
218- 'bits' : args .woq_bits , # 1-8 bits
226+ 'bits' : args .woq_bits , # 1-8 bits
219227 'group_size' : args .woq_group_size , # -1 (per-channel)
220- 'scheme' : args .woq_scheme , # sym/asym
221- 'algorithm' : args .woq_algo , # RTN/AWQ/TEQ
228+ 'scheme' : args .woq_scheme , # sym/asym
229+ 'algorithm' : args .woq_algo , # RTN/AWQ/TEQ
222230 },
223231 },
224232 }
225- op_name_dict = {
226- 'lm_head' :{"weight" : {'dtype' : 'fp32' },},
227- 'embed_out' :{"weight" : {'dtype' : 'fp32' },}, # for dolly_v2
233+ op_name_dict = {
234+ 'lm_head' : {"weight" : {'dtype' : 'fp32' }, },
235+ 'embed_out' : {"weight" : {'dtype' : 'fp32' }, }, # for dolly_v2
228236 }
229237 recipes ["rtn_args" ] = {
230238 "enable_mse_search" : args .woq_enable_mse_search ,
231239 "enable_full_range" : args .woq_enable_full_range ,
232240 }
233241 recipes ['gptq_args' ] = {
234- 'percdamp' : args .gptq_percdamp ,
235- 'act_order' :args .gptq_actorder ,
236- 'block_size' : args .gptq_block_size ,
237- 'nsamples' : args .gptq_nsamples ,
238- 'use_max_length' : args .gptq_use_max_length ,
239- 'pad_max_length' : args .gptq_pad_max_length
240- }
242+ 'percdamp' : args .gptq_percdamp ,
243+ 'act_order' : args .gptq_actorder ,
244+ 'block_size' : args .gptq_block_size ,
245+ 'nsamples' : args .gptq_nsamples ,
246+ 'use_max_length' : args .gptq_use_max_length ,
247+ 'pad_max_length' : args .gptq_pad_max_length
248+ }
241249 # GPTQ: use assistive functions to modify calib_dataloader and calib_func
242250 # TEQ: set calib_func=None, use default training func as calib_func
243251 if args .woq_algo in ["GPTQ" , "TEQ" ]:
@@ -253,30 +261,32 @@ def calib_func(prepared_model):
253261 # for test on various models, keep the code of directly call gptq_quantize
254262 if args .gptq_debug :
255263 from neural_compressor .adaptor .torch_utils .weight_only import gptq_quantize
264+
256265 conf = {
257- ".*" :{
258- 'wbits' : args .woq_bits , # 1-8 bits
266+ ".*" : {
267+ 'wbits' : args .woq_bits , # 1-8 bits
259268 'group_size' : args .woq_group_size , # -1 (per-channel)
260269 'sym' : (args .woq_scheme == "sym" ),
261270 'act_order' : args .gptq_actorder ,
262271 }
263- }
272+ }
264273 q_model_gptq_debug , gptq_config = gptq_quantize (
265- user_model ,
266- weight_config = conf ,
267- dataloader = calib_dataloader ,
268- nsamples = args .gptq_nsamples ,
269- use_max_length = args .gptq_use_max_length ,
270- pad_max_length = args .gptq_pad_max_length
274+ user_model ,
275+ weight_config = conf ,
276+ dataloader = calib_dataloader ,
277+ nsamples = args .gptq_nsamples ,
278+ use_max_length = args .gptq_use_max_length ,
279+ pad_max_length = args .gptq_pad_max_length
271280 )
272281 from intel_extension_for_transformers .llm .evaluation .lm_eval import evaluate
282+
273283 results = evaluate (
274284 model = "hf-causal" ,
275- model_args = 'pretrained=' + args .model + ',tokenizer=' + args .model + ',dtype=float32' ,
285+ model_args = 'pretrained=' + args .model + ',tokenizer=' + args .model + ',dtype=float32' ,
276286 user_model = q_model_gptq_debug , tasks = ["lambada_openai" ],
277- device = DEV .type ,
278287 batch_size = 4
279288 )
289+ exit (0 )
280290
281291 else :
282292 if re .search ("gpt" , user_model .config .model_type ):
@@ -306,6 +316,8 @@ def calib_func(prepared_model):
306316 if isinstance (args .alpha , list ):
307317 eval_dataset = load_dataset ('lambada' , split = 'validation' )
308318 evaluator = Evaluator (eval_dataset , tokenizer )
319+
320+
309321 def eval_func (model ):
310322 acc = evaluator .evaluate (model )
311323 return acc
@@ -323,6 +335,7 @@ def eval_func(model):
323335if args .int8 or args .int8_bf16_mixed :
324336 print ("load int8 model" )
325337 from neural_compressor .utils .pytorch import load
338+
326339 if args .ipex :
327340 user_model = load (os .path .abspath (os .path .expanduser (args .output_dir )))
328341 else :
@@ -335,9 +348,10 @@ def eval_func(model):
335348if args .accuracy :
336349 user_model .eval ()
337350 from intel_extension_for_transformers .llm .evaluation .lm_eval import evaluate
351+
338352 results = evaluate (
339353 model = "hf-causal" ,
340- model_args = 'pretrained=' + args .model + ',tokenizer=' + args .model + ',dtype=float32' ,
354+ model_args = 'pretrained=' + args .model + ',tokenizer=' + args .model + ',dtype=float32' ,
341355 user_model = user_model ,
342356 batch_size = args .batch_size ,
343357 tasks = args .tasks ,
@@ -358,11 +372,12 @@ def eval_func(model):
358372 user_model .eval ()
359373 from intel_extension_for_transformers .llm .evaluation .lm_eval import evaluate
360374 import time
375+
361376 samples = args .iters * args .batch_size
362377 start = time .time ()
363378 results = evaluate (
364379 model = "hf-causal" ,
365- model_args = 'pretrained=' + args .model + ',tokenizer=' + args .model + ',dtype=float32' ,
380+ model_args = 'pretrained=' + args .model + ',tokenizer=' + args .model + ',dtype=float32' ,
366381 user_model = user_model ,
367382 batch_size = args .batch_size ,
368383 tasks = args .tasks ,
@@ -376,5 +391,5 @@ def eval_func(model):
376391 acc = results ["results" ][task_name ]["acc" ]
377392 print ("Accuracy: %.5f" % acc )
378393 print ('Throughput: %.3f samples/sec' % (samples / (end - start )))
379- print ('Latency: %.3f ms' % ((end - start )* 1000 / samples ))
394+ print ('Latency: %.3f ms' % ((end - start ) * 1000 / samples ))
380395 print ('Batch size = %d' % args .batch_size )
0 commit comments