diff --git a/config/config.go b/config/config.go index d5e6818..4aa27b3 100644 --- a/config/config.go +++ b/config/config.go @@ -17,6 +17,7 @@ type Config struct { LogLevel int PprofPort int PollIntervalSeconds int + Phase2Permissions bool } var cfg *Config @@ -40,6 +41,8 @@ func Get() Config { _ = viper.BindEnv("pollintervalseconds", "POLL_INTERVAL_SECONDS") _ = viper.BindEnv("pprofport", "PPROF_PORT") + _ = viper.BindEnv("phase2permissions", "PHASE2_PERMISSIONS") + cfg = &Config{} if err := viper.Unmarshal(&cfg); err != nil { panic(fmt.Errorf("parsing configuration: %v", err)) diff --git a/handler/handler.go b/handler/handler.go index cc6233c..8d9aaeb 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -43,13 +43,14 @@ type MetadataChecker interface { } type SpotHandler struct { - castClient castai.Client - clientset kubernetes.Interface - metadataChecker MetadataChecker - nodeName string - pollWaitInterval time.Duration - log logrus.FieldLogger - gracePeriod time.Duration + castClient castai.Client + clientset kubernetes.Interface + metadataChecker MetadataChecker + nodeName string + pollWaitInterval time.Duration + log logrus.FieldLogger + gracePeriod time.Duration + phase2Permissions bool } func NewSpotHandler( @@ -59,15 +60,17 @@ func NewSpotHandler( metadataChecker MetadataChecker, pollWaitInterval time.Duration, nodeName string, + phase2Permissions bool, ) *SpotHandler { return &SpotHandler{ - castClient: castClient, - clientset: clientset, - metadataChecker: metadataChecker, - log: log, - nodeName: nodeName, - pollWaitInterval: pollWaitInterval, - gracePeriod: 30 * time.Second, + castClient: castClient, + clientset: clientset, + metadataChecker: metadataChecker, + log: log, + nodeName: nodeName, + pollWaitInterval: pollWaitInterval, + gracePeriod: 30 * time.Second, + phase2Permissions: phase2Permissions, } } @@ -153,7 +156,12 @@ func (g *SpotHandler) handleInterruption(ctx context.Context) error { return err } - return g.taintNode(ctx, node) + if g.phase2Permissions { + return g.taintNode(ctx, node) + } + + g.log.Info("skipping node tainting, phase2 permissions not enabled") + return nil } func (g *SpotHandler) taintNode(ctx context.Context, node *v1.Node) error { diff --git a/handler/handler_test.go b/handler/handler_test.go index 1b31846..21a670a 100644 --- a/handler/handler_test.go +++ b/handler/handler_test.go @@ -57,12 +57,13 @@ func TestRunLoop(t *testing.T) { mockInterrupt := &mockInterruptChecker{interrupted: true} handler := SpotHandler{ - pollWaitInterval: 100 * time.Millisecond, - metadataChecker: mockInterrupt, - castClient: mockCastClient, - nodeName: nodeName, - clientset: fakeApi, - log: log, + pollWaitInterval: 100 * time.Millisecond, + metadataChecker: mockInterrupt, + castClient: mockCastClient, + nodeName: nodeName, + clientset: fakeApi, + log: log, + phase2Permissions: true, } ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -82,6 +83,61 @@ func TestRunLoop(t *testing.T) { }) }) + t.Run("do not taint node if not enough permissions", func(t *testing.T) { + mothershipCalls := 0 + castS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, re *http.Request) { + mothershipCalls++ + var req castai.CloudEventRequest + r.NoError(json.NewDecoder(re.Body).Decode(&req)) + r.Equal(req.NodeID, castNodeID) + w.WriteHeader(http.StatusOK) + })) + defer castS.Close() + + node2 := &v1.Node{ + ObjectMeta: metav1.ObjectMeta{ + Name: nodeName, + Labels: map[string]string{ + CastNodeIDLabel: castNodeID, + }, + }, + Spec: v1.NodeSpec{ + Unschedulable: false, + }, + } + fakeApi := fake.NewSimpleClientset(node2) + castHttp, err := castai.NewRestyClient(castS.URL, "test", "", log.Level, 100*time.Millisecond, "0.0.0") + r.NoError(err) + mockCastClient := castai.NewClient(log, castHttp, "test2") + + mockInterrupt := &mockInterruptChecker{interrupted: true} + handler := SpotHandler{ + pollWaitInterval: 100 * time.Millisecond, + metadataChecker: mockInterrupt, + castClient: mockCastClient, + nodeName: nodeName, + clientset: fakeApi, + log: log, + phase2Permissions: false, + } + + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + err = handler.Run(ctx) + require.NoError(t, err) + r.Equal(1, mothershipCalls) + + node2, _ = fakeApi.CoreV1().Nodes().Get(context.Background(), nodeName, metav1.GetOptions{}) + r.Equal(false, node2.Spec.Unschedulable) + r.NotEqual(valueNodeDrainingReasonInterrupted, node2.Labels[labelNodeDraining]) + r.NotContains(node2.Spec.Taints, v1.Taint{ + Key: taintNodeDraining, + Value: valueTrue, + Effect: taintNodeDrainingEffect, + }) + }) + t.Run("keep checking interruption on context canceled", func(t *testing.T) { mothershipCalls := 0 castS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, re *http.Request) { diff --git a/main.go b/main.go index ed02055..fb55993 100644 --- a/main.go +++ b/main.go @@ -92,6 +92,7 @@ func main() { interruptChecker, time.Duration(cfg.PollIntervalSeconds)*time.Second, cfg.NodeName, + cfg.Phase2Permissions, ) if cfg.PprofPort != 0 {