You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

257 lines
8.3 KiB
Python

# -*- encoding: utf-8 -*-
# @Author: SWHL
# @Contact: liekkaskono@163.com
import argparse
import copy
import importlib
import time
from dataclasses import asdict, dataclass
from enum import Enum
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import cv2
import numpy as np
from rapid_table.utils import DownloadModel, LoadImage, Logger, VisTable
from .table_matcher import TableMatch
from .table_structure import TableStructurer, TableStructureUnitable
from markdownify import markdownify as md
logger = Logger(logger_name=__name__).get_log()
root_dir = Path(__file__).resolve().parent
class ModelType(Enum):
PPSTRUCTURE_EN = "ppstructure_en"
PPSTRUCTURE_ZH = "ppstructure_zh"
SLANETPLUS = "slanet_plus"
UNITABLE = "unitable"
ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
KEY_TO_MODEL_URL = {
ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
ModelType.UNITABLE.value: {
"encoder": f"{ROOT_URL}/unitable/encoder.pth",
"decoder": f"{ROOT_URL}/unitable/decoder.pth",
"vocab": f"{ROOT_URL}/unitable/vocab.json",
},
}
@dataclass
class RapidTableInput:
model_type: Optional[str] = ModelType.SLANETPLUS.value
model_path: Union[str, Path, None, Dict[str, str]] = None
use_cuda: bool = False
device: str = "cpu"
@dataclass
class RapidTableOutput:
pred_html: Optional[str] = None
cell_bboxes: Optional[np.ndarray] = None
logic_points: Optional[np.ndarray] = None
elapse: Optional[float] = None
class RapidTable:
def __init__(self, config: RapidTableInput):
self.model_type = config.model_type
if self.model_type not in KEY_TO_MODEL_URL:
model_list = ",".join(KEY_TO_MODEL_URL)
raise ValueError(
f"{self.model_type} is not supported. The currently supported models are {model_list}."
)
config.model_path = self.get_model_path(config.model_type, config.model_path)
if self.model_type == ModelType.UNITABLE.value:
self.table_structure = TableStructureUnitable(asdict(config))
else:
self.table_structure = TableStructurer(asdict(config))
self.table_matcher = TableMatch()
try:
self.ocr_engine = importlib.import_module("rapidocr").RapidOCR()
except ModuleNotFoundError:
self.ocr_engine = None
self.load_img = LoadImage()
def __call__(
self,
img_content: Union[str, np.ndarray, bytes, Path],
ocr_result: List[Union[List[List[float]], str, str]] = None,
) -> RapidTableOutput:
if self.ocr_engine is None and ocr_result is None:
raise ValueError(
"One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
)
img = self.load_img(img_content)
s = time.perf_counter()
h, w = img.shape[:2]
if ocr_result is None:
ocr_result = self.ocr_engine(img)
ocr_result = list(
zip(
ocr_result.boxes,
ocr_result.txts,
ocr_result.scores,
)
)
dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
pred_structures, cell_bboxes, _ = self.table_structure(copy.deepcopy(img))
# 适配slanet-plus模型输出的box缩放还原
if self.model_type == ModelType.SLANETPLUS.value:
cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
# 过滤掉占位的bbox
mask = ~np.all(cell_bboxes == 0, axis=1)
cell_bboxes = cell_bboxes[mask]
logic_points = self.table_matcher.decode_logic_points(pred_structures)
elapse = time.perf_counter() - s
return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
def get_boxes_recs(
self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
) -> Tuple[np.ndarray, Tuple[str, str]]:
dt_boxes, rec_res, scores = list(zip(*ocr_result))
rec_res = list(zip(rec_res, scores))
r_boxes = []
for box in dt_boxes:
box = np.array(box)
x_min = max(0, box[:, 0].min() - 1)
x_max = min(w, box[:, 0].max() + 1)
y_min = max(0, box[:, 1].min() - 1)
y_max = min(h, box[:, 1].max() + 1)
box = [x_min, y_min, x_max, y_max]
r_boxes.append(box)
dt_boxes = np.array(r_boxes)
return dt_boxes, rec_res
def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
h, w = img.shape[:2]
resized = 488
ratio = min(resized / h, resized / w)
w_ratio = resized / (w * ratio)
h_ratio = resized / (h * ratio)
cell_bboxes[:, 0::2] *= w_ratio
cell_bboxes[:, 1::2] *= h_ratio
return cell_bboxes
@staticmethod
def get_model_path(
model_type: str, model_path: Union[str, Path, None]
) -> Union[str, Dict[str, str]]:
if model_path is not None:
return model_path
model_url = KEY_TO_MODEL_URL.get(model_type, None)
if isinstance(model_url, str):
model_path = DownloadModel.download(model_url)
return model_path
if isinstance(model_url, dict):
model_paths = {}
for k, url in model_url.items():
model_paths[k] = DownloadModel.download(
url, save_model_name=f"{model_type}_{Path(url).name}"
)
return model_paths
raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")
def parse_args(arg_list: Optional[List[str]] = None):
parser = argparse.ArgumentParser()
parser.add_argument(
"-v",
"--vis",
action="store_true",
default=False,
help="Wheter to visualize the layout results.",
)
parser.add_argument(
"-img", "--img_path", type=str, required=True, help="Path to image for layout."
)
parser.add_argument(
"-m",
"--model_type",
type=str,
default=ModelType.SLANETPLUS.value,
choices=list(KEY_TO_MODEL_URL),
)
args = parser.parse_args(arg_list)
return args
# try:
# ocr_engine = importlib.import_module("rapidocr").RapidOCR()
# except ModuleNotFoundError as exc:
# raise ModuleNotFoundError(
# "Please install the rapidocr by pip install rapidocr"
# ) from exc
input_args = RapidTableInput(model_type=ModelType.SLANETPLUS.value)
table_engine = RapidTable(input_args)
def table2md_pipeline(img, ocr_result):
# rapid_ocr_output = ocr_engine(img)
# ocr_result = list(zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores))
table_results = table_engine(img, ocr_result)
html_content = table_results.pred_html
md_content = md(html_content)
return md_content
# def main(arg_list: Optional[List[str]] = None):
# args = parse_args(arg_list)
# try:
# ocr_engine = importlib.import_module("rapidocr").RapidOCR()
# except ModuleNotFoundError as exc:
# raise ModuleNotFoundError(
# "Please install the rapidocr by pip install rapidocr"
# ) from exc
# input_args = RapidTableInput(model_type=args.model_type)
# table_engine = RapidTable(input_args)
# img = cv2.imread(args.img_path)
# rapid_ocr_output = ocr_engine(img)
# ocr_result = list(
# zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
# )
# table_results = table_engine(img, ocr_result)
# print(table_results.pred_html)
# viser = VisTable()
# if args.vis:
# img_path = Path(args.img_path)
# save_dir = img_path.resolve().parent
# save_html_path = save_dir / f"{Path(img_path).stem}.html"
# save_drawed_path = save_dir / f"vis_{Path(img_path).name}"
# viser(img_path, table_results, save_html_path, save_drawed_path)
if __name__ == "__main__":
res = table2md_pipeline(cv2.imread('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/11.jpg'))
print('*' * 50)
print(res)