diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index f21ffc4..bbfdf78 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -36,7 +36,7 @@ jobs: run: bunx prettier --check . - name: Type check (tsc) - run: bunx tsc --noEmit + run: bunx tsgo --noEmit - name: Run tests run: bun run test diff --git a/src/zod/event.ts b/src/zod/event.ts index d975186..5a11b22 100644 --- a/src/zod/event.ts +++ b/src/zod/event.ts @@ -6,6 +6,37 @@ import { eq } from "drizzle-orm"; import { EventError } from "../errors/event"; import { tagCache } from "../utils/tagCache"; +const fetchTagAmount = async ( + tag: string, + notFoundMessage: string +): Promise => { + const cachedAmount = tagCache.get(tag); + if (cachedAmount !== undefined) { + return cachedAmount; + } + + const db = getPostgresDB(); + try { + const [tagRow] = await db + .select() + .from(tagsTable) + .where(eq(tagsTable.tag, tag)) + .limit(1); + + if (!tagRow) { + throw EventError.validationFailed(notFoundMessage); + } + + tagCache.set(tag, tagRow.amount); + return tagRow.amount; + } catch (e) { + if (e instanceof EventError) { + throw e; + } + throw EventError.unknown(e as Error); + } +}; + const BaseEvent = z.object({ type: z.number(), // overwritten later by discriminators userId: USER_ID_CONFIG.validator, @@ -38,37 +69,14 @@ const SDKCallEvent = BaseEvent.extend({ ]), }) .transform(async (v) => { - // If a tag is provided, fetch the integer value for the tag and store it into debitAmount if (v.debit.case === "tag") { - const cachedAmount = tagCache.get(v.debit.value); - if (cachedAmount !== undefined) { - return { sdkCallType: v.sdkCallType, debitAmount: cachedAmount }; - } - - const db = getPostgresDB(); - try { - const [tagRow] = await db - .select() - .from(tagsTable) - .where(eq(tagsTable.tag, v.debit.value)) - .limit(1); - - if (!tagRow) { - throw EventError.validationFailed( - `Tag not found: ${v.debit.value}` - ); - } - - tagCache.set(v.debit.value, tagRow.amount); - return { sdkCallType: v.sdkCallType, debitAmount: tagRow.amount }; - } catch (e) { - if (e instanceof EventError) { - throw e; - } - throw EventError.unknown(e as Error); - } + const debitAmount = await fetchTagAmount( + v.debit.value, + `Tag not found: ${v.debit.value}` + ); + return { sdkCallType: v.sdkCallType, debitAmount }; } - // Otherwise use provided debitAmount (apply original transformation behavior) + return { sdkCallType: v.sdkCallType, debitAmount: Math.floor(v.debit.value * 100), @@ -112,37 +120,13 @@ const AITokenUsageEvent = BaseEvent.extend({ ]), }) .transform(async (v) => { - const db = getPostgresDB(); - // Process input debit let inputDebitAmount: number; if (v.inputDebit.case === "inputTag") { - const cachedAmount = tagCache.get(v.inputDebit.value); - if (cachedAmount !== undefined) { - inputDebitAmount = cachedAmount; - } else { - try { - const [tagRow] = await db - .select() - .from(tagsTable) - .where(eq(tagsTable.tag, v.inputDebit.value)) - .limit(1); - - if (!tagRow) { - throw EventError.validationFailed( - `Input tag not found: ${v.inputDebit.value}`, - ); - } - - tagCache.set(v.inputDebit.value, tagRow.amount); - inputDebitAmount = tagRow.amount; - } catch (e) { - if (e instanceof EventError) { - throw e; - } - throw EventError.unknown(e as Error); - } - } + inputDebitAmount = await fetchTagAmount( + v.inputDebit.value, + `Input tag not found: ${v.inputDebit.value}` + ); } else { inputDebitAmount = Math.floor(v.inputDebit.value * 100); } @@ -150,32 +134,10 @@ const AITokenUsageEvent = BaseEvent.extend({ // Process output debit let outputDebitAmount: number; if (v.outputDebit.case === "outputTag") { - const cachedAmount = tagCache.get(v.outputDebit.value); - if (cachedAmount !== undefined) { - outputDebitAmount = cachedAmount; - } else { - try { - const [tagRow] = await db - .select() - .from(tagsTable) - .where(eq(tagsTable.tag, v.outputDebit.value)) - .limit(1); - - if (!tagRow) { - throw EventError.validationFailed( - `Output tag not found: ${v.outputDebit.value}`, - ); - } - - tagCache.set(v.outputDebit.value, tagRow.amount); - outputDebitAmount = tagRow.amount; - } catch (e) { - if (e instanceof EventError) { - throw e; - } - throw EventError.unknown(e as Error); - } - } + outputDebitAmount = await fetchTagAmount( + v.outputDebit.value, + `Output tag not found: ${v.outputDebit.value}` + ); } else { outputDebitAmount = Math.floor(v.outputDebit.value * 100); }