From 39189f5eb14266a3138fd7d90821dabdf1347efb Mon Sep 17 00:00:00 2001 From: Francesco Torta <62566275+fra98@users.noreply.github.com> Date: Thu, 13 Nov 2025 12:44:59 +0100 Subject: [PATCH 1/2] Add option to override providerID through node annotation --- handler/handler.go | 12 ++++- handler/handler_test.go | 108 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 119 insertions(+), 1 deletion(-) diff --git a/handler/handler.go b/handler/handler.go index 5ebeaba..1f63257 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -14,11 +14,15 @@ import ( apitypes "k8s.io/apimachinery/pkg/types" "k8s.io/apimachinery/pkg/util/strategicpatch" "k8s.io/client-go/kubernetes" + "k8s.io/utils/ptr" "github.com/castai/spot-handler/castai" ) -const CastNodeIDLabel = "provisioner.cast.ai/node-id" +const ( + CastNodeIDLabel = "provisioner.cast.ai/node-id" + OverrideProviderIDAnnot = "provisioner.cast.ai/override-provider-id" +) const ( taintNodeDraining = "autoscaling.cast.ai/draining" @@ -141,6 +145,9 @@ func (g *SpotHandler) handleInterruption(ctx context.Context) error { if node.Spec.ProviderID != "" { req.ProviderID = &node.Spec.ProviderID } + if node.Annotations != nil && node.Annotations[OverrideProviderIDAnnot] != "" { + req.ProviderID = ptr.To(node.Annotations[OverrideProviderIDAnnot]) + } if err = g.castClient.SendCloudEvent(ctx, req); err != nil { return err } @@ -217,6 +224,9 @@ func (g *SpotHandler) handleRebalanceRecommendation(ctx context.Context) error { if node.Spec.ProviderID != "" { req.ProviderID = &node.Spec.ProviderID } + if node.Annotations != nil && node.Annotations[OverrideProviderIDAnnot] != "" { + req.ProviderID = ptr.To(node.Annotations[OverrideProviderIDAnnot]) + } return g.castClient.SendCloudEvent(ctx, req) } diff --git a/handler/handler_test.go b/handler/handler_test.go index f649b18..1b31846 100644 --- a/handler/handler_test.go +++ b/handler/handler_test.go @@ -309,6 +309,114 @@ func TestRunLoop(t *testing.T) { require.NoError(t, err) r.Equal(1, mothershipCalls) }) + + t.Run("override providerID in interruption event", func(t *testing.T) { + originalProviderID := "aws:///us-east-1a/i-1234567890abcdef0" + overrideProviderID := "aws:///us-east-1b/i-0987654321fedcba0" + nodeWithOverride := &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + Labels: map[string]string{ + CastNodeIDLabel: castNodeID, + }, + Annotations: map[string]string{ + OverrideProviderIDAnnot: overrideProviderID, + }, + }, + Spec: v1.NodeSpec{ + Unschedulable: false, + ProviderID: originalProviderID, + }, + } + + mothershipCalls := 0 + castS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, re *http.Request) { + mothershipCalls++ + var req castai.CloudEventRequest + r.NoError(json.NewDecoder(re.Body).Decode(&req)) + r.Equal(castNodeID, req.NodeID) + r.NotNil(req.ProviderID, "ProviderID should not be nil") + r.Equal(overrideProviderID, *req.ProviderID, "ProviderID should use override annotation value") + w.WriteHeader(http.StatusOK) + })) + defer castS.Close() + + fakeApi := fake.NewSimpleClientset(nodeWithOverride) + castHttp, err := castai.NewRestyClient(castS.URL, "test", "", log.Level, 100*time.Millisecond, "0.0.0") + r.NoError(err) + mockCastClient := castai.NewClient(log, castHttp, "test1") + + mockInterrupt := &mockInterruptChecker{interrupted: true} + handler := SpotHandler{ + pollWaitInterval: 100 * time.Millisecond, + metadataChecker: mockInterrupt, + castClient: mockCastClient, + nodeName: nodeName, + clientset: fakeApi, + log: log, + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + err = handler.Run(ctx) + require.NoError(t, err) + r.Equal(1, mothershipCalls) + }) + + t.Run("override providerID in rebalance recommendation event", func(t *testing.T) { + originalProviderID := "gce://my-project/us-central1-a/instance-123" + overrideProviderID := "gce://my-project/us-central1-b/instance-456" + nodeWithOverride := &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + Labels: map[string]string{ + CastNodeIDLabel: castNodeID, + }, + Annotations: map[string]string{ + OverrideProviderIDAnnot: overrideProviderID, + }, + }, + Spec: v1.NodeSpec{ + Unschedulable: false, + ProviderID: originalProviderID, + }, + } + + mothershipCalls := 0 + castS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, re *http.Request) { + mothershipCalls++ + var req castai.CloudEventRequest + r.NoError(json.NewDecoder(re.Body).Decode(&req)) + r.Equal(castNodeID, req.NodeID) + r.NotNil(req.ProviderID, "ProviderID should not be nil") + r.Equal(overrideProviderID, *req.ProviderID, "ProviderID should use override annotation value") + w.WriteHeader(http.StatusOK) + })) + defer castS.Close() + + fakeApi := fake.NewSimpleClientset(nodeWithOverride) + castHttp, err := castai.NewRestyClient(castS.URL, "test", "", log.Level, 100*time.Millisecond, "0.0.0") + r.NoError(err) + mockCastClient := castai.NewClient(log, castHttp, "test1") + + mockRecommendation := &mockInterruptChecker{rebalanceRecommendation: true} + handler := SpotHandler{ + pollWaitInterval: 100 * time.Millisecond, + metadataChecker: mockRecommendation, + castClient: mockCastClient, + nodeName: nodeName, + clientset: fakeApi, + log: log, + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + err = handler.Run(ctx) + require.NoError(t, err) + r.Equal(1, mothershipCalls) + }) } type mockInterruptChecker struct { From f817a5c13f8c5393fb0bd7e1aa18a64bef9d91c1 Mon Sep 17 00:00:00 2001 From: Francesco Torta <62566275+fra98@users.noreply.github.com> Date: Thu, 13 Nov 2025 15:51:03 +0100 Subject: [PATCH 2/2] Log values before sending to mothership --- handler/handler.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/handler/handler.go b/handler/handler.go index 1f63257..cc6233c 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -148,6 +148,7 @@ func (g *SpotHandler) handleInterruption(ctx context.Context) error { if node.Annotations != nil && node.Annotations[OverrideProviderIDAnnot] != "" { req.ProviderID = ptr.To(node.Annotations[OverrideProviderIDAnnot]) } + g.log.Infof("sending interruption cloud event to mothership: nodeID: %s, providerID: %s", req.NodeID, ptr.Deref(req.ProviderID, "")) if err = g.castClient.SendCloudEvent(ctx, req); err != nil { return err } @@ -228,5 +229,6 @@ func (g *SpotHandler) handleRebalanceRecommendation(ctx context.Context) error { req.ProviderID = ptr.To(node.Annotations[OverrideProviderIDAnnot]) } + g.log.Infof("sending rebalance recommendation cloud event to mothership: nodeID: %s, providerID: %s", req.NodeID, ptr.Deref(req.ProviderID, "")) return g.castClient.SendCloudEvent(ctx, req) }