Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions internal/wsprotocol/ack_sequence.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package wsprotocol

type AckPayload struct {
AckedMessageID string `json:"acked_message_id"`
AckedType MessageType `json:"acked_type"`
TaskID string `json:"task_id,omitempty"`
ExecutionAttemptID string `json:"execution_attempt_id,omitempty"`
HighestOutputSequence *int `json:"highest_output_sequence,omitempty"`
}

type AckOptions struct {
AckedMessageID string
AckedType MessageType
TaskID string
ExecutionAttemptID string
HighestOutputSequence *int
}

type ErrorPayload struct {
Code string `json:"code"`
Message string `json:"message"`
Retryable bool `json:"retryable"`
RelatedMessageID string `json:"related_message_id,omitempty"`
}

type ErrorOptions struct {
Code string
Message string
Retryable bool
RelatedMessageID string
}

type MessageRecordResult int

const (
MessageAccepted MessageRecordResult = iota
MessageDuplicate
)

type MessageTracker struct {
processed map[string]struct{}
}

type SequenceStatus int

const (
SequenceAccepted SequenceStatus = iota
SequenceDuplicate
SequenceGap
)

type SequenceResult struct {
Status SequenceStatus
HighestOutputSequence int
ExpectedSequence int
}

type OutputSequenceTracker struct {
highest map[sequenceKey]int
}

type sequenceKey struct {
executionAttemptID string
stream Stream
}

func BuildAck(opts AckOptions) AckPayload {
return AckPayload{
AckedMessageID: opts.AckedMessageID,
AckedType: opts.AckedType,
TaskID: opts.TaskID,
ExecutionAttemptID: opts.ExecutionAttemptID,
HighestOutputSequence: opts.HighestOutputSequence,
}
}

func BuildError(opts ErrorOptions) ErrorPayload {
return ErrorPayload{
Code: opts.Code,
Message: opts.Message,
Retryable: opts.Retryable,
RelatedMessageID: opts.RelatedMessageID,
}
}

func NewMessageTracker() *MessageTracker {
return &MessageTracker{processed: make(map[string]struct{})}
}

func (t *MessageTracker) Record(messageID string) MessageRecordResult {
if _, ok := t.processed[messageID]; ok {
return MessageDuplicate
}

t.processed[messageID] = struct{}{}
return MessageAccepted
}

func NewOutputSequenceTracker() *OutputSequenceTracker {
return &OutputSequenceTracker{highest: make(map[sequenceKey]int)}
}

func (t *OutputSequenceTracker) Record(executionAttemptID string, stream Stream, sequence int) SequenceResult {
key := sequenceKey{executionAttemptID: executionAttemptID, stream: stream}
highest := t.highest[key]
expected := highest + 1

if sequence == expected {
t.highest[key] = sequence
return SequenceResult{Status: SequenceAccepted, HighestOutputSequence: sequence, ExpectedSequence: sequence + 1}
}

if sequence <= highest {
return SequenceResult{Status: SequenceDuplicate, HighestOutputSequence: highest, ExpectedSequence: expected}
}

return SequenceResult{Status: SequenceGap, HighestOutputSequence: highest, ExpectedSequence: expected}
}

func (r SequenceResult) Retryable() bool {
return r.Status == SequenceGap
}
141 changes: 141 additions & 0 deletions internal/wsprotocol/ack_sequence_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package wsprotocol

import (
"encoding/json"
"reflect"
"testing"
)

func TestAcknowledgementPayload(t *testing.T) {
payload := BuildAck(AckOptions{
AckedMessageID: "msg_123",
AckedType: TypeTaskOutput,
TaskID: "tsk_123",
ExecutionAttemptID: "attempt_123",
HighestOutputSequence: intPtr(12),
})

if payload.AckedMessageID != "msg_123" {
t.Errorf("acked_message_id = %q, want msg_123", payload.AckedMessageID)
}
if payload.AckedType != TypeTaskOutput {
t.Errorf("acked_type = %q, want %q", payload.AckedType, TypeTaskOutput)
}
if payload.HighestOutputSequence == nil || *payload.HighestOutputSequence != 12 {
t.Fatalf("highest_output_sequence = %v, want 12", payload.HighestOutputSequence)
}
}

