From 8b206b8235ecef99d1ae48e5426cacc82d55ea43 Mon Sep 17 00:00:00 2001 From: yang Date: Thu, 16 Apr 2026 22:58:18 -0400 Subject: [PATCH 1/4] refactor(embedding): extract EmbeddingProvider layer with OpenAI + Ollama implementations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The embedding service was a monolithic OpenAI-specific module. This extracts a provider interface so new backends (Ollama, vLLM, LiteLLM, Voyage) slot in without touching callers. Changes: - Add src/core/embedding/provider.ts — EmbeddingProvider interface + ProviderConfig type - Add src/core/embedding/providers/openai.ts — OpenAIProvider with Matryoshka dim param gated to text-embedding-3 family - Add src/core/embedding/providers/ollama.ts — OllamaProvider over /v1/embeddings, infers dim from known model registry, normalizes errors for retry - Add src/core/embedding/factory.ts — createProvider(config) + resolveConfig that merges explicit config > EMBEDDING_* env vars > defaults - Add src/core/embedding/service.ts — provider-agnostic batching, retry, truncation - Add src/core/embedding/index.ts — public surface - Keep src/core/embedding.ts as a thin re-export shim so existing imports work unchanged - Add test/embedding/provider.test.ts — 15 tests covering both providers, factory, env resolution Default behavior is preserved: no flags, no env vars → OpenAI text-embedding-3-large at 1536 dimensions. The full existing test suite (861 tests) passes without changes. The schema still hardcodes vector(1536); provider-driven schema templating lands in the follow-up commit. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/core/embedding.ts | 103 +++-------------- src/core/embedding/factory.ts | 56 +++++++++ src/core/embedding/index.ts | 19 ++++ src/core/embedding/provider.ts | 43 +++++++ src/core/embedding/providers/ollama.ts | 121 ++++++++++++++++++++ src/core/embedding/providers/openai.ts | 71 ++++++++++++ src/core/embedding/service.ts | 108 ++++++++++++++++++ test/embedding/provider.test.ts | 150 +++++++++++++++++++++++++ 8 files changed, 581 insertions(+), 90 deletions(-) create mode 100644 src/core/embedding/factory.ts create mode 100644 src/core/embedding/index.ts create mode 100644 src/core/embedding/provider.ts create mode 100644 src/core/embedding/providers/ollama.ts create mode 100644 src/core/embedding/providers/openai.ts create mode 100644 src/core/embedding/service.ts create mode 100644 test/embedding/provider.test.ts diff --git a/src/core/embedding.ts b/src/core/embedding.ts index 4689ccd1..0a975fc7 100644 --- a/src/core/embedding.ts +++ b/src/core/embedding.ts @@ -1,94 +1,17 @@ /** - * Embedding Service - * Ported from production Ruby implementation (embedding_service.rb, 190 LOC) + * BACKWARD-COMPATIBILITY SHIM * - * OpenAI text-embedding-3-large at 1536 dimensions. - * Retry with exponential backoff (4s base, 120s cap, 5 retries). - * 8000 character input truncation. + * The embedding implementation moved to `src/core/embedding/` as a provider layer + * (OpenAIProvider, OllamaProvider, factory, service). This file re-exports the + * public surface so existing imports keep working without churn: + * + * import { embed, embedBatch } from '../core/embedding.ts'; + * + * New code should import from `./embedding/index.ts` directly to access + * createProvider, EmbeddingProvider, OllamaProvider, etc. + * + * Test mocks (`mock.module('../src/core/embedding.ts', () => ({ embedBatch }))`) + * continue to intercept the call chain at this shim, so existing tests work unchanged. */ -import OpenAI from 'openai'; - -const MODEL = 'text-embedding-3-large'; -const DIMENSIONS = 1536; -const MAX_CHARS = 8000; -const MAX_RETRIES = 5; -const BASE_DELAY_MS = 4000; -const MAX_DELAY_MS = 120000; -const BATCH_SIZE = 100; - -let client: OpenAI | null = null; - -function getClient(): OpenAI { - if (!client) { - client = new OpenAI(); - } - return client; -} - -export async function embed(text: string): Promise { - const truncated = text.slice(0, MAX_CHARS); - const result = await embedBatch([truncated]); - return result[0]; -} - -export async function embedBatch(texts: string[]): Promise { - const truncated = texts.map(t => t.slice(0, MAX_CHARS)); - const results: Float32Array[] = []; - - // Process in batches of BATCH_SIZE - for (let i = 0; i < truncated.length; i += BATCH_SIZE) { - const batch = truncated.slice(i, i + BATCH_SIZE); - const batchResults = await embedBatchWithRetry(batch); - results.push(...batchResults); - } - - return results; -} - -async function embedBatchWithRetry(texts: string[]): Promise { - for (let attempt = 0; attempt < MAX_RETRIES; attempt++) { - try { - const response = await getClient().embeddings.create({ - model: MODEL, - input: texts, - dimensions: DIMENSIONS, - }); - - // Sort by index to maintain order - const sorted = response.data.sort((a, b) => a.index - b.index); - return sorted.map(d => new Float32Array(d.embedding)); - } catch (e: unknown) { - if (attempt === MAX_RETRIES - 1) throw e; - - // Check for rate limit with Retry-After header - let delay = exponentialDelay(attempt); - - if (e instanceof OpenAI.APIError && e.status === 429) { - const retryAfter = e.headers?.['retry-after']; - if (retryAfter) { - const parsed = parseInt(retryAfter, 10); - if (!isNaN(parsed)) { - delay = parsed * 1000; - } - } - } - - await sleep(delay); - } - } - - // Should not reach here - throw new Error('Embedding failed after all retries'); -} - -function exponentialDelay(attempt: number): number { - const delay = BASE_DELAY_MS * Math.pow(2, attempt); - return Math.min(delay, MAX_DELAY_MS); -} - -function sleep(ms: number): Promise { - return new Promise(resolve => setTimeout(resolve, ms)); -} - -export { MODEL as EMBEDDING_MODEL, DIMENSIONS as EMBEDDING_DIMENSIONS }; +export { embed, embedBatch } from './embedding/service.ts'; diff --git a/src/core/embedding/factory.ts b/src/core/embedding/factory.ts new file mode 100644 index 00000000..9c9af41d --- /dev/null +++ b/src/core/embedding/factory.ts @@ -0,0 +1,56 @@ +/** + * Provider factory — resolves a ProviderConfig to a concrete EmbeddingProvider. + * + * Resolution order (most specific wins): + * 1. Explicit ProviderConfig argument (from CLI flags or `.gbrain.config.json`) + * 2. Env vars: EMBEDDING_PROVIDER, EMBEDDING_MODEL, EMBEDDING_DIMENSIONS, EMBEDDING_BASE_URL + * 3. Defaults: OpenAI text-embedding-3-large at 1536 dimensions + * + * Callers should pass an explicit config when they have one. Env-var fallback exists for + * scripts and tests that don't go through `gbrain init`. + */ + +import type { EmbeddingProvider, ProviderConfig } from './provider.ts'; +import { OpenAIProvider } from './providers/openai.ts'; +import { OllamaProvider } from './providers/ollama.ts'; + +const REGISTRY: Record EmbeddingProvider> = { + openai: OpenAIProvider, + ollama: OllamaProvider, +}; + +export function createProvider(config?: Partial): EmbeddingProvider { + const resolved = resolveConfig(config); + const ProviderClass = REGISTRY[resolved.provider]; + if (!ProviderClass) { + const known = Object.keys(REGISTRY).join(', '); + throw new Error(`Unknown embedding provider '${resolved.provider}'. Known: ${known}.`); + } + return new ProviderClass(resolved); +} + +export function resolveConfig(override?: Partial): ProviderConfig { + const fromEnv: Partial = { + provider: process.env.EMBEDDING_PROVIDER, + model: process.env.EMBEDDING_MODEL, + dimensions: process.env.EMBEDDING_DIMENSIONS + ? parseInt(process.env.EMBEDDING_DIMENSIONS, 10) + : undefined, + baseUrl: process.env.EMBEDDING_BASE_URL ?? process.env.OPENAI_BASE_URL, + apiKey: process.env.OPENAI_API_KEY, + }; + + // Override > env > defaults (defaults filled per-provider in the constructor) + const provider = override?.provider ?? fromEnv.provider ?? 'openai'; + return { + provider, + model: override?.model ?? fromEnv.model, + dimensions: override?.dimensions ?? fromEnv.dimensions, + baseUrl: override?.baseUrl ?? fromEnv.baseUrl, + apiKey: override?.apiKey ?? fromEnv.apiKey, + }; +} + +export function listProviders(): string[] { + return Object.keys(REGISTRY); +} diff --git a/src/core/embedding/index.ts b/src/core/embedding/index.ts new file mode 100644 index 00000000..b05ecef4 --- /dev/null +++ b/src/core/embedding/index.ts @@ -0,0 +1,19 @@ +/** + * Public surface of the embedding layer. + * + * Most callers want `embed` / `embedBatch` from `./service.ts`. + * `gbrain doctor` and `gbrain init` use `createProvider` + `getActiveProvider` to introspect. + */ + +export type { EmbeddingProvider, ProviderConfig, HealthCheckResult } from './provider.ts'; +export { createProvider, resolveConfig, listProviders } from './factory.ts'; +export { + embed, + embedBatch, + setProvider, + getActiveProvider, + getEmbeddingModel, + getEmbeddingDimensions, +} from './service.ts'; +export { OpenAIProvider } from './providers/openai.ts'; +export { OllamaProvider, OllamaError } from './providers/ollama.ts'; diff --git a/src/core/embedding/provider.ts b/src/core/embedding/provider.ts new file mode 100644 index 00000000..1ffc98f2 --- /dev/null +++ b/src/core/embedding/provider.ts @@ -0,0 +1,43 @@ +/** + * EmbeddingProvider — the contract every embedding backend implements. + * + * Provider quirks (Matryoshka dim param, error shapes, auth) live behind this interface + * so callers (service, init, embed command) never branch on provider name. + * + * All providers MUST return vectors of exactly `dimensions` length per call. + * If a provider's model returns a different size, the provider implementation + * must reject (not silently truncate/pad). + */ + +export interface EmbeddingProvider { + readonly name: string; // 'openai' | 'ollama' | future + readonly model: string; // 'text-embedding-3-large' | 'nomic-embed-text' | ... + readonly dimensions: number; // fixed for the lifetime of this instance + readonly maxInputChars: number; // truncation budget per text + + embed(texts: string[]): Promise; + + /** Lightweight liveness check — used by `gbrain doctor` and init. */ + healthCheck(): Promise; +} + +export interface HealthCheckResult { + ok: boolean; + reason?: string; + // Optional metadata for `gbrain doctor --json` + latencyMs?: number; + detectedDimensions?: number; +} + +export interface ProviderConfig { + /** Provider name. Currently 'openai' or 'ollama'. */ + provider: string; + /** Model identifier. Required for non-default providers; optional for openai (defaults to text-embedding-3-large). */ + model?: string; + /** Output dimension. Required if it cannot be inferred from (provider, model). */ + dimensions?: number; + /** Override base URL (for self-hosted vLLM, LiteLLM proxy, custom Ollama port). */ + baseUrl?: string; + /** API key. Optional for local providers; required for OpenAI proper. */ + apiKey?: string; +} diff --git a/src/core/embedding/providers/ollama.ts b/src/core/embedding/providers/ollama.ts new file mode 100644 index 00000000..2454a682 --- /dev/null +++ b/src/core/embedding/providers/ollama.ts @@ -0,0 +1,121 @@ +/** + * OllamaProvider — embeddings via Ollama's OpenAI-compatible /v1/embeddings endpoint. + * + * Differences from OpenAI: + * - No `dimensions` parameter (Matryoshka not supported) + * - Output dim is fixed by the model (nomic-embed-text=768, mxbai-embed-large=1024, bge-m3=1024) + * - No API key required (ignored if sent) + * - Errors don't follow OpenAI's shape — we normalize them here so the service's + * retry loop sees consistent error types. + */ + +import OpenAI from 'openai'; +import type { EmbeddingProvider, HealthCheckResult, ProviderConfig } from '../provider.ts'; + +const DEFAULT_BASE_URL = 'http://localhost:11434/v1'; +const DEFAULT_MAX_CHARS = 8000; + +/** Known Ollama embedding models and their native output dimensions. */ +const KNOWN_DIMENSIONS: Record = { + 'nomic-embed-text': 768, + 'mxbai-embed-large': 1024, + 'bge-m3': 1024, + 'snowflake-arctic-embed:large': 1024, + 'all-minilm': 384, +}; + +export class OllamaProvider implements EmbeddingProvider { + readonly name = 'ollama'; + readonly model: string; + readonly dimensions: number; + readonly maxInputChars = DEFAULT_MAX_CHARS; + private readonly client: OpenAI; + + constructor(config: ProviderConfig) { + if (!config.model) { + throw new Error("OllamaProvider requires `model` in ProviderConfig (e.g. 'nomic-embed-text')."); + } + this.model = config.model; + this.dimensions = config.dimensions ?? KNOWN_DIMENSIONS[config.model] ?? 0; + + if (!this.dimensions) { + throw new Error( + `OllamaProvider: cannot infer dimensions for model '${config.model}'. ` + + `Pass --dimensions explicitly or add it to KNOWN_DIMENSIONS in providers/ollama.ts.` + ); + } + + this.client = new OpenAI({ + apiKey: config.apiKey ?? 'ollama-no-key', + baseURL: config.baseUrl ?? DEFAULT_BASE_URL, + }); + } + + async embed(texts: string[]): Promise { + if (texts.length === 0) return []; + let response; + try { + response = await this.client.embeddings.create({ + model: this.model, + input: texts, + }); + } catch (e: unknown) { + // Normalize Ollama errors so service-layer retry can distinguish transient vs fatal. + throw normalizeOllamaError(e); + } + const sorted = response.data.sort((a, b) => a.index - b.index); + return sorted.map(d => { + const v = new Float32Array(d.embedding); + if (v.length !== this.dimensions) { + throw new Error( + `OllamaProvider: expected ${this.dimensions}-dim vector, got ${v.length}. ` + + `Model ${this.model} may not match its declared dimensions — check ollama pull output.` + ); + } + return v; + }); + } + + async healthCheck(): Promise { + const start = Date.now(); + try { + const result = await this.embed(['health check']); + return { + ok: true, + latencyMs: Date.now() - start, + detectedDimensions: result[0]?.length, + }; + } catch (e: unknown) { + const reason = e instanceof Error ? e.message : String(e); + return { ok: false, reason, latencyMs: Date.now() - start }; + } + } +} + +class OllamaError extends Error { + constructor(message: string, readonly status?: number, readonly transient = false) { + super(message); + this.name = 'OllamaError'; + } +} + +function normalizeOllamaError(e: unknown): Error { + if (e instanceof OpenAI.APIError) { + // Ollama may return 404 if model not pulled, 503 if loading, 500 transient. + const transient = e.status === 503 || e.status === 500 || e.status === 429; + let hint = ''; + if (e.status === 404) hint = ` (model not pulled? try: ollama pull ${e.message.match(/model "([^"]+)"/)?.[1] ?? 'MODEL'})`; + if (e.status === 503) hint = ' (Ollama is loading the model — retry shortly)'; + return new OllamaError(`Ollama API ${e.status}: ${e.message}${hint}`, e.status, transient); + } + if (e instanceof Error && /ECONNREFUSED|fetch failed|ENOTFOUND/.test(e.message)) { + return new OllamaError( + `Ollama not reachable at the configured base URL. Is the daemon running? Try: ollama serve`, + undefined, + true + ); + } + return e instanceof Error ? e : new Error(String(e)); +} + +export { OllamaError }; diff --git a/src/core/embedding/providers/openai.ts b/src/core/embedding/providers/openai.ts new file mode 100644 index 00000000..d6b0cc2b --- /dev/null +++ b/src/core/embedding/providers/openai.ts @@ -0,0 +1,71 @@ +/** + * OpenAIProvider — embeddings via OpenAI's API or any OpenAI-compatible endpoint + * that supports the Matryoshka `dimensions` parameter (text-embedding-3 family). + */ + +import OpenAI from 'openai'; +import type { EmbeddingProvider, HealthCheckResult, ProviderConfig } from '../provider.ts'; + +const DEFAULT_MODEL = 'text-embedding-3-large'; +const DEFAULT_DIMENSIONS = 1536; +const DEFAULT_MAX_CHARS = 8000; + +/** Models in the text-embedding-3 family accept the `dimensions` param (Matryoshka). */ +function supportsMatryoshka(model: string): boolean { + return model.startsWith('text-embedding-3'); +} + +export class OpenAIProvider implements EmbeddingProvider { + readonly name = 'openai'; + readonly model: string; + readonly dimensions: number; + readonly maxInputChars = DEFAULT_MAX_CHARS; + private readonly client: OpenAI; + private readonly useDimensionsParam: boolean; + + constructor(config: ProviderConfig) { + this.model = config.model ?? DEFAULT_MODEL; + this.dimensions = config.dimensions ?? DEFAULT_DIMENSIONS; + this.useDimensionsParam = supportsMatryoshka(this.model); + + this.client = new OpenAI({ + apiKey: config.apiKey ?? process.env.OPENAI_API_KEY ?? '', + ...(config.baseUrl ? { baseURL: config.baseUrl } : {}), + }); + } + + async embed(texts: string[]): Promise { + if (texts.length === 0) return []; + const response = await this.client.embeddings.create({ + model: this.model, + input: texts, + ...(this.useDimensionsParam ? { dimensions: this.dimensions } : {}), + }); + const sorted = response.data.sort((a, b) => a.index - b.index); + return sorted.map(d => { + const v = new Float32Array(d.embedding); + if (v.length !== this.dimensions) { + throw new Error( + `OpenAIProvider: expected ${this.dimensions}-dim vector, got ${v.length}. ` + + `Model ${this.model} may not support requested dimensions.` + ); + } + return v; + }); + } + + async healthCheck(): Promise { + const start = Date.now(); + try { + const result = await this.embed(['health check']); + return { + ok: true, + latencyMs: Date.now() - start, + detectedDimensions: result[0]?.length, + }; + } catch (e: unknown) { + const reason = e instanceof Error ? e.message : String(e); + return { ok: false, reason, latencyMs: Date.now() - start }; + } + } +} diff --git a/src/core/embedding/service.ts b/src/core/embedding/service.ts new file mode 100644 index 00000000..bf2eec32 --- /dev/null +++ b/src/core/embedding/service.ts @@ -0,0 +1,108 @@ +/** + * Embedding Service — provider-agnostic batching, retry, truncation. + * + * Owns the cross-cutting concerns: chunked batching to respect provider rate limits, + * exponential backoff on retryable errors, input truncation to provider's max chars. + * + * Delegates the actual API call to a provider instance from `./factory.ts`. + */ + +import OpenAI from 'openai'; +import type { EmbeddingProvider } from './provider.ts'; +import { createProvider } from './factory.ts'; +import { OllamaError } from './providers/ollama.ts'; + +const MAX_RETRIES = 5; +const BASE_DELAY_MS = 4000; +const MAX_DELAY_MS = 120000; +const BATCH_SIZE = 100; + +let defaultProvider: EmbeddingProvider | null = null; + +/** Lazy-init: build the default provider on first use. Override in tests via setProvider. */ +function getProvider(): EmbeddingProvider { + if (!defaultProvider) { + defaultProvider = createProvider(); + } + return defaultProvider; +} + +/** Replace the singleton provider — for tests, or after config reload. */ +export function setProvider(provider: EmbeddingProvider | null): void { + defaultProvider = provider; +} + +/** Returns the active provider's metadata without re-creating it. */ +export function getActiveProvider(): EmbeddingProvider { + return getProvider(); +} + +export async function embed(text: string): Promise { + const provider = getProvider(); + const truncated = text.slice(0, provider.maxInputChars); + const result = await embedBatchInternal(provider, [truncated]); + return result[0]; +} + +export async function embedBatch(texts: string[]): Promise { + const provider = getProvider(); + const truncated = texts.map(t => t.slice(0, provider.maxInputChars)); + const results: Float32Array[] = []; + + for (let i = 0; i < truncated.length; i += BATCH_SIZE) { + const batch = truncated.slice(i, i + BATCH_SIZE); + const batchResults = await embedBatchInternal(provider, batch); + results.push(...batchResults); + } + return results; +} + +async function embedBatchInternal(provider: EmbeddingProvider, texts: string[]): Promise { + for (let attempt = 0; attempt < MAX_RETRIES; attempt++) { + try { + return await provider.embed(texts); + } catch (e: unknown) { + if (attempt === MAX_RETRIES - 1) throw e; + if (!isRetryable(e)) throw e; + await sleep(retryDelay(e, attempt)); + } + } + throw new Error('Embedding failed after all retries'); +} + +function isRetryable(e: unknown): boolean { + if (e instanceof OpenAI.APIError) { + return e.status === 429 || e.status === 500 || e.status === 502 || e.status === 503; + } + if (e instanceof OllamaError) return e.transient; + // Network / DNS / fetch failures from any provider — retry + if (e instanceof Error && /ECONNREFUSED|fetch failed|ENOTFOUND|ETIMEDOUT/.test(e.message)) return true; + return false; +} + +function retryDelay(e: unknown, attempt: number): number { + // Honor Retry-After if the provider sent one (OpenAI 429s). + if (e instanceof OpenAI.APIError && e.status === 429) { + const retryAfter = e.headers?.['retry-after']; + if (retryAfter) { + const parsed = parseInt(retryAfter, 10); + if (!isNaN(parsed)) return parsed * 1000; + } + } + return Math.min(BASE_DELAY_MS * Math.pow(2, attempt), MAX_DELAY_MS); +} + +function sleep(ms: number): Promise { + return new Promise(resolve => setTimeout(resolve, ms)); +} + +// Backward-compat exports — these mirror the old src/core/embedding.ts contract. +// Code that does `import { EMBEDDING_MODEL, EMBEDDING_DIMENSIONS } from '...'` +// gets the active provider's values. +export function getEmbeddingModel(): string { + return getProvider().model; +} + +export function getEmbeddingDimensions(): number { + return getProvider().dimensions; +} diff --git a/test/embedding/provider.test.ts b/test/embedding/provider.test.ts new file mode 100644 index 00000000..f8d9476f --- /dev/null +++ b/test/embedding/provider.test.ts @@ -0,0 +1,150 @@ +import { describe, test, expect, beforeEach, afterEach, mock } from 'bun:test'; +import { OpenAIProvider } from '../../src/core/embedding/providers/openai.ts'; +import { OllamaProvider } from '../../src/core/embedding/providers/ollama.ts'; +import { createProvider, resolveConfig, listProviders } from '../../src/core/embedding/factory.ts'; + +const mockCreate = mock(async (_args: any) => ({ + data: [{ index: 0, embedding: new Array(1536).fill(0.1) }], +})); + +mock.module('openai', () => { + class MockOpenAI { + embeddings = { create: mockCreate }; + constructor(public config: any) {} + } + return { + default: MockOpenAI, + APIError: class APIError extends Error { + constructor(public status: number, message: string, public headers?: any) { + super(message); + } + }, + }; +}); + +beforeEach(() => { mockCreate.mockClear(); }); + +afterEach(() => { + delete process.env.EMBEDDING_PROVIDER; + delete process.env.EMBEDDING_MODEL; + delete process.env.EMBEDDING_DIMENSIONS; + delete process.env.EMBEDDING_BASE_URL; + delete process.env.OPENAI_BASE_URL; +}); + +describe('OpenAIProvider', () => { + test('sends Matryoshka dimensions param for text-embedding-3-large', async () => { + const p = new OpenAIProvider({ provider: 'openai' }); + expect(p.name).toBe('openai'); + expect(p.model).toBe('text-embedding-3-large'); + expect(p.dimensions).toBe(1536); + await p.embed(['hello']); + expect(mockCreate).toHaveBeenCalledWith({ + model: 'text-embedding-3-large', + input: ['hello'], + dimensions: 1536, + }); + }); + + test('omits dimensions param for non-text-embedding-3 models', async () => { + const p = new OpenAIProvider({ provider: 'openai', model: 'text-embedding-ada-002', dimensions: 1536 }); + await p.embed(['hello']); + const call = mockCreate.mock.calls[0][0]; + expect(call).not.toHaveProperty('dimensions'); + expect(call.model).toBe('text-embedding-ada-002'); + }); + + test('rejects vectors of unexpected dimension', async () => { + mockCreate.mockImplementationOnce(async () => ({ + data: [{ index: 0, embedding: new Array(768).fill(0.1) }], + })); + const p = new OpenAIProvider({ provider: 'openai', dimensions: 1536 }); + await expect(p.embed(['x'])).rejects.toThrow(/expected 1536-dim vector, got 768/); + }); +}); + +describe('OllamaProvider', () => { + test('infers dimensions from known model registry', () => { + const p = new OllamaProvider({ provider: 'ollama', model: 'nomic-embed-text' }); + expect(p.name).toBe('ollama'); + expect(p.model).toBe('nomic-embed-text'); + expect(p.dimensions).toBe(768); + }); + + test('throws when model is missing', () => { + expect(() => new OllamaProvider({ provider: 'ollama' })).toThrow(/requires `model`/); + }); + + test('throws when dimensions cannot be inferred for unknown model', () => { + expect(() => new OllamaProvider({ provider: 'ollama', model: 'mystery-model' })).toThrow(/cannot infer dimensions/); + }); + + test('omits dimensions param in API call', async () => { + mockCreate.mockImplementationOnce(async () => ({ + data: [{ index: 0, embedding: new Array(768).fill(0.1) }], + })); + const p = new OllamaProvider({ provider: 'ollama', model: 'nomic-embed-text' }); + await p.embed(['hello']); + const call = mockCreate.mock.calls[0][0]; + expect(call).not.toHaveProperty('dimensions'); + expect(call.model).toBe('nomic-embed-text'); + }); + + test('uses default base URL http://localhost:11434/v1', () => { + const p = new OllamaProvider({ provider: 'ollama', model: 'nomic-embed-text' }); + expect((p as any).client.config.baseURL).toBe('http://localhost:11434/v1'); + }); + + test('rejects vectors of unexpected dimension', async () => { + mockCreate.mockImplementationOnce(async () => ({ + data: [{ index: 0, embedding: new Array(1024).fill(0.1) }], + })); + const p = new OllamaProvider({ provider: 'ollama', model: 'nomic-embed-text' }); + await expect(p.embed(['x'])).rejects.toThrow(/expected 768-dim vector, got 1024/); + }); +}); + +describe('factory', () => { + test('listProviders returns known names', () => { + const names = listProviders(); + expect(names).toContain('openai'); + expect(names).toContain('ollama'); + }); + + test('createProvider with explicit ollama config', () => { + const p = createProvider({ provider: 'ollama', model: 'nomic-embed-text' }); + expect(p.name).toBe('ollama'); + expect(p.dimensions).toBe(768); + }); + + test('createProvider defaults to OpenAI when nothing specified', () => { + const p = createProvider(); + expect(p.name).toBe('openai'); + expect(p.model).toBe('text-embedding-3-large'); + expect(p.dimensions).toBe(1536); + }); + + test('createProvider throws on unknown provider', () => { + expect(() => createProvider({ provider: 'fictional' })).toThrow(/Unknown embedding provider/); + }); + + test('resolveConfig pulls from EMBEDDING_* env vars', () => { + process.env.EMBEDDING_PROVIDER = 'ollama'; + process.env.EMBEDDING_MODEL = 'mxbai-embed-large'; + process.env.EMBEDDING_DIMENSIONS = '1024'; + process.env.EMBEDDING_BASE_URL = 'http://example.com/v1'; + const cfg = resolveConfig(); + expect(cfg).toMatchObject({ + provider: 'ollama', + model: 'mxbai-embed-large', + dimensions: 1024, + baseUrl: 'http://example.com/v1', + }); + }); + + test('resolveConfig override beats env', () => { + process.env.EMBEDDING_MODEL = 'env-model'; + const cfg = resolveConfig({ provider: 'openai', model: 'override-model' }); + expect(cfg.model).toBe('override-model'); + }); +}); From 0abd57cdb1683a4e01a185584a81f808d7717d87 Mon Sep 17 00:00:00 2001 From: yang Date: Thu, 16 Apr 2026 23:05:58 -0400 Subject: [PATCH 2/4] feat(schema): provider-driven embedding dimensions + init flags MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Until now the PGLite and Postgres schemas hardcoded vector(1536) and text-embedding-3-large — the result of the v0.6 env-var shim stopping at embedding.ts without reaching the schema layer. This patch finishes the abstraction: a brain's embedding dim and default model are chosen at init time from the resolved EmbeddingProvider, templated into the schema, and persisted to ~/.gbrain/config.json. Changes: - Convert PGLITE_SCHEMA_SQL const to pgliteSchema({dimensions, defaultModel}) function; keep the const as a backward-compat alias that evaluates defaults. - Same shape for postgresSchema in src/core/schema-embedded.ts; SCHEMA_SQL alias preserved. - Engine.initSchema() now takes optional opts (same shape), passes through to the schema function. Default behavior unchanged when called with no args. - Add embedding: {provider, model, dimensions, base_url} field to GBrainConfig. - init.ts: parse --provider / --model / --dimensions / --base-url; resolve via createProvider() (validates + infers Ollama dims); dim-mismatch guard refuses re-init against an existing brain with different dimensions; pass opts to initSchema; persist the chosen provider to config. - cli.ts: --version also prints active provider when a config is loadable. - test/schema-templating.test.ts — 11 new unit tests covering default fallback, partial opts, Postgres dollar-quote preservation, and const-alias parity. Example usage: gbrain init --provider=ollama --model=nomic-embed-text # 768d brain gbrain init --provider=openai # 1536d brain (default) gbrain init --provider=openai --dimensions=3072 # full text-embedding-3-large gbrain init # defaults (openai 1536d) All 861 existing tests still pass; 11 new schema tests added (872 total). Co-Authored-By: Claude Opus 4.7 (1M context) --- src/cli.ts | 11 ++++ src/commands/init.ts | 101 ++++++++++++++++++++++++++++++--- src/core/config.ts | 17 ++++++ src/core/db.ts | 6 +- src/core/engine.ts | 2 +- src/core/pglite-engine.ts | 6 +- src/core/pglite-schema.ts | 36 ++++++++++-- src/core/postgres-engine.ts | 6 +- src/core/schema-embedded.ts | 35 ++++++++++-- test/schema-templating.test.ts | 94 ++++++++++++++++++++++++++++++ 10 files changed, 285 insertions(+), 29 deletions(-) create mode 100644 test/schema-templating.test.ts diff --git a/src/cli.ts b/src/cli.ts index 33f149f4..55c1dab5 100644 --- a/src/cli.ts +++ b/src/cli.ts @@ -31,6 +31,17 @@ async function main() { if (command === '--version' || command === 'version') { console.log(`gbrain ${VERSION}`); + // Surface the active embedding provider so users running in multiple shells + // notice when they're on a non-default brain (local Ollama vs OpenAI). + try { + const cfg = loadConfig(); + if (cfg?.embedding) { + const { provider, model, dimensions } = cfg.embedding; + console.log(`embedding: ${provider} / ${model} (${dimensions}d)`); + } + } catch { + // Config not readable — fine, --version shouldn't fail on that + } return; } diff --git a/src/commands/init.ts b/src/commands/init.ts index 6380912f..68317f4e 100644 --- a/src/commands/init.ts +++ b/src/commands/init.ts @@ -6,8 +6,50 @@ import { homedir } from 'os'; const __filename = fileURLToPath(import.meta.url); const __dirname = dirname(__filename); -import { saveConfig, type GBrainConfig } from '../core/config.ts'; +import { loadConfig, saveConfig, type GBrainConfig } from '../core/config.ts'; import { createEngine } from '../core/engine-factory.ts'; +import { createProvider, resolveConfig as resolveEmbeddingConfig } from '../core/embedding/index.ts'; +import type { EmbeddingProvider, ProviderConfig } from '../core/embedding/index.ts'; + +/** + * Parse --provider / --model / --dimensions / --base-url flags. + * Falls back to EMBEDDING_* env vars (handled inside resolveEmbeddingConfig). + */ +function parseEmbeddingFlags(args: string[]): Partial { + const flag = (name: string): string | undefined => { + const i = args.indexOf(name); + return i !== -1 ? args[i + 1] : undefined; + }; + const dims = flag('--dimensions'); + return { + provider: flag('--provider'), + model: flag('--model'), + dimensions: dims ? parseInt(dims, 10) : undefined, + baseUrl: flag('--base-url'), + }; +} + +/** + * Resolve the embedding provider for this init, and guard against dim-mismatch + * with any existing brain config. Returns the instantiated provider. + */ +function resolveProviderWithGuard(args: string[]): { provider: EmbeddingProvider; resolved: ProviderConfig } { + const resolved = resolveEmbeddingConfig(parseEmbeddingFlags(args)); + const provider = createProvider(resolved); // validates config + infers dims + + const existing = loadConfig(); + if (existing?.embedding && existing.embedding.dimensions !== provider.dimensions) { + console.error(''); + console.error('Cannot re-init: existing brain has a different embedding dimension.'); + console.error(` Existing: ${existing.embedding.provider} / ${existing.embedding.model} (${existing.embedding.dimensions}d)`); + console.error(` Requested: ${provider.name} / ${provider.model} (${provider.dimensions}d)`); + console.error(''); + console.error('Switching providers requires regenerating all embeddings.'); + console.error('To start fresh: delete ~/.gbrain/config.json and the brain data directory, then rerun gbrain init.'); + process.exit(1); + } + return { provider, resolved }; +} export async function runInit(args: string[]) { const isSupabase = args.includes('--supabase'); @@ -21,6 +63,10 @@ export async function runInit(args: string[]) { const pathIndex = args.indexOf('--path'); const customPath = pathIndex !== -1 ? args[pathIndex + 1] : null; + // Resolve embedding provider up front. Fails fast on bad provider/model/dims + // or dim-mismatch with an existing brain — before any engine state is created. + const { provider, resolved: providerResolved } = resolveProviderWithGuard(args); + // Explicit PGLite mode if (isPGLite || (!isSupabase && !manualUrl && !isNonInteractive)) { // Smart detection: scan for .md files unless --pglite flag forces it @@ -37,7 +83,7 @@ export async function runInit(args: string[]) { } } - return initPGLite({ jsonOutput, apiKey, customPath }); + return initPGLite({ jsonOutput, apiKey, customPath, provider, providerResolved }); } // Supabase/Postgres mode @@ -56,20 +102,33 @@ export async function runInit(args: string[]) { databaseUrl = await supabaseWizard(); } - return initPostgres({ databaseUrl, jsonOutput, apiKey }); + return initPostgres({ databaseUrl, jsonOutput, apiKey, provider, providerResolved }); } -async function initPGLite(opts: { jsonOutput: boolean; apiKey: string | null; customPath: string | null }) { +async function initPGLite(opts: { + jsonOutput: boolean; + apiKey: string | null; + customPath: string | null; + provider: EmbeddingProvider; + providerResolved: ProviderConfig; +}) { const dbPath = opts.customPath || join(homedir(), '.gbrain', 'brain.pglite'); console.log(`Setting up local brain with PGLite (no server needed)...`); + console.log(`Embedding: ${opts.provider.name} / ${opts.provider.model} (${opts.provider.dimensions}d)`); const engine = await createEngine({ engine: 'pglite' }); await engine.connect({ database_path: dbPath, engine: 'pglite' }); - await engine.initSchema(); + await engine.initSchema({ dimensions: opts.provider.dimensions, defaultModel: opts.provider.model }); const config: GBrainConfig = { engine: 'pglite', database_path: dbPath, + embedding: { + provider: opts.provider.name, + model: opts.provider.model, + dimensions: opts.provider.dimensions, + ...(opts.providerResolved.baseUrl ? { base_url: opts.providerResolved.baseUrl } : {}), + }, ...(opts.apiKey ? { openai_api_key: opts.apiKey } : {}), }; saveConfig(config); @@ -78,7 +137,13 @@ async function initPGLite(opts: { jsonOutput: boolean; apiKey: string | null; cu await engine.disconnect(); if (opts.jsonOutput) { - console.log(JSON.stringify({ status: 'success', engine: 'pglite', path: dbPath, pages: stats.page_count })); + console.log(JSON.stringify({ + status: 'success', + engine: 'pglite', + path: dbPath, + pages: stats.page_count, + embedding: config.embedding, + })); } else { console.log(`\nBrain ready at ${dbPath}`); console.log(`${stats.page_count} pages. Engine: PGLite (local Postgres).`); @@ -89,7 +154,13 @@ async function initPGLite(opts: { jsonOutput: boolean; apiKey: string | null; cu } } -async function initPostgres(opts: { databaseUrl: string; jsonOutput: boolean; apiKey: string | null }) { +async function initPostgres(opts: { + databaseUrl: string; + jsonOutput: boolean; + apiKey: string | null; + provider: EmbeddingProvider; + providerResolved: ProviderConfig; +}) { const { databaseUrl } = opts; // Detect Supabase direct connection URLs and warn about IPv6 @@ -137,11 +208,18 @@ async function initPostgres(opts: { databaseUrl: string; jsonOutput: boolean; ap } console.log('Running schema migration...'); - await engine.initSchema(); + console.log(`Embedding: ${opts.provider.name} / ${opts.provider.model} (${opts.provider.dimensions}d)`); + await engine.initSchema({ dimensions: opts.provider.dimensions, defaultModel: opts.provider.model }); const config: GBrainConfig = { engine: 'postgres', database_url: databaseUrl, + embedding: { + provider: opts.provider.name, + model: opts.provider.model, + dimensions: opts.provider.dimensions, + ...(opts.providerResolved.baseUrl ? { base_url: opts.providerResolved.baseUrl } : {}), + }, ...(opts.apiKey ? { openai_api_key: opts.apiKey } : {}), }; saveConfig(config); @@ -151,7 +229,12 @@ async function initPostgres(opts: { databaseUrl: string; jsonOutput: boolean; ap await engine.disconnect(); if (opts.jsonOutput) { - console.log(JSON.stringify({ status: 'success', engine: 'postgres', pages: stats.page_count })); + console.log(JSON.stringify({ + status: 'success', + engine: 'postgres', + pages: stats.page_count, + embedding: config.embedding, + })); } else { console.log(`\nBrain ready. ${stats.page_count} pages. Engine: Postgres (Supabase).`); console.log('Next: gbrain import '); diff --git a/src/core/config.ts b/src/core/config.ts index dcc7a14b..811515e2 100644 --- a/src/core/config.ts +++ b/src/core/config.ts @@ -13,6 +13,23 @@ export interface GBrainConfig { database_path?: string; openai_api_key?: string; anthropic_api_key?: string; + /** + * Embedding provider config, persisted at `gbrain init` and frozen for the + * brain's life. Presence indicates a provider was chosen explicitly; absence + * means legacy behavior (OpenAI text-embedding-3-large 1536d via env vars). + */ + embedding?: EmbeddingConfig; +} + +export interface EmbeddingConfig { + /** Provider name. Currently 'openai' or 'ollama'. */ + provider: string; + /** Model identifier. */ + model: string; + /** Output vector dimension — MUST match the pgvector schema column. */ + dimensions: number; + /** Optional base URL override for OpenAI-compatible endpoints. */ + base_url?: string; } /** diff --git a/src/core/db.ts b/src/core/db.ts index 2edaa811..1244ed64 100644 --- a/src/core/db.ts +++ b/src/core/db.ts @@ -1,6 +1,6 @@ import postgres from 'postgres'; import { GBrainError, type EngineConfig } from './types.ts'; -import { SCHEMA_SQL } from './schema-embedded.ts'; +import { postgresSchema } from './schema-embedded.ts'; let sql: ReturnType | null = null; let connectedUrl: string | null = null; @@ -68,12 +68,12 @@ export async function disconnect(): Promise { } } -export async function initSchema(): Promise { +export async function initSchema(opts?: { dimensions?: number; defaultModel?: string }): Promise { const conn = getConnection(); // Advisory lock prevents concurrent initSchema() calls from deadlocking await conn`SELECT pg_advisory_lock(42)`; try { - await conn.unsafe(SCHEMA_SQL); + await conn.unsafe(postgresSchema(opts)); } finally { await conn`SELECT pg_advisory_unlock(42)`; } diff --git a/src/core/engine.ts b/src/core/engine.ts index 63abf3e3..1237ce0e 100644 --- a/src/core/engine.ts +++ b/src/core/engine.ts @@ -25,7 +25,7 @@ export interface BrainEngine { // Lifecycle connect(config: EngineConfig): Promise; disconnect(): Promise; - initSchema(): Promise; + initSchema(opts?: { dimensions?: number; defaultModel?: string }): Promise; transaction(fn: (engine: BrainEngine) => Promise): Promise; // Pages CRUD diff --git a/src/core/pglite-engine.ts b/src/core/pglite-engine.ts index cc1ca310..d779ec5f 100644 --- a/src/core/pglite-engine.ts +++ b/src/core/pglite-engine.ts @@ -5,7 +5,7 @@ import type { Transaction } from '@electric-sql/pglite'; import type { BrainEngine } from './engine.ts'; import { MAX_SEARCH_LIMIT, clampSearchLimit } from './engine.ts'; import { runMigrations } from './migrate.ts'; -import { PGLITE_SCHEMA_SQL } from './pglite-schema.ts'; +import { pgliteSchema } from './pglite-schema.ts'; import { acquireLock, releaseLock, type LockHandle } from './pglite-lock.ts'; import type { Page, PageInput, PageFilters, PageType, @@ -60,8 +60,8 @@ export class PGLiteEngine implements BrainEngine { } } - async initSchema(): Promise { - await this.db.exec(PGLITE_SCHEMA_SQL); + async initSchema(opts?: { dimensions?: number; defaultModel?: string }): Promise { + await this.db.exec(pgliteSchema(opts)); const { applied } = await runMigrations(this); if (applied > 0) { diff --git a/src/core/pglite-schema.ts b/src/core/pglite-schema.ts index 13fad56d..c751f0be 100644 --- a/src/core/pglite-schema.ts +++ b/src/core/pglite-schema.ts @@ -1,6 +1,10 @@ /** * PGLite schema — derived from schema-embedded.ts (Postgres schema). * + * The schema is templated by embedding dimensions and default model so the brain + * can be initialized for any provider (OpenAI 1536d, Ollama nomic 768d, etc.) + * without editing this file. + * * Differences from Postgres: * - No RLS block (no role system in embedded PGLite) * - No access_tokens / mcp_request_log (local-only, no remote auth) @@ -13,7 +17,21 @@ * test/edge-bundle.test.ts has a drift detection test. */ -export const PGLITE_SCHEMA_SQL = ` +export interface SchemaOpts { + /** pgvector column dimension. Defaults to 1536 (OpenAI text-embedding-3-large). */ + dimensions?: number; + /** Default model string written into the `model` column and config rows. */ + defaultModel?: string; +} + +const DEFAULT_DIMENSIONS = 1536; +const DEFAULT_MODEL = 'text-embedding-3-large'; + +export function pgliteSchema(opts: SchemaOpts = {}): string { + const dims = opts.dimensions ?? DEFAULT_DIMENSIONS; + const model = opts.defaultModel ?? DEFAULT_MODEL; + + return ` -- GBrain PGLite schema (local embedded Postgres) CREATE EXTENSION IF NOT EXISTS vector; @@ -48,8 +66,8 @@ CREATE TABLE IF NOT EXISTS content_chunks ( chunk_index INTEGER NOT NULL, chunk_text TEXT NOT NULL, chunk_source TEXT NOT NULL DEFAULT 'compiled_truth', - embedding vector(1536), - model TEXT NOT NULL DEFAULT 'text-embedding-3-large', + embedding vector(${dims}), + model TEXT NOT NULL DEFAULT '${model}', token_count INTEGER, embedded_at TIMESTAMPTZ, created_at TIMESTAMPTZ NOT NULL DEFAULT now() @@ -154,8 +172,8 @@ CREATE TABLE IF NOT EXISTS config ( INSERT INTO config (key, value) VALUES ('version', '1'), ('engine', 'pglite'), - ('embedding_model', 'text-embedding-3-large'), - ('embedding_dimensions', '1536'), + ('embedding_model', '${model}'), + ('embedding_dimensions', '${dims}'), ('chunk_strategy', 'semantic') ON CONFLICT (key) DO NOTHING; @@ -207,3 +225,11 @@ CREATE TRIGGER trg_timeline_search_vector FOR EACH ROW EXECUTE FUNCTION update_page_search_vector_from_timeline(); `; +} + +/** + * Backward-compat constant alias. Evaluates `pgliteSchema()` with defaults + * (OpenAI text-embedding-3-large at 1536 dimensions) — same SQL as before the + * schema-templating change. + */ +export const PGLITE_SCHEMA_SQL = pgliteSchema(); diff --git a/src/core/postgres-engine.ts b/src/core/postgres-engine.ts index dc536c73..341d4aba 100644 --- a/src/core/postgres-engine.ts +++ b/src/core/postgres-engine.ts @@ -2,7 +2,7 @@ import postgres from 'postgres'; import type { BrainEngine } from './engine.ts'; import { MAX_SEARCH_LIMIT, clampSearchLimit } from './engine.ts'; import { runMigrations } from './migrate.ts'; -import { SCHEMA_SQL } from './schema-embedded.ts'; +import { postgresSchema } from './schema-embedded.ts'; import type { Page, PageInput, PageFilters, Chunk, ChunkInput, @@ -56,13 +56,13 @@ export class PostgresEngine implements BrainEngine { } } - async initSchema(): Promise { + async initSchema(opts?: { dimensions?: number; defaultModel?: string }): Promise { const conn = this.sql; // Advisory lock prevents concurrent initSchema() calls from deadlocking // on DDL statements (DROP TRIGGER + CREATE TRIGGER acquire AccessExclusiveLock) await conn`SELECT pg_advisory_lock(42)`; try { - await conn.unsafe(SCHEMA_SQL); + await conn.unsafe(postgresSchema(opts)); // Run any pending migrations automatically const { applied } = await runMigrations(this); diff --git a/src/core/schema-embedded.ts b/src/core/schema-embedded.ts index eb0759e8..1de65b5b 100644 --- a/src/core/schema-embedded.ts +++ b/src/core/schema-embedded.ts @@ -1,7 +1,24 @@ // AUTO-GENERATED — do not edit. Run: bun run build:schema // Source: src/schema.sql +// +// Schema templated by embedding dimensions and default model so the brain +// can be initialized for any provider (OpenAI 1536d, Ollama nomic 768d, etc.). -export const SCHEMA_SQL = ` +export interface SchemaOpts { + /** pgvector column dimension. Defaults to 1536 (OpenAI text-embedding-3-large). */ + dimensions?: number; + /** Default model string written into the `model` column and config rows. */ + defaultModel?: string; +} + +const DEFAULT_DIMENSIONS = 1536; +const DEFAULT_MODEL = 'text-embedding-3-large'; + +export function postgresSchema(opts: SchemaOpts = {}): string { + const dims = opts.dimensions ?? DEFAULT_DIMENSIONS; + const model = opts.defaultModel ?? DEFAULT_MODEL; + + return ` -- GBrain Postgres + pgvector schema CREATE EXTENSION IF NOT EXISTS vector; @@ -36,8 +53,8 @@ CREATE TABLE IF NOT EXISTS content_chunks ( chunk_index INTEGER NOT NULL, chunk_text TEXT NOT NULL, chunk_source TEXT NOT NULL DEFAULT 'compiled_truth', - embedding vector(1536), - model TEXT NOT NULL DEFAULT 'text-embedding-3-large', + embedding vector(${dims}), + model TEXT NOT NULL DEFAULT '${model}', token_count INTEGER, embedded_at TIMESTAMPTZ, created_at TIMESTAMPTZ NOT NULL DEFAULT now() @@ -141,8 +158,8 @@ CREATE TABLE IF NOT EXISTS config ( INSERT INTO config (key, value) VALUES ('version', '1'), - ('embedding_model', 'text-embedding-3-large'), - ('embedding_dimensions', '1536'), + ('embedding_model', '${model}'), + ('embedding_dimensions', '${dims}'), ('chunk_strategy', 'semantic') ON CONFLICT (key) DO NOTHING; @@ -277,3 +294,11 @@ BEGIN END IF; END \$\$; `; +} + +/** + * Backward-compat constant alias. Evaluates `postgresSchema()` with defaults + * (OpenAI text-embedding-3-large at 1536 dimensions) — same SQL as before the + * schema-templating change. + */ +export const SCHEMA_SQL = postgresSchema(); diff --git a/test/schema-templating.test.ts b/test/schema-templating.test.ts new file mode 100644 index 00000000..ef0d12a8 --- /dev/null +++ b/test/schema-templating.test.ts @@ -0,0 +1,94 @@ +import { describe, test, expect } from 'bun:test'; +import { pgliteSchema, PGLITE_SCHEMA_SQL } from '../src/core/pglite-schema.ts'; +import { postgresSchema, SCHEMA_SQL } from '../src/core/schema-embedded.ts'; + +/** + * Schema templating must: + * 1) default to (1536, 'text-embedding-3-large') — backward compat + * 2) substitute vector(dim) + DEFAULT 'model' when opts given + * 3) not leak template placeholder strings into the SQL output + * 4) keep the const aliases identical to the default-function output + */ + +describe('pgliteSchema', () => { + test('defaults to 1536d + text-embedding-3-large', () => { + const sql = pgliteSchema(); + expect(sql).toContain('vector(1536)'); + expect(sql).toContain("DEFAULT 'text-embedding-3-large'"); + expect(sql).toContain("('embedding_model', 'text-embedding-3-large')"); + expect(sql).toContain("('embedding_dimensions', '1536')"); + }); + + test('templates to Ollama dims + model when opts given', () => { + const sql = pgliteSchema({ dimensions: 768, defaultModel: 'nomic-embed-text' }); + expect(sql).toContain('vector(768)'); + expect(sql).not.toContain('vector(1536)'); + expect(sql).toContain("DEFAULT 'nomic-embed-text'"); + expect(sql).not.toContain("DEFAULT 'text-embedding-3-large'"); + expect(sql).toContain("('embedding_model', 'nomic-embed-text')"); + expect(sql).toContain("('embedding_dimensions', '768')"); + }); + + test('no template placeholder strings leak into output', () => { + const sql = pgliteSchema({ dimensions: 1024, defaultModel: 'mxbai-embed-large' }); + expect(sql).not.toContain('${'); + expect(sql).not.toContain('undefined'); + expect(sql).not.toContain('null'); + }); + + test('const alias matches default function output', () => { + expect(PGLITE_SCHEMA_SQL).toBe(pgliteSchema()); + }); + + test('partial opts — only dimensions', () => { + const sql = pgliteSchema({ dimensions: 3072 }); + expect(sql).toContain('vector(3072)'); + expect(sql).toContain("DEFAULT 'text-embedding-3-large'"); // falls back to default model + }); + + test('partial opts — only model', () => { + const sql = pgliteSchema({ defaultModel: 'voyage-3' }); + expect(sql).toContain('vector(1536)'); // falls back to default dim + expect(sql).toContain("DEFAULT 'voyage-3'"); + }); +}); + +describe('postgresSchema', () => { + test('defaults to 1536d + text-embedding-3-large', () => { + const sql = postgresSchema(); + expect(sql).toContain('vector(1536)'); + expect(sql).toContain("DEFAULT 'text-embedding-3-large'"); + }); + + test('templates to Ollama dims + model', () => { + const sql = postgresSchema({ dimensions: 768, defaultModel: 'nomic-embed-text' }); + expect(sql).toContain('vector(768)'); + expect(sql).toContain("DEFAULT 'nomic-embed-text'"); + expect(sql).toContain("('embedding_model', 'nomic-embed-text')"); + expect(sql).toContain("('embedding_dimensions', '768')"); + }); + + test('preserves Postgres dollar-quoted functions after templating', () => { + const sql = postgresSchema({ dimensions: 768, defaultModel: 'nomic-embed-text' }); + // Dollar-quoted plpgsql function bodies must survive the template (they use $$ markers) + expect(sql).toContain('CREATE OR REPLACE FUNCTION update_page_search_vector()'); + expect(sql).toContain('LANGUAGE plpgsql'); + // Two function definitions + expect((sql.match(/\$\$ LANGUAGE plpgsql/g) || []).length).toBe(2); + }); + + test('const alias matches default function output', () => { + expect(SCHEMA_SQL).toBe(postgresSchema()); + }); +}); + +describe('schema drift between PGLite and Postgres', () => { + test('both schemas have matching embedding dim when called with same opts', () => { + const p = pgliteSchema({ dimensions: 768, defaultModel: 'nomic-embed-text' }); + const g = postgresSchema({ dimensions: 768, defaultModel: 'nomic-embed-text' }); + expect(p).toContain('vector(768)'); + expect(g).toContain('vector(768)'); + expect(p).toContain("DEFAULT 'nomic-embed-text'"); + expect(g).toContain("DEFAULT 'nomic-embed-text'"); + }); +}); From 7081e12ef7c48fe0712acde206b10619d72479f8 Mon Sep 17 00:00:00 2001 From: yang Date: Thu, 16 Apr 2026 23:08:31 -0400 Subject: [PATCH 3/4] fix(init): support --flag=value form for embedding flags `gbrain init --provider=ollama --model=nomic-embed-text --base-url=http://...` was silently falling through to defaults because parseEmbeddingFlags only handled `--flag value` (space-separated) form. Supporting both forms is standard CLI behavior. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/cli.ts | 0 src/commands/init.ts | 7 +++++++ 2 files changed, 7 insertions(+) mode change 100644 => 100755 src/cli.ts diff --git a/src/cli.ts b/src/cli.ts old mode 100644 new mode 100755 diff --git a/src/commands/init.ts b/src/commands/init.ts index 68317f4e..e1e7aa93 100644 --- a/src/commands/init.ts +++ b/src/commands/init.ts @@ -13,10 +13,17 @@ import type { EmbeddingProvider, ProviderConfig } from '../core/embedding/index. /** * Parse --provider / --model / --dimensions / --base-url flags. + * Supports both `--flag value` (space) and `--flag=value` (equals) forms. * Falls back to EMBEDDING_* env vars (handled inside resolveEmbeddingConfig). */ function parseEmbeddingFlags(args: string[]): Partial { const flag = (name: string): string | undefined => { + // `--flag=value` form + const prefix = name + '='; + for (const a of args) { + if (a.startsWith(prefix)) return a.slice(prefix.length); + } + // `--flag value` form const i = args.indexOf(name); return i !== -1 ? args[i + 1] : undefined; }; From 41469e15318a10455ebcde71913e74056ab37a00 Mon Sep 17 00:00:00 2001 From: yang Date: Thu, 16 Apr 2026 23:10:25 -0400 Subject: [PATCH 4/4] fix(cli): hydrate EmbeddingProvider from config on engine connect connectEngine() only loaded the database config. Commands that trigger embedding (embed, import, query, search) fell back to the service's default provider (OpenAI) regardless of what the brain was initialized with, causing 401s when the brain was configured for Ollama. Now connectEngine reads config.embedding, builds the matching provider, and installs it via setProvider before any command runs. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/cli.ts | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/cli.ts b/src/cli.ts index 55c1dab5..ec92f941 100755 --- a/src/cli.ts +++ b/src/cli.ts @@ -390,6 +390,20 @@ async function connectEngine(): Promise { console.error('No brain configured. Run: gbrain init'); process.exit(1); } + + // Hydrate the embedding provider from the brain's persisted config so all + // commands (embed, import, query) use the provider the brain was initialized + // with — not whatever EMBEDDING_* env vars happen to be set. + if (config.embedding) { + const { createProvider, setProvider } = await import('./core/embedding/index.ts'); + setProvider(createProvider({ + provider: config.embedding.provider, + model: config.embedding.model, + dimensions: config.embedding.dimensions, + baseUrl: config.embedding.base_url, + })); + } + const { createEngine } = await import('./core/engine-factory.ts'); const engine = await createEngine(toEngineConfig(config)); await engine.connect(toEngineConfig(config));