From 507924a8d53ac26efef1e16b7b850ad77ad330ce Mon Sep 17 00:00:00 2001 From: jackthepunished <107313375+jackthepunished@users.noreply.github.com> Date: Tue, 28 Apr 2026 15:32:08 +0300 Subject: [PATCH 1/4] fix(security): stop logging plaintext password reset tokens The password reset flow logged the plaintext token alongside the user's email at Debug level so it could be picked up during MVP testing. Anyone with log access could mint a password reset and take over the account. Replace the Debug log with an Info entry that records only the user_id and reset-token row id. The plaintext token must reach the user via email; once an EmailService is wired in it should be the sole consumer. Closes #293 --- internal/core/services/password_reset.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/internal/core/services/password_reset.go b/internal/core/services/password_reset.go index 725a84b0d..c534661ba 100644 --- a/internal/core/services/password_reset.go +++ b/internal/core/services/password_reset.go @@ -75,10 +75,9 @@ func (s *PasswordResetService) RequestReset(ctx context.Context, email string) e return err } - // Note: EmailService integration is pending. - // For MVP/Demo: Log the token so we can test it manually. - // Future: Inject and use EmailService here. - s.logger.Debug("password reset token", "email", email, "token", token) + // TODO: deliver `token` via an injected EmailService once available. + // Never log or persist the plaintext token — its only safe destination is the user. + s.logger.Info("password reset token issued", "user_id", user.ID, "token_id", resetToken.ID) return nil } From 8e802aa4f46cadcb0381fbc8c914e3ae07f53cef Mon Sep 17 00:00:00 2001 From: jackthepunished <107313375+jackthepunished@users.noreply.github.com> Date: Tue, 28 Apr 2026 15:35:12 +0300 Subject: [PATCH 2/4] fix(security): sanitize Content-Disposition filename to prevent header injection Object keys flowed straight into the Content-Disposition header via fmt.Sprintf("attachment; filename=%s", key), so any key containing CRLF, double quotes, or backslashes could split the response or inject arbitrary headers. Path-bearing keys also leaked the full bucket path to the client as the suggested filename. Replace both call sites (authenticated download and presigned download) with a single helper, contentDispositionAttachment, that: - reduces the key to its basename via path.Base - emits an ASCII-only `filename="..."` fallback with control bytes, non-ASCII bytes, quotes, and backslashes mapped to `_` - emits the full Unicode basename in `filename*=UTF-8''...` per RFC 5987 with proper attr-char percent-encoding - guarantees CR/LF can never reach the wire Add table-driven unit tests covering response-splitting, quote/backslash injection, non-ASCII names, nested keys, and empty keys. Closes #225, #226 --- internal/handlers/storage_handler.go | 79 ++++++++++++++++- ...torage_handler_content_disposition_test.go | 86 +++++++++++++++++++ 2 files changed, 163 insertions(+), 2 deletions(-) create mode 100644 internal/handlers/storage_handler_content_disposition_test.go diff --git a/internal/handlers/storage_handler.go b/internal/handlers/storage_handler.go index 31151aa71..42f810e89 100644 --- a/internal/handlers/storage_handler.go +++ b/internal/handlers/storage_handler.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "path" "strconv" "strings" "time" @@ -38,6 +39,80 @@ const ( headerContentSha256 = "X-Content-Sha256" ) +// contentDispositionAttachment builds a safe `Content-Disposition: attachment` +// header for a stored object. +// +// Object keys can contain path segments, non-ASCII characters, control +// characters, quotes, backslashes, or CRLF that — if interpolated naively — +// would either corrupt the header (HTTP response splitting) or let an attacker +// inject additional headers. The output therefore emits two parameters per +// RFC 6266: +// +// - `filename="..."` ASCII-only fallback for legacy clients. All bytes +// outside the safe printable range and the two +// characters that are special inside a quoted-string +// (`"` and `\`) are replaced with `_`. +// - `filename*=UTF-8''…` RFC 5987 percent-encoded form preserving the +// original Unicode basename for modern clients. +// +// `path.Base` is used to discard any path segments embedded in the key. If the +// resulting name is empty we fall back to "download". +func contentDispositionAttachment(key string) string { + name := path.Base(key) + if name == "." || name == "/" || name == "" { + name = "download" + } + + return fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, + asciiFilenameFallback(name), rfc5987Encode(name)) +} + +// asciiFilenameFallback returns an ASCII-only sanitized copy of name suitable +// for the legacy `filename=` parameter. Any byte that is a control character +// (<0x20 or 0x7f), non-ASCII (>=0x80), or special inside a quoted-string is +// replaced with `_` so the value can be safely wrapped in double quotes. +func asciiFilenameFallback(name string) string { + out := make([]byte, 0, len(name)) + for i := 0; i < len(name); i++ { + c := name[i] + switch { + case c < 0x20, c == 0x7f, c >= 0x80, c == '"', c == '\\': + out = append(out, '_') + default: + out = append(out, c) + } + } + if len(out) == 0 { + return "download" + } + return string(out) +} + +// rfc5987Encode percent-encodes a value per RFC 5987 attr-char rules so it can +// be safely placed in a `filename*` parameter. +func rfc5987Encode(s string) string { + const hex = "0123456789ABCDEF" + var b strings.Builder + b.Grow(len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + switch { + case c >= 'A' && c <= 'Z', + c >= 'a' && c <= 'z', + c >= '0' && c <= '9', + c == '!', c == '#', c == '$', c == '&', c == '+', + c == '-', c == '.', c == '^', c == '_', c == '`', + c == '|', c == '~': + b.WriteByte(c) + default: + b.WriteByte('%') + b.WriteByte(hex[c>>4]) + b.WriteByte(hex[c&0x0f]) + } + } + return b.String() +} + // Upload uploads an object to a bucket // @Summary Upload an object // @Description Uploads a file/object to the specified bucket and key @@ -105,7 +180,7 @@ func (h *StorageHandler) Download(c *gin.Context) { defer func() { _ = reader.Close() }() // Set headers - c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=%s", key)) + c.Header("Content-Disposition", contentDispositionAttachment(key)) c.Header("Content-Type", obj.ContentType) c.Header("Content-Length", fmt.Sprintf("%d", obj.SizeBytes)) @@ -461,7 +536,7 @@ func (h *StorageHandler) ServePresignedDownload(c *gin.Context) { } defer func() { _ = reader.Close() }() - c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=%s", key)) + c.Header("Content-Disposition", contentDispositionAttachment(key)) c.Header("Content-Type", obj.ContentType) c.Header("Content-Length", fmt.Sprintf("%d", obj.SizeBytes)) _, _ = io.Copy(c.Writer, reader) diff --git a/internal/handlers/storage_handler_content_disposition_test.go b/internal/handlers/storage_handler_content_disposition_test.go new file mode 100644 index 000000000..64a4aec1d --- /dev/null +++ b/internal/handlers/storage_handler_content_disposition_test.go @@ -0,0 +1,86 @@ +package httphandlers + +import ( + "strings" + "testing" +) + +func TestContentDispositionAttachment(t *testing.T) { + cases := []struct { + name string + key string + wantFilename string + wantFilenameStar string + }{ + { + name: "simple ascii filename", + key: "report.pdf", + wantFilename: `filename="report.pdf"`, + wantFilenameStar: `filename*=UTF-8''report.pdf`, + }, + { + name: "nested key uses basename", + key: "exports/2026/q1/report.pdf", + wantFilename: `filename="report.pdf"`, + wantFilenameStar: `filename*=UTF-8''report.pdf`, + }, + { + name: "CRLF response splitting attempt is sanitized", + key: "evil\r\nSet-Cookie: pwned=1", + wantFilename: `filename="evil__Set-Cookie: pwned=1"`, + wantFilenameStar: `filename*=UTF-8''evil%0D%0ASet-Cookie%3A%20pwned%3D1`, + }, + { + name: "embedded quote and backslash are sanitized", + key: `bad"name\file.txt`, + wantFilename: `filename="bad_name_file.txt"`, + wantFilenameStar: `filename*=UTF-8''bad%22name%5Cfile.txt`, + }, + { + name: "non-ASCII falls back in legacy filename, preserved in filename*", + key: "résumé.pdf", + wantFilename: `filename="r__sum__.pdf"`, // 2 bytes per accented char both replaced + wantFilenameStar: `filename*=UTF-8''r%C3%A9sum%C3%A9.pdf`, + }, + { + name: "empty key falls back to download", + key: "", + wantFilename: `filename="download"`, + wantFilenameStar: `filename*=UTF-8''download`, + }, + { + name: "trailing slash falls back to download", + key: "folder/", + wantFilename: `filename="folder"`, // path.Base normalizes + wantFilenameStar: `filename*=UTF-8''folder`, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := contentDispositionAttachment(tc.key) + if !strings.HasPrefix(got, "attachment; ") { + t.Errorf("missing attachment prefix: %q", got) + } + if !strings.Contains(got, tc.wantFilename) { + t.Errorf("missing legacy filename param\n got: %q\n want: %q", got, tc.wantFilename) + } + if !strings.Contains(got, tc.wantFilenameStar) { + t.Errorf("missing filename* param\n got: %q\n want: %q", got, tc.wantFilenameStar) + } + if strings.ContainsAny(got, "\r\n") { + t.Errorf("output must not contain CR/LF: %q", got) + } + }) + } +} + +func TestContentDispositionAttachment_NoCRLFEverEscapes(t *testing.T) { + for _, c := range []byte{'\r', '\n', 0, 0x1f, 0x7f} { + key := "x" + string(c) + "y" + got := contentDispositionAttachment(key) + if strings.ContainsAny(got, "\r\n") { + t.Fatalf("control byte 0x%02x leaked into header: %q", c, got) + } + } +} From 2cfa64c292a070145832553b9860d7d730a923d2 Mon Sep 17 00:00:00 2001 From: jackthepunished <107313375+jackthepunished@users.noreply.github.com> Date: Wed, 29 Apr 2026 22:30:36 +0300 Subject: [PATCH 3/4] fix(notify): surface non-2xx webhook delivery responses Webhook delivery in NotifyService.deliverToWebhook used http.DefaultClient and only inspected transport errors, so 4xx/5xx responses were silently treated as successful deliveries. Subscribers returning failures (bad endpoints, transient 5xx) had messages dropped without any log signal. Check resp.StatusCode and emit a structured warn log on >=400 responses, add a 15s timeout via a dedicated http.Client, drain the response body before close, and stop discarding the error from NewRequestWithContext. Closes #338 Co-Authored-By: Claude Opus 4.7 (1M context) --- internal/core/services/notify.go | 613 ++++++------- internal/core/services/notify_unit_test.go | 948 +++++++++++---------- 2 files changed, 833 insertions(+), 728 deletions(-) diff --git a/internal/core/services/notify.go b/internal/core/services/notify.go index babfc7aa3..517121499 100644 --- a/internal/core/services/notify.go +++ b/internal/core/services/notify.go @@ -1,297 +1,316 @@ -// Package services implements core business workflows. -package services - -import ( - "bytes" - "context" - "fmt" - "log/slog" - "net/http" - "time" - - "github.com/google/uuid" - appcontext "github.com/poyrazk/thecloud/internal/core/context" - "github.com/poyrazk/thecloud/internal/core/domain" - "github.com/poyrazk/thecloud/internal/core/ports" -) - -// NotifyServiceParams defines the dependencies for NotifyService. -type NotifyServiceParams struct { - Repo ports.NotifyRepository - RBACSvc ports.RBACService - QueueSvc ports.QueueService - EventSvc ports.EventService - AuditSvc ports.AuditService - Logger *slog.Logger -} - -// NotifyService manages topics, subscriptions, and message delivery. -type NotifyService struct { - repo ports.NotifyRepository - rbacSvc ports.RBACService - queueSvc ports.QueueService - eventSvc ports.EventService - auditSvc ports.AuditService - logger *slog.Logger -} - -// NewNotifyService constructs a NotifyService with its dependencies. -func NewNotifyService(params NotifyServiceParams) ports.NotifyService { - return &NotifyService{ - repo: params.Repo, - rbacSvc: params.RBACSvc, - queueSvc: params.QueueSvc, - eventSvc: params.EventSvc, - auditSvc: params.AuditSvc, - logger: params.Logger, - } -} - -func (s *NotifyService) CreateTopic(ctx context.Context, name string) (*domain.Topic, error) { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyCreate, "*"); err != nil { - return nil, err - } - - existing, _ := s.repo.GetTopicByName(ctx, name, userID) - if existing != nil { - return nil, fmt.Errorf("topic with name %s already exists", name) - } - - id := uuid.New() - topic := &domain.Topic{ - ID: id, - UserID: userID, - Name: name, - ARN: fmt.Sprintf("arn:thecloud:notify:local:%s:topic/%s", userID, name), - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } - - if err := s.repo.CreateTopic(ctx, topic); err != nil { - return nil, err - } - - _ = s.eventSvc.RecordEvent(ctx, "TOPIC_CREATED", topic.ID.String(), "TOPIC", nil) - - _ = s.auditSvc.Log(ctx, topic.UserID, "notify.topic_create", "topic", topic.ID.String(), map[string]interface{}{ - "name": topic.Name, - }) - - return topic, nil -} - -func (s *NotifyService) ListTopics(ctx context.Context) ([]*domain.Topic, error) { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyRead, "*"); err != nil { - return nil, err - } - - return s.repo.ListTopics(ctx, userID) -} - -func (s *NotifyService) DeleteTopic(ctx context.Context, id uuid.UUID) error { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyDelete, id.String()); err != nil { - return err - } - - topic, err := s.repo.GetTopicByID(ctx, id, userID) - if err != nil { - return err - } - if topic == nil { - return fmt.Errorf("topic not found") - } - - if err := s.repo.DeleteTopic(ctx, id); err != nil { - return err - } - - _ = s.eventSvc.RecordEvent(ctx, "TOPIC_DELETED", id.String(), "TOPIC", nil) - - _ = s.auditSvc.Log(ctx, topic.UserID, "notify.topic_delete", "topic", topic.ID.String(), map[string]interface{}{ - "name": topic.Name, - }) - - return nil -} - -func (s *NotifyService) Subscribe(ctx context.Context, topicID uuid.UUID, protocol domain.SubscriptionProtocol, endpoint string) (*domain.Subscription, error) { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyWrite, topicID.String()); err != nil { - return nil, err - } - - // Verify topic exists and belongs to user - topic, err := s.repo.GetTopicByID(ctx, topicID, userID) - if err != nil { - return nil, err - } - - sub := &domain.Subscription{ - ID: uuid.New(), - UserID: userID, - TopicID: topic.ID, - Protocol: protocol, - Endpoint: endpoint, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } - - if err := s.repo.CreateSubscription(ctx, sub); err != nil { - return nil, err - } - - _ = s.eventSvc.RecordEvent(ctx, "SUBSCRIPTION_CREATED", sub.ID.String(), "SUBSCRIPTION", map[string]interface{}{"topic_id": topicID}) - - _ = s.auditSvc.Log(ctx, sub.UserID, "notify.subscribe", "subscription", sub.ID.String(), map[string]interface{}{ - "topic_id": topicID.String(), - "protocol": protocol, - "endpoint": endpoint, - }) - - return sub, nil -} - -func (s *NotifyService) ListSubscriptions(ctx context.Context, topicID uuid.UUID) ([]*domain.Subscription, error) { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyRead, topicID.String()); err != nil { - return nil, err - } - - // Verify topic ownership - _, err := s.repo.GetTopicByID(ctx, topicID, userID) - if err != nil { - return nil, err - } - - return s.repo.ListSubscriptions(ctx, topicID) -} - -func (s *NotifyService) Unsubscribe(ctx context.Context, id uuid.UUID) error { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyDelete, id.String()); err != nil { - return err - } - - sub, err := s.repo.GetSubscriptionByID(ctx, id, userID) - if err != nil { - return err - } - - if err := s.repo.DeleteSubscription(ctx, sub.ID); err != nil { - return err - } - - _ = s.eventSvc.RecordEvent(ctx, "SUBSCRIPTION_DELETED", id.String(), "SUBSCRIPTION", nil) - - _ = s.auditSvc.Log(ctx, sub.UserID, "notify.unsubscribe", "subscription", sub.ID.String(), map[string]interface{}{ - "topic_id": sub.TopicID.String(), - }) - - return nil -} - -func (s *NotifyService) Publish(ctx context.Context, topicID uuid.UUID, body string) error { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyWrite, topicID.String()); err != nil { - return err - } - - topic, err := s.repo.GetTopicByID(ctx, topicID, userID) - if err != nil { - return err - } - - msg := &domain.NotifyMessage{ - ID: uuid.New(), - TopicID: topic.ID, - Body: body, - CreatedAt: time.Now(), - } - - if err := s.repo.SaveMessage(ctx, msg); err != nil { - return err - } - - subs, err := s.repo.ListSubscriptions(ctx, topicID) - if err != nil { - return err - } - - // Delivery logic - for _, sub := range subs { - // Create a background context for async delivery to avoid request cancellation - // but keep it separate for each subscriber to avoid shared timeout issues - go func(c context.Context, sub *domain.Subscription) { - deliveryCtx, cancel := context.WithTimeout(c, 30*time.Second) - defer cancel() - - // Carry over potential trace IDs or other metadata if needed - // (Simplified for now, but avoids using request-scoped ctx) - s.deliver(deliveryCtx, sub, body) - }(ctx, sub) - } - - if err := s.eventSvc.RecordEvent(ctx, "TOPIC_PUBLISHED", topic.ID.String(), "TOPIC", map[string]interface{}{"message_id": msg.ID}); err != nil { - s.logger.Warn("failed to record topic publish event", "topic_id", topic.ID, "error", err) - } - - if err := s.auditSvc.Log(ctx, topic.UserID, "notify.publish", "topic", topic.ID.String(), map[string]interface{}{ - "message_id": msg.ID.String(), - }); err != nil { - s.logger.Warn("failed to log topic publish audit event", "topic_id", topic.ID, "error", err) - } - - return nil -} - -func (s *NotifyService) deliver(ctx context.Context, sub *domain.Subscription, body string) { - switch sub.Protocol { - case domain.ProtocolQueue: - s.deliverToQueue(ctx, sub, body) - case domain.ProtocolWebhook: - s.deliverToWebhook(ctx, sub, body) - } -} - -func (s *NotifyService) deliverToQueue(ctx context.Context, sub *domain.Subscription, body string) { - // Endpoint is Queue ARN or ID. Let's assume ID for simplicity or parse ARN. - // For now let's assume endpoint is the Queue UUID string. - qID, err := uuid.Parse(sub.Endpoint) - if err != nil { - s.logger.Warn("invalid queue ID in subscription", "endpoint", sub.Endpoint, "error", err) - return - } - // We need to bypass user check or use sub.UserID context - deliveryCtx := appcontext.WithUserID(ctx, sub.UserID) - if _, err = s.queueSvc.SendMessage(deliveryCtx, qID, body); err != nil { - s.logger.Warn("failed to deliver to queue", "queue_id", qID, "error", err) - } -} - -func (s *NotifyService) deliverToWebhook(ctx context.Context, sub *domain.Subscription, body string) { - req, _ := http.NewRequestWithContext(ctx, "POST", sub.Endpoint, bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) - if err != nil { - s.logger.Warn("failed to deliver to webhook", "endpoint", sub.Endpoint, "error", err) - return - } - _ = resp.Body.Close() -} +// Package services implements core business workflows. +package services + +import ( + "bytes" + "context" + "fmt" + "io" + "log/slog" + "net/http" + "time" + + "github.com/google/uuid" + appcontext "github.com/poyrazk/thecloud/internal/core/context" + "github.com/poyrazk/thecloud/internal/core/domain" + "github.com/poyrazk/thecloud/internal/core/ports" +) + +var webhookHTTPClient = &http.Client{Timeout: 15 * time.Second} + +// NotifyServiceParams defines the dependencies for NotifyService. +type NotifyServiceParams struct { + Repo ports.NotifyRepository + RBACSvc ports.RBACService + QueueSvc ports.QueueService + EventSvc ports.EventService + AuditSvc ports.AuditService + Logger *slog.Logger +} + +// NotifyService manages topics, subscriptions, and message delivery. +type NotifyService struct { + repo ports.NotifyRepository + rbacSvc ports.RBACService + queueSvc ports.QueueService + eventSvc ports.EventService + auditSvc ports.AuditService + logger *slog.Logger +} + +// NewNotifyService constructs a NotifyService with its dependencies. +func NewNotifyService(params NotifyServiceParams) ports.NotifyService { + return &NotifyService{ + repo: params.Repo, + rbacSvc: params.RBACSvc, + queueSvc: params.QueueSvc, + eventSvc: params.EventSvc, + auditSvc: params.AuditSvc, + logger: params.Logger, + } +} + +func (s *NotifyService) CreateTopic(ctx context.Context, name string) (*domain.Topic, error) { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyCreate, "*"); err != nil { + return nil, err + } + + existing, _ := s.repo.GetTopicByName(ctx, name, userID) + if existing != nil { + return nil, fmt.Errorf("topic with name %s already exists", name) + } + + id := uuid.New() + topic := &domain.Topic{ + ID: id, + UserID: userID, + Name: name, + ARN: fmt.Sprintf("arn:thecloud:notify:local:%s:topic/%s", userID, name), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := s.repo.CreateTopic(ctx, topic); err != nil { + return nil, err + } + + _ = s.eventSvc.RecordEvent(ctx, "TOPIC_CREATED", topic.ID.String(), "TOPIC", nil) + + _ = s.auditSvc.Log(ctx, topic.UserID, "notify.topic_create", "topic", topic.ID.String(), map[string]interface{}{ + "name": topic.Name, + }) + + return topic, nil +} + +func (s *NotifyService) ListTopics(ctx context.Context) ([]*domain.Topic, error) { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyRead, "*"); err != nil { + return nil, err + } + + return s.repo.ListTopics(ctx, userID) +} + +func (s *NotifyService) DeleteTopic(ctx context.Context, id uuid.UUID) error { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyDelete, id.String()); err != nil { + return err + } + + topic, err := s.repo.GetTopicByID(ctx, id, userID) + if err != nil { + return err + } + if topic == nil { + return fmt.Errorf("topic not found") + } + + if err := s.repo.DeleteTopic(ctx, id); err != nil { + return err + } + + _ = s.eventSvc.RecordEvent(ctx, "TOPIC_DELETED", id.String(), "TOPIC", nil) + + _ = s.auditSvc.Log(ctx, topic.UserID, "notify.topic_delete", "topic", topic.ID.String(), map[string]interface{}{ + "name": topic.Name, + }) + + return nil +} + +func (s *NotifyService) Subscribe(ctx context.Context, topicID uuid.UUID, protocol domain.SubscriptionProtocol, endpoint string) (*domain.Subscription, error) { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyWrite, topicID.String()); err != nil { + return nil, err + } + + // Verify topic exists and belongs to user + topic, err := s.repo.GetTopicByID(ctx, topicID, userID) + if err != nil { + return nil, err + } + + sub := &domain.Subscription{ + ID: uuid.New(), + UserID: userID, + TopicID: topic.ID, + Protocol: protocol, + Endpoint: endpoint, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := s.repo.CreateSubscription(ctx, sub); err != nil { + return nil, err + } + + _ = s.eventSvc.RecordEvent(ctx, "SUBSCRIPTION_CREATED", sub.ID.String(), "SUBSCRIPTION", map[string]interface{}{"topic_id": topicID}) + + _ = s.auditSvc.Log(ctx, sub.UserID, "notify.subscribe", "subscription", sub.ID.String(), map[string]interface{}{ + "topic_id": topicID.String(), + "protocol": protocol, + "endpoint": endpoint, + }) + + return sub, nil +} + +func (s *NotifyService) ListSubscriptions(ctx context.Context, topicID uuid.UUID) ([]*domain.Subscription, error) { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyRead, topicID.String()); err != nil { + return nil, err + } + + // Verify topic ownership + _, err := s.repo.GetTopicByID(ctx, topicID, userID) + if err != nil { + return nil, err + } + + return s.repo.ListSubscriptions(ctx, topicID) +} + +func (s *NotifyService) Unsubscribe(ctx context.Context, id uuid.UUID) error { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyDelete, id.String()); err != nil { + return err + } + + sub, err := s.repo.GetSubscriptionByID(ctx, id, userID) + if err != nil { + return err + } + + if err := s.repo.DeleteSubscription(ctx, sub.ID); err != nil { + return err + } + + _ = s.eventSvc.RecordEvent(ctx, "SUBSCRIPTION_DELETED", id.String(), "SUBSCRIPTION", nil) + + _ = s.auditSvc.Log(ctx, sub.UserID, "notify.unsubscribe", "subscription", sub.ID.String(), map[string]interface{}{ + "topic_id": sub.TopicID.String(), + }) + + return nil +} + +func (s *NotifyService) Publish(ctx context.Context, topicID uuid.UUID, body string) error { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyWrite, topicID.String()); err != nil { + return err + } + + topic, err := s.repo.GetTopicByID(ctx, topicID, userID) + if err != nil { + return err + } + + msg := &domain.NotifyMessage{ + ID: uuid.New(), + TopicID: topic.ID, + Body: body, + CreatedAt: time.Now(), + } + + if err := s.repo.SaveMessage(ctx, msg); err != nil { + return err + } + + subs, err := s.repo.ListSubscriptions(ctx, topicID) + if err != nil { + return err + } + + // Delivery logic + for _, sub := range subs { + // Create a background context for async delivery to avoid request cancellation + // but keep it separate for each subscriber to avoid shared timeout issues + go func(c context.Context, sub *domain.Subscription) { + deliveryCtx, cancel := context.WithTimeout(c, 30*time.Second) + defer cancel() + + // Carry over potential trace IDs or other metadata if needed + // (Simplified for now, but avoids using request-scoped ctx) + s.deliver(deliveryCtx, sub, body) + }(ctx, sub) + } + + if err := s.eventSvc.RecordEvent(ctx, "TOPIC_PUBLISHED", topic.ID.String(), "TOPIC", map[string]interface{}{"message_id": msg.ID}); err != nil { + s.logger.Warn("failed to record topic publish event", "topic_id", topic.ID, "error", err) + } + + if err := s.auditSvc.Log(ctx, topic.UserID, "notify.publish", "topic", topic.ID.String(), map[string]interface{}{ + "message_id": msg.ID.String(), + }); err != nil { + s.logger.Warn("failed to log topic publish audit event", "topic_id", topic.ID, "error", err) + } + + return nil +} + +func (s *NotifyService) deliver(ctx context.Context, sub *domain.Subscription, body string) { + switch sub.Protocol { + case domain.ProtocolQueue: + s.deliverToQueue(ctx, sub, body) + case domain.ProtocolWebhook: + s.deliverToWebhook(ctx, sub, body) + } +} + +func (s *NotifyService) deliverToQueue(ctx context.Context, sub *domain.Subscription, body string) { + // Endpoint is Queue ARN or ID. Let's assume ID for simplicity or parse ARN. + // For now let's assume endpoint is the Queue UUID string. + qID, err := uuid.Parse(sub.Endpoint) + if err != nil { + s.logger.Warn("invalid queue ID in subscription", "endpoint", sub.Endpoint, "error", err) + return + } + // We need to bypass user check or use sub.UserID context + deliveryCtx := appcontext.WithUserID(ctx, sub.UserID) + if _, err = s.queueSvc.SendMessage(deliveryCtx, qID, body); err != nil { + s.logger.Warn("failed to deliver to queue", "queue_id", qID, "error", err) + } +} + +func (s *NotifyService) deliverToWebhook(ctx context.Context, sub *domain.Subscription, body string) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, sub.Endpoint, bytes.NewBufferString(body)) + if err != nil { + s.logger.Warn("failed to build webhook request", "endpoint", sub.Endpoint, "error", err) + return + } + req.Header.Set("Content-Type", "application/json") + + resp, err := webhookHTTPClient.Do(req) + if err != nil { + s.logger.Warn("failed to deliver to webhook", "endpoint", sub.Endpoint, "error", err) + return + } + defer func() { + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + }() + + if resp.StatusCode >= 400 { + s.logger.Warn("webhook delivery failed", + "endpoint", sub.Endpoint, + "subscription_id", sub.ID, + "status", resp.StatusCode) + return + } +} diff --git a/internal/core/services/notify_unit_test.go b/internal/core/services/notify_unit_test.go index 20024ce81..424b3f450 100644 --- a/internal/core/services/notify_unit_test.go +++ b/internal/core/services/notify_unit_test.go @@ -1,432 +1,518 @@ -package services_test - -import ( - "context" - "fmt" - "log/slog" - "testing" - - "github.com/google/uuid" - appcontext "github.com/poyrazk/thecloud/internal/core/context" - "github.com/poyrazk/thecloud/internal/core/domain" - "github.com/poyrazk/thecloud/internal/core/services" - "github.com/poyrazk/thecloud/internal/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -const testTopicName = "my-topic" - -func TestNotifyServiceUnit(t *testing.T) { - t.Run("CRUD", testNotifyServiceUnitCRUD) - t.Run("RBACErrors", testNotifyServiceUnitRbacErrors) - t.Run("RepoErrors", testNotifyServiceUnitRepoErrors) - t.Run("PublishErrors", testNotifyServiceUnitPublishErrors) -} - -func testNotifyServiceUnitCRUD(t *testing.T) { - mockRepo := new(MockNotifyRepo) - mockQueueSvc := new(MockQueueService) - mockEventSvc := new(MockEventService) - mockAuditSvc := new(MockAuditService) - rbacSvc := new(MockRBACService) - rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - - svc := services.NewNotifyService(services.NotifyServiceParams{ - Repo: mockRepo, - RBACSvc: rbacSvc, - QueueSvc: mockQueueSvc, - EventSvc: mockEventSvc, - AuditSvc: mockAuditSvc, - Logger: slog.Default(), - }) - - ctx := context.Background() - userID := uuid.New() - tenantID := uuid.New() - ctx = appcontext.WithUserID(ctx, userID) - ctx = appcontext.WithTenantID(ctx, tenantID) - - t.Run("CreateTopic", func(t *testing.T) { - mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(nil, nil).Once() - mockRepo.On("CreateTopic", mock.Anything, mock.Anything).Return(nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_CREATED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.topic_create", "topic", mock.Anything, mock.Anything).Return(nil).Once() - - topic, err := svc.CreateTopic(ctx, testTopicName) - require.NoError(t, err) - assert.NotNil(t, topic) - assert.Equal(t, testTopicName, topic.Name) - }) - - t.Run("ListTopics", func(t *testing.T) { - mockRepo.On("ListTopics", mock.Anything, userID).Return([]*domain.Topic{{ID: uuid.New(), Name: "topic1"}}, nil).Once() - - topics, err := svc.ListTopics(ctx) - require.NoError(t, err) - assert.Len(t, topics, 1) - }) - - t.Run("DeleteTopic", func(t *testing.T) { - topicID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID, Name: "to-delete"} - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("DeleteTopic", mock.Anything, topicID).Return(nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_DELETED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.topic_delete", "topic", mock.Anything, mock.Anything).Return(nil).Once() - - err := svc.DeleteTopic(ctx, topicID) - require.NoError(t, err) - }) - - t.Run("Subscribe", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() - mockRepo.On("CreateSubscription", mock.Anything, mock.Anything).Return(nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "SUBSCRIPTION_CREATED", mock.Anything, "SUBSCRIPTION", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.subscribe", "subscription", mock.Anything, mock.Anything).Return(nil).Once() - - sub, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") - require.NoError(t, err) - assert.NotNil(t, sub) - }) - - t.Run("ListSubscriptions", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() - mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{{ID: uuid.New()}}, nil).Once() - - subs, err := svc.ListSubscriptions(ctx, topicID) - require.NoError(t, err) - assert.Len(t, subs, 1) - }) - - t.Run("Unsubscribe", func(t *testing.T) { - subID := uuid.New() - topicID := uuid.New() - sub := &domain.Subscription{ID: subID, UserID: userID, TopicID: topicID} - mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(sub, nil).Once() - mockRepo.On("DeleteSubscription", mock.Anything, subID).Return(nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "SUBSCRIPTION_DELETED", mock.Anything, "SUBSCRIPTION", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.unsubscribe", "subscription", mock.Anything, mock.Anything).Return(nil).Once() - - err := svc.Unsubscribe(ctx, subID) - require.NoError(t, err) - }) - - t.Run("Publish", func(t *testing.T) { - done := make(chan struct{}) - topicID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID} - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() - mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{}, nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.NoError(t, err) - <-done - }) -} - -func testNotifyServiceUnitRbacErrors(t *testing.T) { - mockRepo := new(MockNotifyRepo) - mockQueueSvc := new(MockQueueService) - mockEventSvc := new(MockEventService) - mockAuditSvc := new(MockAuditService) - rbacSvc := new(MockRBACService) - - svc := services.NewNotifyService(services.NotifyServiceParams{ - Repo: mockRepo, - RBACSvc: rbacSvc, - QueueSvc: mockQueueSvc, - EventSvc: mockEventSvc, - AuditSvc: mockAuditSvc, - Logger: slog.Default(), - }) - - ctx := context.Background() - userID := uuid.New() - tenantID := uuid.New() - ctx = appcontext.WithUserID(ctx, userID) - ctx = appcontext.WithTenantID(ctx, tenantID) - - type rbacCase struct { - name string - permission domain.Permission - resourceID string - invoke func(id string) error - } - - cases := []rbacCase{ - { - name: "CreateTopic_Unauthorized", - permission: domain.PermissionNotifyCreate, - resourceID: "*", - invoke: func(id string) error { - _, err := svc.CreateTopic(ctx, "my-topic") - return err - }, - }, - { - name: "ListTopics_Unauthorized", - permission: domain.PermissionNotifyRead, - resourceID: "*", - invoke: func(id string) error { - _, err := svc.ListTopics(ctx) - return err - }, - }, - { - name: "DeleteTopic_Unauthorized", - permission: domain.PermissionNotifyDelete, - resourceID: uuid.New().String(), - invoke: func(id string) error { - return svc.DeleteTopic(ctx, uuid.MustParse(id)) - }, - }, - { - name: "Subscribe_Unauthorized", - permission: domain.PermissionNotifyWrite, - resourceID: uuid.New().String(), - invoke: func(id string) error { - _, err := svc.Subscribe(ctx, uuid.MustParse(id), domain.ProtocolWebhook, "https://example.com/hook") - return err - }, - }, - { - name: "ListSubscriptions_Unauthorized", - permission: domain.PermissionNotifyRead, - resourceID: uuid.New().String(), - invoke: func(id string) error { - _, err := svc.ListSubscriptions(ctx, uuid.MustParse(id)) - return err - }, - }, - { - name: "Unsubscribe_Unauthorized", - permission: domain.PermissionNotifyDelete, - resourceID: uuid.New().String(), - invoke: func(id string) error { - return svc.Unsubscribe(ctx, uuid.MustParse(id)) - }, - }, - { - name: "Publish_Unauthorized", - permission: domain.PermissionNotifyWrite, - resourceID: uuid.New().String(), - invoke: func(id string) error { - return svc.Publish(ctx, uuid.MustParse(id), "hello") - }, - }, - } - - authErr := errors.New(errors.Forbidden, "permission denied") - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - rbacSvc.On("Authorize", mock.Anything, userID, tenantID, c.permission, c.resourceID).Return(authErr).Once() - err := c.invoke(c.resourceID) - require.Error(t, err) - assert.True(t, errors.Is(err, errors.Forbidden)) - }) - } -} - -func testNotifyServiceUnitRepoErrors(t *testing.T) { - mockRepo := new(MockNotifyRepo) - mockQueueSvc := new(MockQueueService) - mockEventSvc := new(MockEventService) - mockAuditSvc := new(MockAuditService) - rbacSvc := new(MockRBACService) - rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - - svc := services.NewNotifyService(services.NotifyServiceParams{ - Repo: mockRepo, - RBACSvc: rbacSvc, - QueueSvc: mockQueueSvc, - EventSvc: mockEventSvc, - AuditSvc: mockAuditSvc, - Logger: slog.Default(), - }) - - ctx := context.Background() - userID := uuid.New() - tenantID := uuid.New() - ctx = appcontext.WithUserID(ctx, userID) - ctx = appcontext.WithTenantID(ctx, tenantID) - - t.Run("CreateTopic_DuplicateName", func(t *testing.T) { - existing := &domain.Topic{ID: uuid.New(), Name: testTopicName} - mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(existing, nil).Once() - - _, err := svc.CreateTopic(ctx, testTopicName) - require.Error(t, err) - assert.Contains(t, err.Error(), "already exists") - }) - - t.Run("CreateTopic_RepoError", func(t *testing.T) { - mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(nil, nil).Once() - mockRepo.On("CreateTopic", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() - - _, err := svc.CreateTopic(ctx, testTopicName) - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) - - t.Run("DeleteTopic_NotFound", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, nil).Once() - - err := svc.DeleteTopic(ctx, topicID) - require.Error(t, err) - assert.Contains(t, err.Error(), "not found") - }) - - t.Run("DeleteTopic_RepoError", func(t *testing.T) { - topicID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID, Name: "test"} - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("DeleteTopic", mock.Anything, topicID).Return(fmt.Errorf("db error")).Once() - - err := svc.DeleteTopic(ctx, topicID) - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) - - t.Run("Subscribe_TopicNotFound", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() - - _, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") - require.Error(t, err) - }) - - t.Run("Subscribe_RepoError", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() - mockRepo.On("CreateSubscription", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() - - _, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) - - t.Run("Unsubscribe_NotFound", func(t *testing.T) { - subID := uuid.New() - mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() - - err := svc.Unsubscribe(ctx, subID) - require.Error(t, err) - }) - - t.Run("Unsubscribe_RepoError", func(t *testing.T) { - subID := uuid.New() - topicID := uuid.New() - sub := &domain.Subscription{ID: subID, UserID: userID, TopicID: topicID} - mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(sub, nil).Once() - mockRepo.On("DeleteSubscription", mock.Anything, subID).Return(fmt.Errorf("db error")).Once() - - err := svc.Unsubscribe(ctx, subID) - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) - - t.Run("ListSubscriptions_TopicNotFound", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() - - _, err := svc.ListSubscriptions(ctx, topicID) - require.Error(t, err) - }) - - t.Run("Publish_TopicNotFound", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.Error(t, err) - }) - - t.Run("Publish_SaveMessageError", func(t *testing.T) { - topicID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID} - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) -} - -func testNotifyServiceUnitPublishErrors(t *testing.T) { - mockRepo := new(MockNotifyRepo) - mockQueueSvc := new(MockQueueService) - mockEventSvc := new(MockEventService) - mockAuditSvc := new(MockAuditService) - rbacSvc := new(MockRBACService) - rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - - svc := services.NewNotifyService(services.NotifyServiceParams{ - Repo: mockRepo, - RBACSvc: rbacSvc, - QueueSvc: mockQueueSvc, - EventSvc: mockEventSvc, - AuditSvc: mockAuditSvc, - Logger: slog.Default(), - }) - - ctx := context.Background() - userID := uuid.New() - tenantID := uuid.New() - ctx = appcontext.WithUserID(ctx, userID) - ctx = appcontext.WithTenantID(ctx, tenantID) - - t.Run("Publish_WebhookDeliveryError", func(t *testing.T) { - done := make(chan struct{}) - topicID := uuid.New() - subID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID} - sub := &domain.Subscription{ - ID: subID, - UserID: userID, - TopicID: topicID, - Protocol: domain.ProtocolWebhook, - Endpoint: "http://localhost:9999/nonexistent", - } - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() - mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.NoError(t, err) - <-done - }) - - t.Run("Publish_QueueInvalidUUID", func(t *testing.T) { - done := make(chan struct{}) - topicID := uuid.New() - subID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID} - sub := &domain.Subscription{ - ID: subID, - UserID: userID, - TopicID: topicID, - Protocol: domain.ProtocolQueue, - Endpoint: "not-a-valid-uuid", - } - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() - mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.NoError(t, err) - <-done - }) +package services_test + +import ( + "bytes" + "context" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/google/uuid" + appcontext "github.com/poyrazk/thecloud/internal/core/context" + "github.com/poyrazk/thecloud/internal/core/domain" + "github.com/poyrazk/thecloud/internal/core/services" + "github.com/poyrazk/thecloud/internal/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +const testTopicName = "my-topic" + +func TestNotifyServiceUnit(t *testing.T) { + t.Run("CRUD", testNotifyServiceUnitCRUD) + t.Run("RBACErrors", testNotifyServiceUnitRbacErrors) + t.Run("RepoErrors", testNotifyServiceUnitRepoErrors) + t.Run("PublishErrors", testNotifyServiceUnitPublishErrors) +} + +func testNotifyServiceUnitCRUD(t *testing.T) { + mockRepo := new(MockNotifyRepo) + mockQueueSvc := new(MockQueueService) + mockEventSvc := new(MockEventService) + mockAuditSvc := new(MockAuditService) + rbacSvc := new(MockRBACService) + rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + svc := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo, + RBACSvc: rbacSvc, + QueueSvc: mockQueueSvc, + EventSvc: mockEventSvc, + AuditSvc: mockAuditSvc, + Logger: slog.Default(), + }) + + ctx := context.Background() + userID := uuid.New() + tenantID := uuid.New() + ctx = appcontext.WithUserID(ctx, userID) + ctx = appcontext.WithTenantID(ctx, tenantID) + + t.Run("CreateTopic", func(t *testing.T) { + mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(nil, nil).Once() + mockRepo.On("CreateTopic", mock.Anything, mock.Anything).Return(nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_CREATED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.topic_create", "topic", mock.Anything, mock.Anything).Return(nil).Once() + + topic, err := svc.CreateTopic(ctx, testTopicName) + require.NoError(t, err) + assert.NotNil(t, topic) + assert.Equal(t, testTopicName, topic.Name) + }) + + t.Run("ListTopics", func(t *testing.T) { + mockRepo.On("ListTopics", mock.Anything, userID).Return([]*domain.Topic{{ID: uuid.New(), Name: "topic1"}}, nil).Once() + + topics, err := svc.ListTopics(ctx) + require.NoError(t, err) + assert.Len(t, topics, 1) + }) + + t.Run("DeleteTopic", func(t *testing.T) { + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID, Name: "to-delete"} + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("DeleteTopic", mock.Anything, topicID).Return(nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_DELETED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.topic_delete", "topic", mock.Anything, mock.Anything).Return(nil).Once() + + err := svc.DeleteTopic(ctx, topicID) + require.NoError(t, err) + }) + + t.Run("Subscribe", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() + mockRepo.On("CreateSubscription", mock.Anything, mock.Anything).Return(nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "SUBSCRIPTION_CREATED", mock.Anything, "SUBSCRIPTION", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.subscribe", "subscription", mock.Anything, mock.Anything).Return(nil).Once() + + sub, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") + require.NoError(t, err) + assert.NotNil(t, sub) + }) + + t.Run("ListSubscriptions", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() + mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{{ID: uuid.New()}}, nil).Once() + + subs, err := svc.ListSubscriptions(ctx, topicID) + require.NoError(t, err) + assert.Len(t, subs, 1) + }) + + t.Run("Unsubscribe", func(t *testing.T) { + subID := uuid.New() + topicID := uuid.New() + sub := &domain.Subscription{ID: subID, UserID: userID, TopicID: topicID} + mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(sub, nil).Once() + mockRepo.On("DeleteSubscription", mock.Anything, subID).Return(nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "SUBSCRIPTION_DELETED", mock.Anything, "SUBSCRIPTION", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.unsubscribe", "subscription", mock.Anything, mock.Anything).Return(nil).Once() + + err := svc.Unsubscribe(ctx, subID) + require.NoError(t, err) + }) + + t.Run("Publish", func(t *testing.T) { + done := make(chan struct{}) + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() + mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{}, nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.NoError(t, err) + <-done + }) +} + +func testNotifyServiceUnitRbacErrors(t *testing.T) { + mockRepo := new(MockNotifyRepo) + mockQueueSvc := new(MockQueueService) + mockEventSvc := new(MockEventService) + mockAuditSvc := new(MockAuditService) + rbacSvc := new(MockRBACService) + + svc := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo, + RBACSvc: rbacSvc, + QueueSvc: mockQueueSvc, + EventSvc: mockEventSvc, + AuditSvc: mockAuditSvc, + Logger: slog.Default(), + }) + + ctx := context.Background() + userID := uuid.New() + tenantID := uuid.New() + ctx = appcontext.WithUserID(ctx, userID) + ctx = appcontext.WithTenantID(ctx, tenantID) + + type rbacCase struct { + name string + permission domain.Permission + resourceID string + invoke func(id string) error + } + + cases := []rbacCase{ + { + name: "CreateTopic_Unauthorized", + permission: domain.PermissionNotifyCreate, + resourceID: "*", + invoke: func(id string) error { + _, err := svc.CreateTopic(ctx, "my-topic") + return err + }, + }, + { + name: "ListTopics_Unauthorized", + permission: domain.PermissionNotifyRead, + resourceID: "*", + invoke: func(id string) error { + _, err := svc.ListTopics(ctx) + return err + }, + }, + { + name: "DeleteTopic_Unauthorized", + permission: domain.PermissionNotifyDelete, + resourceID: uuid.New().String(), + invoke: func(id string) error { + return svc.DeleteTopic(ctx, uuid.MustParse(id)) + }, + }, + { + name: "Subscribe_Unauthorized", + permission: domain.PermissionNotifyWrite, + resourceID: uuid.New().String(), + invoke: func(id string) error { + _, err := svc.Subscribe(ctx, uuid.MustParse(id), domain.ProtocolWebhook, "https://example.com/hook") + return err + }, + }, + { + name: "ListSubscriptions_Unauthorized", + permission: domain.PermissionNotifyRead, + resourceID: uuid.New().String(), + invoke: func(id string) error { + _, err := svc.ListSubscriptions(ctx, uuid.MustParse(id)) + return err + }, + }, + { + name: "Unsubscribe_Unauthorized", + permission: domain.PermissionNotifyDelete, + resourceID: uuid.New().String(), + invoke: func(id string) error { + return svc.Unsubscribe(ctx, uuid.MustParse(id)) + }, + }, + { + name: "Publish_Unauthorized", + permission: domain.PermissionNotifyWrite, + resourceID: uuid.New().String(), + invoke: func(id string) error { + return svc.Publish(ctx, uuid.MustParse(id), "hello") + }, + }, + } + + authErr := errors.New(errors.Forbidden, "permission denied") + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + rbacSvc.On("Authorize", mock.Anything, userID, tenantID, c.permission, c.resourceID).Return(authErr).Once() + err := c.invoke(c.resourceID) + require.Error(t, err) + assert.True(t, errors.Is(err, errors.Forbidden)) + }) + } +} + +func testNotifyServiceUnitRepoErrors(t *testing.T) { + mockRepo := new(MockNotifyRepo) + mockQueueSvc := new(MockQueueService) + mockEventSvc := new(MockEventService) + mockAuditSvc := new(MockAuditService) + rbacSvc := new(MockRBACService) + rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + svc := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo, + RBACSvc: rbacSvc, + QueueSvc: mockQueueSvc, + EventSvc: mockEventSvc, + AuditSvc: mockAuditSvc, + Logger: slog.Default(), + }) + + ctx := context.Background() + userID := uuid.New() + tenantID := uuid.New() + ctx = appcontext.WithUserID(ctx, userID) + ctx = appcontext.WithTenantID(ctx, tenantID) + + t.Run("CreateTopic_DuplicateName", func(t *testing.T) { + existing := &domain.Topic{ID: uuid.New(), Name: testTopicName} + mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(existing, nil).Once() + + _, err := svc.CreateTopic(ctx, testTopicName) + require.Error(t, err) + assert.Contains(t, err.Error(), "already exists") + }) + + t.Run("CreateTopic_RepoError", func(t *testing.T) { + mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(nil, nil).Once() + mockRepo.On("CreateTopic", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() + + _, err := svc.CreateTopic(ctx, testTopicName) + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) + + t.Run("DeleteTopic_NotFound", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, nil).Once() + + err := svc.DeleteTopic(ctx, topicID) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("DeleteTopic_RepoError", func(t *testing.T) { + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID, Name: "test"} + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("DeleteTopic", mock.Anything, topicID).Return(fmt.Errorf("db error")).Once() + + err := svc.DeleteTopic(ctx, topicID) + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) + + t.Run("Subscribe_TopicNotFound", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() + + _, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") + require.Error(t, err) + }) + + t.Run("Subscribe_RepoError", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() + mockRepo.On("CreateSubscription", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() + + _, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) + + t.Run("Unsubscribe_NotFound", func(t *testing.T) { + subID := uuid.New() + mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() + + err := svc.Unsubscribe(ctx, subID) + require.Error(t, err) + }) + + t.Run("Unsubscribe_RepoError", func(t *testing.T) { + subID := uuid.New() + topicID := uuid.New() + sub := &domain.Subscription{ID: subID, UserID: userID, TopicID: topicID} + mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(sub, nil).Once() + mockRepo.On("DeleteSubscription", mock.Anything, subID).Return(fmt.Errorf("db error")).Once() + + err := svc.Unsubscribe(ctx, subID) + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) + + t.Run("ListSubscriptions_TopicNotFound", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() + + _, err := svc.ListSubscriptions(ctx, topicID) + require.Error(t, err) + }) + + t.Run("Publish_TopicNotFound", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.Error(t, err) + }) + + t.Run("Publish_SaveMessageError", func(t *testing.T) { + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) +} + +func testNotifyServiceUnitPublishErrors(t *testing.T) { + mockRepo := new(MockNotifyRepo) + mockQueueSvc := new(MockQueueService) + mockEventSvc := new(MockEventService) + mockAuditSvc := new(MockAuditService) + rbacSvc := new(MockRBACService) + rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + svc := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo, + RBACSvc: rbacSvc, + QueueSvc: mockQueueSvc, + EventSvc: mockEventSvc, + AuditSvc: mockAuditSvc, + Logger: slog.Default(), + }) + + ctx := context.Background() + userID := uuid.New() + tenantID := uuid.New() + ctx = appcontext.WithUserID(ctx, userID) + ctx = appcontext.WithTenantID(ctx, tenantID) + + t.Run("Publish_WebhookDeliveryError", func(t *testing.T) { + done := make(chan struct{}) + topicID := uuid.New() + subID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + sub := &domain.Subscription{ + ID: subID, + UserID: userID, + TopicID: topicID, + Protocol: domain.ProtocolWebhook, + Endpoint: "http://localhost:9999/nonexistent", + } + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() + mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.NoError(t, err) + <-done + }) + + t.Run("Publish_WebhookNon2xxStatus", func(t *testing.T) { + // Issue #338: webhook delivery must surface non-2xx HTTP responses. + var ( + mu sync.Mutex + received bool + ) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + mu.Lock() + received = true + mu.Unlock() + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + var logBuf bytes.Buffer + var logMu sync.Mutex + capturingLogger := slog.New(slog.NewTextHandler(&lockedWriter{w: &logBuf, mu: &logMu}, &slog.HandlerOptions{Level: slog.LevelDebug})) + + mockRepo2 := new(MockNotifyRepo) + mockQueueSvc2 := new(MockQueueService) + mockEventSvc2 := new(MockEventService) + mockAuditSvc2 := new(MockAuditService) + rbacSvc2 := new(MockRBACService) + rbacSvc2.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + svc2 := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo2, + RBACSvc: rbacSvc2, + QueueSvc: mockQueueSvc2, + EventSvc: mockEventSvc2, + AuditSvc: mockAuditSvc2, + Logger: capturingLogger, + }) + + done := make(chan struct{}) + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + sub := &domain.Subscription{ + ID: uuid.New(), + UserID: userID, + TopicID: topicID, + Protocol: domain.ProtocolWebhook, + Endpoint: server.URL, + } + mockRepo2.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo2.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() + mockRepo2.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() + mockEventSvc2.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc2.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() + + err := svc2.Publish(ctx, topicID, "hello") + require.NoError(t, err) + <-done + + // Allow the async webhook goroutine a moment to fire and log. + require.Eventually(t, func() bool { + mu.Lock() + defer mu.Unlock() + return received + }, 2*time.Second, 10*time.Millisecond, "webhook server never received request") + + require.Eventually(t, func() bool { + logMu.Lock() + defer logMu.Unlock() + return strings.Contains(logBuf.String(), "webhook delivery failed") && + strings.Contains(logBuf.String(), "status=500") + }, 2*time.Second, 10*time.Millisecond, "expected webhook delivery failure log with status=500") + }) + + t.Run("Publish_QueueInvalidUUID", func(t *testing.T) { + done := make(chan struct{}) + topicID := uuid.New() + subID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + sub := &domain.Subscription{ + ID: subID, + UserID: userID, + TopicID: topicID, + Protocol: domain.ProtocolQueue, + Endpoint: "not-a-valid-uuid", + } + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() + mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.NoError(t, err) + <-done + }) +} + +type lockedWriter struct { + w *bytes.Buffer + mu *sync.Mutex +} + +func (l *lockedWriter) Write(p []byte) (int, error) { + l.mu.Lock() + defer l.mu.Unlock() + return l.w.Write(p) } \ No newline at end of file From 05215036d456083c6f429c3b694d8ab1e75963be Mon Sep 17 00:00:00 2001 From: jackthepunished <107313375+jackthepunished@users.noreply.github.com> Date: Wed, 29 Apr 2026 23:16:42 +0300 Subject: [PATCH 4/4] fix(#299): scope pipelines, queues, and caches by tenant Three pipeline_repo SQL statements (UpdateBuild, ListBuildSteps, ListBuildLogs) still filtered by user_id even though their port declares the second parameter as tenantID. Result: a teammate viewing another user's build saw no steps or logs, and the worker's UpdateBuild silently scoped to the wrong column. The pipeline worker also passed job.UserID where the repo expects tenantID, so builds loaded only when user_id happened to equal tenant_id. - Rewrite the three SQL statements to scope by builds.tenant_id. - Add TenantID to BuildJob; populate it in createAndQueueBuild and consume it in the worker (loadBuildAndPipeline + ctx enrichment). - New migration 108 drops UNIQUE(user_id, name) on pipelines, caches, and queues and replaces it with UNIQUE(tenant_id, name), which is what the service-layer name checks have always assumed. A defensive CTE renames pre-existing collisions before the constraint flip so the migration is safe on populated databases. UserID stays on rows for attribution; only the scope of queries and uniqueness changes, matching the issue's "tenant ownership with user attribution" prescription. --- internal/core/domain/jobs.go | 1 + internal/core/services/pipeline.go | 2 + ...tenant_scoped_resource_uniqueness.down.sql | 9 +++ ...8_tenant_scoped_resource_uniqueness.up.sql | 69 +++++++++++++++++++ .../repositories/postgres/pipeline_repo.go | 16 ++--- internal/workers/pipeline_worker.go | 5 +- internal/workers/pipeline_worker_test.go | 9 +-- 7 files changed, 97 insertions(+), 14 deletions(-) create mode 100644 internal/repositories/postgres/migrations/108_tenant_scoped_resource_uniqueness.down.sql create mode 100644 internal/repositories/postgres/migrations/108_tenant_scoped_resource_uniqueness.up.sql diff --git a/internal/core/domain/jobs.go b/internal/core/domain/jobs.go index 525c0316b..ad59358c3 100644 --- a/internal/core/domain/jobs.go +++ b/internal/core/domain/jobs.go @@ -49,6 +49,7 @@ type BuildJob struct { BuildID uuid.UUID `json:"build_id"` PipelineID uuid.UUID `json:"pipeline_id"` UserID uuid.UUID `json:"user_id"` + TenantID uuid.UUID `json:"tenant_id"` CommitHash string `json:"commit_hash,omitempty"` Trigger BuildTriggerType `json:"trigger"` } diff --git a/internal/core/services/pipeline.go b/internal/core/services/pipeline.go index 90d596afa..037398a04 100644 --- a/internal/core/services/pipeline.go +++ b/internal/core/services/pipeline.go @@ -237,6 +237,7 @@ func (s *PipelineService) TriggerBuildWebhook(ctx context.Context, opts ports.We } webhookCtx := appcontext.WithUserID(ctx, pipeline.UserID) + webhookCtx = appcontext.WithTenantID(webhookCtx, pipeline.TenantID) return s.createAndQueueBuild(webhookCtx, pipeline, commitHash, domain.BuildTriggerWebhook) } @@ -263,6 +264,7 @@ func (s *PipelineService) createAndQueueBuild(ctx context.Context, pipeline *dom BuildID: build.ID, PipelineID: build.PipelineID, UserID: build.UserID, + TenantID: build.TenantID, CommitHash: build.CommitHash, Trigger: build.TriggerType, } diff --git a/internal/repositories/postgres/migrations/108_tenant_scoped_resource_uniqueness.down.sql b/internal/repositories/postgres/migrations/108_tenant_scoped_resource_uniqueness.down.sql new file mode 100644 index 000000000..741e1fde7 --- /dev/null +++ b/internal/repositories/postgres/migrations/108_tenant_scoped_resource_uniqueness.down.sql @@ -0,0 +1,9 @@ +-- +goose Down + +ALTER TABLE pipelines DROP CONSTRAINT IF EXISTS pipelines_tenant_id_name_key; +ALTER TABLE caches DROP CONSTRAINT IF EXISTS caches_tenant_id_name_key; +ALTER TABLE queues DROP CONSTRAINT IF EXISTS queues_tenant_id_name_key; + +ALTER TABLE pipelines ADD CONSTRAINT pipelines_user_id_name_key UNIQUE (user_id, name); +ALTER TABLE caches ADD CONSTRAINT caches_user_id_name_key UNIQUE (user_id, name); +ALTER TABLE queues ADD CONSTRAINT queues_user_id_name_key UNIQUE (user_id, name); diff --git a/internal/repositories/postgres/migrations/108_tenant_scoped_resource_uniqueness.up.sql b/internal/repositories/postgres/migrations/108_tenant_scoped_resource_uniqueness.up.sql new file mode 100644 index 000000000..0601ae318 --- /dev/null +++ b/internal/repositories/postgres/migrations/108_tenant_scoped_resource_uniqueness.up.sql @@ -0,0 +1,69 @@ +-- +goose Up + +-- Pipelines, caches, and queues were created with UNIQUE(user_id, name), which +-- allowed two users in the same tenant to register identically named +-- resources. The application layer scopes lookups by tenant_id, so the +-- service-layer name check sees only one row while the duplicate slips past +-- the user-scoped DB constraint. Replace with tenant-scoped uniqueness. + +-- Disambiguate any preexisting duplicates within the same tenant by appending +-- a short user-id suffix to all but the oldest collider. Required so the new +-- UNIQUE(tenant_id, name) constraint can be added without conflicts. +WITH ranked AS ( + SELECT id, + tenant_id, + name, + user_id, + ROW_NUMBER() OVER ( + PARTITION BY tenant_id, name + ORDER BY created_at, id + ) AS rn + FROM pipelines + WHERE tenant_id IS NOT NULL +) +UPDATE pipelines p +SET name = p.name || '-' || substr(replace(r.user_id::text, '-', ''), 1, 8) +FROM ranked r +WHERE p.id = r.id AND r.rn > 1; + +WITH ranked AS ( + SELECT id, + tenant_id, + name, + user_id, + ROW_NUMBER() OVER ( + PARTITION BY tenant_id, name + ORDER BY created_at, id + ) AS rn + FROM caches + WHERE tenant_id IS NOT NULL +) +UPDATE caches c +SET name = c.name || '-' || substr(replace(r.user_id::text, '-', ''), 1, 8) +FROM ranked r +WHERE c.id = r.id AND r.rn > 1; + +WITH ranked AS ( + SELECT id, + tenant_id, + name, + user_id, + ROW_NUMBER() OVER ( + PARTITION BY tenant_id, name + ORDER BY created_at, id + ) AS rn + FROM queues + WHERE tenant_id IS NOT NULL +) +UPDATE queues q +SET name = q.name || '-' || substr(replace(r.user_id::text, '-', ''), 1, 8) +FROM ranked r +WHERE q.id = r.id AND r.rn > 1; + +ALTER TABLE pipelines DROP CONSTRAINT IF EXISTS pipelines_user_id_name_key; +ALTER TABLE caches DROP CONSTRAINT IF EXISTS caches_user_id_name_key; +ALTER TABLE queues DROP CONSTRAINT IF EXISTS queues_user_id_name_key; + +ALTER TABLE pipelines ADD CONSTRAINT pipelines_tenant_id_name_key UNIQUE (tenant_id, name); +ALTER TABLE caches ADD CONSTRAINT caches_tenant_id_name_key UNIQUE (tenant_id, name); +ALTER TABLE queues ADD CONSTRAINT queues_tenant_id_name_key UNIQUE (tenant_id, name); diff --git a/internal/repositories/postgres/pipeline_repo.go b/internal/repositories/postgres/pipeline_repo.go index 258618c09..dfe557fa1 100644 --- a/internal/repositories/postgres/pipeline_repo.go +++ b/internal/repositories/postgres/pipeline_repo.go @@ -168,9 +168,9 @@ func (r *PipelineRepository) UpdateBuild(ctx context.Context, build *domain.Buil started_at = $2, finished_at = $3, updated_at = NOW() - WHERE id = $4 AND user_id = $5 + WHERE id = $4 AND tenant_id = $5 ` - _, err := r.db.Exec(ctx, query, build.Status, build.StartedAt, build.FinishedAt, build.ID, build.UserID) + _, err := r.db.Exec(ctx, query, build.Status, build.StartedAt, build.FinishedAt, build.ID, build.TenantID) return err } @@ -195,15 +195,15 @@ func (r *PipelineRepository) CreateBuildStep(ctx context.Context, step *domain.B return err } -func (r *PipelineRepository) ListBuildSteps(ctx context.Context, buildID, userID uuid.UUID) ([]*domain.BuildStep, error) { +func (r *PipelineRepository) ListBuildSteps(ctx context.Context, buildID, tenantID uuid.UUID) ([]*domain.BuildStep, error) { query := ` SELECT s.id, s.build_id, s.name, s.image, s.commands, s.status, s.exit_code, s.started_at, s.finished_at, s.created_at, s.updated_at FROM build_steps s INNER JOIN builds b ON b.id = s.build_id - WHERE s.build_id = $1 AND b.user_id = $2 + WHERE s.build_id = $1 AND b.tenant_id = $2 ORDER BY s.created_at ASC ` - rows, err := r.db.Query(ctx, query, buildID, userID) + rows, err := r.db.Query(ctx, query, buildID, tenantID) if err != nil { return nil, err } @@ -233,7 +233,7 @@ func (r *PipelineRepository) AppendBuildLog(ctx context.Context, log *domain.Bui return err } -func (r *PipelineRepository) ListBuildLogs(ctx context.Context, buildID, userID uuid.UUID, limit int) ([]*domain.BuildLog, error) { +func (r *PipelineRepository) ListBuildLogs(ctx context.Context, buildID, tenantID uuid.UUID, limit int) ([]*domain.BuildLog, error) { if limit <= 0 { limit = 200 } @@ -242,11 +242,11 @@ func (r *PipelineRepository) ListBuildLogs(ctx context.Context, buildID, userID SELECT l.id, l.build_id, l.step_id, l.content, l.created_at FROM build_logs l INNER JOIN builds b ON b.id = l.build_id - WHERE l.build_id = $1 AND b.user_id = $2 + WHERE l.build_id = $1 AND b.tenant_id = $2 ORDER BY l.created_at ASC LIMIT $3 ` - rows, err := r.db.Query(ctx, query, buildID, userID, limit) + rows, err := r.db.Query(ctx, query, buildID, tenantID, limit) if err != nil { return nil, err } diff --git a/internal/workers/pipeline_worker.go b/internal/workers/pipeline_worker.go index c232d39fa..3cf60c3fc 100644 --- a/internal/workers/pipeline_worker.go +++ b/internal/workers/pipeline_worker.go @@ -138,6 +138,7 @@ func (w *PipelineWorker) processJob(workerCtx context.Context, msg *ports.Durabl ctx, cancel := context.WithTimeout(workerCtx, 30*time.Minute) defer cancel() ctx = appcontext.WithUserID(ctx, job.UserID) + ctx = appcontext.WithTenantID(ctx, job.TenantID) build, pipeline, err := w.loadBuildAndPipeline(ctx, job) if err != nil { @@ -213,7 +214,7 @@ func (w *PipelineWorker) processJob(workerCtx context.Context, msg *ports.Durabl } func (w *PipelineWorker) loadBuildAndPipeline(ctx context.Context, job domain.BuildJob) (*domain.Build, *domain.Pipeline, error) { - build, err := w.repo.GetBuild(ctx, job.BuildID, job.UserID) + build, err := w.repo.GetBuild(ctx, job.BuildID, job.TenantID) if err != nil { w.logger.Error("failed to load build", "build_id", job.BuildID, "error", err) return nil, nil, err @@ -222,7 +223,7 @@ func (w *PipelineWorker) loadBuildAndPipeline(ctx context.Context, job domain.Bu return nil, nil, nil } - pipeline, err := w.repo.GetPipeline(ctx, job.PipelineID, job.UserID) + pipeline, err := w.repo.GetPipeline(ctx, job.PipelineID, job.TenantID) if err != nil { w.logger.Error("failed to load pipeline", "pipeline_id", job.PipelineID, "error", err) w.failBuild(ctx, build, "pipeline load error: "+err.Error()) diff --git a/internal/workers/pipeline_worker_test.go b/internal/workers/pipeline_worker_test.go index 3b305c049..e5ac62e2b 100644 --- a/internal/workers/pipeline_worker_test.go +++ b/internal/workers/pipeline_worker_test.go @@ -181,11 +181,12 @@ func TestPipelineWorker_processJob(t *testing.T) { buildID := uuid.New() pipelineID := uuid.New() userID := uuid.New() - job := domain.BuildJob{BuildID: buildID, PipelineID: pipelineID, UserID: userID} + tenantID := uuid.New() + job := domain.BuildJob{BuildID: buildID, PipelineID: pipelineID, UserID: userID, TenantID: tenantID} msg := &ports.DurableMessage{ID: "1-0", Queue: pipelineQueueName} t.Run("Success", func(t *testing.T) { - build := &domain.Build{ID: buildID, PipelineID: pipelineID, UserID: userID} + build := &domain.Build{ID: buildID, PipelineID: pipelineID, UserID: userID, TenantID: tenantID} pipeline := &domain.Pipeline{ ID: pipelineID, Config: domain.PipelineConfig{ @@ -200,8 +201,8 @@ func TestPipelineWorker_processJob(t *testing.T) { }, } - repo.On("GetBuild", mock.Anything, buildID, userID).Return(build, nil).Once() - repo.On("GetPipeline", mock.Anything, pipelineID, userID).Return(pipeline, nil).Once() + repo.On("GetBuild", mock.Anything, buildID, tenantID).Return(build, nil).Once() + repo.On("GetPipeline", mock.Anything, pipelineID, tenantID).Return(pipeline, nil).Once() repo.On("UpdateBuild", mock.Anything, mock.MatchedBy(func(b *domain.Build) bool { return b.Status == domain.BuildStatusRunning })).Return(nil).Once()