Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions apps/web/src/api/lib/uncertainty/calibration.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import { eq, sql } from "drizzle-orm";
import { uuidv7 } from "uuidv7";
import type { Database } from "../../../db/client";
import {
uncertaintyCalibrationSnapshot,
uncertaintyPrediction,
} from "../../../db/schema/uncertainty";
import { BUCKET_COUNT, bucketBounds, bucketIndex, brierScore } from "./cohort";

export type CohortRow = {
cohortKey: string;
claimedConfidence: number;
outcomeCorrectness: number | null;
state: "emitted" | "witnessed" | "orphaned" | "retired";
};

export type BucketSnapshot = {
cohortKey: string;
bucketLower: number;
bucketUpper: number;
claimedConfidence: number;
actualCorrectness: number;
predictionCount: number;
orphanCount: number;
brierScore: number;
};

/**
* Splits witnessed predictions in a cohort into 10 buckets of 0.1 width on
* claimed_confidence. Per bucket: mean claimed, mean correctness, count, and
* Brier score. Orphan count is per-cohort (not per-bucket) and attached to
* every emitted snapshot row so the dashboard can show the orphan share.
*
* Buckets with zero witnessed predictions are dropped -- a bucket only exists
* once at least one prediction has fully resolved.
*/
export function bucketCohort(rows: readonly CohortRow[]): BucketSnapshot[] {
if (rows.length === 0) return [];
const cohortKey = rows[0].cohortKey;

const witnessed = rows.filter(
(r): r is CohortRow & { outcomeCorrectness: number } =>
r.state === "witnessed" && r.outcomeCorrectness !== null,
);
const orphanCount = rows.filter((r) => r.state === "orphaned").length;

const buckets: { claimed: number[]; correctness: number[] }[] = Array.from(
{ length: BUCKET_COUNT },
() => ({ claimed: [], correctness: [] }),
);

for (const row of witnessed) {
const idx = bucketIndex(row.claimedConfidence);
buckets[idx].claimed.push(row.claimedConfidence);
buckets[idx].correctness.push(row.outcomeCorrectness);
}

const out: BucketSnapshot[] = [];
for (let i = 0; i < BUCKET_COUNT; i++) {
const b = buckets[i];
if (b.claimed.length === 0) continue;
const { lower, upper } = bucketBounds(i);
const meanClaimed = b.claimed.reduce((a, c) => a + c, 0) / b.claimed.length;
const meanCorrectness = b.correctness.reduce((a, c) => a + c, 0) / b.correctness.length;
out.push({
cohortKey,
bucketLower: lower,
bucketUpper: upper,
claimedConfidence: meanClaimed,
actualCorrectness: meanCorrectness,
predictionCount: b.claimed.length,
orphanCount,
brierScore: brierScore(b.claimed, b.correctness),
});
}
return out;
}

/**
* Recompute snapshots for every active cohort. "Active" means the cohort has
* at least one resolved (witnessed or orphaned) prediction.
*/
export async function rollCalibration(
db: Database,
now: Date = new Date(),
): Promise<{ cohorts: number; rows: number }> {
const cohortKeys = await db
.select({ cohortKey: uncertaintyPrediction.cohortKey })
.from(uncertaintyPrediction)
.where(sql`${uncertaintyPrediction.state} IN ('witnessed', 'orphaned')`)
.groupBy(uncertaintyPrediction.cohortKey);

let totalRows = 0;
for (const { cohortKey } of cohortKeys) {
const rows = (await db
.select({
cohortKey: uncertaintyPrediction.cohortKey,
claimedConfidence: uncertaintyPrediction.claimedConfidence,
outcomeCorrectness: uncertaintyPrediction.outcomeCorrectness,
state: uncertaintyPrediction.state,
})
.from(uncertaintyPrediction)
.where(eq(uncertaintyPrediction.cohortKey, cohortKey))) as CohortRow[];

const snapshots = bucketCohort(rows);

await db
.delete(uncertaintyCalibrationSnapshot)
.where(eq(uncertaintyCalibrationSnapshot.cohortKey, cohortKey));

if (snapshots.length > 0) {
await db.insert(uncertaintyCalibrationSnapshot).values(
snapshots.map((s) => ({
id: uuidv7(),
...s,
computedAt: now,
})),
);
totalRows += snapshots.length;
}
}

return { cohorts: cohortKeys.length, rows: totalRows };
}
26 changes: 18 additions & 8 deletions apps/web/src/api/routes/uncertainty.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { createDb } from "../../db/client";
import { uncertaintyCalibrationSnapshot, uncertaintyPrediction } from "../../db/schema/uncertainty";
import { outcomeLabelSchema, surfaceSchema } from "../../db/schema/uncertainty.zod";
import { cohortKey, DAY_MS, DEFAULT_ORPHAN_TTL_DAYS } from "../lib/uncertainty/cohort";
import { rollCalibration } from "../lib/uncertainty/calibration";
import { sweepOrphans } from "../lib/uncertainty/orphan-sweep";
import { requireAuth } from "../middleware/auth";
import type { HonoEnv } from "../types";
Expand Down Expand Up @@ -142,14 +143,23 @@ const uncertaintyRoutes = new Hono<HonoEnv>()
return c.json({ surfaces: rows });
});

