From 396026dfb83a77b2959a9fd8791bb615eb578a90 Mon Sep 17 00:00:00 2001 From: Harry Smaje Date: Fri, 27 Mar 2026 13:56:53 +0000 Subject: [PATCH 1/3] Add stream status trace messages for Airbyte 2.x protocol compliance Emit STARTED status when beginning to read each stream and COMPLETE/INCOMPLETE terminal status when finishing. This is required by Airbyte protocol v0.2.0+ and fixes the error: "streams did not receive a terminal stream status message" Also update Dockerfile to use standard golang/alpine base images instead of pscale.dev private images. --- Dockerfile | 5 +++-- cmd/airbyte-source/read.go | 18 +++++++++++++---- cmd/internal/logger.go | 21 ++++++++++++++++++++ cmd/internal/mock_types.go | 4 ++++ cmd/internal/types.go | 40 ++++++++++++++++++++++++++++++++------ 5 files changed, 76 insertions(+), 12 deletions(-) diff --git a/Dockerfile b/Dockerfile index 2954db2..d34eb0f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ # syntax=docker/dockerfile:1 ARG GO_VERSION=1.22.2 -FROM pscale.dev/wolfi-prod/go:${GO_VERSION} AS build +FROM golang:${GO_VERSION} AS build WORKDIR /airbyte-source COPY . . @@ -9,8 +9,9 @@ COPY . . RUN go mod download RUN CGO_ENABLED=0 go build -ldflags="-s -w" -trimpath -o /connect -FROM pscale.dev/wolfi-prod/base:latest +FROM alpine:latest +RUN apk add --no-cache ca-certificates COPY --from=build /connect /usr/local/bin/ ENV AIRBYTE_ENTRYPOINT "/usr/local/bin/connect" ENTRYPOINT ["/usr/local/bin/connect"] diff --git a/cmd/airbyte-source/read.go b/cmd/airbyte-source/read.go index 09b8031..a3bc0ed 100644 --- a/cmd/airbyte-source/read.go +++ b/cmd/airbyte-source/read.go @@ -109,9 +109,13 @@ func ReadCommand(ch *Helper) *cobra.Command { streamState, ok := syncState.Streams[streamStateKey] if !ok { ch.Logger.Error(fmt.Sprintf("Unable to read state for stream %v", streamStateKey)) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) os.Exit(1) } + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_STARTED) + + streamFailed := false for shardName, shardState := range streamState.Shards { var tc *psdbconnectv1alpha1.TableCursor @@ -119,22 +123,28 @@ func ReadCommand(ch *Helper) *cobra.Command { ch.Logger.Log(internal.LOGLEVEL_INFO, fmt.Sprintf("Using serialized cursor for stream %s", streamStateKey)) if err != nil { ch.Logger.Error(fmt.Sprintf("Invalid serialized cursor for stream %v, failed with [%v]", streamStateKey, err)) - os.Exit(1) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) + streamFailed = true + break } sc, err := ch.Database.Read(ctx, cmd.OutOrStdout(), psc, configuredStream, tc) if err != nil { ch.Logger.Error(err.Error()) - os.Exit(1) + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_INCOMPLETE) + streamFailed = true + break } if sc != nil { - // if we get any new state, we assign it here. - // otherwise, the older state is round-tripped back to Airbyte. syncState.Streams[streamStateKey].Shards[shardName] = sc } ch.Logger.State(syncState) } + + if !streamFailed { + ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_COMPLETE) + } } }, } diff --git a/cmd/internal/logger.go b/cmd/internal/logger.go index 0ada8c0..c894600 100644 --- a/cmd/internal/logger.go +++ b/cmd/internal/logger.go @@ -16,6 +16,7 @@ type AirbyteLogger interface { Flush() State(syncState SyncState) Error(error string) + StreamStatus(namespace, streamName, status string) } const MaxBatchSize = 10000 @@ -103,6 +104,26 @@ func (a *airbyteLogger) Error(error string) { } } +func (a *airbyteLogger) StreamStatus(namespace, streamName, status string) { + now := time.Now() + if err := a.recordEncoder.Encode(AirbyteMessage{ + Type: TRACE, + Trace: &AirbyteTraceMessage{ + Type: TRACE_TYPE_STREAM_STATUS, + EmittedAt: float64(now.UnixMilli()), + StreamStatus: &AirbyteStreamStatus{ + StreamDescriptor: StreamDescriptor{ + Name: streamName, + Namespace: namespace, + }, + Status: status, + }, + }, + }); err != nil { + a.Error(fmt.Sprintf("stream status encoding error: %v", err)) + } +} + func (a *airbyteLogger) ConnectionStatus(status ConnectionStatus) { if err := a.recordEncoder.Encode(AirbyteMessage{ Type: CONNECTION_STATUS, diff --git a/cmd/internal/mock_types.go b/cmd/internal/mock_types.go index 742b822..bfa3530 100644 --- a/cmd/internal/mock_types.go +++ b/cmd/internal/mock_types.go @@ -60,6 +60,10 @@ func (testAirbyteLogger) Error(error string) { panic("implement me") } +func (testAirbyteLogger) StreamStatus(namespace, streamName, status string) { + // no-op for tests +} + type vstreamClientMock struct { vstreamFn func(ctx context.Context, in *vtgate.VStreamRequest, opts ...grpc.CallOption) (vtgateservice.Vitess_VStreamClient, error) vstreamFnInvoked bool diff --git a/cmd/internal/types.go b/cmd/internal/types.go index 19b5f64..604ac93 100644 --- a/cmd/internal/types.go +++ b/cmd/internal/types.go @@ -21,6 +21,17 @@ const ( LOG = "LOG" CONNECTION_STATUS = "CONNECTION_STATUS" CATALOG = "CATALOG" + TRACE = "TRACE" +) + +const ( + TRACE_TYPE_STREAM_STATUS = "STREAM_STATUS" +) + +const ( + STREAM_STATUS_STARTED = "STARTED" + STREAM_STATUS_COMPLETE = "COMPLETE" + STREAM_STATUS_INCOMPLETE = "INCOMPLETE" ) const ( @@ -389,13 +400,30 @@ type AirbyteState struct { Data SyncState `json:"data"` } +type StreamDescriptor struct { + Name string `json:"name"` + Namespace string `json:"namespace"` +} + +type AirbyteStreamStatus struct { + StreamDescriptor StreamDescriptor `json:"stream_descriptor"` + Status string `json:"status"` +} + +type AirbyteTraceMessage struct { + Type string `json:"type"` + EmittedAt float64 `json:"emitted_at"` + StreamStatus *AirbyteStreamStatus `json:"stream_status,omitempty"` +} + type AirbyteMessage struct { - Type string `json:"type"` - Log *AirbyteLogMessage `json:"log,omitempty"` - ConnectionStatus *ConnectionStatus `json:"connectionStatus,omitempty"` - Catalog *Catalog `json:"catalog,omitempty"` - Record *AirbyteRecord `json:"record,omitempty"` - State *AirbyteState `json:"state,omitempty"` + Type string `json:"type"` + Log *AirbyteLogMessage `json:"log,omitempty"` + ConnectionStatus *ConnectionStatus `json:"connectionStatus,omitempty"` + Catalog *Catalog `json:"catalog,omitempty"` + Record *AirbyteRecord `json:"record,omitempty"` + State *AirbyteState `json:"state,omitempty"` + Trace *AirbyteTraceMessage `json:"trace,omitempty"` } // A map of starting GTIDs for every keyspace and shard From eefcb529aa1ee6e69d6e100160eae3676252bbda Mon Sep 17 00:00:00 2001 From: Harry Smaje Date: Mon, 30 Mar 2026 11:20:09 +0100 Subject: [PATCH 2/3] Replace LEGACY global state with per-stream STREAM state for Airbyte 2.x Airbyte 2.x rejects LEGACY state messages with: IllegalArgumentException: LEGACY states are deprecated This changes the connector to emit one STATE message per stream with type=STREAM and a stream_descriptor, instead of a single global state blob. Each stream's shard cursors are emitted individually after the stream finishes processing. --- cmd/airbyte-source/read.go | 2 +- cmd/internal/logger.go | 17 +++++++++++++---- cmd/internal/mock_types.go | 5 ++--- cmd/internal/types.go | 12 +++++++++++- 4 files changed, 27 insertions(+), 9 deletions(-) diff --git a/cmd/airbyte-source/read.go b/cmd/airbyte-source/read.go index a3bc0ed..5e32e4f 100644 --- a/cmd/airbyte-source/read.go +++ b/cmd/airbyte-source/read.go @@ -139,10 +139,10 @@ func ReadCommand(ch *Helper) *cobra.Command { if sc != nil { syncState.Streams[streamStateKey].Shards[shardName] = sc } - ch.Logger.State(syncState) } if !streamFailed { + ch.Logger.StreamState(keyspaceOrDatabase, configuredStream.Stream.Name, syncState.Streams[streamStateKey]) ch.Logger.StreamStatus(keyspaceOrDatabase, configuredStream.Stream.Name, internal.STREAM_STATUS_COMPLETE) } } diff --git a/cmd/internal/logger.go b/cmd/internal/logger.go index c894600..15aa50a 100644 --- a/cmd/internal/logger.go +++ b/cmd/internal/logger.go @@ -14,7 +14,7 @@ type AirbyteLogger interface { ConnectionStatus(status ConnectionStatus) Record(tableNamespace, tableName string, data map[string]interface{}) Flush() - State(syncState SyncState) + StreamState(namespace, streamName string, shardStates ShardStates) Error(error string) StreamStatus(namespace, streamName, status string) } @@ -83,10 +83,19 @@ func (a *airbyteLogger) Flush() { a.records = a.records[:0] } -func (a *airbyteLogger) State(syncState SyncState) { +func (a *airbyteLogger) StreamState(namespace, streamName string, shardStates ShardStates) { if err := a.recordEncoder.Encode(AirbyteMessage{ - Type: STATE, - State: &AirbyteState{syncState}, + Type: STATE, + State: &AirbyteState{ + Type: STATE_TYPE_STREAM, + Stream: &AirbyteStreamState{ + StreamDescriptor: StreamDescriptor{ + Name: streamName, + Namespace: namespace, + }, + StreamState: &shardStates, + }, + }, }); err != nil { a.Error(fmt.Sprintf("state encoding error: %v", err)) } diff --git a/cmd/internal/mock_types.go b/cmd/internal/mock_types.go index bfa3530..2e7dc04 100644 --- a/cmd/internal/mock_types.go +++ b/cmd/internal/mock_types.go @@ -50,9 +50,8 @@ func (tal *testAirbyteLogger) Record(tableNamespace, tableName string, data map[ func (testAirbyteLogger) Flush() { } -func (testAirbyteLogger) State(syncState SyncState) { - // TODO implement me - panic("implement me") +func (testAirbyteLogger) StreamState(namespace, streamName string, shardStates ShardStates) { + // no-op for tests } func (testAirbyteLogger) Error(error string) { diff --git a/cmd/internal/types.go b/cmd/internal/types.go index 604ac93..a24a099 100644 --- a/cmd/internal/types.go +++ b/cmd/internal/types.go @@ -396,8 +396,18 @@ func mapEnumValue(value sqltypes.Value, values []string) sqltypes.Value { return value } +const ( + STATE_TYPE_STREAM = "STREAM" +) + +type AirbyteStreamState struct { + StreamDescriptor StreamDescriptor `json:"stream_descriptor"` + StreamState *ShardStates `json:"stream_state"` +} + type AirbyteState struct { - Data SyncState `json:"data"` + Type string `json:"type"` + Stream *AirbyteStreamState `json:"stream,omitempty"` } type StreamDescriptor struct { From e49ed7b6c2f9cf051b1a95ef98b075df67526a68 Mon Sep 17 00:00:00 2001 From: Harry Smaje Date: Mon, 30 Mar 2026 11:27:21 +0100 Subject: [PATCH 3/3] Add tests for per-stream state and stream status protocol compliance MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Logger tests (cmd/internal/logger_test.go): - StreamState emits correct per-stream format with type=STREAM - StreamState handles multiple shards - No LEGACY "data" field present in state output - StreamStatus emits TRACE messages with correct status values - JSON round-trip produces exact Airbyte protocol v2 structure - Multiple stream states are independent Read protocol tests (cmd/airbyte-source/read_protocol_test.go): - Read emits per-stream STATE (not LEGACY global state) - STARTED and COMPLETE emitted for each stream - State descriptors have correct stream name and namespace - Message ordering: STARTED → STATE → COMPLETE - Multi-shard state contains all shard cursors - Read errors emit INCOMPLETE (not COMPLETE) and skip state emission --- cmd/airbyte-source/read_protocol_test.go | 373 +++++++++++++++++++++++ cmd/internal/logger_test.go | 194 ++++++++++++ 2 files changed, 567 insertions(+) create mode 100644 cmd/airbyte-source/read_protocol_test.go create mode 100644 cmd/internal/logger_test.go diff --git a/cmd/airbyte-source/read_protocol_test.go b/cmd/airbyte-source/read_protocol_test.go new file mode 100644 index 0000000..3f9efce --- /dev/null +++ b/cmd/airbyte-source/read_protocol_test.go @@ -0,0 +1,373 @@ +package airbyte_source + +import ( + "bytes" + "context" + "encoding/json" + "io" + "os" + "testing" + + "github.com/planetscale/airbyte-source/cmd/internal" + psdbconnect "github.com/planetscale/airbyte-source/proto/psdbconnect/v1alpha1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockDatabase implements internal.PlanetScaleDatabase for read protocol tests. +type mockDatabase struct { + shards []string + readFunc func(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) + readCalls int +} + +func (m *mockDatabase) CanConnect(ctx context.Context, ps internal.PlanetScaleSource) error { + return nil +} + +func (m *mockDatabase) DiscoverSchema(ctx context.Context, ps internal.PlanetScaleSource) (internal.Catalog, error) { + return internal.Catalog{}, nil +} + +func (m *mockDatabase) ListShards(ctx context.Context, ps internal.PlanetScaleSource) ([]string, error) { + return m.shards, nil +} + +func (m *mockDatabase) Read(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) { + m.readCalls++ + if m.readFunc != nil { + return m.readFunc(ctx, w, ps, s, tc) + } + newCursor, _ := internal.TableCursorToSerializedCursor(&psdbconnect.TableCursor{ + Shard: tc.Shard, + Keyspace: tc.Keyspace, + Position: "MySQL56/updated-position", + }) + return newCursor, nil +} + +func (m *mockDatabase) Close() error { + return nil +} + +func newTestConfig() []byte { + return []byte(`{"host":"test.psdb.cloud","database":"testdb","username":"user","password":"pass"}`) +} + +func newTestCatalog(t *testing.T, streams ...string) string { + t.Helper() + catalog := internal.ConfiguredCatalog{} + for _, name := range streams { + catalog.Streams = append(catalog.Streams, internal.ConfiguredStream{ + Stream: internal.Stream{ + Name: name, + Namespace: "testdb", + }, + SyncMode: "full_refresh", + }) + } + b, err := json.Marshal(catalog) + require.NoError(t, err) + return string(b) +} + +func writeTempFile(t *testing.T, content []byte) string { + t.Helper() + f, err := os.CreateTemp(t.TempDir(), "*.json") + require.NoError(t, err) + _, err = f.Write(content) + require.NoError(t, err) + require.NoError(t, f.Close()) + return f.Name() +} + +func parseOutputMessages(t *testing.T, buf *bytes.Buffer) []internal.AirbyteMessage { + t.Helper() + var messages []internal.AirbyteMessage + decoder := json.NewDecoder(buf) + for decoder.More() { + var msg internal.AirbyteMessage + if err := decoder.Decode(&msg); err != nil { + break + } + messages = append(messages, msg) + } + return messages +} + +func setupReadCommand(t *testing.T, db *mockDatabase, catalogJSON string) (*bytes.Buffer, *Helper) { + t.Helper() + b := bytes.NewBufferString("") + h := &Helper{ + Database: db, + FileReader: testFileReader{content: newTestConfig()}, + Logger: internal.NewLogger(b), + } + return b, h +} + +func TestRead_EmitsPerStreamStateNotLegacy(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "users") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + var stateMessages []internal.AirbyteMessage + for _, msg := range messages { + if msg.Type == internal.STATE { + stateMessages = append(stateMessages, msg) + } + } + + require.NotEmpty(t, stateMessages, "should emit at least one STATE message") + + for _, msg := range stateMessages { + assert.Equal(t, internal.STATE_TYPE_STREAM, msg.State.Type, + "state.type must be STREAM, not LEGACY") + require.NotNil(t, msg.State.Stream, + "state.stream must be present") + assert.NotEmpty(t, msg.State.Stream.StreamDescriptor.Name, + "stream_descriptor.name must be set") + assert.NotEmpty(t, msg.State.Stream.StreamDescriptor.Namespace, + "stream_descriptor.namespace must be set") + require.NotNil(t, msg.State.Stream.StreamState, + "stream_state must be present") + } +} + +func TestRead_EmitsStartedAndCompletePerStream(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "orders", "products") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + type streamStatusEntry struct { + Name string + Status string + } + var statuses []streamStatusEntry + for _, msg := range messages { + if msg.Type == internal.TRACE && msg.Trace != nil && + msg.Trace.Type == internal.TRACE_TYPE_STREAM_STATUS && + msg.Trace.StreamStatus != nil { + statuses = append(statuses, streamStatusEntry{ + Name: msg.Trace.StreamStatus.StreamDescriptor.Name, + Status: msg.Trace.StreamStatus.Status, + }) + } + } + + expectedStatuses := []streamStatusEntry{ + {"orders", internal.STREAM_STATUS_STARTED}, + {"orders", internal.STREAM_STATUS_COMPLETE}, + {"products", internal.STREAM_STATUS_STARTED}, + {"products", internal.STREAM_STATUS_COMPLETE}, + } + assert.Equal(t, expectedStatuses, statuses) +} + +func TestRead_StatePerStreamContainsCorrectDescriptor(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "accounts", "sessions") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + statesByStream := map[string]internal.AirbyteMessage{} + for _, msg := range messages { + if msg.Type == internal.STATE { + name := msg.State.Stream.StreamDescriptor.Name + statesByStream[name] = msg + } + } + + assert.Contains(t, statesByStream, "accounts") + assert.Contains(t, statesByStream, "sessions") + assert.Equal(t, "testdb", statesByStream["accounts"].State.Stream.StreamDescriptor.Namespace) + assert.Equal(t, "testdb", statesByStream["sessions"].State.Stream.StreamDescriptor.Namespace) +} + +func TestRead_StateEmittedAfterStartedBeforeComplete(t *testing.T) { + db := &mockDatabase{shards: []string{"-"}} + catalogJSON := newTestCatalog(t, "events") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + startedIdx := -1 + stateIdx := -1 + completeIdx := -1 + + for i, msg := range messages { + if msg.Type == internal.TRACE && msg.Trace != nil && + msg.Trace.StreamStatus != nil && + msg.Trace.StreamStatus.StreamDescriptor.Name == "events" { + if msg.Trace.StreamStatus.Status == internal.STREAM_STATUS_STARTED { + startedIdx = i + } + if msg.Trace.StreamStatus.Status == internal.STREAM_STATUS_COMPLETE { + completeIdx = i + } + } + if msg.Type == internal.STATE && msg.State != nil && + msg.State.Stream != nil && + msg.State.Stream.StreamDescriptor.Name == "events" { + stateIdx = i + } + } + + require.Greater(t, startedIdx, -1, "STARTED should be emitted") + require.Greater(t, stateIdx, -1, "STATE should be emitted") + require.Greater(t, completeIdx, -1, "COMPLETE should be emitted") + + assert.Less(t, startedIdx, stateIdx, "STARTED should come before STATE") + assert.Less(t, stateIdx, completeIdx, "STATE should come before COMPLETE") +} + +func TestRead_MultiShardStateContainsAllShards(t *testing.T) { + db := &mockDatabase{shards: []string{"-80", "80-"}} + catalogJSON := newTestCatalog(t, "data") + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, []byte(catalogJSON)) + + b, h := setupReadCommand(t, db, catalogJSON) + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + var stateMsg *internal.AirbyteMessage + for _, msg := range messages { + if msg.Type == internal.STATE { + stateMsg = &msg + } + } + + require.NotNil(t, stateMsg, "should have a STATE message") + require.NotNil(t, stateMsg.State.Stream.StreamState) + assert.Len(t, stateMsg.State.Stream.StreamState.Shards, 2, + "state should contain both shards") + assert.Contains(t, stateMsg.State.Stream.StreamState.Shards, "-80") + assert.Contains(t, stateMsg.State.Stream.StreamState.Shards, "80-") +} + +func TestRead_ReadErrorEmitsIncompleteNotComplete(t *testing.T) { + db := &mockDatabase{ + shards: []string{"-"}, + readFunc: func(ctx context.Context, w io.Writer, ps internal.PlanetScaleSource, s internal.ConfiguredStream, tc *psdbconnect.TableCursor) (*internal.SerializedCursor, error) { + if s.Stream.Name == "bad_table" { + return nil, assert.AnError + } + newCursor, _ := internal.TableCursorToSerializedCursor(&psdbconnect.TableCursor{ + Shard: tc.Shard, + Keyspace: tc.Keyspace, + Position: "MySQL56/pos", + }) + return newCursor, nil + }, + } + + catalog := internal.ConfiguredCatalog{ + Streams: []internal.ConfiguredStream{ + { + Stream: internal.Stream{Name: "good_table", Namespace: "testdb"}, + SyncMode: "full_refresh", + }, + { + Stream: internal.Stream{Name: "bad_table", Namespace: "testdb"}, + SyncMode: "full_refresh", + }, + }, + } + catalogBytes, _ := json.Marshal(catalog) + + configFile := writeTempFile(t, newTestConfig()) + catalogFile := writeTempFile(t, catalogBytes) + + b := bytes.NewBufferString("") + h := &Helper{ + Database: db, + FileReader: testFileReader{content: newTestConfig()}, + Logger: internal.NewLogger(b), + } + + cmd := ReadCommand(h) + cmd.SetOut(b) + require.NoError(t, cmd.Flag("config").Value.Set(configFile)) + require.NoError(t, cmd.Flag("catalog").Value.Set(catalogFile)) + require.NoError(t, cmd.Execute()) + + messages := parseOutputMessages(t, b) + + streamStatuses := map[string][]string{} + for _, msg := range messages { + if msg.Type == internal.TRACE && msg.Trace != nil && + msg.Trace.StreamStatus != nil { + name := msg.Trace.StreamStatus.StreamDescriptor.Name + streamStatuses[name] = append(streamStatuses[name], msg.Trace.StreamStatus.Status) + } + } + + assert.Equal(t, []string{internal.STREAM_STATUS_STARTED, internal.STREAM_STATUS_COMPLETE}, + streamStatuses["good_table"]) + assert.Equal(t, []string{internal.STREAM_STATUS_STARTED, internal.STREAM_STATUS_INCOMPLETE}, + streamStatuses["bad_table"]) + + // good_table should have a STATE message, bad_table should NOT + hasGoodState := false + hasBadState := false + for _, msg := range messages { + if msg.Type == internal.STATE && msg.State != nil && msg.State.Stream != nil { + if msg.State.Stream.StreamDescriptor.Name == "good_table" { + hasGoodState = true + } + if msg.State.Stream.StreamDescriptor.Name == "bad_table" { + hasBadState = true + } + } + } + assert.True(t, hasGoodState, "good_table should have a STATE message") + assert.False(t, hasBadState, "bad_table should NOT have a STATE message (it failed)") +} diff --git a/cmd/internal/logger_test.go b/cmd/internal/logger_test.go new file mode 100644 index 0000000..751655e --- /dev/null +++ b/cmd/internal/logger_test.go @@ -0,0 +1,194 @@ +package internal + +import ( + "bytes" + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestStreamState_EmitsPerStreamFormat(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + shardStates := ShardStates{ + Shards: map[string]*SerializedCursor{ + "-": {Cursor: "abc123"}, + }, + } + + logger.StreamState("my-database", "users", shardStates) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STATE, msg.Type) + require.NotNil(t, msg.State) + assert.Equal(t, STATE_TYPE_STREAM, msg.State.Type) + require.NotNil(t, msg.State.Stream) + assert.Equal(t, "users", msg.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "my-database", msg.State.Stream.StreamDescriptor.Namespace) + require.NotNil(t, msg.State.Stream.StreamState) + assert.Equal(t, "abc123", msg.State.Stream.StreamState.Shards["-"].Cursor) +} + +func TestStreamState_MultipleShards(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + shardStates := ShardStates{ + Shards: map[string]*SerializedCursor{ + "-80": {Cursor: "cursor1"}, + "80-": {Cursor: "cursor2"}, + }, + } + + logger.StreamState("sharded-db", "orders", shardStates) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STATE_TYPE_STREAM, msg.State.Type) + assert.Equal(t, "orders", msg.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "sharded-db", msg.State.Stream.StreamDescriptor.Namespace) + assert.Len(t, msg.State.Stream.StreamState.Shards, 2) + assert.Equal(t, "cursor1", msg.State.Stream.StreamState.Shards["-80"].Cursor) + assert.Equal(t, "cursor2", msg.State.Stream.StreamState.Shards["80-"].Cursor) +} + +func TestStreamState_NoLegacyDataField(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + shardStates := ShardStates{ + Shards: map[string]*SerializedCursor{ + "-": {Cursor: "abc"}, + }, + } + + logger.StreamState("db", "table1", shardStates) + + // Parse as raw JSON to verify no "data" key exists (which would indicate LEGACY format) + var raw map[string]json.RawMessage + err := json.NewDecoder(b).Decode(&raw) + require.NoError(t, err) + + var stateRaw map[string]json.RawMessage + err = json.Unmarshal(raw["state"], &stateRaw) + require.NoError(t, err) + + _, hasData := stateRaw["data"] + assert.False(t, hasData, "state should not contain 'data' field (LEGACY format)") + + _, hasType := stateRaw["type"] + assert.True(t, hasType, "state must contain 'type' field") + + _, hasStream := stateRaw["stream"] + assert.True(t, hasStream, "state must contain 'stream' field") +} + +func TestStreamStatus_EmitsTraceMessage(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamStatus("my-db", "accounts", STREAM_STATUS_STARTED) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, TRACE, msg.Type) + require.NotNil(t, msg.Trace) + assert.Equal(t, TRACE_TYPE_STREAM_STATUS, msg.Trace.Type) + assert.True(t, msg.Trace.EmittedAt > 0) + require.NotNil(t, msg.Trace.StreamStatus) + assert.Equal(t, "accounts", msg.Trace.StreamStatus.StreamDescriptor.Name) + assert.Equal(t, "my-db", msg.Trace.StreamStatus.StreamDescriptor.Namespace) + assert.Equal(t, STREAM_STATUS_STARTED, msg.Trace.StreamStatus.Status) +} + +func TestStreamStatus_Complete(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamStatus("ns", "tbl", STREAM_STATUS_COMPLETE) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STREAM_STATUS_COMPLETE, msg.Trace.StreamStatus.Status) +} + +func TestStreamStatus_Incomplete(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamStatus("ns", "tbl", STREAM_STATUS_INCOMPLETE) + + var msg AirbyteMessage + err := json.NewDecoder(b).Decode(&msg) + require.NoError(t, err) + + assert.Equal(t, STREAM_STATUS_INCOMPLETE, msg.Trace.StreamStatus.Status) +} + +func TestStreamState_JSONRoundTrip(t *testing.T) { + // Verify the JSON output can be parsed back into the exact expected Airbyte protocol structure + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamState("anam-lab", "persona", ShardStates{ + Shards: map[string]*SerializedCursor{ + "-": {Cursor: "encoded-cursor-data"}, + }, + }) + + // Parse into a generic structure to verify exact JSON shape + var raw map[string]interface{} + err := json.NewDecoder(b).Decode(&raw) + require.NoError(t, err) + + assert.Equal(t, "STATE", raw["type"]) + + state := raw["state"].(map[string]interface{}) + assert.Equal(t, "STREAM", state["type"]) + + stream := state["stream"].(map[string]interface{}) + descriptor := stream["stream_descriptor"].(map[string]interface{}) + assert.Equal(t, "persona", descriptor["name"]) + assert.Equal(t, "anam-lab", descriptor["namespace"]) + + streamState := stream["stream_state"].(map[string]interface{}) + shards := streamState["shards"].(map[string]interface{}) + shard := shards["-"].(map[string]interface{}) + assert.Equal(t, "encoded-cursor-data", shard["cursor"]) +} + +func TestMultipleStreamStates_EachIndependent(t *testing.T) { + b := bytes.NewBufferString("") + logger := NewLogger(b) + + logger.StreamState("db", "table1", ShardStates{ + Shards: map[string]*SerializedCursor{"-": {Cursor: "c1"}}, + }) + logger.StreamState("db", "table2", ShardStates{ + Shards: map[string]*SerializedCursor{"-": {Cursor: "c2"}}, + }) + + decoder := json.NewDecoder(b) + + var msg1 AirbyteMessage + require.NoError(t, decoder.Decode(&msg1)) + assert.Equal(t, "table1", msg1.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "c1", msg1.State.Stream.StreamState.Shards["-"].Cursor) + + var msg2 AirbyteMessage + require.NoError(t, decoder.Decode(&msg2)) + assert.Equal(t, "table2", msg2.State.Stream.StreamDescriptor.Name) + assert.Equal(t, "c2", msg2.State.Stream.StreamState.Shards["-"].Cursor) +}