From 0950ea258ab0e2d4a2d081aaddae9612172e73dc Mon Sep 17 00:00:00 2001 From: ankitsheoran1 Date: Mon, 17 Nov 2025 18:56:48 +0530 Subject: [PATCH 1/6] add support for basic read schema --- .../connectors/postgres/index_trigger_sync.go | 957 ++++++++++++++++++ .../postgres/index_trigger_sync_test.go | 672 ++++++++++++ 2 files changed, 1629 insertions(+) create mode 100644 flow/connectors/postgres/index_trigger_sync.go create mode 100644 flow/connectors/postgres/index_trigger_sync_test.go diff --git a/flow/connectors/postgres/index_trigger_sync.go b/flow/connectors/postgres/index_trigger_sync.go new file mode 100644 index 0000000000..5f8b7aa65c --- /dev/null +++ b/flow/connectors/postgres/index_trigger_sync.go @@ -0,0 +1,957 @@ +package connpostgres + +import ( + "context" + "fmt" + "log/slog" + "strings" + + "github.com/PeerDB-io/peerdb/flow/connectors/utils" + "github.com/PeerDB-io/peerdb/flow/generated/protos" +) + +// IndexInfo represents information about a PostgreSQL index +type IndexInfo struct { + IndexName string + TableSchema string + TableName string + IndexDef string + IsUnique bool + IsPrimary bool + IndexColumns []string +} + +// TriggerInfo represents information about a PostgreSQL trigger +type TriggerInfo struct { + TriggerName string + TableSchema string + TableName string + TriggerDef string + EventManipulation string + ActionTiming string + ActionStatement string +} + +// ConstraintInfo represents information about a PostgreSQL constraint +type ConstraintInfo struct { + ConstraintName string + TableSchema string + TableName string + ConstraintType string // 'c' for check, 'f' for foreign key, 'u' for unique, 'p' for primary key + ConstraintDef string // Full constraint definition from pg_get_constraintdef + IsDeferrable bool + IsDeferred bool +} + +// SyncIndexesAndTriggers syncs indexes, triggers, and constraints from source to destination. +// This is called once during initial setup, not for on-the-fly changes. +// +// Features: +// - Syncs all non-primary-key indexes from source to destination +// - Syncs all triggers from source to destination +// - Syncs check constraints and foreign key constraints +// - Automatically syncs trigger functions if they don't exist on destination +// - Skips indexes/triggers/constraints that already exist on destination +// +// Limitations: +// - Only runs during initial setup (not for on-the-fly changes) +// - Requires trigger functions to exist on source (or will attempt to sync them) +// - Primary key indexes and constraints are skipped (already exist) +// - Foreign key constraints referencing tables not in the sync are skipped +func (c *PostgresConnector) SyncIndexesAndTriggers( + ctx context.Context, + tableMappings []*protos.TableMapping, + sourceConn *PostgresConnector, +) error { + c.logger.Info("Starting index and trigger synchronization", + slog.Int("tableCount", len(tableMappings))) + + for _, tableMapping := range tableMappings { + srcTable, err := utils.ParseSchemaTable(tableMapping.SourceTableIdentifier) + if err != nil { + return fmt.Errorf("error parsing source table %s: %w", tableMapping.SourceTableIdentifier, err) + } + + dstTable, err := utils.ParseSchemaTable(tableMapping.DestinationTableIdentifier) + if err != nil { + return fmt.Errorf("error parsing destination table %s: %w", tableMapping.DestinationTableIdentifier, err) + } + + // Sync indexes + if err := c.syncIndexesForTable(ctx, srcTable, dstTable, sourceConn); err != nil { + c.logger.Warn("Failed to sync indexes for table", + slog.String("srcTable", srcTable.String()), + slog.String("dstTable", dstTable.String()), + slog.Any("error", err)) + // Continue with other tables even if one fails + } + + // Sync triggers + if err := c.syncTriggersForTable(ctx, srcTable, dstTable, sourceConn); err != nil { + c.logger.Warn("Failed to sync triggers for table", + slog.String("srcTable", srcTable.String()), + slog.String("dstTable", dstTable.String()), + slog.Any("error", err)) + // Continue with other tables even if one fails + } + + // Sync constraints (check constraints and foreign keys) + if err := c.syncConstraintsForTable(ctx, srcTable, dstTable, sourceConn, tableMappings); err != nil { + c.logger.Warn("Failed to sync constraints for table", + slog.String("srcTable", srcTable.String()), + slog.String("dstTable", dstTable.String()), + slog.Any("error", err)) + // Continue with other tables even if one fails + } + } + + c.logger.Info("Completed index, trigger, and constraint synchronization") + return nil +} + +// syncIndexesForTable syncs indexes for a specific table +func (c *PostgresConnector) syncIndexesForTable( + ctx context.Context, + srcTable *utils.SchemaTable, + dstTable *utils.SchemaTable, + sourceConn *PostgresConnector, +) error { + // Get indexes from source + srcIndexes, err := sourceConn.getIndexesForTable(ctx, srcTable) + if err != nil { + return fmt.Errorf("error getting source indexes: %w", err) + } + + // Get indexes from destination + dstIndexes, err := c.getIndexesForTable(ctx, dstTable) + if err != nil { + return fmt.Errorf("error getting destination indexes: %w", err) + } + + // Create a map of destination indexes by name for quick lookup + dstIndexMap := make(map[string]*IndexInfo, len(dstIndexes)) + for _, idx := range dstIndexes { + dstIndexMap[idx.IndexName] = idx + } + + // Find missing indexes and create them + createdCount := 0 + for _, srcIdx := range srcIndexes { + // Skip primary key indexes - they should already exist + if srcIdx.IsPrimary { + continue + } + + // Check if index already exists in destination + if _, exists := dstIndexMap[srcIdx.IndexName]; exists { + c.logger.Debug("Index already exists in destination", + slog.String("indexName", srcIdx.IndexName), + slog.String("dstTable", dstTable.String())) + continue + } + + // Create the index + // Replace source schema/table names with destination + indexSQL := c.adaptIndexSQL(srcIdx.IndexDef, srcTable, dstTable) + + c.logger.Info("Creating index on destination", + slog.String("indexName", srcIdx.IndexName), + slog.String("srcTable", srcTable.String()), + slog.String("dstTable", dstTable.String()), + slog.String("indexSQL", indexSQL)) + + if _, err := c.conn.Exec(ctx, indexSQL); err != nil { + c.logger.Error("Failed to create index", + slog.String("indexName", srcIdx.IndexName), + slog.String("indexSQL", indexSQL), + slog.Any("error", err)) + // Continue with other indexes even if one fails + continue + } + + createdCount++ + c.logger.Info("Successfully created index", + slog.String("indexName", srcIdx.IndexName), + slog.String("dstTable", dstTable.String())) + } + + if createdCount > 0 { + c.logger.Info("Created indexes for table", + slog.String("dstTable", dstTable.String()), + slog.Int("createdCount", createdCount)) + } + + return nil +} + +// syncTriggersForTable syncs triggers for a specific table +func (c *PostgresConnector) syncTriggersForTable( + ctx context.Context, + srcTable *utils.SchemaTable, + dstTable *utils.SchemaTable, + sourceConn *PostgresConnector, +) error { + c.logger.Info("Starting trigger sync for table", + slog.String("srcTable", srcTable.String()), + slog.String("dstTable", dstTable.String())) + + // Get triggers from source + srcTriggers, err := sourceConn.getTriggersForTable(ctx, srcTable) + if err != nil { + return fmt.Errorf("error getting source triggers: %w", err) + } + + c.logger.Info("Retrieved source triggers", + slog.String("srcTable", srcTable.String()), + slog.Int("triggerCount", len(srcTriggers))) + + // Get triggers from destination + dstTriggers, err := c.getTriggersForTable(ctx, dstTable) + if err != nil { + return fmt.Errorf("error getting destination triggers: %w", err) + } + + c.logger.Info("Retrieved destination triggers", + slog.String("dstTable", dstTable.String()), + slog.Int("triggerCount", len(dstTriggers))) + + // Create a map of destination triggers by name for quick lookup + dstTriggerMap := make(map[string]*TriggerInfo, len(dstTriggers)) + for _, trig := range dstTriggers { + dstTriggerMap[trig.TriggerName] = trig + } + + // Find missing triggers and create them + createdCount := 0 + for _, srcTrig := range srcTriggers { + c.logger.Info("Processing source trigger", + slog.String("triggerName", srcTrig.TriggerName), + slog.String("triggerDef", srcTrig.TriggerDef), + slog.String("srcTable", srcTable.String())) + + // Check if trigger already exists in destination + if _, exists := dstTriggerMap[srcTrig.TriggerName]; exists { + c.logger.Info("Trigger already exists in destination, skipping", + slog.String("triggerName", srcTrig.TriggerName), + slog.String("dstTable", dstTable.String())) + continue + } + + // Extract function name from trigger definition + funcName, funcSchema := c.extractFunctionFromTriggerDef(srcTrig.TriggerDef) + c.logger.Info("Extracted function from trigger definition", + slog.String("triggerName", srcTrig.TriggerName), + slog.String("functionName", funcName), + slog.String("functionSchema", funcSchema)) + + // Check if function exists on destination + // Try multiple schemas if function name doesn't have schema qualification + funcExists := false + if funcName != "" { + schemasToCheck := []string{funcSchema} + // If no schema was specified, try public schema and table's schema + if funcSchema == "public" || funcSchema == "" { + schemasToCheck = []string{"public", dstTable.Schema, srcTable.Schema} + } + + for _, schema := range schemasToCheck { + exists, err := c.checkFunctionExists(ctx, schema, funcName) + if err != nil { + c.logger.Warn("Failed to check if function exists", + slog.String("functionName", funcName), + slog.String("functionSchema", schema), + slog.Any("error", err)) + continue + } + if exists { + funcExists = true + funcSchema = schema // Update to the actual schema where function exists + c.logger.Info("Found function on destination", + slog.String("functionName", funcName), + slog.String("functionSchema", schema)) + break + } + } + + if !funcExists { + // Try to sync the function from source + // Try multiple schemas on source to find the function + sourceSchemasToCheck := []string{funcSchema, "public", srcTable.Schema} + if funcSchema == "public" || funcSchema == "" { + sourceSchemasToCheck = []string{"public", srcTable.Schema} + } + + funcSynced := false + for _, sourceSchema := range sourceSchemasToCheck { + c.logger.Info("Attempting to sync trigger function from source", + slog.String("functionName", funcName), + slog.String("sourceSchema", sourceSchema), + slog.String("targetSchema", funcSchema)) + + if err := c.syncTriggerFunction(ctx, sourceSchema, funcName, funcSchema, sourceConn); err != nil { + c.logger.Debug("Failed to sync function from this schema, trying next", + slog.String("functionName", funcName), + slog.String("sourceSchema", sourceSchema), + slog.Any("error", err)) + continue + } + + // Verify function was created + funcExists, err = c.checkFunctionExists(ctx, funcSchema, funcName) + if err == nil && funcExists { + funcSynced = true + c.logger.Info("Successfully synced trigger function", + slog.String("functionName", funcName), + slog.String("sourceSchema", sourceSchema), + slog.String("targetSchema", funcSchema)) + break + } + } + + if !funcSynced { + c.logger.Warn("Failed to sync trigger function from source, skipping trigger", + slog.String("triggerName", srcTrig.TriggerName), + slog.String("functionName", funcName), + slog.String("checkedSchemas", fmt.Sprintf("%v", sourceSchemasToCheck)), + slog.String("hint", "Create the function manually on destination")) + continue + } + } + } + + // Create the trigger + // pg_get_triggerdef already gives us the full CREATE TRIGGER statement + // We just need to replace source schema/table names with destination + triggerSQL := c.adaptTriggerSQL(srcTrig.TriggerDef, srcTable, dstTable) + + c.logger.Info("Creating trigger on destination", + slog.String("triggerName", srcTrig.TriggerName), + slog.String("srcTable", srcTable.String()), + slog.String("dstTable", dstTable.String()), + slog.String("triggerSQL", triggerSQL)) + + if _, err := c.conn.Exec(ctx, triggerSQL); err != nil { + c.logger.Error("Failed to create trigger", + slog.String("triggerName", srcTrig.TriggerName), + slog.String("triggerSQL", triggerSQL), + slog.Any("error", err)) + // Continue with other triggers even if one fails + continue + } + + createdCount++ + c.logger.Info("Successfully created trigger", + slog.String("triggerName", srcTrig.TriggerName), + slog.String("dstTable", dstTable.String())) + } + + if createdCount > 0 { + c.logger.Info("Created triggers for table", + slog.String("dstTable", dstTable.String()), + slog.Int("createdCount", createdCount)) + } + + return nil +} + +// getIndexesForTable retrieves all indexes for a given table +func (c *PostgresConnector) getIndexesForTable( + ctx context.Context, + table *utils.SchemaTable, +) ([]*IndexInfo, error) { + // Use pg_indexes view which is simpler and more reliable + query := ` + SELECT + indexname, + schemaname, + tablename, + indexdef + FROM pg_indexes + WHERE schemaname = $1 AND tablename = $2 + ORDER BY indexname + ` + + rows, err := c.conn.Query(ctx, query, table.Schema, table.Table) + if err != nil { + return nil, fmt.Errorf("error querying indexes: %w", err) + } + defer rows.Close() + + var indexes []*IndexInfo + for rows.Next() { + var idx IndexInfo + err := rows.Scan( + &idx.IndexName, + &idx.TableSchema, + &idx.TableName, + &idx.IndexDef, + ) + if err != nil { + return nil, fmt.Errorf("error scanning index row: %w", err) + } + + // Determine if index is unique or primary key + idx.IsUnique = strings.Contains(strings.ToUpper(idx.IndexDef), "UNIQUE") + + // Check if it's a primary key constraint + // Primary keys are typically named like tablename_pkey + idx.IsPrimary = strings.HasSuffix(idx.IndexName, "_pkey") || + strings.Contains(strings.ToUpper(idx.IndexDef), "PRIMARY KEY") + + // Extract column names from index definition + // This is a simple extraction - may need refinement for complex cases + idx.IndexColumns = c.extractColumnsFromIndexDef(idx.IndexDef) + + indexes = append(indexes, &idx) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating index rows: %w", err) + } + + return indexes, nil +} + +// extractColumnsFromIndexDef extracts column names from index definition +func (c *PostgresConnector) extractColumnsFromIndexDef(indexDef string) []string { + // This is a simplified extraction - looks for patterns like (col1, col2) + // For more complex cases, we might need to parse the SQL properly + var columns []string + + // Find the part between parentheses + start := strings.Index(indexDef, "(") + end := strings.LastIndex(indexDef, ")") + if start >= 0 && end > start { + colPart := indexDef[start+1 : end] + // Split by comma and clean up + parts := strings.Split(colPart, ",") + for _, part := range parts { + col := strings.TrimSpace(part) + // Remove function calls, operators, etc. - just get column name + // Remove quotes if present + col = strings.Trim(col, `"'`) + // Take only the column name part (before any operators or functions) + if spaceIdx := strings.Index(col, " "); spaceIdx > 0 { + col = col[:spaceIdx] + } + if len(col) > 0 { + columns = append(columns, col) + } + } + } + + return columns +} + +// getTriggersForTable retrieves all triggers for a given table +func (c *PostgresConnector) getTriggersForTable( + ctx context.Context, + table *utils.SchemaTable, +) ([]*TriggerInfo, error) { + // Use pg_trigger and pg_proc to get full trigger definition + query := ` + SELECT + t.tgname as trigger_name, + n.nspname as schema_name, + c.relname as table_name, + pg_get_triggerdef(t.oid) as trigger_def, + CASE + WHEN t.tgtype & 2 = 2 THEN 'BEFORE' + WHEN t.tgtype & 64 = 64 THEN 'INSTEAD OF' + ELSE 'AFTER' + END as action_timing, + CASE + WHEN t.tgtype & 4 = 4 THEN 'INSERT' + WHEN t.tgtype & 8 = 8 THEN 'DELETE' + WHEN t.tgtype & 16 = 16 THEN 'UPDATE' + ELSE 'UNKNOWN' + END as event_manipulation + FROM pg_trigger t + JOIN pg_class c ON c.oid = t.tgrelid + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = $1 + AND c.relname = $2 + AND NOT t.tgisinternal + ORDER BY t.tgname + ` + + rows, err := c.conn.Query(ctx, query, table.Schema, table.Table) + if err != nil { + return nil, fmt.Errorf("error querying triggers: %w", err) + } + defer rows.Close() + + var triggers []*TriggerInfo + for rows.Next() { + var trig TriggerInfo + err := rows.Scan( + &trig.TriggerName, + &trig.TableSchema, + &trig.TableName, + &trig.TriggerDef, + &trig.ActionTiming, + &trig.EventManipulation, + ) + if err != nil { + return nil, fmt.Errorf("error scanning trigger row: %w", err) + } + + // Extract action statement from trigger definition + // The trigger_def from pg_get_triggerdef already contains the full CREATE TRIGGER statement + // We just need to extract the EXECUTE FUNCTION part + trig.ActionStatement = c.extractActionStatement(trig.TriggerDef) + + triggers = append(triggers, &trig) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating trigger rows: %w", err) + } + + return triggers, nil +} + +// extractActionStatement extracts the action statement (EXECUTE FUNCTION ...) from trigger definition +func (c *PostgresConnector) extractActionStatement(triggerDef string) string { + // pg_get_triggerdef returns something like: + // CREATE TRIGGER trigger_name BEFORE INSERT ON schema.table FOR EACH ROW EXECUTE FUNCTION function_name() + // We want to extract the EXECUTE FUNCTION part + executeIdx := strings.Index(strings.ToUpper(triggerDef), "EXECUTE") + if executeIdx >= 0 { + return triggerDef[executeIdx:] + } + return "" +} + +// adaptIndexSQL adapts index SQL from source to destination table +func (c *PostgresConnector) adaptIndexSQL( + indexSQL string, + srcTable *utils.SchemaTable, + dstTable *utils.SchemaTable, +) string { + // Replace source schema.table with destination schema.table + adapted := strings.ReplaceAll(indexSQL, + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(srcTable.Schema), utils.QuoteIdentifier(srcTable.Table)), + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTable.Table))) + + // Also handle unquoted versions + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("%s.%s", srcTable.Schema, srcTable.Table), + fmt.Sprintf("%s.%s", dstTable.Schema, dstTable.Table)) + + return adapted +} + +// extractFunctionFromTriggerDef extracts function name and schema from trigger definition +func (c *PostgresConnector) extractFunctionFromTriggerDef(triggerDef string) (funcName, funcSchema string) { + // pg_get_triggerdef returns something like: + // CREATE TRIGGER trigger_name BEFORE INSERT ON schema.table FOR EACH ROW EXECUTE FUNCTION schema.function_name() + // We need to extract the function name and schema + + // Find "EXECUTE FUNCTION" or "EXECUTE PROCEDURE" + executeIdx := strings.Index(strings.ToUpper(triggerDef), "EXECUTE") + if executeIdx < 0 { + return "", "" + } + + // Get the part after EXECUTE + executePart := triggerDef[executeIdx:] + + // Look for FUNCTION or PROCEDURE keyword + funcKeywordIdx := strings.Index(strings.ToUpper(executePart), "FUNCTION") + if funcKeywordIdx < 0 { + funcKeywordIdx = strings.Index(strings.ToUpper(executePart), "PROCEDURE") + if funcKeywordIdx < 0 { + return "", "" + } + } + + // Get the function part (after FUNCTION/PROCEDURE keyword) + funcPart := strings.TrimSpace(executePart[funcKeywordIdx+8:]) // 8 = len("FUNCTION") or len("PROCEDURE") + + // Remove trailing parentheses and whitespace + funcPart = strings.TrimSpace(strings.TrimSuffix(funcPart, "()")) + funcPart = strings.TrimSpace(strings.TrimSuffix(funcPart, ")")) + + // Check if it has schema qualification (schema.function) + if dotIdx := strings.LastIndex(funcPart, "."); dotIdx >= 0 { + funcSchema = funcPart[:dotIdx] + funcName = funcPart[dotIdx+1:] + // Remove quotes if present + funcSchema = strings.Trim(funcSchema, `"'`) + funcName = strings.Trim(funcName, `"'`) + } else { + // No schema, function is in current schema or public + funcName = strings.Trim(funcPart, `"'`) + funcSchema = "public" // Default to public schema + } + + return funcName, funcSchema +} + +// checkFunctionExists checks if a function exists in the specified schema +func (c *PostgresConnector) checkFunctionExists(ctx context.Context, schema, funcName string) (bool, error) { + query := ` + SELECT EXISTS ( + SELECT 1 + FROM pg_proc p + JOIN pg_namespace n ON n.oid = p.pronamespace + WHERE n.nspname = $1 AND p.proname = $2 + ) + ` + + var exists bool + err := c.conn.QueryRow(ctx, query, schema, funcName).Scan(&exists) + if err != nil { + return false, fmt.Errorf("error checking function existence: %w", err) + } + + return exists, nil +} + +// syncTriggerFunction syncs a trigger function from source to destination +func (c *PostgresConnector) syncTriggerFunction( + ctx context.Context, + sourceSchema, funcName, targetSchema string, + sourceConn *PostgresConnector, +) error { + // Get function definition from source + query := ` + SELECT pg_get_functiondef(p.oid) as function_def + FROM pg_proc p + JOIN pg_namespace n ON n.oid = p.pronamespace + WHERE n.nspname = $1 AND p.proname = $2 + LIMIT 1 + ` + + var funcDef string + err := sourceConn.conn.QueryRow(ctx, query, sourceSchema, funcName).Scan(&funcDef) + if err != nil { + return fmt.Errorf("error getting function definition from source schema %s: %w", sourceSchema, err) + } + + if funcDef == "" { + return fmt.Errorf("function definition is empty") + } + + c.logger.Info("Retrieved function definition from source", + slog.String("functionName", funcName), + slog.String("sourceSchema", sourceSchema), + slog.String("functionDef", funcDef)) + + // Adapt function definition to use target schema if different + // Replace source schema with target schema in the function definition + if sourceSchema != targetSchema { + funcDef = strings.ReplaceAll(funcDef, + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(sourceSchema), utils.QuoteIdentifier(funcName)), + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(targetSchema), utils.QuoteIdentifier(funcName))) + funcDef = strings.ReplaceAll(funcDef, + fmt.Sprintf("%s.%s", sourceSchema, funcName), + fmt.Sprintf("%s.%s", targetSchema, funcName)) + } + + // Create function on destination + // The function definition from pg_get_functiondef already includes CREATE OR REPLACE FUNCTION + // We just need to execute it + c.logger.Info("Creating function on destination", + slog.String("functionName", funcName), + slog.String("targetSchema", targetSchema)) + + if _, err := c.conn.Exec(ctx, funcDef); err != nil { + return fmt.Errorf("error creating function on destination: %w", err) + } + + c.logger.Info("Successfully created function on destination", + slog.String("functionName", funcName), + slog.String("targetSchema", targetSchema)) + + return nil +} + +// adaptTriggerSQL adapts trigger SQL from source to destination table +func (c *PostgresConnector) adaptTriggerSQL( + triggerSQL string, + srcTable *utils.SchemaTable, + dstTable *utils.SchemaTable, +) string { + // Replace source schema.table with destination schema.table + adapted := strings.ReplaceAll(triggerSQL, + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(srcTable.Schema), utils.QuoteIdentifier(srcTable.Table)), + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTable.Table))) + + // Also handle unquoted versions + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("%s.%s", srcTable.Schema, srcTable.Table), + fmt.Sprintf("%s.%s", dstTable.Schema, dstTable.Table)) + + return adapted +} + +// syncConstraintsForTable syncs constraints (check and foreign key) for a specific table +func (c *PostgresConnector) syncConstraintsForTable( + ctx context.Context, + srcTable *utils.SchemaTable, + dstTable *utils.SchemaTable, + sourceConn *PostgresConnector, + tableMappings []*protos.TableMapping, +) error { + c.logger.Info("Starting constraint sync for table", + slog.String("srcTable", srcTable.String()), + slog.String("dstTable", dstTable.String())) + + // Get constraints from source + srcConstraints, err := sourceConn.getConstraintsForTable(ctx, srcTable) + if err != nil { + return fmt.Errorf("error getting source constraints: %w", err) + } + + c.logger.Info("Retrieved source constraints", + slog.String("srcTable", srcTable.String()), + slog.Int("constraintCount", len(srcConstraints))) + + // Get constraints from destination + dstConstraints, err := c.getConstraintsForTable(ctx, dstTable) + if err != nil { + return fmt.Errorf("error getting destination constraints: %w", err) + } + + c.logger.Info("Retrieved destination constraints", + slog.String("dstTable", dstTable.String()), + slog.Int("constraintCount", len(dstConstraints))) + + // Create a map of destination constraints by name for quick lookup + dstConstraintMap := make(map[string]*ConstraintInfo, len(dstConstraints)) + for _, constraint := range dstConstraints { + dstConstraintMap[constraint.ConstraintName] = constraint + } + + // Build a mapping of source table names to destination table names for FK resolution + tableNameMap := make(map[string]string) + for _, tm := range tableMappings { + src, err := utils.ParseSchemaTable(tm.SourceTableIdentifier) + if err != nil { + continue + } + dst, err := utils.ParseSchemaTable(tm.DestinationTableIdentifier) + if err != nil { + continue + } + // Map both qualified and unqualified names + tableNameMap[src.String()] = dst.String() + tableNameMap[fmt.Sprintf("%s.%s", src.Schema, src.Table)] = fmt.Sprintf("%s.%s", dst.Schema, dst.Table) + } + + // Find missing constraints and create them + createdCount := 0 + for _, srcConstraint := range srcConstraints { + c.logger.Info("Processing source constraint", + slog.String("constraintName", srcConstraint.ConstraintName), + slog.String("constraintType", srcConstraint.ConstraintType), + slog.String("constraintDef", srcConstraint.ConstraintDef), + slog.String("srcTable", srcTable.String())) + + // Skip primary key constraints - they should already exist + if srcConstraint.ConstraintType == "p" { + c.logger.Debug("Skipping primary key constraint", + slog.String("constraintName", srcConstraint.ConstraintName)) + continue + } + + // Skip unique constraints that are already covered by unique indexes + if srcConstraint.ConstraintType == "u" { + c.logger.Debug("Skipping unique constraint (handled by unique index)", + slog.String("constraintName", srcConstraint.ConstraintName)) + continue + } + + // Check if constraint already exists in destination + if _, exists := dstConstraintMap[srcConstraint.ConstraintName]; exists { + c.logger.Info("Constraint already exists in destination, skipping", + slog.String("constraintName", srcConstraint.ConstraintName), + slog.String("dstTable", dstTable.String())) + continue + } + + // Adapt constraint definition for destination + constraintSQL := c.adaptConstraintSQL(srcConstraint.ConstraintDef, srcTable, dstTable, tableNameMap, srcConstraint.ConstraintName) + + c.logger.Info("Creating constraint on destination", + slog.String("constraintName", srcConstraint.ConstraintName), + slog.String("constraintType", srcConstraint.ConstraintType), + slog.String("srcTable", srcTable.String()), + slog.String("dstTable", dstTable.String()), + slog.String("constraintSQL", constraintSQL)) + + if _, err := c.conn.Exec(ctx, constraintSQL); err != nil { + c.logger.Error("Failed to create constraint", + slog.String("constraintName", srcConstraint.ConstraintName), + slog.String("constraintType", srcConstraint.ConstraintType), + slog.String("constraintSQL", constraintSQL), + slog.Any("error", err)) + // Continue with other constraints even if one fails + continue + } + + createdCount++ + c.logger.Info("Successfully created constraint", + slog.String("constraintName", srcConstraint.ConstraintName), + slog.String("constraintType", srcConstraint.ConstraintType), + slog.String("dstTable", dstTable.String())) + } + + if createdCount > 0 { + c.logger.Info("Created constraints for table", + slog.String("dstTable", dstTable.String()), + slog.Int("createdCount", createdCount)) + } + + return nil +} + +// getConstraintsForTable retrieves all constraints for a given table +func (c *PostgresConnector) getConstraintsForTable( + ctx context.Context, + table *utils.SchemaTable, +) ([]*ConstraintInfo, error) { + // Query constraints from pg_constraint + query := ` + SELECT + con.conname as constraint_name, + n.nspname as schema_name, + c.relname as table_name, + con.contype::text as constraint_type, + pg_get_constraintdef(con.oid) as constraint_def, + con.condeferrable as is_deferrable, + con.condeferred as is_deferred + FROM pg_constraint con + JOIN pg_class c ON c.oid = con.conrelid + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = $1 + AND c.relname = $2 + AND con.contype IN ('c', 'f') -- 'c' for check, 'f' for foreign key + ORDER BY con.conname + ` + + rows, err := c.conn.Query(ctx, query, table.Schema, table.Table) + if err != nil { + return nil, fmt.Errorf("error querying constraints: %w", err) + } + defer rows.Close() + + var constraints []*ConstraintInfo + for rows.Next() { + var constraint ConstraintInfo + err := rows.Scan( + &constraint.ConstraintName, + &constraint.TableSchema, + &constraint.TableName, + &constraint.ConstraintType, + &constraint.ConstraintDef, + &constraint.IsDeferrable, + &constraint.IsDeferred, + ) + if err != nil { + return nil, fmt.Errorf("error scanning constraint row: %w", err) + } + + constraints = append(constraints, &constraint) + } + + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("error iterating constraint rows: %w", err) + } + + return constraints, nil +} + +// adaptConstraintSQL adapts constraint SQL from source to destination table +func (c *PostgresConnector) adaptConstraintSQL( + constraintDef string, + srcTable *utils.SchemaTable, + dstTable *utils.SchemaTable, + tableNameMap map[string]string, + constraintName string, +) string { + // The constraint definition from pg_get_constraintdef is already in the format: + // For check: CHECK (expression) + // For foreign key: FOREIGN KEY (columns) REFERENCES table(columns) + // We need to: + // 1. For foreign keys, replace referenced table names using tableNameMap FIRST + // 2. Then replace source table name with destination table name (for self-referencing FKs) + + adapted := constraintDef + + // For foreign key constraints, replace referenced table names + // Handle both cross-table and self-referencing foreign keys + if strings.HasPrefix(strings.ToUpper(strings.TrimSpace(adapted)), "FOREIGN KEY") { + // First, add the current table to the tableNameMap if not already present + // This ensures self-referencing FKs are handled + srcTableStr := fmt.Sprintf("%s.%s", srcTable.Schema, srcTable.Table) + dstTableStr := fmt.Sprintf("%s.%s", dstTable.Schema, dstTable.Table) + if _, exists := tableNameMap[srcTableStr]; !exists { + tableNameMap[srcTableStr] = dstTableStr + } + // Also add unqualified table name + if _, exists := tableNameMap[srcTable.Table]; !exists { + tableNameMap[srcTable.Table] = dstTable.Table + } + + // Look for REFERENCES clause and replace table names + for srcTableName, dstTableName := range tableNameMap { + // Handle schema-qualified table names in REFERENCES + if strings.Contains(srcTableName, ".") { + parts := strings.Split(srcTableName, ".") + if len(parts) == 2 { + srcSchema, srcTbl := parts[0], parts[1] + dstParts := strings.Split(dstTableName, ".") + if len(dstParts) == 2 { + dstSchema, dstTbl := dstParts[0], dstParts[1] + // Replace schema.table references (quoted) + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("REFERENCES %s.%s", utils.QuoteIdentifier(srcSchema), utils.QuoteIdentifier(srcTbl)), + fmt.Sprintf("REFERENCES %s.%s", utils.QuoteIdentifier(dstSchema), utils.QuoteIdentifier(dstTbl))) + // Replace schema.table references (unquoted) + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("REFERENCES %s.%s", srcSchema, srcTbl), + fmt.Sprintf("REFERENCES %s.%s", dstSchema, dstTbl)) + } + } + } else { + // Handle unqualified table names in REFERENCES + // Replace unqualified table name (quoted) + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("REFERENCES %s", utils.QuoteIdentifier(srcTableName)), + fmt.Sprintf("REFERENCES %s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTableName))) + // Replace unqualified table name (unquoted) + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("REFERENCES %s", srcTableName), + fmt.Sprintf("REFERENCES %s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTableName))) + } + } + } + + // Replace source schema.table with destination schema.table in CHECK constraints + // (For FKs, we've already handled the REFERENCES clause above) + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(srcTable.Schema), utils.QuoteIdentifier(srcTable.Table)), + fmt.Sprintf("%s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTable.Table))) + + // Also handle unquoted versions + adapted = strings.ReplaceAll(adapted, + fmt.Sprintf("%s.%s", srcTable.Schema, srcTable.Table), + fmt.Sprintf("%s.%s", dstTable.Schema, dstTable.Table)) + + // Build the full ALTER TABLE statement + // For check constraints: ALTER TABLE ... ADD CONSTRAINT ... CHECK ... + // For foreign keys: ALTER TABLE ... ADD CONSTRAINT ... FOREIGN KEY ... + if strings.HasPrefix(strings.ToUpper(strings.TrimSpace(adapted)), "CHECK") || + strings.HasPrefix(strings.ToUpper(strings.TrimSpace(adapted)), "FOREIGN KEY") { + return fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s %s", + utils.QuoteIdentifier(dstTable.Schema), + utils.QuoteIdentifier(dstTable.Table), + utils.QuoteIdentifier(constraintName), + adapted) + } + + return adapted +} diff --git a/flow/connectors/postgres/index_trigger_sync_test.go b/flow/connectors/postgres/index_trigger_sync_test.go new file mode 100644 index 0000000000..4e74ed1b4f --- /dev/null +++ b/flow/connectors/postgres/index_trigger_sync_test.go @@ -0,0 +1,672 @@ +package connpostgres + +import ( + "context" + "fmt" + "strings" + "testing" + + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/require" + + "github.com/PeerDB-io/peerdb/flow/connectors/utils" + "github.com/PeerDB-io/peerdb/flow/e2eshared" + "github.com/PeerDB-io/peerdb/flow/generated/protos" + "github.com/PeerDB-io/peerdb/flow/internal" + "github.com/PeerDB-io/peerdb/flow/shared" +) + +type IndexTriggerSyncTestSuite struct { + t *testing.T + sourceConn *PostgresConnector + destConn *PostgresConnector + sourceSchema string + destSchema string +} + +func SetupIndexTriggerSyncSuite(t *testing.T) IndexTriggerSyncTestSuite { + t.Helper() + + // Create source connector + sourceConn, err := NewPostgresConnector(t.Context(), nil, internal.GetCatalogPostgresConfigFromEnv(t.Context())) + require.NoError(t, err) + + // Create destination connector (can be same DB for testing) + destConn, err := NewPostgresConnector(t.Context(), nil, internal.GetCatalogPostgresConfigFromEnv(t.Context())) + require.NoError(t, err) + + // Create test schemas + sourceSchema := "src_idx_" + strings.ToLower(shared.RandomString(8)) + destSchema := "dst_idx_" + strings.ToLower(shared.RandomString(8)) + + setupTx, err := sourceConn.conn.Begin(t.Context()) + require.NoError(t, err) + defer func() { + err := setupTx.Rollback(t.Context()) + if err != pgx.ErrTxClosed { + require.NoError(t, err) + } + }() + + _, err = setupTx.Exec(t.Context(), fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", sourceSchema)) + require.NoError(t, err) + _, err = setupTx.Exec(t.Context(), "CREATE SCHEMA "+sourceSchema) + require.NoError(t, err) + require.NoError(t, setupTx.Commit(t.Context())) + + setupTx2, err := destConn.conn.Begin(t.Context()) + require.NoError(t, err) + defer func() { + err := setupTx2.Rollback(t.Context()) + if err != pgx.ErrTxClosed { + require.NoError(t, err) + } + }() + + _, err = setupTx2.Exec(t.Context(), fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", destSchema)) + require.NoError(t, err) + _, err = setupTx2.Exec(t.Context(), "CREATE SCHEMA "+destSchema) + require.NoError(t, err) + require.NoError(t, setupTx2.Commit(t.Context())) + + return IndexTriggerSyncTestSuite{ + t: t, + sourceConn: sourceConn, + destConn: destConn, + sourceSchema: sourceSchema, + destSchema: destSchema, + } +} + +func (s IndexTriggerSyncTestSuite) Teardown(ctx context.Context) { + _, _ = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", s.sourceSchema)) + _, _ = s.destConn.conn.Exec(ctx, fmt.Sprintf("DROP SCHEMA IF EXISTS %s CASCADE", s.destSchema)) + require.NoError(s.t, s.sourceConn.Close()) + require.NoError(s.t, s.destConn.Close()) +} + +func (s IndexTriggerSyncTestSuite) TestSyncIndexes() { + ctx := s.t.Context() + sourceTable := s.sourceSchema + ".test_table" + destTable := s.destSchema + ".test_table" + + // Create source table with indexes + _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT, + email TEXT, + created_at TIMESTAMPTZ DEFAULT NOW() + )`, sourceTable)) + require.NoError(s.t, err) + + // Create indexes on source + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("CREATE INDEX idx_name ON %s(name)", sourceTable)) + require.NoError(s.t, err) + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("CREATE UNIQUE INDEX idx_email ON %s(email)", sourceTable)) + require.NoError(s.t, err) + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("CREATE INDEX idx_created_at ON %s(created_at DESC)", sourceTable)) + require.NoError(s.t, err) + + // Create destination table (without indexes) + // Note: In the real PeerDB flow, CreateNormalizedTable creates the destination table automatically + // before SyncIndexesAndTriggers is called. Here we create it manually to simulate that state + // for unit testing SyncIndexesAndTriggers in isolation. + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT, + email TEXT, + created_at TIMESTAMPTZ DEFAULT NOW() + )`, destTable)) + require.NoError(s.t, err) + + // Sync indexes + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: destTable, + }, + } + + err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) + require.NoError(s.t, err) + + // Verify indexes were created on destination + destTableParsed, err := utils.ParseSchemaTable(destTable) + require.NoError(s.t, err) + destIndexes, err := s.destConn.getIndexesForTable(ctx, destTableParsed) + require.NoError(s.t, err) + + // Should have primary key + 3 indexes = 4 total (excluding primary key from sync) + indexNames := make(map[string]bool) + for _, idx := range destIndexes { + indexNames[idx.IndexName] = true + } + + // Check that our indexes exist (primary key is created automatically) + require.True(s.t, indexNames["test_table_pkey"], "Primary key should exist") + require.True(s.t, indexNames["idx_name"], "idx_name should be synced") + require.True(s.t, indexNames["idx_email"], "idx_email should be synced") + require.True(s.t, indexNames["idx_created_at"], "idx_created_at should be synced") +} + +func (s IndexTriggerSyncTestSuite) TestSyncTriggers() { + ctx := s.t.Context() + sourceTable := s.sourceSchema + ".test_trigger_table" + destTable := s.destSchema + ".test_trigger_table" + + // Create a function for the trigger + _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.update_timestamp() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = NOW(); + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, s.sourceSchema)) + require.NoError(s.t, err) + + // Create the same function in destination schema + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.update_timestamp() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = NOW(); + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, s.destSchema)) + require.NoError(s.t, err) + + // Create source table with trigger + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT, + updated_at TIMESTAMPTZ DEFAULT NOW() + )`, sourceTable)) + require.NoError(s.t, err) + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TRIGGER update_updated_at + BEFORE UPDATE ON %s + FOR EACH ROW + EXECUTE FUNCTION %s.update_timestamp()`, sourceTable, s.sourceSchema)) + require.NoError(s.t, err) + + // Create destination table (without trigger) + // Note: In the real PeerDB flow, CreateNormalizedTable creates the destination table automatically + // before SyncIndexesAndTriggers is called. Here we create it manually to simulate that state. + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT, + updated_at TIMESTAMPTZ DEFAULT NOW() + )`, destTable)) + require.NoError(s.t, err) + + // Sync triggers + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: destTable, + }, + } + + err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) + require.NoError(s.t, err) + + // Verify trigger was created on destination + destTableParsed, err := utils.ParseSchemaTable(destTable) + require.NoError(s.t, err) + destTriggers, err := s.destConn.getTriggersForTable(ctx, destTableParsed) + require.NoError(s.t, err) + + // Should have 1 trigger + require.Len(s.t, destTriggers, 1, "Should have 1 trigger") + require.Equal(s.t, "update_updated_at", destTriggers[0].TriggerName) +} + +func (s IndexTriggerSyncTestSuite) TestSyncIndexesAndTriggersTogether() { + ctx := s.t.Context() + sourceTable := s.sourceSchema + ".test_both_table" + destTable := s.destSchema + ".test_both_table" + + // Create function for trigger + _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.log_changes() + RETURNS TRIGGER AS $$ + BEGIN + -- Simple trigger function + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, s.sourceSchema)) + require.NoError(s.t, err) + + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.log_changes() + RETURNS TRIGGER AS $$ + BEGIN + -- Simple trigger function + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, s.destSchema)) + require.NoError(s.t, err) + + // Create source table with index and trigger + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + data TEXT + )`, sourceTable)) + require.NoError(s.t, err) + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("CREATE INDEX idx_data ON %s(data)", sourceTable)) + require.NoError(s.t, err) + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TRIGGER log_trigger + AFTER INSERT ON %s + FOR EACH ROW + EXECUTE FUNCTION %s.log_changes()`, sourceTable, s.sourceSchema)) + require.NoError(s.t, err) + + // Create destination table (without index and trigger) + // Note: In the real PeerDB flow, CreateNormalizedTable creates the destination table automatically + // before SyncIndexesAndTriggers is called. Here we create it manually to simulate that state. + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + data TEXT + )`, destTable)) + require.NoError(s.t, err) + + // Sync both indexes and triggers + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: destTable, + }, + } + + err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) + require.NoError(s.t, err) + + // Verify index was created + destTableParsed, err := utils.ParseSchemaTable(destTable) + require.NoError(s.t, err) + destIndexes, err := s.destConn.getIndexesForTable(ctx, destTableParsed) + require.NoError(s.t, err) + + indexNames := make(map[string]bool) + for _, idx := range destIndexes { + indexNames[idx.IndexName] = true + } + require.True(s.t, indexNames["idx_data"], "idx_data should be synced") + + // Verify trigger was created + destTriggers, err := s.destConn.getTriggersForTable(ctx, destTableParsed) + require.NoError(s.t, err) + + require.Len(s.t, destTriggers, 1, "Should have 1 trigger") + require.Equal(s.t, "log_trigger", destTriggers[0].TriggerName) +} + +func (s IndexTriggerSyncTestSuite) TestSyncCheckConstraints() { + ctx := s.t.Context() + sourceTable := s.sourceSchema + ".test_check_table" + destTable := s.destSchema + ".test_check_table" + + // Create source table with check constraints + _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT NOT NULL, + age INT, + email TEXT + )`, sourceTable)) + require.NoError(s.t, err) + + // Add check constraints + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_name_length + CHECK (char_length(name) >= 3)`, sourceTable)) + require.NoError(s.t, err) + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_age_positive + CHECK (age > 0 AND age < 150)`, sourceTable)) + require.NoError(s.t, err) + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_email_format + CHECK (email IS NULL OR email LIKE '%%@%%')`, sourceTable)) + require.NoError(s.t, err) + + // Create destination table (without constraints) + // Note: In the real PeerDB flow, CreateNormalizedTable creates the destination table automatically + // before SyncIndexesAndTriggers is called. Here we create it manually to simulate that state. + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT NOT NULL, + age INT, + email TEXT + )`, destTable)) + require.NoError(s.t, err) + + // Sync constraints + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: destTable, + }, + } + + err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) + require.NoError(s.t, err) + + // Verify constraints were created on destination + destTableParsed, err := utils.ParseSchemaTable(destTable) + require.NoError(s.t, err) + destConstraints, err := s.destConn.getConstraintsForTable(ctx, destTableParsed) + require.NoError(s.t, err) + + // Should have 3 check constraints + constraintNames := make(map[string]bool) + for _, constraint := range destConstraints { + if constraint.ConstraintType == "c" { + constraintNames[constraint.ConstraintName] = true + } + } + + require.True(s.t, constraintNames["check_name_length"], "check_name_length should be synced") + require.True(s.t, constraintNames["check_age_positive"], "check_age_positive should be synced") + require.True(s.t, constraintNames["check_email_format"], "check_email_format should be synced") +} + +func (s IndexTriggerSyncTestSuite) TestSyncForeignKeyConstraints() { + ctx := s.t.Context() + + // Create parent table on both source and destination + parentSourceTable := s.sourceSchema + ".parent_table" + parentDestTable := s.destSchema + ".parent_table" + + _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT + )`, parentSourceTable)) + require.NoError(s.t, err) + + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT + )`, parentDestTable)) + require.NoError(s.t, err) + + // Create child table on source with foreign key + childSourceTable := s.sourceSchema + ".child_table" + childDestTable := s.destSchema + ".child_table" + + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + parent_id INT, + name TEXT + )`, childSourceTable)) + require.NoError(s.t, err) + + // Add foreign key constraint + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT fk_parent + FOREIGN KEY (parent_id) REFERENCES %s(id)`, childSourceTable, parentSourceTable)) + require.NoError(s.t, err) + + // Create child table on destination (without foreign key) + // Note: In the real PeerDB flow, CreateNormalizedTable creates the destination table automatically + // before SyncIndexesAndTriggers is called. Here we create it manually to simulate that state. + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + parent_id INT, + name TEXT + )`, childDestTable)) + require.NoError(s.t, err) + + // Sync constraints (both tables need to be in the mapping for FK to work) + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: parentSourceTable, + DestinationTableIdentifier: parentDestTable, + }, + { + SourceTableIdentifier: childSourceTable, + DestinationTableIdentifier: childDestTable, + }, + } + + err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) + require.NoError(s.t, err) + + // Verify foreign key constraint was created on destination + childDestTableParsed, err := utils.ParseSchemaTable(childDestTable) + require.NoError(s.t, err) + destConstraints, err := s.destConn.getConstraintsForTable(ctx, childDestTableParsed) + require.NoError(s.t, err) + + // Should have 1 foreign key constraint + fkFound := false + for _, constraint := range destConstraints { + if constraint.ConstraintType == "f" && constraint.ConstraintName == "fk_parent" { + fkFound = true + // Verify the constraint definition references the correct destination table + require.Contains(s.t, constraint.ConstraintDef, parentDestTable, + "Foreign key should reference destination parent table") + break + } + } + require.True(s.t, fkFound, "fk_parent foreign key should be synced") +} + +func (s IndexTriggerSyncTestSuite) TestSyncConstraintsTogether() { + ctx := s.t.Context() + sourceTable := s.sourceSchema + ".test_constraints_table" + destTable := s.destSchema + ".test_constraints_table" + + // Create source table with both check and foreign key constraints + // First create a referenced table + refSourceTable := s.sourceSchema + ".ref_table" + refDestTable := s.destSchema + ".ref_table" + + _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + code TEXT UNIQUE + )`, refSourceTable)) + require.NoError(s.t, err) + + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + code TEXT UNIQUE + )`, refDestTable)) + require.NoError(s.t, err) + + // Create main table with constraints + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT NOT NULL, + ref_id INT, + status TEXT + )`, sourceTable)) + require.NoError(s.t, err) + + // Add check constraint + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_name_length + CHECK (char_length(name) >= 2)`, sourceTable)) + require.NoError(s.t, err) + + // Add foreign key constraint + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT fk_ref + FOREIGN KEY (ref_id) REFERENCES %s(id)`, sourceTable, refSourceTable)) + require.NoError(s.t, err) + + // Create destination table (without constraints) + // Note: In the real PeerDB flow, CreateNormalizedTable creates the destination table automatically + // before SyncIndexesAndTriggers is called. Here we create it manually to simulate that state. + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT NOT NULL, + ref_id INT, + status TEXT + )`, destTable)) + require.NoError(s.t, err) + + // Sync constraints (both tables in mapping) + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: refSourceTable, + DestinationTableIdentifier: refDestTable, + }, + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: destTable, + }, + } + + err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) + require.NoError(s.t, err) + + // Verify constraints were created + destTableParsed, err := utils.ParseSchemaTable(destTable) + require.NoError(s.t, err) + destConstraints, err := s.destConn.getConstraintsForTable(ctx, destTableParsed) + require.NoError(s.t, err) + + constraintNames := make(map[string]string) // name -> type + for _, constraint := range destConstraints { + constraintNames[constraint.ConstraintName] = constraint.ConstraintType + } + + require.Equal(s.t, "c", constraintNames["check_name_length"], "check_name_length should be synced") + require.Equal(s.t, "f", constraintNames["fk_ref"], "fk_ref should be synced") +} + +func (s IndexTriggerSyncTestSuite) TestSyncAllTogether() { + ctx := s.t.Context() + sourceTable := s.sourceSchema + ".test_all_table" + destTable := s.destSchema + ".test_all_table" + + // Create function for trigger + _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.audit_func() + RETURNS TRIGGER AS $$ + BEGIN + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, s.sourceSchema)) + require.NoError(s.t, err) + + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.audit_func() + RETURNS TRIGGER AS $$ + BEGIN + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, s.destSchema)) + require.NoError(s.t, err) + + // Create source table with index, trigger, and constraints + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT NOT NULL, + email TEXT, + created_at TIMESTAMPTZ DEFAULT NOW() + )`, sourceTable)) + require.NoError(s.t, err) + + // Add index + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("CREATE INDEX idx_name ON %s(name)", sourceTable)) + require.NoError(s.t, err) + + // Add trigger + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TRIGGER audit_trigger + AFTER INSERT ON %s + FOR EACH ROW + EXECUTE FUNCTION %s.audit_func()`, sourceTable, s.sourceSchema)) + require.NoError(s.t, err) + + // Add check constraint + _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_name_length + CHECK (char_length(name) >= 3)`, sourceTable)) + require.NoError(s.t, err) + + // Create destination table (without index, trigger, or constraints) + // Note: In the real PeerDB flow, CreateNormalizedTable creates the destination table automatically + // before SyncIndexesAndTriggers is called. Here we create it manually to simulate that state. + _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` + CREATE TABLE %s ( + id INT PRIMARY KEY, + name TEXT NOT NULL, + email TEXT, + created_at TIMESTAMPTZ DEFAULT NOW() + )`, destTable)) + require.NoError(s.t, err) + + // Sync everything + tableMappings := []*protos.TableMapping{ + { + SourceTableIdentifier: sourceTable, + DestinationTableIdentifier: destTable, + }, + } + + err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) + require.NoError(s.t, err) + + // Verify index was created + destTableParsed, err := utils.ParseSchemaTable(destTable) + require.NoError(s.t, err) + destIndexes, err := s.destConn.getIndexesForTable(ctx, destTableParsed) + require.NoError(s.t, err) + + indexNames := make(map[string]bool) + for _, idx := range destIndexes { + indexNames[idx.IndexName] = true + } + require.True(s.t, indexNames["idx_name"], "idx_name should be synced") + + // Verify trigger was created + destTriggers, err := s.destConn.getTriggersForTable(ctx, destTableParsed) + require.NoError(s.t, err) + require.Len(s.t, destTriggers, 1, "Should have 1 trigger") + require.Equal(s.t, "audit_trigger", destTriggers[0].TriggerName) + + // Verify constraint was created + destConstraints, err := s.destConn.getConstraintsForTable(ctx, destTableParsed) + require.NoError(s.t, err) + + constraintFound := false + for _, constraint := range destConstraints { + if constraint.ConstraintName == "check_name_length" && constraint.ConstraintType == "c" { + constraintFound = true + break + } + } + require.True(s.t, constraintFound, "check_name_length constraint should be synced") +} + +func TestIndexTriggerSync(t *testing.T) { + e2eshared.RunSuite(t, SetupIndexTriggerSyncSuite) +} From ac6834e2383b163ea3b1dd777161be17c7cd1d4c Mon Sep 17 00:00:00 2001 From: ankitsheoran1 Date: Mon, 17 Nov 2025 18:59:16 +0530 Subject: [PATCH 2/6] add support for all cases --- flow/activities/flowable.go | 57 +++ flow/activities/flowable_core.go | 92 +++++ flow/connectors/postgres/cdc.go | 356 ++++++++++++++++-- .../postgres/normalize_stmt_generator.go | 31 +- flow/connectors/postgres/postgres.go | 130 ++++++- flow/e2e/postgres_test.go | 173 +++++++++ flow/workflows/setup_flow.go | 7 + protos/flow.proto | 11 + 8 files changed, 809 insertions(+), 48 deletions(-) diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 3fe6a6dd3e..40bc6e384b 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -281,6 +281,63 @@ func (a *FlowableActivity) CreateNormalizedTable( }, nil } +// SyncIndexesAndTriggers syncs indexes and triggers from source to destination +// This is called once during initial setup, not for on-the-fly changes +func (a *FlowableActivity) SyncIndexesAndTriggers( + ctx context.Context, + config *protos.SetupNormalizedTableBatchInput, +) error { + logger := internal.LoggerFromCtx(ctx) + ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName) + + // Only sync for Postgres to Postgres + // Check if destination is Postgres + dstConn, dstClose, err := connectors.GetByNameAs[connectors.NormalizedTablesConnector](ctx, config.Env, a.CatalogPool, config.PeerName) + if err != nil { + if errors.Is(err, errors.ErrUnsupported) { + logger.Info("Connector does not support normalized tables, skipping index/trigger sync") + return nil + } + return a.Alerter.LogFlowError(ctx, config.FlowName, fmt.Errorf("failed to get destination connector: %w", err)) + } + defer dstClose(ctx) + + // Check if destination connector is Postgres + pgDstConn, ok := dstConn.(*connpostgres.PostgresConnector) + if !ok { + logger.Info("Destination is not Postgres, skipping index/trigger sync") + return nil + } + + // Get source connector (use SourcePeerName if available, otherwise skip) + if config.SourcePeerName == "" { + logger.Info("Source peer name not provided, skipping index/trigger sync") + return nil + } + + srcConn, srcClose, err := connectors.GetByNameAs[connectors.GetTableSchemaConnector](ctx, config.Env, a.CatalogPool, config.SourcePeerName) + if err != nil { + return a.Alerter.LogFlowError(ctx, config.FlowName, fmt.Errorf("failed to get source connector: %w", err)) + } + defer srcClose(ctx) + + // Check if source connector is Postgres + pgSrcConn, ok := srcConn.(*connpostgres.PostgresConnector) + if !ok { + logger.Info("Source is not Postgres, skipping index/trigger sync") + return nil + } + + // Sync indexes, triggers, and constraints + a.Alerter.LogFlowInfo(ctx, config.FlowName, "Syncing indexes, triggers, and constraints from source to destination") + if err := pgDstConn.SyncIndexesAndTriggers(ctx, config.TableMappings, pgSrcConn); err != nil { + return a.Alerter.LogFlowError(ctx, config.FlowName, fmt.Errorf("failed to sync indexes, triggers, and constraints: %w", err)) + } + + a.Alerter.LogFlowInfo(ctx, config.FlowName, "Successfully synced indexes, triggers, and constraints") + return nil +} + func (a *FlowableActivity) SyncFlow( ctx context.Context, config *protos.FlowConnectionConfigsCore, diff --git a/flow/activities/flowable_core.go b/flow/activities/flowable_core.go index 2da7e035a0..75605ccb75 100644 --- a/flow/activities/flowable_core.go +++ b/flow/activities/flowable_core.go @@ -108,6 +108,89 @@ func (a *FlowableActivity) applySchemaDeltas( return nil } +// updateDestinationSchemaMapping updates the destination schema mapping in the catalog +// after schema deltas have been applied to the destination. This ensures normalization +// uses the correct updated schema. +func (a *FlowableActivity) updateDestinationSchemaMapping( + ctx context.Context, + config *protos.FlowConnectionConfigsCore, + options *protos.SyncFlowOptions, + schemaDeltas []*protos.TableSchemaDelta, +) error { + logger := internal.LoggerFromCtx(ctx) + + // Filter table mappings for tables that had schema changes + filteredTableMappings := make([]*protos.TableMapping, 0, len(schemaDeltas)) + for _, tableMapping := range options.TableMappings { + if slices.ContainsFunc(schemaDeltas, func(schemaDelta *protos.TableSchemaDelta) bool { + return schemaDelta.SrcTableName == tableMapping.SourceTableIdentifier && + schemaDelta.DstTableName == tableMapping.DestinationTableIdentifier + }) { + filteredTableMappings = append(filteredTableMappings, tableMapping) + } + } + + if len(filteredTableMappings) == 0 { + return nil + } + + // Get destination connector to fetch updated schema + dstConn, dstClose, err := connectors.GetByNameAs[connectors.GetTableSchemaConnector](ctx, config.Env, a.CatalogPool, config.DestinationName) + if err != nil { + return fmt.Errorf("failed to get destination connector for schema update: %w", err) + } + defer dstClose(ctx) + + logger.Info("Updating destination schema mapping after schema deltas", + slog.String("flowName", config.FlowJobName), + slog.Int("tablesAffected", len(filteredTableMappings))) + + // Fetch updated schema from destination + tableNameSchemaMapping, err := dstConn.GetTableSchema(ctx, config.Env, config.Version, config.System, filteredTableMappings) + if err != nil { + return fmt.Errorf("failed to get updated schema from destination: %w", err) + } + + // Build processed schema mapping (maps destination table names to schemas) + processed := internal.BuildProcessedSchemaMapping(filteredTableMappings, tableNameSchemaMapping, logger) + + // Update catalog with new destination schemas + tx, err := a.CatalogPool.BeginTx(ctx, pgx.TxOptions{}) + if err != nil { + return fmt.Errorf("failed to start transaction for schema mapping update: %w", err) + } + defer shared.RollbackTx(tx, logger) + + for tableName, tableSchema := range processed { + processedBytes, err := proto.Marshal(tableSchema) + if err != nil { + return fmt.Errorf("failed to marshal table schema for %s: %w", tableName, err) + } + if _, err := tx.Exec( + ctx, + "insert into table_schema_mapping(flow_name, table_name, table_schema) values ($1, $2, $3) "+ + "on conflict (flow_name, table_name) do update set table_schema = $3", + config.FlowJobName, + tableName, + processedBytes, + ); err != nil { + return fmt.Errorf("failed to update schema mapping for %s: %w", tableName, err) + } + logger.Info("Updated destination schema mapping in catalog", + slog.String("tableName", tableName), + slog.Int("columnCount", len(tableSchema.Columns))) + } + + if err := tx.Commit(ctx); err != nil { + return fmt.Errorf("failed to commit schema mapping update: %w", err) + } + + logger.Info("Successfully updated destination schema mapping", + slog.String("flowName", config.FlowJobName), + slog.Int("tablesUpdated", len(processed))) + return nil +} + func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncConnectorCore, Items model.Items]( ctx context.Context, a *FlowableActivity, @@ -338,6 +421,15 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon return nil, err } + // For Postgres to Postgres, update destination schema mapping after applying schema deltas + if len(res.TableSchemaDeltas) > 0 { + if err := a.updateDestinationSchemaMapping(ctx, config, options, res.TableSchemaDeltas); err != nil { + logger.Warn("Failed to update destination schema mapping, normalization may use stale schema", + slog.Any("error", err)) + // Don't fail the sync if schema mapping update fails, but log it + } + } + if recordBatchSync.NeedsNormalize() { syncState.Store(shared.Ptr("normalizing")) normRequests.Update(res.CurrentSyncBatchID) diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index f041814a43..d4ed309b26 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -187,6 +187,7 @@ func (pgProcessor) NewItems(size int) model.PgItems { return model.NewPgItems(size) } +// How then integer / timestamp type got synced for our test case func (pgProcessor) Process( items model.PgItems, p *PostgresCDCSource, @@ -222,6 +223,7 @@ func (qProcessor) NewItems(size int) model.RecordItems { return model.NewRecordItems(size) } +// what is diff b/w pgProcessor and qProcessor func (qProcessor) Process( items model.RecordItems, p *PostgresCDCSource, @@ -684,7 +686,7 @@ func PullCdcRecords[Items model.Items]( logger.Debug("XLogData", slog.Any("WALStart", xld.WALStart), slog.Any("ServerWALEnd", xld.ServerWALEnd), slog.Time("ServerTime", xld.ServerTime)) - rec, err := processMessage(ctx, p, records, xld, clientXLogPos, processor) + rec, err := processMessage(ctx, p, req, records, xld, clientXLogPos, processor) if err != nil { return exceptions.NewPostgresLogicalMessageProcessingError(err) } @@ -697,6 +699,18 @@ func PullCdcRecords[Items model.Items]( fetchedBytes.Add(int64(len(msg.Data))) totalFetchedBytes.Add(int64(len(msg.Data))) tableName := rec.GetDestinationTableName() + + // Log if this is a DML operation following a schema change + switch rec.(type) { + case *model.InsertRecord[Items], *model.UpdateRecord[Items], *model.DeleteRecord[Items]: + if schema, ok := req.TableNameSchemaMapping[tableName]; ok { + logger.Debug("Processing DML operation", + slog.String("tableName", tableName), + slog.Int("columnCount", len(schema.Columns)), + slog.Any("LSN", xld.WALStart)) + } + } + switch r := rec.(type) { case *model.UpdateRecord[Items]: // tableName here is destination tableName. @@ -792,10 +806,26 @@ func PullCdcRecords[Items model.Items]( case *model.RelationRecord[Items]: tableSchemaDelta := r.TableSchemaDelta - if len(tableSchemaDelta.AddedColumns) > 0 { - logger.Info(fmt.Sprintf("Detected schema change for table %s, addedColumns: %v", - tableSchemaDelta.SrcTableName, tableSchemaDelta.AddedColumns)) + if len(tableSchemaDelta.AddedColumns) > 0 || len(tableSchemaDelta.DroppedColumns) > 0 { + addedColNames := make([]string, 0, len(tableSchemaDelta.AddedColumns)) + for _, col := range tableSchemaDelta.AddedColumns { + addedColNames = append(addedColNames, fmt.Sprintf("%s(%s)", col.Name, col.Type)) + } + logger.Info("Processing RelationRecord with schema changes", + slog.String("srcTableName", tableSchemaDelta.SrcTableName), + slog.String("dstTableName", tableSchemaDelta.DstTableName), + slog.Int("addedColumnsCount", len(tableSchemaDelta.AddedColumns)), + slog.Any("addedColumns", addedColNames), + slog.Int("droppedColumnsCount", len(tableSchemaDelta.DroppedColumns)), + slog.Any("droppedColumns", tableSchemaDelta.DroppedColumns), + slog.Int64("checkpointID", r.CheckpointID), + slog.Uint64("transactionID", r.TransactionID)) records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta) + logger.Info("Added schema delta to records stream", + slog.String("srcTableName", tableSchemaDelta.SrcTableName)) + } else { + logger.Debug("RelationRecord with no schema changes, skipping", + slog.String("srcTableName", tableSchemaDelta.SrcTableName)) } case *model.MessageRecord[Items]: @@ -850,6 +880,7 @@ func (p *PostgresCDCSource) baseRecord(lsn pglogrepl.LSN) model.BaseRecord { func processMessage[Items model.Items]( ctx context.Context, p *PostgresCDCSource, + req *model.PullRecordsRequest[Items], batch *model.CDCStream[Items], xld pglogrepl.XLogData, currentClientXlogPos pglogrepl.LSN, @@ -890,17 +921,34 @@ func processMessage[Items model.Items]( return nil, err } + currColNames := make([]string, 0, len(msg.Columns)) + for _, col := range msg.Columns { + currColNames = append(currColNames, col.Name) + } + + logger.Info("Received RelationMessage from WAL", + slog.Uint64("RelationID", uint64(msg.RelationID)), + slog.String("Namespace", msg.Namespace), + slog.String("RelationName", msg.RelationName), + slog.Int("columnCount", len(msg.Columns)), + slog.Any("columnNames", currColNames), + slog.Any("LSN", currentClientXlogPos)) + if _, exists := p.srcTableIDNameMapping[msg.RelationID]; !exists { + logger.Warn("RelationMessage received for table not in replication set, skipping", + slog.Uint64("RelationID", uint64(msg.RelationID)), + slog.String("Namespace", msg.Namespace), + slog.String("RelationName", msg.RelationName)) return nil, nil } - logger.Debug("RelationMessage", + logger.Debug("Processing RelationMessage for replicated table", slog.Uint64("RelationID", uint64(msg.RelationID)), slog.String("Namespace", msg.Namespace), slog.String("RelationName", msg.RelationName), slog.Any("Columns", msg.Columns)) - return processRelationMessage[Items](ctx, p, currentClientXlogPos, msg) + return processRelationMessage[Items](ctx, p, req, currentClientXlogPos, msg) case *pglogrepl.LogicalDecodingMessage: logger.Debug("LogicalDecodingMessage", slog.Bool("Transactional", msg.Transactional), @@ -1071,9 +1119,19 @@ func processDeleteMessage[Items model.Items]( } // processRelationMessage processes a RelationMessage and returns a TableSchemaDelta +// Currently supported DDL operations: +// - ADD COLUMN (with default values) +// - DROP COLUMN (excluding PeerDB system columns) +// - ALTER COLUMN TYPE (column type changes) +// +// Not currently supported: +// - CREATE/DROP INDEX (indexes are not replicated via logical replication RelationMessages) +// - CREATE/DROP TRIGGER (triggers are not replicated via logical replication RelationMessages) +// - Other DDL operations not captured in RelationMessages func processRelationMessage[Items model.Items]( ctx context.Context, p *PostgresCDCSource, + req *model.PullRecordsRequest[Items], lsn pglogrepl.LSN, currRel *pglogrepl.RelationMessage, ) (model.Record[Items], error) { @@ -1091,7 +1149,11 @@ func processRelationMessage[Items model.Items]( p.logger.Info("processing RelationMessage", slog.Any("LSN", lsn), - slog.String("RelationName", currRelName)) + slog.String("RelationName", currRelName), + slog.Int("RelationID", int(currRel.RelationID)), + slog.Int("columnCount", len(currRel.Columns)), + slog.String("relationNamespace", currRel.Namespace), + slog.String("relationRelationName", currRel.RelationName)) // retrieve current TableSchema for table changed, mapping uses dst table name as key, need to translate source name currRelDstInfo, ok := p.tableNameMapping[currRelName] if !ok { @@ -1103,14 +1165,36 @@ func processRelationMessage[Items model.Items]( prevSchema, ok := p.tableNameSchemaMapping[currRelDstInfo.Name] if !ok { p.logger.Error("Detected relation message for table, but not in table schema mapping", - slog.String("tableName", currRelDstInfo.Name)) + slog.String("tableName", currRelDstInfo.Name), + slog.String("srcTableName", currRelName), + slog.Int("relationID", int(currRel.RelationID))) return nil, fmt.Errorf("cannot find table schema for %s", currRelDstInfo.Name) } + if prevSchema == nil { + p.logger.Error("Previous schema is nil for table", + slog.String("tableName", currRelDstInfo.Name)) + return nil, fmt.Errorf("previous schema is nil for %s", currRelDstInfo.Name) + } + + if len(prevSchema.Columns) == 0 { + p.logger.Warn("Previous schema has no columns, skipping schema comparison", + slog.String("tableName", currRelDstInfo.Name)) + // Still process to update the schema cache, but don't detect changes + p.relationMessageMapping[currRel.RelationID] = currRel + return nil, nil + } + prevRelMap := make(map[string]string, len(prevSchema.Columns)) + prevColNames := make([]string, 0, len(prevSchema.Columns)) for _, column := range prevSchema.Columns { prevRelMap[column.Name] = column.Type + prevColNames = append(prevColNames, column.Name) } + p.logger.Info("Retrieved previous schema from cache", + slog.String("tableName", currRelDstInfo.Name), + slog.Int("previousColumnCount", len(prevSchema.Columns)), + slog.Any("previousColumns", prevColNames)) currRelMap := make(map[string]string, len(currRel.Columns)) for _, column := range currRel.Columns { @@ -1139,9 +1223,20 @@ func processRelationMessage[Items model.Items]( SrcTableName: p.srcTableIDNameMapping[currRel.RelationID], DstTableName: p.tableNameMapping[p.srcTableIDNameMapping[currRel.RelationID]].Name, AddedColumns: nil, + DroppedColumns: nil, System: prevSchema.System, NullableEnabled: prevSchema.NullableEnabled, } + + // Log previous and current column names for debugging + currColNames := make([]string, 0, len(currRel.Columns)) + for _, col := range currRel.Columns { + currColNames = append(currColNames, col.Name) + } + p.logger.Debug("Comparing schemas for schema change detection", + slog.String("tableName", schemaDelta.SrcTableName), + slog.Any("previousColumns", prevColNames), + slog.Any("currentColumns", currColNames)) for _, column := range currRel.Columns { // not present in previous relation message, but in current one, so added. if _, ok := prevRelMap[column.Name]; !ok { @@ -1163,63 +1258,266 @@ func processRelationMessage[Items model.Items]( p.logger.Info("Detected added column", slog.String("columnName", column.Name), slog.String("columnType", currRelMap[column.Name]), - slog.String("relationName", schemaDelta.SrcTableName)) + slog.String("typeModifier", fmt.Sprintf("%d", column.TypeModifier)), + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Any("LSN", lsn)) } else { p.logger.Warn(fmt.Sprintf("Detected added column %s in table %s, but not propagating because excluded", column.Name, schemaDelta.SrcTableName)) } // present in previous and current relation messages, but data types have changed. - // so we add it to AddedColumns and DroppedColumns, knowing that we process DroppedColumns first. + // Detect and record type changes } else if prevRelMap[column.Name] != currRelMap[column.Name] { - p.logger.Warn(fmt.Sprintf("Detected column %s with type changed from %s to %s in table %s, but not propagating", - column.Name, prevRelMap[column.Name], currRelMap[column.Name], schemaDelta.SrcTableName)) + // Get default value for the column if it exists + var newDefaultValue string + for _, currCol := range currRel.Columns { + if currCol.Name == column.Name { + // Fetch default value from pg_catalog + rows, err := p.conn.Query(ctx, + `SELECT pg_get_expr(adbin, adrelid) as column_default + FROM pg_attribute + LEFT JOIN pg_attrdef ON pg_attribute.attrelid = pg_attrdef.adrelid + AND pg_attribute.attnum = pg_attrdef.adnum + WHERE attrelid=$1 AND attname=$2 AND attnum > 0 AND NOT attisdropped`, + currRel.RelationID, column.Name) + if err == nil { + var defaultVal pgtype.Text + if rows.Next() { + if err := rows.Scan(&defaultVal); err == nil && defaultVal.Valid { + newDefaultValue = defaultVal.String + } + } + rows.Close() + } + break + } + } + + typeChange := &protos.ColumnTypeChange{ + ColumnName: column.Name, + OldType: prevRelMap[column.Name], + NewType: currRelMap[column.Name], + NewDefaultValue: newDefaultValue, + } + schemaDelta.TypeChangedColumns = append(schemaDelta.TypeChangedColumns, typeChange) + p.logger.Info("Detected column type change", + slog.String("columnName", column.Name), + slog.String("oldType", prevRelMap[column.Name]), + slog.String("newType", currRelMap[column.Name]), + slog.String("defaultValue", newDefaultValue), + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Any("LSN", lsn)) } } for _, column := range prevSchema.Columns { // present in previous relation message, but not in current one, so dropped. if _, ok := currRelMap[column.Name]; !ok { - p.logger.Warn(fmt.Sprintf("Detected dropped column %s in table %s, but not propagating", column, - schemaDelta.SrcTableName)) + // Never drop PeerDB system columns - they are managed by PeerDB + // System columns start with _PEERDB_ prefix + if strings.HasPrefix(column.Name, "_PEERDB_") { + p.logger.Warn("Detected dropped PeerDB system column, skipping", + slog.String("columnName", column.Name), + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.String("reason", "PeerDB system columns cannot be dropped")) + continue + } + + // only add to delta if not excluded + if _, ok := p.tableNameMapping[p.srcTableIDNameMapping[currRel.RelationID]].Exclude[column.Name]; !ok { + schemaDelta.DroppedColumns = append(schemaDelta.DroppedColumns, column.Name) + p.logger.Info("Detected dropped column", + slog.String("columnName", column.Name), + slog.String("columnType", column.Type), + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Any("LSN", lsn)) + } else { + p.logger.Warn(fmt.Sprintf("Detected dropped column %s in table %s, but not propagating because excluded", + column.Name, schemaDelta.SrcTableName)) + } + } + } + + // Log summary of detected schema changes + if len(schemaDelta.AddedColumns) > 0 || len(schemaDelta.DroppedColumns) > 0 || len(schemaDelta.TypeChangedColumns) > 0 { + addedColNames := make([]string, 0, len(schemaDelta.AddedColumns)) + for _, col := range schemaDelta.AddedColumns { + addedColNames = append(addedColNames, fmt.Sprintf("%s(%s)", col.Name, col.Type)) + } + typeChangedColNames := make([]string, 0, len(schemaDelta.TypeChangedColumns)) + for _, tc := range schemaDelta.TypeChangedColumns { + typeChangedColNames = append(typeChangedColNames, fmt.Sprintf("%s(%s->%s)", tc.ColumnName, tc.OldType, tc.NewType)) } + p.logger.Info("Schema change detection summary", + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Int("addedColumnsCount", len(schemaDelta.AddedColumns)), + slog.Any("addedColumns", addedColNames), + slog.Int("droppedColumnsCount", len(schemaDelta.DroppedColumns)), + slog.Any("droppedColumns", schemaDelta.DroppedColumns), + slog.Int("typeChangedColumnsCount", len(schemaDelta.TypeChangedColumns)), + slog.Any("typeChangedColumns", typeChangedColNames), + slog.Any("LSN", lsn)) } - if len(potentiallyNullableAddedColumns) > 0 { - p.logger.Info("Checking for potentially nullable columns in table", + + // Update relationMessageMapping IMMEDIATELY so DML operations that follow + // in the same WAL stream use the updated schema + p.relationMessageMapping[currRel.RelationID] = currRel + p.logger.Info("Updated relationMessageMapping with new schema", + slog.String("tableName", currRelName), + slog.Int("columnCount", len(currRel.Columns)), + slog.Any("LSN", lsn)) + + // Fetch default values and nullable info for added columns from pg_catalog + if len(schemaDelta.AddedColumns) > 0 { + addedColNames := make([]string, 0, len(schemaDelta.AddedColumns)) + for _, col := range schemaDelta.AddedColumns { + addedColNames = append(addedColNames, utils.QuoteLiteral(col.Name)) + } + + p.logger.Info("Fetching default values and nullable info for added columns", slog.String("tableName", schemaDelta.SrcTableName), - slog.Any("potentiallyNullable", potentiallyNullableAddedColumns)) + slog.Any("addedColumns", addedColNames)) rows, err := p.conn.Query( ctx, fmt.Sprintf( - "select attname from pg_attribute where attrelid=$1 and attname in (%s) and not attnotnull", - strings.Join(potentiallyNullableAddedColumns, ","), + `SELECT attname, attnotnull, pg_get_expr(adbin, adrelid) as column_default + FROM pg_attribute + LEFT JOIN pg_attrdef ON pg_attribute.attrelid = pg_attrdef.adrelid + AND pg_attribute.attnum = pg_attrdef.adnum + WHERE attrelid=$1 AND attname IN (%s) AND attnum > 0 AND NOT attisdropped`, + strings.Join(addedColNames, ","), ), currRel.RelationID, ) if err != nil { - return nil, fmt.Errorf("error looking up column nullable info for schema change: %w", err) + return nil, fmt.Errorf("error looking up column default/nullable info for schema change: %w", err) + } + + type colInfo struct { + attname string + attnotnull bool + columnDefault pgtype.Text } - attnames, err := pgx.CollectRows[string](rows, pgx.RowTo) + colInfos, err := pgx.CollectRows(rows, func(row pgx.CollectableRow) (colInfo, error) { + var info colInfo + err := rows.Scan(&info.attname, &info.attnotnull, &info.columnDefault) + return info, err + }) if err != nil { - return nil, fmt.Errorf("error collecting rows for column nullable info for schema change: %w", err) + return nil, fmt.Errorf("error collecting rows for column default/nullable info: %w", err) + } + + colInfoMap := make(map[string]colInfo, len(colInfos)) + for _, info := range colInfos { + colInfoMap[info.attname] = info } + for _, column := range schemaDelta.AddedColumns { - if slices.Contains(attnames, column.Name) { - column.Nullable = true - p.logger.Info(fmt.Sprintf("Detected column %s in table %s as nullable", - column.Name, schemaDelta.SrcTableName)) + if info, ok := colInfoMap[column.Name]; ok { + column.Nullable = !info.attnotnull + if info.columnDefault.Valid && info.columnDefault.String != "" { + // Store default value - we'll use it in the ADD COLUMN statement + column.DefaultValue = info.columnDefault.String + p.logger.Info("Detected column with default value", + slog.String("columnName", column.Name), + slog.String("defaultValue", info.columnDefault.String), + slog.Bool("nullable", column.Nullable), + slog.String("tableName", schemaDelta.SrcTableName)) + } else { + p.logger.Info("Detected column without default value", + slog.String("columnName", column.Name), + slog.Bool("nullable", column.Nullable), + slog.String("tableName", schemaDelta.SrcTableName)) + } } } } - p.relationMessageMapping[currRel.RelationID] = currRel - // only log audit if there is actionable delta - if len(schemaDelta.AddedColumns) > 0 { + // Update the cached schema mapping after detecting changes + // This ensures the next comparison uses the updated schema + if len(schemaDelta.AddedColumns) > 0 || len(schemaDelta.DroppedColumns) > 0 || len(schemaDelta.TypeChangedColumns) > 0 { + // Create updated schema with new columns added, dropped columns removed, and type changes applied + updatedSchema := &protos.TableSchema{ + Columns: make([]*protos.FieldDescription, 0, len(prevSchema.Columns)), + System: prevSchema.System, + PrimaryKeyColumns: prevSchema.PrimaryKeyColumns, + NullableEnabled: prevSchema.NullableEnabled, + } + + // Build maps for efficient lookup + droppedColsMap := make(map[string]bool, len(schemaDelta.DroppedColumns)) + for _, droppedCol := range schemaDelta.DroppedColumns { + droppedColsMap[droppedCol] = true + } + + typeChangedColsMap := make(map[string]*protos.ColumnTypeChange, len(schemaDelta.TypeChangedColumns)) + for _, typeChange := range schemaDelta.TypeChangedColumns { + typeChangedColsMap[typeChange.ColumnName] = typeChange + } + + // Add existing columns that weren't dropped, updating types if changed + for _, col := range prevSchema.Columns { + if !droppedColsMap[col.Name] { + // Check if this column's type changed + if typeChange, ok := typeChangedColsMap[col.Name]; ok { + // Update the column with new type + updatedCol := &protos.FieldDescription{ + Name: col.Name, + Type: typeChange.NewType, + TypeModifier: col.TypeModifier, + Nullable: col.Nullable, + DefaultValue: typeChange.NewDefaultValue, + } + updatedSchema.Columns = append(updatedSchema.Columns, updatedCol) + } else { + updatedSchema.Columns = append(updatedSchema.Columns, col) + } + } + } + + // Add newly added columns + for _, addedCol := range schemaDelta.AddedColumns { + updatedSchema.Columns = append(updatedSchema.Columns, addedCol) + } + + // Update the cached schema mapping in both places + // This ensures DML operations that follow in the same WAL stream use updated schema + p.tableNameSchemaMapping[currRelDstInfo.Name] = updatedSchema + if req != nil && req.TableNameSchemaMapping != nil { + req.TableNameSchemaMapping[currRelDstInfo.Name] = updatedSchema + p.logger.Info("Updated req.TableNameSchemaMapping for immediate DML processing", + slog.String("tableName", currRelDstInfo.Name)) + } + + p.logger.Info("Updated cached schema mapping after detecting changes", + slog.String("tableName", currRelDstInfo.Name), + slog.Int("addedColumns", len(schemaDelta.AddedColumns)), + slog.Int("droppedColumns", len(schemaDelta.DroppedColumns)), + slog.Int("totalColumns", len(updatedSchema.Columns))) + } + + // Return RelationRecord if there are any schema changes (added, dropped, or type changed) + if len(schemaDelta.AddedColumns) > 0 || len(schemaDelta.DroppedColumns) > 0 || len(schemaDelta.TypeChangedColumns) > 0 { + p.logger.Info("Returning RelationRecord with schema delta", + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Int("addedColumnsCount", len(schemaDelta.AddedColumns)), + slog.Int("droppedColumnsCount", len(schemaDelta.DroppedColumns)), + slog.Int("typeChangedColumnsCount", len(schemaDelta.TypeChangedColumns)), + slog.Any("LSN", lsn)) return &model.RelationRecord[Items]{ BaseRecord: p.baseRecord(lsn), TableSchemaDelta: schemaDelta, }, monitoring.AuditSchemaDelta(ctx, p.catalogPool.Pool, p.flowJobName, schemaDelta) } + p.logger.Debug("No schema changes detected, returning nil", + slog.String("tableName", schemaDelta.SrcTableName)) return nil, nil } diff --git a/flow/connectors/postgres/normalize_stmt_generator.go b/flow/connectors/postgres/normalize_stmt_generator.go index d090ff01ba..699796d593 100644 --- a/flow/connectors/postgres/normalize_stmt_generator.go +++ b/flow/connectors/postgres/normalize_stmt_generator.go @@ -138,18 +138,34 @@ func (n *normalizeStmtGenerator) generateMergeStatement( unchangedToastColumns []string, ) string { columnCount := len(normalizedTableSchema.Columns) - quotedColumnNames := make([]string, columnCount) - flattenedCastsSQLArray := make([]string, 0, columnCount) parsedDstTable, _ := utils.ParseSchemaTable(dstTableName) primaryKeyColumnCasts := make(map[string]string) primaryKeySelectSQLArray := make([]string, 0, len(normalizedTableSchema.PrimaryKeyColumns)) - for i, column := range normalizedTableSchema.Columns { + + // Filter out PeerDB system columns - they are added separately + systemCols := make(map[string]bool) + if n.peerdbCols.SyncedAtColName != "" { + systemCols[n.peerdbCols.SyncedAtColName] = true + } + if n.peerdbCols.SoftDeleteColName != "" { + systemCols[n.peerdbCols.SoftDeleteColName] = true + } + + // Build column lists excluding system columns + quotedColumnNamesFiltered := make([]string, 0, columnCount) + + for _, column := range normalizedTableSchema.Columns { + // Skip PeerDB system columns - they are handled separately + if systemCols[column.Name] { + continue + } + genericColumnType := column.Type quotedCol := utils.QuoteIdentifier(column.Name) stringCol := utils.QuoteLiteral(column.Name) - quotedColumnNames[i] = quotedCol + quotedColumnNamesFiltered = append(quotedColumnNamesFiltered, quotedCol) pgType := n.columnTypeToPg(normalizedTableSchema, genericColumnType) expr := n.generateExpr(normalizedTableSchema, genericColumnType, stringCol, pgType) @@ -159,14 +175,17 @@ func (n *normalizeStmtGenerator) generateMergeStatement( primaryKeySelectSQLArray = append(primaryKeySelectSQLArray, fmt.Sprintf("src.%s=dst.%s", quotedCol, quotedCol)) } } + + // Use filtered column names (excluding system columns) + quotedColumnNames := quotedColumnNamesFiltered flattenedCastsSQL := strings.Join(flattenedCastsSQLArray, ",") - insertValuesSQLArray := make([]string, 0, columnCount+2) + insertValuesSQLArray := make([]string, 0, len(quotedColumnNames)+2) for _, quotedCol := range quotedColumnNames { insertValuesSQLArray = append(insertValuesSQLArray, "src."+quotedCol) } updateStatementsforToastCols := n.generateUpdateStatements(quotedColumnNames, unchangedToastColumns) - // append synced_at column + // append synced_at column (system column, added separately) if n.peerdbCols.SyncedAtColName != "" { quotedColumnNames = append(quotedColumnNames, utils.QuoteIdentifier(n.peerdbCols.SyncedAtColName)) insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 03c1c32131..0548a00f6d 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -1134,39 +1134,143 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas( defer shared.RollbackTx(tableSchemaModifyTx, c.logger) for _, schemaDelta := range schemaDeltas { - if schemaDelta == nil || len(schemaDelta.AddedColumns) == 0 { + if schemaDelta == nil { + c.logger.Warn("Skipping nil schema delta") continue } + // Skip if no schema changes + if len(schemaDelta.AddedColumns) == 0 && len(schemaDelta.DroppedColumns) == 0 && len(schemaDelta.TypeChangedColumns) == 0 { + continue + } + + dstSchemaTable, err := utils.ParseSchemaTable(schemaDelta.DstTableName) + if err != nil { + return fmt.Errorf("error parsing schema and table for %s: %w", schemaDelta.DstTableName, err) + } + + for _, droppedColumn := range schemaDelta.DroppedColumns { + // Never drop PeerDB system columns - they are managed by PeerDB + // System columns start with _PEERDB_ prefix + if strings.HasPrefix(droppedColumn, "_PEERDB_") { + c.logger.Warn("Skipping drop of PeerDB system column", + slog.String("columnName", droppedColumn), + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.String("reason", "PeerDB system columns cannot be dropped")) + continue + } + + dropSQL := fmt.Sprintf( + "ALTER TABLE %s.%s DROP COLUMN IF EXISTS %s", + utils.QuoteIdentifier(dstSchemaTable.Schema), + utils.QuoteIdentifier(dstSchemaTable.Table), + utils.QuoteIdentifier(droppedColumn)) + _, err = c.execWithLoggingTx(ctx, dropSQL, tableSchemaModifyTx) + if err != nil { + c.logger.Error("Failed to drop column", + slog.String("columnName", droppedColumn), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Any("error", err)) + return fmt.Errorf("failed to drop column %s for table %s: %w", droppedColumn, + schemaDelta.DstTableName, err) + } + } + + for _, typeChange := range schemaDelta.TypeChangedColumns { + // Never change type of PeerDB system columns + if strings.HasPrefix(typeChange.ColumnName, "_PEERDB_") { + continue + } + + newType := typeChange.NewType + if schemaDelta.System == protos.TypeSystem_Q { + newType = qValueKindToPostgresType(typeChange.NewType) + } + + // Build ALTER COLUMN statement + alterSQL := fmt.Sprintf( + "ALTER TABLE %s.%s ALTER COLUMN %s TYPE %s", + utils.QuoteIdentifier(dstSchemaTable.Schema), + utils.QuoteIdentifier(dstSchemaTable.Table), + utils.QuoteIdentifier(typeChange.ColumnName), + newType) + + _, err = c.execWithLoggingTx(ctx, alterSQL, tableSchemaModifyTx) + if err != nil { + c.logger.Error("Failed to change column type", + slog.String("columnName", typeChange.ColumnName), + slog.String("oldType", typeChange.OldType), + slog.String("newType", typeChange.NewType), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Any("error", err)) + } else { + c.logger.Info("Successfully changed column type", + slog.String("columnName", typeChange.ColumnName), + slog.String("oldType", typeChange.OldType), + slog.String("newType", typeChange.NewType), + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName)) + + // Update default value if provided + if typeChange.NewDefaultValue != "" { + defaultSQL := fmt.Sprintf( + "ALTER TABLE %s.%s ALTER COLUMN %s SET DEFAULT %s", + utils.QuoteIdentifier(dstSchemaTable.Schema), + utils.QuoteIdentifier(dstSchemaTable.Table), + utils.QuoteIdentifier(typeChange.ColumnName), + typeChange.NewDefaultValue) + + _, err = c.execWithLoggingTx(ctx, defaultSQL, tableSchemaModifyTx) + if err != nil { + c.logger.Warn("Failed to set default value after type change", + slog.String("columnName", typeChange.ColumnName), + slog.String("defaultValue", typeChange.NewDefaultValue), + slog.Any("error", err)) + } + } + } + } + for _, addedColumn := range schemaDelta.AddedColumns { columnType := addedColumn.Type if schemaDelta.System == protos.TypeSystem_Q { - columnType = qValueKindToPostgresType(columnType) + columnType = qValueKindToPostgresType(addedColumn.Type) } - dstSchemaTable, err := utils.ParseSchemaTable(schemaDelta.DstTableName) - if err != nil { - return fmt.Errorf("error parsing schema and table for %s: %w", schemaDelta.DstTableName, err) + // Build column definition with type and default value + columnDef := fmt.Sprintf("%s %s", utils.QuoteIdentifier(addedColumn.Name), columnType) + if addedColumn.DefaultValue != "" { + // Include DEFAULT value if present + columnDef += fmt.Sprintf(" DEFAULT %s", addedColumn.DefaultValue) + } else if !addedColumn.Nullable { + // If NOT NULL without DEFAULT, PostgreSQL will require it + columnDef += " NOT NULL" } - _, err = c.execWithLoggingTx(ctx, fmt.Sprintf( - "ALTER TABLE %s.%s ADD COLUMN IF NOT EXISTS %s %s", + addSQL := fmt.Sprintf( + "ALTER TABLE %s.%s ADD COLUMN IF NOT EXISTS %s", utils.QuoteIdentifier(dstSchemaTable.Schema), utils.QuoteIdentifier(dstSchemaTable.Table), - utils.QuoteIdentifier(addedColumn.Name), columnType), tableSchemaModifyTx) + columnDef) + + _, err = c.execWithLoggingTx(ctx, addSQL, tableSchemaModifyTx) if err != nil { + c.logger.Error("Failed to add column", + slog.String("columnName", addedColumn.Name), + slog.String("columnType", addedColumn.Type), + slog.String("dstTableName", schemaDelta.DstTableName), + slog.Any("error", err)) return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.Name, schemaDelta.DstTableName, err) } - c.logger.Info(fmt.Sprintf("[schema delta replay] added column %s with data type %s", - addedColumn.Name, addedColumn.Type), - slog.String("srcTableName", schemaDelta.SrcTableName), - slog.String("dstTableName", schemaDelta.DstTableName), - ) } } if err := tableSchemaModifyTx.Commit(ctx); err != nil { + c.logger.Error("Failed to commit schema modification transaction", + slog.String("flowJobName", flowJobName), + slog.Any("error", err)) return fmt.Errorf("failed to commit transaction for table schema modification: %w", err) } return nil diff --git a/flow/e2e/postgres_test.go b/flow/e2e/postgres_test.go index 702784e018..16b1b952ea 100644 --- a/flow/e2e/postgres_test.go +++ b/flow/e2e/postgres_test.go @@ -1286,3 +1286,176 @@ func (s PeerFlowE2ETestSuitePG) TestResync(tableName string) { env.Cancel(s.t.Context()) RequireEnvCanceled(s.t, env) } + +// Test_Indexes_Triggers_Constraints_PG tests that indexes, triggers, and constraints +// are synced from source to destination during initial setup +func (s PeerFlowE2ETestSuitePG) Test_Indexes_Triggers_Constraints_PG() { + tc := NewTemporalClient(s.t) + + srcTableName := s.attachSchemaSuffix("test_idx_trig_const") + dstTableName := s.attachSchemaSuffix("test_idx_trig_const_dst") + + // Create a trigger function on source + _, err := s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + CREATE OR REPLACE FUNCTION %s.update_timestamp() + RETURNS TRIGGER AS $$ + BEGIN + NEW.updated_at = NOW(); + RETURN NEW; + END; + $$ LANGUAGE plpgsql`, Schema(s))) + require.NoError(s.t, err) + + // Create source table with indexes, triggers, and constraints + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + CREATE TABLE IF NOT EXISTS %s ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + email TEXT, + age INT, + created_at TIMESTAMPTZ DEFAULT NOW(), + updated_at TIMESTAMPTZ DEFAULT NOW() + )`, srcTableName)) + require.NoError(s.t, err) + + // Create indexes on source + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + CREATE INDEX idx_name ON %s(name)`, srcTableName)) + require.NoError(s.t, err) + + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + CREATE UNIQUE INDEX idx_email ON %s(email) WHERE email IS NOT NULL`, srcTableName)) + require.NoError(s.t, err) + + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + CREATE INDEX idx_created_at ON %s(created_at DESC)`, srcTableName)) + require.NoError(s.t, err) + + // Create trigger on source + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + CREATE TRIGGER update_updated_at + BEFORE UPDATE ON %s + FOR EACH ROW + EXECUTE FUNCTION %s.update_timestamp()`, srcTableName, Schema(s))) + require.NoError(s.t, err) + + // Create check constraints on source + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_name_length + CHECK (char_length(name) >= 3)`, srcTableName)) + require.NoError(s.t, err) + + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_age_positive + CHECK (age > 0 AND age < 150)`, srcTableName)) + require.NoError(s.t, err) + + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + ALTER TABLE %s ADD CONSTRAINT check_email_format + CHECK (email IS NULL OR email LIKE '%%@%%')`, srcTableName)) + require.NoError(s.t, err) + + connectionGen := FlowConnectionGenerationConfig{ + FlowJobName: s.attachSuffix("test_idx_trig_const"), + TableNameMapping: map[string]string{srcTableName: dstTableName}, + Destination: s.Peer().Name, + } + + flowConnConfig := connectionGen.GenerateFlowConnectionConfigs(s) + flowConnConfig.MaxBatchSize = 100 + flowConnConfig.DoInitialSnapshot = true + + env := ExecutePeerflow(s.t, tc, flowConnConfig) + SetupCDCFlowStatusQuery(s.t, env, flowConnConfig) + + // Wait for initial setup to complete (this includes index/trigger/constraint sync) + EnvWaitFor(s.t, env, 3*time.Minute, "waiting for initial setup", func() bool { + // Check if destination table exists + var exists bool + err := s.Conn().QueryRow(s.t.Context(), ` + SELECT EXISTS ( + SELECT 1 FROM pg_catalog.pg_tables + WHERE schemaname = $1 AND tablename = $2 + )`, Schema(s), "test_idx_trig_const_dst").Scan(&exists) + return err == nil && exists + }) + + // Verify indexes were synced + s.t.Log("Verifying indexes were synced...") + indexQuery := ` + SELECT indexname + FROM pg_indexes + WHERE schemaname = $1 AND tablename = $2 AND indexname != $3 + ORDER BY indexname + ` + rows, err := s.Conn().Query(s.t.Context(), indexQuery, Schema(s), "test_idx_trig_const_dst", "test_idx_trig_const_dst_pkey") + require.NoError(s.t, err) + defer rows.Close() + + indexNames := make(map[string]bool) + for rows.Next() { + var indexName string + require.NoError(s.t, rows.Scan(&indexName)) + indexNames[indexName] = true + } + + require.True(s.t, indexNames["idx_name"], "idx_name should be synced") + require.True(s.t, indexNames["idx_email"], "idx_email should be synced") + require.True(s.t, indexNames["idx_created_at"], "idx_created_at should be synced") + + // Verify trigger was synced + s.t.Log("Verifying trigger was synced...") + triggerQuery := ` + SELECT t.tgname + FROM pg_trigger t + JOIN pg_class c ON c.oid = t.tgrelid + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = $1 + AND c.relname = $2 + AND NOT t.tgisinternal + ` + var triggerName string + err = s.Conn().QueryRow(s.t.Context(), triggerQuery, Schema(s), "test_idx_trig_const_dst").Scan(&triggerName) + require.NoError(s.t, err) + require.Equal(s.t, "update_updated_at", triggerName, "update_updated_at trigger should be synced") + + // Verify constraints were synced + s.t.Log("Verifying constraints were synced...") + constraintQuery := ` + SELECT con.conname + FROM pg_constraint con + JOIN pg_class c ON c.oid = con.conrelid + JOIN pg_namespace n ON n.oid = c.relnamespace + WHERE n.nspname = $1 + AND c.relname = $2 + AND con.contype::text = 'c' + ORDER BY con.conname + ` + rows, err = s.Conn().Query(s.t.Context(), constraintQuery, Schema(s), "test_idx_trig_const_dst") + require.NoError(s.t, err) + defer rows.Close() + + constraintNames := make(map[string]bool) + for rows.Next() { + var constraintName string + require.NoError(s.t, rows.Scan(&constraintName)) + constraintNames[constraintName] = true + } + + require.True(s.t, constraintNames["check_name_length"], "check_name_length should be synced") + require.True(s.t, constraintNames["check_age_positive"], "check_age_positive should be synced") + require.True(s.t, constraintNames["check_email_format"], "check_email_format should be synced") + + s.t.Log("All indexes, triggers, and constraints were successfully synced!") + + // Insert some data to verify the trigger works + _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` + INSERT INTO %s(name, email, age) VALUES ('test', 'test@example.com', 25)`, srcTableName)) + EnvNoError(s.t, env, err) + + // Wait for data to sync + EnvWaitForEqualTablesWithNames(env, s, "waiting for data sync", srcTableName, dstTableName, "id,name,email,age") + + env.Cancel(s.t.Context()) + RequireEnvCanceled(s.t, env) +} diff --git a/flow/workflows/setup_flow.go b/flow/workflows/setup_flow.go index f4c79eb67c..8029631bb7 100644 --- a/flow/workflows/setup_flow.go +++ b/flow/workflows/setup_flow.go @@ -213,6 +213,7 @@ func (s *SetupFlowExecution) setupNormalizedTables( FlowName: flowConnectionConfigs.FlowJobName, Env: flowConnectionConfigs.Env, IsResync: flowConnectionConfigs.Resync, + SourcePeerName: flowConnectionConfigs.SourceName, } if err := workflow.ExecuteActivity(ctx, flowable.CreateNormalizedTable, setupConfig).Get(ctx, nil); err != nil { @@ -220,6 +221,12 @@ func (s *SetupFlowExecution) setupNormalizedTables( return fmt.Errorf("failed to create normalized tables: %w", err) } + // Sync indexes, triggers, and constraints after tables are created + if err := workflow.ExecuteActivity(ctx, flowable.SyncIndexesAndTriggers, setupConfig).Get(ctx, nil); err != nil { + s.Warn("failed to sync indexes, triggers, and constraints", slog.Any("error", err)) + // Don't fail the setup if sync fails - log warning and continue + } + s.Info("finished setting up normalized tables for peer flow") return nil } diff --git a/protos/flow.proto b/protos/flow.proto index ee94edfadb..b412379075 100644 --- a/protos/flow.proto +++ b/protos/flow.proto @@ -240,6 +240,7 @@ message FieldDescription { string type = 2; int32 type_modifier = 3; bool nullable = 4; + string default_value = 5; } message SetupTableSchemaBatchInput { @@ -262,6 +263,7 @@ message SetupNormalizedTableBatchInput { string flow_name = 6; string peer_name = 7; bool is_resync = 8; + string source_peer_name = 9; } message SetupNormalizedTableOutput { @@ -425,10 +427,19 @@ message DropFlowInput { bool resync = 8; } +message ColumnTypeChange { + string column_name = 1; + string old_type = 2; + string new_type = 3; + string new_default_value = 4; +} + message TableSchemaDelta { string src_table_name = 1; string dst_table_name = 2; repeated FieldDescription added_columns = 3; + repeated string dropped_columns = 6; + repeated ColumnTypeChange type_changed_columns = 7; TypeSystem system = 4; bool nullable_enabled = 5; } From 150125d697952a32e1b4ef0e49ef4ccb449e602f Mon Sep 17 00:00:00 2001 From: ankitsheoran1 Date: Mon, 17 Nov 2025 19:18:00 +0530 Subject: [PATCH 3/6] refactor --- .../connectors/postgres/index_trigger_sync.go | 194 ------------------ .../postgres/normalize_stmt_generator.go | 2 - flow/connectors/postgres/postgres.go | 5 + flow/workflows/setup_flow.go | 1 - 4 files changed, 5 insertions(+), 197 deletions(-) diff --git a/flow/connectors/postgres/index_trigger_sync.go b/flow/connectors/postgres/index_trigger_sync.go index 5f8b7aa65c..153a35baba 100644 --- a/flow/connectors/postgres/index_trigger_sync.go +++ b/flow/connectors/postgres/index_trigger_sync.go @@ -10,7 +10,6 @@ import ( "github.com/PeerDB-io/peerdb/flow/generated/protos" ) -// IndexInfo represents information about a PostgreSQL index type IndexInfo struct { IndexName string TableSchema string @@ -20,8 +19,6 @@ type IndexInfo struct { IsPrimary bool IndexColumns []string } - -// TriggerInfo represents information about a PostgreSQL trigger type TriggerInfo struct { TriggerName string TableSchema string @@ -32,7 +29,6 @@ type TriggerInfo struct { ActionStatement string } -// ConstraintInfo represents information about a PostgreSQL constraint type ConstraintInfo struct { ConstraintName string TableSchema string @@ -43,21 +39,6 @@ type ConstraintInfo struct { IsDeferred bool } -// SyncIndexesAndTriggers syncs indexes, triggers, and constraints from source to destination. -// This is called once during initial setup, not for on-the-fly changes. -// -// Features: -// - Syncs all non-primary-key indexes from source to destination -// - Syncs all triggers from source to destination -// - Syncs check constraints and foreign key constraints -// - Automatically syncs trigger functions if they don't exist on destination -// - Skips indexes/triggers/constraints that already exist on destination -// -// Limitations: -// - Only runs during initial setup (not for on-the-fly changes) -// - Requires trigger functions to exist on source (or will attempt to sync them) -// - Primary key indexes and constraints are skipped (already exist) -// - Foreign key constraints referencing tables not in the sync are skipped func (c *PostgresConnector) SyncIndexesAndTriggers( ctx context.Context, tableMappings []*protos.TableMapping, @@ -83,25 +64,20 @@ func (c *PostgresConnector) SyncIndexesAndTriggers( slog.String("srcTable", srcTable.String()), slog.String("dstTable", dstTable.String()), slog.Any("error", err)) - // Continue with other tables even if one fails } - // Sync triggers if err := c.syncTriggersForTable(ctx, srcTable, dstTable, sourceConn); err != nil { c.logger.Warn("Failed to sync triggers for table", slog.String("srcTable", srcTable.String()), slog.String("dstTable", dstTable.String()), slog.Any("error", err)) - // Continue with other tables even if one fails } - // Sync constraints (check constraints and foreign keys) if err := c.syncConstraintsForTable(ctx, srcTable, dstTable, sourceConn, tableMappings); err != nil { c.logger.Warn("Failed to sync constraints for table", slog.String("srcTable", srcTable.String()), slog.String("dstTable", dstTable.String()), slog.Any("error", err)) - // Continue with other tables even if one fails } } @@ -109,49 +85,39 @@ func (c *PostgresConnector) SyncIndexesAndTriggers( return nil } -// syncIndexesForTable syncs indexes for a specific table func (c *PostgresConnector) syncIndexesForTable( ctx context.Context, srcTable *utils.SchemaTable, dstTable *utils.SchemaTable, sourceConn *PostgresConnector, ) error { - // Get indexes from source srcIndexes, err := sourceConn.getIndexesForTable(ctx, srcTable) if err != nil { return fmt.Errorf("error getting source indexes: %w", err) } - // Get indexes from destination dstIndexes, err := c.getIndexesForTable(ctx, dstTable) if err != nil { return fmt.Errorf("error getting destination indexes: %w", err) } - // Create a map of destination indexes by name for quick lookup dstIndexMap := make(map[string]*IndexInfo, len(dstIndexes)) for _, idx := range dstIndexes { dstIndexMap[idx.IndexName] = idx } - // Find missing indexes and create them createdCount := 0 for _, srcIdx := range srcIndexes { - // Skip primary key indexes - they should already exist if srcIdx.IsPrimary { continue } - // Check if index already exists in destination if _, exists := dstIndexMap[srcIdx.IndexName]; exists { c.logger.Debug("Index already exists in destination", slog.String("indexName", srcIdx.IndexName), slog.String("dstTable", dstTable.String())) continue } - - // Create the index - // Replace source schema/table names with destination indexSQL := c.adaptIndexSQL(srcIdx.IndexDef, srcTable, dstTable) c.logger.Info("Creating index on destination", @@ -165,7 +131,6 @@ func (c *PostgresConnector) syncIndexesForTable( slog.String("indexName", srcIdx.IndexName), slog.String("indexSQL", indexSQL), slog.Any("error", err)) - // Continue with other indexes even if one fails continue } @@ -191,37 +156,21 @@ func (c *PostgresConnector) syncTriggersForTable( dstTable *utils.SchemaTable, sourceConn *PostgresConnector, ) error { - c.logger.Info("Starting trigger sync for table", - slog.String("srcTable", srcTable.String()), - slog.String("dstTable", dstTable.String())) - - // Get triggers from source srcTriggers, err := sourceConn.getTriggersForTable(ctx, srcTable) if err != nil { return fmt.Errorf("error getting source triggers: %w", err) } - c.logger.Info("Retrieved source triggers", - slog.String("srcTable", srcTable.String()), - slog.Int("triggerCount", len(srcTriggers))) - - // Get triggers from destination dstTriggers, err := c.getTriggersForTable(ctx, dstTable) if err != nil { return fmt.Errorf("error getting destination triggers: %w", err) } - c.logger.Info("Retrieved destination triggers", - slog.String("dstTable", dstTable.String()), - slog.Int("triggerCount", len(dstTriggers))) - - // Create a map of destination triggers by name for quick lookup dstTriggerMap := make(map[string]*TriggerInfo, len(dstTriggers)) for _, trig := range dstTriggers { dstTriggerMap[trig.TriggerName] = trig } - // Find missing triggers and create them createdCount := 0 for _, srcTrig := range srcTriggers { c.logger.Info("Processing source trigger", @@ -229,7 +178,6 @@ func (c *PostgresConnector) syncTriggersForTable( slog.String("triggerDef", srcTrig.TriggerDef), slog.String("srcTable", srcTable.String())) - // Check if trigger already exists in destination if _, exists := dstTriggerMap[srcTrig.TriggerName]; exists { c.logger.Info("Trigger already exists in destination, skipping", slog.String("triggerName", srcTrig.TriggerName), @@ -237,19 +185,11 @@ func (c *PostgresConnector) syncTriggersForTable( continue } - // Extract function name from trigger definition funcName, funcSchema := c.extractFunctionFromTriggerDef(srcTrig.TriggerDef) - c.logger.Info("Extracted function from trigger definition", - slog.String("triggerName", srcTrig.TriggerName), - slog.String("functionName", funcName), - slog.String("functionSchema", funcSchema)) - // Check if function exists on destination - // Try multiple schemas if function name doesn't have schema qualification funcExists := false if funcName != "" { schemasToCheck := []string{funcSchema} - // If no schema was specified, try public schema and table's schema if funcSchema == "public" || funcSchema == "" { schemasToCheck = []string{"public", dstTable.Schema, srcTable.Schema} } @@ -274,8 +214,6 @@ func (c *PostgresConnector) syncTriggersForTable( } if !funcExists { - // Try to sync the function from source - // Try multiple schemas on source to find the function sourceSchemasToCheck := []string{funcSchema, "public", srcTable.Schema} if funcSchema == "public" || funcSchema == "" { sourceSchemasToCheck = []string{"public", srcTable.Schema} @@ -289,14 +227,9 @@ func (c *PostgresConnector) syncTriggersForTable( slog.String("targetSchema", funcSchema)) if err := c.syncTriggerFunction(ctx, sourceSchema, funcName, funcSchema, sourceConn); err != nil { - c.logger.Debug("Failed to sync function from this schema, trying next", - slog.String("functionName", funcName), - slog.String("sourceSchema", sourceSchema), - slog.Any("error", err)) continue } - // Verify function was created funcExists, err = c.checkFunctionExists(ctx, funcSchema, funcName) if err == nil && funcExists { funcSynced = true @@ -318,24 +251,13 @@ func (c *PostgresConnector) syncTriggersForTable( } } } - - // Create the trigger - // pg_get_triggerdef already gives us the full CREATE TRIGGER statement - // We just need to replace source schema/table names with destination triggerSQL := c.adaptTriggerSQL(srcTrig.TriggerDef, srcTable, dstTable) - c.logger.Info("Creating trigger on destination", - slog.String("triggerName", srcTrig.TriggerName), - slog.String("srcTable", srcTable.String()), - slog.String("dstTable", dstTable.String()), - slog.String("triggerSQL", triggerSQL)) - if _, err := c.conn.Exec(ctx, triggerSQL); err != nil { c.logger.Error("Failed to create trigger", slog.String("triggerName", srcTrig.TriggerName), slog.String("triggerSQL", triggerSQL), slog.Any("error", err)) - // Continue with other triggers even if one fails continue } @@ -359,7 +281,6 @@ func (c *PostgresConnector) getIndexesForTable( ctx context.Context, table *utils.SchemaTable, ) ([]*IndexInfo, error) { - // Use pg_indexes view which is simpler and more reliable query := ` SELECT indexname, @@ -390,16 +311,10 @@ func (c *PostgresConnector) getIndexesForTable( return nil, fmt.Errorf("error scanning index row: %w", err) } - // Determine if index is unique or primary key idx.IsUnique = strings.Contains(strings.ToUpper(idx.IndexDef), "UNIQUE") - - // Check if it's a primary key constraint - // Primary keys are typically named like tablename_pkey idx.IsPrimary = strings.HasSuffix(idx.IndexName, "_pkey") || strings.Contains(strings.ToUpper(idx.IndexDef), "PRIMARY KEY") - // Extract column names from index definition - // This is a simple extraction - may need refinement for complex cases idx.IndexColumns = c.extractColumnsFromIndexDef(idx.IndexDef) indexes = append(indexes, &idx) @@ -414,23 +329,16 @@ func (c *PostgresConnector) getIndexesForTable( // extractColumnsFromIndexDef extracts column names from index definition func (c *PostgresConnector) extractColumnsFromIndexDef(indexDef string) []string { - // This is a simplified extraction - looks for patterns like (col1, col2) - // For more complex cases, we might need to parse the SQL properly var columns []string - // Find the part between parentheses start := strings.Index(indexDef, "(") end := strings.LastIndex(indexDef, ")") if start >= 0 && end > start { colPart := indexDef[start+1 : end] - // Split by comma and clean up parts := strings.Split(colPart, ",") for _, part := range parts { col := strings.TrimSpace(part) - // Remove function calls, operators, etc. - just get column name - // Remove quotes if present col = strings.Trim(col, `"'`) - // Take only the column name part (before any operators or functions) if spaceIdx := strings.Index(col, " "); spaceIdx > 0 { col = col[:spaceIdx] } @@ -448,7 +356,6 @@ func (c *PostgresConnector) getTriggersForTable( ctx context.Context, table *utils.SchemaTable, ) ([]*TriggerInfo, error) { - // Use pg_trigger and pg_proc to get full trigger definition query := ` SELECT t.tgname as trigger_name, @@ -496,9 +403,6 @@ func (c *PostgresConnector) getTriggersForTable( return nil, fmt.Errorf("error scanning trigger row: %w", err) } - // Extract action statement from trigger definition - // The trigger_def from pg_get_triggerdef already contains the full CREATE TRIGGER statement - // We just need to extract the EXECUTE FUNCTION part trig.ActionStatement = c.extractActionStatement(trig.TriggerDef) triggers = append(triggers, &trig) @@ -511,11 +415,7 @@ func (c *PostgresConnector) getTriggersForTable( return triggers, nil } -// extractActionStatement extracts the action statement (EXECUTE FUNCTION ...) from trigger definition func (c *PostgresConnector) extractActionStatement(triggerDef string) string { - // pg_get_triggerdef returns something like: - // CREATE TRIGGER trigger_name BEFORE INSERT ON schema.table FOR EACH ROW EXECUTE FUNCTION function_name() - // We want to extract the EXECUTE FUNCTION part executeIdx := strings.Index(strings.ToUpper(triggerDef), "EXECUTE") if executeIdx >= 0 { return triggerDef[executeIdx:] @@ -523,18 +423,15 @@ func (c *PostgresConnector) extractActionStatement(triggerDef string) string { return "" } -// adaptIndexSQL adapts index SQL from source to destination table func (c *PostgresConnector) adaptIndexSQL( indexSQL string, srcTable *utils.SchemaTable, dstTable *utils.SchemaTable, ) string { - // Replace source schema.table with destination schema.table adapted := strings.ReplaceAll(indexSQL, fmt.Sprintf("%s.%s", utils.QuoteIdentifier(srcTable.Schema), utils.QuoteIdentifier(srcTable.Table)), fmt.Sprintf("%s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTable.Table))) - // Also handle unquoted versions adapted = strings.ReplaceAll(adapted, fmt.Sprintf("%s.%s", srcTable.Schema, srcTable.Table), fmt.Sprintf("%s.%s", dstTable.Schema, dstTable.Table)) @@ -542,22 +439,13 @@ func (c *PostgresConnector) adaptIndexSQL( return adapted } -// extractFunctionFromTriggerDef extracts function name and schema from trigger definition func (c *PostgresConnector) extractFunctionFromTriggerDef(triggerDef string) (funcName, funcSchema string) { - // pg_get_triggerdef returns something like: - // CREATE TRIGGER trigger_name BEFORE INSERT ON schema.table FOR EACH ROW EXECUTE FUNCTION schema.function_name() - // We need to extract the function name and schema - - // Find "EXECUTE FUNCTION" or "EXECUTE PROCEDURE" executeIdx := strings.Index(strings.ToUpper(triggerDef), "EXECUTE") if executeIdx < 0 { return "", "" } - // Get the part after EXECUTE executePart := triggerDef[executeIdx:] - - // Look for FUNCTION or PROCEDURE keyword funcKeywordIdx := strings.Index(strings.ToUpper(executePart), "FUNCTION") if funcKeywordIdx < 0 { funcKeywordIdx = strings.Index(strings.ToUpper(executePart), "PROCEDURE") @@ -566,10 +454,7 @@ func (c *PostgresConnector) extractFunctionFromTriggerDef(triggerDef string) (fu } } - // Get the function part (after FUNCTION/PROCEDURE keyword) funcPart := strings.TrimSpace(executePart[funcKeywordIdx+8:]) // 8 = len("FUNCTION") or len("PROCEDURE") - - // Remove trailing parentheses and whitespace funcPart = strings.TrimSpace(strings.TrimSuffix(funcPart, "()")) funcPart = strings.TrimSpace(strings.TrimSuffix(funcPart, ")")) @@ -577,11 +462,9 @@ func (c *PostgresConnector) extractFunctionFromTriggerDef(triggerDef string) (fu if dotIdx := strings.LastIndex(funcPart, "."); dotIdx >= 0 { funcSchema = funcPart[:dotIdx] funcName = funcPart[dotIdx+1:] - // Remove quotes if present funcSchema = strings.Trim(funcSchema, `"'`) funcName = strings.Trim(funcName, `"'`) } else { - // No schema, function is in current schema or public funcName = strings.Trim(funcPart, `"'`) funcSchema = "public" // Default to public schema } @@ -615,7 +498,6 @@ func (c *PostgresConnector) syncTriggerFunction( sourceSchema, funcName, targetSchema string, sourceConn *PostgresConnector, ) error { - // Get function definition from source query := ` SELECT pg_get_functiondef(p.oid) as function_def FROM pg_proc p @@ -639,8 +521,6 @@ func (c *PostgresConnector) syncTriggerFunction( slog.String("sourceSchema", sourceSchema), slog.String("functionDef", funcDef)) - // Adapt function definition to use target schema if different - // Replace source schema with target schema in the function definition if sourceSchema != targetSchema { funcDef = strings.ReplaceAll(funcDef, fmt.Sprintf("%s.%s", utils.QuoteIdentifier(sourceSchema), utils.QuoteIdentifier(funcName)), @@ -650,9 +530,6 @@ func (c *PostgresConnector) syncTriggerFunction( fmt.Sprintf("%s.%s", targetSchema, funcName)) } - // Create function on destination - // The function definition from pg_get_functiondef already includes CREATE OR REPLACE FUNCTION - // We just need to execute it c.logger.Info("Creating function on destination", slog.String("functionName", funcName), slog.String("targetSchema", targetSchema)) @@ -674,12 +551,10 @@ func (c *PostgresConnector) adaptTriggerSQL( srcTable *utils.SchemaTable, dstTable *utils.SchemaTable, ) string { - // Replace source schema.table with destination schema.table adapted := strings.ReplaceAll(triggerSQL, fmt.Sprintf("%s.%s", utils.QuoteIdentifier(srcTable.Schema), utils.QuoteIdentifier(srcTable.Table)), fmt.Sprintf("%s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTable.Table))) - // Also handle unquoted versions adapted = strings.ReplaceAll(adapted, fmt.Sprintf("%s.%s", srcTable.Schema, srcTable.Table), fmt.Sprintf("%s.%s", dstTable.Schema, dstTable.Table)) @@ -695,37 +570,21 @@ func (c *PostgresConnector) syncConstraintsForTable( sourceConn *PostgresConnector, tableMappings []*protos.TableMapping, ) error { - c.logger.Info("Starting constraint sync for table", - slog.String("srcTable", srcTable.String()), - slog.String("dstTable", dstTable.String())) - - // Get constraints from source srcConstraints, err := sourceConn.getConstraintsForTable(ctx, srcTable) if err != nil { return fmt.Errorf("error getting source constraints: %w", err) } - c.logger.Info("Retrieved source constraints", - slog.String("srcTable", srcTable.String()), - slog.Int("constraintCount", len(srcConstraints))) - - // Get constraints from destination dstConstraints, err := c.getConstraintsForTable(ctx, dstTable) if err != nil { return fmt.Errorf("error getting destination constraints: %w", err) } - c.logger.Info("Retrieved destination constraints", - slog.String("dstTable", dstTable.String()), - slog.Int("constraintCount", len(dstConstraints))) - - // Create a map of destination constraints by name for quick lookup dstConstraintMap := make(map[string]*ConstraintInfo, len(dstConstraints)) for _, constraint := range dstConstraints { dstConstraintMap[constraint.ConstraintName] = constraint } - // Build a mapping of source table names to destination table names for FK resolution tableNameMap := make(map[string]string) for _, tm := range tableMappings { src, err := utils.ParseSchemaTable(tm.SourceTableIdentifier) @@ -736,59 +595,34 @@ func (c *PostgresConnector) syncConstraintsForTable( if err != nil { continue } - // Map both qualified and unqualified names tableNameMap[src.String()] = dst.String() tableNameMap[fmt.Sprintf("%s.%s", src.Schema, src.Table)] = fmt.Sprintf("%s.%s", dst.Schema, dst.Table) } - // Find missing constraints and create them createdCount := 0 for _, srcConstraint := range srcConstraints { - c.logger.Info("Processing source constraint", - slog.String("constraintName", srcConstraint.ConstraintName), - slog.String("constraintType", srcConstraint.ConstraintType), - slog.String("constraintDef", srcConstraint.ConstraintDef), - slog.String("srcTable", srcTable.String())) - // Skip primary key constraints - they should already exist if srcConstraint.ConstraintType == "p" { - c.logger.Debug("Skipping primary key constraint", - slog.String("constraintName", srcConstraint.ConstraintName)) continue } - - // Skip unique constraints that are already covered by unique indexes if srcConstraint.ConstraintType == "u" { - c.logger.Debug("Skipping unique constraint (handled by unique index)", - slog.String("constraintName", srcConstraint.ConstraintName)) continue } - // Check if constraint already exists in destination if _, exists := dstConstraintMap[srcConstraint.ConstraintName]; exists { c.logger.Info("Constraint already exists in destination, skipping", slog.String("constraintName", srcConstraint.ConstraintName), slog.String("dstTable", dstTable.String())) continue } - - // Adapt constraint definition for destination constraintSQL := c.adaptConstraintSQL(srcConstraint.ConstraintDef, srcTable, dstTable, tableNameMap, srcConstraint.ConstraintName) - c.logger.Info("Creating constraint on destination", - slog.String("constraintName", srcConstraint.ConstraintName), - slog.String("constraintType", srcConstraint.ConstraintType), - slog.String("srcTable", srcTable.String()), - slog.String("dstTable", dstTable.String()), - slog.String("constraintSQL", constraintSQL)) - if _, err := c.conn.Exec(ctx, constraintSQL); err != nil { c.logger.Error("Failed to create constraint", slog.String("constraintName", srcConstraint.ConstraintName), slog.String("constraintType", srcConstraint.ConstraintType), slog.String("constraintSQL", constraintSQL), slog.Any("error", err)) - // Continue with other constraints even if one fails continue } @@ -813,7 +647,6 @@ func (c *PostgresConnector) getConstraintsForTable( ctx context.Context, table *utils.SchemaTable, ) ([]*ConstraintInfo, error) { - // Query constraints from pg_constraint query := ` SELECT con.conname as constraint_name, @@ -872,33 +705,18 @@ func (c *PostgresConnector) adaptConstraintSQL( tableNameMap map[string]string, constraintName string, ) string { - // The constraint definition from pg_get_constraintdef is already in the format: - // For check: CHECK (expression) - // For foreign key: FOREIGN KEY (columns) REFERENCES table(columns) - // We need to: - // 1. For foreign keys, replace referenced table names using tableNameMap FIRST - // 2. Then replace source table name with destination table name (for self-referencing FKs) adapted := constraintDef - - // For foreign key constraints, replace referenced table names - // Handle both cross-table and self-referencing foreign keys if strings.HasPrefix(strings.ToUpper(strings.TrimSpace(adapted)), "FOREIGN KEY") { - // First, add the current table to the tableNameMap if not already present - // This ensures self-referencing FKs are handled srcTableStr := fmt.Sprintf("%s.%s", srcTable.Schema, srcTable.Table) dstTableStr := fmt.Sprintf("%s.%s", dstTable.Schema, dstTable.Table) if _, exists := tableNameMap[srcTableStr]; !exists { tableNameMap[srcTableStr] = dstTableStr } - // Also add unqualified table name if _, exists := tableNameMap[srcTable.Table]; !exists { tableNameMap[srcTable.Table] = dstTable.Table } - - // Look for REFERENCES clause and replace table names for srcTableName, dstTableName := range tableNameMap { - // Handle schema-qualified table names in REFERENCES if strings.Contains(srcTableName, ".") { parts := strings.Split(srcTableName, ".") if len(parts) == 2 { @@ -906,44 +724,32 @@ func (c *PostgresConnector) adaptConstraintSQL( dstParts := strings.Split(dstTableName, ".") if len(dstParts) == 2 { dstSchema, dstTbl := dstParts[0], dstParts[1] - // Replace schema.table references (quoted) adapted = strings.ReplaceAll(adapted, fmt.Sprintf("REFERENCES %s.%s", utils.QuoteIdentifier(srcSchema), utils.QuoteIdentifier(srcTbl)), fmt.Sprintf("REFERENCES %s.%s", utils.QuoteIdentifier(dstSchema), utils.QuoteIdentifier(dstTbl))) - // Replace schema.table references (unquoted) adapted = strings.ReplaceAll(adapted, fmt.Sprintf("REFERENCES %s.%s", srcSchema, srcTbl), fmt.Sprintf("REFERENCES %s.%s", dstSchema, dstTbl)) } } } else { - // Handle unqualified table names in REFERENCES - // Replace unqualified table name (quoted) adapted = strings.ReplaceAll(adapted, fmt.Sprintf("REFERENCES %s", utils.QuoteIdentifier(srcTableName)), fmt.Sprintf("REFERENCES %s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTableName))) - // Replace unqualified table name (unquoted) adapted = strings.ReplaceAll(adapted, fmt.Sprintf("REFERENCES %s", srcTableName), fmt.Sprintf("REFERENCES %s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTableName))) } } } - - // Replace source schema.table with destination schema.table in CHECK constraints - // (For FKs, we've already handled the REFERENCES clause above) adapted = strings.ReplaceAll(adapted, fmt.Sprintf("%s.%s", utils.QuoteIdentifier(srcTable.Schema), utils.QuoteIdentifier(srcTable.Table)), fmt.Sprintf("%s.%s", utils.QuoteIdentifier(dstTable.Schema), utils.QuoteIdentifier(dstTable.Table))) - // Also handle unquoted versions adapted = strings.ReplaceAll(adapted, fmt.Sprintf("%s.%s", srcTable.Schema, srcTable.Table), fmt.Sprintf("%s.%s", dstTable.Schema, dstTable.Table)) - // Build the full ALTER TABLE statement - // For check constraints: ALTER TABLE ... ADD CONSTRAINT ... CHECK ... - // For foreign keys: ALTER TABLE ... ADD CONSTRAINT ... FOREIGN KEY ... if strings.HasPrefix(strings.ToUpper(strings.TrimSpace(adapted)), "CHECK") || strings.HasPrefix(strings.ToUpper(strings.TrimSpace(adapted)), "FOREIGN KEY") { return fmt.Sprintf("ALTER TABLE %s.%s ADD CONSTRAINT %s %s", diff --git a/flow/connectors/postgres/normalize_stmt_generator.go b/flow/connectors/postgres/normalize_stmt_generator.go index 699796d593..8d49969c06 100644 --- a/flow/connectors/postgres/normalize_stmt_generator.go +++ b/flow/connectors/postgres/normalize_stmt_generator.go @@ -157,7 +157,6 @@ func (n *normalizeStmtGenerator) generateMergeStatement( quotedColumnNamesFiltered := make([]string, 0, columnCount) for _, column := range normalizedTableSchema.Columns { - // Skip PeerDB system columns - they are handled separately if systemCols[column.Name] { continue } @@ -185,7 +184,6 @@ func (n *normalizeStmtGenerator) generateMergeStatement( } updateStatementsforToastCols := n.generateUpdateStatements(quotedColumnNames, unchangedToastColumns) - // append synced_at column (system column, added separately) if n.peerdbCols.SyncedAtColName != "" { quotedColumnNames = append(quotedColumnNames, utils.QuoteIdentifier(n.peerdbCols.SyncedAtColName)) insertValuesSQLArray = append(insertValuesSQLArray, "CURRENT_TIMESTAMP") diff --git a/flow/connectors/postgres/postgres.go b/flow/connectors/postgres/postgres.go index 0548a00f6d..99200cb1e3 100644 --- a/flow/connectors/postgres/postgres.go +++ b/flow/connectors/postgres/postgres.go @@ -1264,6 +1264,11 @@ func (c *PostgresConnector) ReplayTableSchemaDeltas( return fmt.Errorf("failed to add column %s for table %s: %w", addedColumn.Name, schemaDelta.DstTableName, err) } + c.logger.Info(fmt.Sprintf("[schema delta replay] added column %s with data type %s", + addedColumn.Name, addedColumn.Type), + slog.String("srcTableName", schemaDelta.SrcTableName), + slog.String("dstTableName", schemaDelta.DstTableName), + ) } } diff --git a/flow/workflows/setup_flow.go b/flow/workflows/setup_flow.go index 8029631bb7..be27ede2f4 100644 --- a/flow/workflows/setup_flow.go +++ b/flow/workflows/setup_flow.go @@ -224,7 +224,6 @@ func (s *SetupFlowExecution) setupNormalizedTables( // Sync indexes, triggers, and constraints after tables are created if err := workflow.ExecuteActivity(ctx, flowable.SyncIndexesAndTriggers, setupConfig).Get(ctx, nil); err != nil { s.Warn("failed to sync indexes, triggers, and constraints", slog.Any("error", err)) - // Don't fail the setup if sync fails - log warning and continue } s.Info("finished setting up normalized tables for peer flow") From 8aba6ef470b9f9d0578d25540cf44e11aace2cc1 Mon Sep 17 00:00:00 2001 From: ankitsheoran1 Date: Mon, 17 Nov 2025 19:20:02 +0530 Subject: [PATCH 4/6] remove extra comment --- flow/activities/flowable_core.go | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/flow/activities/flowable_core.go b/flow/activities/flowable_core.go index 75605ccb75..815af3ffb5 100644 --- a/flow/activities/flowable_core.go +++ b/flow/activities/flowable_core.go @@ -108,9 +108,6 @@ func (a *FlowableActivity) applySchemaDeltas( return nil } -// updateDestinationSchemaMapping updates the destination schema mapping in the catalog -// after schema deltas have been applied to the destination. This ensures normalization -// uses the correct updated schema. func (a *FlowableActivity) updateDestinationSchemaMapping( ctx context.Context, config *protos.FlowConnectionConfigsCore, @@ -119,7 +116,6 @@ func (a *FlowableActivity) updateDestinationSchemaMapping( ) error { logger := internal.LoggerFromCtx(ctx) - // Filter table mappings for tables that had schema changes filteredTableMappings := make([]*protos.TableMapping, 0, len(schemaDeltas)) for _, tableMapping := range options.TableMappings { if slices.ContainsFunc(schemaDeltas, func(schemaDelta *protos.TableSchemaDelta) bool { @@ -134,27 +130,19 @@ func (a *FlowableActivity) updateDestinationSchemaMapping( return nil } - // Get destination connector to fetch updated schema dstConn, dstClose, err := connectors.GetByNameAs[connectors.GetTableSchemaConnector](ctx, config.Env, a.CatalogPool, config.DestinationName) if err != nil { return fmt.Errorf("failed to get destination connector for schema update: %w", err) } defer dstClose(ctx) - logger.Info("Updating destination schema mapping after schema deltas", - slog.String("flowName", config.FlowJobName), - slog.Int("tablesAffected", len(filteredTableMappings))) - - // Fetch updated schema from destination tableNameSchemaMapping, err := dstConn.GetTableSchema(ctx, config.Env, config.Version, config.System, filteredTableMappings) if err != nil { return fmt.Errorf("failed to get updated schema from destination: %w", err) } - // Build processed schema mapping (maps destination table names to schemas) processed := internal.BuildProcessedSchemaMapping(filteredTableMappings, tableNameSchemaMapping, logger) - // Update catalog with new destination schemas tx, err := a.CatalogPool.BeginTx(ctx, pgx.TxOptions{}) if err != nil { return fmt.Errorf("failed to start transaction for schema mapping update: %w", err) @@ -176,18 +164,11 @@ func (a *FlowableActivity) updateDestinationSchemaMapping( ); err != nil { return fmt.Errorf("failed to update schema mapping for %s: %w", tableName, err) } - logger.Info("Updated destination schema mapping in catalog", - slog.String("tableName", tableName), - slog.Int("columnCount", len(tableSchema.Columns))) } if err := tx.Commit(ctx); err != nil { return fmt.Errorf("failed to commit schema mapping update: %w", err) } - - logger.Info("Successfully updated destination schema mapping", - slog.String("flowName", config.FlowJobName), - slog.Int("tablesUpdated", len(processed))) return nil } @@ -421,12 +402,10 @@ func syncCore[TPull connectors.CDCPullConnectorCore, TSync connectors.CDCSyncCon return nil, err } - // For Postgres to Postgres, update destination schema mapping after applying schema deltas if len(res.TableSchemaDeltas) > 0 { if err := a.updateDestinationSchemaMapping(ctx, config, options, res.TableSchemaDeltas); err != nil { logger.Warn("Failed to update destination schema mapping, normalization may use stale schema", slog.Any("error", err)) - // Don't fail the sync if schema mapping update fails, but log it } } From 1c69caff9d10737a67979848068484de43082069 Mon Sep 17 00:00:00 2001 From: ankitsheoran1 Date: Mon, 17 Nov 2025 19:23:56 +0530 Subject: [PATCH 5/6] remove extra code --- flow/connectors/postgres/cdc.go | 77 +-------------------------------- 1 file changed, 1 insertion(+), 76 deletions(-) diff --git a/flow/connectors/postgres/cdc.go b/flow/connectors/postgres/cdc.go index d4ed309b26..ff6901ad58 100644 --- a/flow/connectors/postgres/cdc.go +++ b/flow/connectors/postgres/cdc.go @@ -187,7 +187,6 @@ func (pgProcessor) NewItems(size int) model.PgItems { return model.NewPgItems(size) } -// How then integer / timestamp type got synced for our test case func (pgProcessor) Process( items model.PgItems, p *PostgresCDCSource, @@ -223,7 +222,6 @@ func (qProcessor) NewItems(size int) model.RecordItems { return model.NewRecordItems(size) } -// what is diff b/w pgProcessor and qProcessor func (qProcessor) Process( items model.RecordItems, p *PostgresCDCSource, @@ -700,17 +698,6 @@ func PullCdcRecords[Items model.Items]( totalFetchedBytes.Add(int64(len(msg.Data))) tableName := rec.GetDestinationTableName() - // Log if this is a DML operation following a schema change - switch rec.(type) { - case *model.InsertRecord[Items], *model.UpdateRecord[Items], *model.DeleteRecord[Items]: - if schema, ok := req.TableNameSchemaMapping[tableName]; ok { - logger.Debug("Processing DML operation", - slog.String("tableName", tableName), - slog.Int("columnCount", len(schema.Columns)), - slog.Any("LSN", xld.WALStart)) - } - } - switch r := rec.(type) { case *model.UpdateRecord[Items]: // tableName here is destination tableName. @@ -811,15 +798,6 @@ func PullCdcRecords[Items model.Items]( for _, col := range tableSchemaDelta.AddedColumns { addedColNames = append(addedColNames, fmt.Sprintf("%s(%s)", col.Name, col.Type)) } - logger.Info("Processing RelationRecord with schema changes", - slog.String("srcTableName", tableSchemaDelta.SrcTableName), - slog.String("dstTableName", tableSchemaDelta.DstTableName), - slog.Int("addedColumnsCount", len(tableSchemaDelta.AddedColumns)), - slog.Any("addedColumns", addedColNames), - slog.Int("droppedColumnsCount", len(tableSchemaDelta.DroppedColumns)), - slog.Any("droppedColumns", tableSchemaDelta.DroppedColumns), - slog.Int64("checkpointID", r.CheckpointID), - slog.Uint64("transactionID", r.TransactionID)) records.AddSchemaDelta(req.TableNameMapping, tableSchemaDelta) logger.Info("Added schema delta to records stream", slog.String("srcTableName", tableSchemaDelta.SrcTableName)) @@ -877,6 +855,7 @@ func (p *PostgresCDCSource) baseRecord(lsn pglogrepl.LSN) model.BaseRecord { } } +// req *model.PullRecordsRequest[Items] because schema-change handling now needs access to the request’s table-schema cache func processMessage[Items model.Items]( ctx context.Context, p *PostgresCDCSource, @@ -926,14 +905,6 @@ func processMessage[Items model.Items]( currColNames = append(currColNames, col.Name) } - logger.Info("Received RelationMessage from WAL", - slog.Uint64("RelationID", uint64(msg.RelationID)), - slog.String("Namespace", msg.Namespace), - slog.String("RelationName", msg.RelationName), - slog.Int("columnCount", len(msg.Columns)), - slog.Any("columnNames", currColNames), - slog.Any("LSN", currentClientXlogPos)) - if _, exists := p.srcTableIDNameMapping[msg.RelationID]; !exists { logger.Warn("RelationMessage received for table not in replication set, skipping", slog.Uint64("RelationID", uint64(msg.RelationID)), @@ -942,12 +913,6 @@ func processMessage[Items model.Items]( return nil, nil } - logger.Debug("Processing RelationMessage for replicated table", - slog.Uint64("RelationID", uint64(msg.RelationID)), - slog.String("Namespace", msg.Namespace), - slog.String("RelationName", msg.RelationName), - slog.Any("Columns", msg.Columns)) - return processRelationMessage[Items](ctx, p, req, currentClientXlogPos, msg) case *pglogrepl.LogicalDecodingMessage: logger.Debug("LogicalDecodingMessage", @@ -1119,15 +1084,6 @@ func processDeleteMessage[Items model.Items]( } // processRelationMessage processes a RelationMessage and returns a TableSchemaDelta -// Currently supported DDL operations: -// - ADD COLUMN (with default values) -// - DROP COLUMN (excluding PeerDB system columns) -// - ALTER COLUMN TYPE (column type changes) -// -// Not currently supported: -// - CREATE/DROP INDEX (indexes are not replicated via logical replication RelationMessages) -// - CREATE/DROP TRIGGER (triggers are not replicated via logical replication RelationMessages) -// - Other DDL operations not captured in RelationMessages func processRelationMessage[Items model.Items]( ctx context.Context, p *PostgresCDCSource, @@ -1191,10 +1147,6 @@ func processRelationMessage[Items model.Items]( prevRelMap[column.Name] = column.Type prevColNames = append(prevColNames, column.Name) } - p.logger.Info("Retrieved previous schema from cache", - slog.String("tableName", currRelDstInfo.Name), - slog.Int("previousColumnCount", len(prevSchema.Columns)), - slog.Any("previousColumns", prevColNames)) currRelMap := make(map[string]string, len(currRel.Columns)) for _, column := range currRel.Columns { @@ -1325,7 +1277,6 @@ func processRelationMessage[Items model.Items]( continue } - // only add to delta if not excluded if _, ok := p.tableNameMapping[p.srcTableIDNameMapping[currRel.RelationID]].Exclude[column.Name]; !ok { schemaDelta.DroppedColumns = append(schemaDelta.DroppedColumns, column.Name) p.logger.Info("Detected dropped column", @@ -1341,7 +1292,6 @@ func processRelationMessage[Items model.Items]( } } - // Log summary of detected schema changes if len(schemaDelta.AddedColumns) > 0 || len(schemaDelta.DroppedColumns) > 0 || len(schemaDelta.TypeChangedColumns) > 0 { addedColNames := make([]string, 0, len(schemaDelta.AddedColumns)) for _, col := range schemaDelta.AddedColumns { @@ -1363,15 +1313,7 @@ func processRelationMessage[Items model.Items]( slog.Any("LSN", lsn)) } - // Update relationMessageMapping IMMEDIATELY so DML operations that follow - // in the same WAL stream use the updated schema p.relationMessageMapping[currRel.RelationID] = currRel - p.logger.Info("Updated relationMessageMapping with new schema", - slog.String("tableName", currRelName), - slog.Int("columnCount", len(currRel.Columns)), - slog.Any("LSN", lsn)) - - // Fetch default values and nullable info for added columns from pg_catalog if len(schemaDelta.AddedColumns) > 0 { addedColNames := make([]string, 0, len(schemaDelta.AddedColumns)) for _, col := range schemaDelta.AddedColumns { @@ -1424,11 +1366,6 @@ func processRelationMessage[Items model.Items]( if info.columnDefault.Valid && info.columnDefault.String != "" { // Store default value - we'll use it in the ADD COLUMN statement column.DefaultValue = info.columnDefault.String - p.logger.Info("Detected column with default value", - slog.String("columnName", column.Name), - slog.String("defaultValue", info.columnDefault.String), - slog.Bool("nullable", column.Nullable), - slog.String("tableName", schemaDelta.SrcTableName)) } else { p.logger.Info("Detected column without default value", slog.String("columnName", column.Name), @@ -1439,10 +1376,7 @@ func processRelationMessage[Items model.Items]( } } - // Update the cached schema mapping after detecting changes - // This ensures the next comparison uses the updated schema if len(schemaDelta.AddedColumns) > 0 || len(schemaDelta.DroppedColumns) > 0 || len(schemaDelta.TypeChangedColumns) > 0 { - // Create updated schema with new columns added, dropped columns removed, and type changes applied updatedSchema := &protos.TableSchema{ Columns: make([]*protos.FieldDescription, 0, len(prevSchema.Columns)), System: prevSchema.System, @@ -1450,7 +1384,6 @@ func processRelationMessage[Items model.Items]( NullableEnabled: prevSchema.NullableEnabled, } - // Build maps for efficient lookup droppedColsMap := make(map[string]bool, len(schemaDelta.DroppedColumns)) for _, droppedCol := range schemaDelta.DroppedColumns { droppedColsMap[droppedCol] = true @@ -1460,13 +1393,9 @@ func processRelationMessage[Items model.Items]( for _, typeChange := range schemaDelta.TypeChangedColumns { typeChangedColsMap[typeChange.ColumnName] = typeChange } - - // Add existing columns that weren't dropped, updating types if changed for _, col := range prevSchema.Columns { if !droppedColsMap[col.Name] { - // Check if this column's type changed if typeChange, ok := typeChangedColsMap[col.Name]; ok { - // Update the column with new type updatedCol := &protos.FieldDescription{ Name: col.Name, Type: typeChange.NewType, @@ -1481,13 +1410,10 @@ func processRelationMessage[Items model.Items]( } } - // Add newly added columns for _, addedCol := range schemaDelta.AddedColumns { updatedSchema.Columns = append(updatedSchema.Columns, addedCol) } - // Update the cached schema mapping in both places - // This ensures DML operations that follow in the same WAL stream use updated schema p.tableNameSchemaMapping[currRelDstInfo.Name] = updatedSchema if req != nil && req.TableNameSchemaMapping != nil { req.TableNameSchemaMapping[currRelDstInfo.Name] = updatedSchema @@ -1502,7 +1428,6 @@ func processRelationMessage[Items model.Items]( slog.Int("totalColumns", len(updatedSchema.Columns))) } - // Return RelationRecord if there are any schema changes (added, dropped, or type changed) if len(schemaDelta.AddedColumns) > 0 || len(schemaDelta.DroppedColumns) > 0 || len(schemaDelta.TypeChangedColumns) > 0 { p.logger.Info("Returning RelationRecord with schema delta", slog.String("srcTableName", schemaDelta.SrcTableName), From 10eba32d91b7db908b0682d5b4bf8c83d367ce7c Mon Sep 17 00:00:00 2001 From: ankitsheoran1 Date: Mon, 17 Nov 2025 20:08:56 +0530 Subject: [PATCH 6/6] linter fix --- flow/activities/flowable.go | 10 +- .../postgres/index_trigger_sync_test.go | 117 ------------------ flow/e2e/postgres_test.go | 19 --- 3 files changed, 1 insertion(+), 145 deletions(-) diff --git a/flow/activities/flowable.go b/flow/activities/flowable.go index 40bc6e384b..8033c5572f 100644 --- a/flow/activities/flowable.go +++ b/flow/activities/flowable.go @@ -282,7 +282,7 @@ func (a *FlowableActivity) CreateNormalizedTable( } // SyncIndexesAndTriggers syncs indexes and triggers from source to destination -// This is called once during initial setup, not for on-the-fly changes +// This is called once during initial setup, not for on-the-fly changes, thats reason its added into flow func (a *FlowableActivity) SyncIndexesAndTriggers( ctx context.Context, config *protos.SetupNormalizedTableBatchInput, @@ -290,8 +290,6 @@ func (a *FlowableActivity) SyncIndexesAndTriggers( logger := internal.LoggerFromCtx(ctx) ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName) - // Only sync for Postgres to Postgres - // Check if destination is Postgres dstConn, dstClose, err := connectors.GetByNameAs[connectors.NormalizedTablesConnector](ctx, config.Env, a.CatalogPool, config.PeerName) if err != nil { if errors.Is(err, errors.ErrUnsupported) { @@ -302,14 +300,11 @@ func (a *FlowableActivity) SyncIndexesAndTriggers( } defer dstClose(ctx) - // Check if destination connector is Postgres pgDstConn, ok := dstConn.(*connpostgres.PostgresConnector) if !ok { logger.Info("Destination is not Postgres, skipping index/trigger sync") return nil } - - // Get source connector (use SourcePeerName if available, otherwise skip) if config.SourcePeerName == "" { logger.Info("Source peer name not provided, skipping index/trigger sync") return nil @@ -321,14 +316,11 @@ func (a *FlowableActivity) SyncIndexesAndTriggers( } defer srcClose(ctx) - // Check if source connector is Postgres pgSrcConn, ok := srcConn.(*connpostgres.PostgresConnector) if !ok { logger.Info("Source is not Postgres, skipping index/trigger sync") return nil } - - // Sync indexes, triggers, and constraints a.Alerter.LogFlowInfo(ctx, config.FlowName, "Syncing indexes, triggers, and constraints from source to destination") if err := pgDstConn.SyncIndexesAndTriggers(ctx, config.TableMappings, pgSrcConn); err != nil { return a.Alerter.LogFlowError(ctx, config.FlowName, fmt.Errorf("failed to sync indexes, triggers, and constraints: %w", err)) diff --git a/flow/connectors/postgres/index_trigger_sync_test.go b/flow/connectors/postgres/index_trigger_sync_test.go index 4e74ed1b4f..e7919a2e69 100644 --- a/flow/connectors/postgres/index_trigger_sync_test.go +++ b/flow/connectors/postgres/index_trigger_sync_test.go @@ -26,16 +26,10 @@ type IndexTriggerSyncTestSuite struct { func SetupIndexTriggerSyncSuite(t *testing.T) IndexTriggerSyncTestSuite { t.Helper() - - // Create source connector sourceConn, err := NewPostgresConnector(t.Context(), nil, internal.GetCatalogPostgresConfigFromEnv(t.Context())) require.NoError(t, err) - - // Create destination connector (can be same DB for testing) destConn, err := NewPostgresConnector(t.Context(), nil, internal.GetCatalogPostgresConfigFromEnv(t.Context())) require.NoError(t, err) - - // Create test schemas sourceSchema := "src_idx_" + strings.ToLower(shared.RandomString(8)) destSchema := "dst_idx_" + strings.ToLower(shared.RandomString(8)) @@ -89,8 +83,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncIndexes() { ctx := s.t.Context() sourceTable := s.sourceSchema + ".test_table" destTable := s.destSchema + ".test_table" - - // Create source table with indexes _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( id INT PRIMARY KEY, @@ -99,8 +91,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncIndexes() { created_at TIMESTAMPTZ DEFAULT NOW() )`, sourceTable)) require.NoError(s.t, err) - - // Create indexes on source _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("CREATE INDEX idx_name ON %s(name)", sourceTable)) require.NoError(s.t, err) @@ -109,11 +99,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncIndexes() { _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("CREATE INDEX idx_created_at ON %s(created_at DESC)", sourceTable)) require.NoError(s.t, err) - - // Create destination table (without indexes) - // Note: In the real PeerDB flow, CreateNormalizedTable creates the destination table automatically - // before SyncIndexesAndTriggers is called. Here we create it manually to simulate that state - // for unit testing SyncIndexesAndTriggers in isolation. _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( id INT PRIMARY KEY, @@ -123,7 +108,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncIndexes() { )`, destTable)) require.NoError(s.t, err) - // Sync indexes tableMappings := []*protos.TableMapping{ { SourceTableIdentifier: sourceTable, @@ -133,20 +117,14 @@ func (s IndexTriggerSyncTestSuite) TestSyncIndexes() { err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) require.NoError(s.t, err) - - // Verify indexes were created on destination destTableParsed, err := utils.ParseSchemaTable(destTable) require.NoError(s.t, err) destIndexes, err := s.destConn.getIndexesForTable(ctx, destTableParsed) require.NoError(s.t, err) - - // Should have primary key + 3 indexes = 4 total (excluding primary key from sync) indexNames := make(map[string]bool) for _, idx := range destIndexes { indexNames[idx.IndexName] = true } - - // Check that our indexes exist (primary key is created automatically) require.True(s.t, indexNames["test_table_pkey"], "Primary key should exist") require.True(s.t, indexNames["idx_name"], "idx_name should be synced") require.True(s.t, indexNames["idx_email"], "idx_email should be synced") @@ -157,8 +135,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncTriggers() { ctx := s.t.Context() sourceTable := s.sourceSchema + ".test_trigger_table" destTable := s.destSchema + ".test_trigger_table" - - // Create a function for the trigger _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` CREATE OR REPLACE FUNCTION %s.update_timestamp() RETURNS TRIGGER AS $$ @@ -168,8 +144,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncTriggers() { END; $$ LANGUAGE plpgsql`, s.sourceSchema)) require.NoError(s.t, err) - - // Create the same function in destination schema _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` CREATE OR REPLACE FUNCTION %s.update_timestamp() RETURNS TRIGGER AS $$ @@ -179,8 +153,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncTriggers() { END; $$ LANGUAGE plpgsql`, s.destSchema)) require.NoError(s.t, err) - - // Create source table with trigger _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( id INT PRIMARY KEY, @@ -195,10 +167,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncTriggers() { FOR EACH ROW EXECUTE FUNCTION %s.update_timestamp()`, sourceTable, s.sourceSchema)) require.NoError(s.t, err) - - // Create destination table (without trigger) - // Note: In the real PeerDB flow, CreateNormalizedTable creates the destination table automatically - // before SyncIndexesAndTriggers is called. Here we create it manually to simulate that state. _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( id INT PRIMARY KEY, @@ -206,8 +174,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncTriggers() { updated_at TIMESTAMPTZ DEFAULT NOW() )`, destTable)) require.NoError(s.t, err) - - // Sync triggers tableMappings := []*protos.TableMapping{ { SourceTableIdentifier: sourceTable, @@ -217,14 +183,10 @@ func (s IndexTriggerSyncTestSuite) TestSyncTriggers() { err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) require.NoError(s.t, err) - - // Verify trigger was created on destination destTableParsed, err := utils.ParseSchemaTable(destTable) require.NoError(s.t, err) destTriggers, err := s.destConn.getTriggersForTable(ctx, destTableParsed) require.NoError(s.t, err) - - // Should have 1 trigger require.Len(s.t, destTriggers, 1, "Should have 1 trigger") require.Equal(s.t, "update_updated_at", destTriggers[0].TriggerName) } @@ -254,8 +216,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncIndexesAndTriggersTogether() { END; $$ LANGUAGE plpgsql`, s.destSchema)) require.NoError(s.t, err) - - // Create source table with index and trigger _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( id INT PRIMARY KEY, @@ -272,18 +232,12 @@ func (s IndexTriggerSyncTestSuite) TestSyncIndexesAndTriggersTogether() { FOR EACH ROW EXECUTE FUNCTION %s.log_changes()`, sourceTable, s.sourceSchema)) require.NoError(s.t, err) - - // Create destination table (without index and trigger) - // Note: In the real PeerDB flow, CreateNormalizedTable creates the destination table automatically - // before SyncIndexesAndTriggers is called. Here we create it manually to simulate that state. _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( id INT PRIMARY KEY, data TEXT )`, destTable)) require.NoError(s.t, err) - - // Sync both indexes and triggers tableMappings := []*protos.TableMapping{ { SourceTableIdentifier: sourceTable, @@ -293,8 +247,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncIndexesAndTriggersTogether() { err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) require.NoError(s.t, err) - - // Verify index was created destTableParsed, err := utils.ParseSchemaTable(destTable) require.NoError(s.t, err) destIndexes, err := s.destConn.getIndexesForTable(ctx, destTableParsed) @@ -305,8 +257,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncIndexesAndTriggersTogether() { indexNames[idx.IndexName] = true } require.True(s.t, indexNames["idx_data"], "idx_data should be synced") - - // Verify trigger was created destTriggers, err := s.destConn.getTriggersForTable(ctx, destTableParsed) require.NoError(s.t, err) @@ -318,8 +268,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncCheckConstraints() { ctx := s.t.Context() sourceTable := s.sourceSchema + ".test_check_table" destTable := s.destSchema + ".test_check_table" - - // Create source table with check constraints _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( id INT PRIMARY KEY, @@ -328,8 +276,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncCheckConstraints() { email TEXT )`, sourceTable)) require.NoError(s.t, err) - - // Add check constraints _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` ALTER TABLE %s ADD CONSTRAINT check_name_length CHECK (char_length(name) >= 3)`, sourceTable)) @@ -344,10 +290,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncCheckConstraints() { ALTER TABLE %s ADD CONSTRAINT check_email_format CHECK (email IS NULL OR email LIKE '%%@%%')`, sourceTable)) require.NoError(s.t, err) - - // Create destination table (without constraints) - // Note: In the real PeerDB flow, CreateNormalizedTable creates the destination table automatically - // before SyncIndexesAndTriggers is called. Here we create it manually to simulate that state. _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( id INT PRIMARY KEY, @@ -357,7 +299,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncCheckConstraints() { )`, destTable)) require.NoError(s.t, err) - // Sync constraints tableMappings := []*protos.TableMapping{ { SourceTableIdentifier: sourceTable, @@ -367,14 +308,10 @@ func (s IndexTriggerSyncTestSuite) TestSyncCheckConstraints() { err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) require.NoError(s.t, err) - - // Verify constraints were created on destination destTableParsed, err := utils.ParseSchemaTable(destTable) require.NoError(s.t, err) destConstraints, err := s.destConn.getConstraintsForTable(ctx, destTableParsed) require.NoError(s.t, err) - - // Should have 3 check constraints constraintNames := make(map[string]bool) for _, constraint := range destConstraints { if constraint.ConstraintType == "c" { @@ -389,8 +326,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncCheckConstraints() { func (s IndexTriggerSyncTestSuite) TestSyncForeignKeyConstraints() { ctx := s.t.Context() - - // Create parent table on both source and destination parentSourceTable := s.sourceSchema + ".parent_table" parentDestTable := s.destSchema + ".parent_table" @@ -407,8 +342,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncForeignKeyConstraints() { name TEXT )`, parentDestTable)) require.NoError(s.t, err) - - // Create child table on source with foreign key childSourceTable := s.sourceSchema + ".child_table" childDestTable := s.destSchema + ".child_table" @@ -419,16 +352,10 @@ func (s IndexTriggerSyncTestSuite) TestSyncForeignKeyConstraints() { name TEXT )`, childSourceTable)) require.NoError(s.t, err) - - // Add foreign key constraint _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` ALTER TABLE %s ADD CONSTRAINT fk_parent FOREIGN KEY (parent_id) REFERENCES %s(id)`, childSourceTable, parentSourceTable)) require.NoError(s.t, err) - - // Create child table on destination (without foreign key) - // Note: In the real PeerDB flow, CreateNormalizedTable creates the destination table automatically - // before SyncIndexesAndTriggers is called. Here we create it manually to simulate that state. _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( id INT PRIMARY KEY, @@ -436,8 +363,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncForeignKeyConstraints() { name TEXT )`, childDestTable)) require.NoError(s.t, err) - - // Sync constraints (both tables need to be in the mapping for FK to work) tableMappings := []*protos.TableMapping{ { SourceTableIdentifier: parentSourceTable, @@ -451,19 +376,14 @@ func (s IndexTriggerSyncTestSuite) TestSyncForeignKeyConstraints() { err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) require.NoError(s.t, err) - - // Verify foreign key constraint was created on destination childDestTableParsed, err := utils.ParseSchemaTable(childDestTable) require.NoError(s.t, err) destConstraints, err := s.destConn.getConstraintsForTable(ctx, childDestTableParsed) require.NoError(s.t, err) - - // Should have 1 foreign key constraint fkFound := false for _, constraint := range destConstraints { if constraint.ConstraintType == "f" && constraint.ConstraintName == "fk_parent" { fkFound = true - // Verify the constraint definition references the correct destination table require.Contains(s.t, constraint.ConstraintDef, parentDestTable, "Foreign key should reference destination parent table") break @@ -476,9 +396,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncConstraintsTogether() { ctx := s.t.Context() sourceTable := s.sourceSchema + ".test_constraints_table" destTable := s.destSchema + ".test_constraints_table" - - // Create source table with both check and foreign key constraints - // First create a referenced table refSourceTable := s.sourceSchema + ".ref_table" refDestTable := s.destSchema + ".ref_table" @@ -495,8 +412,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncConstraintsTogether() { code TEXT UNIQUE )`, refDestTable)) require.NoError(s.t, err) - - // Create main table with constraints _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( id INT PRIMARY KEY, @@ -505,22 +420,14 @@ func (s IndexTriggerSyncTestSuite) TestSyncConstraintsTogether() { status TEXT )`, sourceTable)) require.NoError(s.t, err) - - // Add check constraint _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` ALTER TABLE %s ADD CONSTRAINT check_name_length CHECK (char_length(name) >= 2)`, sourceTable)) require.NoError(s.t, err) - - // Add foreign key constraint _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` ALTER TABLE %s ADD CONSTRAINT fk_ref FOREIGN KEY (ref_id) REFERENCES %s(id)`, sourceTable, refSourceTable)) require.NoError(s.t, err) - - // Create destination table (without constraints) - // Note: In the real PeerDB flow, CreateNormalizedTable creates the destination table automatically - // before SyncIndexesAndTriggers is called. Here we create it manually to simulate that state. _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( id INT PRIMARY KEY, @@ -529,8 +436,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncConstraintsTogether() { status TEXT )`, destTable)) require.NoError(s.t, err) - - // Sync constraints (both tables in mapping) tableMappings := []*protos.TableMapping{ { SourceTableIdentifier: refSourceTable, @@ -544,8 +449,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncConstraintsTogether() { err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) require.NoError(s.t, err) - - // Verify constraints were created destTableParsed, err := utils.ParseSchemaTable(destTable) require.NoError(s.t, err) destConstraints, err := s.destConn.getConstraintsForTable(ctx, destTableParsed) @@ -564,8 +467,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncAllTogether() { ctx := s.t.Context() sourceTable := s.sourceSchema + ".test_all_table" destTable := s.destSchema + ".test_all_table" - - // Create function for trigger _, err := s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` CREATE OR REPLACE FUNCTION %s.audit_func() RETURNS TRIGGER AS $$ @@ -583,8 +484,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncAllTogether() { END; $$ LANGUAGE plpgsql`, s.destSchema)) require.NoError(s.t, err) - - // Create source table with index, trigger, and constraints _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( id INT PRIMARY KEY, @@ -593,12 +492,8 @@ func (s IndexTriggerSyncTestSuite) TestSyncAllTogether() { created_at TIMESTAMPTZ DEFAULT NOW() )`, sourceTable)) require.NoError(s.t, err) - - // Add index _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf("CREATE INDEX idx_name ON %s(name)", sourceTable)) require.NoError(s.t, err) - - // Add trigger _, err = s.sourceConn.conn.Exec(ctx, fmt.Sprintf(` CREATE TRIGGER audit_trigger AFTER INSERT ON %s @@ -611,10 +506,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncAllTogether() { ALTER TABLE %s ADD CONSTRAINT check_name_length CHECK (char_length(name) >= 3)`, sourceTable)) require.NoError(s.t, err) - - // Create destination table (without index, trigger, or constraints) - // Note: In the real PeerDB flow, CreateNormalizedTable creates the destination table automatically - // before SyncIndexesAndTriggers is called. Here we create it manually to simulate that state. _, err = s.destConn.conn.Exec(ctx, fmt.Sprintf(` CREATE TABLE %s ( id INT PRIMARY KEY, @@ -623,8 +514,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncAllTogether() { created_at TIMESTAMPTZ DEFAULT NOW() )`, destTable)) require.NoError(s.t, err) - - // Sync everything tableMappings := []*protos.TableMapping{ { SourceTableIdentifier: sourceTable, @@ -634,8 +523,6 @@ func (s IndexTriggerSyncTestSuite) TestSyncAllTogether() { err = s.destConn.SyncIndexesAndTriggers(ctx, tableMappings, s.sourceConn) require.NoError(s.t, err) - - // Verify index was created destTableParsed, err := utils.ParseSchemaTable(destTable) require.NoError(s.t, err) destIndexes, err := s.destConn.getIndexesForTable(ctx, destTableParsed) @@ -646,14 +533,10 @@ func (s IndexTriggerSyncTestSuite) TestSyncAllTogether() { indexNames[idx.IndexName] = true } require.True(s.t, indexNames["idx_name"], "idx_name should be synced") - - // Verify trigger was created destTriggers, err := s.destConn.getTriggersForTable(ctx, destTableParsed) require.NoError(s.t, err) require.Len(s.t, destTriggers, 1, "Should have 1 trigger") require.Equal(s.t, "audit_trigger", destTriggers[0].TriggerName) - - // Verify constraint was created destConstraints, err := s.destConn.getConstraintsForTable(ctx, destTableParsed) require.NoError(s.t, err) diff --git a/flow/e2e/postgres_test.go b/flow/e2e/postgres_test.go index 16b1b952ea..1490e30c70 100644 --- a/flow/e2e/postgres_test.go +++ b/flow/e2e/postgres_test.go @@ -1287,15 +1287,12 @@ func (s PeerFlowE2ETestSuitePG) TestResync(tableName string) { RequireEnvCanceled(s.t, env) } -// Test_Indexes_Triggers_Constraints_PG tests that indexes, triggers, and constraints -// are synced from source to destination during initial setup func (s PeerFlowE2ETestSuitePG) Test_Indexes_Triggers_Constraints_PG() { tc := NewTemporalClient(s.t) srcTableName := s.attachSchemaSuffix("test_idx_trig_const") dstTableName := s.attachSchemaSuffix("test_idx_trig_const_dst") - // Create a trigger function on source _, err := s.Conn().Exec(s.t.Context(), fmt.Sprintf(` CREATE OR REPLACE FUNCTION %s.update_timestamp() RETURNS TRIGGER AS $$ @@ -1306,7 +1303,6 @@ func (s PeerFlowE2ETestSuitePG) Test_Indexes_Triggers_Constraints_PG() { $$ LANGUAGE plpgsql`, Schema(s))) require.NoError(s.t, err) - // Create source table with indexes, triggers, and constraints _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` CREATE TABLE IF NOT EXISTS %s ( id SERIAL PRIMARY KEY, @@ -1317,8 +1313,6 @@ func (s PeerFlowE2ETestSuitePG) Test_Indexes_Triggers_Constraints_PG() { updated_at TIMESTAMPTZ DEFAULT NOW() )`, srcTableName)) require.NoError(s.t, err) - - // Create indexes on source _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` CREATE INDEX idx_name ON %s(name)`, srcTableName)) require.NoError(s.t, err) @@ -1330,16 +1324,12 @@ func (s PeerFlowE2ETestSuitePG) Test_Indexes_Triggers_Constraints_PG() { _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` CREATE INDEX idx_created_at ON %s(created_at DESC)`, srcTableName)) require.NoError(s.t, err) - - // Create trigger on source _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` CREATE TRIGGER update_updated_at BEFORE UPDATE ON %s FOR EACH ROW EXECUTE FUNCTION %s.update_timestamp()`, srcTableName, Schema(s))) require.NoError(s.t, err) - - // Create check constraints on source _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` ALTER TABLE %s ADD CONSTRAINT check_name_length CHECK (char_length(name) >= 3)`, srcTableName)) @@ -1367,10 +1357,7 @@ func (s PeerFlowE2ETestSuitePG) Test_Indexes_Triggers_Constraints_PG() { env := ExecutePeerflow(s.t, tc, flowConnConfig) SetupCDCFlowStatusQuery(s.t, env, flowConnConfig) - - // Wait for initial setup to complete (this includes index/trigger/constraint sync) EnvWaitFor(s.t, env, 3*time.Minute, "waiting for initial setup", func() bool { - // Check if destination table exists var exists bool err := s.Conn().QueryRow(s.t.Context(), ` SELECT EXISTS ( @@ -1379,8 +1366,6 @@ func (s PeerFlowE2ETestSuitePG) Test_Indexes_Triggers_Constraints_PG() { )`, Schema(s), "test_idx_trig_const_dst").Scan(&exists) return err == nil && exists }) - - // Verify indexes were synced s.t.Log("Verifying indexes were synced...") indexQuery := ` SELECT indexname @@ -1403,7 +1388,6 @@ func (s PeerFlowE2ETestSuitePG) Test_Indexes_Triggers_Constraints_PG() { require.True(s.t, indexNames["idx_email"], "idx_email should be synced") require.True(s.t, indexNames["idx_created_at"], "idx_created_at should be synced") - // Verify trigger was synced s.t.Log("Verifying trigger was synced...") triggerQuery := ` SELECT t.tgname @@ -1419,7 +1403,6 @@ func (s PeerFlowE2ETestSuitePG) Test_Indexes_Triggers_Constraints_PG() { require.NoError(s.t, err) require.Equal(s.t, "update_updated_at", triggerName, "update_updated_at trigger should be synced") - // Verify constraints were synced s.t.Log("Verifying constraints were synced...") constraintQuery := ` SELECT con.conname @@ -1448,12 +1431,10 @@ func (s PeerFlowE2ETestSuitePG) Test_Indexes_Triggers_Constraints_PG() { s.t.Log("All indexes, triggers, and constraints were successfully synced!") - // Insert some data to verify the trigger works _, err = s.Conn().Exec(s.t.Context(), fmt.Sprintf(` INSERT INTO %s(name, email, age) VALUES ('test', 'test@example.com', 25)`, srcTableName)) EnvNoError(s.t, env, err) - // Wait for data to sync EnvWaitForEqualTablesWithNames(env, s, "waiting for data sync", srcTableName, dstTableName, "id,name,email,age") env.Cancel(s.t.Context())