diff --git a/src/askui/__init__.py b/src/askui/__init__.py index 48be5f72..868e7d58 100644 --- a/src/askui/__init__.py +++ b/src/askui/__init__.py @@ -1,6 +1,6 @@ """AskUI Vision Agent""" -__version__ = "0.22.2" +__version__ = "0.22.3" import logging import os diff --git a/src/askui/agent_base.py b/src/askui/agent_base.py index ec0919e2..bd4fa826 100644 --- a/src/askui/agent_base.py +++ b/src/askui/agent_base.py @@ -18,6 +18,7 @@ from askui.models.shared.tools import Tool, ToolCollection from askui.tools.agent_os import AgentOs from askui.tools.android.agent_os import AndroidAgentOs +from askui.utils.annotation_writer import AnnotationWriter from askui.utils.image_utils import ImageSource from askui.utils.source_utils import InputSource, load_image_source @@ -25,6 +26,7 @@ from .models.exceptions import ElementNotFoundError, WaitUntilError from .models.model_router import ModelRouter, initialize_default_model_registry from .models.models import ( + DetectedElement, ModelChoice, ModelName, ModelRegistry, @@ -507,6 +509,102 @@ def locate_all( ) return self._locate(locator=locator, screenshot=screenshot, model=model) + @telemetry.record_call(exclude={"screenshot"}) + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def locate_all_elements( + self, + screenshot: Optional[InputSource] = None, + model: ModelComposition | None = None, + ) -> list[DetectedElement]: + """Locate all elements in the current screen using AskUI Models. + + Args: + screenshot (InputSource | None, optional): The screenshot to use for + locating the elements. Can be a path to an image file, a PIL Image + object or a data URL. If `None`, takes a screenshot of the currently + selected display. + model (ModelComposition | None, optional): The model composition + to be used for locating the elements. + + Returns: + list[DetectedElement]: A list of detected elements + + Example: + ```python + from askui import VisionAgent + + with VisionAgent() as agent: + detected_elements = agent.locate_all_elements() + print(f"Found {len(detected_elements)} elements: {detected_elements}") + ``` + """ + _screenshot = load_image_source( + self._agent_os.screenshot() if screenshot is None else screenshot + ) + return self._model_router.locate_all_elements( + image=_screenshot, model=model or ModelName.ASKUI + ) + + @telemetry.record_call(exclude={"screenshot", "annotation_dir"}) + @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) + def annotate( + self, + screenshot: InputSource | None = None, + annotation_dir: str = "annotations", + model: ModelComposition | None = None, + ) -> None: + """Annotate the screenshot with the detected elements. + Creates an interactive HTML file with the detected elements + and saves it to the annotation directory. + The HTML file can be opened in a browser to see the annotated image. + The user can hover over the elements to see their names and text value + and click on the box to copy the text value to the clipboard. + + Args: + screenshot (ImageSource | None, optional): The screenshot to annotate. + If `None`, takes a screenshot of the currently selected display. + annotation_dir (str): The directory to save the annotated + image. Defaults to "annotations". + model (ModelComposition | None, optional): The composition + of the model(s) to be used for annotating the image. + If `None`, uses the default model. + + Example Using VisionAgent: + ```python + from askui import VisionAgent + + with VisionAgent() as agent: + agent.annotate() + ``` + + Example Using AndroidVisionAgent: + ```python + from askui import AndroidVisionAgent + + with AndroidVisionAgent() as agent: + agent.annotate() + ``` + + Example Using VisionAgent with custom screenshot and annotation directory: + ```python + from askui import VisionAgent + + with VisionAgent() as agent: + agent.annotate(screenshot="screenshot.png", annotation_dir="htmls") + ``` + """ + if screenshot is None: + screenshot = self._agent_os.screenshot() + + detected_elements = self.locate_all_elements( + screenshot=screenshot, + model=model, + ) + AnnotationWriter( + image=screenshot, + elements=detected_elements, + ).save_to_dir(annotation_dir) + @telemetry.record_call(exclude={"until"}) @validate_call(config=ConfigDict(arbitrary_types_allowed=True)) def wait( diff --git a/src/askui/models/askui/models.py b/src/askui/models/askui/models.py index 53e5948a..9586316b 100644 --- a/src/askui/models/askui/models.py +++ b/src/askui/models/askui/models.py @@ -17,6 +17,7 @@ QueryUnexpectedResponseError, ) from askui.models.models import ( + DetectedElement, GetModel, LocateModel, ModelComposition, @@ -131,6 +132,34 @@ def _locate( for element in detected_elements ] + @override + def locate_all_elements( + self, + image: ImageSource, + model: ModelComposition | str, + ) -> list[DetectedElement]: + request_body: dict[str, Any] = { + "image": image.to_data_url(), + "instruction": "get all elements", + } + + if isinstance(model, ModelComposition): + request_body["modelComposition"] = model.model_dump(by_alias=True) + logger.debug( + "Model composition", + extra={ + "modelComposition": json_lib.dumps(request_body["modelComposition"]) + }, + ) + + response = self._inference_api.post(path="/inference", json=request_body) + content = response.json() + assert content["type"] == "DETECTED_ELEMENTS", ( + f"Received unknown content type {content['type']}" + ) + detected_elements = content["data"]["detected_elements"] + return [DetectedElement.from_json(element) for element in detected_elements] + class AskUiGetModel(GetModel): """A GetModel implementation that is supposed to be as comprehensive and diff --git a/src/askui/models/model_router.py b/src/askui/models/model_router.py index 8114510b..e0c02fb2 100644 --- a/src/askui/models/model_router.py +++ b/src/askui/models/model_router.py @@ -18,6 +18,7 @@ from askui.models.models import ( MODEL_TYPES, ActModel, + DetectedElement, GetModel, LocateModel, Model, @@ -145,7 +146,6 @@ def tars_handler() -> UiTarsApiHandler: ModelName.ASKUI__COMBO: askui_locate_model, ModelName.ASKUI__OCR: askui_locate_model, ModelName.ASKUI__PTA: askui_locate_model, - ModelName.CLAUDE__SONNET__4__20250514: lambda: anthropic_facade("anthropic"), ModelName.HF__SPACES__ASKUI__PTA_1: hf_spaces_handler, ModelName.HF__SPACES__QWEN__QWEN2_VL_2B_INSTRUCT: hf_spaces_handler, ModelName.HF__SPACES__QWEN__QWEN2_VL_7B_INSTRUCT: hf_spaces_handler, @@ -265,3 +265,17 @@ def locate( extra={"model": _model}, ) return m.locate(locator, screenshot, _model_composition or _model) + + def locate_all_elements( + self, + image: ImageSource, + model: ModelComposition | str, + ) -> list[DetectedElement]: + _model = ModelName.ASKUI if isinstance(model, ModelComposition) else model + _model_composition = model if isinstance(model, ModelComposition) else None + m, _model = self._get_model(_model, "locate") + logger.debug( + "Routing locate_all_elements prediction to", + extra={"model": _model}, + ) + return m.locate_all_elements(image, model=_model_composition or _model) diff --git a/src/askui/models/models.py b/src/askui/models/models.py index 5020733e..c188499a 100644 --- a/src/askui/models/models.py +++ b/src/askui/models/models.py @@ -153,6 +153,80 @@ def __getitem__(self, index: int) -> ModelDefinition: """ +class BoundingBox(BaseModel): + model_config = ConfigDict( + extra="ignore", + ) + + xmin: int + ymin: int + xmax: int + ymax: int + + @staticmethod + def from_json(data: dict[str, float]) -> "BoundingBox": + return BoundingBox( + xmin=int(data["xmin"]), + ymin=int(data["ymin"]), + xmax=int(data["xmax"]), + ymax=int(data["ymax"]), + ) + + def __str__(self) -> str: + return f"[{self.xmin}, {self.ymin}, {self.xmax}, {self.ymax}]" + + @property + def width(self) -> int: + """The width of the bounding box.""" + return self.xmax - self.xmin + + @property + def height(self) -> int: + """The height of the bounding box.""" + return self.ymax - self.ymin + + @property + def center(self) -> Point: + """The center point of the bounding box.""" + return int((self.xmin + self.xmax) / 2), int((self.ymin + self.ymax) / 2) + + +class DetectedElement(BaseModel): + model_config = ConfigDict( + extra="ignore", + ) + + name: str + text: str + bounding_box: BoundingBox + + @staticmethod + def from_json(data: dict[str, str | float | dict[str, float]]) -> "DetectedElement": + return DetectedElement( + name=str(data["name"]), + text=str(data["text"]), + bounding_box=BoundingBox.from_json(data["bndbox"]), # type: ignore + ) + + def __str__(self) -> str: + return f"[name={self.name}, text={self.text}, bndbox={str(self.bounding_box)}]" + + @property + def center(self) -> Point: + """The center point of the detected element.""" + return self.bounding_box.center + + @property + def width(self) -> int: + """The width of the detected element.""" + return self.bounding_box.width + + @property + def height(self) -> int: + """The height of the detected element.""" + return self.bounding_box.height + + class ActModel(abc.ABC): """Abstract base class for models that can execute autonomous actions. @@ -336,6 +410,23 @@ def locate( """ raise NotImplementedError + def locate_all_elements( + self, + image: ImageSource, + model: ModelComposition | str, + ) -> list[DetectedElement]: + """Locate all elements in an image. + + Args: + image (ImageSource): The image to analyze (screenshot or provided image) + model (ModelComposition | str): Either a string model name or a + `ModelComposition` for models that support composition + + Returns: + A list of detected elements + """ + raise NotImplementedError + Model = ActModel | GetModel | LocateModel """Union type of all abstract model classes. diff --git a/src/askui/models/shared/facade.py b/src/askui/models/shared/facade.py index bd63c4f0..27c06fb3 100644 --- a/src/askui/models/shared/facade.py +++ b/src/askui/models/shared/facade.py @@ -5,6 +5,7 @@ from askui.locators.locators import Locator from askui.models.models import ( ActModel, + DetectedElement, GetModel, LocateModel, ModelComposition, @@ -65,3 +66,11 @@ def locate( model: ModelComposition | str, ) -> PointList: return self._locate_model.locate(locator, image, model) + + @override + def locate_all_elements( + self, + image: ImageSource, + model: ModelComposition | str, + ) -> list[DetectedElement]: + return self._locate_model.locate_all_elements(image, model) diff --git a/src/askui/utils/annotation_writer.py b/src/askui/utils/annotation_writer.py new file mode 100644 index 00000000..179f5cc6 --- /dev/null +++ b/src/askui/utils/annotation_writer.py @@ -0,0 +1,271 @@ +from datetime import datetime, timezone +from html import escape +from pathlib import Path + +from askui.models.models import DetectedElement +from askui.utils.source_utils import InputSource, load_image_source + + +class AnnotationWriter: + """ + A writer that generates an interactive HTML file with annotated image elements. + + The generated HTML file displays an image with bounding boxes around detected + elements. Users can hover over elements to see their names and click to copy + their text values to the clipboard. + + Args: + image (InputSource): The image source to annotate. Can be a path to an + image file, a PIL Image object, or a data URL. + elements (list[DetectedElement]): A list of detected elements to annotate + on the image. Each element should have a name, text, and bounding box. + """ + + def __init__( + self, + image: InputSource, + elements: list[DetectedElement], + ): + self._encoded_image = load_image_source(image).to_data_url() + self._elements = elements + + def _get_style(self) -> str: + """ + Generate the CSS styles for the annotation HTML. + + Returns: + str: CSS styles as a string for styling bounding boxes, tooltips, + and the copy message. + """ + return """ + + """ + + def _get_script(self) -> str: + """ + Generate the JavaScript code for interactive features. + + Returns: + str: JavaScript code as a string for handling copy-to-clipboard + functionality, tooltip positioning, and copy message display. + """ + return """ + + """ + + def _get_elements_html(self) -> str: + """ + Generate HTML for all detected elements. + + Returns: + str: Concatenated HTML string for all bounding box elements. + """ + return "".join(self._get_box_html(element) for element in self._elements) + + def _get_box_html(self, element: DetectedElement) -> str: + """ + Generate HTML for a single detected element's bounding box. + + Args: + element (DetectedElement): The detected element to generate HTML for. + + Returns: + str: HTML string for a single bounding box with tooltip and click + handler. + """ + bbox = element.bounding_box + + escaped_text = escape(element.text or "") + escaped_name = escape(element.name or "") + + # safe HTML text for tooltips + tooltip_text = escaped_name + if escaped_text: + tooltip_text += f": {escaped_text}" + + style = ( + f"left:{bbox.xmin}px; " + f"top:{bbox.ymin}px; " + f"width:{bbox.width}px; " + f"height:{bbox.height}px;" + ) + + return f""" +
+ {tooltip_text} +
+ """ + + def _get_full_html(self) -> str: + """ + Generate the complete HTML document with all annotations. + + Returns: + str: Complete HTML document as a string including DOCTYPE, head, + styles, scripts, and body with annotated image. + """ + return f""" + + + + + Image Annotated By AskUI + + {self._get_style()} + + {self._get_script()} + + + + +
📋 Copied to clipboard!
+ +
+ + {self._get_elements_html()} +
+ + + + """ + + def save_to_dir(self, annotation_dir: Path | str) -> Path: + """ + Write the annotated HTML file to the annotation directory. + + Creates the annotation directory if it doesn't exist and generates a + timestamped filename for the HTML file. + + Args: + annotation_dir (Path | str): The directory where the + HTML file will be saved. + + Returns: + Path: The path to the written HTML file. + """ + if isinstance(annotation_dir, str): + annotation_dir = Path(annotation_dir) + if not annotation_dir.exists(): + annotation_dir.mkdir(parents=True, exist_ok=True) + + current_timestamp = datetime.now(timezone.utc).strftime("%Y%m%d%H%M%S") + file_name = f"annotated_image_{current_timestamp}.html" + file_path = annotation_dir / file_name + html_content = self._get_full_html() + + with file_path.open("w", encoding="utf-8") as f: + f.write(html_content) + + return file_path