Skip to content
50 changes: 49 additions & 1 deletion src/lib/river/server.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import z from 'zod';
import type {
AgentRouter,
CreateAgentRouter,
CreateAiSdkRiverAgent,
CreateCustomRiverAgent,
ServerHook,
LifecycleHooks,
ServerEndpointHandler,
ServerSideAgentRunner
} from './types.js';
Expand Down Expand Up @@ -70,7 +73,7 @@ const createServerSideAgentRunner: ServerSideAgentRunner = (router) => {
};
};

const createServerEndpointHandler: ServerEndpointHandler = (router) => {
const createServerEndpointHandler: ServerEndpointHandler = (router, hooks) => {
const runner = createServerSideAgentRunner(router);
return {
POST: async (event) => {
Expand All @@ -95,10 +98,26 @@ const createServerEndpointHandler: ServerEndpointHandler = (router) => {
const error = new RiverError('Invalid body', bodyResult.error);
return new Response(JSON.stringify(error), { status: 400 });
}
const { agentId, input } = bodyResult.data;

const stream = new ReadableStream<Uint8Array>({
async start(streamController) {
// TODO: make it so that you can do some wait until and piping shit in here

const defaultErrorHandler = async (err: unknown) => {
if (hooks?.onError) {
const error =
err instanceof RiverError
? err
: new RiverError(`[RIVER:${agentId}] - Run Failed`, err);
await callServerHook(hooks.onError, { event, agentId, input, error });
} else {
console.error('Unhandled error during agent run:', err);
}
};

await callServerHook(hooks?.beforeAgentRun, { event, agentId, input, abortController });

try {
await runner.runAgent({
agentId: bodyResult.data.agentId,
Expand All @@ -111,13 +130,18 @@ const createServerEndpointHandler: ServerEndpointHandler = (router) => {
streamController.close();
} else {
streamController.error(error);

await defaultErrorHandler(error);
}
} finally {
streamController.close();
await callServerHook(hooks?.afterAgentRun, { event, agentId, input });
}
},
cancel(reason) {
abortController.abort(reason);

callServerHook(hooks?.onAbort, { event, agentId, input, reason });
}
});

Expand All @@ -126,6 +150,30 @@ const createServerEndpointHandler: ServerEndpointHandler = (router) => {
};
};

async function callServerHook<T>(
hook: ServerHook<T> | undefined,
args: T,
globalOnError?: (err: unknown) => Promise<void>
) {
if (!hook) return;

try {
if (typeof hook === 'function') {
await hook(args);
} else {
await hook.try(args);
}
} catch (err) {
if (hook && typeof hook !== 'function' && hook.catch) {
await hook.catch(err, { ...args });
} else if (globalOnError) {
await globalOnError(err);
} else {
console.error('Unhandled hook error:', err);
}
}
}

export const RIVER_SERVER = {
createAgentRouter,
createAiSdkAgent,
Expand Down
56 changes: 54 additions & 2 deletions src/lib/river/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,59 @@ type ServerSideAgentRunner = <T extends AgentRouter>(
};

type ServerEndpointHandler = <T extends AgentRouter>(
router: DecoratedAgentRouter<T>
router: DecoratedAgentRouter<T>,
hooks?: LifecycleHooks<T>
) => { POST: (event: RequestEvent) => Promise<Response> };

type ServerHook<T> =
| ((args: T) => Promise<void> | void)
| {
try: (args: T) => Promise<void> | void;
catch: (error: unknown, args: T) => Promise<void> | void;
};

type AllBeforeAgentRunArgs<T extends AgentRouter> = {
[K in keyof T]: {
event: RequestEvent;
agentId: K;
input: InferRiverAgentInputType<T[K]>;
abortController: AbortController;
};
}[keyof T];

type AllAfterAgentRunArgs<T extends AgentRouter> = {
[K in keyof T]: {
event: RequestEvent;
agentId: K;
input: InferRiverAgentInputType<T[K]>;
};
}[keyof T];

type AllOnAbortArgs<T extends AgentRouter> = {
[K in keyof T]: {
event: RequestEvent;
agentId: K;
input: InferRiverAgentInputType<T[K]>;
reason?: unknown;
};
}[keyof T];

type AllOnErrorArgs<T extends AgentRouter> = {
[K in keyof T]: {
event: RequestEvent;
error: RiverError;
agentId: K;
input: InferRiverAgentInputType<T[K]>;
};
}[keyof T];

type LifecycleHooks<T extends AgentRouter> = {
beforeAgentRun?: ServerHook<AllBeforeAgentRunArgs<T>>;
afterAgentRun?: ServerHook<AllAfterAgentRunArgs<T>>;
onAbort?: ServerHook<AllOnAbortArgs<T>>;
onError?: (ctx: AllOnErrorArgs<T>) => void | Promise<void>;
};

// CLIENT CALLER SECTION
type OnCompleteCallback = (data: { totalChunks: number; duration: number }) => void | Promise<void>;
type OnErrorCallback = (error: RiverError) => void | Promise<void>;
Expand Down Expand Up @@ -127,5 +177,7 @@ export type {
ServerSideAgentRunner,
ServerEndpointHandler,
ClientSideCaller,
AgentRouter
AgentRouter,
LifecycleHooks,
ServerHook
};
32 changes: 30 additions & 2 deletions src/routes/examples/river/+server.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,34 @@
import { RIVER_SERVER } from '$lib/index.js';
import { RIVER_SERVER, RiverError } from '$lib/index.js';
import { z } from 'zod/mini';
import { exampleRouter } from './router.js';

// in the real world this should probably be in src/routes/api/river/+server.ts

export const { POST } = RIVER_SERVER.createServerEndpointHandler(exampleRouter);
export const { POST } = RIVER_SERVER.createServerEndpointHandler(exampleRouter, {
beforeAgentRun: {
try: ({ input, agentId }) => {
console.log('[HOOK] - beforeAgentRun', { input });
if (agentId === 'exampleCustomAgent') {
// With support for type narrowing
console.log('Narrowed Input', input);
}
throw new RiverError('Example Throw');
},
catch: (error, { input }) => {
// Allows for per hook error handling
console.error('[HOOK ERROR] - beforeAgentRun', error);
}
},
afterAgentRun: async ({ event, input, agentId }) => {
if (agentId === 'chatAgent') {
input;
}
console.log('[HOOK] - afterAgentRun', { agentId });
},
onAbort: ({ event, input, agentId, reason }) => {
console.log('[HOOK] - onAbort', { reason });
},
onError: async ({ event, input, agentId, error }) => {
console.log('[HOOK] - onError', { error });
}
});