Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions dimos/agents/mcp/fixtures/test_tool_stream_agent.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
{
"responses": [
{
"content": "",
"tool_calls": [
{
"name": "start_streaming",
"args": {
"count": 3
},
"id": "call_toolstream_001",
"type": "tool_call"
}
]
},
{
"content": "I've started the streaming process. You'll receive 3 updates shortly.",
"tool_calls": []
},
{
"content": "Received the first streaming update: Update 1 of 3.",
"tool_calls": []
},
{
"content": "Received the second streaming update: Update 2 of 3.",
"tool_calls": []
},
{
"content": "Received the third and final streaming update: Update 3 of 3. All updates complete.",
"tool_calls": []
}
]
}
58 changes: 56 additions & 2 deletions dimos/agents/mcp/mcp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import json
from queue import Empty, Queue
from threading import Event, RLock, Thread
import time
Expand Down Expand Up @@ -99,6 +100,59 @@ def _mcp_request(self, method: str, params: dict[str, Any] | None = None) -> dic
result: dict[str, Any] = data.get("result")
return result

def _mcp_tool_call(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]:
"""Execute a tool call, handling both JSON and SSE streaming responses."""
body: dict[str, Any] = {
"jsonrpc": "2.0",
"id": self._seq_ids.next(),
"method": "tools/call",
"params": {"name": name, "arguments": arguments},
}

with self._http_client.stream(
"POST",
self.config.mcp_server_url,
json=body,
headers={"Accept": "application/json, text/event-stream"},
) as resp:
resp.raise_for_status()
content_type = resp.headers.get("content-type", "")

if "text/event-stream" in content_type:
return self._consume_sse_tool_response(resp, name)

data = json.loads(resp.read())
if "error" in data:
raise RuntimeError(f"MCP error {data['error']['code']}: {data['error']['message']}")
result: dict[str, Any] = data.get("result", {})
return result

def _consume_sse_tool_response(
self, response: httpx.Response, tool_name: str
) -> dict[str, Any]:
"""Parse an SSE tool response, injecting notifications as HumanMessages."""
result: dict[str, Any] | None = None
for line in response.iter_lines():
if not line.startswith("data: "):
continue
try:
data = json.loads(line[6:])
except json.JSONDecodeError:
continue

if data.get("method") == "notifications/message":
text = data.get("params", {}).get("data", "")
if text:
self._message_queue.put(
HumanMessage(content=f"[Tool stream update from '{tool_name}']: {text}")
)
elif "result" in data:
result = data["result"]

if result is None:
return {"content": [{"type": "text", "text": "Stream ended without result."}]}
return result

def _fetch_tools(self, timeout: float = 60.0, interval: float = 1.0) -> list[StructuredTool]:
result = self._try_fetch_tools(timeout=timeout, interval=interval)
if result is None:
Expand Down Expand Up @@ -138,7 +192,7 @@ def _mcp_tool_to_langchain(self, mcp_tool: dict[str, Any]) -> StructuredTool:
input_schema = mcp_tool.get("inputSchema", {"type": "object", "properties": {}})

def call_tool(**kwargs: Any) -> str:
result = self._mcp_request("tools/call", {"name": name, "arguments": kwargs})
result = self._mcp_tool_call(name, kwargs)
content = result.get("content", [])
parts = [c.get("text", "") for c in content if c.get("type") == "text"]
text = "\n".join(parts)
Expand Down Expand Up @@ -241,7 +295,7 @@ def dispatch_continuation(
tool_args[key] = continuation_context[context_key]

try:
result = self._mcp_request("tools/call", {"name": tool_name, "arguments": tool_args})
result = self._mcp_tool_call(tool_name, tool_args)
content = result.get("content", [])
parts = [c.get("text", "") for c in content if c.get("type") == "text"]
text = "\n".join(parts)
Expand Down
144 changes: 137 additions & 7 deletions dimos/agents/mcp/mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
from __future__ import annotations

import asyncio
from collections.abc import AsyncGenerator
import concurrent.futures
from dataclasses import dataclass
import json
import os
import time
Expand All @@ -23,14 +25,17 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from reactivex.disposable import Disposable
from starlette.requests import Request
from starlette.responses import Response
from starlette.responses import Response, StreamingResponse
import uvicorn

from dimos.agents.annotation import skill
from dimos.agents.mcp.tool_stream import ToolStreamEvent
from dimos.core.core import rpc
from dimos.core.module import Module
from dimos.core.rpc_client import RpcCall, RPCClient
from dimos.core.stream import In
from dimos.utils.logging_config import setup_logger

if TYPE_CHECKING:
Expand All @@ -43,11 +48,12 @@
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST"],
allow_methods=["POST", "GET"],
allow_headers=["*"],
)
app.state.skills = []
app.state.rpc_calls = {}
app.state.sse_queues = [] # list[asyncio.Queue] for SSE clients


