diff --git a/docling_ibm_models/layoutmodel/layout_predictor.py b/docling_ibm_models/layoutmodel/layout_predictor.py index 60ab1a5..b0e6a27 100644 --- a/docling_ibm_models/layoutmodel/layout_predictor.py +++ b/docling_ibm_models/layoutmodel/layout_predictor.py @@ -5,7 +5,7 @@ import logging import os from collections.abc import Iterable -from typing import Set, Union +from typing import Set, Union, List import numpy as np import torch @@ -133,8 +133,8 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]: page_img = Image.fromarray(orig_img).convert("RGB") else: raise TypeError("Not supported input image format") - resize = {"height": self._image_size, "width": self._image_size} + inputs = self._image_processor( images=page_img, return_tensors="pt", @@ -175,3 +175,90 @@ def predict(self, orig_img: Union[Image.Image, np.ndarray]) -> Iterable[dict]: "label": label_str, "confidence": score, } + + + @torch.inference_mode() + def predict_batch(self, orig_img: List[Union[Image.Image, np.ndarray]]) -> Iterable[dict]: + """ + Predict bounding boxes for a batch of page images. + The origin (0, 0) is the top-left corner and the predicted bbox coords are provided as: + [left, top, right, bottom] + + Parameter + --------- + origin_img: List of images to be predicted as a PIL Image object or numpy array. + + Yield + ----- + Iterable per page of bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b" + + Raises + ------ + TypeError when the input image is not supported + """ + # Convert image format + if isinstance(orig_img[0], Image.Image): + page_img = [img.convert("RGB") for img in orig_img] + elif isinstance(orig_img[0], np.ndarray): + page_img = [Image.fromarray(img).convert("RGB") for img in orig_img] + else: + raise TypeError("Not supported input image format") + + resize = {"height": self._image_size, "width": self._image_size} + inputs = self._image_processor( + images=page_img, + return_tensors="pt", + size=resize, + ).to(self._device) + + target_sizes = torch.tensor([page_img[i].size[::-1] for i in range(len(page_img))]) + + outputs = self._model(**inputs) + + results = self._image_processor.post_process_object_detection( + outputs, + target_sizes=target_sizes, + threshold=self._threshold, + ) + + for batch_item_idx, result in enumerate(results): + w, h = page_img[batch_item_idx].size + yield self.postprocess_result(result, w, h) + + def postprocess_result(self, result: dict, w: int, h: int) -> Iterable[dict]: + """ + Postprocess the result of the layout prediction. + + Parameters + ---------- + result: The result of the layout prediction. + w: The width of the image. + h: The height of the image. + + Yields + ------ + Bounding box as a dict with the keys: "label", "confidence", "l", "t", "r", "b" + """ + for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]): + score = float(score.item()) + + label_id = int(label_id.item()) + 1 # Advance the label_id + label_str = self._classes_map[label_id] + + # Filter out blacklisted classes + if label_str in self._black_classes: + continue + + bbox_float = [float(b.item()) for b in box] + l = min(w, max(0, bbox_float[0])) + t = min(h, max(0, bbox_float[1])) + r = min(w, max(0, bbox_float[2])) + b = min(h, max(0, bbox_float[3])) + yield { + "l": l, + "t": t, + "r": r, + "b": b, + "label": label_str, + "confidence": score, + }