Skip to content

Commit 5abfbaa

Browse files
Clement-LelievreClement
andauthored
Support multi-modal output in LLMJudge (#3696)
Co-authored-by: Clement <clement@rayon.so>
1 parent eaedf8a commit 5abfbaa

File tree

2 files changed

+148
-70
lines changed

2 files changed

+148
-70
lines changed

pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -221,39 +221,42 @@ def _stringify(value: Any) -> str:
221221
return repr(value)
222222

223223

224+
def _make_section(content: Any, tag: str) -> list[str | UserContent]:
225+
"""Create a tagged section, handling different content types, for use in the LLMJudge's prompt.
226+
227+
Args:
228+
content (Any): content to include in the section_
229+
tag (str): tag name for the section
230+
231+
Returns:
232+
list[str | UserContent]: the tagged section as a list of strings or UserContent
233+
"""
234+
sections: list[str | UserContent] = []
235+
content = content if isinstance(content, Sequence) and not isinstance(content, str) else [content]
236+
237+
sections.append(f'<{tag}>')
238+
for item in content:
239+
sections.append(item if isinstance(item, str | MultiModalContent) else _stringify(item))
240+
sections.append(f'</{tag}>')
241+
return sections
242+
243+
224244
def _build_prompt(
225245
output: Any,
226246
rubric: str,
227247
inputs: Any | None = None,
228248
expected_output: Any | None = None,
229249
) -> str | Sequence[str | UserContent]:
230-
"""Build a prompt that includes input, output, and rubric."""
250+
"""Build a prompt that includes input, output, expected output, and rubric."""
231251
sections: list[str | UserContent] = []
232-
233252
if inputs is not None:
234-
if isinstance(inputs, str):
235-
sections.append(f'<Input>\n{inputs}\n</Input>')
236-
else:
237-
sections.append('<Input>\n')
238-
if isinstance(inputs, Sequence):
239-
for item in inputs: # type: ignore
240-
if isinstance(item, str | MultiModalContent):
241-
sections.append(item)
242-
else:
243-
sections.append(_stringify(item))
244-
elif isinstance(inputs, MultiModalContent):
245-
sections.append(inputs)
246-
else:
247-
sections.append(_stringify(inputs))
248-
sections.append('</Input>')
249-
250-
sections.append(f'<Output>\n{_stringify(output)}\n</Output>')
251-
sections.append(f'<Rubric>\n{rubric}\n</Rubric>')
253+
sections.extend(_make_section(inputs, 'Input'))
252254

253-
if expected_output is not None:
254-
sections.append(f'<ExpectedOutput>\n{_stringify(expected_output)}\n</ExpectedOutput>')
255+
sections.extend(_make_section(output, 'Output'))
256+
sections.extend(_make_section(rubric, 'Rubric'))
255257

256-
if inputs is None or isinstance(inputs, str):
257-
return '\n\n'.join(sections) # type: ignore[arg-type]
258-
else:
259-
return sections
258+
if expected_output is not None:
259+
sections.extend(_make_section(expected_output, 'ExpectedOutput'))
260+
if all(isinstance(section, str) for section in sections):
261+
return '\n'.join(sections) # type: ignore[arg-type]
262+
return sections

tests/evals/test_llm_as_a_judge.py

Lines changed: 119 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,26 @@ async def test_judge_input_output_binary_content_list_mock(mocker: MockerFixture
166166
assert image_content in raw_prompt, 'Expected the exact BinaryContent instance to be in the prompt list'
167167

168168

169+
async def test_judge_binary_output_mock(mocker: MockerFixture, image_content: BinaryContent) -> None:
170+
"""Test judge_output function when binary content is to be judged"""
171+
# Mock the agent run method
172+
mock_result = mocker.MagicMock()
173+
mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0)
174+
mock_run = mocker.patch('pydantic_ai.agent.AbstractAgent.run', return_value=mock_result)
175+
176+
result = await judge_output(output=image_content, rubric='dummy rubric')
177+
assert isinstance(result, GradingOutput)
178+
assert result.reason == 'Test passed'
179+
assert result.pass_ is True
180+
assert result.score == 1.0
181+
182+
# Verify the agent was called with correct prompt
183+
mock_run.assert_called_once()
184+
call_args, *_ = mock_run.call_args
185+
186+
assert call_args == snapshot((['<Output>', image_content, '</Output>', '<Rubric>', 'dummy rubric', '</Rubric>'],))
187+
188+
169189
async def test_judge_input_output_binary_content_mock(mocker: MockerFixture, image_content: BinaryContent):
170190
"""Test judge_input_output function with mocked agent."""
171191
# Mock the agent run method
@@ -237,10 +257,24 @@ async def test_judge_input_output_expected_mock(mocker: MockerFixture, image_con
237257

238258
# Verify the agent was called with correct prompt
239259
call_args = mock_run.call_args[0]
240-
assert '<Input>\nHello\n</Input>' in call_args[0]
241-
assert '<ExpectedOutput>\nHello\n</ExpectedOutput>' in call_args[0]
242-
assert '<Output>\nHello world\n</Output>' in call_args[0]
243-
assert '<Rubric>\nOutput contains input\n</Rubric>' in call_args[0]
260+
assert call_args == snapshot(
261+
(
262+
"""\
263+
<Input>
264+
Hello
265+
</Input>
266+
<Output>
267+
Hello world
268+
</Output>
269+
<Rubric>
270+
Output contains input
271+
</Rubric>
272+
<ExpectedOutput>
273+
Hello
274+
</ExpectedOutput>\
275+
""",
276+
)
277+
)
244278

245279
result = await judge_input_output_expected(image_content, 'Hello world', 'Hello', 'Output contains input')
246280
assert isinstance(result, GradingOutput)
@@ -249,10 +283,24 @@ async def test_judge_input_output_expected_mock(mocker: MockerFixture, image_con
249283
assert result.score == 1.0
250284

251285
call_args = mock_run.call_args[0]
252-
assert image_content in call_args[0]
253-
assert '<ExpectedOutput>\nHello\n</ExpectedOutput>' in call_args[0]
254-
assert '<Output>\nHello world\n</Output>' in call_args[0]
255-
assert '<Rubric>\nOutput contains input\n</Rubric>' in call_args[0]
286+
assert call_args == snapshot(
287+
(
288+
[
289+
'<Input>',
290+
image_content,
291+
'</Input>',
292+
'<Output>',
293+
'Hello world',
294+
'</Output>',
295+
'<Rubric>',
296+
'Output contains input',
297+
'</Rubric>',
298+
'<ExpectedOutput>',
299+
'Hello',
300+
'</ExpectedOutput>',
301+
],
302+
)
303+
)
256304

257305

258306
@pytest.mark.anyio
@@ -279,10 +327,24 @@ async def test_judge_input_output_expected_with_model_settings_mock(
279327
assert result.score == 1.0
280328

281329
call_args, call_kwargs = mock_run.call_args
282-
assert '<Input>\nHello settings\n</Input>' in call_args[0]
283-
assert '<ExpectedOutput>\nHello\n</ExpectedOutput>' in call_args[0]
284-
assert '<Output>\nHello world with settings\n</Output>' in call_args[0]
285-
assert '<Rubric>\nOutput contains input with settings\n</Rubric>' in call_args[0]
330+
assert call_args == snapshot(
331+
(
332+
"""\
333+
<Input>
334+
Hello settings
335+
</Input>
336+
<Output>
337+
Hello world with settings
338+
</Output>
339+
<Rubric>
340+
Output contains input with settings
341+
</Rubric>
342+
<ExpectedOutput>
343+
Hello
344+
</ExpectedOutput>\
345+
""",
346+
)
347+
)
286348
assert call_kwargs['model_settings'] == test_model_settings
287349
# Check if 'model' kwarg is passed, its value will be the default model or None
288350
assert 'model' in call_kwargs
@@ -301,10 +363,24 @@ async def test_judge_input_output_expected_with_model_settings_mock(
301363
assert result.score == 1.0
302364

303365
call_args, call_kwargs = mock_run.call_args
304-
assert image_content in call_args[0]
305-
assert '<ExpectedOutput>\nHello\n</ExpectedOutput>' in call_args[0]
306-
assert '<Output>\nHello world with settings\n</Output>' in call_args[0]
307-
assert '<Rubric>\nOutput contains input with settings\n</Rubric>' in call_args[0]
366+
assert call_args == snapshot(
367+
(
368+
[
369+
'<Input>',
370+
image_content,
371+
'</Input>',
372+
'<Output>',
373+
'Hello world with settings',
374+
'</Output>',
375+
'<Rubric>',
376+
'Output contains input with settings',
377+
'</Rubric>',
378+
'<ExpectedOutput>',
379+
'Hello',
380+
'</ExpectedOutput>',
381+
],
382+
)
383+
)
308384
assert call_kwargs['model_settings'] == test_model_settings
309385
# Check if 'model' kwarg is passed, its value will be the default model or None
310386
assert 'model' in call_kwargs
@@ -326,26 +402,20 @@ async def test_judge_input_output_expected_with_model_settings_mock(
326402

327403
assert call_args == snapshot(
328404
(
329-
[
330-
'<Input>\n',
331-
'123',
332-
'</Input>',
333-
"""\
405+
"""\
406+
<Input>
407+
123
408+
</Input>
334409
<Output>
335410
Hello world with settings
336-
</Output>\
337-
""",
338-
"""\
411+
</Output>
339412
<Rubric>
340413
Output contains input with settings
341-
</Rubric>\
342-
""",
343-
"""\
414+
</Rubric>
344415
<ExpectedOutput>
345416
Hello
346417
</ExpectedOutput>\
347418
""",
348-
],
349419
)
350420
)
351421

@@ -366,26 +436,20 @@ async def test_judge_input_output_expected_with_model_settings_mock(
366436

367437
assert call_args == snapshot(
368438
(
369-
[
370-
'<Input>\n',
371-
'123',
372-
'</Input>',
373-
"""\
439+
"""\
440+
<Input>
441+
123
442+
</Input>
374443
<Output>
375444
Hello world with settings
376-
</Output>\
377-
""",
378-
"""\
445+
</Output>
379446
<Rubric>
380447
Output contains input with settings
381-
</Rubric>\
382-
""",
383-
"""\
448+
</Rubric>
384449
<ExpectedOutput>
385450
Hello
386451
</ExpectedOutput>\
387452
""",
388-
],
389453
)
390454
)
391455

@@ -455,10 +519,21 @@ async def test_judge_output_expected_with_model_settings_mock(mocker: MockerFixt
455519
assert result.score == 1.0
456520

457521
call_args, call_kwargs = mock_run.call_args
458-
assert '<Input>' not in call_args[0]
459-
assert '<ExpectedOutput>\nHello\n</ExpectedOutput>' in call_args[0]
460-
assert '<Output>' in call_args[0]
461-
assert '<Rubric>\nOutput contains input with settings\n</Rubric>' in call_args[0]
522+
assert call_args == snapshot(
523+
(
524+
[
525+
'<Output>',
526+
image_content,
527+
'</Output>',
528+
'<Rubric>',
529+
'Output contains input with settings',
530+
'</Rubric>',
531+
'<ExpectedOutput>',
532+
'Hello',
533+
'</ExpectedOutput>',
534+
],
535+
)
536+
)
462537
assert call_kwargs['model_settings'] == test_model_settings
463538
# Check if 'model' kwarg is passed, its value will be the default model or None
464539
assert 'model' in call_kwargs

0 commit comments

Comments
 (0)