From dcb281b03af64613e073a0e479c962f243d01c85 Mon Sep 17 00:00:00 2001 From: knielsen404 Date: Thu, 22 May 2025 17:44:17 +0200 Subject: [PATCH 1/9] add session store --- src/mcp_scan/StorageFile.py | 3 +- src/mcp_scan/cli.py | 2 +- src/mcp_scan_server/routes/policies.py | 57 +++++++-- src/mcp_scan_server/server.py | 4 +- src/mcp_scan_server/session_store.py | 114 +++++++++++++++++ tests/e2e/test_full_proxy_flow.py | 4 +- tests/unit/test_session.py | 171 +++++++++++++++++++++++++ 7 files changed, 341 insertions(+), 14 deletions(-) create mode 100644 src/mcp_scan_server/session_store.py create mode 100644 tests/unit/test_session.py diff --git a/src/mcp_scan/StorageFile.py b/src/mcp_scan/StorageFile.py index 2b4acb0..1cf72de 100644 --- a/src/mcp_scan/StorageFile.py +++ b/src/mcp_scan/StorageFile.py @@ -7,9 +7,10 @@ import rich import yaml # type: ignore -from mcp_scan_server.models import DEFAULT_GUARDRAIL_CONFIG, GuardrailConfigFile from pydantic import ValidationError +from mcp_scan_server.models import DEFAULT_GUARDRAIL_CONFIG, GuardrailConfigFile + from .models import Entity, ScannedEntities, ScannedEntity, entity_type_to_str, hash_entity from .utils import upload_whitelist_entry diff --git a/src/mcp_scan/cli.py b/src/mcp_scan/cli.py index 3101ab6..95b27eb 100644 --- a/src/mcp_scan/cli.py +++ b/src/mcp_scan/cli.py @@ -7,10 +7,10 @@ import psutil import rich from invariant.__main__ import add_extra -from mcp_scan_server.server import MCPScanServer from rich.logging import RichHandler from mcp_scan.gateway import MCPGatewayConfig, MCPGatewayInstaller +from mcp_scan_server.server import MCPScanServer from .MCPScanner import MCPScanner from .paths import WELL_KNOWN_MCP_PATHS, client_shorthands_to_paths diff --git a/src/mcp_scan_server/routes/policies.py b/src/mcp_scan_server/routes/policies.py index c2496ca..8781ec8 100644 --- a/src/mcp_scan_server/routes/policies.py +++ b/src/mcp_scan_server/routes/policies.py @@ -1,12 +1,14 @@ # type: ignore import asyncio import os +from typing import Any import fastapi import rich import yaml # type: ignore from fastapi import APIRouter, Depends, Request from invariant.analyzer.policy import LocalPolicy +from invariant.analyzer.runtime.nodes import Event from invariant.analyzer.runtime.runtime_errors import ( ExcessivePolicyError, InvariantAttributeError, @@ -15,6 +17,7 @@ from pydantic import ValidationError from mcp_scan_server.activity_logger import ActivityLogger, get_activity_logger +from mcp_scan_server.session_store import SessionStore, to_session from ..models import ( DEFAULT_GUARDRAIL_CONFIG, @@ -27,6 +30,7 @@ from ..parse_config import parse_config router = APIRouter() +session_store = SessionStore() async def load_guardrails_config_file(config_file_path: str) -> GuardrailConfigFile: @@ -106,7 +110,9 @@ async def get_policy( return {"policies": policies} -async def check_policy(policy_str: str, messages: list[dict], parameters: dict | None = None) -> PolicyCheckResult: +async def check_policy( + policy_str: str, messages: list[dict[str, Any]], parameters: dict | None = None, from_index: int = -1 +) -> PolicyCheckResult: """ Check a policy using the invariant analyzer. @@ -118,6 +124,10 @@ async def check_policy(policy_str: str, messages: list[dict], parameters: dict | Returns: A PolicyCheckResult object. """ + + # If from_index is not provided, assume all but the last message have been analyzed + from_index = from_index if from_index != -1 else len(messages) - 1 + try: policy = LocalPolicy.from_string(policy_str) @@ -127,8 +137,7 @@ async def check_policy(policy_str: str, messages: list[dict], parameters: dict | success=False, error_message=str(policy), ) - - result = await policy.a_analyze_pending(messages[:-1], [messages[-1]], **(parameters or {})) + result = await policy.a_analyze_pending(messages[:from_index], messages[from_index:], **(parameters or {})) return PolicyCheckResult( policy=policy_str, @@ -162,6 +171,25 @@ def to_json_serializable_dict(obj): return type(obj).__name__ + "(" + str(obj) + ")" +async def get_messages_from_session( + check_request: BatchCheckRequest, client_name: str, server_name: str, session_id: str +) -> list[Event]: + """Get the messages from the session store.""" + try: + session = await to_session(check_request.messages, server_name, session_id) + session = session_store.fetch_and_merge(client_name, session) + messages = [node.message for node in session.get_sorted_nodes()] + except Exception as e: + rich.print( + f"[bold red]Error parsing messages for client {client_name} and server {server_name}: {e}[/bold red]" + ) + + # If we fail to parse the session, return the original messages + messages = check_request.messages + + return messages + + @router.post("/policy/check/batch", response_model=BatchCheckResponse) async def batch_check_policies( check_request: BatchCheckRequest, @@ -169,20 +197,33 @@ async def batch_check_policies( activity_logger: ActivityLogger = Depends(get_activity_logger), ): """Check a policy using the invariant analyzer.""" + metadata = check_request.parameters.get("metadata", {}) + + mcp_client = metadata.get("client", "Unknown Client") + mcp_server = metadata.get("server", "Unknown Server") + session_id = metadata.get("session_id", "") + + messages = await get_messages_from_session(check_request, mcp_client, mcp_server, session_id) + last_analysis_index = session_store[mcp_client].last_analysis_index + results = await asyncio.gather( - *[check_policy(policy, check_request.messages, check_request.parameters) for policy in check_request.policies] + *[ + check_policy(policy, messages, check_request.parameters, last_analysis_index) + for policy in check_request.policies + ] ) - metadata = check_request.parameters.get("metadata", {}) + # Update the last analysis index + session_store[mcp_client].last_analysis_index = len(messages) guardrails_action = check_request.parameters.get("action", "block") await activity_logger.log( check_request.messages, { - "client": metadata.get("client", "Unknown Client"), - "mcp_server": metadata.get("server", "Unknown Server"), + "client": mcp_client, + "mcp_server": mcp_server, "user": metadata.get("system_user", None), - "session_id": metadata.get("session_id", ""), + "session_id": session_id, }, results, guardrails_action, diff --git a/src/mcp_scan_server/server.py b/src/mcp_scan_server/server.py index a551849..d7e5508 100644 --- a/src/mcp_scan_server/server.py +++ b/src/mcp_scan_server/server.py @@ -8,7 +8,6 @@ from mcp_scan_server.activity_logger import setup_activity_logger # type: ignore -from .routes.policies import load_guardrails_config_file from .routes.policies import router as policies_router # type: ignore from .routes.push import router as push_router from .routes.trace import router as dataset_trace_router @@ -69,7 +68,8 @@ async def on_startup(self): # setup activity logger setup_activity_logger(self.app, pretty=self.pretty) - # load config file to validate + from .routes.policies import load_guardrails_config_file + await load_guardrails_config_file(self.config_file_path) async def life_span(self, app: FastAPI): diff --git a/src/mcp_scan_server/session_store.py b/src/mcp_scan_server/session_store.py new file mode 100644 index 0000000..5a9ded7 --- /dev/null +++ b/src/mcp_scan_server/session_store.py @@ -0,0 +1,114 @@ +import heapq +from dataclasses import dataclass +from datetime import datetime +from typing import Any + + +@dataclass(frozen=True) +class SessionNode: + """ + Represents a single event in a session. + """ + + timestamp: datetime + message: dict[str, Any] + session_id: str + server_name: str + original_session_index: int + + def __hash__(self) -> int: + """Assume uniqueness by session_id, index in session and time of event.""" + return hash((self.session_id, self.original_session_index, self.timestamp)) + + def __lt__(self, other: "SessionNode") -> bool: + """Sort by timestamp.""" + return self.timestamp < other.timestamp + + +class Session: + """ + Represents a sequence of SessionNodes, sorted by timestamp. + """ + + def __init__( + self, + nodes: list[SessionNode] | None = None, + ): + self.nodes: list[SessionNode] = nodes or [] + self.last_analysis_index: int = -1 + + def merge(self, other: "Session") -> None: + """ + Merge two session objects into a joint session. + This assumes the precondition that both sessions are sorted and has + the postcondition that the merged session is sorted and has no duplicates. + """ + merged_nodes = heapq.merge(self.nodes, other.nodes) + combined_nodes: list[SessionNode] = [] + seen: set[SessionNode] = set() + + for node in merged_nodes: + if node not in seen: + seen.add(node) + combined_nodes.append(node) + self.nodes = combined_nodes + + def get_sorted_nodes(self) -> list[SessionNode]: + return list(self.nodes) + + def __repr__(self): + return f"Session(nodes={self.get_sorted_nodes()})" + + +class SessionStore: + """ + Stores sessions by client_name. Optionally expires sessions after TTL seconds. + """ + + def __init__(self): + self.sessions: dict[str, Session] = {} + + def _default_session(self) -> Session: + return Session() + + def __str__(self): + return f"SessionStore(sessions={self.sessions})" + + def __getitem__(self, client_name: str) -> Session: + if client_name not in self.sessions: + self.sessions[client_name] = self._default_session() + return self.sessions[client_name] + + def __setitem__(self, client_name: str, session: Session) -> None: + self.sessions[client_name] = session + + def __repr__(self): + return self.__str__() + + def fetch_and_merge(self, client_name: str, other: Session) -> Session: + """ + Fetch the session for the given client_name and merge it with the other session, returning the merged session. + """ + session = self[client_name] + session.merge(other) + return session + + +async def to_session(messages: list[dict[str, Any]], server_name: str, session_id: str) -> Session: + """ + Convert a list of messages to a session. + """ + session_nodes: list[SessionNode] = [] + for i, message in enumerate(messages): + timestamp = datetime.fromisoformat(message["timestamp"]) + session_nodes.append( + SessionNode( + server_name=server_name, + message=message, + original_session_index=i, + session_id=session_id, + timestamp=timestamp, + ) + ) + + return Session(nodes=session_nodes) diff --git a/tests/e2e/test_full_proxy_flow.py b/tests/e2e/test_full_proxy_flow.py index 153db82..b4b5ae0 100644 --- a/tests/e2e/test_full_proxy_flow.py +++ b/tests/e2e/test_full_proxy_flow.py @@ -166,14 +166,14 @@ async def test_basic(self, toy_server_add_config_file, pretty): # wait for client to finish try: client_output = await asyncio.wait_for(client_program, timeout=20) - except asyncio.TimeoutError: + except asyncio.TimeoutError as e: print("Client timed out") process.terminate() process.wait() stdout, stderr = process.communicate() print(safe_decode(stdout)) print(safe_decode(stderr)) - raise AssertionError("timed out waiting for MCP server to respond") + raise AssertionError("timed out waiting for MCP server to respond") from e assert int(client_output["result"]) == 3 diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py new file mode 100644 index 0000000..5e574f9 --- /dev/null +++ b/tests/unit/test_session.py @@ -0,0 +1,171 @@ +import datetime + +import pytest + +from src.mcp_scan_server.session_store import Session, SessionNode, SessionStore, to_session + + +def create_timestamped_node(timestamp: datetime.datetime): + return SessionNode(timestamp=timestamp, message={}, session_id="", server_name="", original_session_index=0) + + +@pytest.fixture +def some_date(): + return datetime.datetime(2021, 1, 1, 12, 0, 0) + + +def test_session_node_ordering(some_date: datetime.datetime): + """Make sure session nodes are sorted by timestamp""" + session_nodes = [ + create_timestamped_node(some_date), + create_timestamped_node(some_date - datetime.timedelta(seconds=1)), + create_timestamped_node(some_date - datetime.timedelta(seconds=2)), + ] + session_nodes.sort() + assert session_nodes[0].timestamp < session_nodes[1].timestamp + assert session_nodes[1].timestamp < session_nodes[2].timestamp + + +def test_session_class_merge_function_ignore_duplicates(some_date: datetime.datetime): + session1_nodes = [ + create_timestamped_node(some_date), + create_timestamped_node(some_date + datetime.timedelta(seconds=1)), + create_timestamped_node(some_date + datetime.timedelta(seconds=2)), # duplicate node + ] + session2_nodes = [ + create_timestamped_node(some_date + datetime.timedelta(seconds=2)), # duplicate node + create_timestamped_node(some_date + datetime.timedelta(seconds=3)), + ] + session1 = Session(nodes=session1_nodes) + session2 = Session(nodes=session2_nodes) + + # Check that the nodes are sorted (precondition for merge) + assert session1_nodes == session1.nodes + assert session2_nodes == session2.nodes + + session1.merge(session2) + + # Check that duplicate is ignored + assert len(session1.nodes) == 4, "Duplicate nodes should be ignored" + + # Check that the nodes are sorted and dates are correct + assert session1.nodes[0].timestamp == some_date + assert session1.nodes[1].timestamp == some_date + datetime.timedelta(seconds=1) + assert session1.nodes[2].timestamp == some_date + datetime.timedelta(seconds=2) + assert session1.nodes[3].timestamp == some_date + datetime.timedelta(seconds=3) + + +def test_session_store_missing_client_name(some_date: datetime.datetime): + """Test that the session store returns a default session if the client name is not found""" + session_store = SessionStore() + session_store["client_name"] = Session(nodes=[]) + assert session_store["client_name"] is not None + + # Check that the default session is returned if the client name is not found + assert session_store["missing_client_name"] is not None + assert session_store["missing_client_name"].nodes == [] + + +def test_session_store_fetch_and_merge_only_relevant_sessions_is_updated(some_date: datetime.datetime): + session_store = SessionStore() + + # Create two clients with some nodes + client1_nodes = [ + create_timestamped_node(some_date), + create_timestamped_node(some_date + datetime.timedelta(seconds=1)), + ] + client2_nodes = [ + create_timestamped_node(some_date + datetime.timedelta(seconds=2)), + create_timestamped_node(some_date + datetime.timedelta(seconds=3)), + ] + + # Add the clients to the session store + session_store["client_name_1"] = Session(nodes=client1_nodes) + session_store["client_name_2"] = Session(nodes=client2_nodes) + + # Create new nodes for client 1 + new_nodes = [ + create_timestamped_node(some_date + datetime.timedelta(seconds=4)), + create_timestamped_node(some_date + datetime.timedelta(seconds=5)), + ] + new_nodes_session = Session(nodes=new_nodes) + + session_store.fetch_and_merge("client_name_1", new_nodes_session) + + # Check that the new nodes are merged with the old nodes + assert session_store["client_name_1"].nodes == [ + client1_nodes[0], + client1_nodes[1], + new_nodes[0], + new_nodes[1], + ] + + # Check that the other client's session is not affected + assert session_store["client_name_2"].nodes == client2_nodes + + +@pytest.mark.asyncio +async def test_original_session_index_server_name_and_session_id_are_maintained_during_merge(): + session_nodes = [ + {"role": "user", "content": "msg1", "timestamp": "2021-01-01T12:00:00Z"}, + {"role": "assistant", "content": "msg2", "timestamp": "2021-01-01T12:00:01Z"}, + ] + server1_name = "server_name1" + session_id1 = "session_id1" + session = await to_session(session_nodes, server1_name, session_id1) + assert session.nodes[0].original_session_index == 0 + assert session.nodes[1].original_session_index == 1 + + new_nodes = [ + {"role": "user", "content": "msg1", "timestamp": "2021-01-01T12:00:02Z"}, + {"role": "assistant", "content": "msg2", "timestamp": "2021-01-01T12:00:03Z"}, + ] + server2_name = "server_name2" + session_id2 = "session_id2" + new_nodes_session = await to_session(new_nodes, server2_name, session_id2) + session.merge(new_nodes_session) + + # Assert original session index is maintained + assert session.nodes[0].original_session_index == 0 + assert session.nodes[1].original_session_index == 1 + assert session.nodes[2].original_session_index == 0 + assert session.nodes[3].original_session_index == 1 + + # Assert server name and session id are maintained + assert session.nodes[0].server_name == server1_name + assert session.nodes[1].server_name == server1_name + + assert session.nodes[2].server_name == server2_name + assert session.nodes[3].server_name == server2_name + + assert session.nodes[0].session_id == session_id1 + assert session.nodes[1].session_id == session_id1 + + assert session.nodes[2].session_id == session_id2 + assert session.nodes[3].session_id == session_id2 + + +@pytest.mark.asyncio +async def test_to_session_function(): + """Test that the to_session function creates a session with the correct nodes""" + messages = [ + {"role": "user", "content": "Hello, world!", "timestamp": "2021-01-01T12:00:00Z"}, + {"role": "assistant", "content": "Hello, world!", "timestamp": "2021-01-01T12:00:01Z"}, + ] + session = await to_session(messages, "server_name", "session_id") + assert session.nodes == [ + SessionNode( + timestamp=datetime.datetime.fromisoformat(messages[0]["timestamp"]), + message=messages[0], + session_id="session_id", + server_name="server_name", + original_session_index=0, + ), + SessionNode( + timestamp=datetime.datetime.fromisoformat(messages[1]["timestamp"]), + message=messages[1], + session_id="session_id", + server_name="server_name", + original_session_index=1, + ), + ] From 4d8c1142c54596f00a3941b040447a30437063f4 Mon Sep 17 00:00:00 2001 From: knielsen404 Date: Mon, 26 May 2025 12:33:29 +0100 Subject: [PATCH 2/9] add client-level custom guardrails --- src/mcp_scan_server/models.py | 58 ++++--- src/mcp_scan_server/parse_config.py | 28 +++- src/mcp_scan_server/session_store.py | 2 +- tests/unit/test_mcp_scan_server.py | 219 ++++++++++++++++++--------- 4 files changed, 203 insertions(+), 104 deletions(-) diff --git a/src/mcp_scan_server/models.py b/src/mcp_scan_server/models.py index 4633738..5f0d96a 100644 --- a/src/mcp_scan_server/models.py +++ b/src/mcp_scan_server/models.py @@ -143,6 +143,13 @@ class ServerGuardrailConfig(BaseModel): model_config = ConfigDict(extra="forbid") +class ClientGuardrailConfig(BaseModel): + custom_guardrails: list[DatasetPolicy] | None = Field(default=None) + servers: dict[str, ServerGuardrailConfig] = Field(default_factory=dict) + + model_config = ConfigDict(extra="forbid") + + class GuardrailConfigFile: """ The guardrail config file model. @@ -158,31 +165,40 @@ class GuardrailConfigFile: A tool can also be disabled by setting enabled to False. Example config file: - ``` + ```yaml cursor: # The client - whatsapp: # The server - guardrails: - pii: block # Shorthand guardrail - moderated: paused - - custom_guardrails: # List of custom guardrails - - name: "Custom Guardrail" - id: "custom_guardrail_1" - action: block - content: | - raise "Error" if: - (msg: Message) - "error" in msg.content - - tools: # Dictionary of tools - send_message: - enabled: false # Disable the send_message tool - read_messages: - secrets: block # Block secrets + custom_guardrails: # List of client-wide custom guardrails + - name: "Custom Guardrail" + id: "custom_guardrail_1" + action: block + content: | + raise "Error" if: + (msg: Message) + "error" in msg.content + servers: + whatsapp: # The server + guardrails: + pii: block # Shorthand guardrail + moderated: paused + + custom_guardrails: # List of custom guardrails + - name: "Custom Guardrail" + id: "custom_guardrail_1" + action: block + content: | + raise "Error" if: + (msg: Message) + "error" in msg.content + + tools: # Dictionary of tools + send_message: + enabled: false # Disable the send_message tool + read_messages: + secrets: block # Block secrets ``` """ - ConfigFileStructure = dict[str, dict[str, ServerGuardrailConfig]] + ConfigFileStructure = dict[str, ClientGuardrailConfig] _config_validator = TypeAdapter(ConfigFileStructure) def __init__(self, clients: ConfigFileStructure | None = None): diff --git a/src/mcp_scan_server/parse_config.py b/src/mcp_scan_server/parse_config.py index 810b3fb..5a3a487 100644 --- a/src/mcp_scan_server/parse_config.py +++ b/src/mcp_scan_server/parse_config.py @@ -13,6 +13,7 @@ whitelist_tool_from_guardrail, ) from mcp_scan_server.models import ( + ClientGuardrailConfig, DatasetPolicy, GuardrailConfigFile, GuardrailMode, @@ -313,6 +314,20 @@ def parse_tool_shorthand_guardrails( return result, disabled_tools +def parse_client_guardrails( + config: ClientGuardrailConfig, +) -> list[DatasetPolicy]: + """Parse client-specific guardrails from the client config. + + Args: + config: The client guardrail config. + + Returns: + A list of DatasetPolicy objects from client guardrails. + """ + return config.custom_guardrails or [] + + @lru_cache async def parse_config( config: GuardrailConfigFile, @@ -330,17 +345,17 @@ async def parse_config( A list of DatasetPolicy objects with all guardrails. """ policies: list[DatasetPolicy] = [] - found = False for client, client_config in config.items(): if (client_name and client != client_name) or not client_config: continue - for server, server_config in client_config.items(): + # Add client-level (custom) guardrails directly to the policies + policies.extend(parse_client_guardrails(client_config)) + + for server, server_config in client_config.servers.items(): if server_name and server != server_name: continue - # mark as found - found = True # Parse guardrails for this client-server pair server_shorthands = parse_server_shorthand_guardrails(server_config) @@ -351,13 +366,12 @@ async def parse_config( policies.extend(collect_guardrails(server_shorthands, tool_shorthands, disabled_tools, client, server)) policies.extend(custom_guardrails) - if not found: + # Create all default guardrails if no guardrails are configured + if len(policies) == 0: logger.warning( "No guardrails found for client '%s' and server '%s'. Using default guardrails.", client_name, server_name ) - # Create all default guardrails if no guardrails are configured - if len(policies) == 0: for name in get_available_templates(): policies.append(generate_policy(name, GuardrailMode.log, client_name, server_name)) diff --git a/src/mcp_scan_server/session_store.py b/src/mcp_scan_server/session_store.py index 5a9ded7..b9109bc 100644 --- a/src/mcp_scan_server/session_store.py +++ b/src/mcp_scan_server/session_store.py @@ -62,7 +62,7 @@ def __repr__(self): class SessionStore: """ - Stores sessions by client_name. Optionally expires sessions after TTL seconds. + Stores sessions by client_name. """ def __init__(self): diff --git a/tests/unit/test_mcp_scan_server.py b/tests/unit/test_mcp_scan_server.py index 56c3bf6..bc0ef25 100644 --- a/tests/unit/test_mcp_scan_server.py +++ b/tests/unit/test_mcp_scan_server.py @@ -16,6 +16,7 @@ whitelist_tool_from_guardrail, ) from mcp_scan_server.models import ( + ClientGuardrailConfig, DatasetPolicy, GuardrailConfig, GuardrailConfigFile, @@ -86,36 +87,38 @@ def valid_guardrail_config_file(tmp_path): config_file.write_text( """ cursor: - server1: - guardrails: - pii: "block" - moderated: "block" - links: "block" - secrets: "block" - - custom_guardrails: - - name: "Guardrail 1" - id: "guardrail_1" - enabled: true - action: "block" - content: | - raise "error" if: - (msg: ToolOutput) - "Test1" in msg.content + servers: + server1: + guardrails: + pii: "block" + moderated: "block" + links: "block" + secrets: "block" + + custom_guardrails: + - name: "Guardrail 1" + id: "guardrail_1" + enabled: true + action: "block" + content: | + raise "error" if: + (msg: ToolOutput) + "Test1" in msg.content - tools: + tools: tool_name: - enabled: true - pii: "block" - moderated: "block" - links: "block" - secrets: "block" - server2: - guardrails: - pii: "block" - moderated: "block" - links: "block" - secrets: "block" + enabled: true + pii: "block" + moderated: "block" + links: "block" + secrets: "block" + + server2: + guardrails: + pii: "block" + moderated: "block" + links: "block" + secrets: "block" """ ) return str(config_file) @@ -482,14 +485,16 @@ async def test_server_shorthands_override_default_guardrails(): """Test that server shorthands override default guardrails.""" config = GuardrailConfigFile( { - "cursor": { - "server1": ServerGuardrailConfig( - guardrails=GuardrailConfig( - pii=GuardrailMode.block, - moderated=GuardrailMode.paused, + "cursor": ClientGuardrailConfig( + servers={ + "server1": ServerGuardrailConfig( + guardrails=GuardrailConfig( + pii=GuardrailMode.block, + moderated=GuardrailMode.paused, + ), ), - ), - } + }, + ) } ) policies = await parse_config(config, "cursor", "server1") @@ -511,15 +516,17 @@ async def test_tools_partially_override_default_guardrails(mock_get_templates): """Test that tools partially override default guardrails.""" config = GuardrailConfigFile( { - "cursor": { - "server1": ServerGuardrailConfig( - tools={ - "tool_name": ToolGuardrailConfig( - pii=GuardrailMode.block, - ), - }, - ) - } + "cursor": ClientGuardrailConfig( + servers={ + "server1": ServerGuardrailConfig( + tools={ + "tool_name": ToolGuardrailConfig( + pii=GuardrailMode.block, + ), + }, + ), + }, + ) } ) @@ -551,40 +558,43 @@ async def test_tools_partially_override_default_guardrails(mock_get_templates): @pytest.mark.asyncio async def test_parse_config(): """Test that the parse_config function parses the config file correctly.""" + config = GuardrailConfigFile( { - "cursor": { - "server1": ServerGuardrailConfig( - guardrails=GuardrailConfig( - pii=GuardrailMode.block, - moderated=GuardrailMode.log, - secrets=GuardrailMode.paused, - ), - tools={ - "tool_name": ToolGuardrailConfig( + "cursor": ClientGuardrailConfig( + servers={ + "server1": ServerGuardrailConfig( + guardrails=GuardrailConfig( pii=GuardrailMode.block, - moderated=GuardrailMode.paused, - links=GuardrailMode.log, - enabled=True, - ), - "tool_name2": ToolGuardrailConfig( - pii=GuardrailMode.block, - moderated=GuardrailMode.block, - enabled=True, + moderated=GuardrailMode.log, + secrets=GuardrailMode.paused, ), - }, - ) - } + tools={ + "tool_name": ToolGuardrailConfig( + pii=GuardrailMode.block, + moderated=GuardrailMode.paused, + links=GuardrailMode.log, + enabled=True, + ), + "tool_name2": ToolGuardrailConfig( + pii=GuardrailMode.block, + moderated=GuardrailMode.block, + enabled=True, + ), + }, + ) + } + ) } ) - config = await parse_config(config) + policies = await parse_config(config) # We should have 7 policies since: # pii creates 1 policy because the action (block) of all shorthands match # moderated creates 3 policies (one for each action) # secrets creates 1 policy because it is defined as a server shorthand # links creates 2 policies -- one for the tool_name shorthand and one default - assert len(config) == 7 + assert len(policies) == 7 @pytest.mark.asyncio @@ -592,15 +602,17 @@ async def test_disable_tool(): """Test that the disable_tool function disables a tool correctly.""" config = GuardrailConfigFile( { - "cursor": { - "server1": ServerGuardrailConfig( - tools={ - "tool_name": ToolGuardrailConfig( - enabled=False, - ), - }, - ), - } + "cursor": ClientGuardrailConfig( + servers={ + "server1": ServerGuardrailConfig( + tools={ + "tool_name": ToolGuardrailConfig( + enabled=False, + ), + }, + ), + }, + ) } ) policies = await parse_config(config, "cursor", "server1") @@ -634,3 +646,60 @@ async def test_disable_tool(): ] ) assert len(result.errors) == 1, "Tool should be blocked" + + +@pytest.mark.asyncio +@patch("mcp_scan_server.parse_config.get_available_templates", return_value=()) +async def test_server_level_guardrails_are_applied_to_all_servers(mock_get_templates): + """Test that server level guardrails are applied to all servers.""" + config = GuardrailConfigFile( + { + "cursor": ClientGuardrailConfig( + custom_guardrails=[ + { + "name": "Guardrail 1", + "id": "guardrail_1", + "enabled": True, + "action": "block", + "content": "raise 'error' if: (msg: Message) 'error' in msg.content", + } + ] + ) + } + ) + + # Test that regardless of the server, the guardrail is applied when the client is cursor + policies = await parse_config(config, "cursor", "server1") + assert len(policies) == 1 + assert "error" in policies[0].content + + policies = await parse_config(config, "cursor", "server2") + assert len(policies) == 1 + assert "error" in policies[0].content + + # Test that it is not applied when the client is not cursor + policies = await parse_config(config, "not_cursor", "server1") + assert len(policies) == 0 + + +@pytest.mark.asyncio +@patch("mcp_scan_server.parse_config.get_available_templates", return_value=()) +async def test_server_level_guardrails(mock_get_templates): + """Test that server level guardrails are applied correctly.""" + + config = """ +cursor: + custom_guardrails: + - name: "Guardrail 1" + id: "guardrail_1" + enabled: true + action: "block" + content: | + raise "this is a custom error" if: + (msg: Message) + "error" in msg.content +""" + config = GuardrailConfigFile.model_validate(yaml.safe_load(config)) + policies = await parse_config(config) + assert len(policies) == 1 + assert "this is a custom error" in policies[0].content From 618b35719681da16642f89db98d2908645279276 Mon Sep 17 00:00:00 2001 From: knielsen404 Date: Mon, 26 May 2025 13:10:26 +0100 Subject: [PATCH 3/9] fix: default guardrails --- src/mcp_scan_server/models.py | 4 +++ src/mcp_scan_server/parse_config.py | 44 +++++++++++++------------- tests/unit/test_mcp_scan_server.py | 48 ++++++++++++++++++++++++++--- 3 files changed, 70 insertions(+), 26 deletions(-) diff --git a/src/mcp_scan_server/models.py b/src/mcp_scan_server/models.py index 5f0d96a..acf3253 100644 --- a/src/mcp_scan_server/models.py +++ b/src/mcp_scan_server/models.py @@ -1,6 +1,7 @@ import datetime from collections.abc import ItemsView from enum import Enum +from typing import Any import yaml # type: ignore from invariant.analyzer.policy import AnalysisResult @@ -235,6 +236,9 @@ def model_dump_yaml(self) -> str: def __getitem__(self, key: str) -> dict[str, ServerGuardrailConfig]: return self.clients[key] + def get(self, key: str, default: Any = None) -> dict[str, ServerGuardrailConfig]: + return self.clients.get(key, default) + def __getattr__(self, key: str) -> dict[str, ServerGuardrailConfig]: return self.clients[key] diff --git a/src/mcp_scan_server/parse_config.py b/src/mcp_scan_server/parse_config.py index 5a3a487..477b5c4 100644 --- a/src/mcp_scan_server/parse_config.py +++ b/src/mcp_scan_server/parse_config.py @@ -101,8 +101,8 @@ def get_available_templates(directory: Path = DEFAULT_GUARDRAIL_DIR) -> tuple[st def generate_disable_tool_policy( tool_name: str, - client_name: str, - server_name: str, + client_name: str | None, + server_name: str | None, ) -> DatasetPolicy: """Generate a guardrail policy to disable a tool. @@ -176,8 +176,8 @@ def collect_guardrails( server_shorthand_guardrails: dict[str, GuardrailMode], tool_shorthand_guardrails: dict[str, dict[str, GuardrailMode]], disabled_tools: list[str], - client: str, - server: str, + client: str | None, + server: str | None, ) -> list[DatasetPolicy]: """Collect all guardrails and resolve conflicts. @@ -248,7 +248,9 @@ def collect_guardrails( return policies -def parse_custom_guardrails(config: ServerGuardrailConfig, client: str, server: str) -> list[DatasetPolicy]: +def parse_custom_guardrails( + config: ServerGuardrailConfig, client: str | None, server: str | None +) -> list[DatasetPolicy]: """Parse custom guardrails from the server config. Args: @@ -344,35 +346,33 @@ async def parse_config( Returns: A list of DatasetPolicy objects with all guardrails. """ - policies: list[DatasetPolicy] = [] - - for client, client_config in config.items(): - if (client_name and client != client_name) or not client_config: - continue + client_policies: list[DatasetPolicy] = [] + server_policies: list[DatasetPolicy] = [] + client_config = config.get(client_name) + if client_config: # Add client-level (custom) guardrails directly to the policies - policies.extend(parse_client_guardrails(client_config)) - - for server, server_config in client_config.servers.items(): - if server_name and server != server_name: - continue + client_policies.extend(parse_client_guardrails(client_config)) + server_config = client_config.servers.get(server_name) + if server_config: # Parse guardrails for this client-server pair server_shorthands = parse_server_shorthand_guardrails(server_config) tool_shorthands, disabled_tools = parse_tool_shorthand_guardrails(server_config) - custom_guardrails = parse_custom_guardrails(server_config, client, server) + custom_guardrails = parse_custom_guardrails(server_config, client_name, server_name) - # Collect and resolve guardrails - policies.extend(collect_guardrails(server_shorthands, tool_shorthands, disabled_tools, client, server)) - policies.extend(custom_guardrails) + server_policies.extend( + collect_guardrails(server_shorthands, tool_shorthands, disabled_tools, client_name, server_name) + ) + server_policies.extend(custom_guardrails) # Create all default guardrails if no guardrails are configured - if len(policies) == 0: + if len(server_policies) == 0: logger.warning( "No guardrails found for client '%s' and server '%s'. Using default guardrails.", client_name, server_name ) for name in get_available_templates(): - policies.append(generate_policy(name, GuardrailMode.log, client_name, server_name)) + server_policies.append(generate_policy(name, GuardrailMode.log, client_name, server_name)) - return policies + return client_policies + server_policies diff --git a/tests/unit/test_mcp_scan_server.py b/tests/unit/test_mcp_scan_server.py index bc0ef25..80ab945 100644 --- a/tests/unit/test_mcp_scan_server.py +++ b/tests/unit/test_mcp_scan_server.py @@ -470,14 +470,27 @@ async def test_empty_string_config_generates_default_guardrails(mock_get_templat """Test that the parse_config function generates the correct policies.""" config_str = """ """ + + number_of_templates = get_number_of_guardrail_templates() + config = GuardrailConfigFile.model_validate(config_str) policies = await parse_config(config) - assert len(policies) == get_number_of_guardrail_templates() + assert len(policies) == number_of_templates config_str = None config = GuardrailConfigFile.model_validate(config_str) policies = await parse_config(config) - assert len(policies) == get_number_of_guardrail_templates() + assert len(policies) == number_of_templates + + # Check that parsing in client and server args still works + policies = await parse_config(config, "cursor", None) + assert len(policies) == number_of_templates + + policies = await parse_config(config, None, "server1") + assert len(policies) == number_of_templates + + policies = await parse_config(config, "cursor", "server1") + assert len(policies) == number_of_templates @pytest.mark.asyncio @@ -587,7 +600,7 @@ async def test_parse_config(): ) } ) - policies = await parse_config(config) + policies = await parse_config(config, "cursor", "server1") # We should have 7 policies since: # pii creates 1 policy because the action (block) of all shorthands match @@ -700,6 +713,33 @@ async def test_server_level_guardrails(mock_get_templates): "error" in msg.content """ config = GuardrailConfigFile.model_validate(yaml.safe_load(config)) - policies = await parse_config(config) + policies = await parse_config(config, "cursor", "server1") assert len(policies) == 1 assert "this is a custom error" in policies[0].content + + +@pytest.mark.asyncio +@patch("mcp_scan_server.parse_config.get_available_templates", return_value=("pii",)) +async def test_defaults_are_added_with_client_level_guardrails(mock_get_templates): + """Test that defaults are added with client level guardrails.""" + config = GuardrailConfigFile( + { + "cursor": ClientGuardrailConfig( + custom_guardrails=[ + { + "name": "Guardrail 1", + "id": "guardrail_1", + "enabled": True, + "content": "raise 'custom error' if: (msg: Message)", + } + ] + ) + } + ) + + policies = await parse_config(config, "cursor", "server1") + assert len(policies) == 2 + + policy_ids = [policy.id for policy in policies] + assert "guardrail_1" in policy_ids + assert "cursor-server1-pii-default" in policy_ids From 2dd1081c8ee39ab1ee84ec99544735593366c6c5 Mon Sep 17 00:00:00 2001 From: knielsen404 Date: Tue, 27 May 2025 16:43:28 +0100 Subject: [PATCH 4/9] feat: add mcp proxy --record --- src/mcp_scan/cli.py | 45 ++-- src/mcp_scan_server/record_file.py | 294 +++++++++++++++++++++++++++ src/mcp_scan_server/routes/push.py | 21 +- src/mcp_scan_server/routes/trace.py | 10 + src/mcp_scan_server/server.py | 6 + src/mcp_scan_server/session_store.py | 41 +++- tests/unit/test_record_file.py | 278 +++++++++++++++++++++++++ tests/unit/test_session.py | 95 +++++++++ 8 files changed, 767 insertions(+), 23 deletions(-) create mode 100644 src/mcp_scan_server/record_file.py create mode 100644 tests/unit/test_record_file.py diff --git a/src/mcp_scan/cli.py b/src/mcp_scan/cli.py index 95b27eb..d525fea 100644 --- a/src/mcp_scan/cli.py +++ b/src/mcp_scan/cli.py @@ -125,7 +125,26 @@ def add_server_arguments(parser): choices=["oneline", "compact", "full", "none"], help="Pretty print the output (default: compact)", ) - server_group.add_argument( + + +def add_mcp_scan_server_arguments(parser): + """Add arguments related to MCP scan server.""" + mcp_scan_server_group = parser.add_argument_group("MCP Scan Server Options") + mcp_scan_server_group.add_argument( + "--port", + type=int, + default=8129, + help="Port to run the server on (default: 8129).", + metavar="PORT", + ) + mcp_scan_server_group.add_argument( + "--record", + type=str, + default=None, + help="Filename to record the proxy requests to.", + metavar="RECORD_FILE", + ) + mcp_scan_server_group.add_argument( "--install-extras", nargs="+", default=None, @@ -133,6 +152,8 @@ def add_server_arguments(parser): metavar="EXTRA", ) + return mcp_scan_server_group + def add_install_arguments(parser): parser.add_argument( @@ -352,28 +373,16 @@ def main(): # SERVER command server_parser = subparsers.add_parser("server", help="Start the MCP scan server") - server_parser.add_argument( - "--port", - type=int, - default=8129, - help="Port to run the server on (default: 8129)", - metavar="PORT", - ) add_common_arguments(server_parser) add_server_arguments(server_parser) + add_mcp_scan_server_arguments(server_parser) # PROXY command proxy_parser = subparsers.add_parser("proxy", help="Installs and proxies MCP requests, uninstalls on exit") - proxy_parser.add_argument( - "--port", - type=int, - default=8129, - help="Port to run the server on (default: 8129)", - metavar="PORT", - ) add_common_arguments(proxy_parser) add_server_arguments(proxy_parser) add_install_arguments(proxy_parser) + add_mcp_scan_server_arguments(proxy_parser) # Parse arguments (default to 'scan' if no command provided) args = parser.parse_args(["scan"] if len(sys.argv) == 1 else None) @@ -414,7 +423,11 @@ def server(on_exit=None): sf = StorageFile(args.storage_file) guardrails_config_path = sf.create_guardrails_config() mcp_scan_server = MCPScanServer( - port=args.port, config_file_path=guardrails_config_path, on_exit=on_exit, pretty=args.pretty + port=args.port, + config_file_path=guardrails_config_path, + on_exit=on_exit, + pretty=args.pretty, + record_file=args.record, ) mcp_scan_server.run() diff --git a/src/mcp_scan_server/record_file.py b/src/mcp_scan_server/record_file.py new file mode 100644 index 0000000..2c9a63a --- /dev/null +++ b/src/mcp_scan_server/record_file.py @@ -0,0 +1,294 @@ +import json +import os +import uuid +from dataclasses import dataclass + +import rich +from invariant_sdk.client import Client + +from .session_store import Message, Session, SessionStore + + +class TraceClientMapping: + """ + A singleton class to store the mapping between trace ids and client names. + + This is used to ensure that a trace id is generated only once for a given client and + that it is consistent, so we can append to explorer properly. + """ + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.trace_id_to_client_name = {} + cls._instance.client_name_to_trace_id = {} + return cls._instance + + def get_client_name(self, trace_id: str) -> str | None: + """ + Get the client name for the given trace id. + """ + return self.trace_id_to_client_name.get(trace_id, None) + + def get_trace_id(self, client_name: str) -> str | None: + """ + Get the trace id for the given client name. + """ + return self.client_name_to_trace_id.get(client_name, None) + + def set_trace_id(self, trace_id: str, client_name: str) -> None: + """ + Set the trace id for the given client name. + """ + self.trace_id_to_client_name[trace_id] = client_name + self.client_name_to_trace_id[client_name] = trace_id + + def clear(self) -> None: + """ + Clear the trace client mapping. + """ + self.trace_id_to_client_name: dict[str, str] = {} + self.client_name_to_trace_id: dict[str, str] = {} + + def __str__(self) -> str: + return f"TraceClientMapping[{self.trace_id_to_client_name}]" + + def __repr__(self) -> str: + return self.__str__() + + +def generate_record_file_postfix() -> str: + """Generate a random postfix for the record file.""" + return str(uuid.uuid4())[:8] + + +# Initialize the invariant sdk client if the API key is set +invariant_sdk_client = Client() if os.environ.get("INVARIANT_API_KEY") else None + +# Initialize the trace client mapping and session store +trace_client_mapping = TraceClientMapping() +session_store = SessionStore() +record_post_fix = generate_record_file_postfix() + + +@dataclass(frozen=True) +class RecordFile: + """Base class for record file names.""" + + def __message_styling_wrapper__(self, message: str) -> str: + """Wrap the message in styling.""" + return f"[yellow]Recording session data to {message}[/yellow]" + + def startup_message(self) -> str: + """Return a message to be printed on startup.""" + raise NotImplementedError("Subclasses must implement this method") + + +@dataclass(frozen=True) +class ExplorerRecordFile(RecordFile): + """Record file for explorer datasets.""" + + dataset_name: str + + def startup_message(self) -> str: + """Return a message to be printed on startup.""" + return self.__message_styling_wrapper__(f"explorer dataset: '{self.dataset_name}'") + + +@dataclass(frozen=True) +class LocalRecordFile(RecordFile): + """Record file for local files.""" + + filename: str + base_path: str = os.path.expanduser("~/.mcp-scan/sessions") + postfix: str = "" + + def startup_message(self) -> str: + """Return a message to be printed on startup.""" + return self.__message_styling_wrapper__(f"local file: '{os.path.join(self.base_path, self.filename)}'") + + def get_session_file_path(self, client_name: str | None) -> str: + """Get the path to the session file for a given client.""" + # Use client name as the filename, with .jsonl extension + client_name = client_name or "unknown" + return os.path.join(self.base_path, f"{client_name}-{self.postfix}.jsonl") + + +def parse_record_file_name(record_file: str | None) -> RecordFile | None: + """Parse the record file name and return a RecordFile object.""" + if record_file is None: + return None + + # Check if it has the form explorer:{dataset_name} + if record_file.startswith("explorer:"): + dataset_name = record_file.split(":")[1] + return ExplorerRecordFile(dataset_name) + + # Check that it ends with .json or .jsonl + if not record_file.endswith(".json") and not record_file.endswith(".jsonl"): + raise ValueError(f"Record file must end with .json or .jsonl: {record_file}") + + return LocalRecordFile(record_file, postfix=record_post_fix) + + +async def _push_session_to_explorer( + session_data: list[Message], record_file: ExplorerRecordFile, client_name: str +) -> str | None: + """ + Push the session to the explorer. + """ + if invariant_sdk_client is None: + raise ValueError( + "Invariant SDK client is not initialized. Please set the INVARIANT_API_KEY environment variable." + ) + + try: + response = invariant_sdk_client.create_request_and_push_trace( + messages=[session_data], + dataset=record_file.dataset_name, + metadata=[ + { + "hierarchy_path": [client_name], + } + ], + ) + + trace_id = response.id[0] + trace_client_mapping.set_trace_id(trace_id, client_name) + return trace_id + except Exception as e: + rich.print(f"[bold red]Error pushing session to explorer: {e}[/bold red]") + return None + + +async def _push_session_to_local_file( + session_data: list[Message], record_file: LocalRecordFile, client_name: str +) -> str | None: + """ + Push the session to the local file. + """ + os.makedirs(record_file.base_path, exist_ok=True) + trace_id = str(uuid.uuid4()) + trace_client_mapping.set_trace_id(trace_id, client_name) + + session_file_path = record_file.get_session_file_path(client_name) + + # Write each message as a JSONL line + with open(session_file_path, "a") as f: + for message in session_data: + f.write(json.dumps(message) + "\n") + + return trace_id + + +async def _format_session_data(session: Session, index: int) -> list[Message]: + """ + Format the session data for the record file. + Only returns new messages that haven't been recorded yet. + """ + # convert session data to a list of messages, but only include new ones + messages = [] + for i, node in enumerate(session.nodes): + if i > index: + messages.append(node.message) + + return messages + + +async def _append_messages_to_explorer( + trace_id: str, session_data: list[Message], record_file: ExplorerRecordFile +) -> None: + """ + Append messages to the explorer. + """ + if invariant_sdk_client is None: + raise ValueError( + "Invariant SDK client is not initialized. Please set the INVARIANT_API_KEY environment variable." + ) + + invariant_sdk_client.create_request_and_append_messages( + messages=session_data, + trace_id=trace_id, + ) + + +async def _append_messages_to_local_file( + trace_id: str, session_data: list[Message], record_file: LocalRecordFile +) -> None: + """ + Append messages to the local file. + """ + client_name = trace_client_mapping.get_client_name(trace_id) + session_file_path = record_file.get_session_file_path(client_name) + + # Append each message as a new line + with open(session_file_path, "a") as f: + for message in session_data: + f.write(json.dumps(message) + "\n") + + +async def push_session_to_record_file(session: Session, record_file: RecordFile, client_name: str) -> str | None: + """ + Push a session to a record file. + + This function may be called multiple times with partially the same data. The behavior is as follows: + - The first time we try to push for a given client, we push the data (either create it for local files or push to explorer) + - The next time we try to push for a given client (for example from a different server) we instead append to the record file. + We monitor which clients have been pushed to the record file by checking the trace client mapping. + + Returns the trace id. + """ + index = session.last_pushed_index + session_data = await _format_session_data(session, index) + + # If there are no new messages, return None + if not session_data: + return None + + # If we have already pushed for this client, append to the record file + if trace_id := trace_client_mapping.get_trace_id(client_name): + await append_messages_to_record_file(trace_id, record_file) + return trace_id + + # Otherwise, push to the record file + if isinstance(record_file, ExplorerRecordFile): + trace_id = await _push_session_to_explorer(session_data, record_file, client_name) + elif isinstance(record_file, LocalRecordFile): + trace_id = await _push_session_to_local_file(session_data, record_file, client_name) + else: + raise ValueError(f"Invalid record file: {record_file}") + + # Update the last pushed index + session.last_pushed_index += len(session_data) + + return trace_id + + +async def append_messages_to_record_file(trace_id: str, record_file: RecordFile) -> None: + """ + Append messages to the record file. + """ + client_name = trace_client_mapping.get_client_name(trace_id) + if client_name is None: + raise ValueError(f"Trace id {trace_id} not found in trace client mapping") + + session = session_store[client_name] + index = session.last_pushed_index + session_data = await _format_session_data(session, index) + + # If there are no new messages, return + if not session_data: + return + + # Otherwise, append to the record file + if isinstance(record_file, ExplorerRecordFile): + await _append_messages_to_explorer(trace_id, session_data, record_file) + elif isinstance(record_file, LocalRecordFile): + await _append_messages_to_local_file(trace_id, session_data, record_file) + else: + raise ValueError(f"Invalid record file: {record_file}") + + # Update the last pushed index + session.last_pushed_index += len(session_data) diff --git a/src/mcp_scan_server/routes/push.py b/src/mcp_scan_server/routes/push.py index aef75e4..0b7549a 100644 --- a/src/mcp_scan_server/routes/push.py +++ b/src/mcp_scan_server/routes/push.py @@ -1,15 +1,30 @@ +import json import uuid from fastapi import APIRouter, Request from invariant_sdk.types.push_traces import PushTracesResponse +from ..record_file import push_session_to_record_file +from ..session_store import SessionStore + router = APIRouter() +session_store = SessionStore() + @router.post("/trace") async def push_trace(request: Request) -> PushTracesResponse: """Push a trace. For now, this is a dummy response.""" - trace_id = str(uuid.uuid4()) - # return the trace ID - return PushTracesResponse(id=[trace_id], success=True) + body = await request.body() + client = json.loads(body) + mcp_client = client.get("metadata")[0].get("client") + session = session_store[mcp_client] + + record_file = request.app.state.record_file + + # Push the session to the record file if it exists + if trace_id := await push_session_to_record_file(session, record_file, mcp_client): + return PushTracesResponse(id=[trace_id], success=True) + else: + return PushTracesResponse(id=[str(uuid.uuid4())], success=False) diff --git a/src/mcp_scan_server/routes/trace.py b/src/mcp_scan_server/routes/trace.py index 92e73dc..59fde53 100644 --- a/src/mcp_scan_server/routes/trace.py +++ b/src/mcp_scan_server/routes/trace.py @@ -1,10 +1,20 @@ from fastapi import APIRouter, Request +from ..record_file import append_messages_to_record_file +from ..session_store import SessionStore + router = APIRouter() +session_store = SessionStore() + + @router.post("/{trace_id}/messages") async def append_messages(trace_id: str, request: Request): """Append messages to a trace. For now this is a dummy response.""" + # If we are calling append, we should already have set the trace_id + if request.app.state.record_file: + await append_messages_to_record_file(trace_id, request.app.state.record_file) + return {"success": True} diff --git a/src/mcp_scan_server/server.py b/src/mcp_scan_server/server.py index d7e5508..62dada7 100644 --- a/src/mcp_scan_server/server.py +++ b/src/mcp_scan_server/server.py @@ -8,6 +8,7 @@ from mcp_scan_server.activity_logger import setup_activity_logger # type: ignore +from .record_file import parse_record_file_name from .routes.policies import router as policies_router # type: ignore from .routes.push import router as push_router from .routes.trace import router as dataset_trace_router @@ -32,6 +33,7 @@ def __init__( on_exit: Callable | None = None, log_level: str = "error", pretty: Literal["oneline", "compact", "full", "none"] = "compact", + record_file: str | None = None, ): self.port = port self.config_file_path = config_file_path @@ -41,6 +43,7 @@ def __init__( self.app = FastAPI(lifespan=self.life_span) self.app.state.config_file_path = config_file_path + self.app.state.record_file = parse_record_file_name(record_file) self.app.include_router(policies_router, prefix="/api/v1") self.app.include_router(push_router, prefix="/api/v1/push") @@ -65,6 +68,9 @@ async def on_startup(self): """Startup event for the FastAPI app.""" rich.print("[bold green]MCP-scan server started (http://localhost:" + str(self.port) + ")[/bold green]") + if self.app.state.record_file is not None: + rich.print(self.app.state.record_file.startup_message()) + # setup activity logger setup_activity_logger(self.app, pretty=self.pretty) diff --git a/src/mcp_scan_server/session_store.py b/src/mcp_scan_server/session_store.py index b9109bc..a44abc2 100644 --- a/src/mcp_scan_server/session_store.py +++ b/src/mcp_scan_server/session_store.py @@ -3,6 +3,8 @@ from datetime import datetime from typing import Any +Message = dict[str, Any] + @dataclass(frozen=True) class SessionNode: @@ -11,7 +13,7 @@ class SessionNode: """ timestamp: datetime - message: dict[str, Any] + message: Message session_id: str server_name: str original_session_index: int @@ -24,6 +26,12 @@ def __lt__(self, other: "SessionNode") -> bool: """Sort by timestamp.""" return self.timestamp < other.timestamp + def to_json(self) -> Message: + """ + Convert the session node to a message. + """ + return self.message + class Session: """ @@ -36,6 +44,7 @@ def __init__( ): self.nodes: list[SessionNode] = nodes or [] self.last_analysis_index: int = -1 + self.last_pushed_index: int = -1 def merge(self, other: "Session") -> None: """ @@ -59,15 +68,27 @@ def get_sorted_nodes(self) -> list[SessionNode]: def __repr__(self): return f"Session(nodes={self.get_sorted_nodes()})" + def to_json(self) -> list[Message]: + """ + Convert the session to a list of messages. + """ + return [node.to_json() for node in self.nodes] + class SessionStore: """ Stores sessions by client_name. """ - def __init__(self): - self.sessions: dict[str, Session] = {} + _instance = None + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.sessions = {} + return cls._instance + + @classmethod def _default_session(self) -> Session: return Session() @@ -93,8 +114,20 @@ def fetch_and_merge(self, client_name: str, other: Session) -> Session: session.merge(other) return session + def to_json(self) -> dict[str, dict[str, list[dict[str, Any]]]]: + """ + Convert the sessions to a dictionary. + """ + return {"sessions": {client_name: session.to_json() for client_name, session in self.sessions.items()}} + + def clear(self) -> None: + """ + Clear the session store. + """ + self.sessions: dict[str, Session] = {} + -async def to_session(messages: list[dict[str, Any]], server_name: str, session_id: str) -> Session: +async def to_session(messages: list[Message], server_name: str, session_id: str) -> Session: """ Convert a list of messages to a session. """ diff --git a/tests/unit/test_record_file.py b/tests/unit/test_record_file.py new file mode 100644 index 0000000..0e2b2f3 --- /dev/null +++ b/tests/unit/test_record_file.py @@ -0,0 +1,278 @@ +import datetime +import json +import os +from unittest.mock import ANY, Mock + +import pytest + +from mcp_scan_server.record_file import ( + ExplorerRecordFile, + LocalRecordFile, + TraceClientMapping, + parse_record_file_name, + push_session_to_record_file, +) +from mcp_scan_server.session_store import Session, SessionNode, SessionStore + + +@pytest.fixture(autouse=True) +def cleanup_trace_client_mapping(): + """ + Cleanup the trace client mapping. + """ + trace_client_mapping = TraceClientMapping() + trace_client_mapping.clear() + + +@pytest.fixture(autouse=True) +def cleanup_session_store(): + """ + Cleanup the session store. + """ + session_store = SessionStore() + session_store.clear() + + +@pytest.mark.parametrize("filename", ["explorer:test", "explorer:test.json", "explorer:test.jsonl"]) +def test_parse_record_filename_explorer_valid(filename): + file = parse_record_file_name(filename) + assert file.dataset_name == filename.split(":")[1] + assert isinstance(file, ExplorerRecordFile) + + +@pytest.mark.parametrize("filename", ["test.json", "test.jsonl"]) +def test_parse_record_filename_local_valid(filename): + file = parse_record_file_name(filename) + assert file.filename == filename + assert isinstance(file, LocalRecordFile) + + +@pytest.mark.parametrize("filename", ["test.something", "test"]) +def test_parse_record_filename_local_invalid(filename): + with pytest.raises(ValueError): + parse_record_file_name(filename) + + +def test_trace_client_mapping(): + """ + Test that we can set and get trace ids and client names. + """ + trace_client_mapping = TraceClientMapping() + trace_client_mapping.set_trace_id("trace_id", "client_name") + trace_client_mapping.set_trace_id("trace_id_2", "client_name_2") + + # Test that the trace id and client name are correctly set + assert trace_client_mapping.get_trace_id("client_name") == "trace_id" + assert trace_client_mapping.get_client_name("trace_id") == "client_name" + + # test that non-existent trace ids return None + assert trace_client_mapping.get_trace_id("client_name_3") is None + assert trace_client_mapping.get_client_name("trace_id_3") is None + + # test that non-existent client names return None + assert trace_client_mapping.get_trace_id("client_name_4") is None + assert trace_client_mapping.get_client_name("trace_id_4") is None + + +def test_trace_client_mapping_shares_state(): + """ + Test that we maintain the same state across multiple instances of the TraceClientMapping class. + """ + trace_ids = ["trace_id", "trace_id_2"] + client_names = ["client_name", "client_name_2"] + trace_client_mapping = TraceClientMapping() + + # Populate the first mapping + for trace_id, client_name in zip(trace_ids, client_names, strict=False): + trace_client_mapping.set_trace_id(trace_id, client_name) + + # Create a new mapping and check that the mappings are the same + trace_client_mapping_2 = TraceClientMapping() + for trace_id, client_name in zip(trace_ids, client_names, strict=False): + assert trace_client_mapping_2.get_trace_id(client_name) == trace_id + assert trace_client_mapping_2.get_client_name(trace_id) == client_name + + +def _setup_test_session_and_session_store(): + """Helper function to set up a test session.""" + # Create a session with a message + session = Session( + nodes=[ + SessionNode( + timestamp=datetime.datetime.now(), + message={"role": "user", "content": "test"}, + session_id="session_id", + server_name="server_name", + original_session_index=0, + ) + ] + ) + + # Add the session to the session store + session_store = SessionStore() + session_store["client_name"] = session + + return session_store, session + + +def _setup_test_session_and_mock(monkeypatch, mock_trace_id="test_trace_id"): + """Helper function to set up test session and mock client.""" + # Create a mock for the invariant SDK client + mock_client = Mock() + mock_client.create_request_and_push_trace.return_value = Mock(id=[mock_trace_id]) + + # Patch the invariant SDK client + monkeypatch.setattr("mcp_scan_server.record_file.invariant_sdk_client", mock_client) + + # Create the trace client mapping + trace_client_mapping = TraceClientMapping() + + return mock_client, trace_client_mapping + + +@pytest.mark.asyncio +async def test_push_session_to_record_file_explorer_first_time_calls_create_request_and_push_trace(monkeypatch): + """ + Test that we call create_request_and_push_trace when pushing a session to the record file for the first time. + """ + _, session = _setup_test_session_and_session_store() + mock_client, trace_client_mapping = _setup_test_session_and_mock(monkeypatch) + mock_trace_id = "test_trace_id" + + # Push the session to the record file + trace_id = await push_session_to_record_file(session, ExplorerRecordFile("test"), "client_name") + + # Check that the trace id is set + assert trace_id == mock_trace_id + + # Check that the trace id is in the trace client mapping + assert trace_client_mapping.get_trace_id("client_name") == mock_trace_id + + # Check that the invariant sdk client was called with the correct arguments + mock_client.create_request_and_push_trace.assert_called_once_with( + messages=[[{"role": "user", "content": "test"}]], + dataset="test", + metadata=ANY, # Ignore metadata + ) + + +@pytest.mark.asyncio +async def test_push_session_to_record_file_explorer_second_time_calls_append_messages(monkeypatch): + """ + Test that we call append_messages when pushing a session to the record file for the second time. + """ + _, session = _setup_test_session_and_session_store() + mock_client, trace_client_mapping = _setup_test_session_and_mock(monkeypatch) + mock_trace_id = "test_trace_id" + message = {"role": "assistant", "content": "response"} + + # First push to set up the trace ID + await push_session_to_record_file(session, ExplorerRecordFile("test"), "client_name") + + # Add a new message to the session + session.nodes.append( + SessionNode( + timestamp=datetime.datetime.now(), + message=message, + session_id="session_id", + server_name="server_name", + original_session_index=1, + ) + ) + + # Second push should append + trace_id = await push_session_to_record_file(session, ExplorerRecordFile("test"), "client_name") + + # Check that we got the same trace ID back + assert trace_id == mock_trace_id + + # Check that the trace id is in the trace client mapping + assert trace_client_mapping.get_trace_id("client_name") == mock_trace_id + + # Check that create_request_and_push_trace was only called once (from the first push) + mock_client.create_request_and_push_trace.assert_called_once() + + # Check that append_messages was called with the correct arguments + mock_client.create_request_and_append_messages.assert_called_once_with( + messages=[message], + trace_id=mock_trace_id, + ) + + +def _setup_test_session_and_local_file(tmp_path): + """Helper function to set up test session and local file.""" + + # Create the trace client mapping + trace_client_mapping = TraceClientMapping() + + # Create a LocalRecordFile with the temp path + record_file = LocalRecordFile("test.jsonl", base_path=str(tmp_path)) + + return trace_client_mapping, record_file + + +@pytest.mark.asyncio +async def test_push_session_to_record_file_local_creates_file_and_writes_to_it(tmp_path): + """ + Test that we create a file and write to it when pushing a session to the record file for the first time. + Also check that the path is set correctly to the LocalRecordFile data. + """ + _, session = _setup_test_session_and_session_store() + trace_client_mapping, record_file = _setup_test_session_and_local_file(tmp_path) + + # Push the session to the record file + trace_id = await push_session_to_record_file(session, record_file, "client_name") + + # Check that the trace id is set + assert trace_id is not None + assert trace_client_mapping.get_trace_id("client_name") == trace_id + + # Check that the file was created with the correct path + expected_path = record_file.get_session_file_path("client_name") + assert os.path.exists(expected_path) + + # Check that the file contains the correct content + with open(expected_path) as f: + content = f.read().strip() + assert content == json.dumps({"role": "user", "content": "test"}) + + +@pytest.mark.asyncio +async def test_push_session_to_record_file_local_second_time_appends_to_file(tmp_path): + """ + Test that we append to the file when pushing a session to the record file for the second time. + """ + _, session = _setup_test_session_and_session_store() + trace_client_mapping, record_file = _setup_test_session_and_local_file(tmp_path) + + # First push to set up the trace ID + trace_id = await push_session_to_record_file(session, record_file, "client_name") + + # Add a new message to the session + message = {"role": "assistant", "content": "response"} + session.nodes.append( + SessionNode( + timestamp=datetime.datetime.now(), + message=message, + session_id="session_id", + server_name="server_name", + original_session_index=1, + ) + ) + + # Second push should append + new_trace_id = await push_session_to_record_file(session, record_file, "client_name") + + # Check that we got the same trace ID back + assert new_trace_id == trace_id + + # Check that the trace id is in the trace client mapping + assert trace_client_mapping.get_trace_id("client_name") == new_trace_id + + # Check that the file contains both messages + expected_path = record_file.get_session_file_path("client_name") + with open(expected_path) as f: + lines = f.readlines() + assert len(lines) == 2 + assert json.loads(lines[0].strip()) == {"role": "user", "content": "test"} + assert json.loads(lines[1].strip()) == message diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 5e574f9..77fa0a4 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1,4 +1,5 @@ import datetime +import json import pytest @@ -9,6 +10,18 @@ def create_timestamped_node(timestamp: datetime.datetime): return SessionNode(timestamp=timestamp, message={}, session_id="", server_name="", original_session_index=0) +# create a cleanup function to delete the session store and make it run after each test +def cleanup_session_store(): + session_store = SessionStore() + session_store.clear() + + +@pytest.fixture(autouse=True) +def cleanup_session_store_after_test(): + yield + cleanup_session_store() + + @pytest.fixture def some_date(): return datetime.datetime(2021, 1, 1, 12, 0, 0) @@ -169,3 +182,85 @@ async def test_to_session_function(): original_session_index=1, ), ] + + +def test_session_node_to_json(): + session_node = SessionNode( + timestamp=datetime.datetime.now(), + message={"role": "user", "content": "Hello, world!"}, + session_id="session_id", + server_name="server_name", + original_session_index=0, + ) + + session = Session(nodes=[session_node]) + + session_store = SessionStore() + session_store["client_name"] = session + + session_store_json = session_store.to_json() + assert session_store_json is not None + + +def test_json_serialization(): + """Test that the JSON serialization works correctly for all classes.""" + timestamp = datetime.datetime(2024, 1, 1, 12, 0, 0) + message = {"role": "user", "content": "Hello, world!"} + + # Create a session node + node = SessionNode( + timestamp=timestamp, + message=message, + session_id="test_session", + server_name="test_server", + original_session_index=0, + ) + + # Test SessionNode JSON serialization + node_dict = node.to_json() + assert node_dict == message + + # Create a session with the node + session = Session(nodes=[node]) + + # Test Session JSON serialization + session_dict = session.to_json() + assert session_dict == [message] + + # Create a session store with the session + store = SessionStore() + store["test_client"] = session + + # Test SessionStore JSON serialization + store_json = store.to_json() + assert store_json == {"sessions": {"test_client": [message]}} + + # Finally test that we can dump and load + store_json_str = json.dumps(store_json) + assert store_json_str is not None + + store_dict = json.loads(store_json_str) + assert store_dict == {"sessions": {"test_client": [message]}} + + +def test_session_store_shares_state(): + """Test that the session store shares state between multiple instances.""" + + # Create and populate the session store + session_store = SessionStore() + session_store["client_name_1"] = Session( + nodes=[ + SessionNode( + timestamp=datetime.datetime.now(), + message={"role": "user", "content": "Hello, world!"}, + session_id="session_id", + server_name="server_name", + original_session_index=0, + ) + ] + ) + + # Create new session store and check that the session is shared + session_store_2 = SessionStore() + assert session_store_2["client_name_1"] is not None + assert session_store_2["client_name_1"].nodes == session_store["client_name_1"].nodes From a7a3734f5fca8584192c6f2fec99cf463375a792 Mon Sep 17 00:00:00 2001 From: knielsen404 Date: Tue, 27 May 2025 17:16:28 +0100 Subject: [PATCH 5/9] add annotations to explorer push --- src/mcp_scan_server/record_file.py | 28 +++++++++++++++++++++------- src/mcp_scan_server/routes/push.py | 6 +++--- src/mcp_scan_server/routes/trace.py | 11 ++++++++++- tests/unit/test_record_file.py | 2 ++ 4 files changed, 36 insertions(+), 11 deletions(-) diff --git a/src/mcp_scan_server/record_file.py b/src/mcp_scan_server/record_file.py index 2c9a63a..ff9cad0 100644 --- a/src/mcp_scan_server/record_file.py +++ b/src/mcp_scan_server/record_file.py @@ -2,6 +2,7 @@ import os import uuid from dataclasses import dataclass +from typing import Any import rich from invariant_sdk.client import Client @@ -134,7 +135,10 @@ def parse_record_file_name(record_file: str | None) -> RecordFile | None: async def _push_session_to_explorer( - session_data: list[Message], record_file: ExplorerRecordFile, client_name: str + session_data: list[Message], + record_file: ExplorerRecordFile, + client_name: str, + annotations: dict[str, Any] | None = None, ) -> str | None: """ Push the session to the explorer. @@ -148,6 +152,7 @@ async def _push_session_to_explorer( response = invariant_sdk_client.create_request_and_push_trace( messages=[session_data], dataset=record_file.dataset_name, + annotations=annotations, metadata=[ { "hierarchy_path": [client_name], @@ -164,7 +169,10 @@ async def _push_session_to_explorer( async def _push_session_to_local_file( - session_data: list[Message], record_file: LocalRecordFile, client_name: str + session_data: list[Message], + record_file: LocalRecordFile, + client_name: str, + annotations: dict[str, Any] | None = None, ) -> str | None: """ Push the session to the local file. @@ -198,7 +206,10 @@ async def _format_session_data(session: Session, index: int) -> list[Message]: async def _append_messages_to_explorer( - trace_id: str, session_data: list[Message], record_file: ExplorerRecordFile + trace_id: str, + session_data: list[Message], + record_file: ExplorerRecordFile, + annotations: dict[str, Any] | None = None, ) -> None: """ Append messages to the explorer. @@ -211,11 +222,12 @@ async def _append_messages_to_explorer( invariant_sdk_client.create_request_and_append_messages( messages=session_data, trace_id=trace_id, + annotations=annotations, ) async def _append_messages_to_local_file( - trace_id: str, session_data: list[Message], record_file: LocalRecordFile + trace_id: str, session_data: list[Message], record_file: LocalRecordFile, annotations: dict[str, Any] | None = None ) -> None: """ Append messages to the local file. @@ -266,7 +278,9 @@ async def push_session_to_record_file(session: Session, record_file: RecordFile, return trace_id -async def append_messages_to_record_file(trace_id: str, record_file: RecordFile) -> None: +async def append_messages_to_record_file( + trace_id: str, record_file: RecordFile, annotations: dict[str, Any] | None = None +) -> None: """ Append messages to the record file. """ @@ -284,9 +298,9 @@ async def append_messages_to_record_file(trace_id: str, record_file: RecordFile) # Otherwise, append to the record file if isinstance(record_file, ExplorerRecordFile): - await _append_messages_to_explorer(trace_id, session_data, record_file) + await _append_messages_to_explorer(trace_id, session_data, record_file, annotations) elif isinstance(record_file, LocalRecordFile): - await _append_messages_to_local_file(trace_id, session_data, record_file) + await _append_messages_to_local_file(trace_id, session_data, record_file, annotations) else: raise ValueError(f"Invalid record file: {record_file}") diff --git a/src/mcp_scan_server/routes/push.py b/src/mcp_scan_server/routes/push.py index 0b7549a..a46a05a 100644 --- a/src/mcp_scan_server/routes/push.py +++ b/src/mcp_scan_server/routes/push.py @@ -16,9 +16,9 @@ async def push_trace(request: Request) -> PushTracesResponse: """Push a trace. For now, this is a dummy response.""" - body = await request.body() - client = json.loads(body) - mcp_client = client.get("metadata")[0].get("client") + request_data = await request.body() + request_data = json.loads(request_data) + mcp_client = request_data.get("metadata")[0].get("client") session = session_store[mcp_client] record_file = request.app.state.record_file diff --git a/src/mcp_scan_server/routes/trace.py b/src/mcp_scan_server/routes/trace.py index 59fde53..66087f1 100644 --- a/src/mcp_scan_server/routes/trace.py +++ b/src/mcp_scan_server/routes/trace.py @@ -1,3 +1,5 @@ +import json + from fastapi import APIRouter, Request from ..record_file import append_messages_to_record_file @@ -13,8 +15,15 @@ async def append_messages(trace_id: str, request: Request): """Append messages to a trace. For now this is a dummy response.""" + request_data = await request.body() + request_data = json.loads(request_data) + # If we are calling append, we should already have set the trace_id if request.app.state.record_file: - await append_messages_to_record_file(trace_id, request.app.state.record_file) + await append_messages_to_record_file( + trace_id, + request.app.state.record_file, + annotations=request_data.get("annotations"), + ) return {"success": True} diff --git a/tests/unit/test_record_file.py b/tests/unit/test_record_file.py index 0e2b2f3..cd80137 100644 --- a/tests/unit/test_record_file.py +++ b/tests/unit/test_record_file.py @@ -153,6 +153,7 @@ async def test_push_session_to_record_file_explorer_first_time_calls_create_requ messages=[[{"role": "user", "content": "test"}]], dataset="test", metadata=ANY, # Ignore metadata + annotations=ANY, ) @@ -196,6 +197,7 @@ async def test_push_session_to_record_file_explorer_second_time_calls_append_mes mock_client.create_request_and_append_messages.assert_called_once_with( messages=[message], trace_id=mock_trace_id, + annotations=ANY, ) From 875cc5720efa35cd9b3163656aec70aea9317dc3 Mon Sep 17 00:00:00 2001 From: knielsen404 Date: Tue, 3 Jun 2025 14:04:32 +0200 Subject: [PATCH 6/9] refactor --- src/mcp_scan_server/record_file.py | 202 +++++++++++++---------------- tests/unit/test_record_file.py | 26 ++-- 2 files changed, 105 insertions(+), 123 deletions(-) diff --git a/src/mcp_scan_server/record_file.py b/src/mcp_scan_server/record_file.py index 5d311eb..436c993 100644 --- a/src/mcp_scan_server/record_file.py +++ b/src/mcp_scan_server/record_file.py @@ -65,9 +65,6 @@ def generate_record_file_postfix() -> str: return str(uuid.uuid4())[:8] -# Initialize the invariant sdk client if the API key is set -invariant_sdk_client = Client() if os.environ.get("INVARIANT_API_KEY") else None - # Initialize the trace client mapping and session store trace_client_mapping = TraceClientMapping() @@ -84,24 +81,79 @@ def startup_message(self) -> str: """Return a message to be printed on startup.""" raise NotImplementedError("Subclasses must implement this method") + def push( + self, session_data: list[Message], client_name: str, annotations: dict[str, Any] | None = None + ) -> str | None: + raise NotImplementedError("Subclasses must implement this method") + + def append(self, session_data: list[Message], trace_id: str, annotations: dict[str, Any] | None = None) -> None: + raise NotImplementedError("Subclasses must implement this method") + @dataclass(frozen=True) class ExplorerRecordFile(RecordFile): """Record file for explorer datasets.""" dataset_name: str + client: Client def startup_message(self) -> str: """Return a message to be printed on startup.""" return self.__message_styling_wrapper__(f"explorer dataset: '{self.dataset_name}'") + def _check_is_key_set(self) -> None: + """Check if the invariant API key is set.""" + if not os.environ.get("INVARIANT_API_KEY"): + raise ValueError( + "Invariant SDK client is not initialized. Please set the INVARIANT_API_KEY environment variable." + ) + + def push( + self, session_data: list[Message], client_name: str, annotations: dict[str, Any] | None = None + ) -> str | None: + """Push the session to the explorer.""" + self._check_is_key_set() + + try: + response = self.client.create_request_and_push_trace( + messages=[session_data], + dataset=self.dataset_name, + annotations=annotations, + metadata=[ + { + "hierarchy_path": [client_name], + } + ], + ) + except Exception as e: + rich.print(f"[bold red]Error pushing session to explorer: {e}[/bold red]") + return None + + trace_id = response.id[0] + trace_client_mapping.set_trace_id(trace_id, client_name) + return trace_id + + def append(self, session_data: list[Message], trace_id: str, annotations: dict[str, Any] | None = None) -> None: + """Append the session data to the explorer.""" + self._check_is_key_set() + + try: + self.client.create_request_and_append_messages( + messages=session_data, + trace_id=trace_id, + annotations=annotations, + ) + except Exception as e: + rich.print(f"[bold red]Error appending messages to explorer: {e}[/bold red]") + return None + @dataclass(frozen=True) class LocalRecordFile(RecordFile): """Record file for local files.""" directory_name: str - base_path: str = os.path.expanduser("~/.mcp-scan/sessions/") + base_path: str postfix: str = str(uuid.uuid4())[:8] def startup_message(self) -> str: @@ -122,8 +174,35 @@ def get_session_file_path(self, client_name: str | None) -> str: client_name = client_name or "unknown" return self.get_filepath(client_name) + def push( + self, session_data: list[Message], client_name: str, annotations: dict[str, Any] | None = None + ) -> str | None: + """Push the session to the local file.""" + os.makedirs(self.get_directory(), exist_ok=True) + trace_id = str(uuid.uuid4()) + trace_client_mapping.set_trace_id(trace_id, client_name) + + session_file_path = self.get_session_file_path(client_name) + + # Write each message as a JSONL line + with open(session_file_path, "w") as f: + for message in session_data: + f.write(json.dumps(message) + "\n") + + return trace_id + + def append(self, session_data: list[Message], trace_id: str, annotations: dict[str, Any] | None = None) -> None: + """Append the session data to the local file.""" + client_name = trace_client_mapping.get_client_name(trace_id) + session_file_path = self.get_session_file_path(client_name) + + # Append each message as a new line + with open(session_file_path, "a") as f: + for message in session_data: + f.write(json.dumps(message) + "\n") + -def parse_record_file_name(record_file: str | None) -> RecordFile | None: +def parse_record_file_name(record_file: str | None, base_path: str | None = None) -> RecordFile | None: """Parse the record file name and return a RecordFile object.""" if record_file is None: return None @@ -131,12 +210,14 @@ def parse_record_file_name(record_file: str | None) -> RecordFile | None: # Check if it has the form explorer:{dataset_name} if record_file.startswith("explorer:"): dataset_name = record_file.split(":")[1] - return ExplorerRecordFile(dataset_name) + return ExplorerRecordFile(dataset_name, Client()) # Check if it has the form local:{directory_name} elif record_file.startswith("local:"): directory_name = record_file.split(":")[1] - file_object = LocalRecordFile(directory_name) + if base_path is None: + base_path = os.path.expanduser("~/.mcp-scan/sessions/") + file_object = LocalRecordFile(directory_name, base_path=base_path) os.makedirs(file_object.get_directory(), exist_ok=True) return file_object @@ -147,63 +228,6 @@ def parse_record_file_name(record_file: str | None) -> RecordFile | None: ) -async def _push_session_to_explorer( - session_data: list[Message], - record_file: ExplorerRecordFile, - client_name: str, - annotations: dict[str, Any] | None = None, -) -> str | None: - """ - Push the session to the explorer. - """ - if invariant_sdk_client is None: - raise ValueError( - "Invariant SDK client is not initialized. Please set the INVARIANT_API_KEY environment variable." - ) - - try: - response = invariant_sdk_client.create_request_and_push_trace( - messages=[session_data], - dataset=record_file.dataset_name, - annotations=annotations, - metadata=[ - { - "hierarchy_path": [client_name], - } - ], - ) - - trace_id = response.id[0] - trace_client_mapping.set_trace_id(trace_id, client_name) - return trace_id - except Exception as e: - rich.print(f"[bold red]Error pushing session to explorer: {e}[/bold red]") - return None - - -async def _push_session_to_local_file( - session_data: list[Message], - record_file: LocalRecordFile, - client_name: str, - annotations: dict[str, Any] | None = None, -) -> str | None: - """ - Push the session to the local file. - """ - os.makedirs(record_file.base_path, exist_ok=True) - trace_id = str(uuid.uuid4()) - trace_client_mapping.set_trace_id(trace_id, client_name) - - session_file_path = record_file.get_session_file_path(client_name) - - # Write each message as a JSONL line - with open(session_file_path, "w") as f: - for message in session_data: - f.write(json.dumps(message) + "\n") - - return trace_id - - async def _format_session_data(session: Session, index: int) -> list[Message]: """ Format the session data for the record file. @@ -218,42 +242,6 @@ async def _format_session_data(session: Session, index: int) -> list[Message]: return messages -async def _append_messages_to_explorer( - trace_id: str, - session_data: list[Message], - record_file: ExplorerRecordFile, - annotations: dict[str, Any] | None = None, -) -> None: - """ - Append messages to the explorer. - """ - if invariant_sdk_client is None: - raise ValueError( - "Invariant SDK client is not initialized. Please set the INVARIANT_API_KEY environment variable." - ) - - invariant_sdk_client.create_request_and_append_messages( - messages=session_data, - trace_id=trace_id, - annotations=annotations, - ) - - -async def _append_messages_to_local_file( - trace_id: str, session_data: list[Message], record_file: LocalRecordFile, annotations: dict[str, Any] | None = None -) -> None: - """ - Append messages to the local file. - """ - client_name = trace_client_mapping.get_client_name(trace_id) - session_file_path = record_file.get_session_file_path(client_name) - - # Append each message as a new line - with open(session_file_path, "a") as f: - for message in session_data: - f.write(json.dumps(message) + "\n") - - async def push_session_to_record_file( session: Session, record_file: RecordFile, client_name: str, session_store: SessionStore ) -> str | None: @@ -280,12 +268,7 @@ async def push_session_to_record_file( return trace_id # Otherwise, push to the record file - if isinstance(record_file, ExplorerRecordFile): - trace_id = await _push_session_to_explorer(session_data, record_file, client_name) - elif isinstance(record_file, LocalRecordFile): - trace_id = await _push_session_to_local_file(session_data, record_file, client_name) - else: - raise ValueError(f"Invalid record file: {record_file}") + trace_id = record_file.push(session_data, client_name) # Update the last pushed index session.last_pushed_index += len(session_data) @@ -316,12 +299,7 @@ async def append_messages_to_record_file( return # Otherwise, append to the record file - if isinstance(record_file, ExplorerRecordFile): - await _append_messages_to_explorer(trace_id, session_data, record_file, annotations) - elif isinstance(record_file, LocalRecordFile): - await _append_messages_to_local_file(trace_id, session_data, record_file, annotations) - else: - raise ValueError(f"Invalid record file: {record_file}") + record_file.append(session_data, trace_id, annotations) # Update the last pushed index session.last_pushed_index += len(session_data) diff --git a/tests/unit/test_record_file.py b/tests/unit/test_record_file.py index bbe9e09..36a047c 100644 --- a/tests/unit/test_record_file.py +++ b/tests/unit/test_record_file.py @@ -25,7 +25,8 @@ def cleanup_trace_client_mapping(): @pytest.mark.parametrize("filename", ["explorer:test", "explorer:test", "explorer:test"]) -def test_parse_record_filename_explorer_valid(filename): +def test_parse_record_filename_explorer_valid(filename, monkeypatch): + monkeypatch.setattr("mcp_scan_server.record_file.Client", Mock()) file = parse_record_file_name(filename) assert file.dataset_name == filename.split(":")[1] assert isinstance(file, ExplorerRecordFile) @@ -111,8 +112,8 @@ def _setup_test_session_and_mock(monkeypatch, mock_trace_id="test_trace_id"): mock_client = Mock() mock_client.create_request_and_push_trace.return_value = Mock(id=[mock_trace_id]) - # Patch the invariant SDK client - monkeypatch.setattr("mcp_scan_server.record_file.invariant_sdk_client", mock_client) + monkeypatch.setattr("invariant_sdk.client.Client", mock_client) + monkeypatch.setattr("mcp_scan_server.record_file.ExplorerRecordFile._check_is_key_set", Mock()) # Create the trace client mapping trace_client_mapping = TraceClientMapping() @@ -130,7 +131,9 @@ async def test_push_session_to_record_file_explorer_first_time_calls_create_requ mock_trace_id = "test_trace_id" # Push the session to the record file - trace_id = await push_session_to_record_file(session, ExplorerRecordFile("test"), "client_name", session_store) + trace_id = await push_session_to_record_file( + session, ExplorerRecordFile("test", mock_client), "client_name", session_store + ) # Check that the trace id is set assert trace_id == mock_trace_id @@ -158,7 +161,7 @@ async def test_push_session_to_record_file_explorer_second_time_calls_append_mes message = {"role": "assistant", "content": "response"} # First push to set up the trace ID - await push_session_to_record_file(session, ExplorerRecordFile("test"), "client_name", session_store) + await push_session_to_record_file(session, ExplorerRecordFile("test", mock_client), "client_name", session_store) # Add a new message to the session session.nodes.append( @@ -171,7 +174,9 @@ async def test_push_session_to_record_file_explorer_second_time_calls_append_mes ) ) # Second push should append - trace_id = await push_session_to_record_file(session, ExplorerRecordFile("test"), "client_name", session_store) + trace_id = await push_session_to_record_file( + session, ExplorerRecordFile("test", mock_client), "client_name", session_store + ) # Check that we got the same trace ID back assert trace_id == mock_trace_id @@ -195,9 +200,8 @@ def _setup_test_session_and_local_file(tmp_path): # Create the trace client mapping trace_client_mapping = TraceClientMapping() - print(tmp_path) - record_file = parse_record_file_name("local:test") + record_file = parse_record_file_name("local:test", base_path=str(tmp_path)) assert isinstance(record_file, LocalRecordFile), "Should have LocalRecordFile" return trace_client_mapping, record_file @@ -264,7 +268,7 @@ async def test_push_session_to_record_file_local_second_time_appends_to_file(tmp # Check that the file contains both messages expected_path = record_file.get_session_file_path("client_name") with open(expected_path) as f: - lines = f.readlines() + lines = [line.strip() for line in f.readlines() if line.strip()] # Filter out empty lines assert len(lines) == 2 - assert json.loads(lines[0].strip()) == {"role": "user", "content": "test"} - assert json.loads(lines[1].strip()) == message + assert json.loads(lines[0]) == {"role": "user", "content": "test"} + assert json.loads(lines[1]) == message From 1afb355f6ce66b677190e4bb8674a41a98e5b0a8 Mon Sep 17 00:00:00 2001 From: knielsen404 Date: Tue, 12 Aug 2025 16:38:40 +0200 Subject: [PATCH 7/9] debug: verbose ci --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 3c7bb6c..b999e8e 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ test: ci: export MCP_SCAN_ENVIRONMENT=ci uv pip install -e .[test] - uv run pytest + uv run pytest -vv clean: rm -rf ./dist From 8eb8889d0ad04c6a273a6c4f6c7628e845c203c4 Mon Sep 17 00:00:00 2001 From: knielsen404 Date: Tue, 12 Aug 2025 16:45:35 +0200 Subject: [PATCH 8/9] debug: add more logging --- tests/e2e/test_full_proxy_flow.py | 92 +++++++++++++++++++++++++------ 1 file changed, 75 insertions(+), 17 deletions(-) diff --git a/tests/e2e/test_full_proxy_flow.py b/tests/e2e/test_full_proxy_flow.py index b4b5ae0..ae02ce5 100644 --- a/tests/e2e/test_full_proxy_flow.py +++ b/tests/e2e/test_full_proxy_flow.py @@ -25,23 +25,31 @@ def safe_decode(bytes_output, encoding="utf-8", errors="replace"): async def run_toy_server_client(config): - async with get_client(config) as (read, write): - async with ClientSession(read, write) as session: - print("[Client] Initializing connection") - await session.initialize() - print("[Client] Listing tools") - tools = await session.list_tools() - print("[Client] Tools: ", tools.tools) - - print("[Client] Calling tool add") - result = await session.call_tool("add", arguments={"a": 1, "b": 2}) - result = result.content[0].text - print("[Client] Result: ", result) - - return { - "result": result, - "tools": tools.tools, - } + try: + async with get_client(config) as (read, write): + async with ClientSession(read, write) as session: + print("[Client] Initializing connection") + await session.initialize() + print("[Client] Listing tools") + tools = await session.list_tools() + print("[Client] Tools: ", tools.tools) + + print("[Client] Calling tool add") + result = await session.call_tool("add", arguments={"a": 1, "b": 2}) + result = result.content[0].text + print("[Client] Result: ", result) + + return { + "result": result, + "tools": tools.tools, + } + except Exception as e: + print(f"[Client] Error during MCP communication: {e}") + print(f"[Client] Error type: {type(e)}") + import traceback + + print(f"[Client] Traceback: {traceback.format_exc()}") + raise return result @@ -151,16 +159,54 @@ async def test_basic(self, toy_server_add_config_file, pretty): + safe_decode(stderr) ) + # Debug: Check process status and get any output so far + print(f"[DEBUG] Process status: {process.poll()}") + + # Add a small delay to let the server start up properly + print("[DEBUG] Waiting for server to start up...") + await asyncio.sleep(2) + + # Check if process is still running + if process.poll() is not None: + stdout, stderr = process.communicate() + print("[DEBUG] Process terminated unexpectedly") + print(f"[DEBUG] stdout: {safe_decode(stdout)}") + print(f"[DEBUG] stderr: {safe_decode(stderr)}") + raise AssertionError("Process terminated unexpectedly during startup") + with open(toy_server_add_config_file) as f: # assert that 'invariant-gateway' is in the file content = f.read() print(content) + # Debug: Try to check if the server is actually responding + print("[DEBUG] Checking if server is responding...") + try: + import aiohttp + + async with aiohttp.ClientSession() as session: + async with session.get(f"http://localhost:{self.PORT}/") as response: + print(f"[DEBUG] Server health check response: {response.status}") + if response.status == 200: + print("[DEBUG] Server is responding to HTTP requests") + else: + print(f"[DEBUG] Server returned unexpected status: {response.status}") + except Exception as e: + print(f"[DEBUG] Server health check failed: {e}") + print("[DEBUG] This might be expected if server is still starting up") + # start client config = await scan_mcp_config_file(toy_server_add_config_file) servers = list(config.mcpServers.values()) assert len(servers) == 1 server = servers[0] + + # Debug: Print server config before starting client + print(f"[DEBUG] Server config: {server}") + print(f"[DEBUG] Server command: {server.command}") + print(f"[DEBUG] Server args: {server.args}") + print(f"[DEBUG] Server env: {server.env}") + client_program = run_toy_server_client(server) # wait for client to finish @@ -174,6 +220,18 @@ async def test_basic(self, toy_server_add_config_file, pretty): print(safe_decode(stdout)) print(safe_decode(stderr)) raise AssertionError("timed out waiting for MCP server to respond") from e + except Exception as e: + print(f"Client failed with exception: {e}") + print(f"Exception type: {type(e)}") + # Get any output from the process before terminating + stdout, stderr = process.communicate() + print("=== PROCESS OUTPUT ===") + print(safe_decode(stdout)) + print("=== PROCESS ERROR ===") + print(safe_decode(stderr)) + process.terminate() + process.wait() + raise assert int(client_output["result"]) == 3 From f0934d363db9c3f94388d04da0447cfc1aafb39e Mon Sep 17 00:00:00 2001 From: knielsen404 Date: Tue, 12 Aug 2025 16:48:37 +0200 Subject: [PATCH 9/9] debug: revert change --- tests/e2e/test_full_proxy_flow.py | 73 +++++++------------------------ 1 file changed, 17 insertions(+), 56 deletions(-) diff --git a/tests/e2e/test_full_proxy_flow.py b/tests/e2e/test_full_proxy_flow.py index ae02ce5..bdf1be5 100644 --- a/tests/e2e/test_full_proxy_flow.py +++ b/tests/e2e/test_full_proxy_flow.py @@ -25,31 +25,23 @@ def safe_decode(bytes_output, encoding="utf-8", errors="replace"): async def run_toy_server_client(config): - try: - async with get_client(config) as (read, write): - async with ClientSession(read, write) as session: - print("[Client] Initializing connection") - await session.initialize() - print("[Client] Listing tools") - tools = await session.list_tools() - print("[Client] Tools: ", tools.tools) - - print("[Client] Calling tool add") - result = await session.call_tool("add", arguments={"a": 1, "b": 2}) - result = result.content[0].text - print("[Client] Result: ", result) - - return { - "result": result, - "tools": tools.tools, - } - except Exception as e: - print(f"[Client] Error during MCP communication: {e}") - print(f"[Client] Error type: {type(e)}") - import traceback - - print(f"[Client] Traceback: {traceback.format_exc()}") - raise + async with get_client(config) as (read, write): + async with ClientSession(read, write) as session: + print("[Client] Initializing connection") + await session.initialize() + print("[Client] Listing tools") + tools = await session.list_tools() + print("[Client] Tools: ", tools.tools) + + print("[Client] Calling tool add") + result = await session.call_tool("add", arguments={"a": 1, "b": 2}) + result = result.content[0].text + print("[Client] Result: ", result) + + return { + "result": result, + "tools": tools.tools, + } return result @@ -159,42 +151,11 @@ async def test_basic(self, toy_server_add_config_file, pretty): + safe_decode(stderr) ) - # Debug: Check process status and get any output so far - print(f"[DEBUG] Process status: {process.poll()}") - - # Add a small delay to let the server start up properly - print("[DEBUG] Waiting for server to start up...") - await asyncio.sleep(2) - - # Check if process is still running - if process.poll() is not None: - stdout, stderr = process.communicate() - print("[DEBUG] Process terminated unexpectedly") - print(f"[DEBUG] stdout: {safe_decode(stdout)}") - print(f"[DEBUG] stderr: {safe_decode(stderr)}") - raise AssertionError("Process terminated unexpectedly during startup") - with open(toy_server_add_config_file) as f: # assert that 'invariant-gateway' is in the file content = f.read() print(content) - # Debug: Try to check if the server is actually responding - print("[DEBUG] Checking if server is responding...") - try: - import aiohttp - - async with aiohttp.ClientSession() as session: - async with session.get(f"http://localhost:{self.PORT}/") as response: - print(f"[DEBUG] Server health check response: {response.status}") - if response.status == 200: - print("[DEBUG] Server is responding to HTTP requests") - else: - print(f"[DEBUG] Server returned unexpected status: {response.status}") - except Exception as e: - print(f"[DEBUG] Server health check failed: {e}") - print("[DEBUG] This might be expected if server is still starting up") - # start client config = await scan_mcp_config_file(toy_server_add_config_file) servers = list(config.mcpServers.values())