diff --git a/pb/c1/connector/v2/resource.pb.go b/pb/c1/connector/v2/resource.pb.go index f09a6418b..277f2129f 100644 --- a/pb/c1/connector/v2/resource.pb.go +++ b/pb/c1/connector/v2/resource.pb.go @@ -131,8 +131,13 @@ type ResourceType struct { Annotations []*anypb.Any `protobuf:"bytes,4,rep,name=annotations,proto3" json:"annotations,omitempty"` Description string `protobuf:"bytes,5,opt,name=description,proto3" json:"description,omitempty"` SourcedExternally bool `protobuf:"varint,6,opt,name=sourced_externally,json=sourcedExternally,proto3" json:"sourced_externally,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // Sync bucketing configuration for parallel processing + // Resource types with the same bucket name will be processed sequentially within that bucket + // Resource types with different bucket names can be processed in parallel + // If not specified, the default bucket from ParallelSyncConfig will be used + SyncBucket string `protobuf:"bytes,7,opt,name=sync_bucket,json=syncBucket,proto3" json:"sync_bucket,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ResourceType) Reset() { @@ -202,6 +207,13 @@ func (x *ResourceType) GetSourcedExternally() bool { return false } +func (x *ResourceType) GetSyncBucket() string { + if x != nil { + return x.SyncBucket + } + return "" +} + func (x *ResourceType) SetId(v string) { x.Id = v } @@ -226,6 +238,10 @@ func (x *ResourceType) SetSourcedExternally(v bool) { x.SourcedExternally = v } +func (x *ResourceType) SetSyncBucket(v string) { + x.SyncBucket = v +} + type ResourceType_builder struct { _ [0]func() // Prevents comparability and use of unkeyed literals for the builder. @@ -235,6 +251,11 @@ type ResourceType_builder struct { Annotations []*anypb.Any Description string SourcedExternally bool + // Sync bucketing configuration for parallel processing + // Resource types with the same bucket name will be processed sequentially within that bucket + // Resource types with different bucket names can be processed in parallel + // If not specified, the default bucket from ParallelSyncConfig will be used + SyncBucket string } func (b0 ResourceType_builder) Build() *ResourceType { @@ -247,6 +268,7 @@ func (b0 ResourceType_builder) Build() *ResourceType { x.Annotations = b.Annotations x.Description = b.Description x.SourcedExternally = b.SourcedExternally + x.SyncBucket = b.SyncBucket return m0 } @@ -4182,7 +4204,7 @@ var File_c1_connector_v2_resource_proto protoreflect.FileDescriptor const file_c1_connector_v2_resource_proto_rawDesc = "" + "\n" + - "\x1ec1/connector/v2/resource.proto\x12\x0fc1.connector.v2\x1a\x19google/protobuf/any.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x17validate/validate.proto\"\xb4\x03\n" + + "\x1ec1/connector/v2/resource.proto\x12\x0fc1.connector.v2\x1a\x19google/protobuf/any.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x17validate/validate.proto\"\xd5\x03\n" + "\fResourceType\x12\x1a\n" + "\x02id\x18\x01 \x01(\tB\n" + "\xfaB\ar\x05 \x01(\x80\bR\x02id\x120\n" + @@ -4192,7 +4214,9 @@ const file_c1_connector_v2_resource_proto_rawDesc = "" + "\vannotations\x18\x04 \x03(\v2\x14.google.protobuf.AnyR\vannotations\x12/\n" + "\vdescription\x18\x05 \x01(\tB\r\xfaB\n" + "r\b \x01(\x80 \xd0\x01\x01R\vdescription\x12-\n" + - "\x12sourced_externally\x18\x06 \x01(\bR\x11sourcedExternally\"p\n" + + "\x12sourced_externally\x18\x06 \x01(\bR\x11sourcedExternally\x12\x1f\n" + + "\vsync_bucket\x18\a \x01(\tR\n" + + "syncBucket\"p\n" + "\x05Trait\x12\x15\n" + "\x11TRAIT_UNSPECIFIED\x10\x00\x12\x0e\n" + "\n" + diff --git a/pb/c1/connector/v2/resource.pb.validate.go b/pb/c1/connector/v2/resource.pb.validate.go index 48fa35152..98c4eca67 100644 --- a/pb/c1/connector/v2/resource.pb.validate.go +++ b/pb/c1/connector/v2/resource.pb.validate.go @@ -165,6 +165,8 @@ func (m *ResourceType) validate(all bool) error { // no validation rules for SourcedExternally + // no validation rules for SyncBucket + if len(errors) > 0 { return ResourceTypeMultiError(errors) } diff --git a/pb/c1/connector/v2/resource_protoopaque.pb.go b/pb/c1/connector/v2/resource_protoopaque.pb.go index f02851454..a5279d686 100644 --- a/pb/c1/connector/v2/resource_protoopaque.pb.go +++ b/pb/c1/connector/v2/resource_protoopaque.pb.go @@ -131,6 +131,7 @@ type ResourceType struct { xxx_hidden_Annotations *[]*anypb.Any `protobuf:"bytes,4,rep,name=annotations,proto3"` xxx_hidden_Description string `protobuf:"bytes,5,opt,name=description,proto3"` xxx_hidden_SourcedExternally bool `protobuf:"varint,6,opt,name=sourced_externally,json=sourcedExternally,proto3"` + xxx_hidden_SyncBucket string `protobuf:"bytes,7,opt,name=sync_bucket,json=syncBucket,proto3"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -204,6 +205,13 @@ func (x *ResourceType) GetSourcedExternally() bool { return false } +func (x *ResourceType) GetSyncBucket() string { + if x != nil { + return x.xxx_hidden_SyncBucket + } + return "" +} + func (x *ResourceType) SetId(v string) { x.xxx_hidden_Id = v } @@ -228,6 +236,10 @@ func (x *ResourceType) SetSourcedExternally(v bool) { x.xxx_hidden_SourcedExternally = v } +func (x *ResourceType) SetSyncBucket(v string) { + x.xxx_hidden_SyncBucket = v +} + type ResourceType_builder struct { _ [0]func() // Prevents comparability and use of unkeyed literals for the builder. @@ -237,6 +249,11 @@ type ResourceType_builder struct { Annotations []*anypb.Any Description string SourcedExternally bool + // Sync bucketing configuration for parallel processing + // Resource types with the same bucket name will be processed sequentially within that bucket + // Resource types with different bucket names can be processed in parallel + // If not specified, the default bucket from ParallelSyncConfig will be used + SyncBucket string } func (b0 ResourceType_builder) Build() *ResourceType { @@ -249,6 +266,7 @@ func (b0 ResourceType_builder) Build() *ResourceType { x.xxx_hidden_Annotations = &b.Annotations x.xxx_hidden_Description = b.Description x.xxx_hidden_SourcedExternally = b.SourcedExternally + x.xxx_hidden_SyncBucket = b.SyncBucket return m0 } @@ -4177,7 +4195,7 @@ var File_c1_connector_v2_resource_proto protoreflect.FileDescriptor const file_c1_connector_v2_resource_proto_rawDesc = "" + "\n" + - "\x1ec1/connector/v2/resource.proto\x12\x0fc1.connector.v2\x1a\x19google/protobuf/any.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x17validate/validate.proto\"\xb4\x03\n" + + "\x1ec1/connector/v2/resource.proto\x12\x0fc1.connector.v2\x1a\x19google/protobuf/any.proto\x1a\x1cgoogle/protobuf/struct.proto\x1a\x17validate/validate.proto\"\xd5\x03\n" + "\fResourceType\x12\x1a\n" + "\x02id\x18\x01 \x01(\tB\n" + "\xfaB\ar\x05 \x01(\x80\bR\x02id\x120\n" + @@ -4187,7 +4205,9 @@ const file_c1_connector_v2_resource_proto_rawDesc = "" + "\vannotations\x18\x04 \x03(\v2\x14.google.protobuf.AnyR\vannotations\x12/\n" + "\vdescription\x18\x05 \x01(\tB\r\xfaB\n" + "r\b \x01(\x80 \xd0\x01\x01R\vdescription\x12-\n" + - "\x12sourced_externally\x18\x06 \x01(\bR\x11sourcedExternally\"p\n" + + "\x12sourced_externally\x18\x06 \x01(\bR\x11sourcedExternally\x12\x1f\n" + + "\vsync_bucket\x18\a \x01(\tR\n" + + "syncBucket\"p\n" + "\x05Trait\x12\x15\n" + "\x11TRAIT_UNSPECIFIED\x10\x00\x12\x0e\n" + "\n" + diff --git a/pkg/cli/commands.go b/pkg/cli/commands.go index 883632d9d..388cb87f0 100644 --- a/pkg/cli/commands.go +++ b/pkg/cli/commands.go @@ -333,6 +333,10 @@ func MakeMainCommand[T field.Configurable]( } } + if v.GetBool("parallel-sync") { + opts = append(opts, connectorrunner.WithParallelSyncEnabled()) + } + if v.GetString("c1z-temp-dir") != "" { c1zTmpDir := v.GetString("c1z-temp-dir") if _, err := os.Stat(c1zTmpDir); os.IsNotExist(err) { diff --git a/pkg/connectorrunner/runner.go b/pkg/connectorrunner/runner.go index 1f3945971..51ccced79 100644 --- a/pkg/connectorrunner/runner.go +++ b/pkg/connectorrunner/runner.go @@ -341,6 +341,7 @@ type runnerConfig struct { syncDifferConfig *syncDifferConfig syncCompactorConfig *syncCompactorConfig skipFullSync bool + parallelSync bool targetedSyncResourceIDs []string externalResourceC1Z string externalResourceEntitlementIdFilter string @@ -552,6 +553,13 @@ func WithFullSyncDisabled() Option { } } +func WithParallelSyncEnabled() Option { + return func(ctx context.Context, cfg *runnerConfig) error { + cfg.parallelSync = true + return nil + } +} + func WithTargetedSyncResourceIDs(resourceIDs []string) Option { return func(ctx context.Context, cfg *runnerConfig) error { cfg.targetedSyncResourceIDs = resourceIDs @@ -803,6 +811,7 @@ func NewConnectorRunner(ctx context.Context, c types.ConnectorServer, opts ...Op local.WithSkipEntitlementsAndGrants(cfg.skipEntitlementsAndGrants), local.WithSkipGrants(cfg.skipGrants), local.WithSyncResourceTypeIDs(cfg.syncResourceTypeIDs), + local.WithParallelSyncEnabled(cfg.parallelSync), ) if err != nil { return nil, err @@ -815,7 +824,8 @@ func NewConnectorRunner(ctx context.Context, c types.ConnectorServer, opts ...Op return runner, nil } - tm, err := c1api.NewC1TaskManager(ctx, + tm, err := c1api.NewC1TaskManager( + ctx, cfg.clientID, cfg.clientSecret, cfg.tempDir, @@ -824,6 +834,7 @@ func NewConnectorRunner(ctx context.Context, c types.ConnectorServer, opts ...Op cfg.externalResourceEntitlementIdFilter, cfg.targetedSyncResourceIDs, cfg.syncResourceTypeIDs, + cfg.parallelSync, ) if err != nil { return nil, err diff --git a/pkg/dotc1z/c1file.go b/pkg/dotc1z/c1file.go index 3e922ec61..1f1963270 100644 --- a/pkg/dotc1z/c1file.go +++ b/pkg/dotc1z/c1file.go @@ -7,10 +7,13 @@ import ( "fmt" "os" "path/filepath" + "strings" "sync" "time" "github.com/doug-martin/goqu/v9" + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "go.uber.org/zap" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -49,6 +52,14 @@ type C1File struct { slowQueryLogTimesMu sync.Mutex slowQueryThreshold time.Duration slowQueryLogFrequency time.Duration + + // WAL checkpointing + checkpointTicker *time.Ticker + checkpointStop chan struct{} + checkpointDone chan struct{} + checkpointOnce sync.Once + checkpointMu sync.RWMutex // Prevents DB activity during WAL checkpoint to avoid WAL file growth under heavy load + checkpointEnabled bool // Whether WAL checkpointing is enabled } var _ connectorstore.Writer = (*C1File)(nil) @@ -67,6 +78,12 @@ func WithC1FPragma(name string, value string) C1FOption { } } +func WithC1FWALCheckpoint(enable bool) C1FOption { + return func(o *C1File) { + o.checkpointEnabled = enable + } +} + // Returns a C1File instance for the given db filepath. func NewC1File(ctx context.Context, dbFilePath string, opts ...C1FOption) (*C1File, error) { ctx, span := tracer.Start(ctx, "NewC1File") @@ -87,6 +104,8 @@ func NewC1File(ctx context.Context, dbFilePath string, opts ...C1FOption) (*C1Fi slowQueryLogTimes: make(map[string]time.Time), slowQueryThreshold: 5 * time.Second, slowQueryLogFrequency: 1 * time.Minute, + checkpointStop: make(chan struct{}), + checkpointDone: make(chan struct{}), } for _, opt := range opts { @@ -107,9 +126,10 @@ func NewC1File(ctx context.Context, dbFilePath string, opts ...C1FOption) (*C1Fi } type c1zOptions struct { - tmpDir string - pragmas []pragma - decoderOptions []DecoderOption + tmpDir string + pragmas []pragma + decoderOptions []DecoderOption + enableWALCheckpoint bool } type C1ZOption func(*c1zOptions) @@ -131,6 +151,12 @@ func WithDecoderOptions(opts ...DecoderOption) C1ZOption { } } +func WithWALCheckpoint(enable bool) C1ZOption { + return func(o *c1zOptions) { + o.enableWALCheckpoint = enable + } +} + // Returns a new C1File instance with its state stored at the provided filename. func NewC1ZFile(ctx context.Context, outputFilePath string, opts ...C1ZOption) (*C1File, error) { ctx, span := tracer.Start(ctx, "NewC1ZFile") @@ -150,6 +176,9 @@ func NewC1ZFile(ctx context.Context, outputFilePath string, opts ...C1ZOption) ( for _, pragma := range options.pragmas { c1fopts = append(c1fopts, WithC1FPragma(pragma.name, pragma.value)) } + if options.enableWALCheckpoint { + c1fopts = append(c1fopts, WithC1FWALCheckpoint(true)) + } c1File, err := NewC1File(ctx, dbFilePath, c1fopts...) if err != nil { @@ -174,6 +203,15 @@ func cleanupDbDir(dbFilePath string, err error) error { func (c *C1File) Close() error { var err error + // Stop WAL checkpointing if it's running + if c.checkpointTicker != nil { + c.checkpointTicker.Stop() + c.checkpointOnce.Do(func() { + close(c.checkpointStop) + }) + <-c.checkpointDone // Wait for goroutine to finish + } + if c.rawDb != nil { err = c.rawDb.Close() if err != nil { @@ -223,6 +261,11 @@ func (c *C1File) init(ctx context.Context) error { } } + // Start WAL checkpointing if enabled, journal mode is WAL, and checkpointing is enabled + if c.checkpointEnabled && c.isWALMode(ctx) { + c.startWALCheckpointing() + } + return nil } @@ -412,3 +455,70 @@ func (c *C1File) GrantStats(ctx context.Context, syncType connectorstore.SyncTyp return stats, nil } + +// isWALMode checks if the database is using WAL mode. +func (c *C1File) isWALMode(ctx context.Context) bool { + for _, pragma := range c.pragmas { + if pragma.name == "journal_mode" && strings.EqualFold(pragma.value, "wal") { + return true + } + } + + var mode string + if err := c.rawDb.QueryRowContext(ctx, "PRAGMA journal_mode").Scan(&mode); err == nil { + return strings.EqualFold(mode, "wal") + } + + return false +} + +// startWALCheckpointing starts a background goroutine to perform WAL checkpoints every 5 minutes. +func (c *C1File) startWALCheckpointing() { + c.checkpointTicker = time.NewTicker(5 * time.Minute) + + go func() { + defer close(c.checkpointDone) + for { + select { + case <-c.checkpointTicker.C: + c.performWALCheckpoint() + case <-c.checkpointStop: + return + } + } + }() +} + +// acquireCheckpointLock acquires a read lock for database operations. +func (c *C1File) acquireCheckpointLock() { + if c.checkpointEnabled { + c.checkpointMu.RLock() + } +} + +// releaseCheckpointLock releases the read lock for database operations. +func (c *C1File) releaseCheckpointLock() { + if c.checkpointEnabled { + c.checkpointMu.RUnlock() + } +} + +// performWALCheckpoint performs a WAL checkpoint using SQLITE_CHECKPOINT_RESTART or SQLITE_CHECKPOINT_TRUNCATE. +func (c *C1File) performWALCheckpoint() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Acquire write lock to pause all database operations during checkpoint + c.checkpointMu.Lock() + defer c.checkpointMu.Unlock() + + // First try SQLITE_CHECKPOINT_TRUNCATE + _, err := c.rawDb.ExecContext(ctx, "PRAGMA wal_checkpoint(TRUNCATE)") + if err != nil { + // If TRUNCATE fails, try RESTART + _, err = c.rawDb.ExecContext(ctx, "PRAGMA wal_checkpoint(RESTART)") + if err != nil { + ctxzap.Extract(ctx).Error("failed to perform WAL checkpoint", zap.Error(err)) + } + } +} diff --git a/pkg/dotc1z/manager/local/local.go b/pkg/dotc1z/manager/local/local.go index 6ccbdd9c0..5e273e4fa 100644 --- a/pkg/dotc1z/manager/local/local.go +++ b/pkg/dotc1z/manager/local/local.go @@ -16,10 +16,11 @@ import ( var tracer = otel.Tracer("baton-sdk/pkg.dotc1z.manager.local") type localManager struct { - filePath string - tmpPath string - tmpDir string - decoderOptions []dotc1z.DecoderOption + filePath string + tmpPath string + tmpDir string + decoderOptions []dotc1z.DecoderOption + enableWALCheckpoint bool } type Option func(*localManager) @@ -36,6 +37,12 @@ func WithDecoderOptions(opts ...dotc1z.DecoderOption) Option { } } +func WithWALCheckpoint(enable bool) Option { + return func(o *localManager) { + o.enableWALCheckpoint = enable + } +} + func (l *localManager) copyFileToTmp(ctx context.Context) error { _, span := tracer.Start(ctx, "localManager.copyFileToTmp") defer span.End() @@ -112,6 +119,9 @@ func (l *localManager) LoadC1Z(ctx context.Context) (*dotc1z.C1File, error) { if len(l.decoderOptions) > 0 { opts = append(opts, dotc1z.WithDecoderOptions(l.decoderOptions...)) } + if l.enableWALCheckpoint { + opts = append(opts, dotc1z.WithWALCheckpoint(true)) + } return dotc1z.NewC1ZFile(ctx, l.tmpPath, opts...) } diff --git a/pkg/dotc1z/manager/manager.go b/pkg/dotc1z/manager/manager.go index e94665251..9887d756a 100644 --- a/pkg/dotc1z/manager/manager.go +++ b/pkg/dotc1z/manager/manager.go @@ -18,8 +18,9 @@ type Manager interface { } type managerOptions struct { - tmpDir string - decoderOptions []dotc1z.DecoderOption + tmpDir string + decoderOptions []dotc1z.DecoderOption + enableWALCheckpoint bool } type ManagerOption func(*managerOptions) @@ -36,6 +37,12 @@ func WithDecoderOptions(opts ...dotc1z.DecoderOption) ManagerOption { } } +func WithWALCheckpoint(enable bool) ManagerOption { + return func(o *managerOptions) { + o.enableWALCheckpoint = enable + } +} + // Given a file path, return a Manager that can read and write files to that path. // // The first thing we do is check if the file path starts with "s3://". If it does, we return a new @@ -56,6 +63,9 @@ func New(ctx context.Context, filePath string, opts ...ManagerOption) (Manager, if len(options.decoderOptions) > 0 { s3Opts = append(s3Opts, s3.WithDecoderOptions(options.decoderOptions...)) } + if options.enableWALCheckpoint { + s3Opts = append(s3Opts, s3.WithWALCheckpoint(true)) + } return s3.NewS3Manager(ctx, filePath, s3Opts...) default: var localOpts []local.Option @@ -65,6 +75,9 @@ func New(ctx context.Context, filePath string, opts ...ManagerOption) (Manager, if len(options.decoderOptions) > 0 { localOpts = append(localOpts, local.WithDecoderOptions(options.decoderOptions...)) } + if options.enableWALCheckpoint { + localOpts = append(localOpts, local.WithWALCheckpoint(true)) + } return local.New(ctx, filePath, localOpts...) } } diff --git a/pkg/dotc1z/manager/s3/s3.go b/pkg/dotc1z/manager/s3/s3.go index 385b1bc47..0b707d0cc 100644 --- a/pkg/dotc1z/manager/s3/s3.go +++ b/pkg/dotc1z/manager/s3/s3.go @@ -19,11 +19,12 @@ import ( var tracer = otel.Tracer("baton-sdk/pkg.dotc1z.manager.s3") type s3Manager struct { - client *us3.S3Client - fileName string - tmpFile string - tmpDir string - decoderOptions []dotc1z.DecoderOption + client *us3.S3Client + fileName string + tmpFile string + tmpDir string + decoderOptions []dotc1z.DecoderOption + enableWALCheckpoint bool } type Option func(*s3Manager) @@ -40,6 +41,12 @@ func WithDecoderOptions(opts ...dotc1z.DecoderOption) Option { } } +func WithWALCheckpoint(enable bool) Option { + return func(o *s3Manager) { + o.enableWALCheckpoint = enable + } +} + func (s *s3Manager) copyToTempFile(ctx context.Context, r io.Reader) error { _, span := tracer.Start(ctx, "s3Manager.copyToTempFile") defer span.End() @@ -130,6 +137,9 @@ func (s *s3Manager) LoadC1Z(ctx context.Context) (*dotc1z.C1File, error) { if len(s.decoderOptions) > 0 { opts = append(opts, dotc1z.WithDecoderOptions(s.decoderOptions...)) } + if s.enableWALCheckpoint { + opts = append(opts, dotc1z.WithWALCheckpoint(true)) + } return dotc1z.NewC1ZFile(ctx, s.tmpFile, opts...) } diff --git a/pkg/dotc1z/sql_helpers.go b/pkg/dotc1z/sql_helpers.go index 07b045138..f1b447d16 100644 --- a/pkg/dotc1z/sql_helpers.go +++ b/pkg/dotc1z/sql_helpers.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "strconv" + "strings" "time" "github.com/doug-martin/goqu/v9" @@ -225,6 +226,10 @@ func (c *C1File) listConnectorObjects(ctx context.Context, tableName string, req // Start timing the query execution queryStartTime := time.Now() + // Acquire checkpoint lock to coordinate with WAL checkpointing + c.acquireCheckpointLock() + defer c.releaseCheckpointLock() + // Execute the query rows, err := c.db.QueryContext(ctx, query, args...) if err != nil { @@ -306,7 +311,10 @@ func prepareConnectorObjectRows[T proto.Message]( return rows, nil } -// executeChunkedInsert executes the insert query in chunks. +func isSQLiteBusy(err error) bool { + return strings.Contains(err.Error(), "database is locked") || strings.Contains(err.Error(), "SQLITE_BUSY") +} + func executeChunkedInsert( ctx context.Context, c *C1File, @@ -320,13 +328,6 @@ func executeChunkedInsert( chunks++ } - tx, err := c.db.BeginTx(ctx, nil) - if err != nil { - return err - } - - var txError error - for i := 0; i < chunks; i++ { start := i * chunkSize end := (i + 1) * chunkSize @@ -335,40 +336,104 @@ func executeChunkedInsert( } chunkedRows := rows[start:end] - // Create the base insert dataset + err := executeChunkWithRetry(ctx, c, tableName, chunkedRows, buildQueryFn) + if err != nil { + return err + } + } + + return nil +} + +// executeChunkWithRetry executes a single chunk with retry logic for SQLITE_BUSY errors. +func executeChunkWithRetry( + ctx context.Context, + c *C1File, + tableName string, + chunkedRows []*goqu.Record, + buildQueryFn func(*goqu.InsertDataset, []*goqu.Record) (*goqu.InsertDataset, error), +) error { + maxRetries := 5 + baseDelay := 10 * time.Millisecond + + for attempt := 0; attempt < maxRetries; attempt++ { + // Acquire checkpoint lock to coordinate with WAL checkpointing + c.acquireCheckpointLock() + + tx, err := c.db.BeginTx(ctx, nil) + if err != nil { + c.releaseCheckpointLock() + if isSQLiteBusy(err) && attempt < maxRetries-1 { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(time.Duration(attempt+1) * baseDelay): + continue + } + } + return err + } + insertDs := tx.Insert(tableName) - // Apply the custom query building function insertDs, err = buildQueryFn(insertDs, chunkedRows) if err != nil { - txError = err - break + c.releaseCheckpointLock() + rollbackErr := tx.Rollback() + if rollbackErr != nil { + return errors.Join(err, rollbackErr) + } + return err } // Generate the SQL query, args, err := insertDs.ToSQL() if err != nil { - txError = err - break + c.releaseCheckpointLock() + rollbackErr := tx.Rollback() + if rollbackErr != nil { + return errors.Join(err, rollbackErr) + } + return err } - // Execute the query _, err = tx.ExecContext(ctx, query, args...) if err != nil { - txError = err - break + c.releaseCheckpointLock() + rollbackErr := tx.Rollback() + if rollbackErr != nil { + return errors.Join(err, rollbackErr) + } + if isSQLiteBusy(err) && attempt < maxRetries-1 { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(time.Duration(attempt+1) * baseDelay): + continue + } + } + return err } - } - if txError != nil { - if rollbackErr := tx.Rollback(); rollbackErr != nil { - return errors.Join(rollbackErr, txError) + err = tx.Commit() + if err != nil { + c.releaseCheckpointLock() + if isSQLiteBusy(err) && attempt < maxRetries-1 { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(time.Duration(attempt+1) * baseDelay): + continue + } + } + return err } - return fmt.Errorf("error executing chunked insert: %w", txError) + c.releaseCheckpointLock() + return nil } - return tx.Commit() + return fmt.Errorf("failed to execute chunk after %d retries", maxRetries) } func bulkPutConnectorObject[T proto.Message]( diff --git a/pkg/field/defaults.go b/pkg/field/defaults.go index 8a05c7930..56c02ae9b 100644 --- a/pkg/field/defaults.go +++ b/pkg/field/defaults.go @@ -89,6 +89,7 @@ var ( WithPersistent(true), WithExportTarget(ExportTargetOps)) skipFullSync = BoolField("skip-full-sync", WithDescription("This must be set to skip a full sync"), WithPersistent(true), WithExportTarget(ExportTargetNone)) + parallelSync = BoolField("parallel-sync", WithDescription("This must be set to enable parallel sync"), WithPersistent(true), WithExportTarget(ExportTargetNone)) targetedSyncResourceIDs = StringSliceField("sync-resources", WithDescription("The resource IDs to sync"), WithPersistent(true), WithExportTarget(ExportTargetNone)) skipEntitlementsAndGrants = BoolField("skip-entitlements-and-grants", WithDescription("This must be set to skip syncing of entitlements and grants"), @@ -303,6 +304,7 @@ var DefaultFields = []SchemaField{ invokeActionField, invokeActionArgsField, ServerSessionStoreMaximumSizeField, + parallelSync, otelCollectorEndpoint, otelCollectorEndpointTLSCertPath, diff --git a/pkg/sync/parallel_syncer.go b/pkg/sync/parallel_syncer.go new file mode 100644 index 000000000..c3a312b2b --- /dev/null +++ b/pkg/sync/parallel_syncer.go @@ -0,0 +1,1983 @@ +package sync + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/attribute" + "go.uber.org/zap" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + v2 "github.com/conductorone/baton-sdk/pb/c1/connector/v2" + reader_v2 "github.com/conductorone/baton-sdk/pb/c1/reader/v2" + "github.com/conductorone/baton-sdk/pkg/annotations" +) + +var _ Syncer = (*parallelSyncer)(nil) + +var taskRetryLimit = 5 +var errTaskQueueFull = errors.New("task queue is full") +var parallelTracer = otel.Tracer("baton-sdk/parallel-sync") + +const ( + nextPageAction = "next_page" + finishAction = "finish" +) + +// min returns the smaller of two integers. +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +// addTaskWithRetry adds a task to the queue with retry logic for queue full errors. +func (ps *parallelSyncer) addTaskWithRetry(ctx context.Context, task *task, maxRetries int) error { + for attempt := 0; attempt <= maxRetries; attempt++ { + err := ps.taskQueue.AddTask(ctx, task) + if err == nil { + return nil + } + + if !errors.Is(err, errTaskQueueFull) { + return err + } + + // If this is the last attempt, return the error + if attempt == maxRetries { + return fmt.Errorf("failed to add task after %d retries: %w", maxRetries, err) + } + + // Wait before retrying, with true exponential backoff + backoffDuration := time.Duration(1< 0 { + c.WorkerCount = count + } + return c +} + +// WithDefaultBucket sets the default bucket for resource types that don't specify a sync_bucket. +func (c *ParallelSyncConfig) WithDefaultBucket(bucket string) *ParallelSyncConfig { + c.DefaultBucket = bucket + return c +} + +// task represents a unit of work for the parallel syncer. +type task struct { + Action Action + ResourceID string + ResourceType *v2.ResourceType // The resource type for this task +} + +// TaskResult contains tasks that should be created after completing a task. +type TaskResult struct { + Tasks []*task + Error error +} + +// DeferredTaskAdder collects tasks during processing and adds them after completion. +type DeferredTaskAdder struct { + pendingTasks []*task + sync.RWMutex +} + +func NewDeferredTaskAdder() *DeferredTaskAdder { + return &DeferredTaskAdder{ + pendingTasks: make([]*task, 0), + } +} + +func (dta *DeferredTaskAdder) AddPendingTask(task *task) { + dta.Lock() + defer dta.Unlock() + dta.pendingTasks = append(dta.pendingTasks, task) +} + +func (dta *DeferredTaskAdder) GetPendingTasks() []*task { + dta.RLock() + defer dta.RUnlock() + return dta.pendingTasks +} + +func (dta *DeferredTaskAdder) Clear() { + dta.Lock() + defer dta.Unlock() + dta.pendingTasks = dta.pendingTasks[:0] // Reuse slice +} + +// taskQueue manages the distribution of tasks to workers using dynamic bucketing. +type taskQueue struct { + bucketQueues map[string]chan *task // Map of bucket name to task channel + parallelSyncer *parallelSyncer + mu sync.RWMutex + closed bool +} + +// newTaskQueue creates a new task queue. +func newTaskQueue(parallelSyncer *parallelSyncer) *taskQueue { + // Initialize with an empty map of bucket queues + // Buckets will be created dynamically as tasks are added + return &taskQueue{ + bucketQueues: make(map[string]chan *task), + parallelSyncer: parallelSyncer, + } +} + +func (q *taskQueue) getOrCreateBucketChannel(bucket string) (chan *task, error) { + q.mu.Lock() + defer q.mu.Unlock() + + if q.closed { + return nil, errors.New("task queue is closed") + } + + // Create the bucket queue if it doesn't exist + queue, exists := q.bucketQueues[bucket] + if !exists { + queueSize := q.parallelSyncer.config.WorkerCount * 10 + queue = make(chan *task, queueSize) + q.bucketQueues[bucket] = queue + } + + return queue, nil +} + +// AddTask adds a task to the appropriate queue. +func (q *taskQueue) AddTask(ctx context.Context, t *task) error { + bucket := q.getBucketForTask(t) + queue, err := q.getOrCreateBucketChannel(bucket) + if err != nil { + return err + } + + // Add the task to the appropriate bucket queue with timeout + // This prevents indefinite blocking while still allowing graceful handling of full queues + timeout := 30 * time.Second + select { + case queue <- t: + // Log task addition for debugging + l := ctxzap.Extract(ctx) + l.Info("task added to queue", + zap.String("bucket", bucket), + zap.String("operation", t.Action.Op.String()), + zap.String("resource_type", t.Action.ResourceTypeID), + zap.Int("queue_length", len(queue))) + return nil + case <-time.After(timeout): + return errTaskQueueFull + case <-ctx.Done(): + return ctx.Err() + } +} + +// AddTaskWithTimeout adds a task with a custom timeout and dynamic queue expansion. +func (q *taskQueue) AddTaskWithTimeout(ctx context.Context, t *task, timeout time.Duration) error { + bucket := q.getBucketForTask(t) + queue, err := q.getOrCreateBucketChannel(bucket) + if err != nil { + return err + } + + // Try to add the task + select { + case queue <- t: + return nil + case <-time.After(timeout): + // Queue is full, try to expand it + return q.expandQueueAndRetry(bucket, t, timeout) + case <-ctx.Done(): + return ctx.Err() + } +} + +// expandQueueAndRetry attempts to expand the queue and retry adding the task. +func (q *taskQueue) expandQueueAndRetry(bucket string, t *task, timeout time.Duration) error { + q.mu.Lock() + defer q.mu.Unlock() + + if q.closed { + return errors.New("task queue is closed") + } + + l := ctxzap.Extract(context.Background()) + + // Get current queue + currentQueue := q.bucketQueues[bucket] + currentSize := cap(currentQueue) + currentLen := len(currentQueue) + + // Only expand if queue is nearly full + if currentLen < currentSize-1 { + return errTaskQueueFull + } + + // Calculate new size (double it, but cap at reasonable limit) + newSize := minInt(currentSize*2, 50000) // Cap at 50k tasks per bucket + + if newSize <= currentSize { + l.Warn("queue expansion blocked - already at maximum size", + zap.String("bucket", bucket), + zap.Int("current_size", currentSize)) + return errTaskQueueFull + } + + l.Info("expanding queue due to pressure", + zap.String("bucket", bucket), + zap.Int("old_size", currentSize), + zap.Int("new_size", newSize), + zap.Int("current_length", currentLen)) + + // Create new larger queue + newQueue := make(chan *task, newSize) + + // Copy existing tasks to new queue + for len(currentQueue) > 0 { + task := <-currentQueue + select { + case newQueue <- task: + default: + // This should never happen since new queue is larger + l.Error("failed to copy task to expanded queue") + return errTaskQueueFull + } + } + + // Replace the queue + q.bucketQueues[bucket] = newQueue + + // Try to add the new task + select { + case newQueue <- t: + return nil + default: + // This should never happen since we just expanded + l.Error("failed to add task to expanded queue") + return errTaskQueueFull + } +} + +// getBucketForTask determines the bucket for a task based on the resource type's sync_bucket. +func (q *taskQueue) getBucketForTask(t *task) string { + // If the resource type has an explicit sync_bucket, use it + if t.ResourceType != nil && t.ResourceType.SyncBucket != "" { + return t.ResourceType.SyncBucket + } + + // If no explicit bucket and default is empty, create a unique bucket per resource type + if q.parallelSyncer.config.DefaultBucket == "" { + return fmt.Sprintf("resource-type-%s", t.Action.ResourceTypeID) + } + + // Otherwise use the configured default bucket + return q.parallelSyncer.config.DefaultBucket +} + +// GetTask retrieves the next task with intelligent bucket selection. +func (q *taskQueue) GetTask(ctx context.Context) (*task, error) { + q.mu.Lock() // Use write lock to make the operation atomic + defer q.mu.Unlock() + + // Debug logging + l := ctxzap.Extract(ctx) + l.Debug("GetTask called", + zap.Int("total_buckets", len(q.bucketQueues)), + zap.Strings("bucket_names", getMapKeys(q.bucketQueues))) + + if len(q.bucketQueues) == 0 { + l.Debug("no buckets available") + return nil, errors.New("no buckets available") + } + + // First, try to find a bucket with available tasks + var availableBuckets []string + for bucketName, queue := range q.bucketQueues { + queueLen := len(queue) + l.Debug("checking bucket", zap.String("bucket", bucketName), zap.Int("queue_length", queueLen)) + if queueLen > 0 { + availableBuckets = append(availableBuckets, bucketName) + } + } + + l.Debug("available buckets", zap.Strings("buckets", availableBuckets)) + + if len(availableBuckets) == 0 { + l.Debug("no tasks available in any bucket") + return nil, errors.New("no tasks available") + } + + // Try to get a task from each available bucket in round-robin order + // Use a more robust approach that handles the case where a queue becomes empty + for _, bucketName := range availableBuckets { + queue := q.bucketQueues[bucketName] + + // Double-check the queue still has items before trying to read + if len(queue) == 0 { + l.Debug("bucket queue became empty", zap.String("bucket", bucketName)) + continue + } + + select { + case t := <-queue: + l.Debug("retrieved task from bucket", zap.String("bucket", bucketName)) + return t, nil + default: + l.Debug("bucket queue empty when trying to read", zap.String("bucket", bucketName)) + continue + } + } + + l.Debug("failed to get task from any available bucket") + return nil, errors.New("no tasks available") +} + +// getMapKeys returns the keys of a map as a slice. +func getMapKeys(m map[string]chan *task) []string { + keys := make([]string, 0, len(m)) + for k := range m { + keys = append(keys, k) + } + return keys +} + +// GetTaskFromBucket retrieves a task from a specific bucket. +func (q *taskQueue) GetTaskFromBucket(bucketName string) (*task, error) { + q.mu.Lock() // Use write lock to make the operation atomic + defer q.mu.Unlock() + + queue, exists := q.bucketQueues[bucketName] + if !exists { + return nil, fmt.Errorf("bucket '%s' does not exist", bucketName) + } + + select { + case t := <-queue: + return t, nil + default: + return nil, errors.New("no tasks available in bucket") + } +} + +// GetBucketStats returns statistics about each bucket. +func (q *taskQueue) GetBucketStats() map[string]int { + q.mu.RLock() + defer q.mu.RUnlock() + + stats := make(map[string]int) + for bucketName, queue := range q.bucketQueues { + stats[bucketName] = len(queue) + } + return stats +} + +// Close closes the task queue. +func (q *taskQueue) Close() { + q.mu.Lock() + defer q.mu.Unlock() + + for _, w := range q.parallelSyncer.workers { + w.cancel() + } + q.closed = true +} + +// worker represents a worker goroutine that processes tasks. +type worker struct { + id int + taskQueue *taskQueue + syncer *parallelSyncer + ctx context.Context + cancel context.CancelFunc + wg *sync.WaitGroup + rateLimited atomic.Bool + isProcessing atomic.Bool +} + +// newWorker creates a new worker. +func newWorker(id int, taskQueue *taskQueue, syncer *parallelSyncer, ctx context.Context, wg *sync.WaitGroup) *worker { + workerCtx, cancel := context.WithCancel(ctx) + return &worker{ + id: id, + taskQueue: taskQueue, + syncer: syncer, + ctx: workerCtx, + cancel: cancel, + wg: wg, + } +} + +// Start starts the worker with bucket-aware task processing and work-stealing. +func (w *worker) Start() { + defer w.wg.Done() + + l := ctxzap.Extract(w.ctx) + l.Debug("worker started", zap.Int("worker_id", w.id)) + + // Track which bucket this worker is currently working on + currentBucket := "" + consecutiveFailures := 0 + maxConsecutiveFailures := 3 + + for { + select { + case <-w.ctx.Done(): + l.Debug("worker stopped", zap.Int("worker_id", w.id)) + return + default: + // Try to get a task, with preference for the current bucket if we're making progress + task, err := w.taskQueue.GetTask(w.ctx) + if err != nil { + // No tasks available, wait a bit + l.Debug("no tasks available, waiting", zap.Int("worker_id", w.id), zap.Error(err)) + select { + case <-w.ctx.Done(): + l.Debug("worker context cancelled, stopping", zap.Int("worker_id", w.id)) + return + case <-time.After(100 * time.Millisecond): + } + continue + } + l.Debug("worker got task", zap.Int("worker_id", w.id), zap.String("task_op", task.Action.Op.String())) + + // Track which bucket we're working on + taskBucket := w.taskQueue.getBucketForTask(task) + if taskBucket != currentBucket { + l.Debug("worker switching buckets", + zap.Int("worker_id", w.id), + zap.String("from_bucket", currentBucket), + zap.String("to_bucket", taskBucket)) + currentBucket = taskBucket + consecutiveFailures = 0 + } + + // Add detailed task information logging + l.Debug("processing task details", + zap.Int("worker_id", w.id), + zap.String("task_op", task.Action.Op.String()), + zap.String("resource_type", task.Action.ResourceTypeID), + zap.String("page_token", task.Action.PageToken), + zap.String("bucket", taskBucket)) + + // Set processing flag + w.isProcessing.Store(true) + + // Process the task + taskResult, err := w.processTask(task) + if err != nil { + // Add pending tasks after task completion (even if failed, they might be valid) + if taskResult != nil && len(taskResult.Tasks) > 0 { + err = w.addTasksAfterCompletion(taskResult.Tasks) + if err != nil { + l.Error("failed to add tasks after completion", + zap.Int("worker_id", w.id), + zap.String("bucket", taskBucket), + zap.Error(err)) + w.taskQueue.Close() + return + } + } + l.Error("failed to process task", + zap.Int("worker_id", w.id), + zap.String("bucket", taskBucket), + zap.String("operation", task.Action.Op.String()), + zap.String("resource_type", task.Action.ResourceTypeID), + zap.Error(err)) + + consecutiveFailures++ + + // Check if this is a rate limit error + if w.isRateLimitError(err) { + w.rateLimited.Store(true) + + // If we're hitting rate limits in the current bucket, consider switching + if consecutiveFailures >= maxConsecutiveFailures { + l.Info("worker hitting rate limits in bucket, will try other buckets", + zap.Int("worker_id", w.id), + zap.String("bucket", taskBucket), + zap.Int("consecutive_failures", consecutiveFailures)) + + // Force bucket switch on next iteration + currentBucket = "" + consecutiveFailures = 0 + } + + // Wait before retrying with bucket-specific delay + delay := w.getBucketRateLimitDelay(taskBucket) + select { + case <-w.ctx.Done(): + return + case <-time.After(delay): + } + } else { + // Non-rate-limit error, reset rate limit flag + w.rateLimited.Store(false) + } + } else { + // Task succeeded, add any pending tasks after completion + if taskResult != nil && len(taskResult.Tasks) > 0 { + err = w.addTasksAfterCompletion(taskResult.Tasks) + if err != nil { + l.Error("failed to add tasks after completion", + zap.Int("worker_id", w.id), + zap.String("bucket", taskBucket), + zap.Error(err)) + w.taskQueue.Close() + return + } + } + + // Reset failure counters + w.rateLimited.Store(false) + consecutiveFailures = 0 + } + + // Reset processing flag + w.isProcessing.Store(false) + } + } +} + +// processTask processes a single task and returns any tasks that should be created after completion. +func (w *worker) processTask(t *task) (*TaskResult, error) { + ctx, span := parallelTracer.Start(w.ctx, "worker.processTask") + defer span.End() + + span.SetAttributes( + attribute.Int("worker_id", w.id), + attribute.String("operation", t.Action.Op.String()), + attribute.String("resource_type", t.Action.ResourceTypeID), + ) + + switch t.Action.Op { + case SyncResourcesOp: + tasks, err := w.syncer.syncResourcesCollectTasks(ctx, t.Action) + return &TaskResult{ + Tasks: tasks, + Error: err, + }, err + case SyncEntitlementsOp: + if t.Action.ResourceID != "" { + err := w.syncer.syncEntitlementsForResource(ctx, t.Action) + return &TaskResult{Tasks: []*task{}, Error: err}, err + } else { + err := w.syncer.syncEntitlementsForResourceType(ctx, t.Action) + return &TaskResult{Tasks: []*task{}, Error: err}, err + } + case SyncGrantsOp: + if t.Action.ResourceID != "" { + err := w.syncer.syncGrantsForResource(ctx, t.Action) + return &TaskResult{Tasks: []*task{}, Error: err}, err + } else { + err := w.syncer.syncGrantsForResourceType(ctx, t.Action) + return &TaskResult{Tasks: []*task{}, Error: err}, err + } + case CollectEntitlementsAndGrantsTasksOp: + tasks, err := w.syncer.collectEntitlementsAndGrantsTasks(ctx, t.Action) + return &TaskResult{ + Tasks: tasks, + Error: err, + }, err + default: + return nil, fmt.Errorf("unsupported operation: %s", t.Action.Op.String()) + } +} + +// isRateLimitError checks if an error is a rate limit error. +func (w *worker) isRateLimitError(err error) bool { + // Check for rate limit annotations in the error + if err == nil { + return false + } + + // This is a simplified check - in practice, you'd want to check the actual + // error type returned by the connector for rate limiting + return status.Code(err) == codes.ResourceExhausted +} + +// getBucketRateLimitDelay returns the appropriate delay for a bucket based on rate limiting. +func (w *worker) getBucketRateLimitDelay(bucket string) time.Duration { + // Different buckets can have different rate limit delays + // This allows for bucket-specific rate limiting strategies + + switch { + case strings.Contains(bucket, "rate-limited"): + return 2 * time.Second // Longer delay for rate-limited buckets + case strings.Contains(bucket, "fast-apis"): + return 100 * time.Millisecond // Shorter delay for fast APIs + default: + return 1 * time.Second // Default delay + } +} + +// Stop stops the worker. +func (w *worker) Stop() { + w.cancel() +} + +// parallelSyncer extends the base syncer with parallel processing capabilities. +type parallelSyncer struct { + syncer *SequentialSyncer + config *ParallelSyncConfig + taskQueue *taskQueue + workers []*worker + workerWg sync.WaitGroup + mu sync.RWMutex +} + +// NewParallelSyncer creates a new parallel syncer. +func NewParallelSyncer(baseSyncer *SequentialSyncer, config *ParallelSyncConfig) *parallelSyncer { + if config == nil { + config = DefaultParallelSyncConfig() + } + + // Enable WAL checkpointing for parallel sync to prevent checkpoint failures under high concurrency + baseSyncer.enableWALCheckpoint = true + + return ¶llelSyncer{ + syncer: baseSyncer, + config: config, + } +} + +// Sync implements the Syncer interface using parallel processing. +func (ps *parallelSyncer) Sync(ctx context.Context) error { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.Sync") + defer span.End() + + l := ctxzap.Extract(ctx) + + // Initialize the sync + if err := ps.initializeSync(ctx); err != nil { + return err + } + + // Create task queue + ps.taskQueue = newTaskQueue(ps) + defer ps.taskQueue.Close() + + // Start workers + if err := ps.startWorkers(ctx); err != nil { + return err + } + defer ps.stopWorkers() + + // Generate initial tasks + if err := ps.generateInitialTasks(ctx); err != nil { + return err + } + + // Wait for all tasks to complete + if err := ps.waitForCompletion(ctx); err != nil { + return err + } + + // Now that all parallel processing is complete, run grant expansion sequentially + if err := ps.syncGrantExpansion(ctx); err != nil { + l.Error("failed to run grant expansion", zap.Error(err)) + return fmt.Errorf("failed to run grant expansion: %w", err) + } + + // Run external resources sync if configured + if ps.syncer.externalResourceReader != nil { + if err := ps.syncExternalResources(ctx); err != nil { + l.Error("failed to run external resources sync", zap.Error(err)) + return fmt.Errorf("failed to run external resources sync: %w", err) + } + } + + // Finalize sync + if err := ps.finalizeSync(ctx); err != nil { + return err + } + + return nil +} + +// initializeSync performs the initial sync setup. +func (ps *parallelSyncer) initializeSync(ctx context.Context) error { + // Load store and validate connector (reuse existing logic) + if err := ps.syncer.loadStore(ctx); err != nil { + return err + } + + _, err := ps.syncer.connector.Validate(ctx, &v2.ConnectorServiceValidateRequest{}) + if err != nil { + return err + } + + // Start or resume sync + _, _, err = ps.syncer.startOrResumeSync(ctx) + if err != nil { + return err + } + + // Set up state + currentStep, err := ps.syncer.store.CurrentSyncStep(ctx) + if err != nil { + return err + } + + state := &state{} + if err := state.Unmarshal(currentStep); err != nil { + return err + } + ps.syncer.state = state + + // Set progress counts to parallel mode for thread safety + if ps.syncer.counts != nil { + ps.syncer.counts.SetSequentialMode(false) + } + + return nil +} + +// startWorkers starts all worker goroutines. +func (ps *parallelSyncer) startWorkers(ctx context.Context) error { + ps.workers = make([]*worker, ps.config.WorkerCount) + + for i := 0; i < ps.config.WorkerCount; i++ { + worker := newWorker(i, ps.taskQueue, ps, ctx, &ps.workerWg) + ps.workers[i] = worker + ps.workerWg.Add(1) + go worker.Start() + } + + return nil +} + +// stopWorkers stops all workers. +func (ps *parallelSyncer) stopWorkers() { + for _, worker := range ps.workers { + worker.Stop() + } + ps.workerWg.Wait() +} + +// areWorkersIdle checks if all workers are currently idle (not processing tasks). +func (ps *parallelSyncer) areWorkersIdle() bool { + ps.mu.RLock() + defer ps.mu.RUnlock() + + for _, worker := range ps.workers { + if worker.isProcessing.Load() { + return false + } + } + return true +} + +// generateInitialTasks creates the initial set of tasks following the original sync workflow. +func (ps *parallelSyncer) generateInitialTasks(ctx context.Context) error { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.generateInitialTasks") + defer span.End() + + l := ctxzap.Extract(ctx) + + // Follow the exact same workflow as the original sync + // 1. Start with resource types + // 2. Then resources for each resource type (sequentially within each resource type) + // 3. Then entitlements for each resource type (sequentially within each resource type) + // 4. Then grants for each resource type (sequentially within each resource type) + // 5. Then grant expansion and external resources + + // First, sync resource types + if err := ps.syncResourceTypes(ctx); err != nil { + l.Error("failed to sync resource types", zap.Error(err)) + return err + } + + // Get all resource types and create resource sync tasks + resp, err := ps.syncer.store.ListResourceTypes(ctx, &v2.ResourceTypesServiceListResourceTypesRequest{}) + if err != nil { + l.Error("failed to list resource types", zap.Error(err)) + return err + } + + // Group resource types by their buckets for better task organization + bucketGroups := make(map[string][]*v2.ResourceType) + for _, rt := range resp.List { + bucket := ps.getBucketForResourceType(rt) + bucketGroups[bucket] = append(bucketGroups[bucket], rt) + } + + // Create tasks for each bucket, ensuring sequential processing within each bucket + for _, resourceTypes := range bucketGroups { + l := ctxzap.Extract(ctx) + + // Create tasks for this bucket + for _, rt := range resourceTypes { + // Create task to sync resources for this resource type + task := &task{ + Action: Action{ + Op: SyncResourcesOp, + ResourceTypeID: rt.Id, + }, + ResourceType: rt, // Include the resource type for bucket determination + } + + if err := ps.addTaskWithRetry(ctx, task, taskRetryLimit); err != nil { + l.Error("failed to add resource sync task", zap.Error(err)) + return fmt.Errorf("failed to add resource sync task for resource type %s: %w", rt.Id, err) + } + } + } + + // Note: Grant expansion and external resources tasks are NOT added here + // They are added after ALL resource types are completely processed + // This ensures the correct order: resources → entitlements → grants → grant expansion → external resources + + return nil +} + +// getBucketForResourceType determines the bucket for a resource type. +func (ps *parallelSyncer) getBucketForResourceType(rt *v2.ResourceType) string { + // If the resource type has an explicit sync_bucket, use it + if rt.SyncBucket != "" { + return rt.SyncBucket + } + + // If no explicit bucket and default is empty, create a unique bucket per resource type + if ps.config.DefaultBucket == "" { + return fmt.Sprintf("resource-type-%s", rt.Id) + } + + // Otherwise use the configured default bucket + return ps.config.DefaultBucket +} + +// waitForCompletion waits for all tasks to complete with bucket-aware monitoring. +func (ps *parallelSyncer) waitForCompletion(ctx context.Context) error { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.waitForCompletion") + defer span.End() + + l := ctxzap.Extract(ctx) + + // Monitor task completion with periodic status updates + ticker := time.NewTicker(5 * time.Second) + defer ticker.Stop() + + lastTaskCount := 0 + noProgressCount := 0 + maxNoProgressCount := 6 // 30 seconds without progress + + for { + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + // Get current bucket statistics + bucketStats := ps.taskQueue.GetBucketStats() + totalTasks := 0 + for _, count := range bucketStats { + totalTasks += count + } + + // Log progress + if len(bucketStats) > 0 { + // Debug: Log which buckets still have active tasks + activeBuckets := make([]string, 0) + for bucketName, taskCount := range bucketStats { + if taskCount > 0 { + activeBuckets = append(activeBuckets, fmt.Sprintf("%s:%d", bucketName, taskCount)) + } + } + l.Debug("active buckets", zap.Strings("active_buckets", activeBuckets)) + } + + // Check if we're making progress + if totalTasks == lastTaskCount { + noProgressCount++ + if noProgressCount >= maxNoProgressCount { + l.Warn("no task progress detected", + zap.Int("no_progress_count", noProgressCount), + zap.Int("last_task_count", lastTaskCount), + zap.Int("total_tasks", totalTasks)) + } + } else { + noProgressCount = 0 + lastTaskCount = totalTasks + } + + // Check if all resource-specific tasks are complete + // We need to ensure ALL resource types have finished processing + if totalTasks == 0 { + // Double-check that we're truly done with resource processing + // Look for any active resource processing in the bucket stats + allResourceProcessingComplete := true + for _, taskCount := range bucketStats { + if taskCount > 0 { + allResourceProcessingComplete = false + break + } + } + + if allResourceProcessingComplete { + // Additional safety check: wait a bit more to ensure workers are truly idle + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(2 * time.Second): + } + + // Check one more time to ensure no new tasks appeared + finalBucketStats := ps.taskQueue.GetBucketStats() + finalTotalTasks := 0 + for _, count := range finalBucketStats { + finalTotalTasks += count + } + + if finalTotalTasks == 0 { + // Final check: ensure all workers are actually idle + if ps.areWorkersIdle() { + return nil + } else { + // Reset progress counters since we're not done yet + noProgressCount = 0 + lastTaskCount = finalTotalTasks + } + } else { + // Reset progress counters since we're not done yet + noProgressCount = 0 + lastTaskCount = finalTotalTasks + } + } + } + } + } +} + +// syncGrantExpansion handles grant expansion by delegating to the base syncer. +func (ps *parallelSyncer) syncGrantExpansion(ctx context.Context) error { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.syncGrantExpansion") + defer span.End() + + // The base syncer's SyncGrantExpansion expects to have actions in its state stack + // We need to set up the proper state context before calling it + ps.syncer.state.PushAction(ctx, Action{ + Op: SyncGrantExpansionOp, + }) + + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + } + + currentAction := ps.syncer.state.Current() + if currentAction == nil || currentAction.Op != SyncGrantExpansionOp { + break + } + + // Delegate to the base syncer's grant expansion logic + // This ensures we get the exact same behavior as the sequential sync + err := ps.syncer.SyncGrantExpansion(ctx) + if err != nil { + return err + } + } + + return nil +} + +// syncExternalResources handles external resources by delegating to the base syncer. +func (ps *parallelSyncer) syncExternalResources(ctx context.Context) error { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.syncExternalResources") + defer span.End() + + // The base syncer's SyncExternalResources expects to have actions in its state stack + // We need to set up the proper state context before calling it + ps.syncer.state.PushAction(ctx, Action{ + Op: SyncExternalResourcesOp, + }) + + // Delegate to the base syncer's external resources logic + // This ensures we get the exact same behavior as the sequential sync + err := ps.syncer.SyncExternalResources(ctx) + + // Clean up the state + ps.syncer.state.FinishAction(ctx) + + return err +} + +// finalizeSync performs final sync cleanup. +func (ps *parallelSyncer) finalizeSync(ctx context.Context) error { + // End sync + if err := ps.syncer.store.EndSync(ctx); err != nil { + return err + } + + // Cleanup + if err := ps.syncer.store.Cleanup(ctx); err != nil { + return err + } + + _, err := ps.syncer.connector.Cleanup(ctx, &v2.ConnectorServiceCleanupRequest{}) + if err != nil { + ctxzap.Extract(ctx).Error("error clearing connector caches", zap.Error(err)) + } + + return nil +} + +// syncResourceTypes syncs resource types (equivalent to SyncResourceTypes). +func (ps *parallelSyncer) syncResourceTypes(ctx context.Context) error { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.syncResourceTypes") + defer span.End() + + // This replicates the exact logic from the original SyncResourceTypes + resp, err := ps.syncer.connector.ListResourceTypes(ctx, &v2.ResourceTypesServiceListResourceTypesRequest{}) + if err != nil { + return err + } + + err = ps.syncer.store.PutResourceTypes(ctx, resp.List...) + if err != nil { + return err + } + + ps.syncer.counts.AddResourceTypes(len(resp.List)) + + return nil +} + +// syncResourcesCollectTasks does the same work as syncResources but collects tasks instead of adding them immediately. +func (ps *parallelSyncer) syncResourcesCollectTasks(ctx context.Context, action Action) ([]*task, error) { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.syncResourcesCollectTasks") + defer span.End() + + l := ctxzap.Extract(ctx) + var collectedTasks []*task + + // Add panic recovery to catch any unexpected errors + defer func() { + if r := recover(); r != nil { + l.Error("panic in syncResourcesCollectTasks", + zap.String("resource_type", action.ResourceTypeID), + zap.Any("panic", r)) + } + }() + + // This replicates the exact logic from the original SyncResources + req := &v2.ResourcesServiceListResourcesRequest{ + ResourceTypeId: action.ResourceTypeID, + PageToken: action.PageToken, + } + + // If this is a child resource task, set the parent resource ID + if action.ParentResourceID != "" { + req.ParentResourceId = &v2.ResourceId{ + ResourceType: action.ParentResourceTypeID, + Resource: action.ParentResourceID, + } + } + + resp, err := ps.syncer.connector.ListResources(ctx, req) + if err != nil { + l.Error("failed to list resources", zap.Error(err)) + return nil, err + } + + // Store resources + if len(resp.List) > 0 { + err = ps.syncer.store.PutResources(ctx, resp.List...) + if err != nil { + l.Error("failed to store resources", zap.Error(err)) + return nil, err + } + } + + // Update progress counts + resourceTypeId := action.ResourceTypeID + ps.syncer.counts.AddResources(resourceTypeId, len(resp.List)) + + // Log progress + ps.syncer.counts.LogResourcesProgress(ctx, resourceTypeId) + + // Process each resource (handle sub-resources) + for _, r := range resp.List { + if err := ps.syncer.getSubResources(ctx, r); err != nil { + l.Error("failed to process sub-resources", zap.Error(err)) + return nil, err + } + } + + // Handle pagination - if there are more pages, collect the task for next page + if resp.NextPageToken != "" { + nextPageTask := &task{ + Action: Action{ + Op: SyncResourcesOp, + ResourceTypeID: action.ResourceTypeID, + PageToken: resp.NextPageToken, + }, + } + collectedTasks = append(collectedTasks, nextPageTask) + return collectedTasks, nil // Don't create entitlement/grant tasks yet, wait for all pages + } + + // Check if this resource type has child resource types that need to be processed + if err := ps.processChildResourceTypes(ctx, action.ResourceTypeID); err != nil { + l.Error("failed to process child resource types", zap.Error(err)) + return nil, err + } + + actionForEntitlementsAndGrants := Action{ + Op: CollectEntitlementsAndGrantsTasksOp, + ResourceTypeID: action.ResourceTypeID, + PageToken: "", + } + entitlementsAndGrantsTasks, err := ps.collectEntitlementsAndGrantsTasks(ctx, actionForEntitlementsAndGrants) + if err != nil { + l.Error("failed to collect entitlements and grants tasks", zap.Error(err)) + return nil, err + } + + collectedTasks = append(collectedTasks, entitlementsAndGrantsTasks...) + + return collectedTasks, nil +} + +// syncResourcesCollectTasks does the same work as syncResources but collects tasks instead of adding them immediately. +func (ps *parallelSyncer) collectEntitlementsAndGrantsTasks(ctx context.Context, action Action) ([]*task, error) { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.collectEntitlementsAndGrantsTasks") + defer span.End() + + l := ctxzap.Extract(ctx) + var collectedTasks []*task + + // Add panic recovery to catch any unexpected errors + defer func() { + if r := recover(); r != nil { + l.Error("panic in collectEntitlementsAndGrantsTasks", + zap.String("resource_type", action.ResourceTypeID), + zap.Any("panic", r)) + } + }() + + allResourcesResp, err := ps.syncer.store.ListResources(ctx, &v2.ResourcesServiceListResourcesRequest{ + ResourceTypeId: action.ResourceTypeID, + PageToken: action.PageToken, + }) + if err != nil { + l.Error("failed to list resources for task creation", zap.Error(err)) + return nil, err + } + + // Create individual tasks for each resource's entitlements and grants + for _, resource := range allResourcesResp.List { + // Check if we should skip entitlements and grants for this resource + shouldSkip, err := ps.shouldSkipEntitlementsAndGrants(ctx, resource) + if err != nil { + l.Error("failed to check if resource should be skipped", zap.Error(err)) + return nil, err + } + if shouldSkip { + continue + } + + // Create task to sync entitlements for this specific resource + entitlementsTask := &task{ + Action: Action{ + Op: SyncEntitlementsOp, + ResourceTypeID: action.ResourceTypeID, + ResourceID: resource.Id.Resource, + }, + } + collectedTasks = append(collectedTasks, entitlementsTask) + + // Create task to sync grants for this specific resource + grantsTask := &task{ + Action: Action{ + Op: SyncGrantsOp, + ResourceTypeID: action.ResourceTypeID, + ResourceID: resource.Id.Resource, + }, + } + collectedTasks = append(collectedTasks, grantsTask) + } + if allResourcesResp.NextPageToken != "" { + collectedTasks = append(collectedTasks, &task{ + Action: Action{ + Op: CollectEntitlementsAndGrantsTasksOp, + ResourceTypeID: action.ResourceTypeID, + PageToken: allResourcesResp.NextPageToken, + }, + }) + } + + return collectedTasks, nil +} + +// processChildResourceTypes processes child resource types for a given parent resource type. +func (ps *parallelSyncer) processChildResourceTypes(ctx context.Context, parentResourceTypeID string) error { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.processChildResourceTypes") + defer span.End() + + l := ctxzap.Extract(ctx) + + // Get all resources of the parent resource type + resp, err := ps.syncer.store.ListResources(ctx, &v2.ResourcesServiceListResourcesRequest{ + ResourceTypeId: parentResourceTypeID, + PageToken: "", + }) + if err != nil { + l.Error("failed to list parent resources", zap.Error(err)) + return err + } + + // For each parent resource, check if it has child resource types + for _, parentResource := range resp.List { + if err := ps.processChildResourcesForParent(ctx, parentResource); err != nil { + l.Error("failed to process child resources for parent", + zap.Error(err), + zap.String("parent_resource_id", parentResource.Id.Resource), + zap.String("parent_resource_type", parentResource.Id.ResourceType)) + return err + } + } + + return nil +} + +// processChildResourcesForParent processes child resources for a specific parent resource. +func (ps *parallelSyncer) processChildResourcesForParent(ctx context.Context, parentResource *v2.Resource) error { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.processChildResourcesForParent") + defer span.End() + + // Check for ChildResourceType annotations + for _, annotation := range parentResource.Annotations { + var childResourceType v2.ChildResourceType + if err := annotation.UnmarshalTo(&childResourceType); err != nil { + // Not a ChildResourceType annotation, skip + continue + } + + childResourceTypeID := childResourceType.ResourceTypeId + + // Create a task to sync child resources for this parent + childResourcesTask := &task{ + Action: Action{ + Op: SyncResourcesOp, + ResourceTypeID: childResourceTypeID, + ParentResourceTypeID: parentResource.Id.ResourceType, + ParentResourceID: parentResource.Id.Resource, + }, + } + + if err := ps.addTaskWithRetry(ctx, childResourcesTask, taskRetryLimit); err != nil { + return fmt.Errorf("failed to add child resources task for %s under parent %s: %w", + childResourceTypeID, parentResource.Id.Resource, err) + } + } + + return nil +} + +// syncEntitlementsForResourceType processes entitlements for all resources of a resource type. +func (ps *parallelSyncer) syncEntitlementsForResourceType(ctx context.Context, action Action) error { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.syncEntitlementsForResourceType") + defer span.End() + + l := ctxzap.Extract(ctx) + + // Get all resources for this resource type + resp, err := ps.syncer.store.ListResources(ctx, &v2.ResourcesServiceListResourcesRequest{ + ResourceTypeId: action.ResourceTypeID, + PageToken: action.PageToken, + }) + if err != nil { + l.Error("failed to list resources for entitlements", zap.Error(err)) + return err + } + + // Process each resource's entitlements sequentially + for _, r := range resp.List { + // Check if we should skip entitlements for this resource + shouldSkip, err := ps.shouldSkipEntitlementsAndGrants(ctx, r) + if err != nil { + return err + } + if shouldSkip { + continue + } + + // Create local state context for this resource + localState := NewLocalStateContext(r.Id) + + // Use our state-agnostic method to sync entitlements for this specific resource + decision, err := ps.syncEntitlementsForResourceLogic(ctx, r.Id, localState) + if err != nil { + l.Error("failed to sync entitlements for resource", + zap.String("resource_type", r.Id.ResourceType), + zap.String("resource_id", r.Id.Resource), + zap.Error(err)) + return err + } + + // Handle pagination if needed + for decision.ShouldContinue && decision.Action == nextPageAction { + // Update the local state with the new page token before continuing + if err := localState.NextPage(ctx, decision.NextPageToken); err != nil { + l.Error("failed to update local state with next page token", + zap.String("resource_type", r.Id.ResourceType), + zap.String("page_token", decision.NextPageToken), + zap.Error(err)) + return err + } + + // Continue with next page + decision, err = ps.syncEntitlementsForResourceLogic(ctx, r.Id, localState) + if err != nil { + l.Error("failed to sync entitlements for resource on next page", + zap.String("resource_type", r.Id.ResourceType), + zap.String("resource_id", r.Id.Resource), + zap.Error(err)) + return err + } + } + } + + return nil +} + +// syncEntitlementsForResource processes entitlements for a specific resource. +func (ps *parallelSyncer) syncEntitlementsForResource(ctx context.Context, action Action) error { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.syncEntitlementsForResource") + defer span.End() + + l := ctxzap.Extract(ctx) + // Create resource ID from action + resourceID := &v2.ResourceId{ + ResourceType: action.ResourceTypeID, + Resource: action.ResourceID, + } + + // Create local state context for this resource + localState := NewLocalStateContext(resourceID) + + // Use existing logic but for single resource + decision, err := ps.syncEntitlementsForResourceLogic(ctx, resourceID, localState) + if err != nil { + l.Error("failed to sync entitlements for resource", + zap.String("resource_type", action.ResourceTypeID), + zap.String("resource_id", action.ResourceID), + zap.Error(err)) + return err + } + + // Handle pagination if needed + for decision.ShouldContinue && decision.Action == nextPageAction { + // Update the local state with the new page token before continuing + if err := localState.NextPage(ctx, decision.NextPageToken); err != nil { + l.Error("failed to update local state with next page token", + zap.String("resource_type", action.ResourceTypeID), + zap.String("resource_id", action.ResourceID), + zap.String("page_token", decision.NextPageToken), + zap.Error(err)) + return err + } + + // Continue with next page + decision, err = ps.syncEntitlementsForResourceLogic(ctx, resourceID, localState) + if err != nil { + l.Error("failed to sync entitlements for resource on next page", + zap.String("resource_type", action.ResourceTypeID), + zap.String("resource_id", action.ResourceID), + zap.Error(err)) + return err + } + } + + return nil +} + +// syncGrantsForResource processes grants for a specific resource. +func (ps *parallelSyncer) syncGrantsForResource(ctx context.Context, action Action) error { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.syncGrantsForResource") + defer span.End() + + l := ctxzap.Extract(ctx) + // Create resource ID from action + resourceID := &v2.ResourceId{ + ResourceType: action.ResourceTypeID, + Resource: action.ResourceID, + } + + // Create local state context for this resource + localState := NewLocalStateContext(resourceID) + + // Use existing logic but for single resource + decision, err := ps.syncGrantsForResourceLogic(ctx, resourceID, localState) + if err != nil { + l.Error("failed to sync grants for resource", + zap.String("resource_type", action.ResourceTypeID), + zap.String("resource_id", action.ResourceID), + zap.Error(err)) + return err + } + + // Handle pagination if needed + for decision.ShouldContinue && decision.Action == nextPageAction { + // Update the local state with the new page token before continuing + if err := localState.NextPage(ctx, decision.NextPageToken); err != nil { + l.Error("failed to update local state with next page token", + zap.String("resource_type", action.ResourceTypeID), + zap.String("resource_id", action.ResourceID), + zap.String("page_token", decision.NextPageToken), + zap.Error(err)) + return err + } + + // Continue with next page + decision, err = ps.syncGrantsForResourceLogic(ctx, resourceID, localState) + if err != nil { + l.Error("failed to sync grants for resource on next page", + zap.String("resource_type", action.ResourceTypeID), + zap.String("resource_id", action.ResourceID), + zap.Error(err)) + return err + } + } + + return nil +} + +// syncGrantsForResourceType processes grants for all resources of a resource type. +func (ps *parallelSyncer) syncGrantsForResourceType(ctx context.Context, action Action) error { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.syncGrantsForResourceType") + defer span.End() + + l := ctxzap.Extract(ctx) + + // Get all resources for this resource type + resp, err := ps.syncer.store.ListResources(ctx, &v2.ResourcesServiceListResourcesRequest{ + ResourceTypeId: action.ResourceTypeID, + PageToken: action.PageToken, + }) + if err != nil { + l.Error("failed to list resources for grants", zap.Error(err)) + return err + } + + // Process each resource's grants sequentially + for _, r := range resp.List { + // Check if we should skip grants for this resource + shouldSkip, err := ps.shouldSkipEntitlementsAndGrants(ctx, r) + if err != nil { + return err + } + if shouldSkip { + continue + } + + // Create local state context for this resource + localState := NewLocalStateContext(r.Id) + + // Use our state-agnostic method to sync grants for this specific resource + decision, err := ps.syncGrantsForResourceLogic(ctx, r.Id, localState) + if err != nil { + l.Error("failed to sync grants for resource", + zap.String("resource_type", r.Id.ResourceType), + zap.String("resource_id", r.Id.Resource), + zap.Error(err)) + return err + } + + // Handle pagination if needed + for decision.ShouldContinue && decision.Action == nextPageAction { + // Update the local state with the new page token before continuing + if err := localState.NextPage(ctx, decision.NextPageToken); err != nil { + l.Error("failed to update local state with next page token", + zap.String("resource_type", r.Id.ResourceType), + zap.String("resource_id", r.Id.Resource), + zap.String("page_token", decision.NextPageToken), + zap.Error(err)) + return err + } + + // Continue with next page + decision, err = ps.syncGrantsForResourceLogic(ctx, r.Id, localState) + if err != nil { + l.Error("failed to sync grants for resource on next page", + zap.String("resource_type", r.Id.ResourceType), + zap.String("resource_id", r.Id.Resource), + zap.Error(err)) + return err + } + } + } + + return nil +} + +// syncGrantsForResourceLogic contains the core logic for syncing grants for a resource. +// This method is state-agnostic and returns an ActionDecision for the caller to handle. +func (ps *parallelSyncer) syncGrantsForResourceLogic(ctx context.Context, resourceID *v2.ResourceId, state StateInterface) (*ActionDecision, error) { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.syncGrantsForResourceLogic") + defer span.End() + + l := ctxzap.Extract(ctx) + + // Get the resource from the store + resourceResponse, err := ps.syncer.store.GetResource(ctx, &reader_v2.ResourcesReaderServiceGetResourceRequest{ + ResourceId: resourceID, + }) + if err != nil { + return nil, fmt.Errorf("error fetching resource '%s': %w", resourceID.Resource, err) + } + + resource := resourceResponse.Resource + + var prevSyncID string + var prevEtag *v2.ETag + var etagMatch bool + var grants []*v2.Grant + + resourceAnnos := annotations.Annotations(resource.GetAnnotations()) + pageToken := state.PageToken(ctx) + + prevSyncID, prevEtag, err = ps.syncer.fetchResourceForPreviousSync(ctx, resourceID) + if err != nil { + return nil, err + } + resourceAnnos.Update(prevEtag) + resource.Annotations = resourceAnnos + + resp, err := ps.syncer.connector.ListGrants(ctx, &v2.GrantsServiceListGrantsRequest{Resource: resource, PageToken: pageToken}) + if err != nil { + return nil, err + } + + // Fetch any etagged grants for this resource + var etaggedGrants []*v2.Grant + etaggedGrants, etagMatch, err = ps.syncer.fetchEtaggedGrantsForResource(ctx, resource, prevEtag, prevSyncID, resp) + if err != nil { + return nil, err + } + grants = append(grants, etaggedGrants...) + + // We want to process any grants from the previous sync first so that if there is a conflict, the newer data takes precedence + grants = append(grants, resp.List...) + + // Process grants and collect state information + needsExpansion := false + hasExternalResources := false + shouldFetchRelated := state.ShouldFetchRelatedResources() + + for _, grant := range grants { + grantAnnos := annotations.Annotations(grant.GetAnnotations()) + if grantAnnos.Contains(&v2.GrantExpandable{}) { + needsExpansion = true + state.SetNeedsExpansion() + } + if grantAnnos.ContainsAny(&v2.ExternalResourceMatchAll{}, &v2.ExternalResourceMatch{}, &v2.ExternalResourceMatchID{}) { + hasExternalResources = true + state.SetHasExternalResourcesGrants() + } + + if !shouldFetchRelated { + continue + } + // Some connectors emit grants for other resources. If we're doing a partial sync, check if it exists and queue a fetch if not. + entitlementResource := grant.GetEntitlement().GetResource() + _, err := ps.syncer.store.GetResource(ctx, &reader_v2.ResourcesReaderServiceGetResourceRequest{ + ResourceId: entitlementResource.GetId(), + }) + if err != nil { + if !errors.Is(err, sql.ErrNoRows) { + return nil, err + } + + erId := entitlementResource.GetId() + prId := entitlementResource.GetParentResourceId() + resource, err := ps.syncer.getResourceFromConnector(ctx, erId, prId) + if err != nil { + l.Error("error fetching entitlement resource", zap.Error(err)) + return nil, err + } + if resource == nil { + continue + } + if err := ps.syncer.store.PutResources(ctx, resource); err != nil { + return nil, err + } + } + } + + // Store the grants + err = ps.syncer.store.PutGrants(ctx, grants...) + if err != nil { + return nil, err + } + + // Update progress counts + ps.syncer.counts.AddGrantsProgress(resourceID.ResourceType, 1) + ps.syncer.counts.LogGrantsProgress(ctx, resourceID.ResourceType) + + // We may want to update the etag on the resource. If we matched a previous etag, then we should use that. + // Otherwise, we should use the etag from the response if provided. + var updatedETag *v2.ETag + + if etagMatch { + updatedETag = prevEtag + } else { + newETag := &v2.ETag{} + respAnnos := annotations.Annotations(resp.GetAnnotations()) + ok, err := respAnnos.Pick(newETag) + if err != nil { + return nil, err + } + if ok { + updatedETag = newETag + } + } + + if updatedETag != nil { + resourceAnnos.Update(updatedETag) + resource.Annotations = resourceAnnos + err = ps.syncer.store.PutResources(ctx, resource) + if err != nil { + return nil, err + } + } + + // Check if we need to continue with pagination + if resp.NextPageToken != "" { + return &ActionDecision{ + ShouldContinue: true, + NextPageToken: resp.NextPageToken, + Action: nextPageAction, + NeedsExpansion: needsExpansion, + HasExternalResources: hasExternalResources, + ShouldFetchRelated: shouldFetchRelated, + }, nil + } + + // No more pages, action is complete + return &ActionDecision{ + ShouldContinue: false, + Action: finishAction, + NeedsExpansion: needsExpansion, + HasExternalResources: hasExternalResources, + ShouldFetchRelated: shouldFetchRelated, + }, nil +} + +// syncEntitlementsForResourceLogic contains the core logic for syncing entitlements for a resource. +// This method is state-agnostic and returns an ActionDecision for the caller to handle. +func (ps *parallelSyncer) syncEntitlementsForResourceLogic(ctx context.Context, resourceID *v2.ResourceId, state StateInterface) (*ActionDecision, error) { + ctx, span := parallelTracer.Start(ctx, "parallelSyncer.syncEntitlementsForResourceLogic") + defer span.End() + + // Get the resource from the store + resourceResponse, err := ps.syncer.store.GetResource(ctx, &reader_v2.ResourcesReaderServiceGetResourceRequest{ + ResourceId: resourceID, + }) + if err != nil { + return nil, fmt.Errorf("error fetching resource '%s': %w", resourceID.Resource, err) + } + + resource := resourceResponse.Resource + pageToken := state.PageToken(ctx) + + // Call the connector to list entitlements for this resource + resp, err := ps.syncer.connector.ListEntitlements(ctx, &v2.EntitlementsServiceListEntitlementsRequest{ + Resource: resource, + PageToken: pageToken, + }) + if err != nil { + return nil, err + } + + // Store the entitlements + err = ps.syncer.store.PutEntitlements(ctx, resp.List...) + if err != nil { + return nil, err + } + + // Update progress counts + ps.syncer.counts.AddEntitlementsProgress(resourceID.ResourceType, 1) + ps.syncer.counts.LogEntitlementsProgress(ctx, resourceID.ResourceType) + + // Check if we need to continue with pagination + if resp.NextPageToken != "" { + return &ActionDecision{ + ShouldContinue: true, + NextPageToken: resp.NextPageToken, + Action: nextPageAction, + }, nil + } + + // No more pages, action is complete + return &ActionDecision{ + ShouldContinue: false, + Action: finishAction, + }, nil +} + +// shouldSkipEntitlementsAndGrants checks if entitlements and grants should be skipped for a resource. +func (ps *parallelSyncer) shouldSkipEntitlementsAndGrants(ctx context.Context, r *v2.Resource) (bool, error) { + // This replicates the logic from the original shouldSkipEntitlementsAndGrants method + // Check if the resource has the SkipEntitlementsAndGrants annotation + + for _, a := range r.Annotations { + if a.MessageIs((*v2.SkipEntitlementsAndGrants)(nil)) { + return true, nil + } + } + + return false, nil +} + +// Close implements the Syncer interface. +func (ps *parallelSyncer) Close(ctx context.Context) error { + // Stop all workers + ps.stopWorkers() + + // Close the task queue + if ps.taskQueue != nil { + ps.taskQueue.Close() + } + + // Call the base syncer's Close method + return ps.syncer.Close(ctx) +} + +// GetBucketStats returns statistics about all buckets. +func (ps *parallelSyncer) GetBucketStats() map[string]int { + if ps.taskQueue == nil { + return make(map[string]int) + } + return ps.taskQueue.GetBucketStats() +} + +// GetWorkerStatus returns the status of all workers. +func (ps *parallelSyncer) GetWorkerStatus() []map[string]interface{} { + ps.mu.RLock() + defer ps.mu.RUnlock() + + status := make([]map[string]interface{}, len(ps.workers)) + for i, worker := range ps.workers { + status[i] = map[string]interface{}{ + "worker_id": worker.id, + "rate_limited": worker.rateLimited.Load(), + } + } + return status +} + +// NewParallelSyncerFromSyncer creates a parallel syncer from an existing syncer. +func NewParallelSyncerFromSyncer(s Syncer, config *ParallelSyncConfig) (*parallelSyncer, error) { + // Try to cast to the concrete syncer type + if baseSyncer, ok := s.(*SequentialSyncer); ok { + return NewParallelSyncer(baseSyncer, config), nil + } + + return nil, fmt.Errorf("cannot create parallel syncer from syncer type: %T", s) +} diff --git a/pkg/sync/state.go b/pkg/sync/state.go index 71e14911b..184c3b096 100644 --- a/pkg/sync/state.go +++ b/pkg/sync/state.go @@ -66,6 +66,8 @@ func (s ActionOp) String() string { return "targeted-resource-sync" case SyncStaticEntitlementsOp: return "list-static-entitlements" + case CollectEntitlementsAndGrantsTasksOp: + return "collect-entitlements-and-grants-tasks" default: return "unknown" } @@ -113,6 +115,8 @@ func newActionOp(str string) ActionOp { return SyncStaticEntitlementsOp case ListResourcesForEntitlementsOp.String(): return ListResourcesForEntitlementsOp + case CollectEntitlementsAndGrantsTasksOp.String(): + return CollectEntitlementsAndGrantsTasksOp default: return UnknownOp } @@ -133,6 +137,7 @@ const ( SyncGrantExpansionOp SyncTargetedResourceOp SyncStaticEntitlementsOp + CollectEntitlementsAndGrantsTasksOp ) // Action stores the current operation, page token, and optional fields for which resource is being worked with. diff --git a/pkg/sync/syncer.go b/pkg/sync/syncer.go index 48fb33b74..a7e3502b6 100644 --- a/pkg/sync/syncer.go +++ b/pkg/sync/syncer.go @@ -12,6 +12,7 @@ import ( "slices" "strconv" "strings" + "sync" "time" "github.com/Masterminds/semver/v3" @@ -63,6 +64,8 @@ type ProgressCounts struct { GrantsProgress map[string]int LastGrantLog map[string]time.Time LastActionLog time.Time + mu sync.RWMutex // Protect concurrent access to maps + sequentialMode bool // Disable mutex protection for sequential sync } const maxLogFrequency = 10 * time.Second @@ -76,6 +79,7 @@ func NewProgressCounts() *ProgressCounts { GrantsProgress: make(map[string]int), LastGrantLog: make(map[string]time.Time), LastActionLog: time.Time{}, + sequentialMode: true, // Default to sequential mode for backward compatibility } } @@ -85,24 +89,50 @@ func (p *ProgressCounts) LogResourceTypesProgress(ctx context.Context) { } func (p *ProgressCounts) LogResourcesProgress(ctx context.Context, resourceType string) { + var resources int + if p.sequentialMode { + resources = p.Resources[resourceType] + } else { + p.mu.RLock() + resources = p.Resources[resourceType] + p.mu.RUnlock() + } + l := ctxzap.Extract(ctx) - resources := p.Resources[resourceType] l.Info("Synced resources", zap.String("resource_type_id", resourceType), zap.Int("count", resources)) } func (p *ProgressCounts) LogEntitlementsProgress(ctx context.Context, resourceType string) { - entitlementsProgress := p.EntitlementsProgress[resourceType] - resources := p.Resources[resourceType] + var entitlementsProgress, resources int + var lastLogTime time.Time + + if p.sequentialMode { + entitlementsProgress = p.EntitlementsProgress[resourceType] + resources = p.Resources[resourceType] + lastLogTime = p.LastEntitlementLog[resourceType] + } else { + p.mu.RLock() + entitlementsProgress = p.EntitlementsProgress[resourceType] + resources = p.Resources[resourceType] + lastLogTime = p.LastEntitlementLog[resourceType] + p.mu.RUnlock() + } l := ctxzap.Extract(ctx) if resources == 0 { // if resuming sync, resource counts will be zero, so don't calculate percentage. just log every 10 seconds. - if time.Since(p.LastEntitlementLog[resourceType]) > maxLogFrequency { + if time.Since(lastLogTime) > maxLogFrequency { l.Info("Syncing entitlements", zap.String("resource_type_id", resourceType), zap.Int("synced", entitlementsProgress), ) - p.LastEntitlementLog[resourceType] = time.Now() + if !p.sequentialMode { + p.mu.Lock() + p.LastEntitlementLog[resourceType] = time.Now() + p.mu.Unlock() + } else { + p.LastEntitlementLog[resourceType] = time.Now() + } } return } @@ -116,8 +146,14 @@ func (p *ProgressCounts) LogEntitlementsProgress(ctx context.Context, resourceTy zap.Int("count", entitlementsProgress), zap.Int("total", resources), ) - p.LastEntitlementLog[resourceType] = time.Time{} - case time.Since(p.LastEntitlementLog[resourceType]) > maxLogFrequency: + if !p.sequentialMode { + p.mu.Lock() + p.LastEntitlementLog[resourceType] = time.Time{} + p.mu.Unlock() + } else { + p.LastEntitlementLog[resourceType] = time.Time{} + } + case time.Since(lastLogTime) > maxLogFrequency: if entitlementsProgress > resources { l.Warn("more entitlement resources than resources", zap.String("resource_type_id", resourceType), @@ -132,23 +168,47 @@ func (p *ProgressCounts) LogEntitlementsProgress(ctx context.Context, resourceTy zap.Int("percent_complete", percentComplete), ) } - p.LastEntitlementLog[resourceType] = time.Now() + if !p.sequentialMode { + p.mu.Lock() + p.LastEntitlementLog[resourceType] = time.Now() + p.mu.Unlock() + } else { + p.LastEntitlementLog[resourceType] = time.Now() + } } } func (p *ProgressCounts) LogGrantsProgress(ctx context.Context, resourceType string) { - grantsProgress := p.GrantsProgress[resourceType] - resources := p.Resources[resourceType] + var grantsProgress, resources int + var lastLogTime time.Time + + if p.sequentialMode { + grantsProgress = p.GrantsProgress[resourceType] + resources = p.Resources[resourceType] + lastLogTime = p.LastGrantLog[resourceType] + } else { + p.mu.RLock() + grantsProgress = p.GrantsProgress[resourceType] + resources = p.Resources[resourceType] + lastLogTime = p.LastGrantLog[resourceType] + p.mu.RUnlock() + } l := ctxzap.Extract(ctx) if resources == 0 { // if resuming sync, resource counts will be zero, so don't calculate percentage. just log every 10 seconds. - if time.Since(p.LastGrantLog[resourceType]) > maxLogFrequency { + if time.Since(lastLogTime) > maxLogFrequency { l.Info("Syncing grants", zap.String("resource_type_id", resourceType), zap.Int("synced", grantsProgress), ) - p.LastGrantLog[resourceType] = time.Now() + if !p.sequentialMode { + p.mu.Lock() + p.LastGrantLog[resourceType] = time.Now() + p.mu.Unlock() + } else { + p.LastGrantLog[resourceType] = time.Now() + } } return } @@ -162,8 +222,14 @@ func (p *ProgressCounts) LogGrantsProgress(ctx context.Context, resourceType str zap.Int("count", grantsProgress), zap.Int("total", resources), ) - p.LastGrantLog[resourceType] = time.Time{} - case time.Since(p.LastGrantLog[resourceType]) > maxLogFrequency: + if !p.sequentialMode { + p.mu.Lock() + p.LastGrantLog[resourceType] = time.Time{} + p.mu.Unlock() + } else { + p.LastGrantLog[resourceType] = time.Time{} + } + case time.Since(lastLogTime) > maxLogFrequency: if grantsProgress > resources { l.Warn("more grant resources than resources", zap.String("resource_type_id", resourceType), @@ -178,23 +244,75 @@ func (p *ProgressCounts) LogGrantsProgress(ctx context.Context, resourceType str zap.Int("percent_complete", percentComplete), ) } - p.LastGrantLog[resourceType] = time.Now() + if !p.sequentialMode { + p.mu.Lock() + p.LastGrantLog[resourceType] = time.Now() + p.mu.Unlock() + } else { + p.LastGrantLog[resourceType] = time.Now() + } } } func (p *ProgressCounts) LogExpandProgress(ctx context.Context, actions []*expand.EntitlementGraphAction) { actionsLen := len(actions) - if time.Since(p.LastActionLog) < maxLogFrequency { - return + + if p.sequentialMode { + if time.Since(p.LastActionLog) < maxLogFrequency { + return + } + p.LastActionLog = time.Now() + } else { + p.mu.Lock() + if time.Since(p.LastActionLog) < maxLogFrequency { + p.mu.Unlock() + return + } + p.LastActionLog = time.Now() + p.mu.Unlock() } - p.LastActionLog = time.Now() l := ctxzap.Extract(ctx) l.Info("Expanding grants", zap.Int("actions_remaining", actionsLen)) } -// syncer orchestrates a connector sync and stores the results using the provided datasource.Writer. -type syncer struct { +// Thread-safe methods for parallel syncer + +// AddResourceTypes safely adds to the resource types count. +func (p *ProgressCounts) AddResourceTypes(count int) { + p.mu.Lock() + defer p.mu.Unlock() + p.ResourceTypes += count +} + +// AddResources safely adds to the resources count for a specific resource type. +func (p *ProgressCounts) AddResources(resourceType string, count int) { + p.mu.Lock() + defer p.mu.Unlock() + p.Resources[resourceType] += count +} + +// AddEntitlementsProgress safely adds to the entitlements progress count for a specific resource type. +func (p *ProgressCounts) AddEntitlementsProgress(resourceType string, count int) { + p.mu.Lock() + defer p.mu.Unlock() + p.EntitlementsProgress[resourceType] += count +} + +// AddGrantsProgress safely adds to the grants progress count for a specific resource type. +func (p *ProgressCounts) AddGrantsProgress(resourceType string, count int) { + p.mu.Lock() + defer p.mu.Unlock() + p.GrantsProgress[resourceType] += count +} + +// SetSequentialMode enables/disables mutex protection for sequential sync. +func (p *ProgressCounts) SetSequentialMode(sequential bool) { + p.sequentialMode = sequential +} + +// SequentialSyncer orchestrates a connector sync and stores the results using the provided datasource.Writer. +type SequentialSyncer struct { c1zManager manager.Manager c1zPath string externalResourceC1ZPath string @@ -222,12 +340,13 @@ type syncer struct { injectSyncIDAnnotation bool setSessionStore sessions.SetSessionStore syncResourceTypes []string + enableWALCheckpoint bool } const minCheckpointInterval = 10 * time.Second // Checkpoint marshals the current state and stores it. -func (s *syncer) Checkpoint(ctx context.Context, force bool) error { +func (s *SequentialSyncer) Checkpoint(ctx context.Context, force bool) error { if !force && !s.lastCheckPointTime.IsZero() && time.Since(s.lastCheckPointTime) < minCheckpointInterval { return nil } @@ -247,13 +366,13 @@ func (s *syncer) Checkpoint(ctx context.Context, force bool) error { return nil } -func (s *syncer) handleInitialActionForStep(ctx context.Context, a Action) { +func (s *SequentialSyncer) handleInitialActionForStep(ctx context.Context, a Action) { if s.transitionHandler != nil { s.transitionHandler(a) } } -func (s *syncer) handleProgress(ctx context.Context, a *Action, c int) { +func (s *SequentialSyncer) handleProgress(ctx context.Context, a *Action, c int) { if s.progressHandler != nil { //nolint:gosec // No risk of overflow because `c` is a slice length. count := uint32(c) @@ -273,7 +392,7 @@ func isWarning(ctx context.Context, err error) bool { return false } -func (s *syncer) startOrResumeSync(ctx context.Context) (string, bool, error) { +func (s *SequentialSyncer) startOrResumeSync(ctx context.Context) (string, bool, error) { // Sync resuming logic: // If we know our sync ID, set it as the current sync and return (resuming that sync). // If targetedSyncResourceIDs is not set, find the most recent unfinished sync of our desired sync type & resume it (regardless of partial or full). @@ -320,7 +439,7 @@ func (s *syncer) startOrResumeSync(ctx context.Context) (string, bool, error) { return syncID, newSync, nil } -func (s *syncer) getActiveSyncID() string { +func (s *SequentialSyncer) getActiveSyncID() string { if s.injectSyncIDAnnotation { return s.syncID } @@ -331,7 +450,7 @@ func (s *syncer) getActiveSyncID() string { // For each page of data that is required to be fetched from the connector, a new action is pushed on to the stack. Once // an action is completed, it is popped off of the queue. Before processing each action, we checkpoint the state object // into the datasource. This allows for graceful resumes if a sync is interrupted. -func (s *syncer) Sync(ctx context.Context) error { +func (s *SequentialSyncer) Sync(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.Sync") defer span.End() @@ -689,7 +808,7 @@ func (s *syncer) Sync(ctx context.Context) error { return nil } -func (s *syncer) SkipSync(ctx context.Context) error { +func (s *SequentialSyncer) SkipSync(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.SkipSync") defer span.End() @@ -733,7 +852,7 @@ func (s *syncer) SkipSync(ctx context.Context) error { return nil } -func (s *syncer) listAllResourceTypes(ctx context.Context) iter.Seq2[[]*v2.ResourceType, error] { +func (s *SequentialSyncer) listAllResourceTypes(ctx context.Context) iter.Seq2[[]*v2.ResourceType, error] { return func(yield func([]*v2.ResourceType, error) bool) { pageToken := "" for { @@ -757,7 +876,7 @@ func (s *syncer) listAllResourceTypes(ctx context.Context) iter.Seq2[[]*v2.Resou } // SyncResourceTypes calls the ListResourceType() connector endpoint and persists the results in to the datasource. -func (s *syncer) SyncResourceTypes(ctx context.Context) error { +func (s *SequentialSyncer) SyncResourceTypes(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.SyncResourceTypes") defer span.End() @@ -844,7 +963,7 @@ func validateSyncResourceTypesFilter(resourceTypesFilter []string, validResource } // getSubResources fetches the sub resource types from a resources' annotations. -func (s *syncer) getSubResources(ctx context.Context, parent *v2.Resource) error { +func (s *SequentialSyncer) getSubResources(ctx context.Context, parent *v2.Resource) error { ctx, span := tracer.Start(ctx, "syncer.getSubResources") defer span.End() @@ -869,7 +988,7 @@ func (s *syncer) getSubResources(ctx context.Context, parent *v2.Resource) error return nil } -func (s *syncer) getResourceFromConnector(ctx context.Context, resourceID *v2.ResourceId, parentResourceID *v2.ResourceId) (*v2.Resource, error) { +func (s *SequentialSyncer) getResourceFromConnector(ctx context.Context, resourceID *v2.ResourceId, parentResourceID *v2.ResourceId) (*v2.Resource, error) { ctx, span := tracer.Start(ctx, "syncer.getResource") defer span.End() @@ -895,7 +1014,7 @@ func (s *syncer) getResourceFromConnector(ctx context.Context, resourceID *v2.Re return nil, err } -func (s *syncer) SyncTargetedResource(ctx context.Context) error { +func (s *SequentialSyncer) SyncTargetedResource(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.SyncTargetedResource") defer span.End() @@ -973,7 +1092,7 @@ func (s *syncer) SyncTargetedResource(ctx context.Context) error { // SyncResources handles fetching all of the resources from the connector given the provided resource types. For each // resource, we gather any child resource types it may emit, and traverse the resource tree. -func (s *syncer) SyncResources(ctx context.Context) error { +func (s *SequentialSyncer) SyncResources(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.SyncResources") defer span.End() @@ -1017,7 +1136,7 @@ func (s *syncer) SyncResources(ctx context.Context) error { } // syncResources fetches a given resource from the connector, and returns a slice of new child resources to fetch. -func (s *syncer) syncResources(ctx context.Context) error { +func (s *SequentialSyncer) syncResources(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.syncResources") defer span.End() @@ -1093,7 +1212,7 @@ func (s *syncer) syncResources(ctx context.Context) error { return nil } -func (s *syncer) validateResourceTraits(ctx context.Context, r *v2.Resource) error { +func (s *SequentialSyncer) validateResourceTraits(ctx context.Context, r *v2.Resource) error { ctx, span := tracer.Start(ctx, "syncer.validateResourceTraits") defer span.End() @@ -1144,7 +1263,7 @@ func (s *syncer) validateResourceTraits(ctx context.Context, r *v2.Resource) err // shouldSkipEntitlementsAndGrants determines if we should sync entitlements for a given resource. We cache the // result of this function for each resource type to avoid constant lookups in the database. -func (s *syncer) shouldSkipEntitlementsAndGrants(ctx context.Context, r *v2.Resource) (bool, error) { +func (s *SequentialSyncer) shouldSkipEntitlementsAndGrants(ctx context.Context, r *v2.Resource) (bool, error) { ctx, span := tracer.Start(ctx, "syncer.shouldSkipEntitlementsAndGrants") defer span.End() @@ -1177,7 +1296,7 @@ func (s *syncer) shouldSkipEntitlementsAndGrants(ctx context.Context, r *v2.Reso return skipEntitlements, nil } -func (s *syncer) shouldSkipGrants(ctx context.Context, r *v2.Resource) (bool, error) { +func (s *SequentialSyncer) shouldSkipGrants(ctx context.Context, r *v2.Resource) (bool, error) { annos := annotations.Annotations(r.GetAnnotations()) if annos.Contains(&v2.SkipGrants{}) { return true, nil @@ -1192,7 +1311,7 @@ func (s *syncer) shouldSkipGrants(ctx context.Context, r *v2.Resource) (bool, er // SyncEntitlements fetches the entitlements from the connector. It first lists each resource from the datastore, // and pushes an action to fetch the entitlements for each resource. -func (s *syncer) SyncEntitlements(ctx context.Context) error { +func (s *SequentialSyncer) SyncEntitlements(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.SyncEntitlements") defer span.End() @@ -1245,7 +1364,7 @@ func (s *syncer) SyncEntitlements(ctx context.Context) error { } // syncEntitlementsForResource fetches the entitlements for a specific resource from the connector. -func (s *syncer) syncEntitlementsForResource(ctx context.Context, resourceID *v2.ResourceId) error { +func (s *SequentialSyncer) syncEntitlementsForResource(ctx context.Context, resourceID *v2.ResourceId) error { ctx, span := tracer.Start(ctx, "syncer.syncEntitlementsForResource") defer span.End() @@ -1281,7 +1400,13 @@ func (s *syncer) syncEntitlementsForResource(ctx context.Context, resourceID *v2 return err } } else { - s.counts.EntitlementsProgress[resourceID.GetResourceType()] += 1 + if s.counts.sequentialMode { + s.counts.EntitlementsProgress[resourceID.ResourceType] += 1 + } else { + s.counts.mu.Lock() + s.counts.EntitlementsProgress[resourceID.ResourceType] += 1 + s.counts.mu.Unlock() + } s.counts.LogEntitlementsProgress(ctx, resourceID.GetResourceType()) s.state.FinishAction(ctx) @@ -1290,7 +1415,7 @@ func (s *syncer) syncEntitlementsForResource(ctx context.Context, resourceID *v2 return nil } -func (s *syncer) SyncStaticEntitlements(ctx context.Context) error { +func (s *SequentialSyncer) SyncStaticEntitlements(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.SyncStaticEntitlements") defer span.End() @@ -1315,7 +1440,7 @@ func (s *syncer) SyncStaticEntitlements(ctx context.Context) error { return nil } -func (s *syncer) syncStaticEntitlementsForResourceType(ctx context.Context, resourceTypeID string) error { +func (s *SequentialSyncer) syncStaticEntitlementsForResourceType(ctx context.Context, resourceTypeID string) error { ctx, span := tracer.Start(ctx, "syncer.syncStaticEntitlementsForResource") defer span.End() @@ -1396,7 +1521,7 @@ func (s *syncer) syncStaticEntitlementsForResourceType(ctx context.Context, reso // syncAssetsForResource looks up a resource given the input ID. From there it looks to see if there are any traits that // include references to an asset. For each AssetRef, we then call GetAsset on the connector and stream the asset from the connector. // Once we have the entire asset, we put it in the database. -func (s *syncer) syncAssetsForResource(ctx context.Context, resourceID *v2.ResourceId) error { +func (s *SequentialSyncer) syncAssetsForResource(ctx context.Context, resourceID *v2.ResourceId) error { ctx, span := tracer.Start(ctx, "syncer.syncAssetsForResource") defer span.End() @@ -1504,7 +1629,7 @@ func (s *syncer) syncAssetsForResource(ctx context.Context, resourceID *v2.Resou } // SyncAssets iterates each resource in the data store, and adds an action to fetch all of the assets for that resource. -func (s *syncer) SyncAssets(ctx context.Context) error { +func (s *SequentialSyncer) SyncAssets(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.SyncAssets") defer span.End() @@ -1551,7 +1676,7 @@ func (s *syncer) SyncAssets(ctx context.Context) error { } // SyncGrantExpansion documentation pending. -func (s *syncer) SyncGrantExpansion(ctx context.Context) error { +func (s *SequentialSyncer) SyncGrantExpansion(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.SyncGrantExpansion") defer span.End() @@ -1687,7 +1812,7 @@ func (s *syncer) SyncGrantExpansion(ctx context.Context) error { // SyncGrants fetches the grants for each resource from the connector. It iterates each resource // from the datastore, and pushes a new action to sync the grants for each individual resource. -func (s *syncer) SyncGrants(ctx context.Context) error { +func (s *SequentialSyncer) SyncGrants(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.SyncGrants") defer span.End() @@ -1743,7 +1868,7 @@ type latestSyncFetcher interface { LatestFinishedSync(ctx context.Context, syncType connectorstore.SyncType) (string, error) } -func (s *syncer) fetchResourceForPreviousSync(ctx context.Context, resourceID *v2.ResourceId) (string, *v2.ETag, error) { +func (s *SequentialSyncer) fetchResourceForPreviousSync(ctx context.Context, resourceID *v2.ResourceId) (string, *v2.ETag, error) { ctx, span := tracer.Start(ctx, "syncer.fetchResourceForPreviousSync") defer span.End() @@ -1797,7 +1922,7 @@ func (s *syncer) fetchResourceForPreviousSync(ctx context.Context, resourceID *v return previousSyncID, nil, nil } -func (s *syncer) fetchEtaggedGrantsForResource( +func (s *SequentialSyncer) fetchEtaggedGrantsForResource( ctx context.Context, resource *v2.Resource, prevEtag *v2.ETag, @@ -1864,7 +1989,7 @@ func (s *syncer) fetchEtaggedGrantsForResource( } // syncGrantsForResource fetches the grants for a specific resource from the connector. -func (s *syncer) syncGrantsForResource(ctx context.Context, resourceID *v2.ResourceId) error { +func (s *SequentialSyncer) syncGrantsForResource(ctx context.Context, resourceID *v2.ResourceId) error { ctx, span := tracer.Start(ctx, "syncer.syncGrantsForResource") defer span.End() @@ -2024,7 +2149,7 @@ func (s *syncer) syncGrantsForResource(ctx context.Context, resourceID *v2.Resou return nil } -func (s *syncer) SyncExternalResources(ctx context.Context) error { +func (s *SequentialSyncer) SyncExternalResources(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.SyncExternalResources") defer span.End() @@ -2038,7 +2163,7 @@ func (s *syncer) SyncExternalResources(ctx context.Context) error { } } -func (s *syncer) SyncExternalResourcesWithGrantToEntitlement(ctx context.Context, entitlementId string) error { +func (s *SequentialSyncer) SyncExternalResourcesWithGrantToEntitlement(ctx context.Context, entitlementId string) error { ctx, span := tracer.Start(ctx, "syncer.SyncExternalResourcesWithGrantToEntitlement") defer span.End() @@ -2175,7 +2300,7 @@ func (s *syncer) SyncExternalResourcesWithGrantToEntitlement(ctx context.Context return nil } -func (s *syncer) SyncExternalResourcesUsersAndGroups(ctx context.Context) error { +func (s *SequentialSyncer) SyncExternalResourcesUsersAndGroups(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.SyncExternalResourcesUsersAndGroups") defer span.End() @@ -2289,7 +2414,7 @@ func (s *syncer) SyncExternalResourcesUsersAndGroups(ctx context.Context) error return nil } -func (s *syncer) listExternalResourcesForResourceType(ctx context.Context, resourceTypeId string) ([]*v2.Resource, error) { +func (s *SequentialSyncer) listExternalResourcesForResourceType(ctx context.Context, resourceTypeId string) ([]*v2.Resource, error) { resources := make([]*v2.Resource, 0) pageToken := "" for { @@ -2309,7 +2434,7 @@ func (s *syncer) listExternalResourcesForResourceType(ctx context.Context, resou return resources, nil } -func (s *syncer) listExternalEntitlementsForResource(ctx context.Context, resource *v2.Resource) ([]*v2.Entitlement, error) { +func (s *SequentialSyncer) listExternalEntitlementsForResource(ctx context.Context, resource *v2.Resource) ([]*v2.Entitlement, error) { ents := make([]*v2.Entitlement, 0) entitlementToken := "" @@ -2330,7 +2455,7 @@ func (s *syncer) listExternalEntitlementsForResource(ctx context.Context, resour return ents, nil } -func (s *syncer) listExternalGrantsForEntitlement(ctx context.Context, ent *v2.Entitlement) iter.Seq2[[]*v2.Grant, error] { +func (s *SequentialSyncer) listExternalGrantsForEntitlement(ctx context.Context, ent *v2.Entitlement) iter.Seq2[[]*v2.Grant, error] { return func(yield func([]*v2.Grant, error) bool) { pageToken := "" for { @@ -2356,7 +2481,7 @@ func (s *syncer) listExternalGrantsForEntitlement(ctx context.Context, ent *v2.E } } -func (s *syncer) listExternalResourceTypes(ctx context.Context) ([]*v2.ResourceType, error) { +func (s *SequentialSyncer) listExternalResourceTypes(ctx context.Context) ([]*v2.ResourceType, error) { resourceTypes := make([]*v2.ResourceType, 0) rtPageToken := "" for { @@ -2375,7 +2500,7 @@ func (s *syncer) listExternalResourceTypes(ctx context.Context) ([]*v2.ResourceT return resourceTypes, nil } -func (s *syncer) listAllGrants(ctx context.Context) iter.Seq2[[]*v2.Grant, error] { +func (s *SequentialSyncer) listAllGrants(ctx context.Context) iter.Seq2[[]*v2.Grant, error] { return func(yield func([]*v2.Grant, error) bool) { pageToken := "" for { @@ -2400,7 +2525,7 @@ func (s *syncer) listAllGrants(ctx context.Context) iter.Seq2[[]*v2.Grant, error } } -func (s *syncer) processGrantsWithExternalPrincipals(ctx context.Context, principals []*v2.Resource) error { +func (s *SequentialSyncer) processGrantsWithExternalPrincipals(ctx context.Context, principals []*v2.Resource) error { ctx, span := tracer.Start(ctx, "processGrantsWithExternalPrincipals") defer span.End() @@ -2707,7 +2832,7 @@ func GetExpandableAnnotation(annos annotations.Annotations) (*v2.GrantExpandable return expandableAnno, nil } -func (s *syncer) runGrantExpandActions(ctx context.Context) (bool, error) { +func (s *SequentialSyncer) runGrantExpandActions(ctx context.Context) (bool, error) { ctx, span := tracer.Start(ctx, "syncer.runGrantExpandActions") defer span.End() @@ -2862,7 +2987,7 @@ func (s *syncer) runGrantExpandActions(ctx context.Context) (bool, error) { return false, nil } -func (s *syncer) newExpandedGrant(_ context.Context, descEntitlement *v2.Entitlement, principal *v2.Resource) (*v2.Grant, error) { +func (s *SequentialSyncer) newExpandedGrant(_ context.Context, descEntitlement *v2.Entitlement, principal *v2.Resource) (*v2.Grant, error) { enResource := descEntitlement.GetResource() if enResource == nil { return nil, fmt.Errorf("newExpandedGrant: entitlement has no resource") @@ -2887,7 +3012,7 @@ func (s *syncer) newExpandedGrant(_ context.Context, descEntitlement *v2.Entitle } // expandGrantsForEntitlements expands grants for the given entitlement. -func (s *syncer) expandGrantsForEntitlements(ctx context.Context) error { +func (s *SequentialSyncer) expandGrantsForEntitlements(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.expandGrantsForEntitlements") defer span.End() @@ -2971,7 +3096,7 @@ func (s *syncer) expandGrantsForEntitlements(ctx context.Context) error { return nil } -func (s *syncer) loadStore(ctx context.Context) error { +func (s *SequentialSyncer) loadStore(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.loadStore") defer span.End() @@ -2980,7 +3105,12 @@ func (s *syncer) loadStore(ctx context.Context) error { } if s.c1zManager == nil { - m, err := manager.New(ctx, s.c1zPath, manager.WithTmpDir(s.tmpDir)) + opts := []manager.ManagerOption{manager.WithTmpDir(s.tmpDir)} + // Enable WAL checkpointing for parallel sync to prevent checkpoint failures under high concurrency + if s.enableWALCheckpoint { + opts = append(opts, manager.WithWALCheckpoint(true)) + } + m, err := manager.New(ctx, s.c1zPath, opts...) if err != nil { return err } @@ -3001,7 +3131,7 @@ func (s *syncer) loadStore(ctx context.Context) error { } // Close closes the datastorage to ensure it is updated on disk. -func (s *syncer) Close(ctx context.Context) error { +func (s *SequentialSyncer) Close(ctx context.Context) error { ctx, span := tracer.Start(ctx, "syncer.Close") defer span.End() @@ -3035,12 +3165,12 @@ func (s *syncer) Close(ctx context.Context) error { return nil } -type SyncOpt func(s *syncer) +type SyncOpt func(s *SequentialSyncer) // WithRunDuration sets a `time.Duration` for `NewSyncer` Options. // `d` represents a duration. The elapsed time between two instants as an int64 nanosecond count. func WithRunDuration(d time.Duration) SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { if d > 0 { s.runDuration = d } @@ -3049,7 +3179,7 @@ func WithRunDuration(d time.Duration) SyncOpt { // WithTransitionHandler sets a `transitionHandler` for `NewSyncer` Options. func WithTransitionHandler(f func(s Action)) SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { if f != nil { s.transitionHandler = f } @@ -3058,7 +3188,7 @@ func WithTransitionHandler(f func(s Action)) SyncOpt { // WithProgress sets a `progressHandler` for `NewSyncer` Options. func WithProgressHandler(f func(s *Progress)) SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { if f != nil { s.progressHandler = f } @@ -3066,43 +3196,43 @@ func WithProgressHandler(f func(s *Progress)) SyncOpt { } func WithConnectorStore(store connectorstore.Writer) SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { s.store = store } } func WithC1ZPath(path string) SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { s.c1zPath = path } } func WithTmpDir(path string) SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { s.tmpDir = path } } func WithSkipFullSync() SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { s.skipFullSync = true } } func WithExternalResourceC1ZPath(path string) SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { s.externalResourceC1ZPath = path } } func WithExternalResourceEntitlementIdFilter(entitlementId string) SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { s.externalResourceEntitlementIdFilter = entitlementId } } func WithTargetedSyncResourceIDs(resourceIDs []string) SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { s.targetedSyncResourceIDs = resourceIDs if len(resourceIDs) > 0 { s.syncType = connectorstore.SyncTypePartial @@ -3114,36 +3244,36 @@ func WithTargetedSyncResourceIDs(resourceIDs []string) SyncOpt { } func WithSessionStore(sessionStore sessions.SetSessionStore) SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { s.setSessionStore = sessionStore } } func WithSyncResourceTypes(resourceTypeIDs []string) SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { s.syncResourceTypes = resourceTypeIDs } } func WithOnlyExpandGrants() SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { s.onlyExpandGrants = true } } func WithDontExpandGrants() SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { s.dontExpandGrants = true } } func WithSyncID(syncID string) SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { s.syncID = syncID } } func WithSkipEntitlementsAndGrants(skip bool) SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { s.skipEntitlementsAndGrants = skip // Partial syncs can skip entitlements and grants, so don't update the sync type in that case. if s.syncType == connectorstore.SyncTypePartial { @@ -3158,14 +3288,14 @@ func WithSkipEntitlementsAndGrants(skip bool) SyncOpt { } func WithSkipGrants(skip bool) SyncOpt { - return func(s *syncer) { + return func(s *SequentialSyncer) { s.skipGrants = skip } } // NewSyncer returns a new syncer object. -func NewSyncer(ctx context.Context, c types.ConnectorClient, opts ...SyncOpt) (Syncer, error) { - s := &syncer{ +func NewSyncer(ctx context.Context, c types.ConnectorClient, opts ...SyncOpt) (*SequentialSyncer, error) { + s := &SequentialSyncer{ connector: c, skipEGForResourceType: make(map[string]bool), resourceTypeTraits: make(map[string][]v2.ResourceType_Trait), diff --git a/pkg/tasks/c1api/full_sync.go b/pkg/tasks/c1api/full_sync.go index 24266a6f7..8cd81e2dd 100644 --- a/pkg/tasks/c1api/full_sync.go +++ b/pkg/tasks/c1api/full_sync.go @@ -35,6 +35,7 @@ type fullSyncTaskHandler struct { externalResourceEntitlementIdFilter string targetedSyncResourceIDs []string syncResourceTypeIDs []string + parallelSync bool } func (c *fullSyncTaskHandler) sync(ctx context.Context, c1zPath string) error { @@ -87,12 +88,20 @@ func (c *fullSyncTaskHandler) sync(ctx context.Context, c1zPath string) error { syncOpts = append(syncOpts, sdkSync.WithSessionStore(setSessionStore)) } - syncer, err := sdkSync.NewSyncer(ctx, cc, syncOpts...) + var syncer sdkSync.Syncer + baseSyncer, err := sdkSync.NewSyncer(ctx, c.helpers.ConnectorClient(), syncOpts...) if err != nil { l.Error("failed to create syncer", zap.Error(err)) return err } + if c.parallelSync { + config := sdkSync.DefaultParallelSyncConfig().WithWorkerCount(10) + syncer = sdkSync.NewParallelSyncer(baseSyncer, config) + } else { + syncer = baseSyncer + } + // TODO(jirwin): Should we attempt to retry at all before failing the task? err = syncer.Sync(ctx) if err != nil { @@ -194,6 +203,7 @@ func newFullSyncTaskHandler( externalResourceEntitlementIdFilter string, targetedSyncResourceIDs []string, syncResourceTypeIDs []string, + parallelSync bool, ) tasks.TaskHandler { return &fullSyncTaskHandler{ task: task, @@ -203,6 +213,7 @@ func newFullSyncTaskHandler( externalResourceEntitlementIdFilter: externalResourceEntitlementIdFilter, targetedSyncResourceIDs: targetedSyncResourceIDs, syncResourceTypeIDs: syncResourceTypeIDs, + parallelSync: parallelSync, } } diff --git a/pkg/tasks/c1api/manager.go b/pkg/tasks/c1api/manager.go index 97a036a13..0ba8d02d0 100644 --- a/pkg/tasks/c1api/manager.go +++ b/pkg/tasks/c1api/manager.go @@ -54,6 +54,7 @@ type c1ApiTaskManager struct { externalResourceEntitlementIdFilter string targetedSyncResourceIDs []string syncResourceTypeIDs []string + parallelSync bool } // getHeartbeatInterval returns an appropriate heartbeat interval. If the interval is 0, it will return the default heartbeat interval. @@ -250,6 +251,7 @@ func (c *c1ApiTaskManager) Process(ctx context.Context, task *v1.Task, cc types. c.externalResourceEntitlementIdFilter, c.targetedSyncResourceIDs, c.syncResourceTypeIDs, + c.parallelSync, ) case taskTypes.HelloType: handler = newHelloTaskHandler(task, tHelpers) @@ -299,9 +301,16 @@ func (c *c1ApiTaskManager) Process(ctx context.Context, task *v1.Task, cc types. } func NewC1TaskManager( - ctx context.Context, clientID string, clientSecret string, tempDir string, skipFullSync bool, - externalC1Z string, externalResourceEntitlementIdFilter string, targetedSyncResourceIDs []string, + ctx context.Context, + clientID string, + clientSecret string, + tempDir string, + skipFullSync bool, + externalC1Z string, + externalResourceEntitlementIdFilter string, + targetedSyncResourceIDs []string, syncResourceTypeIDs []string, + parallelSync bool, ) (tasks.Manager, error) { serviceClient, err := newServiceClient(ctx, clientID, clientSecret) if err != nil { @@ -316,5 +325,6 @@ func NewC1TaskManager( externalResourceEntitlementIdFilter: externalResourceEntitlementIdFilter, targetedSyncResourceIDs: targetedSyncResourceIDs, syncResourceTypeIDs: syncResourceTypeIDs, + parallelSync: parallelSync, }, nil } diff --git a/pkg/tasks/local/syncer.go b/pkg/tasks/local/syncer.go index 15148c48c..5a1c5cbcf 100644 --- a/pkg/tasks/local/syncer.go +++ b/pkg/tasks/local/syncer.go @@ -25,6 +25,7 @@ type localSyncer struct { skipEntitlementsAndGrants bool skipGrants bool syncResourceTypeIDs []string + parallelSync bool } type Option func(*localSyncer) @@ -71,6 +72,12 @@ func WithSkipGrants(skip bool) Option { } } +func WithParallelSyncEnabled(parallel bool) Option { + return func(m *localSyncer) { + m.parallelSync = parallel + } +} + func (m *localSyncer) GetTempDir() string { return "" } @@ -97,7 +104,9 @@ func (m *localSyncer) Process(ctx context.Context, task *v1.Task, cc types.Conne if ssetSessionStore, ok := cc.(session.SetSessionStore); ok { setSessionStore = ssetSessionStore } - syncer, err := sdkSync.NewSyncer(ctx, cc, + + var syncer sdkSync.Syncer + baseSyncer, err := sdkSync.NewSyncer(ctx, cc, sdkSync.WithC1ZPath(m.dbPath), sdkSync.WithTmpDir(m.tmpDir), sdkSync.WithExternalResourceC1ZPath(m.externalResourceC1Z), @@ -112,6 +121,13 @@ func (m *localSyncer) Process(ctx context.Context, task *v1.Task, cc types.Conne return err } + if m.parallelSync { + config := sdkSync.DefaultParallelSyncConfig().WithWorkerCount(10) + syncer = sdkSync.NewParallelSyncer(baseSyncer, config) + } else { + syncer = baseSyncer + } + err = syncer.Sync(ctx) if err != nil { if closeErr := syncer.Close(ctx); closeErr != nil { diff --git a/proto/c1/connector/v2/resource.proto b/proto/c1/connector/v2/resource.proto index 2528030a4..754c8b96a 100644 --- a/proto/c1/connector/v2/resource.proto +++ b/proto/c1/connector/v2/resource.proto @@ -44,6 +44,12 @@ message ResourceType { ignore_empty: true }]; bool sourced_externally = 6; + + // Sync bucketing configuration for parallel processing + // Resource types with the same bucket name will be processed sequentially within that bucket + // Resource types with different bucket names can be processed in parallel + // If not specified, the default bucket from ParallelSyncConfig will be used + string sync_bucket = 7; } message ResourceTypesServiceListResourceTypesRequest {