Skip to content

Commit 0f634ac

Browse files
Merge pull request #12 from obot-platform/refactoring
Enhance: refactoring code to address security
2 parents 4f00635 + 88dc411 commit 0f634ac

File tree

10 files changed

+250
-82
lines changed

10 files changed

+250
-82
lines changed

main.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ import (
88
)
99

1010
func main() {
11-
proxy, err := proxy.NewOAuthProxy()
11+
// Load configuration from environment variables
12+
config := proxy.LoadConfigFromEnv()
13+
14+
proxy, err := proxy.NewOAuthProxy(config)
1215
if err != nil {
1316
log.Fatalf("Failed to create OAuth proxy: %v", err)
1417
}

main_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ func TestIntegrationFlow(t *testing.T) {
6161
}()
6262

6363
// Create OAuth proxy
64-
oauthProxy, err := proxy.NewOAuthProxy()
64+
config := proxy.LoadConfigFromEnv()
65+
oauthProxy, err := proxy.NewOAuthProxy(config)
6566
if err != nil {
6667
t.Skipf("Skipping test due to database connection error: %v", err)
6768
}
@@ -164,7 +165,8 @@ func TestOAuthProxyCreation(t *testing.T) {
164165
}()
165166

166167
// Create OAuth proxy
167-
oauthProxy, err := proxy.NewOAuthProxy()
168+
config := proxy.LoadConfigFromEnv()
169+
oauthProxy, err := proxy.NewOAuthProxy(config)
168170
require.NoError(t, err, "Should be able to create OAuth proxy with valid environment")
169171
require.NotNil(t, oauthProxy, "OAuth proxy should not be nil")
170172

@@ -213,7 +215,8 @@ func TestOAuthProxyStart(t *testing.T) {
213215
}()
214216

215217
// Create OAuth proxy
216-
oauthProxy, err := proxy.NewOAuthProxy()
218+
config := proxy.LoadConfigFromEnv()
219+
oauthProxy, err := proxy.NewOAuthProxy(config)
217220
require.NoError(t, err)
218221
defer func() {
219222
_ = oauthProxy.Close()

pkg/db/db.go

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ func (d *Store) setupSchema() error {
7979
&types.Grant{},
8080
&types.AuthorizationCode{},
8181
&types.TokenData{},
82+
&types.StoredAuthRequest{},
8283
)
8384
if err != nil {
8485
return fmt.Errorf("failed to auto-migrate database schema: %w", err)
@@ -334,6 +335,50 @@ func (d *Store) CleanupExpiredTokens() error {
334335
fmt.Printf("Deleted %d expired grants\n", result.RowsAffected)
335336
}
336337

338+
// Delete expired auth requests
339+
if err := d.CleanupExpiredAuthRequests(); err != nil {
340+
return fmt.Errorf("failed to cleanup expired auth requests: %w", err)
341+
}
342+
343+
return nil
344+
}
345+
346+
// StoreAuthRequest stores an authorization request with a 15-minute TTL
347+
func (d *Store) StoreAuthRequest(key string, data map[string]interface{}) error {
348+
authRequest := &types.StoredAuthRequest{
349+
Key: key,
350+
Data: types.JSON(data),
351+
ExpiresAt: time.Now().Add(15 * time.Minute), // 15-minute TTL
352+
}
353+
return d.db.Create(authRequest).Error
354+
}
355+
356+
// GetAuthRequest retrieves an authorization request by key and checks TTL
357+
func (d *Store) GetAuthRequest(key string) (map[string]interface{}, error) {
358+
var authRequest types.StoredAuthRequest
359+
err := d.db.First(&authRequest, "key = ? AND expires_at > ?", key, time.Now()).Error
360+
if err != nil {
361+
return nil, err
362+
}
363+
364+
// Convert JSON back to map
365+
return map[string]interface{}(authRequest.Data), nil
366+
}
367+
368+
// DeleteAuthRequest deletes an authorization request by key
369+
func (d *Store) DeleteAuthRequest(key string) error {
370+
return d.db.Delete(&types.StoredAuthRequest{}, "key = ?", key).Error
371+
}
372+
373+
// CleanupExpiredAuthRequests removes expired authorization requests
374+
func (d *Store) CleanupExpiredAuthRequests() error {
375+
result := d.db.Where("expires_at < ?", time.Now()).Delete(&types.StoredAuthRequest{})
376+
if result.Error != nil {
377+
return fmt.Errorf("failed to cleanup expired auth requests: %w", result.Error)
378+
}
379+
if result.RowsAffected > 0 {
380+
fmt.Printf("Deleted %d expired auth requests\n", result.RowsAffected)
381+
}
337382
return nil
338383
}
339384

pkg/encryption/random.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,5 @@ func GenerateRandomString(length int) string {
1212
if _, err := rand.Read(bytes); err != nil {
1313
panic(fmt.Errorf("failed to generate random string: %w", err))
1414
}
15-
return base64.URLEncoding.EncodeToString(bytes)
15+
return base64.RawURLEncoding.EncodeToString(bytes)
1616
}

pkg/oauth/authorize/authorize.go

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,37 @@
11
package authorize
22

33
import (
4-
"encoding/base64"
5-
"encoding/json"
64
"fmt"
75
"net/http"
86
"net/url"
9-
"os"
107
"strings"
118

9+
"github.com/obot-platform/mcp-oauth-proxy/pkg/encryption"
1210
"github.com/obot-platform/mcp-oauth-proxy/pkg/handlerutils"
1311
"github.com/obot-platform/mcp-oauth-proxy/pkg/providers"
1412
"github.com/obot-platform/mcp-oauth-proxy/pkg/types"
1513
)
1614

1715
type AuthorizationStore interface {
1816
GetClient(clientID string) (*types.ClientInfo, error)
17+
StoreAuthRequest(key string, data map[string]interface{}) error
1918
}
2019

2120
type Handler struct {
2221
db AuthorizationStore
2322
provider providers.Provider
2423
scopesSupported []string
24+
clientID string
25+
clientSecret string
2526
}
2627

27-
func NewHandler(db AuthorizationStore, provider providers.Provider, scopesSupported []string) http.Handler {
28+
func NewHandler(db AuthorizationStore, provider providers.Provider, scopesSupported []string, clientID, clientSecret string) http.Handler {
2829
return &Handler{
2930
db: db,
3031
provider: provider,
3132
scopesSupported: scopesSupported,
33+
clientID: clientID,
34+
clientSecret: clientSecret,
3235
}
3336
}
3437

@@ -107,37 +110,45 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
107110
return
108111
}
109112

110-
// Get the provider's client ID and secret
111-
clientID := os.Getenv("OAUTH_CLIENT_ID")
112-
clientSecret := os.Getenv("OAUTH_CLIENT_SECRET")
113-
114113
// Check if provider is configured
115-
if clientID == "" || clientSecret == "" {
114+
if p.clientID == "" || p.clientSecret == "" {
116115
handlerutils.JSON(w, http.StatusBadRequest, types.OAuthError{
117116
Error: "invalid_request",
118117
ErrorDescription: "OAuth provider not configured",
119118
})
120119
return
121120
}
122121

123-
stateData, err := json.Marshal(authReq)
124-
if err != nil {
122+
// Generate a random state key
123+
stateKey := encryption.GenerateRandomString(32)
124+
125+
// Store the auth request data in the database
126+
authData := map[string]interface{}{
127+
"response_type": authReq.ResponseType,
128+
"client_id": authReq.ClientID,
129+
"redirect_uri": authReq.RedirectURI,
130+
"scope": authReq.Scope,
131+
"state": authReq.State,
132+
"code_challenge": authReq.CodeChallenge,
133+
"code_challenge_method": authReq.CodeChallengeMethod,
134+
}
135+
136+
if err := p.db.StoreAuthRequest(stateKey, authData); err != nil {
125137
handlerutils.JSON(w, http.StatusInternalServerError, types.OAuthError{
126138
Error: "server_error",
127-
ErrorDescription: "Failed to marshal state data",
139+
ErrorDescription: "Failed to store authorization request",
128140
})
129141
return
130142
}
131143

132-
encodedState := base64.URLEncoding.EncodeToString(stateData)
133144
redirectURI := fmt.Sprintf("%s/callback", handlerutils.GetBaseURL(r))
134145

135146
// Generate authorization URL with the provider
136147
authURL := p.provider.GetAuthorizationURL(
137-
clientID,
148+
p.clientID,
138149
redirectURI,
139150
authReq.Scope,
140-
encodedState,
151+
stateKey,
141152
)
142153

143154
// Redirect to the provider's authorization URL

pkg/oauth/callback/callback.go

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,10 @@
11
package callback
22

33
import (
4-
"encoding/base64"
5-
"encoding/json"
64
"fmt"
75
"log"
86
"net/http"
97
"net/url"
10-
"os"
118
"strings"
129
"time"
1310

@@ -20,22 +17,48 @@ import (
2017
type Store interface {
2118
StoreGrant(grant *types.Grant) error
2219
StoreAuthCode(code, grantID, userID string) error
20+
GetAuthRequest(key string) (map[string]interface{}, error)
21+
DeleteAuthRequest(key string) error
2322
}
2423

2524
type Handler struct {
2625
db Store
2726
provider providers.Provider
2827
encryptionKey []byte
28+
clientID string
29+
clientSecret string
2930
}
3031

31-
func NewHandler(db Store, provider providers.Provider, encryptionKey []byte) http.Handler {
32+
func NewHandler(db Store, provider providers.Provider, encryptionKey []byte, clientID, clientSecret string) http.Handler {
3233
return &Handler{
3334
db: db,
3435
provider: provider,
3536
encryptionKey: encryptionKey,
37+
clientID: clientID,
38+
clientSecret: clientSecret,
3639
}
3740
}
3841

42+
// scopeContainsProfileOrEmail checks if the given scopes contain profile or email
43+
func (p *Handler) scopeContainsProfileOrEmail(scopes []string) bool {
44+
for _, scope := range scopes {
45+
if scope == "profile" || scope == "email" {
46+
return true
47+
}
48+
}
49+
return false
50+
}
51+
52+
// getStringFromMap safely extracts a string value from a map[string]interface{}
53+
func getStringFromMap(data map[string]interface{}, key string) string {
54+
if value, ok := data[key]; ok {
55+
if str, ok := value.(string); ok {
56+
return str
57+
}
58+
}
59+
return ""
60+
}
61+
3962
func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
4063
// Handle OAuth callback from external providers
4164
code := r.URL.Query().Get("code")
@@ -61,30 +84,39 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
6184
return
6285
}
6386

64-
var authReq types.AuthRequest
65-
stateData, err := base64.URLEncoding.DecodeString(state)
87+
// Retrieve auth request data from database using state as key
88+
authData, err := p.db.GetAuthRequest(state)
6689
if err != nil {
6790
handlerutils.JSON(w, http.StatusBadRequest, types.OAuthError{
6891
Error: "invalid_request",
69-
ErrorDescription: "Invalid state parameter",
92+
ErrorDescription: "Invalid or expired state parameter",
7093
})
7194
return
7295
}
73-
if err := json.Unmarshal(stateData, &authReq); err != nil {
74-
handlerutils.JSON(w, http.StatusBadRequest, types.OAuthError{
75-
Error: "invalid_request",
76-
ErrorDescription: "Invalid state parameter",
77-
})
78-
return
96+
97+
// Convert auth data back to AuthRequest struct
98+
authReq := types.AuthRequest{
99+
ResponseType: getStringFromMap(authData, "response_type"),
100+
ClientID: getStringFromMap(authData, "client_id"),
101+
RedirectURI: getStringFromMap(authData, "redirect_uri"),
102+
Scope: getStringFromMap(authData, "scope"),
103+
State: getStringFromMap(authData, "state"),
104+
CodeChallenge: getStringFromMap(authData, "code_challenge"),
105+
CodeChallengeMethod: getStringFromMap(authData, "code_challenge_method"),
79106
}
80107

108+
// Clean up the auth request data after successful retrieval
109+
defer func() {
110+
if err := p.db.DeleteAuthRequest(state); err != nil {
111+
log.Printf("Failed to delete auth request: %v", err)
112+
}
113+
}()
114+
81115
// Get provider credentials
82-
clientID := os.Getenv("OAUTH_CLIENT_ID")
83-
clientSecret := os.Getenv("OAUTH_CLIENT_SECRET")
84116
redirectURI := fmt.Sprintf("%s/callback", handlerutils.GetBaseURL(r))
85117

86118
// Exchange code for tokens
87-
tokenInfo, err := p.provider.ExchangeCodeForToken(r.Context(), code, clientID, clientSecret, redirectURI)
119+
tokenInfo, err := p.provider.ExchangeCodeForToken(r.Context(), code, p.clientID, p.clientSecret, redirectURI)
88120
if err != nil {
89121
log.Printf("Failed to exchange code for token: %v", err)
90122
handlerutils.JSON(w, http.StatusBadRequest, types.OAuthError{
@@ -94,15 +126,22 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
94126
return
95127
}
96128

97-
// Get user info from the provider
98-
userInfo, err := p.provider.GetUserInfo(r.Context(), tokenInfo.AccessToken)
99-
if err != nil {
100-
log.Printf("Failed to get user info: %v", err)
101-
handlerutils.JSON(w, http.StatusBadRequest, types.OAuthError{
102-
Error: "invalid_grant",
103-
ErrorDescription: "Failed to get user information",
104-
})
105-
return
129+
// Check if scope includes profile or email before getting user info
130+
scopes := strings.Fields(authReq.Scope)
131+
needsUserInfo := p.scopeContainsProfileOrEmail(scopes)
132+
133+
userInfo := &providers.UserInfo{}
134+
if needsUserInfo {
135+
// Get user info from the provider
136+
userInfo, err = p.provider.GetUserInfo(r.Context(), tokenInfo.AccessToken)
137+
if err != nil {
138+
log.Printf("Failed to get user info: %v", err)
139+
handlerutils.JSON(w, http.StatusBadRequest, types.OAuthError{
140+
Error: "invalid_grant",
141+
ErrorDescription: "Failed to get user information",
142+
})
143+
return
144+
}
106145
}
107146

108147
// Create a grant for this user
@@ -111,13 +150,17 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
111150

112151
// Prepare sensitive props data
113152
sensitiveProps := map[string]interface{}{
114-
"email": userInfo.Email,
115-
"name": userInfo.Name,
116153
"access_token": tokenInfo.AccessToken,
117154
"refresh_token": tokenInfo.RefreshToken,
118155
"expires_at": tokenInfo.ExpireAt,
119156
}
120157

158+
// Only add user info if we have it
159+
if needsUserInfo {
160+
sensitiveProps["email"] = userInfo.Email
161+
sensitiveProps["name"] = userInfo.Name
162+
}
163+
121164
// Initialize props map
122165
props := make(map[string]interface{})
123166

@@ -138,17 +181,13 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
138181
props["algorithm"] = encryptedProps.Algorithm
139182
props["encrypted"] = true
140183

141-
// Add non-sensitive data
142-
props["user_id"] = userInfo.ID
143-
144184
grant := &types.Grant{
145185
ID: grantID,
146186
ClientID: authReq.ClientID,
147187
UserID: userInfo.ID,
148-
Scope: strings.Fields(authReq.Scope),
188+
Scope: scopes,
149189
Metadata: map[string]interface{}{
150190
"provider": p.provider,
151-
"label": userInfo.Name,
152191
},
153192
Props: props,
154193
CreatedAt: now,

0 commit comments

Comments
 (0)