Skip to content

Commit daeb3bb

Browse files
authored
Support for immediate saving to reduce ram usage (#965)
1 parent 758b239 commit daeb3bb

File tree

11 files changed

+441
-133
lines changed

11 files changed

+441
-133
lines changed

auto_round/autoround.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def __new__(
8585
enable_adam: bool = False,
8686
# for MLLM and Diffusion
8787
extra_config: ExtraConfig = None,
88+
low_cpu_mem_usage: bool = False,
8889
**kwargs,
8990
) -> BaseCompressor:
9091
"""Initialize AutoRound with quantization and tuning configuration.
@@ -105,6 +106,7 @@ def __new__(
105106
lr (float, optional): Learning rate; if None, set to 1.0 / iters except when iters==0.
106107
minmax_lr (float, optional): Learning rate for min-max tuning; defaults to `lr`.
107108
low_gpu_mem_usage (bool, optional): Lower GPU memory mode. Defaults to False.
109+
low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False.
108110
iters (int, optional): Optimization iterations. Defaults to 200.
109111
seqlen (int, optional): Calibration sequence length. Defaults to 2048.
110112
nsamples (int, optional): Number of calibration samples. Defaults to 128.
@@ -186,6 +188,7 @@ def __new__(
186188
device_map=device_map,
187189
enable_torch_compile=enable_torch_compile,
188190
seed=seed,
191+
low_cpu_mem_usage=low_cpu_mem_usage,
189192
**kwargs,
190193
)
191194
return ar

auto_round/compressors/base.py

Lines changed: 75 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
collect_best_params,
4040
get_shared_keys,
4141
gguf_args_check,
42+
immediate_saving,
4243
infer_bits_by_data_type,
4344
init_cache,
4445
is_mx_fp,
@@ -148,6 +149,7 @@ def __init__(
148149
enable_alg_ext: bool = False,
149150
disable_opt_rtn: bool = False,
150151
seed: int = 42,
152+
low_cpu_mem_usage: bool = False,
151153
**kwargs,
152154
):
153155
"""Initialize AutoRound with quantization and tuning configuration.
@@ -168,6 +170,7 @@ def __init__(
168170
lr (float, optional): Learning rate; if None, set to 1.0 / iters except when iters==0.
169171
minmax_lr (float, optional): Learning rate for min-max tuning; defaults to `lr`.
170172
low_gpu_mem_usage (bool, optional): Lower GPU memory mode. Defaults to False.
173+
low_cpu_mem_usage (bool, optional): Lower CPU memory mode. Defaults to False.
171174
iters (int, optional): Optimization iterations. Defaults to 200.
172175
seqlen (int, optional): Calibration sequence length. Defaults to 2048.
173176
nsamples (int, optional): Number of calibration samples. Defaults to 128.
@@ -244,6 +247,7 @@ def __init__(
244247
self.supported_types = SUPPORTED_LAYER_TYPES
245248
self.inner_supported_types = INNER_SUPPORTED_LAYER_TYPES
246249
self.scale_dtype = convert_dtype_str2torch(scale_dtype)
250+
self.low_cpu_mem_usage = low_cpu_mem_usage
247251

248252
if kwargs:
249253
logger.warning(f"unrecognized keys {list(kwargs.keys())} were passed. Please check them.")
@@ -336,7 +340,10 @@ def __init__(
336340
self.lr_scheduler = lr_scheduler
337341
self.optimizer = self._get_optimizer(None)
338342
self.disable_opt_rtn = disable_opt_rtn
339-
self.is_packing_immediate = False # whether to pack the layer immediately after tuning
343+
344+
# Whether to pack the layer immediately after tuning
345+
self.immediate_packing = False
346+
self.immediate_saving = False
340347

341348
# KV cache, this one does not affect tuning but will collect some infos during tuning
342349
self.static_kv_dtype = static_kv_dtype
@@ -1205,7 +1212,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
12051212
`rtn_*` version if supported, then wraps and unwraps the module to apply
12061213
quantization. If GPU memory is insufficient, it falls back to CPU.
12071214
1208-
If packing is enabled (`is_packing_immediate`), the function will also export
1215+
If packing is enabled (`immediate_packing`), the function will also export
12091216
the quantized layer to the appropriate backend format.
12101217
12111218
Args:
@@ -1222,7 +1229,7 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
12221229

12231230
# Step 1: Try quantization on GPU first, fall back to CPU if OOM
12241231
# if only export gguf, using gguf-packing instead of rtn
1225-
if self.is_packing_immediate and self.iters == 0 and "gguf" in self.formats[0] and not self.disable_opt_rtn:
1232+
if self.immediate_packing and self.iters == 0 and "gguf" in self.formats[0] and not self.disable_opt_rtn:
12261233
m.scale = None
12271234
m.zp = None
12281235
else:
@@ -1259,34 +1266,45 @@ def _quantize_layer_via_rtn(self, name: str) -> None:
12591266
raise
12601267

12611268
# Step 2: Optional immediate packing/export
1262-
if self.is_packing_immediate:
1263-
from auto_round.export import PACKING_LAYER_WITH_FORMAT
1264-
1265-
if check_to_quantized(m):
1266-
target_backend = self.formats[0].split(":")[0] if ":" in self.formats[0] else self.formats[0]
1267-
has_gguf = any("gguf" in fmt for fmt in self.formats)
1268-
1269-
if has_gguf:
1270-
from auto_round.export.export_to_gguf.export import pack_gguf_layer
1271-
1272-
output_dir = self._get_save_folder_name(self.formats[0])
1273-
model_type = ModelType.MMPROJ if self.mllm else ModelType.TEXT
1274-
pack_gguf_layer(
1275-
name,
1276-
self.model,
1277-
self.formats[0],
1278-
output_dir,
1279-
self.layer_config,
1280-
self.tokenizer,
1281-
processor=self.processor if hasattr(self, "processor") else None,
1282-
image_processor=self.image_processor if hasattr(self, "image_processor") else None,
1283-
model_type=model_type,
1284-
)
1285-
else:
1286-
PACKING_LAYER_WITH_FORMAT[target_backend](name, self.model, self.formats[0], device=self.device)
1269+
if self.immediate_packing:
1270+
self._immediate_pack(name)
12871271
else:
12881272
set_module(self.model, name, m)
12891273

1274+
if self.immediate_saving:
1275+
all_to_quantized_module_names = [n for n, m in self.model.named_modules() if check_to_quantized(m)]
1276+
last_module = (len(all_to_quantized_module_names) == 0) or (name == all_to_quantized_module_names[-1])
1277+
m = get_module(self.model, name)
1278+
immediate_saving(self, m, name, last_module)
1279+
1280+
def _immediate_pack(self, name: str):
1281+
m = get_module(self.model, name)
1282+
if not check_to_quantized(m):
1283+
return
1284+
from auto_round.export import PACKING_LAYER_WITH_FORMAT
1285+
1286+
target_backend = self.formats[0].split(":")[0] if ":" in self.formats[0] else self.formats[0]
1287+
has_gguf = any("gguf" in fmt for fmt in self.formats)
1288+
1289+
if has_gguf:
1290+
from auto_round.export.export_to_gguf.export import pack_gguf_layer
1291+
1292+
output_dir = self._get_save_folder_name(self.formats[0])
1293+
model_type = ModelType.MMPROJ if self.mllm else ModelType.TEXT
1294+
pack_gguf_layer(
1295+
name,
1296+
self.model,
1297+
self.formats[0],
1298+
output_dir,
1299+
self.layer_config,
1300+
self.tokenizer,
1301+
processor=self.processor if hasattr(self, "processor") else None,
1302+
image_processor=self.image_processor if hasattr(self, "image_processor") else None,
1303+
model_type=model_type,
1304+
)
1305+
else:
1306+
PACKING_LAYER_WITH_FORMAT[target_backend](name, self.model, self.formats[0], device=self.device)
1307+
12901308
@torch.inference_mode()
12911309
def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
12921310
"""Quantize all modules in the model using RTN (Round-To-Nearest) strategy.
@@ -1301,7 +1319,6 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]:
13011319
self.model.to(self.amp_dtype)
13021320

13031321
all_to_quantized_module_names: list[str] = [n for n, m in self.model.named_modules() if check_to_quantized(m)]
1304-
13051322
if is_nv_fp(self.data_type):
13061323
from auto_round.data_type.nvfp import calculate_gparam
13071324
from auto_round.data_type.utils import update_fused_layer_global_scales
@@ -1477,8 +1494,8 @@ def _quantize_via_rtn_blockwise(self, all_to_quantized_module_names: list[str])
14771494
if hasattr(m, "tmp_name") and m.tmp_name in all_to_quantized_module_names:
14781495
self._quantize_layer_via_rtn(m.tmp_name)
14791496
all_to_quantized_module_names.remove(m.tmp_name)
1480-
1481-
mv_module_from_gpu(block)
1497+
if not self.immediate_saving:
1498+
mv_module_from_gpu(block)
14821499
pbar.update(1)
14831500

14841501
pbar.close()
@@ -1556,7 +1573,12 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
15561573
)
15571574
and self.inplace
15581575
):
1559-
self.is_packing_immediate = True
1576+
self.immediate_packing = True
1577+
if "gguf" not in formats[0] and self.low_cpu_mem_usage:
1578+
self.immediate_saving = True
1579+
if self.immediate_saving and "int" not in self.data_type:
1580+
logger.warning("immediate_saving is only supported for int quantization, set to False")
1581+
self.immediate_saving = False
15601582
if self.iters == 0:
15611583
return self._quantize_rtn()
15621584

@@ -1628,15 +1650,14 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]:
16281650
device=self.device,
16291651
pbar=pbar,
16301652
)
1631-
if self.is_packing_immediate and len(self.formats) != 1:
1653+
if self.immediate_packing and len(self.formats) != 1:
16321654
raise ValueError(
1633-
f"Expected exactly one packing format when 'is_packing_immediate' is True, "
1655+
f"Expected exactly one packing format when 'immediate_packing' is True, "
16341656
f"but got {len(self.formats)} formats."
16351657
)
16361658
pbar.set_description("Quantizing done")
16371659
pbar.close()
1638-
1639-
self._quantize_layers(layer_names, all_inputs) ##TODO pack layer immediately
1660+
self._quantize_layers(layer_names, all_inputs)
16401661

16411662
if is_fp8_model(self.model):
16421663
for n, m in self.model.named_modules():
@@ -1714,7 +1735,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
17141735
has_gguf = False
17151736
if hasattr(self, "formats"):
17161737
has_gguf = any("gguf" in format_ for format_ in self.formats)
1717-
if has_gguf and self.is_packing_immediate:
1738+
if has_gguf and self.immediate_packing:
17181739
enable_quanted_input = False
17191740

17201741
if hasattr(self.model, "hf_device_map") and len(self.model.hf_device_map) > 1 and enable_quanted_input:
@@ -1727,8 +1748,8 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
17271748
accelerate.hooks.remove_hook_from_submodules(
17281749
self.model
17291750
) ##self.model.hf_device_map has not been changed
1730-
1731-
self.model = mv_module_from_gpu(self.model)
1751+
if not self.immediate_saving:
1752+
self.model = mv_module_from_gpu(self.model)
17321753
clear_memory()
17331754
quant_layer = self._quantize_layer
17341755
for layer_name in layer_names:
@@ -1737,6 +1758,12 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None:
17371758
q_layer_input = q_layer_inputs.get(layer_name, None) if q_layer_inputs is not None else None
17381759
q_layer_input = to_device(q_layer_input, self.cache_device)
17391760
quant_layer(layer_name, layer_input, q_layer_input, device=self.device)
1761+
if self.immediate_packing:
1762+
self._immediate_pack(layer_name)
1763+
1764+
if self.immediate_saving:
1765+
m = get_module(self.model, layer_name)
1766+
immediate_saving(self, m, name=layer_name, last_group=True)
17401767
del layer_input
17411768
clear_memory(q_layer_input)
17421769

@@ -1937,7 +1964,6 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
19371964
layer_names = []
19381965
if layer_names is None:
19391966
layer_names = []
1940-
19411967
if self.low_gpu_mem_usage or (
19421968
len(block_names) == 1
19431969
and len(layer_names) == 0
@@ -2732,6 +2758,7 @@ def _quantize_blocks(
27322758
input_ids = to_device(input_ids, self.cache_device)
27332759
input_others = to_device(input_others, self.cache_device)
27342760
# As in calibration phase, we may use bf16 for calibration due to low_gpu_memory usage
2761+
27352762
tmp_dtype = self.amp_dtype if self.amp else torch.float32
27362763
input_ids = to_dtype(input_ids, tmp_dtype)
27372764

@@ -2808,38 +2835,20 @@ def _quantize_blocks(
28082835
)
28092836
if hasattr(model, "config"):
28102837
del m.config
2811-
if self.is_packing_immediate:
2812-
from auto_round.export import PACKING_LAYER_WITH_FORMAT
2813-
2838+
if self.immediate_packing:
28142839
for _, tmp_m in m.named_modules():
28152840
if not (hasattr(tmp_m, "bits") and check_to_quantized(tmp_m)):
28162841
continue
2817-
target_backend = self.formats[0].split(":")[0] if ":" in self.formats[0] else self.formats[0]
2818-
has_gguf = any("gguf" in format_ for format_ in self.formats)
2819-
if has_gguf:
2820-
from auto_round.export.export_to_gguf.export import pack_gguf_layer
2821-
2822-
output_dir = self._get_save_folder_name(self.formats[0])
2823-
model_type = ModelType.MMPROJ if self.mllm else ModelType.TEXT
2824-
pack_gguf_layer(
2825-
tmp_m.tmp_name,
2826-
self.model,
2827-
self.formats[0],
2828-
output_dir,
2829-
self.layer_config,
2830-
self.tokenizer,
2831-
processor=self.processor if hasattr(self, "processor") else None,
2832-
image_processor=self.image_processor if hasattr(self, "image_processor") else None,
2833-
model_type=model_type,
2834-
)
2835-
else:
2836-
PACKING_LAYER_WITH_FORMAT[target_backend](
2837-
tmp_m.tmp_name, self.model, self.formats[0], device=self.device
2838-
)
2842+
self._immediate_pack(tmp_m.tmp_name)
2843+
2844+
if self.immediate_saving:
2845+
last_group = (i + nblocks) >= len(block_names)
2846+
immediate_saving(self, m, last_group=last_group)
28392847
if pbar is not None:
28402848
pbar.update(1)
28412849

2842-
self.model = mv_module_from_gpu(self.model)
2850+
if not self.immediate_saving:
2851+
self.model = mv_module_from_gpu(self.model)
28432852
for n, m in self.model.named_modules():
28442853
if hasattr(m, "name"):
28452854
delattr(m, "name")

0 commit comments

Comments
 (0)