Skip to content
Draft
112 changes: 108 additions & 4 deletions internal/plugin/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ const (
deviceListEnvVar = "NVIDIA_VISIBLE_DEVICES"
deviceListAsVolumeMountsHostPath = "/dev/null"
deviceListAsVolumeMountsContainerPathRoot = "/var/run/nvidia-container-devices"

// healthChannelBufferSize defines the buffer capacity for the health
// channel. This is sized to handle bursts of unhealthy device reports
// without blocking the health check goroutine. With 8 GPUs and
// potential for multiple events per GPU (XID errors, ECC errors, etc.),
// a buffer of 64 provides ample headroom while using a power-of-2 size
// for cache-friendly alignment.
healthChannelBufferSize = 64
)

// nvidiaDevicePlugin implements the Kubernetes device plugin API
Expand All @@ -64,6 +72,10 @@ type nvidiaDevicePlugin struct {
health chan *rm.Device
stop chan interface{}

// deviceListUpdate is used to trigger ListAndWatch to send updated device
// list to kubelet (e.g., when devices recover from unhealthy state)
deviceListUpdate chan struct{}

imexChannels imex.Channels

mps mpsOptions
Expand Down Expand Up @@ -108,15 +120,20 @@ func getPluginSocketPath(resource spec.ResourceName) string {

func (plugin *nvidiaDevicePlugin) initialize() {
plugin.server = grpc.NewServer([]grpc.ServerOption{}...)
plugin.health = make(chan *rm.Device)
plugin.health = make(chan *rm.Device, healthChannelBufferSize)
plugin.stop = make(chan interface{})
plugin.deviceListUpdate = make(chan struct{}, 1)
}

func (plugin *nvidiaDevicePlugin) cleanup() {
close(plugin.stop)
if plugin.deviceListUpdate != nil {
close(plugin.deviceListUpdate)
}
plugin.server = nil
plugin.health = nil
plugin.stop = nil
plugin.deviceListUpdate = nil
}

// Devices returns the full set of devices associated with the plugin.
Expand Down Expand Up @@ -156,6 +173,9 @@ func (plugin *nvidiaDevicePlugin) Start(kubeletSocket string) error {
}
}()

// Start recovery worker to detect when unhealthy devices become healthy
go plugin.runRecoveryWorker()

return nil
}

Expand Down Expand Up @@ -263,7 +283,9 @@ func (plugin *nvidiaDevicePlugin) GetDevicePluginOptions(context.Context, *plugi
return options, nil
}

// ListAndWatch lists devices and update that list according to the health status
// ListAndWatch lists devices and update that list according to the health
// status. This now supports device recovery: when devices that were marked
// unhealthy recover, they are automatically re-advertised to kubelet.
func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error {
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
return err
Expand All @@ -274,9 +296,17 @@ func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.D
case <-plugin.stop:
return nil
case d := <-plugin.health:
// FIXME: there is no way to recover from the Unhealthy state.
// Device marked unhealthy by health check
d.Health = pluginapi.Unhealthy
klog.Infof("'%s' device marked unhealthy: %s", plugin.rm.Resource(), d.ID)
klog.Infof("'%s' device marked unhealthy: %s (reason: %s)",
plugin.rm.Resource(), d.ID, d.UnhealthyReason)
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
return nil
}
case <-plugin.deviceListUpdate:
// Device recovery or other device list change
klog.Infof("'%s' device list updated, notifying kubelet",
plugin.rm.Resource())
if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil {
return nil
}
Expand Down Expand Up @@ -512,6 +542,80 @@ func (plugin *nvidiaDevicePlugin) updateResponseForDeviceMounts(response *plugin
}
}

// runRecoveryWorker periodically checks if unhealthy devices have recovered
// and notifies kubelet when they do.
func (plugin *nvidiaDevicePlugin) runRecoveryWorker() {
const recoveryInterval = 30 * time.Second

ticker := time.NewTicker(recoveryInterval)
defer ticker.Stop()

klog.V(2).Infof("Recovery worker started for '%s' (interval=%v)",
plugin.rm.Resource(), recoveryInterval)

for {
select {
case <-plugin.stop:
klog.V(2).Info("Recovery worker stopped")
return
case <-ticker.C:
plugin.checkForRecoveredDevices()
}
}
}

