Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/ros2_medkit_mcp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class Settings(BaseModel):
description="Base URL of the ros2_medkit SOVD API",
)
bearer_token: str | None = Field(
default_factory=lambda: os.getenv("ROS2_MEDKIT_BEARER_TOKEN"),
default_factory=lambda: os.getenv("ROS2_MEDKIT_BEARER_TOKEN") or None,
description="Optional Bearer token for authentication",
)
timeout_seconds: float = Field(
Expand Down
46 changes: 42 additions & 4 deletions src/ros2_medkit_mcp/mcp_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
UpdateExecutionArgs,
filter_entities,
)
from ros2_medkit_mcp.plugin import McpPlugin

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -630,18 +631,23 @@ async def download_rosbags_for_fault(
}


def register_tools(server: Server, client: SovdClient) -> None:
def register_tools(
server: Server, client: SovdClient, plugins: list[McpPlugin] | None = None
) -> None:
"""Register all MCP tools on the server.

Args:
server: The MCP server to register tools on.
client: The SOVD client for making API calls.
plugins: Optional list of plugins providing additional tools.
"""
# Tool name → plugin mapping, built during list_tools and used for dispatch
plugin_tool_map: dict[str, McpPlugin] = {}

@server.list_tools()
async def list_tools() -> list[Tool]:
"""List available tools."""
return [
tools = [
# ==================== Discovery ====================
Tool(
name="sovd_version",
Expand Down Expand Up @@ -1491,6 +1497,31 @@ async def list_tools() -> list[Tool]:
},
),
]
# Append plugin tools
if plugins:
for plugin in plugins:
try:
plugin_tools = plugin.list_tools()
for t in plugin_tools:
if t.name in TOOL_ALIASES:
logger.warning(
"Plugin %s: tool '%s' collides with built-in tool, skipping",
plugin.name,
t.name,
)
continue
if t.name in plugin_tool_map:
logger.warning(
"Plugin %s: tool '%s' collides with another plugin tool, skipping",
plugin.name,
t.name,
)
continue
tools.append(t)
plugin_tool_map[t.name] = plugin
except Exception:
logger.exception("Failed to list tools from plugin: %s", plugin.name)
return tools

@server.call_tool()
async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
Expand Down Expand Up @@ -1794,6 +1825,10 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[TextContent]:
)

else:
# Check plugin tool map before reporting unknown tool
plugin = plugin_tool_map.get(normalized_name)
if plugin is not None:
return await plugin.call_tool(normalized_name, arguments)
return format_error(f"Unknown tool: {name}")

except SovdClientError as e:
Expand Down Expand Up @@ -1849,15 +1884,18 @@ async def read_resource(uri: str) -> list[TextContent]:
raise ValueError(f"Unknown resource URI: {uri}")


