Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
309 changes: 309 additions & 0 deletions flow/activities/flowable.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,315 @@ func (a *FlowableActivity) CreateNormalizedTable(
}, nil
}

// CreateNormalizedTableIndexes creates indexes on normalized tables (Postgres-specific).
func (a *FlowableActivity) CreateNormalizedTableIndexes(
ctx context.Context,
config *protos.SetupNormalizedTableBatchInput,
) error {
numIndexesCreated := atomic.Uint32{}
numTablesToProcess := atomic.Int32{}

shutdown := heartbeatRoutine(ctx, func() string {
return fmt.Sprintf("creating indexes - %d of %d tables processed", numIndexesCreated.Load(), numTablesToProcess.Load())
})
defer shutdown()

logger := internal.LoggerFromCtx(ctx)
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName)

conn, connClose, err := connectors.GetByNameAs[connectors.NormalizedTablesConnector](ctx, config.Env, a.CatalogPool, config.PeerName)
if err != nil {
return a.Alerter.LogFlowError(ctx, config.FlowName, fmt.Errorf("failed to get connector: %w", err))
}
defer connClose(ctx)

// Check if this is a Postgres connector - only Postgres supports index migration
pgConn, ok := conn.(*connpostgres.PostgresConnector)
if !ok {
logger.Info("Connector does not support index migration, skipping",
slog.String("flowName", config.FlowName))
return nil
}

tableNameSchemaMapping, err := a.getTableNameSchemaMapping(ctx, config.FlowName)
if err != nil {
return err
}

numTablesToProcess.Store(int32(len(tableNameSchemaMapping)))

// Process each table's indexes in a separate transaction for fault isolation
for _, tableMapping := range config.TableMappings {
tableIdentifier := tableMapping.DestinationTableIdentifier
tableSchema := tableNameSchemaMapping[tableIdentifier]

if len(tableSchema.Indexes) == 0 {
logger.Info("No indexes to create for table", slog.String("table", tableIdentifier))
numIndexesCreated.Add(1)
continue
}

// Each table gets its own transaction for index creation
tx, err := pgConn.Conn().Begin(ctx)
if err != nil {
a.Alerter.LogFlowWarning(ctx, config.FlowName,
fmt.Errorf("failed to start transaction for index creation on table %s: %w", tableIdentifier, err))
continue
}

parsedTable, err := utils.ParseSchemaTable(tableIdentifier)
if err != nil {
shared.RollbackTx(tx, logger)
a.Alerter.LogFlowWarning(ctx, config.FlowName,
fmt.Errorf("failed to parse table identifier %s: %w", tableIdentifier, err))
continue
}

logger.Info("Creating indexes on destination table",
slog.String("table", tableIdentifier),
slog.Int("indexCount", len(tableSchema.Indexes)))

if err := pgConn.CreateTableIndexesFromSchema(ctx, tx, tableSchema, parsedTable); err != nil {
shared.RollbackTx(tx, logger)
// Log warning but don't fail - indexes are optional
a.Alerter.LogFlowWarning(ctx, config.FlowName,
fmt.Errorf("failed to create indexes for table %s: %w", tableIdentifier, err))
} else {
if err := tx.Commit(ctx); err != nil {
a.Alerter.LogFlowWarning(ctx, config.FlowName,
fmt.Errorf("failed to commit index creation for table %s: %w", tableIdentifier, err))
} else {
logger.Info("Successfully created indexes for table", slog.String("table", tableIdentifier))
}
}

numIndexesCreated.Add(1)
}

a.Alerter.LogFlowInfo(ctx, config.FlowName, "Index creation completed for all tables")
return nil
}

// CreateNormalizedTableFunctions creates functions on normalized table schemas (Postgres-specific).
func (a *FlowableActivity) CreateNormalizedTableFunctions(
ctx context.Context,
config *protos.SetupNormalizedTableBatchInput,
) error {
numSchemasProcessed := atomic.Uint32{}

shutdown := heartbeatRoutine(ctx, func() string {
return fmt.Sprintf("creating functions - %d schemas processed", numSchemasProcessed.Load())
})
defer shutdown()

logger := internal.LoggerFromCtx(ctx)
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName)

conn, connClose, err := connectors.GetByNameAs[connectors.NormalizedTablesConnector](ctx, config.Env, a.CatalogPool, config.PeerName)
if err != nil {
return a.Alerter.LogFlowError(ctx, config.FlowName, fmt.Errorf("failed to get connector: %w", err))
}
defer connClose(ctx)

// Check if this is a Postgres connector - only Postgres supports function migration
pgConn, ok := conn.(*connpostgres.PostgresConnector)
if !ok {
logger.Info("Connector does not support function migration, skipping",
slog.String("flowName", config.FlowName))
return nil
}

