Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ func (c *CacheConfig) Label() string {
type Config struct {
Port int
AdminSecret string
MaxRequestBodySize int
Database DatabaseConfig
Cache CacheConfig
UseWebsocket bool // 是否启用 WebSocket 传输
Expand All @@ -81,7 +82,10 @@ func Load(envPath string) (*Config, error) {
}
_ = godotenv.Load(envPath)

cfg := &Config{Port: 8080}
cfg := &Config{
Port: 8080,
MaxRequestBodySize: 32 * 1024 * 1024,
}

// Web服务端口
if port := os.Getenv("CODEX_PORT"); port != "" {
Expand All @@ -90,6 +94,11 @@ func Load(envPath string) (*Config, error) {
fmt.Sscanf(port, "%d", &cfg.Port)
}
cfg.AdminSecret = strings.TrimSpace(os.Getenv("ADMIN_SECRET"))
if v := strings.TrimSpace(os.Getenv("CODEX_MAX_REQUEST_BODY_SIZE_MB")); v != "" {
if mb, err := strconv.Atoi(v); err == nil && mb > 0 {
cfg.MaxRequestBodySize = mb * 1024 * 1024
}
}

// WebSocket 配置
if v := strings.ToLower(strings.TrimSpace(os.Getenv("USE_WEBSOCKET"))); v == "true" || v == "1" {
Expand Down
43 changes: 43 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import "testing"
func TestLoadDefaultsToPostgresAndRedis(t *testing.T) {
keys := []string{
"CODEX_PORT",
"CODEX_MAX_REQUEST_BODY_SIZE_MB",
"PORT",
"ADMIN_SECRET",
"DATABASE_DRIVER",
Expand Down Expand Up @@ -48,11 +49,15 @@ func TestLoadDefaultsToPostgresAndRedis(t *testing.T) {
if got := cfg.Port; got != 8080 {
t.Fatalf("Port = %d, want %d", got, 8080)
}
if got := cfg.MaxRequestBodySize; got != 32*1024*1024 {
t.Fatalf("MaxRequestBodySize = %d, want %d", got, 32*1024*1024)
}
}

func TestLoadAllowsExplicitSQLiteAndMemory(t *testing.T) {
keys := []string{
"CODEX_PORT",
"CODEX_MAX_REQUEST_BODY_SIZE_MB",
"PORT",
"ADMIN_SECRET",
"DATABASE_DRIVER",
Expand Down Expand Up @@ -95,6 +100,7 @@ func TestLoadAllowsExplicitSQLiteAndMemory(t *testing.T) {
func TestLoadReadsAdminSecretFromEnv(t *testing.T) {
keys := []string{
"CODEX_PORT",
"CODEX_MAX_REQUEST_BODY_SIZE_MB",
"PORT",
"ADMIN_SECRET",
"DATABASE_DRIVER",
Expand Down Expand Up @@ -127,3 +133,40 @@ func TestLoadReadsAdminSecretFromEnv(t *testing.T) {
t.Fatalf("AdminSecret = %q, want %q", got, "from-env-secret")
}
}

func TestLoadReadsMaxRequestBodySizeFromEnv(t *testing.T) {
keys := []string{
"CODEX_PORT",
"CODEX_MAX_REQUEST_BODY_SIZE_MB",
"PORT",
"ADMIN_SECRET",
"DATABASE_DRIVER",
"DATABASE_PATH",
"DATABASE_HOST",
"DATABASE_PORT",
"DATABASE_USER",
"DATABASE_PASSWORD",
"DATABASE_NAME",
"DATABASE_SSLMODE",
"CACHE_DRIVER",
"REDIS_ADDR",
"REDIS_PASSWORD",
"REDIS_DB",
}
for _, key := range keys {
t.Setenv(key, "")
}

t.Setenv("DATABASE_HOST", "postgres")
t.Setenv("REDIS_ADDR", "redis:6379")
t.Setenv("CODEX_MAX_REQUEST_BODY_SIZE_MB", "64")

cfg, err := Load("__not_exists__.env")
if err != nil {
t.Fatalf("Load() 返回错误: %v", err)
}

if got := cfg.MaxRequestBodySize; got != 64*1024*1024 {
t.Fatalf("MaxRequestBodySize = %d, want %d", got, 64*1024*1024)
}
}
3 changes: 2 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,12 +166,13 @@ func main() {
r.Use(api.RecoveryMiddleware())
r.Use(api.RequestContextMiddleware())
r.Use(api.VersionMiddleware())
security.MaxRequestBodySize = cfg.MaxRequestBodySize
r.Use(security.RequestSizeLimiter(int64(security.MaxRequestBodySize)))
r.Use(api.BodyCacheMiddleware())
r.Use(api.CORSMiddleware())
r.Use(api.SecurityHeadersMiddleware())
r.Use(loggerMiddleware())
r.Use(security.SecurityHeadersMiddleware())
r.Use(security.RequestSizeLimiter(security.MaxRequestBodySize))

// handler 不再接收 cfg.APIKeys
// 从环境变量读取 Codex 画像与 Beta 配置。
Expand Down
2 changes: 1 addition & 1 deletion proxy/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,7 @@ func convertAnthropicTools(tools []anthropicTool) []any {
if len(t.InputSchema) > 0 {
var params map[string]any
if json.Unmarshal(t.InputSchema, &params) == nil {
stripUnsupportedSchemaKeys(params)
sanitizeSchemaForUpstream(params)
item["parameters"] = params
}
}
Expand Down
2 changes: 1 addition & 1 deletion proxy/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ func (h *Handler) ResponsesCompact(c *gin.Context) {
rawBody, _ = sjson.SetBytes(rawBody, "stream", false)

// 准备上游请求体
codexBody, _ := PrepareResponsesBody(rawBody)
codexBody, _ := PrepareCompactResponsesBody(rawBody)

// 带重试的上游请求
maxRetries := h.getMaxRetries()
Expand Down
80 changes: 74 additions & 6 deletions proxy/translator.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ type openAIRequest struct {

// openAIMessage 表示一条 OpenAI 消息
type openAIMessage struct {
Role string `json:"role"`
Content json.RawMessage `json:"content"` // string 或 []contentPart
Role string `json:"role"`
Content json.RawMessage `json:"content"` // string 或 []contentPart
ToolCalls []openAIToolCall `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}

// openAIToolCall 表示 assistant 消息中的工具调用
Expand Down Expand Up @@ -339,9 +339,9 @@ func PrepareResponsesBody(rawBody []byte) ([]byte, string) {
}
}
}
// 递归清理不支持的 JSON Schema 关键字
// 递归清理不支持的 JSON Schema 关键字,并修正上游要求的结构
if params, ok := toolMap["parameters"].(map[string]any); ok {
stripUnsupportedSchemaKeys(params)
sanitizeSchemaForUpstream(params)
}
}
}
Expand Down Expand Up @@ -389,6 +389,16 @@ func PrepareResponsesBody(rawBody []byte) ([]byte, string) {
return result, expandedInputRaw
}

// PrepareCompactResponsesBody 将 /responses/compact 请求转换为上游可接受的格式。
// 它复用通用 Responses 预处理,但会移除 compact 端点不接受的自动注入字段。
func PrepareCompactResponsesBody(rawBody []byte) ([]byte, string) {
body, expandedInputRaw := PrepareResponsesBody(rawBody)
body, _ = sjson.DeleteBytes(body, "include")
body, _ = sjson.DeleteBytes(body, "store")
body, _ = sjson.DeleteBytes(body, "stream")
return body, expandedInputRaw
}

// normalizeReasoningEffort 将 reasoning_effort 钳位到上游支持的值
func normalizeReasoningEffort(effort string) string {
if effort == "" {
Expand Down Expand Up @@ -577,7 +587,7 @@ func convertToolsToCodexFormat(rawTools []json.RawMessage) []any {
if len(parsed.Function.Parameters) > 0 {
var params map[string]any
if json.Unmarshal(parsed.Function.Parameters, &params) == nil {
stripUnsupportedSchemaKeys(params)
sanitizeSchemaForUpstream(params)
item["parameters"] = params
}
}
Expand Down Expand Up @@ -692,6 +702,64 @@ func stripUnsupportedSchemaKeys(schema map[string]interface{}) {
}
}

func sanitizeSchemaForUpstream(schema map[string]interface{}) {
stripUnsupportedSchemaKeys(schema)
ensureArrayItems(schema)
}

// ensureArrayItems 递归为缺失 items 的数组 schema 补上空 schema,
// 兼容上游对 array 必须声明 items 的校验。
func ensureArrayItems(schema map[string]interface{}) {
if schemaDeclaresArray(schema) {
if _, ok := schema["items"]; !ok {
schema["items"] = map[string]interface{}{}
}
}
if props, ok := schema["properties"].(map[string]interface{}); ok {
for _, v := range props {
if sub, ok := v.(map[string]interface{}); ok {
ensureArrayItems(sub)
}
}
}
if items, ok := schema["items"].(map[string]interface{}); ok {
ensureArrayItems(items)
}
for _, key := range []string{"allOf", "anyOf", "oneOf"} {
if arr, ok := schema[key].([]interface{}); ok {
for _, item := range arr {
if sub, ok := item.(map[string]interface{}); ok {
ensureArrayItems(sub)
}
}
}
}
if addProps, ok := schema["additionalProperties"].(map[string]interface{}); ok {
ensureArrayItems(addProps)
}
if defs, ok := schema["$defs"].(map[string]interface{}); ok {
for _, v := range defs {
if sub, ok := v.(map[string]interface{}); ok {
ensureArrayItems(sub)
}
}
}
}

func schemaDeclaresArray(schema map[string]interface{}) bool {
switch t := schema["type"].(type) {
case string:
return t == "array"
case []interface{}:
for _, item := range t {
if s, ok := item.(string); ok && s == "array" {
return true
}
}
}
return false
}

// ==================== 响应翻译: Codex SSE → OpenAI SSE ====================

// UsageInfo token 使用统计
Expand Down
113 changes: 113 additions & 0 deletions proxy/translator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,119 @@ func TestTranslateRequest_PreservesSupportedServiceTier(t *testing.T) {
}
}

func TestTranslateRequest_FillsMissingArrayItemsInToolSchema(t *testing.T) {
raw := []byte(`{
"model":"gpt-5.4",
"messages":[{"role":"user","content":"test"}],
"tools":[
{
"type":"function",
"function":{
"name":"godot-mcp_node_signal",
"parameters":{
"type":"object",
"properties":{
"args":{"type":"array"}
}
}
}
}
]
}`)

got, err := TranslateRequest(raw)
if err != nil {
t.Fatalf("TranslateRequest returned error: %v", err)
}

items := gjson.GetBytes(got, "tools.0.parameters.properties.args.items")
if !items.Exists() || items.Type != gjson.JSON {
t.Fatalf("expected array schema items object to be injected, got %s", items.Raw)
}
}

func TestPrepareResponsesBody_FillsMissingArrayItemsInToolSchema(t *testing.T) {
raw := []byte(`{
"model":"gpt-5.4",
"input":"test",
"tools":[
{
"type":"function",
"name":"godot-mcp_node_signal",
"parameters":{
"type":"object",
"properties":{
"args":{"type":"array"}
}
}
}
]
}`)

got, _ := PrepareResponsesBody(raw)

items := gjson.GetBytes(got, "tools.0.parameters.properties.args.items")
if !items.Exists() || items.Type != gjson.JSON {
t.Fatalf("expected array schema items object to be injected, got %s", items.Raw)
}
}

func TestPrepareResponsesBody_DefaultsIncludeForResponses(t *testing.T) {
raw := []byte(`{
"model":"gpt-5.4",
"input":"test"
}`)

got, _ := PrepareResponsesBody(raw)

include := gjson.GetBytes(got, "include")
if !include.Exists() || len(include.Array()) != 1 || include.Array()[0].String() != "reasoning.encrypted_content" {
t.Fatalf("expected default include for responses, got %s", include.Raw)
}
if stream := gjson.GetBytes(got, "stream"); !stream.Exists() || !stream.Bool() {
t.Fatalf("expected stream to be forced for responses, got %s", stream.Raw)
}
if store := gjson.GetBytes(got, "store"); !store.Exists() || store.Bool() {
t.Fatalf("expected store=false for responses, got %s", store.Raw)
}
}

func TestPrepareCompactResponsesBody_RemovesUnsupportedInjectedFields(t *testing.T) {
raw := []byte(`{
"model":"gpt-5.4",
"input":"test"
}`)

got, _ := PrepareCompactResponsesBody(raw)

for _, field := range []string{"include", "store", "stream"} {
if gjson.GetBytes(got, field).Exists() {
t.Fatalf("expected %s to be removed for compact body", field)
}
}
input := gjson.GetBytes(got, "input")
if !input.Exists() || !input.IsArray() || len(input.Array()) != 1 {
t.Fatalf("expected compact input to remain normalized, got %s", input.Raw)
}
if input.Array()[0].Get("content").String() != "test" {
t.Fatalf("expected compact input content to be preserved, got %s", input.Raw)
}
}

func TestPrepareCompactResponsesBody_RemovesClientSuppliedInclude(t *testing.T) {
raw := []byte(`{
"model":"gpt-5.4",
"input":"test",
"include":["reasoning.encrypted_content"]
}`)

got, _ := PrepareCompactResponsesBody(raw)

if gjson.GetBytes(got, "include").Exists() {
t.Fatalf("expected client-supplied include to be removed for compact body, got %s", string(got))
}
}

// ==================== Function Calling 测试 ====================

func TestConvertMessagesToInput_ToolRole(t *testing.T) {
Expand Down
5 changes: 4 additions & 1 deletion security/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ const (
MaxEmailLength = 255
MaxProxyURLLength = 500
MaxTokenLength = 8192
MaxRequestBodySize = 10 * 1024 * 1024 // 10MB
MaxHeaderSize = 16 * 1024 // 16KB
AllowedModelPattern = `^[a-zA-Z0-9._-]+$`
AllowedEndpointPattern = `^[a-zA-Z0-9/_-]+$`
)

const DefaultMaxRequestBodySize = 32 * 1024 * 1024 // 32MB

var MaxRequestBodySize = DefaultMaxRequestBodySize

// Dangerous patterns for XSS prevention
var (
xssPatterns = []*regexp.Regexp{
Expand Down
Loading