From 828f1e7faa039e7563717c8b6303403ef57aa6cb Mon Sep 17 00:00:00 2001 From: Bartosz Burda Date: Thu, 19 Feb 2026 09:17:24 +0000 Subject: [PATCH 1/4] feat: add McpPlugin protocol and entry_points discovery Plugins register via entry_points group 'ros2_medkit_mcp.plugins'. Plugin tools merged into list_tools() and call_tool() dispatch. Startup/shutdown lifecycle in server_stdio and server_http. Broken plugins logged and skipped. --- src/ros2_medkit_mcp/mcp_app.py | 28 ++++++++++++--- src/ros2_medkit_mcp/plugin.py | 50 ++++++++++++++++++++++++++ src/ros2_medkit_mcp/server_http.py | 17 ++++++++- src/ros2_medkit_mcp/server_stdio.py | 18 +++++++++- tests/test_plugin_discovery.py | 56 +++++++++++++++++++++++++++++ 5 files changed, 163 insertions(+), 6 deletions(-) create mode 100644 src/ros2_medkit_mcp/plugin.py create mode 100644 tests/test_plugin_discovery.py diff --git a/src/ros2_medkit_mcp/mcp_app.py b/src/ros2_medkit_mcp/mcp_app.py index 9c89c05..98521ca 100644 --- a/src/ros2_medkit_mcp/mcp_app.py +++ b/src/ros2_medkit_mcp/mcp_app.py @@ -14,6 +14,7 @@ from ros2_medkit_mcp.client import SovdClient, SovdClientError from ros2_medkit_mcp.config import Settings +from ros2_medkit_mcp.plugin import McpPlugin from ros2_medkit_mcp.models import ( AppIdArgs, AreaComponentsArgs, @@ -630,18 +631,19 @@ async def download_rosbags_for_fault( } -def register_tools(server: Server, client: SovdClient) -> None: +def register_tools(server: Server, client: SovdClient, plugins: list[McpPlugin] | None = None) -> None: """Register all MCP tools on the server. Args: server: The MCP server to register tools on. client: The SOVD client for making API calls. + plugins: Optional list of plugins providing additional tools. """ @server.list_tools() async def list_tools() -> list[Tool]: """List available tools.""" - return [ + tools = [ # ==================== Discovery ==================== Tool( name="sovd_version", @@ -1491,6 +1493,14 @@ async def list_tools() -> list[Tool]: }, ), ] + # Append plugin tools + if plugins: + for plugin in plugins: + try: + tools.extend(plugin.list_tools()) + except Exception: + logger.exception("Failed to list tools from plugin: %s", plugin.name) + return tools @server.call_tool() async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: @@ -1794,6 +1804,15 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: ) else: + # Try plugins before reporting unknown tool + if plugins: + for plugin in plugins: + try: + plugin_tool_names = {t.name for t in plugin.list_tools()} + if normalized_name in plugin_tool_names: + return await plugin.call_tool(normalized_name, arguments) + except Exception: + logger.exception("Plugin %s failed to handle tool %s", plugin.name, normalized_name) return format_error(f"Unknown tool: {name}") except SovdClientError as e: @@ -1849,15 +1868,16 @@ async def read_resource(uri: str) -> list[TextContent]: raise ValueError(f"Unknown resource URI: {uri}") -def setup_mcp_app(server: Server, settings: Settings, client: SovdClient) -> None: +def setup_mcp_app(server: Server, settings: Settings, client: SovdClient, plugins: list[McpPlugin] | None = None) -> None: """Set up the complete MCP application. Args: server: The MCP server to configure. settings: Application settings. client: The SOVD client for API calls. + plugins: Optional list of plugins providing additional tools. """ - register_tools(server, client) + register_tools(server, client, plugins=plugins) register_resources(server) logger.info( "MCP server configured for %s", diff --git a/src/ros2_medkit_mcp/plugin.py b/src/ros2_medkit_mcp/plugin.py new file mode 100644 index 0000000..8d06fc7 --- /dev/null +++ b/src/ros2_medkit_mcp/plugin.py @@ -0,0 +1,50 @@ +"""Plugin interface for ros2_medkit_mcp. + +Third-party packages can register as plugins via entry_points: + + [project.entry-points."ros2_medkit_mcp.plugins"] + my_plugin = "my_package.plugin:MyPlugin" + +Plugins must implement the McpPlugin protocol. +""" + +from __future__ import annotations + +import logging +from importlib.metadata import entry_points +from typing import Any, Protocol + +from mcp.types import TextContent, Tool + +logger = logging.getLogger(__name__) + +PLUGIN_GROUP = "ros2_medkit_mcp.plugins" + + +class McpPlugin(Protocol): + """Interface for MCP server plugins.""" + + @property + def name(self) -> str: ... + + def list_tools(self) -> list[Tool]: ... + + async def call_tool(self, name: str, arguments: dict[str, Any]) -> list[TextContent]: ... + + async def startup(self) -> None: ... + + async def shutdown(self) -> None: ... + + +def discover_plugins() -> list[McpPlugin]: + """Discover and instantiate plugins registered via entry_points.""" + plugins: list[McpPlugin] = [] + for ep in entry_points(group=PLUGIN_GROUP): + try: + plugin_cls = ep.load() + plugin = plugin_cls() + logger.info("Discovered plugin: %s (from %s)", plugin.name, ep.value) + plugins.append(plugin) + except Exception: + logger.exception("Failed to load plugin: %s", ep.name) + return plugins diff --git a/src/ros2_medkit_mcp/server_http.py b/src/ros2_medkit_mcp/server_http.py index bead326..1c83f4f 100644 --- a/src/ros2_medkit_mcp/server_http.py +++ b/src/ros2_medkit_mcp/server_http.py @@ -18,6 +18,7 @@ from ros2_medkit_mcp.client import SovdClient from ros2_medkit_mcp.config import get_settings from ros2_medkit_mcp.mcp_app import create_mcp_server, setup_mcp_app +from ros2_medkit_mcp.plugin import discover_plugins # Configure logging logging.basicConfig( @@ -36,7 +37,8 @@ def create_app() -> Starlette: settings = get_settings() mcp_server = create_mcp_server() client = SovdClient(settings) - setup_mcp_app(mcp_server, settings, client) + plugins = discover_plugins() + setup_mcp_app(mcp_server, settings, client, plugins=plugins) # Create SSE transport - path is where clients POST messages sse_transport = SseServerTransport("/mcp/messages/") @@ -81,9 +83,22 @@ async def on_startup() -> None: """Application startup handler.""" logger.info("ros2_medkit MCP server starting (HTTP transport)") logger.info("Connecting to SOVD API at %s", settings.base_url) + # Start plugins + for plugin in plugins: + try: + await plugin.startup() + logger.info("Plugin started: %s", plugin.name) + except Exception: + logger.exception("Failed to start plugin: %s", plugin.name) async def on_shutdown() -> None: """Application shutdown handler.""" + # Shutdown plugins + for plugin in plugins: + try: + await plugin.shutdown() + except Exception: + logger.exception("Failed to shutdown plugin: %s", plugin.name) await client.close() logger.info("Server shutdown complete") diff --git a/src/ros2_medkit_mcp/server_stdio.py b/src/ros2_medkit_mcp/server_stdio.py index 1f32e6c..79c8b99 100644 --- a/src/ros2_medkit_mcp/server_stdio.py +++ b/src/ros2_medkit_mcp/server_stdio.py @@ -13,6 +13,7 @@ from ros2_medkit_mcp.client import SovdClient from ros2_medkit_mcp.config import get_settings from ros2_medkit_mcp.mcp_app import create_mcp_server, setup_mcp_app +from ros2_medkit_mcp.plugin import discover_plugins # Configure logging to stderr to avoid interfering with stdio transport logging.basicConfig( @@ -31,9 +32,18 @@ async def run_server() -> None: server = create_mcp_server() client = SovdClient(settings) + plugins = discover_plugins() try: - setup_mcp_app(server, settings, client) + # Start plugins + for plugin in plugins: + try: + await plugin.startup() + logger.info("Plugin started: %s", plugin.name) + except Exception: + logger.exception("Failed to start plugin: %s", plugin.name) + + setup_mcp_app(server, settings, client, plugins=plugins) async with stdio_server() as (read_stream, write_stream): await server.run( @@ -42,6 +52,12 @@ async def run_server() -> None: server.create_initialization_options(), ) finally: + # Shutdown plugins + for plugin in plugins: + try: + await plugin.shutdown() + except Exception: + logger.exception("Failed to shutdown plugin: %s", plugin.name) await client.close() logger.info("Server shutdown complete") diff --git a/tests/test_plugin_discovery.py b/tests/test_plugin_discovery.py new file mode 100644 index 0000000..8c25de1 --- /dev/null +++ b/tests/test_plugin_discovery.py @@ -0,0 +1,56 @@ +"""Tests for MCP plugin discovery and integration.""" + +from unittest.mock import MagicMock, patch + +import pytest +from mcp.types import TextContent, Tool + +from ros2_medkit_mcp.plugin import discover_plugins + + +class FakePlugin: + @property + def name(self) -> str: + return "fake" + + def list_tools(self) -> list[Tool]: + return [Tool(name="fake_tool", description="A fake tool", inputSchema={"type": "object", "properties": {}})] + + async def call_tool(self, name: str, arguments: dict) -> list[TextContent]: + if name == "fake_tool": + return [TextContent(type="text", text="fake result")] + raise ValueError(f"Unknown tool: {name}") + + async def startup(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + +class TestDiscoverPlugins: + @patch("ros2_medkit_mcp.plugin.entry_points") + def test_discovers_installed_plugins(self, mock_eps: MagicMock) -> None: + mock_ep = MagicMock() + mock_ep.name = "fake" + mock_ep.value = "fake_package.plugin:FakePlugin" + mock_ep.load.return_value = FakePlugin + mock_eps.return_value = [mock_ep] + plugins = discover_plugins() + assert len(plugins) == 1 + assert plugins[0].name == "fake" + + @patch("ros2_medkit_mcp.plugin.entry_points") + def test_no_plugins_installed(self, mock_eps: MagicMock) -> None: + mock_eps.return_value = [] + plugins = discover_plugins() + assert plugins == [] + + @patch("ros2_medkit_mcp.plugin.entry_points") + def test_broken_plugin_skipped(self, mock_eps: MagicMock) -> None: + mock_ep = MagicMock() + mock_ep.name = "broken" + mock_ep.load.side_effect = ImportError("no such module") + mock_eps.return_value = [mock_ep] + plugins = discover_plugins() + assert plugins == [] From f84d24dad8394d956d068c424e50004fe78ff573 Mon Sep 17 00:00:00 2001 From: Bartosz Burda Date: Thu, 19 Feb 2026 17:17:32 +0000 Subject: [PATCH 2/4] fix: address review feedback on plugin system MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix ruff lint errors (import order, unused import, type annotation) - Build tool→plugin map at registration time for O(1) dispatch - Make plugin dispatch authoritative (first declaring plugin owns it) - Track successfully started plugins; only shutdown those on exit --- src/ros2_medkit_mcp/mcp_app.py | 30 ++++++++++++++++------------- src/ros2_medkit_mcp/server_http.py | 7 +++++-- src/ros2_medkit_mcp/server_stdio.py | 6 ++++-- tests/test_plugin_discovery.py | 14 ++++++++++---- 4 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/ros2_medkit_mcp/mcp_app.py b/src/ros2_medkit_mcp/mcp_app.py index 98521ca..65cc6e3 100644 --- a/src/ros2_medkit_mcp/mcp_app.py +++ b/src/ros2_medkit_mcp/mcp_app.py @@ -14,7 +14,6 @@ from ros2_medkit_mcp.client import SovdClient, SovdClientError from ros2_medkit_mcp.config import Settings -from ros2_medkit_mcp.plugin import McpPlugin from ros2_medkit_mcp.models import ( AppIdArgs, AreaComponentsArgs, @@ -59,6 +58,7 @@ UpdateExecutionArgs, filter_entities, ) +from ros2_medkit_mcp.plugin import McpPlugin logger = logging.getLogger(__name__) @@ -631,7 +631,9 @@ async def download_rosbags_for_fault( } -def register_tools(server: Server, client: SovdClient, plugins: list[McpPlugin] | None = None) -> None: +def register_tools( + server: Server, client: SovdClient, plugins: list[McpPlugin] | None = None +) -> None: """Register all MCP tools on the server. Args: @@ -639,6 +641,8 @@ def register_tools(server: Server, client: SovdClient, plugins: list[McpPlugin] client: The SOVD client for making API calls. plugins: Optional list of plugins providing additional tools. """ + # Tool name → plugin mapping, built during list_tools and used for dispatch + plugin_tool_map: dict[str, McpPlugin] = {} @server.list_tools() async def list_tools() -> list[Tool]: @@ -1497,7 +1501,10 @@ async def list_tools() -> list[Tool]: if plugins: for plugin in plugins: try: - tools.extend(plugin.list_tools()) + plugin_tools = plugin.list_tools() + tools.extend(plugin_tools) + for t in plugin_tools: + plugin_tool_map[t.name] = plugin except Exception: logger.exception("Failed to list tools from plugin: %s", plugin.name) return tools @@ -1804,15 +1811,10 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]: ) else: - # Try plugins before reporting unknown tool - if plugins: - for plugin in plugins: - try: - plugin_tool_names = {t.name for t in plugin.list_tools()} - if normalized_name in plugin_tool_names: - return await plugin.call_tool(normalized_name, arguments) - except Exception: - logger.exception("Plugin %s failed to handle tool %s", plugin.name, normalized_name) + # Check plugin tool map before reporting unknown tool + plugin = plugin_tool_map.get(normalized_name) + if plugin is not None: + return await plugin.call_tool(normalized_name, arguments) return format_error(f"Unknown tool: {name}") except SovdClientError as e: @@ -1868,7 +1870,9 @@ async def read_resource(uri: str) -> list[TextContent]: raise ValueError(f"Unknown resource URI: {uri}") -def setup_mcp_app(server: Server, settings: Settings, client: SovdClient, plugins: list[McpPlugin] | None = None) -> None: +def setup_mcp_app( + server: Server, settings: Settings, client: SovdClient, plugins: list[McpPlugin] | None = None +) -> None: """Set up the complete MCP application. Args: diff --git a/src/ros2_medkit_mcp/server_http.py b/src/ros2_medkit_mcp/server_http.py index 1c83f4f..e8c73d3 100644 --- a/src/ros2_medkit_mcp/server_http.py +++ b/src/ros2_medkit_mcp/server_http.py @@ -79,6 +79,8 @@ async def health_check(_request: Request) -> JSONResponse: } ) + started_plugins: list = [] + async def on_startup() -> None: """Application startup handler.""" logger.info("ros2_medkit MCP server starting (HTTP transport)") @@ -87,14 +89,15 @@ async def on_startup() -> None: for plugin in plugins: try: await plugin.startup() + started_plugins.append(plugin) logger.info("Plugin started: %s", plugin.name) except Exception: logger.exception("Failed to start plugin: %s", plugin.name) async def on_shutdown() -> None: """Application shutdown handler.""" - # Shutdown plugins - for plugin in plugins: + # Only shutdown plugins that started successfully + for plugin in started_plugins: try: await plugin.shutdown() except Exception: diff --git a/src/ros2_medkit_mcp/server_stdio.py b/src/ros2_medkit_mcp/server_stdio.py index 79c8b99..2c927a8 100644 --- a/src/ros2_medkit_mcp/server_stdio.py +++ b/src/ros2_medkit_mcp/server_stdio.py @@ -34,11 +34,13 @@ async def run_server() -> None: client = SovdClient(settings) plugins = discover_plugins() + started_plugins = [] try: # Start plugins for plugin in plugins: try: await plugin.startup() + started_plugins.append(plugin) logger.info("Plugin started: %s", plugin.name) except Exception: logger.exception("Failed to start plugin: %s", plugin.name) @@ -52,8 +54,8 @@ async def run_server() -> None: server.create_initialization_options(), ) finally: - # Shutdown plugins - for plugin in plugins: + # Only shutdown plugins that started successfully + for plugin in started_plugins: try: await plugin.shutdown() except Exception: diff --git a/tests/test_plugin_discovery.py b/tests/test_plugin_discovery.py index 8c25de1..cde7b24 100644 --- a/tests/test_plugin_discovery.py +++ b/tests/test_plugin_discovery.py @@ -1,8 +1,8 @@ """Tests for MCP plugin discovery and integration.""" +from typing import Any from unittest.mock import MagicMock, patch -import pytest from mcp.types import TextContent, Tool from ros2_medkit_mcp.plugin import discover_plugins @@ -14,9 +14,15 @@ def name(self) -> str: return "fake" def list_tools(self) -> list[Tool]: - return [Tool(name="fake_tool", description="A fake tool", inputSchema={"type": "object", "properties": {}})] - - async def call_tool(self, name: str, arguments: dict) -> list[TextContent]: + return [ + Tool( + name="fake_tool", + description="A fake tool", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + async def call_tool(self, name: str, _arguments: dict[str, Any]) -> list[TextContent]: if name == "fake_tool": return [TextContent(type="text", text="fake result")] raise ValueError(f"Unknown tool: {name}") From f29d277915a162b61fe7e134ddbb025e3cf623d3 Mon Sep 17 00:00:00 2001 From: Bartosz Burda Date: Thu, 19 Feb 2026 17:33:00 +0000 Subject: [PATCH 3/4] fix: treat empty bearer token env var as unset --- src/ros2_medkit_mcp/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/ros2_medkit_mcp/config.py b/src/ros2_medkit_mcp/config.py index 659721f..842c6a6 100644 --- a/src/ros2_medkit_mcp/config.py +++ b/src/ros2_medkit_mcp/config.py @@ -27,7 +27,7 @@ class Settings(BaseModel): description="Base URL of the ros2_medkit SOVD API", ) bearer_token: str | None = Field( - default_factory=lambda: os.getenv("ROS2_MEDKIT_BEARER_TOKEN"), + default_factory=lambda: os.getenv("ROS2_MEDKIT_BEARER_TOKEN") or None, description="Optional Bearer token for authentication", ) timeout_seconds: float = Field( From deeb7f45fae39de1017aa8060a4a06b03d33becd Mon Sep 17 00:00:00 2001 From: Bartosz Burda Date: Sun, 22 Feb 2026 08:00:12 +0000 Subject: [PATCH 4/4] fix: address review feedback - collision check, lifecycle helpers, protocol validation - Add tool name collision detection against built-in tools (TOOL_ALIASES) and inter-plugin duplicates during plugin tool registration - Extract start_plugins()/shutdown_plugins() helpers to plugin.py, removing duplicated loops from server_stdio.py and server_http.py - Add hasattr check in discover_plugins() to skip non-conforming plugins - Pass only started_plugins to setup_mcp_app (not all discovered) - Move setup_mcp_app into on_startup in server_http.py so plugin tools are only registered after successful startup - Add tests for dispatch, collision detection, and lifecycle helpers --- src/ros2_medkit_mcp/mcp_app.py | 16 ++- src/ros2_medkit_mcp/plugin.py | 25 ++++ src/ros2_medkit_mcp/server_http.py | 23 +-- src/ros2_medkit_mcp/server_stdio.py | 22 +-- tests/test_plugin_discovery.py | 215 +++++++++++++++++++++++++++- 5 files changed, 263 insertions(+), 38 deletions(-) diff --git a/src/ros2_medkit_mcp/mcp_app.py b/src/ros2_medkit_mcp/mcp_app.py index 65cc6e3..1d8d5b6 100644 --- a/src/ros2_medkit_mcp/mcp_app.py +++ b/src/ros2_medkit_mcp/mcp_app.py @@ -1502,8 +1502,22 @@ async def list_tools() -> list[Tool]: for plugin in plugins: try: plugin_tools = plugin.list_tools() - tools.extend(plugin_tools) for t in plugin_tools: + if t.name in TOOL_ALIASES: + logger.warning( + "Plugin %s: tool '%s' collides with built-in tool, skipping", + plugin.name, + t.name, + ) + continue + if t.name in plugin_tool_map: + logger.warning( + "Plugin %s: tool '%s' collides with another plugin tool, skipping", + plugin.name, + t.name, + ) + continue + tools.append(t) plugin_tool_map[t.name] = plugin except Exception: logger.exception("Failed to list tools from plugin: %s", plugin.name) diff --git a/src/ros2_medkit_mcp/plugin.py b/src/ros2_medkit_mcp/plugin.py index 8d06fc7..9917350 100644 --- a/src/ros2_medkit_mcp/plugin.py +++ b/src/ros2_medkit_mcp/plugin.py @@ -43,8 +43,33 @@ def discover_plugins() -> list[McpPlugin]: try: plugin_cls = ep.load() plugin = plugin_cls() + if not hasattr(plugin, "name") or not hasattr(plugin, "list_tools"): + logger.warning("Plugin %s does not implement McpPlugin, skipping", ep.name) + continue logger.info("Discovered plugin: %s (from %s)", plugin.name, ep.value) plugins.append(plugin) except Exception: logger.exception("Failed to load plugin: %s", ep.name) return plugins + + +async def start_plugins(plugins: list[McpPlugin]) -> list[McpPlugin]: + """Start plugins, returning only those that started successfully.""" + started: list[McpPlugin] = [] + for plugin in plugins: + try: + await plugin.startup() + started.append(plugin) + logger.info("Plugin started: %s", plugin.name) + except Exception: + logger.exception("Failed to start plugin: %s", plugin.name) + return started + + +async def shutdown_plugins(plugins: list[McpPlugin]) -> None: + """Shut down plugins, logging errors without raising.""" + for plugin in plugins: + try: + await plugin.shutdown() + except Exception: + logger.exception("Failed to shutdown plugin: %s", plugin.name) diff --git a/src/ros2_medkit_mcp/server_http.py b/src/ros2_medkit_mcp/server_http.py index e8c73d3..2691bca 100644 --- a/src/ros2_medkit_mcp/server_http.py +++ b/src/ros2_medkit_mcp/server_http.py @@ -18,7 +18,7 @@ from ros2_medkit_mcp.client import SovdClient from ros2_medkit_mcp.config import get_settings from ros2_medkit_mcp.mcp_app import create_mcp_server, setup_mcp_app -from ros2_medkit_mcp.plugin import discover_plugins +from ros2_medkit_mcp.plugin import McpPlugin, discover_plugins, shutdown_plugins, start_plugins # Configure logging logging.basicConfig( @@ -38,7 +38,6 @@ def create_app() -> Starlette: mcp_server = create_mcp_server() client = SovdClient(settings) plugins = discover_plugins() - setup_mcp_app(mcp_server, settings, client, plugins=plugins) # Create SSE transport - path is where clients POST messages sse_transport = SseServerTransport("/mcp/messages/") @@ -79,29 +78,19 @@ async def health_check(_request: Request) -> JSONResponse: } ) - started_plugins: list = [] + started_plugins: list[McpPlugin] = [] async def on_startup() -> None: """Application startup handler.""" logger.info("ros2_medkit MCP server starting (HTTP transport)") logger.info("Connecting to SOVD API at %s", settings.base_url) - # Start plugins - for plugin in plugins: - try: - await plugin.startup() - started_plugins.append(plugin) - logger.info("Plugin started: %s", plugin.name) - except Exception: - logger.exception("Failed to start plugin: %s", plugin.name) + started = await start_plugins(plugins) + started_plugins.extend(started) + setup_mcp_app(mcp_server, settings, client, plugins=started_plugins) async def on_shutdown() -> None: """Application shutdown handler.""" - # Only shutdown plugins that started successfully - for plugin in started_plugins: - try: - await plugin.shutdown() - except Exception: - logger.exception("Failed to shutdown plugin: %s", plugin.name) + await shutdown_plugins(started_plugins) await client.close() logger.info("Server shutdown complete") diff --git a/src/ros2_medkit_mcp/server_stdio.py b/src/ros2_medkit_mcp/server_stdio.py index 2c927a8..ce183b5 100644 --- a/src/ros2_medkit_mcp/server_stdio.py +++ b/src/ros2_medkit_mcp/server_stdio.py @@ -13,7 +13,7 @@ from ros2_medkit_mcp.client import SovdClient from ros2_medkit_mcp.config import get_settings from ros2_medkit_mcp.mcp_app import create_mcp_server, setup_mcp_app -from ros2_medkit_mcp.plugin import discover_plugins +from ros2_medkit_mcp.plugin import discover_plugins, shutdown_plugins, start_plugins # Configure logging to stderr to avoid interfering with stdio transport logging.basicConfig( @@ -34,18 +34,9 @@ async def run_server() -> None: client = SovdClient(settings) plugins = discover_plugins() - started_plugins = [] + started_plugins = await start_plugins(plugins) try: - # Start plugins - for plugin in plugins: - try: - await plugin.startup() - started_plugins.append(plugin) - logger.info("Plugin started: %s", plugin.name) - except Exception: - logger.exception("Failed to start plugin: %s", plugin.name) - - setup_mcp_app(server, settings, client, plugins=plugins) + setup_mcp_app(server, settings, client, plugins=started_plugins) async with stdio_server() as (read_stream, write_stream): await server.run( @@ -54,12 +45,7 @@ async def run_server() -> None: server.create_initialization_options(), ) finally: - # Only shutdown plugins that started successfully - for plugin in started_plugins: - try: - await plugin.shutdown() - except Exception: - logger.exception("Failed to shutdown plugin: %s", plugin.name) + await shutdown_plugins(started_plugins) await client.close() logger.info("Server shutdown complete") diff --git a/tests/test_plugin_discovery.py b/tests/test_plugin_discovery.py index cde7b24..0a257b9 100644 --- a/tests/test_plugin_discovery.py +++ b/tests/test_plugin_discovery.py @@ -1,11 +1,13 @@ """Tests for MCP plugin discovery and integration.""" +import logging from typing import Any -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch +import pytest from mcp.types import TextContent, Tool -from ros2_medkit_mcp.plugin import discover_plugins +from ros2_medkit_mcp.plugin import discover_plugins, shutdown_plugins, start_plugins class FakePlugin: @@ -60,3 +62,212 @@ def test_broken_plugin_skipped(self, mock_eps: MagicMock) -> None: mock_eps.return_value = [mock_ep] plugins = discover_plugins() assert plugins == [] + + @patch("ros2_medkit_mcp.plugin.entry_points") + def test_non_conforming_plugin_skipped(self, mock_eps: MagicMock) -> None: + """Plugin without required attributes is skipped with warning.""" + + class BadPlugin: + pass + + mock_ep = MagicMock() + mock_ep.name = "bad" + mock_ep.load.return_value = BadPlugin + mock_eps.return_value = [mock_ep] + plugins = discover_plugins() + assert plugins == [] + + +class TestPluginLifecycle: + @pytest.mark.asyncio + async def test_start_plugins_returns_started(self) -> None: + plugin = FakePlugin() + started = await start_plugins([plugin]) + assert len(started) == 1 + assert started[0] is plugin + + @pytest.mark.asyncio + async def test_start_plugins_skips_failed(self) -> None: + good = FakePlugin() + bad = MagicMock() + bad.name = "bad" + bad.startup = AsyncMock(side_effect=RuntimeError("init failed")) + started = await start_plugins([good, bad]) + assert len(started) == 1 + assert started[0] is good + + @pytest.mark.asyncio + async def test_shutdown_plugins_calls_all(self) -> None: + p1 = MagicMock() + p1.name = "p1" + p1.shutdown = AsyncMock() + p2 = MagicMock() + p2.name = "p2" + p2.shutdown = AsyncMock() + await shutdown_plugins([p1, p2]) + p1.shutdown.assert_awaited_once() + p2.shutdown.assert_awaited_once() + + @pytest.mark.asyncio + async def test_shutdown_plugins_continues_on_error(self) -> None: + p1 = MagicMock() + p1.name = "p1" + p1.shutdown = AsyncMock(side_effect=RuntimeError("boom")) + p2 = MagicMock() + p2.name = "p2" + p2.shutdown = AsyncMock() + await shutdown_plugins([p1, p2]) + p2.shutdown.assert_awaited_once() + + +class TestPluginToolRegistration: + """Tests for plugin tool registration and dispatch in mcp_app.register_tools.""" + + def _make_server_mock(self) -> tuple[MagicMock, dict[str, Any]]: + """Create a mock Server that captures registered handlers.""" + server = MagicMock() + handlers: dict[str, Any] = {} + + def list_tools_decorator(): + def wrapper(fn: Any) -> Any: + handlers["list_tools"] = fn + return fn + + return wrapper + + def call_tool_decorator(): + def wrapper(fn: Any) -> Any: + handlers["call_tool"] = fn + return fn + + return wrapper + + server.list_tools = list_tools_decorator + server.call_tool = call_tool_decorator + return server, handlers + + @pytest.mark.asyncio + async def test_plugin_tool_dispatch(self) -> None: + """Plugin tools are dispatched via plugin_tool_map.""" + from ros2_medkit_mcp.mcp_app import register_tools + + server, handlers = self._make_server_mock() + client = MagicMock() + plugin = FakePlugin() + + register_tools(server, client, plugins=[plugin]) + + # list_tools should include the plugin tool + tools = await handlers["list_tools"]() + tool_names = {t.name for t in tools} + assert "fake_tool" in tool_names + + # call_tool should dispatch to plugin + result = await handlers["call_tool"]("fake_tool", {}) + assert len(result) == 1 + assert result[0].text == "fake result" + + @pytest.mark.asyncio + async def test_builtin_collision_skipped(self, caplog: pytest.LogCaptureFixture) -> None: + """Plugin tool colliding with built-in is skipped with warning.""" + from ros2_medkit_mcp.mcp_app import register_tools + + class CollidingPlugin: + @property + def name(self) -> str: + return "colliding" + + def list_tools(self) -> list[Tool]: + return [ + Tool( + name="sovd_health", + description="Collides with built-in", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + async def call_tool(self, _name: str, _arguments: dict[str, Any]) -> list[TextContent]: + return [TextContent(type="text", text="should not reach")] + + async def startup(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + server, handlers = self._make_server_mock() + client = MagicMock() + register_tools(server, client, plugins=[CollidingPlugin()]) + + with caplog.at_level(logging.WARNING): + tools = await handlers["list_tools"]() + + assert "collides with built-in tool" in caplog.text + + # sovd_health should appear exactly once (the built-in) + plugin_tool_names = [t.name for t in tools if t.name == "sovd_health"] + assert len(plugin_tool_names) == 1 + + @pytest.mark.asyncio + async def test_inter_plugin_collision_skipped(self, caplog: pytest.LogCaptureFixture) -> None: + """Second plugin declaring same tool name is skipped.""" + from ros2_medkit_mcp.mcp_app import register_tools + + class PluginA: + @property + def name(self) -> str: + return "plugin_a" + + def list_tools(self) -> list[Tool]: + return [ + Tool( + name="shared_tool", + description="From A", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + async def call_tool(self, _name: str, _arguments: dict[str, Any]) -> list[TextContent]: + return [TextContent(type="text", text="from A")] + + async def startup(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + class PluginB: + @property + def name(self) -> str: + return "plugin_b" + + def list_tools(self) -> list[Tool]: + return [ + Tool( + name="shared_tool", + description="From B", + inputSchema={"type": "object", "properties": {}}, + ) + ] + + async def call_tool(self, _name: str, _arguments: dict[str, Any]) -> list[TextContent]: + return [TextContent(type="text", text="from B")] + + async def startup(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + server, handlers = self._make_server_mock() + client = MagicMock() + register_tools(server, client, plugins=[PluginA(), PluginB()]) + + with caplog.at_level(logging.WARNING): + await handlers["list_tools"]() + + assert "collides with another plugin tool" in caplog.text + + # Dispatch should go to plugin A (first registered) + result = await handlers["call_tool"]("shared_tool", {}) + assert result[0].text == "from A"