Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pkg/scheduler/event.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
clientgoscheme "k8s.io/client-go/kubernetes/scheme"
v1core "k8s.io/client-go/kubernetes/typed/core/v1"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
)

// Define events for ResourceBinding, ResourceFilter objects and their associated resources.
Expand Down Expand Up @@ -57,6 +58,10 @@ func (s *Scheduler) recordScheduleBindingResultEvent(pod *corev1.Pod, eventReaso
if pod == nil {
return
}
if s.eventRecorder == nil {
klog.Warning("eventRecorder is nil, skipping event creation")
return
}
if schedulerErr == nil {
successMsg := fmt.Sprintf("Successfully binding node %v to %v/%v", nodeResult, pod.Namespace, pod.Name)
s.eventRecorder.Event(pod, corev1.EventTypeNormal, eventReason, successMsg)
Expand Down
28 changes: 28 additions & 0 deletions pkg/scheduler/event_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,31 @@ func TestRecordScheduleFilterResultEvent(t *testing.T) {
})
}
}

func TestRecordScheduleBindingResultEvent_NilRecorder(t *testing.T) {
// Initialize a scheduler with NO event recorder
s := &Scheduler{
kubeClient: fake.NewSimpleClientset(),
eventRecorder: nil,
}

pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{Name: "test-pod", Namespace: "default"},
}

s.recordScheduleBindingResultEvent(pod, "BindingSucceed", []string{"node-1"}, nil)
}

func TestRecordScheduleFilterResultEvent_NilRecorder(t *testing.T) {
// Initialize a scheduler with NO event recorder
s := &Scheduler{
kubeClient: fake.NewSimpleClientset(),
eventRecorder: nil,
}

pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{Name: "test-pod", Namespace: "default"},
}

