diff --git a/tests/common/test_viz.py b/tests/common/test_jupyter.py similarity index 63% rename from tests/common/test_viz.py rename to tests/common/test_jupyter.py index c74d671..70f7b9c 100644 --- a/tests/common/test_viz.py +++ b/tests/common/test_jupyter.py @@ -1,9 +1,8 @@ import pytest from PIL import Image -from vlmrun.common.viz import ( +from vlmrun.common.jupyter import ( DisplayOptions, - xywh_to_xyxy, - extract_bbox, + BoundingBox, get_boxes_from_response, ensure_image, render_bbox_image, @@ -82,39 +81,79 @@ def test_display_options_validation(): DisplayOptions(image_width=0).validate_image_width() -def test_xywh_to_xyxy(): - """Test conversion from XYWH to XYXY format.""" - xywh = (0.1, 0.2, 0.3, 0.4) - xyxy = xywh_to_xyxy(xywh) - expected = (0.1, 0.2, 0.4, 0.6) - assert all(pytest.approx(a) == b for a, b in zip(xyxy, expected)) +def test_bounding_box_creation(): + """Test BoundingBox creation and validation.""" + # Test direct XYXY creation + bbox = BoundingBox(0.1, 0.2, 0.3, 0.4) + assert bbox.x1 == pytest.approx(0.1) + assert bbox.y1 == pytest.approx(0.2) + assert bbox.x2 == pytest.approx(0.3) + assert bbox.y2 == pytest.approx(0.4) - xywh = (10, 20, 30, 40) - xyxy = xywh_to_xyxy(xywh) - expected = (10, 20, 40, 60) - assert all(pytest.approx(a) == b for a, b in zip(xyxy, expected)) + # Test coordinate normalization + bbox = BoundingBox(0.3, 0.2, 0.1, 0.4) + assert bbox.x1 == pytest.approx(0.1) # Should swap to ensure x1 < x2 + assert bbox.x2 == pytest.approx(0.3) + # Test invalid coordinates + with pytest.raises(ValueError): + BoundingBox(-0.1, 0.2, 0.3, 0.4) -def test_extract_bbox(): - """Test bounding box extraction from various formats.""" - result = extract_bbox([0.1, 0.2, 0.3, 0.4]) - expected = (0.1, 0.2, 0.3, 0.4) - assert all(pytest.approx(a) == b for a, b in zip(result, expected)) - - result = extract_bbox({"bbox": [0.1, 0.2, 0.3, 0.4]}) - expected = (0.1, 0.2, 0.3, 0.4) - assert all(pytest.approx(a) == b for a, b in zip(result, expected)) - - result = extract_bbox({"xywh": [0.1, 0.2, 0.3, 0.4]}) - expected = (0.1, 0.2, 0.4, 0.6) - assert all(pytest.approx(a) == b for a, b in zip(result, expected)) - - result = extract_bbox({"bbox": {"xywh": [0.1, 0.2, 0.3, 0.4]}}) - expected = (0.1, 0.2, 0.4, 0.6) - assert all(pytest.approx(a) == b for a, b in zip(result, expected)) - - assert extract_bbox({"invalid": "format"}) is None - assert extract_bbox([1, 2, 3]) is None + with pytest.raises(ValueError): + BoundingBox(0.1, 0.2, 1.1, 0.4) + + # Test from_xywh + bbox = BoundingBox.from_xywh(0.1, 0.2, 0.3, 0.4) + assert bbox.x1 == pytest.approx(0.1) + assert bbox.y1 == pytest.approx(0.2) + assert bbox.x2 == pytest.approx(0.4) # x + width + assert bbox.y2 == pytest.approx(0.6) # y + height + + # Test from_dict with various formats + bbox = BoundingBox.from_dict([0.1, 0.2, 0.3, 0.4]) + assert bbox.x1 == pytest.approx(0.1) + assert bbox.y1 == pytest.approx(0.2) + assert bbox.x2 == pytest.approx(0.3) + assert bbox.y2 == pytest.approx(0.4) + + bbox = BoundingBox.from_dict({"bbox": [0.1, 0.2, 0.3, 0.4]}) + assert bbox.x1 == pytest.approx(0.1) + assert bbox.y1 == pytest.approx(0.2) + assert bbox.x2 == pytest.approx(0.3) + assert bbox.y2 == pytest.approx(0.4) + + bbox = BoundingBox.from_dict({"bbox": {"xywh": [0.1, 0.2, 0.3, 0.4]}}) + assert bbox.x1 == pytest.approx(0.1) + assert bbox.y1 == pytest.approx(0.2) + assert bbox.x2 == pytest.approx(0.4) + assert bbox.y2 == pytest.approx(0.6) + + # Test properties + bbox = BoundingBox(0.1, 0.2, 0.4, 0.6) + assert bbox.width == pytest.approx(0.3) + assert bbox.height == pytest.approx(0.4) + assert bbox.area == pytest.approx(0.12) + assert bbox.center[0] == pytest.approx(0.25) + assert bbox.center[1] == pytest.approx(0.4) + + # Test pixel coordinates + coords = bbox.to_pixel_coords(100, 100) + assert coords == ( + 10, + 20, + 40, + 60, + ) # Integer coordinates, so exact comparison is fine + + # Test to_dict + bbox = BoundingBox( + 0.1, 0.2, 0.3, 0.4, content="test", confidence=0.9, field="field1" + ) + d = bbox.to_dict() + assert all(pytest.approx(a) == b for a, b in zip(d["bbox"], [0.1, 0.2, 0.3, 0.4])) + assert d["content"] == "test" + assert d["confidence"] == pytest.approx(0.9) + assert d["field"] == "field1" def test_get_boxes_from_response(sample_response): @@ -123,32 +162,40 @@ def test_get_boxes_from_response(sample_response): assert len(boxes) == 4 # street, city, dob, full_name metadata # Check street metadata extraction - street_box = next(box for box in boxes if box.get("field") == "address.street") - expected = (0.349, 0.588, 0.755, 0.634) # xywh converted to xyxy - assert all(pytest.approx(a) == b for a, b in zip(street_box["bbox"], expected)) - assert street_box["content"] == "10 WONDERFUL DRIVE" - assert pytest.approx(street_box["confidence"]) == 0.9 + street_box = next(box for box in boxes if box.field == "address.street") + assert street_box.x1 == pytest.approx(0.349) + assert street_box.y1 == pytest.approx(0.588) + assert street_box.x2 == pytest.approx(0.755) # x + width + assert street_box.y2 == pytest.approx(0.634) # y + height + assert street_box.content == "10 WONDERFUL DRIVE" + assert street_box.confidence == pytest.approx(0.9) # Check city metadata extraction - city_box = next(box for box in boxes if box.get("field") == "address.city") - expected = (0.347, 0.640, 0.534, 0.681) # xywh converted to xyxy - assert all(pytest.approx(a) == b for a, b in zip(city_box["bbox"], expected)) - assert city_box["content"] == "MONTGOMERY" - assert pytest.approx(city_box["confidence"]) == 1.0 + city_box = next(box for box in boxes if box.field == "address.city") + assert city_box.x1 == pytest.approx(0.347) + assert city_box.y1 == pytest.approx(0.640) + assert city_box.x2 == pytest.approx(0.534) # x + width + assert city_box.y2 == pytest.approx(0.681) # y + height + assert city_box.content == "MONTGOMERY" + assert city_box.confidence == pytest.approx(1.0) # Check date of birth metadata extraction - dob_box = next(box for box in boxes if box.get("field") == "date_of_birth") - expected = (0.349, 0.431, 0.484, 0.480) # xywh converted to xyxy - assert all(pytest.approx(a) == b for a, b in zip(dob_box["bbox"], expected)) - assert dob_box["content"] == "01-05-1948" - assert pytest.approx(dob_box["confidence"]) == 1.0 + dob_box = next(box for box in boxes if box.field == "date_of_birth") + assert dob_box.x1 == pytest.approx(0.349) + assert dob_box.y1 == pytest.approx(0.431) + assert dob_box.x2 == pytest.approx(0.484) # x + width + assert dob_box.y2 == pytest.approx(0.480) # y + height + assert dob_box.content == "01-05-1948" + assert dob_box.confidence == pytest.approx(1.0) # Check full name metadata extraction - name_box = next(box for box in boxes if box.get("field") == "full_name") - expected = (0.398, 0.783, 0.853, 0.955) # xywh converted to xyxy - assert all(pytest.approx(a) == b for a, b in zip(name_box["bbox"], expected)) - assert name_box["content"] == "Connor Sample" - assert pytest.approx(name_box["confidence"]) == 1.0 + name_box = next(box for box in boxes if box.field == "full_name") + assert name_box.x1 == pytest.approx(0.398) + assert name_box.y1 == pytest.approx(0.783) + assert name_box.x2 == pytest.approx(0.853) # x + width + assert name_box.y2 == pytest.approx(0.955) # y + height + assert name_box.content == "Connor Sample" + assert name_box.confidence == pytest.approx(1.0) def test_ensure_image(sample_image, tmp_path): diff --git a/vlmrun/common/viz.py b/vlmrun/common/jupyter.py similarity index 76% rename from vlmrun/common/viz.py rename to vlmrun/common/jupyter.py index 129b86e..c02dc81 100644 --- a/vlmrun/common/viz.py +++ b/vlmrun/common/jupyter.py @@ -27,9 +27,165 @@ "tensor", } -Coordinates4 = Tuple[float, float, float, float] -BoundingBox = Coordinates4 # (x1, y1, x2, y2) -XYWHBox = Coordinates4 # (x, y, width, height) + +class BoundingBox: + """A bounding box with normalized [0,1] coordinates in XYXY format.""" + + def __init__( + self, + x1: float, + y1: float, + x2: float, + y2: float, + *, + content: Optional[str] = None, + confidence: Optional[float] = None, + field: Optional[str] = None, + ): + """Initialize a bounding box. + + Args: + x1, y1: Top-left coordinates (normalized) + x2, y2: Bottom-right coordinates (normalized) + content: Optional content/label + confidence: Optional confidence score + field: Optional field name + """ + # Validate coordinates are in [0,1] range + for coord in (x1, y1, x2, y2): + if not 0 <= coord <= 1: + raise ValueError( + f"Coordinates must be normalized to [0,1] range, got {coord}" + ) + + # Ensure x1,y1 is top-left and x2,y2 is bottom-right + self.x1 = min(x1, x2) + self.y1 = min(y1, y2) + self.x2 = max(x1, x2) + self.y2 = max(y1, y2) + + self.content = content + self.confidence = confidence + self.field = field + + @classmethod + def from_xywh( + cls, x: float, y: float, width: float, height: float, **kwargs + ) -> "BoundingBox": + """Create a bounding box from XYWH format. + + Args: + x: Left coordinate + y: Top coordinate + width: Width of box + height: Height of box + **kwargs: Additional arguments passed to __init__ + + Returns: + BoundingBox instance + """ + return cls(x, y, x + width, y + height, **kwargs) + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> Optional["BoundingBox"]: + """Create a bounding box from a dictionary representation. + + Handles various formats: + - Direct bbox list/tuple + - Dict with 'bbox' key + - Dict with 'xywh' key + - Dict with 'bbox' containing nested 'xywh' + + Args: + data: Dictionary containing bounding box data + + Returns: + BoundingBox instance or None if no valid bbox found + """ + if isinstance(data, (list, tuple)) and len(data) == 4: + return cls(*data) + + if not isinstance(data, dict): + return None + + kwargs = { + "content": data.get("bbox_content") or data.get("content"), + "confidence": data.get("confidence"), + "field": data.get("field"), + } + + if "bbox" in data: + bbox = data["bbox"] + if isinstance(bbox, dict) and "xywh" in bbox: + return cls.from_xywh(*bbox["xywh"], **kwargs) + elif isinstance(bbox, (list, tuple)) and len(bbox) == 4: + return cls(*bbox, **kwargs) + elif "xywh" in data: + return cls.from_xywh(*data["xywh"], **kwargs) + + return None + + @property + def width(self) -> float: + """Get width of bounding box.""" + return self.x2 - self.x1 + + @property + def height(self) -> float: + """Get height of bounding box.""" + return self.y2 - self.y1 + + @property + def area(self) -> float: + """Get area of bounding box.""" + return self.width * self.height + + @property + def center(self) -> Tuple[float, float]: + """Get center point of bounding box.""" + return ((self.x1 + self.x2) / 2, (self.y1 + self.y2) / 2) + + def to_pixel_coords( + self, image_width: int, image_height: int + ) -> Tuple[int, int, int, int]: + """Convert normalized coordinates to pixel coordinates. + + Args: + image_width: Width of the image in pixels + image_height: Height of the image in pixels + + Returns: + Tuple of (x1, y1, x2, y2) in pixel coordinates + """ + return ( + int(self.x1 * image_width), + int(self.y1 * image_height), + int(self.x2 * image_width), + int(self.y2 * image_height), + ) + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary representation.""" + d = {"bbox": [self.x1, self.y1, self.x2, self.y2]} + if self.content is not None: + d["content"] = self.content + if self.confidence is not None: + d["confidence"] = self.confidence + if self.field is not None: + d["field"] = self.field + return d + + def __repr__(self) -> str: + attrs = [f"bbox=[{self.x1:.3f}, {self.y1:.3f}, {self.x2:.3f}, {self.y2:.3f}]"] + if self.content: + attrs.append(f"content='{self.content}'") + if self.confidence is not None: + attrs.append(f"confidence={self.confidence:.3f}") + if self.field: + attrs.append(f"field='{self.field}'") + return f"BoundingBox({', '.join(attrs)})" + + ImageType = Union[str, Path, Image.Image] ResultType = Union[Dict[str, Any], BaseModel] ImageInfoType = Dict[str, Any] @@ -60,51 +216,7 @@ def validate_image_width(self) -> None: raise ValueError("image_width must be a positive integer") -def xywh_to_xyxy(box: XYWHBox) -> BoundingBox: - """Convert bounding box from (x, y, width, height) to (x1, y1, x2, y2) format. - - Args: - box: Tuple of (x, y, width, height) coordinates - - Returns: - Tuple of (x1, y1, x2, y2) coordinates - """ - x, y, w, h = box - return (x, y, x + w, y + h) - - -def extract_bbox(value: Union[Dict, List]) -> Optional[BoundingBox]: - """Extract bounding box from various formats. - - Handles: - - Direct bbox list/tuple - - Dict with 'bbox' key - - Dict with 'xywh' key - - Dict with 'bbox' containing nested 'xywh' - - Args: - value: Value containing bounding box information - - Returns: - Bounding box as (x1, y1, x2, y2) or None if no valid bbox found - """ - if isinstance(value, (list, tuple)) and len(value) == 4: - return tuple(value) - - if isinstance(value, dict): - if "bbox" in value: - bbox = value["bbox"] - if isinstance(bbox, dict) and "xywh" in bbox: - return xywh_to_xyxy(tuple(bbox["xywh"])) - elif isinstance(bbox, (list, tuple)) and len(bbox) == 4: - return tuple(bbox) - elif "xywh" in value: - return xywh_to_xyxy(tuple(value["xywh"])) - - return None - - -def get_boxes_from_response(response: Union[Dict, Any]) -> List[Dict[str, BoundingBox]]: +def get_boxes_from_response(response: Union[Dict, Any]) -> List[BoundingBox]: """Extract bounding boxes from VLM Run response. Handles various response formats including: @@ -117,7 +229,7 @@ def get_boxes_from_response(response: Union[Dict, Any]) -> List[Dict[str, Boundi response: Raw response dictionary or object with response attribute Returns: - List of dictionaries containing bounding box coordinates in (x1, y1, x2, y2) format + List of BoundingBox instances """ if hasattr(response, "response"): response = response.response @@ -129,16 +241,9 @@ def get_boxes_from_response(response: Union[Dict, Any]) -> List[Dict[str, Boundi def process_metadata(metadata: Dict, field_name: str) -> None: if isinstance(metadata, dict): - bbox = extract_bbox(metadata) + bbox = BoundingBox.from_dict({**metadata, "field": field_name}) if bbox: - boxes.append( - { - "bbox": bbox, - "field": field_name, - "content": metadata.get("bbox_content"), - "confidence": metadata.get("confidence"), - } - ) + boxes.append(bbox) def process_dict(d: Dict, prefix: str = "") -> None: for key, value in d.items(): @@ -153,14 +258,14 @@ def process_dict(d: Dict, prefix: str = "") -> None: for key in ["bounding_boxes", "boxes"]: if key in response and isinstance(response[key], list): for box in response[key]: - bbox = extract_bbox(box) + bbox = BoundingBox.from_dict(box) if bbox: - boxes.append({"bbox": bbox}) + boxes.append(bbox) if "bbox" in response: - bbox = extract_bbox(response["bbox"]) + bbox = BoundingBox.from_dict(response) if bbox: - boxes.append({"bbox": bbox}) + boxes.append(bbox) # Process all metadata fields, including nested ones process_dict(response) @@ -247,29 +352,22 @@ def render_bbox_image( # Draw boxes for box in boxes: try: - x1, y1, x2, y2 = box["bbox"] - if not all(isinstance(coord, (int, float)) for coord in (x1, y1, x2, y2)): - raise ValueError(f"Invalid box coordinates: {box['bbox']}") - - # Convert normalized coordinates to pixel coordinates - x1_px = int(x1 * img_width) - y1_px = int(y1 * img_height) - x2_px = int(x2 * img_width) - y2_px = int(y2 * img_height) + # Get pixel coordinates + x1_px, y1_px, x2_px, y2_px = box.to_pixel_coords(img_width, img_height) # Draw rectangle cv2.rectangle(img, (x1_px, y1_px), (x2_px, y2_px), box_color, box_thickness) # Draw label if enabled and available - if (show_content and "content" in box and box["content"]) or ( - show_confidence and "confidence" in box + if (show_content and box.content) or ( + show_confidence and box.confidence is not None ): # Build label text label_parts = [] - if show_content and box.get("content"): - label_parts.append(box["content"]) - if show_confidence and box.get("confidence") is not None: - label_parts.append(f"({box['confidence']:.2f})") + if show_content and box.content: + label_parts.append(box.content) + if show_confidence and box.confidence is not None: + label_parts.append(f"({box.confidence:.2f})") if label_parts: # Only draw if we have something to show label = " ".join(label_parts) @@ -303,7 +401,7 @@ def render_bbox_image( thickness, ) - except (KeyError, ValueError, TypeError) as e: + except (ValueError, TypeError) as e: raise ValueError(f"Invalid bounding box format: {e}") # Convert back to PIL