From 6d84b04bba3b8afc62eb2bef87d910056785f4f8 Mon Sep 17 00:00:00 2001 From: Mohammad Aziz Date: Tue, 28 Apr 2026 13:50:44 +0530 Subject: [PATCH] feat(ws): add reconnect reconciliation with agent snapshot and hello-ack directives --- app/services/localtaskstore/execution_test.go | 16 ++ app/services/localtaskstore/recovery_test.go | 54 ++++ app/services/localtaskstore/store.go | 234 +++++++++++++++++- app/services/wsclient/client.go | 140 ++++++++++- app/services/wsclient/client_test.go | 186 +++++++++++++- internal/wsprotocol/message.go | 111 +++++++++ internal/wsprotocol/message_test.go | 63 +++++ main.go | 1 + 8 files changed, 791 insertions(+), 14 deletions(-) diff --git a/app/services/localtaskstore/execution_test.go b/app/services/localtaskstore/execution_test.go index 64a593a..8208969 100644 --- a/app/services/localtaskstore/execution_test.go +++ b/app/services/localtaskstore/execution_test.go @@ -51,6 +51,22 @@ func TestTaskStateTreatsNewAttemptAsDistinct(t *testing.T) { require.False(t, state.Exists) } +func TestDiscardReceivedRemovesOnlyReceivedAttempt(t *testing.T) { + store := newTestStore(t, 1024*1024, 1024) + + _, err := store.RecordReceived(TaskReceipt{TaskID: "task-1", ExecutionAttemptID: "attempt-received"}) + require.NoError(t, err) + require.NoError(t, store.RecordStarted("task-1", "attempt-running")) + + require.NoError(t, store.DiscardReceived("task-1", "attempt-received")) + received, err := store.TaskState("task-1", "attempt-received") + require.NoError(t, err) + require.False(t, received.Exists) + running, err := store.TaskState("task-1", "attempt-running") + require.NoError(t, err) + require.Equal(t, TaskStatusRunning, running.Status) +} + func openTestStore(t *testing.T, path string, spoolCapBytes, terminalReserveBytes int64) *Store { t.Helper() diff --git a/app/services/localtaskstore/recovery_test.go b/app/services/localtaskstore/recovery_test.go index e1bb0b3..c8aeda5 100644 --- a/app/services/localtaskstore/recovery_test.go +++ b/app/services/localtaskstore/recovery_test.go @@ -20,6 +20,47 @@ func TestSnapshotIncludesLocalTruncationState(t *testing.T) { require.True(t, snapshot.Tasks[0].LocalOutputTruncated) } +func TestSnapshotIncludesReconnectSafeState(t *testing.T) { + store := newTestStore(t, 1024*1024, 1024) + + _, err := store.RecordReceived(TaskReceipt{TaskID: "task-received", ExecutionAttemptID: "attempt-received"}) + require.NoError(t, err) + require.NoError(t, store.RecordStarted("task-running", "attempt-running")) + require.NoError(t, store.AppendOutputChunk(OutputChunk{ + MessageID: "msg-output-1", + TaskID: "task-running", + ExecutionAttemptID: "attempt-running", + Stream: "stdout", + Sequence: 4, + Payload: "hello", + ByteCount: 5, + })) + require.NoError(t, store.RecordFinal(FinalResult{ + MessageID: "msg-final-1", + TaskID: "task-final", + ExecutionAttemptID: "attempt-final", + Status: "completed", + ExitCode: 0, + Payload: `{"status":"completed","exit_code":0,"output_truncated":false,"error_truncated":false}`, + })) + + snapshot, err := store.Snapshot() + require.NoError(t, err) + require.NotNil(t, snapshot.RunningTask) + require.Equal(t, "task-running", snapshot.RunningTask.TaskID) + require.Equal(t, int64(4), snapshot.RunningTask.LastOutputSequence["stdout"]) + require.Len(t, snapshot.ReceivedNotStarted, 1) + require.Equal(t, "task-received", snapshot.ReceivedNotStarted[0].TaskID) + require.Len(t, snapshot.UnackedFinals, 1) + require.Equal(t, "msg-final-1", snapshot.UnackedFinals[0].MessageID) + require.Equal(t, "completed", snapshot.UnackedFinals[0].Status) + require.Len(t, snapshot.UnackedOutput, 1) + require.Equal(t, int64(4), snapshot.UnackedOutput[0].FirstSequence) + require.Equal(t, int64(4), snapshot.UnackedOutput[0].LastSequence) + require.Positive(t, snapshot.SpoolStatus.BytesUsed) + require.Equal(t, int64(1024*1024), snapshot.SpoolStatus.ByteCap) +} + func TestMarkInterruptedRunningTasksQueuesTerminalRecordAcrossRestart(t *testing.T) { storePath := filepath.Join(t.TempDir(), "task_store.db") store := openTestStore(t, storePath, 1024*1024, 1024) @@ -41,6 +82,19 @@ func TestMarkInterruptedRunningTasksQueuesTerminalRecordAcrossRestart(t *testing require.Contains(t, messages[0].Payload, "interrupted") } +func TestSnapshotReportsInterruptedStartupRecoveryAsUnackedFinal(t *testing.T) { + store := newTestStore(t, 1024*1024, 1024) + + require.NoError(t, store.RecordStarted("task-1", "attempt-1")) + require.NoError(t, store.MarkInterruptedRunningTasks()) + + snapshot, err := store.Snapshot() + require.NoError(t, err) + require.Nil(t, snapshot.RunningTask) + require.Len(t, snapshot.UnackedFinals, 1) + require.Equal(t, "interrupted", snapshot.UnackedFinals[0].Status) +} + func TestMarkInterruptedRunningTasksIsIdempotent(t *testing.T) { store := newTestStore(t, 1024*1024, 1024) diff --git a/app/services/localtaskstore/store.go b/app/services/localtaskstore/store.go index b117a22..2d91e53 100644 --- a/app/services/localtaskstore/store.go +++ b/app/services/localtaskstore/store.go @@ -1,6 +1,7 @@ package localtaskstore import ( + "encoding/json" "errors" "fmt" "os" @@ -41,6 +42,8 @@ type TaskState struct { OutputTruncated bool ErrorTruncated bool LocalOutputTruncated bool + ReceivedAt time.Time + StartedAt time.Time } type OutputChunk struct { @@ -76,7 +79,50 @@ type OutboxMessage struct { } type Snapshot struct { - Tasks []TaskState + Tasks []TaskState + RunningTask *RunningTaskSnapshot + ReceivedNotStarted []ReceivedAttemptSnapshot + UnackedFinals []UnackedFinalSnapshot + UnackedOutput []UnackedOutputRange + SpoolStatus SpoolStatus +} + +type RunningTaskSnapshot struct { + TaskID string + ExecutionAttemptID string + StartedAt time.Time + LastOutputSequence map[string]int64 +} + +type ReceivedAttemptSnapshot struct { + TaskID string + ExecutionAttemptID string + ReceivedAt time.Time +} + +type UnackedFinalSnapshot struct { + MessageID string + TaskID string + ExecutionAttemptID string + Status string + ExitCode int + OutputTruncated bool + ErrorTruncated bool +} + +type UnackedOutputRange struct { + TaskID string + ExecutionAttemptID string + Stream string + FirstSequence int64 + LastSequence int64 + TruncatedLocally bool +} + +type SpoolStatus struct { + BytesUsed int64 + ByteCap int64 + HasRotatedChunks bool } type Config struct { @@ -96,11 +142,13 @@ type ResultOutbox interface { RecordFinal(FinalResult) error UnackedMessages() ([]OutboxMessage, error) AckMessage(messageID string) error + UnackedMessagesFrom(taskID, executionAttemptID, stream string, nextSequence int64) ([]OutboxMessage, error) } type RecoveryStore interface { Snapshot() (Snapshot, error) MarkInterruptedRunningTasks() error + DiscardReceived(taskID, executionAttemptID string) error } type Store struct { @@ -118,6 +166,8 @@ type taskExecutionRecord struct { OutputTruncated bool ErrorTruncated bool LocalOutputTruncated bool + ReceivedAt *time.Time + StartedAt *time.Time CreatedAt time.Time UpdatedAt time.Time } @@ -223,6 +273,7 @@ func (s *Store) RecordReceived(receipt TaskReceipt) (TaskState, error) { var state TaskState err := s.db.Transaction(func(tx *gorm.DB) error { + now := time.Now().UTC() if err := s.ensureTerminalReserveAvailable(tx); err != nil { return err } @@ -241,6 +292,7 @@ func (s *Store) RecordReceived(receipt TaskReceipt) (TaskState, error) { TaskID: receipt.TaskID, ExecutionAttemptID: receipt.ExecutionAttemptID, Status: TaskStatusReceived, + ReceivedAt: &now, } if err := tx.Create(&record).Error; err != nil { return err @@ -280,10 +332,12 @@ func (s *Store) RecordStarted(taskID, executionAttemptID string) error { } return s.db.Transaction(func(tx *gorm.DB) error { + now := time.Now().UTC() return s.upsertExecutionState(tx, taskExecutionRecord{ TaskID: taskID, ExecutionAttemptID: executionAttemptID, Status: TaskStatusRunning, + StartedAt: &now, }) }) } @@ -464,6 +518,39 @@ func (s *Store) UnackedMessages() ([]OutboxMessage, error) { return messages, nil } +func (s *Store) UnackedMessagesFrom(taskID, executionAttemptID, stream string, nextSequence int64) ([]OutboxMessage, error) { + if taskID == "" { + return nil, fmt.Errorf("task ID is required") + } + if executionAttemptID == "" { + return nil, fmt.Errorf("execution attempt ID is required") + } + if stream == "" { + return nil, fmt.Errorf("stream is required") + } + if nextSequence <= 0 { + return nil, fmt.Errorf("next sequence must be positive") + } + + var records []outboxMessageRecord + if err := s.db.Where( + "acked_at IS NULL AND type = ? AND task_id = ? AND execution_attempt_id = ? AND stream = ? AND sequence >= ?", + OutboxMessageTypeOutput, + taskID, + executionAttemptID, + stream, + nextSequence, + ).Order("sequence ASC, created_at ASC, id ASC").Find(&records).Error; err != nil { + return nil, fmt.Errorf("load unacked messages from sequence: %w", err) + } + + messages := make([]OutboxMessage, 0, len(records)) + for _, record := range records { + messages = append(messages, outboxMessageFromRecord(record)) + } + return messages, nil +} + func (s *Store) AckMessage(messageID string) error { if messageID == "" { return fmt.Errorf("message ID is required") @@ -476,16 +563,105 @@ func (s *Store) AckMessage(messageID string) error { } func (s *Store) Snapshot() (Snapshot, error) { - var records []taskExecutionRecord - if err := s.db.Order("updated_at ASC, id ASC").Find(&records).Error; err != nil { + var snapshot Snapshot + err := s.db.Transaction(func(tx *gorm.DB) error { + var records []taskExecutionRecord + if err := tx.Order("updated_at ASC, id ASC").Find(&records).Error; err != nil { + return err + } + + snapshot = Snapshot{ + Tasks: make([]TaskState, 0, len(records)), + ReceivedNotStarted: []ReceivedAttemptSnapshot{}, + UnackedFinals: []UnackedFinalSnapshot{}, + UnackedOutput: []UnackedOutputRange{}, + SpoolStatus: SpoolStatus{ByteCap: s.spoolCapBytes}, + } + stateByAttempt := make(map[string]TaskState, len(records)) + for _, record := range records { + state := taskStateFromRecord(record) + snapshot.Tasks = append(snapshot.Tasks, state) + stateByAttempt[attemptKey(state.TaskID, state.ExecutionAttemptID)] = state + switch state.Status { + case TaskStatusRunning: + running := RunningTaskSnapshot{ + TaskID: state.TaskID, + ExecutionAttemptID: state.ExecutionAttemptID, + StartedAt: state.StartedAt, + LastOutputSequence: map[string]int64{"stdout": 0, "stderr": 0}, + } + snapshot.RunningTask = &running + case TaskStatusReceived: + snapshot.ReceivedNotStarted = append(snapshot.ReceivedNotStarted, ReceivedAttemptSnapshot{ + TaskID: state.TaskID, + ExecutionAttemptID: state.ExecutionAttemptID, + ReceivedAt: state.ReceivedAt, + }) + } + } + + var outbox []outboxMessageRecord + if err := tx.Where("acked_at IS NULL").Order("created_at ASC, id ASC").Find(&outbox).Error; err != nil { + return err + } + outputRanges := map[string]*UnackedOutputRange{} + for _, message := range outbox { + snapshot.SpoolStatus.BytesUsed += message.ByteCount + state := stateByAttempt[attemptKey(message.TaskID, message.ExecutionAttemptID)] + if state.LocalOutputTruncated { + snapshot.SpoolStatus.HasRotatedChunks = true + } + switch message.Type { + case OutboxMessageTypeFinal: + snapshot.UnackedFinals = append(snapshot.UnackedFinals, finalSnapshotFromRecord(message)) + case OutboxMessageTypeOutput: + key := attemptKey(message.TaskID, message.ExecutionAttemptID) + "|" + message.Stream + rangeSnapshot := outputRanges[key] + if rangeSnapshot == nil { + rangeSnapshot = &UnackedOutputRange{ + TaskID: message.TaskID, + ExecutionAttemptID: message.ExecutionAttemptID, + Stream: message.Stream, + FirstSequence: message.Sequence, + LastSequence: message.Sequence, + TruncatedLocally: state.LocalOutputTruncated, + } + outputRanges[key] = rangeSnapshot + } else { + if message.Sequence < rangeSnapshot.FirstSequence { + rangeSnapshot.FirstSequence = message.Sequence + } + if message.Sequence > rangeSnapshot.LastSequence { + rangeSnapshot.LastSequence = message.Sequence + } + rangeSnapshot.TruncatedLocally = rangeSnapshot.TruncatedLocally || state.LocalOutputTruncated + } + if snapshot.RunningTask != nil && snapshot.RunningTask.TaskID == message.TaskID && snapshot.RunningTask.ExecutionAttemptID == message.ExecutionAttemptID { + if message.Sequence > snapshot.RunningTask.LastOutputSequence[message.Stream] { + snapshot.RunningTask.LastOutputSequence[message.Stream] = message.Sequence + } + } + } + } + for _, outputRange := range outputRanges { + snapshot.UnackedOutput = append(snapshot.UnackedOutput, *outputRange) + } + return nil + }) + if err != nil { return Snapshot{}, fmt.Errorf("load task snapshot: %w", err) } + return snapshot, nil +} - snapshot := Snapshot{Tasks: make([]TaskState, 0, len(records))} - for _, record := range records { - snapshot.Tasks = append(snapshot.Tasks, taskStateFromRecord(record)) +func (s *Store) DiscardReceived(taskID, executionAttemptID string) error { + if taskID == "" { + return fmt.Errorf("task ID is required") } - return snapshot, nil + if executionAttemptID == "" { + return fmt.Errorf("execution attempt ID is required") + } + return s.db.Where("task_id = ? AND execution_attempt_id = ? AND status = ?", taskID, executionAttemptID, TaskStatusReceived).Delete(&taskExecutionRecord{}).Error } func (s *Store) MarkInterruptedRunningTasks() error { @@ -553,6 +729,12 @@ func (s *Store) upsertExecutionState(tx *gorm.DB, next taskExecutionRecord) erro if next.LocalOutputTruncated { updates["local_output_truncated"] = true } + if next.ReceivedAt != nil { + updates["received_at"] = next.ReceivedAt + } + if next.StartedAt != nil { + updates["started_at"] = next.StartedAt + } return tx.Model(&existing).Updates(updates).Error } @@ -595,6 +777,14 @@ func validateFinalResult(result FinalResult) error { } func taskStateFromRecord(record taskExecutionRecord) TaskState { + var receivedAt time.Time + if record.ReceivedAt != nil { + receivedAt = *record.ReceivedAt + } + var startedAt time.Time + if record.StartedAt != nil { + startedAt = *record.StartedAt + } return TaskState{ ID: record.ID, Exists: true, @@ -605,6 +795,8 @@ func taskStateFromRecord(record taskExecutionRecord) TaskState { OutputTruncated: record.OutputTruncated, ErrorTruncated: record.ErrorTruncated, LocalOutputTruncated: record.LocalOutputTruncated, + ReceivedAt: receivedAt, + StartedAt: startedAt, } } @@ -624,3 +816,31 @@ func outboxMessageFromRecord(record outboxMessageRecord) OutboxMessage { func interruptedMessageID(taskID, executionAttemptID string) string { return "local-interrupted-" + strings.NewReplacer("/", "-", " ", "-", "|", "-").Replace(taskID+"-"+executionAttemptID) } + +func attemptKey(taskID, executionAttemptID string) string { + return taskID + "|" + executionAttemptID +} + +func finalSnapshotFromRecord(record outboxMessageRecord) UnackedFinalSnapshot { + snapshot := UnackedFinalSnapshot{ + MessageID: record.MessageID, + TaskID: record.TaskID, + ExecutionAttemptID: record.ExecutionAttemptID, + } + if record.Payload == "" { + return snapshot + } + var payload struct { + Status string `json:"status"` + ExitCode int `json:"exit_code"` + OutputTruncated bool `json:"output_truncated"` + ErrorTruncated bool `json:"error_truncated"` + } + if err := json.Unmarshal([]byte(record.Payload), &payload); err == nil { + snapshot.Status = payload.Status + snapshot.ExitCode = payload.ExitCode + snapshot.OutputTruncated = payload.OutputTruncated + snapshot.ErrorTruncated = payload.ErrorTruncated + } + return snapshot +} diff --git a/app/services/wsclient/client.go b/app/services/wsclient/client.go index 4cee80a..1f2897f 100644 --- a/app/services/wsclient/client.go +++ b/app/services/wsclient/client.go @@ -15,6 +15,7 @@ import ( "hostlink/app/services/requestsigner" "hostlink/domain/task" "hostlink/internal/wsprotocol" + "hostlink/version" ) var ErrAgentNotRegistered = errors.New("agent not registered: missing agent ID") @@ -47,6 +48,7 @@ type Config struct { SleepFunc SleepFunc ResultOutbox localtaskstore.ResultOutbox ReceiptStore localtaskstore.ReceiptStore + RecoveryStore localtaskstore.RecoveryStore TaskEnqueuer TaskEnqueuer } @@ -67,6 +69,7 @@ type Client struct { conn Conn outbox localtaskstore.ResultOutbox receipts localtaskstore.ReceiptStore + recovery localtaskstore.RecoveryStore enqueuer TaskEnqueuer } @@ -109,6 +112,7 @@ func New(cfg Config) (*Client, error) { sleep: cfg.SleepFunc, outbox: cfg.ResultOutbox, receipts: cfg.ReceiptStore, + recovery: cfg.RecoveryStore, enqueuer: cfg.TaskEnqueuer, }, nil } @@ -214,16 +218,26 @@ func (c *Client) readLoop(ctx context.Context, conn Conn, helloMessageID string) switch env.Type { case wsprotocol.TypeAgentHelloAck: - ack, err := wsprotocol.DecodePayload[wsprotocol.AckPayload](env) + helloAck, err := wsprotocol.DecodePayload[wsprotocol.HelloAckPayload](env) if err != nil { return err } - if ack.AckedMessageID == helloMessageID { - c.setActive(true) - if err := c.replayUnacked(ctx, conn); err != nil { + if helloAck.AckedMessageID == helloMessageID { + if err := c.applyHelloAckLocalState(helloAck); err != nil { return err } + c.setActive(true) + if helloAck.HasReconciliationDirectives() { + if err := c.replayRequestedOutput(ctx, conn, helloAck.OutputReplay); err != nil { + return err + } + } else { + if err := c.replayUnacked(ctx, conn); err != nil { + return err + } + } } + ack := wsprotocol.AckPayload{AckedMessageID: helloAck.AckedMessageID, AckedType: helloAck.AckedType} c.setLastAck(&ack) case wsprotocol.TypeAck: ack, err := wsprotocol.DecodePayload[wsprotocol.AckPayload](env) @@ -376,16 +390,115 @@ func (c *Client) RecordStarted(ctx context.Context, receipt localtaskstore.TaskR } func (c *Client) buildHello() wsprotocol.Envelope { + payload := c.buildHelloPayload() return wsprotocol.Envelope{ ProtocolVersion: wsprotocol.ProtocolVersion, MessageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()), Type: wsprotocol.TypeAgentHello, AgentID: c.agentID, SentAt: time.Now().UTC().Format(time.RFC3339), - Payload: map[string]any{}, + Payload: payloadFromValue(payload), } } +func (c *Client) buildHelloPayload() wsprotocol.HelloPayload { + payload := wsprotocol.HelloPayload{ + ReceivedNotStarted: []wsprotocol.ReceivedNotStartedAttempt{}, + UnackedFinals: []wsprotocol.UnackedFinalSnapshot{}, + UnackedOutput: []wsprotocol.UnackedOutputRange{}, + SpoolStatus: wsprotocol.SpoolStatus{}, + ClientVersion: version.Version, + } + if c.recovery == nil { + return payload + } + snapshot, err := c.recovery.Snapshot() + if err != nil { + return payload + } + if snapshot.RunningTask != nil { + payload.RunningTask = &wsprotocol.RunningTaskSnapshot{ + TaskID: snapshot.RunningTask.TaskID, + ExecutionAttemptID: snapshot.RunningTask.ExecutionAttemptID, + StartedAt: formatTime(snapshot.RunningTask.StartedAt), + LastOutputSequence: map[string]int{ + "stdout": int(snapshot.RunningTask.LastOutputSequence["stdout"]), + "stderr": int(snapshot.RunningTask.LastOutputSequence["stderr"]), + }, + } + } + for _, attempt := range snapshot.ReceivedNotStarted { + payload.ReceivedNotStarted = append(payload.ReceivedNotStarted, wsprotocol.ReceivedNotStartedAttempt{ + TaskID: attempt.TaskID, + ExecutionAttemptID: attempt.ExecutionAttemptID, + ReceivedAt: formatTime(attempt.ReceivedAt), + }) + } + for _, final := range snapshot.UnackedFinals { + payload.UnackedFinals = append(payload.UnackedFinals, wsprotocol.UnackedFinalSnapshot{ + MessageID: final.MessageID, + TaskID: final.TaskID, + ExecutionAttemptID: final.ExecutionAttemptID, + Status: wsprotocol.FinalStatus(final.Status), + ExitCode: final.ExitCode, + OutputTruncated: final.OutputTruncated, + ErrorTruncated: final.ErrorTruncated, + }) + } + for _, outputRange := range snapshot.UnackedOutput { + payload.UnackedOutput = append(payload.UnackedOutput, wsprotocol.UnackedOutputRange{ + TaskID: outputRange.TaskID, + ExecutionAttemptID: outputRange.ExecutionAttemptID, + Stream: wsprotocol.Stream(outputRange.Stream), + FirstSequence: int(outputRange.FirstSequence), + LastSequence: int(outputRange.LastSequence), + TruncatedLocally: outputRange.TruncatedLocally, + }) + } + payload.SpoolStatus = wsprotocol.SpoolStatus{ + BytesUsed: snapshot.SpoolStatus.BytesUsed, + ByteCap: snapshot.SpoolStatus.ByteCap, + HasRotatedChunks: snapshot.SpoolStatus.HasRotatedChunks, + } + return payload +} + +func (c *Client) applyHelloAckLocalState(ack wsprotocol.HelloAckPayload) error { + if c.outbox != nil { + for _, messageID := range ack.AcknowledgedFinalMessageIDs { + if err := c.outbox.AckMessage(messageID); err != nil { + return err + } + } + } + if c.recovery != nil { + for _, attempt := range ack.DiscardedAttempts { + if err := c.recovery.DiscardReceived(attempt.TaskID, attempt.ExecutionAttemptID); err != nil { + return err + } + } + } + return nil +} + +func (c *Client) replayRequestedOutput(ctx context.Context, conn Conn, replayRequests []wsprotocol.OutputReplayDirective) error { + if c.outbox == nil { + return nil + } + for _, replay := range replayRequests { + messages, err := c.outbox.UnackedMessagesFrom(replay.TaskID, replay.ExecutionAttemptID, string(replay.Stream), int64(replay.NextSequence)) + if err != nil { + return err + } + for _, message := range messages { + if err := c.writeEnvelope(ctx, conn, envelopeFromOutboxMessage(c.agentID, message)); err != nil { + return err + } + } + } + return nil +} + func (c *Client) buildTaskStateEnvelope(messageType wsprotocol.MessageType, taskID, executionAttemptID string) wsprotocol.Envelope { return wsprotocol.Envelope{ ProtocolVersion: wsprotocol.ProtocolVersion, @@ -493,6 +606,23 @@ func envelopeFromOutboxMessage(agentID string, message localtaskstore.OutboxMess } } +func payloadFromValue(value any) map[string]any { + data, _ := json.Marshal(value) + var payload map[string]any + _ = json.Unmarshal(data, &payload) + if payload == nil { + return map[string]any{} + } + return payload +} + +func formatTime(value time.Time) string { + if value.IsZero() { + return "" + } + return value.UTC().Format(time.RFC3339) +} + func sleepContext(ctx context.Context, d time.Duration) error { timer := time.NewTimer(d) defer timer.Stop() diff --git a/app/services/wsclient/client_test.go b/app/services/wsclient/client_test.go index b45e580..e326c5f 100644 --- a/app/services/wsclient/client_test.go +++ b/app/services/wsclient/client_test.go @@ -38,8 +38,10 @@ func TestClientSendsHelloAndMarksActiveAfterHelloAck(t *testing.T) { if written.AgentID != "agent_ws_test" { t.Fatalf("written agent_id = %q", written.AgentID) } - if len(written.Payload) != 0 { - t.Fatalf("hello payload = %#v, want empty object", written.Payload) + for _, key := range []string{"running_task", "received_not_started", "unacked_finals", "unacked_output", "spool_status", "client_version"} { + if _, ok := written.Payload[key]; !ok { + t.Fatalf("hello payload missing %q: %#v", key, written.Payload) + } } conn.readCh <- wsprotocol.Envelope{ @@ -61,6 +63,56 @@ func TestClientSendsHelloAndMarksActiveAfterHelloAck(t *testing.T) { } } +func TestClientHelloPayloadIncludesLocalReconnectSnapshot(t *testing.T) { + store := newClientTestStore(t) + _, err := store.RecordReceived(localtaskstore.TaskReceipt{TaskID: "task-received", ExecutionAttemptID: "attempt-received"}) + requireNoError(t, err) + requireNoError(t, store.RecordStarted("task-running", "attempt-running")) + requireNoError(t, store.AppendOutputChunk(localtaskstore.OutputChunk{ + MessageID: "msg-output-1", + TaskID: "task-running", + ExecutionAttemptID: "attempt-running", + Stream: "stdout", + Sequence: 3, + Payload: "hello", + ByteCount: 5, + })) + requireNoError(t, store.RecordFinal(localtaskstore.FinalResult{ + MessageID: "msg-final-1", + TaskID: "task-final", + ExecutionAttemptID: "attempt-final", + Status: "completed", + ExitCode: 0, + Payload: `{"status":"completed","exit_code":0,"output_truncated":false,"error_truncated":false}`, + })) + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithRecoveryStore(store)) + + runCtx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + written := conn.waitForWrite(t) + running, ok := written.Payload["running_task"].(map[string]any) + if !ok || running["task_id"] != "task-running" { + t.Fatalf("running_task = %#v", written.Payload["running_task"]) + } + if len(written.Payload["received_not_started"].([]any)) != 1 { + t.Fatalf("received_not_started = %#v", written.Payload["received_not_started"]) + } + if len(written.Payload["unacked_finals"].([]any)) != 1 { + t.Fatalf("unacked_finals = %#v", written.Payload["unacked_finals"]) + } + if len(written.Payload["unacked_output"].([]any)) != 1 { + t.Fatalf("unacked_output = %#v", written.Payload["unacked_output"]) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + func TestClientDialUsesSignedUpgradeHeaders(t *testing.T) { conn := newFakeConn() dialer := &fakeDialer{conn: conn} @@ -380,6 +432,101 @@ func TestClientReplaysUnackedMessagesAfterHelloAck(t *testing.T) { } } +func TestClientHelloAckAppliesFinalDiscardAndTargetedReplayDirectives(t *testing.T) { + store := newClientTestStore(t) + _, err := store.RecordReceived(localtaskstore.TaskReceipt{TaskID: "task-stale", ExecutionAttemptID: "attempt-stale"}) + requireNoError(t, err) + requireNoError(t, store.AppendOutputChunk(localtaskstore.OutputChunk{ + MessageID: "msg-output-1", + TaskID: "task-running", + ExecutionAttemptID: "attempt-running", + Stream: "stdout", + Sequence: 1, + Payload: "old", + ByteCount: 3, + })) + requireNoError(t, store.AppendOutputChunk(localtaskstore.OutputChunk{ + MessageID: "msg-output-2", + TaskID: "task-running", + ExecutionAttemptID: "attempt-running", + Stream: "stdout", + Sequence: 2, + Payload: "needed", + ByteCount: 6, + })) + requireNoError(t, store.RecordFinal(localtaskstore.FinalResult{ + MessageID: "msg-final-1", + TaskID: "task-final", + ExecutionAttemptID: "attempt-final", + Status: "completed", + Payload: `{"status":"completed","exit_code":0,"output_truncated":false,"error_truncated":false}`, + })) + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithResultOutbox(store), WithReceiptStore(store), WithRecoveryStore(store)) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + hello := conn.waitForWrite(t) + conn.readCh <- helloAckEnvelopeWithDirectives(hello.MessageID, wsprotocol.HelloAckPayload{ + AcknowledgedFinalMessageIDs: []string{"msg-final-1"}, + DiscardedAttempts: []wsprotocol.DiscardedAttempt{{TaskID: "task-stale", ExecutionAttemptID: "attempt-stale", Reason: "stale_attempt"}}, + OutputReplay: []wsprotocol.OutputReplayDirective{{TaskID: "task-running", ExecutionAttemptID: "attempt-running", Stream: wsprotocol.StreamStdout, NextSequence: 2}}, + }) + + replayed := conn.waitForWrite(t) + if replayed.MessageID != "msg-output-2" { + t.Fatalf("replayed message = %#v", replayed) + } + waitFor(t, func() bool { return client.IsActive() }, "client active after directives") + state, err := store.TaskState("task-stale", "attempt-stale") + requireNoError(t, err) + if state.Exists { + t.Fatalf("discarded state still exists: %#v", state) + } + messages, err := store.UnackedMessages() + requireNoError(t, err) + if containsMessage(messages, "msg-final-1") { + t.Fatalf("acked final still unacked: %#v", messages) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + +func TestClientHelloAckDiscardDoesNotEraseRunningAttempt(t *testing.T) { + store := newClientTestStore(t) + requireNoError(t, store.RecordStarted("task-running", "attempt-running")) + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithReceiptStore(store), WithRecoveryStore(store)) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + hello := conn.waitForWrite(t) + conn.readCh <- helloAckEnvelopeWithDirectives(hello.MessageID, wsprotocol.HelloAckPayload{ + DiscardedAttempts: []wsprotocol.DiscardedAttempt{{TaskID: "task-running", ExecutionAttemptID: "attempt-running", Reason: "stale_attempt"}}, + }) + + waitFor(t, func() bool { return client.IsActive() }, "client active after discard directive") + state, err := store.TaskState("task-running", "attempt-running") + requireNoError(t, err) + if !state.Exists || state.Status != localtaskstore.TaskStatusRunning { + t.Fatalf("running state = %#v", state) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + func TestClientRetryableErrorKeepsConnectionAndOutboxMessage(t *testing.T) { store := newClientTestStore(t) requireNoError(t, store.AppendOutputChunk(localtaskstore.OutputChunk{ @@ -597,6 +744,10 @@ func WithReceiptStore(store localtaskstore.ReceiptStore) clientOption { return func(cfg *Config) { cfg.ReceiptStore = store } } +func WithRecoveryStore(store localtaskstore.RecoveryStore) clientOption { + return func(cfg *Config) { cfg.RecoveryStore = store } +} + func WithTaskEnqueuer(enqueuer TaskEnqueuer) clientOption { return func(cfg *Config) { cfg.TaskEnqueuer = enqueuer } } @@ -740,6 +891,28 @@ func helloAckEnvelope(ackedMessageID string) wsprotocol.Envelope { } } +func helloAckEnvelopeWithDirectives(ackedMessageID string, directives wsprotocol.HelloAckPayload) wsprotocol.Envelope { + directives.AckedMessageID = ackedMessageID + directives.AckedType = wsprotocol.TypeAgentHello + if directives.AcknowledgedFinalMessageIDs == nil { + directives.AcknowledgedFinalMessageIDs = []string{} + } + if directives.DiscardedAttempts == nil { + directives.DiscardedAttempts = []wsprotocol.DiscardedAttempt{} + } + if directives.OutputReplay == nil { + directives.OutputReplay = []wsprotocol.OutputReplayDirective{} + } + return wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: "msg_hello_ack", + Type: wsprotocol.TypeAgentHelloAck, + AgentID: "agent_ws_test", + SentAt: time.Now().UTC().Format(time.RFC3339), + Payload: payloadMapForTest(directives), + } +} + func deliverEnvelope(messageID, taskID, attemptID, command string, priority int) wsprotocol.Envelope { return wsprotocol.Envelope{ ProtocolVersion: wsprotocol.ProtocolVersion, @@ -809,6 +982,15 @@ func intValuePtr(value int) *int { return &value } +func containsMessage(messages []localtaskstore.OutboxMessage, messageID string) bool { + for _, message := range messages { + if message.MessageID == messageID { + return true + } + } + return false +} + func saveTestPrivateKey(t *testing.T, dir string) string { t.Helper() privateKey, err := rsa.GenerateKey(rand.Reader, 2048) diff --git a/internal/wsprotocol/message.go b/internal/wsprotocol/message.go index c1878bc..db1c2b4 100644 --- a/internal/wsprotocol/message.go +++ b/internal/wsprotocol/message.go @@ -70,6 +70,74 @@ type TaskDeliverPayload struct { Priority int `json:"priority"` } +type HelloPayload struct { + RunningTask *RunningTaskSnapshot `json:"running_task"` + ReceivedNotStarted []ReceivedNotStartedAttempt `json:"received_not_started"` + UnackedFinals []UnackedFinalSnapshot `json:"unacked_finals"` + UnackedOutput []UnackedOutputRange `json:"unacked_output"` + SpoolStatus SpoolStatus `json:"spool_status"` + ClientVersion string `json:"client_version"` +} + +type RunningTaskSnapshot struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + StartedAt string `json:"started_at"` + LastOutputSequence map[string]int `json:"last_output_sequence"` +} + +type ReceivedNotStartedAttempt struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + ReceivedAt string `json:"received_at"` +} + +type UnackedFinalSnapshot struct { + MessageID string `json:"message_id"` + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + Status FinalStatus `json:"status"` + ExitCode int `json:"exit_code"` + OutputTruncated bool `json:"output_truncated"` + ErrorTruncated bool `json:"error_truncated"` +} + +type UnackedOutputRange struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + Stream Stream `json:"stream"` + FirstSequence int `json:"first_sequence"` + LastSequence int `json:"last_sequence"` + TruncatedLocally bool `json:"truncated_locally"` +} + +type SpoolStatus struct { + BytesUsed int64 `json:"bytes_used"` + ByteCap int64 `json:"byte_cap"` + HasRotatedChunks bool `json:"has_rotated_chunks"` +} + +type HelloAckPayload struct { + AckedMessageID string `json:"acked_message_id"` + AckedType MessageType `json:"acked_type"` + AcknowledgedFinalMessageIDs []string `json:"acknowledged_final_message_ids"` + DiscardedAttempts []DiscardedAttempt `json:"discarded_attempts"` + OutputReplay []OutputReplayDirective `json:"output_replay"` +} + +type DiscardedAttempt struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + Reason string `json:"reason"` +} + +type OutputReplayDirective struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + Stream Stream `json:"stream"` + NextSequence int `json:"next_sequence"` +} + func (e Envelope) Validate(authenticatedAgentID string) error { if e.ProtocolVersion != ProtocolVersion { return fmt.Errorf("unsupported protocol_version: %d", e.ProtocolVersion) @@ -132,6 +200,49 @@ func (p TaskDeliverPayload) Validate() error { return nil } +func (p HelloPayload) Validate() error { + if p.ClientVersion == "" { + return fmt.Errorf("client_version is required") + } + if p.RunningTask != nil { + if p.RunningTask.TaskID == "" { + return fmt.Errorf("running_task.task_id is required") + } + if p.RunningTask.ExecutionAttemptID == "" { + return fmt.Errorf("running_task.execution_attempt_id is required") + } + if p.RunningTask.LastOutputSequence == nil { + return fmt.Errorf("running_task.last_output_sequence is required") + } + } + for _, attempt := range p.ReceivedNotStarted { + if attempt.TaskID == "" || attempt.ExecutionAttemptID == "" { + return fmt.Errorf("received_not_started entries require task_id and execution_attempt_id") + } + } + for _, final := range p.UnackedFinals { + if final.MessageID == "" || final.TaskID == "" || final.ExecutionAttemptID == "" { + return fmt.Errorf("unacked_finals entries require message_id, task_id, and execution_attempt_id") + } + } + for _, output := range p.UnackedOutput { + if output.TaskID == "" || output.ExecutionAttemptID == "" { + return fmt.Errorf("unacked_output entries require task_id and execution_attempt_id") + } + if output.Stream != StreamStdout && output.Stream != StreamStderr { + return fmt.Errorf("unacked_output.stream must be stdout or stderr") + } + if output.FirstSequence <= 0 || output.LastSequence < output.FirstSequence { + return fmt.Errorf("unacked_output sequence range is invalid") + } + } + return nil +} + +func (p HelloAckPayload) HasReconciliationDirectives() bool { + return p.AcknowledgedFinalMessageIDs != nil || p.DiscardedAttempts != nil || p.OutputReplay != nil +} + func DecodePayload[T any](e Envelope) (T, error) { var payload T diff --git a/internal/wsprotocol/message_test.go b/internal/wsprotocol/message_test.go index 543a3b2..6b4a337 100644 --- a/internal/wsprotocol/message_test.go +++ b/internal/wsprotocol/message_test.go @@ -189,6 +189,69 @@ func TestPayloadValidation(t *testing.T) { }) } +func TestHelloPayloadUsesReconnectSnapshotShape(t *testing.T) { + payload := HelloPayload{ + RunningTask: &RunningTaskSnapshot{ + TaskID: "task-1", + ExecutionAttemptID: "attempt-1", + StartedAt: "2026-04-28T00:00:00Z", + LastOutputSequence: map[string]int{"stdout": 2, "stderr": 1}, + }, + ReceivedNotStarted: []ReceivedNotStartedAttempt{{TaskID: "task-2", ExecutionAttemptID: "attempt-2", ReceivedAt: "2026-04-28T00:00:01Z"}}, + UnackedFinals: []UnackedFinalSnapshot{{ + MessageID: "msg-final-1", + TaskID: "task-3", + ExecutionAttemptID: "attempt-3", + Status: FinalStatusCompleted, + }}, + UnackedOutput: []UnackedOutputRange{{TaskID: "task-1", ExecutionAttemptID: "attempt-1", Stream: StreamStdout, FirstSequence: 2, LastSequence: 4}}, + SpoolStatus: SpoolStatus{BytesUsed: 128, ByteCap: 1024, HasRotatedChunks: true}, + ClientVersion: "test-version", + } + + if err := payload.Validate(); err != nil { + t.Fatalf("expected hello payload to validate, got %v", err) + } + + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal hello payload: %v", err) + } + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal hello payload: %v", err) + } + if !reflect.DeepEqual(sortedKeys(decoded), []string{"client_version", "received_not_started", "running_task", "spool_status", "unacked_finals", "unacked_output"}) { + t.Fatalf("hello payload fields = %v", sortedKeys(decoded)) + } + lastOutput, ok := decoded["running_task"].(map[string]any)["last_output_sequence"].(map[string]any) + if !ok || lastOutput["stdout"] != float64(2) || lastOutput["stderr"] != float64(1) { + t.Fatalf("last_output_sequence = %#v", decoded["running_task"]) + } +} + +func TestHelloAckPayloadUsesReconciliationDirectiveShape(t *testing.T) { + payload := HelloAckPayload{ + AckedMessageID: "msg-hello", + AckedType: TypeAgentHello, + AcknowledgedFinalMessageIDs: []string{"msg-final-1"}, + DiscardedAttempts: []DiscardedAttempt{{TaskID: "task-2", ExecutionAttemptID: "attempt-2", Reason: "stale_attempt"}}, + OutputReplay: []OutputReplayDirective{{TaskID: "task-1", ExecutionAttemptID: "attempt-1", Stream: StreamStderr, NextSequence: 14}}, + } + + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal hello ack: %v", err) + } + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal hello ack: %v", err) + } + if !reflect.DeepEqual(sortedKeys(decoded), []string{"acked_message_id", "acked_type", "acknowledged_final_message_ids", "discarded_attempts", "output_replay"}) { + t.Fatalf("hello ack fields = %v", sortedKeys(decoded)) + } +} + func intPtr(value int) *int { return &value } diff --git a/main.go b/main.go index 039fe9c..2e061ae 100644 --- a/main.go +++ b/main.go @@ -352,6 +352,7 @@ func newDefaultWebSocketRuntime(localStore *localtaskstore.Store, enqueuer wscli PingInterval: appconf.WebSocketPingInterval(), ResultOutbox: localStore, ReceiptStore: localStore, + RecoveryStore: localStore, TaskEnqueuer: enqueuer, }) }