def _jsonrpc_result(req_id: Any, result: Any) -> dict[str, Any]:
Expand All @@ -62,12 +68,20 @@ def _jsonrpc_error(req_id: Any, code: int, message: str) -> dict[str, Any]:
return {"jsonrpc": "2.0", "id": req_id, "error": {"code": code, "message": message}}


@dataclass(frozen=True, slots=True)
class _StreamingToolResult:
"""Marker returned when a tool starts a background stream."""

req_id: Any
tool_name: str


def _handle_initialize(req_id: Any) -> dict[str, Any]:
return _jsonrpc_result(
req_id,
{
"protocolVersion": "2025-11-25",
"capabilities": {"tools": {}},
"capabilities": {"tools": {}, "logging": {}},
"serverInfo": {"name": "dimensional", "version": "1.0.0"},
},
)
Expand All @@ -90,7 +104,7 @@ def _handle_tools_list(req_id: Any, skills: list[SkillInfo]) -> dict[str, Any]:

async def _handle_tools_call(
req_id: Any, params: dict[str, Any], rpc_calls: dict[str, Any]
) -> dict[str, Any]:
) -> dict[str, Any] | _StreamingToolResult:
name = params.get("name", "")
args: dict[str, Any] = params.get("arguments") or {}

Expand All @@ -111,8 +125,8 @@ async def _handle_tools_call(
duration = f"{time.monotonic() - t0:.3f}s"

if result is None:
logger.info("MCP tool done (async)", tool=name, duration=duration)
return _jsonrpc_result_text(req_id, "It has started. You will be updated later.")
logger.info("MCP tool streaming", tool=name, duration=duration)
return _StreamingToolResult(req_id=req_id, tool_name=name)

response = str(result)[:200]
if hasattr(result, "agent_encode"):
Expand All @@ -127,7 +141,7 @@ async def handle_request(
request: dict[str, Any],
skills: list[SkillInfo],
rpc_calls: dict[str, Any],
) -> dict[str, Any] | None:
) -> dict[str, Any] | _StreamingToolResult | None:
"""Handle a single MCP JSON-RPC request.

Returns None for JSON-RPC notifications (no ``id``), which must not
Expand Down Expand Up @@ -161,13 +175,119 @@ async def mcp_endpoint(request: Request) -> Response:
{"jsonrpc": "2.0", "id": None, "error": {"code": -32700, "message": "Parse error"}},
status_code=400,
)

# Pre-register a queue for tool stream events when the client accepts SSE.
accept = request.headers.get("accept", "")
is_tool_call = body.get("method") == "tools/call"
client_accepts_sse = "text/event-stream" in accept

stream_queue: asyncio.Queue[dict[str, Any]] | None = None
if is_tool_call and client_accepts_sse:
stream_queue = asyncio.Queue()
app.state.sse_queues.append(stream_queue)

result = await handle_request(body, request.app.state.skills, request.app.state.rpc_calls)

# Streaming tool: return SSE response if client supports it.
if isinstance(result, _StreamingToolResult):
if stream_queue is not None:
return _streaming_tool_response(result, stream_queue)
# Client doesn't support SSE — fall back to immediate JSON.
return JSONResponse(
_jsonrpc_result_text(result.req_id, "It has started. You will be updated later.")
)

# Non-streaming: remove the pre-registered queue if any.
if stream_queue is not None:
try:
app.state.sse_queues.remove(stream_queue)
except ValueError:
pass

if result is None:
return Response(status_code=204)
return JSONResponse(result)


_STREAM_TIMEOUT = 300.0 # seconds


def _sse_event(data: dict[str, Any]) -> str:
"""Format a JSON-RPC message as an SSE ``event: message`` frame."""
return f"event: message\ndata: {json.dumps(data)}\n\n"


def _streaming_tool_response(
streaming: _StreamingToolResult,
queue: asyncio.Queue[dict[str, Any]],
) -> StreamingResponse:
"""Build an SSE response that forwards ToolStream events as MCP log notifications.