const internalRoutes = new Hono<HonoEnv>().post("/orphan-sweep", async (c) => {
if (!authorizeCron({ req: { header: (n) => c.req.header(n) }, env: c.env })) {
return c.json({ error: "Unauthorized" }, 401);
}
const db = createDb(c.env.DB);
const result = await sweepOrphans(db);
return c.json(result);
});
const internalRoutes = new Hono<HonoEnv>()
.post("/orphan-sweep", async (c) => {
if (!authorizeCron({ req: { header: (n) => c.req.header(n) }, env: c.env })) {
return c.json({ error: "Unauthorized" }, 401);
}
const db = createDb(c.env.DB);
const result = await sweepOrphans(db);
return c.json(result);
})
.post("/calibration-roll", async (c) => {
if (!authorizeCron({ req: { header: (n) => c.req.header(n) }, env: c.env })) {
return c.json({ error: "Unauthorized" }, 401);
}
const db = createDb(c.env.DB);
const result = await rollCalibration(db);
return c.json(result);
});

uncertaintyRoutes.route("/internal", internalRoutes);

Expand Down
97 changes: 97 additions & 0 deletions apps/web/tests/api/lib/uncertainty/calibration.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import { describe, expect, it } from "vitest";
import { bucketCohort, type CohortRow } from "../../../../src/api/lib/uncertainty/calibration";

const COHORT = "rafters.color|claude-sonnet-4-7|2026-04-01|0.8";

function row(partial: Partial<CohortRow>): CohortRow {
return {
cohortKey: COHORT,
claimedConfidence: 0.85,
outcomeCorrectness: 1,
state: "witnessed",
...partial,
};
}

describe("bucketCohort", () => {
it("returns no buckets for an empty cohort", () => {
expect(bucketCohort([])).toEqual([]);
});

it("drops buckets with zero witnessed predictions", () => {
// Single witnessed prediction at 0.85 -> bucket 8 only
const out = bucketCohort([row({ claimedConfidence: 0.85, outcomeCorrectness: 1 })]);
expect(out).toHaveLength(1);
expect(out[0].bucketLower).toBeCloseTo(0.8, 5);
expect(out[0].bucketUpper).toBeCloseTo(0.9, 5);
});

it("means claimed and correctness within a bucket", () => {
const out = bucketCohort([
row({ claimedConfidence: 0.82, outcomeCorrectness: 1 }),
row({ claimedConfidence: 0.88, outcomeCorrectness: 0.5 }),
]);
expect(out).toHaveLength(1);
expect(out[0].claimedConfidence).toBeCloseTo(0.85, 5);
expect(out[0].actualCorrectness).toBeCloseTo(0.75, 5);
expect(out[0].predictionCount).toBe(2);
});

it("computes Brier as mean squared error within the bucket", () => {
const out = bucketCohort([
row({ claimedConfidence: 0.8, outcomeCorrectness: 1 }), // (0.8-1)^2 = 0.04
row({ claimedConfidence: 0.85, outcomeCorrectness: 0 }), // (0.85-0)^2 = 0.7225
]);
expect(out[0].brierScore).toBeCloseTo((0.04 + 0.7225) / 2, 5);
});

it("ignores witnessed rows with null correctness", () => {
const out = bucketCohort([
row({ claimedConfidence: 0.85, outcomeCorrectness: null, state: "witnessed" }),
row({ claimedConfidence: 0.85, outcomeCorrectness: 1 }),
]);
expect(out).toHaveLength(1);
expect(out[0].predictionCount).toBe(1);
});

it("ignores emitted-but-not-witnessed rows but does not double-count", () => {
const out = bucketCohort([
row({ claimedConfidence: 0.85, outcomeCorrectness: null, state: "emitted" }),
row({ claimedConfidence: 0.85, outcomeCorrectness: 1 }),
]);
expect(out).toHaveLength(1);
expect(out[0].predictionCount).toBe(1);
});

it("attaches per-cohort orphan_count to every emitted bucket row", () => {
const out = bucketCohort([
row({ claimedConfidence: 0.25, outcomeCorrectness: 1 }),
row({ claimedConfidence: 0.85, outcomeCorrectness: 0.5 }),
row({ claimedConfidence: 0.85, outcomeCorrectness: null, state: "orphaned" }),
row({ claimedConfidence: 0.85, outcomeCorrectness: null, state: "orphaned" }),
]);
expect(out).toHaveLength(2);
for (const snapshot of out) {
expect(snapshot.orphanCount).toBe(2);
}
});

it("places confidence=1.0 in the top bucket", () => {
const out = bucketCohort([row({ claimedConfidence: 1, outcomeCorrectness: 1 })]);
expect(out).toHaveLength(1);
expect(out[0].bucketLower).toBeCloseTo(0.9, 5);
});

it("returns buckets sorted by lower bound", () => {
const out = bucketCohort([
row({ claimedConfidence: 0.95, outcomeCorrectness: 1 }),
row({ claimedConfidence: 0.15, outcomeCorrectness: 0 }),
row({ claimedConfidence: 0.55, outcomeCorrectness: 0.5 }),
]);
expect(out.map((s) => s.bucketLower)).toEqual([
expect.closeTo(0.1, 5),
expect.closeTo(0.5, 5),
expect.closeTo(0.9, 5),
]);
});
});
11 changes: 6 additions & 5 deletions apps/web/wrangler.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,12 @@
"vars": {
"CRON_SECRET": "",
},
// Hourly orphan sweep. The Astro adapter does not expose a scheduled()
// entrypoint, so the trigger calls POST /api/uncertainty/internal/orphan-sweep
// via an external scheduler (or a future companion worker) using the
// CRON_SECRET bearer token.
// Hourly orphan sweep + nightly calibration roll. The Astro adapter does
// not expose a scheduled() entrypoint, so triggers call the internal
// endpoints (/api/uncertainty/internal/orphan-sweep at :00 every hour,
// /api/uncertainty/internal/calibration-roll at 03:15 UTC daily) via an
// external scheduler using the CRON_SECRET bearer token.
"triggers": {
"crons": ["0 * * * *"],
"crons": ["0 * * * *", "15 3 * * *"],
},
}
Loading