Skip to content

Commit 7bb085c

Browse files
committed
Add generate_dataset_from_agent
1 parent 92289ce commit 7bb085c

File tree

2 files changed

+118
-9
lines changed

2 files changed

+118
-9
lines changed

pydantic_evals/pydantic_evals/generation.py

Lines changed: 109 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,25 @@
66

77
from __future__ import annotations
88

9+
import json
10+
import re
911
from collections.abc import Sequence
1012
from pathlib import Path
11-
from typing import Any
13+
from typing import Any, TypeVar
1214

13-
from pydantic import ValidationError
14-
from typing_extensions import TypeVar
15+
from pydantic import BaseModel, ValidationError
1516

1617
from pydantic_ai import Agent, models
1718
from pydantic_ai._utils import strip_markdown_fences
18-
from pydantic_evals import Dataset
19+
from pydantic_evals import Case, Dataset
1920
from pydantic_evals.evaluators.evaluator import Evaluator
2021

21-
__all__ = ('generate_dataset',)
22+
__all__ = ('generate_dataset', 'generate_dataset_for_agent')
2223

2324
InputsT = TypeVar('InputsT', default=Any)
2425
"""Generic type for the inputs to the task being evaluated."""
25-
2626
OutputT = TypeVar('OutputT', default=Any)
2727
"""Generic type for the expected output of the task being evaluated."""
28-
2928
MetadataT = TypeVar('MetadataT', default=Any)
3029
"""Generic type for the metadata associated with the task being evaluated."""
3130

@@ -40,7 +39,6 @@ async def generate_dataset(
4039
extra_instructions: str | None = None,
4140
) -> Dataset[InputsT, OutputT, MetadataT]:
4241
"""Use an LLM to generate a dataset of test cases, each consisting of input, expected output, and metadata.
43-
4442
This function creates a properly structured dataset with the specified input, output, and metadata types.
4543
It uses an LLM to attempt to generate realistic test cases that conform to the types' schemas.
4644
@@ -72,7 +70,6 @@ async def generate_dataset(
7270
output_type=str,
7371
retries=1,
7472
)
75-
7673
result = await agent.run(extra_instructions or 'Please generate the object.')
7774
output = strip_markdown_fences(result.output)
7875
try:
@@ -83,3 +80,106 @@ async def generate_dataset(
8380
if path is not None:
8481
result.to_file(path, custom_evaluator_types=custom_evaluator_types) # pragma: no cover
8582
return result
83+
84+
85+
async def generate_dataset_for_agent(
86+
agent: Agent[Any, OutputT],
87+
*,
88+
inputs_type: type[InputsT] = str, # type: ignore
89+
metadata_type: type[MetadataT] | None = None,
90+
path: Path | str | None = None,
91+
model: models.Model | models.KnownModelName = 'openai:gpt-4o',
92+
n_examples: int = 3,
93+
extra_instructions: str | None = None,
94+
) -> Dataset[InputsT, OutputT, MetadataT]:
95+
"""Generate evaluation cases by running inputs through a target agent.
96+
97+
Generates diverse inputs and metadata using an LLM, then runs them through the agent
98+
to produce realistic expected outputs for evaluation.
99+
100+
Args:
101+
agent: Pydantic AI agent to extract outputs from.
102+
inputs_type: Type of inputs the agent expects. Defaults to str.
103+
metadata_type: Type for metadata. Defaults to None (uses NoneType).
104+
path: Optional path to save the generated dataset.
105+
model: Pydantic AI model to use for generation. Defaults to 'gpt-4o'.
106+
n_examples: Number of examples to generate. Defaults to 3.
107+
extra_instructions: Optional additional instructions for the LLM.
108+
109+
Returns:
110+
A properly structured Dataset object with generated test cases.
111+
112+
Raises:
113+
ValidationError: If the LLM's response cannot be parsed.
114+
"""
115+
# Get output schema with proper type handling
116+
# Check if it's a Pydantic model class (not an instance) before calling model_json_schema
117+
output_schema: str
118+
if isinstance(agent.output_type, type) and issubclass(agent.output_type, BaseModel):
119+
output_schema = str(agent.output_type.model_json_schema())
120+
else:
121+
# For other types (str, custom output specs, etc.), just use string representation
122+
output_schema = str(agent.output_type)
123+
124+
# Get inputs schema with proper type handling
125+
inputs_schema: str
126+
if issubclass(inputs_type, BaseModel):
127+
inputs_schema = str(inputs_type.model_json_schema())
128+
else:
129+
inputs_schema = str(inputs_type)
130+
131+
generation_prompt = (
132+
f'Generate {n_examples} test case inputs for an agent.\n\n'
133+
f'The agent accepts inputs of type: {inputs_schema}\n'
134+
f'The agent produces outputs of type: {output_schema}\n\n'
135+
f'Return a JSON array of objects with "name" (optional string), "inputs" (matching the input type), '
136+
f'and "metadata" (optional, any additional context).\n'
137+
f'You must not include any characters in your response before the opening [ of the JSON array, or after the closing ].'
138+
+ (f'\n\n{extra_instructions}' if extra_instructions else '')
139+
)
140+
141+
gen_agent = Agent(
142+
model,
143+
system_prompt=generation_prompt,
144+
output_type=str,
145+
retries=1,
146+
)
147+
148+
result = await gen_agent.run('Please generate the test case inputs and metadata.')
149+
output = strip_markdown_fences(result.output).strip()
150+
151+
try:
152+
if not output:
153+
raise ValueError('Empty output after stripping markdown fences')
154+
155+
# Additional cleanup in case strip_markdown_fences didn't catch everything
156+
# Remove markdown code blocks with optional language identifier
157+
output = re.sub(r'^```(?:json)?\s*\n?', '', output)
158+
output = re.sub(r'\n?```\s*$', '', output)
159+
output = output.strip()
160+
161+
inputs_metadata = json.loads(output)
162+
cases = []
163+
164+
for i, item in enumerate(inputs_metadata):
165+
agent_result = await agent.run(item['inputs'])
166+
cases.append(
167+
Case(
168+
name=item.get('name', f'case-{i}'),
169+
inputs=item['inputs'],
170+
expected_output=agent_result.output,
171+
metadata=item.get('metadata'),
172+
)
173+
)
174+
175+
result_dataset = Dataset(cases=cases, evaluators=[])
176+
177+
except (json.JSONDecodeError, KeyError, ValidationError) as e: # pragma: no cover
178+
print(f'Raw response from model:\n{result.output}\n')
179+
print(f'After stripping markdown fences:\n{output}')
180+
raise e
181+
182+
if path is not None:
183+
result_dataset.to_file(path)
184+
185+
return result_dataset

tests/evals/test_dataset.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1470,6 +1470,15 @@ def test_import_generate_dataset():
14701470
assert generate_dataset
14711471

14721472

1473+
def test_import_generate_dataset_from_agent():
1474+
# this function is tough to test in an interesting way outside an example...
1475+
# this at least ensures importing it doesn't fail.
1476+
# TODO: Add an "example" that actually makes use of this functionality
1477+
from pydantic_evals.generation import generate_dataset_for_agent
1478+
1479+
assert generate_dataset_for_agent
1480+
1481+
14731482
def test_evaluate_non_serializable_inputs():
14741483
@dataclass
14751484
class MyInputs:

0 commit comments

Comments
 (0)