diff --git a/.gitignore b/.gitignore index ef2d935a..9e730c98 100644 --- a/.gitignore +++ b/.gitignore @@ -15,6 +15,7 @@ pgstore/* gitstore/* objectstore/* static/* +refs/* # Authentication data auths/* @@ -30,3 +31,7 @@ GEMINI.md .vscode/* .claude/* .serena/* + +# macOS +.DS_Store +._* diff --git a/internal/api/server.go b/internal/api/server.go index ab9c0354..b3c3bd26 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -921,6 +921,7 @@ func (s *Server) UpdateClients(cfg *config.Config) { geminiAPIKeyCount := len(cfg.GeminiKey) claudeAPIKeyCount := len(cfg.ClaudeKey) codexAPIKeyCount := len(cfg.CodexKey) + vertexAICompatCount := len(cfg.VertexCompatAPIKey) openAICompatCount := 0 for i := range cfg.OpenAICompatibility { entry := cfg.OpenAICompatibility[i] @@ -931,13 +932,14 @@ func (s *Server) UpdateClients(cfg *config.Config) { openAICompatCount += len(entry.APIKeys) } - total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount - fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)\n", + total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount + fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n", total, authFiles, geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, + vertexAICompatCount, openAICompatCount, ) } diff --git a/internal/config/config.go b/internal/config/config.go index 97b5a0c2..473a0553 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -64,6 +64,10 @@ type Config struct { // GeminiKey defines Gemini API key configurations with optional routing overrides. GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"` + // VertexCompatAPIKey defines Vertex AI-compatible API key configurations for third-party providers. + // Used for services that use Vertex AI-style paths but with simple API key authentication. + VertexCompatAPIKey []VertexCompatKey `yaml:"vertex-api-key" json:"vertex-api-key"` + // RequestRetry defines the retry times when the request failed. RequestRetry int `yaml:"request-retry" json:"request-retry"` // MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential. @@ -325,6 +329,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Sanitize Gemini API key configuration and migrate legacy entries. cfg.SanitizeGeminiKeys() + // Sanitize Vertex-compatible API keys: drop entries without base-url + cfg.SanitizeVertexCompatKeys() + // Sanitize Codex keys: drop entries without base-url cfg.SanitizeCodexKeys() @@ -813,6 +820,7 @@ func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool { switch key { case "generative-language-api-key", "gemini-api-key", + "vertex-api-key", "claude-api-key", "codex-api-key", "openai-compatibility": diff --git a/internal/config/vertex_compat.go b/internal/config/vertex_compat.go new file mode 100644 index 00000000..a8d94ccb --- /dev/null +++ b/internal/config/vertex_compat.go @@ -0,0 +1,84 @@ +package config + +import "strings" + +// VertexCompatKey represents the configuration for Vertex AI-compatible API keys. +// This supports third-party services that use Vertex AI-style endpoint paths +// (/publishers/google/models/{model}:streamGenerateContent) but authenticate +// with simple API keys instead of Google Cloud service account credentials. +// +// Example services: zenmux.ai and similar Vertex-compatible providers. +type VertexCompatKey struct { + // APIKey is the authentication key for accessing the Vertex-compatible API. + // Maps to the x-goog-api-key header. + APIKey string `yaml:"api-key" json:"api-key"` + + // BaseURL is the base URL for the Vertex-compatible API endpoint. + // The executor will append "/v1/publishers/google/models/{model}:action" to this. + // Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..." + BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"` + + // ProxyURL optionally overrides the global proxy for this API key. + ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + + // Headers optionally adds extra HTTP headers for requests sent with this key. + // Commonly used for cookies, user-agent, and other authentication headers. + Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` + + // Models defines the model configurations including aliases for routing. + Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"` +} + +// VertexCompatModel represents a model configuration for Vertex compatibility, +// including the actual model name and its alias for API routing. +type VertexCompatModel struct { + // Name is the actual model name used by the external provider. + Name string `yaml:"name" json:"name"` + + // Alias is the model name alias that clients will use to reference this model. + Alias string `yaml:"alias" json:"alias"` +} + +// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials. +func (cfg *Config) SanitizeVertexCompatKeys() { + if cfg == nil { + return + } + + seen := make(map[string]struct{}, len(cfg.VertexCompatAPIKey)) + out := cfg.VertexCompatAPIKey[:0] + for i := range cfg.VertexCompatAPIKey { + entry := cfg.VertexCompatAPIKey[i] + entry.APIKey = strings.TrimSpace(entry.APIKey) + if entry.APIKey == "" { + continue + } + entry.BaseURL = strings.TrimSpace(entry.BaseURL) + if entry.BaseURL == "" { + // BaseURL is required for vertex-compat keys + continue + } + entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) + entry.Headers = NormalizeHeaders(entry.Headers) + + // Sanitize models: remove entries without valid alias + sanitizedModels := make([]VertexCompatModel, 0, len(entry.Models)) + for _, model := range entry.Models { + model.Alias = strings.TrimSpace(model.Alias) + model.Name = strings.TrimSpace(model.Name) + if model.Alias != "" && model.Name != "" { + sanitizedModels = append(sanitizedModels, model) + } + } + entry.Models = sanitizedModels + + // Use API key + base URL as uniqueness key + uniqueKey := entry.APIKey + "|" + entry.BaseURL + if _, exists := seen[uniqueKey]; exists { + continue + } + seen[uniqueKey] = struct{}{} + out = append(out, entry) + } + cfg.VertexCompatAPIKey = out +} diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index bd4242a1..eeb7356e 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -44,6 +44,22 @@ func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor { // Identifier returns provider key for manager routing. func (e *GeminiVertexExecutor) Identifier() string { return "vertex" } +// GeminiVertexCompatExecutor is a thin wrapper around GeminiVertexExecutor +// that provides the correct identifier for vertex-compat routing. +type GeminiVertexCompatExecutor struct { + *GeminiVertexExecutor +} + +// NewGeminiVertexCompatExecutor constructs the Vertex-compatible executor. +func NewGeminiVertexCompatExecutor(cfg *config.Config) *GeminiVertexCompatExecutor { + return &GeminiVertexCompatExecutor{ + GeminiVertexExecutor: NewGeminiVertexExecutor(cfg), + } +} + +// Identifier returns provider key for manager routing. +func (e *GeminiVertexCompatExecutor) Identifier() string { return "vertex-compat" } + // PrepareRequest is a no-op for Vertex. func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil @@ -51,11 +67,238 @@ func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.A // Execute handles non-streaming requests. func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { - projectID, location, saJSON, errCreds := vertexCreds(auth) - if errCreds != nil { - return resp, errCreds + // Try API key authentication first + apiKey, baseURL := vertexAPICreds(auth) + + // If no API key found, fall back to service account authentication + if apiKey == "" { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return resp, errCreds + } + return e.executeWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) + } + + // Use API key authentication + return e.executeWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) +} + +// ExecuteStream handles SSE streaming for Vertex. +func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + // Try API key authentication first + apiKey, baseURL := vertexAPICreds(auth) + + // If no API key found, fall back to service account authentication + if apiKey == "" { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return nil, errCreds + } + return e.executeStreamWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) + } + + // Use API key authentication + return e.executeStreamWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) +} + +// CountTokens calls Vertex countTokens endpoint. +func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + // Try API key authentication first + apiKey, baseURL := vertexAPICreds(auth) + + // If no API key found, fall back to service account authentication + if apiKey == "" { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return cliproxyexecutor.Response{}, errCreds + } + return e.countTokensWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) + } + + // Use API key authentication + return e.countTokensWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) +} + +// countTokensWithServiceAccount handles token counting using service account credentials. +func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + } + translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + respCtx := context.WithValue(ctx, "alt", opts.Alt) + translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + + baseURL := vertexBaseURL(location) + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens") + + httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + if errNewReq != nil { + return cliproxyexecutor.Response{}, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" { + httpReq.Header.Set("Authorization", "Bearer "+token) + } else if errTok != nil { + log.Errorf("vertex executor: access token error: %v", errTok) + return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"} + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translatedReq, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return cliproxyexecutor.Response{}, errDo + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return cliproxyexecutor.Response{}, errRead + } + appendAPIResponseChunk(ctx, e.cfg, data) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + count := gjson.GetBytes(data, "totalTokens").Int() + out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + +// countTokensWithAPIKey handles token counting using API key credentials. +func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + } + translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + respCtx := context.WithValue(ctx, "alt", opts.Alt) + translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + + // For API key auth, use simpler URL format without project/location + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens") + + httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + if errNewReq != nil { + return cliproxyexecutor.Response{}, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if apiKey != "" { + httpReq.Header.Set("x-goog-api-key", apiKey) + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translatedReq, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return cliproxyexecutor.Response{}, errDo + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} + } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return cliproxyexecutor.Response{}, errRead + } + appendAPIResponseChunk(ctx, e.cfg, data) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + count := gjson.GetBytes(data, "totalTokens").Int() + out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + +// Refresh is a no-op for service account based credentials. +func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + return auth, nil +} + +// executeWithServiceAccount handles authentication using service account credentials. +// This method contains the original service account authentication logic. +func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) { reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -149,13 +392,105 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A return resp, nil } -// ExecuteStream handles SSE streaming for Vertex. -func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { - projectID, location, saJSON, errCreds := vertexCreds(auth) - if errCreds != nil { - return nil, errCreds +// executeWithAPIKey handles authentication using API key credentials. +// This method follows the vertex-compat pattern for API key authentication. +func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) { + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) + } + body = util.StripThinkingConfigIfUnsupported(req.Model, body) + body = fixGeminiImageAspectRatio(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) + + action := "generateContent" + if req.Metadata != nil { + if a, _ := req.Metadata["action"].(string); a == "countTokens" { + action = "countTokens" + } + } + + // For API key auth, use simpler URL format without project/location + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, action) + if opts.Alt != "" && action != "countTokens" { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } + body, _ = sjson.DeleteBytes(body, "session_id") + + httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if errNewReq != nil { + return resp, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if apiKey != "" { + httpReq.Header.Set("x-goog-api-key", apiKey) + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return resp, errDo } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + err = statusErr{code: httpResp.StatusCode, msg: string(b)} + return resp, err + } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return resp, errRead + } + appendAPIResponseChunk(ctx, e.cfg, data) + reporter.publish(ctx, parseGeminiUsage(data)) + var param any + out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil +} +// executeStreamWithServiceAccount handles streaming authentication using service account credentials. +func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) { reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -266,42 +601,44 @@ func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxy return stream, nil } -// CountTokens calls Vertex countTokens endpoint. -func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - projectID, location, saJSON, errCreds := vertexCreds(auth) - if errCreds != nil { - return cliproxyexecutor.Response{}, errCreds - } +// executeStreamWithAPIKey handles streaming authentication using API key credentials. +func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + from := opts.SourceFormat to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { if budgetOverride != nil { norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) budgetOverride = &norm } - translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } - translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) - translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) - respCtx := context.WithValue(ctx, "alt", opts.Alt) - translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") - translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") - translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + body = util.StripThinkingConfigIfUnsupported(req.Model, body) + body = fixGeminiImageAspectRatio(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) - baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens") + // For API key auth, use simpler URL format without project/location + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "streamGenerateContent") + if opts.Alt == "" { + url = url + "?alt=sse" + } else { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } + body, _ = sjson.DeleteBytes(body, "session_id") - httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if errNewReq != nil { - return cliproxyexecutor.Response{}, errNewReq + return nil, errNewReq } httpReq.Header.Set("Content-Type", "application/json") - if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" { - httpReq.Header.Set("Authorization", "Bearer "+token) - } else if errTok != nil { - log.Errorf("vertex executor: access token error: %v", errTok) - return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"} + if apiKey != "" { + httpReq.Header.Set("x-goog-api-key", apiKey) } applyGeminiHeaders(httpReq, auth) @@ -315,7 +652,7 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau URL: url, Method: http.MethodPost, Headers: httpReq.Header.Clone(), - Body: translatedReq, + Body: body, Provider: e.Identifier(), AuthID: authID, AuthLabel: authLabel, @@ -327,38 +664,53 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau httpResp, errDo := httpClient.Do(httpReq) if errDo != nil { recordAPIResponseError(ctx, e.cfg, errDo) - return cliproxyexecutor.Response{}, errDo + return nil, errDo } - defer func() { - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("vertex executor: close response body error: %v", errClose) - } - }() recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} - } - data, errRead := io.ReadAll(httpResp.Body) - if errRead != nil { - recordAPIResponseError(ctx, e.cfg, errRead) - return cliproxyexecutor.Response{}, errRead - } - appendAPIResponseChunk(ctx, e.cfg, data) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) - return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} } - count := gjson.GetBytes(data, "totalTokens").Int() - out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil -} -// Refresh is a no-op for service account based credentials. -func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { - return auth, nil + out := make(chan cliproxyexecutor.StreamChunk) + stream = out + go func() { + defer close(out) + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + scanner := bufio.NewScanner(httpResp.Body) + scanner.Buffer(nil, 20_971_520) + var param any + for scanner.Scan() { + line := scanner.Bytes() + appendAPIResponseChunk(ctx, e.cfg, line) + if detail, ok := parseGeminiStreamUsage(line); ok { + reporter.publish(ctx, detail) + } + lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + } + lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + if errScan := scanner.Err(); errScan != nil { + recordAPIResponseError(ctx, e.cfg, errScan) + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: errScan} + } + }() + return stream, nil } // vertexCreds extracts project, location and raw service account JSON from auth metadata. @@ -401,6 +753,23 @@ func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccou return projectID, location, saJSON, nil } +// vertexAPICreds extracts API key and base URL from auth attributes following the claudeCreds pattern. +func vertexAPICreds(a *cliproxyauth.Auth) (apiKey, baseURL string) { + if a == nil { + return "", "" + } + if a.Attributes != nil { + apiKey = a.Attributes["api_key"] + baseURL = a.Attributes["base_url"] + } + if apiKey == "" && a.Metadata != nil { + if v, ok := a.Metadata["access_token"].(string); ok { + apiKey = v + } + } + return +} + func vertexBaseURL(location string) string { loc := strings.TrimSpace(location) if loc == "" { diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index a284541a..6ecf88a3 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -496,6 +496,18 @@ func computeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) str return hex.EncodeToString(sum[:]) } +func computeVertexCompatModelsHash(models []config.VertexCompatModel) string { + if len(models) == 0 { + return "" + } + data, err := json.Marshal(models) + if err != nil || len(data) == 0 { + return "" + } + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + // computeClaudeModelsHash returns a stable hash for Claude model aliases. func computeClaudeModelsHash(models []config.ClaudeModel) string { if len(models) == 0 { @@ -902,8 +914,8 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string // no legacy clients to unregister // Create new API key clients based on the new config - geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg) - totalAPIKeyClients := geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount + geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg) + totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount log.Debugf("loaded %d API key clients", totalAPIKeyClients) var authFileCount int @@ -946,7 +958,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string w.clientsMutex.Unlock() } - totalNewClients := authFileCount + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount + totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount // Ensure consumers observe the new configuration before auth updates dispatch. if w.reloadCallback != nil { @@ -956,10 +968,11 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string w.refreshAuthState() - log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)", + log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex-compat keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)", totalNewClients, authFileCount, geminiAPIKeyCount, + vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount, @@ -1074,6 +1087,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { applyAuthExcludedModelsMeta(a, cfg, entry.ExcludedModels, "apikey") out = append(out, a) } + // Claude API keys -> synthesize auths for i := range cfg.ClaudeKey { ck := cfg.ClaudeKey[i] @@ -1240,6 +1254,42 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } } } + + // Process Vertex compatibility providers + for i := range cfg.VertexCompatAPIKey { + compat := &cfg.VertexCompatAPIKey[i] + providerName := "vertex-compat" + base := strings.TrimSpace(compat.BaseURL) + + key := strings.TrimSpace(compat.APIKey) + proxyURL := strings.TrimSpace(compat.ProxyURL) + idKind := fmt.Sprintf("vertex-compatibility:%s", base) + id, token := idGen.next(idKind, key, base, proxyURL) + attrs := map[string]string{ + "source": fmt.Sprintf("config:vertex-compatibility[%s]", token), + "base_url": base, + "provider_key": providerName, + } + if key != "" { + attrs["api_key"] = key + } + if hash := computeVertexCompatModelsHash(compat.Models); hash != "" { + attrs["models_hash"] = hash + } + addConfigHeadersToAttrs(compat.Headers, attrs) + a := &coreauth.Auth{ + ID: id, + Provider: providerName, + Label: "Vertex Compatibility", + Status: coreauth.StatusActive, + ProxyURL: proxyURL, + Attributes: attrs, + CreatedAt: now, + UpdatedAt: now, + } + out = append(out, a) + } + // Also synthesize auth entries directly from auth files (for OAuth/file-backed providers) entries, _ := os.ReadDir(w.authDir) for _, e := range entries { @@ -1456,8 +1506,9 @@ func (w *Watcher) loadFileClients(cfg *config.Config) int { return authFileCount } -func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) { +func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) { geminiAPIKeyCount := 0 + vertexCompatAPIKeyCount := 0 claudeAPIKeyCount := 0 codexAPIKeyCount := 0 openAICompatCount := 0 @@ -1466,6 +1517,9 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) { // Stateless executor handles Gemini API keys; avoid constructing legacy clients. geminiAPIKeyCount += len(cfg.GeminiKey) } + if len(cfg.VertexCompatAPIKey) > 0 { + vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey) + } if len(cfg.ClaudeKey) > 0 { claudeAPIKeyCount += len(cfg.ClaudeKey) } @@ -1483,7 +1537,7 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) { } } } - return geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount + return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount } func diffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string { diff --git a/sdk/cliproxy/providers.go b/sdk/cliproxy/providers.go index a5810336..401885f5 100644 --- a/sdk/cliproxy/providers.go +++ b/sdk/cliproxy/providers.go @@ -29,7 +29,7 @@ func NewAPIKeyClientProvider() APIKeyClientProvider { type apiKeyClientProvider struct{} func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error) { - geminiCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg) + geminiCount, vertexCompatCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg) if ctx != nil { select { case <-ctx.Done(): @@ -38,9 +38,10 @@ func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*A } } return &APIKeyClientResult{ - GeminiKeyCount: geminiCount, - ClaudeKeyCount: claudeCount, - CodexKeyCount: codexCount, - OpenAICompatCount: openAICompat, + GeminiKeyCount: geminiCount, + VertexCompatKeyCount: vertexCompatCount, + ClaudeKeyCount: claudeCount, + CodexKeyCount: codexCount, + OpenAICompatCount: openAICompat, }, nil } diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index c2ebba8d..f0b6bf53 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -324,7 +324,7 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName if len(a.Attributes) > 0 { providerKey = strings.TrimSpace(a.Attributes["provider_key"]) compatName = strings.TrimSpace(a.Attributes["compat_name"]) - if providerKey != "" || compatName != "" { + if compatName != "" { if providerKey == "" { providerKey = compatName } @@ -362,6 +362,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg)) case "vertex": s.coreManager.RegisterExecutor(executor.NewGeminiVertexExecutor(s.cfg)) + case "vertex-compat": + s.coreManager.RegisterExecutor(executor.NewGeminiVertexCompatExecutor(s.cfg)) case "gemini-cli": s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg)) case "aistudio": @@ -498,7 +500,7 @@ func (s *Service) Run(ctx context.Context) error { }() time.Sleep(100 * time.Millisecond) - fmt.Println("API server started successfully") + fmt.Printf("API server started successfully on: %d\n", s.cfg.Port) if s.hooks.OnAfterStart != nil { s.hooks.OnAfterStart(s) @@ -680,6 +682,35 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { // Vertex AI Gemini supports the same model identifiers as Gemini. models = registry.GetGeminiVertexModels() models = applyExcludedModels(models, excluded) + case "vertex-compat": + // Handle Vertex AI compatibility providers with custom model definitions + if s.cfg != nil && len(s.cfg.VertexCompatAPIKey) > 0 { + // Create models for all Vertex compatibility providers + allModels := make([]*ModelInfo, 0) + for i := range s.cfg.VertexCompatAPIKey { + compat := &s.cfg.VertexCompatAPIKey[i] + for j := range compat.Models { + m := compat.Models[j] + // Use alias as model ID, fallback to name if alias is empty + modelID := m.Alias + if modelID == "" { + modelID = m.Name + } + if modelID != "" { + allModels = append(allModels, &ModelInfo{ + ID: modelID, + Object: "model", + Created: time.Now().Unix(), + OwnedBy: "vertex-compat", + Type: "vertex-compat", + DisplayName: m.Name, + }) + } + } + } + models = allModels + } + case "gemini-cli": models = registry.GetGeminiCLIModels() models = applyExcludedModels(models, excluded) diff --git a/sdk/cliproxy/types.go b/sdk/cliproxy/types.go index b44185d1..42c7c488 100644 --- a/sdk/cliproxy/types.go +++ b/sdk/cliproxy/types.go @@ -49,19 +49,21 @@ type APIKeyClientProvider interface { Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error) } -// APIKeyClientResult contains API key based clients along with type counts. -// It provides metadata about the number of clients loaded for each provider type. +// APIKeyClientResult is returned by APIKeyClientProvider.Load() type APIKeyClientResult struct { - // GeminiKeyCount is the number of Gemini API key clients loaded. + // GeminiKeyCount is the number of Gemini API keys loaded GeminiKeyCount int - // ClaudeKeyCount is the number of Claude API key clients loaded. + // VertexCompatKeyCount is the number of Vertex-compatible API keys loaded + VertexCompatKeyCount int + + // ClaudeKeyCount is the number of Claude API keys loaded ClaudeKeyCount int - // CodexKeyCount is the number of Codex API key clients loaded. + // CodexKeyCount is the number of Codex API keys loaded CodexKeyCount int - // OpenAICompatCount is the number of OpenAI-compatible API key clients loaded. + // OpenAICompatCount is the number of OpenAI compatibility API keys loaded OpenAICompatCount int }