Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/api/setup/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func InitHandlers(svcs *Services, cfg *platform.Config, logger *slog.Logger) *Ha
Event: httphandlers.NewEventHandler(svcs.Event),
Volume: httphandlers.NewVolumeHandler(svcs.Volume),
LB: httphandlers.NewLBHandler(svcs.LB),
Dashboard: httphandlers.NewDashboardHandler(svcs.Dashboard),
Dashboard: httphandlers.NewDashboardHandler(svcs.Dashboard, cfg.WSAllowedOrigins),
RBAC: httphandlers.NewRBACHandler(svcs.RBAC),
Snapshot: httphandlers.NewSnapshotHandler(svcs.Snapshot),
Stack: httphandlers.NewStackHandler(svcs.Stack),
Expand Down
77 changes: 73 additions & 4 deletions internal/handlers/dashboard_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package httphandlers
import (
"net/http"
"strconv"
"strings"
"time"

"github.com/gin-gonic/gin"
Expand All @@ -13,12 +14,30 @@ import (

// DashboardHandler handles dashboard API endpoints.
type DashboardHandler struct {
svc ports.DashboardService
svc ports.DashboardService
allowedOrigins []string
}

// NewDashboardHandler creates a new dashboard handler.
func NewDashboardHandler(svc ports.DashboardService) *DashboardHandler {
return &DashboardHandler{svc: svc}
//
// allowedOrigins is the explicit allowlist of cross-origin Origin values
// permitted to subscribe to the SSE stream. Accepts either repeated entries
// or a single comma-separated string (matching the WS_ALLOWED_ORIGINS form).
// An empty list means same-origin only — no Access-Control-Allow-Origin
// header is emitted, so browsers fall back to same-origin enforcement. The
// literal "*" entry opts into permissive mode for non-browser clients; even
// then the response echoes the request Origin rather than "*", because the
// SSE EventSource API ignores wildcards when credentials are involved. See #347.
Comment on lines +28 to +30
func NewDashboardHandler(svc ports.DashboardService, allowedOrigins ...string) *DashboardHandler {
cleaned := make([]string, 0, len(allowedOrigins))
for _, raw := range allowedOrigins {
for _, o := range strings.Split(raw, ",") {
if trimmed := strings.TrimSpace(o); trimmed != "" {
cleaned = append(cleaned, trimmed)
}
}
}
return &DashboardHandler{svc: svc, allowedOrigins: cleaned}
}

// GetSummary returns resource counts and overview metrics.
Expand Down Expand Up @@ -98,10 +117,16 @@ func (h *DashboardHandler) GetStats(c *gin.Context) {
// @Security APIKeyAuth
// @Router /api/dashboard/stream [get]
func (h *DashboardHandler) StreamEvents(c *gin.Context) {
// Enforce CORS before any streaming headers are written. SSE inherits the
// caller's cookies/API key, so accepting "*" would let an attacker-hosted
// page receive a logged-in user's events.
if !h.applyCORS(c) {
return
}

c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")

ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
Expand All @@ -127,3 +152,47 @@ func (h *DashboardHandler) StreamEvents(c *gin.Context) {
}
}
}

// applyCORS validates the request Origin and, if accepted, writes the
// CORS response headers. Returns false (after writing 403) when the origin
// is set but not allowed; the caller must abort.
//
// This endpoint takes authoritative control of its CORS headers, clearing
// the permissive wildcard set by httputil.CORS() upstream so that an
// authenticated SSE stream can never be subscribed to by a third-party origin.
//
// Behaviour:
// - Empty Origin → same-origin or non-browser request; clear wildcard
// defaults and emit no Access-Control-* headers (browsers default to
// same-origin).
// - Origin in allowlist → echo it back with Vary: Origin and credentials.
// - "*" in allowlist → echo the request Origin (never literal "*", which
// EventSource rejects when credentials are involved).
// - Otherwise → 403.
func (h *DashboardHandler) applyCORS(c *gin.Context) bool {
headers := c.Writer.Header()
headers.Del("Access-Control-Allow-Origin")
headers.Del("Access-Control-Allow-Credentials")

origin := c.GetHeader("Origin")
if origin == "" {
return true
}

allowed := false
for _, o := range h.allowedOrigins {
if o == "*" || o == origin {
allowed = true
break
}
}
if !allowed {
c.AbortWithStatus(http.StatusForbidden)
return false
}
Comment on lines +177 to +192

headers.Set("Access-Control-Allow-Origin", origin)
headers.Set("Access-Control-Allow-Credentials", "true")
headers.Set("Vary", "Origin")
return true
Comment on lines +194 to +197
}
94 changes: 94 additions & 0 deletions internal/handlers/dashboard_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,100 @@ func TestDashboardHandlerStreamEvents(t *testing.T) {

assert.Contains(t, w.Header().Get("Content-Type"), "text/event-stream")
assert.Contains(t, w.Body.String(), "event:summary")
// Same-origin (no Origin header) → must not emit a wildcard CORS header. See #347.
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
}

func TestDashboardHandlerStreamEventsCORS(t *testing.T) {
t.Parallel()

runStream := func(t *testing.T, h *DashboardHandler, origin string) *httptest.ResponseRecorder {
t.Helper()
gin.SetMode(gin.TestMode)
r := gin.New()
r.GET("/stream", h.StreamEvents)

req, err := http.NewRequest("GET", "/stream", nil)
require.NoError(t, err)
if origin != "" {
req.Header.Set("Origin", origin)
}
ctx, cancel := context.WithCancel(context.Background())
req = req.WithContext(ctx)

w := httptest.NewRecorder()
done := make(chan struct{})
go func() {
r.ServeHTTP(w, req)
close(done)
}()
time.Sleep(50 * time.Millisecond)
cancel()
<-done
return w
}

t.Run("AllowedOriginIsEchoed", func(t *testing.T) {
t.Parallel()
mockSvc := new(dashboardServiceMock)
mockSvc.On("GetSummary", mock.Anything).Return(&domain.ResourceSummary{}, nil)
h := NewDashboardHandler(mockSvc, "https://dash.example.com,https://other.example.com")

w := runStream(t, h, "https://dash.example.com")

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "https://dash.example.com", w.Header().Get("Access-Control-Allow-Origin"))
assert.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials"))
assert.Equal(t, "Origin", w.Header().Get("Vary"))
})

t.Run("DisallowedOriginRejected", func(t *testing.T) {
t.Parallel()
mockSvc := new(dashboardServiceMock)
h := NewDashboardHandler(mockSvc, "https://dash.example.com")

w := runStream(t, h, "https://attacker.example.com")

assert.Equal(t, http.StatusForbidden, w.Code)
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
mockSvc.AssertNotCalled(t, "GetSummary", mock.Anything)
})

t.Run("EmptyAllowlistRejectsCrossOrigin", func(t *testing.T) {
t.Parallel()
mockSvc := new(dashboardServiceMock)
h := NewDashboardHandler(mockSvc)

w := runStream(t, h, "https://attacker.example.com")

assert.Equal(t, http.StatusForbidden, w.Code)
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
mockSvc.AssertNotCalled(t, "GetSummary", mock.Anything)
})

t.Run("WildcardEchoesOriginNotStar", func(t *testing.T) {
t.Parallel()
mockSvc := new(dashboardServiceMock)
mockSvc.On("GetSummary", mock.Anything).Return(&domain.ResourceSummary{}, nil)
h := NewDashboardHandler(mockSvc, "*")

w := runStream(t, h, "https://anything.example.com")

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "https://anything.example.com", w.Header().Get("Access-Control-Allow-Origin"))
})

t.Run("SameOriginEmitsNoCORSHeader", func(t *testing.T) {
t.Parallel()
mockSvc := new(dashboardServiceMock)
mockSvc.On("GetSummary", mock.Anything).Return(&domain.ResourceSummary{}, nil)
h := NewDashboardHandler(mockSvc, "https://dash.example.com")

w := runStream(t, h, "")

assert.Equal(t, http.StatusOK, w.Code)
assert.Empty(t, w.Header().Get("Access-Control-Allow-Origin"))
})
}

func TestDashboardHandlerGetRecentEventsLimits(t *testing.T) {
Expand Down
Loading