You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
257 lines
8.3 KiB
Python
257 lines
8.3 KiB
Python
# -*- encoding: utf-8 -*-
|
|
# @Author: SWHL
|
|
# @Contact: liekkaskono@163.com
|
|
import argparse
|
|
import copy
|
|
import importlib
|
|
import time
|
|
from dataclasses import asdict, dataclass
|
|
from enum import Enum
|
|
from pathlib import Path
|
|
from typing import Dict, List, Optional, Tuple, Union
|
|
|
|
import cv2
|
|
import numpy as np
|
|
|
|
from rapid_table.utils import DownloadModel, LoadImage, Logger, VisTable
|
|
|
|
from .table_matcher import TableMatch
|
|
from .table_structure import TableStructurer, TableStructureUnitable
|
|
from markdownify import markdownify as md
|
|
|
|
logger = Logger(logger_name=__name__).get_log()
|
|
root_dir = Path(__file__).resolve().parent
|
|
|
|
|
|
class ModelType(Enum):
|
|
PPSTRUCTURE_EN = "ppstructure_en"
|
|
PPSTRUCTURE_ZH = "ppstructure_zh"
|
|
SLANETPLUS = "slanet_plus"
|
|
UNITABLE = "unitable"
|
|
|
|
|
|
ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
|
|
KEY_TO_MODEL_URL = {
|
|
ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
|
|
ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
|
|
ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
|
|
ModelType.UNITABLE.value: {
|
|
"encoder": f"{ROOT_URL}/unitable/encoder.pth",
|
|
"decoder": f"{ROOT_URL}/unitable/decoder.pth",
|
|
"vocab": f"{ROOT_URL}/unitable/vocab.json",
|
|
},
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class RapidTableInput:
|
|
model_type: Optional[str] = ModelType.SLANETPLUS.value
|
|
model_path: Union[str, Path, None, Dict[str, str]] = None
|
|
use_cuda: bool = False
|
|
device: str = "cpu"
|
|
|
|
|
|
@dataclass
|
|
class RapidTableOutput:
|
|
pred_html: Optional[str] = None
|
|
cell_bboxes: Optional[np.ndarray] = None
|
|
logic_points: Optional[np.ndarray] = None
|
|
elapse: Optional[float] = None
|
|
|
|
|
|
class RapidTable:
|
|
def __init__(self, config: RapidTableInput):
|
|
self.model_type = config.model_type
|
|
if self.model_type not in KEY_TO_MODEL_URL:
|
|
model_list = ",".join(KEY_TO_MODEL_URL)
|
|
raise ValueError(
|
|
f"{self.model_type} is not supported. The currently supported models are {model_list}."
|
|
)
|
|
|
|
config.model_path = self.get_model_path(config.model_type, config.model_path)
|
|
if self.model_type == ModelType.UNITABLE.value:
|
|
self.table_structure = TableStructureUnitable(asdict(config))
|
|
else:
|
|
self.table_structure = TableStructurer(asdict(config))
|
|
|
|
self.table_matcher = TableMatch()
|
|
|
|
try:
|
|
self.ocr_engine = importlib.import_module("rapidocr").RapidOCR()
|
|
except ModuleNotFoundError:
|
|
self.ocr_engine = None
|
|
|
|
self.load_img = LoadImage()
|
|
|
|
def __call__(
|
|
self,
|
|
img_content: Union[str, np.ndarray, bytes, Path],
|
|
ocr_result: List[Union[List[List[float]], str, str]] = None,
|
|
) -> RapidTableOutput:
|
|
if self.ocr_engine is None and ocr_result is None:
|
|
raise ValueError(
|
|
"One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
|
|
)
|
|
|
|
img = self.load_img(img_content)
|
|
|
|
s = time.perf_counter()
|
|
h, w = img.shape[:2]
|
|
|
|
if ocr_result is None:
|
|
ocr_result = self.ocr_engine(img)
|
|
ocr_result = list(
|
|
zip(
|
|
ocr_result.boxes,
|
|
ocr_result.txts,
|
|
ocr_result.scores,
|
|
)
|
|
)
|
|
dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
|
|
|
|
pred_structures, cell_bboxes, _ = self.table_structure(copy.deepcopy(img))
|
|
|
|
# 适配slanet-plus模型输出的box缩放还原
|
|
if self.model_type == ModelType.SLANETPLUS.value:
|
|
cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
|
|
|
|
pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
|
|
|
|
# 过滤掉占位的bbox
|
|
mask = ~np.all(cell_bboxes == 0, axis=1)
|
|
cell_bboxes = cell_bboxes[mask]
|
|
|
|
logic_points = self.table_matcher.decode_logic_points(pred_structures)
|
|
elapse = time.perf_counter() - s
|
|
return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
|
|
|
|
def get_boxes_recs(
|
|
self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
|
|
) -> Tuple[np.ndarray, Tuple[str, str]]:
|
|
dt_boxes, rec_res, scores = list(zip(*ocr_result))
|
|
rec_res = list(zip(rec_res, scores))
|
|
|
|
r_boxes = []
|
|
for box in dt_boxes:
|
|
box = np.array(box)
|
|
x_min = max(0, box[:, 0].min() - 1)
|
|
x_max = min(w, box[:, 0].max() + 1)
|
|
y_min = max(0, box[:, 1].min() - 1)
|
|
y_max = min(h, box[:, 1].max() + 1)
|
|
box = [x_min, y_min, x_max, y_max]
|
|
r_boxes.append(box)
|
|
dt_boxes = np.array(r_boxes)
|
|
return dt_boxes, rec_res
|
|
|
|
def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
|
|
h, w = img.shape[:2]
|
|
resized = 488
|
|
ratio = min(resized / h, resized / w)
|
|
w_ratio = resized / (w * ratio)
|
|
h_ratio = resized / (h * ratio)
|
|
cell_bboxes[:, 0::2] *= w_ratio
|
|
cell_bboxes[:, 1::2] *= h_ratio
|
|
return cell_bboxes
|
|
|
|
@staticmethod
|
|
def get_model_path(
|
|
model_type: str, model_path: Union[str, Path, None]
|
|
) -> Union[str, Dict[str, str]]:
|
|
if model_path is not None:
|
|
return model_path
|
|
|
|
model_url = KEY_TO_MODEL_URL.get(model_type, None)
|
|
if isinstance(model_url, str):
|
|
model_path = DownloadModel.download(model_url)
|
|
return model_path
|
|
|
|
if isinstance(model_url, dict):
|
|
model_paths = {}
|
|
for k, url in model_url.items():
|
|
model_paths[k] = DownloadModel.download(
|
|
url, save_model_name=f"{model_type}_{Path(url).name}"
|
|
)
|
|
return model_paths
|
|
|
|
raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")
|
|
|
|
|
|
def parse_args(arg_list: Optional[List[str]] = None):
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"-v",
|
|
"--vis",
|
|
action="store_true",
|
|
default=False,
|
|
help="Wheter to visualize the layout results.",
|
|
)
|
|
parser.add_argument(
|
|
"-img", "--img_path", type=str, required=True, help="Path to image for layout."
|
|
)
|
|
parser.add_argument(
|
|
"-m",
|
|
"--model_type",
|
|
type=str,
|
|
default=ModelType.SLANETPLUS.value,
|
|
choices=list(KEY_TO_MODEL_URL),
|
|
)
|
|
args = parser.parse_args(arg_list)
|
|
return args
|
|
|
|
|
|
# try:
|
|
# ocr_engine = importlib.import_module("rapidocr").RapidOCR()
|
|
# except ModuleNotFoundError as exc:
|
|
# raise ModuleNotFoundError(
|
|
# "Please install the rapidocr by pip install rapidocr"
|
|
# ) from exc
|
|
|
|
input_args = RapidTableInput(model_type=ModelType.SLANETPLUS.value)
|
|
table_engine = RapidTable(input_args)
|
|
|
|
def table2md_pipeline(img, ocr_result):
|
|
# rapid_ocr_output = ocr_engine(img)
|
|
# ocr_result = list(zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores))
|
|
table_results = table_engine(img, ocr_result)
|
|
html_content = table_results.pred_html
|
|
md_content = md(html_content)
|
|
return md_content
|
|
|
|
|
|
# def main(arg_list: Optional[List[str]] = None):
|
|
# args = parse_args(arg_list)
|
|
|
|
# try:
|
|
# ocr_engine = importlib.import_module("rapidocr").RapidOCR()
|
|
# except ModuleNotFoundError as exc:
|
|
# raise ModuleNotFoundError(
|
|
# "Please install the rapidocr by pip install rapidocr"
|
|
# ) from exc
|
|
|
|
# input_args = RapidTableInput(model_type=args.model_type)
|
|
# table_engine = RapidTable(input_args)
|
|
|
|
# img = cv2.imread(args.img_path)
|
|
|
|
# rapid_ocr_output = ocr_engine(img)
|
|
# ocr_result = list(
|
|
# zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
|
|
# )
|
|
# table_results = table_engine(img, ocr_result)
|
|
# print(table_results.pred_html)
|
|
|
|
# viser = VisTable()
|
|
# if args.vis:
|
|
# img_path = Path(args.img_path)
|
|
|
|
# save_dir = img_path.resolve().parent
|
|
# save_html_path = save_dir / f"{Path(img_path).stem}.html"
|
|
# save_drawed_path = save_dir / f"vis_{Path(img_path).name}"
|
|
# viser(img_path, table_results, save_html_path, save_drawed_path)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
res = table2md_pipeline(cv2.imread('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/11.jpg'))
|
|
print('*' * 50)
|
|
print(res)
|