The response streams ``notifications/message`` for each update and ends
with the JSON-RPC result carrying the accumulated text.
"""

async def event_generator() -> AsyncGenerator[str, None]:
stream_id: str | None = None
collected: list[str] = []
try:
while True:
try:
data = await asyncio.wait_for(queue.get(), timeout=_STREAM_TIMEOUT)
except asyncio.TimeoutError:
text = "\n".join(collected) if collected else "No updates received."
yield _sse_event(_jsonrpc_result_text(streaming.req_id, text))
return

try:
event = ToolStreamEvent(**data)
except (TypeError, KeyError):
continue

# Filter: match by tool_name, lock onto the first stream_id seen.
if event.tool_name != streaming.tool_name:
continue
if stream_id is None:
stream_id = event.stream_id
elif event.stream_id != stream_id:
continue

if event.type == "update" and event.text:
collected.append(event.text)
yield _sse_event(
{
"jsonrpc": "2.0",
"method": "notifications/message",
"params": {
"level": "info",
"logger": event.tool_name,
"data": event.text,
},
}
)

elif event.type == "close":
text = "\n".join(collected) if collected else "Stream completed."
yield _sse_event(_jsonrpc_result_text(streaming.req_id, text))
return
except asyncio.CancelledError:
pass
finally:
try:
app.state.sse_queues.remove(queue)
except ValueError:
pass

return StreamingResponse(
event_generator(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)


class McpServer(Module):
tool_streams: In[dict[str, Any]]

_uvicorn_server: uvicorn.Server | None = None
_serve_future: concurrent.futures.Future[None] | None = None

Expand All @@ -176,6 +296,16 @@ def start(self) -> None:
super().start()
self._start_server()

loop = self._loop

def _on_tool_stream_message(msg: dict[str, Any]) -> None:
if loop is None:
return
for queue in list(app.state.sse_queues):
asyncio.run_coroutine_threadsafe(queue.put(msg), loop)

self._disposables.add(Disposable(self.tool_streams.subscribe(_on_tool_stream_message)))

@rpc
def stop(self) -> None:
if self._uvicorn_server:
Expand Down
17 changes: 17 additions & 0 deletions dimos/agents/mcp/test_mcp_client_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,28 @@ def _mock_post(url: str, **kwargs: object) -> MagicMock:
return resp


def _mock_stream(_method: str, _url: str, **kwargs: object) -> MagicMock:
"""Wrap ``_mock_post`` result as an httpx streaming context manager."""
post_resp = _mock_post(_url, **kwargs)
json_bytes = json.dumps(post_resp.json.return_value).encode()

resp = MagicMock()
resp.raise_for_status = MagicMock()
resp.headers = {"content-type": "application/json"}
resp.read.return_value = json_bytes

cm = MagicMock()
cm.__enter__ = MagicMock(return_value=resp)
cm.__exit__ = MagicMock(return_value=False)
return cm


@pytest.fixture
def mcp_client() -> McpClient:
"""Build an McpClient wired to the mock MCP post handler."""
mock_http = MagicMock()
mock_http.post.side_effect = _mock_post
mock_http.stream.side_effect = _mock_stream

with patch("dimos.agents.mcp.mcp_client.httpx.Client", return_value=mock_http):
client = McpClient.__new__(McpClient)
Expand Down
Loading
Loading