From 302fd3e6d2f4f58d6181fb62cbd8af979c7c29d1 Mon Sep 17 00:00:00 2001 From: Michele Ceriotti Date: Mon, 9 Feb 2026 21:39:33 +0100 Subject: [PATCH 1/2] Handle device selection with index --- metatomic-torch/src/misc.cpp | 13 ++++++++++--- metatomic-torch/src/register.cpp | 16 ++++++++-------- python/metatomic_torch/tests/test_pick_device.py | 14 ++++++++++++++ 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/metatomic-torch/src/misc.cpp b/metatomic-torch/src/misc.cpp index b2991fb2..ab4a2498 100644 --- a/metatomic-torch/src/misc.cpp +++ b/metatomic-torch/src/misc.cpp @@ -138,10 +138,17 @@ c10::DeviceType pick_device( } // normalize desired and check - std::string wanted = lower(desired_device.value()); + std::string wanted_str = lower(desired_device.value()); + torch::DeviceType wanted_type; + try { + wanted_type = torch::Device(wanted_str).type(); + } catch (const std::exception &) { + C10_THROW_ERROR(ValueError, "invalid device string: " + desired_device.value()); + } + for (auto &a : available) { - if (a == wanted) { - return map_to_devicetype(a); + if (map_to_devicetype(a) == wanted_type) { + return wanted_type; } } diff --git a/metatomic-torch/src/register.cpp b/metatomic-torch/src/register.cpp index fd5a22e5..f965b896 100644 --- a/metatomic-torch/src/register.cpp +++ b/metatomic-torch/src/register.cpp @@ -13,20 +13,20 @@ std::string pick_device_pywrapper( ) { try { torch::optional desired = torch::nullopt; - if (requested_device.has_value()) { + if (requested_device.has_value() && !requested_device->empty()) { desired = requested_device.value(); } c10::DeviceType devtype = metatomic_torch::pick_device(model_devices, desired); - // Convert device type to string, stripping device index - torch::Device dev(devtype); - std::string s = dev.str(); - auto pos = s.find(':'); - if (pos != std::string::npos) { - return s.substr(0, pos); + if (desired.has_value()) { + // User requested a specific device (possibly with an index like "cuda:1"). + // We return it normalized (e.g. "CUDA:1" -> "cuda:1"). + return torch::Device(desired.value()).str(); + } else { + // Automatic selection: return the device type name (e.g. "cuda"). + return torch::Device(devtype).str(); } - return s; } catch (const std::exception &e) { throw std::runtime_error(std::string("pick_device failed: ") + e.what()); diff --git a/python/metatomic_torch/tests/test_pick_device.py b/python/metatomic_torch/tests/test_pick_device.py index fa57cc31..d4d3ca18 100644 --- a/python/metatomic_torch/tests/test_pick_device.py +++ b/python/metatomic_torch/tests/test_pick_device.py @@ -51,3 +51,17 @@ def test_pick_device_error_on_unavailable_requested(): # If CUDA is available, requesting a non-present device should raise with pytest.raises(RuntimeError): mta.pick_device(["cpu"], "cuda") + + +def test_pick_device_indexed(): + # Test that indexed device strings like "cpu:0" or "cuda:1" are accepted + # and preserved. + res = mta.pick_device(["cpu", "cuda"], "cpu:0") + assert res == "cpu:0" + + if torch.cuda.is_available(): + res = mta.pick_device(["cpu", "cuda"], "cuda:0") + assert res == "cuda:0" + + with pytest.raises(RuntimeError, match="invalid device string"): + mta.pick_device(["cpu"], "cpu:invalid") From c1f2de0c4151aa8663da49d86d3b8707e3f3280d Mon Sep 17 00:00:00 2001 From: Michele Ceriotti Date: Mon, 9 Feb 2026 21:43:28 +0100 Subject: [PATCH 2/2] Add documentation --- metatomic-torch/include/metatomic/torch/misc.hpp | 4 ++++ .../metatomic/torch/documentation.py | 16 +++++++++++----- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/metatomic-torch/include/metatomic/torch/misc.hpp b/metatomic-torch/include/metatomic/torch/misc.hpp index ec166a85..84fa4f21 100644 --- a/metatomic-torch/include/metatomic/torch/misc.hpp +++ b/metatomic-torch/include/metatomic/torch/misc.hpp @@ -23,6 +23,10 @@ METATOMIC_TORCH_EXPORT std::string version(); /// model, the user-provided `desired_device` and what's available on the /// current machine. /// +/// If `desired_device` is provided, it is checked against the `model_devices` +/// and the machine availability. If it contains a device index (e.g. "cuda:1"), +/// the base device type ("cuda") is used for these checks. +/// /// This function returns a c10::DeviceType (torch::DeviceType). It does NOT /// decide a device index — callers that need a full torch::Device should /// construct one from the returned DeviceType (and choose an index explicitly). diff --git a/python/metatomic_torch/metatomic/torch/documentation.py b/python/metatomic_torch/metatomic/torch/documentation.py index 83360c12..e0787ab2 100644 --- a/python/metatomic_torch/metatomic/torch/documentation.py +++ b/python/metatomic_torch/metatomic/torch/documentation.py @@ -536,13 +536,19 @@ def unit_conversion_factor(quantity: str, from_unit: str, to_unit: str): def pick_device(model_devices: List[str], desired_device: Optional[str]) -> str: """ - Select the best device according to the list of ``model_devices`` from a - model, the user-provided ``desired_device`` and what's available on the - current machine. + Select the best device according to the list of ``model_devices`` from a model, the + user-provided ``desired_device`` and what's available on the current machine. + + If ``desired_device`` is provided, it is checked against the ``model_devices`` and + the machine availability. If it contains a device index (e.g. ``"cuda:1"``), the + base device type (``"cuda"``) is used for these checks, and the full string is + returned if successful. + + If ``desired_device`` is ``None`` or an empty string, the first available device + from ``model_devices`` will be picked and returned. :param model_devices: list of devices supported by a model in order of preference - :param desired_device: user-provided desired device. If ``None`` or not available, - the first available device from ``model_devices`` will be picked. + :param desired_device: user-provided desired device. """