Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
run: bunx prettier --check .

- name: Type check (tsc)
run: bunx tsc --noEmit
run: bunx tsgo --noEmit

- name: Run tests
run: bun run test
128 changes: 45 additions & 83 deletions src/zod/event.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,37 @@ import { eq } from "drizzle-orm";
import { EventError } from "../errors/event";
import { tagCache } from "../utils/tagCache";

const fetchTagAmount = async (
tag: string,
notFoundMessage: string
): Promise<number> => {
const cachedAmount = tagCache.get(tag);
if (cachedAmount !== undefined) {
return cachedAmount;
}

const db = getPostgresDB();
try {
const [tagRow] = await db
.select()
.from(tagsTable)
.where(eq(tagsTable.tag, tag))
.limit(1);

if (!tagRow) {
throw EventError.validationFailed(notFoundMessage);
}

tagCache.set(tag, tagRow.amount);
return tagRow.amount;
} catch (e) {
if (e instanceof EventError) {
throw e;
}
throw EventError.unknown(e as Error);
}
};

const BaseEvent = z.object({
type: z.number(), // overwritten later by discriminators
userId: USER_ID_CONFIG.validator,
Expand Down Expand Up @@ -38,37 +69,14 @@ const SDKCallEvent = BaseEvent.extend({
]),
})
.transform(async (v) => {
// If a tag is provided, fetch the integer value for the tag and store it into debitAmount
if (v.debit.case === "tag") {
const cachedAmount = tagCache.get(v.debit.value);
if (cachedAmount !== undefined) {
return { sdkCallType: v.sdkCallType, debitAmount: cachedAmount };
}

const db = getPostgresDB();
try {
const [tagRow] = await db
.select()
.from(tagsTable)
.where(eq(tagsTable.tag, v.debit.value))
.limit(1);

if (!tagRow) {
throw EventError.validationFailed(
`Tag not found: ${v.debit.value}`
);
}

tagCache.set(v.debit.value, tagRow.amount);
return { sdkCallType: v.sdkCallType, debitAmount: tagRow.amount };
} catch (e) {
if (e instanceof EventError) {
throw e;
}
throw EventError.unknown(e as Error);
}
const debitAmount = await fetchTagAmount(
v.debit.value,
`Tag not found: ${v.debit.value}`
);
return { sdkCallType: v.sdkCallType, debitAmount };
}
// Otherwise use provided debitAmount (apply original transformation behavior)

return {
sdkCallType: v.sdkCallType,
debitAmount: Math.floor(v.debit.value * 100),
Expand Down Expand Up @@ -112,70 +120,24 @@ const AITokenUsageEvent = BaseEvent.extend({
]),
})
.transform(async (v) => {
const db = getPostgresDB();

// Process input debit
let inputDebitAmount: number;
if (v.inputDebit.case === "inputTag") {
const cachedAmount = tagCache.get(v.inputDebit.value);
if (cachedAmount !== undefined) {
inputDebitAmount = cachedAmount;
} else {
try {
const [tagRow] = await db
.select()
.from(tagsTable)
.where(eq(tagsTable.tag, v.inputDebit.value))
.limit(1);

if (!tagRow) {
throw EventError.validationFailed(
`Input tag not found: ${v.inputDebit.value}`,
);
}

tagCache.set(v.inputDebit.value, tagRow.amount);
inputDebitAmount = tagRow.amount;
} catch (e) {
if (e instanceof EventError) {
throw e;
}
throw EventError.unknown(e as Error);
}
}
inputDebitAmount = await fetchTagAmount(
v.inputDebit.value,
`Input tag not found: ${v.inputDebit.value}`
);
} else {
inputDebitAmount = Math.floor(v.inputDebit.value * 100);
}

// Process output debit
let outputDebitAmount: number;
if (v.outputDebit.case === "outputTag") {
const cachedAmount = tagCache.get(v.outputDebit.value);
if (cachedAmount !== undefined) {
outputDebitAmount = cachedAmount;
} else {
try {
const [tagRow] = await db
.select()
.from(tagsTable)
.where(eq(tagsTable.tag, v.outputDebit.value))
.limit(1);

if (!tagRow) {
throw EventError.validationFailed(
`Output tag not found: ${v.outputDebit.value}`,
);
}

tagCache.set(v.outputDebit.value, tagRow.amount);
outputDebitAmount = tagRow.amount;
} catch (e) {
if (e instanceof EventError) {
throw e;
}
throw EventError.unknown(e as Error);
}
}
outputDebitAmount = await fetchTagAmount(
v.outputDebit.value,
`Output tag not found: ${v.outputDebit.value}`
);
} else {
outputDebitAmount = Math.floor(v.outputDebit.value * 100);
}
Expand Down
Loading