Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@ type TokenConfig struct {
// AuthCodeIssuer is the expected issuer claim in auth code JWTs.
// Defaults to Token.Issuer when empty.
AuthCodeIssuer string `koanf:"auth_code_issuer"`

// EncryptionKey is the AES-256 key for encrypting downstream OAuth tokens.
// Loaded from ZEROID_TOKEN_ENCRYPTION_KEY env var. Required for downstream token API.
EncryptionKey string `koanf:"encryption_key"`
}

// TelemetryConfig holds OpenTelemetry settings.
Expand Down Expand Up @@ -237,6 +241,7 @@ func loadEnvVars(k *koanf.Koanf) error {
"ZEROID_BASE_URL": "token.base_url",
"ZEROID_TOKEN_TTL_SECONDS": "token.default_ttl",
"ZEROID_MAX_TOKEN_TTL_SECONDS": "token.max_ttl",
"ZEROID_TOKEN_ENCRYPTION_KEY": "token.encryption_key",

// WIMSE
"ZEROID_WIMSE_DOMAIN": "wimse_domain",
Expand Down
39 changes: 39 additions & 0 deletions domain/downstream_token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
package domain

import (
"encoding/json"
"time"

"github.com/uptrace/bun"
)

// DownstreamToken stores an encrypted OAuth token for accessing a third-party
// MCP server on behalf of a specific user. Firehog fetches these tokens at
// request time and injects them into downstream MCP requests.
type DownstreamToken struct {
bun.BaseModel `bun:"table:downstream_tokens,alias:dt"`

ID string `bun:"id,pk,type:uuid,default:gen_random_uuid()" json:"id"`
AccountID string `bun:"account_id,notnull" json:"account_id"`
ProjectID string `bun:"project_id,notnull" json:"project_id"`
UserID string `bun:"user_id,notnull" json:"user_id"`
ServerSlug string `bun:"server_slug,notnull" json:"server_slug"`
AccessToken string `bun:"access_token,notnull" json:"-"`
RefreshToken string `bun:"refresh_token" json:"-"`
TokenType string `bun:"token_type,notnull,default:'Bearer'" json:"token_type"`
Scopes string `bun:"scopes" json:"scopes"`
ExpiresAt *time.Time `bun:"expires_at" json:"expires_at,omitempty"`
OAuthConfig json.RawMessage `bun:"oauth_config,type:jsonb" json:"-"`
CreatedAt time.Time `bun:"created_at,nullzero,notnull,default:current_timestamp" json:"created_at"`
UpdatedAt time.Time `bun:"updated_at,nullzero,notnull,default:current_timestamp" json:"updated_at"`
}

// DownstreamTokenStatus is a safe view of a token (no secrets).
type DownstreamTokenStatus struct {
ServerSlug string `json:"server_slug"`
UserID string `json:"user_id"`
Connected bool `json:"connected"`
TokenType string `json:"token_type"`
Scopes string `json:"scopes"`
ConnectedAt string `json:"connected_at"`
}
207 changes: 207 additions & 0 deletions internal/handler/downstream_token.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
package handler

import (
"context"
"encoding/json"
"net/http"

"github.com/danielgtaylor/huma/v2"
"github.com/highflame-ai/zeroid/domain"
"github.com/highflame-ai/zeroid/internal/middleware"
"github.com/highflame-ai/zeroid/internal/service"
"github.com/rs/zerolog/log"
)

// --- Input/Output types ---

type StoreDownstreamTokenInput struct {
ServerSlug string `path:"server_slug" doc:"MCP server slug"`
Body struct {
AccessToken string `json:"access_token" required:"true" doc:"Downstream access token"`
RefreshToken string `json:"refresh_token,omitempty" doc:"Downstream refresh token"`
TokenType string `json:"token_type,omitempty" doc:"Token type (default: Bearer)"`
Scopes string `json:"scopes,omitempty" doc:"Space-separated scopes"`
ExpiresIn *int `json:"expires_in,omitempty" doc:"Seconds until expiry"`
OAuthConfig json.RawMessage `json:"oauth_config,omitempty" doc:"OAuth provider config for refresh"`
}
}

type StoreDownstreamTokenOutput struct {
Body struct {
Message string `json:"message"`
}
}

type GetDownstreamTokenInput struct {
ServerSlug string `path:"server_slug" doc:"MCP server slug"`
}

type GetDownstreamTokenOutput struct {
Body *service.GetTokenResponse
}

type DeleteDownstreamTokenInput struct {
ServerSlug string `path:"server_slug" doc:"MCP server slug"`
}

type DeleteDownstreamTokenOutput struct {
Body struct {
Message string `json:"message"`
}
}

type ListDownstreamTokensOutput struct {
Body struct {
Tokens []domain.DownstreamTokenStatus `json:"tokens"`
}
}

// --- Route registration ---

func (a *API) registerDownstreamTokenRoutes(api huma.API) {
huma.Register(api, huma.Operation{
OperationID: "store-downstream-token",
Method: http.MethodPost,
Path: "/api/v1/downstream-tokens/{server_slug}",
Summary: "Store a downstream OAuth token for the current user",
Tags: []string{"Downstream Tokens"},
DefaultStatus: http.StatusCreated,
}, a.storeDownstreamTokenOp)

huma.Register(api, huma.Operation{
OperationID: "get-downstream-token",
Method: http.MethodGet,
Path: "/api/v1/downstream-tokens/{server_slug}",
Summary: "Get a decrypted downstream token (for firehog injection)",
Tags: []string{"Downstream Tokens"},
}, a.getDownstreamTokenOp)

huma.Register(api, huma.Operation{
OperationID: "delete-downstream-token",
Method: http.MethodDelete,
Path: "/api/v1/downstream-tokens/{server_slug}",
Summary: "Delete a downstream token (disconnect)",
Tags: []string{"Downstream Tokens"},
}, a.deleteDownstreamTokenOp)

huma.Register(api, huma.Operation{
OperationID: "list-downstream-tokens",
Method: http.MethodGet,
Path: "/api/v1/downstream-tokens",
Summary: "List connected downstream servers for the current user",
Tags: []string{"Downstream Tokens"},
}, a.listDownstreamTokensOp)
}

// --- Operations ---

func (a *API) checkDownstreamTokenSvc() error {
if a.downstreamTokenSvc == nil {
return huma.Error503ServiceUnavailable("downstream token service not configured (set ZEROID_TOKEN_ENCRYPTION_KEY)")
}
return nil
}

func (a *API) storeDownstreamTokenOp(ctx context.Context, input *StoreDownstreamTokenInput) (*StoreDownstreamTokenOutput, error) {
if err := a.checkDownstreamTokenSvc(); err != nil {
return nil, err
}
tenant, err := middleware.GetTenant(ctx)
if err != nil {
return nil, huma.Error401Unauthorized("missing tenant context")
}
userID := middleware.GetCallerName(ctx)
if userID == "" {
return nil, huma.Error401Unauthorized("missing user context")
}

err = a.downstreamTokenSvc.StoreToken(ctx, tenant.AccountID, tenant.ProjectID, userID, input.ServerSlug, &service.StoreTokenRequest{
AccessToken: input.Body.AccessToken,
RefreshToken: input.Body.RefreshToken,
TokenType: input.Body.TokenType,
Scopes: input.Body.Scopes,
ExpiresIn: input.Body.ExpiresIn,
OAuthConfig: input.Body.OAuthConfig,
})
if err != nil {
log.Error().Err(err).Str("server", input.ServerSlug).Str("user", userID).Msg("failed to store downstream token")
return nil, huma.Error500InternalServerError("failed to store token")
}

out := &StoreDownstreamTokenOutput{}
out.Body.Message = "Token stored successfully"
return out, nil
}

func (a *API) getDownstreamTokenOp(ctx context.Context, input *GetDownstreamTokenInput) (*GetDownstreamTokenOutput, error) {
if err := a.checkDownstreamTokenSvc(); err != nil {
return nil, err
}
tenant, err := middleware.GetTenant(ctx)
if err != nil {
return nil, huma.Error401Unauthorized("missing tenant context")
}
userID := middleware.GetCallerName(ctx)
if userID == "" {
return nil, huma.Error401Unauthorized("missing user context")
}

resp, err := a.downstreamTokenSvc.GetToken(ctx, tenant.AccountID, tenant.ProjectID, userID, input.ServerSlug)
if err != nil {
log.Warn().Err(err).Str("server", input.ServerSlug).Str("user", userID).Msg("downstream token not found")
return nil, huma.Error404NotFound("downstream token not found")
}

return &GetDownstreamTokenOutput{Body: resp}, nil
}

func (a *API) deleteDownstreamTokenOp(ctx context.Context, input *DeleteDownstreamTokenInput) (*DeleteDownstreamTokenOutput, error) {
if err := a.checkDownstreamTokenSvc(); err != nil {
return nil, err
}
tenant, err := middleware.GetTenant(ctx)
if err != nil {
return nil, huma.Error401Unauthorized("missing tenant context")
}
userID := middleware.GetCallerName(ctx)
if userID == "" {
return nil, huma.Error401Unauthorized("missing user context")
}

if err := a.downstreamTokenSvc.DeleteToken(ctx, tenant.AccountID, tenant.ProjectID, userID, input.ServerSlug); err != nil {
log.Error().Err(err).Str("server", input.ServerSlug).Str("user", userID).Msg("failed to delete downstream token")
return nil, huma.Error500InternalServerError("failed to delete token")
}

out := &DeleteDownstreamTokenOutput{}
out.Body.Message = "Token deleted successfully"
return out, nil
}

func (a *API) listDownstreamTokensOp(ctx context.Context, _ *struct{}) (*ListDownstreamTokensOutput, error) {
if err := a.checkDownstreamTokenSvc(); err != nil {
return nil, err
}
tenant, err := middleware.GetTenant(ctx)
if err != nil {
return nil, huma.Error401Unauthorized("missing tenant context")
}
userID := middleware.GetCallerName(ctx)
if userID == "" {
return nil, huma.Error401Unauthorized("missing user context")
}

statuses, err := a.downstreamTokenSvc.ListByUser(ctx, tenant.AccountID, tenant.ProjectID, userID)
if err != nil {
log.Error().Err(err).Str("user", userID).Msg("failed to list downstream tokens")
return nil, huma.Error500InternalServerError("failed to list tokens")
}

if statuses == nil {
statuses = []domain.DownstreamTokenStatus{}
}

out := &ListDownstreamTokensOutput{}
out.Body.Tokens = statuses
return out, nil
}
4 changes: 4 additions & 0 deletions internal/handler/routes.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ type API struct {
oauthClientSvc *service.OAuthClientService
signalSvc *service.SignalService
apiKeySvc *service.APIKeyService
downstreamTokenSvc *service.DownstreamTokenService
agentSvc *service.AgentService
jwksSvc *signing.JWKSService
db *bun.DB
Expand All @@ -48,6 +49,7 @@ func NewAPI(
oauthClientSvc *service.OAuthClientService,
signalSvc *service.SignalService,
apiKeySvc *service.APIKeyService,
downstreamTokenSvc *service.DownstreamTokenService,
agentSvc *service.AgentService,
jwksSvc *signing.JWKSService,
db *bun.DB,
Expand All @@ -63,6 +65,7 @@ func NewAPI(
oauthClientSvc: oauthClientSvc,
signalSvc: signalSvc,
apiKeySvc: apiKeySvc,
downstreamTokenSvc: downstreamTokenSvc,
agentSvc: agentSvc,
jwksSvc: jwksSvc,
db: db,
Expand Down Expand Up @@ -109,6 +112,7 @@ func (a *API) RegisterAdmin(api huma.API, router chi.Router) {
a.registerAttestationRoutes(api)
a.registerOAuthClientRoutes(api)
a.registerAPIKeyRoutes(api)
a.registerDownstreamTokenRoutes(api)
a.registerAgentRoutes(api)
a.registerSignalRoutes(api, router)
a.registerProofVerifyRoute(api)
Expand Down
Loading