diff --git a/internal/core/domain/jobs.go b/internal/core/domain/jobs.go index 525c0316b..ad59358c3 100644 --- a/internal/core/domain/jobs.go +++ b/internal/core/domain/jobs.go @@ -49,6 +49,7 @@ type BuildJob struct { BuildID uuid.UUID `json:"build_id"` PipelineID uuid.UUID `json:"pipeline_id"` UserID uuid.UUID `json:"user_id"` + TenantID uuid.UUID `json:"tenant_id"` CommitHash string `json:"commit_hash,omitempty"` Trigger BuildTriggerType `json:"trigger"` } diff --git a/internal/core/services/notify.go b/internal/core/services/notify.go index babfc7aa3..517121499 100644 --- a/internal/core/services/notify.go +++ b/internal/core/services/notify.go @@ -1,297 +1,316 @@ -// Package services implements core business workflows. -package services - -import ( - "bytes" - "context" - "fmt" - "log/slog" - "net/http" - "time" - - "github.com/google/uuid" - appcontext "github.com/poyrazk/thecloud/internal/core/context" - "github.com/poyrazk/thecloud/internal/core/domain" - "github.com/poyrazk/thecloud/internal/core/ports" -) - -// NotifyServiceParams defines the dependencies for NotifyService. -type NotifyServiceParams struct { - Repo ports.NotifyRepository - RBACSvc ports.RBACService - QueueSvc ports.QueueService - EventSvc ports.EventService - AuditSvc ports.AuditService - Logger *slog.Logger -} - -// NotifyService manages topics, subscriptions, and message delivery. -type NotifyService struct { - repo ports.NotifyRepository - rbacSvc ports.RBACService - queueSvc ports.QueueService - eventSvc ports.EventService - auditSvc ports.AuditService - logger *slog.Logger -} - -// NewNotifyService constructs a NotifyService with its dependencies. -func NewNotifyService(params NotifyServiceParams) ports.NotifyService { - return &NotifyService{ - repo: params.Repo, - rbacSvc: params.RBACSvc, - queueSvc: params.QueueSvc, - eventSvc: params.EventSvc, - auditSvc: params.AuditSvc, - logger: params.Logger, - } -} - -func (s *NotifyService) CreateTopic(ctx context.Context, name string) (*domain.Topic, error) { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyCreate, "*"); err != nil { - return nil, err - } - - existing, _ := s.repo.GetTopicByName(ctx, name, userID) - if existing != nil { - return nil, fmt.Errorf("topic with name %s already exists", name) - } - - id := uuid.New() - topic := &domain.Topic{ - ID: id, - UserID: userID, - Name: name, - ARN: fmt.Sprintf("arn:thecloud:notify:local:%s:topic/%s", userID, name), - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } - - if err := s.repo.CreateTopic(ctx, topic); err != nil { - return nil, err - } - - _ = s.eventSvc.RecordEvent(ctx, "TOPIC_CREATED", topic.ID.String(), "TOPIC", nil) - - _ = s.auditSvc.Log(ctx, topic.UserID, "notify.topic_create", "topic", topic.ID.String(), map[string]interface{}{ - "name": topic.Name, - }) - - return topic, nil -} - -func (s *NotifyService) ListTopics(ctx context.Context) ([]*domain.Topic, error) { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyRead, "*"); err != nil { - return nil, err - } - - return s.repo.ListTopics(ctx, userID) -} - -func (s *NotifyService) DeleteTopic(ctx context.Context, id uuid.UUID) error { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyDelete, id.String()); err != nil { - return err - } - - topic, err := s.repo.GetTopicByID(ctx, id, userID) - if err != nil { - return err - } - if topic == nil { - return fmt.Errorf("topic not found") - } - - if err := s.repo.DeleteTopic(ctx, id); err != nil { - return err - } - - _ = s.eventSvc.RecordEvent(ctx, "TOPIC_DELETED", id.String(), "TOPIC", nil) - - _ = s.auditSvc.Log(ctx, topic.UserID, "notify.topic_delete", "topic", topic.ID.String(), map[string]interface{}{ - "name": topic.Name, - }) - - return nil -} - -func (s *NotifyService) Subscribe(ctx context.Context, topicID uuid.UUID, protocol domain.SubscriptionProtocol, endpoint string) (*domain.Subscription, error) { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyWrite, topicID.String()); err != nil { - return nil, err - } - - // Verify topic exists and belongs to user - topic, err := s.repo.GetTopicByID(ctx, topicID, userID) - if err != nil { - return nil, err - } - - sub := &domain.Subscription{ - ID: uuid.New(), - UserID: userID, - TopicID: topic.ID, - Protocol: protocol, - Endpoint: endpoint, - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - } - - if err := s.repo.CreateSubscription(ctx, sub); err != nil { - return nil, err - } - - _ = s.eventSvc.RecordEvent(ctx, "SUBSCRIPTION_CREATED", sub.ID.String(), "SUBSCRIPTION", map[string]interface{}{"topic_id": topicID}) - - _ = s.auditSvc.Log(ctx, sub.UserID, "notify.subscribe", "subscription", sub.ID.String(), map[string]interface{}{ - "topic_id": topicID.String(), - "protocol": protocol, - "endpoint": endpoint, - }) - - return sub, nil -} - -func (s *NotifyService) ListSubscriptions(ctx context.Context, topicID uuid.UUID) ([]*domain.Subscription, error) { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyRead, topicID.String()); err != nil { - return nil, err - } - - // Verify topic ownership - _, err := s.repo.GetTopicByID(ctx, topicID, userID) - if err != nil { - return nil, err - } - - return s.repo.ListSubscriptions(ctx, topicID) -} - -func (s *NotifyService) Unsubscribe(ctx context.Context, id uuid.UUID) error { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyDelete, id.String()); err != nil { - return err - } - - sub, err := s.repo.GetSubscriptionByID(ctx, id, userID) - if err != nil { - return err - } - - if err := s.repo.DeleteSubscription(ctx, sub.ID); err != nil { - return err - } - - _ = s.eventSvc.RecordEvent(ctx, "SUBSCRIPTION_DELETED", id.String(), "SUBSCRIPTION", nil) - - _ = s.auditSvc.Log(ctx, sub.UserID, "notify.unsubscribe", "subscription", sub.ID.String(), map[string]interface{}{ - "topic_id": sub.TopicID.String(), - }) - - return nil -} - -func (s *NotifyService) Publish(ctx context.Context, topicID uuid.UUID, body string) error { - userID := appcontext.UserIDFromContext(ctx) - tenantID := appcontext.TenantIDFromContext(ctx) - - if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyWrite, topicID.String()); err != nil { - return err - } - - topic, err := s.repo.GetTopicByID(ctx, topicID, userID) - if err != nil { - return err - } - - msg := &domain.NotifyMessage{ - ID: uuid.New(), - TopicID: topic.ID, - Body: body, - CreatedAt: time.Now(), - } - - if err := s.repo.SaveMessage(ctx, msg); err != nil { - return err - } - - subs, err := s.repo.ListSubscriptions(ctx, topicID) - if err != nil { - return err - } - - // Delivery logic - for _, sub := range subs { - // Create a background context for async delivery to avoid request cancellation - // but keep it separate for each subscriber to avoid shared timeout issues - go func(c context.Context, sub *domain.Subscription) { - deliveryCtx, cancel := context.WithTimeout(c, 30*time.Second) - defer cancel() - - // Carry over potential trace IDs or other metadata if needed - // (Simplified for now, but avoids using request-scoped ctx) - s.deliver(deliveryCtx, sub, body) - }(ctx, sub) - } - - if err := s.eventSvc.RecordEvent(ctx, "TOPIC_PUBLISHED", topic.ID.String(), "TOPIC", map[string]interface{}{"message_id": msg.ID}); err != nil { - s.logger.Warn("failed to record topic publish event", "topic_id", topic.ID, "error", err) - } - - if err := s.auditSvc.Log(ctx, topic.UserID, "notify.publish", "topic", topic.ID.String(), map[string]interface{}{ - "message_id": msg.ID.String(), - }); err != nil { - s.logger.Warn("failed to log topic publish audit event", "topic_id", topic.ID, "error", err) - } - - return nil -} - -func (s *NotifyService) deliver(ctx context.Context, sub *domain.Subscription, body string) { - switch sub.Protocol { - case domain.ProtocolQueue: - s.deliverToQueue(ctx, sub, body) - case domain.ProtocolWebhook: - s.deliverToWebhook(ctx, sub, body) - } -} - -func (s *NotifyService) deliverToQueue(ctx context.Context, sub *domain.Subscription, body string) { - // Endpoint is Queue ARN or ID. Let's assume ID for simplicity or parse ARN. - // For now let's assume endpoint is the Queue UUID string. - qID, err := uuid.Parse(sub.Endpoint) - if err != nil { - s.logger.Warn("invalid queue ID in subscription", "endpoint", sub.Endpoint, "error", err) - return - } - // We need to bypass user check or use sub.UserID context - deliveryCtx := appcontext.WithUserID(ctx, sub.UserID) - if _, err = s.queueSvc.SendMessage(deliveryCtx, qID, body); err != nil { - s.logger.Warn("failed to deliver to queue", "queue_id", qID, "error", err) - } -} - -func (s *NotifyService) deliverToWebhook(ctx context.Context, sub *domain.Subscription, body string) { - req, _ := http.NewRequestWithContext(ctx, "POST", sub.Endpoint, bytes.NewBufferString(body)) - req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) - if err != nil { - s.logger.Warn("failed to deliver to webhook", "endpoint", sub.Endpoint, "error", err) - return - } - _ = resp.Body.Close() -} +// Package services implements core business workflows. +package services + +import ( + "bytes" + "context" + "fmt" + "io" + "log/slog" + "net/http" + "time" + + "github.com/google/uuid" + appcontext "github.com/poyrazk/thecloud/internal/core/context" + "github.com/poyrazk/thecloud/internal/core/domain" + "github.com/poyrazk/thecloud/internal/core/ports" +) + +var webhookHTTPClient = &http.Client{Timeout: 15 * time.Second} + +// NotifyServiceParams defines the dependencies for NotifyService. +type NotifyServiceParams struct { + Repo ports.NotifyRepository + RBACSvc ports.RBACService + QueueSvc ports.QueueService + EventSvc ports.EventService + AuditSvc ports.AuditService + Logger *slog.Logger +} + +// NotifyService manages topics, subscriptions, and message delivery. +type NotifyService struct { + repo ports.NotifyRepository + rbacSvc ports.RBACService + queueSvc ports.QueueService + eventSvc ports.EventService + auditSvc ports.AuditService + logger *slog.Logger +} + +// NewNotifyService constructs a NotifyService with its dependencies. +func NewNotifyService(params NotifyServiceParams) ports.NotifyService { + return &NotifyService{ + repo: params.Repo, + rbacSvc: params.RBACSvc, + queueSvc: params.QueueSvc, + eventSvc: params.EventSvc, + auditSvc: params.AuditSvc, + logger: params.Logger, + } +} + +func (s *NotifyService) CreateTopic(ctx context.Context, name string) (*domain.Topic, error) { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyCreate, "*"); err != nil { + return nil, err + } + + existing, _ := s.repo.GetTopicByName(ctx, name, userID) + if existing != nil { + return nil, fmt.Errorf("topic with name %s already exists", name) + } + + id := uuid.New() + topic := &domain.Topic{ + ID: id, + UserID: userID, + Name: name, + ARN: fmt.Sprintf("arn:thecloud:notify:local:%s:topic/%s", userID, name), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := s.repo.CreateTopic(ctx, topic); err != nil { + return nil, err + } + + _ = s.eventSvc.RecordEvent(ctx, "TOPIC_CREATED", topic.ID.String(), "TOPIC", nil) + + _ = s.auditSvc.Log(ctx, topic.UserID, "notify.topic_create", "topic", topic.ID.String(), map[string]interface{}{ + "name": topic.Name, + }) + + return topic, nil +} + +func (s *NotifyService) ListTopics(ctx context.Context) ([]*domain.Topic, error) { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyRead, "*"); err != nil { + return nil, err + } + + return s.repo.ListTopics(ctx, userID) +} + +func (s *NotifyService) DeleteTopic(ctx context.Context, id uuid.UUID) error { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyDelete, id.String()); err != nil { + return err + } + + topic, err := s.repo.GetTopicByID(ctx, id, userID) + if err != nil { + return err + } + if topic == nil { + return fmt.Errorf("topic not found") + } + + if err := s.repo.DeleteTopic(ctx, id); err != nil { + return err + } + + _ = s.eventSvc.RecordEvent(ctx, "TOPIC_DELETED", id.String(), "TOPIC", nil) + + _ = s.auditSvc.Log(ctx, topic.UserID, "notify.topic_delete", "topic", topic.ID.String(), map[string]interface{}{ + "name": topic.Name, + }) + + return nil +} + +func (s *NotifyService) Subscribe(ctx context.Context, topicID uuid.UUID, protocol domain.SubscriptionProtocol, endpoint string) (*domain.Subscription, error) { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyWrite, topicID.String()); err != nil { + return nil, err + } + + // Verify topic exists and belongs to user + topic, err := s.repo.GetTopicByID(ctx, topicID, userID) + if err != nil { + return nil, err + } + + sub := &domain.Subscription{ + ID: uuid.New(), + UserID: userID, + TopicID: topic.ID, + Protocol: protocol, + Endpoint: endpoint, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err := s.repo.CreateSubscription(ctx, sub); err != nil { + return nil, err + } + + _ = s.eventSvc.RecordEvent(ctx, "SUBSCRIPTION_CREATED", sub.ID.String(), "SUBSCRIPTION", map[string]interface{}{"topic_id": topicID}) + + _ = s.auditSvc.Log(ctx, sub.UserID, "notify.subscribe", "subscription", sub.ID.String(), map[string]interface{}{ + "topic_id": topicID.String(), + "protocol": protocol, + "endpoint": endpoint, + }) + + return sub, nil +} + +func (s *NotifyService) ListSubscriptions(ctx context.Context, topicID uuid.UUID) ([]*domain.Subscription, error) { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyRead, topicID.String()); err != nil { + return nil, err + } + + // Verify topic ownership + _, err := s.repo.GetTopicByID(ctx, topicID, userID) + if err != nil { + return nil, err + } + + return s.repo.ListSubscriptions(ctx, topicID) +} + +func (s *NotifyService) Unsubscribe(ctx context.Context, id uuid.UUID) error { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyDelete, id.String()); err != nil { + return err + } + + sub, err := s.repo.GetSubscriptionByID(ctx, id, userID) + if err != nil { + return err + } + + if err := s.repo.DeleteSubscription(ctx, sub.ID); err != nil { + return err + } + + _ = s.eventSvc.RecordEvent(ctx, "SUBSCRIPTION_DELETED", id.String(), "SUBSCRIPTION", nil) + + _ = s.auditSvc.Log(ctx, sub.UserID, "notify.unsubscribe", "subscription", sub.ID.String(), map[string]interface{}{ + "topic_id": sub.TopicID.String(), + }) + + return nil +} + +func (s *NotifyService) Publish(ctx context.Context, topicID uuid.UUID, body string) error { + userID := appcontext.UserIDFromContext(ctx) + tenantID := appcontext.TenantIDFromContext(ctx) + + if err := s.rbacSvc.Authorize(ctx, userID, tenantID, domain.PermissionNotifyWrite, topicID.String()); err != nil { + return err + } + + topic, err := s.repo.GetTopicByID(ctx, topicID, userID) + if err != nil { + return err + } + + msg := &domain.NotifyMessage{ + ID: uuid.New(), + TopicID: topic.ID, + Body: body, + CreatedAt: time.Now(), + } + + if err := s.repo.SaveMessage(ctx, msg); err != nil { + return err + } + + subs, err := s.repo.ListSubscriptions(ctx, topicID) + if err != nil { + return err + } + + // Delivery logic + for _, sub := range subs { + // Create a background context for async delivery to avoid request cancellation + // but keep it separate for each subscriber to avoid shared timeout issues + go func(c context.Context, sub *domain.Subscription) { + deliveryCtx, cancel := context.WithTimeout(c, 30*time.Second) + defer cancel() + + // Carry over potential trace IDs or other metadata if needed + // (Simplified for now, but avoids using request-scoped ctx) + s.deliver(deliveryCtx, sub, body) + }(ctx, sub) + } + + if err := s.eventSvc.RecordEvent(ctx, "TOPIC_PUBLISHED", topic.ID.String(), "TOPIC", map[string]interface{}{"message_id": msg.ID}); err != nil { + s.logger.Warn("failed to record topic publish event", "topic_id", topic.ID, "error", err) + } + + if err := s.auditSvc.Log(ctx, topic.UserID, "notify.publish", "topic", topic.ID.String(), map[string]interface{}{ + "message_id": msg.ID.String(), + }); err != nil { + s.logger.Warn("failed to log topic publish audit event", "topic_id", topic.ID, "error", err) + } + + return nil +} + +func (s *NotifyService) deliver(ctx context.Context, sub *domain.Subscription, body string) { + switch sub.Protocol { + case domain.ProtocolQueue: + s.deliverToQueue(ctx, sub, body) + case domain.ProtocolWebhook: + s.deliverToWebhook(ctx, sub, body) + } +} + +func (s *NotifyService) deliverToQueue(ctx context.Context, sub *domain.Subscription, body string) { + // Endpoint is Queue ARN or ID. Let's assume ID for simplicity or parse ARN. + // For now let's assume endpoint is the Queue UUID string. + qID, err := uuid.Parse(sub.Endpoint) + if err != nil { + s.logger.Warn("invalid queue ID in subscription", "endpoint", sub.Endpoint, "error", err) + return + } + // We need to bypass user check or use sub.UserID context + deliveryCtx := appcontext.WithUserID(ctx, sub.UserID) + if _, err = s.queueSvc.SendMessage(deliveryCtx, qID, body); err != nil { + s.logger.Warn("failed to deliver to queue", "queue_id", qID, "error", err) + } +} + +func (s *NotifyService) deliverToWebhook(ctx context.Context, sub *domain.Subscription, body string) { + req, err := http.NewRequestWithContext(ctx, http.MethodPost, sub.Endpoint, bytes.NewBufferString(body)) + if err != nil { + s.logger.Warn("failed to build webhook request", "endpoint", sub.Endpoint, "error", err) + return + } + req.Header.Set("Content-Type", "application/json") + + resp, err := webhookHTTPClient.Do(req) + if err != nil { + s.logger.Warn("failed to deliver to webhook", "endpoint", sub.Endpoint, "error", err) + return + } + defer func() { + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + }() + + if resp.StatusCode >= 400 { + s.logger.Warn("webhook delivery failed", + "endpoint", sub.Endpoint, + "subscription_id", sub.ID, + "status", resp.StatusCode) + return + } +} diff --git a/internal/core/services/notify_unit_test.go b/internal/core/services/notify_unit_test.go index 20024ce81..424b3f450 100644 --- a/internal/core/services/notify_unit_test.go +++ b/internal/core/services/notify_unit_test.go @@ -1,432 +1,518 @@ -package services_test - -import ( - "context" - "fmt" - "log/slog" - "testing" - - "github.com/google/uuid" - appcontext "github.com/poyrazk/thecloud/internal/core/context" - "github.com/poyrazk/thecloud/internal/core/domain" - "github.com/poyrazk/thecloud/internal/core/services" - "github.com/poyrazk/thecloud/internal/errors" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/mock" - "github.com/stretchr/testify/require" -) - -const testTopicName = "my-topic" - -func TestNotifyServiceUnit(t *testing.T) { - t.Run("CRUD", testNotifyServiceUnitCRUD) - t.Run("RBACErrors", testNotifyServiceUnitRbacErrors) - t.Run("RepoErrors", testNotifyServiceUnitRepoErrors) - t.Run("PublishErrors", testNotifyServiceUnitPublishErrors) -} - -func testNotifyServiceUnitCRUD(t *testing.T) { - mockRepo := new(MockNotifyRepo) - mockQueueSvc := new(MockQueueService) - mockEventSvc := new(MockEventService) - mockAuditSvc := new(MockAuditService) - rbacSvc := new(MockRBACService) - rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - - svc := services.NewNotifyService(services.NotifyServiceParams{ - Repo: mockRepo, - RBACSvc: rbacSvc, - QueueSvc: mockQueueSvc, - EventSvc: mockEventSvc, - AuditSvc: mockAuditSvc, - Logger: slog.Default(), - }) - - ctx := context.Background() - userID := uuid.New() - tenantID := uuid.New() - ctx = appcontext.WithUserID(ctx, userID) - ctx = appcontext.WithTenantID(ctx, tenantID) - - t.Run("CreateTopic", func(t *testing.T) { - mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(nil, nil).Once() - mockRepo.On("CreateTopic", mock.Anything, mock.Anything).Return(nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_CREATED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.topic_create", "topic", mock.Anything, mock.Anything).Return(nil).Once() - - topic, err := svc.CreateTopic(ctx, testTopicName) - require.NoError(t, err) - assert.NotNil(t, topic) - assert.Equal(t, testTopicName, topic.Name) - }) - - t.Run("ListTopics", func(t *testing.T) { - mockRepo.On("ListTopics", mock.Anything, userID).Return([]*domain.Topic{{ID: uuid.New(), Name: "topic1"}}, nil).Once() - - topics, err := svc.ListTopics(ctx) - require.NoError(t, err) - assert.Len(t, topics, 1) - }) - - t.Run("DeleteTopic", func(t *testing.T) { - topicID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID, Name: "to-delete"} - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("DeleteTopic", mock.Anything, topicID).Return(nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_DELETED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.topic_delete", "topic", mock.Anything, mock.Anything).Return(nil).Once() - - err := svc.DeleteTopic(ctx, topicID) - require.NoError(t, err) - }) - - t.Run("Subscribe", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() - mockRepo.On("CreateSubscription", mock.Anything, mock.Anything).Return(nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "SUBSCRIPTION_CREATED", mock.Anything, "SUBSCRIPTION", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.subscribe", "subscription", mock.Anything, mock.Anything).Return(nil).Once() - - sub, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") - require.NoError(t, err) - assert.NotNil(t, sub) - }) - - t.Run("ListSubscriptions", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() - mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{{ID: uuid.New()}}, nil).Once() - - subs, err := svc.ListSubscriptions(ctx, topicID) - require.NoError(t, err) - assert.Len(t, subs, 1) - }) - - t.Run("Unsubscribe", func(t *testing.T) { - subID := uuid.New() - topicID := uuid.New() - sub := &domain.Subscription{ID: subID, UserID: userID, TopicID: topicID} - mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(sub, nil).Once() - mockRepo.On("DeleteSubscription", mock.Anything, subID).Return(nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "SUBSCRIPTION_DELETED", mock.Anything, "SUBSCRIPTION", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.unsubscribe", "subscription", mock.Anything, mock.Anything).Return(nil).Once() - - err := svc.Unsubscribe(ctx, subID) - require.NoError(t, err) - }) - - t.Run("Publish", func(t *testing.T) { - done := make(chan struct{}) - topicID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID} - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() - mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{}, nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.NoError(t, err) - <-done - }) -} - -func testNotifyServiceUnitRbacErrors(t *testing.T) { - mockRepo := new(MockNotifyRepo) - mockQueueSvc := new(MockQueueService) - mockEventSvc := new(MockEventService) - mockAuditSvc := new(MockAuditService) - rbacSvc := new(MockRBACService) - - svc := services.NewNotifyService(services.NotifyServiceParams{ - Repo: mockRepo, - RBACSvc: rbacSvc, - QueueSvc: mockQueueSvc, - EventSvc: mockEventSvc, - AuditSvc: mockAuditSvc, - Logger: slog.Default(), - }) - - ctx := context.Background() - userID := uuid.New() - tenantID := uuid.New() - ctx = appcontext.WithUserID(ctx, userID) - ctx = appcontext.WithTenantID(ctx, tenantID) - - type rbacCase struct { - name string - permission domain.Permission - resourceID string - invoke func(id string) error - } - - cases := []rbacCase{ - { - name: "CreateTopic_Unauthorized", - permission: domain.PermissionNotifyCreate, - resourceID: "*", - invoke: func(id string) error { - _, err := svc.CreateTopic(ctx, "my-topic") - return err - }, - }, - { - name: "ListTopics_Unauthorized", - permission: domain.PermissionNotifyRead, - resourceID: "*", - invoke: func(id string) error { - _, err := svc.ListTopics(ctx) - return err - }, - }, - { - name: "DeleteTopic_Unauthorized", - permission: domain.PermissionNotifyDelete, - resourceID: uuid.New().String(), - invoke: func(id string) error { - return svc.DeleteTopic(ctx, uuid.MustParse(id)) - }, - }, - { - name: "Subscribe_Unauthorized", - permission: domain.PermissionNotifyWrite, - resourceID: uuid.New().String(), - invoke: func(id string) error { - _, err := svc.Subscribe(ctx, uuid.MustParse(id), domain.ProtocolWebhook, "https://example.com/hook") - return err - }, - }, - { - name: "ListSubscriptions_Unauthorized", - permission: domain.PermissionNotifyRead, - resourceID: uuid.New().String(), - invoke: func(id string) error { - _, err := svc.ListSubscriptions(ctx, uuid.MustParse(id)) - return err - }, - }, - { - name: "Unsubscribe_Unauthorized", - permission: domain.PermissionNotifyDelete, - resourceID: uuid.New().String(), - invoke: func(id string) error { - return svc.Unsubscribe(ctx, uuid.MustParse(id)) - }, - }, - { - name: "Publish_Unauthorized", - permission: domain.PermissionNotifyWrite, - resourceID: uuid.New().String(), - invoke: func(id string) error { - return svc.Publish(ctx, uuid.MustParse(id), "hello") - }, - }, - } - - authErr := errors.New(errors.Forbidden, "permission denied") - for _, c := range cases { - t.Run(c.name, func(t *testing.T) { - rbacSvc.On("Authorize", mock.Anything, userID, tenantID, c.permission, c.resourceID).Return(authErr).Once() - err := c.invoke(c.resourceID) - require.Error(t, err) - assert.True(t, errors.Is(err, errors.Forbidden)) - }) - } -} - -func testNotifyServiceUnitRepoErrors(t *testing.T) { - mockRepo := new(MockNotifyRepo) - mockQueueSvc := new(MockQueueService) - mockEventSvc := new(MockEventService) - mockAuditSvc := new(MockAuditService) - rbacSvc := new(MockRBACService) - rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - - svc := services.NewNotifyService(services.NotifyServiceParams{ - Repo: mockRepo, - RBACSvc: rbacSvc, - QueueSvc: mockQueueSvc, - EventSvc: mockEventSvc, - AuditSvc: mockAuditSvc, - Logger: slog.Default(), - }) - - ctx := context.Background() - userID := uuid.New() - tenantID := uuid.New() - ctx = appcontext.WithUserID(ctx, userID) - ctx = appcontext.WithTenantID(ctx, tenantID) - - t.Run("CreateTopic_DuplicateName", func(t *testing.T) { - existing := &domain.Topic{ID: uuid.New(), Name: testTopicName} - mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(existing, nil).Once() - - _, err := svc.CreateTopic(ctx, testTopicName) - require.Error(t, err) - assert.Contains(t, err.Error(), "already exists") - }) - - t.Run("CreateTopic_RepoError", func(t *testing.T) { - mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(nil, nil).Once() - mockRepo.On("CreateTopic", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() - - _, err := svc.CreateTopic(ctx, testTopicName) - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) - - t.Run("DeleteTopic_NotFound", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, nil).Once() - - err := svc.DeleteTopic(ctx, topicID) - require.Error(t, err) - assert.Contains(t, err.Error(), "not found") - }) - - t.Run("DeleteTopic_RepoError", func(t *testing.T) { - topicID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID, Name: "test"} - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("DeleteTopic", mock.Anything, topicID).Return(fmt.Errorf("db error")).Once() - - err := svc.DeleteTopic(ctx, topicID) - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) - - t.Run("Subscribe_TopicNotFound", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() - - _, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") - require.Error(t, err) - }) - - t.Run("Subscribe_RepoError", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() - mockRepo.On("CreateSubscription", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() - - _, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) - - t.Run("Unsubscribe_NotFound", func(t *testing.T) { - subID := uuid.New() - mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() - - err := svc.Unsubscribe(ctx, subID) - require.Error(t, err) - }) - - t.Run("Unsubscribe_RepoError", func(t *testing.T) { - subID := uuid.New() - topicID := uuid.New() - sub := &domain.Subscription{ID: subID, UserID: userID, TopicID: topicID} - mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(sub, nil).Once() - mockRepo.On("DeleteSubscription", mock.Anything, subID).Return(fmt.Errorf("db error")).Once() - - err := svc.Unsubscribe(ctx, subID) - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) - - t.Run("ListSubscriptions_TopicNotFound", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() - - _, err := svc.ListSubscriptions(ctx, topicID) - require.Error(t, err) - }) - - t.Run("Publish_TopicNotFound", func(t *testing.T) { - topicID := uuid.New() - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.Error(t, err) - }) - - t.Run("Publish_SaveMessageError", func(t *testing.T) { - topicID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID} - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.Error(t, err) - assert.Contains(t, err.Error(), "db error") - }) -} - -func testNotifyServiceUnitPublishErrors(t *testing.T) { - mockRepo := new(MockNotifyRepo) - mockQueueSvc := new(MockQueueService) - mockEventSvc := new(MockEventService) - mockAuditSvc := new(MockAuditService) - rbacSvc := new(MockRBACService) - rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) - - svc := services.NewNotifyService(services.NotifyServiceParams{ - Repo: mockRepo, - RBACSvc: rbacSvc, - QueueSvc: mockQueueSvc, - EventSvc: mockEventSvc, - AuditSvc: mockAuditSvc, - Logger: slog.Default(), - }) - - ctx := context.Background() - userID := uuid.New() - tenantID := uuid.New() - ctx = appcontext.WithUserID(ctx, userID) - ctx = appcontext.WithTenantID(ctx, tenantID) - - t.Run("Publish_WebhookDeliveryError", func(t *testing.T) { - done := make(chan struct{}) - topicID := uuid.New() - subID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID} - sub := &domain.Subscription{ - ID: subID, - UserID: userID, - TopicID: topicID, - Protocol: domain.ProtocolWebhook, - Endpoint: "http://localhost:9999/nonexistent", - } - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() - mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.NoError(t, err) - <-done - }) - - t.Run("Publish_QueueInvalidUUID", func(t *testing.T) { - done := make(chan struct{}) - topicID := uuid.New() - subID := uuid.New() - topic := &domain.Topic{ID: topicID, UserID: userID} - sub := &domain.Subscription{ - ID: subID, - UserID: userID, - TopicID: topicID, - Protocol: domain.ProtocolQueue, - Endpoint: "not-a-valid-uuid", - } - mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() - mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() - mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() - mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() - mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() - - err := svc.Publish(ctx, topicID, "hello") - require.NoError(t, err) - <-done - }) +package services_test + +import ( + "bytes" + "context" + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" + + "github.com/google/uuid" + appcontext "github.com/poyrazk/thecloud/internal/core/context" + "github.com/poyrazk/thecloud/internal/core/domain" + "github.com/poyrazk/thecloud/internal/core/services" + "github.com/poyrazk/thecloud/internal/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +const testTopicName = "my-topic" + +func TestNotifyServiceUnit(t *testing.T) { + t.Run("CRUD", testNotifyServiceUnitCRUD) + t.Run("RBACErrors", testNotifyServiceUnitRbacErrors) + t.Run("RepoErrors", testNotifyServiceUnitRepoErrors) + t.Run("PublishErrors", testNotifyServiceUnitPublishErrors) +} + +func testNotifyServiceUnitCRUD(t *testing.T) { + mockRepo := new(MockNotifyRepo) + mockQueueSvc := new(MockQueueService) + mockEventSvc := new(MockEventService) + mockAuditSvc := new(MockAuditService) + rbacSvc := new(MockRBACService) + rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + svc := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo, + RBACSvc: rbacSvc, + QueueSvc: mockQueueSvc, + EventSvc: mockEventSvc, + AuditSvc: mockAuditSvc, + Logger: slog.Default(), + }) + + ctx := context.Background() + userID := uuid.New() + tenantID := uuid.New() + ctx = appcontext.WithUserID(ctx, userID) + ctx = appcontext.WithTenantID(ctx, tenantID) + + t.Run("CreateTopic", func(t *testing.T) { + mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(nil, nil).Once() + mockRepo.On("CreateTopic", mock.Anything, mock.Anything).Return(nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_CREATED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.topic_create", "topic", mock.Anything, mock.Anything).Return(nil).Once() + + topic, err := svc.CreateTopic(ctx, testTopicName) + require.NoError(t, err) + assert.NotNil(t, topic) + assert.Equal(t, testTopicName, topic.Name) + }) + + t.Run("ListTopics", func(t *testing.T) { + mockRepo.On("ListTopics", mock.Anything, userID).Return([]*domain.Topic{{ID: uuid.New(), Name: "topic1"}}, nil).Once() + + topics, err := svc.ListTopics(ctx) + require.NoError(t, err) + assert.Len(t, topics, 1) + }) + + t.Run("DeleteTopic", func(t *testing.T) { + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID, Name: "to-delete"} + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("DeleteTopic", mock.Anything, topicID).Return(nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_DELETED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.topic_delete", "topic", mock.Anything, mock.Anything).Return(nil).Once() + + err := svc.DeleteTopic(ctx, topicID) + require.NoError(t, err) + }) + + t.Run("Subscribe", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() + mockRepo.On("CreateSubscription", mock.Anything, mock.Anything).Return(nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "SUBSCRIPTION_CREATED", mock.Anything, "SUBSCRIPTION", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.subscribe", "subscription", mock.Anything, mock.Anything).Return(nil).Once() + + sub, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") + require.NoError(t, err) + assert.NotNil(t, sub) + }) + + t.Run("ListSubscriptions", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() + mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{{ID: uuid.New()}}, nil).Once() + + subs, err := svc.ListSubscriptions(ctx, topicID) + require.NoError(t, err) + assert.Len(t, subs, 1) + }) + + t.Run("Unsubscribe", func(t *testing.T) { + subID := uuid.New() + topicID := uuid.New() + sub := &domain.Subscription{ID: subID, UserID: userID, TopicID: topicID} + mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(sub, nil).Once() + mockRepo.On("DeleteSubscription", mock.Anything, subID).Return(nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "SUBSCRIPTION_DELETED", mock.Anything, "SUBSCRIPTION", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.unsubscribe", "subscription", mock.Anything, mock.Anything).Return(nil).Once() + + err := svc.Unsubscribe(ctx, subID) + require.NoError(t, err) + }) + + t.Run("Publish", func(t *testing.T) { + done := make(chan struct{}) + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() + mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{}, nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.NoError(t, err) + <-done + }) +} + +func testNotifyServiceUnitRbacErrors(t *testing.T) { + mockRepo := new(MockNotifyRepo) + mockQueueSvc := new(MockQueueService) + mockEventSvc := new(MockEventService) + mockAuditSvc := new(MockAuditService) + rbacSvc := new(MockRBACService) + + svc := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo, + RBACSvc: rbacSvc, + QueueSvc: mockQueueSvc, + EventSvc: mockEventSvc, + AuditSvc: mockAuditSvc, + Logger: slog.Default(), + }) + + ctx := context.Background() + userID := uuid.New() + tenantID := uuid.New() + ctx = appcontext.WithUserID(ctx, userID) + ctx = appcontext.WithTenantID(ctx, tenantID) + + type rbacCase struct { + name string + permission domain.Permission + resourceID string + invoke func(id string) error + } + + cases := []rbacCase{ + { + name: "CreateTopic_Unauthorized", + permission: domain.PermissionNotifyCreate, + resourceID: "*", + invoke: func(id string) error { + _, err := svc.CreateTopic(ctx, "my-topic") + return err + }, + }, + { + name: "ListTopics_Unauthorized", + permission: domain.PermissionNotifyRead, + resourceID: "*", + invoke: func(id string) error { + _, err := svc.ListTopics(ctx) + return err + }, + }, + { + name: "DeleteTopic_Unauthorized", + permission: domain.PermissionNotifyDelete, + resourceID: uuid.New().String(), + invoke: func(id string) error { + return svc.DeleteTopic(ctx, uuid.MustParse(id)) + }, + }, + { + name: "Subscribe_Unauthorized", + permission: domain.PermissionNotifyWrite, + resourceID: uuid.New().String(), + invoke: func(id string) error { + _, err := svc.Subscribe(ctx, uuid.MustParse(id), domain.ProtocolWebhook, "https://example.com/hook") + return err + }, + }, + { + name: "ListSubscriptions_Unauthorized", + permission: domain.PermissionNotifyRead, + resourceID: uuid.New().String(), + invoke: func(id string) error { + _, err := svc.ListSubscriptions(ctx, uuid.MustParse(id)) + return err + }, + }, + { + name: "Unsubscribe_Unauthorized", + permission: domain.PermissionNotifyDelete, + resourceID: uuid.New().String(), + invoke: func(id string) error { + return svc.Unsubscribe(ctx, uuid.MustParse(id)) + }, + }, + { + name: "Publish_Unauthorized", + permission: domain.PermissionNotifyWrite, + resourceID: uuid.New().String(), + invoke: func(id string) error { + return svc.Publish(ctx, uuid.MustParse(id), "hello") + }, + }, + } + + authErr := errors.New(errors.Forbidden, "permission denied") + for _, c := range cases { + t.Run(c.name, func(t *testing.T) { + rbacSvc.On("Authorize", mock.Anything, userID, tenantID, c.permission, c.resourceID).Return(authErr).Once() + err := c.invoke(c.resourceID) + require.Error(t, err) + assert.True(t, errors.Is(err, errors.Forbidden)) + }) + } +} + +func testNotifyServiceUnitRepoErrors(t *testing.T) { + mockRepo := new(MockNotifyRepo) + mockQueueSvc := new(MockQueueService) + mockEventSvc := new(MockEventService) + mockAuditSvc := new(MockAuditService) + rbacSvc := new(MockRBACService) + rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + svc := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo, + RBACSvc: rbacSvc, + QueueSvc: mockQueueSvc, + EventSvc: mockEventSvc, + AuditSvc: mockAuditSvc, + Logger: slog.Default(), + }) + + ctx := context.Background() + userID := uuid.New() + tenantID := uuid.New() + ctx = appcontext.WithUserID(ctx, userID) + ctx = appcontext.WithTenantID(ctx, tenantID) + + t.Run("CreateTopic_DuplicateName", func(t *testing.T) { + existing := &domain.Topic{ID: uuid.New(), Name: testTopicName} + mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(existing, nil).Once() + + _, err := svc.CreateTopic(ctx, testTopicName) + require.Error(t, err) + assert.Contains(t, err.Error(), "already exists") + }) + + t.Run("CreateTopic_RepoError", func(t *testing.T) { + mockRepo.On("GetTopicByName", mock.Anything, testTopicName, userID).Return(nil, nil).Once() + mockRepo.On("CreateTopic", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() + + _, err := svc.CreateTopic(ctx, testTopicName) + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) + + t.Run("DeleteTopic_NotFound", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, nil).Once() + + err := svc.DeleteTopic(ctx, topicID) + require.Error(t, err) + assert.Contains(t, err.Error(), "not found") + }) + + t.Run("DeleteTopic_RepoError", func(t *testing.T) { + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID, Name: "test"} + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("DeleteTopic", mock.Anything, topicID).Return(fmt.Errorf("db error")).Once() + + err := svc.DeleteTopic(ctx, topicID) + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) + + t.Run("Subscribe_TopicNotFound", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() + + _, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") + require.Error(t, err) + }) + + t.Run("Subscribe_RepoError", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(&domain.Topic{ID: topicID}, nil).Once() + mockRepo.On("CreateSubscription", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() + + _, err := svc.Subscribe(ctx, topicID, domain.ProtocolWebhook, "https://example.com/hook") + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) + + t.Run("Unsubscribe_NotFound", func(t *testing.T) { + subID := uuid.New() + mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() + + err := svc.Unsubscribe(ctx, subID) + require.Error(t, err) + }) + + t.Run("Unsubscribe_RepoError", func(t *testing.T) { + subID := uuid.New() + topicID := uuid.New() + sub := &domain.Subscription{ID: subID, UserID: userID, TopicID: topicID} + mockRepo.On("GetSubscriptionByID", mock.Anything, subID, userID).Return(sub, nil).Once() + mockRepo.On("DeleteSubscription", mock.Anything, subID).Return(fmt.Errorf("db error")).Once() + + err := svc.Unsubscribe(ctx, subID) + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) + + t.Run("ListSubscriptions_TopicNotFound", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() + + _, err := svc.ListSubscriptions(ctx, topicID) + require.Error(t, err) + }) + + t.Run("Publish_TopicNotFound", func(t *testing.T) { + topicID := uuid.New() + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(nil, errors.New(errors.NotFound, "not found")).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.Error(t, err) + }) + + t.Run("Publish_SaveMessageError", func(t *testing.T) { + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(fmt.Errorf("db error")).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.Error(t, err) + assert.Contains(t, err.Error(), "db error") + }) +} + +func testNotifyServiceUnitPublishErrors(t *testing.T) { + mockRepo := new(MockNotifyRepo) + mockQueueSvc := new(MockQueueService) + mockEventSvc := new(MockEventService) + mockAuditSvc := new(MockAuditService) + rbacSvc := new(MockRBACService) + rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + svc := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo, + RBACSvc: rbacSvc, + QueueSvc: mockQueueSvc, + EventSvc: mockEventSvc, + AuditSvc: mockAuditSvc, + Logger: slog.Default(), + }) + + ctx := context.Background() + userID := uuid.New() + tenantID := uuid.New() + ctx = appcontext.WithUserID(ctx, userID) + ctx = appcontext.WithTenantID(ctx, tenantID) + + t.Run("Publish_WebhookDeliveryError", func(t *testing.T) { + done := make(chan struct{}) + topicID := uuid.New() + subID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + sub := &domain.Subscription{ + ID: subID, + UserID: userID, + TopicID: topicID, + Protocol: domain.ProtocolWebhook, + Endpoint: "http://localhost:9999/nonexistent", + } + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() + mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.NoError(t, err) + <-done + }) + + t.Run("Publish_WebhookNon2xxStatus", func(t *testing.T) { + // Issue #338: webhook delivery must surface non-2xx HTTP responses. + var ( + mu sync.Mutex + received bool + ) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + mu.Lock() + received = true + mu.Unlock() + w.WriteHeader(http.StatusInternalServerError) + })) + defer server.Close() + + var logBuf bytes.Buffer + var logMu sync.Mutex + capturingLogger := slog.New(slog.NewTextHandler(&lockedWriter{w: &logBuf, mu: &logMu}, &slog.HandlerOptions{Level: slog.LevelDebug})) + + mockRepo2 := new(MockNotifyRepo) + mockQueueSvc2 := new(MockQueueService) + mockEventSvc2 := new(MockEventService) + mockAuditSvc2 := new(MockAuditService) + rbacSvc2 := new(MockRBACService) + rbacSvc2.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil) + + svc2 := services.NewNotifyService(services.NotifyServiceParams{ + Repo: mockRepo2, + RBACSvc: rbacSvc2, + QueueSvc: mockQueueSvc2, + EventSvc: mockEventSvc2, + AuditSvc: mockAuditSvc2, + Logger: capturingLogger, + }) + + done := make(chan struct{}) + topicID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + sub := &domain.Subscription{ + ID: uuid.New(), + UserID: userID, + TopicID: topicID, + Protocol: domain.ProtocolWebhook, + Endpoint: server.URL, + } + mockRepo2.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo2.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() + mockRepo2.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() + mockEventSvc2.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc2.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() + + err := svc2.Publish(ctx, topicID, "hello") + require.NoError(t, err) + <-done + + // Allow the async webhook goroutine a moment to fire and log. + require.Eventually(t, func() bool { + mu.Lock() + defer mu.Unlock() + return received + }, 2*time.Second, 10*time.Millisecond, "webhook server never received request") + + require.Eventually(t, func() bool { + logMu.Lock() + defer logMu.Unlock() + return strings.Contains(logBuf.String(), "webhook delivery failed") && + strings.Contains(logBuf.String(), "status=500") + }, 2*time.Second, 10*time.Millisecond, "expected webhook delivery failure log with status=500") + }) + + t.Run("Publish_QueueInvalidUUID", func(t *testing.T) { + done := make(chan struct{}) + topicID := uuid.New() + subID := uuid.New() + topic := &domain.Topic{ID: topicID, UserID: userID} + sub := &domain.Subscription{ + ID: subID, + UserID: userID, + TopicID: topicID, + Protocol: domain.ProtocolQueue, + Endpoint: "not-a-valid-uuid", + } + mockRepo.On("GetTopicByID", mock.Anything, topicID, userID).Return(topic, nil).Once() + mockRepo.On("SaveMessage", mock.Anything, mock.Anything).Return(nil).Once() + mockRepo.On("ListSubscriptions", mock.Anything, topicID).Return([]*domain.Subscription{sub}, nil).Once() + mockEventSvc.On("RecordEvent", mock.Anything, "TOPIC_PUBLISHED", mock.Anything, "TOPIC", mock.Anything).Return(nil).Once() + mockAuditSvc.On("Log", mock.Anything, userID, "notify.publish", "topic", topicID.String(), mock.Anything).Return(nil).Run(func(mock.Arguments) { close(done) }).Once() + + err := svc.Publish(ctx, topicID, "hello") + require.NoError(t, err) + <-done + }) +} + +type lockedWriter struct { + w *bytes.Buffer + mu *sync.Mutex +} + +func (l *lockedWriter) Write(p []byte) (int, error) { + l.mu.Lock() + defer l.mu.Unlock() + return l.w.Write(p) } \ No newline at end of file diff --git a/internal/core/services/password_reset.go b/internal/core/services/password_reset.go index 725a84b0d..c534661ba 100644 --- a/internal/core/services/password_reset.go +++ b/internal/core/services/password_reset.go @@ -75,10 +75,9 @@ func (s *PasswordResetService) RequestReset(ctx context.Context, email string) e return err } - // Note: EmailService integration is pending. - // For MVP/Demo: Log the token so we can test it manually. - // Future: Inject and use EmailService here. - s.logger.Debug("password reset token", "email", email, "token", token) + // TODO: deliver `token` via an injected EmailService once available. + // Never log or persist the plaintext token — its only safe destination is the user. + s.logger.Info("password reset token issued", "user_id", user.ID, "token_id", resetToken.ID) return nil } diff --git a/internal/core/services/pipeline.go b/internal/core/services/pipeline.go index 90d596afa..037398a04 100644 --- a/internal/core/services/pipeline.go +++ b/internal/core/services/pipeline.go @@ -237,6 +237,7 @@ func (s *PipelineService) TriggerBuildWebhook(ctx context.Context, opts ports.We } webhookCtx := appcontext.WithUserID(ctx, pipeline.UserID) + webhookCtx = appcontext.WithTenantID(webhookCtx, pipeline.TenantID) return s.createAndQueueBuild(webhookCtx, pipeline, commitHash, domain.BuildTriggerWebhook) } @@ -263,6 +264,7 @@ func (s *PipelineService) createAndQueueBuild(ctx context.Context, pipeline *dom BuildID: build.ID, PipelineID: build.PipelineID, UserID: build.UserID, + TenantID: build.TenantID, CommitHash: build.CommitHash, Trigger: build.TriggerType, } diff --git a/internal/handlers/storage_handler.go b/internal/handlers/storage_handler.go index 31151aa71..42f810e89 100644 --- a/internal/handlers/storage_handler.go +++ b/internal/handlers/storage_handler.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net/http" + "path" "strconv" "strings" "time" @@ -38,6 +39,80 @@ const ( headerContentSha256 = "X-Content-Sha256" ) +// contentDispositionAttachment builds a safe `Content-Disposition: attachment` +// header for a stored object. +// +// Object keys can contain path segments, non-ASCII characters, control +// characters, quotes, backslashes, or CRLF that — if interpolated naively — +// would either corrupt the header (HTTP response splitting) or let an attacker +// inject additional headers. The output therefore emits two parameters per +// RFC 6266: +// +// - `filename="..."` ASCII-only fallback for legacy clients. All bytes +// outside the safe printable range and the two +// characters that are special inside a quoted-string +// (`"` and `\`) are replaced with `_`. +// - `filename*=UTF-8''…` RFC 5987 percent-encoded form preserving the +// original Unicode basename for modern clients. +// +// `path.Base` is used to discard any path segments embedded in the key. If the +// resulting name is empty we fall back to "download". +func contentDispositionAttachment(key string) string { + name := path.Base(key) + if name == "." || name == "/" || name == "" { + name = "download" + } + + return fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, + asciiFilenameFallback(name), rfc5987Encode(name)) +} + +// asciiFilenameFallback returns an ASCII-only sanitized copy of name suitable +// for the legacy `filename=` parameter. Any byte that is a control character +// (<0x20 or 0x7f), non-ASCII (>=0x80), or special inside a quoted-string is +// replaced with `_` so the value can be safely wrapped in double quotes. +func asciiFilenameFallback(name string) string { + out := make([]byte, 0, len(name)) + for i := 0; i < len(name); i++ { + c := name[i] + switch { + case c < 0x20, c == 0x7f, c >= 0x80, c == '"', c == '\\': + out = append(out, '_') + default: + out = append(out, c) + } + } + if len(out) == 0 { + return "download" + } + return string(out) +} + +// rfc5987Encode percent-encodes a value per RFC 5987 attr-char rules so it can +// be safely placed in a `filename*` parameter. +func rfc5987Encode(s string) string { + const hex = "0123456789ABCDEF" + var b strings.Builder + b.Grow(len(s)) + for i := 0; i < len(s); i++ { + c := s[i] + switch { + case c >= 'A' && c <= 'Z', + c >= 'a' && c <= 'z', + c >= '0' && c <= '9', + c == '!', c == '#', c == '$', c == '&', c == '+', + c == '-', c == '.', c == '^', c == '_', c == '`', + c == '|', c == '~': + b.WriteByte(c) + default: + b.WriteByte('%') + b.WriteByte(hex[c>>4]) + b.WriteByte(hex[c&0x0f]) + } + } + return b.String() +} + // Upload uploads an object to a bucket // @Summary Upload an object // @Description Uploads a file/object to the specified bucket and key @@ -105,7 +180,7 @@ func (h *StorageHandler) Download(c *gin.Context) { defer func() { _ = reader.Close() }() // Set headers - c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=%s", key)) + c.Header("Content-Disposition", contentDispositionAttachment(key)) c.Header("Content-Type", obj.ContentType) c.Header("Content-Length", fmt.Sprintf("%d", obj.SizeBytes)) @@ -461,7 +536,7 @@ func (h *StorageHandler) ServePresignedDownload(c *gin.Context) { } defer func() { _ = reader.Close() }() - c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=%s", key)) + c.Header("Content-Disposition", contentDispositionAttachment(key)) c.Header("Content-Type", obj.ContentType) c.Header("Content-Length", fmt.Sprintf("%d", obj.SizeBytes)) _, _ = io.Copy(c.Writer, reader) diff --git a/internal/handlers/storage_handler_content_disposition_test.go b/internal/handlers/storage_handler_content_disposition_test.go new file mode 100644 index 000000000..64a4aec1d --- /dev/null +++ b/internal/handlers/storage_handler_content_disposition_test.go @@ -0,0 +1,86 @@ +package httphandlers + +import ( + "strings" + "testing" +) + +func TestContentDispositionAttachment(t *testing.T) { + cases := []struct { + name string + key string + wantFilename string + wantFilenameStar string + }{ + { + name: "simple ascii filename", + key: "report.pdf", + wantFilename: `filename="report.pdf"`, + wantFilenameStar: `filename*=UTF-8''report.pdf`, + }, + { + name: "nested key uses basename", + key: "exports/2026/q1/report.pdf", + wantFilename: `filename="report.pdf"`, + wantFilenameStar: `filename*=UTF-8''report.pdf`, + }, + { + name: "CRLF response splitting attempt is sanitized", + key: "evil\r\nSet-Cookie: pwned=1", + wantFilename: `filename="evil__Set-Cookie: pwned=1"`, + wantFilenameStar: `filename*=UTF-8''evil%0D%0ASet-Cookie%3A%20pwned%3D1`, + }, + { + name: "embedded quote and backslash are sanitized", + key: `bad"name\file.txt`, + wantFilename: `filename="bad_name_file.txt"`, + wantFilenameStar: `filename*=UTF-8''bad%22name%5Cfile.txt`, + }, + { + name: "non-ASCII falls back in legacy filename, preserved in filename*", + key: "résumé.pdf", + wantFilename: `filename="r__sum__.pdf"`, // 2 bytes per accented char both replaced + wantFilenameStar: `filename*=UTF-8''r%C3%A9sum%C3%A9.pdf`, + }, + { + name: "empty key falls back to download", + key: "", + wantFilename: `filename="download"`, + wantFilenameStar: `filename*=UTF-8''download`, + }, + { + name: "trailing slash falls back to download", + key: "folder/", + wantFilename: `filename="folder"`, // path.Base normalizes + wantFilenameStar: `filename*=UTF-8''folder`, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := contentDispositionAttachment(tc.key) + if !strings.HasPrefix(got, "attachment; ") { + t.Errorf("missing attachment prefix: %q", got) + } + if !strings.Contains(got, tc.wantFilename) { + t.Errorf("missing legacy filename param\n got: %q\n want: %q", got, tc.wantFilename) + } + if !strings.Contains(got, tc.wantFilenameStar) { + t.Errorf("missing filename* param\n got: %q\n want: %q", got, tc.wantFilenameStar) + } + if strings.ContainsAny(got, "\r\n") { + t.Errorf("output must not contain CR/LF: %q", got) + } + }) + } +} + +func TestContentDispositionAttachment_NoCRLFEverEscapes(t *testing.T) { + for _, c := range []byte{'\r', '\n', 0, 0x1f, 0x7f} { + key := "x" + string(c) + "y" + got := contentDispositionAttachment(key) + if strings.ContainsAny(got, "\r\n") { + t.Fatalf("control byte 0x%02x leaked into header: %q", c, got) + } + } +} diff --git a/internal/repositories/postgres/migrations/108_tenant_scoped_resource_uniqueness.down.sql b/internal/repositories/postgres/migrations/108_tenant_scoped_resource_uniqueness.down.sql new file mode 100644 index 000000000..741e1fde7 --- /dev/null +++ b/internal/repositories/postgres/migrations/108_tenant_scoped_resource_uniqueness.down.sql @@ -0,0 +1,9 @@ +-- +goose Down + +ALTER TABLE pipelines DROP CONSTRAINT IF EXISTS pipelines_tenant_id_name_key; +ALTER TABLE caches DROP CONSTRAINT IF EXISTS caches_tenant_id_name_key; +ALTER TABLE queues DROP CONSTRAINT IF EXISTS queues_tenant_id_name_key; + +ALTER TABLE pipelines ADD CONSTRAINT pipelines_user_id_name_key UNIQUE (user_id, name); +ALTER TABLE caches ADD CONSTRAINT caches_user_id_name_key UNIQUE (user_id, name); +ALTER TABLE queues ADD CONSTRAINT queues_user_id_name_key UNIQUE (user_id, name); diff --git a/internal/repositories/postgres/migrations/108_tenant_scoped_resource_uniqueness.up.sql b/internal/repositories/postgres/migrations/108_tenant_scoped_resource_uniqueness.up.sql new file mode 100644 index 000000000..0601ae318 --- /dev/null +++ b/internal/repositories/postgres/migrations/108_tenant_scoped_resource_uniqueness.up.sql @@ -0,0 +1,69 @@ +-- +goose Up + +-- Pipelines, caches, and queues were created with UNIQUE(user_id, name), which +-- allowed two users in the same tenant to register identically named +-- resources. The application layer scopes lookups by tenant_id, so the +-- service-layer name check sees only one row while the duplicate slips past +-- the user-scoped DB constraint. Replace with tenant-scoped uniqueness. + +-- Disambiguate any preexisting duplicates within the same tenant by appending +-- a short user-id suffix to all but the oldest collider. Required so the new +-- UNIQUE(tenant_id, name) constraint can be added without conflicts. +WITH ranked AS ( + SELECT id, + tenant_id, + name, + user_id, + ROW_NUMBER() OVER ( + PARTITION BY tenant_id, name + ORDER BY created_at, id + ) AS rn + FROM pipelines + WHERE tenant_id IS NOT NULL +) +UPDATE pipelines p +SET name = p.name || '-' || substr(replace(r.user_id::text, '-', ''), 1, 8) +FROM ranked r +WHERE p.id = r.id AND r.rn > 1; + +WITH ranked AS ( + SELECT id, + tenant_id, + name, + user_id, + ROW_NUMBER() OVER ( + PARTITION BY tenant_id, name + ORDER BY created_at, id + ) AS rn + FROM caches + WHERE tenant_id IS NOT NULL +) +UPDATE caches c +SET name = c.name || '-' || substr(replace(r.user_id::text, '-', ''), 1, 8) +FROM ranked r +WHERE c.id = r.id AND r.rn > 1; + +WITH ranked AS ( + SELECT id, + tenant_id, + name, + user_id, + ROW_NUMBER() OVER ( + PARTITION BY tenant_id, name + ORDER BY created_at, id + ) AS rn + FROM queues + WHERE tenant_id IS NOT NULL +) +UPDATE queues q +SET name = q.name || '-' || substr(replace(r.user_id::text, '-', ''), 1, 8) +FROM ranked r +WHERE q.id = r.id AND r.rn > 1; + +ALTER TABLE pipelines DROP CONSTRAINT IF EXISTS pipelines_user_id_name_key; +ALTER TABLE caches DROP CONSTRAINT IF EXISTS caches_user_id_name_key; +ALTER TABLE queues DROP CONSTRAINT IF EXISTS queues_user_id_name_key; + +ALTER TABLE pipelines ADD CONSTRAINT pipelines_tenant_id_name_key UNIQUE (tenant_id, name); +ALTER TABLE caches ADD CONSTRAINT caches_tenant_id_name_key UNIQUE (tenant_id, name); +ALTER TABLE queues ADD CONSTRAINT queues_tenant_id_name_key UNIQUE (tenant_id, name); diff --git a/internal/repositories/postgres/pipeline_repo.go b/internal/repositories/postgres/pipeline_repo.go index 258618c09..dfe557fa1 100644 --- a/internal/repositories/postgres/pipeline_repo.go +++ b/internal/repositories/postgres/pipeline_repo.go @@ -168,9 +168,9 @@ func (r *PipelineRepository) UpdateBuild(ctx context.Context, build *domain.Buil started_at = $2, finished_at = $3, updated_at = NOW() - WHERE id = $4 AND user_id = $5 + WHERE id = $4 AND tenant_id = $5 ` - _, err := r.db.Exec(ctx, query, build.Status, build.StartedAt, build.FinishedAt, build.ID, build.UserID) + _, err := r.db.Exec(ctx, query, build.Status, build.StartedAt, build.FinishedAt, build.ID, build.TenantID) return err } @@ -195,15 +195,15 @@ func (r *PipelineRepository) CreateBuildStep(ctx context.Context, step *domain.B return err } -func (r *PipelineRepository) ListBuildSteps(ctx context.Context, buildID, userID uuid.UUID) ([]*domain.BuildStep, error) { +func (r *PipelineRepository) ListBuildSteps(ctx context.Context, buildID, tenantID uuid.UUID) ([]*domain.BuildStep, error) { query := ` SELECT s.id, s.build_id, s.name, s.image, s.commands, s.status, s.exit_code, s.started_at, s.finished_at, s.created_at, s.updated_at FROM build_steps s INNER JOIN builds b ON b.id = s.build_id - WHERE s.build_id = $1 AND b.user_id = $2 + WHERE s.build_id = $1 AND b.tenant_id = $2 ORDER BY s.created_at ASC ` - rows, err := r.db.Query(ctx, query, buildID, userID) + rows, err := r.db.Query(ctx, query, buildID, tenantID) if err != nil { return nil, err } @@ -233,7 +233,7 @@ func (r *PipelineRepository) AppendBuildLog(ctx context.Context, log *domain.Bui return err } -func (r *PipelineRepository) ListBuildLogs(ctx context.Context, buildID, userID uuid.UUID, limit int) ([]*domain.BuildLog, error) { +func (r *PipelineRepository) ListBuildLogs(ctx context.Context, buildID, tenantID uuid.UUID, limit int) ([]*domain.BuildLog, error) { if limit <= 0 { limit = 200 } @@ -242,11 +242,11 @@ func (r *PipelineRepository) ListBuildLogs(ctx context.Context, buildID, userID SELECT l.id, l.build_id, l.step_id, l.content, l.created_at FROM build_logs l INNER JOIN builds b ON b.id = l.build_id - WHERE l.build_id = $1 AND b.user_id = $2 + WHERE l.build_id = $1 AND b.tenant_id = $2 ORDER BY l.created_at ASC LIMIT $3 ` - rows, err := r.db.Query(ctx, query, buildID, userID, limit) + rows, err := r.db.Query(ctx, query, buildID, tenantID, limit) if err != nil { return nil, err } diff --git a/internal/workers/pipeline_worker.go b/internal/workers/pipeline_worker.go index c232d39fa..3cf60c3fc 100644 --- a/internal/workers/pipeline_worker.go +++ b/internal/workers/pipeline_worker.go @@ -138,6 +138,7 @@ func (w *PipelineWorker) processJob(workerCtx context.Context, msg *ports.Durabl ctx, cancel := context.WithTimeout(workerCtx, 30*time.Minute) defer cancel() ctx = appcontext.WithUserID(ctx, job.UserID) + ctx = appcontext.WithTenantID(ctx, job.TenantID) build, pipeline, err := w.loadBuildAndPipeline(ctx, job) if err != nil { @@ -213,7 +214,7 @@ func (w *PipelineWorker) processJob(workerCtx context.Context, msg *ports.Durabl } func (w *PipelineWorker) loadBuildAndPipeline(ctx context.Context, job domain.BuildJob) (*domain.Build, *domain.Pipeline, error) { - build, err := w.repo.GetBuild(ctx, job.BuildID, job.UserID) + build, err := w.repo.GetBuild(ctx, job.BuildID, job.TenantID) if err != nil { w.logger.Error("failed to load build", "build_id", job.BuildID, "error", err) return nil, nil, err @@ -222,7 +223,7 @@ func (w *PipelineWorker) loadBuildAndPipeline(ctx context.Context, job domain.Bu return nil, nil, nil } - pipeline, err := w.repo.GetPipeline(ctx, job.PipelineID, job.UserID) + pipeline, err := w.repo.GetPipeline(ctx, job.PipelineID, job.TenantID) if err != nil { w.logger.Error("failed to load pipeline", "pipeline_id", job.PipelineID, "error", err) w.failBuild(ctx, build, "pipeline load error: "+err.Error()) diff --git a/internal/workers/pipeline_worker_test.go b/internal/workers/pipeline_worker_test.go index 3b305c049..e5ac62e2b 100644 --- a/internal/workers/pipeline_worker_test.go +++ b/internal/workers/pipeline_worker_test.go @@ -181,11 +181,12 @@ func TestPipelineWorker_processJob(t *testing.T) { buildID := uuid.New() pipelineID := uuid.New() userID := uuid.New() - job := domain.BuildJob{BuildID: buildID, PipelineID: pipelineID, UserID: userID} + tenantID := uuid.New() + job := domain.BuildJob{BuildID: buildID, PipelineID: pipelineID, UserID: userID, TenantID: tenantID} msg := &ports.DurableMessage{ID: "1-0", Queue: pipelineQueueName} t.Run("Success", func(t *testing.T) { - build := &domain.Build{ID: buildID, PipelineID: pipelineID, UserID: userID} + build := &domain.Build{ID: buildID, PipelineID: pipelineID, UserID: userID, TenantID: tenantID} pipeline := &domain.Pipeline{ ID: pipelineID, Config: domain.PipelineConfig{ @@ -200,8 +201,8 @@ func TestPipelineWorker_processJob(t *testing.T) { }, } - repo.On("GetBuild", mock.Anything, buildID, userID).Return(build, nil).Once() - repo.On("GetPipeline", mock.Anything, pipelineID, userID).Return(pipeline, nil).Once() + repo.On("GetBuild", mock.Anything, buildID, tenantID).Return(build, nil).Once() + repo.On("GetPipeline", mock.Anything, pipelineID, tenantID).Return(pipeline, nil).Once() repo.On("UpdateBuild", mock.Anything, mock.MatchedBy(func(b *domain.Build) bool { return b.Status == domain.BuildStatusRunning })).Return(nil).Once()