// checkForRecoveredDevices checks all unhealthy devices to see if they have
// recovered. If any have recovered, triggers a device list update to
// kubelet.
func (plugin *nvidiaDevicePlugin) checkForRecoveredDevices() {
recoveredDevices := []*rm.Device{}

for _, d := range plugin.rm.Devices() {
if !d.IsUnhealthy() {
continue
}

// Increment recovery attempts
d.RecoveryAttempts++

// Check if device has recovered
healthy, err := plugin.rm.CheckDeviceHealth(d)
if err != nil {
klog.V(4).Infof("Device %s recovery check failed (attempt %d): %v",
d.ID, d.RecoveryAttempts, err)
continue
}

if healthy {
klog.Infof("Device %s has RECOVERED! Was unhealthy for %v (reason: %s)",
d.ID, d.UnhealthyDuration(), d.UnhealthyReason)
d.MarkHealthy()
recoveredDevices = append(recoveredDevices, d)
} else {
klog.V(3).Infof("Device %s still unhealthy (attempt %d, duration %v)",
d.ID, d.RecoveryAttempts, d.UnhealthyDuration())
}
}

// If any devices recovered, notify ListAndWatch
if len(recoveredDevices) > 0 {
klog.Infof("Total recovered devices: %d", len(recoveredDevices))
plugin.triggerDeviceListUpdate()
}
}

// triggerDeviceListUpdate sends a signal to ListAndWatch to send an updated
// device list to kubelet. Uses a buffered channel with non-blocking send to
// avoid blocking the recovery worker.
func (plugin *nvidiaDevicePlugin) triggerDeviceListUpdate() {
select {
case plugin.deviceListUpdate <- struct{}{}:
klog.V(3).Info("Device list update triggered")
default:
klog.V(4).Info("Device list update already pending, skipping")
}
}

