diff --git a/Makefile b/Makefile index 8bb9c6c..af2857f 100644 --- a/Makefile +++ b/Makefile @@ -82,7 +82,7 @@ clean: ## Clean temporary files @ go clean @ rm -f ./vui @ rm -f *.log - @ rm -f coverage.* + @ rm -f coverage.* ./*.test @ rm -rf ./dist/ .PHONY: clean diff --git a/README.md b/README.md index d3856e2..63de65c 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,20 @@ profiles: aws_region: "us-east-1" ``` +#### AWS SSM Parameters + +```yaml +profiles: + aws_ssm: + engine: aws/ssm + auth_method: "aws" + auth_config: + aws_access_key_id: "${AWS_ACCESS_KEY_ID}" + aws_secret_access_key: "${AWS_SECRET_ACCESS_KEY}" + aws_session_token: "${AWS_SESSION_TOKEN}" + aws_region: "us-east-1" +``` + ## Installation ### Download from Release diff --git a/configs/vui.yaml b/configs/vui.yaml index d0e40d5..175973a 100644 --- a/configs/vui.yaml +++ b/configs/vui.yaml @@ -102,3 +102,14 @@ profiles: aws_access_key_id: "${AWS_ACCESS_KEY_ID}" aws_secret_access_key: "${AWS_SECRET_ACCESS_KEY}" aws_session_token: "${AWS_SESSION_TOKEN}" + + aws-ssm: + engine: "aws/ssm" + address: "http://localhost:4566" + auth_method: "aws" + namespace: "" + auth_config: + aws_region: "us-east-1" + aws_access_key_id: "${AWS_ACCESS_KEY_ID}" + aws_secret_access_key: "${AWS_SECRET_ACCESS_KEY}" + aws_session_token: "${AWS_SESSION_TOKEN}" diff --git a/internal/engines/aws/aws_ssm.go b/internal/engines/aws/aws_ssm.go new file mode 100644 index 0000000..21e50c0 --- /dev/null +++ b/internal/engines/aws/aws_ssm.go @@ -0,0 +1,427 @@ +package aws + +import ( + "context" + "encoding/json" + "fmt" + "path/filepath" + "strings" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go/service/ssm/ssmiface" + "github.com/aws/aws-sdk-go/service/sts" + "github.com/rvolykh/vui/internal/config" + "github.com/rvolykh/vui/internal/models" + "github.com/rvolykh/vui/internal/utils" + "github.com/sirupsen/logrus" +) + +type AWSSSMClient struct { + client ssmiface.SSMAPI + profile *config.Profile + logger *logrus.Logger + region string + address string +} + +func NewAWSSSMClient(logger *logrus.Logger, profile *config.Profile) (*AWSSSMClient, error) { + region := utils.Coalesce(profile.AuthConfig.AWSRegion, "us-east-1") + + awsConfig := aws.NewConfig().WithRegion(region) + + if profile.Address != "" { + awsConfig.WithEndpoint(profile.Address) + } + + if profile.AuthConfig.AWSAccessKeyID == "" || profile.AuthConfig.AWSSecretAccessKey == "" { + return nil, fmt.Errorf("aws_access_key_id and aws_secret_access_key are required for AWS SSM Parameter Store authentication") + } + + awsConfig.WithCredentials(credentials.NewStaticCredentials( + profile.AuthConfig.AWSAccessKeyID, + profile.AuthConfig.AWSSecretAccessKey, + profile.AuthConfig.AWSSessionToken, + )) + + sess, err := session.NewSession(awsConfig) + if err != nil { + return nil, fmt.Errorf("failed to create AWS session for static credentials: %w", err) + } + + awsRole := profile.AuthConfig.AWSRole + if awsRole != "" { + awsConfig.WithCredentials(stscreds.NewCredentials(sess, awsRole)) + + sess, err = session.NewSession(awsConfig) + if err != nil { + return nil, fmt.Errorf("failed to create AWS session for assumed role: %w", err) + } + } + + address := profile.Address + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + result, err := sts.New(sess).GetCallerIdentityWithContext(ctx, &sts.GetCallerIdentityInput{}) + if err == nil && result != nil && result.Account != nil { + address = fmt.Sprintf("aws://%s:%s", *result.Account, region) + } else if profile.Address == "" { + address = fmt.Sprintf("ssm.%s.amazonaws.com", region) + } + + return &AWSSSMClient{ + client: ssm.New(sess), + profile: profile, + logger: logger, + region: region, + address: address, + }, nil +} + +func (c *AWSSSMClient) Authenticate() error { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + // For AWS SSM, we can verify by making a simple API call + _, err := c.client.DescribeParametersWithContext(ctx, &ssm.DescribeParametersInput{ + MaxResults: aws.Int64(1), + }) + if err != nil { + return fmt.Errorf("failed to authenticate with AWS SSM Parameter Store: %w", err) + } + + c.logger.Debug("AWS SSM Parameter Store authentication verified successfully") + return nil +} + +func (c *AWSSSMClient) GetAddress() string { + if c.address == "" { + return c.profile.Address + } + return c.address +} + +func (c *AWSSSMClient) GetStatus(ctx context.Context) (models.ConnectionStatus, error) { + _, err := c.client.DescribeParametersWithContext(ctx, &ssm.DescribeParametersInput{ + MaxResults: aws.Int64(1), + }) + if err != nil { + return models.ConnectionStatus{ + Status: models.StatusDisconnected, + Address: c.GetAddress(), + LastCheck: time.Now(), + Error: err.Error(), + }, nil + } + + return models.ConnectionStatus{ + Status: models.StatusConnected, + Address: c.GetAddress(), + Version: "AWS SSM Parameter Store", + ClusterID: c.region, + LastCheck: time.Now(), + }, nil +} + +func (c *AWSSSMClient) ListSecrets(path string) ([]*models.SecretNode, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + normalizedPath := strings.Trim(path, "/") + + // SSM Parameter Store uses hierarchical paths + // If path is empty, we start from root. Otherwise, we use the path as a prefix + parameterFilters := []*ssm.ParameterStringFilter{} + + if normalizedPath != "" { + // Add a prefix filter - SSM paths typically start with / + prefix := "/" + normalizedPath + parameterFilters = append(parameterFilters, &ssm.ParameterStringFilter{ + Key: aws.String("Name"), + Option: aws.String("BeginsWith"), + Values: []*string{aws.String(prefix)}, + }) + } + + input := &ssm.DescribeParametersInput{ + MaxResults: aws.Int64(10), + ParameterFilters: parameterFilters, + } + + var allParameters []*ssm.ParameterMetadata + err := c.client.DescribeParametersPagesWithContext(ctx, input, func(page *ssm.DescribeParametersOutput, lastPage bool) bool { + if page.Parameters != nil { + allParameters = append(allParameters, page.Parameters...) + } + return !lastPage + }) + if err != nil { + return nil, fmt.Errorf("failed to list parameters: %w", err) + } + + // Build a tree structure from parameter names + nodeMap := make(map[string]*models.SecretNode) + + for _, param := range allParameters { + if param.Name == nil { + continue + } + + paramName := *param.Name + // Remove leading slash if present + paramName = strings.TrimPrefix(paramName, "/") + + // Remove the prefix if we're filtering by path + var relativePath string + if normalizedPath != "" { + if !strings.HasPrefix(paramName, normalizedPath+"/") && paramName != normalizedPath { + continue + } + if paramName == normalizedPath { + // This is the exact path requested, skip it as it's not a child + continue + } + relativePath = strings.TrimPrefix(paramName, normalizedPath+"/") + } else { + relativePath = paramName + } + + // Split into parts to find immediate children + parts := strings.Split(relativePath, "/") + if len(parts) == 0 { + continue + } + + firstPart := parts[0] + fullPath := firstPart + if normalizedPath != "" { + fullPath = normalizedPath + "/" + firstPart + } + // Ensure full path starts with / + if !strings.HasPrefix(fullPath, "/") { + fullPath = "/" + fullPath + } + + // Check if this is a direct child or a nested path + if len(parts) == 1 { + // This is a direct parameter at the current path level + node := &models.SecretNode{ + Name: firstPart, + Path: "/" + paramName, // Use full parameter name with leading slash + IsSecret: true, + } + + if param.LastModifiedDate != nil { + node.Metadata = &models.SecretMetadata{ + CreatedTime: *param.LastModifiedDate, + Version: 1, + } + } + + nodeMap[firstPart] = node + } else { + // This is a nested path, create a directory node + if _, exists := nodeMap[firstPart]; !exists { + dirNode := &models.SecretNode{ + Name: firstPart, + Path: fullPath, + IsSecret: false, + Children: []*models.SecretNode{}, + } + nodeMap[firstPart] = dirNode + } + } + } + + // Convert map to slice + result := []*models.SecretNode{} + for _, node := range nodeMap { + result = append(result, node) + } + + return result, nil +} + +func (c *AWSSSMClient) GetSecret(path string) (*models.SecretNode, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // AWS SSM uses parameter name as identifier + // Ensure path starts with / + paramName := strings.Trim(path, "/") + if !strings.HasPrefix(paramName, "/") { + paramName = "/" + paramName + } + + input := &ssm.GetParameterInput{ + Name: aws.String(paramName), + WithDecryption: aws.Bool(true), // Always decrypt SecureString parameters + } + + result, err := c.client.GetParameterWithContext(ctx, input) + if err != nil { + return nil, fmt.Errorf("failed to get parameter '%s': %w", path, err) + } + + if result.Parameter == nil { + return nil, fmt.Errorf("parameter '%s' not found", path) + } + + param := result.Parameter + node := &models.SecretNode{ + Name: filepath.Base(paramName), + Path: paramName, + IsSecret: true, + Metadata: &models.SecretMetadata{}, + } + + // Parse parameter value + var secretData map[string]any + paramValue := "" + if param.Value != nil { + paramValue = *param.Value + } + + // Try to parse as JSON + if err := json.Unmarshal([]byte(paramValue), &secretData); err != nil { + // If not JSON, treat as plain string + secretData = map[string]any{ + "value": paramValue, + } + } + + node.Data = secretData + + // Get additional metadata + if param.Version != nil { + node.Metadata.Version = int(*param.Version) + } + if param.LastModifiedDate != nil { + node.Metadata.CreatedTime = *param.LastModifiedDate + } + + return node, nil +} + +func (c *AWSSSMClient) CreateSecret(path string, data map[string]any) error { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Ensure path starts with / + paramName := strings.Trim(path, "/") + if !strings.HasPrefix(paramName, "/") { + paramName = "/" + paramName + } + + // Convert data map to JSON string + jsonData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal secret data: %w", err) + } + + // Determine parameter type - if data contains sensitive info, use SecureString + // For simplicity, we'll use SecureString for all secrets + paramType := "SecureString" + + // Check if the data suggests it's a plain string (single "value" key) + if len(data) == 1 { + if _, ok := data["value"]; ok { + // If it's a simple value, use String type + paramType = "String" + jsonData = []byte(fmt.Sprintf("%v", data["value"])) + } + } + + input := &ssm.PutParameterInput{ + Name: aws.String(paramName), + Value: aws.String(string(jsonData)), + Type: aws.String(paramType), + Overwrite: aws.Bool(false), // Don't overwrite existing parameters + } + + _, err = c.client.PutParameterWithContext(ctx, input) + if err != nil { + return fmt.Errorf("failed to create parameter '%s': %w", path, err) + } + + c.logger.Infof("Created parameter: %s", paramName) + return nil +} + +func (c *AWSSSMClient) UpdateSecret(path string, data map[string]any) error { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Ensure path starts with / + paramName := strings.Trim(path, "/") + if !strings.HasPrefix(paramName, "/") { + paramName = "/" + paramName + } + + // Convert data map to JSON string + jsonData, err := json.Marshal(data) + if err != nil { + return fmt.Errorf("failed to marshal secret data: %w", err) + } + + // Get existing parameter to determine type + getInput := &ssm.GetParameterInput{ + Name: aws.String(paramName), + } + existingParam, err := c.client.GetParameterWithContext(ctx, getInput) + + paramType := "SecureString" + if err == nil && existingParam != nil && existingParam.Parameter != nil && existingParam.Parameter.Type != nil { + paramType = *existingParam.Parameter.Type + } else { + // If parameter doesn't exist or we can't determine type, check data structure + if len(data) == 1 { + if _, ok := data["value"]; ok { + paramType = "String" + jsonData = []byte(fmt.Sprintf("%v", data["value"])) + } + } + } + + input := &ssm.PutParameterInput{ + Name: aws.String(paramName), + Value: aws.String(string(jsonData)), + Type: aws.String(paramType), + Overwrite: aws.Bool(true), // Overwrite for updates + } + + _, err = c.client.PutParameterWithContext(ctx, input) + if err != nil { + return fmt.Errorf("failed to update parameter '%s': %w", path, err) + } + + c.logger.Infof("Updated parameter: %s", paramName) + return nil +} + +func (c *AWSSSMClient) DeleteSecret(path string) error { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + // Ensure path starts with / + paramName := strings.Trim(path, "/") + if !strings.HasPrefix(paramName, "/") { + paramName = "/" + paramName + } + + input := &ssm.DeleteParameterInput{ + Name: aws.String(paramName), + } + + _, err := c.client.DeleteParameterWithContext(ctx, input) + if err != nil { + return fmt.Errorf("failed to delete parameter '%s': %w", path, err) + } + + c.logger.Infof("Deleted parameter: %s", paramName) + return nil +} diff --git a/internal/engines/aws/aws_ssm_test.go b/internal/engines/aws/aws_ssm_test.go new file mode 100644 index 0000000..bd0255e --- /dev/null +++ b/internal/engines/aws/aws_ssm_test.go @@ -0,0 +1,698 @@ +package aws + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/ssm" + "github.com/aws/aws-sdk-go/service/ssm/ssmiface" + "github.com/rvolykh/vui/internal/config" + "github.com/rvolykh/vui/internal/models" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// mockSSM is a mock implementation of ssmiface.SSMAPI +type mockSSM struct { + ssmiface.SSMAPI + describeParametersOutput *ssm.DescribeParametersOutput + describeParametersError error + getParameterOutput *ssm.GetParameterOutput + getParameterError error + putParameterOutput *ssm.PutParameterOutput + putParameterError error + deleteParameterOutput *ssm.DeleteParameterOutput + deleteParameterError error +} + +func (m *mockSSM) DescribeParametersPagesWithContext(ctx aws.Context, input *ssm.DescribeParametersInput, fn func(*ssm.DescribeParametersOutput, bool) bool, opts ...request.Option) error { + if m.describeParametersError != nil { + return m.describeParametersError + } + if m.describeParametersOutput != nil { + fn(m.describeParametersOutput, true) + } + return nil +} + +func (m *mockSSM) DescribeParametersWithContext(ctx aws.Context, input *ssm.DescribeParametersInput, opts ...request.Option) (*ssm.DescribeParametersOutput, error) { + if m.describeParametersError != nil { + return nil, m.describeParametersError + } + return m.describeParametersOutput, nil +} + +func (m *mockSSM) GetParameterWithContext(ctx aws.Context, input *ssm.GetParameterInput, opts ...request.Option) (*ssm.GetParameterOutput, error) { + if m.getParameterError != nil { + return nil, m.getParameterError + } + return m.getParameterOutput, nil +} + +func (m *mockSSM) PutParameterWithContext(ctx aws.Context, input *ssm.PutParameterInput, opts ...request.Option) (*ssm.PutParameterOutput, error) { + if m.putParameterError != nil { + return nil, m.putParameterError + } + return m.putParameterOutput, nil +} + +func (m *mockSSM) DeleteParameterWithContext(ctx aws.Context, input *ssm.DeleteParameterInput, opts ...request.Option) (*ssm.DeleteParameterOutput, error) { + if m.deleteParameterError != nil { + return nil, m.deleteParameterError + } + return m.deleteParameterOutput, nil +} + +// Helper function to create a test client with a mock SSM +func createTestSSMClientWithMock(mockSSM ssmiface.SSMAPI, profile *config.Profile) *AWSSSMClient { + return &AWSSSMClient{ + client: mockSSM, + profile: profile, + logger: logrus.New(), + region: "us-east-1", + address: "https://ssm.us-east-1.amazonaws.com", + } +} + +func TestAWSSSMClient_Implements(t *testing.T) { + t.Run("client has all required methods", func(t *testing.T) { + mockSSM := &mockSSM{} + profile := &config.Profile{} + client := createTestSSMClientWithMock(mockSSM, profile) + + // Verify client implements all required methods + var _ interface { + Authenticate() error + GetAddress() string + GetStatus(context.Context) (models.ConnectionStatus, error) + ListSecrets(string) ([]*models.SecretNode, error) + GetSecret(string) (*models.SecretNode, error) + CreateSecret(string, map[string]any) error + UpdateSecret(string, map[string]any) error + DeleteSecret(string) error + } = client + }) +} + +func TestNewAWSSSMClient(t *testing.T) { + logger := logrus.New() + logger.SetLevel(logrus.ErrorLevel) + + tests := []struct { + name string + profile *config.Profile + wantError bool + errorContains string + }{ + { + name: "valid profile with credentials", + profile: &config.Profile{ + AuthConfig: config.AuthConfig{ + AWSAccessKeyID: "test-key", + AWSSecretAccessKey: "test-secret", + AWSRegion: "us-west-2", + }, + }, + wantError: false, + }, + { + name: "valid profile with default region", + profile: &config.Profile{ + AuthConfig: config.AuthConfig{ + AWSAccessKeyID: "test-key", + AWSSecretAccessKey: "test-secret", + }, + }, + wantError: false, + }, + { + name: "missing access key", + profile: &config.Profile{ + AuthConfig: config.AuthConfig{ + AWSSecretAccessKey: "test-secret", + }, + }, + wantError: true, + errorContains: "aws_access_key_id", + }, + { + name: "missing secret key", + profile: &config.Profile{ + AuthConfig: config.AuthConfig{ + AWSAccessKeyID: "test-key", + }, + }, + wantError: true, + errorContains: "aws_secret_access_key", + }, + { + name: "with custom endpoint", + profile: &config.Profile{ + Address: "http://localhost:4566", + AuthConfig: config.AuthConfig{ + AWSAccessKeyID: "test-key", + AWSSecretAccessKey: "test-secret", + }, + }, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + client, err := NewAWSSSMClient(logger, tt.profile) + if tt.wantError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, client) + } else { + // Note: This will fail if AWS credentials are not configured + // For production tests, you'd need actual AWS credentials or skip this test + if err != nil { + t.Skipf("Skipping test: AWS credentials not configured: %v", err) + return + } + require.NoError(t, err) + require.NotNil(t, client) + assert.Equal(t, tt.profile, client.profile) + } + }) + } +} + +func TestAWSSSMClient_Authenticate(t *testing.T) { + tests := []struct { + name string + mockSSM *mockSSM + wantError bool + errorContains string + }{ + { + name: "successful authentication", + mockSSM: &mockSSM{ + describeParametersOutput: &ssm.DescribeParametersOutput{}, + }, + wantError: false, + }, + { + name: "authentication failure", + mockSSM: &mockSSM{ + describeParametersError: errors.New("access denied"), + }, + wantError: true, + errorContains: "failed to authenticate", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + profile := &config.Profile{} + client := createTestSSMClientWithMock(tt.mockSSM, profile) + + err := client.Authenticate() + if tt.wantError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestAWSSSMClient_GetAddress(t *testing.T) { + tests := []struct { + name string + client *AWSSSMClient + profile *config.Profile + want string + }{ + { + name: "returns address when set", + client: &AWSSSMClient{ + address: "https://ssm.us-east-1.amazonaws.com", + }, + profile: &config.Profile{}, + want: "https://ssm.us-east-1.amazonaws.com", + }, + { + name: "returns profile address when client address is empty", + client: &AWSSSMClient{ + address: "", + profile: &config.Profile{ + Address: "http://localhost:4566", + }, + }, + profile: &config.Profile{ + Address: "http://localhost:4566", + }, + want: "http://localhost:4566", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.client.profile = tt.profile + got := tt.client.GetAddress() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestAWSSSMClient_GetStatus(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + mockSSM *mockSSM + wantStatus models.ConnectionStatus + wantError bool + errorContains string + }{ + { + name: "connected status", + mockSSM: &mockSSM{ + describeParametersOutput: &ssm.DescribeParametersOutput{}, + }, + wantStatus: models.ConnectionStatus{ + Status: models.StatusConnected, + Address: "https://ssm.us-east-1.amazonaws.com", + Version: "AWS SSM Parameter Store", + ClusterID: "us-east-1", + }, + wantError: false, + }, + { + name: "disconnected status", + mockSSM: &mockSSM{ + describeParametersError: errors.New("network error"), + }, + wantStatus: models.ConnectionStatus{ + Status: models.StatusDisconnected, + Address: "https://ssm.us-east-1.amazonaws.com", + Error: "network error", + }, + wantError: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + profile := &config.Profile{} + client := createTestSSMClientWithMock(tt.mockSSM, profile) + + status, err := client.GetStatus(ctx) + if tt.wantError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + assert.Equal(t, tt.wantStatus.Status, status.Status) + assert.Equal(t, tt.wantStatus.Address, status.Address) + assert.Equal(t, tt.wantStatus.Version, status.Version) + assert.Equal(t, tt.wantStatus.ClusterID, status.ClusterID) + if tt.wantStatus.Error != "" { + assert.Contains(t, status.Error, tt.wantStatus.Error) + } + assert.False(t, status.LastCheck.IsZero()) + } + }) + } +} + +func TestAWSSSMClient_ListSecrets(t *testing.T) { + tests := []struct { + name string + path string + mockSSM *mockSSM + wantNodes []*models.SecretNode + wantError bool + errorContains string + }{ + { + name: "list root parameters", + path: "", + mockSSM: &mockSSM{ + describeParametersOutput: &ssm.DescribeParametersOutput{ + Parameters: []*ssm.ParameterMetadata{ + { + Name: aws.String("/app/dev/db/password"), + LastModifiedDate: aws.Time(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)), + }, + { + Name: aws.String("/app/prod/api/key"), + LastModifiedDate: aws.Time(time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)), + }, + }, + }, + }, + wantNodes: []*models.SecretNode{ + { + Name: "app", + Path: "/app", + IsSecret: false, + Children: []*models.SecretNode{}, + }, + }, + wantError: false, + }, + { + name: "list parameters with path prefix", + path: "app/dev", + mockSSM: &mockSSM{ + describeParametersOutput: &ssm.DescribeParametersOutput{ + Parameters: []*ssm.ParameterMetadata{ + { + Name: aws.String("/app/dev/db/password"), + LastModifiedDate: aws.Time(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)), + }, + }, + }, + }, + wantNodes: []*models.SecretNode{ + { + Name: "db", + Path: "/app/dev/db", + IsSecret: false, + Children: []*models.SecretNode{}, + }, + }, + wantError: false, + }, + { + name: "list secrets error", + path: "", + mockSSM: &mockSSM{ + describeParametersError: errors.New("access denied"), + }, + wantError: true, + errorContains: "failed to list parameters", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + profile := &config.Profile{} + client := createTestSSMClientWithMock(tt.mockSSM, profile) + + nodes, err := client.ListSecrets(tt.path) + if tt.wantError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, nodes) + } else { + require.NoError(t, err) + require.NotNil(t, nodes) + if len(tt.wantNodes) > 0 { + assert.Equal(t, len(tt.wantNodes), len(nodes)) + for i, wantNode := range tt.wantNodes { + if i < len(nodes) { + assert.Equal(t, wantNode.Name, nodes[i].Name) + assert.Equal(t, wantNode.Path, nodes[i].Path) + assert.Equal(t, wantNode.IsSecret, nodes[i].IsSecret) + } + } + } + } + }) + } +} + +func TestAWSSSMClient_GetSecret(t *testing.T) { + tests := []struct { + name string + path string + mockSSM *mockSSM + wantNode *models.SecretNode + wantError bool + errorContains string + }{ + { + name: "get secret successfully", + path: "/app/dev/db/password", + mockSSM: &mockSSM{ + getParameterOutput: &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: aws.String("/app/dev/db/password"), + Value: aws.String(`{"password":"secret123"}`), + Type: aws.String("SecureString"), + Version: aws.Int64(1), + LastModifiedDate: aws.Time(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)), + }, + }, + }, + wantNode: &models.SecretNode{ + Name: "password", + Path: "/app/dev/db/password", + IsSecret: true, + Data: map[string]any{ + "password": "secret123", + }, + Metadata: &models.SecretMetadata{ + Version: 1, + CreatedTime: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + }, + wantError: false, + }, + { + name: "get secret with plain string value", + path: "/app/key", + mockSSM: &mockSSM{ + getParameterOutput: &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Name: aws.String("/app/key"), + Value: aws.String("simple-value"), + Type: aws.String("String"), + Version: aws.Int64(2), + LastModifiedDate: aws.Time(time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)), + }, + }, + }, + wantNode: &models.SecretNode{ + Name: "key", + Path: "/app/key", + IsSecret: true, + Data: map[string]any{ + "value": "simple-value", + }, + Metadata: &models.SecretMetadata{ + Version: 2, + CreatedTime: time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC), + }, + }, + wantError: false, + }, + { + name: "get secret error", + path: "/nonexistent", + mockSSM: &mockSSM{ + getParameterError: errors.New("parameter not found"), + }, + wantError: true, + errorContains: "failed to get parameter", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + profile := &config.Profile{} + client := createTestSSMClientWithMock(tt.mockSSM, profile) + + node, err := client.GetSecret(tt.path) + if tt.wantError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + assert.Nil(t, node) + } else { + require.NoError(t, err) + require.NotNil(t, node) + assert.Equal(t, tt.wantNode.Name, node.Name) + assert.Equal(t, tt.wantNode.Path, node.Path) + assert.Equal(t, tt.wantNode.IsSecret, node.IsSecret) + if tt.wantNode.Data != nil { + assert.Equal(t, tt.wantNode.Data, node.Data) + } + if tt.wantNode.Metadata != nil { + require.NotNil(t, node.Metadata) + assert.Equal(t, tt.wantNode.Metadata.Version, node.Metadata.Version) + assert.Equal(t, tt.wantNode.Metadata.CreatedTime, node.Metadata.CreatedTime) + } + } + }) + } +} + +func TestAWSSSMClient_CreateSecret(t *testing.T) { + tests := []struct { + name string + path string + data map[string]any + mockSSM *mockSSM + wantError bool + errorContains string + }{ + { + name: "create secret successfully", + path: "/app/dev/key", + data: map[string]any{ + "password": "secret123", + "username": "admin", + }, + mockSSM: &mockSSM{ + putParameterOutput: &ssm.PutParameterOutput{ + Version: aws.Int64(1), + }, + }, + wantError: false, + }, + { + name: "create secret error", + path: "/app/key", + data: map[string]any{ + "value": "test", + }, + mockSSM: &mockSSM{ + putParameterError: errors.New("parameter already exists"), + }, + wantError: true, + errorContains: "failed to create parameter", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + profile := &config.Profile{} + client := createTestSSMClientWithMock(tt.mockSSM, profile) + + err := client.CreateSecret(tt.path, tt.data) + if tt.wantError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestAWSSSMClient_UpdateSecret(t *testing.T) { + tests := []struct { + name string + path string + data map[string]any + mockSSM *mockSSM + wantError bool + errorContains string + }{ + { + name: "update secret successfully", + path: "/app/dev/key", + data: map[string]any{ + "password": "newsecret123", + }, + mockSSM: &mockSSM{ + getParameterOutput: &ssm.GetParameterOutput{ + Parameter: &ssm.Parameter{ + Type: aws.String("SecureString"), + }, + }, + putParameterOutput: &ssm.PutParameterOutput{ + Version: aws.Int64(2), + }, + }, + wantError: false, + }, + { + name: "update secret error", + path: "/app/key", + data: map[string]any{ + "value": "test", + }, + mockSSM: &mockSSM{ + putParameterError: errors.New("parameter update failed"), + }, + wantError: true, + errorContains: "failed to update parameter", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + profile := &config.Profile{} + client := createTestSSMClientWithMock(tt.mockSSM, profile) + + err := client.UpdateSecret(tt.path, tt.data) + if tt.wantError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + } + }) + } +} + +func TestAWSSSMClient_DeleteSecret(t *testing.T) { + tests := []struct { + name string + path string + mockSSM *mockSSM + wantError bool + errorContains string + }{ + { + name: "delete secret successfully", + path: "/app/dev/key", + mockSSM: &mockSSM{ + deleteParameterOutput: &ssm.DeleteParameterOutput{}, + }, + wantError: false, + }, + { + name: "delete secret error", + path: "/nonexistent", + mockSSM: &mockSSM{ + deleteParameterError: errors.New("parameter not found"), + }, + wantError: true, + errorContains: "failed to delete parameter", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + profile := &config.Profile{} + client := createTestSSMClientWithMock(tt.mockSSM, profile) + + err := client.DeleteSecret(tt.path) + if tt.wantError { + require.Error(t, err) + if tt.errorContains != "" { + assert.Contains(t, err.Error(), tt.errorContains) + } + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/internal/engines/engines_factory.go b/internal/engines/engines_factory.go index c57557a..0cd54e3 100644 --- a/internal/engines/engines_factory.go +++ b/internal/engines/engines_factory.go @@ -25,6 +25,8 @@ func (f *EnginesFactory) SetupEngine(name string, profile *config.Profile) (Secr return vault.NewVaultClient(f.logger, profile) case "aws/secretsmanager": return aws.NewAWSSecretsManagerClient(f.logger, profile) + case "aws/ssm": + return aws.NewAWSSSMClient(f.logger, profile) default: return nil, fmt.Errorf("unknown engine: %s", name) } diff --git a/internal/engines/engines_factory_test.go b/internal/engines/engines_factory_test.go index 1286a91..13647ac 100644 --- a/internal/engines/engines_factory_test.go +++ b/internal/engines/engines_factory_test.go @@ -38,6 +38,21 @@ func TestEnginesFactory_SetupEngine(t *testing.T) { assert.NotNil(t, engine) }) + t.Run("aws_ssm", func(t *testing.T) { + engine, err := factory.SetupEngine("aws/ssm", &config.Profile{ + Engine: "aws/ssm", + Address: "http://localhost:8200", + AuthMethod: "aws", + AuthConfig: config.AuthConfig{ + AWSAccessKeyID: "test-key", + AWSSecretAccessKey: "test-secret", + AWSRegion: "us-west-2", + }, + }) + require.NoError(t, err) + assert.NotNil(t, engine) + }) + t.Run("unknown", func(t *testing.T) { engine, err := factory.SetupEngine("unknown", &config.Profile{ Engine: "unknown", diff --git a/sandbox/docker-compose.yml b/sandbox/docker-compose.yml index 18af7ce..b2e348e 100644 --- a/sandbox/docker-compose.yml +++ b/sandbox/docker-compose.yml @@ -37,7 +37,7 @@ services: ports: - "127.0.0.1:4566:4566" environment: - - SERVICES=sts,iam,secretsmanager + - SERVICES=sts,iam,secretsmanager,ssm volumes: - "./vol/localstack:/var/lib/localstack"