From 507924a8d53ac26efef1e16b7b850ad77ad330ce Mon Sep 17 00:00:00 2001 From: jackthepunished <107313375+jackthepunished@users.noreply.github.com> Date: Tue, 28 Apr 2026 15:32:08 +0300 Subject: [PATCH 1/4] fix(security): stop logging plaintext password reset tokens The password reset flow logged the plaintext token alongside the user's email at Debug level so it could be picked up during MVP testing. Anyone with log access could mint a password reset and take over the account. Replace the Debug log with an Info entry that records only the user_id and reset-token row id. The plaintext token must reach the user via email; once an EmailService is wired in it should be the sole consumer. Closes #293 --- internal/core/services/password_reset.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/internal/core/services/password_reset.go b/internal/core/services/password_reset.go index 725a84b0d..c534661ba 100644 --- a/internal/core/services/password_reset.go +++ b/internal/core/services/password_reset.go @@ -75,10 +75,9 @@ func (s *PasswordResetService) RequestReset(ctx context.Context, email string) e return err } - // Note: EmailService integration is pending. - // For MVP/Demo: Log the token so we can test it manually. - // Future: Inject and use EmailService here. - s.logger.Debug("password reset token", "email", email, "token", token) + // TODO: deliver `token` via an injected EmailService once available. + // Never log or persist the plaintext token — its only safe destination is the user. + s.logger.Info("password reset token issued", "user_id", user.ID, "token_id", resetToken.ID) return nil } From 8e802aa4f46cadcb0381fbc8c914e3ae07f53cef Mon Sep 17 00:00:00 2001 From: jackthepunished <107313375+jackthepunished@users.noreply.github.com> Date: Tue, 28 Apr 2026 15:35:12 +0300 Subject: [PATCH 2/4] fix(security): sanitize Content-Disposition filename to prevent header injection Object keys flowed straight into the Content-Disposition header via fmt.Sprintf("attachment; filename=%s", key), so any key containing CRLF, double quotes, or backslashes could split the response or inject arbitrary headers. Path-bearing keys also leaked the full bucket path to the client as the suggested filename. Replace both call sites (authenticated download and presigned download) with a single helper, contentDispositionAttachment, that: - reduces the key to its basename via path.Base - emits an ASCII-only `filename="..."` fallback with control bytes, non-ASCII bytes, quotes, and backslashes mapped to `_` - emits the full Unicode basename in `filename*=UTF-8''...` per RFC 5987 with proper attr-char percent-encoding - guarantees CR/LF can never reach the wire Add table-driven unit tests covering response-splitting, quote/backslash injection, non-ASCII names, nested keys, and empty keys. Closes #225, #226 --- internal/handlers/storage_handler.go | 79 ++++++++++++++++- ...torage_handler_content_disposition_test.go | 86 +++++++++++++++++++ 2 files changed, 163 insertions(+), 2 deletions(-) create mode 100644 internal/handlers/storage_handler_content_disposition_test.go diff --git a/internal/handlers/storage_handler.go b/internal/handlers/storage_handler.go index 31151aa71..42f810e89 100644 --- a/internal/handlers/storage_handler.go +++ b/internal/handlers/storage_handler.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "path" "strconv" "strings" "time" @@ -38,6 +39,80 @@ const ( headerContentSha256 = "X-Content-Sha256" ) +// contentDispositionAttachment builds a safe `Content-Disposition: attachment` +// header for a stored object. +// +// Object keys can contain path segments, non-ASCII characters, control +// characters, quotes, backslashes, or CRLF that — if interpolated naively — +// would either corrupt the header (HTTP response splitting) or let an attacker +// inject additional headers. The output therefore emits two parameters per +// RFC 6266: +// +// - `filename="..."` ASCII-only fallback for legacy clients. All bytes +// outside the safe printable range and the two +// characters that are special inside a quoted-string +// (`"` and `\`) are replaced with `_`. +// - `filename*=UTF-8''…` RFC 5987 percent-encoded form preserving the +// original Unicode basename for modern clients. +// +// `path.Base` is used to discard any path segments embedded in the key. If the +// resulting name is empty we fall back to "download". +func contentDispositionAttachment(key string) string { + name := path.Base(key) + if name == "." || name == "/" || name == "" { + name = "download" + } + + return fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, + asciiFilenameFallback(name), rfc5987Encode(name)) +} + +// asciiFilenameFallback returns an ASCII-only sanitized copy of name suitable +// for the legacy `filename=` parameter. Any byte that is a control character +// (<0x20 or 0x7f), non-ASCII (>=0x80), or special inside a quoted-string is +// replaced with `_` so the value can be safely wrapped in double quotes. +func asciiFilenameFallback(name string) string { + out := make([]byte, 0, len(name)) + for i := 0; i < len(name); i++ { + c := name[i] + switch { + case c < 0x20, c == 0x7f, c >= 0x80, c == '"', c == '\\': + out = append(out, '_') + default: + out = append(out, c) + } + } + if len(out) == 0 { + return "download" + } + return string(out) +} + +// rfc5987Encode percent-encodes a value per RFC 5987 attr-char rules so it can +// be safely placed in a `filename*` parameter. +func rfc5987Encode(s string) string { + const hex = "0123456789ABCDEF" + var b strings.Builder + b.Grow(len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + switch { + case c >= 'A' && c <= 'Z', + c >= 'a' && c <= 'z', + c >= '0' && c <= '9', + c == '!', c == '#', c == '$', c == '&', c == '+', + c == '-', c == '.', c == '^', c == '_', c == '`', + c == '|', c == '~': + b.WriteByte(c) + default: + b.WriteByte('%') + b.WriteByte(hex[c>>4]) + b.WriteByte(hex[c&0x0f]) + } + } + return b.String() +} + // Upload uploads an object to a bucket // @Summary Upload an object // @Description Uploads a file/object to the specified bucket and key @@ -105,7 +180,7 @@ func (h *StorageHandler) Download(c *gin.Context) { defer func() { _ = reader.Close() }() // Set headers - c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=%s", key)) + c.Header("Content-Disposition", contentDispositionAttachment(key)) c.Header("Content-Type", obj.ContentType) c.Header("Content-Length", fmt.Sprintf("%d", obj.SizeBytes)) @@ -461,7 +536,7 @@ func (h *StorageHandler) ServePresignedDownload(c *gin.Context) { } defer func() { _ = reader.Close() }() - c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=%s", key)) + c.Header("Content-Disposition", contentDispositionAttachment(key)) c.Header("Content-Type", obj.ContentType) c.Header("Content-Length", fmt.Sprintf("%d", obj.SizeBytes)) _, _ = io.Copy(c.Writer, reader) diff --git a/internal/handlers/storage_handler_content_disposition_test.go b/internal/handlers/storage_handler_content_disposition_test.go new file mode 100644 index 000000000..64a4aec1d --- /dev/null +++ b/internal/handlers/storage_handler_content_disposition_test.go @@ -0,0 +1,86 @@ +package httphandlers + +import ( + "strings" + "testing" +) + +func TestContentDispositionAttachment(t *testing.T) { + cases := []struct { + name string + key string + wantFilename string + wantFilenameStar string + }{ + { + name: "simple ascii filename", + key: "report.pdf", + wantFilename: `filename="report.pdf"`, + wantFilenameStar: `filename*=UTF-8''report.pdf`, + }, + { + name: "nested key uses basename", + key: "exports/2026/q1/report.pdf", + wantFilename: `filename="report.pdf"`, + wantFilenameStar: `filename*=UTF-8''report.pdf`, + }, + { + name: "CRLF response splitting attempt is sanitized", + key: "evil\r\nSet-Cookie: pwned=1", + wantFilename: `filename="evil__Set-Cookie: pwned=1"`, + wantFilenameStar: `filename*=UTF-8''evil%0D%0ASet-Cookie%3A%20pwned%3D1`, + }, + { + name: "embedded quote and backslash are sanitized", + key: `bad"name\file.txt`, + wantFilename: `filename="bad_name_file.txt"`, + wantFilenameStar: `filename*=UTF-8''bad%22name%5Cfile.txt`, + }, + { + name: "non-ASCII falls back in legacy filename, preserved in filename*", + key: "résumé.pdf", + wantFilename: `filename="r__sum__.pdf"`, // 2 bytes per accented char both replaced + wantFilenameStar: `filename*=UTF-8''r%C3%A9sum%C3%A9.pdf`, + }, + { + name: "empty key falls back to download", + key: "", + wantFilename: `filename="download"`, + wantFilenameStar: `filename*=UTF-8''download`, + }, + { + name: "trailing slash falls back to download", + key: "folder/", + wantFilename: `filename="folder"`, // path.Base normalizes + wantFilenameStar: `filename*=UTF-8''folder`, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := contentDispositionAttachment(tc.key) + if !strings.HasPrefix(got, "attachment; ") { + t.Errorf("missing attachment prefix: %q", got) + } + if !strings.Contains(got, tc.wantFilename) { + t.Errorf("missing legacy filename param\n got: %q\n want: %q", got, tc.wantFilename) + } + if !strings.Contains(got, tc.wantFilenameStar) { + t.Errorf("missing filename* param\n got: %q\n want: %q", got, tc.wantFilenameStar) + } + if strings.ContainsAny(got, "\r\n") { + t.Errorf("output must not contain CR/LF: %q", got) + } + }) + } +} + +func TestContentDispositionAttachment_NoCRLFEverEscapes(t *testing.T) { + for _, c := range []byte{'\r', '\n', 0, 0x1f, 0x7f} { + key := "x" + string(c) + "y" + got := contentDispositionAttachment(key) + if strings.ContainsAny(got, "\r\n") { + t.Fatalf("control byte 0x%02x leaked into header: %q", c, got) + } + } +} From 2cfa64c292a070145832553b9860d7d730a923d2 Mon Sep 17 00:00:00 2001 From: jackthepunished <107313375+jackthepunished@users.noreply.github.com> Date: Wed, 29 Apr 2026 22:30:36 +0300 Subject: [PATCH 3/4] fix(notify): surface non-2xx webhook delivery responses Webhook delivery in NotifyService.deliverToWebhook used http.DefaultClient and only inspected transport errors, so 4xx/5xx responses were silently treated as successful deliveries. Subscribers returning failures (bad endpoints, transient 5xx) had messages dropped without any log signal. Check resp.StatusCode and emit a structured warn log on >=400 responses, add a 15s timeout via a dedicated http.Client, drain the response body before close, and stop discarding the error from NewRequestWithContext. Closes #338 Co-Authored-By: Claude Opus 4.7 (1M context) --- internal/core/services/notify.go | 613 ++++++------- internal/core/services/notify_unit_test.go | 948 +++++++++++---------- 2 files changed, 833 insertions(+), 728 deletions(-) diff --git a/internal/core/services/notify.go b/internal/core/services/notify.go index babfc7aa3..517121499 100644 --- a/internal/core/services/notify.go +++ b/internal/core/services/notify.go @@ -1,297 +1,316 @@ -// Package services implements core business workflows. -package services - -import ( - "bytes" - "context" - "fmt" - "log/slog" - "net/http" - "time" - - "github.com/google/uuid" - appcontext "github.com/poyrazk/thecloud/internal/core/context" - "github.com/poyrazk/thecloud/internal/core/domain" - "github.com/poyrazk/thecloud/internal/core/ports" -) - -// NotifyServiceParams defines the dependencies for NotifyService. -type NotifyServiceParams struct { - Repo ports.NotifyRepository - RBACSvc ports.RBACService - QueueSvc ports.QueueService - EventSvc ports.EventService - AuditSvc ports.AuditService - Logger *slog.Logger -} - -// NotifyService manages topics, subscriptions, and message delivery. -type NotifyService struct { - repo ports.NotifyRepository - rbacSvc ports.RBACService - queueSvc ports.QueueService - eventSvc ports.EventService - auditSvc ports.AuditService - logger *slog.Logger -} - -// NewNotifyService constructs a NotifyService with its dependencies. -func NewNotifyService(params NotifyServiceParams) ports.NotifyService { - return &NotifyService{ - repo: params.Repo, - rbacSvc: params.RBACSvc, - queueSvc: params.QueueSvc, - eventSvc: params.EventSvc, - auditSvc: params.AuditSvc, - logger: params.Logger, - } -} - -func (s *NotifyService) CreateTopic(ctx context.Context, name string) (*domain.Topic, error) { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyCreate, "*"); err != nil { - return nil, err - } - - existing, _ := s.repo.GetTopicByName(ctx, name, userID) - if existing != nil { - return nil, fmt.Errorf("topic with name %s already exists", name) - } - - id := uuid.New() - topic := &domain.Topic{ - ID: id, - UserID: userID, - Name: name, - ARN: fmt.Sprintf("arn:thecloud:notify:local:%s:topic/%s", userID, name), - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } - - if err := s.repo.CreateTopic(ctx, topic); err != nil { - return nil, err - } - - _ = s.eventSvc.RecordEvent(ctx, "TOPIC_CREATED", topic.ID.String(), "TOPIC", nil) - - _ = s.auditSvc.Log(ctx, topic.UserID, "notify.topic_create", "topic", topic.ID.String(), map[string]interface{}{ - "name": topic.Name, - }) - - return topic, nil -} - -func (s *NotifyService) ListTopics(ctx context.Context) ([]*domain.Topic, error) { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyRead, "*"); err != nil { - return nil, err - } - - return s.repo.ListTopics(ctx, userID) -} - -func (s *NotifyService) DeleteTopic(ctx context.Context, id uuid.UUID) error { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyDelete, id.String()); err != nil { - return err - } - - topic, err := s.repo.GetTopicByID(ctx, id, userID) - if err != nil { - return err - } - if topic == nil { - return fmt.Errorf("topic not found") - } - - if err := s.repo.DeleteTopic(ctx, id); err != nil { - return err - } - - _ = s.eventSvc.RecordEvent(ctx, "TOPIC_DELETED", id.String(), "TOPIC", nil) - - _ = s.auditSvc.Log(ctx, topic.UserID, "notify.topic_delete", "topic", topic.ID.String(), map[string]interface{}{ - "name": topic.Name, - }) - - return nil -} - -func (s *NotifyService) Subscribe(ctx context.Context, topicID uuid.UUID, protocol domain.SubscriptionProtocol, endpoint string) (*domain.Subscription, error) { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyWrite, topicID.String()); err != nil { - return nil, err - } - - // Verify topic exists and belongs to user - topic, err := s.repo.GetTopicByID(ctx, topicID, userID) - if err != nil { - return nil, err - } - - sub := &domain.Subscription{ - ID: uuid.New(), - UserID: userID, - TopicID: topic.ID, - Protocol: protocol, - Endpoint: endpoint, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } - - if err := s.repo.CreateSubscription(ctx, sub); err != nil { - return nil, err - } - - _ = s.eventSvc.RecordEvent(ctx, "SUBSCRIPTION_CREATED", sub.ID.String(), "SUBSCRIPTION", map[string]interface{}{"topic_id": topicID}) - - _ = s.auditSvc.Log(ctx, sub.UserID, "notify.subscribe", "subscription", sub.ID.String(), map[string]interface{}{ - "topic_id": topicID.String(), - "protocol": protocol, - "endpoint": endpoint, - }) - - return sub, nil -} - -func (s *NotifyService) ListSubscriptions(ctx context.Context, topicID uuid.UUID) ([]*domain.Subscription, error) { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyRead, topicID.String()); err != nil { - return nil, err - } - - // Verify topic ownership - _, err := s.repo.GetTopicByID(ctx, topicID, userID) - if err != nil { - return nil, err - } - - return s.repo.ListSubscriptions(ctx, topicID) -} - -func (s *NotifyService) Unsubscribe(ctx context.Context, id uuid.UUID) error { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyDelete, id.String()); err != nil { - return err - } - - sub, err := s.repo.GetSubscriptionByID(ctx, id, userID) - if err != nil { - return err - } - - if err := s.repo.DeleteSubscription(ctx, sub.ID); err != nil { - return err - } - - _ = s.eventSvc.RecordEvent(ctx, "SUBSCRIPTION_DELETED", id.String(), "SUBSCRIPTION", nil) - - _ = s.auditSvc.Log(ctx, sub.UserID, "notify.unsubscribe", "subscription", sub.ID.String(), map[string]interface{}{ - "topic_id": sub.TopicID.String(), - }) - - return nil -} - -func (s *NotifyService) Publish(ctx context.Context, topicID uuid.UUID, body string) error { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyWrite, topicID.String()); err != nil { - return err - } - - topic, err := s.repo.GetTopicByID(ctx, topicID, userID) - if err != nil { - return err - } - - msg := &domain.NotifyMessage{ - ID: uuid.New(), - TopicID: topic.ID, - Body: body, - CreatedAt: time.Now(), - } - - if err := s.repo.SaveMessage(ctx, msg); err != nil { - return err - } - - subs, err := s.repo.ListSubscriptions(ctx, topicID) - if err != nil { - return err - } - - // Delivery logic - for _, sub := range subs { - // Create a background context for async delivery to avoid request cancellation - // but keep it separate for each subscriber to avoid shared timeout issues - go func(c context.Context, sub *domain.Subscription) { - deliveryCtx, cancel := context.WithTimeout(c, 30*time.Second) - defer cancel() - - // Carry over potential trace IDs or other metadata if needed - // (Simplified for now, but avoids using request-scoped ctx) - s.deliver(deliveryCtx, sub, body) - }(ctx, sub) - } - - if err := s.eventSvc.RecordEvent(ctx, "TOPIC_PUBLISHED", topic.ID.String(), "TOPIC", map[string]interface{}{"message_id": msg.ID}); err != nil { - s.logger.Warn("failed to record topic publish event", "topic_id", topic.ID, "error", err) - } - - if err := s.auditSvc.Log(ctx, topic.UserID, "notify.publish", "topic", topic.ID.String(), map[string]interface{}{ - "message_id": msg.ID.String(), - }); err != nil { - s.logger.Warn("failed to log topic publish audit event", "topic_id", topic.ID, "error", err) - } - - return nil -} - -func (s *NotifyService) deliver(ctx context.Context, sub *domain.Subscription, body string) { - switch sub.Protocol { - case domain.ProtocolQueue: - s.deliverToQueue(ctx, sub, body) - case domain.ProtocolWebhook: - s.deliverToWebhook(ctx, sub, body) - } -} - -func (s *NotifyService) deliverToQueue(ctx context.Context, sub *domain.Subscription, body string) { - // Endpoint is Queue ARN or ID. Let's assume ID for simplicity or parse ARN. - // For now let's assume endpoint is the Queue UUID string. - qID, err := uuid.Parse(sub.Endpoint) - if err != nil { - s.logger.Warn("invalid queue ID in subscription", "endpoint", sub.Endpoint, "error", err) - return - } - // We need to bypass user check or use sub.UserID context - deliveryCtx := appcontext.WithUserID(ctx, sub.UserID) - if _, err = s.queueSvc.SendMessage(deliveryCtx, qID, body); err != nil { - s.logger.Warn("failed to deliver to queue", "queue_id", qID, "error", err) - } -} - -func (s *NotifyService) deliverToWebhook(ctx context.Context, sub *domain.Subscription, body string) { - req, _ := http.NewRequestWithContext(ctx, "POST", sub.Endpoint, bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) - if err != nil { - s.logger.Warn("failed to deliver to webhook", "endpoint", sub.Endpoint, "error", err) - return - } - _ = resp.Body.Close() -} +// Package services implements core business workflows. +package services + +import ( + "bytes" + "context" + "fmt" + "io" + "log/slog" + "net/http" + "time" + + "github.com/google/uuid" + appcontext "github.com/poyrazk/thecloud/internal/core/context" + "github.com/poyrazk/thecloud/internal/core/domain" + "github.com/poyrazk/thecloud/internal/core/ports" +) + +var webhookHTTPClient = &http.Client{Timeout: 15 * time.Second} + +// NotifyServiceParams defines the dependencies for NotifyService. +type NotifyServiceParams struct { + Repo ports.NotifyRepository + RBACSvc ports.RBACService + QueueSvc ports.QueueService + EventSvc ports.EventService + AuditSvc ports.AuditService + Logger *slog.Logger +} + +// NotifyService manages topics, subscriptions, and message delivery. +type NotifyService struct { + repo ports.NotifyRepository + rbacSvc ports.RBACService + queueSvc ports.QueueService + eventSvc ports.EventService + auditSvc ports.AuditService + logger *slog.Logger +} + +// NewNotifyService constructs a NotifyService with its dependencies. +func NewNotifyService(params NotifyServiceParams) ports.NotifyService { + return &NotifyService{ + repo: params.Repo, + rbacSvc: params.RBACSvc, + queueSvc: params.QueueSvc, + eventSvc: params.EventSvc, + auditSvc: params.AuditSvc, + logger: params.Logger, + } +} + +func (s *NotifyService) CreateTopic(ctx context.Context, name string) (*domain.Topic, error) { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyCreate, "*"); err != nil { + return nil, err + } + + existing, _ := s.repo.GetTopicByName(ctx, name, userID) + if existing != nil { + return nil, fmt.Errorf("topic with name %s already exists", name) + } + + id := uuid.New() + topic := &domain.Topic{ + ID: id, + UserID: userID, + Name: name, + ARN: fmt.Sprintf("arn:thecloud:notify:local:%s:topic/%s", userID, name), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := s.repo.CreateTopic(ctx, topic); err != nil { + return nil, err + } + + _ = s.eventSvc.RecordEvent(ctx, "TOPIC_CREATED", topic.ID.String(), "TOPIC", nil) + + _ = s.auditSvc.Log(ctx, topic.UserID, "notify.topic_create", "topic", topic.ID.String(), map[string]interface{}{ + "name": topic.Name, + }) + + return topic, nil +} + +func (s *NotifyService) ListTopics(ctx context.Context) ([]*domain.Topic, error) { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyRead, "*"); err != nil { + return nil, err + } + + return s.repo.ListTopics(ctx, userID) +} + +func (s *NotifyService) DeleteTopic(ctx context.Context, id uuid.UUID) error { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyDelete, id.String()); err != nil { + return err + } + + topic, err := s.repo.GetTopicByID(ctx, id, userID) + if err != nil { + return err + } + if topic == nil { + return fmt.Errorf("topic not found") + } + + if err := s.repo.DeleteTopic(ctx, id); err != nil { + return err + } + + _ = s.eventSvc.RecordEvent(ctx, "TOPIC_DELETED", id.String(), "TOPIC", nil) + + _ = s.auditSvc.Log(ctx, topic.UserID, "notify.topic_delete", "topic", topic.ID.String(), map[string]interface{}{ + "name": topic.Name, + }) + + return nil +} + +func (s *NotifyService) Subscribe(ctx context.Context, topicID uuid.UUID, protocol domain.SubscriptionProtocol, endpoint string) (*domain.Subscription, error) { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyWrite, topicID.String()); err != nil { + return nil, err + } + + // Verify topic exists and belongs to user + topic, err := s.repo.GetTopicByID(ctx, topicID, userID) + if err != nil { + return nil, err + } + + sub := &domain.Subscription{ + ID: uuid.New(), + UserID: userID, + TopicID: topic.ID, + Protocol: protocol, + Endpoint: endpoint, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := s.repo.CreateSubscription(ctx, sub); err != nil { + return nil, err + } + + _ = s.eventSvc.RecordEvent(ctx, "SUBSCRIPTION_CREATED", sub.ID.String(), "SUBSCRIPTION", map[string]interface{}{"topic_id": topicID}) + + _ = s.auditSvc.Log(ctx, sub.UserID, "notify.subscribe", "subscription", sub.ID.String(), map[string]interface{}{ + "topic_id": topicID.String(), + "protocol": protocol, + "endpoint": endpoint, + }) + + return sub, nil +} + +func (s *NotifyService) ListSubscriptions(ctx context.Context, topicID uuid.UUID) ([]*domain.Subscription, error) { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyRead, topicID.String()); err != nil { + return nil, err + } + + // Verify topic ownership + _, err := s.repo.GetTopicByID(ctx, topicID, userID) + if err != nil { + return nil, err + } + + return s.repo.ListSubscriptions(ctx, topicID) +} + +func (s *NotifyService) Unsubscribe(ctx context.Context, id uuid.UUID) error { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyDelete, id.String()); err != nil { + return err + } + + sub, err := s.repo.GetSubscriptionByID(ctx, id, userID) + if err != nil { + return err + } + + if err := s.repo.DeleteSubscription(ctx, sub.ID); err != nil { + return err + } + + _ = s.eventSvc.RecordEvent(ctx, "SUBSCRIPTION_DELETED", id.String(), "SUBSCRIPTION", nil) + + _ = s.auditSvc.Log(ctx, sub.UserID, "notify.unsubscribe", "subscription", sub.ID.String(), map[string]interface{}{ + "topic_id": sub.TopicID.String(), + }) + + return nil +} + +func (s *NotifyService) Publish(ctx context.Context, topicID uuid.UUID, body string) error { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyWrite, topicID.String()); err != nil { + return err + } + + topic, err := s.repo.GetTopicByID(ctx, topicID, userID) + if err != nil { + return err + } + + msg := &domain.NotifyMessage{ + ID: uuid.New(), + TopicID: topic.ID, + Body: body, + CreatedAt: time.Now(), + } + + if err := s.repo.SaveMessage(ctx, msg); err != nil { + return err + } + + subs, err := s.repo.ListSubscriptions(ctx, topicID) + if err != nil { + return err + } + + // Delivery logic + for _, sub := range subs { + // Create a background context for async delivery to avoid request cancellation + // but keep it separate for each subscriber to avoid shared timeout issues + go func(c context.Context, sub *domain.Subscription) { + deliveryCtx, cancel := context.WithTimeout(c, 30*time.Second) + defer cancel() + + // Carry over potential trace IDs or other metadata if needed + // (Simplified for now, but avoids using request-scoped ctx) + s.deliver(deliveryCtx, sub, body) + }(ctx, sub) + } + + if err := s.eventSvc.RecordEvent(ctx, "TOPIC_PUBLISHED", topic.ID.String(), "TOPIC", map[string]interface{}{"message_id": msg.ID}); err != nil { + s.logger.Warn("failed to record topic publish event", "topic_id", topic.ID, "error", err) + } + + if err := s.auditSvc.Log(ctx, topic.UserID, "notify.publish", "topic", topic.ID.String(), map[string]interface{}{ + "message_id": msg.ID.String(), + }); err != nil { + s.logger.Warn("failed to log topic publish audit event", "topic_id", topic.ID, "error", err) + } + + return nil +} + +func (s *NotifyService) deliver(ctx context.Context, sub *domain.Subscription, body string) { + switch sub.Protocol { + case domain.ProtocolQueue: + s.deliverToQueue(ctx, sub, body) + case domain.ProtocolWebhook: + s.deliverToWebhook(ctx, sub, body) + } +} + +func (s *NotifyService) deliverToQueue(ctx context.Context, sub *domain.Subscription, body string) { + // Endpoint is Queue ARN or ID. Let's assume ID for simplicity or parse ARN. + // For now let's assume endpoint is the Queue UUID string. + qID, err := uuid.Parse(sub.Endpoint) + if err != nil { + s.logger.Warn("invalid queue ID in subscription", "endpoint", sub.Endpoint, "error", err) + return + } + // We need to bypass user check or use sub.UserID context + deliveryCtx := appcontext.WithUserID(ctx, sub.UserID) + if _, err = s.queueSvc.SendMessage(deliveryCtx, qID, body); err != nil { + s.logger.Warn("failed to deliver to queue", "queue_id", qID, "error", err) + } +} + +func (s *NotifyService) deliverToWebhook(ctx context.Context, sub *domain.Subscription, body string) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, sub.Endpoint, bytes.NewBufferString(body)) + if err != nil { + s.logger.Warn("failed to build webhook request", "endpoint", sub.Endpoint, "error", err) + return + } + req.Header.Set("Content-Type", "application/json") + + resp, err := webhookHTTPClient.Do(req) + if err != nil { + s.logger.Warn("failed to deliver to webhook", "endpoint", sub.Endpoint, "error", err) + return + } + defer func() { + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + }() + + if resp.StatusCode >= 400 { + s.logger.Warn("webhook delivery failed", + "endpoint", sub.Endpoint, + "subscription_id", sub.ID, + "status", resp.StatusCode) + return + } +} diff --git a/internal/core/services/notify_unit_test.go b/internal/core/services/notify_unit_test.go index 20024ce81..424b3f450 100644 --- a/internal/core/services/notify_unit_test.go +++ b/internal/core/services/notify_unit_test.go @@ -1,432 +1,518 @@ -package services_test - -import ( - "context" - "fmt" - "log/slog" - "testing" - - "github.com/google/uuid" - appcontext "github.com/poyrazk/thecloud/internal/core/context" - "github.com/poyrazk/thecloud/internal/core/domain" - "github.com/poyrazk/thecloud/internal/core/services" - "github.com/poyrazk/thecloud/internal/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -const testTopicName = "my-topic" - -func TestNotifyServiceUnit(t *testing.T) { - t.Run("CRUD", testNotifyServiceUnitCRUD) - t.Run("RBACErrors", testNotifyServiceUnitRbacErrors) - t.Run("RepoErrors", testNotifyServiceUnitRepoErrors) - t.Run("PublishErrors", testNotifyServiceUnitPublishErrors) -} - -func testNotifyServiceUnitCRUD(t *testing.T) { - mockRepo := new(MockNotifyRepo) - mockQueueSvc := new(MockQueueService) - mockEventSvc := new(MockEventService) - mockAuditSvc := new(MockAuditService) - rbacSvc := new(MockRBACService) - rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - - svc := services.NewNotifyService(services.NotifyServiceParams{ - Repo: mockRepo, - RBACSvc: rbacSvc, - QueueSvc: mockQueueSvc, - EventSvc: mockEventSvc, - AuditSvc: mockAuditSvc, - Logger: slog.Default(), - }) - - ctx := context.Background() - userID := uuid.New() - tenantID := uuid.New() - ctx = appcontext.WithUserID(ctx, userID) - ctx = appcontext.WithTenantID(ctx, tenantID) - - t.Run("CreateTopic", func(t *testing.T) { - mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(nil, nil).Once() - mockRepo.On("CreateTopic", mock.Anything, mock.Anything).Return(nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_CREATED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.topic_create", "topic", mock.Anything, mock.Anything).Return(nil).Once() - - topic, err := svc.CreateTopic(ctx, testTopicName) - require.NoError(t, err) - assert.NotNil(t, topic) - assert.Equal(t, testTopicName, topic.Name) - }) - - t.Run("ListTopics", func(t *testing.T) { - mockRepo.On("ListTopics", mock.Anything, userID).Return([]*domain.Topic{{ID: uuid.New(), Name: "topic1"}}, nil).Once() - - topics, err := svc.ListTopics(ctx) - require.NoError(t, err) - assert.Len(t, topics, 1) - }) - - t.Run("DeleteTopic", func(t *testing.T) { - topicID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID, Name: "to-delete"} - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("DeleteTopic", mock.Anything, topicID).Return(nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_DELETED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.topic_delete", "topic", mock.Anything, mock.Anything).Return(nil).Once() - - err := svc.DeleteTopic(ctx, topicID) - require.NoError(t, err) - }) - - t.Run("Subscribe", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() - mockRepo.On("CreateSubscription", mock.Anything, mock.Anything).Return(nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "SUBSCRIPTION_CREATED", mock.Anything, "SUBSCRIPTION", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.subscribe", "subscription", mock.Anything, mock.Anything).Return(nil).Once() - - sub, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") - require.NoError(t, err) - assert.NotNil(t, sub) - }) - - t.Run("ListSubscriptions", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() - mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{{ID: uuid.New()}}, nil).Once() - - subs, err := svc.ListSubscriptions(ctx, topicID) - require.NoError(t, err) - assert.Len(t, subs, 1) - }) - - t.Run("Unsubscribe", func(t *testing.T) { - subID := uuid.New() - topicID := uuid.New() - sub := &domain.Subscription{ID: subID, UserID: userID, TopicID: topicID} - mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(sub, nil).Once() - mockRepo.On("DeleteSubscription", mock.Anything, subID).Return(nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "SUBSCRIPTION_DELETED", mock.Anything, "SUBSCRIPTION", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.unsubscribe", "subscription", mock.Anything, mock.Anything).Return(nil).Once() - - err := svc.Unsubscribe(ctx, subID) - require.NoError(t, err) - }) - - t.Run("Publish", func(t *testing.T) { - done := make(chan struct{}) - topicID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID} - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() - mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{}, nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.NoError(t, err) - <-done - }) -} - -func testNotifyServiceUnitRbacErrors(t *testing.T) { - mockRepo := new(MockNotifyRepo) - mockQueueSvc := new(MockQueueService) - mockEventSvc := new(MockEventService) - mockAuditSvc := new(MockAuditService) - rbacSvc := new(MockRBACService) - - svc := services.NewNotifyService(services.NotifyServiceParams{ - Repo: mockRepo, - RBACSvc: rbacSvc, - QueueSvc: mockQueueSvc, - EventSvc: mockEventSvc, - AuditSvc: mockAuditSvc, - Logger: slog.Default(), - }) - - ctx := context.Background() - userID := uuid.New() - tenantID := uuid.New() - ctx = appcontext.WithUserID(ctx, userID) - ctx = appcontext.WithTenantID(ctx, tenantID) - - type rbacCase struct { - name string - permission domain.Permission - resourceID string - invoke func(id string) error - } - - cases := []rbacCase{ - { - name: "CreateTopic_Unauthorized", - permission: domain.PermissionNotifyCreate, - resourceID: "*", - invoke: func(id string) error { - _, err := svc.CreateTopic(ctx, "my-topic") - return err - }, - }, - { - name: "ListTopics_Unauthorized", - permission: domain.PermissionNotifyRead, - resourceID: "*", - invoke: func(id string) error { - _, err := svc.ListTopics(ctx) - return err - }, - }, - { - name: "DeleteTopic_Unauthorized", - permission: domain.PermissionNotifyDelete, - resourceID: uuid.New().String(), - invoke: func(id string) error { - return svc.DeleteTopic(ctx, uuid.MustParse(id)) - }, - }, - { - name: "Subscribe_Unauthorized", - permission: domain.PermissionNotifyWrite, - resourceID: uuid.New().String(), - invoke: func(id string) error { - _, err := svc.Subscribe(ctx, uuid.MustParse(id), domain.ProtocolWebhook, "https://example.com/hook") - return err - }, - }, - { - name: "ListSubscriptions_Unauthorized", - permission: domain.PermissionNotifyRead, - resourceID: uuid.New().String(), - invoke: func(id string) error { - _, err := svc.ListSubscriptions(ctx, uuid.MustParse(id)) - return err - }, - }, - { - name: "Unsubscribe_Unauthorized", - permission: domain.PermissionNotifyDelete, - resourceID: uuid.New().String(), - invoke: func(id string) error { - return svc.Unsubscribe(ctx, uuid.MustParse(id)) - }, - }, - { - name: "Publish_Unauthorized", - permission: domain.PermissionNotifyWrite, - resourceID: uuid.New().String(), - invoke: func(id string) error { - return svc.Publish(ctx, uuid.MustParse(id), "hello") - }, - }, - } - - authErr := errors.New(errors.Forbidden, "permission denied") - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - rbacSvc.On("Authorize", mock.Anything, userID, tenantID, c.permission, c.resourceID).Return(authErr).Once() - err := c.invoke(c.resourceID) - require.Error(t, err) - assert.True(t, errors.Is(err, errors.Forbidden)) - }) - } -} - -func testNotifyServiceUnitRepoErrors(t *testing.T) { - mockRepo := new(MockNotifyRepo) - mockQueueSvc := new(MockQueueService) - mockEventSvc := new(MockEventService) - mockAuditSvc := new(MockAuditService) - rbacSvc := new(MockRBACService) - rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - - svc := services.NewNotifyService(services.NotifyServiceParams{ - Repo: mockRepo, - RBACSvc: rbacSvc, - QueueSvc: mockQueueSvc, - EventSvc: mockEventSvc, - AuditSvc: mockAuditSvc, - Logger: slog.Default(), - }) - - ctx := context.Background() - userID := uuid.New() - tenantID := uuid.New() - ctx = appcontext.WithUserID(ctx, userID) - ctx = appcontext.WithTenantID(ctx, tenantID) - - t.Run("CreateTopic_DuplicateName", func(t *testing.T) { - existing := &domain.Topic{ID: uuid.New(), Name: testTopicName} - mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(existing, nil).Once() - - _, err := svc.CreateTopic(ctx, testTopicName) - require.Error(t, err) - assert.Contains(t, err.Error(), "already exists") - }) - - t.Run("CreateTopic_RepoError", func(t *testing.T) { - mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(nil, nil).Once() - mockRepo.On("CreateTopic", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() - - _, err := svc.CreateTopic(ctx, testTopicName) - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) - - t.Run("DeleteTopic_NotFound", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, nil).Once() - - err := svc.DeleteTopic(ctx, topicID) - require.Error(t, err) - assert.Contains(t, err.Error(), "not found") - }) - - t.Run("DeleteTopic_RepoError", func(t *testing.T) { - topicID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID, Name: "test"} - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("DeleteTopic", mock.Anything, topicID).Return(fmt.Errorf("db error")).Once() - - err := svc.DeleteTopic(ctx, topicID) - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) - - t.Run("Subscribe_TopicNotFound", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() - - _, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") - require.Error(t, err) - }) - - t.Run("Subscribe_RepoError", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() - mockRepo.On("CreateSubscription", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() - - _, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) - - t.Run("Unsubscribe_NotFound", func(t *testing.T) { - subID := uuid.New() - mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() - - err := svc.Unsubscribe(ctx, subID) - require.Error(t, err) - }) - - t.Run("Unsubscribe_RepoError", func(t *testing.T) { - subID := uuid.New() - topicID := uuid.New() - sub := &domain.Subscription{ID: subID, UserID: userID, TopicID: topicID} - mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(sub, nil).Once() - mockRepo.On("DeleteSubscription", mock.Anything, subID).Return(fmt.Errorf("db error")).Once() - - err := svc.Unsubscribe(ctx, subID) - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) - - t.Run("ListSubscriptions_TopicNotFound", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() - - _, err := svc.ListSubscriptions(ctx, topicID) - require.Error(t, err) - }) - - t.Run("Publish_TopicNotFound", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.Error(t, err) - }) - - t.Run("Publish_SaveMessageError", func(t *testing.T) { - topicID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID} - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) -} - -func testNotifyServiceUnitPublishErrors(t *testing.T) { - mockRepo := new(MockNotifyRepo) - mockQueueSvc := new(MockQueueService) - mockEventSvc := new(MockEventService) - mockAuditSvc := new(MockAuditService) - rbacSvc := new(MockRBACService) - rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - - svc := services.NewNotifyService(services.NotifyServiceParams{ - Repo: mockRepo, - RBACSvc: rbacSvc, - QueueSvc: mockQueueSvc, - EventSvc: mockEventSvc, - AuditSvc: mockAuditSvc, - Logger: slog.Default(), - }) - - ctx := context.Background() - userID := uuid.New() - tenantID := uuid.New() - ctx = appcontext.WithUserID(ctx, userID) - ctx = appcontext.WithTenantID(ctx, tenantID) - - t.Run("Publish_WebhookDeliveryError", func(t *testing.T) { - done := make(chan struct{}) - topicID := uuid.New() - subID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID} - sub := &domain.Subscription{ - ID: subID, - UserID: userID, - TopicID: topicID, - Protocol: domain.ProtocolWebhook, - Endpoint: "http://localhost:9999/nonexistent", - } - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() - mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.NoError(t, err) - <-done - }) - - t.Run("Publish_QueueInvalidUUID", func(t *testing.T) { - done := make(chan struct{}) - topicID := uuid.New() - subID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID} - sub := &domain.Subscription{ - ID: subID, - UserID: userID, - TopicID: topicID, - Protocol: domain.ProtocolQueue, - Endpoint: "not-a-valid-uuid", - } - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() - mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.NoError(t, err) - <-done - }) +package services_test + +import ( + "bytes" + "context" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/google/uuid" + appcontext "github.com/poyrazk/thecloud/internal/core/context" + "github.com/poyrazk/thecloud/internal/core/domain" + "github.com/poyrazk/thecloud/internal/core/services" + "github.com/poyrazk/thecloud/internal/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +const testTopicName = "my-topic" + +func TestNotifyServiceUnit(t *testing.T) { + t.Run("CRUD", testNotifyServiceUnitCRUD) + t.Run("RBACErrors", testNotifyServiceUnitRbacErrors) + t.Run("RepoErrors", testNotifyServiceUnitRepoErrors) + t.Run("PublishErrors", testNotifyServiceUnitPublishErrors) +} + +func testNotifyServiceUnitCRUD(t *testing.T) { + mockRepo := new(MockNotifyRepo) + mockQueueSvc := new(MockQueueService) + mockEventSvc := new(MockEventService) + mockAuditSvc := new(MockAuditService) + rbacSvc := new(MockRBACService) + rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + svc := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo, + RBACSvc: rbacSvc, + QueueSvc: mockQueueSvc, + EventSvc: mockEventSvc, + AuditSvc: mockAuditSvc, + Logger: slog.Default(), + }) + + ctx := context.Background() + userID := uuid.New() + tenantID := uuid.New() + ctx = appcontext.WithUserID(ctx, userID) + ctx = appcontext.WithTenantID(ctx, tenantID) + + t.Run("CreateTopic", func(t *testing.T) { + mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(nil, nil).Once() + mockRepo.On("CreateTopic", mock.Anything, mock.Anything).Return(nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_CREATED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.topic_create", "topic", mock.Anything, mock.Anything).Return(nil).Once() + + topic, err := svc.CreateTopic(ctx, testTopicName) + require.NoError(t, err) + assert.NotNil(t, topic) + assert.Equal(t, testTopicName, topic.Name) + }) + + t.Run("ListTopics", func(t *testing.T) { + mockRepo.On("ListTopics", mock.Anything, userID).Return([]*domain.Topic{{ID: uuid.New(), Name: "topic1"}}, nil).Once() + + topics, err := svc.ListTopics(ctx) + require.NoError(t, err) + assert.Len(t, topics, 1) + }) + + t.Run("DeleteTopic", func(t *testing.T) { + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID, Name: "to-delete"} + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("DeleteTopic", mock.Anything, topicID).Return(nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_DELETED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.topic_delete", "topic", mock.Anything, mock.Anything).Return(nil).Once() + + err := svc.DeleteTopic(ctx, topicID) + require.NoError(t, err) + }) + + t.Run("Subscribe", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() + mockRepo.On("CreateSubscription", mock.Anything, mock.Anything).Return(nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "SUBSCRIPTION_CREATED", mock.Anything, "SUBSCRIPTION", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.subscribe", "subscription", mock.Anything, mock.Anything).Return(nil).Once() + + sub, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") + require.NoError(t, err) + assert.NotNil(t, sub) + }) + + t.Run("ListSubscriptions", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() + mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{{ID: uuid.New()}}, nil).Once() + + subs, err := svc.ListSubscriptions(ctx, topicID) + require.NoError(t, err) + assert.Len(t, subs, 1) + }) + + t.Run("Unsubscribe", func(t *testing.T) { + subID := uuid.New() + topicID := uuid.New() + sub := &domain.Subscription{ID: subID, UserID: userID, TopicID: topicID} + mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(sub, nil).Once() + mockRepo.On("DeleteSubscription", mock.Anything, subID).Return(nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "SUBSCRIPTION_DELETED", mock.Anything, "SUBSCRIPTION", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.unsubscribe", "subscription", mock.Anything, mock.Anything).Return(nil).Once() + + err := svc.Unsubscribe(ctx, subID) + require.NoError(t, err) + }) + + t.Run("Publish", func(t *testing.T) { + done := make(chan struct{}) + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() + mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{}, nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.NoError(t, err) + <-done + }) +} + +func testNotifyServiceUnitRbacErrors(t *testing.T) { + mockRepo := new(MockNotifyRepo) + mockQueueSvc := new(MockQueueService) + mockEventSvc := new(MockEventService) + mockAuditSvc := new(MockAuditService) + rbacSvc := new(MockRBACService) + + svc := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo, + RBACSvc: rbacSvc, + QueueSvc: mockQueueSvc, + EventSvc: mockEventSvc, + AuditSvc: mockAuditSvc, + Logger: slog.Default(), + }) + + ctx := context.Background() + userID := uuid.New() + tenantID := uuid.New() + ctx = appcontext.WithUserID(ctx, userID) + ctx = appcontext.WithTenantID(ctx, tenantID) + + type rbacCase struct { + name string + permission domain.Permission + resourceID string + invoke func(id string) error + } + + cases := []rbacCase{ + { + name: "CreateTopic_Unauthorized", + permission: domain.PermissionNotifyCreate, + resourceID: "*", + invoke: func(id string) error { + _, err := svc.CreateTopic(ctx, "my-topic") + return err + }, + }, + { + name: "ListTopics_Unauthorized", + permission: domain.PermissionNotifyRead, + resourceID: "*", + invoke: func(id string) error { + _, err := svc.ListTopics(ctx) + return err + }, + }, + { + name: "DeleteTopic_Unauthorized", + permission: domain.PermissionNotifyDelete, + resourceID: uuid.New().String(), + invoke: func(id string) error { + return svc.DeleteTopic(ctx, uuid.MustParse(id)) + }, + }, + { + name: "Subscribe_Unauthorized", + permission: domain.PermissionNotifyWrite, + resourceID: uuid.New().String(), + invoke: func(id string) error { + _, err := svc.Subscribe(ctx, uuid.MustParse(id), domain.ProtocolWebhook, "https://example.com/hook") + return err + }, + }, + { + name: "ListSubscriptions_Unauthorized", + permission: domain.PermissionNotifyRead, + resourceID: uuid.New().String(), + invoke: func(id string) error { + _, err := svc.ListSubscriptions(ctx, uuid.MustParse(id)) + return err + }, + }, + { + name: "Unsubscribe_Unauthorized", + permission: domain.PermissionNotifyDelete, + resourceID: uuid.New().String(), + invoke: func(id string) error { + return svc.Unsubscribe(ctx, uuid.MustParse(id)) + }, + }, + { + name: "Publish_Unauthorized", + permission: domain.PermissionNotifyWrite, + resourceID: uuid.New().String(), + invoke: func(id string) error { + return svc.Publish(ctx, uuid.MustParse(id), "hello") + }, + }, + } + + authErr := errors.New(errors.Forbidden, "permission denied") + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + rbacSvc.On("Authorize", mock.Anything, userID, tenantID, c.permission, c.resourceID).Return(authErr).Once() + err := c.invoke(c.resourceID) + require.Error(t, err) + assert.True(t, errors.Is(err, errors.Forbidden)) + }) + } +} + +func testNotifyServiceUnitRepoErrors(t *testing.T) { + mockRepo := new(MockNotifyRepo) + mockQueueSvc := new(MockQueueService) + mockEventSvc := new(MockEventService) + mockAuditSvc := new(MockAuditService) + rbacSvc := new(MockRBACService) + rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + svc := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo, + RBACSvc: rbacSvc, + QueueSvc: mockQueueSvc, + EventSvc: mockEventSvc, + AuditSvc: mockAuditSvc, + Logger: slog.Default(), + }) + + ctx := context.Background() + userID := uuid.New() + tenantID := uuid.New() + ctx = appcontext.WithUserID(ctx, userID) + ctx = appcontext.WithTenantID(ctx, tenantID) + + t.Run("CreateTopic_DuplicateName", func(t *testing.T) { + existing := &domain.Topic{ID: uuid.New(), Name: testTopicName} + mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(existing, nil).Once() + + _, err := svc.CreateTopic(ctx, testTopicName) + require.Error(t, err) + assert.Contains(t, err.Error(), "already exists") + }) + + t.Run("CreateTopic_RepoError", func(t *testing.T) { + mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(nil, nil).Once() + mockRepo.On("CreateTopic", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() + + _, err := svc.CreateTopic(ctx, testTopicName) + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) + + t.Run("DeleteTopic_NotFound", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, nil).Once() + + err := svc.DeleteTopic(ctx, topicID) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("DeleteTopic_RepoError", func(t *testing.T) { + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID, Name: "test"} + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("DeleteTopic", mock.Anything, topicID).Return(fmt.Errorf("db error")).Once() + + err := svc.DeleteTopic(ctx, topicID) + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) + + t.Run("Subscribe_TopicNotFound", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() + + _, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") + require.Error(t, err) + }) + + t.Run("Subscribe_RepoError", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() + mockRepo.On("CreateSubscription", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() + + _, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) + + t.Run("Unsubscribe_NotFound", func(t *testing.T) { + subID := uuid.New() + mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() + + err := svc.Unsubscribe(ctx, subID) + require.Error(t, err) + }) + + t.Run("Unsubscribe_RepoError", func(t *testing.T) { + subID := uuid.New() + topicID := uuid.New() + sub := &domain.Subscription{ID: subID, UserID: userID, TopicID: topicID} + mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(sub, nil).Once() + mockRepo.On("DeleteSubscription", mock.Anything, subID).Return(fmt.Errorf("db error")).Once() + + err := svc.Unsubscribe(ctx, subID) + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) + + t.Run("ListSubscriptions_TopicNotFound", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() + + _, err := svc.ListSubscriptions(ctx, topicID) + require.Error(t, err) + }) + + t.Run("Publish_TopicNotFound", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.Error(t, err) + }) + + t.Run("Publish_SaveMessageError", func(t *testing.T) { + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) +} + +func testNotifyServiceUnitPublishErrors(t *testing.T) { + mockRepo := new(MockNotifyRepo) + mockQueueSvc := new(MockQueueService) + mockEventSvc := new(MockEventService) + mockAuditSvc := new(MockAuditService) + rbacSvc := new(MockRBACService) + rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + svc := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo, + RBACSvc: rbacSvc, + QueueSvc: mockQueueSvc, + EventSvc: mockEventSvc, + AuditSvc: mockAuditSvc, + Logger: slog.Default(), + }) + + ctx := context.Background() + userID := uuid.New() + tenantID := uuid.New() + ctx = appcontext.WithUserID(ctx, userID) + ctx = appcontext.WithTenantID(ctx, tenantID) + + t.Run("Publish_WebhookDeliveryError", func(t *testing.T) { + done := make(chan struct{}) + topicID := uuid.New() + subID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + sub := &domain.Subscription{ + ID: subID, + UserID: userID, + TopicID: topicID, + Protocol: domain.ProtocolWebhook, + Endpoint: "http://localhost:9999/nonexistent", + } + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() + mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.NoError(t, err) + <-done + }) + + t.Run("Publish_WebhookNon2xxStatus", func(t *testing.T) { + // Issue #338: webhook delivery must surface non-2xx HTTP responses. + var ( + mu sync.Mutex + received bool + ) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + mu.Lock() + received = true + mu.Unlock() + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + var logBuf bytes.Buffer + var logMu sync.Mutex + capturingLogger := slog.New(slog.NewTextHandler(&lockedWriter{w: &logBuf, mu: &logMu}, &slog.HandlerOptions{Level: slog.LevelDebug})) + + mockRepo2 := new(MockNotifyRepo) + mockQueueSvc2 := new(MockQueueService) + mockEventSvc2 := new(MockEventService) + mockAuditSvc2 := new(MockAuditService) + rbacSvc2 := new(MockRBACService) + rbacSvc2.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + svc2 := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo2, + RBACSvc: rbacSvc2, + QueueSvc: mockQueueSvc2, + EventSvc: mockEventSvc2, + AuditSvc: mockAuditSvc2, + Logger: capturingLogger, + }) + + done := make(chan struct{}) + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + sub := &domain.Subscription{ + ID: uuid.New(), + UserID: userID, + TopicID: topicID, + Protocol: domain.ProtocolWebhook, + Endpoint: server.URL, + } + mockRepo2.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo2.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() + mockRepo2.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() + mockEventSvc2.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc2.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() + + err := svc2.Publish(ctx, topicID, "hello") + require.NoError(t, err) + <-done + + // Allow the async webhook goroutine a moment to fire and log. + require.Eventually(t, func() bool { + mu.Lock() + defer mu.Unlock() + return received + }, 2*time.Second, 10*time.Millisecond, "webhook server never received request") + + require.Eventually(t, func() bool { + logMu.Lock() + defer logMu.Unlock() + return strings.Contains(logBuf.String(), "webhook delivery failed") && + strings.Contains(logBuf.String(), "status=500") + }, 2*time.Second, 10*time.Millisecond, "expected webhook delivery failure log with status=500") + }) + + t.Run("Publish_QueueInvalidUUID", func(t *testing.T) { + done := make(chan struct{}) + topicID := uuid.New() + subID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + sub := &domain.Subscription{ + ID: subID, + UserID: userID, + TopicID: topicID, + Protocol: domain.ProtocolQueue, + Endpoint: "not-a-valid-uuid", + } + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() + mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.NoError(t, err) + <-done + }) +} + +type lockedWriter struct { + w *bytes.Buffer + mu *sync.Mutex +} + +func (l *lockedWriter) Write(p []byte) (int, error) { + l.mu.Lock() + defer l.mu.Unlock() + return l.w.Write(p) } \ No newline at end of file From d978ba67ebdf13ebaac7906ae9f4380d31c431a9 Mon Sep 17 00:00:00 2001 From: jackthepunished <107313375+jackthepunished@users.noreply.github.com> Date: Wed, 29 Apr 2026 23:18:24 +0300 Subject: [PATCH 4/4] fix(k8s): handle ignored repo.Update errors in provisioner Closes #310. Previously four repo.Update() calls in the kubeadm provisioner discarded their error with `_ =`, hiding silent state inconsistency when the database write failed: - Provision: returning success after Status flip to Running. - provisionControlPlane: persisting ControlPlaneIPs and KubeconfigEncrypted. - failCluster: persisting ClusterStatusFailed. Now: - Provision wraps and returns the persistence error so the caller knows state was not durably committed. - provisionControlPlane routes update failures through failCluster so the cluster is consistently marked Failed. - failCluster logs (rather than silently drops) a secondary update failure while preserving the original failure error for the caller. Adds a regression test ensuring failCluster surfaces the original error even when the status persistence fails. --- internal/repositories/k8s/provisioner.go | 19 +++++++-- .../k8s/provisioner_extra_test.go | 42 +++++++++++++------ 2 files changed, 45 insertions(+), 16 deletions(-) diff --git a/internal/repositories/k8s/provisioner.go b/internal/repositories/k8s/provisioner.go index 1c8a7db08..339baddd6 100644 --- a/internal/repositories/k8s/provisioner.go +++ b/internal/repositories/k8s/provisioner.go @@ -108,7 +108,9 @@ func (p *KubeadmProvisioner) Provision(ctx context.Context, cluster *domain.Clus } cluster.Status = domain.ClusterStatusRunning - _ = p.repo.Update(ctx, cluster) + if err := p.repo.Update(ctx, cluster); err != nil { + return fmt.Errorf("provisioning succeeded but failed to persist running status for cluster %s: %w", cluster.ID, err) + } return nil } @@ -185,7 +187,9 @@ func (p *KubeadmProvisioner) provisionControlPlane(ctx context.Context, cluster return p.failCluster(ctx, cluster, "control plane node failed to get an IP", nil) } cluster.ControlPlaneIPs = append(cluster.ControlPlaneIPs, masterIP) - _ = p.repo.Update(ctx, cluster) + if err := p.repo.Update(ctx, cluster); err != nil { + return p.failCluster(ctx, cluster, "failed to persist control plane IP", err) + } // Wait for kubeadm init to finish and kubeconfig to be available via SSH p.logger.Info("waiting for kubeadm init to complete", "ip", masterIP) @@ -202,7 +206,9 @@ func (p *KubeadmProvisioner) provisionControlPlane(ctx context.Context, cluster } cluster.KubeconfigEncrypted = encryptedKubeconfig - _ = p.repo.Update(ctx, cluster) + if err := p.repo.Update(ctx, cluster); err != nil { + return p.failCluster(ctx, cluster, "failed to persist encrypted kubeconfig", err) + } return nil } @@ -462,7 +468,12 @@ func (p *KubeadmProvisioner) Deprovision(ctx context.Context, cluster *domain.Cl func (p *KubeadmProvisioner) failCluster(ctx context.Context, cluster *domain.Cluster, msg string, err error) error { cluster.Status = domain.ClusterStatusFailed - _ = p.repo.Update(ctx, cluster) + if updateErr := p.repo.Update(ctx, cluster); updateErr != nil { + // We're already returning a failure; log so the persistence + // gap is visible in operations rather than silently dropped. + p.logger.Error("failed to persist failed cluster status", + "cluster_id", cluster.ID, "error", updateErr, "original_error", err) + } p.logger.Error(msg, "cluster_id", cluster.ID, "error", err) return fmt.Errorf("%s: %w", msg, err) } diff --git a/internal/repositories/k8s/provisioner_extra_test.go b/internal/repositories/k8s/provisioner_extra_test.go index 5c89ed11a..2cdd1b519 100644 --- a/internal/repositories/k8s/provisioner_extra_test.go +++ b/internal/repositories/k8s/provisioner_extra_test.go @@ -106,22 +106,40 @@ func TestCreateBackup_Extra(t *testing.T) { func TestFailCluster(t *testing.T) { ctx := context.Background() - repo := new(mockClusterRepo) logger := slog.New(slog.NewTextHandler(io.Discard, nil)) - cluster := &domain.Cluster{ID: uuid.New()} - p := &KubeadmProvisioner{ - repo: repo, - logger: logger, - } + t.Run("UpdateSucceeds", func(t *testing.T) { + repo := new(mockClusterRepo) + cluster := &domain.Cluster{ID: uuid.New()} + p := &KubeadmProvisioner{repo: repo, logger: logger} + + repo.On("Update", ctx, mock.MatchedBy(func(c *domain.Cluster) bool { + return c.Status == domain.ClusterStatusFailed + })).Return(nil).Once() - repo.On("Update", ctx, mock.MatchedBy(func(c *domain.Cluster) bool { - return c.Status == domain.ClusterStatusFailed - })).Return(nil).Once() + err := p.failCluster(ctx, cluster, "test error", fmt.Errorf("underlying")) + require.Error(t, err) + assert.Contains(t, err.Error(), "test error") + repo.AssertExpectations(t) + }) - err := p.failCluster(ctx, cluster, "test error", fmt.Errorf("underlying")) - require.Error(t, err) - assert.Contains(t, err.Error(), "test error") + t.Run("UpdateFails_StillReturnsOriginalError", func(t *testing.T) { + repo := new(mockClusterRepo) + cluster := &domain.Cluster{ID: uuid.New()} + p := &KubeadmProvisioner{repo: repo, logger: logger} + + repo.On("Update", ctx, mock.Anything).Return(fmt.Errorf("db down")).Once() + + err := p.failCluster(ctx, cluster, "test error", fmt.Errorf("underlying")) + require.Error(t, err) + // Original failure must surface, not the persistence error. + assert.Contains(t, err.Error(), "test error") + assert.Contains(t, err.Error(), "underlying") + assert.NotContains(t, err.Error(), "db down") + // In-memory status still flipped to Failed even if persistence failed. + assert.Equal(t, domain.ClusterStatusFailed, cluster.Status) + repo.AssertExpectations(t) + }) } func TestDeprovision(t *testing.T) {