Skip to content

Commit 95fc784

Browse files
authored
feat: add forward_auth support (#16)
Signed-off-by: Donnie Adams <donnie@obot.ai>
1 parent 81b3f7f commit 95fc784

File tree

7 files changed

+857
-65
lines changed

7 files changed

+857
-65
lines changed

main.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,6 @@ func main() {
2828
handler := proxy.GetHandler()
2929

3030
// Start server
31-
log.Printf("Starting OAuth proxy server on localhost:8080")
32-
log.Fatal(http.ListenAndServe(":8080", handler))
31+
log.Printf("Starting OAuth proxy server on localhost:" + config.Port)
32+
log.Fatal(http.ListenAndServe(":"+config.Port, handler))
3333
}

main_test.go

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,17 @@ func TestIntegrationFlow(t *testing.T) {
121121
assert.Contains(t, w.Header().Get("WWW-Authenticate"), "Bearer")
122122
})
123123

124+
// Test protected resource metadata with path parameter
125+
t.Run("ProtectedResourceMetadataWithPath", func(t *testing.T) {
126+
w := httptest.NewRecorder()
127+
req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource/test/path", nil)
128+
handler.ServeHTTP(w, req)
129+
130+
assert.Equal(t, http.StatusOK, w.Code)
131+
assert.Contains(t, w.Body.String(), "resource")
132+
assert.Contains(t, w.Body.String(), "authorization_servers")
133+
})
134+
124135
// Test authorization endpoint redirects (basic validation)
125136
t.Run("AuthorizationEndpointRedirect", func(t *testing.T) {
126137
w := httptest.NewRecorder()
@@ -259,3 +270,161 @@ func TestOAuthProxyStart(t *testing.T) {
259270
t.Fatal("Server did not stop within timeout")
260271
}
261272
}
273+
274+
func TestForwardAuthIntegrationFlow(t *testing.T) {
275+
// Skip if running in short mode
276+
if testing.Short() {
277+
t.Skip("Skipping integration tests in short mode")
278+
}
279+
280+
// Set required environment variables for forward auth testing
281+
oldVars := map[string]string{
282+
"OAUTH_CLIENT_ID": os.Getenv("OAUTH_CLIENT_ID"),
283+
"OAUTH_CLIENT_SECRET": os.Getenv("OAUTH_CLIENT_SECRET"),
284+
"OAUTH_AUTHORIZE_URL": os.Getenv("OAUTH_AUTHORIZE_URL"),
285+
"SCOPES_SUPPORTED": os.Getenv("SCOPES_SUPPORTED"),
286+
"MCP_SERVER_URL": os.Getenv("MCP_SERVER_URL"),
287+
"DATABASE_DSN": os.Getenv("DATABASE_DSN"),
288+
"PROXY_MODE": os.Getenv("PROXY_MODE"),
289+
"PORT": os.Getenv("PORT"),
290+
}
291+
292+
// Set test environment variables for forward auth mode
293+
testEnvVars := map[string]string{
294+
"OAUTH_CLIENT_ID": "test_client_id",
295+
"OAUTH_CLIENT_SECRET": "test_client_secret",
296+
"OAUTH_AUTHORIZE_URL": "https://accounts.google.com",
297+
"SCOPES_SUPPORTED": "openid,profile,email",
298+
"PROXY_MODE": "forward_auth",
299+
"PORT": "8082", // Different port to avoid conflicts
300+
"DATABASE_DSN": os.Getenv("TEST_DATABASE_DSN"), // Use test database if available
301+
}
302+
303+
for key, value := range testEnvVars {
304+
if value != "" {
305+
if err := os.Setenv(key, value); err != nil {
306+
t.Logf("Failed to set %s: %v", key, err)
307+
}
308+
}
309+
}
310+
311+
// Restore environment variables after test
312+
defer func() {
313+
for key, value := range oldVars {
314+
if value != "" {
315+
_ = os.Setenv(key, value)
316+
} else {
317+
_ = os.Unsetenv(key)
318+
}
319+
}
320+
}()
321+
322+
// Create OAuth proxy in forward auth mode
323+
config, err := proxy.LoadConfigFromEnv()
324+
if err != nil {
325+
log.Fatalf("Failed to load configuration: %v", err)
326+
}
327+
require.Equal(t, "forward_auth", config.Mode)
328+
require.Equal(t, "8082", config.Port)
329+
330+
oauthProxy, err := proxy.NewOAuthProxy(config)
331+
if err != nil {
332+
t.Skipf("Skipping test due to database connection error: %v", err)
333+
}
334+
defer func() {
335+
if err := oauthProxy.Close(); err != nil {
336+
t.Logf("Error closing OAuth proxy: %v", err)
337+
}
338+
}()
339+
340+
// Get HTTP handler
341+
handler := oauthProxy.GetHandler()
342+
343+
// Test health endpoint works in forward auth mode
344+
t.Run("ForwardAuthHealthEndpoint", func(t *testing.T) {
345+
w := httptest.NewRecorder()
346+
req := httptest.NewRequest("GET", "/health", nil)
347+
handler.ServeHTTP(w, req)
348+
349+
assert.Equal(t, http.StatusOK, w.Code)
350+
assert.Contains(t, w.Body.String(), "ok")
351+
})
352+
353+
// Test OAuth metadata endpoints work in forward auth mode
354+
t.Run("ForwardAuthOAuthMetadata", func(t *testing.T) {
355+
w := httptest.NewRecorder()
356+
req := httptest.NewRequest("GET", "/.well-known/oauth-authorization-server", nil)
357+
handler.ServeHTTP(w, req)
358+
359+
assert.Equal(t, http.StatusOK, w.Code)
360+
assert.Contains(t, w.Body.String(), "authorization_endpoint")
361+
assert.Contains(t, w.Body.String(), "token_endpoint")
362+
})
363+
364+
// Test protected resource metadata with path works in forward auth mode
365+
t.Run("ForwardAuthProtectedResourceMetadataWithPath", func(t *testing.T) {
366+
w := httptest.NewRecorder()
367+
req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource/api/v1", nil)
368+
handler.ServeHTTP(w, req)
369+
370+
assert.Equal(t, http.StatusOK, w.Code)
371+
assert.Contains(t, w.Body.String(), "resource")
372+
assert.Contains(t, w.Body.String(), "authorization_servers")
373+
})
374+
375+
// Test that forward auth mode requires authorization for protected endpoints
376+
t.Run("ForwardAuthRequiresAuth", func(t *testing.T) {
377+
testPaths := []string{"/api", "/data", "/protected", "/mcp", "/test"}
378+
379+
for _, path := range testPaths {
380+
t.Run("Path_"+path, func(t *testing.T) {
381+
w := httptest.NewRecorder()
382+
req := httptest.NewRequest("GET", path, nil)
383+
handler.ServeHTTP(w, req)
384+
385+
assert.Equal(t, http.StatusUnauthorized, w.Code)
386+
assert.Contains(t, w.Header().Get("WWW-Authenticate"), "Bearer")
387+
})
388+
}
389+
})
390+
391+
// Test that forward auth mode doesn't proxy to MCP server (no proxying behavior)
392+
t.Run("ForwardAuthNoProxying", func(t *testing.T) {
393+
// In forward auth mode, there should be no attempt to proxy to an MCP server
394+
// Instead, the proxy should validate tokens and set headers for downstream services
395+
w := httptest.NewRecorder()
396+
req := httptest.NewRequest("GET", "/api/test", nil)
397+
handler.ServeHTTP(w, req)
398+
399+
// Should get unauthorized (no proxying attempt)
400+
assert.Equal(t, http.StatusUnauthorized, w.Code)
401+
assert.Contains(t, w.Header().Get("WWW-Authenticate"), "Bearer")
402+
403+
// Should not have any proxy-related error messages
404+
assert.NotContains(t, w.Body.String(), "proxy")
405+
assert.NotContains(t, w.Body.String(), "502")
406+
assert.NotContains(t, w.Body.String(), "Bad Gateway")
407+
})
408+
409+
// Test authorization endpoint redirects work in forward auth mode
410+
t.Run("ForwardAuthAuthorizationEndpointRedirect", func(t *testing.T) {
411+
w := httptest.NewRecorder()
412+
req := httptest.NewRequest("GET", "/authorize?response_type=code&client_id=test&redirect_uri=http://localhost:8082/callback&scope=openid", nil)
413+
handler.ServeHTTP(w, req)
414+
415+
// Should get a redirect to the OAuth provider or an error about invalid client
416+
assert.True(t, w.Code == http.StatusFound || w.Code == http.StatusBadRequest)
417+
})
418+
419+
// Test CORS headers work in forward auth mode
420+
t.Run("ForwardAuthCORSHeaders", func(t *testing.T) {
421+
w := httptest.NewRecorder()
422+
req := httptest.NewRequest("OPTIONS", "/api", nil)
423+
req.Header.Set("Origin", "https://example.com")
424+
req.Header.Set("Access-Control-Request-Method", "GET")
425+
handler.ServeHTTP(w, req)
426+
427+
// Should handle CORS preflight
428+
assert.Contains(t, w.Header().Get("Access-Control-Allow-Origin"), "*")
429+
})
430+
}

pkg/handlerutils/handlerutils.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ func JSON(w http.ResponseWriter, statusCode int, obj any) {
2020
"error_description": "Failed to encode JSON response",
2121
"error_detail": err.Error(),
2222
})
23-
w.WriteHeader(http.StatusInternalServerError)
2423
_, _ = w.Write(errText)
2524
}
2625
}
@@ -52,6 +51,10 @@ func GetClientIP(r *http.Request) string {
5251
// GetBaseURL returns the URL of the request without the path and
5352
// infers the scheme (http or https)
5453
func GetBaseURL(r *http.Request) string {
54+
if url := r.Header.Get("X-Mcp-Oauth-Proxy-URL"); url != "" {
55+
return url
56+
}
57+
5558
scheme := "http"
5659
if r.TLS != nil || r.Header.Get("X-Forwarded-Proto") == "https" {
5760
scheme = "https"

pkg/oauth/register/register.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
135135
response["client_secret_expires_at"] = 0 // Never expires
136136
}
137137

138-
handlerutils.JSON(w, http.StatusCreated, response)
138+
handlerutils.JSON(w, http.StatusOK, response)
139139
}
140140

141141
func (p *Handler) validateClientMetadata(metadata map[string]any) (*types.ClientInfo, error) {

0 commit comments

Comments
 (0)