diff --git a/cmd/vulcan-aws-trusted-advisor/main.go b/cmd/vulcan-aws-trusted-advisor/main.go index f82ce4161..6a05f9c02 100644 --- a/cmd/vulcan-aws-trusted-advisor/main.go +++ b/cmd/vulcan-aws-trusted-advisor/main.go @@ -17,6 +17,9 @@ import ( "github.com/adevinta/vulcan-check-sdk/helpers/awshelpers" checkstate "github.com/adevinta/vulcan-check-sdk/state" report "github.com/adevinta/vulcan-report" + awsRetry "github.com/aws/aws-sdk-go-v2/aws/retry" + supporttypes "github.com/aws/aws-sdk-go-v2/service/support/types" + "golang.org/x/time/rate" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/support" @@ -111,7 +114,9 @@ func scanAccount(ctx context.Context, opt options, target, _ string, logger *log return checkstate.ErrAssetUnreachable } - s := support.NewFromConfig(cfg) + s := support.NewFromConfig(cfg, func(o *support.Options) { + o.Retryer = awsRetry.AddWithMaxAttempts(o.Retryer, 5) + }) // Retrieve checks list checks, err := s.DescribeTrustedAdvisorChecks( ctx, @@ -123,76 +128,18 @@ func scanAccount(ctx context.Context, opt options, target, _ string, logger *log } // Refresh checks - checkIds := []*string{} - enqueued := 0 - for _, check := range checks.Checks { - // Ignore results if we can't know the category - if check.Category == nil { - continue - } - - // Ignore results that does are not security - if *check.Category != "security" { - continue - } - checkIds = append(checkIds, check.Id) - refreshed, err := s.RefreshTrustedAdvisorCheck(ctx, &support.RefreshTrustedAdvisorCheckInput{CheckId: check.Id}) - if err != nil { - // Haven't found a more elegant way to check for an - // InvalidParameterValueException. This error type is not defined in the - // support/types package as it is for other services. - if strings.Contains(err.Error(), "InvalidParameterValueException") { - logger.Printf("check '%s' is not refreshable\n", *check.Name) - continue - } - return err - } - - logger.Printf("check '%s' is refreshed with status: '%s'\n", *check.Name, *refreshed.Status.Status) - if *refreshed.Status.Status == "enqueued" { - enqueued++ - } + limiter := rate.NewLimiter(rate.Every(100*time.Millisecond), 1) + toPoll, err := refreshSecurityChecks(ctx, s, checks.Checks, limiter, logger) + if err != nil { + return err } - // If some check was enqueued for refreshing - // poll it's status and wait up until opt.RefreshTimeout - if enqueued > 0 { - t := time.NewTicker(time.Duration(opt.RefreshTimeout) * time.Second) - defer t.Stop() + // Poll refresh statuses with a timeout + if len(toPoll) > 0 { + ctxTimeout, cancel := context.WithTimeout(ctx, time.Duration(opt.RefreshTimeout)*time.Second) + defer cancel() - LOOP: - for { - select { - case <-t.C: - break LOOP - default: - checkStatus, err := s.DescribeTrustedAdvisorCheckRefreshStatuses( - ctx, - &support.DescribeTrustedAdvisorCheckRefreshStatusesInput{ - CheckIds: checkIds, - }, - ) - // Haven't found a more elegant way to check for an - // InvalidParameterValueException. This error type is not - // defined in the support/types package as it is for other - // services. - if err != nil && !strings.Contains(err.Error(), "InvalidParameterValueException") { - return fmt.Errorf("unable to check the refresh statuses: %w", err) - } - var pending bool - for _, cs := range checkStatus.Statuses { - if *cs.Status == "enqueued" || *cs.Status == "processing" { - pending = true - break - } - } - if !pending { - break LOOP - } - logger.Infof("Waiting for checks to be refreshed. Sleeping for %v...", rfrshInterval) - time.Sleep(rfrshInterval) - } - } + pollRefreshStatuses(ctxTimeout, s, toPoll, rfrshInterval, logger) } // Retrieve checks summaries @@ -363,3 +310,83 @@ func scanAccount(ctx context.Context, opt options, target, _ string, logger *log } return err } + +func refreshSecurityChecks(ctx context.Context, svc *support.Client, checks []supporttypes.TrustedAdvisorCheckDescription, limiter *rate.Limiter, logger *logrus.Entry) ([]*string, error) { + var enqueuedIDs []*string + + for _, chk := range checks { + if chk.Category == nil || *chk.Category != "security" { + continue + } + + if err := limiter.Wait(ctx); err != nil { + return nil, fmt.Errorf("rate limiter interrupted: %w", err) + } + + out, err := svc.RefreshTrustedAdvisorCheck(ctx, &support.RefreshTrustedAdvisorCheckInput{ + CheckId: chk.Id, + }) + if err != nil { + if strings.Contains(err.Error(), "InvalidParameterValueException") { + logger.Warnf("check %q is not refreshable, ignoring", *chk.Name) + continue + } + return nil, fmt.Errorf("refresh %s: %w", *chk.Id, err) + } + status := aws.ToString(out.Status.Status) + logger.Infof("refresh of %q check with status %s", *chk.Name, status) + if status == "enqueued" { + enqueuedIDs = append(enqueuedIDs, chk.Id) + } + } + + return enqueuedIDs, nil +} + +func pollRefreshStatuses(ctx context.Context, svc *support.Client, ids []*string, maxRefreshWaitInterval time.Duration, logger *logrus.Entry) { + for { + select { + case <-ctx.Done(): + logger.Warnf("maxRefreshWaitInterval reached, stop polling") + return + default: + out, err := svc.DescribeTrustedAdvisorCheckRefreshStatuses(ctx, &support.DescribeTrustedAdvisorCheckRefreshStatusesInput{ + CheckIds: ids, + }) + if err != nil { + logger.Errorf("DescribeTrustedAdvisorCheckRefreshStatuses failed: %v", err) + return + } + + var maxSleep time.Duration + var pending bool + + for _, st := range out.Statuses { + s := aws.ToString(st.Status) + if s == "enqueued" || s == "processing" { + pending = true + + if st.MillisUntilNextRefreshable != 0 { + d := time.Duration(st.MillisUntilNextRefreshable) * time.Millisecond + if d > maxSleep { + maxSleep = d + } + } + } + } + + if !pending { + return + } + + if maxSleep <= 0 { + maxSleep = maxRefreshWaitInterval + } + logger.Infof("waiting %s until next check", maxSleep) + select { + case <-time.After(maxSleep): + case <-ctx.Done(): + } + } + } +} diff --git a/cmd/vulcan-aws-trusted-advisor/manifest.toml b/cmd/vulcan-aws-trusted-advisor/manifest.toml index bf9b1d9f0..7445f0988 100644 --- a/cmd/vulcan-aws-trusted-advisor/manifest.toml +++ b/cmd/vulcan-aws-trusted-advisor/manifest.toml @@ -1,4 +1,5 @@ Description = "Runs an AWS Trusted Advisor check against an AWS account" AssetTypes = ["AWSAccount"] +Timeout = 900 # 15 minutes. RequiredVars = ["VULCAN_ASSUME_ROLE_ENDPOINT", "ROLE_NAME"] -Options = '{"refresh_timeout": 60}' +Options = '{"refresh_timeout": 600}' diff --git a/go.mod b/go.mod index 76b927bbf..e47cbd4f4 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/zaproxy/zap-api-go v0.0.0-20231219145106-e9ebb9695484 golang.org/x/net v0.41.0 golang.org/x/text v0.26.0 + golang.org/x/time v0.12.0 ) require ( diff --git a/go.sum b/go.sum index 785d5298f..ce2ac77c9 100644 --- a/go.sum +++ b/go.sum @@ -215,6 +215,8 @@ golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M= golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=