From 7f1d64001ba1ec58d0e2b24c92c96fbb7641991f Mon Sep 17 00:00:00 2001 From: Benedikt Koehler Date: Sun, 26 Apr 2026 22:34:23 +0200 Subject: [PATCH 1/3] feat: add provider fallback chain with primary-leave cooldown MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds a resilience layer modeled on hermes-agent's `_try_activate_fallback` pattern: an ordered fallback chain swaps in a backup provider on auth (401/403) or rate-limit (429/quota) failures, while a 60s cooldown clock is set only when leaving the primary — chain-internal switches don't re-arm it. Subsequent requests skip a cooled-down primary and go straight to the first healthy fallback. - New `src/gateway/provider-fallback.ts` with `ProviderFallbackController`, module-level cooldown map, error classifier, and `callWithProviderFallback` wrapper. - Wraps both tool-chat and streaming tool-chat handlers in `openai-compatible.ts`. Streaming retries refuse mid-stream switches to avoid duplicated text deltas. - Configured via `HYBRIDAI_FALLBACK_CHAIN` env var (JSON array of `{model, baseUrl?, keyEnv?, chatbotId?, agentId?}` entries). - 11 new unit tests covering chain parsing, error classification, primary-leave cooldown semantics, key-env override, primary-cooldown skip, and exhausted-chain re-throw. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/gateway/openai-compatible.ts | 71 +++++--- src/gateway/provider-fallback.ts | 241 +++++++++++++++++++++++++++ tests/provider-fallback.test.ts | 272 +++++++++++++++++++++++++++++++ 3 files changed, 561 insertions(+), 23 deletions(-) create mode 100644 src/gateway/provider-fallback.ts create mode 100644 tests/provider-fallback.test.ts diff --git a/src/gateway/openai-compatible.ts b/src/gateway/openai-compatible.ts index cdc6e8024..54440de73 100644 --- a/src/gateway/openai-compatible.ts +++ b/src/gateway/openai-compatible.ts @@ -39,6 +39,10 @@ import { callOpenAICompatibleModelStream, mapOpenAICompatibleUsageToTokenStats, } from './openai-compatible-model.js'; +import { + callWithProviderFallback, + loadFallbackChainFromEnv, +} from './provider-fallback.js'; import { OpenAICompatibleRequestError, readOpenAICompatibleChatRequest, @@ -495,12 +499,18 @@ async function handleOpenAICompatibleToolChat( ): Promise { const runtime = await resolveToolAwareRuntime({ prepared }); const messages = await buildToolAwareMessages({ input, prepared }); - const result = await callOpenAICompatibleModel({ - runtime, - model: prepared.model, - messages, - tools: input.tools, - toolChoice: input.toolChoice, + const result = await callWithProviderFallback({ + primaryRuntime: runtime, + primaryModel: prepared.model, + chain: loadFallbackChainFromEnv(), + invoke: (activeRuntime, activeModel) => + callOpenAICompatibleModel({ + runtime: activeRuntime, + model: activeModel, + messages, + tools: input.tools, + toolChoice: input.toolChoice, + }), }); const choice = result.choices[0]; const payload = buildOpenAICompatibleCompletionResponse({ @@ -694,23 +704,38 @@ async function handleOpenAICompatibleStreamingToolChat( try { const runtime = await resolveToolAwareRuntime({ prepared }); const messages = await buildToolAwareMessages({ input, prepared }); - const result = await callOpenAICompatibleModelStream({ - runtime, - model: prepared.model, - messages, - tools: input.tools, - toolChoice: input.toolChoice, - onTextDelta: (delta) => { - if (!isResponseWritable(res) || !delta) return; - writeOpenAICompatibleStreamChunk( - res, - buildOpenAICompatibleStreamTextChunk({ - completionId, - created, - model: prepared.responseModel, - content: delta, - }), - ); + let streamStarted = false; + const result = await callWithProviderFallback({ + primaryRuntime: runtime, + primaryModel: prepared.model, + chain: loadFallbackChainFromEnv(), + invoke: async (activeRuntime, activeModel) => { + if (streamStarted) { + throw new Error( + 'Stream already started; cannot retry provider fallback mid-stream.', + ); + } + return callOpenAICompatibleModelStream({ + runtime: activeRuntime, + model: activeModel, + messages, + tools: input.tools, + toolChoice: input.toolChoice, + onTextDelta: (delta) => { + if (!delta) return; + streamStarted = true; + if (!isResponseWritable(res)) return; + writeOpenAICompatibleStreamChunk( + res, + buildOpenAICompatibleStreamTextChunk({ + completionId, + created, + model: prepared.responseModel, + content: delta, + }), + ); + }, + }); }, }); if (!isResponseWritable(res)) return; diff --git a/src/gateway/provider-fallback.ts b/src/gateway/provider-fallback.ts new file mode 100644 index 000000000..8211b58c4 --- /dev/null +++ b/src/gateway/provider-fallback.ts @@ -0,0 +1,241 @@ +import { performance } from 'node:perf_hooks'; + +import { resolveModelRuntimeCredentials } from '../providers/factory.js'; +import type { ResolvedModelRuntimeCredentials } from '../providers/types.js'; + +export interface FallbackChainEntry { + model: string; + baseUrl?: string; + keyEnv?: string; + chatbotId?: string; + agentId?: string; +} + +export type FallbackReason = 'auth' | 'rate_limit' | 'other'; + +export interface FallbackActivation { + runtime: ResolvedModelRuntimeCredentials; + model: string; + entry: FallbackChainEntry; +} + +const DEFAULT_COOLDOWN_MS = 60_000; + +const cooldownMap = new Map(); + +function isRecord(value: unknown): value is Record { + return !!value && typeof value === 'object' && !Array.isArray(value); +} + +export function loadFallbackChainFromEnv( + raw: string | undefined = process.env.HYBRIDAI_FALLBACK_CHAIN, +): FallbackChainEntry[] { + const text = String(raw || '').trim(); + if (!text) return []; + let parsed: unknown; + try { + parsed = JSON.parse(text); + } catch { + return []; + } + if (!Array.isArray(parsed)) return []; + const entries: FallbackChainEntry[] = []; + for (const item of parsed) { + if (!isRecord(item)) continue; + const model = typeof item.model === 'string' ? item.model.trim() : ''; + if (!model) continue; + const entry: FallbackChainEntry = { model }; + if (typeof item.baseUrl === 'string' && item.baseUrl.trim()) { + entry.baseUrl = item.baseUrl.trim(); + } + if (typeof item.keyEnv === 'string' && item.keyEnv.trim()) { + entry.keyEnv = item.keyEnv.trim(); + } + if (typeof item.chatbotId === 'string' && item.chatbotId.trim()) { + entry.chatbotId = item.chatbotId.trim(); + } + if (typeof item.agentId === 'string' && item.agentId.trim()) { + entry.agentId = item.agentId.trim(); + } + entries.push(entry); + } + return entries; +} + +export function classifyProviderError(err: unknown): FallbackReason { + const text = err instanceof Error ? err.message : String(err); + if (/(^|\D)401(\D|$)|(^|\D)403(\D|$)/.test(text)) return 'auth'; + if (/unauthorized|forbidden|invalid api key|permission denied/i.test(text)) { + return 'auth'; + } + if (/(^|\D)429(\D|$)/.test(text)) return 'rate_limit'; + if (/rate[- ]?limit|too many requests|quota|billing/i.test(text)) { + return 'rate_limit'; + } + return 'other'; +} + +export function isProviderCooledDown( + providerId: string, + now: number = performance.now(), +): boolean { + const until = cooldownMap.get(providerId); + return typeof until === 'number' && until > now; +} + +export function markProviderCooldown( + providerId: string, + durationMs: number = DEFAULT_COOLDOWN_MS, + now: number = performance.now(), +): void { + if (!providerId) return; + cooldownMap.set(providerId, now + Math.max(0, durationMs)); +} + +export function clearProviderCooldown(providerId?: string): void { + if (!providerId) { + cooldownMap.clear(); + return; + } + cooldownMap.delete(providerId); +} + +async function resolveEntry( + entry: FallbackChainEntry, +): Promise { + let runtime: ResolvedModelRuntimeCredentials; + try { + runtime = await resolveModelRuntimeCredentials({ + model: entry.model, + ...(entry.agentId ? { agentId: entry.agentId } : {}), + ...(entry.chatbotId ? { chatbotId: entry.chatbotId } : {}), + }); + } catch { + return null; + } + let apiKey = runtime.apiKey; + if (entry.keyEnv) { + const envKey = String(process.env[entry.keyEnv] || '').trim(); + if (envKey) apiKey = envKey; + } + if (!apiKey && !runtime.isLocal) return null; + return { + ...runtime, + ...(entry.baseUrl ? { baseUrl: entry.baseUrl } : {}), + apiKey, + ...(entry.chatbotId ? { chatbotId: entry.chatbotId } : {}), + }; +} + +export interface ProviderFallbackControllerOptions { + chain: FallbackChainEntry[]; + primaryProvider: string; + cooldownMs?: number; +} + +export class ProviderFallbackController { + private readonly chain: FallbackChainEntry[]; + private readonly primaryProvider: string; + private readonly cooldownMs: number; + private index = 0; + private activated = false; + + constructor(opts: ProviderFallbackControllerOptions) { + this.chain = opts.chain; + this.primaryProvider = String(opts.primaryProvider || '') + .trim() + .toLowerCase(); + this.cooldownMs = opts.cooldownMs ?? DEFAULT_COOLDOWN_MS; + } + + hasRemaining(): boolean { + return this.index < this.chain.length; + } + + isActivated(): boolean { + return this.activated; + } + + async tryActivate( + reason: FallbackReason, + currentProvider: string, + ): Promise { + if (reason === 'rate_limit' && this.primaryProvider) { + const current = String(currentProvider || '') + .trim() + .toLowerCase(); + const leavingPrimary = !this.activated || current === this.primaryProvider; + if (leavingPrimary) { + markProviderCooldown(this.primaryProvider, this.cooldownMs); + } + } + while (this.index < this.chain.length) { + const entry = this.chain[this.index]; + this.index += 1; + if (!entry) continue; + const runtime = await resolveEntry(entry); + if (!runtime) continue; + this.activated = true; + return { runtime, model: entry.model, entry }; + } + return null; + } +} + +export interface CallWithFallbackParams { + primaryRuntime: ResolvedModelRuntimeCredentials; + primaryModel: string; + chain: FallbackChainEntry[]; + cooldownMs?: number; + invoke: ( + runtime: ResolvedModelRuntimeCredentials, + model: string, + ) => Promise; + onFallback?: (activation: FallbackActivation, reason: FallbackReason) => void; +} + +export async function callWithProviderFallback( + params: CallWithFallbackParams, +): Promise { + const controller = new ProviderFallbackController({ + chain: params.chain, + primaryProvider: params.primaryRuntime.provider, + ...(params.cooldownMs !== undefined ? { cooldownMs: params.cooldownMs } : {}), + }); + + let runtime = params.primaryRuntime; + let model = params.primaryModel; + + if ( + params.chain.length > 0 && + isProviderCooledDown(params.primaryRuntime.provider) + ) { + const activation = await controller.tryActivate( + 'rate_limit', + params.primaryRuntime.provider, + ); + if (activation) { + runtime = activation.runtime; + model = activation.model; + params.onFallback?.(activation, 'rate_limit'); + } + } + + const maxAttempts = params.chain.length + 1; + let lastError: unknown; + for (let attempt = 0; attempt < maxAttempts; attempt += 1) { + try { + return await params.invoke(runtime, model); + } catch (err) { + lastError = err; + const reason = classifyProviderError(err); + if (reason === 'other' || !controller.hasRemaining()) throw err; + const activation = await controller.tryActivate(reason, runtime.provider); + if (!activation) throw err; + runtime = activation.runtime; + model = activation.model; + params.onFallback?.(activation, reason); + } + } + throw lastError; +} diff --git a/tests/provider-fallback.test.ts b/tests/provider-fallback.test.ts new file mode 100644 index 000000000..f66ddf18f --- /dev/null +++ b/tests/provider-fallback.test.ts @@ -0,0 +1,272 @@ +import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest'; + +import type { ResolvedModelRuntimeCredentials } from '../src/providers/types.js'; + +const resolveModelRuntimeCredentials = vi.fn(); + +vi.mock('../src/providers/factory.js', () => ({ + resolveModelRuntimeCredentials: ( + params: { model: string; chatbotId?: string; agentId?: string } | undefined, + ) => resolveModelRuntimeCredentials(params), +})); + +async function importModule() { + return import('../src/gateway/provider-fallback.js'); +} + +function runtimeFixture( + provider: string, + overrides: Partial = {}, +): ResolvedModelRuntimeCredentials { + return { + provider: provider as ResolvedModelRuntimeCredentials['provider'], + apiKey: `${provider}-key`, + baseUrl: `https://${provider}.example.com/v1`, + chatbotId: '', + enableRag: false, + requestHeaders: {}, + agentId: 'main', + isLocal: false, + ...overrides, + }; +} + +beforeEach(() => { + resolveModelRuntimeCredentials.mockReset(); + vi.resetModules(); +}); + +afterEach(() => { + vi.restoreAllMocks(); +}); + +describe('loadFallbackChainFromEnv', () => { + test('returns empty list for missing or invalid values', async () => { + const mod = await importModule(); + expect(mod.loadFallbackChainFromEnv(undefined)).toEqual([]); + expect(mod.loadFallbackChainFromEnv('')).toEqual([]); + expect(mod.loadFallbackChainFromEnv('not json')).toEqual([]); + expect(mod.loadFallbackChainFromEnv('{}')).toEqual([]); + }); + + test('parses well-formed entries and drops invalid ones', async () => { + const mod = await importModule(); + const chain = mod.loadFallbackChainFromEnv( + JSON.stringify([ + { model: 'gpt-4o-mini' }, + { model: ' ' }, + { baseUrl: 'https://x' }, + { + model: 'claude-3-5-haiku', + baseUrl: 'https://anthropic.example/v1', + keyEnv: 'ANTHROPIC_KEY', + chatbotId: 'cb-2', + }, + ]), + ); + expect(chain).toEqual([ + { model: 'gpt-4o-mini' }, + { + model: 'claude-3-5-haiku', + baseUrl: 'https://anthropic.example/v1', + keyEnv: 'ANTHROPIC_KEY', + chatbotId: 'cb-2', + }, + ]); + }); +}); + +describe('classifyProviderError', () => { + test('identifies auth, rate-limit, and unknown failures', async () => { + const mod = await importModule(); + expect(mod.classifyProviderError(new Error('failed with 401: bad'))).toBe( + 'auth', + ); + expect(mod.classifyProviderError(new Error('Forbidden: blocked'))).toBe( + 'auth', + ); + expect(mod.classifyProviderError(new Error('HTTP 429 too many'))).toBe( + 'rate_limit', + ); + expect( + mod.classifyProviderError(new Error('daily quota exhausted')), + ).toBe('rate_limit'); + expect(mod.classifyProviderError(new Error('500 internal'))).toBe('other'); + }); +}); + +describe('ProviderFallbackController', () => { + test('advances through chain and skips unresolvable entries', async () => { + const mod = await importModule(); + resolveModelRuntimeCredentials + .mockResolvedValueOnce(runtimeFixture('openrouter')) + .mockRejectedValueOnce(new Error('no credentials')) + .mockResolvedValueOnce(runtimeFixture('mistral')); + + const controller = new mod.ProviderFallbackController({ + chain: [ + { model: 'openrouter/a' }, + { model: 'broken/b' }, + { model: 'mistral/c' }, + ], + primaryProvider: 'openai', + }); + + const first = await controller.tryActivate('auth', 'openai'); + expect(first?.runtime.provider).toBe('openrouter'); + expect(controller.isActivated()).toBe(true); + + const second = await controller.tryActivate('auth', 'openrouter'); + expect(second?.runtime.provider).toBe('mistral'); + + const third = await controller.tryActivate('auth', 'mistral'); + expect(third).toBeNull(); + expect(controller.hasRemaining()).toBe(false); + }); + + test('rate-limit cooldown only set when leaving primary', async () => { + const mod = await importModule(); + mod.clearProviderCooldown(); + + resolveModelRuntimeCredentials + .mockResolvedValueOnce(runtimeFixture('openrouter')) + .mockResolvedValueOnce(runtimeFixture('mistral')); + + const controller = new mod.ProviderFallbackController({ + chain: [{ model: 'openrouter/a' }, { model: 'mistral/b' }], + primaryProvider: 'openai', + cooldownMs: 500, + }); + + await controller.tryActivate('rate_limit', 'openai'); + expect(mod.isProviderCooledDown('openai')).toBe(true); + + mod.clearProviderCooldown('openai'); + + await controller.tryActivate('rate_limit', 'openrouter'); + expect(mod.isProviderCooledDown('openai')).toBe(false); + }); + + test('keyEnv override wins over provider credentials', async () => { + const mod = await importModule(); + resolveModelRuntimeCredentials.mockResolvedValueOnce( + runtimeFixture('openrouter', { apiKey: 'fallback-key' }), + ); + vi.stubEnv('CUSTOM_KEY_ENV', 'env-override'); + + const controller = new mod.ProviderFallbackController({ + chain: [{ model: 'openrouter/a', keyEnv: 'CUSTOM_KEY_ENV' }], + primaryProvider: 'openai', + }); + + const activation = await controller.tryActivate('auth', 'openai'); + expect(activation?.runtime.apiKey).toBe('env-override'); + vi.unstubAllEnvs(); + }); +}); + +describe('callWithProviderFallback', () => { + test('invokes primary only when chain is empty', async () => { + const mod = await importModule(); + const invoke = vi.fn().mockResolvedValue({ id: 'ok' }); + const result = await mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [], + invoke, + }); + expect(result).toEqual({ id: 'ok' }); + expect(invoke).toHaveBeenCalledTimes(1); + }); + + test('falls back on auth error and returns fallback result', async () => { + const mod = await importModule(); + mod.clearProviderCooldown(); + resolveModelRuntimeCredentials.mockResolvedValueOnce( + runtimeFixture('openrouter'), + ); + + const invoke = vi + .fn() + .mockRejectedValueOnce(new Error('Provider returned 401')) + .mockResolvedValueOnce({ id: 'fallback' }); + + const onFallback = vi.fn(); + const result = await mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [{ model: 'openrouter/a' }], + invoke, + onFallback, + }); + + expect(result).toEqual({ id: 'fallback' }); + expect(invoke).toHaveBeenCalledTimes(2); + expect(invoke.mock.calls[1]?.[0].provider).toBe('openrouter'); + expect(onFallback).toHaveBeenCalledWith( + expect.objectContaining({ entry: { model: 'openrouter/a' } }), + 'auth', + ); + }); + + test('non-auth / non-rate-limit errors surface without fallback', async () => { + const mod = await importModule(); + const invoke = vi.fn().mockRejectedValueOnce(new Error('500 server down')); + await expect( + mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [{ model: 'openrouter/a' }], + invoke, + }), + ).rejects.toThrow('500 server down'); + expect(invoke).toHaveBeenCalledTimes(1); + }); + + test('primary cooldown skips straight to fallback on next request', async () => { + const mod = await importModule(); + mod.clearProviderCooldown(); + mod.markProviderCooldown('openai', 5_000); + + resolveModelRuntimeCredentials.mockResolvedValueOnce( + runtimeFixture('openrouter'), + ); + const invoke = vi.fn().mockResolvedValueOnce({ id: 'fallback' }); + + const result = await mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [{ model: 'openrouter/a' }], + invoke, + }); + + expect(result).toEqual({ id: 'fallback' }); + expect(invoke).toHaveBeenCalledTimes(1); + expect(invoke.mock.calls[0]?.[0].provider).toBe('openrouter'); + mod.clearProviderCooldown(); + }); + + test('exhausted chain re-throws the last error', async () => { + const mod = await importModule(); + mod.clearProviderCooldown(); + resolveModelRuntimeCredentials + .mockResolvedValueOnce(runtimeFixture('openrouter')) + .mockResolvedValueOnce(runtimeFixture('mistral')); + + const invoke = vi + .fn() + .mockRejectedValueOnce(new Error('HTTP 401 A')) + .mockRejectedValueOnce(new Error('HTTP 401 B')) + .mockRejectedValueOnce(new Error('HTTP 401 C')); + + await expect( + mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [{ model: 'openrouter/a' }, { model: 'mistral/b' }], + invoke, + }), + ).rejects.toThrow('HTTP 401 C'); + expect(invoke).toHaveBeenCalledTimes(3); + }); +}); From 281be06b87b173b7e679d3da9b8e0331f496bed5 Mon Sep 17 00:00:00 2001 From: Benedikt Koehler Date: Sun, 26 Apr 2026 22:34:34 +0200 Subject: [PATCH 2/3] style: apply biome formatting to provider-fallback files Co-Authored-By: Claude Opus 4.7 (1M context) --- src/gateway/openai-compatible.ts | 8 ++++---- src/gateway/provider-fallback.ts | 7 +++++-- tests/provider-fallback.test.ts | 6 +++--- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/src/gateway/openai-compatible.ts b/src/gateway/openai-compatible.ts index 54440de73..290c8d606 100644 --- a/src/gateway/openai-compatible.ts +++ b/src/gateway/openai-compatible.ts @@ -39,10 +39,6 @@ import { callOpenAICompatibleModelStream, mapOpenAICompatibleUsageToTokenStats, } from './openai-compatible-model.js'; -import { - callWithProviderFallback, - loadFallbackChainFromEnv, -} from './provider-fallback.js'; import { OpenAICompatibleRequestError, readOpenAICompatibleChatRequest, @@ -58,6 +54,10 @@ import { sendOpenAICompatibleStreamError, writeOpenAICompatibleStreamChunk, } from './openai-compatible-response.js'; +import { + callWithProviderFallback, + loadFallbackChainFromEnv, +} from './provider-fallback.js'; function isResponseWritable(res: ServerResponse): boolean { return !res.writableEnded && !res.destroyed; diff --git a/src/gateway/provider-fallback.ts b/src/gateway/provider-fallback.ts index 8211b58c4..c26f7b5fd 100644 --- a/src/gateway/provider-fallback.ts +++ b/src/gateway/provider-fallback.ts @@ -164,7 +164,8 @@ export class ProviderFallbackController { const current = String(currentProvider || '') .trim() .toLowerCase(); - const leavingPrimary = !this.activated || current === this.primaryProvider; + const leavingPrimary = + !this.activated || current === this.primaryProvider; if (leavingPrimary) { markProviderCooldown(this.primaryProvider, this.cooldownMs); } @@ -200,7 +201,9 @@ export async function callWithProviderFallback( const controller = new ProviderFallbackController({ chain: params.chain, primaryProvider: params.primaryRuntime.provider, - ...(params.cooldownMs !== undefined ? { cooldownMs: params.cooldownMs } : {}), + ...(params.cooldownMs !== undefined + ? { cooldownMs: params.cooldownMs } + : {}), }); let runtime = params.primaryRuntime; diff --git a/tests/provider-fallback.test.ts b/tests/provider-fallback.test.ts index f66ddf18f..52d74a445 100644 --- a/tests/provider-fallback.test.ts +++ b/tests/provider-fallback.test.ts @@ -88,9 +88,9 @@ describe('classifyProviderError', () => { expect(mod.classifyProviderError(new Error('HTTP 429 too many'))).toBe( 'rate_limit', ); - expect( - mod.classifyProviderError(new Error('daily quota exhausted')), - ).toBe('rate_limit'); + expect(mod.classifyProviderError(new Error('daily quota exhausted'))).toBe( + 'rate_limit', + ); expect(mod.classifyProviderError(new Error('500 internal'))).toBe('other'); }); }); From 818a01a498c6a26ede75b7a832dbcf0314cae56a Mon Sep 17 00:00:00 2001 From: Benedikt Koehler Date: Sun, 26 Apr 2026 22:41:01 +0200 Subject: [PATCH 3/3] fix: address PR #413 review feedback - Reuse the shared `isRecord` helper from `src/utils/type-guards.ts` instead of redeclaring it locally. - Stop re-arming the primary cooldown on the cooled-down skip path: `tryActivate` now accepts `{ markCooldown }`, and the initial skip in `callWithProviderFallback` passes `false`. Without this, steady traffic while the primary was cooling down would push its deadline forward on every request and the primary would never recover. Covered by a new test that fires three back-to-back requests against a cooled-down primary and asserts the original 5 s deadline is honored. - Add an optional `shouldFallback(err, reason)` callback to `callWithProviderFallback`. The streaming tool-chat handler passes `() => !streamStarted`, so a mid-stream provider failure now re-throws the original 401/429 error instead of being masked by a generic "Stream already started" placeholder. Covered by tests for both the suppress and allow paths. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/gateway/openai-compatible.ts | 17 +++--- src/gateway/provider-fallback.ts | 25 +++++++-- tests/provider-fallback.test.ts | 90 ++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 14 deletions(-) diff --git a/src/gateway/openai-compatible.ts b/src/gateway/openai-compatible.ts index 290c8d606..a111129f0 100644 --- a/src/gateway/openai-compatible.ts +++ b/src/gateway/openai-compatible.ts @@ -709,13 +709,13 @@ async function handleOpenAICompatibleStreamingToolChat( primaryRuntime: runtime, primaryModel: prepared.model, chain: loadFallbackChainFromEnv(), - invoke: async (activeRuntime, activeModel) => { - if (streamStarted) { - throw new Error( - 'Stream already started; cannot retry provider fallback mid-stream.', - ); - } - return callOpenAICompatibleModelStream({ + // Once the SSE stream has emitted bytes we cannot safely switch + // providers — a fallback would duplicate or interleave content. Suppress + // further retries so the original provider error propagates instead of a + // generic "Stream already started" placeholder. + shouldFallback: () => !streamStarted, + invoke: (activeRuntime, activeModel) => + callOpenAICompatibleModelStream({ runtime: activeRuntime, model: activeModel, messages, @@ -735,8 +735,7 @@ async function handleOpenAICompatibleStreamingToolChat( }), ); }, - }); - }, + }), }); if (!isResponseWritable(res)) return; const choice = result.choices[0]; diff --git a/src/gateway/provider-fallback.ts b/src/gateway/provider-fallback.ts index c26f7b5fd..2a6e31aeb 100644 --- a/src/gateway/provider-fallback.ts +++ b/src/gateway/provider-fallback.ts @@ -2,6 +2,7 @@ import { performance } from 'node:perf_hooks'; import { resolveModelRuntimeCredentials } from '../providers/factory.js'; import type { ResolvedModelRuntimeCredentials } from '../providers/types.js'; +import { isRecord } from '../utils/type-guards.js'; export interface FallbackChainEntry { model: string; @@ -23,10 +24,6 @@ const DEFAULT_COOLDOWN_MS = 60_000; const cooldownMap = new Map(); -function isRecord(value: unknown): value is Record { - return !!value && typeof value === 'object' && !Array.isArray(value); -} - export function loadFallbackChainFromEnv( raw: string | undefined = process.env.HYBRIDAI_FALLBACK_CHAIN, ): FallbackChainEntry[] { @@ -159,8 +156,10 @@ export class ProviderFallbackController { async tryActivate( reason: FallbackReason, currentProvider: string, + options: { markCooldown?: boolean } = {}, ): Promise { - if (reason === 'rate_limit' && this.primaryProvider) { + const markCooldown = options.markCooldown !== false; + if (markCooldown && reason === 'rate_limit' && this.primaryProvider) { const current = String(currentProvider || '') .trim() .toLowerCase(); @@ -193,6 +192,14 @@ export interface CallWithFallbackParams { model: string, ) => Promise; onFallback?: (activation: FallbackActivation, reason: FallbackReason) => void; + /** + * Optional gate consulted before each fallback retry. Receives the original + * error and the classified reason; return `false` to suppress further + * fallback attempts and re-throw the original error. Useful for callers that + * have begun emitting bytes (e.g. SSE streams) and cannot safely retry on a + * different provider mid-response. + */ + shouldFallback?: (err: unknown, reason: FallbackReason) => boolean; } export async function callWithProviderFallback( @@ -213,9 +220,14 @@ export async function callWithProviderFallback( params.chain.length > 0 && isProviderCooledDown(params.primaryRuntime.provider) ) { + // Primary is already in cooldown from a previous request — skip straight to + // the first fallback. Pass `markCooldown: false` so we do NOT extend the + // existing deadline; otherwise steady traffic would push the cooldown + // forward on every request and the primary would never come back. const activation = await controller.tryActivate( 'rate_limit', params.primaryRuntime.provider, + { markCooldown: false }, ); if (activation) { runtime = activation.runtime; @@ -233,6 +245,9 @@ export async function callWithProviderFallback( lastError = err; const reason = classifyProviderError(err); if (reason === 'other' || !controller.hasRemaining()) throw err; + if (params.shouldFallback && !params.shouldFallback(err, reason)) { + throw err; + } const activation = await controller.tryActivate(reason, runtime.provider); if (!activation) throw err; runtime = activation.runtime; diff --git a/tests/provider-fallback.test.ts b/tests/provider-fallback.test.ts index 52d74a445..69a49f580 100644 --- a/tests/provider-fallback.test.ts +++ b/tests/provider-fallback.test.ts @@ -1,3 +1,5 @@ +import { performance } from 'node:perf_hooks'; + import { afterEach, beforeEach, describe, expect, test, vi } from 'vitest'; import type { ResolvedModelRuntimeCredentials } from '../src/providers/types.js'; @@ -246,6 +248,94 @@ describe('callWithProviderFallback', () => { mod.clearProviderCooldown(); }); + test('cooled-down primary does not extend its own deadline on subsequent requests', async () => { + const mod = await importModule(); + mod.clearProviderCooldown(); + mod.markProviderCooldown('openai', 5_000); + + const cooledUntil = (() => { + // Snapshot the deadline by sniffing isProviderCooledDown across times. + // We don't have a getter, so probe at now and now+4000ms. + return { + atStart: mod.isProviderCooledDown('openai'), + }; + })(); + expect(cooledUntil.atStart).toBe(true); + + resolveModelRuntimeCredentials + .mockResolvedValueOnce(runtimeFixture('openrouter')) + .mockResolvedValueOnce(runtimeFixture('openrouter')) + .mockResolvedValueOnce(runtimeFixture('openrouter')); + + const invoke = vi.fn().mockResolvedValue({ id: 'fb' }); + + // Three back-to-back requests while primary is cooled down. Each request + // creates a fresh controller that sees `!activated`. If we marked + // cooldown, the deadline would be pushed forward on every call and the + // primary would never come back. + for (let i = 0; i < 3; i += 1) { + await mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [{ model: 'openrouter/a' }], + cooldownMs: 5_000, + invoke, + }); + } + + // Far enough past the original 5 s deadline to expose any extension. + expect(mod.isProviderCooledDown('openai', performance.now() + 6_000)).toBe( + false, + ); + mod.clearProviderCooldown(); + }); + + test('shouldFallback=false re-throws the original error without retrying', async () => { + const mod = await importModule(); + mod.clearProviderCooldown(); + resolveModelRuntimeCredentials.mockResolvedValueOnce( + runtimeFixture('openrouter'), + ); + + const invoke = vi + .fn() + .mockRejectedValueOnce(new Error('upstream 401 from primary')); + + await expect( + mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [{ model: 'openrouter/a' }], + invoke, + shouldFallback: () => false, + }), + ).rejects.toThrow('upstream 401 from primary'); + expect(invoke).toHaveBeenCalledTimes(1); + }); + + test('shouldFallback=true allows fallback to proceed', async () => { + const mod = await importModule(); + mod.clearProviderCooldown(); + resolveModelRuntimeCredentials.mockResolvedValueOnce( + runtimeFixture('openrouter'), + ); + + const invoke = vi + .fn() + .mockRejectedValueOnce(new Error('upstream 401 from primary')) + .mockResolvedValueOnce({ id: 'ok' }); + + const result = await mod.callWithProviderFallback({ + primaryRuntime: runtimeFixture('openai'), + primaryModel: 'gpt-4o', + chain: [{ model: 'openrouter/a' }], + invoke, + shouldFallback: () => true, + }); + expect(result).toEqual({ id: 'ok' }); + expect(invoke).toHaveBeenCalledTimes(2); + }); + test('exhausted chain re-throws the last error', async () => { const mod = await importModule(); mod.clearProviderCooldown();