diff --git a/examples/structure.rs b/examples/structure.rs index 29256d2..2d15f08 100644 --- a/examples/structure.rs +++ b/examples/structure.rs @@ -674,49 +674,62 @@ fn main() -> Result<(), Box> { // Collect all results for potential concatenation let mut all_results: Vec = Vec::new(); - // Process each input source - for (idx, source) in std::mem::take(&mut input_sources).into_iter().enumerate() { + // Collect images and metadata for batch processing (cross-page formula batching) + let mut images: Vec = Vec::new(); + let mut source_meta: Vec<(String, String)> = Vec::new(); // (source_path, source_stem) + + for source in std::mem::take(&mut input_sources) { let source_path = source.path(); - let source_stem = { - match &source { - InputSource::ImageFile(p) => p - .file_stem() - .and_then(|s| s.to_str()) - .unwrap_or("result") - .to_string(), - InputSource::PdfPage { - pdf_path, - page_number, - .. - } => { - format!( - "{}_page_{:03}", - pdf_path - .file_stem() - .and_then(|s| s.to_str()) - .unwrap_or("pdf"), - page_number - ) - } + let source_stem = match &source { + InputSource::ImageFile(p) => p + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("result") + .to_string(), + InputSource::PdfPage { + pdf_path, + page_number, + .. + } => { + format!( + "{}_page_{:03}", + pdf_path + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("pdf"), + page_number + ) } }; - info!("\nProcessing input {}: {}", idx + 1, source_path); - - let image = match source.into_image() { - Ok(img) => img, + match source.into_image() { + Ok(img) => { + images.push(img); + source_meta.push((source_path, source_stem)); + } Err(err) => { - error!("Failed to load image: {}", err); - continue; + error!("Failed to load image {}: {}", source_path, err); } - }; + } + } + + info!( + "Batch processing {} image(s) with cross-page formula batching", + images.len() + ); + let batch_results = analyzer.predict_images(images); - let mut result = match analyzer.predict_image(image) { + // Process each result: assign metadata, save, visualize, log + for (idx, (page_result, (source_path, source_stem))) in + batch_results.into_iter().zip(source_meta).enumerate() + { + let mut result = match page_result { Ok(res) => res, Err(err) => { error!("Failed to analyze {}: {}", source_path, err); continue; } }; + info!("\nProcessed input {}: {}", idx + 1, source_path); result.input_path = std::sync::Arc::from(source_path.clone()); // Always collect results for potential concatenation diff --git a/oar-ocr-core/src/domain/structure.rs b/oar-ocr-core/src/domain/structure.rs index 4fcf687..9da5ff9 100644 --- a/oar-ocr-core/src/domain/structure.rs +++ b/oar-ocr-core/src/domain/structure.rs @@ -68,7 +68,7 @@ fn semantic_title_level_and_format(cleaned: &str) -> Option<(usize, String)> { keyword.as_str(), "ABSTRACT" | "INTRODUCTION" | "REFERENCES" | "REFERENCE" ) { - return Some((1, trimmed.to_string())); + return Some((2, trimmed.to_string())); } if let Some(captures) = TITLE_NUMBERING_REGEX.captures(cleaned) { @@ -494,8 +494,9 @@ impl StructureResult { let mut md = String::new(); let elements = &self.layout_elements; let paragraph_title_levels = infer_paragraph_title_levels(elements); - let mut last_label: Option = None; - let mut prev_element: Option<&LayoutElement> = None; + // Track the most recent Text/ReferenceContent element so paragraph + // continuation works across intervening figures/tables. + let mut prev_text_element: Option<&LayoutElement> = None; for (idx, element) in elements.iter().enumerate() { // PP-StructureV3 markdown ignores auxiliary labels. @@ -530,10 +531,10 @@ impl StructureResult { // Determine seg_start_flag for paragraph continuity (PaddleX get_seg_flag). // When both current and previous are "text" and seg_start_flag is false, // they belong to the same paragraph — join without \n\n separator. - let seg_start_flag = get_seg_flag(element, prev_element); + let seg_start_flag = get_seg_flag(element, prev_text_element); let is_continuation = element.element_type == LayoutElementType::Text - && last_label == Some(LayoutElementType::Text) + && prev_text_element.is_some() && !seg_start_flag; // Add separator between elements @@ -547,9 +548,18 @@ impl StructureResult { if !md.is_empty() { md.push_str("\n\n"); } - md.push_str("# "); if let Some(text) = &element.text { let cleaned = clean_ocr_text(text); + // Downgrade section-level keywords to ## when misclassified as DocTitle + let keyword = cleaned.trim().trim_end_matches(':').to_ascii_uppercase(); + if matches!( + keyword.as_str(), + "ABSTRACT" | "INTRODUCTION" | "REFERENCES" | "REFERENCE" + ) { + md.push_str("## "); + } else { + md.push_str("# "); + } md.push_str(&cleaned); } } @@ -621,27 +631,40 @@ impl StructureResult { }; // Check if this formula is on the same line as adjacent text elements - // to determine if it's an inline formula or display formula + // to determine if it's an inline formula or display formula. + // Only consider the nearest non-formula/non-formula-number neighbor + // on each side, and require BOTH sides to have text on the same line. + // This prevents display formulas from being misclassified as inline + // when they happen to be vertically aligned with a distant text block. let is_inline = { - // Look for previous non-formula text element on the same line - let has_prev_text = (0..idx).rev().any(|i| { - let prev = &elements[i]; - !prev.element_type.is_formula() - && (prev.element_type == LayoutElementType::Text + let has_prev_text = (0..idx) + .rev() + .find(|&i| { + let t = elements[i].element_type; + !t.is_formula() && t != LayoutElementType::FormulaNumber + }) + .is_some_and(|i| { + let prev = &elements[i]; + (prev.element_type == LayoutElementType::Text || prev.element_type == LayoutElementType::ReferenceContent) - && is_same_line(&element.bbox, &prev.bbox) - }); - - // Look for next non-formula text element on the same line - let has_next_text = ((idx + 1)..elements.len()).any(|i| { - let next = &elements[i]; - !next.element_type.is_formula() - && (next.element_type == LayoutElementType::Text + && is_same_line(&element.bbox, &prev.bbox) + }); + + let has_next_text = ((idx + 1)..elements.len()) + .find(|&i| { + let t = elements[i].element_type; + !t.is_formula() && t != LayoutElementType::FormulaNumber + }) + .is_some_and(|i| { + let next = &elements[i]; + (next.element_type == LayoutElementType::Text || next.element_type == LayoutElementType::ReferenceContent) - && is_same_line(&element.bbox, &next.bbox) - }); + && is_same_line(&element.bbox, &next.bbox) + }); - has_prev_text || has_next_text + // Require text on BOTH sides for inline — a formula with text + // only on one side is almost always a display equation. + has_prev_text && has_next_text }; if is_inline { @@ -788,8 +811,13 @@ impl StructureResult { // Default text elements - following PaddleX's text handling _ => { if let Some(text) = &element.text { - // For text continuation (same paragraph), join directly - if is_continuation { + let cleaned = clean_ocr_text(text); + if has_bullet_markers(&cleaned) { + if !md.is_empty() { + md.push_str("\n\n"); + } + format_as_bullet_list(&cleaned, &mut md); + } else if is_continuation { let formatted = format_text_block(text); md.push_str(&formatted); } else { @@ -803,8 +831,11 @@ impl StructureResult { } } - last_label = Some(element.element_type); - prev_element = Some(element); + if element.element_type == LayoutElementType::Text + || element.element_type == LayoutElementType::ReferenceContent + { + prev_text_element = Some(element); + } } md.trim().to_string() } @@ -1342,6 +1373,29 @@ fn format_vision_footnote_block(text: &str) -> String { step1.replace('\n', "\n\n") } +/// Bullet marker characters commonly found in OCR text. +const BULLET_MARKERS: &[char] = &['•', '●', '◦', '▪', '◆']; + +/// Checks if text contains bullet markers that should be formatted as a list. +fn has_bullet_markers(text: &str) -> bool { + BULLET_MARKERS.iter().any(|&m| text.contains(m)) +} + +/// Formats text with bullet markers as a markdown list. +/// +/// Splits on any bullet marker character so mixed markers (e.g. `• item1 ▪ item2`) +/// are all handled correctly. +fn format_as_bullet_list(text: &str, md: &mut String) { + for item in text.split(|c: char| BULLET_MARKERS.contains(&c)) { + let item = item.trim(); + if !item.is_empty() { + md.push_str("- "); + md.push_str(item); + md.push('\n'); + } + } +} + /// Checks if a character is a Chinese character. /// /// Used to determine spacing rules when concatenating pages. @@ -2602,11 +2656,11 @@ mod tests { #[test] fn test_format_title_with_level_keywords() { let (level, text) = format_title_with_level("Abstract", None); - assert_eq!(level, 1); + assert_eq!(level, 2); assert_eq!(text, "Abstract"); let (level, text) = format_title_with_level("References:", None); - assert_eq!(level, 1); + assert_eq!(level, 2); assert_eq!(text, "References:"); } diff --git a/oar-ocr-core/src/processors/formula_preprocess.rs b/oar-ocr-core/src/processors/formula_preprocess.rs index ace0cab..13e8d55 100644 --- a/oar-ocr-core/src/processors/formula_preprocess.rs +++ b/oar-ocr-core/src/processors/formula_preprocess.rs @@ -169,7 +169,7 @@ impl FormulaPreprocessor { let final_width = new_width.min(target_width); let final_height = new_height.min(target_height); - let resized = resize(img, final_width, final_height, FilterType::Lanczos3); + let resized = resize(img, final_width, final_height, FilterType::Triangle); // Calculate padding to center the image let delta_width = target_width - final_width; diff --git a/src/oarocr/stitching.rs b/src/oarocr/stitching.rs index 426bfb8..1b72df9 100644 --- a/src/oarocr/stitching.rs +++ b/src/oarocr/stitching.rs @@ -939,7 +939,7 @@ impl ResultStitcher { continue; } if i != candidate_indices.len() - 1 && !content.ends_with(' ') { - content.push(' '); + content.push_str("
"); } } joined.push_str(&content); @@ -965,30 +965,19 @@ impl ResultStitcher { } // --- Sort cells into rows --- - // When detected bboxes are available we sort them (better spatial accuracy) - // to pick the IoA bbox for OCR matching. We also independently sort the - // structure cells so that the td→cell text-assignment step uses a valid - // index into `cells[]`. Without this separation the det-bbox sort indices - // are silently reused as structure-cell indices, misassigning OCR to wrong - // cells whenever the two orderings differ. - let (match_sorted_indices, cell_sorted_indices, match_row_flags) = - if let Some(det_bboxes) = cell_bboxes_override { - let temp_cells: Vec = det_bboxes - .iter() - .map(|b| TableCell::new(b.clone(), 0.5)) - .collect(); - let (det_sorted, row_flags) = - Self::sort_table_cells_boxes(&temp_cells, row_y_tolerance); - // Sort structure cells independently so their indices stay valid. - let (cell_sorted, _) = Self::sort_table_cells_boxes(cells, row_y_tolerance); - (det_sorted, cell_sorted, row_flags) - } else { - let (sorted, row_flags) = Self::sort_table_cells_boxes(cells, row_y_tolerance); - // When there is no override the two index lists are identical. - (sorted.clone(), sorted, row_flags) - }; - - if match_sorted_indices.is_empty() || match_row_flags.is_empty() { + // Sort structure cells — their bboxes drive both IoA matching and the + // td→cell text-assignment step. Detected-cell bboxes (cell_bboxes_override) + // are intentionally NOT used for IoA because the detected model can produce + // a different cell count per row than the structure tokens, causing local_idx + // to diverge from td_index and corrupting OCR-to-cell assignments. + // + // When cell_bboxes_override is present, cross-row OCR deduplication is + // enabled downstream to prevent large detected cells spanning multiple + // structure rows from duplicating content. + let (cell_sorted_indices, cell_row_flags) = + Self::sort_table_cells_boxes(cells, row_y_tolerance); + + if cell_sorted_indices.is_empty() || cell_row_flags.is_empty() { return None; } @@ -997,9 +986,10 @@ impl ResultStitcher { return None; } - // Align match row flags with structure token row boundaries - let mut match_aligned = Self::map_and_get_max(&match_row_flags, &row_start_index); - match_aligned.push(match_sorted_indices.len()); + // Align structure-cell row flags with structure-token row boundaries. + // cell_aligned is used both for IoA matching (correct space) and td→cell mapping. + let mut cell_aligned = Self::map_and_get_max(&cell_row_flags, &row_start_index); + cell_aligned.push(cell_sorted_indices.len()); row_start_index.push( structure_tokens .iter() @@ -1009,27 +999,36 @@ impl ResultStitcher { // --- Per-row matching: cell → OCR (PaddleX style) --- // For each cell in the row, collect ALL OCR boxes with IoA > 0.7. - // No cross-row deduplication — each row independently checks all OCR boxes, - // matching PaddleX v2 behavior. The 0.7 IoA threshold naturally prevents - // false cross-row matches. + // When using detected cell bboxes (cell_bboxes_override is Some), apply + // cross-row deduplication: an OCR box already claimed by an earlier row is + // not re-matched in a later row. This prevents large detected cells that + // span multiple structure rows from duplicating their content across those rows. + // In pure E2E mode (cell_bboxes_override is None) the PaddleX v2 behavior of + // independent per-row matching is preserved. + let use_cross_row_dedup = cell_bboxes_override.is_some(); + let mut globally_matched_ocr: std::collections::HashSet = + std::collections::HashSet::new(); let mut all_matched: Vec>> = Vec::new(); - for k in 0..match_aligned.len().saturating_sub(1) { - let row_start = match_aligned[k].min(match_sorted_indices.len()); - let row_end = match_aligned[k + 1].min(match_sorted_indices.len()); + for k in 0..cell_aligned.len().saturating_sub(1) { + let row_start = cell_aligned[k].min(cell_sorted_indices.len()); + let row_end = cell_aligned[k + 1].min(cell_sorted_indices.len()); let mut matched: std::collections::HashMap> = std::collections::HashMap::new(); - for (local_idx, &bbox_idx) in - match_sorted_indices[row_start..row_end].iter().enumerate() + for (local_idx, &cell_idx) in cell_sorted_indices[row_start..row_end].iter().enumerate() { - // Use detected bbox directly when available, else structure cell bbox - let cell_box = cell_bboxes_override - .and_then(|bbs| bbs.get(bbox_idx)) - .unwrap_or_else(|| &cells[bbox_idx.min(cells.len() - 1)].bbox); + // Always use structure cell bbox for IoA matching. Detected-cell bboxes + // (cell_bboxes_override) are not used here because their cell count per + // row can differ from the structure td count, causing local_idx to + // diverge from td_index and corrupt the OCR-to-cell assignment. + let cell_box = &cells[cell_idx.min(cells.len() - 1)].bbox; for (ocr_idx, (_, ocr_region)) in ocr_candidates.iter().enumerate() { + if use_cross_row_dedup && globally_matched_ocr.contains(&ocr_idx) { + continue; + } // IoA = intersection / OCR_area (PaddleX compute_inter > 0.7) let ioa = ocr_region.bounding_box.ioa(cell_box); if ioa > 0.7 { @@ -1038,6 +1037,12 @@ impl ResultStitcher { } } + if use_cross_row_dedup { + for indices in matched.values() { + globally_matched_ocr.extend(indices.iter().copied()); + } + } + all_matched.push(matched); } @@ -1070,11 +1075,11 @@ impl ResultStitcher { } // Map td position to the original cell index via sorted ordering. - // match_aligned[matched_row_idx] + td_index gives the position in the - // sorted cell list. Use cell_sorted_indices (indices into cells[]) - // rather than match_sorted_indices (which may be indices into det_bboxes - // when cell_bboxes_override is active). - let mapped_cell_idx = match_aligned + // Use cell_aligned (derived from structure-cell row flags) rather than + // match_aligned (derived from detected-cell row flags). When the two + // models disagree on cell count per row, using match_aligned here would + // offset into the wrong row of cell_sorted_indices. + let mapped_cell_idx = cell_aligned .get(matched_row_idx) .copied() .and_then(|row_start| { @@ -1315,7 +1320,7 @@ impl ResultStitcher { continue; } if i != matched_indices.len() - 1 && !content.ends_with(' ') { - content.push(' '); + content.push_str("
"); } } diff --git a/src/oarocr/structure.rs b/src/oarocr/structure.rs index a1b0e31..8783a2f 100644 --- a/src/oarocr/structure.rs +++ b/src/oarocr/structure.rs @@ -1169,6 +1169,17 @@ pub struct OARStructure { pipeline: StructurePipeline, } +/// Intermediate result from preprocessing and layout detection for a single page. +/// Produced by `OARStructure::prepare_page` and consumed by `complete_page`. +struct PreparedPage { + current_image: std::sync::Arc, + orientation_angle: Option, + rectified_img: Option>, + rotation: Option, + layout_elements: Vec, + detected_region_blocks: Option>, +} + impl OARStructure { /// Refinement of overall OCR results using layout boxes. /// @@ -2242,19 +2253,9 @@ impl OARStructure { Ok(result) } - /// Analyzes the structure of a document image. - /// - /// This method is the core implementation for structure analysis and can be called - /// directly with an in-memory image. - /// - /// # Arguments - /// - /// * `image` - The input RGB image - /// - /// # Returns - /// - /// A `StructureResult` containing detected layout elements, tables, formulas, and text. - pub fn predict_image(&self, image: image::RgbImage) -> Result { + /// Preprocesses a page image and runs layout detection, returning intermediate + /// results ready for formula recognition and downstream processing. + fn prepare_page(&self, image: image::RgbImage) -> Result { use crate::oarocr::preprocess::DocumentPreprocessor; use std::sync::Arc; @@ -2268,11 +2269,38 @@ impl OARStructure { let rectified_img = preprocess.rectified_img; let rotation = preprocess.rotation; - let (mut layout_elements, mut detected_region_blocks) = + let (layout_elements, detected_region_blocks) = self.detect_layout_and_regions(¤t_image)?; + Ok(PreparedPage { + current_image, + orientation_angle, + rectified_img, + rotation, + layout_elements, + detected_region_blocks, + }) + } + + /// Completes page analysis given a `PreparedPage` and pre-computed formula results. + /// Runs seal detection, OCR, table analysis, stitching, and coordinate transforms. + fn complete_page( + &self, + prepared: PreparedPage, + mut formulas: Vec, + ) -> Result { + use std::sync::Arc; + + let PreparedPage { + current_image, + orientation_angle, + rectified_img, + rotation, + mut layout_elements, + mut detected_region_blocks, + } = prepared; + let mut tables = Vec::new(); - let mut formulas = self.recognize_formulas(¤t_image, &layout_elements)?; self.detect_seal_text(¤t_image, &mut layout_elements)?; @@ -2481,6 +2509,123 @@ impl OARStructure { Ok(result) } + + /// Analyzes the structure of a single document image. + pub fn predict_image(&self, image: image::RgbImage) -> Result { + let prepared = self.prepare_page(image)?; + let formulas = + self.recognize_formulas(&prepared.current_image, &prepared.layout_elements)?; + self.complete_page(prepared, formulas) + } + + /// Analyzes multiple document page images with cross-page formula batching. + /// + /// All formula crops from every page are collected first and forwarded to the + /// formula adapter in a single `execute` call, reducing ONNX inference overhead + /// compared to calling [`predict_image`] sequentially. Layout detection and all + /// other per-page steps are still performed independently per page. + /// + /// Per-page errors are returned individually so that a failure on one page does + /// not abort the remaining pages. + pub fn predict_images( + &self, + images: Vec, + ) -> Vec> { + use oar_ocr_core::core::traits::task::ImageTaskInput; + use oar_ocr_core::domain::structure::FormulaResult; + use oar_ocr_core::utils::BBoxCrop; + + if images.is_empty() { + return Vec::new(); + } + + // Phase 1: Preprocessing + layout detection for every page. + // Pages that fail preparation are recorded as Err and skipped in later phases. + let prepared_pages: Vec> = images + .into_iter() + .map(|image| self.prepare_page(image)) + .collect(); + + // Phase 2: Batch formula recognition across all successfully prepared pages. + let num_pages = prepared_pages.len(); + let mut per_page_formulas: Vec> = + (0..num_pages).map(|_| Vec::new()).collect(); + + if let Some(ref formula_adapter) = self.pipeline.formula_recognition_adapter { + let mut all_crops: Vec = Vec::new(); + let mut crop_meta: Vec<(usize, oar_ocr_core::processors::BoundingBox)> = Vec::new(); + + for (page_idx, prepared) in prepared_pages.iter().enumerate() { + let prepared = match prepared { + Ok(p) => p, + Err(_) => continue, + }; + for elem in prepared + .layout_elements + .iter() + .filter(|e| e.element_type.is_formula()) + { + match BBoxCrop::crop_bounding_box(&prepared.current_image, &elem.bbox) { + Ok(crop) => { + all_crops.push(crop); + crop_meta.push((page_idx, elem.bbox.clone())); + } + Err(err) => { + tracing::warn!("Formula region crop failed (batch): {}", err); + } + } + } + } + + if !all_crops.is_empty() { + let batch_size = formula_adapter.recommended_batch_size().max(1); + let mut remaining_crops = all_crops; + let mut meta_offset = 0; + + while !remaining_crops.is_empty() { + let chunk_len = batch_size.min(remaining_crops.len()); + let rest = remaining_crops.split_off(chunk_len); + let chunk_vec = remaining_crops; + remaining_crops = rest; + + let chunk_meta = &crop_meta[meta_offset..meta_offset + chunk_len]; + match formula_adapter.execute(ImageTaskInput::new(chunk_vec), None) { + Ok(formula_output) => { + for ((page_idx, bbox), (formula_text, score)) in + chunk_meta.iter().cloned().zip( + formula_output + .formulas + .into_iter() + .zip(formula_output.scores), + ) + { + let width = bbox.x_max() - bbox.x_min(); + let height = bbox.y_max() - bbox.y_min(); + if width > 0.0 && height > 0.0 { + per_page_formulas[page_idx].push(FormulaResult { + bbox, + latex: formula_text, + confidence: score.unwrap_or(0.0), + }); + } + } + } + Err(err) => { + tracing::warn!("Batch formula recognition failed: {}", err); + } + } + meta_offset += chunk_len; + } + } + } + + // Phase 3: Complete each page with its pre-computed formula results. + prepared_pages + .into_iter() + .zip(per_page_formulas) + .map(|(prepared, formulas)| self.complete_page(prepared?, formulas)) + .collect() + } } #[cfg(test)] diff --git a/src/oarocr/table_analyzer.rs b/src/oarocr/table_analyzer.rs index eeda3a6..4a60991 100644 --- a/src/oarocr/table_analyzer.rs +++ b/src/oarocr/table_analyzer.rs @@ -80,7 +80,6 @@ fn cluster_positions(mut positions: Vec, tolerance: f32) -> Vec { return Vec::new(); } positions.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); - positions.dedup_by(|a, b| (*a - *b).abs() < 0.1); let mut clustered = Vec::new(); let mut current_cluster = vec![positions[0]];