diff --git a/Makefile b/Makefile index 3c7bb6c..b999e8e 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ test: ci: export MCP_SCAN_ENVIRONMENT=ci uv pip install -e .[test] - uv run pytest + uv run pytest -vv clean: rm -rf ./dist diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/mcp_scan/cli.py b/src/mcp_scan/cli.py index 275dd51..e86d030 100644 --- a/src/mcp_scan/cli.py +++ b/src/mcp_scan/cli.py @@ -126,7 +126,26 @@ def add_server_arguments(parser): choices=["oneline", "compact", "full", "none"], help="Pretty print the output (default: compact)", ) - server_group.add_argument( + + +def add_mcp_scan_server_arguments(parser): + """Add arguments related to MCP scan server.""" + mcp_scan_server_group = parser.add_argument_group("MCP Scan Server Options") + mcp_scan_server_group.add_argument( + "--port", + type=int, + default=8129, + help="Port to run the server on (default: 8129).", + metavar="PORT", + ) + mcp_scan_server_group.add_argument( + "--record", + type=str, + default=None, + help="Filename to record the proxy requests to.", + metavar="RECORD_FILE", + ) + mcp_scan_server_group.add_argument( "--install-extras", nargs="+", default=None, @@ -134,6 +153,8 @@ def add_server_arguments(parser): metavar="EXTRA", ) + return mcp_scan_server_group + def add_install_arguments(parser): parser.add_argument( @@ -374,28 +395,16 @@ def main(): # SERVER command server_parser = subparsers.add_parser("server", help="Start the MCP scan server") - server_parser.add_argument( - "--port", - type=int, - default=8129, - help="Port to run the server on (default: 8129)", - metavar="PORT", - ) add_common_arguments(server_parser) add_server_arguments(server_parser) + add_mcp_scan_server_arguments(server_parser) # PROXY command proxy_parser = subparsers.add_parser("proxy", help="Installs and proxies MCP requests, uninstalls on exit") - proxy_parser.add_argument( - "--port", - type=int, - default=8129, - help="Port to run the server on (default: 8129)", - metavar="PORT", - ) add_common_arguments(proxy_parser) add_server_arguments(proxy_parser) add_install_arguments(proxy_parser) + add_mcp_scan_server_arguments(proxy_parser) # Parse arguments (default to 'scan' if no command provided) if len(sys.argv) == 1 or sys.argv[1] not in subparsers.choices: @@ -438,7 +447,11 @@ def server(on_exit=None): sf = StorageFile(args.storage_file) guardrails_config_path = sf.create_guardrails_config() mcp_scan_server = MCPScanServer( - port=args.port, config_file_path=guardrails_config_path, on_exit=on_exit, pretty=args.pretty + port=args.port, + config_file_path=guardrails_config_path, + on_exit=on_exit, + pretty=args.pretty, + record_file=args.record, ) mcp_scan_server.run() diff --git a/src/mcp_scan_server/record_file.py b/src/mcp_scan_server/record_file.py new file mode 100644 index 0000000..436c993 --- /dev/null +++ b/src/mcp_scan_server/record_file.py @@ -0,0 +1,305 @@ +import json +import os +import uuid +from dataclasses import dataclass +from typing import Any + +import rich +from invariant_sdk.client import Client + +from .session_store import Message, Session, SessionStore + + +class TraceClientMapping: + """ + A singleton class to store the mapping between trace ids and client names. + + This is used to ensure that a trace id is generated only once for a given client and + that it is consistent, so we can append to explorer properly. + """ + + _instance = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance.trace_id_to_client_name = {} + cls._instance.client_name_to_trace_id = {} + return cls._instance + + def get_client_name(self, trace_id: str) -> str | None: + """ + Get the client name for the given trace id. + """ + return self.trace_id_to_client_name.get(trace_id, None) + + def get_trace_id(self, client_name: str) -> str | None: + """ + Get the trace id for the given client name. + """ + return self.client_name_to_trace_id.get(client_name, None) + + def set_trace_id(self, trace_id: str, client_name: str) -> None: + """ + Set the trace id for the given client name. + """ + self.trace_id_to_client_name[trace_id] = client_name + self.client_name_to_trace_id[client_name] = trace_id + + def clear(self) -> None: + """ + Clear the trace client mapping. + """ + self.trace_id_to_client_name: dict[str, str] = {} + self.client_name_to_trace_id: dict[str, str] = {} + + def __str__(self) -> str: + return f"TraceClientMapping[{self.trace_id_to_client_name}]" + + def __repr__(self) -> str: + return self.__str__() + + +def generate_record_file_postfix() -> str: + """Generate a random postfix for the record file.""" + return str(uuid.uuid4())[:8] + + +# Initialize the trace client mapping and session store +trace_client_mapping = TraceClientMapping() + + +@dataclass(frozen=True) +class RecordFile: + """Base class for record file names.""" + + def __message_styling_wrapper__(self, message: str) -> str: + """Wrap the message in styling.""" + return f"[yellow]Recording session data to {message}[/yellow]" + + def startup_message(self) -> str: + """Return a message to be printed on startup.""" + raise NotImplementedError("Subclasses must implement this method") + + def push( + self, session_data: list[Message], client_name: str, annotations: dict[str, Any] | None = None + ) -> str | None: + raise NotImplementedError("Subclasses must implement this method") + + def append(self, session_data: list[Message], trace_id: str, annotations: dict[str, Any] | None = None) -> None: + raise NotImplementedError("Subclasses must implement this method") + + +@dataclass(frozen=True) +class ExplorerRecordFile(RecordFile): + """Record file for explorer datasets.""" + + dataset_name: str + client: Client + + def startup_message(self) -> str: + """Return a message to be printed on startup.""" + return self.__message_styling_wrapper__(f"explorer dataset: '{self.dataset_name}'") + + def _check_is_key_set(self) -> None: + """Check if the invariant API key is set.""" + if not os.environ.get("INVARIANT_API_KEY"): + raise ValueError( + "Invariant SDK client is not initialized. Please set the INVARIANT_API_KEY environment variable." + ) + + def push( + self, session_data: list[Message], client_name: str, annotations: dict[str, Any] | None = None + ) -> str | None: + """Push the session to the explorer.""" + self._check_is_key_set() + + try: + response = self.client.create_request_and_push_trace( + messages=[session_data], + dataset=self.dataset_name, + annotations=annotations, + metadata=[ + { + "hierarchy_path": [client_name], + } + ], + ) + except Exception as e: + rich.print(f"[bold red]Error pushing session to explorer: {e}[/bold red]") + return None + + trace_id = response.id[0] + trace_client_mapping.set_trace_id(trace_id, client_name) + return trace_id + + def append(self, session_data: list[Message], trace_id: str, annotations: dict[str, Any] | None = None) -> None: + """Append the session data to the explorer.""" + self._check_is_key_set() + + try: + self.client.create_request_and_append_messages( + messages=session_data, + trace_id=trace_id, + annotations=annotations, + ) + except Exception as e: + rich.print(f"[bold red]Error appending messages to explorer: {e}[/bold red]") + return None + + +@dataclass(frozen=True) +class LocalRecordFile(RecordFile): + """Record file for local files.""" + + directory_name: str + base_path: str + postfix: str = str(uuid.uuid4())[:8] + + def startup_message(self) -> str: + """Return a message to be printed on startup.""" + return self.__message_styling_wrapper__(f"local directory: '{self.get_directory()}'") + + def get_directory(self) -> str: + """Get the directory for the record file.""" + return os.path.join(self.base_path, self.directory_name) + + def get_filepath(self, client_name: str | None) -> str: + """Get the path to the session file for a given client.""" + client_name = client_name or "unknown" + return os.path.join(self.get_directory(), f"{client_name}-{self.postfix}.jsonl") + + def get_session_file_path(self, client_name: str | None) -> str: + """Get the path to the session file for a given client.""" + client_name = client_name or "unknown" + return self.get_filepath(client_name) + + def push( + self, session_data: list[Message], client_name: str, annotations: dict[str, Any] | None = None + ) -> str | None: + """Push the session to the local file.""" + os.makedirs(self.get_directory(), exist_ok=True) + trace_id = str(uuid.uuid4()) + trace_client_mapping.set_trace_id(trace_id, client_name) + + session_file_path = self.get_session_file_path(client_name) + + # Write each message as a JSONL line + with open(session_file_path, "w") as f: + for message in session_data: + f.write(json.dumps(message) + "\n") + + return trace_id + + def append(self, session_data: list[Message], trace_id: str, annotations: dict[str, Any] | None = None) -> None: + """Append the session data to the local file.""" + client_name = trace_client_mapping.get_client_name(trace_id) + session_file_path = self.get_session_file_path(client_name) + + # Append each message as a new line + with open(session_file_path, "a") as f: + for message in session_data: + f.write(json.dumps(message) + "\n") + + +def parse_record_file_name(record_file: str | None, base_path: str | None = None) -> RecordFile | None: + """Parse the record file name and return a RecordFile object.""" + if record_file is None: + return None + + # Check if it has the form explorer:{dataset_name} + if record_file.startswith("explorer:"): + dataset_name = record_file.split(":")[1] + return ExplorerRecordFile(dataset_name, Client()) + + # Check if it has the form local:{directory_name} + elif record_file.startswith("local:"): + directory_name = record_file.split(":")[1] + if base_path is None: + base_path = os.path.expanduser("~/.mcp-scan/sessions/") + file_object = LocalRecordFile(directory_name, base_path=base_path) + os.makedirs(file_object.get_directory(), exist_ok=True) + return file_object + + # Otherwise, unknown record file type + else: + raise ValueError( + f"Invalid record file name: {record_file}. Must be of the form explorer:{{dataset_name}} or local:{{directory_name}}" + ) + + +async def _format_session_data(session: Session, index: int) -> list[Message]: + """ + Format the session data for the record file. + Only returns new messages that haven't been recorded yet. + """ + # convert session data to a list of messages, but only include new ones + messages = [] + for i, node in enumerate(session.nodes): + if i > index: + messages.append(node.message) + + return messages + + +async def push_session_to_record_file( + session: Session, record_file: RecordFile, client_name: str, session_store: SessionStore +) -> str | None: + """ + Push a session to a record file. + + This function may be called multiple times with partially the same data. The behavior is as follows: + - The first time we try to push for a given client, we push the data (either create it for local files or push to explorer) + - The next time we try to push for a given client (for example from a different server) we instead append to the record file. + We monitor which clients have been pushed to the record file by checking the trace client mapping. + + Returns the trace id. + """ + index = session.last_pushed_index + session_data = await _format_session_data(session, index) + + # If there are no new messages, return None + if not session_data: + return None + + # If we have already pushed for this client, append to the record file + if trace_id := trace_client_mapping.get_trace_id(client_name): + await append_messages_to_record_file(trace_id, record_file, session_store=session_store) + return trace_id + + # Otherwise, push to the record file + trace_id = record_file.push(session_data, client_name) + + # Update the last pushed index + session.last_pushed_index += len(session_data) + + return trace_id + + +async def append_messages_to_record_file( + trace_id: str, + record_file: RecordFile, + session_store: SessionStore, + annotations: dict[str, Any] | None = None, +) -> None: + """ + Append messages to the record file. + """ + client_name = trace_client_mapping.get_client_name(trace_id) + if client_name is None: + rich.print(f"[bold red]Trace id {trace_id} not found in trace client mapping. Cancelling append. [/bold red]") + return + + session = session_store[client_name] + index = session.last_pushed_index + session_data = await _format_session_data(session, index) + + # If there are no new messages, return + if not session_data: + return + + # Otherwise, append to the record file + record_file.append(session_data, trace_id, annotations) + + # Update the last pushed index + session.last_pushed_index += len(session_data) diff --git a/src/mcp_scan_server/routes/policies.py b/src/mcp_scan_server/routes/policies.py index 8781ec8..9e5a8b1 100644 --- a/src/mcp_scan_server/routes/policies.py +++ b/src/mcp_scan_server/routes/policies.py @@ -17,7 +17,7 @@ from pydantic import ValidationError from mcp_scan_server.activity_logger import ActivityLogger, get_activity_logger -from mcp_scan_server.session_store import SessionStore, to_session +from mcp_scan_server.session_store import SessionStore, get_session_store, to_session from ..models import ( DEFAULT_GUARDRAIL_CONFIG, @@ -30,7 +30,6 @@ from ..parse_config import parse_config router = APIRouter() -session_store = SessionStore() async def load_guardrails_config_file(config_file_path: str) -> GuardrailConfigFile: @@ -172,7 +171,7 @@ def to_json_serializable_dict(obj): async def get_messages_from_session( - check_request: BatchCheckRequest, client_name: str, server_name: str, session_id: str + check_request: BatchCheckRequest, client_name: str, server_name: str, session_id: str, session_store: SessionStore ) -> list[Event]: """Get the messages from the session store.""" try: @@ -195,6 +194,7 @@ async def batch_check_policies( check_request: BatchCheckRequest, request: fastapi.Request, activity_logger: ActivityLogger = Depends(get_activity_logger), + session_store: SessionStore = Depends(get_session_store), ): """Check a policy using the invariant analyzer.""" metadata = check_request.parameters.get("metadata", {}) @@ -203,7 +203,7 @@ async def batch_check_policies( mcp_server = metadata.get("server", "Unknown Server") session_id = metadata.get("session_id", "") - messages = await get_messages_from_session(check_request, mcp_client, mcp_server, session_id) + messages = await get_messages_from_session(check_request, mcp_client, mcp_server, session_id, session_store) last_analysis_index = session_store[mcp_client].last_analysis_index results = await asyncio.gather( diff --git a/src/mcp_scan_server/routes/push.py b/src/mcp_scan_server/routes/push.py index aef75e4..8bdb1eb 100644 --- a/src/mcp_scan_server/routes/push.py +++ b/src/mcp_scan_server/routes/push.py @@ -1,15 +1,28 @@ +import json import uuid -from fastapi import APIRouter, Request +from fastapi import APIRouter, Depends, Request from invariant_sdk.types.push_traces import PushTracesResponse +from ..record_file import push_session_to_record_file +from ..session_store import SessionStore, get_session_store + router = APIRouter() @router.post("/trace") -async def push_trace(request: Request) -> PushTracesResponse: +async def push_trace(request: Request, session_store: SessionStore = Depends(get_session_store)) -> PushTracesResponse: """Push a trace. For now, this is a dummy response.""" - trace_id = str(uuid.uuid4()) - # return the trace ID - return PushTracesResponse(id=[trace_id], success=True) + request_data = await request.body() + request_data = json.loads(request_data) + mcp_client = request_data.get("metadata")[0].get("client") + session = session_store[mcp_client] + + record_file = request.app.state.record_file + + # Push the session to the record file if it exists + if trace_id := await push_session_to_record_file(session, record_file, mcp_client, session_store): + return PushTracesResponse(id=[trace_id], success=True) + else: + return PushTracesResponse(id=[str(uuid.uuid4())], success=False) diff --git a/src/mcp_scan_server/routes/trace.py b/src/mcp_scan_server/routes/trace.py index 92e73dc..b019b43 100644 --- a/src/mcp_scan_server/routes/trace.py +++ b/src/mcp_scan_server/routes/trace.py @@ -1,10 +1,27 @@ -from fastapi import APIRouter, Request +import json + +from fastapi import APIRouter, Depends, Request + +from ..record_file import append_messages_to_record_file +from ..session_store import SessionStore, get_session_store router = APIRouter() @router.post("/{trace_id}/messages") -async def append_messages(trace_id: str, request: Request): +async def append_messages(trace_id: str, request: Request, session_store: SessionStore = Depends(get_session_store)): """Append messages to a trace. For now this is a dummy response.""" + request_data = await request.body() + request_data = json.loads(request_data) + + # If we are calling append, we should already have set the trace_id + if request.app.state.record_file: + await append_messages_to_record_file( + trace_id, + request.app.state.record_file, + session_store=session_store, + annotations=request_data.get("annotations"), + ) + return {"success": True} diff --git a/src/mcp_scan_server/server.py b/src/mcp_scan_server/server.py index d7e5508..e54711a 100644 --- a/src/mcp_scan_server/server.py +++ b/src/mcp_scan_server/server.py @@ -8,10 +8,12 @@ from mcp_scan_server.activity_logger import setup_activity_logger # type: ignore +from .record_file import parse_record_file_name from .routes.policies import router as policies_router # type: ignore from .routes.push import router as push_router from .routes.trace import router as dataset_trace_router from .routes.user import router as user_router +from .session_store import setup_session_store class MCPScanServer: @@ -32,6 +34,7 @@ def __init__( on_exit: Callable | None = None, log_level: str = "error", pretty: Literal["oneline", "compact", "full", "none"] = "compact", + record_file: str | None = None, ): self.port = port self.config_file_path = config_file_path @@ -41,6 +44,7 @@ def __init__( self.app = FastAPI(lifespan=self.life_span) self.app.state.config_file_path = config_file_path + self.app.state.record_file = parse_record_file_name(record_file) self.app.include_router(policies_router, prefix="/api/v1") self.app.include_router(push_router, prefix="/api/v1/push") @@ -65,9 +69,15 @@ async def on_startup(self): """Startup event for the FastAPI app.""" rich.print("[bold green]MCP-scan server started (http://localhost:" + str(self.port) + ")[/bold green]") + if self.app.state.record_file is not None: + rich.print(self.app.state.record_file.startup_message()) + # setup activity logger setup_activity_logger(self.app, pretty=self.pretty) + # setup session store + setup_session_store(self.app) + from .routes.policies import load_guardrails_config_file await load_guardrails_config_file(self.config_file_path) diff --git a/src/mcp_scan_server/session_store.py b/src/mcp_scan_server/session_store.py index 25ad521..7809bbf 100644 --- a/src/mcp_scan_server/session_store.py +++ b/src/mcp_scan_server/session_store.py @@ -3,6 +3,10 @@ from enum import Enum from typing import Any +from fastapi import FastAPI, Request + +Message = dict[str, Any] + class MergeNodeTypes(Enum): SELF = "self" @@ -24,7 +28,7 @@ class SessionNode: """ timestamp: datetime - message: dict[str, Any] + message: Message session_id: str server_name: str original_session_index: int @@ -37,6 +41,12 @@ def __lt__(self, other: "SessionNode") -> bool: """Sort by timestamp.""" return self.timestamp < other.timestamp + def to_json(self) -> Message: + """ + Convert the session node to a message. + """ + return self.message + class Session: """ @@ -49,6 +59,7 @@ def __init__( ): self.nodes: list[SessionNode] = nodes or [] self.last_analysis_index: int = -1 + self.last_pushed_index: int = -1 def _build_stack(self, other: "Session") -> list[MergeInstruction]: """ @@ -145,6 +156,12 @@ def get_sorted_nodes(self) -> list[SessionNode]: def __repr__(self): return f"Session(nodes={self.get_sorted_nodes()})" + def to_json(self) -> list[Message]: + """ + Convert the session to a list of messages. + """ + return [node.to_json() for node in self.nodes] + class SessionStore: """ @@ -179,6 +196,12 @@ def fetch_and_merge(self, client_name: str, other: Session) -> Session: session.merge(other) return session + def to_json(self) -> dict[str, dict[str, list[dict[str, Any]]]]: + """ + Convert the sessions to a dictionary. + """ + return {"sessions": {client_name: session.to_json() for client_name, session in self.sessions.items()}} + async def to_session(messages: list[dict[str, Any]], server_name: str, session_id: str) -> Session: """ @@ -198,3 +221,17 @@ async def to_session(messages: list[dict[str, Any]], server_name: str, session_i ) return Session(nodes=session_nodes) + + +def get_session_store(request: Request) -> SessionStore: + """ + Get the session store. + """ + return request.app.state.session_store + + +def setup_session_store(app: FastAPI) -> None: + """ + Setup the session store as a dependency for the given FastAPI app. + """ + app.state.session_store = SessionStore() diff --git a/tests/e2e/test_full_proxy_flow.py b/tests/e2e/test_full_proxy_flow.py index b4b5ae0..bdf1be5 100644 --- a/tests/e2e/test_full_proxy_flow.py +++ b/tests/e2e/test_full_proxy_flow.py @@ -161,6 +161,13 @@ async def test_basic(self, toy_server_add_config_file, pretty): servers = list(config.mcpServers.values()) assert len(servers) == 1 server = servers[0] + + # Debug: Print server config before starting client + print(f"[DEBUG] Server config: {server}") + print(f"[DEBUG] Server command: {server.command}") + print(f"[DEBUG] Server args: {server.args}") + print(f"[DEBUG] Server env: {server.env}") + client_program = run_toy_server_client(server) # wait for client to finish @@ -174,6 +181,18 @@ async def test_basic(self, toy_server_add_config_file, pretty): print(safe_decode(stdout)) print(safe_decode(stderr)) raise AssertionError("timed out waiting for MCP server to respond") from e + except Exception as e: + print(f"Client failed with exception: {e}") + print(f"Exception type: {type(e)}") + # Get any output from the process before terminating + stdout, stderr = process.communicate() + print("=== PROCESS OUTPUT ===") + print(safe_decode(stdout)) + print("=== PROCESS ERROR ===") + print(safe_decode(stderr)) + process.terminate() + process.wait() + raise assert int(client_output["result"]) == 3 diff --git a/tests/unit/test_record_file.py b/tests/unit/test_record_file.py new file mode 100644 index 0000000..36a047c --- /dev/null +++ b/tests/unit/test_record_file.py @@ -0,0 +1,274 @@ +import datetime +import json +import os +from unittest.mock import ANY, Mock + +import pytest + +from mcp_scan_server.record_file import ( + ExplorerRecordFile, + LocalRecordFile, + TraceClientMapping, + parse_record_file_name, + push_session_to_record_file, +) +from mcp_scan_server.session_store import Session, SessionNode, SessionStore + + +@pytest.fixture(autouse=True) +def cleanup_trace_client_mapping(): + """ + Cleanup the trace client mapping. + """ + trace_client_mapping = TraceClientMapping() + trace_client_mapping.clear() + + +@pytest.mark.parametrize("filename", ["explorer:test", "explorer:test", "explorer:test"]) +def test_parse_record_filename_explorer_valid(filename, monkeypatch): + monkeypatch.setattr("mcp_scan_server.record_file.Client", Mock()) + file = parse_record_file_name(filename) + assert file.dataset_name == filename.split(":")[1] + assert isinstance(file, ExplorerRecordFile) + + +@pytest.mark.parametrize("filename", ["local:test", "local:test"]) +def test_parse_record_filename_local_valid(filename): + file = parse_record_file_name(filename) + assert isinstance(file, LocalRecordFile) + + +@pytest.mark.parametrize("filename", ["test.something", "test"]) +def test_parse_record_filename_local_invalid(filename): + with pytest.raises(ValueError): + parse_record_file_name(filename) + + +def test_trace_client_mapping(): + """ + Test that we can set and get trace ids and client names. + """ + trace_client_mapping = TraceClientMapping() + trace_client_mapping.set_trace_id("trace_id", "client_name") + trace_client_mapping.set_trace_id("trace_id_2", "client_name_2") + + # Test that the trace id and client name are correctly set + assert trace_client_mapping.get_trace_id("client_name") == "trace_id" + assert trace_client_mapping.get_client_name("trace_id") == "client_name" + + # test that non-existent trace ids return None + assert trace_client_mapping.get_trace_id("client_name_3") is None + assert trace_client_mapping.get_client_name("trace_id_3") is None + + # test that non-existent client names return None + assert trace_client_mapping.get_trace_id("client_name_4") is None + assert trace_client_mapping.get_client_name("trace_id_4") is None + + +def test_trace_client_mapping_shares_state(): + """ + Test that we maintain the same state across multiple instances of the TraceClientMapping class. + """ + trace_ids = ["trace_id", "trace_id_2"] + client_names = ["client_name", "client_name_2"] + trace_client_mapping = TraceClientMapping() + + # Populate the first mapping + for trace_id, client_name in zip(trace_ids, client_names, strict=False): + trace_client_mapping.set_trace_id(trace_id, client_name) + + # Create a new mapping and check that the mappings are the same + trace_client_mapping_2 = TraceClientMapping() + for trace_id, client_name in zip(trace_ids, client_names, strict=False): + assert trace_client_mapping_2.get_trace_id(client_name) == trace_id + assert trace_client_mapping_2.get_client_name(trace_id) == client_name + + +def _setup_test_session_and_session_store(): + """Helper function to set up a test session.""" + # Create a session with a message + session = Session( + nodes=[ + SessionNode( + timestamp=datetime.datetime.now(), + message={"role": "user", "content": "test"}, + session_id="session_id", + server_name="server_name", + original_session_index=0, + ) + ] + ) + + # Add the session to the session store + session_store = SessionStore() + session_store["client_name"] = session + + return session_store, session + + +def _setup_test_session_and_mock(monkeypatch, mock_trace_id="test_trace_id"): + """Helper function to set up test session and mock client.""" + # Create a mock for the invariant SDK client + mock_client = Mock() + mock_client.create_request_and_push_trace.return_value = Mock(id=[mock_trace_id]) + + monkeypatch.setattr("invariant_sdk.client.Client", mock_client) + monkeypatch.setattr("mcp_scan_server.record_file.ExplorerRecordFile._check_is_key_set", Mock()) + + # Create the trace client mapping + trace_client_mapping = TraceClientMapping() + + return mock_client, trace_client_mapping + + +@pytest.mark.asyncio +async def test_push_session_to_record_file_explorer_first_time_calls_create_request_and_push_trace(monkeypatch): + """ + Test that we call create_request_and_push_trace when pushing a session to the record file for the first time. + """ + session_store, session = _setup_test_session_and_session_store() + mock_client, trace_client_mapping = _setup_test_session_and_mock(monkeypatch) + mock_trace_id = "test_trace_id" + + # Push the session to the record file + trace_id = await push_session_to_record_file( + session, ExplorerRecordFile("test", mock_client), "client_name", session_store + ) + + # Check that the trace id is set + assert trace_id == mock_trace_id + + # Check that the trace id is in the trace client mapping + assert trace_client_mapping.get_trace_id("client_name") == mock_trace_id + + # Check that the invariant sdk client was called with the correct arguments + mock_client.create_request_and_push_trace.assert_called_once_with( + messages=[[{"role": "user", "content": "test"}]], + dataset="test", + metadata=ANY, # Ignore metadata + annotations=ANY, + ) + + +@pytest.mark.asyncio +async def test_push_session_to_record_file_explorer_second_time_calls_append_messages(monkeypatch): + """ + Test that we call append_messages when pushing a session to the record file for the second time. + """ + session_store, session = _setup_test_session_and_session_store() + mock_client, trace_client_mapping = _setup_test_session_and_mock(monkeypatch) + mock_trace_id = "test_trace_id" + message = {"role": "assistant", "content": "response"} + + # First push to set up the trace ID + await push_session_to_record_file(session, ExplorerRecordFile("test", mock_client), "client_name", session_store) + + # Add a new message to the session + session.nodes.append( + SessionNode( + timestamp=datetime.datetime.now(), + message=message, + session_id="session_id", + server_name="server_name", + original_session_index=1, + ) + ) + # Second push should append + trace_id = await push_session_to_record_file( + session, ExplorerRecordFile("test", mock_client), "client_name", session_store + ) + + # Check that we got the same trace ID back + assert trace_id == mock_trace_id + + # Check that the trace id is in the trace client mapping + assert trace_client_mapping.get_trace_id("client_name") == mock_trace_id + + # Check that create_request_and_push_trace was only called once (from the first push) + mock_client.create_request_and_push_trace.assert_called_once() + + # Check that append_messages was called with the correct arguments + mock_client.create_request_and_append_messages.assert_called_once_with( + messages=[message], + trace_id=mock_trace_id, + annotations=ANY, + ) + + +def _setup_test_session_and_local_file(tmp_path): + """Helper function to set up test session and local file.""" + + # Create the trace client mapping + trace_client_mapping = TraceClientMapping() + + record_file = parse_record_file_name("local:test", base_path=str(tmp_path)) + assert isinstance(record_file, LocalRecordFile), "Should have LocalRecordFile" + + return trace_client_mapping, record_file + + +@pytest.mark.asyncio +async def test_push_session_to_record_file_local_creates_file_and_writes_to_it(tmp_path): + """ + Test that we create a file and write to it when pushing a session to the record file for the first time. + Also check that the path is set correctly to the LocalRecordFile data. + """ + session_store, session = _setup_test_session_and_session_store() + trace_client_mapping, record_file = _setup_test_session_and_local_file(tmp_path) + + # Push the session to the record file + trace_id = await push_session_to_record_file(session, record_file, "client_name", session_store) + + # Check that the trace id is set + assert trace_id is not None + assert trace_client_mapping.get_trace_id("client_name") == trace_id + + # Check that the file was created with the correct path + expected_path = record_file.get_session_file_path("client_name") + assert os.path.exists(expected_path) + + # Check that the file contains the correct content + with open(expected_path) as f: + content = f.read().strip() + assert content == json.dumps({"role": "user", "content": "test"}) + + +@pytest.mark.asyncio +async def test_push_session_to_record_file_local_second_time_appends_to_file(tmp_path): + """ + Test that we append to the file when pushing a session to the record file for the second time. + """ + session_store, session = _setup_test_session_and_session_store() + trace_client_mapping, record_file = _setup_test_session_and_local_file(tmp_path) + + # First push to set up the trace ID + trace_id = await push_session_to_record_file(session, record_file, "client_name", session_store) + + # Add a new message to the session + message = {"role": "assistant", "content": "response"} + session.nodes.append( + SessionNode( + timestamp=datetime.datetime.now(), + message=message, + session_id="session_id", + server_name="server_name", + original_session_index=1, + ) + ) + + # Second push should append + new_trace_id = await push_session_to_record_file(session, record_file, "client_name", session_store) + + # Check that we got the same trace ID back + assert new_trace_id == trace_id + + # Check that the trace id is in the trace client mapping + assert trace_client_mapping.get_trace_id("client_name") == new_trace_id + + # Check that the file contains both messages + expected_path = record_file.get_session_file_path("client_name") + with open(expected_path) as f: + lines = [line.strip() for line in f.readlines() if line.strip()] # Filter out empty lines + assert len(lines) == 2 + assert json.loads(lines[0]) == {"role": "user", "content": "test"} + assert json.loads(lines[1]) == message diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 412fb98..8eedfc6 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1,4 +1,5 @@ import datetime +import json import pytest @@ -425,3 +426,62 @@ def test_session_merge_last_analysis_index_is_reset_when_other_has_nodes_before_ session2 = Session(nodes=nodes2) session1.merge(session2) assert session1.last_analysis_index == -1 + + +def test_json_serialization(): + """Test that the JSON serialization works correctly for all classes.""" + timestamp = datetime.datetime(2024, 1, 1, 12, 0, 0) + message = {"role": "user", "content": "Hello, world!"} + + # Create a session node + node = SessionNode( + timestamp=timestamp, + message=message, + session_id="test_session", + server_name="test_server", + original_session_index=0, + ) + + # Test SessionNode JSON serialization + node_dict = node.to_json() + assert node_dict == message + + # Create a session with the node + session = Session(nodes=[node]) + + # Test Session JSON serialization + session_dict = session.to_json() + assert session_dict == [message] + + # Create a session store with the session + store = SessionStore() + store["test_client"] = session + + # Test SessionStore JSON serialization + store_json = store.to_json() + assert store_json == {"sessions": {"test_client": [message]}} + + # Finally test that we can dump and load + store_json_str = json.dumps(store_json) + assert store_json_str is not None + + store_dict = json.loads(store_json_str) + assert store_dict == {"sessions": {"test_client": [message]}} + + +def test_session_node_to_json(): + session_node = SessionNode( + timestamp=datetime.datetime.now(), + message={"role": "user", "content": "Hello, world!"}, + session_id="session_id", + server_name="server_name", + original_session_index=0, + ) + + session = Session(nodes=[session_node]) + + session_store = SessionStore() + session_store["client_name"] = session + + session_store_json = session_store.to_json() + assert session_store_json is not None