Skip to content

Commit 81b3f7f

Browse files
authored
refactor: stop rewriting paths in proxy (#15)
This change also properly handles redirects from the downstream MCP server and rewrites them to point to the proxy. Signed-off-by: Donnie Adams <donnie@obot.ai>
1 parent 8333411 commit 81b3f7f

File tree

6 files changed

+72
-38
lines changed

6 files changed

+72
-38
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ You can customize the scopes based on your needs. Common additional scopes inclu
271271
OAUTH_CLIENT_ID: "your-oauth-client-id"
272272
OAUTH_CLIENT_SECRET: "your-oauth-client-secret"
273273
OAUTH_AUTHORIZE_URL: "https://your-oauth-provider.com/oauth/authorize"
274-
MCP_SERVER_URL: "http://localhost:3000/mcp"
274+
MCP_SERVER_URL: "http://localhost:3000"
275275
ENCRYPTION_KEY: "your-base64-encoded-32-byte-key"
276276
ports:
277277
- "8080:8080"
@@ -295,7 +295,7 @@ You can customize the scopes based on your needs. Common additional scopes inclu
295295
OAUTH_CLIENT_ID: "your-oauth-client-id"
296296
OAUTH_CLIENT_SECRET: "your-oauth-client-secret"
297297
OAUTH_AUTHORIZE_URL: "https://your-oauth-provider.com/oauth/authorize"
298-
MCP_SERVER_URL: "http://localhost:3000/mcp"
298+
MCP_SERVER_URL: "http://localhost:3000"
299299
volumes:
300300
- ./data:/app/data # Persist SQLite database
301301
ports:
@@ -342,7 +342,7 @@ You can customize the scopes based on your needs. Common additional scopes inclu
342342

343343
### MCP Proxy
344344

345-
- `ANY /mcp/*` - Proxies requests to MCP server with user context headers
345+
- `ANY /*` - Proxies any request not mentioned above to MCP server with user context headers
346346

347347
## OAuth Flow
348348

docker-compose.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ services:
1919
# OAUTH_CLIENT_ID: "your-oauth-client-id"
2020
# OAUTH_CLIENT_SECRET: "your-oauth-client-secret"
2121
# OAUTH_AUTHORIZE_URL: "https://your-oauth-provider.com/oauth/authorize"
22-
# MCP_SERVER_URL: "http://localhost:3000/mcp"
22+
# MCP_SERVER_URL: "http://localhost:3000"
2323
# ports:
2424
# - "8080:8080"
2525
# depends_on:

main.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@ import (
99

1010
func main() {
1111
// Load configuration from environment variables
12-
config := proxy.LoadConfigFromEnv()
12+
config, err := proxy.LoadConfigFromEnv()
13+
if err != nil {
14+
log.Fatalf("Failed to load configuration: %v", err)
15+
}
1316

1417
proxy, err := proxy.NewOAuthProxy(config)
1518
if err != nil {

main_test.go

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package main
22

33
import (
44
"context"
5+
"log"
56
"net/http"
67
"net/http/httptest"
78
"os"
@@ -61,7 +62,10 @@ func TestIntegrationFlow(t *testing.T) {
6162
}()
6263

6364
// Create OAuth proxy
64-
config := proxy.LoadConfigFromEnv()
65+
config, err := proxy.LoadConfigFromEnv()
66+
if err != nil {
67+
log.Fatalf("Failed to load configuration: %v", err)
68+
}
6569
oauthProxy, err := proxy.NewOAuthProxy(config)
6670
if err != nil {
6771
t.Skipf("Skipping test due to database connection error: %v", err)
@@ -99,7 +103,7 @@ func TestIntegrationFlow(t *testing.T) {
99103
// Test protected resource metadata
100104
t.Run("ProtectedResourceMetadata", func(t *testing.T) {
101105
w := httptest.NewRecorder()
102-
req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource/mcp", nil)
106+
req := httptest.NewRequest("GET", "/.well-known/oauth-protected-resource", nil)
103107
handler.ServeHTTP(w, req)
104108

105109
assert.Equal(t, http.StatusOK, w.Code)
@@ -165,7 +169,10 @@ func TestOAuthProxyCreation(t *testing.T) {
165169
}()
166170

167171
// Create OAuth proxy
168-
config := proxy.LoadConfigFromEnv()
172+
config, err := proxy.LoadConfigFromEnv()
173+
if err != nil {
174+
log.Fatalf("Failed to load configuration: %v", err)
175+
}
169176
oauthProxy, err := proxy.NewOAuthProxy(config)
170177
require.NoError(t, err, "Should be able to create OAuth proxy with valid environment")
171178
require.NotNil(t, oauthProxy, "OAuth proxy should not be nil")
@@ -215,7 +222,10 @@ func TestOAuthProxyStart(t *testing.T) {
215222
}()
216223

217224
// Create OAuth proxy
218-
config := proxy.LoadConfigFromEnv()
225+
config, err := proxy.LoadConfigFromEnv()
226+
if err != nil {
227+
log.Fatalf("Failed to load configuration: %v", err)
228+
}
219229
oauthProxy, err := proxy.NewOAuthProxy(config)
220230
require.NoError(t, err)
221231
defer func() {

pkg/oauth/validate/validatetoken.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func (p *TokenValidator) WithTokenValidation(next http.HandlerFunc) http.Handler
2828
authHeader := r.Header.Get("Authorization")
2929
if authHeader == "" {
3030
// Return 401 with proper WWW-Authenticate header
31-
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource/mcp", handlerutils.GetBaseURL(r))
31+
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource", handlerutils.GetBaseURL(r))
3232
wwwAuthValue := fmt.Sprintf(`Bearer error="invalid_token", error_description="Missing Authorization header", resource_metadata="%s"`, resourceMetadataUrl)
3333
w.Header().Set("WWW-Authenticate", wwwAuthValue)
3434
handlerutils.JSON(w, http.StatusUnauthorized, map[string]string{
@@ -41,7 +41,7 @@ func (p *TokenValidator) WithTokenValidation(next http.HandlerFunc) http.Handler
4141
// Parse Authorization header
4242
parts := strings.SplitN(authHeader, " ", 2)
4343
if len(parts) != 2 || strings.ToLower(parts[0]) != "bearer" || parts[1] == "" {
44-
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource/mcp", handlerutils.GetBaseURL(r))
44+
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource", handlerutils.GetBaseURL(r))
4545
wwwAuthValue := fmt.Sprintf(`Bearer error="invalid_token", error_description="Invalid Authorization header format, expected 'Bearer TOKEN'", resource_metadata="%s"`, resourceMetadataUrl)
4646
w.Header().Set("WWW-Authenticate", wwwAuthValue)
4747
handlerutils.JSON(w, http.StatusUnauthorized, map[string]string{
@@ -55,7 +55,7 @@ func (p *TokenValidator) WithTokenValidation(next http.HandlerFunc) http.Handler
5555

5656
tokenInfo, err := p.tokenManager.GetTokenInfo(token)
5757
if err != nil {
58-
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource/mcp", handlerutils.GetBaseURL(r))
58+
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource", handlerutils.GetBaseURL(r))
5959
wwwAuthValue := fmt.Sprintf(`Bearer error="invalid_token", error_description="Invalid or expired token", resource_metadata="%s"`, resourceMetadataUrl)
6060
w.Header().Set("WWW-Authenticate", wwwAuthValue)
6161
handlerutils.JSON(w, http.StatusUnauthorized, map[string]string{
@@ -69,7 +69,7 @@ func (p *TokenValidator) WithTokenValidation(next http.HandlerFunc) http.Handler
6969
if tokenInfo.Props != nil {
7070
decryptedProps, err := encryption.DecryptPropsIfNeeded(p.encryptionKey, tokenInfo.Props)
7171
if err != nil {
72-
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource/mcp", handlerutils.GetBaseURL(r))
72+
resourceMetadataUrl := fmt.Sprintf("%s/.well-known/oauth-protected-resource", handlerutils.GetBaseURL(r))
7373
wwwAuthValue := fmt.Sprintf(`Bearer error="invalid_token", error_description="Failed to decrypt token data", resource_metadata="%s"`, resourceMetadataUrl)
7474
w.Header().Set("WWW-Authenticate", wwwAuthValue)
7575
handlerutils.JSON(w, http.StatusUnauthorized, map[string]string{

pkg/proxy/proxy.go

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ type OAuthProxy struct {
4848
}
4949

5050
// LoadConfigFromEnv loads configuration from environment variables
51-
func LoadConfigFromEnv() *types.Config {
52-
return &types.Config{
51+
func LoadConfigFromEnv() (*types.Config, error) {
52+
config := &types.Config{
5353
DatabaseDSN: os.Getenv("DATABASE_DSN"),
5454
OAuthClientID: os.Getenv("OAUTH_CLIENT_ID"),
5555
OAuthClientSecret: os.Getenv("OAUTH_CLIENT_SECRET"),
@@ -58,6 +58,14 @@ func LoadConfigFromEnv() *types.Config {
5858
EncryptionKey: os.Getenv("ENCRYPTION_KEY"),
5959
MCPServerURL: os.Getenv("MCP_SERVER_URL"),
6060
}
61+
62+
if u, err := url.Parse(config.MCPServerURL); err != nil || u.Scheme != "http" && u.Scheme != "https" {
63+
return nil, fmt.Errorf("invalid MCP server URL: %w", err)
64+
} else if u.Path != "" && u.Path != "/" || u.RawQuery != "" || u.Fragment != "" {
65+
return nil, fmt.Errorf("MCP server URL must not contain a path, query, or fragment")
66+
}
67+
68+
return config, nil
6169
}
6270

6371
func NewOAuthProxy(config *types.Config) (*OAuthProxy, error) {
@@ -194,22 +202,21 @@ func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) {
194202
revokeHandler := revoke.NewHandler(p.db)
195203
tokenValidator := validate.NewTokenValidator(p.tokenManager, p.encryptionKey)
196204

197-
mux.HandleFunc("/health", p.withCORS(p.healthHandler))
205+
mux.HandleFunc("GET /health", p.withCORS(p.healthHandler))
198206

199207
// OAuth endpoints
200-
mux.HandleFunc("/authorize", p.withCORS(p.withRateLimit(authorizeHandler)))
201-
mux.HandleFunc("/callback", p.withCORS(p.withRateLimit(callbackHandler)))
202-
mux.HandleFunc("/token", p.withCORS(p.withRateLimit(tokenHandler)))
203-
mux.HandleFunc("/revoke", p.withCORS(p.withRateLimit(revokeHandler)))
204-
mux.HandleFunc("/register", p.withCORS(p.withRateLimit(register.NewHandler(p.db))))
208+
mux.HandleFunc("GET /authorize", p.withCORS(p.withRateLimit(authorizeHandler)))
209+
mux.HandleFunc("GET /callback", p.withCORS(p.withRateLimit(callbackHandler)))
210+
mux.HandleFunc("POST /token", p.withCORS(p.withRateLimit(tokenHandler)))
211+
mux.HandleFunc("POST /revoke", p.withCORS(p.withRateLimit(revokeHandler)))
212+
mux.HandleFunc("POST /register", p.withCORS(p.withRateLimit(register.NewHandler(p.db))))
205213

206214
// Metadata endpoints
207-
mux.HandleFunc("/.well-known/oauth-authorization-server", p.withCORS(p.oauthMetadataHandler))
208-
mux.HandleFunc("/.well-known/oauth-protected-resource/mcp", p.withCORS(p.protectedResourceMetadataHandler))
215+
mux.HandleFunc("GET /.well-known/oauth-authorization-server", p.withCORS(p.oauthMetadataHandler))
216+
mux.HandleFunc("GET /.well-known/oauth-protected-resource", p.withCORS(p.protectedResourceMetadataHandler))
209217

210-
// Protected resource endpoints
211-
mux.HandleFunc("/mcp", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(p.mcpProxyHandler))))
212-
mux.HandleFunc("/mcp/{path...}", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(p.mcpProxyHandler))))
218+
// Protect everything else
219+
mux.HandleFunc("/{path...}", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(p.mcpProxyHandler))))
213220
}
214221

215222
// GetHandler returns an http.Handler for the OAuth proxy
@@ -289,9 +296,10 @@ func (p *OAuthProxy) oauthMetadataHandler(w http.ResponseWriter, r *http.Request
289296
}
290297

291298
func (p *OAuthProxy) protectedResourceMetadataHandler(w http.ResponseWriter, r *http.Request) {
299+
baseURL := handlerutils.GetBaseURL(r)
292300
metadata := types.OAuthProtectedResourceMetadata{
293-
Resource: fmt.Sprintf("%s/mcp", handlerutils.GetBaseURL(r)),
294-
AuthorizationServers: []string{handlerutils.GetBaseURL(r)},
301+
Resource: baseURL,
302+
AuthorizationServers: []string{baseURL},
295303
Scopes: p.metadata.ScopesSupported,
296304
ResourceName: p.resourceName,
297305
ResourceDocumentation: p.metadata.ServiceDocumentation,
@@ -387,15 +395,7 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
387395
}
388396

389397
// Create target URL
390-
var targetURL string
391-
if path == "" {
392-
// If no path is provided, use the MCP server URL directly
393-
targetURL = p.GetMCPServerURL()
394-
} else {
395-
// If path is provided, append it to the MCP server URL
396-
targetURL = p.GetMCPServerURL() + "/" + path
397-
}
398-
398+
targetURL := p.GetMCPServerURL() + "/" + path
399399
// Log the proxy request for debugging
400400
log.Printf("Proxying request: %s %s -> %s", r.Method, r.URL.Path, targetURL)
401401

@@ -404,12 +404,12 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
404404
Director: func(req *http.Request) {
405405
req.Header.Del("Authorization")
406406
req.Header.Set("X-Forwarded-Host", req.Host)
407+
req.Header.Set("X-Forwarded-Proto", req.URL.Scheme)
407408

408409
newURL, _ := url.Parse(targetURL)
409410
req.URL.Scheme = newURL.Scheme
410411
req.URL.Host = newURL.Host
411412
req.Host = newURL.Host
412-
req.URL.Path = newURL.Path
413413

414414
// Add forwarded headers from token props
415415
if tokenInfo.Props != nil {
@@ -427,6 +427,27 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
427427
}
428428
}
429429
},
430+
ModifyResponse: func(resp *http.Response) error {
431+
// Rewrite Location header to use proxy host instead of downstream server host
432+
if location := resp.Header.Get("Location"); location != "" {
433+
if locationURL, err := url.Parse(location); err == nil {
434+
// Get the original request to extract proxy host
435+
proxyHost := resp.Request.Header.Get("X-Forwarded-Host")
436+
if proxyHost != "" {
437+
// Parse downstream server URL to get scheme
438+
downstreamURL, _ := url.Parse(p.GetMCPServerURL())
439+
440+
// Only rewrite if the location points to the downstream server
441+
if locationURL.Host == downstreamURL.Host {
442+
locationURL.Scheme = resp.Request.URL.Scheme
443+
locationURL.Host = proxyHost
444+
resp.Header.Set("Location", locationURL.String())
445+
}
446+
}
447+
}
448+
}
449+
return nil
450+
},
430451
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, err error) {
431452
log.Printf("Proxy error: %v", err)
432453
rw.WriteHeader(http.StatusBadGateway)

0 commit comments

Comments
 (0)