diff --git a/docs/mcp.md b/docs/mcp.md index 4ee7b5781..481b674a5 100644 --- a/docs/mcp.md +++ b/docs/mcp.md @@ -179,6 +179,34 @@ The constructor accepts additional options: - `max_retry_attempts` and `retry_backoff_seconds_base` add automatic retries for `list_tools()` and `call_tool()`. - `tool_filter` lets you expose only a subset of tools (see [Tool filtering](#tool-filtering)). +### Handling MCP tool errors + +Agents convert each MCP tool into an SDK `FunctionTool`. If a remote tool raises an exception the default behaviour is to bubble +the error up and abort the run. When you want to return a string error message to the model instead, set +`failure_error_function` in the agent's `mcp_config`. The callback receives the tool run context and the raised exception (with +the original exception attached as `__cause__`) and should return the string that the model sees. + +```python +from typing import Any + +from agents import Agent, RunContextWrapper + + +def format_mcp_error(ctx: RunContextWrapper[Any], error: Exception) -> str: + # Unwrap the root cause to keep the message concise. + root_error = error.__cause__ if getattr(error, "__cause__", None) else error + return f"The MCP tool failed: {root_error}" + + +agent = Agent( + name="Assistant", + mcp_servers=[server], + mcp_config={"failure_error_function": format_mcp_error}, +) +``` + +If you prefer to let the exception terminate the run, omit `failure_error_function` or set it to `None`. + ## 3. HTTP with SSE MCP servers If the MCP server implements the HTTP with SSE transport, instantiate diff --git a/src/agents/agent.py b/src/agents/agent.py index a061926b1..c79370367 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -25,7 +25,13 @@ from .models.interface import Model from .prompts import DynamicPromptFunction, Prompt, PromptUtil from .run_context import RunContextWrapper, TContext -from .tool import FunctionTool, FunctionToolResult, Tool, function_tool +from .tool import ( + FunctionTool, + FunctionToolResult, + Tool, + ToolErrorFunction, + function_tool, +) from .util import _transforms from .util._types import MaybeAwaitable @@ -72,6 +78,13 @@ class MCPConfig(TypedDict): best-effort conversion, so some schemas may not be convertible. Defaults to False. """ + failure_error_function: NotRequired[ToolErrorFunction | None] + """Optional function used to generate an error response when an MCP tool invocation fails. + + If provided, exceptions raised while calling a tool will be converted to the string returned by + this function instead of propagating to the agent. + """ + @dataclass class AgentBase(Generic[TContext]): @@ -104,8 +117,13 @@ class AgentBase(Generic[TContext]): async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]: """Fetches the available tools from the MCP servers.""" convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) + failure_error_function = self.mcp_config.get("failure_error_function") return await MCPUtil.get_all_function_tools( - self.mcp_servers, convert_schemas_to_strict, run_context, self + self.mcp_servers, + convert_schemas_to_strict, + run_context, + self, + failure_error_function=failure_error_function, ) async def get_all_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]: diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 6cfe5c96d..b842e7cc1 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -1,4 +1,4 @@ -import functools +import inspect import json from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union @@ -11,8 +11,9 @@ from ..logger import logger from ..run_context import RunContextWrapper from ..strict_schema import ensure_strict_json_schema -from ..tool import FunctionTool, Tool -from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span +from ..tool import FunctionTool, Tool, ToolErrorFunction +from ..tracing import FunctionSpanData, SpanError, get_current_span, mcp_tools_span +from ..util import _error_tracing from ..util._types import MaybeAwaitable if TYPE_CHECKING: @@ -116,13 +117,18 @@ async def get_all_function_tools( convert_schemas_to_strict: bool, run_context: RunContextWrapper[Any], agent: "AgentBase", + failure_error_function: ToolErrorFunction | None = None, ) -> list[Tool]: """Get all function tools from a list of MCP servers.""" tools = [] tool_names: set[str] = set() for server in servers: server_tools = await cls.get_function_tools( - server, convert_schemas_to_strict, run_context, agent + server, + convert_schemas_to_strict, + run_context, + agent, + failure_error_function=failure_error_function, ) server_tool_names = {tool.name for tool in server_tools} if len(server_tool_names & tool_names) > 0: @@ -142,6 +148,7 @@ async def get_function_tools( convert_schemas_to_strict: bool, run_context: RunContextWrapper[Any], agent: "AgentBase", + failure_error_function: ToolErrorFunction | None = None, ) -> list[Tool]: """Get all function tools from a single MCP server.""" @@ -149,14 +156,46 @@ async def get_function_tools( tools = await server.list_tools(run_context, agent) span.span_data.result = [tool.name for tool in tools] - return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools] + return [ + cls.to_function_tool( + tool, + server, + convert_schemas_to_strict, + failure_error_function=failure_error_function, + ) + for tool in tools + ] @classmethod def to_function_tool( - cls, tool: "MCPTool", server: "MCPServer", convert_schemas_to_strict: bool + cls, + tool: "MCPTool", + server: "MCPServer", + convert_schemas_to_strict: bool, + failure_error_function: ToolErrorFunction | None = None, ) -> FunctionTool: """Convert an MCP tool to an Agents SDK function tool.""" - invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool) + async def invoke_func(context: RunContextWrapper[Any], input_json: str) -> str: + try: + return await cls.invoke_mcp_tool(server, tool, context, input_json) + except Exception as exc: + if failure_error_function is None: + raise + + result = failure_error_function(context, exc) + if inspect.isawaitable(result): + result = await result + + _error_tracing.attach_error_to_current_span( + SpanError( + message="Error running MCP tool (non-fatal)", + data={ + "tool_name": tool.name, + "error": str(exc), + }, + ) + ) + return result schema, is_strict = tool.inputSchema, False # MCP spec doesn't require the inputSchema to have `properties`, but OpenAI spec does. diff --git a/tests/mcp/test_mcp_util.py b/tests/mcp/test_mcp_util.py index e434f7542..3f6f69f02 100644 --- a/tests/mcp/test_mcp_util.py +++ b/tests/mcp/test_mcp_util.py @@ -9,6 +9,7 @@ from agents import Agent, FunctionTool, RunContextWrapper from agents.exceptions import AgentsException, ModelBehaviorError from agents.mcp import MCPServer, MCPUtil +from agents.tool_context import ToolContext from .helpers import FakeMCPServer @@ -130,6 +131,34 @@ async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixtur assert "Error invoking MCP tool test_tool_1" in caplog.text +@pytest.mark.asyncio +async def test_mcp_failure_error_function_handles_exceptions(): + """Agent-level failure error functions should convert MCP invocation errors to strings.""" + + server = CrashingFakeMCPServer() + server.add_tool("test_tool_1", {}) + + def failure_handler(ctx: RunContextWrapper[Any], error: Exception) -> str: + root = error.__cause__ if getattr(error, "__cause__", None) else error + return f"custom error: {root}" + + agent = Agent( + name="test_agent", + mcp_servers=[server], + mcp_config={"failure_error_function": failure_handler}, + ) + run_context = RunContextWrapper(context=None) + tools = await agent.get_mcp_tools(run_context) + + tool = next(tool for tool in tools if tool.name == "test_tool_1") + tool_context = ToolContext( + context=None, tool_name=tool.name, tool_call_id="1", tool_arguments="{}" + ) + + result = await tool.on_invoke_tool(tool_context, "{}") + assert result == "custom error: Crash!" + + @pytest.mark.asyncio async def test_agent_convert_schemas_true(): """Test that setting convert_schemas_to_strict to True converts non-strict schemas to strict.