diff --git a/src/gateway/openai-compatible.ts b/src/gateway/openai-compatible.ts index cdc6e802..a111129f 100644 --- a/src/gateway/openai-compatible.ts +++ b/src/gateway/openai-compatible.ts @@ -54,6 +54,10 @@ import { sendOpenAICompatibleStreamError, writeOpenAICompatibleStreamChunk, } from './openai-compatible-response.js'; +import { + callWithProviderFallback, + loadFallbackChainFromEnv, +} from './provider-fallback.js'; function isResponseWritable(res: ServerResponse): boolean { return !res.writableEnded && !res.destroyed; @@ -495,12 +499,18 @@ async function handleOpenAICompatibleToolChat( ): Promise { const runtime = await resolveToolAwareRuntime({ prepared }); const messages = await buildToolAwareMessages({ input, prepared }); - const result = await callOpenAICompatibleModel({ - runtime, - model: prepared.model, - messages, - tools: input.tools, - toolChoice: input.toolChoice, + const result = await callWithProviderFallback({ + primaryRuntime: runtime, + primaryModel: prepared.model, + chain: loadFallbackChainFromEnv(), + invoke: (activeRuntime, activeModel) => + callOpenAICompatibleModel({ + runtime: activeRuntime, + model: activeModel, + messages, + tools: input.tools, + toolChoice: input.toolChoice, + }), }); const choice = result.choices[0]; const payload = buildOpenAICompatibleCompletionResponse({ @@ -694,24 +704,38 @@ async function handleOpenAICompatibleStreamingToolChat( try { const runtime = await resolveToolAwareRuntime({ prepared }); const messages = await buildToolAwareMessages({ input, prepared }); - const result = await callOpenAICompatibleModelStream({ - runtime, - model: prepared.model, - messages, - tools: input.tools, - toolChoice: input.toolChoice, - onTextDelta: (delta) => { - if (!isResponseWritable(res) || !delta) return; - writeOpenAICompatibleStreamChunk( - res, - buildOpenAICompatibleStreamTextChunk({ - completionId, - created, - model: prepared.responseModel, - content: delta, - }), - ); - }, + let streamStarted = false; + const result = await callWithProviderFallback({ + primaryRuntime: runtime, + primaryModel: prepared.model, + chain: loadFallbackChainFromEnv(), + // Once the SSE stream has emitted bytes we cannot safely switch + // providers — a fallback would duplicate or interleave content. Suppress + // further retries so the original provider error propagates instead of a + // generic "Stream already started" placeholder. + shouldFallback: () => !streamStarted, + invoke: (activeRuntime, activeModel) => + callOpenAICompatibleModelStream({ + runtime: activeRuntime, + model: activeModel, + messages, + tools: input.tools, + toolChoice: input.toolChoice, + onTextDelta: (delta) => { + if (!delta) return; + streamStarted = true; + if (!isResponseWritable(res)) return; + writeOpenAICompatibleStreamChunk( + res, + buildOpenAICompatibleStreamTextChunk({ + completionId, + created, + model: prepared.responseModel, + content: delta, + }), + ); + }, + }), }); if (!isResponseWritable(res)) return; const choice = result.choices[0]; diff --git a/src/gateway/provider-fallback.ts b/src/gateway/provider-fallback.ts new file mode 100644 index 00000000..2a6e31ae --- /dev/null +++ b/src/gateway/provider-fallback.ts @@ -0,0 +1,259 @@ +import { performance } from 'node:perf_hooks'; + +import { resolveModelRuntimeCredentials } from '../providers/factory.js'; +import type { ResolvedModelRuntimeCredentials } from '../providers/types.js'; +import { isRecord } from '../utils/type-guards.js'; + +export interface FallbackChainEntry { + model: string; + baseUrl?: string; + keyEnv?: string; + chatbotId?: string; + agentId?: string; +} + +export type FallbackReason = 'auth' | 'rate_limit' | 'other'; + +export interface FallbackActivation { + runtime: ResolvedModelRuntimeCredentials; + model: string; + entry: FallbackChainEntry; +} + +const DEFAULT_COOLDOWN_MS = 60_000; + +const cooldownMap = new Map(); + +export function loadFallbackChainFromEnv( + raw: string | undefined = process.env.HYBRIDAI_FALLBACK_CHAIN, +): FallbackChainEntry[] { + const text = String(raw || '').trim(); + if (!text) return []; + let parsed: unknown; + try { + parsed = JSON.parse(text); + } catch { + return []; + } + if (!Array.isArray(parsed)) return []; + const entries: FallbackChainEntry[] = []; + for (const item of parsed) { + if (!isRecord(item)) continue; + const model = typeof item.model === 'string' ? item.model.trim() : ''; + if (!model) continue; + const entry: FallbackChainEntry = { model }; + if (typeof item.baseUrl === 'string' && item.baseUrl.trim()) { + entry.baseUrl = item.baseUrl.trim(); + } + if (typeof item.keyEnv === 'string' && item.keyEnv.trim()) { + entry.keyEnv = item.keyEnv.trim(); + } + if (typeof item.chatbotId === 'string' && item.chatbotId.trim()) { + entry.chatbotId = item.chatbotId.trim(); + } + if (typeof item.agentId === 'string' && item.agentId.trim()) { + entry.agentId = item.agentId.trim(); + } + entries.push(entry); + } + return entries; +} + +export function classifyProviderError(err: unknown): FallbackReason { + const text = err instanceof Error ? err.message : String(err); + if (/(^|\D)401(\D|$)|(^|\D)403(\D|$)/.test(text)) return 'auth'; + if (/unauthorized|forbidden|invalid api key|permission denied/i.test(text)) { + return 'auth'; + } + if (/(^|\D)429(\D|$)/.test(text)) return 'rate_limit'; + if (/rate[- ]?limit|too many requests|quota|billing/i.test(text)) { + return 'rate_limit'; + } + return 'other'; +} + +export function isProviderCooledDown( + providerId: string, + now: number = performance.now(), +): boolean { + const until = cooldownMap.get(providerId); + return typeof until === 'number' && until > now; +} + +export function markProviderCooldown( + providerId: string, + durationMs: number = DEFAULT_COOLDOWN_MS, + now: number = performance.now(), +): void { + if (!providerId) return; + cooldownMap.set(providerId, now + Math.max(0, durationMs)); +} + +export function clearProviderCooldown(providerId?: string): void { + if (!providerId) { + cooldownMap.clear(); + return; + } + cooldownMap.delete(providerId); +} + +async function resolveEntry( + entry: FallbackChainEntry, +): Promise { + let runtime: ResolvedModelRuntimeCredentials; + try { + runtime = await resolveModelRuntimeCredentials({ + model: entry.model, + ...(entry.agentId ? { agentId: entry.agentId } : {}), + ...(entry.chatbotId ? { chatbotId: entry.chatbotId } : {}), + }); + } catch { + return null; + } + let apiKey = runtime.apiKey; + if (entry.keyEnv) { + const envKey = String(process.env[entry.keyEnv] || '').trim(); + if (envKey) apiKey = envKey; + } + if (!apiKey && !runtime.isLocal) return null; + return { + ...runtime, + ...(entry.baseUrl ? { baseUrl: entry.baseUrl } : {}), + apiKey, + ...(entry.chatbotId ? { chatbotId: entry.chatbotId } : {}), + }; +} + +export interface ProviderFallbackControllerOptions { + chain: FallbackChainEntry[]; + primaryProvider: string; + cooldownMs?: number; +} + +export class ProviderFallbackController { + private readonly chain: FallbackChainEntry[]; + private readonly primaryProvider: string; + private readonly cooldownMs: number; + private index = 0; + private activated = false; + + constructor(opts: ProviderFallbackControllerOptions) { + this.chain = opts.chain; + this.primaryProvider = String(opts.primaryProvider || '') + .trim() + .toLowerCase(); + this.cooldownMs = opts.cooldownMs ?? DEFAULT_COOLDOWN_MS; + } + + hasRemaining(): boolean { + return this.index < this.chain.length; + } + + isActivated(): boolean { + return this.activated; + } + + async tryActivate( + reason: FallbackReason, + currentProvider: string, + options: { markCooldown?: boolean } = {}, + ): Promise { + const markCooldown = options.markCooldown !== false; + if (markCooldown && reason === 'rate_limit' && this.primaryProvider) { + const current = String(currentProvider || '') + .trim() + .toLowerCase(); + const leavingPrimary = + !this.activated || current === this.primaryProvider; + if (leavingPrimary) { + markProviderCooldown(this.primaryProvider, this.cooldownMs); + } + } + while (this.index < this.chain.length) { + const entry = this.chain[this.index]; + this.index += 1; + if (!entry) continue; + const runtime = await resolveEntry(entry); + if (!runtime) continue; + this.activated = true; + return { runtime, model: entry.model, entry }; + } + return null; + } +} + +export interface CallWithFallbackParams { + primaryRuntime: ResolvedModelRuntimeCredentials; + primaryModel: string; + chain: FallbackChainEntry[]; + cooldownMs?: number; + invoke: ( + runtime: ResolvedModelRuntimeCredentials, + model: string, + ) => Promise; + onFallback?: (activation: FallbackActivation, reason: FallbackReason) => void; + /** + * Optional gate consulted before each fallback retry. Receives the original + * error and the classified reason; return `false` to suppress further + * fallback attempts and re-throw the original error. Useful for callers that + * have begun emitting bytes (e.g. SSE streams) and cannot safely retry on a + * different provider mid-response. + */ + shouldFallback?: (err: unknown, reason: FallbackReason) => boolean; +} + +export async function callWithProviderFallback( + params: CallWithFallbackParams, +): Promise { + const controller = new ProviderFallbackController({ + chain: params.chain, + primaryProvider: params.primaryRuntime.provider, + ...(params.cooldownMs !== undefined + ? { cooldownMs: params.cooldownMs } + : {}), + }); + + let runtime = params.primaryRuntime; + let model = params.primaryModel; + + if ( + params.chain.length > 0 && + isProviderCooledDown(params.primaryRuntime.provider) + ) { + // Primary is already in cooldown from a previous request — skip straight to + // the first fallback. Pass `markCooldown: false` so we do NOT extend the + // existing deadline; otherwise steady traffic would push the cooldown + // forward on every request and the primary would never come back. + const activation = await controller.tryActivate( + 'rate_limit', + params.primaryRuntime.provider, + { markCooldown: false }, + ); + if (activation) { + runtime = activation.runtime; + model = activation.model; + params.onFallback?.(activation, 'rate_limit'); + } + } + + const maxAttempts = params.chain.length + 1; + let lastError: unknown; + for (let attempt = 0; attempt < maxAttempts; attempt += 1) { + try { + return await params.invoke(runtime, model); + } catch (err) { + lastError = err; + const reason = classifyProviderError(err); + if (reason === 'other' || !controller.hasRemaining()) throw err; + if (params.shouldFallback && !params.shouldFallback(err, reason)) { + throw err; + } + const activation = await controller.tryActivate(reason, runtime.provider); + if (!activation) throw err; + runtime = activation.runtime; + model = activation.model; + params.onFallback?.(activation, reason); + } + } + throw lastError; +} diff --git a/tests/provider-fallback.test.ts b/tests/provider-fallback.test.ts new file mode 100644 index 00000000..69a49f58 --- /dev/null +++ b/tests/provider-fallback.test.ts @@ -0,0 +1,362 @@ +import { performance } from 'node:perf_hooks'; + +import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest'; + +import type { ResolvedModelRuntimeCredentials } from '../src/providers/types.js'; + +const resolveModelRuntimeCredentials = vi.fn(); + +vi.mock('../src/providers/factory.js', () => ({ + resolveModelRuntimeCredentials: ( + params: { model: string; chatbotId?: string; agentId?: string } | undefined, + ) => resolveModelRuntimeCredentials(params), +})); + +async function importModule() { + return import('../src/gateway/provider-fallback.js'); +} + +function runtimeFixture( + provider: string, + overrides: Partial = {}, +): ResolvedModelRuntimeCredentials { + return { + provider: provider as ResolvedModelRuntimeCredentials['provider'], + apiKey: `${provider}-key`, + baseUrl: `https://${provider}.example.com/v1`, + chatbotId: '', + enableRag: false, + requestHeaders: {}, + agentId: 'main', + isLocal: false, + ...overrides, + }; +} + +beforeEach(() => { + resolveModelRuntimeCredentials.mockReset(); + vi.resetModules(); +}); + +afterEach(() => { + vi.restoreAllMocks(); +}); + +describe('loadFallbackChainFromEnv', () => { + test('returns empty list for missing or invalid values', async () => { + const mod = await importModule(); + expect(mod.loadFallbackChainFromEnv(undefined)).toEqual([]); + expect(mod.loadFallbackChainFromEnv('')).toEqual([]); + expect(mod.loadFallbackChainFromEnv('not json')).toEqual([]); + expect(mod.loadFallbackChainFromEnv('{}')).toEqual([]); + }); + + test('parses well-formed entries and drops invalid ones', async () => { + const mod = await importModule(); + const chain = mod.loadFallbackChainFromEnv( + JSON.stringify([ + { model: 'gpt-4o-mini' }, + { model: ' ' }, + { baseUrl: 'https://x' }, + { + model: 'claude-3-5-haiku', + baseUrl: 'https://anthropic.example/v1', + keyEnv: 'ANTHROPIC_KEY', + chatbotId: 'cb-2', + }, + ]), + ); + expect(chain).toEqual([ + { model: 'gpt-4o-mini' }, + { + model: 'claude-3-5-haiku', + baseUrl: 'https://anthropic.example/v1', + keyEnv: 'ANTHROPIC_KEY', + chatbotId: 'cb-2', + }, + ]); + }); +}); + +describe('classifyProviderError', () => { + test('identifies auth, rate-limit, and unknown failures', async () => { + const mod = await importModule(); + expect(mod.classifyProviderError(new Error('failed with 401: bad'))).toBe( + 'auth', + ); + expect(mod.classifyProviderError(new Error('Forbidden: blocked'))).toBe( + 'auth', + ); + expect(mod.classifyProviderError(new Error('HTTP 429 too many'))).toBe( + 'rate_limit', + ); + expect(mod.classifyProviderError(new Error('daily quota exhausted'))).toBe( + 'rate_limit', + ); + expect(mod.classifyProviderError(new Error('500 internal'))).toBe('other'); + }); +}); + +describe('ProviderFallbackController', () => { + test('advances through chain and skips unresolvable entries', async () => { + const mod = await importModule(); + resolveModelRuntimeCredentials + .mockResolvedValueOnce(runtimeFixture('openrouter')) + .mockRejectedValueOnce(new Error('no credentials')) + .mockResolvedValueOnce(runtimeFixture('mistral')); + + const controller = new mod.ProviderFallbackController({ + chain: [ + { model: 'openrouter/a' }, + { model: 'broken/b' }, + { model: 'mistral/c' }, + ], + primaryProvider: 'openai', + }); + + const first = await controller.tryActivate('auth', 'openai'); + expect(first?.runtime.provider).toBe('openrouter'); + expect(controller.isActivated()).toBe(true); + + const second = await controller.tryActivate('auth', 'openrouter'); + expect(second?.runtime.provider).toBe('mistral'); + + const third = await controller.tryActivate('auth', 'mistral'); + expect(third).toBeNull(); + expect(controller.hasRemaining()).toBe(false); + }); + + test('rate-limit cooldown only set when leaving primary', async () => { + const mod = await importModule(); + mod.clearProviderCooldown(); + + resolveModelRuntimeCredentials + .mockResolvedValueOnce(runtimeFixture('openrouter')) + .mockResolvedValueOnce(runtimeFixture('mistral')); + + const controller = new mod.ProviderFallbackController({ + chain: [{ model: 'openrouter/a' }, { model: 'mistral/b' }], + primaryProvider: 'openai', + cooldownMs: 500, + }); + + await controller.tryActivate('rate_limit', 'openai'); + expect(mod.isProviderCooledDown('openai')).toBe(true); + + mod.clearProviderCooldown('openai'); + + await controller.tryActivate('rate_limit', 'openrouter'); + expect(mod.isProviderCooledDown('openai')).toBe(false); + }); + + test('keyEnv override wins over provider credentials', async () => { + const mod = await importModule(); + resolveModelRuntimeCredentials.mockResolvedValueOnce( + runtimeFixture('openrouter', { apiKey: 'fallback-key' }), + ); + vi.stubEnv('CUSTOM_KEY_ENV', 'env-override'); + + const controller = new mod.ProviderFallbackController({ + chain: [{ model: 'openrouter/a', keyEnv: 'CUSTOM_KEY_ENV' }], + primaryProvider: 'openai', + }); + + const activation = await controller.tryActivate('auth', 'openai'); + expect(activation?.runtime.apiKey).toBe('env-override'); + vi.unstubAllEnvs(); + }); +}); + +describe('callWithProviderFallback', () => { + test('invokes primary only when chain is empty', async () => { + const mod = await importModule(); + const invoke = vi.fn().mockResolvedValue({ id: 'ok' }); + const result = await mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [], + invoke, + }); + expect(result).toEqual({ id: 'ok' }); + expect(invoke).toHaveBeenCalledTimes(1); + }); + + test('falls back on auth error and returns fallback result', async () => { + const mod = await importModule(); + mod.clearProviderCooldown(); + resolveModelRuntimeCredentials.mockResolvedValueOnce( + runtimeFixture('openrouter'), + ); + + const invoke = vi + .fn() + .mockRejectedValueOnce(new Error('Provider returned 401')) + .mockResolvedValueOnce({ id: 'fallback' }); + + const onFallback = vi.fn(); + const result = await mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [{ model: 'openrouter/a' }], + invoke, + onFallback, + }); + + expect(result).toEqual({ id: 'fallback' }); + expect(invoke).toHaveBeenCalledTimes(2); + expect(invoke.mock.calls[1]?.[0].provider).toBe('openrouter'); + expect(onFallback).toHaveBeenCalledWith( + expect.objectContaining({ entry: { model: 'openrouter/a' } }), + 'auth', + ); + }); + + test('non-auth / non-rate-limit errors surface without fallback', async () => { + const mod = await importModule(); + const invoke = vi.fn().mockRejectedValueOnce(new Error('500 server down')); + await expect( + mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [{ model: 'openrouter/a' }], + invoke, + }), + ).rejects.toThrow('500 server down'); + expect(invoke).toHaveBeenCalledTimes(1); + }); + + test('primary cooldown skips straight to fallback on next request', async () => { + const mod = await importModule(); + mod.clearProviderCooldown(); + mod.markProviderCooldown('openai', 5_000); + + resolveModelRuntimeCredentials.mockResolvedValueOnce( + runtimeFixture('openrouter'), + ); + const invoke = vi.fn().mockResolvedValueOnce({ id: 'fallback' }); + + const result = await mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [{ model: 'openrouter/a' }], + invoke, + }); + + expect(result).toEqual({ id: 'fallback' }); + expect(invoke).toHaveBeenCalledTimes(1); + expect(invoke.mock.calls[0]?.[0].provider).toBe('openrouter'); + mod.clearProviderCooldown(); + }); + + test('cooled-down primary does not extend its own deadline on subsequent requests', async () => { + const mod = await importModule(); + mod.clearProviderCooldown(); + mod.markProviderCooldown('openai', 5_000); + + const cooledUntil = (() => { + // Snapshot the deadline by sniffing isProviderCooledDown across times. + // We don't have a getter, so probe at now and now+4000ms. + return { + atStart: mod.isProviderCooledDown('openai'), + }; + })(); + expect(cooledUntil.atStart).toBe(true); + + resolveModelRuntimeCredentials + .mockResolvedValueOnce(runtimeFixture('openrouter')) + .mockResolvedValueOnce(runtimeFixture('openrouter')) + .mockResolvedValueOnce(runtimeFixture('openrouter')); + + const invoke = vi.fn().mockResolvedValue({ id: 'fb' }); + + // Three back-to-back requests while primary is cooled down. Each request + // creates a fresh controller that sees `!activated`. If we marked + // cooldown, the deadline would be pushed forward on every call and the + // primary would never come back. + for (let i = 0; i < 3; i += 1) { + await mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [{ model: 'openrouter/a' }], + cooldownMs: 5_000, + invoke, + }); + } + + // Far enough past the original 5 s deadline to expose any extension. + expect(mod.isProviderCooledDown('openai', performance.now() + 6_000)).toBe( + false, + ); + mod.clearProviderCooldown(); + }); + + test('shouldFallback=false re-throws the original error without retrying', async () => { + const mod = await importModule(); + mod.clearProviderCooldown(); + resolveModelRuntimeCredentials.mockResolvedValueOnce( + runtimeFixture('openrouter'), + ); + + const invoke = vi + .fn() + .mockRejectedValueOnce(new Error('upstream 401 from primary')); + + await expect( + mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [{ model: 'openrouter/a' }], + invoke, + shouldFallback: () => false, + }), + ).rejects.toThrow('upstream 401 from primary'); + expect(invoke).toHaveBeenCalledTimes(1); + }); + + test('shouldFallback=true allows fallback to proceed', async () => { + const mod = await importModule(); + mod.clearProviderCooldown(); + resolveModelRuntimeCredentials.mockResolvedValueOnce( + runtimeFixture('openrouter'), + ); + + const invoke = vi + .fn() + .mockRejectedValueOnce(new Error('upstream 401 from primary')) + .mockResolvedValueOnce({ id: 'ok' }); + + const result = await mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [{ model: 'openrouter/a' }], + invoke, + shouldFallback: () => true, + }); + expect(result).toEqual({ id: 'ok' }); + expect(invoke).toHaveBeenCalledTimes(2); + }); + + test('exhausted chain re-throws the last error', async () => { + const mod = await importModule(); + mod.clearProviderCooldown(); + resolveModelRuntimeCredentials + .mockResolvedValueOnce(runtimeFixture('openrouter')) + .mockResolvedValueOnce(runtimeFixture('mistral')); + + const invoke = vi + .fn() + .mockRejectedValueOnce(new Error('HTTP 401 A')) + .mockRejectedValueOnce(new Error('HTTP 401 B')) + .mockRejectedValueOnce(new Error('HTTP 401 C')); + + await expect( + mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [{ model: 'openrouter/a' }, { model: 'mistral/b' }], + invoke, + }), + ).rejects.toThrow('HTTP 401 C'); + expect(invoke).toHaveBeenCalledTimes(3); + }); +});