Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions lib/chat/__tests__/handleChatStream.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ vi.mock("@/lib/chat/handleChatCompletion", () => ({
handleChatCompletion: vi.fn(),
}));

vi.mock("@/lib/credits/handleChatCredits", () => ({
handleChatCredits: vi.fn(),
}));

vi.mock("@/lib/const", () => ({
DEFAULT_MODEL: "openai/gpt-5-mini",
}));

vi.mock("ai", () => ({
createUIMessageStream: vi.fn(),
createUIMessageStreamResponse: vi.fn(),
Expand Down
8 changes: 6 additions & 2 deletions lib/chat/__tests__/integration/chatEndToEnd.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -556,8 +556,9 @@ describe("Chat Integration Tests", () => {
expect(mockDeductCredits).not.toHaveBeenCalled();
});

it("handles zero cost gracefully", async () => {
it("deducts minimum 1 credit when cost is zero", async () => {
mockGetCreditUsage.mockResolvedValue(0);
mockDeductCredits.mockResolvedValue({ success: true, newBalance: 332 });

await handleChatCredits({
usage: { promptTokens: 10, completionTokens: 5 },
Expand All @@ -566,7 +567,10 @@ describe("Chat Integration Tests", () => {
});

expect(mockGetCreditUsage).toHaveBeenCalled();
expect(mockDeductCredits).not.toHaveBeenCalled();
expect(mockDeductCredits).toHaveBeenCalledWith({
accountId: "account-123",
creditsToDeduct: 1,
});
});

it("catches credit deduction errors without breaking chat flow", async () => {
Expand Down
17 changes: 13 additions & 4 deletions lib/chat/handleChatStream.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import { NextRequest, NextResponse } from "next/server";
import { createUIMessageStream, createUIMessageStreamResponse } from "ai";
import { handleChatCompletion } from "./handleChatCompletion";
import { handleChatCredits } from "@/lib/credits/handleChatCredits";
import { validateChatRequest } from "./validateChatRequest";
import { setupChatRequest } from "./setupChatRequest";
import { getCorsHeaders } from "@/lib/networking/getCorsHeaders";
import generateUUID from "@/lib/uuid/generateUUID";
import { DEFAULT_MODEL } from "@/lib/const";

/**
* Handles a streaming chat request.
Expand All @@ -28,15 +30,15 @@ export async function handleChatStream(request: NextRequest): Promise<Response>
const chatConfig = await setupChatRequest(body);
const { agent } = chatConfig;

let streamResult: Awaited<ReturnType<typeof agent.stream>> | undefined;

const stream = createUIMessageStream({
originalMessages: body.messages,
generateId: generateUUID,
execute: async options => {
const { writer } = options;
const result = await agent.stream(chatConfig);
writer.merge(result.toUIMessageStream());
// Note: Credit handling and chat completion handling will be added
// as part of the handleChatCredits and handleChatCompletion migrations
streamResult = await agent.stream(chatConfig);
writer.merge(streamResult.toUIMessageStream());
},
onFinish: async event => {
if (event.isAborted) {
Expand All @@ -46,6 +48,13 @@ export async function handleChatStream(request: NextRequest): Promise<Response>
const responseMessages =
assistantMessages.length > 0 ? assistantMessages : [event.responseMessage];
await handleChatCompletion(body, responseMessages);
if (streamResult) {
await handleChatCredits({
usage: await streamResult.usage,
model: body.model ?? DEFAULT_MODEL,
accountId: body.accountId,
});
}
},
onError: e => {
console.error("/api/chat onError:", e);
Expand Down
8 changes: 6 additions & 2 deletions lib/credits/__tests__/handleChatCredits.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ describe("handleChatCredits", () => {
expect(mockDeductCredits).not.toHaveBeenCalled();
});

it("skips credit deduction when usage cost is 0", async () => {
it("deducts minimum 1 credit when usage cost is 0", async () => {
mockGetCreditUsage.mockResolvedValue(0);
mockDeductCredits.mockResolvedValue({ success: true, newBalance: 332 });

await handleChatCredits({
usage: { promptTokens: 0, completionTokens: 0 },
Expand All @@ -101,7 +102,10 @@ describe("handleChatCredits", () => {
});

expect(mockGetCreditUsage).toHaveBeenCalled();
expect(mockDeductCredits).not.toHaveBeenCalled();
expect(mockDeductCredits).toHaveBeenCalledWith({
accountId: "account-123",
creditsToDeduct: 1,
});
});
});

Expand Down
15 changes: 6 additions & 9 deletions lib/credits/handleChatCredits.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ interface HandleChatCreditsParams {

/**
* Handles credit deduction after chat completion.
* Calculates usage cost and deducts appropriate credits from the user's account.
* Always deducts at least 1 credit when accountId is present (round up from usage cost).
* @param usage - The language model usage data
* @param model - The model ID used for the chat
* @param accountId - The account ID to deduct credits from (optional)
Expand All @@ -27,15 +27,12 @@ export const handleChatCredits = async ({

try {
const usageCost = await getCreditUsage(usage, model);
const creditsToDeduct = Math.max(1, Math.round(usageCost * 100));

if (usageCost > 0) {
const creditsToDeduct = Math.max(1, Math.round(usageCost * 100));

await deductCredits({
accountId,
creditsToDeduct,
});
}
await deductCredits({
accountId,
creditsToDeduct,
});
} catch (error) {
console.error("Failed to handle chat credits:", error);
// Don't throw error to avoid breaking the chat flow
Expand Down
Loading