diff --git a/lib/chat/__tests__/handleChatStream.test.ts b/lib/chat/__tests__/handleChatStream.test.ts index b98655e1..ab9b9e79 100644 --- a/lib/chat/__tests__/handleChatStream.test.ts +++ b/lib/chat/__tests__/handleChatStream.test.ts @@ -54,6 +54,14 @@ vi.mock("@/lib/chat/handleChatCompletion", () => ({ handleChatCompletion: vi.fn(), })); +vi.mock("@/lib/credits/handleChatCredits", () => ({ + handleChatCredits: vi.fn(), +})); + +vi.mock("@/lib/const", () => ({ + DEFAULT_MODEL: "openai/gpt-5-mini", +})); + vi.mock("ai", () => ({ createUIMessageStream: vi.fn(), createUIMessageStreamResponse: vi.fn(), diff --git a/lib/chat/__tests__/integration/chatEndToEnd.test.ts b/lib/chat/__tests__/integration/chatEndToEnd.test.ts index 02e758be..25841a5e 100644 --- a/lib/chat/__tests__/integration/chatEndToEnd.test.ts +++ b/lib/chat/__tests__/integration/chatEndToEnd.test.ts @@ -556,8 +556,9 @@ describe("Chat Integration Tests", () => { expect(mockDeductCredits).not.toHaveBeenCalled(); }); - it("handles zero cost gracefully", async () => { + it("deducts minimum 1 credit when cost is zero", async () => { mockGetCreditUsage.mockResolvedValue(0); + mockDeductCredits.mockResolvedValue({ success: true, newBalance: 332 }); await handleChatCredits({ usage: { promptTokens: 10, completionTokens: 5 }, @@ -566,7 +567,10 @@ describe("Chat Integration Tests", () => { }); expect(mockGetCreditUsage).toHaveBeenCalled(); - expect(mockDeductCredits).not.toHaveBeenCalled(); + expect(mockDeductCredits).toHaveBeenCalledWith({ + accountId: "account-123", + creditsToDeduct: 1, + }); }); it("catches credit deduction errors without breaking chat flow", async () => { diff --git a/lib/chat/handleChatStream.ts b/lib/chat/handleChatStream.ts index b181bcde..7319b2cb 100644 --- a/lib/chat/handleChatStream.ts +++ b/lib/chat/handleChatStream.ts @@ -1,10 +1,12 @@ import { NextRequest, NextResponse } from "next/server"; import { createUIMessageStream, createUIMessageStreamResponse } from "ai"; import { handleChatCompletion } from "./handleChatCompletion"; +import { handleChatCredits } from "@/lib/credits/handleChatCredits"; import { validateChatRequest } from "./validateChatRequest"; import { setupChatRequest } from "./setupChatRequest"; import { getCorsHeaders } from "@/lib/networking/getCorsHeaders"; import generateUUID from "@/lib/uuid/generateUUID"; +import { DEFAULT_MODEL } from "@/lib/const"; /** * Handles a streaming chat request. @@ -28,15 +30,15 @@ export async function handleChatStream(request: NextRequest): Promise const chatConfig = await setupChatRequest(body); const { agent } = chatConfig; + let streamResult: Awaited> | undefined; + const stream = createUIMessageStream({ originalMessages: body.messages, generateId: generateUUID, execute: async options => { const { writer } = options; - const result = await agent.stream(chatConfig); - writer.merge(result.toUIMessageStream()); - // Note: Credit handling and chat completion handling will be added - // as part of the handleChatCredits and handleChatCompletion migrations + streamResult = await agent.stream(chatConfig); + writer.merge(streamResult.toUIMessageStream()); }, onFinish: async event => { if (event.isAborted) { @@ -46,6 +48,13 @@ export async function handleChatStream(request: NextRequest): Promise const responseMessages = assistantMessages.length > 0 ? assistantMessages : [event.responseMessage]; await handleChatCompletion(body, responseMessages); + if (streamResult) { + await handleChatCredits({ + usage: await streamResult.usage, + model: body.model ?? DEFAULT_MODEL, + accountId: body.accountId, + }); + } }, onError: e => { console.error("/api/chat onError:", e); diff --git a/lib/credits/__tests__/handleChatCredits.test.ts b/lib/credits/__tests__/handleChatCredits.test.ts index 500b45bb..7ece222f 100644 --- a/lib/credits/__tests__/handleChatCredits.test.ts +++ b/lib/credits/__tests__/handleChatCredits.test.ts @@ -91,8 +91,9 @@ describe("handleChatCredits", () => { expect(mockDeductCredits).not.toHaveBeenCalled(); }); - it("skips credit deduction when usage cost is 0", async () => { + it("deducts minimum 1 credit when usage cost is 0", async () => { mockGetCreditUsage.mockResolvedValue(0); + mockDeductCredits.mockResolvedValue({ success: true, newBalance: 332 }); await handleChatCredits({ usage: { promptTokens: 0, completionTokens: 0 }, @@ -101,7 +102,10 @@ describe("handleChatCredits", () => { }); expect(mockGetCreditUsage).toHaveBeenCalled(); - expect(mockDeductCredits).not.toHaveBeenCalled(); + expect(mockDeductCredits).toHaveBeenCalledWith({ + accountId: "account-123", + creditsToDeduct: 1, + }); }); }); diff --git a/lib/credits/handleChatCredits.ts b/lib/credits/handleChatCredits.ts index c0462eab..b0d88438 100644 --- a/lib/credits/handleChatCredits.ts +++ b/lib/credits/handleChatCredits.ts @@ -10,7 +10,7 @@ interface HandleChatCreditsParams { /** * Handles credit deduction after chat completion. - * Calculates usage cost and deducts appropriate credits from the user's account. + * Always deducts at least 1 credit when accountId is present (round up from usage cost). * @param usage - The language model usage data * @param model - The model ID used for the chat * @param accountId - The account ID to deduct credits from (optional) @@ -27,15 +27,12 @@ export const handleChatCredits = async ({ try { const usageCost = await getCreditUsage(usage, model); + const creditsToDeduct = Math.max(1, Math.round(usageCost * 100)); - if (usageCost > 0) { - const creditsToDeduct = Math.max(1, Math.round(usageCost * 100)); - - await deductCredits({ - accountId, - creditsToDeduct, - }); - } + await deductCredits({ + accountId, + creditsToDeduct, + }); } catch (error) { console.error("Failed to handle chat credits:", error); // Don't throw error to avoid breaking the chat flow