-
-
Notifications
You must be signed in to change notification settings - Fork 2
fix(auth): Fix unnecesary header sanitization in tansport #137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -48,19 +48,30 @@ def __init__( | |||||||||
| url: str, | ||||||||||
| timeout: float = 30.0, | ||||||||||
| auth_headers: dict[str, str | None] | None = None, | ||||||||||
| safety_enabled: bool = True, | ||||||||||
| ): | ||||||||||
| self.url = url | ||||||||||
| self.timeout = timeout | ||||||||||
| self.safety_enabled = safety_enabled | ||||||||||
| self.headers = { | ||||||||||
| "Accept": DEFAULT_HTTP_ACCEPT, | ||||||||||
| "Content-Type": JSON_CONTENT_TYPE, | ||||||||||
| } | ||||||||||
| if auth_headers: | ||||||||||
| self.headers.update(auth_headers) | ||||||||||
| self.auth_headers = {k: v for k, v in (auth_headers or {}).items() if v is not None} | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line exceeds maximum length. Line 60 exceeds the project's 88-character limit (92 characters). Consider breaking it into multiple lines for better readability. Apply this diff: - self.auth_headers = {k: v for k, v in (auth_headers or {}).items() if v is not None}
+ self.auth_headers = {
+ k: v for k, v in (auth_headers or {}).items() if v is not None
+ }📝 Committable suggestion
Suggested change
🧰 Tools🪛 GitHub Actions: Lint[error] 60-60: E501 Line too long (92 > 88) 🤖 Prompt for AI Agents |
||||||||||
|
|
||||||||||
| # Track last activity for process management | ||||||||||
| self._last_activity = time.time() | ||||||||||
|
|
||||||||||
| def _prepare_headers_with_auth(self, headers: dict[str, str]) -> dict[str, str]: | ||||||||||
| """Prepare headers with optional safety sanitization and auth headers.""" | ||||||||||
| if self.safety_enabled: | ||||||||||
| safe_headers = self._prepare_safe_headers(headers) | ||||||||||
| else: | ||||||||||
| safe_headers = headers.copy() | ||||||||||
| # Add auth headers after sanitization (they are user-configured and safe) | ||||||||||
| safe_headers.update(self.auth_headers) | ||||||||||
| return safe_headers | ||||||||||
|
|
||||||||||
| # Initialize process manager for any subprocesses (like proxy servers) | ||||||||||
| watchdog_config = WatchdogConfig( | ||||||||||
| check_interval=1.0, | ||||||||||
|
|
@@ -118,7 +129,7 @@ async def send_request( | |||||||||
|
|
||||||||||
| # Use shared network functionality | ||||||||||
| self._validate_network_request(self.url) | ||||||||||
| safe_headers = self._prepare_safe_headers(self.headers) | ||||||||||
| safe_headers = self._prepare_headers_with_auth(self.headers) | ||||||||||
|
|
||||||||||
| async with self._create_http_client(self.timeout) as client: | ||||||||||
| response = await client.post(self.url, json=payload, headers=safe_headers) | ||||||||||
|
|
@@ -162,7 +173,7 @@ async def send_raw(self, payload: dict[str, Any]) -> Any: | |||||||||
|
|
||||||||||
| # Use shared network functionality | ||||||||||
| self._validate_network_request(self.url) | ||||||||||
| safe_headers = self._prepare_safe_headers(self.headers) | ||||||||||
| safe_headers = self._prepare_headers_with_auth(self.headers) | ||||||||||
|
|
||||||||||
| async with self._create_http_client(self.timeout) as client: | ||||||||||
| response = await client.post(self.url, json=payload, headers=safe_headers) | ||||||||||
|
|
@@ -202,7 +213,7 @@ async def send_notification( | |||||||||
|
|
||||||||||
| # Use shared network functionality | ||||||||||
| self._validate_network_request(self.url) | ||||||||||
| safe_headers = self._prepare_safe_headers(self.headers) | ||||||||||
| safe_headers = self._prepare_headers_with_auth(self.headers) | ||||||||||
|
|
||||||||||
| async with self._create_http_client(self.timeout) as client: | ||||||||||
| response = await client.post(self.url, json=payload, headers=safe_headers) | ||||||||||
|
|
@@ -242,7 +253,7 @@ async def _stream_request( | |||||||||
|
|
||||||||||
| # Use shared network functionality | ||||||||||
| self._validate_network_request(self.url) | ||||||||||
| safe_headers = self._prepare_safe_headers(self.headers) | ||||||||||
| safe_headers = self._prepare_headers_with_auth(self.headers) | ||||||||||
|
|
||||||||||
| async with self._create_http_client(self.timeout) as client: | ||||||||||
| # First request | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,15 +14,26 @@ def __init__( | |
| url: str, | ||
| timeout: float = 30.0, | ||
| auth_headers: dict[str, str | None] | None = None, | ||
| safety_enabled: bool = True, | ||
| ): | ||
| self.url = url | ||
| self.timeout = timeout | ||
| self.safety_enabled = safety_enabled | ||
| self.headers = { | ||
| "Accept": "text/event-stream", | ||
| "Content-Type": "application/json", | ||
| } | ||
| if auth_headers: | ||
| self.headers.update(auth_headers) | ||
| self.auth_headers = {k: v for k, v in (auth_headers or {}).items() if v is not None} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line exceeds maximum length. Line 26 exceeds the project's 88-character limit (92 characters). Apply this diff: - self.auth_headers = {k: v for k, v in (auth_headers or {}).items() if v is not None}
+ self.auth_headers = {
+ k: v for k, v in (auth_headers or {}).items() if v is not None
+ }
🧰 Tools🪛 GitHub Actions: Lint[error] 26-26: E501 Line too long (92 > 88) 🤖 Prompt for AI Agents |
||
|
|
||
| def _prepare_headers_with_auth(self, headers: dict[str, str]) -> dict[str, str]: | ||
| """Prepare headers with optional safety sanitization and auth headers.""" | ||
| if self.safety_enabled: | ||
| safe_headers = sanitize_headers(headers) | ||
| else: | ||
| safe_headers = headers.copy() | ||
| # Add auth headers after sanitization (they are user-configured and safe) | ||
| safe_headers.update(self.auth_headers) | ||
| return safe_headers | ||
|
|
||
| async def send_request( | ||
| self, method: str, params: dict[str, Any | None] | None = None | ||
|
|
@@ -42,7 +53,7 @@ async def send_raw(self, payload: dict[str, Any]) -> Any: | |
| "Network to non-local host is disallowed by safety policy", | ||
| context={"url": self.url}, | ||
| ) | ||
| safe_headers = sanitize_headers(self.headers) | ||
| safe_headers = self._prepare_headers_with_auth(self.headers) | ||
| response = await client.post(self.url, json=payload, headers=safe_headers) | ||
| response.raise_for_status() | ||
| buffer: list[str] = [] | ||
|
|
@@ -107,7 +118,7 @@ async def send_notification( | |
| "Network to non-local host is disallowed by safety policy", | ||
| context={"url": self.url}, | ||
| ) | ||
| safe_headers = sanitize_headers(self.headers) | ||
| safe_headers = self._prepare_headers_with_auth(self.headers) | ||
| response = await client.post(self.url, json=payload, headers=safe_headers) | ||
| response.raise_for_status() | ||
|
|
||
|
|
@@ -130,7 +141,7 @@ async def _stream_request(self, payload: dict[str, Any]): | |
| "Network to non-local host is disallowed by safety policy", | ||
| context={"url": self.url}, | ||
| ) | ||
| safe_headers = sanitize_headers(self.headers) | ||
| safe_headers = self._prepare_headers_with_auth(self.headers) | ||
| async with client.stream( | ||
| "POST", | ||
| self.url, | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -51,17 +51,28 @@ def __init__( | |||||||||||||||||||||||||||||||||||||
| url: str, | ||||||||||||||||||||||||||||||||||||||
| timeout: float = DEFAULT_TIMEOUT, | ||||||||||||||||||||||||||||||||||||||
| auth_headers: dict[str, str | None] = None, | ||||||||||||||||||||||||||||||||||||||
| safety_enabled: bool = True, | ||||||||||||||||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||||||||||||||||
| self.url = url | ||||||||||||||||||||||||||||||||||||||
| self.timeout = timeout | ||||||||||||||||||||||||||||||||||||||
| self.safety_enabled = safety_enabled | ||||||||||||||||||||||||||||||||||||||
| self.headers: dict[str, str] = { | ||||||||||||||||||||||||||||||||||||||
| "Accept": DEFAULT_HTTP_ACCEPT, | ||||||||||||||||||||||||||||||||||||||
| "Content-Type": JSON_CT, | ||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||
| if auth_headers: | ||||||||||||||||||||||||||||||||||||||
| self.headers.update(auth_headers) | ||||||||||||||||||||||||||||||||||||||
| self.auth_headers = {k: v for k, v in (auth_headers or {}).items() if v is not None} | ||||||||||||||||||||||||||||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Line exceeds maximum length. Line 63 exceeds the project's 88-character limit (92 characters). Apply this diff: - self.auth_headers = {k: v for k, v in (auth_headers or {}).items() if v is not None}
+ self.auth_headers = {
+ k: v for k, v in (auth_headers or {}).items() if v is not None
+ }📝 Committable suggestion
Suggested change
🧰 Tools🪛 GitHub Actions: Lint[error] 63-63: E501 Line too long (92 > 88) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| self._logger = logging.getLogger(__name__) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def _prepare_headers_with_auth(self, headers: dict[str, str]) -> dict[str, str]: | ||||||||||||||||||||||||||||||||||||||
| """Prepare headers with optional safety sanitization and auth headers.""" | ||||||||||||||||||||||||||||||||||||||
| if self.safety_enabled: | ||||||||||||||||||||||||||||||||||||||
| safe_headers = self._prepare_headers_with_auth(headers) | ||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||
| safe_headers = headers.copy() | ||||||||||||||||||||||||||||||||||||||
| # Add auth headers after sanitization (they are user-configured and safe) | ||||||||||||||||||||||||||||||||||||||
| safe_headers.update(self.auth_headers) | ||||||||||||||||||||||||||||||||||||||
| return safe_headers | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+67
to
+75
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Critical: Infinite recursion in Line 70 calls Apply this diff to call def _prepare_headers_with_auth(self, headers: dict[str, str]) -> dict[str, str]:
"""Prepare headers with optional safety sanitization and auth headers."""
if self.safety_enabled:
- safe_headers = self._prepare_headers_with_auth(headers)
+ safe_headers = sanitize_headers(headers)
else:
safe_headers = headers.copy()
# Add auth headers after sanitization (they are user-configured and safe)
safe_headers.update(self.auth_headers)
return safe_headers
- self.session_id: str | None = None
- self.protocol_version: str | None = None
- self._initialized: bool = False
- self._init_lock: asyncio.Lock = asyncio.Lock()
- self._initializing: bool = FalseNote: The code snippet also shows lines 76-80 that appear to be misplaced (they should be in 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||
| self.session_id: str | None = None | ||||||||||||||||||||||||||||||||||||||
| self.protocol_version: str | None = None | ||||||||||||||||||||||||||||||||||||||
| self._initialized: bool = False | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -211,7 +222,7 @@ async def send_raw(self, payload: dict[str, Any]) -> Any: | |||||||||||||||||||||||||||||||||||||
| ) as client: | ||||||||||||||||||||||||||||||||||||||
| self._ensure_host_allowed() | ||||||||||||||||||||||||||||||||||||||
| response = await self._post_with_retries( | ||||||||||||||||||||||||||||||||||||||
| client, self.url, payload, sanitize_headers(headers) | ||||||||||||||||||||||||||||||||||||||
| client, self.url, payload, self._prepare_headers_with_auth(headers) | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
| # Handle redirect by retrying once with provided Location or trailing slash | ||||||||||||||||||||||||||||||||||||||
| redirect_url = self._resolve_redirect(response) | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -314,7 +325,7 @@ async def send_notification( | |||||||||||||||||||||||||||||||||||||
| timeout=self.timeout, follow_redirects=False, trust_env=False | ||||||||||||||||||||||||||||||||||||||
| ) as client: | ||||||||||||||||||||||||||||||||||||||
| self._ensure_host_allowed() | ||||||||||||||||||||||||||||||||||||||
| safe_headers = sanitize_headers(headers) | ||||||||||||||||||||||||||||||||||||||
| safe_headers = self._prepare_headers_with_auth(headers) | ||||||||||||||||||||||||||||||||||||||
| response = await self._post_with_retries( | ||||||||||||||||||||||||||||||||||||||
| client, self.url, payload, safe_headers | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
@@ -422,7 +433,7 @@ async def _stream_request(self, payload: dict[str, Any]): | |||||||||||||||||||||||||||||||||||||
| timeout=self.timeout, follow_redirects=False, trust_env=False | ||||||||||||||||||||||||||||||||||||||
| ) as client: | ||||||||||||||||||||||||||||||||||||||
| self._ensure_host_allowed() | ||||||||||||||||||||||||||||||||||||||
| safe_headers = sanitize_headers(headers) | ||||||||||||||||||||||||||||||||||||||
| safe_headers = self._prepare_headers_with_auth(headers) | ||||||||||||||||||||||||||||||||||||||
| response = await client.stream( | ||||||||||||||||||||||||||||||||||||||
| "POST", self.url, json=payload, headers=safe_headers | ||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix the conditional logic for safety_enabled.
The current implementation only adds
safety_enabledtofactory_kwargswhen it'sTrue. If a user explicitly setssafety_enabled=False, it won't be passed to the transport, and the transport will default toTrue, ignoring the user's preference.Apply this diff to always pass the safety_enabled value:
factory_kwargs = {"timeout": args.timeout} - - # Get safety settings - safety_enabled = client_args.get("safety_enabled", True) - if safety_enabled: - factory_kwargs["safety_enabled"] = safety_enabled - + + # Pass safety_enabled to transport (defaults to True if not provided) + factory_kwargs["safety_enabled"] = client_args.get("safety_enabled", True) +📝 Committable suggestion
🤖 Prompt for AI Agents