From 0d7f780a816a19d7f8fe78507bd28b75e7b059e7 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Sun, 26 Jan 2025 21:11:23 +0000 Subject: [PATCH 1/8] feat: creating `inference_api` - Exposing an user friendly interface to consume the `onnx` backend --- examples/ort-raw-session/index.ts | 41 ++++++++++++++ ext/ai/js/ai.js | 8 +-- ext/ai/js/onnxruntime/inference_api.js | 75 ++++++++++++++++++++++++++ ext/ai/js/onnxruntime/onnx.js | 4 +- ext/ai/lib.rs | 3 +- 5 files changed, 125 insertions(+), 6 deletions(-) create mode 100644 examples/ort-raw-session/index.ts create mode 100644 ext/ai/js/onnxruntime/inference_api.js diff --git a/examples/ort-raw-session/index.ts b/examples/ort-raw-session/index.ts new file mode 100644 index 000000000..91ff66c2d --- /dev/null +++ b/examples/ort-raw-session/index.ts @@ -0,0 +1,41 @@ +const { Tensor, RawSession } = Supabase.ai; + +const session = await RawSession.fromHuggingFace('kallebysantos/vehicle-emission', { + path: { + modelFile: 'model.onnx', + }, +}); + +Deno.serve(async (_req: Request) => { + // sample data could be a JSON request + const carsBatchInput = [{ + 'Model_Year': 2021, + 'Engine_Size': 2.9, + 'Cylinders': 6, + 'Fuel_Consumption_in_City': 13.9, + 'Fuel_Consumption_in_City_Hwy': 10.3, + 'Fuel_Consumption_comb': 12.3, + 'Smog_Level': 3, + }, { + 'Model_Year': 2023, + 'Engine_Size': 2.4, + 'Cylinders': 4, + 'Fuel_Consumption_in_City': 9.9, + 'Fuel_Consumption_in_City_Hwy': 7.0, + 'Fuel_Consumption_comb': 8.6, + 'Smog_Level': 3, + }]; + + // Parsing objects to tensor input + const inputTensors = {}; + session.inputs.forEach((inputKey) => { + const values = carsBatchInput.map((item) => item[inputKey]); + + inputTensors[inputKey] = new Tensor('float32', values, [values.length, 1]); + }); + + const { emissions } = await session.run(inputTensors); + // [ 289.01, 199.53] + + return Response.json({ result: emissions }); +}); diff --git a/ext/ai/js/ai.js b/ext/ai/js/ai.js index c2a306927..c6d53889c 100644 --- a/ext/ai/js/ai.js +++ b/ext/ai/js/ai.js @@ -1,5 +1,6 @@ -import "ext:ai/onnxruntime/onnx.js"; -import EventSourceStream from "ext:ai/util/event_source_stream.mjs"; +import 'ext:ai/onnxruntime/onnx.js'; +import InferenceAPI from 'ext:ai/onnxruntime/inference_api.js'; +import EventSourceStream from 'ext:ai/util/event_source_stream.mjs'; const core = globalThis.Deno.core; @@ -257,7 +258,8 @@ const MAIN_WORKER_API = { }; const USER_WORKER_API = { - Session, + Session, + ...InferenceAPI }; export { MAIN_WORKER_API, USER_WORKER_API }; diff --git a/ext/ai/js/onnxruntime/inference_api.js b/ext/ai/js/onnxruntime/inference_api.js new file mode 100644 index 000000000..195bf1a86 --- /dev/null +++ b/ext/ai/js/onnxruntime/inference_api.js @@ -0,0 +1,75 @@ +import { InferenceSession, Tensor } from 'ext:ai/onnxruntime/onnx.js'; + +const DEFAULT_HUGGING_FACE_OPTIONS = { + hostname: 'https://huggingface.co', + path: { + template: '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true', + revision: 'main', + modelFile: 'model_quantized.onnx', + }, +}; + +/** + * An user friendly API for onnx backend + */ +class UserInferenceSession { + inner; + + id; + inputs; + outputs; + + constructor(session) { + this.inner = session; + + this.id = session.sessionId; + this.inputs = session.inputNames; + this.outputs = session.outputNames; + } + + static async fromUrl(modelUrl) { + if (modelUrl instanceof URL) { + modelUrl = modelUrl.toString(); + } + + const encoder = new TextEncoder(); + const modelUrlBuffer = encoder.encode(modelUrl); + const session = await InferenceSession.fromBuffer(modelUrlBuffer); + + return new UserInferenceSession(session); + } + + static async fromHuggingFace(repoId, opts = {}) { + const hostname = opts?.hostname ?? DEFAULT_HUGGING_FACE_OPTIONS.hostname; + const pathOpts = { + ...DEFAULT_HUGGING_FACE_OPTIONS.path, + ...opts?.path, + }; + + const modelPath = pathOpts.template + .replaceAll('{REPO_ID}', repoId) + .replaceAll('{REVISION}', pathOpts.revision) + .replaceAll('{MODEL_FILE}', pathOpts.modelFile); + + if (!URL.canParse(modelPath, hostname)) { + throw Error(`[Invalid URL] Couldn't parse the model path: "${modelPath}"`); + } + + return await UserInferenceSession.fromUrl(new URL(modelPath, hostname)); + } + + async run(inputs) { + return await this.inner.run(inputs); + } +} + +class UserTensor extends Tensor { + constructor(type, data, dim) { + super(type, data, dim); + } +} + +export default { + RawSession: UserInferenceSession, + Tensor: UserTensor, +}; diff --git a/ext/ai/js/onnxruntime/onnx.js b/ext/ai/js/onnxruntime/onnx.js index 2e0e4548a..643de198a 100644 --- a/ext/ai/js/onnxruntime/onnx.js +++ b/ext/ai/js/onnxruntime/onnx.js @@ -31,7 +31,7 @@ class TensorProxy { } } -class Tensor { +export class Tensor { /** @type {DataType} Type of the tensor. */ type; @@ -67,7 +67,7 @@ class Tensor { } } -class InferenceSession { +export class InferenceSession { sessionId; inputNames; outputNames; diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index 224b0450f..fb449441e 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -55,7 +55,8 @@ deno_core::extension!( "util/event_stream_parser.mjs", "util/event_source_stream.mjs", "onnxruntime/onnx.js", - "onnxruntime/cache_adapter.js" + "onnxruntime/cache_adapter.js", + "onnxruntime/inference_api.js" ] ); From c906bf1fd284d811917963f48977eec521379a31 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Sun, 26 Jan 2025 21:27:07 +0000 Subject: [PATCH 2/8] stamp: add typescript defs for 'InferenceAPI' --- examples/ort-raw-session/types.d.ts | 161 ++++++++++++++++++++++++++++ 1 file changed, 161 insertions(+) create mode 100644 examples/ort-raw-session/types.d.ts diff --git a/examples/ort-raw-session/types.d.ts b/examples/ort-raw-session/types.d.ts new file mode 100644 index 000000000..f0f4c2f11 --- /dev/null +++ b/examples/ort-raw-session/types.d.ts @@ -0,0 +1,161 @@ +declare namespace Supabase { + /** + * Provides AI related APIs + */ + export interface Ai { + /** Provides an user friendly interface for the low level *onnx backend API*. + * A `RawSession` can execute any *onnx* model, but we only recommend it for `tabular` or *self-made* models, where you need mode control of model execution and pre/pos-processing. + * Consider a high-level implementation like `@huggingface/transformers.js` for generic tasks like `nlp`, `computer-vision` or `audio`. + * + * **Example:** + * ```typescript + * const session = await RawSession.fromHuggingFace('Supabase/gte-small'); + * // const session = await RawSession.fromUrl("https://example.com/model.onnx"); + * + * // Prepare the input tensors + * const inputs = { + * input1: new Tensor("float32", [1.0, 2.0, 3.0], [3]), + * input2: new Tensor("float32", [4.0, 5.0, 6.0], [3]), + * }; + * + * // Run the model + * const outputs = await session.run(inputs); + * + * console.log(outputs.output1); // Output tensor + * ``` + */ + readonly RawSession: typeof RawSession; + + /** A low level representation of model input/output. + * Supabase's `Tensor` is totally compatible with `@huggingface/transformers.js`'s `Tensor`. It means that you can use its high-level API to apply some common operations like `sum()`, `min()`, `max()`, `normalize()` etc... + * + * **Example: Generating embeddings from scratch** + * ```typescript + * import { Tensor as HFTensor } from "@huggingface/transformers.js"; + * const { Tensor, RawSession } = Supabase.ai; + * + * const session = await RawSession.fromHuggingFace('Supabase/gte-small'); + * + * // Example only, in real 'feature-extraction' tensors are given from the tokenizer step. + * const inputs = { + * input_ids: new Tensor('float32', [...], [n, 2]), + * attention_mask: new Tensor('float32', [...], [n, 2]), + * token_types_ids: new Tensor('float32', [...], [n, 2]) + * }; + * + * const { last_hidden_state } = await session.run(inputs); + * + * // Using `transformers.js` APIs + * const hfTensor = HFTensor.mean_pooling(last_hidden_state, inputs.attention_mask).normalize(); + * + * return hfTensor.tolist(); + * + * ``` + */ + readonly Tensor: typeof Tensor; + } + + /** + * Provides AI related APIs + */ + export const ai: Ai; + + export type TensorDataTypeMap = { + float32: Float32Array | number[]; + float64: Float64Array | number[]; + string: string[]; + int8: Int8Array | number[]; + uint8: Uint8Array | number[]; + int16: Int16Array | number[]; + uint16: Uint16Array | number[]; + int32: Int32Array | number[]; + uint32: Uint32Array | number[]; + int64: BigInt64Array | number[]; + uint64: BigUint64Array | number[]; + bool: Uint8Array | number[]; + }; + + export type TensorMap = { [key: string]: Tensor }; + + export class Tensor { + /** Type of the tensor. */ + type: T; + + /** The data stored in the tensor. */ + data: TensorDataTypeMap[T]; + + /** Dimensions of the tensor. */ + dims: number[]; + + /** The total number of elements in the tensor. */ + size: number; + + constructor(type: T, data: TensorDataTypeMap[T], dims: number[]); + } + + export class RawSession { + /** The underline session's ID. + * Session's ID are unique for each loaded model, it means that even if a session is constructed twice its will share the same ID. + */ + id: string; + + /** A list of all input keys the model expects. */ + inputs: string[]; + + /** A list of all output keys the model will result. */ + outputs: string[]; + + /** Loads a ONNX model session from source URL. + * Sessions are loaded once, then will keep warm cross worker's requests + */ + static fromUrl(source: string | URL): Promise; + + /** Loads a ONNX model session from **HuggingFace** repository. + * Sessions are loaded once, then will keep warm cross worker's requests + */ + static fromHuggingFace(repoId: string, opts?: { + /** + * @default 'https://huggingface.co' + */ + hostname?: string | URL; + path?: { + /** + * @default '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true' + */ + template?: string; + /** + * @default 'main' + */ + revision?: string; + /** + * @default 'model_quantized.onnx' + */ + modelFile?: string; + }; + }): Promise; + + /** Run the current session with the given inputs. + * Use `inputs` and `outputs` properties to know the required inputs and expected results for the model session. + * + * @param inputs The input tensors required by the model. + * @returns The output tensors generated by the model. + * + * @example + * ```typescript + * const session = await RawSession.fromUrl("https://example.com/model.onnx"); + * + * // Prepare the input tensors + * const inputs = { + * input1: new Tensor("float32", [1.0, 2.0, 3.0], [3]), + * input2: new Tensor("float32", [4.0, 5.0, 6.0], [3]), + * }; + * + * // Run the model + * const outputs = await session.run(inputs); + * + * console.log(outputs.output1); // Output tensor + * ``` + */ + run(inputs: TensorMap): Promise; + } +} From 9b2c48cc706e385f963009562af0a301fca15c45 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Thu, 30 Jan 2025 20:20:40 +0000 Subject: [PATCH 3/8] feat: adding `text-to-audio` example - using `InferenceAPI` to perform `text-to-audio`. - encoding `wave` audio tensors from the rust land --- examples/ort-raw-session/index.ts | 76 +++++++++- examples/ort-raw-session/types.d.ts | 2 + examples/text-to-audio/index.ts | 103 +++++++++++++ examples/text-to-audio/phonemizer.js | 198 +++++++++++++++++++++++++ ext/ai/js/onnxruntime/inference_api.js | 17 ++- ext/ai/lib.rs | 1 + ext/ai/onnxruntime/mod.rs | 43 ++++++ 7 files changed, 438 insertions(+), 2 deletions(-) create mode 100644 examples/text-to-audio/index.ts create mode 100644 examples/text-to-audio/phonemizer.js diff --git a/examples/ort-raw-session/index.ts b/examples/ort-raw-session/index.ts index 91ff66c2d..b50c4357e 100644 --- a/examples/ort-raw-session/index.ts +++ b/examples/ort-raw-session/index.ts @@ -1,3 +1,76 @@ +/// /// + +/* +const modelUrl = 'https://huggingface.co/kalleby/hp-to-miles/resolve/main/model.onnx?download=true'; +const modelConfigUrl = + 'https://huggingface.co/kalleby/hp-to-miles/resolve/main/config.json?download=true'; + +const model = await Supabase.ai.RawSession.fromUrl(modelUrl); +const modelConfig = await fetch(modelConfigUrl).then((r) => r.json()); + +Deno.serve(async (req: Request) => { + const params = new URL(req.url).searchParams; + const inputValue = parseInt(params.get('value')); + + const input = new Supabase.ai.RawTensor('float32', [inputValue], [1, 1]); + .minMaxNormalize(modelConfig.input.min, modelConfig.input.max); + + const output = await model.run({ + 'dense_dense1_input': input, + }); + + console.log('output', output); + + const outputTensor = output['dense_Dense4'] + .minMaxUnnormalize(modelConfig.label.min, modelConfig.label.max); + + return Response.json({ result: outputTensor.data }); +}); +*/ + +// transformers.js Compatible: +// import { Tensor } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.3.2'; +// const rawTensor = new Supabase.ai.RawTensor('string', urls, [urls.length]); +// console.log('raw tensor', rawTensor ); +// +// const tensor = new Tensor(rawTensor); +// console.log('hf tensor', tensor); +// +// 'hf tensor operations' +// tensor.min(); tensor.max(); tensor.norm() .... + +// const modelUrl = +// 'https://huggingface.co/pirocheto/phishing-url-detection/resolve/main/model.onnx?download=true'; + +/* +const { Tensor, RawSession } = Supabase.ai; + +const model = await RawSession.fromHuggingFace('pirocheto/phishing-url-detection', { + path: { + template: `{REPO_ID}/resolve/{REVISION}/{MODEL_FILE}?donwload=true`, + modelFile: 'model.onnx', + }, +}); + +console.log('session', model); + +Deno.serve(async (_req: Request) => { + const urls = [ + 'https://clubedemilhagem.com/home.php', + 'http://www.medicalnewstoday.com/articles/188939.php', + 'https://magalu-crediarioluiza.com/Produto_20203/produto.php?sku=1', + ]; + + const inputs = new Tensor('string', urls, [urls.length]); + console.log('tensor', inputs.data); + + const output = await model.run({ inputs }); + console.log(output); + + return Response.json({ result: output.probabilities }); +}); +*/ + const { Tensor, RawSession } = Supabase.ai; const session = await RawSession.fromHuggingFace('kallebysantos/vehicle-emission', { @@ -27,7 +100,7 @@ Deno.serve(async (_req: Request) => { }]; // Parsing objects to tensor input - const inputTensors = {}; + const inputTensors: Record> = {}; session.inputs.forEach((inputKey) => { const values = carsBatchInput.map((item) => item[inputKey]); @@ -35,6 +108,7 @@ Deno.serve(async (_req: Request) => { }); const { emissions } = await session.run(inputTensors); + console.log(emissions); // [ 289.01, 199.53] return Response.json({ result: emissions }); diff --git a/examples/ort-raw-session/types.d.ts b/examples/ort-raw-session/types.d.ts index f0f4c2f11..eafc051b7 100644 --- a/examples/ort-raw-session/types.d.ts +++ b/examples/ort-raw-session/types.d.ts @@ -91,6 +91,8 @@ declare namespace Supabase { size: number; constructor(type: T, data: TensorDataTypeMap[T], dims: number[]); + + tryEncodeAudio(sampleRate: number): Promise; } export class RawSession { diff --git a/examples/text-to-audio/index.ts b/examples/text-to-audio/index.ts new file mode 100644 index 000000000..e3e68d463 --- /dev/null +++ b/examples/text-to-audio/index.ts @@ -0,0 +1,103 @@ +// Setup type definitions for built-in Supabase Runtime APIs +import 'jsr:@supabase/functions-js/edge-runtime.d.ts'; +import { PreTrainedTokenizer } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.3.1'; + +// import 'phonemize' code from Kokoro.js repo +import { phonemize } from './phonemizer.js'; + +const { Tensor, RawSession } = Supabase.ai; + +const STYLE_DIM = 256; +const SAMPLE_RATE = 24000; +const MODEL_ID = 'onnx-community/Kokoro-82M-ONNX'; + +// https://huggingface.co/onnx-community/Kokoro-82M-ONNX#samples +const ALLOWED_VOICES = [ + 'af_bella', + 'af_nicole', + 'af_sarah', + 'af_sky', + 'am_adam', + 'am_michael', + 'bf_emma', + 'bf_isabella', + 'bm_george', + 'bm_lewis', +]; + +const session = await RawSession.fromHuggingFace(MODEL_ID); + +Deno.serve(async (req) => { + const params = new URL(req.url).searchParams; + const text = params.get('text') ?? 'Hello from Supabase!'; + const voice = params.get('voice') ?? 'af_bella'; + + if (!ALLOWED_VOICES.includes(voice)) { + return Response.json({ + error: `invalid voice '${voice}'`, + must_be_one_of: ALLOWED_VOICES, + }, { status: 400 }); + } + + const tokenizer = await loadTokenizer(); + const language = voice.at(0); // 'a'merican | 'b'ritish + const phonemes = await phonemize(text, language); + const { input_ids } = tokenizer(phonemes, { + truncation: true, + }); + + // Select voice style based on number of input tokens + const num_tokens = Math.max( + input_ids.dims.at(-1) - 2, // Without padding; + 0, + ); + + const voiceStyle = await loadVoiceStyle(voice, num_tokens); + + const { waveform } = await session.run({ + input_ids, + style: voiceStyle, + speed: new Tensor('float32', [1], [1]), + }); + + // Do `wave` encoding from rust backend + const audio = await waveform.tryEncodeAudio(SAMPLE_RATE); + + return new Response(audio, { + headers: { + 'Content-Type': 'audio/wav', + }, + }); +}); + +async function loadVoiceStyle(voice: string, num_tokens: number) { + const voice_url = + `https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/voices/${voice}.bin?download=true`; + + console.log('loading voice:', voice_url); + + const voiceBuffer = await fetch(voice_url).then(async (res) => await res.arrayBuffer()); + + const offset = num_tokens * STYLE_DIM; + const voiceData = new Float32Array(voiceBuffer).slice( + offset, + offset + STYLE_DIM, + ); + + return new Tensor('float32', voiceData, [1, STYLE_DIM]); +} + +async function loadTokenizer() { + // BUG: invalid 'h' not JSON. That's why we need to manually fetch the assets + // const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID); + + const tokenizerData = await fetch( + 'https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/tokenizer.json?download=true', + ).then(async (res) => await res.json()); + + const tokenizerConfig = await fetch( + 'https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/tokenizer_config.json?download=true', + ).then(async (res) => await res.json()); + + return new PreTrainedTokenizer(tokenizerData, tokenizerConfig); +} diff --git a/examples/text-to-audio/phonemizer.js b/examples/text-to-audio/phonemizer.js new file mode 100644 index 000000000..e694e801f --- /dev/null +++ b/examples/text-to-audio/phonemizer.js @@ -0,0 +1,198 @@ +// Source code from https://github.com/hexgrad/kokoro/blob/a09db51873211d76a3e49b058b18873a4e002b81/kokoro.js/src/phonemize.js +// BUG: Don't now why, but if import it from cdnjs will cause runtime stack overflow +import { phonemize as espeakng } from 'npm:phonemizer@1.2.1'; + +/** + * Helper function to split a string on a regex, but keep the delimiters. + * This is required, because the JavaScript `.split()` method does not keep the delimiters, + * and wrapping in a capturing group causes issues with existing capturing groups (due to nesting). + * @param {string} text The text to split. + * @param {RegExp} regex The regex to split on. + * @returns {{match: boolean; text: string}[]} The split string. + */ +function split(text, regex) { + const result = []; + let prev = 0; + for (const match of text.matchAll(regex)) { + const fullMatch = match[0]; + if (prev < match.index) { + result.push({ match: false, text: text.slice(prev, match.index) }); + } + if (fullMatch.length > 0) { + result.push({ match: true, text: fullMatch }); + } + prev = match.index + fullMatch.length; + } + if (prev < text.length) { + result.push({ match: false, text: text.slice(prev) }); + } + return result; +} + +/** + * Helper function to split numbers into phonetic equivalents + * @param {string} match The matched number + * @returns {string} The phonetic equivalent + */ +function split_num(match) { + if (match.includes('.')) { + return match; + } else if (match.includes(':')) { + let [h, m] = match.split(':').map(Number); + if (m === 0) { + return `${h} o'clock`; + } else if (m < 10) { + return `${h} oh ${m}`; + } + return `${h} ${m}`; + } + let year = parseInt(match.slice(0, 4), 10); + if (year < 1100 || year % 1000 < 10) { + return match; + } + let left = match.slice(0, 2); + let right = parseInt(match.slice(2, 4), 10); + let suffix = match.endsWith('s') ? 's' : ''; + if (year % 1000 >= 100 && year % 1000 <= 999) { + if (right === 0) { + return `${left} hundred${suffix}`; + } else if (right < 10) { + return `${left} oh ${right}${suffix}`; + } + } + return `${left} ${right}${suffix}`; +} + +/** + * Helper function to format monetary values + * @param {string} match The matched currency + * @returns {string} The formatted currency + */ +function flip_money(match) { + const bill = match[0] === '$' ? 'dollar' : 'pound'; + if (isNaN(Number(match.slice(1)))) { + return `${match.slice(1)} ${bill}s`; + } else if (!match.includes('.')) { + let suffix = match.slice(1) === '1' ? '' : 's'; + return `${match.slice(1)} ${bill}${suffix}`; + } + const [b, c] = match.slice(1).split('.'); + const d = parseInt(c.padEnd(2, '0'), 10); + let coins = match[0] === '$' ? (d === 1 ? 'cent' : 'cents') : d === 1 ? 'penny' : 'pence'; + return `${b} ${bill}${b === '1' ? '' : 's'} and ${d} ${coins}`; +} + +/** + * Helper function to process decimal numbers + * @param {string} match The matched number + * @returns {string} The formatted number + */ +function point_num(match) { + let [a, b] = match.split('.'); + return `${a} point ${b.split('').join(' ')}`; +} + +/** + * Normalize text for phonemization + * @param {string} text The text to normalize + * @returns {string} The normalized text + */ +function normalize_text(text) { + return ( + text + // 1. Handle quotes and brackets + .replace(/[‘’]/g, "'") + .replace(/«/g, '“') + .replace(/»/g, '”') + .replace(/[“”]/g, '"') + .replace(/\(/g, '«') + .replace(/\)/g, '»') + // 2. Replace uncommon punctuation marks + .replace(/、/g, ', ') + .replace(/。/g, '. ') + .replace(/!/g, '! ') + .replace(/,/g, ', ') + .replace(/:/g, ': ') + .replace(/;/g, '; ') + .replace(/?/g, '? ') + // 3. Whitespace normalization + .replace(/[^\S \n]/g, ' ') + .replace(/ +/, ' ') + .replace(/(?<=\n) +(?=\n)/g, '') + // 4. Abbreviations + .replace(/\bD[Rr]\.(?= [A-Z])/g, 'Doctor') + .replace(/\b(?:Mr\.|MR\.(?= [A-Z]))/g, 'Mister') + .replace(/\b(?:Ms\.|MS\.(?= [A-Z]))/g, 'Miss') + .replace(/\b(?:Mrs\.|MRS\.(?= [A-Z]))/g, 'Mrs') + .replace(/\betc\.(?! [A-Z])/gi, 'etc') + // 5. Normalize casual words + .replace(/\b(y)eah?\b/gi, "$1e'a") + // 5. Handle numbers and currencies + .replace(/\d*\.\d+|\b\d{4}s?\b|(? m.replace(/\./g, '-')) + .replace(/(?<=[A-Z])\.(?=[A-Z])/gi, '-') + // 8. Strip leading and trailing whitespace + .trim() + ); +} + +/** + * Escapes regular expression special characters from a string by replacing them with their escaped counterparts. + * + * @param {string} string The string to escape. + * @returns {string} The escaped string. + */ +function escapeRegExp(string) { + return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string +} + +const PUNCTUATION = ';:,.!?¡¿—…"«»“”(){}[]'; +const PUNCTUATION_PATTERN = new RegExp(`(\\s*[${escapeRegExp(PUNCTUATION)}]+\\s*)+`, 'g'); + +export async function phonemize(text, language = 'a', norm = true) { + // 1. Normalize text + if (norm) { + text = normalize_text(text); + } + + // 2. Split into chunks, to ensure we preserve punctuation + const sections = split(text, PUNCTUATION_PATTERN); + + // 3. Convert each section to phonemes + const lang = language === 'a' ? 'en-us' : 'en'; + const ps = (await Promise.all( + sections.map(async ( + { match, text }, + ) => (match ? text : (await espeakng(text, lang)).join(' '))), + )).join(''); + + // 4. Post-process phonemes + let processed = ps + // https://en.wiktionary.org/wiki/kokoro#English + .replace(/kəkˈoːɹoʊ/g, 'kˈoʊkəɹoʊ') + .replace(/kəkˈɔːɹəʊ/g, 'kˈəʊkəɹəʊ') + .replace(/ʲ/g, 'j') + .replace(/r/g, 'ɹ') + .replace(/x/g, 'k') + .replace(/ɬ/g, 'l') + .replace(/(?<=[a-zɹː])(?=hˈʌndɹɪd)/g, ' ') + .replace(/ z(?=[;:,.!?¡¿—…"«»“” ]|$)/g, 'z'); + + // 5. Additional post-processing for American English + if (language === 'a') { + processed = processed.replace(/(?<=nˈaɪn)ti(?!ː)/g, 'di'); + } + return processed.trim(); +} diff --git a/ext/ai/js/onnxruntime/inference_api.js b/ext/ai/js/onnxruntime/inference_api.js index 195bf1a86..295d9e5c0 100644 --- a/ext/ai/js/onnxruntime/inference_api.js +++ b/ext/ai/js/onnxruntime/inference_api.js @@ -59,7 +59,18 @@ class UserInferenceSession { } async run(inputs) { - return await this.inner.run(inputs); + const outputs = await core.ops.op_sb_ai_ort_run_session(this.id, inputs); + + // Parse to Tensor + for (const key in outputs) { + if (Object.hasOwn(outputs, key)) { + const { type, data, dims } = outputs[key]; + + outputs[key] = new UserTensor(type, data.buffer, dims); + } + } + + return outputs; } } @@ -67,6 +78,10 @@ class UserTensor extends Tensor { constructor(type, data, dim) { super(type, data, dim); } + + async tryEncodeAudio(sampleRate) { + return await core.ops.op_sb_ai_ort_encode_tensor_audio(this.data, sampleRate); + } } export default { diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index fb449441e..715d9b963 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -47,6 +47,7 @@ deno_core::extension!( op_ai_try_cleanup_unused_session, op_ai_ort_init_session, op_ai_ort_run_session, + op_ai_ort_encode_tensor_audio, ], esm_entry_point = "ext:ai/ai.js", esm = [ diff --git a/ext/ai/onnxruntime/mod.rs b/ext/ai/onnxruntime/mod.rs index b1dd136f8..0f6842a4a 100644 --- a/ext/ai/onnxruntime/mod.rs +++ b/ext/ai/onnxruntime/mod.rs @@ -19,6 +19,7 @@ use deno_core::error::AnyError; use deno_core::op2; use deno_core::JsBuffer; use deno_core::JsRuntime; +use deno_core::ToJsBuffer; use deno_core::OpState; use deno_core::V8CrossThreadTaskSpawner; @@ -31,6 +32,7 @@ use tensor::ToJsTensor; use tokio::sync::oneshot; use tracing::debug; use tracing::trace; +use tokio_util::bytes::BufMut; #[op2(async)] #[to_v8] @@ -133,3 +135,44 @@ pub async fn op_ai_ort_run_session( rx.await.context("failed to get inference result")? } + +// REF: https://youtu.be/qqjvB_VxMRM?si=7lnYdgbhOC_K7P6S +// http://soundfile.sapp.org/doc/WaveFormat/ +#[op2] +#[serde] +pub fn op_ai_ort_encode_tensor_audio( + #[serde] tensor: JsBuffer, + sample_rate: u32, +) -> Result { + // let copy for now + let data_buffer = tensor.into_iter().as_slice(); + + let sample_size = 4; // f32 tensor + let data_chunk_size = data_buffer.len() as u32 * sample_size; + let total_riff_size = 36 + data_chunk_size; // 36 is the total of bytes until write data + + let mut audio_wav = Vec::new(); + + // RIFF HEADER + audio_wav.extend_from_slice(b"RIFF"); + audio_wav.put_u32_le(total_riff_size); + audio_wav.extend_from_slice(b"WAVE"); + + // FORMAT CHUNK + audio_wav.extend_from_slice(b"fmt "); // whitespace needed "fmt" + " " + audio_wav.put_u32_le(16); // PCM chunk size + audio_wav.put_u16_le(3); // RAW audio format + audio_wav.put_u16_le(1); // Number of channels + audio_wav.put_u32_le(sample_rate); + audio_wav.put_u32_le(sample_rate * sample_size); // Byte rate + audio_wav.put_u16_le(sample_size as u16); // Block align + audio_wav.put_u16_le(32); // f32 Bits per sample + + // DATA Chunk + audio_wav.extend_from_slice(b"data"); // chunk ID + audio_wav.put_u32_le(data_chunk_size); + + audio_wav.extend_from_slice(data_buffer); + + Ok(ToJsBuffer::from(audio_wav)) +} From b8951cb2e032a3a0466a44bb1d6b7ce1d5252621 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Sun, 2 Feb 2025 09:57:06 +0000 Subject: [PATCH 4/8] stamp: adding paper references for model magic numbers Documenting the "magic numbers" of the `text-to-audio` exmaple, [original paper](https://arxiv.org/pdf/2306.07691) --- examples/text-to-audio/index.ts | 148 +++++++++++++++++--------------- 1 file changed, 78 insertions(+), 70 deletions(-) diff --git a/examples/text-to-audio/index.ts b/examples/text-to-audio/index.ts index e3e68d463..7d2b2711f 100644 --- a/examples/text-to-audio/index.ts +++ b/examples/text-to-audio/index.ts @@ -7,97 +7,105 @@ import { phonemize } from './phonemizer.js'; const { Tensor, RawSession } = Supabase.ai; +/* NOTE: Reference [original paper](https://arxiv.org/pdf/2306.07691#Model%20Training): +> All datasets were resampled to 24 kHz to match LibriTTS, and the texts +> were converted into phonemes using phonemizer' +*/ +const SAMPLE_RATE = 24000; // 24 kHz + +/* NOTE: Reference [original paper](https://arxiv.org/pdf/2306.07691#Detailed%20Model%20Architectures): +> The size of s and c is 256 × 1 +*/ const STYLE_DIM = 256; -const SAMPLE_RATE = 24000; const MODEL_ID = 'onnx-community/Kokoro-82M-ONNX'; // https://huggingface.co/onnx-community/Kokoro-82M-ONNX#samples const ALLOWED_VOICES = [ - 'af_bella', - 'af_nicole', - 'af_sarah', - 'af_sky', - 'am_adam', - 'am_michael', - 'bf_emma', - 'bf_isabella', - 'bm_george', - 'bm_lewis', + 'af_bella', + 'af_nicole', + 'af_sarah', + 'af_sky', + 'am_adam', + 'am_michael', + 'bf_emma', + 'bf_isabella', + 'bm_george', + 'bm_lewis', ]; const session = await RawSession.fromHuggingFace(MODEL_ID); Deno.serve(async (req) => { - const params = new URL(req.url).searchParams; - const text = params.get('text') ?? 'Hello from Supabase!'; - const voice = params.get('voice') ?? 'af_bella'; - - if (!ALLOWED_VOICES.includes(voice)) { - return Response.json({ - error: `invalid voice '${voice}'`, - must_be_one_of: ALLOWED_VOICES, - }, { status: 400 }); - } - - const tokenizer = await loadTokenizer(); - const language = voice.at(0); // 'a'merican | 'b'ritish - const phonemes = await phonemize(text, language); - const { input_ids } = tokenizer(phonemes, { - truncation: true, - }); - - // Select voice style based on number of input tokens - const num_tokens = Math.max( - input_ids.dims.at(-1) - 2, // Without padding; - 0, - ); - - const voiceStyle = await loadVoiceStyle(voice, num_tokens); - - const { waveform } = await session.run({ - input_ids, - style: voiceStyle, - speed: new Tensor('float32', [1], [1]), - }); - - // Do `wave` encoding from rust backend - const audio = await waveform.tryEncodeAudio(SAMPLE_RATE); - - return new Response(audio, { - headers: { - 'Content-Type': 'audio/wav', - }, - }); + const params = new URL(req.url).searchParams; + const text = params.get('text') ?? 'Hello from Supabase!'; + const voice = params.get('voice') ?? 'af_bella'; + + if (!ALLOWED_VOICES.includes(voice)) { + return Response.json({ + error: `invalid voice '${voice}'`, + must_be_one_of: ALLOWED_VOICES, + }, { status: 400 }); + } + + const tokenizer = await loadTokenizer(); + const language = voice.at(0); // 'a'merican | 'b'ritish + const phonemes = await phonemize(text, language); + const { input_ids } = tokenizer(phonemes, { + truncation: true, + }); + + // Select voice style based on number of input tokens + const num_tokens = Math.max( + input_ids.dims.at(-1) - 2, // Without padding; + 0, + ); + + const voiceStyle = await loadVoiceStyle(voice, num_tokens); + + const { waveform } = await session.run({ + input_ids, + style: voiceStyle, + speed: new Tensor('float32', [1], [1]), + }); + + // Do `wave` encoding from rust backend + const audio = await waveform.tryEncodeAudio(SAMPLE_RATE); + + return new Response(audio, { + headers: { + 'Content-Type': 'audio/wav', + }, + }); }); async function loadVoiceStyle(voice: string, num_tokens: number) { - const voice_url = - `https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/voices/${voice}.bin?download=true`; + const voice_url = + `https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/voices/${voice}.bin?download=true`; - console.log('loading voice:', voice_url); + console.log('loading voice:', voice_url); - const voiceBuffer = await fetch(voice_url).then(async (res) => await res.arrayBuffer()); + const voiceBuffer = await fetch(voice_url).then(async (res) => await res.arrayBuffer()); - const offset = num_tokens * STYLE_DIM; - const voiceData = new Float32Array(voiceBuffer).slice( - offset, - offset + STYLE_DIM, - ); + const offset = num_tokens * STYLE_DIM; + const voiceData = new Float32Array(voiceBuffer).slice( + offset, + offset + STYLE_DIM, + ); - return new Tensor('float32', voiceData, [1, STYLE_DIM]); + return new Tensor('float32', voiceData, [1, STYLE_DIM]); } async function loadTokenizer() { - // BUG: invalid 'h' not JSON. That's why we need to manually fetch the assets - // const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID); + // BUG: invalid 'h' not JSON. That's why we need to manually fetch the assets + // const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID); - const tokenizerData = await fetch( - 'https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/tokenizer.json?download=true', - ).then(async (res) => await res.json()); + const tokenizerData = await fetch( + 'https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/tokenizer.json?download=true', + ).then(async (res) => await res.json()); - const tokenizerConfig = await fetch( - 'https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/tokenizer_config.json?download=true', - ).then(async (res) => await res.json()); + const tokenizerConfig = await fetch( + 'https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/tokenizer_config.json?download=true', + ).then(async (res) => await res.json()); - return new PreTrainedTokenizer(tokenizerData, tokenizerConfig); + return new PreTrainedTokenizer(tokenizerData, tokenizerConfig); } From b17079ba4bb0095c79dd0e070d67e6c9ac653391 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Thu, 6 Feb 2025 15:25:50 +0000 Subject: [PATCH 5/8] fix(ext/ai): moving types to `global.d.ts` --- examples/ort-raw-session/index.ts | 74 ++++---- examples/ort-raw-session/types.d.ts | 163 ---------------- examples/text-to-audio/index.ts | 4 +- ext/ai/js/onnxruntime/inference_api.js | 116 ++++++------ types/global.d.ts | 250 ++++++++++++++++++++----- 5 files changed, 298 insertions(+), 309 deletions(-) delete mode 100644 examples/ort-raw-session/types.d.ts diff --git a/examples/ort-raw-session/index.ts b/examples/ort-raw-session/index.ts index b50c4357e..f87b20c87 100644 --- a/examples/ort-raw-session/index.ts +++ b/examples/ort-raw-session/index.ts @@ -1,5 +1,3 @@ -/// /// - /* const modelUrl = 'https://huggingface.co/kalleby/hp-to-miles/resolve/main/model.onnx?download=true'; const modelConfigUrl = @@ -71,45 +69,45 @@ Deno.serve(async (_req: Request) => { }); */ -const { Tensor, RawSession } = Supabase.ai; +const { RawTensor, RawSession } = Supabase.ai; const session = await RawSession.fromHuggingFace('kallebysantos/vehicle-emission', { - path: { - modelFile: 'model.onnx', - }, + path: { + modelFile: 'model.onnx', + }, }); Deno.serve(async (_req: Request) => { - // sample data could be a JSON request - const carsBatchInput = [{ - 'Model_Year': 2021, - 'Engine_Size': 2.9, - 'Cylinders': 6, - 'Fuel_Consumption_in_City': 13.9, - 'Fuel_Consumption_in_City_Hwy': 10.3, - 'Fuel_Consumption_comb': 12.3, - 'Smog_Level': 3, - }, { - 'Model_Year': 2023, - 'Engine_Size': 2.4, - 'Cylinders': 4, - 'Fuel_Consumption_in_City': 9.9, - 'Fuel_Consumption_in_City_Hwy': 7.0, - 'Fuel_Consumption_comb': 8.6, - 'Smog_Level': 3, - }]; - - // Parsing objects to tensor input - const inputTensors: Record> = {}; - session.inputs.forEach((inputKey) => { - const values = carsBatchInput.map((item) => item[inputKey]); - - inputTensors[inputKey] = new Tensor('float32', values, [values.length, 1]); - }); - - const { emissions } = await session.run(inputTensors); - console.log(emissions); - // [ 289.01, 199.53] - - return Response.json({ result: emissions }); + // sample data could be a JSON request + const carsBatchInput = [{ + 'Model_Year': 2021, + 'Engine_Size': 2.9, + 'Cylinders': 6, + 'Fuel_Consumption_in_City': 13.9, + 'Fuel_Consumption_in_City_Hwy': 10.3, + 'Fuel_Consumption_comb': 12.3, + 'Smog_Level': 3, + }, { + 'Model_Year': 2023, + 'Engine_Size': 2.4, + 'Cylinders': 4, + 'Fuel_Consumption_in_City': 9.9, + 'Fuel_Consumption_in_City_Hwy': 7.0, + 'Fuel_Consumption_comb': 8.6, + 'Smog_Level': 3, + }]; + + // Parsing objects to tensor input + const inputTensors: Record> = {}; + session.inputs.forEach((inputKey) => { + const values = carsBatchInput.map((item) => item[inputKey]); + + inputTensors[inputKey] = new RawTensor('float32', values, [values.length, 1]); + }); + + const { emissions } = await session.run(inputTensors); + console.log(emissions); + // [ 289.01, 199.53] + + return Response.json({ result: emissions }); }); diff --git a/examples/ort-raw-session/types.d.ts b/examples/ort-raw-session/types.d.ts deleted file mode 100644 index eafc051b7..000000000 --- a/examples/ort-raw-session/types.d.ts +++ /dev/null @@ -1,163 +0,0 @@ -declare namespace Supabase { - /** - * Provides AI related APIs - */ - export interface Ai { - /** Provides an user friendly interface for the low level *onnx backend API*. - * A `RawSession` can execute any *onnx* model, but we only recommend it for `tabular` or *self-made* models, where you need mode control of model execution and pre/pos-processing. - * Consider a high-level implementation like `@huggingface/transformers.js` for generic tasks like `nlp`, `computer-vision` or `audio`. - * - * **Example:** - * ```typescript - * const session = await RawSession.fromHuggingFace('Supabase/gte-small'); - * // const session = await RawSession.fromUrl("https://example.com/model.onnx"); - * - * // Prepare the input tensors - * const inputs = { - * input1: new Tensor("float32", [1.0, 2.0, 3.0], [3]), - * input2: new Tensor("float32", [4.0, 5.0, 6.0], [3]), - * }; - * - * // Run the model - * const outputs = await session.run(inputs); - * - * console.log(outputs.output1); // Output tensor - * ``` - */ - readonly RawSession: typeof RawSession; - - /** A low level representation of model input/output. - * Supabase's `Tensor` is totally compatible with `@huggingface/transformers.js`'s `Tensor`. It means that you can use its high-level API to apply some common operations like `sum()`, `min()`, `max()`, `normalize()` etc... - * - * **Example: Generating embeddings from scratch** - * ```typescript - * import { Tensor as HFTensor } from "@huggingface/transformers.js"; - * const { Tensor, RawSession } = Supabase.ai; - * - * const session = await RawSession.fromHuggingFace('Supabase/gte-small'); - * - * // Example only, in real 'feature-extraction' tensors are given from the tokenizer step. - * const inputs = { - * input_ids: new Tensor('float32', [...], [n, 2]), - * attention_mask: new Tensor('float32', [...], [n, 2]), - * token_types_ids: new Tensor('float32', [...], [n, 2]) - * }; - * - * const { last_hidden_state } = await session.run(inputs); - * - * // Using `transformers.js` APIs - * const hfTensor = HFTensor.mean_pooling(last_hidden_state, inputs.attention_mask).normalize(); - * - * return hfTensor.tolist(); - * - * ``` - */ - readonly Tensor: typeof Tensor; - } - - /** - * Provides AI related APIs - */ - export const ai: Ai; - - export type TensorDataTypeMap = { - float32: Float32Array | number[]; - float64: Float64Array | number[]; - string: string[]; - int8: Int8Array | number[]; - uint8: Uint8Array | number[]; - int16: Int16Array | number[]; - uint16: Uint16Array | number[]; - int32: Int32Array | number[]; - uint32: Uint32Array | number[]; - int64: BigInt64Array | number[]; - uint64: BigUint64Array | number[]; - bool: Uint8Array | number[]; - }; - - export type TensorMap = { [key: string]: Tensor }; - - export class Tensor { - /** Type of the tensor. */ - type: T; - - /** The data stored in the tensor. */ - data: TensorDataTypeMap[T]; - - /** Dimensions of the tensor. */ - dims: number[]; - - /** The total number of elements in the tensor. */ - size: number; - - constructor(type: T, data: TensorDataTypeMap[T], dims: number[]); - - tryEncodeAudio(sampleRate: number): Promise; - } - - export class RawSession { - /** The underline session's ID. - * Session's ID are unique for each loaded model, it means that even if a session is constructed twice its will share the same ID. - */ - id: string; - - /** A list of all input keys the model expects. */ - inputs: string[]; - - /** A list of all output keys the model will result. */ - outputs: string[]; - - /** Loads a ONNX model session from source URL. - * Sessions are loaded once, then will keep warm cross worker's requests - */ - static fromUrl(source: string | URL): Promise; - - /** Loads a ONNX model session from **HuggingFace** repository. - * Sessions are loaded once, then will keep warm cross worker's requests - */ - static fromHuggingFace(repoId: string, opts?: { - /** - * @default 'https://huggingface.co' - */ - hostname?: string | URL; - path?: { - /** - * @default '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true' - */ - template?: string; - /** - * @default 'main' - */ - revision?: string; - /** - * @default 'model_quantized.onnx' - */ - modelFile?: string; - }; - }): Promise; - - /** Run the current session with the given inputs. - * Use `inputs` and `outputs` properties to know the required inputs and expected results for the model session. - * - * @param inputs The input tensors required by the model. - * @returns The output tensors generated by the model. - * - * @example - * ```typescript - * const session = await RawSession.fromUrl("https://example.com/model.onnx"); - * - * // Prepare the input tensors - * const inputs = { - * input1: new Tensor("float32", [1.0, 2.0, 3.0], [3]), - * input2: new Tensor("float32", [4.0, 5.0, 6.0], [3]), - * }; - * - * // Run the model - * const outputs = await session.run(inputs); - * - * console.log(outputs.output1); // Output tensor - * ``` - */ - run(inputs: TensorMap): Promise; - } -} diff --git a/examples/text-to-audio/index.ts b/examples/text-to-audio/index.ts index 7d2b2711f..9e70b0d4a 100644 --- a/examples/text-to-audio/index.ts +++ b/examples/text-to-audio/index.ts @@ -1,11 +1,9 @@ -// Setup type definitions for built-in Supabase Runtime APIs -import 'jsr:@supabase/functions-js/edge-runtime.d.ts'; import { PreTrainedTokenizer } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.3.1'; // import 'phonemize' code from Kokoro.js repo import { phonemize } from './phonemizer.js'; -const { Tensor, RawSession } = Supabase.ai; +const { RawTensor, RawSession } = Supabase.ai; /* NOTE: Reference [original paper](https://arxiv.org/pdf/2306.07691#Model%20Training): > All datasets were resampled to 24 kHz to match LibriTTS, and the texts diff --git a/ext/ai/js/onnxruntime/inference_api.js b/ext/ai/js/onnxruntime/inference_api.js index 295d9e5c0..92e0f51da 100644 --- a/ext/ai/js/onnxruntime/inference_api.js +++ b/ext/ai/js/onnxruntime/inference_api.js @@ -1,90 +1,90 @@ import { InferenceSession, Tensor } from 'ext:ai/onnxruntime/onnx.js'; const DEFAULT_HUGGING_FACE_OPTIONS = { - hostname: 'https://huggingface.co', - path: { - template: '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true', - revision: 'main', - modelFile: 'model_quantized.onnx', - }, + hostname: 'https://huggingface.co', + path: { + template: '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true', + revision: 'main', + modelFile: 'model_quantized.onnx', + }, }; /** * An user friendly API for onnx backend */ class UserInferenceSession { - inner; + inner; - id; - inputs; - outputs; + id; + inputs; + outputs; - constructor(session) { - this.inner = session; + constructor(session) { + this.inner = session; - this.id = session.sessionId; - this.inputs = session.inputNames; - this.outputs = session.outputNames; - } - - static async fromUrl(modelUrl) { - if (modelUrl instanceof URL) { - modelUrl = modelUrl.toString(); + this.id = session.sessionId; + this.inputs = session.inputNames; + this.outputs = session.outputNames; } - const encoder = new TextEncoder(); - const modelUrlBuffer = encoder.encode(modelUrl); - const session = await InferenceSession.fromBuffer(modelUrlBuffer); + static async fromUrl(modelUrl) { + if (modelUrl instanceof URL) { + modelUrl = modelUrl.toString(); + } + + const encoder = new TextEncoder(); + const modelUrlBuffer = encoder.encode(modelUrl); + const session = await InferenceSession.fromBuffer(modelUrlBuffer); + + return new UserInferenceSession(session); + } - return new UserInferenceSession(session); - } + static async fromHuggingFace(repoId, opts = {}) { + const hostname = opts?.hostname ?? DEFAULT_HUGGING_FACE_OPTIONS.hostname; + const pathOpts = { + ...DEFAULT_HUGGING_FACE_OPTIONS.path, + ...opts?.path, + }; - static async fromHuggingFace(repoId, opts = {}) { - const hostname = opts?.hostname ?? DEFAULT_HUGGING_FACE_OPTIONS.hostname; - const pathOpts = { - ...DEFAULT_HUGGING_FACE_OPTIONS.path, - ...opts?.path, - }; + const modelPath = pathOpts.template + .replaceAll('{REPO_ID}', repoId) + .replaceAll('{REVISION}', pathOpts.revision) + .replaceAll('{MODEL_FILE}', pathOpts.modelFile); - const modelPath = pathOpts.template - .replaceAll('{REPO_ID}', repoId) - .replaceAll('{REVISION}', pathOpts.revision) - .replaceAll('{MODEL_FILE}', pathOpts.modelFile); + if (!URL.canParse(modelPath, hostname)) { + throw Error(`[Invalid URL] Couldn't parse the model path: "${modelPath}"`); + } - if (!URL.canParse(modelPath, hostname)) { - throw Error(`[Invalid URL] Couldn't parse the model path: "${modelPath}"`); + return await UserInferenceSession.fromUrl(new URL(modelPath, hostname)); } - return await UserInferenceSession.fromUrl(new URL(modelPath, hostname)); - } + async run(inputs) { + const outputs = await core.ops.op_sb_ai_ort_run_session(this.id, inputs); - async run(inputs) { - const outputs = await core.ops.op_sb_ai_ort_run_session(this.id, inputs); + // Parse to Tensor + for (const key in outputs) { + if (Object.hasOwn(outputs, key)) { + const { type, data, dims } = outputs[key]; - // Parse to Tensor - for (const key in outputs) { - if (Object.hasOwn(outputs, key)) { - const { type, data, dims } = outputs[key]; + outputs[key] = new UserTensor(type, data.buffer, dims); + } + } - outputs[key] = new UserTensor(type, data.buffer, dims); - } + return outputs; } - - return outputs; - } } class UserTensor extends Tensor { - constructor(type, data, dim) { - super(type, data, dim); - } + constructor(type, data, dim) { + super(type, data, dim); + } - async tryEncodeAudio(sampleRate) { - return await core.ops.op_sb_ai_ort_encode_tensor_audio(this.data, sampleRate); - } + async tryEncodeAudio(sampleRate) { + return await core.ops.op_sb_ai_ort_encode_tensor_audio(this.data, sampleRate); + } } export default { - RawSession: UserInferenceSession, - Tensor: UserTensor, + RawSession: UserInferenceSession, + RawTensor: UserTensor, }; diff --git a/types/global.d.ts b/types/global.d.ts index 7810e23a9..89129a3d5 100644 --- a/types/global.d.ts +++ b/types/global.d.ts @@ -12,6 +12,14 @@ declare interface WindowEventMap { "drain": Event; } +type DecoratorType = 'tc39' | 'typescript' | 'typescript_with_metadata'; + +interface JsxImportBaseConfig { + defaultSpecifier?: string | null; + module?: string | null; + baseUrl?: string | null; +} + // TODO(Nyannyacha): These two type defs will be provided later. // deno-lint-ignore no-explicit-any @@ -150,55 +158,203 @@ declare namespace EdgeRuntime { } declare namespace Supabase { - export namespace ai { - interface ModelOptions { - /** - * Pool embeddings by taking their mean. Applies only for `gte-small` model - */ - mean_pool?: boolean; - - /** - * Normalize the embeddings result. Applies only for `gte-small` model - */ - normalize?: boolean; - - /** - * Stream response from model. Applies only for LLMs like `mistral` (default: false) - */ - stream?: boolean; - - /** - * Automatically abort the request to the model after specified time (in seconds). Applies only for LLMs like `mistral` (default: 60) - */ - timeout?: number; - - /** - * Mode for the inference API host. (default: 'ollama') - */ - mode?: "ollama" | "openaicompatible"; - signal?: AbortSignal; - } +export namespace ai { + interface ModelOptions { + /** + * Pool embeddings by taking their mean. Applies only for `gte-small` model + */ + mean_pool?: boolean; + + /** + * Normalize the embeddings result. Applies only for `gte-small` model + */ + normalize?: boolean; + + /** + * Stream response from model. Applies only for LLMs like `mistral` (default: false) + */ + stream?: boolean; + + /** + * Automatically abort the request to the model after specified time (in seconds). Applies only for LLMs like `mistral` (default: 60) + */ + timeout?: number; + + /** + * Mode for the inference API host. (default: 'ollama') + */ + mode?: 'ollama' | 'openaicompatible'; + signal?: AbortSignal; + } + + export type TensorDataTypeMap = { + float32: Float32Array | number[]; + float64: Float64Array | number[]; + string: string[]; + int8: Int8Array | number[]; + uint8: Uint8Array | number[]; + int16: Int16Array | number[]; + uint16: Uint16Array | number[]; + int32: Int32Array | number[]; + uint32: Uint32Array | number[]; + int64: BigInt64Array | number[]; + uint64: BigUint64Array | number[]; + bool: Uint8Array | number[]; + }; + + export type TensorMap = { [key: string]: RawTensor }; + + export class Session { + /** + * Create a new model session using given model + */ + constructor(model: string); + + /** + * Execute the given prompt in model session + */ + run( + prompt: + | string + | Omit< + import('openai').OpenAI.Chat.ChatCompletionCreateParams, + 'model' | 'stream' + >, + modelOptions?: ModelOptions, + ): unknown; + } - export class Session { - /** - * Create a new model session using given model - */ - constructor(model: string); - - /** - * Execute the given prompt in model session - */ - run( - prompt: - | string - | Omit< - import("openai").OpenAI.Chat.ChatCompletionCreateParams, - "model" | "stream" - >, - modelOptions?: ModelOptions, - ): unknown; + /** Provides an user friendly interface for the low level *onnx backend API*. + * A `RawSession` can execute any *onnx* model, but we only recommend it for `tabular` or *self-made* models, where you need mode control of model execution and pre/pos-processing. + * Consider a high-level implementation like `@huggingface/transformers.js` for generic tasks like `nlp`, `computer-vision` or `audio`. + * + * **Example:** + * ```typescript + * const session = await RawSession.fromHuggingFace('Supabase/gte-small'); + * // const session = await RawSession.fromUrl("https://example.com/model.onnx"); + * + * // Prepare the input tensors + * const inputs = { + * input1: new Tensor("float32", [1.0, 2.0, 3.0], [3]), + * input2: new Tensor("float32", [4.0, 5.0, 6.0], [3]), + * }; + * + * // Run the model + * const outputs = await session.run(inputs); + * + * console.log(outputs.output1); // Output tensor + * ``` + */ + export class RawSession { + /** The underline session's ID. + * Session's ID are unique for each loaded model, it means that even if a session is constructed twice its will share the same ID. + */ + id: string; + + /** A list of all input keys the model expects. */ + inputs: string[]; + + /** A list of all output keys the model will result. */ + outputs: string[]; + + /** Loads a ONNX model session from source URL. + * Sessions are loaded once, then will keep warm cross worker's requests + */ + static fromUrl(source: string | URL): Promise; + + /** Loads a ONNX model session from **HuggingFace** repository. + * Sessions are loaded once, then will keep warm cross worker's requests + */ + static fromHuggingFace(repoId: string, opts?: { + /** + * @default 'https://huggingface.co' + */ + hostname?: string | URL; + path?: { + /** + * @default '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true' + */ + template?: string; + /** + * @default 'main' + */ + revision?: string; + /** + * @default 'model_quantized.onnx' + */ + modelFile?: string; + }; + }): Promise; + + /** Run the current session with the given inputs. + * Use `inputs` and `outputs` properties to know the required inputs and expected results for the model session. + * + * @param inputs The input tensors required by the model. + * @returns The output tensors generated by the model. + * + * @example + * ```typescript + * const session = await RawSession.fromUrl("https://example.com/model.onnx"); + * + * // Prepare the input tensors + * const inputs = { + * input1: new Tensor("float32", [1.0, 2.0, 3.0], [3]), + * input2: new Tensor("float32", [4.0, 5.0, 6.0], [3]), + * }; + * + * // Run the model + * const outputs = await session.run(inputs); + * + * console.log(outputs.output1); // Output tensor + * ``` + */ + run(inputs: TensorMap): Promise; + } + + /** A low level representation of model input/output. + * Supabase's `Tensor` is totally compatible with `@huggingface/transformers.js`'s `Tensor`. It means that you can use its high-level API to apply some common operations like `sum()`, `min()`, `max()`, `normalize()` etc... + * + * **Example: Generating embeddings from scratch** + * ```typescript + * import { Tensor as HFTensor } from "@huggingface/transformers.js"; + * const { Tensor, RawSession } = Supabase.ai; + * + * const session = await RawSession.fromHuggingFace('Supabase/gte-small'); + * + * // Example only, in real 'feature-extraction' tensors are given from the tokenizer step. + * const inputs = { + * input_ids: new Tensor('float32', [...], [n, 2]), + * attention_mask: new Tensor('float32', [...], [n, 2]), + * token_types_ids: new Tensor('float32', [...], [n, 2]) + * }; + * + * const { last_hidden_state } = await session.run(inputs); + * + * // Using `transformers.js` APIs + * const hfTensor = HFTensor.mean_pooling(last_hidden_state, inputs.attention_mask).normalize(); + * + * return hfTensor.tolist(); + * + * ``` + */ + export class RawTensor { + /** Type of the tensor. */ + type: T; + + /** The data stored in the tensor. */ + data: TensorDataTypeMap[T]; + + /** Dimensions of the tensor. */ + dims: number[]; + + /** The total number of elements in the tensor. */ + size: number; + + constructor(type: T, data: TensorDataTypeMap[T], dims: number[]); + + tryEncodeAudio(sampleRate: number): Promise; + } } - } } declare namespace Deno { From eb1bb49222de0a4dde1a56dd666b2047a55fa31b Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Sat, 15 Feb 2025 17:22:22 +0000 Subject: [PATCH 6/8] feat: support for authorization header on model fetch --- ext/ai/js/onnxruntime/inference_api.js | 125 +++++++++++++------------ ext/ai/js/onnxruntime/onnx.js | 13 ++- ext/ai/lib.rs | 3 +- ext/ai/onnxruntime/mod.rs | 30 +++--- ext/ai/onnxruntime/model.rs | 8 +- ext/ai/onnxruntime/session.rs | 18 ++-- ext/ai/utils.rs | 14 +++ 7 files changed, 122 insertions(+), 89 deletions(-) diff --git a/ext/ai/js/onnxruntime/inference_api.js b/ext/ai/js/onnxruntime/inference_api.js index 92e0f51da..48ff6b57e 100644 --- a/ext/ai/js/onnxruntime/inference_api.js +++ b/ext/ai/js/onnxruntime/inference_api.js @@ -1,90 +1,93 @@ +const core = globalThis.Deno.core; import { InferenceSession, Tensor } from 'ext:ai/onnxruntime/onnx.js'; const DEFAULT_HUGGING_FACE_OPTIONS = { - hostname: 'https://huggingface.co', - path: { - template: '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true', - revision: 'main', - modelFile: 'model_quantized.onnx', - }, + hostname: 'https://huggingface.co', + path: { + template: '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true', + revision: 'main', + modelFile: 'model_quantized.onnx', + }, }; /** * An user friendly API for onnx backend */ class UserInferenceSession { - inner; + inner; - id; - inputs; - outputs; + id; + inputs; + outputs; - constructor(session) { - this.inner = session; + constructor(session) { + this.inner = session; - this.id = session.sessionId; - this.inputs = session.inputNames; - this.outputs = session.outputNames; - } - - static async fromUrl(modelUrl) { - if (modelUrl instanceof URL) { - modelUrl = modelUrl.toString(); - } - - const encoder = new TextEncoder(); - const modelUrlBuffer = encoder.encode(modelUrl); - const session = await InferenceSession.fromBuffer(modelUrlBuffer); + this.id = session.sessionId; + this.inputs = session.inputNames; + this.outputs = session.outputNames; + } - return new UserInferenceSession(session); + static async fromUrl(modelUrl) { + if (modelUrl instanceof URL) { + modelUrl = modelUrl.toString(); } - static async fromHuggingFace(repoId, opts = {}) { - const hostname = opts?.hostname ?? DEFAULT_HUGGING_FACE_OPTIONS.hostname; - const pathOpts = { - ...DEFAULT_HUGGING_FACE_OPTIONS.path, - ...opts?.path, - }; - - const modelPath = pathOpts.template - .replaceAll('{REPO_ID}', repoId) - .replaceAll('{REVISION}', pathOpts.revision) - .replaceAll('{MODEL_FILE}', pathOpts.modelFile); - - if (!URL.canParse(modelPath, hostname)) { - throw Error(`[Invalid URL] Couldn't parse the model path: "${modelPath}"`); - } - - return await UserInferenceSession.fromUrl(new URL(modelPath, hostname)); + const encoder = new TextEncoder(); + const modelUrlBuffer = encoder.encode(modelUrl); + const session = await InferenceSession.fromBuffer(modelUrlBuffer); + + return new UserInferenceSession(session); + } + + static async fromHuggingFace(repoId, opts = {}) { + const hostname = opts?.hostname ?? DEFAULT_HUGGING_FACE_OPTIONS.hostname; + const pathOpts = { + ...DEFAULT_HUGGING_FACE_OPTIONS.path, + ...opts?.path, + }; + + const modelPath = pathOpts.template + .replaceAll('{REPO_ID}', repoId) + .replaceAll('{REVISION}', pathOpts.revision) + .replaceAll('{MODEL_FILE}', pathOpts.modelFile); + + if (!URL.canParse(modelPath, hostname)) { + throw Error( + `[Invalid URL] Couldn't parse the model path: "${modelPath}"`, + ); } - async run(inputs) { - const outputs = await core.ops.op_sb_ai_ort_run_session(this.id, inputs); + return await UserInferenceSession.fromUrl(new URL(modelPath, hostname)); + } - // Parse to Tensor - for (const key in outputs) { - if (Object.hasOwn(outputs, key)) { - const { type, data, dims } = outputs[key]; + async run(inputs) { + const outputs = await core.ops.op_ai_ort_run_session(this.id, inputs); - outputs[key] = new UserTensor(type, data.buffer, dims); - } - } + // Parse to Tensor + for (const key in outputs) { + if (Object.hasOwn(outputs, key)) { + const { type, data, dims } = outputs[key]; - return outputs; + outputs[key] = new UserTensor(type, data.buffer, dims); + } } + + return outputs; + } } class UserTensor extends Tensor { - constructor(type, data, dim) { - super(type, data, dim); - } + constructor(type, data, dim) { + super(type, data, dim); + } - async tryEncodeAudio(sampleRate) { - return await core.ops.op_sb_ai_ort_encode_tensor_audio(this.data, sampleRate); - } + async tryEncodeAudio(sampleRate) { + return await core.ops.op_ai_ort_encode_tensor_audio(this.data, sampleRate); + } } export default { - RawSession: UserInferenceSession, - RawTensor: UserTensor, + RawSession: UserInferenceSession, + RawTensor: UserTensor, }; diff --git a/ext/ai/js/onnxruntime/onnx.js b/ext/ai/js/onnxruntime/onnx.js index 643de198a..16ffeeb56 100644 --- a/ext/ai/js/onnxruntime/onnx.js +++ b/ext/ai/js/onnxruntime/onnx.js @@ -18,7 +18,7 @@ const DataTypeMap = Object.freeze({ class TensorProxy { get(target, property) { switch (property) { - case "data": + case 'data': return target.data?.c ?? target.data; default: @@ -86,6 +86,15 @@ export class InferenceSession { return new InferenceSession(id, inputs, outputs); } + static async fromRequest(modelUrl, authorization) { + const [id, inputs, outputs] = await core.ops.op_ai_ort_init_session( + modelUrl, + authorization, + ); + + return new InferenceSession(id, inputs, outputs); + } + async run(inputs) { const sessionInputs = {}; @@ -125,4 +134,4 @@ const onnxruntime = { }, }; -globalThis[Symbol.for("onnxruntime")] = onnxruntime; +globalThis[Symbol.for('onnxruntime')] = onnxruntime; diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index 715d9b963..c9af09d77 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -118,7 +118,7 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { let handle = handle.clone(); move || { handle.block_on(async move { - load_session_from_url(Url::parse(consts::GTE_SMALL_MODEL_URL).unwrap()) + load_session_from_url(Url::parse(consts::GTE_SMALL_MODEL_URL).unwrap(), None) .await }) } @@ -143,6 +143,7 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { "tokenizer", Url::parse(consts::GTE_SMALL_TOKENIZER_URL).unwrap(), None, + None ) .map_err(AnyError::from) .and_then(|it| { diff --git a/ext/ai/onnxruntime/mod.rs b/ext/ai/onnxruntime/mod.rs index 0f6842a4a..f28026dc6 100644 --- a/ext/ai/onnxruntime/mod.rs +++ b/ext/ai/onnxruntime/mod.rs @@ -37,26 +37,28 @@ use tokio_util::bytes::BufMut; #[op2(async)] #[to_v8] pub async fn op_ai_ort_init_session( - state: Rc>, - #[buffer] model_bytes: JsBuffer, + state: Rc>, + #[buffer] model_bytes: JsBuffer, + // Maybe improve the code style to enum payload or something else + #[string] req_authorization: Option, ) -> Result { let model_bytes = model_bytes.into_parts().to_boxed_slice(); let model_bytes_or_url = str::from_utf8(&model_bytes) .map_err(AnyError::from) .and_then(|utf8_str| Url::parse(utf8_str).map_err(AnyError::from)); - let model = match model_bytes_or_url { - Ok(model_url) => { - trace!(kind = "url", url = %model_url); - Model::from_url(model_url).await? - } - Err(_) => { - trace!(kind = "bytes", len = model_bytes.len()); - Model::from_bytes(&model_bytes).await? - } - }; - - let mut state = state.borrow_mut(); + let model = match model_bytes_or_url { + Ok(model_url) => { + trace!(kind = "url", url = %model_url); + Model::from_url(model_url, req_authorization).await? + } + Err(_) => { + trace!(kind = "bytes", len = model_bytes.len()); + Model::from_bytes(&model_bytes).await? + } + }; + + let mut state = state.borrow_mut(); let mut sessions = { state.try_take::>>().unwrap_or_default() }; diff --git a/ext/ai/onnxruntime/model.rs b/ext/ai/onnxruntime/model.rs index f3a17e6e4..f1575d06e 100644 --- a/ext/ai/onnxruntime/model.rs +++ b/ext/ai/onnxruntime/model.rs @@ -71,9 +71,11 @@ impl Model { .map(Self::new) } - pub async fn from_url(model_url: Url) -> Result { - load_session_from_url(model_url).await.map(Self::new) - } + pub async fn from_url(model_url: Url, authorization: Option) -> Result { + load_session_from_url(model_url, authorization) + .await + .map(Self::new) + } pub async fn from_bytes(model_bytes: &[u8]) -> Result { load_session_from_bytes(model_bytes).await.map(Self::new) diff --git a/ext/ai/onnxruntime/session.rs b/ext/ai/onnxruntime/session.rs index 6205e8550..ca670792b 100644 --- a/ext/ai/onnxruntime/session.rs +++ b/ext/ai/onnxruntime/session.rs @@ -154,9 +154,10 @@ pub(crate) async fn load_session_from_bytes( #[instrument(level = "debug", fields(%model_url), err)] pub(crate) async fn load_session_from_url( - model_url: Url, + model_url: Url, + authorization: Option, ) -> Result { - let session_id = fxhash::hash(model_url.as_str()).to_string(); + let session_id = fxhash::hash(model_url.as_str()).to_string(); let mut sessions = SESSIONS.lock().await; @@ -165,12 +166,13 @@ pub(crate) async fn load_session_from_url( return Ok((session_id, session.clone()).into()); } - let model_file_path = crate::utils::fetch_and_cache_from_url( - "model", - model_url, - Some(session_id.to_string()), - ) - .await?; + let model_file_path = crate::utils::fetch_and_cache_from_url( + "model", + model_url, + Some(session_id.to_string()), + authorization, + ) + .await?; let model_bytes = tokio::fs::read(model_file_path).await?; let session = create_session(model_bytes.as_slice())?; diff --git a/ext/ai/utils.rs b/ext/ai/utils.rs index f9e188ba3..f57eb0f16 100644 --- a/ext/ai/utils.rs +++ b/ext/ai/utils.rs @@ -20,6 +20,7 @@ pub async fn fetch_and_cache_from_url( kind: &'static str, url: Url, cache_id: Option, + authorization: Option, ) -> Result { let cache_id = cache_id.unwrap_or(fxhash::hash(url.as_str()).to_string()); let download_dir = std::env::var("EXT_AI_CACHE_DIR") @@ -91,13 +92,26 @@ pub async fn fetch_and_cache_from_url( use reqwest::*; + let mut headers = header::HeaderMap::new(); + + if let Some(authorization) = authorization { + let mut authorization = + header::HeaderValue::from_str(authorization.as_str())?; + authorization.set_sensitive(true); + + headers.insert(header::AUTHORIZATION, authorization); + }; + let resp = Client::builder() .http1_only() + .default_headers(headers) .build() .context("failed to create http client")? .get(url.clone()) .send() .await + .context("failed to download")? + .error_for_status() .context("failed to download")?; let file = tokio::fs::File::create(&filepath) From 71077118bcffcd6b42da4ba6b7d065cc1c5504e7 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Tue, 25 Feb 2025 17:51:23 +0000 Subject: [PATCH 7/8] stamp: add model loading fromStorage - Adding `fromStorage` method to InferenceAPI, its allows model loadingfrom Supabase Storage with public/private bucket support. --- ext/ai/js/onnxruntime/inference_api.js | 34 ++++++++++++++++++++++++-- types/global.d.ts | 13 ++++++++++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/ext/ai/js/onnxruntime/inference_api.js b/ext/ai/js/onnxruntime/inference_api.js index 48ff6b57e..13d2cad25 100644 --- a/ext/ai/js/onnxruntime/inference_api.js +++ b/ext/ai/js/onnxruntime/inference_api.js @@ -10,6 +10,13 @@ const DEFAULT_HUGGING_FACE_OPTIONS = { }, }; +const DEFAULT_STORAGE_OPTIONS = () => ({ + hostname: Deno.env.get('SUPABASE_URL'), + mode: { + authorization: Deno.env.get('SUPABASE_SERVICE_ROLE_KEY'), + }, +}); + /** * An user friendly API for onnx backend */ @@ -28,14 +35,17 @@ class UserInferenceSession { this.outputs = session.outputNames; } - static async fromUrl(modelUrl) { + static async fromUrl(modelUrl, authorization) { if (modelUrl instanceof URL) { modelUrl = modelUrl.toString(); } const encoder = new TextEncoder(); const modelUrlBuffer = encoder.encode(modelUrl); - const session = await InferenceSession.fromBuffer(modelUrlBuffer); + const session = await InferenceSession.fromRequest( + modelUrlBuffer, + authorization, + ); return new UserInferenceSession(session); } @@ -61,6 +71,26 @@ class UserInferenceSession { return await UserInferenceSession.fromUrl(new URL(modelPath, hostname)); } + static async fromStorage(modelPath, opts = {}) { + const defaultOpts = DEFAULT_STORAGE_OPTIONS(); + const hostname = opts?.hostname ?? defaultOpts.hostname; + const mode = opts?.mode ?? defaultOpts.mode; + + const assetPath = mode === 'public' ? `public/${modelPath}` : `authenticated/${modelPath}`; + + const storageUrl = `/storage/v1/object/${assetPath}`; + + if (!URL.canParse(storageUrl, hostname)) { + throw Error( + `[Invalid URL] Couldn't parse the model path: "${storageUrl}"`, + ); + } + + return await UserInferenceSession.fromUrl( + new URL(storageUrl, hostname), + mode?.authorization, + ); + } async run(inputs) { const outputs = await core.ops.op_ai_ort_run_session(this.id, inputs); diff --git a/types/global.d.ts b/types/global.d.ts index 89129a3d5..47c33612f 100644 --- a/types/global.d.ts +++ b/types/global.d.ts @@ -286,6 +286,19 @@ export namespace ai { }; }): Promise; + /** Loads a ONNX model session from **Storage**. + * Sessions are loaded once, then will keep warm cross worker's requests + */ + static fromStorage(repoId: string, opts?: { + /** + * @default 'env SUPABASE_URL' + */ + hostname?: string | URL; + mode?: 'public' | { + authorization: string; + }; + }): Promise; + /** Run the current session with the given inputs. * Use `inputs` and `outputs` properties to know the required inputs and expected results for the model session. * From bbcda7a313caabce8aa00d62f08d37ed2dfda8f0 Mon Sep 17 00:00:00 2001 From: kallebysantos Date: Tue, 25 Feb 2025 20:48:28 +0000 Subject: [PATCH 8/8] stamp: clippy and format --- examples/ort-raw-session/index.ts | 76 ++--- examples/text-to-audio/index.ts | 146 ++++----- examples/text-to-audio/phonemizer.js | 122 +++---- ext/ai/js/ai.js | 10 +- ext/ai/js/onnxruntime/inference_api.js | 24 +- ext/ai/js/onnxruntime/onnx.js | 4 +- ext/ai/lib.rs | 9 +- ext/ai/onnxruntime/mod.rs | 88 +++--- ext/ai/onnxruntime/model.rs | 13 +- ext/ai/onnxruntime/session.rs | 20 +- types/global.d.ts | 420 +++++++++++++------------ 11 files changed, 480 insertions(+), 452 deletions(-) diff --git a/examples/ort-raw-session/index.ts b/examples/ort-raw-session/index.ts index f87b20c87..a0b06c414 100644 --- a/examples/ort-raw-session/index.ts +++ b/examples/ort-raw-session/index.ts @@ -71,43 +71,49 @@ Deno.serve(async (_req: Request) => { const { RawTensor, RawSession } = Supabase.ai; -const session = await RawSession.fromHuggingFace('kallebysantos/vehicle-emission', { +const session = await RawSession.fromHuggingFace( + "kallebysantos/vehicle-emission", + { path: { - modelFile: 'model.onnx', + modelFile: "model.onnx", }, -}); + }, +); Deno.serve(async (_req: Request) => { - // sample data could be a JSON request - const carsBatchInput = [{ - 'Model_Year': 2021, - 'Engine_Size': 2.9, - 'Cylinders': 6, - 'Fuel_Consumption_in_City': 13.9, - 'Fuel_Consumption_in_City_Hwy': 10.3, - 'Fuel_Consumption_comb': 12.3, - 'Smog_Level': 3, - }, { - 'Model_Year': 2023, - 'Engine_Size': 2.4, - 'Cylinders': 4, - 'Fuel_Consumption_in_City': 9.9, - 'Fuel_Consumption_in_City_Hwy': 7.0, - 'Fuel_Consumption_comb': 8.6, - 'Smog_Level': 3, - }]; - - // Parsing objects to tensor input - const inputTensors: Record> = {}; - session.inputs.forEach((inputKey) => { - const values = carsBatchInput.map((item) => item[inputKey]); - - inputTensors[inputKey] = new RawTensor('float32', values, [values.length, 1]); - }); - - const { emissions } = await session.run(inputTensors); - console.log(emissions); - // [ 289.01, 199.53] - - return Response.json({ result: emissions }); + // sample data could be a JSON request + const carsBatchInput = [{ + "Model_Year": 2021, + "Engine_Size": 2.9, + "Cylinders": 6, + "Fuel_Consumption_in_City": 13.9, + "Fuel_Consumption_in_City_Hwy": 10.3, + "Fuel_Consumption_comb": 12.3, + "Smog_Level": 3, + }, { + "Model_Year": 2023, + "Engine_Size": 2.4, + "Cylinders": 4, + "Fuel_Consumption_in_City": 9.9, + "Fuel_Consumption_in_City_Hwy": 7.0, + "Fuel_Consumption_comb": 8.6, + "Smog_Level": 3, + }]; + + // Parsing objects to tensor input + const inputTensors: Record> = {}; + session.inputs.forEach((inputKey) => { + const values = carsBatchInput.map((item) => item[inputKey]); + + inputTensors[inputKey] = new RawTensor("float32", values, [ + values.length, + 1, + ]); + }); + + const { emissions } = await session.run(inputTensors); + console.log(emissions); + // [ 289.01, 199.53] + + return Response.json({ result: emissions }); }); diff --git a/examples/text-to-audio/index.ts b/examples/text-to-audio/index.ts index 9e70b0d4a..5014ac170 100644 --- a/examples/text-to-audio/index.ts +++ b/examples/text-to-audio/index.ts @@ -1,7 +1,7 @@ -import { PreTrainedTokenizer } from 'https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.3.1'; +import { PreTrainedTokenizer } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.3.1"; // import 'phonemize' code from Kokoro.js repo -import { phonemize } from './phonemizer.js'; +import { phonemize } from "./phonemizer.js"; const { RawTensor, RawSession } = Supabase.ai; @@ -15,95 +15,97 @@ const SAMPLE_RATE = 24000; // 24 kHz > The size of s and c is 256 × 1 */ const STYLE_DIM = 256; -const MODEL_ID = 'onnx-community/Kokoro-82M-ONNX'; +const MODEL_ID = "onnx-community/Kokoro-82M-ONNX"; // https://huggingface.co/onnx-community/Kokoro-82M-ONNX#samples const ALLOWED_VOICES = [ - 'af_bella', - 'af_nicole', - 'af_sarah', - 'af_sky', - 'am_adam', - 'am_michael', - 'bf_emma', - 'bf_isabella', - 'bm_george', - 'bm_lewis', + "af_bella", + "af_nicole", + "af_sarah", + "af_sky", + "am_adam", + "am_michael", + "bf_emma", + "bf_isabella", + "bm_george", + "bm_lewis", ]; const session = await RawSession.fromHuggingFace(MODEL_ID); Deno.serve(async (req) => { - const params = new URL(req.url).searchParams; - const text = params.get('text') ?? 'Hello from Supabase!'; - const voice = params.get('voice') ?? 'af_bella'; - - if (!ALLOWED_VOICES.includes(voice)) { - return Response.json({ - error: `invalid voice '${voice}'`, - must_be_one_of: ALLOWED_VOICES, - }, { status: 400 }); - } - - const tokenizer = await loadTokenizer(); - const language = voice.at(0); // 'a'merican | 'b'ritish - const phonemes = await phonemize(text, language); - const { input_ids } = tokenizer(phonemes, { - truncation: true, - }); - - // Select voice style based on number of input tokens - const num_tokens = Math.max( - input_ids.dims.at(-1) - 2, // Without padding; - 0, - ); - - const voiceStyle = await loadVoiceStyle(voice, num_tokens); - - const { waveform } = await session.run({ - input_ids, - style: voiceStyle, - speed: new Tensor('float32', [1], [1]), - }); - - // Do `wave` encoding from rust backend - const audio = await waveform.tryEncodeAudio(SAMPLE_RATE); - - return new Response(audio, { - headers: { - 'Content-Type': 'audio/wav', - }, - }); + const params = new URL(req.url).searchParams; + const text = params.get("text") ?? "Hello from Supabase!"; + const voice = params.get("voice") ?? "af_bella"; + + if (!ALLOWED_VOICES.includes(voice)) { + return Response.json({ + error: `invalid voice '${voice}'`, + must_be_one_of: ALLOWED_VOICES, + }, { status: 400 }); + } + + const tokenizer = await loadTokenizer(); + const language = voice.at(0); // 'a'merican | 'b'ritish + const phonemes = await phonemize(text, language); + const { input_ids } = tokenizer(phonemes, { + truncation: true, + }); + + // Select voice style based on number of input tokens + const num_tokens = Math.max( + input_ids.dims.at(-1) - 2, // Without padding; + 0, + ); + + const voiceStyle = await loadVoiceStyle(voice, num_tokens); + + const { waveform } = await session.run({ + input_ids, + style: voiceStyle, + speed: new Tensor("float32", [1], [1]), + }); + + // Do `wave` encoding from rust backend + const audio = await waveform.tryEncodeAudio(SAMPLE_RATE); + + return new Response(audio, { + headers: { + "Content-Type": "audio/wav", + }, + }); }); async function loadVoiceStyle(voice: string, num_tokens: number) { - const voice_url = - `https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/voices/${voice}.bin?download=true`; + const voice_url = + `https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/voices/${voice}.bin?download=true`; - console.log('loading voice:', voice_url); + console.log("loading voice:", voice_url); - const voiceBuffer = await fetch(voice_url).then(async (res) => await res.arrayBuffer()); + const voiceBuffer = await fetch(voice_url).then(async (res) => + await res.arrayBuffer() + ); - const offset = num_tokens * STYLE_DIM; - const voiceData = new Float32Array(voiceBuffer).slice( - offset, - offset + STYLE_DIM, - ); + const offset = num_tokens * STYLE_DIM; + const voiceData = new Float32Array(voiceBuffer).slice( + offset, + offset + STYLE_DIM, + ); - return new Tensor('float32', voiceData, [1, STYLE_DIM]); + return new Tensor("float32", voiceData, [1, STYLE_DIM]); } async function loadTokenizer() { - // BUG: invalid 'h' not JSON. That's why we need to manually fetch the assets - // const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID); + // BUG: invalid 'h' not JSON. That's why we need to manually fetch the assets + // const tokenizer = await AutoTokenizer.from_pretrained(MODEL_ID); - const tokenizerData = await fetch( - 'https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/tokenizer.json?download=true', - ).then(async (res) => await res.json()); + const tokenizerData = await fetch( + "https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/tokenizer.json?download=true", + ).then(async (res) => await res.json()); - const tokenizerConfig = await fetch( - 'https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/tokenizer_config.json?download=true', - ).then(async (res) => await res.json()); + const tokenizerConfig = await fetch( + "https://huggingface.co/onnx-community/Kokoro-82M-ONNX/resolve/main/tokenizer_config.json?download=true", + ).then(async (res) => await res.json()); - return new PreTrainedTokenizer(tokenizerData, tokenizerConfig); + return new PreTrainedTokenizer(tokenizerData, tokenizerConfig); } diff --git a/examples/text-to-audio/phonemizer.js b/examples/text-to-audio/phonemizer.js index e694e801f..3a2b8028d 100644 --- a/examples/text-to-audio/phonemizer.js +++ b/examples/text-to-audio/phonemizer.js @@ -1,6 +1,6 @@ // Source code from https://github.com/hexgrad/kokoro/blob/a09db51873211d76a3e49b058b18873a4e002b81/kokoro.js/src/phonemize.js // BUG: Don't now why, but if import it from cdnjs will cause runtime stack overflow -import { phonemize as espeakng } from 'npm:phonemizer@1.2.1'; +import { phonemize as espeakng } from "npm:phonemizer@1.2.1"; /** * Helper function to split a string on a regex, but keep the delimiters. @@ -35,10 +35,10 @@ function split(text, regex) { * @returns {string} The phonetic equivalent */ function split_num(match) { - if (match.includes('.')) { + if (match.includes(".")) { return match; - } else if (match.includes(':')) { - let [h, m] = match.split(':').map(Number); + } else if (match.includes(":")) { + let [h, m] = match.split(":").map(Number); if (m === 0) { return `${h} o'clock`; } else if (m < 10) { @@ -52,7 +52,7 @@ function split_num(match) { } let left = match.slice(0, 2); let right = parseInt(match.slice(2, 4), 10); - let suffix = match.endsWith('s') ? 's' : ''; + let suffix = match.endsWith("s") ? "s" : ""; if (year % 1000 >= 100 && year % 1000 <= 999) { if (right === 0) { return `${left} hundred${suffix}`; @@ -69,17 +69,21 @@ function split_num(match) { * @returns {string} The formatted currency */ function flip_money(match) { - const bill = match[0] === '$' ? 'dollar' : 'pound'; + const bill = match[0] === "$" ? "dollar" : "pound"; if (isNaN(Number(match.slice(1)))) { return `${match.slice(1)} ${bill}s`; - } else if (!match.includes('.')) { - let suffix = match.slice(1) === '1' ? '' : 's'; + } else if (!match.includes(".")) { + let suffix = match.slice(1) === "1" ? "" : "s"; return `${match.slice(1)} ${bill}${suffix}`; } - const [b, c] = match.slice(1).split('.'); - const d = parseInt(c.padEnd(2, '0'), 10); - let coins = match[0] === '$' ? (d === 1 ? 'cent' : 'cents') : d === 1 ? 'penny' : 'pence'; - return `${b} ${bill}${b === '1' ? '' : 's'} and ${d} ${coins}`; + const [b, c] = match.slice(1).split("."); + const d = parseInt(c.padEnd(2, "0"), 10); + let coins = match[0] === "$" + ? (d === 1 ? "cent" : "cents") + : d === 1 + ? "penny" + : "pence"; + return `${b} ${bill}${b === "1" ? "" : "s"} and ${d} ${coins}`; } /** @@ -88,8 +92,8 @@ function flip_money(match) { * @returns {string} The formatted number */ function point_num(match) { - let [a, b] = match.split('.'); - return `${a} point ${b.split('').join(' ')}`; + let [a, b] = match.split("."); + return `${a} point ${b.split("").join(" ")}`; } /** @@ -102,47 +106,50 @@ function normalize_text(text) { text // 1. Handle quotes and brackets .replace(/[‘’]/g, "'") - .replace(/«/g, '“') - .replace(/»/g, '”') + .replace(/«/g, "“") + .replace(/»/g, "”") .replace(/[“”]/g, '"') - .replace(/\(/g, '«') - .replace(/\)/g, '»') + .replace(/\(/g, "«") + .replace(/\)/g, "»") // 2. Replace uncommon punctuation marks - .replace(/、/g, ', ') - .replace(/。/g, '. ') - .replace(/!/g, '! ') - .replace(/,/g, ', ') - .replace(/:/g, ': ') - .replace(/;/g, '; ') - .replace(/?/g, '? ') + .replace(/、/g, ", ") + .replace(/。/g, ". ") + .replace(/!/g, "! ") + .replace(/,/g, ", ") + .replace(/:/g, ": ") + .replace(/;/g, "; ") + .replace(/?/g, "? ") // 3. Whitespace normalization - .replace(/[^\S \n]/g, ' ') - .replace(/ +/, ' ') - .replace(/(?<=\n) +(?=\n)/g, '') + .replace(/[^\S \n]/g, " ") + .replace(/ +/, " ") + .replace(/(?<=\n) +(?=\n)/g, "") // 4. Abbreviations - .replace(/\bD[Rr]\.(?= [A-Z])/g, 'Doctor') - .replace(/\b(?:Mr\.|MR\.(?= [A-Z]))/g, 'Mister') - .replace(/\b(?:Ms\.|MS\.(?= [A-Z]))/g, 'Miss') - .replace(/\b(?:Mrs\.|MRS\.(?= [A-Z]))/g, 'Mrs') - .replace(/\betc\.(?! [A-Z])/gi, 'etc') + .replace(/\bD[Rr]\.(?= [A-Z])/g, "Doctor") + .replace(/\b(?:Mr\.|MR\.(?= [A-Z]))/g, "Mister") + .replace(/\b(?:Ms\.|MS\.(?= [A-Z]))/g, "Miss") + .replace(/\b(?:Mrs\.|MRS\.(?= [A-Z]))/g, "Mrs") + .replace(/\betc\.(?! [A-Z])/gi, "etc") // 5. Normalize casual words .replace(/\b(y)eah?\b/gi, "$1e'a") // 5. Handle numbers and currencies - .replace(/\d*\.\d+|\b\d{4}s?\b|(? m.replace(/\./g, '-')) - .replace(/(?<=[A-Z])\.(?=[A-Z])/gi, '-') + .replace(/(?:[A-Za-z]\.){2,} [a-z]/g, (m) => m.replace(/\./g, "-")) + .replace(/(?<=[A-Z])\.(?=[A-Z])/gi, "-") // 8. Strip leading and trailing whitespace .trim() ); @@ -155,13 +162,16 @@ function normalize_text(text) { * @returns {string} The escaped string. */ function escapeRegExp(string) { - return string.replace(/[.*+?^${}()|[\]\\]/g, '\\$&'); // $& means the whole matched string + return string.replace(/[.*+?^${}()|[\]\\]/g, "\\$&"); // $& means the whole matched string } const PUNCTUATION = ';:,.!?¡¿—…"«»“”(){}[]'; -const PUNCTUATION_PATTERN = new RegExp(`(\\s*[${escapeRegExp(PUNCTUATION)}]+\\s*)+`, 'g'); +const PUNCTUATION_PATTERN = new RegExp( + `(\\s*[${escapeRegExp(PUNCTUATION)}]+\\s*)+`, + "g", +); -export async function phonemize(text, language = 'a', norm = true) { +export async function phonemize(text, language = "a", norm = true) { // 1. Normalize text if (norm) { text = normalize_text(text); @@ -171,28 +181,28 @@ export async function phonemize(text, language = 'a', norm = true) { const sections = split(text, PUNCTUATION_PATTERN); // 3. Convert each section to phonemes - const lang = language === 'a' ? 'en-us' : 'en'; + const lang = language === "a" ? "en-us" : "en"; const ps = (await Promise.all( sections.map(async ( { match, text }, - ) => (match ? text : (await espeakng(text, lang)).join(' '))), - )).join(''); + ) => (match ? text : (await espeakng(text, lang)).join(" "))), + )).join(""); // 4. Post-process phonemes let processed = ps // https://en.wiktionary.org/wiki/kokoro#English - .replace(/kəkˈoːɹoʊ/g, 'kˈoʊkəɹoʊ') - .replace(/kəkˈɔːɹəʊ/g, 'kˈəʊkəɹəʊ') - .replace(/ʲ/g, 'j') - .replace(/r/g, 'ɹ') - .replace(/x/g, 'k') - .replace(/ɬ/g, 'l') - .replace(/(?<=[a-zɹː])(?=hˈʌndɹɪd)/g, ' ') - .replace(/ z(?=[;:,.!?¡¿—…"«»“” ]|$)/g, 'z'); + .replace(/kəkˈoːɹoʊ/g, "kˈoʊkəɹoʊ") + .replace(/kəkˈɔːɹəʊ/g, "kˈəʊkəɹəʊ") + .replace(/ʲ/g, "j") + .replace(/r/g, "ɹ") + .replace(/x/g, "k") + .replace(/ɬ/g, "l") + .replace(/(?<=[a-zɹː])(?=hˈʌndɹɪd)/g, " ") + .replace(/ z(?=[;:,.!?¡¿—…"«»“” ]|$)/g, "z"); // 5. Additional post-processing for American English - if (language === 'a') { - processed = processed.replace(/(?<=nˈaɪn)ti(?!ː)/g, 'di'); + if (language === "a") { + processed = processed.replace(/(?<=nˈaɪn)ti(?!ː)/g, "di"); } return processed.trim(); } diff --git a/ext/ai/js/ai.js b/ext/ai/js/ai.js index c6d53889c..330095947 100644 --- a/ext/ai/js/ai.js +++ b/ext/ai/js/ai.js @@ -1,6 +1,6 @@ -import 'ext:ai/onnxruntime/onnx.js'; -import InferenceAPI from 'ext:ai/onnxruntime/inference_api.js'; -import EventSourceStream from 'ext:ai/util/event_source_stream.mjs'; +import "ext:ai/onnxruntime/onnx.js"; +import InferenceAPI from "ext:ai/onnxruntime/inference_api.js"; +import EventSourceStream from "ext:ai/util/event_source_stream.mjs"; const core = globalThis.Deno.core; @@ -258,8 +258,8 @@ const MAIN_WORKER_API = { }; const USER_WORKER_API = { - Session, - ...InferenceAPI + Session, + ...InferenceAPI, }; export { MAIN_WORKER_API, USER_WORKER_API }; diff --git a/ext/ai/js/onnxruntime/inference_api.js b/ext/ai/js/onnxruntime/inference_api.js index 13d2cad25..f9b85b003 100644 --- a/ext/ai/js/onnxruntime/inference_api.js +++ b/ext/ai/js/onnxruntime/inference_api.js @@ -1,19 +1,19 @@ const core = globalThis.Deno.core; -import { InferenceSession, Tensor } from 'ext:ai/onnxruntime/onnx.js'; +import { InferenceSession, Tensor } from "ext:ai/onnxruntime/onnx.js"; const DEFAULT_HUGGING_FACE_OPTIONS = { - hostname: 'https://huggingface.co', + hostname: "https://huggingface.co", path: { - template: '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true', - revision: 'main', - modelFile: 'model_quantized.onnx', + template: "{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true", + revision: "main", + modelFile: "model_quantized.onnx", }, }; const DEFAULT_STORAGE_OPTIONS = () => ({ - hostname: Deno.env.get('SUPABASE_URL'), + hostname: Deno.env.get("SUPABASE_URL"), mode: { - authorization: Deno.env.get('SUPABASE_SERVICE_ROLE_KEY'), + authorization: Deno.env.get("SUPABASE_SERVICE_ROLE_KEY"), }, }); @@ -58,9 +58,9 @@ class UserInferenceSession { }; const modelPath = pathOpts.template - .replaceAll('{REPO_ID}', repoId) - .replaceAll('{REVISION}', pathOpts.revision) - .replaceAll('{MODEL_FILE}', pathOpts.modelFile); + .replaceAll("{REPO_ID}", repoId) + .replaceAll("{REVISION}", pathOpts.revision) + .replaceAll("{MODEL_FILE}", pathOpts.modelFile); if (!URL.canParse(modelPath, hostname)) { throw Error( @@ -76,7 +76,9 @@ class UserInferenceSession { const hostname = opts?.hostname ?? defaultOpts.hostname; const mode = opts?.mode ?? defaultOpts.mode; - const assetPath = mode === 'public' ? `public/${modelPath}` : `authenticated/${modelPath}`; + const assetPath = mode === "public" + ? `public/${modelPath}` + : `authenticated/${modelPath}`; const storageUrl = `/storage/v1/object/${assetPath}`; diff --git a/ext/ai/js/onnxruntime/onnx.js b/ext/ai/js/onnxruntime/onnx.js index 16ffeeb56..0f1138b1d 100644 --- a/ext/ai/js/onnxruntime/onnx.js +++ b/ext/ai/js/onnxruntime/onnx.js @@ -18,7 +18,7 @@ const DataTypeMap = Object.freeze({ class TensorProxy { get(target, property) { switch (property) { - case 'data': + case "data": return target.data?.c ?? target.data; default: @@ -134,4 +134,4 @@ const onnxruntime = { }, }; -globalThis[Symbol.for('onnxruntime')] = onnxruntime; +globalThis[Symbol.for("onnxruntime")] = onnxruntime; diff --git a/ext/ai/lib.rs b/ext/ai/lib.rs index c9af09d77..266e92058 100644 --- a/ext/ai/lib.rs +++ b/ext/ai/lib.rs @@ -118,8 +118,11 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { let handle = handle.clone(); move || { handle.block_on(async move { - load_session_from_url(Url::parse(consts::GTE_SMALL_MODEL_URL).unwrap(), None) - .await + load_session_from_url( + Url::parse(consts::GTE_SMALL_MODEL_URL).unwrap(), + None, + ) + .await }) } }) @@ -143,7 +146,7 @@ async fn init_gte(state: Rc>) -> Result<(), Error> { "tokenizer", Url::parse(consts::GTE_SMALL_TOKENIZER_URL).unwrap(), None, - None + None, ) .map_err(AnyError::from) .and_then(|it| { diff --git a/ext/ai/onnxruntime/mod.rs b/ext/ai/onnxruntime/mod.rs index f28026dc6..c37de4f92 100644 --- a/ext/ai/onnxruntime/mod.rs +++ b/ext/ai/onnxruntime/mod.rs @@ -19,8 +19,8 @@ use deno_core::error::AnyError; use deno_core::op2; use deno_core::JsBuffer; use deno_core::JsRuntime; -use deno_core::ToJsBuffer; use deno_core::OpState; +use deno_core::ToJsBuffer; use deno_core::V8CrossThreadTaskSpawner; use model::Model; @@ -30,35 +30,35 @@ use reqwest::Url; use tensor::JsTensor; use tensor::ToJsTensor; use tokio::sync::oneshot; +use tokio_util::bytes::BufMut; use tracing::debug; use tracing::trace; -use tokio_util::bytes::BufMut; #[op2(async)] #[to_v8] pub async fn op_ai_ort_init_session( - state: Rc>, - #[buffer] model_bytes: JsBuffer, - // Maybe improve the code style to enum payload or something else - #[string] req_authorization: Option, + state: Rc>, + #[buffer] model_bytes: JsBuffer, + // Maybe improve the code style to enum payload or something else + #[string] req_authorization: Option, ) -> Result { let model_bytes = model_bytes.into_parts().to_boxed_slice(); let model_bytes_or_url = str::from_utf8(&model_bytes) .map_err(AnyError::from) .and_then(|utf8_str| Url::parse(utf8_str).map_err(AnyError::from)); - let model = match model_bytes_or_url { - Ok(model_url) => { - trace!(kind = "url", url = %model_url); - Model::from_url(model_url, req_authorization).await? - } - Err(_) => { - trace!(kind = "bytes", len = model_bytes.len()); - Model::from_bytes(&model_bytes).await? - } - }; - - let mut state = state.borrow_mut(); + let model = match model_bytes_or_url { + Ok(model_url) => { + trace!(kind = "url", url = %model_url); + Model::from_url(model_url, req_authorization).await? + } + Err(_) => { + trace!(kind = "bytes", len = model_bytes.len()); + Model::from_bytes(&model_bytes).await? + } + }; + + let mut state = state.borrow_mut(); let mut sessions = { state.try_take::>>().unwrap_or_default() }; @@ -143,38 +143,38 @@ pub async fn op_ai_ort_run_session( #[op2] #[serde] pub fn op_ai_ort_encode_tensor_audio( - #[serde] tensor: JsBuffer, - sample_rate: u32, + #[serde] tensor: JsBuffer, + sample_rate: u32, ) -> Result { - // let copy for now - let data_buffer = tensor.into_iter().as_slice(); + // let copy for now + let data_buffer = tensor.iter().as_slice(); - let sample_size = 4; // f32 tensor - let data_chunk_size = data_buffer.len() as u32 * sample_size; - let total_riff_size = 36 + data_chunk_size; // 36 is the total of bytes until write data + let sample_size = 4; // f32 tensor + let data_chunk_size = data_buffer.len() as u32 * sample_size; + let total_riff_size = 36 + data_chunk_size; // 36 is the total of bytes until write data - let mut audio_wav = Vec::new(); + let mut audio_wav = Vec::new(); - // RIFF HEADER - audio_wav.extend_from_slice(b"RIFF"); - audio_wav.put_u32_le(total_riff_size); - audio_wav.extend_from_slice(b"WAVE"); + // RIFF HEADER + audio_wav.extend_from_slice(b"RIFF"); + audio_wav.put_u32_le(total_riff_size); + audio_wav.extend_from_slice(b"WAVE"); - // FORMAT CHUNK - audio_wav.extend_from_slice(b"fmt "); // whitespace needed "fmt" + " " - audio_wav.put_u32_le(16); // PCM chunk size - audio_wav.put_u16_le(3); // RAW audio format - audio_wav.put_u16_le(1); // Number of channels - audio_wav.put_u32_le(sample_rate); - audio_wav.put_u32_le(sample_rate * sample_size); // Byte rate - audio_wav.put_u16_le(sample_size as u16); // Block align - audio_wav.put_u16_le(32); // f32 Bits per sample + // FORMAT CHUNK + audio_wav.extend_from_slice(b"fmt "); // whitespace needed "fmt" + " " + audio_wav.put_u32_le(16); // PCM chunk size + audio_wav.put_u16_le(3); // RAW audio format + audio_wav.put_u16_le(1); // Number of channels + audio_wav.put_u32_le(sample_rate); + audio_wav.put_u32_le(sample_rate * sample_size); // Byte rate + audio_wav.put_u16_le(sample_size as u16); // Block align + audio_wav.put_u16_le(32); // f32 Bits per sample - // DATA Chunk - audio_wav.extend_from_slice(b"data"); // chunk ID - audio_wav.put_u32_le(data_chunk_size); + // DATA Chunk + audio_wav.extend_from_slice(b"data"); // chunk ID + audio_wav.put_u32_le(data_chunk_size); - audio_wav.extend_from_slice(data_buffer); + audio_wav.extend_from_slice(data_buffer); - Ok(ToJsBuffer::from(audio_wav)) + Ok(ToJsBuffer::from(audio_wav)) } diff --git a/ext/ai/onnxruntime/model.rs b/ext/ai/onnxruntime/model.rs index f1575d06e..a8907b479 100644 --- a/ext/ai/onnxruntime/model.rs +++ b/ext/ai/onnxruntime/model.rs @@ -71,11 +71,14 @@ impl Model { .map(Self::new) } - pub async fn from_url(model_url: Url, authorization: Option) -> Result { - load_session_from_url(model_url, authorization) - .await - .map(Self::new) - } + pub async fn from_url( + model_url: Url, + authorization: Option, + ) -> Result { + load_session_from_url(model_url, authorization) + .await + .map(Self::new) + } pub async fn from_bytes(model_bytes: &[u8]) -> Result { load_session_from_bytes(model_bytes).await.map(Self::new) diff --git a/ext/ai/onnxruntime/session.rs b/ext/ai/onnxruntime/session.rs index ca670792b..e407ee948 100644 --- a/ext/ai/onnxruntime/session.rs +++ b/ext/ai/onnxruntime/session.rs @@ -154,10 +154,10 @@ pub(crate) async fn load_session_from_bytes( #[instrument(level = "debug", fields(%model_url), err)] pub(crate) async fn load_session_from_url( - model_url: Url, - authorization: Option, + model_url: Url, + authorization: Option, ) -> Result { - let session_id = fxhash::hash(model_url.as_str()).to_string(); + let session_id = fxhash::hash(model_url.as_str()).to_string(); let mut sessions = SESSIONS.lock().await; @@ -166,13 +166,13 @@ pub(crate) async fn load_session_from_url( return Ok((session_id, session.clone()).into()); } - let model_file_path = crate::utils::fetch_and_cache_from_url( - "model", - model_url, - Some(session_id.to_string()), - authorization, - ) - .await?; + let model_file_path = crate::utils::fetch_and_cache_from_url( + "model", + model_url, + Some(session_id.to_string()), + authorization, + ) + .await?; let model_bytes = tokio::fs::read(model_file_path).await?; let session = create_session(model_bytes.as_slice())?; diff --git a/types/global.d.ts b/types/global.d.ts index 47c33612f..48aa5a302 100644 --- a/types/global.d.ts +++ b/types/global.d.ts @@ -12,12 +12,12 @@ declare interface WindowEventMap { "drain": Event; } -type DecoratorType = 'tc39' | 'typescript' | 'typescript_with_metadata'; +type DecoratorType = "tc39" | "typescript" | "typescript_with_metadata"; interface JsxImportBaseConfig { - defaultSpecifier?: string | null; - module?: string | null; - baseUrl?: string | null; + defaultSpecifier?: string | null; + module?: string | null; + baseUrl?: string | null; } // TODO(Nyannyacha): These two type defs will be provided later. @@ -158,216 +158,218 @@ declare namespace EdgeRuntime { } declare namespace Supabase { -export namespace ai { - interface ModelOptions { - /** - * Pool embeddings by taking their mean. Applies only for `gte-small` model - */ - mean_pool?: boolean; - - /** - * Normalize the embeddings result. Applies only for `gte-small` model - */ - normalize?: boolean; - - /** - * Stream response from model. Applies only for LLMs like `mistral` (default: false) - */ - stream?: boolean; - - /** - * Automatically abort the request to the model after specified time (in seconds). Applies only for LLMs like `mistral` (default: 60) - */ - timeout?: number; - - /** - * Mode for the inference API host. (default: 'ollama') - */ - mode?: 'ollama' | 'openaicompatible'; - signal?: AbortSignal; - } - - export type TensorDataTypeMap = { - float32: Float32Array | number[]; - float64: Float64Array | number[]; - string: string[]; - int8: Int8Array | number[]; - uint8: Uint8Array | number[]; - int16: Int16Array | number[]; - uint16: Uint16Array | number[]; - int32: Int32Array | number[]; - uint32: Uint32Array | number[]; - int64: BigInt64Array | number[]; - uint64: BigUint64Array | number[]; - bool: Uint8Array | number[]; - }; + export namespace ai { + interface ModelOptions { + /** + * Pool embeddings by taking their mean. Applies only for `gte-small` model + */ + mean_pool?: boolean; + + /** + * Normalize the embeddings result. Applies only for `gte-small` model + */ + normalize?: boolean; + + /** + * Stream response from model. Applies only for LLMs like `mistral` (default: false) + */ + stream?: boolean; + + /** + * Automatically abort the request to the model after specified time (in seconds). Applies only for LLMs like `mistral` (default: 60) + */ + timeout?: number; + + /** + * Mode for the inference API host. (default: 'ollama') + */ + mode?: "ollama" | "openaicompatible"; + signal?: AbortSignal; + } - export type TensorMap = { [key: string]: RawTensor }; - - export class Session { - /** - * Create a new model session using given model - */ - constructor(model: string); - - /** - * Execute the given prompt in model session - */ - run( - prompt: - | string - | Omit< - import('openai').OpenAI.Chat.ChatCompletionCreateParams, - 'model' | 'stream' - >, - modelOptions?: ModelOptions, - ): unknown; - } - - /** Provides an user friendly interface for the low level *onnx backend API*. - * A `RawSession` can execute any *onnx* model, but we only recommend it for `tabular` or *self-made* models, where you need mode control of model execution and pre/pos-processing. - * Consider a high-level implementation like `@huggingface/transformers.js` for generic tasks like `nlp`, `computer-vision` or `audio`. - * - * **Example:** - * ```typescript - * const session = await RawSession.fromHuggingFace('Supabase/gte-small'); - * // const session = await RawSession.fromUrl("https://example.com/model.onnx"); - * - * // Prepare the input tensors - * const inputs = { - * input1: new Tensor("float32", [1.0, 2.0, 3.0], [3]), - * input2: new Tensor("float32", [4.0, 5.0, 6.0], [3]), - * }; - * - * // Run the model - * const outputs = await session.run(inputs); - * - * console.log(outputs.output1); // Output tensor - * ``` + export type TensorDataTypeMap = { + float32: Float32Array | number[]; + float64: Float64Array | number[]; + string: string[]; + int8: Int8Array | number[]; + uint8: Uint8Array | number[]; + int16: Int16Array | number[]; + uint16: Uint16Array | number[]; + int32: Int32Array | number[]; + uint32: Uint32Array | number[]; + int64: BigInt64Array | number[]; + uint64: BigUint64Array | number[]; + bool: Uint8Array | number[]; + }; + + export type TensorMap = { + [key: string]: RawTensor; + }; + + export class Session { + /** + * Create a new model session using given model + */ + constructor(model: string); + + /** + * Execute the given prompt in model session + */ + run( + prompt: + | string + | Omit< + import("openai").OpenAI.Chat.ChatCompletionCreateParams, + "model" | "stream" + >, + modelOptions?: ModelOptions, + ): unknown; + } + + /** Provides an user friendly interface for the low level *onnx backend API*. + * A `RawSession` can execute any *onnx* model, but we only recommend it for `tabular` or *self-made* models, where you need mode control of model execution and pre/pos-processing. + * Consider a high-level implementation like `@huggingface/transformers.js` for generic tasks like `nlp`, `computer-vision` or `audio`. + * + * **Example:** + * ```typescript + * const session = await RawSession.fromHuggingFace('Supabase/gte-small'); + * // const session = await RawSession.fromUrl("https://example.com/model.onnx"); + * + * // Prepare the input tensors + * const inputs = { + * input1: new Tensor("float32", [1.0, 2.0, 3.0], [3]), + * input2: new Tensor("float32", [4.0, 5.0, 6.0], [3]), + * }; + * + * // Run the model + * const outputs = await session.run(inputs); + * + * console.log(outputs.output1); // Output tensor + * ``` + */ + export class RawSession { + /** The underline session's ID. + * Session's ID are unique for each loaded model, it means that even if a session is constructed twice its will share the same ID. + */ + id: string; + + /** A list of all input keys the model expects. */ + inputs: string[]; + + /** A list of all output keys the model will result. */ + outputs: string[]; + + /** Loads a ONNX model session from source URL. + * Sessions are loaded once, then will keep warm cross worker's requests + */ + static fromUrl(source: string | URL): Promise; + + /** Loads a ONNX model session from **HuggingFace** repository. + * Sessions are loaded once, then will keep warm cross worker's requests + */ + static fromHuggingFace(repoId: string, opts?: { + /** + * @default 'https://huggingface.co' */ - export class RawSession { - /** The underline session's ID. - * Session's ID are unique for each loaded model, it means that even if a session is constructed twice its will share the same ID. - */ - id: string; - - /** A list of all input keys the model expects. */ - inputs: string[]; - - /** A list of all output keys the model will result. */ - outputs: string[]; - - /** Loads a ONNX model session from source URL. - * Sessions are loaded once, then will keep warm cross worker's requests - */ - static fromUrl(source: string | URL): Promise; - - /** Loads a ONNX model session from **HuggingFace** repository. - * Sessions are loaded once, then will keep warm cross worker's requests - */ - static fromHuggingFace(repoId: string, opts?: { - /** - * @default 'https://huggingface.co' - */ - hostname?: string | URL; - path?: { - /** - * @default '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true' - */ - template?: string; - /** - * @default 'main' - */ - revision?: string; - /** - * @default 'model_quantized.onnx' - */ - modelFile?: string; - }; - }): Promise; - - /** Loads a ONNX model session from **Storage**. - * Sessions are loaded once, then will keep warm cross worker's requests - */ - static fromStorage(repoId: string, opts?: { - /** - * @default 'env SUPABASE_URL' - */ - hostname?: string | URL; - mode?: 'public' | { - authorization: string; - }; - }): Promise; - - /** Run the current session with the given inputs. - * Use `inputs` and `outputs` properties to know the required inputs and expected results for the model session. - * - * @param inputs The input tensors required by the model. - * @returns The output tensors generated by the model. - * - * @example - * ```typescript - * const session = await RawSession.fromUrl("https://example.com/model.onnx"); - * - * // Prepare the input tensors - * const inputs = { - * input1: new Tensor("float32", [1.0, 2.0, 3.0], [3]), - * input2: new Tensor("float32", [4.0, 5.0, 6.0], [3]), - * }; - * - * // Run the model - * const outputs = await session.run(inputs); - * - * console.log(outputs.output1); // Output tensor - * ``` - */ - run(inputs: TensorMap): Promise; - } - - /** A low level representation of model input/output. - * Supabase's `Tensor` is totally compatible with `@huggingface/transformers.js`'s `Tensor`. It means that you can use its high-level API to apply some common operations like `sum()`, `min()`, `max()`, `normalize()` etc... - * - * **Example: Generating embeddings from scratch** - * ```typescript - * import { Tensor as HFTensor } from "@huggingface/transformers.js"; - * const { Tensor, RawSession } = Supabase.ai; - * - * const session = await RawSession.fromHuggingFace('Supabase/gte-small'); - * - * // Example only, in real 'feature-extraction' tensors are given from the tokenizer step. - * const inputs = { - * input_ids: new Tensor('float32', [...], [n, 2]), - * attention_mask: new Tensor('float32', [...], [n, 2]), - * token_types_ids: new Tensor('float32', [...], [n, 2]) - * }; - * - * const { last_hidden_state } = await session.run(inputs); - * - * // Using `transformers.js` APIs - * const hfTensor = HFTensor.mean_pooling(last_hidden_state, inputs.attention_mask).normalize(); - * - * return hfTensor.tolist(); - * - * ``` + hostname?: string | URL; + path?: { + /** + * @default '{REPO_ID}/resolve/{REVISION}/onnx/{MODEL_FILE}?donwload=true' + */ + template?: string; + /** + * @default 'main' + */ + revision?: string; + /** + * @default 'model_quantized.onnx' + */ + modelFile?: string; + }; + }): Promise; + + /** Loads a ONNX model session from **Storage**. + * Sessions are loaded once, then will keep warm cross worker's requests + */ + static fromStorage(repoId: string, opts?: { + /** + * @default 'env SUPABASE_URL' */ - export class RawTensor { - /** Type of the tensor. */ - type: T; - - /** The data stored in the tensor. */ - data: TensorDataTypeMap[T]; - - /** Dimensions of the tensor. */ - dims: number[]; - - /** The total number of elements in the tensor. */ - size: number; - - constructor(type: T, data: TensorDataTypeMap[T], dims: number[]); + hostname?: string | URL; + mode?: "public" | { + authorization: string; + }; + }): Promise; + + /** Run the current session with the given inputs. + * Use `inputs` and `outputs` properties to know the required inputs and expected results for the model session. + * + * @param inputs The input tensors required by the model. + * @returns The output tensors generated by the model. + * + * @example + * ```typescript + * const session = await RawSession.fromUrl("https://example.com/model.onnx"); + * + * // Prepare the input tensors + * const inputs = { + * input1: new Tensor("float32", [1.0, 2.0, 3.0], [3]), + * input2: new Tensor("float32", [4.0, 5.0, 6.0], [3]), + * }; + * + * // Run the model + * const outputs = await session.run(inputs); + * + * console.log(outputs.output1); // Output tensor + * ``` + */ + run(inputs: TensorMap): Promise; + } - tryEncodeAudio(sampleRate: number): Promise; - } + /** A low level representation of model input/output. + * Supabase's `Tensor` is totally compatible with `@huggingface/transformers.js`'s `Tensor`. It means that you can use its high-level API to apply some common operations like `sum()`, `min()`, `max()`, `normalize()` etc... + * + * **Example: Generating embeddings from scratch** + * ```typescript + * import { Tensor as HFTensor } from "@huggingface/transformers.js"; + * const { Tensor, RawSession } = Supabase.ai; + * + * const session = await RawSession.fromHuggingFace('Supabase/gte-small'); + * + * // Example only, in real 'feature-extraction' tensors are given from the tokenizer step. + * const inputs = { + * input_ids: new Tensor('float32', [...], [n, 2]), + * attention_mask: new Tensor('float32', [...], [n, 2]), + * token_types_ids: new Tensor('float32', [...], [n, 2]) + * }; + * + * const { last_hidden_state } = await session.run(inputs); + * + * // Using `transformers.js` APIs + * const hfTensor = HFTensor.mean_pooling(last_hidden_state, inputs.attention_mask).normalize(); + * + * return hfTensor.tolist(); + * + * ``` + */ + export class RawTensor { + /** Type of the tensor. */ + type: T; + + /** The data stored in the tensor. */ + data: TensorDataTypeMap[T]; + + /** Dimensions of the tensor. */ + dims: number[]; + + /** The total number of elements in the tensor. */ + size: number; + + constructor(type: T, data: TensorDataTypeMap[T], dims: number[]); + + tryEncodeAudio(sampleRate: number): Promise; } + } } declare namespace Deno {