Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions src/__tests__/schema-drift.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ import { describe, it, expect, beforeEach, vi } from 'vitest';
import { validateSchema, type SchemaField } from '../utils/schema-validator';
import { SchemaDriftError } from '../errors';
import { AnthropicProvider } from '../providers/anthropic';
import { OpenAIProvider } from '../providers/openai';
import { GroqProvider } from '../providers/groq';
import { CerebrasProvider } from '../providers/cerebras';
import type { BaseProvider } from '../providers/base';
import { LLMProviderFactory } from '../factory';
import { defaultCircuitBreakerManager } from '../utils/circuit-breaker';
import { defaultExhaustionRegistry } from '../utils/exhaustion';
Expand Down Expand Up @@ -669,3 +673,196 @@ describe('Anthropic nested content-block validation (H-2 / #42)', () => {
expect(res.content).toBe('hi');
});
});

// ── OpenAI-compat provider schema validation ────────────────────────────
//
// OpenAI, Groq, and Cerebras all serve the /chat/completions envelope.
// Driven through describe.each so drift-parity is enforced by construction —
// if one provider's schema diverges, its tests break loudly.

interface OpenAICompatCase {
name: string;
factory: () => BaseProvider;
model: string;
}

const openAiCompatCases: OpenAICompatCase[] = [
{
name: 'openai',
factory: () => new OpenAIProvider({ apiKey: 'test-key', maxRetries: 0 }),
model: 'gpt-4o-mini',
},
{
name: 'groq',
factory: () => new GroqProvider({ apiKey: 'test-key', maxRetries: 0 }),
model: 'llama-3.1-8b-instant',
},
{
name: 'cerebras',
factory: () => new CerebrasProvider({ apiKey: 'test-key', maxRetries: 0 }),
model: 'llama-3.1-8b',
},
];

describe.each(openAiCompatCases)('$name response schema validation', ({ name, factory, model }) => {
let provider: BaseProvider;

beforeEach(() => {
vi.clearAllMocks();
defaultCircuitBreakerManager.resetAll();
provider = factory();
});

const validResponse = {
id: 'chatcmpl_1',
object: 'chat.completion',
created: 1700000000,
model,
choices: [{
index: 0,
message: { role: 'assistant', content: 'hello' },
finish_reason: 'stop',
}],
usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 },
};

it('passes through a well-formed response', async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => validResponse,
headers: new Headers({ 'content-type': 'application/json' }),
});

const res = await provider.generateResponse({
messages: [{ role: 'user', content: 'hi' }],
model,
});

expect(res.content).toBe('hello');
expect(res.usage.inputTokens).toBe(10);
});

it('throws SchemaDriftError when usage.prompt_tokens is renamed', async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
...validResponse,
usage: { input_tokens: 10, completion_tokens: 5, total_tokens: 15 },
}),
headers: new Headers({ 'content-type': 'application/json' }),
});

await expect(provider.generateResponse({
messages: [{ role: 'user', content: 'hi' }],
model,
})).rejects.toMatchObject({
code: 'SCHEMA_DRIFT',
provider: name,
path: 'usage.prompt_tokens',
});
});

it('throws SchemaDriftError when choices field is removed', async () => {
const { choices: _choices, ...rest } = validResponse;
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => rest,
headers: new Headers({ 'content-type': 'application/json' }),
});

await expect(provider.generateResponse({
messages: [{ role: 'user', content: 'hi' }],
model,
})).rejects.toMatchObject({ code: 'SCHEMA_DRIFT', path: 'choices' });
});

it('throws SchemaDriftError when choices is empty (routes through drift, not bare throw)', async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({ ...validResponse, choices: [] }),
headers: new Headers({ 'content-type': 'application/json' }),
});

await expect(provider.generateResponse({
messages: [{ role: 'user', content: 'hi' }],
model,
})).rejects.toMatchObject({
code: 'SCHEMA_DRIFT',
provider: name,
path: 'choices[0]',
});
});

it('throws SchemaDriftError when tool_call function.arguments is not a string', async () => {
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
...validResponse,
choices: [{
index: 0,
message: {
role: 'assistant',
content: null,
tool_calls: [{
id: 'call_1',
type: 'function',
function: { name: 'my_tool', arguments: { already: 'parsed' } },
}],
},
finish_reason: 'tool_calls',
}],
}),
headers: new Headers({ 'content-type': 'application/json' }),
});

await expect(provider.generateResponse({
messages: [{ role: 'user', content: 'hi' }],
model,
})).rejects.toMatchObject({
code: 'SCHEMA_DRIFT',
provider: name,
path: 'choices[0].message.tool_calls[0].function.arguments',
expected: 'string',
actual: 'object',
});
});

it('accepts unknown tool_call type without surfacing it as a function call (forward-compat)', async () => {
// Schema's discriminator skips unknown `type` values (forward-compat for
// additive upstream changes). The provider's formatResponse must not
// dereference the function-shaped payload on a skipped variant, or an
// unknown shape becomes a bare TypeError (bypassing drift/fallback) or
// gets mis-surfaced as a normal function call. Mock omits `function`
// entirely to exercise the TypeError path specifically.
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
...validResponse,
choices: [{
index: 0,
message: {
role: 'assistant',
content: 'hi',
tool_calls: [{
id: 'call_1',
type: 'code_interpreter', // hypothetical future tool type
// intentionally no `function` field — unknown variants may have a
// different shape upstream
}],
},
finish_reason: 'stop',
}],
}),
headers: new Headers({ 'content-type': 'application/json' }),
});

