From 994d4ddd3749b90306bc314964ef6b0ba057c0f3 Mon Sep 17 00:00:00 2001 From: Prince Roshan Date: Mon, 24 Nov 2025 19:18:22 +0530 Subject: [PATCH] fix(auth): Fix unnecesary header sanitization in tansport --- mcp_fuzzer/cli/runner.py | 7 ++++++- mcp_fuzzer/client/__init__.py | 1 + mcp_fuzzer/transport/http.py | 23 +++++++++++++++++------ mcp_fuzzer/transport/sse.py | 21 ++++++++++++++++----- mcp_fuzzer/transport/streamable_http.py | 21 ++++++++++++++++----- 5 files changed, 56 insertions(+), 17 deletions(-) diff --git a/mcp_fuzzer/cli/runner.py b/mcp_fuzzer/cli/runner.py index 8995381..1155378 100644 --- a/mcp_fuzzer/cli/runner.py +++ b/mcp_fuzzer/cli/runner.py @@ -47,7 +47,12 @@ def create_transport_with_auth(args, client_args: dict[str, Any]): logger.debug("No auth headers found for default tool mapping") factory_kwargs = {"timeout": args.timeout} - + + # Get safety settings + safety_enabled = client_args.get("safety_enabled", True) + if safety_enabled: + factory_kwargs["safety_enabled"] = safety_enabled + # Apply auth headers to HTTP-based protocols if args.protocol in ("http", "https", "streamablehttp", "sse") and auth_headers: factory_kwargs["auth_headers"] = auth_headers diff --git a/mcp_fuzzer/client/__init__.py b/mcp_fuzzer/client/__init__.py index d326c91..3e009dc 100644 --- a/mcp_fuzzer/client/__init__.py +++ b/mcp_fuzzer/client/__init__.py @@ -56,6 +56,7 @@ def __init__(self, protocol, endpoint, timeout): client_args = { "auth_manager": config.get("auth_manager"), + "safety_enabled": config.get("safety_enabled", True), } transport = create_transport_with_auth(args, client_args) diff --git a/mcp_fuzzer/transport/http.py b/mcp_fuzzer/transport/http.py index b67f7a8..9a92afa 100644 --- a/mcp_fuzzer/transport/http.py +++ b/mcp_fuzzer/transport/http.py @@ -48,19 +48,30 @@ def __init__( url: str, timeout: float = 30.0, auth_headers: dict[str, str | None] | None = None, + safety_enabled: bool = True, ): self.url = url self.timeout = timeout + self.safety_enabled = safety_enabled self.headers = { "Accept": DEFAULT_HTTP_ACCEPT, "Content-Type": JSON_CONTENT_TYPE, } - if auth_headers: - self.headers.update(auth_headers) + self.auth_headers = {k: v for k, v in (auth_headers or {}).items() if v is not None} # Track last activity for process management self._last_activity = time.time() + def _prepare_headers_with_auth(self, headers: dict[str, str]) -> dict[str, str]: + """Prepare headers with optional safety sanitization and auth headers.""" + if self.safety_enabled: + safe_headers = self._prepare_safe_headers(headers) + else: + safe_headers = headers.copy() + # Add auth headers after sanitization (they are user-configured and safe) + safe_headers.update(self.auth_headers) + return safe_headers + # Initialize process manager for any subprocesses (like proxy servers) watchdog_config = WatchdogConfig( check_interval=1.0, @@ -118,7 +129,7 @@ async def send_request( # Use shared network functionality self._validate_network_request(self.url) - safe_headers = self._prepare_safe_headers(self.headers) + safe_headers = self._prepare_headers_with_auth(self.headers) async with self._create_http_client(self.timeout) as client: response = await client.post(self.url, json=payload, headers=safe_headers) @@ -162,7 +173,7 @@ async def send_raw(self, payload: dict[str, Any]) -> Any: # Use shared network functionality self._validate_network_request(self.url) - safe_headers = self._prepare_safe_headers(self.headers) + safe_headers = self._prepare_headers_with_auth(self.headers) async with self._create_http_client(self.timeout) as client: response = await client.post(self.url, json=payload, headers=safe_headers) @@ -202,7 +213,7 @@ async def send_notification( # Use shared network functionality self._validate_network_request(self.url) - safe_headers = self._prepare_safe_headers(self.headers) + safe_headers = self._prepare_headers_with_auth(self.headers) async with self._create_http_client(self.timeout) as client: response = await client.post(self.url, json=payload, headers=safe_headers) @@ -242,7 +253,7 @@ async def _stream_request( # Use shared network functionality self._validate_network_request(self.url) - safe_headers = self._prepare_safe_headers(self.headers) + safe_headers = self._prepare_headers_with_auth(self.headers) async with self._create_http_client(self.timeout) as client: # First request diff --git a/mcp_fuzzer/transport/sse.py b/mcp_fuzzer/transport/sse.py index 4a528df..8b3375d 100644 --- a/mcp_fuzzer/transport/sse.py +++ b/mcp_fuzzer/transport/sse.py @@ -14,15 +14,26 @@ def __init__( url: str, timeout: float = 30.0, auth_headers: dict[str, str | None] | None = None, + safety_enabled: bool = True, ): self.url = url self.timeout = timeout + self.safety_enabled = safety_enabled self.headers = { "Accept": "text/event-stream", "Content-Type": "application/json", } - if auth_headers: - self.headers.update(auth_headers) + self.auth_headers = {k: v for k, v in (auth_headers or {}).items() if v is not None} + + def _prepare_headers_with_auth(self, headers: dict[str, str]) -> dict[str, str]: + """Prepare headers with optional safety sanitization and auth headers.""" + if self.safety_enabled: + safe_headers = sanitize_headers(headers) + else: + safe_headers = headers.copy() + # Add auth headers after sanitization (they are user-configured and safe) + safe_headers.update(self.auth_headers) + return safe_headers async def send_request( self, method: str, params: dict[str, Any | None] | None = None @@ -42,7 +53,7 @@ async def send_raw(self, payload: dict[str, Any]) -> Any: "Network to non-local host is disallowed by safety policy", context={"url": self.url}, ) - safe_headers = sanitize_headers(self.headers) + safe_headers = self._prepare_headers_with_auth(self.headers) response = await client.post(self.url, json=payload, headers=safe_headers) response.raise_for_status() buffer: list[str] = [] @@ -107,7 +118,7 @@ async def send_notification( "Network to non-local host is disallowed by safety policy", context={"url": self.url}, ) - safe_headers = sanitize_headers(self.headers) + safe_headers = self._prepare_headers_with_auth(self.headers) response = await client.post(self.url, json=payload, headers=safe_headers) response.raise_for_status() @@ -130,7 +141,7 @@ async def _stream_request(self, payload: dict[str, Any]): "Network to non-local host is disallowed by safety policy", context={"url": self.url}, ) - safe_headers = sanitize_headers(self.headers) + safe_headers = self._prepare_headers_with_auth(self.headers) async with client.stream( "POST", self.url, diff --git a/mcp_fuzzer/transport/streamable_http.py b/mcp_fuzzer/transport/streamable_http.py index 2e3be62..c6abc84 100644 --- a/mcp_fuzzer/transport/streamable_http.py +++ b/mcp_fuzzer/transport/streamable_http.py @@ -51,17 +51,28 @@ def __init__( url: str, timeout: float = DEFAULT_TIMEOUT, auth_headers: dict[str, str | None] = None, + safety_enabled: bool = True, ): self.url = url self.timeout = timeout + self.safety_enabled = safety_enabled self.headers: dict[str, str] = { "Accept": DEFAULT_HTTP_ACCEPT, "Content-Type": JSON_CT, } - if auth_headers: - self.headers.update(auth_headers) + self.auth_headers = {k: v for k, v in (auth_headers or {}).items() if v is not None} self._logger = logging.getLogger(__name__) + + def _prepare_headers_with_auth(self, headers: dict[str, str]) -> dict[str, str]: + """Prepare headers with optional safety sanitization and auth headers.""" + if self.safety_enabled: + safe_headers = self._prepare_headers_with_auth(headers) + else: + safe_headers = headers.copy() + # Add auth headers after sanitization (they are user-configured and safe) + safe_headers.update(self.auth_headers) + return safe_headers self.session_id: str | None = None self.protocol_version: str | None = None self._initialized: bool = False @@ -211,7 +222,7 @@ async def send_raw(self, payload: dict[str, Any]) -> Any: ) as client: self._ensure_host_allowed() response = await self._post_with_retries( - client, self.url, payload, sanitize_headers(headers) + client, self.url, payload, self._prepare_headers_with_auth(headers) ) # Handle redirect by retrying once with provided Location or trailing slash redirect_url = self._resolve_redirect(response) @@ -314,7 +325,7 @@ async def send_notification( timeout=self.timeout, follow_redirects=False, trust_env=False ) as client: self._ensure_host_allowed() - safe_headers = sanitize_headers(headers) + safe_headers = self._prepare_headers_with_auth(headers) response = await self._post_with_retries( client, self.url, payload, safe_headers ) @@ -422,7 +433,7 @@ async def _stream_request(self, payload: dict[str, Any]): timeout=self.timeout, follow_redirects=False, trust_env=False ) as client: self._ensure_host_allowed() - safe_headers = sanitize_headers(headers) + safe_headers = self._prepare_headers_with_auth(headers) response = await client.stream( "POST", self.url, json=payload, headers=safe_headers )