diff --git a/README.md b/README.md index d61cbf9c..1b8a38e7 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ An on-device search engine for everything you need to remember. Index your markdown notes, meeting transcripts, documentation, and knowledge bases. Search with keywords or natural language. Ideal for your agentic flows. -QMD combines BM25 full-text search, vector semantic search, and LLM re-ranking—all running locally via node-llama-cpp with GGUF models. +QMD combines BM25 full-text search, vector semantic search, and LLM re-ranking. By default it runs locally via node-llama-cpp with GGUF models, with optional OpenRouter-backed inference. ![QMD Architecture](assets/qmd-architecture.png) @@ -280,6 +280,36 @@ Supported model families: > since vectors are not cross-compatible between models. The prompt format is > automatically adjusted for each model family. +### OpenRouter Mode (Optional) + +QMD defaults to `local` inference. To use OpenRouter for embeddings, query expansion, and reranking: + +```sh +export QMD_LLM_PROVIDER=openrouter +export QMD_OPENROUTER_API_KEY="sk-or-..." +``` + +When OpenRouter mode is active, QMD prints a single remote-inference notice per process. + +You can also store the key in a file (default path): + +```sh +mkdir -p ~/.config/qmd +chmod 700 ~/.config/qmd +printf '%s\n' 'sk-or-...' > ~/.config/qmd/openrouter.key +chmod 600 ~/.config/qmd/openrouter.key +export QMD_LLM_PROVIDER=openrouter +``` + +Supported key env vars: +- `QMD_OPENROUTER_API_KEY` (preferred) +- `OPENROUTER_API_KEY` +- `QMD_OPENROUTER_API_KEY_FILE` (custom key file path) + +Optional model overrides: +- `QMD_OPENROUTER_EMBED_MODEL` +- `QMD_OPENROUTER_GENERATE_MODEL` +- `QMD_OPENROUTER_RERANK_MODEL` ## Installation ```sh diff --git a/src/llm.openrouter.test.ts b/src/llm.openrouter.test.ts new file mode 100644 index 00000000..f56c06c1 --- /dev/null +++ b/src/llm.openrouter.test.ts @@ -0,0 +1,192 @@ +import { describe, test, expect, beforeEach, afterEach } from "bun:test"; +import { mkdtemp, rm, writeFile } from "node:fs/promises"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import { + OpenRouterLLM, + getDefaultLLM, + getDefaultLLMProvider, + disposeDefaultLLM, + resetDefaultLLMForTests, +} from "./llm.js"; + +function jsonResponse(body: unknown, status: number = 200): Response { + return new Response(JSON.stringify(body), { + status, + headers: { "content-type": "application/json" }, + }); +} + +const originalFetch = globalThis.fetch; + +describe("OpenRouter provider", () => { + beforeEach(() => { + delete process.env.QMD_LLM_PROVIDER; + delete process.env.QMD_OPENROUTER_API_KEY; + delete process.env.OPENROUTER_API_KEY; + delete process.env.QMD_OPENROUTER_API_KEY_FILE; + delete process.env.QMD_OPENROUTER_BASE_URL; + delete process.env.QMD_OPENROUTER_EMBED_MODEL; + delete process.env.QMD_OPENROUTER_GENERATE_MODEL; + delete process.env.QMD_OPENROUTER_RERANK_MODEL; + resetDefaultLLMForTests(); + }); + + afterEach(async () => { + globalThis.fetch = originalFetch; + await disposeDefaultLLM(); + resetDefaultLLMForTests(); + }); + + test("uses OpenRouter provider when QMD_LLM_PROVIDER=openrouter", () => { + process.env.QMD_LLM_PROVIDER = "openrouter"; + process.env.QMD_OPENROUTER_API_KEY = "test-key"; + + const llm = getDefaultLLM(); + + expect(getDefaultLLMProvider()).toBe("openrouter"); + expect(llm).toBeInstanceOf(OpenRouterLLM); + }); + + test("prints remote notice only once per process", () => { + process.env.QMD_LLM_PROVIDER = "openrouter"; + process.env.QMD_OPENROUTER_API_KEY = "test-key"; + + const stderrAny = process.stderr as any; + const originalWrite = stderrAny.write; + const writes: string[] = []; + stderrAny.write = (chunk: any) => { + writes.push(String(chunk)); + return true; + }; + + try { + getDefaultLLM(); + getDefaultLLM(); + } finally { + stderrAny.write = originalWrite; + } + + const notices = writes.filter(line => line.includes("OpenRouter")); + expect(notices.length).toBe(1); + }); + + test("embed sends OpenRouter embeddings request", async () => { + const calls: Array<{ url: string; body: any; headers: Record }> = []; + globalThis.fetch = async (url: string | URL | Request, init?: RequestInit): Promise => { + const parsedBody = JSON.parse(String(init?.body || "{}")); + calls.push({ + url: String(url), + body: parsedBody, + headers: init?.headers as Record, + }); + return jsonResponse({ + data: [{ index: 0, embedding: [0.5, 0.25, -0.1] }], + }); + }; + + const llm = new OpenRouterLLM({ + apiKey: "abc123", + baseUrl: "https://openrouter.ai/api/v1", + embedModel: "openai/text-embedding-3-small", + }); + + const result = await llm.embed("hello world"); + + expect(result).not.toBeNull(); + expect(result!.embedding).toEqual([0.5, 0.25, -0.1]); + expect(calls).toHaveLength(1); + expect(calls[0]!.url).toBe("https://openrouter.ai/api/v1/embeddings"); + expect(calls[0]!.body.model).toBe("openai/text-embedding-3-small"); + expect(calls[0]!.body.input).toBe("hello world"); + expect(calls[0]!.headers.Authorization).toBe("Bearer abc123"); + }); + + test("embedBatch maps embeddings by index order", async () => { + globalThis.fetch = async () => jsonResponse({ + data: [ + { index: 1, embedding: [0, 1] }, + { index: 0, embedding: [1, 0] }, + ], + }); + + const llm = new OpenRouterLLM({ apiKey: "test-key" }); + const results = await llm.embedBatch(["first", "second"]); + + expect(results).toHaveLength(2); + expect(results[0]!.embedding).toEqual([1, 0]); + expect(results[1]!.embedding).toEqual([0, 1]); + }); + + test("expandQuery parses typed query lines and filters lexical when disabled", async () => { + globalThis.fetch = async () => jsonResponse({ + choices: [ + { + message: { + content: "lex: deploy auth service\nvec: deploy authentication stack\nhyde: documentation for deploying auth service", + }, + }, + ], + }); + + const llm = new OpenRouterLLM({ apiKey: "test-key" }); + const queryables = await llm.expandQuery("deploy auth service", { includeLexical: false }); + + expect(queryables.some(q => q.type === "lex")).toBe(false); + expect(queryables.some(q => q.type === "vec")).toBe(true); + expect(queryables.some(q => q.type === "hyde")).toBe(true); + }); + + test("rerank uses embedding similarity and sorts descending", async () => { + globalThis.fetch = async (_url: string | URL | Request, init?: RequestInit): Promise => { + const body = JSON.parse(String(init?.body || "{}")); + + if (typeof body.input === "string") { + return jsonResponse({ + data: [{ index: 0, embedding: [1, 0] }], + }); + } + + return jsonResponse({ + data: [ + { index: 0, embedding: [0.9, 0] }, + { index: 1, embedding: [0, 1] }, + ], + }); + }; + + const llm = new OpenRouterLLM({ apiKey: "test-key", rerankModel: "openai/text-embedding-3-small" }); + const reranked = await llm.rerank("auth query", [ + { file: "a.md", text: "authentication docs" }, + { file: "b.md", text: "gardening notes" }, + ]); + + expect(reranked.results).toHaveLength(2); + expect(reranked.results[0]!.file).toBe("a.md"); + expect(reranked.results[0]!.score).toBeGreaterThan(reranked.results[1]!.score); + }); + + test("loads API key from file when env var is not set", async () => { + const tempDir = await mkdtemp(join(tmpdir(), "qmd-openrouter-test-")); + const keyFile = join(tempDir, "openrouter.key"); + await writeFile(keyFile, "file-key-123\n", "utf-8"); + + try { + const authHeaders: string[] = []; + globalThis.fetch = async (_url: string | URL | Request, init?: RequestInit): Promise => { + const headers = init?.headers as Record; + authHeaders.push(headers.Authorization || ""); + return jsonResponse({ + data: [{ index: 0, embedding: [0.1, 0.2] }], + }); + }; + + const llm = new OpenRouterLLM({ apiKeyFile: keyFile }); + await llm.embed("test"); + + expect(authHeaders[0]).toBe("Bearer file-key-123"); + } finally { + await rm(tempDir, { recursive: true, force: true }); + } + }); +}); diff --git a/src/llm.ts b/src/llm.ts index 100a1ec7..80de81c4 100644 --- a/src/llm.ts +++ b/src/llm.ts @@ -1,7 +1,8 @@ /** - * llm.ts - LLM abstraction layer for QMD using node-llama-cpp + * llm.ts - LLM abstraction layer for QMD * - * Provides embeddings, text generation, and reranking using local GGUF models. + * Provides embeddings, text generation, and reranking using local GGUF models + * or OpenRouter-hosted models. */ import { @@ -186,6 +187,11 @@ export type RerankDocument = { title?: string; }; +/** + * Backing inference provider + */ +export type LLMProvider = "local" | "openrouter"; + // ============================================================================= // Model Configuration // ============================================================================= @@ -197,6 +203,13 @@ const DEFAULT_EMBED_MODEL = process.env.QMD_EMBED_MODEL ?? "hf:ggml-org/embeddin const DEFAULT_RERANK_MODEL = "hf:ggml-org/Qwen3-Reranker-0.6B-Q8_0-GGUF/qwen3-reranker-0.6b-q8_0.gguf"; // const DEFAULT_GENERATE_MODEL = "hf:ggml-org/Qwen3-0.6B-GGUF/Qwen3-0.6B-Q8_0.gguf"; const DEFAULT_GENERATE_MODEL = "hf:tobil/qmd-query-expansion-1.7B-gguf/qmd-query-expansion-1.7B-q4_k_m.gguf"; +const DEFAULT_PROVIDER: LLMProvider = "local"; + +const DEFAULT_OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1"; +const DEFAULT_OPENROUTER_EMBED_MODEL = "openai/text-embedding-3-small"; +const DEFAULT_OPENROUTER_GENERATE_MODEL = "openai/gpt-4o-mini"; +const DEFAULT_OPENROUTER_RERANK_MODEL = "openai/text-embedding-3-small"; +const DEFAULT_OPENROUTER_API_KEY_FILE = join(homedir(), ".config", "qmd", "openrouter.key"); // Alternative generation models for query expansion: // LiquidAI LFM2 - hybrid architecture optimized for edge/on-device inference @@ -207,6 +220,8 @@ export const LFM2_INSTRUCT_MODEL = "hf:LiquidAI/LFM2.5-1.2B-Instruct-GGUF/LFM2.5 export const DEFAULT_EMBED_MODEL_URI = DEFAULT_EMBED_MODEL; export const DEFAULT_RERANK_MODEL_URI = DEFAULT_RERANK_MODEL; export const DEFAULT_GENERATE_MODEL_URI = DEFAULT_GENERATE_MODEL; +export const DEFAULT_OPENROUTER_BASE_URL_URI = DEFAULT_OPENROUTER_BASE_URL; +export const DEFAULT_OPENROUTER_API_KEY_PATH = DEFAULT_OPENROUTER_API_KEY_FILE; // Local model cache directory const MODEL_CACHE_DIR = join(homedir(), ".cache", "qmd", "models"); @@ -304,6 +319,93 @@ export async function pullModels( return results; } +function normalizeProvider(provider: string | undefined): LLMProvider { + const value = (provider || DEFAULT_PROVIDER).trim().toLowerCase(); + if (value === "local" || value === "openrouter") { + return value; + } + console.error(`Unknown QMD_LLM_PROVIDER="${provider}". Falling back to "local".`); + return "local"; +} + +function stripTrailingSlash(url: string): string { + return url.endsWith("/") ? url.slice(0, -1) : url; +} + +function trimSingleLine(value: string): string { + const firstLine = value.split(/\r?\n/, 1)[0] ?? ""; + return firstLine.trim(); +} + +function loadOpenRouterApiKey(config: { apiKey?: string; apiKeyFile?: string } = {}): string { + const directKey = config.apiKey?.trim(); + if (directKey) return directKey; + + const envKey = process.env.QMD_OPENROUTER_API_KEY?.trim() || process.env.OPENROUTER_API_KEY?.trim(); + if (envKey) return envKey; + + const keyFile = config.apiKeyFile || process.env.QMD_OPENROUTER_API_KEY_FILE || DEFAULT_OPENROUTER_API_KEY_FILE; + if (existsSync(keyFile)) { + const fileKey = trimSingleLine(readFileSync(keyFile, "utf-8")); + if (fileKey) return fileKey; + throw new Error(`OpenRouter API key file exists but is empty: ${keyFile}`); + } + + throw new Error( + `OpenRouter API key missing. Set QMD_OPENROUTER_API_KEY (or OPENROUTER_API_KEY), ` + + `or write the key to ${keyFile}` + ); +} + +function parseExpandedQueryLines(raw: string, query: string, includeLexical: boolean): Queryable[] { + const lines = raw.trim().split("\n").map(line => line.trim()).filter(Boolean); + const queryLower = query.toLowerCase(); + const queryTerms = queryLower.replace(/[^a-z0-9\s]/g, " ").split(/\s+/).filter(Boolean); + + const hasQueryTerm = (text: string): boolean => { + const lower = text.toLowerCase(); + if (queryTerms.length === 0) return true; + return queryTerms.some(term => lower.includes(term)); + }; + + const queryables: Queryable[] = lines.map(line => { + const colonIdx = line.indexOf(":"); + if (colonIdx === -1) return null; + const type = line.slice(0, colonIdx).trim(); + if (type !== "lex" && type !== "vec" && type !== "hyde") return null; + const text = line.slice(colonIdx + 1).trim(); + if (!hasQueryTerm(text)) return null; + return { type: type as QueryType, text }; + }).filter((q): q is Queryable => q !== null); + + const filtered = includeLexical ? queryables : queryables.filter(q => q.type !== "lex"); + if (filtered.length > 0) return filtered; + + const fallback: Queryable[] = [ + { type: "hyde", text: `Information about ${query}` }, + { type: "lex", text: query }, + { type: "vec", text: query }, + ]; + return includeLexical ? fallback : fallback.filter(q => q.type !== "lex"); +} + +function cosineSimilarity(a: number[], b: number[]): number { + if (a.length === 0 || b.length === 0 || a.length !== b.length) return 0; + + let dot = 0; + let normA = 0; + let normB = 0; + for (let i = 0; i < a.length; i++) { + const va = a[i] ?? 0; + const vb = b[i] ?? 0; + dot += va * vb; + normA += va * va; + normB += vb * vb; + } + if (normA === 0 || normB === 0) return 0; + return dot / (Math.sqrt(normA) * Math.sqrt(normB)); +} + // ============================================================================= // LLM Interface // ============================================================================= @@ -317,6 +419,11 @@ export interface LLM { */ embed(text: string, options?: EmbedOptions): Promise; + /** + * Batch embed multiple texts + */ + embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]>; + /** * Generate text completion */ @@ -345,6 +452,253 @@ export interface LLM { dispose(): Promise; } +// ============================================================================= +// OpenRouter Implementation +// ============================================================================= + +type OpenRouterEmbeddingResponse = { + data?: Array<{ + embedding?: number[]; + index?: number; + }>; +}; + +type OpenRouterChatResponse = { + choices?: Array<{ + message?: { + content?: string | Array<{ text?: string }>; + }; + }>; +}; + +export type OpenRouterConfig = { + apiKey?: string; + apiKeyFile?: string; + baseUrl?: string; + embedModel?: string; + generateModel?: string; + rerankModel?: string; + requestTimeoutMs?: number; +}; + +export class OpenRouterLLM implements LLM { + private apiKey: string; + private baseUrl: string; + private embedModelUri: string; + private generateModelUri: string; + private rerankModelUri: string; + private requestTimeoutMs: number; + + constructor(config: OpenRouterConfig = {}) { + this.apiKey = loadOpenRouterApiKey(config); + this.baseUrl = stripTrailingSlash(config.baseUrl || process.env.QMD_OPENROUTER_BASE_URL || DEFAULT_OPENROUTER_BASE_URL); + this.embedModelUri = config.embedModel || process.env.QMD_OPENROUTER_EMBED_MODEL || DEFAULT_OPENROUTER_EMBED_MODEL; + this.generateModelUri = config.generateModel || process.env.QMD_OPENROUTER_GENERATE_MODEL || DEFAULT_OPENROUTER_GENERATE_MODEL; + this.rerankModelUri = config.rerankModel || process.env.QMD_OPENROUTER_RERANK_MODEL || DEFAULT_OPENROUTER_RERANK_MODEL; + this.requestTimeoutMs = config.requestTimeoutMs ?? 60_000; + } + + private async postJson(path: string, payload: unknown): Promise { + const url = `${this.baseUrl}${path}`; + const controller = new AbortController(); + const timeout = setTimeout(() => controller.abort(), this.requestTimeoutMs); + + try { + const response = await fetch(url, { + method: "POST", + headers: { + Authorization: `Bearer ${this.apiKey}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(payload), + signal: controller.signal, + }); + + const raw = await response.text(); + if (!response.ok) { + const body = raw.slice(0, 500); + throw new Error(`OpenRouter ${path} failed (${response.status}): ${body}`); + } + + return JSON.parse(raw) as T; + } finally { + clearTimeout(timeout); + } + } + + private static contentToString(content: string | Array<{ text?: string }> | undefined): string { + if (typeof content === "string") return content; + if (Array.isArray(content)) { + return content + .map(part => (typeof part?.text === "string" ? part.text : "")) + .join(""); + } + return ""; + } + + private async requestEmbeddings(input: string | string[], model: string): Promise { + const response = await this.postJson("/embeddings", { + model, + input, + encoding_format: "float", + }); + + const rows = Array.isArray(response.data) ? [...response.data] : []; + rows.sort((a, b) => (a.index ?? 0) - (b.index ?? 0)); + + return rows.map((row) => { + if (!Array.isArray(row.embedding)) return []; + return row.embedding; + }); + } + + async embed(text: string, options: EmbedOptions = {}): Promise { + try { + const model = options.model || this.embedModelUri; + const vectors = await this.requestEmbeddings(text, model); + const vector = vectors[0]; + if (!vector || vector.length === 0) { + throw new Error("OpenRouter embedding response missing embedding vector"); + } + + return { embedding: vector, model }; + } catch (error) { + console.error("OpenRouter embedding error:", error); + return null; + } + } + + async embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]> { + if (texts.length === 0) return []; + try { + const data = await this.requestEmbeddings(texts, this.embedModelUri); + + return texts.map((_, i) => { + const vector = data[i]; + if (!vector || vector.length === 0) return null; + return { embedding: vector, model: this.embedModelUri }; + }); + } catch (error) { + console.error("OpenRouter batch embedding error:", error); + return texts.map(() => null); + } + } + + async generate(prompt: string, options: GenerateOptions = {}): Promise { + try { + const model = options.model || this.generateModelUri; + const response = await this.postJson("/chat/completions", { + model, + temperature: options.temperature ?? 0.7, + max_tokens: options.maxTokens ?? 150, + messages: [ + { role: "user", content: prompt }, + ], + }); + + const content = OpenRouterLLM.contentToString(response.choices?.[0]?.message?.content); + if (!content) { + throw new Error("OpenRouter completion response missing content"); + } + + return { + text: content, + model, + done: true, + }; + } catch (error) { + console.error("OpenRouter generation error:", error); + return null; + } + } + + async modelExists(model: string): Promise { + return { name: model, exists: true }; + } + + async expandQuery(query: string, options: { context?: string; includeLexical?: boolean } = {}): Promise { + const includeLexical = options.includeLexical ?? true; + const context = options.context; + const contextBlock = context ? `Context: ${context}\n` : ""; + const lexicalRule = includeLexical + ? "You may use lex, vec, and hyde query types." + : "Use only vec and hyde query types (no lex entries)."; + + const prompt = [ + "Expand the search query into short retrieval variants.", + "Output only lines in this exact format: type: text", + "Allowed type values: lex, vec, hyde.", + lexicalRule, + "Keep at least one important term from the original query in each line.", + contextBlock, + `Original query: ${query}`, + ].filter(Boolean).join("\n"); + + try { + const response = await this.postJson("/chat/completions", { + model: this.generateModelUri, + temperature: 0.2, + max_tokens: 300, + messages: [{ role: "user", content: prompt }], + }); + + const content = OpenRouterLLM.contentToString(response.choices?.[0]?.message?.content); + return parseExpandedQueryLines(content, query, includeLexical); + } catch (error) { + console.error("OpenRouter query expansion error:", error); + const fallback: Queryable[] = [{ type: "vec", text: query }]; + if (includeLexical) fallback.unshift({ type: "lex", text: query }); + return fallback; + } + } + + async rerank( + query: string, + documents: RerankDocument[], + options: RerankOptions = {} + ): Promise { + if (documents.length === 0) { + return { results: [], model: options.model || this.rerankModelUri }; + } + + try { + const model = options.model || this.rerankModelUri; + const queryVectors = await this.requestEmbeddings(query, model); + const queryEmbedding = queryVectors[0]; + if (!queryEmbedding || queryEmbedding.length === 0) { + throw new Error("Failed to embed rerank query"); + } + + const docEmbeddings = await this.requestEmbeddings(documents.map(doc => doc.text), model); + const scored: RerankDocumentResult[] = documents.map((doc, index) => { + const emb = docEmbeddings[index]; + const rawCosine = emb ? cosineSimilarity(queryEmbedding, emb) : 0; + return { + file: doc.file, + index, + score: (rawCosine + 1) / 2, // Normalize cosine [-1,1] -> [0,1] + }; + }); + + scored.sort((a, b) => b.score - a.score); + return { + results: scored, + model, + }; + } catch (error) { + console.error("OpenRouter rerank error:", error); + return { + results: documents.map((doc, index) => ({ file: doc.file, index, score: 0 })), + model: options.model || this.rerankModelUri, + }; + } + } + + async dispose(): Promise { + // No local resources to dispose. + } +} + // ============================================================================= // node-llama-cpp Implementation // ============================================================================= @@ -840,7 +1194,7 @@ export class LlamaCpp implements LLM { const embedding = await context.getEmbeddingFor(text); return { - embedding: Array.from(embedding.vector), + embedding: Array.from(embedding.vector) as number[], model: this.embedModelUri, }; } catch (error) { @@ -1017,36 +1371,7 @@ export class LlamaCpp implements LLM { }, }); - const lines = result.trim().split("\n"); - const queryLower = query.toLowerCase(); - const queryTerms = queryLower.replace(/[^a-z0-9\s]/g, " ").split(/\s+/).filter(Boolean); - - const hasQueryTerm = (text: string): boolean => { - const lower = text.toLowerCase(); - if (queryTerms.length === 0) return true; - return queryTerms.some(term => lower.includes(term)); - }; - - const queryables: Queryable[] = lines.map(line => { - const colonIdx = line.indexOf(":"); - if (colonIdx === -1) return null; - const type = line.slice(0, colonIdx).trim(); - if (type !== 'lex' && type !== 'vec' && type !== 'hyde') return null; - const text = line.slice(colonIdx + 1).trim(); - if (!hasQueryTerm(text)) return null; - return { type: type as QueryType, text }; - }).filter((q): q is Queryable => q !== null); - - // Filter out lex entries if not requested - const filtered = includeLexical ? queryables : queryables.filter(q => q.type !== 'lex'); - if (filtered.length > 0) return filtered; - - const fallback: Queryable[] = [ - { type: 'hyde', text: `Information about ${query}` }, - { type: 'lex', text: query }, - { type: 'vec', text: query }, - ]; - return includeLexical ? fallback : fallback.filter(q => q.type !== 'lex'); + return parseExpandedQueryLines(result, query, includeLexical); } catch (error) { console.error("Structured query expansion failed:", error); // Fallback to original query @@ -1232,11 +1557,11 @@ export class LlamaCpp implements LLM { * Coordinates with LlamaCpp idle timeout to prevent disposal during active sessions. */ class LLMSessionManager { - private llm: LlamaCpp; + private llm: LLM; private _activeSessionCount = 0; private _inFlightOperations = 0; - constructor(llm: LlamaCpp) { + constructor(llm: LLM) { this.llm = llm; } @@ -1272,7 +1597,7 @@ class LLMSessionManager { this._inFlightOperations = Math.max(0, this._inFlightOperations - 1); } - getLlamaCpp(): LlamaCpp { + getLLM(): LLM { return this.llm; } } @@ -1375,18 +1700,18 @@ class LLMSession implements ILLMSession { } async embed(text: string, options?: EmbedOptions): Promise { - return this.withOperation(() => this.manager.getLlamaCpp().embed(text, options)); + return this.withOperation(() => this.manager.getLLM().embed(text, options)); } async embedBatch(texts: string[]): Promise<(EmbeddingResult | null)[]> { - return this.withOperation(() => this.manager.getLlamaCpp().embedBatch(texts)); + return this.withOperation(() => this.manager.getLLM().embedBatch(texts)); } async expandQuery( query: string, options?: { context?: string; includeLexical?: boolean } ): Promise { - return this.withOperation(() => this.manager.getLlamaCpp().expandQuery(query, options)); + return this.withOperation(() => this.manager.getLLM().expandQuery(query, options)); } async rerank( @@ -1394,19 +1719,77 @@ class LLMSession implements ILLMSession { documents: RerankDocument[], options?: RerankOptions ): Promise { - return this.withOperation(() => this.manager.getLlamaCpp().rerank(query, documents, options)); + return this.withOperation(() => this.manager.getLLM().rerank(query, documents, options)); } } -// Session manager for the default LlamaCpp instance +// Session manager for the default LLM instance let defaultSessionManager: LLMSessionManager | null = null; +let defaultLlamaCpp: LlamaCpp | null = null; +let defaultOpenRouterLLM: OpenRouterLLM | null = null; +let defaultLLM: LLM | null = null; +let defaultLLMProvider: LLMProvider | null = null; +let didWarnOpenRouterRemote = false; + +/** + * Emit the remote-provider notice once per process. + */ +function warnOpenRouterOnce(): void { + if (didWarnOpenRouterRemote) return; + didWarnOpenRouterRemote = true; + process.stderr.write( + "Notice: QMD is using OpenRouter (remote inference over HTTPS) for model operations.\n" + ); +} /** - * Get the session manager for the default LlamaCpp instance. + * Resolve the default provider from environment. + * Defaults to local so remote inference is always opt-in. + */ +export function getDefaultLLMProvider(): LLMProvider { + return defaultLLMProvider ?? normalizeProvider(process.env.QMD_LLM_PROVIDER); +} + +function getDefaultOpenRouterLLM(): OpenRouterLLM { + if (!defaultOpenRouterLLM) { + defaultOpenRouterLLM = new OpenRouterLLM(); + } + return defaultOpenRouterLLM; +} + +/** + * Get the default LLM instance (local or OpenRouter based on QMD_LLM_PROVIDER). + */ +export function getDefaultLLM(): LLM { + const provider = normalizeProvider(process.env.QMD_LLM_PROVIDER); + if (defaultLLM && defaultLLMProvider === provider) { + return defaultLLM; + } + + if (defaultLLM && defaultLLMProvider !== provider) { + defaultSessionManager = null; + void defaultLLM.dispose().catch(() => {}); + defaultLLM = null; + } + + if (provider === "openrouter") { + warnOpenRouterOnce(); + defaultLLM = getDefaultOpenRouterLLM(); + defaultLLMProvider = "openrouter"; + return defaultLLM; + } + + defaultLLM = getDefaultLlamaCpp(); + defaultLLMProvider = "local"; + return defaultLLM; +} + +/** + * Get the session manager for the default LLM instance. */ function getSessionManager(): LLMSessionManager { - const llm = getDefaultLlamaCpp(); - if (!defaultSessionManager || defaultSessionManager.getLlamaCpp() !== llm) { + const llm = getDefaultLLM(); + if (!defaultSessionManager || defaultSessionManager.getLLM() !== llm) { defaultSessionManager = new LLMSessionManager(llm); } return defaultSessionManager; @@ -1450,11 +1833,9 @@ export function canUnloadLLM(): boolean { } // ============================================================================= -// Singleton for default LlamaCpp instance +// Singleton accessors // ============================================================================= -let defaultLlamaCpp: LlamaCpp | null = null; - /** * Get the default LlamaCpp instance (creates one if needed) */ @@ -1471,6 +1852,10 @@ export function getDefaultLlamaCpp(): LlamaCpp { */ export function setDefaultLlamaCpp(llm: LlamaCpp | null): void { defaultLlamaCpp = llm; + if (defaultLLMProvider === "local") { + defaultLLM = llm; + defaultSessionManager = null; + } } /** @@ -1479,7 +1864,47 @@ export function setDefaultLlamaCpp(llm: LlamaCpp | null): void { */ export async function disposeDefaultLlamaCpp(): Promise { if (defaultLlamaCpp) { + const existing = defaultLlamaCpp; await defaultLlamaCpp.dispose(); defaultLlamaCpp = null; + if (defaultLLM === existing) { + defaultLLM = null; + defaultLLMProvider = null; + defaultSessionManager = null; + } } } + +/** + * Dispose the active default LLM instance (provider-aware). + */ +export async function disposeDefaultLLM(): Promise { + const disposed = new Set(); + const disposeOne = async (llm: LLM | null): Promise => { + if (!llm || disposed.has(llm)) return; + disposed.add(llm); + await llm.dispose(); + }; + + await disposeOne(defaultLLM); + await disposeOne(defaultLlamaCpp); + await disposeOne(defaultOpenRouterLLM); + + defaultLLM = null; + defaultLLMProvider = null; + defaultLlamaCpp = null; + defaultOpenRouterLLM = null; + defaultSessionManager = null; +} + +/** + * Test helper: clears default singleton state without disposing native resources. + */ +export function resetDefaultLLMForTests(): void { + defaultLLM = null; + defaultLLMProvider = null; + defaultLlamaCpp = null; + defaultOpenRouterLLM = null; + defaultSessionManager = null; + didWarnOpenRouterRemote = false; +} diff --git a/src/qmd.ts b/src/qmd.ts index 9446ab9d..7069f218 100755 --- a/src/qmd.ts +++ b/src/qmd.ts @@ -71,7 +71,7 @@ import { createStore, getDefaultDbPath, } from "./store.js"; -import { disposeDefaultLlamaCpp, getDefaultLlamaCpp, withLLMSession, pullModels, DEFAULT_EMBED_MODEL_URI, DEFAULT_GENERATE_MODEL_URI, DEFAULT_RERANK_MODEL_URI, DEFAULT_MODEL_CACHE_DIR } from "./llm.js"; +import { getDefaultLlamaCpp, getDefaultLLM, disposeDefaultLLM, withLLMSession, pullModels, DEFAULT_EMBED_MODEL_URI, DEFAULT_GENERATE_MODEL_URI, DEFAULT_RERANK_MODEL_URI, DEFAULT_MODEL_CACHE_DIR, type ILLMSession, type RerankDocument } from "./llm.js"; import { formatSearchResults, formatDocuments, @@ -247,6 +247,28 @@ function computeDisplayPath( return filepath; } +// Rerank documents using node-llama-cpp cross-encoder model +async function rerank(query: string, documents: { file: string; text: string }[], _model: string = DEFAULT_RERANK_MODEL, _db?: Database, session?: ILLMSession): Promise<{ file: string; score: number }[]> { + if (documents.length === 0) return []; + + const total = documents.length; + process.stderr.write(`Reranking ${total} documents...\n`); + progress.indeterminate(); + + const rerankDocs: RerankDocument[] = documents.map((doc) => ({ + file: doc.file, + text: doc.text.slice(0, 4000), // Truncate to context limit + })); + + const result = session + ? await session.rerank(query, rerankDocs) + : await getDefaultLLM().rerank(query, rerankDocs); + + progress.clear(); + process.stderr.write("\n"); + + return result.results.map((r) => ({ file: r.file, score: r.score })); +} function formatTimeAgo(date: Date): string { const seconds = Math.floor((Date.now() - date.getTime()) / 1000); @@ -1581,7 +1603,7 @@ async function vectorIndex(model: string = DEFAULT_EMBED_MODEL, force: boolean = const title = extractTitle(item.body, item.path); const displayName = item.path; - const chunks = await chunkDocumentByTokens(item.body); // Uses actual tokenizer + const chunks = await chunkDocumentByTokens(item.body); // Falls back to char chunking for remote providers if (chunks.length > 1) multiChunkDocs++; @@ -2527,6 +2549,7 @@ function showHelp(): void { console.log(""); console.log("Global options:"); console.log(" --index - Use a named index (default: index)"); + console.log(" Env: QMD_LLM_PROVIDER=openrouter - Use OpenRouter instead of local GGUF models"); console.log(""); console.log("Search options:"); console.log(" -n - Max results (default 5, or 20 for --files/--json)"); @@ -2544,6 +2567,11 @@ function showHelp(): void { console.log(" --max-bytes - Skip files larger than N bytes (default 10240)"); console.log(" --json/--csv/--md/--xml/--files - Same formats as search"); console.log(""); + console.log("Local models (auto-downloaded from HuggingFace when provider=local):"); + console.log(" Embedding: embeddinggemma-300M-Q8_0"); + console.log(" Reranking: qwen3-reranker-0.6b-q8_0"); + console.log(" Generation: Qwen3-0.6B-Q8_0"); + console.log(""); console.log(`Index: ${getDbPath()}`); } @@ -3019,7 +3047,7 @@ if (isMain) { } if (cli.command !== "mcp") { - await disposeDefaultLlamaCpp(); + await disposeDefaultLLM(); process.exit(0); } diff --git a/src/store.ts b/src/store.ts index 4c7f8a0f..83da712c 100644 --- a/src/store.ts +++ b/src/store.ts @@ -17,8 +17,9 @@ import picomatch from "picomatch"; import { createHash } from "crypto"; import { realpathSync, statSync, mkdirSync } from "node:fs"; import { - LlamaCpp, getDefaultLlamaCpp, + getDefaultLLM, + getDefaultLLMProvider, formatQueryForEmbedding, formatDocForEmbedding, type RerankDocument, @@ -1475,6 +1476,16 @@ export async function chunkDocumentByTokens( overlapTokens: number = CHUNK_OVERLAP_TOKENS, windowTokens: number = CHUNK_WINDOW_TOKENS ): Promise<{ text: string; pos: number; tokens: number }[]> { + if (getDefaultLLMProvider() !== "local") { + // Remote providers do not expose tokenizer APIs; fall back to char chunking. + const approxChunks = chunkDocument(content, CHUNK_SIZE_CHARS, CHUNK_OVERLAP_CHARS); + return approxChunks.map(chunk => ({ + text: chunk.text, + pos: chunk.pos, + tokens: Math.max(1, Math.ceil(chunk.text.length / 4)), + })); + } + const llm = getDefaultLlamaCpp(); // Use moderate chars/token estimate (prose ~4, code ~2, mixed ~3) @@ -2290,10 +2301,16 @@ export async function searchVec(db: Database, query: string, model: string, limi async function getEmbedding(text: string, model: string, isQuery: boolean, session?: ILLMSession): Promise { // Format text using the appropriate prompt template - const formattedText = isQuery ? formatQueryForEmbedding(text, model) : formatDocForEmbedding(text, undefined, model); + const provider = getDefaultLLMProvider(); + const formattedText = isQuery + ? formatQueryForEmbedding(text, provider === "local" ? model : undefined) + : formatDocForEmbedding(text, undefined, provider === "local" ? model : undefined); + const options = provider === "local" + ? { model, isQuery } + : { isQuery }; const result = session - ? await session.embed(formattedText, { model, isQuery }) - : await getDefaultLlamaCpp().embed(formattedText, { model, isQuery }); + ? await session.embed(formattedText, options) + : await getDefaultLLM().embed(formattedText, options); return result?.embedding || null; } @@ -2347,8 +2364,9 @@ export function insertEmbedding( // ============================================================================= export async function expandQuery(query: string, model: string = DEFAULT_QUERY_MODEL, db: Database, intent?: string): Promise { + const provider = getDefaultLLMProvider(); // Check cache first — stored as JSON preserving types - const cacheKey = getCacheKey("expandQuery", { query, model, ...(intent && { intent }) }); + const cacheKey = getCacheKey("expandQuery", { query, model, provider, ...(intent && { intent }) }); const cached = getCachedResult(db, cacheKey); if (cached) { try { @@ -2358,9 +2376,9 @@ export async function expandQuery(query: string, model: string = DEFAULT_QUERY_M } } - const llm = getDefaultLlamaCpp(); - // Note: LlamaCpp uses hardcoded model, model parameter is ignored - const results = await llm.expandQuery(query, { intent }); + const results = provider === "local" + ? await getDefaultLlamaCpp().expandQuery(query, { intent }) + : await getDefaultLLM().expandQuery(query); // Map Queryable[] → ExpandedQuery[] (same shape, decoupled from llm.ts internals). // Filter out entries that duplicate the original query text. @@ -2382,7 +2400,7 @@ export async function expandQuery(query: string, model: string = DEFAULT_QUERY_M export async function rerank(query: string, documents: { file: string; text: string }[], model: string = DEFAULT_RERANK_MODEL, db: Database, intent?: string): Promise<{ file: string; score: number }[]> { // Prepend intent to rerank query so the reranker scores with domain context const rerankQuery = intent ? `${intent}\n\n${query}` : query; - + const provider = getDefaultLLMProvider(); const cachedResults: Map = new Map(); const uncachedDocsByChunk: Map = new Map(); @@ -2392,9 +2410,16 @@ export async function rerank(query: string, documents: { file: string; text: str // File path is excluded from the new cache key because the reranker score // depends on the chunk content, not where it came from. for (const doc of documents) { - const cacheKey = getCacheKey("rerank", { query: rerankQuery, model, chunk: doc.text }); - const legacyCacheKey = getCacheKey("rerank", { query, file: doc.file, model, chunk: doc.text }); - const cached = getCachedResult(db, cacheKey) ?? getCachedResult(db, legacyCacheKey); + const cacheKey = getCacheKey("rerank", { query, file: doc.file, model, provider, chunk: doc.text }); + const legacyProviderAgnosticKey = provider === "local" + ? getCacheKey("rerank", { query: rerankQuery, model, chunk: doc.text }) + : null; + const legacyFileKey = provider === "local" + ? getCacheKey("rerank", { query, file: doc.file, model, chunk: doc.text }) + : null; + const cached = getCachedResult(db, cacheKey) + ?? (legacyProviderAgnosticKey ? getCachedResult(db, legacyProviderAgnosticKey) : null) + ?? (legacyFileKey ? getCachedResult(db, legacyFileKey) : null); if (cached !== null) { cachedResults.set(doc.text, parseFloat(cached)); } else { @@ -2402,17 +2427,17 @@ export async function rerank(query: string, documents: { file: string; text: str } } - // Rerank uncached documents using LlamaCpp if (uncachedDocsByChunk.size > 0) { - const llm = getDefaultLlamaCpp(); const uncachedDocs = [...uncachedDocsByChunk.values()]; - const rerankResult = await llm.rerank(rerankQuery, uncachedDocs, { model }); + const rerankResult = provider === "local" + ? await getDefaultLlamaCpp().rerank(rerankQuery, uncachedDocs, { model }) + : await getDefaultLLM().rerank(rerankQuery, uncachedDocs); // Cache results by chunk text so identical chunks across files are scored once. const textByFile = new Map(uncachedDocs.map(d => [d.file, d.text])); for (const result of rerankResult.results) { const chunk = textByFile.get(result.file) || ""; - const cacheKey = getCacheKey("rerank", { query: rerankQuery, model, chunk }); + const cacheKey = getCacheKey("rerank", { query: rerankQuery, model, provider, chunk }); setCachedResult(db, cacheKey, result.score.toString()); cachedResults.set(chunk, result.score); }