diff --git a/multi_ssh_mcp.py b/multi_ssh_mcp.py index 8a3a497..1ad71de 100644 --- a/multi_ssh_mcp.py +++ b/multi_ssh_mcp.py @@ -20,6 +20,7 @@ import paramiko from fastmcp import FastMCP +from fastmcp.server.auth.providers.jwt import StaticTokenVerifier import jc # Import security utilities @@ -424,8 +425,23 @@ def main(): logger.info(f"Loaded {len(ssh_manager.servers_config)} server(s): {', '.join(ssh_manager.servers_config.keys())}") # Create FastMCP server - mcp = FastMCP("Multi-SSH Server") + bearer_token = os.environ.get("MCP_TOKEN") + + if bearer_token: + verifier = StaticTokenVerifier( + tokens={ + bearer_token: { + "client_id": "superuser", + "scopes": ["read:data", "write:data", "admin:users"] + } + }, + required_scopes=["read:data"] + ) + mcp = FastMCP("Multi-SSH Server", auth=verifier) + else: + mcp = FastMCP("Multi-SSH Server") + @mcp.tool() def list_servers() -> str: """List all configured SSH servers with their details""" @@ -752,19 +768,30 @@ def network_diagnostics(command_type: str, destination: str, server_name: str = return f"{command_type} failed: {result['error']}" # Run the FastMCP server with transport selection - transport = os.environ.get("MCP_TRANSPORT", "stdio") + transport = os.environ.get("MCP_TRANSPORT", "stdio").lower() if transport == "sse": - # Server-Sent Events mode for HTTP streaming + # Server-Sent Events mode for HTTP streaming (not recommended) import uvicorn - from fastmcp.sse import create_sse_transport + sse_transport = mcp.sse_app() port = int(os.environ.get("MCP_PORT", "8080")) host = os.environ.get("MCP_HOST", "0.0.0.0") logger.info(f"Starting SSE transport on {host}:{port}") - sse_transport = create_sse_transport(mcp, host=host, port=port) uvicorn.run(sse_transport, host=host, port=port, log_level="info") + + elif transport == "http": + # Streamable HTTP transport (recommended) + import uvicorn + http_transport = mcp.http_app() + + port = int(os.environ.get("MCP_PORT", "8080")) + host = os.environ.get("MCP_HOST", "0.0.0.0") + + logger.info(f"Starting Streamable HTTP transport on {host}:{port}") + uvicorn.run(http_transport, host=host, port=port, log_level="info") + else: # Default stdio transport logger.info("Starting stdio transport") @@ -772,4 +799,4 @@ def network_diagnostics(command_type: str, destination: str, server_name: str = if __name__ == "__main__": - main() \ No newline at end of file + main()