Skip to content

Commit ead6f29

Browse files
authored
fix multiple devices map issue in calibration (#1003)
1 parent f6745fd commit ead6f29

File tree

5 files changed

+105
-99
lines changed

5 files changed

+105
-99
lines changed

auto_round/auto_scheme/utils.py

Lines changed: 1 addition & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
get_layer_features,
2828
get_module,
2929
is_hpex_available,
30+
parse_all_available_device,
3031
)
3132

3233

@@ -204,92 +205,6 @@ def compute_layer_bits(
204205
return total_bits, avg_bits
205206

206207

207-
def parse_all_available_device(device_map: Union[str, torch.device, int, dict, None] = None) -> list:
208-
"""
209-
Parse the device map and return a list of all available devices.
210-
211-
Supported input formats:
212-
- None: Automatically detect all available devices
213-
- int: A single device index (e.g., 0)
214-
- str: Examples:
215-
"cpu"
216-
"cuda:0,cuda:1"
217-
"0,1" (numeric device indices)
218-
- dict: Extract all device values from the dictionary
219-
- torch.device: e.g. torch.device("cuda:0")
220-
221-
Returns:
222-
list[str]: Normalized device names, e.g., ["cuda:0", "cuda:1"] or ["cpu"]
223-
"""
224-
225-
# === Step 1. Detect available device types ===
226-
device_types = []
227-
if torch.cuda.is_available():
228-
device_types.append("cuda")
229-
if hasattr(torch, "xpu") and torch.xpu.is_available():
230-
device_types.append("xpu")
231-
if hasattr(torch, "hpu") and is_hpex_available():
232-
device_types.append("hpu")
233-
234-
# Always include CPU as a fallback
235-
if not device_types:
236-
device_types = ["cpu"]
237-
238-
# === Step 2. Parse different input formats ===
239-
if device_map is None:
240-
# Automatically detect one available device
241-
if "cuda" in device_types:
242-
return ["cuda:0"]
243-
elif "xpu" in device_types:
244-
return ["xpu:0"]
245-
elif "hpu" in device_types:
246-
return ["hpu:0"]
247-
else:
248-
return ["cpu"]
249-
250-
if isinstance(device_map, torch.device):
251-
# Handle torch.device objects
252-
dev_type = device_map.type
253-
index = device_map.index
254-
if dev_type == "cpu":
255-
return ["cpu"]
256-
if index is None:
257-
index = 0
258-
return [f"{dev_type}:{index}"]
259-
260-
if isinstance(device_map, int):
261-
# Integer input → use primary available device type
262-
device_type = device_types[0]
263-
return [f"{device_type}:{device_map}"] if device_type != "cpu" else ["cpu"]
264-
265-
if isinstance(device_map, str):
266-
# Remove whitespace
267-
device_map = device_map.strip()
268-
if device_map.lower() == "cpu":
269-
return ["cpu"]
270-
271-
# Split by commas
272-
parts = [x.strip() for x in device_map.split(",") if x.strip()]
273-
parsed = []
274-
for p in parts:
275-
if p.isdigit():
276-
# Numeric → assign to first available device type
277-
device_type = device_types[0]
278-
parsed.append(f"{device_type}:{p}" if device_type != "cpu" else "cpu")
279-
else:
280-
parsed.append(p)
281-
return parsed
282-
283-
if isinstance(device_map, dict):
284-
# Extract all devices recursively from dict values
285-
devices = set()
286-
for v in device_map.values():
287-
devices.update(parse_all_available_device(v))
288-
return sorted(devices)
289-
290-
raise TypeError(f"Unsupported device_map type: {type(device_map)}")
291-
292-
293208
# Important Notice This dispatch does not follow dict device_map, just extract all available devices and use them
294209
def dispatch_model_by_all_available_devices(
295210
model: torch.nn.Module, device_map: Union[str, int, dict, None]

auto_round/compressors/base.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
from auto_round.utils.device import (
9898
clear_memory_if_reached_threshold,
9999
get_major_device,
100+
parse_all_available_device,
100101
set_auto_device_map_for_block_with_tuning,
101102
set_non_auto_device_map,
102103
)
@@ -1980,17 +1981,25 @@ def try_cache_inter_data_gpucpu(self, block_names, nsamples, layer_names=None, l
19801981
self.model = dispatch_model(self.model, device_map=self.model.hf_device_map)
19811982
else:
19821983
# Change this if new device is supported
1983-
if str(self.model.device) == "cpu" and (
1984-
self.device.startswith("xpu") or self.device.startswith("cuda")
1985-
):
1984+
if str(self.model.device) == "cpu" and (not self.device.startswith("hpu")):
19861985
no_split_modules = getattr(self.model, "_no_split_modules", [])
1986+
devices = parse_all_available_device(self.device_map)
19871987
max_memory = get_balanced_memory(
19881988
self.model,
19891989
max_memory=None,
19901990
no_split_module_classes=no_split_modules,
19911991
)
1992+
new_max_memory = {}
1993+
for device in devices:
1994+
if ":" in device:
1995+
device = int(device.split(":")[-1])
1996+
elif device == "cpu":
1997+
device = "cpu"
1998+
else:
1999+
raise ValueError(f"Unsupported device {device} in device_map: {self.device_map}")
2000+
new_max_memory[device] = max_memory[device]
19922001
device_map = infer_auto_device_map(
1993-
self.model, max_memory=max_memory, no_split_module_classes=no_split_modules
2002+
self.model, max_memory=new_max_memory, no_split_module_classes=no_split_modules
19942003
)
19952004

19962005
self.model = dispatch_model(self.model, device_map=device_map)

auto_round/utils/device.py

Lines changed: 91 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -924,10 +924,10 @@ def estimate_tuning_block_mem(
924924
# TODO: XPU takes more memory than expected. for llama 8B, it's about 12 GB
925925
xpu_additional_memory = 12 # GB
926926
additional_memory += xpu_additional_memory
927-
logger.warning_once(
928-
"[Memory Estimation]: If there is an abnormal memory issue, please collect log with "
929-
+ "AR_LOG_LEVEL=debug and raise issue to us."
930-
)
927+
# logger.warning_once(
928+
# "[Memory Estimation]: If there is an abnormal memory issue, please collect log with "
929+
# + "AR_LOG_LEVEL=debug and raise issue to us."
930+
# )
931931

932932
return layer_memory_dict, layer_activation_memory, block_input_output_memory, additional_memory
933933

@@ -1028,8 +1028,7 @@ def set_auto_device_map_for_block_with_tuning(
10281028
mem_per_param = total_available_memory / total_params
10291029

10301030
# Initialize device memory tracking
1031-
device_memory = {}
1032-
device_memory[device_0] = card_0_left_memory
1031+
device_memory = {device_0: card_0_left_memory}
10331032
for i in range(1, len(gpu_devices)):
10341033
device_idx = device_list[i] if device_list else i
10351034
device_memory[gpu_devices[i]] = get_device_memory(device_idx)
@@ -1181,3 +1180,89 @@ def set_avg_auto_device_map(model: torch.nn.Module, device_map):
11811180
groups = partition_dict_numbers(number_dict, 2)
11821181
for i, group in enumerate(groups):
11831182
print(f"Group {i + 1}: {group}, Sum: {sum(group.values())}")
1183+
1184+
1185+
def parse_all_available_device(device_map: Union[str, torch.device, int, dict, None] = None) -> list:
1186+
"""
1187+
Parse the device map and return a list of all available devices.
1188+
1189+
Supported input formats:
1190+
- None: Automatically detect all available devices
1191+
- int: A single device index (e.g., 0)
1192+
- str: Examples:
1193+
"cpu"
1194+
"cuda:0,cuda:1"
1195+
"0,1" (numeric device indices)
1196+
- dict: Extract all device values from the dictionary
1197+
- torch.device: e.g. torch.device("cuda:0")
1198+
1199+
Returns:
1200+
list[str]: Normalized device names, e.g., ["cuda:0", "cuda:1"] or ["cpu"]
1201+
"""
1202+
1203+
# === Step 1. Detect available device types ===
1204+
device_types = []
1205+
if torch.cuda.is_available():
1206+
device_types.append("cuda")
1207+
if hasattr(torch, "xpu") and torch.xpu.is_available():
1208+
device_types.append("xpu")
1209+
if hasattr(torch, "hpu") and is_hpex_available():
1210+
device_types.append("hpu")
1211+
1212+
# Always include CPU as a fallback
1213+
if not device_types:
1214+
device_types = ["cpu"]
1215+
1216+
# === Step 2. Parse different input formats ===
1217+
if device_map is None:
1218+
# Automatically detect one available device
1219+
if "cuda" in device_types:
1220+
return ["cuda:0"]
1221+
elif "xpu" in device_types:
1222+
return ["xpu:0"]
1223+
elif "hpu" in device_types:
1224+
return ["hpu:0"]
1225+
else:
1226+
return ["cpu"]
1227+
1228+
if isinstance(device_map, torch.device):
1229+
# Handle torch.device objects
1230+
dev_type = device_map.type
1231+
index = device_map.index
1232+
if dev_type == "cpu":
1233+
return ["cpu"]
1234+
if index is None:
1235+
index = 0
1236+
return [f"{dev_type}:{index}"]
1237+
1238+
if isinstance(device_map, int):
1239+
# Integer input → use primary available device type
1240+
device_type = device_types[0]
1241+
return [f"{device_type}:{device_map}"] if device_type != "cpu" else ["cpu"]
1242+
1243+
if isinstance(device_map, str):
1244+
# Remove whitespace
1245+
device_map = device_map.strip()
1246+
if device_map.lower() == "cpu":
1247+
return ["cpu"]
1248+
1249+
# Split by commas
1250+
parts = [x.strip() for x in device_map.split(",") if x.strip()]
1251+
parsed = []
1252+
for p in parts:
1253+
if p.isdigit():
1254+
# Numeric → assign to first available device type
1255+
device_type = device_types[0]
1256+
parsed.append(f"{device_type}:{p}" if device_type != "cpu" else "cpu")
1257+
else:
1258+
parsed.append(p)
1259+
return parsed
1260+
1261+
if isinstance(device_map, dict):
1262+
# Extract all devices recursively from dict values
1263+
devices = set()
1264+
for v in device_map.values():
1265+
devices.update(parse_all_available_device(v))
1266+
return sorted(devices)
1267+
1268+
raise TypeError(f"Unsupported device_map type: {type(device_map)}")

auto_round_extension/triton/qlinear_tritonv2_zp.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,8 @@
1515
import math
1616
from logging import getLogger
1717

18-
import numpy as np
1918
import torch
2019
import torch.nn as nn
21-
import transformers
2220

2321
from auto_round_extension.triton.triton_utils_zp.mixin import TritonModuleMixin
2422

auto_round_extension/triton/triton_utils/dequant.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
import torch
3939
import triton
4040
import triton.language as tl
41-
from torch.cuda.amp import custom_bwd, custom_fwd
4241

4342

4443
def make_dequant_configs(block_sizes, num_warps):

0 commit comments

Comments
 (0)