Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions docs/mcp.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,34 @@ The constructor accepts additional options:
- `max_retry_attempts` and `retry_backoff_seconds_base` add automatic retries for `list_tools()` and `call_tool()`.
- `tool_filter` lets you expose only a subset of tools (see [Tool filtering](#tool-filtering)).

### Handling MCP tool errors

Agents convert each MCP tool into an SDK `FunctionTool`. If a remote tool raises an exception the default behaviour is to bubble
the error up and abort the run. When you want to return a string error message to the model instead, set
`failure_error_function` in the agent's `mcp_config`. The callback receives the tool run context and the raised exception (with
the original exception attached as `__cause__`) and should return the string that the model sees.

```python
from typing import Any

from agents import Agent, RunContextWrapper


def format_mcp_error(ctx: RunContextWrapper[Any], error: Exception) -> str:
# Unwrap the root cause to keep the message concise.
root_error = error.__cause__ if getattr(error, "__cause__", None) else error
return f"The MCP tool failed: {root_error}"


agent = Agent(
name="Assistant",
mcp_servers=[server],
mcp_config={"failure_error_function": format_mcp_error},
)
```

If you prefer to let the exception terminate the run, omit `failure_error_function` or set it to `None`.

## 3. HTTP with SSE MCP servers

If the MCP server implements the HTTP with SSE transport, instantiate
Expand Down
22 changes: 20 additions & 2 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@
from .models.interface import Model
from .prompts import DynamicPromptFunction, Prompt, PromptUtil
from .run_context import RunContextWrapper, TContext
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
from .tool import (
FunctionTool,
FunctionToolResult,
Tool,
ToolErrorFunction,
function_tool,
)
from .util import _transforms
from .util._types import MaybeAwaitable

Expand Down Expand Up @@ -72,6 +78,13 @@ class MCPConfig(TypedDict):
best-effort conversion, so some schemas may not be convertible. Defaults to False.
"""

failure_error_function: NotRequired[ToolErrorFunction | None]
"""Optional function used to generate an error response when an MCP tool invocation fails.

If provided, exceptions raised while calling a tool will be converted to the string returned by
this function instead of propagating to the agent.
"""


@dataclass
class AgentBase(Generic[TContext]):
Expand Down Expand Up @@ -104,8 +117,13 @@ class AgentBase(Generic[TContext]):
async def get_mcp_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
"""Fetches the available tools from the MCP servers."""
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
failure_error_function = self.mcp_config.get("failure_error_function")
return await MCPUtil.get_all_function_tools(
self.mcp_servers, convert_schemas_to_strict, run_context, self
self.mcp_servers,
convert_schemas_to_strict,
run_context,
self,
failure_error_function=failure_error_function,
)

async def get_all_tools(self, run_context: RunContextWrapper[TContext]) -> list[Tool]:
Expand Down
53 changes: 46 additions & 7 deletions src/agents/mcp/util.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import functools
import inspect
import json
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union
Expand All @@ -11,8 +11,9 @@
from ..logger import logger
from ..run_context import RunContextWrapper
from ..strict_schema import ensure_strict_json_schema
from ..tool import FunctionTool, Tool
from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span
from ..tool import FunctionTool, Tool, ToolErrorFunction
from ..tracing import FunctionSpanData, SpanError, get_current_span, mcp_tools_span
from ..util import _error_tracing
from ..util._types import MaybeAwaitable

if TYPE_CHECKING:
Expand Down Expand Up @@ -116,13 +117,18 @@ async def get_all_function_tools(
convert_schemas_to_strict: bool,
run_context: RunContextWrapper[Any],
agent: "AgentBase",
failure_error_function: ToolErrorFunction | None = None,
) -> list[Tool]:
"""Get all function tools from a list of MCP servers."""
tools = []
tool_names: set[str] = set()
for server in servers:
server_tools = await cls.get_function_tools(
server, convert_schemas_to_strict, run_context, agent
server,
convert_schemas_to_strict,
run_context,
agent,
failure_error_function=failure_error_function,
)
server_tool_names = {tool.name for tool in server_tools}
if len(server_tool_names & tool_names) > 0:
Expand All @@ -142,21 +148,54 @@ async def get_function_tools(
convert_schemas_to_strict: bool,
run_context: RunContextWrapper[Any],
agent: "AgentBase",
failure_error_function: ToolErrorFunction | None = None,
) -> list[Tool]:
"""Get all function tools from a single MCP server."""

with mcp_tools_span(server=server.name) as span:
tools = await server.list_tools(run_context, agent)
span.span_data.result = [tool.name for tool in tools]

return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools]
return [
cls.to_function_tool(
tool,
server,
convert_schemas_to_strict,
failure_error_function=failure_error_function,
)
for tool in tools
]

@classmethod
def to_function_tool(
cls, tool: "MCPTool", server: "MCPServer", convert_schemas_to_strict: bool
cls,
tool: "MCPTool",
server: "MCPServer",
convert_schemas_to_strict: bool,
failure_error_function: ToolErrorFunction | None = None,
) -> FunctionTool:
"""Convert an MCP tool to an Agents SDK function tool."""
invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool)
async def invoke_func(context: RunContextWrapper[Any], input_json: str) -> str:
try:
return await cls.invoke_mcp_tool(server, tool, context, input_json)
except Exception as exc:
if failure_error_function is None:
raise

result = failure_error_function(context, exc)
if inspect.isawaitable(result):
result = await result

_error_tracing.attach_error_to_current_span(
SpanError(
message="Error running MCP tool (non-fatal)",
data={
"tool_name": tool.name,
"error": str(exc),
},
)
)
return result
schema, is_strict = tool.inputSchema, False

# MCP spec doesn't require the inputSchema to have `properties`, but OpenAI spec does.
Expand Down
29 changes: 29 additions & 0 deletions tests/mcp/test_mcp_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from agents import Agent, FunctionTool, RunContextWrapper
from agents.exceptions import AgentsException, ModelBehaviorError
from agents.mcp import MCPServer, MCPUtil
from agents.tool_context import ToolContext

from .helpers import FakeMCPServer

Expand Down Expand Up @@ -130,6 +131,34 @@ async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixtur
assert "Error invoking MCP tool test_tool_1" in caplog.text


@pytest.mark.asyncio
async def test_mcp_failure_error_function_handles_exceptions():
"""Agent-level failure error functions should convert MCP invocation errors to strings."""

server = CrashingFakeMCPServer()
server.add_tool("test_tool_1", {})

def failure_handler(ctx: RunContextWrapper[Any], error: Exception) -> str:
root = error.__cause__ if getattr(error, "__cause__", None) else error
return f"custom error: {root}"

agent = Agent(
name="test_agent",
mcp_servers=[server],
mcp_config={"failure_error_function": failure_handler},
)
run_context = RunContextWrapper(context=None)
tools = await agent.get_mcp_tools(run_context)

tool = next(tool for tool in tools if tool.name == "test_tool_1")
tool_context = ToolContext(
context=None, tool_name=tool.name, tool_call_id="1", tool_arguments="{}"
)

result = await tool.on_invoke_tool(tool_context, "{}")
assert result == "custom error: Crash!"


@pytest.mark.asyncio
async def test_agent_convert_schemas_true():
"""Test that setting convert_schemas_to_strict to True converts non-strict schemas to strict.
Expand Down
Loading