@@ -48,8 +48,8 @@ type OAuthProxy struct {
4848}
4949
5050// LoadConfigFromEnv loads configuration from environment variables
51- func LoadConfigFromEnv () * types.Config {
52- return & types.Config {
51+ func LoadConfigFromEnv () ( * types.Config , error ) {
52+ config := & types.Config {
5353 DatabaseDSN : os .Getenv ("DATABASE_DSN" ),
5454 OAuthClientID : os .Getenv ("OAUTH_CLIENT_ID" ),
5555 OAuthClientSecret : os .Getenv ("OAUTH_CLIENT_SECRET" ),
@@ -58,6 +58,14 @@ func LoadConfigFromEnv() *types.Config {
5858 EncryptionKey : os .Getenv ("ENCRYPTION_KEY" ),
5959 MCPServerURL : os .Getenv ("MCP_SERVER_URL" ),
6060 }
61+
62+ if u , err := url .Parse (config .MCPServerURL ); err != nil || u .Scheme != "http" && u .Scheme != "https" {
63+ return nil , fmt .Errorf ("invalid MCP server URL: %w" , err )
64+ } else if u .Path != "" && u .Path != "/" || u .RawQuery != "" || u .Fragment != "" {
65+ return nil , fmt .Errorf ("MCP server URL must not contain a path, query, or fragment" )
66+ }
67+
68+ return config , nil
6169}
6270
6371func NewOAuthProxy (config * types.Config ) (* OAuthProxy , error ) {
@@ -194,22 +202,21 @@ func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) {
194202 revokeHandler := revoke .NewHandler (p .db )
195203 tokenValidator := validate .NewTokenValidator (p .tokenManager , p .encryptionKey )
196204
197- mux .HandleFunc ("/health" , p .withCORS (p .healthHandler ))
205+ mux .HandleFunc ("GET /health" , p .withCORS (p .healthHandler ))
198206
199207 // OAuth endpoints
200- mux .HandleFunc ("/authorize" , p .withCORS (p .withRateLimit (authorizeHandler )))
201- mux .HandleFunc ("/callback" , p .withCORS (p .withRateLimit (callbackHandler )))
202- mux .HandleFunc ("/token" , p .withCORS (p .withRateLimit (tokenHandler )))
203- mux .HandleFunc ("/revoke" , p .withCORS (p .withRateLimit (revokeHandler )))
204- mux .HandleFunc ("/register" , p .withCORS (p .withRateLimit (register .NewHandler (p .db ))))
208+ mux .HandleFunc ("GET /authorize" , p .withCORS (p .withRateLimit (authorizeHandler )))
209+ mux .HandleFunc ("GET /callback" , p .withCORS (p .withRateLimit (callbackHandler )))
210+ mux .HandleFunc ("POST /token" , p .withCORS (p .withRateLimit (tokenHandler )))
211+ mux .HandleFunc ("POST /revoke" , p .withCORS (p .withRateLimit (revokeHandler )))
212+ mux .HandleFunc ("POST /register" , p .withCORS (p .withRateLimit (register .NewHandler (p .db ))))
205213
206214 // Metadata endpoints
207- mux .HandleFunc ("/.well-known/oauth-authorization-server" , p .withCORS (p .oauthMetadataHandler ))
208- mux .HandleFunc ("/.well-known/oauth-protected-resource/mcp " , p .withCORS (p .protectedResourceMetadataHandler ))
215+ mux .HandleFunc ("GET /.well-known/oauth-authorization-server" , p .withCORS (p .oauthMetadataHandler ))
216+ mux .HandleFunc ("GET /.well-known/oauth-protected-resource" , p .withCORS (p .protectedResourceMetadataHandler ))
209217
210- // Protected resource endpoints
211- mux .HandleFunc ("/mcp" , p .withCORS (p .withRateLimit (tokenValidator .WithTokenValidation (p .mcpProxyHandler ))))
212- mux .HandleFunc ("/mcp/{path...}" , p .withCORS (p .withRateLimit (tokenValidator .WithTokenValidation (p .mcpProxyHandler ))))
218+ // Protect everything else
219+ mux .HandleFunc ("/{path...}" , p .withCORS (p .withRateLimit (tokenValidator .WithTokenValidation (p .mcpProxyHandler ))))
213220}
214221
215222// GetHandler returns an http.Handler for the OAuth proxy
@@ -289,9 +296,10 @@ func (p *OAuthProxy) oauthMetadataHandler(w http.ResponseWriter, r *http.Request
289296}
290297
291298func (p * OAuthProxy ) protectedResourceMetadataHandler (w http.ResponseWriter , r * http.Request ) {
299+ baseURL := handlerutils .GetBaseURL (r )
292300 metadata := types.OAuthProtectedResourceMetadata {
293- Resource : fmt . Sprintf ( "%s/mcp" , handlerutils . GetBaseURL ( r )) ,
294- AuthorizationServers : []string {handlerutils . GetBaseURL ( r ) },
301+ Resource : baseURL ,
302+ AuthorizationServers : []string {baseURL },
295303 Scopes : p .metadata .ScopesSupported ,
296304 ResourceName : p .resourceName ,
297305 ResourceDocumentation : p .metadata .ServiceDocumentation ,
@@ -387,15 +395,7 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
387395 }
388396
389397 // Create target URL
390- var targetURL string
391- if path == "" {
392- // If no path is provided, use the MCP server URL directly
393- targetURL = p .GetMCPServerURL ()
394- } else {
395- // If path is provided, append it to the MCP server URL
396- targetURL = p .GetMCPServerURL () + "/" + path
397- }
398-
398+ targetURL := p .GetMCPServerURL () + "/" + path
399399 // Log the proxy request for debugging
400400 log .Printf ("Proxying request: %s %s -> %s" , r .Method , r .URL .Path , targetURL )
401401
@@ -404,12 +404,12 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
404404 Director : func (req * http.Request ) {
405405 req .Header .Del ("Authorization" )
406406 req .Header .Set ("X-Forwarded-Host" , req .Host )
407+ req .Header .Set ("X-Forwarded-Proto" , req .URL .Scheme )
407408
408409 newURL , _ := url .Parse (targetURL )
409410 req .URL .Scheme = newURL .Scheme
410411 req .URL .Host = newURL .Host
411412 req .Host = newURL .Host
412- req .URL .Path = newURL .Path
413413
414414 // Add forwarded headers from token props
415415 if tokenInfo .Props != nil {
@@ -427,6 +427,27 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
427427 }
428428 }
429429 },
430+ ModifyResponse : func (resp * http.Response ) error {
431+ // Rewrite Location header to use proxy host instead of downstream server host
432+ if location := resp .Header .Get ("Location" ); location != "" {
433+ if locationURL , err := url .Parse (location ); err == nil {
434+ // Get the original request to extract proxy host
435+ proxyHost := resp .Request .Header .Get ("X-Forwarded-Host" )
436+ if proxyHost != "" {
437+ // Parse downstream server URL to get scheme
438+ downstreamURL , _ := url .Parse (p .GetMCPServerURL ())
439+
440+ // Only rewrite if the location points to the downstream server
441+ if locationURL .Host == downstreamURL .Host {
442+ locationURL .Scheme = resp .Request .URL .Scheme
443+ locationURL .Host = proxyHost
444+ resp .Header .Set ("Location" , locationURL .String ())
445+ }
446+ }
447+ }
448+ }
449+ return nil
450+ },
430451 ErrorHandler : func (rw http.ResponseWriter , req * http.Request , err error ) {
431452 log .Printf ("Proxy error: %v" , err )
432453 rw .WriteHeader (http .StatusBadGateway )
0 commit comments