Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 1 addition & 86 deletions auto_round/auto_scheme/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
get_layer_features,
get_module,
is_hpex_available,
parse_all_available_device,
)


Expand Down Expand Up @@ -204,92 +205,6 @@ def compute_layer_bits(
return total_bits, avg_bits


def parse_all_available_device(device_map: Union[str, torch.device, int, dict, None] = None) -> list:
"""
Parse the device map and return a list of all available devices.

Supported input formats:
- None: Automatically detect all available devices
- int: A single device index (e.g., 0)
- str: Examples:
"cpu"
"cuda:0,cuda:1"
"0,1" (numeric device indices)
- dict: Extract all device values from the dictionary
- torch.device: e.g. torch.device("cuda:0")

Returns:
list[str]: Normalized device names, e.g., ["cuda:0", "cuda:1"] or ["cpu"]
"""

# === Step 1. Detect available device types ===
device_types = []
if torch.cuda.is_available():
device_types.append("cuda")
if hasattr(torch, "xpu") and torch.xpu.is_available():
device_types.append("xpu")
if hasattr(torch, "hpu") and is_hpex_available():
device_types.append("hpu")

# Always include CPU as a fallback
if not device_types:
device_types = ["cpu"]

# === Step 2. Parse different input formats ===
if device_map is None:
# Automatically detect one available device
if "cuda" in device_types:
return ["cuda:0"]
elif "xpu" in device_types:
return ["xpu:0"]
elif "hpu" in device_types:
return ["hpu:0"]
else:
return ["cpu"]

if isinstance(device_map, torch.device):
# Handle torch.device objects
dev_type = device_map.type
index = device_map.index
if dev_type == "cpu":
return ["cpu"]
if index is None:
index = 0
return [f"{dev_type}:{index}"]

if isinstance(device_map, int):
# Integer input → use primary available device type
device_type = device_types[0]
return [f"{device_type}:{device_map}"] if device_type != "cpu" else ["cpu"]

if isinstance(device_map, str):
# Remove whitespace
device_map = device_map.strip()
if device_map.lower() == "cpu":
return ["cpu"]

# Split by commas
parts = [x.strip() for x in device_map.split(",") if x.strip()]
parsed = []
for p in parts:
if p.isdigit():
# Numeric → assign to first available device type
device_type = device_types[0]
parsed.append(f"{device_type}:{p}" if device_type != "cpu" else "cpu")
else:
parsed.append(p)
return parsed

if isinstance(device_map, dict):
# Extract all devices recursively from dict values
devices = set()
for v in device_map.values():
devices.update(parse_all_available_device(v))
return sorted(devices)

raise TypeError(f"Unsupported device_map type: {type(device_map)}")


# Important Notice This dispatch does not follow dict device_map, just extract all available devices and use them
def dispatch_model_by_all_available_devices(
model: torch.nn.Module, device_map: Union[str, int, dict, None]
Expand Down
17 changes: 13 additions & 4 deletions auto_round/compressors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@
from auto_round.utils.device import (
clear_memory_if_reached_threshold,
get_major_device,
parse_all_available_device,
set_auto_device_map_for_block_with_tuning,
set_non_auto_device_map,
)
Expand Down Expand Up @@ -1954,17 +1955,25 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
self.model = dispatch_model(self.model, device_map=self.model.hf_device_map)
else:
# Change this if new device is supported
if str(self.model.device) == "cpu" and (
self.device.startswith("xpu") or self.device.startswith("cuda")
):
if str(self.model.device) == "cpu" and (not self.device.startswith("hpu")):
no_split_modules = getattr(self.model, "_no_split_modules", [])
devices = parse_all_available_device(self.device_map)
max_memory = get_balanced_memory(
self.model,
max_memory=None,
no_split_module_classes=no_split_modules,
)
new_max_memory = {}
for device in devices:
if ":" in device:
device = int(device.split(":")[-1])
elif device == "cpu":
device = "cpu"
else:
raise ValueError(f"Unsupported device {device} in device_map: {self.device_map}")
new_max_memory[device] = max_memory[device]
device_map = infer_auto_device_map(
self.model, max_memory=max_memory, no_split_module_classes=no_split_modules
self.model, max_memory=new_max_memory, no_split_module_classes=no_split_modules
)

self.model = dispatch_model(self.model, device_map=device_map)
Expand Down
97 changes: 91 additions & 6 deletions auto_round/utils/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,10 +924,10 @@ def estimate_tuning_block_mem(
# TODO: XPU takes more memory than expected. for llama 8B, it's about 12 GB
xpu_additional_memory = 12 # GB
additional_memory += xpu_additional_memory
logger.warning_once(
"[Memory Estimation]: If there is an abnormal memory issue, please collect log with "
+ "AR_LOG_LEVEL=debug and raise issue to us."
)
# logger.warning_once(
# "[Memory Estimation]: If there is an abnormal memory issue, please collect log with "
# + "AR_LOG_LEVEL=debug and raise issue to us."
# )

return layer_memory_dict, layer_activation_memory, block_input_output_memory, additional_memory

Expand Down Expand Up @@ -1028,8 +1028,7 @@ def set_auto_device_map_for_block_with_tuning(
mem_per_param = total_available_memory / total_params

# Initialize device memory tracking
device_memory = {}
device_memory[device_0] = card_0_left_memory
device_memory = {device_0: card_0_left_memory}
for i in range(1, len(gpu_devices)):
device_idx = device_list[i] if device_list else i
device_memory[gpu_devices[i]] = get_device_memory(device_idx)
Expand Down Expand Up @@ -1181,3 +1180,89 @@ def set_avg_auto_device_map(model: torch.nn.Module, device_map):
groups = partition_dict_numbers(number_dict, 2)
for i, group in enumerate(groups):
print(f"Group {i + 1}: {group}, Sum: {sum(group.values())}")


def parse_all_available_device(device_map: Union[str, torch.device, int, dict, None] = None) -> list:
"""
Parse the device map and return a list of all available devices.

Supported input formats:
- None: Automatically detect all available devices
- int: A single device index (e.g., 0)
- str: Examples:
"cpu"
"cuda:0,cuda:1"
"0,1" (numeric device indices)
- dict: Extract all device values from the dictionary
- torch.device: e.g. torch.device("cuda:0")

Returns:
list[str]: Normalized device names, e.g., ["cuda:0", "cuda:1"] or ["cpu"]
"""

# === Step 1. Detect available device types ===
device_types = []
if torch.cuda.is_available():
device_types.append("cuda")
if hasattr(torch, "xpu") and torch.xpu.is_available():
device_types.append("xpu")
if hasattr(torch, "hpu") and is_hpex_available():
device_types.append("hpu")

# Always include CPU as a fallback
if not device_types:
device_types = ["cpu"]

# === Step 2. Parse different input formats ===
if device_map is None:
# Automatically detect one available device
if "cuda" in device_types:
return ["cuda:0"]
elif "xpu" in device_types:
return ["xpu:0"]
elif "hpu" in device_types:
return ["hpu:0"]
else:
return ["cpu"]

if isinstance(device_map, torch.device):
# Handle torch.device objects
dev_type = device_map.type
index = device_map.index
if dev_type == "cpu":
return ["cpu"]
if index is None:
index = 0
return [f"{dev_type}:{index}"]

if isinstance(device_map, int):
# Integer input → use primary available device type
device_type = device_types[0]
return [f"{device_type}:{device_map}"] if device_type != "cpu" else ["cpu"]

if isinstance(device_map, str):
# Remove whitespace
device_map = device_map.strip()
if device_map.lower() == "cpu":
return ["cpu"]

# Split by commas
parts = [x.strip() for x in device_map.split(",") if x.strip()]
parsed = []
for p in parts:
if p.isdigit():
# Numeric → assign to first available device type
device_type = device_types[0]
parsed.append(f"{device_type}:{p}" if device_type != "cpu" else "cpu")
else:
parsed.append(p)
return parsed

if isinstance(device_map, dict):
# Extract all devices recursively from dict values
devices = set()
for v in device_map.values():
devices.update(parse_all_available_device(v))
return sorted(devices)

raise TypeError(f"Unsupported device_map type: {type(device_map)}")
2 changes: 0 additions & 2 deletions auto_round_extension/triton/qlinear_tritonv2_zp.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
import math
from logging import getLogger

import numpy as np
import torch
import torch.nn as nn
import transformers

from auto_round_extension.triton.triton_utils_zp.mixin import TritonModuleMixin

Expand Down
1 change: 0 additions & 1 deletion auto_round_extension/triton/triton_utils/dequant.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd


def make_dequant_configs(block_sizes, num_warps):
Expand Down