Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions metatomic-torch/include/metatomic/torch/misc.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ METATOMIC_TORCH_EXPORT std::string version();
/// model, the user-provided `desired_device` and what's available on the
/// current machine.
///
/// If `desired_device` is provided, it is checked against the `model_devices`
/// and the machine availability. If it contains a device index (e.g. "cuda:1"),
/// the base device type ("cuda") is used for these checks.
///
/// This function returns a c10::DeviceType (torch::DeviceType). It does NOT
/// decide a device index — callers that need a full torch::Device should
/// construct one from the returned DeviceType (and choose an index explicitly).
Expand Down
13 changes: 10 additions & 3 deletions metatomic-torch/src/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,17 @@ c10::DeviceType pick_device(
}

// normalize desired and check
std::string wanted = lower(desired_device.value());
std::string wanted_str = lower(desired_device.value());
torch::DeviceType wanted_type;
try {
wanted_type = torch::Device(wanted_str).type();
} catch (const std::exception &) {
C10_THROW_ERROR(ValueError, "invalid device string: " + desired_device.value());
}

for (auto &a : available) {
if (a == wanted) {
return map_to_devicetype(a);
if (map_to_devicetype(a) == wanted_type) {
return wanted_type;
}
}

Expand Down
16 changes: 8 additions & 8 deletions metatomic-torch/src/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,20 @@ std::string pick_device_pywrapper(
) {
try {
torch::optional<std::string> desired = torch::nullopt;
if (requested_device.has_value()) {
if (requested_device.has_value() && !requested_device->empty()) {
desired = requested_device.value();
}

c10::DeviceType devtype = metatomic_torch::pick_device(model_devices, desired);

// Convert device type to string, stripping device index
torch::Device dev(devtype);
std::string s = dev.str();
auto pos = s.find(':');
if (pos != std::string::npos) {
return s.substr(0, pos);
if (desired.has_value()) {
// User requested a specific device (possibly with an index like "cuda:1").
// We return it normalized (e.g. "CUDA:1" -> "cuda:1").
return torch::Device(desired.value()).str();
} else {
// Automatic selection: return the device type name (e.g. "cuda").
return torch::Device(devtype).str();
}
return s;

} catch (const std::exception &e) {
throw std::runtime_error(std::string("pick_device failed: ") + e.what());
Expand Down
16 changes: 11 additions & 5 deletions python/metatomic_torch/metatomic/torch/documentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,13 +536,19 @@ def unit_conversion_factor(quantity: str, from_unit: str, to_unit: str):

def pick_device(model_devices: List[str], desired_device: Optional[str]) -> str:
"""
Select the best device according to the list of ``model_devices`` from a
model, the user-provided ``desired_device`` and what's available on the
current machine.
Select the best device according to the list of ``model_devices`` from a model, the
user-provided ``desired_device`` and what's available on the current machine.

If ``desired_device`` is provided, it is checked against the ``model_devices`` and
the machine availability. If it contains a device index (e.g. ``"cuda:1"``), the
base device type (``"cuda"``) is used for these checks, and the full string is
returned if successful.

If ``desired_device`` is ``None`` or an empty string, the first available device
from ``model_devices`` will be picked and returned.

:param model_devices: list of devices supported by a model in order of preference
:param desired_device: user-provided desired device. If ``None`` or not available,
the first available device from ``model_devices`` will be picked.
:param desired_device: user-provided desired device.
"""


Expand Down
14 changes: 14 additions & 0 deletions python/metatomic_torch/tests/test_pick_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,17 @@ def test_pick_device_error_on_unavailable_requested():
# If CUDA is available, requesting a non-present device should raise
with pytest.raises(RuntimeError):
mta.pick_device(["cpu"], "cuda")


def test_pick_device_indexed():
# Test that indexed device strings like "cpu:0" or "cuda:1" are accepted
# and preserved.
res = mta.pick_device(["cpu", "cuda"], "cpu:0")
assert res == "cpu:0"

if torch.cuda.is_available():
res = mta.pick_device(["cpu", "cuda"], "cuda:0")
assert res == "cuda:0"

with pytest.raises(RuntimeError, match="invalid device string"):
mta.pick_device(["cpu"], "cpu:invalid")
Loading