Skip to content

Commit e4d621a

Browse files
Merge pull request #25 from ibuildthecloud/middleware
enhance: add middleware mode
2 parents c625a66 + 6d423a4 commit e4d621a

File tree

7 files changed

+103
-68
lines changed

7 files changed

+103
-68
lines changed

pkg/oauth/callback/callback.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package callback
22

33
import (
4+
"encoding/json"
45
"fmt"
56
"log"
67
"net/http"
@@ -198,6 +199,8 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
198199
if needsUserInfo {
199200
sensitiveProps["email"] = userInfo.Email
200201
sensitiveProps["name"] = userInfo.Name
202+
infoJSON, _ := json.Marshal(userInfo)
203+
sensitiveProps["info"] = string(infoJSON)
201204
}
202205

203206
// Initialize props map

pkg/oauth/validate/validatetoken.go

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,16 @@ import (
1919
)
2020

2121
type TokenValidator struct {
22-
tokenManager *tokens.TokenManager
23-
encryptionKey []byte
24-
mcpUIManager *mcpui.Manager // Optional MCP UI manager for JWT handling
25-
db TokenStore // Database for refresh operations
26-
provider providers.Provider // OAuth provider for generating auth URLs
27-
clientID string // OAuth client ID
28-
clientSecret string // OAuth client secret
29-
scopesSupported []string // Supported OAuth scopes
30-
routePrefix string
22+
tokenManager *tokens.TokenManager
23+
encryptionKey []byte
24+
mcpUIManager *mcpui.Manager // Optional MCP UI manager for JWT handling
25+
db TokenStore // Database for refresh operations
26+
provider providers.Provider // OAuth provider for generating auth URLs
27+
clientID string // OAuth client ID
28+
clientSecret string // OAuth client secret
29+
scopesSupported []string // Supported OAuth scopes
30+
routePrefix string
31+
requiredAuthPaths []string
3132
}
3233

3334
// TokenStore interface for database operations needed by validator
@@ -39,17 +40,18 @@ type TokenStore interface {
3940
StoreAuthRequest(key string, data map[string]any) error
4041
}
4142

42-
func NewTokenValidator(tokenManager *tokens.TokenManager, mcpUIManager *mcpui.Manager, encryptionKey []byte, db TokenStore, provider providers.Provider, clientID, clientSecret string, scopesSupported []string, routePrefix string) *TokenValidator {
43+
func NewTokenValidator(tokenManager *tokens.TokenManager, mcpUIManager *mcpui.Manager, encryptionKey []byte, db TokenStore, provider providers.Provider, clientID, clientSecret string, scopesSupported []string, routePrefix string, requiredAuthPaths []string) *TokenValidator {
4344
return &TokenValidator{
44-
mcpUIManager: mcpUIManager,
45-
tokenManager: tokenManager,
46-
encryptionKey: encryptionKey,
47-
db: db,
48-
provider: provider,
49-
clientID: clientID,
50-
clientSecret: clientSecret,
51-
scopesSupported: scopesSupported,
52-
routePrefix: routePrefix,
45+
mcpUIManager: mcpUIManager,
46+
tokenManager: tokenManager,
47+
encryptionKey: encryptionKey,
48+
db: db,
49+
provider: provider,
50+
clientID: clientID,
51+
clientSecret: clientSecret,
52+
scopesSupported: scopesSupported,
53+
routePrefix: routePrefix,
54+
requiredAuthPaths: requiredAuthPaths,
5355
}
5456
}
5557

@@ -181,6 +183,21 @@ func (p *TokenValidator) setCookiesForRefresh(w http.ResponseWriter, r *http.Req
181183
func (p *TokenValidator) WithTokenValidation(next http.HandlerFunc) http.HandlerFunc {
182184
return func(w http.ResponseWriter, r *http.Request) {
183185
authHeader := r.Header.Get("Authorization")
186+
if authHeader == "" && len(p.requiredAuthPaths) > 0 {
187+
matches := false
188+
for _, path := range p.requiredAuthPaths {
189+
if strings.HasPrefix(r.URL.Path, path) {
190+
matches = true
191+
break
192+
}
193+
}
194+
if !matches {
195+
// Not a protected path, skip validation
196+
next.ServeHTTP(w, r)
197+
return
198+
}
199+
}
200+
184201
if authHeader == "" {
185202
// Try cookie-based authentication with refresh capability
186203
var bearerTokenFromCookie string
@@ -351,7 +368,8 @@ func (p *TokenValidator) handleOauthFlow(w http.ResponseWriter, r *http.Request)
351368
}
352369

353370
func GetTokenInfo(r *http.Request) *tokens.TokenInfo {
354-
return r.Context().Value(tokenInfoKey{}).(*tokens.TokenInfo)
371+
v, _ := r.Context().Value(tokenInfoKey{}).(*tokens.TokenInfo)
372+
return v
355373
}
356374

357375
type tokenInfoKey struct{}

pkg/providers/generic.go

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -175,24 +175,25 @@ func (p *GenericProvider) GetUserInfo(ctx context.Context, accessToken string) (
175175
return nil, fmt.Errorf("failed to decode user info response: %w", err)
176176
}
177177

178-
var userInfo *UserInfo
179-
if p.metadata.UserinfoEndpoint == "https://api.github.com/user" {
180-
userInfo = &UserInfo{
181-
ID: getString(userInfoResp, "login"),
182-
Email: getString(userInfoResp, "email"),
183-
Name: getString(userInfoResp, "name"),
184-
}
185-
} else {
186-
userInfo = &UserInfo{
187-
ID: getString(userInfoResp, "sub"),
188-
Email: getString(userInfoResp, "email"),
189-
Name: getString(userInfoResp, "name"),
190-
}
178+
userInfo := &UserInfo{
179+
ID: getString(userInfoResp, "id"),
180+
Sub: getString(userInfoResp, "sub"),
181+
Login: getString(userInfoResp, "login"),
182+
Email: getString(userInfoResp, "email"),
183+
EmailVerified: getBool(userInfoResp, "email_verified"),
184+
Name: getString(userInfoResp, "name"),
185+
Picture: getString(userInfoResp, "picture"),
186+
GivenName: getString(userInfoResp, "given_name"),
187+
FamilyName: getString(userInfoResp, "family_name"),
188+
Locale: getString(userInfoResp, "locale"),
191189
}
192190

193-
// If sub is not available, try other common ID fields
194191
if userInfo.ID == "" {
195-
userInfo.ID = getString(userInfoResp, "id")
192+
userInfo.ID = userInfo.Sub
193+
}
194+
195+
if userInfo.ID == "" && p.metadata.UserinfoEndpoint == "https://api.github.com/user" {
196+
userInfo.ID = userInfo.Login
196197
}
197198

198199
return userInfo, nil
@@ -231,12 +232,13 @@ func (p *GenericProvider) GetName() string {
231232
return "generic"
232233
}
233234

235+
func getBool(m map[string]any, key string) bool {
236+
b, _ := m[key].(bool)
237+
return b
238+
}
239+
234240
// Helper functions
235241
func getString(m map[string]any, key string) string {
236-
if val, ok := m[key]; ok {
237-
if str, ok := val.(string); ok {
238-
return str
239-
}
240-
}
241-
return ""
242+
str, _ := m[key].(string)
243+
return str
242244
}

pkg/providers/provider.go

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,16 @@ import (
88

99
// UserInfo represents user information from OAuth provider
1010
type UserInfo struct {
11-
ID string `json:"id"`
12-
Email string `json:"email"`
13-
Name string `json:"name"`
11+
ID string `json:"id"`
12+
Sub string `json:"sub"`
13+
Login string `json:"login"`
14+
Email string `json:"email"`
15+
EmailVerified bool `json:"email_verified"`
16+
Name string `json:"name"`
17+
GivenName string `json:"given_name"`
18+
FamilyName string `json:"family_name"`
19+
Picture string `json:"picture"`
20+
Locale string `json:"locale"`
1421
}
1522

1623
// TokenInfo represents token information from OAuth provider

pkg/proxy/proxy.go

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@ import (
55
"encoding/base64"
66
"fmt"
77
"log"
8+
"maps"
89
"net/http"
910
"net/http/httputil"
1011
"net/url"
1112
"os"
1213
"strconv"
1314
"strings"
14-
"sync"
1515
"time"
1616

1717
"github.com/gorilla/handlers"
@@ -43,7 +43,6 @@ type OAuthProxy struct {
4343
provider string
4444
encryptionKey []byte
4545
resourceName string
46-
lock sync.Mutex
4746
config *types.Config
4847

4948
ctx context.Context
@@ -53,6 +52,7 @@ type OAuthProxy struct {
5352
const (
5453
ModeProxy = "proxy"
5554
ModeForwardAuth = "forward_auth"
55+
Middleware = "middleware"
5656
)
5757

5858
func NewOAuthProxy(config *types.Config) (*OAuthProxy, error) {
@@ -206,7 +206,7 @@ func (p *OAuthProxy) Start(ctx context.Context) error {
206206
return nil
207207
}
208208

209-
func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) {
209+
func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux, next http.Handler) {
210210
provider, err := p.providers.GetProvider(p.provider)
211211
if err != nil {
212212
log.Fatalf("Failed to get provider: %v", err)
@@ -216,7 +216,7 @@ func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) {
216216
tokenHandler := token.NewHandler(p.db)
217217
callbackHandler := callback.NewHandler(p.db, provider, p.encryptionKey, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.config.RoutePrefix, p.mcpUIManager)
218218
revokeHandler := revoke.NewHandler(p.db)
219-
tokenValidator := validate.NewTokenValidator(p.tokenManager, p.mcpUIManager, p.encryptionKey, p.db, provider, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.metadata.ScopesSupported, p.config.RoutePrefix)
219+
tokenValidator := validate.NewTokenValidator(p.tokenManager, p.mcpUIManager, p.encryptionKey, p.db, provider, p.GetOAuthClientID(), p.GetOAuthClientSecret(), p.metadata.ScopesSupported, p.config.RoutePrefix, p.config.RequiredAuthPaths)
220220
successHandler := success.NewHandler()
221221

222222
// Get route prefix from config
@@ -239,13 +239,15 @@ func (p *OAuthProxy) SetupRoutes(mux *http.ServeMux) {
239239
mux.HandleFunc("GET "+prefix+"/auth/mcp-ui/success", p.withCORS(p.withRateLimit(successHandler)))
240240

241241
// Protect everything else
242-
mux.HandleFunc(prefix+"/{path...}", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(p.mcpProxyHandler))))
242+
mux.HandleFunc(prefix+"/{path...}", p.withCORS(p.withRateLimit(tokenValidator.WithTokenValidation(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
243+
p.mcpProxyHandler(w, r, next)
244+
})))))
243245
}
244246

245247
// GetHandler returns an http.Handler for the OAuth proxy
246248
func (p *OAuthProxy) GetHandler() http.Handler {
247249
mux := http.NewServeMux()
248-
p.SetupRoutes(mux)
250+
p.SetupRoutes(mux, nil)
249251

250252
// Wrap with logging middleware
251253
loggedHandler := handlers.LoggingHandler(os.Stdout, mux)
@@ -335,20 +337,17 @@ func (p *OAuthProxy) protectedResourceMetadataHandler(w http.ResponseWriter, r *
335337
handlerutils.JSON(w, http.StatusOK, metadata)
336338
}
337339

338-
func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
340+
func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request, next http.Handler) {
339341
tokenInfo := validate.GetTokenInfo(r)
340342
path := r.PathValue("path")
341343

342344
// Check if the access token is expired and refresh if needed
343-
if tokenInfo.Props != nil {
345+
if tokenInfo != nil && tokenInfo.Props != nil {
344346
if _, ok := tokenInfo.Props["access_token"].(string); ok {
345347
// Check if token is expired (with a 5-minute buffer)
346348
expiresAt, ok := tokenInfo.Props["expires_at"].(float64)
347349
if ok && expiresAt > 0 {
348350
if time.Now().Add(5 * time.Minute).After(time.Unix(int64(expiresAt), 0)) {
349-
// when refreshing token, we need to lock the database to avoid race conditions
350-
// otherwise we could get save the old access token into the database when another refresh process is running
351-
p.lock.Lock()
352351
log.Printf("Access token is expired or will expire soon, attempting to refresh")
353352

354353
// Get the refresh token
@@ -359,7 +358,6 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
359358
"error": "invalid_token",
360359
"error_description": "Access token expired and no refresh token available",
361360
})
362-
p.lock.Unlock()
363361
return
364362
}
365363

@@ -371,7 +369,6 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
371369
"error": "server_error",
372370
"error_description": "Failed to refresh token",
373371
})
374-
p.lock.Unlock()
375372
return
376373
}
377374

@@ -384,7 +381,6 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
384381
"error": "server_error",
385382
"error_description": "OAuth credentials not configured",
386383
})
387-
p.lock.Unlock()
388384
return
389385
}
390386

@@ -396,7 +392,6 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
396392
"error": "invalid_token",
397393
"error_description": "Failed to refresh access token",
398394
})
399-
p.lock.Unlock()
400395
return
401396
}
402397

@@ -407,21 +402,20 @@ func (p *OAuthProxy) mcpProxyHandler(w http.ResponseWriter, r *http.Request) {
407402
"error": "server_error",
408403
"error_description": "Failed to update grant with new token",
409404
})
410-
p.lock.Unlock()
411405
return
412406
}
413407

