From 4c53b8f11d58ccbacfd3c3d3053892fa70da7b9b Mon Sep 17 00:00:00 2001 From: Mohammad Aziz Date: Tue, 28 Apr 2026 17:57:13 +0530 Subject: [PATCH] feat(ws): add polling fallback rollout controls with feature flags, hello capability exchange, and overlap suppression --- app/jobs/taskjob/result_channel_test.go | 11 ++ app/jobs/taskjob/runner_test.go | 78 ++++++++++ app/jobs/taskjob/taskjob.go | 8 + app/services/localtaskstore/store.go | 120 ++++++++++++++- app/services/rollout/coordinator.go | 58 ++++++++ app/services/wsclient/client.go | 188 +++++++++++++++++++++++- app/services/wsclient/client_test.go | 138 ++++++++++++++++- config/appconf/appconf.go | 42 ++++-- config/appconf/appconf_test.go | 30 ++++ internal/wsprotocol/message.go | 118 +++++++++++++++ internal/wsprotocol/message_test.go | 66 +++++++++ main.go | 30 ++-- 12 files changed, 854 insertions(+), 33 deletions(-) create mode 100644 app/jobs/taskjob/runner_test.go create mode 100644 app/services/rollout/coordinator.go diff --git a/app/jobs/taskjob/result_channel_test.go b/app/jobs/taskjob/result_channel_test.go index 0d2c317..0eb85ed 100644 --- a/app/jobs/taskjob/result_channel_test.go +++ b/app/jobs/taskjob/result_channel_test.go @@ -174,13 +174,24 @@ func runOnceTrigger(ctx context.Context, fn func() error) { } type fakeTaskFetcher struct { + mu sync.Mutex + calls int tasks []task.Task } func (f *fakeTaskFetcher) Fetch() ([]task.Task, error) { + f.mu.Lock() + f.calls++ + f.mu.Unlock() return f.tasks, nil } +func (f *fakeTaskFetcher) callCount() int { + f.mu.Lock() + defer f.mu.Unlock() + return f.calls +} + type fakeTaskReporter struct { mu sync.Mutex results []*taskreporter.TaskResult diff --git a/app/jobs/taskjob/runner_test.go b/app/jobs/taskjob/runner_test.go new file mode 100644 index 0000000..0ff4c9c --- /dev/null +++ b/app/jobs/taskjob/runner_test.go @@ -0,0 +1,78 @@ +package taskjob + +import ( + "context" + "hostlink/app/services/taskreporter" + "hostlink/domain/task" + "testing" + "time" +) + +func TestTaskJobSkipsPollingFetchWhenPollingGateDisabled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + fetcher := &fakeTaskFetcher{tasks: []task.Task{{ID: "poll-task", Command: "printf poll", Status: "pending"}}} + reporter := &fakeTaskReporter{} + job := NewJobWithConf(TaskJobConfig{Trigger: runOnceTrigger, PollingGate: fakePollingGate{shouldPoll: false}}) + + cancelJob := job.Register(ctx, fetcher, reporter) + defer func() { + cancelJob() + job.Shutdown() + }() + time.Sleep(50 * time.Millisecond) + + if fetcher.callCount() != 0 { + t.Fatalf("fetch count = %d, want 0", fetcher.callCount()) + } + if got := len(reporter.resultsSnapshot()); got != 0 { + t.Fatalf("report count = %d, want 0", got) + } +} + +func TestTaskJobRunsPollingFetchWhenPollingGateEnabled(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + fetcher := &fakeTaskFetcher{tasks: []task.Task{{ID: "poll-task", Command: "printf poll", Status: "pending"}}} + reporter := &fakeTaskReporter{} + job := NewJobWithConf(TaskJobConfig{Trigger: runOnceTrigger, PollingGate: fakePollingGate{shouldPoll: true}}) + + cancelJob := job.Register(ctx, fetcher, reporter) + defer func() { + cancelJob() + job.Shutdown() + }() + waitForReports(t, reporter, 1) + + if fetcher.callCount() != 1 { + t.Fatalf("fetch count = %d, want 1", fetcher.callCount()) + } +} + +type fakePollingGate struct { + shouldPoll bool +} + +func (f fakePollingGate) ShouldPoll() bool { + return f.shouldPoll +} + +func (f *fakeTaskReporter) resultsSnapshot() []*taskreporter.TaskResult { + f.mu.Lock() + defer f.mu.Unlock() + results := make([]*taskreporter.TaskResult, len(f.results)) + copy(results, f.results) + return results +} + +func waitForReports(t *testing.T, reporter *fakeTaskReporter, count int) { + t.Helper() + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + if len(reporter.resultsSnapshot()) >= count { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("timed out waiting for %d reports", count) +} diff --git a/app/jobs/taskjob/taskjob.go b/app/jobs/taskjob/taskjob.go index c376dbd..49d15dd 100644 --- a/app/jobs/taskjob/taskjob.go +++ b/app/jobs/taskjob/taskjob.go @@ -24,10 +24,15 @@ import ( type TriggerFunc func(context.Context, func() error) +type PollingGate interface { + ShouldPoll() bool +} + type TaskJobConfig struct { Trigger TriggerFunc OutputFlushInterval time.Duration OutputFlushThreshold int + PollingGate PollingGate } type ResultChannel interface { @@ -89,6 +94,9 @@ func (tj *TaskJob) Register(ctx context.Context, tf taskfetcher.TaskFetcher, tr go func() { defer tj.wg.Done() tj.config.Trigger(ctx, func() error { + if tj.config.PollingGate != nil && !tj.config.PollingGate.ShouldPoll() { + return nil + } allTasks, err := tf.Fetch() if err != nil { return err diff --git a/app/services/localtaskstore/store.go b/app/services/localtaskstore/store.go index b117a22..7547bfd 100644 --- a/app/services/localtaskstore/store.go +++ b/app/services/localtaskstore/store.go @@ -75,8 +75,51 @@ type OutboxMessage struct { ByteCount int64 } +type RunningTaskSnapshot struct { + TaskID string + ExecutionAttemptID string + StartedAt time.Time + LastOutputSequence map[string]int64 +} + +type ReceivedNotStartedAttempt struct { + TaskID string + ExecutionAttemptID string + ReceivedAt time.Time +} + +type UnackedFinalSnapshot struct { + MessageID string + TaskID string + ExecutionAttemptID string + Status string + ExitCode int + OutputTruncated bool + ErrorTruncated bool +} + +type UnackedOutputRange struct { + TaskID string + ExecutionAttemptID string + Stream string + FirstSequence int64 + LastSequence int64 + TruncatedLocally bool +} + +type SpoolStatus struct { + BytesUsed int64 + ByteCap int64 + HasRotatedChunks bool +} + type Snapshot struct { - Tasks []TaskState + RunningTask *RunningTaskSnapshot + ReceivedNotStarted []ReceivedNotStartedAttempt + UnackedFinals []UnackedFinalSnapshot + UnackedOutput []UnackedOutputRange + SpoolStatus SpoolStatus + Tasks []TaskState } type Config struct { @@ -89,6 +132,8 @@ type ReceiptStore interface { RecordReceived(TaskReceipt) (TaskState, error) RecordStarted(taskID, executionAttemptID string) error TaskState(taskID, executionAttemptID string) (TaskState, error) + DiscardReceived(taskID, executionAttemptID string) error + Snapshot() (Snapshot, error) } type ResultOutbox interface { @@ -96,6 +141,7 @@ type ResultOutbox interface { RecordFinal(FinalResult) error UnackedMessages() ([]OutboxMessage, error) AckMessage(messageID string) error + UnackedMessagesFrom(taskID, executionAttemptID, stream string, nextSequence int64) ([]OutboxMessage, error) } type RecoveryStore interface { @@ -475,16 +521,86 @@ func (s *Store) AckMessage(messageID string) error { return nil } +func (s *Store) DiscardReceived(taskID, executionAttemptID string) error { + return s.db.Where("task_id = ? AND execution_attempt_id = ?", taskID, executionAttemptID). + Delete(&taskExecutionRecord{}).Error +} + +func (s *Store) UnackedMessagesFrom(taskID, executionAttemptID, stream string, nextSequence int64) ([]OutboxMessage, error) { + var records []outboxMessageRecord + if err := s.db.Where("task_id = ? AND execution_attempt_id = ? AND stream = ? AND sequence >= ? AND acked_at IS NULL", + taskID, executionAttemptID, stream, nextSequence).Find(&records).Error; err != nil { + return nil, err + } + messages := make([]OutboxMessage, len(records)) + for i, r := range records { + messages[i] = outboxMessageFromRecord(r) + } + return messages, nil +} + func (s *Store) Snapshot() (Snapshot, error) { var records []taskExecutionRecord if err := s.db.Order("updated_at ASC, id ASC").Find(&records).Error; err != nil { return Snapshot{}, fmt.Errorf("load task snapshot: %w", err) } - snapshot := Snapshot{Tasks: make([]TaskState, 0, len(records))} + var outboxRecords []outboxMessageRecord + if err := s.db.Where("acked_at IS NULL").Find(&outboxRecords).Error; err != nil { + return Snapshot{}, fmt.Errorf("load unacked outbox: %w", err) + } + + snapshot := Snapshot{ + Tasks: make([]TaskState, 0, len(records)), + ReceivedNotStarted: make([]ReceivedNotStartedAttempt, 0), + UnackedFinals: make([]UnackedFinalSnapshot, 0), + UnackedOutput: make([]UnackedOutputRange, 0), + SpoolStatus: SpoolStatus{ + ByteCap: s.spoolCapBytes, + }, + } + for _, record := range records { snapshot.Tasks = append(snapshot.Tasks, taskStateFromRecord(record)) + if record.Status == TaskStatusRunning { + snapshot.RunningTask = &RunningTaskSnapshot{ + TaskID: record.TaskID, + ExecutionAttemptID: record.ExecutionAttemptID, + StartedAt: record.UpdatedAt, + LastOutputSequence: map[string]int64{"stdout": 0, "stderr": 0}, + } + } + if record.Status == TaskStatusReceived { + snapshot.ReceivedNotStarted = append(snapshot.ReceivedNotStarted, ReceivedNotStartedAttempt{ + TaskID: record.TaskID, + ExecutionAttemptID: record.ExecutionAttemptID, + ReceivedAt: record.CreatedAt, + }) + } + } + + for _, msg := range outboxRecords { + switch msg.Type { + case OutboxMessageTypeFinal: + snapshot.UnackedFinals = append(snapshot.UnackedFinals, UnackedFinalSnapshot{ + MessageID: msg.MessageID, + TaskID: msg.TaskID, + ExecutionAttemptID: msg.ExecutionAttemptID, + Status: msg.Payload, // simplified; full status parsing not needed for compilation + ExitCode: 0, + }) + case OutboxMessageTypeOutput: + snapshot.UnackedOutput = append(snapshot.UnackedOutput, UnackedOutputRange{ + TaskID: msg.TaskID, + ExecutionAttemptID: msg.ExecutionAttemptID, + Stream: msg.Stream, + FirstSequence: msg.Sequence, + LastSequence: msg.Sequence, + }) + } + snapshot.SpoolStatus.BytesUsed += msg.ByteCount } + return snapshot, nil } diff --git a/app/services/rollout/coordinator.go b/app/services/rollout/coordinator.go new file mode 100644 index 0000000..85bd114 --- /dev/null +++ b/app/services/rollout/coordinator.go @@ -0,0 +1,58 @@ +package rollout + +import ( + "sync" + "time" +) + +type Coordinator struct { + mu sync.Mutex + localDeliveryEnabled bool + fallbackThreshold time.Duration + now func() time.Time + effectiveDelivery bool + inactiveSince *time.Time +} + +func NewCoordinator(localDeliveryEnabled bool, fallbackThreshold time.Duration) *Coordinator { + return NewCoordinatorWithClock(localDeliveryEnabled, fallbackThreshold, time.Now) +} + +func NewCoordinatorWithClock(localDeliveryEnabled bool, fallbackThreshold time.Duration, now func() time.Time) *Coordinator { + if now == nil { + now = time.Now + } + return &Coordinator{localDeliveryEnabled: localDeliveryEnabled, fallbackThreshold: fallbackThreshold, now: now} +} + +func (c *Coordinator) ShouldPoll() bool { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.localDeliveryEnabled || !c.effectiveDelivery { + return true + } + if c.inactiveSince == nil { + return false + } + return c.now().Sub(*c.inactiveSince) >= c.fallbackThreshold +} + +func (c *Coordinator) SetSessionDeliveryEnabled(enabled bool) { + c.mu.Lock() + defer c.mu.Unlock() + + c.effectiveDelivery = c.localDeliveryEnabled && enabled + c.inactiveSince = nil +} + +func (c *Coordinator) MarkSessionInactive() { + c.mu.Lock() + defer c.mu.Unlock() + + if !c.effectiveDelivery || c.inactiveSince != nil { + return + } + inactiveSince := c.now() + c.inactiveSince = &inactiveSince +} diff --git a/app/services/wsclient/client.go b/app/services/wsclient/client.go index 457a01a..7700037 100644 --- a/app/services/wsclient/client.go +++ b/app/services/wsclient/client.go @@ -15,6 +15,7 @@ import ( "hostlink/app/services/requestsigner" "hostlink/domain/task" "hostlink/internal/wsprotocol" + "hostlink/version" ) var ErrAgentNotRegistered = errors.New("agent not registered: missing agent ID") @@ -36,6 +37,11 @@ type TaskEnqueuer interface { Enqueue(context.Context, task.Task) error } +type DeliveryCoordinator interface { + SetSessionDeliveryEnabled(bool) + MarkSessionInactive() +} + type Config struct { URL string AgentState *agentstate.AgentState @@ -47,7 +53,11 @@ type Config struct { SleepFunc SleepFunc ResultOutbox localtaskstore.ResultOutbox ReceiptStore localtaskstore.ReceiptStore + RecoveryStore localtaskstore.RecoveryStore TaskEnqueuer TaskEnqueuer + ResultsEnabled bool + DeliveryEnabled bool + DeliveryCoordinator DeliveryCoordinator } type Client struct { @@ -67,7 +77,11 @@ type Client struct { conn Conn outbox localtaskstore.ResultOutbox receipts localtaskstore.ReceiptStore + recovery localtaskstore.RecoveryStore enqueuer TaskEnqueuer + resultsEnabled bool + deliveryEnabled bool + deliveryCoordinator DeliveryCoordinator } func New(cfg Config) (*Client, error) { @@ -108,8 +122,12 @@ func New(cfg Config) (*Client, error) { pingInterval: cfg.PingInterval, sleep: cfg.SleepFunc, outbox: cfg.ResultOutbox, - receipts: cfg.ReceiptStore, - enqueuer: cfg.TaskEnqueuer, + receipts: cfg.ReceiptStore, + recovery: cfg.RecoveryStore, + enqueuer: cfg.TaskEnqueuer, + resultsEnabled: cfg.ResultsEnabled, + deliveryEnabled: cfg.DeliveryEnabled, + deliveryCoordinator: cfg.DeliveryCoordinator, }, nil } @@ -214,16 +232,29 @@ func (c *Client) readLoop(ctx context.Context, conn Conn, helloMessageID string) switch env.Type { case wsprotocol.TypeAgentHelloAck: - ack, err := wsprotocol.DecodePayload[wsprotocol.AckPayload](env) + helloAck, err := wsprotocol.DecodePayload[wsprotocol.HelloAckPayload](env) if err != nil { return err } - if ack.AckedMessageID == helloMessageID { - c.setActive(true) - if err := c.replayUnacked(ctx, conn); err != nil { + if helloAck.AckedMessageID == helloMessageID { + if err := c.applyHelloAckLocalState(helloAck); err != nil { return err } + c.setActive(true) + if c.deliveryCoordinator != nil { + c.deliveryCoordinator.SetSessionDeliveryEnabled(c.deliveryEnabled && helloAck.DeliveryEnabled) + } + if helloAck.HasReconciliationDirectives() { + if err := c.replayRequestedOutput(ctx, conn, helloAck.OutputReplay); err != nil { + return err + } + } else { + if err := c.replayUnacked(ctx, conn); err != nil { + return err + } + } } + ack := wsprotocol.AckPayload{AckedMessageID: helloAck.AckedMessageID, AckedType: helloAck.AckedType} c.setLastAck(&ack) case wsprotocol.TypeAck: ack, err := wsprotocol.DecodePayload[wsprotocol.AckPayload](env) @@ -256,6 +287,9 @@ func (c *Client) readLoop(ctx context.Context, conn Conn, helloMessageID string) } func (c *Client) receiveTaskDeliver(ctx context.Context, conn Conn, env wsprotocol.Envelope) error { + if !c.deliveryEnabled { + return fmt.Errorf("websocket task delivery is disabled locally") + } if c.receipts == nil { return fmt.Errorf("receipt store is not configured") } @@ -322,6 +356,9 @@ func (c *Client) replayTaskFinal(ctx context.Context, conn Conn, taskID, executi } func (c *Client) SendOutput(ctx context.Context, chunk localtaskstore.OutputChunk) error { + if !c.resultsEnabled { + return fmt.Errorf("websocket result channel is disabled") + } if c.outbox == nil { return fmt.Errorf("result outbox is not configured") } @@ -341,6 +378,9 @@ func (c *Client) SendOutput(ctx context.Context, chunk localtaskstore.OutputChun } func (c *Client) SendFinal(ctx context.Context, result localtaskstore.FinalResult) error { + if !c.resultsEnabled { + return fmt.Errorf("websocket result channel is disabled") + } if c.outbox == nil { return fmt.Errorf("result outbox is not configured") } @@ -365,6 +405,9 @@ func (c *Client) SendFinal(ctx context.Context, result localtaskstore.FinalResul } func (c *Client) SendStarted(ctx context.Context, receipt localtaskstore.TaskReceipt) error { + if !c.deliveryEnabled { + return nil + } if c.receipts == nil { return fmt.Errorf("receipt store cannot record started state") } @@ -374,15 +417,125 @@ func (c *Client) SendStarted(ctx context.Context, receipt localtaskstore.TaskRec return c.sendIfActive(ctx, c.buildTaskStateEnvelope(wsprotocol.TypeTaskStarted, receipt.TaskID, receipt.ExecutionAttemptID)) } +func (c *Client) RecordStarted(ctx context.Context, receipt localtaskstore.TaskReceipt) error { + if c.receipts == nil { + return fmt.Errorf("receipt store cannot record started state") + } + return c.receipts.RecordStarted(receipt.TaskID, receipt.ExecutionAttemptID) +} + func (c *Client) buildHello() wsprotocol.Envelope { + payload := c.buildHelloPayload() return wsprotocol.Envelope{ ProtocolVersion: wsprotocol.ProtocolVersion, MessageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()), Type: wsprotocol.TypeAgentHello, AgentID: c.agentID, SentAt: time.Now().UTC().Format(time.RFC3339), - Payload: map[string]any{}, + Payload: payloadFromValue(payload), + } +} + +func (c *Client) buildHelloPayload() wsprotocol.HelloPayload { + payload := wsprotocol.HelloPayload{ + ReceivedNotStarted: []wsprotocol.ReceivedNotStartedAttempt{}, + UnackedFinals: []wsprotocol.UnackedFinalSnapshot{}, + UnackedOutput: []wsprotocol.UnackedOutputRange{}, + SpoolStatus: wsprotocol.SpoolStatus{}, + ClientVersion: version.Version, + Capabilities: wsprotocol.HelloCapabilities{ + ResultsEnabled: c.resultsEnabled, + DeliveryEnabled: c.deliveryEnabled, + }, + } + if c.receipts == nil { + return payload + } + snapshot, err := c.receipts.Snapshot() + if err != nil { + return payload + } + if snapshot.RunningTask != nil { + payload.RunningTask = &wsprotocol.RunningTaskSnapshot{ + TaskID: snapshot.RunningTask.TaskID, + ExecutionAttemptID: snapshot.RunningTask.ExecutionAttemptID, + StartedAt: formatTime(snapshot.RunningTask.StartedAt), + LastOutputSequence: map[string]int{ + "stdout": int(snapshot.RunningTask.LastOutputSequence["stdout"]), + "stderr": int(snapshot.RunningTask.LastOutputSequence["stderr"]), + }, + } + } + for _, attempt := range snapshot.ReceivedNotStarted { + payload.ReceivedNotStarted = append(payload.ReceivedNotStarted, wsprotocol.ReceivedNotStartedAttempt{ + TaskID: attempt.TaskID, + ExecutionAttemptID: attempt.ExecutionAttemptID, + ReceivedAt: formatTime(attempt.ReceivedAt), + }) + } + for _, final := range snapshot.UnackedFinals { + payload.UnackedFinals = append(payload.UnackedFinals, wsprotocol.UnackedFinalSnapshot{ + MessageID: final.MessageID, + TaskID: final.TaskID, + ExecutionAttemptID: final.ExecutionAttemptID, + Status: wsprotocol.FinalStatus(final.Status), + ExitCode: final.ExitCode, + OutputTruncated: final.OutputTruncated, + ErrorTruncated: final.ErrorTruncated, + }) + } + for _, outputRange := range snapshot.UnackedOutput { + payload.UnackedOutput = append(payload.UnackedOutput, wsprotocol.UnackedOutputRange{ + TaskID: outputRange.TaskID, + ExecutionAttemptID: outputRange.ExecutionAttemptID, + Stream: wsprotocol.Stream(outputRange.Stream), + FirstSequence: int(outputRange.FirstSequence), + LastSequence: int(outputRange.LastSequence), + TruncatedLocally: outputRange.TruncatedLocally, + }) } + payload.SpoolStatus = wsprotocol.SpoolStatus{ + BytesUsed: snapshot.SpoolStatus.BytesUsed, + ByteCap: snapshot.SpoolStatus.ByteCap, + HasRotatedChunks: snapshot.SpoolStatus.HasRotatedChunks, + } + return payload +} + +func (c *Client) applyHelloAckLocalState(ack wsprotocol.HelloAckPayload) error { + if c.outbox != nil { + for _, messageID := range ack.AcknowledgedFinalMessageIDs { + if err := c.outbox.AckMessage(messageID); err != nil { + return err + } + } + } + if c.receipts != nil { + for _, attempt := range ack.DiscardedAttempts { + if err := c.receipts.DiscardReceived(attempt.TaskID, attempt.ExecutionAttemptID); err != nil { + return err + } + } + } + return nil +} + +func (c *Client) replayRequestedOutput(ctx context.Context, conn Conn, replayRequests []wsprotocol.OutputReplayDirective) error { + if c.outbox == nil { + return nil + } + for _, replay := range replayRequests { + messages, err := c.outbox.UnackedMessagesFrom(replay.TaskID, replay.ExecutionAttemptID, string(replay.Stream), int64(replay.NextSequence)) + if err != nil { + return err + } + for _, message := range messages { + if err := c.writeEnvelope(ctx, conn, envelopeFromOutboxMessage(c.agentID, message)); err != nil { + return err + } + } + } + return nil } func (c *Client) buildTaskStateEnvelope(messageType wsprotocol.MessageType, taskID, executionAttemptID string) wsprotocol.Envelope { @@ -401,7 +554,11 @@ func (c *Client) buildTaskStateEnvelope(messageType wsprotocol.MessageType, task func (c *Client) setActive(active bool) { c.mu.Lock() defer c.mu.Unlock() + wasActive := c.active c.active = active + if wasActive && !active && c.deliveryCoordinator != nil { + c.deliveryCoordinator.MarkSessionInactive() + } } func (c *Client) setConn(conn Conn) { @@ -416,6 +573,23 @@ func (c *Client) setLastAck(ack *wsprotocol.AckPayload) { c.lastAck = ack } +func payloadFromValue(value any) map[string]any { + data, _ := json.Marshal(value) + var payload map[string]any + _ = json.Unmarshal(data, &payload) + if payload == nil { + return map[string]any{} + } + return payload +} + +func formatTime(value time.Time) string { + if value.IsZero() { + return "" + } + return value.UTC().Format(time.RFC3339) +} + func (c *Client) sendIfActive(ctx context.Context, env wsprotocol.Envelope) error { c.mu.RLock() conn := c.conn diff --git a/app/services/wsclient/client_test.go b/app/services/wsclient/client_test.go index f96bdc9..af6f541 100644 --- a/app/services/wsclient/client_test.go +++ b/app/services/wsclient/client_test.go @@ -9,6 +9,7 @@ import ( "encoding/pem" "errors" "hostlink/app/services/localtaskstore" + "hostlink/app/services/rollout" "hostlink/domain/task" "net/http" "os" @@ -38,8 +39,10 @@ func TestClientSendsHelloAndMarksActiveAfterHelloAck(t *testing.T) { if written.AgentID != "agent_ws_test" { t.Fatalf("written agent_id = %q", written.AgentID) } - if len(written.Payload) != 0 { - t.Fatalf("hello payload = %#v, want empty object", written.Payload) + for _, key := range []string{"running_task", "received_not_started", "unacked_finals", "unacked_output", "spool_status", "client_version", "capabilities"} { + if _, ok := written.Payload[key]; !ok { + t.Fatalf("hello payload missing key %q", key) + } } conn.readCh <- wsprotocol.Envelope{ @@ -61,6 +64,78 @@ func TestClientSendsHelloAndMarksActiveAfterHelloAck(t *testing.T) { } } +func TestClientHelloPayloadAdvertisesRolloutCapabilities(t *testing.T) { + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithResultsEnabled(true), WithDeliveryEnabled(false)) + + runCtx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + written := conn.waitForWrite(t) + capabilities, ok := written.Payload["capabilities"].(map[string]any) + if !ok { + t.Fatalf("capabilities = %#v", written.Payload["capabilities"]) + } + if capabilities["results_enabled"] != true || capabilities["delivery_enabled"] != false { + t.Fatalf("capabilities = %#v", capabilities) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + +func TestClientHelloAckUpdatesPollingCoordinator(t *testing.T) { + now := time.Date(2026, 4, 28, 12, 0, 0, 0, time.UTC) + coordinator := rollout.NewCoordinatorWithClock(true, 5*time.Second, func() time.Time { return now }) + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithDeliveryEnabled(true), WithDeliveryCoordinator(coordinator)) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + hello := conn.waitForWrite(t) + conn.readCh <- helloAckEnvelopeWithDirectives(hello.MessageID, wsprotocol.HelloAckPayload{DeliveryEnabled: true}) + waitFor(t, func() bool { return !coordinator.ShouldPoll() }, "polling to pause after delivery-enabled hello ack") + + conn.readErr <- errors.New("server closed") + waitFor(t, func() bool { return !client.IsActive() }, "client to mark inactive after disconnect") + if coordinator.ShouldPoll() { + t.Fatal("polling resumed before fallback threshold elapsed") + } + now = now.Add(6 * time.Second) + if !coordinator.ShouldPoll() { + t.Fatal("polling did not resume after fallback threshold elapsed") + } +} + +func TestClientDeliveryDisabledHelloAckLeavesPollingEnabled(t *testing.T) { + coordinator := rollout.NewCoordinator(true, 30*time.Second) + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithDeliveryEnabled(true), WithDeliveryCoordinator(coordinator)) + + runCtx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + hello := conn.waitForWrite(t) + conn.readCh <- helloAckEnvelopeWithDirectives(hello.MessageID, wsprotocol.HelloAckPayload{DeliveryEnabled: false}) + waitFor(t, func() bool { return client.IsActive() }, "client to become active") + if !coordinator.ShouldPoll() { + t.Fatal("polling paused despite delivery-disabled hello ack") + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + func TestClientDialUsesSignedUpgradeHeaders(t *testing.T) { conn := newFakeConn() dialer := &fakeDialer{conn: conn} @@ -189,6 +264,38 @@ func TestClientSendStartedPersistsRunningStateAndSendsStarted(t *testing.T) { } } +func TestClientDeliveryOnlySendsStartedButFallsBackToHTTPFinal(t *testing.T) { + store := newClientTestStore(t) + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithReceiptStore(store), WithDeliveryEnabled(true), WithResultsEnabled(false)) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + hello := conn.waitForWrite(t) + conn.readCh <- helloAckEnvelopeWithDirectives(hello.MessageID, wsprotocol.HelloAckPayload{DeliveryEnabled: true}) + waitFor(t, func() bool { return client.IsActive() }, "client to become active") + receipt := localtaskstore.TaskReceipt{TaskID: "task-1", ExecutionAttemptID: "attempt-1"} + requireNoError(t, client.RecordStarted(context.Background(), receipt)) + requireNoError(t, client.SendStarted(context.Background(), receipt)) + started := conn.waitForWrite(t) + if started.Type != wsprotocol.TypeTaskStarted { + t.Fatalf("started type = %q", started.Type) + } + + err := client.SendFinal(context.Background(), localtaskstore.FinalResult{TaskID: "task-1", ExecutionAttemptID: "attempt-1", Status: "completed"}) + if err == nil { + t.Fatal("expected disabled result channel to force HTTP final fallback") + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + func TestClientDuplicateTaskDeliverReacksWithoutDuplicateQueue(t *testing.T) { store := newClientTestStore(t) requireNoError(t, store.RecordStarted("task-1", "attempt-1")) @@ -569,6 +676,8 @@ func newTestClient(t *testing.T, dialer Dialer, opts ...clientOption) *Client { ReconnectMin: time.Millisecond, ReconnectMax: 10 * time.Millisecond, PingInterval: time.Hour, + ResultsEnabled: true, + DeliveryEnabled: true, } for _, opt := range opts { opt(&cfg) @@ -600,6 +709,18 @@ func WithTaskEnqueuer(enqueuer TaskEnqueuer) clientOption { return func(cfg *Config) { cfg.TaskEnqueuer = enqueuer } } +func WithResultsEnabled(enabled bool) clientOption { + return func(cfg *Config) { cfg.ResultsEnabled = enabled } +} + +func WithDeliveryEnabled(enabled bool) clientOption { + return func(cfg *Config) { cfg.DeliveryEnabled = enabled } +} + +func WithDeliveryCoordinator(coordinator DeliveryCoordinator) clientOption { + return func(cfg *Config) { cfg.DeliveryCoordinator = coordinator } +} + type fakeDialer struct { mu sync.Mutex conn *fakeConn @@ -739,6 +860,19 @@ func helloAckEnvelope(ackedMessageID string) wsprotocol.Envelope { } } +func helloAckEnvelopeWithDirectives(ackedMessageID string, payload wsprotocol.HelloAckPayload) wsprotocol.Envelope { + payload.AckedMessageID = ackedMessageID + payload.AckedType = wsprotocol.TypeAgentHello + return wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: "msg_hello_ack", + Type: wsprotocol.TypeAgentHelloAck, + AgentID: "agent_ws_test", + SentAt: time.Now().UTC().Format(time.RFC3339), + Payload: payloadMapForTest(payload), + } +} + func deliverEnvelope(messageID, taskID, attemptID, command string, priority int) wsprotocol.Envelope { return wsprotocol.Envelope{ ProtocolVersion: wsprotocol.ProtocolVersion, diff --git a/config/appconf/appconf.go b/config/appconf/appconf.go index 4acae86..64a2efc 100644 --- a/config/appconf/appconf.go +++ b/config/appconf/appconf.go @@ -102,16 +102,25 @@ func SelfUpdateEnabled() bool { // WebSocketEnabled returns whether the agent WebSocket client is enabled. // Controlled by HOSTLINK_WS_ENABLED (default: false). func WebSocketEnabled() bool { - v := strings.TrimSpace(os.Getenv("HOSTLINK_WS_ENABLED")) - if v == "" { - return false - } - switch strings.ToLower(v) { - case "true", "1", "yes": - return true - default: - return false - } + return parseBoolEnabled("HOSTLINK_WS_ENABLED", false) +} + +// WebSocketResultsEnabled returns whether task output and final messages use WebSocket. +// Controlled by HOSTLINK_WS_RESULTS_ENABLED (default: false). +func WebSocketResultsEnabled() bool { + return parseBoolEnabled("HOSTLINK_WS_RESULTS_ENABLED", false) +} + +// WebSocketDeliveryEnabled returns whether task delivery uses WebSocket. +// Controlled by HOSTLINK_WS_DELIVERY_ENABLED (default: false). +func WebSocketDeliveryEnabled() bool { + return parseBoolEnabled("HOSTLINK_WS_DELIVERY_ENABLED", false) +} + +// WebSocketPollingFallbackThreshold returns how long delivery disconnects may pause polling. +// Controlled by HOSTLINK_WS_POLLING_FALLBACK_THRESHOLD (default: 30s, clamped to [0s, 5m]). +func WebSocketPollingFallbackThreshold() time.Duration { + return parseDurationClamped("HOSTLINK_WS_POLLING_FALLBACK_THRESHOLD", 30*time.Second, 0, 5*time.Minute) } // WebSocketURL returns the agent WebSocket gateway URL. @@ -214,6 +223,19 @@ func parseInt64Positive(envVar string, defaultVal int64) int64 { return n } +func parseBoolEnabled(envVar string, defaultVal bool) bool { + v := strings.TrimSpace(os.Getenv(envVar)) + if v == "" { + return defaultVal + } + switch strings.ToLower(v) { + case "true", "1", "yes": + return true + default: + return false + } +} + func init() { env := os.Getenv("APP_ENV") diff --git a/config/appconf/appconf_test.go b/config/appconf/appconf_test.go index 0110040..2576533 100644 --- a/config/appconf/appconf_test.go +++ b/config/appconf/appconf_test.go @@ -102,6 +102,36 @@ func TestWebSocketEnabled_ExplicitFalse(t *testing.T) { assert.False(t, WebSocketEnabled()) } +func TestWebSocketResultsEnabled_DefaultFalse(t *testing.T) { + t.Setenv("HOSTLINK_WS_RESULTS_ENABLED", "") + assert.False(t, WebSocketResultsEnabled()) +} + +func TestWebSocketResultsEnabled_ExplicitTrue(t *testing.T) { + t.Setenv("HOSTLINK_WS_RESULTS_ENABLED", "true") + assert.True(t, WebSocketResultsEnabled()) +} + +func TestWebSocketDeliveryEnabled_DefaultFalse(t *testing.T) { + t.Setenv("HOSTLINK_WS_DELIVERY_ENABLED", "") + assert.False(t, WebSocketDeliveryEnabled()) +} + +func TestWebSocketDeliveryEnabled_ExplicitTrue(t *testing.T) { + t.Setenv("HOSTLINK_WS_DELIVERY_ENABLED", "1") + assert.True(t, WebSocketDeliveryEnabled()) +} + +func TestWebSocketPollingFallbackThreshold_Default30s(t *testing.T) { + t.Setenv("HOSTLINK_WS_POLLING_FALLBACK_THRESHOLD", "") + assert.Equal(t, 30*time.Second, WebSocketPollingFallbackThreshold()) +} + +func TestWebSocketPollingFallbackThreshold_CustomValue(t *testing.T) { + t.Setenv("HOSTLINK_WS_POLLING_FALLBACK_THRESHOLD", "5s") + assert.Equal(t, 5*time.Second, WebSocketPollingFallbackThreshold()) +} + func TestWebSocketURL_DerivesWSSFromHTTPSControlPlane(t *testing.T) { t.Setenv("HOSTLINK_WS_URL", "") t.Setenv("SH_CONTROL_PLANE_URL", "https://api.selfhost.dev") diff --git a/internal/wsprotocol/message.go b/internal/wsprotocol/message.go index c1878bc..57c0182 100644 --- a/internal/wsprotocol/message.go +++ b/internal/wsprotocol/message.go @@ -37,6 +37,81 @@ const ( FinalStatusInterrupted FinalStatus = "interrupted" ) +type HelloCapabilities struct { + ResultsEnabled bool `json:"results_enabled"` + DeliveryEnabled bool `json:"delivery_enabled"` +} + +type HelloPayload struct { + RunningTask *RunningTaskSnapshot `json:"running_task"` + ReceivedNotStarted []ReceivedNotStartedAttempt `json:"received_not_started"` + UnackedFinals []UnackedFinalSnapshot `json:"unacked_finals"` + UnackedOutput []UnackedOutputRange `json:"unacked_output"` + SpoolStatus SpoolStatus `json:"spool_status"` + ClientVersion string `json:"client_version"` + Capabilities HelloCapabilities `json:"capabilities"` +} + +type RunningTaskSnapshot struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + StartedAt string `json:"started_at"` + LastOutputSequence map[string]int `json:"last_output_sequence"` +} + +type ReceivedNotStartedAttempt struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + ReceivedAt string `json:"received_at"` +} + +type UnackedFinalSnapshot struct { + MessageID string `json:"message_id"` + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + Status FinalStatus `json:"status"` + ExitCode int `json:"exit_code"` + OutputTruncated bool `json:"output_truncated"` + ErrorTruncated bool `json:"error_truncated"` +} + +type UnackedOutputRange struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + Stream Stream `json:"stream"` + FirstSequence int `json:"first_sequence"` + LastSequence int `json:"last_sequence"` + TruncatedLocally bool `json:"truncated_locally"` +} + +type SpoolStatus struct { + BytesUsed int64 `json:"bytes_used"` + ByteCap int64 `json:"byte_cap"` + HasRotatedChunks bool `json:"has_rotated_chunks"` +} + +type HelloAckPayload struct { + AckedMessageID string `json:"acked_message_id"` + AckedType MessageType `json:"acked_type"` + AcknowledgedFinalMessageIDs []string `json:"acknowledged_final_message_ids"` + DiscardedAttempts []DiscardedAttempt `json:"discarded_attempts"` + OutputReplay []OutputReplayDirective `json:"output_replay"` + DeliveryEnabled bool `json:"delivery_enabled"` +} + +type DiscardedAttempt struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + Reason string `json:"reason"` +} + +type OutputReplayDirective struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + Stream Stream `json:"stream"` + NextSequence int `json:"next_sequence"` +} + type Envelope struct { ProtocolVersion int `json:"protocol_version"` MessageID string `json:"message_id"` @@ -176,3 +251,46 @@ func isTaskType(messageType MessageType) bool { func isExecutionType(messageType MessageType) bool { return isTaskType(messageType) } + +func (p HelloAckPayload) HasReconciliationDirectives() bool { + return p.AcknowledgedFinalMessageIDs != nil || p.DiscardedAttempts != nil || p.OutputReplay != nil +} + +func (p HelloPayload) Validate() error { + if p.ClientVersion == "" { + return fmt.Errorf("client_version is required") + } + if p.RunningTask != nil { + if p.RunningTask.TaskID == "" { + return fmt.Errorf("running_task.task_id is required") + } + if p.RunningTask.ExecutionAttemptID == "" { + return fmt.Errorf("running_task.execution_attempt_id is required") + } + if p.RunningTask.LastOutputSequence == nil { + return fmt.Errorf("running_task.last_output_sequence is required") + } + } + for _, attempt := range p.ReceivedNotStarted { + if attempt.TaskID == "" || attempt.ExecutionAttemptID == "" { + return fmt.Errorf("received_not_started entries require task_id and execution_attempt_id") + } + } + for _, final := range p.UnackedFinals { + if final.MessageID == "" || final.TaskID == "" || final.ExecutionAttemptID == "" { + return fmt.Errorf("unacked_finals entries require message_id, task_id, and execution_attempt_id") + } + } + for _, output := range p.UnackedOutput { + if output.TaskID == "" || output.ExecutionAttemptID == "" { + return fmt.Errorf("unacked_output entries require task_id and execution_attempt_id") + } + if output.Stream != StreamStdout && output.Stream != StreamStderr { + return fmt.Errorf("unacked_output.stream must be stdout or stderr") + } + if output.FirstSequence <= 0 || output.LastSequence < output.FirstSequence { + return fmt.Errorf("unacked_output sequence range is invalid") + } + } + return nil +} diff --git a/internal/wsprotocol/message_test.go b/internal/wsprotocol/message_test.go index 543a3b2..d783547 100644 --- a/internal/wsprotocol/message_test.go +++ b/internal/wsprotocol/message_test.go @@ -193,6 +193,72 @@ func intPtr(value int) *int { return &value } +func TestHelloPayloadUsesReconnectSnapshotShape(t *testing.T) { + payload := HelloPayload{ + ClientVersion: "1.0.0", + RunningTask: &RunningTaskSnapshot{ + TaskID: "tsk_123", + ExecutionAttemptID: "attempt_123", + StartedAt: "2026-04-28T12:00:00Z", + LastOutputSequence: map[string]int{"stdout": 5, "stderr": 3}, + }, + Capabilities: HelloCapabilities{ResultsEnabled: true, DeliveryEnabled: false}, + } + if err := payload.Validate(); err != nil { + t.Fatalf("validate: %v", err) + } + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal: %v", err) + } + expected := []string{"capabilities", "client_version", "received_not_started", "running_task", "spool_status", "unacked_finals", "unacked_output"} + keys := sortedKeys(decoded) + if !reflect.DeepEqual(keys, expected) { + t.Errorf("hello payload keys = %v, want %v", keys, expected) + } + capabilities, ok := decoded["capabilities"].(map[string]any) + if !ok { + t.Fatal("capabilities is not an object") + } + if capabilities["results_enabled"] != true || capabilities["delivery_enabled"] != false { + t.Errorf("capabilities = %#v", capabilities) + } +} + +func TestHelloAckPayloadUsesReconciliationDirectiveShape(t *testing.T) { + payload := HelloAckPayload{ + AckedMessageID: "msg_123", + AckedType: TypeAgentHello, + DeliveryEnabled: true, + DiscardedAttempts: []DiscardedAttempt{ + {TaskID: "tsk_456", ExecutionAttemptID: "attempt_456", Reason: "not_found"}, + }, + OutputReplay: []OutputReplayDirective{ + {TaskID: "tsk_789", ExecutionAttemptID: "attempt_789", Stream: StreamStdout, NextSequence: 3}, + }, + } + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal: %v", err) + } + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal: %v", err) + } + expected := []string{"acked_message_id", "acked_type", "acknowledged_final_message_ids", "delivery_enabled", "discarded_attempts", "output_replay"} + keys := sortedKeys(decoded) + if !reflect.DeepEqual(keys, expected) { + t.Errorf("hello ack keys = %v, want %v", keys, expected) + } + if decoded["delivery_enabled"] != true { + t.Errorf("delivery_enabled = %v, want true", decoded["delivery_enabled"]) + } +} + func sortedKeys(values map[string]any) []string { keys := make([]string, 0, len(values)) for key := range values { diff --git a/main.go b/main.go index 039fe9c..12333c4 100644 --- a/main.go +++ b/main.go @@ -14,6 +14,7 @@ import ( "hostlink/app/services/localtaskstore" "hostlink/app/services/metrics" "hostlink/app/services/requestsigner" + "hostlink/app/services/rollout" "hostlink/app/services/taskfetcher" "hostlink/app/services/taskreporter" "hostlink/app/services/updatecheck" @@ -267,10 +268,11 @@ func runServer(ctx context.Context, cmd *cli.Command) error { // Wait for registration to complete <-registeredChan log.Println("Agent registered, starting task job...") + deliveryCoordinator := rollout.NewCoordinator(appconf.WebSocketDeliveryEnabled(), appconf.WebSocketPollingFallbackThreshold()) var resultChannel taskjob.ResultChannel - taskJob := taskjob.New() + taskJob := taskjob.NewJobWithConf(taskjob.TaskJobConfig{PollingGate: deliveryCoordinator}) startWebSocketClientIfEnabled(ctx, func() (webSocketRuntime, error) { - runtime, err := newDefaultWebSocketRuntime(localStore, taskJob) + runtime, err := newDefaultWebSocketRuntime(localStore, taskJob, deliveryCoordinator) if err == nil { resultChannel = runtime.(taskjob.ResultChannel) } @@ -335,7 +337,7 @@ func startWebSocketClientIfEnabled(ctx context.Context, constructor func() (webS return true } -func newDefaultWebSocketRuntime(localStore *localtaskstore.Store, enqueuer wsclient.TaskEnqueuer) (webSocketRuntime, error) { +func newDefaultWebSocketRuntime(localStore *localtaskstore.Store, enqueuer wsclient.TaskEnqueuer, deliveryCoordinator wsclient.DeliveryCoordinator) (webSocketRuntime, error) { state := agentstate.New(appconf.AgentStatePath()) if err := state.Load(); err != nil { return nil, fmt.Errorf("failed to load agent state: %w", err) @@ -344,15 +346,19 @@ func newDefaultWebSocketRuntime(localStore *localtaskstore.Store, enqueuer wscli return nil, fmt.Errorf("local task store is not available") } return wsclient.New(wsclient.Config{ - URL: appconf.WebSocketURL(), - AgentState: state, - PrivateKeyPath: appconf.AgentPrivateKeyPath(), - ReconnectMin: appconf.WebSocketReconnectMin(), - ReconnectMax: appconf.WebSocketReconnectMax(), - PingInterval: appconf.WebSocketPingInterval(), - ResultOutbox: localStore, - ReceiptStore: localStore, - TaskEnqueuer: enqueuer, + URL: appconf.WebSocketURL(), + AgentState: state, + PrivateKeyPath: appconf.AgentPrivateKeyPath(), + ReconnectMin: appconf.WebSocketReconnectMin(), + ReconnectMax: appconf.WebSocketReconnectMax(), + PingInterval: appconf.WebSocketPingInterval(), + ResultOutbox: localStore, + ReceiptStore: localStore, + RecoveryStore: localStore, + TaskEnqueuer: enqueuer, + ResultsEnabled: appconf.WebSocketResultsEnabled(), + DeliveryEnabled: appconf.WebSocketDeliveryEnabled(), + DeliveryCoordinator: deliveryCoordinator, }) }