const res = await provider.generateResponse({
messages: [{ role: 'user', content: 'hi' }],
model,
});
expect(res.content).toBe('hi');
// Critical: the unknown variant must be dropped, not mis-surfaced as a
// function call and not crashed through.
expect(res.toolCalls).toBeUndefined();
});
});
16 changes: 12 additions & 4 deletions src/__tests__/tool-call-validation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,13 @@ describe('Tool call validation at provider boundary', () => {
expect(res.toolCalls).toBeUndefined();
});

it('should drop tool call with non-string arguments', async () => {
it('routes non-string arguments to SchemaDriftError (envelope contract violation)', async () => {
// `function.arguments` is a non-string: this is an envelope-shape
// violation (OpenAI's contract says stringified JSON), not a within-
// envelope semantic issue. Schema validation upstream of
// validateToolCalls catches it and routes through drift/fallback,
// rather than silently dropping the tool_call. Behavior upgrade from
// PR #23 once #39 slice 2 landed.
mockFetch.mockResolvedValueOnce({
ok: true,
json: async () => ({
Expand All @@ -188,12 +194,14 @@ describe('Tool call validation at provider boundary', () => {
headers: new Headers({ 'content-type': 'application/json' })
});

const res = await provider.generateResponse({
await expect(provider.generateResponse({
messages: [{ role: 'user', content: 'hi' }],
model: 'gpt-4o-mini'
})).rejects.toMatchObject({
code: 'SCHEMA_DRIFT',
provider: 'openai',
path: 'choices[0].message.tool_calls[0].function.arguments',
});

expect(res.toolCalls).toBeUndefined();
});

it('should keep valid tool calls and drop invalid ones', async () => {
Expand Down
59 changes: 52 additions & 7 deletions src/providers/cerebras.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,48 @@ import { BaseProvider } from './base';
import {
LLMErrorFactory,
AuthenticationError,
ConfigurationError
ConfigurationError,
SchemaDriftError
} from '../errors';
import { validateSchema, type SchemaField } from '../utils/schema-validator';

// Cerebras serves the OpenAI /chat/completions contract. See groq.ts for the
// rationale on keeping each OpenAI-compat provider's schema as its own
// constant rather than a shared import.
const CEREBRAS_RESPONSE_SCHEMA: SchemaField[] = [
{ path: 'id', type: 'string' },
{ path: 'model', type: 'string' },
{
path: 'choices',
type: 'array',
items: {
shape: [
{ path: 'message', type: 'object' },
{ path: 'message.content', type: 'string-or-null' },
{ path: 'finish_reason', type: 'string' },
{
path: 'message.tool_calls',
type: 'array',
optional: true,
items: {
discriminator: 'type',
variants: {
function: [
{ path: 'id', type: 'string' },
{ path: 'function.name', type: 'string' },
{ path: 'function.arguments', type: 'string' },
],
},
},
},
],
},
},
{ path: 'usage', type: 'object' },
{ path: 'usage.prompt_tokens', type: 'number' },
{ path: 'usage.completion_tokens', type: 'number' },
{ path: 'usage.total_tokens', type: 'number' },
];

interface CerebrasMessage {
role: 'system' | 'user' | 'assistant' | 'tool';
Expand Down Expand Up @@ -110,8 +150,9 @@ export class CerebrasProvider extends BaseProvider {
throw await LLMErrorFactory.fromFetchResponse('cerebras', httpResponse);
}

const data: CerebrasResponse = await httpResponse.json();
return this.formatResponse(data, Date.now() - startTime);
const data = await httpResponse.json() as unknown;
validateSchema('cerebras', data, CEREBRAS_RESPONSE_SCHEMA);
return this.formatResponse(data as CerebrasResponse, Date.now() - startTime);
});

this.updateMetrics(response.responseTime, true, response.usage.cost);
Expand Down Expand Up @@ -379,7 +420,7 @@ export class CerebrasProvider extends BaseProvider {
): LLMResponse {
const choice = data.choices[0];
if (!choice) {
throw new Error('No choices returned from Cerebras');
throw new SchemaDriftError('cerebras', 'choices[0]', 'object', 'undefined');
}

const content = choice.message.content || '';
Expand All @@ -394,10 +435,14 @@ export class CerebrasProvider extends BaseProvider {
)
};

// Extract tool calls if present (validated at provider boundary)
// Extract tool calls if present (validated at provider boundary).
// See groq.ts for the rationale on filtering unknown `type` variants
// before dereferencing `tc.function` — keeps forward-compat discriminator
// skips from becoming bare TypeErrors.
let toolCalls: ToolCall[] | undefined;
if (choice.message.tool_calls && choice.message.tool_calls.length > 0) {
const raw: ToolCall[] = choice.message.tool_calls.map(tc => ({
const functionCalls = choice.message.tool_calls?.filter(tc => tc.type === 'function');
if (functionCalls && functionCalls.length > 0) {
const raw: ToolCall[] = functionCalls.map(tc => ({
id: tc.id,
type: 'function' as const,
function: { name: tc.function.name, arguments: tc.function.arguments }
Expand Down
Loading
Loading