diff --git a/test_unstructured_inference/inference/test_layout_element.py b/test_unstructured_inference/inference/test_layout_element.py index b0cd2966..b12a409a 100644 --- a/test_unstructured_inference/inference/test_layout_element.py +++ b/test_unstructured_inference/inference/test_layout_element.py @@ -1,4 +1,5 @@ -from unstructured_inference.inference.layoutelement import LayoutElement, TextRegion +from unstructured_inference.inference.layoutelement import LayoutElement, LayoutElements, TextRegion +import numpy as np def test_layout_element_do_dict(mock_layout_element): @@ -18,3 +19,32 @@ def test_layout_element_from_region(mock_rectangle): region = TextRegion(bbox=mock_rectangle) assert LayoutElement.from_region(region) == expected + + +def test_layout_elements_iter_support(): + coords = np.array([[0, 0, 100, 100]]) + texts = np.array(["sample"]) + probs = np.array([0.9]) + class_ids = np.array([0]) + class_id_map = {0: "Text"} + sources = np.array(["test_source"]) + text_as_html = np.array(["
sample
"]) + table_as_cells = np.array([None]) + + layout_elements = LayoutElements( + element_coords=coords, + texts=texts, + element_probs=probs, + element_class_ids=class_ids, + element_class_id_map=class_id_map, + sources=sources, + text_as_html=text_as_html, + table_as_cells=table_as_cells, + ) + + # New feature test: __iter__() works + elements = list(layout_elements) + assert len(elements) == 1 + assert isinstance(elements[0], LayoutElement) + assert elements[0].text == "sample" + assert elements[0].type == "Text" diff --git a/unstructured_inference/inference/elements.py b/unstructured_inference/inference/elements.py index 81647ced..1c27be9d 100644 --- a/unstructured_inference/inference/elements.py +++ b/unstructured_inference/inference/elements.py @@ -229,6 +229,9 @@ def __post_init__(self): def __getitem__(self, indices) -> TextRegions: return self.slice(indices) + def __iter__(self): + return self.iter_elements() + def slice(self, indices) -> TextRegions: """slice text regions based on indices""" return TextRegions( diff --git a/unstructured_inference/inference/layoutelement.py b/unstructured_inference/inference/layoutelement.py index 5b4c6fda..a68fe5aa 100644 --- a/unstructured_inference/inference/layoutelement.py +++ b/unstructured_inference/inference/layoutelement.py @@ -78,6 +78,9 @@ def __eq__(self, other: object) -> bool: def __getitem__(self, indices): return self.slice(indices) + def __iter__(self): + return self.iter_elements() + def slice(self, indices) -> LayoutElements: """slice and return only selected indices""" return LayoutElements(