Skip to content

Commit b7fa030

Browse files
committed
chore: switch to oauth2 package for token handling
Signed-off-by: Donnie Adams <donnie@acorn.io>
1 parent 0f634ac commit b7fa030

File tree

18 files changed

+134
-319
lines changed

18 files changed

+134
-319
lines changed

go.mod

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ require (
2424
github.com/pmezard/go-difflib v1.0.0 // indirect
2525
github.com/rogpeppe/go-internal v1.8.0 // indirect
2626
golang.org/x/crypto v0.41.0 // indirect
27+
golang.org/x/oauth2 v0.30.0 // indirect
2728
golang.org/x/sync v0.16.0 // indirect
2829
golang.org/x/text v0.28.0 // indirect
2930
gopkg.in/yaml.v3 v3.0.1 // indirect

go.sum

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOf
3737
github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
3838
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
3939
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
40+
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
41+
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
4042
golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw=
4143
golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA=
4244
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=

pkg/db/db.go

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ func (d *Store) StoreClient(client *types.ClientInfo) error {
105105

106106
// StoreGrant stores a new grant
107107
func (d *Store) StoreGrant(grant *types.Grant) error {
108-
// Convert []string to StringSlice and map[string]interface{} to JSON for GORM
108+
// Convert []string to StringSlice and map[string]any to JSON for GORM
109109
gormGrant := &types.Grant{
110110
ID: grant.ID,
111111
ClientID: grant.ClientID,
@@ -157,8 +157,8 @@ func (d *Store) GetGrant(grantID, userID string) (*types.Grant, error) {
157157
ClientID: grant.ClientID,
158158
UserID: grant.UserID,
159159
Scope: []string(grant.Scope),
160-
Metadata: map[string]interface{}(grant.Metadata),
161-
Props: map[string]interface{}(grant.Props),
160+
Metadata: map[string]any(grant.Metadata),
161+
Props: map[string]any(grant.Props),
162162
CreatedAt: grant.CreatedAt,
163163
ExpiresAt: grant.ExpiresAt,
164164
CodeChallenge: grant.CodeChallenge,
@@ -275,7 +275,7 @@ func (d *Store) RevokeToken(token string) error {
275275
now := time.Now()
276276

277277
// First try to revoke as access token
278-
result := d.db.Model(&types.TokenData{}).Where("access_token = ?", hashedToken).Updates(map[string]interface{}{
278+
result := d.db.Model(&types.TokenData{}).Where("access_token = ?", hashedToken).Updates(map[string]any{
279279
"revoked": true,
280280
"revoked_at": &now,
281281
})
@@ -288,7 +288,7 @@ func (d *Store) RevokeToken(token string) error {
288288
}
289289

290290
// If not found as access token, try as refresh token
291-
result = d.db.Model(&types.TokenData{}).Where("refresh_token = ?", hashedToken).Updates(map[string]interface{}{
291+
result = d.db.Model(&types.TokenData{}).Where("refresh_token = ?", hashedToken).Updates(map[string]any{
292292
"revoked": true,
293293
"revoked_at": &now,
294294
})
@@ -344,7 +344,7 @@ func (d *Store) CleanupExpiredTokens() error {
344344
}
345345

346346
// StoreAuthRequest stores an authorization request with a 15-minute TTL
347-
func (d *Store) StoreAuthRequest(key string, data map[string]interface{}) error {
347+
func (d *Store) StoreAuthRequest(key string, data map[string]any) error {
348348
authRequest := &types.StoredAuthRequest{
349349
Key: key,
350350
Data: types.JSON(data),
@@ -354,15 +354,15 @@ func (d *Store) StoreAuthRequest(key string, data map[string]interface{}) error
354354
}
355355

356356
// GetAuthRequest retrieves an authorization request by key and checks TTL
357-
func (d *Store) GetAuthRequest(key string) (map[string]interface{}, error) {
357+
func (d *Store) GetAuthRequest(key string) (map[string]any, error) {
358358
var authRequest types.StoredAuthRequest
359359
err := d.db.First(&authRequest, "key = ? AND expires_at > ?", key, time.Now()).Error
360360
if err != nil {
361361
return nil, err
362362
}
363363

364364
// Convert JSON back to map
365-
return map[string]interface{}(authRequest.Data), nil
365+
return map[string]any(authRequest.Data), nil
366366
}
367367

368368
// DeleteAuthRequest deletes an authorization request by key

pkg/db/db_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ func testGrantOperations(t *testing.T, db *Store) {
128128
ClientID: clientID,
129129
UserID: userID,
130130
Scope: []string{"read", "write", "admin"},
131-
Metadata: map[string]interface{}{"provider": "test", "ip": "127.0.0.1"},
132-
Props: map[string]interface{}{"email": "test@example.com", "name": "Test User"},
131+
Metadata: map[string]any{"provider": "test", "ip": "127.0.0.1"},
132+
Props: map[string]any{"email": "test@example.com", "name": "Test User"},
133133
CreatedAt: time.Now().Unix(),
134134
ExpiresAt: time.Now().Add(10 * time.Minute).Unix(),
135135
CodeChallenge: "test_challenge",
@@ -174,7 +174,7 @@ func testTokenOperations(t *testing.T, db *Store) {
174174
ClientID: "test_client_db",
175175
UserID: "test_user_123",
176176
Scope: []string{"read", "write", "admin"},
177-
Metadata: map[string]interface{}{"provider": "test", "ip": "127.0.0.1"},
177+
Metadata: map[string]any{"provider": "test", "ip": "127.0.0.1"},
178178
}
179179

180180
err = db.StoreGrant(grant)
@@ -262,7 +262,7 @@ func testAuthCodeOperations(t *testing.T, db *Store) {
262262
ClientID: "test_client_db",
263263
UserID: userID,
264264
Scope: []string{"read", "write", "admin"},
265-
Metadata: map[string]interface{}{"provider": "test", "ip": "127.0.0.1"},
265+
Metadata: map[string]any{"provider": "test", "ip": "127.0.0.1"},
266266
}
267267

268268
err = db.StoreGrant(grant)
@@ -298,7 +298,7 @@ func testCleanupOperations(t *testing.T, db *Store) {
298298
ClientID: "test_client_db",
299299
UserID: userID,
300300
Scope: []string{"read", "write", "admin"},
301-
Metadata: map[string]interface{}{"provider": "test", "ip": "127.0.0.1"},
301+
Metadata: map[string]any{"provider": "test", "ip": "127.0.0.1"},
302302
}
303303

304304
err = db.StoreGrant(grant)
@@ -382,7 +382,7 @@ func testRefreshTokenExpiration(t *testing.T, db *Store) {
382382
ClientID: clientID,
383383
UserID: "test_user_123",
384384
Scope: []string{"read", "write", "admin"},
385-
Metadata: map[string]interface{}{"provider": "test", "ip": "127.0.0.1"},
385+
Metadata: map[string]any{"provider": "test", "ip": "127.0.0.1"},
386386
ExpiresAt: time.Now().Add(1 * time.Hour).Unix(),
387387
}
388388

pkg/db/sqlite_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ func TestSQLiteDatabase(t *testing.T) {
6363
ClientID: "test_client_sqlite",
6464
UserID: "test_user",
6565
Scope: []string{"openid", "profile"},
66-
Metadata: map[string]interface{}{"test": "value"},
67-
Props: map[string]interface{}{"prop": "value"},
66+
Metadata: map[string]any{"test": "value"},
67+
Props: map[string]any{"prop": "value"},
6868
CreatedAt: time.Now().Unix(),
6969
ExpiresAt: time.Now().Add(time.Hour).Unix(),
7070
}

pkg/encryption/encryption.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ type EncryptedProps struct {
1717
}
1818

1919
// EncryptData encrypts sensitive data using AES-256-GCM
20-
func EncryptData(data map[string]interface{}, encryptionKey []byte) (*EncryptedProps, error) {
20+
func EncryptData(data map[string]any, encryptionKey []byte) (*EncryptedProps, error) {
2121
// Convert data to JSON
2222
jsonData, err := json.Marshal(data)
2323
if err != nil {
@@ -53,7 +53,7 @@ func EncryptData(data map[string]interface{}, encryptionKey []byte) (*EncryptedP
5353
}
5454

5555
// DecryptData decrypts encrypted data using AES-256-GCM
56-
func DecryptData(encryptedProps *EncryptedProps, encryptionKey []byte) (map[string]interface{}, error) {
56+
func DecryptData(encryptedProps *EncryptedProps, encryptionKey []byte) (map[string]any, error) {
5757
// Decode base64 data
5858
ciphertext, err := base64.StdEncoding.DecodeString(encryptedProps.Data)
5959
if err != nil {
@@ -84,7 +84,7 @@ func DecryptData(encryptedProps *EncryptedProps, encryptionKey []byte) (map[stri
8484
}
8585

8686
// Unmarshal JSON data
87-
var data map[string]interface{}
87+
var data map[string]any
8888
if err := json.Unmarshal(plaintext, &data); err != nil {
8989
return nil, fmt.Errorf("failed to unmarshal decrypted data: %w", err)
9090
}
@@ -93,7 +93,7 @@ func DecryptData(encryptedProps *EncryptedProps, encryptionKey []byte) (map[stri
9393
}
9494

9595
// DecryptPropsIfNeeded decrypts props data if it's encrypted, otherwise returns the original data
96-
func DecryptPropsIfNeeded(encryptionKey []byte, props map[string]interface{}) (map[string]interface{}, error) {
96+
func DecryptPropsIfNeeded(encryptionKey []byte, props map[string]any) (map[string]any, error) {
9797
// Check if data is encrypted
9898
encrypted, ok := props["encrypted"].(bool)
9999
if !ok || !encrypted {
@@ -131,7 +131,7 @@ func DecryptPropsIfNeeded(encryptionKey []byte, props map[string]interface{}) (m
131131
}
132132

133133
// Merge decrypted data with non-sensitive props
134-
result := make(map[string]interface{})
134+
result := make(map[string]any)
135135
for key, value := range props {
136136
if key != "encrypted_data" && key != "iv" && key != "algorithm" && key != "encrypted" {
137137
result[key] = value

pkg/handlerutils/handlerutils.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ import (
88
"strings"
99
)
1010

11-
func JSON(w http.ResponseWriter, statusCode int, obj interface{}) {
11+
func JSON(w http.ResponseWriter, statusCode int, obj any) {
1212
w.Header().Set("Content-Type", "application/json")
1313
w.WriteHeader(statusCode)
1414
if obj != nil {

pkg/oauth/authorize/authorize.go

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"fmt"
55
"net/http"
66
"net/url"
7+
"slices"
78
"strings"
89

910
"github.com/obot-platform/mcp-oauth-proxy/pkg/encryption"
@@ -14,7 +15,7 @@ import (
1415

1516
type AuthorizationStore interface {
1617
GetClient(clientID string) (*types.ClientInfo, error)
17-
StoreAuthRequest(key string, data map[string]interface{}) error
18+
StoreAuthRequest(key string, data map[string]any) error
1819
}
1920

2021
type Handler struct {
@@ -94,15 +95,7 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
9495
return
9596
}
9697

97-
// Validate redirect URI
98-
validRedirect := false
99-
for _, uri := range clientInfo.RedirectUris {
100-
if uri == authReq.RedirectURI {
101-
validRedirect = true
102-
break
103-
}
104-
}
105-
if !validRedirect {
98+
if !slices.Contains(clientInfo.RedirectUris, authReq.RedirectURI) {
10699
handlerutils.JSON(w, http.StatusBadRequest, types.OAuthError{
107100
Error: "invalid_request",
108101
ErrorDescription: "Invalid redirect URI",
@@ -123,7 +116,7 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
123116
stateKey := encryption.GenerateRandomString(32)
124117

125118
// Store the auth request data in the database
126-
authData := map[string]interface{}{
119+
authData := map[string]any{
127120
"response_type": authReq.ResponseType,
128121
"client_id": authReq.ClientID,
129122
"redirect_uri": authReq.RedirectURI,

pkg/oauth/callback/callback.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import (
1717
type Store interface {
1818
StoreGrant(grant *types.Grant) error
1919
StoreAuthCode(code, grantID, userID string) error
20-
GetAuthRequest(key string) (map[string]interface{}, error)
20+
GetAuthRequest(key string) (map[string]any, error)
2121
DeleteAuthRequest(key string) error
2222
}
2323

@@ -49,8 +49,8 @@ func (p *Handler) scopeContainsProfileOrEmail(scopes []string) bool {
4949
return false
5050
}
5151

52-
// getStringFromMap safely extracts a string value from a map[string]interface{}
53-
func getStringFromMap(data map[string]interface{}, key string) string {
52+
// getStringFromMap safely extracts a string value from a map[string]any
53+
func getStringFromMap(data map[string]any, key string) string {
5454
if value, ok := data[key]; ok {
5555
if str, ok := value.(string); ok {
5656
return str
@@ -149,10 +149,10 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
149149
now := time.Now().Unix()
150150

151151
// Prepare sensitive props data
152-
sensitiveProps := map[string]interface{}{
152+
sensitiveProps := map[string]any{
153153
"access_token": tokenInfo.AccessToken,
154154
"refresh_token": tokenInfo.RefreshToken,
155-
"expires_at": tokenInfo.ExpireAt,
155+
"expires_at": tokenInfo.Expiry.Unix(),
156156
}
157157

158158
// Only add user info if we have it
@@ -162,7 +162,7 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
162162
}
163163

164164
// Initialize props map
165-
props := make(map[string]interface{})
165+
props := make(map[string]any)
166166

167167
// Encrypt the sensitive props data
168168
encryptedProps, err := encryption.EncryptData(sensitiveProps, p.encryptionKey)
@@ -186,7 +186,7 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
186186
ClientID: authReq.ClientID,
187187
UserID: userInfo.ID,
188188
Scope: scopes,
189-
Metadata: map[string]interface{}{
189+
Metadata: map[string]any{
190190
"provider": p.provider,
191191
},
192192
Props: props,

pkg/oauth/register/register.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
4949
}
5050

5151
// Parse request.JSON body
52-
var clientMetadata map[string]interface{}
52+
var clientMetadata map[string]any
5353
if err := json.NewDecoder(io.LimitReader(r.Body, 1024*1024)).Decode(&clientMetadata); err != nil {
5454
handlerutils.JSON(w, http.StatusBadRequest, types.OAuthError{
5555
Error: "invalid_request",
@@ -112,7 +112,7 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
112112

113113
// Build response
114114
baseURL := handlerutils.GetBaseURL(r)
115-
response := map[string]interface{}{
115+
response := map[string]any{
116116
"client_id": clientInfo.ClientID,
117117
"redirect_uris": clientInfo.RedirectUris,
118118
"client_name": clientInfo.ClientName,
@@ -138,9 +138,9 @@ func (p *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
138138
handlerutils.JSON(w, http.StatusCreated, response)
139139
}
140140

141-
func (p *Handler) validateClientMetadata(metadata map[string]interface{}) (*types.ClientInfo, error) {
141+
func (p *Handler) validateClientMetadata(metadata map[string]any) (*types.ClientInfo, error) {
142142
// Helper function to validate string fields
143-
validateStringField := func(field interface{}, name string) (string, error) {
143+
validateStringField := func(field any, name string) (string, error) {
144144
if field == nil {
145145
return "", nil
146146
}
@@ -151,11 +151,11 @@ func (p *Handler) validateClientMetadata(metadata map[string]interface{}) (*type
151151
}
152152

153153
// Helper function to validate string arrays
154-
validateStringArray := func(arr interface{}, name string) ([]string, error) {
154+
validateStringArray := func(arr any, name string) ([]string, error) {
155155
if arr == nil {
156156
return nil, nil
157157
}
158-
if array, ok := arr.([]interface{}); ok {
158+
if array, ok := arr.([]any); ok {
159159
result := make([]string, len(array))
160160
for i, item := range array {
161161
if str, ok := item.(string); ok {

0 commit comments

Comments
 (0)