@@ -924,10 +924,10 @@ def estimate_tuning_block_mem(
924924 # TODO: XPU takes more memory than expected. for llama 8B, it's about 12 GB
925925 xpu_additional_memory = 12 # GB
926926 additional_memory += xpu_additional_memory
927- logger .warning_once (
928- "[Memory Estimation]: If there is an abnormal memory issue, please collect log with "
929- + "AR_LOG_LEVEL=debug and raise issue to us."
930- )
927+ # logger.warning_once(
928+ # "[Memory Estimation]: If there is an abnormal memory issue, please collect log with "
929+ # + "AR_LOG_LEVEL=debug and raise issue to us."
930+ # )
931931
932932 return layer_memory_dict , layer_activation_memory , block_input_output_memory , additional_memory
933933
@@ -1028,8 +1028,7 @@ def set_auto_device_map_for_block_with_tuning(
10281028 mem_per_param = total_available_memory / total_params
10291029
10301030 # Initialize device memory tracking
1031- device_memory = {}
1032- device_memory [device_0 ] = card_0_left_memory
1031+ device_memory = {device_0 : card_0_left_memory }
10331032 for i in range (1 , len (gpu_devices )):
10341033 device_idx = device_list [i ] if device_list else i
10351034 device_memory [gpu_devices [i ]] = get_device_memory (device_idx )
@@ -1181,3 +1180,89 @@ def set_avg_auto_device_map(model: torch.nn.Module, device_map):
11811180 groups = partition_dict_numbers (number_dict , 2 )
11821181 for i , group in enumerate (groups ):
11831182 print (f"Group { i + 1 } : { group } , Sum: { sum (group .values ())} " )
1183+
1184+
1185+ def parse_all_available_device (device_map : Union [str , torch .device , int , dict , None ] = None ) -> list :
1186+ """
1187+ Parse the device map and return a list of all available devices.
1188+
1189+ Supported input formats:
1190+ - None: Automatically detect all available devices
1191+ - int: A single device index (e.g., 0)
1192+ - str: Examples:
1193+ "cpu"
1194+ "cuda:0,cuda:1"
1195+ "0,1" (numeric device indices)
1196+ - dict: Extract all device values from the dictionary
1197+ - torch.device: e.g. torch.device("cuda:0")
1198+
1199+ Returns:
1200+ list[str]: Normalized device names, e.g., ["cuda:0", "cuda:1"] or ["cpu"]
1201+ """
1202+
1203+ # === Step 1. Detect available device types ===
1204+ device_types = []
1205+ if torch .cuda .is_available ():
1206+ device_types .append ("cuda" )
1207+ if hasattr (torch , "xpu" ) and torch .xpu .is_available ():
1208+ device_types .append ("xpu" )
1209+ if hasattr (torch , "hpu" ) and is_hpex_available ():
1210+ device_types .append ("hpu" )
1211+
1212+ # Always include CPU as a fallback
1213+ if not device_types :
1214+ device_types = ["cpu" ]
1215+
1216+ # === Step 2. Parse different input formats ===
1217+ if device_map is None :
1218+ # Automatically detect one available device
1219+ if "cuda" in device_types :
1220+ return ["cuda:0" ]
1221+ elif "xpu" in device_types :
1222+ return ["xpu:0" ]
1223+ elif "hpu" in device_types :
1224+ return ["hpu:0" ]
1225+ else :
1226+ return ["cpu" ]
1227+
1228+ if isinstance (device_map , torch .device ):
1229+ # Handle torch.device objects
1230+ dev_type = device_map .type
1231+ index = device_map .index
1232+ if dev_type == "cpu" :
1233+ return ["cpu" ]
1234+ if index is None :
1235+ index = 0
1236+ return [f"{ dev_type } :{ index } " ]
1237+
1238+ if isinstance (device_map , int ):
1239+ # Integer input → use primary available device type
1240+ device_type = device_types [0 ]
1241+ return [f"{ device_type } :{ device_map } " ] if device_type != "cpu" else ["cpu" ]
1242+
1243+ if isinstance (device_map , str ):
1244+ # Remove whitespace
1245+ device_map = device_map .strip ()
1246+ if device_map .lower () == "cpu" :
1247+ return ["cpu" ]
1248+
1249+ # Split by commas
1250+ parts = [x .strip () for x in device_map .split ("," ) if x .strip ()]
1251+ parsed = []
1252+ for p in parts :
1253+ if p .isdigit ():
1254+ # Numeric → assign to first available device type
1255+ device_type = device_types [0 ]
1256+ parsed .append (f"{ device_type } :{ p } " if device_type != "cpu" else "cpu" )
1257+ else :
1258+ parsed .append (p )
1259+ return parsed
1260+
1261+ if isinstance (device_map , dict ):
1262+ # Extract all devices recursively from dict values
1263+ devices = set ()
1264+ for v in device_map .values ():
1265+ devices .update (parse_all_available_device (v ))
1266+ return sorted (devices )
1267+
1268+ raise TypeError (f"Unsupported device_map type: { type (device_map )} " )
0 commit comments