diff --git a/mcp_fuzzer/cli/validators.py b/mcp_fuzzer/cli/validators.py index e581a4f..52f1135 100644 --- a/mcp_fuzzer/cli/validators.py +++ b/mcp_fuzzer/cli/validators.py @@ -12,7 +12,7 @@ from ..exceptions import ArgumentValidationError from ..config import load_config_file -from ..transport.factory import create_transport +from ..transport.catalog import build_driver from ..exceptions import MCPError, TransportError from ..env import ENVIRONMENT_VARIABLES, ValidationType @@ -25,9 +25,10 @@ def __init__(self): def validate_arguments(self, args: argparse.Namespace) -> None: """Validate CLI arguments for fuzzing operations.""" - is_utility_command = getattr(args, "check_env", False) or getattr( - args, "validate_config", None - ) is not None + is_utility_command = ( + getattr(args, "check_env", False) + or getattr(args, "validate_config", None) is not None + ) if not is_utility_command and not getattr(args, "endpoint", None): raise ArgumentValidationError( @@ -67,17 +68,14 @@ def validate_config_file(self, path: str) -> None: """Validate a config file and print success message.""" load_config_file(path) success_msg = ( - "[green]:heavy_check_mark: Configuration file " - f"'{path}' is valid[/green]" + f"[green]:heavy_check_mark: Configuration file '{path}' is valid[/green]" ) self.console.print(emoji.emojize(success_msg, language="alias")) - def check_environment_variables(self) -> bool: """Print environment variable status and return validation result.""" self.console.print("[bold]Environment variables check:[/bold]") - all_valid = True for env_var in ENVIRONMENT_VARIABLES: name = env_var["name"] @@ -154,17 +152,13 @@ def _get_validation_error_msg( ) elif validation_type == ValidationType.NUMERIC: return ( - "[red]:heavy_multiplication_x: " - f"{name}={value} (must be numeric)[/red]" + f"[red]:heavy_multiplication_x: {name}={value} (must be numeric)[/red]" ) - return ( - "[red]:heavy_multiplication_x: " - f"{name}={value} (invalid value)[/red]" - ) + return f"[red]:heavy_multiplication_x: {name}={value} (invalid value)[/red]" def validate_transport(self, args: Any) -> None: try: - _ = create_transport( + _ = build_driver( args.protocol, args.endpoint, timeout=args.timeout, diff --git a/mcp_fuzzer/client/main.py b/mcp_fuzzer/client/main.py index e11c311..64a9667 100644 --- a/mcp_fuzzer/client/main.py +++ b/mcp_fuzzer/client/main.py @@ -13,7 +13,7 @@ from ..exceptions import MCPError from .settings import ClientSettings from .base import MCPFuzzerClient -from .transport import create_transport_with_auth +from .transport import build_driver_with_auth # For backward compatibility UnifiedMCPFuzzerClient = MCPFuzzerClient @@ -47,7 +47,7 @@ def __init__(self, protocol, endpoint, timeout): "auth_manager": config.get("auth_manager"), } - transport = create_transport_with_auth(args, client_args) + transport = build_driver_with_auth(args, client_args) safety_enabled = config.get("safety_enabled", True) safety_system = None diff --git a/mcp_fuzzer/client/tool_client.py b/mcp_fuzzer/client/tool_client.py index 9dd6786..a9dd148 100644 --- a/mcp_fuzzer/client/tool_client.py +++ b/mcp_fuzzer/client/tool_client.py @@ -18,6 +18,8 @@ DEFAULT_MAX_TOTAL_FUZZING_TIME, DEFAULT_FORCE_KILL_TIMEOUT, ) +from ..transport.interfaces import JsonRpcAdapter + class ToolClient: """Client for fuzzing MCP tools.""" @@ -40,6 +42,7 @@ def __init__( max_concurrency: Maximum number of concurrent operations """ self.transport = transport + self._rpc = JsonRpcAdapter(transport) self.auth_manager = auth_manager or AuthManager() self.enable_safety = enable_safety if not enable_safety: @@ -60,7 +63,7 @@ async def _get_tools_from_server(self) -> list[dict[str, Any]]: List of tool definitions or empty list if failed. """ try: - tools = await self.transport.get_tools() + tools = await self._rpc.get_tools() if not tools: self._logger.warning("Server returned an empty list of tools.") return [] @@ -93,7 +96,7 @@ async def _fuzz_single_tool_with_timeout( try: tool_task = asyncio.create_task( self.fuzz_tool(tool, runs_per_tool, tool_timeout=tool_timeout), - name=f"fuzz_tool_{tool_name}" + name=f"fuzz_tool_{tool_name}", ) try: @@ -180,9 +183,7 @@ async def fuzz_tool( # Call the tool with the generated arguments try: - result = await self.transport.call_tool( - tool_name, args_for_call - ) + result = await self._rpc.call_tool(tool_name, args_for_call) results.append( { "args": sanitized_args, @@ -311,11 +312,11 @@ async def _process_fuzz_results( fuzz_results: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Process fuzz results with safety checks and tool calls. - + Args: tool_name: Name of the tool being fuzzed fuzz_results: List of fuzz results from the fuzzer - + Returns: List of processed results with tool call outcomes """ diff --git a/mcp_fuzzer/client/transport/__init__.py b/mcp_fuzzer/client/transport/__init__.py index 9be3bbd..c87f710 100644 --- a/mcp_fuzzer/client/transport/__init__.py +++ b/mcp_fuzzer/client/transport/__init__.py @@ -1,3 +1,3 @@ -from .factory import create_transport_with_auth +from .factory import build_driver_with_auth -__all__ = ["create_transport_with_auth"] +__all__ = ["build_driver_with_auth"] diff --git a/mcp_fuzzer/client/transport/factory.py b/mcp_fuzzer/client/transport/factory.py index ac26a51..1a0c30c 100644 --- a/mcp_fuzzer/client/transport/factory.py +++ b/mcp_fuzzer/client/transport/factory.py @@ -9,11 +9,12 @@ from rich.console import Console -from ...transport.factory import create_transport as base_create_transport +from ...transport.catalog import build_driver as base_build_driver + logger = logging.getLogger(__name__) -def create_transport_with_auth(args: Any, client_args: dict[str, Any]): +def build_driver_with_auth(args: Any, client_args: dict[str, Any]): """Create a transport with authentication headers when available.""" try: auth_headers = None @@ -46,7 +47,7 @@ def create_transport_with_auth(args: Any, client_args: dict[str, Any]): args.protocol.upper(), args.endpoint, ) - transport = base_create_transport( + transport = base_build_driver( args.protocol, args.endpoint, **factory_kwargs, @@ -59,4 +60,4 @@ def create_transport_with_auth(args: Any, client_args: dict[str, Any]): sys.exit(1) -__all__ = ["create_transport_with_auth"] +__all__ = ["build_driver_with_auth"] diff --git a/mcp_fuzzer/config/loader.py b/mcp_fuzzer/config/loader.py index fbc21ee..5e40ee1 100644 --- a/mcp_fuzzer/config/loader.py +++ b/mcp_fuzzer/config/loader.py @@ -13,12 +13,13 @@ from .manager import config from ..exceptions import ConfigFileError, MCPError -from ..transport.custom import register_custom_transport -from ..transport.base import TransportProtocol +from ..transport.catalog.custom_catalog import register_custom_driver +from ..transport.interfaces import TransportDriver import importlib logger = logging.getLogger(__name__) + def find_config_file( config_path: str | None = None, search_paths: list[str] | None = None, @@ -61,6 +62,7 @@ def find_config_file( return None + def load_config_file(file_path: str) -> dict[str, Any]: """Load configuration from a YAML file. @@ -99,6 +101,7 @@ def load_config_file(file_path: str) -> dict[str, Any]: f"Unexpected error reading configuration file {file_path}: {str(e)}" ) + def apply_config_file( config_path: str | None = None, search_paths: list[str] | None = None, @@ -130,6 +133,7 @@ def apply_config_file( config.update(config_data) return True + def get_config_schema() -> dict[str, Any]: """Get the configuration schema. @@ -240,7 +244,7 @@ def get_config_schema() -> dict[str, Any]: "factory": { "type": "string", "description": "Dotted path to factory function " - "(e.g., pkg.mod.create_transport)", + "(e.g., pkg.mod.build_driver)", }, "config_schema": { "type": "object", @@ -322,6 +326,7 @@ def get_config_schema() -> dict[str, Any]: }, } + def load_custom_transports(config_data: dict[str, Any]) -> None: """Load and register custom transports from configuration. @@ -340,9 +345,9 @@ def load_custom_transports(config_data: dict[str, Any]) -> None: transport_class = getattr(module, class_name) if not isinstance(transport_class, type): raise ConfigFileError(f"{module_path}.{class_name} is not a class") - if not issubclass(transport_class, TransportProtocol): + if not issubclass(transport_class, TransportDriver): raise ConfigFileError( - f"{module_path}.{class_name} must subclass TransportProtocol" + f"{module_path}.{class_name} must subclass TransportDriver" ) # Register the transport @@ -362,7 +367,7 @@ def load_custom_transports(config_data: dict[str, Any]) -> None: if not callable(factory_fn): raise ConfigFileError(f"Factory '{factory_path}' is not callable") - register_custom_transport( + register_custom_driver( name=transport_name, transport_class=transport_class, description=description, diff --git a/mcp_fuzzer/events.py b/mcp_fuzzer/events.py index 950a056..1a7efcf 100644 --- a/mcp_fuzzer/events.py +++ b/mcp_fuzzer/events.py @@ -3,7 +3,7 @@ Shared event contract definitions for runtime components. This module exposes a lightweight protocol that both the runtime `ProcessManager` -and transport `TransportManager` use to dispatch lifecycle events. Observers can +and transport `ProcessSupervisor` use to dispatch lifecycle events. Observers can subscribe to these events to monitor state changes without depending on internal implementation details. @@ -20,9 +20,9 @@ - ``signal`` / ``signal_all`` (ProcessManager): emitted when signals are sent. payload keys: ``pid``, ``signal``, ``process_name``, ``result`` (``signal_all`` adds ``results`` and ``failures``). - - ``signal_failed`` (TransportManager): emitted when signal dispatch fails. + - ``signal_failed`` (ProcessSupervisor): emitted when signal dispatch fails. payload keys: ``pid`` and ``error``. -- ``oversized_output`` (TransportManager): emitted whenever stdio output exceeds +- ``oversized_output`` (ProcessSupervisor): emitted whenever stdio output exceeds the configured cap. Payload keys: ``pid``, ``size``, ``limit``. Future event producers should keep the payloads shallow (``dict[str, Any]``) to diff --git a/mcp_fuzzer/transport/__init__.py b/mcp_fuzzer/transport/__init__.py index d6a9271..b59ab58 100644 --- a/mcp_fuzzer/transport/__init__.py +++ b/mcp_fuzzer/transport/__init__.py @@ -1,28 +1,64 @@ -from .base import TransportProtocol -from .http import HTTPTransport -from .sse import SSETransport -from .stdio import StdioTransport -from .streamable_http import StreamableHTTPTransport -from .factory import create_transport -from .manager import TransportManager, TransportProcessState -from .custom import ( - CustomTransportRegistry, - register_custom_transport, - create_custom_transport, - list_custom_transports, +"""Transport subsystem composed of interfaces, drivers, catalogs, and controllers.""" + +from .interfaces import ( + TransportDriver, + DriverState, + ParsedEndpoint, + DriverBaseBehavior, + HttpClientBehavior, + ResponseParserBehavior, + LifecycleBehavior, + TransportError, + NetworkError, + PayloadValidationError, + JsonRpcAdapter, +) +from .drivers import ( + HttpDriver, + SseDriver, + StdioDriver, + StreamHttpDriver, +) +from .catalog import ( + DriverCatalog, + driver_catalog, + build_driver, + EndpointResolver, + CustomDriverCatalog, + register_custom_driver, + build_custom_driver, + list_custom_drivers, + custom_driver_catalog, ) +from .controller.coordinator import TransportCoordinator +from .controller.process_supervisor import ProcessSupervisor, ProcessState __all__ = [ - "TransportProtocol", - "HTTPTransport", - "SSETransport", - "StdioTransport", - "StreamableHTTPTransport", - "TransportManager", - "TransportProcessState", - "create_transport", - "CustomTransportRegistry", - "register_custom_transport", - "create_custom_transport", - "list_custom_transports", + "TransportDriver", + "DriverState", + "ParsedEndpoint", + "DriverBaseBehavior", + "HttpClientBehavior", + "ResponseParserBehavior", + "LifecycleBehavior", + "TransportError", + "NetworkError", + "PayloadValidationError", + "JsonRpcAdapter", + "HttpDriver", + "SseDriver", + "StdioDriver", + "StreamHttpDriver", + "DriverCatalog", + "driver_catalog", + "build_driver", + "EndpointResolver", + "CustomDriverCatalog", + "register_custom_driver", + "build_custom_driver", + "list_custom_drivers", + "custom_driver_catalog", + "TransportCoordinator", + "ProcessSupervisor", + "ProcessState", ] diff --git a/mcp_fuzzer/transport/catalog/__init__.py b/mcp_fuzzer/transport/catalog/__init__.py new file mode 100644 index 0000000..0017b1c --- /dev/null +++ b/mcp_fuzzer/transport/catalog/__init__.py @@ -0,0 +1,24 @@ +"""Driver catalog, builders, and custom driver helpers.""" + +from .catalog import DriverCatalog +from .builder import driver_catalog, build_driver +from .resolver import EndpointResolver +from .custom_catalog import ( + CustomDriverCatalog, + register_custom_driver, + build_custom_driver, + list_custom_drivers, + custom_driver_catalog, +) + +__all__ = [ + "DriverCatalog", + "driver_catalog", + "build_driver", + "EndpointResolver", + "CustomDriverCatalog", + "register_custom_driver", + "build_custom_driver", + "list_custom_drivers", + "custom_driver_catalog", +] diff --git a/mcp_fuzzer/transport/catalog/builder.py b/mcp_fuzzer/transport/catalog/builder.py new file mode 100644 index 0000000..1940129 --- /dev/null +++ b/mcp_fuzzer/transport/catalog/builder.py @@ -0,0 +1,105 @@ +"""Transport factory for creating transport instances. + +This module provides a simplified factory that uses the unified registry +and URL parser to create transport instances. +""" + +from __future__ import annotations + +from ..interfaces.driver import TransportDriver +from .catalog import driver_catalog +from .resolver import EndpointResolver +from ...exceptions import TransportRegistrationError + + +endpoint_resolver = EndpointResolver(driver_catalog) + + +def build_driver( + url_or_protocol: str, endpoint: str | None = None, **kwargs +) -> TransportDriver: + """Create a transport from either a full URL or protocol + endpoint. + + This factory function supports two calling patterns: + 1. Single URL: build_driver("http://localhost:8080/api") + 2. Protocol + endpoint: build_driver("http", "localhost:8080/api") + + The function automatically detects custom transports and handles URL parsing. + + Args: + url_or_protocol: Full URL or protocol name + endpoint: Optional endpoint (for protocol+endpoint pattern) + **kwargs: Additional arguments to pass to transport constructor + + Returns: + Transport instance + + Raises: + TransportRegistrationError: If protocol/scheme is not supported + + Examples: + >>> transport = build_driver("http://localhost:8080") + >>> transport = build_driver("http", "localhost:8080") + >>> transport = build_driver("sse://localhost:8080/events") + >>> transport = build_driver("stdio:python server.py") + """ + # Parse URL or protocol+endpoint + parsed = endpoint_resolver.parse(url_or_protocol, endpoint) + + if not parsed.scheme: + raise TransportRegistrationError( + f"Could not determine transport scheme from: {url_or_protocol}" + ) + + # Check if transport is registered + if not driver_catalog.is_registered(parsed.scheme): + # List available transports for error message + builtin = list(driver_catalog.list_builtin_transports().keys()) + custom = list(driver_catalog.list_custom_drivers().keys()) + + error_msg = f"Unsupported transport scheme: '{parsed.scheme}'" + if builtin: + error_msg += f"\nBuilt-in transports: {', '.join(builtin)}" + if custom: + error_msg += f"\nCustom transports: {', '.join(custom)}" + + raise TransportRegistrationError(error_msg) + + # Create transport using registry + try: + return driver_catalog.build_driver(parsed.scheme, parsed.endpoint, **kwargs) + except Exception as e: + raise TransportRegistrationError( + f"Failed to create transport '{parsed.scheme}': {e}" + ) from e + + +# Register built-in transports with the global unified registry +def _register_builtin_transports(): + """Register all built-in transport types.""" + from ..drivers.http_driver import HttpDriver + from ..drivers.sse_driver import SseDriver + from ..drivers.stdio_driver import StdioDriver + from ..drivers.stream_http_driver import StreamHttpDriver + + # Only register if not already registered (allow tests to override) + transports = { + "http": HttpDriver, + "https": HttpDriver, + "sse": SseDriver, + "stdio": StdioDriver, + "streamablehttp": StreamHttpDriver, + } + + for name, cls in transports.items(): + if not driver_catalog.is_registered(name): + driver_catalog.register( + name, + cls, + description=f"Built-in {name.upper()} driver", + is_custom=False, + ) + + +# Register built-in transports on module import +_register_builtin_transports() diff --git a/mcp_fuzzer/transport/catalog/catalog.py b/mcp_fuzzer/transport/catalog/catalog.py new file mode 100644 index 0000000..a6361c1 --- /dev/null +++ b/mcp_fuzzer/transport/catalog/catalog.py @@ -0,0 +1,297 @@ +"""Unified transport registry system. + +This module provides a single registry for both built-in and custom transports, +replacing the previous dual registry system. +""" + +from __future__ import annotations + +import logging +from typing import Any, Callable, Type + +from ..interfaces.driver import TransportDriver +from ...exceptions import TransportRegistrationError + +logger = logging.getLogger(__name__) + + +class TransportMetadata: + """Metadata for a registered transport.""" + + def __init__( + self, + transport_class: Type[TransportDriver], + description: str = "", + config_schema: dict[str, Any] | None = None, + factory_function: Callable | None = None, + is_custom: bool = False, + ): + self.transport_class = transport_class + self.description = description + self.config_schema = config_schema + self.factory_function = factory_function + self.is_custom = is_custom + + +class DriverCatalog: + """Unified registry for both built-in and custom transports. + + This registry replaces the previous dual registry system (TransportRegistry + and CustomDriverCatalog) with a single, consistent interface. + """ + + def __init__(self): + """Initialize the unified transport registry.""" + self._transports: dict[str, TransportMetadata] = {} + + def register( + self, + name: str, + transport_class: Type[TransportDriver], + description: str = "", + config_schema: dict[str, Any] | None = None, + factory_function: Callable | None = None, + is_custom: bool = False, + allow_override: bool = False, + ) -> None: + """Register a transport with the registry. + + Args: + name: Unique name for the transport (case-insensitive) + transport_class: The transport class to register + description: Human-readable description + config_schema: JSON schema for transport configuration + factory_function: Optional factory function for creating instances + is_custom: Whether this is a custom (non-built-in) transport + allow_override: Whether to allow overriding existing registration + + Raises: + TransportRegistrationError: If name already registered and + override not allowed + """ + key = name.strip().lower() + + # Check for existing registration + if key in self._transports and not allow_override: + existing = self._transports[key] + existing_type = "custom" if existing.is_custom else "built-in" + new_type = "custom" if is_custom else "built-in" + raise TransportRegistrationError( + f"Transport '{name}' is already registered as " + f"{existing_type} transport. " + f"Cannot register as {new_type} transport without " + "allow_override=True." + ) + + # Validate transport class + if not issubclass(transport_class, TransportDriver): + raise TransportRegistrationError( + f"Transport class {transport_class} must inherit from TransportDriver" + ) + + # Register the transport + self._transports[key] = TransportMetadata( + transport_class=transport_class, + description=description, + config_schema=config_schema, + factory_function=factory_function, + is_custom=is_custom, + ) + + transport_type = "custom" if is_custom else "built-in" + logger.info(f"Registered {transport_type} transport: {key}") + + def unregister(self, name: str) -> None: + """Unregister a transport. + + Args: + name: Name of the transport to unregister + + Raises: + TransportRegistrationError: If transport is not registered + """ + key = name.strip().lower() + if key not in self._transports: + raise TransportRegistrationError(f"Transport '{name}' is not registered") + + metadata = self._transports[key] + transport_type = "custom" if metadata.is_custom else "built-in" + del self._transports[key] + logger.info(f"Unregistered {transport_type} transport: {key}") + + def is_registered(self, name: str) -> bool: + """Check if a transport is registered. + + Args: + name: Transport name to check + + Returns: + True if registered, False otherwise + """ + return name.strip().lower() in self._transports + + def is_custom_transport(self, name: str) -> bool: + """Check if a registered transport is custom. + + Args: + name: Transport name to check + + Returns: + True if registered and custom, False otherwise + """ + key = name.strip().lower() + if key not in self._transports: + return False + return self._transports[key].is_custom + + def get_transport_class(self, name: str) -> Type[TransportDriver]: + """Get the transport class for a registered transport. + + Args: + name: Name of the registered transport + + Returns: + The transport class + + Raises: + TransportRegistrationError: If transport is not registered + """ + key = name.strip().lower() + if key not in self._transports: + raise TransportRegistrationError(f"Transport '{name}' is not registered") + return self._transports[key].transport_class + + def get_transport_info(self, name: str) -> dict[str, Any]: + """Get information about a registered transport. + + Args: + name: Name of the registered transport + + Returns: + Dictionary containing transport information + + Raises: + TransportRegistrationError: If transport is not registered + """ + key = name.strip().lower() + if key not in self._transports: + raise TransportRegistrationError(f"Transport '{name}' is not registered") + + metadata = self._transports[key] + return { + "class": metadata.transport_class, + "description": metadata.description, + "config_schema": metadata.config_schema, + "factory": metadata.factory_function, + "is_custom": metadata.is_custom, + } + + def list_transports(self, include_custom: bool = True) -> dict[str, dict[str, Any]]: + """List registered transports. + + Args: + include_custom: Whether to include custom transports + + Returns: + Dictionary mapping transport names to their information + """ + result = {} + for name, metadata in self._transports.items(): + if not include_custom and metadata.is_custom: + continue + result[name] = { + "class": metadata.transport_class, + "description": metadata.description, + "config_schema": metadata.config_schema, + "factory": metadata.factory_function, + "is_custom": metadata.is_custom, + } + return result + + def list_builtin_transports(self) -> dict[str, dict[str, Any]]: + """List only built-in transports. + + Returns: + Dictionary mapping transport names to their information + """ + return { + name: info + for name, metadata in self._transports.items() + if not metadata.is_custom + for info in [ + { + "class": metadata.transport_class, + "description": metadata.description, + "config_schema": metadata.config_schema, + "factory": metadata.factory_function, + "is_custom": metadata.is_custom, + } + ] + } + + def list_custom_drivers(self) -> dict[str, dict[str, Any]]: + """List only custom transports. + + Returns: + Dictionary mapping transport names to their information + """ + return { + name: info + for name, metadata in self._transports.items() + if metadata.is_custom + for info in [ + { + "class": metadata.transport_class, + "description": metadata.description, + "config_schema": metadata.config_schema, + "factory": metadata.factory_function, + "is_custom": metadata.is_custom, + } + ] + } + + def build_driver(self, name: str, *args, **kwargs) -> TransportDriver: + """Create an instance of a registered transport. + + Args: + name: Name of the registered transport + *args: Positional arguments to pass to transport constructor/factory + **kwargs: Keyword arguments to pass to transport constructor/factory + + Returns: + Transport instance + + Raises: + TransportRegistrationError: If transport is not registered + """ + metadata = self.get_transport_info(name) + transport_class = metadata["class"] + factory = metadata.get("factory") + + # Use factory if provided + if factory is not None: + return factory(*args, **kwargs) + + # Handle custom transport URL shorthand (e.g., "custom://endpoint") + if ( + metadata["is_custom"] + and args + and len(args) == 1 + and isinstance(args[0], str) + ): + url = args[0] + if f"{name}://" in url: + endpoint = url.split(f"{name}://", 1)[1] + args = (endpoint,) + args[1:] + + # Use class constructor + return transport_class(*args, **kwargs) + + def clear(self) -> None: + """Clear all registered transports. Useful for testing.""" + self._transports.clear() + logger.debug("Cleared all registered transports") + + +# Global registry instance +driver_catalog = DriverCatalog() diff --git a/mcp_fuzzer/transport/catalog/custom_catalog.py b/mcp_fuzzer/transport/catalog/custom_catalog.py new file mode 100644 index 0000000..622e6ba --- /dev/null +++ b/mcp_fuzzer/transport/catalog/custom_catalog.py @@ -0,0 +1,233 @@ +"""Custom transport registry and utilities. + +This module provides support for registering and managing custom transport +implementations. It now uses the unified registry system internally for +consistency across all transport types. +""" + +from __future__ import annotations + +import logging +from typing import Type, Any, Callable + +from ..interfaces.driver import TransportDriver +from .catalog import driver_catalog +from ...exceptions import TransportRegistrationError + +logger = logging.getLogger(__name__) + + +class CustomDriverCatalog: + """Registry for custom transport implementations. + + This class now wraps the unified registry, marking transports as custom. + It maintains backward compatibility with the previous API while using + the unified registry internally. + """ + + def __init__(self): + """Initialize custom transport registry.""" + # Use the unified registry internally + self._registry = driver_catalog + + def clear(self) -> None: + """Clear all registered custom transports. Useful for testing.""" + # Only clear custom transports + custom_transports = list(self._registry.list_custom_drivers().keys()) + for name in custom_transports: + try: + self._registry.unregister(name) + except TransportRegistrationError: + pass + + def register( + self, + name: str, + transport_class: Type[TransportDriver], + description: str = "", + config_schema: dict[str, Any] | None = None, + factory_function: Callable | None = None, + ) -> None: + """Register a custom transport. + + Args: + name: Unique name for the transport + transport_class: The transport class to register + description: Human-readable description + config_schema: JSON schema for transport configuration + factory_function: Optional factory function to create transport instances + + Raises: + TransportRegistrationError: If transport name is already registered + """ + # Use unified registry with custom flag + self._registry.register( + name=name, + transport_class=transport_class, + description=description or f"Custom {name} transport", + config_schema=config_schema, + factory_function=factory_function, + is_custom=True, + allow_override=False, + ) + + def unregister(self, name: str) -> None: + """Unregister a custom transport. + + Args: + name: Name of the transport to unregister + + Raises: + TransportRegistrationError: If transport is not registered or not custom + """ + key = name.strip().lower() + + # Verify it's a custom transport before unregistering + if not self._registry.is_custom_transport(key): + if self._registry.is_registered(key): + raise TransportRegistrationError( + f"Transport '{name}' is a built-in transport and " + "cannot be unregistered" + ) + else: + raise TransportRegistrationError( + f"Transport '{name}' is not registered" + ) + + self._registry.unregister(name) + + def get_transport_class(self, name: str) -> Type[TransportDriver]: + """Get the transport class for a registered custom transport. + + Args: + name: Name of the registered transport + + Returns: + The transport class + + Raises: + TransportRegistrationError: If transport is not registered or not custom + """ + key = name.strip().lower() + + if not self._registry.is_custom_transport(key): + if self._registry.is_registered(key): + raise TransportRegistrationError( + f"Transport '{name}' is a built-in transport, not custom" + ) + else: + raise TransportRegistrationError( + f"Transport '{name}' is not registered" + ) + + return self._registry.get_transport_class(name) + + def get_transport_info(self, name: str) -> dict[str, Any]: + """Get information about a registered custom transport. + + Args: + name: Name of the registered transport + + Returns: + Dictionary containing transport information + + Raises: + TransportRegistrationError: If transport is not registered or not custom + """ + key = name.strip().lower() + + if not self._registry.is_custom_transport(key): + if self._registry.is_registered(key): + raise TransportRegistrationError( + f"Transport '{name}' is a built-in transport, not custom" + ) + else: + raise TransportRegistrationError( + f"Transport '{name}' is not registered" + ) + + return self._registry.get_transport_info(name) + + def list_transports(self) -> dict[str, dict[str, Any]]: + """List all registered custom transports. + + Returns: + Dictionary mapping transport names to their information + """ + return self._registry.list_custom_drivers() + + def build_driver(self, name: str, *args, **kwargs) -> TransportDriver: + """Create an instance of a registered custom transport. + + Args: + name: Name of the registered transport + *args: Positional arguments to pass to transport constructor + **kwargs: Keyword arguments to pass to transport constructor + + Returns: + Transport instance + + Raises: + TransportRegistrationError: If transport is not registered or not custom + """ + key = name.strip().lower() + + if not self._registry.is_custom_transport(key): + if self._registry.is_registered(key): + raise TransportRegistrationError( + f"Transport '{name}' is a built-in transport, not custom" + ) + else: + raise TransportRegistrationError( + f"Transport '{name}' is not registered" + ) + + return self._registry.build_driver(name, *args, **kwargs) + + +# Global registry instance +custom_driver_catalog = CustomDriverCatalog() + + +def register_custom_driver( + name: str, + transport_class: Type[TransportDriver], + description: str = "", + config_schema: dict[str, Any] | None = None, + factory_function: Callable | None = None, +) -> None: + """Register a custom transport with the global registry. + + Args: + name: Unique name for the transport + transport_class: The transport class to register + description: Human-readable description + config_schema: JSON schema for transport configuration + factory_function: Optional factory function to create transport instances + """ + custom_driver_catalog.register( + name, transport_class, description, config_schema, factory_function + ) + + +def build_custom_driver(name: str, *args, **kwargs) -> TransportDriver: + """Create an instance of a registered custom transport. + + Args: + name: Name of the registered transport + *args: Positional arguments to pass to transport constructor + **kwargs: Keyword arguments to pass to transport constructor + + Returns: + Transport instance + """ + return custom_driver_catalog.build_driver(name, *args, **kwargs) + + +def list_custom_drivers() -> dict[str, dict[str, Any]]: + """List all registered custom transports. + + Returns: + Dictionary mapping transport names to their information + """ + return custom_driver_catalog.list_transports() diff --git a/mcp_fuzzer/transport/catalog/resolver.py b/mcp_fuzzer/transport/catalog/resolver.py new file mode 100644 index 0000000..47f0882 --- /dev/null +++ b/mcp_fuzzer/transport/catalog/resolver.py @@ -0,0 +1,161 @@ +"""URL parsing utilities for transport creation. + +This module provides URL parsing functionality extracted from the factory, +supporting both standard URLs and custom transport schemes. +""" + +from __future__ import annotations + +from urllib.parse import urlparse, urlunparse +from typing import TYPE_CHECKING + +from ..interfaces.states import ParsedEndpoint + +if TYPE_CHECKING: + from .catalog import DriverCatalog + + +class EndpointResolver: + """Parser for transport URLs and protocol+endpoint patterns. + + Handles both standard URLs (http://..., https://...) and custom + transport schemes (sse://..., stdio:..., streamablehttp://..., etc.) + """ + + def __init__(self, registry: DriverCatalog | None = None): + """Initialize URL parser. + + Args: + registry: Optional registry to check for custom schemes + """ + self._registry = registry + + def set_registry(self, registry: DriverCatalog) -> None: + """Set or update the registry used for custom scheme lookup. + + Args: + registry: Registry instance + """ + self._registry = registry + + def parse( + self, url_or_protocol: str, endpoint: str | None = None + ) -> ParsedEndpoint: + """Parse a URL or protocol+endpoint into structured components. + + Supports two calling patterns: + 1. Single URL: parse("http://localhost:8080/api") + 2. Protocol + endpoint: parse("http", "localhost:8080/api") + + Args: + url_or_protocol: Full URL or protocol name + endpoint: Optional endpoint (for protocol+endpoint pattern) + + Returns: + ParsedEndpoint with structured components + """ + url_or_protocol = url_or_protocol.strip() + endpoint = endpoint.strip() if isinstance(endpoint, str) else endpoint + # Handle protocol+endpoint pattern + if endpoint is not None: + return self._parse_protocol_endpoint(url_or_protocol, endpoint) + + # Handle full URL pattern + return self._parse_url(url_or_protocol) + + def _parse_protocol_endpoint(self, protocol: str, endpoint: str) -> ParsedEndpoint: + """Parse protocol and endpoint into ParsedEndpoint. + + Args: + protocol: Transport protocol name + endpoint: Endpoint string + + Returns: + ParsedEndpoint with protocol as scheme and endpoint + """ + protocol_lower = protocol.strip().lower() + + # Check if it's a custom transport + is_custom = False + if self._registry: + is_custom = self._registry.is_registered(protocol_lower) + + return ParsedEndpoint( + scheme=protocol_lower, + endpoint=endpoint, + is_custom=is_custom, + original_url=f"{protocol}://{endpoint}", + ) + + def _parse_url(self, url: str) -> ParsedEndpoint: + """Parse a full URL into ParsedEndpoint. + + Args: + url: Full URL string + + Returns: + ParsedEndpoint with parsed components + """ + parsed = urlparse(url) + scheme = (parsed.scheme or "").lower() + + # Handle schemes that urlparse doesn't recognize (e.g., custom drivers) + if not scheme and "://" in url: + scheme = url.split("://", 1)[0].strip().lower() + + # Check if custom transport + is_custom = False + if self._registry and scheme: + is_custom = self._registry.is_custom_transport(scheme) + + # Determine endpoint based on scheme + endpoint = self._resolve_endpoint(url, parsed, scheme) + + return ParsedEndpoint( + scheme=scheme, + endpoint=endpoint, + is_custom=is_custom, + original_url=url, + netloc=parsed.netloc, + path=parsed.path, + params=parsed.params, + query=parsed.query, + fragment=parsed.fragment, + ) + + def _resolve_endpoint(self, original_url: str, parsed, scheme: str) -> str: + """Resolve the endpoint from parsed URL components. + + Args: + original_url: Original URL string + parsed: Result from urlparse + scheme: URL scheme + + Returns: + Resolved endpoint string + """ + # For stdio, extract command + if scheme == "stdio": + has_parts = parsed.netloc or parsed.path + cmd_source = (parsed.netloc + parsed.path) if has_parts else "" + return cmd_source.lstrip("/") + + # For SSE and StreamableHTTP, convert to HTTP URL + if scheme in ("sse", "streamablehttp"): + return urlunparse( + ( + "http", + parsed.netloc, + parsed.path, + parsed.params, + parsed.query, + parsed.fragment, + ) + ) + + # For HTTP/HTTPS, return original URL + if scheme in ("http", "https"): + return original_url + + # For custom transports, return original URL + return original_url diff --git a/mcp_fuzzer/transport/controller/__init__.py b/mcp_fuzzer/transport/controller/__init__.py new file mode 100644 index 0000000..3c12855 --- /dev/null +++ b/mcp_fuzzer/transport/controller/__init__.py @@ -0,0 +1,23 @@ +"""Transport controllers for driver coordination and process supervision.""" + +__all__ = [ + "TransportCoordinator", + "ProcessSupervisor", + "ProcessState", +] + + +def __getattr__(name: str): + if name == "TransportCoordinator": + from .coordinator import TransportCoordinator + + return TransportCoordinator + if name == "ProcessSupervisor": + from .process_supervisor import ProcessSupervisor + + return ProcessSupervisor + if name == "ProcessState": + from .process_supervisor import ProcessState + + return ProcessState + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/mcp_fuzzer/transport/controller/coordinator.py b/mcp_fuzzer/transport/controller/coordinator.py new file mode 100644 index 0000000..b85d8ec --- /dev/null +++ b/mcp_fuzzer/transport/controller/coordinator.py @@ -0,0 +1,242 @@ +"""Transport subsystem manager for coordinating transport operations. + +This module provides the TransportCoordinator which coordinates all +transport-related operations including creation, lifecycle, and JSON-RPC operations. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict + +from ..interfaces.driver import TransportDriver +from ..interfaces.rpc_adapter import JsonRpcAdapter +from ..catalog.catalog import driver_catalog +from ...exceptions import TransportError + + +class TransportCoordinator: + """Manager for the transport subsystem. + + Coordinates transport creation, lifecycle, JSON-RPC operations, and error handling. + This manager acts as the single point of coordination for all transport-related + operations within the transport module. + """ + + def __init__(self, config: Dict[str, Any] | None = None): + """Initialize transport subsystem manager. + + Args: + config: Optional configuration dictionary for transport settings + """ + self._logger = logging.getLogger(__name__) + self._config = config or {} + self._active_transports: dict[str, TransportDriver] = {} + self._jsonrpc_helper = JsonRpcAdapter() + + def get_jsonrpc_helper(self) -> JsonRpcAdapter: + """Get the JSON-RPC helper for transport operations. + + Returns: + JsonRpcAdapter instance + """ + return self._jsonrpc_helper + + async def build_driver( + self, + url_or_protocol: str, + endpoint: str | None = None, + transport_id: str | None = None, + **kwargs, + ) -> TransportDriver: + """Create a transport instance. + + Args: + url_or_protocol: Full URL or protocol name + endpoint: Optional endpoint (for protocol+endpoint pattern) + transport_id: Optional ID to track this transport + **kwargs: Additional arguments for transport constructor + + Returns: + Transport instance + + Raises: + TransportError: If transport creation fails + """ + try: + # Local import avoids circular dependency during module initialization + from ..catalog.builder import build_driver as build_driver_fn + + transport = build_driver_fn(url_or_protocol, endpoint, **kwargs) + + # Track active transport if ID provided + if transport_id: + self._active_transports[transport_id] = transport + + # Set transport in JSON-RPC helper for use + self._jsonrpc_helper.set_transport(transport) + + self._logger.debug( + f"Created transport: {url_or_protocol}" + + (f" (id: {transport_id})" if transport_id else "") + ) + + return transport + except Exception as e: + self._logger.error(f"Failed to create transport: {e}") + raise TransportError(f"Transport creation failed: {e}") from e + + async def connect( + self, transport: TransportDriver, transport_id: str | None = None + ) -> None: + """Connect a transport. + + Args: + transport: Transport to connect + transport_id: Optional ID for tracking + + Raises: + TransportError: If connection fails + """ + try: + await transport.connect() + + if transport_id and transport_id not in self._active_transports: + self._active_transports[transport_id] = transport + + msg = "Connected transport" + if transport_id: + msg += f" (id: {transport_id})" + self._logger.debug(msg) + except Exception as e: + self._logger.error(f"Failed to connect transport: {e}") + raise TransportError(f"Transport connection failed: {e}") from e + + async def disconnect( + self, transport: TransportDriver, transport_id: str | None = None + ) -> None: + """Disconnect a transport. + + Args: + transport: Transport to disconnect + transport_id: Optional ID for tracking + + Raises: + TransportError: If disconnection fails + """ + try: + await transport.disconnect() + + if transport_id and transport_id in self._active_transports: + del self._active_transports[transport_id] + + msg = "Disconnected transport" + if transport_id: + msg += f" (id: {transport_id})" + self._logger.debug(msg) + except Exception as e: + self._logger.warning(f"Error disconnecting transport: {e}") + # Don't raise on disconnect errors, just log + + async def send_request( + self, + transport: TransportDriver, + method: str, + params: dict[str, Any] | None = None, + ) -> Any: + """Send a JSON-RPC request through a transport. + + Args: + transport: Transport to use + method: JSON-RPC method name + params: Optional parameters + + Returns: + Response from server + + Raises: + TransportError: If request fails + """ + try: + return await transport.send_request(method, params) + except Exception as e: + self._logger.error(f"Request failed ({method}): {e}") + raise TransportError(f"Request failed: {e}") from e + + async def send_raw( + self, transport: TransportDriver, payload: dict[str, Any] + ) -> Any: + """Send a raw payload through a transport. + + Args: + transport: Transport to use + payload: Raw payload to send + + Returns: + Response from server + + Raises: + TransportError: If request fails + """ + try: + return await transport.send_raw(payload) + except Exception as e: + self._logger.error(f"Raw request failed: {e}") + raise TransportError(f"Raw request failed: {e}") from e + + async def get_tools(self, transport: TransportDriver) -> list[dict[str, Any]]: + """Get tools from server using JSON-RPC helper. + + Args: + transport: Transport to use + + Returns: + List of tools from server + """ + self._jsonrpc_helper.set_transport(transport) + return await self._jsonrpc_helper.get_tools() + + async def call_tool( + self, transport: TransportDriver, tool_name: str, arguments: dict[str, Any] + ) -> Any: + """Call a tool using JSON-RPC helper. + + Args: + transport: Transport to use + tool_name: Name of tool to call + arguments: Tool arguments + + Returns: + Tool execution result + """ + self._jsonrpc_helper.set_transport(transport) + return await self._jsonrpc_helper.call_tool(tool_name, arguments) + + def get_active_transports(self) -> dict[str, TransportDriver]: + """Get all active tracked transports. + + Returns: + Dictionary mapping transport IDs to transport instances + """ + return dict(self._active_transports) + + def list_available_transports(self) -> dict[str, dict[str, Any]]: + """List all available transport types. + + Returns: + Dictionary of available transports (built-in and custom) + """ + return driver_catalog.list_transports() + + async def cleanup(self) -> None: + """Clean up all active transports.""" + self._logger.debug("Cleaning up transport subsystem") + + for transport_id, transport in list(self._active_transports.items()): + try: + await self.disconnect(transport, transport_id) + except Exception as e: + self._logger.warning(f"Error cleaning up transport {transport_id}: {e}") + + self._active_transports.clear() + self._logger.debug("Transport subsystem cleanup complete") diff --git a/mcp_fuzzer/transport/manager.py b/mcp_fuzzer/transport/controller/process_supervisor.py similarity index 91% rename from mcp_fuzzer/transport/manager.py rename to mcp_fuzzer/transport/controller/process_supervisor.py index e344bcb..61523a8 100644 --- a/mcp_fuzzer/transport/manager.py +++ b/mcp_fuzzer/transport/controller/process_supervisor.py @@ -4,16 +4,16 @@ from dataclasses import dataclass from typing import Any, Callable -from ..events import ProcessEventObserver -from ..exceptions import TransportError +from ...events import ProcessEventObserver +from ...exceptions import TransportError @dataclass -class TransportProcessState: +class ProcessState: """Lightweight state container for a transport-managed process. ``restart_count`` tracks the lifetime number of restarts (used for reporting), - while :class:`TransportManager._restart_attempts` counts consecutive restart + while :class:`ProcessSupervisor._restart_attempts` counts consecutive restart trials within the current backoff window. """ @@ -53,7 +53,7 @@ def record_stderr_tail(self, text: str) -> None: self.last_stderr_tail = text -class TransportManager: +class ProcessSupervisor: """Helper that tracks transport process state, backoff, and observer events. ``_restart_attempts`` counts the consecutive restart trials used for the @@ -65,12 +65,12 @@ class TransportManager: def __init__( self, *, - max_read_bytes: int = 256 * 1024, # 256KB - backoff_base: float = 0.2, # 0.2 seconds - backoff_cap: float = 2.0, # 2 seconds + max_read_bytes: int = 256 * 1024, # 256KB + backoff_base: float = 0.2, # 0.2 seconds + backoff_cap: float = 2.0, # 2 seconds logger: logging.Logger | None = None, ) -> None: - self.state = TransportProcessState() + self.state = ProcessState() self._observers: list[ProcessEventObserver] = [] self._max_read_bytes = max_read_bytes self._backoff_base = backoff_base diff --git a/mcp_fuzzer/transport/custom.py b/mcp_fuzzer/transport/custom.py deleted file mode 100644 index d33f3fe..0000000 --- a/mcp_fuzzer/transport/custom.py +++ /dev/null @@ -1,196 +0,0 @@ -""" -Custom transport registry and utilities. - -This module provides support for registering and managing custom transport -implementations that can be used alongside built-in transports. -""" - -import logging -from typing import Type, Any, Callable -from .base import TransportProtocol -from ..exceptions import TransportRegistrationError - -logger = logging.getLogger(__name__) - -class CustomTransportRegistry: - """Registry for custom transport implementations.""" - - def __init__(self): - self._transports: dict[str, dict[str, Any]] = {} - - def clear(self) -> None: - """Clear all registered transports. Useful for testing.""" - self._transports.clear() - - def register( - self, - name: str, - transport_class: Type[TransportProtocol], - description: str = "", - config_schema: dict[str, Any] | None = None, - factory_function: Callable | None = None, - ) -> None: - """Register a custom transport. - - Args: - name: Unique name for the transport - transport_class: The transport class to register - description: Human-readable description - config_schema: JSON schema for transport configuration - factory_function: Optional factory function to create transport instances - - Raises: - TransportRegistrationError: If transport name is already registered - """ - key = name.strip().lower() - if key in self._transports: - raise TransportRegistrationError( - f"Transport '{name}' is already registered" - ) - - if not issubclass(transport_class, TransportProtocol): - raise TransportRegistrationError( - f"Transport class {transport_class} must inherit from TransportProtocol" - ) - - self._transports[key] = { - "class": transport_class, - "description": description, - "config_schema": config_schema, - "factory": factory_function, - } - - logger.info(f"Registered custom transport: {key}") - - def unregister(self, name: str) -> None: - """Unregister a custom transport. - - Args: - name: Name of the transport to unregister - - Raises: - TransportRegistrationError: If transport is not registered - """ - key = name.strip().lower() - if key not in self._transports: - raise TransportRegistrationError(f"Transport '{name}' is not registered") - - del self._transports[key] - logger.info(f"Unregistered custom transport: {key}") - - def get_transport_class(self, name: str) -> Type[TransportProtocol]: - """Get the transport class for a registered transport. - - Args: - name: Name of the registered transport - - Returns: - The transport class - - Raises: - TransportRegistrationError: If transport is not registered - """ - key = name.strip().lower() - if key not in self._transports: - raise TransportRegistrationError(f"Transport '{name}' is not registered") - return self._transports[key]["class"] - - def get_transport_info(self, name: str) -> dict[str, Any]: - """Get information about a registered transport. - - Args: - name: Name of the registered transport - - Returns: - Dictionary containing transport information - - Raises: - TransportRegistrationError: If transport is not registered - """ - key = name.strip().lower() - if key not in self._transports: - raise TransportRegistrationError(f"Transport '{name}' is not registered") - return self._transports[key].copy() - - def list_transports(self) -> dict[str, dict[str, Any]]: - """List all registered custom transports. - - Returns: - Dictionary mapping transport names to their information - """ - return self._transports.copy() - - def create_transport(self, name: str, *args, **kwargs) -> TransportProtocol: - """Create an instance of a registered transport. - - Args: - name: Name of the registered transport - *args: Positional arguments to pass to transport constructor - **kwargs: Keyword arguments to pass to transport constructor - - Returns: - Transport instance - - Raises: - TransportRegistrationError: If transport is not registered - """ - transport_info = self.get_transport_info(name) - transport_class = transport_info["class"] - factory = transport_info.get("factory") - - # If no factory is provided, support "name://endpoint" shorthand by - # rewriting the first positional arg to just the endpoint. - if factory is None: - if args and len(args) == 1 and isinstance(args[0], str): - url = args[0] - if f"{name}://" in url: - endpoint = url.split(f"{name}://", 1)[1] - args = (endpoint,) + args[1:] - return transport_class(*args, **kwargs) - else: - # Pass through as-is; factories can handle full URLs or endpoints. - return factory(*args, **kwargs) - -# Global registry instance -registry = CustomTransportRegistry() - -def register_custom_transport( - name: str, - transport_class: Type[TransportProtocol], - description: str = "", - config_schema: dict[str, Any] | None = None, - factory_function: Callable | None = None, -) -> None: - """Register a custom transport with the global registry. - - Args: - name: Unique name for the transport - transport_class: The transport class to register - description: Human-readable description - config_schema: JSON schema for transport configuration - factory_function: Optional factory function to create transport instances - """ - registry.register( - name, transport_class, description, config_schema, factory_function - ) - -def create_custom_transport(name: str, *args, **kwargs) -> TransportProtocol: - """Create an instance of a registered custom transport. - - Args: - name: Name of the registered transport - *args: Positional arguments to pass to transport constructor - **kwargs: Keyword arguments to pass to transport constructor - - Returns: - Transport instance - """ - return registry.create_transport(name, *args, **kwargs) - -def list_custom_transports() -> dict[str, dict[str, Any]]: - """List all registered custom transports. - - Returns: - Dictionary mapping transport names to their information - """ - return registry.list_transports() diff --git a/mcp_fuzzer/transport/drivers/__init__.py b/mcp_fuzzer/transport/drivers/__init__.py new file mode 100644 index 0000000..aaa1f8d --- /dev/null +++ b/mcp_fuzzer/transport/drivers/__init__.py @@ -0,0 +1,13 @@ +"""Concrete driver implementations for the transport subsystem.""" + +from .http_driver import HttpDriver +from .sse_driver import SseDriver +from .stdio_driver import StdioDriver +from .stream_http_driver import StreamHttpDriver + +__all__ = [ + "HttpDriver", + "SseDriver", + "StdioDriver", + "StreamHttpDriver", +] diff --git a/mcp_fuzzer/transport/http.py b/mcp_fuzzer/transport/drivers/http_driver.py similarity index 93% rename from mcp_fuzzer/transport/http.py rename to mcp_fuzzer/transport/drivers/http_driver.py index 417def1..8f0be57 100644 --- a/mcp_fuzzer/transport/http.py +++ b/mcp_fuzzer/transport/drivers/http_driver.py @@ -7,38 +7,48 @@ import httpx -from .base import TransportProtocol -from .mixins import NetworkTransportMixin, ResponseParsingMixin -from ..fuzz_engine.runtime import ProcessManager, WatchdogConfig -from ..config import ( +from ..interfaces.driver import TransportDriver +from ..interfaces.behaviors import ( + HttpClientBehavior, + ResponseParserBehavior, + LifecycleBehavior, +) +from ...fuzz_engine.runtime import ProcessManager, WatchdogConfig +from ...config import ( JSON_CONTENT_TYPE, DEFAULT_HTTP_ACCEPT, ) -from ..safety_system.policy import ( +from ...safety_system.policy import ( resolve_redirect_safely, ) -class HTTPTransport(TransportProtocol, NetworkTransportMixin, ResponseParsingMixin): + +class HttpDriver( + TransportDriver, + HttpClientBehavior, + ResponseParserBehavior, + LifecycleBehavior, +): """ HTTP transport implementation with reduced code duplication. This implementation uses mixins to provide shared functionality, addressing the code duplication issues identified in GitHub issue #41. - Mixin Composition: - - TransportProtocol (ABC): Defines the core interface (send_request, send_raw, etc.) - - NetworkTransportMixin: Provides shared network functionality including: + Behavior Composition: + - TransportDriver (ABC): Defines the core interface (send_request, send_raw, etc.) + - HttpClientBehavior: Provides shared network functionality including: - Connection management and HTTP client creation - Header preparation and validation - Timeout handling and activity tracking - Network request validation and error handling - - ResponseParsingMixin: Handles HTTP-specific response processing: + - ResponseParserBehavior: Handles HTTP-specific response processing: - JSON-RPC payload creation and validation - HTTP response error handling (status codes, timeouts) - Redirect resolution with safety policies - Response parsing and serialization checks - This composition allows HTTPTransport to focus on HTTP-specific logic while + This composition allows HttpDriver to focus on HTTP-specific logic while reusing common network and response handling code. Future HTTP transports (e.g., WebSocket over HTTP) can inherit from the same mixins to stay consistent. """ diff --git a/mcp_fuzzer/transport/sse.py b/mcp_fuzzer/transport/drivers/sse_driver.py similarity index 52% rename from mcp_fuzzer/transport/sse.py rename to mcp_fuzzer/transport/drivers/sse_driver.py index 4a528df..7ce8d1b 100644 --- a/mcp_fuzzer/transport/sse.py +++ b/mcp_fuzzer/transport/drivers/sse_driver.py @@ -1,20 +1,45 @@ +"""Server-Sent Events (SSE) transport implementation. + +This transport implementation uses mixins to provide shared functionality, +reducing code duplication significantly (~100 lines). +""" + +from __future__ import annotations + import json -import logging -from typing import Any +from typing import Any, AsyncIterator + +from ..interfaces.driver import TransportDriver +from ..interfaces.behaviors import HttpClientBehavior, ResponseParserBehavior +from ...exceptions import TransportError + -import httpx +class SseDriver(TransportDriver, HttpClientBehavior, ResponseParserBehavior): + """SSE transport implementation using mixins. -from .base import TransportProtocol -from ..exceptions import NetworkPolicyViolation, ServerError, TransportError -from ..safety_system.policy import is_host_allowed, sanitize_headers + This implementation leverages mixins to share common network and parsing + functionality with other HTTP-based transports, reducing code duplication. + + Mixin Composition: + - TransportDriver: Core interface + - HttpClientBehavior: Network validation, header sanitization, HTTP client + - ResponseParserBehavior: SSE event parsing + """ -class SSETransport(TransportProtocol): def __init__( self, url: str, timeout: float = 30.0, auth_headers: dict[str, str | None] | None = None, ): + """Initialize SSE transport. + + Args: + url: Server URL for SSE connection + timeout: Request timeout in seconds + auth_headers: Optional authentication headers + """ + super().__init__() self.url = url self.timeout = timeout self.headers = { @@ -27,46 +52,56 @@ def __init__( async def send_request( self, method: str, params: dict[str, Any | None] | None = None ) -> dict[str, Any]: - # SSE transport does not support non-streaming requests via send_request. - # Use stream-based APIs instead (e.g., _stream_request). - raise NotImplementedError("SSETransport does not support send_request") + """SSE transport does not support non-streaming requests. + + Use stream-based APIs instead (e.g., _stream_request). + + Raises: + NotImplementedError: SSE transport only supports streaming + """ + raise NotImplementedError("SseDriver does not support send_request") async def send_raw(self, payload: dict[str, Any]) -> Any: - async with httpx.AsyncClient( - timeout=self.timeout, - follow_redirects=False, - trust_env=False, - ) as client: - if not is_host_allowed(self.url): - raise NetworkPolicyViolation( - "Network to non-local host is disallowed by safety policy", - context={"url": self.url}, - ) - safe_headers = sanitize_headers(self.headers) + """Send a raw payload and return the first SSE response. + + Args: + payload: JSON-RPC payload to send + + Returns: + Parsed response from server + + Raises: + TransportError: If request fails or no valid response + ServerError: If server returns an error + """ + # Use shared network functionality + self._validate_network_request(self.url) + safe_headers = self._prepare_safe_headers(self.headers) + + async with self._create_http_client(self.timeout) as client: response = await client.post(self.url, json=payload, headers=safe_headers) - response.raise_for_status() + self._handle_http_response_error(response) + + # Process response text as SSE stream buffer: list[str] = [] - def flush_once() -> dict[str, Any | None]: + def flush_once() -> dict[str, Any] | None: + """Flush buffer and parse as SSE event.""" if not buffer: return None event_text = "\n".join(buffer) buffer.clear() try: - data = SSETransport._parse_sse_event(event_text) + data = self.parse_sse_event(event_text) except json.JSONDecodeError: - logging.error("Failed to parse SSE data as JSON") + self._logger.error("Failed to parse SSE data as JSON") return None if data is None: return None - if "error" in data: - raise ServerError( - "Server returned error", - context={"url": self.url, "error": data["error"]}, - ) - result = data.get("result", data) - return result if isinstance(result, dict) else {"result": result} + # Use shared result extraction + return self._extract_result_from_response(data) + # Parse response text line by line for line in response.text.splitlines(): if not line.strip(): result = flush_once() @@ -74,20 +109,19 @@ def flush_once() -> dict[str, Any | None]: return result continue buffer.append(line) + + # Flush remaining buffer result = flush_once() if result is not None: return result + + # Try parsing entire response as JSON try: data = response.json() - if "error" in data: - raise ServerError( - "Server returned error", - context={"url": self.url, "error": data["error"]}, - ) - result = data.get("result", data) - return result if isinstance(result, dict) else {"result": result} + return self._extract_result_from_response(data) except json.JSONDecodeError: pass + raise TransportError( "No valid SSE response received", context={"url": self.url}, @@ -96,22 +130,28 @@ def flush_once() -> dict[str, Any | None]: async def send_notification( self, method: str, params: dict[str, Any | None] | None = None ) -> None: - payload = {"jsonrpc": "2.0", "method": method, "params": params or {}} - async with httpx.AsyncClient( - timeout=self.timeout, - follow_redirects=False, - trust_env=False, - ) as client: - if not is_host_allowed(self.url): - raise NetworkPolicyViolation( - "Network to non-local host is disallowed by safety policy", - context={"url": self.url}, - ) - safe_headers = sanitize_headers(self.headers) + """Send a JSON-RPC notification via SSE. + + Args: + method: Method name + params: Optional parameters + + Raises: + TransportError: If notification fails to send + """ + payload = self._create_jsonrpc_notification(method, params) + + # Use shared network functionality + self._validate_network_request(self.url) + safe_headers = self._prepare_safe_headers(self.headers) + + async with self._create_http_client(self.timeout) as client: response = await client.post(self.url, json=payload, headers=safe_headers) - response.raise_for_status() + self._handle_http_response_error(response) - async def _stream_request(self, payload: dict[str, Any]): + async def _stream_request( + self, payload: dict[str, Any] + ) -> AsyncIterator[dict[str, Any]]: """Stream a request via SSE and yield parsed events. Args: @@ -119,25 +159,22 @@ async def _stream_request(self, payload: dict[str, Any]): Yields: Parsed JSON objects from SSE events + + Raises: + TransportError: If streaming fails """ - async with httpx.AsyncClient( - timeout=self.timeout, - follow_redirects=False, - trust_env=False, - ) as client: - if not is_host_allowed(self.url): - raise NetworkPolicyViolation( - "Network to non-local host is disallowed by safety policy", - context={"url": self.url}, - ) - safe_headers = sanitize_headers(self.headers) + # Use shared network functionality + self._validate_network_request(self.url) + safe_headers = self._prepare_safe_headers(self.headers) + + async with self._create_http_client(self.timeout) as client: async with client.stream( "POST", self.url, json=payload, headers=safe_headers, ) as response: - response.raise_for_status() + self._handle_http_response_error(response) chunks = response.aiter_text() buffer = [] # Buffer to accumulate SSE event data @@ -158,13 +195,11 @@ async def _stream_request(self, payload: dict[str, Any]): if buffer: try: event_text = "\n".join(buffer) - parsed = SSETransport._parse_sse_event( - event_text - ) + parsed = self.parse_sse_event(event_text) if parsed is not None: yield parsed except json.JSONDecodeError: - logging.error( + self._logger.error( "Failed to parse SSE event payload as JSON" ) finally: @@ -184,13 +219,11 @@ async def _stream_request(self, payload: dict[str, Any]): if buffer: try: event_text = "\n".join(buffer) - parsed = SSETransport._parse_sse_event( - event_text - ) + parsed = self.parse_sse_event(event_text) if parsed is not None: yield parsed except json.JSONDecodeError: - logging.error( + self._logger.error( "Failed to parse SSE event payload as JSON" ) finally: @@ -200,38 +233,8 @@ async def _stream_request(self, payload: dict[str, Any]): if buffer: try: event_text = "\n".join(buffer) - parsed = SSETransport._parse_sse_event(event_text) + parsed = self.parse_sse_event(event_text) if parsed is not None: yield parsed except json.JSONDecodeError: - logging.error("Failed to parse SSE event payload as JSON") - - @staticmethod - def _parse_sse_event(event_text: str) -> dict[str, Any | None]: - """Parse a single SSE event text into a JSON object. - - The input may contain multiple lines such as "event:", "data:", or - control fields like "retry:". Only the JSON payload from one or more - "data:" lines is considered. Multiple data lines are concatenated. - - Returns None when there is no data payload. Raises JSONDecodeError when - a data payload is present but cannot be parsed as JSON. - """ - if not event_text: - return None - - data_parts: list[str] = [] - for raw_line in event_text.splitlines(): - line = raw_line.strip() - if not line: - continue - if line.startswith("data:"): - data_parts.append(line[len("data:") :].strip()) - # Ignore other fields such as "event:" and "retry:" - - if not data_parts: - return None - - data_str = "\n".join(data_parts) - # May raise JSONDecodeError if invalid, as intended by tests - return json.loads(data_str) + self._logger.error("Failed to parse SSE event payload as JSON") diff --git a/mcp_fuzzer/transport/stdio.py b/mcp_fuzzer/transport/drivers/stdio_driver.py similarity index 97% rename from mcp_fuzzer/transport/stdio.py rename to mcp_fuzzer/transport/drivers/stdio_driver.py index c9aa0a2..ab6b087 100644 --- a/mcp_fuzzer/transport/stdio.py +++ b/mcp_fuzzer/transport/drivers/stdio_driver.py @@ -11,19 +11,20 @@ import time from typing import Any, Callable -from .base import TransportProtocol -from ..exceptions import ( +from ..interfaces.driver import TransportDriver +from ...exceptions import ( ProcessSignalError, ProcessStartError, ServerError, TransportError, ) -from ..fuzz_engine.runtime import ProcessManager, WatchdogConfig -from ..safety_system.policy import sanitize_subprocess_env -from ..config.constants import PROCESS_WAIT_TIMEOUT -from .manager import TransportManager +from ...fuzz_engine.runtime import ProcessManager, WatchdogConfig +from ...safety_system.policy import sanitize_subprocess_env +from ...config.constants import PROCESS_WAIT_TIMEOUT +from ..controller.process_supervisor import ProcessSupervisor -class StdioTransport(TransportProtocol): + +class StdioDriver(TransportDriver): def __init__( self, command: str, @@ -42,7 +43,7 @@ def __init__( self._initialized = False self._last_activity = time.time() self.process_manager = process_manager - self.manager = TransportManager(logger=logging.getLogger(__name__)) + self.manager = ProcessSupervisor(logger=logging.getLogger(__name__)) def _get_lock(self): """Get or create the lock lazily.""" @@ -445,9 +446,9 @@ async def send_timeout_signal(self, signal_type: str = "timeout") -> bool: logging.info( ( "Sent terminate timeout signal to process " - f"{self.process.pid}" - ) + f"{self.process.pid}" ) + ) elif signal_type == "force": # Send SIGKILL (force kill) if os.name != "nt": @@ -499,8 +500,8 @@ async def send_timeout_signal(self, signal_type: str = "timeout") -> bool: ( "Sent terminate interrupt signal to process " f"{self.process.pid}" - ) ) + ) else: logging.warning(f"Unknown signal type: {signal_type}") return False diff --git a/mcp_fuzzer/transport/streamable_http.py b/mcp_fuzzer/transport/drivers/stream_http_driver.py similarity index 64% rename from mcp_fuzzer/transport/streamable_http.py rename to mcp_fuzzer/transport/drivers/stream_http_driver.py index 2e3be62..9289bd0 100644 --- a/mcp_fuzzer/transport/streamable_http.py +++ b/mcp_fuzzer/transport/drivers/stream_http_driver.py @@ -1,10 +1,25 @@ +"""Stream HTTP driver with SSE support and session headers. + +This transport implementation uses mixins to reduce code duplication significantly +(~150 lines), sharing network validation, header handling, and response parsing +with other HTTP-based transports. +""" + +from __future__ import annotations + import asyncio import json -import logging -from typing import Any +from typing import Any, AsyncIterator import httpx -from ..config import ( + +from ..interfaces.driver import TransportDriver +from ..interfaces.behaviors import ( + HttpClientBehavior, + ResponseParserBehavior, + NetworkError as DriverNetworkError, +) +from ...config import ( DEFAULT_PROTOCOL_VERSION, CONTENT_TYPE_HEADER, JSON_CONTENT_TYPE, @@ -13,7 +28,7 @@ MCP_PROTOCOL_VERSION_HEADER, DEFAULT_HTTP_ACCEPT, ) -from ..types import ( +from ...types import ( HTTP_ACCEPTED, HTTP_REDIRECT_TEMPORARY, HTTP_REDIRECT_PERMANENT, @@ -21,14 +36,7 @@ DEFAULT_TIMEOUT, RETRY_DELAY, ) - -from .base import TransportProtocol -from ..exceptions import NetworkPolicyViolation, ServerError, TransportError -from ..safety_system.policy import ( - is_host_allowed, - resolve_redirect_safely, - sanitize_headers, -) +from ...exceptions import TransportError # Back-compat local aliases (referenced by tests) MCP_SESSION_ID = MCP_SESSION_ID_HEADER @@ -37,13 +45,19 @@ JSON_CT = JSON_CONTENT_TYPE SSE_CT = SSE_CONTENT_TYPE -class StreamableHTTPTransport(TransportProtocol): - """Streamable HTTP transport with basic SSE support and session headers. - This mirrors the MCP SDK's StreamableHTTP semantics enough for fuzzing: +class StreamHttpDriver(TransportDriver, HttpClientBehavior, ResponseParserBehavior): + """Streamable HTTP transport with MCP session management. + + This mirrors the MCP SDK's StreamableHTTP semantics for fuzzing: - Sends Accept: application/json, text/event-stream - Parses JSON or SSE responses - Tracks and propagates mcp-session-id and mcp-protocol-version headers + + Mixin Composition: + - TransportDriver: Core interface + - HttpClientBehavior: Network validation, header sanitization, HTTP client + - ResponseParserBehavior: Response parsing (JSON and SSE) """ def __init__( @@ -52,6 +66,14 @@ def __init__( timeout: float = DEFAULT_TIMEOUT, auth_headers: dict[str, str | None] = None, ): + """Initialize streamable HTTP transport. + + Args: + url: Server URL + timeout: Request timeout in seconds + auth_headers: Optional authentication headers + """ + super().__init__() self.url = url self.timeout = timeout self.headers: dict[str, str] = { @@ -61,7 +83,6 @@ def __init__( if auth_headers: self.headers.update(auth_headers) - self._logger = logging.getLogger(__name__) self.session_id: str | None = None self.protocol_version: str | None = None self._initialized: bool = False @@ -69,6 +90,11 @@ def __init__( self._initializing: bool = False def _prepare_headers(self) -> dict[str, str]: + """Prepare headers with session information. + + Returns: + Headers dict with session information + """ headers = dict(self.headers) if self.session_id: headers[MCP_SESSION_ID] = self.session_id @@ -76,35 +102,28 @@ def _prepare_headers(self) -> dict[str, str]: headers[MCP_PROTOCOL_VERSION] = self.protocol_version return headers - def _ensure_host_allowed(self) -> None: - """Raise if the destination host violates safety policy.""" - if not is_host_allowed(self.url): - raise NetworkPolicyViolation( - "Network to non-local host is disallowed by safety policy", - context={"url": self.url}, - ) - - def _raise_http_status_error( - self, error: httpx.HTTPStatusError, *, method: str | None = None - ) -> None: - """Convert httpx HTTP status errors into TransportError instances.""" - request_url = str(error.request.url) if error.request else self.url - status = error.response.status_code if error.response else None - context: dict[str, Any] = {"url": request_url, "status": status} - if method: - context["method"] = method - raise TransportError( - f"HTTP error while communicating with {request_url}", context=context - ) from error - def _maybe_extract_session_headers(self, response: httpx.Response) -> None: + """Extract session ID from response headers. + + Args: + response: HTTP response to extract from + """ sid = response.headers.get(MCP_SESSION_ID) if sid: - # Update session id if server sends one self.session_id = sid self._logger.debug("Received session id: %s", sid) + protocol_header = response.headers.get(MCP_PROTOCOL_VERSION) + if protocol_header: + self.protocol_version = protocol_header + self._logger.debug("Received protocol version header: %s", protocol_header) + def _maybe_extract_protocol_version_from_result(self, result: Any) -> None: + """Extract protocol version from result. + + Args: + result: Result dict that may contain protocolVersion + """ try: if isinstance(result, dict) and "protocolVersion" in result: pv = result.get("protocolVersion") @@ -114,8 +133,55 @@ def _maybe_extract_protocol_version_from_result(self, result: Any) -> None: except Exception: pass - async def _parse_sse_response(self, response: httpx.Response) -> Any: - """Parse SSE stream and return on first JSON-RPC response/error.""" + def _resolve_redirect(self, response: httpx.Response) -> str | None: + """Resolve redirect target with safety checks. + + Args: + response: HTTP response to check for redirects + + Returns: + Resolved redirect URL or None + """ + redirect_codes = (HTTP_REDIRECT_TEMPORARY, HTTP_REDIRECT_PERMANENT) + if response.status_code not in redirect_codes: + return None + + location = response.headers.get("location") + if not location and not self.url.endswith("/"): + location = self.url + "/" + if not location: + return None + + # Use the base mixin's redirect resolution (imported from safety policy) + from ...safety_system.policy import resolve_redirect_safely + + resolved = resolve_redirect_safely(self.url, location) + if not resolved: + self._logger.warning( + "Refusing redirect that violates policy from %s", self.url + ) + return resolved + + def _extract_content_type(self, response: httpx.Response) -> str: + """Extract content type from response. + + Args: + response: HTTP response + + Returns: + Content type string (lowercase) + """ + return response.headers.get(CONTENT_TYPE, "").lower() + + async def _parse_sse_response_for_result(self, response: httpx.Response) -> Any: + """Parse SSE stream and return first JSON-RPC response/error. + + Args: + response: HTTP response with SSE content + + Returns: + First parsed result from SSE stream + """ # Basic SSE parser: accumulate fields until blank line event: dict[str, Any] = {"event": "message", "data": []} async for line in response.aiter_lines(): @@ -159,43 +225,103 @@ async def _parse_sse_response(self, response: httpx.Response) -> Any: # If we exit loop without a response, return None return None - def _resolve_redirect(self, response: httpx.Response) -> str | None: - redirect_codes = (HTTP_REDIRECT_TEMPORARY, HTTP_REDIRECT_PERMANENT) - if response.status_code not in redirect_codes: - return None - location = response.headers.get("location") - if not location and not self.url.endswith("/"): - location = self.url + "/" - if not location: - return None - resolved = resolve_redirect_safely(self.url, location) - if not resolved: - self._logger.warning( - "Refusing redirect that violates policy from %s", self.url - ) - return resolved + async def _post_with_retries( + self, + client: httpx.AsyncClient, + url: str, + json: dict[str, Any], + headers: dict[str, str], + retries: int = 2, + ) -> httpx.Response: + """POST with exponential backoff for transient network errors. - def _extract_content_type(self, response: httpx.Response) -> str: - return response.headers.get(CONTENT_TYPE, "").lower() + Args: + client: HTTP client + url: URL to post to + json: JSON payload + headers: Request headers + retries: Maximum retry attempts + + Returns: + HTTP response + + Raises: + TransportError: If all retries fail + """ + delay = RETRY_DELAY + attempt = 0 + while True: + try: + return await client.post(url, json=json, headers=headers) + except (httpx.ConnectError, httpx.ReadTimeout) as e: + # Only retry for safe, idempotent, or initialization-like methods + method = None + try: + method = json.get("method") + except Exception: + pass + safe = method in ( + "initialize", + "notifications/initialized", + "tools/list", + "prompts/list", + "resources/list", + ) + if attempt >= retries or not safe: + context = { + "url": url, + "error_type": type(e).__name__, + "attempts": attempt + 1, + } + if method: + context["method"] = method + raise TransportError( + "Connection failed while contacting server", context=context + ) from e + self._logger.debug( + "POST retry %d for %s due to %s", + attempt + 1, + url, + type(e).__name__, + ) + await asyncio.sleep(delay) + delay *= 2 + attempt += 1 async def send_request( self, method: str, params: dict[str, Any] | None = None ) -> Any: + """Send a JSON-RPC request and return the response. + + Args: + method: Method name + params: Optional parameters + + Returns: + Response from server + """ request_id = str(asyncio.get_running_loop().time()) - payload = { - "jsonrpc": "2.0", - "id": request_id, - "method": method, - "params": params or {}, - } + payload = self._create_jsonrpc_request(method, params, request_id) return await self.send_raw(payload) async def send_raw(self, payload: dict[str, Any]) -> Any: + """Send raw payload and return the response. + + Args: + payload: Raw JSON-RPC payload + + Returns: + Response from server + + Raises: + TransportError: If request fails + """ # Ensure MCP initialization handshake once per session try: method = payload.get("method") except AttributeError: method = None + if not self._initialized and method != "initialize": async with self._init_lock: if not self._initialized and not self._initializing: @@ -206,24 +332,28 @@ async def send_raw(self, payload: dict[str, Any]) -> Any: self._initializing = False headers = self._prepare_headers() - async with httpx.AsyncClient( - timeout=self.timeout, follow_redirects=False, trust_env=False - ) as client: - self._ensure_host_allowed() + + # Use shared network functionality + self._validate_network_request(self.url) + safe_headers = self._prepare_safe_headers(headers) + + async with self._create_http_client(self.timeout) as client: response = await self._post_with_retries( - client, self.url, payload, sanitize_headers(headers) + client, self.url, payload, safe_headers ) - # Handle redirect by retrying once with provided Location or trailing slash + + # Handle redirect redirect_url = self._resolve_redirect(response) if redirect_url: self._logger.debug("Following redirect to %s", redirect_url) response = await self._post_with_retries( - client, redirect_url, payload, headers + client, redirect_url, payload, safe_headers ) + # Update session headers if available self._maybe_extract_session_headers(response) - # Handle status codes similar to SDK + # Handle special status codes if response.status_code == HTTP_ACCEPTED: return {} if response.status_code == HTTP_NOT_FOUND: @@ -232,68 +362,30 @@ async def send_raw(self, payload: dict[str, Any]) -> Any: context={"url": self.url, "status": response.status_code}, ) + # Use shared error handling try: - response.raise_for_status() - except httpx.HTTPStatusError as exc: - self._raise_http_status_error(exc, method=method) + self._handle_http_response_error(response) + except DriverNetworkError as exc: + context = { + "url": self.url, + "status": response.status_code, + } + raise TransportError(str(exc), context=context) from exc + ct = self._extract_content_type(response) if ct.startswith(JSON_CT): - # Try to get the JSON response - try: - data = response.json() - except json.JSONDecodeError: - # Fallback: parse first JSON object from raw stream - data = {} - if hasattr(response, "aread"): - try: - content = await response.aread() - content_str = content.decode("utf-8").strip() - decoder = json.JSONDecoder() - pos = 0 - # Limit attempts to prevent infinite loops - max_attempts = 1000 - attempts = 0 - while pos < len(content_str) and attempts < max_attempts: - attempts += 1 - try: - parsed, new_pos = decoder.raw_decode( - content_str, pos - ) - data = parsed - break - except json.JSONDecodeError: - pos += 1 - # Skip whitespace - while ( - pos < len(content_str) - and content_str[pos].isspace() - ): - pos += 1 - except Exception: - pass - - if isinstance(data, dict): - if "error" in data: - raise ServerError( - "Server returned error", - context={"url": self.url, "error": data["error"]}, - ) - if "result" in data: - # Extract protocol version if present (initialize) - self._maybe_extract_protocol_version_from_result(data["result"]) - # Mark initialized if this was an explicit initialize call - if method == "initialize": - self._initialized = True - result = data["result"] - return ( - result if isinstance(result, dict) else {"result": result} - ) - # Normalize non-dict payloads + # Use shared JSON parsing (returns JSON-RPC result payload) + data = self._parse_http_response_json(response, fallback_to_sse=False) + + self._maybe_extract_protocol_version_from_result(data) + if method == "initialize": + self._initialized = True + return data if isinstance(data, dict) else {"result": data} if ct.startswith(SSE_CT): - parsed = await self._parse_sse_response(response) + parsed = await self._parse_sse_response_for_result(response) if method == "initialize": self._initialized = True if parsed is None: @@ -308,13 +400,20 @@ async def send_raw(self, payload: dict[str, Any]) -> Any: async def send_notification( self, method: str, params: dict[str, Any] | None = None ) -> None: - payload = {"jsonrpc": "2.0", "method": method, "params": params or {}} + """Send a JSON-RPC notification. + + Args: + method: Method name + params: Optional parameters + """ + payload = self._create_jsonrpc_notification(method, params) headers = self._prepare_headers() - async with httpx.AsyncClient( - timeout=self.timeout, follow_redirects=False, trust_env=False - ) as client: - self._ensure_host_allowed() - safe_headers = sanitize_headers(headers) + + # Use shared network functionality + self._validate_network_request(self.url) + safe_headers = self._prepare_safe_headers(headers) + + async with self._create_http_client(self.timeout) as client: response = await self._post_with_retries( client, self.url, payload, safe_headers ) @@ -323,13 +422,10 @@ async def send_notification( response = await self._post_with_retries( client, redirect_url, payload, safe_headers ) - try: - response.raise_for_status() - except httpx.HTTPStatusError as exc: - self._raise_http_status_error(exc, method=method) + self._handle_http_response_error(response) async def _do_initialize(self) -> None: - """Perform a minimal MCP initialize + initialized notification.""" + """Perform MCP initialize + initialized notification.""" init_payload = { "jsonrpc": "2.0", "id": str(asyncio.get_running_loop().time()), @@ -357,72 +453,24 @@ async def _do_initialize(self) -> None: # Surface the failure; leave _initialized False raise - async def _post_with_retries( - self, - client: httpx.AsyncClient, - url: str, - json: dict[str, Any], - headers: dict[str, str], - retries: int = 2, # Default max retries - ) -> httpx.Response: - """POST with simple exponential backoff for transient network errors.""" - delay = RETRY_DELAY - attempt = 0 - while True: - try: - return await client.post(url, json=json, headers=headers) - except (httpx.ConnectError, httpx.ReadTimeout) as e: - # Only retry for safe, idempotent, or initialization-like methods - method = None - try: - method = json.get("method") - except Exception: - pass - safe = method in ( - "initialize", - "notifications/initialized", - "tools/list", - "prompts/list", - "resources/list", - ) - if attempt >= retries or not safe: - context = { - "url": url, - "error_type": type(e).__name__, - "attempts": attempt + 1, - } - if method: - context["method"] = method - raise TransportError( - "Connection failed while contacting server", context=context - ) from e - self._logger.debug( - "POST retry %d for %s due to %s", - attempt + 1, - url, - type(e).__name__, - ) - await asyncio.sleep(delay) - delay *= 2 - attempt += 1 + async def _stream_request( + self, payload: dict[str, Any] + ) -> AsyncIterator[dict[str, Any]]: + """Stream a request and yield parsed data lines. - async def _stream_request(self, payload: dict[str, Any]): - """Stream a request and yield parsed JSON or SSE data lines. + Args: + payload: Request payload - This mirrors the logic used in HTTPTransport._stream_request but adapted - for the streamable transport and its header/session handling. + Yields: + Parsed JSON objects from stream """ headers = self._prepare_headers() - method = None - try: - method = payload.get("method") - except AttributeError: - method = None - async with httpx.AsyncClient( - timeout=self.timeout, follow_redirects=False, trust_env=False - ) as client: - self._ensure_host_allowed() - safe_headers = sanitize_headers(headers) + + # Use shared network functionality + self._validate_network_request(self.url) + safe_headers = self._prepare_safe_headers(headers) + + async with self._create_http_client(self.timeout) as client: response = await client.stream( "POST", self.url, json=payload, headers=safe_headers ) @@ -435,9 +483,10 @@ async def _stream_request(self, payload: dict[str, Any]): ) try: - response.raise_for_status() + self._handle_http_response_error(response) # Update session headers from streaming response self._maybe_extract_session_headers(response) + async for line in response.aiter_lines(): if not line.strip(): continue @@ -452,7 +501,5 @@ async def _stream_request(self, payload: dict[str, Any]): except json.JSONDecodeError: self._logger.error("Failed to parse SSE data as JSON") continue - except httpx.HTTPStatusError as exc: - self._raise_http_status_error(exc, method=method) finally: await response.aclose() diff --git a/mcp_fuzzer/transport/factory.py b/mcp_fuzzer/transport/factory.py deleted file mode 100644 index 7c7a855..0000000 --- a/mcp_fuzzer/transport/factory.py +++ /dev/null @@ -1,125 +0,0 @@ - -from .base import TransportProtocol -from .http import HTTPTransport -from .sse import SSETransport -from .stdio import StdioTransport -from .streamable_http import StreamableHTTPTransport -from .custom import registry as custom_registry -from urllib.parse import urlparse, urlunparse -from ..exceptions import TransportRegistrationError - -class TransportRegistry: - """Registry for transport classes.""" - - def __init__(self): - self._transports: dict[str, type[TransportProtocol]] = {} - - def register(self, name: str, cls: type[TransportProtocol]) -> None: - """Register a transport class by name.""" - self._transports[name.lower()] = cls - - def list_transports(self) -> dict[str, type[TransportProtocol]]: - """List all registered transports.""" - return self._transports.copy() - - def create_transport(self, name: str, *args, **kwargs) -> TransportProtocol: - """Create a transport instance by name.""" - name_lower = name.lower() - if name_lower not in self._transports: - raise TransportRegistrationError(f"Unknown transport: {name}") - cls = self._transports[name_lower] - return cls(*args, **kwargs) - - -# Global registry -registry = TransportRegistry() - -# Register built-in transports -registry.register("http", HTTPTransport) -registry.register("https", HTTPTransport) -registry.register("sse", SSETransport) -registry.register("stdio", StdioTransport) -registry.register("streamablehttp", StreamableHTTPTransport) - -def create_transport( - url_or_protocol: str, endpoint: str | None = None, **kwargs -) -> TransportProtocol: - """Create a transport from either a full URL or protocol + endpoint. - - Backward-compatible with previous signature (protocol, endpoint). - """ - # Back-compat path: two-argument usage - if endpoint is not None: - key = url_or_protocol.strip().lower() - # Try custom transports first - try: - return custom_registry.create_transport(key, endpoint, **kwargs) - except TransportRegistrationError: - pass - # Try built-in registry - try: - return registry.create_transport(key, endpoint, **kwargs) - except TransportRegistrationError: - raise TransportRegistrationError( - f"Unsupported protocol: {url_or_protocol}. " - f"Supported: {', '.join(registry.list_transports().keys())}; " - f"custom: {', '.join(sorted(custom_registry.list_transports().keys()))}" - ) - - # Single-URL usage - parsed = urlparse(url_or_protocol) - scheme = (parsed.scheme or "").lower() - - # Handle custom schemes that urlparse doesn't recognize - if not scheme and "://" in url_or_protocol: - # Extract scheme manually for custom transports - scheme_part = url_or_protocol.split("://", 1)[0].strip().lower() - if custom_registry.list_transports().get(scheme_part): - scheme = scheme_part - - # Check for custom transport schemes first - if scheme: - try: - return custom_registry.create_transport(scheme, url_or_protocol, **kwargs) - except TransportRegistrationError: - pass # Fall through to built-in schemes - - if scheme in ("http", "https"): - return registry.create_transport("http", url_or_protocol, **kwargs) - if scheme == "sse": - # Convert sse://host/path to http://host/path (preserve params/query/fragment) - http_url = urlunparse( - ( - "http", - parsed.netloc, - parsed.path, - parsed.params, - parsed.query, - parsed.fragment, - ) - ) - return registry.create_transport("sse", http_url, **kwargs) - if scheme == "stdio": - # Allow stdio:cmd or stdio://cmd; default empty if none - has_parts = parsed.netloc or parsed.path - cmd_source = (parsed.netloc + parsed.path) if has_parts else "" - cmd = cmd_source.lstrip("/") - return registry.create_transport("stdio", cmd, **kwargs) - if scheme == "streamablehttp": - http_url = urlunparse( - ( - "http", - parsed.netloc, - parsed.path, - parsed.params, - parsed.query, - parsed.fragment, - ) - ) - return registry.create_transport("streamablehttp", http_url, **kwargs) - - raise TransportRegistrationError( - f"Unsupported URL scheme: {scheme or 'none'}. " - f"Supported: {', '.join(registry.list_transports().keys())}, " - f"custom: {', '.join(sorted(custom_registry.list_transports().keys()))}" - ) diff --git a/mcp_fuzzer/transport/interfaces/__init__.py b/mcp_fuzzer/transport/interfaces/__init__.py new file mode 100644 index 0000000..935d60b --- /dev/null +++ b/mcp_fuzzer/transport/interfaces/__init__.py @@ -0,0 +1,28 @@ +"""Core driver interfaces, states, shared behaviors, and RPC adapters.""" + +from .driver import TransportDriver +from .states import DriverState, ParsedEndpoint +from .behaviors import ( + DriverBaseBehavior, + HttpClientBehavior, + ResponseParserBehavior, + LifecycleBehavior, + TransportError, + NetworkError, + PayloadValidationError, +) +from .rpc_adapter import JsonRpcAdapter + +__all__ = [ + "TransportDriver", + "DriverState", + "ParsedEndpoint", + "DriverBaseBehavior", + "HttpClientBehavior", + "ResponseParserBehavior", + "LifecycleBehavior", + "TransportError", + "NetworkError", + "PayloadValidationError", + "JsonRpcAdapter", +] diff --git a/mcp_fuzzer/transport/mixins.py b/mcp_fuzzer/transport/interfaces/behaviors.py similarity index 70% rename from mcp_fuzzer/transport/mixins.py rename to mcp_fuzzer/transport/interfaces/behaviors.py index 076945c..0754555 100644 --- a/mcp_fuzzer/transport/mixins.py +++ b/mcp_fuzzer/transport/interfaces/behaviors.py @@ -7,6 +7,7 @@ import json import logging +import time from abc import ABC from typing import ( Any, @@ -22,7 +23,9 @@ except ImportError: # pragma: no cover from typing_extensions import NotRequired -from ..safety_system.policy import is_host_allowed, sanitize_headers +from ...safety_system.policy import is_host_allowed, sanitize_headers +from .states import DriverState + class JSONRPCRequest(TypedDict): """Type definition for JSON-RPC request structure.""" @@ -32,6 +35,7 @@ class JSONRPCRequest(TypedDict): params: NotRequired[list[Any] | dict[str, Any]] id: str | int | None + class JSONRPCNotification(TypedDict): """Type definition for JSON-RPC notification structure.""" @@ -39,6 +43,7 @@ class JSONRPCNotification(TypedDict): method: str params: NotRequired[list[Any] | dict[str, Any]] + class JSONRPCErrorObject(TypedDict): """Type definition for JSON-RPC error object.""" @@ -46,6 +51,7 @@ class JSONRPCErrorObject(TypedDict): message: str data: NotRequired[Any] + class JSONRPCSuccessResponse(TypedDict): """Type definition for JSON-RPC success response.""" @@ -53,6 +59,7 @@ class JSONRPCSuccessResponse(TypedDict): result: Any id: str | int | None + class JSONRPCErrorResponse(TypedDict): """Type definition for JSON-RPC error response.""" @@ -60,23 +67,28 @@ class JSONRPCErrorResponse(TypedDict): error: JSONRPCErrorObject id: str | int | None + JSONRPCResponse = JSONRPCSuccessResponse | JSONRPCErrorResponse + class TransportError(Exception): - """Base exception for transport-related errors.""" + """Base exception for transport driver errors.""" pass + class NetworkError(TransportError): """Exception raised for network-related errors.""" pass + class PayloadValidationError(TransportError): """Exception raised for invalid payload validation.""" pass + class ResponseParser(Protocol): """Protocol for response parsing functionality.""" @@ -88,8 +100,9 @@ def parse_sse_response(self, response_text: str) -> dict[str, Any | None]: """Parse SSE response and extract JSON data.""" ... -class BaseTransportMixin(ABC): - """Base mixin providing common transport functionality.""" + +class DriverBaseBehavior(ABC): + """Base behavior providing common transport functionality.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -179,27 +192,36 @@ def _validate_jsonrpc_payload( raise PayloadValidationError("'method' must be a non-empty string") if "params" in payload and not isinstance(payload["params"], (list, dict)): raise PayloadValidationError("'params' must be array or object") - if "id" in payload and not isinstance(payload["id"], (str, int)) \ - and payload["id"] is not None: + if ( + "id" in payload + and not isinstance(payload["id"], (str, int)) + and payload["id"] is not None + ): raise PayloadValidationError("'id' must be string, number, or null") if strict and "id" not in payload: # In strict mode treat request-like payloads without id as invalid raise PayloadValidationError("Missing required field: id") else: if has_result == has_error: - raise PayloadValidationError("Response must have exactly one of \ -result or error") + raise PayloadValidationError( + "Response must have exactly one of \ +result or error" + ) if "id" not in payload: raise PayloadValidationError("Response must include 'id'") if not isinstance(payload["id"], (str, int)) and payload["id"] is not None: raise PayloadValidationError("'id' must be string, number, or null") if has_error: err = payload["error"] - if not isinstance(err, dict) or "code" not in err \ - or "message" not in err: + if ( + not isinstance(err, dict) + or "code" not in err + or "message" not in err + ): raise PayloadValidationError("Invalid error object") - if not isinstance(err["code"], int) \ - or not isinstance(err["message"], str): + if not isinstance(err["code"], int) or not isinstance( + err["message"], str + ): raise PayloadValidationError("Invalid error fields") def _validate_payload_serializable(self, payload: dict[str, Any]) -> None: @@ -260,8 +282,9 @@ def _extract_result_from_response( return data -class NetworkTransportMixin(BaseTransportMixin): - """Mixin for network-based transports (HTTP, SSE, WebSocket).""" + +class HttpClientBehavior(DriverBaseBehavior): + """Behavior mix-in for network-based transports (HTTP, SSE, WebSocket).""" def _validate_network_request(self, url: str) -> None: """Validate network request against safety policies. @@ -364,8 +387,9 @@ def _parse_http_response_json( raise TransportError("No valid JSON data found in response") -class ResponseParsingMixin(BaseTransportMixin): - """Mixin providing shared response parsing functionality.""" + +class ResponseParserBehavior(DriverBaseBehavior): + """Behavior providing shared response parsing functionality.""" def parse_sse_event(self, event_text: str) -> dict[str, Any | None]: """Parse a single SSE event text into a JSON object. @@ -453,3 +477,135 @@ def parse_streaming_response( logging.getLogger(self.__class__.__name__).error( "Failed to parse SSE event payload as JSON" ) + + +class LifecycleBehavior(DriverBaseBehavior): + """Behavior providing connection lifecycle management and activity tracking. + + This mixin adds connection state management, activity tracking for timeouts, + and resource cleanup patterns that can be used by all transport types. + """ + + def __init__(self, *args, **kwargs): + """Initialize connection lifecycle management.""" + super().__init__(*args, **kwargs) + self._connection_state = DriverState.INIT + self._last_activity = time.time() + self._connection_start_time: float | None = None + self._connection_end_time: float | None = None + self._activity_callbacks: list = [] + + @property + def connection_state(self) -> DriverState: + """Get current connection state.""" + return self._connection_state + + @property + def last_activity(self) -> float: + """Get timestamp of last activity.""" + return self._last_activity + + @property + def connection_duration(self) -> float | None: + """Get duration of connection in seconds, or None if not connected.""" + if self._connection_start_time is None: + return None + end_time = self._connection_end_time or time.time() + return end_time - self._connection_start_time + + def is_connected(self) -> bool: + """Check if transport is in connected state.""" + return self._connection_state == DriverState.CONNECTED + + def is_closed(self) -> bool: + """Check if transport is in closed state.""" + return self._connection_state == DriverState.CLOSED + + def is_error(self) -> bool: + """Check if transport is in error state.""" + return self._connection_state == DriverState.ERROR + + def _update_activity(self) -> None: + """Update the last activity timestamp.""" + self._last_activity = time.time() + + # Notify activity callbacks + for callback in self._activity_callbacks: + try: + callback(self._last_activity) + except Exception as e: + self._logger.warning(f"Activity callback failed: {e}") + + def _set_connection_state(self, state: DriverState) -> None: + """Set the connection state. + + Args: + state: New connection state + """ + old_state = self._connection_state + self._connection_state = state + + # Track connection timing + if state == DriverState.CONNECTED and old_state != DriverState.CONNECTED: + self._connection_start_time = time.time() + self._connection_end_time = None + elif ( + state in (DriverState.CLOSED, DriverState.ERROR) + and old_state == DriverState.CONNECTED + ): + self._connection_end_time = time.time() + + self._logger.debug( + f"Connection state changed: {old_state.value} -> {state.value}" + ) + + def register_activity_callback(self, callback) -> None: + """Register a callback to be notified on activity updates. + + Args: + callback: Callable that takes timestamp as argument + """ + self._activity_callbacks.append(callback) + + def time_since_last_activity(self) -> float: + """Get time in seconds since last activity.""" + return time.time() - self._last_activity + + async def _lifecycle_connect(self) -> None: + """Mark connection as starting and update state.""" + self._set_connection_state(DriverState.CONNECTING) + self._update_activity() + + async def _lifecycle_connected(self) -> None: + """Mark connection as established and update state.""" + self._set_connection_state(DriverState.CONNECTED) + self._update_activity() + + async def _lifecycle_disconnect(self) -> None: + """Mark connection as disconnecting and update state.""" + if self._connection_state == DriverState.CONNECTED: + self._set_connection_state(DriverState.DISCONNECTING) + self._update_activity() + + async def _lifecycle_closed(self) -> None: + """Mark connection as closed and update state.""" + self._set_connection_state(DriverState.CLOSED) + self._update_activity() + + async def _lifecycle_error(self, error: Exception | None = None) -> None: + """Mark connection as in error state. + + Args: + error: Optional exception that caused the error + """ + self._set_connection_state(DriverState.ERROR) + if error: + self._logger.error(f"Connection error: {error}") + self._update_activity() + + async def _cleanup_resources(self) -> None: + """Cleanup any resources held by the transport. + + Subclasses should override this to implement specific cleanup logic. + """ + pass diff --git a/mcp_fuzzer/transport/interfaces/driver.py b/mcp_fuzzer/transport/interfaces/driver.py new file mode 100644 index 0000000..a85a3ee --- /dev/null +++ b/mcp_fuzzer/transport/interfaces/driver.py @@ -0,0 +1,116 @@ +"""Core driver interface for all transport implementations. + +The TransportDriver describes the contract that every concrete driver +must follow. JSON-RPC specific helpers live in the rpc_adapter module. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, AsyncIterator + + +class TransportDriver(ABC): + """Abstract base class for transport drivers. + + This interface defines the core methods that all transports must implement + for sending requests, notifications, and streaming data. JSON-RPC specific + operations (get_tools, call_tool, etc.) have been moved to JsonRpcAdapter. + """ + + @abstractmethod + async def send_request( + self, method: str, params: dict[str, Any] | None = None + ) -> Any: + """Send a JSON-RPC request and return the response. + + Args: + method: The method name to call + params: Optional parameters for the method + + Returns: + Response data from the server + + Raises: + TransportError: If the request fails + """ + pass + + @abstractmethod + async def send_raw(self, payload: dict[str, Any]) -> Any: + """Send a raw payload and return the response. + + Args: + payload: Raw payload to send (should be JSON-RPC compatible) + + Returns: + Response data from the server + + Raises: + TransportError: If the request fails + """ + pass + + @abstractmethod + async def send_notification( + self, method: str, params: dict[str, Any] | None = None + ) -> None: + """Send a JSON-RPC notification (fire-and-forget). + + Args: + method: The method name to call + params: Optional parameters for the method + + Raises: + TransportError: If the notification fails to send + """ + pass + + async def connect(self) -> None: + """Connect to the transport. + + Default implementation does nothing. Transports that require + explicit connection setup should override this method. + """ + pass + + async def disconnect(self) -> None: + """Disconnect from the transport. + + Default implementation does nothing. Transports that require + explicit connection teardown should override this method. + """ + pass + + async def stream_request( + self, payload: dict[str, Any] + ) -> AsyncIterator[dict[str, Any]]: + """Stream a request to the transport. + + This is the public interface for streaming. The actual implementation + is delegated to _stream_request which subclasses must implement. + + Args: + payload: The request payload + + Yields: + Response chunks from the transport + """ + async for response in self._stream_request(payload): + yield response + + @abstractmethod + async def _stream_request( + self, payload: dict[str, Any] + ) -> AsyncIterator[dict[str, Any]]: + """Stream a request to the transport (implementation). + + Subclasses must implement this method to provide streaming support. + + Args: + payload: The request payload + + Yields: + Response chunks from the transport + """ + pass diff --git a/mcp_fuzzer/transport/base.py b/mcp_fuzzer/transport/interfaces/rpc_adapter.py similarity index 53% rename from mcp_fuzzer/transport/base.py rename to mcp_fuzzer/transport/interfaces/rpc_adapter.py index ce7e1c8..00f5727 100644 --- a/mcp_fuzzer/transport/base.py +++ b/mcp_fuzzer/transport/interfaces/rpc_adapter.py @@ -1,113 +1,132 @@ +"""JSON-RPC helper utilities for transport layer. + +This module provides shared JSON-RPC functionality that was previously +embedded in the TransportDriver base class. The JsonRpcAdapter can be +composed into transports or used standalone. +""" + from __future__ import annotations -from abc import ABC, abstractmethod import logging -from typing import Any, AsyncIterator +from typing import Any, TYPE_CHECKING +if TYPE_CHECKING: + from .driver import TransportDriver -class TransportProtocol(ABC): - @abstractmethod - async def send_request( - self, method: str, params: dict[str, Any] | None = None - ) -> Any: - pass - @abstractmethod - async def send_raw(self, payload: dict[str, Any]) -> Any: - pass +class JsonRpcAdapter: + """Helper class providing JSON-RPC operations for transports. - @abstractmethod - async def send_notification( - self, method: str, params: dict[str, Any] | None = None - ) -> None: - pass + This class can be composed into transport implementations or used + standalone to perform common JSON-RPC operations like fetching tools, + calling tools, and handling batch requests. + """ - async def connect(self) -> None: - """Connect to the transport. Default implementation does nothing.""" - pass + def __init__(self, transport: TransportDriver | None = None): + """Initialize the JSON-RPC helper. - async def disconnect(self) -> None: - """Disconnect from the transport. Default implementation does nothing.""" - pass + Args: + transport: Optional transport to use for requests. Can be set later. + """ + self._transport = transport + self._logger = logging.getLogger(__name__) - async def stream_request( - self, payload: dict[str, Any] - ) -> AsyncIterator[dict[str, Any]]: - """Stream a request to the transport. + def set_transport(self, transport: TransportDriver) -> None: + """Set or update the transport used for requests. Args: - payload: The request payload - - Yields: - Response chunks from the transport + transport: Transport instance to use """ - async for response in self._stream_request(payload): - yield response - - @abstractmethod - async def _stream_request( - self, payload: dict[str, Any] - ) -> AsyncIterator[dict[str, Any]]: - """Subclasses must implement streaming of requests.""" - pass + self._transport = transport async def get_tools(self) -> list[dict[str, Any]]: + """Fetch the list of available tools from the server. + + Returns: + List of tool definitions from the server + + Raises: + RuntimeError: If no transport is set + """ + if not self._transport: + raise RuntimeError("No transport set for JsonRpcAdapter") + try: - response = await self.send_request("tools/list") - logging.debug("Raw server response: %s", response) + response = await self._transport.send_request("tools/list") + self._logger.debug("Raw server response: %s", response) + if not isinstance(response, dict): - logging.warning( + self._logger.warning( "Server response is not a dictionary. Got type: %s", type(response), ) return [] + if "tools" not in response: - logging.warning( + self._logger.warning( "Server response missing 'tools' key. Keys present: %s", list(response.keys()), ) return [] + tools = response["tools"] - logging.info("Found %d tools from server", len(tools)) + self._logger.info("Found %d tools from server", len(tools)) return tools except Exception as e: - logging.exception("Failed to fetch tools from server: %s", e) + self._logger.exception("Failed to fetch tools from server: %s", e) return [] async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any: - """ - Call a tool on the server with the given arguments. + """Call a tool on the server with the given arguments. Note: Safety checks and sanitization are handled at the client layer, NOT in the transport. This keeps the transport layer focused on communication concerns only. + + Args: + tool_name: Name of the tool to call + arguments: Arguments to pass to the tool + + Returns: + Tool execution result from the server + + Raises: + RuntimeError: If no transport is set """ + if not self._transport: + raise RuntimeError("No transport set for JsonRpcAdapter") + params = {"name": tool_name, "arguments": arguments} - return await self.send_request("tools/call", params) + return await self._transport.send_request("tools/call", params) async def send_batch_request( self, batch: list[dict[str, Any]] ) -> list[dict[str, Any]]: - """ - Send a batch of JSON-RPC requests/notifications. + """Send a batch of JSON-RPC requests/notifications. Args: batch: List of JSON-RPC requests/notifications Returns: List of responses (may be out of order or incomplete) + + Raises: + RuntimeError: If no transport is set """ + if not self._transport: + raise RuntimeError("No transport set for JsonRpcAdapter") + # Default implementation sends each request individually - # Subclasses can override for true batch support + # Transports can override for true batch support responses = [] for request in batch: try: if "id" not in request or request["id"] is None: # Notification - no response expected - await self.send_raw(request) + await self._transport.send_raw(request) else: # Request - response expected - response = await self.send_raw(request) + response = await self._transport.send_raw(request) # Normalize to dict if not isinstance(response, dict): response = {"result": response} @@ -117,7 +136,7 @@ async def send_batch_request( response["id"] = req_id responses.append(response) except Exception as e: - logging.warning(f"Failed to send batch request: {e}") + self._logger.warning(f"Failed to send batch request: {e}") responses.append({"error": str(e), "id": request.get("id")}) return responses @@ -125,8 +144,7 @@ async def send_batch_request( def collate_batch_responses( self, requests: list[dict[str, Any]], responses: list[dict[str, Any]] ) -> dict[Any, dict[str, Any]]: - """ - Collate batch responses by ID, handling out-of-order and missing responses. + """Collate batch responses by ID, handling out-of-order and missing responses. Args: requests: Original batch requests @@ -149,7 +167,9 @@ def collate_batch_responses( collated[response_id] = response else: # Unmatched response - could be error or notification response - logging.warning(f"Received response with unmatched ID: {response_id}") + self._logger.warning( + f"Received response with unmatched ID: {response_id}" + ) # Check for missing responses for req_id, request in expected_responses.items(): diff --git a/mcp_fuzzer/transport/interfaces/states.py b/mcp_fuzzer/transport/interfaces/states.py new file mode 100644 index 0000000..865c903 --- /dev/null +++ b/mcp_fuzzer/transport/interfaces/states.py @@ -0,0 +1,58 @@ +"""Transport type definitions and enums. + +This module defines shared types, enums, and data structures used across +the transport layer. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from enum import Enum + + +class DriverState(Enum): + """Lifecycle states for transport drivers.""" + + INIT = "init" + CONNECTING = "connecting" + CONNECTED = "connected" + DISCONNECTING = "disconnecting" + CLOSED = "closed" + ERROR = "error" + + +@dataclass +class ParsedEndpoint: + """Structured result of endpoint resolution.""" + + scheme: str + """URL scheme (e.g., 'http', 'https', 'sse', 'stdio', 'streamablehttp')""" + + endpoint: str + """Endpoint/URL after scheme processing""" + + is_custom: bool = False + """Whether this is a custom transport scheme""" + + original_url: str = "" + """Original URL before parsing""" + + netloc: str = "" + """Network location from URL parsing""" + + path: str = "" + """Path component from URL parsing""" + + params: str = "" + """Parameters from URL parsing""" + + query: str = "" + """Query string from URL parsing""" + + fragment: str = "" + """Fragment from URL parsing""" + + def __post_init__(self): + """Ensure endpoint is set if not provided.""" + if not self.endpoint and self.original_url: + self.endpoint = self.original_url diff --git a/tests/integration/test_client_transport.py b/tests/integration/test_client_transport.py index 8d01094..d03fea8 100644 --- a/tests/integration/test_client_transport.py +++ b/tests/integration/test_client_transport.py @@ -7,7 +7,7 @@ import httpx import pytest from mcp_fuzzer.client.protocol_client import ProtocolClient -from mcp_fuzzer.transport.streamable_http import StreamableHTTPTransport +from mcp_fuzzer.transport.drivers.stream_http_driver import StreamHttpDriver pytestmark = [pytest.mark.integration, pytest.mark.client, pytest.mark.transport] @@ -16,7 +16,7 @@ def client_setup(): """Fixture for client and transport setup.""" base_url = "http://localhost:8000" - transport = StreamableHTTPTransport(base_url) + transport = StreamHttpDriver(base_url) # Skip initialize handshake in tests to avoid mocking extra POSTs try: transport._initialized = True @@ -31,6 +31,6 @@ async def test_client_transport_integration(client_setup): """Test client and transport integration.""" # This is a basic test to verify the client and transport can be instantiated assert isinstance(client_setup["client"], ProtocolClient) - assert isinstance(client_setup["transport"], StreamableHTTPTransport) + assert isinstance(client_setup["transport"], StreamHttpDriver) # Verify the transport was created with the correct URL assert client_setup["transport"].url == "http://localhost:8000" diff --git a/tests/integration/test_custom_transport.py b/tests/integration/test_custom_transport.py index c1af8e7..9606f2b 100644 --- a/tests/integration/test_custom_transport.py +++ b/tests/integration/test_custom_transport.py @@ -7,12 +7,13 @@ from mcp_fuzzer.config import apply_config_file, load_custom_transports from mcp_fuzzer.exceptions import ConfigFileError, TransportRegistrationError -from mcp_fuzzer.transport import create_transport, register_custom_transport -from mcp_fuzzer.transport.base import TransportProtocol +from mcp_fuzzer.transport import build_driver, register_custom_driver +from mcp_fuzzer.transport.interfaces import TransportDriver, JsonRpcAdapter +from mcp_fuzzer.transport.catalog.custom_catalog import custom_driver_catalog from typing import Any, Dict, Optional, AsyncIterator -class IntegrationTestTransport(TransportProtocol): +class IntegrationTestTransport(TransportDriver): """Test transport for integration testing.""" def __init__(self, endpoint: str, **kwargs): @@ -72,10 +73,8 @@ class TestCustomTransportConfiguration: def setup_method(self): """Clear any existing custom transports.""" - from mcp_fuzzer.transport.custom import registry - - for name in list(registry.list_transports().keys()): - registry.unregister(name) + for name in list(custom_driver_catalog.list_transports().keys()): + custom_driver_catalog.unregister(name) def test_config_file_custom_transport_loading(self): """Test loading custom transports from configuration file.""" @@ -100,13 +99,13 @@ def test_config_file_custom_transport_loading(self): assert apply_config_file(config_path=config_path) is True # Test that transport was loaded - from mcp_fuzzer.transport import list_custom_transports + from mcp_fuzzer.transport import list_custom_drivers - transports = list_custom_transports() + transports = list_custom_drivers() assert "integration_test" in transports # Test creating transport instance - transport = create_transport("integration_test://test-endpoint") + transport = build_driver("integration_test://test-endpoint") assert isinstance(transport, IntegrationTestTransport) assert transport.endpoint == "test-endpoint" @@ -134,22 +133,20 @@ class TestCustomTransportLifecycle: def setup_method(self): """Clear any existing custom transports.""" - from mcp_fuzzer.transport.custom import registry - - for name in list(registry.list_transports().keys()): - registry.unregister(name) + for name in list(custom_driver_catalog.list_transports().keys()): + custom_driver_catalog.unregister(name) async def test_full_transport_lifecycle(self): """Test complete transport lifecycle from registration to usage.""" # Register custom transport - register_custom_transport( + register_custom_driver( name="lifecycle_test", transport_class=IntegrationTestTransport, description="Lifecycle test transport", ) # Create transport instance - transport = create_transport("lifecycle_test://test-server") + transport = build_driver("lifecycle_test://test-server") # Test connection await transport.connect() @@ -177,8 +174,8 @@ async def test_full_transport_lifecycle(self): assert response["result"]["payload"] == {"stream": "test"} break # Only test first response - # Test tools listing (inherited method) - # Mock the send_request for tools/list + # Test tools listing through JsonRpcAdapter + rpc_helper = JsonRpcAdapter(transport) original_send_request = transport.send_request async def mock_tools_request(method, params=None): @@ -188,7 +185,7 @@ async def mock_tools_request(method, params=None): transport.send_request = mock_tools_request try: - tools = await transport.get_tools() + tools = await rpc_helper.get_tools() assert tools == [{"name": "integration_tool"}] finally: transport.send_request = original_send_request @@ -203,10 +200,8 @@ class TestCustomTransportErrorHandling: def setup_method(self): """Clear any existing custom transports.""" - from mcp_fuzzer.transport.custom import registry - - for name in list(registry.list_transports().keys()): - registry.unregister(name) + for name in list(custom_driver_catalog.list_transports().keys()): + custom_driver_catalog.unregister(name) def test_invalid_registration(self): """Test error handling for invalid transport registration.""" @@ -215,27 +210,27 @@ class InvalidTransport: pass with pytest.raises( - TransportRegistrationError, match="must inherit from TransportProtocol" + TransportRegistrationError, match="must inherit from TransportDriver" ): - register_custom_transport(name="invalid", transport_class=InvalidTransport) + register_custom_driver(name="invalid", transport_class=InvalidTransport) def test_duplicate_registration(self): """Test error handling for duplicate transport registration.""" - register_custom_transport( + register_custom_driver( name="duplicate_test", transport_class=IntegrationTestTransport ) with pytest.raises(TransportRegistrationError, match="already registered"): - register_custom_transport( + register_custom_driver( name="duplicate_test", transport_class=IntegrationTestTransport ) def test_unknown_transport_creation(self): """Test error handling for unknown transport creation.""" with pytest.raises( - TransportRegistrationError, match="Unsupported URL scheme" + TransportRegistrationError, match="Unsupported transport scheme" ): - create_transport("unknown_transport://endpoint") + build_driver("unknown_transport://endpoint") class TestCustomTransportWithClient: @@ -243,20 +238,18 @@ class TestCustomTransportWithClient: def setup_method(self): """Clear any existing custom transports.""" - from mcp_fuzzer.transport.custom import registry - - for name in list(registry.list_transports().keys()): - registry.unregister(name) + for name in list(custom_driver_catalog.list_transports().keys()): + custom_driver_catalog.unregister(name) async def test_transport_with_mcp_client(self): """Test using custom transport with MCP client.""" # Register custom transport - register_custom_transport( + register_custom_driver( name="client_test", transport_class=IntegrationTestTransport ) # Create transport - transport = create_transport("client_test://mcp-server") + transport = build_driver("client_test://mcp-server") # Import and create MCP client (this would normally be done) # This is a simplified test - in real usage, you'd use the full client @@ -269,6 +262,8 @@ async def test_transport_with_mcp_client(self): # Test basic functionality await transport.connect() - # Test tool calling through transport - result = await transport.call_tool("test_tool", {"arg": "value"}) + rpc_helper = JsonRpcAdapter(transport) + + # Test tool calling through transport via RPC helper + result = await rpc_helper.call_tool("test_tool", {"arg": "value"}) assert "result" in result diff --git a/tests/integration/test_standardized_output.py b/tests/integration/test_standardized_output.py index 5881e61..833a1a8 100644 --- a/tests/integration/test_standardized_output.py +++ b/tests/integration/test_standardized_output.py @@ -11,7 +11,7 @@ from mcp_fuzzer.client import MCPFuzzerClient from mcp_fuzzer.reports import FuzzerReporter -from mcp_fuzzer.transport import create_transport +from mcp_fuzzer.transport import build_driver class TestStandardizedOutputIntegration: @@ -24,6 +24,7 @@ def setup_method(self): def teardown_method(self): """Clean up test fixtures.""" import shutil + shutil.rmtree(self.temp_dir) @pytest.mark.asyncio @@ -42,18 +43,13 @@ async def test_full_fuzzing_workflow_with_standardized_output(self): { "success": False, "exception": "ValueError: invalid input", - "args": {"param": "value2"} + "args": {"param": "value2"}, }, - {"success": True, "args": {"param": "value3"}} + {"success": True, "args": {"param": "value3"}}, ] } - protocol_results = { - "InitializeRequest": [ - {"success": True}, - {"success": True} - ] - } + protocol_results = {"InitializeRequest": [{"success": True}, {"success": True}]} # Add results to reporter reporter.add_tool_results("test_tool", tool_results["test_tool"]) @@ -67,7 +63,7 @@ async def test_full_fuzzing_workflow_with_standardized_output(self): protocol="http", endpoint="http://test-server:8000", runs=3, - runs_per_type=2 + runs_per_type=2, ) # Generate standardized reports @@ -126,14 +122,16 @@ def test_configuration_driven_output(self): from mcp_fuzzer.config import config # Set output configuration - config.update({ - "output": { - "format": "json", - "directory": self.temp_dir, - "compress": False, - "types": ["fuzzing_results", "error_report"] + config.update( + { + "output": { + "format": "json", + "directory": self.temp_dir, + "compress": False, + "types": ["fuzzing_results", "error_report"], + } } - }) + ) # Create reporter (should pick up config) reporter = FuzzerReporter() @@ -156,21 +154,21 @@ def test_error_report_generation(self): "tool_name": "dangerous_tool", "severity": "high", "message": "Command injection detected", - "arguments": {"cmd": "rm -rf /"} + "arguments": {"cmd": "rm -rf /"}, }, { "type": "protocol_error", "protocol_type": "InitializeRequest", "severity": "medium", "message": "Invalid JSON in request", - "details": {"field": "jsonrpc", "expected": "2.0", "got": "1.0"} + "details": {"field": "jsonrpc", "expected": "2.0", "got": "1.0"}, }, { "type": "system_error", "severity": "low", "message": "Network timeout", - "context": {"endpoint": "http://test.com", "timeout": 30} - } + "context": {"endpoint": "http://test.com", "timeout": 30}, + }, ] filepath = manager.save_error_report(errors=errors) @@ -199,14 +197,14 @@ def test_safety_summary_generation(self): "tool_name": "file_operations", "reason": "File system access blocked", "arguments": {"path": "/etc/passwd"}, - "timestamp": "2024-01-01T10:00:00Z" + "timestamp": "2024-01-01T10:00:00Z", }, { "tool_name": "network_tools", "reason": "Network access blocked", "arguments": {"url": "http://malicious.com"}, - "timestamp": "2024-01-01T10:01:00Z" - } + "timestamp": "2024-01-01T10:01:00Z", + }, ] safety_data = { @@ -214,10 +212,10 @@ def test_safety_summary_generation(self): "statistics": { "total_operations_blocked": 3, "unique_tools_blocked": 2, - "risk_assessment": "medium" + "risk_assessment": "medium", }, "blocked_operations": blocked_operations, - "risk_assessment": "medium" + "risk_assessment": "medium", } filepath = manager.save_safety_summary(safety_data) @@ -242,15 +240,15 @@ async def test_multiple_output_types_generation(self): reporter = FuzzerReporter(output_dir=self.temp_dir) # Add some mock data with errors to ensure error report is generated - reporter.add_tool_results("test_tool", [ - {"success": True}, - {"success": False, "exception": "ValueError: test error"} - ]) + reporter.add_tool_results( + "test_tool", + [ + {"success": True}, + {"success": False, "exception": "ValueError: test error"}, + ], + ) reporter.set_fuzzing_metadata( - mode="tools", - protocol="http", - endpoint="http://test.com", - runs=2 + mode="tools", protocol="http", endpoint="http://test.com", runs=2 ) # Generate multiple output types @@ -291,7 +289,7 @@ def test_output_file_naming_convention(self): # Verify filename format: timestamp_output_type.json filename = Path(filepath).name - pattern = r'^\d{8}_\d{6}_fuzzing_results\.json$' + pattern = r"^\d{8}_\d{6}_fuzzing_results\.json$" assert re.match(pattern, filename) def test_session_isolation(self): diff --git a/tests/unit/cli/test_cli.py b/tests/unit/cli/test_cli.py index d255a14..a49175e 100644 --- a/tests/unit/cli/test_cli.py +++ b/tests/unit/cli/test_cli.py @@ -22,7 +22,7 @@ from mcp_fuzzer.client.runtime.async_runner import execute_inner_client from mcp_fuzzer.client.runtime.retry import run_with_retry_on_interrupt from mcp_fuzzer.client.safety import SafetyController -from mcp_fuzzer.client.transport.factory import create_transport_with_auth +from mcp_fuzzer.client.transport.factory import build_driver_with_auth from mcp_fuzzer.exceptions import ArgumentValidationError, MCPError @@ -140,6 +140,7 @@ def test_validate_arguments_errors(): with pytest.raises(ArgumentValidationError): validator.validate_arguments(args) + def test_validate_arguments_runs_per_type_invalid_type(): validator = ValidationManager() args = argparse.Namespace( @@ -155,6 +156,7 @@ def test_validate_arguments_runs_per_type_invalid_type(): with pytest.raises(ArgumentValidationError): validator.validate_arguments(args) + def test_validate_arguments_timeout_negative(): validator = ValidationManager() args = argparse.Namespace( @@ -218,6 +220,7 @@ def test_validate_arguments_allows_utility_without_endpoint(): ) validator.validate_arguments(args) + def test_validate_arguments_protocol_type_wrong_mode_with_endpoint(): validator = ValidationManager() args = argparse.Namespace( @@ -233,6 +236,7 @@ def test_validate_arguments_protocol_type_wrong_mode_with_endpoint(): with pytest.raises(ArgumentValidationError): validator.validate_arguments(args) + def test_validate_arguments_runs_not_int(): validator = ValidationManager() args = argparse.Namespace( @@ -267,9 +271,7 @@ def test_build_cli_config_merges_and_returns_cli_config(): def test_handle_validate_config(monkeypatch): validator = ValidationManager() - with patch( - "mcp_fuzzer.cli.validators.load_config_file" - ) as mock_load: + with patch("mcp_fuzzer.cli.validators.load_config_file") as mock_load: validator.validate_config_file("config.yml") mock_load.assert_called_once_with("config.yml") @@ -308,10 +310,8 @@ def test_transport_factory_applies_auth_headers(): args = MagicMock(protocol="http", endpoint="http://example.com", timeout=10.0) auth_manager = MagicMock() auth_manager.get_default_auth_headers.return_value = {"Authorization": "x"} - with patch( - "mcp_fuzzer.client.transport.factory.base_create_transport" - ) as mock_create: - create_transport_with_auth(args, {"auth_manager": auth_manager}) + with patch("mcp_fuzzer.client.transport.factory.base_build_driver") as mock_create: + build_driver_with_auth(args, {"auth_manager": auth_manager}) mock_create.assert_called_once_with( "http", "http://example.com", @@ -334,8 +334,10 @@ def test_safety_controller(): def test_execute_inner_client_pytest_branch(monkeypatch): monkeypatch.setenv("PYTEST_CURRENT_TEST", "1") + async def dummy_main(): return None + with patch("mcp_fuzzer.client.runtime.async_runner.asyncio.run") as mock_run: execute_inner_client(argparse.Namespace(), dummy_main, ["prog"]) mock_run.assert_called_once() @@ -395,7 +397,7 @@ def test_validate_transport_errors(): validator = ValidationManager() args = argparse.Namespace(protocol="http", endpoint="http://x", timeout=1) with patch( - "mcp_fuzzer.cli.validators.create_transport", + "mcp_fuzzer.cli.validators.build_driver", side_effect=Exception("boom"), ): with pytest.raises(Exception): @@ -406,7 +408,7 @@ def test_validate_transport_mcp_error_passthrough(): validator = ValidationManager() args = argparse.Namespace(protocol="http", endpoint="http://x", timeout=1) with patch( - "mcp_fuzzer.cli.validators.create_transport", + "mcp_fuzzer.cli.validators.build_driver", side_effect=MCPError("err", code="X"), ): with pytest.raises(MCPError): @@ -557,6 +559,7 @@ def test_build_cli_config_uses_config_file(monkeypatch): assert cli_config.merged["runs"] == 42 assert cli_config.merged["allow_hosts"] == ["a.local"] + def test_build_cli_config_handles_apply_config_error(caplog): caplog.set_level(logging.DEBUG) args = _base_args(config=None) @@ -568,6 +571,7 @@ def test_build_cli_config_handles_apply_config_error(caplog): assert cli_config.merged["endpoint"] == "http://localhost" assert "fail" in "".join(caplog.messages) + def test_build_cli_config_raises_config_error(): args = _base_args(config="bad.yml") with patch( diff --git a/tests/unit/client/test_client_main.py b/tests/unit/client/test_client_main.py index cfe9556..fc7f2b8 100644 --- a/tests/unit/client/test_client_main.py +++ b/tests/unit/client/test_client_main.py @@ -54,7 +54,7 @@ def test_unified_client_main_tools_mode(): with ( patch( - "mcp_fuzzer.client.main.create_transport_with_auth", + "mcp_fuzzer.client.main.build_driver_with_auth", return_value=mock_transport, ) as mock_transport_factory, patch("mcp_fuzzer.client.main.SafetyFilter", return_value=mock_safety), @@ -76,7 +76,7 @@ def test_unified_client_main_unknown_mode_logs_error_and_returns_nonzero(): client_instance.cleanup = AsyncMock() with ( patch( - "mcp_fuzzer.client.main.create_transport_with_auth", + "mcp_fuzzer.client.main.build_driver_with_auth", return_value=MagicMock(), ), patch("mcp_fuzzer.client.main.SafetyFilter", return_value=MagicMock()), @@ -98,7 +98,7 @@ def test_unified_client_main_sets_fs_root_when_provided(): with ( patch( - "mcp_fuzzer.client.main.create_transport_with_auth", + "mcp_fuzzer.client.main.build_driver_with_auth", return_value=MagicMock(), ), patch("mcp_fuzzer.client.main.SafetyFilter", return_value=mock_safety), @@ -118,7 +118,7 @@ def test_unified_client_main_protocol_and_both_modes(): client_instance.cleanup = AsyncMock() with ( patch( - "mcp_fuzzer.client.main.create_transport_with_auth", + "mcp_fuzzer.client.main.build_driver_with_auth", return_value=MagicMock(), ), patch("mcp_fuzzer.client.main.SafetyFilter", return_value=MagicMock()), @@ -136,7 +136,7 @@ def test_unified_client_main_protocol_and_both_modes(): client_instance2.cleanup = AsyncMock() with ( patch( - "mcp_fuzzer.client.main.create_transport_with_auth", + "mcp_fuzzer.client.main.build_driver_with_auth", return_value=MagicMock(), ), patch("mcp_fuzzer.client.main.SafetyFilter", return_value=MagicMock()), @@ -163,7 +163,7 @@ def test_unified_client_main_exports_reports_and_handles_errors(): client_instance.reporter = reporter with ( patch( - "mcp_fuzzer.client.main.create_transport_with_auth", + "mcp_fuzzer.client.main.build_driver_with_auth", return_value=MagicMock(), ), patch("mcp_fuzzer.client.main.SafetyFilter", return_value=MagicMock()), @@ -174,6 +174,7 @@ def test_unified_client_main_exports_reports_and_handles_errors(): reporter.export_csv.assert_called_once() reporter.export_markdown.assert_called_once() + def test_unified_client_main_exports_html_xml(): settings = _settings( mode="tool", @@ -189,7 +190,7 @@ def test_unified_client_main_exports_html_xml(): client_instance.reporter = reporter with ( patch( - "mcp_fuzzer.client.main.create_transport_with_auth", + "mcp_fuzzer.client.main.build_driver_with_auth", return_value=MagicMock(), ), patch("mcp_fuzzer.client.main.SafetyFilter", return_value=MagicMock()), @@ -209,7 +210,7 @@ def test_unified_client_main_safety_disabled(): client_instance.cleanup = AsyncMock() with ( patch( - "mcp_fuzzer.client.main.create_transport_with_auth", + "mcp_fuzzer.client.main.build_driver_with_auth", return_value=MagicMock(), ), patch("mcp_fuzzer.client.main.SafetyFilter") as mock_safety, @@ -231,7 +232,7 @@ def test_unified_client_main_tool_results_summary(monkeypatch): client_instance.print_tool_summary = MagicMock() with ( patch( - "mcp_fuzzer.client.main.create_transport_with_auth", + "mcp_fuzzer.client.main.build_driver_with_auth", return_value=MagicMock(), ), patch("mcp_fuzzer.client.main.SafetyFilter", return_value=MagicMock()), @@ -250,7 +251,7 @@ def test_unified_client_main_returns_one_on_exception(): client_instance.cleanup = AsyncMock() with ( patch( - "mcp_fuzzer.client.main.create_transport_with_auth", + "mcp_fuzzer.client.main.build_driver_with_auth", return_value=MagicMock(), ), patch("mcp_fuzzer.client.main.SafetyFilter", return_value=MagicMock()), diff --git a/tests/unit/transport/test_custom.py b/tests/unit/transport/test_custom.py index 4dbbf7c..4dbe19e 100644 --- a/tests/unit/transport/test_custom.py +++ b/tests/unit/transport/test_custom.py @@ -4,18 +4,19 @@ from unittest.mock import Mock from typing import Any, Dict, Optional, AsyncIterator -from mcp_fuzzer.transport.base import TransportProtocol -from mcp_fuzzer.transport.custom import ( - CustomTransportRegistry, - register_custom_transport, - create_custom_transport, - list_custom_transports, +from mcp_fuzzer.transport.interfaces import TransportDriver, JsonRpcAdapter +from mcp_fuzzer.transport.catalog import build_driver +from mcp_fuzzer.transport.catalog.custom_catalog import ( + CustomDriverCatalog, + register_custom_driver, + build_custom_driver, + list_custom_drivers, + custom_driver_catalog, ) -from mcp_fuzzer.transport.factory import create_transport from mcp_fuzzer.exceptions import ConnectionError, TransportRegistrationError -class MockTransport(TransportProtocol): +class MockTransport(TransportDriver): """Mock transport for testing.""" def __init__(self, endpoint: str, **kwargs): @@ -41,17 +42,20 @@ async def _stream_request( yield {"result": "stream_response"} -class TestCustomTransportRegistry: +class TestCustomDriverCatalog: """Test the custom transport registry functionality.""" + def setup_method(self): + custom_driver_catalog.clear() + def test_registry_initialization(self): """Test that registry initializes correctly.""" - registry = CustomTransportRegistry() + registry = CustomDriverCatalog() assert registry.list_transports() == {} def test_register_transport(self): """Test registering a custom transport.""" - registry = CustomTransportRegistry() + registry = CustomDriverCatalog() registry.register( name="mock_transport", @@ -68,7 +72,7 @@ def test_register_transport(self): def test_register_duplicate_transport(self): """Test that registering a duplicate transport raises an error.""" - registry = CustomTransportRegistry() + registry = CustomDriverCatalog() registry.register( name="mock_transport", @@ -88,14 +92,14 @@ def test_register_duplicate_transport(self): def test_register_invalid_transport_class(self): """Test that registering an invalid transport class raises an error.""" - registry = CustomTransportRegistry() + registry = CustomDriverCatalog() class InvalidTransport: pass with pytest.raises( TransportRegistrationError, - match="Transport class .* must inherit from TransportProtocol", + match="Transport class .* must inherit from TransportDriver", ): registry.register( name="invalid_transport", @@ -105,7 +109,7 @@ class InvalidTransport: def test_unregister_transport(self): """Test unregistering a custom transport.""" - registry = CustomTransportRegistry() + registry = CustomDriverCatalog() registry.register( name="mock_transport", @@ -118,7 +122,7 @@ def test_unregister_transport(self): def test_unregister_nonexistent_transport(self): """Test that unregistering a non-existent transport raises an error.""" - registry = CustomTransportRegistry() + registry = CustomDriverCatalog() with pytest.raises( TransportRegistrationError, @@ -128,7 +132,7 @@ def test_unregister_nonexistent_transport(self): def test_get_transport_class(self): """Test getting transport class from registry.""" - registry = CustomTransportRegistry() + registry = CustomDriverCatalog() registry.register( name="mock_transport", @@ -141,7 +145,7 @@ def test_get_transport_class(self): def test_get_transport_info(self): """Test getting transport info from registry.""" - registry = CustomTransportRegistry() + registry = CustomDriverCatalog() registry.register( name="mock_transport", @@ -153,9 +157,9 @@ def test_get_transport_info(self): assert info["class"] == MockTransport assert info["description"] == "Mock transport for testing" - def test_create_transport(self): + def test_build_driver(self): """Test creating transport instance from registry.""" - registry = CustomTransportRegistry() + registry = CustomDriverCatalog() registry.register( name="mock_transport", @@ -163,9 +167,7 @@ def test_create_transport(self): description="Mock transport for testing", ) - transport = registry.create_transport( - "mock_transport", "test-endpoint", timeout=30 - ) + transport = registry.build_driver("mock_transport", "test-endpoint", timeout=30) assert isinstance(transport, MockTransport) assert transport.endpoint == "test-endpoint" assert transport.kwargs == {"timeout": 30} @@ -176,30 +178,30 @@ class TestCustomTransportFunctions: def setup_method(self): """Clear the global registry before each test.""" - from mcp_fuzzer.transport.custom import registry + registry = custom_driver_catalog registry.clear() - def test_register_custom_transport(self): - """Test the global register_custom_transport function.""" - register_custom_transport( + def test_register_custom_driver(self): + """Test the global register_custom_driver function.""" + register_custom_driver( name="global_mock", transport_class=MockTransport, description="Global mock transport", ) - transports = list_custom_transports() + transports = list_custom_drivers() assert "global_mock" in transports - def test_create_custom_transport(self): - """Test the global create_custom_transport function.""" - register_custom_transport( + def test_build_custom_driver(self): + """Test the global build_custom_driver function.""" + register_custom_driver( name="global_mock", transport_class=MockTransport, description="Global mock transport", ) - transport = create_custom_transport("global_mock", "test-endpoint") + transport = build_custom_driver("global_mock", "test-endpoint") assert isinstance(transport, MockTransport) assert transport.endpoint == "test-endpoint" @@ -209,39 +211,39 @@ class TestTransportFactoryIntegration: def setup_method(self): """Clear the global registry before each test.""" - from mcp_fuzzer.transport.custom import registry + registry = custom_driver_catalog registry.clear() def test_custom_transport_via_factory(self): """Test creating custom transport via factory.""" - register_custom_transport( + register_custom_driver( name="factory_mock", transport_class=MockTransport, description="Factory mock transport", ) - transport = create_transport("factory_mock://test-endpoint") + transport = build_driver("factory_mock://test-endpoint") assert isinstance(transport, MockTransport) assert transport.endpoint == "test-endpoint" def test_custom_transport_via_factory_two_args(self): """Back-compat: (protocol, endpoint) for custom transports.""" - register_custom_transport( + register_custom_driver( name="factory_mock", transport_class=MockTransport, description="Factory mock transport", ) - transport = create_transport("factory_mock", "test-endpoint") + transport = build_driver("factory_mock", "test-endpoint") assert isinstance(transport, MockTransport) assert transport.endpoint == "test-endpoint" def test_unknown_custom_transport(self): """Test that unknown custom transport raises error.""" with pytest.raises( - TransportRegistrationError, match="Unsupported URL scheme: unknown" + TransportRegistrationError, match="Unsupported transport scheme" ): - create_transport("unknown://test-endpoint") + build_driver("unknown://test-endpoint") def test_custom_transport_with_config_schema(self): """Test custom transport with configuration schema.""" @@ -252,20 +254,20 @@ def test_custom_transport_with_config_schema(self): }, } - register_custom_transport( + register_custom_driver( name="schema_mock", transport_class=MockTransport, description="Schema mock transport", config_schema=schema, ) - transports = list_custom_transports() + transports = list_custom_drivers() assert "schema_mock" in transports assert transports["schema_mock"]["config_schema"] == schema -class TestTransportProtocolCompliance: - """Test that custom transports comply with TransportProtocol.""" +class TestTransportDriverCompliance: + """Test that custom transports comply with TransportDriver.""" async def test_mock_transport_compliance(self): """Test that MockTransport implements all required methods.""" @@ -300,15 +302,19 @@ async def mock_send_request(method, params=None): return await original_send_request(method, params) transport.send_request = mock_send_request + rpc_helper = JsonRpcAdapter(transport) - tools = await transport.get_tools() - assert tools == [{"name": "test_tool"}] + try: + tools = await rpc_helper.get_tools() + assert tools == [{"name": "test_tool"}] + finally: + transport.send_request = original_send_request class TestCustomTransportCloseMethod: """Test that custom transports can implement and have their close method invoked.""" - class CloseableTransport(TransportProtocol): + class CloseableTransport(TransportDriver): """Mock transport that tracks close() calls.""" def __init__(self, endpoint: str, **kwargs): @@ -373,7 +379,7 @@ async def test_custom_transport_close_method_invoked(self): async def test_custom_transport_close_with_resources(self): """Test that custom transport can clean up resources in close method.""" - class ResourceManagedTransport(TransportProtocol): + class ResourceManagedTransport(TransportDriver): """Transport that manages resources.""" def __init__(self, endpoint: str): @@ -423,15 +429,15 @@ class TestCustomTransportSelfRegistration: def setup_method(self): """Clear the global registry before each test.""" - from mcp_fuzzer.transport.custom import registry + registry = custom_driver_catalog registry.clear() def test_self_registration_with_registry(self): """Test that custom transport can self-register using registry.register.""" - from mcp_fuzzer.transport.custom import registry + registry = custom_driver_catalog - class SelfRegisteringTransport(TransportProtocol): + class SelfRegisteringTransport(TransportDriver): """Transport that self-registers on import.""" def __init__(self, endpoint: str, **kwargs): @@ -473,10 +479,10 @@ async def _stream_request( def test_self_registration_at_module_level(self): """Test self-registration pattern at module level.""" - from mcp_fuzzer.transport.custom import registry + registry = custom_driver_catalog # Simulate module-level registration - class ModuleLevelTransport(TransportProtocol): + class ModuleLevelTransport(TransportDriver): """Transport registered at module level.""" def __init__(self, endpoint: str): @@ -509,9 +515,9 @@ async def _stream_request( def test_self_registration_with_schema(self): """Test self-registration with configuration schema.""" - from mcp_fuzzer.transport.custom import registry + registry = custom_driver_catalog - class ConfigurableTransport(TransportProtocol): + class ConfigurableTransport(TransportDriver): """Transport with configuration schema.""" def __init__(self, endpoint: str, timeout: float = 30.0): @@ -563,15 +569,15 @@ class TestSelfRegisteredTransportInstantiation: def setup_method(self): """Clear the global registry before each test.""" - from mcp_fuzzer.transport.custom import registry + registry = custom_driver_catalog registry.clear() async def test_instantiate_self_registered_transport(self): """Test that self-registered transport can be instantiated and used.""" - from mcp_fuzzer.transport.custom import registry + registry = custom_driver_catalog - class UsableTransport(TransportProtocol): + class UsableTransport(TransportDriver): """Fully functional self-registered transport.""" def __init__(self, endpoint: str, **kwargs): @@ -610,7 +616,7 @@ async def _stream_request( registry.register("usable", UsableTransport, description="Usable transport") # Instantiate via registry - transport = registry.create_transport("usable", "test-server", timeout=60) + transport = registry.build_driver("usable", "test-server", timeout=60) # Verify instantiation assert isinstance(transport, UsableTransport) @@ -643,9 +649,9 @@ async def _stream_request( async def test_instantiate_via_factory(self): """Test that self-registered transport can be instantiated via factory.""" - from mcp_fuzzer.transport.custom import registry + registry = custom_driver_catalog - class FactoryUsableTransport(TransportProtocol): + class FactoryUsableTransport(TransportDriver): """Transport usable via factory.""" def __init__(self, endpoint: str): @@ -675,7 +681,7 @@ async def _stream_request( ) # Instantiate via factory using URL format - transport = create_transport("factory_usable://my-endpoint") + transport = build_driver("factory_usable://my-endpoint") # Verify it works assert isinstance(transport, FactoryUsableTransport) @@ -687,9 +693,9 @@ async def _stream_request( async def test_instantiate_with_custom_factory(self): """Test instantiation with custom factory function.""" - from mcp_fuzzer.transport.custom import registry + registry = custom_driver_catalog - class FactoryManagedTransport(TransportProtocol): + class FactoryManagedTransport(TransportDriver): """Transport created via factory function.""" def __init__(self, endpoint: str, custom_arg: str): @@ -733,7 +739,7 @@ def custom_factory(url_or_endpoint: str, **kwargs): ) # Instantiate via registry with factory - transport = registry.create_transport( + transport = registry.build_driver( "factory_managed", "test-endpoint", custom_arg="custom_value" ) diff --git a/tests/unit/transport/test_stdio.py b/tests/unit/transport/test_stdio.py index 52af0e0..a9c0cb1 100644 --- a/tests/unit/transport/test_stdio.py +++ b/tests/unit/transport/test_stdio.py @@ -7,22 +7,22 @@ import pytest # Import the class to test -from mcp_fuzzer.transport.stdio import StdioTransport +from mcp_fuzzer.transport.drivers.stdio_driver import StdioDriver from mcp_fuzzer.fuzz_engine.runtime import ProcessManager, WatchdogConfig from mcp_fuzzer.exceptions import MCPError, ServerError, TransportError -class TestStdioTransport: +class TestStdioDriver: def setup_method(self): """Set up test fixtures.""" self.command = "test_command" self.timeout = 10.0 - self.transport = StdioTransport(self.command, self.timeout) + self.transport = StdioDriver(self.command, self.timeout) self.transport.process_manager = AsyncMock(spec=ProcessManager) self.transport._lock = AsyncMock(spec=asyncio.Lock) def test_init(self): - """Test initialization of StdioTransport.""" + """Test initialization of StdioDriver.""" assert self.transport.command == self.command assert self.transport.timeout == self.timeout assert self.transport.process is None @@ -35,8 +35,8 @@ def test_init(self): ) @pytest.mark.asyncio - @patch("mcp_fuzzer.transport.stdio.asyncio.create_subprocess_exec") - @patch("mcp_fuzzer.transport.stdio.shlex.split") + @patch("mcp_fuzzer.transport.drivers.stdio_driver.asyncio.create_subprocess_exec") + @patch("mcp_fuzzer.transport.drivers.stdio_driver.shlex.split") async def test_ensure_connection_new_process( self, mock_shlex_split, mock_create_subprocess ): @@ -68,7 +68,7 @@ async def test_ensure_connection_new_process( ) @pytest.mark.asyncio - @patch("mcp_fuzzer.transport.stdio.asyncio.create_subprocess_exec") + @patch("mcp_fuzzer.transport.drivers.stdio_driver.asyncio.create_subprocess_exec") async def test_ensure_connection_existing_process_alive( self, mock_create_subprocess ): @@ -85,7 +85,7 @@ async def test_ensure_connection_existing_process_alive( assert self.transport._initialized is True @pytest.mark.asyncio - @patch("mcp_fuzzer.transport.stdio.asyncio.create_subprocess_exec") + @patch("mcp_fuzzer.transport.drivers.stdio_driver.asyncio.create_subprocess_exec") async def test_ensure_connection_existing_process_dead( self, mock_create_subprocess ): @@ -210,7 +210,7 @@ async def test_send_request(self): with patch.object( self.transport, "_send_message", new=AsyncMock() ) as mock_send: - with patch("mcp_fuzzer.transport.stdio.uuid") as mock_uuid: + with patch("mcp_fuzzer.transport.drivers.stdio_driver.uuid") as mock_uuid: # Force the request_id to a known value mock_uuid.uuid4.return_value = "test_id" @@ -237,7 +237,7 @@ async def test_send_request_error_response(self): with patch.object( self.transport, "_send_message", new=AsyncMock() ) as mock_send: - with patch("mcp_fuzzer.transport.stdio.uuid") as mock_uuid: + with patch("mcp_fuzzer.transport.drivers.stdio_driver.uuid") as mock_uuid: # Force the request_id to a known value mock_uuid.uuid4.return_value = "test_id" @@ -257,13 +257,16 @@ async def test_send_request_error_response(self): mock_send.assert_awaited_once() mock_receive.assert_awaited_once() + @pytest.mark.asyncio async def test_send_request_no_response(self): """send_request should raise TransportError when no response arrives.""" - with patch.object( - self.transport, "_send_message", new=AsyncMock() - ), patch("mcp_fuzzer.transport.stdio.uuid") as mock_uuid, patch.object( - self.transport, "_receive_message", new=AsyncMock(return_value=None) + with ( + patch.object(self.transport, "_send_message", new=AsyncMock()), + patch("mcp_fuzzer.transport.drivers.stdio_driver.uuid") as mock_uuid, + patch.object( + self.transport, "_receive_message", new=AsyncMock(return_value=None) + ), ): mock_uuid.uuid4.return_value = "test_id" with pytest.raises(TransportError): @@ -389,8 +392,10 @@ async def test_send_timeout_signal_process_not_registered_timeout(self): self.transport.process = mock_process self.transport.process_manager.is_process_registered.return_value = False - with patch("mcp_fuzzer.transport.stdio.logging.info") as mock_log: - with patch("mcp_fuzzer.transport.stdio.os") as mock_os: + with patch( + "mcp_fuzzer.transport.drivers.stdio_driver.logging.info" + ) as mock_log: + with patch("mcp_fuzzer.transport.drivers.stdio_driver.os") as mock_os: # Mock getpgid to avoid OS errors mock_os.name = "posix" mock_os.getpgid.return_value = 123 @@ -410,8 +415,10 @@ async def test_send_timeout_signal_process_not_registered_force(self): self.transport.process = mock_process self.transport.process_manager.is_process_registered.return_value = False - with patch("mcp_fuzzer.transport.stdio.logging.info") as mock_log: - with patch("mcp_fuzzer.transport.stdio.os") as mock_os: + with patch( + "mcp_fuzzer.transport.drivers.stdio_driver.logging.info" + ) as mock_log: + with patch("mcp_fuzzer.transport.drivers.stdio_driver.os") as mock_os: # Mock kill to avoid OS errors mock_os.name = "posix" @@ -430,8 +437,10 @@ async def test_send_timeout_signal_process_not_registered_interrupt(self): self.transport.process = mock_process self.transport.process_manager.is_process_registered.return_value = False - with patch("mcp_fuzzer.transport.stdio.logging.info") as mock_log: - with patch("mcp_fuzzer.transport.stdio.os") as mock_os: + with patch( + "mcp_fuzzer.transport.drivers.stdio_driver.logging.info" + ) as mock_log: + with patch("mcp_fuzzer.transport.drivers.stdio_driver.os") as mock_os: # Mock kill to avoid OS errors mock_os.name = "posix" @@ -458,13 +467,15 @@ async def test_send_timeout_signal_no_process(self): self.transport.process = None result = await self.transport.send_timeout_signal("timeout") assert result is False + @pytest.mark.asyncio async def test_send_raw_no_response(self): """send_raw should raise TransportError when no message arrives.""" - with patch.object( - self.transport, "_send_message", new=AsyncMock() - ), patch.object( - self.transport, "_receive_message", new=AsyncMock(return_value=None) + with ( + patch.object(self.transport, "_send_message", new=AsyncMock()), + patch.object( + self.transport, "_receive_message", new=AsyncMock(return_value=None) + ), ): with pytest.raises(TransportError): await self.transport.send_raw({"raw": "data"}) diff --git a/tests/unit/transport/test_streamable_http.py b/tests/unit/transport/test_streamable_http.py index 6c079fa..36691c9 100644 --- a/tests/unit/transport/test_streamable_http.py +++ b/tests/unit/transport/test_streamable_http.py @@ -3,8 +3,8 @@ from typing import Any, Dict, List, Optional, Union import pytest -from mcp_fuzzer.transport.streamable_http import ( - StreamableHTTPTransport, +from mcp_fuzzer.transport.drivers.stream_http_driver import ( + StreamHttpDriver, CONTENT_TYPE, ) from mcp_fuzzer.config import DEFAULT_PROTOCOL_VERSION @@ -109,7 +109,7 @@ async def test_streamable_http_json_initialize(monkeypatch): monkeypatch.setattr(httpx, "AsyncClient", lambda *a, **k: fake) - t = StreamableHTTPTransport("http://test/mcp", timeout=1) + t = StreamHttpDriver("http://test/mcp", timeout=1) # Act: initialize result = await t.send_request("initialize", {"params": {}}) @@ -146,7 +146,7 @@ async def test_streamable_http_sse_response(monkeypatch): monkeypatch.setattr(httpx, "AsyncClient", lambda *a, **k: fake) - t = StreamableHTTPTransport("http://test/mcp", timeout=1) + t = StreamHttpDriver("http://test/mcp", timeout=1) # Act result = await t.send_request("initialize", {}) @@ -165,7 +165,7 @@ async def test_streamable_http_wraps_http_status_error(monkeypatch): fake = _FakeAsyncClient([bad_response]) monkeypatch.setattr(httpx, "AsyncClient", lambda *a, **k: fake) - transport = StreamableHTTPTransport("http://test/mcp", timeout=1) + transport = StreamHttpDriver("http://test/mcp", timeout=1) transport._initialized = True with pytest.raises(TransportError) as excinfo: @@ -186,7 +186,7 @@ async def test_streamable_http_wraps_connect_errors(monkeypatch): fake = _FakeAsyncClient([connect_exc]) monkeypatch.setattr(httpx, "AsyncClient", lambda *a, **k: fake) - transport = StreamableHTTPTransport("http://test/mcp", timeout=1) + transport = StreamHttpDriver("http://test/mcp", timeout=1) transport._initialized = True with pytest.raises(TransportError) as excinfo: diff --git a/tests/unit/transport/test_transport.py b/tests/unit/transport/test_transport.py index d218feb..1b82f21 100644 --- a/tests/unit/transport/test_transport.py +++ b/tests/unit/transport/test_transport.py @@ -13,16 +13,16 @@ import pytest from mcp_fuzzer.transport import ( - HTTPTransport, - SSETransport, - StdioTransport, - TransportProtocol, - create_transport, + HttpDriver, + SseDriver, + StdioDriver, + TransportDriver, + build_driver, ) -from mcp_fuzzer.transport.mixins import ( - BaseTransportMixin, - NetworkTransportMixin, - ResponseParsingMixin, +from mcp_fuzzer.transport.interfaces.behaviors import ( + DriverBaseBehavior, + HttpClientBehavior, + ResponseParserBehavior, JSONRPCRequest, JSONRPCNotification, TransportError, @@ -34,21 +34,21 @@ pytestmark = [pytest.mark.unit, pytest.mark.transport] -# Test cases for TransportProtocol class +# Test cases for TransportDriver class @pytest.mark.asyncio async def test_transport_protocol_abstract(): - """Test that TransportProtocol is properly abstract.""" - # Should not be able to instantiate TransportProtocol directly + """Test that TransportDriver is properly abstract.""" + # Should not be able to instantiate TransportDriver directly with pytest.raises(TypeError): - TransportProtocol() + TransportDriver() @pytest.mark.asyncio async def test_transport_protocol_connection_methods(): - """Test TransportProtocol connection management methods.""" + """Test TransportDriver connection management methods.""" # Create a concrete implementation - class TestTransport(TransportProtocol): + class TestTransport(TransportDriver): async def send_request(self, method, params=None): return {"test": "response"} @@ -75,10 +75,10 @@ async def _stream_request(self, payload): @pytest.mark.asyncio async def test_transport_protocol_send_request(): - """Test TransportProtocol send_request method.""" + """Test TransportDriver send_request method.""" # Create a concrete implementation with mocked _send_request - class TestTransport(TransportProtocol): + class TestTransport(TransportDriver): async def send_request(self, method, params=None): payload = {"method": method} if params: @@ -112,10 +112,10 @@ async def _stream_request(self, payload): @pytest.mark.asyncio async def test_transport_protocol_stream_request(): - """Test TransportProtocol stream_request method.""" + """Test TransportDriver stream_request method.""" # Create a concrete implementation with mocked _stream_request - class TestTransport(TransportProtocol): + class TestTransport(TransportDriver): async def send_request(self, method, params=None): payload = {"method": method} if params: @@ -154,16 +154,16 @@ async def _stream_request(self, payload): assert transport.last_payload == test_payload -# Test cases for HTTPTransport class +# Test cases for HttpDriver class @pytest.fixture def http_transport(): - """Fixture for HTTPTransport test cases.""" - return HTTPTransport("https://example.com/api") + """Fixture for HttpDriver test cases.""" + return HttpDriver("https://example.com/api") @pytest.mark.asyncio async def test_http_transport_init(http_transport): - """Test HTTPTransport initialization.""" + """Test HttpDriver initialization.""" assert http_transport.url == "https://example.com/api" assert http_transport.timeout == 30.0 assert "Accept" in http_transport.headers @@ -172,7 +172,7 @@ async def test_http_transport_init(http_transport): @pytest.mark.asyncio async def test_http_transport_send_request(http_transport): - """Test HTTPTransport send_request method.""" + """Test HttpDriver send_request method.""" test_payload = {"method": "test.method", "params": {"key": "value"}} test_response = {"result": "success"} @@ -196,7 +196,7 @@ async def test_http_transport_send_request(http_transport): @pytest.mark.asyncio async def test_http_transport_send_request_error(http_transport): - """Test HTTPTransport send_request with error response.""" + """Test HttpDriver send_request with error response.""" test_payload = {"method": "test.method", "params": {"key": "value"}} with patch.object(httpx.AsyncClient, "post") as mock_post: @@ -209,7 +209,7 @@ async def test_http_transport_send_request_error(http_transport): @pytest.mark.asyncio async def test_http_transport_stream_request(http_transport): - """Test HTTPTransport stream_request method.""" + """Test HttpDriver stream_request method.""" test_payload = {"method": "test.method", "params": {"key": "value"}} test_responses = [ {"id": 1, "result": "streaming"}, @@ -271,7 +271,7 @@ async def __anext__(self): @pytest.mark.asyncio async def test_http_transport_stream_request_error(http_transport): - """Test HTTPTransport stream_request with error.""" + """Test HttpDriver stream_request with error.""" test_payload = {"method": "test.method", "params": {"key": "value"}} with patch.object(httpx.AsyncClient, "post") as mock_post: @@ -285,22 +285,22 @@ async def test_http_transport_stream_request_error(http_transport): @pytest.mark.asyncio async def test_http_transport_connect_disconnect(http_transport): - """Test HTTPTransport connect and disconnect methods.""" + """Test HttpDriver connect and disconnect methods.""" # These should not raise any exceptions await http_transport.connect() await http_transport.disconnect() -# Test cases for SSETransport class +# Test cases for SseDriver class @pytest.fixture def sse_transport(): - """Fixture for SSETransport test cases.""" - return SSETransport("https://example.com/events") + """Fixture for SseDriver test cases.""" + return SseDriver("https://example.com/events") @pytest.mark.asyncio async def test_sse_transport_init(sse_transport): - """Test SSETransport initialization.""" + """Test SseDriver initialization.""" assert sse_transport.url == "https://example.com/events" assert sse_transport.timeout == 30.0 assert "Accept" in sse_transport.headers @@ -309,14 +309,14 @@ async def test_sse_transport_init(sse_transport): @pytest.mark.asyncio async def test_sse_transport_send_request_not_implemented(sse_transport): - """Test SSETransport send_request is not implemented.""" + """Test SseDriver send_request is not implemented.""" with pytest.raises(NotImplementedError): await sse_transport.send_request("test") @pytest.mark.asyncio async def test_sse_transport_stream_request(sse_transport): - """Test SSETransport stream_request method.""" + """Test SseDriver stream_request method.""" test_payload = {"method": "test.method", "params": {"key": "value"}} # Create mock SSE events - each event needs to end with a blank line @@ -353,7 +353,7 @@ async def test_sse_transport_stream_request(sse_transport): @pytest.mark.asyncio async def test_sse_transport_stream_request_error(sse_transport): - """Test SSETransport stream_request with error.""" + """Test SseDriver stream_request with error.""" test_payload = {"method": "test.method", "params": {"key": "value"}} with patch.object(httpx.AsyncClient, "stream") as mock_stream: @@ -367,50 +367,52 @@ async def test_sse_transport_stream_request_error(sse_transport): @pytest.mark.asyncio async def test_sse_transport_parse_sse_event(): - """Test SSETransport _parse_sse_event method.""" + """Test SseDriver parse_sse_event method.""" + driver = SseDriver("http://test", timeout=1) + # Standard SSE event sse_event = 'event: message\ndata: {"id": 1, "result": "success"}' - result = SSETransport._parse_sse_event(sse_event) + result = driver.parse_sse_event(sse_event) assert result == {"id": 1, "result": "success"} # Multiline data sse_event = 'event: message\ndata: {"id": 1,\ndata: "result": "multiline"}' - result = SSETransport._parse_sse_event(sse_event) + result = driver.parse_sse_event(sse_event) assert result == {"id": 1, "result": "multiline"} # With retry field (should ignore) sse_event = 'retry: 3000\nevent: message\ndata: {"id": 1}' - result = SSETransport._parse_sse_event(sse_event) + result = driver.parse_sse_event(sse_event) assert result == {"id": 1} # Empty event - assert SSETransport._parse_sse_event("") is None + assert driver.parse_sse_event("") is None # Invalid JSON sse_event = "event: message\ndata: not_json" with pytest.raises(json.JSONDecodeError): - SSETransport._parse_sse_event(sse_event) + driver.parse_sse_event(sse_event) -# Test cases for StdioTransport class +# Test cases for StdioDriver class @pytest.fixture def stdio_transport(): - """Fixture for StdioTransport test cases.""" - with patch("mcp_fuzzer.transport.stdio.sys") as mock_sys: - transport = StdioTransport("test_command") + """Fixture for StdioDriver test cases.""" + with patch("mcp_fuzzer.transport.drivers.stdio_driver.sys") as mock_sys: + transport = StdioDriver("test_command") transport._sys = mock_sys # Attach the mock to the transport yield transport @pytest.mark.asyncio async def test_stdio_transport_init(stdio_transport): - """Test StdioTransport initialization.""" + """Test StdioDriver initialization.""" assert stdio_transport.request_id == 1 @pytest.mark.asyncio async def test_stdio_transport_send_request(stdio_transport): - """Test StdioTransport send_request method.""" + """Test StdioDriver send_request method.""" test_payload = {"method": "test.method", "params": {"key": "value"}} test_response = {"id": 1, "result": "success"} @@ -432,7 +434,7 @@ async def test_stdio_transport_send_request(stdio_transport): @pytest.mark.asyncio async def test_stdio_transport_send_request_error(stdio_transport): - """Test StdioTransport send_request with error response.""" + """Test StdioDriver send_request with error response.""" test_payload = {"method": "test.method", "params": {"key": "value"}} test_error = {"id": 1, "error": {"code": -32600, "message": "Invalid Request"}} @@ -448,7 +450,7 @@ async def test_stdio_transport_send_request_error(stdio_transport): @pytest.mark.asyncio async def test_stdio_transport_send_request_invalid_json(stdio_transport): - """Test StdioTransport send_request with invalid JSON response.""" + """Test StdioDriver send_request with invalid JSON response.""" test_payload = {"method": "test.method", "params": {"key": "value"}} # Set up the mocks @@ -461,7 +463,7 @@ async def test_stdio_transport_send_request_invalid_json(stdio_transport): @pytest.mark.asyncio async def test_stdio_transport_stream_request(stdio_transport): - """Test StdioTransport stream_request method.""" + """Test StdioDriver stream_request method.""" test_payload = {"method": "test.method", "params": {"key": "value"}} test_responses = [ {"id": 1, "result": "streaming"}, @@ -486,42 +488,42 @@ async def test_stdio_transport_stream_request(stdio_transport): assert stdio_transport._sys.stdout.write.call_count == 1 -# Test cases for create_transport function -def test_create_transport_http(): - """Test create_transport with HTTP URL.""" - transport = create_transport("http://example.com/api") - assert isinstance(transport, HTTPTransport) +# Test cases for build_driver function +def test_build_driver_http(): + """Test build_driver with HTTP URL.""" + transport = build_driver("http://example.com/api") + assert isinstance(transport, HttpDriver) assert transport.url == "http://example.com/api" -def test_create_transport_https(): - """Test create_transport with HTTPS URL.""" - transport = create_transport("https://example.com/api") - assert isinstance(transport, HTTPTransport) +def test_build_driver_https(): + """Test build_driver with HTTPS URL.""" + transport = build_driver("https://example.com/api") + assert isinstance(transport, HttpDriver) assert transport.url == "https://example.com/api" -def test_create_transport_sse(): - """Test create_transport with SSE URL.""" - transport = create_transport("sse://example.com/events") - assert isinstance(transport, SSETransport) +def test_build_driver_sse(): + """Test build_driver with SSE URL.""" + transport = build_driver("sse://example.com/events") + assert isinstance(transport, SseDriver) assert transport.url == "http://example.com/events" -def test_create_transport_stdio(): - """Test create_transport with stdio URL.""" - transport = create_transport("stdio:") - assert isinstance(transport, StdioTransport) +def test_build_driver_stdio(): + """Test build_driver with stdio URL.""" + transport = build_driver("stdio:") + assert isinstance(transport, StdioDriver) -def test_create_transport_protocol_and_endpoint_builtin(): +def test_build_driver_protocol_and_endpoint_builtin(): """Ensure built-in transports work with protocol+endpoint usage.""" - transport = create_transport("stdio", "node server.js") - assert isinstance(transport, StdioTransport) + transport = build_driver("stdio", "node server.js") + assert isinstance(transport, StdioDriver) assert transport.command == "node server.js" -def test_create_transport_invalid_scheme(): - """Test create_transport with invalid URL scheme.""" +def test_build_driver_invalid_scheme(): + """Test build_driver with invalid URL scheme.""" with pytest.raises(TransportRegistrationError): - create_transport("invalid://example.com") + build_driver("invalid://example.com")