tableNameSchemaMapping, err := a.getTableNameSchemaMapping(ctx, config.FlowName)
if err != nil {
return err
}

// Group tables by schema since functions are schema-level objects
schemaToTables := make(map[string][]*protos.TableSchema)
for _, tableSchema := range tableNameSchemaMapping {
parsedTable, err := utils.ParseSchemaTable(tableSchema.TableIdentifier)
if err != nil {
logger.Warn("Failed to parse table identifier",
slog.String("table", tableSchema.TableIdentifier),
slog.Any("error", err))
continue
}
schemaToTables[parsedTable.Schema] = append(schemaToTables[parsedTable.Schema], tableSchema)
}

// Process each schema's functions in a separate transaction
for schemaName, tables := range schemaToTables {
// Collect all functions from all tables in this schema
allFunctions := make(map[string]*protos.FunctionDescription)
var representativeTable *protos.TableSchema

for _, tableSchema := range tables {
if representativeTable == nil {
representativeTable = tableSchema
}
for _, fn := range tableSchema.Functions {
// Use function name as key to deduplicate (same function might be on multiple tables)
allFunctions[fn.FunctionName] = fn
}
}

if len(allFunctions) == 0 {
logger.Info("No functions to create for schema", slog.String("schema", schemaName))
numSchemasProcessed.Add(1)
continue
}

// Create a temporary table schema with all functions for this schema
schemaFunctions := &protos.TableSchema{
TableIdentifier: representativeTable.TableIdentifier,
Functions: make([]*protos.FunctionDescription, 0, len(allFunctions)),
}
for _, fn := range allFunctions {
schemaFunctions.Functions = append(schemaFunctions.Functions, fn)
}

// Each schema gets its own transaction for function creation
tx, err := pgConn.Conn().Begin(ctx)
if err != nil {
a.Alerter.LogFlowWarning(ctx, config.FlowName,
fmt.Errorf("failed to start transaction for function creation on schema %s: %w", schemaName, err))
continue
}

parsedTable, err := utils.ParseSchemaTable(schemaFunctions.TableIdentifier)
if err != nil {
shared.RollbackTx(tx, logger)
a.Alerter.LogFlowWarning(ctx, config.FlowName,
fmt.Errorf("failed to parse table identifier for schema %s: %w", schemaName, err))
continue
}

logger.Info("Creating functions on destination schema",
slog.String("schema", schemaName),
slog.Int("functionCount", len(schemaFunctions.Functions)))

createdFunctions := pgConn.CreateTableFunctionsAndTrack(ctx, tx, schemaFunctions, parsedTable)

successCount := 0
for _, success := range createdFunctions {
if success {
successCount++
}
}

if err := tx.Commit(ctx); err != nil {
a.Alerter.LogFlowWarning(ctx, config.FlowName,
fmt.Errorf("failed to commit function creation for schema %s: %w", schemaName, err))
} else {
logger.Info("Function creation summary",
slog.String("schema", schemaName),
slog.Int("total", len(schemaFunctions.Functions)),
slog.Int("successful", successCount),
slog.Int("failed", len(schemaFunctions.Functions)-successCount))
}

numSchemasProcessed.Add(1)
}

a.Alerter.LogFlowInfo(ctx, config.FlowName, "Function creation completed for all schemas")
return nil
}

