diff --git a/auto_round/auto_scheme/utils.py b/auto_round/auto_scheme/utils.py index cc96455d6..67b7fd428 100644 --- a/auto_round/auto_scheme/utils.py +++ b/auto_round/auto_scheme/utils.py @@ -27,6 +27,7 @@ get_layer_features, get_module, is_hpex_available, + parse_all_available_device, ) @@ -204,92 +205,6 @@ def compute_layer_bits( return total_bits, avg_bits -def parse_all_available_device(device_map: Union[str, torch.device, int, dict, None] = None) -> list: - """ - Parse the device map and return a list of all available devices. - - Supported input formats: - - None: Automatically detect all available devices - - int: A single device index (e.g., 0) - - str: Examples: - "cpu" - "cuda:0,cuda:1" - "0,1" (numeric device indices) - - dict: Extract all device values from the dictionary - - torch.device: e.g. torch.device("cuda:0") - - Returns: - list[str]: Normalized device names, e.g., ["cuda:0", "cuda:1"] or ["cpu"] - """ - - # === Step 1. Detect available device types === - device_types = [] - if torch.cuda.is_available(): - device_types.append("cuda") - if hasattr(torch, "xpu") and torch.xpu.is_available(): - device_types.append("xpu") - if hasattr(torch, "hpu") and is_hpex_available(): - device_types.append("hpu") - - # Always include CPU as a fallback - if not device_types: - device_types = ["cpu"] - - # === Step 2. Parse different input formats === - if device_map is None: - # Automatically detect one available device - if "cuda" in device_types: - return ["cuda:0"] - elif "xpu" in device_types: - return ["xpu:0"] - elif "hpu" in device_types: - return ["hpu:0"] - else: - return ["cpu"] - - if isinstance(device_map, torch.device): - # Handle torch.device objects - dev_type = device_map.type - index = device_map.index - if dev_type == "cpu": - return ["cpu"] - if index is None: - index = 0 - return [f"{dev_type}:{index}"] - - if isinstance(device_map, int): - # Integer input → use primary available device type - device_type = device_types[0] - return [f"{device_type}:{device_map}"] if device_type != "cpu" else ["cpu"] - - if isinstance(device_map, str): - # Remove whitespace - device_map = device_map.strip() - if device_map.lower() == "cpu": - return ["cpu"] - - # Split by commas - parts = [x.strip() for x in device_map.split(",") if x.strip()] - parsed = [] - for p in parts: - if p.isdigit(): - # Numeric → assign to first available device type - device_type = device_types[0] - parsed.append(f"{device_type}:{p}" if device_type != "cpu" else "cpu") - else: - parsed.append(p) - return parsed - - if isinstance(device_map, dict): - # Extract all devices recursively from dict values - devices = set() - for v in device_map.values(): - devices.update(parse_all_available_device(v)) - return sorted(devices) - - raise TypeError(f"Unsupported device_map type: {type(device_map)}") - - # Important Notice This dispatch does not follow dict device_map, just extract all available devices and use them def dispatch_model_by_all_available_devices( model: torch.nn.Module, device_map: Union[str, int, dict, None] diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 0a682079a..e7ce77404 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -96,6 +96,7 @@ from auto_round.utils.device import ( clear_memory_if_reached_threshold, get_major_device, + parse_all_available_device, set_auto_device_map_for_block_with_tuning, set_non_auto_device_map, ) @@ -1954,17 +1955,25 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l self.model = dispatch_model(self.model, device_map=self.model.hf_device_map) else: # Change this if new device is supported - if str(self.model.device) == "cpu" and ( - self.device.startswith("xpu") or self.device.startswith("cuda") - ): + if str(self.model.device) == "cpu" and (not self.device.startswith("hpu")): no_split_modules = getattr(self.model, "_no_split_modules", []) + devices = parse_all_available_device(self.device_map) max_memory = get_balanced_memory( self.model, max_memory=None, no_split_module_classes=no_split_modules, ) + new_max_memory = {} + for device in devices: + if ":" in device: + device = int(device.split(":")[-1]) + elif device == "cpu": + device = "cpu" + else: + raise ValueError(f"Unsupported device {device} in device_map: {self.device_map}") + new_max_memory[device] = max_memory[device] device_map = infer_auto_device_map( - self.model, max_memory=max_memory, no_split_module_classes=no_split_modules + self.model, max_memory=new_max_memory, no_split_module_classes=no_split_modules ) self.model = dispatch_model(self.model, device_map=device_map) diff --git a/auto_round/utils/device.py b/auto_round/utils/device.py index b0ecf9019..cfafab7d3 100644 --- a/auto_round/utils/device.py +++ b/auto_round/utils/device.py @@ -924,10 +924,10 @@ def estimate_tuning_block_mem( # TODO: XPU takes more memory than expected. for llama 8B, it's about 12 GB xpu_additional_memory = 12 # GB additional_memory += xpu_additional_memory - logger.warning_once( - "[Memory Estimation]: If there is an abnormal memory issue, please collect log with " - + "AR_LOG_LEVEL=debug and raise issue to us." - ) + # logger.warning_once( + # "[Memory Estimation]: If there is an abnormal memory issue, please collect log with " + # + "AR_LOG_LEVEL=debug and raise issue to us." + # ) return layer_memory_dict, layer_activation_memory, block_input_output_memory, additional_memory @@ -1028,8 +1028,7 @@ def set_auto_device_map_for_block_with_tuning( mem_per_param = total_available_memory / total_params # Initialize device memory tracking - device_memory = {} - device_memory[device_0] = card_0_left_memory + device_memory = {device_0: card_0_left_memory} for i in range(1, len(gpu_devices)): device_idx = device_list[i] if device_list else i device_memory[gpu_devices[i]] = get_device_memory(device_idx) @@ -1181,3 +1180,89 @@ def set_avg_auto_device_map(model: torch.nn.Module, device_map): groups = partition_dict_numbers(number_dict, 2) for i, group in enumerate(groups): print(f"Group {i + 1}: {group}, Sum: {sum(group.values())}") + + +def parse_all_available_device(device_map: Union[str, torch.device, int, dict, None] = None) -> list: + """ + Parse the device map and return a list of all available devices. + + Supported input formats: + - None: Automatically detect all available devices + - int: A single device index (e.g., 0) + - str: Examples: + "cpu" + "cuda:0,cuda:1" + "0,1" (numeric device indices) + - dict: Extract all device values from the dictionary + - torch.device: e.g. torch.device("cuda:0") + + Returns: + list[str]: Normalized device names, e.g., ["cuda:0", "cuda:1"] or ["cpu"] + """ + + # === Step 1. Detect available device types === + device_types = [] + if torch.cuda.is_available(): + device_types.append("cuda") + if hasattr(torch, "xpu") and torch.xpu.is_available(): + device_types.append("xpu") + if hasattr(torch, "hpu") and is_hpex_available(): + device_types.append("hpu") + + # Always include CPU as a fallback + if not device_types: + device_types = ["cpu"] + + # === Step 2. Parse different input formats === + if device_map is None: + # Automatically detect one available device + if "cuda" in device_types: + return ["cuda:0"] + elif "xpu" in device_types: + return ["xpu:0"] + elif "hpu" in device_types: + return ["hpu:0"] + else: + return ["cpu"] + + if isinstance(device_map, torch.device): + # Handle torch.device objects + dev_type = device_map.type + index = device_map.index + if dev_type == "cpu": + return ["cpu"] + if index is None: + index = 0 + return [f"{dev_type}:{index}"] + + if isinstance(device_map, int): + # Integer input → use primary available device type + device_type = device_types[0] + return [f"{device_type}:{device_map}"] if device_type != "cpu" else ["cpu"] + + if isinstance(device_map, str): + # Remove whitespace + device_map = device_map.strip() + if device_map.lower() == "cpu": + return ["cpu"] + + # Split by commas + parts = [x.strip() for x in device_map.split(",") if x.strip()] + parsed = [] + for p in parts: + if p.isdigit(): + # Numeric → assign to first available device type + device_type = device_types[0] + parsed.append(f"{device_type}:{p}" if device_type != "cpu" else "cpu") + else: + parsed.append(p) + return parsed + + if isinstance(device_map, dict): + # Extract all devices recursively from dict values + devices = set() + for v in device_map.values(): + devices.update(parse_all_available_device(v)) + return sorted(devices) + + raise TypeError(f"Unsupported device_map type: {type(device_map)}") diff --git a/auto_round_extension/triton/qlinear_tritonv2_zp.py b/auto_round_extension/triton/qlinear_tritonv2_zp.py index 36ecd6601..58ff9de0a 100644 --- a/auto_round_extension/triton/qlinear_tritonv2_zp.py +++ b/auto_round_extension/triton/qlinear_tritonv2_zp.py @@ -15,10 +15,8 @@ import math from logging import getLogger -import numpy as np import torch import torch.nn as nn -import transformers from auto_round_extension.triton.triton_utils_zp.mixin import TritonModuleMixin diff --git a/auto_round_extension/triton/triton_utils/dequant.py b/auto_round_extension/triton/triton_utils/dequant.py index 33eb8d8e7..404819159 100644 --- a/auto_round_extension/triton/triton_utils/dequant.py +++ b/auto_round_extension/triton/triton_utils/dequant.py @@ -38,7 +38,6 @@ import torch import triton import triton.language as tl -from torch.cuda.amp import custom_bwd, custom_fwd def make_dequant_configs(block_sizes, num_warps):