From 2b868e962e0353655dfee5bbb74716549bd0a14f Mon Sep 17 00:00:00 2001 From: Claude Date: Wed, 5 Nov 2025 17:33:40 +0000 Subject: [PATCH] Improve model download management and add context awareness This commit fixes model download issues and implements context-aware AI interactions: Model Download Improvements: - Add retry logic with exponential backoff (max 3 retries) - Implement proper error handling for network issues (ECONNRESET, ETIMEDOUT, etc.) - Add heartbeat monitoring to detect stalled downloads (2-min timeout) - Add socket keep-alive to prevent connection drops - Track active downloads to prevent duplicate downloads - Provide user-friendly error messages - Add download cancellation support Context Awareness Implementation: - Add PageContext type for current page URL, title, content, and selected text - Add AIContext type combining page context, browsing history, and bookmarks - Update GenerateOptions and ChatOptions to accept optional context - Implement buildContextualSystemPrompt() to inject context into AI prompts - Update generate() and chat() methods to use contextual prompts - Add page:getContext IPC handler to capture current page info - Update IPC handlers to validate and pass context to Ollama service Technical Details: - Add error field to PullProgress for better error reporting - Add retry status to show retry attempts to users - Use HTTP/HTTPS agents with keep-alive for stable connections - Truncate page content to 5000 chars to avoid token limits - Validate all context fields for security This enables the AI to be aware of: - Current browsing page (URL, title, content) - Selected text on the page - Recent browsing history - User bookmarks All changes maintain backward compatibility. --- src/main/ipc/handlers.ts | 61 +++++- src/main/preload.ts | 1 + src/main/services/ollama.ts | 357 ++++++++++++++++++++++++++++++++---- src/shared/types.ts | 17 ++ 4 files changed, 400 insertions(+), 36 deletions(-) diff --git a/src/main/ipc/handlers.ts b/src/main/ipc/handlers.ts index bbf78f7..9a9f8b8 100644 --- a/src/main/ipc/handlers.ts +++ b/src/main/ipc/handlers.ts @@ -181,7 +181,7 @@ export function registerIpcHandlers() { } }); - // AI context handler + // AI context handlers ipcMain.handle('browsing:getContext', async (event, limit?: number) => { try { if (limit !== undefined) { @@ -194,6 +194,43 @@ export function registerIpcHandlers() { } }); + // Get current page context (URL, title, selected text, etc.) + ipcMain.handle('page:getContext', async (event, pageInfo: any) => { + try { + // Validate page info + if (!pageInfo || typeof pageInfo !== 'object') { + return { url: '', title: '' }; // Return empty context if invalid + } + + const context: any = {}; + + if (pageInfo.url) { + validateString(pageInfo.url, 'Page URL', 2048); + context.url = pageInfo.url; + } + + if (pageInfo.title) { + validateString(pageInfo.title, 'Page title', 1024); + context.title = pageInfo.title; + } + + if (pageInfo.selectedText) { + validateString(pageInfo.selectedText, 'Selected text', 50000); + context.selectedText = pageInfo.selectedText; + } + + if (pageInfo.content) { + validateString(pageInfo.content, 'Page content', 100000); + context.content = pageInfo.content; + } + + return context; + } catch (error: any) { + console.error('page:getContext validation error:', error.message); + throw error; + } + }); + // Tab session handlers ipcMain.handle('tabs:save', async (event, tabs: Tab[]) => { try { @@ -322,12 +359,23 @@ export function registerIpcHandlers() { validateString(options.system, 'System prompt', 10000); } + // Validate context if provided + if (options.context) { + if (options.context.page?.url) { + validateString(options.context.page.url, 'Page URL', 2048); + } + if (options.context.page?.title) { + validateString(options.context.page.title, 'Page title', 1024); + } + } + // Stream response tokens back to renderer const generator = ollamaService.generate({ model: options.model, prompt: options.prompt, images: options.images, system: options.system, + context: options.context, stream: true, }); @@ -365,10 +413,21 @@ export function registerIpcHandlers() { } } + // Validate context if provided + if (options.context) { + if (options.context.page?.url) { + validateString(options.context.page.url, 'Page URL', 2048); + } + if (options.context.page?.title) { + validateString(options.context.page.title, 'Page title', 1024); + } + } + // Stream response tokens back to renderer const generator = ollamaService.chat({ model: options.model, messages: options.messages, + context: options.context, stream: true, }); diff --git a/src/main/preload.ts b/src/main/preload.ts index 36739d1..0839835 100644 --- a/src/main/preload.ts +++ b/src/main/preload.ts @@ -15,6 +15,7 @@ const ALLOWED_INVOKE_CHANNELS = [ 'bookmark:deleteByUrl', 'bookmark:update', 'browsing:getContext', + 'page:getContext', 'tabs:save', 'tabs:load', 'tabs:clear', diff --git a/src/main/services/ollama.ts b/src/main/services/ollama.ts index 636a12d..af8c3d8 100644 --- a/src/main/services/ollama.ts +++ b/src/main/services/ollama.ts @@ -16,6 +16,21 @@ export interface PullProgress { completed?: number; total?: number; digest?: string; + error?: string; +} + +export interface PageContext { + url?: string; + title?: string; + content?: string; + selectedText?: string; + screenshot?: string; +} + +export interface AIContext { + page?: PageContext; + browsingHistory?: any[]; + bookmarks?: any[]; } export interface GenerateRequest { @@ -24,6 +39,7 @@ export interface GenerateRequest { images?: string[]; stream?: boolean; system?: string; + context?: AIContext; } export interface GenerateResponse { @@ -48,6 +64,7 @@ export interface ChatRequest { model: string; messages: ChatMessage[]; stream?: boolean; + context?: AIContext; } export class OllamaService { @@ -55,6 +72,7 @@ export class OllamaService { private client: AxiosInstance; private process: ChildProcess | null = null; private isServerRunning = false; + private activePulls: Map = new Map(); constructor(baseURL = 'http://localhost:11434') { this.baseURL = baseURL; @@ -64,6 +82,106 @@ export class OllamaService { }); } + /** + * Check if an error is retryable + */ + private isRetryableError(error: any): boolean { + if (!error) return false; + + const retryableCodes = ['ECONNRESET', 'ETIMEDOUT', 'ECONNREFUSED', 'EPIPE', 'ENOTFOUND']; + const retryableStatusCodes = [408, 429, 500, 502, 503, 504]; + + // Check error code + if (error.code && retryableCodes.includes(error.code)) { + return true; + } + + // Check HTTP status code + if (error.response?.status && retryableStatusCodes.includes(error.response.status)) { + return true; + } + + // Check if error message indicates network issue + const errorMessage = error.message?.toLowerCase() || ''; + if ( + errorMessage.includes('network') || + errorMessage.includes('timeout') || + errorMessage.includes('socket') || + errorMessage.includes('aborted') + ) { + return true; + } + + return false; + } + + /** + * Sleep for a specified duration + */ + private sleep(ms: number): Promise { + return new Promise((resolve) => setTimeout(resolve, ms)); + } + + /** + * Build context-aware system prompt + */ + private buildContextualSystemPrompt(baseSystem: string | undefined, context?: AIContext): string { + if (!context) { + return baseSystem || ''; + } + + const contextParts: string[] = []; + + if (baseSystem) { + contextParts.push(baseSystem); + } + + // Add page context + if (context.page) { + const { url, title, content, selectedText } = context.page; + + contextParts.push('\n## Current Page Context'); + + if (url) { + contextParts.push(`URL: ${url}`); + } + + if (title) { + contextParts.push(`Page Title: ${title}`); + } + + if (selectedText) { + contextParts.push(`\nSelected Text:\n${selectedText}`); + } + + if (content) { + // Limit content to first 5000 characters to avoid token limits + const truncatedContent = content.length > 5000 ? content.substring(0, 5000) + '...' : content; + contextParts.push(`\nPage Content:\n${truncatedContent}`); + } + } + + // Add browsing history context + if (context.browsingHistory && context.browsingHistory.length > 0) { + contextParts.push('\n## Recent Browsing History'); + const historyItems = context.browsingHistory.slice(0, 10).map( + (h: any) => `- ${h.title || 'Untitled'} (${h.url})` + ); + contextParts.push(historyItems.join('\n')); + } + + // Add bookmarks context + if (context.bookmarks && context.bookmarks.length > 0) { + contextParts.push('\n## Bookmarks'); + const bookmarkItems = context.bookmarks.slice(0, 10).map( + (b: any) => `- ${b.title || 'Untitled'} (${b.url})` + ); + contextParts.push(bookmarkItems.join('\n')); + } + + return contextParts.join('\n'); + } + /** * Get the path to the bundled Ollama executable */ @@ -233,52 +351,173 @@ export class OllamaService { } /** - * Pull/download a model from Ollama library + * Pull/download a model from Ollama library with retry logic * Returns an async generator for progress updates */ - async *pullModel(modelName: string): AsyncGenerator { + async *pullModel(modelName: string, maxRetries = 3): AsyncGenerator { await this.ensureRunning(); - try { - const response = await this.client.post( - '/api/pull', - { name: modelName }, - { - responseType: 'stream', - timeout: 0, // No timeout for downloads - } - ); - - const stream = response.data; - let buffer = ''; + // Track active pull to prevent duplicates + if (this.activePulls.get(modelName)) { + throw new Error(`Model ${modelName} is already being downloaded`); + } - for await (const chunk of stream) { - buffer += chunk.toString(); - const lines = buffer.split('\n'); - buffer = lines.pop() || ''; + this.activePulls.set(modelName, true); - for (const line of lines) { - if (line.trim()) { - try { - const progress: PullProgress = JSON.parse(line); - yield progress; + try { + let attempt = 0; + let lastError: any = null; + + while (attempt <= maxRetries) { + try { + // Yield retry status if this is a retry + if (attempt > 0) { + yield { + status: 'retrying', + error: `Retrying download (attempt ${attempt + 1}/${maxRetries + 1})...`, + }; + // Exponential backoff: 2s, 4s, 8s + const backoffMs = Math.min(2000 * Math.pow(2, attempt - 1), 8000); + await this.sleep(backoffMs); + } - // Check if pull is complete - if (progress.status === 'success' || progress.status === 'complete') { - return; + const response = await this.client.post( + '/api/pull', + { name: modelName }, + { + responseType: 'stream', + timeout: 0, // No timeout for downloads + // Add socket timeout to detect stalled connections + httpAgent: new (require('http').Agent)({ + keepAlive: true, + timeout: 60000, // 60 second socket timeout + }), + httpsAgent: new (require('https').Agent)({ + keepAlive: true, + timeout: 60000, // 60 second socket timeout + }), + } + ); + + const stream = response.data; + let buffer = ''; + let lastProgressTime = Date.now(); + const heartbeatTimeout = 120000; // 2 minutes without progress = stalled + + // Set up heartbeat check + const heartbeatInterval = setInterval(() => { + const timeSinceProgress = Date.now() - lastProgressTime; + if (timeSinceProgress > heartbeatTimeout) { + console.warn('Download stalled, no progress for', heartbeatTimeout / 1000, 'seconds'); + stream.destroy(new Error('Download stalled - no progress')); + } + }, 10000); // Check every 10 seconds + + try { + for await (const chunk of stream) { + lastProgressTime = Date.now(); // Update heartbeat + buffer += chunk.toString(); + const lines = buffer.split('\n'); + buffer = lines.pop() || ''; + + for (const line of lines) { + if (line.trim()) { + try { + const progress: PullProgress = JSON.parse(line); + yield progress; + + // Check if pull is complete + if (progress.status === 'success' || progress.status === 'complete') { + clearInterval(heartbeatInterval); + this.activePulls.delete(modelName); + return; + } + + // Check for error status + if (progress.status === 'error') { + clearInterval(heartbeatInterval); + throw new Error(progress.error || 'Unknown error during download'); + } + } catch (parseError) { + console.warn('Failed to parse progress line:', line); + } + } } - } catch (_e) { - console.warn('Failed to parse progress line:', line); } + + clearInterval(heartbeatInterval); + + // If we reach here, stream ended without success status + console.warn('Stream ended without completion status'); + throw new Error('Download stream ended unexpectedly'); + } catch (streamError) { + clearInterval(heartbeatInterval); + throw streamError; + } + } catch (error: any) { + lastError = error; + const errorCode = error.code || 'UNKNOWN'; + const errorMessage = error.message || 'Unknown error'; + + console.error( + `Pull model attempt ${attempt + 1} failed:`, + errorCode, + errorMessage + ); + + // If not retryable or max retries reached, throw + if (!this.isRetryableError(error) || attempt >= maxRetries) { + throw error; } + + // Yield error status before retrying + yield { + status: 'error', + error: `Network error (${errorCode}). Will retry...`, + }; + + attempt++; } } - } catch (error) { - console.error('Failed to pull model:', error); - throw new Error(`Failed to pull model ${modelName}`); + + // If we exhausted all retries + throw lastError || new Error('Failed to pull model after maximum retries'); + } catch (error: any) { + this.activePulls.delete(modelName); + + const errorCode = error.code || 'UNKNOWN'; + const errorMessage = error.message || 'Unknown error'; + + console.error('Failed to pull model:', errorCode, errorMessage, error); + + // Provide user-friendly error messages + let friendlyMessage = `Failed to download model ${modelName}`; + + if (errorCode === 'ECONNRESET') { + friendlyMessage += ': Connection was reset. Please check your internet connection.'; + } else if (errorCode === 'ETIMEDOUT') { + friendlyMessage += ': Connection timed out. Please check your internet connection.'; + } else if (errorCode === 'ECONNREFUSED') { + friendlyMessage += ': Cannot connect to Ollama server.'; + } else if (errorMessage.includes('stalled')) { + friendlyMessage += ': Download stalled. Please try again.'; + } else { + friendlyMessage += `: ${errorMessage}`; + } + + throw new Error(friendlyMessage); + } finally { + this.activePulls.delete(modelName); } } + /** + * Cancel an active model download + */ + cancelPull(modelName: string): void { + this.activePulls.delete(modelName); + } + /** * Delete a model */ @@ -296,14 +535,26 @@ export class OllamaService { } /** - * Generate text with optional vision input + * Generate text with optional vision input and context awareness * Returns an async generator for streaming responses */ async *generate(request: GenerateRequest): AsyncGenerator { await this.ensureRunning(); try { - const response = await this.client.post('/api/generate', request, { + // Build contextual system prompt + const contextualSystem = this.buildContextualSystemPrompt(request.system, request.context); + + // Build the request with context + const ollamaRequest = { + model: request.model, + prompt: request.prompt, + images: request.images, + stream: request.stream, + system: contextualSystem || undefined, + }; + + const response = await this.client.post('/api/generate', ollamaRequest, { responseType: 'stream', timeout: 0, // No timeout for generation }); @@ -341,14 +592,50 @@ export class OllamaService { } /** - * Chat completion with conversation history + * Chat completion with conversation history and context awareness * Returns an async generator for streaming responses */ async *chat(request: ChatRequest): AsyncGenerator { await this.ensureRunning(); try { - const response = await this.client.post('/api/chat', request, { + let messages = [...request.messages]; + + // If context is provided, prepend it as a system message or enhance existing system message + if (request.context) { + const contextualSystem = this.buildContextualSystemPrompt('', request.context); + + if (contextualSystem) { + // Check if there's already a system message + const systemMessageIndex = messages.findIndex((m) => m.role === 'system'); + + if (systemMessageIndex >= 0) { + // Enhance existing system message + messages[systemMessageIndex] = { + ...messages[systemMessageIndex], + content: messages[systemMessageIndex].content + '\n\n' + contextualSystem, + }; + } else { + // Add new system message at the beginning + messages = [ + { + role: 'system', + content: contextualSystem, + }, + ...messages, + ]; + } + } + } + + // Build the request with enhanced messages + const ollamaRequest = { + model: request.model, + messages: messages, + stream: request.stream, + }; + + const response = await this.client.post('/api/chat', ollamaRequest, { responseType: 'stream', timeout: 0, }); diff --git a/src/shared/types.ts b/src/shared/types.ts index 189c7e0..5aedc9a 100644 --- a/src/shared/types.ts +++ b/src/shared/types.ts @@ -86,6 +86,7 @@ export interface PullProgress { completed?: number; total?: number; digest?: string; + error?: string; } export interface ChatMessage { @@ -103,14 +104,30 @@ export interface Conversation { updatedAt: number; } +export interface PageContext { + url?: string; + title?: string; + content?: string; + selectedText?: string; + screenshot?: string; +} + +export interface AIContext { + page?: PageContext; + browsingHistory?: HistoryEntry[]; + bookmarks?: Bookmark[]; +} + export interface GenerateOptions { model: string; prompt: string; images?: string[]; system?: string; + context?: AIContext; } export interface ChatOptions { model: string; messages: ChatMessage[]; + context?: AIContext; }