diff --git a/app/api/web-search/route.ts b/app/api/web-search/route.ts index 5a9708260..b81ee6b1e 100644 --- a/app/api/web-search/route.ts +++ b/app/api/web-search/route.ts @@ -2,21 +2,25 @@ * Web Search API * * POST /api/web-search - * Simple JSON request/response using Tavily search. + * Supports multiple search providers (Tavily, Claude). */ import { NextRequest } from 'next/server'; import { callLLM } from '@/lib/ai/llm'; import { searchWithTavily, formatSearchResultsAsContext } from '@/lib/web-search/tavily'; +import { searchWithClaude } from '@/lib/web-search/claude'; +import { WEB_SEARCH_PROVIDERS } from '@/lib/web-search/constants'; import { resolveWebSearchApiKey } from '@/lib/server/provider-config'; +import { validateUrlForSSRF } from '@/lib/server/ssrf-guard'; import { createLogger } from '@/lib/logger'; -import { apiError, apiSuccess } from '@/lib/server/api-response'; +import { API_ERROR_CODES, apiError, apiSuccess } from '@/lib/server/api-response'; import { buildSearchQuery, SEARCH_QUERY_REWRITE_EXCERPT_LENGTH, } from '@/lib/server/search-query-builder'; import { resolveModelFromHeaders } from '@/lib/server/resolve-model'; import type { AICallFn } from '@/lib/generation/pipeline-types'; +import type { WebSearchProviderId } from '@/lib/web-search/types'; const log = createLogger('WebSearch'); @@ -28,23 +32,47 @@ export async function POST(req: NextRequest) { query: requestQuery, pdfText, apiKey: clientApiKey, + providerId: requestProviderId, + providerConfig, } = body as { query?: string; pdfText?: string; apiKey?: string; + providerId?: WebSearchProviderId; + providerConfig?: { + modelId?: string; + baseUrl?: string; + tools?: Array<{ type: string; name: string }>; + }; }; query = requestQuery; + // Provider must be explicitly specified + const providerId: WebSearchProviderId | null = requestProviderId ?? null; + if (!query || !query.trim()) { return apiError('MISSING_REQUIRED_FIELD', 400, 'query is required'); } - const apiKey = resolveWebSearchApiKey(clientApiKey); + if (!providerId) { + return apiError( + 'MISSING_PROVIDER', + 400, + 'Web search provider is not selected. Please select a provider in the toolbar.', + ); + } + + if (!(providerId in WEB_SEARCH_PROVIDERS)) { + return apiError('INVALID_REQUEST', 400, `Unknown web search provider: ${providerId}`); + } + + const apiKey = resolveWebSearchApiKey(providerId, clientApiKey); if (!apiKey) { + const envVar = providerId === 'claude' ? 'ANTHROPIC_API_KEY' : 'TAVILY_API_KEY'; return apiError( 'MISSING_API_KEY', 400, - 'Tavily API key is not configured. Set it in Settings → Web Search or set TAVILY_API_KEY env var.', + `${providerId} API key is not configured. Set it in Settings → Web Search or set ${envVar} env var.`, ); } @@ -75,13 +103,45 @@ export async function POST(req: NextRequest) { const searchQuery = await buildSearchQuery(query, boundedPdfText, aiCall); log.info('Running web search API request', { + provider: providerId, hasPdfContext: searchQuery.hasPdfContext, rawRequirementLength: searchQuery.rawRequirementLength, rewriteAttempted: searchQuery.rewriteAttempted, finalQueryLength: searchQuery.finalQueryLength, }); - const result = await searchWithTavily({ query: searchQuery.query, apiKey }); + // Validate client-supplied base URL against SSRF in all environments + if (providerConfig?.baseUrl) { + const ssrfError = await validateUrlForSSRF(providerConfig.baseUrl); + if (ssrfError) { + return apiError(API_ERROR_CODES.INVALID_URL, 400, ssrfError); + } + } + + const baseUrl = + providerConfig?.baseUrl || WEB_SEARCH_PROVIDERS[providerId].defaultBaseUrl || ''; + + let result; + switch (providerId) { + case 'claude': { + result = await searchWithClaude({ + query: searchQuery.query, + apiKey, + baseUrl, + modelId: providerConfig?.modelId, + tools: providerConfig?.tools, + }); + break; + } + case 'tavily': { + result = await searchWithTavily({ + query: searchQuery.query, + apiKey, + baseUrl, + }); + break; + } + } const context = formatSearchResultsAsContext(result); return apiSuccess({ diff --git a/app/generation-preview/page.tsx b/app/generation-preview/page.tsx index c4dfd12a3..839a72324 100644 --- a/app/generation-preview/page.tsx +++ b/app/generation-preview/page.tsx @@ -300,15 +300,25 @@ function GenerationPreviewContent() { setWebSearchSources([]); const wsSettings = useSettingsStore.getState(); - const wsApiKey = - wsSettings.webSearchProvidersConfig?.[wsSettings.webSearchProviderId]?.apiKey; + const wsProviderId = wsSettings.webSearchProviderId; + const wsProviderConfig = wsProviderId + ? wsSettings.webSearchProvidersConfig?.[wsProviderId] + : null; const res = await fetch('/api/web-search', { method: 'POST', headers: getApiHeaders(), body: JSON.stringify({ query: currentSession.requirements.requirement, pdfText: currentSession.pdfText || undefined, - apiKey: wsApiKey || undefined, + providerId: wsProviderId || undefined, + apiKey: wsProviderConfig?.apiKey || undefined, + providerConfig: wsProviderId + ? { + modelId: wsProviderConfig?.modelId, + baseUrl: wsProviderConfig?.baseUrl, + tools: wsProviderConfig?.tools, + } + : undefined, }), signal, }); diff --git a/app/page.tsx b/app/page.tsx index 4c86d52af..2f49be494 100644 --- a/app/page.tsx +++ b/app/page.tsx @@ -55,21 +55,18 @@ import { useImportClassroom } from '@/lib/import/use-import-classroom'; const log = createLogger('Home'); -const WEB_SEARCH_STORAGE_KEY = 'webSearchEnabled'; const RECENT_OPEN_STORAGE_KEY = 'recentClassroomsOpen'; const INTERACTIVE_MODE_STORAGE_KEY = 'interactiveModeEnabled'; interface FormState { pdfFile: File | null; requirement: string; - webSearch: boolean; interactiveMode: boolean; } const initialFormState: FormState = { pdfFile: null, requirement: '', - webSearch: false, interactiveMode: false, }; @@ -89,6 +86,8 @@ function HomePage() { // Model setup state const currentModelId = useSettingsStore((s) => s.modelId); + const webSearchEnabled = useSettingsStore((s) => s.webSearchEnabled); + const setWebSearchEnabled = useSettingsStore((s) => s.setWebSearchEnabled); const [recentOpen, setRecentOpen] = useState(true); // Hydrate client-only state after mount (avoids SSR mismatch) @@ -101,16 +100,20 @@ function HomePage() { /* localStorage unavailable */ } try { - const savedWebSearch = localStorage.getItem(WEB_SEARCH_STORAGE_KEY); - const savedInteractiveMode = localStorage.getItem(INTERACTIVE_MODE_STORAGE_KEY); - const updates: Partial = {}; - if (savedWebSearch === 'true') updates.webSearch = true; - if (savedInteractiveMode === 'true') updates.interactiveMode = true; - if (Object.keys(updates).length > 0) { - setForm((prev) => ({ ...prev, ...updates })); + // Migrate webSearchEnabled from old localStorage key into the Zustand store + const oldWebSearch = localStorage.getItem('webSearchEnabled'); + if (oldWebSearch === 'true' && !useSettingsStore.getState().webSearchEnabled) { + const store = useSettingsStore.getState(); + if (!store.webSearchProviderId) { + store.setWebSearchProvider('tavily'); + } + store.setWebSearchEnabled(true); } + if (oldWebSearch !== null) localStorage.removeItem('webSearchEnabled'); + const savedInteractiveMode = localStorage.getItem(INTERACTIVE_MODE_STORAGE_KEY); + if (savedInteractiveMode === 'true') setForm((prev) => ({ ...prev, interactiveMode: true })); } catch { - /* localStorage unavailable */ + /* ignore */ } }, []); /* eslint-enable react-hooks/set-state-in-effect */ @@ -204,7 +207,6 @@ function HomePage() { const updateForm = (field: K, value: FormState[K]) => { setForm((prev) => ({ ...prev, [field]: value })); try { - if (field === 'webSearch') localStorage.setItem(WEB_SEARCH_STORAGE_KEY, String(value)); if (field === 'interactiveMode') localStorage.setItem(INTERACTIVE_MODE_STORAGE_KEY, String(value)); if (field === 'requirement') updateRequirementCache(value as string); @@ -268,7 +270,7 @@ function HomePage() { requirement: form.requirement, userNickname: userProfile.nickname || undefined, userBio: userProfile.bio || undefined, - webSearch: form.webSearch || undefined, + webSearch: webSearchEnabled || undefined, interactiveMode: form.interactiveMode, }; @@ -513,8 +515,8 @@ function HomePage() {
updateForm('webSearch', v)} + webSearch={webSearchEnabled} + onWebSearchChange={setWebSearchEnabled} onSettingsOpen={(section) => { setSettingsSection(section); setSettingsOpen(true); diff --git a/components/generation/generation-toolbar.tsx b/components/generation/generation-toolbar.tsx index e12ba3180..97ef03564 100644 --- a/components/generation/generation-toolbar.tsx +++ b/components/generation/generation-toolbar.tsx @@ -1,7 +1,16 @@ 'use client'; import { useState, useRef, useMemo } from 'react'; -import { Bot, Check, ChevronLeft, Paperclip, FileText, X, Globe2 } from 'lucide-react'; +import { + Bot, + Check, + ChevronLeft, + ChevronRight, + Paperclip, + FileText, + X, + Globe2, +} from 'lucide-react'; import { Popover, PopoverContent, PopoverTrigger } from '@/components/ui/popover'; import { Select, @@ -11,6 +20,7 @@ import { SelectValue, } from '@/components/ui/select'; import { Tooltip, TooltipContent, TooltipTrigger } from '@/components/ui/tooltip'; +import { Switch } from '@/components/ui/switch'; import { cn } from '@/lib/utils'; import { useI18n } from '@/lib/hooks/use-i18n'; import { useSettingsStore } from '@/lib/store/settings'; @@ -57,17 +67,19 @@ export function GenerationToolbar({ const webSearchProviderId = useSettingsStore((s) => s.webSearchProviderId); const webSearchProvidersConfig = useSettingsStore((s) => s.webSearchProvidersConfig); const setWebSearchProvider = useSettingsStore((s) => s.setWebSearchProvider); + const setWebSearchProviderConfig = useSettingsStore((s) => s.setWebSearchProviderConfig); const fileInputRef = useRef(null); const [isDragging, setIsDragging] = useState(false); + const [webSearchPopoverOpen, setWebSearchPopoverOpen] = useState(false); + const [drillWebSearchProvider, setDrillWebSearchProvider] = useState( + null, + ); - // Check if the selected web search provider has a valid config (API key or server-configured) - const webSearchProvider = WEB_SEARCH_PROVIDERS[webSearchProviderId]; - const webSearchConfig = webSearchProvidersConfig[webSearchProviderId]; - const webSearchAvailable = webSearchProvider - ? !webSearchProvider.requiresApiKey || - !!webSearchConfig?.apiKey || - !!webSearchConfig?.isServerConfigured - : false; + // Check if any web search provider has a valid config (API key or server-configured) + const webSearchAvailable = Object.values(WEB_SEARCH_PROVIDERS).some((provider) => { + const config = webSearchProvidersConfig[provider.id]; + return !provider.requiresApiKey || !!config?.apiKey || !!config?.isServerConfigured; + }); // Configured LLM providers (only those with valid credentials + models + endpoint) const configuredProviders = providersConfig @@ -268,90 +280,242 @@ export function GenerationToolbar({ {/* ── Web Search ── */} - {webSearchAvailable ? ( - - - - - - {/* Toggle */} - + + + {/* Level 1: Provider list */} + {!drillWebSearchProvider && ( +
+
+ + + {t('toolbar.webSearchProvider')} + + { + onWebSearchChange(enabled); + }} + className="scale-[0.85] origin-right" + /> +
+
+ {Object.values(WEB_SEARCH_PROVIDERS) + .filter((provider) => { + const cfg = webSearchProvidersConfig[provider.id]; + return ( + !provider.requiresApiKey || !!cfg?.apiKey || !!cfg?.isServerConfigured + ); + }) + .map((provider) => { + const cfg = webSearchProvidersConfig[provider.id as WebSearchProviderId]; + const isActive = webSearchProviderId === provider.id && webSearch; + const builtinModels = + WEB_SEARCH_PROVIDERS[provider.id as WebSearchProviderId]?.models || []; + const effectiveModels = cfg?.models?.length ? cfg.models : builtinModels; + const hasModels = effectiveModels.length > 0; + + return ( + + ); + })} +
+
)} - /> -
-

- {webSearch ? t('toolbar.webSearchOn') : t('toolbar.webSearchOff')} -

-

- {t('toolbar.webSearchDesc')} -

-
- - {/* Provider selector */} -
- - {t('toolbar.webSearchProvider')} - - -
-
-
- ) : ( - - - - + })()} + + + + + {!webSearchAvailable ? ( {t('toolbar.webSearchNoProvider')} - - )} + ) : webSearch && webSearchProviderId ? ( + + {(() => { + const providerName = + WEB_SEARCH_PROVIDERS[webSearchProviderId]?.name || webSearchProviderId; + const cfg = webSearchProvidersConfig[webSearchProviderId]; + const builtinModels = + WEB_SEARCH_PROVIDERS[webSearchProviderId as WebSearchProviderId]?.models || []; + const effectiveModels = cfg?.models?.length ? cfg.models : builtinModels; + if (cfg?.modelId) { + const modelName = + effectiveModels.find((m) => m.id === cfg.modelId)?.name || cfg.modelId; + return `${providerName} / ${modelName}`; + } + return providerName; + })()} + + ) : ( + {t('toolbar.webSearchProvider')} + )} + {/* ── Separator ── */}
diff --git a/components/settings/index.tsx b/components/settings/index.tsx index 552455873..44bcf8e99 100644 --- a/components/settings/index.tsx +++ b/components/settings/index.tsx @@ -235,7 +235,7 @@ export function SettingsDialog({ open, onOpenChange, initialSection }: SettingsD const [selectedProviderId, setSelectedProviderId] = useState(providerId); const [selectedPdfProviderId, setSelectedPdfProviderId] = useState(pdfProviderId); const [selectedWebSearchProviderId, setSelectedWebSearchProviderId] = - useState(webSearchProviderId); + useState(webSearchProviderId ?? 'tavily'); const [selectedImageProviderId, setSelectedImageProviderId] = useState(imageProviderId); const [selectedVideoProviderId, setSelectedVideoProviderId] = @@ -588,7 +588,9 @@ export function SettingsDialog({ open, onOpenChange, initialSection }: SettingsD ); } case 'web-search': { - const wsProvider = WEB_SEARCH_PROVIDERS[selectedWebSearchProviderId]; + const wsProvider = selectedWebSearchProviderId + ? WEB_SEARCH_PROVIDERS[selectedWebSearchProviderId] + : null; if (!wsProvider) return null; return ( <> @@ -863,7 +865,7 @@ export function SettingsDialog({ open, onOpenChange, initialSection }: SettingsD )} - {activeSection === 'web-search' && ( - + {activeSection === 'web-search' && selectedWebSearchProviderId && ( + )} {activeSection === 'image' && ( diff --git a/components/settings/tool-edit-dialog.tsx b/components/settings/tool-edit-dialog.tsx new file mode 100644 index 000000000..6a3375636 --- /dev/null +++ b/components/settings/tool-edit-dialog.tsx @@ -0,0 +1,72 @@ +'use client'; + +import { Dialog, DialogContent, DialogTitle, DialogDescription } from '@/components/ui/dialog'; +import { Button } from '@/components/ui/button'; +import { Input } from '@/components/ui/input'; +import { Label } from '@/components/ui/label'; +import { useI18n } from '@/lib/hooks/use-i18n'; + +interface Tool { + type: string; + name: string; +} + +interface ToolEditDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + tool: Tool | null; + setTool: (tool: Tool | null) => void; + onSave: () => void; +} + +export function ToolEditDialog({ open, onOpenChange, tool, setTool, onSave }: ToolEditDialogProps) { + const { t } = useI18n(); + + const handleClose = () => { + onOpenChange(false); + setTool(null); + }; + + if (!tool) return null; + + const canSave = !!(tool.type.trim() && tool.name.trim()); + + return ( + + + + {tool.type === '' + ? t('settings.webSearchAddToolTitle') + : t('settings.webSearchEditToolTitle')} + + {t('settings.webSearchToolDialogDesc')} +
+
+ + setTool({ ...tool, type: e.target.value })} + /> +
+
+ + setTool({ ...tool, name: e.target.value })} + /> +
+
+ + +
+
+
+
+ ); +} diff --git a/components/settings/web-search-model-dialog.tsx b/components/settings/web-search-model-dialog.tsx new file mode 100644 index 000000000..272988326 --- /dev/null +++ b/components/settings/web-search-model-dialog.tsx @@ -0,0 +1,169 @@ +'use client'; + +import { useState, useCallback } from 'react'; +import { Label } from '@/components/ui/label'; +import { Input } from '@/components/ui/input'; +import { Button } from '@/components/ui/button'; +import { + Dialog, + DialogContent, + DialogHeader, + DialogTitle, + DialogDescription, + DialogFooter, +} from '@/components/ui/dialog'; +import { Zap, Loader2, CheckCircle2, XCircle } from 'lucide-react'; +import { cn } from '@/lib/utils'; +import { useI18n } from '@/lib/hooks/use-i18n'; + +interface WebSearchModelDialogProps { + open: boolean; + onOpenChange: (open: boolean) => void; + model: { id: string; name: string } | null; + setModel: (model: { id: string; name: string } | null) => void; + isModelValid: ( + providerId: string, + modelId: string, + ) => Promise<{ status: boolean; error?: string }>; + onSave: () => void; + isEditing: boolean; + apiKey?: string; + providerId: string; + isServerConfigured?: boolean; +} + +export function WebSearchModelDialog({ + open, + model, + isEditing, + providerId, + apiKey, + isServerConfigured, + onOpenChange, + setModel, + onSave, + isModelValid, +}: WebSearchModelDialogProps) { + const { t } = useI18n(); + const [testStatus, setTestStatus] = useState<'idle' | 'testing' | 'success' | 'error'>('idle'); + const [testMessage, setTestMessage] = useState(''); + + const canTest = !!(model?.id?.trim() && model?.name?.trim() && (apiKey || isServerConfigured)); + + const handleTest = useCallback(async () => { + if (!canTest || !model) return; + try { + setTestStatus('testing'); + setTestMessage(''); + const { status, error } = await isModelValid(providerId, model.id); + if (status) { + setTestStatus('success'); + setTestMessage(t('settings.connectionSuccess')); + } else { + setTestStatus('error'); + setTestMessage(error || t('settings.connectionFailed')); + } + } catch { + setTestStatus('error'); + setTestMessage(t('settings.connectionFailed')); + } + }, [canTest, model, providerId, t, isModelValid]); + + const handleOpenChange = (open: boolean) => { + if (!open) { + setTestStatus('idle'); + setTestMessage(''); + } + onOpenChange(open); + }; + + return ( + + + + + {isEditing + ? t('settings.webSearchEditModelTitle') + : t('settings.webSearchAddModelTitle')} + + {t('settings.webSearchModelDialogDesc')} + +
+
+ + { + setModel(model ? { ...model, id: e.target.value } : null); + setTestStatus('idle'); + setTestMessage(''); + }} + placeholder="claude-opus-4-6" + className="font-mono text-sm" + /> +
+
+ + { + setModel(model ? { ...model, name: e.target.value } : null); + setTestStatus('idle'); + setTestMessage(''); + }} + placeholder="Claude Opus 4.6" + className="text-sm" + /> +
+ + {/* Test connection */} +
+
+ + +
+ {testMessage && ( +
+
+ {testStatus === 'success' && } + {testStatus === 'error' && } +

{testMessage}

+
+
+ )} +
+
+ + + + +
+
+ ); +} diff --git a/components/settings/web-search-settings.tsx b/components/settings/web-search-settings.tsx index d5cf37761..0ee01c444 100644 --- a/components/settings/web-search-settings.tsx +++ b/components/settings/web-search-settings.tsx @@ -1,13 +1,28 @@ 'use client'; -import { useState } from 'react'; +import { useState, useCallback } from 'react'; import { Label } from '@/components/ui/label'; import { Input } from '@/components/ui/input'; +import { Button } from '@/components/ui/button'; import { useI18n } from '@/lib/hooks/use-i18n'; import { useSettingsStore } from '@/lib/store/settings'; import { WEB_SEARCH_PROVIDERS } from '@/lib/web-search/constants'; import type { WebSearchProviderId } from '@/lib/web-search/types'; -import { Eye, EyeOff } from 'lucide-react'; +import { + Eye, + EyeOff, + Trash2, + Settings2, + Plus, + Zap, + Loader2, + CheckCircle2, + XCircle, +} from 'lucide-react'; +import { cn } from '@/lib/utils'; +import { toast } from 'sonner'; +import { ToolEditDialog } from './tool-edit-dialog'; +import { WebSearchModelDialog } from './web-search-model-dialog'; interface WebSearchSettingsProps { selectedProviderId: WebSearchProviderId; @@ -16,6 +31,14 @@ interface WebSearchSettingsProps { export function WebSearchSettings({ selectedProviderId }: WebSearchSettingsProps) { const { t } = useI18n(); const [showApiKey, setShowApiKey] = useState(false); + const [isToolDialogOpen, setIsToolDialogOpen] = useState(false); + const [editingTool, setEditingTool] = useState<{ type: string; name: string } | null>(null); + const [editingToolIndex, setEditingToolIndex] = useState(null); + const [isModelDialogOpen, setIsModelDialogOpen] = useState(false); + const [editingModel, setEditingModel] = useState<{ id: string; name: string } | null>(null); + const [editingModelIndex, setEditingModelIndex] = useState(null); + const [testStatus, setTestStatus] = useState<'idle' | 'testing' | 'success' | 'error'>('idle'); + const [testMessage, setTestMessage] = useState(''); const webSearchProvidersConfig = useSettingsStore((state) => state.webSearchProvidersConfig); const setWebSearchProviderConfig = useSettingsStore((state) => state.setWebSearchProviderConfig); @@ -23,29 +46,185 @@ export function WebSearchSettings({ selectedProviderId }: WebSearchSettingsProps const provider = WEB_SEARCH_PROVIDERS[selectedProviderId]; const isServerConfigured = !!webSearchProvidersConfig[selectedProviderId]?.isServerConfigured; - // Reset showApiKey when provider changes (derived state pattern) - const [prevSelectedProviderId, setPrevSelectedProviderId] = useState(selectedProviderId); - if (selectedProviderId !== prevSelectedProviderId) { - setPrevSelectedProviderId(selectedProviderId); - setShowApiKey(false); + const isModelValid = useCallback( + async (providerId: string, modelId?: string): Promise<{ status: boolean; error?: string }> => { + const config = webSearchProvidersConfig[selectedProviderId]; + const apiKey = config?.apiKey || ''; + const baseUrl = + config?.baseUrl || WEB_SEARCH_PROVIDERS[selectedProviderId]?.defaultBaseUrl || ''; + + switch (providerId) { + case 'tavily': { + const response = await fetch('/api/web-search', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + query: 'test connection', + baseUrl, + apiKey, + providerId: providerId, + providerConfig: { baseUrl: baseUrl || undefined }, + }), + }); + const data = await response.json(); + return Promise.resolve({ status: data.success || response.ok, error: data.error }); + } + case 'claude': { + // Use verify-model endpoint with the selected (or default) model + const model = modelId || config?.modelId || 'claude-haiku-4-5'; + const response = await fetch('/api/web-search', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + query: 'test connection', + apiKey, + providerType: 'anthropic', + providerId: providerId, + requiresApiKey: !isServerConfigured, + providerConfig: { + baseUrl, + modelId: model, + tools: [], + }, + }), + }); + const data = await response.json(); + return Promise.resolve({ status: data.success, error: data.error }); + } + default: { + return Promise.reject({ status: false }); + } + } + }, + [webSearchProvidersConfig, isServerConfigured, selectedProviderId], + ); + + const handleTestConnection = useCallback(async () => { + try { + setTestStatus('testing'); + setTestMessage(''); + const { status, error } = await isModelValid(selectedProviderId); + if (status) { + setTestStatus('success'); + setTestMessage(t('settings.connectionSuccess')); + } else { + setTestStatus('error'); + setTestMessage(error || t('settings.connectionFailed')); + } + } catch { + setTestStatus('error'); + setTestMessage(t('settings.connectionFailed')); + } + }, [selectedProviderId, isModelValid, t]); + + // Guard against undefined provider + if (!provider) { + return null; } + const tools = webSearchProvidersConfig[selectedProviderId]?.tools || []; + const models = webSearchProvidersConfig[selectedProviderId]?.models || []; + + const handleAddTool = () => { + setEditingTool({ type: '', name: '' }); + setEditingToolIndex(null); + setIsToolDialogOpen(true); + }; + + const handleEditTool = (tool: { type: string; name: string }, index: number) => { + setEditingTool({ ...tool }); + setEditingToolIndex(index); + setIsToolDialogOpen(true); + }; + + const handleSaveTool = () => { + if (!editingTool) return; + + const newTools = [...tools]; + if (editingToolIndex !== null) { + newTools[editingToolIndex] = { type: editingTool.type.trim(), name: editingTool.name.trim() }; + } else { + newTools.push({ type: editingTool.type.trim(), name: editingTool.name.trim() }); + } + setWebSearchProviderConfig(selectedProviderId, { tools: newTools }); + setIsToolDialogOpen(false); + setEditingTool(null); + setEditingToolIndex(null); + }; + + const handleDeleteTool = (index: number) => { + const newTools = tools.filter((_, i) => i !== index); + setWebSearchProviderConfig(selectedProviderId, { tools: newTools }); + }; + + const handleAddModel = () => { + setEditingModel({ id: '', name: '' }); + setEditingModelIndex(null); + setIsModelDialogOpen(true); + }; + + const handleEditModel = (model: { id: string; name: string }, index: number) => { + setEditingModel({ ...model }); + setEditingModelIndex(index); + setIsModelDialogOpen(true); + }; + + const handleSaveModel = () => { + if (!editingModel) return; + + const trimmedId = editingModel.id.trim(); + const trimmedName = editingModel.name.trim(); + const isDuplicate = models.some((m, i) => m.id === trimmedId && i !== editingModelIndex); + if (isDuplicate) { + toast.error(t('settings.webSearchModelIdDuplicate')); + return; + } + + const newModels = [...models]; + if (editingModelIndex !== null) { + newModels[editingModelIndex] = { id: trimmedId, name: trimmedName }; + } else { + newModels.push({ id: trimmedId, name: trimmedName }); + } + setWebSearchProviderConfig(selectedProviderId, { models: newModels }); + setIsModelDialogOpen(false); + setEditingModel(null); + setEditingModelIndex(null); + }; + + const handleDeleteModel = (index: number) => { + const newModels = models.filter((_, i) => i !== index); + setWebSearchProviderConfig(selectedProviderId, { models: newModels }); + }; + + const getApiKeyLabel = (selectedProviderId: string) => { + switch (selectedProviderId) { + case 'claude': { + return t('settings.webSearchClaudeApiKey'); + } + case 'tavily': { + return t('settings.webSearchTavilyApiKey'); + } + default: + return t('settings.webSearchApiKey'); + } + }; + return (
- {/* Server-configured notice */} {isServerConfigured && (
{t('settings.serverConfiguredNotice')}
)} - {/* API Key + Base URL Configuration */} {(provider.requiresApiKey || isServerConfigured) && ( <> -
-
- -
+ {/* API Key */} +
+ +
+
: }
-

{t('settings.webSearchApiKeyHint')}

-
- -
- - - setWebSearchProviderConfig(selectedProviderId, { - baseUrl: e.target.value, - }) +
+ {testMessage && ( +
+
+ {testStatus === 'success' && } + {testStatus === 'error' && } +

{testMessage}

+
+
+ )} +

{t('settings.webSearchApiKeyHint')}

+
+ + {/* Base URL */} +
+ + + setWebSearchProviderConfig(selectedProviderId, { + baseUrl: e.target.value, + }) + } + className="text-sm" + /> + {(() => { + const effectiveBaseUrl = + webSearchProvidersConfig[selectedProviderId]?.baseUrl || + provider.defaultBaseUrl || + ''; + if (!effectiveBaseUrl) return null; + const endpointPath = WEB_SEARCH_PROVIDERS[selectedProviderId]?.path || ''; + return ( +

+ {t('settings.requestUrl')}: {effectiveBaseUrl} + {endpointPath} +

+ ); + })()}
- {/* Request URL Preview */} - {(() => { - const effectiveBaseUrl = - webSearchProvidersConfig[selectedProviderId]?.baseUrl || - provider.defaultBaseUrl || - ''; - if (!effectiveBaseUrl) return null; - const fullUrl = effectiveBaseUrl + '/search'; - return ( -

- {t('settings.requestUrl')}: {fullUrl} -

- ); - })()} + {selectedProviderId === 'claude' && ( +
+
+
+ + +
+
+ {models.length === 0 ? ( +

+ {t('settings.webSearchNoModels')} {t('settings.webSearchNoModelsHint')} +

+ ) : ( + models.map((model, index) => ( +
+
+
{model.name}
+
{model.id}
+
+
+ + +
+
+ )) + )} +
+
+ +
+
+ + +
+
+ {tools.length === 0 ? ( +

+ {t('settings.webSearchNoTools')} +

+ ) : ( + tools.map((tool, index) => ( +
+
+ {tool.name} + + {tool.type} + +
+
+ + +
+
+ )) + )} +
+
+
+ )} )} + + + +
); } diff --git a/lib/i18n/locales/ar-SA.json b/lib/i18n/locales/ar-SA.json index cb2eca2aa..02594bd50 100644 --- a/lib/i18n/locales/ar-SA.json +++ b/lib/i18n/locales/ar-SA.json @@ -884,7 +884,29 @@ "webSearchApiKeyHint": "احصل على مفتاح API من tavily.com للبحث في الإنترنت", "webSearchBaseUrl": "العنوان الأساسي", "webSearchServerConfigured": "تم تكوين مفتاح API لـ Tavily على الخادم", - "optional": "اختياري" + "optional": "اختياري", + "webSearchTavilyApiKey": "مفتاح Tavily API", + "webSearchClaudeApiKey": "مفتاح Claude API", + "webSearchModelId": "معرّف النموذج", + "webSearchApiVersion": "إصدار API", + "webSearchToolsConfiguration": "الأدوات", + "webSearchNewTool": "أداة جديدة", + "webSearchNoTools": "لا توجد أدوات مُعدّة.", + "webSearchAddToolTitle": "إضافة أداة", + "webSearchEditToolTitle": "تعديل الأداة", + "webSearchToolDialogDesc": "قم بتكوين نوع الأداة واسمها لبحث Claude.", + "webSearchToolType": "النوع", + "webSearchToolName": "الاسم", + "webSearchModelsConfiguration": "النماذج", + "webSearchNewModel": "نموذج جديد", + "webSearchNoModels": "لا توجد نماذج مُعدّة.", + "webSearchNoModelsHint": "أضف نموذجًا واحدًا على الأقل لاستخدام بحث Claude.", + "webSearchAddModelTitle": "إضافة نموذج", + "webSearchEditModelTitle": "تعديل النموذج", + "webSearchModelDialogDesc": "قم بتكوين معرّف النموذج واسمه لبحث Claude.", + "webSearchModelIdField": "معرّف النموذج", + "webSearchModelNameField": "اسم النموذج", + "webSearchModelIdDuplicate": "يوجد نموذج بهذا المعرّف بالفعل" }, "profile": { "title": "الملف الشخصي", diff --git a/lib/i18n/locales/en-US.json b/lib/i18n/locales/en-US.json index c110806fb..50f43cb39 100644 --- a/lib/i18n/locales/en-US.json +++ b/lib/i18n/locales/en-US.json @@ -16,7 +16,7 @@ "webSearchOn": "Enabled", "webSearchOff": "Click to enable", "webSearchDesc": "Search the web for up-to-date information before generation", - "webSearchProvider": "Search engine", + "webSearchProvider": "Web Search", "webSearchNoProvider": "Configure search API key in Settings", "selectProvider": "Select provider", "configureProvider": "Set up model", @@ -879,13 +879,35 @@ "clearCacheSuccess": "Cache cleared, page will refresh shortly", "clearCacheFailed": "Failed to clear cache, please try again", "webSearchSettings": "Web Search", - "webSearchApiKey": "Tavily API Key", - "webSearchApiKeyPlaceholder": "Enter your Tavily API Key", + "webSearchApiKey": "API Key", + "webSearchApiKeyPlaceholder": "Enter API Key", "webSearchApiKeyPlaceholderServer": "Server key configured, optionally override", - "webSearchApiKeyHint": "Get an API key from tavily.com for web search", + "webSearchApiKeyHint": "Enter the API key provided by the search service", "webSearchBaseUrl": "Base URL", - "webSearchServerConfigured": "Server-side Tavily API key is configured", - "optional": "Optional" + "webSearchServerConfigured": "Server-side API key is configured", + "optional": "Optional", + "webSearchTavilyApiKey": "Tavily API Key", + "webSearchClaudeApiKey": "Claude API Key", + "webSearchModelId": "Model ID", + "webSearchApiVersion": "API Version", + "webSearchToolsConfiguration": "Tools", + "webSearchNewTool": "New Tool", + "webSearchNoTools": "No tools configured.", + "webSearchAddToolTitle": "Add Tool", + "webSearchEditToolTitle": "Edit Tool", + "webSearchToolDialogDesc": "Configure the tool's type and name for Claude search.", + "webSearchToolType": "Type", + "webSearchToolName": "Name", + "webSearchModelsConfiguration": "Models", + "webSearchNewModel": "New Model", + "webSearchNoModels": "No models configured.", + "webSearchNoModelsHint": "Add at least one model to use Claude search.", + "webSearchAddModelTitle": "Add Model", + "webSearchEditModelTitle": "Edit Model", + "webSearchModelDialogDesc": "Configure the model's ID and name for Claude search.", + "webSearchModelIdField": "Model ID", + "webSearchModelNameField": "Model Name", + "webSearchModelIdDuplicate": "A model with this ID already exists" }, "profile": { "title": "Profile", diff --git a/lib/i18n/locales/ja-JP.json b/lib/i18n/locales/ja-JP.json index 70a353f18..650f5d28e 100644 --- a/lib/i18n/locales/ja-JP.json +++ b/lib/i18n/locales/ja-JP.json @@ -879,13 +879,35 @@ "clearCacheSuccess": "キャッシュをクリアしました。まもなくページが更新されます", "clearCacheFailed": "キャッシュのクリアに失敗しました。もう一度お試しください", "webSearchSettings": "ウェブ検索", - "webSearchApiKey": "Tavily APIキー", - "webSearchApiKeyPlaceholder": "Tavily APIキーを入力", + "webSearchApiKey": "APIキー", + "webSearchApiKeyPlaceholder": "APIキーを入力", "webSearchApiKeyPlaceholderServer": "サーバーキー設定済み、任意で上書き", - "webSearchApiKeyHint": "ウェブ検索用のAPIキーをtavily.comで取得してください", + "webSearchApiKeyHint": "検索サービスが提供するAPIキーを入力してください", "webSearchBaseUrl": "ベースURL", - "webSearchServerConfigured": "サーバー側でTavily APIキーが設定済みです", - "optional": "任意" + "webSearchServerConfigured": "サーバー側のAPIキーが設定済みです", + "optional": "任意", + "webSearchTavilyApiKey": "Tavily APIキー", + "webSearchClaudeApiKey": "Claude APIキー", + "webSearchModelId": "モデルID", + "webSearchApiVersion": "APIバージョン", + "webSearchToolsConfiguration": "ツール", + "webSearchNewTool": "新規ツール", + "webSearchNoTools": "ツールが設定されていません。", + "webSearchAddToolTitle": "ツールを追加", + "webSearchEditToolTitle": "ツールを編集", + "webSearchToolDialogDesc": "Claude検索用のツールタイプと名前を設定します。", + "webSearchToolType": "タイプ", + "webSearchToolName": "名前", + "webSearchModelsConfiguration": "モデル", + "webSearchNewModel": "新規モデル", + "webSearchNoModels": "モデルが設定されていません。", + "webSearchNoModelsHint": "Claude検索を使用するには、少なくとも1つのモデルを追加してください。", + "webSearchAddModelTitle": "モデルを追加", + "webSearchEditModelTitle": "モデルを編集", + "webSearchModelDialogDesc": "Claude検索用のモデルIDと名前を設定します。", + "webSearchModelIdField": "モデルID", + "webSearchModelNameField": "モデル名", + "webSearchModelIdDuplicate": "このIDのモデルはすでに存在します" }, "profile": { "title": "プロフィール", diff --git a/lib/i18n/locales/ru-RU.json b/lib/i18n/locales/ru-RU.json index c6a0291b5..c7fd8608e 100644 --- a/lib/i18n/locales/ru-RU.json +++ b/lib/i18n/locales/ru-RU.json @@ -879,13 +879,35 @@ "clearCacheSuccess": "Кэш очищен, страница скоро обновится", "clearCacheFailed": "Не удалось очистить кэш, попробуйте снова", "webSearchSettings": "Веб-поиск", - "webSearchApiKey": "Tavily API-ключ", - "webSearchApiKeyPlaceholder": "Введите ваш Tavily API-ключ", + "webSearchApiKey": "API-ключ", + "webSearchApiKeyPlaceholder": "Введите API-ключ", "webSearchApiKeyPlaceholderServer": "Серверный ключ настроен, можно ввести свой", - "webSearchApiKeyHint": "Получите API-ключ на tavily.com для веб-поиска", + "webSearchApiKeyHint": "Введите API-ключ, предоставленный сервисом поиска", "webSearchBaseUrl": "Base URL", - "webSearchServerConfigured": "Серверный Tavily API-ключ настроен", - "optional": "Необязательно" + "webSearchServerConfigured": "API-ключ настроен на стороне сервера", + "optional": "Необязательно", + "webSearchTavilyApiKey": "Tavily API-ключ", + "webSearchClaudeApiKey": "Claude API-ключ", + "webSearchModelId": "ID модели", + "webSearchApiVersion": "Версия API", + "webSearchToolsConfiguration": "Инструменты", + "webSearchNewTool": "Новый инструмент", + "webSearchNoTools": "Инструменты не настроены.", + "webSearchAddToolTitle": "Добавить инструмент", + "webSearchEditToolTitle": "Редактировать инструмент", + "webSearchToolDialogDesc": "Настройте тип и название инструмента для поиска Claude.", + "webSearchToolType": "Тип", + "webSearchToolName": "Название", + "webSearchModelsConfiguration": "Модели", + "webSearchNewModel": "Новая модель", + "webSearchNoModels": "Модели не настроены.", + "webSearchNoModelsHint": "Добавьте хотя бы одну модель для использования поиска Claude.", + "webSearchAddModelTitle": "Добавить модель", + "webSearchEditModelTitle": "Редактировать модель", + "webSearchModelDialogDesc": "Настройте ID и название модели для поиска Claude.", + "webSearchModelIdField": "ID модели", + "webSearchModelNameField": "Название модели", + "webSearchModelIdDuplicate": "Модель с таким ID уже существует" }, "profile": { "title": "Профиль", diff --git a/lib/i18n/locales/zh-CN.json b/lib/i18n/locales/zh-CN.json index 1c114c833..e8faa37c9 100644 --- a/lib/i18n/locales/zh-CN.json +++ b/lib/i18n/locales/zh-CN.json @@ -879,13 +879,35 @@ "clearCacheSuccess": "缓存已清空,页面即将刷新", "clearCacheFailed": "清空缓存失败,请重试", "webSearchSettings": "网络搜索", - "webSearchApiKey": "Tavily API Key", - "webSearchApiKeyPlaceholder": "输入你的 Tavily API Key", + "webSearchApiKey": "API 密钥", + "webSearchApiKeyPlaceholder": "输入 API Key", "webSearchApiKeyPlaceholderServer": "已配置服务端密钥,可选填覆盖", - "webSearchApiKeyHint": "从 tavily.com 获取 API Key,用于网络搜索", + "webSearchApiKeyHint": "请输入相应服务商提供的 API Key,用于网络搜索", "webSearchBaseUrl": "Base URL", - "webSearchServerConfigured": "服务端已配置 Tavily API Key", - "optional": "可选" + "webSearchServerConfigured": "服务端已配置 API 密钥", + "optional": "可选", + "webSearchTavilyApiKey": "Tavily API 密钥", + "webSearchClaudeApiKey": "Claude API 密钥", + "webSearchModelId": "模型 ID", + "webSearchApiVersion": "API 版本", + "webSearchToolsConfiguration": "工具", + "webSearchNewTool": "新建工具", + "webSearchNoTools": "暂无已配置的工具", + "webSearchAddToolTitle": "添加工具", + "webSearchEditToolTitle": "编辑工具", + "webSearchToolDialogDesc": "配置 Claude 搜索工具的类型和名称", + "webSearchToolType": "类型", + "webSearchToolName": "名称", + "webSearchModelsConfiguration": "模型", + "webSearchNewModel": "新建模型", + "webSearchNoModels": "暂无已配置的模型。", + "webSearchNoModelsHint": "请添加至少一个模型才能使用 Claude 搜索。", + "webSearchAddModelTitle": "添加模型", + "webSearchEditModelTitle": "编辑模型", + "webSearchModelDialogDesc": "配置 Claude 搜索模型的 ID 和名称", + "webSearchModelIdField": "模型 ID", + "webSearchModelNameField": "模型名称", + "webSearchModelIdDuplicate": "该模型 ID 已存在" }, "profile": { "title": "个人资料", diff --git a/lib/server/api-response.ts b/lib/server/api-response.ts index 07d2b6d68..2aa1651f9 100644 --- a/lib/server/api-response.ts +++ b/lib/server/api-response.ts @@ -2,6 +2,7 @@ import { NextResponse } from 'next/server'; export const API_ERROR_CODES = { MISSING_REQUIRED_FIELD: 'MISSING_REQUIRED_FIELD', + MISSING_PROVIDER: 'MISSING_PROVIDER', MISSING_API_KEY: 'MISSING_API_KEY', INVALID_REQUEST: 'INVALID_REQUEST', INVALID_URL: 'INVALID_URL', diff --git a/lib/server/classroom-generation.ts b/lib/server/classroom-generation.ts index 48c644bb6..d4f88b5ee 100644 --- a/lib/server/classroom-generation.ts +++ b/lib/server/classroom-generation.ts @@ -20,6 +20,9 @@ import { resolveWebSearchApiKey } from '@/lib/server/provider-config'; import { resolveModel } from '@/lib/server/resolve-model'; import { buildSearchQuery } from '@/lib/server/search-query-builder'; import { searchWithTavily, formatSearchResultsAsContext } from '@/lib/web-search/tavily'; +import { searchWithClaude } from '@/lib/web-search/claude'; +import { WEB_SEARCH_PROVIDERS } from '@/lib/web-search/constants'; +import { validateUrlForSSRF } from '@/lib/server/ssrf-guard'; import { persistClassroom } from '@/lib/server/classroom-storage'; import { generateMediaForClassroom, @@ -36,6 +39,10 @@ export interface GenerateClassroomInput { requirement: string; pdfContent?: { text: string; images: string[] }; enableWebSearch?: boolean; + webSearchProviderId?: string; + webSearchBaseUrl?: string; + webSearchModelId?: string; + webSearchTools?: Array<{ type: string; name: string }>; enableImageGeneration?: boolean; enableVideoGeneration?: boolean; enableTTS?: boolean; @@ -234,31 +241,64 @@ export async function generateClassroom( // Web search (optional, graceful degradation) let researchContext: string | undefined; if (input.enableWebSearch) { - const tavilyKey = resolveWebSearchApiKey(); - if (tavilyKey) { - try { - const searchQuery = await buildSearchQuery(requirement, pdfText, searchQueryAiCall); - - log.info('Running web search for classroom generation', { - hasPdfContext: searchQuery.hasPdfContext, - rawRequirementLength: searchQuery.rawRequirementLength, - rewriteAttempted: searchQuery.rewriteAttempted, - finalQueryLength: searchQuery.finalQueryLength, - }); - - const searchResult = await searchWithTavily({ - query: searchQuery.query, - apiKey: tavilyKey, - }); - researchContext = formatSearchResultsAsContext(searchResult); - if (researchContext) { - log.info(`Web search returned ${searchResult.sources.length} sources`); + // Validate and resolve the provider ID; unknown values are treated as 'tavily' (safe default). + const rawProviderId = input.webSearchProviderId || 'tavily'; + const providerId = + rawProviderId in WEB_SEARCH_PROVIDERS + ? (rawProviderId as keyof typeof WEB_SEARCH_PROVIDERS) + : ('tavily' as const); + if (rawProviderId !== providerId) { + log.warn(`Unknown webSearchProviderId "${rawProviderId}", falling back to tavily`); + } + const searchKey = resolveWebSearchApiKey(providerId); + if (searchKey) { + const ssrfError = input.webSearchBaseUrl ? validateUrlForSSRF(input.webSearchBaseUrl) : null; + if (ssrfError) { + log.warn(`webSearchBaseUrl rejected by SSRF guard (${ssrfError}), skipping web search`); + } else { + try { + const searchQuery = await buildSearchQuery(requirement, pdfText, searchQueryAiCall); + + log.info('Running web search for classroom generation', { + provider: providerId, + hasPdfContext: searchQuery.hasPdfContext, + rawRequirementLength: searchQuery.rawRequirementLength, + rewriteAttempted: searchQuery.rewriteAttempted, + finalQueryLength: searchQuery.finalQueryLength, + }); + + const effectiveBaseUrl = + input.webSearchBaseUrl || WEB_SEARCH_PROVIDERS[providerId]?.defaultBaseUrl || ''; + + let searchResult; + if (providerId === 'claude') { + searchResult = await searchWithClaude({ + query: searchQuery.query, + apiKey: searchKey, + baseUrl: effectiveBaseUrl, + modelId: input.webSearchModelId, + tools: input.webSearchTools, + }); + } else { + searchResult = await searchWithTavily({ + query: searchQuery.query, + apiKey: searchKey, + baseUrl: effectiveBaseUrl, + }); + } + + researchContext = formatSearchResultsAsContext(searchResult); + if (researchContext) { + log.info(`Web search returned ${searchResult.sources.length} sources`); + } + } catch (e) { + log.warn('Web search failed, continuing without search context:', e); } - } catch (e) { - log.warn('Web search failed, continuing without search context:', e); } } else { - log.warn('enableWebSearch is true but no Tavily API key configured, skipping web search'); + log.warn( + `enableWebSearch is true but no API key configured for ${providerId}, skipping web search`, + ); } } diff --git a/lib/server/provider-config.ts b/lib/server/provider-config.ts index c66835912..0ebcc1eef 100644 --- a/lib/server/provider-config.ts +++ b/lib/server/provider-config.ts @@ -92,6 +92,9 @@ const VIDEO_ENV_MAP: Record = { const WEB_SEARCH_ENV_MAP: Record = { TAVILY: 'tavily', + CLAUDE: 'claude', + // Also recognise ANTHROPIC_API_KEY so server-config detection aligns with resolveWebSearchApiKey + ANTHROPIC: 'claude', }; // --------------------------------------------------------------------------- @@ -404,9 +407,9 @@ export function resolveVideoBaseUrl( // --------------------------------------------------------------------------- /** Returns server-configured web search providers (no apiKeys exposed) */ -export function getServerWebSearchProviders(): Record { +export function getServerWebSearchProviders(): Record { const cfg = getConfig(); - const result: Record = {}; + const result: Record = {}; for (const [id, entry] of Object.entries(cfg.webSearch)) { result[id] = {}; if (entry.baseUrl) result[id].baseUrl = entry.baseUrl; @@ -414,10 +417,13 @@ export function getServerWebSearchProviders(): Record server key > TAVILY_API_KEY env > empty */ -export function resolveWebSearchApiKey(clientKey?: string): string { +/** Resolve Web Search API key: client key > server key > env fallback > empty */ +export function resolveWebSearchApiKey(providerId: string, clientKey?: string): string { if (clientKey) return clientKey; - const serverKey = getConfig().webSearch.tavily?.apiKey; + const serverKey = getConfig().webSearch[providerId]?.apiKey; if (serverKey) return serverKey; - return process.env.TAVILY_API_KEY || ''; + // Claude web search reuses the standard Anthropic API key + const envVar = + providerId === 'claude' ? 'ANTHROPIC_API_KEY' : `${providerId.toUpperCase()}_API_KEY`; + return process.env[envVar] || ''; } diff --git a/lib/server/ssrf-guard.ts b/lib/server/ssrf-guard.ts index e40bb8142..a100b872f 100644 --- a/lib/server/ssrf-guard.ts +++ b/lib/server/ssrf-guard.ts @@ -166,7 +166,7 @@ export function isPrivateIP(ip: string): boolean { * Validate a URL against SSRF attacks. * Returns null if the URL is safe, or an error message string if blocked. */ -export async function validateUrlForSSRF(url: string): Promise { +export async function validateUrlForSSRF(url: string): Promise { let parsed: URL; try { parsed = new URL(url); @@ -181,7 +181,7 @@ export async function validateUrlForSSRF(url: string): Promise { // Self-hosted deployments can set ALLOW_LOCAL_NETWORKS=true to skip private-IP checks const allowLocal = process.env.ALLOW_LOCAL_NETWORKS; if (allowLocal === 'true' || allowLocal === '1') { - return null; + return ''; } const hostname = normalizeAddress(parsed.hostname); @@ -196,7 +196,7 @@ export async function validateUrlForSSRF(url: string): Promise { } if (isIP(hostname)) { - return null; + return ''; } let resolvedAddresses: Array<{ address: string; family: number }>; @@ -214,5 +214,5 @@ export async function validateUrlForSSRF(url: string): Promise { return 'Local/private network URLs are not allowed'; } - return null; + return ''; } diff --git a/lib/store/settings.ts b/lib/store/settings.ts index aaa1c3831..0e3f68031 100644 --- a/lib/store/settings.ts +++ b/lib/store/settings.ts @@ -133,7 +133,8 @@ export interface SettingsState { videoGenerationEnabled: boolean; // Web Search settings - webSearchProviderId: WebSearchProviderId; + webSearchEnabled: boolean; + webSearchProviderId: WebSearchProviderId | null; webSearchProvidersConfig: Record< WebSearchProviderId, { @@ -142,6 +143,9 @@ export interface SettingsState { enabled: boolean; isServerConfigured?: boolean; serverBaseUrl?: string; + modelId?: string; + tools?: Array<{ type: string; name: string }>; + models?: Array<{ id: string; name: string }>; } >; @@ -275,10 +279,18 @@ export interface SettingsState { setVideoGenerationEnabled: (enabled: boolean) => void; // Web Search actions - setWebSearchProvider: (providerId: WebSearchProviderId) => void; + setWebSearchEnabled: (enabled: boolean) => void; + setWebSearchProvider: (providerId: WebSearchProviderId | null) => void; setWebSearchProviderConfig: ( providerId: WebSearchProviderId, - config: Partial<{ apiKey: string; baseUrl: string; enabled: boolean }>, + config: Partial<{ + apiKey: string; + baseUrl: string; + enabled: boolean; + modelId: string; + tools: Array<{ type: string; name: string }>; + models: Array<{ id: string; name: string }>; + }>, ) => void; // Server provider actions @@ -371,10 +383,28 @@ const getDefaultVideoConfig = () => ({ // Initialize default Web Search config const getDefaultWebSearchConfig = () => ({ - webSearchProviderId: 'tavily' as WebSearchProviderId, + webSearchProviderId: null as WebSearchProviderId | null, webSearchProvidersConfig: { - tavily: { apiKey: '', baseUrl: '', enabled: true }, - } as Record, + tavily: { apiKey: '', baseUrl: WEB_SEARCH_PROVIDERS.tavily.defaultBaseUrl, enabled: true }, + claude: { + apiKey: '', + baseUrl: WEB_SEARCH_PROVIDERS.claude.defaultBaseUrl, + enabled: true, + modelId: '', + tools: [{ type: 'web_search_20260209', name: 'web_search' }], + models: WEB_SEARCH_PROVIDERS.claude?.models?.map((m) => ({ id: m.id, name: m.name })) ?? [], + }, + } as Record< + WebSearchProviderId, + { + apiKey: string; + baseUrl: string; + enabled: boolean; + modelId?: string; + tools?: Array<{ type: string; name: string }>; + models?: Array<{ id: string; name: string }>; + } + >, }); /** @@ -400,7 +430,10 @@ function ensureValidProviderSelections(state: Partial): void { state.pdfProviderId = defaultPdfConfig.pdfProviderId; } - if (!hasProviderId(WEB_SEARCH_PROVIDERS, state.webSearchProviderId)) { + if ( + state.webSearchProviderId !== null && + !hasProviderId(WEB_SEARCH_PROVIDERS, state.webSearchProviderId) + ) { state.webSearchProviderId = defaultWebSearchConfig.webSearchProviderId; } @@ -519,6 +552,30 @@ function ensureBuiltInVideoProviders(state: Partial): void { }); } +/** + * Ensure webSearchProvidersConfig includes all built-in web search providers. + * Called on every rehydrate so newly added providers appear automatically. + */ +function ensureBuiltInWebSearchProviders(state: Partial): void { + if (!state.webSearchProvidersConfig) return; + const defaultConfig = getDefaultWebSearchConfig().webSearchProvidersConfig; + Object.keys(defaultConfig).forEach((pid) => { + const providerId = pid as WebSearchProviderId; + if (!state.webSearchProvidersConfig![providerId]) { + state.webSearchProvidersConfig![providerId] = defaultConfig[providerId]; + } else { + const entry = state.webSearchProvidersConfig![providerId]; + if (!entry.baseUrl) { + entry.baseUrl = defaultConfig[providerId].baseUrl; + } + if (!entry.models?.length && defaultConfig[providerId]?.models?.length) { + // Initialize models from defaults if not yet set for this provider + entry.models = defaultConfig[providerId].models; + } + } + }); +} + // Migrate from old localStorage format const migrateFromOldStorage = () => { if (typeof window === 'undefined') return null; @@ -636,6 +693,9 @@ export const useSettingsStore = create()( imageGenerationEnabled: false, videoGenerationEnabled: false, + // Web search toggle (off by default) + webSearchEnabled: false, + // Audio feature toggles (on by default) ttsEnabled: true, asrEnabled: true, @@ -872,17 +932,80 @@ export const useSettingsStore = create()( }), // Web Search actions - setWebSearchProvider: (providerId) => set({ webSearchProviderId: providerId }), + setWebSearchEnabled: (enabled) => { + if (enabled) { + const state = get(); + const cfg = state.webSearchProvidersConfig; + const firstUsableProviderId = (Object.keys(cfg) as WebSearchProviderId[]).find( + (id) => cfg[id].isServerConfigured || cfg[id].apiKey, + ); + if (!firstUsableProviderId) return; + set({ webSearchEnabled: true }); + // Auto-select a provider when none is selected yet + if (!state.webSearchProviderId) { + get().setWebSearchProvider(firstUsableProviderId); + } + } else { + set({ webSearchEnabled: false }); + // Also deselect provider (which clears modelId per setWebSearchProvider logic) + get().setWebSearchProvider(null); + } + }, + setWebSearchProvider: (providerId) => + set((state) => { + if (providerId !== null) return { webSearchProviderId: providerId }; + // Deselect: clear modelId for the previously selected provider + const prev = state.webSearchProviderId; + if (!prev || !state.webSearchProvidersConfig[prev]?.modelId) { + return { webSearchProviderId: null }; + } + return { + webSearchProviderId: null, + webSearchProvidersConfig: { + ...state.webSearchProvidersConfig, + [prev]: { ...state.webSearchProvidersConfig[prev], modelId: '' }, + }, + }; + }), setWebSearchProviderConfig: (providerId, config) => - set((state) => ({ - webSearchProvidersConfig: { - ...state.webSearchProvidersConfig, - [providerId]: { - ...state.webSearchProvidersConfig[providerId], - ...config, + set((state) => { + const updatedProviderConfig = { + ...state.webSearchProvidersConfig[providerId], + ...config, + }; + const apiKeyRemoved = + 'apiKey' in config && !config.apiKey && !updatedProviderConfig.isServerConfigured; + const isSelected = state.webSearchProviderId === providerId; + + // When the selected provider loses its key, try to switch to another usable provider + // or disable web search entirely + let extraUpdates: Record = {}; + if (apiKeyRemoved && isSelected) { + const updatedConfig = { + ...state.webSearchProvidersConfig, + [providerId]: updatedProviderConfig, + }; + const otherUsableProviderId = ( + Object.keys(updatedConfig) as WebSearchProviderId[] + ).find( + (id) => + id !== providerId && + (updatedConfig[id].isServerConfigured || updatedConfig[id].apiKey), + ); + extraUpdates = { + webSearchProviderId: otherUsableProviderId ?? null, + ...(otherUsableProviderId ? {} : { webSearchEnabled: false }), + }; + } + + return { + webSearchProvidersConfig: { + ...state.webSearchProvidersConfig, + [providerId]: updatedProviderConfig, }, - }, - })), + ...extraUpdates, + }; + }), // Fetch server-configured providers and merge into local state fetchServerProviders: async () => { @@ -1354,9 +1477,10 @@ export const useSettingsStore = create()( ensureBuiltInProviders(state); promoteLegacyCustomProviderBaseUrls(state); - // Ensure image/video configs have all built-in providers + // Ensure image/video/web-search configs have all built-in providers ensureBuiltInImageProviders(state); ensureBuiltInVideoProviders(state); + ensureBuiltInWebSearchProviders(state); // Migrate from old ttsModel to new ttsProviderId if (state.ttsModel && !state.ttsProviderId) { @@ -1464,6 +1588,10 @@ export const useSettingsStore = create()( const oldIsServerConfigured = (stateRecord.webSearchIsServerConfigured as boolean) || false; state.webSearchProviderId = 'tavily' as WebSearchProviderId; + // Enable web search if old user had a configured provider + if (oldApiKey || oldIsServerConfigured) { + state.webSearchEnabled = true; + } state.webSearchProvidersConfig = { tavily: { apiKey: oldApiKey, @@ -1471,6 +1599,13 @@ export const useSettingsStore = create()( enabled: true, isServerConfigured: oldIsServerConfigured, }, + claude: { + apiKey: '', + baseUrl: '', + enabled: true, + modelId: '', + tools: [{ type: 'web_search_20260209', name: 'web_search' }], + }, } as SettingsState['webSearchProvidersConfig']; delete stateRecord.webSearchApiKey; delete stateRecord.webSearchIsServerConfigured; @@ -1488,6 +1623,7 @@ export const useSettingsStore = create()( promoteLegacyCustomProviderBaseUrls(merged as Partial); ensureBuiltInImageProviders(merged as Partial); ensureBuiltInVideoProviders(merged as Partial); + ensureBuiltInWebSearchProviders(merged as Partial); ensureValidProviderSelections(merged as Partial); return merged as SettingsState; }, diff --git a/lib/types/web-search.ts b/lib/types/web-search.ts index ba2624b86..10729b552 100644 --- a/lib/types/web-search.ts +++ b/lib/types/web-search.ts @@ -2,7 +2,7 @@ export interface WebSearchSource { title: string; url: string; content: string; - score: number; + score?: number; } export interface WebSearchResult { diff --git a/lib/web-search/claude.ts b/lib/web-search/claude.ts new file mode 100644 index 000000000..68a81f24d --- /dev/null +++ b/lib/web-search/claude.ts @@ -0,0 +1,166 @@ +/** + * Claude Web Search Integration + * + * Uses the AI SDK Anthropic provider with the native web_search_20260209 tool. + */ + +import { generateText } from 'ai'; +import { createAnthropic, type AnthropicProvider } from '@ai-sdk/anthropic'; +import { proxyFetch } from '@/lib/server/proxy-fetch'; +import { validateUrlForSSRF } from '@/lib/server/ssrf-guard'; +import { createLogger } from '@/lib/logger'; +import type { WebSearchResult, WebSearchSource } from '@/lib/types/web-search'; + +type ToolDef = { type: string; name: string }; + +const DEFAULT_TOOLS: ToolDef[] = [{ type: 'web_search_20260209', name: 'web_search' }]; + +function buildTools(provider: AnthropicProvider, configuredTools?: ToolDef[]) { + const defs = configuredTools?.length ? configuredTools : DEFAULT_TOOLS; + return Object.fromEntries( + defs.map((t) => + t.type === 'web_search_20250305' + ? [t.name, provider.tools.webSearch_20250305()] + : [t.name, provider.tools.webSearch_20260209()], + ), + ); +} + +/** + * Wraps proxyFetch to inject `allowed_callers: ["direct"]` on every tool in outgoing + * Anthropic API requests. The AI SDK's provider-defined tool serialisation hard-codes the + * tool object and never emits `allowed_callers`, so we must patch it at the fetch layer. + */ +async function fetchWithAllowedCallers(url: string, init?: RequestInit): Promise { + if (init?.method === 'POST' && typeof init.body === 'string') { + try { + const body = JSON.parse(init.body); + if (Array.isArray(body?.tools)) { + const before = body.tools.map((t: Record) => t.allowed_callers); + body.tools = body.tools.map((tool: Record) => + tool.allowed_callers ? tool : { ...tool, allowed_callers: ['direct'] }, + ); + const after = body.tools.map((t: Record) => t.allowed_callers); + log.debug( + `fetchWithAllowedCallers: injecting allowed_callers url: ${url}, before: ${before}, after: ${after}`, + ); + init = { ...init, body: JSON.stringify(body) }; + } else { + log.debug(`fetchWithAllowedCallers: POST to ${url} — no tools array in body`); + } + log.debug(`final payload: ${JSON.stringify(body)}`); + } catch { + /* leave body unchanged if it can't be parsed */ + } + } else { + log.info( + `fetchWithAllowedCallers: called [method=${init?.method} bodyType=${typeof init?.body}]`, + ); + } + return proxyFetch(url, init); +} + +const PAGE_CONTENT_MAX_LENGTH = 2000; +const PAGE_FETCH_TIMEOUT_MS = 5000; + +const log = createLogger('ClaudeSearch'); + +/** Fetch a URL and return plain text extracted from its HTML. Returns empty string on any failure. */ +async function fetchPageContent(url: string): Promise { + const ssrfError = await validateUrlForSSRF(url); + if (ssrfError) { + log.warn(`Blocked page fetch due to SSRF check [url="${url}" reason="${ssrfError}"]`); + return ''; + } + log.info(`Fetching page content: ${url}`); + try { + const res = await proxyFetch(url, { + headers: { Accept: 'text/html', 'User-Agent': 'Mozilla/5.0 (compatible; OpenMAIC/1.0)' }, + signal: AbortSignal.timeout(PAGE_FETCH_TIMEOUT_MS), + }); + if (!res.ok) { + log.warn(`Failed to fetch page content [url="${url}" status=${res.status}]`); + return ''; + } + const html = await res.text(); + // Strip scripts, styles, and all tags; collapse whitespace + const text = html + .replace(//gi, ' ') + .replace(//gi, ' ') + .replace(/<[^>]+>/g, ' ') + .replace(/\s+/g, ' ') + .trim(); + const content = text.slice(0, PAGE_CONTENT_MAX_LENGTH); + log.info(`Fetched page content [url="${url}" chars=${content.length}]`); + return content; + } catch (e) { + log.warn(`Error fetching page content [url="${url}"]:`, e); + return ''; + } +} + +/** + * Search the web using Claude's native web search tool via the AI SDK. + */ +export async function searchWithClaude(params: { + query: string; + apiKey: string; + modelId?: string; + baseUrl: string; + tools?: ToolDef[]; +}): Promise { + const { query, apiKey, modelId: rawModelId, baseUrl, tools } = params; + const modelId = rawModelId?.trim() || 'claude-sonnet-4-6'; + + const provider = createAnthropic({ + apiKey, + baseURL: baseUrl, + fetch: fetchWithAllowedCallers as typeof fetch, + }); + + try { + const startTime = Date.now(); + + const result = await generateText({ + model: provider(modelId), + messages: [ + { + role: 'user', + content: `Search for the following and provide a comprehensive summary with source links: ${query}.`, + }, + ], + maxOutputTokens: 4096, + tools: buildTools(provider, tools), + }); + + // The AI SDK surfaces web search results as sources (url + title only; no snippet content). + // We fetch each page to populate content, then drop any that fail. + const sources: WebSearchSource[] = result.sources.flatMap((s) => { + if (s.sourceType !== 'url') return []; + return [{ url: s.url, title: s.title || s.url, content: '' }]; + }); + + await Promise.all( + sources.map(async (s) => { + s.content = await fetchPageContent(s.url); + }), + ); + + const sourcesWithContent = sources.filter((s) => s.content); + + return { + answer: result.text, + sources: sourcesWithContent, + query, + responseTime: Date.now() - startTime, + }; + } catch (e) { + log.error('Claude search failed', e); + throw e; + } +} + +/** + * Reuse formatting logic from Tavily. + */ +export { formatSearchResultsAsContext } from './tavily'; diff --git a/lib/web-search/constants.ts b/lib/web-search/constants.ts index 6542bbb2a..7c56fe274 100644 --- a/lib/web-search/constants.ts +++ b/lib/web-search/constants.ts @@ -13,6 +13,23 @@ export const WEB_SEARCH_PROVIDERS: Record { - const { query, apiKey, maxResults = 5 } = params; + const { query, apiKey, baseUrl, maxResults = 5 } = params; // Tavily rejects queries over 400 characters with a 400 error const truncatedQuery = query.slice(0, TAVILY_MAX_QUERY_LENGTH); - const res = await proxyFetch(TAVILY_API_URL, { + const res = await proxyFetch(`${baseUrl}/search`, { method: 'POST', headers: { 'Content-Type': 'application/json', diff --git a/lib/web-search/types.ts b/lib/web-search/types.ts index f83822c7c..f3f0e787f 100644 --- a/lib/web-search/types.ts +++ b/lib/web-search/types.ts @@ -5,15 +5,22 @@ /** * Web Search Provider IDs */ -export type WebSearchProviderId = 'tavily'; +export type WebSearchProviderId = 'tavily' | 'claude'; /** * Web Search Provider Configuration */ +export interface WebSearchModel { + id: string; + name: string; +} + export interface WebSearchProviderConfig { id: WebSearchProviderId; name: string; requiresApiKey: boolean; defaultBaseUrl?: string; icon?: string; + path?: string; + models?: WebSearchModel[]; } diff --git a/public/logos/tavily.svg b/public/logos/tavily.svg new file mode 100644 index 000000000..08e491352 --- /dev/null +++ b/public/logos/tavily.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/tests/server/provider-config.test.ts b/tests/server/provider-config.test.ts index 58d05c942..cd6a64f7f 100644 --- a/tests/server/provider-config.test.ts +++ b/tests/server/provider-config.test.ts @@ -155,15 +155,31 @@ providers: }); describe('resolveWebSearchApiKey', () => { - it('returns client key first', async () => { + it('returns client key first for tavily', async () => { const { resolveWebSearchApiKey } = await import('@/lib/server/provider-config'); - expect(resolveWebSearchApiKey('client-key')).toBe('client-key'); + expect(resolveWebSearchApiKey('tavily', 'client-key')).toBe('client-key'); }); it('falls back to TAVILY_API_KEY env var', async () => { vi.stubEnv('TAVILY_API_KEY', 'tvly-bare-env'); const { resolveWebSearchApiKey } = await import('@/lib/server/provider-config'); - expect(resolveWebSearchApiKey()).toBe('tvly-bare-env'); + expect(resolveWebSearchApiKey('tavily')).toBe('tvly-bare-env'); + }); + + it('returns client key first for claude', async () => { + const { resolveWebSearchApiKey } = await import('@/lib/server/provider-config'); + expect(resolveWebSearchApiKey('claude', 'sk-client')).toBe('sk-client'); + }); + + it('falls back to ANTHROPIC_API_KEY env var for claude', async () => { + vi.stubEnv('ANTHROPIC_API_KEY', 'sk-anthropic-env'); + const { resolveWebSearchApiKey } = await import('@/lib/server/provider-config'); + expect(resolveWebSearchApiKey('claude')).toBe('sk-anthropic-env'); + }); + + it('returns empty string when no key configured', async () => { + const { resolveWebSearchApiKey } = await import('@/lib/server/provider-config'); + expect(resolveWebSearchApiKey('tavily')).toBe(''); }); }); diff --git a/tests/server/ssrf-guard.test.ts b/tests/server/ssrf-guard.test.ts index 9aa95d813..108b3a926 100644 --- a/tests/server/ssrf-guard.test.ts +++ b/tests/server/ssrf-guard.test.ts @@ -21,21 +21,21 @@ describe('validateUrlForSSRF', () => { const { validateUrlForSSRF } = await import('@/lib/server/ssrf-guard'); - await expect(validateUrlForSSRF('https://api.openai.com')).resolves.toBeNull(); + await expect(validateUrlForSSRF('https://api.openai.com')).resolves.toBe(''); expect(lookupMock).toHaveBeenCalledWith('api.openai.com', { all: true, verbatim: true }); }); it('allows a public IP literal without DNS lookup', async () => { const { validateUrlForSSRF } = await import('@/lib/server/ssrf-guard'); - await expect(validateUrlForSSRF('https://8.8.8.8')).resolves.toBeNull(); + await expect(validateUrlForSSRF('https://8.8.8.8')).resolves.toBe(''); expect(lookupMock).not.toHaveBeenCalled(); }); it('allows a public IPv6 literal without DNS lookup', async () => { const { validateUrlForSSRF } = await import('@/lib/server/ssrf-guard'); - await expect(validateUrlForSSRF('https://[2606:4700:4700::1111]')).resolves.toBeNull(); + await expect(validateUrlForSSRF('https://[2606:4700:4700::1111]')).resolves.toBe(''); expect(lookupMock).not.toHaveBeenCalled(); }); @@ -126,7 +126,7 @@ describe('validateUrlForSSRF', () => { const { validateUrlForSSRF } = await import('@/lib/server/ssrf-guard'); // 2002:0808:0808:: embeds 8.8.8.8 - await expect(validateUrlForSSRF('http://[2002:0808:0808::]')).resolves.toBeNull(); + await expect(validateUrlForSSRF('http://[2002:0808:0808::]')).resolves.toBe(''); expect(lookupMock).not.toHaveBeenCalled(); }); @@ -146,7 +146,7 @@ describe('validateUrlForSSRF', () => { // Client IPv4 8.8.8.8 XOR 0xFFFFFFFF = 0xF7F7F7F7 → hextets f7f7:f7f7 await expect( validateUrlForSSRF('http://[2001:0000:4136:e378:8000:63bf:f7f7:f7f7]'), - ).resolves.toBeNull(); + ).resolves.toBe(''); expect(lookupMock).not.toHaveBeenCalled(); }); diff --git a/tests/store/settings-web-search.test.ts b/tests/store/settings-web-search.test.ts new file mode 100644 index 000000000..13cbe8a27 --- /dev/null +++ b/tests/store/settings-web-search.test.ts @@ -0,0 +1,142 @@ +/** + * Tests for web search settings store behaviour: + * - Default tools pre-populated for the claude provider + * - ensureBuiltInWebSearchProviders fills missing providers on rehydrate + */ + +import { describe, it, expect, vi, beforeEach } from 'vitest'; + +vi.mock('@/lib/ai/providers', () => ({ + PROVIDERS: { + openai: { + id: 'openai', + name: 'OpenAI', + type: 'openai', + defaultBaseUrl: 'https://api.openai.com/v1', + requiresApiKey: true, + icon: '', + models: [{ id: 'gpt-4o', name: 'GPT-4o' }], + }, + }, +})); + +vi.mock('@/lib/audio/constants', () => ({ + TTS_PROVIDERS: { + 'browser-native-tts': { + id: 'browser-native-tts', + name: 'Browser Native TTS', + requiresApiKey: false, + defaultModelId: '', + models: [], + voices: [{ id: 'default', name: 'Default', language: 'en', gender: 'neutral' }], + supportedFormats: ['browser'], + }, + }, + ASR_PROVIDERS: { + 'browser-native': { + id: 'browser-native', + name: 'Browser Native ASR', + requiresApiKey: false, + defaultModelId: '', + models: [], + supportedLanguages: ['en'], + supportedFormats: ['browser'], + }, + }, + DEFAULT_TTS_VOICES: { 'browser-native-tts': 'default' }, +})); + +vi.mock('@/lib/audio/types', () => ({})); + +vi.mock('@/lib/pdf/constants', () => ({ + PDF_PROVIDERS: { unpdf: { id: 'unpdf', requiresApiKey: false } }, +})); + +vi.mock('@/lib/media/image-providers', () => ({ + IMAGE_PROVIDERS: {}, +})); + +vi.mock('@/lib/media/video-providers', () => ({ + VIDEO_PROVIDERS: {}, +})); + +vi.mock('@/lib/logger', () => ({ + createLogger: () => ({ + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }), +})); + +vi.stubGlobal('fetch', vi.fn().mockResolvedValue({ ok: false })); + +const storage = new Map(); +vi.stubGlobal('localStorage', { + getItem: (key: string) => storage.get(key) ?? null, + setItem: (key: string, value: string) => storage.set(key, value), + removeItem: (key: string) => storage.delete(key), +}); + +describe('web search store defaults', () => { + beforeEach(() => { + vi.resetModules(); + storage.clear(); + }); + + async function getStore() { + const { useSettingsStore } = await import('@/lib/store/settings'); + return useSettingsStore; + } + + it('pre-populates the default web_search tool for the claude provider', async () => { + const store = await getStore(); + const claudeConfig = store.getState().webSearchProvidersConfig.claude; + + expect(claudeConfig.tools).toContainEqual({ + type: 'web_search_20260209', + name: 'web_search', + }); + }); + + it('has at least one tool in the claude provider default config', async () => { + const store = await getStore(); + const tools = store.getState().webSearchProvidersConfig.claude.tools ?? []; + expect(tools.length).toBeGreaterThan(0); + }); + + it('populates claude provider config on rehydrate when missing from persisted state', async () => { + // Simulate persisted state that only has tavily (old format before claude was added) + storage.set( + 'openmaic-settings', + JSON.stringify({ + state: { + webSearchProviderId: 'tavily', + webSearchProvidersConfig: { + tavily: { apiKey: '', baseUrl: '', enabled: true }, + // claude missing intentionally + }, + }, + version: 0, + }), + ); + + const store = await getStore(); + const claudeConfig = store.getState().webSearchProvidersConfig.claude; + + expect(claudeConfig).toBeDefined(); + expect(claudeConfig.tools?.length).toBeGreaterThan(0); + }); + + it('setWebSearchProviderConfig persists tool changes', async () => { + const store = await getStore(); + const newTools = [ + { type: 'web_search_20260209', name: 'web_search' }, + { type: 'custom_tool', name: 'my_tool' }, + ]; + + store.getState().setWebSearchProviderConfig('claude', { tools: newTools }); + + expect(store.getState().webSearchProvidersConfig.claude.tools).toEqual(newTools); + }); +}); diff --git a/tests/web-search/claude.test.ts b/tests/web-search/claude.test.ts new file mode 100644 index 000000000..faa3f1297 --- /dev/null +++ b/tests/web-search/claude.test.ts @@ -0,0 +1,423 @@ +import { describe, it, expect, vi, beforeEach } from 'vitest'; + +// ── AI SDK mocks ────────────────────────────────────────────────────────────── + +const { mockGenerateText, mockProvider, mockTool, mockCreateAnthropic } = vi.hoisted(() => { + const mockTool = {}; + const mockModel = {}; + const mockProvider = Object.assign( + vi.fn(() => mockModel), + { + tools: { + webSearch_20260209: vi.fn(() => mockTool), + webSearch_20250305: vi.fn(() => mockTool), + }, + }, + ); + const mockCreateAnthropic = vi.fn(() => mockProvider); + const mockGenerateText = vi.fn(); + return { mockGenerateText, mockProvider, mockTool, mockCreateAnthropic }; +}); + +vi.mock('ai', () => ({ generateText: mockGenerateText })); +vi.mock('@ai-sdk/anthropic', () => ({ createAnthropic: mockCreateAnthropic })); + +// ── Infrastructure mocks ────────────────────────────────────────────────────── + +vi.mock('@/lib/server/proxy-fetch', () => ({ proxyFetch: vi.fn() })); + +vi.mock('@/lib/server/ssrf-guard', () => ({ + validateUrlForSSRF: async (url: string): Promise => { + let parsed: URL; + try { + parsed = new URL(url); + } catch { + return 'Invalid URL'; + } + if (!['http:', 'https:'].includes(parsed.protocol)) return 'Only HTTP(S) URLs are allowed'; + const hostname = parsed.hostname.replace(/^\[|\]$/g, ''); + const privatePatterns = [ + /^localhost$/i, + /^127\./, + /^10\./, + /^172\.(1[6-9]|2\d|3[01])\./, + /^192\.168\./, + /^169\.254\./, + /^::1$/, + ]; + if (privatePatterns.some((p) => p.test(hostname))) + return 'Local/private network URLs are not allowed'; + return ''; + }, +})); + +vi.mock('@/lib/logger', () => ({ + createLogger: () => ({ info: vi.fn(), warn: vi.fn(), error: vi.fn(), debug: vi.fn() }), +})); + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +import { proxyFetch } from '@/lib/server/proxy-fetch'; +import { searchWithClaude } from '@/lib/web-search/claude'; + +const mockProxyFetch = proxyFetch as ReturnType; + +type UrlSource = { sourceType: 'url'; type: 'source'; id: string; url: string; title?: string }; + +function mockAIResponse(text = 'Search result', sources: UrlSource[] = []) { + mockGenerateText.mockResolvedValueOnce({ text, sources }); +} + +function mockPageResponse(html: string) { + mockProxyFetch.mockResolvedValueOnce({ ok: true, text: async () => html }); +} + +function mockPageFailure() { + mockProxyFetch.mockResolvedValueOnce({ ok: false, status: 404 }); +} + +// ── Tests ───────────────────────────────────────────────────────────────────── + +describe('searchWithClaude', () => { + beforeEach(() => { + mockProxyFetch.mockReset(); + mockGenerateText.mockReset(); + mockCreateAnthropic.mockClear(); + }); + + // ── fetch interceptor: allowed_callers injection ────────────────────────── + + it('injects allowed_callers=["direct"] on tools that omit it', async () => { + mockAIResponse(); + await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + + const calls = mockCreateAnthropic.mock.calls as unknown as [options: Record][]; + const wrappedFetch = calls[0]![0]!['fetch'] as ( + url: string, + init?: RequestInit, + ) => Promise; + + const body = JSON.stringify({ tools: [{ type: 'web_search_20260209', name: 'web_search' }] }); + mockProxyFetch.mockResolvedValueOnce(new Response('{}', { status: 200 })); + await wrappedFetch('https://api.anthropic.com/v1/messages', { method: 'POST', body }); + + const proxyCalls = mockProxyFetch.mock.calls as unknown as [string, RequestInit][]; + const sentBody = JSON.parse(proxyCalls[0]![1]!.body as string); + expect(sentBody.tools[0].allowed_callers).toEqual(['direct']); + }); + + it('does not overwrite allowed_callers when already set', async () => { + mockAIResponse(); + await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + + const calls2 = mockCreateAnthropic.mock.calls as unknown as [ + options: Record, + ][]; + const wrappedFetch = calls2[0]![0]!['fetch'] as ( + url: string, + init?: RequestInit, + ) => Promise; + + const body = JSON.stringify({ + tools: [{ type: 'web_search_20260209', name: 'web_search', allowed_callers: ['agent'] }], + }); + mockProxyFetch.mockResolvedValueOnce(new Response('{}', { status: 200 })); + await wrappedFetch('https://api.anthropic.com/v1/messages', { method: 'POST', body }); + + const proxyCalls2 = mockProxyFetch.mock.calls as unknown as [string, RequestInit][]; + const sentBody = JSON.parse(proxyCalls2[0]![1]!.body as string); + expect(sentBody.tools[0].allowed_callers).toEqual(['agent']); + }); + + // ── provider setup ──────────────────────────────────────────────────────── + + it('passes baseUrl and apiKey to createAnthropic', async () => { + mockAIResponse(); + await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + expect(mockCreateAnthropic).toHaveBeenCalledWith( + expect.objectContaining({ baseURL: 'https://api.anthropic.com', apiKey: 'sk-test' }), + ); + }); + + it('calls generateText with the web_search tool', async () => { + mockAIResponse(); + await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + expect(mockGenerateText).toHaveBeenCalledWith( + expect.objectContaining({ tools: expect.objectContaining({ web_search: mockTool }) }), + ); + }); + + it('falls back to claude-sonnet-4-6 when no modelId provided', async () => { + mockAIResponse(); + await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + expect(mockProvider).toHaveBeenCalledWith('claude-sonnet-4-6'); + }); + + it('uses the provided modelId', async () => { + mockAIResponse(); + await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + modelId: 'claude-opus-4-7', + }); + expect(mockProvider).toHaveBeenCalledWith('claude-opus-4-7'); + }); + + // ── answer ──────────────────────────────────────────────────────────────── + + it('returns the answer text from generateText', async () => { + mockAIResponse('Comprehensive answer about the topic'); + const result = await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + expect(result.answer).toBe('Comprehensive answer about the topic'); + expect(result.query).toBe('test'); + }); + + // ── page content fetching ───────────────────────────────────────────────── + + it('fetches page content for each source URL', async () => { + mockAIResponse('Answer', [ + { sourceType: 'url', type: 'source', id: '1', url: 'https://example.com', title: 'Example' }, + ]); + mockPageResponse('

Page content here

'); + + const result = await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + + expect(mockProxyFetch).toHaveBeenCalledWith('https://example.com', expect.any(Object)); + expect(result.sources).toHaveLength(1); + expect(result.sources[0].content).toBe('Page content here'); + }); + + it('strips HTML tags and collapses whitespace from fetched page content', async () => { + mockAIResponse('Answer', [ + { sourceType: 'url', type: 'source', id: '1', url: 'https://example.com', title: 'Ex' }, + ]); + mockPageResponse(` + + + +

Title

Some content

+ + `); + + const result = await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + + expect(result.sources[0].content).not.toContain('<'); + expect(result.sources[0].content).not.toContain('alert'); + expect(result.sources[0].content).not.toContain('color: red'); + expect(result.sources[0].content).toContain('Title'); + expect(result.sources[0].content).toContain('Some content'); + }); + + it('fetches multiple sources in parallel', async () => { + mockAIResponse('Answer', [ + { sourceType: 'url', type: 'source', id: '1', url: 'https://a.com', title: 'A' }, + { sourceType: 'url', type: 'source', id: '2', url: 'https://b.com', title: 'B' }, + ]); + mockPageResponse('

Content A

'); + mockPageResponse('

Content B

'); + + const result = await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + + expect(result.sources).toHaveLength(2); + const fetchedUrls = mockProxyFetch.mock.calls.map((call: unknown[]) => call[0]); + expect(fetchedUrls).toContain('https://a.com'); + expect(fetchedUrls).toContain('https://b.com'); + expect(result.sources.find((s) => s.url === 'https://a.com')?.content).toContain('Content A'); + expect(result.sources.find((s) => s.url === 'https://b.com')?.content).toContain('Content B'); + }); + + it('filters out sources where page fetch returns non-ok response', async () => { + mockAIResponse('Answer', [ + { sourceType: 'url', type: 'source', id: '1', url: 'https://dead.com', title: 'Dead' }, + ]); + mockPageFailure(); + + const result = await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + expect(result.sources).toHaveLength(0); + }); + + it('filters out sources where page fetch throws (network error)', async () => { + mockAIResponse('Answer', [ + { sourceType: 'url', type: 'source', id: '1', url: 'https://dead.com', title: 'Dead' }, + ]); + mockProxyFetch.mockRejectedValueOnce(new Error('Network timeout')); + + const result = await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + expect(result.sources).toHaveLength(0); + }); + + it('keeps sources with content and drops sources without after mixed page fetches', async () => { + mockAIResponse('Answer', [ + { sourceType: 'url', type: 'source', id: '1', url: 'https://good.com', title: 'Good' }, + { sourceType: 'url', type: 'source', id: '2', url: 'https://dead.com', title: 'Dead' }, + ]); + mockPageResponse('

Good content

'); + mockPageFailure(); + + const result = await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + expect(result.sources).toHaveLength(1); + expect(result.sources[0].url).toBe('https://good.com'); + }); + + it('ignores non-url sources (document sources)', async () => { + mockGenerateText.mockResolvedValueOnce({ + text: 'Answer', + sources: [ + { sourceType: 'document', type: 'source', id: '1', mediaType: 'text/plain', title: 'Doc' }, + { sourceType: 'url', type: 'source', id: '2', url: 'https://example.com', title: 'Web' }, + ], + }); + mockPageResponse('

Web content

'); + + const result = await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + expect(result.sources).toHaveLength(1); + expect(result.sources[0].url).toBe('https://example.com'); + }); + + // ── SSRF protection ─────────────────────────────────────────────────────── + + it('skips page fetch for localhost URLs (SSRF protection)', async () => { + mockAIResponse('Answer', [ + { + sourceType: 'url', + type: 'source', + id: '1', + url: 'http://localhost/secret', + title: 'Local', + }, + ]); + + const result = await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + expect(mockProxyFetch).not.toHaveBeenCalled(); + expect(result.sources).toHaveLength(0); + }); + + it('skips page fetch for private IP URLs (SSRF protection)', async () => { + mockAIResponse('Answer', [ + { + sourceType: 'url', + type: 'source', + id: '1', + url: 'http://192.168.1.1/admin', + title: 'Private', + }, + ]); + + const result = await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + expect(mockProxyFetch).not.toHaveBeenCalled(); + expect(result.sources).toHaveLength(0); + }); + + it('skips page fetch for non-HTTP(S) URLs (SSRF protection)', async () => { + mockAIResponse('Answer', [ + { sourceType: 'url', type: 'source', id: '1', url: 'file:///etc/passwd', title: 'File' }, + ]); + + const result = await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + expect(mockProxyFetch).not.toHaveBeenCalled(); + expect(result.sources).toHaveLength(0); + }); + + it('skips page fetch for cloud metadata endpoint URLs (SSRF protection)', async () => { + mockAIResponse('Answer', [ + { + sourceType: 'url', + type: 'source', + id: '1', + url: 'http://169.254.169.254/latest/meta-data/', + title: 'Meta', + }, + ]); + + const result = await searchWithClaude({ + query: 'test', + apiKey: 'sk-test', + baseUrl: 'https://api.anthropic.com', + }); + expect(mockProxyFetch).not.toHaveBeenCalled(); + expect(result.sources).toHaveLength(0); + }); + + // ── error propagation ───────────────────────────────────────────────────── + + it('throws when generateText rejects', async () => { + mockGenerateText.mockRejectedValueOnce(new Error('Claude API error (401): invalid x-api-key')); + + await expect( + searchWithClaude({ query: 'test', apiKey: 'bad-key', baseUrl: 'https://api.anthropic.com' }), + ).rejects.toThrow('Claude API error (401)'); + }); + + it('throws when generateText rejects with a network error', async () => { + mockGenerateText.mockRejectedValueOnce(new Error('Network failure')); + + await expect( + searchWithClaude({ query: 'test', apiKey: 'sk-test', baseUrl: 'https://api.anthropic.com' }), + ).rejects.toThrow(); + }); +});