Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions rapids_cli/doctor/checks/nvlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,43 @@
import pynvml


def check_nvlink_status(verbose=True):
"""Check the system for NVLink with 2 or more GPUs."""
def check_nvlink_status(verbose=True, **kwargs):
"""Check NVLink status across all GPUs."""
try:
pynvml.nvmlInit()
except pynvml.NVMLError as e:
raise ValueError("GPU not found. Please ensure GPUs are installed.") from e

device_count = pynvml.nvmlDeviceGetCount()

# NVLink requires at least 2 GPUs to be meaningful. A single GPU has nothing
# to link to, so there is nothing to check.
if device_count < 2:
return False

for i in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(i)
for nvlink_id in range(pynvml.NVML_NVLINK_MAX_LINKS):
# Note: this check assumes a homogeneous GPU environment (all GPUs of the same
# model). Mixed configurations — e.g. some NVLink-capable GPUs alongside some
# that are not — are not handled and may produce misleading results.

failed_links: list[tuple[int, int]] = []

for gpu_idx in range(device_count):
handle = pynvml.nvmlDeviceGetHandleByIndex(gpu_idx)
for link_id in range(pynvml.NVML_NVLINK_MAX_LINKS):
try:
pynvml.nvmlDeviceGetNvLinkState(handle, 0)
return True
except pynvml.NVMLError as e:
raise ValueError(f"NVLink {nvlink_id} Status Check Failed") from e
# nvmlDeviceGetNvLinkState(device, link) returns NVML_FEATURE_ENABLED
# if the link is active, or NVML_FEATURE_DISABLED if it is not.
state = pynvml.nvmlDeviceGetNvLinkState(handle, link_id)
if state == pynvml.NVML_FEATURE_DISABLED:
failed_links.append((gpu_idx, link_id))
except pynvml.NVMLError_NotSupported:
# The driver reports NVLink is not supported on this system.
# There is nothing to check — skip like the single-GPU case above.
return False

if failed_links:
details = ", ".join(f"GPU {gpu} link {link}" for gpu, link in failed_links)
raise ValueError(f"NVLink inactive on: {details}")

if verbose:
return f"All NVLinks active across {device_count} GPUs"
69 changes: 63 additions & 6 deletions rapids_cli/tests/test_nvlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,32 @@
from rapids_cli.doctor.checks.nvlink import check_nvlink_status


def test_check_nvlink_status_success():
@pytest.mark.parametrize(
"verbose, expected",
[
(True, "All NVLinks active across 2 GPUs"),
(False, None),
],
)
def test_check_nvlink_status_success(verbose, expected):
"""2 GPUs, all NVLinks active — verbose controls whether a summary string is returned."""
import pynvml

mock_handle = MagicMock()
with (
patch("pynvml.nvmlInit"),
patch("pynvml.nvmlDeviceGetCount", return_value=2),
patch("pynvml.nvmlDeviceGetHandleByIndex", return_value=mock_handle),
patch("pynvml.nvmlDeviceGetNvLinkState", return_value=1),
patch(
"pynvml.nvmlDeviceGetNvLinkState", return_value=pynvml.NVML_FEATURE_ENABLED
),
):
result = check_nvlink_status(verbose=True)
assert result is True
result = check_nvlink_status(verbose=verbose)
assert result == expected


def test_check_nvlink_status_single_gpu():
"""Single GPU — NVLink is not applicable, check skips early."""
with (
patch("pynvml.nvmlInit"),
patch("pynvml.nvmlDeviceGetCount", return_value=1),
Expand All @@ -29,6 +42,7 @@ def test_check_nvlink_status_single_gpu():


def test_check_nvlink_status_no_gpu():
"""nvmlInit fails — no GPUs installed."""
import pynvml

with patch("pynvml.nvmlInit", side_effect=pynvml.NVMLError(1)):
Expand All @@ -38,7 +52,8 @@ def test_check_nvlink_status_no_gpu():
check_nvlink_status(verbose=False)


def test_check_nvlink_status_nvml_error():
def test_check_nvlink_status_not_supported():
"""NVLink is not supported on this system — check skips silently like single-GPU case."""
import pynvml

mock_handle = MagicMock()
Expand All @@ -50,5 +65,47 @@ def test_check_nvlink_status_nvml_error():
"pynvml.nvmlDeviceGetNvLinkState", side_effect=pynvml.NVMLError_NotSupported
),
):
with pytest.raises(ValueError, match="NVLink 0 Status Check Failed"):
result = check_nvlink_status(verbose=False)
assert result is False


def test_check_nvlink_status_link_inactive():
"""A supported link is inactive — check fails and reports which GPU and link."""
import pynvml

mock_handle = MagicMock()
with (
patch("pynvml.nvmlInit"),
patch("pynvml.nvmlDeviceGetCount", return_value=2),
patch("pynvml.nvmlDeviceGetHandleByIndex", return_value=mock_handle),
patch(
"pynvml.nvmlDeviceGetNvLinkState", return_value=pynvml.NVML_FEATURE_DISABLED
),
):
with pytest.raises(ValueError, match="NVLink inactive on:"):
check_nvlink_status(verbose=False)


def test_check_nvlink_status_partial_failure():
"""Some links active, some inactive — all failures are reported in a single error."""
import pynvml

mock_handle = MagicMock()

# Simulate: link 0 active, link 1 inactive, rest active
def mock_link_state(handle, link_id):
if link_id == 1:
return pynvml.NVML_FEATURE_DISABLED
return pynvml.NVML_FEATURE_ENABLED

with (
patch("pynvml.nvmlInit"),
patch("pynvml.nvmlDeviceGetCount", return_value=2),
patch("pynvml.nvmlDeviceGetHandleByIndex", return_value=mock_handle),
patch("pynvml.nvmlDeviceGetNvLinkState", side_effect=mock_link_state),
):
with pytest.raises(ValueError, match="NVLink inactive on:") as exc_info:
check_nvlink_status(verbose=False)
# Both GPUs should have link 1 reported as failed
assert "GPU 0 link 1" in str(exc_info.value)
assert "GPU 1 link 1" in str(exc_info.value)
Loading