diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 95ba8c8819..7210ef880a 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -274,6 +274,36 @@ func (a *FlowableActivity) CreateNormalizedTable( return nil, fmt.Errorf("failed to commit normalized tables tx: %w", err) } + // For Postgres-to-Postgres flows, migrate triggers and indexes after tables are created + if dstPgConn, ok := conn.(*connpostgres.PostgresConnector); ok { + // Get source peer name from catalog + var sourcePeerName string + var sourcePeerType protos.DBType + err := a.CatalogPool.QueryRow(ctx, + `SELECT COALESCE(sp.name, ''), COALESCE(sp.type, 0) + FROM flows f + LEFT JOIN peers sp ON f.source_peer = sp.id + WHERE f.name = $1`, + config.FlowName).Scan(&sourcePeerName, &sourcePeerType) + if err == nil && sourcePeerName != "" && sourcePeerType == protos.DBType_POSTGRES { + // Get source connector + if srcPgConn, srcPgClose, err := connectors.GetPostgresConnectorByName(ctx, config.Env, a.CatalogPool, sourcePeerName); err == nil { + defer srcPgClose(ctx) + logger.Info("Migrating triggers and indexes for Postgres-to-Postgres flow") + // Migrate full schema first (in case there are differences) + if err := connpostgres.MigrateSchemaFromSource(ctx, srcPgConn, dstPgConn, config.TableMappings); err != nil { + logger.Warn("failed to migrate schema during setup", slog.Any("error", err)) + // Don't fail setup if schema migration fails, tables are already created + } + // Migrate triggers and indexes + if err := dstPgConn.MigrateTriggersAndIndexesForPostgresToPostgres(ctx, srcPgConn, config.TableMappings); err != nil { + logger.Warn("failed to migrate triggers and indexes during setup", slog.Any("error", err)) + // Don't fail setup if migration fails + } + } + } + } + a.Alerter.LogFlowInfo(ctx, config.FlowName, "All destination tables have been setup") return &protos.SetupNormalizedTableBatchOutput{ diff --git a/flow/activities/flowable_core.go b/flow/activities/flowable_core.go index 2da7e035a0..99a15a25a4 100644 --- a/flow/activities/flowable_core.go +++ b/flow/activities/flowable_core.go @@ -238,6 +238,19 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon return nil, fmt.Errorf("failed to sync schema: %w", err) } + // For Postgres-to-Postgres flows, migrate triggers and indexes after schema changes + // Use type switch to check if destination is Postgres + switch dstPgConn := any(dstConn).(type) { + case *connpostgres.PostgresConnector: + if srcPgConn, srcPgClose, err := connectors.GetPostgresConnectorByName(ctx, config.Env, a.CatalogPool, config.SourceName); err == nil { + defer srcPgClose(ctx) + if err := dstPgConn.MigrateTriggersAndIndexesForPostgresToPostgres(ctx, srcPgConn, options.TableMappings); err != nil { + logger.Warn("failed to migrate triggers and indexes", slog.Any("error", err)) + // Don't fail the flow, just log a warning + } + } + } + return nil, a.applySchemaDeltas(ctx, config, options, recordBatchSync.SchemaDeltas) } @@ -338,6 +351,26 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon return nil, err } + // For Postgres-to-Postgres flows, migrate triggers and indexes after schema changes + // Schema deltas are already replayed in syncRecordsCore, but we migrate triggers/indexes here + // where we have access to both source and destination connectors + if len(res.TableSchemaDeltas) > 0 { + // Get destination connector again to check if it's Postgres + dstConnForMigration, dstCloseForMigration, err := connectors.GetByNameAs[connectors.CDCSyncConnectorCore](ctx, config.Env, a.CatalogPool, config.DestinationName) + if err == nil { + defer dstCloseForMigration(ctx) + if dstPgConn, ok := dstConnForMigration.(*connpostgres.PostgresConnector); ok { + if srcPgConn, srcPgClose, err := connectors.GetPostgresConnectorByName(ctx, config.Env, a.CatalogPool, config.SourceName); err == nil { + defer srcPgClose(ctx) + if err := dstPgConn.MigrateTriggersAndIndexesForPostgresToPostgres(ctx, srcPgConn, options.TableMappings); err != nil { + logger.Warn("failed to migrate triggers and indexes after sync", slog.Any("error", err)) + // Don't fail the flow, just log a warning + } + } + } + } + } + if recordBatchSync.NeedsNormalize() { syncState.Store(shared.Ptr("normalizing")) normRequests.Update(res.CurrentSyncBatchID) diff --git a/flow/connectors/postgres/migration.go b/flow/connectors/postgres/migration.go new file mode 100644 index 0000000000..dded15e5d0 --- /dev/null +++ b/flow/connectors/postgres/migration.go @@ -0,0 +1,684 @@ +package connpostgres + +import ( + "context" + "fmt" + "log/slog" + "strings" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgtype" + "go.temporal.io/sdk/log" + + "github.com/PeerDB-io/peerdb/flow/connectors/utils" + "github.com/PeerDB-io/peerdb/flow/generated/protos" + "github.com/PeerDB-io/peerdb/flow/shared" +) + +// TableSchema represents the full schema of a table +type TableSchema struct { + SchemaName string + TableName string + Columns []ColumnDefinition + PrimaryKey []string + Constraints []ConstraintDefinition +} + +// ColumnDefinition represents a column in a table +type ColumnDefinition struct { + Name string + DataType string + IsNullable bool + DefaultValue *string + IsPrimaryKey bool +} + +// ConstraintDefinition represents a constraint on a table +type ConstraintDefinition struct { + Name string + Type string // PRIMARY KEY, FOREIGN KEY, UNIQUE, CHECK + Definition string +} + +// TriggerDefinition represents a trigger +type TriggerDefinition struct { + Name string + Event string // INSERT, UPDATE, DELETE, TRUNCATE + Timing string // BEFORE, AFTER, INSTEAD OF + Function string + Condition *string + Definition string // Full CREATE TRIGGER statement +} + +// IndexDefinition represents an index +type IndexDefinition struct { + Name string + TableSchema string + TableName string + Columns []string + IsUnique bool + Method string // btree, hash, gist, etc. + Definition string // Full CREATE INDEX statement +} + +// MigrateSchemaFromSource migrates the full schema from source to target Postgres database +// sourceConnector is the connector for the source database +// targetConnector is the connector for the target database +func MigrateSchemaFromSource( + ctx context.Context, + sourceConnector *PostgresConnector, + targetConnector *PostgresConnector, + tableMappings []*protos.TableMapping, +) error { + sourceConn := sourceConnector.conn + targetConn := targetConnector.conn + logger := sourceConnector.logger + logger.Info("[migration] Starting schema migration") + + for _, mapping := range tableMappings { + sourceTable := mapping.SourceTableIdentifier + targetTable := mapping.DestinationTableIdentifier + + if targetTable == "" { + targetTable = sourceTable + } + + logger.Info("[migration] Migrating schema", + slog.String("sourceTable", sourceTable), + slog.String("targetTable", targetTable)) + + // Get source schema + sourceSchema, err := getTableSchema(ctx, sourceConn, sourceTable) + if err != nil { + return fmt.Errorf("failed to get source schema for %s: %w", sourceTable, err) + } + + // Check if target table exists + targetSchemaTable, err := utils.ParseSchemaTable(targetTable) + if err != nil { + return fmt.Errorf("error parsing target table %s: %w", targetTable, err) + } + + tableExists, err := tableExists(ctx, targetConn, targetSchemaTable) + if err != nil { + return fmt.Errorf("failed to check if target table exists: %w", err) + } + + if !tableExists { + // Create table with full schema + if err := createTableFromSchema(ctx, targetConn, sourceSchema, targetSchemaTable, logger); err != nil { + return fmt.Errorf("failed to create table %s: %w", targetTable, err) + } + } else { + // Alter table to match source schema + if err := alterTableToMatchSchema(ctx, targetConn, sourceSchema, targetSchemaTable, logger); err != nil { + return fmt.Errorf("failed to alter table %s: %w", targetTable, err) + } + } + } + + logger.Info("[migration] Schema migration completed") + return nil +} + +// MigrateTriggersFromSource migrates triggers from source to target Postgres database +func MigrateTriggersFromSource( + ctx context.Context, + sourceConnector *PostgresConnector, + targetConnector *PostgresConnector, + tableMappings []*protos.TableMapping, +) error { + sourceConn := sourceConnector.conn + targetConn := targetConnector.conn + logger := sourceConnector.logger + logger.Info("[migration] Starting trigger migration") + + for _, mapping := range tableMappings { + sourceTable := mapping.SourceTableIdentifier + targetTable := mapping.DestinationTableIdentifier + + if targetTable == "" { + targetTable = sourceTable + } + + logger.Info("[migration] Migrating triggers", + slog.String("sourceTable", sourceTable), + slog.String("targetTable", targetTable)) + + // Get source triggers + sourceTriggers, err := getTableTriggers(ctx, sourceConn, sourceTable) + if err != nil { + return fmt.Errorf("failed to get source triggers for %s: %w", sourceTable, err) + } + + // Get target triggers + targetTriggers, err := getTableTriggers(ctx, targetConn, targetTable) + if err != nil { + return fmt.Errorf("failed to get target triggers for %s: %w", targetTable, err) + } + + // Create a map of existing target triggers + targetTriggerMap := make(map[string]*TriggerDefinition) + for _, trigger := range targetTriggers { + targetTriggerMap[trigger.Name] = trigger + } + + // Migrate each source trigger + tx, err := targetConn.Begin(ctx) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer shared.RollbackTx(tx, logger) + + for _, sourceTrigger := range sourceTriggers { + if _, exists := targetTriggerMap[sourceTrigger.Name]; exists { + // Drop existing trigger if it exists + _, err = tx.Exec(ctx, fmt.Sprintf( + "DROP TRIGGER IF EXISTS %s ON %s", + utils.QuoteIdentifier(sourceTrigger.Name), + utils.QuoteIdentifier(targetTable))) + if err != nil { + return fmt.Errorf("failed to drop existing trigger %s: %w", sourceTrigger.Name, err) + } + } + + // Create trigger + triggerSQL := buildCreateTriggerSQL(sourceTrigger, targetTable) + _, err = tx.Exec(ctx, triggerSQL) + if err != nil { + return fmt.Errorf("failed to create trigger %s: %w", sourceTrigger.Name, err) + } + + logger.Info("[migration] Created trigger", + slog.String("trigger", sourceTrigger.Name), + slog.String("table", targetTable)) + } + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("failed to commit trigger migration: %w", err) + } + } + + logger.Info("[migration] Trigger migration completed") + return nil +} + +// MigrateIndexesFromSource migrates indexes from source to target Postgres database +func MigrateIndexesFromSource( + ctx context.Context, + sourceConnector *PostgresConnector, + targetConnector *PostgresConnector, + tableMappings []*protos.TableMapping, +) error { + sourceConn := sourceConnector.conn + targetConn := targetConnector.conn + logger := sourceConnector.logger + logger.Info("[migration] Starting index migration") + + for _, mapping := range tableMappings { + sourceTable := mapping.SourceTableIdentifier + targetTable := mapping.DestinationTableIdentifier + + if targetTable == "" { + targetTable = sourceTable + } + + logger.Info("[migration] Migrating indexes", + slog.String("sourceTable", sourceTable), + slog.String("targetTable", targetTable)) + + // Get source indexes + sourceIndexes, err := getTableIndexes(ctx, sourceConn, sourceTable) + if err != nil { + return fmt.Errorf("failed to get source indexes for %s: %w", sourceTable, err) + } + + // Get target indexes + targetIndexes, err := getTableIndexes(ctx, targetConn, targetTable) + if err != nil { + return fmt.Errorf("failed to get target indexes for %s: %w", targetTable, err) + } + + // Create a map of existing target indexes + targetIndexMap := make(map[string]*IndexDefinition) + for _, index := range targetIndexes { + targetIndexMap[index.Name] = index + } + + // Migrate each source index + tx, err := targetConn.Begin(ctx) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer shared.RollbackTx(tx, logger) + + for _, sourceIndex := range sourceIndexes { + // Skip primary key indexes as they're handled by schema migration + if strings.HasSuffix(sourceIndex.Name, "_pkey") { + continue + } + + if _, exists := targetIndexMap[sourceIndex.Name]; exists { + // Drop existing index if it exists + _, err = tx.Exec(ctx, fmt.Sprintf( + "DROP INDEX IF EXISTS %s.%s", + utils.QuoteIdentifier(sourceIndex.TableSchema), + utils.QuoteIdentifier(sourceIndex.Name))) + if err != nil { + return fmt.Errorf("failed to drop existing index %s: %w", sourceIndex.Name, err) + } + } + + // Create index (use the definition from source, but replace table name) + indexSQL := buildCreateIndexSQL(sourceIndex, targetTable) + _, err = tx.Exec(ctx, indexSQL) + if err != nil { + return fmt.Errorf("failed to create index %s: %w", sourceIndex.Name, err) + } + + logger.Info("[migration] Created index", + slog.String("index", sourceIndex.Name), + slog.String("table", targetTable)) + } + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("failed to commit index migration: %w", err) + } + } + + logger.Info("[migration] Index migration completed") + return nil +} + +// Helper functions + +func getTableSchema(ctx context.Context, conn *pgx.Conn, tableName string) (*TableSchema, error) { + schemaTable, err := utils.ParseSchemaTable(tableName) + if err != nil { + return nil, fmt.Errorf("error parsing table name: %w", err) + } + + // Get columns + columnsQuery := ` + SELECT + a.attname AS column_name, + pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type, + a.attnotnull AS not_null, + pg_get_expr(adbin, adrelid) AS default_value + FROM pg_attribute a + LEFT JOIN pg_attrdef ad ON a.attrelid = ad.adrelid AND a.attnum = ad.adnum + JOIN pg_class c ON a.attrelid = c.oid + JOIN pg_namespace n ON c.relnamespace = n.oid + WHERE n.nspname = $1 AND c.relname = $2 + AND a.attnum > 0 AND NOT a.attisdropped + ORDER BY a.attnum` + + rows, err := conn.Query(ctx, columnsQuery, schemaTable.Schema, schemaTable.Table) + if err != nil { + return nil, fmt.Errorf("failed to query columns: %w", err) + } + defer rows.Close() + + var columns []ColumnDefinition + var columnNames []string + for rows.Next() { + var colName, dataType pgtype.Text + var notNull pgtype.Bool + var defaultValue pgtype.Text + + if err := rows.Scan(&colName, &dataType, ¬Null, &defaultValue); err != nil { + return nil, fmt.Errorf("failed to scan column: %w", err) + } + + var defValue *string + if defaultValue.Valid { + defValue = &defaultValue.String + } + + columns = append(columns, ColumnDefinition{ + Name: colName.String, + DataType: dataType.String, + IsNullable: !notNull.Bool, + DefaultValue: defValue, + }) + columnNames = append(columnNames, colName.String) + } + + // Get primary key + pkQuery := ` + SELECT a.attname + FROM pg_index i + JOIN pg_attribute a ON a.attrelid = i.indrelid AND a.attnum = ANY(i.indkey) + JOIN pg_class c ON i.indrelid = c.oid + JOIN pg_namespace n ON c.relnamespace = n.oid + WHERE n.nspname = $1 AND c.relname = $2 AND i.indisprimary + ORDER BY array_position(i.indkey, a.attnum)` + + pkRows, err := conn.Query(ctx, pkQuery, schemaTable.Schema, schemaTable.Table) + if err != nil { + return nil, fmt.Errorf("failed to query primary key: %w", err) + } + defer pkRows.Close() + + var primaryKey []string + for pkRows.Next() { + var pkCol pgtype.Text + if err := pkRows.Scan(&pkCol); err != nil { + return nil, fmt.Errorf("failed to scan primary key: %w", err) + } + primaryKey = append(primaryKey, pkCol.String) + } + + // Mark primary key columns + for i := range columns { + for _, pkCol := range primaryKey { + if columns[i].Name == pkCol { + columns[i].IsPrimaryKey = true + break + } + } + } + + return &TableSchema{ + SchemaName: schemaTable.Schema, + TableName: schemaTable.Table, + Columns: columns, + PrimaryKey: primaryKey, + Constraints: []ConstraintDefinition{}, // Can be extended later + }, nil +} + +func tableExists(ctx context.Context, conn *pgx.Conn, schemaTable *utils.SchemaTable) (bool, error) { + var exists bool + err := conn.QueryRow(ctx, ` + SELECT EXISTS ( + SELECT 1 FROM pg_class c + JOIN pg_namespace n ON c.relnamespace = n.oid + WHERE n.nspname = $1 AND c.relname = $2 + )`, schemaTable.Schema, schemaTable.Table).Scan(&exists) + return exists, err +} + +func createTableFromSchema( + ctx context.Context, + conn *pgx.Conn, + schema *TableSchema, + targetSchemaTable *utils.SchemaTable, + logger log.Logger, +) error { + var columnDefs []string + for _, col := range schema.Columns { + colDef := fmt.Sprintf("%s %s", utils.QuoteIdentifier(col.Name), col.DataType) + if !col.IsNullable { + colDef += " NOT NULL" + } + if col.DefaultValue != nil { + colDef += " DEFAULT " + *col.DefaultValue + } + columnDefs = append(columnDefs, colDef) + } + + var pkDef string + if len(schema.PrimaryKey) > 0 { + pkCols := make([]string, len(schema.PrimaryKey)) + for i, pk := range schema.PrimaryKey { + pkCols[i] = utils.QuoteIdentifier(pk) + } + pkDef = fmt.Sprintf(", PRIMARY KEY (%s)", strings.Join(pkCols, ", ")) + } + + createTableSQL := fmt.Sprintf( + "CREATE TABLE IF NOT EXISTS %s.%s (%s%s)", + utils.QuoteIdentifier(targetSchemaTable.Schema), + utils.QuoteIdentifier(targetSchemaTable.Table), + strings.Join(columnDefs, ", "), + pkDef, + ) + + _, err := conn.Exec(ctx, createTableSQL) + if err != nil { + return fmt.Errorf("failed to create table: %w", err) + } + + logger.Info("[migration] Created table", + slog.String("table", targetSchemaTable.String())) + return nil +} + +func alterTableToMatchSchema( + ctx context.Context, + conn *pgx.Conn, + sourceSchema *TableSchema, + targetSchemaTable *utils.SchemaTable, + logger log.Logger, +) error { + // Get target schema + targetSchema, err := getTableSchema(ctx, conn, targetSchemaTable.String()) + if err != nil { + return fmt.Errorf("failed to get target schema: %w", err) + } + + // Create maps for comparison + sourceColMap := make(map[string]ColumnDefinition) + for _, col := range sourceSchema.Columns { + sourceColMap[col.Name] = col + } + + targetColMap := make(map[string]ColumnDefinition) + for _, col := range targetSchema.Columns { + targetColMap[col.Name] = col + } + + tx, err := conn.Begin(ctx) + if err != nil { + return fmt.Errorf("failed to begin transaction: %w", err) + } + defer shared.RollbackTx(tx, logger) + + // Add missing columns + for _, sourceCol := range sourceSchema.Columns { + if _, exists := targetColMap[sourceCol.Name]; !exists { + colDef := fmt.Sprintf("%s %s", utils.QuoteIdentifier(sourceCol.Name), sourceCol.DataType) + if !sourceCol.IsNullable { + colDef += " NOT NULL" + } + if sourceCol.DefaultValue != nil { + colDef += " DEFAULT " + *sourceCol.DefaultValue + } + + alterSQL := fmt.Sprintf( + "ALTER TABLE %s.%s ADD COLUMN %s", + utils.QuoteIdentifier(targetSchemaTable.Schema), + utils.QuoteIdentifier(targetSchemaTable.Table), + colDef, + ) + + _, err = tx.Exec(ctx, alterSQL) + if err != nil { + return fmt.Errorf("failed to add column %s: %w", sourceCol.Name, err) + } + + logger.Info("[migration] Added column", + slog.String("column", sourceCol.Name), + slog.String("table", targetSchemaTable.String())) + } + } + + // Note: We don't drop columns or change column types to avoid data loss + // This can be extended based on requirements + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("failed to commit schema changes: %w", err) + } + + return nil +} + +func getTableTriggers(ctx context.Context, conn *pgx.Conn, tableName string) ([]*TriggerDefinition, error) { + schemaTable, err := utils.ParseSchemaTable(tableName) + if err != nil { + return nil, fmt.Errorf("error parsing table name: %w", err) + } + + query := ` + SELECT + t.tgname AS trigger_name, + pg_get_triggerdef(t.oid) AS trigger_definition, + CASE + WHEN t.tgtype & 2 = 2 THEN 'BEFORE' + WHEN t.tgtype & 64 = 64 THEN 'INSTEAD OF' + ELSE 'AFTER' + END AS timing, + CASE + WHEN t.tgtype & 4 = 4 THEN 'INSERT' + WHEN t.tgtype & 8 = 8 THEN 'DELETE' + WHEN t.tgtype & 16 = 16 THEN 'UPDATE' + WHEN t.tgtype & 32 = 32 THEN 'TRUNCATE' + ELSE 'UNKNOWN' + END AS event, + pg_get_function_identity_arguments(t.tgfoid) AS function_args + FROM pg_trigger t + JOIN pg_class c ON t.tgrelid = c.oid + JOIN pg_namespace n ON c.relnamespace = n.oid + JOIN pg_proc p ON t.tgfoid = p.oid + JOIN pg_namespace pn ON p.pronamespace = pn.oid + WHERE n.nspname = $1 AND c.relname = $2 + AND NOT t.tgisinternal + ORDER BY t.tgname` + + rows, err := conn.Query(ctx, query, schemaTable.Schema, schemaTable.Table) + if err != nil { + return nil, fmt.Errorf("failed to query triggers: %w", err) + } + defer rows.Close() + + var triggers []*TriggerDefinition + for rows.Next() { + var triggerName, triggerDef, timing, event, functionArgs pgtype.Text + if err := rows.Scan(&triggerName, &triggerDef, &timing, &event, &functionArgs); err != nil { + return nil, fmt.Errorf("failed to scan trigger: %w", err) + } + + // Extract function name from trigger definition + functionName := "" + if functionArgs.Valid { + // The function name is typically in the trigger definition + // We'll extract it from pg_get_triggerdef output + functionName = functionArgs.String + } + + triggers = append(triggers, &TriggerDefinition{ + Name: triggerName.String, + Event: event.String, + Timing: timing.String, + Function: functionName, + Definition: triggerDef.String, + }) + } + + return triggers, nil +} + +func buildCreateTriggerSQL(trigger *TriggerDefinition, targetTable string) string { + // Use the full trigger definition from pg_get_triggerdef + // It already contains the complete CREATE TRIGGER statement + // We just need to replace the table reference + schemaTable, err := utils.ParseSchemaTable(targetTable) + if err != nil { + // Fallback to using the definition as-is + return trigger.Definition + } + + // Extract the table name from the definition and replace it + // pg_get_triggerdef returns: CREATE TRIGGER name ... ON table_name ... + // We need to find and replace the table reference + def := trigger.Definition + + // Try to find "ON schema.table" pattern and replace + // This is a simplified approach - a full SQL parser would be better + parts := strings.Split(def, " ON ") + if len(parts) >= 2 { + // Find the table part (might be schema.table or just table) + tablePart := strings.Fields(parts[1])[0] + newTableRef := fmt.Sprintf("%s.%s", utils.QuoteIdentifier(schemaTable.Schema), utils.QuoteIdentifier(schemaTable.Table)) + def = strings.Replace(def, tablePart, newTableRef, 1) + } + + return def +} + +func getTableIndexes(ctx context.Context, conn *pgx.Conn, tableName string) ([]*IndexDefinition, error) { + schemaTable, err := utils.ParseSchemaTable(tableName) + if err != nil { + return nil, fmt.Errorf("error parsing table name: %w", err) + } + + query := ` + SELECT + i.indexname AS index_name, + i.indexdef AS index_definition, + i.indexdef LIKE '%UNIQUE%' AS is_unique + FROM pg_indexes i + WHERE i.schemaname = $1 AND i.tablename = $2 + AND i.indexname NOT LIKE '%_pkey' + ORDER BY i.indexname` + + rows, err := conn.Query(ctx, query, schemaTable.Schema, schemaTable.Table) + if err != nil { + return nil, fmt.Errorf("failed to query indexes: %w", err) + } + defer rows.Close() + + var indexes []*IndexDefinition + for rows.Next() { + var indexName, indexDef pgtype.Text + var isUnique pgtype.Bool + if err := rows.Scan(&indexName, &indexDef, &isUnique); err != nil { + return nil, fmt.Errorf("failed to scan index: %w", err) + } + + // Extract columns from index definition (simplified) + columns := extractIndexColumns(indexDef.String) + + indexes = append(indexes, &IndexDefinition{ + Name: indexName.String, + TableSchema: schemaTable.Schema, + TableName: schemaTable.Table, + Columns: columns, + IsUnique: isUnique.Bool, + Definition: indexDef.String, + }) + } + + return indexes, nil +} + +func extractIndexColumns(indexDef string) []string { + // Simple extraction - find columns in parentheses + // This is simplified - in production, use proper SQL parsing + start := strings.Index(indexDef, "(") + end := strings.Index(indexDef, ")") + if start == -1 || end == -1 { + return []string{} + } + + colsStr := indexDef[start+1 : end] + cols := strings.Split(colsStr, ",") + var result []string + for _, col := range cols { + col = strings.TrimSpace(col) + col = strings.Trim(col, "\"") + if col != "" { + result = append(result, col) + } + } + return result +} + +func buildCreateIndexSQL(index *IndexDefinition, targetTable string) string { + // Use the definition from source, but replace table name + schemaTable, _ := utils.ParseSchemaTable(targetTable) + oldTableRef := fmt.Sprintf("%s.%s", index.TableSchema, index.TableName) + newTableRef := fmt.Sprintf("%s.%s", utils.QuoteIdentifier(schemaTable.Schema), utils.QuoteIdentifier(schemaTable.Table)) + return strings.Replace(index.Definition, oldTableRef, newTableRef, 1) +} diff --git a/flow/connectors/postgres/migration_test.go b/flow/connectors/postgres/migration_test.go new file mode 100644 index 0000000000..ffad6c91b1 --- /dev/null +++ b/flow/connectors/postgres/migration_test.go @@ -0,0 +1,407 @@ +package connpostgres + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" + + "github.com/PeerDB-io/peerdb/flow/e2eshared" + "github.com/PeerDB-io/peerdb/flow/generated/protos" + "github.com/PeerDB-io/peerdb/flow/internal" + "github.com/PeerDB-io/peerdb/flow/shared" +) + +type MigrationTestSuite struct { + t *testing.T + sourceConn *PostgresConnector + targetConn *PostgresConnector + sourceSchema string + targetSchema string + sourceConnRaw *pgx.Conn + targetConnRaw *pgx.Conn +} + +func SetupMigrationSuite(t *testing.T) MigrationTestSuite { + t.Helper() + + // Create source connector + sourceConnector, err := NewPostgresConnector(t.Context(), nil, internal.GetCatalogPostgresConfigFromEnv(t.Context())) + require.NoError(t, err) + + // Create target connector (can be same DB, different schema for testing) + targetConnector, err := NewPostgresConnector(t.Context(), nil, internal.GetCatalogPostgresConfigFromEnv(t.Context())) + require.NoError(t, err) + + // Get raw connections for direct queries + sourceConnRaw := sourceConnector.conn + targetConnRaw := targetConnector.conn + + // Create test schemas + sourceSchema := "migrate_src_" + strings.ToLower(shared.RandomString(8)) + targetSchema := "migrate_dst_" + strings.ToLower(shared.RandomString(8)) + + // Setup source schema + setupTx, err := sourceConnRaw.Begin(t.Context()) + require.NoError(t, err) + _, err = setupTx.Exec(t.Context(), fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", sourceSchema)) + require.NoError(t, err) + _, err = setupTx.Exec(t.Context(), "CREATE SCHEMA "+sourceSchema) + require.NoError(t, err) + require.NoError(t, setupTx.Commit(t.Context())) + + // Setup target schema + setupTx2, err := targetConnRaw.Begin(t.Context()) + require.NoError(t, err) + _, err = setupTx2.Exec(t.Context(), fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", targetSchema)) + require.NoError(t, err) + _, err = setupTx2.Exec(t.Context(), "CREATE SCHEMA "+targetSchema) + require.NoError(t, err) + require.NoError(t, setupTx2.Commit(t.Context())) + + return MigrationTestSuite{ + t: t, + sourceConn: sourceConnector, + targetConn: targetConnector, + sourceSchema: sourceSchema, + targetSchema: targetSchema, + sourceConnRaw: sourceConnRaw, + targetConnRaw: targetConnRaw, + } +} + +func (s MigrationTestSuite) Teardown(ctx context.Context) { + // Cleanup schemas + teardownTx, err := s.sourceConnRaw.Begin(ctx) + require.NoError(s.t, err) + _, err = teardownTx.Exec(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", s.sourceSchema)) + require.NoError(s.t, err) + _, err = teardownTx.Exec(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", s.targetSchema)) + require.NoError(s.t, err) + require.NoError(s.t, teardownTx.Commit(ctx)) + + require.NoError(s.t, s.sourceConn.Close()) + require.NoError(s.t, s.targetConn.Close()) +} + +func TestSchemaMigration(t *testing.T) { + e2eshared.RunSuite(t, func(t *testing.T) MigrationTestSuite { + suite := SetupMigrationSuite(t) + defer suite.Teardown(t.Context()) + + // Create source table with schema + tableName := "test_table" + sourceTable := fmt.Sprintf("%s.%s", suite.sourceSchema, tableName) + targetTable := fmt.Sprintf("%s.%s", suite.targetSchema, tableName) + + // Create source table + _, err := suite.sourceConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(255), + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + age INT + )`, sourceTable)) + require.NoError(t, err) + + // Insert some test data + _, err = suite.sourceConnRaw.Exec(t.Context(), fmt.Sprintf(` + INSERT INTO %s (id, name, email, age) VALUES + (1, 'Alice', 'alice@example.com', 30), + (2, 'Bob', 'bob@example.com', 25)`, sourceTable)) + require.NoError(t, err) + + // Migrate schema + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: targetTable, + }, + } + + err = MigrateSchemaFromSource(t.Context(), suite.sourceConn, suite.targetConn, tableMappings) + require.NoError(t, err) + + // Verify target table exists and has correct schema + var columnCount int + err = suite.targetConnRaw.QueryRow(t.Context(), fmt.Sprintf(` + SELECT COUNT(*) FROM information_schema.columns + WHERE table_schema = '%s' AND table_name = '%s'`, + suite.targetSchema, tableName)).Scan(&columnCount) + require.NoError(t, err) + require.Equal(t, 5, columnCount, "Target table should have 5 columns") + + // Verify primary key + var pkCount int + err = suite.targetConnRaw.QueryRow(t.Context(), fmt.Sprintf(` + SELECT COUNT(*) FROM information_schema.table_constraints + WHERE table_schema = '%s' AND table_name = '%s' AND constraint_type = 'PRIMARY KEY'`, + suite.targetSchema, tableName)).Scan(&pkCount) + require.NoError(t, err) + require.Equal(t, 1, pkCount, "Target table should have a primary key") + + return suite + }) +} + +func TestTriggerMigration(t *testing.T) { + e2eshared.RunSuite(t, func(t *testing.T) MigrationTestSuite { + suite := SetupMigrationSuite(t) + defer suite.Teardown(t.Context()) + + tableName := "test_table" + sourceTable := fmt.Sprintf("%s.%s", suite.sourceSchema, tableName) + targetTable := fmt.Sprintf("%s.%s", suite.targetSchema, tableName) + + // Create source table + _, err := suite.sourceConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name VARCHAR(100), + updated_at TIMESTAMP + )`, sourceTable)) + require.NoError(t, err) + + // Create trigger function + _, err = suite.sourceConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.update_updated_at() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = CURRENT_TIMESTAMP; + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, suite.sourceSchema)) + require.NoError(t, err) + + // Create trigger on source table + _, err = suite.sourceConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE TRIGGER update_timestamp + BEFORE UPDATE ON %s + FOR EACH ROW + EXECUTE FUNCTION %s.update_updated_at()`, sourceTable, suite.sourceSchema)) + require.NoError(t, err) + + // First migrate schema + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: targetTable, + }, + } + + err = MigrateSchemaFromSource(t.Context(), suite.sourceConn, suite.targetConn, tableMappings) + require.NoError(t, err) + + // Create trigger function in target schema + _, err = suite.targetConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.update_updated_at() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = CURRENT_TIMESTAMP; + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, suite.targetSchema)) + require.NoError(t, err) + + // Migrate triggers + err = MigrateTriggersFromSource(t.Context(), suite.sourceConn, suite.targetConn, tableMappings) + require.NoError(t, err) + + // Verify trigger exists on target table + var triggerCount int + err = suite.targetConnRaw.QueryRow(t.Context(), fmt.Sprintf(` + SELECT COUNT(*) FROM information_schema.triggers + WHERE trigger_schema = '%s' AND event_object_table = '%s'`, + suite.targetSchema, tableName)).Scan(&triggerCount) + require.NoError(t, err) + require.GreaterOrEqual(t, triggerCount, 1, "Target table should have at least one trigger") + + return suite + }) +} + +func TestIndexMigration(t *testing.T) { + e2eshared.RunSuite(t, func(t *testing.T) MigrationTestSuite { + suite := SetupMigrationSuite(t) + defer suite.Teardown(t.Context()) + + tableName := "test_table" + sourceTable := fmt.Sprintf("%s.%s", suite.sourceSchema, tableName) + targetTable := fmt.Sprintf("%s.%s", suite.targetSchema, tableName) + + // Create source table + _, err := suite.sourceConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name VARCHAR(100), + email VARCHAR(255), + created_at TIMESTAMP + )`, sourceTable)) + require.NoError(t, err) + + // Create indexes on source table + _, err = suite.sourceConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE INDEX idx_name ON %s (name)`, sourceTable)) + require.NoError(t, err) + + _, err = suite.sourceConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE UNIQUE INDEX idx_email ON %s (email)`, sourceTable)) + require.NoError(t, err) + + _, err = suite.sourceConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE INDEX idx_created_at ON %s (created_at DESC)`, sourceTable)) + require.NoError(t, err) + + // First migrate schema + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: targetTable, + }, + } + + err = MigrateSchemaFromSource(t.Context(), suite.sourceConn, suite.targetConn, tableMappings) + require.NoError(t, err) + + // Migrate indexes + err = MigrateIndexesFromSource(t.Context(), suite.sourceConn, suite.targetConn, tableMappings) + require.NoError(t, err) + + // Verify indexes exist on target table + var indexCount int + err = suite.targetConnRaw.QueryRow(t.Context(), fmt.Sprintf(` + SELECT COUNT(*) FROM pg_indexes + WHERE schemaname = '%s' AND tablename = '%s' AND indexname NOT LIKE '%%_pkey'`, + suite.targetSchema, tableName)).Scan(&indexCount) + require.NoError(t, err) + require.GreaterOrEqual(t, indexCount, 2, "Target table should have at least 2 indexes (excluding primary key)") + + // Verify specific indexes + var idxNameExists bool + err = suite.targetConnRaw.QueryRow(t.Context(), fmt.Sprintf(` + SELECT EXISTS ( + SELECT 1 FROM pg_indexes + WHERE schemaname = '%s' AND tablename = '%s' AND indexname = 'idx_name' + )`, suite.targetSchema, tableName)).Scan(&idxNameExists) + require.NoError(t, err) + require.True(t, idxNameExists, "idx_name index should exist") + + return suite + }) +} + +func TestFullMigration(t *testing.T) { + e2eshared.RunSuite(t, func(t *testing.T) MigrationTestSuite { + suite := SetupMigrationSuite(t) + defer suite.Teardown(t.Context()) + + tableName := "full_test_table" + sourceTable := fmt.Sprintf("%s.%s", suite.sourceSchema, tableName) + targetTable := fmt.Sprintf("%s.%s", suite.targetSchema, tableName) + + // Create comprehensive source table + _, err := suite.sourceConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name VARCHAR(100) NOT NULL, + email VARCHAR(255) UNIQUE, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP + )`, sourceTable)) + require.NoError(t, err) + + // Create trigger function + _, err = suite.sourceConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.update_updated_at() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = CURRENT_TIMESTAMP; + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, suite.sourceSchema)) + require.NoError(t, err) + + // Create trigger + _, err = suite.sourceConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE TRIGGER update_timestamp + BEFORE UPDATE ON %s + FOR EACH ROW + EXECUTE FUNCTION %s.update_updated_at()`, sourceTable, suite.sourceSchema)) + require.NoError(t, err) + + // Create indexes + _, err = suite.sourceConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE INDEX idx_name ON %s (name)`, sourceTable)) + require.NoError(t, err) + + _, err = suite.sourceConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE INDEX idx_created_at ON %s (created_at DESC)`, sourceTable)) + require.NoError(t, err) + + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: targetTable, + }, + } + + // Migrate schema + err = MigrateSchemaFromSource(t.Context(), suite.sourceConn, suite.targetConn, tableMappings) + require.NoError(t, err) + + // Create trigger function in target + _, err = suite.targetConnRaw.Exec(t.Context(), fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.update_updated_at() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = CURRENT_TIMESTAMP; + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, suite.targetSchema)) + require.NoError(t, err) + + // Migrate triggers + err = MigrateTriggersFromSource(t.Context(), suite.sourceConn, suite.targetConn, tableMappings) + require.NoError(t, err) + + // Migrate indexes + err = MigrateIndexesFromSource(t.Context(), suite.sourceConn, suite.targetConn, tableMappings) + require.NoError(t, err) + + // Verify everything + // Check table exists + var tableExists bool + err = suite.targetConnRaw.QueryRow(t.Context(), fmt.Sprintf(` + SELECT EXISTS ( + SELECT 1 FROM information_schema.tables + WHERE table_schema = '%s' AND table_name = '%s' + )`, suite.targetSchema, tableName)).Scan(&tableExists) + require.NoError(t, err) + require.True(t, tableExists, "Target table should exist") + + // Check trigger exists + var triggerExists bool + err = suite.targetConnRaw.QueryRow(t.Context(), fmt.Sprintf(` + SELECT EXISTS ( + SELECT 1 FROM information_schema.triggers + WHERE trigger_schema = '%s' AND event_object_table = '%s' + )`, suite.targetSchema, tableName)).Scan(&triggerExists) + require.NoError(t, err) + require.True(t, triggerExists, "Target table should have triggers") + + // Check indexes exist + var indexCount int + err = suite.targetConnRaw.QueryRow(t.Context(), fmt.Sprintf(` + SELECT COUNT(*) FROM pg_indexes + WHERE schemaname = '%s' AND tablename = '%s' AND indexname NOT LIKE '%%_pkey'`, + suite.targetSchema, tableName)).Scan(&indexCount) + require.NoError(t, err) + require.GreaterOrEqual(t, indexCount, 2, "Target table should have at least 2 indexes") + + return suite + }) +} diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 9037b9272a..bd3f7379a5 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -1161,6 +1161,31 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas( return nil } +// MigrateTriggersAndIndexesForPostgresToPostgres migrates triggers and indexes +// from source to target for Postgres-to-Postgres flows. +// This should be called after schema deltas are replayed or during initial setup. +func (c *PostgresConnector) MigrateTriggersAndIndexesForPostgresToPostgres( + ctx context.Context, + sourceConnector *PostgresConnector, + tableMappings []*protos.TableMapping, +) error { + // Migrate triggers + if err := MigrateTriggersFromSource(ctx, sourceConnector, c, tableMappings); err != nil { + c.logger.Warn("failed to migrate triggers", slog.Any("error", err)) + // Don't fail the entire flow if trigger migration fails + // Triggers are less critical than schema + } + + // Migrate indexes + if err := MigrateIndexesFromSource(ctx, sourceConnector, c, tableMappings); err != nil { + c.logger.Warn("failed to migrate indexes", slog.Any("error", err)) + // Don't fail the entire flow if index migration fails + // Indexes can be recreated manually if needed + } + + return nil +} + // EnsurePullability ensures that a table is pullable, implementing the Connector interface. func (c *PostgresConnector) EnsurePullability( ctx context.Context,