diff --git a/droidrun/tools/android/portal_client.py b/droidrun/tools/android/portal_client.py index f1267036..be24d881 100644 --- a/droidrun/tools/android/portal_client.py +++ b/droidrun/tools/android/portal_client.py @@ -19,6 +19,19 @@ PORTAL_REMOTE_PORT = 8080 # Port on device where Portal HTTP server runs +# Granular timeout policy for the persistent session. +# Separating connect/read/write/pool prevents slow devices from causing +# cascading stalls when socket reuse is delayed by network jitter. +_SESSION_TIMEOUT = httpx.Timeout( + connect=5.0, # TCP handshake (only on first request per connection) + read=15.0, # Portal response time (state_full can be slow on busy devices) + write=10.0, # Request upload (keyboard input payloads are small) + pool=5.0, # Wait for a connection from the pool +) + +# Tighter timeout for lightweight probe calls (ping, version) +_PROBE_TIMEOUT = httpx.Timeout(connect=5.0, read=5.0, write=5.0, pool=5.0) + class PortalClient: """ @@ -27,15 +40,17 @@ class PortalClient: Automatically handles TCP vs Content Provider fallback with the following strategy: - On init, checks for existing port forward and reuses it - If no forward exists, creates new one - - Tests connection and sets tcp_available flag + - Fetches auth token and tests connection before enabling TCP + - Maintains a persistent httpx.AsyncClient for all TCP requests - All methods auto-select TCP or content provider based on availability - Port forwards persist until device disconnect (no explicit cleanup needed) Key features: - Reuses existing port forwards (no cleanup needed) - - Automatic fallback to content provider if TCP fails - - Zero explicit resource management - - Graceful degradation + - Persistent HTTP session eliminates per-call TCP handshake overhead + - Bearer token auth for Portal HTTP server + - Graceful degradation to content provider on any TCP failure + - Self-healing: 401/403 triggers token re-fetch and session rebuild Note: TCP mode is significantly faster but requires ADB port forwarding. Content provider mode works without port forwarding but has higher latency. @@ -47,10 +62,13 @@ def __init__(self, device: AdbDevice, prefer_tcp: bool = False): Args: device: ADB device instance - prefer_tcp: Whether to prefer TCP communication (will fallback to content provider if unavailable) + prefer_tcp: Whether to prefer TCP communication (will fallback to + content provider if unavailable) Note: - Call `await client.connect()` after initialization to establish connection. + Call `await client.connect()` after initialization to establish + connection. Call `await client.disconnect()` when done to release + the persistent HTTP session. """ self.device = device self.prefer_tcp = prefer_tcp @@ -58,11 +76,20 @@ def __init__(self, device: AdbDevice, prefer_tcp: bool = False): self.tcp_base_url = None self.local_tcp_port = None self._auth_token: Optional[str] = None + + # NOTE: + # DroidRun agent loop is sequential (no concurrent requests), + # so this shared AsyncClient is never accessed concurrently. + # This avoids thread-safety concerns while enabling connection reuse. + self._session: Optional[httpx.AsyncClient] = None self._connected = False async def connect(self) -> None: """ - Establish connection... + Establish connection to Portal. + + If prefer_tcp=True, attempts to set up TCP mode with auth. + Always falls back to content provider on any failure. """ if self._connected: return @@ -72,13 +99,32 @@ async def connect(self) -> None: self._connected = True + async def disconnect(self) -> None: + """ + Release the persistent HTTP session and reset connection state. + + Safe to call multiple times (idempotent). After disconnect(), + connect() can be called again to re-establish. + """ + if self._session is not None: + try: + await self._session.aclose() + except Exception: + pass + self._session = None + + # Enforce invariant: no session means TCP not available + self.tcp_available = False + self._connected = False + async def _ensure_connected(self) -> None: """Check if connected, raise error if not.""" if not self._connected: await self.connect() async def _fetch_auth_token(self) -> Optional[str]: - """Fetch the auth token from the Portal via the content provider. + """ + Fetch the auth token from the Portal via the content provider. The Portal HTTP server requires a Bearer token for all requests. The token is generated by the Portal app and exposed via the content @@ -117,6 +163,7 @@ async def _fetch_auth_token(self) -> Optional[str]: logger.debug(f"Auth token: unexpected response format: {data}") return None + except Exception as e: logger.debug(f"Failed to fetch auth token: {e}") return None @@ -129,17 +176,75 @@ def _tcp_headers(self) -> Dict[str, str]: headers["Authorization"] = f"Bearer {self._auth_token}" return headers + def _build_session(self) -> httpx.AsyncClient: + """ + Build a new persistent AsyncClient with current auth headers. + + Uses granular httpx.Timeout to prevent socket-reuse stalls on slow + or remote devices. Called only after TCP is confirmed working. + """ + return httpx.AsyncClient( + headers=self._tcp_headers, + timeout=_SESSION_TIMEOUT, + ) + + async def _replace_session(self) -> None: + """ + Close any existing session and build a fresh one with current token. + + Used when the token rotates (server restart, 401 recovery) so the + persistent session's default Authorization header stays current. + Safe to call whether or not a session currently exists. + """ + if self._session is not None: + try: + await self._session.aclose() + except Exception: + pass + self._session = self._build_session() + + async def _degrade_to_cp(self, reason: str) -> None: + """ + Gracefully downgrade from TCP to content provider. + + Called on unrecoverable TCP failures (connection errors, fatal status + codes). Closes the session cleanly, resets state, and logs once so + operators can diagnose without being spammed. + """ + if self._session is not None: + try: + await self._session.aclose() + except Exception: + pass + self._session = None + if self.tcp_available: + logger.warning( + f"Portal TCP degraded → falling back to content provider. Reason: {reason}" + ) + self.tcp_available = False + async def _try_enable_tcp(self) -> None: """ - Try to enable TCP communication. Fails silently and falls back to content provider. + Try to enable TCP communication. Fails silently and falls back to + content provider. Strategy: 1. Fetch auth token via content provider (secure ADB channel) 2. Check if port forward already exists → reuse 3. If not, create new forward - 4. Test connection with authenticated ping - 5. Set tcp_available flag + 4. Test connection via /version (authenticated endpoint, not /ping) + 5. Build persistent session ONLY on confirmed success + 6. Set tcp_available flag """ + # Clean up any session left over from a previous connect() call. + # Handles reconnect scenario: disconnect() then connect() again. + if self._session is not None: + try: + await self._session.aclose() + except Exception: + pass + self._session = None + try: # Step 1: Fetch auth token before any HTTP calls self._auth_token = await self._fetch_auth_token() @@ -166,17 +271,25 @@ async def _try_enable_tcp(self) -> None: f"Reusing existing forward: localhost:{local_port} -> device:{PORTAL_REMOTE_PORT}" ) - # Store local port self.local_tcp_port = local_port - - # Step 4: Test connection with auth self.tcp_base_url = f"http://localhost:{local_port}" + + # Step 4: Test connection using /version (requires auth). + # /ping does NOT require auth — a 200 from /ping only proves + # reachability, not that the token is valid. Using /version + # confirms both reachability and auth are working before we + # commit to TCP mode and build a persistent session. if await self._test_connection(): + # Step 5: Build persistent session ONLY after confirmed success. + # Enforces invariant: _session is not None ↔ tcp_available. + self._session = self._build_session() self.tcp_available = True logger.debug(f"✓ TCP mode enabled: {self.tcp_base_url}") else: # Step 4b: Try enabling the HTTP server via content provider - logger.debug("TCP ping failed, trying to enable Portal HTTP server...") + logger.debug( + "TCP auth probe failed, trying to enable Portal HTTP server..." + ) await self.device.shell( "content insert --uri content://com.droidrun.portal/toggle_socket_server --bind enabled:b:true" ) @@ -188,6 +301,8 @@ async def _try_enable_tcp(self) -> None: self._auth_token = new_token if await self._test_connection(): + # Build session only after confirmed success on retry + self._session = self._build_session() self.tcp_available = True logger.debug( f"✓ TCP mode enabled after starting server: {self.tcp_base_url}" @@ -198,6 +313,13 @@ async def _try_enable_tcp(self) -> None: except Exception as e: logger.warning(f"TCP unavailable ({e}), using content provider fallback") + # Ensure no zombie session on any exception path + if self._session is not None: + try: + await self._session.aclose() + except Exception: + pass + self._session = None self.tcp_available = False async def _find_existing_forward(self) -> Optional[int]: @@ -211,7 +333,6 @@ async def _find_existing_forward(self) -> Optional[int]: forwards = [] async for forward in self.device.forward_list(): forwards.append(forward) - # forwards is a list of ForwardItem objects with serial, local, remote attributes for forward in forwards: if ( forward.serial == self.device.serial @@ -232,48 +353,61 @@ async def _find_existing_forward(self) -> Optional[int]: async def _tcp_request( self, - client: httpx.AsyncClient, method: str, url: str, extra_headers: Optional[Dict[str, str]] = None, **kwargs, - ) -> httpx.Response: - """Make an authenticated TCP request, re-fetching the token once on 401/403. + ) -> Optional[httpx.Response]: + """ + Make an authenticated TCP request via the persistent session. + """ + if not self.tcp_available or self._session is None: + logger.debug("_tcp_request: TCP inactive → fallback") + return None - This is the single choke-point for all TCP HTTP traffic so that token - rotation is handled uniformly rather than duplicated per call-site. + headers = {**self._tcp_headers, **(extra_headers or {})} - Args: - client: Shared httpx.AsyncClient for the request. - method: HTTP method string ("GET", "POST", …). - url: Full URL to request. - extra_headers: Additional headers merged on top of auth headers - (e.g. ``{"Content-Type": "application/json"}``). - **kwargs: Passed straight through to ``client.request()``. + try: + response = await self._session.request(method, url, headers=headers, **kwargs) - Returns: - The httpx.Response (possibly from the retry attempt). - """ - headers = {**self._tcp_headers, **(extra_headers or {})} - response = await client.request(method, url, headers=headers, **kwargs) + # 🔁 AUTH RECOVERY + if response.status_code in (401, 403): + logger.debug("Auth expired → refreshing token") + new_token = await self._fetch_auth_token() + if new_token: + self._auth_token = new_token + await self._replace_session() + headers = {**self._tcp_headers, **(extra_headers or {})} + response = await self._session.request(method, url, headers=headers, **kwargs) - if response.status_code in (401, 403): - logger.debug( - f"TCP auth rejected ({response.status_code}), re-fetching token..." - ) - self._auth_token = await self._fetch_auth_token() - if self._auth_token: - headers = {**self._tcp_headers, **(extra_headers or {})} - response = await client.request(method, url, headers=headers, **kwargs) + return response - return response + except (httpx.ConnectError, httpx.RemoteProtocolError, httpx.NetworkError) as e: + await self._degrade_to_cp(str(e)) + return None + + except Exception as e: + logger.debug(f"TCP request error: {e}") + return None async def _test_connection(self) -> bool: - """Test if TCP connection to Portal is working (with auth).""" + """ + Test if TCP connection to Portal is working (authenticated). + + Uses /version which requires a valid Bearer token. This ensures + that tcp_available=True means both reachability AND auth are + confirmed — unlike /ping which returns 200 without any token. + + Intentionally uses a temporary httpx.AsyncClient (not self._session) + so that doctor.py and CLI diagnostics can call this without affecting + the persistent session lifecycle. + """ try: async with httpx.AsyncClient() as client: - response = await self._tcp_request( - client, "GET", f"{self.tcp_base_url}/ping", timeout=5 + response = await client.get( + f"{self.tcp_base_url}/version", + headers=self._tcp_headers, + timeout=_PROBE_TIMEOUT, ) return response.status_code == 200 except Exception as e: @@ -304,9 +438,7 @@ def _parse_content_provider_output( json_str = line[result_start:] try: json_data = json.loads(json_str) - # Handle nested "result" or "data" field with JSON string (backward compatible) if isinstance(json_data, dict): - # Check for 'result' first (new portal format), then 'data' (legacy) inner_key = ( "result" if "result" in json_data @@ -352,40 +484,30 @@ async def get_state(self) -> Dict[str, Any]: async def _get_state_tcp(self) -> Dict[str, Any]: """Get state via TCP.""" - try: - async with httpx.AsyncClient() as client: - response = await self._tcp_request( - client, "GET", f"{self.tcp_base_url}/state_full", timeout=10 + response = await self._tcp_request( + "GET", f"{self.tcp_base_url}/state_full", timeout=15 + ) + if response is not None and response.status_code == 200: + data = response.json() + if isinstance(data, dict): + inner_key = ( + "result" + if "result" in data + else "data" if "data" in data else None ) - if response.status_code == 200: - data = response.json() + if inner_key: + inner_value = data[inner_key] + if isinstance(inner_value, str): + try: + return json.loads(inner_value) + except json.JSONDecodeError: + pass + elif isinstance(inner_value, dict): + return inner_value + return data - # Handle nested "result" or "data" field (backward compatible) - if isinstance(data, dict): - # Check for 'result' first (new portal format), then 'data' (legacy) - inner_key = ( - "result" - if "result" in data - else "data" if "data" in data else None - ) - if inner_key: - inner_value = data[inner_key] - if isinstance(inner_value, str): - try: - return json.loads(inner_value) - except json.JSONDecodeError: - pass - elif isinstance(inner_value, dict): - return inner_value - return data - else: - logger.debug( - f"TCP get_state failed ({response.status_code}), using fallback" - ) - return await self._get_state_content_provider() - except Exception as e: - logger.debug(f"TCP get_state error: {e}, using fallback") - return await self._get_state_content_provider() + logger.debug("TCP get_state failed, using content provider fallback") + return await self._get_state_content_provider() async def _get_state_content_provider(self) -> Dict[str, Any]: """Get state via content provider (fallback).""" @@ -401,9 +523,7 @@ async def _get_state_content_provider(self) -> Dict[str, Any]: "message": "Failed to parse state data from ContentProvider", } - # Handle nested "result" or "data" field if present (backward compatible) if isinstance(state_data, dict): - # Check for 'result' first (new portal format), then 'data' (legacy) inner_key = ( "result" if "result" in state_data @@ -446,29 +566,21 @@ async def input_text(self, text: str, clear: bool = False) -> bool: async def _input_text_tcp(self, text: str, clear: bool) -> bool: """Input text via TCP.""" - try: - encoded = base64.b64encode(text.encode()).decode() - payload = {"base64_text": encoded, "clear": clear} - async with httpx.AsyncClient() as client: - response = await self._tcp_request( - client, - "POST", - f"{self.tcp_base_url}/keyboard/input", - extra_headers={"Content-Type": "application/json"}, - json=payload, - timeout=10, - ) - if response.status_code == 200: - logger.debug("TCP input_text successful") - return True - else: - logger.debug( - f"TCP input_text failed ({response.status_code}), using fallback" - ) - return await self._input_text_content_provider(text, clear) - except Exception as e: - logger.debug(f"TCP input_text error: {e}, using fallback") - return await self._input_text_content_provider(text, clear) + encoded = base64.b64encode(text.encode()).decode() + payload = {"base64_text": encoded, "clear": clear} + response = await self._tcp_request( + "POST", + f"{self.tcp_base_url}/keyboard/input", + extra_headers={"Content-Type": "application/json"}, + json=payload, + timeout=10, + ) + if response is not None and response.status_code == 200: + logger.debug("TCP input_text successful") + return True + + logger.debug("TCP input_text failed, using content provider fallback") + return await self._input_text_content_provider(text, clear) async def _input_text_content_provider(self, text: str, clear: bool) -> bool: """Input text via content provider (fallback).""" @@ -505,37 +617,26 @@ async def take_screenshot(self, hide_overlay: bool = True) -> bytes: async def _take_screenshot_tcp(self, hide_overlay: bool) -> bytes: """Take screenshot via TCP.""" - try: - url = f"{self.tcp_base_url}/screenshot" - if not hide_overlay: - url += "?hideOverlay=false" + url = f"{self.tcp_base_url}/screenshot" + if not hide_overlay: + url += "?hideOverlay=false" - async with httpx.AsyncClient() as client: - response = await self._tcp_request(client, "GET", url, timeout=10.0) - if response.status_code == 200: - data = response.json() - # Check for 'result' first (new portal format), then 'data' (legacy) - if data.get("status") == "success": - inner_key = ( - "result" - if "result" in data - else "data" if "data" in data else None - ) - if inner_key: - logger.debug("Screenshot taken via TCP") - return base64.b64decode(data[inner_key]) - logger.debug( - "TCP screenshot failed (invalid response), using fallback" - ) - return await self._take_screenshot_adb() - else: - logger.debug( - f"TCP screenshot failed ({response.status_code}), using fallback" - ) - return await self._take_screenshot_adb() - except Exception as e: - logger.debug(f"TCP screenshot error: {e}, using fallback") - return await self._take_screenshot_adb() + response = await self._tcp_request("GET", url, timeout=15.0) + + if response is not None and response.status_code == 200: + data = response.json() + if data.get("status") == "success": + inner_key = ( + "result" + if "result" in data + else "data" if "data" in data else None + ) + if inner_key: + logger.debug("Screenshot via TCP") + return base64.b64decode(data[inner_key]) + + logger.debug("TCP screenshot failed → ADB fallback") + return await self._take_screenshot_adb() async def _take_screenshot_adb(self) -> bytes: """Take screenshot via ADB screencap (fallback).""" @@ -559,7 +660,6 @@ async def get_apps(self, include_system: bool = True) -> List[Dict[str, str]]: try: logger.debug("Getting apps via content provider") - # Query content provider output = await self.device.shell( "content query --uri content://com.droidrun.portal/packages" ) @@ -569,19 +669,13 @@ async def get_apps(self, include_system: bool = True) -> List[Dict[str, str]]: logger.warning("No packages data found in content provider response") return [] - # Handle both formats: - # - New format: array directly (via RawArray -> result: [...]) - # - Legacy format: wrapped in {"packages": [...]} packages_list = None if isinstance(packages_data, list): - # New format: packages_data is already the list packages_list = packages_data elif isinstance(packages_data, dict): if "packages" in packages_data: - # Legacy format: wrapped in {"packages": [...]} packages_list = packages_data["packages"] else: - # May be wrapped in result/data inner_key = ( "result" if "result" in packages_data @@ -600,12 +694,10 @@ async def get_apps(self, include_system: bool = True) -> List[Dict[str, str]]: logger.warning("Could not extract packages list from response") return [] - # Filter and format apps apps = [] for package_info in packages_list: if not include_system and package_info.get("isSystemApp", False): continue - apps.append( { "package": package_info.get("packageName", ""), @@ -624,24 +716,19 @@ async def get_version(self) -> str: """Get Portal app version.""" await self._ensure_connected() if self.tcp_available: - try: - async with httpx.AsyncClient() as client: - response = await self._tcp_request( - client, "GET", f"{self.tcp_base_url}/version", timeout=5.0 - ) - if response.status_code == 200: - data = response.json() - # Check for 'result' first (new portal format), then 'data' (legacy) - inner_key = ( - "result" - if "result" in data - else "data" if "data" in data else None - ) - if inner_key: - return data[inner_key] - return data.get("status", "unknown") - except Exception: - pass + response = await self._tcp_request( + "GET", f"{self.tcp_base_url}/version", timeout=5.0 + ) + if response is not None and response.status_code == 200: + data = response.json() + inner_key = ( + "result" + if "result" in data + else "data" if "data" in data else None + ) + if inner_key: + return data[inner_key] + return data.get("status", "unknown") # Fallback to content provider try: @@ -650,7 +737,6 @@ async def get_version(self) -> str: ) result = self._parse_content_provider_output(output) if result: - # Check for 'result' first (new portal format), then 'data' (legacy) inner_key = ( "result" if "result" in result @@ -672,35 +758,35 @@ async def ping(self) -> Dict[str, Any]: """ await self._ensure_connected() if self.tcp_available: - try: - async with httpx.AsyncClient() as client: - response = await self._tcp_request( - client, "GET", f"{self.tcp_base_url}/ping", timeout=5.0 - ) - if response.status_code == 200: - try: - tcp_response = response.json() if response.content else {} - result = { - "status": "success", - "method": "tcp", - "url": self.tcp_base_url, - "response": tcp_response, - } - except json.JSONDecodeError: - result = { - "status": "success", - "method": "tcp", - "url": self.tcp_base_url, - "response": response.text, - } - else: - return { - "status": "error", - "method": "tcp", - "message": f"HTTP {response.status_code}: {response.text}", - } - except Exception as e: - return {"status": "error", "method": "tcp", "message": str(e)} + response = await self._tcp_request( + "GET", f"{self.tcp_base_url}/ping", timeout=5.0 + ) + if response is not None and response.status_code == 200: + try: + tcp_response = response.json() if response.content else {} + result = { + "status": "success", + "method": "tcp", + "url": self.tcp_base_url, + "response": tcp_response, + } + except json.JSONDecodeError: + result = { + "status": "success", + "method": "tcp", + "url": self.tcp_base_url, + "response": response.text if response else "", + } + else: + return { + "status": "error", + "method": "tcp", + "message": ( + f"HTTP {response.status_code}: {response.text}" + if response is not None + else "TCP not available" + ), + } else: # Test content provider try: @@ -740,4 +826,4 @@ async def ping(self) -> Dict[str, Any]: "message": f"state check failed: {e}", } - return result + return result \ No newline at end of file diff --git a/droidrun/tools/ui/provider.py b/droidrun/tools/ui/provider.py index 5caedd37..ae40b104 100644 --- a/droidrun/tools/ui/provider.py +++ b/droidrun/tools/ui/provider.py @@ -184,7 +184,11 @@ async def _recover_portal(self) -> None: new_token = await portal._fetch_auth_token() if new_token: portal._auth_token = new_token - logger.debug("Auth token refreshed after TCP server restart") + # Rebuild the persistent session so its default Authorization + # header reflects the rotated token. Without this, the session + # keeps sending the old (now-rejected) token on every request. + await portal._replace_session() + logger.debug("Auth token refreshed and session rebuilt after TCP server restart") except Exception as e: logger.debug(f"TCP server restart failed: {e}")