From 478f9b8fadeeb8f3b1a2a17c1eb803cbdf5c23bf Mon Sep 17 00:00:00 2001 From: Mohammad Aziz Date: Tue, 28 Apr 2026 11:54:14 +0530 Subject: [PATCH 1/2] feat(hostlink): split TaskJob into shared runner and executor, route polling through enqueue, separate durable start from transport start --- app/jobs/taskjob/result_channel_test.go | 76 ++++++- app/jobs/taskjob/runner_test.go | 138 +++++++++++++ app/jobs/taskjob/taskjob.go | 199 ++++++++++++------ app/services/localtaskstore/store.go | 1 + app/services/wsclient/client.go | 117 ++++++++++- app/services/wsclient/client_test.go | 240 ++++++++++++++++++++++ internal/wsprotocol/message.go | 12 ++ main.go | 8 +- test/integration/taskjob_reporter_test.go | 39 ++-- 9 files changed, 740 insertions(+), 90 deletions(-) create mode 100644 app/jobs/taskjob/runner_test.go diff --git a/app/jobs/taskjob/result_channel_test.go b/app/jobs/taskjob/result_channel_test.go index 477fc51..305240d 100644 --- a/app/jobs/taskjob/result_channel_test.go +++ b/app/jobs/taskjob/result_channel_test.go @@ -8,6 +8,8 @@ import ( "hostlink/app/services/taskreporter" "hostlink/domain/task" "io" + "os" + "path/filepath" "sync" "testing" "time" @@ -26,6 +28,12 @@ func TestTaskJobStreamsOutputAndFinalOverResultChannel(t *testing.T) { job.processTask(context.Background(), fetcher.tasks[0], reporter, channel) + if len(channel.started) != 1 { + t.Fatalf("started len = %d, want 1", len(channel.started)) + } + if channel.started[0].TaskID != "task-1" || channel.started[0].ExecutionAttemptID != "attempt-1" { + t.Fatalf("started = %#v", channel.started[0]) + } if len(channel.outputs) != 2 { t.Fatalf("outputs len = %d, want 2", len(channel.outputs)) } @@ -84,6 +92,34 @@ func TestTaskJobFallsBackToHTTPReporterWhenFinalPersistenceFails(t *testing.T) { } } +func TestTaskJobRecordsStartedBeforeProcessLaunch(t *testing.T) { + marker := filepath.Join(t.TempDir(), "process-started") + reporter := &fakeTaskReporter{} + checked := false + channel := &fakeResultChannel{recordStartedHook: func() error { + checked = true + if _, err := os.Stat(marker); err == nil { + return errors.New("process launched before durable started state") + } + return nil + }} + job := NewJobWithConf(TaskJobConfig{Trigger: runOnceTrigger}) + + job.processTask(context.Background(), task.Task{ + ID: "task-1", + ExecutionAttemptID: "attempt-1", + Command: "printf launched > " + marker, + Status: "pending", + }, reporter, channel) + + if !checked { + t.Fatal("RecordStarted was not called") + } + if len(channel.started) != 1 { + t.Fatalf("started len = %d, want 1", len(channel.started)) + } +} + func TestCaptureStreamFlushesOnByteThreshold(t *testing.T) { reader, writer := io.Pipe() channel := &fakeResultChannel{} @@ -176,23 +212,49 @@ func (f *fakeTaskFetcher) Fetch() ([]task.Task, error) { } type fakeTaskReporter struct { - mu sync.Mutex - results []*taskreporter.TaskResult + mu sync.Mutex + taskIDsReported []string + results []*taskreporter.TaskResult } func (f *fakeTaskReporter) Report(taskID string, result *taskreporter.TaskResult) error { f.mu.Lock() defer f.mu.Unlock() + f.taskIDsReported = append(f.taskIDsReported, taskID) f.results = append(f.results, result) return nil } type fakeResultChannel struct { - mu sync.Mutex - outputs []localtaskstore.OutputChunk - finals []localtaskstore.FinalResult - outputErrs []error - finalErr error + mu sync.Mutex + recordedStarted []localtaskstore.TaskReceipt + started []localtaskstore.TaskReceipt + outputs []localtaskstore.OutputChunk + finals []localtaskstore.FinalResult + recordStartedErr error + recordStartedHook func() error + startedErr error + outputErrs []error + finalErr error +} + +func (f *fakeResultChannel) RecordStarted(ctx context.Context, receipt localtaskstore.TaskReceipt) error { + f.mu.Lock() + defer f.mu.Unlock() + if f.recordStartedHook != nil { + if err := f.recordStartedHook(); err != nil { + return err + } + } + f.recordedStarted = append(f.recordedStarted, receipt) + return f.recordStartedErr +} + +func (f *fakeResultChannel) SendStarted(ctx context.Context, receipt localtaskstore.TaskReceipt) error { + f.mu.Lock() + defer f.mu.Unlock() + f.started = append(f.started, receipt) + return f.startedErr } func (f *fakeResultChannel) SendOutput(ctx context.Context, chunk localtaskstore.OutputChunk) error { diff --git a/app/jobs/taskjob/runner_test.go b/app/jobs/taskjob/runner_test.go new file mode 100644 index 0000000..9303876 --- /dev/null +++ b/app/jobs/taskjob/runner_test.go @@ -0,0 +1,138 @@ +package taskjob + +import ( + "context" + "fmt" + "hostlink/app/services/taskreporter" + "hostlink/domain/task" + "os" + "path/filepath" + "testing" + "time" +) + +func TestTaskJobQueuesPollingAndEnqueuedWorkSequentially(t *testing.T) { + marker := filepath.Join(t.TempDir(), "polling-started") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + fetcher := &fakeTaskFetcher{tasks: []task.Task{{ + ID: "poll-task", + Command: fmt.Sprintf("printf started > %q; sleep 0.2; printf first", marker), + Status: "pending", + }}} + reporter := &fakeTaskReporter{} + job := NewJobWithConf(TaskJobConfig{Trigger: runOnceTrigger}) + + cancelJob := job.Register(ctx, fetcher, reporter) + defer func() { + cancelJob() + job.Shutdown() + }() + waitForFile(t, marker) + if err := job.Enqueue(ctx, task.Task{ID: "ws-task", Command: "printf second", Status: "pending"}); err != nil { + t.Fatalf("enqueue websocket task: %v", err) + } + waitForReports(t, reporter, 2) + + got := reporter.taskIDs() + want := []string{"poll-task", "ws-task"} + if fmt.Sprint(got) != fmt.Sprint(want) { + t.Fatalf("report order = %v, want %v", got, want) + } +} + +func TestTaskJobRunsEnqueuedTasksSequentially(t *testing.T) { + marker := filepath.Join(t.TempDir(), "first-started") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + reporter := &fakeTaskReporter{} + job := NewJobWithConf(TaskJobConfig{Trigger: noOpTrigger}) + + cancelJob := job.Register(ctx, &fakeTaskFetcher{}, reporter) + defer func() { + cancelJob() + job.Shutdown() + }() + if err := job.Enqueue(ctx, task.Task{ID: "task-1", Command: fmt.Sprintf("printf started > %q; sleep 0.2; printf first", marker), Status: "pending"}); err != nil { + t.Fatalf("enqueue first task: %v", err) + } + waitForFile(t, marker) + if err := job.Enqueue(ctx, task.Task{ID: "task-2", Command: "printf second", Status: "pending"}); err != nil { + t.Fatalf("enqueue second task: %v", err) + } + waitForReports(t, reporter, 2) + + got := reporter.taskIDs() + want := []string{"task-1", "task-2"} + if fmt.Sprint(got) != fmt.Sprint(want) { + t.Fatalf("report order = %v, want %v", got, want) + } +} + +func TestTaskJobSuppressesDuplicateQueuedAttempt(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + reporter := &fakeTaskReporter{} + job := NewJobWithConf(TaskJobConfig{Trigger: noOpTrigger}) + taskAttempt := task.Task{ID: "task-1", ExecutionAttemptID: "attempt-1", Command: "printf run", Status: "pending"} + + cancelJob := job.Register(ctx, &fakeTaskFetcher{}, reporter) + defer func() { + cancelJob() + job.Shutdown() + }() + if err := job.Enqueue(ctx, taskAttempt); err != nil { + t.Fatalf("enqueue first attempt: %v", err) + } + if err := job.Enqueue(ctx, taskAttempt); err != nil { + t.Fatalf("enqueue duplicate attempt: %v", err) + } + waitForReports(t, reporter, 1) + time.Sleep(100 * time.Millisecond) + + if got := len(reporter.resultsSnapshot()); got != 1 { + t.Fatalf("report count = %d, want 1", got) + } +} + +func noOpTrigger(ctx context.Context, fn func() error) {} + +func waitForFile(t *testing.T, path string) { + t.Helper() + deadline := time.Now().Add(time.Second) + for time.Now().Before(deadline) { + if _, err := os.Stat(path); err == nil { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("timed out waiting for %s", path) +} + +func waitForReports(t *testing.T, reporter *fakeTaskReporter, count int) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if len(reporter.resultsSnapshot()) >= count { + return + } + time.Sleep(time.Millisecond) + } + t.Fatalf("timed out waiting for %d reports", count) +} + +func (f *fakeTaskReporter) taskIDs() []string { + f.mu.Lock() + defer f.mu.Unlock() + ids := make([]string, 0, len(f.taskIDsReported)) + ids = append(ids, f.taskIDsReported...) + return ids +} + +func (f *fakeTaskReporter) resultsSnapshot() []*taskreporter.TaskResult { + f.mu.Lock() + defer f.mu.Unlock() + results := make([]*taskreporter.TaskResult, 0, len(f.results)) + results = append(results, f.results...) + return results +} diff --git a/app/jobs/taskjob/taskjob.go b/app/jobs/taskjob/taskjob.go index 6df56c2..fff4cec 100644 --- a/app/jobs/taskjob/taskjob.go +++ b/app/jobs/taskjob/taskjob.go @@ -31,14 +31,29 @@ type TaskJobConfig struct { } type ResultChannel interface { + RecordStarted(context.Context, localtaskstore.TaskReceipt) error + SendStarted(context.Context, localtaskstore.TaskReceipt) error SendOutput(context.Context, localtaskstore.OutputChunk) error SendFinal(context.Context, localtaskstore.FinalResult) error } +type TaskExecutor interface { + Execute(context.Context, task.Task) error +} + +type taskExecutor struct { + config TaskJobConfig + reporter taskreporter.TaskReporter + channel ResultChannel +} + type TaskJob struct { - config TaskJobConfig - cancel context.CancelFunc - wg sync.WaitGroup + config TaskJobConfig + enqueueCh chan task.Task + cancel context.CancelFunc + wg sync.WaitGroup + mu sync.Mutex + knownAttempt map[string]struct{} } func New() *TaskJob { @@ -59,7 +74,9 @@ func NewJobWithConf(cfg TaskJobConfig) *TaskJob { } return &TaskJob{ - config: cfg, + config: cfg, + enqueueCh: make(chan task.Task, 16), + knownAttempt: make(map[string]struct{}), } } @@ -70,6 +87,12 @@ func (tj *TaskJob) Register(ctx context.Context, tf taskfetcher.TaskFetcher, tr if len(channels) > 0 { channel = channels[0] } + executor := &taskExecutor{config: tj.config, reporter: tr, channel: channel} + tj.wg.Add(1) + go func() { + defer tj.wg.Done() + tj.run(ctx, executor) + }() tj.wg.Add(1) go func() { defer tj.wg.Done() @@ -78,14 +101,13 @@ func (tj *TaskJob) Register(ctx context.Context, tf taskfetcher.TaskFetcher, tr if err != nil { return err } - incompleteTasks := []task.Task{} - for _, task := range allTasks { - if task.Status != "completed" { - incompleteTasks = append(incompleteTasks, task) + for _, t := range allTasks { + if t.Status == "completed" { + continue + } + if err := tj.Enqueue(ctx, t); err != nil { + return err } - } - for _, t := range incompleteTasks { - tj.processTask(ctx, t, tr, channel) } return nil }) @@ -93,19 +115,83 @@ func (tj *TaskJob) Register(ctx context.Context, tf taskfetcher.TaskFetcher, tr return cancel } +func (tj *TaskJob) run(ctx context.Context, executor TaskExecutor) { + for { + select { + case queued := <-tj.enqueueCh: + executor.Execute(ctx, queued) + tj.finishAttempt(queued) + case <-ctx.Done(): + return + } + } +} + +func (tj *TaskJob) Enqueue(ctx context.Context, t task.Task) error { + if !tj.reserveAttempt(t) { + return nil + } + select { + case tj.enqueueCh <- t: + return nil + case <-ctx.Done(): + tj.finishAttempt(t) + return ctx.Err() + } +} + +func (tj *TaskJob) reserveAttempt(t task.Task) bool { + key := executionKey(t) + if key == "" { + return true + } + tj.mu.Lock() + defer tj.mu.Unlock() + if _, ok := tj.knownAttempt[key]; ok { + return false + } + tj.knownAttempt[key] = struct{}{} + return true +} + +func (tj *TaskJob) finishAttempt(t task.Task) { + key := executionKey(t) + if key == "" { + return + } + tj.mu.Lock() + defer tj.mu.Unlock() + delete(tj.knownAttempt, key) +} + +func executionKey(t task.Task) string { + if t.ID == "" || t.ExecutionAttemptID == "" { + return "" + } + return t.ID + "\x00" + t.ExecutionAttemptID +} + +func (e *taskExecutor) Execute(ctx context.Context, t task.Task) error { + e.processTask(ctx, t) + return nil +} + func (tj *TaskJob) processTask(ctx context.Context, t task.Task, tr taskreporter.TaskReporter, channel ResultChannel) { + executor := &taskExecutor{config: tj.config, reporter: tr, channel: channel} + executor.processTask(ctx, t) +} + +func (tj *TaskJob) captureStream(ctx context.Context, t task.Task, stream string, reader io.Reader, sink *bytes.Buffer, channel ResultChannel) { + executor := &taskExecutor{config: tj.config, channel: channel} + executor.captureStream(ctx, t, stream, reader, sink) +} + +func (e *taskExecutor) processTask(ctx context.Context, t task.Task) { tempFile, err := os.CreateTemp("", "*_script.sh") if err != nil { t.Error = fmt.Sprintf("failed to create temp file: %v", err) t.Status = "failed" - if reportErr := tr.Report(t.ID, &taskreporter.TaskResult{ - Status: t.Status, - Output: t.Output, - Error: t.Error, - ExitCode: t.ExitCode, - }); reportErr != nil { - log.Errorf("failed to report task %s: %v", t.ID, reportErr) - } + e.reportHTTPResult(t, "failed", t.Output, t.Error, t.ExitCode) return } defer os.Remove(tempFile.Name()) @@ -114,14 +200,7 @@ func (tj *TaskJob) processTask(ctx context.Context, t task.Task, tr taskreporter tempFile.Close() t.Error = fmt.Sprintf("failed to write script: %v", err) t.Status = "failed" - if reportErr := tr.Report(t.ID, &taskreporter.TaskResult{ - Status: t.Status, - Output: t.Output, - Error: t.Error, - ExitCode: t.ExitCode, - }); reportErr != nil { - log.Errorf("failed to report task %s: %v", t.ID, reportErr) - } + e.reportHTTPResult(t, "failed", t.Output, t.Error, t.ExitCode) return } tempFile.Close() @@ -129,19 +208,12 @@ func (tj *TaskJob) processTask(ctx context.Context, t task.Task, tr taskreporter if err := os.Chmod(tempFile.Name(), 0755); err != nil { t.Error = fmt.Sprintf("failed to chmod: %v", err) t.Status = "failed" - if reportErr := tr.Report(t.ID, &taskreporter.TaskResult{ - Status: t.Status, - Output: t.Output, - Error: t.Error, - ExitCode: t.ExitCode, - }); reportErr != nil { - log.Errorf("failed to report task %s: %v", t.ID, reportErr) - } + e.reportHTTPResult(t, "failed", t.Output, t.Error, t.ExitCode) return } execCmd := exec.Command("/bin/sh", "-c", tempFile.Name()) - if channel != nil && t.ExecutionAttemptID != "" { - tj.processTaskWithResultChannel(ctx, t, execCmd, tr, channel) + if e.channel != nil && t.ExecutionAttemptID != "" { + e.processTaskWithResultChannel(ctx, t, execCmd) return } @@ -158,30 +230,33 @@ func (tj *TaskJob) processTask(ctx context.Context, t task.Task, tr taskreporter t.Error = errMsg t.Output = string(output) t.Status = "completed" - if reportErr := tr.Report(t.ID, &taskreporter.TaskResult{ - Status: t.Status, - Output: t.Output, - Error: t.Error, - ExitCode: t.ExitCode, - }); reportErr != nil { - log.Errorf("failed to report task %s: %v", t.ID, reportErr) - } + e.reportHTTPResult(t, t.Status, t.Output, t.Error, t.ExitCode) } -func (tj *TaskJob) processTaskWithResultChannel(ctx context.Context, t task.Task, execCmd *exec.Cmd, tr taskreporter.TaskReporter, channel ResultChannel) { +func (e *taskExecutor) processTaskWithResultChannel(ctx context.Context, t task.Task, execCmd *exec.Cmd) { stdout, err := execCmd.StdoutPipe() if err != nil { - tj.reportHTTPResult(t, tr, "failed", "", fmt.Sprintf("failed to capture stdout: %v", err), 1) + e.reportHTTPResult(t, "failed", "", fmt.Sprintf("failed to capture stdout: %v", err), 1) return } stderr, err := execCmd.StderrPipe() if err != nil { - tj.reportHTTPResult(t, tr, "failed", "", fmt.Sprintf("failed to capture stderr: %v", err), 1) + e.reportHTTPResult(t, "failed", "", fmt.Sprintf("failed to capture stderr: %v", err), 1) return } + receipt := localtaskstore.TaskReceipt{TaskID: t.ID, ExecutionAttemptID: t.ExecutionAttemptID} + if err := e.channel.RecordStarted(ctx, receipt); err != nil { + e.reportHTTPResult(t, "failed", "", fmt.Sprintf("failed to record task start: %v", err), 1) + return + } if err := execCmd.Start(); err != nil { - tj.reportHTTPResult(t, tr, "failed", "", err.Error(), 1) + e.reportHTTPResult(t, "failed", "", err.Error(), 1) + return + } + if err := e.channel.SendStarted(ctx, receipt); err != nil { + _ = execCmd.Process.Kill() + e.reportHTTPResult(t, "failed", "", fmt.Sprintf("failed to report task start: %v", err), 1) return } @@ -191,11 +266,11 @@ func (tj *TaskJob) processTaskWithResultChannel(ctx context.Context, t task.Task wg.Add(2) go func() { defer wg.Done() - tj.captureStream(ctx, t, "stdout", stdout, &stdoutBuf, channel) + e.captureStream(ctx, t, "stdout", stdout, &stdoutBuf) }() go func() { defer wg.Done() - tj.captureStream(ctx, t, "stderr", stderr, &stderrBuf, channel) + e.captureStream(ctx, t, "stderr", stderr, &stderrBuf) }() wg.Wait() @@ -219,7 +294,7 @@ func (tj *TaskJob) processTaskWithResultChannel(ctx context.Context, t task.Task resultPayload := taskreporter.TaskResult{Status: status, Output: output, Error: errMsg, ExitCode: exitCode} finalPayload, err := json.Marshal(resultPayload) if err != nil { - tj.reportHTTPResult(t, tr, status, output, errMsg, exitCode) + e.reportHTTPResult(t, status, output, errMsg, exitCode) return } @@ -231,19 +306,19 @@ func (tj *TaskJob) processTaskWithResultChannel(ctx context.Context, t task.Task ExitCode: exitCode, Payload: string(finalPayload), } - if err := channel.SendFinal(ctx, final); err != nil { - tj.reportHTTPResult(t, tr, status, output, errMsg, exitCode) + if err := e.channel.SendFinal(ctx, final); err != nil { + e.reportHTTPResult(t, status, output, errMsg, exitCode) } } -func (tj *TaskJob) captureStream(ctx context.Context, t task.Task, stream string, reader io.Reader, sink *bytes.Buffer, channel ResultChannel) { +func (e *taskExecutor) captureStream(ctx context.Context, t task.Task, stream string, reader io.Reader, sink *bytes.Buffer) { sequence := int64(1) chunks := make(chan string, 1) go func() { defer close(chunks) - buffered := bufio.NewReaderSize(reader, tj.config.OutputFlushThreshold) + buffered := bufio.NewReaderSize(reader, e.config.OutputFlushThreshold) for { - buf := make([]byte, max(tj.config.OutputFlushThreshold, 1)) + buf := make([]byte, max(e.config.OutputFlushThreshold, 1)) n, err := buffered.Read(buf) if n > 0 { chunks <- string(buf[:n]) @@ -255,7 +330,7 @@ func (tj *TaskJob) captureStream(ctx context.Context, t task.Task, stream string }() var pending bytes.Buffer - ticker := time.NewTicker(tj.config.OutputFlushInterval) + ticker := time.NewTicker(e.config.OutputFlushInterval) defer ticker.Stop() flush := func() bool { @@ -263,7 +338,7 @@ func (tj *TaskJob) captureStream(ctx context.Context, t task.Task, stream string return true } chunk := pending.String() - err := channel.SendOutput(ctx, localtaskstore.OutputChunk{ + err := e.channel.SendOutput(ctx, localtaskstore.OutputChunk{ MessageID: messageID(t.ID, t.ExecutionAttemptID, stream, sequence), TaskID: t.ID, ExecutionAttemptID: t.ExecutionAttemptID, @@ -289,7 +364,7 @@ func (tj *TaskJob) captureStream(ctx context.Context, t task.Task, stream string } sink.WriteString(chunk) pending.WriteString(chunk) - if pending.Len() >= tj.config.OutputFlushThreshold { + if pending.Len() >= e.config.OutputFlushThreshold { flush() } case <-ticker.C: @@ -300,8 +375,8 @@ func (tj *TaskJob) captureStream(ctx context.Context, t task.Task, stream string } } -func (tj *TaskJob) reportHTTPResult(t task.Task, tr taskreporter.TaskReporter, status, output, errMsg string, exitCode int) { - if reportErr := tr.Report(t.ID, &taskreporter.TaskResult{ +func (e *taskExecutor) reportHTTPResult(t task.Task, status, output, errMsg string, exitCode int) { + if reportErr := e.reporter.Report(t.ID, &taskreporter.TaskResult{ Status: status, Output: output, Error: errMsg, diff --git a/app/services/localtaskstore/store.go b/app/services/localtaskstore/store.go index 460206e..b117a22 100644 --- a/app/services/localtaskstore/store.go +++ b/app/services/localtaskstore/store.go @@ -87,6 +87,7 @@ type Config struct { type ReceiptStore interface { RecordReceived(TaskReceipt) (TaskState, error) + RecordStarted(taskID, executionAttemptID string) error TaskState(taskID, executionAttemptID string) (TaskState, error) } diff --git a/app/services/wsclient/client.go b/app/services/wsclient/client.go index 5b9a16a..4cee80a 100644 --- a/app/services/wsclient/client.go +++ b/app/services/wsclient/client.go @@ -13,6 +13,7 @@ import ( "hostlink/app/services/agentstate" "hostlink/app/services/requestsigner" + "hostlink/domain/task" "hostlink/internal/wsprotocol" ) @@ -31,6 +32,10 @@ type Conn interface { type SleepFunc func(context.Context, time.Duration) error +type TaskEnqueuer interface { + Enqueue(context.Context, task.Task) error +} + type Config struct { URL string AgentState *agentstate.AgentState @@ -41,6 +46,8 @@ type Config struct { PingInterval time.Duration SleepFunc SleepFunc ResultOutbox localtaskstore.ResultOutbox + ReceiptStore localtaskstore.ReceiptStore + TaskEnqueuer TaskEnqueuer } type Client struct { @@ -53,12 +60,14 @@ type Client struct { pingInterval time.Duration sleep SleepFunc - mu sync.RWMutex - writeMu sync.Mutex - active bool - lastAck *wsprotocol.AckPayload - conn Conn - outbox localtaskstore.ResultOutbox + mu sync.RWMutex + writeMu sync.Mutex + active bool + lastAck *wsprotocol.AckPayload + conn Conn + outbox localtaskstore.ResultOutbox + receipts localtaskstore.ReceiptStore + enqueuer TaskEnqueuer } func New(cfg Config) (*Client, error) { @@ -99,6 +108,8 @@ func New(cfg Config) (*Client, error) { pingInterval: cfg.PingInterval, sleep: cfg.SleepFunc, outbox: cfg.ResultOutbox, + receipts: cfg.ReceiptStore, + enqueuer: cfg.TaskEnqueuer, }, nil } @@ -234,12 +245,82 @@ func (c *Client) readLoop(ctx context.Context, conn Conn, helloMessageID string) continue } return fmt.Errorf("websocket protocol error: %s", env.MessageID) + case wsprotocol.TypeTaskDeliver: + if err := c.receiveTaskDeliver(ctx, conn, env); err != nil { + return err + } default: return fmt.Errorf("unsupported inbound websocket message type: %s", env.Type) } } } +func (c *Client) receiveTaskDeliver(ctx context.Context, conn Conn, env wsprotocol.Envelope) error { + if c.receipts == nil { + return fmt.Errorf("receipt store is not configured") + } + payload, err := wsprotocol.DecodePayload[wsprotocol.TaskDeliverPayload](env) + if err != nil { + return err + } + if err := payload.Validate(); err != nil { + return err + } + + previous, err := c.receipts.TaskState(env.TaskID, env.ExecutionAttemptID) + if err != nil { + return err + } + state, err := c.receipts.RecordReceived(localtaskstore.TaskReceipt{ + TaskID: env.TaskID, + ExecutionAttemptID: env.ExecutionAttemptID, + }) + if err != nil { + return err + } + if previous.Exists && previous.Status == localtaskstore.TaskStatusRunning { + return c.writeEnvelope(ctx, conn, c.buildTaskStateEnvelope(wsprotocol.TypeTaskStarted, env.TaskID, env.ExecutionAttemptID)) + } + if previous.Exists && (previous.Status == localtaskstore.TaskStatusFinal || previous.Status == localtaskstore.TaskStatusInterrupted) { + replayed, err := c.replayTaskFinal(ctx, conn, env.TaskID, env.ExecutionAttemptID) + if err != nil { + return err + } + if replayed { + return nil + } + } + if err := c.writeEnvelope(ctx, conn, c.buildTaskStateEnvelope(wsprotocol.TypeTaskReceived, env.TaskID, env.ExecutionAttemptID)); err != nil { + return err + } + if !previous.Exists && state.Status == localtaskstore.TaskStatusReceived && c.enqueuer != nil { + return c.enqueuer.Enqueue(ctx, task.Task{ + ID: env.TaskID, + ExecutionAttemptID: env.ExecutionAttemptID, + Command: payload.Command, + Status: "pending", + Priority: payload.Priority, + }) + } + return nil +} + +func (c *Client) replayTaskFinal(ctx context.Context, conn Conn, taskID, executionAttemptID string) (bool, error) { + if c.outbox == nil { + return false, nil + } + messages, err := c.outbox.UnackedMessages() + if err != nil { + return false, err + } + for _, message := range messages { + if message.TaskID == taskID && message.ExecutionAttemptID == executionAttemptID && message.Type == localtaskstore.OutboxMessageTypeFinal { + return true, c.writeEnvelope(ctx, conn, envelopeFromOutboxMessage(c.agentID, message)) + } + } + return false, nil +} + func (c *Client) SendOutput(ctx context.Context, chunk localtaskstore.OutputChunk) error { if c.outbox == nil { return fmt.Errorf("result outbox is not configured") @@ -283,6 +364,17 @@ func (c *Client) SendFinal(ctx context.Context, result localtaskstore.FinalResul return nil } +func (c *Client) SendStarted(ctx context.Context, receipt localtaskstore.TaskReceipt) error { + return c.sendIfActive(ctx, c.buildTaskStateEnvelope(wsprotocol.TypeTaskStarted, receipt.TaskID, receipt.ExecutionAttemptID)) +} + +func (c *Client) RecordStarted(ctx context.Context, receipt localtaskstore.TaskReceipt) error { + if c.receipts == nil { + return fmt.Errorf("receipt store cannot record started state") + } + return c.receipts.RecordStarted(receipt.TaskID, receipt.ExecutionAttemptID) +} + func (c *Client) buildHello() wsprotocol.Envelope { return wsprotocol.Envelope{ ProtocolVersion: wsprotocol.ProtocolVersion, @@ -294,6 +386,19 @@ func (c *Client) buildHello() wsprotocol.Envelope { } } +func (c *Client) buildTaskStateEnvelope(messageType wsprotocol.MessageType, taskID, executionAttemptID string) wsprotocol.Envelope { + return wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: fmt.Sprintf("msg_%s_%s_%d", taskID, messageType, time.Now().UnixNano()), + Type: messageType, + AgentID: c.agentID, + TaskID: taskID, + ExecutionAttemptID: executionAttemptID, + SentAt: time.Now().UTC().Format(time.RFC3339), + Payload: map[string]any{}, + } +} + func (c *Client) setActive(active bool) { c.mu.Lock() defer c.mu.Unlock() diff --git a/app/services/wsclient/client_test.go b/app/services/wsclient/client_test.go index 62bc4ea..b45e580 100644 --- a/app/services/wsclient/client_test.go +++ b/app/services/wsclient/client_test.go @@ -9,6 +9,7 @@ import ( "encoding/pem" "errors" "hostlink/app/services/localtaskstore" + "hostlink/domain/task" "net/http" "os" "path/filepath" @@ -118,6 +119,179 @@ func TestClientHandlesAckWithoutTaskSideEffects(t *testing.T) { } } +func TestClientReceivesTaskDeliverStoresAcksAndQueues(t *testing.T) { + store := newClientTestStore(t) + enqueuer := &fakeTaskEnqueuer{} + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithReceiptStore(store), WithTaskEnqueuer(enqueuer)) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + hello := conn.waitForWrite(t) + conn.readCh <- helloAckEnvelope(hello.MessageID) + conn.readCh <- deliverEnvelope("msg_deliver", "task-1", "attempt-1", "printf hi", 2) + + received := conn.waitForWrite(t) + if received.Type != wsprotocol.TypeTaskReceived { + t.Fatalf("received type = %q, want %q", received.Type, wsprotocol.TypeTaskReceived) + } + if received.TaskID != "task-1" || received.ExecutionAttemptID != "attempt-1" { + t.Fatalf("received envelope = %#v", received) + } + state, err := store.TaskState("task-1", "attempt-1") + requireNoError(t, err) + if !state.Exists || state.Status != localtaskstore.TaskStatusReceived { + t.Fatalf("state = %#v", state) + } + waitFor(t, func() bool { return len(enqueuer.tasks()) == 1 }, "task to be queued") + queued := enqueuer.tasks()[0] + if queued.ID != "task-1" || queued.ExecutionAttemptID != "attempt-1" || queued.Command != "printf hi" || queued.Priority != 2 { + t.Fatalf("queued task = %#v", queued) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + +func TestClientRecordStartedPersistsRunningStateAndSendStartedSendsStarted(t *testing.T) { + store := newClientTestStore(t) + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithReceiptStore(store)) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + hello := conn.waitForWrite(t) + conn.readCh <- helloAckEnvelope(hello.MessageID) + waitFor(t, func() bool { return client.IsActive() }, "client to become active") + requireNoError(t, client.RecordStarted(context.Background(), localtaskstore.TaskReceipt{TaskID: "task-1", ExecutionAttemptID: "attempt-1"})) + requireNoError(t, client.SendStarted(context.Background(), localtaskstore.TaskReceipt{TaskID: "task-1", ExecutionAttemptID: "attempt-1"})) + + started := conn.waitForWrite(t) + if started.Type != wsprotocol.TypeTaskStarted { + t.Fatalf("started type = %q", started.Type) + } + state, err := store.TaskState("task-1", "attempt-1") + requireNoError(t, err) + if state.Status != localtaskstore.TaskStatusRunning { + t.Fatalf("state = %#v", state) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + +func TestClientDuplicateTaskDeliverReacksWithoutDuplicateQueue(t *testing.T) { + store := newClientTestStore(t) + requireNoError(t, store.RecordStarted("task-1", "attempt-1")) + enqueuer := &fakeTaskEnqueuer{} + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithReceiptStore(store), WithTaskEnqueuer(enqueuer)) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + hello := conn.waitForWrite(t) + conn.readCh <- helloAckEnvelope(hello.MessageID) + conn.readCh <- deliverEnvelope("msg_deliver", "task-1", "attempt-1", "printf hi", 2) + + started := conn.waitForWrite(t) + if started.Type != wsprotocol.TypeTaskStarted { + t.Fatalf("started type = %q", started.Type) + } + if len(enqueuer.tasks()) != 0 { + t.Fatalf("queued tasks = %#v, want none", enqueuer.tasks()) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + +func TestClientReceivedDuplicateTaskDeliverReacksWithoutDuplicateQueue(t *testing.T) { + store := newClientTestStore(t) + _, err := store.RecordReceived(localtaskstore.TaskReceipt{TaskID: "task-1", ExecutionAttemptID: "attempt-1"}) + requireNoError(t, err) + enqueuer := &fakeTaskEnqueuer{} + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithReceiptStore(store), WithTaskEnqueuer(enqueuer)) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + hello := conn.waitForWrite(t) + conn.readCh <- helloAckEnvelope(hello.MessageID) + conn.readCh <- deliverEnvelope("msg_deliver", "task-1", "attempt-1", "printf hi", 2) + + received := conn.waitForWrite(t) + if received.Type != wsprotocol.TypeTaskReceived { + t.Fatalf("received type = %q", received.Type) + } + if len(enqueuer.tasks()) != 0 { + t.Fatalf("queued tasks = %#v, want none", enqueuer.tasks()) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + +func TestClientFinalDuplicateTaskDeliverResendsUnackedFinalWithoutQueue(t *testing.T) { + store := newClientTestStore(t) + requireNoError(t, store.RecordFinal(localtaskstore.FinalResult{ + MessageID: "msg-final-1", + TaskID: "task-1", + ExecutionAttemptID: "attempt-1", + Status: "completed", + ExitCode: 0, + Payload: `{"status":"completed","exit_code":0,"output_truncated":false,"error_truncated":false}`, + })) + enqueuer := &fakeTaskEnqueuer{} + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithReceiptStore(store), WithResultOutbox(store), WithTaskEnqueuer(enqueuer)) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + hello := conn.waitForWrite(t) + conn.readCh <- helloAckEnvelope(hello.MessageID) + replayed := conn.waitForWrite(t) + if replayed.Type != wsprotocol.TypeTaskFinal { + t.Fatalf("hello replay type = %q", replayed.Type) + } + conn.readCh <- deliverEnvelope("msg_deliver", "task-1", "attempt-1", "printf hi", 2) + + final := conn.waitForWrite(t) + if final.Type != wsprotocol.TypeTaskFinal || final.MessageID != "msg-final-1" { + t.Fatalf("final duplicate response = %#v", final) + } + if len(enqueuer.tasks()) != 0 { + t.Fatalf("queued tasks = %#v, want none", enqueuer.tasks()) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + func TestClientAckRemovesResultMessageFromOutbox(t *testing.T) { store := newClientTestStore(t) requireNoError(t, store.AppendOutputChunk(localtaskstore.OutputChunk{ @@ -419,6 +593,14 @@ func WithResultOutbox(outbox localtaskstore.ResultOutbox) clientOption { return func(cfg *Config) { cfg.ResultOutbox = outbox } } +func WithReceiptStore(store localtaskstore.ReceiptStore) clientOption { + return func(cfg *Config) { cfg.ReceiptStore = store } +} + +func WithTaskEnqueuer(enqueuer TaskEnqueuer) clientOption { + return func(cfg *Config) { cfg.TaskEnqueuer = enqueuer } +} + type fakeDialer struct { mu sync.Mutex conn *fakeConn @@ -544,6 +726,64 @@ func ackEnvelope(messageID, ackedMessageID string, ackedType wsprotocol.MessageT } } +func helloAckEnvelope(ackedMessageID string) wsprotocol.Envelope { + return wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: "msg_hello_ack", + Type: wsprotocol.TypeAgentHelloAck, + AgentID: "agent_ws_test", + SentAt: time.Now().UTC().Format(time.RFC3339), + Payload: payloadMapForTest(wsprotocol.BuildAck(wsprotocol.AckOptions{ + AckedMessageID: ackedMessageID, + AckedType: wsprotocol.TypeAgentHello, + })), + } +} + +func deliverEnvelope(messageID, taskID, attemptID, command string, priority int) wsprotocol.Envelope { + return wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: messageID, + Type: wsprotocol.TypeTaskDeliver, + AgentID: "agent_ws_test", + TaskID: taskID, + ExecutionAttemptID: attemptID, + SentAt: time.Now().UTC().Format(time.RFC3339), + Payload: map[string]any{ + "command": command, + "priority": priority, + }, + } +} + +func payloadMapForTest(value any) map[string]any { + data, _ := json.Marshal(value) + var payload map[string]any + _ = json.Unmarshal(data, &payload) + return payload +} + +type fakeTaskEnqueuer struct { + mu sync.Mutex + queued []task.Task + enqueue error +} + +func (f *fakeTaskEnqueuer) Enqueue(ctx context.Context, t task.Task) error { + f.mu.Lock() + defer f.mu.Unlock() + f.queued = append(f.queued, t) + return f.enqueue +} + +func (f *fakeTaskEnqueuer) tasks() []task.Task { + f.mu.Lock() + defer f.mu.Unlock() + tasks := make([]task.Task, len(f.queued)) + copy(tasks, f.queued) + return tasks +} + func newClientTestStore(t *testing.T) *localtaskstore.Store { t.Helper() store, err := localtaskstore.New(localtaskstore.Config{ diff --git a/internal/wsprotocol/message.go b/internal/wsprotocol/message.go index 80f375d..c1878bc 100644 --- a/internal/wsprotocol/message.go +++ b/internal/wsprotocol/message.go @@ -65,6 +65,11 @@ type FinalPayload struct { ErrorTruncated bool `json:"error_truncated"` } +type TaskDeliverPayload struct { + Command string `json:"command"` + Priority int `json:"priority"` +} + func (e Envelope) Validate(authenticatedAgentID string) error { if e.ProtocolVersion != ProtocolVersion { return fmt.Errorf("unsupported protocol_version: %d", e.ProtocolVersion) @@ -120,6 +125,13 @@ func (p FinalPayload) Validate() error { } } +func (p TaskDeliverPayload) Validate() error { + if p.Command == "" { + return fmt.Errorf("command is required") + } + return nil +} + func DecodePayload[T any](e Envelope) (T, error) { var payload T diff --git a/main.go b/main.go index 7208b8a..039fe9c 100644 --- a/main.go +++ b/main.go @@ -268,8 +268,9 @@ func runServer(ctx context.Context, cmd *cli.Command) error { <-registeredChan log.Println("Agent registered, starting task job...") var resultChannel taskjob.ResultChannel + taskJob := taskjob.New() startWebSocketClientIfEnabled(ctx, func() (webSocketRuntime, error) { - runtime, err := newDefaultWebSocketRuntime(localStore) + runtime, err := newDefaultWebSocketRuntime(localStore, taskJob) if err == nil { resultChannel = runtime.(taskjob.ResultChannel) } @@ -286,7 +287,6 @@ func runServer(ctx context.Context, cmd *cli.Command) error { log.Printf("failed to initialize task reporter: %v", err) return } - taskJob := taskjob.New() taskJob.Register(ctx, fetcher, reporter, resultChannel) metricsReporter, err := metrics.New() @@ -335,7 +335,7 @@ func startWebSocketClientIfEnabled(ctx context.Context, constructor func() (webS return true } -func newDefaultWebSocketRuntime(localStore *localtaskstore.Store) (webSocketRuntime, error) { +func newDefaultWebSocketRuntime(localStore *localtaskstore.Store, enqueuer wsclient.TaskEnqueuer) (webSocketRuntime, error) { state := agentstate.New(appconf.AgentStatePath()) if err := state.Load(); err != nil { return nil, fmt.Errorf("failed to load agent state: %w", err) @@ -351,6 +351,8 @@ func newDefaultWebSocketRuntime(localStore *localtaskstore.Store) (webSocketRunt ReconnectMax: appconf.WebSocketReconnectMax(), PingInterval: appconf.WebSocketPingInterval(), ResultOutbox: localStore, + ReceiptStore: localStore, + TaskEnqueuer: enqueuer, }) } diff --git a/test/integration/taskjob_reporter_test.go b/test/integration/taskjob_reporter_test.go index 6538207..a3d9cc0 100644 --- a/test/integration/taskjob_reporter_test.go +++ b/test/integration/taskjob_reporter_test.go @@ -22,6 +22,7 @@ import ( "os" "sync" "testing" + "time" "github.com/glebarez/sqlite" "github.com/labstack/echo/v4" @@ -48,6 +49,7 @@ func TestTaskJobReporter_SendsUpdateViaAPI(t *testing.T) { wg.Add(1) env.echo.PUT("/api/v1/tasks/:id", func(c echo.Context) error { + defer wg.Done() mu.Lock() updateReceived = true mu.Unlock() @@ -57,7 +59,6 @@ func TestTaskJobReporter_SendsUpdateViaAPI(t *testing.T) { job := taskjob.NewJobWithConf(taskjob.TaskJobConfig{ Trigger: func(ctx context.Context, fn func() error) { fn() - wg.Done() }, }) defer job.Shutdown() @@ -65,7 +66,7 @@ func TestTaskJobReporter_SendsUpdateViaAPI(t *testing.T) { ctx := context.Background() job.Register(ctx, env.fetcher, env.reporter) - wg.Wait() + waitForTaskJobReporterUpdate(t, &wg) mu.Lock() defer mu.Unlock() @@ -94,6 +95,7 @@ func TestTaskJobReporter_CapturesTaskOutput(t *testing.T) { wg.Add(1) env.echo.PUT("/api/v1/tasks/:id", func(c echo.Context) error { + defer wg.Done() var req map[string]any if err := c.Bind(&req); err != nil { return err @@ -109,7 +111,6 @@ func TestTaskJobReporter_CapturesTaskOutput(t *testing.T) { job := taskjob.NewJobWithConf(taskjob.TaskJobConfig{ Trigger: func(ctx context.Context, fn func() error) { fn() - wg.Done() }, }) defer job.Shutdown() @@ -117,7 +118,7 @@ func TestTaskJobReporter_CapturesTaskOutput(t *testing.T) { ctx := context.Background() job.Register(ctx, env.fetcher, env.reporter) - wg.Wait() + waitForTaskJobReporterUpdate(t, &wg) mu.Lock() defer mu.Unlock() @@ -142,6 +143,7 @@ func TestTaskJobReporter_SendsExitCode(t *testing.T) { wg.Add(1) env.echo.PUT("/api/v1/tasks/:id", func(c echo.Context) error { + defer wg.Done() var req map[string]any if err := c.Bind(&req); err != nil { return err @@ -157,7 +159,6 @@ func TestTaskJobReporter_SendsExitCode(t *testing.T) { job := taskjob.NewJobWithConf(taskjob.TaskJobConfig{ Trigger: func(ctx context.Context, fn func() error) { fn() - wg.Done() }, }) defer job.Shutdown() @@ -165,7 +166,7 @@ func TestTaskJobReporter_SendsExitCode(t *testing.T) { ctx := context.Background() job.Register(ctx, env.fetcher, env.reporter) - wg.Wait() + waitForTaskJobReporterUpdate(t, &wg) mu.Lock() defer mu.Unlock() @@ -191,6 +192,7 @@ func TestTaskJobReporter_SendsErrorMessages(t *testing.T) { wg.Add(1) env.echo.PUT("/api/v1/tasks/:id", func(c echo.Context) error { + defer wg.Done() var req map[string]any if err := c.Bind(&req); err != nil { return err @@ -209,7 +211,6 @@ func TestTaskJobReporter_SendsErrorMessages(t *testing.T) { job := taskjob.NewJobWithConf(taskjob.TaskJobConfig{ Trigger: func(ctx context.Context, fn func() error) { fn() - wg.Done() }, }) defer job.Shutdown() @@ -217,7 +218,7 @@ func TestTaskJobReporter_SendsErrorMessages(t *testing.T) { ctx := context.Background() job.Register(ctx, env.fetcher, env.reporter) - wg.Wait() + waitForTaskJobReporterUpdate(t, &wg) mu.Lock() defer mu.Unlock() @@ -243,6 +244,7 @@ func TestTaskJobReporter_IncludesAuthHeaders(t *testing.T) { wg.Add(1) env.echo.PUT("/api/v1/tasks/:id", func(c echo.Context) error { + defer wg.Done() mu.Lock() hasAgentID = c.Request().Header.Get("X-Agent-ID") != "" hasTimestamp = c.Request().Header.Get("X-Timestamp") != "" @@ -255,7 +257,6 @@ func TestTaskJobReporter_IncludesAuthHeaders(t *testing.T) { job := taskjob.NewJobWithConf(taskjob.TaskJobConfig{ Trigger: func(ctx context.Context, fn func() error) { fn() - wg.Done() }, }) defer job.Shutdown() @@ -263,7 +264,7 @@ func TestTaskJobReporter_IncludesAuthHeaders(t *testing.T) { ctx := context.Background() job.Register(ctx, env.fetcher, env.reporter) - wg.Wait() + waitForTaskJobReporterUpdate(t, &wg) mu.Lock() defer mu.Unlock() @@ -291,6 +292,7 @@ func TestTaskJobReporter_FailedUpdateIsLogged(t *testing.T) { wg.Add(1) env.echo.PUT("/api/v1/tasks/:id", func(c echo.Context) error { + defer wg.Done() mu.Lock() updateAttempted = true mu.Unlock() @@ -300,7 +302,6 @@ func TestTaskJobReporter_FailedUpdateIsLogged(t *testing.T) { job := taskjob.NewJobWithConf(taskjob.TaskJobConfig{ Trigger: func(ctx context.Context, fn func() error) { fn() - wg.Done() }, }) defer job.Shutdown() @@ -308,13 +309,27 @@ func TestTaskJobReporter_FailedUpdateIsLogged(t *testing.T) { ctx := context.Background() job.Register(ctx, env.fetcher, env.reporter) - wg.Wait() + waitForTaskJobReporterUpdate(t, &wg) mu.Lock() defer mu.Unlock() assert.True(t, updateAttempted, "Update should be attempted even if it fails") } +func waitForTaskJobReporterUpdate(t *testing.T, wg *sync.WaitGroup) { + t.Helper() + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for task update") + } +} + type taskJobTestEnv struct { db *gorm.DB echo *echo.Echo From d5bc777407310d4338d7887a54283a79b03eb325 Mon Sep 17 00:00:00 2001 From: Mohammad Aziz Date: Tue, 28 Apr 2026 14:10:43 +0530 Subject: [PATCH 2/2] feat(ws): add reconnect reconciliation with agent snapshot and hello-ack directives --- app/services/localtaskstore/execution_test.go | 16 ++ app/services/localtaskstore/recovery_test.go | 54 ++++ app/services/localtaskstore/store.go | 234 +++++++++++++++++- app/services/wsclient/client.go | 140 ++++++++++- app/services/wsclient/client_test.go | 186 +++++++++++++- internal/wsprotocol/message.go | 111 +++++++++ internal/wsprotocol/message_test.go | 63 +++++ main.go | 1 + 8 files changed, 791 insertions(+), 14 deletions(-) diff --git a/app/services/localtaskstore/execution_test.go b/app/services/localtaskstore/execution_test.go index 64a593a..8208969 100644 --- a/app/services/localtaskstore/execution_test.go +++ b/app/services/localtaskstore/execution_test.go @@ -51,6 +51,22 @@ func TestTaskStateTreatsNewAttemptAsDistinct(t *testing.T) { require.False(t, state.Exists) } +func TestDiscardReceivedRemovesOnlyReceivedAttempt(t *testing.T) { + store := newTestStore(t, 1024*1024, 1024) + + _, err := store.RecordReceived(TaskReceipt{TaskID: "task-1", ExecutionAttemptID: "attempt-received"}) + require.NoError(t, err) + require.NoError(t, store.RecordStarted("task-1", "attempt-running")) + + require.NoError(t, store.DiscardReceived("task-1", "attempt-received")) + received, err := store.TaskState("task-1", "attempt-received") + require.NoError(t, err) + require.False(t, received.Exists) + running, err := store.TaskState("task-1", "attempt-running") + require.NoError(t, err) + require.Equal(t, TaskStatusRunning, running.Status) +} + func openTestStore(t *testing.T, path string, spoolCapBytes, terminalReserveBytes int64) *Store { t.Helper() diff --git a/app/services/localtaskstore/recovery_test.go b/app/services/localtaskstore/recovery_test.go index e1bb0b3..c8aeda5 100644 --- a/app/services/localtaskstore/recovery_test.go +++ b/app/services/localtaskstore/recovery_test.go @@ -20,6 +20,47 @@ func TestSnapshotIncludesLocalTruncationState(t *testing.T) { require.True(t, snapshot.Tasks[0].LocalOutputTruncated) } +func TestSnapshotIncludesReconnectSafeState(t *testing.T) { + store := newTestStore(t, 1024*1024, 1024) + + _, err := store.RecordReceived(TaskReceipt{TaskID: "task-received", ExecutionAttemptID: "attempt-received"}) + require.NoError(t, err) + require.NoError(t, store.RecordStarted("task-running", "attempt-running")) + require.NoError(t, store.AppendOutputChunk(OutputChunk{ + MessageID: "msg-output-1", + TaskID: "task-running", + ExecutionAttemptID: "attempt-running", + Stream: "stdout", + Sequence: 4, + Payload: "hello", + ByteCount: 5, + })) + require.NoError(t, store.RecordFinal(FinalResult{ + MessageID: "msg-final-1", + TaskID: "task-final", + ExecutionAttemptID: "attempt-final", + Status: "completed", + ExitCode: 0, + Payload: `{"status":"completed","exit_code":0,"output_truncated":false,"error_truncated":false}`, + })) + + snapshot, err := store.Snapshot() + require.NoError(t, err) + require.NotNil(t, snapshot.RunningTask) + require.Equal(t, "task-running", snapshot.RunningTask.TaskID) + require.Equal(t, int64(4), snapshot.RunningTask.LastOutputSequence["stdout"]) + require.Len(t, snapshot.ReceivedNotStarted, 1) + require.Equal(t, "task-received", snapshot.ReceivedNotStarted[0].TaskID) + require.Len(t, snapshot.UnackedFinals, 1) + require.Equal(t, "msg-final-1", snapshot.UnackedFinals[0].MessageID) + require.Equal(t, "completed", snapshot.UnackedFinals[0].Status) + require.Len(t, snapshot.UnackedOutput, 1) + require.Equal(t, int64(4), snapshot.UnackedOutput[0].FirstSequence) + require.Equal(t, int64(4), snapshot.UnackedOutput[0].LastSequence) + require.Positive(t, snapshot.SpoolStatus.BytesUsed) + require.Equal(t, int64(1024*1024), snapshot.SpoolStatus.ByteCap) +} + func TestMarkInterruptedRunningTasksQueuesTerminalRecordAcrossRestart(t *testing.T) { storePath := filepath.Join(t.TempDir(), "task_store.db") store := openTestStore(t, storePath, 1024*1024, 1024) @@ -41,6 +82,19 @@ func TestMarkInterruptedRunningTasksQueuesTerminalRecordAcrossRestart(t *testing require.Contains(t, messages[0].Payload, "interrupted") } +func TestSnapshotReportsInterruptedStartupRecoveryAsUnackedFinal(t *testing.T) { + store := newTestStore(t, 1024*1024, 1024) + + require.NoError(t, store.RecordStarted("task-1", "attempt-1")) + require.NoError(t, store.MarkInterruptedRunningTasks()) + + snapshot, err := store.Snapshot() + require.NoError(t, err) + require.Nil(t, snapshot.RunningTask) + require.Len(t, snapshot.UnackedFinals, 1) + require.Equal(t, "interrupted", snapshot.UnackedFinals[0].Status) +} + func TestMarkInterruptedRunningTasksIsIdempotent(t *testing.T) { store := newTestStore(t, 1024*1024, 1024) diff --git a/app/services/localtaskstore/store.go b/app/services/localtaskstore/store.go index b117a22..2d91e53 100644 --- a/app/services/localtaskstore/store.go +++ b/app/services/localtaskstore/store.go @@ -1,6 +1,7 @@ package localtaskstore import ( + "encoding/json" "errors" "fmt" "os" @@ -41,6 +42,8 @@ type TaskState struct { OutputTruncated bool ErrorTruncated bool LocalOutputTruncated bool + ReceivedAt time.Time + StartedAt time.Time } type OutputChunk struct { @@ -76,7 +79,50 @@ type OutboxMessage struct { } type Snapshot struct { - Tasks []TaskState + Tasks []TaskState + RunningTask *RunningTaskSnapshot + ReceivedNotStarted []ReceivedAttemptSnapshot + UnackedFinals []UnackedFinalSnapshot + UnackedOutput []UnackedOutputRange + SpoolStatus SpoolStatus +} + +type RunningTaskSnapshot struct { + TaskID string + ExecutionAttemptID string + StartedAt time.Time + LastOutputSequence map[string]int64 +} + +type ReceivedAttemptSnapshot struct { + TaskID string + ExecutionAttemptID string + ReceivedAt time.Time +} + +type UnackedFinalSnapshot struct { + MessageID string + TaskID string + ExecutionAttemptID string + Status string + ExitCode int + OutputTruncated bool + ErrorTruncated bool +} + +type UnackedOutputRange struct { + TaskID string + ExecutionAttemptID string + Stream string + FirstSequence int64 + LastSequence int64 + TruncatedLocally bool +} + +type SpoolStatus struct { + BytesUsed int64 + ByteCap int64 + HasRotatedChunks bool } type Config struct { @@ -96,11 +142,13 @@ type ResultOutbox interface { RecordFinal(FinalResult) error UnackedMessages() ([]OutboxMessage, error) AckMessage(messageID string) error + UnackedMessagesFrom(taskID, executionAttemptID, stream string, nextSequence int64) ([]OutboxMessage, error) } type RecoveryStore interface { Snapshot() (Snapshot, error) MarkInterruptedRunningTasks() error + DiscardReceived(taskID, executionAttemptID string) error } type Store struct { @@ -118,6 +166,8 @@ type taskExecutionRecord struct { OutputTruncated bool ErrorTruncated bool LocalOutputTruncated bool + ReceivedAt *time.Time + StartedAt *time.Time CreatedAt time.Time UpdatedAt time.Time } @@ -223,6 +273,7 @@ func (s *Store) RecordReceived(receipt TaskReceipt) (TaskState, error) { var state TaskState err := s.db.Transaction(func(tx *gorm.DB) error { + now := time.Now().UTC() if err := s.ensureTerminalReserveAvailable(tx); err != nil { return err } @@ -241,6 +292,7 @@ func (s *Store) RecordReceived(receipt TaskReceipt) (TaskState, error) { TaskID: receipt.TaskID, ExecutionAttemptID: receipt.ExecutionAttemptID, Status: TaskStatusReceived, + ReceivedAt: &now, } if err := tx.Create(&record).Error; err != nil { return err @@ -280,10 +332,12 @@ func (s *Store) RecordStarted(taskID, executionAttemptID string) error { } return s.db.Transaction(func(tx *gorm.DB) error { + now := time.Now().UTC() return s.upsertExecutionState(tx, taskExecutionRecord{ TaskID: taskID, ExecutionAttemptID: executionAttemptID, Status: TaskStatusRunning, + StartedAt: &now, }) }) } @@ -464,6 +518,39 @@ func (s *Store) UnackedMessages() ([]OutboxMessage, error) { return messages, nil } +func (s *Store) UnackedMessagesFrom(taskID, executionAttemptID, stream string, nextSequence int64) ([]OutboxMessage, error) { + if taskID == "" { + return nil, fmt.Errorf("task ID is required") + } + if executionAttemptID == "" { + return nil, fmt.Errorf("execution attempt ID is required") + } + if stream == "" { + return nil, fmt.Errorf("stream is required") + } + if nextSequence <= 0 { + return nil, fmt.Errorf("next sequence must be positive") + } + + var records []outboxMessageRecord + if err := s.db.Where( + "acked_at IS NULL AND type = ? AND task_id = ? AND execution_attempt_id = ? AND stream = ? AND sequence >= ?", + OutboxMessageTypeOutput, + taskID, + executionAttemptID, + stream, + nextSequence, + ).Order("sequence ASC, created_at ASC, id ASC").Find(&records).Error; err != nil { + return nil, fmt.Errorf("load unacked messages from sequence: %w", err) + } + + messages := make([]OutboxMessage, 0, len(records)) + for _, record := range records { + messages = append(messages, outboxMessageFromRecord(record)) + } + return messages, nil +} + func (s *Store) AckMessage(messageID string) error { if messageID == "" { return fmt.Errorf("message ID is required") @@ -476,16 +563,105 @@ func (s *Store) AckMessage(messageID string) error { } func (s *Store) Snapshot() (Snapshot, error) { - var records []taskExecutionRecord - if err := s.db.Order("updated_at ASC, id ASC").Find(&records).Error; err != nil { + var snapshot Snapshot + err := s.db.Transaction(func(tx *gorm.DB) error { + var records []taskExecutionRecord + if err := tx.Order("updated_at ASC, id ASC").Find(&records).Error; err != nil { + return err + } + + snapshot = Snapshot{ + Tasks: make([]TaskState, 0, len(records)), + ReceivedNotStarted: []ReceivedAttemptSnapshot{}, + UnackedFinals: []UnackedFinalSnapshot{}, + UnackedOutput: []UnackedOutputRange{}, + SpoolStatus: SpoolStatus{ByteCap: s.spoolCapBytes}, + } + stateByAttempt := make(map[string]TaskState, len(records)) + for _, record := range records { + state := taskStateFromRecord(record) + snapshot.Tasks = append(snapshot.Tasks, state) + stateByAttempt[attemptKey(state.TaskID, state.ExecutionAttemptID)] = state + switch state.Status { + case TaskStatusRunning: + running := RunningTaskSnapshot{ + TaskID: state.TaskID, + ExecutionAttemptID: state.ExecutionAttemptID, + StartedAt: state.StartedAt, + LastOutputSequence: map[string]int64{"stdout": 0, "stderr": 0}, + } + snapshot.RunningTask = &running + case TaskStatusReceived: + snapshot.ReceivedNotStarted = append(snapshot.ReceivedNotStarted, ReceivedAttemptSnapshot{ + TaskID: state.TaskID, + ExecutionAttemptID: state.ExecutionAttemptID, + ReceivedAt: state.ReceivedAt, + }) + } + } + + var outbox []outboxMessageRecord + if err := tx.Where("acked_at IS NULL").Order("created_at ASC, id ASC").Find(&outbox).Error; err != nil { + return err + } + outputRanges := map[string]*UnackedOutputRange{} + for _, message := range outbox { + snapshot.SpoolStatus.BytesUsed += message.ByteCount + state := stateByAttempt[attemptKey(message.TaskID, message.ExecutionAttemptID)] + if state.LocalOutputTruncated { + snapshot.SpoolStatus.HasRotatedChunks = true + } + switch message.Type { + case OutboxMessageTypeFinal: + snapshot.UnackedFinals = append(snapshot.UnackedFinals, finalSnapshotFromRecord(message)) + case OutboxMessageTypeOutput: + key := attemptKey(message.TaskID, message.ExecutionAttemptID) + "|" + message.Stream + rangeSnapshot := outputRanges[key] + if rangeSnapshot == nil { + rangeSnapshot = &UnackedOutputRange{ + TaskID: message.TaskID, + ExecutionAttemptID: message.ExecutionAttemptID, + Stream: message.Stream, + FirstSequence: message.Sequence, + LastSequence: message.Sequence, + TruncatedLocally: state.LocalOutputTruncated, + } + outputRanges[key] = rangeSnapshot + } else { + if message.Sequence < rangeSnapshot.FirstSequence { + rangeSnapshot.FirstSequence = message.Sequence + } + if message.Sequence > rangeSnapshot.LastSequence { + rangeSnapshot.LastSequence = message.Sequence + } + rangeSnapshot.TruncatedLocally = rangeSnapshot.TruncatedLocally || state.LocalOutputTruncated + } + if snapshot.RunningTask != nil && snapshot.RunningTask.TaskID == message.TaskID && snapshot.RunningTask.ExecutionAttemptID == message.ExecutionAttemptID { + if message.Sequence > snapshot.RunningTask.LastOutputSequence[message.Stream] { + snapshot.RunningTask.LastOutputSequence[message.Stream] = message.Sequence + } + } + } + } + for _, outputRange := range outputRanges { + snapshot.UnackedOutput = append(snapshot.UnackedOutput, *outputRange) + } + return nil + }) + if err != nil { return Snapshot{}, fmt.Errorf("load task snapshot: %w", err) } + return snapshot, nil +} - snapshot := Snapshot{Tasks: make([]TaskState, 0, len(records))} - for _, record := range records { - snapshot.Tasks = append(snapshot.Tasks, taskStateFromRecord(record)) +func (s *Store) DiscardReceived(taskID, executionAttemptID string) error { + if taskID == "" { + return fmt.Errorf("task ID is required") } - return snapshot, nil + if executionAttemptID == "" { + return fmt.Errorf("execution attempt ID is required") + } + return s.db.Where("task_id = ? AND execution_attempt_id = ? AND status = ?", taskID, executionAttemptID, TaskStatusReceived).Delete(&taskExecutionRecord{}).Error } func (s *Store) MarkInterruptedRunningTasks() error { @@ -553,6 +729,12 @@ func (s *Store) upsertExecutionState(tx *gorm.DB, next taskExecutionRecord) erro if next.LocalOutputTruncated { updates["local_output_truncated"] = true } + if next.ReceivedAt != nil { + updates["received_at"] = next.ReceivedAt + } + if next.StartedAt != nil { + updates["started_at"] = next.StartedAt + } return tx.Model(&existing).Updates(updates).Error } @@ -595,6 +777,14 @@ func validateFinalResult(result FinalResult) error { } func taskStateFromRecord(record taskExecutionRecord) TaskState { + var receivedAt time.Time + if record.ReceivedAt != nil { + receivedAt = *record.ReceivedAt + } + var startedAt time.Time + if record.StartedAt != nil { + startedAt = *record.StartedAt + } return TaskState{ ID: record.ID, Exists: true, @@ -605,6 +795,8 @@ func taskStateFromRecord(record taskExecutionRecord) TaskState { OutputTruncated: record.OutputTruncated, ErrorTruncated: record.ErrorTruncated, LocalOutputTruncated: record.LocalOutputTruncated, + ReceivedAt: receivedAt, + StartedAt: startedAt, } } @@ -624,3 +816,31 @@ func outboxMessageFromRecord(record outboxMessageRecord) OutboxMessage { func interruptedMessageID(taskID, executionAttemptID string) string { return "local-interrupted-" + strings.NewReplacer("/", "-", " ", "-", "|", "-").Replace(taskID+"-"+executionAttemptID) } + +func attemptKey(taskID, executionAttemptID string) string { + return taskID + "|" + executionAttemptID +} + +func finalSnapshotFromRecord(record outboxMessageRecord) UnackedFinalSnapshot { + snapshot := UnackedFinalSnapshot{ + MessageID: record.MessageID, + TaskID: record.TaskID, + ExecutionAttemptID: record.ExecutionAttemptID, + } + if record.Payload == "" { + return snapshot + } + var payload struct { + Status string `json:"status"` + ExitCode int `json:"exit_code"` + OutputTruncated bool `json:"output_truncated"` + ErrorTruncated bool `json:"error_truncated"` + } + if err := json.Unmarshal([]byte(record.Payload), &payload); err == nil { + snapshot.Status = payload.Status + snapshot.ExitCode = payload.ExitCode + snapshot.OutputTruncated = payload.OutputTruncated + snapshot.ErrorTruncated = payload.ErrorTruncated + } + return snapshot +} diff --git a/app/services/wsclient/client.go b/app/services/wsclient/client.go index 4cee80a..1f2897f 100644 --- a/app/services/wsclient/client.go +++ b/app/services/wsclient/client.go @@ -15,6 +15,7 @@ import ( "hostlink/app/services/requestsigner" "hostlink/domain/task" "hostlink/internal/wsprotocol" + "hostlink/version" ) var ErrAgentNotRegistered = errors.New("agent not registered: missing agent ID") @@ -47,6 +48,7 @@ type Config struct { SleepFunc SleepFunc ResultOutbox localtaskstore.ResultOutbox ReceiptStore localtaskstore.ReceiptStore + RecoveryStore localtaskstore.RecoveryStore TaskEnqueuer TaskEnqueuer } @@ -67,6 +69,7 @@ type Client struct { conn Conn outbox localtaskstore.ResultOutbox receipts localtaskstore.ReceiptStore + recovery localtaskstore.RecoveryStore enqueuer TaskEnqueuer } @@ -109,6 +112,7 @@ func New(cfg Config) (*Client, error) { sleep: cfg.SleepFunc, outbox: cfg.ResultOutbox, receipts: cfg.ReceiptStore, + recovery: cfg.RecoveryStore, enqueuer: cfg.TaskEnqueuer, }, nil } @@ -214,16 +218,26 @@ func (c *Client) readLoop(ctx context.Context, conn Conn, helloMessageID string) switch env.Type { case wsprotocol.TypeAgentHelloAck: - ack, err := wsprotocol.DecodePayload[wsprotocol.AckPayload](env) + helloAck, err := wsprotocol.DecodePayload[wsprotocol.HelloAckPayload](env) if err != nil { return err } - if ack.AckedMessageID == helloMessageID { - c.setActive(true) - if err := c.replayUnacked(ctx, conn); err != nil { + if helloAck.AckedMessageID == helloMessageID { + if err := c.applyHelloAckLocalState(helloAck); err != nil { return err } + c.setActive(true) + if helloAck.HasReconciliationDirectives() { + if err := c.replayRequestedOutput(ctx, conn, helloAck.OutputReplay); err != nil { + return err + } + } else { + if err := c.replayUnacked(ctx, conn); err != nil { + return err + } + } } + ack := wsprotocol.AckPayload{AckedMessageID: helloAck.AckedMessageID, AckedType: helloAck.AckedType} c.setLastAck(&ack) case wsprotocol.TypeAck: ack, err := wsprotocol.DecodePayload[wsprotocol.AckPayload](env) @@ -376,16 +390,115 @@ func (c *Client) RecordStarted(ctx context.Context, receipt localtaskstore.TaskR } func (c *Client) buildHello() wsprotocol.Envelope { + payload := c.buildHelloPayload() return wsprotocol.Envelope{ ProtocolVersion: wsprotocol.ProtocolVersion, MessageID: fmt.Sprintf("msg_%d", time.Now().UnixNano()), Type: wsprotocol.TypeAgentHello, AgentID: c.agentID, SentAt: time.Now().UTC().Format(time.RFC3339), - Payload: map[string]any{}, + Payload: payloadFromValue(payload), } } +func (c *Client) buildHelloPayload() wsprotocol.HelloPayload { + payload := wsprotocol.HelloPayload{ + ReceivedNotStarted: []wsprotocol.ReceivedNotStartedAttempt{}, + UnackedFinals: []wsprotocol.UnackedFinalSnapshot{}, + UnackedOutput: []wsprotocol.UnackedOutputRange{}, + SpoolStatus: wsprotocol.SpoolStatus{}, + ClientVersion: version.Version, + } + if c.recovery == nil { + return payload + } + snapshot, err := c.recovery.Snapshot() + if err != nil { + return payload + } + if snapshot.RunningTask != nil { + payload.RunningTask = &wsprotocol.RunningTaskSnapshot{ + TaskID: snapshot.RunningTask.TaskID, + ExecutionAttemptID: snapshot.RunningTask.ExecutionAttemptID, + StartedAt: formatTime(snapshot.RunningTask.StartedAt), + LastOutputSequence: map[string]int{ + "stdout": int(snapshot.RunningTask.LastOutputSequence["stdout"]), + "stderr": int(snapshot.RunningTask.LastOutputSequence["stderr"]), + }, + } + } + for _, attempt := range snapshot.ReceivedNotStarted { + payload.ReceivedNotStarted = append(payload.ReceivedNotStarted, wsprotocol.ReceivedNotStartedAttempt{ + TaskID: attempt.TaskID, + ExecutionAttemptID: attempt.ExecutionAttemptID, + ReceivedAt: formatTime(attempt.ReceivedAt), + }) + } + for _, final := range snapshot.UnackedFinals { + payload.UnackedFinals = append(payload.UnackedFinals, wsprotocol.UnackedFinalSnapshot{ + MessageID: final.MessageID, + TaskID: final.TaskID, + ExecutionAttemptID: final.ExecutionAttemptID, + Status: wsprotocol.FinalStatus(final.Status), + ExitCode: final.ExitCode, + OutputTruncated: final.OutputTruncated, + ErrorTruncated: final.ErrorTruncated, + }) + } + for _, outputRange := range snapshot.UnackedOutput { + payload.UnackedOutput = append(payload.UnackedOutput, wsprotocol.UnackedOutputRange{ + TaskID: outputRange.TaskID, + ExecutionAttemptID: outputRange.ExecutionAttemptID, + Stream: wsprotocol.Stream(outputRange.Stream), + FirstSequence: int(outputRange.FirstSequence), + LastSequence: int(outputRange.LastSequence), + TruncatedLocally: outputRange.TruncatedLocally, + }) + } + payload.SpoolStatus = wsprotocol.SpoolStatus{ + BytesUsed: snapshot.SpoolStatus.BytesUsed, + ByteCap: snapshot.SpoolStatus.ByteCap, + HasRotatedChunks: snapshot.SpoolStatus.HasRotatedChunks, + } + return payload +} + +func (c *Client) applyHelloAckLocalState(ack wsprotocol.HelloAckPayload) error { + if c.outbox != nil { + for _, messageID := range ack.AcknowledgedFinalMessageIDs { + if err := c.outbox.AckMessage(messageID); err != nil { + return err + } + } + } + if c.recovery != nil { + for _, attempt := range ack.DiscardedAttempts { + if err := c.recovery.DiscardReceived(attempt.TaskID, attempt.ExecutionAttemptID); err != nil { + return err + } + } + } + return nil +} + +func (c *Client) replayRequestedOutput(ctx context.Context, conn Conn, replayRequests []wsprotocol.OutputReplayDirective) error { + if c.outbox == nil { + return nil + } + for _, replay := range replayRequests { + messages, err := c.outbox.UnackedMessagesFrom(replay.TaskID, replay.ExecutionAttemptID, string(replay.Stream), int64(replay.NextSequence)) + if err != nil { + return err + } + for _, message := range messages { + if err := c.writeEnvelope(ctx, conn, envelopeFromOutboxMessage(c.agentID, message)); err != nil { + return err + } + } + } + return nil +} + func (c *Client) buildTaskStateEnvelope(messageType wsprotocol.MessageType, taskID, executionAttemptID string) wsprotocol.Envelope { return wsprotocol.Envelope{ ProtocolVersion: wsprotocol.ProtocolVersion, @@ -493,6 +606,23 @@ func envelopeFromOutboxMessage(agentID string, message localtaskstore.OutboxMess } } +func payloadFromValue(value any) map[string]any { + data, _ := json.Marshal(value) + var payload map[string]any + _ = json.Unmarshal(data, &payload) + if payload == nil { + return map[string]any{} + } + return payload +} + +func formatTime(value time.Time) string { + if value.IsZero() { + return "" + } + return value.UTC().Format(time.RFC3339) +} + func sleepContext(ctx context.Context, d time.Duration) error { timer := time.NewTimer(d) defer timer.Stop() diff --git a/app/services/wsclient/client_test.go b/app/services/wsclient/client_test.go index b45e580..e326c5f 100644 --- a/app/services/wsclient/client_test.go +++ b/app/services/wsclient/client_test.go @@ -38,8 +38,10 @@ func TestClientSendsHelloAndMarksActiveAfterHelloAck(t *testing.T) { if written.AgentID != "agent_ws_test" { t.Fatalf("written agent_id = %q", written.AgentID) } - if len(written.Payload) != 0 { - t.Fatalf("hello payload = %#v, want empty object", written.Payload) + for _, key := range []string{"running_task", "received_not_started", "unacked_finals", "unacked_output", "spool_status", "client_version"} { + if _, ok := written.Payload[key]; !ok { + t.Fatalf("hello payload missing %q: %#v", key, written.Payload) + } } conn.readCh <- wsprotocol.Envelope{ @@ -61,6 +63,56 @@ func TestClientSendsHelloAndMarksActiveAfterHelloAck(t *testing.T) { } } +func TestClientHelloPayloadIncludesLocalReconnectSnapshot(t *testing.T) { + store := newClientTestStore(t) + _, err := store.RecordReceived(localtaskstore.TaskReceipt{TaskID: "task-received", ExecutionAttemptID: "attempt-received"}) + requireNoError(t, err) + requireNoError(t, store.RecordStarted("task-running", "attempt-running")) + requireNoError(t, store.AppendOutputChunk(localtaskstore.OutputChunk{ + MessageID: "msg-output-1", + TaskID: "task-running", + ExecutionAttemptID: "attempt-running", + Stream: "stdout", + Sequence: 3, + Payload: "hello", + ByteCount: 5, + })) + requireNoError(t, store.RecordFinal(localtaskstore.FinalResult{ + MessageID: "msg-final-1", + TaskID: "task-final", + ExecutionAttemptID: "attempt-final", + Status: "completed", + ExitCode: 0, + Payload: `{"status":"completed","exit_code":0,"output_truncated":false,"error_truncated":false}`, + })) + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithRecoveryStore(store)) + + runCtx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + written := conn.waitForWrite(t) + running, ok := written.Payload["running_task"].(map[string]any) + if !ok || running["task_id"] != "task-running" { + t.Fatalf("running_task = %#v", written.Payload["running_task"]) + } + if len(written.Payload["received_not_started"].([]any)) != 1 { + t.Fatalf("received_not_started = %#v", written.Payload["received_not_started"]) + } + if len(written.Payload["unacked_finals"].([]any)) != 1 { + t.Fatalf("unacked_finals = %#v", written.Payload["unacked_finals"]) + } + if len(written.Payload["unacked_output"].([]any)) != 1 { + t.Fatalf("unacked_output = %#v", written.Payload["unacked_output"]) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + func TestClientDialUsesSignedUpgradeHeaders(t *testing.T) { conn := newFakeConn() dialer := &fakeDialer{conn: conn} @@ -380,6 +432,101 @@ func TestClientReplaysUnackedMessagesAfterHelloAck(t *testing.T) { } } +func TestClientHelloAckAppliesFinalDiscardAndTargetedReplayDirectives(t *testing.T) { + store := newClientTestStore(t) + _, err := store.RecordReceived(localtaskstore.TaskReceipt{TaskID: "task-stale", ExecutionAttemptID: "attempt-stale"}) + requireNoError(t, err) + requireNoError(t, store.AppendOutputChunk(localtaskstore.OutputChunk{ + MessageID: "msg-output-1", + TaskID: "task-running", + ExecutionAttemptID: "attempt-running", + Stream: "stdout", + Sequence: 1, + Payload: "old", + ByteCount: 3, + })) + requireNoError(t, store.AppendOutputChunk(localtaskstore.OutputChunk{ + MessageID: "msg-output-2", + TaskID: "task-running", + ExecutionAttemptID: "attempt-running", + Stream: "stdout", + Sequence: 2, + Payload: "needed", + ByteCount: 6, + })) + requireNoError(t, store.RecordFinal(localtaskstore.FinalResult{ + MessageID: "msg-final-1", + TaskID: "task-final", + ExecutionAttemptID: "attempt-final", + Status: "completed", + Payload: `{"status":"completed","exit_code":0,"output_truncated":false,"error_truncated":false}`, + })) + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithResultOutbox(store), WithReceiptStore(store), WithRecoveryStore(store)) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + hello := conn.waitForWrite(t) + conn.readCh <- helloAckEnvelopeWithDirectives(hello.MessageID, wsprotocol.HelloAckPayload{ + AcknowledgedFinalMessageIDs: []string{"msg-final-1"}, + DiscardedAttempts: []wsprotocol.DiscardedAttempt{{TaskID: "task-stale", ExecutionAttemptID: "attempt-stale", Reason: "stale_attempt"}}, + OutputReplay: []wsprotocol.OutputReplayDirective{{TaskID: "task-running", ExecutionAttemptID: "attempt-running", Stream: wsprotocol.StreamStdout, NextSequence: 2}}, + }) + + replayed := conn.waitForWrite(t) + if replayed.MessageID != "msg-output-2" { + t.Fatalf("replayed message = %#v", replayed) + } + waitFor(t, func() bool { return client.IsActive() }, "client active after directives") + state, err := store.TaskState("task-stale", "attempt-stale") + requireNoError(t, err) + if state.Exists { + t.Fatalf("discarded state still exists: %#v", state) + } + messages, err := store.UnackedMessages() + requireNoError(t, err) + if containsMessage(messages, "msg-final-1") { + t.Fatalf("acked final still unacked: %#v", messages) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + +func TestClientHelloAckDiscardDoesNotEraseRunningAttempt(t *testing.T) { + store := newClientTestStore(t) + requireNoError(t, store.RecordStarted("task-running", "attempt-running")) + conn := newFakeConn() + dialer := &fakeDialer{conn: conn} + client := newTestClient(t, dialer, WithReceiptStore(store), WithRecoveryStore(store)) + + runCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + done := make(chan error, 1) + go func() { done <- client.Start(runCtx) }() + + hello := conn.waitForWrite(t) + conn.readCh <- helloAckEnvelopeWithDirectives(hello.MessageID, wsprotocol.HelloAckPayload{ + DiscardedAttempts: []wsprotocol.DiscardedAttempt{{TaskID: "task-running", ExecutionAttemptID: "attempt-running", Reason: "stale_attempt"}}, + }) + + waitFor(t, func() bool { return client.IsActive() }, "client active after discard directive") + state, err := store.TaskState("task-running", "attempt-running") + requireNoError(t, err) + if !state.Exists || state.Status != localtaskstore.TaskStatusRunning { + t.Fatalf("running state = %#v", state) + } + cancel() + if err := <-done; err != nil { + t.Fatalf("Start returned error: %v", err) + } +} + func TestClientRetryableErrorKeepsConnectionAndOutboxMessage(t *testing.T) { store := newClientTestStore(t) requireNoError(t, store.AppendOutputChunk(localtaskstore.OutputChunk{ @@ -597,6 +744,10 @@ func WithReceiptStore(store localtaskstore.ReceiptStore) clientOption { return func(cfg *Config) { cfg.ReceiptStore = store } } +func WithRecoveryStore(store localtaskstore.RecoveryStore) clientOption { + return func(cfg *Config) { cfg.RecoveryStore = store } +} + func WithTaskEnqueuer(enqueuer TaskEnqueuer) clientOption { return func(cfg *Config) { cfg.TaskEnqueuer = enqueuer } } @@ -740,6 +891,28 @@ func helloAckEnvelope(ackedMessageID string) wsprotocol.Envelope { } } +func helloAckEnvelopeWithDirectives(ackedMessageID string, directives wsprotocol.HelloAckPayload) wsprotocol.Envelope { + directives.AckedMessageID = ackedMessageID + directives.AckedType = wsprotocol.TypeAgentHello + if directives.AcknowledgedFinalMessageIDs == nil { + directives.AcknowledgedFinalMessageIDs = []string{} + } + if directives.DiscardedAttempts == nil { + directives.DiscardedAttempts = []wsprotocol.DiscardedAttempt{} + } + if directives.OutputReplay == nil { + directives.OutputReplay = []wsprotocol.OutputReplayDirective{} + } + return wsprotocol.Envelope{ + ProtocolVersion: wsprotocol.ProtocolVersion, + MessageID: "msg_hello_ack", + Type: wsprotocol.TypeAgentHelloAck, + AgentID: "agent_ws_test", + SentAt: time.Now().UTC().Format(time.RFC3339), + Payload: payloadMapForTest(directives), + } +} + func deliverEnvelope(messageID, taskID, attemptID, command string, priority int) wsprotocol.Envelope { return wsprotocol.Envelope{ ProtocolVersion: wsprotocol.ProtocolVersion, @@ -809,6 +982,15 @@ func intValuePtr(value int) *int { return &value } +func containsMessage(messages []localtaskstore.OutboxMessage, messageID string) bool { + for _, message := range messages { + if message.MessageID == messageID { + return true + } + } + return false +} + func saveTestPrivateKey(t *testing.T, dir string) string { t.Helper() privateKey, err := rsa.GenerateKey(rand.Reader, 2048) diff --git a/internal/wsprotocol/message.go b/internal/wsprotocol/message.go index c1878bc..db1c2b4 100644 --- a/internal/wsprotocol/message.go +++ b/internal/wsprotocol/message.go @@ -70,6 +70,74 @@ type TaskDeliverPayload struct { Priority int `json:"priority"` } +type HelloPayload struct { + RunningTask *RunningTaskSnapshot `json:"running_task"` + ReceivedNotStarted []ReceivedNotStartedAttempt `json:"received_not_started"` + UnackedFinals []UnackedFinalSnapshot `json:"unacked_finals"` + UnackedOutput []UnackedOutputRange `json:"unacked_output"` + SpoolStatus SpoolStatus `json:"spool_status"` + ClientVersion string `json:"client_version"` +} + +type RunningTaskSnapshot struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + StartedAt string `json:"started_at"` + LastOutputSequence map[string]int `json:"last_output_sequence"` +} + +type ReceivedNotStartedAttempt struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + ReceivedAt string `json:"received_at"` +} + +type UnackedFinalSnapshot struct { + MessageID string `json:"message_id"` + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + Status FinalStatus `json:"status"` + ExitCode int `json:"exit_code"` + OutputTruncated bool `json:"output_truncated"` + ErrorTruncated bool `json:"error_truncated"` +} + +type UnackedOutputRange struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + Stream Stream `json:"stream"` + FirstSequence int `json:"first_sequence"` + LastSequence int `json:"last_sequence"` + TruncatedLocally bool `json:"truncated_locally"` +} + +type SpoolStatus struct { + BytesUsed int64 `json:"bytes_used"` + ByteCap int64 `json:"byte_cap"` + HasRotatedChunks bool `json:"has_rotated_chunks"` +} + +type HelloAckPayload struct { + AckedMessageID string `json:"acked_message_id"` + AckedType MessageType `json:"acked_type"` + AcknowledgedFinalMessageIDs []string `json:"acknowledged_final_message_ids"` + DiscardedAttempts []DiscardedAttempt `json:"discarded_attempts"` + OutputReplay []OutputReplayDirective `json:"output_replay"` +} + +type DiscardedAttempt struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + Reason string `json:"reason"` +} + +type OutputReplayDirective struct { + TaskID string `json:"task_id"` + ExecutionAttemptID string `json:"execution_attempt_id"` + Stream Stream `json:"stream"` + NextSequence int `json:"next_sequence"` +} + func (e Envelope) Validate(authenticatedAgentID string) error { if e.ProtocolVersion != ProtocolVersion { return fmt.Errorf("unsupported protocol_version: %d", e.ProtocolVersion) @@ -132,6 +200,49 @@ func (p TaskDeliverPayload) Validate() error { return nil } +func (p HelloPayload) Validate() error { + if p.ClientVersion == "" { + return fmt.Errorf("client_version is required") + } + if p.RunningTask != nil { + if p.RunningTask.TaskID == "" { + return fmt.Errorf("running_task.task_id is required") + } + if p.RunningTask.ExecutionAttemptID == "" { + return fmt.Errorf("running_task.execution_attempt_id is required") + } + if p.RunningTask.LastOutputSequence == nil { + return fmt.Errorf("running_task.last_output_sequence is required") + } + } + for _, attempt := range p.ReceivedNotStarted { + if attempt.TaskID == "" || attempt.ExecutionAttemptID == "" { + return fmt.Errorf("received_not_started entries require task_id and execution_attempt_id") + } + } + for _, final := range p.UnackedFinals { + if final.MessageID == "" || final.TaskID == "" || final.ExecutionAttemptID == "" { + return fmt.Errorf("unacked_finals entries require message_id, task_id, and execution_attempt_id") + } + } + for _, output := range p.UnackedOutput { + if output.TaskID == "" || output.ExecutionAttemptID == "" { + return fmt.Errorf("unacked_output entries require task_id and execution_attempt_id") + } + if output.Stream != StreamStdout && output.Stream != StreamStderr { + return fmt.Errorf("unacked_output.stream must be stdout or stderr") + } + if output.FirstSequence <= 0 || output.LastSequence < output.FirstSequence { + return fmt.Errorf("unacked_output sequence range is invalid") + } + } + return nil +} + +func (p HelloAckPayload) HasReconciliationDirectives() bool { + return p.AcknowledgedFinalMessageIDs != nil || p.DiscardedAttempts != nil || p.OutputReplay != nil +} + func DecodePayload[T any](e Envelope) (T, error) { var payload T diff --git a/internal/wsprotocol/message_test.go b/internal/wsprotocol/message_test.go index 543a3b2..6b4a337 100644 --- a/internal/wsprotocol/message_test.go +++ b/internal/wsprotocol/message_test.go @@ -189,6 +189,69 @@ func TestPayloadValidation(t *testing.T) { }) } +func TestHelloPayloadUsesReconnectSnapshotShape(t *testing.T) { + payload := HelloPayload{ + RunningTask: &RunningTaskSnapshot{ + TaskID: "task-1", + ExecutionAttemptID: "attempt-1", + StartedAt: "2026-04-28T00:00:00Z", + LastOutputSequence: map[string]int{"stdout": 2, "stderr": 1}, + }, + ReceivedNotStarted: []ReceivedNotStartedAttempt{{TaskID: "task-2", ExecutionAttemptID: "attempt-2", ReceivedAt: "2026-04-28T00:00:01Z"}}, + UnackedFinals: []UnackedFinalSnapshot{{ + MessageID: "msg-final-1", + TaskID: "task-3", + ExecutionAttemptID: "attempt-3", + Status: FinalStatusCompleted, + }}, + UnackedOutput: []UnackedOutputRange{{TaskID: "task-1", ExecutionAttemptID: "attempt-1", Stream: StreamStdout, FirstSequence: 2, LastSequence: 4}}, + SpoolStatus: SpoolStatus{BytesUsed: 128, ByteCap: 1024, HasRotatedChunks: true}, + ClientVersion: "test-version", + } + + if err := payload.Validate(); err != nil { + t.Fatalf("expected hello payload to validate, got %v", err) + } + + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal hello payload: %v", err) + } + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal hello payload: %v", err) + } + if !reflect.DeepEqual(sortedKeys(decoded), []string{"client_version", "received_not_started", "running_task", "spool_status", "unacked_finals", "unacked_output"}) { + t.Fatalf("hello payload fields = %v", sortedKeys(decoded)) + } + lastOutput, ok := decoded["running_task"].(map[string]any)["last_output_sequence"].(map[string]any) + if !ok || lastOutput["stdout"] != float64(2) || lastOutput["stderr"] != float64(1) { + t.Fatalf("last_output_sequence = %#v", decoded["running_task"]) + } +} + +func TestHelloAckPayloadUsesReconciliationDirectiveShape(t *testing.T) { + payload := HelloAckPayload{ + AckedMessageID: "msg-hello", + AckedType: TypeAgentHello, + AcknowledgedFinalMessageIDs: []string{"msg-final-1"}, + DiscardedAttempts: []DiscardedAttempt{{TaskID: "task-2", ExecutionAttemptID: "attempt-2", Reason: "stale_attempt"}}, + OutputReplay: []OutputReplayDirective{{TaskID: "task-1", ExecutionAttemptID: "attempt-1", Stream: StreamStderr, NextSequence: 14}}, + } + + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal hello ack: %v", err) + } + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal hello ack: %v", err) + } + if !reflect.DeepEqual(sortedKeys(decoded), []string{"acked_message_id", "acked_type", "acknowledged_final_message_ids", "discarded_attempts", "output_replay"}) { + t.Fatalf("hello ack fields = %v", sortedKeys(decoded)) + } +} + func intPtr(value int) *int { return &value } diff --git a/main.go b/main.go index 039fe9c..2e061ae 100644 --- a/main.go +++ b/main.go @@ -352,6 +352,7 @@ func newDefaultWebSocketRuntime(localStore *localtaskstore.Store, enqueuer wscli PingInterval: appconf.WebSocketPingInterval(), ResultOutbox: localStore, ReceiptStore: localStore, + RecoveryStore: localStore, TaskEnqueuer: enqueuer, }) }