From 1be72ec5a663ac6b3f6c01bc35f807b5c97a103c Mon Sep 17 00:00:00 2001 From: Lionello Lunesu Date: Thu, 19 Feb 2026 09:17:32 -0800 Subject: [PATCH 1/2] Reapply "Logs refactoring using Go iterators (#1845)" This reverts commit 0ebd3d7e9a1d96a6b537a132a6e88c5d8ff04a18. --- src/cmd/cli/command/commands.go | 8 +- src/pkg/cli/client/byoc/aws/alb_logs.go | 236 +++++++++++ src/pkg/cli/client/byoc/aws/alb_logs_test.go | 79 ++++ src/pkg/cli/client/byoc/aws/byoc.go | 165 ++++---- .../client/byoc/aws/byoc_integration_test.go | 23 +- src/pkg/cli/client/byoc/aws/byoc_test.go | 395 +++++++++++++++--- src/pkg/cli/client/byoc/aws/stream.go | 78 +--- src/pkg/cli/client/byoc/aws/stream_test.go | 4 +- src/pkg/cli/client/byoc/aws/subscribe.go | 139 +++--- src/pkg/cli/client/byoc/aws/subscribe_test.go | 245 +++++++---- ...260206T0000Z_44.233.47.227_7tj887d8.log.gz | Bin 0 -> 1485 bytes ...60206T0010Z_34.217.170.253_o3mv43zx.log.gz | Bin 0 -> 801 bytes ...260206T0010Z_44.233.47.227_2eihuvci.log.gz | Bin 0 -> 2745 bytes src/pkg/cli/client/byoc/do/byoc.go | 108 ++--- src/pkg/cli/client/byoc/do/stream.go | 128 +++--- src/pkg/cli/client/byoc/gcp/byoc.go | 40 +- src/pkg/cli/client/byoc/gcp/byoc_test.go | 13 +- src/pkg/cli/client/byoc/gcp/stream.go | 232 +++++----- src/pkg/cli/client/byoc/gcp/stream_test.go | 18 +- src/pkg/cli/client/mock.go | 73 +++- src/pkg/cli/client/playground.go | 36 +- src/pkg/cli/client/provider.go | 7 +- src/pkg/cli/composeUp_test.go | 15 +- src/pkg/cli/safe_closer.go | 33 -- src/pkg/cli/subscribe.go | 31 +- src/pkg/cli/subscribe_test.go | 288 ++++++------- src/pkg/cli/tail.go | 55 ++- src/pkg/cli/tailAndMonitor_test.go | 147 ++++--- src/pkg/cli/tail_test.go | 75 ++-- src/pkg/cli/waitForCdTaskExit.go | 6 +- src/pkg/cli/waitForCdTaskExit_test.go | 6 +- src/pkg/clouds/aws/common.go | 9 +- src/pkg/clouds/aws/cw/logs.go | 336 ++++++--------- src/pkg/clouds/aws/cw/logs_test.go | 50 +-- src/pkg/clouds/aws/cw/merge.go | 131 ++++-- src/pkg/clouds/aws/cw/merge_test.go | 199 +++++++++ src/pkg/clouds/aws/cw/stream.go | 177 ++++---- src/pkg/clouds/aws/cw/stream_test.go | 31 +- src/pkg/clouds/aws/ecs/status.go | 27 +- src/pkg/clouds/aws/ecs/tail.go | 56 +-- src/pkg/clouds/gcp/cloudbuild.go | 12 +- src/pkg/track/track.go | 1 + 42 files changed, 2218 insertions(+), 1494 deletions(-) create mode 100644 src/pkg/cli/client/byoc/aws/alb_logs.go create mode 100644 src/pkg/cli/client/byoc/aws/alb_logs_test.go create mode 100644 src/pkg/cli/client/byoc/aws/testdata/123456789012_elasticloadbalancing_us-west-2_app.defang-agentic-strands-aws7d0286.c9b3756e8ef89456_20260206T0000Z_44.233.47.227_7tj887d8.log.gz create mode 100644 src/pkg/cli/client/byoc/aws/testdata/123456789012_elasticloadbalancing_us-west-2_app.defang-agentic-strands-aws7d0286.c9b3756e8ef89456_20260206T0010Z_34.217.170.253_o3mv43zx.log.gz create mode 100644 src/pkg/cli/client/byoc/aws/testdata/123456789012_elasticloadbalancing_us-west-2_app.defang-agentic-strands-aws7d0286.c9b3756e8ef89456_20260206T0010Z_44.233.47.227_2eihuvci.log.gz delete mode 100644 src/pkg/cli/safe_closer.go create mode 100644 src/pkg/clouds/aws/cw/merge_test.go diff --git a/src/cmd/cli/command/commands.go b/src/cmd/cli/command/commands.go index 0f6ddbefe..4891799eb 100644 --- a/src/cmd/cli/command/commands.go +++ b/src/cmd/cli/command/commands.go @@ -49,6 +49,7 @@ func Execute(ctx context.Context) error { if err := RootCmd.ExecuteContext(ctx); err != nil { if !errors.Is(err, context.Canceled) && !errors.Is(err, context.DeadlineExceeded) { term.Error("Error:", client.PrettyError(err)) + track.Evt("CLI Error", P("err", err)) } if err == dryrun.ErrDryRun { @@ -378,12 +379,7 @@ var RootCmd = &cobra.Command{ // Use "defer" to track any errors that occur during the command defer func() { - var errString = "" - if err != nil { - errString = err.Error() - } - - track.Cmd(cmd, "Invoked", P("args", args), P("err", errString), P("non-interactive", global.NonInteractive), P("provider", global.Stack.Provider)) + track.Cmd(cmd, "Invoked", P("args", args), P("err", err), P("non-interactive", global.NonInteractive), P("provider", global.Stack.Provider)) }() // Do this first, since any errors will be printed to the console diff --git a/src/pkg/cli/client/byoc/aws/alb_logs.go b/src/pkg/cli/client/byoc/aws/alb_logs.go new file mode 100644 index 000000000..3228cbdcf --- /dev/null +++ b/src/pkg/cli/client/byoc/aws/alb_logs.go @@ -0,0 +1,236 @@ +package aws + +import ( + "bufio" + "compress/gzip" + "context" + "errors" + "fmt" + "io" + "iter" + "slices" + "strings" + "time" + + "github.com/DefangLabs/defang/src/pkg/clouds/aws/cw" + "github.com/DefangLabs/defang/src/pkg/term" + "github.com/aws/aws-sdk-go-v2/service/s3" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" +) + +func (b *ByocAws) fetchAndStreamAlbLogs(ctx context.Context, projectName string, since, end time.Time, pattern string) (iter.Seq2[cw.LogEvent, error], error) { + cfg, err := b.driver.LoadConfig(ctx) + if err != nil { + return nil, err + } + + s3Client := s3.NewFromConfig(cfg) + bucketsOutput, err := s3Client.ListBuckets(ctx, &s3.ListBucketsInput{}) + if err != nil { + return nil, err + } + + bucketPrefix := fmt.Sprintf("%s-%s-alb-logs", projectName, b.PulumiStack) + if b.Prefix != "" { + bucketPrefix = b.Prefix + "-" + bucketPrefix + } + term.Debug("Query ALB logs", bucketPrefix) + if len(bucketPrefix) > 31 { + // HACK: AWS CD truncates the ALB name to 31 characters (because of the long Terraform suffix) + bucketPrefix = bucketPrefix[:31] + } + bucketPrefix = strings.ToLower(bucketPrefix) + + // First, find bucket with the given prefix for the project/stack + var bucketName string + for _, bucket := range bucketsOutput.Buckets { + if strings.HasPrefix(*bucket.Name, bucketPrefix) { + // TODO: inspect the bucket tags to ensure it belongs to the right org/project/stack + bucketName = *bucket.Name + break + } + } + + if bucketName == "" { + return nil, fmt.Errorf("no bucket found with prefix %q", bucketPrefix) + } + + return func(yield func(cw.LogEvent, error) bool) { + for logs, err := range b.fetchAndStreamAlbLogsFromBucket(ctx, bucketName, since, end, s3Client, pattern) { + if err != nil { + yield(cw.LogEvent{}, err) + return + } + for _, log := range logs { + timestamp := log.Timestamp.UnixMilli() // FIXME: this destroys the original timestamp precision + if !yield(cw.LogEvent{ + Message: &log.Message, + Timestamp: ×tamp, + }, nil) { + return + } + } + } + }, nil +} + +type s3Lister interface { + s3.ListObjectsV2APIClient + GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) +} + +func getAlbLogObjectGroupKey(objName string) string { + // 123456789012_elasticloadbalancing_us-test-2_app.defang-project-stack-alb.d850f5ca299e222a_20260207T0120Z_11.22.33.44_2khrazuh.log.gz + key, _, _ := strings.Cut(objName, "Z_") + return key +} + +func (b *ByocAws) fetchAndStreamAlbLogsFromBucket(ctx context.Context, bucketName string, since, end time.Time, s3Client s3Lister, pattern string) iter.Seq2[[]ALBLogEntry, error] { + return func(yield func([]ALBLogEntry, error) bool) { + if end.IsZero() { + end = time.Now() + } + if since.IsZero() { + since = end.Add(-60 * time.Minute) + } + // If the end time is 00:00:01Z, we should still consider log files modified at 00:05:03Z + // because each file has ~5 minutes of logs and writing the file will have take a few seconds. + lastModifiedEnd := end.Add(5*time.Minute + 5*time.Second) + + // Use a single listing with the region-level prefix instead of iterating day-by-day. + // StartAfter skips to the start date, so empty buckets complete in a single API call. + objectPrefix := fmt.Sprintf("AWSLogs/%s/elasticloadbalancing/%s/", b.driver.AccountID, b.driver.Region) + year, month, day := since.UTC().Date() + startAfter := fmt.Sprintf("AWSLogs/%s/elasticloadbalancing/%s/%04d/%02d/%02d/", b.driver.AccountID, b.driver.Region, year, month, day) + + listInput := s3.ListObjectsV2Input{ + Bucket: &bucketName, + Prefix: &objectPrefix, + StartAfter: &startAfter, + } + var groupKey string + var group []s3types.Object + done: + for { + list, err := s3Client.ListObjectsV2(ctx, &listInput) + if err != nil { + yield(nil, err) + return + } + for _, obj := range list.Contents { + // LastModified is time of latest record. Skip objects with events older than the since-time + if obj.LastModified.Before(since) { + continue + } + // Check end-time, but consider that each object has ~5 minutes of logs + if obj.LastModified.After(lastModifiedEnd) { + break done + } + if key := getAlbLogObjectGroupKey(*obj.Key); key == groupKey { + // Same timespan as the previous object, so add to group for merging. + group = append(group, obj) + } else { + // New timespan, so stream logs from the previous group(s) before starting a new group. + logs, err := readAlbLogsGroup(ctx, bucketName, group, since, end, s3Client, pattern) + if len(logs) > 0 || err != nil { + if !yield(logs, err) { + return + } + } + group = []s3types.Object{obj} + groupKey = key + } + } + if list.NextContinuationToken == nil { + break + } + listInput.ContinuationToken = list.NextContinuationToken + } + // Flush remaining group + logs, err := readAlbLogsGroup(ctx, bucketName, group, since, end, s3Client, pattern) + if len(logs) > 0 || err != nil { + yield(logs, err) + } + } +} + +type ALBLogEntry struct { + Message string + Timestamp time.Time +} + +func readAlbLogsGroup(ctx context.Context, bucketName string, group []s3types.Object, since, end time.Time, s3Client s3Lister, pattern string) ([]ALBLogEntry, error) { + var allEntries []ALBLogEntry + for _, obj := range group { + content, err := s3Client.GetObject(ctx, &s3.GetObjectInput{ + Bucket: &bucketName, + Key: obj.Key, + }) + if err != nil { + return nil, err // or continue with other objects? + } + entries, err := readAlbLogs(content.Body, since, end, pattern) + if err != nil { + return nil, err // or continue with other objects? + } + if allEntries == nil { + allEntries = entries + } else { + allEntries = append(allEntries, entries...) + } + } + // Always need to sort, because log entries within each object are not in order. + slices.SortFunc(allEntries, func(a, b ALBLogEntry) int { + return a.Timestamp.Compare(b.Timestamp) + }) + return allEntries, nil +} + +var errMalformedALBLogLine = errors.New("malformed ALB log line") + +func parseAlbLogTime(logLine string) (time.Time, error) { + // https 2026-02-05T23:58:32.578204Z app/defang-project-stack7d0286/c9b3756e8ef89456 11.22.33.44:34025 - -1 -1 -1 404 - 842 1023 "POST https://11.22.33.44:443/ HTTP/1.1" "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/115.0.0.0 Safari/537.36 Edg/115.0.1901.203" ECDHE-RSA-AES128-GCM-SHA256 TLSv1.2 - "Root=1-69852ea8-7429674e211c223e3c211c6d" "-" "arn:aws:acm:us-test-2:123456789012:certificate/be524858-3414-4e98-be52-240358d85b1c" 0 2026-02-05T23:58:32.493000Z "fixed-response" "-" "-" "-" "-" "-" "-" TID_ba88f3bfb4f5c249b7d9f74348a70697 "-" "-" "-" + timestampStart := strings.IndexByte(logLine, ' ') + 1 // will be 0 if not found + timestampEnd := strings.IndexByte(logLine[timestampStart:], ' ') + timestampStart + if timestampEnd <= timestampStart { + return time.Time{}, errMalformedALBLogLine + } + return time.Parse(time.RFC3339Nano, logLine[timestampStart:timestampEnd]) +} + +func readAlbLogs(body io.ReadCloser, since, end time.Time, pattern string) ([]ALBLogEntry, error) { + defer body.Close() + gzipReader, err := gzip.NewReader(body) + if err != nil { + return nil, err + } + var entries []ALBLogEntry + lineScanner := bufio.NewScanner(gzipReader) + for lineScanner.Scan() { + logLine := lineScanner.Text() + if !strings.Contains(logLine, pattern) { + continue + } + timestamp, err := parseAlbLogTime(logLine) + if err != nil { + continue // malformed timestamp: ignore + } + if timestamp.Before(since) { + continue + } + if timestamp.After(end) { + continue // can't break, because there can be out-of-order timestamps + } + entries = append(entries, ALBLogEntry{ + Message: logLine, + Timestamp: timestamp, + }) + } + if err := lineScanner.Err(); err != nil { + return nil, err + } + if err := gzipReader.Close(); err != nil { + return nil, err // only returns err on failed checksum after io.EOF + } + return entries, nil +} diff --git a/src/pkg/cli/client/byoc/aws/alb_logs_test.go b/src/pkg/cli/client/byoc/aws/alb_logs_test.go new file mode 100644 index 000000000..c634d17b5 --- /dev/null +++ b/src/pkg/cli/client/byoc/aws/alb_logs_test.go @@ -0,0 +1,79 @@ +package aws + +import ( + "context" + "os" + "path/filepath" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/service/s3" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/aws/smithy-go/ptr" + "github.com/stretchr/testify/require" +) + +func Test_readAlbLogs(t *testing.T) { + gz, err := os.Open("testdata/123456789012_elasticloadbalancing_us-west-2_app.defang-agentic-strands-aws7d0286.c9b3756e8ef89456_20260206T0000Z_44.233.47.227_7tj887d8.log.gz") + require.NoError(t, err) + entries, err := readAlbLogs(gz, time.Time{}, time.Now(), "") + require.NoError(t, err) + for _, entry := range entries { + t.Logf("%s: %s", entry.Timestamp, entry.Message) + } +} + +type mockS3Lister struct{} + +func (m mockS3Lister) ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) { + entries, err := os.ReadDir(filepath.Join(".", *params.Bucket)) + contents := make([]s3types.Object, len(entries)) + for i, entry := range entries { + contents[i].Key = ptr.String(entry.Name()) + contents[i].LastModified = ptr.Time(time.Now()) + } + return &s3.ListObjectsV2Output{ + Contents: contents, + }, err +} + +func (m mockS3Lister) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + body, err := os.Open(*params.Key) + return &s3.GetObjectOutput{ + Body: body, + }, err +} + +func Test_streamAlbLogGroup(t *testing.T) { + s3Client := mockS3Lister{} + + t.Run("empty group", func(t *testing.T) { + entries, err := readAlbLogsGroup(t.Context(), "testdata", nil, time.Time{}, time.Now(), s3Client, "") + require.NoError(t, err) + require.Empty(t, entries) + }) + + t.Run("with test files", func(t *testing.T) { + files, err := os.ReadDir("testdata") + require.NoError(t, err) + var objects []s3types.Object + for _, f := range files { + if filepath.Ext(f.Name()) == ".gz" { + objects = append(objects, s3types.Object{ + Key: ptr.String(filepath.Join("testdata", f.Name())), + LastModified: ptr.Time(time.Now()), + }) + } + } + entries, err := readAlbLogsGroup(t.Context(), "testdata", objects, time.Time{}, time.Now(), s3Client, "") + require.NoError(t, err) + for _, entry := range entries { + t.Logf("%s: %s", entry.Timestamp, entry.Message) + } + require.NotEmpty(t, entries) + // Verify entries are sorted by timestamp + for i := 1; i < len(entries); i++ { + require.False(t, entries[i].Timestamp.Before(entries[i-1].Timestamp), "entries not sorted at index %d", i) + } + }) +} diff --git a/src/pkg/cli/client/byoc/aws/byoc.go b/src/pkg/cli/client/byoc/aws/byoc.go index 39807bba8..e97d541e3 100644 --- a/src/pkg/cli/client/byoc/aws/byoc.go +++ b/src/pkg/cli/client/byoc/aws/byoc.go @@ -12,7 +12,6 @@ import ( "os" "strconv" "strings" - "sync" "time" "github.com/DefangLabs/defang/src/pkg" @@ -22,7 +21,6 @@ import ( "github.com/DefangLabs/defang/src/pkg/cli/compose" "github.com/DefangLabs/defang/src/pkg/clouds" "github.com/DefangLabs/defang/src/pkg/clouds/aws" - "github.com/DefangLabs/defang/src/pkg/clouds/aws/codebuild" "github.com/DefangLabs/defang/src/pkg/clouds/aws/cw" "github.com/DefangLabs/defang/src/pkg/clouds/aws/ecs" "github.com/DefangLabs/defang/src/pkg/clouds/aws/ecs/cfn" @@ -37,7 +35,6 @@ import ( defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" awssdk "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" - "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs" cwTypes "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types" "github.com/aws/aws-sdk-go-v2/service/route53" "github.com/aws/aws-sdk-go-v2/service/s3" @@ -56,12 +53,9 @@ type ByocAws struct { driver *cfn.AwsEcsCfn // TODO: ecs is stateful, contains the output of the cd cfn stack after SetUpCD - ecsEventHandlers []ECSEventHandler - codebuildEventHandlers []CodebuildEventHandler - handlersLock sync.RWMutex - cdEtag types.ETag - cdStart time.Time - cdTaskArn ecs.TaskArn + cdEtag types.ETag + cdStart time.Time + cdTaskArn ecs.TaskArn needDockerHubCreds bool } @@ -176,15 +170,16 @@ func (b *ByocAws) SetUpCD(ctx context.Context) error { return nil } -func (b *ByocAws) GetDeploymentStatus(ctx context.Context) error { - if err := ecs.GetTaskStatus(ctx, b.cdTaskArn); err != nil { +func (b *ByocAws) GetDeploymentStatus(ctx context.Context) (bool, error) { + done, err := ecs.GetTaskStatus(ctx, b.cdTaskArn) + if err != nil { // check if the task failed; if so, return the a ErrDeploymentFailed error if taskErr := new(ecs.TaskFailure); errors.As(err, taskErr) { - return client.ErrDeploymentFailed{Message: taskErr.Error()} + return done, client.ErrDeploymentFailed{Message: taskErr.Error()} } - return err + return done, err } - return nil + return done, nil } func (b *ByocAws) Deploy(ctx context.Context, req *client.DeployRequest) (*defangv1.DeployResponse, error) { @@ -551,7 +546,7 @@ func (b *ByocAws) runCdCommand(ctx context.Context, cmd cdCommand) (ecs.TaskArn, if os.Getenv("DEFANG_PULUMI_DIR") != "" { // Convert the environment to a human-readable array of KEY=VALUE strings for debugging - debugEnv := []string{"AWS_REGION=" + b.driver.Region.String()} + debugEnv := []string{"AWS_REGION=" + string(b.driver.Region)} if awsProfile := os.Getenv("AWS_PROFILE"); awsProfile != "" { debugEnv = append(debugEnv, "AWS_PROFILE="+awsProfile) } @@ -679,7 +674,7 @@ func (b *ByocAws) CreateUploadURL(ctx context.Context, req *defangv1.UploadURLRe }, nil } -func (b *ByocAws) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (client.ServerStream[defangv1.TailResponse], error) { +func (b *ByocAws) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (iter.Seq2[*defangv1.TailResponse, error], error) { // FillOutputs is needed to get the CD task ARN or the LogGroup ARNs // if the cloud formation stack has been destroyed, we can still query // logs for builds and services @@ -699,23 +694,52 @@ func (b *ByocAws) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (cli // * No Etag, service: tail all tasks/services with that service name // * Etag, no services: tail all tasks/services with that Etag // * Etag, service: tail that task/service - var tailStream cw.LiveTailStream + var logSeq iter.Seq2[cw.LogEvent, error] etag, err := types.ParseEtag(req.Etag) if err != nil && req.Etag != "" { // Assume invalid "etag" is the task ID of the CD task - tailStream, err = b.queryCdLogs(ctx, cwClient, req) - // no need to filter events by etag because we only show logs from the specified task ID + cdSeq, err := b.queryOrTailCdLogs(ctx, cwClient, req) + if err != nil { + return nil, AnnotateAwsError(err) + } + logSeq = cw.Flatten(cdSeq) + // No need to filter events by etag because we only show logs from the specified task ID } else { - tailStream, err = b.queryLogs(ctx, cwClient, req) + logSeq, err = b.queryOrTailLogs(ctx, cwClient, req) + if err != nil { + return nil, AnnotateAwsError(err) + } } - if err != nil { - return nil, AnnotateAwsError(err) + parser := &logEventParser{ + etag: etag, + services: req.Services, } - return newByocServerStream(tailStream, etag, req.Services, b, b), nil + return func(yield func(*defangv1.TailResponse, error) bool) { + for event, err := range logSeq { + if err != nil { + // Ignore ResourceNotFoundException errors which can only happen if a log stream is missing during Query + var resourceNotFound *cwTypes.ResourceNotFoundException + if errors.As(err, &resourceNotFound) { + term.Debugf("Log stream not found while tailing, skipping: %v", err) + continue + } + if !yield(nil, AnnotateAwsError(err)) { + return + } + continue + } + resp := parser.parseEvent(event) + if resp != nil { + if !yield(resp, nil) { + return + } + } + } + }, nil } -func (b *ByocAws) queryCdLogs(ctx context.Context, cwClient *cloudwatchlogs.Client, req *defangv1.TailRequest) (cw.LiveTailStream, error) { +func (b *ByocAws) queryOrTailCdLogs(ctx context.Context, cwClient cw.LogsClient, req *defangv1.TailRequest) (iter.Seq2[[]cw.LogEvent, error], error) { var err error b.cdTaskArn, err = b.driver.GetTaskArn(req.Etag) // only fails on missing task ID if err != nil { @@ -730,7 +754,7 @@ func (b *ByocAws) queryCdLogs(ctx context.Context, cwClient *cloudwatchlogs.Clie } } -func (b *ByocAws) queryLogs(ctx context.Context, cwClient *cloudwatchlogs.Client, req *defangv1.TailRequest) (cw.LiveTailStream, error) { +func (b *ByocAws) queryOrTailLogs(ctx context.Context, cwClient cw.LogsClient, req *defangv1.TailRequest) (iter.Seq2[cw.LogEvent, error], error) { start := timeutils.AsTime(req.Since, time.Time{}) end := timeutils.AsTime(req.Until, time.Time{}) @@ -740,15 +764,19 @@ func (b *ByocAws) queryLogs(ctx context.Context, cwClient *cloudwatchlogs.Client } lgis := b.getLogGroupInputs(req.Etag, req.Project, service, req.Pattern, logs.LogType(req.LogType)) if req.Follow { - return cw.QueryAndTailLogGroups( + logSeq, err := cw.QueryAndTailLogGroups( ctx, cwClient, start, end, lgis..., ) + if err != nil { + return nil, err + } + return cw.Flatten(logSeq), nil } else { - evtsChan, errsChan := cw.QueryLogGroups( + logSeq, err := cw.QueryLogGroups( ctx, cwClient, start, @@ -756,15 +784,26 @@ func (b *ByocAws) queryLogs(ctx context.Context, cwClient *cloudwatchlogs.Client req.Limit, lgis..., ) - if evtsChan == nil { - var errs []error - for err := range errsChan { - errs = append(errs, err) + if err != nil { + return nil, err + } + if len(req.Services) == 0 { + albIter, err := b.fetchAndStreamAlbLogs(ctx, req.Project, start, end, req.Pattern) + if err != nil { + term.Debugf("Failed to fetch ALB logs: %v", err) + } else { + logSeq = cw.MergeLogEvents(logSeq, albIter) + if req.Limit > 0 { + // take the first/last n events only from the merged stream + if start.IsZero() { + logSeq = cw.TakeLastN(logSeq, int(req.Limit)) + } else { + logSeq = cw.TakeFirstN(logSeq, int(req.Limit)) + } + } } - return nil, errors.Join(errs...) } - // TODO: any errors from errsChan should be reported but get dropped - return cw.NewStaticLogStream(evtsChan, func() {}), nil + return logSeq, nil } } @@ -890,53 +929,33 @@ func (b *ByocAws) CdList(ctx context.Context, allRegions bool) (iter.Seq[state.I } } -type ECSEventHandler interface { - HandleECSEvent(evt ecs.Event) -} - -type CodebuildEventHandler interface { - HandleCodebuildEvent(evt codebuild.Event) -} - -func (b *ByocAws) Subscribe(ctx context.Context, req *defangv1.SubscribeRequest) (client.ServerStream[defangv1.SubscribeResponse], error) { - s := &byocSubscribeServerStream{ - services: req.Services, - etag: req.Etag, - - ch: make(chan *defangv1.SubscribeResponse), - done: make(chan struct{}), +func (b *ByocAws) Subscribe(ctx context.Context, req *defangv1.SubscribeRequest) (iter.Seq2[*defangv1.SubscribeResponse, error], error) { + if err := b.driver.FillOutputs(ctx); err != nil { + term.Warnf("Unable to get log group ARNs: %v", err) } - b.AddEcsEventHandler(s) - b.AddCodebuildEventHandler(s) - return s, nil -} -func (b *ByocAws) HandleECSEvent(evt ecs.Event) { - b.handlersLock.RLock() - defer b.handlersLock.RUnlock() - for _, handler := range b.ecsEventHandlers { - handler.HandleECSEvent(evt) + cwClient, err := cw.NewCloudWatchLogsClient(ctx, b.driver.Region) + if err != nil { + return nil, AnnotateAwsError(err) } -} -func (b *ByocAws) HandleCodebuildEvent(evt codebuild.Event) { - b.handlersLock.RLock() - defer b.handlersLock.RUnlock() - for _, handler := range b.codebuildEventHandlers { - handler.HandleCodebuildEvent(evt) + lgis := b.getSubscribeLogGroupInputs(req.Project) + logSeq, err := cw.QueryAndTailLogGroups(ctx, cwClient, b.cdStart, time.Time{}, lgis...) + if err != nil { + return nil, AnnotateAwsError(err) } -} -func (b *ByocAws) AddEcsEventHandler(handler ECSEventHandler) { - b.handlersLock.Lock() - defer b.handlersLock.Unlock() - b.ecsEventHandlers = append(b.ecsEventHandlers, handler) + etag, _ := types.ParseEtag(req.Etag) + return parseSubscribeEvents(logSeq, etag, req.Services), nil } -func (b *ByocAws) AddCodebuildEventHandler(handler CodebuildEventHandler) { - b.handlersLock.Lock() - defer b.handlersLock.Unlock() - b.codebuildEventHandlers = append(b.codebuildEventHandlers, handler) +func (b *ByocAws) getSubscribeLogGroupInputs(projectName string) []cw.LogGroupInput { + var groups []cw.LogGroupInput + buildsARN := b.makeLogGroupARN(b.StackDir(projectName, "builds")) + groups = append(groups, cw.LogGroupInput{LogGroupARN: buildsARN}) + ecsARN := b.makeLogGroupARN(b.StackDir(projectName, "ecs")) + groups = append(groups, cw.LogGroupInput{LogGroupARN: ecsARN}) + return groups } func (b *ByocAws) GetPrivateDomain(projectName string) string { diff --git a/src/pkg/cli/client/byoc/aws/byoc_integration_test.go b/src/pkg/cli/client/byoc/aws/byoc_integration_test.go index b8933798b..2d12d6c14 100644 --- a/src/pkg/cli/client/byoc/aws/byoc_integration_test.go +++ b/src/pkg/cli/client/byoc/aws/byoc_integration_test.go @@ -34,7 +34,7 @@ func TestDeploy(t *testing.T) { func TestTail(t *testing.T) { b := NewByocProvider(ctx, "TestTail", "") - ss, err := b.QueryLogs(context.Background(), &defangv1.TailRequest{Project: "byoc_integration_test"}) + logs, err := b.QueryLogs(context.Background(), &defangv1.TailRequest{Project: "byoc_integration_test"}) if err != nil { // the only acceptable error is "unauthorized" if connect.CodeOf(err) == connect.CodeUnauthenticated { @@ -42,18 +42,17 @@ func TestTail(t *testing.T) { } t.Fatalf("unexpected error: %v", err) } - defer ss.Close() - // First we expect "true" (the "start" event) - if ss.Receive() != true { - t.Error("expected Receive() to return true") - } - if len(ss.Msg().Entries) != 0 { - t.Error("expected empty entries") - } - err = ss.Err() - if err != nil { - t.Error(err) + for msg, err := range logs { + if err != nil { + t.Errorf("unexpected error: %v", err) + break + } + // First message should have empty entries (the "start" event) + if len(msg.Entries) != 0 { + t.Error("expected empty entries") + } + break // only check the first message } } diff --git a/src/pkg/cli/client/byoc/aws/byoc_test.go b/src/pkg/cli/client/byoc/aws/byoc_test.go index 87f2d79a1..be50d6618 100644 --- a/src/pkg/cli/client/byoc/aws/byoc_test.go +++ b/src/pkg/cli/client/byoc/aws/byoc_test.go @@ -6,24 +6,30 @@ import ( "context" "embed" "encoding/json" + "fmt" "io" "os" "path/filepath" "strings" - "sync" "testing" + "time" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc" "github.com/DefangLabs/defang/src/pkg/clouds/aws" - "github.com/DefangLabs/defang/src/pkg/clouds/aws/ecs" + "github.com/DefangLabs/defang/src/pkg/clouds/aws/cw" "github.com/DefangLabs/defang/src/pkg/clouds/aws/ecs/cfn" "github.com/DefangLabs/defang/src/pkg/dns" + "github.com/DefangLabs/defang/src/pkg/logs" "github.com/DefangLabs/defang/src/pkg/term" "github.com/DefangLabs/defang/src/pkg/types" defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs" + cwTypes "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types" composeTypes "github.com/compose-spec/compose-go/v2/types" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" ) func TestDomainMultipleProjectSupport(t *testing.T) { @@ -105,12 +111,15 @@ var testDir embed.FS var expectedDir embed.FS func TestSubscribe(t *testing.T) { - t.Skip("Pending test") + t.Skip("Pending test") // TODO: requires CW mock or real AWS credentials tests, err := testDir.ReadDir("testdata") if err != nil { t.Fatalf("failed to load ecs events test files: %v", err) } for _, tt := range tests { + if !strings.HasSuffix(tt.Name(), ".json") { + continue + } t.Run(tt.Name(), func(t *testing.T) { start := strings.LastIndex(tt.Name(), "-") end := strings.LastIndex(tt.Name(), ".") @@ -120,60 +129,53 @@ func TestSubscribe(t *testing.T) { name := tt.Name()[:start] etag := tt.Name()[start+1 : end] - byoc := &ByocAws{} - - resp, err := byoc.Subscribe(t.Context(), &defangv1.SubscribeRequest{ - Etag: etag, - Services: []string{"api", "web"}, - }) - if err != nil { - t.Fatalf("Subscribe() failed: %v", err) - } - - var wg sync.WaitGroup - wg.Add(1) - go func() { - defer wg.Done() - - filename := filepath.Join("testdata", name+".events") - ef, _ := expectedDir.ReadFile(filename) - dec := json.NewDecoder(bytes.NewReader(ef)) - - for { - if !resp.Receive() { - if resp.Err() != nil { - t.Errorf("Receive() failed: %v", resp.Err()) - } - break - } - msg := resp.Msg() - var expected defangv1.SubscribeResponse - if err := dec.Decode(&expected); err == io.EOF { - t.Errorf("unexpected message: %v", msg) - } else if err != nil { - t.Errorf("error unmarshaling expected ECS event: %v", err) - } else if msg.Name != expected.Name || msg.Status != expected.Status || msg.State != expected.State { - t.Errorf("expected message-, got+\n-%v\n+%v", &expected, msg) - } - } - }() - data, err := testDir.ReadFile(filepath.Join("testdata", tt.Name())) if err != nil { t.Fatalf("failed to read test file: %v", err) } + + // Build CW log events from the ECS event JSON lines + ecsLogGroup := "arn:aws:logs:us-west-2:123:log-group:/ecs" lines := bufio.NewScanner(bytes.NewReader(data)) + var cwEvents []cw.LogEvent + var ts int64 for lines.Scan() { - ecsEvt, err := ecs.ParseECSEvent([]byte(lines.Text())) - if err != nil { - t.Fatalf("error parsing ECS event: %v", err) - } + line := lines.Text() + cwEvents = append(cwEvents, cw.LogEvent{ + LogGroupIdentifier: &ecsLogGroup, + LogStreamName: awssdk.String("some-stream"), + Message: awssdk.String(line), + Timestamp: &ts, + }) + } - byoc.HandleECSEvent(ecsEvt) + // Feed through parseSubscribeEvents + evtIter := func(yield func([]cw.LogEvent, error) bool) { + for _, evt := range cwEvents { + if !yield([]cw.LogEvent{evt}, nil) { + return + } + } } - resp.Close() - wg.Wait() + filename := filepath.Join("testdata", name+".events") + ef, _ := expectedDir.ReadFile(filename) + dec := json.NewDecoder(bytes.NewReader(ef)) + + for msg, err := range parseSubscribeEvents(evtIter, etag, []string{"api", "web"}) { + if err != nil { + t.Errorf("unexpected error: %v", err) + break + } + var expected defangv1.SubscribeResponse + if err := dec.Decode(&expected); err == io.EOF { + t.Errorf("unexpected message: %v", msg) + } else if err != nil { + t.Errorf("error unmarshaling expected event: %v", err) + } else if msg.Name != expected.Name || msg.Status != expected.Status || msg.State != expected.State { + t.Errorf("expected message-, got+\n-%v\n+%v", &expected, msg) + } + } }) } } @@ -407,3 +409,300 @@ aws_secret_access_key = wJalrXUtnFEMI/KDEFANG/bPxRfiCYEXAMPLEKEY }) } } + +// mockCWClient implements cw.LogsClient for testing queryLogs and queryCdLogs. +type mockCWClient struct { + events []cwTypes.FilteredLogEvent +} + +func (m *mockCWClient) FilterLogEvents(ctx context.Context, input *cloudwatchlogs.FilterLogEventsInput, optFns ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.FilterLogEventsOutput, error) { + events := m.events + if input.Limit != nil && int(*input.Limit) < len(events) { + events = events[:*input.Limit] + } + return &cloudwatchlogs.FilterLogEventsOutput{ + Events: events, + }, nil +} + +func (m *mockCWClient) StartLiveTail(ctx context.Context, input *cloudwatchlogs.StartLiveTailInput, optFns ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.StartLiveTailOutput, error) { + return nil, &cwTypes.ResourceNotFoundException{ + Message: awssdk.String("mock: log group does not exist"), + } +} + +// makeMockEvents creates n FilteredLogEvents with sequential timestamps and messages. +// The log stream name follows the awslogs format: "/_/" +func makeMockEvents(n int, service, etag string) []cwTypes.FilteredLogEvent { + events := make([]cwTypes.FilteredLogEvent, n) + for i := range events { + ts := int64((i + 1) * 1000) // 1000, 2000, 3000, ... + events[i] = cwTypes.FilteredLogEvent{ + Message: awssdk.String(fmt.Sprintf("log message %d", i+1)), + Timestamp: &ts, + LogStreamName: awssdk.String(fmt.Sprintf("%s/%s_%s/task%d", service, service, etag, i)), + } + } + return events +} + +func newTestByocAws() *ByocAws { + b := &ByocAws{ + driver: cfn.New(byoc.CdTaskPrefix, aws.Region("us-test-2")), + } + b.driver.AccountID = "123456789012" + b.driver.LogGroupARN = "arn:aws:logs:us-test-2:123456789012:log-group:defang-cd-LogGroup:*" + b.driver.ClusterName = "test-cluster" + b.ByocBaseClient = byoc.NewByocBaseClient("tenant1", b, "beta") + return b +} + +func collectEvents(t *testing.T, iter func(func(cw.LogEvent, error) bool)) []cw.LogEvent { + t.Helper() + var events []cw.LogEvent + for evt, err := range iter { + require.NoError(t, err) + events = append(events, evt) + } + return events +} + +func TestQueryLogs(t *testing.T) { + const etag = "hg2xsgvsldqk" + baseTime := time.Date(2026, 1, 15, 10, 0, 0, 0, time.UTC) + + tests := []struct { + name string + req *defangv1.TailRequest + numEvents int // how many mock events to create + wantCount int // expected number of events returned + wantFirst string + wantLast string + }{ + { + name: "query, no limit, no times", + req: &defangv1.TailRequest{ + Services: []string{"app"}, + Project: "testproject", + LogType: uint32(logs.LogTypeRun), + }, + numEvents: 5, + wantCount: 5, + wantFirst: "log message 1", + wantLast: "log message 5", + }, + { + name: "query, no limit, with start", + req: &defangv1.TailRequest{ + Services: []string{"app"}, + Since: timestamppb.New(baseTime), + Project: "testproject", + LogType: uint32(logs.LogTypeRun), + }, + numEvents: 5, + wantCount: 5, + wantFirst: "log message 1", + wantLast: "log message 5", + }, + { + name: "query, no limit, with start and end", + req: &defangv1.TailRequest{ + Services: []string{"app"}, + Since: timestamppb.New(baseTime), + Until: timestamppb.New(baseTime.Add(time.Hour)), + Project: "testproject", + LogType: uint32(logs.LogTypeRun), + }, + numEvents: 5, + wantCount: 5, + wantFirst: "log message 1", + wantLast: "log message 5", + }, + { + name: "query, limit 3, with start (TakeFirstN)", + req: &defangv1.TailRequest{ + Services: []string{"app"}, + Since: timestamppb.New(baseTime), + Limit: 3, + Project: "testproject", + LogType: uint32(logs.LogTypeRun), + }, + numEvents: 5, + wantCount: 3, + wantFirst: "log message 1", + wantLast: "log message 3", + }, + { + name: "query, limit 3, no start (TakeLastN)", + req: &defangv1.TailRequest{ + Services: []string{"app"}, + Limit: 3, + Project: "testproject", + LogType: uint32(logs.LogTypeRun), + }, + numEvents: 5, + wantCount: 3, + wantFirst: "log message 3", + wantLast: "log message 5", + }, + { + name: "query, limit 3, with start and end", + req: &defangv1.TailRequest{ + Services: []string{"app"}, + Since: timestamppb.New(baseTime), + Until: timestamppb.New(baseTime.Add(time.Hour)), + Limit: 3, + Project: "testproject", + LogType: uint32(logs.LogTypeRun), + }, + numEvents: 5, + wantCount: 3, + wantFirst: "log message 1", + wantLast: "log message 3", + }, + { + name: "query, limit exceeds events", + req: &defangv1.TailRequest{ + Services: []string{"app"}, + Since: timestamppb.New(baseTime), + Limit: 10, + Project: "testproject", + LogType: uint32(logs.LogTypeRun), + }, + numEvents: 3, + wantCount: 3, + wantFirst: "log message 1", + wantLast: "log message 3", + }, + { + name: "query, zero events", + req: &defangv1.TailRequest{ + Services: []string{"app"}, + Since: timestamppb.New(baseTime), + Project: "testproject", + LogType: uint32(logs.LogTypeRun), + }, + numEvents: 0, + wantCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := newTestByocAws() + mock := &mockCWClient{ + events: makeMockEvents(tt.numEvents, "app", etag), + } + + logSeq, err := b.queryOrTailLogs(t.Context(), mock, tt.req) + require.NoError(t, err) + + events := collectEvents(t, logSeq) + assert.Len(t, events, tt.wantCount) + + if tt.wantCount > 0 { + assert.Equal(t, tt.wantFirst, *events[0].Message) + assert.Equal(t, tt.wantLast, *events[len(events)-1].Message) + + // Verify ascending timestamp order + for i := 1; i < len(events); i++ { + assert.LessOrEqual(t, *events[i-1].Timestamp, *events[i].Timestamp, "events not in ascending order at index %d", i) + } + } + }) + } +} + +func TestQueryLogs_FollowMode(t *testing.T) { + b := newTestByocAws() + mock := &mockCWClient{ + events: makeMockEvents(3, "app", "hg2xsgvsldqk"), + } + + ctx, cancel := context.WithCancel(t.Context()) + req := &defangv1.TailRequest{ + Services: []string{"app"}, + Follow: true, + Project: "testproject", + LogType: uint32(logs.LogTypeRun), + } + + logSeq, err := b.queryOrTailLogs(ctx, mock, req) + require.NoError(t, err) + + // Cancel immediately to stop the polling/tailing + cancel() + + // Should not panic; may yield context.Canceled errors + for _, err := range logSeq { + if err != nil { + assert.ErrorIs(t, err, context.Canceled) + } + } +} + +func TestQueryCdLogs(t *testing.T) { + baseTime := time.Date(2026, 1, 15, 10, 0, 0, 0, time.UTC) + taskID := "abc123def456" + + tests := []struct { + name string + req *defangv1.TailRequest + numEvents int + wantCount int + }{ + { + name: "query mode, no limit", + req: &defangv1.TailRequest{ + Etag: taskID, + Since: timestamppb.New(baseTime), + }, + numEvents: 5, + wantCount: 5, + }, + { + name: "query mode, with limit", + req: &defangv1.TailRequest{ + Etag: taskID, + Since: timestamppb.New(baseTime), + Limit: 3, + }, + numEvents: 5, + wantCount: 3, + }, + { + name: "query mode, with start and end", + req: &defangv1.TailRequest{ + Etag: taskID, + Since: timestamppb.New(baseTime), + Until: timestamppb.New(baseTime.Add(time.Hour)), + }, + numEvents: 5, + wantCount: 5, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b := newTestByocAws() + mock := &mockCWClient{ + events: makeMockEvents(tt.numEvents, "crun", ""), + } + + batchSeq, err := b.queryOrTailCdLogs(t.Context(), mock, tt.req) + require.NoError(t, err) + + // Flatten and collect + logSeq := cw.Flatten(batchSeq) + events := collectEvents(t, logSeq) + assert.Len(t, events, tt.wantCount) + }) + } +} + +// TestQueryCdLogs_FollowMode is skipped because TailTaskID polls getTaskStatus +// (real AWS ECS API) when StartLiveTail returns ResourceNotFoundException. +// Testing follow mode for CD logs requires mocking the ECS DescribeTasks API. +func TestQueryCdLogs_FollowMode(t *testing.T) { + t.Skip("requires ECS API mock for getTaskStatus") +} diff --git a/src/pkg/cli/client/byoc/aws/stream.go b/src/pkg/cli/client/byoc/aws/stream.go index 69ac5cbb8..489eb8bc8 100644 --- a/src/pkg/cli/client/byoc/aws/stream.go +++ b/src/pkg/cli/client/byoc/aws/stream.go @@ -2,15 +2,12 @@ package aws import ( "encoding/json" - "io" "regexp" "slices" "strings" "time" - "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/cli/client/byoc" - "github.com/DefangLabs/defang/src/pkg/clouds/aws/codebuild" "github.com/DefangLabs/defang/src/pkg/clouds/aws/cw" "github.com/DefangLabs/defang/src/pkg/clouds/aws/ecs" "github.com/DefangLabs/defang/src/pkg/logs" @@ -22,65 +19,18 @@ import ( var codeBuildPrefixRegex = regexp.MustCompile(`^[a-zA-Z0-9_-]+-image/`) -// byocServerStream is a wrapper around awsecs.EventStream that implements connect-like ServerStream -type byocServerStream struct { - err error +// logEventParser converts raw CW log events into TailResponse protos. +type logEventParser struct { etag string - response *defangv1.TailResponse services []string - stream cw.LiveTailStream - - ecsEventsHandler ECSEventHandler - codebuildEventHandler CodebuildEventHandler -} - -func newByocServerStream(stream cw.LiveTailStream, etag string, services []string, ecsEventHandler ECSEventHandler, codebuildEventHandler CodebuildEventHandler) *byocServerStream { - return &byocServerStream{ - etag: etag, - stream: stream, - services: services, - - ecsEventsHandler: ecsEventHandler, - codebuildEventHandler: codebuildEventHandler, - } } -var _ client.ServerStream[defangv1.TailResponse] = (*byocServerStream)(nil) - -func (bs *byocServerStream) Close() error { - return bs.stream.Close() -} - -func (bs *byocServerStream) Err() error { - if bs.err == io.EOF { - return nil // same as the original gRPC/connect server stream - } - return bs.err +func (p *logEventParser) parseEvent(event cw.LogEvent) *defangv1.TailResponse { + return p.parseEvents([]cw.LogEvent{event}) } -func (bs *byocServerStream) Msg() *defangv1.TailResponse { - return bs.response -} - -func (bs *byocServerStream) Receive() bool { - e := <-bs.stream.Events() - if err := bs.stream.Err(); err != nil { - bs.err = AnnotateAwsError(err) - return false - } - evts, err := cw.GetLogEvents(e) - if err != nil { - bs.err = err - return false - } - bs.response = bs.parseEvents(evts) - return true -} - -func (bs *byocServerStream) parseEvents(events []cw.LogEvent) *defangv1.TailResponse { +func (p *logEventParser) parseEvents(events []cw.LogEvent) *defangv1.TailResponse { if len(events) == 0 { - // The original gRPC/connect server stream would never send an empty response. - // We could loop around the select, but returning an empty response updates the spinner. return nil } @@ -91,6 +41,9 @@ func (bs *byocServerStream) parseEvents(events []cw.LogEvent) *defangv1.TailResp // Get the Etag/Host/Service from the first entry (should be the same for all events in this batch) first := events[0] switch { + case first.LogGroupIdentifier == nil || first.LogStreamName == nil: + response.Service = "alb" + // response.Host = TODO: we can get the ALB IP from the bucket object name case strings.HasSuffix(*first.LogGroupIdentifier, "/ecs"): // ECS lifecycle events. LogStreams: "f0b805a8-fa74-3212-b6ce-a981c011d337" parseECSEventRecords = true @@ -142,10 +95,10 @@ func (bs *byocServerStream) parseEvents(events []cw.LogEvent) *defangv1.TailResp } // Client-side filtering on etag and service (if provided) - if response.Etag != "" && bs.etag != "" && bs.etag != response.Etag { + if response.Etag != "" && p.etag != "" && p.etag != response.Etag { return nil // TODO: filter these out using the AWS StartLiveTail API } - if len(bs.services) > 0 && !slices.Contains(bs.services, response.GetService()) { + if len(p.services) > 0 && !slices.Contains(p.services, response.GetService()) { return nil // TODO: filter these out using the AWS StartLiveTail API } @@ -172,9 +125,6 @@ func (bs *byocServerStream) parseEvents(events []cw.LogEvent) *defangv1.TailResp if err != nil { term.Debugf("error parsing ECS event, output raw event log: %v", err) } else { - if bs.ecsEventsHandler != nil { - bs.ecsEventsHandler.HandleECSEvent(evt) - } entry.Service = evt.Service() entry.Etag = evt.Etag() entry.Host = evt.Host() @@ -184,17 +134,13 @@ func (bs *byocServerStream) parseEvents(events []cw.LogEvent) *defangv1.TailResp entry.Service = response.Service entry.Etag = response.Etag entry.Host = response.Host - evt := codebuild.ParseCodebuildEvent(entry) - if bs.codebuildEventHandler != nil && evt.State() != defangv1.ServiceState_NOT_SPECIFIED { - bs.codebuildEventHandler.HandleCodebuildEvent(evt) - } } else if (response.Service == "cd") && (strings.HasPrefix(entry.Message, logs.ErrorPrefix) || strings.Contains(strings.ToLower(entry.Message), "error:")) { entry.Stderr = true } - if entry.Etag != "" && bs.etag != "" && entry.Etag != bs.etag { + if entry.Etag != "" && p.etag != "" && entry.Etag != p.etag { continue } - if entry.Service != "" && len(bs.services) > 0 && !slices.Contains(bs.services, entry.Service) { + if entry.Service != "" && len(p.services) > 0 && !slices.Contains(p.services, entry.Service) { continue } diff --git a/src/pkg/cli/client/byoc/aws/stream_test.go b/src/pkg/cli/client/byoc/aws/stream_test.go index 2d9926279..1555b1191 100644 --- a/src/pkg/cli/client/byoc/aws/stream_test.go +++ b/src/pkg/cli/client/byoc/aws/stream_test.go @@ -143,10 +143,10 @@ func TestStreamToLogEvent(t *testing.T) { }, } - var byocServiceStream = newByocServerStream(nil, testEtag, []string{"cd", "app", "django", "django-image"}, nil, nil) + parser := &logEventParser{etag: testEtag, services: []string{"cd", "app", "django", "django-image"}} for _, td := range testdata { - tailResp := byocServiceStream.parseEvents([]cw.LogEvent{*td.event}) + tailResp := parser.parseEvents([]cw.LogEvent{*td.event}) if (td.wantResp == nil) != (tailResp == nil) { t.Errorf("nil mismatch: expected %v, got %v", td.wantResp, tailResp) continue diff --git a/src/pkg/cli/client/byoc/aws/subscribe.go b/src/pkg/cli/client/byoc/aws/subscribe.go index a48055726..9d5aa29aa 100644 --- a/src/pkg/cli/client/byoc/aws/subscribe.go +++ b/src/pkg/cli/client/byoc/aws/subscribe.go @@ -1,81 +1,118 @@ package aws import ( + "iter" "slices" "strings" "github.com/DefangLabs/defang/src/pkg/clouds/aws/codebuild" + "github.com/DefangLabs/defang/src/pkg/clouds/aws/cw" "github.com/DefangLabs/defang/src/pkg/clouds/aws/ecs" + "github.com/DefangLabs/defang/src/pkg/term" "github.com/DefangLabs/defang/src/pkg/types" defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" ) -type byocSubscribeServerStream struct { - services []string - etag types.ETag +// parseSubscribeEvents converts CW log events from ECS and builds log groups +// into SubscribeResponse, filtering by etag and services. +func parseSubscribeEvents(logSeq iter.Seq2[[]cw.LogEvent, error], etag types.ETag, services []string) iter.Seq2[*defangv1.SubscribeResponse, error] { + return func(yield func(*defangv1.SubscribeResponse, error) bool) { + for events, err := range logSeq { + for _, event := range events { + if resp := parseSubscribeEvent(event, etag, services); resp != nil { + if !yield(resp, nil) { + return + } + } + } + if err != nil { + if !yield(nil, err) { + return + } + } + } + } +} - ch chan *defangv1.SubscribeResponse - resp *defangv1.SubscribeResponse - err error - done chan struct{} +func parseSubscribeEvent(evt cw.LogEvent, etag types.ETag, services []string) *defangv1.SubscribeResponse { + if evt.LogGroupIdentifier == nil || evt.Message == nil { + return nil + } + + switch { + case strings.HasSuffix(*evt.LogGroupIdentifier, "/ecs"): + return parseECSSubscribeEvent(evt, etag, services) + case strings.HasSuffix(*evt.LogGroupIdentifier, "/builds") && + evt.LogStreamName != nil && + codeBuildPrefixRegex.MatchString(*evt.LogStreamName): + return parseCodebuildSubscribeEvent(evt, etag, services) + default: + return nil + } } -func (s *byocSubscribeServerStream) HandleCodebuildEvent(evt codebuild.Event) { - if etag := evt.Etag(); etag == "" || etag != s.etag { - return +func parseECSSubscribeEvent(evt cw.LogEvent, etag types.ETag, services []string) *defangv1.SubscribeResponse { + ecsEvt, err := ecs.ParseECSEvent([]byte(*evt.Message)) + if err != nil { + term.Debugf("error parsing ECS event: %v", err) + return nil } - service := strings.TrimSuffix(evt.Service(), "-image") - if len(s.services) > 0 && !slices.Contains(s.services, service) { - return + + if e := ecsEvt.Etag(); e == "" || (etag != "" && e != etag) { + return nil } - resp := defangv1.SubscribeResponse{ - Name: evt.Service(), - Status: evt.Status(), - State: evt.State(), + if service := ecsEvt.Service(); len(services) > 0 && !slices.Contains(services, service) { + return nil } - select { - case s.ch <- &resp: - case <-s.done: + + return &defangv1.SubscribeResponse{ + Name: ecsEvt.Service(), + Status: ecsEvt.Status(), + State: ecsEvt.State(), } } -func (s *byocSubscribeServerStream) HandleECSEvent(evt ecs.Event) { - if etag := evt.Etag(); etag == "" || etag != s.etag { - return +func parseCodebuildSubscribeEvent(evt cw.LogEvent, etag types.ETag, services []string) *defangv1.SubscribeResponse { + // Extract service/etag from log stream name: "-image/_/" + if evt.LogStreamName == nil { + return nil } - if service := evt.Service(); len(s.services) > 0 && !slices.Contains(s.services, service) { - return + parts := strings.Split(*evt.LogStreamName, "/") + if len(parts) != 3 { + return nil } - resp := defangv1.SubscribeResponse{ - Name: evt.Service(), - Status: evt.Status(), - State: evt.State(), + underscore := strings.LastIndexByte(parts[1], '_') + if underscore < 0 { + return nil } - select { - case s.ch <- &resp: - case <-s.done: - } -} -func (s *byocSubscribeServerStream) Close() error { - close(s.done) - return nil -} + cbEtag := parts[1][underscore+1:] + cbService := parts[0] // -image + cbHost := parts[2] // build id -func (s *byocSubscribeServerStream) Receive() bool { - select { - case resp := <-s.ch: - s.resp = resp - return true - case <-s.done: - return false + if etag != "" && cbEtag != etag { + return nil } -} -func (s *byocSubscribeServerStream) Msg() *defangv1.SubscribeResponse { - return s.resp -} + service := strings.TrimSuffix(cbService, "-image") + if len(services) > 0 && !slices.Contains(services, service) { + return nil + } -func (s *byocSubscribeServerStream) Err() error { - return s.err + entry := &defangv1.LogEntry{ + Message: *evt.Message, + Service: cbService, + Etag: cbEtag, + Host: cbHost, + } + cbEvt := codebuild.ParseCodebuildEvent(entry) + if cbEvt.State() == defangv1.ServiceState_NOT_SPECIFIED { + return nil + } + + return &defangv1.SubscribeResponse{ + Name: service, + Status: cbEvt.Status(), + State: cbEvt.State(), + } } diff --git a/src/pkg/cli/client/byoc/aws/subscribe_test.go b/src/pkg/cli/client/byoc/aws/subscribe_test.go index fa9053b6b..90510e152 100644 --- a/src/pkg/cli/client/byoc/aws/subscribe_test.go +++ b/src/pkg/cli/client/byoc/aws/subscribe_test.go @@ -3,124 +3,191 @@ package aws import ( "testing" + "github.com/DefangLabs/defang/src/pkg/clouds/aws/cw" defangv1 "github.com/DefangLabs/defang/src/protos/io/defang/v1" + awssdk "github.com/aws/aws-sdk-go-v2/aws" ) -type mockECSEvent struct { - etag string - service string - state defangv1.ServiceState +func makeCWLogEvent(logGroup, streamName, message string) cw.LogEvent { + var ts int64 + return cw.LogEvent{ + LogGroupIdentifier: awssdk.String(logGroup), + LogStreamName: awssdk.String(streamName), + Message: awssdk.String(message), + Timestamp: &ts, + } } -func (e *mockECSEvent) Service() string { return e.service } -func (e *mockECSEvent) Etag() string { return e.etag } -func (e *mockECSEvent) Host() string { return "" } -func (e *mockECSEvent) Status() string { return "" } -func (e *mockECSEvent) State() defangv1.ServiceState { return e.state } +// ECS Task State Change with etag in container override name: service1_etag1 +const ecsTaskStateChange = `{"version":"0","id":"abc","detail-type":"ECS Task State Change","source":"aws.ecs","account":"123","time":"2024-01-01T00:00:00Z","region":"us-west-2","resources":["arn:aws:ecs:us-west-2:123:task/cluster/taskid"],"detail":{"lastStatus":"DEACTIVATING","stoppedReason":"","taskArn":"arn:aws:ecs:us-west-2:123:task/cluster/taskid","containers":[],"overrides":{"containerOverrides":[{"name":"service1_etag1","command":[]}]},"startedBy":"ecs-svc/deploy1"}}` -func TestByocSubscribeServerStream(t *testing.T) { - t.Run("ignore event with different etag", func(t *testing.T) { - ss := &byocSubscribeServerStream{ - services: []string{"service1", "service2"}, - etag: "etag1", - ch: make(chan *defangv1.SubscribeResponse), - done: make(chan struct{}), +func TestParseSubscribeEvent_ECS(t *testing.T) { + t.Run("matches etag and service", func(t *testing.T) { + evt := makeCWLogEvent("arn:aws:logs:us-west-2:123:log-group:/ecs", "stream-id", ecsTaskStateChange) + resp := parseSubscribeEvent(evt, "etag1", []string{"service1"}) + if resp == nil { + t.Fatal("expected a response") } + if resp.Name != "service1" { + t.Errorf("expected service1, got %s", resp.Name) + } + if resp.State != defangv1.ServiceState_DEPLOYMENT_FAILED { + t.Errorf("expected DEPLOYMENT_FAILED, got %s", resp.State) + } + }) - go func() { - ss.HandleECSEvent(&mockECSEvent{etag: "different-etag", service: "service1"}) - ss.Close() - }() - if ss.Receive() { - t.Errorf("expected no message, but got one: %v", ss.Msg()) + t.Run("no etag filter passes matching events", func(t *testing.T) { + evt := makeCWLogEvent("arn:aws:logs:us-west-2:123:log-group:/ecs", "stream-id", ecsTaskStateChange) + resp := parseSubscribeEvent(evt, "", nil) + if resp == nil { + t.Fatal("expected a response") + } + if resp.Name != "service1" { + t.Errorf("expected service1, got %s", resp.Name) } }) - t.Run("ignore event from a different service", func(t *testing.T) { - ss := &byocSubscribeServerStream{ - services: []string{"service1", "service2"}, - etag: "etag1", - ch: make(chan *defangv1.SubscribeResponse), - done: make(chan struct{}), + t.Run("filters by etag", func(t *testing.T) { + evt := makeCWLogEvent("arn:aws:logs:us-west-2:123:log-group:/ecs", "stream-id", ecsTaskStateChange) + resp := parseSubscribeEvent(evt, "different-etag", nil) + if resp != nil { + t.Errorf("expected nil for different etag, got %v", resp) } + }) - go func() { - ss.HandleECSEvent(&mockECSEvent{etag: "etag1", service: "service3"}) - ss.Close() - }() - if ss.Receive() { - t.Errorf("expected no message, but got one: %v", ss.Msg()) + t.Run("filters by service", func(t *testing.T) { + evt := makeCWLogEvent("arn:aws:logs:us-west-2:123:log-group:/ecs", "stream-id", ecsTaskStateChange) + resp := parseSubscribeEvent(evt, "etag1", []string{"other-service"}) + if resp != nil { + t.Errorf("expected nil for different service, got %v", resp) } }) - t.Run("receive event from correct service and etag", func(t *testing.T) { - ss := &byocSubscribeServerStream{ - services: []string{"service1", "service2"}, - etag: "etag1", - ch: make(chan *defangv1.SubscribeResponse), - done: make(chan struct{}), + t.Run("non-ecs log group ignored", func(t *testing.T) { + evt := makeCWLogEvent("arn:aws:logs:us-west-2:123:log-group:/logs", "stream-id", ecsTaskStateChange) + resp := parseSubscribeEvent(evt, "etag1", nil) + if resp != nil { + t.Errorf("expected nil for non-ecs log group, got %v", resp) } + }) +} - go func() { - ss.HandleECSEvent(&mockECSEvent{etag: "etag1", service: "service2", state: defangv1.ServiceState_DEPLOYMENT_COMPLETED}) - ss.Close() - }() - if !ss.Receive() { - t.Errorf("expected a message, but got none") +func TestParseSubscribeEvent_Codebuild(t *testing.T) { + t.Run("build activating", func(t *testing.T) { + evt := makeCWLogEvent( + "arn:aws:logs:us-west-2:123:log-group:tenant/proj/builds", + "worker-image/worker_etag123/build-id-1", + "Running on CodeBuild", + ) + resp := parseSubscribeEvent(evt, "", nil) + if resp == nil { + t.Fatal("expected a response") } - if ss.Msg().Name != "service2" { - t.Errorf("expected service2, but got: %s", ss.Msg().Name) + if resp.Name != "worker" { + t.Errorf("expected worker, got %s", resp.Name) } - if ss.Msg().State != defangv1.ServiceState_DEPLOYMENT_COMPLETED { - t.Errorf("expected state RUNNING, but got: %s", ss.Msg().State) + if resp.State != defangv1.ServiceState_BUILD_ACTIVATING { + t.Errorf("expected BUILD_ACTIVATING, got %s", resp.State) } }) - t.Run("multiple events", func(t *testing.T) { - ss := &byocSubscribeServerStream{ - services: []string{"service1", "service2"}, - etag: "etag1", - ch: make(chan *defangv1.SubscribeResponse), - done: make(chan struct{}), - } - - go func() { - ss.HandleECSEvent(&mockECSEvent{etag: "etag1", service: "service2", state: defangv1.ServiceState_DEPLOYMENT_PENDING}) - ss.HandleECSEvent(&mockECSEvent{etag: "etag1", service: "service1", state: defangv1.ServiceState_BUILD_ACTIVATING}) - ss.HandleECSEvent(&mockECSEvent{etag: "etag1", service: "service2", state: defangv1.ServiceState_DEPLOYMENT_COMPLETED}) - ss.Close() - }() - count := 0 - for ss.Receive() { - msg := ss.Msg() - if count == 0 && (msg.Name != "service2" || msg.State != defangv1.ServiceState_DEPLOYMENT_PENDING) { - t.Errorf("first message mismatch, got: %v", msg) - } - if count == 1 && (msg.Name != "service1" || msg.State != defangv1.ServiceState_BUILD_ACTIVATING) { - t.Errorf("second message mismatch, got: %v", msg) - } - if count == 2 && (msg.Name != "service2" || msg.State != defangv1.ServiceState_DEPLOYMENT_COMPLETED) { - t.Errorf("third message mismatch, got: %v", msg) - } - count++ + t.Run("build failed", func(t *testing.T) { + evt := makeCWLogEvent( + "arn:aws:logs:us-west-2:123:log-group:tenant/proj/builds", + "api-image/api_etag456/build-id-2", + "Phase complete: BUILD State: FAILED", + ) + resp := parseSubscribeEvent(evt, "", nil) + if resp == nil { + t.Fatal("expected a response") + } + if resp.Name != "api" { + t.Errorf("expected api, got %s", resp.Name) } - if count != 3 { - t.Errorf("expected 3 messages, but got %d", count) + if resp.State != defangv1.ServiceState_BUILD_FAILED { + t.Errorf("expected BUILD_FAILED, got %s", resp.State) } }) - t.Run("event after close", func(t *testing.T) { - ss := &byocSubscribeServerStream{ - services: []string{"service1"}, - etag: "etag1", - ch: make(chan *defangv1.SubscribeResponse), - done: make(chan struct{}), + t.Run("filters by etag", func(t *testing.T) { + evt := makeCWLogEvent( + "arn:aws:logs:us-west-2:123:log-group:tenant/proj/builds", + "worker-image/worker_etag123/build-id-1", + "Running on CodeBuild", + ) + resp := parseSubscribeEvent(evt, "different-etag", nil) + if resp != nil { + t.Errorf("expected nil for different etag, got %v", resp) } + }) - ss.Close() - ss.HandleECSEvent(&mockECSEvent{etag: "etag1", service: "service1", state: defangv1.ServiceState_DEPLOYMENT_COMPLETED}) - if ss.Receive() { - t.Errorf("expected no message after close, but got one: %v", ss.Msg()) + t.Run("filters by service", func(t *testing.T) { + evt := makeCWLogEvent( + "arn:aws:logs:us-west-2:123:log-group:tenant/proj/builds", + "worker-image/worker_etag123/build-id-1", + "Running on CodeBuild", + ) + resp := parseSubscribeEvent(evt, "", []string{"api"}) + if resp != nil { + t.Errorf("expected nil for different service, got %v", resp) } }) + + t.Run("unrecognized message ignored", func(t *testing.T) { + evt := makeCWLogEvent( + "arn:aws:logs:us-west-2:123:log-group:tenant/proj/builds", + "worker-image/worker_etag123/build-id-1", + "Some random log line", + ) + resp := parseSubscribeEvent(evt, "", nil) + if resp != nil { + t.Errorf("expected nil for NOT_SPECIFIED state, got %v", resp) + } + }) +} + +func TestParseSubscribeEvents(t *testing.T) { + events := []cw.LogEvent{ + makeCWLogEvent( + "arn:aws:logs:us-west-2:123:log-group:tenant/proj/builds", + "worker-image/worker_etag1/build-id", + "Running on CodeBuild", + ), + makeCWLogEvent( + "arn:aws:logs:us-west-2:123:log-group:tenant/proj/builds", + "worker-image/worker_etag1/build-id", + "Some random log line", // NOT_SPECIFIED state, should be filtered + ), + makeCWLogEvent( + "arn:aws:logs:us-west-2:123:log-group:/ecs", + "stream-id", + ecsTaskStateChange, + ), + } + + iter := func(yield func([]cw.LogEvent, error) bool) { + for _, evt := range events { + if !yield([]cw.LogEvent{evt}, nil) { + return + } + } + } + + var results []*defangv1.SubscribeResponse + for resp, err := range parseSubscribeEvents(iter, "", nil) { + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + results = append(results, resp) + } + + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } + if results[0].Name != "worker" || results[0].State != defangv1.ServiceState_BUILD_ACTIVATING { + t.Errorf("first result mismatch: %v", results[0]) + } + if results[1].Name != "service1" || results[1].State != defangv1.ServiceState_DEPLOYMENT_FAILED { + t.Errorf("second result mismatch: %v", results[1]) + } } diff --git a/src/pkg/cli/client/byoc/aws/testdata/123456789012_elasticloadbalancing_us-west-2_app.defang-agentic-strands-aws7d0286.c9b3756e8ef89456_20260206T0000Z_44.233.47.227_7tj887d8.log.gz b/src/pkg/cli/client/byoc/aws/testdata/123456789012_elasticloadbalancing_us-west-2_app.defang-agentic-strands-aws7d0286.c9b3756e8ef89456_20260206T0000Z_44.233.47.227_7tj887d8.log.gz new file mode 100644 index 0000000000000000000000000000000000000000..e567919107c59862f70e1739dc3d4534abd6e7cb GIT binary patch literal 1485 zcmV;;1v2^{iwFP!0000419VYMYuqppz4uoPJ|(1*rO}5ahaMKk?3QHP5*q@!6xo({ zO}w_T-EQcwuNJmY3T0*v^YwW1=D4oY0?1GkhJ>jSMXGeF5}H`elz0UyH}HynsN z_+jm8vaGWknuWOYk~9pBj%!;xE$`u7%Ujq?Q}17V_1LYk5(yO=ZXS2#{^=w1-J6H4uiwU7 z$d2=P@-gQD@nTT8*3I3YmApAzd)zWkk%|vLa=VC)PdBM?;ZMG-HZi7LK@~QY4 z#=^bP^Ej@bIMLQ9^wJX4$QoZcCLFd}wL&XX**XbVkq~Yk(jao`>Qnk-A!om=1XE$S zw1}J~2}^4~uU*^KZuN2HRnS@)A|#hYdTYp~L{KuJOk-5V>j>C?-mMs8uMoA}yKl(s zmuVapeKEJrOZ)? zT$%*Ilu`7_;3OM03VA80000000aZwl1ppc zFc5(6{S{(Q38`dhG0m=`9V$%(0uw?+iL=HOrA9$0p>cTW$T4oEQbcf&WO2UE$!;G^nk0w}q-1Cq zfJqT$a$F%=jsZ z3XP}aS>rKT4nrI6LUGZ|iIRp%O^z-$`PJoH(l!r)oQLwEKPIc)*gu9ui0E0R#p|FToC@NdUMY2$)awaLq*3h4ShF-Lm>NLev>Bca?v=iTV zm@w4pzn7_16O#rAm%w@M;XKs3fk0sh3W8=t9MBC~QErTOcn5ylSWL5D?}@)pYh41_ zYw>H`iGSMrh{75c+$4K#l1;O9N}ER1RB)t{Q_&AbNNs@t;V)C7Xh$IsBuD7pUytz6 zj0}uP)pR}OJyN0NV3EPiINmZ&c?aE1I-N4iC7*E>>?Gt;(+|2ekZ+vz1w%tVv%)xg_k2&jY)zcl=f+r< zBZW|x?@RsJ9Up3=`;>Ei%ZYJTORfdcu6Hr|-uS4anw@sRD-%TXPkt|aEpfz_cvn(f zN(!|-K2kXh{}+o=lKsQ|!5;hg;p(fGb`nXeS%)diDtc$*6wyZ^gw?kH>!S(Y`~n9I zhU)ltMuZs}(bvnSIH5NklNw029EPj%qo1F_KjWOI z04cA29uoLVG6NH~p~3ai=mss0)+yZ-MqLZeYRPM1+D19n5Fo&BM=@2dLL0&ho>cd{ zj=Wr@w3JF!kf!E_)k4cwDM`t`>mDetc}1!{9*>cXYdnz$U%JgPNAeBxC9XfyM>0cE1BGdf-%@9L^5Z8=v^-_p@;hAqVb%-arQckyj-r@UFz(@~; zW1g--ggm`Y`Bz1D+q?Y^9uFJX><(O5xZi$-!+s-52Hls#Yq4L9R*x~}PaGO&wd7iW zX2Kiftg)zi$+7V$8+5ApgHnE-+Y)Z;!>K(dxM0d4T4}j*qNHjJYXh!%8mHhh(jK)? zR$GvYD^TbxTt^@jliChe_dHaD{cp{Tu_scE(;E)3Vp?MHdJ}RL^G|i3?|%5P521J3 fIN{Z(nAIG|F*u_gV{UMNgFgHQZeCl~V*&sG2VH^o literal 0 HcmV?d00001 diff --git a/src/pkg/cli/client/byoc/aws/testdata/123456789012_elasticloadbalancing_us-west-2_app.defang-agentic-strands-aws7d0286.c9b3756e8ef89456_20260206T0010Z_44.233.47.227_2eihuvci.log.gz b/src/pkg/cli/client/byoc/aws/testdata/123456789012_elasticloadbalancing_us-west-2_app.defang-agentic-strands-aws7d0286.c9b3756e8ef89456_20260206T0010Z_44.233.47.227_2eihuvci.log.gz new file mode 100644 index 0000000000000000000000000000000000000000..8dc3326876019bc4479312e808eabcf4b764a551 GIT binary patch literal 2745 zcmV;q3P$xGiwFP!0000418q@VYuqppeDALieo9CsucXz79r|!w%w0*YEwLfcmm=Hp zor%vjw%ZH&^_AfYr4(r&cBL6+W>3p9&P0IF6a)ob0HmN&flMKAfbYbOV_f^j^+)QC zzF*pk&dcQbdZz9&n;L}Gag~(PDDACpY=TM?ZWRlxjDukUO^QN^giu2H^&((srBZ~0 zkR;k=1v#x^(l|zBLP~}KDgqQ`e7DWFU$eWo8;(QIPRqF?`=a<3bH*bQ9ftR|>s+iD zkeipbuZPP_z7&C(!958BjrZhD<1N{YW9MId`PeS8l7>l5ZXWl=;prpk+E-6@zIq*Q z$@Vl2XCHGO7}rB`*SM+u1Ig;+^^GUMgHIWeZ2Pd!=ySfIn=I$T(%tqz^ZiDIFp8)A zSNInEM$f~reBx9mR!Kmls;O&l6>`P7DVtK%wKkwsa79D8sZT@7sjJTE&zWBQyik!! z%P|omWNRR;{Is-9Te-!@rB?#2vQ#2RiaxP)9Z`WGm94ERc@+`(|Gs4c>%OC=ee*S) z{5%f*?5{#rG5=KY`QZm((@0Z!P&&v)UuwaP@TC-GWlZJ%27UMqziHVsYytoPABzY8 z000001OvsBO>5jR5Qgvl6~ddn76d}eL7tdP)rXvC{365pQ%Zgnm6=hKn z^2);s2tedMYj$8-N?RFOh66?ECnAq4$dD0?sM_v620p}b*Kf!8K0Pl4iCRbmC#a!XA{no> z#;9Zr!IKQ?H-a7jm+0;~Ap`&bABzY8000001Os(Y+e+;)5Pd&iG4x5nB$G*+OT`CK zQPB;e1;rQ1PE&Vr+m`Nnj-0Qz^$sF)V_??ITC*C*9tg=WB!;F$EYK9nb0j5meG8%Q zMV0E%-m`F@+OdgjGY+AxHZ1(x*b2!R5&crx+$5K3=XGwtYVL?@%c;l-9RUL--(BI{ z5=u~nkR2bDu$ovDLTJrNsaF`;!l9_c-zIh~pAwvu<%N)3W{@qq|IK*gWBaf%P!HXN*m?lB2T#Bnyb z3UwHocStsKo-DiUx^Yi3<6W-M)eNj3}&ABzY8000001Os(YOH0Hs5Wf2>hMx40Wb$Z|cJZdJ9u$#+uop?wWVg7qOIvmq zf4%AY0!3yrA(_YbeY2US9S9VfpdhF##1c)Z6g9#~ZCBuSJJy8gx;1g@&`qr;<21Og z8HqcLra@sf^F=L9u7eHH7An_Z6t#jXLq(|yX@LOYpDt0PC$|!$kdQ5}E7&9wC1VLy zNO|fKrnJb`p&!Gr4+A@dn!#;VJu*&t2HAc8)^4|sD2C4UM1nf-+^-__mKi`3EA(iTD$ntNsI zMp$%ptzs~$mhqcZYB+lb7uBxV+W`OoABzY8000001Os)CO>4t242JLf6()DvvZcs& zptp@}hrt**^)i&iZUdz)O<=HJKbL(S#*kiwKzj6QzpiKH!LtUhfqQ@sQb%kkkjJ=H zaz5KB7dh^KUDuz2|NaGlt!gfLnlE`=>+sdn^2Zbr+F#l?^;-MK z+gFUGl-w2!d6GaJNv1I1=$(X71b%GPg_{p;&ASGE0RR9WiwFP!000041KpHOZ`?Kz zhVT6q#GVqMq&UNG89f;G!VVI*MRtJ{xfmmn+EpV<3rVhn{`zt^i@LUB)GgW`3Xy_1 zlF0dZ-r3IcZel3Wa*15pwUng{B?RGvbJBifwA&RiHPml7ZBjqCHJ|2@`Z#gAn_QIW zZBd6xIb)MgjSt#brUdvrtU%>UrIZB*Qb+;L2&+pYC9u4f@^HuFGv@VSA~Nf>XN!yF znr-(!mPPTE<8Pz%hG^gD*^^^GGugI>KTC0LxuGc>nAYj23)Q*DMVyda{ z+kV5YhF{yRqrwQuUR_b$_VX}pPuQD&P94hz_QQ((l%2l@bM0QUv+X$Cqyo_9X!6BY z)X>-#MmeGEf-TR^FPHrN>Xe@@SAd>hoL%wNKjmk9Ol!2ydEP0D5k+0)ELr<&xr)l2*b_|N@o9AMv(>S+HThpA1DjC$? zc&;>Pu2b-QZ{nz>GCq1!LA_w`44RoA`H?M}VZ5Vp?=(No$B<=n2I~HU^_%l+>myVU zRVAy6Djhw^oULf1O5RAF?swyhZw~N1^+Ic*TyWt2u0X({0L+5$LAb0CuN>oy!*OV- z^Ae5BQBQx{gCDIlzz61n*t8+AjnL*R4)?* ze*~!Z9LNV%=>*9KX?&J|jIPq*SjvBewte$?)Cj-S`AA#)IYmCEX*cv!x@YQl_>K_n z8IxI4P4XHeYMV58ry6QpY)I)~jC&DdMktQ~{~^Y}I{P)ow0aR^{~x#l2QO` 0 && time.Since(start) > 10*time.Millisecond { - s.queryHead(query, 0) + lister, err := s.gcpLogsClient.ListLogEntries(s.ctx, query, gcp.OrderAscending) + if err != nil { + yield(nil, err) + return + } + if !s.yieldList(yield, lister, 0) { + return + } } // Start tailing logs after all older logs are processed - if err := s.tailer.Start(s.ctx, query); err != nil { - s.errCh <- err + if err := tailer.Start(s.ctx, query); err != nil { + yield(nil, err) return } for { - entry, err := s.tailer.Next(s.ctx) + entry, err := tailer.Next(s.ctx) if err != nil { - s.errCh <- err + if context.Cause(s.ctx) == io.EOF || errors.Is(err, io.EOF) { + return + } + if isContextCanceledError(err) { + if cause := context.Cause(s.ctx); cause != nil { + yield(nil, cause) + } + return + } + yield(nil, err) return } resps, err := s.parseAndFilter(entry) if err != nil { - s.errCh <- err + yield(nil, err) return } for _, resp := range resps { - s.respCh <- resp + if !yield(resp, nil) { + return + } } } - }() + }, nil } -func (s *ServerStream[T]) StartHead(limit int32) { +// Head returns an iterator that queries logs in ascending order. +func (s *ServerStream[T]) Head(limit int32) iter.Seq2[*T, error] { query := s.query.GetQuery() term.Debugf("Query logs with query: \n%v", query) - go func() { - s.queryHead(query, limit) - }() + return func(yield func(*T, error) bool) { + lister, err := s.gcpLogsClient.ListLogEntries(s.ctx, query, gcp.OrderAscending) + if err != nil { + yield(nil, err) + return + } + s.yieldList(yield, lister, limit) + } } -func (s *ServerStream[T]) StartTail(limit int32) { +// Tail returns an iterator that queries logs in descending order, reversing if a limit is set. +func (s *ServerStream[T]) Tail(limit int32) iter.Seq2[*T, error] { query := s.query.GetQuery() term.Debugf("Query logs with query: \n%v", query) - go func() { - s.queryTail(query, limit) - }() -} - -func (s *ServerStream[T]) queryHead(query string, limit int32) { - lister, err := s.gcpLogsClient.ListLogEntries(s.ctx, query, gcp.OrderAscending) - if err != nil { - s.errCh <- err - return - } - if limit == 0 { - err = s.listToChannel(lister) - if err != nil && !errors.Is(err, io.EOF) { // Ignore EOF for listing older logs, to proceed to tailing - s.errCh <- err - return - } - } else { - buffer, err := s.listToBuffer(lister, limit) + return func(yield func(*T, error) bool) { + lister, err := s.gcpLogsClient.ListLogEntries(s.ctx, query, gcp.OrderDescending) if err != nil { - s.errCh <- err + yield(nil, err) + return } - for i := range buffer { - s.respCh <- buffer[i] + if limit == 0 { + s.yieldList(yield, lister, 0) + } else { + buffer, err := s.listToBuffer(lister, limit) + if err != nil { + yield(nil, err) + return + } + // iterate over the buffer in reverse order to send the oldest resps first + for i := len(buffer) - 1; i >= 0; i-- { + if !yield(buffer[i], nil) { + return + } + } } - s.errCh <- io.EOF } } -func (s *ServerStream[T]) queryTail(query string, limit int32) { - lister, err := s.gcpLogsClient.ListLogEntries(s.ctx, query, gcp.OrderDescending) - if err != nil { - s.errCh <- err - return - } - if limit == 0 { - err = s.listToChannel(lister) +// yieldList yields items from lister to yield. Returns true if iteration completed +// (EOF or limit reached), false if the consumer stopped or an error was yielded. +func (s *ServerStream[T]) yieldList(yield func(*T, error) bool, lister gcp.Lister, limit int32) bool { + count := int32(0) + for { + if limit > 0 && count >= limit { + return true + } + entry, err := lister.Next() if err != nil { - s.errCh <- err - return + if errors.Is(err, io.EOF) { + return true + } + yield(nil, err) + return false } - } else { - buffer, err := s.listToBuffer(lister, limit) + resps, err := s.parseAndFilter(entry) if err != nil { - s.errCh <- err + yield(nil, err) + return false } - // iterate over the buffer in reverse order to send the oldest resps first - for i := len(buffer) - 1; i >= 0; i-- { - s.respCh <- buffer[i] + for _, resp := range resps { + count++ + if !yield(resp, nil) { + return false + } } - s.errCh <- io.EOF } } func (s *ServerStream[T]) listToBuffer(lister gcp.Lister, limit int32) ([]*T, error) { - received := 0 buffer := make([]*T, 0, limit) for range limit { entry, err := lister.Next() @@ -215,27 +205,10 @@ func (s *ServerStream[T]) listToBuffer(lister gcp.Lister, limit int32) ([]*T, er return nil, err } buffer = append(buffer, resps...) - received += len(resps) } return buffer, nil } -func (s *ServerStream[T]) listToChannel(lister gcp.Lister) error { - for { - entry, err := lister.Next() - if err != nil { - return err - } - resps, err := s.parseAndFilter(entry) - if err != nil { - return err - } - for _, resp := range resps { - s.respCh <- resp - } - } -} - func (s *ServerStream[T]) parseAndFilter(entry *loggingpb.LogEntry) ([]*T, error) { resps, err := s.parse(entry) if err != nil { @@ -265,19 +238,11 @@ func (s *ServerStream[T]) GetQuery() string { return s.query.GetQuery() } -func (s *ServerStream[T]) Err() error { - return s.lastErr -} - -func (s *ServerStream[T]) Msg() *T { - return s.lastResp -} - type LogStream struct { *ServerStream[defangv1.TailResponse] } -func NewLogStream(ctx context.Context, gcpLogsClient GcpLogsClient, services []string) (*LogStream, error) { +func NewLogStream(ctx context.Context, gcpLogsClient GcpLogsClient, services []string) *LogStream { restoreServiceName := getServiceNameRestorer(services, gcp.SafeLabelValue, func(entry *defangv1.TailResponse) string { return entry.Service }, func(entry *defangv1.TailResponse, name string) *defangv1.TailResponse { @@ -285,13 +250,9 @@ func NewLogStream(ctx context.Context, gcpLogsClient GcpLogsClient, services []s return entry }) - ss, err := NewServerStream(ctx, gcpLogsClient, getLogEntryParser(ctx, gcpLogsClient), restoreServiceName) - if err != nil { - return nil, err - } - + ss := NewServerStream(ctx, gcpLogsClient, getLogEntryParser(ctx, gcpLogsClient), restoreServiceName) ss.query = NewLogQuery(gcpLogsClient.GetProjectID()) - return &LogStream{ServerStream: ss}, nil + return &LogStream{ServerStream: ss} } func (s *LogStream) AddJobExecutionLog(executionName string) { @@ -341,7 +302,7 @@ func getServiceNameRestorer[T any](services []string, encode func(string) string } } -func NewSubscribeStream(ctx context.Context, driver GcpLogsClient, waitForCD bool, etag string, services []string, filters ...LogFilter[*defangv1.SubscribeResponse]) (*SubscribeStream, error) { +func NewSubscribeStream(ctx context.Context, driver GcpLogsClient, waitForCD bool, etag string, services []string, filters ...LogFilter[*defangv1.SubscribeResponse]) *SubscribeStream { filters = append(filters, getServiceNameRestorer(services, gcp.SafeLabelValue, func(entry *defangv1.SubscribeResponse) string { return entry.Name }, func(entry *defangv1.SubscribeResponse, name string) *defangv1.SubscribeResponse { @@ -350,12 +311,9 @@ func NewSubscribeStream(ctx context.Context, driver GcpLogsClient, waitForCD boo }), ) - ss, err := NewServerStream(ctx, driver, getActivityParser(ctx, driver, waitForCD, etag), filters...) - if err != nil { - return nil, err - } + ss := NewServerStream(ctx, driver, getActivityParser(ctx, driver, waitForCD, etag), filters...) ss.query = NewSubscribeQuery() - return &SubscribeStream{ServerStream: ss}, nil + return &SubscribeStream{ServerStream: ss} } func (s *SubscribeStream) AddJobExecutionUpdate(executionName string) { diff --git a/src/pkg/cli/client/byoc/gcp/stream_test.go b/src/pkg/cli/client/byoc/gcp/stream_test.go index 92468effc..0ffc17009 100644 --- a/src/pkg/cli/client/byoc/gcp/stream_test.go +++ b/src/pkg/cli/client/byoc/gcp/stream_test.go @@ -1,6 +1,7 @@ package gcp import ( + "iter" "strconv" "testing" @@ -150,28 +151,27 @@ func TestServerStream_Start(t *testing.T) { tailer: &MockGcpLoggingTailer{}, } - stream, err := NewServerStream( + stream := NewServerStream( ctx, mockGcpLogsClient, getLogEntryParser(ctx, mockGcpLogsClient), restoreServiceName, ) - assert.NoError(t, err) stream.query = NewLogQuery(projectId) + var logs iter.Seq2[*defangv1.TailResponse, error] if tt.direction == head { - stream.StartHead(tt.limit) + logs = stream.Head(tt.limit) } else { - stream.StartTail(tt.limit) + logs = stream.Tail(tt.limit) } - collectedMessages := []string{} - for { - if !stream.Receive() { - assert.NoError(t, stream.Err()) + var collectedMessages []string + for response, err := range logs { + assert.NoError(t, err) + if err != nil { break } - response := stream.Msg() collectedMessages = append(collectedMessages, response.Entries[0].Message) } assert.Equal(t, len(tt.expectedMsgs), len(collectedMessages)) diff --git a/src/pkg/cli/client/mock.go b/src/pkg/cli/client/mock.go index bf1f38e21..49e5e31aa 100644 --- a/src/pkg/cli/client/mock.go +++ b/src/pkg/cli/client/mock.go @@ -3,9 +3,11 @@ package client import ( "context" "errors" + "iter" "net/http" "net/url" "path" + "sync" "sync/atomic" "github.com/DefangLabs/defang/src/pkg/dns" @@ -17,6 +19,24 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) +// MockIter creates an iter.Seq2 from a pre-populated list of responses and a final error. +// A nil response in the list acts as a stream-end marker. +func MockIter[T any](resps []*T, finalErr error) iter.Seq2[*T, error] { + return func(yield func(*T, error) bool) { + for _, resp := range resps { + if resp == nil { + return + } + if !yield(resp, nil) { + return + } + } + if finalErr != nil { + yield(nil, finalErr) + } + } +} + type MockProvider struct { Provider UploadUrl string @@ -89,14 +109,16 @@ func (m *MockServerStream[T]) Err() error { // returns messages and errors from channels. It blocks until the channels are // closed or an error is received. It is used for testing purposes. type MockWaitStream[T any] struct { - msg *T - err error - msgCh chan *T + msg *T + err error + msgCh chan *T + done chan struct{} + closeOnce sync.Once } // NewMockWaitStream returns a ServerStream that will block until closed. func NewMockWaitStream[T any]() *MockWaitStream[T] { - return &MockWaitStream[T]{msgCh: make(chan *T)} + return &MockWaitStream[T]{msgCh: make(chan *T), done: make(chan struct{})} } func (m *MockWaitStream[T]) Send(msg *T, err error) { @@ -105,9 +127,13 @@ func (m *MockWaitStream[T]) Send(msg *T, err error) { } func (m *MockWaitStream[T]) Receive() bool { - msg, ok := <-m.msgCh - m.msg = msg - return ok && msg != nil + select { + case msg, ok := <-m.msgCh: + m.msg = msg + return ok && msg != nil + case <-m.done: + return false + } } func (m *MockWaitStream[T]) Msg() *T { @@ -119,10 +145,41 @@ func (m *MockWaitStream[T]) Err() error { } func (m *MockWaitStream[T]) Close() error { - close(m.msgCh) + m.closeOnce.Do(func() { close(m.done) }) return nil } +// ServerStreamIterCtx adapts a ServerStream to iter.Seq2, closing the stream when the +// context is canceled. This is needed for blocking streams (e.g. MockWaitStream) +// where Receive() blocks on a channel and won't return until Close() is called. +func ServerStreamIterCtx[T any](ctx context.Context, stream ServerStream[T]) iter.Seq2[*T, error] { + return func(yield func(*T, error) bool) { + var closeOnce sync.Once + closeStream := func() { closeOnce.Do(func() { stream.Close() }) } + + // Close the stream when context is canceled to unblock Receive() + done := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + closeStream() + case <-done: + } + }() + defer close(done) + defer closeStream() + + for stream.Receive() { + if !yield(stream.Msg(), nil) { + return + } + } + if err := stream.Err(); err != nil { + yield(nil, err) + } + } +} + type MockFabricClient struct { FabricClient DelegateDomain string diff --git a/src/pkg/cli/client/playground.go b/src/pkg/cli/client/playground.go index 3b6ec9a02..777d9cfb3 100644 --- a/src/pkg/cli/client/playground.go +++ b/src/pkg/cli/client/playground.go @@ -42,8 +42,8 @@ func (g *PlaygroundProvider) Deploy(ctx context.Context, req *DeployRequest) (*d return getMsg(g.GetFabricClient().Deploy(ctx, connect.NewRequest(&req.DeployRequest))) } -func (g *PlaygroundProvider) GetDeploymentStatus(ctx context.Context) error { - return io.EOF // TODO: implement on fabric, for now assume service is deployed +func (g *PlaygroundProvider) GetDeploymentStatus(ctx context.Context) (bool, error) { + return true, io.EOF // TODO: implement on fabric, for now assume service is deployed } func (g *PlaygroundProvider) Preview(ctx context.Context, req *DeployRequest) (*defangv1.DeployResponse, error) { @@ -85,12 +85,36 @@ func (g *PlaygroundProvider) CreateUploadURL(ctx context.Context, req *defangv1. return getMsg(g.GetFabricClient().CreateUploadURL(ctx, connect.NewRequest(req))) } -func (g *PlaygroundProvider) Subscribe(ctx context.Context, req *defangv1.SubscribeRequest) (ServerStream[defangv1.SubscribeResponse], error) { - return g.GetFabricClient().Subscribe(ctx, connect.NewRequest(req)) +func (g *PlaygroundProvider) Subscribe(ctx context.Context, req *defangv1.SubscribeRequest) (iter.Seq2[*defangv1.SubscribeResponse, error], error) { + stream, err := g.GetFabricClient().Subscribe(ctx, connect.NewRequest(req)) + if err != nil { + return nil, err + } + return serverStreamIter(stream), nil } -func (g *PlaygroundProvider) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (ServerStream[defangv1.TailResponse], error) { - return g.GetFabricClient().Tail(ctx, connect.NewRequest(req)) +func (g *PlaygroundProvider) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (iter.Seq2[*defangv1.TailResponse, error], error) { + stream, err := g.GetFabricClient().Tail(ctx, connect.NewRequest(req)) + if err != nil { + return nil, err + } + return serverStreamIter(stream), nil +} + +// serverStreamIter adapts any ServerStream[T] (including connect-go ServerStreamForClient) +// to iter.Seq2. The stream is closed when the consumer stops iterating or the stream ends. +func serverStreamIter[T any](stream ServerStream[T]) iter.Seq2[*T, error] { + return func(yield func(*T, error) bool) { + defer stream.Close() + for stream.Receive() { + if !yield(stream.Msg(), nil) { + return + } + } + if err := stream.Err(); err != nil { + yield(nil, err) + } + } } func (g *PlaygroundProvider) CdCommand(ctx context.Context, req CdCommandRequest) (types.ETag, error) { diff --git a/src/pkg/cli/client/provider.go b/src/pkg/cli/client/provider.go index ddaa56d8e..328004e50 100644 --- a/src/pkg/cli/client/provider.go +++ b/src/pkg/cli/client/provider.go @@ -57,6 +57,7 @@ type PrepareDomainDelegationResponse struct { DelegationSetId string } +// Deprecated: use iter.Seq or iter.Seq2 instead type ServerStream[Res any] interface { Close() error Receive() bool @@ -73,7 +74,7 @@ type Provider interface { DelayBeforeRetry(context.Context) error DeleteConfig(context.Context, *defangv1.Secrets) error Deploy(context.Context, *DeployRequest) (*defangv1.DeployResponse, error) - GetDeploymentStatus(context.Context) error // nil means deployment is pending/running; io.EOF means deployment is done + GetDeploymentStatus(context.Context) (bool, error) GetProjectUpdate(context.Context, string) (*defangv1.ProjectUpdate, error) GetService(context.Context, *defangv1.GetRequest) (*defangv1.ServiceInfo, error) GetServices(context.Context, *defangv1.GetServicesRequest) (*defangv1.GetServicesResponse, error) @@ -83,12 +84,12 @@ type Provider interface { PrepareDomainDelegation(context.Context, PrepareDomainDelegationRequest) (*PrepareDomainDelegationResponse, error) Preview(context.Context, *DeployRequest) (*defangv1.DeployResponse, error) PutConfig(context.Context, *defangv1.PutConfigRequest) error - QueryLogs(context.Context, *defangv1.TailRequest) (ServerStream[defangv1.TailResponse], error) + QueryLogs(context.Context, *defangv1.TailRequest) (iter.Seq2[*defangv1.TailResponse, error], error) // Deprecated: should use stacks instead of ProjectName fallback. RemoteProjectName(context.Context) (string, error) SetCanIUseConfig(*defangv1.CanIUseResponse) SetUpCD(context.Context) error - Subscribe(context.Context, *defangv1.SubscribeRequest) (ServerStream[defangv1.SubscribeResponse], error) + Subscribe(context.Context, *defangv1.SubscribeRequest) (iter.Seq2[*defangv1.SubscribeResponse, error], error) TearDownCD(context.Context) error } diff --git a/src/pkg/cli/composeUp_test.go b/src/pkg/cli/composeUp_test.go index ae9a0a240..846e4ec8a 100644 --- a/src/pkg/cli/composeUp_test.go +++ b/src/pkg/cli/composeUp_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "io" + "iter" "net/http" "net/http/httptest" "sync" @@ -59,30 +60,30 @@ func (*mockDeployProvider) Preview(ctx context.Context, req *client.DeployReques return &defangv1.DeployResponse{Services: services, Etag: etag}, ctx.Err() } -func (m *mockDeployProvider) Subscribe(ctx context.Context, req *defangv1.SubscribeRequest) (client.ServerStream[defangv1.SubscribeResponse], error) { +func (m *mockDeployProvider) Subscribe(ctx context.Context, req *defangv1.SubscribeRequest) (iter.Seq2[*defangv1.SubscribeResponse, error], error) { m.lock.Lock() defer m.lock.Unlock() m.subscribeStream = client.NewMockWaitStream[defangv1.SubscribeResponse]() - return m.subscribeStream, ctx.Err() + return client.ServerStreamIterCtx[defangv1.SubscribeResponse](ctx, m.subscribeStream), ctx.Err() } -func (m *mockDeployProvider) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (client.ServerStream[defangv1.TailResponse], error) { +func (m *mockDeployProvider) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (iter.Seq2[*defangv1.TailResponse, error], error) { m.lock.Lock() defer m.lock.Unlock() m.tailStream = client.NewMockWaitStream[defangv1.TailResponse]() - return m.tailStream, ctx.Err() + return client.ServerStreamIterCtx(ctx, m.tailStream), ctx.Err() } func (m *mockDeployProvider) GetProjectUpdate(ctx context.Context, projectName string) (*defangv1.ProjectUpdate, error) { return m.prevProjectUpdate, ctx.Err() } -func (m *mockDeployProvider) GetDeploymentStatus(ctx context.Context) error { +func (m *mockDeployProvider) GetDeploymentStatus(ctx context.Context) (bool, error) { select { case <-ctx.Done(): - return context.Cause(ctx) + return true, context.Cause(ctx) default: - return m.deploymentStatus + return m.deploymentStatus != nil, m.deploymentStatus } } diff --git a/src/pkg/cli/safe_closer.go b/src/pkg/cli/safe_closer.go deleted file mode 100644 index e57752c0b..000000000 --- a/src/pkg/cli/safe_closer.go +++ /dev/null @@ -1,33 +0,0 @@ -package cli - -import ( - "io" - "sync/atomic" -) - -// SafeCloser atomically tracks a stream and closes the old one on Swap. -type SafeCloser struct { - ptr atomic.Pointer[struct{ io.Closer }] -} - -func NewSafeCloser(stream io.Closer) *SafeCloser { - var a SafeCloser - a.ptr.Store(&struct{ io.Closer }{stream}) - return &a -} - -// Swap atomically replaces the stream and closes the old one. -func (a *SafeCloser) Swap(stream io.Closer) error { - if old := a.ptr.Swap(&struct{ io.Closer }{stream}); old != nil { - return old.Close() - } - return nil -} - -// Close atomically removes and closes the current stream. -func (a *SafeCloser) Close() error { - if stream := a.ptr.Swap(nil); stream != nil { - return stream.Close() - } - return nil -} diff --git a/src/pkg/cli/subscribe.go b/src/pkg/cli/subscribe.go index 6b2c52acc..be00a1e37 100644 --- a/src/pkg/cli/subscribe.go +++ b/src/pkg/cli/subscribe.go @@ -3,6 +3,7 @@ package cli import ( "context" "errors" + "iter" "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/term" @@ -30,19 +31,13 @@ func WaitServiceState( // Assume "services" are normalized service names subscribeRequest := defangv1.SubscribeRequest{Project: projectName, Etag: etag, Services: services} - serverStream, err := provider.Subscribe(ctx, &subscribeRequest) + logs, err := provider.Subscribe(ctx, &subscribeRequest) if err != nil { return nil, err } - ctx, cancel := context.WithCancel(ctx) - defer cancel() // to ensure we close the stream and clean-up this context - - safeCloser := NewSafeCloser(serverStream) - go func() { - <-ctx.Done() - safeCloser.Close() - }() + next, stop := iter.Pull2(logs) + defer stop() serviceStates := make(ServiceStates, len(services)) // Make sure all services are in the map or `allInState` might return true too early @@ -52,23 +47,27 @@ func WaitServiceState( // Monitor for when all services are completed to end this command for { - if !serverStream.Receive() { - // Reconnect on Error: internal: stream error: stream ID 5; INTERNAL_ERROR; received from peer - if isTransientError(serverStream.Err()) { + msg, err, ok := next() + if !ok { + return serviceStates, nil + } + if err != nil { + // Reconnect on transient errors + if isTransientError(err) { if err := provider.DelayBeforeRetry(ctx); err != nil { return serviceStates, err } - serverStream, err = provider.Subscribe(ctx, &subscribeRequest) + stop() // stop the old iterator + logs, err = provider.Subscribe(ctx, &subscribeRequest) if err != nil { return serviceStates, err } - safeCloser.Swap(serverStream) // closes the old stream + next, stop = iter.Pull2(logs) continue } - return serviceStates, serverStream.Err() + return serviceStates, err } - msg := serverStream.Msg() if msg == nil { continue } diff --git a/src/pkg/cli/subscribe_test.go b/src/pkg/cli/subscribe_test.go index e1e8eacc7..9db318530 100644 --- a/src/pkg/cli/subscribe_test.go +++ b/src/pkg/cli/subscribe_test.go @@ -3,8 +3,8 @@ package cli import ( "context" "errors" + "iter" "reflect" - "sync/atomic" "testing" "time" @@ -14,36 +14,17 @@ import ( "github.com/bufbuild/connect-go" ) -// MockSubscribeServerStream mocks the stream response for Subscribe. -type MockSubscribeServerStream = client.MockServerStream[defangv1.SubscribeResponse] - // mockSubscribeProvider mocks the provider for Subscribe. type mockSubscribeProvider struct { client.MockProvider reqs []*defangv1.SubscribeRequest - resps map[types.ETag]*MockSubscribeServerStream -} - -type mockSubscribeServerStream struct { - *MockSubscribeServerStream - closed atomic.Bool -} - -func (m *mockSubscribeServerStream) Close() error { - if m.closed.Swap(true) { - panic("mockSubscribeServerStream already closed") - } - return nil -} - -func (m *mockSubscribeServerStream) Receive() bool { - return !m.closed.Load() && m.MockSubscribeServerStream.Receive() + resps map[types.ETag][]*defangv1.SubscribeResponse } func (m *mockSubscribeProvider) Subscribe( _ context.Context, req *defangv1.SubscribeRequest, -) (client.ServerStream[defangv1.SubscribeResponse], error) { +) (iter.Seq2[*defangv1.SubscribeResponse, error], error) { m.reqs = append(m.reqs, req) resps, ok := m.resps[req.Etag] @@ -51,127 +32,117 @@ func (m *mockSubscribeProvider) Subscribe( panic("unexpected etag; not in resps map") } - return &mockSubscribeServerStream{MockSubscribeServerStream: resps}, nil + return client.MockIter(resps, nil), nil } func TestWaitServiceState(t *testing.T) { ctx := t.Context() provider := &mockSubscribeProvider{ - resps: map[string]*MockSubscribeServerStream{ + resps: map[string][]*defangv1.SubscribeResponse{ "etag1": { - Resps: []*defangv1.SubscribeResponse{ - { - Name: "service1", - State: defangv1.ServiceState_BUILD_QUEUED, - }, - { - Name: "service1", - State: defangv1.ServiceState_BUILD_PROVISIONING, - }, - { - Name: "service1", - State: defangv1.ServiceState_DEPLOYMENT_COMPLETED, - }, + { + Name: "service1", + State: defangv1.ServiceState_BUILD_QUEUED, + }, + { + Name: "service1", + State: defangv1.ServiceState_BUILD_PROVISIONING, + }, + { + Name: "service1", + State: defangv1.ServiceState_DEPLOYMENT_COMPLETED, }, }, "etag2": { - Resps: []*defangv1.SubscribeResponse{ - { - Name: "service1", - State: defangv1.ServiceState_BUILD_QUEUED, - }, - { - Name: "service1", - State: defangv1.ServiceState_BUILD_PROVISIONING, - }, - { - Name: "service1", - State: defangv1.ServiceState_DEPLOYMENT_COMPLETED, - }, - { - Name: "service2", - State: defangv1.ServiceState_BUILD_QUEUED, - }, - { - Name: "service2", - State: defangv1.ServiceState_BUILD_PROVISIONING, - }, - { - Name: "service2", - State: defangv1.ServiceState_DEPLOYMENT_COMPLETED, - }, + { + Name: "service1", + State: defangv1.ServiceState_BUILD_QUEUED, + }, + { + Name: "service1", + State: defangv1.ServiceState_BUILD_PROVISIONING, + }, + { + Name: "service1", + State: defangv1.ServiceState_DEPLOYMENT_COMPLETED, + }, + { + Name: "service2", + State: defangv1.ServiceState_BUILD_QUEUED, + }, + { + Name: "service2", + State: defangv1.ServiceState_BUILD_PROVISIONING, + }, + { + Name: "service2", + State: defangv1.ServiceState_DEPLOYMENT_COMPLETED, }, }, "etag3": { - Resps: []*defangv1.SubscribeResponse{ - { - Name: "service1", - State: defangv1.ServiceState_BUILD_QUEUED, - }, - { - Name: "service1", - State: defangv1.ServiceState_BUILD_PROVISIONING, - }, - { - Name: "service1", - State: defangv1.ServiceState_BUILD_FAILED, - }, + { + Name: "service1", + State: defangv1.ServiceState_BUILD_QUEUED, + }, + { + Name: "service1", + State: defangv1.ServiceState_BUILD_PROVISIONING, + }, + { + Name: "service1", + State: defangv1.ServiceState_BUILD_FAILED, }, }, "etag4": { - Resps: []*defangv1.SubscribeResponse{ - { - Name: "service1", - State: defangv1.ServiceState_BUILD_QUEUED, - }, - { - Name: "service1", - State: defangv1.ServiceState_BUILD_PROVISIONING, - }, - { - Name: "service1", - State: defangv1.ServiceState_DEPLOYMENT_FAILED, - }, + { + Name: "service1", + State: defangv1.ServiceState_BUILD_QUEUED, + }, + { + Name: "service1", + State: defangv1.ServiceState_BUILD_PROVISIONING, + }, + { + Name: "service1", + State: defangv1.ServiceState_DEPLOYMENT_FAILED, }, }, "etag5": { - Resps: []*defangv1.SubscribeResponse{ - { - Name: "service1", - State: defangv1.ServiceState_BUILD_QUEUED, - }, - { - Name: "service1", - State: defangv1.ServiceState_BUILD_PROVISIONING, - }, - { - Name: "service1", - State: defangv1.ServiceState_DEPLOYMENT_COMPLETED, - }, - { - Name: "service2", - State: defangv1.ServiceState_BUILD_QUEUED, - }, - { - Name: "service2", - State: defangv1.ServiceState_BUILD_PROVISIONING, - }, - { - Name: "service2", - State: defangv1.ServiceState_DEPLOYMENT_COMPLETED, - }, - { - Name: "service3", - State: defangv1.ServiceState_BUILD_QUEUED, - }, - { - Name: "service3", - State: defangv1.ServiceState_BUILD_PROVISIONING, - }, - { - Name: "service3", - State: defangv1.ServiceState_DEPLOYMENT_FAILED, - }, + { + Name: "service1", + State: defangv1.ServiceState_BUILD_QUEUED, + }, + { + Name: "service1", + State: defangv1.ServiceState_BUILD_PROVISIONING, + }, + { + Name: "service1", + State: defangv1.ServiceState_DEPLOYMENT_COMPLETED, + }, + { + Name: "service2", + State: defangv1.ServiceState_BUILD_QUEUED, + }, + { + Name: "service2", + State: defangv1.ServiceState_BUILD_PROVISIONING, + }, + { + Name: "service2", + State: defangv1.ServiceState_DEPLOYMENT_COMPLETED, + }, + { + Name: "service3", + State: defangv1.ServiceState_BUILD_QUEUED, + }, + { + Name: "service3", + State: defangv1.ServiceState_BUILD_PROVISIONING, + }, + { + Name: "service3", + State: defangv1.ServiceState_DEPLOYMENT_FAILED, }, }, }, @@ -268,69 +239,48 @@ func TestWaitServiceState(t *testing.T) { } } -type MockSubscribeServerStreamForReconnectTest struct { - Error error - retry int -} - -func (*MockSubscribeServerStreamForReconnectTest) Close() error { - return nil -} - -func (m *MockSubscribeServerStreamForReconnectTest) Receive() bool { - return false -} - -func (m *MockSubscribeServerStreamForReconnectTest) Msg() *defangv1.SubscribeResponse { - return nil -} - -func (m *MockSubscribeServerStreamForReconnectTest) Err() error { - if m.retry < 5 { - m.retry++ - return m.Error - } - return connect.NewError(connect.CodeCanceled, errors.New("cancel connect error")) // cancel the connection after 5 retries to avoid infinite loop -} - type mockSubscribeProviderForReconnectTest struct { client.MockProvider - stream *MockSubscribeServerStreamForReconnectTest + err error + retry int client.RetryDelayer } func (m *mockSubscribeProviderForReconnectTest) Subscribe( _ context.Context, _ *defangv1.SubscribeRequest, -) (client.ServerStream[defangv1.SubscribeResponse], error) { - return m.stream, nil +) (iter.Seq2[*defangv1.SubscribeResponse, error], error) { + var err error + if m.retry < 5 { + m.retry++ + err = m.err + } else { + err = connect.NewError(connect.CodeCanceled, errors.New("cancel connect error")) + } + return func(yield func(*defangv1.SubscribeResponse, error) bool) { + yield(nil, err) + }, nil } func TestWaitServiceStateStreamReceive(t *testing.T) { tests := []struct { name string - stream *MockSubscribeServerStreamForReconnectTest + err error expectRetry bool }{ { - name: "stream receive returns permission denied error and not retry to connect", - stream: &MockSubscribeServerStreamForReconnectTest{ - Error: connect.NewError(connect.CodePermissionDenied, errors.New("Not Transient Error")), - }, + name: "stream receive returns permission denied error and not retry to connect", + err: connect.NewError(connect.CodePermissionDenied, errors.New("Not Transient Error")), expectRetry: false, }, { - name: "stream receive returns unavailable error and retry to connect", - stream: &MockSubscribeServerStreamForReconnectTest{ - Error: connect.NewError(connect.CodeUnavailable, errors.New("stream error")), - }, + name: "stream receive returns unavailable error and retry to connect", + err: connect.NewError(connect.CodeUnavailable, errors.New("stream error")), expectRetry: true, }, { - name: "stream receive returns internal error and retry to connect", - stream: &MockSubscribeServerStreamForReconnectTest{ - Error: connect.NewError(connect.CodeInternal, errors.New("internal error")), - }, + name: "stream receive returns internal error and retry to connect", + err: connect.NewError(connect.CodeInternal, errors.New("internal error")), expectRetry: true, }, } @@ -338,7 +288,7 @@ func TestWaitServiceStateStreamReceive(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { ctx := t.Context() - provider := &mockSubscribeProviderForReconnectTest{stream: tt.stream, RetryDelayer: client.RetryDelayer{Delay: 1 * time.Millisecond}} + provider := &mockSubscribeProviderForReconnectTest{err: tt.err, RetryDelayer: client.RetryDelayer{Delay: 1 * time.Millisecond}} _, err := WaitServiceState( ctx, provider, defangv1.ServiceState_DEPLOYMENT_COMPLETED, @@ -346,10 +296,10 @@ func TestWaitServiceStateStreamReceive(t *testing.T) { "EtagSomething", []string{"service1"}, ) - if !tt.expectRetry && isTransientError(err) && provider.stream.retry > 5 { + if !tt.expectRetry && isTransientError(err) && provider.retry > 5 { t.Errorf("unexpected error: %v", err) } - if tt.expectRetry && err == nil && provider.stream.retry < 5 { + if tt.expectRetry && err == nil && provider.retry < 5 { t.Error("expected error but got nil") } }) diff --git a/src/pkg/cli/tail.go b/src/pkg/cli/tail.go index ef8716d6b..79b3d2620 100644 --- a/src/pkg/cli/tail.go +++ b/src/pkg/cli/tail.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "iter" "net" "os" "regexp" @@ -234,13 +235,13 @@ func streamLogs(ctx context.Context, provider client.Provider, projectName strin term.Debug("Tail request:", tailRequest) - serverStream, err := provider.QueryLogs(ctx, tailRequest) + logSeq, err := provider.QueryLogs(ctx, tailRequest) if err != nil { return err } ctx, cancel := context.WithCancel(ctx) - defer cancel() // to ensure we close the stream and clean-up this context + defer cancel() // to ensure we clean-up this context spin := spinner.New() doSpinner := !options.Raw && term.StdoutCanColor() && term.IsTerminal() @@ -293,7 +294,7 @@ func streamLogs(ctx context.Context, provider client.Provider, projectName strin } } - return receiveLogs(ctx, provider, projectName, tailRequest, serverStream, &options, doSpinner, handler) + return receiveLogs(ctx, provider, projectName, tailRequest, logSeq, &options, doSpinner, handler) } func makeHeadBookendOptions(options *TailOptions, firstLogTime time.Time) *TailOptions { @@ -332,29 +333,33 @@ func printTailBookend(options *TailOptions, lastLogTime time.Time) { } } -func receiveLogs(ctx context.Context, provider client.Provider, projectName string, tailRequest *defangv1.TailRequest, serverStream client.ServerStream[defangv1.TailResponse], options *TailOptions, doSpinner bool, handler LogEntryHandler) error { - safeCloser := NewSafeCloser(serverStream) - go func() { - <-ctx.Done() - safeCloser.Close() - }() +func receiveLogs(ctx context.Context, provider client.Provider, projectName string, tailRequest *defangv1.TailRequest, logSeq iter.Seq2[*defangv1.TailResponse, error], options *TailOptions, doSpinner bool, handler LogEntryHandler) error { + next, stop := iter.Pull2(logSeq) + defer stop() headBookendPrinted := false lastLogTime := time.Time{} skipDuplicate := false - var err error for { - if !serverStream.Receive() { - if errors.Is(serverStream.Err(), context.Canceled) || errors.Is(serverStream.Err(), context.DeadlineExceeded) { - return &CancelError{TailOptions: *options, error: serverStream.Err(), ProjectName: projectName} + msg, err, ok := next() + if !ok { + // Iterator finished normally + if options.PrintBookends { + printTailBookend(options, lastLogTime) } - if errors.Is(serverStream.Err(), io.EOF) { - return serverStream.Err() + return nil + } + if err != nil { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return &CancelError{TailOptions: *options, error: err, ProjectName: projectName} + } + if errors.Is(err, io.EOF) { + return err } - // Reconnect on Error: internal: stream error: stream ID 5; INTERNAL_ERROR; received from peer - if isTransientError(serverStream.Err()) { - term.Debug("Disconnected:", serverStream.Err()) + // Reconnect on transient errors + if isTransientError(err) { + term.Debug("Disconnected:", err) var spaces int if !options.Raw { spaces, _ = term.Warnf("Reconnecting...\r") // overwritten below @@ -363,12 +368,13 @@ func receiveLogs(ctx context.Context, provider client.Provider, projectName stri return err } tailRequest.Since = timestamppb.New(options.Since) - serverStream, err = provider.QueryLogs(ctx, tailRequest) + stop() // stop the old iterator + newLogSeq, err := provider.QueryLogs(ctx, tailRequest) if err != nil { term.Debug("Reconnect failed:", err) return err } - safeCloser.Swap(serverStream) // closes the old stream + next, stop = iter.Pull2(newLogSeq) if !options.Raw { term.Printf("%*s", spaces, "\r") // clear the "reconnecting" message } @@ -376,15 +382,8 @@ func receiveLogs(ctx context.Context, provider client.Provider, projectName stri continue } - if serverStream.Err() == nil { // returns nil on EOF - if options.PrintBookends { - printTailBookend(options, lastLogTime) - } - return nil - } - return serverStream.Err() + return err } - msg := serverStream.Msg() if msg == nil { continue diff --git a/src/pkg/cli/tailAndMonitor_test.go b/src/pkg/cli/tailAndMonitor_test.go index 83e0ff2cf..2e5360771 100644 --- a/src/pkg/cli/tailAndMonitor_test.go +++ b/src/pkg/cli/tailAndMonitor_test.go @@ -3,6 +3,7 @@ package cli import ( "context" "io" + "iter" "testing" "time" @@ -13,20 +14,46 @@ import ( "github.com/stretchr/testify/require" ) +type mockSubscribeData struct { + index int + resps []*defangv1.SubscribeResponse + err error +} + type mockTailAndMonitorProvider struct { - mockSubscribeProvider + client.MockProvider getDeploymentStatusErr error + subs map[types.ETag]*mockSubscribeData } -func (m *mockTailAndMonitorProvider) GetDeploymentStatus(ctx context.Context) error { +func (m *mockTailAndMonitorProvider) GetDeploymentStatus(ctx context.Context) (bool, error) { if err := ctx.Err(); err != nil { - return err + return false, err } - return m.getDeploymentStatusErr + return false, m.getDeploymentStatusErr +} + +func (m *mockTailAndMonitorProvider) Subscribe(_ context.Context, req *defangv1.SubscribeRequest) (iter.Seq2[*defangv1.SubscribeResponse, error], error) { + data := m.subs[req.Etag] + return func(yield func(*defangv1.SubscribeResponse, error) bool) { + for data.index < len(data.resps) { + resp := data.resps[data.index] + data.index++ + if resp == nil { + if data.err != nil { + yield(nil, data.err) + } + return + } + if !yield(resp, nil) { + return + } + } + }, nil } -func (m *mockTailAndMonitorProvider) QueryLogs(ctx context.Context, r *defangv1.TailRequest) (client.ServerStream[defangv1.TailResponse], error) { - return client.NewMockWaitStream[defangv1.TailResponse](), ctx.Err() +func (m *mockTailAndMonitorProvider) QueryLogs(ctx context.Context, r *defangv1.TailRequest) (iter.Seq2[*defangv1.TailResponse, error], error) { + return client.ServerStreamIterCtx(ctx, client.NewMockWaitStream[defangv1.TailResponse]()), ctx.Err() } func (m *mockTailAndMonitorProvider) DelayBeforeRetry(ctx context.Context) error { @@ -36,61 +63,59 @@ func (m *mockTailAndMonitorProvider) DelayBeforeRetry(ctx context.Context) error func TestTailAndMonitor(t *testing.T) { mockProvider := &mockTailAndMonitorProvider{ getDeploymentStatusErr: io.EOF, //client.ErrDeploymentFailed{}, // done - mockSubscribeProvider: mockSubscribeProvider{ - resps: map[types.ETag]*MockSubscribeServerStream{ - "deployment12": { - Error: io.ErrUnexpectedEOF, // reconnection - Resps: []*defangv1.SubscribeResponse{ - nil, // reconnect - {Service: nil, Name: "web", Status: "", State: defangv1.ServiceState_DEPLOYMENT_PENDING}, - {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_DEPLOYMENT_PENDING}, - {Service: nil, Name: "web", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE - {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE - {Service: nil, Name: "api", Status: " : 5d5a308a19fd48f3972ae9aa74768f29", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_PROVISIONING - {Service: nil, Name: "web", Status: " : 346b2dbd236b4a24ab86abcfafda4eef", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_PROVISIONING - {Service: nil, Name: "web", Status: " : 346b2dbd236b4a24ab86abcfafda4eef", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_PENDING - {Service: nil, Name: "api", Status: " : 5d5a308a19fd48f3972ae9aa74768f29", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_PENDING - {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE - {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE - {Service: nil, Name: "web", Status: " : 346b2dbd236b4a24ab86abcfafda4eef", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_ACTIVATING - {Service: nil, Name: "api", Status: " : 5d5a308a19fd48f3972ae9aa74768f29", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_ACTIVATING - {Service: nil, Name: "web", Status: " : 346b2dbd236b4a24ab86abcfafda4eef", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_RUNNING - {Service: nil, Name: "web", Status: "", State: defangv1.ServiceState_DEPLOYMENT_COMPLETED}, - {Service: nil, Name: "web", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE - {Service: nil, Name: "web", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE - {Service: nil, Name: "api", Status: " : 5d5a308a19fd48f3972ae9aa74768f29", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_RUNNING - {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE - {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE - {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_DEPLOYMENT_COMPLETED}, - {Service: nil, Name: "auth", Status: "", State: defangv1.ServiceState_DEPLOYMENT_PENDING}, - {Service: nil, Name: "auth", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE - {Service: nil, Name: "auth", Status: " : 0b54ed2ba5fa4ec7bbc2abd658c5684c", State: defangv1.ServiceState_NOT_SPECIFIED}, - {Service: nil, Name: "auth", Status: " : 0b54ed2ba5fa4ec7bbc2abd658c5684c", State: defangv1.ServiceState_NOT_SPECIFIED}, - {Service: nil, Name: "auth", Status: " : 0b54ed2ba5fa4ec7bbc2abd658c5684c", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_ACTIVATING - {Service: nil, Name: "auth", Status: " : 0b54ed2ba5fa4ec7bbc2abd658c5684c", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_RUNNING - {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_DEPLOYMENT_PENDING}, - {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE - {Service: nil, Name: "auth", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE - {Service: nil, Name: "auth", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE - {Service: nil, Name: "auth", Status: "", State: defangv1.ServiceState_DEPLOYMENT_COMPLETED}, - {Service: nil, Name: "web", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE - {Service: nil, Name: "web", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE - {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE - {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE - {Service: nil, Name: "hasura", Status: " : c7ed06d1bd824a97a6a2b1435f20511b", State: defangv1.ServiceState_NOT_SPECIFIED}, - {Service: nil, Name: "hasura", Status: " : c7ed06d1bd824a97a6a2b1435f20511b", State: defangv1.ServiceState_NOT_SPECIFIED}, - {Service: nil, Name: "hasura", Status: " : c7ed06d1bd824a97a6a2b1435f20511b", State: defangv1.ServiceState_NOT_SPECIFIED}, - {Service: nil, Name: "hasura", Status: " : c7ed06d1bd824a97a6a2b1435f20511b", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_ACTIVATING - {Service: nil, Name: "hasura", Status: " : c7ed06d1bd824a97a6a2b1435f20511b", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_RUNNING - {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_DEPLOYMENT_PENDING}, // simulate TASK_STOPPED - {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_DEPLOYMENT_COMPLETED}, - {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE - {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE - {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE - {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE - {Service: nil, Name: "auth", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE - {Service: nil, Name: "auth", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE - }, + subs: map[types.ETag]*mockSubscribeData{ + "deployment12": { + err: io.ErrUnexpectedEOF, // reconnection + resps: []*defangv1.SubscribeResponse{ + nil, // reconnect + {Service: nil, Name: "web", Status: "", State: defangv1.ServiceState_DEPLOYMENT_PENDING}, + {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_DEPLOYMENT_PENDING}, + {Service: nil, Name: "web", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE + {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE + {Service: nil, Name: "api", Status: " : 5d5a308a19fd48f3972ae9aa74768f29", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_PROVISIONING + {Service: nil, Name: "web", Status: " : 346b2dbd236b4a24ab86abcfafda4eef", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_PROVISIONING + {Service: nil, Name: "web", Status: " : 346b2dbd236b4a24ab86abcfafda4eef", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_PENDING + {Service: nil, Name: "api", Status: " : 5d5a308a19fd48f3972ae9aa74768f29", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_PENDING + {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE + {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE + {Service: nil, Name: "web", Status: " : 346b2dbd236b4a24ab86abcfafda4eef", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_ACTIVATING + {Service: nil, Name: "api", Status: " : 5d5a308a19fd48f3972ae9aa74768f29", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_ACTIVATING + {Service: nil, Name: "web", Status: " : 346b2dbd236b4a24ab86abcfafda4eef", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_RUNNING + {Service: nil, Name: "web", Status: "", State: defangv1.ServiceState_DEPLOYMENT_COMPLETED}, + {Service: nil, Name: "web", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE + {Service: nil, Name: "web", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE + {Service: nil, Name: "api", Status: " : 5d5a308a19fd48f3972ae9aa74768f29", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_RUNNING + {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE + {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE + {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_DEPLOYMENT_COMPLETED}, + {Service: nil, Name: "auth", Status: "", State: defangv1.ServiceState_DEPLOYMENT_PENDING}, + {Service: nil, Name: "auth", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE + {Service: nil, Name: "auth", Status: " : 0b54ed2ba5fa4ec7bbc2abd658c5684c", State: defangv1.ServiceState_NOT_SPECIFIED}, + {Service: nil, Name: "auth", Status: " : 0b54ed2ba5fa4ec7bbc2abd658c5684c", State: defangv1.ServiceState_NOT_SPECIFIED}, + {Service: nil, Name: "auth", Status: " : 0b54ed2ba5fa4ec7bbc2abd658c5684c", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_ACTIVATING + {Service: nil, Name: "auth", Status: " : 0b54ed2ba5fa4ec7bbc2abd658c5684c", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_RUNNING + {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_DEPLOYMENT_PENDING}, + {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE + {Service: nil, Name: "auth", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE + {Service: nil, Name: "auth", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE + {Service: nil, Name: "auth", Status: "", State: defangv1.ServiceState_DEPLOYMENT_COMPLETED}, + {Service: nil, Name: "web", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE + {Service: nil, Name: "web", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE + {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE + {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE + {Service: nil, Name: "hasura", Status: " : c7ed06d1bd824a97a6a2b1435f20511b", State: defangv1.ServiceState_NOT_SPECIFIED}, + {Service: nil, Name: "hasura", Status: " : c7ed06d1bd824a97a6a2b1435f20511b", State: defangv1.ServiceState_NOT_SPECIFIED}, + {Service: nil, Name: "hasura", Status: " : c7ed06d1bd824a97a6a2b1435f20511b", State: defangv1.ServiceState_NOT_SPECIFIED}, + {Service: nil, Name: "hasura", Status: " : c7ed06d1bd824a97a6a2b1435f20511b", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_ACTIVATING + {Service: nil, Name: "hasura", Status: " : c7ed06d1bd824a97a6a2b1435f20511b", State: defangv1.ServiceState_NOT_SPECIFIED}, // TASK_RUNNING + {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_DEPLOYMENT_PENDING}, // simulate TASK_STOPPED + {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_DEPLOYMENT_COMPLETED}, + {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE + {Service: nil, Name: "hasura", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE + {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE + {Service: nil, Name: "api", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE + {Service: nil, Name: "auth", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // CAPACITY_PROVIDER_STEADY_STATE + {Service: nil, Name: "auth", Status: "", State: defangv1.ServiceState_NOT_SPECIFIED}, // SERVICE_STEADY_STATE }, }, }, diff --git a/src/pkg/cli/tail_test.go b/src/pkg/cli/tail_test.go index 2ab50b0d9..754f03035 100644 --- a/src/pkg/cli/tail_test.go +++ b/src/pkg/cli/tail_test.go @@ -12,6 +12,8 @@ import ( "testing" "time" + "iter" + "github.com/DefangLabs/defang/src/pkg/cli/client" "github.com/DefangLabs/defang/src/pkg/clouds/aws/ecs" "github.com/DefangLabs/defang/src/pkg/logs" @@ -51,8 +53,8 @@ func TestIsProgressDot(t *testing.T) { type mockTailProvider struct { client.Provider - ServerStreams []client.ServerStream[defangv1.TailResponse] - Reqs []*defangv1.TailRequest + Iters []iter.Seq2[*defangv1.TailResponse, error] + Reqs []*defangv1.TailRequest } func (mockTailProvider) DelayBeforeRetry(ctx context.Context) error { @@ -60,29 +62,25 @@ func (mockTailProvider) DelayBeforeRetry(ctx context.Context) error { return ctx.Err() } -func (m *mockTailProvider) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (client.ServerStream[defangv1.TailResponse], error) { +func (m *mockTailProvider) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (iter.Seq2[*defangv1.TailResponse, error], error) { dup, _ := proto.Clone(req).(*defangv1.TailRequest) m.Reqs = append(m.Reqs, dup) - if len(m.ServerStreams) == 0 { + if len(m.Iters) == 0 { return nil, errors.New("no server stream provided") } - ss := m.ServerStreams[0] - m.ServerStreams = m.ServerStreams[1:] - return ss, nil + it := m.Iters[0] + m.Iters = m.Iters[1:] + return it, nil } -type mockTailStream = client.MockServerStream[defangv1.TailResponse] - func (m *mockTailProvider) MockTimestamp(timestamp time.Time) *mockTailProvider { return &mockTailProvider{ - ServerStreams: []client.ServerStream[defangv1.TailResponse]{ - &mockTailStream{ - Resps: []*defangv1.TailResponse{ - {Entries: []*defangv1.LogEntry{ - {Timestamp: timestamppb.New(timestamp)}, - }}, - }, - }, &mockTailStream{Error: io.EOF}, + Iters: []iter.Seq2[*defangv1.TailResponse, error]{ + client.MockIter([]*defangv1.TailResponse{ + {Entries: []*defangv1.LogEntry{ + {Timestamp: timestamppb.New(timestamp)}, + }}, + }, nil), }, } } @@ -100,27 +98,24 @@ func TestTail(t *testing.T) { const projectName = "project1" p := &mockTailProvider{ - ServerStreams: []client.ServerStream[defangv1.TailResponse]{ - &mockTailStream{ - Resps: []*defangv1.TailResponse{ - {Service: "service1", Etag: "SOMEETAG", Host: "SOMEHOST", Entries: []*defangv1.LogEntry{ - {Message: "e1msg1", Timestamp: timestamppb.Now()}, - {Message: "e1msg2", Timestamp: timestamppb.Now(), Etag: "SOMEOTHERETAG"}, // Test event etag override the response etag - {Message: "e1msg3", Timestamp: timestamppb.Now(), Etag: "SOMEOTHERETAG2", Host: "SOMEOTHERHOST"}, // override both etag and host - {Message: "e1msg4", Timestamp: timestamppb.Now(), Etag: "SOMEOTHERETAG2", Host: "SOMEOTHERHOST", Service: "service2"}, // override both etag, host and service - {Message: "e1err1", Timestamp: timestamppb.Now(), Stderr: true}, // Error message should be in stdout too when not raw - }}, - {Service: "service1", Etag: "SOMEETAG", Host: "SOMEHOST", Entries: []*defangv1.LogEntry{ // Test entry etag does not affect the default values from response - {Message: "e2err1", Timestamp: timestamppb.Now(), Stderr: true, Etag: "SOMEOTHERETAG"}, // Error message should be in stdout too when not raw - {Message: "e2msg1", Timestamp: timestamppb.Now(), Etag: "ENTRIES2ETAG"}, - {Message: "e2msg2", Timestamp: timestamppb.Now()}, - {Message: "e2msg3", Timestamp: timestamppb.Now(), Etag: "SOMEOTHERETAG2", Host: "SOMEOTHERHOST", Service: "service2"}, // override both etag, host and service - {Message: "e2msg4", Timestamp: timestamppb.Now()}, - }}, - }, - Error: connect.NewError(connect.CodeInternal, &cwTypes.SessionStreamingException{}), // to test retries - }, - &mockTailStream{Error: io.EOF}, + Iters: []iter.Seq2[*defangv1.TailResponse, error]{ + client.MockIter([]*defangv1.TailResponse{ + {Service: "service1", Etag: "SOMEETAG", Host: "SOMEHOST", Entries: []*defangv1.LogEntry{ + {Message: "e1msg1", Timestamp: timestamppb.Now()}, + {Message: "e1msg2", Timestamp: timestamppb.Now(), Etag: "SOMEOTHERETAG"}, // Test event etag override the response etag + {Message: "e1msg3", Timestamp: timestamppb.Now(), Etag: "SOMEOTHERETAG2", Host: "SOMEOTHERHOST"}, // override both etag and host + {Message: "e1msg4", Timestamp: timestamppb.Now(), Etag: "SOMEOTHERETAG2", Host: "SOMEOTHERHOST", Service: "service2"}, // override both etag, host and service + {Message: "e1err1", Timestamp: timestamppb.Now(), Stderr: true}, // Error message should be in stdout too when not raw + }}, + {Service: "service1", Etag: "SOMEETAG", Host: "SOMEHOST", Entries: []*defangv1.LogEntry{ // Test entry etag does not affect the default values from response + {Message: "e2err1", Timestamp: timestamppb.Now(), Stderr: true, Etag: "SOMEOTHERETAG"}, // Error message should be in stdout too when not raw + {Message: "e2msg1", Timestamp: timestamppb.Now(), Etag: "ENTRIES2ETAG"}, + {Message: "e2msg2", Timestamp: timestamppb.Now()}, + {Message: "e2msg3", Timestamp: timestamppb.Now(), Etag: "SOMEOTHERETAG2", Host: "SOMEOTHERHOST", Service: "service2"}, // override both etag, host and service + {Message: "e2msg4", Timestamp: timestamppb.Now()}, + }}, + }, connect.NewError(connect.CodeInternal, &cwTypes.SessionStreamingException{})), // to test retries + client.MockIter[defangv1.TailResponse](nil, io.EOF), }, } @@ -292,8 +287,8 @@ type mockQueryErrorProvider struct { TailStreamError error } -func (m mockQueryErrorProvider) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (client.ServerStream[defangv1.TailResponse], error) { - return &mockTailStream{Error: m.TailStreamError}, nil +func (m mockQueryErrorProvider) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (iter.Seq2[*defangv1.TailResponse, error], error) { + return client.MockIter[defangv1.TailResponse](nil, m.TailStreamError), nil } func TestTailError(t *testing.T) { diff --git a/src/pkg/cli/waitForCdTaskExit.go b/src/pkg/cli/waitForCdTaskExit.go index e164f69d1..3ac733fc9 100644 --- a/src/pkg/cli/waitForCdTaskExit.go +++ b/src/pkg/cli/waitForCdTaskExit.go @@ -18,9 +18,9 @@ func WaitForCdTaskExit(ctx context.Context, provider client.Provider) error { for { select { case <-ticker.C: - err := provider.GetDeploymentStatus(ctx) + done, err := provider.GetDeploymentStatus(ctx) // End condition: EOF indicates that the task has completed successfully - if errors.Is(err, io.EOF) { + if done || errors.Is(err, io.EOF) { return nil } // Retry on transient errors @@ -28,10 +28,10 @@ func WaitForCdTaskExit(ctx context.Context, provider client.Provider) error { // If it's a transient error, we can retry at the next tick continue } - // nil means the task is still running and we continue polling if err != nil { return err } + // nil means the task is still running and we continue polling case <-ctx.Done(): // Stop the loop when the context is cancelled return ctx.Err() } diff --git a/src/pkg/cli/waitForCdTaskExit_test.go b/src/pkg/cli/waitForCdTaskExit_test.go index b7a238dba..7a9be51de 100644 --- a/src/pkg/cli/waitForCdTaskExit_test.go +++ b/src/pkg/cli/waitForCdTaskExit_test.go @@ -16,13 +16,13 @@ type mockCdWaiter struct { getDeploymentStatusErr error } -func (m *mockCdWaiter) GetDeploymentStatus(ctx context.Context) error { +func (m *mockCdWaiter) GetDeploymentStatus(ctx context.Context) (bool, error) { err := m.getDeploymentStatusErr // This logic was copied from AWS provider, to ensure the errs work correctly if taskErr := new(ecs.TaskFailure); errors.As(err, taskErr) { - return client.ErrDeploymentFailed{Message: taskErr.Error()} + return false, client.ErrDeploymentFailed{Message: taskErr.Error()} } - return err + return false, err } func TestWaitForCdTaskExit(t *testing.T) { diff --git a/src/pkg/clouds/aws/common.go b/src/pkg/clouds/aws/common.go index 891c205f4..4ac11019c 100644 --- a/src/pkg/clouds/aws/common.go +++ b/src/pkg/clouds/aws/common.go @@ -9,19 +9,20 @@ import ( "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials/processcreds" + r53types "github.com/aws/aws-sdk-go-v2/service/route53/types" "github.com/aws/aws-sdk-go-v2/service/sts" ) -type Region string +type Region = r53types.VPCRegion type Aws struct { AccountID string Region Region } -func (r Region) String() string { - return string(r) -} +// func (r Region) String() string { +// return string(r) +// } func (a *Aws) LoadConfig(ctx context.Context) (aws.Config, error) { cfg, err := LoadDefaultConfig(ctx, a.Region) diff --git a/src/pkg/clouds/aws/cw/logs.go b/src/pkg/clouds/aws/cw/logs.go index df9320224..5d9673f49 100644 --- a/src/pkg/clouds/aws/cw/logs.go +++ b/src/pkg/clouds/aws/cw/logs.go @@ -5,8 +5,8 @@ import ( "errors" "fmt" "io" + "iter" "strings" - "sync" "time" "github.com/DefangLabs/defang/src/pkg/clouds/aws" @@ -27,49 +27,10 @@ func getLogGroupIdentifier(arnOrId string) string { } type LogsClient interface { - FilterLogEventsAPI + FilterLogEventsAPIClient StartLiveTailAPI } -func QueryAndTailLogGroups(ctx context.Context, cwClient LogsClient, start, end time.Time, logGroups ...LogGroupInput) (LiveTailStream, error) { - ctx, cancel := context.WithCancel(ctx) - - e := &eventStream{ - cancel: cancel, - ch: make(chan types.StartLiveTailResponseStream), - } - - // We must close the channel when all log groups are done - var wg sync.WaitGroup - var err error - for _, lgi := range logGroups { - var es LiveTailStream - es, err = QueryAndTailLogGroup(ctx, cwClient, lgi, start, end) - if err != nil { - continue - } - wg.Add(1) - go func() { - defer es.Close() - defer wg.Done() - // FIXME: this should *merge* the events from all log groups - e.err = e.pipeEvents(ctx, es) - }() - } - - go func() { - wg.Wait() - close(e.ch) - }() - - if err != nil { - cancel() // abort any goroutines (caller won't call Close) - return nil, err - } - - return e, nil -} - // LogGroupInput is like cloudwatchlogs.StartLiveTailInput but with only one LogGroup and one LogStream prefix. type LogGroupInput struct { LogGroupARN string @@ -82,7 +43,7 @@ type StartLiveTailAPI interface { StartLiveTail(ctx context.Context, params *cloudwatchlogs.StartLiveTailInput, optFns ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.StartLiveTailOutput, error) } -func TailLogGroup(ctx context.Context, cwClient StartLiveTailAPI, input LogGroupInput) (LiveTailStream, error) { +func TailLogGroup(ctx context.Context, cwClient StartLiveTailAPI, input LogGroupInput) (iter.Seq2[[]LogEvent, error], error) { if input.LogGroupARN == "" { return nil, errors.New("LogGroupARN is required") } @@ -107,159 +68,156 @@ func TailLogGroup(ctx context.Context, cwClient StartLiveTailAPI, input LogGroup return nil, err } - return slto.GetStream(), nil + stream := slto.GetStream() + return func(yield func([]LogEvent, error) bool) { + defer stream.Close() + for { + select { + case e := <-stream.Events(): + if err := stream.Err(); err != nil { + yield(nil, err) + return + } + events, err := getLogEvents(e) + if err != nil { + if !yield(nil, err) { + return + } + } + if !yield(events, nil) { + return + } + case <-ctx.Done(): + yield(nil, ctx.Err()) + return + } + } + }, nil } -type FilterLogEventsAPI interface { - FilterLogEvents(ctx context.Context, params *cloudwatchlogs.FilterLogEventsInput, optFns ...func(*cloudwatchlogs.Options)) (*cloudwatchlogs.FilterLogEventsOutput, error) -} +type FilterLogEventsAPIClient = cloudwatchlogs.FilterLogEventsAPIClient -func QueryLogGroups(ctx context.Context, cwClient FilterLogEventsAPI, start, end time.Time, limit int32, logGroups ...LogGroupInput) (<-chan LogEvent, <-chan error) { - var evtsChan chan LogEvent - errChan := make(chan error, len(logGroups)) - var wg sync.WaitGroup - for _, lgi := range logGroups { - wg.Add(1) - lgEvtChan := make(chan LogEvent) - // Start a go routine for each log group - go func(lgi LogGroupInput) { - defer close(lgEvtChan) - defer wg.Done() - // CloudWatch only supports querying a LogGroup from a timestamp in - // ascending order. After we query each LogGroup, we merge the results - // and take the last N events. Because we can't tell in advance which - // LogGroup will have the most recent events, we have to query all - // log groups without limit, and then apply the limit after merging. - // TODO: optimize this by simulating a descending query by doing - // multiple queries with time windows, starting from the end time - // and moving backwards until we have enough events. - err := QueryLogGroup(ctx, cwClient, lgi, start, end, 0, func(logEvents []LogEvent) error { - for _, event := range logEvents { - lgEvtChan <- event +// Flatten converts an iterator of batches into an iterator of individual items. +func Flatten[T any](seq iter.Seq2[[]T, error]) iter.Seq2[T, error] { + return func(yield func(T, error) bool) { + for items, err := range seq { + for _, item := range items { + if !yield(item, nil) { + return } - return nil - }) - if err != nil { - errChan <- fmt.Errorf("error querying log group %q: %w", lgi.LogGroupARN, err) } - }(lgi) - evtsChan = mergeLogEventChan(evtsChan, lgEvtChan) // Merge sort the log events based on timestamp - // take the last n events only - if limit > 0 { - if start.IsZero() { - evtsChan = takeLastN(evtsChan, int(limit)) - } else { - evtsChan = takeFirstN(evtsChan, int(limit)) + if err != nil { + var zero T + if !yield(zero, err) { + return + } } } } - go func() { - wg.Wait() - close(errChan) - }() - return evtsChan, errChan -} - -func QueryLogGroup(ctx context.Context, cwClient FilterLogEventsAPI, input LogGroupInput, start, end time.Time, limit int32, cb func([]LogEvent) error) error { - return filterLogEvents(ctx, cwClient, input, start, end, limit, cb) } -func QueryLogGroupStream(ctx context.Context, cwClient FilterLogEventsAPI, input LogGroupInput, start, end time.Time, limit int32) (EventStream[types.StartLiveTailResponseStream], error) { - ctx, cancel := context.WithCancel(ctx) - es := newEventStream(cancel) // calling Close on the stream will cancel the context - - go func() { - defer close(es.ch) - // TODO: this QueryLogGroup function doesn't return until all logs are fetched, so returning a stream is not very useful - if err := QueryLogGroup(ctx, cwClient, input, start, end, limit, func(events []LogEvent) error { - select { - case <-ctx.Done(): - return ctx.Err() - case es.ch <- &types.StartLiveTailResponseStreamMemberSessionUpdate{Value: types.LiveTailSessionUpdate{SessionResults: events}}: - return nil +func QueryLogGroups(ctx context.Context, cwClient FilterLogEventsAPIClient, start, end time.Time, limit int32, logGroups ...LogGroupInput) (iter.Seq2[LogEvent, error], error) { + if len(logGroups) == 0 { + return nil, errors.New("at least one LogGroupInput is required") + } + var merged iter.Seq2[LogEvent, error] + for _, lgi := range logGroups { + logSeq, err := QueryLogGroup(ctx, cwClient, lgi, start, end, limit) + if err != nil { + // This only happens if there's a missing LogGroupARN, in which case we can't proceed at all + return nil, err + } + merged = MergeLogEvents(merged, Flatten(logSeq)) // Merge sort the log events based on timestamp + if limit > 0 { + // take the first/last n events only from the merged stream + if start.IsZero() { + merged = TakeLastN(merged, int(limit)) + } else { + merged = TakeFirstN(merged, int(limit)) } - }); err != nil { - es.err = err } - }() - - return es, nil + } + return merged, nil } -func filterLogEvents(ctx context.Context, cw FilterLogEventsAPI, lgi LogGroupInput, start, end time.Time, limit int32, cb func([]LogEvent) error) error { +func QueryLogGroup(ctx context.Context, cw FilterLogEventsAPIClient, lgi LogGroupInput, start, end time.Time, limit int32) (iter.Seq2[[]LogEvent, error], error) { if lgi.LogGroupARN == "" { - return errors.New("LogGroupARN is required") + return nil, errors.New("LogGroupARN is required") } var pattern *string if lgi.LogEventFilterPattern != "" { pattern = &lgi.LogEventFilterPattern } logGroupIdentifier := getLogGroupIdentifier(lgi.LogGroupARN) + if end.IsZero() { + end = time.Now() + } + if start.IsZero() { + // CloudWatch only supports querying a LogGroup from a timestamp in + // ascending order. After we query each LogGroup, we merge the results + // and take the last N events. Because we can't tell in advance which + // LogGroup will have the most recent events, we have to query all + // log groups without limit, and then apply the limit after merging. + // TODO: optimize this by simulating a descending query by doing + // multiple queries with time windows, starting from the end time + // and moving backwards until we have enough events. + start = end.Add(-60 * time.Minute) + limit = 0 + } + params := &cloudwatchlogs.FilterLogEventsInput{ LogGroupIdentifier: &logGroupIdentifier, LogStreamNames: lgi.LogStreamNames, FilterPattern: pattern, - } - - if limit != 0 { - params.Limit = &limit - } - if !start.IsZero() { - params.StartTime = ptr.Int64(start.UnixMilli()) - } - if !end.IsZero() { - params.EndTime = ptr.Int64(end.UnixMilli()) - } - if start.IsZero() && end.IsZero() { - // If no time range is specified, limit to the last 60 minutes - now := time.Now() - start = now.Add(-60 * time.Minute) - params.StartTime = ptr.Int64(start.UnixMilli()) - params.EndTime = ptr.Int64(now.UnixMilli()) + StartTime: ptr.Int64(start.UnixMilli()), // rounds down + EndTime: ptr.Int64(end.UnixMilli() + 1), // round up } if lgi.LogStreamNamePrefix != "" { params.LogStreamNamePrefix = &lgi.LogStreamNamePrefix } - for { - if limit > 0 { - // Specifying the limit parameter only guarantees that a single page doesn't return more log events than the - // specified limit, but it might return fewer events than the limit. This is the expected API behavior. - params.Limit = ptr.Int32(limit) - } - fleo, err := cw.FilterLogEvents(ctx, params) - if err != nil { - return err - } - events := make([]LogEvent, len(fleo.Events)) - for i, event := range fleo.Events { - events[i] = LogEvent{ - IngestionTime: event.IngestionTime, - LogGroupIdentifier: &logGroupIdentifier, - Message: event.Message, - Timestamp: event.Timestamp, - LogStreamName: event.LogStreamName, + return func(yield func([]LogEvent, error) bool) { + for { + if limit > 0 { + // Specifying the limit parameter only guarantees that a single page doesn't return more log events than the + // specified limit, but it might return fewer events than the limit. This is the expected API behavior. + params.Limit = ptr.Int32(limit) } - } - if err := cb(events); err != nil { - return err - } - if fleo.NextToken == nil { - return nil - } - if limit > 0 { - if len(events) < int(limit) { // this handles len(events) == 0 as well - limit -= int32(len(events)) // #nosec G115 - always safe because len(events) < limit - } else if lastTS := events[len(events)-1].Timestamp; lastTS != nil && time.UnixMilli(*lastTS).Equal(start) { - // If the last event timestamp is equal to the start time, we risk getting stuck in a loop - // where the agent keeps asking for logs since the last timestamp, but ends up fetching the same logs - // over and over. To avoid this, we ignore the limit and keep going, until the timestamp changes. - limit = 10 // arbitrary small number to make some progress; could be smarter - } else { - return nil + fleo, err := cw.FilterLogEvents(ctx, params) + if err != nil { + yield(nil, err) + return } + events := make([]LogEvent, len(fleo.Events)) + for i, event := range fleo.Events { + events[i] = LogEvent{ + IngestionTime: event.IngestionTime, + LogGroupIdentifier: &logGroupIdentifier, + Message: event.Message, + Timestamp: event.Timestamp, + LogStreamName: event.LogStreamName, + } + } + if !yield(events, nil) { + return + } + if fleo.NextToken == nil { + return + } + if limit > 0 { + if len(events) < int(limit) { // this handles len(events) == 0 as well + limit -= int32(len(events)) // #nosec G115 - always safe because len(events) < limit + } else if lastTS := events[len(events)-1].Timestamp; lastTS != nil && time.UnixMilli(*lastTS).Equal(start) { + // If the last event timestamp is equal to the start time, we risk getting stuck in a loop + // where the agent keeps asking for logs since the last timestamp, but ends up fetching the same logs + // over and over. To avoid this, we ignore the limit and keep going, until the timestamp changes. + limit = 10 // arbitrary small number to make some progress; could be smarter + } else { + return + } + } + params.NextToken = fleo.NextToken } - params.NextToken = fleo.NextToken - } + }, nil } func NewCloudWatchLogsClient(ctx context.Context, region aws.Region) (*cloudwatchlogs.Client, error) { @@ -272,17 +230,7 @@ func NewCloudWatchLogsClient(ctx context.Context, region aws.Region) (*cloudwatc type LogEvent = types.LiveTailSessionLogEvent -// EventStream is a generic interface that represents a stream of events -type EventStream[T any] interface { - Events() <-chan T - Close() error - Err() error -} - -// Deprecated: LiveTailStream is a stream of events from a call to AWS StartLiveTail -type LiveTailStream = EventStream[types.StartLiveTailResponseStream] - -func GetLogEvents(e types.StartLiveTailResponseStream) ([]LogEvent, error) { +func getLogEvents(e types.StartLiveTailResponseStream) ([]LogEvent, error) { switch ev := e.(type) { case *types.StartLiveTailResponseStreamMemberSessionStart: // fmt.Println("session start:", ev.Value.SessionId) @@ -296,43 +244,3 @@ func GetLogEvents(e types.StartLiveTailResponseStream) ([]LogEvent, error) { return nil, fmt.Errorf("unexpected event: %T", ev) } } - -func takeLastN[T any](input chan T, n int) chan T { - if n <= 0 { - return input - } - out := make(chan T) - go func() { - defer close(out) - var buffer []T - for evt := range input { - buffer = append(buffer, evt) - if len(buffer) > n { - buffer = buffer[1:] // remove oldest - } - } - for _, evt := range buffer { - out <- evt - } - }() - return out -} - -func takeFirstN[T any](input chan T, n int) chan T { - if n <= 0 { - return input - } - out := make(chan T) - go func() { - defer close(out) - count := 0 - for evt := range input { - out <- evt - count++ - if count >= n { - break - } - } - }() - return out -} diff --git a/src/pkg/clouds/aws/cw/logs_test.go b/src/pkg/clouds/aws/cw/logs_test.go index 3c0ab9ee0..3b8490cfa 100644 --- a/src/pkg/clouds/aws/cw/logs_test.go +++ b/src/pkg/clouds/aws/cw/logs_test.go @@ -10,6 +10,7 @@ import ( "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types" "github.com/aws/smithy-go/ptr" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestLogGroupIdentifier(t *testing.T) { @@ -46,7 +47,7 @@ func (m *mockFiltererTailer) StartLiveTail(ctx context.Context, input *cloudwatc // This is a pretty bad test. It ends up only testing that we can poll for the log group. // because our mock StartLiveTail always returns ResourceNotFoundException. -// That means the test repeatedly tries to open a stream until we call Close() on it. +// That means the test repeatedly tries to open a stream until context is canceled. func TestQueryAndTailLogGroups(t *testing.T) { logGroups := []LogGroupInput{ { @@ -54,20 +55,17 @@ func TestQueryAndTailLogGroups(t *testing.T) { }, } mockFiltererTailer := &mockFiltererTailer{} - e, err := QueryAndTailLogGroups(t.Context(), mockFiltererTailer, time.Now(), time.Time{}, logGroups...) + ctx, cancel := context.WithCancel(t.Context()) + evts, err := QueryAndTailLogGroups(ctx, mockFiltererTailer, time.Now(), time.Time{}, logGroups...) if err != nil { t.Errorf("Expected no error, but got: %v", err) } - if e.Err() != nil { - t.Errorf("Expected no error, but got: %v", e.Err()) - } - err = e.Close() - if err != nil { - t.Errorf("Expected no error, but got: %v", err) - } - _, ok := <-e.Events() - if ok { - t.Error("Expected channel to be closed") + // Cancel context to stop the polling + cancel() + for _, err := range evts { + if err != nil && err != context.Canceled { + t.Errorf("Expected no error or context.Canceled, but got: %v", err) + } } } @@ -91,9 +89,9 @@ func TestQueryLogGroups(t *testing.T) { }{ { limit: 2, - since: time.Time{}, + since: time.Now().Add(-time.Hour), until: time.Time{}, - expectedMessages: []string{"Log event 2", "Log event 3"}, + expectedMessages: []string{"Log event 1", "Log event 2"}, }, { limit: 2, @@ -103,9 +101,9 @@ func TestQueryLogGroups(t *testing.T) { }, { limit: 2, - since: time.Time{}, + since: time.Now().Add(-time.Hour), until: time.Now(), - expectedMessages: []string{"Log event 2", "Log event 3"}, + expectedMessages: []string{"Log event 1", "Log event 2"}, }, } @@ -119,7 +117,7 @@ func TestQueryLogGroups(t *testing.T) { mockFiltererTailer := &mockFiltererTailer{ filteredLogEvents: logEvents, } - eventsCh, errsCh := QueryLogGroups( + logSeq, err := QueryLogGroups( t.Context(), mockFiltererTailer, tt.since, @@ -128,27 +126,19 @@ func TestQueryLogGroups(t *testing.T) { int32(tt.limit), logGroups..., ) + require.NoError(t, err) collectedMessages := make([]string, 0) - for { - event, ok := <-eventsCh - if !ok { + for evt, err := range logSeq { + if err != nil { + t.Errorf("Expected no error, but got: %v", err) break } - collectedMessages = append(collectedMessages, *event.Message) + collectedMessages = append(collectedMessages, *evt.Message) } assert.Len(t, collectedMessages, tt.limit) for i, expectedMsg := range tt.expectedMessages { assert.Equal(t, expectedMsg, collectedMessages[i]) } - - select { - case err, ok := <-errsCh: - if ok && err != nil { - t.Errorf("Expected no error, but got: %v", err) - } - default: - // No error received, as expected - } } } diff --git a/src/pkg/clouds/aws/cw/merge.go b/src/pkg/clouds/aws/cw/merge.go index 76672f0cd..8c091328b 100644 --- a/src/pkg/clouds/aws/cw/merge.go +++ b/src/pkg/clouds/aws/cw/merge.go @@ -1,39 +1,114 @@ package cw -// Inspired by https://dev.to/vinaygo/concurrency-merge-sort-using-channels-and-goroutines-in-golang-35f7 -func Mergech[T any](left chan T, right chan T, c chan T, less func(T, T) bool) { - defer close(c) - val, ok := <-left - val2, ok2 := <-right - for ok && ok2 { - if less(val, val2) { - c <- val - val, ok = <-left - } else { - c <- val2 - val2, ok2 = <-right +import ( + "iter" +) + +// MergeLogEvents merge-sorts two ascending iterators by Timestamp. +// Uses iter.Pull2 internally for two-pointer merge. +func MergeLogEvents(left, right iter.Seq2[LogEvent, error]) iter.Seq2[LogEvent, error] { + if left == nil { + return right + } + if right == nil { + return left + } + return func(yield func(LogEvent, error) bool) { + nextL, stopL := iter.Pull2(left) + defer stopL() + nextR, stopR := iter.Pull2(right) + defer stopR() + + lVal, lErr, lOk := nextL() + rVal, rErr, rOk := nextR() + + for lOk && rOk { + if lErr != nil { + if !yield(lVal, lErr) { + return + } + lVal, lErr, lOk = nextL() + continue + } + if rErr != nil { + if !yield(rVal, rErr) { + return + } + rVal, rErr, rOk = nextR() + continue + } + if *lVal.Timestamp <= *rVal.Timestamp { + if !yield(lVal, nil) { + return + } + lVal, lErr, lOk = nextL() + } else { + if !yield(rVal, nil) { + return + } + rVal, rErr, rOk = nextR() + } + } + + for lOk { + if !yield(lVal, lErr) { + return + } + lVal, lErr, lOk = nextL() + } + for rOk { + if !yield(rVal, rErr) { + return + } + rVal, rErr, rOk = nextR() } } - for ok { - c <- val - val, ok = <-left +} + +// TakeFirstN yields at most n items from the iterator, then stops. +func TakeFirstN(seq iter.Seq2[LogEvent, error], n int) iter.Seq2[LogEvent, error] { + if n <= 0 { + return seq } - for ok2 { - c <- val2 - val2, ok2 = <-right + return func(yield func(LogEvent, error) bool) { + count := 0 + for evt, err := range seq { + if !yield(evt, err) { + return + } + if err == nil { + count++ + if count >= n { + return + } + } + } } } -func mergeLogEventChan(left, right chan LogEvent) chan LogEvent { - if left == nil { - return right +// TakeLastN buffers the entire input, then yields the last n items. +func TakeLastN(seq iter.Seq2[LogEvent, error], n int) iter.Seq2[LogEvent, error] { + if n <= 0 { + return seq } - if right == nil { - return left + return func(yield func(LogEvent, error) bool) { + var buffer []LogEvent + for evt, err := range seq { + if err != nil { + if !yield(evt, err) { + return + } + continue + } + buffer = append(buffer, evt) + if len(buffer) > n { + buffer = buffer[1:] + } + } + for _, evt := range buffer { + if !yield(evt, nil) { + return + } + } } - out := make(chan LogEvent) - go Mergech(left, right, out, func(i1, i2 LogEvent) bool { - return *i1.Timestamp < *i2.Timestamp - }) - return out } diff --git a/src/pkg/clouds/aws/cw/merge_test.go b/src/pkg/clouds/aws/cw/merge_test.go new file mode 100644 index 000000000..c6b3f96ba --- /dev/null +++ b/src/pkg/clouds/aws/cw/merge_test.go @@ -0,0 +1,199 @@ +package cw + +import ( + "errors" + "iter" + "testing" + + "github.com/aws/smithy-go/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func logEvents(timestamps ...int64) iter.Seq2[LogEvent, error] { + return func(yield func(LogEvent, error) bool) { + for _, ts := range timestamps { + if !yield(LogEvent{Timestamp: ptr.Int64(ts)}, nil) { + return + } + } + } +} + +func collect(seq iter.Seq2[LogEvent, error]) ([]int64, error) { + var timestamps []int64 + for evt, err := range seq { + if err != nil { + return timestamps, err + } + timestamps = append(timestamps, *evt.Timestamp) + } + return timestamps, nil +} + +func TestMergeLogEvents(t *testing.T) { + tests := []struct { + name string + left iter.Seq2[LogEvent, error] + right iter.Seq2[LogEvent, error] + expected []int64 + }{ + { + name: "both empty", + left: logEvents(), + right: logEvents(), + expected: nil, + }, + { + name: "left nil", + left: nil, + right: logEvents(1, 3, 5), + expected: []int64{1, 3, 5}, + }, + { + name: "right nil", + left: logEvents(2, 4, 6), + right: nil, + expected: []int64{2, 4, 6}, + }, + { + name: "both nil", + left: nil, + right: nil, + expected: nil, + }, + { + name: "interleaved", + left: logEvents(1, 3, 5), + right: logEvents(2, 4, 6), + expected: []int64{1, 2, 3, 4, 5, 6}, + }, + { + name: "left before right", + left: logEvents(1, 2, 3), + right: logEvents(4, 5, 6), + expected: []int64{1, 2, 3, 4, 5, 6}, + }, + { + name: "right before left", + left: logEvents(4, 5, 6), + right: logEvents(1, 2, 3), + expected: []int64{1, 2, 3, 4, 5, 6}, + }, + { + name: "equal timestamps", + left: logEvents(1, 2, 3), + right: logEvents(1, 2, 3), + expected: []int64{1, 1, 2, 2, 3, 3}, + }, + { + name: "left empty", + left: logEvents(), + right: logEvents(1, 2, 3), + expected: []int64{1, 2, 3}, + }, + { + name: "right empty", + left: logEvents(1, 2, 3), + right: logEvents(), + expected: []int64{1, 2, 3}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + merged := MergeLogEvents(tt.left, tt.right) + if merged == nil { + assert.Nil(t, tt.expected) + return + } + got, err := collect(merged) + require.NoError(t, err) + assert.Equal(t, tt.expected, got) + }) + } +} + +func TestMergeLogEvents_Error(t *testing.T) { + testErr := errors.New("test error") + errSeq := func(yield func(LogEvent, error) bool) { + yield(LogEvent{}, testErr) + } + + t.Run("left error", func(t *testing.T) { + merged := MergeLogEvents(errSeq, logEvents(1, 2)) + _, err := collect(merged) + assert.Equal(t, testErr, err) + }) + + t.Run("right error", func(t *testing.T) { + merged := MergeLogEvents(logEvents(1, 2), errSeq) + _, err := collect(merged) + assert.Equal(t, testErr, err) + }) +} + +func TestMergeLogEvents_EarlyStop(t *testing.T) { + merged := MergeLogEvents(logEvents(1, 3, 5, 7, 9), logEvents(2, 4, 6, 8, 10)) + got := TakeFirstN(merged, 4) + ts, err := collect(got) + require.NoError(t, err) + assert.Equal(t, []int64{1, 2, 3, 4}, ts) +} + +func TestTakeFirstN(t *testing.T) { + testTakeN(t, TakeFirstN, []struct { + name string + input []int64 + n int + expected []int64 + }{ + {"take 3 of 5", []int64{1, 2, 3, 4, 5}, 3, []int64{1, 2, 3}}, + {"take 5 of 3", []int64{1, 2, 3}, 5, []int64{1, 2, 3}}, + {"take 0", []int64{1, 2, 3}, 0, []int64{1, 2, 3}}, + {"take negative", []int64{1, 2, 3}, -1, []int64{1, 2, 3}}, + {"empty input", nil, 3, nil}, + }) +} + +func TestTakeLastN(t *testing.T) { + testTakeN(t, TakeLastN, []struct { + name string + input []int64 + n int + expected []int64 + }{ + {"last 3 of 5", []int64{1, 2, 3, 4, 5}, 3, []int64{3, 4, 5}}, + {"last 5 of 3", []int64{1, 2, 3}, 5, []int64{1, 2, 3}}, + {"last 0", []int64{1, 2, 3}, 0, []int64{1, 2, 3}}, + {"last negative", []int64{1, 2, 3}, -1, []int64{1, 2, 3}}, + {"empty input", nil, 3, nil}, + }) +} + +func testTakeN(t *testing.T, takeFn func(iter.Seq2[LogEvent, error], int) iter.Seq2[LogEvent, error], tests []struct { + name string + input []int64 + n int + expected []int64 +}) { + t.Helper() + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := collect(takeFn(logEvents(tt.input...), tt.n)) + require.NoError(t, err) + assert.Equal(t, tt.expected, got) + }) + } +} + +func TestTakeLastN_Error(t *testing.T) { + testErr := errors.New("test error") + seq := func(yield func(LogEvent, error) bool) { + if yield(LogEvent{Timestamp: ptr.Int64(1)}, nil) { + yield(LogEvent{}, testErr) + } + } + _, err := collect(TakeLastN(seq, 5)) + assert.Equal(t, testErr, err) +} diff --git a/src/pkg/clouds/aws/cw/stream.go b/src/pkg/clouds/aws/cw/stream.go index bc279b217..e6774f3c8 100644 --- a/src/pkg/clouds/aws/cw/stream.go +++ b/src/pkg/clouds/aws/cw/stream.go @@ -3,26 +3,17 @@ package cw import ( "context" "errors" + "iter" + "sync" "time" "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs/types" ) -// QueryAndTailLogGroup queries the log group from the give start time and initiates a Live Tail session. +// QueryAndTailLogGroup queries the log group from the given start time and initiates a Live Tail session. // This function also handles the case where the log group does not exist yet. -// The caller should call `Close()` on the returned EventStream when done. -func QueryAndTailLogGroup(ctx context.Context, cw LogsClient, lgi LogGroupInput, start, end time.Time) (LiveTailStream, error) { - ctx, cancel := context.WithCancel(ctx) - - es := &eventStream{ - cancel: cancel, - ch: make(chan types.StartLiveTailResponseStream), - } - - var tailStream LiveTailStream - // First call TailLogGroup once to check if the log group exists or we have another error - var err error - tailStream, err = TailLogGroup(ctx, cw, lgi) +func QueryAndTailLogGroup(ctx context.Context, cwClient LogsClient, lgi LogGroupInput, start, end time.Time) (iter.Seq2[[]LogEvent, error], error) { + tailSeq, err := TailLogGroup(ctx, cwClient, lgi) if err != nil { var resourceNotFound *types.ResourceNotFoundException if !errors.As(err, &resourceNotFound) { @@ -31,46 +22,47 @@ func QueryAndTailLogGroup(ctx context.Context, cw LogsClient, lgi LogGroupInput, // Doesn't exist yet, continue to poll for it } - // Start goroutine to wait for the log group to be created and then tail it - go func() { - defer close(es.ch) - + return func(yield func([]LogEvent, error) bool) { // If the log group does not exist yet, poll until it does - if tailStream == nil { + if tailSeq == nil { var err error - tailStream, err = pollTailLogGroup(ctx, cw, lgi) + tailSeq, err = pollTailLogGroup(ctx, cwClient, lgi) if err != nil { - es.err = err + yield(nil, err) return } } - defer tailStream.Close() + // Live tail started. Query historical logs if !start.IsZero() { if end.IsZero() { end = time.Now() } - // Query the logs between the start time and now; TODO: could use a single CloudWatch client for all queries in same region - if err := QueryLogGroup(ctx, cw, lgi, start, end, 0, func(events []LogEvent) error { - es.ch <- &types.StartLiveTailResponseStreamMemberSessionUpdate{ - Value: types.LiveTailSessionUpdate{SessionResults: events}, + querySeq, err := QueryLogGroup(ctx, cwClient, lgi, start, end, 0) + if err != nil { + if !yield(nil, err) { + return + } + } else { + for events, err := range querySeq { + if !yield(events, err) { + return + } } - return nil - }); err != nil { - es.err = err - return // the caller will likely cancel the context } } - // Pipe the events from the tail stream to the internal channel - es.err = es.pipeEvents(ctx, tailStream) - }() - - return es, nil + // Tail live logs + for events, err := range tailSeq { + if !yield(events, err) { + return + } + } + }, nil } // pollTailLogGroup polls the log group and starts the Live Tail session once it's available -func pollTailLogGroup(ctx context.Context, cw StartLiveTailAPI, lgi LogGroupInput) (LiveTailStream, error) { +func pollTailLogGroup(ctx context.Context, cw StartLiveTailAPI, lgi LogGroupInput) (iter.Seq2[[]LogEvent, error], error) { ticker := time.NewTicker(time.Second) defer ticker.Stop() @@ -80,82 +72,69 @@ func pollTailLogGroup(ctx context.Context, cw StartLiveTailAPI, lgi LogGroupInpu case <-ctx.Done(): return nil, ctx.Err() case <-ticker.C: - eventStream, err := TailLogGroup(ctx, cw, lgi) + logIter, err := TailLogGroup(ctx, cw, lgi) if errors.As(err, &resourceNotFound) { continue // keep trying } - return eventStream, err + return logIter, err } } } -// eventStream is an bare implementation of the EventStream interface. -type eventStream struct { - cancel context.CancelFunc - ch chan types.StartLiveTailResponseStream - err error -} - -var _ LiveTailStream = (*eventStream)(nil) - -func (es *eventStream) Close() error { - es.cancel() - return nil -} - -func (es *eventStream) Err() error { - return es.err -} - -func (es *eventStream) Events() <-chan types.StartLiveTailResponseStream { - return es.ch -} +// QueryAndTailLogGroups queries and tails multiple log groups concurrently. +// Events from different groups are interleaved (not merge-sorted). +func QueryAndTailLogGroups(ctx context.Context, cwClient LogsClient, start, end time.Time, lgis ...LogGroupInput) (iter.Seq2[[]LogEvent, error], error) { + ctx, cancel := context.WithCancel(ctx) -// pipeEvents copies events from the given EventStream to the internal channel, -// until the context is canceled or an error occurs in the given EventStream. -func (es *eventStream) pipeEvents(ctx context.Context, tailStream LiveTailStream) error { - for { - // Double select to make sure context cancellation is not blocked by either the receive or send - // See: https://stackoverflow.com/questions/60030756/what-does-it-mean-when-one-channel-uses-two-arrows-to-write-to-another-channel - select { - case event := <-tailStream.Events(): // blocking - if err := tailStream.Err(); err != nil { - return err - } - if event == nil { - return nil - } - select { - case es.ch <- event: - case <-ctx.Done(): - return ctx.Err() - } - case <-ctx.Done(): // blocking - return ctx.Err() + type result struct { + events []LogEvent + err error + } + ch := make(chan result) + + var wg sync.WaitGroup + var lastErr error + for _, lgi := range lgis { + logSeq, err := QueryAndTailLogGroup(ctx, cwClient, lgi, start, end) + if err != nil { + lastErr = err + continue } + wg.Add(1) + go func() { + defer wg.Done() + for events, err := range logSeq { + select { + case ch <- result{events, err}: + case <-ctx.Done(): + return + } + if err != nil { + return + } + } + }() } -} -func newEventStream(cancel func()) *eventStream { - return &eventStream{ - cancel: cancel, - ch: make(chan types.StartLiveTailResponseStream), - } -} + go func() { + wg.Wait() + close(ch) + }() -func NewStaticLogStream(ch <-chan LogEvent, cancel func()) EventStream[types.StartLiveTailResponseStream] { - es := newEventStream(cancel) + if lastErr != nil { + cancel() + return nil, lastErr + } - go func() { - defer close(es.ch) - for evt := range ch { - es.ch <- &types.StartLiveTailResponseStreamMemberSessionUpdate{ - Value: types.LiveTailSessionUpdate{ - SessionResults: []types.LiveTailSessionLogEvent{evt}, - }, + return func(yield func([]LogEvent, error) bool) { + defer cancel() + for r := range ch { + if !yield(r.events, r.err) { + return + } + if r.err != nil { + return } } - }() - - return es + }, nil } diff --git a/src/pkg/clouds/aws/cw/stream_test.go b/src/pkg/clouds/aws/cw/stream_test.go index a699da413..8b297d00b 100644 --- a/src/pkg/clouds/aws/cw/stream_test.go +++ b/src/pkg/clouds/aws/cw/stream_test.go @@ -4,12 +4,13 @@ package cw import ( "context" + "fmt" "testing" "time" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/cloudwatchlogs" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestPendingStream(t *testing.T) { @@ -23,26 +24,22 @@ func TestPendingStream(t *testing.T) { t.Skipf("Failed to load AWS config: %v", err) } + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + cw := cloudwatchlogs.NewFromConfig(cfg) - ps, err := QueryAndTailLogGroup(context.Background(), cw, LogGroupInput{ + evts, err := QueryAndTailLogGroup(ctx, cw, LogGroupInput{ LogGroupARN: "arn:aws:logs:us-west-2:532501343364:log-group:/ecs/lio/logss:*", }, time.Now().Add(-time.Minute), time.Time{}) - assert.NoError(t, err) - - go func() { - time.Sleep(5 * time.Second) - ps.Close() - }() + require.NoError(t, err) - if ps.Err() != nil { - t.Errorf("Error: %v", ps.Err()) - } - - for e := range ps.Events() { - if e == nil { - t.Errorf("Error: %v", ps.Err()) + for evt, err := range evts { + if err != nil { + t.Logf("Stream ended: %v", err) + break + } + for _, evt := range evt { + fmt.Println(*evt.Message) } - println(e) } - t.Error(ps.Err()) } diff --git a/src/pkg/clouds/aws/ecs/status.go b/src/pkg/clouds/aws/ecs/status.go index c011a2713..7d8f3b0e6 100644 --- a/src/pkg/clouds/aws/ecs/status.go +++ b/src/pkg/clouds/aws/ecs/status.go @@ -14,7 +14,7 @@ import ( ) // GetTaskStatus returns nil if the task is still running, io.EOF if the task is stopped successfully, or an error if the task failed. -func GetTaskStatus(ctx context.Context, taskArn TaskArn) error { +func GetTaskStatus(ctx context.Context, taskArn TaskArn) (bool, error) { region := region.FromArn(*taskArn) cluster, taskID := SplitClusterTask(taskArn) return getTaskStatus(ctx, region, cluster, taskID) @@ -31,10 +31,10 @@ func isTaskTerminalStatus(status string) bool { } // getTaskStatus returns nil if the task is still running, io.EOF if the task is stopped successfully, or an error if the task failed. -func getTaskStatus(ctx context.Context, region aws.Region, cluster, taskId string) error { +func getTaskStatus(ctx context.Context, region aws.Region, cluster, taskId string) (bool, error) { cfg, err := aws.LoadDefaultConfig(ctx, region) if err != nil { - return err + return false, err } ecsClient := ecs.NewFromConfig(cfg) @@ -44,26 +44,33 @@ func getTaskStatus(ctx context.Context, region aws.Region, cluster, taskId strin Tasks: []string{taskId}, }) if ti == nil || len(ti.Tasks) == 0 { - return nil // task doesn't exist yet; TODO: check the actual error from DescribeTasks + return false, nil // task doesn't exist (yet); TODO: check the actual error from DescribeTasks } task := ti.Tasks[0] if task.LastStatus == nil || !isTaskTerminalStatus(*task.LastStatus) { - return nil // still running + return false, nil // still running } + var stoppedReason string + if task.StoppedReason != nil { + stoppedReason = *task.StoppedReason + } switch task.StopCode { default: - return TaskFailure{task.StopCode, *task.StoppedReason} + return true, TaskFailure{task.StopCode, stoppedReason} case ecsTypes.TaskStopCodeEssentialContainerExited: for _, c := range task.Containers { if c.ExitCode != nil && *c.ExitCode != 0 { - reason := fmt.Sprintf("%s with code %d", *task.StoppedReason, *c.ExitCode) - return TaskFailure{task.StopCode, reason} + if stoppedReason == "" { + stoppedReason = "essential container exited" + } + reason := fmt.Sprintf("%s with code %d", stoppedReason, *c.ExitCode) + return true, TaskFailure{task.StopCode, reason} } } fallthrough case "": // TODO: shouldn't happen - return io.EOF // Success + return true, io.EOF // Success; EOF returned for backward compatibility } } @@ -91,7 +98,7 @@ func WaitForTask(ctx context.Context, taskArn TaskArn, poll time.Duration) error // Handle cancellation return ctx.Err() case <-ticker.C: - if err := GetTaskStatus(ctx, taskArn); err != nil { + if done, err := GetTaskStatus(ctx, taskArn); done || err != nil { return err } } diff --git a/src/pkg/clouds/aws/ecs/tail.go b/src/pkg/clouds/aws/ecs/tail.go index 869745d5d..de98b454f 100644 --- a/src/pkg/clouds/aws/ecs/tail.go +++ b/src/pkg/clouds/aws/ecs/tail.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "iter" "path" "time" @@ -22,36 +23,33 @@ func (a *AwsEcs) Tail(ctx context.Context, taskArn TaskArn) error { } taskId := GetTaskID(taskArn) a.Region = region.FromArn(*taskArn) - es, err := a.TailTaskID(ctx, cwClient, taskId) + + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + tailIter, err := a.TailTaskID(ctx, cwClient, taskId) if err != nil { return err } - ctx, cancel := context.WithCancel(ctx) - defer cancel() - taskch := make(chan error) - defer close(taskch) + taskch := make(chan error, 1) go func() { taskch <- WaitForTask(ctx, taskArn, time.Second*3) + cancel() // stop tailing when task finishes }() - for { - select { - case e := <-es.Events(): // blocking - events, err := cw.GetLogEvents(e) - // Print before checking for errors, so we don't lose any logs in case of EOF - for _, event := range events { - fmt.Println(*event.Message) - } - if err != nil { + for batch, err := range tailIter { + if err != nil { + if !errors.Is(err, context.Canceled) { return err } - case <-ctx.Done(): - return ctx.Err() - case err := <-taskch: - return err + break + } + for _, evt := range batch { + fmt.Println(*evt.Message) } } + return <-taskch } func (a *AwsEcs) GetTaskArn(taskID string) (TaskArn, error) { @@ -65,16 +63,20 @@ func (a *AwsEcs) GetTaskArn(taskID string) (TaskArn, error) { return &taskArn, nil } -func (a *AwsEcs) QueryTaskID(ctx context.Context, cwClient cw.FilterLogEventsAPI, taskID string, start, end time.Time, limit int32) (cw.LiveTailStream, error) { +func (a *AwsEcs) QueryTaskID(ctx context.Context, cwClient cw.FilterLogEventsAPIClient, taskID string, start, end time.Time, limit int32) (iter.Seq2[[]cw.LogEvent, error], error) { if taskID == "" { return nil, errors.New("taskID is empty") } lgi := cw.LogGroupInput{LogGroupARN: a.LogGroupARN, LogStreamNames: []string{GetCDLogStreamForTaskID(taskID)}} - return cw.QueryLogGroupStream(ctx, cwClient, lgi, start, end, limit) + logSeq, err := cw.QueryLogGroup(ctx, cwClient, lgi, start, end, limit) + if err != nil { + return nil, err + } + return logSeq, nil } -func (a *AwsEcs) TailTaskID(ctx context.Context, cwClient cw.StartLiveTailAPI, taskID string) (cw.LiveTailStream, error) { +func (a *AwsEcs) TailTaskID(ctx context.Context, cwClient cw.StartLiveTailAPI, taskID string) (iter.Seq2[[]cw.LogEvent, error], error) { if taskID == "" { return nil, errors.New("taskID is required") } @@ -87,16 +89,16 @@ func (a *AwsEcs) TailTaskID(ctx context.Context, cwClient cw.StartLiveTailAPI, t lgi := cw.LogGroupInput{LogGroupARN: a.LogGroupARN, LogStreamNames: []string{GetCDLogStreamForTaskID(taskID)}} for { - stream, err := cw.TailLogGroup(ctx, cwClient, lgi) + logSeq, err := cw.TailLogGroup(ctx, cwClient, lgi) if err != nil { var resourceNotFound *cwTypes.ResourceNotFoundException if !errors.As(err, &resourceNotFound) { return nil, err } // The log stream doesn't exist yet, so wait for it to be created, but bail out if the task is stopped - err := getTaskStatus(ctx, a.Region, a.ClusterName, taskID) - if err != nil { - return nil, err + done, err := getTaskStatus(ctx, a.Region, a.ClusterName, taskID) + if done || err != nil { + return nil, err // TODO: handle transient errors } // continue loop, waiting for the log stream to be created; sleep to avoid throttling if err := pkg.SleepWithContext(ctx, time.Second); err != nil { @@ -104,8 +106,8 @@ func (a *AwsEcs) TailTaskID(ctx context.Context, cwClient cw.StartLiveTailAPI, t } continue } - // TODO: should wrap this stream so we can return io.EOF on task stop - return stream, nil + // TODO: should wrap this iter so we can return io.EOF on task stop + return logSeq, nil } } diff --git a/src/pkg/clouds/gcp/cloudbuild.go b/src/pkg/clouds/gcp/cloudbuild.go index 60ce9bd97..b2ccf026c 100644 --- a/src/pkg/clouds/gcp/cloudbuild.go +++ b/src/pkg/clouds/gcp/cloudbuild.go @@ -184,25 +184,25 @@ func (gcp Gcp) RunCloudBuild(ctx context.Context, args CloudBuildArgs) (string, return op.Name(), nil } -func (gcp Gcp) GetBuildStatus(ctx context.Context, startBuildOpName string) error { +func (gcp Gcp) GetBuildStatus(ctx context.Context, startBuildOpName string) (bool, error) { svc, err := cloudbuild.NewClient(ctx) if err != nil { - return fmt.Errorf("failed to create Cloud Build client: %w", err) + return false, fmt.Errorf("failed to create Cloud Build client: %w", err) } defer svc.Close() op := svc.CreateBuildOperation(startBuildOpName) build, err := op.Poll(ctx) if err != nil { - return fmt.Errorf("failed to poll build operation: %w", err) + return false, fmt.Errorf("failed to poll build operation: %w", err) } if build != nil { if build.Status == cloudbuildpb.Build_SUCCESS { - return io.EOF + return true, io.EOF // success; EOF is returned for backward compatibility } - return client.ErrDeploymentFailed{Message: fmt.Sprintf("build failed with status: %v", build.Status)} + return true, client.ErrDeploymentFailed{Message: fmt.Sprintf("build failed with status: %v", build.Status)} } - return nil + return false, nil } func GetMachineType(machineType *string) MachineType { diff --git a/src/pkg/track/track.go b/src/pkg/track/track.go index 3da23b161..ce0aad35e 100644 --- a/src/pkg/track/track.go +++ b/src/pkg/track/track.go @@ -84,5 +84,6 @@ func Cmd(cmd *cobra.Command, verb string, props ...Property) { props = append(props, P(f.Name, f.Value)) }) } + // This was supposed to be strings.Title but that got deprecated, so now we're stuck with strings.ToTitle which makes everything uppercase. Oh well. Evt(strings.ToTitle(command+" "+verb), props...) } From ef0dbaa2ba851d85261d68bc5a5e4e3a16fad90a Mon Sep 17 00:00:00 2001 From: Edward J Date: Thu, 19 Feb 2026 12:31:48 -0800 Subject: [PATCH 2/2] Do not query older subscribe events --- src/pkg/cli/client/byoc/gcp/byoc.go | 5 ++++- src/pkg/cli/client/byoc/gcp/stream.go | 3 ++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/pkg/cli/client/byoc/gcp/byoc.go b/src/pkg/cli/client/byoc/gcp/byoc.go index f82f7071e..3b9dfb3f0 100644 --- a/src/pkg/cli/client/byoc/gcp/byoc.go +++ b/src/pkg/cli/client/byoc/gcp/byoc.go @@ -582,7 +582,10 @@ func (b *ByocGcp) Subscribe(ctx context.Context, req *defangv1.SubscribeRequest) subscribeStream := NewSubscribeStream(ctx, b.driver, true, req.Etag, req.Services, ignoreCdSuccess) subscribeStream.AddJobStatusUpdate(b.PulumiStack, req.Project, req.Etag, req.Services) subscribeStream.AddServiceStatusUpdate(b.PulumiStack, req.Project, req.Etag, req.Services) - return subscribeStream.Follow(time.Now()) + + now := time.Now() + subscribeStream.query.AddSince(now) // Do no query historical events + return subscribeStream.Follow(now) } func (b *ByocGcp) QueryLogs(ctx context.Context, req *defangv1.TailRequest) (iter.Seq2[*defangv1.TailResponse, error], error) { diff --git a/src/pkg/cli/client/byoc/gcp/stream.go b/src/pkg/cli/client/byoc/gcp/stream.go index 84970d1a9..e9b1e1abf 100644 --- a/src/pkg/cli/client/byoc/gcp/stream.go +++ b/src/pkg/cli/client/byoc/gcp/stream.go @@ -69,11 +69,12 @@ func (s *ServerStream[T]) Follow(start time.Time) (iter.Seq2[*T, error], error) return nil, err } query := s.query.GetQuery() + shouldList := !start.IsZero() && start.Unix() > 0 && time.Since(start) > 10*time.Millisecond term.Debugf("Query and tail logs since %v with query: \n%v", start, query) return func(yield func(*T, error) bool) { defer tailer.Close() // Only query older logs if start time is more than 10ms ago - if !start.IsZero() && start.Unix() > 0 && time.Since(start) > 10*time.Millisecond { + if shouldList { lister, err := s.gcpLogsClient.ListLogEntries(s.ctx, query, gcp.OrderAscending) if err != nil { yield(nil, err)