diff --git a/.env b/.env
new file mode 100644
index 0000000..fc7ebc4
--- /dev/null
+++ b/.env
@@ -0,0 +1,7 @@
+POSTGRESQL_HOST=
+POSTGRESQL_PORT=
+POSTGRESQL_USERNAME=
+POSTGRESQL_PASSWORD=
+POSTGRESQL_DATABASE=
+
+VISUAL=0
\ No newline at end of file
diff --git a/.env.dev b/.env.dev
new file mode 100644
index 0000000..817e26a
--- /dev/null
+++ b/.env.dev
@@ -0,0 +1,7 @@
+POSTGRESQL_HOST=192.168.10.137
+POSTGRESQL_PORT=54321
+POSTGRESQL_USERNAME=postgres
+POSTGRESQL_PASSWORD=123456
+POSTGRESQL_DATABASE=pdf-qa
+
+VISUAL=1
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..fe192d2
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,5 @@
+venv
+*.pdf
+.vscode
+visual_images/*.jpg
+__pycache__
\ No newline at end of file
diff --git a/README.md b/README.md
index e69de29..85fb2e9 100644
--- a/README.md
+++ b/README.md
@@ -0,0 +1,21 @@
+# 环境说明
+
+```
+# 1. pdf转图片需要安装以下依赖
+apt install -y poppler-utils
+# 2. 解决paddle的兼容问题
+export FLAGS_enable_pir_api=0
+# 3. 安装paddlepaddle-gpu后需要配置的环境
+apt install libcudnn8
+apt install libcudnn8-dev
+echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.11/site-packages/nvidia/cublas/lib/" >> /etc/profile
+# 4. 手动安装MinerU需要用到的模型,下载路径:
+# model_dir is: /root/.cache/modelscope/hub/models/opendatalab/PDF-Extract-Kit-1___0/models
+# layoutreader_model_dir is: /root/.cache/modelscope/hub/models/ppaanngggg/layoutreader
+# The configuration file has been configured successfully, the path is: /root/magic-pdf.json
+# 需要将项目中的magic-pdf.json链接到/root/magic-pdf.json
+ln -s `pwd`/magic-pdf.json /root/magic-pdf.json
+python download_MinerU_models.py
+# 5. python连接postgresql需要下载的依赖
+apt install postgresql postgresql-contrib libpq-dev
+```
diff --git a/download_MinerU_models.py b/download_MinerU_models.py
new file mode 100644
index 0000000..626473d
--- /dev/null
+++ b/download_MinerU_models.py
@@ -0,0 +1,67 @@
+import json
+import shutil
+import os
+
+import requests
+from modelscope import snapshot_download
+
+
+def download_json(url):
+ # 下载JSON文件
+ response = requests.get(url)
+ response.raise_for_status() # 检查请求是否成功
+ return response.json()
+
+
+def download_and_modify_json(url, local_filename, modifications):
+ if os.path.exists(local_filename):
+ data = json.load(open(local_filename))
+ config_version = data.get('config_version', '0.0.0')
+ if config_version < '1.2.0':
+ data = download_json(url)
+ else:
+ data = download_json(url)
+
+ # 修改内容
+ for key, value in modifications.items():
+ data[key] = value
+
+ # 保存修改后的内容
+ with open(local_filename, 'w', encoding='utf-8') as f:
+ json.dump(data, f, ensure_ascii=False, indent=4)
+
+
+if __name__ == '__main__':
+ mineru_patterns = [
+ # "models/Layout/LayoutLMv3/*",
+ "models/Layout/YOLO/*",
+ "models/MFD/YOLO/*",
+ "models/MFR/unimernet_hf_small_2503/*",
+ "models/OCR/paddleocr_torch/*",
+ # "models/TabRec/TableMaster/*",
+ # "models/TabRec/StructEqTable/*",
+ ]
+ model_dir = snapshot_download('opendatalab/PDF-Extract-Kit-1.0', allow_patterns=mineru_patterns)
+ layoutreader_model_dir = snapshot_download('ppaanngggg/layoutreader')
+ model_dir = model_dir + '/models'
+ print(f'model_dir is: {model_dir}')
+ print(f'layoutreader_model_dir is: {layoutreader_model_dir}')
+
+ # paddleocr_model_dir = model_dir + '/OCR/paddleocr'
+ # user_paddleocr_dir = os.path.expanduser('~/.paddleocr')
+ # if os.path.exists(user_paddleocr_dir):
+ # shutil.rmtree(user_paddleocr_dir)
+ # shutil.copytree(paddleocr_model_dir, user_paddleocr_dir)
+
+ json_url = 'https://gcore.jsdelivr.net/gh/opendatalab/MinerU@master/magic-pdf.template.json'
+ config_file_name = 'magic-pdf.json'
+ home_dir = os.path.expanduser('~')
+ config_file = os.path.join(home_dir, config_file_name)
+
+ json_mods = {
+ 'models-dir': model_dir,
+ 'layoutreader-model-dir': layoutreader_model_dir,
+ }
+
+ download_and_modify_json(json_url, config_file, json_mods)
+ print(f'The configuration file has been configured successfully, the path is: {config_file}')
diff --git a/helper/content_recognition/main.py b/helper/content_recognition/main.py
new file mode 100644
index 0000000..b51f2c9
--- /dev/null
+++ b/helper/content_recognition/main.py
@@ -0,0 +1,116 @@
+from typing import List
+import cv2
+from .utils import scanning_document_classify, text_rec, table_rec, scanning_document_rec, markdown_rec, assign_tables_to_titles, remove_watermark
+from tqdm import tqdm
+
+
+class LayoutRecognitionResult(object):
+
+ def __init__(self, clsid, content, box, table_title=None):
+ self.clsid = clsid
+ self.content = content
+ self.box = box
+ self.table_title = table_title
+
+ def __repr__(self):
+ return f"[{self.clsid}] {self.content}"
+
+
+expand_pixel = 10
+
+
+def rec(page_detection_results, tmp_dir) -> List[List[LayoutRecognitionResult]]:
+ page_recognition_results = []
+
+ for page_idx in tqdm(range(len(page_detection_results)), '文本识别'):
+ results = page_detection_results[page_idx]
+ if not results.boxes:
+ page_recognition_results.append([])
+ continue
+
+ img = cv2.imread(results.image_path)
+ h, w = img.shape[:2]
+
+ for layout in results.boxes:
+ # box往外扩一点便于ocr
+ layout.pos[0] -= expand_pixel
+ layout.pos[1] -= expand_pixel
+ layout.pos[2] += expand_pixel
+ layout.pos[3] += expand_pixel
+
+ layout.pos[0] = max(0, layout.pos[0])
+ layout.pos[1] = max(0, layout.pos[1])
+ layout.pos[2] = min(w, layout.pos[2])
+ layout.pos[3] = min(h, layout.pos[3])
+
+ outputs = []
+
+ is_scanning_document = False
+ for layout in results.boxes:
+ x1, y1, x2, y2 = layout.pos
+ x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
+ layout_img = img[y1: y2, x1: x2]
+ content = None
+ if layout.clsid == 0:
+ # text
+ content = markdown_rec(layout_img)
+ elif layout.clsid == 2:
+ # figure
+ if scanning_document_classify(layout_img):
+ # 扫描件
+ is_scanning_document = True
+ content, layout_img = scanning_document_rec(layout_img)
+ source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
+ elif layout.clsid == 4:
+ # table
+ if scanning_document_classify(layout_img):
+ is_scanning_document = True
+ content, layout_img = scanning_document_rec(layout_img)
+ source_page_no_watermark_img = remove_watermark(cv2.imread(f'{tmp_dir}/{page_idx + 1}.jpg'))
+ else:
+ content = table_rec(layout_img)
+ elif layout.clsid == 5:
+ # table caption
+ ocr_results = text_rec(layout_img)
+ content = ''
+ for o in ocr_results:
+ content += f'{o}\n'
+ while content.endswith('\n'):
+ content = content[:-1]
+
+ if not content:
+ continue
+
+ result = LayoutRecognitionResult(layout.clsid, content, layout.pos)
+ outputs.append(result)
+
+ if is_scanning_document and len(outputs) == 1:
+ # 扫描件额外提取标题
+ h, w = source_page_no_watermark_img.shape[:2]
+ if h > w:
+ title_img = source_page_no_watermark_img[:360, :w, ...]
+
+ # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
+ # vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 0), (w, 360), (255, 255, 0), 3)
+ # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
+ else:
+ title_img = source_page_no_watermark_img[:410, :w, ...]
+
+ # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}.jpg', title_img)
+ # vis = cv2.rectangle(source_page_no_watermark_img.copy(), (0, 310), (w, 410), (255, 255, 0), 3)
+ # cv2.imwrite(f'/mnt/pdf2markdown/temp/{page_idx + 1}-vis.jpg', vis)
+ title = text_rec(title_img)
+ outputs[0].table_title = '\n'.join(title)
+ else:
+ # 自动给表格分配距离它最近的标题
+ assign_tables_to_titles(outputs)
+
+ # 表格标题可以删掉了
+ outputs = [_ for _ in outputs if _.clsid != 5]
+ # 将2-图片 和 4-表格转为数据库中的枚举 1-表格
+ for o in outputs:
+ if o.clsid == 2 or o.clsid == 4:
+ o.clsid = 1
+ page_recognition_results.append(outputs)
+
+ return page_recognition_results
diff --git a/helper/content_recognition/rapid_table_pipeline/__init__.py b/helper/content_recognition/rapid_table_pipeline/__init__.py
new file mode 100644
index 0000000..702f7d0
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/__init__.py
@@ -0,0 +1,5 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+from .main import RapidTable, RapidTableInput
+from .utils import VisTable
diff --git a/helper/content_recognition/rapid_table_pipeline/main.py b/helper/content_recognition/rapid_table_pipeline/main.py
new file mode 100644
index 0000000..3dcfef4
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/main.py
@@ -0,0 +1,258 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+import argparse
+import copy
+import importlib
+import time
+from dataclasses import asdict, dataclass
+from enum import Enum
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple, Union
+
+import cv2
+import numpy as np
+
+from rapid_table.utils import DownloadModel, LoadImage, Logger, VisTable
+
+from .table_matcher import TableMatch
+from .table_structure import TableStructurer, TableStructureUnitable
+from markdownify import markdownify as md
+
+logger = Logger(logger_name=__name__).get_log()
+root_dir = Path(__file__).resolve().parent
+
+
+class ModelType(Enum):
+ PPSTRUCTURE_EN = "ppstructure_en"
+ PPSTRUCTURE_ZH = "ppstructure_zh"
+ SLANETPLUS = "slanet_plus"
+ UNITABLE = "unitable"
+
+
+ROOT_URL = "https://www.modelscope.cn/models/RapidAI/RapidTable/resolve/master/"
+KEY_TO_MODEL_URL = {
+ ModelType.PPSTRUCTURE_EN.value: f"{ROOT_URL}/en_ppstructure_mobile_v2_SLANet.onnx",
+ ModelType.PPSTRUCTURE_ZH.value: f"{ROOT_URL}/ch_ppstructure_mobile_v2_SLANet.onnx",
+ ModelType.SLANETPLUS.value: f"{ROOT_URL}/slanet-plus.onnx",
+ ModelType.UNITABLE.value: {
+ "encoder": f"{ROOT_URL}/unitable/encoder.pth",
+ "decoder": f"{ROOT_URL}/unitable/decoder.pth",
+ "vocab": f"{ROOT_URL}/unitable/vocab.json",
+ },
+}
+
+
+@dataclass
+class RapidTableInput:
+ model_type: Optional[str] = ModelType.SLANETPLUS.value
+ model_path: Union[str, Path, None, Dict[str, str]] = None
+ use_cuda: bool = False
+ device: str = "cpu"
+
+
+@dataclass
+class RapidTableOutput:
+ pred_html: Optional[str] = None
+ cell_bboxes: Optional[np.ndarray] = None
+ logic_points: Optional[np.ndarray] = None
+ elapse: Optional[float] = None
+
+
+class RapidTable:
+ def __init__(self, config: RapidTableInput):
+ self.model_type = config.model_type
+ if self.model_type not in KEY_TO_MODEL_URL:
+ model_list = ",".join(KEY_TO_MODEL_URL)
+ raise ValueError(
+ f"{self.model_type} is not supported. The currently supported models are {model_list}."
+ )
+
+ config.model_path = self.get_model_path(config.model_type, config.model_path)
+ if self.model_type == ModelType.UNITABLE.value:
+ self.table_structure = TableStructureUnitable(asdict(config))
+ else:
+ self.table_structure = TableStructurer(asdict(config))
+
+ self.table_matcher = TableMatch()
+
+ try:
+ self.ocr_engine = importlib.import_module("rapidocr").RapidOCR()
+ except ModuleNotFoundError:
+ self.ocr_engine = None
+
+ self.load_img = LoadImage()
+
+ def __call__(
+ self,
+ img_content: Union[str, np.ndarray, bytes, Path],
+ ocr_result: List[Union[List[List[float]], str, str]] = None,
+ ) -> RapidTableOutput:
+ if self.ocr_engine is None and ocr_result is None:
+ raise ValueError(
+ "One of two conditions must be met: ocr_result is not empty, or rapidocr is installed."
+ )
+
+ img = self.load_img(img_content)
+
+ s = time.perf_counter()
+ h, w = img.shape[:2]
+
+ if ocr_result is None:
+ ocr_result = self.ocr_engine(img)
+ ocr_result = list(
+ zip(
+ ocr_result.boxes,
+ ocr_result.txts,
+ ocr_result.scores,
+ )
+ )
+ dt_boxes, rec_res = self.get_boxes_recs(ocr_result, h, w)
+
+ pred_structures, cell_bboxes, _ = self.table_structure(copy.deepcopy(img))
+
+ # 适配slanet-plus模型输出的box缩放还原
+ if self.model_type == ModelType.SLANETPLUS.value:
+ cell_bboxes = self.adapt_slanet_plus(img, cell_bboxes)
+
+ pred_html = self.table_matcher(pred_structures, cell_bboxes, dt_boxes, rec_res)
+
+ # 过滤掉占位的bbox
+ mask = ~np.all(cell_bboxes == 0, axis=1)
+ cell_bboxes = cell_bboxes[mask]
+
+ logic_points = self.table_matcher.decode_logic_points(pred_structures)
+ elapse = time.perf_counter() - s
+ return RapidTableOutput(pred_html, cell_bboxes, logic_points, elapse)
+
+ def get_boxes_recs(
+ self, ocr_result: List[Union[List[List[float]], str, str]], h: int, w: int
+ ) -> Tuple[np.ndarray, Tuple[str, str]]:
+ dt_boxes, rec_res, scores = list(zip(*ocr_result))
+ rec_res = list(zip(rec_res, scores))
+
+ r_boxes = []
+ for box in dt_boxes:
+ box = np.array(box)
+ x_min = max(0, box[:, 0].min() - 1)
+ x_max = min(w, box[:, 0].max() + 1)
+ y_min = max(0, box[:, 1].min() - 1)
+ y_max = min(h, box[:, 1].max() + 1)
+ box = [x_min, y_min, x_max, y_max]
+ r_boxes.append(box)
+ dt_boxes = np.array(r_boxes)
+ return dt_boxes, rec_res
+
+ def adapt_slanet_plus(self, img: np.ndarray, cell_bboxes: np.ndarray) -> np.ndarray:
+ h, w = img.shape[:2]
+ resized = 488
+ ratio = min(resized / h, resized / w)
+ w_ratio = resized / (w * ratio)
+ h_ratio = resized / (h * ratio)
+ cell_bboxes[:, 0::2] *= w_ratio
+ cell_bboxes[:, 1::2] *= h_ratio
+ return cell_bboxes
+
+ @staticmethod
+ def get_model_path(
+ model_type: str, model_path: Union[str, Path, None]
+ ) -> Union[str, Dict[str, str]]:
+ if model_path is not None:
+ return model_path
+
+ model_url = KEY_TO_MODEL_URL.get(model_type, None)
+ if isinstance(model_url, str):
+ model_path = DownloadModel.download(model_url)
+ return model_path
+
+ if isinstance(model_url, dict):
+ model_paths = {}
+ for k, url in model_url.items():
+ model_paths[k] = DownloadModel.download(
+ url, save_model_name=f"{model_type}_{Path(url).name}"
+ )
+ return model_paths
+
+ raise ValueError(f"Model URL: {type(model_url)} is not between str and dict.")
+
+
+def parse_args(arg_list: Optional[List[str]] = None):
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "-v",
+ "--vis",
+ action="store_true",
+ default=False,
+ help="Wheter to visualize the layout results.",
+ )
+ parser.add_argument(
+ "-img", "--img_path", type=str, required=True, help="Path to image for layout."
+ )
+ parser.add_argument(
+ "-m",
+ "--model_type",
+ type=str,
+ default=ModelType.SLANETPLUS.value,
+ choices=list(KEY_TO_MODEL_URL),
+ )
+ args = parser.parse_args(arg_list)
+ return args
+
+
+try:
+ ocr_engine = importlib.import_module("rapidocr").RapidOCR()
+except ModuleNotFoundError as exc:
+ raise ModuleNotFoundError(
+ "Please install the rapidocr by pip install rapidocr"
+ ) from exc
+
+input_args = RapidTableInput(model_type=ModelType.SLANETPLUS.value)
+table_engine = RapidTable(input_args)
+
+def table2md_pipeline(img):
+ rapid_ocr_output = ocr_engine(img)
+ ocr_result = list(
+ zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
+ )
+ table_results = table_engine(img, ocr_result)
+ html_content = table_results.pred_html
+ md_content = md(html_content)
+ return md_content
+
+
+# def main(arg_list: Optional[List[str]] = None):
+# args = parse_args(arg_list)
+
+# try:
+# ocr_engine = importlib.import_module("rapidocr").RapidOCR()
+# except ModuleNotFoundError as exc:
+# raise ModuleNotFoundError(
+# "Please install the rapidocr by pip install rapidocr"
+# ) from exc
+
+# input_args = RapidTableInput(model_type=args.model_type)
+# table_engine = RapidTable(input_args)
+
+# img = cv2.imread(args.img_path)
+
+# rapid_ocr_output = ocr_engine(img)
+# ocr_result = list(
+# zip(rapid_ocr_output.boxes, rapid_ocr_output.txts, rapid_ocr_output.scores)
+# )
+# table_results = table_engine(img, ocr_result)
+# print(table_results.pred_html)
+
+# viser = VisTable()
+# if args.vis:
+# img_path = Path(args.img_path)
+
+# save_dir = img_path.resolve().parent
+# save_html_path = save_dir / f"{Path(img_path).stem}.html"
+# save_drawed_path = save_dir / f"vis_{Path(img_path).name}"
+# viser(img_path, table_results, save_html_path, save_drawed_path)
+
+
+if __name__ == "__main__":
+ res = table2md_pipeline(cv2.imread('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/11.jpg'))
+ print('*' * 50)
+ print(res)
diff --git a/helper/content_recognition/rapid_table_pipeline/models/ch_ppstructure_mobile_v2_SLANet.onnx b/helper/content_recognition/rapid_table_pipeline/models/ch_ppstructure_mobile_v2_SLANet.onnx
new file mode 100644
index 0000000..d8b58fc
Binary files /dev/null and b/helper/content_recognition/rapid_table_pipeline/models/ch_ppstructure_mobile_v2_SLANet.onnx differ
diff --git a/helper/content_recognition/rapid_table_pipeline/models/slanet-plus.onnx b/helper/content_recognition/rapid_table_pipeline/models/slanet-plus.onnx
new file mode 100644
index 0000000..d263823
Binary files /dev/null and b/helper/content_recognition/rapid_table_pipeline/models/slanet-plus.onnx differ
diff --git a/helper/content_recognition/rapid_table_pipeline/table_matcher/__init__.py b/helper/content_recognition/rapid_table_pipeline/table_matcher/__init__.py
new file mode 100644
index 0000000..9bff7d7
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/table_matcher/__init__.py
@@ -0,0 +1,4 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+from .matcher import TableMatch
diff --git a/helper/content_recognition/rapid_table_pipeline/table_matcher/matcher.py b/helper/content_recognition/rapid_table_pipeline/table_matcher/matcher.py
new file mode 100644
index 0000000..f579a12
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/table_matcher/matcher.py
@@ -0,0 +1,199 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# -*- encoding: utf-8 -*-
+import numpy as np
+
+from .utils import compute_iou, distance
+
+
+class TableMatch:
+ def __init__(self, filter_ocr_result=True, use_master=False):
+ self.filter_ocr_result = filter_ocr_result
+ self.use_master = use_master
+
+ def __call__(self, pred_structures, cell_bboxes, dt_boxes, rec_res):
+ if self.filter_ocr_result:
+ dt_boxes, rec_res = self._filter_ocr_result(cell_bboxes, dt_boxes, rec_res)
+ matched_index = self.match_result(dt_boxes, cell_bboxes)
+ pred_html, pred = self.get_pred_html(pred_structures, matched_index, rec_res)
+ return pred_html
+
+ def match_result(self, dt_boxes, cell_bboxes, min_iou=0.1**8):
+ matched = {}
+ for i, gt_box in enumerate(dt_boxes):
+ distances = []
+ for j, pred_box in enumerate(cell_bboxes):
+ if len(pred_box) == 8:
+ pred_box = [
+ np.min(pred_box[0::2]),
+ np.min(pred_box[1::2]),
+ np.max(pred_box[0::2]),
+ np.max(pred_box[1::2]),
+ ]
+ distances.append(
+ (distance(gt_box, pred_box), 1.0 - compute_iou(gt_box, pred_box))
+ ) # compute iou and l1 distance
+ sorted_distances = distances.copy()
+ # select det box by iou and l1 distance
+ sorted_distances = sorted(
+ sorted_distances, key=lambda item: (item[1], item[0])
+ )
+ # must > min_iou
+ if sorted_distances[0][1] >= 1 - min_iou:
+ continue
+
+ if distances.index(sorted_distances[0]) not in matched:
+ matched[distances.index(sorted_distances[0])] = [i]
+ else:
+ matched[distances.index(sorted_distances[0])].append(i)
+ return matched
+
+ def get_pred_html(self, pred_structures, matched_index, ocr_contents):
+ end_html = []
+ td_index = 0
+ for tag in pred_structures:
+ if "" not in tag:
+ end_html.append(tag)
+ continue
+
+ if "
| " == tag:
+ end_html.extend("")
+
+ if td_index in matched_index.keys():
+ b_with = False
+ if (
+ "" in ocr_contents[matched_index[td_index][0]]
+ and len(matched_index[td_index]) > 1
+ ):
+ b_with = True
+ end_html.extend("")
+
+ for i, td_index_index in enumerate(matched_index[td_index]):
+ content = ocr_contents[td_index_index][0]
+ if len(matched_index[td_index]) > 1:
+ if len(content) == 0:
+ continue
+
+ if content[0] == " ":
+ content = content[1:]
+
+ if "" in content:
+ content = content[3:]
+
+ if "" in content:
+ content = content[:-4]
+
+ if len(content) == 0:
+ continue
+
+ if i != len(matched_index[td_index]) - 1 and " " != content[-1]:
+ content += " "
+ end_html.extend(content)
+
+ if b_with:
+ end_html.extend("")
+
+ if " | | " == tag:
+ end_html.append("")
+ else:
+ end_html.append(tag)
+
+ td_index += 1
+
+ # Filter elements
+ filter_elements = ["", "", "", ""]
+ end_html = [v for v in end_html if v not in filter_elements]
+ return "".join(end_html), end_html
+
+ def decode_logic_points(self, pred_structures):
+ logic_points = []
+ current_row = 0
+ current_col = 0
+ max_rows = 0
+ max_cols = 0
+ occupied_cells = {} # 用于记录已经被占用的单元格
+
+ def is_occupied(row, col):
+ return (row, col) in occupied_cells
+
+ def mark_occupied(row, col, rowspan, colspan):
+ for r in range(row, row + rowspan):
+ for c in range(col, col + colspan):
+ occupied_cells[(r, c)] = True
+
+ i = 0
+ while i < len(pred_structures):
+ token = pred_structures[i]
+
+ if token == "":
+ current_col = 0 # 每次遇到
时,重置当前列号
+ elif token == "
":
+ current_row += 1 # 行结束,行号增加
+ elif token.startswith(" | ":
+ j += 1
+ # 提取 colspan 和 rowspan 属性
+ while j < len(pred_structures) and not pred_structures[
+ j
+ ].startswith(">"):
+ if "colspan=" in pred_structures[j]:
+ colspan = int(pred_structures[j].split("=")[1].strip("\"'"))
+ elif "rowspan=" in pred_structures[j]:
+ rowspan = int(pred_structures[j].split("=")[1].strip("\"'"))
+ j += 1
+
+ # 跳过已经处理过的属性 token
+ i = j
+
+ # 找到下一个未被占用的列
+ while is_occupied(current_row, current_col):
+ current_col += 1
+
+ # 计算逻辑坐标
+ r_start = current_row
+ r_end = current_row + rowspan - 1
+ col_start = current_col
+ col_end = current_col + colspan - 1
+
+ # 记录逻辑坐标
+ logic_points.append([r_start, r_end, col_start, col_end])
+
+ # 标记占用的单元格
+ mark_occupied(r_start, col_start, rowspan, colspan)
+
+ # 更新当前列号
+ current_col += colspan
+
+ # 更新最大行数和列数
+ max_rows = max(max_rows, r_end + 1)
+ max_cols = max(max_cols, col_end + 1)
+
+ i += 1
+
+ return logic_points
+
+ def _filter_ocr_result(self, cell_bboxes, dt_boxes, rec_res):
+ y1 = cell_bboxes[:, 1::2].min()
+ new_dt_boxes = []
+ new_rec_res = []
+
+ for box, rec in zip(dt_boxes, rec_res):
+ if np.max(box[1::2]) < y1:
+ continue
+ new_dt_boxes.append(box)
+ new_rec_res.append(rec)
+ return new_dt_boxes, new_rec_res
diff --git a/helper/content_recognition/rapid_table_pipeline/table_matcher/utils.py b/helper/content_recognition/rapid_table_pipeline/table_matcher/utils.py
new file mode 100644
index 0000000..57a613c
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/table_matcher/utils.py
@@ -0,0 +1,249 @@
+# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+import copy
+import re
+
+
+def deal_isolate_span(thead_part):
+ """
+ Deal with isolate span cases in this function.
+ It causes by wrong prediction in structure recognition model.
+ eg. predict | to | rowspan="2">.
+ :param thead_part:
+ :return:
+ """
+ # 1. find out isolate span tokens.
+ isolate_pattern = (
+ ' | rowspan="(\d)+" colspan="(\d)+">|'
+ ' | colspan="(\d)+" rowspan="(\d)+">|'
+ ' | rowspan="(\d)+">|'
+ ' | colspan="(\d)+">'
+ )
+ isolate_iter = re.finditer(isolate_pattern, thead_part)
+ isolate_list = [i.group() for i in isolate_iter]
+
+ # 2. find out span number, by step 1 results.
+ span_pattern = (
+ ' rowspan="(\d)+" colspan="(\d)+"|'
+ ' colspan="(\d)+" rowspan="(\d)+"|'
+ ' rowspan="(\d)+"|'
+ ' colspan="(\d)+"'
+ )
+ corrected_list = []
+ for isolate_item in isolate_list:
+ span_part = re.search(span_pattern, isolate_item)
+ spanStr_in_isolateItem = span_part.group()
+ # 3. merge the span number into the span token format string.
+ if spanStr_in_isolateItem is not None:
+ corrected_item = f" | "
+ corrected_list.append(corrected_item)
+ else:
+ corrected_list.append(None)
+
+ # 4. replace original isolated token.
+ for corrected_item, isolate_item in zip(corrected_list, isolate_list):
+ if corrected_item is not None:
+ thead_part = thead_part.replace(isolate_item, corrected_item)
+ else:
+ pass
+ return thead_part
+
+
+def deal_duplicate_bb(thead_part):
+ """
+ Deal duplicate or after replace.
+ Keep one in a | token.
+ :param thead_part:
+ :return:
+ """
+ # 1. find out | in .
+ td_pattern = (
+ '(.+?) | |'
+ '(.+?) | |'
+ '(.+?) | |'
+ '(.+?) | |'
+ "(.*?) | "
+ )
+ td_iter = re.finditer(td_pattern, thead_part)
+ td_list = [t.group() for t in td_iter]
+
+ # 2. is multiply in | or not?
+ new_td_list = []
+ for td_item in td_list:
+ if td_item.count("") > 1 or td_item.count("") > 1:
+ # multiply in | case.
+ # 1. remove all
+ td_item = td_item.replace("", "").replace("", "")
+ # 2. replace -> , -> .
+ td_item = td_item.replace("", " | ").replace(" | ", "")
+ new_td_list.append(td_item)
+ else:
+ new_td_list.append(td_item)
+
+ # 3. replace original thead part.
+ for td_item, new_td_item in zip(td_list, new_td_list):
+ thead_part = thead_part.replace(td_item, new_td_item)
+ return thead_part
+
+
+def deal_bb(result_token):
+ """
+ In our opinion, always occurs in text's context.
+ This function will find out all tokens in and insert by manual.
+ :param result_token:
+ :return:
+ """
+ # find out parts.
+ thead_pattern = "(.*?)"
+ if re.search(thead_pattern, result_token) is None:
+ return result_token
+ thead_part = re.search(thead_pattern, result_token).group()
+ origin_thead_part = copy.deepcopy(thead_part)
+
+ # check "rowspan" or "colspan" occur in parts or not .
+ span_pattern = '| | | | | | '
+ span_iter = re.finditer(span_pattern, thead_part)
+ span_list = [s.group() for s in span_iter]
+ has_span_in_head = True if len(span_list) > 0 else False
+
+ if not has_span_in_head:
+ # not include "rowspan" or "colspan" branch 1.
+ # 1. replace | to | , and | to
+ # 2. it is possible to predict text include or by Text-line recognition,
+ # so we replace to , and to
+ thead_part = (
+ thead_part.replace("", " | ")
+ .replace(" | ", "")
+ .replace("", "")
+ .replace("", "")
+ )
+ else:
+ # include "rowspan" or "colspan" branch 2.
+ # Firstly, we deal rowspan or colspan cases.
+ # 1. replace > to >
+ # 2. replace to
+ # 3. it is possible to predict text include or by Text-line recognition,
+ # so we replace to , and to
+
+ # Secondly, deal ordinary cases like branch 1
+
+ # replace ">" to ""
+ replaced_span_list = []
+ for sp in span_list:
+ replaced_span_list.append(sp.replace(">", ">"))
+ for sp, rsp in zip(span_list, replaced_span_list):
+ thead_part = thead_part.replace(sp, rsp)
+
+ # replace "" to ""
+ thead_part = thead_part.replace("", "")
+
+ # remove duplicated by re.sub
+ mb_pattern = "()+"
+ single_b_string = ""
+ thead_part = re.sub(mb_pattern, single_b_string, thead_part)
+
+ mgb_pattern = "()+"
+ single_gb_string = ""
+ thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
+
+ # ordinary cases like branch 1
+ thead_part = thead_part.replace("", " | ").replace("", "")
+
+ # convert back to , empty cell has no .
+ # but space cell( ) is suitable for | |
+ thead_part = thead_part.replace(" | ", " | ")
+ # deal with duplicated
+ thead_part = deal_duplicate_bb(thead_part)
+ # deal with isolate span tokens, which causes by wrong predict by structure prediction.
+ # eg.PMC5994107_011_00.png
+ thead_part = deal_isolate_span(thead_part)
+ # replace original result with new thead part.
+ result_token = result_token.replace(origin_thead_part, thead_part)
+ return result_token
+
+
+def deal_eb_token(master_token):
+ """
+ post process with , , ...
+ emptyBboxTokenDict = {
+ "[]": '',
+ "[' ']": '',
+ "['', ' ', '']": '',
+ "['\\u2028', '\\u2028']": '',
+ "['', ' ', '']": '',
+ "['', '']": '',
+ "['', ' ', '']": '',
+ "['', '', '', '']": '',
+ "['', '', ' ', '', '']": '',
+ "['', '']": '',
+ "['', ' ', '\\u2028', ' ', '\\u2028', ' ', '']": '',
+ }
+ :param master_token:
+ :return:
+ """
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", "\u2028\u2028 | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace("", " | ")
+ master_token = master_token.replace(
+ "", " \u2028 \u2028 | "
+ )
+ return master_token
+
+
+def distance(box_1, box_2):
+ x1, y1, x2, y2 = box_1
+ x3, y3, x4, y4 = box_2
+ dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
+ dis_2 = abs(x3 - x1) + abs(y3 - y1)
+ dis_3 = abs(x4 - x2) + abs(y4 - y2)
+ return dis + min(dis_2, dis_3)
+
+
+def compute_iou(rec1, rec2):
+ """
+ computing IoU
+ :param rec1: (y0, x0, y1, x1), which reflects
+ (top, left, bottom, right)
+ :param rec2: (y0, x0, y1, x1)
+ :return: scala value of IoU
+ """
+ # computing area of each rectangles
+ S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
+ S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
+
+ # computing the sum_area
+ sum_area = S_rec1 + S_rec2
+
+ # find the each edge of intersect rectangle
+ left_line = max(rec1[1], rec2[1])
+ right_line = min(rec1[3], rec2[3])
+ top_line = max(rec1[0], rec2[0])
+ bottom_line = min(rec1[2], rec2[2])
+
+ # judge if there is an intersect
+ if left_line >= right_line or top_line >= bottom_line:
+ return 0.0
+
+ intersect = (right_line - left_line) * (bottom_line - top_line)
+ return (intersect / (sum_area - intersect)) * 1.0
diff --git a/helper/content_recognition/rapid_table_pipeline/table_structure/__init__.py b/helper/content_recognition/rapid_table_pipeline/table_structure/__init__.py
new file mode 100644
index 0000000..3e638d9
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/table_structure/__init__.py
@@ -0,0 +1,15 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from .table_structure import TableStructurer
+from .table_structure_unitable import TableStructureUnitable
diff --git a/helper/content_recognition/rapid_table_pipeline/table_structure/table_structure.py b/helper/content_recognition/rapid_table_pipeline/table_structure/table_structure.py
new file mode 100644
index 0000000..9152603
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/table_structure/table_structure.py
@@ -0,0 +1,58 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import time
+from typing import Any, Dict
+
+import numpy as np
+
+from .utils import OrtInferSession, TableLabelDecode, TablePreprocess
+
+
+class TableStructurer:
+ def __init__(self, config: Dict[str, Any]):
+ self.preprocess_op = TablePreprocess()
+
+ self.session = OrtInferSession(config)
+
+ self.character = self.session.get_metadata()
+ self.postprocess_op = TableLabelDecode(self.character)
+
+ def __call__(self, img):
+ starttime = time.time()
+ data = {"image": img}
+ data = self.preprocess_op(data)
+ img = data[0]
+ if img is None:
+ return None, 0
+ img = np.expand_dims(img, axis=0)
+ img = img.copy()
+
+ outputs = self.session([img])
+
+ preds = {"loc_preds": outputs[0], "structure_probs": outputs[1]}
+
+ shape_list = np.expand_dims(data[-1], axis=0)
+ post_result = self.postprocess_op(preds, [shape_list])
+
+ bbox_list = post_result["bbox_batch_list"][0]
+
+ structure_str_list = post_result["structure_batch_list"][0]
+ structure_str_list = structure_str_list[0]
+ structure_str_list = (
+ ["", "", ""]
+ + structure_str_list
+ + ["
", "", ""]
+ )
+ elapse = time.time() - starttime
+ return structure_str_list, bbox_list, elapse
diff --git a/helper/content_recognition/rapid_table_pipeline/table_structure/table_structure_unitable.py b/helper/content_recognition/rapid_table_pipeline/table_structure/table_structure_unitable.py
new file mode 100644
index 0000000..2b98006
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/table_structure/table_structure_unitable.py
@@ -0,0 +1,229 @@
+import re
+import time
+
+import cv2
+import numpy as np
+import torch
+from PIL import Image
+from tokenizers import Tokenizer
+from torchvision import transforms
+
+from .unitable_modules import Encoder, GPTFastDecoder
+
+IMG_SIZE = 448
+EOS_TOKEN = ""
+BBOX_TOKENS = [f"bbox-{i}" for i in range(IMG_SIZE + 1)]
+
+HTML_BBOX_HTML_TOKENS = [
+ " | ",
+ "[",
+ "] | ",
+ "[",
+ "> | ",
+ "",
+ "
",
+ "",
+ "",
+ "",
+ "",
+ ' rowspan="2"',
+ ' rowspan="3"',
+ ' rowspan="4"',
+ ' rowspan="5"',
+ ' rowspan="6"',
+ ' rowspan="7"',
+ ' rowspan="8"',
+ ' rowspan="9"',
+ ' rowspan="10"',
+ ' rowspan="11"',
+ ' rowspan="12"',
+ ' rowspan="13"',
+ ' rowspan="14"',
+ ' rowspan="15"',
+ ' rowspan="16"',
+ ' rowspan="17"',
+ ' rowspan="18"',
+ ' rowspan="19"',
+ ' colspan="2"',
+ ' colspan="3"',
+ ' colspan="4"',
+ ' colspan="5"',
+ ' colspan="6"',
+ ' colspan="7"',
+ ' colspan="8"',
+ ' colspan="9"',
+ ' colspan="10"',
+ ' colspan="11"',
+ ' colspan="12"',
+ ' colspan="13"',
+ ' colspan="14"',
+ ' colspan="15"',
+ ' colspan="16"',
+ ' colspan="17"',
+ ' colspan="18"',
+ ' colspan="19"',
+ ' colspan="25"',
+]
+
+VALID_HTML_BBOX_TOKENS = [EOS_TOKEN] + HTML_BBOX_HTML_TOKENS + BBOX_TOKENS
+TASK_TOKENS = [
+ "[table]",
+ "[html]",
+ "[cell]",
+ "[bbox]",
+ "[cell+bbox]",
+ "[html+bbox]",
+]
+
+
+class TableStructureUnitable:
+ def __init__(self, config):
+ # encoder_path: str, decoder_path: str, vocab_path: str, device: str
+ vocab_path = config["model_path"]["vocab"]
+ encoder_path = config["model_path"]["encoder"]
+ decoder_path = config["model_path"]["decoder"]
+ device = config.get("device", "cuda:0") if config["use_cuda"] else "cpu"
+
+ self.vocab = Tokenizer.from_file(vocab_path)
+ self.token_white_list = [
+ self.vocab.token_to_id(i) for i in VALID_HTML_BBOX_TOKENS
+ ]
+ self.bbox_token_ids = set(self.vocab.token_to_id(i) for i in BBOX_TOKENS)
+ self.bbox_close_html_token = self.vocab.token_to_id("]")
+ self.prefix_token_id = self.vocab.token_to_id("[html+bbox]")
+ self.eos_id = self.vocab.token_to_id(EOS_TOKEN)
+ self.max_seq_len = 1024
+ self.device = device
+ self.img_size = IMG_SIZE
+
+ # init encoder
+ encoder_state_dict = torch.load(encoder_path, map_location=device)
+ self.encoder = Encoder()
+ self.encoder.load_state_dict(encoder_state_dict)
+ self.encoder.eval().to(device)
+
+ # init decoder
+ decoder_state_dict = torch.load(decoder_path, map_location=device)
+ self.decoder = GPTFastDecoder()
+ self.decoder.load_state_dict(decoder_state_dict)
+ self.decoder.eval().to(device)
+
+ # define img transform
+ self.transform = transforms.Compose(
+ [
+ transforms.Resize((448, 448)),
+ transforms.ToTensor(),
+ transforms.Normalize(
+ mean=[0.86597056, 0.88463002, 0.87491087],
+ std=[0.20686628, 0.18201602, 0.18485524],
+ ),
+ ]
+ )
+
+ @torch.inference_mode()
+ def __call__(self, image: np.ndarray):
+ start_time = time.time()
+ ori_h, ori_w = image.shape[:2]
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ image = Image.fromarray(image)
+ image = self.transform(image).unsqueeze(0).to(self.device)
+ self.decoder.setup_caches(
+ max_batch_size=1,
+ max_seq_length=self.max_seq_len,
+ dtype=image.dtype,
+ device=self.device,
+ )
+ context = (
+ torch.tensor([self.prefix_token_id], dtype=torch.int32)
+ .repeat(1, 1)
+ .to(self.device)
+ )
+ eos_id_tensor = torch.tensor(self.eos_id, dtype=torch.int32).to(self.device)
+ memory = self.encoder(image)
+ context = self.loop_decode(context, eos_id_tensor, memory)
+ bboxes, html_tokens = self.decode_tokens(context)
+ bboxes = bboxes.astype(np.float32)
+
+ # rescale boxes
+ scale_h = ori_h / self.img_size
+ scale_w = ori_w / self.img_size
+ bboxes[:, 0::2] *= scale_w # 缩放 x 坐标
+ bboxes[:, 1::2] *= scale_h # 缩放 y 坐标
+ bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, ori_w - 1)
+ bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, ori_h - 1)
+ structure_str_list = (
+ ["", "", ""]
+ + html_tokens
+ + ["
", "", ""]
+ )
+ return structure_str_list, bboxes, time.time() - start_time
+
+ def decode_tokens(self, context):
+ pred_html = context[0]
+ pred_html = pred_html.detach().cpu().numpy()
+ pred_html = self.vocab.decode(pred_html, skip_special_tokens=False)
+ seq = pred_html.split("")[0]
+ token_black_list = ["", "", *TASK_TOKENS]
+ for i in token_black_list:
+ seq = seq.replace(i, "")
+
+ tr_pattern = re.compile(r"(.*?)
", re.DOTALL)
+ td_pattern = re.compile(r"(.*?) | ", re.DOTALL)
+ bbox_pattern = re.compile(r"\[ bbox-(\d+) bbox-(\d+) bbox-(\d+) bbox-(\d+) \]")
+
+ decoded_list = []
+ bbox_coords = []
+
+ # 查找所有的 标签
+ for tr_match in tr_pattern.finditer(pred_html):
+ tr_content = tr_match.group(1)
+ decoded_list.append("
")
+
+ # 查找所有的 标签
+ for td_match in td_pattern.finditer(tr_content):
+ td_attrs = td_match.group(1).strip()
+ td_content = td_match.group(2).strip()
+ if td_attrs:
+ decoded_list.append(" | ")
+ decoded_list.append(" | ")
+ else:
+ decoded_list.append(" | ")
+
+ # 查找 bbox 坐标
+ bbox_match = bbox_pattern.search(td_content)
+ if bbox_match:
+ xmin, ymin, xmax, ymax = map(int, bbox_match.groups())
+ # 将坐标转换为从左上角开始顺时针到左下角的点的坐标
+ coords = np.array([xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax])
+ bbox_coords.append(coords)
+ else:
+ # 填充占位的bbox,保证后续流程统一
+ bbox_coords.append(np.array([0, 0, 0, 0, 0, 0, 0, 0]))
+ decoded_list.append("
")
+
+ bbox_coords_array = np.array(bbox_coords)
+ return bbox_coords_array, decoded_list
+
+ def loop_decode(self, context, eos_id_tensor, memory):
+ box_token_count = 0
+ for _ in range(self.max_seq_len):
+ eos_flag = (context == eos_id_tensor).any(dim=1)
+ if torch.all(eos_flag):
+ break
+
+ next_tokens = self.decoder(memory, context)
+ if next_tokens[0] in self.bbox_token_ids:
+ box_token_count += 1
+ if box_token_count > 4:
+ next_tokens = torch.tensor(
+ [self.bbox_close_html_token], dtype=torch.int32
+ )
+ box_token_count = 0
+ context = torch.cat([context, next_tokens], dim=1)
+ return context
diff --git a/helper/content_recognition/rapid_table_pipeline/table_structure/unitable_modules.py b/helper/content_recognition/rapid_table_pipeline/table_structure/unitable_modules.py
new file mode 100644
index 0000000..5b8dac3
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/table_structure/unitable_modules.py
@@ -0,0 +1,911 @@
+from dataclasses import dataclass
+from functools import partial
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import functional as F
+from torch.nn.modules.transformer import _get_activation_fn
+
+TOKEN_WHITE_LIST = [
+ 1,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ 23,
+ 24,
+ 25,
+ 26,
+ 27,
+ 28,
+ 29,
+ 30,
+ 31,
+ 32,
+ 33,
+ 34,
+ 35,
+ 36,
+ 37,
+ 38,
+ 39,
+ 40,
+ 41,
+ 42,
+ 43,
+ 44,
+ 45,
+ 46,
+ 47,
+ 48,
+ 49,
+ 50,
+ 51,
+ 52,
+ 53,
+ 54,
+ 55,
+ 56,
+ 57,
+ 58,
+ 59,
+ 60,
+ 61,
+ 62,
+ 63,
+ 64,
+ 65,
+ 66,
+ 67,
+ 68,
+ 69,
+ 70,
+ 71,
+ 72,
+ 73,
+ 74,
+ 75,
+ 76,
+ 77,
+ 78,
+ 79,
+ 80,
+ 81,
+ 82,
+ 83,
+ 84,
+ 85,
+ 86,
+ 87,
+ 88,
+ 89,
+ 90,
+ 91,
+ 92,
+ 93,
+ 94,
+ 95,
+ 96,
+ 97,
+ 98,
+ 99,
+ 100,
+ 101,
+ 102,
+ 103,
+ 104,
+ 105,
+ 106,
+ 107,
+ 108,
+ 109,
+ 110,
+ 111,
+ 112,
+ 113,
+ 114,
+ 115,
+ 116,
+ 117,
+ 118,
+ 119,
+ 120,
+ 121,
+ 122,
+ 123,
+ 124,
+ 125,
+ 126,
+ 127,
+ 128,
+ 129,
+ 130,
+ 131,
+ 132,
+ 133,
+ 134,
+ 135,
+ 136,
+ 137,
+ 138,
+ 139,
+ 140,
+ 141,
+ 142,
+ 143,
+ 144,
+ 145,
+ 146,
+ 147,
+ 148,
+ 149,
+ 150,
+ 151,
+ 152,
+ 153,
+ 154,
+ 155,
+ 156,
+ 157,
+ 158,
+ 159,
+ 160,
+ 161,
+ 162,
+ 163,
+ 164,
+ 165,
+ 166,
+ 167,
+ 168,
+ 169,
+ 170,
+ 171,
+ 172,
+ 173,
+ 174,
+ 175,
+ 176,
+ 177,
+ 178,
+ 179,
+ 180,
+ 181,
+ 182,
+ 183,
+ 184,
+ 185,
+ 186,
+ 187,
+ 188,
+ 189,
+ 190,
+ 191,
+ 192,
+ 193,
+ 194,
+ 195,
+ 196,
+ 197,
+ 198,
+ 199,
+ 200,
+ 201,
+ 202,
+ 203,
+ 204,
+ 205,
+ 206,
+ 207,
+ 208,
+ 209,
+ 210,
+ 211,
+ 212,
+ 213,
+ 214,
+ 215,
+ 216,
+ 217,
+ 218,
+ 219,
+ 220,
+ 221,
+ 222,
+ 223,
+ 224,
+ 225,
+ 226,
+ 227,
+ 228,
+ 229,
+ 230,
+ 231,
+ 232,
+ 233,
+ 234,
+ 235,
+ 236,
+ 237,
+ 238,
+ 239,
+ 240,
+ 241,
+ 242,
+ 243,
+ 244,
+ 245,
+ 246,
+ 247,
+ 248,
+ 249,
+ 250,
+ 251,
+ 252,
+ 253,
+ 254,
+ 255,
+ 256,
+ 257,
+ 258,
+ 259,
+ 260,
+ 261,
+ 262,
+ 263,
+ 264,
+ 265,
+ 266,
+ 267,
+ 268,
+ 269,
+ 270,
+ 271,
+ 272,
+ 273,
+ 274,
+ 275,
+ 276,
+ 277,
+ 278,
+ 279,
+ 280,
+ 281,
+ 282,
+ 283,
+ 284,
+ 285,
+ 286,
+ 287,
+ 288,
+ 289,
+ 290,
+ 291,
+ 292,
+ 293,
+ 294,
+ 295,
+ 296,
+ 297,
+ 298,
+ 299,
+ 300,
+ 301,
+ 302,
+ 303,
+ 304,
+ 305,
+ 306,
+ 307,
+ 308,
+ 309,
+ 310,
+ 311,
+ 312,
+ 313,
+ 314,
+ 315,
+ 316,
+ 317,
+ 318,
+ 319,
+ 320,
+ 321,
+ 322,
+ 323,
+ 324,
+ 325,
+ 326,
+ 327,
+ 328,
+ 329,
+ 330,
+ 331,
+ 332,
+ 333,
+ 334,
+ 335,
+ 336,
+ 337,
+ 338,
+ 339,
+ 340,
+ 341,
+ 342,
+ 343,
+ 344,
+ 345,
+ 346,
+ 347,
+ 348,
+ 349,
+ 350,
+ 351,
+ 352,
+ 353,
+ 354,
+ 355,
+ 356,
+ 357,
+ 358,
+ 359,
+ 360,
+ 361,
+ 362,
+ 363,
+ 364,
+ 365,
+ 366,
+ 367,
+ 368,
+ 369,
+ 370,
+ 371,
+ 372,
+ 373,
+ 374,
+ 375,
+ 376,
+ 377,
+ 378,
+ 379,
+ 380,
+ 381,
+ 382,
+ 383,
+ 384,
+ 385,
+ 386,
+ 387,
+ 388,
+ 389,
+ 390,
+ 391,
+ 392,
+ 393,
+ 394,
+ 395,
+ 396,
+ 397,
+ 398,
+ 399,
+ 400,
+ 401,
+ 402,
+ 403,
+ 404,
+ 405,
+ 406,
+ 407,
+ 408,
+ 409,
+ 410,
+ 411,
+ 412,
+ 413,
+ 414,
+ 415,
+ 416,
+ 417,
+ 418,
+ 419,
+ 420,
+ 421,
+ 422,
+ 423,
+ 424,
+ 425,
+ 426,
+ 427,
+ 428,
+ 429,
+ 430,
+ 431,
+ 432,
+ 433,
+ 434,
+ 435,
+ 436,
+ 437,
+ 438,
+ 439,
+ 440,
+ 441,
+ 442,
+ 443,
+ 444,
+ 445,
+ 446,
+ 447,
+ 448,
+ 449,
+ 450,
+ 451,
+ 452,
+ 453,
+ 454,
+ 455,
+ 456,
+ 457,
+ 458,
+ 459,
+ 460,
+ 461,
+ 462,
+ 463,
+ 464,
+ 465,
+ 466,
+ 467,
+ 468,
+ 469,
+ 470,
+ 471,
+ 472,
+ 473,
+ 474,
+ 475,
+ 476,
+ 477,
+ 478,
+ 479,
+ 480,
+ 481,
+ 482,
+ 483,
+ 484,
+ 485,
+ 486,
+ 487,
+ 488,
+ 489,
+ 490,
+ 491,
+ 492,
+ 493,
+ 494,
+ 495,
+ 496,
+ 497,
+ 498,
+ 499,
+ 500,
+ 501,
+ 502,
+ 503,
+ 504,
+ 505,
+ 506,
+ 507,
+ 508,
+ 509,
+]
+
+
+class ImgLinearBackbone(nn.Module):
+ def __init__(
+ self,
+ d_model: int,
+ patch_size: int,
+ in_chan: int = 3,
+ ) -> None:
+ super().__init__()
+
+ self.conv_proj = nn.Conv2d(
+ in_chan,
+ out_channels=d_model,
+ kernel_size=patch_size,
+ stride=patch_size,
+ )
+ self.d_model = d_model
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.conv_proj(x)
+ x = x.flatten(start_dim=-2).transpose(1, 2)
+ return x
+
+
+class Encoder(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ self.patch_size = 16
+ self.d_model = 768
+ self.dropout = 0
+ self.activation = "gelu"
+ self.norm_first = True
+ self.ff_ratio = 4
+ self.nhead = 12
+ self.max_seq_len = 1024
+ self.n_encoder_layer = 12
+ encoder_layer = nn.TransformerEncoderLayer(
+ self.d_model,
+ nhead=self.nhead,
+ dim_feedforward=self.ff_ratio * self.d_model,
+ dropout=self.dropout,
+ activation=self.activation,
+ batch_first=True,
+ norm_first=self.norm_first,
+ )
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
+ self.norm = norm_layer(self.d_model)
+ self.backbone = ImgLinearBackbone(
+ d_model=self.d_model, patch_size=self.patch_size
+ )
+ self.pos_embed = PositionEmbedding(
+ max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
+ )
+ self.encoder = nn.TransformerEncoder(
+ encoder_layer, num_layers=self.n_encoder_layer, enable_nested_tensor=False
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ src_feature = self.backbone(x)
+ src_feature = self.pos_embed(src_feature)
+ memory = self.encoder(src_feature)
+ memory = self.norm(memory)
+ return memory
+
+
+class PositionEmbedding(nn.Module):
+ def __init__(self, max_seq_len: int, d_model: int, dropout: float) -> None:
+ super().__init__()
+ self.embedding = nn.Embedding(max_seq_len, d_model)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
+ # assume x is batch first
+ if input_pos is None:
+ _pos = torch.arange(x.shape[1], device=x.device)
+ else:
+ _pos = input_pos
+ out = self.embedding(_pos)
+ return self.dropout(out + x)
+
+
+class TokenEmbedding(nn.Module):
+ def __init__(
+ self,
+ vocab_size: int,
+ d_model: int,
+ padding_idx: int,
+ ) -> None:
+ super().__init__()
+ assert vocab_size > 0
+ self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
+
+ def forward(self, x: Tensor) -> Tensor:
+ return self.embedding(x)
+
+
+def find_multiple(n: int, k: int) -> int:
+ if n % k == 0:
+ return n
+ return n + k - (n % k)
+
+
+@dataclass
+class ModelArgs:
+ n_layer: int = 4
+ n_head: int = 12
+ dim: int = 768
+ intermediate_size: int = None
+ head_dim: int = 64
+ activation: str = "gelu"
+ norm_first: bool = True
+
+ def __post_init__(self):
+ if self.intermediate_size is None:
+ hidden_dim = 4 * self.dim
+ n_hidden = int(2 * hidden_dim / 3)
+ self.intermediate_size = find_multiple(n_hidden, 256)
+ self.head_dim = self.dim // self.n_head
+
+
+class KVCache(nn.Module):
+ def __init__(
+ self,
+ max_batch_size,
+ max_seq_length,
+ n_heads,
+ head_dim,
+ dtype=torch.bfloat16,
+ device="cpu",
+ ):
+ super().__init__()
+ cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
+ self.register_buffer(
+ "k_cache",
+ torch.zeros(cache_shape, dtype=dtype, device=device),
+ persistent=False,
+ )
+ self.register_buffer(
+ "v_cache",
+ torch.zeros(cache_shape, dtype=dtype, device=device),
+ persistent=False,
+ )
+
+ def update(self, input_pos, k_val, v_val):
+ # input_pos: [S], k_val: [B, H, S, D]
+ # assert input_pos.shape[0] == k_val.shape[2]
+
+ bs = k_val.shape[0]
+ k_out = self.k_cache
+ v_out = self.v_cache
+ k_out[:bs, :, input_pos] = k_val
+ v_out[:bs, :, input_pos] = v_val
+
+ return k_out[:bs], v_out[:bs]
+
+
+class GPTFastDecoder(nn.Module):
+ def __init__(self) -> None:
+ super().__init__()
+
+ self.vocab_size = 960
+ self.padding_idx = 2
+ self.prefix_token_id = 11
+ self.eos_id = 1
+ self.max_seq_len = 1024
+ self.dropout = 0
+ self.d_model = 768
+ self.nhead = 12
+ self.activation = "gelu"
+ self.norm_first = True
+ self.n_decoder_layer = 4
+ config = ModelArgs(
+ n_layer=self.n_decoder_layer,
+ n_head=self.nhead,
+ dim=self.d_model,
+ intermediate_size=self.d_model * 4,
+ activation=self.activation,
+ norm_first=self.norm_first,
+ )
+ self.config = config
+ self.layers = nn.ModuleList(
+ TransformerBlock(config) for _ in range(config.n_layer)
+ )
+ self.token_embed = TokenEmbedding(
+ vocab_size=self.vocab_size,
+ d_model=self.d_model,
+ padding_idx=self.padding_idx,
+ )
+ self.pos_embed = PositionEmbedding(
+ max_seq_len=self.max_seq_len, d_model=self.d_model, dropout=self.dropout
+ )
+ self.generator = nn.Linear(self.d_model, self.vocab_size)
+ self.token_white_list = TOKEN_WHITE_LIST
+ self.mask_cache: Optional[Tensor] = None
+ self.max_batch_size = -1
+ self.max_seq_length = -1
+
+ def setup_caches(self, max_batch_size, max_seq_length, dtype, device):
+ for b in self.layers:
+ b.multihead_attn.k_cache = None
+ b.multihead_attn.v_cache = None
+
+ if (
+ self.max_seq_length >= max_seq_length
+ and self.max_batch_size >= max_batch_size
+ ):
+ return
+ head_dim = self.config.dim // self.config.n_head
+ max_seq_length = find_multiple(max_seq_length, 8)
+ self.max_seq_length = max_seq_length
+ self.max_batch_size = max_batch_size
+
+ for b in self.layers:
+ b.self_attn.kv_cache = KVCache(
+ max_batch_size,
+ max_seq_length,
+ self.config.n_head,
+ head_dim,
+ dtype,
+ device,
+ )
+ b.multihead_attn.k_cache = None
+ b.multihead_attn.v_cache = None
+
+ self.causal_mask = torch.tril(
+ torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)
+ ).to(device)
+
+ def forward(self, memory: Tensor, tgt: Tensor) -> Tensor:
+ input_pos = torch.tensor([tgt.shape[1] - 1], device=tgt.device, dtype=torch.int)
+ tgt = tgt[:, -1:]
+ tgt_feature = self.pos_embed(self.token_embed(tgt), input_pos=input_pos)
+ # tgt = self.decoder(tgt_feature, memory, input_pos)
+ with torch.backends.cuda.sdp_kernel(
+ enable_flash=False, enable_mem_efficient=False, enable_math=True
+ ):
+ logits = tgt_feature
+ tgt_mask = self.causal_mask[None, None, input_pos]
+ for i, layer in enumerate(self.layers):
+ logits = layer(logits, memory, input_pos=input_pos, tgt_mask=tgt_mask)
+ # return output
+ logits = self.generator(logits)[:, -1, :]
+ total = set([i for i in range(logits.shape[-1])])
+ black_list = list(total.difference(set(self.token_white_list)))
+ logits[..., black_list] = -1e9
+ probs = F.softmax(logits, dim=-1)
+ _, next_tokens = probs.topk(1)
+ return next_tokens
+
+
+class TransformerBlock(nn.Module):
+ def __init__(self, config: ModelArgs) -> None:
+ super().__init__()
+ self.self_attn = Attention(config)
+ self.multihead_attn = CrossAttention(config)
+
+ layer_norm_eps = 1e-5
+
+ d_model = config.dim
+ dim_feedforward = config.intermediate_size
+
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+ self.norm_first = config.norm_first
+ self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+ self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
+
+ self.activation = _get_activation_fn(config.activation)
+
+ def forward(
+ self,
+ tgt: Tensor,
+ memory: Tensor,
+ tgt_mask: Tensor,
+ input_pos: Tensor,
+ ) -> Tensor:
+ if self.norm_first:
+ x = tgt
+ x = x + self.self_attn(self.norm1(x), tgt_mask, input_pos)
+ x = x + self.multihead_attn(self.norm2(x), memory)
+ x = x + self._ff_block(self.norm3(x))
+ else:
+ x = tgt
+ x = self.norm1(x + self.self_attn(x, tgt_mask, input_pos))
+ x = self.norm2(x + self.multihead_attn(x, memory))
+ x = self.norm3(x + self._ff_block(x))
+ return x
+
+ def _ff_block(self, x: Tensor) -> Tensor:
+ x = self.linear2(self.activation(self.linear1(x)))
+ return x
+
+
+class Attention(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ # key, query, value projections for all heads, but in a batch
+ self.wqkv = nn.Linear(config.dim, 3 * config.dim)
+ self.wo = nn.Linear(config.dim, config.dim)
+
+ self.kv_cache: Optional[KVCache] = None
+
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+ self.dim = config.dim
+
+ def forward(
+ self,
+ x: Tensor,
+ mask: Tensor,
+ input_pos: Optional[Tensor] = None,
+ ) -> Tensor:
+ bsz, seqlen, _ = x.shape
+
+ kv_size = self.n_head * self.head_dim
+ q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
+
+ q = q.view(bsz, seqlen, self.n_head, self.head_dim)
+ k = k.view(bsz, seqlen, self.n_head, self.head_dim)
+ v = v.view(bsz, seqlen, self.n_head, self.head_dim)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ if self.kv_cache is not None:
+ k, v = self.kv_cache.update(input_pos, k, v)
+
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
+
+ y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
+
+ y = self.wo(y)
+ return y
+
+
+class CrossAttention(nn.Module):
+ def __init__(self, config: ModelArgs):
+ super().__init__()
+ assert config.dim % config.n_head == 0
+
+ self.query = nn.Linear(config.dim, config.dim)
+ self.key = nn.Linear(config.dim, config.dim)
+ self.value = nn.Linear(config.dim, config.dim)
+ self.out = nn.Linear(config.dim, config.dim)
+
+ self.k_cache = None
+ self.v_cache = None
+
+ self.n_head = config.n_head
+ self.head_dim = config.head_dim
+
+ def get_kv(self, xa: torch.Tensor):
+ if self.k_cache is not None and self.v_cache is not None:
+ return self.k_cache, self.v_cache
+
+ k = self.key(xa)
+ v = self.value(xa)
+
+ # Reshape for correct format
+ batch_size, source_seq_len, _ = k.shape
+ k = k.view(batch_size, source_seq_len, self.n_head, self.head_dim)
+ v = v.view(batch_size, source_seq_len, self.n_head, self.head_dim)
+
+ if self.k_cache is None:
+ self.k_cache = k
+ if self.v_cache is None:
+ self.v_cache = v
+
+ return k, v
+
+ def forward(
+ self,
+ x: Tensor,
+ xa: Tensor,
+ ):
+ q = self.query(x)
+ batch_size, target_seq_len, _ = q.shape
+ q = q.view(batch_size, target_seq_len, self.n_head, self.head_dim)
+ k, v = self.get_kv(xa)
+
+ q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
+
+ wv = F.scaled_dot_product_attention(
+ query=q,
+ key=k,
+ value=v,
+ is_causal=False,
+ )
+ wv = wv.transpose(1, 2).reshape(
+ batch_size,
+ target_seq_len,
+ self.n_head * self.head_dim,
+ )
+
+ return self.out(wv)
diff --git a/helper/content_recognition/rapid_table_pipeline/table_structure/utils.py b/helper/content_recognition/rapid_table_pipeline/table_structure/utils.py
new file mode 100644
index 0000000..484a63d
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/table_structure/utils.py
@@ -0,0 +1,544 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+import os
+import platform
+import traceback
+from enum import Enum
+from pathlib import Path
+from typing import Any, Dict, List, Tuple, Union
+
+import cv2
+import numpy as np
+from onnxruntime import (
+ GraphOptimizationLevel,
+ InferenceSession,
+ SessionOptions,
+ get_available_providers,
+ get_device,
+)
+
+from rapid_table.utils import Logger
+
+
+class EP(Enum):
+ CPU_EP = "CPUExecutionProvider"
+ CUDA_EP = "CUDAExecutionProvider"
+ DIRECTML_EP = "DmlExecutionProvider"
+
+
+class OrtInferSession:
+ def __init__(self, config: Dict[str, Any]):
+ self.logger = Logger(logger_name=__name__).get_log()
+
+ model_path = config.get("model_path", None)
+ self._verify_model(model_path)
+
+ self.cfg_use_cuda = config.get("use_cuda", None)
+ self.cfg_use_dml = config.get("use_dml", None)
+
+ self.had_providers: List[str] = get_available_providers()
+ EP_list = self._get_ep_list()
+
+ sess_opt = self._init_sess_opts(config)
+ self.session = InferenceSession(
+ model_path,
+ sess_options=sess_opt,
+ providers=EP_list,
+ )
+ self._verify_providers()
+
+ @staticmethod
+ def _init_sess_opts(config: Dict[str, Any]) -> SessionOptions:
+ sess_opt = SessionOptions()
+ sess_opt.log_severity_level = 4
+ sess_opt.enable_cpu_mem_arena = False
+ sess_opt.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL
+
+ cpu_nums = os.cpu_count()
+ intra_op_num_threads = config.get("intra_op_num_threads", -1)
+ if intra_op_num_threads != -1 and 1 <= intra_op_num_threads <= cpu_nums:
+ sess_opt.intra_op_num_threads = intra_op_num_threads
+
+ inter_op_num_threads = config.get("inter_op_num_threads", -1)
+ if inter_op_num_threads != -1 and 1 <= inter_op_num_threads <= cpu_nums:
+ sess_opt.inter_op_num_threads = inter_op_num_threads
+
+ return sess_opt
+
+ def get_metadata(self, key: str = "character") -> list:
+ meta_dict = self.session.get_modelmeta().custom_metadata_map
+ content_list = meta_dict[key].splitlines()
+ return content_list
+
+ def _get_ep_list(self) -> List[Tuple[str, Dict[str, Any]]]:
+ cpu_provider_opts = {
+ "arena_extend_strategy": "kSameAsRequested",
+ }
+ EP_list = [(EP.CPU_EP.value, cpu_provider_opts)]
+
+ cuda_provider_opts = {
+ "device_id": 0,
+ "arena_extend_strategy": "kNextPowerOfTwo",
+ "cudnn_conv_algo_search": "EXHAUSTIVE",
+ "do_copy_in_default_stream": True,
+ }
+ self.use_cuda = self._check_cuda()
+ if self.use_cuda:
+ EP_list.insert(0, (EP.CUDA_EP.value, cuda_provider_opts))
+
+ self.use_directml = self._check_dml()
+ if self.use_directml:
+ self.logger.info(
+ "Windows 10 or above detected, try to use DirectML as primary provider"
+ )
+ directml_options = (
+ cuda_provider_opts if self.use_cuda else cpu_provider_opts
+ )
+ EP_list.insert(0, (EP.DIRECTML_EP.value, directml_options))
+ return EP_list
+
+ def _check_cuda(self) -> bool:
+ if not self.cfg_use_cuda:
+ return False
+
+ cur_device = get_device()
+ if cur_device == "GPU" and EP.CUDA_EP.value in self.had_providers:
+ return True
+
+ self.logger.warning(
+ "%s is not in available providers (%s). Use %s inference by default.",
+ EP.CUDA_EP.value,
+ self.had_providers,
+ self.had_providers[0],
+ )
+ self.logger.info("!!!Recommend to use rapidocr_paddle for inference on GPU.")
+ self.logger.info(
+ "(For reference only) If you want to use GPU acceleration, you must do:"
+ )
+ self.logger.info(
+ "First, uninstall all onnxruntime pakcages in current environment."
+ )
+ self.logger.info(
+ "Second, install onnxruntime-gpu by `pip install onnxruntime-gpu`."
+ )
+ self.logger.info(
+ "\tNote the onnxruntime-gpu version must match your cuda and cudnn version."
+ )
+ self.logger.info(
+ "\tYou can refer this link: https://onnxruntime.ai/docs/execution-providers/CUDA-EP.html"
+ )
+ self.logger.info(
+ "Third, ensure %s is in available providers list. e.g. ['CUDAExecutionProvider', 'CPUExecutionProvider']",
+ EP.CUDA_EP.value,
+ )
+ return False
+
+ def _check_dml(self) -> bool:
+ if not self.cfg_use_dml:
+ return False
+
+ cur_os = platform.system()
+ if cur_os != "Windows":
+ self.logger.warning(
+ "DirectML is only supported in Windows OS. The current OS is %s. Use %s inference by default.",
+ cur_os,
+ self.had_providers[0],
+ )
+ return False
+
+ cur_window_version = int(platform.release().split(".")[0])
+ if cur_window_version < 10:
+ self.logger.warning(
+ "DirectML is only supported in Windows 10 and above OS. The current Windows version is %s. Use %s inference by default.",
+ cur_window_version,
+ self.had_providers[0],
+ )
+ return False
+
+ if EP.DIRECTML_EP.value in self.had_providers:
+ return True
+
+ self.logger.warning(
+ "%s is not in available providers (%s). Use %s inference by default.",
+ EP.DIRECTML_EP.value,
+ self.had_providers,
+ self.had_providers[0],
+ )
+ self.logger.info("If you want to use DirectML acceleration, you must do:")
+ self.logger.info(
+ "First, uninstall all onnxruntime pakcages in current environment."
+ )
+ self.logger.info(
+ "Second, install onnxruntime-directml by `pip install onnxruntime-directml`"
+ )
+ self.logger.info(
+ "Third, ensure %s is in available providers list. e.g. ['DmlExecutionProvider', 'CPUExecutionProvider']",
+ EP.DIRECTML_EP.value,
+ )
+ return False
+
+ def _verify_providers(self):
+ session_providers = self.session.get_providers()
+ first_provider = session_providers[0]
+
+ if self.use_cuda and first_provider != EP.CUDA_EP.value:
+ self.logger.warning(
+ "%s is not avaiable for current env, the inference part is automatically shifted to be executed under %s.",
+ EP.CUDA_EP.value,
+ first_provider,
+ )
+
+ if self.use_directml and first_provider != EP.DIRECTML_EP.value:
+ self.logger.warning(
+ "%s is not available for current env, the inference part is automatically shifted to be executed under %s.",
+ EP.DIRECTML_EP.value,
+ first_provider,
+ )
+
+ def __call__(self, input_content: List[np.ndarray]) -> np.ndarray:
+ input_dict = dict(zip(self.get_input_names(), input_content))
+ try:
+ return self.session.run(None, input_dict)
+ except Exception as e:
+ error_info = traceback.format_exc()
+ raise ONNXRuntimeError(error_info) from e
+
+ def get_input_names(self) -> List[str]:
+ return [v.name for v in self.session.get_inputs()]
+
+ def get_output_names(self) -> List[str]:
+ return [v.name for v in self.session.get_outputs()]
+
+ def get_character_list(self, key: str = "character") -> List[str]:
+ meta_dict = self.session.get_modelmeta().custom_metadata_map
+ return meta_dict[key].splitlines()
+
+ def have_key(self, key: str = "character") -> bool:
+ meta_dict = self.session.get_modelmeta().custom_metadata_map
+ if key in meta_dict.keys():
+ return True
+ return False
+
+ @staticmethod
+ def _verify_model(model_path: Union[str, Path, None]):
+ if model_path is None:
+ raise ValueError("model_path is None!")
+
+ model_path = Path(model_path)
+ if not model_path.exists():
+ raise FileNotFoundError(f"{model_path} does not exists.")
+
+ if not model_path.is_file():
+ raise FileExistsError(f"{model_path} is not a file.")
+
+
+class ONNXRuntimeError(Exception):
+ pass
+
+
+class TableLabelDecode:
+ def __init__(self, dict_character, merge_no_span_structure=True, **kwargs):
+ if merge_no_span_structure:
+ if " | " not in dict_character:
+ dict_character.append(" | ")
+ if "" in dict_character:
+ dict_character.remove(" | ")
+
+ dict_character = self.add_special_char(dict_character)
+ self.dict = {}
+ for i, char in enumerate(dict_character):
+ self.dict[char] = i
+ self.character = dict_character
+ self.td_token = [" | ", " | | "]
+
+ def __call__(self, preds, batch=None):
+ structure_probs = preds["structure_probs"]
+ bbox_preds = preds["loc_preds"]
+ shape_list = batch[-1]
+ result = self.decode(structure_probs, bbox_preds, shape_list)
+ if len(batch) == 1: # only contains shape
+ return result
+
+ label_decode_result = self.decode_label(batch)
+ return result, label_decode_result
+
+ def decode(self, structure_probs, bbox_preds, shape_list):
+ """convert text-label into text-index."""
+ ignored_tokens = self.get_ignored_tokens()
+ end_idx = self.dict[self.end_str]
+
+ structure_idx = structure_probs.argmax(axis=2)
+ structure_probs = structure_probs.max(axis=2)
+
+ structure_batch_list = []
+ bbox_batch_list = []
+ batch_size = len(structure_idx)
+ for batch_idx in range(batch_size):
+ structure_list = []
+ bbox_list = []
+ score_list = []
+ for idx in range(len(structure_idx[batch_idx])):
+ char_idx = int(structure_idx[batch_idx][idx])
+ if idx > 0 and char_idx == end_idx:
+ break
+
+ if char_idx in ignored_tokens:
+ continue
+
+ text = self.character[char_idx]
+ if text in self.td_token:
+ bbox = bbox_preds[batch_idx, idx]
+ bbox = self._bbox_decode(bbox, shape_list[batch_idx])
+ bbox_list.append(bbox)
+ structure_list.append(text)
+ score_list.append(structure_probs[batch_idx, idx])
+ structure_batch_list.append([structure_list, np.mean(score_list)])
+ bbox_batch_list.append(np.array(bbox_list))
+ result = {
+ "bbox_batch_list": bbox_batch_list,
+ "structure_batch_list": structure_batch_list,
+ }
+ return result
+
+ def decode_label(self, batch):
+ """convert text-label into text-index."""
+ structure_idx = batch[1]
+ gt_bbox_list = batch[2]
+ shape_list = batch[-1]
+ ignored_tokens = self.get_ignored_tokens()
+ end_idx = self.dict[self.end_str]
+
+ structure_batch_list = []
+ bbox_batch_list = []
+ batch_size = len(structure_idx)
+ for batch_idx in range(batch_size):
+ structure_list = []
+ bbox_list = []
+ for idx in range(len(structure_idx[batch_idx])):
+ char_idx = int(structure_idx[batch_idx][idx])
+ if idx > 0 and char_idx == end_idx:
+ break
+
+ if char_idx in ignored_tokens:
+ continue
+
+ structure_list.append(self.character[char_idx])
+
+ bbox = gt_bbox_list[batch_idx][idx]
+ if bbox.sum() != 0:
+ bbox = self._bbox_decode(bbox, shape_list[batch_idx])
+ bbox_list.append(bbox)
+
+ structure_batch_list.append(structure_list)
+ bbox_batch_list.append(bbox_list)
+ result = {
+ "bbox_batch_list": bbox_batch_list,
+ "structure_batch_list": structure_batch_list,
+ }
+ return result
+
+ def _bbox_decode(self, bbox, shape):
+ h, w = shape[:2]
+ bbox[0::2] *= w
+ bbox[1::2] *= h
+ return bbox
+
+ def get_ignored_tokens(self):
+ beg_idx = self.get_beg_end_flag_idx("beg")
+ end_idx = self.get_beg_end_flag_idx("end")
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ if beg_or_end == "beg":
+ return np.array(self.dict[self.beg_str])
+
+ if beg_or_end == "end":
+ return np.array(self.dict[self.end_str])
+
+ raise TypeError(f"unsupport type {beg_or_end} in get_beg_end_flag_idx")
+
+ def add_special_char(self, dict_character):
+ self.beg_str = "sos"
+ self.end_str = "eos"
+ dict_character = [self.beg_str] + dict_character + [self.end_str]
+ return dict_character
+
+
+class TablePreprocess:
+ def __init__(self):
+ self.table_max_len = 488
+ self.build_pre_process_list()
+ self.ops = self.create_operators()
+
+ def __call__(self, data):
+ """transform"""
+ if self.ops is None:
+ self.ops = []
+
+ for op in self.ops:
+ data = op(data)
+ if data is None:
+ return None
+ return data
+
+ def create_operators(
+ self,
+ ):
+ """
+ create operators based on the config
+
+ Args:
+ params(list): a dict list, used to create some operators
+ """
+ assert isinstance(
+ self.pre_process_list, list
+ ), "operator config should be a list"
+ ops = []
+ for operator in self.pre_process_list:
+ assert (
+ isinstance(operator, dict) and len(operator) == 1
+ ), "yaml format error"
+ op_name = list(operator)[0]
+ param = {} if operator[op_name] is None else operator[op_name]
+ op = eval(op_name)(**param)
+ ops.append(op)
+ return ops
+
+ def build_pre_process_list(self):
+ resize_op = {
+ "ResizeTableImage": {
+ "max_len": self.table_max_len,
+ }
+ }
+ pad_op = {
+ "PaddingTableImage": {"size": [self.table_max_len, self.table_max_len]}
+ }
+ normalize_op = {
+ "NormalizeImage": {
+ "std": [0.229, 0.224, 0.225],
+ "mean": [0.485, 0.456, 0.406],
+ "scale": "1./255.",
+ "order": "hwc",
+ }
+ }
+ to_chw_op = {"ToCHWImage": None}
+ keep_keys_op = {"KeepKeys": {"keep_keys": ["image", "shape"]}}
+ self.pre_process_list = [
+ resize_op,
+ normalize_op,
+ pad_op,
+ to_chw_op,
+ keep_keys_op,
+ ]
+
+
+class ResizeTableImage:
+ def __init__(self, max_len, resize_bboxes=False, infer_mode=False):
+ super(ResizeTableImage, self).__init__()
+ self.max_len = max_len
+ self.resize_bboxes = resize_bboxes
+ self.infer_mode = infer_mode
+
+ def __call__(self, data):
+ img = data["image"]
+ height, width = img.shape[0:2]
+ ratio = self.max_len / (max(height, width) * 1.0)
+ resize_h = int(height * ratio)
+ resize_w = int(width * ratio)
+ resize_img = cv2.resize(img, (resize_w, resize_h))
+ if self.resize_bboxes and not self.infer_mode:
+ data["bboxes"] = data["bboxes"] * ratio
+ data["image"] = resize_img
+ data["src_img"] = img
+ data["shape"] = np.array([height, width, ratio, ratio])
+ data["max_len"] = self.max_len
+ return data
+
+
+class PaddingTableImage:
+ def __init__(self, size, **kwargs):
+ super(PaddingTableImage, self).__init__()
+ self.size = size
+
+ def __call__(self, data):
+ img = data["image"]
+ pad_h, pad_w = self.size
+ padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
+ height, width = img.shape[0:2]
+ padding_img[0:height, 0:width, :] = img.copy()
+ data["image"] = padding_img
+ shape = data["shape"].tolist()
+ shape.extend([pad_h, pad_w])
+ data["shape"] = np.array(shape)
+ return data
+
+
+class NormalizeImage:
+ """normalize image such as substract mean, divide std"""
+
+ def __init__(self, scale=None, mean=None, std=None, order="chw", **kwargs):
+ if isinstance(scale, str):
+ scale = eval(scale)
+ self.scale = np.float32(scale if scale is not None else 1.0 / 255.0)
+ mean = mean if mean is not None else [0.485, 0.456, 0.406]
+ std = std if std is not None else [0.229, 0.224, 0.225]
+
+ shape = (3, 1, 1) if order == "chw" else (1, 1, 3)
+ self.mean = np.array(mean).reshape(shape).astype("float32")
+ self.std = np.array(std).reshape(shape).astype("float32")
+
+ def __call__(self, data):
+ img = np.array(data["image"])
+ assert isinstance(img, np.ndarray), "invalid input 'img' in NormalizeImage"
+ data["image"] = (img.astype("float32") * self.scale - self.mean) / self.std
+ return data
+
+
+class ToCHWImage:
+ """convert hwc image to chw image"""
+
+ def __init__(self, **kwargs):
+ pass
+
+ def __call__(self, data):
+ img = np.array(data["image"])
+ data["image"] = img.transpose((2, 0, 1))
+ return data
+
+
+class KeepKeys:
+ def __init__(self, keep_keys, **kwargs):
+ self.keep_keys = keep_keys
+
+ def __call__(self, data):
+ data_list = []
+ for key in self.keep_keys:
+ data_list.append(data[key])
+ return data_list
+
+
+def trans_char_ocr_res(ocr_res):
+ word_result = []
+ for res in ocr_res:
+ score = res[2]
+ for word_box, word in zip(res[3], res[4]):
+ word_res = []
+ word_res.append(word_box)
+ word_res.append(word)
+ word_res.append(score)
+ word_result.append(word_res)
+ return word_result
diff --git a/helper/content_recognition/rapid_table_pipeline/utils/__init__.py b/helper/content_recognition/rapid_table_pipeline/utils/__init__.py
new file mode 100644
index 0000000..8754555
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/utils/__init__.py
@@ -0,0 +1,8 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+from .download_model import DownloadModel
+from .load_image import LoadImage
+from .logger import Logger
+from .utils import is_url
+from .vis import VisTable
diff --git a/helper/content_recognition/rapid_table_pipeline/utils/download_model.py b/helper/content_recognition/rapid_table_pipeline/utils/download_model.py
new file mode 100644
index 0000000..7d35a88
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/utils/download_model.py
@@ -0,0 +1,67 @@
+import io
+from pathlib import Path
+from typing import Optional, Union
+
+import requests
+from tqdm import tqdm
+
+from .logger import Logger
+
+PROJECT_DIR = Path(__file__).resolve().parent.parent
+DEFAULT_MODEL_DIR = PROJECT_DIR / "models"
+
+
+class DownloadModel:
+ logger = Logger(logger_name=__name__).get_log()
+
+ @classmethod
+ def download(
+ cls,
+ model_full_url: Union[str, Path],
+ save_dir: Union[str, Path, None] = None,
+ save_model_name: Optional[str] = None,
+ ) -> str:
+ if save_dir is None:
+ save_dir = DEFAULT_MODEL_DIR
+
+ save_dir.mkdir(parents=True, exist_ok=True)
+
+ if save_model_name is None:
+ save_model_name = Path(model_full_url).name
+
+ save_file_path = save_dir / save_model_name
+ if save_file_path.exists():
+ cls.logger.info("%s already exists", save_file_path)
+ return str(save_file_path)
+
+ try:
+ cls.logger.info("Download %s to %s", model_full_url, save_dir)
+ file = cls.download_as_bytes_with_progress(model_full_url, save_model_name)
+ cls.save_file(save_file_path, file)
+ except Exception as exc:
+ raise DownloadModelError from exc
+ return str(save_file_path)
+
+ @staticmethod
+ def download_as_bytes_with_progress(
+ url: Union[str, Path], name: Optional[str] = None
+ ) -> bytes:
+ resp = requests.get(str(url), stream=True, allow_redirects=True, timeout=180)
+ total = int(resp.headers.get("content-length", 0))
+ bio = io.BytesIO()
+ with tqdm(
+ desc=name, total=total, unit="b", unit_scale=True, unit_divisor=1024
+ ) as pbar:
+ for chunk in resp.iter_content(chunk_size=65536):
+ pbar.update(len(chunk))
+ bio.write(chunk)
+ return bio.getvalue()
+
+ @staticmethod
+ def save_file(save_path: Union[str, Path], file: bytes):
+ with open(save_path, "wb") as f:
+ f.write(file)
+
+
+class DownloadModelError(Exception):
+ pass
diff --git a/helper/content_recognition/rapid_table_pipeline/utils/load_image.py b/helper/content_recognition/rapid_table_pipeline/utils/load_image.py
new file mode 100644
index 0000000..86ab953
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/utils/load_image.py
@@ -0,0 +1,131 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+from io import BytesIO
+from pathlib import Path
+from typing import Any, Union
+
+import cv2
+import numpy as np
+import requests
+from PIL import Image, UnidentifiedImageError
+
+from .utils import is_url
+
+root_dir = Path(__file__).resolve().parent
+InputType = Union[str, np.ndarray, bytes, Path, Image.Image]
+
+
+class LoadImage:
+ def __init__(self):
+ pass
+
+ def __call__(self, img: InputType) -> np.ndarray:
+ if not isinstance(img, InputType.__args__):
+ raise LoadImageError(
+ f"The img type {type(img)} does not in {InputType.__args__}"
+ )
+
+ origin_img_type = type(img)
+ img = self.load_img(img)
+ img = self.convert_img(img, origin_img_type)
+ return img
+
+ def load_img(self, img: InputType) -> np.ndarray:
+ if isinstance(img, (str, Path)):
+ if is_url(img):
+ img = Image.open(requests.get(img, stream=True, timeout=60).raw)
+ else:
+ self.verify_exist(img)
+ img = Image.open(img)
+
+ try:
+ img = self.img_to_ndarray(img)
+ except UnidentifiedImageError as e:
+ raise LoadImageError(f"cannot identify image file {img}") from e
+ return img
+
+ if isinstance(img, bytes):
+ img = self.img_to_ndarray(Image.open(BytesIO(img)))
+ return img
+
+ if isinstance(img, np.ndarray):
+ return img
+
+ if isinstance(img, Image.Image):
+ return self.img_to_ndarray(img)
+
+ raise LoadImageError(f"{type(img)} is not supported!")
+
+ def img_to_ndarray(self, img: Image.Image) -> np.ndarray:
+ if img.mode == "1":
+ img = img.convert("L")
+ return np.array(img)
+ return np.array(img)
+
+ def convert_img(self, img: np.ndarray, origin_img_type: Any) -> np.ndarray:
+ if img.ndim == 2:
+ return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ if img.ndim == 3:
+ channel = img.shape[2]
+ if channel == 1:
+ return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
+
+ if channel == 2:
+ return self.cvt_two_to_three(img)
+
+ if channel == 3:
+ if issubclass(origin_img_type, (str, Path, bytes, Image.Image)):
+ return cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+ return img
+
+ if channel == 4:
+ return self.cvt_four_to_three(img)
+
+ raise LoadImageError(
+ f"The channel({channel}) of the img is not in [1, 2, 3, 4]"
+ )
+
+ raise LoadImageError(f"The ndim({img.ndim}) of the img is not in [2, 3]")
+
+ @staticmethod
+ def cvt_two_to_three(img: np.ndarray) -> np.ndarray:
+ """gray + alpha → BGR"""
+ img_gray = img[..., 0]
+ img_bgr = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2BGR)
+
+ img_alpha = img[..., 1]
+ not_a = cv2.bitwise_not(img_alpha)
+ not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
+
+ new_img = cv2.bitwise_and(img_bgr, img_bgr, mask=img_alpha)
+ new_img = cv2.add(new_img, not_a)
+ return new_img
+
+ @staticmethod
+ def cvt_four_to_three(img: np.ndarray) -> np.ndarray:
+ """RGBA → BGR"""
+ r, g, b, a = cv2.split(img)
+ new_img = cv2.merge((b, g, r))
+
+ not_a = cv2.bitwise_not(a)
+ not_a = cv2.cvtColor(not_a, cv2.COLOR_GRAY2BGR)
+
+ new_img = cv2.bitwise_and(new_img, new_img, mask=a)
+
+ mean_color = np.mean(new_img)
+ if mean_color <= 0.0:
+ new_img = cv2.add(new_img, not_a)
+ else:
+ new_img = cv2.bitwise_not(new_img)
+ return new_img
+
+ @staticmethod
+ def verify_exist(file_path: Union[str, Path]):
+ if not Path(file_path).exists():
+ raise LoadImageError(f"{file_path} does not exist.")
+
+
+class LoadImageError(Exception):
+ pass
diff --git a/helper/content_recognition/rapid_table_pipeline/utils/logger.py b/helper/content_recognition/rapid_table_pipeline/utils/logger.py
new file mode 100644
index 0000000..d3be7fe
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/utils/logger.py
@@ -0,0 +1,37 @@
+# -*- encoding: utf-8 -*-
+# @Author: Jocker1212
+# @Contact: xinyijianggo@gmail.com
+import logging
+
+import colorlog
+
+
+class Logger:
+ def __init__(self, log_level=logging.DEBUG, logger_name=None):
+ self.logger = logging.getLogger(logger_name)
+ self.logger.setLevel(log_level)
+ self.logger.propagate = False
+
+ formatter = colorlog.ColoredFormatter(
+ "%(log_color)s[%(levelname)s] %(asctime)s [RapidTable] %(filename)s:%(lineno)d: %(message)s",
+ log_colors={
+ "DEBUG": "cyan",
+ "INFO": "green",
+ "WARNING": "yellow",
+ "ERROR": "red",
+ "CRITICAL": "red,bg_white",
+ },
+ )
+
+ if not self.logger.handlers:
+ console_handler = logging.StreamHandler()
+ console_handler.setFormatter(formatter)
+
+ for handler in self.logger.handlers:
+ self.logger.removeHandler(handler)
+
+ console_handler.setLevel(log_level)
+ self.logger.addHandler(console_handler)
+
+ def get_log(self):
+ return self.logger
diff --git a/helper/content_recognition/rapid_table_pipeline/utils/utils.py b/helper/content_recognition/rapid_table_pipeline/utils/utils.py
new file mode 100644
index 0000000..a182929
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/utils/utils.py
@@ -0,0 +1,12 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+from urllib.parse import urlparse
+
+
+def is_url(url: str) -> bool:
+ try:
+ result = urlparse(url)
+ return all([result.scheme, result.netloc])
+ except Exception:
+ return False
diff --git a/helper/content_recognition/rapid_table_pipeline/utils/vis.py b/helper/content_recognition/rapid_table_pipeline/utils/vis.py
new file mode 100644
index 0000000..88fc69b
--- /dev/null
+++ b/helper/content_recognition/rapid_table_pipeline/utils/vis.py
@@ -0,0 +1,145 @@
+# -*- encoding: utf-8 -*-
+# @Author: SWHL
+# @Contact: liekkaskono@163.com
+import os
+from pathlib import Path
+from typing import Optional, Union
+
+import cv2
+import numpy as np
+
+from .load_image import LoadImage
+
+
+class VisTable:
+ def __init__(self):
+ self.load_img = LoadImage()
+
+ def __call__(
+ self,
+ img_path: Union[str, Path],
+ table_results,
+ save_html_path: Optional[str] = None,
+ save_drawed_path: Optional[str] = None,
+ save_logic_path: Optional[str] = None,
+ ):
+ if save_html_path:
+ html_with_border = self.insert_border_style(table_results.pred_html)
+ self.save_html(save_html_path, html_with_border)
+
+ table_cell_bboxes = table_results.cell_bboxes
+ if table_cell_bboxes is None:
+ return None
+
+ img = self.load_img(img_path)
+
+ dims_bboxes = table_cell_bboxes.shape[1]
+ if dims_bboxes == 4:
+ drawed_img = self.draw_rectangle(img, table_cell_bboxes)
+ elif dims_bboxes == 8:
+ drawed_img = self.draw_polylines(img, table_cell_bboxes)
+ else:
+ raise ValueError("Shape of table bounding boxes is not between in 4 or 8.")
+
+ if save_drawed_path:
+ self.save_img(save_drawed_path, drawed_img)
+
+ if save_logic_path and table_results.logic_points:
+ polygons = [[box[0], box[1], box[4], box[5]] for box in table_cell_bboxes]
+ self.plot_rec_box_with_logic_info(
+ img, save_logic_path, table_results.logic_points, polygons
+ )
+ return drawed_img
+
+ def insert_border_style(self, table_html_str: str):
+ style_res = """"""
+
+ prefix_table, suffix_table = table_html_str.split("")
+ html_with_border = f"{prefix_table}{style_res}{suffix_table}"
+ return html_with_border
+
+ def plot_rec_box_with_logic_info(
+ self, img: np.ndarray, output_path, logic_points, sorted_polygons
+ ):
+ """
+ :param img_path
+ :param output_path
+ :param logic_points: [row_start,row_end,col_start,col_end]
+ :param sorted_polygons: [xmin,ymin,xmax,ymax]
+ :return:
+ """
+ # 读取原图
+ img = cv2.copyMakeBorder(
+ img, 0, 0, 0, 100, cv2.BORDER_CONSTANT, value=[255, 255, 255]
+ )
+ # 绘制 polygons 矩形
+ for idx, polygon in enumerate(sorted_polygons):
+ x0, y0, x1, y1 = polygon[0], polygon[1], polygon[2], polygon[3]
+ x0 = round(x0)
+ y0 = round(y0)
+ x1 = round(x1)
+ y1 = round(y1)
+ cv2.rectangle(img, (x0, y0), (x1, y1), (0, 0, 255), 1)
+ # 增大字体大小和线宽
+ font_scale = 0.9 # 原先是0.5
+ thickness = 1 # 原先是1
+ logic_point = logic_points[idx]
+ cv2.putText(
+ img,
+ f"row: {logic_point[0]}-{logic_point[1]}",
+ (x0 + 3, y0 + 8),
+ cv2.FONT_HERSHEY_PLAIN,
+ font_scale,
+ (0, 0, 255),
+ thickness,
+ )
+ cv2.putText(
+ img,
+ f"col: {logic_point[2]}-{logic_point[3]}",
+ (x0 + 3, y0 + 18),
+ cv2.FONT_HERSHEY_PLAIN,
+ font_scale,
+ (0, 0, 255),
+ thickness,
+ )
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
+ # 保存绘制后的图像
+ self.save_img(output_path, img)
+
+ @staticmethod
+ def draw_rectangle(img: np.ndarray, boxes: np.ndarray) -> np.ndarray:
+ img_copy = img.copy()
+ for box in boxes.astype(int):
+ x1, y1, x2, y2 = box
+ cv2.rectangle(img_copy, (x1, y1), (x2, y2), (255, 0, 0), 2)
+ return img_copy
+
+ @staticmethod
+ def draw_polylines(img: np.ndarray, points) -> np.ndarray:
+ img_copy = img.copy()
+ for point in points.astype(int):
+ point = point.reshape(4, 2)
+ cv2.polylines(img_copy, [point.astype(int)], True, (255, 0, 0), 2)
+ return img_copy
+
+ @staticmethod
+ def save_img(save_path: Union[str, Path], img: np.ndarray):
+ cv2.imwrite(str(save_path), img)
+
+ @staticmethod
+ def save_html(save_path: Union[str, Path], html: str):
+ with open(save_path, "w", encoding="utf-8") as f:
+ f.write(html)
diff --git a/helper/content_recognition/test.py b/helper/content_recognition/test.py
new file mode 100644
index 0000000..1474bb8
--- /dev/null
+++ b/helper/content_recognition/test.py
@@ -0,0 +1,18 @@
+from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
+from magic_pdf.data.read_api import read_local_images
+from markdownify import markdownify as md
+import re
+
+
+# proc
+## Create Dataset Instance
+input_file = "/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/207.jpg"
+
+ds = read_local_images(input_file)[0]
+
+x = ds.apply(doc_analyze, ocr=True)
+x = x.pipe_ocr_mode(None)
+html = x.get_markdown(None)
+content = md(html)
+content = re.sub(r'\\([#*_`])', r'\1', content)
+print(content)
diff --git a/helper/content_recognition/utils.py b/helper/content_recognition/utils.py
new file mode 100644
index 0000000..7e9cb03
--- /dev/null
+++ b/helper/content_recognition/utils.py
@@ -0,0 +1,213 @@
+import os
+import tempfile
+import cv2
+import numpy as np
+from paddleocr import PaddleOCR
+from marker.converters.table import TableConverter
+from marker.models import create_model_dict
+from marker.output import text_from_rendered
+from .rapid_table_pipeline.main import table2md_pipeline
+from magic_pdf.model.doc_analyze_by_custom_model import doc_analyze
+from magic_pdf.data.read_api import read_local_images
+from markdownify import markdownify as md
+import re
+
+
+def scanning_document_classify(image):
+ # 判断是否是扫描件
+
+ # 将图像从BGR颜色空间转换到HSV颜色空间
+ hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
+
+ # 定义红色的HSV范围
+ lower_red1 = np.array([0, 70, 50])
+ upper_red1 = np.array([10, 255, 255])
+ lower_red2 = np.array([170, 70, 50])
+ upper_red2 = np.array([180, 255, 255])
+
+ # 创建两个掩码,一个用于低色调的红色,一个用于高色调的红色
+ mask1 = cv2.inRange(hsv, lower_red1, upper_red1)
+ mask2 = cv2.inRange(hsv, lower_red2, upper_red2)
+
+ # 将两个掩码合并
+ mask = cv2.bitwise_or(mask1, mask2)
+
+ # 计算红色区域的非零像素数量
+ non_zero_pixels = cv2.countNonZero(mask)
+ return 1 < non_zero_pixels < 1000
+
+
+def remove_watermark(image):
+ # 去除红色印章
+ _, _, r_channel = cv2.split(image)
+ r_channel[r_channel > 210] = 255
+ r_channel = cv2.cvtColor(r_channel, cv2.COLOR_GRAY2BGR)
+ return r_channel
+
+
+def html2md(html_content):
+ md_content = md(html_content)
+ md_content = re.sub(r'\\([#*_`])', r'\1', md_content)
+ return md_content
+
+
+def markdown_rec(image):
+ # TODO 可以传入文件夹
+ image_path = f'{tempfile.mktemp()}.jpg'
+ cv2.imwrite(image_path, image)
+
+ try:
+ ds = read_local_images(image_path)[0]
+ x = ds.apply(doc_analyze, ocr=True)
+ x = x.pipe_ocr_mode(None)
+ html = x.get_markdown(None)
+ finally:
+ os.remove(image_path)
+ return html2md(html)
+
+
+ocr = PaddleOCR(lang='ch') # need to run only once to download and load model into memory
+
+
+def text_rec(image):
+ result = ocr.ocr(image, cls=True)
+ output = []
+ for idx in range(len(result)):
+ res = result[idx]
+ if not res:
+ continue
+ for line in res:
+ if not line:
+ continue
+ output.append(line[1][0])
+ return output
+
+
+def table_rec(image):
+ return table2md_pipeline(image)
+
+
+table_converter = TableConverter(artifact_dict=create_model_dict())
+
+
+def scanning_document_rec(image):
+ # TODO 内部的ocr可以替换为paddleocr以提升文字识别精度
+ image_path = f'{tempfile.mktemp()}.jpg'
+ cv2.imwrite(image_path, image)
+
+ try:
+ no_watermark_image = remove_watermark(cv2.imread(image_path))
+ prefix, suffix = image_path.split('.')
+ new_image_path = f'{prefix}_remove_watermark.{suffix}'
+ cv2.imwrite(new_image_path, no_watermark_image)
+
+ rendered = table_converter(new_image_path)
+ text, _, _ = text_from_rendered(rendered)
+ finally:
+ os.remove(image_path)
+ return text, no_watermark_image
+
+
+def compute_box_distance(box1, box2):
+ x11, y11, x12, y12 = box1
+ x21, y21, x22, y22 = box2
+
+ # 计算水平和垂直方向的重叠量
+ x_overlap = max(0, min(x12, x22) - max(x11, x21))
+ y_overlap = max(0, min(y12, y22) - max(y11, y21))
+
+ # 如果有重叠(x和y都重叠),返回负的重叠深度(取 min 表示最小穿透)
+ if x_overlap > 0 and y_overlap > 0:
+ return -min(x_overlap, y_overlap)
+
+ distances = []
+
+ # 如果 x 方向有投影重叠,计算上下边的距离
+ if x12 > x21 and x11 < x22:
+ dist_top = y21 - y12 # box1下边到box2上边
+ dist_bottom = y11 - y22 # box1上边到box2下边
+ if dist_top > 0:
+ distances.append(dist_top)
+ if dist_bottom > 0:
+ distances.append(dist_bottom)
+
+ # 如果 y 方向有投影重叠,计算左右边的距离
+ if y12 > y21 and y11 < y22:
+ dist_left = x11 - x22 # box1左边到box2右边
+ dist_right = x21 - x12 # box1右边到box2左边
+ if dist_left > 0:
+ distances.append(dist_left)
+ if dist_right > 0:
+ distances.append(dist_right)
+
+ # 如果有合法的距离,返回最小值,否则说明边无法对齐,返回 None
+ return min(distances) if distances else None
+
+
+def assign_tables_to_titles(layout_results, max_distance=200):
+ tables = [_ for _ in layout_results if _.clsid == 4]
+ titles = [_ for _ in layout_results if _.clsid == 5]
+
+ table_to_title = {}
+ title_to_table = {}
+
+ changed = True
+ while changed:
+ changed = False
+ for title in titles:
+ title_id = id(title)
+
+ best_table = None
+ min_dist = float('inf')
+
+ for table in tables:
+ table_id = id(table)
+
+ dist = compute_box_distance(title.box, table.box)
+ if dist is None or dist > max_distance:
+ continue
+
+ if dist < min_dist:
+ min_dist = dist
+ best_table = table
+
+ if best_table is None:
+ continue
+
+ table_id = id(best_table)
+
+ current_table = title_to_table.get(title_id)
+ if current_table is best_table:
+ continue # 已是最优,无需更新
+
+ prev_title = table_to_title.get(table_id)
+ if prev_title:
+ prev_title_id = id(prev_title)
+ prev_dist = compute_box_distance(prev_title.box, best_table.box)
+ if prev_dist is not None and prev_dist <= min_dist:
+ continue # 原标题绑定得更近,跳过
+
+ # 解绑旧标题
+ title_to_table.pop(prev_title_id, None)
+
+ # 更新新绑定
+ title_to_table[title_id] = best_table
+ table_to_title[table_id] = title
+ changed = True # 有更新
+
+ # 最终写回绑定结果
+ for table in tables:
+ table_id = id(table)
+ title = table_to_title.get(table_id)
+ if title:
+ table.table_title = title.content
+ else:
+ table.table_title = None
+
+
+if __name__ == '__main__':
+ # content = text_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/5.jpg')
+ # content = markdown_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/3.jpg')
+ # content = table_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/6.jpg')
+ content = scanning_document_rec('/mnt/research/PaddleOCR/pdf2md_pipeline/s4_content_recognition/all_layouts/103.jpg')
+ print(content)
diff --git a/helper/db_helper.py b/helper/db_helper.py
new file mode 100644
index 0000000..fe0242a
--- /dev/null
+++ b/helper/db_helper.py
@@ -0,0 +1,128 @@
+import psycopg2
+from psycopg2 import OperationalError, extras
+from loguru import logger
+import os
+import traceback
+
+
+# 创建连接(你只需要创建一次,然后在服务中复用)
+def create_connection():
+ try:
+ conn = psycopg2.connect(
+ dbname=os.environ['POSTGRESQL_DATABASE'],
+ user=os.environ['POSTGRESQL_USERNAME'],
+ password=os.environ['POSTGRESQL_PASSWORD'],
+ host=os.environ['POSTGRESQL_HOST'],
+ port=os.environ['POSTGRESQL_PORT']
+ )
+ conn.autocommit = False
+ with conn.cursor() as cur:
+ cur.execute("SET TIME ZONE 'Asia/Shanghai';")
+ conn.commit()
+ return conn
+ except OperationalError as e:
+ logger.error(f"连接数据库失败: {e}")
+ return None
+
+
+# 插入数据的函数
+def insert_data(conn, table, data_dict, pk_name='id'):
+ """
+ 向指定表中插入数据
+ :param conn: psycopg2 connection 对象
+ :param table: 表名(字符串)
+ :param data_dict: 字典格式的数据,比如 {"column1": value1, "column2": value2}
+ """
+ if conn is None:
+ logger.error("数据库连接无效")
+ return
+
+ try:
+ with conn.cursor() as cur:
+ columns = data_dict.keys()
+ values = [data_dict[col] for col in columns]
+
+ # 构造 SQL
+ insert_query = f"""
+ INSERT INTO {table} ({', '.join(columns)})
+ VALUES ({', '.join(['%s'] * len(values))})
+ RETURNING {pk_name}
+ """
+ cur.execute(insert_query, values)
+ inserted_id = cur.fetchone()[0]
+ return inserted_id
+
+ except Exception as e:
+ logger.error(f"插入数据失败: {e}")
+ raise e
+
+
+def insert_multiple_data(conn, table, data_list, batch_size=100):
+ """
+ 批量向指定表中插入数据(使用 execute_values)
+ :param conn: psycopg2 connection 对象
+ :param table: 表名(字符串)
+ :param data_list: 包含多个字典的列表,每个字典代表一行数据,比如 [{"column1": value1, "column2": value2}, ...]
+ """
+ if conn is None:
+ logger.error("数据库连接无效")
+ return
+
+ try:
+ with conn.cursor() as cur:
+ columns = data_list[0].keys()
+ insert_query = f"""
+ INSERT INTO {table} ({', '.join(columns)})
+ VALUES %s
+ """
+ for i in range(0, len(data_list), batch_size):
+ batch = data_list[i:i + batch_size]
+ values = [tuple(d.values()) for d in batch]
+ extras.execute_values(cur, insert_query, values)
+
+ except Exception as e:
+ logger.error(f"批量插入数据失败: {e}")
+ raise e
+
+
+conn = create_connection()
+
+
+def insert_pdf2md_table(pdf_path, pdf_name, process_status, start_time, end_time, rec_results):
+ data_dict = {
+ 'path': pdf_path,
+ 'filename': pdf_name,
+ 'process_status': process_status,
+ 'analysis_start_time': start_time,
+ 'analysis_end_time': end_time
+ }
+ try:
+ inserted_id = insert_data(conn, 'pdf_info', data_dict)
+ if process_status == 2:
+ data_list = []
+ for i in range(len(rec_results)):
+ # 每一页
+ page_no = i + 1
+ for j in range(len(rec_results[i])):
+ # 每一个box
+ box = rec_results[i][j]
+ content = box.content
+ clsid = box.clsid
+ table_title = box.table_title
+ order = j
+ data_dict = {
+ 'layout_type': clsid,
+ 'content': content,
+ 'page_no': page_no,
+ 'pdf_id': inserted_id,
+ 'table_title': table_title,
+ 'display_order': order
+ }
+ data_list.append(data_dict)
+ insert_multiple_data(conn, 'pdf_analysis_output', data_list)
+ conn.commit()
+ return inserted_id
+ except Exception as e:
+ conn.rollback()
+ logger.error(f'operate database error!\n{traceback.format_exc()}')
+ raise e
diff --git a/helper/image_helper.py b/helper/image_helper.py
new file mode 100644
index 0000000..ec4c429
--- /dev/null
+++ b/helper/image_helper.py
@@ -0,0 +1,47 @@
+from typing import List
+from pdf2image import convert_from_path
+import os
+import paddleclas
+import cv2
+from .page_detection.utils import PageDetectionResult
+
+
+paddle_clas_model = paddleclas.PaddleClas(model_name="text_image_orientation")
+
+def pdf2image(pdf_path, output_dir):
+ if not os.path.isdir(output_dir):
+ os.makedirs(output_dir)
+ images = convert_from_path(pdf_path)
+ for i, image in enumerate(images):
+ image.save(f'{output_dir}/{i + 1}.jpg')
+
+
+def image_orient_cls(input_data):
+ return paddle_clas_model.predict(input_data)
+
+
+def page_detection_visual(page_detection_result: PageDetectionResult):
+ img = cv2.imread(page_detection_result.image_path)
+ for box in page_detection_result.boxes:
+ pos = box.pos
+ clsid = box.clsid
+ confidence = box.confidence
+ if clsid == 0:
+ color = (0, 0, 0)
+ text = 'text'
+ elif clsid == 1:
+ color = (255, 0, 0)
+ text = 'title'
+ elif clsid == 2:
+ color = (0, 255, 0)
+ text = 'figure'
+ elif clsid == 4:
+ color = (0, 0, 255)
+ text = 'table'
+ if clsid == 5:
+ color = (255, 0, 255)
+ text = 'table caption'
+ text = f'{text} {confidence}'
+ img = cv2.rectangle(img, (int(pos[0]), int(pos[1])), (int(pos[2]), int(pos[3])), color, 2)
+ cv2.putText(img, text, (int(pos[0]), int(pos[1])), cv2.FONT_HERSHEY_TRIPLEX, 1, color, 2)
+ return img
diff --git a/helper/page_detection/clrnet_postprocess.py b/helper/page_detection/clrnet_postprocess.py
new file mode 100644
index 0000000..efaa345
--- /dev/null
+++ b/helper/page_detection/clrnet_postprocess.py
@@ -0,0 +1,262 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import paddle
+import paddle.nn as nn
+from scipy.special import softmax
+from scipy.interpolate import InterpolatedUnivariateSpline
+
+
+def line_iou(pred, target, img_w, length=15, aligned=True):
+ '''
+ Calculate the line iou value between predictions and targets
+ Args:
+ pred: lane predictions, shape: (num_pred, 72)
+ target: ground truth, shape: (num_target, 72)
+ img_w: image width
+ length: extended radius
+ aligned: True for iou loss calculation, False for pair-wise ious in assign
+ '''
+ px1 = pred - length
+ px2 = pred + length
+ tx1 = target - length
+ tx2 = target + length
+
+ if aligned:
+ invalid_mask = target
+ ovr = paddle.minimum(px2, tx2) - paddle.maximum(px1, tx1)
+ union = paddle.maximum(px2, tx2) - paddle.minimum(px1, tx1)
+ else:
+ num_pred = pred.shape[0]
+ invalid_mask = target.tile([num_pred, 1, 1])
+
+ ovr = (paddle.minimum(px2[:, None, :], tx2[None, ...]) - paddle.maximum(
+ px1[:, None, :], tx1[None, ...]))
+ union = (paddle.maximum(px2[:, None, :], tx2[None, ...]) -
+ paddle.minimum(px1[:, None, :], tx1[None, ...]))
+
+ invalid_masks = (invalid_mask < 0) | (invalid_mask >= img_w)
+
+ ovr[invalid_masks] = 0.
+ union[invalid_masks] = 0.
+ iou = ovr.sum(axis=-1) / (union.sum(axis=-1) + 1e-9)
+ return iou
+
+
+class Lane:
+ def __init__(self, points=None, invalid_value=-2., metadata=None):
+ super(Lane, self).__init__()
+ self.curr_iter = 0
+ self.points = points
+ self.invalid_value = invalid_value
+ self.function = InterpolatedUnivariateSpline(
+ points[:, 1], points[:, 0], k=min(3, len(points) - 1))
+ self.min_y = points[:, 1].min() - 0.01
+ self.max_y = points[:, 1].max() + 0.01
+ self.metadata = metadata or {}
+
+ def __repr__(self):
+ return '[Lane]\n' + str(self.points) + '\n[/Lane]'
+
+ def __call__(self, lane_ys):
+ lane_xs = self.function(lane_ys)
+
+ lane_xs[(lane_ys < self.min_y) | (lane_ys > self.max_y
+ )] = self.invalid_value
+ return lane_xs
+
+ def to_array(self, sample_y_range, img_w, img_h):
+ self.sample_y = range(sample_y_range[0], sample_y_range[1],
+ sample_y_range[2])
+ sample_y = self.sample_y
+ img_w, img_h = img_w, img_h
+ ys = np.array(sample_y) / float(img_h)
+ xs = self(ys)
+ valid_mask = (xs >= 0) & (xs < 1)
+ lane_xs = xs[valid_mask] * img_w
+ lane_ys = ys[valid_mask] * img_h
+ lane = np.concatenate(
+ (lane_xs.reshape(-1, 1), lane_ys.reshape(-1, 1)), axis=1)
+ return lane
+
+ def __iter__(self):
+ return self
+
+ def __next__(self):
+ if self.curr_iter < len(self.points):
+ self.curr_iter += 1
+ return self.points[self.curr_iter - 1]
+ self.curr_iter = 0
+ raise StopIteration
+
+
+class CLRNetPostProcess(object):
+ """
+ Args:
+ input_shape (int): network input image size
+ ori_shape (int): ori image shape of before padding
+ scale_factor (float): scale factor of ori image
+ enable_mkldnn (bool): whether to open MKLDNN
+ """
+
+ def __init__(self, img_w, ori_img_h, cut_height, conf_threshold, nms_thres,
+ max_lanes, num_points):
+ self.img_w = img_w
+ self.conf_threshold = conf_threshold
+ self.nms_thres = nms_thres
+ self.max_lanes = max_lanes
+ self.num_points = num_points
+ self.n_strips = num_points - 1
+ self.n_offsets = num_points
+ self.ori_img_h = ori_img_h
+ self.cut_height = cut_height
+
+ self.prior_ys = paddle.linspace(
+ start=1, stop=0, num=self.n_offsets).astype('float64')
+
+ def predictions_to_pred(self, predictions):
+ """
+ Convert predictions to internal Lane structure for evaluation.
+ """
+ lanes = []
+ for lane in predictions:
+ lane_xs = lane[6:].clone()
+ start = min(
+ max(0, int(round(lane[2].item() * self.n_strips))),
+ self.n_strips)
+ length = int(round(lane[5].item()))
+ end = start + length - 1
+ end = min(end, len(self.prior_ys) - 1)
+ if start > 0:
+ mask = ((lane_xs[:start] >= 0.) &
+ (lane_xs[:start] <= 1.)).cpu().detach().numpy()[::-1]
+ mask = ~((mask.cumprod()[::-1]).astype(np.bool_))
+ lane_xs[:start][mask] = -2
+ if end < len(self.prior_ys) - 1:
+ lane_xs[end + 1:] = -2
+
+ lane_ys = self.prior_ys[lane_xs >= 0].clone()
+ lane_xs = lane_xs[lane_xs >= 0]
+ lane_xs = lane_xs.flip(axis=0).astype('float64')
+ lane_ys = lane_ys.flip(axis=0)
+
+ lane_ys = (lane_ys *
+ (self.ori_img_h - self.cut_height) + self.cut_height
+ ) / self.ori_img_h
+ if len(lane_xs) <= 1:
+ continue
+ points = paddle.stack(
+ x=(lane_xs.reshape([-1, 1]), lane_ys.reshape([-1, 1])),
+ axis=1).squeeze(axis=2)
+ lane = Lane(
+ points=points.cpu().numpy(),
+ metadata={
+ 'start_x': lane[3],
+ 'start_y': lane[2],
+ 'conf': lane[1]
+ })
+ lanes.append(lane)
+ return lanes
+
+ def lane_nms(self, predictions, scores, nms_overlap_thresh, top_k):
+ """
+ NMS for lane detection.
+ predictions: paddle.Tensor [num_lanes,conf,y,x,lenght,72offsets] [12,77]
+ scores: paddle.Tensor [num_lanes]
+ nms_overlap_thresh: float
+ top_k: int
+ """
+ # sort by scores to get idx
+ idx = scores.argsort(descending=True)
+ keep = []
+
+ condidates = predictions.clone()
+ condidates = condidates.index_select(idx)
+
+ while len(condidates) > 0:
+ keep.append(idx[0])
+ if len(keep) >= top_k or len(condidates) == 1:
+ break
+
+ ious = []
+ for i in range(1, len(condidates)):
+ ious.append(1 - line_iou(
+ condidates[i].unsqueeze(0),
+ condidates[0].unsqueeze(0),
+ img_w=self.img_w,
+ length=15))
+ ious = paddle.to_tensor(ious)
+
+ mask = ious <= nms_overlap_thresh
+ id = paddle.where(mask == False)[0]
+
+ if id.shape[0] == 0:
+ break
+ condidates = condidates[1:].index_select(id)
+ idx = idx[1:].index_select(id)
+ keep = paddle.stack(keep)
+
+ return keep
+
+ def get_lanes(self, output, as_lanes=True):
+ """
+ Convert model output to lanes.
+ """
+ softmax = nn.Softmax(axis=1)
+ decoded = []
+
+ for predictions in output:
+ if len(predictions) == 0:
+ decoded.append([])
+ continue
+ threshold = self.conf_threshold
+ scores = softmax(predictions[:, :2])[:, 1]
+ keep_inds = scores >= threshold
+ predictions = predictions[keep_inds]
+ scores = scores[keep_inds]
+
+ if predictions.shape[0] == 0:
+ decoded.append([])
+ continue
+ nms_predictions = predictions.detach().clone()
+ nms_predictions = paddle.concat(
+ x=[nms_predictions[..., :4], nms_predictions[..., 5:]], axis=-1)
+
+ nms_predictions[..., 4] = nms_predictions[..., 4] * self.n_strips
+ nms_predictions[..., 5:] = nms_predictions[..., 5:] * (
+ self.img_w - 1)
+
+ keep = self.lane_nms(
+ nms_predictions[..., 5:],
+ scores,
+ nms_overlap_thresh=self.nms_thres,
+ top_k=self.max_lanes)
+
+ predictions = predictions.index_select(keep)
+
+ if predictions.shape[0] == 0:
+ decoded.append([])
+ continue
+ predictions[:, 5] = paddle.round(predictions[:, 5] * self.n_strips)
+ if as_lanes:
+ pred = self.predictions_to_pred(predictions)
+ else:
+ pred = predictions
+ decoded.append(pred)
+ return decoded
+
+ def __call__(self, lanes_list):
+ lanes = self.get_lanes(lanes_list)
+ return lanes
diff --git a/helper/page_detection/keypoint_preprocess.py b/helper/page_detection/keypoint_preprocess.py
new file mode 100644
index 0000000..b4e50e8
--- /dev/null
+++ b/helper/page_detection/keypoint_preprocess.py
@@ -0,0 +1,243 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+this code is based on https://github.com/open-mmlab/mmpose/mmpose/core/post_processing/post_transforms.py
+"""
+import cv2
+import numpy as np
+
+
+class EvalAffine(object):
+ def __init__(self, size, stride=64):
+ super(EvalAffine, self).__init__()
+ self.size = size
+ self.stride = stride
+
+ def __call__(self, image, im_info):
+ s = self.size
+ h, w, _ = image.shape
+ trans, size_resized = get_affine_mat_kernel(h, w, s, inv=False)
+ image_resized = cv2.warpAffine(image, trans, size_resized)
+ return image_resized, im_info
+
+
+def get_affine_mat_kernel(h, w, s, inv=False):
+ if w < h:
+ w_ = s
+ h_ = int(np.ceil((s / w * h) / 64.) * 64)
+ scale_w = w
+ scale_h = h_ / w_ * w
+
+ else:
+ h_ = s
+ w_ = int(np.ceil((s / h * w) / 64.) * 64)
+ scale_h = h
+ scale_w = w_ / h_ * h
+
+ center = np.array([np.round(w / 2.), np.round(h / 2.)])
+
+ size_resized = (w_, h_)
+ trans = get_affine_transform(
+ center, np.array([scale_w, scale_h]), 0, size_resized, inv=inv)
+
+ return trans, size_resized
+
+
+def get_affine_transform(center,
+ input_size,
+ rot,
+ output_size,
+ shift=(0., 0.),
+ inv=False):
+ """Get the affine transform matrix, given the center/scale/rot/output_size.
+
+ Args:
+ center (np.ndarray[2, ]): Center of the bounding box (x, y).
+ scale (np.ndarray[2, ]): Scale of the bounding box
+ wrt [width, height].
+ rot (float): Rotation angle (degree).
+ output_size (np.ndarray[2, ]): Size of the destination heatmaps.
+ shift (0-100%): Shift translation ratio wrt the width/height.
+ Default (0., 0.).
+ inv (bool): Option to inverse the affine transform direction.
+ (inv=False: src->dst or inv=True: dst->src)
+
+ Returns:
+ np.ndarray: The transform matrix.
+ """
+ assert len(center) == 2
+ assert len(output_size) == 2
+ assert len(shift) == 2
+ if not isinstance(input_size, (np.ndarray, list)):
+ input_size = np.array([input_size, input_size], dtype=np.float32)
+ scale_tmp = input_size
+
+ shift = np.array(shift)
+ src_w = scale_tmp[0]
+ dst_w = output_size[0]
+ dst_h = output_size[1]
+
+ rot_rad = np.pi * rot / 180
+ src_dir = rotate_point([0., src_w * -0.5], rot_rad)
+ dst_dir = np.array([0., dst_w * -0.5])
+
+ src = np.zeros((3, 2), dtype=np.float32)
+ src[0, :] = center + scale_tmp * shift
+ src[1, :] = center + src_dir + scale_tmp * shift
+ src[2, :] = _get_3rd_point(src[0, :], src[1, :])
+
+ dst = np.zeros((3, 2), dtype=np.float32)
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
+ dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
+
+ if inv:
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+ else:
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+ return trans
+
+
+def get_warp_matrix(theta, size_input, size_dst, size_target):
+ """This code is based on
+ https://github.com/open-mmlab/mmpose/blob/master/mmpose/core/post_processing/post_transforms.py
+
+ Calculate the transformation matrix under the constraint of unbiased.
+ Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased
+ Data Processing for Human Pose Estimation (CVPR 2020).
+
+ Args:
+ theta (float): Rotation angle in degrees.
+ size_input (np.ndarray): Size of input image [w, h].
+ size_dst (np.ndarray): Size of output image [w, h].
+ size_target (np.ndarray): Size of ROI in input plane [w, h].
+
+ Returns:
+ matrix (np.ndarray): A matrix for transformation.
+ """
+ theta = np.deg2rad(theta)
+ matrix = np.zeros((2, 3), dtype=np.float32)
+ scale_x = size_dst[0] / size_target[0]
+ scale_y = size_dst[1] / size_target[1]
+ matrix[0, 0] = np.cos(theta) * scale_x
+ matrix[0, 1] = -np.sin(theta) * scale_x
+ matrix[0, 2] = scale_x * (
+ -0.5 * size_input[0] * np.cos(theta) + 0.5 * size_input[1] *
+ np.sin(theta) + 0.5 * size_target[0])
+ matrix[1, 0] = np.sin(theta) * scale_y
+ matrix[1, 1] = np.cos(theta) * scale_y
+ matrix[1, 2] = scale_y * (
+ -0.5 * size_input[0] * np.sin(theta) - 0.5 * size_input[1] *
+ np.cos(theta) + 0.5 * size_target[1])
+ return matrix
+
+
+def rotate_point(pt, angle_rad):
+ """Rotate a point by an angle.
+
+ Args:
+ pt (list[float]): 2 dimensional point to be rotated
+ angle_rad (float): rotation angle by radian
+
+ Returns:
+ list[float]: Rotated point.
+ """
+ assert len(pt) == 2
+ sn, cs = np.sin(angle_rad), np.cos(angle_rad)
+ new_x = pt[0] * cs - pt[1] * sn
+ new_y = pt[0] * sn + pt[1] * cs
+ rotated_pt = [new_x, new_y]
+
+ return rotated_pt
+
+
+def _get_3rd_point(a, b):
+ """To calculate the affine matrix, three pairs of points are required. This
+ function is used to get the 3rd point, given 2D points a & b.
+
+ The 3rd point is defined by rotating vector `a - b` by 90 degrees
+ anticlockwise, using b as the rotation center.
+
+ Args:
+ a (np.ndarray): point(x,y)
+ b (np.ndarray): point(x,y)
+
+ Returns:
+ np.ndarray: The 3rd point.
+ """
+ assert len(a) == 2
+ assert len(b) == 2
+ direction = a - b
+ third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
+
+ return third_pt
+
+
+class TopDownEvalAffine(object):
+ """apply affine transform to image and coords
+
+ Args:
+ trainsize (list): [w, h], the standard size used to train
+ use_udp (bool): whether to use Unbiased Data Processing.
+ records(dict): the dict contained the image and coords
+
+ Returns:
+ records (dict): contain the image and coords after tranformed
+
+ """
+
+ def __init__(self, trainsize, use_udp=False):
+ self.trainsize = trainsize
+ self.use_udp = use_udp
+
+ def __call__(self, image, im_info):
+ rot = 0
+ imshape = im_info['im_shape'][::-1]
+ center = im_info['center'] if 'center' in im_info else imshape / 2.
+ scale = im_info['scale'] if 'scale' in im_info else imshape
+ if self.use_udp:
+ trans = get_warp_matrix(
+ rot, center * 2.0,
+ [self.trainsize[0] - 1.0, self.trainsize[1] - 1.0], scale)
+ image = cv2.warpAffine(
+ image,
+ trans, (int(self.trainsize[0]), int(self.trainsize[1])),
+ flags=cv2.INTER_LINEAR)
+ else:
+ trans = get_affine_transform(center, scale, rot, self.trainsize)
+ image = cv2.warpAffine(
+ image,
+ trans, (int(self.trainsize[0]), int(self.trainsize[1])),
+ flags=cv2.INTER_LINEAR)
+
+ return image, im_info
+
+
+def expand_crop(images, rect, expand_ratio=0.3):
+ imgh, imgw, c = images.shape
+ label, conf, xmin, ymin, xmax, ymax = [int(x) for x in rect.tolist()]
+ if label != 0:
+ return None, None, None
+ org_rect = [xmin, ymin, xmax, ymax]
+ h_half = (ymax - ymin) * (1 + expand_ratio) / 2.
+ w_half = (xmax - xmin) * (1 + expand_ratio) / 2.
+ if h_half > w_half * 4 / 3:
+ w_half = h_half * 0.75
+ center = [(ymin + ymax) / 2., (xmin + xmax) / 2.]
+ ymin = max(0, int(center[0] - h_half))
+ ymax = min(imgh - 1, int(center[0] + h_half))
+ xmin = max(0, int(center[1] - w_half))
+ xmax = min(imgw - 1, int(center[1] + w_half))
+ return images[ymin:ymax, xmin:xmax, :], [xmin, ymin, xmax, ymax], org_rect
diff --git a/helper/page_detection/main.py b/helper/page_detection/main.py
new file mode 100644
index 0000000..fd23166
--- /dev/null
+++ b/helper/page_detection/main.py
@@ -0,0 +1,69 @@
+from typing import List
+from .pdf_detection import Pipeline
+from utils import non_max_suppression, merge_text_and_title_boxes, LayoutBox, PageDetectionResult
+from tqdm import tqdm
+
+
+
+"""
+ 0 - Text
+ 1 - Title
+ 2 - Figure
+ 3 - Figure caption
+ 4 - Table
+ 5 - Table caption
+ 6 - Header
+ 7 - Footer
+ 8 - Reference
+ 9 - Equation
+"""
+pipeline = Pipeline('./models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer')
+
+effective_labels = [0, 1, 2, 4, 5]
+# nms优先级,索引越低优先级越低,box重叠时优先保留表格
+label_scores = [1, 5, 0, 2, 4]
+expand_pixel = 10
+
+
+def layout_analysis(image_paths) -> List[PageDetectionResult]:
+ layout_analysis_results = []
+ for image_path in tqdm(image_paths, '版面分析'):
+ page_detecion_outputs = pipeline(image_path)
+ layout_boxes = []
+ for i in range(len(page_detecion_outputs)):
+ clsid, box, confidence = page_detecion_outputs[i]
+ if clsid in effective_labels:
+ layout_boxes.append(LayoutBox(clsid, box, confidence))
+ page_detecion_outputs = PageDetectionResult(layout_boxes, image_path)
+
+ scores = []
+ poses = []
+ for box in page_detecion_outputs.boxes:
+ # 相同的label重叠时,保留面积更大的
+ area = (box.pos[3] - box.pos[1]) * (box.pos[2] - box.pos[0])
+ area_score = area / 5000000
+ scores.append(label_scores.index(box.clsid) + area_score)
+ poses.append(box.pos)
+ indices = non_max_suppression(poses, scores, 0.2)
+ _boxes = []
+ for i in indices:
+ _boxes.append(page_detecion_outputs.boxes[i])
+ page_detecion_outputs.boxes = _boxes
+
+ for i in range(len(page_detecion_outputs.boxes) - 1, -1, -1):
+ box = page_detecion_outputs.boxes[i]
+ if box.clsid in (0, 5):
+ # 移除Table box和Figure box中的Table caption box和Text box (有些扫描件会被识别为Figure)
+ for _box in page_detecion_outputs.boxes:
+ if _box.clsid != 2 and _box.clsid != 4:
+ continue
+ if box.pos[0] > _box.pos[0] and box.pos[1] > _box.pos[1] and box.pos[2] < _box.pos[2] and box.pos[3] < _box.pos[3]:
+ page_detecion_outputs.boxes.remove(box)
+
+ # 将text和title合并起来,便于转成markdown格式
+ page_detecion_outputs.boxes = merge_text_and_title_boxes(page_detecion_outputs.boxes, (0, 1))
+ # 对box进行排序
+ page_detecion_outputs.boxes.sort(key=lambda x: (x.pos[1], x.pos[0]))
+ layout_analysis_results.append(page_detecion_outputs)
+
+ return layout_analysis_results
diff --git a/helper/page_detection/pdf_detection.py b/helper/page_detection/pdf_detection.py
new file mode 100644
index 0000000..44b484a
--- /dev/null
+++ b/helper/page_detection/pdf_detection.py
@@ -0,0 +1,918 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import yaml
+import glob
+import json
+from pathlib import Path
+
+import cv2
+import numpy as np
+import math
+import paddle
+from paddle.inference import Config
+from paddle.inference import create_predictor
+
+import sys
+# add deploy path of PaddleDetection to sys.path
+parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
+sys.path.insert(0, parent_path)
+
+from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, WarpAffine, Pad, decode_image, CULaneResize
+from picodet_postprocess import PicoDetPostProcess
+from clrnet_postprocess import CLRNetPostProcess
+from visualize import visualize_box_mask, imshow_lanes
+from utils import argsparser, Timer, multiclass_nms, coco_clsid2catid
+
+# Global dictionary
+SUPPORT_MODELS = {
+ 'YOLO', 'PPYOLOE', 'RCNN', 'SSD', 'Face', 'FCOS', 'SOLOv2', 'TTFNet',
+ 'S2ANet', 'JDE', 'FairMOT', 'DeepSORT', 'GFL', 'PicoDet', 'CenterNet',
+ 'TOOD', 'RetinaNet', 'StrongBaseline', 'STGCN', 'YOLOX', 'YOLOF', 'PPHGNet',
+ 'PPLCNet', 'DETR', 'CenterTrack', 'CLRNet'
+}
+
+
+
+class Detector(object):
+ """
+ Args:
+ pred_config (object): config of model, defined by `Config(model_dir)`
+ model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
+ device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
+ run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
+ batch_size (int): size of pre batch in inference
+ trt_min_shape (int): min shape for dynamic shape in trt
+ trt_max_shape (int): max shape for dynamic shape in trt
+ trt_opt_shape (int): opt shape for dynamic shape in trt
+ trt_calib_mode (bool): If the model is produced by TRT offline quantitative
+ calibration, trt_calib_mode need to set True
+ cpu_threads (int): cpu threads
+ enable_mkldnn (bool): whether to open MKLDNN
+ enable_mkldnn_bfloat16 (bool): whether to turn on mkldnn bfloat16
+ output_dir (str): The path of output
+ threshold (float): The threshold of score for visualization
+ delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
+ Used by action model.
+ """
+
+ def __init__(self,
+ model_dir,
+ device='CPU',
+ run_mode='paddle',
+ batch_size=1,
+ trt_min_shape=1,
+ trt_max_shape=1280,
+ trt_opt_shape=640,
+ trt_calib_mode=False,
+ cpu_threads=1,
+ enable_mkldnn=False,
+ enable_mkldnn_bfloat16=False,
+ output_dir='output',
+ threshold=0.5,
+ delete_shuffle_pass=False,
+ use_fd_format=False):
+ self.pred_config = self.set_config(
+ model_dir, use_fd_format=use_fd_format)
+ self.predictor, self.config = load_predictor(
+ model_dir,
+ self.pred_config.arch,
+ run_mode=run_mode,
+ batch_size=batch_size,
+ min_subgraph_size=self.pred_config.min_subgraph_size,
+ device=device,
+ use_dynamic_shape=self.pred_config.use_dynamic_shape,
+ trt_min_shape=trt_min_shape,
+ trt_max_shape=trt_max_shape,
+ trt_opt_shape=trt_opt_shape,
+ trt_calib_mode=trt_calib_mode,
+ cpu_threads=cpu_threads,
+ enable_mkldnn=enable_mkldnn,
+ enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
+ delete_shuffle_pass=delete_shuffle_pass)
+ self.det_times = Timer()
+ self.batch_size = batch_size
+ self.output_dir = output_dir
+ self.threshold = threshold
+ self.device = device
+
+ def set_config(self, model_dir, use_fd_format):
+ return PredictConfig(model_dir, use_fd_format=use_fd_format)
+
+ def preprocess(self, image_list):
+ preprocess_ops = []
+ for op_info in self.pred_config.preprocess_infos:
+ new_op_info = op_info.copy()
+ op_type = new_op_info.pop('type')
+ preprocess_ops.append(eval(op_type)(**new_op_info))
+
+ input_im_lst = []
+ input_im_info_lst = []
+ for im_path in image_list:
+ im, im_info = preprocess(im_path, preprocess_ops)
+ input_im_lst.append(im)
+ input_im_info_lst.append(im_info)
+ inputs = create_inputs(input_im_lst, input_im_info_lst)
+ input_names = self.predictor.get_input_names()
+ for i in range(len(input_names)):
+ input_tensor = self.predictor.get_input_handle(input_names[i])
+ if input_names[i] == 'x':
+ input_tensor.copy_from_cpu(inputs['image'])
+ else:
+ input_tensor.copy_from_cpu(inputs[input_names[i]])
+
+ return inputs
+
+ def postprocess(self, inputs, result):
+ # postprocess output of predictor
+ np_boxes_num = result['boxes_num']
+ assert isinstance(np_boxes_num, np.ndarray), \
+ '`np_boxes_num` should be a `numpy.ndarray`'
+
+ result = {k: v for k, v in result.items() if v is not None}
+ return result
+
+ def filter_box(self, result, threshold):
+ np_boxes_num = result['boxes_num']
+ boxes = result['boxes']
+ start_idx = 0
+ filter_boxes = []
+ filter_num = []
+ for i in range(len(np_boxes_num)):
+ boxes_num = np_boxes_num[i]
+ boxes_i = boxes[start_idx:start_idx + boxes_num, :]
+ idx = boxes_i[:, 1] > threshold
+ filter_boxes_i = boxes_i[idx, :]
+ filter_boxes.append(filter_boxes_i)
+ filter_num.append(filter_boxes_i.shape[0])
+ start_idx += boxes_num
+ boxes = np.concatenate(filter_boxes)
+ filter_num = np.array(filter_num)
+ filter_res = {'boxes': boxes, 'boxes_num': filter_num}
+ return filter_res
+
+ def predict(self, repeats=1, run_benchmark=False):
+ '''
+ Args:
+ repeats (int): repeats number for prediction
+ Returns:
+ result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
+ matix element:[class, score, x_min, y_min, x_max, y_max]
+ MaskRCNN's result include 'masks': np.ndarray:
+ shape: [N, im_h, im_w]
+ '''
+ # model prediction
+ np_boxes_num, np_boxes, np_masks = np.array([0]), None, None
+
+ if run_benchmark:
+ for i in range(repeats):
+ self.predictor.run()
+ if self.device == 'GPU':
+ paddle.device.cuda.synchronize()
+ else:
+ paddle.device.synchronize(device=self.device.lower())
+
+ result = dict(
+ boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num)
+ return result
+
+ for i in range(repeats):
+ self.predictor.run()
+ output_names = self.predictor.get_output_names()
+ boxes_tensor = self.predictor.get_output_handle(output_names[0])
+ np_boxes = boxes_tensor.copy_to_cpu()
+ if len(output_names) == 1:
+ # some exported model can not get tensor 'bbox_num'
+ np_boxes_num = np.array([len(np_boxes)])
+ else:
+ boxes_num = self.predictor.get_output_handle(output_names[1])
+ np_boxes_num = boxes_num.copy_to_cpu()
+ if self.pred_config.mask:
+ masks_tensor = self.predictor.get_output_handle(output_names[2])
+ np_masks = masks_tensor.copy_to_cpu()
+ result = dict(boxes=np_boxes, masks=np_masks, boxes_num=np_boxes_num)
+ return result
+
+ def merge_batch_result(self, batch_result):
+ if len(batch_result) == 1:
+ return batch_result[0]
+ res_key = batch_result[0].keys()
+ results = {k: [] for k in res_key}
+ for res in batch_result:
+ for k, v in res.items():
+ results[k].append(v)
+ for k, v in results.items():
+ if k not in ['masks', 'segm']:
+ results[k] = np.concatenate(v)
+ return results
+
+ def get_timer(self):
+ return self.det_times
+
+ def predict_image(self,
+ image_list,
+ threshold=0.5,
+ visual=True):
+ batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
+ results = []
+ for i in range(batch_loop_cnt):
+ start_index = i * self.batch_size
+ end_index = min((i + 1) * self.batch_size, len(image_list))
+ batch_image_list = image_list[start_index:end_index]
+ # preprocess
+ self.det_times.preprocess_time_s.start()
+ inputs = self.preprocess(batch_image_list)
+ self.det_times.preprocess_time_s.end()
+
+ # model prediction
+ self.det_times.inference_time_s.start()
+ result = self.predict()
+ self.det_times.inference_time_s.end()
+
+ # postprocess
+ self.det_times.postprocess_time_s.start()
+ result = self.postprocess(inputs, result)
+ self.det_times.postprocess_time_s.end()
+ self.det_times.img_num += len(batch_image_list)
+
+ if visual:
+ visualize(
+ batch_image_list,
+ result,
+ self.pred_config.labels,
+ output_dir=self.output_dir,
+ threshold=self.threshold)
+ # TODO 在这里处理batch
+ results.append(result)
+ results = self.merge_batch_result(results)
+ boxes = results['boxes']
+ expect_boxes = (boxes[:, 1] > threshold) & (boxes[:, 0] > -1)
+ boxes = boxes[expect_boxes, :]
+ output = []
+ for dt in boxes:
+ clsid, box, confidence = int(dt[0]), dt[2:].tolist(), dt[1]
+ output.append((clsid, box, confidence))
+ return output
+
+
+class DetectorSOLOv2(Detector):
+ """
+ Args:
+ model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
+ device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
+ run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
+ batch_size (int): size of pre batch in inference
+ trt_min_shape (int): min shape for dynamic shape in trt
+ trt_max_shape (int): max shape for dynamic shape in trt
+ trt_opt_shape (int): opt shape for dynamic shape in trt
+ trt_calib_mode (bool): If the model is produced by TRT offline quantitative
+ calibration, trt_calib_mode need to set True
+ cpu_threads (int): cpu threads
+ enable_mkldnn (bool): whether to open MKLDNN
+ enable_mkldnn_bfloat16 (bool): Whether to turn on mkldnn bfloat16
+ output_dir (str): The path of output
+ threshold (float): The threshold of score for visualization
+
+ """
+
+ def __init__(self,
+ model_dir,
+ device='CPU',
+ run_mode='paddle',
+ batch_size=1,
+ trt_min_shape=1,
+ trt_max_shape=1280,
+ trt_opt_shape=640,
+ trt_calib_mode=False,
+ cpu_threads=1,
+ enable_mkldnn=False,
+ enable_mkldnn_bfloat16=False,
+ output_dir='./',
+ threshold=0.5,
+ use_fd_format=False):
+ super(DetectorSOLOv2, self).__init__(
+ model_dir=model_dir,
+ device=device,
+ run_mode=run_mode,
+ batch_size=batch_size,
+ trt_min_shape=trt_min_shape,
+ trt_max_shape=trt_max_shape,
+ trt_opt_shape=trt_opt_shape,
+ trt_calib_mode=trt_calib_mode,
+ cpu_threads=cpu_threads,
+ enable_mkldnn=enable_mkldnn,
+ enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
+ output_dir=output_dir,
+ threshold=threshold,
+ use_fd_format=use_fd_format)
+
+ def predict(self, repeats=1, run_benchmark=False):
+ '''
+ Args:
+ repeats (int): repeat number for prediction
+ Returns:
+ result (dict): 'segm': np.ndarray,shape:[N, im_h, im_w]
+ 'cate_label': label of segm, shape:[N]
+ 'cate_score': confidence score of segm, shape:[N]
+ '''
+ np_segms, np_label, np_score, np_boxes_num = None, None, None, np.array(
+ [0])
+
+ if run_benchmark:
+ for i in range(repeats):
+ self.predictor.run()
+ paddle.device.cuda.synchronize()
+ result = dict(
+ segm=np_segms,
+ label=np_label,
+ score=np_score,
+ boxes_num=np_boxes_num)
+ return result
+
+ for i in range(repeats):
+ self.predictor.run()
+ output_names = self.predictor.get_output_names()
+ np_segms = self.predictor.get_output_handle(output_names[
+ 0]).copy_to_cpu()
+ np_boxes_num = self.predictor.get_output_handle(output_names[
+ 1]).copy_to_cpu()
+ np_label = self.predictor.get_output_handle(output_names[
+ 2]).copy_to_cpu()
+ np_score = self.predictor.get_output_handle(output_names[
+ 3]).copy_to_cpu()
+
+ result = dict(
+ segm=np_segms,
+ label=np_label,
+ score=np_score,
+ boxes_num=np_boxes_num)
+ return result
+
+
+class DetectorPicoDet(Detector):
+ """
+ Args:
+ model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
+ device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
+ run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
+ batch_size (int): size of pre batch in inference
+ trt_min_shape (int): min shape for dynamic shape in trt
+ trt_max_shape (int): max shape for dynamic shape in trt
+ trt_opt_shape (int): opt shape for dynamic shape in trt
+ trt_calib_mode (bool): If the model is produced by TRT offline quantitative
+ calibration, trt_calib_mode need to set True
+ cpu_threads (int): cpu threads
+ enable_mkldnn (bool): whether to turn on MKLDNN
+ enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16
+ """
+
+ def __init__(self,
+ model_dir,
+ device='CPU',
+ run_mode='paddle',
+ batch_size=1,
+ trt_min_shape=1,
+ trt_max_shape=1280,
+ trt_opt_shape=640,
+ trt_calib_mode=False,
+ cpu_threads=1,
+ enable_mkldnn=False,
+ enable_mkldnn_bfloat16=False,
+ output_dir='./',
+ threshold=0.5,
+ use_fd_format=False):
+ super(DetectorPicoDet, self).__init__(
+ model_dir=model_dir,
+ device=device,
+ run_mode=run_mode,
+ batch_size=batch_size,
+ trt_min_shape=trt_min_shape,
+ trt_max_shape=trt_max_shape,
+ trt_opt_shape=trt_opt_shape,
+ trt_calib_mode=trt_calib_mode,
+ cpu_threads=cpu_threads,
+ enable_mkldnn=enable_mkldnn,
+ enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
+ output_dir=output_dir,
+ threshold=threshold,
+ use_fd_format=use_fd_format)
+
+ def postprocess(self, inputs, result):
+ # postprocess output of predictor
+ np_score_list = result['boxes']
+ np_boxes_list = result['boxes_num']
+ postprocessor = PicoDetPostProcess(
+ inputs['image'].shape[2:],
+ inputs['im_shape'],
+ inputs['scale_factor'],
+ strides=self.pred_config.fpn_stride,
+ nms_threshold=self.pred_config.nms['nms_threshold'])
+ np_boxes, np_boxes_num = postprocessor(np_score_list, np_boxes_list)
+ result = dict(boxes=np_boxes, boxes_num=np_boxes_num)
+ return result
+
+ def predict(self, repeats=1, run_benchmark=False):
+ '''
+ Args:
+ repeats (int): repeat number for prediction
+ Returns:
+ result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
+ matix element:[class, score, x_min, y_min, x_max, y_max]
+ '''
+ np_score_list, np_boxes_list = [], []
+
+ if run_benchmark:
+ for i in range(repeats):
+ self.predictor.run()
+ paddle.device.cuda.synchronize()
+ result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
+ return result
+
+ for i in range(repeats):
+ self.predictor.run()
+ np_score_list.clear()
+ np_boxes_list.clear()
+ output_names = self.predictor.get_output_names()
+ num_outs = int(len(output_names) / 2)
+ for out_idx in range(num_outs):
+ np_score_list.append(
+ self.predictor.get_output_handle(output_names[out_idx])
+ .copy_to_cpu())
+ np_boxes_list.append(
+ self.predictor.get_output_handle(output_names[
+ out_idx + num_outs]).copy_to_cpu())
+ result = dict(boxes=np_score_list, boxes_num=np_boxes_list)
+ return result
+
+
+class DetectorCLRNet(Detector):
+ """
+ Args:
+ model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
+ device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
+ run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
+ batch_size (int): size of pre batch in inference
+ trt_min_shape (int): min shape for dynamic shape in trt
+ trt_max_shape (int): max shape for dynamic shape in trt
+ trt_opt_shape (int): opt shape for dynamic shape in trt
+ trt_calib_mode (bool): If the model is produced by TRT offline quantitative
+ calibration, trt_calib_mode need to set True
+ cpu_threads (int): cpu threads
+ enable_mkldnn (bool): whether to turn on MKLDNN
+ enable_mkldnn_bfloat16 (bool): whether to turn on MKLDNN_BFLOAT16
+ """
+
+ def __init__(self,
+ model_dir,
+ device='CPU',
+ run_mode='paddle',
+ batch_size=1,
+ trt_min_shape=1,
+ trt_max_shape=1280,
+ trt_opt_shape=640,
+ trt_calib_mode=False,
+ cpu_threads=1,
+ enable_mkldnn=False,
+ enable_mkldnn_bfloat16=False,
+ output_dir='./',
+ threshold=0.5,
+ use_fd_format=False):
+ super(DetectorCLRNet, self).__init__(
+ model_dir=model_dir,
+ device=device,
+ run_mode=run_mode,
+ batch_size=batch_size,
+ trt_min_shape=trt_min_shape,
+ trt_max_shape=trt_max_shape,
+ trt_opt_shape=trt_opt_shape,
+ trt_calib_mode=trt_calib_mode,
+ cpu_threads=cpu_threads,
+ enable_mkldnn=enable_mkldnn,
+ enable_mkldnn_bfloat16=enable_mkldnn_bfloat16,
+ output_dir=output_dir,
+ threshold=threshold,
+ use_fd_format=use_fd_format)
+
+ deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
+ with open(deploy_file) as f:
+ yml_conf = yaml.safe_load(f)
+ self.img_w = yml_conf['img_w']
+ self.ori_img_h = yml_conf['ori_img_h']
+ self.cut_height = yml_conf['cut_height']
+ self.max_lanes = yml_conf['max_lanes']
+ self.nms_thres = yml_conf['nms_thres']
+ self.num_points = yml_conf['num_points']
+ self.conf_threshold = yml_conf['conf_threshold']
+
+ def postprocess(self, inputs, result):
+ # postprocess output of predictor
+ lanes_list = result['lanes']
+ postprocessor = CLRNetPostProcess(
+ img_w=self.img_w,
+ ori_img_h=self.ori_img_h,
+ cut_height=self.cut_height,
+ conf_threshold=self.conf_threshold,
+ nms_thres=self.nms_thres,
+ max_lanes=self.max_lanes,
+ num_points=self.num_points)
+ lanes = postprocessor(lanes_list)
+ result = dict(lanes=lanes)
+ return result
+
+ def predict(self, repeats=1, run_benchmark=False):
+ '''
+ Args:
+ repeats (int): repeat number for prediction
+ Returns:
+ result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
+ matix element:[class, score, x_min, y_min, x_max, y_max]
+ '''
+ lanes_list = []
+
+ if run_benchmark:
+ for i in range(repeats):
+ self.predictor.run()
+ paddle.device.cuda.synchronize()
+ result = dict(lanes=lanes_list)
+ return result
+
+ for i in range(repeats):
+ # TODO: check the output of predictor
+ self.predictor.run()
+ lanes_list.clear()
+ output_names = self.predictor.get_output_names()
+ num_outs = int(len(output_names) / 2)
+ if num_outs == 0:
+ lanes_list.append([])
+ for out_idx in range(num_outs):
+ lanes_list.append(
+ self.predictor.get_output_handle(output_names[out_idx])
+ .copy_to_cpu())
+ result = dict(lanes=lanes_list)
+ return result
+
+
+def create_inputs(imgs, im_info):
+ """generate input for different model type
+ Args:
+ imgs (list(numpy)): list of images (np.ndarray)
+ im_info (list(dict)): list of image info
+ Returns:
+ inputs (dict): input of model
+ """
+ inputs = {}
+
+ im_shape = []
+ scale_factor = []
+ if len(imgs) == 1:
+ inputs['image'] = np.array((imgs[0], )).astype('float32')
+ inputs['im_shape'] = np.array(
+ (im_info[0]['im_shape'], )).astype('float32')
+ inputs['scale_factor'] = np.array(
+ (im_info[0]['scale_factor'], )).astype('float32')
+ return inputs
+
+ for e in im_info:
+ im_shape.append(np.array((e['im_shape'], )).astype('float32'))
+ scale_factor.append(np.array((e['scale_factor'], )).astype('float32'))
+
+ inputs['im_shape'] = np.concatenate(im_shape, axis=0)
+ inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
+
+ imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
+ max_shape_h = max([e[0] for e in imgs_shape])
+ max_shape_w = max([e[1] for e in imgs_shape])
+ padding_imgs = []
+ for img in imgs:
+ im_c, im_h, im_w = img.shape[:]
+ padding_im = np.zeros(
+ (im_c, max_shape_h, max_shape_w), dtype=np.float32)
+ padding_im[:, :im_h, :im_w] = img
+ padding_imgs.append(padding_im)
+ inputs['image'] = np.stack(padding_imgs, axis=0)
+ return inputs
+
+
+class PredictConfig():
+ """set config of preprocess, postprocess and visualize
+ Args:
+ model_dir (str): root path of model.yml
+ """
+
+ def __init__(self, model_dir, use_fd_format=False):
+ # parsing Yaml config for Preprocess
+ fd_deploy_file = os.path.join(model_dir, 'inference.yml')
+ ppdet_deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
+ if use_fd_format:
+ if not os.path.exists(fd_deploy_file) and os.path.exists(
+ ppdet_deploy_file):
+ raise RuntimeError(
+ "Non-FD format model detected. Please set `use_fd_format` to False."
+ )
+ deploy_file = fd_deploy_file
+ else:
+ if not os.path.exists(ppdet_deploy_file) and os.path.exists(
+ fd_deploy_file):
+ raise RuntimeError(
+ "FD format model detected. Please set `use_fd_format` to False."
+ )
+ deploy_file = ppdet_deploy_file
+ with open(deploy_file) as f:
+ yml_conf = yaml.safe_load(f)
+ self.check_model(yml_conf)
+ self.arch = yml_conf['arch']
+ self.preprocess_infos = yml_conf['Preprocess']
+ self.min_subgraph_size = yml_conf['min_subgraph_size']
+ self.labels = yml_conf['label_list']
+ self.mask = False
+ self.use_dynamic_shape = yml_conf['use_dynamic_shape']
+ if 'mask' in yml_conf:
+ self.mask = yml_conf['mask']
+ self.tracker = None
+ if 'tracker' in yml_conf:
+ self.tracker = yml_conf['tracker']
+ if 'NMS' in yml_conf:
+ self.nms = yml_conf['NMS']
+ if 'fpn_stride' in yml_conf:
+ self.fpn_stride = yml_conf['fpn_stride']
+ if self.arch == 'RCNN' and yml_conf.get('export_onnx', False):
+ print(
+ 'The RCNN export model is used for ONNX and it only supports batch_size = 1'
+ )
+ self.print_config()
+
+ def check_model(self, yml_conf):
+ """
+ Raises:
+ ValueError: loaded model not in supported model type
+ """
+ for support_model in SUPPORT_MODELS:
+ if support_model in yml_conf['arch']:
+ return True
+ raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
+ 'arch'], SUPPORT_MODELS))
+
+ def print_config(self):
+ print('----------- Model Configuration -----------')
+ print('%s: %s' % ('Model Arch', self.arch))
+ print('%s: ' % ('Transform Order'))
+ for op_info in self.preprocess_infos:
+ print('--%s: %s' % ('transform op', op_info['type']))
+ print('--------------------------------------------')
+
+
+def load_predictor(model_dir,
+ arch,
+ run_mode='paddle',
+ batch_size=1,
+ device='CPU',
+ min_subgraph_size=3,
+ use_dynamic_shape=False,
+ trt_min_shape=1,
+ trt_max_shape=1280,
+ trt_opt_shape=640,
+ trt_calib_mode=False,
+ cpu_threads=1,
+ enable_mkldnn=False,
+ enable_mkldnn_bfloat16=False,
+ delete_shuffle_pass=False):
+ """set AnalysisConfig, generate AnalysisPredictor
+ Args:
+ model_dir (str): root path of __model__ and __params__
+ device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
+ run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
+ use_dynamic_shape (bool): use dynamic shape or not
+ trt_min_shape (int): min shape for dynamic shape in trt
+ trt_max_shape (int): max shape for dynamic shape in trt
+ trt_opt_shape (int): opt shape for dynamic shape in trt
+ trt_calib_mode (bool): If the model is produced by TRT offline quantitative
+ calibration, trt_calib_mode need to set True
+ delete_shuffle_pass (bool): whether to remove shuffle_channel_detect_pass in TensorRT.
+ Used by action model.
+ Returns:
+ predictor (PaddlePredictor): AnalysisPredictor
+ Raises:
+ ValueError: predict by TensorRT need device == 'GPU'.
+ """
+ if device != 'GPU' and run_mode != 'paddle':
+ raise ValueError(
+ "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
+ .format(run_mode, device))
+
+ if paddle.__version__ >= '3.0.0' or paddle.__version__ == '0.0.0':
+ model_path = model_dir
+ model_prefix = 'model'
+ infer_param = os.path.join(model_dir, 'model.pdiparams')
+ if not os.path.exists(infer_param):
+ if paddle.framework.use_pir_api():
+ infer_model = os.path.join(model_dir, 'inference.pdmodel')
+ else:
+ infer_model = os.path.join(model_dir, 'inference.json')
+ if not os.path.exists(infer_model):
+ raise ValueError(
+ "Cannot find any inference model in dir: {}.".format(model_dir))
+
+ config = Config(model_path, model_prefix)
+
+ else:
+ infer_model = os.path.join(model_dir, 'model.pdmodel')
+ infer_params = os.path.join(model_dir, 'model.pdiparams')
+ if not os.path.exists(infer_model):
+ infer_model = os.path.join(model_dir, 'inference.pdmodel')
+ infer_params = os.path.join(model_dir, 'inference.pdiparams')
+ if not os.path.exists(infer_model):
+ raise ValueError(
+ "Cannot find any inference model in dir: {},".format(model_dir))
+ config = Config(infer_model, infer_params)
+
+ if device == 'GPU':
+ # initial GPU memory(M), device ID
+ config.enable_use_gpu(200, 0)
+ # optimize graph and fuse op
+ config.switch_ir_optim(True)
+ else:
+ config.disable_gpu()
+ config.set_cpu_math_library_num_threads(cpu_threads)
+ if enable_mkldnn:
+ try:
+ # cache 10 different shapes for mkldnn to avoid memory leak
+ config.set_mkldnn_cache_capacity(10)
+ config.enable_mkldnn()
+ if enable_mkldnn_bfloat16:
+ config.enable_mkldnn_bfloat16()
+ except Exception as e:
+ print(
+ "The current environment does not support `mkldnn`, so disable mkldnn."
+ )
+ pass
+
+ precision_map = {
+ 'trt_int8': Config.Precision.Int8,
+ 'trt_fp32': Config.Precision.Float32,
+ 'trt_fp16': Config.Precision.Half
+ }
+ if run_mode in precision_map.keys():
+ config.enable_tensorrt_engine(
+ workspace_size=(1 << 25) * batch_size,
+ max_batch_size=batch_size,
+ min_subgraph_size=min_subgraph_size,
+ precision_mode=precision_map[run_mode],
+ use_static=False,
+ use_calib_mode=trt_calib_mode)
+ if FLAGS.collect_trt_shape_info:
+ config.collect_shape_range_info(FLAGS.tuned_trt_shape_file)
+ elif os.path.exists(FLAGS.tuned_trt_shape_file):
+ print(f'Use dynamic shape file: '
+ f'{FLAGS.tuned_trt_shape_file} for TRT...')
+ config.enable_tuned_tensorrt_dynamic_shape(
+ FLAGS.tuned_trt_shape_file, True)
+
+ if use_dynamic_shape:
+ min_input_shape = {
+ 'image': [batch_size, 3, trt_min_shape, trt_min_shape],
+ 'scale_factor': [batch_size, 2]
+ }
+ max_input_shape = {
+ 'image': [batch_size, 3, trt_max_shape, trt_max_shape],
+ 'scale_factor': [batch_size, 2]
+ }
+ opt_input_shape = {
+ 'image': [batch_size, 3, trt_opt_shape, trt_opt_shape],
+ 'scale_factor': [batch_size, 2]
+ }
+ config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
+ opt_input_shape)
+ print('trt set dynamic shape done!')
+
+ # disable print log when predict
+ config.disable_glog_info()
+ # enable shared memory
+ config.enable_memory_optim()
+ # disable feed, fetch OP, needed by zero_copy_run
+ config.switch_use_feed_fetch_ops(False)
+ if delete_shuffle_pass:
+ config.delete_pass("shuffle_channel_detect_pass")
+ predictor = create_predictor(config)
+ return predictor, config
+
+
+def visualize(image_list, result, labels, output_dir='output/', threshold=0.5):
+ # visualize the predict result
+ if 'lanes' in result:
+ for idx, image_file in enumerate(image_list):
+ lanes = result['lanes'][idx]
+ img = cv2.imread(image_file)
+ out_file = os.path.join(output_dir, os.path.basename(image_file))
+ # hard code
+ lanes = [lane.to_array([], ) for lane in lanes]
+ imshow_lanes(img, lanes, out_file=out_file)
+ return
+ start_idx = 0
+ for idx, image_file in enumerate(image_list):
+ im_bboxes_num = result['boxes_num'][idx]
+ im_results = {}
+ if 'boxes' in result:
+ im_results['boxes'] = result['boxes'][start_idx:start_idx +
+ im_bboxes_num, :]
+ if 'masks' in result:
+ im_results['masks'] = result['masks'][start_idx:start_idx +
+ im_bboxes_num, :]
+ if 'segm' in result:
+ im_results['segm'] = result['segm'][start_idx:start_idx +
+ im_bboxes_num, :]
+ if 'label' in result:
+ im_results['label'] = result['label'][start_idx:start_idx +
+ im_bboxes_num]
+ if 'score' in result:
+ im_results['score'] = result['score'][start_idx:start_idx +
+ im_bboxes_num]
+
+ start_idx += im_bboxes_num
+ im = visualize_box_mask(
+ image_file, im_results, labels, threshold=threshold)
+ img_name = os.path.split(image_file)[-1]
+ if not os.path.exists(output_dir):
+ os.makedirs(output_dir)
+ out_path = os.path.join(output_dir, img_name)
+ im.save(out_path, quality=95)
+ print("save result to: " + out_path)
+
+
+def print_arguments(args):
+ print('----------- Running Arguments -----------')
+ for arg, value in sorted(vars(args).items()):
+ print('%s: %s' % (arg, value))
+ print('------------------------------------------')
+
+
+class Pipeline(object):
+
+ def __init__(self, model_dir):
+ if FLAGS.use_fd_format:
+ deploy_file = os.path.join(model_dir, 'inference.yml')
+ else:
+ deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
+ with open(deploy_file) as f:
+ yml_conf = yaml.safe_load(f)
+ arch = yml_conf['arch']
+ detector_func = 'Detector'
+ if arch == 'SOLOv2':
+ detector_func = 'DetectorSOLOv2'
+ elif arch == 'PicoDet':
+ detector_func = 'DetectorPicoDet'
+ elif arch == "CLRNet":
+ detector_func = 'DetectorCLRNet'
+
+ self.detector = eval(detector_func)(
+ model_dir,
+ device=FLAGS.device,
+ run_mode=FLAGS.run_mode,
+ batch_size=FLAGS.batch_size,
+ trt_min_shape=FLAGS.trt_min_shape,
+ trt_max_shape=FLAGS.trt_max_shape,
+ trt_opt_shape=FLAGS.trt_opt_shape,
+ trt_calib_mode=FLAGS.trt_calib_mode,
+ cpu_threads=FLAGS.cpu_threads,
+ enable_mkldnn=FLAGS.enable_mkldnn,
+ enable_mkldnn_bfloat16=FLAGS.enable_mkldnn_bfloat16,
+ threshold=FLAGS.threshold,
+ output_dir=FLAGS.output_dir,
+ use_fd_format=FLAGS.use_fd_format)
+
+ def __call__(self, image_path):
+ if FLAGS.image_dir is None and image_path is not None:
+ assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
+ if isinstance(image_path, str):
+ image_path = [image_path]
+ results = self.detector.predict_image(
+ image_path,
+ visual=FLAGS.save_images)
+ return results
+
+
+paddle.enable_static()
+parser = argsparser()
+FLAGS = parser.parse_args()
+print_arguments(FLAGS)
+FLAGS.device = 'GPU'
+FLAGS.save_images = False
+FLAGS.device = FLAGS.device.upper()
+assert FLAGS.device in ['CPU', 'GPU', 'XPU', 'NPU', 'MLU', 'GCU'
+ ], "device should be CPU, GPU, XPU, MLU, NPU or GCU"
+assert not FLAGS.use_gpu, "use_gpu has been deprecated, please use --device"
+
+assert not (
+ FLAGS.enable_mkldnn == False and FLAGS.enable_mkldnn_bfloat16 == True
+), 'To enable mkldnn bfloat, please turn on both enable_mkldnn and enable_mkldnn_bfloat16'
diff --git a/helper/page_detection/picodet_postprocess.py b/helper/page_detection/picodet_postprocess.py
new file mode 100644
index 0000000..7df13f8
--- /dev/null
+++ b/helper/page_detection/picodet_postprocess.py
@@ -0,0 +1,227 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from scipy.special import softmax
+
+
+def hard_nms(box_scores, iou_threshold, top_k=-1, candidate_size=200):
+ """
+ Args:
+ box_scores (N, 5): boxes in corner-form and probabilities.
+ iou_threshold: intersection over union threshold.
+ top_k: keep top_k results. If k <= 0, keep all the results.
+ candidate_size: only consider the candidates with the highest scores.
+ Returns:
+ picked: a list of indexes of the kept boxes
+ """
+ scores = box_scores[:, -1]
+ boxes = box_scores[:, :-1]
+ picked = []
+ indexes = np.argsort(scores)
+ indexes = indexes[-candidate_size:]
+ while len(indexes) > 0:
+ current = indexes[-1]
+ picked.append(current)
+ if 0 < top_k == len(picked) or len(indexes) == 1:
+ break
+ current_box = boxes[current, :]
+ indexes = indexes[:-1]
+ rest_boxes = boxes[indexes, :]
+ iou = iou_of(
+ rest_boxes,
+ np.expand_dims(
+ current_box, axis=0), )
+ indexes = indexes[iou <= iou_threshold]
+
+ return box_scores[picked, :]
+
+
+def iou_of(boxes0, boxes1, eps=1e-5):
+ """Return intersection-over-union (Jaccard index) of boxes.
+ Args:
+ boxes0 (N, 4): ground truth boxes.
+ boxes1 (N or 1, 4): predicted boxes.
+ eps: a small number to avoid 0 as denominator.
+ Returns:
+ iou (N): IoU values.
+ """
+ overlap_left_top = np.maximum(boxes0[..., :2], boxes1[..., :2])
+ overlap_right_bottom = np.minimum(boxes0[..., 2:], boxes1[..., 2:])
+
+ overlap_area = area_of(overlap_left_top, overlap_right_bottom)
+ area0 = area_of(boxes0[..., :2], boxes0[..., 2:])
+ area1 = area_of(boxes1[..., :2], boxes1[..., 2:])
+ return overlap_area / (area0 + area1 - overlap_area + eps)
+
+
+def area_of(left_top, right_bottom):
+ """Compute the areas of rectangles given two corners.
+ Args:
+ left_top (N, 2): left top corner.
+ right_bottom (N, 2): right bottom corner.
+ Returns:
+ area (N): return the area.
+ """
+ hw = np.clip(right_bottom - left_top, 0.0, None)
+ return hw[..., 0] * hw[..., 1]
+
+
+class PicoDetPostProcess(object):
+ """
+ Args:
+ input_shape (int): network input image size
+ ori_shape (int): ori image shape of before padding
+ scale_factor (float): scale factor of ori image
+ enable_mkldnn (bool): whether to open MKLDNN
+ """
+
+ def __init__(self,
+ input_shape,
+ ori_shape,
+ scale_factor,
+ strides=[8, 16, 32, 64],
+ score_threshold=0.4,
+ nms_threshold=0.5,
+ nms_top_k=1000,
+ keep_top_k=100):
+ self.ori_shape = ori_shape
+ self.input_shape = input_shape
+ self.scale_factor = scale_factor
+ self.strides = strides
+ self.score_threshold = score_threshold
+ self.nms_threshold = nms_threshold
+ self.nms_top_k = nms_top_k
+ self.keep_top_k = keep_top_k
+
+ def warp_boxes(self, boxes, ori_shape):
+ """Apply transform to boxes
+ """
+ width, height = ori_shape[1], ori_shape[0]
+ n = len(boxes)
+ if n:
+ # warp points
+ xy = np.ones((n * 4, 3))
+ xy[:, :2] = boxes[:, [0, 1, 2, 3, 0, 3, 2, 1]].reshape(
+ n * 4, 2) # x1y1, x2y2, x1y2, x2y1
+ # xy = xy @ M.T # transform
+ xy = (xy[:, :2] / xy[:, 2:3]).reshape(n, 8) # rescale
+ # create new boxes
+ x = xy[:, [0, 2, 4, 6]]
+ y = xy[:, [1, 3, 5, 7]]
+ xy = np.concatenate(
+ (x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
+ # clip boxes
+ xy[:, [0, 2]] = xy[:, [0, 2]].clip(0, width)
+ xy[:, [1, 3]] = xy[:, [1, 3]].clip(0, height)
+ return xy.astype(np.float32)
+ else:
+ return boxes
+
+ def __call__(self, scores, raw_boxes):
+ batch_size = raw_boxes[0].shape[0]
+ reg_max = int(raw_boxes[0].shape[-1] / 4 - 1)
+ out_boxes_num = []
+ out_boxes_list = []
+ for batch_id in range(batch_size):
+ # generate centers
+ decode_boxes = []
+ select_scores = []
+ for stride, box_distribute, score in zip(self.strides, raw_boxes,
+ scores):
+ box_distribute = box_distribute[batch_id]
+ score = score[batch_id]
+ # centers
+ fm_h = self.input_shape[0] / stride
+ fm_w = self.input_shape[1] / stride
+ h_range = np.arange(fm_h)
+ w_range = np.arange(fm_w)
+ ww, hh = np.meshgrid(w_range, h_range)
+ ct_row = (hh.flatten() + 0.5) * stride
+ ct_col = (ww.flatten() + 0.5) * stride
+ center = np.stack((ct_col, ct_row, ct_col, ct_row), axis=1)
+
+ # box distribution to distance
+ reg_range = np.arange(reg_max + 1)
+ box_distance = box_distribute.reshape((-1, reg_max + 1))
+ box_distance = softmax(box_distance, axis=1)
+ box_distance = box_distance * np.expand_dims(reg_range, axis=0)
+ box_distance = np.sum(box_distance, axis=1).reshape((-1, 4))
+ box_distance = box_distance * stride
+
+ # top K candidate
+ topk_idx = np.argsort(score.max(axis=1))[::-1]
+ topk_idx = topk_idx[:self.nms_top_k]
+ center = center[topk_idx]
+ score = score[topk_idx]
+ box_distance = box_distance[topk_idx]
+
+ # decode box
+ decode_box = center + [-1, -1, 1, 1] * box_distance
+
+ select_scores.append(score)
+ decode_boxes.append(decode_box)
+
+ # nms
+ bboxes = np.concatenate(decode_boxes, axis=0)
+ confidences = np.concatenate(select_scores, axis=0)
+ picked_box_probs = []
+ picked_labels = []
+ for class_index in range(0, confidences.shape[1]):
+ probs = confidences[:, class_index]
+ mask = probs > self.score_threshold
+ probs = probs[mask]
+ if probs.shape[0] == 0:
+ continue
+ subset_boxes = bboxes[mask, :]
+ box_probs = np.concatenate(
+ [subset_boxes, probs.reshape(-1, 1)], axis=1)
+ box_probs = hard_nms(
+ box_probs,
+ iou_threshold=self.nms_threshold,
+ top_k=self.keep_top_k, )
+ picked_box_probs.append(box_probs)
+ picked_labels.extend([class_index] * box_probs.shape[0])
+
+ if len(picked_box_probs) == 0:
+ out_boxes_list.append(np.empty((0, 4)))
+ out_boxes_num.append(0)
+
+ else:
+ picked_box_probs = np.concatenate(picked_box_probs)
+
+ # resize output boxes
+ picked_box_probs[:, :4] = self.warp_boxes(
+ picked_box_probs[:, :4], self.ori_shape[batch_id])
+ im_scale = np.concatenate([
+ self.scale_factor[batch_id][::-1],
+ self.scale_factor[batch_id][::-1]
+ ])
+ picked_box_probs[:, :4] /= im_scale
+ # clas score box
+ out_boxes_list.append(
+ np.concatenate(
+ [
+ np.expand_dims(
+ np.array(picked_labels),
+ axis=-1), np.expand_dims(
+ picked_box_probs[:, 4], axis=-1),
+ picked_box_probs[:, :4]
+ ],
+ axis=1))
+ out_boxes_num.append(len(picked_labels))
+
+ out_boxes_list = np.concatenate(out_boxes_list, axis=0)
+ out_boxes_num = np.asarray(out_boxes_num).astype(np.int32)
+ return out_boxes_list, out_boxes_num
diff --git a/helper/page_detection/preprocess.py b/helper/page_detection/preprocess.py
new file mode 100644
index 0000000..1936d3e
--- /dev/null
+++ b/helper/page_detection/preprocess.py
@@ -0,0 +1,549 @@
+# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import numpy as np
+import imgaug.augmenters as iaa
+from keypoint_preprocess import get_affine_transform
+from PIL import Image
+
+
+def decode_image(im_file, im_info):
+ """read rgb image
+ Args:
+ im_file (str|np.ndarray): input can be image path or np.ndarray
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ if isinstance(im_file, str):
+ with open(im_file, 'rb') as f:
+ im_read = f.read()
+ data = np.frombuffer(im_read, dtype='uint8')
+ im = cv2.imdecode(data, 1) # BGR mode, but need RGB mode
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
+ else:
+ im = im_file
+ im_info['im_shape'] = np.array(im.shape[:2], dtype=np.float32)
+ im_info['scale_factor'] = np.array([1., 1.], dtype=np.float32)
+ return im, im_info
+
+
+class Resize_Mult32(object):
+ """resize image by target_size and max_size
+ Args:
+ target_size (int): the target size of image
+ keep_ratio (bool): whether keep_ratio or not, default true
+ interp (int): method of resize
+ """
+
+ def __init__(self, limit_side_len, limit_type, interp=cv2.INTER_LINEAR):
+ self.limit_side_len = limit_side_len
+ self.limit_type = limit_type
+ self.interp = interp
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ im_channel = im.shape[2]
+ im_scale_y, im_scale_x = self.generate_scale(im)
+ im = cv2.resize(
+ im,
+ None,
+ None,
+ fx=im_scale_x,
+ fy=im_scale_y,
+ interpolation=self.interp)
+ im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
+ im_info['scale_factor'] = np.array(
+ [im_scale_y, im_scale_x]).astype('float32')
+ return im, im_info
+
+ def generate_scale(self, img):
+ """
+ Args:
+ img (np.ndarray): image (np.ndarray)
+ Returns:
+ im_scale_x: the resize ratio of X
+ im_scale_y: the resize ratio of Y
+ """
+ limit_side_len = self.limit_side_len
+ h, w, c = img.shape
+
+ # limit the max side
+ if self.limit_type == 'max':
+ if h > w:
+ ratio = float(limit_side_len) / h
+ else:
+ ratio = float(limit_side_len) / w
+ elif self.limit_type == 'min':
+ if h < w:
+ ratio = float(limit_side_len) / h
+ else:
+ ratio = float(limit_side_len) / w
+ elif self.limit_type == 'resize_long':
+ ratio = float(limit_side_len) / max(h, w)
+ else:
+ raise Exception('not support limit type, image ')
+ resize_h = int(h * ratio)
+ resize_w = int(w * ratio)
+
+ resize_h = max(int(round(resize_h / 32) * 32), 32)
+ resize_w = max(int(round(resize_w / 32) * 32), 32)
+
+ im_scale_y = resize_h / float(h)
+ im_scale_x = resize_w / float(w)
+ return im_scale_y, im_scale_x
+
+
+class Resize(object):
+ """resize image by target_size and max_size
+ Args:
+ target_size (int): the target size of image
+ keep_ratio (bool): whether keep_ratio or not, default true
+ interp (int): method of resize
+ """
+
+ def __init__(self, target_size, keep_ratio=True, interp=cv2.INTER_LINEAR):
+ if isinstance(target_size, int):
+ target_size = [target_size, target_size]
+ self.target_size = target_size
+ self.keep_ratio = keep_ratio
+ self.interp = interp
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ assert len(self.target_size) == 2
+ assert self.target_size[0] > 0 and self.target_size[1] > 0
+ im_channel = im.shape[2]
+ im_scale_y, im_scale_x = self.generate_scale(im)
+ im = cv2.resize(
+ im,
+ None,
+ None,
+ fx=im_scale_x,
+ fy=im_scale_y,
+ interpolation=self.interp)
+ im_info['im_shape'] = np.array(im.shape[:2]).astype('float32')
+ im_info['scale_factor'] = np.array(
+ [im_scale_y, im_scale_x]).astype('float32')
+ return im, im_info
+
+ def generate_scale(self, im):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ Returns:
+ im_scale_x: the resize ratio of X
+ im_scale_y: the resize ratio of Y
+ """
+ origin_shape = im.shape[:2]
+ im_c = im.shape[2]
+ if self.keep_ratio:
+ im_size_min = np.min(origin_shape)
+ im_size_max = np.max(origin_shape)
+ target_size_min = np.min(self.target_size)
+ target_size_max = np.max(self.target_size)
+ im_scale = float(target_size_min) / float(im_size_min)
+ if np.round(im_scale * im_size_max) > target_size_max:
+ im_scale = float(target_size_max) / float(im_size_max)
+ im_scale_x = im_scale
+ im_scale_y = im_scale
+ else:
+ resize_h, resize_w = self.target_size
+ im_scale_y = resize_h / float(origin_shape[0])
+ im_scale_x = resize_w / float(origin_shape[1])
+ return im_scale_y, im_scale_x
+
+
+class ShortSizeScale(object):
+ """
+ Scale images by short size.
+ Args:
+ short_size(float | int): Short size of an image will be scaled to the short_size.
+ fixed_ratio(bool): Set whether to zoom according to a fixed ratio. default: True
+ do_round(bool): Whether to round up when calculating the zoom ratio. default: False
+ backend(str): Choose pillow or cv2 as the graphics processing backend. default: 'pillow'
+ """
+
+ def __init__(self,
+ short_size,
+ fixed_ratio=True,
+ keep_ratio=None,
+ do_round=False,
+ backend='pillow'):
+ self.short_size = short_size
+ assert (fixed_ratio and not keep_ratio) or (
+ not fixed_ratio
+ ), "fixed_ratio and keep_ratio cannot be true at the same time"
+ self.fixed_ratio = fixed_ratio
+ self.keep_ratio = keep_ratio
+ self.do_round = do_round
+
+ assert backend in [
+ 'pillow', 'cv2'
+ ], "Scale's backend must be pillow or cv2, but get {backend}"
+
+ self.backend = backend
+
+ def __call__(self, img):
+ """
+ Performs resize operations.
+ Args:
+ img (PIL.Image): a PIL.Image.
+ return:
+ resized_img: a PIL.Image after scaling.
+ """
+
+ result_img = None
+
+ if isinstance(img, np.ndarray):
+ h, w, _ = img.shape
+ elif isinstance(img, Image.Image):
+ w, h = img.size
+ else:
+ raise NotImplementedError
+
+ if w <= h:
+ ow = self.short_size
+ if self.fixed_ratio: # default is True
+ oh = int(self.short_size * 4.0 / 3.0)
+ elif not self.keep_ratio: # no
+ oh = self.short_size
+ else:
+ scale_factor = self.short_size / w
+ oh = int(h * float(scale_factor) +
+ 0.5) if self.do_round else int(h * self.short_size / w)
+ ow = int(w * float(scale_factor) +
+ 0.5) if self.do_round else int(w * self.short_size / h)
+ else:
+ oh = self.short_size
+ if self.fixed_ratio:
+ ow = int(self.short_size * 4.0 / 3.0)
+ elif not self.keep_ratio: # no
+ ow = self.short_size
+ else:
+ scale_factor = self.short_size / h
+ oh = int(h * float(scale_factor) +
+ 0.5) if self.do_round else int(h * self.short_size / w)
+ ow = int(w * float(scale_factor) +
+ 0.5) if self.do_round else int(w * self.short_size / h)
+
+ if type(img) == np.ndarray:
+ img = Image.fromarray(img, mode='RGB')
+
+ if self.backend == 'pillow':
+ result_img = img.resize((ow, oh), Image.BILINEAR)
+ elif self.backend == 'cv2' and (self.keep_ratio is not None):
+ result_img = cv2.resize(
+ img, (ow, oh), interpolation=cv2.INTER_LINEAR)
+ else:
+ result_img = Image.fromarray(
+ cv2.resize(
+ np.asarray(img), (ow, oh), interpolation=cv2.INTER_LINEAR))
+
+ return result_img
+
+
+class NormalizeImage(object):
+ """normalize image
+ Args:
+ mean (list): im - mean
+ std (list): im / std
+ is_scale (bool): whether need im / 255
+ norm_type (str): type in ['mean_std', 'none']
+ """
+
+ def __init__(self, mean, std, is_scale=True, norm_type='mean_std'):
+ self.mean = mean
+ self.std = std
+ self.is_scale = is_scale
+ self.norm_type = norm_type
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ im = im.astype(np.float32, copy=False)
+ if self.is_scale:
+ scale = 1.0 / 255.0
+ im *= scale
+
+ if self.norm_type == 'mean_std':
+ mean = np.array(self.mean)[np.newaxis, np.newaxis, :]
+ std = np.array(self.std)[np.newaxis, np.newaxis, :]
+ im -= mean
+ im /= std
+ return im, im_info
+
+
+class Permute(object):
+ """permute image
+ Args:
+ to_bgr (bool): whether convert RGB to BGR
+ channel_first (bool): whether convert HWC to CHW
+ """
+
+ def __init__(self, ):
+ super(Permute, self).__init__()
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ im = im.transpose((2, 0, 1)).copy()
+ return im, im_info
+
+
+class PadStride(object):
+ """ padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
+ Args:
+ stride (bool): model with FPN need image shape % stride == 0
+ """
+
+ def __init__(self, stride=0):
+ self.coarsest_stride = stride
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ coarsest_stride = self.coarsest_stride
+ if coarsest_stride <= 0:
+ return im, im_info
+ im_c, im_h, im_w = im.shape
+ pad_h = int(np.ceil(float(im_h) / coarsest_stride) * coarsest_stride)
+ pad_w = int(np.ceil(float(im_w) / coarsest_stride) * coarsest_stride)
+ padding_im = np.zeros((im_c, pad_h, pad_w), dtype=np.float32)
+ padding_im[:, :im_h, :im_w] = im
+ return padding_im, im_info
+
+
+class LetterBoxResize(object):
+ def __init__(self, target_size):
+ """
+ Resize image to target size, convert normalized xywh to pixel xyxy
+ format ([x_center, y_center, width, height] -> [x0, y0, x1, y1]).
+ Args:
+ target_size (int|list): image target size.
+ """
+ super(LetterBoxResize, self).__init__()
+ if isinstance(target_size, int):
+ target_size = [target_size, target_size]
+ self.target_size = target_size
+
+ def letterbox(self, img, height, width, color=(127.5, 127.5, 127.5)):
+ # letterbox: resize a rectangular image to a padded rectangular
+ shape = img.shape[:2] # [height, width]
+ ratio_h = float(height) / shape[0]
+ ratio_w = float(width) / shape[1]
+ ratio = min(ratio_h, ratio_w)
+ new_shape = (round(shape[1] * ratio),
+ round(shape[0] * ratio)) # [width, height]
+ padw = (width - new_shape[0]) / 2
+ padh = (height - new_shape[1]) / 2
+ top, bottom = round(padh - 0.1), round(padh + 0.1)
+ left, right = round(padw - 0.1), round(padw + 0.1)
+
+ img = cv2.resize(
+ img, new_shape, interpolation=cv2.INTER_AREA) # resized, no border
+ img = cv2.copyMakeBorder(
+ img, top, bottom, left, right, cv2.BORDER_CONSTANT,
+ value=color) # padded rectangular
+ return img, ratio, padw, padh
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ assert len(self.target_size) == 2
+ assert self.target_size[0] > 0 and self.target_size[1] > 0
+ height, width = self.target_size
+ h, w = im.shape[:2]
+ im, ratio, padw, padh = self.letterbox(im, height=height, width=width)
+
+ new_shape = [round(h * ratio), round(w * ratio)]
+ im_info['im_shape'] = np.array(new_shape, dtype=np.float32)
+ im_info['scale_factor'] = np.array([ratio, ratio], dtype=np.float32)
+ return im, im_info
+
+
+class Pad(object):
+ def __init__(self, size, fill_value=[114.0, 114.0, 114.0]):
+ """
+ Pad image to a specified size.
+ Args:
+ size (list[int]): image target size
+ fill_value (list[float]): rgb value of pad area, default (114.0, 114.0, 114.0)
+ """
+ super(Pad, self).__init__()
+ if isinstance(size, int):
+ size = [size, size]
+ self.size = size
+ self.fill_value = fill_value
+
+ def __call__(self, im, im_info):
+ im_h, im_w = im.shape[:2]
+ h, w = self.size
+ if h == im_h and w == im_w:
+ im = im.astype(np.float32)
+ return im, im_info
+
+ canvas = np.ones((h, w, 3), dtype=np.float32)
+ canvas *= np.array(self.fill_value, dtype=np.float32)
+ canvas[0:im_h, 0:im_w, :] = im.astype(np.float32)
+ im = canvas
+ return im, im_info
+
+
+class WarpAffine(object):
+ """Warp affine the image
+ """
+
+ def __init__(self,
+ keep_res=False,
+ pad=31,
+ input_h=512,
+ input_w=512,
+ scale=0.4,
+ shift=0.1,
+ down_ratio=4):
+ self.keep_res = keep_res
+ self.pad = pad
+ self.input_h = input_h
+ self.input_w = input_w
+ self.scale = scale
+ self.shift = shift
+ self.down_ratio = down_ratio
+
+ def __call__(self, im, im_info):
+ """
+ Args:
+ im (np.ndarray): image (np.ndarray)
+ im_info (dict): info of image
+ Returns:
+ im (np.ndarray): processed image (np.ndarray)
+ im_info (dict): info of processed image
+ """
+ img = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
+
+ h, w = img.shape[:2]
+
+ if self.keep_res:
+ # True in detection eval/infer
+ input_h = (h | self.pad) + 1
+ input_w = (w | self.pad) + 1
+ s = np.array([input_w, input_h], dtype=np.float32)
+ c = np.array([w // 2, h // 2], dtype=np.float32)
+
+ else:
+ # False in centertrack eval_mot/eval_mot
+ s = max(h, w) * 1.0
+ input_h, input_w = self.input_h, self.input_w
+ c = np.array([w / 2., h / 2.], dtype=np.float32)
+
+ trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
+ img = cv2.resize(img, (w, h))
+ inp = cv2.warpAffine(
+ img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR)
+
+ if not self.keep_res:
+ out_h = input_h // self.down_ratio
+ out_w = input_w // self.down_ratio
+ trans_output = get_affine_transform(c, s, 0, [out_w, out_h])
+
+ im_info.update({
+ 'center': c,
+ 'scale': s,
+ 'out_height': out_h,
+ 'out_width': out_w,
+ 'inp_height': input_h,
+ 'inp_width': input_w,
+ 'trans_input': trans_input,
+ 'trans_output': trans_output,
+ })
+ return inp, im_info
+
+
+class CULaneResize(object):
+ def __init__(self, img_h, img_w, cut_height, prob=0.5):
+ super(CULaneResize, self).__init__()
+ self.img_h = img_h
+ self.img_w = img_w
+ self.cut_height = cut_height
+ self.prob = prob
+
+ def __call__(self, im, im_info):
+ # cut
+ im = im[self.cut_height:, :, :]
+ # resize
+ transform = iaa.Sometimes(self.prob,
+ iaa.Resize({
+ "height": self.img_h,
+ "width": self.img_w
+ }))
+ im = transform(image=im.copy().astype(np.uint8))
+
+ im = im.astype(np.float32) / 255.
+ # check transpose is need whether the func decode_image is equal to CULaneDataSet cv.imread
+ im = im.transpose(2, 0, 1)
+
+ return im, im_info
+
+
+def preprocess(im, preprocess_ops):
+ # process image by preprocess_ops
+ im_info = {
+ 'scale_factor': np.array(
+ [1., 1.], dtype=np.float32),
+ 'im_shape': None,
+ }
+ im, im_info = decode_image(im, im_info)
+ for operator in preprocess_ops:
+ im, im_info = operator(im, im_info)
+ return im, im_info
diff --git a/helper/page_detection/start.sh b/helper/page_detection/start.sh
new file mode 100755
index 0000000..014d60c
--- /dev/null
+++ b/helper/page_detection/start.sh
@@ -0,0 +1,6 @@
+export FLAGS_enable_pir_api=0
+
+python3 pdf_detection.py \
+ --model_dir=/mnt/research/PaddleOCR/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer \
+ --image_file=/mnt/research/PaddleOCR/demo-75-images/12.jpg \
+ --device=GPU
diff --git a/helper/page_detection/utils.py b/helper/page_detection/utils.py
new file mode 100644
index 0000000..17184ef
--- /dev/null
+++ b/helper/page_detection/utils.py
@@ -0,0 +1,629 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+import os
+import ast
+import argparse
+from typing import List, Tuple
+import numpy as np
+
+
+class LayoutBox(object):
+ def __init__(self, clsid: int, pos: List[float], confidence: float):
+ self.clsid = clsid
+ self.pos = pos
+ self.confidence = confidence
+
+
+class PageDetectionResult(object):
+ def __init__(self, boxes: List[LayoutBox], image_path: str):
+ self.boxes = boxes
+ self.image_path = image_path
+
+
+def argsparser():
+ parser = argparse.ArgumentParser(description=__doc__)
+ parser.add_argument(
+ "--image_dir",
+ type=str,
+ default=None,
+ help="Dir of image file, `image_file` has a higher priority.")
+ parser.add_argument(
+ "--batch_size", type=int, default=1, help="batch_size for inference.")
+ parser.add_argument(
+ "--video_file",
+ type=str,
+ default=None,
+ help="Path of video file, `video_file` or `camera_id` has a highest priority."
+ )
+ parser.add_argument(
+ "--camera_id",
+ type=int,
+ default=-1,
+ help="device id of camera to predict.")
+ parser.add_argument(
+ "--threshold", type=float, default=0.5, help="Threshold of score.")
+ parser.add_argument(
+ "--output_dir",
+ type=str,
+ default="output",
+ help="Directory of output visualization files.")
+ parser.add_argument(
+ "--run_mode",
+ type=str,
+ default='paddle',
+ help="mode of running(paddle/trt_fp32/trt_fp16/trt_int8)")
+ parser.add_argument(
+ "--device",
+ type=str,
+ default='cpu',
+ help="Choose the device you want to run, it can be: CPU/GPU/XPU/NPU, default is CPU."
+ )
+ parser.add_argument(
+ "--use_gpu",
+ type=ast.literal_eval,
+ default=False,
+ help="Deprecated, please use `--device`.")
+ parser.add_argument(
+ "--run_benchmark",
+ type=ast.literal_eval,
+ default=False,
+ help="Whether to predict a image_file repeatedly for benchmark")
+ parser.add_argument(
+ "--enable_mkldnn",
+ type=ast.literal_eval,
+ default=False,
+ help="Whether use mkldnn with CPU.")
+ parser.add_argument(
+ "--enable_mkldnn_bfloat16",
+ type=ast.literal_eval,
+ default=False,
+ help="Whether use mkldnn bfloat16 inference with CPU.")
+ parser.add_argument(
+ "--cpu_threads", type=int, default=1, help="Num of threads with CPU.")
+ parser.add_argument(
+ "--trt_min_shape", type=int, default=1, help="min_shape for TensorRT.")
+ parser.add_argument(
+ "--trt_max_shape",
+ type=int,
+ default=1280,
+ help="max_shape for TensorRT.")
+ parser.add_argument(
+ "--trt_opt_shape",
+ type=int,
+ default=640,
+ help="opt_shape for TensorRT.")
+ parser.add_argument(
+ "--trt_calib_mode",
+ type=bool,
+ default=False,
+ help="If the model is produced by TRT offline quantitative "
+ "calibration, trt_calib_mode need to set True.")
+ parser.add_argument(
+ '--save_images',
+ type=ast.literal_eval,
+ default=True,
+ help='Save visualization image results.')
+ parser.add_argument(
+ '--save_mot_txts',
+ action='store_true',
+ help='Save tracking results (txt).')
+ parser.add_argument(
+ '--save_mot_txt_per_img',
+ action='store_true',
+ help='Save tracking results (txt) for each image.')
+ parser.add_argument(
+ '--scaled',
+ type=bool,
+ default=False,
+ help="Whether coords after detector outputs are scaled, False in JDE YOLOv3 "
+ "True in general detector.")
+ parser.add_argument(
+ "--tracker_config", type=str, default=None, help=("tracker donfig"))
+ parser.add_argument(
+ "--reid_model_dir",
+ type=str,
+ default=None,
+ help=("Directory include:'model.pdiparams', 'model.pdmodel', "
+ "'infer_cfg.yml', created by tools/export_model.py."))
+ parser.add_argument(
+ "--reid_batch_size",
+ type=int,
+ default=50,
+ help="max batch_size for reid model inference.")
+ parser.add_argument(
+ '--use_dark',
+ type=ast.literal_eval,
+ default=True,
+ help='whether to use darkpose to get better keypoint position predict ')
+ parser.add_argument(
+ "--action_file",
+ type=str,
+ default=None,
+ help="Path of input file for action recognition.")
+ parser.add_argument(
+ "--window_size",
+ type=int,
+ default=50,
+ help="Temporal size of skeleton feature for action recognition.")
+ parser.add_argument(
+ "--random_pad",
+ type=ast.literal_eval,
+ default=False,
+ help="Whether do random padding for action recognition.")
+ parser.add_argument(
+ "--save_results",
+ action='store_true',
+ default=False,
+ help="Whether save detection result to file using coco format")
+ parser.add_argument(
+ '--use_coco_category',
+ action='store_true',
+ default=False,
+ help='Whether to use the coco format dictionary `clsid2catid`')
+ parser.add_argument(
+ "--slice_infer",
+ action='store_true',
+ help="Whether to slice the image and merge the inference results for small object detection."
+ )
+ parser.add_argument(
+ '--slice_size',
+ nargs='+',
+ type=int,
+ default=[640, 640],
+ help="Height of the sliced image.")
+ parser.add_argument(
+ "--overlap_ratio",
+ nargs='+',
+ type=float,
+ default=[0.25, 0.25],
+ help="Overlap height ratio of the sliced image.")
+ parser.add_argument(
+ "--combine_method",
+ type=str,
+ default='nms',
+ help="Combine method of the sliced images' detection results, choose in ['nms', 'nmm', 'concat']."
+ )
+ parser.add_argument(
+ "--match_threshold",
+ type=float,
+ default=0.6,
+ help="Combine method matching threshold.")
+ parser.add_argument(
+ "--match_metric",
+ type=str,
+ default='ios',
+ help="Combine method matching metric, choose in ['iou', 'ios'].")
+ parser.add_argument(
+ "--collect_trt_shape_info",
+ action='store_true',
+ default=False,
+ help="Whether to collect dynamic shape before using tensorrt.")
+ parser.add_argument(
+ "--tuned_trt_shape_file",
+ type=str,
+ default="shape_range_info.pbtxt",
+ help="Path of a dynamic shape file for tensorrt.")
+ parser.add_argument("--use_fd_format", action="store_true")
+ parser.add_argument(
+ "--task_type",
+ type=str,
+ default='Detection',
+ help="How to save the coco result, it only work with save_results==True. Optional inputs are Rotate or Detection, default is Detection."
+ )
+ return parser
+
+
+class Times(object):
+ def __init__(self):
+ self.time = 0.
+ # start time
+ self.st = 0.
+ # end time
+ self.et = 0.
+
+ def start(self):
+ self.st = time.time()
+
+ def end(self, repeats=1, accumulative=True):
+ self.et = time.time()
+ if accumulative:
+ self.time += (self.et - self.st) / repeats
+ else:
+ self.time = (self.et - self.st) / repeats
+
+ def reset(self):
+ self.time = 0.
+ self.st = 0.
+ self.et = 0.
+
+ def value(self):
+ return round(self.time, 4)
+
+
+class Timer(Times):
+ def __init__(self, with_tracker=False):
+ super(Timer, self).__init__()
+ self.with_tracker = with_tracker
+ self.preprocess_time_s = Times()
+ self.inference_time_s = Times()
+ self.postprocess_time_s = Times()
+ self.tracking_time_s = Times()
+ self.img_num = 0
+
+ def info(self, average=False):
+ pre_time = self.preprocess_time_s.value()
+ infer_time = self.inference_time_s.value()
+ post_time = self.postprocess_time_s.value()
+ track_time = self.tracking_time_s.value()
+
+ total_time = pre_time + infer_time + post_time
+ if self.with_tracker:
+ total_time = total_time + track_time
+ total_time = round(total_time, 4)
+ print("------------------ Inference Time Info ----------------------")
+ print("total_time(ms): {}, img_num: {}".format(total_time * 1000,
+ self.img_num))
+ preprocess_time = round(pre_time / max(1, self.img_num),
+ 4) if average else pre_time
+ postprocess_time = round(post_time / max(1, self.img_num),
+ 4) if average else post_time
+ inference_time = round(infer_time / max(1, self.img_num),
+ 4) if average else infer_time
+ tracking_time = round(track_time / max(1, self.img_num),
+ 4) if average else track_time
+
+ average_latency = total_time / max(1, self.img_num)
+ qps = 0
+ if total_time > 0:
+ qps = 1 / average_latency
+ print("average latency time(ms): {:.2f}, QPS: {:2f}".format(
+ average_latency * 1000, qps))
+ if self.with_tracker:
+ print(
+ "preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}, tracking_time(ms): {:.2f}".
+ format(preprocess_time * 1000, inference_time * 1000,
+ postprocess_time * 1000, tracking_time * 1000))
+ else:
+ print(
+ "preprocess_time(ms): {:.2f}, inference_time(ms): {:.2f}, postprocess_time(ms): {:.2f}".
+ format(preprocess_time * 1000, inference_time * 1000,
+ postprocess_time * 1000))
+
+ def report(self, average=False):
+ dic = {}
+ pre_time = self.preprocess_time_s.value()
+ infer_time = self.inference_time_s.value()
+ post_time = self.postprocess_time_s.value()
+ track_time = self.tracking_time_s.value()
+
+ dic['preprocess_time_s'] = round(pre_time / max(1, self.img_num),
+ 4) if average else pre_time
+ dic['inference_time_s'] = round(infer_time / max(1, self.img_num),
+ 4) if average else infer_time
+ dic['postprocess_time_s'] = round(post_time / max(1, self.img_num),
+ 4) if average else post_time
+ dic['img_num'] = self.img_num
+ total_time = pre_time + infer_time + post_time
+ if self.with_tracker:
+ dic['tracking_time_s'] = round(track_time / max(1, self.img_num),
+ 4) if average else track_time
+ total_time = total_time + track_time
+ dic['total_time_s'] = round(total_time, 4)
+ return dic
+
+
+
+def multiclass_nms(bboxs, num_classes, match_threshold=0.6, match_metric='iou'):
+ final_boxes = []
+ for c in range(num_classes):
+ idxs = bboxs[:, 0] == c
+ if np.count_nonzero(idxs) == 0: continue
+ r = nms(bboxs[idxs, 1:], match_threshold, match_metric)
+ final_boxes.append(np.concatenate([np.full((r.shape[0], 1), c), r], 1))
+ return final_boxes
+
+
+def nms(dets, match_threshold=0.6, match_metric='iou'):
+ """ Apply NMS to avoid detecting too many overlapping bounding boxes.
+ Args:
+ dets: shape [N, 5], [score, x1, y1, x2, y2]
+ match_metric: 'iou' or 'ios'
+ match_threshold: overlap thresh for match metric.
+ """
+ if dets.shape[0] == 0:
+ return dets[[], :]
+ scores = dets[:, 0]
+ x1 = dets[:, 1]
+ y1 = dets[:, 2]
+ x2 = dets[:, 3]
+ y2 = dets[:, 4]
+ areas = (x2 - x1 + 1) * (y2 - y1 + 1)
+ order = scores.argsort()[::-1]
+
+ ndets = dets.shape[0]
+ suppressed = np.zeros((ndets), dtype=np.int32)
+
+ for _i in range(ndets):
+ i = order[_i]
+ if suppressed[i] == 1:
+ continue
+ ix1 = x1[i]
+ iy1 = y1[i]
+ ix2 = x2[i]
+ iy2 = y2[i]
+ iarea = areas[i]
+ for _j in range(_i + 1, ndets):
+ j = order[_j]
+ if suppressed[j] == 1:
+ continue
+ xx1 = max(ix1, x1[j])
+ yy1 = max(iy1, y1[j])
+ xx2 = min(ix2, x2[j])
+ yy2 = min(iy2, y2[j])
+ w = max(0.0, xx2 - xx1 + 1)
+ h = max(0.0, yy2 - yy1 + 1)
+ inter = w * h
+ if match_metric == 'iou':
+ union = iarea + areas[j] - inter
+ match_value = inter / union
+ elif match_metric == 'ios':
+ smaller = min(iarea, areas[j])
+ match_value = inter / smaller
+ else:
+ raise ValueError()
+ if match_value >= match_threshold:
+ suppressed[j] = 1
+ keep = np.where(suppressed == 0)[0]
+ dets = dets[keep, :]
+ return dets
+
+
+coco_clsid2catid = {
+ 0: 1,
+ 1: 2,
+ 2: 3,
+ 3: 4,
+ 4: 5,
+ 5: 6,
+ 6: 7,
+ 7: 8,
+ 8: 9,
+ 9: 10,
+ 10: 11,
+ 11: 13,
+ 12: 14,
+ 13: 15,
+ 14: 16,
+ 15: 17,
+ 16: 18,
+ 17: 19,
+ 18: 20,
+ 19: 21,
+ 20: 22,
+ 21: 23,
+ 22: 24,
+ 23: 25,
+ 24: 27,
+ 25: 28,
+ 26: 31,
+ 27: 32,
+ 28: 33,
+ 29: 34,
+ 30: 35,
+ 31: 36,
+ 32: 37,
+ 33: 38,
+ 34: 39,
+ 35: 40,
+ 36: 41,
+ 37: 42,
+ 38: 43,
+ 39: 44,
+ 40: 46,
+ 41: 47,
+ 42: 48,
+ 43: 49,
+ 44: 50,
+ 45: 51,
+ 46: 52,
+ 47: 53,
+ 48: 54,
+ 49: 55,
+ 50: 56,
+ 51: 57,
+ 52: 58,
+ 53: 59,
+ 54: 60,
+ 55: 61,
+ 56: 62,
+ 57: 63,
+ 58: 64,
+ 59: 65,
+ 60: 67,
+ 61: 70,
+ 62: 72,
+ 63: 73,
+ 64: 74,
+ 65: 75,
+ 66: 76,
+ 67: 77,
+ 68: 78,
+ 69: 79,
+ 70: 80,
+ 71: 81,
+ 72: 82,
+ 73: 84,
+ 74: 85,
+ 75: 86,
+ 76: 87,
+ 77: 88,
+ 78: 89,
+ 79: 90
+}
+
+
+def gaussian_radius(bbox_size, min_overlap):
+ height, width = bbox_size
+
+ a1 = 1
+ b1 = (height + width)
+ c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
+ sq1 = np.sqrt(b1**2 - 4 * a1 * c1)
+ radius1 = (b1 + sq1) / (2 * a1)
+
+ a2 = 4
+ b2 = 2 * (height + width)
+ c2 = (1 - min_overlap) * width * height
+ sq2 = np.sqrt(b2**2 - 4 * a2 * c2)
+ radius2 = (b2 + sq2) / 2
+
+ a3 = 4 * min_overlap
+ b3 = -2 * min_overlap * (height + width)
+ c3 = (min_overlap - 1) * width * height
+ sq3 = np.sqrt(b3**2 - 4 * a3 * c3)
+ radius3 = (b3 + sq3) / 2
+ return min(radius1, radius2, radius3)
+
+
+def gaussian2D(shape, sigma_x=1, sigma_y=1):
+ m, n = [(ss - 1.) / 2. for ss in shape]
+ y, x = np.ogrid[-m:m + 1, -n:n + 1]
+
+ h = np.exp(-(x * x / (2 * sigma_x * sigma_x) + y * y / (2 * sigma_y *
+ sigma_y)))
+ h[h < np.finfo(h.dtype).eps * h.max()] = 0
+ return h
+
+
+def draw_umich_gaussian(heatmap, center, radius, k=1):
+ """
+ draw_umich_gaussian, refer to https://github.com/xingyizhou/CenterNet/blob/master/src/lib/utils/image.py#L126
+ """
+ diameter = 2 * radius + 1
+ gaussian = gaussian2D(
+ (diameter, diameter), sigma_x=diameter / 6, sigma_y=diameter / 6)
+
+ x, y = int(center[0]), int(center[1])
+
+ height, width = heatmap.shape[0:2]
+
+ left, right = min(x, radius), min(width - x, radius + 1)
+ top, bottom = min(y, radius), min(height - y, radius + 1)
+
+ masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
+ masked_gaussian = gaussian[radius - top:radius + bottom, radius - left:
+ radius + right]
+ if min(masked_gaussian.shape) > 0 and min(masked_heatmap.shape) > 0:
+ np.maximum(masked_heatmap, masked_gaussian * k, out=masked_heatmap)
+ return heatmap
+
+
+def iou(box1, box2):
+ """计算两个框的 IoU(交并比)"""
+ x1 = max(box1[0], box2[0])
+ y1 = max(box1[1], box2[1])
+ x2 = min(box1[2], box2[2])
+ y2 = min(box1[3], box2[3])
+
+ inter_area = max(0, x2 - x1) * max(0, y2 - y1)
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
+
+ union_area = box1_area + box2_area - inter_area
+ return inter_area / union_area if union_area != 0 else 0
+
+def non_max_suppression(boxes, scores, iou_threshold):
+ """非极大值抑制"""
+ if not boxes:
+ return []
+ indices = np.argsort(scores)[::-1]
+ selected_boxes = []
+
+ while len(indices) > 0:
+ current = indices[0]
+ selected_boxes.append(current)
+ remaining = indices[1:]
+
+ filtered_indices = []
+ for i in remaining:
+ if iou(boxes[current], boxes[i]) <= iou_threshold:
+ filtered_indices.append(i)
+
+ indices = np.array(filtered_indices)
+
+ return selected_boxes
+
+
+def is_box_inside(inner, outer):
+ """判断inner box是否完全在outer box内"""
+ return (outer[0] <= inner[0] and outer[1] <= inner[1] and
+ outer[2] >= inner[2] and outer[3] >= inner[3])
+
+
+def boxes_overlap(box1: List[float], box2: List[float]) -> bool:
+ """判断两个框是否有交集"""
+ x1_max = max(box1[0], box2[0])
+ y1_max = max(box1[1], box2[1])
+ x2_min = min(box1[2], box2[2])
+ y2_min = min(box1[3], box2[3])
+ return x1_max < x2_min and y1_max < y2_min
+
+
+def merge_boxes(boxes: List[List[float]]) -> List[float]:
+ x1 = min(box[0] for box in boxes)
+ y1 = min(box[1] for box in boxes)
+ x2 = max(box[2] for box in boxes)
+ y2 = max(box[3] for box in boxes)
+ return [x1, y1, x2, y2]
+
+
+def merge_text_and_title_boxes(data: List[LayoutBox], merged_labels: Tuple[int]) -> List[LayoutBox]:
+ text_title_boxes = [(i, box) for i, box in enumerate(data) if box.clsid in merged_labels]
+ other_boxes = [box.pos for box in data if box.clsid in (2, 4, 5)]
+
+ text_title_boxes.sort(key=lambda x: x[1].pos[1]) # sort by y1
+
+ merged = []
+ skip_indices = set()
+ i = 0
+ while i < len(text_title_boxes):
+ if i in skip_indices:
+ i += 1
+ continue
+ current_group = [text_title_boxes[i][1].pos]
+ group_confidences = [text_title_boxes[i][1].confidence]
+ j = i + 1
+ while j < len(text_title_boxes):
+ candidate_box = text_title_boxes[j][1].pos
+ tentative_merge = merge_boxes(current_group + [candidate_box])
+ has_intruder = any(boxes_overlap(other, tentative_merge) for other in other_boxes)
+ if has_intruder:
+ break
+ else:
+ current_group.append(candidate_box)
+ group_confidences.append(text_title_boxes[j][1].confidence)
+ skip_indices.add(j)
+ j += 1
+ if len(current_group) > 1:
+ merged_box = LayoutBox(0, merge_boxes(current_group), max(group_confidences))
+ merged.append(merged_box)
+ else:
+ idx = text_title_boxes[i][0]
+ merged.append(data[idx])
+ i += 1
+
+ remaining = [data[i] for i in range(len(data)) if i not in skip_indices and data[i].clsid not in merged_labels]
+ return merged + remaining
diff --git a/helper/page_detection/visualize.py b/helper/page_detection/visualize.py
new file mode 100644
index 0000000..9e96c29
--- /dev/null
+++ b/helper/page_detection/visualize.py
@@ -0,0 +1,649 @@
+# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import division
+
+import os
+import cv2
+import math
+import numpy as np
+import PIL
+from PIL import Image, ImageDraw, ImageFile
+ImageFile.LOAD_TRUNCATED_IMAGES = True
+
+def imagedraw_textsize_c(draw, text):
+ if int(PIL.__version__.split('.')[0]) < 10:
+ tw, th = draw.textsize(text)
+ else:
+ left, top, right, bottom = draw.textbbox((0, 0), text)
+ tw, th = right - left, bottom - top
+
+ return tw, th
+
+
+def visualize_box_mask(im, results, labels, threshold=0.5):
+ """
+ Args:
+ im (str/np.ndarray): path of image/np.ndarray read by cv2
+ results (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
+ matix element:[class, score, x_min, y_min, x_max, y_max]
+ MaskRCNN's results include 'masks': np.ndarray:
+ shape:[N, im_h, im_w]
+ labels (list): labels:['class1', ..., 'classn']
+ threshold (float): Threshold of score.
+ Returns:
+ im (PIL.Image.Image): visualized image
+ """
+ if isinstance(im, str):
+ im = Image.open(im).convert('RGB')
+ elif isinstance(im, np.ndarray):
+ im = Image.fromarray(im)
+ if 'masks' in results and 'boxes' in results and len(results['boxes']) > 0:
+ im = draw_mask(
+ im, results['boxes'], results['masks'], labels, threshold=threshold)
+ if 'boxes' in results and len(results['boxes']) > 0:
+ im = draw_box(im, results['boxes'], labels, threshold=threshold)
+ if 'segm' in results:
+ im = draw_segm(
+ im,
+ results['segm'],
+ results['label'],
+ results['score'],
+ labels,
+ threshold=threshold)
+ return im
+
+
+def get_color_map_list(num_classes):
+ """
+ Args:
+ num_classes (int): number of class
+ Returns:
+ color_map (list): RGB color list
+ """
+ color_map = num_classes * [0, 0, 0]
+ for i in range(0, num_classes):
+ j = 0
+ lab = i
+ while lab:
+ color_map[i * 3] |= (((lab >> 0) & 1) << (7 - j))
+ color_map[i * 3 + 1] |= (((lab >> 1) & 1) << (7 - j))
+ color_map[i * 3 + 2] |= (((lab >> 2) & 1) << (7 - j))
+ j += 1
+ lab >>= 3
+ color_map = [color_map[i:i + 3] for i in range(0, len(color_map), 3)]
+ return color_map
+
+
+def draw_mask(im, np_boxes, np_masks, labels, threshold=0.5):
+ """
+ Args:
+ im (PIL.Image.Image): PIL image
+ np_boxes (np.ndarray): shape:[N,6], N: number of box,
+ matix element:[class, score, x_min, y_min, x_max, y_max]
+ np_masks (np.ndarray): shape:[N, im_h, im_w]
+ labels (list): labels:['class1', ..., 'classn']
+ threshold (float): threshold of mask
+ Returns:
+ im (PIL.Image.Image): visualized image
+ """
+ color_list = get_color_map_list(len(labels))
+ w_ratio = 0.4
+ alpha = 0.7
+ im = np.array(im).astype('float32')
+ clsid2color = {}
+ expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
+ np_boxes = np_boxes[expect_boxes, :]
+ np_masks = np_masks[expect_boxes, :, :]
+ im_h, im_w = im.shape[:2]
+ np_masks = np_masks[:, :im_h, :im_w]
+ for i in range(len(np_masks)):
+ clsid, score = int(np_boxes[i][0]), np_boxes[i][1]
+ mask = np_masks[i]
+ if clsid not in clsid2color:
+ clsid2color[clsid] = color_list[clsid]
+ color_mask = clsid2color[clsid]
+ for c in range(3):
+ color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
+ idx = np.nonzero(mask)
+ color_mask = np.array(color_mask)
+ im[idx[0], idx[1], :] *= 1.0 - alpha
+ im[idx[0], idx[1], :] += alpha * color_mask
+ return Image.fromarray(im.astype('uint8'))
+
+
+def draw_box(im, np_boxes, labels, threshold=0.5):
+ """
+ Args:
+ im (PIL.Image.Image): PIL image
+ np_boxes (np.ndarray): shape:[N,6], N: number of box,
+ matix element:[class, score, x_min, y_min, x_max, y_max]
+ labels (list): labels:['class1', ..., 'classn']
+ threshold (float): threshold of box
+ Returns:
+ im (PIL.Image.Image): visualized image
+ """
+ draw_thickness = min(im.size) // 320
+ draw = ImageDraw.Draw(im)
+ clsid2color = {}
+ color_list = get_color_map_list(len(labels))
+ expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
+ np_boxes = np_boxes[expect_boxes, :]
+
+ for dt in np_boxes:
+ clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
+ if clsid not in clsid2color:
+ clsid2color[clsid] = color_list[clsid]
+ color = tuple(clsid2color[clsid])
+
+ if len(bbox) == 4:
+ xmin, ymin, xmax, ymax = bbox
+ print('class_id:{:d}, confidence:{:.4f}, left_top:[{:.2f},{:.2f}],'
+ 'right_bottom:[{:.2f},{:.2f}]'.format(
+ int(clsid), score, xmin, ymin, xmax, ymax))
+ # draw bbox
+ draw.line(
+ [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
+ (xmin, ymin)],
+ width=draw_thickness,
+ fill=color)
+ elif len(bbox) == 8:
+ x1, y1, x2, y2, x3, y3, x4, y4 = bbox
+ draw.line(
+ [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)],
+ width=2,
+ fill=color)
+ xmin = min(x1, x2, x3, x4)
+ ymin = min(y1, y2, y3, y4)
+
+ # draw label
+ text = "{} {:.4f}".format(labels[clsid], score)
+ tw, th = imagedraw_textsize_c(draw, text)
+ draw.rectangle(
+ [(xmin + 1, ymin - th), (xmin + tw + 1, ymin)], fill=color)
+ draw.text((xmin + 1, ymin - th), text, fill=(255, 255, 255))
+ return im
+
+
+def draw_segm(im,
+ np_segms,
+ np_label,
+ np_score,
+ labels,
+ threshold=0.5,
+ alpha=0.7):
+ """
+ Draw segmentation on image
+ """
+ mask_color_id = 0
+ w_ratio = .4
+ color_list = get_color_map_list(len(labels))
+ im = np.array(im).astype('float32')
+ clsid2color = {}
+ np_segms = np_segms.astype(np.uint8)
+ for i in range(np_segms.shape[0]):
+ mask, score, clsid = np_segms[i], np_score[i], np_label[i]
+ if score < threshold:
+ continue
+
+ if clsid not in clsid2color:
+ clsid2color[clsid] = color_list[clsid]
+ color_mask = clsid2color[clsid]
+ for c in range(3):
+ color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
+ idx = np.nonzero(mask)
+ color_mask = np.array(color_mask)
+ idx0 = np.minimum(idx[0], im.shape[0] - 1)
+ idx1 = np.minimum(idx[1], im.shape[1] - 1)
+ im[idx0, idx1, :] *= 1.0 - alpha
+ im[idx0, idx1, :] += alpha * color_mask
+ sum_x = np.sum(mask, axis=0)
+ x = np.where(sum_x > 0.5)[0]
+ sum_y = np.sum(mask, axis=1)
+ y = np.where(sum_y > 0.5)[0]
+ x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
+ cv2.rectangle(im, (x0, y0), (x1, y1),
+ tuple(color_mask.astype('int32').tolist()), 1)
+ bbox_text = '%s %.2f' % (labels[clsid], score)
+ t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
+ cv2.rectangle(im, (x0, y0), (x0 + t_size[0], y0 - t_size[1] - 3),
+ tuple(color_mask.astype('int32').tolist()), -1)
+ cv2.putText(
+ im,
+ bbox_text, (x0, y0 - 2),
+ cv2.FONT_HERSHEY_SIMPLEX,
+ 0.3, (0, 0, 0),
+ 1,
+ lineType=cv2.LINE_AA)
+ return Image.fromarray(im.astype('uint8'))
+
+
+def get_color(idx):
+ idx = idx * 3
+ color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
+ return color
+
+
+def visualize_pose(imgfile,
+ results,
+ visual_thresh=0.6,
+ save_name='pose.jpg',
+ save_dir='output',
+ returnimg=False,
+ ids=None):
+ try:
+ import matplotlib.pyplot as plt
+ import matplotlib
+ plt.switch_backend('agg')
+ except Exception as e:
+ print('Matplotlib not found, please install matplotlib.'
+ 'for example: `pip install matplotlib`.')
+ raise e
+ skeletons, scores = results['keypoint']
+ skeletons = np.array(skeletons)
+ kpt_nums = 17
+ if len(skeletons) > 0:
+ kpt_nums = skeletons.shape[1]
+ if kpt_nums == 17: #plot coco keypoint
+ EDGES = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (5, 7), (6, 8),
+ (7, 9), (8, 10), (5, 11), (6, 12), (11, 13), (12, 14),
+ (13, 15), (14, 16), (11, 12)]
+ else: #plot mpii keypoint
+ EDGES = [(0, 1), (1, 2), (3, 4), (4, 5), (2, 6), (3, 6), (6, 7), (7, 8),
+ (8, 9), (10, 11), (11, 12), (13, 14), (14, 15), (8, 12),
+ (8, 13)]
+ NUM_EDGES = len(EDGES)
+
+ colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
+ cmap = matplotlib.cm.get_cmap('hsv')
+ plt.figure()
+
+ img = cv2.imread(imgfile) if type(imgfile) == str else imgfile
+
+ color_set = results['colors'] if 'colors' in results else None
+
+ if 'bbox' in results and ids is None:
+ bboxs = results['bbox']
+ for j, rect in enumerate(bboxs):
+ xmin, ymin, xmax, ymax = rect
+ color = colors[0] if color_set is None else colors[color_set[j] %
+ len(colors)]
+ cv2.rectangle(img, (xmin, ymin), (xmax, ymax), color, 1)
+
+ canvas = img.copy()
+ for i in range(kpt_nums):
+ for j in range(len(skeletons)):
+ if skeletons[j][i, 2] < visual_thresh:
+ continue
+ if ids is None:
+ color = colors[i] if color_set is None else colors[color_set[j]
+ %
+ len(colors)]
+ else:
+ color = get_color(ids[j])
+
+ cv2.circle(
+ canvas,
+ tuple(skeletons[j][i, 0:2].astype('int32')),
+ 2,
+ color,
+ thickness=-1)
+
+ to_plot = cv2.addWeighted(img, 0.3, canvas, 0.7, 0)
+ fig = matplotlib.pyplot.gcf()
+
+ stickwidth = 2
+
+ for i in range(NUM_EDGES):
+ for j in range(len(skeletons)):
+ edge = EDGES[i]
+ if skeletons[j][edge[0], 2] < visual_thresh or skeletons[j][edge[
+ 1], 2] < visual_thresh:
+ continue
+
+ cur_canvas = canvas.copy()
+ X = [skeletons[j][edge[0], 1], skeletons[j][edge[1], 1]]
+ Y = [skeletons[j][edge[0], 0], skeletons[j][edge[1], 0]]
+ mX = np.mean(X)
+ mY = np.mean(Y)
+ length = ((X[0] - X[1])**2 + (Y[0] - Y[1])**2)**0.5
+ angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
+ polygon = cv2.ellipse2Poly((int(mY), int(mX)),
+ (int(length / 2), stickwidth),
+ int(angle), 0, 360, 1)
+ if ids is None:
+ color = colors[i] if color_set is None else colors[color_set[j]
+ %
+ len(colors)]
+ else:
+ color = get_color(ids[j])
+ cv2.fillConvexPoly(cur_canvas, polygon, color)
+ canvas = cv2.addWeighted(canvas, 0.4, cur_canvas, 0.6, 0)
+ if returnimg:
+ return canvas
+ save_name = os.path.join(
+ save_dir, os.path.splitext(os.path.basename(imgfile))[0] + '_vis.jpg')
+ plt.imsave(save_name, canvas[:, :, ::-1])
+ print("keypoint visualize image saved to: " + save_name)
+ plt.close()
+
+
+def visualize_attr(im, results, boxes=None, is_mtmct=False):
+ if isinstance(im, str):
+ im = Image.open(im)
+ im = np.ascontiguousarray(np.copy(im))
+ im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
+ else:
+ im = np.ascontiguousarray(np.copy(im))
+
+ im_h, im_w = im.shape[:2]
+ text_scale = max(0.5, im.shape[0] / 3000.)
+ text_thickness = 1
+
+ line_inter = im.shape[0] / 40.
+ for i, res in enumerate(results):
+ if boxes is None:
+ text_w = 3
+ text_h = 1
+ elif is_mtmct:
+ box = boxes[i] # multi camera, bbox shape is x,y, w,h
+ text_w = int(box[0]) + 3
+ text_h = int(box[1])
+ else:
+ box = boxes[i] # single camera, bbox shape is 0, 0, x,y, w,h
+ text_w = int(box[2]) + 3
+ text_h = int(box[3])
+ for text in res:
+ text_h += int(line_inter)
+ text_loc = (text_w, text_h)
+ cv2.putText(
+ im,
+ text,
+ text_loc,
+ cv2.FONT_ITALIC,
+ text_scale, (0, 255, 255),
+ thickness=text_thickness)
+ return im
+
+
+def visualize_action(im,
+ mot_boxes,
+ action_visual_collector=None,
+ action_text="",
+ video_action_score=None,
+ video_action_text=""):
+ im = cv2.imread(im) if isinstance(im, str) else im
+ im_h, im_w = im.shape[:2]
+
+ text_scale = max(1, im.shape[1] / 400.)
+ text_thickness = 2
+
+ if action_visual_collector:
+ id_action_dict = {}
+ for collector, action_type in zip(action_visual_collector, action_text):
+ id_detected = collector.get_visualize_ids()
+ for pid in id_detected:
+ id_action_dict[pid] = id_action_dict.get(pid, [])
+ id_action_dict[pid].append(action_type)
+ for mot_box in mot_boxes:
+ # mot_box is a format with [mot_id, class, score, xmin, ymin, w, h]
+ if mot_box[0] in id_action_dict:
+ text_position = (int(mot_box[3] + mot_box[5] * 0.75),
+ int(mot_box[4] - 10))
+ display_text = ', '.join(id_action_dict[mot_box[0]])
+ cv2.putText(im, display_text, text_position,
+ cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255), 2)
+
+ if video_action_score:
+ cv2.putText(
+ im,
+ video_action_text + ': %.2f' % video_action_score,
+ (int(im_w / 2), int(15 * text_scale) + 5),
+ cv2.FONT_ITALIC,
+ text_scale, (0, 0, 255),
+ thickness=text_thickness)
+
+ return im
+
+
+def visualize_vehicleplate(im, results, boxes=None):
+ if isinstance(im, str):
+ im = Image.open(im)
+ im = np.ascontiguousarray(np.copy(im))
+ im = cv2.cvtColor(im, cv2.COLOR_RGB2BGR)
+ else:
+ im = np.ascontiguousarray(np.copy(im))
+
+ im_h, im_w = im.shape[:2]
+ text_scale = max(1.0, im.shape[0] / 400.)
+ text_thickness = 2
+
+ line_inter = im.shape[0] / 40.
+ for i, res in enumerate(results):
+ if boxes is None:
+ text_w = 3
+ text_h = 1
+ else:
+ box = boxes[i]
+ text = res
+ if text == "":
+ continue
+ text_w = int(box[2])
+ text_h = int(box[5] + box[3])
+ text_loc = (text_w, text_h)
+ cv2.putText(
+ im,
+ "LP: " + text,
+ text_loc,
+ cv2.FONT_ITALIC,
+ text_scale, (0, 255, 255),
+ thickness=text_thickness)
+ return im
+
+
+def draw_press_box_lanes(im, np_boxes, labels, threshold=0.5):
+ """
+ Args:
+ im (PIL.Image.Image): PIL image
+ np_boxes (np.ndarray): shape:[N,6], N: number of box,
+ matix element:[class, score, x_min, y_min, x_max, y_max]
+ labels (list): labels:['class1', ..., 'classn']
+ threshold (float): threshold of box
+ Returns:
+ im (PIL.Image.Image): visualized image
+ """
+
+ if isinstance(im, str):
+ im = Image.open(im).convert('RGB')
+ elif isinstance(im, np.ndarray):
+ im = Image.fromarray(im)
+
+ draw_thickness = min(im.size) // 320
+ draw = ImageDraw.Draw(im)
+ clsid2color = {}
+ color_list = get_color_map_list(len(labels))
+
+ if np_boxes.shape[1] == 7:
+ np_boxes = np_boxes[:, 1:]
+
+ expect_boxes = (np_boxes[:, 1] > threshold) & (np_boxes[:, 0] > -1)
+ np_boxes = np_boxes[expect_boxes, :]
+
+ for dt in np_boxes:
+ clsid, bbox, score = int(dt[0]), dt[2:], dt[1]
+ if clsid not in clsid2color:
+ clsid2color[clsid] = color_list[clsid]
+ color = tuple(clsid2color[clsid])
+
+ if len(bbox) == 4:
+ xmin, ymin, xmax, ymax = bbox
+ # draw bbox
+ draw.line(
+ [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
+ (xmin, ymin)],
+ width=draw_thickness,
+ fill=(0, 0, 255))
+ elif len(bbox) == 8:
+ x1, y1, x2, y2, x3, y3, x4, y4 = bbox
+ draw.line(
+ [(x1, y1), (x2, y2), (x3, y3), (x4, y4), (x1, y1)],
+ width=2,
+ fill=color)
+ xmin = min(x1, x2, x3, x4)
+ ymin = min(y1, y2, y3, y4)
+
+ # draw label
+ text = "{}".format(labels[clsid])
+ tw, th = imagedraw_textsize_c(draw, text)
+ draw.rectangle(
+ [(xmin + 1, ymax - th), (xmin + tw + 1, ymax)], fill=color)
+ draw.text((xmin + 1, ymax - th), text, fill=(0, 0, 255))
+ return im
+
+
+def visualize_vehiclepress(im, results, threshold=0.5):
+ results = np.array(results)
+ labels = ['violation']
+ im = draw_press_box_lanes(im, results, labels, threshold=threshold)
+ return im
+
+
+def visualize_lane(im, lanes):
+ if isinstance(im, str):
+ im = Image.open(im).convert('RGB')
+ elif isinstance(im, np.ndarray):
+ im = Image.fromarray(im)
+
+ draw_thickness = min(im.size) // 320
+ draw = ImageDraw.Draw(im)
+
+ if len(lanes) > 0:
+ for lane in lanes:
+ draw.line(
+ [(lane[0], lane[1]), (lane[2], lane[3])],
+ width=draw_thickness,
+ fill=(0, 0, 255))
+
+ return im
+
+
+def visualize_vehicle_retrograde(im, mot_res, vehicle_retrograde_res):
+ if isinstance(im, str):
+ im = Image.open(im).convert('RGB')
+ elif isinstance(im, np.ndarray):
+ im = Image.fromarray(im)
+
+ draw_thickness = min(im.size) // 320
+ draw = ImageDraw.Draw(im)
+
+ lane = vehicle_retrograde_res['fence_line']
+ if lane is not None:
+ draw.line(
+ [(lane[0], lane[1]), (lane[2], lane[3])],
+ width=draw_thickness,
+ fill=(0, 0, 0))
+
+ mot_id = vehicle_retrograde_res['output']
+ if mot_id is None or len(mot_id) == 0:
+ return im
+
+ if mot_res is None:
+ return im
+ np_boxes = mot_res['boxes']
+
+ if np_boxes is not None:
+ for dt in np_boxes:
+ if dt[0] not in mot_id:
+ continue
+ bbox = dt[3:]
+ if len(bbox) == 4:
+ xmin, ymin, xmax, ymax = bbox
+ # draw bbox
+ draw.line(
+ [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin),
+ (xmin, ymin)],
+ width=draw_thickness,
+ fill=(0, 255, 0))
+
+ # draw label
+ text = "retrograde"
+ tw, th = imagedraw_textsize_c(draw, text)
+ draw.rectangle(
+ [(xmax + 1, ymin - th), (xmax + tw + 1, ymin)],
+ fill=(0, 255, 0))
+ draw.text((xmax + 1, ymin - th), text, fill=(0, 255, 0))
+
+ return im
+
+
+COLORS = [
+ (255, 0, 0),
+ (0, 255, 0),
+ (0, 0, 255),
+ (255, 255, 0),
+ (255, 0, 255),
+ (0, 255, 255),
+ (128, 255, 0),
+ (255, 128, 0),
+ (128, 0, 255),
+ (255, 0, 128),
+ (0, 128, 255),
+ (0, 255, 128),
+ (128, 255, 255),
+ (255, 128, 255),
+ (255, 255, 128),
+ (60, 180, 0),
+ (180, 60, 0),
+ (0, 60, 180),
+ (0, 180, 60),
+ (60, 0, 180),
+ (180, 0, 60),
+ (255, 0, 0),
+ (0, 255, 0),
+ (0, 0, 255),
+ (255, 255, 0),
+ (255, 0, 255),
+ (0, 255, 255),
+ (128, 255, 0),
+ (255, 128, 0),
+ (128, 0, 255),
+]
+
+
+def imshow_lanes(img, lanes, show=False, out_file=None, width=4):
+ lanes_xys = []
+ for _, lane in enumerate(lanes):
+ xys = []
+ for x, y in lane:
+ if x <= 0 or y <= 0:
+ continue
+ x, y = int(x), int(y)
+ xys.append((x, y))
+ lanes_xys.append(xys)
+ lanes_xys.sort(key=lambda xys: xys[0][0] if len(xys) > 0 else 0)
+
+ for idx, xys in enumerate(lanes_xys):
+ for i in range(1, len(xys)):
+ cv2.line(img, xys[i - 1], xys[i], COLORS[idx], thickness=width)
+
+ if show:
+ cv2.imshow('view', img)
+ cv2.waitKey(0)
+
+ if out_file:
+ if not os.path.exists(os.path.dirname(out_file)):
+ os.makedirs(os.path.dirname(out_file))
+ cv2.imwrite(out_file, img)
\ No newline at end of file
diff --git a/magic-pdf.json b/magic-pdf.json
new file mode 100644
index 0000000..574153e
--- /dev/null
+++ b/magic-pdf.json
@@ -0,0 +1,62 @@
+{
+ "bucket_info": {
+ "bucket-name-1": [
+ "ak",
+ "sk",
+ "endpoint"
+ ],
+ "bucket-name-2": [
+ "ak",
+ "sk",
+ "endpoint"
+ ]
+ },
+ "models-dir": "/root/.cache/modelscope/hub/models/opendatalab/PDF-Extract-Kit-1___0/models",
+ "layoutreader-model-dir": "/root/.cache/modelscope/hub/models/ppaanngggg/layoutreader",
+ "device-mode": "cuda",
+ "layout-config": {
+ "model": "doclayout_yolo"
+ },
+ "formula-config": {
+ "mfd_model": "yolo_v8_mfd",
+ "mfr_model": "unimernet_small",
+ "enable": false
+ },
+ "table-config": {
+ "model": "rapid_table",
+ "sub_model": "slanet_plus",
+ "enable": false,
+ "max_time": 400
+ },
+ "latex-delimiter-config": {
+ "display": {
+ "left": "$$",
+ "right": "$$"
+ },
+ "inline": {
+ "left": "$",
+ "right": "$"
+ }
+ },
+ "llm-aided-config": {
+ "formula_aided": {
+ "api_key": "your_api_key",
+ "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
+ "model": "qwen2.5-7b-instruct",
+ "enable": false
+ },
+ "text_aided": {
+ "api_key": "your_api_key",
+ "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
+ "model": "qwen2.5-7b-instruct",
+ "enable": false
+ },
+ "title_aided": {
+ "api_key": "your_api_key",
+ "base_url": "https://dashscope.aliyuncs.com/compatible-mode/v1",
+ "model": "qwen2.5-32b-instruct",
+ "enable": false
+ }
+ },
+ "config_version": "1.2.1"
+}
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/infer_cfg.yml b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/infer_cfg.yml
new file mode 100644
index 0000000..01a1a7c
--- /dev/null
+++ b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/infer_cfg.yml
@@ -0,0 +1,48 @@
+mode: paddle
+draw_threshold: 0.5
+metric: COCO
+use_dynamic_shape: false
+arch: PicoDet
+min_subgraph_size: 3
+Preprocess:
+- interp: 2
+ keep_ratio: false
+ target_size:
+ - 800
+ - 608
+ type: Resize
+- is_scale: true
+ mean:
+ - 0.485
+ - 0.456
+ - 0.406
+ std:
+ - 0.229
+ - 0.224
+ - 0.225
+ type: NormalizeImage
+- type: Permute
+- stride: 32
+ type: PadStride
+label_list:
+- Text
+- Title
+- Figure
+- Figure caption
+- Table
+- Table caption
+- Header
+- Footer
+- Reference
+- Equation
+NMS:
+ keep_top_k: 100
+ name: MultiClassNMS
+ nms_threshold: 0.5
+ nms_top_k: 1000
+ score_threshold: 0.3
+fpn_stride:
+- 8
+- 16
+- 32
+- 64
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdiparams b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdiparams
new file mode 100644
index 0000000..956d52f
Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdiparams differ
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdiparams.info b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdiparams.info
new file mode 100644
index 0000000..960eaba
Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdiparams.info differ
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdmodel b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdmodel
new file mode 100644
index 0000000..2ba4d31
Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_cdla_infer/model.pdmodel differ
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/infer_cfg.yml b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/infer_cfg.yml
new file mode 100644
index 0000000..d80c7c2
--- /dev/null
+++ b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/infer_cfg.yml
@@ -0,0 +1,43 @@
+mode: paddle
+draw_threshold: 0.5
+metric: COCO
+use_dynamic_shape: false
+arch: PicoDet
+min_subgraph_size: 3
+Preprocess:
+- interp: 2
+ keep_ratio: false
+ target_size:
+ - 800
+ - 608
+ type: Resize
+- is_scale: true
+ mean:
+ - 0.485
+ - 0.456
+ - 0.406
+ std:
+ - 0.229
+ - 0.224
+ - 0.225
+ type: NormalizeImage
+- type: Permute
+- stride: 32
+ type: PadStride
+label_list:
+- text
+- title
+- list
+- table
+- figure
+NMS:
+ keep_top_k: 100
+ name: MultiClassNMS
+ nms_threshold: 0.5
+ nms_top_k: 1000
+ score_threshold: 0.3
+fpn_stride:
+- 8
+- 16
+- 32
+- 64
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdiparams b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdiparams
new file mode 100644
index 0000000..64be068
Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdiparams differ
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdiparams.info b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdiparams.info
new file mode 100644
index 0000000..960eaba
Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdiparams.info differ
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdmodel b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdmodel
new file mode 100644
index 0000000..6f6d382
Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_infer/model.pdmodel differ
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/infer_cfg.yml b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/infer_cfg.yml
new file mode 100644
index 0000000..bc6bba7
--- /dev/null
+++ b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/infer_cfg.yml
@@ -0,0 +1,39 @@
+mode: paddle
+draw_threshold: 0.5
+metric: COCO
+use_dynamic_shape: false
+arch: PicoDet
+min_subgraph_size: 3
+Preprocess:
+- interp: 2
+ keep_ratio: false
+ target_size:
+ - 800
+ - 608
+ type: Resize
+- is_scale: true
+ mean:
+ - 0.485
+ - 0.456
+ - 0.406
+ std:
+ - 0.229
+ - 0.224
+ - 0.225
+ type: NormalizeImage
+- type: Permute
+- stride: 32
+ type: PadStride
+label_list:
+- table
+NMS:
+ keep_top_k: 100
+ name: MultiClassNMS
+ nms_threshold: 0.5
+ nms_top_k: 1000
+ score_threshold: 0.3
+fpn_stride:
+- 8
+- 16
+- 32
+- 64
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdiparams b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdiparams
new file mode 100644
index 0000000..cfd220e
Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdiparams differ
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdiparams.info b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdiparams.info
new file mode 100644
index 0000000..960eaba
Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdiparams.info differ
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdmodel b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdmodel
new file mode 100644
index 0000000..50135d1
Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_fgd_layout_table_infer/model.pdmodel differ
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/infer_cfg.yml b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/infer_cfg.yml
new file mode 100644
index 0000000..d80c7c2
--- /dev/null
+++ b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/infer_cfg.yml
@@ -0,0 +1,43 @@
+mode: paddle
+draw_threshold: 0.5
+metric: COCO
+use_dynamic_shape: false
+arch: PicoDet
+min_subgraph_size: 3
+Preprocess:
+- interp: 2
+ keep_ratio: false
+ target_size:
+ - 800
+ - 608
+ type: Resize
+- is_scale: true
+ mean:
+ - 0.485
+ - 0.456
+ - 0.406
+ std:
+ - 0.229
+ - 0.224
+ - 0.225
+ type: NormalizeImage
+- type: Permute
+- stride: 32
+ type: PadStride
+label_list:
+- text
+- title
+- list
+- table
+- figure
+NMS:
+ keep_top_k: 100
+ name: MultiClassNMS
+ nms_threshold: 0.5
+ nms_top_k: 1000
+ score_threshold: 0.3
+fpn_stride:
+- 8
+- 16
+- 32
+- 64
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdiparams b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdiparams
new file mode 100644
index 0000000..84a3e8b
Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdiparams differ
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdiparams.info b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdiparams.info
new file mode 100644
index 0000000..960eaba
Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdiparams.info differ
diff --git a/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdmodel b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdmodel
new file mode 100644
index 0000000..56a692c
Binary files /dev/null and b/models/PaddleDetection/inference_model/picodet_lcnet_x1_0_layout_infer/model.pdmodel differ
diff --git a/pipeline.py b/pipeline.py
new file mode 100644
index 0000000..ccf62e4
--- /dev/null
+++ b/pipeline.py
@@ -0,0 +1,94 @@
+from dotenv import load_dotenv
+import os
+
+env = os.environ.get('env', 'dev')
+load_dotenv(dotenv_path='.env.dev' if env == 'dev' else '.env', override=True)
+
+import time
+import traceback
+import cv2
+from helper.image_helper import pdf2image, image_orient_cls, page_detection_visual
+from helper.page_detection.main import layout_analysis
+from helper.content_recognition.main import rec
+from helper.db_helper import insert_pdf2md_table
+import tempfile
+from loguru import logger
+import datetime
+import shutil
+
+
+def _pdf2markdown_pipeline(pdf_path, tmp_dir):
+ start_time = time.time()
+ # 1. pdf -> images
+ t1 = time.time()
+ pdf2image(pdf_path, tmp_dir)
+ t2 = time.time()
+
+ # 2. 图片方向分类
+ t3 = time.time()
+ orient_cls_results = image_orient_cls(tmp_dir)
+ t4 = time.time()
+ for r in orient_cls_results:
+ clsid = r[0]['class_ids'][0]
+ filename = r[0]['filename']
+ if clsid == 1 or clsid == 3:
+ img = cv2.imread(filename)
+ img = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
+ cv2.imwrite(filename, img)
+
+ filepaths = os.listdir(tmp_dir)
+ filepaths.sort(key=lambda x: int(x.split('/')[-1].split('.')[0]))
+ filepaths = [f'{tmp_dir}/{_}' for _ in filepaths]
+
+ # filepaths = filepaths[:75]
+
+ # 3. 版面分析
+ t5 = time.time()
+ layout_detection_results = layout_analysis(filepaths)
+ t6 = time.time()
+
+ # 3.1 visual
+ if int(os.environ['VISUAL']):
+ visual_dir = './visual_images'
+ for f in os.listdir(visual_dir):
+ if f.endswith('.jpg'):
+ os.remove(f'{visual_dir}/{f}')
+ for i in range(len(layout_detection_results)):
+ vis_img = page_detection_visual(layout_detection_results[i])
+ cv2.imwrite(f'{visual_dir}/{i + 1}.jpg', vis_img)
+
+ # 4. 内容识别
+ t7 = time.time()
+ layout_recognition_results = rec(layout_detection_results, tmp_dir)
+ t8 = time.time()
+
+ end_time = time.time()
+ logger.info(f'{pdf_path} analysis completed in {round(end_time - start_time, 3)} seconds, including {round(t2 - t1, 3)} for pdf to image, {round(t4 - t3, 3)} second for image orient classification, {round(t6 - t5, 3)} seconds for page detection, and {round(t8 - t7, 3)} seconds for layout recognition, page number: {len(filepaths)}')
+
+ return layout_recognition_results
+
+
+def pdf2markdown_pipeline(pdf_path: str):
+ pdf_name = pdf_path.split('/')[-1]
+ start_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
+ process_status = 0
+ tmp_dir = tempfile.mkdtemp()
+ try:
+ results = _pdf2markdown_pipeline(pdf_path, tmp_dir)
+ except Exception:
+ logger.error(f'analysis pdf error! \n{traceback.format_exc()}')
+ process_status = 3
+ end_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
+ insert_pdf2md_table(pdf_path, pdf_name, process_status, start_time, end_time, None)
+ pdf_id = None
+ else:
+ process_status = 2
+ end_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
+ pdf_id = insert_pdf2md_table(pdf_path, pdf_name, process_status, start_time, end_time, results)
+ finally:
+ shutil.rmtree(tmp_dir)
+ return process_status, pdf_id
+
+
+if __name__ == '__main__':
+ pdf2markdown_pipeline('/mnt/pdf2markdown/demo.pdf')
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000..db03014
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,189 @@
+albucore==0.0.24
+albumentations==2.0.6
+annotated-types==0.7.0
+anthropic==0.46.0
+antlr4-python3-runtime==4.9.3
+anyio==4.9.0
+astor==0.8.1
+babel==2.17.0
+bce-python-sdk==0.9.29
+beautifulsoup4==4.13.4
+blinker==1.9.0
+boto3==1.38.8
+botocore==1.38.8
+Brotli==1.1.0
+cachetools==5.5.2
+certifi==2025.4.26
+cffi==1.17.1
+cfgv==3.4.0
+charset-normalizer==3.4.1
+click==8.1.8
+cobble==0.1.4
+coloredlogs==15.0.1
+colorlog==6.9.0
+contourpy==1.3.2
+cryptography==44.0.3
+cssselect2==0.8.0
+cycler==0.12.1
+Cython==3.0.12
+decorator==5.2.1
+dill==0.4.0
+distlib==0.3.9
+distro==1.9.0
+doclayout_yolo==0.0.2b1
+easydict==1.13
+EbookLib==0.18
+et_xmlfile==2.0.0
+faiss-cpu==1.8.0.post1
+fast-langdetect==0.2.5
+fasttext-predict==0.9.2.4
+filelock==3.18.0
+filetype==1.2.0
+fire==0.7.0
+Flask==3.1.0
+flask-babel==4.0.0
+flatbuffers==25.2.10
+fonttools==4.57.0
+fsspec==2025.3.2
+ftfy==6.3.1
+future==1.0.0
+gast==0.3.3
+google-auth==2.39.0
+google-genai==1.13.0
+h11==0.16.0
+httpcore==1.0.9
+httpx==0.28.1
+huggingface-hub==0.30.2
+humanfriendly==10.0
+identify==2.6.10
+idna==3.10
+imageio==2.37.0
+imgaug==0.4.0
+itsdangerous==2.2.0
+Jinja2==3.1.6
+jiter==0.9.0
+jmespath==1.0.1
+joblib==1.4.2
+kiwisolver==1.4.8
+lazy_loader==0.4
+lmdb==1.6.2
+loguru==0.7.3
+lxml==5.4.0
+magic-pdf==1.3.10
+mammoth==1.9.0
+markdown2==2.5.3
+markdownify==0.13.1
+marker-pdf==1.6.2
+MarkupSafe==3.0.2
+matplotlib==3.10.1
+modelscope==1.25.0
+mpmath==1.3.0
+networkx==3.4.2
+nodeenv==1.9.1
+numpy==1.24.4
+nvidia-cublas-cu12==12.4.5.8
+nvidia-cuda-cupti-cu12==12.4.127
+nvidia-cuda-nvrtc-cu12==12.4.127
+nvidia-cuda-runtime-cu12==12.4.127
+nvidia-cudnn-cu12==9.1.0.70
+nvidia-cufft-cu12==11.2.1.3
+nvidia-curand-cu12==10.3.5.147
+nvidia-cusolver-cu12==11.6.1.9
+nvidia-cusparse-cu12==12.3.1.170
+nvidia-cusparselt-cu12==0.6.2
+nvidia-nccl-cu12==2.21.5
+nvidia-nvjitlink-cu12==12.4.127
+nvidia-nvtx-cu12==12.4.127
+omegaconf==2.3.0
+onnxruntime==1.21.1
+openai==1.77.0
+opencv-contrib-python==4.11.0.86
+opencv-python==4.6.0.66
+opencv-python-headless==4.11.0.86
+openpyxl==3.1.5
+opt-einsum==3.3.0
+packaging==25.0
+paddleclas==2.5.2
+paddleocr==2.10.0
+paddlepaddle-gpu==2.6.2
+pandas==2.2.3
+pdf2image==1.17.0
+pdfminer.six==20250324
+pdftext==0.6.2
+pillow==10.4.0
+platformdirs==4.3.7
+pre_commit==4.2.0
+prettytable==3.16.0
+protobuf==6.30.2
+psutil==7.0.0
+psycopg2==2.9.10
+py-cpuinfo==9.0.0
+pyasn1==0.6.1
+pyasn1_modules==0.4.2
+pyclipper==1.3.0.post6
+pycparser==2.22
+pycryptodome==3.22.0
+pydantic==2.10.6
+pydantic-settings==2.9.1
+pydantic_core==2.27.2
+pydyf==0.11.0
+PyMuPDF==1.24.14
+pyparsing==3.2.3
+pypdfium2==4.30.0
+pyphen==0.17.2
+python-dateutil==2.9.0.post0
+python-docx==1.1.2
+python-dotenv==1.1.0
+python-pptx==1.0.2
+pytz==2025.2
+PyYAML==6.0.2
+rapid-table==1.0.5
+RapidFuzz==3.13.0
+rapidocr==2.0.7
+rarfile==4.2
+regex==2024.11.6
+requests==2.32.3
+robust-downloader==0.0.2
+rsa==4.9.1
+s3transfer==0.12.0
+safetensors==0.5.3
+scikit-image==0.25.2
+scikit-learn==1.6.1
+scipy==1.15.2
+seaborn==0.13.2
+shapely==2.1.0
+simsimd==6.2.1
+six==1.17.0
+sniffio==1.3.1
+soupsieve==2.7
+stringzilla==3.12.5
+surya-ocr==0.13.1
+sympy==1.13.1
+termcolor==3.1.0
+thop==0.1.1.post2209072238
+threadpoolctl==3.6.0
+tifffile==2025.3.30
+tinycss2==1.4.0
+tinyhtml5==2.0.0
+tokenizers==0.21.1
+torch==2.6.0
+torchvision==0.21.0
+tqdm==4.67.1
+transformers==4.51.3
+triton==3.2.0
+typing-inspection==0.4.0
+typing_extensions==4.13.2
+tzdata==2025.2
+ujson==5.10.0
+ultralytics==8.3.127
+ultralytics-thop==2.0.14
+urllib3==2.4.0
+virtualenv==20.31.1
+visualdl==2.5.3
+wcwidth==0.2.13
+weasyprint==63.1
+webencodings==0.5.1
+websockets==15.0.1
+Werkzeug==3.1.3
+XlsxWriter==3.2.3
+zopfli==0.2.3.post1
diff --git a/server.py b/server.py
new file mode 100644
index 0000000..17d28ae
--- /dev/null
+++ b/server.py
@@ -0,0 +1,16 @@
+from flask import Flask, request
+import requests
+from pipeline import pdf2markdown_pipeline
+
+
+app = Flask(__name__)
+
+
+@app.route('/pdf-qa-server/pdf-to-md')
+def pdf2markdown():
+ data = request.json
+ pdf_paths = data['pathList']
+ callback_url = data['webhookUrl']
+ for pdf_path in pdf_paths:
+ process_status, pdf_id = pdf2markdown_pipeline(pdf_path)
+ requests.post(callback_url, json={'pdfId': pdf_id, 'processStatus': process_status})
diff --git a/visual_images/.gitkeep b/visual_images/.gitkeep
new file mode 100644
index 0000000..e69de29