diff --git a/src/__tests__/schema-drift.test.ts b/src/__tests__/schema-drift.test.ts index f3f28e2..48287c0 100644 --- a/src/__tests__/schema-drift.test.ts +++ b/src/__tests__/schema-drift.test.ts @@ -13,6 +13,10 @@ import { describe, it, expect, beforeEach, vi } from 'vitest'; import { validateSchema, type SchemaField } from '../utils/schema-validator'; import { SchemaDriftError } from '../errors'; import { AnthropicProvider } from '../providers/anthropic'; +import { OpenAIProvider } from '../providers/openai'; +import { GroqProvider } from '../providers/groq'; +import { CerebrasProvider } from '../providers/cerebras'; +import type { BaseProvider } from '../providers/base'; import { LLMProviderFactory } from '../factory'; import { defaultCircuitBreakerManager } from '../utils/circuit-breaker'; import { defaultExhaustionRegistry } from '../utils/exhaustion'; @@ -669,3 +673,196 @@ describe('Anthropic nested content-block validation (H-2 / #42)', () => { expect(res.content).toBe('hi'); }); }); + +// ── OpenAI-compat provider schema validation ──────────────────────────── +// +// OpenAI, Groq, and Cerebras all serve the /chat/completions envelope. +// Driven through describe.each so drift-parity is enforced by construction — +// if one provider's schema diverges, its tests break loudly. + +interface OpenAICompatCase { + name: string; + factory: () => BaseProvider; + model: string; +} + +const openAiCompatCases: OpenAICompatCase[] = [ + { + name: 'openai', + factory: () => new OpenAIProvider({ apiKey: 'test-key', maxRetries: 0 }), + model: 'gpt-4o-mini', + }, + { + name: 'groq', + factory: () => new GroqProvider({ apiKey: 'test-key', maxRetries: 0 }), + model: 'llama-3.1-8b-instant', + }, + { + name: 'cerebras', + factory: () => new CerebrasProvider({ apiKey: 'test-key', maxRetries: 0 }), + model: 'llama-3.1-8b', + }, +]; + +describe.each(openAiCompatCases)('$name response schema validation', ({ name, factory, model }) => { + let provider: BaseProvider; + + beforeEach(() => { + vi.clearAllMocks(); + defaultCircuitBreakerManager.resetAll(); + provider = factory(); + }); + + const validResponse = { + id: 'chatcmpl_1', + object: 'chat.completion', + created: 1700000000, + model, + choices: [{ + index: 0, + message: { role: 'assistant', content: 'hello' }, + finish_reason: 'stop', + }], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + }; + + it('passes through a well-formed response', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => validResponse, + headers: new Headers({ 'content-type': 'application/json' }), + }); + + const res = await provider.generateResponse({ + messages: [{ role: 'user', content: 'hi' }], + model, + }); + + expect(res.content).toBe('hello'); + expect(res.usage.inputTokens).toBe(10); + }); + + it('throws SchemaDriftError when usage.prompt_tokens is renamed', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + ...validResponse, + usage: { input_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + }), + headers: new Headers({ 'content-type': 'application/json' }), + }); + + await expect(provider.generateResponse({ + messages: [{ role: 'user', content: 'hi' }], + model, + })).rejects.toMatchObject({ + code: 'SCHEMA_DRIFT', + provider: name, + path: 'usage.prompt_tokens', + }); + }); + + it('throws SchemaDriftError when choices field is removed', async () => { + const { choices: _choices, ...rest } = validResponse; + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => rest, + headers: new Headers({ 'content-type': 'application/json' }), + }); + + await expect(provider.generateResponse({ + messages: [{ role: 'user', content: 'hi' }], + model, + })).rejects.toMatchObject({ code: 'SCHEMA_DRIFT', path: 'choices' }); + }); + + it('throws SchemaDriftError when choices is empty (routes through drift, not bare throw)', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ ...validResponse, choices: [] }), + headers: new Headers({ 'content-type': 'application/json' }), + }); + + await expect(provider.generateResponse({ + messages: [{ role: 'user', content: 'hi' }], + model, + })).rejects.toMatchObject({ + code: 'SCHEMA_DRIFT', + provider: name, + path: 'choices[0]', + }); + }); + + it('throws SchemaDriftError when tool_call function.arguments is not a string', async () => { + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + ...validResponse, + choices: [{ + index: 0, + message: { + role: 'assistant', + content: null, + tool_calls: [{ + id: 'call_1', + type: 'function', + function: { name: 'my_tool', arguments: { already: 'parsed' } }, + }], + }, + finish_reason: 'tool_calls', + }], + }), + headers: new Headers({ 'content-type': 'application/json' }), + }); + + await expect(provider.generateResponse({ + messages: [{ role: 'user', content: 'hi' }], + model, + })).rejects.toMatchObject({ + code: 'SCHEMA_DRIFT', + provider: name, + path: 'choices[0].message.tool_calls[0].function.arguments', + expected: 'string', + actual: 'object', + }); + }); + + it('accepts unknown tool_call type without surfacing it as a function call (forward-compat)', async () => { + // Schema's discriminator skips unknown `type` values (forward-compat for + // additive upstream changes). The provider's formatResponse must not + // dereference the function-shaped payload on a skipped variant, or an + // unknown shape becomes a bare TypeError (bypassing drift/fallback) or + // gets mis-surfaced as a normal function call. Mock omits `function` + // entirely to exercise the TypeError path specifically. + mockFetch.mockResolvedValueOnce({ + ok: true, + json: async () => ({ + ...validResponse, + choices: [{ + index: 0, + message: { + role: 'assistant', + content: 'hi', + tool_calls: [{ + id: 'call_1', + type: 'code_interpreter', // hypothetical future tool type + // intentionally no `function` field — unknown variants may have a + // different shape upstream + }], + }, + finish_reason: 'stop', + }], + }), + headers: new Headers({ 'content-type': 'application/json' }), + }); + + const res = await provider.generateResponse({ + messages: [{ role: 'user', content: 'hi' }], + model, + }); + expect(res.content).toBe('hi'); + // Critical: the unknown variant must be dropped, not mis-surfaced as a + // function call and not crashed through. + expect(res.toolCalls).toBeUndefined(); + }); +}); diff --git a/src/__tests__/tool-call-validation.test.ts b/src/__tests__/tool-call-validation.test.ts index aecbb93..cbc86e5 100644 --- a/src/__tests__/tool-call-validation.test.ts +++ b/src/__tests__/tool-call-validation.test.ts @@ -164,7 +164,13 @@ describe('Tool call validation at provider boundary', () => { expect(res.toolCalls).toBeUndefined(); }); - it('should drop tool call with non-string arguments', async () => { + it('routes non-string arguments to SchemaDriftError (envelope contract violation)', async () => { + // `function.arguments` is a non-string: this is an envelope-shape + // violation (OpenAI's contract says stringified JSON), not a within- + // envelope semantic issue. Schema validation upstream of + // validateToolCalls catches it and routes through drift/fallback, + // rather than silently dropping the tool_call. Behavior upgrade from + // PR #23 once #39 slice 2 landed. mockFetch.mockResolvedValueOnce({ ok: true, json: async () => ({ @@ -188,12 +194,14 @@ describe('Tool call validation at provider boundary', () => { headers: new Headers({ 'content-type': 'application/json' }) }); - const res = await provider.generateResponse({ + await expect(provider.generateResponse({ messages: [{ role: 'user', content: 'hi' }], model: 'gpt-4o-mini' + })).rejects.toMatchObject({ + code: 'SCHEMA_DRIFT', + provider: 'openai', + path: 'choices[0].message.tool_calls[0].function.arguments', }); - - expect(res.toolCalls).toBeUndefined(); }); it('should keep valid tool calls and drop invalid ones', async () => { diff --git a/src/providers/cerebras.ts b/src/providers/cerebras.ts index cd0337c..4f2da33 100644 --- a/src/providers/cerebras.ts +++ b/src/providers/cerebras.ts @@ -8,8 +8,48 @@ import { BaseProvider } from './base'; import { LLMErrorFactory, AuthenticationError, - ConfigurationError + ConfigurationError, + SchemaDriftError } from '../errors'; +import { validateSchema, type SchemaField } from '../utils/schema-validator'; + +// Cerebras serves the OpenAI /chat/completions contract. See groq.ts for the +// rationale on keeping each OpenAI-compat provider's schema as its own +// constant rather than a shared import. +const CEREBRAS_RESPONSE_SCHEMA: SchemaField[] = [ + { path: 'id', type: 'string' }, + { path: 'model', type: 'string' }, + { + path: 'choices', + type: 'array', + items: { + shape: [ + { path: 'message', type: 'object' }, + { path: 'message.content', type: 'string-or-null' }, + { path: 'finish_reason', type: 'string' }, + { + path: 'message.tool_calls', + type: 'array', + optional: true, + items: { + discriminator: 'type', + variants: { + function: [ + { path: 'id', type: 'string' }, + { path: 'function.name', type: 'string' }, + { path: 'function.arguments', type: 'string' }, + ], + }, + }, + }, + ], + }, + }, + { path: 'usage', type: 'object' }, + { path: 'usage.prompt_tokens', type: 'number' }, + { path: 'usage.completion_tokens', type: 'number' }, + { path: 'usage.total_tokens', type: 'number' }, +]; interface CerebrasMessage { role: 'system' | 'user' | 'assistant' | 'tool'; @@ -110,8 +150,9 @@ export class CerebrasProvider extends BaseProvider { throw await LLMErrorFactory.fromFetchResponse('cerebras', httpResponse); } - const data: CerebrasResponse = await httpResponse.json(); - return this.formatResponse(data, Date.now() - startTime); + const data = await httpResponse.json() as unknown; + validateSchema('cerebras', data, CEREBRAS_RESPONSE_SCHEMA); + return this.formatResponse(data as CerebrasResponse, Date.now() - startTime); }); this.updateMetrics(response.responseTime, true, response.usage.cost); @@ -379,7 +420,7 @@ export class CerebrasProvider extends BaseProvider { ): LLMResponse { const choice = data.choices[0]; if (!choice) { - throw new Error('No choices returned from Cerebras'); + throw new SchemaDriftError('cerebras', 'choices[0]', 'object', 'undefined'); } const content = choice.message.content || ''; @@ -394,10 +435,14 @@ export class CerebrasProvider extends BaseProvider { ) }; - // Extract tool calls if present (validated at provider boundary) + // Extract tool calls if present (validated at provider boundary). + // See groq.ts for the rationale on filtering unknown `type` variants + // before dereferencing `tc.function` — keeps forward-compat discriminator + // skips from becoming bare TypeErrors. let toolCalls: ToolCall[] | undefined; - if (choice.message.tool_calls && choice.message.tool_calls.length > 0) { - const raw: ToolCall[] = choice.message.tool_calls.map(tc => ({ + const functionCalls = choice.message.tool_calls?.filter(tc => tc.type === 'function'); + if (functionCalls && functionCalls.length > 0) { + const raw: ToolCall[] = functionCalls.map(tc => ({ id: tc.id, type: 'function' as const, function: { name: tc.function.name, arguments: tc.function.arguments } diff --git a/src/providers/groq.ts b/src/providers/groq.ts index 2cae1df..6598e2e 100644 --- a/src/providers/groq.ts +++ b/src/providers/groq.ts @@ -8,8 +8,49 @@ import { BaseProvider } from './base'; import { LLMErrorFactory, AuthenticationError, - ConfigurationError + ConfigurationError, + SchemaDriftError } from '../errors'; +import { validateSchema, type SchemaField } from '../utils/schema-validator'; + +// Groq serves the OpenAI /chat/completions contract — same envelope shape as +// OpenAI. Kept as a separate constant (not imported from openai.ts) because +// each provider's envelope is an independent API surface; shared drift would +// be a correlated outage signal, not a single bug. +const GROQ_RESPONSE_SCHEMA: SchemaField[] = [ + { path: 'id', type: 'string' }, + { path: 'model', type: 'string' }, + { + path: 'choices', + type: 'array', + items: { + shape: [ + { path: 'message', type: 'object' }, + { path: 'message.content', type: 'string-or-null' }, + { path: 'finish_reason', type: 'string' }, + { + path: 'message.tool_calls', + type: 'array', + optional: true, + items: { + discriminator: 'type', + variants: { + function: [ + { path: 'id', type: 'string' }, + { path: 'function.name', type: 'string' }, + { path: 'function.arguments', type: 'string' }, + ], + }, + }, + }, + ], + }, + }, + { path: 'usage', type: 'object' }, + { path: 'usage.prompt_tokens', type: 'number' }, + { path: 'usage.completion_tokens', type: 'number' }, + { path: 'usage.total_tokens', type: 'number' }, +]; interface GroqMessage { role: 'system' | 'user' | 'assistant' | 'tool'; @@ -110,8 +151,9 @@ export class GroqProvider extends BaseProvider { throw await LLMErrorFactory.fromFetchResponse('groq', httpResponse); } - const data: GroqResponse = await httpResponse.json(); - return this.formatResponse(data, Date.now() - startTime); + const data = await httpResponse.json() as unknown; + validateSchema('groq', data, GROQ_RESPONSE_SCHEMA); + return this.formatResponse(data as GroqResponse, Date.now() - startTime); }); this.updateMetrics(response.responseTime, true, response.usage.cost); @@ -384,7 +426,7 @@ export class GroqProvider extends BaseProvider { ): LLMResponse { const choice = data.choices[0]; if (!choice) { - throw new Error('No choices returned from Groq'); + throw new SchemaDriftError('groq', 'choices[0]', 'object', 'undefined'); } const content = choice.message.content || ''; @@ -399,10 +441,17 @@ export class GroqProvider extends BaseProvider { ) }; - // Extract tool calls if present (validated at provider boundary) + // Extract tool calls if present (validated at provider boundary). + // Filter to function-type variants before dereferencing `tc.function`: + // the schema discriminator treats unknown `type` values as forward-compat + // (skipped, not drift), so a future `code_interpreter`-shaped variant + // may arrive without the `function` field we expect. Dropping at the map + // boundary keeps unknown variants invisible rather than surfacing a bare + // TypeError that bypasses the drift/fallback machinery. let toolCalls: ToolCall[] | undefined; - if (choice.message.tool_calls && choice.message.tool_calls.length > 0) { - const raw: ToolCall[] = choice.message.tool_calls.map(tc => ({ + const functionCalls = choice.message.tool_calls?.filter(tc => tc.type === 'function'); + if (functionCalls && functionCalls.length > 0) { + const raw: ToolCall[] = functionCalls.map(tc => ({ id: tc.id, type: 'function' as const, function: { name: tc.function.name, arguments: tc.function.arguments } diff --git a/src/providers/openai.ts b/src/providers/openai.ts index ca20956..0430cba 100755 --- a/src/providers/openai.ts +++ b/src/providers/openai.ts @@ -9,8 +9,50 @@ import { LLMErrorFactory, AuthenticationError, ModelNotFoundError, - RateLimitError + RateLimitError, + SchemaDriftError } from '../errors'; +import { validateSchema, type SchemaField } from '../utils/schema-validator'; + +// Minimum envelope `formatResponse` reads. `tool_calls` uses a discriminated +// union (single `function` variant today) so an additive new tool type upstream +// is forward-compat rather than drift. Empty `choices` is surfaced as drift at +// the `choices[0]` path rather than a bare throw, so it routes through the +// fallback/hook machinery like every other envelope failure. +const OPENAI_RESPONSE_SCHEMA: SchemaField[] = [ + { path: 'id', type: 'string' }, + { path: 'model', type: 'string' }, + { + path: 'choices', + type: 'array', + items: { + shape: [ + { path: 'message', type: 'object' }, + { path: 'message.content', type: 'string-or-null' }, + { path: 'finish_reason', type: 'string' }, + { + path: 'message.tool_calls', + type: 'array', + optional: true, + items: { + discriminator: 'type', + variants: { + function: [ + { path: 'id', type: 'string' }, + { path: 'function.name', type: 'string' }, + { path: 'function.arguments', type: 'string' }, + ], + }, + }, + }, + ], + }, + }, + { path: 'usage', type: 'object' }, + { path: 'usage.prompt_tokens', type: 'number' }, + { path: 'usage.completion_tokens', type: 'number' }, + { path: 'usage.total_tokens', type: 'number' }, +]; interface OpenAIToolCall { id: string; @@ -124,8 +166,9 @@ export class OpenAIProvider extends BaseProvider { throw await LLMErrorFactory.fromFetchResponse('openai', httpResponse); } - const data: OpenAIResponse = await httpResponse.json(); - return this.formatResponse(data, Date.now() - startTime); + const data = await httpResponse.json() as unknown; + validateSchema('openai', data, OPENAI_RESPONSE_SCHEMA); + return this.formatResponse(data as OpenAIResponse, Date.now() - startTime); }); this.updateMetrics(response.responseTime, true, response.usage.cost); @@ -357,7 +400,7 @@ export class OpenAIProvider extends BaseProvider { ): LLMResponse { const choice = data.choices[0]; if (!choice) { - throw new Error('No choices returned from OpenAI'); + throw new SchemaDriftError('openai', 'choices[0]', 'object', 'undefined'); } const content = choice.message.content || '';