diff --git a/config/config.go b/config/config.go index 6601681..20c53c1 100644 --- a/config/config.go +++ b/config/config.go @@ -68,6 +68,7 @@ func (c *CacheConfig) Label() string { type Config struct { Port int AdminSecret string + MaxRequestBodySize int Database DatabaseConfig Cache CacheConfig UseWebsocket bool // 是否启用 WebSocket 传输 @@ -81,7 +82,10 @@ func Load(envPath string) (*Config, error) { } _ = godotenv.Load(envPath) - cfg := &Config{Port: 8080} + cfg := &Config{ + Port: 8080, + MaxRequestBodySize: 32 * 1024 * 1024, + } // Web服务端口 if port := os.Getenv("CODEX_PORT"); port != "" { @@ -90,6 +94,11 @@ func Load(envPath string) (*Config, error) { fmt.Sscanf(port, "%d", &cfg.Port) } cfg.AdminSecret = strings.TrimSpace(os.Getenv("ADMIN_SECRET")) + if v := strings.TrimSpace(os.Getenv("CODEX_MAX_REQUEST_BODY_SIZE_MB")); v != "" { + if mb, err := strconv.Atoi(v); err == nil && mb > 0 { + cfg.MaxRequestBodySize = mb * 1024 * 1024 + } + } // WebSocket 配置 if v := strings.ToLower(strings.TrimSpace(os.Getenv("USE_WEBSOCKET"))); v == "true" || v == "1" { diff --git a/config/config_test.go b/config/config_test.go index 3035be4..02bd076 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -5,6 +5,7 @@ import "testing" func TestLoadDefaultsToPostgresAndRedis(t *testing.T) { keys := []string{ "CODEX_PORT", + "CODEX_MAX_REQUEST_BODY_SIZE_MB", "PORT", "ADMIN_SECRET", "DATABASE_DRIVER", @@ -48,11 +49,15 @@ func TestLoadDefaultsToPostgresAndRedis(t *testing.T) { if got := cfg.Port; got != 8080 { t.Fatalf("Port = %d, want %d", got, 8080) } + if got := cfg.MaxRequestBodySize; got != 32*1024*1024 { + t.Fatalf("MaxRequestBodySize = %d, want %d", got, 32*1024*1024) + } } func TestLoadAllowsExplicitSQLiteAndMemory(t *testing.T) { keys := []string{ "CODEX_PORT", + "CODEX_MAX_REQUEST_BODY_SIZE_MB", "PORT", "ADMIN_SECRET", "DATABASE_DRIVER", @@ -95,6 +100,7 @@ func TestLoadAllowsExplicitSQLiteAndMemory(t *testing.T) { func TestLoadReadsAdminSecretFromEnv(t *testing.T) { keys := []string{ "CODEX_PORT", + "CODEX_MAX_REQUEST_BODY_SIZE_MB", "PORT", "ADMIN_SECRET", "DATABASE_DRIVER", @@ -127,3 +133,40 @@ func TestLoadReadsAdminSecretFromEnv(t *testing.T) { t.Fatalf("AdminSecret = %q, want %q", got, "from-env-secret") } } + +func TestLoadReadsMaxRequestBodySizeFromEnv(t *testing.T) { + keys := []string{ + "CODEX_PORT", + "CODEX_MAX_REQUEST_BODY_SIZE_MB", + "PORT", + "ADMIN_SECRET", + "DATABASE_DRIVER", + "DATABASE_PATH", + "DATABASE_HOST", + "DATABASE_PORT", + "DATABASE_USER", + "DATABASE_PASSWORD", + "DATABASE_NAME", + "DATABASE_SSLMODE", + "CACHE_DRIVER", + "REDIS_ADDR", + "REDIS_PASSWORD", + "REDIS_DB", + } + for _, key := range keys { + t.Setenv(key, "") + } + + t.Setenv("DATABASE_HOST", "postgres") + t.Setenv("REDIS_ADDR", "redis:6379") + t.Setenv("CODEX_MAX_REQUEST_BODY_SIZE_MB", "64") + + cfg, err := Load("__not_exists__.env") + if err != nil { + t.Fatalf("Load() 返回错误: %v", err) + } + + if got := cfg.MaxRequestBodySize; got != 64*1024*1024 { + t.Fatalf("MaxRequestBodySize = %d, want %d", got, 64*1024*1024) + } +} diff --git a/main.go b/main.go index 0b6f305..5a05477 100644 --- a/main.go +++ b/main.go @@ -166,12 +166,13 @@ func main() { r.Use(api.RecoveryMiddleware()) r.Use(api.RequestContextMiddleware()) r.Use(api.VersionMiddleware()) + security.MaxRequestBodySize = cfg.MaxRequestBodySize + r.Use(security.RequestSizeLimiter(int64(security.MaxRequestBodySize))) r.Use(api.BodyCacheMiddleware()) r.Use(api.CORSMiddleware()) r.Use(api.SecurityHeadersMiddleware()) r.Use(loggerMiddleware()) r.Use(security.SecurityHeadersMiddleware()) - r.Use(security.RequestSizeLimiter(security.MaxRequestBodySize)) // handler 不再接收 cfg.APIKeys // 从环境变量读取 Codex 画像与 Beta 配置。 diff --git a/proxy/anthropic.go b/proxy/anthropic.go index 74a9c2c..a59bcf0 100644 --- a/proxy/anthropic.go +++ b/proxy/anthropic.go @@ -438,7 +438,7 @@ func convertAnthropicTools(tools []anthropicTool) []any { if len(t.InputSchema) > 0 { var params map[string]any if json.Unmarshal(t.InputSchema, ¶ms) == nil { - stripUnsupportedSchemaKeys(params) + sanitizeSchemaForUpstream(params) item["parameters"] = params } } diff --git a/proxy/handler.go b/proxy/handler.go index ca4d6eb..4eb1a6e 100644 --- a/proxy/handler.go +++ b/proxy/handler.go @@ -812,7 +812,7 @@ func (h *Handler) ResponsesCompact(c *gin.Context) { rawBody, _ = sjson.SetBytes(rawBody, "stream", false) // 准备上游请求体 - codexBody, _ := PrepareResponsesBody(rawBody) + codexBody, _ := PrepareCompactResponsesBody(rawBody) // 带重试的上游请求 maxRetries := h.getMaxRetries() diff --git a/proxy/translator.go b/proxy/translator.go index 34990ff..009eb80 100644 --- a/proxy/translator.go +++ b/proxy/translator.go @@ -25,10 +25,10 @@ type openAIRequest struct { // openAIMessage 表示一条 OpenAI 消息 type openAIMessage struct { - Role string `json:"role"` - Content json.RawMessage `json:"content"` // string 或 []contentPart + Role string `json:"role"` + Content json.RawMessage `json:"content"` // string 或 []contentPart ToolCalls []openAIToolCall `json:"tool_calls,omitempty"` - ToolCallID string `json:"tool_call_id,omitempty"` + ToolCallID string `json:"tool_call_id,omitempty"` } // openAIToolCall 表示 assistant 消息中的工具调用 @@ -339,9 +339,9 @@ func PrepareResponsesBody(rawBody []byte) ([]byte, string) { } } } - // 递归清理不支持的 JSON Schema 关键字 + // 递归清理不支持的 JSON Schema 关键字,并修正上游要求的结构 if params, ok := toolMap["parameters"].(map[string]any); ok { - stripUnsupportedSchemaKeys(params) + sanitizeSchemaForUpstream(params) } } } @@ -389,6 +389,16 @@ func PrepareResponsesBody(rawBody []byte) ([]byte, string) { return result, expandedInputRaw } +// PrepareCompactResponsesBody 将 /responses/compact 请求转换为上游可接受的格式。 +// 它复用通用 Responses 预处理,但会移除 compact 端点不接受的自动注入字段。 +func PrepareCompactResponsesBody(rawBody []byte) ([]byte, string) { + body, expandedInputRaw := PrepareResponsesBody(rawBody) + body, _ = sjson.DeleteBytes(body, "include") + body, _ = sjson.DeleteBytes(body, "store") + body, _ = sjson.DeleteBytes(body, "stream") + return body, expandedInputRaw +} + // normalizeReasoningEffort 将 reasoning_effort 钳位到上游支持的值 func normalizeReasoningEffort(effort string) string { if effort == "" { @@ -577,7 +587,7 @@ func convertToolsToCodexFormat(rawTools []json.RawMessage) []any { if len(parsed.Function.Parameters) > 0 { var params map[string]any if json.Unmarshal(parsed.Function.Parameters, ¶ms) == nil { - stripUnsupportedSchemaKeys(params) + sanitizeSchemaForUpstream(params) item["parameters"] = params } } @@ -692,6 +702,64 @@ func stripUnsupportedSchemaKeys(schema map[string]interface{}) { } } +func sanitizeSchemaForUpstream(schema map[string]interface{}) { + stripUnsupportedSchemaKeys(schema) + ensureArrayItems(schema) +} + +// ensureArrayItems 递归为缺失 items 的数组 schema 补上空 schema, +// 兼容上游对 array 必须声明 items 的校验。 +func ensureArrayItems(schema map[string]interface{}) { + if schemaDeclaresArray(schema) { + if _, ok := schema["items"]; !ok { + schema["items"] = map[string]interface{}{} + } + } + if props, ok := schema["properties"].(map[string]interface{}); ok { + for _, v := range props { + if sub, ok := v.(map[string]interface{}); ok { + ensureArrayItems(sub) + } + } + } + if items, ok := schema["items"].(map[string]interface{}); ok { + ensureArrayItems(items) + } + for _, key := range []string{"allOf", "anyOf", "oneOf"} { + if arr, ok := schema[key].([]interface{}); ok { + for _, item := range arr { + if sub, ok := item.(map[string]interface{}); ok { + ensureArrayItems(sub) + } + } + } + } + if addProps, ok := schema["additionalProperties"].(map[string]interface{}); ok { + ensureArrayItems(addProps) + } + if defs, ok := schema["$defs"].(map[string]interface{}); ok { + for _, v := range defs { + if sub, ok := v.(map[string]interface{}); ok { + ensureArrayItems(sub) + } + } + } +} + +func schemaDeclaresArray(schema map[string]interface{}) bool { + switch t := schema["type"].(type) { + case string: + return t == "array" + case []interface{}: + for _, item := range t { + if s, ok := item.(string); ok && s == "array" { + return true + } + } + } + return false +} + // ==================== 响应翻译: Codex SSE → OpenAI SSE ==================== // UsageInfo token 使用统计 diff --git a/proxy/translator_test.go b/proxy/translator_test.go index 480278f..6800a09 100644 --- a/proxy/translator_test.go +++ b/proxy/translator_test.go @@ -68,6 +68,119 @@ func TestTranslateRequest_PreservesSupportedServiceTier(t *testing.T) { } } +func TestTranslateRequest_FillsMissingArrayItemsInToolSchema(t *testing.T) { + raw := []byte(`{ + "model":"gpt-5.4", + "messages":[{"role":"user","content":"test"}], + "tools":[ + { + "type":"function", + "function":{ + "name":"godot-mcp_node_signal", + "parameters":{ + "type":"object", + "properties":{ + "args":{"type":"array"} + } + } + } + } + ] + }`) + + got, err := TranslateRequest(raw) + if err != nil { + t.Fatalf("TranslateRequest returned error: %v", err) + } + + items := gjson.GetBytes(got, "tools.0.parameters.properties.args.items") + if !items.Exists() || items.Type != gjson.JSON { + t.Fatalf("expected array schema items object to be injected, got %s", items.Raw) + } +} + +func TestPrepareResponsesBody_FillsMissingArrayItemsInToolSchema(t *testing.T) { + raw := []byte(`{ + "model":"gpt-5.4", + "input":"test", + "tools":[ + { + "type":"function", + "name":"godot-mcp_node_signal", + "parameters":{ + "type":"object", + "properties":{ + "args":{"type":"array"} + } + } + } + ] + }`) + + got, _ := PrepareResponsesBody(raw) + + items := gjson.GetBytes(got, "tools.0.parameters.properties.args.items") + if !items.Exists() || items.Type != gjson.JSON { + t.Fatalf("expected array schema items object to be injected, got %s", items.Raw) + } +} + +func TestPrepareResponsesBody_DefaultsIncludeForResponses(t *testing.T) { + raw := []byte(`{ + "model":"gpt-5.4", + "input":"test" + }`) + + got, _ := PrepareResponsesBody(raw) + + include := gjson.GetBytes(got, "include") + if !include.Exists() || len(include.Array()) != 1 || include.Array()[0].String() != "reasoning.encrypted_content" { + t.Fatalf("expected default include for responses, got %s", include.Raw) + } + if stream := gjson.GetBytes(got, "stream"); !stream.Exists() || !stream.Bool() { + t.Fatalf("expected stream to be forced for responses, got %s", stream.Raw) + } + if store := gjson.GetBytes(got, "store"); !store.Exists() || store.Bool() { + t.Fatalf("expected store=false for responses, got %s", store.Raw) + } +} + +func TestPrepareCompactResponsesBody_RemovesUnsupportedInjectedFields(t *testing.T) { + raw := []byte(`{ + "model":"gpt-5.4", + "input":"test" + }`) + + got, _ := PrepareCompactResponsesBody(raw) + + for _, field := range []string{"include", "store", "stream"} { + if gjson.GetBytes(got, field).Exists() { + t.Fatalf("expected %s to be removed for compact body", field) + } + } + input := gjson.GetBytes(got, "input") + if !input.Exists() || !input.IsArray() || len(input.Array()) != 1 { + t.Fatalf("expected compact input to remain normalized, got %s", input.Raw) + } + if input.Array()[0].Get("content").String() != "test" { + t.Fatalf("expected compact input content to be preserved, got %s", input.Raw) + } +} + +func TestPrepareCompactResponsesBody_RemovesClientSuppliedInclude(t *testing.T) { + raw := []byte(`{ + "model":"gpt-5.4", + "input":"test", + "include":["reasoning.encrypted_content"] + }`) + + got, _ := PrepareCompactResponsesBody(raw) + + if gjson.GetBytes(got, "include").Exists() { + t.Fatalf("expected client-supplied include to be removed for compact body, got %s", string(got)) + } +} + // ==================== Function Calling 测试 ==================== func TestConvertMessagesToInput_ToolRole(t *testing.T) { diff --git a/security/validator.go b/security/validator.go index 8681fee..7bf0d61 100644 --- a/security/validator.go +++ b/security/validator.go @@ -14,12 +14,15 @@ const ( MaxEmailLength = 255 MaxProxyURLLength = 500 MaxTokenLength = 8192 - MaxRequestBodySize = 10 * 1024 * 1024 // 10MB MaxHeaderSize = 16 * 1024 // 16KB AllowedModelPattern = `^[a-zA-Z0-9._-]+$` AllowedEndpointPattern = `^[a-zA-Z0-9/_-]+$` ) +const DefaultMaxRequestBodySize = 32 * 1024 * 1024 // 32MB + +var MaxRequestBodySize = DefaultMaxRequestBodySize + // Dangerous patterns for XSS prevention var ( xssPatterns = []*regexp.Regexp{