diff --git a/server/gti/gti_mcp/server.py b/server/gti/gti_mcp/server.py index 8c4d4f50..4f6e258d 100644 --- a/server/gti/gti_mcp/server.py +++ b/server/gti/gti_mcp/server.py @@ -1,4 +1,5 @@ # Copyright 2025 Google LLC +# Modifications Copyright (c) 2025-2026 Deep Kanaparthi # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,9 +19,11 @@ import logging import os +from typing import Dict import vt from mcp.server.fastmcp import FastMCP, Context +from mcp.server.transport_security import TransportSecuritySettings logging.basicConfig(level=logging.ERROR) @@ -30,12 +33,46 @@ if os.getenv("STATELESS") == "1": stateless = True +# --------------------------------------------------------------------------- +# HTTP header access for multi-client credential passthrough +# --------------------------------------------------------------------------- +try: + from fastmcp.server.dependencies import get_http_headers +except ImportError: + get_http_headers = None + + +def _get_request_headers() -> Dict[str, str]: + """Get HTTP headers for the current MCP session via FastMCP's built-in context.""" + if get_http_headers is None: + return {} + try: + headers = get_http_headers() + if headers: + return headers + except Exception: + pass + return {} + + +def _get_gti_config() -> Dict[str, str]: + """Resolve VT config from HTTP headers with env var fallback.""" + headers = _get_request_headers() + h = {k.lower(): v for k, v in headers.items()} + return { + "api_key": h.get("x-vt-apikey", os.getenv("VT_APIKEY", "")), + "verify_ssl": os.getenv("VT_VERIFY_SSL", "true").lower() != "false", + } + def _vt_client_factory(unused_ctx) -> vt.Client: - api_key = os.getenv("VT_APIKEY") + cfg = _get_gti_config() + api_key = cfg["api_key"] if not api_key: - raise ValueError("VT_APIKEY environment variable is required") - return vt.Client(api_key) + raise ValueError( + "VT_APIKEY not configured. Set VT_APIKEY env var or send X-VT-ApiKey header." + ) + return vt.Client(api_key, verify_ssl=cfg["verify_ssl"]) vt_client_factory = _vt_client_factory @@ -54,7 +91,11 @@ async def vt_client(ctx: Context) -> AsyncIterator[vt.Client]: server = FastMCP( "Google Threat Intelligence MCP server", dependencies=["vt-py"], - stateless_http=stateless) + stateless_http=stateless, + transport_security=TransportSecuritySettings( + enable_dns_rebinding_protection=False, + ), +) # Load tools. from gti_mcp.tools import * @@ -65,4 +106,16 @@ def main(): if __name__ == '__main__': - main() + import sys + transport = sys.argv[1] if len(sys.argv) > 1 else os.getenv("MCP_TRANSPORT", "stdio") + + if transport == "streamable-http": + import uvicorn + app = server.streamable_http_app() + uvicorn.run( + app, + host=os.getenv("FASTMCP_HOST", "0.0.0.0"), + port=int(os.getenv("FASTMCP_PORT", "8003")), + ) + else: + main() diff --git a/server/gti/pyproject.toml b/server/gti/pyproject.toml index 8d6937dd..8e35ea75 100644 --- a/server/gti/pyproject.toml +++ b/server/gti/pyproject.toml @@ -18,8 +18,10 @@ classifiers = [ "Topic :: Security", ] dependencies = [ - "mcp", - "vt-py" + "mcp>=1.26.0", + "vt-py", + "fastmcp>=2.11.1", + "uvicorn>=0.34.0", ] [project.urls]