From 888cf3f4ca03474a2defb2528e198fe6cd9f1ecf Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Fri, 13 Mar 2026 02:53:54 +0200 Subject: [PATCH 1/2] feat(tool-streams): add tool streams --- .../mcp/fixtures/test_tool_stream_agent.json | 33 ++++ dimos/agents/mcp/mcp_client.py | 60 ++++++ dimos/agents/mcp/mcp_server.py | 50 ++++- dimos/agents/mcp/test_tool_stream.py | 181 ++++++++++++++++++ dimos/agents/mcp/tool_stream.py | 99 ++++++++++ dimos/agents/skills/person_follow.py | 29 ++- dimos/perception/perceive_loop_skill.py | 21 +- 7 files changed, 462 insertions(+), 11 deletions(-) create mode 100644 dimos/agents/mcp/fixtures/test_tool_stream_agent.json create mode 100644 dimos/agents/mcp/test_tool_stream.py create mode 100644 dimos/agents/mcp/tool_stream.py diff --git a/dimos/agents/mcp/fixtures/test_tool_stream_agent.json b/dimos/agents/mcp/fixtures/test_tool_stream_agent.json new file mode 100644 index 0000000000..0c334c5d53 --- /dev/null +++ b/dimos/agents/mcp/fixtures/test_tool_stream_agent.json @@ -0,0 +1,33 @@ +{ + "responses": [ + { + "content": "", + "tool_calls": [ + { + "name": "start_streaming", + "args": { + "count": 3 + }, + "id": "call_toolstream_001", + "type": "tool_call" + } + ] + }, + { + "content": "I've started the streaming process. You'll receive 3 updates shortly.", + "tool_calls": [] + }, + { + "content": "Received the first streaming update: Update 1 of 3.", + "tool_calls": [] + }, + { + "content": "Received the second streaming update: Update 2 of 3.", + "tool_calls": [] + }, + { + "content": "Received the third and final streaming update: Update 3 of 3. All updates complete.", + "tool_calls": [] + } + ] +} diff --git a/dimos/agents/mcp/mcp_client.py b/dimos/agents/mcp/mcp_client.py index e0200a6323..8b671c7c1b 100644 --- a/dimos/agents/mcp/mcp_client.py +++ b/dimos/agents/mcp/mcp_client.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json from queue import Empty, Queue from threading import Event, RLock, Thread import time @@ -26,6 +27,7 @@ from langgraph.graph.state import CompiledStateGraph from reactivex.disposable import Disposable +from dimos.agents.mcp.tool_stream import ToolStreamEvent from dimos.agents.system_prompt import SYSTEM_PROMPT from dimos.agents.utils import pretty_print_langchain_message from dimos.core.core import rpc @@ -60,6 +62,8 @@ class McpClient(Module[McpClientConfig]): _stop_event: Event _http_client: httpx.Client _seq_ids: SequentialIds + _sse_thread: Thread | None + _sse_client: httpx.Client | None def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -76,6 +80,8 @@ def __init__(self, **kwargs: Any) -> None: self._stop_event = Event() self._http_client = httpx.Client(timeout=120.0) self._seq_ids = SequentialIds() + self._sse_thread = None + self._sse_client = None def __reduce__(self) -> Any: return (self.__class__, (), {}) @@ -188,11 +194,18 @@ def on_system_modules(self, _modules: list[RPCClient]) -> None: ) self._thread.start() + self._start_sse_listener() + @rpc def stop(self) -> None: self._stop_event.set() + with self._lock: + if self._sse_client is not None: + self._sse_client.close() if self._thread.is_alive(): self._thread.join(timeout=2.0) + if self._sse_thread is not None and self._sse_thread.is_alive(): + self._sse_thread.join(timeout=2.0) self._http_client.close() super().stop() @@ -289,6 +302,53 @@ def _process_message( if self._message_queue.empty(): self.agent_idle.publish(True) + def _start_sse_listener(self) -> None: + """Connect to the MCP server SSE endpoint to receive tool stream updates.""" + self._sse_thread = Thread(target=self._sse_loop, name="McpClient-SSE", daemon=True) + self._sse_thread.start() + + def _sse_loop(self) -> None: + base_url = self.config.mcp_server_url.rsplit("/mcp", 1)[0] + sse_url = f"{base_url}/mcp/streams" + + while not self._stop_event.is_set(): + try: + self._sse_connect(sse_url) + except Exception: + if not self._stop_event.is_set(): + # Try reconnecting after a short delay + time.sleep(1.0) + + def _sse_connect(self, sse_url: str) -> None: + client = httpx.Client(timeout=None) + with self._lock: + self._sse_client = client + try: + with client.stream("GET", sse_url) as response: + self._sse_consume(response) + finally: + with self._lock: + self._sse_client = None + client.close() + + def _sse_consume(self, response: httpx.Response) -> None: + for line in response.iter_lines(): + if self._stop_event.is_set(): + return + if not line.startswith("data: "): + continue + try: + data = json.loads(line[6:]) + except json.JSONDecodeError: + continue + event = ToolStreamEvent(**data) + if event.type == "update": + self._message_queue.put( + HumanMessage( + content=f"[Tool stream update from '{event.tool_name}']: {event.text}" + ) + ) + def _append_image_to_history( mcp_client: McpClient, func_name: str, uuid_: str, result: Any diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index 9149de06ec..2431bf8231 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -14,6 +14,7 @@ from __future__ import annotations import asyncio +from collections.abc import AsyncGenerator import concurrent.futures import json import os @@ -23,14 +24,16 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse +from reactivex.disposable import Disposable from starlette.requests import Request -from starlette.responses import Response +from starlette.responses import Response, StreamingResponse import uvicorn from dimos.agents.annotation import skill from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.rpc_client import RpcCall, RPCClient +from dimos.core.stream import In from dimos.utils.logging_config import setup_logger if TYPE_CHECKING: @@ -43,11 +46,12 @@ app.add_middleware( CORSMiddleware, allow_origins=["*"], - allow_methods=["POST"], + allow_methods=["POST", "GET"], allow_headers=["*"], ) app.state.skills = [] app.state.rpc_calls = {} +app.state.sse_queues = [] # list[asyncio.Queue] for SSE clients def _jsonrpc_result(req_id: Any, result: Any) -> dict[str, Any]: @@ -167,7 +171,39 @@ async def mcp_endpoint(request: Request) -> Response: return JSONResponse(result) +@app.get("/mcp/streams") +async def streams_sse_endpoint() -> StreamingResponse: + """Server-Sent Events endpoint for tool stream updates. + + Clients subscribe here to receive real-time updates from long-running + skills that use ``ToolStream``. + """ + queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() + app.state.sse_queues.append(queue) + + async def event_generator() -> AsyncGenerator[str, None]: + try: + while True: + data = await queue.get() + yield f"data: {json.dumps(data)}\n\n" + except asyncio.CancelledError: + pass + finally: + try: + app.state.sse_queues.remove(queue) + except ValueError: + pass + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + + class McpServer(Module): + tool_streams: In[dict[str, Any]] + _uvicorn_server: uvicorn.Server | None = None _serve_future: concurrent.futures.Future[None] | None = None @@ -176,6 +212,16 @@ def start(self) -> None: super().start() self._start_server() + loop = self._loop + + def _on_tool_stream_message(msg: dict[str, Any]) -> None: + if loop is None: + return + for queue in list(app.state.sse_queues): + asyncio.run_coroutine_threadsafe(queue.put(msg), loop) + + self._disposables.add(Disposable(self.tool_streams.subscribe(_on_tool_stream_message))) + @rpc def stop(self) -> None: if self._uvicorn_server: diff --git a/dimos/agents/mcp/test_tool_stream.py b/dimos/agents/mcp/test_tool_stream.py new file mode 100644 index 0000000000..a1ea103091 --- /dev/null +++ b/dimos/agents/mcp/test_tool_stream.py @@ -0,0 +1,181 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from threading import Event, Thread +import time + +import httpx +from langchain_core.messages import HumanMessage +import pytest + +from dimos.agents.annotation import skill +from dimos.agents.mcp.mcp_adapter import McpAdapter +from dimos.agents.mcp.mcp_server import McpServer +from dimos.agents.mcp.tool_stream import ToolStream, ToolStreamEvent +from dimos.core.blueprints import autoconnect +from dimos.core.global_config import global_config +from dimos.core.module import Module + + +class StreamingModule(Module): + """Test module that uses ToolStream to send multiple updates.""" + + @skill + def start_streaming(self, count: int) -> None: + """Starts streaming count updates back to the agent.""" + stream = ToolStream("start_streaming") + stream.start() + + def _stream_loop() -> None: + try: + for i in range(count): + time.sleep(0.1) + stream.send(f"Update {i + 1} of {count}") + finally: + stream.stop() + + Thread(target=_stream_loop, daemon=True).start() + + +@pytest.fixture +def mcp_server(): + """Start a blueprint with StreamingModule + McpServer, wait for readiness.""" + global_config.update(viewer="none") + blueprint = autoconnect(StreamingModule.blueprint(), McpServer.blueprint()) + coordinator = blueprint.build() + + adapter = McpAdapter() + if not adapter.wait_for_ready(timeout=15): + coordinator.stop() + pytest.fail("MCP server did not become ready") + + yield adapter + + coordinator.stop() + + +@pytest.mark.slow +def test_tool_stream_sse(mcp_server: McpAdapter) -> None: + """ToolStream updates flow through the HTTP SSE endpoint.""" + adapter = mcp_server + base_url = adapter.url.rsplit("/mcp", 1)[0] + sse_url = f"{base_url}/mcp/streams" + + events: list[ToolStreamEvent] = [] + sse_connected = Event() + sse_done = Event() + + def sse_reader() -> None: + try: + with httpx.Client(timeout=None) as client: + with client.stream("GET", sse_url) as resp: + sse_connected.set() + for line in resp.iter_lines(): + if not line.startswith("data: "): + continue + event = ToolStreamEvent(**json.loads(line[6:])) + events.append(event) + if event.type == "close": + sse_done.set() + return + except Exception: + pass # connection closed during teardown + + reader = Thread(target=sse_reader, daemon=True) + reader.start() + assert sse_connected.wait(5), "SSE connection was not established" + time.sleep(0.5) # let the server register the SSE queue + + # Call the streaming tool via MCP. + result = adapter.call_tool("start_streaming", {"count": 3}) + assert "It has started" in result["content"][0]["text"] + + # Wait for all SSE events (updates + close). + assert sse_done.wait(10), "Timed out waiting for tool stream close event" + + updates = [e for e in events if e.type == "update"] + closes = [e for e in events if e.type == "close"] + + assert len(updates) == 3 + assert updates[0].text == "Update 1 of 3" + assert updates[1].text == "Update 2 of 3" + assert updates[2].text == "Update 3 of 3" + + assert len(closes) == 1 + + # All events share the same stream id and tool name. + stream_ids = {e.stream_id for e in events} + assert len(stream_ids) == 1 + assert all(e.tool_name == "start_streaming" for e in events) + + +@pytest.mark.slow +def test_tool_stream_agent(agent_setup) -> None: # type: ignore[no-untyped-def] + """Tool stream updates arrive at the agent as HumanMessages.""" + history = agent_setup( + blueprints=[StreamingModule.blueprint()], + messages=[ + HumanMessage("Start streaming 3 updates using the start_streaming tool with count=3.") + ], + ) + + # agent_setup returns after the initial tool call round-trip. The tool + # stream updates arrive asynchronously afterwards — poll the history list + # (which is still being mutated by the /agent transport callback). + deadline = time.monotonic() + 10.0 + while time.monotonic() < deadline: + stream_updates = [ + m + for m in history + if isinstance(m, HumanMessage) and "[Tool stream update" in str(m.content) + ] + if len(stream_updates) >= 3: + break + time.sleep(0.2) + + stream_updates = [ + m + for m in history + if isinstance(m, HumanMessage) and "[Tool stream update" in str(m.content) + ] + assert len(stream_updates) == 3 + assert "Update 1 of 3" in stream_updates[0].content + assert "Update 2 of 3" in stream_updates[1].content + assert "Update 3 of 3" in stream_updates[2].content + + +@pytest.fixture() +def make_stream(mocker): + """Create a ToolStream with a mocked transport.""" + mock_transport = mocker.MagicMock() + mocker.patch("dimos.agents.mcp.tool_stream.pLCMTransport", return_value=mock_transport) + stream = ToolStream("test_tool") + stream.start() + return stream, mock_transport + + +def test_send_after_stop_does_not_raise(make_stream) -> None: + stream, _ = make_stream + stream.stop() + # Must not raise even though transport is None after stop. + stream.send("should be ignored") + + +def test_double_stop_is_safe(make_stream) -> None: + stream, mock_transport = make_stream + stream.stop() + stream.stop() + # Transport stop called only once. + mock_transport.stop.assert_called_once() diff --git a/dimos/agents/mcp/tool_stream.py b/dimos/agents/mcp/tool_stream.py new file mode 100644 index 0000000000..a4f5053a09 --- /dev/null +++ b/dimos/agents/mcp/tool_stream.py @@ -0,0 +1,99 @@ +# Copyright 2026 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from dataclasses import asdict, dataclass +import threading +from typing import Any, Literal +import uuid + +from dimos.core.transport import pLCMTransport +from dimos.utils.logging_config import setup_logger + +logger = setup_logger() + +_TOOL_STREAM_TOPIC = "/tool_streams" + + +@dataclass(frozen=True, slots=True) +class ToolStreamEvent: + """A single event published on a tool stream channel.""" + + stream_id: str + tool_name: str + type: Literal["update", "close"] + text: str | None = None + + def to_dict(self) -> dict[str, Any]: + """Serialize to a dict, omitting ``None`` values.""" + return {k: v for k, v in asdict(self).items() if v is not None} + + +class ToolStream: + """A streaming channel for sending updates from a running skill to the agent. + + Each `ToolStream` publishes messages on a shared LCM topic. The agent + (or the MCP server SSE endpoint) subscribes once and receives all updates. + """ + + def __init__(self, tool_name: str) -> None: + self.tool_name: str = tool_name + self.id: str = str(uuid.uuid4()) + self._closed: threading.Event = threading.Event() + self._transport: pLCMTransport[dict[str, Any]] | None = None + self._lock = threading.Lock() + + def start(self) -> None: + self._transport = pLCMTransport(_TOOL_STREAM_TOPIC) + self._transport.start() + + def send(self, message: str) -> None: + with self._lock: + if self._closed.is_set(): + logger.error("Attempted to send on closed ToolStream", stream_id=self.id) + return + if self._transport is None: + logger.error("ToolStream transport not initialized", stream_id=self.id) + return + self._transport.publish( + ToolStreamEvent( + stream_id=self.id, + tool_name=self.tool_name, + type="update", + text=message, + ).to_dict(), + ) + + def stop(self) -> None: + with self._lock: + if self._closed.is_set(): + return + self._closed.set() + if self._transport is not None: + try: + self._transport.publish( + ToolStreamEvent( + stream_id=self.id, + tool_name=self.tool_name, + type="close", + ).to_dict(), + ) + finally: + self._transport.stop() + self._transport = None + + @property + def is_closed(self) -> bool: + return self._closed.is_set() diff --git a/dimos/agents/skills/person_follow.py b/dimos/agents/skills/person_follow.py index 9f97a23d53..2d4bab0f65 100644 --- a/dimos/agents/skills/person_follow.py +++ b/dimos/agents/skills/person_follow.py @@ -17,13 +17,12 @@ import time from typing import Any -from langchain_core.messages import HumanMessage import numpy as np from reactivex.disposable import Disposable from turbojpeg import TurboJPEG -from dimos.agents.agent_spec import AgentSpec from dimos.agents.annotation import skill +from dimos.agents.mcp.tool_stream import ToolStream from dimos.core.core import rpc from dimos.core.module import Module, ModuleConfig from dimos.core.stream import In, Out @@ -65,7 +64,6 @@ class PersonFollowSkillContainer(Module[Config]): global_map: In[PointCloud2] cmd_vel: Out[Twist] - _agent_spec: AgentSpec _frequency: float = 20.0 # Hz - control loop frequency _max_lost_frames: int = 15 # number of frames to wait before declaring person lost _patrolling_module_spec: PatrollingModuleSpec @@ -79,6 +77,7 @@ def __init__(self, **kwargs: Any) -> None: self._thread: Thread | None = None self._should_stop: Event = Event() self._lock = RLock() + self._tool_stream: ToolStream | None = None # Use MuJoCo camera intrinsics in simulation mode camera_info = self.config.camera_info @@ -101,7 +100,15 @@ def start(self) -> None: def stop(self) -> None: self._stop_following() + thread = self._thread + if thread is not None: + thread.join(timeout=2) + self._thread = None + with self._lock: + if self._tool_stream is not None: + self._tool_stream.stop() + self._tool_stream = None if self._tracker is not None: self._tracker.stop() self._tracker = None @@ -140,6 +147,10 @@ def follow_person( self._stop_following() + if self._thread is not None: + self._thread.join(timeout=2) + self._thread = None + self._should_stop.clear() with self._lock: @@ -227,12 +238,15 @@ def _follow_person( logger.info(f"EdgeTAM initialized with {len(initial_detections)} detections") + with self._lock: + self._tool_stream = ToolStream("follow_person") + self._tool_stream.start() self._thread = Thread(target=self._follow_loop, args=(tracker, query), daemon=True) self._thread.start() message = ( "Found the person. Starting to follow. You can stop following by calling " - "the 'stop_following' tool." + "the 'stop_following' tool. You will receive streaming updates." ) if self._patrolling_module_spec.is_patrolling(): @@ -304,7 +318,12 @@ def _stop_following(self) -> None: def _send_stop_reason(self, query: str, reason: str) -> None: self.cmd_vel.publish(Twist.zero()) message = f"Person follow stopped for '{query}'. Reason: {reason}." - self._agent_spec.add_message(HumanMessage(message)) + with self._lock: + stream = self._tool_stream + self._tool_stream = None + if stream is not None: + stream.send(message) + stream.stop() logger.info("Person follow stopped", query=query, reason=reason) diff --git a/dimos/perception/perceive_loop_skill.py b/dimos/perception/perceive_loop_skill.py index ac649f512f..571ce0fa6e 100644 --- a/dimos/perception/perceive_loop_skill.py +++ b/dimos/perception/perceive_loop_skill.py @@ -21,10 +21,10 @@ from typing import TYPE_CHECKING, Any import cv2 -from langchain_core.messages import HumanMessage from dimos.agents.agent_spec import AgentSpec from dimos.agents.annotation import skill +from dimos.agents.mcp.tool_stream import ToolStream from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.stream import In @@ -57,6 +57,7 @@ def __init__(self, **kwargs: Any) -> None: self._lookout_subscription: DisposableBase | None = None self._model_started: bool = False self._lock = RLock() + self._tool_stream: ToolStream | None = None @rpc def start(self) -> None: @@ -118,6 +119,8 @@ def look_out_for( self._model_started = True self._active_lookout = tuple(description_of_things) self._then = then + self._tool_stream = ToolStream("look_out_for") + self._tool_stream.start() self._lookout_subscription = sharpest.subscribe( on_next=self._on_image, on_error=lambda e: logger.exception("Error in perceive loop", exc_info=e), @@ -163,13 +166,20 @@ def _on_image(self, image: Image) -> None: self._then = None self._vl_model.stop() self._model_started = False + tool_stream = self._tool_stream + self._tool_stream = None if then is None: - self._agent_spec.add_message( - HumanMessage(f"Found a match for {active_lookout_str}. Please announce audibly.") - ) + if tool_stream is not None: + tool_stream.send( + f"Found a match for {active_lookout_str}. Please announce audibly." + ) + tool_stream.stop() return + if tool_stream is not None: + tool_stream.stop() + best = max(detections.detections, key=lambda d: d.bbox_2d_volume()) continuation_context: dict[str, Any] = { "bbox": list(best.bbox), @@ -194,6 +204,9 @@ def _stop_lookout(self) -> None: if self._model_started: self._vl_model.stop() self._model_started = False + if self._tool_stream is not None: + self._tool_stream.stop() + self._tool_stream = None def _write_debug_image(image: Image, detections: ImageDetections2D[Detection2DBBox]) -> None: From 2b07827e33c1e0215a8bcaa4023ec289adc5959b Mon Sep 17 00:00:00 2001 From: Paul Nechifor Date: Tue, 31 Mar 2026 04:56:36 +0300 Subject: [PATCH 2/2] rework how streams work --- dimos/agents/mcp/mcp_client.py | 116 +++++++++++------------ dimos/agents/mcp/mcp_server.py | 112 +++++++++++++++++++--- dimos/agents/mcp/test_mcp_client_unit.py | 17 ++++ dimos/agents/mcp/test_tool_stream.py | 92 ++++++++---------- 4 files changed, 209 insertions(+), 128 deletions(-) diff --git a/dimos/agents/mcp/mcp_client.py b/dimos/agents/mcp/mcp_client.py index 8b671c7c1b..51e05ebbd8 100644 --- a/dimos/agents/mcp/mcp_client.py +++ b/dimos/agents/mcp/mcp_client.py @@ -27,7 +27,6 @@ from langgraph.graph.state import CompiledStateGraph from reactivex.disposable import Disposable -from dimos.agents.mcp.tool_stream import ToolStreamEvent from dimos.agents.system_prompt import SYSTEM_PROMPT from dimos.agents.utils import pretty_print_langchain_message from dimos.core.core import rpc @@ -62,8 +61,6 @@ class McpClient(Module[McpClientConfig]): _stop_event: Event _http_client: httpx.Client _seq_ids: SequentialIds - _sse_thread: Thread | None - _sse_client: httpx.Client | None def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -80,8 +77,6 @@ def __init__(self, **kwargs: Any) -> None: self._stop_event = Event() self._http_client = httpx.Client(timeout=120.0) self._seq_ids = SequentialIds() - self._sse_thread = None - self._sse_client = None def __reduce__(self) -> Any: return (self.__class__, (), {}) @@ -105,6 +100,59 @@ def _mcp_request(self, method: str, params: dict[str, Any] | None = None) -> dic result: dict[str, Any] = data.get("result") return result + def _mcp_tool_call(self, name: str, arguments: dict[str, Any]) -> dict[str, Any]: + """Execute a tool call, handling both JSON and SSE streaming responses.""" + body: dict[str, Any] = { + "jsonrpc": "2.0", + "id": self._seq_ids.next(), + "method": "tools/call", + "params": {"name": name, "arguments": arguments}, + } + + with self._http_client.stream( + "POST", + self.config.mcp_server_url, + json=body, + headers={"Accept": "application/json, text/event-stream"}, + ) as resp: + resp.raise_for_status() + content_type = resp.headers.get("content-type", "") + + if "text/event-stream" in content_type: + return self._consume_sse_tool_response(resp, name) + + data = json.loads(resp.read()) + if "error" in data: + raise RuntimeError(f"MCP error {data['error']['code']}: {data['error']['message']}") + result: dict[str, Any] = data.get("result", {}) + return result + + def _consume_sse_tool_response( + self, response: httpx.Response, tool_name: str + ) -> dict[str, Any]: + """Parse an SSE tool response, injecting notifications as HumanMessages.""" + result: dict[str, Any] | None = None + for line in response.iter_lines(): + if not line.startswith("data: "): + continue + try: + data = json.loads(line[6:]) + except json.JSONDecodeError: + continue + + if data.get("method") == "notifications/message": + text = data.get("params", {}).get("data", "") + if text: + self._message_queue.put( + HumanMessage(content=f"[Tool stream update from '{tool_name}']: {text}") + ) + elif "result" in data: + result = data["result"] + + if result is None: + return {"content": [{"type": "text", "text": "Stream ended without result."}]} + return result + def _fetch_tools(self, timeout: float = 60.0, interval: float = 1.0) -> list[StructuredTool]: result = self._try_fetch_tools(timeout=timeout, interval=interval) if result is None: @@ -144,7 +192,7 @@ def _mcp_tool_to_langchain(self, mcp_tool: dict[str, Any]) -> StructuredTool: input_schema = mcp_tool.get("inputSchema", {"type": "object", "properties": {}}) def call_tool(**kwargs: Any) -> str: - result = self._mcp_request("tools/call", {"name": name, "arguments": kwargs}) + result = self._mcp_tool_call(name, kwargs) content = result.get("content", []) parts = [c.get("text", "") for c in content if c.get("type") == "text"] text = "\n".join(parts) @@ -194,18 +242,11 @@ def on_system_modules(self, _modules: list[RPCClient]) -> None: ) self._thread.start() - self._start_sse_listener() - @rpc def stop(self) -> None: self._stop_event.set() - with self._lock: - if self._sse_client is not None: - self._sse_client.close() if self._thread.is_alive(): self._thread.join(timeout=2.0) - if self._sse_thread is not None and self._sse_thread.is_alive(): - self._sse_thread.join(timeout=2.0) self._http_client.close() super().stop() @@ -254,7 +295,7 @@ def dispatch_continuation( tool_args[key] = continuation_context[context_key] try: - result = self._mcp_request("tools/call", {"name": tool_name, "arguments": tool_args}) + result = self._mcp_tool_call(tool_name, tool_args) content = result.get("content", []) parts = [c.get("text", "") for c in content if c.get("type") == "text"] text = "\n".join(parts) @@ -302,53 +343,6 @@ def _process_message( if self._message_queue.empty(): self.agent_idle.publish(True) - def _start_sse_listener(self) -> None: - """Connect to the MCP server SSE endpoint to receive tool stream updates.""" - self._sse_thread = Thread(target=self._sse_loop, name="McpClient-SSE", daemon=True) - self._sse_thread.start() - - def _sse_loop(self) -> None: - base_url = self.config.mcp_server_url.rsplit("/mcp", 1)[0] - sse_url = f"{base_url}/mcp/streams" - - while not self._stop_event.is_set(): - try: - self._sse_connect(sse_url) - except Exception: - if not self._stop_event.is_set(): - # Try reconnecting after a short delay - time.sleep(1.0) - - def _sse_connect(self, sse_url: str) -> None: - client = httpx.Client(timeout=None) - with self._lock: - self._sse_client = client - try: - with client.stream("GET", sse_url) as response: - self._sse_consume(response) - finally: - with self._lock: - self._sse_client = None - client.close() - - def _sse_consume(self, response: httpx.Response) -> None: - for line in response.iter_lines(): - if self._stop_event.is_set(): - return - if not line.startswith("data: "): - continue - try: - data = json.loads(line[6:]) - except json.JSONDecodeError: - continue - event = ToolStreamEvent(**data) - if event.type == "update": - self._message_queue.put( - HumanMessage( - content=f"[Tool stream update from '{event.tool_name}']: {event.text}" - ) - ) - def _append_image_to_history( mcp_client: McpClient, func_name: str, uuid_: str, result: Any diff --git a/dimos/agents/mcp/mcp_server.py b/dimos/agents/mcp/mcp_server.py index 2431bf8231..439a51ec95 100644 --- a/dimos/agents/mcp/mcp_server.py +++ b/dimos/agents/mcp/mcp_server.py @@ -16,6 +16,7 @@ import asyncio from collections.abc import AsyncGenerator import concurrent.futures +from dataclasses import dataclass import json import os import time @@ -30,6 +31,7 @@ import uvicorn from dimos.agents.annotation import skill +from dimos.agents.mcp.tool_stream import ToolStreamEvent from dimos.core.core import rpc from dimos.core.module import Module from dimos.core.rpc_client import RpcCall, RPCClient @@ -66,12 +68,20 @@ def _jsonrpc_error(req_id: Any, code: int, message: str) -> dict[str, Any]: return {"jsonrpc": "2.0", "id": req_id, "error": {"code": code, "message": message}} +@dataclass(frozen=True, slots=True) +class _StreamingToolResult: + """Marker returned when a tool starts a background stream.""" + + req_id: Any + tool_name: str + + def _handle_initialize(req_id: Any) -> dict[str, Any]: return _jsonrpc_result( req_id, { "protocolVersion": "2025-11-25", - "capabilities": {"tools": {}}, + "capabilities": {"tools": {}, "logging": {}}, "serverInfo": {"name": "dimensional", "version": "1.0.0"}, }, ) @@ -94,7 +104,7 @@ def _handle_tools_list(req_id: Any, skills: list[SkillInfo]) -> dict[str, Any]: async def _handle_tools_call( req_id: Any, params: dict[str, Any], rpc_calls: dict[str, Any] -) -> dict[str, Any]: +) -> dict[str, Any] | _StreamingToolResult: name = params.get("name", "") args: dict[str, Any] = params.get("arguments") or {} @@ -115,8 +125,8 @@ async def _handle_tools_call( duration = f"{time.monotonic() - t0:.3f}s" if result is None: - logger.info("MCP tool done (async)", tool=name, duration=duration) - return _jsonrpc_result_text(req_id, "It has started. You will be updated later.") + logger.info("MCP tool streaming", tool=name, duration=duration) + return _StreamingToolResult(req_id=req_id, tool_name=name) response = str(result)[:200] if hasattr(result, "agent_encode"): @@ -131,7 +141,7 @@ async def handle_request( request: dict[str, Any], skills: list[SkillInfo], rpc_calls: dict[str, Any], -) -> dict[str, Any] | None: +) -> dict[str, Any] | _StreamingToolResult | None: """Handle a single MCP JSON-RPC request. Returns None for JSON-RPC notifications (no ``id``), which must not @@ -165,27 +175,101 @@ async def mcp_endpoint(request: Request) -> Response: {"jsonrpc": "2.0", "id": None, "error": {"code": -32700, "message": "Parse error"}}, status_code=400, ) + + # Pre-register a queue for tool stream events when the client accepts SSE. + accept = request.headers.get("accept", "") + is_tool_call = body.get("method") == "tools/call" + client_accepts_sse = "text/event-stream" in accept + + stream_queue: asyncio.Queue[dict[str, Any]] | None = None + if is_tool_call and client_accepts_sse: + stream_queue = asyncio.Queue() + app.state.sse_queues.append(stream_queue) + result = await handle_request(body, request.app.state.skills, request.app.state.rpc_calls) + + # Streaming tool: return SSE response if client supports it. + if isinstance(result, _StreamingToolResult): + if stream_queue is not None: + return _streaming_tool_response(result, stream_queue) + # Client doesn't support SSE — fall back to immediate JSON. + return JSONResponse( + _jsonrpc_result_text(result.req_id, "It has started. You will be updated later.") + ) + + # Non-streaming: remove the pre-registered queue if any. + if stream_queue is not None: + try: + app.state.sse_queues.remove(stream_queue) + except ValueError: + pass + if result is None: return Response(status_code=204) return JSONResponse(result) -@app.get("/mcp/streams") -async def streams_sse_endpoint() -> StreamingResponse: - """Server-Sent Events endpoint for tool stream updates. +_STREAM_TIMEOUT = 300.0 # seconds + + +def _sse_event(data: dict[str, Any]) -> str: + """Format a JSON-RPC message as an SSE ``event: message`` frame.""" + return f"event: message\ndata: {json.dumps(data)}\n\n" + + +def _streaming_tool_response( + streaming: _StreamingToolResult, + queue: asyncio.Queue[dict[str, Any]], +) -> StreamingResponse: + """Build an SSE response that forwards ToolStream events as MCP log notifications. - Clients subscribe here to receive real-time updates from long-running - skills that use ``ToolStream``. + The response streams ``notifications/message`` for each update and ends + with the JSON-RPC result carrying the accumulated text. """ - queue: asyncio.Queue[dict[str, Any]] = asyncio.Queue() - app.state.sse_queues.append(queue) async def event_generator() -> AsyncGenerator[str, None]: + stream_id: str | None = None + collected: list[str] = [] try: while True: - data = await queue.get() - yield f"data: {json.dumps(data)}\n\n" + try: + data = await asyncio.wait_for(queue.get(), timeout=_STREAM_TIMEOUT) + except asyncio.TimeoutError: + text = "\n".join(collected) if collected else "No updates received." + yield _sse_event(_jsonrpc_result_text(streaming.req_id, text)) + return + + try: + event = ToolStreamEvent(**data) + except (TypeError, KeyError): + continue + + # Filter: match by tool_name, lock onto the first stream_id seen. + if event.tool_name != streaming.tool_name: + continue + if stream_id is None: + stream_id = event.stream_id + elif event.stream_id != stream_id: + continue + + if event.type == "update" and event.text: + collected.append(event.text) + yield _sse_event( + { + "jsonrpc": "2.0", + "method": "notifications/message", + "params": { + "level": "info", + "logger": event.tool_name, + "data": event.text, + }, + } + ) + + elif event.type == "close": + text = "\n".join(collected) if collected else "Stream completed." + yield _sse_event(_jsonrpc_result_text(streaming.req_id, text)) + return except asyncio.CancelledError: pass finally: diff --git a/dimos/agents/mcp/test_mcp_client_unit.py b/dimos/agents/mcp/test_mcp_client_unit.py index 8cd888f851..bc1a4bfcaa 100644 --- a/dimos/agents/mcp/test_mcp_client_unit.py +++ b/dimos/agents/mcp/test_mcp_client_unit.py @@ -91,11 +91,28 @@ def _mock_post(url: str, **kwargs: object) -> MagicMock: return resp +def _mock_stream(_method: str, _url: str, **kwargs: object) -> MagicMock: + """Wrap ``_mock_post`` result as an httpx streaming context manager.""" + post_resp = _mock_post(_url, **kwargs) + json_bytes = json.dumps(post_resp.json.return_value).encode() + + resp = MagicMock() + resp.raise_for_status = MagicMock() + resp.headers = {"content-type": "application/json"} + resp.read.return_value = json_bytes + + cm = MagicMock() + cm.__enter__ = MagicMock(return_value=resp) + cm.__exit__ = MagicMock(return_value=False) + return cm + + @pytest.fixture def mcp_client() -> McpClient: """Build an McpClient wired to the mock MCP post handler.""" mock_http = MagicMock() mock_http.post.side_effect = _mock_post + mock_http.stream.side_effect = _mock_stream with patch("dimos.agents.mcp.mcp_client.httpx.Client", return_value=mock_http): client = McpClient.__new__(McpClient) diff --git a/dimos/agents/mcp/test_tool_stream.py b/dimos/agents/mcp/test_tool_stream.py index a1ea103091..64cd1b08af 100644 --- a/dimos/agents/mcp/test_tool_stream.py +++ b/dimos/agents/mcp/test_tool_stream.py @@ -13,7 +13,7 @@ # limitations under the License. import json -from threading import Event, Thread +from threading import Thread import time import httpx @@ -23,7 +23,7 @@ from dimos.agents.annotation import skill from dimos.agents.mcp.mcp_adapter import McpAdapter from dimos.agents.mcp.mcp_server import McpServer -from dimos.agents.mcp.tool_stream import ToolStream, ToolStreamEvent +from dimos.agents.mcp.tool_stream import ToolStream from dimos.core.blueprints import autoconnect from dimos.core.global_config import global_config from dimos.core.module import Module @@ -67,58 +67,44 @@ def mcp_server(): @pytest.mark.slow -def test_tool_stream_sse(mcp_server: McpAdapter) -> None: - """ToolStream updates flow through the HTTP SSE endpoint.""" +def test_tool_stream_inline_sse(mcp_server: McpAdapter) -> None: + """Streaming tool returns inline SSE notifications when client accepts SSE.""" adapter = mcp_server - base_url = adapter.url.rsplit("/mcp", 1)[0] - sse_url = f"{base_url}/mcp/streams" - - events: list[ToolStreamEvent] = [] - sse_connected = Event() - sse_done = Event() - - def sse_reader() -> None: - try: - with httpx.Client(timeout=None) as client: - with client.stream("GET", sse_url) as resp: - sse_connected.set() - for line in resp.iter_lines(): - if not line.startswith("data: "): - continue - event = ToolStreamEvent(**json.loads(line[6:])) - events.append(event) - if event.type == "close": - sse_done.set() - return - except Exception: - pass # connection closed during teardown - - reader = Thread(target=sse_reader, daemon=True) - reader.start() - assert sse_connected.wait(5), "SSE connection was not established" - time.sleep(0.5) # let the server register the SSE queue - - # Call the streaming tool via MCP. - result = adapter.call_tool("start_streaming", {"count": 3}) - assert "It has started" in result["content"][0]["text"] - - # Wait for all SSE events (updates + close). - assert sse_done.wait(10), "Timed out waiting for tool stream close event" - - updates = [e for e in events if e.type == "update"] - closes = [e for e in events if e.type == "close"] - - assert len(updates) == 3 - assert updates[0].text == "Update 1 of 3" - assert updates[1].text == "Update 2 of 3" - assert updates[2].text == "Update 3 of 3" - - assert len(closes) == 1 - - # All events share the same stream id and tool name. - stream_ids = {e.stream_id for e in events} - assert len(stream_ids) == 1 - assert all(e.tool_name == "start_streaming" for e in events) + adapter.initialize() + + body = { + "jsonrpc": "2.0", + "id": 42, + "method": "tools/call", + "params": {"name": "start_streaming", "arguments": {"count": 3}}, + } + + events = [] + with httpx.Client(timeout=30.0) as client: + with client.stream( + "POST", + adapter.url, + json=body, + headers={"Accept": "application/json, text/event-stream"}, + ) as response: + assert response.headers["content-type"].startswith("text/event-stream") + for line in response.iter_lines(): + if line.startswith("data: "): + events.append(json.loads(line[6:])) + + notifications = [e for e in events if e.get("method") == "notifications/message"] + results = [e for e in events if "result" in e] + + assert len(notifications) == 3 + assert notifications[0]["params"]["data"] == "Update 1 of 3" + assert notifications[1]["params"]["data"] == "Update 2 of 3" + assert notifications[2]["params"]["data"] == "Update 3 of 3" + + assert len(results) == 1 + assert results[0]["id"] == 42 + content_text = results[0]["result"]["content"][0]["text"] + assert "Update 1 of 3" in content_text + assert "Update 3 of 3" in content_text @pytest.mark.slow