From 96ff503361a786c19f2a502c3fcb5fccc05959e1 Mon Sep 17 00:00:00 2001 From: Aniketh Dev Date: Tue, 24 Mar 2026 22:45:03 +0530 Subject: [PATCH] feat: add cross-provider fallback routing with request/response transformers --- examples/8-cross-provider-fallback.js | 88 ++++++++++++ package-lock.json | 7 +- src/index.ts | 25 +++- src/interceptors/fetchInterceptor.ts | 166 ++++++++++++++++++++-- src/router/apiKeyManager.ts | 55 ++++++++ src/router/modelRouter.ts | 15 ++ src/router/providerDetector.ts | 87 ++++++++++++ src/router/providerHeaders.ts | 65 +++++++++ src/router/requestTransformer.ts | 193 ++++++++++++++++++++++++++ src/router/responseTransformer.ts | 161 +++++++++++++++++++++ src/router/types.ts | 16 +++ 11 files changed, 865 insertions(+), 13 deletions(-) create mode 100644 examples/8-cross-provider-fallback.js create mode 100644 src/router/apiKeyManager.ts create mode 100644 src/router/providerDetector.ts create mode 100644 src/router/providerHeaders.ts create mode 100644 src/router/requestTransformer.ts create mode 100644 src/router/responseTransformer.ts diff --git a/examples/8-cross-provider-fallback.js b/examples/8-cross-provider-fallback.js new file mode 100644 index 0000000..fa5a451 --- /dev/null +++ b/examples/8-cross-provider-fallback.js @@ -0,0 +1,88 @@ +/** + * Example 8: Cross-Provider Fallback + * + * Demonstrates automatic fallback across different LLM providers. + * If GPT-4o fails, automatically retries with Claude, then Gemini — + * all transparent to the caller with a unified budget. + * + * Prerequisites: + * npm install tokenfirewall + * Set environment variables: OPENAI_API_KEY, ANTHROPIC_API_KEY, GEMINI_API_KEY + */ + +const { + createBudgetGuard, + createModelRouter, + registerApiKeys, + patchGlobalFetch, + getBudgetStatus, + isCrossProviderEnabled, +} = require("tokenfirewall"); + +// 1. Register API keys for all the providers you want to fallback between +registerApiKeys({ + openai: process.env.OPENAI_API_KEY, + anthropic: process.env.ANTHROPIC_API_KEY, + gemini: process.env.GEMINI_API_KEY, +}); + +// 2. Create budget guard — costs are tracked across ALL providers +createBudgetGuard({ + monthlyLimit: 50, // $50 USD total budget + mode: "block", +}); + +// 3. Create model router with cross-provider fallback chains +createModelRouter({ + strategy: "fallback", + fallbackMap: { + // If GPT-4o fails → try Claude 3.5 Sonnet → then Gemini 2.5 Pro + "gpt-4o": ["claude-3-5-sonnet-20241022", "gemini-2.5-pro"], + // If Claude fails → try GPT-4o-mini → then Gemini + "claude-3-5-sonnet-20241022": ["gpt-4o-mini", "gemini-2.5-pro"], + // If Gemini fails → try GPT-4o-mini → then Claude Haiku + "gemini-2.5-pro": ["gpt-4o-mini", "claude-3-5-haiku-20241022"], + }, + maxRetries: 2, + enableCrossProvider: true, // <-- enable cross-provider fallback +}); + +// 4. Patch global fetch +patchGlobalFetch(); + +console.log("Cross-provider enabled:", isCrossProviderEnabled()); + +// 5. Make a normal API call — fallback is fully transparent +async function main() { + try { + const response = await fetch("https://api.openai.com/v1/chat/completions", { + method: "POST", + headers: { + Authorization: `Bearer ${process.env.OPENAI_API_KEY}`, + "Content-Type": "application/json", + }, + body: JSON.stringify({ + model: "gpt-4o", + messages: [{ role: "user", content: "What is the capital of France?" }], + max_tokens: 100, + }), + }); + + const data = await response.json(); + console.log("\nResponse:", data.choices?.[0]?.message?.content); + + // Budget is tracked across all providers + const status = getBudgetStatus(); + console.log("\nBudget status:", status); + + // If GPT-4o failed, the response was automatically: + // 1. Transformed to Claude/Gemini format + // 2. Sent to the fallback provider + // 3. Response transformed back to OpenAI format + // 4. Returned as if GPT-4o answered — fully transparent! + } catch (error) { + console.error("All providers failed:", error.message); + } +} + +main(); diff --git a/package-lock.json b/package-lock.json index 9bd5195..e9cce39 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,16 +1,19 @@ { "name": "tokenfirewall", - "version": "1.0.0", + "version": "2.0.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "tokenfirewall", - "version": "1.0.0", + "version": "2.0.1", "license": "MIT", "devDependencies": { "@types/node": "^20.0.0", "typescript": "^5.0.0" + }, + "engines": { + "node": ">=16.0.0" } }, "node_modules/@types/node": { diff --git a/src/index.ts b/src/index.ts index 5066e86..741823a 100644 --- a/src/index.ts +++ b/src/index.ts @@ -7,7 +7,8 @@ import { patchProvider } from "./interceptors/sdkInterceptor"; import { listAvailableModels, ModelInfo, ListModelsOptions } from "./introspection/modelLister"; import { contextRegistry } from "./introspection/contextRegistry"; import { ModelRouter } from "./router/modelRouter"; -import { ModelRouterOptions } from "./router/types"; +import { ModelRouterOptions, ApiKeyConfig } from "./router/types"; +import { apiKeyManager } from "./router/apiKeyManager"; let globalBudgetManager: BudgetManager | null = null; let globalModelRouter: ModelRouter | null = null; @@ -135,6 +136,25 @@ export function registerModels( } } +/** + * Register API keys for cross-provider fallback + * @param keys - Object mapping provider names to API keys + */ +export function registerApiKeys(keys: ApiKeyConfig): void { + if (!keys || typeof keys !== 'object') { + throw new Error('TokenFirewall: Keys must be an object mapping provider names to API keys'); + } + apiKeyManager.registerKeys(keys); +} + +/** + * Check if cross-provider fallback is enabled + * @returns true if a model router exists with cross-provider enabled + */ +export function isCrossProviderEnabled(): boolean { + return globalModelRouter?.isCrossProviderEnabled() ?? false; +} + /** * Get current budget status * @returns Budget status or null if no budget guard is active @@ -228,7 +248,8 @@ export type { FailureType, FailureContext, RoutingDecision, - RouterEvent + RouterEvent, + ApiKeyConfig } from "./router/types"; /** diff --git a/src/interceptors/fetchInterceptor.ts b/src/interceptors/fetchInterceptor.ts index eb2bb7d..4713888 100644 --- a/src/interceptors/fetchInterceptor.ts +++ b/src/interceptors/fetchInterceptor.ts @@ -3,6 +3,11 @@ import { calculateCost } from "../core/costEngine"; import { logger } from "../logger"; import { BudgetManager } from "../core/budgetManager"; import { ModelRouter } from "../router/modelRouter"; +import { detectProvider, buildProviderUrl, isCrossProviderSwitch } from "../router/providerDetector"; +import { apiKeyManager } from "../router/apiKeyManager"; +import { buildProviderHeaders, appendApiKeyToUrl } from "../router/providerHeaders"; +import { transformRequest } from "../router/requestTransformer"; +import { transformResponse } from "../router/responseTransformer"; let isPatched = false; let budgetManager: BudgetManager | null = null; @@ -98,7 +103,7 @@ async function standardFetch( } /** - * Fetch with automatic retry and model switching + * Fetch with automatic retry and model switching (including cross-provider) */ async function fetchWithRetry( input: Parameters[0], @@ -127,12 +132,22 @@ async function fetchWithRetry( let currentInput = input; // Extract original model and provider from request - const { originalModel, provider } = extractModelInfo(input, init); + const { originalModel, provider: originalProvider } = extractModelInfo(input, init); if (originalModel) { attemptedModels.push(originalModel); } + // Parse the original request body for potential cross-provider transformations + let originalRequestBody: any = null; + if (init?.body) { + try { + originalRequestBody = JSON.parse(init.body as string); + } catch { + // Not JSON + } + } + while (retryCount <= (modelRouter?.getMaxRetries() || 0)) { try { // Make the request @@ -157,12 +172,40 @@ async function fetchWithRetry( throw new Error(JSON.stringify(errorObj)); } + // If this was a cross-provider retry, transform the response back + // to the original provider's format for transparency + if (originalProvider && retryCount > 0) { + const currentModel = extractCurrentModel(currentInput, currentInit); + const currentProvider = currentModel ? detectProvider(currentModel) : null; + + if (currentProvider && currentProvider !== originalProvider) { + try { + const clonedForTransform = response.clone(); + const responseData = await clonedForTransform.json(); + const transformed = transformResponse( + responseData, + currentProvider, // source: the provider that actually responded + originalProvider, // target: what the caller expects + currentModel || '' + ); + return new Response(JSON.stringify(transformed), { + status: response.status, + statusText: response.statusText, + headers: { 'Content-Type': 'application/json' }, + }); + } catch { + // If transformation fails, return original response + return response; + } + } + } + return response; } catch (error) { lastError = error; // If no router or no model info, throw immediately - if (!modelRouter || !originalModel || !provider) { + if (!modelRouter || !originalModel || !originalProvider) { throw error; } @@ -187,7 +230,7 @@ async function fetchWithRetry( return {}; } })() : {}, - provider, + provider: originalProvider, retryCount, attemptedModels }); @@ -206,11 +249,38 @@ async function fetchWithRetry( maxRetries: modelRouter.getMaxRetries() }); - // Update request with new model attemptedModels.push(decision.nextModel); - const updated = updateRequestModel(currentInput, currentInit, decision.nextModel, provider); - currentInput = updated.input; - currentInit = updated.init; + + // Check if this is a cross-provider switch + const nextProvider = detectProvider(decision.nextModel); + const isCrossProvider = nextProvider + && originalProvider + && isCrossProviderSwitch(originalModel, decision.nextModel) + && modelRouter.isCrossProviderEnabled(); + + if (isCrossProvider && nextProvider && originalRequestBody) { + // --- Cross-provider fallback --- + const updated = buildCrossProviderRequest( + originalRequestBody, + originalProvider, + nextProvider, + decision.nextModel + ); + + if (updated) { + currentInput = updated.url; + currentInit = updated.init; + } else { + // Could not build cross-provider request, throw + throw error; + } + } else { + // --- Same-provider fallback (existing behavior) --- + const updated = updateRequestModel(currentInput, currentInit, decision.nextModel, originalProvider); + currentInput = updated.input; + currentInit = updated.init; + } + retryCount++; } } @@ -221,6 +291,85 @@ async function fetchWithRetry( ); } +/** + * Build a completely new request for a different provider (cross-provider fallback) + */ +function buildCrossProviderRequest( + originalBody: any, + sourceProvider: string, + targetProvider: string, + targetModel: string +): { url: string; init: RequestInit } | null { + // Get API key for target provider + const apiKey = apiKeyManager.getKey(targetProvider); + if (!apiKey) { + console.warn( + `TokenFirewall Router: No API key registered for provider "${targetProvider}". ` + + `Cross-provider fallback skipped. Register keys with registerApiKeys().` + ); + return null; + } + + // Transform request body + const transformedBody = transformRequest( + originalBody, + sourceProvider, + targetProvider, + targetModel + ); + + // Build target URL + let targetUrl = buildProviderUrl(targetProvider, targetModel); + if (!targetUrl) { + console.warn( + `TokenFirewall Router: Unknown endpoint for provider "${targetProvider}".` + ); + return null; + } + + // Append API key to URL if needed (Gemini) + targetUrl = appendApiKeyToUrl(targetUrl, targetProvider, apiKey); + + // Build headers + const headers = buildProviderHeaders(targetProvider, apiKey); + + return { + url: targetUrl, + init: { + method: 'POST', + headers, + body: JSON.stringify(transformedBody), + }, + }; +} + +/** + * Extract current model from the (possibly updated) request + */ +function extractCurrentModel( + input: Parameters[0], + init?: Parameters[1] +): string | null { + // Check URL for Gemini-style model + const url = typeof input === 'string' ? input : (input instanceof Request ? input.url : String(input)); + const geminiMatch = url.match(/\/models\/([^:?]+)/); + if (geminiMatch) { + return geminiMatch[1]; + } + + // Check body for model field + if (init?.body) { + try { + const body = JSON.parse(init.body as string); + return body.model || null; + } catch { + return null; + } + } + + return null; +} + /** * Extract model and provider information from request */ @@ -321,7 +470,6 @@ function updateRequestModel( body: init?.body || null, mode: input.mode, credentials: input.credentials, - cache: input.cache, redirect: input.redirect, referrer: input.referrer, integrity: input.integrity diff --git a/src/router/apiKeyManager.ts b/src/router/apiKeyManager.ts new file mode 100644 index 0000000..12e5d23 --- /dev/null +++ b/src/router/apiKeyManager.ts @@ -0,0 +1,55 @@ +/** + * API Key Manager + * Stores and retrieves API keys for different LLM providers + */ + +export class ApiKeyManager { + private keys: Map = new Map(); + + /** + * Register an API key for a provider + */ + public registerKey(provider: string, apiKey: string): void { + if (!provider || typeof provider !== 'string' || provider.trim() === '') { + throw new Error('TokenFirewall: Provider must be a non-empty string'); + } + if (!apiKey || typeof apiKey !== 'string' || apiKey.trim() === '') { + throw new Error(`TokenFirewall: Invalid API key for provider "${provider}"`); + } + this.keys.set(provider.toLowerCase(), apiKey); + } + + /** + * Get API key for a provider + */ + public getKey(provider: string): string | undefined { + return this.keys.get(provider.toLowerCase()); + } + + /** + * Check if a key exists for a provider + */ + public hasKey(provider: string): boolean { + return this.keys.has(provider.toLowerCase()); + } + + /** + * Register multiple keys at once + */ + public registerKeys(keys: Record): void { + for (const [provider, key] of Object.entries(keys)) { + if (key !== undefined && key !== null && key !== '') { + this.registerKey(provider, key); + } + } + } + + /** + * Get list of providers that have keys registered + */ + public getRegisteredProviders(): string[] { + return Array.from(this.keys.keys()); + } +} + +export const apiKeyManager = new ApiKeyManager(); diff --git a/src/router/modelRouter.ts b/src/router/modelRouter.ts index 9ab349f..4b6f8ea 100644 --- a/src/router/modelRouter.ts +++ b/src/router/modelRouter.ts @@ -6,6 +6,7 @@ import { } from "./types"; import { errorDetector } from "./errorDetector"; import { fallbackStrategy, contextStrategy, costStrategy } from "./routingStrategies"; +import { apiKeyManager } from "./apiKeyManager"; /** * Intelligent Model Router @@ -15,11 +16,18 @@ export class ModelRouter { private strategy: RoutingStrategy; private fallbackMap: Record; private maxRetries: number; + private crossProviderEnabled: boolean; constructor(options: ModelRouterOptions) { this.strategy = options.strategy; this.fallbackMap = options.fallbackMap || {}; this.maxRetries = options.maxRetries ?? 1; + this.crossProviderEnabled = options.enableCrossProvider ?? false; + + // Register API keys if provided + if (options.apiKeys) { + apiKeyManager.registerKeys(options.apiKeys); + } this.validateOptions(); } @@ -149,4 +157,11 @@ export class ModelRouter { public getStrategy(): RoutingStrategy { return this.strategy; } + + /** + * Check if cross-provider fallback is enabled + */ + public isCrossProviderEnabled(): boolean { + return this.crossProviderEnabled; + } } diff --git a/src/router/providerDetector.ts b/src/router/providerDetector.ts new file mode 100644 index 0000000..f87936c --- /dev/null +++ b/src/router/providerDetector.ts @@ -0,0 +1,87 @@ +/** + * Provider Detection Module + * Detects which LLM provider a model belongs to and resolves API endpoints + */ + +const MODEL_PREFIX_MAP: Record = { + // OpenAI + 'gpt-': 'openai', + 'o1': 'openai', + 'o3': 'openai', + 'o4': 'openai', + 'chatgpt-': 'openai', + // Anthropic + 'claude-': 'anthropic', + // Gemini + 'gemini-': 'gemini', + // Grok + 'grok-': 'grok', + 'llama-': 'grok', + // Kimi + 'moonshot-': 'kimi', +}; + +const PROVIDER_ENDPOINTS: Record = { + 'openai': 'https://api.openai.com/v1/chat/completions', + 'anthropic': 'https://api.anthropic.com/v1/messages', + 'gemini': 'https://generativelanguage.googleapis.com/v1beta/models', + 'grok': 'https://api.x.ai/v1/chat/completions', + 'kimi': 'https://api.moonshot.cn/v1/chat/completions', +}; + +/** + * Detect provider from model name + */ +export function detectProvider(modelName: string): string | null { + if (!modelName || typeof modelName !== 'string') { + return null; + } + + const lower = modelName.toLowerCase(); + + for (const [prefix, provider] of Object.entries(MODEL_PREFIX_MAP)) { + if (lower.startsWith(prefix)) { + return provider; + } + } + + return null; +} + +/** + * Get API endpoint URL for a provider + * For Gemini, the model name must be appended: endpoint/{model}:generateContent + */ +export function getProviderEndpoint(provider: string): string { + return PROVIDER_ENDPOINTS[provider] || ''; +} + +/** + * Build the full request URL for a provider + model + */ +export function buildProviderUrl(provider: string, model: string): string { + const base = getProviderEndpoint(provider); + if (!base) { + return ''; + } + + if (provider === 'gemini') { + return `${base}/${model}:generateContent`; + } + + return base; +} + +/** + * Check if two models belong to different providers + */ +export function isCrossProviderSwitch(modelA: string, modelB: string): boolean { + const providerA = detectProvider(modelA); + const providerB = detectProvider(modelB); + + if (!providerA || !providerB) { + return false; + } + + return providerA !== providerB; +} diff --git a/src/router/providerHeaders.ts b/src/router/providerHeaders.ts new file mode 100644 index 0000000..37109d9 --- /dev/null +++ b/src/router/providerHeaders.ts @@ -0,0 +1,65 @@ +/** + * Provider Headers Builder + * Builds correct authentication and content headers for each LLM provider + */ + +/** + * Build request headers for a specific provider + */ +export function buildProviderHeaders( + provider: string, + apiKey: string +): Record { + const baseHeaders: Record = { + 'Content-Type': 'application/json', + }; + + switch (provider) { + case 'openai': + return { + ...baseHeaders, + 'Authorization': `Bearer ${apiKey}`, + }; + + case 'anthropic': + return { + ...baseHeaders, + 'x-api-key': apiKey, + 'anthropic-version': '2023-06-01', + }; + + case 'gemini': + // Gemini uses API key as URL query parameter, not in headers + return baseHeaders; + + case 'grok': + return { + ...baseHeaders, + 'Authorization': `Bearer ${apiKey}`, + }; + + case 'kimi': + return { + ...baseHeaders, + 'Authorization': `Bearer ${apiKey}`, + }; + + default: + // Default to Bearer token for unknown providers + return { + ...baseHeaders, + 'Authorization': `Bearer ${apiKey}`, + }; + } +} + +/** + * Build the full URL for a provider, appending API key if needed (e.g., Gemini) + */ +export function appendApiKeyToUrl(url: string, provider: string, apiKey: string): string { + if (provider === 'gemini') { + const separator = url.includes('?') ? '&' : '?'; + return `${url}${separator}key=${apiKey}`; + } + return url; +} diff --git a/src/router/requestTransformer.ts b/src/router/requestTransformer.ts new file mode 100644 index 0000000..18961b9 --- /dev/null +++ b/src/router/requestTransformer.ts @@ -0,0 +1,193 @@ +/** + * Request Transformer + * Converts LLM request formats between providers + * Supports: OpenAI ↔ Anthropic ↔ Gemini ↔ Grok ↔ Kimi + * + * Note: This covers basic text chat completions (Option 1 MVP). + * Streaming, function calling, and vision are not yet supported. + */ + +interface Message { + role: string; + content: string; +} + +/** + * Normalize any provider's request to an internal common format (OpenAI-shaped) + * This makes it easy to then convert to any target format. + */ +function normalizeToCommon(request: any, sourceProvider: string): { + messages: Message[]; + model: string; + temperature?: number; + max_tokens?: number; + top_p?: number; +} { + switch (sourceProvider) { + case 'anthropic': { + const messages: Message[] = []; + // Anthropic has a separate `system` field + if (request.system) { + messages.push({ role: 'system', content: request.system }); + } + if (Array.isArray(request.messages)) { + for (const msg of request.messages) { + messages.push({ + role: msg.role === 'assistant' ? 'assistant' : 'user', + content: typeof msg.content === 'string' ? msg.content : JSON.stringify(msg.content), + }); + } + } + return { + messages, + model: request.model || '', + temperature: request.temperature, + max_tokens: request.max_tokens, + top_p: request.top_p, + }; + } + + case 'gemini': { + const messages: Message[] = []; + if (Array.isArray(request.contents)) { + for (const content of request.contents) { + const role = content.role === 'model' ? 'assistant' : 'user'; + const text = Array.isArray(content.parts) + ? content.parts.map((p: any) => p.text || '').join('') + : ''; + messages.push({ role, content: text }); + } + } + return { + messages, + model: '', + temperature: request.generationConfig?.temperature, + max_tokens: request.generationConfig?.maxOutputTokens, + top_p: request.generationConfig?.topP, + }; + } + + // OpenAI, Grok, Kimi all use OpenAI-compatible format + default: + return { + messages: (request.messages || []).map((m: any) => ({ + role: m.role, + content: typeof m.content === 'string' ? m.content : JSON.stringify(m.content), + })), + model: request.model || '', + temperature: request.temperature, + max_tokens: request.max_tokens, + top_p: request.top_p, + }; + } +} + +/** + * Convert the common (OpenAI-shaped) format to a target provider's format + */ +function commonToTarget( + common: ReturnType, + targetProvider: string, + targetModel: string +): any { + switch (targetProvider) { + case 'anthropic': { + // Extract system message + const systemMsg = common.messages.find(m => m.role === 'system'); + const nonSystemMsgs = common.messages.filter(m => m.role !== 'system'); + + const result: any = { + model: targetModel, + messages: nonSystemMsgs.map(m => ({ + role: m.role === 'assistant' ? 'assistant' : 'user', + content: m.content, + })), + max_tokens: common.max_tokens || 1024, + }; + + if (systemMsg) { + result.system = systemMsg.content; + } + if (common.temperature !== undefined) { + result.temperature = common.temperature; + } + if (common.top_p !== undefined) { + result.top_p = common.top_p; + } + + return result; + } + + case 'gemini': { + // Gemini uses `contents` array; system messages become a leading user turn + const contents: any[] = []; + + for (const msg of common.messages) { + if (msg.role === 'system') { + // Gemini doesn't have a system role; prepend as a user message + contents.push({ + role: 'user', + parts: [{ text: msg.content }], + }); + } else { + contents.push({ + role: msg.role === 'assistant' ? 'model' : 'user', + parts: [{ text: msg.content }], + }); + } + } + + const result: any = { contents }; + + const genConfig: any = {}; + if (common.temperature !== undefined) genConfig.temperature = common.temperature; + if (common.max_tokens !== undefined) genConfig.maxOutputTokens = common.max_tokens; + if (common.top_p !== undefined) genConfig.topP = common.top_p; + + if (Object.keys(genConfig).length > 0) { + result.generationConfig = genConfig; + } + + return result; + } + + // OpenAI, Grok, Kimi — all use OpenAI-compatible format + default: + const result: any = { + model: targetModel, + messages: common.messages.map(m => ({ + role: m.role, + content: m.content, + })), + }; + + if (common.temperature !== undefined) result.temperature = common.temperature; + if (common.max_tokens !== undefined) result.max_tokens = common.max_tokens; + if (common.top_p !== undefined) result.top_p = common.top_p; + + return result; + } +} + +/** + * Transform a request from one provider's format to another + */ +export function transformRequest( + originalRequest: any, + sourceProvider: string, + targetProvider: string, + targetModel: string +): any { + // Same provider — just swap model name + if (sourceProvider === targetProvider) { + if (targetProvider === 'gemini') { + // Gemini model is in URL, not body — just return the body as-is + return { ...originalRequest }; + } + return { ...originalRequest, model: targetModel }; + } + + // Cross-provider: normalize → convert + const common = normalizeToCommon(originalRequest, sourceProvider); + return commonToTarget(common, targetProvider, targetModel); +} diff --git a/src/router/responseTransformer.ts b/src/router/responseTransformer.ts new file mode 100644 index 0000000..28a704c --- /dev/null +++ b/src/router/responseTransformer.ts @@ -0,0 +1,161 @@ +/** + * Response Transformer + * Converts LLM responses back to the caller's expected provider format + * + * The caller originally made a request to Provider A, but a fallback sent it + * to Provider B. We need to transform Provider B's response to look like + * Provider A's response so the caller's code works transparently. + */ + +/** + * Normalize any provider response to OpenAI-shaped format (internal common format) + */ +function normalizeResponseToCommon(response: any, provider: string, model: string): any { + switch (provider) { + case 'anthropic': + return { + id: response.id || `chatcmpl-${Date.now()}`, + object: 'chat.completion', + created: Math.floor(Date.now() / 1000), + model: response.model || model, + choices: [{ + index: 0, + message: { + role: 'assistant', + content: extractAnthropicContent(response), + }, + finish_reason: mapAnthropicStopReason(response.stop_reason), + }], + usage: { + prompt_tokens: response.usage?.input_tokens || 0, + completion_tokens: response.usage?.output_tokens || 0, + total_tokens: (response.usage?.input_tokens || 0) + (response.usage?.output_tokens || 0), + }, + }; + + case 'gemini': + const candidate = response.candidates?.[0]; + return { + id: `chatcmpl-${Date.now()}`, + object: 'chat.completion', + created: Math.floor(Date.now() / 1000), + model: model, + choices: [{ + index: 0, + message: { + role: 'assistant', + content: candidate?.content?.parts?.[0]?.text || '', + }, + finish_reason: mapGeminiFinishReason(candidate?.finishReason), + }], + usage: { + prompt_tokens: response.usageMetadata?.promptTokenCount || 0, + completion_tokens: response.usageMetadata?.candidatesTokenCount || 0, + total_tokens: response.usageMetadata?.totalTokenCount || 0, + }, + }; + + // OpenAI, Grok, Kimi — already in OpenAI-compatible format + default: + return response; + } +} + +/** + * Convert OpenAI-shaped common format to a target provider's response format + */ +function commonToTargetResponse(common: any, targetProvider: string): any { + switch (targetProvider) { + case 'anthropic': + return { + id: common.id, + type: 'message', + role: 'assistant', + content: [{ + type: 'text', + text: common.choices?.[0]?.message?.content || '', + }], + model: common.model, + stop_reason: common.choices?.[0]?.finish_reason === 'stop' ? 'end_turn' : common.choices?.[0]?.finish_reason, + usage: { + input_tokens: common.usage?.prompt_tokens || 0, + output_tokens: common.usage?.completion_tokens || 0, + }, + }; + + case 'gemini': + return { + candidates: [{ + content: { + parts: [{ + text: common.choices?.[0]?.message?.content || '', + }], + role: 'model', + }, + finishReason: 'STOP', + }], + usageMetadata: { + promptTokenCount: common.usage?.prompt_tokens || 0, + candidatesTokenCount: common.usage?.completion_tokens || 0, + totalTokenCount: common.usage?.total_tokens || 0, + }, + }; + + // OpenAI, Grok, Kimi — already OpenAI-compatible + default: + return common; + } +} + +/** + * Transform a response from one provider's format to another + * Used when a cross-provider fallback occurred and the caller expects + * the original provider's response format. + */ +export function transformResponse( + response: any, + sourceProvider: string, + targetProvider: string, + targetModel: string +): any { + // Same provider — return as-is + if (sourceProvider === targetProvider) { + return response; + } + + // Normalize the actual response (from sourceProvider) to common OpenAI format + const common = normalizeResponseToCommon(response, sourceProvider, targetModel); + + // Convert from common format to what the caller expects (targetProvider format) + return commonToTargetResponse(common, targetProvider); +} + +// --- Helpers --- + +function extractAnthropicContent(response: any): string { + if (Array.isArray(response.content)) { + return response.content + .filter((block: any) => block.type === 'text') + .map((block: any) => block.text) + .join(''); + } + return ''; +} + +function mapAnthropicStopReason(reason: string | undefined): string { + switch (reason) { + case 'end_turn': return 'stop'; + case 'max_tokens': return 'length'; + case 'stop_sequence': return 'stop'; + default: return reason || 'stop'; + } +} + +function mapGeminiFinishReason(reason: string | undefined): string { + switch (reason) { + case 'STOP': return 'stop'; + case 'MAX_TOKENS': return 'length'; + case 'SAFETY': return 'content_filter'; + default: return 'stop'; + } +} diff --git a/src/router/types.ts b/src/router/types.ts index 2788a0a..8e333f4 100644 --- a/src/router/types.ts +++ b/src/router/types.ts @@ -17,6 +17,18 @@ export type FailureType = | "access_denied" | "unknown"; +/** + * API key configuration for cross-provider fallback + */ +export interface ApiKeyConfig { + openai?: string; + anthropic?: string; + gemini?: string; + grok?: string; + kimi?: string; + [key: string]: string | undefined; +} + /** * Configuration options for model router */ @@ -27,6 +39,10 @@ export interface ModelRouterOptions { fallbackMap?: Record; /** Maximum number of retry attempts (default: 1) */ maxRetries?: number; + /** API keys for cross-provider fallback */ + apiKeys?: ApiKeyConfig; + /** Enable cross-provider fallback (default: false) */ + enableCrossProvider?: boolean; } /**