Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions lib/oauthProxy.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,42 @@ function extractPartialImage(data) {
return { b64, index, eventType: data.type };
}

function makeOAuthError(message, { status, code = "OAUTH_UPSTREAM_ERROR", upstreamBodyChars, eventType } = {}) {
function makeOAuthError(message, { status, code = "OAUTH_UPSTREAM_ERROR", upstreamBodyChars, eventType, upstreamType, upstreamParam } = {}) {
const err = new Error(message);
err.code = code;
if (status) err.status = status;
if (typeof upstreamBodyChars === "number") err.upstreamBodyChars = upstreamBodyChars;
if (eventType) err.eventType = eventType;
if (upstreamType) err.upstreamType = upstreamType;
if (upstreamParam) err.upstreamParam = upstreamParam;
return err;
}

function makeOAuthHttpError(prefix, status, bodyText) {
let upstreamError = null;
try {
const parsed = JSON.parse(bodyText);
if (parsed && typeof parsed === "object" && parsed.error && typeof parsed.error === "object") {
upstreamError = parsed.error;
}
} catch {}

if (status >= 400 && status < 500 && typeof upstreamError?.message === "string" && upstreamError.message.trim()) {
return makeOAuthError(upstreamError.message, {
status,
code: upstreamError.code || "OAUTH_UPSTREAM_ERROR",
upstreamBodyChars: bodyText.length,
upstreamType: upstreamError.type,
upstreamParam: upstreamError.param,
});
}

return makeOAuthError(`${prefix} ${status}`, {
status,
upstreamBodyChars: bodyText.length,
});
}

async function readImageStream(res, { requestId = null, scope = "oauth", onPartialImage = null } = {}) {
const reader = res.body.getReader();
const decoder = new TextDecoder();
Expand Down Expand Up @@ -204,10 +231,7 @@ export async function generateViaOAuth(
if (!res.ok) {
const text = await res.text();
logEvent("oauth", "error_response", { requestId, status: res.status, errorChars: text.length });
throw makeOAuthError(`OAuth proxy returned ${res.status}`, {
status: res.status,
upstreamBodyChars: text.length,
});
throw makeOAuthHttpError("OAuth proxy returned", res.status, text);
}

const contentType = res.headers.get("content-type") || "";
Expand Down Expand Up @@ -299,10 +323,7 @@ export async function editViaOAuth(prompt, imageB64, quality, size, moderation =
if (!res.ok) {
const text = await res.text();
logEvent("oauth-edit", "error_response", { requestId, status: res.status, errorChars: text.length });
throw makeOAuthError(`OAuth edit returned ${res.status}`, {
status: res.status,
upstreamBodyChars: text.length,
});
throw makeOAuthHttpError("OAuth edit returned", res.status, text);
}

const { imageB64: resultB64, usage, revisedPrompt } = await readImageStream(res, {
Expand Down
12 changes: 12 additions & 0 deletions routes/generate.js
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ function validateModeration(ctx, moderation) {
return { moderation };
}

function isNonRetryableGenerationError(err) {
return Number.isInteger(err?.status) && err.status >= 400 && err.status < 500;
}

export function registerGenerateRoutes(app, ctx) {
app.post("/api/generate", async (req, res) => {
const requestId = typeof req.body?.requestId === "string" ? req.body.requestId : null;
Expand Down Expand Up @@ -145,6 +149,7 @@ export function registerGenerateRoutes(app, ctx) {
lastErr = new Error("Empty response (safety refusal)");
} catch (e) {
lastErr = e;
if (isNonRetryableGenerationError(e)) throw e;
}
if (attempt < MAX_RETRIES) {
logEvent("generate", "retry", { requestId, attempt: attempt + 1, errorCode: lastErr?.code });
Expand Down Expand Up @@ -209,6 +214,13 @@ export function registerGenerateRoutes(app, ctx) {
finishErrorCode = "SAFETY_REFUSAL";
return res.status(422).json({ error: firstErr.message, code: "SAFETY_REFUSAL" });
}
if (firstErr?.status && firstErr.status >= 400 && firstErr.status < 500) {
const code = firstErr.code || classifyUpstreamError(firstErr.message);
finishStatus = "error";
finishHttpStatus = firstErr.status;
finishErrorCode = code;
return res.status(firstErr.status).json({ error: firstErr.message, code });
}
finishStatus = "error";
finishHttpStatus = 500;
finishErrorCode = "GENERATE_ALL_FAILED";
Expand Down
17 changes: 13 additions & 4 deletions routes/nodes.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ function writeNodeError(res, status, code, message, parentNodeId) {
});
}

function isNonRetryableGenerationError(err) {
return Number.isInteger(err?.status) && err.status >= 400 && err.status < 500;
}

function dataUrlFromB64(format, b64) {
return `data:image/${format === "jpeg" ? "jpeg" : format};base64,${b64}`;
}
Expand Down Expand Up @@ -237,20 +241,25 @@ export function registerNodeRoutes(app, ctx) {
lastErr = new Error("Empty response (safety refusal)");
} catch (e) {
lastErr = e;
if (isNonRetryableGenerationError(e)) break;
}
if (attempt < MAX_RETRIES) {
logEvent("node", "retry", { requestId, attempt: attempt + 1, errorCode: lastErr?.code });
}
}

if (!b64) {
const status = isNonRetryableGenerationError(lastErr) ? lastErr.status : 422;
const code = isNonRetryableGenerationError(lastErr)
? (lastErr.code || classifyUpstreamError(lastErr.message))
: "SAFETY_REFUSAL";
finishStatus = "error";
finishHttpStatus = 422;
finishErrorCode = "SAFETY_REFUSAL";
finishHttpStatus = status;
finishErrorCode = code;
return writeNodeError(
res,
422,
"SAFETY_REFUSAL",
status,
code,
lastErr?.message || "Empty response after retry",
parentNodeId,
);
Expand Down
116 changes: 116 additions & 0 deletions tests/generate-route-validation-error.test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import { after, before, describe, it } from "node:test";
import assert from "node:assert/strict";
import express from "express";
import { createServer } from "node:http";
import { mkdtemp, mkdir, rm } from "node:fs/promises";
import { join } from "node:path";
import { tmpdir } from "node:os";
import { registerGenerateRoutes } from "../routes/generate.js";
import { registerNodeRoutes } from "../routes/nodes.js";

const upstreamMessage = "Invalid size '512x512'. Requested resolution is below the current minimum pixel budget.";

function writeInvalidSize(res) {
res.writeHead(400, { "Content-Type": "application/json" });
res.end(JSON.stringify({
error: {
message: upstreamMessage,
type: "image_generation_user_error",
param: "tools",
code: "invalid_value",
},
}));
}

describe("generation route validation errors", () => {
let rootDir;
let oauthServer;
let appServer;
let baseUrl;
let upstreamRequests = 0;

before(async () => {
rootDir = await mkdtemp(join(tmpdir(), "ima2-generate-validation-"));
await mkdir(join(rootDir, "generated"), { recursive: true });

oauthServer = createServer((req, res) => {
if (req.method === "POST" && req.url === "/v1/responses") {
upstreamRequests += 1;
writeInvalidSize(res);
return;
}
res.writeHead(404).end();
});
await new Promise((resolve) => oauthServer.listen(0, "127.0.0.1", resolve));
const oauthAddress = oauthServer.address();

const app = express();
app.use(express.json({ limit: "2mb" }));
const ctx = {
rootDir,
oauthUrl: `http://127.0.0.1:${oauthAddress.port}`,
config: {
ids: { generatedHexBytes: 4 },
oauth: { validModeration: new Set(["auto", "low"]) },
storage: { generatedDir: join(rootDir, "generated") },
},
};
registerGenerateRoutes(app, ctx);
registerNodeRoutes(app, ctx);
await new Promise((resolve) => {
appServer = app.listen(0, "127.0.0.1", resolve);
});
const appAddress = appServer.address();
baseUrl = `http://127.0.0.1:${appAddress.port}`;
});

after(async () => {
await new Promise((resolve, reject) => appServer.close((err) => (err ? reject(err) : resolve())));
await new Promise((resolve, reject) => oauthServer.close((err) => (err ? reject(err) : resolve())));
await rm(rootDir, { recursive: true, force: true });
});

it("returns invalid image size instead of safety refusal", async () => {
upstreamRequests = 0;
const res = await fetch(`${baseUrl}/api/generate`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
prompt: "a simple red circle",
quality: "low",
size: "512x512",
moderation: "low",
requestId: "req_invalid_size_route",
}),
});

const body = await res.json();
assert.equal(res.status, 400);
assert.equal(body.error, upstreamMessage);
assert.equal(body.code, "invalid_value");
assert.equal(upstreamRequests, 1, "non-retryable validation errors should not be retried");
});


it("returns invalid image size from node generation instead of safety refusal", async () => {
upstreamRequests = 0;
const res = await fetch(`${baseUrl}/api/node/generate`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
parentNodeId: null,
prompt: "a simple red circle",
quality: "low",
size: "512x512",
moderation: "low",
requestId: "req_invalid_size_node",
}),
});

const body = await res.json();
assert.equal(res.status, 400);
assert.equal(body.error.code, "invalid_value");
assert.equal(body.error.message, upstreamMessage);
assert.equal(upstreamRequests, 1, "non-retryable node validation errors should not be retried");
});
});
35 changes: 35 additions & 0 deletions tests/oauth-proxy-error-safety.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,38 @@ test("OAuth non-ok responses do not expose raw upstream body in logs or errors",
await new Promise((resolve) => server.close(resolve));
}
});

test("OAuth 4xx validation responses preserve actionable upstream message", async () => {
const upstreamMessage = "Invalid size '512x512'. Requested resolution is below the current minimum pixel budget.";
const server = createServer((_req, res) => {
res.writeHead(400, { "Content-Type": "application/json" });
res.end(JSON.stringify({
error: {
message: upstreamMessage,
type: "image_generation_user_error",
param: "tools",
code: "invalid_value",
},
}));
});
await new Promise((resolve) => server.listen(0, "127.0.0.1", resolve));
const port = server.address().port;

try {
await assert.rejects(
generateViaOAuth("safe test", "low", "512x512", "low", [], "req_invalid_size", "auto", {
oauthUrl: `http://127.0.0.1:${port}`,
}),
(err) => {
assert.equal(err.message, upstreamMessage);
assert.equal(err.status, 400);
assert.equal(err.code, "invalid_value");
assert.equal(err.upstreamType, "image_generation_user_error");
assert.equal(err.upstreamParam, "tools");
return true;
},
);
} finally {
await new Promise((resolve) => server.close(resolve));
}
});