diff --git a/internal/api/setup/router.go b/internal/api/setup/router.go index 07e2fc479..62e27e6b3 100644 --- a/internal/api/setup/router.go +++ b/internal/api/setup/router.go @@ -91,7 +91,7 @@ func InitHandlers(svcs *Services, cfg *platform.Config, logger *slog.Logger) *Ha Event: httphandlers.NewEventHandler(svcs.Event), Volume: httphandlers.NewVolumeHandler(svcs.Volume), LB: httphandlers.NewLBHandler(svcs.LB), - Dashboard: httphandlers.NewDashboardHandler(svcs.Dashboard), + Dashboard: httphandlers.NewDashboardHandler(svcs.Dashboard, cfg.WSAllowedOrigins), RBAC: httphandlers.NewRBACHandler(svcs.RBAC), Snapshot: httphandlers.NewSnapshotHandler(svcs.Snapshot), Stack: httphandlers.NewStackHandler(svcs.Stack), diff --git a/internal/handlers/dashboard_handler.go b/internal/handlers/dashboard_handler.go index 380d0bc4a..287b991ac 100644 --- a/internal/handlers/dashboard_handler.go +++ b/internal/handlers/dashboard_handler.go @@ -4,6 +4,7 @@ package httphandlers import ( "net/http" "strconv" + "strings" "time" "github.com/gin-gonic/gin" @@ -13,12 +14,30 @@ import ( // DashboardHandler handles dashboard API endpoints. type DashboardHandler struct { - svc ports.DashboardService + svc ports.DashboardService + allowedOrigins []string } // NewDashboardHandler creates a new dashboard handler. -func NewDashboardHandler(svc ports.DashboardService) *DashboardHandler { - return &DashboardHandler{svc: svc} +// +// allowedOrigins is the explicit allowlist of cross-origin Origin values +// permitted to subscribe to the SSE stream. Accepts either repeated entries +// or a single comma-separated string (matching the WS_ALLOWED_ORIGINS form). +// An empty list means same-origin only — no Access-Control-Allow-Origin +// header is emitted, so browsers fall back to same-origin enforcement. The +// literal "*" entry opts into permissive mode for non-browser clients; even +// then the response echoes the request Origin rather than "*", because the +// SSE EventSource API ignores wildcards when credentials are involved. See #347. +func NewDashboardHandler(svc ports.DashboardService, allowedOrigins ...string) *DashboardHandler { + cleaned := make([]string, 0, len(allowedOrigins)) + for _, raw := range allowedOrigins { + for _, o := range strings.Split(raw, ",") { + if trimmed := strings.TrimSpace(o); trimmed != "" { + cleaned = append(cleaned, trimmed) + } + } + } + return &DashboardHandler{svc: svc, allowedOrigins: cleaned} } // GetSummary returns resource counts and overview metrics. @@ -98,10 +117,16 @@ func (h *DashboardHandler) GetStats(c *gin.Context) { // @Security APIKeyAuth // @Router /api/dashboard/stream [get] func (h *DashboardHandler) StreamEvents(c *gin.Context) { + // Enforce CORS before any streaming headers are written. SSE inherits the + // caller's cookies/API key, so accepting "*" would let an attacker-hosted + // page receive a logged-in user's events. + if !h.applyCORS(c) { + return + } + c.Header("Content-Type", "text/event-stream") c.Header("Cache-Control", "no-cache") c.Header("Connection", "keep-alive") - c.Header("Access-Control-Allow-Origin", "*") ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() @@ -127,3 +152,47 @@ func (h *DashboardHandler) StreamEvents(c *gin.Context) { } } } + +// applyCORS validates the request Origin and, if accepted, writes the +// CORS response headers. Returns false (after writing 403) when the origin +// is set but not allowed; the caller must abort. +// +// This endpoint takes authoritative control of its CORS headers, clearing +// the permissive wildcard set by httputil.CORS() upstream so that an +// authenticated SSE stream can never be subscribed to by a third-party origin. +// +// Behaviour: +// - Empty Origin → same-origin or non-browser request; clear wildcard +// defaults and emit no Access-Control-* headers (browsers default to +// same-origin). +// - Origin in allowlist → echo it back with Vary: Origin and credentials. +// - "*" in allowlist → echo the request Origin (never literal "*", which +// EventSource rejects when credentials are involved). +// - Otherwise → 403. +func (h *DashboardHandler) applyCORS(c *gin.Context) bool { + headers := c.Writer.Header() + headers.Del("Access-Control-Allow-Origin") + headers.Del("Access-Control-Allow-Credentials") + + origin := c.GetHeader("Origin") + if origin == "" { + return true + } + + allowed := false + for _, o := range h.allowedOrigins { + if o == "*" || o == origin { + allowed = true + break + } + } + if !allowed { + c.AbortWithStatus(http.StatusForbidden) + return false + } + + headers.Set("Access-Control-Allow-Origin", origin) + headers.Set("Access-Control-Allow-Credentials", "true") + headers.Set("Vary", "Origin") + return true +} diff --git a/internal/handlers/dashboard_handler_test.go b/internal/handlers/dashboard_handler_test.go index 503c91f6e..58f424fad 100644 --- a/internal/handlers/dashboard_handler_test.go +++ b/internal/handlers/dashboard_handler_test.go @@ -152,6 +152,100 @@ func TestDashboardHandlerStreamEvents(t *testing.T) { assert.Contains(t, w.Header().Get("Content-Type"), "text/event-stream") assert.Contains(t, w.Body.String(), "event:summary") + // Same-origin (no Origin header) → must not emit a wildcard CORS header. See #347. + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) +} + +func TestDashboardHandlerStreamEventsCORS(t *testing.T) { + t.Parallel() + + runStream := func(t *testing.T, h *DashboardHandler, origin string) *httptest.ResponseRecorder { + t.Helper() + gin.SetMode(gin.TestMode) + r := gin.New() + r.GET("/stream", h.StreamEvents) + + req, err := http.NewRequest("GET", "/stream", nil) + require.NoError(t, err) + if origin != "" { + req.Header.Set("Origin", origin) + } + ctx, cancel := context.WithCancel(context.Background()) + req = req.WithContext(ctx) + + w := httptest.NewRecorder() + done := make(chan struct{}) + go func() { + r.ServeHTTP(w, req) + close(done) + }() + time.Sleep(50 * time.Millisecond) + cancel() + <-done + return w + } + + t.Run("AllowedOriginIsEchoed", func(t *testing.T) { + t.Parallel() + mockSvc := new(dashboardServiceMock) + mockSvc.On("GetSummary", mock.Anything).Return(&domain.ResourceSummary{}, nil) + h := NewDashboardHandler(mockSvc, "https://dash.example.com,https://other.example.com") + + w := runStream(t, h, "https://dash.example.com") + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "https://dash.example.com", w.Header().Get("Access-Control-Allow-Origin")) + assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials")) + assert.Equal(t, "Origin", w.Header().Get("Vary")) + }) + + t.Run("DisallowedOriginRejected", func(t *testing.T) { + t.Parallel() + mockSvc := new(dashboardServiceMock) + h := NewDashboardHandler(mockSvc, "https://dash.example.com") + + w := runStream(t, h, "https://attacker.example.com") + + assert.Equal(t, http.StatusForbidden, w.Code) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + mockSvc.AssertNotCalled(t, "GetSummary", mock.Anything) + }) + + t.Run("EmptyAllowlistRejectsCrossOrigin", func(t *testing.T) { + t.Parallel() + mockSvc := new(dashboardServiceMock) + h := NewDashboardHandler(mockSvc) + + w := runStream(t, h, "https://attacker.example.com") + + assert.Equal(t, http.StatusForbidden, w.Code) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + mockSvc.AssertNotCalled(t, "GetSummary", mock.Anything) + }) + + t.Run("WildcardEchoesOriginNotStar", func(t *testing.T) { + t.Parallel() + mockSvc := new(dashboardServiceMock) + mockSvc.On("GetSummary", mock.Anything).Return(&domain.ResourceSummary{}, nil) + h := NewDashboardHandler(mockSvc, "*") + + w := runStream(t, h, "https://anything.example.com") + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "https://anything.example.com", w.Header().Get("Access-Control-Allow-Origin")) + }) + + t.Run("SameOriginEmitsNoCORSHeader", func(t *testing.T) { + t.Parallel() + mockSvc := new(dashboardServiceMock) + mockSvc.On("GetSummary", mock.Anything).Return(&domain.ResourceSummary{}, nil) + h := NewDashboardHandler(mockSvc, "https://dash.example.com") + + w := runStream(t, h, "") + + assert.Equal(t, http.StatusOK, w.Code) + assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + }) } func TestDashboardHandlerGetRecentEventsLimits(t *testing.T) {