diff --git a/src/lib/river/server.ts b/src/lib/river/server.ts index 74e4d35..13e1ec3 100644 --- a/src/lib/river/server.ts +++ b/src/lib/river/server.ts @@ -1,8 +1,11 @@ import z from 'zod'; import type { + AgentRouter, CreateAgentRouter, CreateAiSdkRiverAgent, CreateCustomRiverAgent, + ServerHook, + LifecycleHooks, ServerEndpointHandler, ServerSideAgentRunner } from './types.js'; @@ -70,7 +73,7 @@ const createServerSideAgentRunner: ServerSideAgentRunner = (router) => { }; }; -const createServerEndpointHandler: ServerEndpointHandler = (router) => { +const createServerEndpointHandler: ServerEndpointHandler = (router, hooks) => { const runner = createServerSideAgentRunner(router); return { POST: async (event) => { @@ -95,10 +98,26 @@ const createServerEndpointHandler: ServerEndpointHandler = (router) => { const error = new RiverError('Invalid body', bodyResult.error); return new Response(JSON.stringify(error), { status: 400 }); } + const { agentId, input } = bodyResult.data; const stream = new ReadableStream({ async start(streamController) { // TODO: make it so that you can do some wait until and piping shit in here + + const defaultErrorHandler = async (err: unknown) => { + if (hooks?.onError) { + const error = + err instanceof RiverError + ? err + : new RiverError(`[RIVER:${agentId}] - Run Failed`, err); + await callServerHook(hooks.onError, { event, agentId, input, error }); + } else { + console.error('Unhandled error during agent run:', err); + } + }; + + await callServerHook(hooks?.beforeAgentRun, { event, agentId, input, abortController }); + try { await runner.runAgent({ agentId: bodyResult.data.agentId, @@ -111,13 +130,18 @@ const createServerEndpointHandler: ServerEndpointHandler = (router) => { streamController.close(); } else { streamController.error(error); + + await defaultErrorHandler(error); } } finally { streamController.close(); + await callServerHook(hooks?.afterAgentRun, { event, agentId, input }); } }, cancel(reason) { abortController.abort(reason); + + callServerHook(hooks?.onAbort, { event, agentId, input, reason }); } }); @@ -126,6 +150,30 @@ const createServerEndpointHandler: ServerEndpointHandler = (router) => { }; }; +async function callServerHook( + hook: ServerHook | undefined, + args: T, + globalOnError?: (err: unknown) => Promise +) { + if (!hook) return; + + try { + if (typeof hook === 'function') { + await hook(args); + } else { + await hook.try(args); + } + } catch (err) { + if (hook && typeof hook !== 'function' && hook.catch) { + await hook.catch(err, { ...args }); + } else if (globalOnError) { + await globalOnError(err); + } else { + console.error('Unhandled hook error:', err); + } + } +} + export const RIVER_SERVER = { createAgentRouter, createAiSdkAgent, diff --git a/src/lib/river/types.ts b/src/lib/river/types.ts index 83aefed..6310160 100644 --- a/src/lib/river/types.ts +++ b/src/lib/river/types.ts @@ -77,9 +77,59 @@ type ServerSideAgentRunner = ( }; type ServerEndpointHandler = ( - router: DecoratedAgentRouter + router: DecoratedAgentRouter, + hooks?: LifecycleHooks ) => { POST: (event: RequestEvent) => Promise }; +type ServerHook = + | ((args: T) => Promise | void) + | { + try: (args: T) => Promise | void; + catch: (error: unknown, args: T) => Promise | void; + }; + +type AllBeforeAgentRunArgs = { + [K in keyof T]: { + event: RequestEvent; + agentId: K; + input: InferRiverAgentInputType; + abortController: AbortController; + }; +}[keyof T]; + +type AllAfterAgentRunArgs = { + [K in keyof T]: { + event: RequestEvent; + agentId: K; + input: InferRiverAgentInputType; + }; +}[keyof T]; + +type AllOnAbortArgs = { + [K in keyof T]: { + event: RequestEvent; + agentId: K; + input: InferRiverAgentInputType; + reason?: unknown; + }; +}[keyof T]; + +type AllOnErrorArgs = { + [K in keyof T]: { + event: RequestEvent; + error: RiverError; + agentId: K; + input: InferRiverAgentInputType; + }; +}[keyof T]; + +type LifecycleHooks = { + beforeAgentRun?: ServerHook>; + afterAgentRun?: ServerHook>; + onAbort?: ServerHook>; + onError?: (ctx: AllOnErrorArgs) => void | Promise; +}; + // CLIENT CALLER SECTION type OnCompleteCallback = (data: { totalChunks: number; duration: number }) => void | Promise; type OnErrorCallback = (error: RiverError) => void | Promise; @@ -127,5 +177,7 @@ export type { ServerSideAgentRunner, ServerEndpointHandler, ClientSideCaller, - AgentRouter + AgentRouter, + LifecycleHooks, + ServerHook }; diff --git a/src/routes/examples/river/+server.ts b/src/routes/examples/river/+server.ts index 53ed8bd..5952b10 100644 --- a/src/routes/examples/river/+server.ts +++ b/src/routes/examples/river/+server.ts @@ -1,6 +1,34 @@ -import { RIVER_SERVER } from '$lib/index.js'; +import { RIVER_SERVER, RiverError } from '$lib/index.js'; +import { z } from 'zod/mini'; import { exampleRouter } from './router.js'; // in the real world this should probably be in src/routes/api/river/+server.ts -export const { POST } = RIVER_SERVER.createServerEndpointHandler(exampleRouter); +export const { POST } = RIVER_SERVER.createServerEndpointHandler(exampleRouter, { + beforeAgentRun: { + try: ({ input, agentId }) => { + console.log('[HOOK] - beforeAgentRun', { input }); + if (agentId === 'exampleCustomAgent') { + // With support for type narrowing + console.log('Narrowed Input', input); + } + throw new RiverError('Example Throw'); + }, + catch: (error, { input }) => { + // Allows for per hook error handling + console.error('[HOOK ERROR] - beforeAgentRun', error); + } + }, + afterAgentRun: async ({ event, input, agentId }) => { + if (agentId === 'chatAgent') { + input; + } + console.log('[HOOK] - afterAgentRun', { agentId }); + }, + onAbort: ({ event, input, agentId, reason }) => { + console.log('[HOOK] - onAbort', { reason }); + }, + onError: async ({ event, input, agentId, error }) => { + console.log('[HOOK] - onError', { error }); + } +});