From 3d8ea811e72cd04a51e48d347c164f06bf7295b6 Mon Sep 17 00:00:00 2001 From: benjamin Date: Fri, 3 Apr 2026 15:08:06 +0800 Subject: [PATCH 1/4] fix: backport upstream responses, gemini, auth, and codex websocket fixes onto origin/main --- .gitignore | 1 + .../api/handlers/management/auth_files.go | 67 ++-- internal/api/modules/amp/fallback_handlers.go | 12 + internal/api/modules/amp/response_rewriter.go | 294 +++++++++++++++--- .../api/modules/amp/response_rewriter_test.go | 34 ++ internal/auth/codex/jwt_parser.go | 9 + internal/auth/codex/openai_auth.go | 17 +- internal/auth/codex/openai_auth_test.go | 4 +- internal/cmd/login.go | 38 +-- internal/misc/header_utils.go | 4 +- .../runtime/executor/aistudio_executor.go | 2 + .../runtime/executor/antigravity_executor.go | 14 +- .../antigravity_executor_client_test.go | 36 +++ internal/runtime/executor/claude_executor.go | 2 + internal/runtime/executor/codex_executor.go | 19 +- .../codex_executor_account_id_test.go | 198 ++++++++++++ .../executor/codex_websockets_executor.go | 170 ++++++++-- .../codex_websockets_executor_test.go | 247 +++++++++++++++ .../runtime/executor/gemini_cli_executor.go | 2 + internal/runtime/executor/gemini_executor.go | 2 + .../executor/gemini_vertex_executor.go | 4 + internal/runtime/executor/iflow_executor.go | 2 + internal/runtime/executor/kimi_executor.go | 2 + .../executor/openai_compat_executor.go | 2 + internal/runtime/executor/proxy_helpers.go | 124 ++++++-- .../runtime/executor/proxy_helpers_test.go | 93 ++++++ internal/runtime/executor/qwen_executor.go | 2 + .../gemini/claude/gemini_claude_request.go | 155 +++++---- sdk/api/handlers/handlers.go | 111 +++++-- .../handlers_stream_bootstrap_test.go | 81 +++++ .../openai/openai_responses_handlers.go | 205 +++++++++++- ...ai_responses_handlers_stream_error_test.go | 2 +- .../openai_responses_handlers_stream_test.go | 142 +++++++++ ...openai_responses_http_continuation_test.go | 149 +++++++++ .../openai/openai_responses_turn_state.go | 185 +++++++++++ .../openai/openai_responses_websocket.go | 104 +++++-- .../openai/openai_responses_websocket_test.go | 178 ++++++++++- sdk/cliproxy/auth/conductor.go | 24 +- .../auth/conductor_stream_retry_test.go | 216 +++++++++++++ 39 files changed, 2624 insertions(+), 329 deletions(-) create mode 100644 internal/runtime/executor/antigravity_executor_client_test.go create mode 100644 internal/runtime/executor/codex_executor_account_id_test.go create mode 100644 sdk/api/handlers/openai/openai_responses_handlers_stream_test.go create mode 100644 sdk/api/handlers/openai/openai_responses_http_continuation_test.go create mode 100644 sdk/api/handlers/openai/openai_responses_turn_state.go create mode 100644 sdk/cliproxy/auth/conductor_stream_retry_test.go diff --git a/.gitignore b/.gitignore index 90ff3a941d..80f4b2eb62 100644 --- a/.gitignore +++ b/.gitignore @@ -49,3 +49,4 @@ _bmad-output/* # macOS .DS_Store ._* +.gocache/ diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 2fcf3b087a..2519e92375 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -482,31 +482,37 @@ func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H { if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") { return nil } - idTokenRaw, ok := auth.Metadata["id_token"].(string) - if !ok { - return nil - } - idToken := strings.TrimSpace(idTokenRaw) - if idToken == "" { - return nil - } - claims, err := codex.ParseJWTToken(idToken) - if err != nil || claims == nil { - return nil - } result := gin.H{} - if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" { - result["chatgpt_account_id"] = v - } - if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" { - result["plan_type"] = v + + // Step 1: unconditionally parse id_token as the baseline source. + // Subscription date fields only exist in id_token, so this must always run. + if idTokenRaw, ok := auth.Metadata["id_token"].(string); ok { + if idToken := strings.TrimSpace(idTokenRaw); idToken != "" { + if claims, err := codex.ParseJWTToken(idToken); err == nil && claims != nil { + if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" { + result["chatgpt_account_id"] = v + } + if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" { + result["plan_type"] = v + } + if v := claims.CodexAuthInfo.ChatgptSubscriptionActiveStart; v != nil { + result["chatgpt_subscription_active_start"] = v + } + if v := claims.CodexAuthInfo.ChatgptSubscriptionActiveUntil; v != nil { + result["chatgpt_subscription_active_until"] = v + } + } + } } - if v := claims.CodexAuthInfo.ChatgptSubscriptionActiveStart; v != nil { - result["chatgpt_subscription_active_start"] = v + + // Step 2: override with explicit values from the JSON file (Metadata) if present. + // These take priority because the user may have set them directly in the imported file. + if v, ok := auth.Metadata["account_id"].(string); ok && strings.TrimSpace(v) != "" { + result["chatgpt_account_id"] = strings.TrimSpace(v) } - if v := claims.CodexAuthInfo.ChatgptSubscriptionActiveUntil; v != nil { - result["chatgpt_subscription_active_until"] = v + if v, ok := auth.Metadata["plan_type"].(string); ok && strings.TrimSpace(v) != "" { + result["plan_type"] = strings.TrimSpace(v) } if len(result) == 0 { @@ -2567,23 +2573,10 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage finalProjectID := projectID if responseProjectID != "" { if explicitProject && !strings.EqualFold(responseProjectID, projectID) { - // Check if this is a free user (gen-lang-client projects or free/legacy tier) - isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") || - strings.EqualFold(tierID, "FREE") || - strings.EqualFold(tierID, "LEGACY") - - if isFreeUser { - // For free users, use backend project ID for preview model access - log.Infof("Gemini onboarding: frontend project %s maps to backend project %s", projectID, responseProjectID) - log.Infof("Using backend project ID: %s (recommended for preview model access)", responseProjectID) - finalProjectID = responseProjectID - } else { - // Pro users: keep requested project ID (original behavior) - log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID) - } - } else { - finalProjectID = responseProjectID + log.Infof("Gemini onboarding: requested project %s maps to backend project %s", projectID, responseProjectID) + log.Infof("Using backend project ID: %s", responseProjectID) } + finalProjectID = responseProjectID } storage.ProjectID = strings.TrimSpace(finalProjectID) diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index 7d7f7f5f28..e4e0f8a650 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -123,6 +123,10 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc return } + // Sanitize request body: remove thinking blocks with invalid signatures + // to prevent upstream API 400 errors + bodyBytes = SanitizeAmpRequestBody(bodyBytes) + // Restore the body for the handler to read c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) @@ -249,6 +253,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel) logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) rewriter := NewResponseRewriter(c.Writer, modelName) + rewriter.suppressThinking = true c.Writer = rewriter // Filter Anthropic-Beta header only for local handling paths filterAntropicBetaHeader(c) @@ -259,10 +264,17 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc } else if len(providers) > 0 { // Log: Using local provider (free) logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath) + // Wrap with ResponseRewriter for local providers too, because upstream + // proxies (e.g. NewAPI) may return a different model name and lack + // Amp-required fields like thinking.signature. + rewriter := NewResponseRewriter(c.Writer, modelName) + rewriter.suppressThinking = providerName != "claude" + c.Writer = rewriter // Filter Anthropic-Beta header only for local handling paths filterAntropicBetaHeader(c) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) handler(c) + rewriter.Flush() } else { // No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) diff --git a/internal/api/modules/amp/response_rewriter.go b/internal/api/modules/amp/response_rewriter.go index 715034f1ca..0de95cf0cb 100644 --- a/internal/api/modules/amp/response_rewriter.go +++ b/internal/api/modules/amp/response_rewriter.go @@ -2,6 +2,7 @@ package amp import ( "bytes" + "fmt" "net/http" "strings" @@ -12,34 +13,85 @@ import ( ) // ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body -// It's used to rewrite model names in responses when model mapping is used +// It is used to rewrite model names in responses when model mapping is used +// and to keep Amp-compatible response shapes. type ResponseRewriter struct { gin.ResponseWriter - body *bytes.Buffer - originalModel string - isStreaming bool + body *bytes.Buffer + originalModel string + isStreaming bool + suppressThinking bool } -// NewResponseRewriter creates a new response rewriter for model name substitution +// NewResponseRewriter creates a new response rewriter for model name substitution. func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter { return &ResponseRewriter{ - ResponseWriter: w, - body: &bytes.Buffer{}, - originalModel: originalModel, + ResponseWriter: w, + body: &bytes.Buffer{}, + originalModel: originalModel, } } -// Write intercepts response writes and buffers them for model name replacement +const maxBufferedResponseBytes = 2 * 1024 * 1024 // 2MB safety cap + +func looksLikeSSEChunk(data []byte) bool { + for _, line := range bytes.Split(data, []byte("\n")) { + trimmed := bytes.TrimSpace(line) + if bytes.HasPrefix(trimmed, []byte("data:")) || + bytes.HasPrefix(trimmed, []byte("event:")) { + return true + } + } + return false +} + +func (rw *ResponseRewriter) enableStreaming(reason string) error { + if rw.isStreaming { + return nil + } + rw.isStreaming = true + + if rw.body != nil && rw.body.Len() > 0 { + buf := rw.body.Bytes() + toFlush := make([]byte, len(buf)) + copy(toFlush, buf) + rw.body.Reset() + + if _, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(toFlush)); err != nil { + return err + } + if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { + flusher.Flush() + } + } + + log.Debugf("amp response rewriter: switched to streaming (%s)", reason) + return nil +} + func (rw *ResponseRewriter) Write(data []byte) (int, error) { - // Detect streaming on first write - if rw.body.Len() == 0 && !rw.isStreaming { + if !rw.isStreaming && rw.body.Len() == 0 { contentType := rw.Header().Get("Content-Type") rw.isStreaming = strings.Contains(contentType, "text/event-stream") || strings.Contains(contentType, "stream") } + if !rw.isStreaming { + if looksLikeSSEChunk(data) { + if err := rw.enableStreaming("sse heuristic"); err != nil { + return 0, err + } + } else if rw.body.Len()+len(data) > maxBufferedResponseBytes { + log.Warnf("amp response rewriter: buffer exceeded %d bytes, switching to streaming", maxBufferedResponseBytes) + if err := rw.enableStreaming("buffer limit"); err != nil { + return 0, err + } + } + } + if rw.isStreaming { - n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data)) + rewritten := rw.rewriteStreamChunk(data) + n, err := rw.ResponseWriter.Write(rewritten) if err == nil { if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { flusher.Flush() @@ -50,7 +102,6 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) { return rw.body.Write(data) } -// Flush writes the buffered response with model names rewritten func (rw *ResponseRewriter) Flush() { if rw.isStreaming { if flusher, ok := rw.ResponseWriter.(http.Flusher); ok { @@ -59,40 +110,80 @@ func (rw *ResponseRewriter) Flush() { return } if rw.body.Len() > 0 { - if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil { + rewritten := rw.rewriteModelInResponse(rw.body.Bytes()) + // Update Content-Length to match the rewritten body size, since + // signature injection and model name changes alter the payload length. + rw.ResponseWriter.Header().Set("Content-Length", fmt.Sprintf("%d", len(rewritten))) + if _, err := rw.ResponseWriter.Write(rewritten); err != nil { log.Warnf("amp response rewriter: failed to write rewritten response: %v", err) } } } -// modelFieldPaths lists all JSON paths where model name may appear var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"} -// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON -// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility -func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte { - // 1. Amp Compatibility: Suppress thinking blocks if tool use is detected - // The Amp client struggles when both thinking and tool_use blocks are present +// ensureAmpSignature injects empty signature fields into tool_use/thinking blocks +// in API responses so that the Amp TUI does not crash on P.signature.length. +func ensureAmpSignature(data []byte) []byte { + for index, block := range gjson.GetBytes(data, "content").Array() { + blockType := block.Get("type").String() + if blockType != "tool_use" && blockType != "thinking" { + continue + } + signaturePath := fmt.Sprintf("content.%d.signature", index) + if gjson.GetBytes(data, signaturePath).Exists() { + continue + } + var err error + data, err = sjson.SetBytes(data, signaturePath, "") + if err != nil { + log.Warnf("Amp ResponseRewriter: failed to add empty signature to %s block: %v", blockType, err) + break + } + } + + contentBlockType := gjson.GetBytes(data, "content_block.type").String() + if (contentBlockType == "tool_use" || contentBlockType == "thinking") && !gjson.GetBytes(data, "content_block.signature").Exists() { + var err error + data, err = sjson.SetBytes(data, "content_block.signature", "") + if err != nil { + log.Warnf("Amp ResponseRewriter: failed to add empty signature to streaming %s block: %v", contentBlockType, err) + } + } + + return data +} + + +func (rw *ResponseRewriter) suppressAmpThinking(data []byte) []byte { + if !rw.suppressThinking { + return data + } if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() { filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`) if filtered.Exists() { originalCount := gjson.GetBytes(data, "content.#").Int() filteredCount := filtered.Get("#").Int() - if originalCount > filteredCount { var err error data, err = sjson.SetBytes(data, "content", filtered.Value()) if err != nil { log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err) - } else { - log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount) - // Log the result for verification - log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String()) } } } } + return data +} + +func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte { + data = ensureAmpSignature(data) + data = rw.suppressAmpThinking(data) + if len(data) == 0 { + return data + } + if rw.originalModel == "" { return data } @@ -104,24 +195,151 @@ func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte { return data } -// rewriteStreamChunk rewrites model names in SSE stream chunks func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte { - if rw.originalModel == "" { - return chunk - } - - // SSE format: "data: {json}\n\n" lines := bytes.Split(chunk, []byte("\n")) - for i, line := range lines { - if bytes.HasPrefix(line, []byte("data: ")) { - jsonData := bytes.TrimPrefix(line, []byte("data: ")) + var out [][]byte + + i := 0 + for i < len(lines) { + line := lines[i] + trimmed := bytes.TrimSpace(line) + + // Case 1: "event:" line - look ahead for its "data:" line + if bytes.HasPrefix(trimmed, []byte("event: ")) { + // Scan forward past blank lines to find the data: line + dataIdx := -1 + for j := i + 1; j < len(lines); j++ { + t := bytes.TrimSpace(lines[j]) + if len(t) == 0 { + continue + } + if bytes.HasPrefix(t, []byte("data: ")) { + dataIdx = j + } + break + } + + if dataIdx >= 0 { + // Found event+data pair - process through rewriter + jsonData := bytes.TrimPrefix(bytes.TrimSpace(lines[dataIdx]), []byte("data: ")) + if len(jsonData) > 0 && jsonData[0] == '{' { + rewritten := rw.rewriteStreamEvent(jsonData) + if rewritten == nil { + i = dataIdx + 1 + continue + } + // Emit event line + out = append(out, line) + // Emit blank lines between event and data + for k := i + 1; k < dataIdx; k++ { + out = append(out, lines[k]) + } + // Emit rewritten data + out = append(out, append([]byte("data: "), rewritten...)) + i = dataIdx + 1 + continue + } + } + + // No data line found (orphan event from cross-chunk split) + // Pass it through as-is - the data will arrive in the next chunk + out = append(out, line) + i++ + continue + } + + // Case 2: standalone "data:" line (no preceding event: in this chunk) + if bytes.HasPrefix(trimmed, []byte("data: ")) { + jsonData := bytes.TrimPrefix(trimmed, []byte("data: ")) if len(jsonData) > 0 && jsonData[0] == '{' { - // Rewrite JSON in the data line - rewritten := rw.rewriteModelInResponse(jsonData) - lines[i] = append([]byte("data: "), rewritten...) + rewritten := rw.rewriteStreamEvent(jsonData) + if rewritten != nil { + out = append(out, append([]byte("data: "), rewritten...)) + } + i++ + continue } } + + // Case 3: everything else + out = append(out, line) + i++ } - return bytes.Join(lines, []byte("\n")) + return bytes.Join(out, []byte("\n")) +} + +// rewriteStreamEvent processes a single JSON event in the SSE stream. +// It rewrites model names and ensures signature fields exist. +func (rw *ResponseRewriter) rewriteStreamEvent(data []byte) []byte { + // Inject empty signature where needed + data = ensureAmpSignature(data) + + // Rewrite model name + if rw.originalModel != "" { + for _, path := range modelFieldPaths { + if gjson.GetBytes(data, path).Exists() { + data, _ = sjson.SetBytes(data, path, rw.originalModel) + } + } + } + + return data +} + +// SanitizeAmpRequestBody removes thinking blocks with empty/missing/invalid signatures +// from the messages array in a request body before forwarding to the upstream API. +// This prevents 400 errors from the API which requires valid signatures on thinking blocks. +func SanitizeAmpRequestBody(body []byte) []byte { + messages := gjson.GetBytes(body, "messages") + if !messages.Exists() || !messages.IsArray() { + return body + } + + modified := false + for msgIdx, msg := range messages.Array() { + if msg.Get("role").String() != "assistant" { + continue + } + content := msg.Get("content") + if !content.Exists() || !content.IsArray() { + continue + } + + var keepBlocks []interface{} + removedCount := 0 + + for _, block := range content.Array() { + blockType := block.Get("type").String() + if blockType == "thinking" { + sig := block.Get("signature") + if !sig.Exists() || sig.Type != gjson.String || strings.TrimSpace(sig.String()) == "" { + removedCount++ + continue + } + } + keepBlocks = append(keepBlocks, block.Value()) + } + + if removedCount > 0 { + contentPath := fmt.Sprintf("messages.%d.content", msgIdx) + var err error + if len(keepBlocks) == 0 { + body, err = sjson.SetBytes(body, contentPath, []interface{}{}) + } else { + body, err = sjson.SetBytes(body, contentPath, keepBlocks) + } + if err != nil { + log.Warnf("Amp RequestSanitizer: failed to remove thinking blocks from message %d: %v", msgIdx, err) + continue + } + modified = true + log.Debugf("Amp RequestSanitizer: removed %d thinking blocks with invalid signatures from message %d", removedCount, msgIdx) + } + } + + if modified { + log.Debugf("Amp RequestSanitizer: sanitized request body") + } + return body } diff --git a/internal/api/modules/amp/response_rewriter_test.go b/internal/api/modules/amp/response_rewriter_test.go index 114a9516fc..88d673feb7 100644 --- a/internal/api/modules/amp/response_rewriter_test.go +++ b/internal/api/modules/amp/response_rewriter_test.go @@ -100,6 +100,40 @@ func TestRewriteStreamChunk_MessageModel(t *testing.T) { } } +func TestRewriteStreamChunk_PassesThroughThinkingBlocks(t *testing.T) { + rw := &ResponseRewriter{} + + chunk := []byte("event: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":0,\"content_block\":{\"type\":\"thinking\",\"thinking\":\"\"}}\n\nevent: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"thinking_delta\",\"thinking\":\"abc\"}}\n\nevent: content_block_stop\ndata: {\"type\":\"content_block_stop\",\"index\":0}\n\nevent: content_block_start\ndata: {\"type\":\"content_block_start\",\"index\":1,\"content_block\":{\"type\":\"tool_use\",\"name\":\"bash\",\"input\":{}}}\n\n") + result := rw.rewriteStreamChunk(chunk) + + if !contains(result, []byte("\"thinking_delta\"")) { + t.Fatalf("expected thinking blocks to pass through in streaming, got %s", string(result)) + } + if !contains(result, []byte("\"tool_use\"")) { + t.Fatalf("expected tool_use content_block frame to remain, got %s", string(result)) + } + if !contains(result, []byte("\"signature\":\"\"")) { + t.Fatalf("expected tool_use content_block signature injection, got %s", string(result)) + } +} + +func TestSanitizeAmpRequestBody_RemovesWhitespaceAndNonStringSignatures(t *testing.T) { + input := []byte(`{"messages":[{"role":"assistant","content":[{"type":"thinking","thinking":"drop-whitespace","signature":" "},{"type":"thinking","thinking":"drop-number","signature":123},{"type":"thinking","thinking":"keep-valid","signature":"valid-signature"},{"type":"text","text":"keep-text"}]}]}`) + result := SanitizeAmpRequestBody(input) + + if contains(result, []byte("drop-whitespace")) { + t.Fatalf("expected whitespace-only signature block to be removed, got %s", string(result)) + } + if contains(result, []byte("drop-number")) { + t.Fatalf("expected non-string signature block to be removed, got %s", string(result)) + } + if !contains(result, []byte("keep-valid")) { + t.Fatalf("expected valid thinking block to remain, got %s", string(result)) + } + if !contains(result, []byte("keep-text")) { + t.Fatalf("expected non-thinking content to remain, got %s", string(result)) + } +} func contains(data, substr []byte) bool { for i := 0; i <= len(data)-len(substr); i++ { if string(data[i:i+len(substr)]) == string(substr) { diff --git a/internal/auth/codex/jwt_parser.go b/internal/auth/codex/jwt_parser.go index 130e86420a..db49781894 100644 --- a/internal/auth/codex/jwt_parser.go +++ b/internal/auth/codex/jwt_parser.go @@ -100,3 +100,12 @@ func (c *JWTClaims) GetUserEmail() string { func (c *JWTClaims) GetAccountID() string { return c.CodexAuthInfo.ChatgptAccountID } + +// GetClientID returns the first audience value from the JWT claims, which represents +// the OAuth client_id used during token issuance. +func (c *JWTClaims) GetClientID() string { + if len(c.Aud) > 0 { + return c.Aud[0] + } + return "" +} diff --git a/internal/auth/codex/openai_auth.go b/internal/auth/codex/openai_auth.go index 001155c77a..444a4b0a25 100644 --- a/internal/auth/codex/openai_auth.go +++ b/internal/auth/codex/openai_auth.go @@ -194,15 +194,17 @@ func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, } // RefreshTokens refreshes an access token using a refresh token. -// This method is called when an access token has expired. It makes a request to the -// token endpoint to obtain a new set of tokens. -func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*CodexTokenData, error) { +// clientID overrides the default hardcoded ClientID when non-empty. +func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken, clientID string) (*CodexTokenData, error) { if refreshToken == "" { return nil, fmt.Errorf("refresh token is required") } + if clientID == "" { + clientID = ClientID + } data := url.Values{ - "client_id": {ClientID}, + "client_id": {clientID}, "grant_type": {"refresh_token"}, "refresh_token": {refreshToken}, "scope": {"openid profile email"}, @@ -285,9 +287,8 @@ func (o *CodexAuth) CreateTokenStorage(bundle *CodexAuthBundle) *CodexTokenStora } // RefreshTokensWithRetry refreshes tokens with a built-in retry mechanism. -// It attempts to refresh the tokens up to a specified maximum number of retries, -// with an exponential backoff strategy to handle transient network errors. -func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*CodexTokenData, error) { +// clientID overrides the default hardcoded ClientID when non-empty. +func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken, clientID string, maxRetries int) (*CodexTokenData, error) { var lastErr error for attempt := 0; attempt < maxRetries; attempt++ { @@ -300,7 +301,7 @@ func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken str } } - tokenData, err := o.RefreshTokens(ctx, refreshToken) + tokenData, err := o.RefreshTokens(ctx, refreshToken, clientID) if err == nil { return tokenData, nil } diff --git a/internal/auth/codex/openai_auth_test.go b/internal/auth/codex/openai_auth_test.go index b46c857ebf..52a814b8a7 100644 --- a/internal/auth/codex/openai_auth_test.go +++ b/internal/auth/codex/openai_auth_test.go @@ -42,7 +42,7 @@ func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) { }, } - _, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3) + _, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", "app_EMoamEEZ73f0CkXaXp7hrann", 3) if err == nil { t.Fatalf("expected error for non-retryable refresh failure") } @@ -73,7 +73,7 @@ func TestRefreshTokensWithRetry_UnauthorizedOnlyAttemptsOnce(t *testing.T) { }, } - _, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3) + _, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", "", 3) if err == nil { t.Fatalf("expected error for unauthorized refresh failure") } diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 16af718ebb..22404dac9c 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -333,42 +333,10 @@ func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage finalProjectID := projectID if responseProjectID != "" { if explicitProject && !strings.EqualFold(responseProjectID, projectID) { - // Check if this is a free user (gen-lang-client projects or free/legacy tier) - isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") || - strings.EqualFold(tierID, "FREE") || - strings.EqualFold(tierID, "LEGACY") - - if isFreeUser { - // Interactive prompt for free users - fmt.Printf("\nGoogle returned a different project ID:\n") - fmt.Printf(" Requested (frontend): %s\n", projectID) - fmt.Printf(" Returned (backend): %s\n\n", responseProjectID) - fmt.Printf(" Backend project IDs have access to preview models (gemini-3-*).\n") - fmt.Printf(" This is normal for free tier users.\n\n") - fmt.Printf("Which project ID would you like to use?\n") - fmt.Printf(" [1] Backend (recommended): %s\n", responseProjectID) - fmt.Printf(" [2] Frontend: %s\n\n", projectID) - fmt.Printf("Enter choice [1]: ") - - reader := bufio.NewReader(os.Stdin) - choice, _ := reader.ReadString('\n') - choice = strings.TrimSpace(choice) - - if choice == "2" { - log.Infof("Using frontend project ID: %s", projectID) - fmt.Println(". Warning: Frontend project IDs may not have access to preview models.") - finalProjectID = projectID - } else { - log.Infof("Using backend project ID: %s (recommended)", responseProjectID) - finalProjectID = responseProjectID - } - } else { - // Pro users: keep requested project ID (original behavior) - log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID) - } - } else { - finalProjectID = responseProjectID + log.Infof("Gemini onboarding: requested project %s maps to backend project %s", projectID, responseProjectID) + log.Infof("Using backend project ID: %s", responseProjectID) } + finalProjectID = responseProjectID } storage.ProjectID = strings.TrimSpace(finalProjectID) diff --git a/internal/misc/header_utils.go b/internal/misc/header_utils.go index 5752a26956..ac022a9627 100644 --- a/internal/misc/header_utils.go +++ b/internal/misc/header_utils.go @@ -12,7 +12,7 @@ import ( const ( // GeminiCLIVersion is the version string reported in the User-Agent for upstream requests. - GeminiCLIVersion = "0.31.0" + GeminiCLIVersion = "0.34.0" // GeminiCLIApiClientHeader is the value for the X-Goog-Api-Client header sent to the Gemini CLI upstream. GeminiCLIApiClientHeader = "google-genai-sdk/1.41.0 gl-node/v22.19.0" @@ -46,7 +46,7 @@ func GeminiCLIUserAgent(model string) string { if model == "" { model = "unknown" } - return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch()) + return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s; terminal)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch()) } // ScrubProxyAndFingerprintHeaders removes all headers that could reveal diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index b1e23860cf..efdf74f044 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -117,6 +117,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) translatedReq, body, err := e.translateRequest(req, opts, false) if err != nil { @@ -176,6 +177,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) translatedReq, body, err := e.translateRequest(req, opts, true) if err != nil { diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index cda02d2cea..0947acdac6 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -75,8 +75,11 @@ func NewAntigravityExecutor(cfg *config.Config) *AntigravityExecutor { // It is initialized once via antigravityTransportOnce to avoid leaking a new connection pool // (and the goroutines managing it) on every request. var ( - antigravityTransport *http.Transport - antigravityTransportOnce sync.Once + antigravityTransport *http.Transport + antigravityTransportOnce sync.Once + antigravityEnvironmentProxyTransport = sync.OnceValue(func() *http.Transport { + return cloneTransportWithHTTP11(newEnvironmentProxyTransport()) + }) ) func cloneTransportWithHTTP11(base *http.Transport) *http.Transport { @@ -122,6 +125,10 @@ func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cli // Preserve proxy settings from proxy-aware transports while forcing HTTP/1.1. if transport, ok := client.Transport.(*http.Transport); ok { + if transport == newEnvironmentProxyTransport() { + client.Transport = antigravityEnvironmentProxyTransport() + return client + } client.Transport = cloneTransportWithHTTP11(transport) } return client @@ -205,6 +212,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("antigravity") @@ -347,6 +355,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("antigravity") @@ -739,6 +748,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("antigravity") diff --git a/internal/runtime/executor/antigravity_executor_client_test.go b/internal/runtime/executor/antigravity_executor_client_test.go new file mode 100644 index 0000000000..2af21fc1f2 --- /dev/null +++ b/internal/runtime/executor/antigravity_executor_client_test.go @@ -0,0 +1,36 @@ +package executor + +import ( + "context" + "net/http" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +func TestNewAntigravityHTTPClientReusesSharedEnvironmentProxyTransport(t *testing.T) { + setEnvironmentProxy(t, "http://env-proxy.example.com:8080") + + clientA := newAntigravityHTTPClient(context.Background(), &config.Config{}, &cliproxyauth.Auth{}, 0) + clientB := newAntigravityHTTPClient(context.Background(), &config.Config{}, &cliproxyauth.Auth{}, 0) + + transportA, okA := clientA.Transport.(*http.Transport) + if !okA { + t.Fatalf("clientA transport type = %T, want *http.Transport", clientA.Transport) + } + transportB, okB := clientB.Transport.(*http.Transport) + if !okB { + t.Fatalf("clientB transport type = %T, want *http.Transport", clientB.Transport) + } + + if transportA != transportB { + t.Fatal("expected Antigravity environment proxy transport to be shared across clients") + } + if transportA == newEnvironmentProxyTransport() { + t.Fatal("expected Antigravity transport to use its HTTP/1.1 clone, not the generic environment proxy transport") + } + if transportA.ForceAttemptHTTP2 { + t.Fatal("expected Antigravity transport to keep HTTP/2 disabled") + } +} diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 8e356f74d3..8e64d79a2f 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -103,6 +103,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("claude") // Use streaming translation to preserve function calling, except for claude. @@ -271,6 +272,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("claude") originalPayloadSource := req.Payload diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 56382b4489..10c510bab0 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -93,6 +93,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat plan, err := e.prepareCodexRequestPlan(ctx, req, opts, codexPreparedRequestPlanExecute) @@ -178,6 +179,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat plan, err := e.prepareCodexRequestPlan(ctx, req, opts, codexPreparedRequestPlanCompact) @@ -265,6 +267,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat plan, err := e.prepareCodexRequestPlan(ctx, req, opts, codexPreparedRequestPlanExecuteStream) @@ -530,17 +533,29 @@ func (e *CodexExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (* if auth == nil { return nil, statusErr{code: 500, msg: "codex executor: auth is nil"} } - var refreshToken string + var refreshToken, clientID string if auth.Metadata != nil { if v, ok := auth.Metadata["refresh_token"].(string); ok && v != "" { refreshToken = v } + // Prefer explicit client_id stored in metadata + if v, ok := auth.Metadata["client_id"].(string); ok && v != "" { + clientID = v + } } if refreshToken == "" { return auth, nil } + // Fall back to parsing client_id from id_token.aud[0] + if clientID == "" { + if idTokenRaw, ok := auth.Metadata["id_token"].(string); ok && idTokenRaw != "" { + if claims, err := codexauth.ParseJWTToken(idTokenRaw); err == nil && claims != nil { + clientID = claims.GetClientID() + } + } + } svc := codexauth.NewCodexAuth(e.cfg) - td, err := svc.RefreshTokensWithRetry(ctx, refreshToken, 3) + td, err := svc.RefreshTokensWithRetry(ctx, refreshToken, clientID, 3) if err != nil { return nil, err } diff --git a/internal/runtime/executor/codex_executor_account_id_test.go b/internal/runtime/executor/codex_executor_account_id_test.go new file mode 100644 index 0000000000..fb01187bb4 --- /dev/null +++ b/internal/runtime/executor/codex_executor_account_id_test.go @@ -0,0 +1,198 @@ +package executor + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/url" + "os" + "strings" + "testing" + + "github.com/google/uuid" + tls "github.com/refraction-networking/utls" + "golang.org/x/net/http2" +) + +// utlsTransport is a minimal Chrome-fingerprint TLS transport for test use. +// Supports HTTP CONNECT proxy tunneling. +type utlsTransport struct { + proxyURL string +} + +func newUtlsTransport(proxyURL string) *utlsTransport { + return &utlsTransport{proxyURL: proxyURL} +} + +func (t *utlsTransport) dial(addr string) (net.Conn, error) { + if t.proxyURL == "" { + return net.Dial("tcp", addr) + } + u, err := url.Parse(t.proxyURL) + if err != nil { + return nil, fmt.Errorf("parse proxy url: %w", err) + } + conn, err := net.Dial("tcp", u.Host) + if err != nil { + return nil, fmt.Errorf("connect to proxy: %w", err) + } + // HTTP CONNECT tunnel + req, _ := http.NewRequest(http.MethodConnect, "http://"+addr, nil) + req.Host = addr + if err = req.Write(conn); err != nil { + conn.Close() + return nil, fmt.Errorf("write CONNECT: %w", err) + } + resp, err := http.ReadResponse(bufio.NewReader(conn), req) + if err != nil { + conn.Close() + return nil, fmt.Errorf("read CONNECT response: %w", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusOK { + conn.Close() + return nil, fmt.Errorf("proxy CONNECT failed: %s", resp.Status) + } + return conn, nil +} + +func (t *utlsTransport) RoundTrip(req *http.Request) (*http.Response, error) { + host := req.URL.Hostname() + addr := host + ":443" + + conn, err := t.dial(addr) + if err != nil { + return nil, fmt.Errorf("dial: %w", err) + } + + tlsConn := tls.UClient(conn, &tls.Config{ServerName: host}, tls.HelloChrome_Auto) + if err = tlsConn.Handshake(); err != nil { + conn.Close() + return nil, fmt.Errorf("tls handshake: %w", err) + } + + tr := &http2.Transport{} + h2Conn, err := tr.NewClientConn(tlsConn) + if err != nil { + tlsConn.Close() + return nil, fmt.Errorf("h2 conn: %w", err) + } + + return h2Conn.RoundTrip(req) +} + +// planTypePriority returns a numeric priority for a plan_type string. +// Higher value means higher priority: team > plus > free > others. +func planTypePriority(planType string) int { + switch strings.ToLower(planType) { + case "team": + return 3 + case "plus": + return 2 + case "free": + return 1 + default: + return 0 + } +} + +// pickBestAccountID selects the best account_id from the $.accounts map returned by +// the accounts/check API. Priority: team > plus > free > any other. +// Returns empty string if no accounts are found. +func pickBestAccountID(accounts map[string]any) string { + bestID := "" + bestPriority := -1 + for accountID, v := range accounts { + info, ok := v.(map[string]any) + if !ok { + continue + } + account, ok := info["account"].(map[string]any) + if !ok { + continue + } + planType, _ := account["plan_type"].(string) + p := planTypePriority(planType) + if p > bestPriority { + bestPriority = p + bestID = accountID + } + } + return bestID +} + +// TestCodexAccountCheck tests GET https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27 +// using a real access_token. Set CODEX_ACCESS_TOKEN (and optionally CODEX_PROXY_URL) to run. +// +// Example: +// +// CODEX_ACCESS_TOKEN=eyJ... go test ./internal/runtime/executor/... -run TestCodexAccountCheck -v +// CODEX_ACCESS_TOKEN=eyJ... CODEX_PROXY_URL=http://127.0.0.1:7890 go test ./internal/runtime/executor/... -run TestCodexAccountCheck -v +func TestCodexAccountCheck(t *testing.T) { + accessToken := os.Getenv("CODEX_ACCESS_TOKEN") + if accessToken == "" { + t.Skip("skipping: CODEX_ACCESS_TOKEN not set") + } + proxyURL := os.Getenv("CODEX_PROXY_URL") + deviceID := uuid.NewString() + targetURL := "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27?timezone_offset_min=-480" + + req, err := http.NewRequest(http.MethodGet, targetURL, nil) + if err != nil { + t.Fatalf("build request: %v", err) + } + + req.Header.Set("accept", "*/*") + req.Header.Set("accept-language", "zh-HK,zh;q=0.9,en-US;q=0.8,en;q=0.7") + req.Header.Set("authorization", "Bearer "+strings.TrimSpace(accessToken)) + req.Header.Set("oai-device-id", deviceID) + req.Header.Set("oai-language", "zh-HK") + req.Header.Set("referer", "https://chatgpt.com/") + req.Header.Set("user-agent", "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/141.0.0.0 Safari/537.36") + req.Header.Set("sec-ch-ua", `"Google Chrome";v="141", "Not?A_Brand";v="8", "Chromium";v="141"`) + req.Header.Set("sec-ch-ua-mobile", "?0") + req.Header.Set("sec-ch-ua-platform", `"macOS"`) + req.Header.Set("sec-fetch-dest", "empty") + req.Header.Set("sec-fetch-mode", "cors") + req.Header.Set("sec-fetch-site", "same-origin") + req.Header.Set("priority", "u=1, i") + + client := &http.Client{ + Transport: newUtlsTransport(proxyURL), + } + + resp, err := client.Do(req) + if err != nil { + t.Fatalf("request failed: %v", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read response: %v", err) + } + + t.Logf("status: %d", resp.StatusCode) + t.Logf("device_id: %s", deviceID) + t.Logf("response: %s", string(body)) + + if resp.StatusCode != http.StatusOK { + t.Errorf("expected 200, got %d", resp.StatusCode) + return + } + + // Parse response and pick the best account_id + var parsed map[string]any + if err = json.Unmarshal(body, &parsed); err != nil { + t.Fatalf("unmarshal response: %v", err) + } + if accounts, ok := parsed["accounts"].(map[string]any); ok { + bestID := pickBestAccountID(accounts) + t.Logf("best_account_id (team>plus>free): %s", bestID) + } else { + t.Logf("no $.accounts map found in response") + } +} diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 6e8006e08a..5091b7292c 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -82,10 +82,40 @@ type codexWebsocketRead struct { err error } +func trySendCodexWebsocketRead(ch chan codexWebsocketRead, done <-chan struct{}, ev codexWebsocketRead) { + if ch == nil { + return + } + defer func() { + if r := recover(); r != nil { + log.Debugf("codex websockets executor: recover trySendCodexWebsocketRead panic=%v", r) + } + }() + select { + case ch <- ev: + case <-done: + default: + } +} + +func tryCloseCodexWebsocketRead(ch chan codexWebsocketRead) { + if ch == nil { + return + } + defer func() { + if r := recover(); r != nil { + log.Debugf("codex websockets executor: recover tryCloseCodexWebsocketRead panic=%v", r) + } + }() + close(ch) +} + func (s *codexWebsocketSession) setActive(ch chan codexWebsocketRead) { if s == nil { return } + // 该方法仅持有 activeMu 调用避免与 connMu->activeMu 锁序冲突 + // 不要在持有 connMu 时调用避免未来引入反向锁序 s.activeMu.Lock() if s.activeCancel != nil { s.activeCancel() @@ -105,6 +135,8 @@ func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) { if s == nil { return } + // 该方法仅持有 activeMu 调用避免与 connMu->activeMu 锁序冲突 + // 不要在持有 connMu 时调用避免未来引入反向锁序 s.activeMu.Lock() if s.activeCh == ch { s.activeCh = nil @@ -117,6 +149,61 @@ func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) { s.activeMu.Unlock() } +func (s *codexWebsocketSession) isCurrentConn(conn *websocket.Conn) bool { + if s == nil || conn == nil { + return false + } + s.connMu.Lock() + current := s.conn + s.connMu.Unlock() + return current == conn +} + +func (s *codexWebsocketSession) activeSnapshotForCurrentConn(conn *websocket.Conn) (chan codexWebsocketRead, <-chan struct{}, bool) { + if s == nil || conn == nil { + return nil, nil, false + } + // 锁顺序固定为 connMu -> activeMu + s.connMu.Lock() + if s.conn != conn { + s.connMu.Unlock() + return nil, nil, false + } + s.activeMu.Lock() + ch := s.activeCh + done := s.activeDone + s.activeMu.Unlock() + s.connMu.Unlock() + return ch, done, true +} + +func (s *codexWebsocketSession) clearActiveForCurrentConn(conn *websocket.Conn, ch chan codexWebsocketRead) bool { + if s == nil || conn == nil || ch == nil { + return false + } + // 锁顺序固定为 connMu -> activeMu + s.connMu.Lock() + if s.conn != conn { + s.connMu.Unlock() + return false + } + s.activeMu.Lock() + if s.activeCh != ch { + s.activeMu.Unlock() + s.connMu.Unlock() + return false + } + s.activeCh = nil + if s.activeCancel != nil { + s.activeCancel() + } + s.activeCancel = nil + s.activeDone = nil + s.activeMu.Unlock() + s.connMu.Unlock() + return true +} + func (s *codexWebsocketSession) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error { if s == nil { return fmt.Errorf("codex websockets executor: session is nil") @@ -157,6 +244,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("codex") @@ -366,6 +454,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("codex") @@ -1076,6 +1165,7 @@ func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWeb } func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *cliproxyauth.Auth, sess *codexWebsocketSession, authID string, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) { + authID = strings.TrimSpace(authID) if sess == nil { return e.dialCodexWebsocket(ctx, auth, wsURL, headers) } @@ -1083,8 +1173,16 @@ func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth * sess.connMu.Lock() conn := sess.conn readerConn := sess.readerConn + currentAuthID := strings.TrimSpace(sess.authID) sess.connMu.Unlock() + if conn != nil && currentAuthID != authID { + // 账号切换时先断开旧连接避免继续复用旧账号 + e.invalidateUpstreamConn(sess, conn, "auth_switched", nil) + conn = nil + readerConn = nil + } if conn != nil { + // 账号未变化时复用连接减少不必要重连 if readerConn != conn { sess.connMu.Lock() sess.readerConn = conn @@ -1126,21 +1224,24 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, return } for { + if !sess.isCurrentConn(conn) { + // 旧连接读循环直接退出避免误伤新请求通道 + return + } _ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout)) msgType, payload, errRead := conn.ReadMessage() if errRead != nil { - sess.activeMu.Lock() - ch := sess.activeCh - done := sess.activeDone - sess.activeMu.Unlock() + // 在同一临界区做归属校验和通道快照避免检查后竞态 + ch, done, current := sess.activeSnapshotForCurrentConn(conn) + if !current { + // 旧连接读错时不触碰当前活跃通道 + return + } if ch != nil { - select { - case ch <- codexWebsocketRead{conn: conn, err: errRead}: - case <-done: - default: + trySendCodexWebsocketRead(ch, done, codexWebsocketRead{conn: conn, err: errRead}) + if sess.clearActiveForCurrentConn(conn, ch) { + tryCloseCodexWebsocketRead(ch) } - sess.clearActive(ch) - close(ch) } e.invalidateUpstreamConn(sess, conn, "upstream_disconnected", errRead) return @@ -1149,29 +1250,29 @@ func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, if msgType != websocket.TextMessage { if msgType == websocket.BinaryMessage { errBinary := fmt.Errorf("codex websockets executor: unexpected binary message") - sess.activeMu.Lock() - ch := sess.activeCh - done := sess.activeDone - sess.activeMu.Unlock() + // 在同一临界区做归属校验和通道快照避免检查后竞态 + ch, done, current := sess.activeSnapshotForCurrentConn(conn) + if !current { + // 旧连接二进制异常时不触碰当前活跃通道 + return + } if ch != nil { - select { - case ch <- codexWebsocketRead{conn: conn, err: errBinary}: - case <-done: - default: + trySendCodexWebsocketRead(ch, done, codexWebsocketRead{conn: conn, err: errBinary}) + if sess.clearActiveForCurrentConn(conn, ch) { + tryCloseCodexWebsocketRead(ch) } - sess.clearActive(ch) - close(ch) } e.invalidateUpstreamConn(sess, conn, "unexpected_binary", errBinary) return } continue } - - sess.activeMu.Lock() - ch := sess.activeCh - done := sess.activeDone - sess.activeMu.Unlock() + // 在同一临界区做归属校验和通道快照避免检查后竞态 + ch, done, current := sess.activeSnapshotForCurrentConn(conn) + if !current { + // 旧连接消息不再分发给新连接请求 + return + } if ch == nil { continue } @@ -1258,17 +1359,34 @@ func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSess reason = "session_closed" } + // 锁顺序固定为 connMu -> activeMu sess.connMu.Lock() conn := sess.conn authID := sess.authID wsURL := sess.wsURL + sessionID := sess.sessionID sess.conn = nil if sess.readerConn == conn { sess.readerConn = nil } - sessionID := sess.sessionID + sess.activeMu.Lock() + ch := sess.activeCh + done := sess.activeDone + if sess.activeCancel != nil { + sess.activeCancel() + } + sess.activeCh = nil + sess.activeCancel = nil + sess.activeDone = nil + sess.activeMu.Unlock() sess.connMu.Unlock() + if ch != nil { + // 会话关闭时允许主动 fail active 唤醒在途 readCodexWebsocketMessage + trySendCodexWebsocketRead(ch, done, codexWebsocketRead{conn: conn, err: fmt.Errorf("codex websockets executor: execution session closed")}) + tryCloseCodexWebsocketRead(ch) + } + if conn == nil { return } diff --git a/internal/runtime/executor/codex_websockets_executor_test.go b/internal/runtime/executor/codex_websockets_executor_test.go index e0330d7492..1319093f99 100644 --- a/internal/runtime/executor/codex_websockets_executor_test.go +++ b/internal/runtime/executor/codex_websockets_executor_test.go @@ -4,6 +4,8 @@ import ( "context" "net/http" "net/http/httptest" + "strings" + "sync" "testing" "time" @@ -438,3 +440,248 @@ func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) { t.Fatal("expected websocket proxy function to be nil for direct mode") } } + +func TestReadCodexWebsocketMessageReturnsWhenReadChannelClosed(t *testing.T) { + t.Parallel() + + sess := &codexWebsocketSession{} + conn := &websocket.Conn{} + readCh := make(chan codexWebsocketRead) + close(readCh) + + _, _, err := readCodexWebsocketMessage(context.Background(), sess, conn, readCh) + if err == nil { + t.Fatal("expected error when session read channel is closed") + } + if !strings.Contains(err.Error(), "session read channel closed") { + t.Fatalf("error = %v, want contains session read channel closed", err) + } +} + +func TestEnsureUpstreamConnReconnectsWhenAuthChanges(t *testing.T) { + var ( + mu sync.Mutex + authorizations []string + ) + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + mu.Lock() + authorizations = append(authorizations, strings.TrimSpace(r.Header.Get("Authorization"))) + mu.Unlock() + + go func() { + defer func() { + _ = conn.Close() + }() + for { + if _, _, errRead := conn.ReadMessage(); errRead != nil { + return + } + } + }() + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + executor := NewCodexWebsocketsExecutor(&config.Config{}) + sess := executor.getOrCreateSession("test-session") + if sess == nil { + t.Fatal("expected session to be created") + } + + auth1 := &cliproxyauth.Auth{ID: "auth-1"} + headers1 := http.Header{} + headers1.Set("Authorization", "Bearer token-1") + conn1, _, errDial1 := executor.ensureUpstreamConn(context.Background(), auth1, sess, auth1.ID, wsURL, headers1) + if errDial1 != nil { + t.Fatalf("first ensureUpstreamConn failed: %v", errDial1) + } + if conn1 == nil { + t.Fatal("first ensureUpstreamConn returned nil connection") + } + + auth2 := &cliproxyauth.Auth{ID: "auth-2"} + headers2 := http.Header{} + headers2.Set("Authorization", "Bearer token-2") + conn2, _, errDial2 := executor.ensureUpstreamConn(context.Background(), auth2, sess, auth2.ID, wsURL, headers2) + if errDial2 != nil { + t.Fatalf("second ensureUpstreamConn failed: %v", errDial2) + } + if conn2 == nil { + t.Fatal("second ensureUpstreamConn returned nil connection") + } + if conn1 == conn2 { + t.Fatal("expected auth change to force upstream reconnect") + } + + deadline := time.Now().Add(2 * time.Second) + for { + mu.Lock() + count := len(authorizations) + mu.Unlock() + if count >= 2 || time.Now().After(deadline) { + break + } + time.Sleep(10 * time.Millisecond) + } + + mu.Lock() + got := append([]string(nil), authorizations...) + mu.Unlock() + if len(got) < 2 { + t.Fatalf("handshake count = %d, want at least 2", len(got)) + } + if got[0] != "Bearer token-1" { + t.Fatalf("first Authorization = %q, want %q", got[0], "Bearer token-1") + } + if got[1] != "Bearer token-2" { + t.Fatalf("second Authorization = %q, want %q", got[1], "Bearer token-2") + } + + executor.closeExecutionSession(sess, "test_done") +} + +func TestCloseExecutionSessionUnblocksActiveRead(t *testing.T) { + t.Parallel() + + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + serverConnCh := make(chan *websocket.Conn, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + serverConnCh <- conn + _, _, _ = conn.ReadMessage() + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + clientConn, _, errDial := websocket.DefaultDialer.Dial(wsURL, nil) + if errDial != nil { + t.Fatalf("dial websocket: %v", errDial) + } + defer func() { _ = clientConn.Close() }() + + var serverConn *websocket.Conn + select { + case serverConn = <-serverConnCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for server websocket connection") + } + + sess := &codexWebsocketSession{ + sessionID: "session-close", + conn: serverConn, + readerConn: serverConn, + } + readCh := make(chan codexWebsocketRead, 4) + sess.setActive(readCh) + + executor := &CodexWebsocketsExecutor{ + CodexExecutor: &CodexExecutor{}, + sessions: map[string]*codexWebsocketSession{ + "session-close": sess, + }, + } + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + readErrCh := make(chan error, 1) + go func() { + _, _, err := readCodexWebsocketMessage(ctx, sess, serverConn, readCh) + readErrCh <- err + }() + + executor.CloseExecutionSession("session-close") + + select { + case err := <-readErrCh: + if err == nil { + t.Fatal("expected read error after closing execution session") + } + errText := err.Error() + if !strings.Contains(errText, "execution session closed") && !strings.Contains(errText, "session read channel closed") { + t.Fatalf("error = %v, want fast-fail error from session close path", err) + } + case <-time.After(3 * time.Second): + t.Fatal("read did not fail fast after closeExecutionSession") + } +} + +func TestEnsureUpstreamConnAuthSwitchRebuildsWebsocketConn(t *testing.T) { + t.Parallel() + + upgrader := websocket.Upgrader{CheckOrigin: func(*http.Request) bool { return true }} + authHeaderCh := make(chan string, 4) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := upgrader.Upgrade(w, r, nil) + if err != nil { + return + } + defer func() { _ = conn.Close() }() + + authHeaderCh <- strings.TrimSpace(r.Header.Get("Authorization")) + for { + _, _, errRead := conn.ReadMessage() + if errRead != nil { + return + } + } + })) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + executor := NewCodexWebsocketsExecutor(&config.Config{}) + sess := &codexWebsocketSession{sessionID: "session-auth-switch"} + + headers1 := http.Header{} + headers1.Set("Authorization", "Bearer token-1") + conn1, _, errDial1 := executor.ensureUpstreamConn(context.Background(), nil, sess, "auth-1", wsURL, headers1) + if errDial1 != nil { + t.Fatalf("ensureUpstreamConn auth-1 error: %v", errDial1) + } + if conn1 == nil { + t.Fatal("ensureUpstreamConn auth-1 returned nil conn") + } + + headers2 := http.Header{} + headers2.Set("Authorization", "Bearer token-2") + conn2, _, errDial2 := executor.ensureUpstreamConn(context.Background(), nil, sess, "auth-2", wsURL, headers2) + if errDial2 != nil { + t.Fatalf("ensureUpstreamConn auth-2 error: %v", errDial2) + } + if conn2 == nil { + t.Fatal("ensureUpstreamConn auth-2 returned nil conn") + } + if conn2 == conn1 { + t.Fatal("expected new websocket conn after auth switch") + } + + defer executor.invalidateUpstreamConn(sess, conn2, "test_done", nil) + + var got1, got2 string + select { + case got1 = <-authHeaderCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for first websocket handshake") + } + select { + case got2 = <-authHeaderCh: + case <-time.After(2 * time.Second): + t.Fatal("timed out waiting for second websocket handshake") + } + if got1 != "Bearer token-1" { + t.Fatalf("first Authorization = %q, want %q", got1, "Bearer token-1") + } + if got2 != "Bearer token-2" { + t.Fatalf("second Authorization = %q, want %q", got2, "Bearer token-2") + } + if got1 == got2 { + t.Fatal("expected different Authorization headers after auth switch") + } +} diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 1be245b702..5e6547ce24 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -114,6 +114,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("gemini-cli") @@ -268,6 +269,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("gemini-cli") diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index 7c25b8935f..dff863e0e7 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -112,6 +112,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) // Official Gemini API via API key or OAuth bearer from := opts.SourceFormat @@ -220,6 +221,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("gemini") diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index 84df56f995..84d29357d0 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -303,6 +303,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) var body []byte @@ -429,6 +430,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("gemini") @@ -534,6 +536,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("gemini") @@ -658,6 +661,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("gemini") diff --git a/internal/runtime/executor/iflow_executor.go b/internal/runtime/executor/iflow_executor.go index 65a0b8f81e..876fb0a232 100644 --- a/internal/runtime/executor/iflow_executor.go +++ b/internal/runtime/executor/iflow_executor.go @@ -88,6 +88,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("openai") @@ -191,6 +192,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("openai") diff --git a/internal/runtime/executor/kimi_executor.go b/internal/runtime/executor/kimi_executor.go index d5e3702f48..ab34221c36 100644 --- a/internal/runtime/executor/kimi_executor.go +++ b/internal/runtime/executor/kimi_executor.go @@ -78,6 +78,7 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) to := sdktranslator.FromString("openai") originalPayloadSource := req.Payload @@ -178,6 +179,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) to := sdktranslator.FromString("openai") originalPayloadSource := req.Payload diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index 623c66206a..a0d3037065 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -74,6 +74,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) baseURL, apiKey := e.resolveCredentials(auth) if baseURL == "" { @@ -181,6 +182,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) baseURL, apiKey := e.resolveCredentials(auth) if baseURL == "" { diff --git a/internal/runtime/executor/proxy_helpers.go b/internal/runtime/executor/proxy_helpers.go index 362024658c..d047484e03 100644 --- a/internal/runtime/executor/proxy_helpers.go +++ b/internal/runtime/executor/proxy_helpers.go @@ -3,6 +3,8 @@ package executor import ( "context" "net/http" + "os" + "net/url" "strings" "sync" "time" @@ -14,7 +16,9 @@ import ( ) var ( - proxyHTTPTransportCache sync.Map // map[string]*cachedProxyTransport + proxyHTTPTransportCache sync.Map // map[string]*cachedProxyTransport + environmentProxyKeys = []string{"HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"} + environmentProxyTransportCache sync.Map // map[string]*http.Transport ) type cachedProxyTransport struct { @@ -25,16 +29,8 @@ type cachedProxyTransport struct { // newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority: // 1. Use auth.ProxyURL if configured (highest priority) // 2. Use cfg.ProxyURL if auth proxy is not configured -// 3. Use RoundTripper from context if neither are configured -// -// Parameters: -// - ctx: The context containing optional RoundTripper -// - cfg: The application configuration -// - auth: The authentication information -// - timeout: The client timeout (0 means no timeout) -// -// Returns: -// - *http.Client: An HTTP client with configured proxy or transport +// 3. Use environment proxy settings if neither are configured +// 4. Use RoundTripper from context if no explicit or environment proxy is configured func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client { var contextTransport http.RoundTripper if ctx != nil { @@ -43,31 +39,25 @@ func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *clip } } - // Priority 1: Use auth.ProxyURL if configured var proxyURL string if auth != nil { proxyURL = strings.TrimSpace(auth.ProxyURL) } - - // Priority 2: Use cfg.ProxyURL if auth proxy is not configured if proxyURL == "" && cfg != nil { proxyURL = strings.TrimSpace(cfg.ProxyURL) } - // Priority 3: Use RoundTripper from context (typically from RoundTripperFor) - if contextTransport != nil && proxyURL == "" { - return newProxyHTTPClient(contextTransport, timeout) - } - - // If we have a proxy URL configured, set up the transport if proxyURL != "" { if transport := cachedTransportForProxyURL(proxyURL); transport != nil { return newProxyHTTPClient(transport, timeout) } - // If proxy setup failed, fall through to context RoundTripper. log.Debugf("failed to setup proxy from URL: %s, falling back to context transport", proxyURL) } + if environmentProxyConfigured() { + return newProxyHTTPClient(newEnvironmentProxyTransport(), timeout) + } + if contextTransport != nil { return newProxyHTTPClient(contextTransport, timeout) } @@ -96,14 +86,6 @@ func newProxyHTTPClient(transport http.RoundTripper, timeout time.Duration) *htt return client } -// buildProxyTransport creates an HTTP transport configured for the given proxy URL. -// It supports SOCKS5, HTTP, and HTTPS proxy protocols. -// -// Parameters: -// - proxyURL: The proxy URL string (e.g., "socks5://user:pass@host:port", "http://host:port") -// -// Returns: -// - *http.Transport: A configured transport, or nil if the proxy URL is invalid func buildProxyTransport(proxyURL string) *http.Transport { transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyURL) if errBuild != nil { @@ -112,3 +94,87 @@ func buildProxyTransport(proxyURL string) *http.Transport { } return transport } + +func environmentProxyConfigured() bool { + for _, key := range environmentProxyKeys { + if strings.TrimSpace(os.Getenv(key)) != "" { + return true + } + } + return false +} + +func newEnvironmentProxyTransport() *http.Transport { + signature := environmentProxySignature() + if cached, ok := environmentProxyTransportCache.Load(signature); ok { + return cached.(*http.Transport) + } + + proxyFunc := environmentProxyFunc() + var transport *http.Transport + if base, ok := http.DefaultTransport.(*http.Transport); ok && base != nil { + clone := base.Clone() + clone.Proxy = proxyFunc + transport = clone + } else { + transport = &http.Transport{Proxy: proxyFunc} + } + actual, _ := environmentProxyTransportCache.LoadOrStore(signature, transport) + return actual.(*http.Transport) +} + +func environmentProxySignature() string { + var values []string + for _, key := range environmentProxyKeys { + values = append(values, key+"="+strings.TrimSpace(os.Getenv(key))) + } + return strings.Join(values, "|") +} + +func environmentProxyFunc() func(*http.Request) (*url.URL, error) { + httpProxy := firstEnvironmentValue("HTTP_PROXY", "http_proxy") + httpsProxy := firstEnvironmentValue("HTTPS_PROXY", "https_proxy") + allProxy := firstEnvironmentValue("ALL_PROXY", "all_proxy") + + return func(req *http.Request) (*url.URL, error) { + if req == nil || req.URL == nil { + return nil, nil + } + + raw := "" + switch strings.ToLower(req.URL.Scheme) { + case "https": + raw = firstNonEmpty(httpsProxy, allProxy, httpProxy) + case "http": + raw = firstNonEmpty(httpProxy, allProxy, httpsProxy) + default: + raw = firstNonEmpty(allProxy, httpsProxy, httpProxy) + } + raw = strings.TrimSpace(raw) + if raw == "" { + return nil, nil + } + if !strings.Contains(raw, "://") { + raw = "http://" + raw + } + return url.Parse(raw) + } +} + +func firstEnvironmentValue(keys ...string) string { + for _, key := range keys { + if value := strings.TrimSpace(os.Getenv(key)); value != "" { + return value + } + } + return "" +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if strings.TrimSpace(value) != "" { + return value + } + } + return "" +} diff --git a/internal/runtime/executor/proxy_helpers_test.go b/internal/runtime/executor/proxy_helpers_test.go index 13df712ee6..4174dafdb9 100644 --- a/internal/runtime/executor/proxy_helpers_test.go +++ b/internal/runtime/executor/proxy_helpers_test.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "net/url" + "os" "testing" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -11,6 +12,27 @@ import ( sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" ) +func setEnvironmentProxy(t *testing.T, proxyURL string) { + t.Helper() + + for _, key := range []string{"HTTP_PROXY", "HTTPS_PROXY"} { + oldValue, hadValue := os.LookupEnv(key) + if err := os.Setenv(key, proxyURL); err != nil { + t.Fatalf("Setenv(%s): %v", key, err) + } + cleanupKey := key + cleanupOldValue := oldValue + cleanupHadValue := hadValue + t.Cleanup(func() { + if cleanupHadValue { + _ = os.Setenv(cleanupKey, cleanupOldValue) + return + } + _ = os.Unsetenv(cleanupKey) + }) + } +} + func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) { t.Parallel() @@ -81,6 +103,77 @@ func TestNewProxyAwareHTTPClientProxyReusesCachedClientWithoutTimeout(t *testing } } +func TestNewProxyAwareHTTPClientFallsBackToEnvironmentProxy(t *testing.T) { + setEnvironmentProxy(t, "http://env-proxy.example.com:8080") + + client := newProxyAwareHTTPClient(context.Background(), &config.Config{}, &cliproxyauth.Auth{}, 0) + + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T, want *http.Transport", client.Transport) + } + if transport.Proxy == nil { + t.Fatal("expected environment proxy transport to configure Proxy function") + } + req, errReq := http.NewRequest(http.MethodGet, "https://example.com", nil) + if errReq != nil { + t.Fatalf("NewRequest() error = %v", errReq) + } + proxyURL, errProxy := transport.Proxy(req) + if errProxy != nil { + t.Fatalf("transport.Proxy() error = %v", errProxy) + } + if proxyURL == nil || proxyURL.String() != "http://env-proxy.example.com:8080" { + t.Fatalf("proxy URL = %v, want http://env-proxy.example.com:8080", proxyURL) + } +} + +func TestNewProxyAwareHTTPClientExplicitProxyWinsOverEnvironmentProxy(t *testing.T) { + setEnvironmentProxy(t, "http://env-proxy.example.com:8080") + + client := newProxyAwareHTTPClient( + context.Background(), + &config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://config-proxy.example.com:8080"}}, + nil, + 0, + ) + + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T, want *http.Transport", client.Transport) + } + req, errReq := http.NewRequest(http.MethodGet, "https://example.com", nil) + if errReq != nil { + t.Fatalf("NewRequest() error = %v", errReq) + } + proxyURL, errProxy := transport.Proxy(req) + if errProxy != nil { + t.Fatalf("transport.Proxy() error = %v", errProxy) + } + if proxyURL == nil || proxyURL.String() != "http://config-proxy.example.com:8080" { + t.Fatalf("proxy URL = %v, want http://config-proxy.example.com:8080", proxyURL) + } +} + +func TestNewProxyAwareHTTPClientReusesEnvironmentProxyTransport(t *testing.T) { + setEnvironmentProxy(t, "http://env-proxy.example.com:8080") + + clientA := newProxyAwareHTTPClient(context.Background(), &config.Config{}, &cliproxyauth.Auth{}, 0) + clientB := newProxyAwareHTTPClient(context.Background(), &config.Config{}, &cliproxyauth.Auth{}, 0) + + transportA, okA := clientA.Transport.(*http.Transport) + if !okA { + t.Fatalf("clientA transport type = %T, want *http.Transport", clientA.Transport) + } + transportB, okB := clientB.Transport.(*http.Transport) + if !okB { + t.Fatalf("clientB transport type = %T, want *http.Transport", clientB.Transport) + } + if transportA != transportB { + t.Fatal("expected environment proxy transport to be shared across clients") + } +} + func TestNewProxyAwareHTTPClientNoProxyDoesNotLeakAntigravityTransportMutation(t *testing.T) { client := newAntigravityHTTPClient(context.Background(), nil, nil, 0) transport, ok := client.Transport.(*http.Transport) diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index e7957d2918..04eb73ac22 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -230,6 +230,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("openai") @@ -333,6 +334,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) + defer reporter.ensurePublished(ctx) from := opts.SourceFormat to := sdktranslator.FromString("openai") diff --git a/internal/translator/gemini/claude/gemini_claude_request.go b/internal/translator/gemini/claude/gemini_claude_request.go index 76d3219f1e..e230f5fd0d 100644 --- a/internal/translator/gemini/claude/gemini_claude_request.go +++ b/internal/translator/gemini/claude/gemini_claude_request.go @@ -6,7 +6,7 @@ package claude import ( - "bytes" + "fmt" "strings" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" @@ -31,33 +31,31 @@ const geminiClaudeThoughtSignature = "skip_thought_signature_validator" // - []byte: The transformed request in Gemini CLI format. func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte { rawJSON := inputRawJSON - rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1) - // Build output Gemini CLI request JSON - out := `{"contents":[]}` - out, _ = sjson.Set(out, "model", modelName) + out := []byte(`{"contents":[]}`) + out, _ = sjson.SetBytes(out, "model", modelName) // system instruction if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() { - systemInstruction := `{"role":"user","parts":[]}` + systemInstruction := []byte(`{"role":"user","parts":[]}`) hasSystemParts := false systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool { if systemPromptResult.Get("type").String() == "text" { textResult := systemPromptResult.Get("text") if textResult.Type == gjson.String { - part := `{"text":""}` - part, _ = sjson.Set(part, "text", textResult.String()) - systemInstruction, _ = sjson.SetRaw(systemInstruction, "parts.-1", part) + part := []byte(`{"text":""}`) + part, _ = sjson.SetBytes(part, "text", textResult.String()) + systemInstruction, _ = sjson.SetRawBytes(systemInstruction, "parts.-1", part) hasSystemParts = true } } return true }) if hasSystemParts { - out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction) + out, _ = sjson.SetRawBytes(out, "system_instruction", systemInstruction) } } else if systemResult.Type == gjson.String { - out, _ = sjson.Set(out, "system_instruction.parts.-1.text", systemResult.String()) + out, _ = sjson.SetBytes(out, "system_instruction.parts.-1.text", systemResult.String()) } // contents @@ -72,17 +70,17 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) role = "model" } - contentJSON := `{"role":"","parts":[]}` - contentJSON, _ = sjson.Set(contentJSON, "role", role) + contentJSON := []byte(`{"role":"","parts":[]}`) + contentJSON, _ = sjson.SetBytes(contentJSON, "role", role) contentsResult := messageResult.Get("content") if contentsResult.IsArray() { contentsResult.ForEach(func(_, contentResult gjson.Result) bool { switch contentResult.Get("type").String() { case "text": - part := `{"text":""}` - part, _ = sjson.Set(part, "text", contentResult.Get("text").String()) - contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part) + part := []byte(`{"text":""}`) + part, _ = sjson.SetBytes(part, "text", contentResult.Get("text").String()) + contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part) case "tool_use": functionName := contentResult.Get("name").String() @@ -95,11 +93,11 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) functionArgs := contentResult.Get("input").String() argsResult := gjson.Parse(functionArgs) if argsResult.IsObject() && gjson.Valid(functionArgs) { - part := `{"thoughtSignature":"","functionCall":{"name":"","args":{}}}` - part, _ = sjson.Set(part, "thoughtSignature", geminiClaudeThoughtSignature) - part, _ = sjson.Set(part, "functionCall.name", functionName) - part, _ = sjson.SetRaw(part, "functionCall.args", functionArgs) - contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part) + part := []byte(`{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`) + part, _ = sjson.SetBytes(part, "thoughtSignature", geminiClaudeThoughtSignature) + part, _ = sjson.SetBytes(part, "functionCall.name", functionName) + part, _ = sjson.SetRawBytes(part, "functionCall.args", []byte(functionArgs)) + contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part) } case "tool_result": @@ -113,10 +111,10 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) } funcName = util.SanitizeFunctionName(funcName) responseData := contentResult.Get("content").Raw - part := `{"functionResponse":{"name":"","response":{"result":""}}}` - part, _ = sjson.Set(part, "functionResponse.name", funcName) - part, _ = sjson.Set(part, "functionResponse.response.result", responseData) - contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part) + part := []byte(`{"functionResponse":{"name":"","response":{"result":""}}}`) + part, _ = sjson.SetBytes(part, "functionResponse.name", funcName) + part, _ = sjson.SetBytes(part, "functionResponse.response.result", responseData) + contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part) case "image": source := contentResult.Get("source") @@ -128,51 +126,84 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) if mimeType == "" || data == "" { return true } - part := `{"inline_data":{"mime_type":"","data":""}}` - part, _ = sjson.Set(part, "inline_data.mime_type", mimeType) - part, _ = sjson.Set(part, "inline_data.data", data) - contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part) + part := []byte(`{"inline_data":{"mime_type":"","data":""}}`) + part, _ = sjson.SetBytes(part, "inline_data.mime_type", mimeType) + part, _ = sjson.SetBytes(part, "inline_data.data", data) + contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part) } return true }) - out, _ = sjson.SetRaw(out, "contents.-1", contentJSON) + out, _ = sjson.SetRawBytes(out, "contents.-1", contentJSON) } else if contentsResult.Type == gjson.String { - part := `{"text":""}` - part, _ = sjson.Set(part, "text", contentsResult.String()) - contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part) - out, _ = sjson.SetRaw(out, "contents.-1", contentJSON) + part := []byte(`{"text":""}`) + part, _ = sjson.SetBytes(part, "text", contentsResult.String()) + contentJSON, _ = sjson.SetRawBytes(contentJSON, "parts.-1", part) + out, _ = sjson.SetRawBytes(out, "contents.-1", contentJSON) } return true }) } + // strip trailing model turn with unanswered function calls — + // Gemini returns empty responses when the last turn is a model + // functionCall with no corresponding user functionResponse. + contents := gjson.GetBytes(out, "contents") + if contents.Exists() && contents.IsArray() { + arr := contents.Array() + if len(arr) > 0 { + last := arr[len(arr)-1] + if last.Get("role").String() == "model" { + hasFC := false + last.Get("parts").ForEach(func(_, part gjson.Result) bool { + if part.Get("functionCall").Exists() { + hasFC = true + return false + } + return true + }) + if hasFC { + out, _ = sjson.DeleteBytes(out, fmt.Sprintf("contents.%d", len(arr)-1)) + } + } + } + } + // tools if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() { hasTools := false toolsResult.ForEach(func(_, toolResult gjson.Result) bool { inputSchemaResult := toolResult.Get("input_schema") if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { - inputSchema := inputSchemaResult.Raw - tool, _ := sjson.Delete(toolResult.Raw, "input_schema") - tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema) - tool, _ = sjson.Delete(tool, "strict") - tool, _ = sjson.Delete(tool, "input_examples") - tool, _ = sjson.Delete(tool, "type") - tool, _ = sjson.Delete(tool, "cache_control") - tool, _ = sjson.Delete(tool, "defer_loading") - tool, _ = sjson.Set(tool, "name", util.SanitizeFunctionName(gjson.Get(tool, "name").String())) - if gjson.Valid(tool) && gjson.Parse(tool).IsObject() { + inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw) + tool := []byte(toolResult.Raw) + var err error + tool, err = sjson.DeleteBytes(tool, "input_schema") + if err != nil { + return true + } + tool, err = sjson.SetRawBytes(tool, "parametersJsonSchema", []byte(inputSchema)) + if err != nil { + return true + } + tool, _ = sjson.DeleteBytes(tool, "strict") + tool, _ = sjson.DeleteBytes(tool, "input_examples") + tool, _ = sjson.DeleteBytes(tool, "type") + tool, _ = sjson.DeleteBytes(tool, "cache_control") + tool, _ = sjson.DeleteBytes(tool, "defer_loading") + tool, _ = sjson.DeleteBytes(tool, "eager_input_streaming") + tool, _ = sjson.SetBytes(tool, "name", util.SanitizeFunctionName(gjson.GetBytes(tool, "name").String())) + if gjson.ValidBytes(tool) && gjson.ParseBytes(tool).IsObject() { if !hasTools { - out, _ = sjson.SetRaw(out, "tools", `[{"functionDeclarations":[]}]`) + out, _ = sjson.SetRawBytes(out, "tools", []byte(`[{"functionDeclarations":[]}]`)) hasTools = true } - out, _ = sjson.SetRaw(out, "tools.0.functionDeclarations.-1", tool) + out, _ = sjson.SetRawBytes(out, "tools.0.functionDeclarations.-1", tool) } } return true }) if !hasTools { - out, _ = sjson.Delete(out, "tools") + out, _ = sjson.DeleteBytes(out, "tools") } } @@ -190,15 +221,15 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) switch toolChoiceType { case "auto": - out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "AUTO") + out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.mode", "AUTO") case "none": - out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "NONE") + out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.mode", "NONE") case "any": - out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "ANY") + out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.mode", "ANY") case "tool": - out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "ANY") + out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.mode", "ANY") if toolChoiceName != "" { - out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)}) + out, _ = sjson.SetBytes(out, "toolConfig.functionCallingConfig.allowedFunctionNames", []string{util.SanitizeFunctionName(toolChoiceName)}) } } } @@ -210,8 +241,8 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) case "enabled": if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number { budget := int(b.Int()) - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget) - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", budget) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.includeThoughts", true) } case "adaptive", "auto": // For adaptive thinking: @@ -223,32 +254,32 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) effort = strings.ToLower(strings.TrimSpace(v.String())) } if effort != "" { - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", effort) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingLevel", effort) } else { maxBudget := 0 if mi := registry.LookupModelInfo(modelName, "gemini"); mi != nil && mi.Thinking != nil { maxBudget = mi.Thinking.Max } if maxBudget > 0 { - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", maxBudget) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", maxBudget) } else { - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", "high") + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingLevel", "high") } } - out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true) + out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.includeThoughts", true) } } if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number { - out, _ = sjson.Set(out, "generationConfig.temperature", v.Num) + out, _ = sjson.SetBytes(out, "generationConfig.temperature", v.Num) } if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number { - out, _ = sjson.Set(out, "generationConfig.topP", v.Num) + out, _ = sjson.SetBytes(out, "generationConfig.topP", v.Num) } if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number { - out, _ = sjson.Set(out, "generationConfig.topK", v.Num) + out, _ = sjson.SetBytes(out, "generationConfig.topK", v.Num) } - result := []byte(out) + result := out result = common.AttachDefaultSafetySettings(result, "safetySettings") return result diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 72baac022c..8f41850b48 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -6,6 +6,7 @@ package handlers import ( "bytes" "encoding/json" + "errors" "fmt" "net/http" "strings" @@ -586,24 +587,29 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl SourceFormat: sdktranslator.FromString(handlerType), } opts.Metadata = reqMeta - streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) + maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg) + streamResult, initialBootstrapRetries, err := h.executeStreamWithBootstrapRetry(ctx, providers, req, opts, maxBootstrapRetries) if err != nil { - errChan := make(chan *interfaces.ErrorMessage, 1) - status := http.StatusInternalServerError - if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { - if code := se.StatusCode(); code > 0 { - status = code + if shouldWrapImmediateStreamError(err) { + streamResult = streamResultFromError(err) + } else { + errChan := make(chan *interfaces.ErrorMessage, 1) + status := http.StatusInternalServerError + if se, ok := err.(interface{ StatusCode() int }); ok && se != nil { + if code := se.StatusCode(); code > 0 { + status = code + } } - } - var addon http.Header - if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil { - if hdr := he.Headers(); hdr != nil { - addon = hdr.Clone() + var addon http.Header + if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil { + if hdr := he.Headers(); hdr != nil { + addon = hdr.Clone() + } } + errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} + close(errChan) + return nil, nil, errChan } - errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon} - close(errChan) - return nil, nil, errChan } passthroughHeadersEnabled := PassthroughHeadersEnabled(h.Cfg) // Capture upstream headers from the initial connection synchronously before the goroutine starts. @@ -622,8 +628,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl defer close(dataChan) defer close(errChan) sentPayload := false - bootstrapRetries := 0 - maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg) + bootstrapRetries := initialBootstrapRetries sendErr := func(msg *interfaces.ErrorMessage) bool { if ctx == nil { @@ -651,20 +656,6 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl } } - bootstrapEligible := func(err error) bool { - status := statusFromError(err) - if status == 0 { - return true - } - switch status { - case http.StatusUnauthorized, http.StatusForbidden, http.StatusPaymentRequired, - http.StatusRequestTimeout, http.StatusTooManyRequests: - return true - default: - return status >= http.StatusInternalServerError - } - } - outer: for { for { @@ -700,7 +691,6 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl streamErr = retryErr } } - status := http.StatusInternalServerError if se, ok := streamErr.(interface{ StatusCode() int }); ok && se != nil { if code := se.StatusCode(); code > 0 { @@ -734,6 +724,65 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl return dataChan, upstreamHeaders, errChan } +func (h *BaseAPIHandler) executeStreamWithBootstrapRetry(ctx context.Context, providers []string, req coreexecutor.Request, opts coreexecutor.Options, maxBootstrapRetries int) (*coreexecutor.StreamResult, int, error) { + bootstrapRetries := 0 + for { + streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) + if err == nil { + return streamResult, bootstrapRetries, nil + } + if ctx != nil && ctx.Err() != nil { + return nil, bootstrapRetries, ctx.Err() + } + if bootstrapRetries >= maxBootstrapRetries || !bootstrapEligible(err) { + return nil, bootstrapRetries, err + } + bootstrapRetries++ + } +} + +func bootstrapEligible(err error) bool { + status := statusFromError(err) + if status == 0 { + return true + } + switch status { + case http.StatusUnauthorized, http.StatusForbidden, http.StatusPaymentRequired, + http.StatusRequestTimeout, http.StatusTooManyRequests: + return true + default: + return status >= http.StatusInternalServerError + } +} + +func shouldWrapImmediateStreamError(err error) bool { + if err == nil { + return false + } + status := statusFromError(err) + switch status { + case http.StatusBadRequest: + return !strings.Contains(err.Error(), "invalid_request_error") + case http.StatusUnprocessableEntity: + return false + } + var authErr *coreauth.Error + if errors.As(err, &authErr) && authErr != nil { + switch authErr.Code { + case "auth_not_found", "provider_not_found": + return false + } + } + return bootstrapEligible(err) +} + +func streamResultFromError(err error) *coreexecutor.StreamResult { + errCh := make(chan coreexecutor.StreamChunk, 1) + errCh <- coreexecutor.StreamChunk{Err: err} + close(errCh) + return &coreexecutor.StreamResult{Chunks: errCh} +} + func validateSSEDataJSON(chunk []byte) error { for _, line := range bytes.Split(chunk, []byte("\n")) { line = bytes.TrimSpace(line) diff --git a/sdk/api/handlers/handlers_stream_bootstrap_test.go b/sdk/api/handlers/handlers_stream_bootstrap_test.go index b08e3a99de..61c0333227 100644 --- a/sdk/api/handlers/handlers_stream_bootstrap_test.go +++ b/sdk/api/handlers/handlers_stream_bootstrap_test.go @@ -136,6 +136,8 @@ type authAwareStreamExecutor struct { type invalidJSONStreamExecutor struct{} +type splitResponsesEventStreamExecutor struct{} + func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" } func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -165,6 +167,36 @@ func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *corea } } +func (e *splitResponsesEventStreamExecutor) Identifier() string { return "split-sse" } + +func (e *splitResponsesEventStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} +} + +func (e *splitResponsesEventStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) { + ch := make(chan coreexecutor.StreamChunk, 2) + ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed")} + ch <- coreexecutor.StreamChunk{Payload: []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}")} + close(ch) + return &coreexecutor.StreamResult{Chunks: ch}, nil +} + +func (e *splitResponsesEventStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *splitResponsesEventStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"} +} + +func (e *splitResponsesEventStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) { + return nil, &coreauth.Error{ + Code: "not_implemented", + Message: "HttpRequest not implemented", + HTTPStatus: http.StatusNotImplemented, + } +} + func (e *authAwareStreamExecutor) Identifier() string { return "codex" } func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -607,3 +639,52 @@ func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t * t.Fatalf("expected terminal error") } } + +func TestExecuteStreamWithAuthManager_AllowsSplitOpenAIResponsesSSEEventLines(t *testing.T) { + executor := &splitResponsesEventStreamExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + auth1 := &coreauth.Auth{ + ID: "auth1", + Provider: "split-sse", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test1@example.com"}, + } + if _, err := manager.Register(context.Background(), auth1); err != nil { + t.Fatalf("manager.Register(auth1): %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth1.ID) + }) + + handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "") + if dataChan == nil || errChan == nil { + t.Fatalf("expected non-nil channels") + } + + var got []string + for chunk := range dataChan { + got = append(got, string(chunk)) + } + + for msg := range errChan { + if msg != nil { + t.Fatalf("unexpected error: %+v", msg) + } + } + + if len(got) != 2 { + t.Fatalf("expected 2 forwarded chunks, got %d: %#v", len(got), got) + } + if got[0] != "event: response.completed" { + t.Fatalf("unexpected first chunk: %q", got[0]) + } + expectedData := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}" + if got[1] != expectedData { + t.Fatalf("unexpected second chunk.\nGot: %q\nWant: %q", got[1], expectedData) + } +} diff --git a/sdk/api/handlers/openai/openai_responses_handlers.go b/sdk/api/handlers/openai/openai_responses_handlers.go index 3bca75f943..388b86bc9b 100644 --- a/sdk/api/handlers/openai/openai_responses_handlers.go +++ b/sdk/api/handlers/openai/openai_responses_handlers.go @@ -9,7 +9,9 @@ package openai import ( "bytes" "context" + "encoding/json" "fmt" + "io" "net/http" "github.com/gin-gonic/gin" @@ -21,10 +23,181 @@ import ( "github.com/tidwall/sjson" ) +func writeResponsesSSEChunk(w io.Writer, chunk []byte) { + if w == nil || len(chunk) == 0 { + return + } + if _, err := w.Write(chunk); err != nil { + return + } + if bytes.HasSuffix(chunk, []byte("\n\n")) || bytes.HasSuffix(chunk, []byte("\r\n\r\n")) { + return + } + suffix := []byte("\n\n") + if bytes.HasSuffix(chunk, []byte("\r\n")) { + suffix = []byte("\r\n") + } else if bytes.HasSuffix(chunk, []byte("\n")) { + suffix = []byte("\n") + } + if _, err := w.Write(suffix); err != nil { + return + } +} + +type responsesSSEFramer struct { + pending []byte +} + +func (f *responsesSSEFramer) WriteChunk(w io.Writer, chunk []byte) { + if len(chunk) == 0 { + return + } + if responsesSSENeedsLineBreak(f.pending, chunk) { + f.pending = append(f.pending, '\n') + } + f.pending = append(f.pending, chunk...) + for { + frameLen := responsesSSEFrameLen(f.pending) + if frameLen == 0 { + break + } + writeResponsesSSEChunk(w, f.pending[:frameLen]) + copy(f.pending, f.pending[frameLen:]) + f.pending = f.pending[:len(f.pending)-frameLen] + } + if len(bytes.TrimSpace(f.pending)) == 0 { + f.pending = f.pending[:0] + return + } + if len(f.pending) == 0 || !responsesSSECanEmitWithoutDelimiter(f.pending) { + return + } + writeResponsesSSEChunk(w, f.pending) + f.pending = f.pending[:0] +} + +func (f *responsesSSEFramer) Flush(w io.Writer) { + if len(f.pending) == 0 { + return + } + if len(bytes.TrimSpace(f.pending)) == 0 { + f.pending = f.pending[:0] + return + } + if !responsesSSECanEmitWithoutDelimiter(f.pending) { + f.pending = f.pending[:0] + return + } + writeResponsesSSEChunk(w, f.pending) + f.pending = f.pending[:0] +} + +func responsesSSEFrameLen(chunk []byte) int { + if len(chunk) == 0 { + return 0 + } + lf := bytes.Index(chunk, []byte("\n\n")) + crlf := bytes.Index(chunk, []byte("\r\n\r\n")) + switch { + case lf < 0: + if crlf < 0 { + return 0 + } + return crlf + 4 + case crlf < 0: + return lf + 2 + case lf < crlf: + return lf + 2 + default: + return crlf + 4 + } +} + +func responsesSSENeedsMoreData(chunk []byte) bool { + trimmed := bytes.TrimSpace(chunk) + if len(trimmed) == 0 { + return false + } + return responsesSSEHasField(trimmed, []byte("event:")) && !responsesSSEHasField(trimmed, []byte("data:")) +} + +func responsesSSEHasField(chunk []byte, prefix []byte) bool { + s := chunk + for len(s) > 0 { + line := s + if i := bytes.IndexByte(s, '\n'); i >= 0 { + line = s[:i] + s = s[i+1:] + } else { + s = nil + } + line = bytes.TrimSpace(line) + if bytes.HasPrefix(line, prefix) { + return true + } + } + return false +} + +func responsesSSECanEmitWithoutDelimiter(chunk []byte) bool { + trimmed := bytes.TrimSpace(chunk) + if len(trimmed) == 0 || responsesSSENeedsMoreData(trimmed) || !responsesSSEHasField(trimmed, []byte("data:")) { + return false + } + return responsesSSEDataLinesValid(trimmed) +} + +func responsesSSEDataLinesValid(chunk []byte) bool { + s := chunk + for len(s) > 0 { + line := s + if i := bytes.IndexByte(s, '\n'); i >= 0 { + line = s[:i] + s = s[i+1:] + } else { + s = nil + } + line = bytes.TrimSpace(line) + if len(line) == 0 || !bytes.HasPrefix(line, []byte("data:")) { + continue + } + data := bytes.TrimSpace(line[len("data:"):]) + if len(data) == 0 || bytes.Equal(data, []byte("[DONE]")) { + continue + } + if !json.Valid(data) { + return false + } + } + return true +} + +func responsesSSENeedsLineBreak(pending, chunk []byte) bool { + if len(pending) == 0 || len(chunk) == 0 { + return false + } + if bytes.HasSuffix(pending, []byte("\n")) || bytes.HasSuffix(pending, []byte("\r")) { + return false + } + if chunk[0] == '\n' || chunk[0] == '\r' { + return false + } + trimmed := bytes.TrimLeft(chunk, " \t") + if len(trimmed) == 0 { + return false + } + for _, prefix := range [][]byte{[]byte("data:"), []byte("event:"), []byte("id:"), []byte("retry:"), []byte(":")} { + if bytes.HasPrefix(trimmed, prefix) { + return true + } + } + return false +} // OpenAIResponsesAPIHandler contains the handlers for OpenAIResponses API endpoints. // It holds a pool of clients to interact with the backend service. type OpenAIResponsesAPIHandler struct { *handlers.BaseAPIHandler + turnState responsesTurnStateCache } // NewOpenAIResponsesAPIHandler creates a new OpenAIResponses API handlers instance. @@ -82,6 +255,12 @@ func (h *OpenAIResponsesAPIHandler) Responses(c *gin.Context) { return } + rawJSON, errMsg := h.normalizeContinuationRequest(rawJSON) + if errMsg != nil { + h.WriteErrorResponse(c, errMsg) + return + } + // Check if the client requested a streaming response. streamResult := gjson.GetBytes(rawJSON, "stream") if streamResult.Type == gjson.True { @@ -131,6 +310,7 @@ func (h *OpenAIResponsesAPIHandler) Compact(c *gin.Context) { cliCancel(errMsg.Error) return } + h.rememberCompletedResponse(rawJSON, resp) handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() @@ -157,6 +337,7 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r cliCancel(errMsg.Error) return } + h.rememberCompletedResponse(rawJSON, resp) handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) _, _ = c.Writer.Write(resp) cliCancel() @@ -193,6 +374,7 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ c.Header("Connection", "keep-alive") c.Header("Access-Control-Allow-Origin", "*") } + framer := &responsesSSEFramer{} // Peek at the first chunk for { @@ -230,30 +412,28 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders) // Write first chunk logic (matching forwardResponsesStream) - if bytes.HasPrefix(chunk, []byte("event:")) { - _, _ = c.Writer.Write([]byte("\n")) - } - _, _ = c.Writer.Write(chunk) - _, _ = c.Writer.Write([]byte("\n")) + h.rememberCompletedResponseFromChunk(rawJSON, chunk) + framer.WriteChunk(c.Writer, chunk) flusher.Flush() // Continue - h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan) + h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, rawJSON, dataChan, errChan, framer) return } } } -func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) { +func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), requestJSON []byte, data <-chan []byte, errs <-chan *interfaces.ErrorMessage, framer *responsesSSEFramer) { + if framer == nil { + framer = &responsesSSEFramer{} + } h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{ WriteChunk: func(chunk []byte) { - if bytes.HasPrefix(chunk, []byte("event:")) { - _, _ = c.Writer.Write([]byte("\n")) - } - _, _ = c.Writer.Write(chunk) - _, _ = c.Writer.Write([]byte("\n")) + h.rememberCompletedResponseFromChunk(requestJSON, chunk) + framer.WriteChunk(c.Writer, chunk) }, WriteTerminalError: func(errMsg *interfaces.ErrorMessage) { + framer.Flush(c.Writer) if errMsg == nil { return } @@ -269,6 +449,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush _, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk)) }, WriteDone: func() { + framer.Flush(c.Writer) _, _ = c.Writer.Write([]byte("\n")) }, }) diff --git a/sdk/api/handlers/openai/openai_responses_handlers_stream_error_test.go b/sdk/api/handlers/openai/openai_responses_handlers_stream_error_test.go index dce738073c..af05e74e0f 100644 --- a/sdk/api/handlers/openai/openai_responses_handlers_stream_error_test.go +++ b/sdk/api/handlers/openai/openai_responses_handlers_stream_error_test.go @@ -32,7 +32,7 @@ func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")} close(errs) - h.forwardResponsesStream(c, flusher, func(error) {}, data, errs) + h.forwardResponsesStream(c, flusher, func(error) {}, nil, data, errs, nil) body := recorder.Body.String() if !strings.Contains(body, `"type":"error"`) { t.Fatalf("expected responses error chunk, got: %q", body) diff --git a/sdk/api/handlers/openai/openai_responses_handlers_stream_test.go b/sdk/api/handlers/openai/openai_responses_handlers_stream_test.go new file mode 100644 index 0000000000..e7b08da70f --- /dev/null +++ b/sdk/api/handlers/openai/openai_responses_handlers_stream_test.go @@ -0,0 +1,142 @@ +package openai + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +func newResponsesStreamTestHandler(t *testing.T) (*OpenAIResponsesAPIHandler, *httptest.ResponseRecorder, *gin.Context, http.Flusher) { + t.Helper() + + gin.SetMode(gin.TestMode) + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil) + h := NewOpenAIResponsesAPIHandler(base) + + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil) + + flusher, ok := c.Writer.(http.Flusher) + if !ok { + t.Fatalf("expected gin writer to implement http.Flusher") + } + + return h, recorder, c, flusher +} + +func TestForwardResponsesStreamSeparatesDataOnlySSEChunks(t *testing.T) { + h, recorder, c, flusher := newResponsesStreamTestHandler(t) + + data := make(chan []byte, 2) + errs := make(chan *interfaces.ErrorMessage) + data <- []byte("data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call\",\"arguments\":\"{}\"}}") + data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}") + close(data) + close(errs) + + h.forwardResponsesStream(c, flusher, func(error) {}, nil, data, errs, nil) + body := recorder.Body.String() + parts := strings.Split(strings.TrimSpace(body), "\n\n") + if len(parts) != 2 { + t.Fatalf("expected 2 SSE events, got %d. Body: %q", len(parts), body) + } + + expectedPart1 := "data: {\"type\":\"response.output_item.done\",\"item\":{\"type\":\"function_call\",\"arguments\":\"{}\"}}" + if parts[0] != expectedPart1 { + t.Errorf("unexpected first event.\nGot: %q\nWant: %q", parts[0], expectedPart1) + } + + expectedPart2 := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[]}}" + if parts[1] != expectedPart2 { + t.Errorf("unexpected second event.\nGot: %q\nWant: %q", parts[1], expectedPart2) + } +} + +func TestForwardResponsesStreamReassemblesSplitSSEEventChunks(t *testing.T) { + h, recorder, c, flusher := newResponsesStreamTestHandler(t) + + data := make(chan []byte, 3) + errs := make(chan *interfaces.ErrorMessage) + data <- []byte("event: response.created") + data <- []byte("data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}") + data <- []byte("\n") + close(data) + close(errs) + + h.forwardResponsesStream(c, flusher, func(error) {}, nil, data, errs, nil) + + got := strings.TrimSuffix(recorder.Body.String(), "\n") + want := "event: response.created\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\n" + if got != want { + t.Fatalf("unexpected split-event framing.\nGot: %q\nWant: %q", got, want) + } +} + +func TestForwardResponsesStreamPreservesValidFullSSEEventChunks(t *testing.T) { + h, recorder, c, flusher := newResponsesStreamTestHandler(t) + + data := make(chan []byte, 1) + errs := make(chan *interfaces.ErrorMessage) + chunk := []byte("event: response.created\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\n") + data <- chunk + close(data) + close(errs) + + h.forwardResponsesStream(c, flusher, func(error) {}, nil, data, errs, nil) + + got := strings.TrimSuffix(recorder.Body.String(), "\n") + if got != string(chunk) { + t.Fatalf("unexpected full-event framing.\nGot: %q\nWant: %q", got, string(chunk)) + } +} + +func TestForwardResponsesStreamBuffersSplitDataPayloadChunks(t *testing.T) { + h, recorder, c, flusher := newResponsesStreamTestHandler(t) + + data := make(chan []byte, 2) + errs := make(chan *interfaces.ErrorMessage) + data <- []byte("data: {\"type\":\"response.created\"") + data <- []byte(",\"response\":{\"id\":\"resp-1\"}}") + close(data) + close(errs) + + h.forwardResponsesStream(c, flusher, func(error) {}, nil, data, errs, nil) + + got := recorder.Body.String() + want := "data: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\n\n" + if got != want { + t.Fatalf("unexpected split-data framing.\nGot: %q\nWant: %q", got, want) + } +} + +func TestResponsesSSENeedsLineBreakSkipsChunksThatAlreadyStartWithNewline(t *testing.T) { + if responsesSSENeedsLineBreak([]byte("event: response.created"), []byte("\n")) { + t.Fatal("expected no injected newline before newline-only chunk") + } + if responsesSSENeedsLineBreak([]byte("event: response.created"), []byte("\r\n")) { + t.Fatal("expected no injected newline before CRLF chunk") + } +} + +func TestForwardResponsesStreamDropsIncompleteTrailingDataChunkOnFlush(t *testing.T) { + h, recorder, c, flusher := newResponsesStreamTestHandler(t) + + data := make(chan []byte, 1) + errs := make(chan *interfaces.ErrorMessage) + data <- []byte("data: {\"type\":\"response.created\"") + close(data) + close(errs) + + h.forwardResponsesStream(c, flusher, func(error) {}, nil, data, errs, nil) + + if got := recorder.Body.String(); got != "\n" { + t.Fatalf("expected incomplete trailing data to be dropped on flush.\nGot: %q", got) + } +} diff --git a/sdk/api/handlers/openai/openai_responses_http_continuation_test.go b/sdk/api/handlers/openai/openai_responses_http_continuation_test.go new file mode 100644 index 0000000000..7248c930a2 --- /dev/null +++ b/sdk/api/handlers/openai/openai_responses_http_continuation_test.go @@ -0,0 +1,149 @@ +package openai + +import ( + "context" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" + "github.com/tidwall/gjson" +) + +type responsesContinuationCaptureExecutor struct { + calls int + payloads [][]byte +} + +func (e *responsesContinuationCaptureExecutor) Identifier() string { return "test-provider" } + +func (e *responsesContinuationCaptureExecutor) Execute(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (coreexecutor.Response, error) { + e.calls++ + e.payloads = append(e.payloads, append([]byte(nil), req.Payload...)) + if e.calls == 1 { + return coreexecutor.Response{Payload: []byte(`{"id":"resp-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"message","id":"assistant-1"}]}`)}, nil + } + return coreexecutor.Response{Payload: []byte(`{"id":"resp-2","output":[{"type":"message","id":"assistant-2"}]}`)}, nil +} + +func (e *responsesContinuationCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) { + return nil, errors.New("not implemented") +} + +func (e *responsesContinuationCaptureExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *responsesContinuationCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *responsesContinuationCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) { + return nil, errors.New("not implemented") +} + +func TestOpenAIResponsesHTTPContinuationMergesCachedTurn(t *testing.T) { + gin.SetMode(gin.TestMode) + executor := &responsesContinuationCaptureExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + auth := &coreauth.Auth{ID: "auth-http-cont", Provider: executor.Identifier(), Status: coreauth.StatusActive} + if _, err := manager.Register(context.Background(), auth); err != nil { + t.Fatalf("Register auth: %v", err) + } + registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + router := gin.New() + router.POST("/v1/responses", h.Responses) + + firstReq := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"test-model","input":[{"type":"message","id":"msg-1"}]}`)) + firstReq.Header.Set("Content-Type", "application/json") + firstResp := httptest.NewRecorder() + router.ServeHTTP(firstResp, firstReq) + if firstResp.Code != http.StatusOK { + t.Fatalf("first status = %d, want %d", firstResp.Code, http.StatusOK) + } + + secondReq := httptest.NewRequest(http.MethodPost, "/v1/responses", strings.NewReader(`{"model":"test-model","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)) + secondReq.Header.Set("Content-Type", "application/json") + secondResp := httptest.NewRecorder() + router.ServeHTTP(secondResp, secondReq) + if secondResp.Code != http.StatusOK { + t.Fatalf("second status = %d, want %d body=%s", secondResp.Code, http.StatusOK, secondResp.Body.String()) + } + + if executor.calls != 2 { + t.Fatalf("executor calls = %d, want 2", executor.calls) + } + if gjson.GetBytes(executor.payloads[1], "previous_response_id").Exists() { + t.Fatalf("second payload must not include previous_response_id: %s", executor.payloads[1]) + } + input := gjson.GetBytes(executor.payloads[1], "input").Array() + if len(input) != 4 { + t.Fatalf("merged input len = %d, want 4: %s", len(input), executor.payloads[1]) + } + if input[0].Get("id").String() != "msg-1" || + input[1].Get("id").String() != "fc-1" || + input[2].Get("id").String() != "assistant-1" || + input[3].Get("id").String() != "tool-out-1" { + t.Fatalf("unexpected merged input order: %s", executor.payloads[1]) + } +} + +func TestNormalizeContinuationRequestSupportsStringInputShorthand(t *testing.T) { + h := &OpenAIResponsesAPIHandler{} + h.rememberCompletedResponse( + []byte(`{"model":"test-model","input":"Use the weather tool for Paris."}`), + []byte(`{"id":"resp-str-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"message","id":"assistant-1"}]}`), + ) + + normalized, errMsg := h.normalizeContinuationRequest([]byte(`{"previous_response_id":"resp-str-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)) + if errMsg != nil { + t.Fatalf("unexpected error: %v", errMsg.Error) + } + input := gjson.GetBytes(normalized, "input").Array() + if len(input) != 4 { + t.Fatalf("merged input len = %d, want 4: %s", len(input), normalized) + } + if input[0].Get("role").String() != "user" { + t.Fatalf("expected normalized first item to be user message: %s", normalized) + } + if input[0].Get("content").String() != "Use the weather tool for Paris." { + t.Fatalf("unexpected normalized first item content: %s", normalized) + } +} + +func TestRememberCompletedResponseFromChunkCachesStreamingTurn(t *testing.T) { + h := &OpenAIResponsesAPIHandler{} + requestJSON := []byte(`{"model":"test-model","input":[{"type":"message","id":"msg-1"}]}`) + chunk := []byte(`event: response.completed +data: {"type":"response.completed","response":{"id":"resp-stream-1","output":[{"type":"function_call","id":"fc-1","call_id":"call-1"},{"type":"message","id":"assistant-1"}]}} + +`) + h.rememberCompletedResponseFromChunk(requestJSON, chunk) + + normalized, errMsg := h.normalizeContinuationRequest([]byte(`{"previous_response_id":"resp-stream-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)) + if errMsg != nil { + t.Fatalf("unexpected error: %v", errMsg.Error) + } + input := gjson.GetBytes(normalized, "input").Array() + if len(input) != 4 { + t.Fatalf("merged input len = %d, want 4: %s", len(input), normalized) + } + if input[3].Get("id").String() != "tool-out-1" { + t.Fatalf("unexpected merged payload: %s", normalized) + } +} diff --git a/sdk/api/handlers/openai/openai_responses_turn_state.go b/sdk/api/handlers/openai/openai_responses_turn_state.go new file mode 100644 index 0000000000..b73dce6672 --- /dev/null +++ b/sdk/api/handlers/openai/openai_responses_turn_state.go @@ -0,0 +1,185 @@ +package openai + +import ( + "bytes" + "fmt" + "net/http" + "strings" + "sync" + "time" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const responsesTurnStateTTL = 30 * time.Minute + +type responsesTurnStateEntry struct { + request []byte + output []byte + expire time.Time +} + +type responsesTurnStateCache struct { + entries sync.Map +} + +func (c *responsesTurnStateCache) load(responseID string) ([]byte, []byte, bool) { + responseID = strings.TrimSpace(responseID) + if responseID == "" { + return nil, nil, false + } + raw, ok := c.entries.Load(responseID) + if !ok { + return nil, nil, false + } + entry, ok := raw.(responsesTurnStateEntry) + if !ok { + c.entries.Delete(responseID) + return nil, nil, false + } + if !entry.expire.After(time.Now()) { + c.entries.Delete(responseID) + return nil, nil, false + } + return bytes.Clone(entry.request), bytes.Clone(entry.output), true +} + +func (c *responsesTurnStateCache) store(responseID string, requestJSON []byte, outputJSON []byte) { + responseID = strings.TrimSpace(responseID) + if responseID == "" || len(requestJSON) == 0 || len(outputJSON) == 0 { + return + } + c.entries.Store(responseID, responsesTurnStateEntry{ + request: bytes.Clone(requestJSON), + output: bytes.Clone(outputJSON), + expire: time.Now().Add(responsesTurnStateTTL), + }) +} + +func normalizeResponsesRequestInputRaw(rawJSON []byte) (string, *interfaces.ErrorMessage) { + input := gjson.GetBytes(rawJSON, "input") + if !input.Exists() { + return "[]", nil + } + if input.IsArray() { + return input.Raw, nil + } + if input.Type == gjson.String { + message := []byte(`{"type":"message","role":"user","content":""}`) + message, _ = sjson.SetBytes(message, "content", input.String()) + return fmt.Sprintf("[%s]", message), nil + } + return "", &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("responses request requires array or string field: input"), + } +} + +func normalizeResponsesHTTPContinuationRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, *interfaces.ErrorMessage) { + if len(lastRequest) == 0 { + return rawJSON, nil + } + + nextInput := gjson.GetBytes(rawJSON, "input") + if !nextInput.Exists() || !nextInput.IsArray() { + return nil, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("responses request requires array field: input"), + } + } + + existingInputRaw, errMsg := normalizeResponsesRequestInputRaw(lastRequest) + if errMsg != nil { + return nil, errMsg + } + mergedInput, errMerge := mergeJSONArrayRaw(existingInputRaw, normalizeJSONArrayRaw(lastResponseOutput)) + if errMerge != nil { + return nil, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("invalid previous response output: %w", errMerge), + } + } + mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw) + if errMerge != nil { + return nil, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("invalid request input: %w", errMerge), + } + } + + normalized, errDelete := sjson.DeleteBytes(rawJSON, "previous_response_id") + if errDelete != nil { + normalized = bytes.Clone(rawJSON) + } + var errSet error + normalized, errSet = sjson.SetRawBytes(normalized, "input", []byte(mergedInput)) + if errSet != nil { + return nil, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("failed to merge responses input: %w", errSet), + } + } + if !gjson.GetBytes(normalized, "model").Exists() { + modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) + if modelName != "" { + normalized, _ = sjson.SetBytes(normalized, "model", modelName) + } + } + if !gjson.GetBytes(normalized, "instructions").Exists() { + instructions := gjson.GetBytes(lastRequest, "instructions") + if instructions.Exists() { + normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw)) + } + } + return normalized, nil +} + +func (h *OpenAIResponsesAPIHandler) normalizeContinuationRequest(rawJSON []byte) ([]byte, *interfaces.ErrorMessage) { + if h == nil { + return rawJSON, nil + } + previousResponseID := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()) + if previousResponseID == "" { + return rawJSON, nil + } + lastRequest, lastResponseOutput, ok := h.turnState.load(previousResponseID) + if !ok { + return rawJSON, nil + } + return normalizeResponsesHTTPContinuationRequest(rawJSON, lastRequest, lastResponseOutput) +} + +func (h *OpenAIResponsesAPIHandler) rememberCompletedResponse(requestJSON []byte, responseJSON []byte) { + if h == nil { + return + } + responseID := strings.TrimSpace(gjson.GetBytes(responseJSON, "id").String()) + if responseID == "" { + responseID = strings.TrimSpace(gjson.GetBytes(responseJSON, "response.id").String()) + } + if responseID == "" { + return + } + output := gjson.GetBytes(responseJSON, "output") + if !output.Exists() || !output.IsArray() { + output = gjson.GetBytes(responseJSON, "response.output") + } + if !output.Exists() || !output.IsArray() { + return + } + h.turnState.store(responseID, requestJSON, []byte(output.Raw)) +} + +func (h *OpenAIResponsesAPIHandler) rememberCompletedResponseFromChunk(requestJSON []byte, chunk []byte) { + if h == nil || len(chunk) == 0 { + return + } + for _, payload := range websocketJSONPayloadsFromChunk(chunk) { + if gjson.GetBytes(payload, "type").String() != wsEventTypeCompleted { + continue + } + h.rememberCompletedResponse(requestJSON, payload) + } +} diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go index 749de57018..f930020303 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket.go +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -80,6 +80,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { var lastRequest []byte lastResponseOutput := []byte("[]") pinnedAuthID := "" + forceDisableIncrementalAfterAuthReset := false for { msgType, payload, errReadMessage := conn.ReadMessage() @@ -106,16 +107,18 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { appendWebsocketEvent(&wsBodyLog, "request", payload) allowIncrementalInputWithPreviousResponseID := false - if pinnedAuthID != "" && h != nil && h.AuthManager != nil { - if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { - allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) - } - } else { - requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) - if requestModelName == "" { - requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) + if !forceDisableIncrementalAfterAuthReset { + if pinnedAuthID != "" && h != nil && h.AuthManager != nil { + if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { + allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) + } + } else { + requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String()) + if requestModelName == "" { + requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) + } + allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) } - allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName) } var requestJSON []byte @@ -150,6 +153,28 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } continue } + nextSessionRequestSnapshot := updatedLastRequest + if shouldBuildResponsesWebsocketFullSnapshot(requestJSON, allowIncrementalInputWithPreviousResponseID) { + _, shadowLastRequest, shadowErr := normalizeResponsesWebsocketRequestWithMode( + payload, + lastRequest, + lastResponseOutput, + false, + ) + if shadowErr != nil { + // 影子快照失败时保留旧快照避免污染会话状态 + nextSessionRequestSnapshot = lastRequest + log.Errorf( + "responses websocket: keep previous snapshot id=%s status=%d error=%v", + passthroughSessionID, + shadowErr.StatusCode, + shadowErr.Error, + ) + } else { + // 增量模式只发 delta 同时维护完整快照供切号恢复 + nextSessionRequestSnapshot = shadowLastRequest + } + } if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) { if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil { requestJSON = updated @@ -168,8 +193,7 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } requestJSON = repairResponsesWebsocketToolCalls(toolPairState, requestJSON) updatedLastRequest = bytes.Clone(requestJSON) - lastRequest = updatedLastRequest - + nextSessionRequestSnapshot = repairResponsesWebsocketToolCalls(toolPairState, nextSessionRequestSnapshot) modelName := gjson.GetBytes(requestJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx) @@ -193,14 +217,32 @@ func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { } dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") - completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID, toolPairState) + completedOutput, terminalStatus, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID, toolPairState) if errForward != nil { wsTerminateErr = errForward appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error())) log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward) return } - lastResponseOutput = completedOutput + if shouldResetResponsesWebsocketAuthPin(terminalStatus) { + // 限额错误后解除 pin 让后续请求重新选可用账号 + log.Infof("responses websocket: reset auth pin id=%s status=%d", passthroughSessionID, terminalStatus) + pinnedAuthID = "" + // 切号恢复阶段先禁用增量模式避免沿用旧账号 response id + forceDisableIncrementalAfterAuthReset = true + if h != nil && h.AuthManager != nil { + h.AuthManager.CloseExecutionSession(passthroughSessionID) + } + } else if forceDisableIncrementalAfterAuthReset && terminalStatus == 0 { + // 仅在成功轮次恢复增量模式避免失败轮次继续透传旧 response id + forceDisableIncrementalAfterAuthReset = false + } + if terminalStatus == 0 { + // 仅在本轮成功后提交快照避免失败轮次污染会话历史 + lastRequest = nextSessionRequestSnapshot + // 仅在本轮成功后提交输出避免失败把状态推进到空输出 + lastResponseOutput = completedOutput + } } } @@ -593,6 +635,14 @@ func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest [] return generateResult.Exists() && !generateResult.Bool() } +func shouldBuildResponsesWebsocketFullSnapshot(normalizedRequestJSON []byte, allowIncrementalInputWithPreviousResponseID bool) bool { + if !allowIncrementalInputWithPreviousResponseID { + return false + } + prev := strings.TrimSpace(gjson.GetBytes(normalizedRequestJSON, "previous_response_id").String()) + return prev != "" +} + func writeResponsesWebsocketSyntheticPrewarm( c *gin.Context, conn *websocket.Conn, @@ -716,21 +766,23 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( wsBodyLog *strings.Builder, sessionID string, toolPairState *websocketToolPairState, -) ([]byte, error) { +) ([]byte, int, error) { completed := false completedOutput := []byte("[]") + terminalStatusCode := 0 for { select { case <-c.Request.Context().Done(): cancel(c.Request.Context().Err()) - return completedOutput, c.Request.Context().Err() + return completedOutput, terminalStatusCode, c.Request.Context().Err() case errMsg, ok := <-errs: if !ok { errs = nil continue } if errMsg != nil { + terminalStatusCode = errMsg.StatusCode h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) @@ -750,7 +802,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( // errWrite, // ) cancel(errMsg.Error) - return completedOutput, errWrite + return completedOutput, terminalStatusCode, errWrite } } if errMsg != nil { @@ -758,7 +810,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( } else { cancel(nil) } - return completedOutput, nil + return completedOutput, terminalStatusCode, nil case chunk, ok := <-data: if !ok { if !completed { @@ -766,6 +818,7 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( StatusCode: http.StatusRequestTimeout, Error: fmt.Errorf("stream closed before response.completed"), } + terminalStatusCode = errMsg.StatusCode h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) @@ -785,13 +838,13 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( errWrite, ) cancel(errMsg.Error) - return completedOutput, errWrite + return completedOutput, terminalStatusCode, errWrite } cancel(errMsg.Error) - return completedOutput, nil + return completedOutput, terminalStatusCode, nil } cancel(nil) - return completedOutput, nil + return completedOutput, terminalStatusCode, nil } payloads := websocketJSONPayloadsFromChunk(chunk) @@ -819,13 +872,22 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( errWrite, ) cancel(errWrite) - return completedOutput, errWrite + return completedOutput, terminalStatusCode, errWrite } } } } } +func shouldResetResponsesWebsocketAuthPin(statusCode int) bool { + switch statusCode { + case http.StatusTooManyRequests, http.StatusForbidden, http.StatusPaymentRequired, http.StatusUnauthorized: + return true + default: + return false + } +} + func responseCompletedOutputFromPayload(payload []byte) []byte { output := gjson.GetBytes(payload, "response.output") if output.Exists() && output.IsArray() { diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go index c0d695cabf..bd4ba62e2d 100644 --- a/sdk/api/handlers/openai/openai_responses_websocket_test.go +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -70,6 +70,27 @@ type websocketAuthCaptureExecutor struct { authIDs []string } +type websocketStatusError struct { + code int + msg string +} + +func (e websocketStatusError) Error() string { + if strings.TrimSpace(e.msg) != "" { + return e.msg + } + return fmt.Sprintf("status %d", e.code) +} + +func (e websocketStatusError) StatusCode() int { + return e.code +} + +type websocketQuotaSwitchExecutor struct { + mu sync.Mutex + authIDs []string +} + func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" } func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -107,6 +128,47 @@ func (e *websocketAuthCaptureExecutor) AuthIDs() []string { return append([]string(nil), e.authIDs...) } +func (e *websocketQuotaSwitchExecutor) Identifier() string { return "test-provider" } + +func (e *websocketQuotaSwitchExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketQuotaSwitchExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, _ coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) { + e.mu.Lock() + if auth != nil { + e.authIDs = append(e.authIDs, auth.ID) + } + e.mu.Unlock() + + if auth != nil && auth.ID == "auth-1" { + return nil, websocketStatusError{code: http.StatusTooManyRequests, msg: "quota exhausted"} + } + + chunks := make(chan coreexecutor.StreamChunk, 1) + chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)} + close(chunks) + return &coreexecutor.StreamResult{Chunks: chunks}, nil +} + +func (e *websocketQuotaSwitchExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *websocketQuotaSwitchExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, errors.New("not implemented") +} + +func (e *websocketQuotaSwitchExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) { + return nil, errors.New("not implemented") +} + +func (e *websocketQuotaSwitchExecutor) AuthIDs() []string { + e.mu.Lock() + defer e.mu.Unlock() + return append([]string(nil), e.authIDs...) +} + func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" } func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -539,6 +601,32 @@ func TestSetWebsocketRequestBody(t *testing.T) { } } +func TestShouldResetResponsesWebsocketAuthPin(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + statusCode int + want bool + }{ + {name: "too_many_requests", statusCode: http.StatusTooManyRequests, want: true}, + {name: "forbidden", statusCode: http.StatusForbidden, want: true}, + {name: "payment_required", statusCode: http.StatusPaymentRequired, want: true}, + {name: "unauthorized", statusCode: http.StatusUnauthorized, want: true}, + {name: "internal_error", statusCode: http.StatusInternalServerError, want: false}, + {name: "zero", statusCode: 0, want: false}, + } + + for i := range cases { + tc := cases[i] + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + if got := shouldResetResponsesWebsocketAuthPin(tc.statusCode); got != tc.want { + t.Fatalf("shouldResetResponsesWebsocketAuthPin(%d) = %v, want %v", tc.statusCode, got, tc.want) + } + }) + } +} func TestRepairResponsesWebsocketToolCallsInsertsCachedOutput(t *testing.T) { state := newWebsocketToolPairState() @@ -672,7 +760,6 @@ func TestWebsocketToolPairStateConcurrentAccess(t *testing.T) { t.Fatal("concurrent tool pair state access timed out") } } - func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { gin.SetMode(gin.TestMode) @@ -700,7 +787,7 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { close(errCh) var bodyLog strings.Builder - completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( + completedOutput, statusCode, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket( ctx, conn, func(...interface{}) {}, @@ -714,6 +801,10 @@ func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) { serverErrCh <- err return } + if statusCode != 0 { + serverErrCh <- fmt.Errorf("status code = %d, want 0", statusCode) + return + } if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" { serverErrCh <- errors.New("completed output not captured") return @@ -947,6 +1038,89 @@ func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) { } } +func TestResponsesWebsocketClearsPinAndSwitchesAuthAfterQuotaError(t *testing.T) { + gin.SetMode(gin.TestMode) + + selector := &orderedWebsocketSelector{order: []string{"auth-1", "auth-2"}} + executor := &websocketQuotaSwitchExecutor{} + manager := coreauth.NewManager(nil, selector, nil) + manager.SetRetryConfig(0, 0, 1, 0) + manager.RegisterExecutor(executor) + + auth1 := &coreauth.Auth{ + ID: "auth-1", + Provider: executor.Identifier(), + Status: coreauth.StatusActive, + Attributes: map[string]string{"websockets": "true"}, + } + if _, err := manager.Register(context.Background(), auth1); err != nil { + t.Fatalf("Register auth-1: %v", err) + } + auth2 := &coreauth.Auth{ + ID: "auth-2", + Provider: executor.Identifier(), + Status: coreauth.StatusActive, + Attributes: map[string]string{"websockets": "true"}, + } + if _, err := manager.Register(context.Background(), auth2); err != nil { + t.Fatalf("Register auth-2: %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth1.ID) + registry.GetGlobalRegistry().UnregisterClient(auth2.ID) + }) + + base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager) + h := NewOpenAIResponsesAPIHandler(base) + router := gin.New() + router.GET("/v1/responses/ws", h.ResponsesWebsocket) + + server := httptest.NewServer(router) + defer server.Close() + + wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws" + conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil) + if err != nil { + t.Fatalf("dial websocket: %v", err) + } + defer func() { + if errClose := conn.Close(); errClose != nil { + t.Fatalf("close websocket: %v", errClose) + } + }() + + if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`)); errWrite != nil { + t.Fatalf("write first websocket message: %v", errWrite) + } + _, firstPayload, errReadFirst := conn.ReadMessage() + if errReadFirst != nil { + t.Fatalf("read first websocket message: %v", errReadFirst) + } + if got := gjson.GetBytes(firstPayload, "type").String(); got != wsEventTypeError { + t.Fatalf("first payload type = %s, want %s", got, wsEventTypeError) + } + if got := gjson.GetBytes(firstPayload, "error.code").String(); got != "rate_limit_exceeded" { + t.Fatalf("first payload code = %s, want rate_limit_exceeded", got) + } + + if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-2"}]}`)); errWrite != nil { + t.Fatalf("write second websocket message: %v", errWrite) + } + _, secondPayload, errReadSecond := conn.ReadMessage() + if errReadSecond != nil { + t.Fatalf("read second websocket message: %v", errReadSecond) + } + if got := gjson.GetBytes(secondPayload, "type").String(); got != wsEventTypeCompleted { + t.Fatalf("second payload type = %s, want %s", got, wsEventTypeCompleted) + } + + if got := executor.AuthIDs(); len(got) != 2 || got[0] != "auth-1" || got[1] != "auth-2" { + t.Fatalf("selected auth IDs = %v, want [auth-1 auth-2]", got) + } +} func TestResponsesWebsocketCompactionResetsTurnStateOnTranscriptReplacement(t *testing.T) { gin.SetMode(gin.TestMode) diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 142abe5797..23c69e6694 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -663,7 +663,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi return nil, &Error{Code: "executor_not_found", Message: "executor not registered"} } var lastErr error - for idx, execModel := range execModels { + for _, execModel := range execModels { resultModel := executionResultModel(routeModel, execModel, pooled) execReq := req execReq.Model = execModel @@ -704,18 +704,6 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi discardStreamChunks(streamResult.Chunks) return nil, bootstrapErr } - if idx < len(execModels)-1 { - rerr := &Error{Message: bootstrapErr.Error()} - if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil { - rerr.HTTPStatus = se.StatusCode() - } - result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: rerr} - result.RetryAfter = retryAfterFromError(bootstrapErr) - m.MarkResult(ctx, result) - discardStreamChunks(streamResult.Chunks) - lastErr = bootstrapErr - continue - } rerr := &Error{Message: bootstrapErr.Error()} if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil { rerr.HTTPStatus = se.StatusCode() @@ -724,18 +712,16 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi result.RetryAfter = retryAfterFromError(bootstrapErr) m.MarkResult(ctx, result) discardStreamChunks(streamResult.Chunks) - return nil, newStreamBootstrapError(bootstrapErr, streamResult.Headers) + lastErr = bootstrapErr + continue } if closed && len(buffered) == 0 { emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true} result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: emptyErr} m.MarkResult(ctx, result) - if idx < len(execModels)-1 { - lastErr = emptyErr - continue - } - return nil, newStreamBootstrapError(emptyErr, streamResult.Headers) + lastErr = emptyErr + continue } remaining := streamResult.Chunks diff --git a/sdk/cliproxy/auth/conductor_stream_retry_test.go b/sdk/cliproxy/auth/conductor_stream_retry_test.go new file mode 100644 index 0000000000..fa599ce5bb --- /dev/null +++ b/sdk/cliproxy/auth/conductor_stream_retry_test.go @@ -0,0 +1,216 @@ +package auth + +import ( + "context" + "fmt" + "net/http" + "strings" + "sync" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +// authAwareStreamExecutor is a test executor that returns different results per auth ID. +type authAwareStreamExecutor struct { + id string + + mu sync.Mutex + streamAuthIDs []string + streamErrors map[string]error // keyed by auth.ID + streamPayloads map[string][]byte // keyed by auth.ID + emptyStreamAuth map[string]struct{} // auth IDs that return empty (closed) stream +} + +func (e *authAwareStreamExecutor) Identifier() string { return e.id } + +func (e *authAwareStreamExecutor) Execute(_ context.Context, _ *Auth, req cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil +} + +func (e *authAwareStreamExecutor) ExecuteStream(_ context.Context, auth *Auth, req cliproxyexecutor.Request, _ cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) { + e.mu.Lock() + e.streamAuthIDs = append(e.streamAuthIDs, auth.ID) + streamErr := e.streamErrors[auth.ID] + payload := e.streamPayloads[auth.ID] + _, isEmpty := e.emptyStreamAuth[auth.ID] + e.mu.Unlock() + + ch := make(chan cliproxyexecutor.StreamChunk, 1) + if streamErr != nil { + ch <- cliproxyexecutor.StreamChunk{Err: streamErr} + close(ch) + return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Auth": {auth.ID}}, Chunks: ch}, nil + } + if isEmpty { + close(ch) + return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Auth": {auth.ID}}, Chunks: ch}, nil + } + if payload == nil { + payload = []byte(auth.ID) + } + ch <- cliproxyexecutor.StreamChunk{Payload: payload} + close(ch) + return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Auth": {auth.ID}}, Chunks: ch}, nil +} + +func (e *authAwareStreamExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) { + return auth, nil +} + +func (e *authAwareStreamExecutor) CountTokens(_ context.Context, _ *Auth, req cliproxyexecutor.Request, _ cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil +} + +func (e *authAwareStreamExecutor) HttpRequest(_ context.Context, _ *Auth, _ *http.Request) (*http.Response, error) { + return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "not implemented"} +} + +func (e *authAwareStreamExecutor) StreamAuthIDs() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.streamAuthIDs)) + copy(out, e.streamAuthIDs) + return out +} + +type streamRetryTestIDs struct { + provider string + model string + authIDs []string +} + +func newStreamRetryTestIDs(t *testing.T) streamRetryTestIDs { + t.Helper() + token := strings.ToLower(strings.NewReplacer("/", "-", " ", "-", ":", "-", "(", "-", ")", "-").Replace(t.Name())) + return streamRetryTestIDs{ + provider: "testprov-" + token, + model: "test-model-" + token, + authIDs: []string{fmt.Sprintf("auth-1-%s", token), fmt.Sprintf("auth-2-%s", token)}, + } +} + +func assertAttemptedAuthIDs(t *testing.T, got []string, want []string) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("stream auth IDs = %v, want exact attempts %v", got, want) + } + seen := make(map[string]int, len(got)) + for _, id := range got { + seen[id]++ + } + for _, id := range want { + if seen[id] != 1 { + t.Fatalf("stream auth IDs = %v, want exact attempts %v", got, want) + } + } +} + +func newMultiAuthTestManager(t *testing.T, model string, authIDs []string, executor *authAwareStreamExecutor) *Manager { + t.Helper() + m := NewManager(nil, nil, nil) + m.RegisterExecutor(executor) + + reg := registry.GetGlobalRegistry() + for _, id := range authIDs { + auth := &Auth{ + ID: id, + Provider: executor.id, + Status: StatusActive, + Attributes: map[string]string{ + "api_key": "key-" + id, + "provider_key": executor.id, + }, + } + if _, err := m.Register(context.Background(), auth); err != nil { + t.Fatalf("register auth %s: %v", id, err) + } + reg.RegisterClient(id, executor.id, []*registry.ModelInfo{{ID: model}}) + } + t.Cleanup(func() { + for _, id := range authIDs { + reg.UnregisterClient(id) + } + }) + return m +} + +func TestExecuteStream_RotatesAuthOnBootstrapError(t *testing.T) { + t.Parallel() + ids := newStreamRetryTestIDs(t) + executor := &authAwareStreamExecutor{ + id: ids.provider, + streamErrors: map[string]error{ + ids.authIDs[0]: &Error{HTTPStatus: http.StatusTooManyRequests, Message: "rate limited"}, + }, + streamPayloads: map[string][]byte{ + ids.authIDs[1]: []byte("ok-from-auth-2"), + }, + } + m := newMultiAuthTestManager(t, ids.model, ids.authIDs, executor) + + streamResult, err := m.ExecuteStream(context.Background(), []string{ids.provider}, cliproxyexecutor.Request{Model: ids.model}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + var payload []byte + for chunk := range streamResult.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected stream error: %v", chunk.Err) + } + payload = append(payload, chunk.Payload...) + } + if string(payload) != "ok-from-auth-2" { + t.Fatalf("payload = %q, want %q", string(payload), "ok-from-auth-2") + } + assertAttemptedAuthIDs(t, executor.StreamAuthIDs(), ids.authIDs) +} + +func TestExecuteStream_RotatesAuthOnEmptyStream(t *testing.T) { + t.Parallel() + ids := newStreamRetryTestIDs(t) + executor := &authAwareStreamExecutor{ + id: ids.provider, + emptyStreamAuth: map[string]struct{}{ids.authIDs[0]: {}}, + streamPayloads: map[string][]byte{ + ids.authIDs[1]: []byte("ok-from-auth-2"), + }, + } + m := newMultiAuthTestManager(t, ids.model, ids.authIDs, executor) + + streamResult, err := m.ExecuteStream(context.Background(), []string{ids.provider}, cliproxyexecutor.Request{Model: ids.model}, cliproxyexecutor.Options{}) + if err != nil { + t.Fatalf("ExecuteStream error: %v", err) + } + var payload []byte + for chunk := range streamResult.Chunks { + if chunk.Err != nil { + t.Fatalf("unexpected stream error: %v", chunk.Err) + } + payload = append(payload, chunk.Payload...) + } + if string(payload) != "ok-from-auth-2" { + t.Fatalf("payload = %q, want %q", string(payload), "ok-from-auth-2") + } + assertAttemptedAuthIDs(t, executor.StreamAuthIDs(), ids.authIDs) +} + +func TestExecuteStream_AllAuthsFailReturnsError(t *testing.T) { + t.Parallel() + ids := newStreamRetryTestIDs(t) + executor := &authAwareStreamExecutor{ + id: ids.provider, + streamErrors: map[string]error{ + ids.authIDs[0]: &Error{HTTPStatus: http.StatusTooManyRequests, Message: "rate limited"}, + ids.authIDs[1]: &Error{HTTPStatus: http.StatusTooManyRequests, Message: "rate limited"}, + }, + } + m := newMultiAuthTestManager(t, ids.model, ids.authIDs, executor) + + _, err := m.ExecuteStream(context.Background(), []string{ids.provider}, cliproxyexecutor.Request{Model: ids.model}, cliproxyexecutor.Options{}) + if err == nil { + t.Fatal("expected error when all auths fail, got nil") + } + assertAttemptedAuthIDs(t, executor.StreamAuthIDs(), ids.authIDs) +} From 094abafeff0f4623bbcbe6f034f867e69c01f22a Mon Sep 17 00:00:00 2001 From: benjamin Date: Fri, 3 Apr 2026 16:20:27 +0800 Subject: [PATCH 2/4] fix(review): preserve bootstrap headers, honor NO_PROXY, and harden integration test --- .../codex_executor_account_id_test.go | 11 +- internal/runtime/executor/proxy_helpers.go | 47 ++------ .../runtime/executor/proxy_helpers_test.go | 34 ++++++ sdk/api/handlers/handlers.go | 6 +- .../handlers_stream_bootstrap_test.go | 100 ++++++++++++++++++ sdk/cliproxy/auth/conductor.go | 22 ++-- 6 files changed, 169 insertions(+), 51 deletions(-) diff --git a/internal/runtime/executor/codex_executor_account_id_test.go b/internal/runtime/executor/codex_executor_account_id_test.go index fb01187bb4..9dfdf44da8 100644 --- a/internal/runtime/executor/codex_executor_account_id_test.go +++ b/internal/runtime/executor/codex_executor_account_id_test.go @@ -1,7 +1,10 @@ +//go:build integration + package executor import ( "bufio" + "context" "encoding/json" "fmt" "io" @@ -11,6 +14,7 @@ import ( "os" "strings" "testing" + "time" "github.com/google/uuid" tls "github.com/refraction-networking/utls" @@ -136,11 +140,15 @@ func TestCodexAccountCheck(t *testing.T) { if accessToken == "" { t.Skip("skipping: CODEX_ACCESS_TOKEN not set") } + t.Helper() proxyURL := os.Getenv("CODEX_PROXY_URL") deviceID := uuid.NewString() targetURL := "https://chatgpt.com/backend-api/accounts/check/v4-2023-04-27?timezone_offset_min=-480" - req, err := http.NewRequest(http.MethodGet, targetURL, nil) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, targetURL, nil) if err != nil { t.Fatalf("build request: %v", err) } @@ -162,6 +170,7 @@ func TestCodexAccountCheck(t *testing.T) { client := &http.Client{ Transport: newUtlsTransport(proxyURL), + Timeout: 20 * time.Second, } resp, err := client.Do(req) diff --git a/internal/runtime/executor/proxy_helpers.go b/internal/runtime/executor/proxy_helpers.go index d047484e03..3503016ce0 100644 --- a/internal/runtime/executor/proxy_helpers.go +++ b/internal/runtime/executor/proxy_helpers.go @@ -12,12 +12,14 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/config" cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" "github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil" + "golang.org/x/net/http/httpproxy" log "github.com/sirupsen/logrus" ) var ( proxyHTTPTransportCache sync.Map // map[string]*cachedProxyTransport environmentProxyKeys = []string{"HTTP_PROXY", "HTTPS_PROXY", "ALL_PROXY", "http_proxy", "https_proxy", "all_proxy"} + environmentNoProxyKeys = []string{"NO_PROXY", "no_proxy"} environmentProxyTransportCache sync.Map // map[string]*http.Transport ) @@ -128,53 +130,20 @@ func environmentProxySignature() string { for _, key := range environmentProxyKeys { values = append(values, key+"="+strings.TrimSpace(os.Getenv(key))) } + for _, key := range environmentNoProxyKeys { + values = append(values, key+"="+strings.TrimSpace(os.Getenv(key))) + } return strings.Join(values, "|") } func environmentProxyFunc() func(*http.Request) (*url.URL, error) { - httpProxy := firstEnvironmentValue("HTTP_PROXY", "http_proxy") - httpsProxy := firstEnvironmentValue("HTTPS_PROXY", "https_proxy") - allProxy := firstEnvironmentValue("ALL_PROXY", "all_proxy") + cfg := httpproxy.FromEnvironment() + proxyFunc := cfg.ProxyFunc() return func(req *http.Request) (*url.URL, error) { if req == nil || req.URL == nil { return nil, nil } - - raw := "" - switch strings.ToLower(req.URL.Scheme) { - case "https": - raw = firstNonEmpty(httpsProxy, allProxy, httpProxy) - case "http": - raw = firstNonEmpty(httpProxy, allProxy, httpsProxy) - default: - raw = firstNonEmpty(allProxy, httpsProxy, httpProxy) - } - raw = strings.TrimSpace(raw) - if raw == "" { - return nil, nil - } - if !strings.Contains(raw, "://") { - raw = "http://" + raw - } - return url.Parse(raw) - } -} - -func firstEnvironmentValue(keys ...string) string { - for _, key := range keys { - if value := strings.TrimSpace(os.Getenv(key)); value != "" { - return value - } - } - return "" -} - -func firstNonEmpty(values ...string) string { - for _, value := range values { - if strings.TrimSpace(value) != "" { - return value - } + return proxyFunc(req.URL) } - return "" } diff --git a/internal/runtime/executor/proxy_helpers_test.go b/internal/runtime/executor/proxy_helpers_test.go index 4174dafdb9..e41ed58aa8 100644 --- a/internal/runtime/executor/proxy_helpers_test.go +++ b/internal/runtime/executor/proxy_helpers_test.go @@ -155,6 +155,40 @@ func TestNewProxyAwareHTTPClientExplicitProxyWinsOverEnvironmentProxy(t *testing } } +func TestNewProxyAwareHTTPClientHonorsNoProxy(t *testing.T) { + setEnvironmentProxy(t, "http://env-proxy.example.com:8080") + + oldNoProxy, hadNoProxy := os.LookupEnv("NO_PROXY") + if err := os.Setenv("NO_PROXY", "example.com"); err != nil { + t.Fatalf("Setenv(NO_PROXY): %v", err) + } + t.Cleanup(func() { + if hadNoProxy { + _ = os.Setenv("NO_PROXY", oldNoProxy) + return + } + _ = os.Unsetenv("NO_PROXY") + }) + + client := newProxyAwareHTTPClient(context.Background(), &config.Config{}, &cliproxyauth.Auth{}, 0) + + transport, ok := client.Transport.(*http.Transport) + if !ok { + t.Fatalf("transport type = %T, want *http.Transport", client.Transport) + } + req, errReq := http.NewRequest(http.MethodGet, "https://example.com", nil) + if errReq != nil { + t.Fatalf("NewRequest() error = %v", errReq) + } + proxyURL, errProxy := transport.Proxy(req) + if errProxy != nil { + t.Fatalf("transport.Proxy() error = %v", errProxy) + } + if proxyURL != nil { + t.Fatalf("proxy URL = %v, want nil for NO_PROXY match", proxyURL) + } +} + func TestNewProxyAwareHTTPClientReusesEnvironmentProxyTransport(t *testing.T) { setEnvironmentProxy(t, "http://env-proxy.example.com:8080") diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 8f41850b48..ad3a380d22 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -780,7 +780,11 @@ func streamResultFromError(err error) *coreexecutor.StreamResult { errCh := make(chan coreexecutor.StreamChunk, 1) errCh <- coreexecutor.StreamChunk{Err: err} close(errCh) - return &coreexecutor.StreamResult{Chunks: errCh} + var headers http.Header + if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil { + headers = cloneHeader(FilterUpstreamHeaders(he.Headers())) + } + return &coreexecutor.StreamResult{Headers: headers, Chunks: errCh} } func validateSSEDataJSON(chunk []byte) error { diff --git a/sdk/api/handlers/handlers_stream_bootstrap_test.go b/sdk/api/handlers/handlers_stream_bootstrap_test.go index 61c0333227..0978b9fb2c 100644 --- a/sdk/api/handlers/handlers_stream_bootstrap_test.go +++ b/sdk/api/handlers/handlers_stream_bootstrap_test.go @@ -6,6 +6,7 @@ import ( "sync" "testing" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" @@ -137,6 +138,7 @@ type authAwareStreamExecutor struct { type invalidJSONStreamExecutor struct{} type splitResponsesEventStreamExecutor struct{} +type bootstrapHeaderStreamExecutor struct{} func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" } @@ -197,6 +199,48 @@ func (e *splitResponsesEventStreamExecutor) HttpRequest(ctx context.Context, aut } } +func (e *bootstrapHeaderStreamExecutor) Identifier() string { return "bootstrap-headers" } + +func (e *bootstrapHeaderStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} +} + +func (e *bootstrapHeaderStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) { + ch := make(chan coreexecutor.StreamChunk, 1) + ch <- coreexecutor.StreamChunk{ + Err: &coreauth.Error{ + Code: "rate_limited", + Message: "rate limited", + Retryable: true, + HTTPStatus: http.StatusTooManyRequests, + }, + } + close(ch) + return &coreexecutor.StreamResult{ + Headers: http.Header{ + "Retry-After": {"17"}, + "X-Upstream": {"bootstrap"}, + }, + Chunks: ch, + }, nil +} + +func (e *bootstrapHeaderStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *bootstrapHeaderStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"} +} + +func (e *bootstrapHeaderStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) { + return nil, &coreauth.Error{ + Code: "not_implemented", + Message: "HttpRequest not implemented", + HTTPStatus: http.StatusNotImplemented, + } +} + func (e *authAwareStreamExecutor) Identifier() string { return "codex" } func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { @@ -688,3 +732,59 @@ func TestExecuteStreamWithAuthManager_AllowsSplitOpenAIResponsesSSEEventLines(t t.Fatalf("unexpected second chunk.\nGot: %q\nWant: %q", got[1], expectedData) } } + +func TestExecuteStreamWithAuthManager_PreservesBootstrapHeadersOnImmediateWrappedError(t *testing.T) { + executor := &bootstrapHeaderStreamExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + auth1 := &coreauth.Auth{ + ID: "auth-bootstrap", + Provider: executor.Identifier(), + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "bootstrap@example.com"}, + } + if _, err := manager.Register(context.Background(), auth1); err != nil { + t.Fatalf("manager.Register(auth1): %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth1.ID) + }) + + handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{ + PassthroughHeaders: true, + Streaming: sdkconfig.StreamingConfig{ + BootstrapRetries: 0, + }, + }, manager) + + dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") + if dataChan == nil || errChan == nil { + t.Fatalf("expected non-nil channels") + } + + for chunk := range dataChan { + t.Fatalf("unexpected payload chunk: %q", string(chunk)) + } + + var gotErr *interfaces.ErrorMessage + for msg := range errChan { + if msg != nil { + gotErr = msg + } + } + if gotErr == nil || gotErr.Error == nil { + t.Fatalf("expected terminal error") + } + if gotErr.StatusCode != http.StatusTooManyRequests { + t.Fatalf("status = %d, want %d", gotErr.StatusCode, http.StatusTooManyRequests) + } + if gotErr.Addon == nil || gotErr.Addon.Get("Retry-After") != "17" { + t.Fatalf("addon headers = %#v, want Retry-After=17", gotErr.Addon) + } + if upstreamHeaders == nil || upstreamHeaders.Get("Retry-After") != "17" { + t.Fatalf("upstream headers = %#v, want Retry-After=17", upstreamHeaders) + } +} diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 23c69e6694..bac8b30b77 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -567,6 +567,16 @@ func (e *streamBootstrapError) Headers() http.Header { return cloneHTTPHeader(e.headers) } +func (e *streamBootstrapError) StatusCode() int { + if e == nil || e.cause == nil { + return 0 + } + if se, ok := e.cause.(interface{ StatusCode() int }); ok && se != nil { + return se.StatusCode() + } + return 0 +} + func streamErrorResult(headers http.Header, err error) *cliproxyexecutor.StreamResult { ch := make(chan cliproxyexecutor.StreamChunk, 1) ch <- cliproxyexecutor.StreamChunk{Err: err} @@ -712,7 +722,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi result.RetryAfter = retryAfterFromError(bootstrapErr) m.MarkResult(ctx, result) discardStreamChunks(streamResult.Chunks) - lastErr = bootstrapErr + lastErr = newStreamBootstrapError(bootstrapErr, streamResult.Headers) continue } @@ -720,7 +730,7 @@ func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor Provi emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true} result := Result{AuthID: auth.ID, Provider: provider, Model: resultModel, Success: false, Error: emptyErr} m.MarkResult(ctx, result) - lastErr = emptyErr + lastErr = newStreamBootstrapError(emptyErr, streamResult.Headers) continue } @@ -1317,10 +1327,6 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string return nil, requestInvalidErr } if lastErr != nil { - var bootstrapErr *streamBootstrapError - if errors.As(lastErr, &bootstrapErr) && bootstrapErr != nil { - return streamErrorResult(bootstrapErr.Headers(), bootstrapErr.cause), nil - } return nil, lastErr } return nil, &Error{Code: "auth_not_found", Message: "no auth available"} @@ -1331,10 +1337,6 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string return nil, requestInvalidErr } if lastErr != nil { - var bootstrapErr *streamBootstrapError - if errors.As(lastErr, &bootstrapErr) && bootstrapErr != nil { - return streamErrorResult(bootstrapErr.Headers(), bootstrapErr.cause), nil - } return nil, lastErr } return nil, errPick From 4d9fcee98fe8840844f1e79d770066bbb96b7da5 Mon Sep 17 00:00:00 2001 From: benjamin Date: Fri, 3 Apr 2026 18:15:47 +0800 Subject: [PATCH 3/4] fix(review): unify usage reporter finalization across executors --- .../runtime/executor/aistudio_executor.go | 6 +- .../runtime/executor/antigravity_executor.go | 9 +-- internal/runtime/executor/claude_executor.go | 6 +- internal/runtime/executor/codex_executor.go | 9 +-- .../executor/codex_websockets_executor.go | 6 +- .../runtime/executor/gemini_cli_executor.go | 6 +- internal/runtime/executor/gemini_executor.go | 6 +- .../executor/gemini_vertex_executor.go | 12 ++-- internal/runtime/executor/iflow_executor.go | 6 +- internal/runtime/executor/kimi_executor.go | 6 +- .../executor/openai_compat_executor.go | 6 +- internal/runtime/executor/qwen_executor.go | 6 +- internal/runtime/executor/usage_helpers.go | 14 ++++ .../runtime/executor/usage_helpers_test.go | 67 +++++++++++++++++++ 14 files changed, 109 insertions(+), 56 deletions(-) diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index efdf74f044..9dc2891a17 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -116,8 +116,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, } baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) translatedReq, body, err := e.translateRequest(req, opts, false) if err != nil { @@ -176,8 +175,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth } baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) translatedReq, body, err := e.translateRequest(req, opts, true) if err != nil { diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 0947acdac6..2e6d10ffc5 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -211,8 +211,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("antigravity") @@ -354,8 +353,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("antigravity") @@ -747,8 +745,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("antigravity") diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 8e64d79a2f..0a71dc377a 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -102,8 +102,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("claude") // Use streaming translation to preserve function calling, except for claude. @@ -271,8 +270,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("claude") originalPayloadSource := req.Payload diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 10c510bab0..c37d05c7c3 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -92,8 +92,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat plan, err := e.prepareCodexRequestPlan(ctx, req, opts, codexPreparedRequestPlanExecute) @@ -178,8 +177,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat plan, err := e.prepareCodexRequestPlan(ctx, req, opts, codexPreparedRequestPlanCompact) @@ -266,8 +264,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat plan, err := e.prepareCodexRequestPlan(ctx, req, opts, codexPreparedRequestPlanExecuteStream) diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go index 5091b7292c..1a739e323a 100644 --- a/internal/runtime/executor/codex_websockets_executor.go +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -243,8 +243,7 @@ func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyaut } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("codex") @@ -453,8 +452,7 @@ func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *clipr } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("codex") diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 5e6547ce24..9300003075 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -113,8 +113,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("gemini-cli") @@ -268,8 +267,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("gemini-cli") diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index dff863e0e7..dffae4ba3c 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -111,8 +111,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r apiKey, bearer := geminiCreds(auth) reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) // Official Gemini API via API key or OAuth bearer from := opts.SourceFormat @@ -220,8 +219,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A apiKey, bearer := geminiCreds(auth) reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("gemini") diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index 84d29357d0..8e21cc8e8e 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -302,8 +302,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) var body []byte @@ -429,8 +428,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("gemini") @@ -535,8 +533,7 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("gemini") @@ -660,8 +657,7 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("gemini") diff --git a/internal/runtime/executor/iflow_executor.go b/internal/runtime/executor/iflow_executor.go index 876fb0a232..85e839fb7a 100644 --- a/internal/runtime/executor/iflow_executor.go +++ b/internal/runtime/executor/iflow_executor.go @@ -87,8 +87,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("openai") @@ -191,8 +190,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("openai") diff --git a/internal/runtime/executor/kimi_executor.go b/internal/runtime/executor/kimi_executor.go index ab34221c36..0ffb0fb66e 100644 --- a/internal/runtime/executor/kimi_executor.go +++ b/internal/runtime/executor/kimi_executor.go @@ -77,8 +77,7 @@ func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req token := kimiCreds(auth) reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) to := sdktranslator.FromString("openai") originalPayloadSource := req.Payload @@ -178,8 +177,7 @@ func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut token := kimiCreds(auth) reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) to := sdktranslator.FromString("openai") originalPayloadSource := req.Payload diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index a0d3037065..30f724218c 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -73,8 +73,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) baseURL, apiKey := e.resolveCredentials(auth) if baseURL == "" { @@ -181,8 +180,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy baseModel := thinking.ParseSuffix(req.Model).ModelName reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) baseURL, apiKey := e.resolveCredentials(auth) if baseURL == "" { diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index 04eb73ac22..da8fe624e4 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -229,8 +229,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("openai") @@ -333,8 +332,7 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut } reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) - defer reporter.trackFailure(ctx, &err) - defer reporter.ensurePublished(ctx) + defer reporter.finalize(ctx, &err) from := opts.SourceFormat to := sdktranslator.FromString("openai") diff --git a/internal/runtime/executor/usage_helpers.go b/internal/runtime/executor/usage_helpers.go index de2f2e527e..83111ab35a 100644 --- a/internal/runtime/executor/usage_helpers.go +++ b/internal/runtime/executor/usage_helpers.go @@ -59,6 +59,20 @@ func (r *usageReporter) trackFailure(ctx context.Context, errPtr *error) { } } +// finalize publishes exactly one terminal usage record based on the final +// function error state. Success paths emit a default record when no explicit +// usage detail was published; failure paths emit a failed record. +func (r *usageReporter) finalize(ctx context.Context, errPtr *error) { + if r == nil { + return + } + if errPtr != nil && *errPtr != nil { + r.publishFailure(ctx) + return + } + r.ensurePublished(ctx) +} + func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Detail, failed bool) { if r == nil { return diff --git a/internal/runtime/executor/usage_helpers_test.go b/internal/runtime/executor/usage_helpers_test.go index 785f72b47c..70609b8281 100644 --- a/internal/runtime/executor/usage_helpers_test.go +++ b/internal/runtime/executor/usage_helpers_test.go @@ -1,6 +1,8 @@ package executor import ( + "context" + "errors" "testing" "time" @@ -62,3 +64,68 @@ func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) { t.Fatalf("latency = %v, want <= 3s", record.Latency) } } + +type usageCapturePlugin struct { + ch chan usage.Record +} + +func (p *usageCapturePlugin) HandleUsage(_ context.Context, record usage.Record) { + select { + case p.ch <- record: + default: + } +} + +func awaitUsageRecord(t *testing.T, ch <-chan usage.Record, provider string) usage.Record { + t.Helper() + + deadline := time.After(2 * time.Second) + for { + select { + case record := <-ch: + if record.Provider == provider { + return record + } + case <-deadline: + t.Fatalf("timed out waiting for usage record for provider %q", provider) + } + } +} + +func TestUsageReporterFinalizePublishesFailureOnError(t *testing.T) { + plugin := &usageCapturePlugin{ch: make(chan usage.Record, 8)} + usage.RegisterPlugin(plugin) + + reporter := &usageReporter{ + provider: "test-finalize-failure", + model: "model", + requestedAt: time.Now(), + } + err := errors.New("boom") + + reporter.finalize(context.Background(), &err) + + record := awaitUsageRecord(t, plugin.ch, reporter.provider) + if !record.Failed { + t.Fatalf("record.Failed = false, want true") + } +} + +func TestUsageReporterFinalizePublishesSuccessWithoutError(t *testing.T) { + plugin := &usageCapturePlugin{ch: make(chan usage.Record, 8)} + usage.RegisterPlugin(plugin) + + reporter := &usageReporter{ + provider: "test-finalize-success", + model: "model", + requestedAt: time.Now(), + } + var err error + + reporter.finalize(context.Background(), &err) + + record := awaitUsageRecord(t, plugin.ch, reporter.provider) + if record.Failed { + t.Fatalf("record.Failed = true, want false") + } +} From 1e2a75966d03ef9a211ff0f7d1392abef458450b Mon Sep 17 00:00:00 2001 From: benjamin Date: Fri, 3 Apr 2026 18:34:07 +0800 Subject: [PATCH 4/4] test: isolate usage reporter tests from global manager --- internal/runtime/executor/usage_helpers.go | 14 +++++- .../runtime/executor/usage_helpers_test.go | 43 +++++-------------- 2 files changed, 22 insertions(+), 35 deletions(-) diff --git a/internal/runtime/executor/usage_helpers.go b/internal/runtime/executor/usage_helpers.go index 83111ab35a..b4d9867417 100644 --- a/internal/runtime/executor/usage_helpers.go +++ b/internal/runtime/executor/usage_helpers.go @@ -24,6 +24,7 @@ type usageReporter struct { source string requestedAt time.Time once sync.Once + publishFn func(context.Context, usage.Record) } func newUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *usageReporter { @@ -34,6 +35,7 @@ func newUsageReporter(ctx context.Context, provider, model string, auth *cliprox requestedAt: time.Now(), apiKey: apiKey, source: resolveUsageSource(auth, apiKey), + publishFn: usage.PublishRecord, } if auth != nil { reporter.authID = auth.ID @@ -87,7 +89,7 @@ func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Det return } r.once.Do(func() { - usage.PublishRecord(ctx, r.buildRecord(detail, failed)) + publishUsageRecord(ctx, r.publishFn, r.buildRecord(detail, failed)) }) } @@ -100,10 +102,18 @@ func (r *usageReporter) ensurePublished(ctx context.Context) { return } r.once.Do(func() { - usage.PublishRecord(ctx, r.buildRecord(usage.Detail{}, false)) + publishUsageRecord(ctx, r.publishFn, r.buildRecord(usage.Detail{}, false)) }) } +func publishUsageRecord(ctx context.Context, fn func(context.Context, usage.Record), record usage.Record) { + if fn == nil { + usage.PublishRecord(ctx, record) + return + } + fn(ctx, record) +} + func (r *usageReporter) buildRecord(detail usage.Detail, failed bool) usage.Record { if r == nil { return usage.Record{Detail: detail, Failed: failed} diff --git a/internal/runtime/executor/usage_helpers_test.go b/internal/runtime/executor/usage_helpers_test.go index 70609b8281..ad9dcc6b98 100644 --- a/internal/runtime/executor/usage_helpers_test.go +++ b/internal/runtime/executor/usage_helpers_test.go @@ -65,66 +65,43 @@ func TestUsageReporterBuildRecordIncludesLatency(t *testing.T) { } } -type usageCapturePlugin struct { - ch chan usage.Record -} - -func (p *usageCapturePlugin) HandleUsage(_ context.Context, record usage.Record) { - select { - case p.ch <- record: - default: - } -} - -func awaitUsageRecord(t *testing.T, ch <-chan usage.Record, provider string) usage.Record { - t.Helper() - - deadline := time.After(2 * time.Second) - for { - select { - case record := <-ch: - if record.Provider == provider { - return record - } - case <-deadline: - t.Fatalf("timed out waiting for usage record for provider %q", provider) - } - } -} - func TestUsageReporterFinalizePublishesFailureOnError(t *testing.T) { - plugin := &usageCapturePlugin{ch: make(chan usage.Record, 8)} - usage.RegisterPlugin(plugin) + records := make(chan usage.Record, 2) reporter := &usageReporter{ provider: "test-finalize-failure", model: "model", requestedAt: time.Now(), + publishFn: func(_ context.Context, record usage.Record) { + records <- record + }, } err := errors.New("boom") reporter.finalize(context.Background(), &err) - record := awaitUsageRecord(t, plugin.ch, reporter.provider) + record := <-records if !record.Failed { t.Fatalf("record.Failed = false, want true") } } func TestUsageReporterFinalizePublishesSuccessWithoutError(t *testing.T) { - plugin := &usageCapturePlugin{ch: make(chan usage.Record, 8)} - usage.RegisterPlugin(plugin) + records := make(chan usage.Record, 2) reporter := &usageReporter{ provider: "test-finalize-success", model: "model", requestedAt: time.Now(), + publishFn: func(_ context.Context, record usage.Record) { + records <- record + }, } var err error reporter.finalize(context.Background(), &err) - record := awaitUsageRecord(t, plugin.ch, reporter.provider) + record := <-records if record.Failed { t.Fatalf("record.Failed = true, want false") }