diff --git a/pkg/scheduler/event.go b/pkg/scheduler/event.go index 21b407531..9a7a4b8e5 100644 --- a/pkg/scheduler/event.go +++ b/pkg/scheduler/event.go @@ -27,6 +27,7 @@ import ( clientgoscheme "k8s.io/client-go/kubernetes/scheme" v1core "k8s.io/client-go/kubernetes/typed/core/v1" "k8s.io/client-go/tools/record" + "k8s.io/klog/v2" ) // Define events for ResourceBinding, ResourceFilter objects and their associated resources. @@ -57,6 +58,10 @@ func (s *Scheduler) recordScheduleBindingResultEvent(pod *corev1.Pod, eventReaso if pod == nil { return } + if s.eventRecorder == nil { + klog.Warning("eventRecorder is nil, skipping event creation") + return + } if schedulerErr == nil { successMsg := fmt.Sprintf("Successfully binding node %v to %v/%v", nodeResult, pod.Namespace, pod.Name) s.eventRecorder.Event(pod, corev1.EventTypeNormal, eventReason, successMsg) diff --git a/pkg/scheduler/event_test.go b/pkg/scheduler/event_test.go index 679a652f8..6c09cc678 100644 --- a/pkg/scheduler/event_test.go +++ b/pkg/scheduler/event_test.go @@ -204,3 +204,31 @@ func TestRecordScheduleFilterResultEvent(t *testing.T) { }) } } + +func TestRecordScheduleBindingResultEvent_NilRecorder(t *testing.T) { + // Initialize a scheduler with NO event recorder + s := &Scheduler{ + kubeClient: fake.NewSimpleClientset(), + eventRecorder: nil, + } + + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "test-pod", Namespace: "default"}, + } + + s.recordScheduleBindingResultEvent(pod, "BindingSucceed", []string{"node-1"}, nil) +} + +func TestRecordScheduleFilterResultEvent_NilRecorder(t *testing.T) { + // Initialize a scheduler with NO event recorder + s := &Scheduler{ + kubeClient: fake.NewSimpleClientset(), + eventRecorder: nil, + } + + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{Name: "test-pod", Namespace: "default"}, + } + + s.recordScheduleFilterResultEvent(pod, "FilteringSucceed", "success", nil) +} diff --git a/pkg/scheduler/scheduler.go b/pkg/scheduler/scheduler.go index 837dc13fd..3abd9c848 100644 --- a/pkg/scheduler/scheduler.go +++ b/pkg/scheduler/scheduler.go @@ -31,6 +31,7 @@ import ( corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" + "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes" @@ -650,6 +651,19 @@ func (s *Scheduler) getPodUsage() (map[string]device.PodUseDeviceStat, error) { func (s *Scheduler) Bind(args extenderv1.ExtenderBindingArgs) (*extenderv1.ExtenderBindingResult, error) { klog.InfoS("Attempting to bind pod to node", "pod", args.PodName, "namespace", args.PodNamespace, "node", args.Node) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + //In-Memory Lock (Queues concurrent pods instead of failing them) + release, ok := nodelockutil.AcquireBindLock(ctx, args.Node) + if !ok { + err := fmt.Errorf("timed out waiting for node bind lock: %s", args.Node) + klog.ErrorS(err, "Bind lock timeout") + return &extenderv1.ExtenderBindingResult{Error: err.Error()}, nil + } + defer release() + var res *extenderv1.ExtenderBindingResult binding := &corev1.Binding{ @@ -673,6 +687,34 @@ func (s *Scheduler) Bind(args extenderv1.ExtenderBindingArgs) (*extenderv1.Exten return res, nil } + err = wait.PollUntilContextTimeout(ctx, 200*time.Millisecond, 25*time.Second, true, func(pollCtx context.Context) (bool, error) { + liveNode, getErr := s.kubeClient.CoreV1().Nodes().Get(pollCtx, args.Node, metav1.GetOptions{}) + if getErr != nil { + return false, getErr + } + // If the node still has the lock annotation from a previous pod, keep waiting + if _, locked := liveNode.Annotations[nodelockutil.NodeLockKey]; locked { + return false, nil + } + return true, nil + }) + + if err != nil { + timeoutErr := fmt.Errorf("timed out waiting for device plugin to clear previous lock on node %s: %v", args.Node, err) + klog.ErrorS(timeoutErr, "Device plugin annotation lock timeout") + return &extenderv1.ExtenderBindingResult{Error: timeoutErr.Error()}, nil + } + + lockValue := nodelockutil.GenerateNodeLockKeyByPod(current) + patchData := fmt.Appendf(nil, `{"metadata":{"annotations":{"%s":"%s"}}}`, nodelockutil.NodeLockKey, lockValue) + + _, err = s.kubeClient.CoreV1().Nodes().Patch(ctx, args.Node, types.MergePatchType, patchData, metav1.PatchOptions{}) + if err != nil { + klog.ErrorS(err, "Failed to apply blind patch to node for mutex lock", "node", args.Node) + res = &extenderv1.ExtenderBindingResult{Error: err.Error()} + return res, nil // Return soft error so scheduler can retry + } + tmppatch := map[string]string{ util.DeviceBindPhase: "allocating", util.BindTimeAnnotations: strconv.FormatInt(time.Now().Unix(), 10), diff --git a/pkg/scheduler/scheduler_test.go b/pkg/scheduler/scheduler_test.go index cd47be0e8..18aaaf525 100644 --- a/pkg/scheduler/scheduler_test.go +++ b/pkg/scheduler/scheduler_test.go @@ -21,6 +21,7 @@ import ( "fmt" "maps" "strings" + "sync" "testing" "time" @@ -29,10 +30,12 @@ import ( corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/wait" "k8s.io/client-go/informers" "k8s.io/client-go/kubernetes/fake" "k8s.io/client-go/tools/cache" + "k8s.io/client-go/tools/record" "k8s.io/klog/v2" extenderv1 "k8s.io/kube-scheduler/extender/v1" @@ -1314,3 +1317,159 @@ func Test_Scheduler_Issue1368_TerminatingPodRetainsCache(t *testing.T) { _, ok = s.podManager.GetPod(terminatedPod) assert.Equal(t, false, ok, "Pod should be removed from cache after reaching a terminal phase (Succeeded/Failed)") } +func Test_Bind_ConcurrentPods_NoExponentialBackoff(t *testing.T) { + // 40+ pods hitting Bind() simultaneously caused LockNode to return errors + // for 39 of them, pushing them into kube-scheduler's exponential backoff queue. + // The fix serialises Bind() calls per-node with an in-memory mutex so that + // concurrent pods wait (queue) rather than fail immediately. + const parallelism = 40 + const nodeName = "gpu-node-1" + + s := NewScheduler() + s.eventRecorder = record.NewFakeRecorder(100) + + client.KubeClient = fake.NewSimpleClientset() + s.kubeClient = client.KubeClient + + informerFactory := informers.NewSharedInformerFactoryWithOptions(client.KubeClient, time.Hour) + s.podLister = informerFactory.Core().V1().Pods().Lister() + s.nodeLister = informerFactory.Core().V1().Nodes().Lister() + informerFactory.Start(s.stopCh) + informerFactory.WaitForCacheSync(s.stopCh) + + sConfig := &config.Config{ + NvidiaConfig: nvidia.NvidiaConfig{ + ResourceCountName: "hami.io/gpu", + ResourceMemoryName: "hami.io/gpumem", + ResourceCoreName: "hami.io/gpucores", + DefaultGPUNum: 1, + }, + } + require.NoError(t, config.InitDevicesWithConfig(sConfig)) + + node := &corev1.Node{ + ObjectMeta: metav1.ObjectMeta{Name: nodeName}, + } + _, err := client.KubeClient.CoreV1().Nodes().Create( + context.Background(), node, metav1.CreateOptions{}, + ) + require.NoError(t, err) + err = informerFactory.Core().V1().Nodes().Informer().GetIndexer().Add(node) + require.NoError(t, err) + + pods := make([]*corev1.Pod, parallelism) + for i := range pods { + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: fmt.Sprintf("mnist-pod-%d", i), + Namespace: "default", + UID: types.UID(fmt.Sprintf("uid-%d", i)), + }, + Spec: corev1.PodSpec{ + Containers: []corev1.Container{{ + Name: "trainer", + Resources: corev1.ResourceRequirements{ + Limits: corev1.ResourceList{ + "hami.io/gpu": *resource.NewQuantity(1, resource.BinarySI), + "hami.io/gpumem": *resource.NewQuantity(1024, resource.BinarySI), + }, + }, + }}, + }, + } + _, err := client.KubeClient.CoreV1().Pods(pod.Namespace).Create( + context.Background(), pod, metav1.CreateOptions{}, + ) + require.NoError(t, err) + err = informerFactory.Core().V1().Pods().Informer().GetIndexer().Add(pod) + require.NoError(t, err) + pods[i] = pod + } + + var ( + wg sync.WaitGroup + mu sync.Mutex + backoffErrors int + bindErrors int + ) + + // Simulates the device plugin clearing the node lock annotation. + // This prevents the polling loop in Bind() from timing out. + ctxFakePlugin, stopFakePlugin := context.WithCancel(context.Background()) + defer stopFakePlugin() + + go func() { + ticker := time.NewTicker(10 * time.Millisecond) + defer ticker.Stop() + for { + select { + case <-ctxFakePlugin.Done(): + return + case <-ticker.C: + n, err := client.KubeClient.CoreV1().Nodes().Get(context.Background(), nodeName, metav1.GetOptions{}) + if err == nil && n.Annotations != nil { + if _, locked := n.Annotations["hami.io/mutex.lock"]; locked { + // Lock found! Simulate device plugin finishing and clearing it. + delete(n.Annotations, "hami.io/mutex.lock") + _, _ = client.KubeClient.CoreV1().Nodes().Update(context.Background(), n, metav1.UpdateOptions{}) + } + } + } + } + }() + + start := make(chan struct{}) + + for i := range parallelism { + wg.Add(1) + go func(pod *corev1.Pod) { + defer wg.Done() + <-start + + result, err := s.Bind(extenderv1.ExtenderBindingArgs{ + PodName: pod.Name, + PodNamespace: pod.Namespace, + PodUID: pod.UID, + Node: nodeName, + }) + + mu.Lock() + defer mu.Unlock() + if err != nil { + backoffErrors++ + t.Logf("BACKOFF-TRIGGERING error for pod %s: %v", pod.Name, err) + } else if result != nil && result.Error != "" { + bindErrors++ + t.Logf("Soft bind error for pod %s: %s", pod.Name, result.Error) + } + }(pods[i]) + } + + close(start) + + // If the in-memory mutex fails to unlock, wg.Wait() will hang forever. + // This ensures the test fails with a clear message instead of timing out the CI pipeline. + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + + select { + case <-done: + // Increased timeout from 5s to 15s to account for the sequential wait of 40 pods + case <-time.After(15 * time.Second): + t.Fatal("FATAL: Test timed out! A deadlock occurred in the Bind() function's mutex logic.") + } + + // Stop the fake plugin since the test has completed + stopFakePlugin() + + require.Equal(t, 0, backoffErrors, + "BUG #1367: %d pods received non-nil errors from Bind(), "+ + "causing exponential backoff. In-memory mutex should prevent this.", + backoffErrors, + ) + + t.Logf("Concurrent Bind test: %d/%d pods had soft bind errors (acceptable), 0 had backoff-triggering errors", bindErrors, parallelism) +} diff --git a/pkg/util/nodelock/nodelock.go b/pkg/util/nodelock/nodelock.go index 41899fe8d..56a57ea08 100644 --- a/pkg/util/nodelock/nodelock.go +++ b/pkg/util/nodelock/nodelock.go @@ -54,6 +54,11 @@ var ( } ) +var ( + bindMu sync.Mutex + bindLocks = map[string]chan struct{}{} +) + // nodeLockManager manages locks on a per-node basis to allow concurrent // operations on different nodes while maintaining mutual exclusion for // operations on the same node. @@ -69,6 +74,25 @@ func newNodeLockManager() nodeLockManager { } } +// AcquireBindLock tries to acquire the per-node bind lock within ctx. +// Returns false if the context expires before the lock is available. +func AcquireBindLock(ctx context.Context, nodeName string) (release func(), ok bool) { + bindMu.Lock() + ch, exists := bindLocks[nodeName] + if !exists { + ch = make(chan struct{}, 1) + bindLocks[nodeName] = ch + } + bindMu.Unlock() + + select { + case ch <- struct{}{}: + return func() { <-ch }, true + case <-ctx.Done(): + return func() {}, false + } +} + // getLock returns the mutex for a specific node, creating it if necessary. // This method is thread-safe. func (m *nodeLockManager) getLock(nodeName string) *sync.Mutex { diff --git a/pkg/util/nodelock/nodelock_test.go b/pkg/util/nodelock/nodelock_test.go index 693f91b38..5357c5c95 100644 --- a/pkg/util/nodelock/nodelock_test.go +++ b/pkg/util/nodelock/nodelock_test.go @@ -18,6 +18,7 @@ package nodelock import ( "context" // Added for the new test + "os" "runtime" "strings" "testing" @@ -555,3 +556,133 @@ func TestSimulateRetryStorm(t *testing.T) { }) } } + +func TestAcquireBindLock(t *testing.T) { + nodeName := "test-node-1" + + ctx1, cancel1 := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel1() + + release, ok := AcquireBindLock(ctx1, nodeName) + if !ok { + t.Errorf("Expected to acquire lock, but failed") + } + release() + + ctx2, cancel2 := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel2() + + release2, ok2 := AcquireBindLock(ctx2, nodeName) + if !ok2 { + t.Fatalf("Expected to acquire lock for timeout test, but failed") + } + defer release2() + + ctx3, cancel3 := context.WithTimeout(context.Background(), 10*time.Millisecond) + defer cancel3() + + release3, ok3 := AcquireBindLock(ctx3, nodeName) + if ok3 { + t.Errorf("Expected lock acquisition to timeout and fail, but it succeeded") + } + release3() +} + +func TestParseNodeLock(t *testing.T) { + now := time.Now().Format(time.RFC3339) + + tests := []struct { + name string + value string + wantErr bool + expectedNs string + expectedPod string + }{ + { + name: "Legacy format without separator", + value: now, + wantErr: false, + expectedNs: "", + expectedPod: "", + }, + { + name: "Valid new format", + value: now + NodeLockSep + "default" + NodeLockSep + "my-pod", + wantErr: false, + expectedNs: "default", + expectedPod: "my-pod", + }, + { + name: "Malformed format with wrong number of parts", + value: now + NodeLockSep + "default", + wantErr: true, + }, + { + name: "Invalid time format", + value: "not-a-timestamp" + NodeLockSep + "default" + NodeLockSep + "my-pod", + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, ns, podName, err := ParseNodeLock(tt.value) + if (err != nil) != tt.wantErr { + t.Errorf("ParseNodeLock() error = %v, wantErr %v", err, tt.wantErr) + } + if !tt.wantErr { + if ns != tt.expectedNs { + t.Errorf("ParseNodeLock() got ns = %v, want %v", ns, tt.expectedNs) + } + if podName != tt.expectedPod { + t.Errorf("ParseNodeLock() got podName = %v, want %v", podName, tt.expectedPod) + } + } + }) + } +} + +func TestSetupNodeLockTimeout(t *testing.T) { + originalTimeout := NodeLockTimeout + defer func() { + NodeLockTimeout = originalTimeout + os.Unsetenv("HAMI_NODELOCK_EXPIRE") + }() + + // Test valid duration + os.Setenv("HAMI_NODELOCK_EXPIRE", "10m") + setupNodeLockTimeout() + if NodeLockTimeout != 10*time.Minute { + t.Errorf("Expected timeout to be 10m, got %v", NodeLockTimeout) + } + + // Test invalid duration (should not crash, should retain previous/default value) + os.Setenv("HAMI_NODELOCK_EXPIRE", "invalid-duration") + setupNodeLockTimeout() + if NodeLockTimeout != 10*time.Minute { + t.Errorf("Expected timeout to remain 10m after invalid env var, got %v", NodeLockTimeout) + } +} + +func TestGenerateNodeLockKeyByPod(t *testing.T) { + // Test with nil pod + keyNil := GenerateNodeLockKeyByPod(nil) + if strings.Contains(keyNil, NodeLockSep) { + t.Errorf("Expected key for nil pod to not contain separator, got %v", keyNil) + } + + // Test with valid pod + pod := &corev1.Pod{ + ObjectMeta: metav1.ObjectMeta{ + Name: "test-pod", + Namespace: "test-ns", + }, + } + keyValid := GenerateNodeLockKeyByPod(pod) + if !strings.Contains(keyValid, NodeLockSep) { + t.Errorf("Expected key for valid pod to contain separator, got %v", keyValid) + } + if !strings.HasSuffix(keyValid, NodeLockSep+"test-ns"+NodeLockSep+"test-pod") { + t.Errorf("Expected key to end with namespace and pod name, got %v", keyValid) + } +}