diff --git a/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md b/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md index 624e580ad1..da4b4b8157 100644 --- a/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md +++ b/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md @@ -443,3 +443,301 @@ gepa = dspy.GEPA( auto="medium" ) ``` + +## Tool Description Optimization + +### What is optimize_tool_descriptions? + +The `optimize_tool_descriptions` parameter enables GEPA to optimize tool descriptions in addition to signature instructions. This is particularly valuable for ReAct agents and other tool-using systems, where the quality of tool descriptions directly impacts the agent's ability to select appropriate tools for each task. + +Unlike signature instructions that guide reasoning strategies, tool descriptions serve a different purpose: they help agents decide **which tool to use** in a given situation. GEPA applies a specialized reflection prompt tailored for tool selection decisions. + +### Tool-Specific Reflection Prompt + +GEPA uses a dedicated prompt for optimizing tool descriptions. The prompt receives the complete ReAct trajectory (all thoughts, actions, observations) from executions that used the tool being optimized: + +```python +class GenerateImprovedToolDescriptionFromFeedback(dspy.Signature): + """You are refining a tool description that the assistant currently uses. + + Review the current description along with examples of the assistant's tool decisions + and the feedback those decisions received. + + Read them together and refine the description. + So the agent understands when this tool actually helps, what argument or result matters, + and what misuse the feedback exposed. Keep the tool's voice and only change what the + evidence justifies. + + Return a refined description that helps the assistant quickly recognize good + opportunities for the tool.""" + + current_tool_description = dspy.InputField(desc="The current description of the tool") + examples_with_feedback = dspy.InputField( + desc="Examples showing tool usage decisions and feedback on correctness" + ) + + improved_tool_description = dspy.OutputField( + desc="An improved description that guides correct tool selection and usage" + ) +``` + +The `examples_with_feedback` contains full ReAct trajectories showing the complete context in which each tool was selected and used, enabling the reflection LM to understand tool selection patterns. + +**Example: Writing Tool-Aware Metrics** + +To provide effective feedback for tool optimization, write metrics that examine the trajectory: + +```python +def tool_feedback_metric(example, prediction, trace=None, pred_name=None, pred_trace=None): + """Metric that provides tool-specific feedback for GEPA optimization.""" + correct = prediction.answer == example.answer + score = 1.0 if correct else 0.0 + + # Generate tool-specific feedback if available + if hasattr(prediction, 'trajectory'): + tools_used = [ + prediction.trajectory[key] + for key in prediction.trajectory + if key.startswith('tool_name_') and prediction.trajectory[key] != 'finish' + ] + feedback = f"{'Correct' if correct else 'Wrong'}. Tools: {', '.join(tools_used)}" + else: + feedback = "Correct" if correct else "Wrong" + + return dspy.Prediction(score=score, feedback=feedback) +``` + +This produces feedback like: +``` +[Tool 'calculator' from 'agent'] Correct. Tools: calculator +[Tool 'search' from 'agent'] Wrong. Tools: search, calculator +``` + +The tool-specific prefix `[Tool 'calculator' from 'agent']` is automatically added by GEPA to focus the reflection LM on optimizing that particular tool's description. + +**Note:** Tool descriptions are treated as components in GEPA's optimization process. The `component_selector` parameter applies to both signature instructions and tool descriptions. For example, `component_selector="all"` optimizes all signatures and tools together, while `component_selector="round_robin"` cycles through them one at a time. + +### Default Behavior + +By default, GEPA only optimizes signature instructions (`optimize_tool_descriptions=False`): + +```python +# Default behavior: only signature optimization +gepa = dspy.GEPA( + metric=my_metric, + reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key), + # optimize_tool_descriptions=False # This is the default + auto="medium" +) +optimized_program = gepa.compile(student, trainset=examples) +``` + +### How It Works + +When enabled, GEPA: + +1. **Discovers all tools**: Traverses your program including nested sub-modules to find all `dspy.Tool` instances +2. **Categorizes components**: Separates tools (identified by `tool:` prefix) from signature instructions +3. **Routes components appropriately**: + - Signature instructions → Default or custom instruction proposer + - Tool descriptions → ToolProposer (receives ReAct's reflective data with tool-specific annotation) +4. **Optimizes holistically**: Treats tool descriptions as first-class components in the optimization process + +### Implementation Details + +**Reflective Dataset Construction:** + +GEPA constructs the reflective dataset for tool optimization in two passes: + +**Pass 1: Build reflective examples for predictors (used by instruction proposer)** + +For each predictor (including ReAct modules), GEPA creates reflective examples containing: +- **Inputs**: The predictor's input fields (e.g., `{"question": "..."}`) +- **Generated Outputs**: ALL of the predictor's output fields converted to strings + - For ReAct: This includes both `answer` AND `trajectory` fields + - The trajectory contains the complete execution trace with all thoughts, actions, and observations +- **Feedback**: Text feedback returned by your metric function + +These examples are used by the instruction proposer to optimize signature instructions. + +**Pass 2: Copy reflective examples to tools with annotation (used by tool proposer)** + +For each tool being optimized, GEPA: +- Identifies ALL ReAct predictors (across all nested modules) that have this tool in their toolset +- Takes ALL reflective examples from those predictors and makes a deep copy for the tool +- Annotates the feedback: `[Tool 'tool_name' from 'predictor_key'] {original_feedback}` +- If multiple ReAct modules use the same tool, their reflective examples are aggregated together + +These annotated examples are used by the tool proposer (with the tool-specific reflection prompt shown above) to optimize tool descriptions. + +This means: +- A tool receives the FULL ReAct trajectory (thoughts, actions, observations) in the "Generated Outputs" field +- The metric can optionally examine the trajectory and include tool-specific insights in the feedback text +- The reflection LM sees complete context about how and when the tool was used + +**Component Identification & Proposer Routing:** + +GEPA discovers tools by traversing ReAct modules and extracting their associated `dspy.Tool` instances. Once identified, GEPA routes components to appropriate proposers: +- **Signature instructions** → Custom instruction proposer (if provided) OR default GEPA proposer +- **Tool descriptions** → Built-in `ToolProposer` (always used, not customizable) + +The custom instruction proposer affects ONLY signature instructions. Tools always use the specialized `ToolProposer` with the tool-specific reflection prompt, regardless of whether you provide a custom instruction proposer. + +### When to Use optimize_tool_descriptions + +Enable `optimize_tool_descriptions=True` when you use `dspy.Tool` in your program and need better tool selection. Here are common scenarios: + +1. **ReAct agents with multiple tools** - Agent with `search` and `calculator` tools keeps searching when it should calculate, or vice versa. GEPA learns from execution feedback to clarify "use search for factual queries, calculator for numerical analysis." + +2. **Multi-agent systems with delegation** - Parent agent has delegation tools to specialized sub-agents but doesn't understand when to use each. GEPA optimizes both delegation tools and sub-agent internal tools holistically. + +3. **Sequential tool workflows** - Tools like `query_database` → `analyze_results` have dependencies but descriptions don't capture this. GEPA learns the sequence and timing from successful executions. + +4. **Domain-specific tools** - Tools like legal vs. medical document search have overlapping but domain-specific purposes. GEPA discovers usage patterns and adds context: "for legal precedents" vs. "for patient records." + +5. **Tools with limitations** - Initial description "Does calculations" is too vague. GEPA adds specificity from observed usage: "Use for arithmetic (+, -, *, /, **). Not for date math or string operations." + +See the usage examples below for implementations of scenarios 1 and 2. + +### Usage Examples + +#### Basic ReAct Agent + +```python +import dspy + +def search_web(query: str) -> str: + return f"Search results for: {query}" + +def calculate(expression: str) -> float: + return eval(expression) + +# Create ReAct agent with tools (poor initial descriptions) +search_tool = dspy.Tool(search_web, name="search", desc="Finds things") +calc_tool = dspy.Tool(calculate, name="calculator", desc="Does calculations") + +agent = dspy.ReAct("question -> answer", tools=[search_tool, calc_tool]) + +# Enable tool optimization +gepa = dspy.GEPA( + metric=my_metric, + reflection_lm=dspy.LM(model="gpt-5-mini"), + optimize_tool_descriptions=True, + component_selector="all", # Optimize all components together + auto="medium" +) + +optimized_agent = gepa.compile(agent, trainset=train_examples, valset=val_examples) + +# View optimized tool descriptions +print("Optimized search tool:", optimized_agent.tools["search"].desc) +print("Optimized calculator tool:", optimized_agent.tools["calculator"].desc) +``` + +**Example output after optimization:** +``` +Optimized search tool: Use when you need to find current information, facts, or data + from external sources. Provide specific search queries to get relevant results. + +Optimized calculator tool: Use for arithmetic operations and mathematical expressions. + Accepts Python-compatible expressions with numbers and operators (+, -, *, /, **). + Do not use for date calculations or string manipulations. +``` + +#### Multi-Agent System + +GEPA automatically discovers and optimizes tools in nested agents: + +```python +import dspy + +def search_web(query: str) -> str: + return f"Search results for: {query}" + +def calculate(expression: str) -> float: + return eval(expression) + +search_tool = dspy.Tool(search_web, name="search", desc="Searches") +calc_tool = dspy.Tool(calculate, name="calculator", desc="Computes") + +class ResearchAssistant(dspy.Module): + def __init__(self): + super().__init__() + self.researcher = dspy.ReAct("query -> findings", tools=[search_tool]) + + def delegate_research(query: str) -> str: + return self.researcher(query=query).findings + + research_tool = dspy.Tool(delegate_research, name="research", desc="Helps with questions") + self.assistant = dspy.ReAct("question -> answer", tools=[research_tool, calc_tool]) + + def forward(self, question): + return self.assistant(question=question) + +# Optimizes ALL tools: calculator, research, search +gepa = dspy.GEPA( + metric=my_metric, + reflection_lm=dspy.LM(model="gpt-5-mini"), + optimize_tool_descriptions=True, + component_selector="all", + auto="medium" +) + +optimized_system = gepa.compile(ResearchAssistant(), trainset=train, valset=val) + +# View optimized nested tool descriptions +print(optimized_system.researcher.tools["search"].desc) +print(optimized_system.assistant.tools["research"].desc) +print(optimized_system.assistant.tools["calculator"].desc) +``` + +### Inspecting Optimized Tool Descriptions + +After optimization, tool descriptions are automatically updated in your program. Access them directly through your module structure: + +```python +optimized_agent = gepa.compile(agent, trainset=train, valset=val) + +# Access tools directly - descriptions are already updated +print(optimized_agent.tools["search"].desc) +print(optimized_agent.tools["calculator"].desc) +``` + +For multi-agent systems, access nested tools through your module hierarchy: + +```python +optimized_system = gepa.compile(ResearchAssistant(), trainset=train, valset=val) + +# Access tools at different levels +print(optimized_system.researcher.tools["search"].desc) # Sub-agent tool +print(optimized_system.assistant.tools["research"].desc) # Main agent tool +print(optimized_system.assistant.tools["calculator"].desc) +``` + +### Compatibility with Custom Instruction Proposers + +Tool optimization works seamlessly with custom instruction proposers. When you provide a custom instruction proposer AND enable `optimize_tool_descriptions=True`: + +**Component routing:** +- **Signature instructions** → Your custom instruction proposer +- **Tool descriptions** → Built-in `ToolProposer` with specialized tool reflection prompt + +**Key points:** +- Both operate independently during the same GEPA run +- Tools receive domain-appropriate optimization guidance (tool selection patterns, usage context) +- Signatures use your custom logic (task-specific reasoning, formatting, etc.) +- The built-in tool proposer is not customizable - it always uses `GenerateImprovedToolDescriptionFromFeedback` + +This separation ensures tools and signatures get appropriate optimization strategies without interference. + +```python +from dspy.teleprompt.gepa.instruction_proposal import MultiModalInstructionProposer + +gepa = dspy.GEPA( + metric=my_metric, + reflection_lm=dspy.LM(model="gpt-5", temperature=1.0, max_tokens=32000, api_key=api_key), + instruction_proposer=MultiModalInstructionProposer(), # For signatures + optimize_tool_descriptions=True, # Enables ToolProposer for tools + auto="medium" +) +``` diff --git a/docs/docs/api/optimizers/GEPA/overview.md b/docs/docs/api/optimizers/GEPA/overview.md index 0125702bea..b830ee3985 100644 --- a/docs/docs/api/optimizers/GEPA/overview.md +++ b/docs/docs/api/optimizers/GEPA/overview.md @@ -117,6 +117,12 @@ Practical Recipe for GEPA-Friendly Feedback: - **Multi-Objective Tasks** (e.g., PUPA): Decompose aggregate scores to reveal contributions from each objective, highlighting tradeoffs (e.g., quality vs. privacy). - **Stacked Pipelines** (e.g., code generation: parse → compile → run → profile → evaluate): Expose stage-specific failures; natural-language traces often suffice for LLM self-correction. +## Tool Description Optimization + +GEPA can optimize tool descriptions for ReAct agents. When `optimize_tool_descriptions=True`, GEPA discovers all tools in your program (including nested multi-agent systems) and applies a specialized reflection prompt to improve how tools are described. This helps agents make better tool selection decisions by learning from execution traces which tools work well in which contexts. + +For details on how tool optimization works, when to use it, and usage examples, see [Tool Description Optimization](GEPA_Advanced.md#tool-description-optimization) in the Advanced Features guide. + ## Custom Instruction Proposal For advanced customization of GEPA's instruction proposal mechanism, including custom instruction proposers and component selectors, see [Advanced Features](GEPA_Advanced.md). diff --git a/dspy/teleprompt/gepa/gepa.py b/dspy/teleprompt/gepa/gepa.py index 87cbbf80a5..e4c4d40862 100644 --- a/dspy/teleprompt/gepa/gepa.py +++ b/dspy/teleprompt/gepa/gepa.py @@ -273,6 +273,11 @@ def metric( warn_on_score_mismatch: GEPA (currently) expects the metric to return the same module-level score when called with and without the pred_name. This flag (defaults to True) determines whether a warning is raised if a mismatch in module-level and predictor-level score is detected. + optimize_tool_descriptions: Whether to optimize tool descriptions for modules with tools + (e.g., ReAct agents). When enabled, tool descriptions are included in the optimization + process alongside signature instructions. See the + [Tool Description Optimization guide](https://dspy.ai/api/optimizers/GEPA/GEPA_Advanced/#tool-description-optimization) + for details on when to use this feature and how it works. Default is False. seed: The random seed to use for reproducibility. Default is 0. gepa_kwargs: (Optional) provide additional kwargs to be passed to [gepa.optimize](https://github.com/gepa-ai/gepa/blob/main/src/gepa/api.py) method @@ -328,6 +333,7 @@ def __init__( wandb_init_kwargs: dict[str, Any] | None = None, track_best_outputs: bool = False, warn_on_score_mismatch: bool = True, + optimize_tool_descriptions: bool = False, use_mlflow: bool = False, # Reproducibility seed: int | None = 0, @@ -390,6 +396,7 @@ def __init__( self.wandb_api_key = wandb_api_key self.wandb_init_kwargs = wandb_init_kwargs self.warn_on_score_mismatch = warn_on_score_mismatch + self.optimize_tool_descriptions = optimize_tool_descriptions self.use_mlflow = use_mlflow if track_best_outputs: @@ -518,11 +525,25 @@ def feedback_fn( rng=rng, reflection_lm=self.reflection_lm, custom_instruction_proposer=self.custom_instruction_proposer, - warn_on_score_mismatch=self.warn_on_score_mismatch + warn_on_score_mismatch=self.warn_on_score_mismatch, + optimize_tool_descriptions=self.optimize_tool_descriptions ) # Instantiate GEPA with the simpler adapter-based API base_program = {name: pred.signature.instructions for name, pred in student.named_predictors()} + + if self.optimize_tool_descriptions: + tool_descriptions = {} + for _, module in student.named_sub_modules(): + if hasattr(module, "tools"): + for tool_name, tool in module.tools.items(): + tool_key = f"tool:{tool_name}" + if tool_key not in tool_descriptions: + tool_descriptions[tool_key] = tool.desc + if tool_descriptions: + logger.info(f"Including {len(tool_descriptions)} tool descriptions for optimization") + base_program.update(tool_descriptions) + gepa_result: GEPAResult = optimize( seed_candidate=base_program, trainset=trainset, diff --git a/dspy/teleprompt/gepa/gepa_utils.py b/dspy/teleprompt/gepa/gepa_utils.py index 844afe8b00..6e6eaf4a7d 100644 --- a/dspy/teleprompt/gepa/gepa_utils.py +++ b/dspy/teleprompt/gepa/gepa_utils.py @@ -1,5 +1,7 @@ import logging import random +from collections import defaultdict +from copy import deepcopy from typing import Any, Callable, Protocol, TypedDict from gepa import EvaluationBatch, GEPAAdapter @@ -15,6 +17,7 @@ logger = logging.getLogger(__name__) + class LoggerAdapter: def __init__(self, logger: logging.Logger): self.logger = logger @@ -22,6 +25,7 @@ def __init__(self, logger: logging.Logger): def log(self, x: str): self.logger.info(x) + DSPyTrace = list[tuple[Any, dict[str, Any], Prediction]] @@ -31,15 +35,17 @@ class ReflectiveExample(TypedDict): Each example contains the predictor inputs, generated outputs, and feedback from evaluation. """ - Inputs: dict[str, Any] # Predictor inputs (may include str, dspy.Image, etc.) - Generated_Outputs: dict[str, Any] | str # Success: dict with output fields, Failure: error message string - Feedback: str # Always a string - from metric function or parsing error message + + Inputs: dict[str, Any] # Predictor inputs (may include str, dspy.Image, etc.) + Generated_Outputs: dict[str, Any] | str # Success: dict with output fields, Failure: error message string + Feedback: str # Always a string - from metric function or parsing error message class ScoreWithFeedback(Prediction): score: float feedback: str + class PredictorFeedbackFn(Protocol): def __call__( predictor_output: dict[str, Any], @@ -64,6 +70,7 @@ def __call__( """ ... + class DspyAdapter(GEPAAdapter[Example, TraceData, Prediction]): def __init__( self, @@ -76,7 +83,8 @@ def __init__( rng: random.Random | None = None, reflection_lm=None, custom_instruction_proposer: "ProposalFn | None" = None, - warn_on_score_mismatch: bool = True + warn_on_score_mismatch: bool = True, + optimize_tool_descriptions: bool = False, ): self.student = student_module self.metric_fn = metric_fn @@ -88,42 +96,118 @@ def __init__( self.reflection_lm = reflection_lm self.custom_instruction_proposer = custom_instruction_proposer self.warn_on_score_mismatch = warn_on_score_mismatch - - if self.custom_instruction_proposer is not None: - # We are only overriding the propose_new_texts method when a custom - # instruction proposer is provided. Otherwise, we use the GEPA - # default propose_new_texts. - - def custom_propose_new_texts( + self.optimize_tool_descriptions = optimize_tool_descriptions + + def build_propose_new_texts(): + instruction_proposer = None + + # Init Signature Proposer if custom proposer is provided. + # Otherwise, use GEPA default proposer. + if self.custom_instruction_proposer is not None: + instruction_proposer = self.custom_instruction_proposer + else: + from gepa.strategies.instruction_proposal import InstructionProposalSignature + + def default_signature_proposer( + candidate: dict[str, str], + reflective_dataset: dict[str, list[dict[str, Any]]], + components_to_update: list[str], + ) -> dict[str, str]: + lm = self.reflection_lm if self.reflection_lm is not None else dspy.settings.lm + sig_texts: dict[str, str] = {} + for name in components_to_update: + base_instruction = candidate[name] + dataset_with_feedback = reflective_dataset[name] + sig_texts[name] = InstructionProposalSignature.run( + lm=(lambda x: lm(x)[0]), + input_dict={ + "current_instruction_doc": base_instruction, + "dataset_with_feedback": dataset_with_feedback, + }, + )["new_instruction"] + return sig_texts + + instruction_proposer = default_signature_proposer + + # Init Tool Proposer if tool optimization is enabled. + tool_proposer = None + if self.optimize_tool_descriptions is not None: + from .instruction_proposal import ToolProposer + + tool_proposer = ToolProposer() + + def propose_component_texts( candidate: dict[str, str], reflective_dataset: dict[str, list[dict[str, Any]]], - components_to_update: list[str] + components_to_update: list[str], ) -> dict[str, str]: + tool_components = [c for c in components_to_update if c.startswith("tool:")] + instruction_components = [c for c in components_to_update if not c.startswith("tool:")] + + results: dict[str, str] = {} + + # Handle signature components. if self.reflection_lm is not None: with dspy.context(lm=self.reflection_lm): - return self.custom_instruction_proposer( + results.update( + instruction_proposer( + candidate=candidate, + reflective_dataset=reflective_dataset, + components_to_update=instruction_components, + ) + ) + else: + results.update( + instruction_proposer( candidate=candidate, reflective_dataset=reflective_dataset, - components_to_update=components_to_update + components_to_update=instruction_components, ) - else: - return self.custom_instruction_proposer( - candidate=candidate, - reflective_dataset=reflective_dataset, - components_to_update=components_to_update ) - self.propose_new_texts = custom_propose_new_texts + # Handle tool if tool proposer is provided. + if tool_proposer is not None: + if self.reflection_lm is not None: + with dspy.context(lm=self.reflection_lm): + results.update( + tool_proposer( + candidate=candidate, + reflective_dataset=reflective_dataset, + components_to_update=tool_components, + ) + ) + else: + results.update( + tool_proposer( + candidate=candidate, + reflective_dataset=reflective_dataset, + components_to_update=tool_components, + ) + ) + + return results + + return propose_component_texts + + self.propose_new_texts = build_propose_new_texts() # Cache predictor names/signatures self.named_predictors = list(self.student.named_predictors()) - def build_program(self, candidate: dict[str, str]): new_prog = self.student.deepcopy() for name, pred in new_prog.named_predictors(): if name in candidate: pred.signature = pred.signature.with_instructions(candidate[name]) + + if self.optimize_tool_descriptions: + for _, module in new_prog.named_sub_modules(): + if hasattr(module, "tools"): + for tool_name, tool in module.tools.items(): + tool_key = f"tool:{tool_name}" + if tool_key in candidate: + tool.desc = candidate[tool_key] + return new_prog def evaluate(self, batch, candidate, capture_traces=False): @@ -165,7 +249,7 @@ def evaluate(self, batch, candidate, capture_traces=False): return_all_scores=True, failure_score=self.failure_score, provide_traceback=True, - max_errors=len(batch) * 100 + max_errors=len(batch) * 100, ) res = evaluator(program) outputs = [r[1] for r in res.results] @@ -173,12 +257,20 @@ def evaluate(self, batch, candidate, capture_traces=False): scores = [s["score"] if hasattr(s, "score") else s for s in scores] return EvaluationBatch(outputs=outputs, scores=scores, trajectories=None) - def make_reflective_dataset(self, candidate, eval_batch, components_to_update) -> dict[str, list[ReflectiveExample]]: + def make_reflective_dataset( + self, candidate, eval_batch, components_to_update + ) -> dict[str, list[ReflectiveExample]]: from dspy.teleprompt.bootstrap_trace import FailedPrediction + program = self.build_program(candidate) ret_d: dict[str, list[ReflectiveExample]] = {} + + # First pass: Process all non-tool components (predictors) for pred_name in components_to_update: + if pred_name.startswith("tool:"): + continue # Skip tools in first pass (tools are processed in the second pass) + module = None for name, m in program.named_predictors(): if name == pred_name: @@ -273,7 +365,6 @@ def make_reflective_dataset(self, candidate, eval_batch, components_to_update) - d["Feedback"] = fb["feedback"] if fb["score"] != module_score: if self.warn_on_score_mismatch: - logger.warning("The score returned by the metric with pred_name is different from the overall metric score. This can indicate 2 things: Either the metric is non-deterministic (e.g., LLM-as-judge, Semantic score, etc.) or the metric returned a score specific to pred_name that differs from the module level score. Currently, GEPA does not support predictor level scoring (support coming soon), and only requires a feedback text to be provided, which can be specific to the predictor or program level. GEPA will ignore the differing score returned, and instead use module level score. You can safely ignore this warning if using a semantic metric, however, if this mismatch is caused due to predictor scoring, please return module-level scores. To disable this warning, set warn_on_score_mismatch=False.") self.warn_on_score_mismatch = False fb["score"] = module_score @@ -284,11 +375,125 @@ def make_reflective_dataset(self, candidate, eval_batch, components_to_update) - continue ret_d[pred_name] = items + # Add tool examples to the reflective dataset + tool_examples = defaultdict(list) + + if self.optimize_tool_descriptions: + # Design Decision: Full ReAct Trajectory Sharing for Tools + # + # Each tool receives the COMPLETE ReAct trajectory (all thoughts, actions, observations) + # rather than only the segments where that tool was used. This trades token efficiency + # for richer optimization context. + # + # Rationale: + # 1. Tools are interdependent: search results inform calculator usage, API responses + # guide follow-up queries. Full trajectory shows these dependencies. + # 2. Reflection LM needs context to understand tool SELECTION patterns: + # - Why did the agent choose this tool over alternatives? + # - When in the reasoning process is this tool most useful? + # - What prior information typically triggers this tool's usage? + # 3. Goal is descriptions that guide "when to use" not just "what it does" + # + # Trade-offs: + # - Cost: N tools = N copies of same trajectory (5 tools = 5x duplication) + # - Benefit: Descriptions capture tool's role in multi-step workflows + # Example: "Use after search when numerical analysis is needed" vs "Does math" + # + for module_path, sub_module in program.named_sub_modules(): + # Walk each sub-module to locate its tools and remember the predictor scope + # so we can share those reflections with the tool descriptions below + tools = getattr(sub_module, "tools", None) + if not tools: + continue + + prefix = module_path.removeprefix("self.") if module_path != "self" else "" + + tool_entries = list(tools.items()) + + for child_name, _ in sub_module.named_predictors(): + predictor_key = child_name if not prefix else f"{prefix}.{child_name}" + reflections = ret_d.get(predictor_key) + if not reflections: + continue + + # Share the FULL ReAct trajectory with each tool + for tool_name, _ in tool_entries: + tool_key = f"tool:{tool_name}" + for item in reflections: + annotated = deepcopy(item) + annotated["Feedback"] = f"[Tool '{tool_name}' from '{predictor_key}'] {item['Feedback']}" + tool_examples[tool_key].append(annotated) + + # Merge tool examples into main dataset (shared tools get examples from all predictors) + ret_d.update(tool_examples) + if len(ret_d) == 0: raise Exception("No valid predictions found for any module.") return ret_d + # Future Work: Joint Tool Optimization with ReAct for Token Efficiency + # =========================================================== + # Current approach duplicates the same trajectory N times for N tools in a ReAct module. + # For multi-tool agents, we could optimize all tools simultaneously to reduce token usage. + # + # Assumption: + # - ReAct module is the only module that uses the tools + # - When optimizing tool descriptions of ReAct, reflection LM would capture general pattern of tools and ReAct's decision making process + # - It's probably better to holistically optimize all tools and ReAct together + + # Proposed Architecture (Exact details may change): + # 1. During reflective dataset construction, group tools by their parent ReAct module: + # - Walk program.named_sub_modules() to find ReAct predictors + # - Extract tools from each ReAct module via getattr(module, "tools", None) + # - Build mapping: {module_path: [tool_name1, tool_name2, ...]} + # - Detect when a module has multiple tools + # + # 2. For multi-tool ReAct modules, choose architectural approach: + # + # Option A: Separate tool-specific proposer signature + # - Create custom signature extending GenerateImprovedToolDescriptionFromFeedback + # - Use dspy.Signature.append_field() to add one output field per tool + # - Example: For 3 tools, add fields "improved_search_desc", "improved_calc_desc", "improved_api_desc" + # - Pro: Clean separation between instruction and tool optimization + # - Con: Separate LM call from ReAct instruction optimization + # + # Option B: Extend ReAct instruction proposer directly + # - Append tool description fields to existing ReAct instruction proposer + # - Update proposer instructions/docstring to include tool optimization guidance + # - Use dspy.Signature's helper functions to add output fields for each tool + # - Aggregate all tools' input/output fields expected to be updated from that ReAct module + # - Pro: Single LM call optimizes ReAct instructions AND tool descriptions together + # - Pro: Reflection LM sees relationship between instructions and tools holistically + # - Con: More complex signature modification, harder to maintain separation of concerns + # + # 3. Pass the ReAct trajectory ONCE to generate all tool descriptions and ReAct instruction simultaneously: + # - Single LM call with multi-field output instead of N separate calls + # - Proposer prompt instructs LM to consider tool interactions + # + # 4. Parse the multi-field output and update each tool's description: + # - Extract each field from the prediction + # - Map back to tool names using the grouping from step 1 + # - Handle parsing errors with fallback to current one-at-a-time approach + # + # Benefits: + # - Eliminates trajectory duplication: 1x token cost instead of Nx + # - Reflection LM sees all tools holistically, can coordinate descriptions + # - Tool descriptions can complement each other ("use search before calculator") + # - Scales better for agents with 10+ tools + # + # Challenges: + # - Signature modification at runtime may require careful field naming/parsing + # - More output fields → higher chance of LM parsing errors (but user will likely to use powerful LMs for ReAct + tools prompts optimization) + # - Need robust fallback when multi-field output fails (DSPy natively implemented fallback logic for this?) + # - Requires refactoring GEPA's "one component at a time" architecture (but we can treat ReAct + tools as "one component") + # + # Implementation Notes (Ignore if it's too overengineering): + # - Start with simple case: all tools from one ReAct module + # - Add retry logic for malformed multi-field outputs + # - Consider hybrid approach: joint optimization for <5 tools, separate for more + # - May need different proposer prompt template for joint vs. individual optimization + # TODO: The current DSPyAdapter implementation uses the GEPA default propose_new_texts. # We can potentially override this, to use the instruction proposal similar to MIPROv2. diff --git a/dspy/teleprompt/gepa/instruction_proposal.py b/dspy/teleprompt/gepa/instruction_proposal.py index 23810b9a02..5429cbf23f 100644 --- a/dspy/teleprompt/gepa/instruction_proposal.py +++ b/dspy/teleprompt/gepa/instruction_proposal.py @@ -310,3 +310,99 @@ def __call__( updated_components[component_name] = new_instruction return updated_components + + +class GenerateImprovedToolDescriptionFromFeedback(dspy.Signature): + """You are refining a tool description that the assistant currently uses. + + Review the current description along with examples of the assistant's tool decisions and the feedback those decisions received. + + Read them together and refine the description. + So the agent understands when this tool actually helps, what argument or result matters, and what misuse the feedback exposed. + Keep the tool's voice and only change what the evidence justifies. + + Return a refined description that helps the assistant quickly recognize good opportunities for the tool.""" + + current_tool_description = dspy.InputField(desc="The current description of the tool") + examples_with_feedback = dspy.InputField(desc="Examples showing tool usage decisions and feedback on correctness") + + improved_tool_description = dspy.OutputField( + desc="An improved description that guides correct tool selection and usage" + ) + + +class ToolProposer(ProposalFn): + """GEPA-compatible tool description proposer. + + Formats reflective examples into structured markdown and calls + `GenerateImprovedToolDescriptionFromFeedback` to produce updated descriptions + for each tool that requires refinement. + """ + + def __init__(self): + self.propose_description = dspy.Predict(GenerateImprovedToolDescriptionFromFeedback) + + def __call__( + self, + candidate: dict[str, str], + reflective_dataset: dict[str, list[ReflectiveExample]], + components_to_update: list[str], + ) -> dict[str, str]: + """GEPA-compatible proposal function. + + Args: + candidate: Current component name -> description mapping + reflective_dataset: Component name -> list of reflective examples + components_to_update: List of component names to update + + Returns: + dict: Component name -> new description mapping + """ + updated_components = {} + + for component_name in components_to_update: + if component_name in candidate and component_name in reflective_dataset: + current_description = candidate[component_name] + component_reflective_data = reflective_dataset[component_name] + + formatted_examples = self._format_examples(component_reflective_data) + result = self.propose_description( + current_tool_description=current_description, + examples_with_feedback=formatted_examples, + ) + + updated_components[component_name] = result.improved_tool_description + + return updated_components + + def _format_examples(self, reflective_dataset: list[ReflectiveExample]) -> str: + """Format reflective examples using GEPA's markdown structure.""" + + def render_value(value, level=3): + if isinstance(value, dict): + s = "" + for key, val in value.items(): + s += f"{'#' * level} {key}\n" + s += render_value(val, min(level + 1, 6)) + if not value: + s += "\n" + return s + if isinstance(value, (list, tuple)): + s = "" + for index, item in enumerate(value): + s += f"{'#' * level} Item {index + 1}\n" + s += render_value(item, min(level + 1, 6)) + if not value: + s += "\n" + return s + return f"{str(value).strip()}\n\n" + + def convert_sample_to_markdown(sample, example_num): + s = f"# Example {example_num}\n" + for key, val in sample.items(): + s += f"## {key}\n" + s += render_value(val, level=3) + return s + + formatted_parts = [convert_sample_to_markdown(example, i + 1) for i, example in enumerate(reflective_dataset)] + return "\n\n".join(formatted_parts) diff --git a/tests/teleprompt/test_gepa_tool_optimization.py b/tests/teleprompt/test_gepa_tool_optimization.py new file mode 100644 index 0000000000..d82d0faa80 --- /dev/null +++ b/tests/teleprompt/test_gepa_tool_optimization.py @@ -0,0 +1,318 @@ +from types import SimpleNamespace + +import dspy +from dspy import Example +from dspy.teleprompt.gepa import gepa_utils +from dspy.utils.dummies import DummyLM + + +def calculator(expression: str) -> str: + try: + return str(eval(expression)) + except Exception: + return "Error" + + +def search(query: str) -> str: + return f"Results for: {query}" + + +def simple_metric(example, prediction, trace=None, pred_name=None, pred_trace=None): + pred_str = str(prediction.answer).strip() + expected = str(example.answer).strip() + score = 1.0 if pred_str == expected else 0.0 + return dspy.Prediction(score=score, feedback="Correct" if score == 1.0 else "Wrong") + + +def make_example(question: str, answer: str) -> Example: + return Example(question=question, answer=answer).with_inputs("question") + + +def make_reflection_entry(question: str, answer: str, feedback: str, score: float = 1.0) -> dict: + return { + "Inputs": {"question": question}, + "Generated Outputs": {"answer": answer}, + "Feedback": f"Score: {score}.\n{feedback}", + } + + +def make_react_module(tool_specs, *, max_iters: int = 3): + class SimpleReact(dspy.Module): + def __init__(self): + super().__init__() + tools = [dspy.Tool(fn, name=name, desc=desc) for name, desc, fn in tool_specs] + self.agent = dspy.ReAct( + "question -> answer", + tools=tools, + max_iters=max_iters, + ) + + def forward(self, question: str): + return self.agent(question=question) + + return SimpleReact() + + +def make_nested_react_module(main_tool_specs, *, nested_tool_specs, max_iters: int = 3): + class NestedReact(dspy.Module): + def __init__(self): + super().__init__() + nested_tools = [dspy.Tool(fn, name=name, desc=desc) for name, desc, fn in nested_tool_specs] + self.subagent = dspy.ReAct( + "task -> result", + tools=nested_tools, + max_iters=max_iters, + ) + + def spawn_subagent(task: str) -> str: + return self.subagent(task=task).result + + spawn_tool = dspy.Tool(spawn_subagent, name="spawn_subagent", desc="Spawns helper agent.") + main_tools = [dspy.Tool(fn, name=name, desc=desc) for name, desc, fn in main_tool_specs] + self.agent = dspy.ReAct( + "question -> answer", + tools=[spawn_tool, *main_tools], + max_iters=max_iters, + ) + + def forward(self, question: str): + return self.agent(question=question) + + return NestedReact() + + +def build_adapter_for_program( + program, + *, + custom_instruction_proposer=None, + reflection_lm=None, + optimize_tool_descriptions: bool = True, +): + predictor_names = sorted(name for name, _ in program.named_predictors()) + if not predictor_names: + raise ValueError("program must expose at least one predictor") + + def metric_fn(example, prediction, trace=None, pred_name=None, pred_trace=None): + return dspy.Prediction(score=1.0, feedback="ok") + + feedback_map = {} + for name in predictor_names: + feedback_map[name] = lambda *args, _name=name, **kwargs: dspy.Prediction( + score=1.0, feedback=f"{_name}-fb" + ) + + adapter = gepa_utils.DspyAdapter( + student_module=program, + metric_fn=metric_fn, + feedback_map=feedback_map, + failure_score=0.0, + reflection_lm=reflection_lm, + custom_instruction_proposer=custom_instruction_proposer, + optimize_tool_descriptions=optimize_tool_descriptions, + ) + + return adapter, predictor_names + + +def stub_optimize(monkeypatch, *, new_descs, captured_seed): + def fake_optimize(*, seed_candidate, **kwargs): + captured_seed.update(seed_candidate) + best_candidate = dict(seed_candidate) + for tool_name, desc in new_descs.items(): + best_candidate[f"tool:{tool_name}"] = desc + return SimpleNamespace(best_candidate=best_candidate) + + monkeypatch.setattr("gepa.optimize", fake_optimize) + + +def test_gepa_updates_nested_agent_tools(monkeypatch): + program = make_nested_react_module( + main_tool_specs=[("calculator", "Does math", calculator)], + nested_tool_specs=[("search", "Searches", search)], + max_iters=1, + ) + + original_descs = { + "calculator": program.agent.tools["calculator"].desc, + "spawn_subagent": program.agent.tools["spawn_subagent"].desc, + "search": program.subagent.tools["search"].desc, + } + + new_descs = { + "calculator": "Clarify how to perform arithmetic precisely.", + "spawn_subagent": "Explain when to spawn a helper agent.", + "search": "Improve how search guidance is presented.", + } + + captured_seed: dict[str, str] = {} + dspy.settings.configure(lm=DummyLM([{"q": "question", "a": "answer"}])) + reflection_lm = DummyLM([{"improved_instruction": "unused"}]) + + stub_optimize(monkeypatch, new_descs=new_descs, captured_seed=captured_seed) + optimizer = dspy.GEPA( + metric=simple_metric, + reflection_lm=reflection_lm, + max_metric_calls=3, + optimize_tool_descriptions=True, + ) + trainset = [ + make_example("What is 2 + 2?", "4"), + ] + optimized = optimizer.compile(program, trainset=trainset) + + for tool_name, original in original_descs.items(): + assert captured_seed[f"tool:{tool_name}"] == original + + assert optimized.agent.tools["calculator"].desc == new_descs["calculator"] + assert optimized.agent.tools["spawn_subagent"].desc == new_descs["spawn_subagent"] + assert optimized.subagent.tools["search"].desc == new_descs["search"] + + +def test_reflective_dataset_shares_feedback_across_shared_tools(): + shared_tool = dspy.Tool(calculator, name="calculator", desc="Original calculator guidance") + + class DualReact(dspy.Module): + def __init__(self): + super().__init__() + self.agent_a = dspy.ReAct("question -> answer", tools=[shared_tool], max_iters=1) + self.agent_b = dspy.ReAct("question -> answer", tools=[shared_tool], max_iters=1) + + def forward(self, question: str): + return dspy.Prediction(answer="unused") + + program = DualReact() + adapter, predictor_names = build_adapter_for_program( + program, + reflection_lm=DummyLM([{"improved_instruction": "Better"}]), + ) + + candidate = {} + for name in predictor_names: + candidate[name] = f"{name}-instruction" + candidate["tool:calculator"] = shared_tool.desc + + program = adapter.build_program(candidate) + predictor_lookup = {name: pred for name, pred in program.named_predictors()} + + trajectories: list[dict] = [] + for index, name in enumerate(predictor_names): + predictor = predictor_lookup[name] + trace_entry = ( + predictor, + {"question": f"Request {index + 1}"}, + dspy.Prediction(answer=f"Response {index + 1}"), + ) + trajectories.append( + { + "trace": [trace_entry], + "example": make_example( + f"Request {index + 1}", + f"Response {index + 1}", + ), + "prediction": dspy.Prediction(answer=f"Response {index + 1}"), + "score": 1.0, + } + ) + + eval_batch = SimpleNamespace(outputs=[], scores=[], trajectories=trajectories) + components_to_update = [*predictor_names, "tool:calculator"] + + reflective_dataset = adapter.make_reflective_dataset(candidate, eval_batch, components_to_update) + + for name in predictor_names: + assert name in reflective_dataset + assert "tool:calculator" in reflective_dataset + assert len(reflective_dataset["tool:calculator"]) == len(predictor_names) + + feedback_texts = [item["Feedback"] for item in reflective_dataset["tool:calculator"]] + for name in predictor_names: + assert any(name in feedback for feedback in feedback_texts) + + +def test_dspy_adapter_uses_custom_instruction_and_tool_proposers(monkeypatch): + program = make_react_module([("toolA", "Original tool desc", lambda arg: arg)]) + + tool_calls: list[tuple[dict, list[str]]] = [] + + class MockToolProposer: + def __call__(self, *, candidate, reflective_dataset, components_to_update): + tool_calls.append((dict(candidate), list(components_to_update))) + return {component: f"tool-new-{component}" for component in components_to_update} + + monkeypatch.setattr( + "dspy.teleprompt.gepa.instruction_proposal.ToolProposer", + MockToolProposer, + ) + + class MockInstructionProposer: + def __init__(self): + self.calls: list[list[str]] = [] + + def __call__(self, *, candidate, reflective_dataset, components_to_update): + self.calls.append(list(components_to_update)) + return {name: f"instr-new-{name}" for name in components_to_update} + + instruction_proposer = MockInstructionProposer() + + adapter, predictor_names = build_adapter_for_program( + program, + custom_instruction_proposer=instruction_proposer, + reflection_lm=DummyLM([{"improved_instruction": "Better"}]), + ) + + predictor_name = predictor_names[0] + tool_key = "tool:toolA" + candidate = { + predictor_name: "Base instruction", + tool_key: program.agent.tools["toolA"].desc, + } + reflective_dataset = { + predictor_name: [ + make_reflection_entry( + "When should I ask for help?", + "Use toolA when delegation unblocks progress.", + "Clarify the decision boundary.", + ) + ], + tool_key: [ + make_reflection_entry( + "When should I ask for help?", + "Use toolA when delegation unblocks progress.", + "Highlight the tool's specialty.", + ) + ], + } + + updated = adapter.propose_new_texts(candidate, reflective_dataset, [predictor_name, tool_key]) + + assert instruction_proposer.calls == [[predictor_name]] + assert tool_calls == [(candidate, [tool_key])] + assert updated[predictor_name] == f"instr-new-{predictor_name}" + assert updated[tool_key] == f"tool-new-{tool_key}" + + +def test_gepa_overwrites_single_react_tool_description(monkeypatch): + program = make_react_module([("calculator", "Does math", calculator)], max_iters=1) + original_desc = program.agent.tools["calculator"].desc + + new_descs = {"calculator": "Clarify how to perform arithmetic precisely."} + captured_seed: dict[str, str] = {} + + dspy.settings.configure(lm=DummyLM([{"q": "question", "a": "answer"}])) + reflection_lm = DummyLM([{"improved_instruction": "unused"}]) + + stub_optimize(monkeypatch, new_descs=new_descs, captured_seed=captured_seed) + optimizer = dspy.GEPA( + metric=simple_metric, + reflection_lm=reflection_lm, + max_metric_calls=3, + optimize_tool_descriptions=True, + ) + trainset = [ + make_example("Compute 3 + 5.", "8"), + ] + optimized = optimizer.compile(program, trainset=trainset) + + assert captured_seed["tool:calculator"] == original_desc + assert optimized.agent.tools["calculator"].desc == new_descs["calculator"] + assert optimized.agent.tools["calculator"].desc != original_desc