// CreateNormalizedTableTriggers creates triggers on normalized tables (Postgres-specific).
// It must run AFTER CreateNormalizedTableFunctions since triggers may depend on functions.
func (a *FlowableActivity) CreateNormalizedTableTriggers(
ctx context.Context,
config *protos.SetupNormalizedTableBatchInput,
) error {
numTriggersCreated := atomic.Uint32{}
numTablesToProcess := atomic.Int32{}

shutdown := heartbeatRoutine(ctx, func() string {
return fmt.Sprintf("creating triggers - %d of %d tables processed", numTriggersCreated.Load(), numTablesToProcess.Load())
})
defer shutdown()

logger := internal.LoggerFromCtx(ctx)
ctx = context.WithValue(ctx, shared.FlowNameKey, config.FlowName)

conn, connClose, err := connectors.GetByNameAs[connectors.NormalizedTablesConnector](ctx, config.Env, a.CatalogPool, config.PeerName)
if err != nil {
return a.Alerter.LogFlowError(ctx, config.FlowName, fmt.Errorf("failed to get connector: %w", err))
}
defer connClose(ctx)

// Check if this is a Postgres connector - only Postgres supports trigger migration
pgConn, ok := conn.(*connpostgres.PostgresConnector)
if !ok {
logger.Info("Connector does not support trigger migration, skipping",
slog.String("flowName", config.FlowName))
return nil
}

tableNameSchemaMapping, err := a.getTableNameSchemaMapping(ctx, config.FlowName)
if err != nil {
return err
}

numTablesToProcess.Store(int32(len(tableNameSchemaMapping)))

// Build a global map of created functions by checking what exists in the destination
createdFunctions := make(map[string]bool)
for _, tableSchema := range tableNameSchemaMapping {
for _, fn := range tableSchema.Functions {
createdFunctions[fn.FunctionName] = true
}
}

for _, tableMapping := range config.TableMappings {
tableIdentifier := tableMapping.DestinationTableIdentifier
tableSchema := tableNameSchemaMapping[tableIdentifier]

if len(tableSchema.Triggers) == 0 {
logger.Info("No triggers to create for table", slog.String("table", tableIdentifier))
numTriggersCreated.Add(1)
continue
}

tx, err := pgConn.Conn().Begin(ctx)
if err != nil {
a.Alerter.LogFlowWarning(ctx, config.FlowName,
fmt.Errorf("failed to start transaction for trigger creation on table %s: %w", tableIdentifier, err))
continue
}

parsedTable, err := utils.ParseSchemaTable(tableIdentifier)
if err != nil {
shared.RollbackTx(tx, logger)
a.Alerter.LogFlowWarning(ctx, config.FlowName,
fmt.Errorf("failed to parse table identifier %s: %w", tableIdentifier, err))
continue
}

logger.Info("Creating triggers on destination table",
slog.String("table", tableIdentifier),
slog.Int("triggerCount", len(tableSchema.Triggers)))

if err := pgConn.CreateTableTriggersWithDependencyCheck(ctx, tx, tableSchema, parsedTable, createdFunctions); err != nil {
shared.RollbackTx(tx, logger)
a.Alerter.LogFlowWarning(ctx, config.FlowName,
fmt.Errorf("failed to create triggers for table %s: %w", tableIdentifier, err))
} else {
if err := tx.Commit(ctx); err != nil {
a.Alerter.LogFlowWarning(ctx, config.FlowName,
fmt.Errorf("failed to commit trigger creation for table %s: %w", tableIdentifier, err))
} else {
logger.Info("Successfully created triggers for table", slog.String("table", tableIdentifier))
}
}

numTriggersCreated.Add(1)
}

a.Alerter.LogFlowInfo(ctx, config.FlowName, "Trigger creation completed for all tables")
return nil
}

func (a *FlowableActivity) SyncFlow(
ctx context.Context,
config *protos.FlowConnectionConfigsCore,
Expand Down
38 changes: 38 additions & 0 deletions flow/connectors/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -1041,18 +1041,54 @@ func (c *PostgresConnector) getTableSchemaForTable(
if err := rows.Err(); err != nil {
return nil, fmt.Errorf("error iterating over table schema: %w", err)
}

rows.Close()

// if we have no pkey, we will use all columns as the pkey for the MERGE statement
if replicaIdentityType == ReplicaIdentityFull && len(pKeyCols) == 0 {
pKeyCols = columnNames
}

// Fetch indexes for the table
indexes, err := c.getTableIndexes(ctx, schemaTable)
c.logger.Info("fetched indexes", slog.String("table", schemaTable.String()), slog.Any("indexes", indexes))
if err != nil {
c.logger.Warn("Failed to fetch indexes for table, continuing without indexes",
slog.String("table", tm.SourceTableIdentifier),
slog.Any("error", err))
indexes = nil
}

// Fetch functions for the table schema
functions, err := c.getTableFunctions(ctx, schemaTable)
c.logger.Info("fetched functions", slog.String("table", schemaTable.String()), slog.Any("functions", functions))
if err != nil {
c.logger.Warn("Failed to fetch functions for table, continuing without functions",
slog.String("table", tm.SourceTableIdentifier),
slog.Any("error", err))
functions = nil
}

// Fetch triggers for the table schema
triggers, err := c.getTableTriggers(ctx, schemaTable)
c.logger.Info("fetched triggers", slog.String("table", schemaTable.String()), slog.Any("triggers", triggers))
if err != nil {
c.logger.Warn("Failed to fetch triggers for table, continuing without triggers",
slog.String("table", tm.SourceTableIdentifier),
slog.Any("error", err))
triggers = nil
}

return &protos.TableSchema{
TableIdentifier: tm.SourceTableIdentifier,
PrimaryKeyColumns: pKeyCols,
IsReplicaIdentityFull: replicaIdentityType == ReplicaIdentityFull,
Columns: columns,
NullableEnabled: nullableEnabled,
System: system,
Indexes: indexes,
Functions: functions,
Triggers: triggers,
}, nil
}

Expand Down Expand Up @@ -1109,6 +1145,8 @@ func (c *PostgresConnector) SetupNormalizedTable(
return false, fmt.Errorf("error while creating normalized table: %w", err)
}

c.logger.Info("Successfully created normalized table", slog.String("table", tableIdentifier))

return false, nil
}

Expand Down
Loading
Loading