From f7ca6ab251ea8a9347819c17704b67e166386e8c Mon Sep 17 00:00:00 2001 From: michaelhtm <98621731+michaelhtm@users.noreply.github.com> Date: Tue, 29 Jul 2025 14:52:12 -0700 Subject: [PATCH] feat: retrieve aws partition from callerIdentity --- apis/core/v1alpha1/common.go | 3 ++ apis/core/v1alpha1/resource_metadata.go | 2 + mocks/pkg/types/aws_resource_identifiers.go | 20 +++++++++ .../pkg/types/aws_resource_manager_factory.go | 38 +++++++++++++---- mocks/pkg/types/service_controller.go | 31 ++++++++------ pkg/runtime/adoption_reconciler.go | 7 +++- pkg/runtime/config.go | 33 ++++++++++++--- pkg/runtime/reconciler.go | 41 ++++++++++++++----- pkg/types/aws_resource_identifiers.go | 2 + pkg/types/aws_resource_manager.go | 8 ++++ pkg/types/service_controller.go | 3 +- 11 files changed, 148 insertions(+), 40 deletions(-) diff --git a/apis/core/v1alpha1/common.go b/apis/core/v1alpha1/common.go index b4519e4e..41acf8bf 100644 --- a/apis/core/v1alpha1/common.go +++ b/apis/core/v1alpha1/common.go @@ -16,6 +16,9 @@ package v1alpha1 // AWSRegion represents an AWS regional identifier type AWSRegion string +// AWSPartition represents an AWS partition identifier +type AWSPartition string + // AWSAccountID represents an AWS account identifier type AWSAccountID string diff --git a/apis/core/v1alpha1/resource_metadata.go b/apis/core/v1alpha1/resource_metadata.go index 7b55d5cb..751db567 100644 --- a/apis/core/v1alpha1/resource_metadata.go +++ b/apis/core/v1alpha1/resource_metadata.go @@ -32,4 +32,6 @@ type ResourceMetadata struct { OwnerAccountID *AWSAccountID `json:"ownerAccountID"` // Region is the AWS region in which the resource exists or will exist. Region *AWSRegion `json:"region"` + // Partition is the AWS partition in which the resource exists or will exist + Partition *AWSPartition `json:"partition"` } diff --git a/mocks/pkg/types/aws_resource_identifiers.go b/mocks/pkg/types/aws_resource_identifiers.go index e13cd46e..644b1882 100644 --- a/mocks/pkg/types/aws_resource_identifiers.go +++ b/mocks/pkg/types/aws_resource_identifiers.go @@ -53,6 +53,26 @@ func (_m *AWSResourceIdentifiers) OwnerAccountID() *v1alpha1.AWSAccountID { return r0 } +// Partition provides a mock function with no fields +func (_m *AWSResourceIdentifiers) Partition() *v1alpha1.AWSPartition { + ret := _m.Called() + + if len(ret) == 0 { + panic("no return value specified for Partition") + } + + var r0 *v1alpha1.AWSPartition + if rf, ok := ret.Get(0).(func() *v1alpha1.AWSPartition); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*v1alpha1.AWSPartition) + } + } + + return r0 +} + // Region provides a mock function with no fields func (_m *AWSResourceIdentifiers) Region() *v1alpha1.AWSRegion { ret := _m.Called() diff --git a/mocks/pkg/types/aws_resource_manager_factory.go b/mocks/pkg/types/aws_resource_manager_factory.go index 589f0782..92a32c9a 100644 --- a/mocks/pkg/types/aws_resource_manager_factory.go +++ b/mocks/pkg/types/aws_resource_manager_factory.go @@ -21,6 +21,26 @@ type AWSResourceManagerFactory struct { mock.Mock } +// GetCachedManager provides a mock function with given fields: _a0, _a1, _a2 +func (_m *AWSResourceManagerFactory) GetCachedManager(_a0 v1alpha1.AWSAccountID, _a1 v1alpha1.AWSRegion, _a2 v1alpha1.AWSResourceName) types.AWSResourceManager { + ret := _m.Called(_a0, _a1, _a2) + + if len(ret) == 0 { + panic("no return value specified for GetCachedManager") + } + + var r0 types.AWSResourceManager + if rf, ok := ret.Get(0).(func(v1alpha1.AWSAccountID, v1alpha1.AWSRegion, v1alpha1.AWSResourceName) types.AWSResourceManager); ok { + r0 = rf(_a0, _a1, _a2) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(types.AWSResourceManager) + } + } + + return r0 +} + // IsAdoptable provides a mock function with no fields func (_m *AWSResourceManagerFactory) IsAdoptable() bool { ret := _m.Called() @@ -39,9 +59,9 @@ func (_m *AWSResourceManagerFactory) IsAdoptable() bool { return r0 } -// ManagerFor provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7 -func (_m *AWSResourceManagerFactory) ManagerFor(_a0 config.Config, _a1 aws.Config, _a2 logr.Logger, _a3 *metrics.Metrics, _a4 types.Reconciler, _a5 v1alpha1.AWSAccountID, _a6 v1alpha1.AWSRegion, _a7 v1alpha1.AWSResourceName) (types.AWSResourceManager, error) { - ret := _m.Called(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7) +// ManagerFor provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7, _a8 +func (_m *AWSResourceManagerFactory) ManagerFor(_a0 config.Config, _a1 aws.Config, _a2 logr.Logger, _a3 *metrics.Metrics, _a4 types.Reconciler, _a5 v1alpha1.AWSAccountID, _a6 v1alpha1.AWSRegion, _a7 v1alpha1.AWSPartition, _a8 v1alpha1.AWSResourceName) (types.AWSResourceManager, error) { + ret := _m.Called(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7, _a8) if len(ret) == 0 { panic("no return value specified for ManagerFor") @@ -49,19 +69,19 @@ func (_m *AWSResourceManagerFactory) ManagerFor(_a0 config.Config, _a1 aws.Confi var r0 types.AWSResourceManager var r1 error - if rf, ok := ret.Get(0).(func(config.Config, aws.Config, logr.Logger, *metrics.Metrics, types.Reconciler, v1alpha1.AWSAccountID, v1alpha1.AWSRegion, v1alpha1.AWSResourceName) (types.AWSResourceManager, error)); ok { - return rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7) + if rf, ok := ret.Get(0).(func(config.Config, aws.Config, logr.Logger, *metrics.Metrics, types.Reconciler, v1alpha1.AWSAccountID, v1alpha1.AWSRegion, v1alpha1.AWSPartition, v1alpha1.AWSResourceName) (types.AWSResourceManager, error)); ok { + return rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7, _a8) } - if rf, ok := ret.Get(0).(func(config.Config, aws.Config, logr.Logger, *metrics.Metrics, types.Reconciler, v1alpha1.AWSAccountID, v1alpha1.AWSRegion, v1alpha1.AWSResourceName) types.AWSResourceManager); ok { - r0 = rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7) + if rf, ok := ret.Get(0).(func(config.Config, aws.Config, logr.Logger, *metrics.Metrics, types.Reconciler, v1alpha1.AWSAccountID, v1alpha1.AWSRegion, v1alpha1.AWSPartition, v1alpha1.AWSResourceName) types.AWSResourceManager); ok { + r0 = rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7, _a8) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(types.AWSResourceManager) } } - if rf, ok := ret.Get(1).(func(config.Config, aws.Config, logr.Logger, *metrics.Metrics, types.Reconciler, v1alpha1.AWSAccountID, v1alpha1.AWSRegion, v1alpha1.AWSResourceName) error); ok { - r1 = rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7) + if rf, ok := ret.Get(1).(func(config.Config, aws.Config, logr.Logger, *metrics.Metrics, types.Reconciler, v1alpha1.AWSAccountID, v1alpha1.AWSRegion, v1alpha1.AWSPartition, v1alpha1.AWSResourceName) error); ok { + r1 = rf(_a0, _a1, _a2, _a3, _a4, _a5, _a6, _a7, _a8) } else { r1 = ret.Error(1) } diff --git a/mocks/pkg/types/service_controller.go b/mocks/pkg/types/service_controller.go index 9714355b..2971d828 100644 --- a/mocks/pkg/types/service_controller.go +++ b/mocks/pkg/types/service_controller.go @@ -104,32 +104,39 @@ func (_m *ServiceController) GetResourceManagerFactories() map[string]types.AWSR return r0 } -// NewAWSConfig provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4 -func (_m *ServiceController) NewAWSConfig(_a0 context.Context, _a1 v1alpha1.AWSRegion, _a2 *string, _a3 v1alpha1.AWSResourceName, _a4 schema.GroupVersionKind) (aws.Config, error) { - ret := _m.Called(_a0, _a1, _a2, _a3, _a4) +// NewAWSConfig provides a mock function with given fields: _a0, _a1, _a2, _a3, _a4, _a5 +func (_m *ServiceController) NewAWSConfig(_a0 context.Context, _a1 v1alpha1.AWSRegion, _a2 *string, _a3 v1alpha1.AWSResourceName, _a4 schema.GroupVersionKind, _a5 string) (aws.Config, string, error) { + ret := _m.Called(_a0, _a1, _a2, _a3, _a4, _a5) if len(ret) == 0 { panic("no return value specified for NewAWSConfig") } var r0 aws.Config - var r1 error - if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.AWSRegion, *string, v1alpha1.AWSResourceName, schema.GroupVersionKind) (aws.Config, error)); ok { - return rf(_a0, _a1, _a2, _a3, _a4) + var r1 string + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.AWSRegion, *string, v1alpha1.AWSResourceName, schema.GroupVersionKind, string) (aws.Config, string, error)); ok { + return rf(_a0, _a1, _a2, _a3, _a4, _a5) } - if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.AWSRegion, *string, v1alpha1.AWSResourceName, schema.GroupVersionKind) aws.Config); ok { - r0 = rf(_a0, _a1, _a2, _a3, _a4) + if rf, ok := ret.Get(0).(func(context.Context, v1alpha1.AWSRegion, *string, v1alpha1.AWSResourceName, schema.GroupVersionKind, string) aws.Config); ok { + r0 = rf(_a0, _a1, _a2, _a3, _a4, _a5) } else { r0 = ret.Get(0).(aws.Config) } - if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.AWSRegion, *string, v1alpha1.AWSResourceName, schema.GroupVersionKind) error); ok { - r1 = rf(_a0, _a1, _a2, _a3, _a4) + if rf, ok := ret.Get(1).(func(context.Context, v1alpha1.AWSRegion, *string, v1alpha1.AWSResourceName, schema.GroupVersionKind, string) string); ok { + r1 = rf(_a0, _a1, _a2, _a3, _a4, _a5) } else { - r1 = ret.Error(1) + r1 = ret.Get(1).(string) } - return r0, r1 + if rf, ok := ret.Get(2).(func(context.Context, v1alpha1.AWSRegion, *string, v1alpha1.AWSResourceName, schema.GroupVersionKind, string) error); ok { + r2 = rf(_a0, _a1, _a2, _a3, _a4, _a5) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 } // WithLogger provides a mock function with given fields: _a0 diff --git a/pkg/runtime/adoption_reconciler.go b/pkg/runtime/adoption_reconciler.go index 279de76f..097db409 100644 --- a/pkg/runtime/adoption_reconciler.go +++ b/pkg/runtime/adoption_reconciler.go @@ -148,8 +148,11 @@ func (r *adoptionReconciler) reconcile(ctx context.Context, req ctrlrt.Request) targetDescriptor := rmf.ResourceDescriptor() endpointURL := r.getEndpointURL(res) gvk := targetDescriptor.GroupVersionKind() + partition := "" - awsconfig, err := r.sc.NewAWSConfig(ctx, region, &endpointURL, roleARN, gvk) + // The config pivot to the roleARN will happen if it is not empty. + // in the NewResourceManager + awsconfig, partition, err := r.sc.NewAWSConfig(ctx, region, &endpointURL, roleARN, gvk, partition) if err != nil { return err } @@ -157,7 +160,7 @@ func (r *adoptionReconciler) reconcile(ctx context.Context, req ctrlrt.Request) ackrtlog.InfoAdoptedResource(r.log, res, "starting adoption reconciliation") rm, err := rmf.ManagerFor( - r.cfg, awsconfig, r.log, r.metrics, r, acctID, region, roleARN, + r.cfg, awsconfig, r.log, r.metrics, r, acctID, region, ackv1alpha1.AWSPartition(partition), roleARN, ) if err != nil { return err diff --git a/pkg/runtime/config.go b/pkg/runtime/config.go index 82ad3a2d..24dba499 100644 --- a/pkg/runtime/config.go +++ b/pkg/runtime/config.go @@ -20,6 +20,7 @@ import ( "strings" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/credentials/stscreds" "github.com/aws/aws-sdk-go-v2/service/sts" @@ -47,7 +48,8 @@ func (c *serviceController) NewAWSConfig( endpointURL *string, roleARN ackv1alpha1.AWSResourceName, groupVersionKind schema.GroupVersionKind, -) (aws.Config, error) { + partition string, +) (aws.Config, string, error) { val := formatUserAgent( appName, @@ -69,19 +71,25 @@ func (c *serviceController) NewAWSConfig( config.WithHTTPClient(client), ) if err != nil { - return awsCfg, err + return awsCfg, partition, err } if endpointURL != nil && *endpointURL != "" { awsCfg.BaseEndpoint = endpointURL } + stsclient := sts.NewFromConfig(awsCfg) if roleARN != "" { - client := sts.NewFromConfig(awsCfg) - creds := stscreds.NewAssumeRoleProvider(client, string(roleARN)) + creds := stscreds.NewAssumeRoleProvider(stsclient, string(roleARN)) awsCfg.Credentials = aws.NewCredentialsCache(creds) } - return awsCfg, nil + if partition == "" { + partition, err = c.getPartition(ctx, stsclient) + if err != nil { + return awsCfg, partition, nil + } + } + return awsCfg, partition, nil } func formatUserAgent(name, version string, extra ...string) string { @@ -91,3 +99,18 @@ func formatUserAgent(name, version string, extra ...string) string { } return ua } + +// getPartition gets the partition of the caller identity +func (c *serviceController) getPartition(ctx context.Context, client *sts.Client) (string, error) { + identity, err := client.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + // what od we do if ARN is nil? + if err != nil || identity.Arn == nil { + return "", err + } + clientArn, err := arn.Parse(*identity.Arn) + if err != nil { + return "", err + } + + return clientArn.Partition, nil +} diff --git a/pkg/runtime/reconciler.go b/pkg/runtime/reconciler.go index ddc94f46..db74cc26 100644 --- a/pkg/runtime/reconciler.go +++ b/pkg/runtime/reconciler.go @@ -262,12 +262,7 @@ func (r *resourceReconciler) Reconcile(ctx context.Context, req ctrlrt.Request) region := r.getRegion(desired) endpointURL := r.getEndpointURL(desired) gvk := r.rd.GroupVersionKind() - // The config pivot to the roleARN will happen if it is not empty. - // in the NewResourceManager - clientConfig, err := r.sc.NewAWSConfig(ctx, region, &endpointURL, roleARN, gvk) - if err != nil { - return ctrlrt.Result{}, err - } + partition := r.getPartition(desired) rlog.WithValues( "account", acctID, @@ -275,12 +270,22 @@ func (r *resourceReconciler) Reconcile(ctx context.Context, req ctrlrt.Request) "region", region, ) - rm, err := r.rmf.ManagerFor( - r.cfg, clientConfig, r.log, r.metrics, r, acctID, region, roleARN, - ) - if err != nil { - return ctrlrt.Result{}, err + rm := r.rmf.GetCachedManager(acctID, region, roleARN) + if rm == nil { + // The config pivot to the roleARN will happen if it is not empty. + // in the NewResourceManager + clientConfig, partition, err := r.sc.NewAWSConfig(ctx, region, &endpointURL, roleARN, gvk, partition) + if err != nil { + return ctrlrt.Result{}, err + } + rm, err = r.rmf.ManagerFor( + r.cfg, clientConfig, r.log, r.metrics, r, acctID, region, ackv1alpha1.AWSPartition(partition), roleARN, + ) + if err != nil { + return ctrlrt.Result{}, err + } } + latest, err := r.reconcile(ctx, rm, desired) return r.HandleReconcileError(ctx, desired, latest, err) } @@ -1313,6 +1318,20 @@ func (r *resourceReconciler) getRegion( return ackv1alpha1.AWSRegion(r.cfg.Region) } +// getPartition attempts getting the partition from the resource status +// if it exists +func (r *resourceReconciler) getPartition( + res acktypes.AWSResource, +) string { + // first try to get the region from the status.resourceMetadata + metadataRegion := res.Identifiers().Partition() + if metadataRegion != nil { + return string(*metadataRegion) + } + + return "" +} + // getDeletionPolicy returns the resource's deletion policy based on the default // behaviour or any other overriding annotations. // diff --git a/pkg/types/aws_resource_identifiers.go b/pkg/types/aws_resource_identifiers.go index d30cf6c9..083981e0 100644 --- a/pkg/types/aws_resource_identifiers.go +++ b/pkg/types/aws_resource_identifiers.go @@ -29,4 +29,6 @@ type AWSResourceIdentifiers interface { ARN() *ackv1alpha1.AWSResourceName // Region is the AWS region in which the resource exists or will exist. Region() *ackv1alpha1.AWSRegion + // Partition is the AWS partition in which the resource exists or will exist. + Partition() *ackv1alpha1.AWSPartition } diff --git a/pkg/types/aws_resource_manager.go b/pkg/types/aws_resource_manager.go index 4ca5740e..8579efc0 100644 --- a/pkg/types/aws_resource_manager.go +++ b/pkg/types/aws_resource_manager.go @@ -112,8 +112,16 @@ type AWSResourceManagerFactory interface { Reconciler, ackv1alpha1.AWSAccountID, ackv1alpha1.AWSRegion, + ackv1alpha1.AWSPartition, ackv1alpha1.AWSResourceName, ) (AWSResourceManager, error) + // GetCachedManager returns an AWSResourceManager if it has previously been created + // and cahced, or returns nil if not + GetCachedManager( + ackv1alpha1.AWSAccountID, + ackv1alpha1.AWSRegion, + ackv1alpha1.AWSResourceName, + ) AWSResourceManager // IsAdoptable returns true if the resource is able to be adopted IsAdoptable() bool // RequeueOnSuccessSeconds returns true if the resource should be requeued after specified seconds diff --git a/pkg/types/service_controller.go b/pkg/types/service_controller.go index d8ed9d57..e19582a2 100644 --- a/pkg/types/service_controller.go +++ b/pkg/types/service_controller.go @@ -85,7 +85,8 @@ type ServiceController interface { *string, ackv1alpha1.AWSResourceName, schema.GroupVersionKind, - ) (aws.Config, error) + string, + ) (aws.Config, string, error) // GetMetadata returns the metadata associated with the service controller. GetMetadata() ServiceControllerMetadata