s.recordScheduleFilterResultEvent(pod, "FilteringSucceed", "success", nil)
}
42 changes: 42 additions & 0 deletions pkg/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/labels"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes"
Expand Down Expand Up @@ -650,6 +651,19 @@ func (s *Scheduler) getPodUsage() (map[string]device.PodUseDeviceStat, error) {

func (s *Scheduler) Bind(args extenderv1.ExtenderBindingArgs) (*extenderv1.ExtenderBindingResult, error) {
klog.InfoS("Attempting to bind pod to node", "pod", args.PodName, "namespace", args.PodNamespace, "node", args.Node)

ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()

//In-Memory Lock (Queues concurrent pods instead of failing them)
release, ok := nodelockutil.AcquireBindLock(ctx, args.Node)
if !ok {
err := fmt.Errorf("timed out waiting for node bind lock: %s", args.Node)
klog.ErrorS(err, "Bind lock timeout")
return &extenderv1.ExtenderBindingResult{Error: err.Error()}, nil
}
defer release()

var res *extenderv1.ExtenderBindingResult

binding := &corev1.Binding{
Expand All @@ -673,6 +687,34 @@ func (s *Scheduler) Bind(args extenderv1.ExtenderBindingArgs) (*extenderv1.Exten
return res, nil
}

err = wait.PollUntilContextTimeout(ctx, 200*time.Millisecond, 25*time.Second, true, func(pollCtx context.Context) (bool, error) {
liveNode, getErr := s.kubeClient.CoreV1().Nodes().Get(pollCtx, args.Node, metav1.GetOptions{})
if getErr != nil {
return false, getErr
}
// If the node still has the lock annotation from a previous pod, keep waiting
if _, locked := liveNode.Annotations[nodelockutil.NodeLockKey]; locked {
return false, nil
}
return true, nil
})

if err != nil {
timeoutErr := fmt.Errorf("timed out waiting for device plugin to clear previous lock on node %s: %v", args.Node, err)
klog.ErrorS(timeoutErr, "Device plugin annotation lock timeout")
return &extenderv1.ExtenderBindingResult{Error: timeoutErr.Error()}, nil
}

lockValue := nodelockutil.GenerateNodeLockKeyByPod(current)
patchData := fmt.Appendf(nil, `{"metadata":{"annotations":{"%s":"%s"}}}`, nodelockutil.NodeLockKey, lockValue)

_, err = s.kubeClient.CoreV1().Nodes().Patch(ctx, args.Node, types.MergePatchType, patchData, metav1.PatchOptions{})
if err != nil {
klog.ErrorS(err, "Failed to apply blind patch to node for mutex lock", "node", args.Node)
res = &extenderv1.ExtenderBindingResult{Error: err.Error()}
return res, nil // Return soft error so scheduler can retry
}

tmppatch := map[string]string{
util.DeviceBindPhase: "allocating",
util.BindTimeAnnotations: strconv.FormatInt(time.Now().Unix(), 10),
Expand Down
159 changes: 159 additions & 0 deletions pkg/scheduler/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"fmt"
"maps"
"strings"
"sync"
"testing"
"time"

Expand All @@ -29,10 +30,12 @@ import (
corev1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/api/resource"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
"k8s.io/apimachinery/pkg/util/wait"
"k8s.io/client-go/informers"
"k8s.io/client-go/kubernetes/fake"
"k8s.io/client-go/tools/cache"
"k8s.io/client-go/tools/record"
"k8s.io/klog/v2"
extenderv1 "k8s.io/kube-scheduler/extender/v1"

Expand Down Expand Up @@ -1314,3 +1317,159 @@ func Test_Scheduler_Issue1368_TerminatingPodRetainsCache(t *testing.T) {
_, ok = s.podManager.GetPod(terminatedPod)
assert.Equal(t, false, ok, "Pod should be removed from cache after reaching a terminal phase (Succeeded/Failed)")
}
func Test_Bind_ConcurrentPods_NoExponentialBackoff(t *testing.T) {
// 40+ pods hitting Bind() simultaneously caused LockNode to return errors
// for 39 of them, pushing them into kube-scheduler's exponential backoff queue.
// The fix serialises Bind() calls per-node with an in-memory mutex so that
// concurrent pods wait (queue) rather than fail immediately.
const parallelism = 40
const nodeName = "gpu-node-1"

s := NewScheduler()
s.eventRecorder = record.NewFakeRecorder(100)

client.KubeClient = fake.NewSimpleClientset()
s.kubeClient = client.KubeClient

informerFactory := informers.NewSharedInformerFactoryWithOptions(client.KubeClient, time.Hour)
s.podLister = informerFactory.Core().V1().Pods().Lister()
s.nodeLister = informerFactory.Core().V1().Nodes().Lister()
informerFactory.Start(s.stopCh)
informerFactory.WaitForCacheSync(s.stopCh)

sConfig := &config.Config{
NvidiaConfig: nvidia.NvidiaConfig{
ResourceCountName: "hami.io/gpu",
ResourceMemoryName: "hami.io/gpumem",
ResourceCoreName: "hami.io/gpucores",
DefaultGPUNum: 1,
},
}
require.NoError(t, config.InitDevicesWithConfig(sConfig))

node := &corev1.Node{
ObjectMeta: metav1.ObjectMeta{Name: nodeName},
}
_, err := client.KubeClient.CoreV1().Nodes().Create(
context.Background(), node, metav1.CreateOptions{},
)
require.NoError(t, err)
err = informerFactory.Core().V1().Nodes().Informer().GetIndexer().Add(node)
require.NoError(t, err)

pods := make([]*corev1.Pod, parallelism)
for i := range pods {
pod := &corev1.Pod{
ObjectMeta: metav1.ObjectMeta{
Name: fmt.Sprintf("mnist-pod-%d", i),
Namespace: "default",
UID: types.UID(fmt.Sprintf("uid-%d", i)),
},
Spec: corev1.PodSpec{
Containers: []corev1.Container{{
Name: "trainer",
Resources: corev1.ResourceRequirements{
Limits: corev1.ResourceList{
"hami.io/gpu": *resource.NewQuantity(1, resource.BinarySI),
"hami.io/gpumem": *resource.NewQuantity(1024, resource.BinarySI),
},
},
}},
},
}
_, err := client.KubeClient.CoreV1().Pods(pod.Namespace).Create(
context.Background(), pod, metav1.CreateOptions{},
)
require.NoError(t, err)
err = informerFactory.Core().V1().Pods().Informer().GetIndexer().Add(pod)
require.NoError(t, err)
pods[i] = pod
}

var (
wg sync.WaitGroup
mu sync.Mutex
backoffErrors int
bindErrors int
)

// Simulates the device plugin clearing the node lock annotation.
// This prevents the polling loop in Bind() from timing out.
ctxFakePlugin, stopFakePlugin := context.WithCancel(context.Background())
defer stopFakePlugin()

go func() {
ticker := time.NewTicker(10 * time.Millisecond)
defer ticker.Stop()
for {
select {
case <-ctxFakePlugin.Done():
return
case <-ticker.C:
n, err := client.KubeClient.CoreV1().Nodes().Get(context.Background(), nodeName, metav1.GetOptions{})
if err == nil && n.Annotations != nil {
if _, locked := n.Annotations["hami.io/mutex.lock"]; locked {
// Lock found! Simulate device plugin finishing and clearing it.
delete(n.Annotations, "hami.io/mutex.lock")
_, _ = client.KubeClient.CoreV1().Nodes().Update(context.Background(), n, metav1.UpdateOptions{})
}
}
}
}
}()

start := make(chan struct{})

for i := range parallelism {
wg.Add(1)
go func(pod *corev1.Pod) {
defer wg.Done()
<-start

result, err := s.Bind(extenderv1.ExtenderBindingArgs{
PodName: pod.Name,
PodNamespace: pod.Namespace,
PodUID: pod.UID,
Node: nodeName,
})

mu.Lock()
defer mu.Unlock()
if err != nil {
backoffErrors++
t.Logf("BACKOFF-TRIGGERING error for pod %s: %v", pod.Name, err)
} else if result != nil && result.Error != "" {
bindErrors++
t.Logf("Soft bind error for pod %s: %s", pod.Name, result.Error)
}
}(pods[i])
}

close(start)

// If the in-memory mutex fails to unlock, wg.Wait() will hang forever.
// This ensures the test fails with a clear message instead of timing out the CI pipeline.
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()

select {
case <-done:
// Increased timeout from 5s to 15s to account for the sequential wait of 40 pods
case <-time.After(15 * time.Second):
t.Fatal("FATAL: Test timed out! A deadlock occurred in the Bind() function's mutex logic.")
}

// Stop the fake plugin since the test has completed
stopFakePlugin()

require.Equal(t, 0, backoffErrors,
"BUG #1367: %d pods received non-nil errors from Bind(), "+
"causing exponential backoff. In-memory mutex should prevent this.",
backoffErrors,
)

t.Logf("Concurrent Bind test: %d/%d pods had soft bind errors (acceptable), 0 had backoff-triggering errors", bindErrors, parallelism)
}
24 changes: 24 additions & 0 deletions pkg/util/nodelock/nodelock.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ var (
}
)

var (
bindMu sync.Mutex
bindLocks = map[string]chan struct{}{}
)

// nodeLockManager manages locks on a per-node basis to allow concurrent
// operations on different nodes while maintaining mutual exclusion for
// operations on the same node.
Expand All @@ -69,6 +74,25 @@ func newNodeLockManager() nodeLockManager {
}
}

// AcquireBindLock tries to acquire the per-node bind lock within ctx.
// Returns false if the context expires before the lock is available.
func AcquireBindLock(ctx context.Context, nodeName string) (release func(), ok bool) {
bindMu.Lock()
ch, exists := bindLocks[nodeName]
if !exists {
ch = make(chan struct{}, 1)
bindLocks[nodeName] = ch
}
bindMu.Unlock()

select {
case ch <- struct{}{}:
return func() { <-ch }, true
case <-ctx.Done():
return func() {}, false
}
}

// getLock returns the mutex for a specific node, creating it if necessary.
// This method is thread-safe.
func (m *nodeLockManager) getLock(nodeName string) *sync.Mutex {
Expand Down
Loading
Loading