diff --git a/internal/plugin/server.go b/internal/plugin/server.go index a9d4c8868..27bd97a6e 100644 --- a/internal/plugin/server.go +++ b/internal/plugin/server.go @@ -46,6 +46,14 @@ const ( deviceListEnvVar = "NVIDIA_VISIBLE_DEVICES" deviceListAsVolumeMountsHostPath = "/dev/null" deviceListAsVolumeMountsContainerPathRoot = "/var/run/nvidia-container-devices" + + // healthChannelBufferSize defines the buffer capacity for the health + // channel. This is sized to handle bursts of unhealthy device reports + // without blocking the health check goroutine. With 8 GPUs and + // potential for multiple events per GPU (XID errors, ECC errors, etc.), + // a buffer of 64 provides ample headroom while using a power-of-2 size + // for cache-friendly alignment. + healthChannelBufferSize = 64 ) // nvidiaDevicePlugin implements the Kubernetes device plugin API @@ -64,6 +72,10 @@ type nvidiaDevicePlugin struct { health chan *rm.Device stop chan interface{} + // deviceListUpdate is used to trigger ListAndWatch to send updated device + // list to kubelet (e.g., when devices recover from unhealthy state) + deviceListUpdate chan struct{} + imexChannels imex.Channels mps mpsOptions @@ -108,15 +120,20 @@ func getPluginSocketPath(resource spec.ResourceName) string { func (plugin *nvidiaDevicePlugin) initialize() { plugin.server = grpc.NewServer([]grpc.ServerOption{}...) - plugin.health = make(chan *rm.Device) + plugin.health = make(chan *rm.Device, healthChannelBufferSize) plugin.stop = make(chan interface{}) + plugin.deviceListUpdate = make(chan struct{}, 1) } func (plugin *nvidiaDevicePlugin) cleanup() { close(plugin.stop) + if plugin.deviceListUpdate != nil { + close(plugin.deviceListUpdate) + } plugin.server = nil plugin.health = nil plugin.stop = nil + plugin.deviceListUpdate = nil } // Devices returns the full set of devices associated with the plugin. @@ -156,6 +173,9 @@ func (plugin *nvidiaDevicePlugin) Start(kubeletSocket string) error { } }() + // Start recovery worker to detect when unhealthy devices become healthy + go plugin.runRecoveryWorker() + return nil } @@ -263,7 +283,9 @@ func (plugin *nvidiaDevicePlugin) GetDevicePluginOptions(context.Context, *plugi return options, nil } -// ListAndWatch lists devices and update that list according to the health status +// ListAndWatch lists devices and update that list according to the health +// status. This now supports device recovery: when devices that were marked +// unhealthy recover, they are automatically re-advertised to kubelet. func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.DevicePlugin_ListAndWatchServer) error { if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil { return err @@ -274,9 +296,17 @@ func (plugin *nvidiaDevicePlugin) ListAndWatch(e *pluginapi.Empty, s pluginapi.D case <-plugin.stop: return nil case d := <-plugin.health: - // FIXME: there is no way to recover from the Unhealthy state. + // Device marked unhealthy by health check d.Health = pluginapi.Unhealthy - klog.Infof("'%s' device marked unhealthy: %s", plugin.rm.Resource(), d.ID) + klog.Infof("'%s' device marked unhealthy: %s (reason: %s)", + plugin.rm.Resource(), d.ID, d.UnhealthyReason) + if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil { + return nil + } + case <-plugin.deviceListUpdate: + // Device recovery or other device list change + klog.Infof("'%s' device list updated, notifying kubelet", + plugin.rm.Resource()) if err := s.Send(&pluginapi.ListAndWatchResponse{Devices: plugin.apiDevices()}); err != nil { return nil } @@ -512,6 +542,80 @@ func (plugin *nvidiaDevicePlugin) updateResponseForDeviceMounts(response *plugin } } +// runRecoveryWorker periodically checks if unhealthy devices have recovered +// and notifies kubelet when they do. +func (plugin *nvidiaDevicePlugin) runRecoveryWorker() { + const recoveryInterval = 30 * time.Second + + ticker := time.NewTicker(recoveryInterval) + defer ticker.Stop() + + klog.V(2).Infof("Recovery worker started for '%s' (interval=%v)", + plugin.rm.Resource(), recoveryInterval) + + for { + select { + case <-plugin.stop: + klog.V(2).Info("Recovery worker stopped") + return + case <-ticker.C: + plugin.checkForRecoveredDevices() + } + } +} + +// checkForRecoveredDevices checks all unhealthy devices to see if they have +// recovered. If any have recovered, triggers a device list update to +// kubelet. +func (plugin *nvidiaDevicePlugin) checkForRecoveredDevices() { + recoveredDevices := []*rm.Device{} + + for _, d := range plugin.rm.Devices() { + if !d.IsUnhealthy() { + continue + } + + // Increment recovery attempts + d.RecoveryAttempts++ + + // Check if device has recovered + healthy, err := plugin.rm.CheckDeviceHealth(d) + if err != nil { + klog.V(4).Infof("Device %s recovery check failed (attempt %d): %v", + d.ID, d.RecoveryAttempts, err) + continue + } + + if healthy { + klog.Infof("Device %s has RECOVERED! Was unhealthy for %v (reason: %s)", + d.ID, d.UnhealthyDuration(), d.UnhealthyReason) + d.MarkHealthy() + recoveredDevices = append(recoveredDevices, d) + } else { + klog.V(3).Infof("Device %s still unhealthy (attempt %d, duration %v)", + d.ID, d.RecoveryAttempts, d.UnhealthyDuration()) + } + } + + // If any devices recovered, notify ListAndWatch + if len(recoveredDevices) > 0 { + klog.Infof("Total recovered devices: %d", len(recoveredDevices)) + plugin.triggerDeviceListUpdate() + } +} + +// triggerDeviceListUpdate sends a signal to ListAndWatch to send an updated +// device list to kubelet. Uses a buffered channel with non-blocking send to +// avoid blocking the recovery worker. +func (plugin *nvidiaDevicePlugin) triggerDeviceListUpdate() { + select { + case plugin.deviceListUpdate <- struct{}{}: + klog.V(3).Info("Device list update triggered") + default: + klog.V(4).Info("Device list update already pending, skipping") + } +} + func (plugin *nvidiaDevicePlugin) apiDeviceSpecs(devRoot string, ids []string) []*pluginapi.DeviceSpec { optional := map[string]bool{ "/dev/nvidiactl": true, diff --git a/internal/plugin/server_test.go b/internal/plugin/server_test.go index b9bddbb6a..dd257f8e4 100644 --- a/internal/plugin/server_test.go +++ b/internal/plugin/server_test.go @@ -18,7 +18,9 @@ package plugin import ( "context" + "fmt" "testing" + "time" "github.com/stretchr/testify/require" pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" @@ -254,3 +256,96 @@ func TestCDIAllocateResponse(t *testing.T) { func ptr[T any](x T) *T { return &x } + +func TestTriggerDeviceListUpdate_Phase2(t *testing.T) { + plugin := &nvidiaDevicePlugin{ + deviceListUpdate: make(chan struct{}, 1), + } + + // First trigger should send signal + plugin.triggerDeviceListUpdate() + select { + case <-plugin.deviceListUpdate: + t.Log("✓ Device list update signal sent") + case <-time.After(100 * time.Millisecond): + t.Fatal("Signal not sent") + } + + // Second trigger with pending signal should not block + plugin.triggerDeviceListUpdate() + plugin.triggerDeviceListUpdate() // Should not block + t.Log("✓ triggerDeviceListUpdate doesn't block when signal pending") +} + +func TestCheckForRecoveredDevices_Phase2(t *testing.T) { + // Create persistent device map + devices := rm.Devices{ + "GPU-0": &rm.Device{ + Device: pluginapi.Device{ + ID: "GPU-0", + Health: pluginapi.Unhealthy, + }, + UnhealthyReason: "XID-79", + }, + "GPU-1": &rm.Device{ + Device: pluginapi.Device{ + ID: "GPU-1", + Health: pluginapi.Unhealthy, + }, + UnhealthyReason: "XID-48", + }, + "GPU-2": &rm.Device{ + Device: pluginapi.Device{ + ID: "GPU-2", + Health: pluginapi.Healthy, + }, + }, + } + + // Create mock resource manager with persistent devices + mockRM := &rm.ResourceManagerMock{ + DevicesFunc: func() rm.Devices { + return devices + }, + CheckDeviceHealthFunc: func(d *rm.Device) (bool, error) { + // GPU-0 recovers, GPU-1 stays unhealthy + if d.ID == "GPU-0" { + return true, nil + } + return false, fmt.Errorf("still unhealthy") + }, + } + + plugin := &nvidiaDevicePlugin{ + rm: mockRM, + deviceListUpdate: make(chan struct{}, 1), + } + + plugin.checkForRecoveredDevices() + + // Verify GPU-0 recovered + gpu0 := devices["GPU-0"] + require.Equal(t, pluginapi.Healthy, gpu0.Health, "GPU-0 should be healthy") + require.Equal(t, "", gpu0.UnhealthyReason) + t.Logf("✓ GPU-0 recovered: Health=%s, Reason=%s", gpu0.Health, gpu0.UnhealthyReason) + + // Verify GPU-1 still unhealthy + gpu1 := devices["GPU-1"] + require.Equal(t, pluginapi.Unhealthy, gpu1.Health, "GPU-1 should still be unhealthy") + require.Equal(t, 1, gpu1.RecoveryAttempts, "GPU-1 recovery attempts should increment") + t.Logf("✓ GPU-1 still unhealthy: attempts=%d", gpu1.RecoveryAttempts) + + // Verify GPU-2 unchanged + gpu2 := devices["GPU-2"] + require.Equal(t, pluginapi.Healthy, gpu2.Health) + require.Equal(t, 0, gpu2.RecoveryAttempts, "Healthy device shouldn't be probed") + t.Log("✓ GPU-2 unchanged (was already healthy)") + + // Verify deviceListUpdate was triggered + select { + case <-plugin.deviceListUpdate: + t.Log("✓ Device list update triggered for recovery") + case <-time.After(100 * time.Millisecond): + t.Fatal("Device list update not triggered") + } +} diff --git a/internal/rm/devices.go b/internal/rm/devices.go index 1049820e8..c05451477 100644 --- a/internal/rm/devices.go +++ b/internal/rm/devices.go @@ -20,6 +20,7 @@ import ( "fmt" "strconv" "strings" + "time" "k8s.io/klog/v2" pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" @@ -35,6 +36,12 @@ type Device struct { // Replicas stores the total number of times this device is replicated. // If this is 0 or 1 then the device is not shared. Replicas int + + // Health tracking fields for recovery detection + LastHealthyTime time.Time // Last time device was confirmed healthy + LastUnhealthyTime time.Time // When device became unhealthy + UnhealthyReason string // Human-readable reason (e.g., "XID-79") + RecoveryAttempts int // Number of recovery probes attempted } // deviceInfo defines the information the required to construct a Device @@ -239,6 +246,40 @@ func (d *Device) GetUUID() string { return AnnotatedID(d.ID).GetID() } +// MarkUnhealthy marks the device as unhealthy and records the reason and +// timestamp. This should be called when a health check detects a device +// failure (e.g., XID error). +func (d *Device) MarkUnhealthy(reason string) { + d.Health = pluginapi.Unhealthy + d.LastUnhealthyTime = time.Now() + d.UnhealthyReason = reason + d.RecoveryAttempts = 0 +} + +// MarkHealthy marks the device as healthy and clears unhealthy state. This +// should be called when recovery detection confirms the device is working +// again. +func (d *Device) MarkHealthy() { + d.Health = pluginapi.Healthy + d.LastHealthyTime = time.Now() + d.UnhealthyReason = "" + d.RecoveryAttempts = 0 +} + +// IsUnhealthy returns true if the device is currently marked as unhealthy. +func (d *Device) IsUnhealthy() bool { + return d.Health == pluginapi.Unhealthy +} + +// UnhealthyDuration returns how long the device has been unhealthy. Returns +// zero duration if the device is healthy. +func (d *Device) UnhealthyDuration() time.Duration { + if !d.IsUnhealthy() { + return 0 + } + return time.Since(d.LastUnhealthyTime) +} + // NewAnnotatedID creates a new AnnotatedID from an ID and a replica number. func NewAnnotatedID(id string, replica int) AnnotatedID { return AnnotatedID(fmt.Sprintf("%s::%d", id, replica)) diff --git a/internal/rm/health.go b/internal/rm/health.go index 1f0fc5c41..84c443caa 100644 --- a/internal/rm/health.go +++ b/internal/rm/health.go @@ -17,13 +17,17 @@ package rm import ( + "context" "fmt" "os" "strconv" "strings" + "sync" + "time" "github.com/NVIDIA/go-nvml/pkg/nvml" "k8s.io/klog/v2" + pluginapi "k8s.io/kubelet/pkg/apis/deviceplugin/v1beta1" ) const ( @@ -40,8 +44,324 @@ const ( envEnableHealthChecks = "DP_ENABLE_HEALTHCHECKS" ) -// CheckHealth performs health checks on a set of devices, writing to the 'unhealthy' channel with any unhealthy devices +// eventResult packages an NVML event with its return code for passing +// between the event receiver goroutine and the main processing loop. +type eventResult struct { + event nvml.EventData + ret nvml.Return +} + +// sendUnhealthyDevice sends a device to the unhealthy channel without +// blocking. If the channel is full, it logs an error and updates the device +// state directly. This prevents the health check goroutine from being blocked +// indefinitely if ListAndWatch is stalled. +func sendUnhealthyDevice(unhealthy chan<- *Device, d *Device) { + select { + case unhealthy <- d: + klog.V(2).Infof("Device %s sent to unhealthy channel", d.ID) + default: + // Channel is full - this indicates ListAndWatch is not consuming + // or the channel buffer is insufficient for the event rate + klog.Errorf("Health channel full (capacity=%d)! "+ + "Unable to report device %s as unhealthy. "+ + "ListAndWatch may be stalled or event rate is too high.", + cap(unhealthy), d.ID) + // Update device state directly as fallback + d.Health = pluginapi.Unhealthy + } +} + +// healthCheckStats tracks statistics about health check operations for +// observability and debugging. +type healthCheckStats struct { + startTime time.Time + eventsProcessed uint64 + devicesMarkedUnhealthy uint64 + errorCount uint64 + xidByType map[uint64]uint64 // XID code -> count + mu sync.Mutex +} + +// recordEvent increments the events processed counter and tracks XID +// distribution. +func (s *healthCheckStats) recordEvent(xid uint64) { + s.mu.Lock() + defer s.mu.Unlock() + s.eventsProcessed++ + if s.xidByType == nil { + s.xidByType = make(map[uint64]uint64) + } + s.xidByType[xid]++ +} + +// recordUnhealthy increments the devices marked unhealthy counter. +func (s *healthCheckStats) recordUnhealthy() { + s.mu.Lock() + defer s.mu.Unlock() + s.devicesMarkedUnhealthy++ +} + +// recordError increments the error counter. +func (s *healthCheckStats) recordError() { + s.mu.Lock() + defer s.mu.Unlock() + s.errorCount++ +} + +// report logs a summary of health check statistics. +func (s *healthCheckStats) report() { + s.mu.Lock() + defer s.mu.Unlock() + + uptime := time.Since(s.startTime) + klog.Infof("HealthCheck Stats: uptime=%v, events=%d, unhealthy=%d, errors=%d", + uptime.Round(time.Second), s.eventsProcessed, + s.devicesMarkedUnhealthy, s.errorCount) + + if len(s.xidByType) > 0 { + klog.Infof("HealthCheck XID distribution: %v", s.xidByType) + } +} + +// nvmlHealthProvider encapsulates the state and logic for NVML-based GPU +// health monitoring. This struct groups related data and provides focused +// methods for device registration and event monitoring. +type nvmlHealthProvider struct { + // Configuration + nvmllib nvml.Interface + devices Devices + + // Device placement maps (for MIG support) + parentToDeviceMap map[string]*Device + deviceIDToGiMap map[string]uint32 + deviceIDToCiMap map[string]uint32 + + // XID filtering + xidsDisabled disabledXIDs + + // Communication + unhealthy chan<- *Device + + // Observability + stats *healthCheckStats +} + +// runEventMonitor runs the main event monitoring loop with context-based +// shutdown coordination and granular error handling. This method preserves +// all robustness features from the original implementation while being +// testable independently. +func (p *nvmlHealthProvider) runEventMonitor( + ctx context.Context, + eventSet nvml.EventSet, + handleError func(nvml.Return, Devices, chan<- *Device) bool, +) error { + // Event receive channel with buffer + eventChan := make(chan eventResult, 10) + + // Start goroutine to receive NVML events + go func() { + defer close(eventChan) + for { + // Check if we should stop + select { + case <-ctx.Done(): + return + default: + } + + // Wait for NVML event with timeout + e, ret := eventSet.Wait(5000) + + // Try to send event result, but respect context cancellation + select { + case <-ctx.Done(): + return + case eventChan <- eventResult{event: e, ret: ret}: + } + } + }() + + // Main event processing loop + for { + select { + case <-ctx.Done(): + klog.V(2).Info("Health check stopped cleanly") + return nil + + case result, ok := <-eventChan: + if !ok { + // Event channel closed, exit + return nil + } + + // Handle timeout - just continue + if result.ret == nvml.ERROR_TIMEOUT { + continue + } + + // Handle NVML errors with granular error handling + if result.ret != nvml.SUCCESS { + p.stats.recordError() + shouldContinue := handleError(result.ret, p.devices, p.unhealthy) + if !shouldContinue { + return fmt.Errorf("fatal NVML error: %v", result.ret) + } + continue + } + + e := result.event + + // Filter non-critical events + if e.EventType != nvml.EventTypeXidCriticalError { + klog.Infof("Skipping non-nvmlEventTypeXidCriticalError event: %+v", e) + continue + } + + // Check if this XID is disabled + if p.xidsDisabled.IsDisabled(e.EventData) { + klog.Infof("Skipping event %+v", e) + continue + } + + klog.Infof("Processing event %+v", e) + + // Record event stats + p.stats.recordEvent(e.EventData) + + // Get device UUID from event + eventUUID, ret := e.Device.GetUUID() + if ret != nvml.SUCCESS { + // If we cannot reliably determine the device UUID, we mark all devices as unhealthy. + klog.Infof("Failed to determine uuid for event %v: %v; Marking all devices as unhealthy.", e, ret) + p.stats.recordError() + for _, d := range p.devices { + p.stats.recordUnhealthy() + sendUnhealthyDevice(p.unhealthy, d) + } + continue + } + + // Find the device that matches this event + d, exists := p.parentToDeviceMap[eventUUID] + if !exists { + klog.Infof("Ignoring event for unexpected device: %v", eventUUID) + continue + } + + // For MIG devices, verify the GI/CI matches + if d.IsMigDevice() && e.GpuInstanceId != 0xFFFFFFFF && e.ComputeInstanceId != 0xFFFFFFFF { + gi := p.deviceIDToGiMap[d.ID] + ci := p.deviceIDToCiMap[d.ID] + if gi != e.GpuInstanceId || ci != e.ComputeInstanceId { + continue + } + klog.Infof("Event for mig device %v (gi=%v, ci=%v)", d.ID, gi, ci) + } + + klog.Infof("XidCriticalError: Xid=%d on Device=%s; marking device as unhealthy.", e.EventData, d.ID) + p.stats.recordUnhealthy() + d.MarkUnhealthy(fmt.Sprintf("XID-%d", e.EventData)) + sendUnhealthyDevice(p.unhealthy, d) + } + } +} + +// registerDeviceEvents registers NVML event handlers for all devices in the +// provider. Devices that fail registration are sent to the unhealthy channel. +// This method is separated for testability and clarity. +func (p *nvmlHealthProvider) registerDeviceEvents(eventSet nvml.EventSet) { + eventMask := uint64(nvml.EventTypeXidCriticalError | nvml.EventTypeDoubleBitEccError | nvml.EventTypeSingleBitEccError) + + for uuid, d := range p.parentToDeviceMap { + gpu, ret := p.nvmllib.DeviceGetHandleByUUID(uuid) + if ret != nvml.SUCCESS { + klog.Infof("unable to get device handle from UUID: %v; marking it as unhealthy", ret) + sendUnhealthyDevice(p.unhealthy, d) + continue + } + + supportedEvents, ret := gpu.GetSupportedEventTypes() + if ret != nvml.SUCCESS { + klog.Infof("unable to determine the supported events for %v: %v; marking it as unhealthy", d.ID, ret) + sendUnhealthyDevice(p.unhealthy, d) + continue + } + + ret = gpu.RegisterEvents(eventMask&supportedEvents, eventSet) + if ret == nvml.ERROR_NOT_SUPPORTED { + klog.Warningf("Device %v is too old to support healthchecking.", d.ID) + } + if ret != nvml.SUCCESS { + klog.Infof("Marking device %v as unhealthy: %v", d.ID, ret) + sendUnhealthyDevice(p.unhealthy, d) + } + } +} + +// handleEventWaitError categorizes NVML errors and determines the +// appropriate action. Returns true if health checking should continue, +// false if it should terminate. +func (r *nvmlResourceManager) handleEventWaitError( + ret nvml.Return, + devices Devices, + unhealthy chan<- *Device, +) bool { + klog.Errorf("Error waiting for NVML event: %v (code: %d)", ret, ret) + + switch ret { + case nvml.ERROR_GPU_IS_LOST: + // Definitive hardware failure - mark all devices unhealthy + klog.Error("GPU_IS_LOST error: Marking all devices as unhealthy") + for _, d := range devices { + sendUnhealthyDevice(unhealthy, d) + } + return true // Continue checking - devices may recover + + case nvml.ERROR_UNINITIALIZED: + // NVML state corrupted - this shouldn't happen in event loop + klog.Error("NVML uninitialized error: This is unexpected, terminating health check") + return false // Fatal, exit health check + + case nvml.ERROR_UNKNOWN, nvml.ERROR_NOT_SUPPORTED: + // Potentially transient or driver issue + klog.Warningf("Transient NVML error (%v): Will retry on next iteration", ret) + return true // Continue checking + + default: + // Unknown error - be conservative and mark devices unhealthy + klog.Errorf("Unexpected NVML error %v: Marking all devices unhealthy conservatively", ret) + for _, d := range devices { + sendUnhealthyDevice(unhealthy, d) + } + return true // Continue checking + } +} + +// checkHealth orchestrates GPU health monitoring by coordinating NVML +// initialization, device registration, and event monitoring. This function +// acts as the main entry point and delegates specific responsibilities to +// focused methods on nvmlHealthProvider. +// +// The orchestration flow: +// 1. Initialize stats tracking and XID filtering +// 2. Initialize NVML and create event set +// 3. Build device placement maps (for MIG support) +// 4. Create nvmlHealthProvider with configuration +// 5. Register device events +// 6. Start context-based shutdown coordination +// 7. Start periodic stats reporting +// 8. Run event monitoring loop +// +// All robustness features are preserved: stats tracking, granular error +// handling, context-based shutdown, and non-blocking device reporting. func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devices, unhealthy chan<- *Device) error { + // Initialize stats tracking + stats := &healthCheckStats{ + startTime: time.Now(), + xidByType: make(map[uint64]uint64), + } + defer stats.report() // Log stats summary on exit + xids := getDisabledHealthCheckXids() if xids.IsAllDisabled() { return nil @@ -62,6 +382,7 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic }() klog.Infof("Ignoring the following XIDs for health checks: %v", xids) + klog.V(2).Infof("CheckHealth: Starting for %d devices", len(devices)) eventSet, ret := r.nvml.EventSetCreate() if ret != nvml.SUCCESS { @@ -71,104 +392,64 @@ func (r *nvmlResourceManager) checkHealth(stop <-chan interface{}, devices Devic _ = eventSet.Free() }() + // Build device placement maps for MIG support parentToDeviceMap := make(map[string]*Device) deviceIDToGiMap := make(map[string]uint32) deviceIDToCiMap := make(map[string]uint32) - eventMask := uint64(nvml.EventTypeXidCriticalError | nvml.EventTypeDoubleBitEccError | nvml.EventTypeSingleBitEccError) for _, d := range devices { uuid, gi, ci, err := r.getDevicePlacement(d) if err != nil { klog.Warningf("Could not determine device placement for %v: %v; Marking it unhealthy.", d.ID, err) - unhealthy <- d + sendUnhealthyDevice(unhealthy, d) continue } deviceIDToGiMap[d.ID] = gi deviceIDToCiMap[d.ID] = ci parentToDeviceMap[uuid] = d - - gpu, ret := r.nvml.DeviceGetHandleByUUID(uuid) - if ret != nvml.SUCCESS { - klog.Infof("unable to get device handle from UUID: %v; marking it as unhealthy", ret) - unhealthy <- d - continue - } - - supportedEvents, ret := gpu.GetSupportedEventTypes() - if ret != nvml.SUCCESS { - klog.Infof("unable to determine the supported events for %v: %v; marking it as unhealthy", d.ID, ret) - unhealthy <- d - continue - } - - ret = gpu.RegisterEvents(eventMask&supportedEvents, eventSet) - if ret == nvml.ERROR_NOT_SUPPORTED { - klog.Warningf("Device %v is too old to support healthchecking.", d.ID) - } - if ret != nvml.SUCCESS { - klog.Infof("Marking device %v as unhealthy: %v", d.ID, ret) - unhealthy <- d - } } - for { - select { - case <-stop: - return nil - default: - } - - e, ret := eventSet.Wait(5000) - if ret == nvml.ERROR_TIMEOUT { - continue - } - if ret != nvml.SUCCESS { - klog.Infof("Error waiting for event: %v; Marking all devices as unhealthy", ret) - for _, d := range devices { - unhealthy <- d - } - continue - } + // Create health provider with device maps + provider := &nvmlHealthProvider{ + nvmllib: r.nvml, + devices: devices, + parentToDeviceMap: parentToDeviceMap, + deviceIDToGiMap: deviceIDToGiMap, + deviceIDToCiMap: deviceIDToCiMap, + xidsDisabled: xids, + unhealthy: unhealthy, + stats: stats, + } - if e.EventType != nvml.EventTypeXidCriticalError { - klog.Infof("Skipping non-nvmlEventTypeXidCriticalError event: %+v", e) - continue - } + // Register device events + provider.registerDeviceEvents(eventSet) - if xids.IsDisabled(e.EventData) { - klog.Infof("Skipping event %+v", e) - continue - } + // Create context for coordinating shutdown + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() - klog.Infof("Processing event %+v", e) - eventUUID, ret := e.Device.GetUUID() - if ret != nvml.SUCCESS { - // If we cannot reliably determine the device UUID, we mark all devices as unhealthy. - klog.Infof("Failed to determine uuid for event %v: %v; Marking all devices as unhealthy.", e, ret) - for _, d := range devices { - unhealthy <- d - } - continue - } - - d, exists := parentToDeviceMap[eventUUID] - if !exists { - klog.Infof("Ignoring event for unexpected device: %v", eventUUID) - continue - } + // Goroutine to watch for stop signal and cancel context + go func() { + <-stop + cancel() + }() - if d.IsMigDevice() && e.GpuInstanceId != 0xFFFFFFFF && e.ComputeInstanceId != 0xFFFFFFFF { - gi := deviceIDToGiMap[d.ID] - ci := deviceIDToCiMap[d.ID] - if gi != e.GpuInstanceId || ci != e.ComputeInstanceId { - continue + // Start periodic stats reporting goroutine + go func() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + stats.report() } - klog.Infof("Event for mig device %v (gi=%v, ci=%v)", d.ID, gi, ci) } + }() - klog.Infof("XidCriticalError: Xid=%d on Device=%s; marking device as unhealthy.", e.EventData, d.ID) - unhealthy <- d - } + // Run event monitor with error handler + return provider.runEventMonitor(ctx, eventSet, r.handleEventWaitError) } const allXIDs = 0 diff --git a/internal/rm/nvml_manager.go b/internal/rm/nvml_manager.go index fac923429..8dcce20c9 100644 --- a/internal/rm/nvml_manager.go +++ b/internal/rm/nvml_manager.go @@ -95,6 +95,48 @@ func (r *nvmlResourceManager) CheckHealth(stop <-chan interface{}, unhealthy cha return r.checkHealth(stop, r.devices, unhealthy) } +// CheckDeviceHealth performs a simple health check on a single device by +// verifying it can be accessed via NVML and responds to basic queries. +// This is used for recovery detection - if a previously unhealthy device +// passes this check, it's considered recovered. We intentionally keep this +// simple and don't try to classify XIDs as recoverable vs permanent - that's +// controlled via DP_DISABLE_HEALTHCHECKS / DP_ENABLE_HEALTHCHECKS env vars. +func (r *nvmlResourceManager) CheckDeviceHealth(d *Device) (bool, error) { + // Initialize NVML for this health check + ret := r.nvml.Init() + if ret != nvml.SUCCESS { + return false, fmt.Errorf("NVML init failed: %v", ret) + } + defer func() { + _ = r.nvml.Shutdown() + }() + + uuid := d.GetUUID() + + // For MIG devices, extract parent UUID + if d.IsMigDevice() { + parentUUID, _, _, err := r.getMigDeviceParts(d) + if err != nil { + return false, fmt.Errorf("cannot determine MIG device parts: %w", err) + } + uuid = parentUUID + } + + // Get device handle + gpu, ret := r.nvml.DeviceGetHandleByUUID(uuid) + if ret != nvml.SUCCESS { + return false, fmt.Errorf("cannot get device handle: %v", ret) + } + + // Perform basic health check - if device responds, consider it healthy + _, ret = gpu.GetName() + if ret != nvml.SUCCESS { + return false, fmt.Errorf("device not responsive (GetName failed): %v", ret) + } + + return true, nil +} + // getPreferredAllocation runs an allocation algorithm over the inputs. // The algorithm chosen is based both on the incoming set of available devices and various config settings. func (r *nvmlResourceManager) getPreferredAllocation(available, required []string, size int) ([]string, error) { diff --git a/internal/rm/rm.go b/internal/rm/rm.go index 33f44b9d8..a9ce5eb4f 100644 --- a/internal/rm/rm.go +++ b/internal/rm/rm.go @@ -45,6 +45,7 @@ type ResourceManager interface { GetDevicePaths([]string) []string GetPreferredAllocation(available, required []string, size int) ([]string, error) CheckHealth(stop <-chan interface{}, unhealthy chan<- *Device) error + CheckDeviceHealth(d *Device) (bool, error) ValidateRequest(AnnotatedIDs) error } diff --git a/internal/rm/rm_mock.go b/internal/rm/rm_mock.go index 4efee5fd9..a9337c7cf 100644 --- a/internal/rm/rm_mock.go +++ b/internal/rm/rm_mock.go @@ -47,6 +47,9 @@ type ResourceManagerMock struct { // CheckHealthFunc mocks the CheckHealth method. CheckHealthFunc func(stop <-chan interface{}, unhealthy chan<- *Device) error + // CheckDeviceHealthFunc mocks the CheckDeviceHealth method. + CheckDeviceHealthFunc func(d *Device) (bool, error) + // DevicesFunc mocks the Devices method. DevicesFunc func() Devices @@ -71,6 +74,11 @@ type ResourceManagerMock struct { // Unhealthy is the unhealthy argument value. Unhealthy chan<- *Device } + // CheckDeviceHealth holds details about calls to the CheckDeviceHealth method. + CheckDeviceHealth []struct { + // D is the d argument value. + D *Device + } // Devices holds details about calls to the Devices method. Devices []struct { } @@ -98,6 +106,7 @@ type ResourceManagerMock struct { } } lockCheckHealth sync.RWMutex + lockCheckDeviceHealth sync.RWMutex lockDevices sync.RWMutex lockGetDevicePaths sync.RWMutex lockGetPreferredAllocation sync.RWMutex @@ -144,6 +153,42 @@ func (mock *ResourceManagerMock) CheckHealthCalls() []struct { return calls } +// CheckDeviceHealth calls CheckDeviceHealthFunc. +func (mock *ResourceManagerMock) CheckDeviceHealth(d *Device) (bool, error) { + callInfo := struct { + D *Device + }{ + D: d, + } + mock.lockCheckDeviceHealth.Lock() + mock.calls.CheckDeviceHealth = append(mock.calls.CheckDeviceHealth, callInfo) + mock.lockCheckDeviceHealth.Unlock() + if mock.CheckDeviceHealthFunc == nil { + var ( + bOut bool + errOut error + ) + return bOut, errOut + } + return mock.CheckDeviceHealthFunc(d) +} + +// CheckDeviceHealthCalls gets all the calls that were made to +// CheckDeviceHealth. Check the length with: +// +// len(mockedResourceManager.CheckDeviceHealthCalls()) +func (mock *ResourceManagerMock) CheckDeviceHealthCalls() []struct { + D *Device +} { + var calls []struct { + D *Device + } + mock.lockCheckDeviceHealth.RLock() + calls = mock.calls.CheckDeviceHealth + mock.lockCheckDeviceHealth.RUnlock() + return calls +} + // Devices calls DevicesFunc. func (mock *ResourceManagerMock) Devices() Devices { callInfo := struct { diff --git a/internal/rm/tegra_manager.go b/internal/rm/tegra_manager.go index 65ca2022f..c716cab1a 100644 --- a/internal/rm/tegra_manager.go +++ b/internal/rm/tegra_manager.go @@ -74,3 +74,11 @@ func (r *tegraResourceManager) GetDevicePaths(ids []string) []string { func (r *tegraResourceManager) CheckHealth(stop <-chan interface{}, unhealthy chan<- *Device) error { return nil } + +// CheckDeviceHealth is not implemented for the tegraResourceManager. +// Tegra devices don't support the same health checking mechanisms as +// NVML-based devices. +func (r *tegraResourceManager) CheckDeviceHealth(d *Device) (bool, error) { + // Always return healthy for Tegra devices (no health checking) + return true, nil +}