Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ type Config struct {
LogLevel int
PprofPort int
PollIntervalSeconds int
Phase2Permissions bool
}

var cfg *Config
Expand All @@ -40,6 +41,8 @@ func Get() Config {
_ = viper.BindEnv("pollintervalseconds", "POLL_INTERVAL_SECONDS")
_ = viper.BindEnv("pprofport", "PPROF_PORT")

_ = viper.BindEnv("phase2permissions", "PHASE2_PERMISSIONS")

cfg = &Config{}
if err := viper.Unmarshal(&cfg); err != nil {
panic(fmt.Errorf("parsing configuration: %v", err))
Expand Down
38 changes: 23 additions & 15 deletions handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ type MetadataChecker interface {
}

type SpotHandler struct {
castClient castai.Client
clientset kubernetes.Interface
metadataChecker MetadataChecker
nodeName string
pollWaitInterval time.Duration
log logrus.FieldLogger
gracePeriod time.Duration
castClient castai.Client
clientset kubernetes.Interface
metadataChecker MetadataChecker
nodeName string
pollWaitInterval time.Duration
log logrus.FieldLogger
gracePeriod time.Duration
phase2Permissions bool
}

func NewSpotHandler(
Expand All @@ -59,15 +60,17 @@ func NewSpotHandler(
metadataChecker MetadataChecker,
pollWaitInterval time.Duration,
nodeName string,
phase2Permissions bool,
) *SpotHandler {
return &SpotHandler{
castClient: castClient,
clientset: clientset,
metadataChecker: metadataChecker,
log: log,
nodeName: nodeName,
pollWaitInterval: pollWaitInterval,
gracePeriod: 30 * time.Second,
castClient: castClient,
clientset: clientset,
metadataChecker: metadataChecker,
log: log,
nodeName: nodeName,
pollWaitInterval: pollWaitInterval,
gracePeriod: 30 * time.Second,
phase2Permissions: phase2Permissions,
}
}

Expand Down Expand Up @@ -153,7 +156,12 @@ func (g *SpotHandler) handleInterruption(ctx context.Context) error {
return err
}

return g.taintNode(ctx, node)
if g.phase2Permissions {
return g.taintNode(ctx, node)
}

g.log.Info("skipping node tainting, phase2 permissions not enabled")
return nil
}

func (g *SpotHandler) taintNode(ctx context.Context, node *v1.Node) error {
Expand Down
68 changes: 62 additions & 6 deletions handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ func TestRunLoop(t *testing.T) {

mockInterrupt := &mockInterruptChecker{interrupted: true}
handler := SpotHandler{
pollWaitInterval: 100 * time.Millisecond,
metadataChecker: mockInterrupt,
castClient: mockCastClient,
nodeName: nodeName,
clientset: fakeApi,
log: log,
pollWaitInterval: 100 * time.Millisecond,
metadataChecker: mockInterrupt,
castClient: mockCastClient,
nodeName: nodeName,
clientset: fakeApi,
log: log,
phase2Permissions: true,
}

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
Expand All @@ -82,6 +83,61 @@ func TestRunLoop(t *testing.T) {
})
})

t.Run("do not taint node if not enough permissions", func(t *testing.T) {
mothershipCalls := 0
castS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, re *http.Request) {
mothershipCalls++
var req castai.CloudEventRequest
r.NoError(json.NewDecoder(re.Body).Decode(&req))
r.Equal(req.NodeID, castNodeID)
w.WriteHeader(http.StatusOK)
}))
defer castS.Close()

node2 := &v1.Node{
ObjectMeta: metav1.ObjectMeta{
Name: nodeName,
Labels: map[string]string{
CastNodeIDLabel: castNodeID,
},
},
Spec: v1.NodeSpec{
Unschedulable: false,
},
}
fakeApi := fake.NewSimpleClientset(node2)
castHttp, err := castai.NewRestyClient(castS.URL, "test", "", log.Level, 100*time.Millisecond, "0.0.0")
r.NoError(err)
mockCastClient := castai.NewClient(log, castHttp, "test2")

mockInterrupt := &mockInterruptChecker{interrupted: true}
handler := SpotHandler{
pollWaitInterval: 100 * time.Millisecond,
metadataChecker: mockInterrupt,
castClient: mockCastClient,
nodeName: nodeName,
clientset: fakeApi,
log: log,
phase2Permissions: false,
}

ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()

err = handler.Run(ctx)
require.NoError(t, err)
r.Equal(1, mothershipCalls)

node2, _ = fakeApi.CoreV1().Nodes().Get(context.Background(), nodeName, metav1.GetOptions{})
r.Equal(false, node2.Spec.Unschedulable)
r.NotEqual(valueNodeDrainingReasonInterrupted, node2.Labels[labelNodeDraining])
r.NotContains(node2.Spec.Taints, v1.Taint{
Key: taintNodeDraining,
Value: valueTrue,
Effect: taintNodeDrainingEffect,
})
})

t.Run("keep checking interruption on context canceled", func(t *testing.T) {
mothershipCalls := 0
castS := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, re *http.Request) {
Expand Down
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ func main() {
interruptChecker,
time.Duration(cfg.PollIntervalSeconds)*time.Second,
cfg.NodeName,
cfg.Phase2Permissions,
)

if cfg.PprofPort != 0 {
Expand Down
Loading