diff --git a/auto_round/calib_dataset.py b/auto_round/calib_dataset.py index a219aa787..96d80f7ea 100644 --- a/auto_round/calib_dataset.py +++ b/auto_round/calib_dataset.py @@ -632,15 +632,8 @@ def select_dataset(dataset, indices): return dataset -def get_dataloader( - tokenizer, - seqlen, - dataset_name="NeelNanda/pile-10k", - seed=42, - bs=8, - nsamples=512, -): - """Generate a DataLoader for calibration using specified parameters. +def get_dataset(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, nsamples=512): + """Generate a dataset for calibration. Args: tokenizer (Tokenizer): The tokenizer to use for tokenization. @@ -650,12 +643,11 @@ def get_dataloader( Defaults to "NeelNanda/pile-10k". split (str, optional): The data split to use. Defaults to None. seed (int, optional): The random seed for reproducibility. Defaults to 42. - bs (int, optional): The batch size. Defaults to 4. nsamples (int, optional): The total number of samples to include. Defaults to 512. apply_chat_template: Whether to apply chat template in tokenization. Returns: - DataLoader: The DataLoader for the calibrated dataset. + Dataset: The processed dataset ready for calibration. """ dataset_names = dataset_name.split(",") @@ -823,7 +815,29 @@ def concat_dataset_element(dataset): else: dataset_final = datasets[0] - # dataset_final = datasets[0] + if len(dataset_final) > nsamples: + dataset_final = select_dataset(dataset_final, range(nsamples)) + return dataset_final + + +def get_dataloader(tokenizer, seqlen, dataset_name="NeelNanda/pile-10k", seed=42, bs=8, nsamples=512): + """Generate a DataLoader for calibration using specified parameters. + + Args: + tokenizer (Tokenizer): The tokenizer to use for tokenization. + seqlen (int): The exact sequence length. samples < seqlen will be dropped, + samples longer than seqlen will be truncated + dataset_name (str, optional): The name of the dataset or datasets separated by commas. + Defaults to "NeelNanda/pile-10k". + split (str, optional): The data split to use. Defaults to None. + seed (int, optional): The random seed for reproducibility. Defaults to 42. + bs (int, optional): The batch size. Defaults to 4. + nsamples (int, optional): The total number of samples to include. Defaults to 512. + apply_chat_template: Whether to apply chat template in tokenization. + + Returns: + DataLoader: The DataLoader for the calibrated dataset. + """ @torch.no_grad() def collate_batch(batch): @@ -849,8 +863,6 @@ def collate_batch(batch): res = {"input_ids": input_ids_new, "attention_mask": attention_mask_new} return res - if len(dataset_final) > nsamples: - dataset_final = select_dataset(dataset_final, range(nsamples)) - + dataset_final = get_dataset(tokenizer, seqlen, dataset_name, seed, nsamples) calib_dataloader = DataLoader(dataset_final, batch_size=bs, shuffle=False, collate_fn=collate_batch) return calib_dataloader diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 8f398e7a1..b4ef88201 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -20,7 +20,7 @@ import traceback from collections import defaultdict from dataclasses import asdict, fields -from typing import Any, Callable, Union +from typing import Any, Callable, Optional, Union import accelerate import torch @@ -88,6 +88,7 @@ is_hpex_available, llm_load_model, mv_module_from_gpu, + normalize_input, set_amax_for_all_moe_layers, set_module, to_device, @@ -1528,6 +1529,21 @@ def _update_inputs(self, inputs: dict, q_inputs: dict) -> tuple[dict, torch.Tens q_inputs = q_inputs.pop(input_id_str[0], None) return inputs, q_inputs + def configure_layer_config(self, enable_gguf_official_mixed: None | bool = True): + self.layer_config, self.has_qlayer_outside_block, self.regex_config = set_layer_config( + self.model, + self.layer_config, + self.scheme, + self.scale_dtype, + self.supported_types, + self.inner_supported_types, + self.quant_block_list, + self.fp_layers, + self.quant_lm_head, + enable_gguf_official_mixed=enable_gguf_official_mixed, + is_mllm=self.mllm, + ) + def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: """Quantize the model and return the quantized model along with layer configurations.The entry of AutoRound. Returns: @@ -1546,20 +1562,8 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: enable_gguf_official_mixed = True else: enable_gguf_official_mixed = False - self.layer_config, self.has_qlayer_outside_block, self.regex_config = set_layer_config( - self.model, - self.layer_config, - self.scheme, - self.scale_dtype, - self.supported_types, - self.inner_supported_types, - self.quant_block_list, - self.fp_layers, - self.quant_lm_head, - enable_gguf_official_mixed=enable_gguf_official_mixed, - is_mllm=self.mllm, - ) + self.configure_layer_config(enable_gguf_official_mixed=enable_gguf_official_mixed) if not hasattr(self, "formats"): logger.warning("this API is deprecated, please use `quantize_and_save` instead") else: @@ -2471,6 +2475,28 @@ def _get_current_num_elm( current_input_ids = [input_ids[i] for i in indices] return sum(id.numel() for id in current_input_ids) + def quantize_block( + self, + block: torch.nn.Module, + inputs: tuple[Union[list[torch.Tensor], dict, Any], Optional[dict]], + q_input: Union[torch.Tensor, dict, None] = None, + device: Union[str, torch.device] = "cpu", + auto_offload=True, + ): + """ + This function quantizes a specific decoded block of a model. + It is primarily used by LLM-Compressor. For more details, please refer to the following PR: + https://github.com/vllm-project/llm-compressor/pull/1994 + """ + + # TODO: release below assertion after supporting MLLM and diffusion model quantization with quantize_block + assert self.__class__.__name__ not in [ + "DiffusionCompressor", + "MLLMCompressor", + ], f"Currently, {self.__class__.__name__} does not support support quantize block with this function." + input_ids, input_others = normalize_input(inputs) + return self._quantize_block(block, input_ids, input_others, q_input, device, auto_offload) + def _quantize_block( self, block: torch.nn.Module, @@ -2478,6 +2504,7 @@ def _quantize_block( input_others: dict, q_input: Union[torch.Tensor, dict, None] = None, device: Union[str, torch.device] = "cpu", + auto_offload=True, ): """Quantize the weights of a given block of the model. @@ -2496,17 +2523,21 @@ def _quantize_block( if is_fp8_linear(m): new_layer = convert_fp8_layer_to_linear(m, self.amp_dtype, self.device).to(device) set_module(block, n, new_layer) - # card_0_in_high_risk indicates that card_0 memory is already in high usage (90%) w/o any weights - # loss_device is used to calculate loss on the second device if available and card_0_in_high_risk - if is_auto_device_mapping(self.device_map) and len(self.device_list) > 1: - card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning( - block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device - ) + + if auto_offload: + # card_0_in_high_risk indicates that card_0 memory is already in high usage (90%) w/o any weights + # loss_device is used to calculate loss on the second device if available and card_0_in_high_risk + if is_auto_device_mapping(self.device_map) and len(self.device_list) > 1: + card_0_in_high_risk, loss_device = set_auto_device_map_for_block_with_tuning( + block, self.device_map, input_ids, self.low_gpu_mem_usage, self.batch_size, device + ) + else: + block = block.to(device) + card_0_in_high_risk, loss_device = False, device else: - block = block.to(device) card_0_in_high_risk, loss_device = False, device - if len(self.device_list) > 1: + if len(self.device_list) > 1 and auto_offload: for n, m in block.named_modules(): if len(list(m.children())) != 0 or not hasattr(m, "tuning_device"): continue @@ -2648,7 +2679,7 @@ def _quantize_block( else: tmp_attention_mask = 1.0 if self.amp: - with autocast(device_type=loss_device.split(":")[0], dtype=self.amp_dtype): + with autocast(device_type=str(loss_device).split(":")[0], dtype=self.amp_dtype): loss = mse_loss( # pylint: disable=not-callable output_q * tmp_attention_mask, current_output * tmp_attention_mask ) @@ -2718,18 +2749,22 @@ def _quantize_block( device, cache_device=self.cache_device, ) - if len(self.device_list) > 1: + + if len(self.device_list) > 1 and auto_offload: accelerate.hooks.remove_hook_from_submodules(block) - mv_module_from_gpu(block) - clear_memory(input_ids, self.device_list) + if auto_offload: + mv_module_from_gpu(block) - return q_outputs, output + clear_memory(input_ids) + return q_outputs, output else: - if len(self.device_list) > 1: + if len(self.device_list) > 1 and auto_offload: accelerate.hooks.remove_hook_from_submodules(block) - mv_module_from_gpu(block) - clear_memory(input_ids, self.device_list) + if auto_offload: + mv_module_from_gpu(block) + clear_memory(input_ids) + return None, output def _split_inputs(self, inputs: dict) -> tuple[torch.Tensor, dict]: diff --git a/auto_round/compressors/utils.py b/auto_round/compressors/utils.py index c4cf83395..5aabe3969 100644 --- a/auto_round/compressors/utils.py +++ b/auto_round/compressors/utils.py @@ -111,7 +111,7 @@ def block_forward( alibi = input_others["alibi"] input_others["alibi"] = alibi.reshape(-1, alibi.shape[2], alibi.shape[3]) if amp: - with autocast(device_type=device.split(":")[0], dtype=amp_dtype): # pragma: no cover + with autocast(device_type=str(device).split(":")[0], dtype=amp_dtype): # pragma: no cover output = block(input_ids, *input_tuple, **input_others) else: output = block(input_ids, *input_tuple, **input_others) diff --git a/auto_round/utils/common.py b/auto_round/utils/common.py index dea3f8c81..9d4e4c98a 100644 --- a/auto_round/utils/common.py +++ b/auto_round/utils/common.py @@ -297,3 +297,19 @@ def get_reciprocal(tensor): else: tensor = torch.where(torch.abs(tensor) < 1e-30, 0, tensor) return torch.where(tensor != 0, 1 / tensor, torch.zeros_like(tensor)) + + +def normalize_input(decoding_layer_inputs: list[tuple[Any]]) -> Tuple[List[torch.Tensor], Dict[str, Any]]: + """Normalize the decoding layer inputs into input_ids and other inputs.""" + input_ids = [] + input_others = {} + input_others["positional_inputs"] = [] + for cur_inp in decoding_layer_inputs: + input_ids.append(cur_inp[0][0][0]) + for key, val in cur_inp[0][1].items(): + input_others[key] = val + # Force 'use_cache' to be False + if "use_cache" in input_others and input_others["use_cache"] is True: + logger.warning_once("Forcing 'use_cache' to be False during calibration.") + input_others["use_cache"] = False + return input_ids, input_others diff --git a/auto_round/utils/model.py b/auto_round/utils/model.py index 30fa2f23c..5f9e1941b 100644 --- a/auto_round/utils/model.py +++ b/auto_round/utils/model.py @@ -514,7 +514,6 @@ def is_mllm_model(model_or_path: Union[str, torch.nn.Module], platform: str = No "pre_mm_projector_norm", "vision", ] - model_path = model_or_path if isinstance(model_or_path, str) else model_or_path.name_or_path if not os.path.isdir(model_path): model_path = download_or_get_path(model_path, platform=platform)