def setup_mcp_app(server: Server, settings: Settings, client: SovdClient) -> None:
def setup_mcp_app(
server: Server, settings: Settings, client: SovdClient, plugins: list[McpPlugin] | None = None
) -> None:
"""Set up the complete MCP application.

Args:
server: The MCP server to configure.
settings: Application settings.
client: The SOVD client for API calls.
plugins: Optional list of plugins providing additional tools.
"""
register_tools(server, client)
register_tools(server, client, plugins=plugins)
register_resources(server)
logger.info(
"MCP server configured for %s",
Expand Down
75 changes: 75 additions & 0 deletions src/ros2_medkit_mcp/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Plugin interface for ros2_medkit_mcp.

Third-party packages can register as plugins via entry_points:

[project.entry-points."ros2_medkit_mcp.plugins"]
my_plugin = "my_package.plugin:MyPlugin"

Plugins must implement the McpPlugin protocol.
"""

from __future__ import annotations

import logging
from importlib.metadata import entry_points
from typing import Any, Protocol

from mcp.types import TextContent, Tool

logger = logging.getLogger(__name__)

PLUGIN_GROUP = "ros2_medkit_mcp.plugins"


class McpPlugin(Protocol):
"""Interface for MCP server plugins."""

@property
def name(self) -> str: ...

def list_tools(self) -> list[Tool]: ...

async def call_tool(self, name: str, arguments: dict[str, Any]) -> list[TextContent]: ...

async def startup(self) -> None: ...

async def shutdown(self) -> None: ...


def discover_plugins() -> list[McpPlugin]:
"""Discover and instantiate plugins registered via entry_points."""
plugins: list[McpPlugin] = []
for ep in entry_points(group=PLUGIN_GROUP):
try:
plugin_cls = ep.load()
plugin = plugin_cls()
if not hasattr(plugin, "name") or not hasattr(plugin, "list_tools"):
logger.warning("Plugin %s does not implement McpPlugin, skipping", ep.name)
continue
logger.info("Discovered plugin: %s (from %s)", plugin.name, ep.value)
plugins.append(plugin)
except Exception:
logger.exception("Failed to load plugin: %s", ep.name)
return plugins


async def start_plugins(plugins: list[McpPlugin]) -> list[McpPlugin]:
"""Start plugins, returning only those that started successfully."""
started: list[McpPlugin] = []
for plugin in plugins:
try:
await plugin.startup()
started.append(plugin)
logger.info("Plugin started: %s", plugin.name)
except Exception:
logger.exception("Failed to start plugin: %s", plugin.name)
return started


async def shutdown_plugins(plugins: list[McpPlugin]) -> None:
"""Shut down plugins, logging errors without raising."""
for plugin in plugins:
try:
await plugin.shutdown()
except Exception:
logger.exception("Failed to shutdown plugin: %s", plugin.name)
9 changes: 8 additions & 1 deletion src/ros2_medkit_mcp/server_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ros2_medkit_mcp.client import SovdClient
from ros2_medkit_mcp.config import get_settings
from ros2_medkit_mcp.mcp_app import create_mcp_server, setup_mcp_app
from ros2_medkit_mcp.plugin import McpPlugin, discover_plugins, shutdown_plugins, start_plugins

# Configure logging
logging.basicConfig(
Expand All @@ -36,7 +37,7 @@ def create_app() -> Starlette:
settings = get_settings()
mcp_server = create_mcp_server()
client = SovdClient(settings)
setup_mcp_app(mcp_server, settings, client)
plugins = discover_plugins()

# Create SSE transport - path is where clients POST messages
sse_transport = SseServerTransport("/mcp/messages/")
Expand Down Expand Up @@ -77,13 +78,19 @@ async def health_check(_request: Request) -> JSONResponse:
}
)

started_plugins: list[McpPlugin] = []

async def on_startup() -> None:
"""Application startup handler."""
logger.info("ros2_medkit MCP server starting (HTTP transport)")
logger.info("Connecting to SOVD API at %s", settings.base_url)
started = await start_plugins(plugins)
started_plugins.extend(started)
setup_mcp_app(mcp_server, settings, client, plugins=started_plugins)

async def on_shutdown() -> None:
"""Application shutdown handler."""
await shutdown_plugins(started_plugins)
await client.close()
logger.info("Server shutdown complete")

Expand Down
6 changes: 5 additions & 1 deletion src/ros2_medkit_mcp/server_stdio.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from ros2_medkit_mcp.client import SovdClient
from ros2_medkit_mcp.config import get_settings
from ros2_medkit_mcp.mcp_app import create_mcp_server, setup_mcp_app
from ros2_medkit_mcp.plugin import discover_plugins, shutdown_plugins, start_plugins

# Configure logging to stderr to avoid interfering with stdio transport
logging.basicConfig(
Expand All @@ -31,9 +32,11 @@ async def run_server() -> None:

server = create_mcp_server()
client = SovdClient(settings)
plugins = discover_plugins()

started_plugins = await start_plugins(plugins)
try:
setup_mcp_app(server, settings, client)
setup_mcp_app(server, settings, client, plugins=started_plugins)

async with stdio_server() as (read_stream, write_stream):
await server.run(
Expand All @@ -42,6 +45,7 @@ async def run_server() -> None:
server.create_initialization_options(),
)
finally:
await shutdown_plugins(started_plugins)
await client.close()
logger.info("Server shutdown complete")

Expand Down
Loading