Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions environments/mcp_env/mcp_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
EXA_FETCH_TOOLS = [
{
"name": "exa",
"transport": "stdio",
"command": "npx",
"args": [
"-y",
Expand All @@ -31,12 +32,22 @@
},
{
"name": "fetch",
"transport": "stdio",
"command": "uvx",
"args": ["mcp-server-fetch"],
"description": "Fetch MCP server",
},
]

BROWSERBASE_TOOLS = [
{
"name": "browserbase",
"transport": "http",
"url": os.getenv("BROWSERBASE_URL"),
"description": "Browserbase MCP",
},
]


class MCPEnv(ToolEnv):
"""Environment for MCP-based tools using the official MCP SDK."""
Expand Down
59 changes: 44 additions & 15 deletions environments/mcp_env/src/mcp_server_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.types import TextContent, Tool

from .models import MCPServerConfig
Expand Down Expand Up @@ -35,27 +36,55 @@ async def connect(self):

async def _get_connection(self):
try:
server_params = StdioServerParameters(
command=self.config.command,
args=self.config.args or [],
env=self.config.env,
)
if self.config.transport == "stdio":
if not self.config.command:
raise ValueError("stdio transport requires 'command'")
server_params = StdioServerParameters(
command=self.config.command,
args=self.config.args or [],
env=self.config.env,
)

async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
self.session = session
async with stdio_client(server_params) as (read, write):
async with ClientSession(read, write) as session:
self.session = session

await session.initialize()
await session.initialize()

tools_response = await session.list_tools()
tools_response = await session.list_tools()

for tool in tools_response.tools:
self.tools[tool.name] = tool
for tool in tools_response.tools:
self.tools[tool.name] = tool

self._ready.set()
self._ready.set()

while True:
await asyncio.sleep(1)
while True:
await asyncio.sleep(1)

elif self.config.transport == "http":
if not self.config.url:
raise ValueError("http transport requires 'url'")

async with streamablehttp_client(
self.config.url,
headers=self.config.headers or {},
) as (read, write, _get_session_id):
async with ClientSession(read, write) as session:
self.session = session

await session.initialize()

tools_response = await session.list_tools()

for tool in tools_response.tools:
self.tools[tool.name] = tool

self._ready.set()

while True:
await asyncio.sleep(1)
else:
raise ValueError(f"Unknown transport: {self.config.transport}")

except asyncio.CancelledError:
raise
Expand Down
13 changes: 9 additions & 4 deletions environments/mcp_env/src/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from dataclasses import dataclass
from typing import Dict, List
from typing import Dict, List, Literal, Optional


@dataclass
class MCPServerConfig:
name: str
command: str
args: List[str] | None = None
env: Dict[str, str] | None = None
transport: Literal["stdio", "http"] = "stdio"
description: str = ""
# stdio params
command: Optional[str] = None
args: Optional[List[str]] = None
env: Optional[Dict[str, str]] = None
# http params
url: Optional[str] = None
headers: Optional[Dict[str, str]] = None
Loading