diff --git a/lib/oauthProxy.js b/lib/oauthProxy.js index c7c8f3b..e6b781f 100644 --- a/lib/oauthProxy.js +++ b/lib/oauthProxy.js @@ -64,15 +64,42 @@ function extractPartialImage(data) { return { b64, index, eventType: data.type }; } -function makeOAuthError(message, { status, code = "OAUTH_UPSTREAM_ERROR", upstreamBodyChars, eventType } = {}) { +function makeOAuthError(message, { status, code = "OAUTH_UPSTREAM_ERROR", upstreamBodyChars, eventType, upstreamType, upstreamParam } = {}) { const err = new Error(message); err.code = code; if (status) err.status = status; if (typeof upstreamBodyChars === "number") err.upstreamBodyChars = upstreamBodyChars; if (eventType) err.eventType = eventType; + if (upstreamType) err.upstreamType = upstreamType; + if (upstreamParam) err.upstreamParam = upstreamParam; return err; } +function makeOAuthHttpError(prefix, status, bodyText) { + let upstreamError = null; + try { + const parsed = JSON.parse(bodyText); + if (parsed && typeof parsed === "object" && parsed.error && typeof parsed.error === "object") { + upstreamError = parsed.error; + } + } catch {} + + if (status >= 400 && status < 500 && typeof upstreamError?.message === "string" && upstreamError.message.trim()) { + return makeOAuthError(upstreamError.message, { + status, + code: upstreamError.code || "OAUTH_UPSTREAM_ERROR", + upstreamBodyChars: bodyText.length, + upstreamType: upstreamError.type, + upstreamParam: upstreamError.param, + }); + } + + return makeOAuthError(`${prefix} ${status}`, { + status, + upstreamBodyChars: bodyText.length, + }); +} + async function readImageStream(res, { requestId = null, scope = "oauth", onPartialImage = null } = {}) { const reader = res.body.getReader(); const decoder = new TextDecoder(); @@ -204,10 +231,7 @@ export async function generateViaOAuth( if (!res.ok) { const text = await res.text(); logEvent("oauth", "error_response", { requestId, status: res.status, errorChars: text.length }); - throw makeOAuthError(`OAuth proxy returned ${res.status}`, { - status: res.status, - upstreamBodyChars: text.length, - }); + throw makeOAuthHttpError("OAuth proxy returned", res.status, text); } const contentType = res.headers.get("content-type") || ""; @@ -299,10 +323,7 @@ export async function editViaOAuth(prompt, imageB64, quality, size, moderation = if (!res.ok) { const text = await res.text(); logEvent("oauth-edit", "error_response", { requestId, status: res.status, errorChars: text.length }); - throw makeOAuthError(`OAuth edit returned ${res.status}`, { - status: res.status, - upstreamBodyChars: text.length, - }); + throw makeOAuthHttpError("OAuth edit returned", res.status, text); } const { imageB64: resultB64, usage, revisedPrompt } = await readImageStream(res, { diff --git a/routes/generate.js b/routes/generate.js index 4b0e403..5970898 100644 --- a/routes/generate.js +++ b/routes/generate.js @@ -18,6 +18,10 @@ function validateModeration(ctx, moderation) { return { moderation }; } +function isNonRetryableGenerationError(err) { + return Number.isInteger(err?.status) && err.status >= 400 && err.status < 500; +} + export function registerGenerateRoutes(app, ctx) { app.post("/api/generate", async (req, res) => { const requestId = typeof req.body?.requestId === "string" ? req.body.requestId : null; @@ -145,6 +149,7 @@ export function registerGenerateRoutes(app, ctx) { lastErr = new Error("Empty response (safety refusal)"); } catch (e) { lastErr = e; + if (isNonRetryableGenerationError(e)) throw e; } if (attempt < MAX_RETRIES) { logEvent("generate", "retry", { requestId, attempt: attempt + 1, errorCode: lastErr?.code }); @@ -209,6 +214,13 @@ export function registerGenerateRoutes(app, ctx) { finishErrorCode = "SAFETY_REFUSAL"; return res.status(422).json({ error: firstErr.message, code: "SAFETY_REFUSAL" }); } + if (firstErr?.status && firstErr.status >= 400 && firstErr.status < 500) { + const code = firstErr.code || classifyUpstreamError(firstErr.message); + finishStatus = "error"; + finishHttpStatus = firstErr.status; + finishErrorCode = code; + return res.status(firstErr.status).json({ error: firstErr.message, code }); + } finishStatus = "error"; finishHttpStatus = 500; finishErrorCode = "GENERATE_ALL_FAILED"; diff --git a/routes/nodes.js b/routes/nodes.js index 13c7ec4..fbc6617 100644 --- a/routes/nodes.js +++ b/routes/nodes.js @@ -49,6 +49,10 @@ function writeNodeError(res, status, code, message, parentNodeId) { }); } +function isNonRetryableGenerationError(err) { + return Number.isInteger(err?.status) && err.status >= 400 && err.status < 500; +} + function dataUrlFromB64(format, b64) { return `data:image/${format === "jpeg" ? "jpeg" : format};base64,${b64}`; } @@ -237,6 +241,7 @@ export function registerNodeRoutes(app, ctx) { lastErr = new Error("Empty response (safety refusal)"); } catch (e) { lastErr = e; + if (isNonRetryableGenerationError(e)) break; } if (attempt < MAX_RETRIES) { logEvent("node", "retry", { requestId, attempt: attempt + 1, errorCode: lastErr?.code }); @@ -244,13 +249,17 @@ export function registerNodeRoutes(app, ctx) { } if (!b64) { + const status = isNonRetryableGenerationError(lastErr) ? lastErr.status : 422; + const code = isNonRetryableGenerationError(lastErr) + ? (lastErr.code || classifyUpstreamError(lastErr.message)) + : "SAFETY_REFUSAL"; finishStatus = "error"; - finishHttpStatus = 422; - finishErrorCode = "SAFETY_REFUSAL"; + finishHttpStatus = status; + finishErrorCode = code; return writeNodeError( res, - 422, - "SAFETY_REFUSAL", + status, + code, lastErr?.message || "Empty response after retry", parentNodeId, ); diff --git a/tests/generate-route-validation-error.test.js b/tests/generate-route-validation-error.test.js new file mode 100644 index 0000000..36f1b04 --- /dev/null +++ b/tests/generate-route-validation-error.test.js @@ -0,0 +1,116 @@ +import { after, before, describe, it } from "node:test"; +import assert from "node:assert/strict"; +import express from "express"; +import { createServer } from "node:http"; +import { mkdtemp, mkdir, rm } from "node:fs/promises"; +import { join } from "node:path"; +import { tmpdir } from "node:os"; +import { registerGenerateRoutes } from "../routes/generate.js"; +import { registerNodeRoutes } from "../routes/nodes.js"; + +const upstreamMessage = "Invalid size '512x512'. Requested resolution is below the current minimum pixel budget."; + +function writeInvalidSize(res) { + res.writeHead(400, { "Content-Type": "application/json" }); + res.end(JSON.stringify({ + error: { + message: upstreamMessage, + type: "image_generation_user_error", + param: "tools", + code: "invalid_value", + }, + })); +} + +describe("generation route validation errors", () => { + let rootDir; + let oauthServer; + let appServer; + let baseUrl; + let upstreamRequests = 0; + + before(async () => { + rootDir = await mkdtemp(join(tmpdir(), "ima2-generate-validation-")); + await mkdir(join(rootDir, "generated"), { recursive: true }); + + oauthServer = createServer((req, res) => { + if (req.method === "POST" && req.url === "/v1/responses") { + upstreamRequests += 1; + writeInvalidSize(res); + return; + } + res.writeHead(404).end(); + }); + await new Promise((resolve) => oauthServer.listen(0, "127.0.0.1", resolve)); + const oauthAddress = oauthServer.address(); + + const app = express(); + app.use(express.json({ limit: "2mb" })); + const ctx = { + rootDir, + oauthUrl: `http://127.0.0.1:${oauthAddress.port}`, + config: { + ids: { generatedHexBytes: 4 }, + oauth: { validModeration: new Set(["auto", "low"]) }, + storage: { generatedDir: join(rootDir, "generated") }, + }, + }; + registerGenerateRoutes(app, ctx); + registerNodeRoutes(app, ctx); + await new Promise((resolve) => { + appServer = app.listen(0, "127.0.0.1", resolve); + }); + const appAddress = appServer.address(); + baseUrl = `http://127.0.0.1:${appAddress.port}`; + }); + + after(async () => { + await new Promise((resolve, reject) => appServer.close((err) => (err ? reject(err) : resolve()))); + await new Promise((resolve, reject) => oauthServer.close((err) => (err ? reject(err) : resolve()))); + await rm(rootDir, { recursive: true, force: true }); + }); + + it("returns invalid image size instead of safety refusal", async () => { + upstreamRequests = 0; + const res = await fetch(`${baseUrl}/api/generate`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + prompt: "a simple red circle", + quality: "low", + size: "512x512", + moderation: "low", + requestId: "req_invalid_size_route", + }), + }); + + const body = await res.json(); + assert.equal(res.status, 400); + assert.equal(body.error, upstreamMessage); + assert.equal(body.code, "invalid_value"); + assert.equal(upstreamRequests, 1, "non-retryable validation errors should not be retried"); + }); + + + it("returns invalid image size from node generation instead of safety refusal", async () => { + upstreamRequests = 0; + const res = await fetch(`${baseUrl}/api/node/generate`, { + method: "POST", + headers: { "Content-Type": "application/json" }, + body: JSON.stringify({ + parentNodeId: null, + prompt: "a simple red circle", + quality: "low", + size: "512x512", + moderation: "low", + requestId: "req_invalid_size_node", + }), + }); + + const body = await res.json(); + assert.equal(res.status, 400); + assert.equal(body.error.code, "invalid_value"); + assert.equal(body.error.message, upstreamMessage); + assert.equal(upstreamRequests, 1, "non-retryable node validation errors should not be retried"); + }); +}); diff --git a/tests/oauth-proxy-error-safety.test.js b/tests/oauth-proxy-error-safety.test.js index 519b1e2..000b52a 100644 --- a/tests/oauth-proxy-error-safety.test.js +++ b/tests/oauth-proxy-error-safety.test.js @@ -34,3 +34,38 @@ test("OAuth non-ok responses do not expose raw upstream body in logs or errors", await new Promise((resolve) => server.close(resolve)); } }); + +test("OAuth 4xx validation responses preserve actionable upstream message", async () => { + const upstreamMessage = "Invalid size '512x512'. Requested resolution is below the current minimum pixel budget."; + const server = createServer((_req, res) => { + res.writeHead(400, { "Content-Type": "application/json" }); + res.end(JSON.stringify({ + error: { + message: upstreamMessage, + type: "image_generation_user_error", + param: "tools", + code: "invalid_value", + }, + })); + }); + await new Promise((resolve) => server.listen(0, "127.0.0.1", resolve)); + const port = server.address().port; + + try { + await assert.rejects( + generateViaOAuth("safe test", "low", "512x512", "low", [], "req_invalid_size", "auto", { + oauthUrl: `http://127.0.0.1:${port}`, + }), + (err) => { + assert.equal(err.message, upstreamMessage); + assert.equal(err.status, 400); + assert.equal(err.code, "invalid_value"); + assert.equal(err.upstreamType, "image_generation_user_error"); + assert.equal(err.upstreamParam, "tools"); + return true; + }, + ); + } finally { + await new Promise((resolve) => server.close(resolve)); + } +});