Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 46 additions & 18 deletions chandra/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,29 @@ class LayoutBlock:
bbox: list[int]
label: str
content: str
table_row_bboxes: list[list[list[int]]]

def normalize_and_clean_bbox(
bbox: list[int], width: int, height: int, width_scaler: float, height_scaler: float
):
return [
max(0, int(bbox[0] * width_scaler)),
max(0, int(bbox[1] * height_scaler)),
min(int(bbox[2] * width_scaler), width),
min(int(bbox[3] * height_scaler), height),
]

def parse_bbox(bbox_str: str) -> list[float]:
try:
bbox = json.loads(bbox_str)
assert len(bbox) == 4, "Invalid bbox length"
except Exception:
try:
bbox = bbox_str.split(" ")
assert len(bbox) == 4, "Invalid bbox length"
except Exception:
bbox = [0, 0, 1, 1]
return bbox

def parse_layout(html: str, image: Image.Image, bbox_scale=settings.BBOX_SCALE):
soup = BeautifulSoup(html, "html.parser")
Expand All @@ -197,28 +219,34 @@ def parse_layout(html: str, image: Image.Image, bbox_scale=settings.BBOX_SCALE):
layout_blocks = []
for div in top_level_divs:
bbox = div.get("data-bbox")

try:
bbox = json.loads(bbox)
assert len(bbox) == 4, "Invalid bbox length"
except Exception:
try:
bbox = bbox.split(" ")
assert len(bbox) == 4, "Invalid bbox length"
except Exception:
bbox = [0, 0, 1, 1]

bbox = parse_bbox(bbox)
bbox = list(map(int, bbox))
# Normalize bbox
bbox = [
max(0, int(bbox[0] * width_scaler)),
max(0, int(bbox[1] * height_scaler)),
min(int(bbox[2] * width_scaler), width),
min(int(bbox[3] * height_scaler), height),
]
bbox = normalize_and_clean_bbox(
bbox, width, height, width_scaler, height_scaler
)
label = div.get("data-label", "block")
content = str(div.decode_contents())
layout_blocks.append(LayoutBlock(bbox=bbox, label=label, content=content))

soup = BeautifulSoup(content, "html.parser")
all_table_row_bboxes = []
for table in soup.find_all("table"):
table_row_bboxes = []
for row in table.find_all("tr"):
if row_bbox := row.get("data-bbox", None):
row_bbox = parse_bbox(row_bbox)
row_bbox = list(map(int, row_bbox))
row_bbox = normalize_and_clean_bbox(
row_bbox, width, height, width_scaler, height_scaler
)
table_row_bboxes.append(row_bbox)
all_table_row_bboxes.append(table_row_bboxes)

layout_blocks.append(
LayoutBlock(
bbox=bbox, label=label, content=content, table_row_bboxes=all_table_row_bboxes
)
)
return layout_blocks


Expand Down
33 changes: 33 additions & 0 deletions chandra/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,40 @@
{PROMPT_ENDING}
""".strip()

OCR_LAYOUT_TABLE_ROW_PROMPT = f"""
OCR this image to HTML, arranged as layout blocks. Each layout block should be a div with the data-bbox attribute representing the bounding box of the block in [x0, y0, x1, y1] format. Bboxes are normalized 0-1000. The data-label attribute is the label for the block.

In addition to layout blocks, table rows must also include bounding boxes.
Each table row (<tr>) inside a Table or Table-Of-Contents must have its own data-bbox attribute and skipping this is not allowed.

Use the following labels:
- Caption
- Footnote
- Equation-Block
- List-Group
- Page-Header
- Page-Footer
- Image
- Section-Header
- Table
- Text
- Complex-Block
- Code-Block
- Form
- Table-Of-Contents
- Figure

{PROMPT_ENDING}

Table Row Guidelines:
* For every Table or Table-Of-Contents, each table row (<tr>) must include a data-bbox attribute in [x0, y0, x1, y1] format.
* The bounding box must cover the full visual extent of the row.
* Table rows without a data-bbox attribute are incomplete.
""".strip()


PROMPT_MAPPING = {
"ocr_layout": OCR_LAYOUT_PROMPT,
"ocr": OCR_PROMPT,
"ocr_layout_table_row": OCR_LAYOUT_TABLE_ROW_PROMPT,
}
12 changes: 11 additions & 1 deletion chandra/scripts/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,11 @@ def embed_images_in_markdown(markdown: str, images: dict) -> str:
def ocr_layout(
img: Image.Image,
model=None,
prompt_type: str = "ocr_layout",
) -> (Image.Image, str):
batch = BatchInputItem(
image=img,
prompt_type="ocr_layout",
prompt_type=prompt_type,
)
result = model.generate([batch])[0]
layout = parse_layout(result.raw, img)
Expand Down Expand Up @@ -108,6 +109,14 @@ def ocr_layout(
pil_image = Image.open(in_file).convert("RGB")
page_number = None

prompt_type = st.sidebar.selectbox(
"Prompt type",
[
"ocr_layout",
"ocr",
"ocr_layout_table_row",
],
)
run_ocr = st.sidebar.button("Run OCR")

if pil_image is None:
Expand All @@ -120,6 +129,7 @@ def ocr_layout(
result, layout_image = ocr_layout(
pil_image,
model,
prompt_type,
)

# Embed images as base64 data URLs in the markdown
Expand Down
9 changes: 9 additions & 0 deletions chandra/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,13 @@ def draw_layout(image: Image.Image, layout_blocks: list[LayoutBlock]):
draw.rectangle(block.bbox, outline="red", width=2)
draw.text((block.bbox[0], block.bbox[1]), block.label, fill="blue")

for table in block.table_row_bboxes:
for row_idx, row_bbox in enumerate(table):
draw.rectangle(row_bbox, outline="green", width=2)
draw.text(
(row_bbox[0], row_bbox[1]),
f"Row {row_idx}",
fill="green",
)

return draw_image