From 3fe67892f173b8aceec6893d60bb6a8e96e5ca8a Mon Sep 17 00:00:00 2001 From: Mohammad Aziz Date: Mon, 27 Apr 2026 12:13:30 +0530 Subject: [PATCH] feat(wsprotocol): implement shared phase-1 WebSocket protocol structs and validation Go protocol package with: - Envelope, OutputPayload, FinalPayload structs and JSON tags - Envelope.Validate(), OutputPayload.Validate(), FinalPayload.Validate() - AckPayload/ErrorPayload builders and MessageTracker/OutputSequenceTracker for dedup and gap handling - DecodePayload generic helper for typed payload extraction - IsSupportedType, isTaskType, isExecutionType helpers Test: message_test.go and ack_sequence_test.go --- internal/wsprotocol/ack_sequence.go | 122 +++++++++++++ internal/wsprotocol/ack_sequence_test.go | 141 +++++++++++++++ internal/wsprotocol/message.go | 166 ++++++++++++++++++ internal/wsprotocol/message_test.go | 209 +++++++++++++++++++++++ 4 files changed, 638 insertions(+) create mode 100644 internal/wsprotocol/ack_sequence.go create mode 100644 internal/wsprotocol/ack_sequence_test.go create mode 100644 internal/wsprotocol/message.go create mode 100644 internal/wsprotocol/message_test.go diff --git a/internal/wsprotocol/ack_sequence.go b/internal/wsprotocol/ack_sequence.go new file mode 100644 index 0000000..0582f8d --- /dev/null +++ b/internal/wsprotocol/ack_sequence.go @@ -0,0 +1,122 @@ +package wsprotocol + +type AckPayload struct { + AckedMessageID string `json:"acked_message_id"` + AckedType MessageType `json:"acked_type"` + TaskID string `json:"task_id,omitempty"` + ExecutionAttemptID string `json:"execution_attempt_id,omitempty"` + HighestOutputSequence *int `json:"highest_output_sequence,omitempty"` +} + +type AckOptions struct { + AckedMessageID string + AckedType MessageType + TaskID string + ExecutionAttemptID string + HighestOutputSequence *int +} + +type ErrorPayload struct { + Code string `json:"code"` + Message string `json:"message"` + Retryable bool `json:"retryable"` + RelatedMessageID string `json:"related_message_id,omitempty"` +} + +type ErrorOptions struct { + Code string + Message string + Retryable bool + RelatedMessageID string +} + +type MessageRecordResult int + +const ( + MessageAccepted MessageRecordResult = iota + MessageDuplicate +) + +type MessageTracker struct { + processed map[string]struct{} +} + +type SequenceStatus int + +const ( + SequenceAccepted SequenceStatus = iota + SequenceDuplicate + SequenceGap +) + +type SequenceResult struct { + Status SequenceStatus + HighestOutputSequence int + ExpectedSequence int +} + +type OutputSequenceTracker struct { + highest map[sequenceKey]int +} + +type sequenceKey struct { + executionAttemptID string + stream Stream +} + +func BuildAck(opts AckOptions) AckPayload { + return AckPayload{ + AckedMessageID: opts.AckedMessageID, + AckedType: opts.AckedType, + TaskID: opts.TaskID, + ExecutionAttemptID: opts.ExecutionAttemptID, + HighestOutputSequence: opts.HighestOutputSequence, + } +} + +func BuildError(opts ErrorOptions) ErrorPayload { + return ErrorPayload{ + Code: opts.Code, + Message: opts.Message, + Retryable: opts.Retryable, + RelatedMessageID: opts.RelatedMessageID, + } +} + +func NewMessageTracker() *MessageTracker { + return &MessageTracker{processed: make(map[string]struct{})} +} + +func (t *MessageTracker) Record(messageID string) MessageRecordResult { + if _, ok := t.processed[messageID]; ok { + return MessageDuplicate + } + + t.processed[messageID] = struct{}{} + return MessageAccepted +} + +func NewOutputSequenceTracker() *OutputSequenceTracker { + return &OutputSequenceTracker{highest: make(map[sequenceKey]int)} +} + +func (t *OutputSequenceTracker) Record(executionAttemptID string, stream Stream, sequence int) SequenceResult { + key := sequenceKey{executionAttemptID: executionAttemptID, stream: stream} + highest := t.highest[key] + expected := highest + 1 + + if sequence == expected { + t.highest[key] = sequence + return SequenceResult{Status: SequenceAccepted, HighestOutputSequence: sequence, ExpectedSequence: sequence + 1} + } + + if sequence <= highest { + return SequenceResult{Status: SequenceDuplicate, HighestOutputSequence: highest, ExpectedSequence: expected} + } + + return SequenceResult{Status: SequenceGap, HighestOutputSequence: highest, ExpectedSequence: expected} +} + +func (r SequenceResult) Retryable() bool { + return r.Status == SequenceGap +} diff --git a/internal/wsprotocol/ack_sequence_test.go b/internal/wsprotocol/ack_sequence_test.go new file mode 100644 index 0000000..6b49ac9 --- /dev/null +++ b/internal/wsprotocol/ack_sequence_test.go @@ -0,0 +1,141 @@ +package wsprotocol + +import ( + "encoding/json" + "reflect" + "testing" +) + +func TestAcknowledgementPayload(t *testing.T) { + payload := BuildAck(AckOptions{ + AckedMessageID: "msg_123", + AckedType: TypeTaskOutput, + TaskID: "tsk_123", + ExecutionAttemptID: "attempt_123", + HighestOutputSequence: intPtr(12), + }) + + if payload.AckedMessageID != "msg_123" { + t.Errorf("acked_message_id = %q, want msg_123", payload.AckedMessageID) + } + if payload.AckedType != TypeTaskOutput { + t.Errorf("acked_type = %q, want %q", payload.AckedType, TypeTaskOutput) + } + if payload.HighestOutputSequence == nil || *payload.HighestOutputSequence != 12 { + t.Fatalf("highest_output_sequence = %v, want 12", payload.HighestOutputSequence) + } +} + +func TestSampleAckPayloadUsesCanonicalFieldNames(t *testing.T) { + payload := BuildAck(AckOptions{ + AckedMessageID: "msg_123", + AckedType: TypeTaskOutput, + TaskID: "tsk_123", + ExecutionAttemptID: "attempt_123", + HighestOutputSequence: intPtr(12), + }) + + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("marshal ack: %v", err) + } + + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal ack: %v", err) + } + + if !reflect.DeepEqual(sortedKeys(decoded), []string{"acked_message_id", "acked_type", "execution_attempt_id", "highest_output_sequence", "task_id"}) { + t.Errorf("ack fields = %v", sortedKeys(decoded)) + } +} + +func TestErrorPayload(t *testing.T) { + payload := BuildError(ErrorOptions{ + Code: "output_sequence_gap", + Message: "expected sequence 8", + Retryable: true, + RelatedMessageID: "msg_123", + }) + + if payload.Code != "output_sequence_gap" { + t.Errorf("code = %q, want output_sequence_gap", payload.Code) + } + if !payload.Retryable { + t.Error("expected retryable error") + } + if payload.RelatedMessageID != "msg_123" { + t.Errorf("related_message_id = %q, want msg_123", payload.RelatedMessageID) + } +} + +func TestMessageTracker(t *testing.T) { + tracker := NewMessageTracker() + + if result := tracker.Record("msg_123"); result != MessageAccepted { + t.Errorf("first record = %v, want %v", result, MessageAccepted) + } + if result := tracker.Record("msg_123"); result != MessageDuplicate { + t.Errorf("second record = %v, want %v", result, MessageDuplicate) + } +} + +func TestOutputSequenceTracker(t *testing.T) { + t.Run("accepts contiguous sequence", func(t *testing.T) { + tracker := NewOutputSequenceTracker() + + result := tracker.Record("attempt_123", StreamStdout, 1) + + if result.Status != SequenceAccepted { + t.Errorf("status = %v, want %v", result.Status, SequenceAccepted) + } + if result.HighestOutputSequence != 1 { + t.Errorf("highest = %d, want 1", result.HighestOutputSequence) + } + }) + + t.Run("re-acks duplicate sequence", func(t *testing.T) { + tracker := NewOutputSequenceTracker() + tracker.Record("attempt_123", StreamStdout, 1) + + result := tracker.Record("attempt_123", StreamStdout, 1) + + if result.Status != SequenceDuplicate { + t.Errorf("status = %v, want %v", result.Status, SequenceDuplicate) + } + if result.HighestOutputSequence != 1 { + t.Errorf("highest = %d, want 1", result.HighestOutputSequence) + } + }) + + t.Run("rejects future gaps as retryable", func(t *testing.T) { + tracker := NewOutputSequenceTracker() + tracker.Record("attempt_123", StreamStdout, 1) + + result := tracker.Record("attempt_123", StreamStdout, 3) + + if result.Status != SequenceGap { + t.Errorf("status = %v, want %v", result.Status, SequenceGap) + } + if result.HighestOutputSequence != 1 { + t.Errorf("highest = %d, want 1", result.HighestOutputSequence) + } + if result.ExpectedSequence != 2 { + t.Errorf("expected = %d, want 2", result.ExpectedSequence) + } + if !result.Retryable() { + t.Error("expected gap to be retryable") + } + }) + + t.Run("tracks streams independently", func(t *testing.T) { + tracker := NewOutputSequenceTracker() + tracker.Record("attempt_123", StreamStdout, 1) + + result := tracker.Record("attempt_123", StreamStderr, 1) + + if result.Status != SequenceAccepted { + t.Errorf("status = %v, want %v", result.Status, SequenceAccepted) + } + }) +} diff --git a/internal/wsprotocol/message.go b/internal/wsprotocol/message.go new file mode 100644 index 0000000..80f375d --- /dev/null +++ b/internal/wsprotocol/message.go @@ -0,0 +1,166 @@ +package wsprotocol + +import ( + "encoding/json" + "fmt" +) + +const ProtocolVersion = 1 + +type MessageType string + +const ( + TypeAgentHello MessageType = "agent.hello" + TypeAgentHelloAck MessageType = "agent.hello_ack" + TypeTaskDeliver MessageType = "task.deliver" + TypeTaskReceived MessageType = "task.received" + TypeTaskStarted MessageType = "task.started" + TypeTaskLeaseHeartbeat MessageType = "task.lease_heartbeat" + TypeTaskOutput MessageType = "task.output" + TypeTaskFinal MessageType = "task.final" + TypeAck MessageType = "ack" + TypeError MessageType = "error" +) + +type Stream string + +const ( + StreamStdout Stream = "stdout" + StreamStderr Stream = "stderr" +) + +type FinalStatus string + +const ( + FinalStatusCompleted FinalStatus = "completed" + FinalStatusFailed FinalStatus = "failed" + FinalStatusInterrupted FinalStatus = "interrupted" +) + +type Envelope struct { + ProtocolVersion int `json:"protocol_version"` + MessageID string `json:"message_id"` + Type MessageType `json:"type"` + AgentID string `json:"agent_id"` + TaskID string `json:"task_id,omitempty"` + ExecutionAttemptID string `json:"execution_attempt_id,omitempty"` + Sequence *int `json:"sequence,omitempty"` + SentAt string `json:"sent_at"` + Payload map[string]any `json:"payload"` +} + +type OutputPayload struct { + Stream Stream `json:"stream"` + Data string `json:"data"` + ByteCount int `json:"byte_count"` + TruncatedLocally bool `json:"truncated_locally,omitempty"` +} + +type FinalPayload struct { + Status FinalStatus `json:"status"` + ExitCode int `json:"exit_code"` + Output string `json:"output,omitempty"` + Error string `json:"error,omitempty"` + OutputTruncated bool `json:"output_truncated"` + ErrorTruncated bool `json:"error_truncated"` +} + +func (e Envelope) Validate(authenticatedAgentID string) error { + if e.ProtocolVersion != ProtocolVersion { + return fmt.Errorf("unsupported protocol_version: %d", e.ProtocolVersion) + } + if e.MessageID == "" { + return fmt.Errorf("message_id is required") + } + if !IsSupportedType(e.Type) { + return fmt.Errorf("unsupported type: %s", e.Type) + } + if e.AgentID == "" { + return fmt.Errorf("agent_id is required") + } + if e.AgentID != authenticatedAgentID { + return fmt.Errorf("agent_id does not match authenticated agent") + } + if e.SentAt == "" { + return fmt.Errorf("sent_at is required") + } + if e.Payload == nil { + return fmt.Errorf("payload must be an object") + } + if isTaskType(e.Type) && e.TaskID == "" { + return fmt.Errorf("task_id is required for task messages") + } + if isExecutionType(e.Type) && e.ExecutionAttemptID == "" { + return fmt.Errorf("execution_attempt_id is required for execution messages") + } + if e.Type == TypeTaskOutput && e.Sequence == nil { + return fmt.Errorf("sequence is required for output messages") + } + + return nil +} + +func (p OutputPayload) Validate() error { + if p.Stream != StreamStdout && p.Stream != StreamStderr { + return fmt.Errorf("stream must be stdout or stderr") + } + if p.ByteCount < 0 { + return fmt.Errorf("byte_count must be non-negative") + } + + return nil +} + +func (p FinalPayload) Validate() error { + switch p.Status { + case FinalStatusCompleted, FinalStatusFailed, FinalStatusInterrupted: + return nil + default: + return fmt.Errorf("status must be completed, failed, or interrupted") + } +} + +func DecodePayload[T any](e Envelope) (T, error) { + var payload T + + data, err := json.Marshal(e.Payload) + if err != nil { + return payload, fmt.Errorf("marshal payload: %w", err) + } + if err := json.Unmarshal(data, &payload); err != nil { + return payload, fmt.Errorf("unmarshal payload: %w", err) + } + + return payload, nil +} + +func IsSupportedType(messageType MessageType) bool { + switch messageType { + case TypeAgentHello, + TypeAgentHelloAck, + TypeTaskDeliver, + TypeTaskReceived, + TypeTaskStarted, + TypeTaskLeaseHeartbeat, + TypeTaskOutput, + TypeTaskFinal, + TypeAck, + TypeError: + return true + default: + return false + } +} + +func isTaskType(messageType MessageType) bool { + return messageType == TypeTaskDeliver || + messageType == TypeTaskReceived || + messageType == TypeTaskStarted || + messageType == TypeTaskLeaseHeartbeat || + messageType == TypeTaskOutput || + messageType == TypeTaskFinal +} + +func isExecutionType(messageType MessageType) bool { + return isTaskType(messageType) +} diff --git a/internal/wsprotocol/message_test.go b/internal/wsprotocol/message_test.go new file mode 100644 index 0000000..543a3b2 --- /dev/null +++ b/internal/wsprotocol/message_test.go @@ -0,0 +1,209 @@ +package wsprotocol + +import ( + "encoding/json" + "reflect" + "testing" +) + +func validOutputEnvelope() Envelope { + return Envelope{ + ProtocolVersion: ProtocolVersion, + MessageID: "msg_123", + Type: TypeTaskOutput, + AgentID: "agt_123", + TaskID: "tsk_123", + ExecutionAttemptID: "attempt_123", + Sequence: intPtr(1), + SentAt: "2026-04-27T00:00:00Z", + Payload: map[string]any{ + "stream": "stdout", + "data": "hello\n", + "byte_count": float64(6), + }, + } +} + +func TestEnvelopeJSONRoundTrip(t *testing.T) { + original := validOutputEnvelope() + + data, err := json.Marshal(original) + if err != nil { + t.Fatalf("marshal envelope: %v", err) + } + + var decoded Envelope + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal envelope: %v", err) + } + + if decoded.ProtocolVersion != ProtocolVersion { + t.Errorf("protocol_version = %d, want %d", decoded.ProtocolVersion, ProtocolVersion) + } + if decoded.MessageID != original.MessageID { + t.Errorf("message_id = %q, want %q", decoded.MessageID, original.MessageID) + } + if decoded.Type != TypeTaskOutput { + t.Errorf("type = %q, want %q", decoded.Type, TypeTaskOutput) + } + if decoded.Sequence == nil || *decoded.Sequence != 1 { + t.Fatalf("sequence = %v, want 1", decoded.Sequence) + } + + payload, err := DecodePayload[OutputPayload](decoded) + if err != nil { + t.Fatalf("decode payload: %v", err) + } + if payload.Stream != StreamStdout { + t.Errorf("stream = %q, want %q", payload.Stream, StreamStdout) + } +} + +func TestSampleOutputEnvelopeUsesCanonicalFieldNames(t *testing.T) { + data, err := json.Marshal(validOutputEnvelope()) + if err != nil { + t.Fatalf("marshal envelope: %v", err) + } + + var decoded map[string]any + if err := json.Unmarshal(data, &decoded); err != nil { + t.Fatalf("unmarshal envelope: %v", err) + } + + for _, field := range []string{"protocol_version", "message_id", "type", "agent_id", "task_id", "execution_attempt_id", "sequence", "sent_at", "payload"} { + if _, ok := decoded[field]; !ok { + t.Errorf("expected field %q", field) + } + } + + payload, ok := decoded["payload"].(map[string]any) + if !ok { + t.Fatalf("payload is %T, want object", decoded["payload"]) + } + if !reflect.DeepEqual(sortedKeys(payload), []string{"byte_count", "data", "stream"}) { + t.Errorf("payload fields = %v", sortedKeys(payload)) + } +} + +func TestEnvelopeValidate(t *testing.T) { + t.Run("accepts complete task output envelope", func(t *testing.T) { + env := validOutputEnvelope() + + if err := env.Validate("agt_123"); err != nil { + t.Fatalf("expected envelope to validate, got %v", err) + } + }) + + t.Run("rejects unsupported protocol version", func(t *testing.T) { + env := validOutputEnvelope() + env.ProtocolVersion = 2 + + if err := env.Validate("agt_123"); err == nil { + t.Fatal("expected unsupported protocol version error") + } + }) + + t.Run("rejects mismatched agent", func(t *testing.T) { + env := validOutputEnvelope() + + if err := env.Validate("agt_other"); err == nil { + t.Fatal("expected mismatched agent error") + } + }) + + t.Run("rejects task messages without task ID", func(t *testing.T) { + env := validOutputEnvelope() + env.TaskID = "" + + if err := env.Validate("agt_123"); err == nil { + t.Fatal("expected missing task ID error") + } + }) + + t.Run("accepts empty payload for non-task message", func(t *testing.T) { + env := validOutputEnvelope() + env.Type = TypeAgentHello + env.TaskID = "" + env.ExecutionAttemptID = "" + env.Sequence = nil + env.Payload = map[string]any{} + + if err := env.Validate("agt_123"); err != nil { + t.Fatalf("expected empty payload to validate, got %v", err) + } + }) +} + +func TestAcceptedMessageTypes(t *testing.T) { + expected := []MessageType{ + TypeAgentHello, + TypeAgentHelloAck, + TypeTaskDeliver, + TypeTaskReceived, + TypeTaskStarted, + TypeTaskLeaseHeartbeat, + TypeTaskOutput, + TypeTaskFinal, + TypeAck, + TypeError, + } + + for _, messageType := range expected { + if !IsSupportedType(messageType) { + t.Errorf("expected %q to be supported", messageType) + } + } +} + +func TestPayloadValidation(t *testing.T) { + t.Run("valid output payload", func(t *testing.T) { + payload := OutputPayload{Stream: StreamStderr, Data: "oops", ByteCount: 4} + + if err := payload.Validate(); err != nil { + t.Fatalf("expected output payload to validate, got %v", err) + } + }) + + t.Run("invalid output stream", func(t *testing.T) { + payload := OutputPayload{Stream: "combined", Data: "oops", ByteCount: 4} + + if err := payload.Validate(); err == nil { + t.Fatal("expected invalid stream error") + } + }) + + t.Run("valid final payload", func(t *testing.T) { + payload := FinalPayload{Status: FinalStatusCompleted, ExitCode: 0} + + if err := payload.Validate(); err != nil { + t.Fatalf("expected final payload to validate, got %v", err) + } + }) + + t.Run("invalid final status", func(t *testing.T) { + payload := FinalPayload{Status: "timed_out", ExitCode: 1} + + if err := payload.Validate(); err == nil { + t.Fatal("expected invalid status error") + } + }) +} + +func intPtr(value int) *int { + return &value +} + +func sortedKeys(values map[string]any) []string { + keys := make([]string, 0, len(values)) + for key := range values { + keys = append(keys, key) + } + for i := 0; i < len(keys); i++ { + for j := i + 1; j < len(keys); j++ { + if keys[j] < keys[i] { + keys[i], keys[j] = keys[j], keys[i] + } + } + } + return keys +}