414408
// Update the token info with the new access token for the current request
415409
tokenInfo.Props["access_token"] = newTokenInfo.AccessToken
416-
p.lock.Unlock()
417-
418410
log.Printf("Successfully refreshed access token")
419411
}
420412
}
421413
}
422414
}
423415

424416
switch p.config.Mode {
417+
case Middleware:
418+
next.ServeHTTP(w, r)
425419
case ModeForwardAuth:
426420
setHeaders(w.Header(), tokenInfo.Props)
427421
case ModeProxy:
@@ -508,13 +502,17 @@ func (p *OAuthProxy) updateGrant(grantID, userID string, oldTokenInfo *tokens.To
508502
return fmt.Errorf("failed to get grant: %w", err)
509503
}
510504

511-
// Prepare sensitive props data
512-
sensitiveProps := map[string]any{
513-
"access_token": newTokenInfo.AccessToken,
514-
"refresh_token": newTokenInfo.RefreshToken,
515-
"expires_at": newTokenInfo.Expiry.Unix(),
505+
sensitiveProps := map[string]any{}
506+
if oldTokenInfo.Props != nil {
507+
// keep all the old props, that include a lot of the user info
508+
maps.Copy(sensitiveProps, oldTokenInfo.Props)
516509
}
517510

511+
// Prepare sensitive props data
512+
sensitiveProps["access_token"] = newTokenInfo.AccessToken
513+
sensitiveProps["refresh_token"] = newTokenInfo.RefreshToken
514+
sensitiveProps["expires_at"] = newTokenInfo.Expiry.Unix()
515+
518516
// Add existing user info if available
519517
if grant.Props != nil {
520518
if email, ok := grant.Props["email"].(string); ok {

pkg/ratelimit/ratelimiter.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
package ratelimit
22

3-
import "time"
3+
import (
4+
"sync"
5+
"time"
6+
)
47

58
// RateLimiter simple in-memory rate limiter
69
type RateLimiter struct {
710
requests map[string][]time.Time
11+
lock sync.Mutex
812
window time.Duration
913
max int
1014
}
@@ -18,6 +22,8 @@ func NewRateLimiter(window time.Duration, max int) *RateLimiter {
1822
}
1923

2024
func (rl *RateLimiter) Allow(key string) bool {
25+
rl.lock.Lock()
26+
defer rl.lock.Unlock()
2127
now := time.Now()
2228
windowStart := now.Add(-rl.window)
2329

pkg/types/types.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ type Config struct {
1616
MCPServerURL string
1717
Mode string
1818
RoutePrefix string
19+
RequiredAuthPaths []string
1920
}
2021

2122
// TokenData represents stored token data for OAuth 2.1 compliance

0 commit comments

Comments
 (0)