diff --git a/config.go b/config.go index 366d201..5be0e02 100644 --- a/config.go +++ b/config.go @@ -80,6 +80,10 @@ type TokenConfig struct { // AuthCodeIssuer is the expected issuer claim in auth code JWTs. // Defaults to Token.Issuer when empty. AuthCodeIssuer string `koanf:"auth_code_issuer"` + + // EncryptionKey is the AES-256 key for encrypting downstream OAuth tokens. + // Loaded from ZEROID_TOKEN_ENCRYPTION_KEY env var. Required for downstream token API. + EncryptionKey string `koanf:"encryption_key"` } // TelemetryConfig holds OpenTelemetry settings. @@ -237,6 +241,7 @@ func loadEnvVars(k *koanf.Koanf) error { "ZEROID_BASE_URL": "token.base_url", "ZEROID_TOKEN_TTL_SECONDS": "token.default_ttl", "ZEROID_MAX_TOKEN_TTL_SECONDS": "token.max_ttl", + "ZEROID_TOKEN_ENCRYPTION_KEY": "token.encryption_key", // WIMSE "ZEROID_WIMSE_DOMAIN": "wimse_domain", diff --git a/domain/downstream_token.go b/domain/downstream_token.go new file mode 100644 index 0000000..d9a0eab --- /dev/null +++ b/domain/downstream_token.go @@ -0,0 +1,39 @@ +package domain + +import ( + "encoding/json" + "time" + + "github.com/uptrace/bun" +) + +// DownstreamToken stores an encrypted OAuth token for accessing a third-party +// MCP server on behalf of a specific user. Firehog fetches these tokens at +// request time and injects them into downstream MCP requests. +type DownstreamToken struct { + bun.BaseModel `bun:"table:downstream_tokens,alias:dt"` + + ID string `bun:"id,pk,type:uuid,default:gen_random_uuid()" json:"id"` + AccountID string `bun:"account_id,notnull" json:"account_id"` + ProjectID string `bun:"project_id,notnull" json:"project_id"` + UserID string `bun:"user_id,notnull" json:"user_id"` + ServerSlug string `bun:"server_slug,notnull" json:"server_slug"` + AccessToken string `bun:"access_token,notnull" json:"-"` + RefreshToken string `bun:"refresh_token" json:"-"` + TokenType string `bun:"token_type,notnull,default:'Bearer'" json:"token_type"` + Scopes string `bun:"scopes" json:"scopes"` + ExpiresAt *time.Time `bun:"expires_at" json:"expires_at,omitempty"` + OAuthConfig json.RawMessage `bun:"oauth_config,type:jsonb" json:"-"` + CreatedAt time.Time `bun:"created_at,nullzero,notnull,default:current_timestamp" json:"created_at"` + UpdatedAt time.Time `bun:"updated_at,nullzero,notnull,default:current_timestamp" json:"updated_at"` +} + +// DownstreamTokenStatus is a safe view of a token (no secrets). +type DownstreamTokenStatus struct { + ServerSlug string `json:"server_slug"` + UserID string `json:"user_id"` + Connected bool `json:"connected"` + TokenType string `json:"token_type"` + Scopes string `json:"scopes"` + ConnectedAt string `json:"connected_at"` +} diff --git a/internal/handler/downstream_token.go b/internal/handler/downstream_token.go new file mode 100644 index 0000000..4a08d93 --- /dev/null +++ b/internal/handler/downstream_token.go @@ -0,0 +1,207 @@ +package handler + +import ( + "context" + "encoding/json" + "net/http" + + "github.com/danielgtaylor/huma/v2" + "github.com/highflame-ai/zeroid/domain" + "github.com/highflame-ai/zeroid/internal/middleware" + "github.com/highflame-ai/zeroid/internal/service" + "github.com/rs/zerolog/log" +) + +// --- Input/Output types --- + +type StoreDownstreamTokenInput struct { + ServerSlug string `path:"server_slug" doc:"MCP server slug"` + Body struct { + AccessToken string `json:"access_token" required:"true" doc:"Downstream access token"` + RefreshToken string `json:"refresh_token,omitempty" doc:"Downstream refresh token"` + TokenType string `json:"token_type,omitempty" doc:"Token type (default: Bearer)"` + Scopes string `json:"scopes,omitempty" doc:"Space-separated scopes"` + ExpiresIn *int `json:"expires_in,omitempty" doc:"Seconds until expiry"` + OAuthConfig json.RawMessage `json:"oauth_config,omitempty" doc:"OAuth provider config for refresh"` + } +} + +type StoreDownstreamTokenOutput struct { + Body struct { + Message string `json:"message"` + } +} + +type GetDownstreamTokenInput struct { + ServerSlug string `path:"server_slug" doc:"MCP server slug"` +} + +type GetDownstreamTokenOutput struct { + Body *service.GetTokenResponse +} + +type DeleteDownstreamTokenInput struct { + ServerSlug string `path:"server_slug" doc:"MCP server slug"` +} + +type DeleteDownstreamTokenOutput struct { + Body struct { + Message string `json:"message"` + } +} + +type ListDownstreamTokensOutput struct { + Body struct { + Tokens []domain.DownstreamTokenStatus `json:"tokens"` + } +} + +// --- Route registration --- + +func (a *API) registerDownstreamTokenRoutes(api huma.API) { + huma.Register(api, huma.Operation{ + OperationID: "store-downstream-token", + Method: http.MethodPost, + Path: "/api/v1/downstream-tokens/{server_slug}", + Summary: "Store a downstream OAuth token for the current user", + Tags: []string{"Downstream Tokens"}, + DefaultStatus: http.StatusCreated, + }, a.storeDownstreamTokenOp) + + huma.Register(api, huma.Operation{ + OperationID: "get-downstream-token", + Method: http.MethodGet, + Path: "/api/v1/downstream-tokens/{server_slug}", + Summary: "Get a decrypted downstream token (for firehog injection)", + Tags: []string{"Downstream Tokens"}, + }, a.getDownstreamTokenOp) + + huma.Register(api, huma.Operation{ + OperationID: "delete-downstream-token", + Method: http.MethodDelete, + Path: "/api/v1/downstream-tokens/{server_slug}", + Summary: "Delete a downstream token (disconnect)", + Tags: []string{"Downstream Tokens"}, + }, a.deleteDownstreamTokenOp) + + huma.Register(api, huma.Operation{ + OperationID: "list-downstream-tokens", + Method: http.MethodGet, + Path: "/api/v1/downstream-tokens", + Summary: "List connected downstream servers for the current user", + Tags: []string{"Downstream Tokens"}, + }, a.listDownstreamTokensOp) +} + +// --- Operations --- + +func (a *API) checkDownstreamTokenSvc() error { + if a.downstreamTokenSvc == nil { + return huma.Error503ServiceUnavailable("downstream token service not configured (set ZEROID_TOKEN_ENCRYPTION_KEY)") + } + return nil +} + +func (a *API) storeDownstreamTokenOp(ctx context.Context, input *StoreDownstreamTokenInput) (*StoreDownstreamTokenOutput, error) { + if err := a.checkDownstreamTokenSvc(); err != nil { + return nil, err + } + tenant, err := middleware.GetTenant(ctx) + if err != nil { + return nil, huma.Error401Unauthorized("missing tenant context") + } + userID := middleware.GetCallerName(ctx) + if userID == "" { + return nil, huma.Error401Unauthorized("missing user context") + } + + err = a.downstreamTokenSvc.StoreToken(ctx, tenant.AccountID, tenant.ProjectID, userID, input.ServerSlug, &service.StoreTokenRequest{ + AccessToken: input.Body.AccessToken, + RefreshToken: input.Body.RefreshToken, + TokenType: input.Body.TokenType, + Scopes: input.Body.Scopes, + ExpiresIn: input.Body.ExpiresIn, + OAuthConfig: input.Body.OAuthConfig, + }) + if err != nil { + log.Error().Err(err).Str("server", input.ServerSlug).Str("user", userID).Msg("failed to store downstream token") + return nil, huma.Error500InternalServerError("failed to store token") + } + + out := &StoreDownstreamTokenOutput{} + out.Body.Message = "Token stored successfully" + return out, nil +} + +func (a *API) getDownstreamTokenOp(ctx context.Context, input *GetDownstreamTokenInput) (*GetDownstreamTokenOutput, error) { + if err := a.checkDownstreamTokenSvc(); err != nil { + return nil, err + } + tenant, err := middleware.GetTenant(ctx) + if err != nil { + return nil, huma.Error401Unauthorized("missing tenant context") + } + userID := middleware.GetCallerName(ctx) + if userID == "" { + return nil, huma.Error401Unauthorized("missing user context") + } + + resp, err := a.downstreamTokenSvc.GetToken(ctx, tenant.AccountID, tenant.ProjectID, userID, input.ServerSlug) + if err != nil { + log.Warn().Err(err).Str("server", input.ServerSlug).Str("user", userID).Msg("downstream token not found") + return nil, huma.Error404NotFound("downstream token not found") + } + + return &GetDownstreamTokenOutput{Body: resp}, nil +} + +func (a *API) deleteDownstreamTokenOp(ctx context.Context, input *DeleteDownstreamTokenInput) (*DeleteDownstreamTokenOutput, error) { + if err := a.checkDownstreamTokenSvc(); err != nil { + return nil, err + } + tenant, err := middleware.GetTenant(ctx) + if err != nil { + return nil, huma.Error401Unauthorized("missing tenant context") + } + userID := middleware.GetCallerName(ctx) + if userID == "" { + return nil, huma.Error401Unauthorized("missing user context") + } + + if err := a.downstreamTokenSvc.DeleteToken(ctx, tenant.AccountID, tenant.ProjectID, userID, input.ServerSlug); err != nil { + log.Error().Err(err).Str("server", input.ServerSlug).Str("user", userID).Msg("failed to delete downstream token") + return nil, huma.Error500InternalServerError("failed to delete token") + } + + out := &DeleteDownstreamTokenOutput{} + out.Body.Message = "Token deleted successfully" + return out, nil +} + +func (a *API) listDownstreamTokensOp(ctx context.Context, _ *struct{}) (*ListDownstreamTokensOutput, error) { + if err := a.checkDownstreamTokenSvc(); err != nil { + return nil, err + } + tenant, err := middleware.GetTenant(ctx) + if err != nil { + return nil, huma.Error401Unauthorized("missing tenant context") + } + userID := middleware.GetCallerName(ctx) + if userID == "" { + return nil, huma.Error401Unauthorized("missing user context") + } + + statuses, err := a.downstreamTokenSvc.ListByUser(ctx, tenant.AccountID, tenant.ProjectID, userID) + if err != nil { + log.Error().Err(err).Str("user", userID).Msg("failed to list downstream tokens") + return nil, huma.Error500InternalServerError("failed to list tokens") + } + + if statuses == nil { + statuses = []domain.DownstreamTokenStatus{} + } + + out := &ListDownstreamTokensOutput{} + out.Body.Tokens = statuses + return out, nil +} diff --git a/internal/handler/routes.go b/internal/handler/routes.go index c0fb45d..32abb6d 100644 --- a/internal/handler/routes.go +++ b/internal/handler/routes.go @@ -29,6 +29,7 @@ type API struct { oauthClientSvc *service.OAuthClientService signalSvc *service.SignalService apiKeySvc *service.APIKeyService + downstreamTokenSvc *service.DownstreamTokenService agentSvc *service.AgentService jwksSvc *signing.JWKSService db *bun.DB @@ -48,6 +49,7 @@ func NewAPI( oauthClientSvc *service.OAuthClientService, signalSvc *service.SignalService, apiKeySvc *service.APIKeyService, + downstreamTokenSvc *service.DownstreamTokenService, agentSvc *service.AgentService, jwksSvc *signing.JWKSService, db *bun.DB, @@ -63,6 +65,7 @@ func NewAPI( oauthClientSvc: oauthClientSvc, signalSvc: signalSvc, apiKeySvc: apiKeySvc, + downstreamTokenSvc: downstreamTokenSvc, agentSvc: agentSvc, jwksSvc: jwksSvc, db: db, @@ -109,6 +112,7 @@ func (a *API) RegisterAdmin(api huma.API, router chi.Router) { a.registerAttestationRoutes(api) a.registerOAuthClientRoutes(api) a.registerAPIKeyRoutes(api) + a.registerDownstreamTokenRoutes(api) a.registerAgentRoutes(api) a.registerSignalRoutes(api, router) a.registerProofVerifyRoute(api) diff --git a/internal/service/downstream_token.go b/internal/service/downstream_token.go new file mode 100644 index 0000000..286bee7 --- /dev/null +++ b/internal/service/downstream_token.go @@ -0,0 +1,283 @@ +package service + +import ( + "context" + "crypto/aes" + "crypto/cipher" + "crypto/rand" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" + + "github.com/highflame-ai/zeroid/domain" + "github.com/highflame-ai/zeroid/internal/store/postgres" + "github.com/rs/zerolog/log" +) + +type DownstreamTokenService struct { + repo *postgres.DownstreamTokenRepository + encryptionKey []byte +} + +func NewDownstreamTokenService(repo *postgres.DownstreamTokenRepository, encryptionKey []byte) *DownstreamTokenService { + return &DownstreamTokenService{ + repo: repo, + encryptionKey: encryptionKey, + } +} + +// StoreTokenRequest is the input for storing a downstream token. +type StoreTokenRequest struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token,omitempty"` + TokenType string `json:"token_type"` + Scopes string `json:"scopes,omitempty"` + ExpiresIn *int `json:"expires_in,omitempty"` + OAuthConfig json.RawMessage `json:"oauth_config,omitempty"` +} + +// GetTokenResponse is the decrypted token returned to firehog. +type GetTokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` +} + +// StoreToken encrypts and persists a downstream token. +func (s *DownstreamTokenService) StoreToken( + ctx context.Context, + accountID, projectID, userID, serverSlug string, + req *StoreTokenRequest, +) error { + encAccess, err := encryptGCM(req.AccessToken, s.encryptionKey) + if err != nil { + return fmt.Errorf("failed to encrypt access token: %w", err) + } + + var encRefresh string + if req.RefreshToken != "" { + encRefresh, err = encryptGCM(req.RefreshToken, s.encryptionKey) + if err != nil { + return fmt.Errorf("failed to encrypt refresh token: %w", err) + } + } + + var expiresAt *time.Time + if req.ExpiresIn != nil && *req.ExpiresIn > 0 { + t := time.Now().Add(time.Duration(*req.ExpiresIn) * time.Second) + expiresAt = &t + } + + tokenType := req.TokenType + if tokenType == "" { + tokenType = "Bearer" + } + + token := &domain.DownstreamToken{ + AccountID: accountID, + ProjectID: projectID, + UserID: userID, + ServerSlug: serverSlug, + AccessToken: encAccess, + RefreshToken: encRefresh, + TokenType: tokenType, + Scopes: req.Scopes, + ExpiresAt: expiresAt, + OAuthConfig: req.OAuthConfig, + UpdatedAt: time.Now(), + } + + return s.repo.Upsert(ctx, token) +} + +// GetToken retrieves and decrypts a downstream token. Auto-refreshes if expired. +func (s *DownstreamTokenService) GetToken( + ctx context.Context, + accountID, projectID, userID, serverSlug string, +) (*GetTokenResponse, error) { + token, err := s.repo.Get(ctx, accountID, projectID, userID, serverSlug) + if err != nil { + return nil, err + } + + // Auto-refresh if expired and refresh token exists. + // Note: concurrent refresh is mitigated by firehog's 5-minute token cache — + // only one request per 5 minutes reaches here per user+server pair. + if isExpired(token) && token.RefreshToken != "" { + if refreshErr := s.tryRefresh(ctx, token); refreshErr != nil { + log.Warn().Err(refreshErr). + Str("server", serverSlug). + Str("user", userID). + Msg("token refresh failed") + } + } + + accessToken, err := decryptGCM(token.AccessToken, s.encryptionKey) + if err != nil { + return nil, fmt.Errorf("failed to decrypt access token: %w", err) + } + + return &GetTokenResponse{ + AccessToken: accessToken, + TokenType: token.TokenType, + }, nil +} + +// DeleteToken removes a downstream token. +func (s *DownstreamTokenService) DeleteToken( + ctx context.Context, + accountID, projectID, userID, serverSlug string, +) error { + return s.repo.Delete(ctx, accountID, projectID, userID, serverSlug) +} + +// ListByUser returns token statuses (no secrets) for a user. +func (s *DownstreamTokenService) ListByUser( + ctx context.Context, + accountID, projectID, userID string, +) ([]domain.DownstreamTokenStatus, error) { + tokens, err := s.repo.ListByUser(ctx, accountID, projectID, userID) + if err != nil { + return nil, err + } + + statuses := make([]domain.DownstreamTokenStatus, len(tokens)) + for i, t := range tokens { + statuses[i] = domain.DownstreamTokenStatus{ + ServerSlug: t.ServerSlug, + UserID: t.UserID, + Connected: true, + TokenType: t.TokenType, + Scopes: t.Scopes, + ConnectedAt: t.CreatedAt.Format(time.RFC3339), + } + } + return statuses, nil +} + +func isExpired(token *domain.DownstreamToken) bool { + if token.ExpiresAt == nil { + return false + } + return time.Now().After(token.ExpiresAt.Add(-60 * time.Second)) +} + +func (s *DownstreamTokenService) tryRefresh(ctx context.Context, token *domain.DownstreamToken) error { + if len(token.OAuthConfig) == 0 { + return fmt.Errorf("no oauth config for refresh") + } + + var oauthCfg struct { + TokenURL string `json:"token_url"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + } + if err := json.Unmarshal(token.OAuthConfig, &oauthCfg); err != nil { + return fmt.Errorf("failed to parse oauth config: %w", err) + } + if oauthCfg.TokenURL == "" { + return fmt.Errorf("missing token_url in oauth config") + } + + refreshToken, err := decryptGCM(token.RefreshToken, s.encryptionKey) + if err != nil { + return fmt.Errorf("failed to decrypt refresh token: %w", err) + } + + // Exchange refresh token + data := url.Values{ + "grant_type": {"refresh_token"}, + "refresh_token": {refreshToken}, + "client_id": {oauthCfg.ClientID}, + "client_secret": {oauthCfg.ClientSecret}, + } + + httpClient := &http.Client{Timeout: 10 * time.Second} + resp, err := httpClient.PostForm(oauthCfg.TokenURL, data) + if err != nil { + return fmt.Errorf("refresh request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("refresh returned status %d", resp.StatusCode) + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + ExpiresIn int `json:"expires_in"` + } + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return fmt.Errorf("failed to parse refresh response: %w", err) + } + + // Encrypt new tokens + encAccess, err := encryptGCM(tokenResp.AccessToken, s.encryptionKey) + if err != nil { + return err + } + token.AccessToken = encAccess + + if tokenResp.RefreshToken != "" { + encRefresh, err := encryptGCM(tokenResp.RefreshToken, s.encryptionKey) + if err != nil { + return err + } + token.RefreshToken = encRefresh + } + + if tokenResp.ExpiresIn > 0 { + t := time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second) + token.ExpiresAt = &t + } + + return s.repo.Update(ctx, token) +} + +// AES-256-GCM encryption (compatible with admin's EncryptStringGCM) +func encryptGCM(plaintext string, key []byte) (string, error) { + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + nonce := make([]byte, aesGCM.NonceSize()) + if _, err = io.ReadFull(rand.Reader, nonce); err != nil { + return "", err + } + ciphertext := aesGCM.Seal(nonce, nonce, []byte(plaintext), nil) + return base64.URLEncoding.EncodeToString(ciphertext), nil +} + +func decryptGCM(encrypted string, key []byte) (string, error) { + enc, err := base64.URLEncoding.DecodeString(encrypted) + if err != nil { + return "", err + } + block, err := aes.NewCipher(key) + if err != nil { + return "", err + } + aesGCM, err := cipher.NewGCM(block) + if err != nil { + return "", err + } + nonceSize := aesGCM.NonceSize() + if len(enc) < nonceSize { + return "", fmt.Errorf("ciphertext too short") + } + nonce, ciphertext := enc[:nonceSize], enc[nonceSize:] + plaintext, err := aesGCM.Open(nil, nonce, ciphertext, nil) + if err != nil { + return "", err + } + return string(plaintext), nil +} + diff --git a/internal/store/postgres/downstream_token.go b/internal/store/postgres/downstream_token.go new file mode 100644 index 0000000..c7c948f --- /dev/null +++ b/internal/store/postgres/downstream_token.go @@ -0,0 +1,105 @@ +package postgres + +import ( + "context" + "fmt" + "time" + + "github.com/highflame-ai/zeroid/domain" + "github.com/uptrace/bun" +) + +type DownstreamTokenRepository struct { + db *bun.DB +} + +func NewDownstreamTokenRepository(db *bun.DB) *DownstreamTokenRepository { + return &DownstreamTokenRepository{db: db} +} + +// Upsert stores or updates a downstream token for a user+server pair. +func (r *DownstreamTokenRepository) Upsert(ctx context.Context, token *domain.DownstreamToken) error { + _, err := r.db.NewInsert(). + Model(token). + On("CONFLICT (user_id, server_slug, account_id, project_id) DO UPDATE"). + Set("access_token = EXCLUDED.access_token"). + Set("refresh_token = EXCLUDED.refresh_token"). + Set("token_type = EXCLUDED.token_type"). + Set("scopes = EXCLUDED.scopes"). + Set("expires_at = EXCLUDED.expires_at"). + Set("oauth_config = EXCLUDED.oauth_config"). + Set("updated_at = EXCLUDED.updated_at"). + Exec(ctx) + if err != nil { + return fmt.Errorf("failed to upsert downstream token: %w", err) + } + return nil +} + +// Get retrieves a downstream token by user+server within a tenant. +func (r *DownstreamTokenRepository) Get(ctx context.Context, accountID, projectID, userID, serverSlug string) (*domain.DownstreamToken, error) { + token := new(domain.DownstreamToken) + err := r.db.NewSelect(). + Model(token). + Where("account_id = ?", accountID). + Where("project_id = ?", projectID). + Where("user_id = ?", userID). + Where("server_slug = ?", serverSlug). + Scan(ctx) + if err != nil { + return nil, fmt.Errorf("downstream token not found: %w", err) + } + return token, nil +} + +// Delete removes a downstream token. +func (r *DownstreamTokenRepository) Delete(ctx context.Context, accountID, projectID, userID, serverSlug string) error { + result, err := r.db.NewDelete(). + Model((*domain.DownstreamToken)(nil)). + Where("account_id = ?", accountID). + Where("project_id = ?", projectID). + Where("user_id = ?", userID). + Where("server_slug = ?", serverSlug). + Exec(ctx) + if err != nil { + return fmt.Errorf("failed to delete downstream token: %w", err) + } + rows, _ := result.RowsAffected() + if rows == 0 { + return fmt.Errorf("downstream token not found") + } + return nil +} + +// ListByUser returns all downstream tokens for a user (no secrets). +func (r *DownstreamTokenRepository) ListByUser(ctx context.Context, accountID, projectID, userID string) ([]*domain.DownstreamToken, error) { + var tokens []*domain.DownstreamToken + err := r.db.NewSelect(). + Model(&tokens). + Column("id", "server_slug", "user_id", "token_type", "scopes", "created_at", "updated_at"). + Where("account_id = ?", accountID). + Where("project_id = ?", projectID). + Where("user_id = ?", userID). + OrderExpr("created_at DESC"). + Scan(ctx) + if err != nil { + return nil, fmt.Errorf("failed to list downstream tokens: %w", err) + } + return tokens, nil +} + +// Update updates the access/refresh token and expiry (for token refresh). +func (r *DownstreamTokenRepository) Update(ctx context.Context, token *domain.DownstreamToken) error { + _, err := r.db.NewUpdate(). + Model(token). + Set("access_token = ?", token.AccessToken). + Set("refresh_token = ?", token.RefreshToken). + Set("expires_at = ?", token.ExpiresAt). + Set("updated_at = ?", time.Now()). + Where("id = ?", token.ID). + Exec(ctx) + if err != nil { + return fmt.Errorf("failed to update downstream token: %w", err) + } + return nil +} diff --git a/migrations/007_downstream_tokens.down.sql b/migrations/007_downstream_tokens.down.sql new file mode 100644 index 0000000..d3d9a96 --- /dev/null +++ b/migrations/007_downstream_tokens.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS downstream_tokens; diff --git a/migrations/007_downstream_tokens.up.sql b/migrations/007_downstream_tokens.up.sql new file mode 100644 index 0000000..07661d8 --- /dev/null +++ b/migrations/007_downstream_tokens.up.sql @@ -0,0 +1,21 @@ +CREATE TABLE IF NOT EXISTS downstream_tokens ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + account_id VARCHAR(255) NOT NULL, + project_id VARCHAR(255) NOT NULL, + user_id VARCHAR(255) NOT NULL, + server_slug VARCHAR(255) NOT NULL, + access_token TEXT NOT NULL, + refresh_token TEXT DEFAULT '', + token_type VARCHAR(50) NOT NULL DEFAULT 'Bearer', + scopes TEXT DEFAULT '', + expires_at TIMESTAMPTZ, + oauth_config JSONB DEFAULT '{}', + created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS idx_downstream_tokens_unique + ON downstream_tokens (user_id, server_slug, account_id, project_id); + +CREATE INDEX IF NOT EXISTS idx_downstream_tokens_lookup + ON downstream_tokens (account_id, project_id, user_id); diff --git a/server.go b/server.go index e422305..4ce6213 100644 --- a/server.go +++ b/server.go @@ -2,6 +2,7 @@ package zeroid import ( "context" + "crypto/sha256" "database/sql" "fmt" "net/http" @@ -65,6 +66,7 @@ type Server struct { oauthClientSvc *service.OAuthClientService signalSvc *service.SignalService apiKeySvc *service.APIKeyService + downstreamTokenSvc *service.DownstreamTokenService agentSvc *service.AgentService jwksSvc *signing.JWKSService refreshTokenSvc *service.RefreshTokenService @@ -151,6 +153,23 @@ func NewServer(cfg Config) (*Server, error) { attestationSvc := service.NewAttestationService(attestationRepo, credentialSvc, identitySvc) oauthClientSvc := service.NewOAuthClientService(oauthClientRepo) apiKeySvc := service.NewAPIKeyService(apiKeyRepo, credentialPolicySvc, identitySvc) + + // Downstream token service (OAuth broker for third-party MCP servers) + downstreamTokenRepo := postgres.NewDownstreamTokenRepository(db) + var downstreamTokenSvc *service.DownstreamTokenService + if cfg.Token.EncryptionKey != "" { + encKey := []byte(cfg.Token.EncryptionKey) + // AES requires 16, 24, or 32 byte key — normalize via SHA-256 + if len(encKey) != 16 && len(encKey) != 24 && len(encKey) != 32 { + h := sha256.Sum256(encKey) + encKey = h[:32] + } + downstreamTokenSvc = service.NewDownstreamTokenService(downstreamTokenRepo, encKey) + log.Info().Msg("Downstream token service initialized") + } else { + log.Warn().Msg("ZEROID_TOKEN_ENCRYPTION_KEY not set — downstream token API disabled") + } + agentSvc := service.NewAgentService(identitySvc, apiKeySvc, apiKeyRepo) refreshTokenSvc := service.NewRefreshTokenService(refreshTokenRepo, db) authCodeIssuer := cfg.Token.AuthCodeIssuer @@ -170,7 +189,7 @@ func NewServer(cfg Config) (*Server, error) { apiHandler := handler.NewAPI( identitySvc, credentialSvc, credentialPolicySvc, attestationSvc, proofSvc, oauthSvc, oauthClientSvc, - signalSvc, apiKeySvc, agentSvc, jwksSvc, db, + signalSvc, apiKeySvc, downstreamTokenSvc, agentSvc, jwksSvc, db, cfg.Token.Issuer, cfg.Token.BaseURL, ) @@ -266,6 +285,7 @@ func NewServer(cfg Config) (*Server, error) { oauthClientSvc: oauthClientSvc, signalSvc: signalSvc, apiKeySvc: apiKeySvc, + downstreamTokenSvc: downstreamTokenSvc, agentSvc: agentSvc, jwksSvc: jwksSvc, refreshTokenSvc: refreshTokenSvc,