diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 3fe6a6dd3e..8033c5572f 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -281,6 +281,55 @@ func (a *FlowableActivity) CreateNormalizedTable( }, nil } +// SyncIndexesAndTriggers syncs indexes and triggers from source to destination +// This is called once during initial setup, not for on-the-fly changes, thats reason its added into flow +func (a *FlowableActivity) SyncIndexesAndTriggers( + ctx context.Context, + config *protos.SetupNormalizedTableBatchInput, +) error { + logger := internal.LoggerFromCtx(ctx) + ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName) + + dstConn, dstClose, err := connectors.GetByNameAs[connectors.NormalizedTablesConnector](ctx, config.Env, a.CatalogPool, config.PeerName) + if err != nil { + if errors.Is(err, errors.ErrUnsupported) { + logger.Info("Connector does not support normalized tables, skipping index/trigger sync") + return nil + } + return a.Alerter.LogFlowError(ctx, config.FlowName, fmt.Errorf("failed to get destination connector: %w", err)) + } + defer dstClose(ctx) + + pgDstConn, ok := dstConn.(*connpostgres.PostgresConnector) + if !ok { + logger.Info("Destination is not Postgres, skipping index/trigger sync") + return nil + } + if config.SourcePeerName == "" { + logger.Info("Source peer name not provided, skipping index/trigger sync") + return nil + } + + srcConn, srcClose, err := connectors.GetByNameAs[connectors.GetTableSchemaConnector](ctx, config.Env, a.CatalogPool, config.SourcePeerName) + if err != nil { + return a.Alerter.LogFlowError(ctx, config.FlowName, fmt.Errorf("failed to get source connector: %w", err)) + } + defer srcClose(ctx) + + pgSrcConn, ok := srcConn.(*connpostgres.PostgresConnector) + if !ok { + logger.Info("Source is not Postgres, skipping index/trigger sync") + return nil + } + a.Alerter.LogFlowInfo(ctx, config.FlowName, "Syncing indexes, triggers, and constraints from source to destination") + if err := pgDstConn.SyncIndexesAndTriggers(ctx, config.TableMappings, pgSrcConn); err != nil { + return a.Alerter.LogFlowError(ctx, config.FlowName, fmt.Errorf("failed to sync indexes, triggers, and constraints: %w", err)) + } + + a.Alerter.LogFlowInfo(ctx, config.FlowName, "Successfully synced indexes, triggers, and constraints") + return nil +} + func (a *FlowableActivity) SyncFlow( ctx context.Context, config *protos.FlowConnectionConfigsCore, diff --git a/flow/activities/flowable_core.go b/flow/activities/flowable_core.go index 2da7e035a0..815af3ffb5 100644 --- a/flow/activities/flowable_core.go +++ b/flow/activities/flowable_core.go @@ -108,6 +108,70 @@ func (a *FlowableActivity) applySchemaDeltas( return nil } +func (a *FlowableActivity) updateDestinationSchemaMapping( + ctx context.Context, + config *protos.FlowConnectionConfigsCore, + options *protos.SyncFlowOptions, + schemaDeltas []*protos.TableSchemaDelta, +) error { + logger := internal.LoggerFromCtx(ctx) + + filteredTableMappings := make([]*protos.TableMapping, 0, len(schemaDeltas)) + for _, tableMapping := range options.TableMappings { + if slices.ContainsFunc(schemaDeltas, func(schemaDelta *protos.TableSchemaDelta) bool { + return schemaDelta.SrcTableName == tableMapping.SourceTableIdentifier && + schemaDelta.DstTableName == tableMapping.DestinationTableIdentifier + }) { + filteredTableMappings = append(filteredTableMappings, tableMapping) + } + } + + if len(filteredTableMappings) == 0 { + return nil + } + + dstConn, dstClose, err := connectors.GetByNameAs[connectors.GetTableSchemaConnector](ctx, config.Env, a.CatalogPool, config.DestinationName) + if err != nil { + return fmt.Errorf("failed to get destination connector for schema update: %w", err) + } + defer dstClose(ctx) + + tableNameSchemaMapping, err := dstConn.GetTableSchema(ctx, config.Env, config.Version, config.System, filteredTableMappings) + if err != nil { + return fmt.Errorf("failed to get updated schema from destination: %w", err) + } + + processed := internal.BuildProcessedSchemaMapping(filteredTableMappings, tableNameSchemaMapping, logger) + + tx, err := a.CatalogPool.BeginTx(ctx, pgx.TxOptions{}) + if err != nil { + return fmt.Errorf("failed to start transaction for schema mapping update: %w", err) + } + defer shared.RollbackTx(tx, logger) + + for tableName, tableSchema := range processed { + processedBytes, err := proto.Marshal(tableSchema) + if err != nil { + return fmt.Errorf("failed to marshal table schema for %s: %w", tableName, err) + } + if _, err := tx.Exec( + ctx, + "insert into table_schema_mapping(flow_name, table_name, table_schema) values ($1, $2, $3) "+ + "on conflict (flow_name, table_name) do update set table_schema = $3", + config.FlowJobName, + tableName, + processedBytes, + ); err != nil { + return fmt.Errorf("failed to update schema mapping for %s: %w", tableName, err) + } + } + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("failed to commit schema mapping update: %w", err) + } + return nil +} + func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncConnectorCore, Items model.Items]( ctx context.Context, a *FlowableActivity, @@ -338,6 +402,13 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon return nil, err } + if len(res.TableSchemaDeltas) > 0 { + if err := a.updateDestinationSchemaMapping(ctx, config, options, res.TableSchemaDeltas); err != nil { + logger.Warn("Failed to update destination schema mapping, normalization may use stale schema", + slog.Any("error", err)) + } + } + if recordBatchSync.NeedsNormalize() { syncState.Store(shared.Ptr("normalizing")) normRequests.Update(res.CurrentSyncBatchID) diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index f041814a43..ff6901ad58 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -684,7 +684,7 @@ func PullCdcRecords[Items model.Items]( logger.Debug("XLogData", slog.Any("WALStart", xld.WALStart), slog.Any("ServerWALEnd", xld.ServerWALEnd), slog.Time("ServerTime", xld.ServerTime)) - rec, err := processMessage(ctx, p, records, xld, clientXLogPos, processor) + rec, err := processMessage(ctx, p, req, records, xld, clientXLogPos, processor) if err != nil { return exceptions.NewPostgresLogicalMessageProcessingError(err) } @@ -697,6 +697,7 @@ func PullCdcRecords[Items model.Items]( fetchedBytes.Add(int64(len(msg.Data))) totalFetchedBytes.Add(int64(len(msg.Data))) tableName := rec.GetDestinationTableName() + switch r := rec.(type) { case *model.UpdateRecord[Items]: // tableName here is destination tableName. @@ -792,10 +793,17 @@ func PullCdcRecords[Items model.Items]( case *model.RelationRecord[Items]: tableSchemaDelta := r.TableSchemaDelta - if len(tableSchemaDelta.AddedColumns) > 0 { - logger.Info(fmt.Sprintf("Detected schema change for table %s, addedColumns: %v", - tableSchemaDelta.SrcTableName, tableSchemaDelta.AddedColumns)) + if len(tableSchemaDelta.AddedColumns) > 0 || len(tableSchemaDelta.DroppedColumns) > 0 { + addedColNames := make([]string, 0, len(tableSchemaDelta.AddedColumns)) + for _, col := range tableSchemaDelta.AddedColumns { + addedColNames = append(addedColNames, fmt.Sprintf("%s(%s)", col.Name, col.Type)) + } records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta) + logger.Info("Added schema delta to records stream", + slog.String("srcTableName", tableSchemaDelta.SrcTableName)) + } else { + logger.Debug("RelationRecord with no schema changes, skipping", + slog.String("srcTableName", tableSchemaDelta.SrcTableName)) } case *model.MessageRecord[Items]: @@ -847,9 +855,11 @@ func (p *PostgresCDCSource) baseRecord(lsn pglogrepl.LSN) model.BaseRecord { } } +// req *model.PullRecordsRequest[Items] because schema-change handling now needs access to the request’s table-schema cache func processMessage[Items model.Items]( ctx context.Context, p *PostgresCDCSource, + req *model.PullRecordsRequest[Items], batch *model.CDCStream[Items], xld pglogrepl.XLogData, currentClientXlogPos pglogrepl.LSN, @@ -890,17 +900,20 @@ func processMessage[Items model.Items]( return nil, err } + currColNames := make([]string, 0, len(msg.Columns)) + for _, col := range msg.Columns { + currColNames = append(currColNames, col.Name) + } + if _, exists := p.srcTableIDNameMapping[msg.RelationID]; !exists { + logger.Warn("RelationMessage received for table not in replication set, skipping", + slog.Uint64("RelationID", uint64(msg.RelationID)), + slog.String("Namespace", msg.Namespace), + slog.String("RelationName", msg.RelationName)) return nil, nil } - logger.Debug("RelationMessage", - slog.Uint64("RelationID", uint64(msg.RelationID)), - slog.String("Namespace", msg.Namespace), - slog.String("RelationName", msg.RelationName), - slog.Any("Columns", msg.Columns)) - - return processRelationMessage[Items](ctx, p, currentClientXlogPos, msg) + return processRelationMessage[Items](ctx, p, req, currentClientXlogPos, msg) case *pglogrepl.LogicalDecodingMessage: logger.Debug("LogicalDecodingMessage", slog.Bool("Transactional", msg.Transactional), @@ -1074,6 +1087,7 @@ func processDeleteMessage[Items model.Items]( func processRelationMessage[Items model.Items]( ctx context.Context, p *PostgresCDCSource, + req *model.PullRecordsRequest[Items], lsn pglogrepl.LSN, currRel *pglogrepl.RelationMessage, ) (model.Record[Items], error) { @@ -1091,7 +1105,11 @@ func processRelationMessage[Items model.Items]( p.logger.Info("processing RelationMessage", slog.Any("LSN", lsn), - slog.String("RelationName", currRelName)) + slog.String("RelationName", currRelName), + slog.Int("RelationID", int(currRel.RelationID)), + slog.Int("columnCount", len(currRel.Columns)), + slog.String("relationNamespace", currRel.Namespace), + slog.String("relationRelationName", currRel.RelationName)) // retrieve current TableSchema for table changed, mapping uses dst table name as key, need to translate source name currRelDstInfo, ok := p.tableNameMapping[currRelName] if !ok { @@ -1103,13 +1121,31 @@ func processRelationMessage[Items model.Items]( prevSchema, ok := p.tableNameSchemaMapping[currRelDstInfo.Name] if !ok { p.logger.Error("Detected relation message for table, but not in table schema mapping", - slog.String("tableName", currRelDstInfo.Name)) + slog.String("tableName", currRelDstInfo.Name), + slog.String("srcTableName", currRelName), + slog.Int("relationID", int(currRel.RelationID))) return nil, fmt.Errorf("cannot find table schema for %s", currRelDstInfo.Name) } + if prevSchema == nil { + p.logger.Error("Previous schema is nil for table", + slog.String("tableName", currRelDstInfo.Name)) + return nil, fmt.Errorf("previous schema is nil for %s", currRelDstInfo.Name) + } + + if len(prevSchema.Columns) == 0 { + p.logger.Warn("Previous schema has no columns, skipping schema comparison", + slog.String("tableName", currRelDstInfo.Name)) + // Still process to update the schema cache, but don't detect changes + p.relationMessageMapping[currRel.RelationID] = currRel + return nil, nil + } + prevRelMap := make(map[string]string, len(prevSchema.Columns)) + prevColNames := make([]string, 0, len(prevSchema.Columns)) for _, column := range prevSchema.Columns { prevRelMap[column.Name] = column.Type + prevColNames = append(prevColNames, column.Name) } currRelMap := make(map[string]string, len(currRel.Columns)) @@ -1139,9 +1175,20 @@ func processRelationMessage[Items model.Items]( SrcTableName: p.srcTableIDNameMapping[currRel.RelationID], DstTableName: p.tableNameMapping[p.srcTableIDNameMapping[currRel.RelationID]].Name, AddedColumns: nil, + DroppedColumns: nil, System: prevSchema.System, NullableEnabled: prevSchema.NullableEnabled, } + + // Log previous and current column names for debugging + currColNames := make([]string, 0, len(currRel.Columns)) + for _, col := range currRel.Columns { + currColNames = append(currColNames, col.Name) + } + p.logger.Debug("Comparing schemas for schema change detection", + slog.String("tableName", schemaDelta.SrcTableName), + slog.Any("previousColumns", prevColNames), + slog.Any("currentColumns", currColNames)) for _, column := range currRel.Columns { // not present in previous relation message, but in current one, so added. if _, ok := prevRelMap[column.Name]; !ok { @@ -1163,63 +1210,239 @@ func processRelationMessage[Items model.Items]( p.logger.Info("Detected added column", slog.String("columnName", column.Name), slog.String("columnType", currRelMap[column.Name]), - slog.String("relationName", schemaDelta.SrcTableName)) + slog.String("typeModifier", fmt.Sprintf("%d", column.TypeModifier)), + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Any("LSN", lsn)) } else { p.logger.Warn(fmt.Sprintf("Detected added column %s in table %s, but not propagating because excluded", column.Name, schemaDelta.SrcTableName)) } // present in previous and current relation messages, but data types have changed. - // so we add it to AddedColumns and DroppedColumns, knowing that we process DroppedColumns first. + // Detect and record type changes } else if prevRelMap[column.Name] != currRelMap[column.Name] { - p.logger.Warn(fmt.Sprintf("Detected column %s with type changed from %s to %s in table %s, but not propagating", - column.Name, prevRelMap[column.Name], currRelMap[column.Name], schemaDelta.SrcTableName)) + // Get default value for the column if it exists + var newDefaultValue string + for _, currCol := range currRel.Columns { + if currCol.Name == column.Name { + // Fetch default value from pg_catalog + rows, err := p.conn.Query(ctx, + `SELECT pg_get_expr(adbin, adrelid) as column_default + FROM pg_attribute + LEFT JOIN pg_attrdef ON pg_attribute.attrelid = pg_attrdef.adrelid + AND pg_attribute.attnum = pg_attrdef.adnum + WHERE attrelid=$1 AND attname=$2 AND attnum > 0 AND NOT attisdropped`, + currRel.RelationID, column.Name) + if err == nil { + var defaultVal pgtype.Text + if rows.Next() { + if err := rows.Scan(&defaultVal); err == nil && defaultVal.Valid { + newDefaultValue = defaultVal.String + } + } + rows.Close() + } + break + } + } + + typeChange := &protos.ColumnTypeChange{ + ColumnName: column.Name, + OldType: prevRelMap[column.Name], + NewType: currRelMap[column.Name], + NewDefaultValue: newDefaultValue, + } + schemaDelta.TypeChangedColumns = append(schemaDelta.TypeChangedColumns, typeChange) + p.logger.Info("Detected column type change", + slog.String("columnName", column.Name), + slog.String("oldType", prevRelMap[column.Name]), + slog.String("newType", currRelMap[column.Name]), + slog.String("defaultValue", newDefaultValue), + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Any("LSN", lsn)) } } for _, column := range prevSchema.Columns { // present in previous relation message, but not in current one, so dropped. if _, ok := currRelMap[column.Name]; !ok { - p.logger.Warn(fmt.Sprintf("Detected dropped column %s in table %s, but not propagating", column, - schemaDelta.SrcTableName)) + // Never drop PeerDB system columns - they are managed by PeerDB + // System columns start with _PEERDB_ prefix + if strings.HasPrefix(column.Name, "_PEERDB_") { + p.logger.Warn("Detected dropped PeerDB system column, skipping", + slog.String("columnName", column.Name), + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.String("reason", "PeerDB system columns cannot be dropped")) + continue + } + + if _, ok := p.tableNameMapping[p.srcTableIDNameMapping[currRel.RelationID]].Exclude[column.Name]; !ok { + schemaDelta.DroppedColumns = append(schemaDelta.DroppedColumns, column.Name) + p.logger.Info("Detected dropped column", + slog.String("columnName", column.Name), + slog.String("columnType", column.Type), + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Any("LSN", lsn)) + } else { + p.logger.Warn(fmt.Sprintf("Detected dropped column %s in table %s, but not propagating because excluded", + column.Name, schemaDelta.SrcTableName)) + } } } - if len(potentiallyNullableAddedColumns) > 0 { - p.logger.Info("Checking for potentially nullable columns in table", + + if len(schemaDelta.AddedColumns) > 0 || len(schemaDelta.DroppedColumns) > 0 || len(schemaDelta.TypeChangedColumns) > 0 { + addedColNames := make([]string, 0, len(schemaDelta.AddedColumns)) + for _, col := range schemaDelta.AddedColumns { + addedColNames = append(addedColNames, fmt.Sprintf("%s(%s)", col.Name, col.Type)) + } + typeChangedColNames := make([]string, 0, len(schemaDelta.TypeChangedColumns)) + for _, tc := range schemaDelta.TypeChangedColumns { + typeChangedColNames = append(typeChangedColNames, fmt.Sprintf("%s(%s->%s)", tc.ColumnName, tc.OldType, tc.NewType)) + } + p.logger.Info("Schema change detection summary", + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Int("addedColumnsCount", len(schemaDelta.AddedColumns)), + slog.Any("addedColumns", addedColNames), + slog.Int("droppedColumnsCount", len(schemaDelta.DroppedColumns)), + slog.Any("droppedColumns", schemaDelta.DroppedColumns), + slog.Int("typeChangedColumnsCount", len(schemaDelta.TypeChangedColumns)), + slog.Any("typeChangedColumns", typeChangedColNames), + slog.Any("LSN", lsn)) + } + + p.relationMessageMapping[currRel.RelationID] = currRel + if len(schemaDelta.AddedColumns) > 0 { + addedColNames := make([]string, 0, len(schemaDelta.AddedColumns)) + for _, col := range schemaDelta.AddedColumns { + addedColNames = append(addedColNames, utils.QuoteLiteral(col.Name)) + } + + p.logger.Info("Fetching default values and nullable info for added columns", slog.String("tableName", schemaDelta.SrcTableName), - slog.Any("potentiallyNullable", potentiallyNullableAddedColumns)) + slog.Any("addedColumns", addedColNames)) rows, err := p.conn.Query( ctx, fmt.Sprintf( - "select attname from pg_attribute where attrelid=$1 and attname in (%s) and not attnotnull", - strings.Join(potentiallyNullableAddedColumns, ","), + `SELECT attname, attnotnull, pg_get_expr(adbin, adrelid) as column_default + FROM pg_attribute + LEFT JOIN pg_attrdef ON pg_attribute.attrelid = pg_attrdef.adrelid + AND pg_attribute.attnum = pg_attrdef.adnum + WHERE attrelid=$1 AND attname IN (%s) AND attnum > 0 AND NOT attisdropped`, + strings.Join(addedColNames, ","), ), currRel.RelationID, ) if err != nil { - return nil, fmt.Errorf("error looking up column nullable info for schema change: %w", err) + return nil, fmt.Errorf("error looking up column default/nullable info for schema change: %w", err) + } + + type colInfo struct { + attname string + attnotnull bool + columnDefault pgtype.Text } - attnames, err := pgx.CollectRows[string](rows, pgx.RowTo) + colInfos, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (colInfo, error) { + var info colInfo + err := rows.Scan(&info.attname, &info.attnotnull, &info.columnDefault) + return info, err + }) if err != nil { - return nil, fmt.Errorf("error collecting rows for column nullable info for schema change: %w", err) + return nil, fmt.Errorf("error collecting rows for column default/nullable info: %w", err) } + + colInfoMap := make(map[string]colInfo, len(colInfos)) + for _, info := range colInfos { + colInfoMap[info.attname] = info + } + for _, column := range schemaDelta.AddedColumns { - if slices.Contains(attnames, column.Name) { - column.Nullable = true - p.logger.Info(fmt.Sprintf("Detected column %s in table %s as nullable", - column.Name, schemaDelta.SrcTableName)) + if info, ok := colInfoMap[column.Name]; ok { + column.Nullable = !info.attnotnull + if info.columnDefault.Valid && info.columnDefault.String != "" { + // Store default value - we'll use it in the ADD COLUMN statement + column.DefaultValue = info.columnDefault.String + } else { + p.logger.Info("Detected column without default value", + slog.String("columnName", column.Name), + slog.Bool("nullable", column.Nullable), + slog.String("tableName", schemaDelta.SrcTableName)) + } } } } - p.relationMessageMapping[currRel.RelationID] = currRel - // only log audit if there is actionable delta - if len(schemaDelta.AddedColumns) > 0 { + if len(schemaDelta.AddedColumns) > 0 || len(schemaDelta.DroppedColumns) > 0 || len(schemaDelta.TypeChangedColumns) > 0 { + updatedSchema := &protos.TableSchema{ + Columns: make([]*protos.FieldDescription, 0, len(prevSchema.Columns)), + System: prevSchema.System, + PrimaryKeyColumns: prevSchema.PrimaryKeyColumns, + NullableEnabled: prevSchema.NullableEnabled, + } + + droppedColsMap := make(map[string]bool, len(schemaDelta.DroppedColumns)) + for _, droppedCol := range schemaDelta.DroppedColumns { + droppedColsMap[droppedCol] = true + } + + typeChangedColsMap := make(map[string]*protos.ColumnTypeChange, len(schemaDelta.TypeChangedColumns)) + for _, typeChange := range schemaDelta.TypeChangedColumns { + typeChangedColsMap[typeChange.ColumnName] = typeChange + } + for _, col := range prevSchema.Columns { + if !droppedColsMap[col.Name] { + if typeChange, ok := typeChangedColsMap[col.Name]; ok { + updatedCol := &protos.FieldDescription{ + Name: col.Name, + Type: typeChange.NewType, + TypeModifier: col.TypeModifier, + Nullable: col.Nullable, + DefaultValue: typeChange.NewDefaultValue, + } + updatedSchema.Columns = append(updatedSchema.Columns, updatedCol) + } else { + updatedSchema.Columns = append(updatedSchema.Columns, col) + } + } + } + + for _, addedCol := range schemaDelta.AddedColumns { + updatedSchema.Columns = append(updatedSchema.Columns, addedCol) + } + + p.tableNameSchemaMapping[currRelDstInfo.Name] = updatedSchema + if req != nil && req.TableNameSchemaMapping != nil { + req.TableNameSchemaMapping[currRelDstInfo.Name] = updatedSchema + p.logger.Info("Updated req.TableNameSchemaMapping for immediate DML processing", + slog.String("tableName", currRelDstInfo.Name)) + } + + p.logger.Info("Updated cached schema mapping after detecting changes", + slog.String("tableName", currRelDstInfo.Name), + slog.Int("addedColumns", len(schemaDelta.AddedColumns)), + slog.Int("droppedColumns", len(schemaDelta.DroppedColumns)), + slog.Int("totalColumns", len(updatedSchema.Columns))) + } + + if len(schemaDelta.AddedColumns) > 0 || len(schemaDelta.DroppedColumns) > 0 || len(schemaDelta.TypeChangedColumns) > 0 { + p.logger.Info("Returning RelationRecord with schema delta", + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Int("addedColumnsCount", len(schemaDelta.AddedColumns)), + slog.Int("droppedColumnsCount", len(schemaDelta.DroppedColumns)), + slog.Int("typeChangedColumnsCount", len(schemaDelta.TypeChangedColumns)), + slog.Any("LSN", lsn)) return &model.RelationRecord[Items]{ BaseRecord: p.baseRecord(lsn), TableSchemaDelta: schemaDelta, }, monitoring.AuditSchemaDelta(ctx, p.catalogPool.Pool, p.flowJobName, schemaDelta) } + p.logger.Debug("No schema changes detected, returning nil", + slog.String("tableName", schemaDelta.SrcTableName)) return nil, nil } diff --git a/flow/connectors/postgres/index_trigger_sync.go b/flow/connectors/postgres/index_trigger_sync.go new file mode 100644 index 0000000000..153a35baba --- /dev/null +++ b/flow/connectors/postgres/index_trigger_sync.go @@ -0,0 +1,763 @@ +package connpostgres + +import ( + "context" + "fmt" + "log/slog" + "strings" + + "github.com/PeerDB-io/peerdb/flow/connectors/utils" + "github.com/PeerDB-io/peerdb/flow/generated/protos" +) + +type IndexInfo struct { + IndexName string + TableSchema string + TableName string + IndexDef string + IsUnique bool + IsPrimary bool + IndexColumns []string +} +type TriggerInfo struct { + TriggerName string + TableSchema string + TableName string + TriggerDef string + EventManipulation string + ActionTiming string + ActionStatement string +} + +type ConstraintInfo struct { + ConstraintName string + TableSchema string + TableName string + ConstraintType string // 'c' for check, 'f' for foreign key, 'u' for unique, 'p' for primary key + ConstraintDef string // Full constraint definition from pg_get_constraintdef + IsDeferrable bool + IsDeferred bool +} + +func (c *PostgresConnector) SyncIndexesAndTriggers( + ctx context.Context, + tableMappings []*protos.TableMapping, + sourceConn *PostgresConnector, +) error { + c.logger.Info("Starting index and trigger synchronization", + slog.Int("tableCount", len(tableMappings))) + + for _, tableMapping := range tableMappings { + srcTable, err := utils.ParseSchemaTable(tableMapping.SourceTableIdentifier) + if err != nil { + return fmt.Errorf("error parsing source table %s: %w", tableMapping.SourceTableIdentifier, err) + } + + dstTable, err := utils.ParseSchemaTable(tableMapping.DestinationTableIdentifier) + if err != nil { + return fmt.Errorf("error parsing destination table %s: %w", tableMapping.DestinationTableIdentifier, err) + } + + // Sync indexes + if err := c.syncIndexesForTable(ctx, srcTable, dstTable, sourceConn); err != nil { + c.logger.Warn("Failed to sync indexes for table", + slog.String("srcTable", srcTable.String()), + slog.String("dstTable", dstTable.String()), + slog.Any("error", err)) + } + + if err := c.syncTriggersForTable(ctx, srcTable, dstTable, sourceConn); err != nil { + c.logger.Warn("Failed to sync triggers for table", + slog.String("srcTable", srcTable.String()), + slog.String("dstTable", dstTable.String()), + slog.Any("error", err)) + } + + if err := c.syncConstraintsForTable(ctx, srcTable, dstTable, sourceConn, tableMappings); err != nil { + c.logger.Warn("Failed to sync constraints for table", + slog.String("srcTable", srcTable.String()), + slog.String("dstTable", dstTable.String()), + slog.Any("error", err)) + } + } + + c.logger.Info("Completed index, trigger, and constraint synchronization") + return nil +} + +func (c *PostgresConnector) syncIndexesForTable( + ctx context.Context, + srcTable *utils.SchemaTable, + dstTable *utils.SchemaTable, + sourceConn *PostgresConnector, +) error { + srcIndexes, err := sourceConn.getIndexesForTable(ctx, srcTable) + if err != nil { + return fmt.Errorf("error getting source indexes: %w", err) + } + + dstIndexes, err := c.getIndexesForTable(ctx, dstTable) + if err != nil { + return fmt.Errorf("error getting destination indexes: %w", err) + } + + dstIndexMap := make(map[string]*IndexInfo, len(dstIndexes)) + for _, idx := range dstIndexes { + dstIndexMap[idx.IndexName] = idx + } + + createdCount := 0 + for _, srcIdx := range srcIndexes { + if srcIdx.IsPrimary { + continue + } + + if _, exists := dstIndexMap[srcIdx.IndexName]; exists { + c.logger.Debug("Index already exists in destination", + slog.String("indexName", srcIdx.IndexName), + slog.String("dstTable", dstTable.String())) + continue + } + indexSQL := c.adaptIndexSQL(srcIdx.IndexDef, srcTable, dstTable) + + c.logger.Info("Creating index on destination", + slog.String("indexName", srcIdx.IndexName), + slog.String("srcTable", srcTable.String()), + slog.String("dstTable", dstTable.String()), + slog.String("indexSQL", indexSQL)) + + if _, err := c.conn.Exec(ctx, indexSQL); err != nil { + c.logger.Error("Failed to create index", + slog.String("indexName", srcIdx.IndexName), + slog.String("indexSQL", indexSQL), + slog.Any("error", err)) + continue + } + + createdCount++ + c.logger.Info("Successfully created index", + slog.String("indexName", srcIdx.IndexName), + slog.String("dstTable", dstTable.String())) + } + + if createdCount > 0 { + c.logger.Info("Created indexes for table", + slog.String("dstTable", dstTable.String()), + slog.Int("createdCount", createdCount)) + } + + return nil +} + +// syncTriggersForTable syncs triggers for a specific table +func (c *PostgresConnector) syncTriggersForTable( + ctx context.Context, + srcTable *utils.SchemaTable, + dstTable *utils.SchemaTable, + sourceConn *PostgresConnector, +) error { + srcTriggers, err := sourceConn.getTriggersForTable(ctx, srcTable) + if err != nil { + return fmt.Errorf("error getting source triggers: %w", err) + } + + dstTriggers, err := c.getTriggersForTable(ctx, dstTable) + if err != nil { + return fmt.Errorf("error getting destination triggers: %w", err) + } + + dstTriggerMap := make(map[string]*TriggerInfo, len(dstTriggers)) + for _, trig := range dstTriggers { + dstTriggerMap[trig.TriggerName] = trig + } + + createdCount := 0 + for _, srcTrig := range srcTriggers { + c.logger.Info("Processing source trigger", + slog.String("triggerName", srcTrig.TriggerName), + slog.String("triggerDef", srcTrig.TriggerDef), + slog.String("srcTable", srcTable.String())) + + if _, exists := dstTriggerMap[srcTrig.TriggerName]; exists { + c.logger.Info("Trigger already exists in destination, skipping", + slog.String("triggerName", srcTrig.TriggerName), + slog.String("dstTable", dstTable.String())) + continue + } + + funcName, funcSchema := c.extractFunctionFromTriggerDef(srcTrig.TriggerDef) + + funcExists := false + if funcName != "" { + schemasToCheck := []string{funcSchema} + if funcSchema == "public" || funcSchema == "" { + schemasToCheck = []string{"public", dstTable.Schema, srcTable.Schema} + } + + for _, schema := range schemasToCheck { + exists, err := c.checkFunctionExists(ctx, schema, funcName) + if err != nil { + c.logger.Warn("Failed to check if function exists", + slog.String("functionName", funcName), + slog.String("functionSchema", schema), + slog.Any("error", err)) + continue + } + if exists { + funcExists = true + funcSchema = schema // Update to the actual schema where function exists + c.logger.Info("Found function on destination", + slog.String("functionName", funcName), + slog.String("functionSchema", schema)) + break + } + } + + if !funcExists { + sourceSchemasToCheck := []string{funcSchema, "public", srcTable.Schema} + if funcSchema == "public" || funcSchema == "" { + sourceSchemasToCheck = []string{"public", srcTable.Schema} + } + + funcSynced := false + for _, sourceSchema := range sourceSchemasToCheck { + c.logger.Info("Attempting to sync trigger function from source", + slog.String("functionName", funcName), + slog.String("sourceSchema", sourceSchema), + slog.String("targetSchema", funcSchema)) + + if err := c.syncTriggerFunction(ctx, sourceSchema, funcName, funcSchema, sourceConn); err != nil { + continue + } + + funcExists, err = c.checkFunctionExists(ctx, funcSchema, funcName) + if err == nil && funcExists { + funcSynced = true + c.logger.Info("Successfully synced trigger function", + slog.String("functionName", funcName), + slog.String("sourceSchema", sourceSchema), + slog.String("targetSchema", funcSchema)) + break + } + } + + if !funcSynced { + c.logger.Warn("Failed to sync trigger function from source, skipping trigger", + slog.String("triggerName", srcTrig.TriggerName), + slog.String("functionName", funcName), + slog.String("checkedSchemas", fmt.Sprintf("%v", sourceSchemasToCheck)), + slog.String("hint", "Create the function manually on destination")) + continue + } + } + } + triggerSQL := c.adaptTriggerSQL(srcTrig.TriggerDef, srcTable, dstTable) + + if _, err := c.conn.Exec(ctx, triggerSQL); err != nil { + c.logger.Error("Failed to create trigger", + slog.String("triggerName", srcTrig.TriggerName), + slog.String("triggerSQL", triggerSQL), + slog.Any("error", err)) + continue + } + + createdCount++ + c.logger.Info("Successfully created trigger", + slog.String("triggerName", srcTrig.TriggerName), + slog.String("dstTable", dstTable.String())) + } + + if createdCount > 0 { + c.logger.Info("Created triggers for table", + slog.String("dstTable", dstTable.String()), + slog.Int("createdCount", createdCount)) + } + + return nil +} + +// getIndexesForTable retrieves all indexes for a given table +func (c *PostgresConnector) getIndexesForTable( + ctx context.Context, + table *utils.SchemaTable, +) ([]*IndexInfo, error) { + query := ` + SELECT + indexname, + schemaname, + tablename, + indexdef + FROM pg_indexes + WHERE schemaname = $1 AND tablename = $2 + ORDER BY indexname + ` + + rows, err := c.conn.Query(ctx, query, table.Schema, table.Table) + if err != nil { + return nil, fmt.Errorf("error querying indexes: %w", err) + } + defer rows.Close() + + var indexes []*IndexInfo + for rows.Next() { + var idx IndexInfo + err := rows.Scan( + &idx.IndexName, + &idx.TableSchema, + &idx.TableName, + &idx.IndexDef, + ) + if err != nil { + return nil, fmt.Errorf("error scanning index row: %w", err) + } + + idx.IsUnique = strings.Contains(strings.ToUpper(idx.IndexDef), "UNIQUE") + idx.IsPrimary = strings.HasSuffix(idx.IndexName, "_pkey") || + strings.Contains(strings.ToUpper(idx.IndexDef), "PRIMARY KEY") + + idx.IndexColumns = c.extractColumnsFromIndexDef(idx.IndexDef) + + indexes = append(indexes, &idx) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating index rows: %w", err) + } + + return indexes, nil +} + +// extractColumnsFromIndexDef extracts column names from index definition +func (c *PostgresConnector) extractColumnsFromIndexDef(indexDef string) []string { + var columns []string + + start := strings.Index(indexDef, "(") + end := strings.LastIndex(indexDef, ")") + if start >= 0 && end > start { + colPart := indexDef[start+1 : end] + parts := strings.Split(colPart, ",") + for _, part := range parts { + col := strings.TrimSpace(part) + col = strings.Trim(col, `"'`) + if spaceIdx := strings.Index(col, " "); spaceIdx > 0 { + col = col[:spaceIdx] + } + if len(col) > 0 { + columns = append(columns, col) + } + } + } + + return columns +} + +// getTriggersForTable retrieves all triggers for a given table +func (c *PostgresConnector) getTriggersForTable( + ctx context.Context, + table *utils.SchemaTable, +) ([]*TriggerInfo, error) { + query := ` + SELECT + t.tgname as trigger_name, + n.nspname as schema_name, + c.relname as table_name, + pg_get_triggerdef(t.oid) as trigger_def, + CASE + WHEN t.tgtype & 2 = 2 THEN 'BEFORE' + WHEN t.tgtype & 64 = 64 THEN 'INSTEAD OF' + ELSE 'AFTER' + END as action_timing, + CASE + WHEN t.tgtype & 4 = 4 THEN 'INSERT' + WHEN t.tgtype & 8 = 8 THEN 'DELETE' + WHEN t.tgtype & 16 = 16 THEN 'UPDATE' + ELSE 'UNKNOWN' + END as event_manipulation + FROM pg_trigger t + JOIN pg_class c ON c.oid = t.tgrelid + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = $1 + AND c.relname = $2 + AND NOT t.tgisinternal + ORDER BY t.tgname + ` + + rows, err := c.conn.Query(ctx, query, table.Schema, table.Table) + if err != nil { + return nil, fmt.Errorf("error querying triggers: %w", err) + } + defer rows.Close() + + var triggers []*TriggerInfo + for rows.Next() { + var trig TriggerInfo + err := rows.Scan( + &trig.TriggerName, + &trig.TableSchema, + &trig.TableName, + &trig.TriggerDef, + &trig.ActionTiming, + &trig.EventManipulation, + ) + if err != nil { + return nil, fmt.Errorf("error scanning trigger row: %w", err) + } + + trig.ActionStatement = c.extractActionStatement(trig.TriggerDef) + + triggers = append(triggers, &trig) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating trigger rows: %w", err) + } + + return triggers, nil +} + +func (c *PostgresConnector) extractActionStatement(triggerDef string) string { + executeIdx := strings.Index(strings.ToUpper(triggerDef), "EXECUTE") + if executeIdx >= 0 { + return triggerDef[executeIdx:] + } + return "" +} + +func (c *PostgresConnector) adaptIndexSQL( + indexSQL string, + srcTable *utils.SchemaTable, + dstTable *utils.SchemaTable, +) string { + adapted := strings.ReplaceAll(indexSQL, + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(srcTable.Schema), utils.QuoteIdentifier(srcTable.Table)), + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTable.Table))) + + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("%s.%s", srcTable.Schema, srcTable.Table), + fmt.Sprintf("%s.%s", dstTable.Schema, dstTable.Table)) + + return adapted +} + +func (c *PostgresConnector) extractFunctionFromTriggerDef(triggerDef string) (funcName, funcSchema string) { + executeIdx := strings.Index(strings.ToUpper(triggerDef), "EXECUTE") + if executeIdx < 0 { + return "", "" + } + + executePart := triggerDef[executeIdx:] + funcKeywordIdx := strings.Index(strings.ToUpper(executePart), "FUNCTION") + if funcKeywordIdx < 0 { + funcKeywordIdx = strings.Index(strings.ToUpper(executePart), "PROCEDURE") + if funcKeywordIdx < 0 { + return "", "" + } + } + + funcPart := strings.TrimSpace(executePart[funcKeywordIdx+8:]) // 8 = len("FUNCTION") or len("PROCEDURE") + funcPart = strings.TrimSpace(strings.TrimSuffix(funcPart, "()")) + funcPart = strings.TrimSpace(strings.TrimSuffix(funcPart, ")")) + + // Check if it has schema qualification (schema.function) + if dotIdx := strings.LastIndex(funcPart, "."); dotIdx >= 0 { + funcSchema = funcPart[:dotIdx] + funcName = funcPart[dotIdx+1:] + funcSchema = strings.Trim(funcSchema, `"'`) + funcName = strings.Trim(funcName, `"'`) + } else { + funcName = strings.Trim(funcPart, `"'`) + funcSchema = "public" // Default to public schema + } + + return funcName, funcSchema +} + +// checkFunctionExists checks if a function exists in the specified schema +func (c *PostgresConnector) checkFunctionExists(ctx context.Context, schema, funcName string) (bool, error) { + query := ` + SELECT EXISTS ( + SELECT 1 + FROM pg_proc p + JOIN pg_namespace n ON n.oid = p.pronamespace + WHERE n.nspname = $1 AND p.proname = $2 + ) + ` + + var exists bool + err := c.conn.QueryRow(ctx, query, schema, funcName).Scan(&exists) + if err != nil { + return false, fmt.Errorf("error checking function existence: %w", err) + } + + return exists, nil +} + +// syncTriggerFunction syncs a trigger function from source to destination +func (c *PostgresConnector) syncTriggerFunction( + ctx context.Context, + sourceSchema, funcName, targetSchema string, + sourceConn *PostgresConnector, +) error { + query := ` + SELECT pg_get_functiondef(p.oid) as function_def + FROM pg_proc p + JOIN pg_namespace n ON n.oid = p.pronamespace + WHERE n.nspname = $1 AND p.proname = $2 + LIMIT 1 + ` + + var funcDef string + err := sourceConn.conn.QueryRow(ctx, query, sourceSchema, funcName).Scan(&funcDef) + if err != nil { + return fmt.Errorf("error getting function definition from source schema %s: %w", sourceSchema, err) + } + + if funcDef == "" { + return fmt.Errorf("function definition is empty") + } + + c.logger.Info("Retrieved function definition from source", + slog.String("functionName", funcName), + slog.String("sourceSchema", sourceSchema), + slog.String("functionDef", funcDef)) + + if sourceSchema != targetSchema { + funcDef = strings.ReplaceAll(funcDef, + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(sourceSchema), utils.QuoteIdentifier(funcName)), + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(targetSchema), utils.QuoteIdentifier(funcName))) + funcDef = strings.ReplaceAll(funcDef, + fmt.Sprintf("%s.%s", sourceSchema, funcName), + fmt.Sprintf("%s.%s", targetSchema, funcName)) + } + + c.logger.Info("Creating function on destination", + slog.String("functionName", funcName), + slog.String("targetSchema", targetSchema)) + + if _, err := c.conn.Exec(ctx, funcDef); err != nil { + return fmt.Errorf("error creating function on destination: %w", err) + } + + c.logger.Info("Successfully created function on destination", + slog.String("functionName", funcName), + slog.String("targetSchema", targetSchema)) + + return nil +} + +// adaptTriggerSQL adapts trigger SQL from source to destination table +func (c *PostgresConnector) adaptTriggerSQL( + triggerSQL string, + srcTable *utils.SchemaTable, + dstTable *utils.SchemaTable, +) string { + adapted := strings.ReplaceAll(triggerSQL, + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(srcTable.Schema), utils.QuoteIdentifier(srcTable.Table)), + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTable.Table))) + + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("%s.%s", srcTable.Schema, srcTable.Table), + fmt.Sprintf("%s.%s", dstTable.Schema, dstTable.Table)) + + return adapted +} + +// syncConstraintsForTable syncs constraints (check and foreign key) for a specific table +func (c *PostgresConnector) syncConstraintsForTable( + ctx context.Context, + srcTable *utils.SchemaTable, + dstTable *utils.SchemaTable, + sourceConn *PostgresConnector, + tableMappings []*protos.TableMapping, +) error { + srcConstraints, err := sourceConn.getConstraintsForTable(ctx, srcTable) + if err != nil { + return fmt.Errorf("error getting source constraints: %w", err) + } + + dstConstraints, err := c.getConstraintsForTable(ctx, dstTable) + if err != nil { + return fmt.Errorf("error getting destination constraints: %w", err) + } + + dstConstraintMap := make(map[string]*ConstraintInfo, len(dstConstraints)) + for _, constraint := range dstConstraints { + dstConstraintMap[constraint.ConstraintName] = constraint + } + + tableNameMap := make(map[string]string) + for _, tm := range tableMappings { + src, err := utils.ParseSchemaTable(tm.SourceTableIdentifier) + if err != nil { + continue + } + dst, err := utils.ParseSchemaTable(tm.DestinationTableIdentifier) + if err != nil { + continue + } + tableNameMap[src.String()] = dst.String() + tableNameMap[fmt.Sprintf("%s.%s", src.Schema, src.Table)] = fmt.Sprintf("%s.%s", dst.Schema, dst.Table) + } + + createdCount := 0 + for _, srcConstraint := range srcConstraints { + + if srcConstraint.ConstraintType == "p" { + continue + } + if srcConstraint.ConstraintType == "u" { + continue + } + + if _, exists := dstConstraintMap[srcConstraint.ConstraintName]; exists { + c.logger.Info("Constraint already exists in destination, skipping", + slog.String("constraintName", srcConstraint.ConstraintName), + slog.String("dstTable", dstTable.String())) + continue + } + constraintSQL := c.adaptConstraintSQL(srcConstraint.ConstraintDef, srcTable, dstTable, tableNameMap, srcConstraint.ConstraintName) + + if _, err := c.conn.Exec(ctx, constraintSQL); err != nil { + c.logger.Error("Failed to create constraint", + slog.String("constraintName", srcConstraint.ConstraintName), + slog.String("constraintType", srcConstraint.ConstraintType), + slog.String("constraintSQL", constraintSQL), + slog.Any("error", err)) + continue + } + + createdCount++ + c.logger.Info("Successfully created constraint", + slog.String("constraintName", srcConstraint.ConstraintName), + slog.String("constraintType", srcConstraint.ConstraintType), + slog.String("dstTable", dstTable.String())) + } + + if createdCount > 0 { + c.logger.Info("Created constraints for table", + slog.String("dstTable", dstTable.String()), + slog.Int("createdCount", createdCount)) + } + + return nil +} + +// getConstraintsForTable retrieves all constraints for a given table +func (c *PostgresConnector) getConstraintsForTable( + ctx context.Context, + table *utils.SchemaTable, +) ([]*ConstraintInfo, error) { + query := ` + SELECT + con.conname as constraint_name, + n.nspname as schema_name, + c.relname as table_name, + con.contype::text as constraint_type, + pg_get_constraintdef(con.oid) as constraint_def, + con.condeferrable as is_deferrable, + con.condeferred as is_deferred + FROM pg_constraint con + JOIN pg_class c ON c.oid = con.conrelid + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = $1 + AND c.relname = $2 + AND con.contype IN ('c', 'f') -- 'c' for check, 'f' for foreign key + ORDER BY con.conname + ` + + rows, err := c.conn.Query(ctx, query, table.Schema, table.Table) + if err != nil { + return nil, fmt.Errorf("error querying constraints: %w", err) + } + defer rows.Close() + + var constraints []*ConstraintInfo + for rows.Next() { + var constraint ConstraintInfo + err := rows.Scan( + &constraint.ConstraintName, + &constraint.TableSchema, + &constraint.TableName, + &constraint.ConstraintType, + &constraint.ConstraintDef, + &constraint.IsDeferrable, + &constraint.IsDeferred, + ) + if err != nil { + return nil, fmt.Errorf("error scanning constraint row: %w", err) + } + + constraints = append(constraints, &constraint) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating constraint rows: %w", err) + } + + return constraints, nil +} + +// adaptConstraintSQL adapts constraint SQL from source to destination table +func (c *PostgresConnector) adaptConstraintSQL( + constraintDef string, + srcTable *utils.SchemaTable, + dstTable *utils.SchemaTable, + tableNameMap map[string]string, + constraintName string, +) string { + + adapted := constraintDef + if strings.HasPrefix(strings.ToUpper(strings.TrimSpace(adapted)), "FOREIGN KEY") { + srcTableStr := fmt.Sprintf("%s.%s", srcTable.Schema, srcTable.Table) + dstTableStr := fmt.Sprintf("%s.%s", dstTable.Schema, dstTable.Table) + if _, exists := tableNameMap[srcTableStr]; !exists { + tableNameMap[srcTableStr] = dstTableStr + } + if _, exists := tableNameMap[srcTable.Table]; !exists { + tableNameMap[srcTable.Table] = dstTable.Table + } + for srcTableName, dstTableName := range tableNameMap { + if strings.Contains(srcTableName, ".") { + parts := strings.Split(srcTableName, ".") + if len(parts) == 2 { + srcSchema, srcTbl := parts[0], parts[1] + dstParts := strings.Split(dstTableName, ".") + if len(dstParts) == 2 { + dstSchema, dstTbl := dstParts[0], dstParts[1] + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("REFERENCES %s.%s", utils.QuoteIdentifier(srcSchema), utils.QuoteIdentifier(srcTbl)), + fmt.Sprintf("REFERENCES %s.%s", utils.QuoteIdentifier(dstSchema), utils.QuoteIdentifier(dstTbl))) + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("REFERENCES %s.%s", srcSchema, srcTbl), + fmt.Sprintf("REFERENCES %s.%s", dstSchema, dstTbl)) + } + } + } else { + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("REFERENCES %s", utils.QuoteIdentifier(srcTableName)), + fmt.Sprintf("REFERENCES %s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTableName))) + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("REFERENCES %s", srcTableName), + fmt.Sprintf("REFERENCES %s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTableName))) + } + } + } + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(srcTable.Schema), utils.QuoteIdentifier(srcTable.Table)), + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTable.Table))) + + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("%s.%s", srcTable.Schema, srcTable.Table), + fmt.Sprintf("%s.%s", dstTable.Schema, dstTable.Table)) + + if strings.HasPrefix(strings.ToUpper(strings.TrimSpace(adapted)), "CHECK") || + strings.HasPrefix(strings.ToUpper(strings.TrimSpace(adapted)), "FOREIGN KEY") { + return fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s %s", + utils.QuoteIdentifier(dstTable.Schema), + utils.QuoteIdentifier(dstTable.Table), + utils.QuoteIdentifier(constraintName), + adapted) + } + + return adapted +} diff --git a/flow/connectors/postgres/index_trigger_sync_test.go b/flow/connectors/postgres/index_trigger_sync_test.go new file mode 100644 index 0000000000..e7919a2e69 --- /dev/null +++ b/flow/connectors/postgres/index_trigger_sync_test.go @@ -0,0 +1,555 @@ +package connpostgres + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" + + "github.com/PeerDB-io/peerdb/flow/connectors/utils" + "github.com/PeerDB-io/peerdb/flow/e2eshared" + "github.com/PeerDB-io/peerdb/flow/generated/protos" + "github.com/PeerDB-io/peerdb/flow/internal" + "github.com/PeerDB-io/peerdb/flow/shared" +) + +type IndexTriggerSyncTestSuite struct { + t *testing.T + sourceConn *PostgresConnector + destConn *PostgresConnector + sourceSchema string + destSchema string +} + +func SetupIndexTriggerSyncSuite(t *testing.T) IndexTriggerSyncTestSuite { + t.Helper() + sourceConn, err := NewPostgresConnector(t.Context(), nil, internal.GetCatalogPostgresConfigFromEnv(t.Context())) + require.NoError(t, err) + destConn, err := NewPostgresConnector(t.Context(), nil, internal.GetCatalogPostgresConfigFromEnv(t.Context())) + require.NoError(t, err) + sourceSchema := "src_idx_" + strings.ToLower(shared.RandomString(8)) + destSchema := "dst_idx_" + strings.ToLower(shared.RandomString(8)) + + setupTx, err := sourceConn.conn.Begin(t.Context()) + require.NoError(t, err) + defer func() { + err := setupTx.Rollback(t.Context()) + if err != pgx.ErrTxClosed { + require.NoError(t, err) + } + }() + + _, err = setupTx.Exec(t.Context(), fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", sourceSchema)) + require.NoError(t, err) + _, err = setupTx.Exec(t.Context(), "CREATE SCHEMA "+sourceSchema) + require.NoError(t, err) + require.NoError(t, setupTx.Commit(t.Context())) + + setupTx2, err := destConn.conn.Begin(t.Context()) + require.NoError(t, err) + defer func() { + err := setupTx2.Rollback(t.Context()) + if err != pgx.ErrTxClosed { + require.NoError(t, err) + } + }() + + _, err = setupTx2.Exec(t.Context(), fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", destSchema)) + require.NoError(t, err) + _, err = setupTx2.Exec(t.Context(), "CREATE SCHEMA "+destSchema) + require.NoError(t, err) + require.NoError(t, setupTx2.Commit(t.Context())) + + return IndexTriggerSyncTestSuite{ + t: t, + sourceConn: sourceConn, + destConn: destConn, + sourceSchema: sourceSchema, + destSchema: destSchema, + } +} + +func (s IndexTriggerSyncTestSuite) Teardown(ctx context.Context) { + _, _ = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", s.sourceSchema)) + _, _ = s.destConn.conn.Exec(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", s.destSchema)) + require.NoError(s.t, s.sourceConn.Close()) + require.NoError(s.t, s.destConn.Close()) +} + +func (s IndexTriggerSyncTestSuite) TestSyncIndexes() { + ctx := s.t.Context() + sourceTable := s.sourceSchema + ".test_table" + destTable := s.destSchema + ".test_table" + _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT, + email TEXT, + created_at TIMESTAMPTZ DEFAULT NOW() + )`, sourceTable)) + require.NoError(s.t, err) + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("CREATE INDEX idx_name ON %s(name)", sourceTable)) + require.NoError(s.t, err) + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("CREATE UNIQUE INDEX idx_email ON %s(email)", sourceTable)) + require.NoError(s.t, err) + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("CREATE INDEX idx_created_at ON %s(created_at DESC)", sourceTable)) + require.NoError(s.t, err) + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT, + email TEXT, + created_at TIMESTAMPTZ DEFAULT NOW() + )`, destTable)) + require.NoError(s.t, err) + + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: destTable, + }, + } + + err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) + require.NoError(s.t, err) + destTableParsed, err := utils.ParseSchemaTable(destTable) + require.NoError(s.t, err) + destIndexes, err := s.destConn.getIndexesForTable(ctx, destTableParsed) + require.NoError(s.t, err) + indexNames := make(map[string]bool) + for _, idx := range destIndexes { + indexNames[idx.IndexName] = true + } + require.True(s.t, indexNames["test_table_pkey"], "Primary key should exist") + require.True(s.t, indexNames["idx_name"], "idx_name should be synced") + require.True(s.t, indexNames["idx_email"], "idx_email should be synced") + require.True(s.t, indexNames["idx_created_at"], "idx_created_at should be synced") +} + +func (s IndexTriggerSyncTestSuite) TestSyncTriggers() { + ctx := s.t.Context() + sourceTable := s.sourceSchema + ".test_trigger_table" + destTable := s.destSchema + ".test_trigger_table" + _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.update_timestamp() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = NOW(); + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, s.sourceSchema)) + require.NoError(s.t, err) + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.update_timestamp() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = NOW(); + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, s.destSchema)) + require.NoError(s.t, err) + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT, + updated_at TIMESTAMPTZ DEFAULT NOW() + )`, sourceTable)) + require.NoError(s.t, err) + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TRIGGER update_updated_at + BEFORE UPDATE ON %s + FOR EACH ROW + EXECUTE FUNCTION %s.update_timestamp()`, sourceTable, s.sourceSchema)) + require.NoError(s.t, err) + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT, + updated_at TIMESTAMPTZ DEFAULT NOW() + )`, destTable)) + require.NoError(s.t, err) + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: destTable, + }, + } + + err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) + require.NoError(s.t, err) + destTableParsed, err := utils.ParseSchemaTable(destTable) + require.NoError(s.t, err) + destTriggers, err := s.destConn.getTriggersForTable(ctx, destTableParsed) + require.NoError(s.t, err) + require.Len(s.t, destTriggers, 1, "Should have 1 trigger") + require.Equal(s.t, "update_updated_at", destTriggers[0].TriggerName) +} + +func (s IndexTriggerSyncTestSuite) TestSyncIndexesAndTriggersTogether() { + ctx := s.t.Context() + sourceTable := s.sourceSchema + ".test_both_table" + destTable := s.destSchema + ".test_both_table" + + // Create function for trigger + _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.log_changes() + RETURNS TRIGGER AS $$ + BEGIN + -- Simple trigger function + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, s.sourceSchema)) + require.NoError(s.t, err) + + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.log_changes() + RETURNS TRIGGER AS $$ + BEGIN + -- Simple trigger function + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, s.destSchema)) + require.NoError(s.t, err) + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + data TEXT + )`, sourceTable)) + require.NoError(s.t, err) + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("CREATE INDEX idx_data ON %s(data)", sourceTable)) + require.NoError(s.t, err) + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TRIGGER log_trigger + AFTER INSERT ON %s + FOR EACH ROW + EXECUTE FUNCTION %s.log_changes()`, sourceTable, s.sourceSchema)) + require.NoError(s.t, err) + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + data TEXT + )`, destTable)) + require.NoError(s.t, err) + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: destTable, + }, + } + + err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) + require.NoError(s.t, err) + destTableParsed, err := utils.ParseSchemaTable(destTable) + require.NoError(s.t, err) + destIndexes, err := s.destConn.getIndexesForTable(ctx, destTableParsed) + require.NoError(s.t, err) + + indexNames := make(map[string]bool) + for _, idx := range destIndexes { + indexNames[idx.IndexName] = true + } + require.True(s.t, indexNames["idx_data"], "idx_data should be synced") + destTriggers, err := s.destConn.getTriggersForTable(ctx, destTableParsed) + require.NoError(s.t, err) + + require.Len(s.t, destTriggers, 1, "Should have 1 trigger") + require.Equal(s.t, "log_trigger", destTriggers[0].TriggerName) +} + +func (s IndexTriggerSyncTestSuite) TestSyncCheckConstraints() { + ctx := s.t.Context() + sourceTable := s.sourceSchema + ".test_check_table" + destTable := s.destSchema + ".test_check_table" + _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT NOT NULL, + age INT, + email TEXT + )`, sourceTable)) + require.NoError(s.t, err) + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_name_length + CHECK (char_length(name) >= 3)`, sourceTable)) + require.NoError(s.t, err) + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_age_positive + CHECK (age > 0 AND age < 150)`, sourceTable)) + require.NoError(s.t, err) + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_email_format + CHECK (email IS NULL OR email LIKE '%%@%%')`, sourceTable)) + require.NoError(s.t, err) + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT NOT NULL, + age INT, + email TEXT + )`, destTable)) + require.NoError(s.t, err) + + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: destTable, + }, + } + + err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) + require.NoError(s.t, err) + destTableParsed, err := utils.ParseSchemaTable(destTable) + require.NoError(s.t, err) + destConstraints, err := s.destConn.getConstraintsForTable(ctx, destTableParsed) + require.NoError(s.t, err) + constraintNames := make(map[string]bool) + for _, constraint := range destConstraints { + if constraint.ConstraintType == "c" { + constraintNames[constraint.ConstraintName] = true + } + } + + require.True(s.t, constraintNames["check_name_length"], "check_name_length should be synced") + require.True(s.t, constraintNames["check_age_positive"], "check_age_positive should be synced") + require.True(s.t, constraintNames["check_email_format"], "check_email_format should be synced") +} + +func (s IndexTriggerSyncTestSuite) TestSyncForeignKeyConstraints() { + ctx := s.t.Context() + parentSourceTable := s.sourceSchema + ".parent_table" + parentDestTable := s.destSchema + ".parent_table" + + _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT + )`, parentSourceTable)) + require.NoError(s.t, err) + + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT + )`, parentDestTable)) + require.NoError(s.t, err) + childSourceTable := s.sourceSchema + ".child_table" + childDestTable := s.destSchema + ".child_table" + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + parent_id INT, + name TEXT + )`, childSourceTable)) + require.NoError(s.t, err) + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT fk_parent + FOREIGN KEY (parent_id) REFERENCES %s(id)`, childSourceTable, parentSourceTable)) + require.NoError(s.t, err) + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + parent_id INT, + name TEXT + )`, childDestTable)) + require.NoError(s.t, err) + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: parentSourceTable, + DestinationTableIdentifier: parentDestTable, + }, + { + SourceTableIdentifier: childSourceTable, + DestinationTableIdentifier: childDestTable, + }, + } + + err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) + require.NoError(s.t, err) + childDestTableParsed, err := utils.ParseSchemaTable(childDestTable) + require.NoError(s.t, err) + destConstraints, err := s.destConn.getConstraintsForTable(ctx, childDestTableParsed) + require.NoError(s.t, err) + fkFound := false + for _, constraint := range destConstraints { + if constraint.ConstraintType == "f" && constraint.ConstraintName == "fk_parent" { + fkFound = true + require.Contains(s.t, constraint.ConstraintDef, parentDestTable, + "Foreign key should reference destination parent table") + break + } + } + require.True(s.t, fkFound, "fk_parent foreign key should be synced") +} + +func (s IndexTriggerSyncTestSuite) TestSyncConstraintsTogether() { + ctx := s.t.Context() + sourceTable := s.sourceSchema + ".test_constraints_table" + destTable := s.destSchema + ".test_constraints_table" + refSourceTable := s.sourceSchema + ".ref_table" + refDestTable := s.destSchema + ".ref_table" + + _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + code TEXT UNIQUE + )`, refSourceTable)) + require.NoError(s.t, err) + + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + code TEXT UNIQUE + )`, refDestTable)) + require.NoError(s.t, err) + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT NOT NULL, + ref_id INT, + status TEXT + )`, sourceTable)) + require.NoError(s.t, err) + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_name_length + CHECK (char_length(name) >= 2)`, sourceTable)) + require.NoError(s.t, err) + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT fk_ref + FOREIGN KEY (ref_id) REFERENCES %s(id)`, sourceTable, refSourceTable)) + require.NoError(s.t, err) + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT NOT NULL, + ref_id INT, + status TEXT + )`, destTable)) + require.NoError(s.t, err) + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: refSourceTable, + DestinationTableIdentifier: refDestTable, + }, + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: destTable, + }, + } + + err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) + require.NoError(s.t, err) + destTableParsed, err := utils.ParseSchemaTable(destTable) + require.NoError(s.t, err) + destConstraints, err := s.destConn.getConstraintsForTable(ctx, destTableParsed) + require.NoError(s.t, err) + + constraintNames := make(map[string]string) // name -> type + for _, constraint := range destConstraints { + constraintNames[constraint.ConstraintName] = constraint.ConstraintType + } + + require.Equal(s.t, "c", constraintNames["check_name_length"], "check_name_length should be synced") + require.Equal(s.t, "f", constraintNames["fk_ref"], "fk_ref should be synced") +} + +func (s IndexTriggerSyncTestSuite) TestSyncAllTogether() { + ctx := s.t.Context() + sourceTable := s.sourceSchema + ".test_all_table" + destTable := s.destSchema + ".test_all_table" + _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.audit_func() + RETURNS TRIGGER AS $$ + BEGIN + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, s.sourceSchema)) + require.NoError(s.t, err) + + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.audit_func() + RETURNS TRIGGER AS $$ + BEGIN + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, s.destSchema)) + require.NoError(s.t, err) + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT NOT NULL, + email TEXT, + created_at TIMESTAMPTZ DEFAULT NOW() + )`, sourceTable)) + require.NoError(s.t, err) + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("CREATE INDEX idx_name ON %s(name)", sourceTable)) + require.NoError(s.t, err) + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TRIGGER audit_trigger + AFTER INSERT ON %s + FOR EACH ROW + EXECUTE FUNCTION %s.audit_func()`, sourceTable, s.sourceSchema)) + require.NoError(s.t, err) + + // Add check constraint + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_name_length + CHECK (char_length(name) >= 3)`, sourceTable)) + require.NoError(s.t, err) + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT NOT NULL, + email TEXT, + created_at TIMESTAMPTZ DEFAULT NOW() + )`, destTable)) + require.NoError(s.t, err) + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: destTable, + }, + } + + err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) + require.NoError(s.t, err) + destTableParsed, err := utils.ParseSchemaTable(destTable) + require.NoError(s.t, err) + destIndexes, err := s.destConn.getIndexesForTable(ctx, destTableParsed) + require.NoError(s.t, err) + + indexNames := make(map[string]bool) + for _, idx := range destIndexes { + indexNames[idx.IndexName] = true + } + require.True(s.t, indexNames["idx_name"], "idx_name should be synced") + destTriggers, err := s.destConn.getTriggersForTable(ctx, destTableParsed) + require.NoError(s.t, err) + require.Len(s.t, destTriggers, 1, "Should have 1 trigger") + require.Equal(s.t, "audit_trigger", destTriggers[0].TriggerName) + destConstraints, err := s.destConn.getConstraintsForTable(ctx, destTableParsed) + require.NoError(s.t, err) + + constraintFound := false + for _, constraint := range destConstraints { + if constraint.ConstraintName == "check_name_length" && constraint.ConstraintType == "c" { + constraintFound = true + break + } + } + require.True(s.t, constraintFound, "check_name_length constraint should be synced") +} + +func TestIndexTriggerSync(t *testing.T) { + e2eshared.RunSuite(t, SetupIndexTriggerSyncSuite) +} diff --git a/flow/connectors/postgres/normalize_stmt_generator.go b/flow/connectors/postgres/normalize_stmt_generator.go index d090ff01ba..8d49969c06 100644 --- a/flow/connectors/postgres/normalize_stmt_generator.go +++ b/flow/connectors/postgres/normalize_stmt_generator.go @@ -138,18 +138,33 @@ func (n *normalizeStmtGenerator) generateMergeStatement( unchangedToastColumns []string, ) string { columnCount := len(normalizedTableSchema.Columns) - quotedColumnNames := make([]string, columnCount) - flattenedCastsSQLArray := make([]string, 0, columnCount) parsedDstTable, _ := utils.ParseSchemaTable(dstTableName) primaryKeyColumnCasts := make(map[string]string) primaryKeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) - for i, column := range normalizedTableSchema.Columns { + + // Filter out PeerDB system columns - they are added separately + systemCols := make(map[string]bool) + if n.peerdbCols.SyncedAtColName != "" { + systemCols[n.peerdbCols.SyncedAtColName] = true + } + if n.peerdbCols.SoftDeleteColName != "" { + systemCols[n.peerdbCols.SoftDeleteColName] = true + } + + // Build column lists excluding system columns + quotedColumnNamesFiltered := make([]string, 0, columnCount) + + for _, column := range normalizedTableSchema.Columns { + if systemCols[column.Name] { + continue + } + genericColumnType := column.Type quotedCol := utils.QuoteIdentifier(column.Name) stringCol := utils.QuoteLiteral(column.Name) - quotedColumnNames[i] = quotedCol + quotedColumnNamesFiltered = append(quotedColumnNamesFiltered, quotedCol) pgType := n.columnTypeToPg(normalizedTableSchema, genericColumnType) expr := n.generateExpr(normalizedTableSchema, genericColumnType, stringCol, pgType) @@ -159,14 +174,16 @@ func (n *normalizeStmtGenerator) generateMergeStatement( primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s", quotedCol, quotedCol)) } } + + // Use filtered column names (excluding system columns) + quotedColumnNames := quotedColumnNamesFiltered flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",") - insertValuesSQLArray := make([]string, 0, columnCount+2) + insertValuesSQLArray := make([]string, 0, len(quotedColumnNames)+2) for _, quotedCol := range quotedColumnNames { insertValuesSQLArray = append(insertValuesSQLArray, "src."+quotedCol) } updateStatementsforToastCols := n.generateUpdateStatements(quotedColumnNames, unchangedToastColumns) - // append synced_at column if n.peerdbCols.SyncedAtColName != "" { quotedColumnNames = append(quotedColumnNames, utils.QuoteIdentifier(n.peerdbCols.SyncedAtColName)) insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 03c1c32131..99200cb1e3 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -1134,27 +1134,133 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas( defer shared.RollbackTx(tableSchemaModifyTx, c.logger) for _, schemaDelta := range schemaDeltas { - if schemaDelta == nil || len(schemaDelta.AddedColumns) == 0 { + if schemaDelta == nil { + c.logger.Warn("Skipping nil schema delta") continue } + // Skip if no schema changes + if len(schemaDelta.AddedColumns) == 0 && len(schemaDelta.DroppedColumns) == 0 && len(schemaDelta.TypeChangedColumns) == 0 { + continue + } + + dstSchemaTable, err := utils.ParseSchemaTable(schemaDelta.DstTableName) + if err != nil { + return fmt.Errorf("error parsing schema and table for %s: %w", schemaDelta.DstTableName, err) + } + + for _, droppedColumn := range schemaDelta.DroppedColumns { + // Never drop PeerDB system columns - they are managed by PeerDB + // System columns start with _PEERDB_ prefix + if strings.HasPrefix(droppedColumn, "_PEERDB_") { + c.logger.Warn("Skipping drop of PeerDB system column", + slog.String("columnName", droppedColumn), + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.String("reason", "PeerDB system columns cannot be dropped")) + continue + } + + dropSQL := fmt.Sprintf( + "ALTER TABLE %s.%s DROP COLUMN IF EXISTS %s", + utils.QuoteIdentifier(dstSchemaTable.Schema), + utils.QuoteIdentifier(dstSchemaTable.Table), + utils.QuoteIdentifier(droppedColumn)) + _, err = c.execWithLoggingTx(ctx, dropSQL, tableSchemaModifyTx) + if err != nil { + c.logger.Error("Failed to drop column", + slog.String("columnName", droppedColumn), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Any("error", err)) + return fmt.Errorf("failed to drop column %s for table %s: %w", droppedColumn, + schemaDelta.DstTableName, err) + } + } + + for _, typeChange := range schemaDelta.TypeChangedColumns { + // Never change type of PeerDB system columns + if strings.HasPrefix(typeChange.ColumnName, "_PEERDB_") { + continue + } + + newType := typeChange.NewType + if schemaDelta.System == protos.TypeSystem_Q { + newType = qValueKindToPostgresType(typeChange.NewType) + } + + // Build ALTER COLUMN statement + alterSQL := fmt.Sprintf( + "ALTER TABLE %s.%s ALTER COLUMN %s TYPE %s", + utils.QuoteIdentifier(dstSchemaTable.Schema), + utils.QuoteIdentifier(dstSchemaTable.Table), + utils.QuoteIdentifier(typeChange.ColumnName), + newType) + + _, err = c.execWithLoggingTx(ctx, alterSQL, tableSchemaModifyTx) + if err != nil { + c.logger.Error("Failed to change column type", + slog.String("columnName", typeChange.ColumnName), + slog.String("oldType", typeChange.OldType), + slog.String("newType", typeChange.NewType), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Any("error", err)) + } else { + c.logger.Info("Successfully changed column type", + slog.String("columnName", typeChange.ColumnName), + slog.String("oldType", typeChange.OldType), + slog.String("newType", typeChange.NewType), + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName)) + + // Update default value if provided + if typeChange.NewDefaultValue != "" { + defaultSQL := fmt.Sprintf( + "ALTER TABLE %s.%s ALTER COLUMN %s SET DEFAULT %s", + utils.QuoteIdentifier(dstSchemaTable.Schema), + utils.QuoteIdentifier(dstSchemaTable.Table), + utils.QuoteIdentifier(typeChange.ColumnName), + typeChange.NewDefaultValue) + + _, err = c.execWithLoggingTx(ctx, defaultSQL, tableSchemaModifyTx) + if err != nil { + c.logger.Warn("Failed to set default value after type change", + slog.String("columnName", typeChange.ColumnName), + slog.String("defaultValue", typeChange.NewDefaultValue), + slog.Any("error", err)) + } + } + } + } + for _, addedColumn := range schemaDelta.AddedColumns { columnType := addedColumn.Type if schemaDelta.System == protos.TypeSystem_Q { - columnType = qValueKindToPostgresType(columnType) + columnType = qValueKindToPostgresType(addedColumn.Type) } - dstSchemaTable, err := utils.ParseSchemaTable(schemaDelta.DstTableName) - if err != nil { - return fmt.Errorf("error parsing schema and table for %s: %w", schemaDelta.DstTableName, err) + // Build column definition with type and default value + columnDef := fmt.Sprintf("%s %s", utils.QuoteIdentifier(addedColumn.Name), columnType) + if addedColumn.DefaultValue != "" { + // Include DEFAULT value if present + columnDef += fmt.Sprintf(" DEFAULT %s", addedColumn.DefaultValue) + } else if !addedColumn.Nullable { + // If NOT NULL without DEFAULT, PostgreSQL will require it + columnDef += " NOT NULL" } - _, err = c.execWithLoggingTx(ctx, fmt.Sprintf( - "ALTER TABLE %s.%s ADD COLUMN IF NOT EXISTS %s %s", + addSQL := fmt.Sprintf( + "ALTER TABLE %s.%s ADD COLUMN IF NOT EXISTS %s", utils.QuoteIdentifier(dstSchemaTable.Schema), utils.QuoteIdentifier(dstSchemaTable.Table), - utils.QuoteIdentifier(addedColumn.Name), columnType), tableSchemaModifyTx) + columnDef) + + _, err = c.execWithLoggingTx(ctx, addSQL, tableSchemaModifyTx) if err != nil { + c.logger.Error("Failed to add column", + slog.String("columnName", addedColumn.Name), + slog.String("columnType", addedColumn.Type), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Any("error", err)) return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.Name, schemaDelta.DstTableName, err) } @@ -1167,6 +1273,9 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas( } if err := tableSchemaModifyTx.Commit(ctx); err != nil { + c.logger.Error("Failed to commit schema modification transaction", + slog.String("flowJobName", flowJobName), + slog.Any("error", err)) return fmt.Errorf("failed to commit transaction for table schema modification: %w", err) } return nil diff --git a/flow/e2e/postgres_test.go b/flow/e2e/postgres_test.go index 702784e018..1490e30c70 100644 --- a/flow/e2e/postgres_test.go +++ b/flow/e2e/postgres_test.go @@ -1286,3 +1286,157 @@ func (s PeerFlowE2ETestSuitePG) TestResync(tableName string) { env.Cancel(s.t.Context()) RequireEnvCanceled(s.t, env) } + +func (s PeerFlowE2ETestSuitePG) Test_Indexes_Triggers_Constraints_PG() { + tc := NewTemporalClient(s.t) + + srcTableName := s.attachSchemaSuffix("test_idx_trig_const") + dstTableName := s.attachSchemaSuffix("test_idx_trig_const_dst") + + _, err := s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.update_timestamp() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = NOW(); + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, Schema(s))) + require.NoError(s.t, err) + + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + email TEXT, + age INT, + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() + )`, srcTableName)) + require.NoError(s.t, err) + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + CREATE INDEX idx_name ON %s(name)`, srcTableName)) + require.NoError(s.t, err) + + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + CREATE UNIQUE INDEX idx_email ON %s(email) WHERE email IS NOT NULL`, srcTableName)) + require.NoError(s.t, err) + + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + CREATE INDEX idx_created_at ON %s(created_at DESC)`, srcTableName)) + require.NoError(s.t, err) + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + CREATE TRIGGER update_updated_at + BEFORE UPDATE ON %s + FOR EACH ROW + EXECUTE FUNCTION %s.update_timestamp()`, srcTableName, Schema(s))) + require.NoError(s.t, err) + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_name_length + CHECK (char_length(name) >= 3)`, srcTableName)) + require.NoError(s.t, err) + + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_age_positive + CHECK (age > 0 AND age < 150)`, srcTableName)) + require.NoError(s.t, err) + + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_email_format + CHECK (email IS NULL OR email LIKE '%%@%%')`, srcTableName)) + require.NoError(s.t, err) + + connectionGen := FlowConnectionGenerationConfig{ + FlowJobName: s.attachSuffix("test_idx_trig_const"), + TableNameMapping: map[string]string{srcTableName: dstTableName}, + Destination: s.Peer().Name, + } + + flowConnConfig := connectionGen.GenerateFlowConnectionConfigs(s) + flowConnConfig.MaxBatchSize = 100 + flowConnConfig.DoInitialSnapshot = true + + env := ExecutePeerflow(s.t, tc, flowConnConfig) + SetupCDCFlowStatusQuery(s.t, env, flowConnConfig) + EnvWaitFor(s.t, env, 3*time.Minute, "waiting for initial setup", func() bool { + var exists bool + err := s.Conn().QueryRow(s.t.Context(), ` + SELECT EXISTS ( + SELECT 1 FROM pg_catalog.pg_tables + WHERE schemaname = $1 AND tablename = $2 + )`, Schema(s), "test_idx_trig_const_dst").Scan(&exists) + return err == nil && exists + }) + s.t.Log("Verifying indexes were synced...") + indexQuery := ` + SELECT indexname + FROM pg_indexes + WHERE schemaname = $1 AND tablename = $2 AND indexname != $3 + ORDER BY indexname + ` + rows, err := s.Conn().Query(s.t.Context(), indexQuery, Schema(s), "test_idx_trig_const_dst", "test_idx_trig_const_dst_pkey") + require.NoError(s.t, err) + defer rows.Close() + + indexNames := make(map[string]bool) + for rows.Next() { + var indexName string + require.NoError(s.t, rows.Scan(&indexName)) + indexNames[indexName] = true + } + + require.True(s.t, indexNames["idx_name"], "idx_name should be synced") + require.True(s.t, indexNames["idx_email"], "idx_email should be synced") + require.True(s.t, indexNames["idx_created_at"], "idx_created_at should be synced") + + s.t.Log("Verifying trigger was synced...") + triggerQuery := ` + SELECT t.tgname + FROM pg_trigger t + JOIN pg_class c ON c.oid = t.tgrelid + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = $1 + AND c.relname = $2 + AND NOT t.tgisinternal + ` + var triggerName string + err = s.Conn().QueryRow(s.t.Context(), triggerQuery, Schema(s), "test_idx_trig_const_dst").Scan(&triggerName) + require.NoError(s.t, err) + require.Equal(s.t, "update_updated_at", triggerName, "update_updated_at trigger should be synced") + + s.t.Log("Verifying constraints were synced...") + constraintQuery := ` + SELECT con.conname + FROM pg_constraint con + JOIN pg_class c ON c.oid = con.conrelid + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = $1 + AND c.relname = $2 + AND con.contype::text = 'c' + ORDER BY con.conname + ` + rows, err = s.Conn().Query(s.t.Context(), constraintQuery, Schema(s), "test_idx_trig_const_dst") + require.NoError(s.t, err) + defer rows.Close() + + constraintNames := make(map[string]bool) + for rows.Next() { + var constraintName string + require.NoError(s.t, rows.Scan(&constraintName)) + constraintNames[constraintName] = true + } + + require.True(s.t, constraintNames["check_name_length"], "check_name_length should be synced") + require.True(s.t, constraintNames["check_age_positive"], "check_age_positive should be synced") + require.True(s.t, constraintNames["check_email_format"], "check_email_format should be synced") + + s.t.Log("All indexes, triggers, and constraints were successfully synced!") + + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + INSERT INTO %s(name, email, age) VALUES ('test', 'test@example.com', 25)`, srcTableName)) + EnvNoError(s.t, env, err) + + EnvWaitForEqualTablesWithNames(env, s, "waiting for data sync", srcTableName, dstTableName, "id,name,email,age") + + env.Cancel(s.t.Context()) + RequireEnvCanceled(s.t, env) +} diff --git a/flow/workflows/setup_flow.go b/flow/workflows/setup_flow.go index f4c79eb67c..be27ede2f4 100644 --- a/flow/workflows/setup_flow.go +++ b/flow/workflows/setup_flow.go @@ -213,6 +213,7 @@ func (s *SetupFlowExecution) setupNormalizedTables( FlowName: flowConnectionConfigs.FlowJobName, Env: flowConnectionConfigs.Env, IsResync: flowConnectionConfigs.Resync, + SourcePeerName: flowConnectionConfigs.SourceName, } if err := workflow.ExecuteActivity(ctx, flowable.CreateNormalizedTable, setupConfig).Get(ctx, nil); err != nil { @@ -220,6 +221,11 @@ func (s *SetupFlowExecution) setupNormalizedTables( return fmt.Errorf("failed to create normalized tables: %w", err) } + // Sync indexes, triggers, and constraints after tables are created + if err := workflow.ExecuteActivity(ctx, flowable.SyncIndexesAndTriggers, setupConfig).Get(ctx, nil); err != nil { + s.Warn("failed to sync indexes, triggers, and constraints", slog.Any("error", err)) + } + s.Info("finished setting up normalized tables for peer flow") return nil } diff --git a/protos/flow.proto b/protos/flow.proto index ee94edfadb..b412379075 100644 --- a/protos/flow.proto +++ b/protos/flow.proto @@ -240,6 +240,7 @@ message FieldDescription { string type = 2; int32 type_modifier = 3; bool nullable = 4; + string default_value = 5; } message SetupTableSchemaBatchInput { @@ -262,6 +263,7 @@ message SetupNormalizedTableBatchInput { string flow_name = 6; string peer_name = 7; bool is_resync = 8; + string source_peer_name = 9; } message SetupNormalizedTableOutput { @@ -425,10 +427,19 @@ message DropFlowInput { bool resync = 8; } +message ColumnTypeChange { + string column_name = 1; + string old_type = 2; + string new_type = 3; + string new_default_value = 4; +} + message TableSchemaDelta { string src_table_name = 1; string dst_table_name = 2; repeated FieldDescription added_columns = 3; + repeated string dropped_columns = 6; + repeated ColumnTypeChange type_changed_columns = 7; TypeSystem system = 4; bool nullable_enabled = 5; }