diff --git a/droidrun/agent/droid/droid_agent.py b/droidrun/agent/droid/droid_agent.py index a5ec8c0c..c46cf2b2 100644 --- a/droidrun/agent/droid/droid_agent.py +++ b/droidrun/agent/droid/droid_agent.py @@ -87,6 +87,7 @@ flush, ) from droidrun.tools.driver.android import AndroidDriver +from droidrun.tools.driver.android_ssh import AndroidSSHDriver from droidrun.tools.driver.base import DeviceDisconnectedError from droidrun.tools.driver.ios import IOSDriver from droidrun.tools.driver.recording import RecordingDriver @@ -407,23 +408,36 @@ async def start_handler( driver = IOSDriver(url=ios_url) await driver.connect() else: - device_serial = self.resolved_device_config.serial - if device_serial is None: - devices = await adb.list() - if not devices: - raise ValueError("No connected Android devices found.") - device_serial = devices[0].serial - - # Auto-setup portal if enabled - if self.config.device.auto_setup: - device_obj = await adb.device(serial=device_serial) - await ensure_portal_ready(device_obj, debug=self.config.logging.debug) - - driver = AndroidDriver( - serial=device_serial, - use_tcp=self.resolved_device_config.use_tcp, - ) - await driver.connect() + # Check if using SSH driver + if self.resolved_device_config.driver_type == "ssh": + driver = AndroidSSHDriver( + target=self.resolved_device_config.ssh_target, + portal_url=self.resolved_device_config.portal_url, + portal_token=self.resolved_device_config.portal_token, + su_path=self.resolved_device_config.su_path, + ) + await driver.connect() + else: + # ADB driver (default) + device_serial = self.resolved_device_config.serial + if device_serial is None: + devices = await adb.list() + if not devices: + raise ValueError("No connected Android devices found.") + device_serial = devices[0].serial + + # Auto-setup portal if enabled (skip for SSH mode) + if self.config.device.auto_setup: + device_obj = await adb.device(serial=device_serial) + await ensure_portal_ready( + device_obj, debug=self.config.logging.debug + ) + + driver = AndroidDriver( + serial=device_serial, + use_tcp=self.resolved_device_config.use_tcp, + ) + await driver.connect() # Wrap with StealthDriver if stealth mode enabled stealth_enabled = self.config.tools and self.config.tools.stealth diff --git a/droidrun/config_example.yaml b/droidrun/config_example.yaml index 1d3161aa..aa4db675 100644 --- a/droidrun/config_example.yaml +++ b/droidrun/config_example.yaml @@ -157,6 +157,25 @@ device: platform: android # Auto-install/update Portal APK and enable accessibility before each run auto_setup: true + + # Driver type: "adb" (default) or "ssh" + # ADB mode: Requires ADB connection (USB or TCP) + # SSH mode: Connects via SSH to device, requires Portal HTTP server + driver_type: adb + + # SSH driver configuration (only used when driver_type: ssh) + # SSH target: hostname from ~/.ssh/config or user@ip + ssh_target: null # e.g., "redmi9" or "user@192.168.1.100" + + # Portal HTTP server URL (only used when driver_type: ssh) + portal_url: null # e.g., "http://192.168.1.100:8080" + + # Portal Bearer token for authentication (only used when driver_type: ssh) + portal_token: "" # e.g., "YOUR_TOKEN" + + # Path to su binary on device (only used when driver_type: ssh) + # Default: "/debug_ramdisk/su" for rooted devices with debug_ramdisk + su_path: /debug_ramdisk/su # === Telemetry Settings === telemetry: diff --git a/droidrun/config_manager/config_manager.py b/droidrun/config_manager/config_manager.py index 561f9228..6a36221e 100644 --- a/droidrun/config_manager/config_manager.py +++ b/droidrun/config_manager/config_manager.py @@ -124,6 +124,17 @@ class DeviceConfig: platform: str = "android" # "android" or "ios" auto_setup: bool = True # auto-install/fix portal before each run + # SSH driver configuration + driver_type: str = "adb" # "adb" or "ssh" + ssh_target: Optional[str] = ( + None # SSH target (e.g., "redmi9" or "user@192.168.1.100") + ) + portal_url: Optional[str] = ( + None # Portal HTTP URL (e.g., "http://192.168.1.100:8080") + ) + portal_token: str = "" # Portal Bearer token + su_path: str = "/debug_ramdisk/su" # Path to su binary on device + @dataclass class TelemetryConfig: diff --git a/droidrun/tools/__init__.py b/droidrun/tools/__init__.py index ba03239f..ce12d14a 100644 --- a/droidrun/tools/__init__.py +++ b/droidrun/tools/__init__.py @@ -4,13 +4,19 @@ from droidrun.tools import AndroidDriver, RecordingDriver, UIState, StateProvider """ -from droidrun.tools.driver import AndroidDriver, DeviceDriver, RecordingDriver +from droidrun.tools.driver import ( + AndroidDriver, + DeviceDriver, + RecordingDriver, + AndroidSSHDriver, +) from droidrun.tools.ui import AndroidStateProvider, StateProvider, UIState __all__ = [ "DeviceDriver", "AndroidDriver", "RecordingDriver", + "AndroidSSHDriver", "UIState", "StateProvider", "AndroidStateProvider", diff --git a/droidrun/tools/android/__init__.py b/droidrun/tools/android/__init__.py index 22a60da5..3d9ac134 100644 --- a/droidrun/tools/android/__init__.py +++ b/droidrun/tools/android/__init__.py @@ -1,5 +1,6 @@ """Android tools.""" from .portal_client import PortalClient +from .portal_client_http import PortalClientHTTP -__all__ = ["PortalClient"] +__all__ = ["PortalClient", "PortalClientHTTP"] diff --git a/droidrun/tools/android/portal_client_http.py b/droidrun/tools/android/portal_client_http.py new file mode 100644 index 00000000..4f8dc17f --- /dev/null +++ b/droidrun/tools/android/portal_client_http.py @@ -0,0 +1,335 @@ +""" +Portal Client HTTP - Direct HTTP communication layer for DroidRun Portal app. + +Simplified client that communicates directly via HTTP without ADB or TCP port forwarding. +Requires a Bearer token for authentication. +""" + +from __future__ import annotations + +import base64 +import json +import logging +from typing import Any, Dict, List, Optional + +import httpx + +logger = logging.getLogger("droidrun") + + +class PortalClientHTTP: + """ + HTTP-only client for DroidRun Portal communication. + + Communicates directly with the Portal HTTP server using a base URL and Bearer token. + No ADB or TCP port forwarding required. + + Usage:: + + client = PortalClientHTTP("http://192.168.1.100:8080", token="YOUR_TOKEN") + await client.connect() + state = await client.get_state() + """ + + def __init__(self, base_url: str, token: str) -> None: + """ + Initialize Portal HTTP client. + + Args: + base_url: Base URL of the Portal HTTP server, e.g. "http://192.168.1.100:8080" + token: Bearer token for authentication + """ + self.base_url = base_url.rstrip("/") + self.token = token + self._headers = {"Authorization": f"Bearer {token}"} + self._connected = False + + async def connect(self) -> None: + """Test connection to the Portal HTTP server.""" + if self._connected: + return + + if not await self._test_connection(): + raise ConnectionError( + f"Failed to connect to Portal at {self.base_url}. " + "Check the URL and token." + ) + + self._connected = True + logger.debug(f"✓ Connected to Portal HTTP server: {self.base_url}") + + async def _ensure_connected(self) -> None: + """Connect if not already connected.""" + if not self._connected: + await self.connect() + + async def _test_connection(self) -> bool: + """Test if HTTP connection to Portal is working.""" + try: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/ping", + headers=self._headers, + timeout=5, + ) + return response.status_code == 200 + except Exception as error: + logger.debug(f"Portal connection test failed: {error}") + return False + + def _extract_inner_value(self, data: Dict[str, Any]) -> Any: + """ + Extract the actual value from Portal response envelope. + + Portal wraps responses in either {"result": ...} (new format) + or {"data": ...} (legacy format). + """ + inner_key = "result" if "result" in data else "data" if "data" in data else None + if inner_key is None: + return data + + inner_value = data[inner_key] + if isinstance(inner_value, str): + try: + return json.loads(inner_value) + except json.JSONDecodeError: + return inner_value + return inner_value + + async def get_state(self) -> Dict[str, Any]: + """ + Get device state (accessibility tree + phone state). + + Returns: + Dictionary containing 'a11y_tree' and 'phone_state' keys + """ + await self._ensure_connected() + try: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/state_full", + headers=self._headers, + timeout=10, + ) + if response.status_code != 200: + return { + "error": f"HTTP {response.status_code}", + "message": response.text, + } + + data = response.json() + if isinstance(data, dict): + return self._extract_inner_value(data) + return data + + except Exception as error: + return {"error": "HTTP Error", "message": str(error)} + + async def input_text(self, text: str, clear: bool = False) -> bool: + """ + Input text via Portal keyboard. + + Args: + text: Text to input + clear: Whether to clear existing text first + + Returns: + True if successful, False otherwise + """ + await self._ensure_connected() + try: + encoded = base64.b64encode(text.encode()).decode() + payload = {"base64_text": encoded, "clear": clear} + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/keyboard/input", + json=payload, + headers=self._headers, + timeout=10, + ) + if response.status_code == 200: + logger.debug("input_text successful") + return True + logger.warning(f"input_text failed: HTTP {response.status_code}") + return False + except Exception as error: + logger.error(f"input_text error: {error}") + return False + + async def take_screenshot(self, hide_overlay: bool = True) -> bytes: + """ + Take screenshot of device. + + Args: + hide_overlay: Whether to hide Portal overlay during screenshot + + Returns: + Screenshot image bytes (PNG format) + """ + await self._ensure_connected() + try: + url = f"{self.base_url}/screenshot" + if not hide_overlay: + url += "?hideOverlay=false" + + async with httpx.AsyncClient() as client: + response = await client.get( + url, + headers=self._headers, + timeout=10.0, + ) + if response.status_code != 200: + raise RuntimeError( + f"Screenshot failed: HTTP {response.status_code}" + ) + + data = response.json() + if data.get("status") == "success": + inner_key = ( + "result" + if "result" in data + else "data" if "data" in data else None + ) + if inner_key: + logger.debug("Screenshot taken via HTTP") + return base64.b64decode(data[inner_key]) + + raise RuntimeError(f"Invalid screenshot response: {data}") + + except Exception as error: + raise RuntimeError(f"take_screenshot error: {error}") from error + + async def get_apps(self, include_system: bool = True) -> List[Dict[str, str]]: + """ + Get installed apps with package name and label. + + Args: + include_system: Whether to include system apps + + Returns: + List of dicts with 'package' and 'label' keys + """ + await self._ensure_connected() + try: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/packages", + headers=self._headers, + timeout=15, + ) + if response.status_code != 200: + raise RuntimeError(f"get_apps failed: HTTP {response.status_code}") + + data = response.json() + packages_data = ( + self._extract_inner_value(data) if isinstance(data, dict) else data + ) + + # Normalise to list + packages_list: Optional[List] = None + if isinstance(packages_data, list): + packages_list = packages_data + elif isinstance(packages_data, dict): + if "packages" in packages_data: + packages_list = packages_data["packages"] + + if not packages_list: + logger.warning("Could not extract packages list from response") + return [] + + apps = [] + for package_info in packages_list: + if not include_system and package_info.get("isSystemApp", False): + continue + apps.append( + { + "package": package_info.get("packageName", ""), + "label": package_info.get("label", ""), + } + ) + + logger.debug(f"Found {len(apps)} apps") + return apps + + except Exception as error: + logger.error(f"get_apps error: {error}") + raise ValueError(f"Error getting apps: {error}") from error + + async def get_version(self) -> str: + """Get Portal app version.""" + await self._ensure_connected() + try: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/version", + headers=self._headers, + timeout=5.0, + ) + if response.status_code == 200: + data = response.json() + if isinstance(data, dict): + inner = self._extract_inner_value(data) + if isinstance(inner, str): + return inner + return data.get("status", "unknown") + except Exception as error: + logger.debug(f"get_version error: {error}") + + return "unknown" + + async def ping(self) -> Dict[str, Any]: + """ + Test Portal connection and verify state availability. + + Returns: + Dictionary with status and connection details + """ + await self._ensure_connected() + try: + async with httpx.AsyncClient() as client: + response = await client.get( + f"{self.base_url}/ping", + headers=self._headers, + timeout=5.0, + ) + if response.status_code != 200: + return { + "status": "error", + "method": "http", + "message": f"HTTP {response.status_code}: {response.text}", + } + + try: + ping_response = response.json() if response.content else {} + except json.JSONDecodeError: + ping_response = response.text + + result: Dict[str, Any] = { + "status": "success", + "method": "http", + "url": self.base_url, + "response": ping_response, + } + + except Exception as error: + return {"status": "error", "method": "http", "message": str(error)} + + # Verify state has the required keys + try: + state = await self.get_state() + required = ("a11y_tree", "phone_state", "device_context") + missing = [key for key in required if key not in state] + if missing: + return { + "status": "error", + "method": "http", + "message": f"incompatible portal — missing {', '.join(missing)}", + } + except Exception as error: + return { + "status": "error", + "method": "http", + "message": f"state check failed: {error}", + } + + return result diff --git a/droidrun/tools/driver/__init__.py b/droidrun/tools/driver/__init__.py index 46f4a13f..cd0357c6 100644 --- a/droidrun/tools/driver/__init__.py +++ b/droidrun/tools/driver/__init__.py @@ -1,6 +1,7 @@ """Device driver abstractions for DroidRun.""" from droidrun.tools.driver.android import AndroidDriver +from droidrun.tools.driver.android_ssh import AndroidSSHDriver from droidrun.tools.driver.base import DeviceDisconnectedError, DeviceDriver from droidrun.tools.driver.cloud import CloudDriver from droidrun.tools.driver.ios import IOSDriver @@ -11,6 +12,7 @@ "DeviceDisconnectedError", "DeviceDriver", "AndroidDriver", + "AndroidSSHDriver", "CloudDriver", "IOSDriver", "RecordingDriver", diff --git a/droidrun/tools/driver/android_ssh.py b/droidrun/tools/driver/android_ssh.py new file mode 100644 index 00000000..5fc68861 --- /dev/null +++ b/droidrun/tools/driver/android_ssh.py @@ -0,0 +1,550 @@ +"""AndroidSSHDriver — SSH-based device driver. + +Wraps ``SSHDevice`` + ``PortalClientHTTP`` to provide clean device I/O +via SSH shell commands and direct HTTP Portal communication. +No ADB required. +""" + +from __future__ import annotations + +import asyncio +import io +import logging +import os +import shlex +import subprocess +from typing import Any, Dict, List, Optional, Union + +from droidrun.tools.android.portal_client_http import PortalClientHTTP +from droidrun.tools.driver.base import DeviceDriver + +logger = logging.getLogger("droidrun") + + +def _list_to_cmdline(args: Union[list, tuple]) -> str: + """Convert a list of arguments to a shell command string. + + Uses shlex.quote for proper escaping (unlike subprocess.list2cmdline). + """ + return " ".join(map(shlex.quote, args)) + + +class SSHDevice: + """Thin wrapper around SSH that mimics the adbutils Device interface. + + Executes commands on the remote Android device via SSH using + a configurable su binary path for root access. + """ + + def __init__(self, target: str, su_path: str = "/debug_ramdisk/su") -> None: + """ + Args: + target: SSH target, e.g. "redmi9" or "user@192.168.1.100" + su_path: Path to su binary on device (default: "/debug_ramdisk/su") + """ + self.target = target + self.su_path = su_path + + def shell( + self, + cmdargs: Union[str, list, tuple], + encoding: str | None = "utf-8", + ) -> str | bytes: + """Execute a shell command on the remote device and return output. + + Args: + cmdargs: Command string or list of arguments + encoding: Output encoding. Pass None to get raw bytes. + + Returns: + Command stdout as str (if encoding set) or bytes (if encoding=None). + Returns empty string/bytes on failure. + """ + if isinstance(cmdargs, (list, tuple)): + cmdargs = _list_to_cmdline(cmdargs) + + cmd = ["ssh", self.target, self.su_path, "-c", cmdargs] + + try: + result = subprocess.run( + cmd, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + except Exception as error: + logger.debug(f"SSH shell command failed: {error}") + return "" if encoding else b"" + + if encoding: + return result.stdout.decode(encoding) + return result.stdout + + def click(self, x: int, y: int) -> None: + logger.debug(f"SSHDevice click at ({x}, {y})") + self.shell(["input", "tap", str(x), str(y)]) + + def swipe( + self, + start_x: int, + start_y: int, + end_x: int, + end_y: int, + duration_seconds: float = 1.0, + ) -> None: + logger.debug( + f"SSHDevice swipe from ({start_x}, {start_y}) to ({end_x}, {end_y})" + ) + duration_ms = str(int(duration_seconds * 1000)) + self.shell( + [ + "input", + "swipe", + str(start_x), + str(start_y), + str(end_x), + str(end_y), + duration_ms, + ] + ) + + def drag( + self, + start_x: int, + start_y: int, + end_x: int, + end_y: int, + duration_seconds: float = 1.0, + ) -> None: + logger.debug( + f"SSHDevice drag from ({start_x}, {start_y}) to ({end_x}, {end_y})" + ) + duration_ms = str(int(duration_seconds * 1000)) + self.shell( + [ + "input", + "draganddrop", + str(start_x), + str(start_y), + str(end_x), + str(end_y), + duration_ms, + ] + ) + + def keyevent(self, keycode: Union[int, str]) -> None: + """Send a key event via ``input keyevent``.""" + self.shell(["input", "keyevent", str(keycode)]) + + def app_start(self, package_name: str, activity: Optional[str] = None) -> None: + """Start an app via ``am start`` or ``monkey``.""" + if activity: + self.shell(["am", "start", "-n", f"{package_name}/{activity}"]) + else: + self.shell( + [ + "monkey", + "-p", + package_name, + "-c", + "android.intent.category.LAUNCHER", + "1", + ] + ) + + def screenshot_bytes(self) -> bytes: + """Take a screenshot and return raw PNG bytes.""" + png_bytes = self.shell(["screencap", "-p"], encoding=None) + return png_bytes + + def force_stop(self, package_name: str) -> None: + """Force stop an app.""" + self.shell(["am", "force-stop", package_name]) + + +class AndroidSSHDriver(DeviceDriver): + """Android device driver using SSH for shell commands and HTTP for Portal.""" + + supported = { + "tap", + "swipe", + "input_text", + "press_key", + "start_app", + "screenshot", + "get_ui_tree", + "get_date", + "get_apps", + "list_packages", + "install_app", + "drag", + } + + def __init__( + self, + target: str, + portal_url: str, + portal_token: str, + su_path: str = "/debug_ramdisk/su", + ) -> None: + """ + Args: + target: SSH target, e.g. "redmi9" or "user@192.168.1.100" + portal_url: Base URL of the Portal HTTP server, e.g. "http://192.168.1.100:8080" + portal_token: Bearer token for Portal authentication + su_path: Path to su binary on device (default: "/debug_ramdisk/su") + """ + self._target = target + self._portal_url = portal_url + self._portal_token = portal_token + self._su_path = su_path + self.device: SSHDevice | None = None + self.portal: PortalClientHTTP | None = None + self._connected = False + # Auto-connect on initialization + asyncio.ensure_future(self.ensure_connected()) + + # -- lifecycle ----------------------------------------------------------- + + async def connect(self) -> None: + if self._connected: + return + + self.device = SSHDevice(self._target, self._su_path) + self.portal = PortalClientHTTP(self._portal_url, self._portal_token) + await self.portal.connect() + + self._connected = True + logger.debug( + f"AndroidSSHDriver connected: target={self._target}, portal={self._portal_url}, su_path={self._su_path}" + ) + + async def ensure_connected(self) -> None: + if not self._connected: + await self.connect() + + # -- input actions ------------------------------------------------------- + + async def tap(self, x: int, y: int) -> None: + await self.ensure_connected() + self.device.click(x, y) + + async def swipe( + self, + x1: int, + y1: int, + x2: int, + y2: int, + duration_ms: float = 1000, + ) -> None: + await self.ensure_connected() + self.device.swipe(x1, y1, x2, y2, duration_ms / 1000) + await asyncio.sleep(duration_ms / 1000) + + async def input_text(self, text: str, clear: bool = False) -> bool: + await self.ensure_connected() + return await self.portal.input_text(text, clear) + + async def press_key(self, keycode: int) -> None: + await self.ensure_connected() + self.device.keyevent(keycode) + + async def drag( + self, + x1: int, + y1: int, + x2: int, + y2: int, + duration: float = 3.0, + ) -> None: + await self.ensure_connected() + self.device.drag(x1, y1, x2, y2, duration) + + # -- app management ------------------------------------------------------ + + async def start_app(self, package: str, activity: Optional[str] = None) -> str: + await self.ensure_connected() + try: + logger.debug(f"Starting app {package} with activity {activity}") + if not activity: + dumpsys_output = self.device.shell( + f"cmd package resolve-activity --brief {package}" + ) + lines = dumpsys_output.strip().splitlines() + if len(lines) < 2: + raise ValueError( + f"Unexpected resolve-activity output: {dumpsys_output!r}" + ) + activity = lines[1].split("/")[1] + + logger.debug(f"Activity: {activity}") + self.device.app_start(package, activity) + logger.debug(f"App started: {package} with activity {activity}") + return f"App started: {package} with activity {activity}" + except Exception as error: + return f"Failed to start app {package}: {error}" + + async def install_app(self, path: str, **kwargs) -> str: + """Install an APK from a local path by copying it to the device via SCP then installing. + + Args: + path: Local path to the APK file + kwargs: + reinstall (bool): Reinstall if already installed (default False) + grant_permissions (bool): Grant all permissions (default True) + """ + await self.ensure_connected() + if not os.path.exists(path): + return f"Failed to install app: APK file not found at {path}" + + reinstall = kwargs.get("reinstall", False) + grant_permissions = kwargs.get("grant_permissions", True) + + remote_path = f"/data/local/tmp/{os.path.basename(path)}" + + # Copy APK to device via SCP + logger.debug(f"Copying APK to device: {path} -> {self._target}:{remote_path}") + scp_cmd = ["scp", path, f"{self._target}:{remote_path}"] + try: + subprocess.run( + scp_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + except subprocess.CalledProcessError as error: + return f"Failed to copy APK to device: {error}" + + # Build pm install flags + flags: List[str] = [] + if reinstall: + flags.append("-r") + if grant_permissions: + flags.append("-g") + + flags_str = " ".join(flags) + install_cmd = f"pm install {flags_str} {remote_path}".strip() + + logger.debug(f"Installing APK: {install_cmd}") + result = self.device.shell(install_cmd) + + # Clean up remote APK + self.device.shell(f"rm -f {remote_path}") + + logger.debug(f"Install result: {result}") + return result.strip() + + async def get_apps(self, include_system: bool = True) -> List[Dict[str, str]]: + await self.ensure_connected() + return await self.portal.get_apps(include_system) + + async def list_packages(self, include_system: bool = False) -> List[str]: + await self.ensure_connected() + filter_flag = "" if include_system else "-3" + output = self.device.shell(f"pm list packages {filter_flag}") + packages = [] + for line in output.strip().splitlines(): + line = line.strip() + if line.startswith("package:"): + packages.append(line[len("package:") :]) + return packages + + # -- state / observation ------------------------------------------------- + + async def screenshot(self, hide_overlay: bool = True) -> bytes: + await self.ensure_connected() + return await self.portal.take_screenshot(hide_overlay) + + async def get_ui_tree(self) -> Dict[str, Any]: + await self.ensure_connected() + return await self.portal.get_state() + + async def get_date(self) -> str: + await self.ensure_connected() + result = self.device.shell("date") + return result.strip() + + # -- element search ------------------------------------------------- + + def _match_text(self, text: str, pattern: str, match_type: str) -> bool: + """Match text against pattern with specified match type. + + Args: + text: Text to match + pattern: Pattern to match against + match_type: Match type ("contains", "exact", "startswith", "endswith") + + Returns: + True if match successful, False otherwise + """ + if not text or not pattern: + return False + text = str(text) if text is not None else "" + pattern = str(pattern) if pattern is not None else "" + + if match_type == "exact": + return text == pattern + elif match_type == "startswith": + return text.startswith(pattern) + elif match_type == "endswith": + return text.endswith(pattern) + else: # contains (default) + return pattern in text + + def _search_node( + self, node: Dict[str, Any], search_conditions: Dict[str, Any], match_type: str + ) -> List[Dict[str, Any]]: + """Recursively search a single node and its children. + + Args: + node: A11y node to search + search_conditions: Dictionary of conditions to match + match_type: Match type for string comparisons + + Returns: + List of matching nodes + """ + matched = [] + + # Check if current node matches all conditions + is_match = True + for key, pattern in search_conditions.items(): + node_value = node.get(key) + if node_value is None: + is_match = False + break + + if isinstance(node_value, str): + if not self._match_text(node_value, pattern, match_type): + is_match = False + break + else: + # For non-string values, use exact match + if node_value != pattern: + is_match = False + break + + if is_match: + matched.append(node) + + # Recursively search children + children = node.get("children", []) + if isinstance(children, list): + for child in children: + if isinstance(child, dict): + matched.extend( + self._search_node(child, search_conditions, match_type) + ) + + return matched + + async def search_element( + self, + search_conditions: Dict[str, Any], + match_type: str = "contains", + ) -> List[Dict[str, Any]]: + """Search for UI elements matching conditions. + + Args: + search_conditions: Dictionary of conditions to match, e.g.: + {"text": "Submit", "className": "android.widget.Button"} + match_type: Match type for string comparisons: + - "contains": Contains pattern (default) + - "exact": Exact match + - "startswith": Starts with pattern + - "endswith": Ends with pattern + + Returns: + List of matching UI elements + """ + await self.ensure_connected() + + matched_elements: List[Dict[str, Any]] = [] + + try: + state = await self.portal.get_state() + node = state.get("a11y_tree", {}) + # Search the single root node + matched_elements.extend( + self._search_node(node, search_conditions, match_type) + ) + + logger.info(f"Found {len(matched_elements)} matching elements") + if not matched_elements: + logger.debug(f"{node}") + + except KeyError: + logger.error("错误:输入数据中缺少 'a11y_tree' 键") + except Exception as e: + logger.error(f"解析 a11y_tree 失败:{str(e)}") + + return matched_elements + + # -- element actions ----------------------------------------------------- + + async def tap_element_relative( + self, + element: Dict[str, Any], + position: str = "center", + ) -> str: + """Tap on a UI element at a specific relative position. + + Args: + element: UI element dict containing boundsInScreen and other properties + position: One of "center" (default), "top", "bottom", "left", "right" + + Returns: + Result message describing the tap action + """ + await self.ensure_connected() + + try: + # Get bounds from boundsInScreen + bounds_in_screen = element.get("boundsInScreen") + if not bounds_in_screen: + return f"Error: Element has no boundsInScreen and cannot be tapped." + + # Extract coordinates from boundsInScreen dict + left = bounds_in_screen.get("left") + top = bounds_in_screen.get("top") + right = bounds_in_screen.get("right") + bottom = bounds_in_screen.get("bottom") + + if any(v is None for v in [left, top, right, bottom]): + return f"Error: Invalid boundsInScreen format for element: {bounds_in_screen}" + + # Calculate tap coordinates based on position + if position == "center": + x = (left + right) // 2 + y = (top + bottom) // 2 + elif position == "top": + x = (left + right) // 2 + y = top + 2 # 2px offset to avoid edge + elif position == "bottom": + x = (left + right) // 2 + y = bottom - 2 + elif position == "left": + x = left + 2 + y = (top + bottom) // 2 + elif position == "right": + x = right - 2 + y = (top + bottom) // 2 + else: + return f"Error: Unknown position '{position}', must be one of center/top/bottom/left/right" + + logger.debug( + f"Tapping element at position '{position}' (coordinates: {x}, {y})" + ) + + # Tap the element + await self.tap(x, y) + await asyncio.sleep(0.5) + + response_parts = [ + f"Tapped element", + f"Position: {position}", + f"Coordinates: ({x}, {y})", + f"Text: '{element.get('text', 'No text')}'", + f"Class: {element.get('className', 'Unknown class')}", + ] + return " | ".join(response_parts) + + except Exception as e: + return f"Error: {str(e)}"