diff --git a/wired_table_rec/main.py b/wired_table_rec/main.py index 161cb45..5fdf89a 100644 --- a/wired_table_rec/main.py +++ b/wired_table_rec/main.py @@ -6,19 +6,16 @@ import logging import time import traceback -from dataclasses import dataclass, asdict -from enum import Enum from pathlib import Path -from typing import List, Optional, Union, Dict, Any +from typing import List, Optional, Tuple, Union, Dict, Any import numpy as np import cv2 -from wired_table_rec.table_structure_cycle_center_net import TSRCycleCenterNet -from wired_table_rec.table_structure_unet import TSRUnet -from wired_table_rec.utils.download_model import DownloadModel +from wired_table_rec.table_line_rec import TableLineRecognition +from wired_table_rec.table_line_rec_plus import TableLineRecognitionPlus from .table_recover import TableRecover -from .utils.utils import InputType, LoadImage -from wired_table_rec.utils.utils_table_recover import ( +from .utils import InputType, LoadImage +from .utils_table_recover import ( match_ocr_cell, plot_html_table, box_4_2_poly_to_box_4_1, @@ -27,73 +24,54 @@ gather_ocr_list_by_row, ) - -class ModelType(Enum): - CYCLE_CENTER_NET = "cycle_center_net" - UNET = "unet" - - -ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/" -KEY_TO_MODEL_URL = { - ModelType.CYCLE_CENTER_NET.value: f"{ROOT_URL}/cycle_center_net.onnx", - ModelType.UNET.value: f"{ROOT_URL}/unet.onnx", -} - - -@dataclass -class WiredTableInput: - model_type: Optional[str] = ModelType.UNET.value - model_path: Union[str, Path, None, Dict[str, str]] = None - use_cuda: bool = False - device: str = "cpu" - - -@dataclass -class WiredTableOutput: - pred_html: Optional[str] = None - cell_bboxes: Optional[np.ndarray] = None - logic_points: Optional[np.ndarray] = None - elapse: Optional[float] = None +cur_dir = Path(__file__).resolve().parent +default_model_path = cur_dir / "models" / "cycle_center_net_v1.onnx" +default_model_path_v2 = cur_dir / "models" / "cycle_center_net_v2.onnx" class WiredTableRecognition: - def __init__(self, config: WiredTableInput): - self.model_type = config.model_type - if self.model_type not in KEY_TO_MODEL_URL: - model_list = ",".join(KEY_TO_MODEL_URL) - raise ValueError( - f"{self.model_type} is not supported. The currently supported models are {model_list}." - ) - - config.model_path = self.get_model_path(config.model_type, config.model_path) - if self.model_type == ModelType.CYCLE_CENTER_NET.value: - self.table_structure = TSRCycleCenterNet(asdict(config)) - else: - self.table_structure = TSRUnet(asdict(config)) - + def __init__(self, table_model_path: Union[str, Path] = None, version="v2"): self.load_img = LoadImage() + if version == "v2": + model_path = table_model_path if table_model_path else default_model_path_v2 + self.table_line_rec = TableLineRecognitionPlus(str(model_path)) + else: + model_path = table_model_path if table_model_path else default_model_path + self.table_line_rec = TableLineRecognition(str(model_path)) self.table_recover = TableRecover() + try: + self.ocr = importlib.import_module("rapidocr_onnxruntime").RapidOCR() + except ModuleNotFoundError: + self.ocr = None + def __call__( self, img: InputType, ocr_result: Optional[List[Union[List[List[float]], str, str]]] = None, **kwargs, - ) -> WiredTableOutput: + ) -> Tuple[str, float, Any, Any, Any]: + if self.ocr is None and ocr_result is None: + raise ValueError( + "One of two conditions must be met: ocr_result is not empty, or rapidocr_onnxruntime is installed." + ) + s = time.perf_counter() + rec_again = True need_ocr = True col_threshold = 15 row_threshold = 10 if kwargs: + rec_again = kwargs.get("rec_again", True) need_ocr = kwargs.get("need_ocr", True) col_threshold = kwargs.get("col_threshold", 15) row_threshold = kwargs.get("row_threshold", 10) img = self.load_img(img) - polygons, rotated_polygons = self.table_structure(img, **kwargs) + polygons, rotated_polygons = self.table_line_rec(img, **kwargs) if polygons is None: logging.warning("polygons is None.") - return WiredTableOutput("", None, None, 0.0) + return "", 0.0, None, None, None try: table_res, logi_points = self.table_recover( @@ -108,34 +86,52 @@ def __call__( sorted_polygons, idx_list = sorted_ocr_boxes( [box_4_2_poly_to_box_4_1(box) for box in polygons] ) - return WiredTableOutput( + return ( "", + time.perf_counter() - s, sorted_polygons, logi_points[idx_list], - time.perf_counter() - s, + [], ) + if ocr_result is None and need_ocr: + ocr_result, _ = self.ocr(img) cell_box_det_map, not_match_orc_boxes = match_ocr_cell(ocr_result, polygons) # 如果有识别框没有ocr结果,直接进行rec补充 - cell_box_det_map = self.fill_blank_rec(img, polygons, cell_box_det_map) + cell_box_det_map = self.re_rec(img, polygons, cell_box_det_map, rec_again) # 转换为中间格式,修正识别框坐标,将物理识别框,逻辑识别框,ocr识别框整合为dict,方便后续处理 - t_rec_ocr_list = self.transform_res(cell_box_det_map, polygons, logi_points) + t_rec_ocr_list_dict = self.transform_res(cell_box_det_map, polygons, logi_points) + # 第一行或者第一列为空时,调整代码 + #adjust_dict = self.adjust_table_cells(t_rec_ocr_list_dict) + adjust_dict = self.process_ocr_result(t_rec_ocr_list_dict) # 将每个单元格中的ocr识别结果排序和同行合并,输出的html能完整保留文字的换行格式 - t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list) + t_rec_ocr_list = self.sort_and_gather_ocr_res(t_rec_ocr_list_dict) # cell_box_map = logi_points = [t_box_ocr["t_logic_box"] for t_box_ocr in t_rec_ocr_list] cell_box_det_map = { i: [ocr_box_and_text[1] for ocr_box_and_text in t_box_ocr["t_ocr_res"]] for i, t_box_ocr in enumerate(t_rec_ocr_list) } - pred_html = plot_html_table(logi_points, cell_box_det_map) - polygons = np.array(polygons).reshape(-1, 8) - logi_points = np.array(logi_points) - elapse = time.perf_counter() - s + table_str = plot_html_table(logi_points, cell_box_det_map) + ocr_boxes_res = [ + box_4_2_poly_to_box_4_1(ori_ocr[0]) for ori_ocr in ocr_result + ] + sorted_ocr_boxes_res, _ = sorted_ocr_boxes(ocr_boxes_res) + sorted_polygons = [box_4_2_poly_to_box_4_1(box) for box in polygons] + sorted_logi_points = logi_points + table_elapse = time.perf_counter() - s except Exception: logging.warning(traceback.format_exc()) - return WiredTableOutput("", None, None, 0.0) - return WiredTableOutput(pred_html, polygons, logi_points, elapse) + return "", 0.0, None, None, None + return ( + table_str, + table_elapse, + sorted_polygons, + sorted_logi_points, + sorted_ocr_boxes_res, + adjust_dict + + ) def transform_res( self, @@ -166,6 +162,102 @@ def transform_res( res.append(dict_res) return res + def process_ocr_result(self, ocr_result): + # 删除第一行的字典,并调整其余字典的行数 + first_row_empty = [entry for entry in ocr_result if + entry['t_logic_box'][0] == 0 and entry['t_logic_box'][1] == 0 and entry['t_ocr_res'][0][ + 1] == ''] + + if len(first_row_empty) == len( + [entry for entry in ocr_result if entry['t_logic_box'][0] == 0 and entry['t_logic_box'][1] == 0]): + # 如果第一行的所有单元格都为空,删除第一行 + ocr_result = [entry for entry in ocr_result if entry['t_logic_box'][0] != 0 or entry['t_logic_box'][1] != 0] + # 调整剩余字典的行数 + for entry in ocr_result: + entry['t_logic_box'][0] -= 1 + entry['t_logic_box'][1] -= 1 + + # 删除第一列的字典,并调整其余字典的列数 + first_col_empty = [entry for entry in ocr_result if + entry['t_logic_box'][2] == 0 and entry['t_logic_box'][3] == 0 and entry['t_ocr_res'][0][ + 1] == ''] + + if len(first_col_empty) == len( + [entry for entry in ocr_result if entry['t_logic_box'][2] == 0 and entry['t_logic_box'][3] == 0]): + # 如果第一列的所有单元格都为空,删除第一列 + ocr_result = [entry for entry in ocr_result if entry['t_logic_box'][2] != 0 or entry['t_logic_box'][3] != 0] + # 调整剩余字典的列数 + for entry in ocr_result: + entry['t_logic_box'][2] -= 1 + entry['t_logic_box'][3] -= 1 + + return ocr_result + + def adjust_table_cells(self, t_rec_ocr_list_dict): + """ + 调整表格单元格,去掉第一行和/或第一列的单元格, + 并更新剩余单元格的行列起始和结束位置。 + + 参数: + t_rec_ocr_list_dict (list): 原始表格单元格识别结果,格式为 + [ + { + "t_box": [xmin, ymin, xmax, ymax], + "t_logic_box": [row_start, row_end, col_start, col_end], + "t_ocr_res": [[box, text], ...] + }, + ... + ] + + 返回: + list: 调整后的表格单元格识别结果,格式与输入相同。 + """ + # 新的结果列表 + adjusted_result = [] + + # 记录是否第一行和第一列的单元格已被删除 + remove_first_row = False + remove_first_col = False + + # 检查并移除第一行 + if all(cell and not cell[1] for cell in t_rec_ocr_list_dict[0].get("t_ocr_res", [])): + remove_first_row = True + + # 检查并移除第一列 + if all(row.get("t_ocr_res") and not row["t_ocr_res"][0][1] for row in t_rec_ocr_list_dict): + remove_first_col = True + + # 遍历原始结果进行调整 + for i, row in enumerate(t_rec_ocr_list_dict): + adjusted_row = [] + + # 如果是第一行并且需要删除,跳过这行 + if remove_first_row and i == 0: + continue + + for j, cell in enumerate(row.get("t_ocr_res", [])): + # 如果是第一列并且需要删除,跳过这一列 + if remove_first_col and j == 0: + continue + + # 更新当前单元格的逻辑位置 + adjusted_cell = { + "t_box": row.get("t_box"), + "t_logic_box": [ + row["t_logic_box"][0] - 1 if i > 0 else row["t_logic_box"][0], + row["t_logic_box"][1] - 1 if i > 0 else row["t_logic_box"][1], + row["t_logic_box"][2] - 1 if j > 0 else row["t_logic_box"][2], + row["t_logic_box"][3] - 1 if j > 0 else row["t_logic_box"][3] + ], + "t_ocr_res": cell + } + adjusted_row.append(adjusted_cell) + + if adjusted_row: + adjusted_result.append(adjusted_row) + + return adjusted_result + def sort_and_gather_ocr_res(self, res): for i, dict_res in enumerate(res): _, sorted_idx = sorted_ocr_boxes( @@ -177,19 +269,30 @@ def sort_and_gather_ocr_res(self, res): ) return res - def fill_blank_rec( + def re_rec( self, img: np.ndarray, sorted_polygons: np.ndarray, cell_box_map: Dict[int, List[str]], + rec_again=True, ) -> Dict[int, List[Any]]: """找到poly对应为空的框,尝试将直接将poly框直接送到识别中""" for i in range(sorted_polygons.shape[0]): if cell_box_map.get(i): continue + if not rec_again: + box = sorted_polygons[i] + cell_box_map[i] = [[box, "", 1]] + continue + crop_img = get_rotate_crop_image(img, sorted_polygons[i]) + pad_img = cv2.copyMakeBorder( + crop_img, 5, 5, 100, 100, cv2.BORDER_CONSTANT, value=(255, 255, 255) + ) + rec_res, _ = self.ocr(pad_img, use_det=False, use_cls=True, use_rec=True) box = sorted_polygons[i] - cell_box_map[i] = [[box, "", 1]] - continue + text = [rec[0] for rec in rec_res] + scores = [rec[1] for rec in rec_res] + cell_box_map[i] = [[box, "".join(text), min(scores)]] return cell_box_map def re_rec_high_precise( @@ -222,28 +325,6 @@ def re_rec_high_precise( ] return cell_box_map - @staticmethod - def get_model_path( - model_type: str, model_path: Union[str, Path, None] - ) -> Union[str, Dict[str, str]]: - if model_path is not None: - return model_path - - model_url = KEY_TO_MODEL_URL.get(model_type, None) - if isinstance(model_url, str): - model_path = DownloadModel.download(model_url) - return model_path - - if isinstance(model_url, dict): - model_paths = {} - for k, url in model_url.items(): - model_paths[k] = DownloadModel.download( - url, save_model_name=f"{model_type}_{Path(url).name}" - ) - return model_paths - - raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.") - def main(): parser = argparse.ArgumentParser() @@ -251,17 +332,17 @@ def main(): args = parser.parse_args() try: - ocr_engine = importlib.import_module("rapidocr").RapidOCR() + ocr_engine = importlib.import_module("rapidocr_onnxruntime").RapidOCR() except ModuleNotFoundError as exc: raise ModuleNotFoundError( - "Please install the rapidocr by pip install rapidocr." + "Please install the rapidocr_onnxruntime by pip install rapidocr_onnxruntime." ) from exc - input_args = WiredTableInput() - table_rec = WiredTableRecognition(input_args) + + table_rec = WiredTableRecognition() ocr_result, _ = ocr_engine(args.img_path) - table_results = table_rec(args.img_path, ocr_result) - print(table_results.pred_html) - print(f"cost: {table_results.elapse:.5f}") + table_str, elapse = table_rec(args.img_path, ocr_result) + print(table_str) + print(f"cost: {elapse:.5f}") if __name__ == "__main__":