From cac4ee99f2736ab995e71fe1e1b7e66057110a9e Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Sat, 10 May 2025 19:40:54 +0100 Subject: [PATCH 01/14] feat: creating llm session abstraction - LLM Section is a wrapper to handle LLM inference based on the selected provider --- ext/ai/js/llm/llm_session.ts | 67 ++++++++++++++++++++++++++++++++++++ ext/ai/lib.rs | 3 +- 2 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 ext/ai/js/llm/llm_session.ts diff --git a/ext/ai/js/llm/llm_session.ts b/ext/ai/js/llm/llm_session.ts new file mode 100644 index 000000000..7c9ce60fe --- /dev/null +++ b/ext/ai/js/llm/llm_session.ts @@ -0,0 +1,67 @@ +// @ts-ignore deno_core environment +const core = globalThis.Deno.core; + +export type LLMRunInput = { + /** + * Stream response from model. Applies only for LLMs like `mistral` (default: false) + */ + stream?: boolean; + + /** + * Automatically abort the request to the model after specified time (in seconds). Applies only for LLMs like `mistral` (default: 60) + */ + timeout?: number; + + prompt: string; + + signal?: AbortSignal; +}; + +export interface ILLMProviderOptions { + inferenceAPIHost: string; + model: string; +} + +export interface ILLMProvider { + // TODO:(kallebysantos) remove 'any' + getStream(prompt: string, signal: AbortSignal): Promise>; + getText(prompt: string, signal: AbortSignal): Promise; +} + +export class LLMSession { + #inner: ILLMProvider; + + constructor(provider: ILLMProvider) { + this.#inner = provider; + } + + static fromProvider(name: string, opts: ILLMProviderOptions) { + const ProviderType = providers[name]; + if (!ProviderType) throw new Error('invalid provider'); + + const provider = new ProviderType(opts); + + return new LLMSession(provider); + } + + run( + opts: LLMRunInput, + ): Promise> | Promise { + const isStream = opts.stream ?? false; + + const timeoutSeconds = typeof opts.timeout === 'number' ? opts.timeout : 60; + const timeoutMs = timeoutSeconds * 1000; + + const timeoutSignal = AbortSignal.timeout(timeoutMs); + const abortSignals = [opts.signal, timeoutSignal] + .filter((it) => it instanceof AbortSignal); + const signal = AbortSignal.any(abortSignals); + + if (isStream) { + return this.#inner.getStream(opts.prompt, signal); + } + + return this.#inner.getText(opts.prompt, signal); + } +} + diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index 224b0450f..dd8bd33fb 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -55,7 +55,8 @@ deno_core::extension!( "util/event_stream_parser.mjs", "util/event_source_stream.mjs", "onnxruntime/onnx.js", - "onnxruntime/cache_adapter.js" + "onnxruntime/cache_adapter.js", + "llm/llm_session.ts", ] ); From f311d25f1a08dd4533ee232631085a17b670f8d2 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Sat, 10 May 2025 19:55:08 +0100 Subject: [PATCH 02/14] stamp: moving streaming utils to `llm` folder - Extracting json parsers to a separated file - Moving LLM stream related code to a separated folder --- ext/ai/js/ai.js | 84 +------------------ ext/ai/js/llm/utils/event_source_stream.mjs | 33 ++++++++ ext/ai/js/llm/utils/event_stream_parser.mjs | 92 +++++++++++++++++++++ ext/ai/js/llm/utils/json_parser.ts | 82 ++++++++++++++++++ ext/ai/lib.rs | 5 +- 5 files changed, 212 insertions(+), 84 deletions(-) create mode 100644 ext/ai/js/llm/utils/event_source_stream.mjs create mode 100644 ext/ai/js/llm/utils/event_stream_parser.mjs create mode 100644 ext/ai/js/llm/utils/json_parser.ts diff --git a/ext/ai/js/ai.js b/ext/ai/js/ai.js index c2a306927..8bb473678 100644 --- a/ext/ai/js/ai.js +++ b/ext/ai/js/ai.js @@ -1,88 +1,8 @@ -import "ext:ai/onnxruntime/onnx.js"; -import EventSourceStream from "ext:ai/util/event_source_stream.mjs"; +import 'ext:ai/onnxruntime/onnx.js'; +import { parseJSON, parseJSONOverEventStream } from './llm/utils/json_parser.ts'; const core = globalThis.Deno.core; -/** - * @param {ReadableStream p !== "")) { - try { - yield JSON.parse(part); - } catch (error) { - yield { error }; - } - } -}; - -/** - * @param {ReadableStream} */ - const reader = decoder.readable.getReader(); - - while (true) { - try { - if (signal.aborted) { - reader.cancel(signal.reason); - reader.releaseLock(); - return { error: signal.reason }; - } - - const { done, value } = await reader.read(); - - if (done) { - break; - } - - yield JSON.parse(value.data); - } catch (error) { - yield { error }; - } - } -}; - class Session { model; init; diff --git a/ext/ai/js/llm/utils/event_source_stream.mjs b/ext/ai/js/llm/utils/event_source_stream.mjs new file mode 100644 index 000000000..fa355da05 --- /dev/null +++ b/ext/ai/js/llm/utils/event_source_stream.mjs @@ -0,0 +1,33 @@ +import EventStreamParser from './event_stream_parser.mjs'; +/** + * A Web stream which handles Server-Sent Events from a binary ReadableStream like you get from the fetch API. + * Implements the TransformStream interface, and can be used with the Streams API as such. + */ +class EventSourceStream { + constructor() { + // Two important things to note here: + // 1. The SSE spec allows for an optional UTF-8 BOM. + // 2. We have to use a *streaming* decoder, in case two adjacent data chunks are split up in the middle of a + // multibyte Unicode character. Trying to parse the two separately would result in data corruption. + const decoder = new TextDecoderStream('utf-8'); + let parser; + const sseStream = new TransformStream({ + start(controller) { + parser = new EventStreamParser((data, eventType, lastEventId) => { + controller.enqueue( + new MessageEvent(eventType, { data, lastEventId }), + ); + }); + }, + transform(chunk) { + parser.push(chunk); + }, + }); + + decoder.readable.pipeThrough(sseStream); + + this.readable = sseStream.readable; + this.writable = decoder.writable; + } +} +export default EventSourceStream; diff --git a/ext/ai/js/llm/utils/event_stream_parser.mjs b/ext/ai/js/llm/utils/event_stream_parser.mjs new file mode 100644 index 000000000..263229a60 --- /dev/null +++ b/ext/ai/js/llm/utils/event_stream_parser.mjs @@ -0,0 +1,92 @@ +// https://github.com/valadaptive/server-sent-stream + +/** + * A parser for the server-sent events stream format. + * + * Note that this parser does not handle text decoding! To do it correctly, use a streaming text decoder, since the + * stream could be split up mid-Unicode character, and decoding each chunk at once could lead to incorrect results. + * + * This parser is used by streaming chunks in using the {@link push} method, and then calling the {@link end} method + * when the stream has ended. + */ +class EventStreamParser { + /** + * Construct a new parser for a single stream. + * @param onEvent A callback which will be called for each new event parsed. The parameters in order are the + * event data, the event type, and the last seen event ID. This may be called none, once, or many times per push() + * call, and may be called from the end() call. + */ + constructor(onEvent) { + this.streamBuffer = ""; + this.lastEventId = ""; + this.onEvent = onEvent; + } + /** + * Process a single incoming chunk of the event stream. + */ + _processChunk() { + // Events are separated by two newlines + const events = this.streamBuffer.split(/\r\n\r\n|\r\r|\n\n/g); + if (events.length === 0) { + return; + } + // The leftover text to remain in the buffer is whatever doesn't have two newlines after it. If the buffer ended + // with two newlines, this will be an empty string. + this.streamBuffer = events.pop(); + for (const eventChunk of events) { + let eventType = ""; + // Split up by single newlines. + const lines = eventChunk.split(/\n|\r|\r\n/g); + let eventData = ""; + for (const line of lines) { + const lineMatch = /([^:]+)(?:: ?(.*))?/.exec(line); + if (lineMatch) { + const field = lineMatch[1]; + const value = lineMatch[2] || ""; + switch (field) { + case "event": + eventType = value; + break; + case "data": + eventData += value; + eventData += "\n"; + break; + case "id": + // The ID field cannot contain null, per the spec + if (!value.includes("\0")) { + this.lastEventId = value; + } + break; + // We do nothing for the `delay` type, and other types are explicitly ignored + } + } + } + // https://html.spec.whatwg.org/multipage/server-sent-events.html#dispatchMessage + // Skip the event if the data buffer is the empty string. + if (eventData === "") { + continue; + } + if (eventData[eventData.length - 1] === "\n") { + eventData = eventData.slice(0, -1); + } + // Trim the *last* trailing newline only. + this.onEvent(eventData, eventType || "message", this.lastEventId); + } + } + /** + * Push a new chunk of data to the parser. This may cause the {@link onEvent} callback to be called, possibly + * multiple times depending on the number of events contained within the chunk. + * @param chunk The incoming chunk of data. + */ + push(chunk) { + this.streamBuffer += chunk; + this._processChunk(); + } + /** + * Indicate that the stream has ended. + */ + end() { + // This is a no-op + } +} +export default EventStreamParser; diff --git a/ext/ai/js/llm/utils/json_parser.ts b/ext/ai/js/llm/utils/json_parser.ts new file mode 100644 index 000000000..636ab4292 --- /dev/null +++ b/ext/ai/js/llm/utils/json_parser.ts @@ -0,0 +1,82 @@ +import EventSourceStream from './event_source_stream.mjs'; + +// Adapted from https://github.com/ollama/ollama-js/blob/6a4bfe3ab033f611639dfe4249bdd6b9b19c7256/src/utils.ts#L262 +// TODO:(kallebysantos) need to simplify it +export async function* parseJSON( + itr: ReadableStream, + signal: AbortSignal, +) { + let buffer = ''; + + const decoder = new TextDecoder('utf-8'); + const reader = itr.getReader(); + + while (true) { + try { + if (signal.aborted) { + reader.cancel(signal.reason); + reader.releaseLock(); + return { error: signal.reason }; + } + + const { done, value } = await reader.read(); + + if (done) { + break; + } + + buffer += decoder.decode(value); + + const parts = buffer.split('\n'); + + buffer = parts.pop() ?? ''; + + for (const part of parts) { + yield JSON.parse(part) as T; + } + } catch (error) { + yield { error }; + } + } + + for (const part of buffer.split('\n').filter((p) => p !== '')) { + try { + yield JSON.parse(part) as T; + } catch (error) { + yield { error }; + } + } +} + +// TODO:(kallebysantos) need to simplify it +export async function* parseJSONOverEventStream( + itr: ReadableStream, + signal: AbortSignal, +) { + const decoder = new EventSourceStream(); + + itr.pipeThrough(decoder); + + const reader: ReadableStreamDefaultReader = decoder.readable + .getReader(); + + while (true) { + try { + if (signal.aborted) { + reader.cancel(signal.reason); + reader.releaseLock(); + return { error: signal.reason }; + } + + const { done, value } = await reader.read(); + + if (done) { + break; + } + + yield JSON.parse(value.data); + } catch (error) { + yield { error }; + } + } +} diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index dd8bd33fb..2f78cc90f 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -52,11 +52,12 @@ deno_core::extension!( esm = [ dir "js", "ai.js", - "util/event_stream_parser.mjs", - "util/event_source_stream.mjs", "onnxruntime/onnx.js", "onnxruntime/cache_adapter.js", "llm/llm_session.ts", + "llm/utils/json_parser.ts", + "llm/utils/event_stream_parser.mjs", + "llm/utils/event_source_stream.mjs", ] ); From 0bf30ca4cbc6674700b40690ff7b5c498fea156d Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Sat, 10 May 2025 20:18:04 +0100 Subject: [PATCH 03/14] feat: implementing Ollama LLM provider - Applying LLM provider interfaces to implement the Ollama provider --- ext/ai/js/ai.js | 78 ++++++++++---------- ext/ai/js/llm/llm_session.ts | 11 ++- ext/ai/js/llm/providers/ollama.ts | 115 ++++++++++++++++++++++++++++++ ext/ai/lib.rs | 1 + 4 files changed, 164 insertions(+), 41 deletions(-) create mode 100644 ext/ai/js/llm/providers/ollama.ts diff --git a/ext/ai/js/ai.js b/ext/ai/js/ai.js index 8bb473678..f8f3da8e2 100644 --- a/ext/ai/js/ai.js +++ b/ext/ai/js/ai.js @@ -1,5 +1,6 @@ import 'ext:ai/onnxruntime/onnx.js'; import { parseJSON, parseJSONOverEventStream } from './llm/utils/json_parser.ts'; +import { LLMSession } from './llm/llm_session.ts'; const core = globalThis.Deno.core; @@ -13,10 +14,10 @@ class Session { this.model = model; this.is_ext_inference_api = false; - if (model === "gte-small") { + if (model === 'gte-small') { this.init = core.ops.op_ai_init_model(model); } else { - this.inferenceAPIHost = core.ops.op_get_env("AI_INFERENCE_API_HOST"); + this.inferenceAPIHost = core.ops.op_get_env('AI_INFERENCE_API_HOST'); this.is_ext_inference_api = !!this.inferenceAPIHost; // only enable external inference API if env variable is set } } @@ -26,16 +27,30 @@ class Session { if (this.is_ext_inference_api) { const stream = opts.stream ?? false; + /** @type {'ollama' | 'openaicompatible'} */ + const mode = opts.mode ?? 'ollama'; + + if (mode === 'ollama') { + // Using the new LLMSession API + const llmSession = LLMSession.fromProvider('ollama', { + inferenceAPIHost: this.inferenceAPIHost, + model: this.model, + }); + + return await llmSession.run({ + prompt, + stream, + signal: opts.signal, + timeout: opts.timeout, + }); + } + // default timeout 60s - const timeout = typeof opts.timeout === "number" ? opts.timeout : 60; + const timeout = typeof opts.timeout === 'number' ? opts.timeout : 60; const timeoutMs = timeout * 1000; - /** @type {'ollama' | 'openaicompatible'} */ - const mode = opts.mode ?? "ollama"; - switch (mode) { - case "ollama": - case "openaicompatible": + case 'openaicompatible': break; default: @@ -48,15 +63,15 @@ class Session { const signal = AbortSignal.any(signals); - const path = mode === "ollama" ? "/api/generate" : "/v1/chat/completions"; - const body = mode === "ollama" ? { prompt } : prompt; + const path = '/v1/chat/completions'; + const body = prompt; const res = await fetch( new URL(path, this.inferenceAPIHost), { - method: "POST", + method: 'POST', headers: { - "Content-Type": "application/json", + 'Content-Type': 'application/json', }, body: JSON.stringify({ model: this.model, @@ -74,20 +89,16 @@ class Session { } if (!res.body) { - throw new Error("Missing body"); + throw new Error('Missing body'); } - const parseGenFn = mode === "ollama" - ? parseJSON - : stream === true - ? parseJSONOverEventStream - : parseJSON; + const parseGenFn = stream === true ? parseJSONOverEventStream : parseJSON; const itr = parseGenFn(res.body, signal); if (stream) { return (async function* () { for await (const message of itr) { - if ("error" in message) { + if ('error' in message) { if (message.error instanceof Error) { throw message.error; } else { @@ -98,20 +109,12 @@ class Session { yield message; switch (mode) { - case "ollama": { - if (message.done) { - return; - } - - break; - } - - case "openaicompatible": { + case 'openaicompatible': { const finishReason = message.choices[0].finish_reason; if (finishReason) { - if (finishReason !== "stop") { - throw new Error("Expected a completed response."); + if (finishReason !== 'stop') { + throw new Error('Expected a completed response.'); } return; @@ -121,18 +124,18 @@ class Session { } default: - throw new Error("unreachable"); + throw new Error('unreachable'); } } throw new Error( - "Did not receive done or success response in stream.", + 'Did not receive done or success response in stream.', ); })(); } else { const message = await itr.next(); - if (message.value && "error" in message.value) { + if (message.value && 'error' in message.value) { const error = message.value.error; if (error instanceof Error) { @@ -142,12 +145,10 @@ class Session { } } - const finish = mode === "ollama" - ? message.value.done - : message.value.choices[0].finish_reason === "stop"; + const finish = message.value.choices[0].finish_reason === 'stop'; if (finish !== true) { - throw new Error("Expected a completed response."); + throw new Error('Expected a completed response.'); } return message.value; @@ -172,8 +173,7 @@ class Session { } const MAIN_WORKER_API = { - tryCleanupUnusedSession: () => - /* async */ core.ops.op_ai_try_cleanup_unused_session(), + tryCleanupUnusedSession: () => /* async */ core.ops.op_ai_try_cleanup_unused_session(), }; const USER_WORKER_API = { diff --git a/ext/ai/js/llm/llm_session.ts b/ext/ai/js/llm/llm_session.ts index 7c9ce60fe..c6bfcf283 100644 --- a/ext/ai/js/llm/llm_session.ts +++ b/ext/ai/js/llm/llm_session.ts @@ -1,3 +1,5 @@ +import { OllamaLLMSession } from './providers/ollama.ts'; + // @ts-ignore deno_core environment const core = globalThis.Deno.core; @@ -28,6 +30,12 @@ export interface ILLMProvider { getText(prompt: string, signal: AbortSignal): Promise; } +export const providers = { + 'ollama': OllamaLLMSession, +} satisfies Record ILLMProvider>; + +export type LLMProviderName = keyof typeof providers; + export class LLMSession { #inner: ILLMProvider; @@ -35,7 +43,7 @@ export class LLMSession { this.#inner = provider; } - static fromProvider(name: string, opts: ILLMProviderOptions) { + static fromProvider(name: LLMProviderName, opts: ILLMProviderOptions) { const ProviderType = providers[name]; if (!ProviderType) throw new Error('invalid provider'); @@ -64,4 +72,3 @@ export class LLMSession { return this.#inner.getText(opts.prompt, signal); } } - diff --git a/ext/ai/js/llm/providers/ollama.ts b/ext/ai/js/llm/providers/ollama.ts new file mode 100644 index 000000000..a901d5725 --- /dev/null +++ b/ext/ai/js/llm/providers/ollama.ts @@ -0,0 +1,115 @@ +import { ILLMProvider, ILLMProviderOptions } from '../llm_session.ts'; +import { parseJSON } from '../utils/json_parser.ts'; + +export type OllamaProviderOptions = ILLMProviderOptions; + +export type OllamaMessage = { + model: string; + created_at: Date; + response: string; + done: boolean; + context: number[]; + total_duration: number; + load_duration: number; + prompt_eval_count: number; + prompt_eval_duration: number; + eval_count: number; + eval_duration: number; +}; + +export class OllamaLLMSession implements ILLMProvider { + opts: OllamaProviderOptions; + + constructor(opts: OllamaProviderOptions) { + this.opts = opts; + } + + // ref: https://github.com/ollama/ollama-js/blob/6a4bfe3ab033f611639dfe4249bdd6b9b19c7256/src/utils.ts#L26 + async getStream( + prompt: string, + signal: AbortSignal, + ): Promise> { + const generator = await this.generate(prompt, signal, true); + + const stream = async function* () { + for await (const message of generator) { + if ('error' in message) { + if (message.error instanceof Error) { + throw message.error; + } else { + throw new Error(message.error as string); + } + } + + yield message; + if (message.done) { + return; + } + } + + throw new Error( + 'Did not receive done or success response in stream.', + ); + }; + + return stream(); + } + + async getText(prompt: string, signal: AbortSignal): Promise { + const generator = await this.generate(prompt, signal); + + const message = await generator.next(); + + if (message.value && 'error' in message.value) { + const error = message.value.error; + + if (error instanceof Error) { + throw error; + } else { + throw new Error(error); + } + } + + const response = message.value; + + if (!response?.done) { + throw new Error('Expected a completed response.'); + } + + return response; + } + + private async generate( + prompt: string, + signal: AbortSignal, + stream: boolean = false, + ) { + const res = await fetch( + new URL('/api/generate', this.opts.inferenceAPIHost), + { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + model: this.opts.model, + stream, + prompt, + }), + signal, + }, + ); + + if (!res.ok) { + throw new Error( + `Failed to fetch inference API host. Status ${res.status}: ${res.statusText}`, + ); + } + + if (!res.body) { + throw new Error('Missing body'); + } + + return parseJSON(res.body, signal); + } +} diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index 2f78cc90f..3957e2c6d 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -55,6 +55,7 @@ deno_core::extension!( "onnxruntime/onnx.js", "onnxruntime/cache_adapter.js", "llm/llm_session.ts", + "llm/providers/ollama.ts", "llm/utils/json_parser.ts", "llm/utils/event_stream_parser.mjs", "llm/utils/event_source_stream.mjs", From fedda2817774c1c8e68208ba38ed5b79803104c4 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Sun, 11 May 2025 13:47:28 +0100 Subject: [PATCH 04/14] feat: implementing 'OpenAI compatible' provider - Applying LLM provider interfaces to implement the 'openaicompatible' mode --- ext/ai/js/ai.js | 145 ++++------------------- ext/ai/js/llm/llm_session.ts | 19 ++- ext/ai/js/llm/providers/ollama.ts | 42 +++---- ext/ai/js/llm/providers/openai.ts | 191 ++++++++++++++++++++++++++++++ ext/ai/lib.rs | 1 + 5 files changed, 248 insertions(+), 150 deletions(-) create mode 100644 ext/ai/js/llm/providers/openai.ts diff --git a/ext/ai/js/ai.js b/ext/ai/js/ai.js index f8f3da8e2..55794890e 100644 --- a/ext/ai/js/ai.js +++ b/ext/ai/js/ai.js @@ -1,6 +1,5 @@ import 'ext:ai/onnxruntime/onnx.js'; -import { parseJSON, parseJSONOverEventStream } from './llm/utils/json_parser.ts'; -import { LLMSession } from './llm/llm_session.ts'; +import { LLMSession, providers } from './llm/llm_session.ts'; const core = globalThis.Deno.core; @@ -9,11 +8,15 @@ class Session { init; is_ext_inference_api; inferenceAPIHost; + extraOpts; - constructor(model) { + // TODO:(kallebysantos) get 'provider' type here and use type checking to suggest Inputs when run + constructor(model, opts = {}) { this.model = model; this.is_ext_inference_api = false; + this.extraOpts = opts; + // TODO:(kallebysantos) do we still need gte-small? if (model === 'gte-small') { this.init = core.ops.op_ai_init_model(model); } else { @@ -28,131 +31,25 @@ class Session { const stream = opts.stream ?? false; /** @type {'ollama' | 'openaicompatible'} */ + // TODO:(kallebysantos) get mode from 'new' and apply type checking based on that const mode = opts.mode ?? 'ollama'; - if (mode === 'ollama') { - // Using the new LLMSession API - const llmSession = LLMSession.fromProvider('ollama', { - inferenceAPIHost: this.inferenceAPIHost, - model: this.model, - }); - - return await llmSession.run({ - prompt, - stream, - signal: opts.signal, - timeout: opts.timeout, - }); - } - - // default timeout 60s - const timeout = typeof opts.timeout === 'number' ? opts.timeout : 60; - const timeoutMs = timeout * 1000; - - switch (mode) { - case 'openaicompatible': - break; - - default: - throw new TypeError(`invalid mode: ${mode}`); - } - - const timeoutSignal = AbortSignal.timeout(timeoutMs); - const signals = [opts.signal, timeoutSignal] - .filter((it) => it instanceof AbortSignal); - - const signal = AbortSignal.any(signals); - - const path = '/v1/chat/completions'; - const body = prompt; - - const res = await fetch( - new URL(path, this.inferenceAPIHost), - { - method: 'POST', - headers: { - 'Content-Type': 'application/json', - }, - body: JSON.stringify({ - model: this.model, - stream, - ...body, - }), - }, - { signal }, - ); - - if (!res.ok) { - throw new Error( - `Failed to fetch inference API host. Status ${res.status}: ${res.statusText}`, - ); + if (!Object.keys(providers).includes(mode)) { + throw new TypeError(`invalid mode: ${mode}`); } - if (!res.body) { - throw new Error('Missing body'); - } - - const parseGenFn = stream === true ? parseJSONOverEventStream : parseJSON; - const itr = parseGenFn(res.body, signal); - - if (stream) { - return (async function* () { - for await (const message of itr) { - if ('error' in message) { - if (message.error instanceof Error) { - throw message.error; - } else { - throw new Error(message.error); - } - } - - yield message; - - switch (mode) { - case 'openaicompatible': { - const finishReason = message.choices[0].finish_reason; - - if (finishReason) { - if (finishReason !== 'stop') { - throw new Error('Expected a completed response.'); - } - - return; - } - - break; - } - - default: - throw new Error('unreachable'); - } - } - - throw new Error( - 'Did not receive done or success response in stream.', - ); - })(); - } else { - const message = await itr.next(); - - if (message.value && 'error' in message.value) { - const error = message.value.error; - - if (error instanceof Error) { - throw error; - } else { - throw new Error(error); - } - } - - const finish = message.value.choices[0].finish_reason === 'stop'; - - if (finish !== true) { - throw new Error('Expected a completed response.'); - } - - return message.value; - } + const llmSession = LLMSession.fromProvider(mode, { + inferenceAPIHost: this.inferenceAPIHost, + model: this.model, + ...this.extraOpts, // allows custom provider initialization like 'apiKey' + }); + + return await llmSession.run({ + prompt, + stream, + signal: opts.signal, + timeout: opts.timeout, + }); } if (this.init) { diff --git a/ext/ai/js/llm/llm_session.ts b/ext/ai/js/llm/llm_session.ts index c6bfcf283..475b99ddf 100644 --- a/ext/ai/js/llm/llm_session.ts +++ b/ext/ai/js/llm/llm_session.ts @@ -1,4 +1,5 @@ import { OllamaLLMSession } from './providers/ollama.ts'; +import { OpenAILLMSession } from './providers/openai.ts'; // @ts-ignore deno_core environment const core = globalThis.Deno.core; @@ -20,18 +21,25 @@ export type LLMRunInput = { }; export interface ILLMProviderOptions { - inferenceAPIHost: string; model: string; + inferenceAPIHost: string; +} + +export interface ILLMProviderInput { + prompt: string | object; + signal: AbortSignal; } export interface ILLMProvider { // TODO:(kallebysantos) remove 'any' - getStream(prompt: string, signal: AbortSignal): Promise>; - getText(prompt: string, signal: AbortSignal): Promise; + // TODO: (kallebysantos) standardised output format + getStream(input: ILLMProviderInput): Promise>; + getText(input: ILLMProviderInput): Promise; } export const providers = { 'ollama': OllamaLLMSession, + 'openaicompatible': OpenAILLMSession, } satisfies Record ILLMProvider>; export type LLMProviderName = keyof typeof providers; @@ -65,10 +73,11 @@ export class LLMSession { .filter((it) => it instanceof AbortSignal); const signal = AbortSignal.any(abortSignals); + const llmInput: ILLMProviderInput = { prompt: opts.prompt, signal }; if (isStream) { - return this.#inner.getStream(opts.prompt, signal); + return this.#inner.getStream(llmInput); } - return this.#inner.getText(opts.prompt, signal); + return this.#inner.getText(llmInput); } } diff --git a/ext/ai/js/llm/providers/ollama.ts b/ext/ai/js/llm/providers/ollama.ts index a901d5725..f714b8762 100644 --- a/ext/ai/js/llm/providers/ollama.ts +++ b/ext/ai/js/llm/providers/ollama.ts @@ -1,7 +1,10 @@ -import { ILLMProvider, ILLMProviderOptions } from '../llm_session.ts'; +import { ILLMProvider, ILLMProviderInput, ILLMProviderOptions } from '../llm_session.ts'; import { parseJSON } from '../utils/json_parser.ts'; export type OllamaProviderOptions = ILLMProviderOptions; +export type OllamaProviderInput = ILLMProviderInput & { + prompt: string; +}; export type OllamaMessage = { model: string; @@ -26,10 +29,13 @@ export class OllamaLLMSession implements ILLMProvider { // ref: https://github.com/ollama/ollama-js/blob/6a4bfe3ab033f611639dfe4249bdd6b9b19c7256/src/utils.ts#L26 async getStream( - prompt: string, - signal: AbortSignal, + { prompt, signal }: OllamaProviderInput, ): Promise> { - const generator = await this.generate(prompt, signal, true); + const generator = await this.generate( + prompt, + signal, + true, + ) as AsyncGenerator; const stream = async function* () { for await (const message of generator) { @@ -55,22 +61,10 @@ export class OllamaLLMSession implements ILLMProvider { return stream(); } - async getText(prompt: string, signal: AbortSignal): Promise { - const generator = await this.generate(prompt, signal); - - const message = await generator.next(); - - if (message.value && 'error' in message.value) { - const error = message.value.error; - - if (error instanceof Error) { - throw error; - } else { - throw new Error(error); - } - } - - const response = message.value; + async getText( + { prompt, signal }: OllamaProviderInput, + ): Promise { + const response = await this.generate(prompt, signal) as OllamaMessage; if (!response?.done) { throw new Error('Expected a completed response.'); @@ -110,6 +104,12 @@ export class OllamaLLMSession implements ILLMProvider { throw new Error('Missing body'); } - return parseJSON(res.body, signal); + if (stream) { + return parseJSON(res.body, signal); + } + + const result: OllamaMessage = await res.json(); + + return result; } } diff --git a/ext/ai/js/llm/providers/openai.ts b/ext/ai/js/llm/providers/openai.ts new file mode 100644 index 000000000..c36ef8355 --- /dev/null +++ b/ext/ai/js/llm/providers/openai.ts @@ -0,0 +1,191 @@ +import { ILLMProvider, ILLMProviderInput, ILLMProviderOptions } from '../llm_session.ts'; +import { parseJSONOverEventStream } from '../utils/json_parser.ts'; + +export type OpenAIProviderOptions = ILLMProviderOptions & { + apiKey?: string; +}; + +// TODO:(kallebysantos) need to double check theses AI generated types +export type OpenAIRequest = { + model: string; + messages: { + role: 'system' | 'user' | 'assistant' | 'tool'; + content: string; + name?: string; + tool_call_id?: string; + function_call?: { + name: string; + arguments: string; + }; + }[]; + temperature?: number; + top_p?: number; + n?: number; + stream?: boolean; + stop?: string | string[]; + max_tokens?: number; + presence_penalty?: number; + frequency_penalty?: number; + logit_bias?: { [token: string]: number }; + user?: string; + tools?: { + type: 'function'; + function: { + name: string; + description?: string; + parameters: any; // Can be refined based on your function definition + }; + }[]; + tool_choice?: 'none' | 'auto' | { + type: 'function'; + function: { name: string }; + }; +}; + +export type OpenAIResponse = { + id: string; + object: 'chat.completion'; + created: number; + model: string; + system_fingerprint?: string; + choices: { + index: number; + message: { + role: 'assistant' | 'user' | 'system' | 'tool'; + content: string | null; + function_call?: { + name: string; + arguments: string; + }; + tool_calls?: { + id: string; + type: 'function'; + function: { + name: string; + arguments: string; + }; + }[]; + }; + finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | null; + }[]; + usage?: { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + }; +}; + +export type OpenAIInput = Omit; + +export type OpenAIProviderInput = Omit & { + // Use Open AI defs + prompt: OpenAIInput; +}; + +export class OpenAILLMSession implements ILLMProvider { + opts: OpenAIProviderOptions; + + constructor(opts: OpenAIProviderOptions) { + this.opts = opts; + } + + async getStream( + { prompt, signal }: OpenAIProviderInput, + ): Promise> { + const generator = await this.generate( + prompt, + signal, + true, + ) as AsyncGenerator; // TODO:(kallebysantos) remove any + + const stream = async function* () { + for await (const message of generator) { + // TODO:(kallebysantos) Simplify duplicated code for stream error checking + if ('error' in message) { + if (message.error instanceof Error) { + throw message.error; + } else { + throw new Error(message.error as string); + } + } + + yield message; + const finishReason = message.choices[0].finish_reason; + + if (finishReason) { + if (finishReason !== 'stop') { + throw new Error('Expected a completed response.'); + } + + return; + } + } + + throw new Error( + 'Did not receive done or success response in stream.', + ); + }; + + return stream(); + } + + async getText( + { prompt, signal }: OpenAIProviderInput, + ): Promise { + const response = await this.generate( + prompt, + signal, + ) as OpenAIResponse; + + const finishReason = response.choices[0].finish_reason; + + if (finishReason !== 'stop') { + throw new Error('Expected a completed response.'); + } + + return response; + } + + private async generate( + input: OpenAIInput, + signal: AbortSignal, + stream: boolean = false, + ) { + const res = await fetch( + new URL('/v1/chat/completions', this.opts.inferenceAPIHost), + { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + 'Authorization': `Bearer ${this.opts.apiKey}`, + }, + body: JSON.stringify( + { + ...input, + model: this.opts.model, + stream, + } satisfies OpenAIRequest, + ), + signal, + }, + ); + + if (!res.ok) { + throw new Error( + `Failed to fetch inference API host. Status ${res.status}: ${res.statusText}`, + ); + } + + if (!res.body) { + throw new Error('Missing body'); + } + + if (stream) { + return parseJSONOverEventStream(res.body, signal); + } + + const result: OpenAIResponse = await res.json(); + + return result; + } +} diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index 3957e2c6d..1d3edfb89 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -56,6 +56,7 @@ deno_core::extension!( "onnxruntime/cache_adapter.js", "llm/llm_session.ts", "llm/providers/ollama.ts", + "llm/providers/openai.ts", "llm/utils/json_parser.ts", "llm/utils/event_stream_parser.mjs", "llm/utils/event_source_stream.mjs", From decc40f81131bc3e61e8bef1a93257cefcc1e5a4 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Mon, 12 May 2025 16:32:32 +0100 Subject: [PATCH 05/14] break: improving typescript support and refactoring the API - Improving Typescript support for dynamic suggestion based on the selected Session type. - Break: Now LLM models must be defined inside `options` argument, it allows a better typescript checking as well makes easier to extend the API. - There's no need to check if `inferenceHost` env var is defined, since we can now switch between different LLM providers. Instead, we can enable LLM support if the given type is an allowed provider. --- ext/ai/js/ai.d.ts | 21 +++++ ext/ai/js/ai.js | 80 ----------------- ext/ai/js/ai.ts | 141 ++++++++++++++++++++++++++++++ ext/ai/js/llm/llm_session.ts | 65 ++++++++++---- ext/ai/js/llm/providers/ollama.ts | 29 +++--- ext/ai/js/llm/providers/openai.ts | 131 +++++++++++++++------------ ext/ai/lib.rs | 4 +- ext/runtime/js/namespaces.js | 2 +- types/global.d.ts | 105 +++++++++++----------- 9 files changed, 362 insertions(+), 216 deletions(-) create mode 100644 ext/ai/js/ai.d.ts delete mode 100644 ext/ai/js/ai.js create mode 100644 ext/ai/js/ai.ts diff --git a/ext/ai/js/ai.d.ts b/ext/ai/js/ai.d.ts new file mode 100644 index 000000000..81a2ec00b --- /dev/null +++ b/ext/ai/js/ai.d.ts @@ -0,0 +1,21 @@ +import { Session } from "./ai.ts"; +import { LLMSessionRunInputOptions } from "./llm/llm_session.ts"; +import { + OllamaProviderInput, + OllamaProviderOptions, +} from "./llm/providers/ollama.ts"; +import { + OpenAIProviderInput, + OpenAIProviderOptions, +} from "./llm/providers/openai.ts"; + +export namespace ai { + export { Session }; + export { + LLMSessionRunInputOptions as LLMRunOptions, + OllamaProviderInput as OllamaInput, + OllamaProviderOptions as OllamaOptions, + OpenAIProviderInput as OpenAICompatibleInput, + OpenAIProviderOptions as OpenAICompatibleOptions, + }; +} diff --git a/ext/ai/js/ai.js b/ext/ai/js/ai.js deleted file mode 100644 index 55794890e..000000000 --- a/ext/ai/js/ai.js +++ /dev/null @@ -1,80 +0,0 @@ -import 'ext:ai/onnxruntime/onnx.js'; -import { LLMSession, providers } from './llm/llm_session.ts'; - -const core = globalThis.Deno.core; - -class Session { - model; - init; - is_ext_inference_api; - inferenceAPIHost; - extraOpts; - - // TODO:(kallebysantos) get 'provider' type here and use type checking to suggest Inputs when run - constructor(model, opts = {}) { - this.model = model; - this.is_ext_inference_api = false; - this.extraOpts = opts; - - // TODO:(kallebysantos) do we still need gte-small? - if (model === 'gte-small') { - this.init = core.ops.op_ai_init_model(model); - } else { - this.inferenceAPIHost = core.ops.op_get_env('AI_INFERENCE_API_HOST'); - this.is_ext_inference_api = !!this.inferenceAPIHost; // only enable external inference API if env variable is set - } - } - - /** @param {string | object} prompt Either a String (ollama) or an OpenAI chat completion body object (openaicompatible): https://platform.openai.com/docs/api-reference/chat/create */ - async run(prompt, opts = {}) { - if (this.is_ext_inference_api) { - const stream = opts.stream ?? false; - - /** @type {'ollama' | 'openaicompatible'} */ - // TODO:(kallebysantos) get mode from 'new' and apply type checking based on that - const mode = opts.mode ?? 'ollama'; - - if (!Object.keys(providers).includes(mode)) { - throw new TypeError(`invalid mode: ${mode}`); - } - - const llmSession = LLMSession.fromProvider(mode, { - inferenceAPIHost: this.inferenceAPIHost, - model: this.model, - ...this.extraOpts, // allows custom provider initialization like 'apiKey' - }); - - return await llmSession.run({ - prompt, - stream, - signal: opts.signal, - timeout: opts.timeout, - }); - } - - if (this.init) { - await this.init; - } - - const mean_pool = opts.mean_pool ?? true; - const normalize = opts.normalize ?? true; - const result = await core.ops.op_ai_run_model( - this.model, - prompt, - mean_pool, - normalize, - ); - - return result; - } -} - -const MAIN_WORKER_API = { - tryCleanupUnusedSession: () => /* async */ core.ops.op_ai_try_cleanup_unused_session(), -}; - -const USER_WORKER_API = { - Session, -}; - -export { MAIN_WORKER_API, USER_WORKER_API }; diff --git a/ext/ai/js/ai.ts b/ext/ai/js/ai.ts new file mode 100644 index 000000000..2e1e47e0b --- /dev/null +++ b/ext/ai/js/ai.ts @@ -0,0 +1,141 @@ +import "./onnxruntime/onnx.js"; +import { + LLMProviderInstance, + LLMProviderName, + LLMSession, + LLMSessionRunInputOptions as LLMInputOptions, + providers, +} from "./llm/llm_session.ts"; + +// @ts-ignore deno_core environment +const core = globalThis.Deno.core; + +// NOTE:(kallebysantos) do we still need gte-small? Or maybe add another type 'embeddings' with custom model opt. +export type SessionType = LLMProviderName | "gte-small"; + +export type SessionOptions = T extends LLMProviderName + ? LLMProviderInstance["options"] + : never; + +export type SessionInput = T extends LLMProviderName + ? LLMProviderInstance["input"] + : T extends "gte-small" ? string + : never; + +export type EmbeddingInputOptions = { + /** + * Pool embeddings by taking their mean + */ + mean_pool?: boolean; + + /** + * Normalize the embeddings result + */ + normalize?: boolean; +}; + +export type SessionInputOptions = T extends + LLMProviderName ? LLMInputOptions + : EmbeddingInputOptions; + +export class Session { + #model?: string; + #init?: Promise; + + // TODO:(kallebysantos) get 'provider' type here and use type checking to suggest Inputs when run + constructor( + public readonly type: T, + public readonly options?: SessionOptions, + ) { + if (this.isEmbeddingType()) { + this.#model = "gte-small"; // Default model + this.#init = core.ops.op_ai_init_model(this.#model); + return; + } + + if (this.isLLMType()) { + if (!Object.keys(providers).includes(type)) { + throw new TypeError(`invalid type: '${type}'`); + } + + if (!this.options || !this.options.model) { + throw new Error( + `missing required parameter 'model' for type: '${type}'`, + ); + } + + this.options.baseURL ??= core.ops.op_get_env( + "AI_INFERENCE_API_HOST", + ) as string; + + if (!this.options.baseURL) { + throw new Error( + `missing required parameter 'baseURL' for type: '${type}'`, + ); + } + } + } + + // /** @param {string | object} prompt Either a String (ollama) or an OpenAI chat completion body object (openaicompatible): https://platform.openai.com/docs/api-reference/chat/create */ + async run(input: SessionInput, options: SessionInputOptions) { + if (this.isLLMType()) { + const opts = options as LLMInputOptions; + const stream = opts.stream ?? false; + + const llmSession = LLMSession.fromProvider(this.type, { + // safety: We did check `options` during construction + baseURL: this.options!.baseURL, + model: this.options!.model, + ...this.options, // allows custom provider initialization like 'apiKey' + }); + + return await llmSession.run(input, { + stream, + signal: opts.signal, + timeout: opts.timeout, + }); + } + + if (this.#init) { + await this.#init; + } + + const opts = options as EmbeddingInputOptions; + + const mean_pool = opts.mean_pool ?? true; + const normalize = opts.normalize ?? true; + + const result = await core.ops.op_ai_run_model( + // @ts-ignore + this.#model, + prompt, + mean_pool, + normalize, + ); + + return result; + } + + private isEmbeddingType( + this: Session, + ): this is Session<"gte-small"> { + return this.type === "gte-small"; + } + + private isLLMType( + this: Session, + ): this is Session { + return this.type !== "gte-small"; + } +} + +const MAIN_WORKER_API = { + tryCleanupUnusedSession: () => + /* async */ core.ops.op_ai_try_cleanup_unused_session(), +}; + +const USER_WORKER_API = { + Session, +}; + +export { MAIN_WORKER_API, USER_WORKER_API }; diff --git a/ext/ai/js/llm/llm_session.ts b/ext/ai/js/llm/llm_session.ts index 475b99ddf..64bcfeada 100644 --- a/ext/ai/js/llm/llm_session.ts +++ b/ext/ai/js/llm/llm_session.ts @@ -1,5 +1,5 @@ -import { OllamaLLMSession } from './providers/ollama.ts'; -import { OpenAILLMSession } from './providers/openai.ts'; +import { OllamaLLMSession } from "./providers/ollama.ts"; +import { OpenAILLMSession } from "./providers/openai.ts"; // @ts-ignore deno_core environment const core = globalThis.Deno.core; @@ -20,30 +20,59 @@ export type LLMRunInput = { signal?: AbortSignal; }; +export interface ILLMProviderMeta { + input: ILLMProviderInput; + output: unknown; + options: ILLMProviderOptions; +} + export interface ILLMProviderOptions { model: string; - inferenceAPIHost: string; + baseURL?: string; } -export interface ILLMProviderInput { - prompt: string | object; - signal: AbortSignal; -} +export type ILLMProviderInput = T extends string ? string + : T; export interface ILLMProvider { // TODO:(kallebysantos) remove 'any' // TODO: (kallebysantos) standardised output format - getStream(input: ILLMProviderInput): Promise>; - getText(input: ILLMProviderInput): Promise; + getStream( + input: ILLMProviderInput, + signal: AbortSignal, + ): Promise>; + getText(input: ILLMProviderInput, signal: AbortSignal): Promise; } export const providers = { - 'ollama': OllamaLLMSession, - 'openaicompatible': OpenAILLMSession, -} satisfies Record ILLMProvider>; + "ollama": OllamaLLMSession, + "openaicompatible": OpenAILLMSession, +} satisfies Record< + string, + new (opts: ILLMProviderOptions) => ILLMProvider & ILLMProviderMeta +>; export type LLMProviderName = keyof typeof providers; +export type LLMProviderClass = (typeof providers)[T]; +export type LLMProviderInstance = InstanceType< + LLMProviderClass +>; + +export type LLMSessionRunInputOptions = { + /** + * Stream response from model. Applies only for LLMs like `mistral` (default: false) + */ + stream?: boolean; + + /** + * Automatically abort the request to the model after specified time (in seconds). Applies only for LLMs like `mistral` (default: 60) + */ + timeout?: number; + + signal?: AbortSignal; +}; + export class LLMSession { #inner: ILLMProvider; @@ -53,7 +82,7 @@ export class LLMSession { static fromProvider(name: LLMProviderName, opts: ILLMProviderOptions) { const ProviderType = providers[name]; - if (!ProviderType) throw new Error('invalid provider'); + if (!ProviderType) throw new Error("invalid provider"); const provider = new ProviderType(opts); @@ -61,11 +90,12 @@ export class LLMSession { } run( - opts: LLMRunInput, + input: ILLMProviderInput, + opts: LLMSessionRunInputOptions, ): Promise> | Promise { const isStream = opts.stream ?? false; - const timeoutSeconds = typeof opts.timeout === 'number' ? opts.timeout : 60; + const timeoutSeconds = typeof opts.timeout === "number" ? opts.timeout : 60; const timeoutMs = timeoutSeconds * 1000; const timeoutSignal = AbortSignal.timeout(timeoutMs); @@ -73,11 +103,10 @@ export class LLMSession { .filter((it) => it instanceof AbortSignal); const signal = AbortSignal.any(abortSignals); - const llmInput: ILLMProviderInput = { prompt: opts.prompt, signal }; if (isStream) { - return this.#inner.getStream(llmInput); + return this.#inner.getStream(input, signal); } - return this.#inner.getText(llmInput); + return this.#inner.getText(input, signal); } } diff --git a/ext/ai/js/llm/providers/ollama.ts b/ext/ai/js/llm/providers/ollama.ts index f714b8762..05f27ea92 100644 --- a/ext/ai/js/llm/providers/ollama.ts +++ b/ext/ai/js/llm/providers/ollama.ts @@ -1,10 +1,13 @@ -import { ILLMProvider, ILLMProviderInput, ILLMProviderOptions } from '../llm_session.ts'; +import { + ILLMProvider, + ILLMProviderInput, + ILLMProviderMeta, + ILLMProviderOptions, +} from '../llm_session.ts'; import { parseJSON } from '../utils/json_parser.ts'; export type OllamaProviderOptions = ILLMProviderOptions; -export type OllamaProviderInput = ILLMProviderInput & { - prompt: string; -}; +export type OllamaProviderInput = ILLMProviderInput; export type OllamaMessage = { model: string; @@ -20,16 +23,19 @@ export type OllamaMessage = { eval_duration: number; }; -export class OllamaLLMSession implements ILLMProvider { - opts: OllamaProviderOptions; +export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { + input!: OllamaProviderInput; + output!: unknown; + options: OllamaProviderOptions; constructor(opts: OllamaProviderOptions) { - this.opts = opts; + this.options = opts; } // ref: https://github.com/ollama/ollama-js/blob/6a4bfe3ab033f611639dfe4249bdd6b9b19c7256/src/utils.ts#L26 async getStream( - { prompt, signal }: OllamaProviderInput, + prompt: OllamaProviderInput, + signal: AbortSignal, ): Promise> { const generator = await this.generate( prompt, @@ -62,7 +68,8 @@ export class OllamaLLMSession implements ILLMProvider { } async getText( - { prompt, signal }: OllamaProviderInput, + prompt: OllamaProviderInput, + signal: AbortSignal, ): Promise { const response = await this.generate(prompt, signal) as OllamaMessage; @@ -79,14 +86,14 @@ export class OllamaLLMSession implements ILLMProvider { stream: boolean = false, ) { const res = await fetch( - new URL('/api/generate', this.opts.inferenceAPIHost), + new URL('/api/generate', this.options.baseURL), { method: 'POST', headers: { 'Content-Type': 'application/json', }, body: JSON.stringify({ - model: this.opts.model, + model: this.options.model, stream, prompt, }), diff --git a/ext/ai/js/llm/providers/openai.ts b/ext/ai/js/llm/providers/openai.ts index c36ef8355..c614ad18c 100644 --- a/ext/ai/js/llm/providers/openai.ts +++ b/ext/ai/js/llm/providers/openai.ts @@ -1,5 +1,10 @@ -import { ILLMProvider, ILLMProviderInput, ILLMProviderOptions } from '../llm_session.ts'; -import { parseJSONOverEventStream } from '../utils/json_parser.ts'; +import { + ILLMProvider, + ILLMProviderInput, + ILLMProviderMeta, + ILLMProviderOptions, +} from "../llm_session.ts"; +import { parseJSONOverEventStream } from "../utils/json_parser.ts"; export type OpenAIProviderOptions = ILLMProviderOptions & { apiKey?: string; @@ -9,7 +14,7 @@ export type OpenAIProviderOptions = ILLMProviderOptions & { export type OpenAIRequest = { model: string; messages: { - role: 'system' | 'user' | 'assistant' | 'tool'; + role: "system" | "user" | "assistant" | "tool"; content: string; name?: string; tool_call_id?: string; @@ -29,68 +34,83 @@ export type OpenAIRequest = { logit_bias?: { [token: string]: number }; user?: string; tools?: { - type: 'function'; + type: "function"; function: { name: string; description?: string; parameters: any; // Can be refined based on your function definition }; }[]; - tool_choice?: 'none' | 'auto' | { - type: 'function'; + tool_choice?: "none" | "auto" | { + type: "function"; function: { name: string }; }; }; +export type OpenAIResponseUsage = { + prompt_tokens: number; + completion_tokens: number; + total_tokens: number; + prompt_tokens_details: { + cached_tokens: 0; + audio_tokens: 0; + }; + completion_tokens_details: { + reasoning_tokens: 0; + audio_tokens: 0; + accepted_prediction_tokens: 0; + rejected_prediction_tokens: 0; + }; +}; + +export type OpenAIResponseChoice = { + index: number; + message: { + role: "assistant" | "user" | "system" | "tool"; + content: string | null; + function_call?: { + name: string; + arguments: string; + }; + tool_calls?: { + id: string; + type: "function"; + function: { + name: string; + arguments: string; + }; + }[]; + }; + finish_reason: "stop" | "length" | "tool_calls" | "content_filter" | null; +}; + export type OpenAIResponse = { id: string; - object: 'chat.completion'; + object: "chat.completion"; created: number; model: string; system_fingerprint?: string; - choices: { - index: number; - message: { - role: 'assistant' | 'user' | 'system' | 'tool'; - content: string | null; - function_call?: { - name: string; - arguments: string; - }; - tool_calls?: { - id: string; - type: 'function'; - function: { - name: string; - arguments: string; - }; - }[]; - }; - finish_reason: 'stop' | 'length' | 'tool_calls' | 'content_filter' | null; - }[]; - usage?: { - prompt_tokens: number; - completion_tokens: number; - total_tokens: number; - }; + choices: OpenAIResponseChoice[]; + usage?: OpenAIResponseUsage; }; -export type OpenAIInput = Omit; +export type OpenAICompatibleInput = Omit; -export type OpenAIProviderInput = Omit & { - // Use Open AI defs - prompt: OpenAIInput; -}; +export type OpenAIProviderInput = ILLMProviderInput; -export class OpenAILLMSession implements ILLMProvider { - opts: OpenAIProviderOptions; +export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { + input!: OpenAIProviderInput; + // TODO:(kallebysantos) add output types + output: unknown; + options: OpenAIProviderOptions; constructor(opts: OpenAIProviderOptions) { - this.opts = opts; + this.options = opts; } async getStream( - { prompt, signal }: OpenAIProviderInput, + prompt: OpenAIProviderInput, + signal: AbortSignal, ): Promise> { const generator = await this.generate( prompt, @@ -101,7 +121,7 @@ export class OpenAILLMSession implements ILLMProvider { const stream = async function* () { for await (const message of generator) { // TODO:(kallebysantos) Simplify duplicated code for stream error checking - if ('error' in message) { + if ("error" in message) { if (message.error instanceof Error) { throw message.error; } else { @@ -113,8 +133,8 @@ export class OpenAILLMSession implements ILLMProvider { const finishReason = message.choices[0].finish_reason; if (finishReason) { - if (finishReason !== 'stop') { - throw new Error('Expected a completed response.'); + if (finishReason !== "stop") { + throw new Error("Expected a completed response."); } return; @@ -122,7 +142,7 @@ export class OpenAILLMSession implements ILLMProvider { } throw new Error( - 'Did not receive done or success response in stream.', + "Did not receive done or success response in stream.", ); }; @@ -130,7 +150,8 @@ export class OpenAILLMSession implements ILLMProvider { } async getText( - { prompt, signal }: OpenAIProviderInput, + prompt: OpenAIProviderInput, + signal: AbortSignal, ): Promise { const response = await this.generate( prompt, @@ -139,30 +160,30 @@ export class OpenAILLMSession implements ILLMProvider { const finishReason = response.choices[0].finish_reason; - if (finishReason !== 'stop') { - throw new Error('Expected a completed response.'); + if (finishReason !== "stop") { + throw new Error("Expected a completed response."); } return response; } private async generate( - input: OpenAIInput, + input: OpenAICompatibleInput, signal: AbortSignal, stream: boolean = false, ) { const res = await fetch( - new URL('/v1/chat/completions', this.opts.inferenceAPIHost), + new URL("/v1/chat/completions", this.options.baseURL), { - method: 'POST', + method: "POST", headers: { - 'Content-Type': 'application/json', - 'Authorization': `Bearer ${this.opts.apiKey}`, + "Content-Type": "application/json", + "Authorization": `Bearer ${this.options.apiKey}`, }, body: JSON.stringify( { ...input, - model: this.opts.model, + model: this.options.model, stream, } satisfies OpenAIRequest, ), @@ -177,7 +198,7 @@ export class OpenAILLMSession implements ILLMProvider { } if (!res.body) { - throw new Error('Missing body'); + throw new Error("Missing body"); } if (stream) { diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index 1d3edfb89..68fc77461 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -48,10 +48,10 @@ deno_core::extension!( op_ai_ort_init_session, op_ai_ort_run_session, ], - esm_entry_point = "ext:ai/ai.js", + esm_entry_point = "ext:ai/ai.ts", esm = [ dir "js", - "ai.js", + "ai.ts", "onnxruntime/onnx.js", "onnxruntime/cache_adapter.js", "llm/llm_session.ts", diff --git a/ext/runtime/js/namespaces.js b/ext/runtime/js/namespaces.js index 7bff9c577..09dd56c99 100644 --- a/ext/runtime/js/namespaces.js +++ b/ext/runtime/js/namespaces.js @@ -1,6 +1,6 @@ import { core, primordials } from "ext:core/mod.js"; -import { MAIN_WORKER_API, USER_WORKER_API } from "ext:ai/ai.js"; +import { MAIN_WORKER_API, USER_WORKER_API } from "ext:ai/ai.ts"; import { SUPABASE_USER_WORKERS } from "ext:user_workers/user_workers.js"; import { applySupabaseTag } from "ext:runtime/http.js"; import { waitUntil } from "ext:runtime/async_hook.js"; diff --git a/types/global.d.ts b/types/global.d.ts index 7810e23a9..94e6620d3 100644 --- a/types/global.d.ts +++ b/types/global.d.ts @@ -149,58 +149,65 @@ declare namespace EdgeRuntime { export { UserWorker as userWorkers }; } +// TODO:(kallebysantos) use some TS builder to bundle all types +import { ai as AINamespace } from "../ext/ai/js/ai.d.ts"; + declare namespace Supabase { - export namespace ai { - interface ModelOptions { - /** - * Pool embeddings by taking their mean. Applies only for `gte-small` model - */ - mean_pool?: boolean; - - /** - * Normalize the embeddings result. Applies only for `gte-small` model - */ - normalize?: boolean; - - /** - * Stream response from model. Applies only for LLMs like `mistral` (default: false) - */ - stream?: boolean; - - /** - * Automatically abort the request to the model after specified time (in seconds). Applies only for LLMs like `mistral` (default: 60) - */ - timeout?: number; - - /** - * Mode for the inference API host. (default: 'ollama') - */ - mode?: "ollama" | "openaicompatible"; - signal?: AbortSignal; - } - - export class Session { - /** - * Create a new model session using given model - */ - constructor(model: string); - - /** - * Execute the given prompt in model session - */ - run( - prompt: - | string - | Omit< - import("openai").OpenAI.Chat.ChatCompletionCreateParams, - "model" | "stream" - >, - modelOptions?: ModelOptions, - ): unknown; - } - } + export import ai = AINamespace; } +// declare namespace Supabase { +// export namespace ai { +// interface ModelOptions { +// /** +// * Pool embeddings by taking their mean. Applies only for `gte-small` model +// */ +// mean_pool?: boolean; +// +// /** +// * Normalize the embeddings result. Applies only for `gte-small` model +// */ +// normalize?: boolean; +// +// /** +// * Stream response from model. Applies only for LLMs like `mistral` (default: false) +// */ +// stream?: boolean; +// +// /** +// * Automatically abort the request to the model after specified time (in seconds). Applies only for LLMs like `mistral` (default: 60) +// */ +// timeout?: number; +// +// /** +// * Mode for the inference API host. (default: 'ollama') +// */ +// mode?: "ollama" | "openaicompatible"; +// signal?: AbortSignal; +// } +// +// export class Session { +// /** +// * Create a new model session using given model +// */ +// constructor(model: string); +// +// /** +// * Execute the given prompt in model session +// */ +// run( +// prompt: +// | string +// | Omit< +// import("openai").OpenAI.Chat.ChatCompletionCreateParams, +// "model" | "stream" +// >, +// modelOptions?: ModelOptions, +// ): unknown; +// } +// } +// } + declare namespace Deno { export namespace errors { class WorkerRequestCancelled extends Error {} From a9e0d411c4460d95c5d5da774ea49be71912a7dd Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Mon, 12 May 2025 20:33:44 +0100 Subject: [PATCH 06/14] stamp: creating result types and common usage interface - Improving typescript with conditional output types based on the selected provider - Defining common properties for LLM providers like `usage` metrics and simplified `value` --- ext/ai/js/ai.ts | 23 +++++++++++---- ext/ai/js/llm/llm_session.ts | 19 +++++++++++-- ext/ai/js/llm/providers/ollama.ts | 47 ++++++++++++++++++++++--------- ext/ai/js/llm/providers/openai.ts | 28 ++++++++++++++---- 4 files changed, 88 insertions(+), 29 deletions(-) diff --git a/ext/ai/js/ai.ts b/ext/ai/js/ai.ts index 2e1e47e0b..2b2257a0b 100644 --- a/ext/ai/js/ai.ts +++ b/ext/ai/js/ai.ts @@ -34,15 +34,23 @@ export type EmbeddingInputOptions = { normalize?: boolean; }; -export type SessionInputOptions = T extends - LLMProviderName ? LLMInputOptions - : EmbeddingInputOptions; +export type SessionInputOptions = T extends "gte-small" + ? EmbeddingInputOptions + : T extends LLMProviderName ? LLMInputOptions + : never; + +export type SessionOutput = T extends "gte-small" + ? number[] + : T extends LLMProviderName + ? O extends { stream: true } + ? AsyncGenerator["output"]> + : LLMProviderInstance["output"] + : never; export class Session { #model?: string; #init?: Promise; - // TODO:(kallebysantos) get 'provider' type here and use type checking to suggest Inputs when run constructor( public readonly type: T, public readonly options?: SessionOptions, @@ -77,7 +85,10 @@ export class Session { } // /** @param {string | object} prompt Either a String (ollama) or an OpenAI chat completion body object (openaicompatible): https://platform.openai.com/docs/api-reference/chat/create */ - async run(input: SessionInput, options: SessionInputOptions) { + async run>( + input: SessionInput, + options: O, + ): Promise> { if (this.isLLMType()) { const opts = options as LLMInputOptions; const stream = opts.stream ?? false; @@ -93,7 +104,7 @@ export class Session { stream, signal: opts.signal, timeout: opts.timeout, - }); + }) as SessionOutput; } if (this.#init) { diff --git a/ext/ai/js/llm/llm_session.ts b/ext/ai/js/llm/llm_session.ts index 64bcfeada..b9b05de5e 100644 --- a/ext/ai/js/llm/llm_session.ts +++ b/ext/ai/js/llm/llm_session.ts @@ -34,14 +34,27 @@ export interface ILLMProviderOptions { export type ILLMProviderInput = T extends string ? string : T; +export interface ILLMProviderOutput { + value?: string; + usage: { + inputTokens: number; + outputTokens: number; + totalTokens: number; + }; + inner: T; +} + export interface ILLMProvider { // TODO:(kallebysantos) remove 'any' // TODO: (kallebysantos) standardised output format getStream( input: ILLMProviderInput, signal: AbortSignal, - ): Promise>; - getText(input: ILLMProviderInput, signal: AbortSignal): Promise; + ): Promise>; + getText( + input: ILLMProviderInput, + signal: AbortSignal, + ): Promise; } export const providers = { @@ -92,7 +105,7 @@ export class LLMSession { run( input: ILLMProviderInput, opts: LLMSessionRunInputOptions, - ): Promise> | Promise { + ): Promise> | Promise { const isStream = opts.stream ?? false; const timeoutSeconds = typeof opts.timeout === "number" ? opts.timeout : 60; diff --git a/ext/ai/js/llm/providers/ollama.ts b/ext/ai/js/llm/providers/ollama.ts index 05f27ea92..49e45c1c1 100644 --- a/ext/ai/js/llm/providers/ollama.ts +++ b/ext/ai/js/llm/providers/ollama.ts @@ -3,11 +3,13 @@ import { ILLMProviderInput, ILLMProviderMeta, ILLMProviderOptions, -} from '../llm_session.ts'; -import { parseJSON } from '../utils/json_parser.ts'; + ILLMProviderOutput, +} from "../llm_session.ts"; +import { parseJSON } from "../utils/json_parser.ts"; export type OllamaProviderOptions = ILLMProviderOptions; export type OllamaProviderInput = ILLMProviderInput; +export type OllamaProviderOutput = ILLMProviderOutput; export type OllamaMessage = { model: string; @@ -25,7 +27,7 @@ export type OllamaMessage = { export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { input!: OllamaProviderInput; - output!: unknown; + output!: OllamaProviderOutput; options: OllamaProviderOptions; constructor(opts: OllamaProviderOptions) { @@ -36,16 +38,18 @@ export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { async getStream( prompt: OllamaProviderInput, signal: AbortSignal, - ): Promise> { + ): Promise> { const generator = await this.generate( prompt, signal, true, ) as AsyncGenerator; + const parser = this.parse; + const stream = async function* () { for await (const message of generator) { - if ('error' in message) { + if ("error" in message) { if (message.error instanceof Error) { throw message.error; } else { @@ -53,14 +57,15 @@ export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { } } - yield message; + yield parser(message); + if (message.done) { return; } } throw new Error( - 'Did not receive done or success response in stream.', + "Did not receive done or success response in stream.", ); }; @@ -70,14 +75,28 @@ export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { async getText( prompt: OllamaProviderInput, signal: AbortSignal, - ): Promise { + ): Promise { const response = await this.generate(prompt, signal) as OllamaMessage; if (!response?.done) { - throw new Error('Expected a completed response.'); + throw new Error("Expected a completed response."); } - return response; + return this.parse(response); + } + + private parse(message: OllamaMessage): OllamaProviderOutput { + const { response, prompt_eval_count, eval_count } = message; + + return { + value: response, + inner: message, + usage: { + inputTokens: prompt_eval_count, + outputTokens: eval_count, + totalTokens: prompt_eval_count + eval_count, + }, + }; } private async generate( @@ -86,11 +105,11 @@ export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { stream: boolean = false, ) { const res = await fetch( - new URL('/api/generate', this.options.baseURL), + new URL("/api/generate", this.options.baseURL), { - method: 'POST', + method: "POST", headers: { - 'Content-Type': 'application/json', + "Content-Type": "application/json", }, body: JSON.stringify({ model: this.options.model, @@ -108,7 +127,7 @@ export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { } if (!res.body) { - throw new Error('Missing body'); + throw new Error("Missing body"); } if (stream) { diff --git a/ext/ai/js/llm/providers/openai.ts b/ext/ai/js/llm/providers/openai.ts index c614ad18c..3e25b2f1c 100644 --- a/ext/ai/js/llm/providers/openai.ts +++ b/ext/ai/js/llm/providers/openai.ts @@ -3,6 +3,7 @@ import { ILLMProviderInput, ILLMProviderMeta, ILLMProviderOptions, + ILLMProviderOutput, } from "../llm_session.ts"; import { parseJSONOverEventStream } from "../utils/json_parser.ts"; @@ -97,11 +98,11 @@ export type OpenAIResponse = { export type OpenAICompatibleInput = Omit; export type OpenAIProviderInput = ILLMProviderInput; +export type OpenAIProviderOutput = ILLMProviderOutput; export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { input!: OpenAIProviderInput; - // TODO:(kallebysantos) add output types - output: unknown; + output!: OpenAIProviderOutput; options: OpenAIProviderOptions; constructor(opts: OpenAIProviderOptions) { @@ -111,13 +112,14 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { async getStream( prompt: OpenAIProviderInput, signal: AbortSignal, - ): Promise> { + ): Promise> { const generator = await this.generate( prompt, signal, true, ) as AsyncGenerator; // TODO:(kallebysantos) remove any + const parser = this.parse; const stream = async function* () { for await (const message of generator) { // TODO:(kallebysantos) Simplify duplicated code for stream error checking @@ -129,7 +131,7 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { } } - yield message; + yield parser(message); const finishReason = message.choices[0].finish_reason; if (finishReason) { @@ -152,7 +154,7 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { async getText( prompt: OpenAIProviderInput, signal: AbortSignal, - ): Promise { + ): Promise { const response = await this.generate( prompt, signal, @@ -164,9 +166,23 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { throw new Error("Expected a completed response."); } - return response; + return this.parse(response); } + private parse(message: OpenAIResponse): OpenAIProviderOutput { + const { usage } = message; + + return { + value: message.choices.at(0)?.message.content ?? undefined, + inner: message, + usage: { + // Usage maybe 'null' while streaming, but the final message will include it + inputTokens: usage?.prompt_tokens ?? 0, + outputTokens: usage?.completion_tokens ?? 0, + totalTokens: usage?.total_tokens ?? 0, + }, + }; + } private async generate( input: OpenAICompatibleInput, signal: AbortSignal, From 0403b9ee6f79016741ce3b9df68414b362deac1a Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Tue, 13 May 2025 11:57:35 +0100 Subject: [PATCH 07/14] stamp: cleaning & polishing --- ext/ai/js/ai.ts | 2 -- ext/ai/js/llm/llm_session.ts | 5 ----- ext/ai/js/llm/providers/ollama.ts | 2 +- ext/ai/js/llm/providers/openai.ts | 25 ++++++++++++++----------- ext/ai/js/llm/utils/json_parser.ts | 16 ++++++++-------- 5 files changed, 23 insertions(+), 27 deletions(-) diff --git a/ext/ai/js/ai.ts b/ext/ai/js/ai.ts index 2b2257a0b..92701d9b6 100644 --- a/ext/ai/js/ai.ts +++ b/ext/ai/js/ai.ts @@ -84,7 +84,6 @@ export class Session { } } - // /** @param {string | object} prompt Either a String (ollama) or an OpenAI chat completion body object (openaicompatible): https://platform.openai.com/docs/api-reference/chat/create */ async run>( input: SessionInput, options: O, @@ -117,7 +116,6 @@ export class Session { const normalize = opts.normalize ?? true; const result = await core.ops.op_ai_run_model( - // @ts-ignore this.#model, prompt, mean_pool, diff --git a/ext/ai/js/llm/llm_session.ts b/ext/ai/js/llm/llm_session.ts index b9b05de5e..17366270e 100644 --- a/ext/ai/js/llm/llm_session.ts +++ b/ext/ai/js/llm/llm_session.ts @@ -1,9 +1,6 @@ import { OllamaLLMSession } from "./providers/ollama.ts"; import { OpenAILLMSession } from "./providers/openai.ts"; -// @ts-ignore deno_core environment -const core = globalThis.Deno.core; - export type LLMRunInput = { /** * Stream response from model. Applies only for LLMs like `mistral` (default: false) @@ -45,8 +42,6 @@ export interface ILLMProviderOutput { } export interface ILLMProvider { - // TODO:(kallebysantos) remove 'any' - // TODO: (kallebysantos) standardised output format getStream( input: ILLMProviderInput, signal: AbortSignal, diff --git a/ext/ai/js/llm/providers/ollama.ts b/ext/ai/js/llm/providers/ollama.ts index 49e45c1c1..459222b58 100644 --- a/ext/ai/js/llm/providers/ollama.ts +++ b/ext/ai/js/llm/providers/ollama.ts @@ -45,8 +45,8 @@ export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { true, ) as AsyncGenerator; + // NOTE:(kallebysantos) we need to clone the lambda parser to avoid `this` conflicts inside the local function* const parser = this.parse; - const stream = async function* () { for await (const message of generator) { if ("error" in message) { diff --git a/ext/ai/js/llm/providers/openai.ts b/ext/ai/js/llm/providers/openai.ts index 3e25b2f1c..ce2932857 100644 --- a/ext/ai/js/llm/providers/openai.ts +++ b/ext/ai/js/llm/providers/openai.ts @@ -11,10 +11,13 @@ export type OpenAIProviderOptions = ILLMProviderOptions & { apiKey?: string; }; +// NOTE:(kallebysantos) we define all types here for better development as well avoid `"npm:openai"` import // TODO:(kallebysantos) need to double check theses AI generated types export type OpenAIRequest = { model: string; messages: { + // NOTE:(kallebysantos) using role as union type is great for intellisense suggestions + // but at same time it forces users to `{} satisfies Supabase.ai.OpenAICompatibleInput` role: "system" | "user" | "assistant" | "tool"; content: string; name?: string; @@ -39,7 +42,7 @@ export type OpenAIRequest = { function: { name: string; description?: string; - parameters: any; // Can be refined based on your function definition + parameters: unknown; }; }[]; tool_choice?: "none" | "auto" | { @@ -53,14 +56,14 @@ export type OpenAIResponseUsage = { completion_tokens: number; total_tokens: number; prompt_tokens_details: { - cached_tokens: 0; - audio_tokens: 0; + cached_tokens: number; + audio_tokens: number; }; completion_tokens_details: { - reasoning_tokens: 0; - audio_tokens: 0; - accepted_prediction_tokens: 0; - rejected_prediction_tokens: 0; + reasoning_tokens: number; + audio_tokens: number; + accepted_prediction_tokens: number; + rejected_prediction_tokens: number; }; }; @@ -117,12 +120,12 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { prompt, signal, true, - ) as AsyncGenerator; // TODO:(kallebysantos) remove any + ) as AsyncGenerator; + // NOTE:(kallebysantos) we need to clone the lambda parser to avoid `this` conflicts inside the local function* const parser = this.parse; const stream = async function* () { for await (const message of generator) { - // TODO:(kallebysantos) Simplify duplicated code for stream error checking if ("error" in message) { if (message.error instanceof Error) { throw message.error; @@ -176,7 +179,7 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { value: message.choices.at(0)?.message.content ?? undefined, inner: message, usage: { - // Usage maybe 'null' while streaming, but the final message will include it + // NOTE:(kallebysantos) usage maybe 'null' while streaming, but the final message will include it inputTokens: usage?.prompt_tokens ?? 0, outputTokens: usage?.completion_tokens ?? 0, totalTokens: usage?.total_tokens ?? 0, @@ -218,7 +221,7 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { } if (stream) { - return parseJSONOverEventStream(res.body, signal); + return parseJSONOverEventStream(res.body, signal); } const result: OpenAIResponse = await res.json(); diff --git a/ext/ai/js/llm/utils/json_parser.ts b/ext/ai/js/llm/utils/json_parser.ts index 636ab4292..b57b9f11f 100644 --- a/ext/ai/js/llm/utils/json_parser.ts +++ b/ext/ai/js/llm/utils/json_parser.ts @@ -1,4 +1,4 @@ -import EventSourceStream from './event_source_stream.mjs'; +import EventSourceStream from "./event_source_stream.mjs"; // Adapted from https://github.com/ollama/ollama-js/blob/6a4bfe3ab033f611639dfe4249bdd6b9b19c7256/src/utils.ts#L262 // TODO:(kallebysantos) need to simplify it @@ -6,9 +6,9 @@ export async function* parseJSON( itr: ReadableStream, signal: AbortSignal, ) { - let buffer = ''; + let buffer = ""; - const decoder = new TextDecoder('utf-8'); + const decoder = new TextDecoder("utf-8"); const reader = itr.getReader(); while (true) { @@ -27,9 +27,9 @@ export async function* parseJSON( buffer += decoder.decode(value); - const parts = buffer.split('\n'); + const parts = buffer.split("\n"); - buffer = parts.pop() ?? ''; + buffer = parts.pop() ?? ""; for (const part of parts) { yield JSON.parse(part) as T; @@ -39,7 +39,7 @@ export async function* parseJSON( } } - for (const part of buffer.split('\n').filter((p) => p !== '')) { + for (const part of buffer.split("\n").filter((p) => p !== "")) { try { yield JSON.parse(part) as T; } catch (error) { @@ -49,7 +49,7 @@ export async function* parseJSON( } // TODO:(kallebysantos) need to simplify it -export async function* parseJSONOverEventStream( +export async function* parseJSONOverEventStream( itr: ReadableStream, signal: AbortSignal, ) { @@ -74,7 +74,7 @@ export async function* parseJSONOverEventStream( break; } - yield JSON.parse(value.data); + yield JSON.parse(value.data) as T; } catch (error) { yield { error }; } From 08643ea8436460939690b60b2ce274867315749a Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Tue, 13 May 2025 18:11:55 +0100 Subject: [PATCH 08/14] fix: openai streaming - OpenAI uses a different streaming alternative that ends with `[DONE]` --- ext/ai/js/llm/providers/openai.ts | 46 ++++++++++++++------- ext/ai/js/llm/utils/event_source_stream.mjs | 12 +++++- ext/ai/js/llm/utils/json_parser.ts | 1 + 3 files changed, 42 insertions(+), 17 deletions(-) diff --git a/ext/ai/js/llm/providers/openai.ts b/ext/ai/js/llm/providers/openai.ts index ce2932857..8d7fda73b 100644 --- a/ext/ai/js/llm/providers/openai.ts +++ b/ext/ai/js/llm/providers/openai.ts @@ -31,6 +31,9 @@ export type OpenAIRequest = { top_p?: number; n?: number; stream?: boolean; + stream_options: { + include_usage: boolean; + }; stop?: string | string[]; max_tokens?: number; presence_penalty?: number; @@ -69,7 +72,7 @@ export type OpenAIResponseUsage = { export type OpenAIResponseChoice = { index: number; - message: { + message?: { role: "assistant" | "user" | "system" | "tool"; content: string | null; function_call?: { @@ -85,6 +88,9 @@ export type OpenAIResponseChoice = { }; }[]; }; + delta?: { + content: string | null; + }; finish_reason: "stop" | "length" | "tool_calls" | "content_filter" | null; }; @@ -98,7 +104,10 @@ export type OpenAIResponse = { usage?: OpenAIResponseUsage; }; -export type OpenAICompatibleInput = Omit; +export type OpenAICompatibleInput = Omit< + OpenAIRequest, + "stream" | "stream_options" | "model" +>; export type OpenAIProviderInput = ILLMProviderInput; export type OpenAIProviderOutput = ILLMProviderOutput; @@ -126,23 +135,25 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { const parser = this.parse; const stream = async function* () { for await (const message of generator) { + // NOTE:(kallebysantos) while streaming the final message will not include 'finish_reason' + // Instead a '[DONE]' value will be returned to close the stream + if ("done" in message && message.done) { + return; + } + if ("error" in message) { if (message.error instanceof Error) { throw message.error; - } else { - throw new Error(message.error as string); } + + throw new Error(message.error as string); } yield parser(message); - const finishReason = message.choices[0].finish_reason; - if (finishReason) { - if (finishReason !== "stop") { - throw new Error("Expected a completed response."); - } - - return; + const finish_reason = message.choices.at(0)?.finish_reason; + if (finish_reason && finish_reason !== "stop") { + throw new Error("Expected a completed response."); } } @@ -172,12 +183,14 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { return this.parse(response); } - private parse(message: OpenAIResponse): OpenAIProviderOutput { - const { usage } = message; + private parse(response: OpenAIResponse): OpenAIProviderOutput { + const { usage } = response; + const choice = response.choices.at(0); return { - value: message.choices.at(0)?.message.content ?? undefined, - inner: message, + // NOTE:(kallebysantos) while streaming the 'delta' field will be used instead of 'message' + value: choice?.message?.content ?? choice?.delta?.content ?? undefined, + inner: response, usage: { // NOTE:(kallebysantos) usage maybe 'null' while streaming, but the final message will include it inputTokens: usage?.prompt_tokens ?? 0, @@ -204,6 +217,9 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { ...input, model: this.options.model, stream, + stream_options: { + include_usage: true, + }, } satisfies OpenAIRequest, ), signal, diff --git a/ext/ai/js/llm/utils/event_source_stream.mjs b/ext/ai/js/llm/utils/event_source_stream.mjs index fa355da05..0ec0b889b 100644 --- a/ext/ai/js/llm/utils/event_source_stream.mjs +++ b/ext/ai/js/llm/utils/event_source_stream.mjs @@ -1,4 +1,4 @@ -import EventStreamParser from './event_stream_parser.mjs'; +import EventStreamParser from "./event_stream_parser.mjs"; /** * A Web stream which handles Server-Sent Events from a binary ReadableStream like you get from the fetch API. * Implements the TransformStream interface, and can be used with the Streams API as such. @@ -9,11 +9,19 @@ class EventSourceStream { // 1. The SSE spec allows for an optional UTF-8 BOM. // 2. We have to use a *streaming* decoder, in case two adjacent data chunks are split up in the middle of a // multibyte Unicode character. Trying to parse the two separately would result in data corruption. - const decoder = new TextDecoderStream('utf-8'); + const decoder = new TextDecoderStream("utf-8"); let parser; const sseStream = new TransformStream({ start(controller) { parser = new EventStreamParser((data, eventType, lastEventId) => { + // NOTE:(kallebysantos) Some providers like OpenAI send '[DONE]' + // to indicates stream terminates, so we need to check if the SSE contains "[DONE]" and close the stream + if (typeof data === "string" && data.trim() === "[DONE]") { + controller.terminate?.(); // If supported + controller.close?.(); // Fallback + return; + } + controller.enqueue( new MessageEvent(eventType, { data, lastEventId }), ); diff --git a/ext/ai/js/llm/utils/json_parser.ts b/ext/ai/js/llm/utils/json_parser.ts index b57b9f11f..2cb318662 100644 --- a/ext/ai/js/llm/utils/json_parser.ts +++ b/ext/ai/js/llm/utils/json_parser.ts @@ -71,6 +71,7 @@ export async function* parseJSONOverEventStream( const { done, value } = await reader.read(); if (done) { + yield { done }; break; } From ba13d074193cb5bd0c81cbbcc32236778c602a4a Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Wed, 14 May 2025 15:05:37 +0100 Subject: [PATCH 09/14] stamp: improved error handling - Applying 'pattern matching' and 'Result pattern' to improve error handling. It enforces that users must first check for errors before consuming the message --- ext/ai/js/ai.ts | 91 +++++++++++++-------- ext/ai/js/llm/llm_session.ts | 32 ++++++-- ext/ai/js/llm/providers/ollama.ts | 126 ++++++++++++++++++++++------- ext/ai/js/llm/providers/openai.ts | 128 +++++++++++++++++++++++------- 4 files changed, 281 insertions(+), 96 deletions(-) diff --git a/ext/ai/js/ai.ts b/ext/ai/js/ai.ts index 92701d9b6..2002df009 100644 --- a/ext/ai/js/ai.ts +++ b/ext/ai/js/ai.ts @@ -10,6 +10,9 @@ import { // @ts-ignore deno_core environment const core = globalThis.Deno.core; +// TODO: extract to utils file +export type Result = [T, undefined] | [undefined, E]; + // NOTE:(kallebysantos) do we still need gte-small? Or maybe add another type 'embeddings' with custom model opt. export type SessionType = LLMProviderName | "gte-small"; @@ -47,6 +50,16 @@ export type SessionOutput = T extends "gte-small" : LLMProviderInstance["output"] : never; +export type SessionError = { + message: string; + inner: T; +}; + +export type SessionOutputError = T extends "gte-small" + ? SessionError + : T extends LLMProviderName ? SessionError["error"]> + : any; + export class Session { #model?: string; #init?: Promise; @@ -87,42 +100,58 @@ export class Session { async run>( input: SessionInput, options: O, - ): Promise> { - if (this.isLLMType()) { - const opts = options as LLMInputOptions; - const stream = opts.stream ?? false; - - const llmSession = LLMSession.fromProvider(this.type, { - // safety: We did check `options` during construction - baseURL: this.options!.baseURL, - model: this.options!.model, - ...this.options, // allows custom provider initialization like 'apiKey' - }); - - return await llmSession.run(input, { - stream, - signal: opts.signal, - timeout: opts.timeout, - }) as SessionOutput; - } + ): Promise< + [SessionOutput, undefined] | [undefined, SessionOutputError] + > { + try { + if (this.isLLMType()) { + const opts = options as LLMInputOptions; + const stream = opts.stream ?? false; + + const llmSession = LLMSession.fromProvider(this.type, { + // safety: We did check `options` during construction + baseURL: this.options!.baseURL, + model: this.options!.model, + ...this.options, // allows custom provider initialization like 'apiKey' + }); + + const [output, error] = await llmSession.run(input, { + stream, + signal: opts.signal, + timeout: opts.timeout, + }); + if (error) { + return [undefined, error as SessionOutputError]; + } + + return [output as SessionOutput, undefined]; + } - if (this.#init) { - await this.#init; - } + if (this.#init) { + await this.#init; + } - const opts = options as EmbeddingInputOptions; + const opts = options as EmbeddingInputOptions; - const mean_pool = opts.mean_pool ?? true; - const normalize = opts.normalize ?? true; + const mean_pool = opts.mean_pool ?? true; + const normalize = opts.normalize ?? true; - const result = await core.ops.op_ai_run_model( - this.#model, - prompt, - mean_pool, - normalize, - ); + const result = await core.ops.op_ai_run_model( + this.#model, + prompt, + mean_pool, + normalize, + ) as SessionOutput; - return result; + return [result, undefined]; + } catch (e: any) { + const error = (e instanceof Error) ? e : new Error(e); + + return [ + undefined, + { inner: error, message: error.message } as SessionOutputError, + ]; + } } private isEmbeddingType( diff --git a/ext/ai/js/llm/llm_session.ts b/ext/ai/js/llm/llm_session.ts index 17366270e..b9db4e157 100644 --- a/ext/ai/js/llm/llm_session.ts +++ b/ext/ai/js/llm/llm_session.ts @@ -1,3 +1,4 @@ +import { Result, SessionError } from "../ai.ts"; import { OllamaLLMSession } from "./providers/ollama.ts"; import { OpenAILLMSession } from "./providers/openai.ts"; @@ -20,6 +21,7 @@ export type LLMRunInput = { export interface ILLMProviderMeta { input: ILLMProviderInput; output: unknown; + error: unknown; options: ILLMProviderOptions; } @@ -41,15 +43,23 @@ export interface ILLMProviderOutput { inner: T; } +export interface ILLMProviderError extends SessionError { +} + export interface ILLMProvider { getStream( input: ILLMProviderInput, signal: AbortSignal, - ): Promise>; + ): Promise< + Result< + AsyncIterable>, + ILLMProviderError + > + >; getText( input: ILLMProviderInput, signal: AbortSignal, - ): Promise; + ): Promise>; } export const providers = { @@ -81,6 +91,10 @@ export type LLMSessionRunInputOptions = { signal?: AbortSignal; }; +export type LLMSessionOutput = + | AsyncIterable> + | ILLMProviderOutput; + export class LLMSession { #inner: ILLMProvider; @@ -97,10 +111,10 @@ export class LLMSession { return new LLMSession(provider); } - run( + async run( input: ILLMProviderInput, opts: LLMSessionRunInputOptions, - ): Promise> | Promise { + ): Promise> { const isStream = opts.stream ?? false; const timeoutSeconds = typeof opts.timeout === "number" ? opts.timeout : 60; @@ -112,7 +126,15 @@ export class LLMSession { const signal = AbortSignal.any(abortSignals); if (isStream) { - return this.#inner.getStream(input, signal); + const [stream, getStreamError] = await this.#inner.getStream( + input, + signal, + ); + if (getStreamError) { + return [undefined, getStreamError]; + } + + return [stream, undefined]; } return this.#inner.getText(input, signal); diff --git a/ext/ai/js/llm/providers/ollama.ts b/ext/ai/js/llm/providers/ollama.ts index 459222b58..511de9382 100644 --- a/ext/ai/js/llm/providers/ollama.ts +++ b/ext/ai/js/llm/providers/ollama.ts @@ -1,5 +1,7 @@ +import { Result } from "../../ai.ts"; import { ILLMProvider, + ILLMProviderError, ILLMProviderInput, ILLMProviderMeta, ILLMProviderOptions, @@ -9,7 +11,11 @@ import { parseJSON } from "../utils/json_parser.ts"; export type OllamaProviderOptions = ILLMProviderOptions; export type OllamaProviderInput = ILLMProviderInput; -export type OllamaProviderOutput = ILLMProviderOutput; +export type OllamaProviderOutput = Result< + ILLMProviderOutput, + OllamaProviderError +>; +export type OllamaProviderError = ILLMProviderError; export type OllamaMessage = { model: string; @@ -28,6 +34,7 @@ export type OllamaMessage = { export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { input!: OllamaProviderInput; output!: OllamaProviderOutput; + error!: OllamaProviderError; options: OllamaProviderOptions; constructor(opts: OllamaProviderOptions) { @@ -38,26 +45,41 @@ export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { async getStream( prompt: OllamaProviderInput, signal: AbortSignal, - ): Promise> { - const generator = await this.generate( + ): Promise< + Result, OllamaProviderError> + > { + const [generator, error] = await this.generate( prompt, signal, true, - ) as AsyncGenerator; + ) as Result, OllamaProviderError>; + + if (error) { + return [undefined, error]; + } // NOTE:(kallebysantos) we need to clone the lambda parser to avoid `this` conflicts inside the local function* const parser = this.parse; const stream = async function* () { for await (const message of generator) { if ("error" in message) { - if (message.error instanceof Error) { - throw message.error; - } else { - throw new Error(message.error as string); - } + const error = (message.error instanceof Error) + ? message.error + : new Error(message.error as string); + + yield [ + undefined, + { + inner: { + error, + currentValue: null, + }, + message: "An unknown error was streamed from the provider.", + } satisfies OllamaProviderError, + ]; } - yield parser(message); + yield [parser(message), undefined]; if (message.done) { return; @@ -69,32 +91,52 @@ export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { ); }; - return stream(); + return [ + stream() as AsyncIterable, + undefined, + ]; } async getText( prompt: OllamaProviderInput, signal: AbortSignal, ): Promise { - const response = await this.generate(prompt, signal) as OllamaMessage; + const [generation, generationError] = await this.generate( + prompt, + signal, + ) as Result; - if (!response?.done) { - throw new Error("Expected a completed response."); + if (generationError) { + return [undefined, generationError]; } - return this.parse(response); + if (!generation?.done) { + return [undefined, { + inner: { + error: new Error("Expected a completed response."), + currentValue: generation, + }, + message: + `Response could not be completed successfully. Expected 'done'`, + }]; + } + + return [this.parse(generation), undefined]; } - private parse(message: OllamaMessage): OllamaProviderOutput { + private parse(message: OllamaMessage): ILLMProviderOutput { const { response, prompt_eval_count, eval_count } = message; + const inputTokens = isNaN(prompt_eval_count) ? 0 : prompt_eval_count; + const outputTokens = isNaN(eval_count) ? 0 : eval_count; + return { value: response, inner: message, usage: { - inputTokens: prompt_eval_count, - outputTokens: eval_count, - totalTokens: prompt_eval_count + eval_count, + inputTokens, + outputTokens, + totalTokens: inputTokens + outputTokens, }, }; } @@ -103,7 +145,9 @@ export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { prompt: string, signal: AbortSignal, stream: boolean = false, - ) { + ): Promise< + Result | OllamaMessage, OllamaProviderError> + > { const res = await fetch( new URL("/api/generate", this.options.baseURL), { @@ -120,22 +164,44 @@ export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { }, ); - if (!res.ok) { - throw new Error( - `Failed to fetch inference API host. Status ${res.status}: ${res.statusText}`, - ); - } + // try to extract the json error otherwise return any text content from the response + if (!res.ok || !res.body) { + const errorMsg = + `Failed to fetch inference API host '${this.options.baseURL}'. Status ${res.status}: ${res.statusText}`; + + if (!res.body) { + const error = { + inner: new Error("Missing response body."), + message: errorMsg, + } satisfies OllamaProviderError; - if (!res.body) { - throw new Error("Missing body"); + return [undefined, error]; + } + + // safe to extract response body cause it was checked above + try { + const error = { + inner: await res.json(), + message: errorMsg, + } satisfies OllamaProviderError; + + return [undefined, error]; + } catch (_) { + const error = { + inner: new Error(await res.text()), + message: errorMsg, + } satisfies OllamaProviderError; + + return [undefined, error]; + } } if (stream) { - return parseJSON(res.body, signal); + const stream = parseJSON(res.body, signal); + return [stream as AsyncGenerator, undefined]; } const result: OllamaMessage = await res.json(); - - return result; + return [result, undefined]; } } diff --git a/ext/ai/js/llm/providers/openai.ts b/ext/ai/js/llm/providers/openai.ts index 8d7fda73b..7ecf1f826 100644 --- a/ext/ai/js/llm/providers/openai.ts +++ b/ext/ai/js/llm/providers/openai.ts @@ -1,5 +1,7 @@ +import { Result } from "../../ai.ts"; import { ILLMProvider, + ILLMProviderError, ILLMProviderInput, ILLMProviderMeta, ILLMProviderOptions, @@ -110,11 +112,16 @@ export type OpenAICompatibleInput = Omit< >; export type OpenAIProviderInput = ILLMProviderInput; -export type OpenAIProviderOutput = ILLMProviderOutput; +export type OpenAIProviderOutput = Result< + ILLMProviderOutput, + OpenAIProviderError +>; +export type OpenAIProviderError = ILLMProviderError; export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { input!: OpenAIProviderInput; output!: OpenAIProviderOutput; + error!: OpenAIProviderError; options: OpenAIProviderOptions; constructor(opts: OpenAIProviderOptions) { @@ -124,12 +131,18 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { async getStream( prompt: OpenAIProviderInput, signal: AbortSignal, - ): Promise> { - const generator = await this.generate( + ): Promise< + Result, OpenAIProviderError> + > { + const [generator, error] = await this.generate( prompt, signal, true, - ) as AsyncGenerator; + ) as Result, OpenAIProviderError>; + + if (error) { + return [undefined, error]; + } // NOTE:(kallebysantos) we need to clone the lambda parser to avoid `this` conflicts inside the local function* const parser = this.parse; @@ -142,18 +155,34 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { } if ("error" in message) { - if (message.error instanceof Error) { - throw message.error; - } + const error = (message.error instanceof Error) + ? message.error + : new Error(message.error as string); - throw new Error(message.error as string); + yield [ + undefined, + { + inner: { + error, + currentValue: null, + }, + message: "An unknown error was streamed from the provider.", + } satisfies OpenAIProviderError, + ]; } - yield parser(message); + yield [parser(message), undefined]; - const finish_reason = message.choices.at(0)?.finish_reason; - if (finish_reason && finish_reason !== "stop") { - throw new Error("Expected a completed response."); + const finishReason = message.choices.at(0)?.finish_reason; + if (finishReason && finishReason !== "stop") { + yield [undefined, { + inner: { + error: new Error("Expected a completed response."), + currentValue: message, + }, + message: + `Response could not be completed successfully. Expected 'stop' finish reason got '${finishReason}'`, + }]; } } @@ -162,28 +191,42 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { ); }; - return stream(); + return [ + stream() as AsyncIterable, + undefined, + ]; } async getText( prompt: OpenAIProviderInput, signal: AbortSignal, ): Promise { - const response = await this.generate( + const [generation, generationError] = await this.generate( prompt, signal, - ) as OpenAIResponse; + ) as Result; - const finishReason = response.choices[0].finish_reason; + if (generationError) { + return [undefined, generationError]; + } + + const finishReason = generation.choices[0].finish_reason; if (finishReason !== "stop") { - throw new Error("Expected a completed response."); + return [undefined, { + inner: { + error: new Error("Expected a completed response."), + currentValue: generation, + }, + message: + `Response could not be completed successfully. Expected 'stop' finish reason got '${finishReason}'`, + }]; } - return this.parse(response); + return [this.parse(generation), undefined]; } - private parse(response: OpenAIResponse): OpenAIProviderOutput { + private parse(response: OpenAIResponse): ILLMProviderOutput { const { usage } = response; const choice = response.choices.at(0); @@ -199,11 +242,14 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { }, }; } + private async generate( input: OpenAICompatibleInput, signal: AbortSignal, stream: boolean = false, - ) { + ): Promise< + Result | OpenAIResponse, OpenAIProviderError> + > { const res = await fetch( new URL("/v1/chat/completions", this.options.baseURL), { @@ -226,22 +272,44 @@ export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { }, ); - if (!res.ok) { - throw new Error( - `Failed to fetch inference API host. Status ${res.status}: ${res.statusText}`, - ); - } + // try to extract the json error otherwise return any text content from the response + if (!res.ok || !res.body) { + const errorMsg = + `Failed to fetch inference API host '${this.options.baseURL}'. Status ${res.status}: ${res.statusText}`; + + if (!res.body) { + const error = { + inner: new Error("Missing response body."), + message: errorMsg, + } satisfies OpenAIProviderError; + + return [undefined, error]; + } - if (!res.body) { - throw new Error("Missing body"); + // safe to extract response body cause it was checked above + try { + const error = { + inner: await res.json(), + message: errorMsg, + } satisfies OpenAIProviderError; + + return [undefined, error]; + } catch (_) { + const error = { + inner: new Error(await res.text()), + message: errorMsg, + } satisfies OpenAIProviderError; + + return [undefined, error]; + } } if (stream) { - return parseJSONOverEventStream(res.body, signal); + const stream = parseJSONOverEventStream(res.body, signal); + return [stream as AsyncGenerator, undefined]; } const result: OpenAIResponse = await res.json(); - - return result; + return [result, undefined]; } } From 611645cdca5b1175d8e0db59c95ceef39822968d Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Wed, 14 May 2025 15:41:35 +0100 Subject: [PATCH 10/14] fix: `SessionOutput` type defs --- ext/ai/js/ai.ts | 7 ++++++- ext/ai/js/llm/providers/ollama.ts | 2 +- ext/ai/js/llm/providers/openai.ts | 2 +- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/ext/ai/js/ai.ts b/ext/ai/js/ai.ts index 2002df009..8dd6f00d6 100644 --- a/ext/ai/js/ai.ts +++ b/ext/ai/js/ai.ts @@ -46,7 +46,12 @@ export type SessionOutput = T extends "gte-small" ? number[] : T extends LLMProviderName ? O extends { stream: true } - ? AsyncGenerator["output"]> + ? AsyncGenerator< + Result< + LLMProviderInstance["output"], + LLMProviderInstance["error"] + > + > : LLMProviderInstance["output"] : never; diff --git a/ext/ai/js/llm/providers/ollama.ts b/ext/ai/js/llm/providers/ollama.ts index 511de9382..0956f1cd7 100644 --- a/ext/ai/js/llm/providers/ollama.ts +++ b/ext/ai/js/llm/providers/ollama.ts @@ -33,7 +33,7 @@ export type OllamaMessage = { export class OllamaLLMSession implements ILLMProvider, ILLMProviderMeta { input!: OllamaProviderInput; - output!: OllamaProviderOutput; + output!: ILLMProviderOutput; error!: OllamaProviderError; options: OllamaProviderOptions; diff --git a/ext/ai/js/llm/providers/openai.ts b/ext/ai/js/llm/providers/openai.ts index 7ecf1f826..e085f9348 100644 --- a/ext/ai/js/llm/providers/openai.ts +++ b/ext/ai/js/llm/providers/openai.ts @@ -120,7 +120,7 @@ export type OpenAIProviderError = ILLMProviderError; export class OpenAILLMSession implements ILLMProvider, ILLMProviderMeta { input!: OpenAIProviderInput; - output!: OpenAIProviderOutput; + output!: ILLMProviderOutput; error!: OpenAIProviderError; options: OpenAIProviderOptions; From 8053e600dd90b84e101c6c47dbc8b4cc4c40a554 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Thu, 15 May 2025 11:57:46 +0100 Subject: [PATCH 11/14] fix: validating empty texts before run gte - It ensures that only valid strings with content can be embeded --- ext/ai/lib.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index 68fc77461..b58280a76 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -280,6 +280,10 @@ async fn run_gte( mean_pool: bool, normalize: bool, ) -> Result, Error> { + if prompt.is_empty() { + bail!("must provide a valid prompt value, got 'empty'") + } + let req_tx; { let op_state = state.borrow(); From 8a605bbf67cad9d539a15c040aa024794f133f5c Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Thu, 15 May 2025 11:59:42 +0100 Subject: [PATCH 12/14] fix: embedding inference for `gte-small` type - Fix wrong input variable name. - Accepting 'opts' param as optinal, applying null safes. --- ext/ai/js/ai.ts | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/ext/ai/js/ai.ts b/ext/ai/js/ai.ts index 8dd6f00d6..23ad212de 100644 --- a/ext/ai/js/ai.ts +++ b/ext/ai/js/ai.ts @@ -44,9 +44,7 @@ export type SessionInputOptions = T extends "gte-small" export type SessionOutput = T extends "gte-small" ? number[] - : T extends LLMProviderName - ? O extends { stream: true } - ? AsyncGenerator< + : T extends LLMProviderName ? O extends { stream: true } ? AsyncGenerator< Result< LLMProviderInstance["output"], LLMProviderInstance["error"] @@ -104,7 +102,7 @@ export class Session { async run>( input: SessionInput, - options: O, + options?: O, ): Promise< [SessionOutput, undefined] | [undefined, SessionOutputError] > { @@ -136,14 +134,14 @@ export class Session { await this.#init; } - const opts = options as EmbeddingInputOptions; + const opts = options as EmbeddingInputOptions | undefined; - const mean_pool = opts.mean_pool ?? true; - const normalize = opts.normalize ?? true; + const mean_pool = opts?.mean_pool ?? true; + const normalize = opts?.normalize ?? true; const result = await core.ops.op_ai_run_model( this.#model, - prompt, + input, mean_pool, normalize, ) as SessionOutput; From 303c99217bd17ac5349c610a94731d9a12fee3d0 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Thu, 15 May 2025 12:01:29 +0100 Subject: [PATCH 13/14] test: fix `gte-small` tests to pass with the new `ai` refactors - Improving tests by checking the result types: success or errors - Testing invalid `gte-small` type name --- crates/base/test_cases/supabase-ai/index.ts | 72 ++++++++++++++++----- 1 file changed, 56 insertions(+), 16 deletions(-) diff --git a/crates/base/test_cases/supabase-ai/index.ts b/crates/base/test_cases/supabase-ai/index.ts index 51d05b87a..6411bf528 100644 --- a/crates/base/test_cases/supabase-ai/index.ts +++ b/crates/base/test_cases/supabase-ai/index.ts @@ -1,6 +1,18 @@ -import { assertGreater, assertLessOrEqual } from "jsr:@std/assert"; +import { + assertEquals, + assertExists, + assertGreater, + assertIsError, + assertLessOrEqual, + assertStringIncludes, + assertThrows, +} from 'jsr:@std/assert'; -const session = new Supabase.ai.Session("gte-small"); +const session = new Supabase.ai.Session('gte-small'); + +assertThrows(() => { + const _ = new Supabase.ai.Session('gte-small_wrong_name'); +}, "invalid 'Session' type"); function dotProduct(a: number[], b: number[]) { let result = 0; @@ -15,27 +27,55 @@ export default { async fetch() { // Generate embedding // @ts-ignore unkwnow type - const meow: number[] = await session.run("meow", { - mean_pool: true, - normalize: true, - }); + const [meow, meowError] = await session.run('meow') as [ + number[], + undefined, + ]; + console.log('cat', meow, meowError); // @ts-ignore unkwnow type - const love: number[] = await session.run("I love cats", { + const [love, loveError] = await session.run('I love cats', { mean_pool: true, normalize: true, - }); + }) as [number[], undefined]; + + // "Valid input should result in ok value" + { + assertExists(meow); + assertExists(love); + + assertEquals(meowError, undefined); + assertEquals(loveError, undefined); + } + + // "Invalid input should result in error value" + { + const [notCat, notCatError] = await session.run({ + bad_input: { 'not a cat': 'let fail' }, + }) as [undefined, { message: string; inner: Error }]; + + assertEquals(notCat, undefined); + + assertExists(notCatError); + assertIsError(notCatError.inner); + assertStringIncludes( + notCatError.message, + 'must provide a valid prompt value', + ); + } - // Ensures `mean_pool` and `normalize` - const sameScore = dotProduct(meow, meow); - const diffScore = dotProduct(meow, love); + // "Ensures `mean_pool` and `normalize`" + { + const sameScore = dotProduct(meow, meow); + const diffScore = dotProduct(meow, love); - assertGreater(sameScore, 0.9); - assertGreater(diffScore, 0.5); - assertGreater(sameScore, diffScore); + assertGreater(sameScore, 0.9); + assertGreater(diffScore, 0.5); + assertGreater(sameScore, diffScore); - assertLessOrEqual(sameScore, 1); - assertLessOrEqual(diffScore, 1); + assertLessOrEqual(sameScore, 1); + assertLessOrEqual(diffScore, 1); + } return new Response( null, From 8aedd15d0eb717fc6c61b726939898985f910ab0 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Thu, 15 May 2025 12:04:01 +0100 Subject: [PATCH 14/14] stamp: format :) --- crates/base/test_cases/supabase-ai/index.ts | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/crates/base/test_cases/supabase-ai/index.ts b/crates/base/test_cases/supabase-ai/index.ts index 6411bf528..97e32d6b0 100644 --- a/crates/base/test_cases/supabase-ai/index.ts +++ b/crates/base/test_cases/supabase-ai/index.ts @@ -6,12 +6,12 @@ import { assertLessOrEqual, assertStringIncludes, assertThrows, -} from 'jsr:@std/assert'; +} from "jsr:@std/assert"; -const session = new Supabase.ai.Session('gte-small'); +const session = new Supabase.ai.Session("gte-small"); assertThrows(() => { - const _ = new Supabase.ai.Session('gte-small_wrong_name'); + const _ = new Supabase.ai.Session("gte-small_wrong_name"); }, "invalid 'Session' type"); function dotProduct(a: number[], b: number[]) { @@ -27,14 +27,13 @@ export default { async fetch() { // Generate embedding // @ts-ignore unkwnow type - const [meow, meowError] = await session.run('meow') as [ + const [meow, meowError] = await session.run("meow") as [ number[], undefined, ]; - console.log('cat', meow, meowError); // @ts-ignore unkwnow type - const [love, loveError] = await session.run('I love cats', { + const [love, loveError] = await session.run("I love cats", { mean_pool: true, normalize: true, }) as [number[], undefined]; @@ -51,7 +50,7 @@ export default { // "Invalid input should result in error value" { const [notCat, notCatError] = await session.run({ - bad_input: { 'not a cat': 'let fail' }, + bad_input: { "not a cat": "let fail" }, }) as [undefined, { message: string; inner: Error }]; assertEquals(notCat, undefined); @@ -60,7 +59,7 @@ export default { assertIsError(notCatError.inner); assertStringIncludes( notCatError.message, - 'must provide a valid prompt value', + "must provide a valid prompt value", ); }