3939 collect_best_params ,
4040 get_shared_keys ,
4141 gguf_args_check ,
42+ immediate_saving ,
4243 infer_bits_by_data_type ,
4344 init_cache ,
4445 is_mx_fp ,
@@ -148,6 +149,7 @@ def __init__(
148149 enable_alg_ext : bool = False ,
149150 disable_opt_rtn : bool = False ,
150151 seed : int = 42 ,
152+ low_cpu_mem_usage : bool = False ,
151153 ** kwargs ,
152154 ):
153155 """Initialize AutoRound with quantization and tuning configuration.
@@ -168,6 +170,7 @@ def __init__(
168170 lr (float, optional): Learning rate; if None, set to 1.0 / iters except when iters==0.
169171 minmax_lr (float, optional): Learning rate for min-max tuning; defaults to `lr`.
170172 low_gpu_mem_usage (bool, optional): Lower GPU memory mode. Defaults to False.
173+ low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False.
171174 iters (int, optional): Optimization iterations. Defaults to 200.
172175 seqlen (int, optional): Calibration sequence length. Defaults to 2048.
173176 nsamples (int, optional): Number of calibration samples. Defaults to 128.
@@ -244,6 +247,7 @@ def __init__(
244247 self .supported_types = SUPPORTED_LAYER_TYPES
245248 self .inner_supported_types = INNER_SUPPORTED_LAYER_TYPES
246249 self .scale_dtype = convert_dtype_str2torch (scale_dtype )
250+ self .low_cpu_mem_usage = low_cpu_mem_usage
247251
248252 if kwargs :
249253 logger .warning (f"unrecognized keys { list (kwargs .keys ())} were passed. Please check them." )
@@ -336,7 +340,10 @@ def __init__(
336340 self .lr_scheduler = lr_scheduler
337341 self .optimizer = self ._get_optimizer (None )
338342 self .disable_opt_rtn = disable_opt_rtn
339- self .is_packing_immediate = False # whether to pack the layer immediately after tuning
343+
344+ # Whether to pack the layer immediately after tuning
345+ self .immediate_packing = False
346+ self .immediate_saving = False
340347
341348 # KV cache, this one does not affect tuning but will collect some infos during tuning
342349 self .static_kv_dtype = static_kv_dtype
@@ -1205,7 +1212,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
12051212 `rtn_*` version if supported, then wraps and unwraps the module to apply
12061213 quantization. If GPU memory is insufficient, it falls back to CPU.
12071214
1208- If packing is enabled (`is_packing_immediate `), the function will also export
1215+ If packing is enabled (`immediate_packing `), the function will also export
12091216 the quantized layer to the appropriate backend format.
12101217
12111218 Args:
@@ -1222,7 +1229,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
12221229
12231230 # Step 1: Try quantization on GPU first, fall back to CPU if OOM
12241231 # if only export gguf, using gguf-packing instead of rtn
1225- if self .is_packing_immediate and self .iters == 0 and "gguf" in self .formats [0 ] and not self .disable_opt_rtn :
1232+ if self .immediate_packing and self .iters == 0 and "gguf" in self .formats [0 ] and not self .disable_opt_rtn :
12261233 m .scale = None
12271234 m .zp = None
12281235 else :
@@ -1259,34 +1266,45 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
12591266 raise
12601267
12611268 # Step 2: Optional immediate packing/export
1262- if self .is_packing_immediate :
1263- from auto_round .export import PACKING_LAYER_WITH_FORMAT
1264-
1265- if check_to_quantized (m ):
1266- target_backend = self .formats [0 ].split (":" )[0 ] if ":" in self .formats [0 ] else self .formats [0 ]
1267- has_gguf = any ("gguf" in fmt for fmt in self .formats )
1268-
1269- if has_gguf :
1270- from auto_round .export .export_to_gguf .export import pack_gguf_layer
1271-
1272- output_dir = self ._get_save_folder_name (self .formats [0 ])
1273- model_type = ModelType .MMPROJ if self .mllm else ModelType .TEXT
1274- pack_gguf_layer (
1275- name ,
1276- self .model ,
1277- self .formats [0 ],
1278- output_dir ,
1279- self .layer_config ,
1280- self .tokenizer ,
1281- processor = self .processor if hasattr (self , "processor" ) else None ,
1282- image_processor = self .image_processor if hasattr (self , "image_processor" ) else None ,
1283- model_type = model_type ,
1284- )
1285- else :
1286- PACKING_LAYER_WITH_FORMAT [target_backend ](name , self .model , self .formats [0 ], device = self .device )
1269+ if self .immediate_packing :
1270+ self ._immediate_pack (name )
12871271 else :
12881272 set_module (self .model , name , m )
12891273
1274+ if self .immediate_saving :
1275+ all_to_quantized_module_names = [n for n , m in self .model .named_modules () if check_to_quantized (m )]
1276+ last_module = (len (all_to_quantized_module_names ) == 0 ) or (name == all_to_quantized_module_names [- 1 ])
1277+ m = get_module (self .model , name )
1278+ immediate_saving (self , m , name , last_module )
1279+
1280+ def _immediate_pack (self , name : str ):
1281+ m = get_module (self .model , name )
1282+ if not check_to_quantized (m ):
1283+ return
1284+ from auto_round .export import PACKING_LAYER_WITH_FORMAT
1285+
1286+ target_backend = self .formats [0 ].split (":" )[0 ] if ":" in self .formats [0 ] else self .formats [0 ]
1287+ has_gguf = any ("gguf" in fmt for fmt in self .formats )
1288+
1289+ if has_gguf :
1290+ from auto_round .export .export_to_gguf .export import pack_gguf_layer
1291+
1292+ output_dir = self ._get_save_folder_name (self .formats [0 ])
1293+ model_type = ModelType .MMPROJ if self .mllm else ModelType .TEXT
1294+ pack_gguf_layer (
1295+ name ,
1296+ self .model ,
1297+ self .formats [0 ],
1298+ output_dir ,
1299+ self .layer_config ,
1300+ self .tokenizer ,
1301+ processor = self .processor if hasattr (self , "processor" ) else None ,
1302+ image_processor = self .image_processor if hasattr (self , "image_processor" ) else None ,
1303+ model_type = model_type ,
1304+ )
1305+ else :
1306+ PACKING_LAYER_WITH_FORMAT [target_backend ](name , self .model , self .formats [0 ], device = self .device )
1307+
12901308 @torch .inference_mode ()
12911309 def _quantize_rtn (self ) -> tuple [torch .nn .Module , dict [str , Any ]]:
12921310 """Quantize all modules in the model using RTN (Round-To-Nearest) strategy.
@@ -1301,7 +1319,6 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
13011319 self .model .to (self .amp_dtype )
13021320
13031321 all_to_quantized_module_names : list [str ] = [n for n , m in self .model .named_modules () if check_to_quantized (m )]
1304-
13051322 if is_nv_fp (self .data_type ):
13061323 from auto_round .data_type .nvfp import calculate_gparam
13071324 from auto_round .data_type .utils import update_fused_layer_global_scales
@@ -1477,8 +1494,8 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
14771494 if hasattr (m , "tmp_name" ) and m .tmp_name in all_to_quantized_module_names :
14781495 self ._quantize_layer_via_rtn (m .tmp_name )
14791496 all_to_quantized_module_names .remove (m .tmp_name )
1480-
1481- mv_module_from_gpu (block )
1497+ if not self . immediate_saving :
1498+ mv_module_from_gpu (block )
14821499 pbar .update (1 )
14831500
14841501 pbar .close ()
@@ -1556,7 +1573,12 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
15561573 )
15571574 and self .inplace
15581575 ):
1559- self .is_packing_immediate = True
1576+ self .immediate_packing = True
1577+ if "gguf" not in formats [0 ] and self .low_cpu_mem_usage :
1578+ self .immediate_saving = True
1579+ if self .immediate_saving and "int" not in self .data_type :
1580+ logger .warning ("immediate_saving is only supported for int quantization, set to False" )
1581+ self .immediate_saving = False
15601582 if self .iters == 0 :
15611583 return self ._quantize_rtn ()
15621584
@@ -1628,15 +1650,14 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
16281650 device = self .device ,
16291651 pbar = pbar ,
16301652 )
1631- if self .is_packing_immediate and len (self .formats ) != 1 :
1653+ if self .immediate_packing and len (self .formats ) != 1 :
16321654 raise ValueError (
1633- f"Expected exactly one packing format when 'is_packing_immediate ' is True, "
1655+ f"Expected exactly one packing format when 'immediate_packing ' is True, "
16341656 f"but got { len (self .formats )} formats."
16351657 )
16361658 pbar .set_description ("Quantizing done" )
16371659 pbar .close ()
1638-
1639- self ._quantize_layers (layer_names , all_inputs ) ##TODO pack layer immediately
1660+ self ._quantize_layers (layer_names , all_inputs )
16401661
16411662 if is_fp8_model (self .model ):
16421663 for n , m in self .model .named_modules ():
@@ -1714,7 +1735,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
17141735 has_gguf = False
17151736 if hasattr (self , "formats" ):
17161737 has_gguf = any ("gguf" in format_ for format_ in self .formats )
1717- if has_gguf and self .is_packing_immediate :
1738+ if has_gguf and self .immediate_packing :
17181739 enable_quanted_input = False
17191740
17201741 if hasattr (self .model , "hf_device_map" ) and len (self .model .hf_device_map ) > 1 and enable_quanted_input :
@@ -1727,8 +1748,8 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
17271748 accelerate .hooks .remove_hook_from_submodules (
17281749 self .model
17291750 ) ##self.model.hf_device_map has not been changed
1730-
1731- self .model = mv_module_from_gpu (self .model )
1751+ if not self . immediate_saving :
1752+ self .model = mv_module_from_gpu (self .model )
17321753 clear_memory ()
17331754 quant_layer = self ._quantize_layer
17341755 for layer_name in layer_names :
@@ -1737,6 +1758,12 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
17371758 q_layer_input = q_layer_inputs .get (layer_name , None ) if q_layer_inputs is not None else None
17381759 q_layer_input = to_device (q_layer_input , self .cache_device )
17391760 quant_layer (layer_name , layer_input , q_layer_input , device = self .device )
1761+ if self .immediate_packing :
1762+ self ._immediate_pack (layer_name )
1763+
1764+ if self .immediate_saving :
1765+ m = get_module (self .model , layer_name )
1766+ immediate_saving (self , m , name = layer_name , last_group = True )
17401767 del layer_input
17411768 clear_memory (q_layer_input )
17421769
@@ -1937,7 +1964,6 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
19371964 layer_names = []
19381965 if layer_names is None :
19391966 layer_names = []
1940-
19411967 if self .low_gpu_mem_usage or (
19421968 len (block_names ) == 1
19431969 and len (layer_names ) == 0
@@ -2732,6 +2758,7 @@ def _quantize_blocks(
27322758 input_ids = to_device (input_ids , self .cache_device )
27332759 input_others = to_device (input_others , self .cache_device )
27342760 # As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage
2761+
27352762 tmp_dtype = self .amp_dtype if self .amp else torch .float32
27362763 input_ids = to_dtype (input_ids , tmp_dtype )
27372764
@@ -2808,38 +2835,20 @@ def _quantize_blocks(
28082835 )
28092836 if hasattr (model , "config" ):
28102837 del m .config
2811- if self .is_packing_immediate :
2812- from auto_round .export import PACKING_LAYER_WITH_FORMAT
2813-
2838+ if self .immediate_packing :
28142839 for _ , tmp_m in m .named_modules ():
28152840 if not (hasattr (tmp_m , "bits" ) and check_to_quantized (tmp_m )):
28162841 continue
2817- target_backend = self .formats [0 ].split (":" )[0 ] if ":" in self .formats [0 ] else self .formats [0 ]
2818- has_gguf = any ("gguf" in format_ for format_ in self .formats )
2819- if has_gguf :
2820- from auto_round .export .export_to_gguf .export import pack_gguf_layer
2821-
2822- output_dir = self ._get_save_folder_name (self .formats [0 ])
2823- model_type = ModelType .MMPROJ if self .mllm else ModelType .TEXT
2824- pack_gguf_layer (
2825- tmp_m .tmp_name ,
2826- self .model ,
2827- self .formats [0 ],
2828- output_dir ,
2829- self .layer_config ,
2830- self .tokenizer ,
2831- processor = self .processor if hasattr (self , "processor" ) else None ,
2832- image_processor = self .image_processor if hasattr (self , "image_processor" ) else None ,
2833- model_type = model_type ,
2834- )
2835- else :
2836- PACKING_LAYER_WITH_FORMAT [target_backend ](
2837- tmp_m .tmp_name , self .model , self .formats [0 ], device = self .device
2838- )
2842+ self ._immediate_pack (tmp_m .tmp_name )
2843+
2844+ if self .immediate_saving :
2845+ last_group = (i + nblocks ) >= len (block_names )
2846+ immediate_saving (self , m , last_group = last_group )
28392847 if pbar is not None :
28402848 pbar .update (1 )
28412849
2842- self .model = mv_module_from_gpu (self .model )
2850+ if not self .immediate_saving :
2851+ self .model = mv_module_from_gpu (self .model )
28432852 for n , m in self .model .named_modules ():
28442853 if hasattr (m , "name" ):
28452854 delattr (m , "name" )
0 commit comments