Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 95 additions & 68 deletions cmd/vulcan-aws-trusted-advisor/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ import (
"github.com/adevinta/vulcan-check-sdk/helpers/awshelpers"
checkstate "github.com/adevinta/vulcan-check-sdk/state"
report "github.com/adevinta/vulcan-report"
awsRetry "github.com/aws/aws-sdk-go-v2/aws/retry"
supporttypes "github.com/aws/aws-sdk-go-v2/service/support/types"
"golang.org/x/time/rate"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/support"
Expand Down Expand Up @@ -111,7 +114,9 @@ func scanAccount(ctx context.Context, opt options, target, _ string, logger *log
return checkstate.ErrAssetUnreachable
}

s := support.NewFromConfig(cfg)
s := support.NewFromConfig(cfg, func(o *support.Options) {
o.Retryer = awsRetry.AddWithMaxAttempts(o.Retryer, 5)
})
// Retrieve checks list
checks, err := s.DescribeTrustedAdvisorChecks(
ctx,
Expand All @@ -123,76 +128,18 @@ func scanAccount(ctx context.Context, opt options, target, _ string, logger *log
}

// Refresh checks
checkIds := []*string{}
enqueued := 0
for _, check := range checks.Checks {
// Ignore results if we can't know the category
if check.Category == nil {
continue
}

// Ignore results that does are not security
if *check.Category != "security" {
continue
}
checkIds = append(checkIds, check.Id)
refreshed, err := s.RefreshTrustedAdvisorCheck(ctx, &support.RefreshTrustedAdvisorCheckInput{CheckId: check.Id})
if err != nil {
// Haven't found a more elegant way to check for an
// InvalidParameterValueException. This error type is not defined in the
// support/types package as it is for other services.
if strings.Contains(err.Error(), "InvalidParameterValueException") {
logger.Printf("check '%s' is not refreshable\n", *check.Name)
continue
}
return err
}

logger.Printf("check '%s' is refreshed with status: '%s'\n", *check.Name, *refreshed.Status.Status)
if *refreshed.Status.Status == "enqueued" {
enqueued++
}
limiter := rate.NewLimiter(rate.Every(100*time.Millisecond), 1)
toPoll, err := refreshSecurityChecks(ctx, s, checks.Checks, limiter, logger)
if err != nil {
return err
}

// If some check was enqueued for refreshing
// poll it's status and wait up until opt.RefreshTimeout
if enqueued > 0 {
t := time.NewTicker(time.Duration(opt.RefreshTimeout) * time.Second)
defer t.Stop()
// Poll refresh statuses with a timeout
if len(toPoll) > 0 {
ctxTimeout, cancel := context.WithTimeout(ctx, time.Duration(opt.RefreshTimeout)*time.Second)
defer cancel()

LOOP:
for {
select {
case <-t.C:
break LOOP
default:
checkStatus, err := s.DescribeTrustedAdvisorCheckRefreshStatuses(
ctx,
&support.DescribeTrustedAdvisorCheckRefreshStatusesInput{
CheckIds: checkIds,
},
)
// Haven't found a more elegant way to check for an
// InvalidParameterValueException. This error type is not
// defined in the support/types package as it is for other
// services.
if err != nil && !strings.Contains(err.Error(), "InvalidParameterValueException") {
return fmt.Errorf("unable to check the refresh statuses: %w", err)
}
var pending bool
for _, cs := range checkStatus.Statuses {
if *cs.Status == "enqueued" || *cs.Status == "processing" {
pending = true
break
}
}
if !pending {
break LOOP
}
logger.Infof("Waiting for checks to be refreshed. Sleeping for %v...", rfrshInterval)
time.Sleep(rfrshInterval)
}
}
pollRefreshStatuses(ctxTimeout, s, toPoll, rfrshInterval, logger)
}

// Retrieve checks summaries
Expand Down Expand Up @@ -363,3 +310,83 @@ func scanAccount(ctx context.Context, opt options, target, _ string, logger *log
}
return err
}

func refreshSecurityChecks(ctx context.Context, svc *support.Client, checks []supporttypes.TrustedAdvisorCheckDescription, limiter *rate.Limiter, logger *logrus.Entry) ([]*string, error) {
var enqueuedIDs []*string

for _, chk := range checks {
if chk.Category == nil || *chk.Category != "security" {
continue
}

if err := limiter.Wait(ctx); err != nil {
return nil, fmt.Errorf("rate limiter interrupted: %w", err)
}

out, err := svc.RefreshTrustedAdvisorCheck(ctx, &support.RefreshTrustedAdvisorCheckInput{
CheckId: chk.Id,
})
if err != nil {
if strings.Contains(err.Error(), "InvalidParameterValueException") {
logger.Warnf("check %q is not refreshable, ignoring", *chk.Name)
continue
}
return nil, fmt.Errorf("refresh %s: %w", *chk.Id, err)
}
status := aws.ToString(out.Status.Status)
logger.Infof("refresh of %q check with status %s", *chk.Name, status)
if status == "enqueued" {
enqueuedIDs = append(enqueuedIDs, chk.Id)
}
}

return enqueuedIDs, nil
}

func pollRefreshStatuses(ctx context.Context, svc *support.Client, ids []*string, maxRefreshWaitInterval time.Duration, logger *logrus.Entry) {
for {
select {
case <-ctx.Done():
logger.Warnf("maxRefreshWaitInterval reached, stop polling")
return
default:
out, err := svc.DescribeTrustedAdvisorCheckRefreshStatuses(ctx, &support.DescribeTrustedAdvisorCheckRefreshStatusesInput{
CheckIds: ids,
})
if err != nil {
logger.Errorf("DescribeTrustedAdvisorCheckRefreshStatuses failed: %v", err)
return
}

var maxSleep time.Duration
var pending bool

for _, st := range out.Statuses {
s := aws.ToString(st.Status)
if s == "enqueued" || s == "processing" {
pending = true

if st.MillisUntilNextRefreshable != 0 {
d := time.Duration(st.MillisUntilNextRefreshable) * time.Millisecond
if d > maxSleep {
maxSleep = d
}
}
}
}

if !pending {
return
}

if maxSleep <= 0 {
maxSleep = maxRefreshWaitInterval
}
logger.Infof("waiting %s until next check", maxSleep)
select {
case <-time.After(maxSleep):
case <-ctx.Done():
}
}
}
}
3 changes: 2 additions & 1 deletion cmd/vulcan-aws-trusted-advisor/manifest.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Description = "Runs an AWS Trusted Advisor check against an AWS account"
AssetTypes = ["AWSAccount"]
Timeout = 900 # 15 minutes.
RequiredVars = ["VULCAN_ASSUME_ROLE_ENDPOINT", "ROLE_NAME"]
Options = '{"refresh_timeout": 60}'
Options = '{"refresh_timeout": 600}'
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ require (
github.com/zaproxy/zap-api-go v0.0.0-20231219145106-e9ebb9695484
golang.org/x/net v0.41.0
golang.org/x/text v0.26.0
golang.org/x/time v0.12.0
)

require (
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ golang.org/x/term v0.32.0/go.mod h1:uZG1FhGx848Sqfsq4/DlJr3xGGsYMu/L5GW4abiaEPQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.26.0 h1:P42AVeLghgTYr4+xUnTRKDMqpar+PtX7KWuNQL21L8M=
golang.org/x/text v0.26.0/go.mod h1:QK15LZJUUQVJxhz7wXgxSy/CJaTFjd0G+YLonydOVQA=
golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE=
golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc=
golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI=
Expand Down