diff --git a/.gitignore b/.gitignore index ef2d935a..c51d4a2b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,5 @@ # Binaries -cli-proxy-api +/cli-proxy-api* *.exe # Configuration @@ -30,3 +30,6 @@ GEMINI.md .vscode/* .claude/* .serena/* + +refs/* +.DS_Store \ No newline at end of file diff --git a/internal/api/server.go b/internal/api/server.go index ab9c0354..b3c3bd26 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -921,6 +921,7 @@ func (s *Server) UpdateClients(cfg *config.Config) { geminiAPIKeyCount := len(cfg.GeminiKey) claudeAPIKeyCount := len(cfg.ClaudeKey) codexAPIKeyCount := len(cfg.CodexKey) + vertexAICompatCount := len(cfg.VertexCompatAPIKey) openAICompatCount := 0 for i := range cfg.OpenAICompatibility { entry := cfg.OpenAICompatibility[i] @@ -931,13 +932,14 @@ func (s *Server) UpdateClients(cfg *config.Config) { openAICompatCount += len(entry.APIKeys) } - total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount - fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)\n", + total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount + fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n", total, authFiles, geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, + vertexAICompatCount, openAICompatCount, ) } diff --git a/internal/config/config.go b/internal/config/config.go index 97b5a0c2..473a0553 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -64,6 +64,10 @@ type Config struct { // GeminiKey defines Gemini API key configurations with optional routing overrides. GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"` + // VertexCompatAPIKey defines Vertex AI-compatible API key configurations for third-party providers. + // Used for services that use Vertex AI-style paths but with simple API key authentication. + VertexCompatAPIKey []VertexCompatKey `yaml:"vertex-api-key" json:"vertex-api-key"` + // RequestRetry defines the retry times when the request failed. RequestRetry int `yaml:"request-retry" json:"request-retry"` // MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential. @@ -325,6 +329,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Sanitize Gemini API key configuration and migrate legacy entries. cfg.SanitizeGeminiKeys() + // Sanitize Vertex-compatible API keys: drop entries without base-url + cfg.SanitizeVertexCompatKeys() + // Sanitize Codex keys: drop entries without base-url cfg.SanitizeCodexKeys() @@ -813,6 +820,7 @@ func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool { switch key { case "generative-language-api-key", "gemini-api-key", + "vertex-api-key", "claude-api-key", "codex-api-key", "openai-compatibility": diff --git a/internal/config/vertex_compat.go b/internal/config/vertex_compat.go new file mode 100644 index 00000000..a8d94ccb --- /dev/null +++ b/internal/config/vertex_compat.go @@ -0,0 +1,84 @@ +package config + +import "strings" + +// VertexCompatKey represents the configuration for Vertex AI-compatible API keys. +// This supports third-party services that use Vertex AI-style endpoint paths +// (/publishers/google/models/{model}:streamGenerateContent) but authenticate +// with simple API keys instead of Google Cloud service account credentials. +// +// Example services: zenmux.ai and similar Vertex-compatible providers. +type VertexCompatKey struct { + // APIKey is the authentication key for accessing the Vertex-compatible API. + // Maps to the x-goog-api-key header. + APIKey string `yaml:"api-key" json:"api-key"` + + // BaseURL is the base URL for the Vertex-compatible API endpoint. + // The executor will append "/v1/publishers/google/models/{model}:action" to this. + // Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..." + BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"` + + // ProxyURL optionally overrides the global proxy for this API key. + ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"` + + // Headers optionally adds extra HTTP headers for requests sent with this key. + // Commonly used for cookies, user-agent, and other authentication headers. + Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"` + + // Models defines the model configurations including aliases for routing. + Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"` +} + +// VertexCompatModel represents a model configuration for Vertex compatibility, +// including the actual model name and its alias for API routing. +type VertexCompatModel struct { + // Name is the actual model name used by the external provider. + Name string `yaml:"name" json:"name"` + + // Alias is the model name alias that clients will use to reference this model. + Alias string `yaml:"alias" json:"alias"` +} + +// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials. +func (cfg *Config) SanitizeVertexCompatKeys() { + if cfg == nil { + return + } + + seen := make(map[string]struct{}, len(cfg.VertexCompatAPIKey)) + out := cfg.VertexCompatAPIKey[:0] + for i := range cfg.VertexCompatAPIKey { + entry := cfg.VertexCompatAPIKey[i] + entry.APIKey = strings.TrimSpace(entry.APIKey) + if entry.APIKey == "" { + continue + } + entry.BaseURL = strings.TrimSpace(entry.BaseURL) + if entry.BaseURL == "" { + // BaseURL is required for vertex-compat keys + continue + } + entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) + entry.Headers = NormalizeHeaders(entry.Headers) + + // Sanitize models: remove entries without valid alias + sanitizedModels := make([]VertexCompatModel, 0, len(entry.Models)) + for _, model := range entry.Models { + model.Alias = strings.TrimSpace(model.Alias) + model.Name = strings.TrimSpace(model.Name) + if model.Alias != "" && model.Name != "" { + sanitizedModels = append(sanitizedModels, model) + } + } + entry.Models = sanitizedModels + + // Use API key + base URL as uniqueness key + uniqueKey := entry.APIKey + "|" + entry.BaseURL + if _, exists := seen[uniqueKey]; exists { + continue + } + seen[uniqueKey] = struct{}{} + out = append(out, entry) + } + cfg.VertexCompatAPIKey = out +} diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index bd4242a1..eeb7356e 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -44,6 +44,22 @@ func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor { // Identifier returns provider key for manager routing. func (e *GeminiVertexExecutor) Identifier() string { return "vertex" } +// GeminiVertexCompatExecutor is a thin wrapper around GeminiVertexExecutor +// that provides the correct identifier for vertex-compat routing. +type GeminiVertexCompatExecutor struct { + *GeminiVertexExecutor +} + +// NewGeminiVertexCompatExecutor constructs the Vertex-compatible executor. +func NewGeminiVertexCompatExecutor(cfg *config.Config) *GeminiVertexCompatExecutor { + return &GeminiVertexCompatExecutor{ + GeminiVertexExecutor: NewGeminiVertexExecutor(cfg), + } +} + +// Identifier returns provider key for manager routing. +func (e *GeminiVertexCompatExecutor) Identifier() string { return "vertex-compat" } + // PrepareRequest is a no-op for Vertex. func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil @@ -51,11 +67,238 @@ func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.A // Execute handles non-streaming requests. func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { - projectID, location, saJSON, errCreds := vertexCreds(auth) - if errCreds != nil { - return resp, errCreds + // Try API key authentication first + apiKey, baseURL := vertexAPICreds(auth) + + // If no API key found, fall back to service account authentication + if apiKey == "" { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return resp, errCreds + } + return e.executeWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) + } + + // Use API key authentication + return e.executeWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) +} + +// ExecuteStream handles SSE streaming for Vertex. +func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + // Try API key authentication first + apiKey, baseURL := vertexAPICreds(auth) + + // If no API key found, fall back to service account authentication + if apiKey == "" { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return nil, errCreds + } + return e.executeStreamWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) + } + + // Use API key authentication + return e.executeStreamWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) +} + +// CountTokens calls Vertex countTokens endpoint. +func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + // Try API key authentication first + apiKey, baseURL := vertexAPICreds(auth) + + // If no API key found, fall back to service account authentication + if apiKey == "" { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return cliproxyexecutor.Response{}, errCreds + } + return e.countTokensWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON) + } + + // Use API key authentication + return e.countTokensWithAPIKey(ctx, auth, req, opts, apiKey, baseURL) +} + +// countTokensWithServiceAccount handles token counting using service account credentials. +func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + } + translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + respCtx := context.WithValue(ctx, "alt", opts.Alt) + translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + + baseURL := vertexBaseURL(location) + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens") + + httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + if errNewReq != nil { + return cliproxyexecutor.Response{}, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" { + httpReq.Header.Set("Authorization", "Bearer "+token) + } else if errTok != nil { + log.Errorf("vertex executor: access token error: %v", errTok) + return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"} + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translatedReq, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return cliproxyexecutor.Response{}, errDo + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return cliproxyexecutor.Response{}, errRead + } + appendAPIResponseChunk(ctx, e.cfg, data) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + count := gjson.GetBytes(data, "totalTokens").Int() + out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + +// countTokensWithAPIKey handles token counting using API key credentials. +func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) { + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + } + translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + respCtx := context.WithValue(ctx, "alt", opts.Alt) + translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + + // For API key auth, use simpler URL format without project/location + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens") + + httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + if errNewReq != nil { + return cliproxyexecutor.Response{}, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if apiKey != "" { + httpReq.Header.Set("x-goog-api-key", apiKey) + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translatedReq, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return cliproxyexecutor.Response{}, errDo + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} + } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return cliproxyexecutor.Response{}, errRead + } + appendAPIResponseChunk(ctx, e.cfg, data) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + count := gjson.GetBytes(data, "totalTokens").Int() + out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + +// Refresh is a no-op for service account based credentials. +func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + return auth, nil +} + +// executeWithServiceAccount handles authentication using service account credentials. +// This method contains the original service account authentication logic. +func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) { reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -149,13 +392,105 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A return resp, nil } -// ExecuteStream handles SSE streaming for Vertex. -func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { - projectID, location, saJSON, errCreds := vertexCreds(auth) - if errCreds != nil { - return nil, errCreds +// executeWithAPIKey handles authentication using API key credentials. +// This method follows the vertex-compat pattern for API key authentication. +func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) { + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) + } + body = util.StripThinkingConfigIfUnsupported(req.Model, body) + body = fixGeminiImageAspectRatio(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) + + action := "generateContent" + if req.Metadata != nil { + if a, _ := req.Metadata["action"].(string); a == "countTokens" { + action = "countTokens" + } + } + + // For API key auth, use simpler URL format without project/location + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, action) + if opts.Alt != "" && action != "countTokens" { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } + body, _ = sjson.DeleteBytes(body, "session_id") + + httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if errNewReq != nil { + return resp, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if apiKey != "" { + httpReq.Header.Set("x-goog-api-key", apiKey) + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return resp, errDo } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + err = statusErr{code: httpResp.StatusCode, msg: string(b)} + return resp, err + } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return resp, errRead + } + appendAPIResponseChunk(ctx, e.cfg, data) + reporter.publish(ctx, parseGeminiUsage(data)) + var param any + out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil +} +// executeStreamWithServiceAccount handles streaming authentication using service account credentials. +func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) { reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) defer reporter.trackFailure(ctx, &err) @@ -266,42 +601,44 @@ func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxy return stream, nil } -// CountTokens calls Vertex countTokens endpoint. -func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { - projectID, location, saJSON, errCreds := vertexCreds(auth) - if errCreds != nil { - return cliproxyexecutor.Response{}, errCreds - } +// executeStreamWithAPIKey handles streaming authentication using API key credentials. +func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + from := opts.SourceFormat to := sdktranslator.FromString("gemini") - translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { if budgetOverride != nil { norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) budgetOverride = &norm } - translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) } - translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) - translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) - respCtx := context.WithValue(ctx, "alt", opts.Alt) - translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") - translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") - translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + body = util.StripThinkingConfigIfUnsupported(req.Model, body) + body = fixGeminiImageAspectRatio(req.Model, body) + body = applyPayloadConfig(e.cfg, req.Model, body) - baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens") + // For API key auth, use simpler URL format without project/location + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "streamGenerateContent") + if opts.Alt == "" { + url = url + "?alt=sse" + } else { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } + body, _ = sjson.DeleteBytes(body, "session_id") - httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) if errNewReq != nil { - return cliproxyexecutor.Response{}, errNewReq + return nil, errNewReq } httpReq.Header.Set("Content-Type", "application/json") - if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" { - httpReq.Header.Set("Authorization", "Bearer "+token) - } else if errTok != nil { - log.Errorf("vertex executor: access token error: %v", errTok) - return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"} + if apiKey != "" { + httpReq.Header.Set("x-goog-api-key", apiKey) } applyGeminiHeaders(httpReq, auth) @@ -315,7 +652,7 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau URL: url, Method: http.MethodPost, Headers: httpReq.Header.Clone(), - Body: translatedReq, + Body: body, Provider: e.Identifier(), AuthID: authID, AuthLabel: authLabel, @@ -327,38 +664,53 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau httpResp, errDo := httpClient.Do(httpReq) if errDo != nil { recordAPIResponseError(ctx, e.cfg, errDo) - return cliproxyexecutor.Response{}, errDo + return nil, errDo } - defer func() { - if errClose := httpResp.Body.Close(); errClose != nil { - log.Errorf("vertex executor: close response body error: %v", errClose) - } - }() recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) - return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} - } - data, errRead := io.ReadAll(httpResp.Body) - if errRead != nil { - recordAPIResponseError(ctx, e.cfg, errRead) - return cliproxyexecutor.Response{}, errRead - } - appendAPIResponseChunk(ctx, e.cfg, data) - if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { - log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) - return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} } - count := gjson.GetBytes(data, "totalTokens").Int() - out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) - return cliproxyexecutor.Response{Payload: []byte(out)}, nil -} -// Refresh is a no-op for service account based credentials. -func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { - return auth, nil + out := make(chan cliproxyexecutor.StreamChunk) + stream = out + go func() { + defer close(out) + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + scanner := bufio.NewScanner(httpResp.Body) + scanner.Buffer(nil, 20_971_520) + var param any + for scanner.Scan() { + line := scanner.Bytes() + appendAPIResponseChunk(ctx, e.cfg, line) + if detail, ok := parseGeminiStreamUsage(line); ok { + reporter.publish(ctx, detail) + } + lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + } + lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + if errScan := scanner.Err(); errScan != nil { + recordAPIResponseError(ctx, e.cfg, errScan) + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: errScan} + } + }() + return stream, nil } // vertexCreds extracts project, location and raw service account JSON from auth metadata. @@ -401,6 +753,23 @@ func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccou return projectID, location, saJSON, nil } +// vertexAPICreds extracts API key and base URL from auth attributes following the claudeCreds pattern. +func vertexAPICreds(a *cliproxyauth.Auth) (apiKey, baseURL string) { + if a == nil { + return "", "" + } + if a.Attributes != nil { + apiKey = a.Attributes["api_key"] + baseURL = a.Attributes["base_url"] + } + if apiKey == "" && a.Metadata != nil { + if v, ok := a.Metadata["access_token"].(string); ok { + apiKey = v + } + } + return +} + func vertexBaseURL(location string) string { loc := strings.TrimSpace(location) if loc == "" { diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index 55ec6dc9..e3421c76 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "context" + "encoding/json" "fmt" "io" "net/http" @@ -59,8 +60,23 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A } translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated) - url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated)) + // Check if this is a web search request (has special marker we added in translator) + isWebSearch := isWebSearchRequest(translated) + + // Store the marker flag but clean the payload before sending + sendPayload := translated + if isWebSearch { + sendPayload = pickWebSearchFields(sendPayload) + } + + var url string + if isWebSearch { + url = strings.TrimSuffix(baseURL, "/") + "/chat/retrieve" + } else { + url = strings.TrimSuffix(baseURL, "/") + "/chat/completions" + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(sendPayload)) if err != nil { return resp, err } @@ -104,10 +120,11 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A } }() recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + log.Debugf("OpenAICompatExecutor Execute: HTTP Response status: %d, headers: %v", httpResp.StatusCode, httpResp.Header) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + log.Debugf("OpenAICompatExecutor Execute: request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) err = statusErr{code: httpResp.StatusCode, msg: string(b)} return resp, err } @@ -117,12 +134,27 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A return resp, err } appendAPIResponseChunk(ctx, e.cfg, body) - reporter.publish(ctx, parseOpenAIUsage(body)) - // Ensure we at least record the request even if upstream doesn't return usage - reporter.ensurePublished(ctx) - // Translate response back to source format when needed + + // Handle web search responses differently from standard OpenAI responses + var out string var param any - out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, body, ¶m) + if isWebSearch { + log.Debugf("OpenAICompatExecutor Execute: Web search response received, request model: %s, raw response: %s", req.Model, string(body)) + // For web search responses, we need to format them properly for Claude + // The /chat/retrieve endpoint returns a different format than OpenAI + translatedOut := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, body, ¶m) + log.Debugf("OpenAICompatExecutor Execute: Web search response translated to: %s", translatedOut) + out = translatedOut + } else { + // Standard OpenAI response handling + reporter.publish(ctx, parseOpenAIUsage(body)) + // Ensure we at least record the request even if upstream doesn't return usage + reporter.ensurePublished(ctx) + // Translate response back to source format when needed + out = sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, body, ¶m) + } + log.Debugf("OpenAICompatExecutor Execute: Response translated to: %s", out) + resp = cliproxyexecutor.Response{Payload: []byte(out)} return resp, nil } @@ -144,8 +176,23 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy } translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated) - url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" - httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated)) + // Check if this is a web search request (has special marker we added in translator) + isWebSearch := isWebSearchRequest(translated) + + // Store the marker flag but clean the payload before sending + sendPayload := translated + if isWebSearch { + sendPayload = pickWebSearchFields(sendPayload) + } + + var url string + if isWebSearch { + url = strings.TrimSuffix(baseURL, "/") + "/chat/retrieve" + } else { + url = strings.TrimSuffix(baseURL, "/") + "/chat/completions" + } + + httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(sendPayload)) if err != nil { return nil, err } @@ -159,8 +206,12 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy attrs = auth.Attributes } util.ApplyCustomHeadersFromAttrs(httpReq, attrs) - httpReq.Header.Set("Accept", "text/event-stream") - httpReq.Header.Set("Cache-Control", "no-cache") + + // For web search, we don't want stream headers as it returns a complete response + if !isWebSearch { + httpReq.Header.Set("Accept", "text/event-stream") + httpReq.Header.Set("Cache-Control", "no-cache") + } var authID, authLabel, authType, authValue string if auth != nil { authID = auth.ID @@ -186,16 +237,18 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy return nil, err } recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + log.Debugf("OpenAICompatExecutor ExecuteStream: HTTP Response status: %d, headers: %v", httpResp.StatusCode, httpResp.Header) if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { b, _ := io.ReadAll(httpResp.Body) appendAPIResponseChunk(ctx, e.cfg, b) - log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + log.Debugf("OpenAICompatExecutor ExecuteStream: request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) if errClose := httpResp.Body.Close(); errClose != nil { log.Errorf("openai compat executor: close response body error: %v", errClose) } err = statusErr{code: httpResp.StatusCode, msg: string(b)} return nil, err } + out := make(chan cliproxyexecutor.StreamChunk) stream = out go func() { @@ -205,32 +258,59 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy log.Errorf("openai compat executor: close response body error: %v", errClose) } }() - scanner := bufio.NewScanner(httpResp.Body) - scanner.Buffer(nil, 20_971_520) - var param any - for scanner.Scan() { - line := scanner.Bytes() - appendAPIResponseChunk(ctx, e.cfg, line) - if detail, ok := parseOpenAIStreamUsage(line); ok { - reporter.publish(ctx, detail) - } - if len(line) == 0 { - continue + + // For web search requests, the response is a single JSON rather than an SSE stream + if isWebSearch { + // Read the complete response body at once, since /chat/retrieve returns complete JSON + body, err := io.ReadAll(httpResp.Body) + if err != nil { + recordAPIResponseError(ctx, e.cfg, err) + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: err} + return } - // OpenAI-compatible streams are SSE: lines typically prefixed with "data: ". - // Pass through translator; it yields one or more chunks for the target schema. - chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(line), ¶m) + + log.Debugf("OpenAICompatExecutor ExecuteStream: Web search response received, raw response: %s", string(body)) + appendAPIResponseChunk(ctx, e.cfg, body) + + // Translate the single web search response to SSE events + // The response translator should handle web search response format and generate SSE events + var param any + chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, body, ¶m) for i := range chunks { + log.Debugf("OpenAICompatExecutor ExecuteStream: Web search SSE event chunk: %s", chunks[i]) out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} } + } else { + // For regular OpenAI-compatible streaming responses + scanner := bufio.NewScanner(httpResp.Body) + buf := make([]byte, 20_971_520) + scanner.Buffer(buf, 20_971_520) + var param any + for scanner.Scan() { + line := scanner.Bytes() + appendAPIResponseChunk(ctx, e.cfg, line) + if detail, ok := parseOpenAIStreamUsage(line); ok { + reporter.publish(ctx, detail) + } + if len(line) == 0 { + continue + } + // OpenAI-compatible streams are SSE: lines typically prefixed with "data: ". + // Pass through translator; it yields one or more chunks for the target schema. + chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(line), ¶m) + for i := range chunks { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])} + } + } + if errScan := scanner.Err(); errScan != nil { + recordAPIResponseError(ctx, e.cfg, errScan) + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: errScan} + } + // Ensure we record the request if no usage chunk was ever seen + reporter.ensurePublished(ctx) } - if errScan := scanner.Err(); errScan != nil { - recordAPIResponseError(ctx, e.cfg, errScan) - reporter.publishFailure(ctx) - out <- cliproxyexecutor.StreamChunk{Err: errScan} - } - // Ensure we record the request if no usage chunk was ever seen - reporter.ensurePublished(ctx) }() return stream, nil } @@ -354,3 +434,71 @@ func (e statusErr) Error() string { } func (e statusErr) StatusCode() int { return e.code } func (e statusErr) RetryAfter() *time.Duration { return e.retryAfter } + +// isWebSearchRequest checks if the translated request is a web search request +// by checking if it has exactly one tool that matches /^web_search/ or if it has the special marker +func isWebSearchRequest(translated []byte) bool { + // First check for the special marker that the translator adds + if bytes.Contains(translated, []byte("\"_web_search_request\":true")) { + return true + } + + var req map[string]interface{} + if err := json.Unmarshal(translated, &req); err != nil { + return false + } + + // Check if tools exist and is an array + tools, ok := req["tools"].([]interface{}) + if !ok || len(tools) != 1 { + return false + } + + // Check if the single tool has a type that matches /^web_search/ + if tool, ok := tools[0].(map[string]interface{}); ok { + if toolType, ok := tool["type"].(string); ok { + return strings.HasPrefix(toolType, "web_search") + } + } + + return false +} + +// pickWebSearchFields extracts only the required fields for /chat/retrieve endpoint +func pickWebSearchFields(payload []byte) []byte { + var data map[string]interface{} + if err := json.Unmarshal(payload, &data); err != nil { + return payload + } + + // Create new map with only the 6 required fields for /chat/retrieve + cleaned := make(map[string]interface{}) + + // Only extract these specific fields (model is required, enableIntention and enableQueryRewrite should be false) + if model, ok := data["model"].(string); ok { + cleaned["model"] = model + } + if phase, ok := data["phase"].(string); ok { + cleaned["phase"] = phase + } + if query, ok := data["query"].(string); ok { + cleaned["query"] = query + } + if enableIntention, ok := data["enableIntention"].(bool); ok { + cleaned["enableIntention"] = enableIntention + } + if appCode, ok := data["appCode"].(string); ok { + cleaned["appCode"] = appCode + } + if enableQueryRewrite, ok := data["enableQueryRewrite"].(bool); ok { + cleaned["enableQueryRewrite"] = enableQueryRewrite + } + + // Re-encode with only the required fields + result, err := json.Marshal(cleaned) + if err != nil { + return payload + } + + return result +} diff --git a/internal/translator/openai/claude/openai_claude_request.go b/internal/translator/openai/claude/openai_claude_request.go index bff306cc..a5746c1f 100644 --- a/internal/translator/openai/claude/openai_claude_request.go +++ b/internal/translator/openai/claude/openai_claude_request.go @@ -8,6 +8,7 @@ package claude import ( "bytes" "encoding/json" + "strings" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -18,11 +19,24 @@ import ( // from the raw JSON request and returns them in the format expected by the OpenAI API. func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream bool) []byte { rawJSON := bytes.Clone(inputRawJSON) - // Base OpenAI Chat Completions API template - out := `{"model":"","messages":[]}` - root := gjson.ParseBytes(rawJSON) + // Check if this is a web search request first + if isWebSearchRequest(root) { + // For web search requests, return the format needed for the /chat/retrieve endpoint + // Add a special indicator that executors can use to route this differently + webSearchRequest := createWebSearchRequestJSON(root) + // Add a metadata field to indicate this is a special web search request + result := make([]byte, len(webSearchRequest)+30) + copy(result, webSearchRequest[:len(webSearchRequest)-1]) // Copy everything except the closing brace + metadata := `,"_web_search_request":true}` + copy(result[len(webSearchRequest)-1:], metadata) + return result + } + + // Base OpenAI Chat Completions API template for non-web-search requests + out := `{"model":"","messages":[]}` + // Model mapping out, _ = sjson.Set(out, "model", modelName) @@ -286,3 +300,150 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) { return "", false } } + +// isWebSearchRequest checks if the Claude request is for a web search by checking for web search tools. +func isWebSearchRequest(root gjson.Result) bool { + tools := root.Get("tools") + if !tools.Exists() || !tools.IsArray() { + return false + } + + found := false + tools.ForEach(func(_, tool gjson.Result) bool { + if tool.Get("type").String() == "web_search_20250305" { + // Found a web search tool + found = true + return false // stop iteration + } + return true + }) + + return found +} + +// createWebSearchRequestJSON creates the JSON for /chat/retrieve endpoint +func createWebSearchRequestJSON(root gjson.Result) []byte { + query := extractWebSearchQuery(root) + if query == "" { + // Default query if extraction fails + query = "web search" + } + + // Create the JSON structure for the chat/retrieve endpoint (enableIntention and enableQueryRewrite should be false) + webSearchJSON := `{"phase":"UNIFY","query":"","enableIntention":false,"appCode":"COMPLEX_CHATBOT","enableQueryRewrite":false}` + webSearchJSON, _ = sjson.Set(webSearchJSON, "query", query) + + return []byte(webSearchJSON) +} + +// extractWebSearchQuery extracts the search query from Claude messages +func extractWebSearchQuery(root gjson.Result) string { + messages := root.Get("messages") + if !messages.Exists() || !messages.IsArray() { + return "" + } + + query := "" + messages.ForEach(func(_, message gjson.Result) bool { + // Only look for the first user message that might contain the query + role := message.Get("role").String() + if role != "user" { + return true // continue to next message + } + + content := message.Get("content") + if !content.Exists() || !content.IsArray() { + return true + } + + content.ForEach(func(_, part gjson.Result) bool { + if part.Get("type").String() == "text" { + text := part.Get("text").String() + // Extract query from message like "Perform a web search for the query: " + if strings.Contains(text, "web search for the query:") { + parts := strings.SplitN(text, "web search for the query:", 2) + if len(parts) > 1 { + query = strings.TrimSpace(parts[1]) + return false // stop iteration + } + } + // Alternative extraction: if the entire text looks like a search query, use it + if query == "" { + // Try to find text after common search phrases + searchPhrases := []string{ + "perform a web search for the query:", + "perform a web search for:", + "web search for the query:", + "web search for:", + "search for the query:", + "search for:", + "query:", + "search query:", + } + for _, phrase := range searchPhrases { + phraseLower := strings.ToLower(phrase) + if idx := strings.Index(strings.ToLower(text), phraseLower); idx >= 0 { + query = strings.TrimSpace(text[idx+len(phrase):]) + // Remove any trailing punctuation that might be part of the instruction + query = strings.TrimRight(query, ".!?") + if query != "" { + return false // stop iteration + } + } + } + + // If still no query found, check if the entire text is a search-like query + if query == "" && (strings.Contains(strings.ToLower(text), "search") || + strings.Contains(strings.ToLower(text), "find") || + strings.Contains(strings.ToLower(text), "what") || + strings.Contains(strings.ToLower(text), "how") || + strings.Contains(strings.ToLower(text), "why") || + strings.Contains(strings.ToLower(text), "when") || + strings.Contains(strings.ToLower(text), "where")) { + trimmed := strings.TrimSpace(text) + if len(trimmed) > 5 { // Basic sanity check + query = trimmed + return false // stop iteration + } + } + } + } + return true + }) + + // Stop after processing user message if a query was found + return query == "" + }) + + // Fallback: if no query found but this is marked as a web search, + // try to use any user message content as the query + if query == "" { + messages.ForEach(func(_, message gjson.Result) bool { + role := message.Get("role").String() + if role != "user" { + return true + } + + content := message.Get("content") + if content.Exists() && content.IsArray() { + content.ForEach(func(_, part gjson.Result) bool { + if part.Get("type").String() == "text" && query == "" { + text := part.Get("text").String() + if text != "" { + query = strings.TrimSpace(text) + // Limit query length to be reasonable + if len(query) > 200 { + query = query[:200] + } + return false // stop iteration + } + } + return true + }) + } + return query == "" + }) + } + + return query +} diff --git a/internal/translator/openai/claude/openai_claude_response.go b/internal/translator/openai/claude/openai_claude_response.go index dac4c970..4373c0c7 100644 --- a/internal/translator/openai/claude/openai_claude_response.go +++ b/internal/translator/openai/claude/openai_claude_response.go @@ -11,6 +11,7 @@ import ( "encoding/json" "fmt" "strings" + "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/tidwall/gjson" @@ -73,7 +74,7 @@ type ToolCallAccumulator struct { // // Returns: // - []string: A slice of strings, each containing an Anthropic-compatible JSON response. -func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { +func ConvertOpenAIResponseToClaude(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string { if *param == nil { *param = &ConvertOpenAIResponseToAnthropicParams{ MessageID: "", @@ -93,6 +94,38 @@ func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestR } } + // Check if this is a web search response (non-streaming case) + // When handling streaming, the web search response should come as a single chunk + if isWebSearchResponse(rawJSON) || (bytes.HasPrefix(rawJSON, dataTag) && isWebSearchResponse(bytes.TrimSpace(rawJSON[5:]))) { + // For web search responses in streaming context, we need to generate SSE events + // If this is a data-tag prefixed response, process as a streaming chunk + if bytes.HasPrefix(rawJSON, dataTag) { + webSearchRaw := bytes.TrimSpace(rawJSON[5:]) + // Also check the original request was streaming to determine format + streamResult := gjson.GetBytes(originalRequestRawJSON, "stream") + isStream := streamResult.Exists() && streamResult.Type != gjson.False + if isStream { + return convertWebSearchResponseToClaudeSSE(webSearchRaw, modelName, (*param).(*ConvertOpenAIResponseToAnthropicParams)) + } else { + // Non-streaming context, return the complete Claude message + converted := convertWebSearchResponseToClaude(webSearchRaw, modelName) + return []string{converted} + } + } else { + // This is unprefixed web search response - check if original request was streaming + streamResult := gjson.GetBytes(originalRequestRawJSON, "stream") + isStream := streamResult.Exists() && streamResult.Type != gjson.False + if isStream { + // Original request was streaming, convert to SSE events + return convertWebSearchResponseToClaudeSSE(rawJSON, modelName, (*param).(*ConvertOpenAIResponseToAnthropicParams)) + } else { + // Non-streaming context, return the complete Claude message + converted := convertWebSearchResponseToClaude(rawJSON, modelName) + return []string{converted} + } + } + } + if !bytes.HasPrefix(rawJSON, dataTag) { return []string{} } @@ -863,3 +896,196 @@ func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, origina func ClaudeTokenCount(ctx context.Context, count int64) string { return fmt.Sprintf(`{"input_tokens":%d}`, count) } + +// isWebSearchResponse checks if the response is a web search response by looking at its structure +func isWebSearchResponse(rawJSON []byte) bool { + root := gjson.ParseBytes(rawJSON) + + // Check for web search response structure (different from OpenAI format) + if root.Get("query").Exists() && root.Get("status").Exists() && root.Get("results").Exists() { + return true + } + + // Check for data array field which contains the search results + if root.Get("data").Exists() && root.Get("data").IsArray() { + return true + } + + // Check for result message field which contains the final answer + if root.Get("result_message").Exists() { + return true + } + + return false +} + +// extractWebSearchResult extracts the web search result as a JSON string +func extractWebSearchResult(rawJSON []byte) string { + root := gjson.ParseBytes(rawJSON) + + // Try to extract from the data array first (most common format) + if data := root.Get("data"); data.Exists() && data.IsArray() { + // Return the raw JSON string of the data array + return data.String() + } + + // Fallback: try other formats + if resultMessage := root.Get("result_message"); resultMessage.Exists() { + return resultMessage.String() + } + + // Last resort: check if there's a result string + if resultText := root.Get("result"); resultText.Exists() && resultText.Type == gjson.String { + return resultText.String() + } + + // Return the raw JSON as a string if nothing else matches + return string(rawJSON) +} + +// convertWebSearchResponseToClaude converts a web search response to Claude format +func convertWebSearchResponseToClaude(rawJSON []byte, modelName string) string { + resultText := extractWebSearchResult(rawJSON) + + // Build Claude response + response := map[string]interface{}{ + "id": generateMessageID(), + "type": "message", + "role": "assistant", + "model": modelName, + "stop_reason": "end_turn", + "stop_sequence": nil, + "usage": map[string]interface{}{ + "input_tokens": 0, + "output_tokens": 0, + }, + } + + // Create content blocks with the result + contentBlocks := []interface{}{ + map[string]interface{}{ + "type": "text", + "text": resultText, + }, + } + + response["content"] = contentBlocks + + // Marshal to JSON + responseJSON, err := json.Marshal(response) + if err != nil { + return "" + } + return string(responseJSON) +} + +// convertWebSearchResponseToClaudeSSE simulates SSE events for web search responses +// This is necessary because /chat/retrieve returns complete JSON, but Claude Code expects SSE format +func convertWebSearchResponseToClaudeSSE(rawJSON []byte, modelName string, param *ConvertOpenAIResponseToAnthropicParams) []string { + var results []string + + // Generate message ID and model if not set + if param.MessageID == "" { + param.MessageID = generateMessageID() + } + if param.Model == "" { + param.Model = modelName + } + + // Send message_start event + messageStart := map[string]interface{}{ + "type": "message_start", + "message": map[string]interface{}{ + "id": param.MessageID, + "type": "message", + "role": "assistant", + "model": param.Model, + "content": []interface{}{}, + "stop_reason": nil, + "stop_sequence": nil, + "usage": map[string]interface{}{ + "input_tokens": 0, + "output_tokens": 0, + }, + }, + } + messageStartJSON, _ := json.Marshal(messageStart) + results = append(results, "event: message_start\ndata: "+string(messageStartJSON)+"\n\n") + + // Extract the result from web search response + resultText := extractWebSearchResult(rawJSON) + + // Split the result into chunks for streaming simulation + if resultText != "" { + // Start content block + param.TextContentBlockIndex = param.NextContentBlockIndex + param.NextContentBlockIndex++ + + contentBlockStart := map[string]interface{}{ + "type": "content_block_start", + "index": param.TextContentBlockIndex, + "content_block": map[string]interface{}{ + "type": "text", + "text": "", + }, + } + contentBlockStartJSON, _ := json.Marshal(contentBlockStart) + results = append(results, "event: content_block_start\ndata: "+string(contentBlockStartJSON)+"\n\n") + + // Send the result text in chunks (simulate streaming) + chunkSize := 200 // Characters per chunk for simulation + for i := 0; i < len(resultText); i += chunkSize { + end := i + chunkSize + if end > len(resultText) { + end = len(resultText) + } + chunk := resultText[i:end] + + contentDelta := map[string]interface{}{ + "type": "content_block_delta", + "index": param.TextContentBlockIndex, + "delta": map[string]interface{}{ + "type": "text_delta", + "text": chunk, + }, + } + contentDeltaJSON, _ := json.Marshal(contentDelta) + results = append(results, "event: content_block_delta\ndata: "+string(contentDeltaJSON)+"\n\n") + } + + // End content block + contentBlockStop := map[string]interface{}{ + "type": "content_block_stop", + "index": param.TextContentBlockIndex, + } + contentBlockStopJSON, _ := json.Marshal(contentBlockStop) + results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n") + } + + // Send message_delta with stop reason + messageDelta := map[string]interface{}{ + "type": "message_delta", + "delta": map[string]interface{}{ + "stop_reason": "end_turn", + "stop_sequence": nil, + }, + "usage": map[string]interface{}{ + "input_tokens": 0, + "output_tokens": 0, + }, + } + messageDeltaJSON, _ := json.Marshal(messageDelta) + results = append(results, "event: message_delta\ndata: "+string(messageDeltaJSON)+"\n\n") + + // Send message_stop + results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n") + + return results +} + +// generateMessageID generates a unique message ID +func generateMessageID() string { + // Simple ID generation - using timestamp should be sufficient + // In production, you might want to use UUID or better ID generation + return fmt.Sprintf("msg_%d", time.Now().UnixNano())[:18] +} diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index a284541a..6ecf88a3 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -496,6 +496,18 @@ func computeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) str return hex.EncodeToString(sum[:]) } +func computeVertexCompatModelsHash(models []config.VertexCompatModel) string { + if len(models) == 0 { + return "" + } + data, err := json.Marshal(models) + if err != nil || len(data) == 0 { + return "" + } + sum := sha256.Sum256(data) + return hex.EncodeToString(sum[:]) +} + // computeClaudeModelsHash returns a stable hash for Claude model aliases. func computeClaudeModelsHash(models []config.ClaudeModel) string { if len(models) == 0 { @@ -902,8 +914,8 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string // no legacy clients to unregister // Create new API key clients based on the new config - geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg) - totalAPIKeyClients := geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount + geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg) + totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount log.Debugf("loaded %d API key clients", totalAPIKeyClients) var authFileCount int @@ -946,7 +958,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string w.clientsMutex.Unlock() } - totalNewClients := authFileCount + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount + totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount // Ensure consumers observe the new configuration before auth updates dispatch. if w.reloadCallback != nil { @@ -956,10 +968,11 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string w.refreshAuthState() - log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)", + log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex-compat keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)", totalNewClients, authFileCount, geminiAPIKeyCount, + vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount, @@ -1074,6 +1087,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { applyAuthExcludedModelsMeta(a, cfg, entry.ExcludedModels, "apikey") out = append(out, a) } + // Claude API keys -> synthesize auths for i := range cfg.ClaudeKey { ck := cfg.ClaudeKey[i] @@ -1240,6 +1254,42 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } } } + + // Process Vertex compatibility providers + for i := range cfg.VertexCompatAPIKey { + compat := &cfg.VertexCompatAPIKey[i] + providerName := "vertex-compat" + base := strings.TrimSpace(compat.BaseURL) + + key := strings.TrimSpace(compat.APIKey) + proxyURL := strings.TrimSpace(compat.ProxyURL) + idKind := fmt.Sprintf("vertex-compatibility:%s", base) + id, token := idGen.next(idKind, key, base, proxyURL) + attrs := map[string]string{ + "source": fmt.Sprintf("config:vertex-compatibility[%s]", token), + "base_url": base, + "provider_key": providerName, + } + if key != "" { + attrs["api_key"] = key + } + if hash := computeVertexCompatModelsHash(compat.Models); hash != "" { + attrs["models_hash"] = hash + } + addConfigHeadersToAttrs(compat.Headers, attrs) + a := &coreauth.Auth{ + ID: id, + Provider: providerName, + Label: "Vertex Compatibility", + Status: coreauth.StatusActive, + ProxyURL: proxyURL, + Attributes: attrs, + CreatedAt: now, + UpdatedAt: now, + } + out = append(out, a) + } + // Also synthesize auth entries directly from auth files (for OAuth/file-backed providers) entries, _ := os.ReadDir(w.authDir) for _, e := range entries { @@ -1456,8 +1506,9 @@ func (w *Watcher) loadFileClients(cfg *config.Config) int { return authFileCount } -func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) { +func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) { geminiAPIKeyCount := 0 + vertexCompatAPIKeyCount := 0 claudeAPIKeyCount := 0 codexAPIKeyCount := 0 openAICompatCount := 0 @@ -1466,6 +1517,9 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) { // Stateless executor handles Gemini API keys; avoid constructing legacy clients. geminiAPIKeyCount += len(cfg.GeminiKey) } + if len(cfg.VertexCompatAPIKey) > 0 { + vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey) + } if len(cfg.ClaudeKey) > 0 { claudeAPIKeyCount += len(cfg.ClaudeKey) } @@ -1483,7 +1537,7 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) { } } } - return geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount + return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount } func diffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string { diff --git a/sdk/cliproxy/providers.go b/sdk/cliproxy/providers.go index a5810336..401885f5 100644 --- a/sdk/cliproxy/providers.go +++ b/sdk/cliproxy/providers.go @@ -29,7 +29,7 @@ func NewAPIKeyClientProvider() APIKeyClientProvider { type apiKeyClientProvider struct{} func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error) { - geminiCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg) + geminiCount, vertexCompatCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg) if ctx != nil { select { case <-ctx.Done(): @@ -38,9 +38,10 @@ func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*A } } return &APIKeyClientResult{ - GeminiKeyCount: geminiCount, - ClaudeKeyCount: claudeCount, - CodexKeyCount: codexCount, - OpenAICompatCount: openAICompat, + GeminiKeyCount: geminiCount, + VertexCompatKeyCount: vertexCompatCount, + ClaudeKeyCount: claudeCount, + CodexKeyCount: codexCount, + OpenAICompatCount: openAICompat, }, nil } diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index c2ebba8d..f0b6bf53 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -324,7 +324,7 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName if len(a.Attributes) > 0 { providerKey = strings.TrimSpace(a.Attributes["provider_key"]) compatName = strings.TrimSpace(a.Attributes["compat_name"]) - if providerKey != "" || compatName != "" { + if compatName != "" { if providerKey == "" { providerKey = compatName } @@ -362,6 +362,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg)) case "vertex": s.coreManager.RegisterExecutor(executor.NewGeminiVertexExecutor(s.cfg)) + case "vertex-compat": + s.coreManager.RegisterExecutor(executor.NewGeminiVertexCompatExecutor(s.cfg)) case "gemini-cli": s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg)) case "aistudio": @@ -498,7 +500,7 @@ func (s *Service) Run(ctx context.Context) error { }() time.Sleep(100 * time.Millisecond) - fmt.Println("API server started successfully") + fmt.Printf("API server started successfully on: %d\n", s.cfg.Port) if s.hooks.OnAfterStart != nil { s.hooks.OnAfterStart(s) @@ -680,6 +682,35 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { // Vertex AI Gemini supports the same model identifiers as Gemini. models = registry.GetGeminiVertexModels() models = applyExcludedModels(models, excluded) + case "vertex-compat": + // Handle Vertex AI compatibility providers with custom model definitions + if s.cfg != nil && len(s.cfg.VertexCompatAPIKey) > 0 { + // Create models for all Vertex compatibility providers + allModels := make([]*ModelInfo, 0) + for i := range s.cfg.VertexCompatAPIKey { + compat := &s.cfg.VertexCompatAPIKey[i] + for j := range compat.Models { + m := compat.Models[j] + // Use alias as model ID, fallback to name if alias is empty + modelID := m.Alias + if modelID == "" { + modelID = m.Name + } + if modelID != "" { + allModels = append(allModels, &ModelInfo{ + ID: modelID, + Object: "model", + Created: time.Now().Unix(), + OwnedBy: "vertex-compat", + Type: "vertex-compat", + DisplayName: m.Name, + }) + } + } + } + models = allModels + } + case "gemini-cli": models = registry.GetGeminiCLIModels() models = applyExcludedModels(models, excluded) diff --git a/sdk/cliproxy/types.go b/sdk/cliproxy/types.go index b44185d1..42c7c488 100644 --- a/sdk/cliproxy/types.go +++ b/sdk/cliproxy/types.go @@ -49,19 +49,21 @@ type APIKeyClientProvider interface { Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error) } -// APIKeyClientResult contains API key based clients along with type counts. -// It provides metadata about the number of clients loaded for each provider type. +// APIKeyClientResult is returned by APIKeyClientProvider.Load() type APIKeyClientResult struct { - // GeminiKeyCount is the number of Gemini API key clients loaded. + // GeminiKeyCount is the number of Gemini API keys loaded GeminiKeyCount int - // ClaudeKeyCount is the number of Claude API key clients loaded. + // VertexCompatKeyCount is the number of Vertex-compatible API keys loaded + VertexCompatKeyCount int + + // ClaudeKeyCount is the number of Claude API keys loaded ClaudeKeyCount int - // CodexKeyCount is the number of Codex API key clients loaded. + // CodexKeyCount is the number of Codex API keys loaded CodexKeyCount int - // OpenAICompatCount is the number of OpenAI-compatible API key clients loaded. + // OpenAICompatCount is the number of OpenAI compatibility API keys loaded OpenAICompatCount int }