66
77from __future__ import annotations
88
9+ import json
10+ import re
911from collections .abc import Sequence
1012from pathlib import Path
11- from typing import Any
13+ from typing import Any , TypeVar
1214
13- from pydantic import ValidationError
14- from typing_extensions import TypeVar
15+ from pydantic import BaseModel , ValidationError
1516
1617from pydantic_ai import Agent , models
1718from pydantic_ai ._utils import strip_markdown_fences
18- from pydantic_evals import Dataset
19+ from pydantic_evals import Case , Dataset
1920from pydantic_evals .evaluators .evaluator import Evaluator
2021
21- __all__ = ('generate_dataset' ,)
22+ __all__ = ('generate_dataset' , 'generate_dataset_for_agent' )
2223
2324InputsT = TypeVar ('InputsT' , default = Any )
2425"""Generic type for the inputs to the task being evaluated."""
25-
2626OutputT = TypeVar ('OutputT' , default = Any )
2727"""Generic type for the expected output of the task being evaluated."""
28-
2928MetadataT = TypeVar ('MetadataT' , default = Any )
3029"""Generic type for the metadata associated with the task being evaluated."""
3130
@@ -40,7 +39,6 @@ async def generate_dataset(
4039 extra_instructions : str | None = None ,
4140) -> Dataset [InputsT , OutputT , MetadataT ]:
4241 """Use an LLM to generate a dataset of test cases, each consisting of input, expected output, and metadata.
43-
4442 This function creates a properly structured dataset with the specified input, output, and metadata types.
4543 It uses an LLM to attempt to generate realistic test cases that conform to the types' schemas.
4644
@@ -72,7 +70,6 @@ async def generate_dataset(
7270 output_type = str ,
7371 retries = 1 ,
7472 )
75-
7673 result = await agent .run (extra_instructions or 'Please generate the object.' )
7774 output = strip_markdown_fences (result .output )
7875 try :
@@ -83,3 +80,106 @@ async def generate_dataset(
8380 if path is not None :
8481 result .to_file (path , custom_evaluator_types = custom_evaluator_types ) # pragma: no cover
8582 return result
83+
84+
85+ async def generate_dataset_for_agent (
86+ agent : Agent [Any , OutputT ],
87+ * ,
88+ inputs_type : type [InputsT ] = str , # type: ignore
89+ metadata_type : type [MetadataT ] | None = None ,
90+ path : Path | str | None = None ,
91+ model : models .Model | models .KnownModelName = 'openai:gpt-4o' ,
92+ n_examples : int = 3 ,
93+ extra_instructions : str | None = None ,
94+ ) -> Dataset [InputsT , OutputT , MetadataT ]:
95+ """Generate evaluation cases by running inputs through a target agent.
96+
97+ Generates diverse inputs and metadata using an LLM, then runs them through the agent
98+ to produce realistic expected outputs for evaluation.
99+
100+ Args:
101+ agent: Pydantic AI agent to extract outputs from.
102+ inputs_type: Type of inputs the agent expects. Defaults to str.
103+ metadata_type: Type for metadata. Defaults to None (uses NoneType).
104+ path: Optional path to save the generated dataset.
105+ model: Pydantic AI model to use for generation. Defaults to 'gpt-4o'.
106+ n_examples: Number of examples to generate. Defaults to 3.
107+ extra_instructions: Optional additional instructions for the LLM.
108+
109+ Returns:
110+ A properly structured Dataset object with generated test cases.
111+
112+ Raises:
113+ ValidationError: If the LLM's response cannot be parsed.
114+ """
115+ # Get output schema with proper type handling
116+ # Check if it's a Pydantic model class (not an instance) before calling model_json_schema
117+ output_schema : str
118+ if isinstance (agent .output_type , type ) and issubclass (agent .output_type , BaseModel ):
119+ output_schema = str (agent .output_type .model_json_schema ())
120+ else :
121+ # For other types (str, custom output specs, etc.), just use string representation
122+ output_schema = str (agent .output_type )
123+
124+ # Get inputs schema with proper type handling
125+ inputs_schema : str
126+ if issubclass (inputs_type , BaseModel ):
127+ inputs_schema = str (inputs_type .model_json_schema ())
128+ else :
129+ inputs_schema = str (inputs_type )
130+
131+ generation_prompt = (
132+ f'Generate { n_examples } test case inputs for an agent.\n \n '
133+ f'The agent accepts inputs of type: { inputs_schema } \n '
134+ f'The agent produces outputs of type: { output_schema } \n \n '
135+ f'Return a JSON array of objects with "name" (optional string), "inputs" (matching the input type), '
136+ f'and "metadata" (optional, any additional context).\n '
137+ f'You must not include any characters in your response before the opening [ of the JSON array, or after the closing ].'
138+ + (f'\n \n { extra_instructions } ' if extra_instructions else '' )
139+ )
140+
141+ gen_agent = Agent (
142+ model ,
143+ system_prompt = generation_prompt ,
144+ output_type = str ,
145+ retries = 1 ,
146+ )
147+
148+ result = await gen_agent .run ('Please generate the test case inputs and metadata.' )
149+ output = strip_markdown_fences (result .output ).strip ()
150+
151+ try :
152+ if not output :
153+ raise ValueError ('Empty output after stripping markdown fences' )
154+
155+ # Additional cleanup in case strip_markdown_fences didn't catch everything
156+ # Remove markdown code blocks with optional language identifier
157+ output = re .sub (r'^```(?:json)?\s*\n?' , '' , output )
158+ output = re .sub (r'\n?```\s*$' , '' , output )
159+ output = output .strip ()
160+
161+ inputs_metadata = json .loads (output )
162+ cases = []
163+
164+ for i , item in enumerate (inputs_metadata ):
165+ agent_result = await agent .run (item ['inputs' ])
166+ cases .append (
167+ Case (
168+ name = item .get ('name' , f'case-{ i } ' ),
169+ inputs = item ['inputs' ],
170+ expected_output = agent_result .output ,
171+ metadata = item .get ('metadata' ),
172+ )
173+ )
174+
175+ result_dataset = Dataset (cases = cases , evaluators = [])
176+
177+ except (json .JSONDecodeError , KeyError , ValidationError ) as e : # pragma: no cover
178+ print (f'Raw response from model:\n { result .output } \n ' )
179+ print (f'After stripping markdown fences:\n { output } ' )
180+ raise e
181+
182+ if path is not None :
183+ result_dataset .to_file (path )
184+
185+ return result_dataset
0 commit comments