func TestSampleAckPayloadUsesCanonicalFieldNames(t *testing.T) {
payload := BuildAck(AckOptions{
AckedMessageID: "msg_123",
AckedType: TypeTaskOutput,
TaskID: "tsk_123",
ExecutionAttemptID: "attempt_123",
HighestOutputSequence: intPtr(12),
})

data, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal ack: %v", err)
}

var decoded map[string]any
if err := json.Unmarshal(data, &decoded); err != nil {
t.Fatalf("unmarshal ack: %v", err)
}

if !reflect.DeepEqual(sortedKeys(decoded), []string{"acked_message_id", "acked_type", "execution_attempt_id", "highest_output_sequence", "task_id"}) {
t.Errorf("ack fields = %v", sortedKeys(decoded))
}
}

func TestErrorPayload(t *testing.T) {
payload := BuildError(ErrorOptions{
Code: "output_sequence_gap",
Message: "expected sequence 8",
Retryable: true,
RelatedMessageID: "msg_123",
})

if payload.Code != "output_sequence_gap" {
t.Errorf("code = %q, want output_sequence_gap", payload.Code)
}
if !payload.Retryable {
t.Error("expected retryable error")
}
if payload.RelatedMessageID != "msg_123" {
t.Errorf("related_message_id = %q, want msg_123", payload.RelatedMessageID)
}
}

func TestMessageTracker(t *testing.T) {
tracker := NewMessageTracker()

if result := tracker.Record("msg_123"); result != MessageAccepted {
t.Errorf("first record = %v, want %v", result, MessageAccepted)
}
if result := tracker.Record("msg_123"); result != MessageDuplicate {
t.Errorf("second record = %v, want %v", result, MessageDuplicate)
}
}

func TestOutputSequenceTracker(t *testing.T) {
t.Run("accepts contiguous sequence", func(t *testing.T) {
tracker := NewOutputSequenceTracker()

result := tracker.Record("attempt_123", StreamStdout, 1)

if result.Status != SequenceAccepted {
t.Errorf("status = %v, want %v", result.Status, SequenceAccepted)
}
if result.HighestOutputSequence != 1 {
t.Errorf("highest = %d, want 1", result.HighestOutputSequence)
}
})

t.Run("re-acks duplicate sequence", func(t *testing.T) {
tracker := NewOutputSequenceTracker()
tracker.Record("attempt_123", StreamStdout, 1)

result := tracker.Record("attempt_123", StreamStdout, 1)

if result.Status != SequenceDuplicate {
t.Errorf("status = %v, want %v", result.Status, SequenceDuplicate)
}
if result.HighestOutputSequence != 1 {
t.Errorf("highest = %d, want 1", result.HighestOutputSequence)
}
})

t.Run("rejects future gaps as retryable", func(t *testing.T) {
tracker := NewOutputSequenceTracker()
tracker.Record("attempt_123", StreamStdout, 1)

result := tracker.Record("attempt_123", StreamStdout, 3)

if result.Status != SequenceGap {
t.Errorf("status = %v, want %v", result.Status, SequenceGap)
}
if result.HighestOutputSequence != 1 {
t.Errorf("highest = %d, want 1", result.HighestOutputSequence)
}
if result.ExpectedSequence != 2 {
t.Errorf("expected = %d, want 2", result.ExpectedSequence)
}
if !result.Retryable() {
t.Error("expected gap to be retryable")
}
})

t.Run("tracks streams independently", func(t *testing.T) {
tracker := NewOutputSequenceTracker()
tracker.Record("attempt_123", StreamStdout, 1)

result := tracker.Record("attempt_123", StreamStderr, 1)

if result.Status != SequenceAccepted {
t.Errorf("status = %v, want %v", result.Status, SequenceAccepted)
}
})
}
Loading
Loading