From a5649221d53e06714c9e05626e33eb91e3d080d2 Mon Sep 17 00:00:00 2001 From: Harry Smaje Date: Tue, 31 Mar 2026 10:57:27 +0100 Subject: [PATCH 1/3] Add stream status trace messages for Airbyte protocol v2 Airbyte 2.x requires sources to emit STREAM_STATUS trace messages (STARTED, COMPLETE, INCOMPLETE) for each stream. Without these, every sync fails with: "streams did not receive a terminal stream status message" Changes: - Add TRACE message type and stream status constants to types.go - Add StreamDescriptor, AirbyteStreamStatus, AirbyteTraceMessage types - Replace legacy global State() with per-stream StreamState() that emits state.type=STREAM (required by Airbyte 2.x, which rejects the LEGACY format with IllegalArgumentException) - Add StreamStatus() method to emit STARTED/COMPLETE/INCOMPLETE traces - Update AirbyteLogger interface and test mock accordingly --- cmd/internal/logger.go | 38 +++++++++++++++++++++++++--- cmd/internal/mock_types.go | 9 ++++--- cmd/internal/types.go | 52 +++++++++++++++++++++++++++++++++----- 3 files changed, 85 insertions(+), 14 deletions(-) diff --git a/cmd/internal/logger.go b/cmd/internal/logger.go index 0ada8c0..15aa50a 100644 --- a/cmd/internal/logger.go +++ b/cmd/internal/logger.go @@ -14,8 +14,9 @@ type AirbyteLogger interface { ConnectionStatus(status ConnectionStatus) Record(tableNamespace, tableName string, data map[string]interface{}) Flush() - State(syncState SyncState) + StreamState(namespace, streamName string, shardStates ShardStates) Error(error string) + StreamStatus(namespace, streamName, status string) } const MaxBatchSize = 10000 @@ -82,10 +83,19 @@ func (a *airbyteLogger) Flush() { a.records = a.records[:0] } -func (a *airbyteLogger) State(syncState SyncState) { +func (a *airbyteLogger) StreamState(namespace, streamName string, shardStates ShardStates) { if err := a.recordEncoder.Encode(AirbyteMessage{ - Type: STATE, - State: &AirbyteState{syncState}, + Type: STATE, + State: &AirbyteState{ + Type: STATE_TYPE_STREAM, + Stream: &AirbyteStreamState{ + StreamDescriptor: StreamDescriptor{ + Name: streamName, + Namespace: namespace, + }, + StreamState: &shardStates, + }, + }, }); err != nil { a.Error(fmt.Sprintf("state encoding error: %v", err)) } @@ -103,6 +113,26 @@ func (a *airbyteLogger) Error(error string) { } } +func (a *airbyteLogger) StreamStatus(namespace, streamName, status string) { + now := time.Now() + if err := a.recordEncoder.Encode(AirbyteMessage{ + Type: TRACE, + Trace: &AirbyteTraceMessage{ + Type: TRACE_TYPE_STREAM_STATUS, + EmittedAt: float64(now.UnixMilli()), + StreamStatus: &AirbyteStreamStatus{ + StreamDescriptor: StreamDescriptor{ + Name: streamName, + Namespace: namespace, + }, + Status: status, + }, + }, + }); err != nil { + a.Error(fmt.Sprintf("stream status encoding error: %v", err)) + } +} + func (a *airbyteLogger) ConnectionStatus(status ConnectionStatus) { if err := a.recordEncoder.Encode(AirbyteMessage{ Type: CONNECTION_STATUS, diff --git a/cmd/internal/mock_types.go b/cmd/internal/mock_types.go index 742b822..2e7dc04 100644 --- a/cmd/internal/mock_types.go +++ b/cmd/internal/mock_types.go @@ -50,9 +50,8 @@ func (tal *testAirbyteLogger) Record(tableNamespace, tableName string, data map[ func (testAirbyteLogger) Flush() { } -func (testAirbyteLogger) State(syncState SyncState) { - // TODO implement me - panic("implement me") +func (testAirbyteLogger) StreamState(namespace, streamName string, shardStates ShardStates) { + // no-op for tests } func (testAirbyteLogger) Error(error string) { @@ -60,6 +59,10 @@ func (testAirbyteLogger) Error(error string) { panic("implement me") } +func (testAirbyteLogger) StreamStatus(namespace, streamName, status string) { + // no-op for tests +} + type vstreamClientMock struct { vstreamFn func(ctx context.Context, in *vtgate.VStreamRequest, opts ...grpc.CallOption) (vtgateservice.Vitess_VStreamClient, error) vstreamFnInvoked bool diff --git a/cmd/internal/types.go b/cmd/internal/types.go index 19b5f64..a24a099 100644 --- a/cmd/internal/types.go +++ b/cmd/internal/types.go @@ -21,6 +21,17 @@ const ( LOG = "LOG" CONNECTION_STATUS = "CONNECTION_STATUS" CATALOG = "CATALOG" + TRACE = "TRACE" +) + +const ( + TRACE_TYPE_STREAM_STATUS = "STREAM_STATUS" +) + +const ( + STREAM_STATUS_STARTED = "STARTED" + STREAM_STATUS_COMPLETE = "COMPLETE" + STREAM_STATUS_INCOMPLETE = "INCOMPLETE" ) const ( @@ -385,17 +396,44 @@ func mapEnumValue(value sqltypes.Value, values []string) sqltypes.Value { return value } +const ( + STATE_TYPE_STREAM = "STREAM" +) + +type AirbyteStreamState struct { + StreamDescriptor StreamDescriptor `json:"stream_descriptor"` + StreamState *ShardStates `json:"stream_state"` +} + type AirbyteState struct { - Data SyncState `json:"data"` + Type string `json:"type"` + Stream *AirbyteStreamState `json:"stream,omitempty"` +} + +type StreamDescriptor struct { + Name string `json:"name"` + Namespace string `json:"namespace"` +} + +type AirbyteStreamStatus struct { + StreamDescriptor StreamDescriptor `json:"stream_descriptor"` + Status string `json:"status"` +} + +type AirbyteTraceMessage struct { + Type string `json:"type"` + EmittedAt float64 `json:"emitted_at"` + StreamStatus *AirbyteStreamStatus `json:"stream_status,omitempty"` } type AirbyteMessage struct { - Type string `json:"type"` - Log *AirbyteLogMessage `json:"log,omitempty"` - ConnectionStatus *ConnectionStatus `json:"connectionStatus,omitempty"` - Catalog *Catalog `json:"catalog,omitempty"` - Record *AirbyteRecord `json:"record,omitempty"` - State *AirbyteState `json:"state,omitempty"` + Type string `json:"type"` + Log *AirbyteLogMessage `json:"log,omitempty"` + ConnectionStatus *ConnectionStatus `json:"connectionStatus,omitempty"` + Catalog *Catalog `json:"catalog,omitempty"` + Record *AirbyteRecord `json:"record,omitempty"` + State *AirbyteState `json:"state,omitempty"` + Trace *AirbyteTraceMessage `json:"trace,omitempty"` } // A map of starting GTIDs for every keyspace and shard From 03b5915f46d2656b49d6fdeffa5168b78079709b Mon Sep 17 00:00:00 2001 From: Harry Smaje Date: Tue, 31 Mar 2026 10:57:39 +0100 Subject: [PATCH 2/3] Emit per-stream status and state in read loop, handle v2 state input Update the read command to be fully compatible with Airbyte 2.x: Read loop changes: - Emit STARTED before reading each stream - Emit COMPLETE after successful read, INCOMPLETE on error - Replace os.Exit(1) with break on per-stream errors so remaining streams still get status messages - Emit per-stream STATE (type=STREAM) after each stream completes instead of one global state blob at the end State parsing changes: - Handle Airbyte v2 per-stream state format on incremental syncs. Airbyte 2.x passes state back as a JSON array of per-stream state objects, not the legacy global SyncState blob. Without this, the second sync always fails because json.Unmarshal fails on the array format, causing os.Exit(1) before any streams are processed. - Fall back to legacy format for backwards compatibility - Default empty namespace to source database name to prevent state key mismatches --- cmd/airbyte-source/read.go | 43 +++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 8 deletions(-) diff --git a/cmd/airbyte-source/read.go b/cmd/airbyte-source/read.go index 09b8031..c56e04b 100644 --- a/cmd/airbyte-source/read.go +++ b/cmd/airbyte-source/read.go @@ -109,9 +109,13 @@ func ReadCommand(ch *Helper) *cobra.Command { streamState, ok := syncState.Streams[streamStateKey] if !ok { ch.Logger.Error(fmt.Sprintf("Unable to read state for stream %v", streamStateKey)) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) os.Exit(1) } + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_STARTED) + + streamFailed := false for shardName, shardState := range streamState.Shards { var tc *psdbconnectv1alpha1.TableCursor @@ -119,21 +123,27 @@ func ReadCommand(ch *Helper) *cobra.Command { ch.Logger.Log(internal.LOGLEVEL_INFO, fmt.Sprintf("Using serialized cursor for stream %s", streamStateKey)) if err != nil { ch.Logger.Error(fmt.Sprintf("Invalid serialized cursor for stream %v, failed with [%v]", streamStateKey, err)) - os.Exit(1) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) + streamFailed = true + break } sc, err := ch.Database.Read(ctx, cmd.OutOrStdout(), psc, configuredStream, tc) if err != nil { ch.Logger.Error(err.Error()) - os.Exit(1) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) + streamFailed = true + break } if sc != nil { - // if we get any new state, we assign it here. - // otherwise, the older state is round-tripped back to Airbyte. syncState.Streams[streamStateKey].Shards[shardName] = sc } - ch.Logger.State(syncState) + } + + if !streamFailed { + ch.Logger.StreamState(keyspaceOrDatabase, configuredStream.Stream.Name, syncState.Streams[streamStateKey]) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_COMPLETE) } } }, @@ -153,9 +163,26 @@ func readState(state string, psc internal.PlanetScaleSource, streams []internal. Streams: map[string]internal.ShardStates{}, } if state != "" { - err := json.Unmarshal([]byte(state), &syncState) - if err != nil { - return syncState, err + // Try parsing as Airbyte v2 per-stream state array first + var perStreamStates []internal.AirbyteState + if err := json.Unmarshal([]byte(state), &perStreamStates); err == nil && len(perStreamStates) > 0 && perStreamStates[0].Type == internal.STATE_TYPE_STREAM { + logger.Log(internal.LOGLEVEL_INFO, fmt.Sprintf("Parsing Airbyte v2 per-stream state (%d streams)", len(perStreamStates))) + for _, s := range perStreamStates { + if s.Stream != nil && s.Stream.StreamState != nil { + ns := s.Stream.StreamDescriptor.Namespace + if ns == "" { + ns = psc.Database + } + key := ns + ":" + s.Stream.StreamDescriptor.Name + syncState.Streams[key] = *s.Stream.StreamState + } + } + } else { + // Fall back to legacy global state format + err := json.Unmarshal([]byte(state), &syncState) + if err != nil { + return syncState, err + } } } From de00c59bff8021c5d14a202d5bdd0329206742ce Mon Sep 17 00:00:00 2001 From: Harry Smaje Date: Tue, 31 Mar 2026 10:57:47 +0100 Subject: [PATCH 3/3] Add tests for Airbyte protocol v2 compliance Logger tests: - StreamState emits correct per-stream format with type=STREAM - Multiple shards included in state output - No legacy "data" field present (would cause LEGACY rejection) - StreamStatus emits TRACE messages with correct status values - JSON round-trip matches exact Airbyte protocol v2 structure Read protocol tests: - Read emits per-stream STATE, not legacy global state - STARTED and COMPLETE emitted for each configured stream - Correct message ordering: STARTED -> STATE -> COMPLETE - Multi-shard state contains all shard cursors - Read errors emit INCOMPLETE and skip state emission --- cmd/airbyte-source/read_protocol_test.go | 373 +++++++++++++++++++++++ cmd/internal/logger_test.go | 194 ++++++++++++ 2 files changed, 567 insertions(+) create mode 100644 cmd/airbyte-source/read_protocol_test.go create mode 100644 cmd/internal/logger_test.go diff --git a/cmd/airbyte-source/read_protocol_test.go b/cmd/airbyte-source/read_protocol_test.go new file mode 100644 index 0000000..3f9efce --- /dev/null +++ b/cmd/airbyte-source/read_protocol_test.go @@ -0,0 +1,373 @@ +package airbyte_source + +import ( + "bytes" + "context" + "encoding/json" + "io" + "os" + "testing" + + "github.com/planetscale/airbyte-source/cmd/internal" + psdbconnect "github.com/planetscale/airbyte-source/proto/psdbconnect/v1alpha1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockDatabase implements internal.PlanetScaleDatabase for read protocol tests. +type mockDatabase struct { + shards []string + readFunc func(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) + readCalls int +} + +func (m *mockDatabase) CanConnect(ctx context.Context, ps internal.PlanetScaleSource) error { + return nil +} + +func (m *mockDatabase) DiscoverSchema(ctx context.Context, ps internal.PlanetScaleSource) (internal.Catalog, error) { + return internal.Catalog{}, nil +} + +func (m *mockDatabase) ListShards(ctx context.Context, ps internal.PlanetScaleSource) ([]string, error) { + return m.shards, nil +} + +func (m *mockDatabase) Read(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) { + m.readCalls++ + if m.readFunc != nil { + return m.readFunc(ctx, w, ps, s, tc) + } + newCursor, _ := internal.TableCursorToSerializedCursor(&psdbconnect.TableCursor{ + Shard: tc.Shard, + Keyspace: tc.Keyspace, + Position: "MySQL56/updated-position", + }) + return newCursor, nil +} + +func (m *mockDatabase) Close() error { + return nil +} + +func newTestConfig() []byte { + return []byte(`{"host":"test.psdb.cloud","database":"testdb","username":"user","password":"pass"}`) +} + +func newTestCatalog(t *testing.T, streams ...string) string { + t.Helper() + catalog := internal.ConfiguredCatalog{} + for _, name := range streams { + catalog.Streams = append(catalog.Streams, internal.ConfiguredStream{ + Stream: internal.Stream{ + Name: name, + Namespace: "testdb", + }, + SyncMode: "full_refresh", + }) + } + b, err := json.Marshal(catalog) + require.NoError(t, err) + return string(b) +} + +func writeTempFile(t *testing.T, content []byte) string { + t.Helper() + f, err := os.CreateTemp(t.TempDir(), "*.json") + require.NoError(t, err) + _, err = f.Write(content) + require.NoError(t, err) + require.NoError(t, f.Close()) + return f.Name() +} + +func parseOutputMessages(t *testing.T, buf *bytes.Buffer) []internal.AirbyteMessage { + t.Helper() + var messages []internal.AirbyteMessage + decoder := json.NewDecoder(buf) + for decoder.More() { + var msg internal.AirbyteMessage + if err := decoder.Decode(&msg); err != nil { + break + } + messages = append(messages, msg) + } + return messages +} + +func setupReadCommand(t *testing.T, db *mockDatabase, catalogJSON string) (*bytes.Buffer, *Helper) { + t.Helper() + b := bytes.NewBufferString("") + h := &Helper{ + Database: db, + FileReader: testFileReader{content: newTestConfig()}, + Logger: internal.NewLogger(b), + } + return b, h +} + +func TestRead_EmitsPerStreamStateNotLegacy(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "users") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + var stateMessages []internal.AirbyteMessage + for _, msg := range messages { + if msg.Type == internal.STATE { + stateMessages = append(stateMessages, msg) + } + } + + require.NotEmpty(t, stateMessages, "should emit at least one STATE message") + + for _, msg := range stateMessages { + assert.Equal(t, internal.STATE_TYPE_STREAM, msg.State.Type, + "state.type must be STREAM, not LEGACY") + require.NotNil(t, msg.State.Stream, + "state.stream must be present") + assert.NotEmpty(t, msg.State.Stream.StreamDescriptor.Name, + "stream_descriptor.name must be set") + assert.NotEmpty(t, msg.State.Stream.StreamDescriptor.Namespace, + "stream_descriptor.namespace must be set") + require.NotNil(t, msg.State.Stream.StreamState, + "stream_state must be present") + } +} + +func TestRead_EmitsStartedAndCompletePerStream(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "orders", "products") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + type streamStatusEntry struct { + Name string + Status string + } + var statuses []streamStatusEntry + for _, msg := range messages { + if msg.Type == internal.TRACE && msg.Trace != nil && + msg.Trace.Type == internal.TRACE_TYPE_STREAM_STATUS && + msg.Trace.StreamStatus != nil { + statuses = append(statuses, streamStatusEntry{ + Name: msg.Trace.StreamStatus.StreamDescriptor.Name, + Status: msg.Trace.StreamStatus.Status, + }) + } + } + + expectedStatuses := []streamStatusEntry{ + {"orders", internal.STREAM_STATUS_STARTED}, + {"orders", internal.STREAM_STATUS_COMPLETE}, + {"products", internal.STREAM_STATUS_STARTED}, + {"products", internal.STREAM_STATUS_COMPLETE}, + } + assert.Equal(t, expectedStatuses, statuses) +} + +func TestRead_StatePerStreamContainsCorrectDescriptor(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "accounts", "sessions") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + statesByStream := map[string]internal.AirbyteMessage{} + for _, msg := range messages { + if msg.Type == internal.STATE { + name := msg.State.Stream.StreamDescriptor.Name + statesByStream[name] = msg + } + } + + assert.Contains(t, statesByStream, "accounts") + assert.Contains(t, statesByStream, "sessions") + assert.Equal(t, "testdb", statesByStream["accounts"].State.Stream.StreamDescriptor.Namespace) + assert.Equal(t, "testdb", statesByStream["sessions"].State.Stream.StreamDescriptor.Namespace) +} + +func TestRead_StateEmittedAfterStartedBeforeComplete(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "events") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + startedIdx := -1 + stateIdx := -1 + completeIdx := -1 + + for i, msg := range messages { + if msg.Type == internal.TRACE && msg.Trace != nil && + msg.Trace.StreamStatus != nil && + msg.Trace.StreamStatus.StreamDescriptor.Name == "events" { + if msg.Trace.StreamStatus.Status == internal.STREAM_STATUS_STARTED { + startedIdx = i + } + if msg.Trace.StreamStatus.Status == internal.STREAM_STATUS_COMPLETE { + completeIdx = i + } + } + if msg.Type == internal.STATE && msg.State != nil && + msg.State.Stream != nil && + msg.State.Stream.StreamDescriptor.Name == "events" { + stateIdx = i + } + } + + require.Greater(t, startedIdx, -1, "STARTED should be emitted") + require.Greater(t, stateIdx, -1, "STATE should be emitted") + require.Greater(t, completeIdx, -1, "COMPLETE should be emitted") + + assert.Less(t, startedIdx, stateIdx, "STARTED should come before STATE") + assert.Less(t, stateIdx, completeIdx, "STATE should come before COMPLETE") +} + +func TestRead_MultiShardStateContainsAllShards(t *testing.T) { + db := &mockDatabase{shards: []string{"-80", "80-"}} + catalogJSON := newTestCatalog(t, "data") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + var stateMsg *internal.AirbyteMessage + for _, msg := range messages { + if msg.Type == internal.STATE { + stateMsg = &msg + } + } + + require.NotNil(t, stateMsg, "should have a STATE message") + require.NotNil(t, stateMsg.State.Stream.StreamState) + assert.Len(t, stateMsg.State.Stream.StreamState.Shards, 2, + "state should contain both shards") + assert.Contains(t, stateMsg.State.Stream.StreamState.Shards, "-80") + assert.Contains(t, stateMsg.State.Stream.StreamState.Shards, "80-") +} + +func TestRead_ReadErrorEmitsIncompleteNotComplete(t *testing.T) { + db := &mockDatabase{ + shards: []string{"-"}, + readFunc: func(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) { + if s.Stream.Name == "bad_table" { + return nil, assert.AnError + } + newCursor, _ := internal.TableCursorToSerializedCursor(&psdbconnect.TableCursor{ + Shard: tc.Shard, + Keyspace: tc.Keyspace, + Position: "MySQL56/pos", + }) + return newCursor, nil + }, + } + + catalog := internal.ConfiguredCatalog{ + Streams: []internal.ConfiguredStream{ + { + Stream: internal.Stream{Name: "good_table", Namespace: "testdb"}, + SyncMode: "full_refresh", + }, + { + Stream: internal.Stream{Name: "bad_table", Namespace: "testdb"}, + SyncMode: "full_refresh", + }, + }, + } + catalogBytes, _ := json.Marshal(catalog) + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, catalogBytes) + + b := bytes.NewBufferString("") + h := &Helper{ + Database: db, + FileReader: testFileReader{content: newTestConfig()}, + Logger: internal.NewLogger(b), + } + + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + streamStatuses := map[string][]string{} + for _, msg := range messages { + if msg.Type == internal.TRACE && msg.Trace != nil && + msg.Trace.StreamStatus != nil { + name := msg.Trace.StreamStatus.StreamDescriptor.Name + streamStatuses[name] = append(streamStatuses[name], msg.Trace.StreamStatus.Status) + } + } + + assert.Equal(t, []string{internal.STREAM_STATUS_STARTED, internal.STREAM_STATUS_COMPLETE}, + streamStatuses["good_table"]) + assert.Equal(t, []string{internal.STREAM_STATUS_STARTED, internal.STREAM_STATUS_INCOMPLETE}, + streamStatuses["bad_table"]) + + // good_table should have a STATE message, bad_table should NOT + hasGoodState := false + hasBadState := false + for _, msg := range messages { + if msg.Type == internal.STATE && msg.State != nil && msg.State.Stream != nil { + if msg.State.Stream.StreamDescriptor.Name == "good_table" { + hasGoodState = true + } + if msg.State.Stream.StreamDescriptor.Name == "bad_table" { + hasBadState = true + } + } + } + assert.True(t, hasGoodState, "good_table should have a STATE message") + assert.False(t, hasBadState, "bad_table should NOT have a STATE message (it failed)") +} diff --git a/cmd/internal/logger_test.go b/cmd/internal/logger_test.go new file mode 100644 index 0000000..751655e --- /dev/null +++ b/cmd/internal/logger_test.go @@ -0,0 +1,194 @@ +package internal + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStreamState_EmitsPerStreamFormat(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + shardStates := ShardStates{ + Shards: map[string]*SerializedCursor{ + "-": {Cursor: "abc123"}, + }, + } + + logger.StreamState("my-database", "users", shardStates) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STATE, msg.Type) + require.NotNil(t, msg.State) + assert.Equal(t, STATE_TYPE_STREAM, msg.State.Type) + require.NotNil(t, msg.State.Stream) + assert.Equal(t, "users", msg.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "my-database", msg.State.Stream.StreamDescriptor.Namespace) + require.NotNil(t, msg.State.Stream.StreamState) + assert.Equal(t, "abc123", msg.State.Stream.StreamState.Shards["-"].Cursor) +} + +func TestStreamState_MultipleShards(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + shardStates := ShardStates{ + Shards: map[string]*SerializedCursor{ + "-80": {Cursor: "cursor1"}, + "80-": {Cursor: "cursor2"}, + }, + } + + logger.StreamState("sharded-db", "orders", shardStates) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STATE_TYPE_STREAM, msg.State.Type) + assert.Equal(t, "orders", msg.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "sharded-db", msg.State.Stream.StreamDescriptor.Namespace) + assert.Len(t, msg.State.Stream.StreamState.Shards, 2) + assert.Equal(t, "cursor1", msg.State.Stream.StreamState.Shards["-80"].Cursor) + assert.Equal(t, "cursor2", msg.State.Stream.StreamState.Shards["80-"].Cursor) +} + +func TestStreamState_NoLegacyDataField(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + shardStates := ShardStates{ + Shards: map[string]*SerializedCursor{ + "-": {Cursor: "abc"}, + }, + } + + logger.StreamState("db", "table1", shardStates) + + // Parse as raw JSON to verify no "data" key exists (which would indicate LEGACY format) + var raw map[string]json.RawMessage + err := json.NewDecoder(b).Decode(&raw) + require.NoError(t, err) + + var stateRaw map[string]json.RawMessage + err = json.Unmarshal(raw["state"], &stateRaw) + require.NoError(t, err) + + _, hasData := stateRaw["data"] + assert.False(t, hasData, "state should not contain 'data' field (LEGACY format)") + + _, hasType := stateRaw["type"] + assert.True(t, hasType, "state must contain 'type' field") + + _, hasStream := stateRaw["stream"] + assert.True(t, hasStream, "state must contain 'stream' field") +} + +func TestStreamStatus_EmitsTraceMessage(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamStatus("my-db", "accounts", STREAM_STATUS_STARTED) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, TRACE, msg.Type) + require.NotNil(t, msg.Trace) + assert.Equal(t, TRACE_TYPE_STREAM_STATUS, msg.Trace.Type) + assert.True(t, msg.Trace.EmittedAt > 0) + require.NotNil(t, msg.Trace.StreamStatus) + assert.Equal(t, "accounts", msg.Trace.StreamStatus.StreamDescriptor.Name) + assert.Equal(t, "my-db", msg.Trace.StreamStatus.StreamDescriptor.Namespace) + assert.Equal(t, STREAM_STATUS_STARTED, msg.Trace.StreamStatus.Status) +} + +func TestStreamStatus_Complete(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamStatus("ns", "tbl", STREAM_STATUS_COMPLETE) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STREAM_STATUS_COMPLETE, msg.Trace.StreamStatus.Status) +} + +func TestStreamStatus_Incomplete(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamStatus("ns", "tbl", STREAM_STATUS_INCOMPLETE) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STREAM_STATUS_INCOMPLETE, msg.Trace.StreamStatus.Status) +} + +func TestStreamState_JSONRoundTrip(t *testing.T) { + // Verify the JSON output can be parsed back into the exact expected Airbyte protocol structure + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamState("anam-lab", "persona", ShardStates{ + Shards: map[string]*SerializedCursor{ + "-": {Cursor: "encoded-cursor-data"}, + }, + }) + + // Parse into a generic structure to verify exact JSON shape + var raw map[string]interface{} + err := json.NewDecoder(b).Decode(&raw) + require.NoError(t, err) + + assert.Equal(t, "STATE", raw["type"]) + + state := raw["state"].(map[string]interface{}) + assert.Equal(t, "STREAM", state["type"]) + + stream := state["stream"].(map[string]interface{}) + descriptor := stream["stream_descriptor"].(map[string]interface{}) + assert.Equal(t, "persona", descriptor["name"]) + assert.Equal(t, "anam-lab", descriptor["namespace"]) + + streamState := stream["stream_state"].(map[string]interface{}) + shards := streamState["shards"].(map[string]interface{}) + shard := shards["-"].(map[string]interface{}) + assert.Equal(t, "encoded-cursor-data", shard["cursor"]) +} + +func TestMultipleStreamStates_EachIndependent(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamState("db", "table1", ShardStates{ + Shards: map[string]*SerializedCursor{"-": {Cursor: "c1"}}, + }) + logger.StreamState("db", "table2", ShardStates{ + Shards: map[string]*SerializedCursor{"-": {Cursor: "c2"}}, + }) + + decoder := json.NewDecoder(b) + + var msg1 AirbyteMessage + require.NoError(t, decoder.Decode(&msg1)) + assert.Equal(t, "table1", msg1.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "c1", msg1.State.Stream.StreamState.Shards["-"].Cursor) + + var msg2 AirbyteMessage + require.NoError(t, decoder.Decode(&msg2)) + assert.Equal(t, "table2", msg2.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "c2", msg2.State.Stream.StreamState.Shards["-"].Cursor) +}