func (plugin *nvidiaDevicePlugin) apiDeviceSpecs(devRoot string, ids []string) []*pluginapi.DeviceSpec {
optional := map[string]bool{
"/dev/nvidiactl": true,
Expand Down
95 changes: 95 additions & 0 deletions internal/plugin/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ package plugin

import (
"context"
"fmt"
"testing"
"time"

"github.com/stretchr/testify/require"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
Expand Down Expand Up @@ -254,3 +256,96 @@ func TestCDIAllocateResponse(t *testing.T) {
func ptr[T any](x T) *T {
return &x
}

func TestTriggerDeviceListUpdate_Phase2(t *testing.T) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a matter of interest, what is Phase2? (Were these tests generated?)

plugin := &nvidiaDevicePlugin{
deviceListUpdate: make(chan struct{}, 1),
}

// First trigger should send signal
plugin.triggerDeviceListUpdate()
select {
case <-plugin.deviceListUpdate:
t.Log("✓ Device list update signal sent")
case <-time.After(100 * time.Millisecond):
t.Fatal("Signal not sent")
}

// Second trigger with pending signal should not block
plugin.triggerDeviceListUpdate()
plugin.triggerDeviceListUpdate() // Should not block
t.Log("✓ triggerDeviceListUpdate doesn't block when signal pending")
}

func TestCheckForRecoveredDevices_Phase2(t *testing.T) {
// Create persistent device map
devices := rm.Devices{
"GPU-0": &rm.Device{
Device: pluginapi.Device{
ID: "GPU-0",
Health: pluginapi.Unhealthy,
},
UnhealthyReason: "XID-79",
},
"GPU-1": &rm.Device{
Device: pluginapi.Device{
ID: "GPU-1",
Health: pluginapi.Unhealthy,
},
UnhealthyReason: "XID-48",
},
"GPU-2": &rm.Device{
Device: pluginapi.Device{
ID: "GPU-2",
Health: pluginapi.Healthy,
},
},
}

// Create mock resource manager with persistent devices
mockRM := &rm.ResourceManagerMock{
DevicesFunc: func() rm.Devices {
return devices
},
CheckDeviceHealthFunc: func(d *rm.Device) (bool, error) {
// GPU-0 recovers, GPU-1 stays unhealthy
if d.ID == "GPU-0" {
return true, nil
}
return false, fmt.Errorf("still unhealthy")
},
}

plugin := &nvidiaDevicePlugin{
rm: mockRM,
deviceListUpdate: make(chan struct{}, 1),
}

plugin.checkForRecoveredDevices()

// Verify GPU-0 recovered
gpu0 := devices["GPU-0"]
require.Equal(t, pluginapi.Healthy, gpu0.Health, "GPU-0 should be healthy")
require.Equal(t, "", gpu0.UnhealthyReason)
t.Logf("✓ GPU-0 recovered: Health=%s, Reason=%s", gpu0.Health, gpu0.UnhealthyReason)

// Verify GPU-1 still unhealthy
gpu1 := devices["GPU-1"]
require.Equal(t, pluginapi.Unhealthy, gpu1.Health, "GPU-1 should still be unhealthy")
require.Equal(t, 1, gpu1.RecoveryAttempts, "GPU-1 recovery attempts should increment")
t.Logf("✓ GPU-1 still unhealthy: attempts=%d", gpu1.RecoveryAttempts)

// Verify GPU-2 unchanged
gpu2 := devices["GPU-2"]
require.Equal(t, pluginapi.Healthy, gpu2.Health)
require.Equal(t, 0, gpu2.RecoveryAttempts, "Healthy device shouldn't be probed")
t.Log("✓ GPU-2 unchanged (was already healthy)")

// Verify deviceListUpdate was triggered
select {
case <-plugin.deviceListUpdate:
t.Log("✓ Device list update triggered for recovery")
case <-time.After(100 * time.Millisecond):
t.Fatal("Device list update not triggered")
}
}
41 changes: 41 additions & 0 deletions internal/rm/devices.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"fmt"
"strconv"
"strings"
"time"

"k8s.io/klog/v2"
pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1"
Expand All @@ -35,6 +36,12 @@ type Device struct {
// Replicas stores the total number of times this device is replicated.
// If this is 0 or 1 then the device is not shared.
Replicas int

// Health tracking fields for recovery detection
LastHealthyTime time.Time // Last time device was confirmed healthy
LastUnhealthyTime time.Time // When device became unhealthy
UnhealthyReason string // Human-readable reason (e.g., "XID-79")
RecoveryAttempts int // Number of recovery probes attempted
}

// deviceInfo defines the information the required to construct a Device
Expand Down Expand Up @@ -239,6 +246,40 @@ func (d *Device) GetUUID() string {
return AnnotatedID(d.ID).GetID()
}

// MarkUnhealthy marks the device as unhealthy and records the reason and
// timestamp. This should be called when a health check detects a device
// failure (e.g., XID error).
func (d *Device) MarkUnhealthy(reason string) {
d.Health = pluginapi.Unhealthy
d.LastUnhealthyTime = time.Now()
d.UnhealthyReason = reason
d.RecoveryAttempts = 0
}

// MarkHealthy marks the device as healthy and clears unhealthy state. This
// should be called when recovery detection confirms the device is working
// again.
func (d *Device) MarkHealthy() {
d.Health = pluginapi.Healthy
d.LastHealthyTime = time.Now()
d.UnhealthyReason = ""
d.RecoveryAttempts = 0
}

// IsUnhealthy returns true if the device is currently marked as unhealthy.
func (d *Device) IsUnhealthy() bool {
return d.Health == pluginapi.Unhealthy
}

// UnhealthyDuration returns how long the device has been unhealthy. Returns
// zero duration if the device is healthy.
func (d *Device) UnhealthyDuration() time.Duration {
if !d.IsUnhealthy() {
return 0
}
return time.Since(d.LastUnhealthyTime)
}

// NewAnnotatedID creates a new AnnotatedID from an ID and a replica number.
func NewAnnotatedID(id string, replica int) AnnotatedID {
return AnnotatedID(fmt.Sprintf("%s::%d", id, replica))
Expand Down
Loading
Loading